Эх сурвалжийг харах

Ruby: fixed string freezing for JRuby.

Josh Haberman 9 жил өмнө
parent
commit
d07a9963df

+ 2 - 2
ruby/src/main/java/com/google/protobuf/jruby/RubyMap.java

@@ -148,8 +148,8 @@ public class RubyMap extends RubyObject {
      */
      */
     @JRubyMethod(name = "[]=")
     @JRubyMethod(name = "[]=")
     public IRubyObject indexSet(ThreadContext context, IRubyObject key, IRubyObject value) {
     public IRubyObject indexSet(ThreadContext context, IRubyObject key, IRubyObject value) {
-        Utils.checkType(context, keyType, key, (RubyModule) valueTypeClass);
-        Utils.checkType(context, valueType, value, (RubyModule) valueTypeClass);
+        key = Utils.checkType(context, keyType, key, (RubyModule) valueTypeClass);
+        value = Utils.checkType(context, valueType, value, (RubyModule) valueTypeClass);
         IRubyObject symbol;
         IRubyObject symbol;
         if (valueType == Descriptors.FieldDescriptor.Type.ENUM &&
         if (valueType == Descriptors.FieldDescriptor.Type.ENUM &&
                 Utils.isRubyNum(value) &&
                 Utils.isRubyNum(value) &&

+ 2 - 2
ruby/src/main/java/com/google/protobuf/jruby/RubyMessage.java

@@ -504,7 +504,7 @@ public class RubyMessage extends RubyObject {
                 break;
                 break;
             case BYTES:
             case BYTES:
             case STRING:
             case STRING:
-                Utils.validateStringEncoding(context.runtime, fieldDescriptor.getType(), value);
+                Utils.validateStringEncoding(context, fieldDescriptor.getType(), value);
                 RubyString str = (RubyString) value;
                 RubyString str = (RubyString) value;
                 switch (fieldDescriptor.getType()) {
                 switch (fieldDescriptor.getType()) {
                     case BYTES:
                     case BYTES:
@@ -695,7 +695,7 @@ public class RubyMessage extends RubyObject {
                     }
                     }
                 }
                 }
                 if (addValue) {
                 if (addValue) {
-                    Utils.checkType(context, fieldType, value, (RubyModule) typeClass);
+                    value = Utils.checkType(context, fieldType, value, (RubyModule) typeClass);
                     this.fields.put(fieldDescriptor, value);
                     this.fields.put(fieldDescriptor, value);
                 } else {
                 } else {
                     this.fields.remove(fieldDescriptor);
                     this.fields.remove(fieldDescriptor);

+ 2 - 2
ruby/src/main/java/com/google/protobuf/jruby/RubyRepeatedField.java

@@ -110,7 +110,7 @@ public class RubyRepeatedField extends RubyObject {
     @JRubyMethod(name = "[]=")
     @JRubyMethod(name = "[]=")
     public IRubyObject indexSet(ThreadContext context, IRubyObject index, IRubyObject value) {
     public IRubyObject indexSet(ThreadContext context, IRubyObject index, IRubyObject value) {
         int arrIndex = normalizeArrayIndex(index);
         int arrIndex = normalizeArrayIndex(index);
-        Utils.checkType(context, fieldType, value, (RubyModule) typeClass);
+        value = Utils.checkType(context, fieldType, value, (RubyModule) typeClass);
         IRubyObject defaultValue = defaultValue(context);
         IRubyObject defaultValue = defaultValue(context);
         for (int i = this.storage.size(); i < arrIndex; i++) {
         for (int i = this.storage.size(); i < arrIndex; i++) {
             this.storage.set(i, defaultValue);
             this.storage.set(i, defaultValue);
@@ -166,7 +166,7 @@ public class RubyRepeatedField extends RubyObject {
     public IRubyObject push(ThreadContext context, IRubyObject value) {
     public IRubyObject push(ThreadContext context, IRubyObject value) {
         if (!(fieldType == Descriptors.FieldDescriptor.Type.MESSAGE &&
         if (!(fieldType == Descriptors.FieldDescriptor.Type.MESSAGE &&
             value == context.runtime.getNil())) {
             value == context.runtime.getNil())) {
-            Utils.checkType(context, fieldType, value, (RubyModule) typeClass);
+            value = Utils.checkType(context, fieldType, value, (RubyModule) typeClass);
         }
         }
         this.storage.add(value);
         this.storage.add(value);
         return this.storage;
         return this.storage;

+ 20 - 17
ruby/src/main/java/com/google/protobuf/jruby/Utils.java

@@ -64,8 +64,8 @@ public class Utils {
         return context.runtime.newSymbol(typeName.replace("TYPE_", "").toLowerCase());
         return context.runtime.newSymbol(typeName.replace("TYPE_", "").toLowerCase());
     }
     }
 
 
-    public static void checkType(ThreadContext context, Descriptors.FieldDescriptor.Type fieldType,
-                            IRubyObject value, RubyModule typeClass) {
+    public static IRubyObject checkType(ThreadContext context, Descriptors.FieldDescriptor.Type fieldType,
+                                        IRubyObject value, RubyModule typeClass) {
         Ruby runtime = context.runtime;
         Ruby runtime = context.runtime;
         Object val;
         Object val;
         switch(fieldType) {
         switch(fieldType) {
@@ -106,7 +106,7 @@ public class Utils {
                 break;
                 break;
             case BYTES:
             case BYTES:
             case STRING:
             case STRING:
-                validateStringEncoding(context.runtime, fieldType, value);
+                value = validateStringEncoding(context, fieldType, value);
                 break;
                 break;
             case MESSAGE:
             case MESSAGE:
                 if (value.getMetaClass() != typeClass) {
                 if (value.getMetaClass() != typeClass) {
@@ -127,6 +127,7 @@ public class Utils {
             default:
             default:
                 break;
                 break;
         }
         }
+        return value;
     }
     }
 
 
     public static IRubyObject wrapPrimaryValue(ThreadContext context, Descriptors.FieldDescriptor.Type fieldType, Object value) {
     public static IRubyObject wrapPrimaryValue(ThreadContext context, Descriptors.FieldDescriptor.Type fieldType, Object value) {
@@ -148,10 +149,16 @@ public class Utils {
                 return runtime.newFloat((Double) value);
                 return runtime.newFloat((Double) value);
             case BOOL:
             case BOOL:
                 return (Boolean) value ? runtime.getTrue() : runtime.getFalse();
                 return (Boolean) value ? runtime.getTrue() : runtime.getFalse();
-            case BYTES:
-                return runtime.newString(((ByteString) value).toStringUtf8());
-            case STRING:
-                return runtime.newString(value.toString());
+            case BYTES: {
+                IRubyObject wrapped = runtime.newString(((ByteString) value).toStringUtf8());
+                wrapped.setFrozen(true);
+                return wrapped;
+            }
+            case STRING: {
+                IRubyObject wrapped = runtime.newString(value.toString());
+                wrapped.setFrozen(true);
+                return wrapped;
+            }
             default:
             default:
                 return runtime.getNil();
                 return runtime.getNil();
         }
         }
@@ -180,25 +187,21 @@ public class Utils {
         }
         }
     }
     }
 
 
-    public static void validateStringEncoding(Ruby runtime, Descriptors.FieldDescriptor.Type type, IRubyObject value) {
+    public static IRubyObject validateStringEncoding(ThreadContext context, Descriptors.FieldDescriptor.Type type, IRubyObject value) {
         if (!(value instanceof RubyString))
         if (!(value instanceof RubyString))
-            throw runtime.newTypeError("Invalid argument for string field.");
-        Encoding encoding = ((RubyString) value).getEncoding();
+            throw context.runtime.newTypeError("Invalid argument for string field.");
         switch(type) {
         switch(type) {
             case BYTES:
             case BYTES:
-                if (encoding != ASCIIEncoding.INSTANCE)
-                    throw runtime.newTypeError("Encoding for bytes fields" +
-                            " must be \"ASCII-8BIT\", but was " + encoding);
+                value = ((RubyString)value).encode(context, context.runtime.evalScriptlet("Encoding::ASCII_8BIT"));
                 break;
                 break;
             case STRING:
             case STRING:
-                if (encoding != UTF8Encoding.INSTANCE
-                        && encoding != USASCIIEncoding.INSTANCE)
-                    throw runtime.newTypeError("Encoding for string fields" +
-                            " must be \"UTF-8\" or \"ASCII\", but was " + encoding);
+                value = ((RubyString)value).encode(context, context.runtime.evalScriptlet("Encoding::UTF_8"));
                 break;
                 break;
             default:
             default:
                 break;
                 break;
         }
         }
+        value.setFrozen(true);
+        return value;
     }
     }
 
 
     public static void checkNameAvailability(ThreadContext context, String name) {
     public static void checkNameAvailability(ThreadContext context, String name) {

+ 4 - 2
ruby/tests/basic.rb

@@ -861,8 +861,10 @@ module BasicTest
       m2 = TestMessage.decode_json(json)
       m2 = TestMessage.decode_json(json)
       assert_equal 'foo', m2.optional_string
       assert_equal 'foo', m2.optional_string
       assert_equal ['bar1', 'bar2'], m2.repeated_string
       assert_equal ['bar1', 'bar2'], m2.repeated_string
-      assert m2.optional_string.frozen?
-      assert m2.repeated_string[0].frozen?
+      if RUBY_PLATFORM != "java"
+        assert m2.optional_string.frozen?
+        assert m2.repeated_string[0].frozen?
+      end
 
 
       proto = m.to_proto
       proto = m.to_proto
       m2 = TestMessage.decode(proto)
       m2 = TestMessage.decode(proto)