Sfoglia il codice sorgente

Merge pull request #1821 from haberman/rubyfreezestr

Ruby: encode and freeze strings when the are assigned or decoded.
Joshua Haberman 9 anni fa
parent
commit
44bd6bda58

+ 42 - 11
ruby/ext/google/protobuf_c/encode_decode.c

@@ -54,7 +54,7 @@ VALUE noleak_rb_str_cat(VALUE rb_str, const char *str, long len) {
 static const void* newhandlerdata(upb_handlers* h, uint32_t ofs) {
   size_t* hd_ofs = ALLOC(size_t);
   *hd_ofs = ofs;
-  upb_handlers_addcleanup(h, hd_ofs, free);
+  upb_handlers_addcleanup(h, hd_ofs, xfree);
   return hd_ofs;
 }
 
@@ -69,7 +69,7 @@ static const void *newsubmsghandlerdata(upb_handlers* h, uint32_t ofs,
   submsg_handlerdata_t *hd = ALLOC(submsg_handlerdata_t);
   hd->ofs = ofs;
   hd->md = upb_fielddef_msgsubdef(f);
-  upb_handlers_addcleanup(h, hd, free);
+  upb_handlers_addcleanup(h, hd, xfree);
   return hd;
 }
 
@@ -99,7 +99,7 @@ static const void *newoneofhandlerdata(upb_handlers *h,
   } else {
     hd->md = NULL;
   }
-  upb_handlers_addcleanup(h, hd, free);
+  upb_handlers_addcleanup(h, hd, xfree);
   return hd;
 }
 
@@ -135,7 +135,7 @@ static void* appendstr_handler(void *closure,
   VALUE ary = (VALUE)closure;
   VALUE str = rb_str_new2("");
   rb_enc_associate(str, kRubyStringUtf8Encoding);
-  RepeatedField_push(ary, str);
+  RepeatedField_push_native(ary, &str);
   return (void*)str;
 }
 
@@ -146,7 +146,7 @@ static void* appendbytes_handler(void *closure,
   VALUE ary = (VALUE)closure;
   VALUE str = rb_str_new2("");
   rb_enc_associate(str, kRubyString8bitEncoding);
-  RepeatedField_push(ary, str);
+  RepeatedField_push_native(ary, &str);
   return (void*)str;
 }
 
@@ -182,6 +182,23 @@ static size_t stringdata_handler(void* closure, const void* hd,
   return len;
 }
 
+static bool stringdata_end_handler(void* closure, const void* hd) {
+  MessageHeader* msg = closure;
+  const size_t *ofs = hd;
+  VALUE rb_str = DEREF(msg, *ofs, VALUE);
+  rb_obj_freeze(rb_str);
+  return true;
+}
+
+static bool appendstring_end_handler(void* closure, const void* hd) {
+  VALUE ary = (VALUE)closure;
+  int size = RepeatedField_size(ary);
+  VALUE* last = RepeatedField_index_native(ary, size - 1);
+  VALUE rb_str = *last;
+  rb_obj_freeze(rb_str);
+  return true;
+}
+
 // Appends a submessage to a repeated field (a regular Ruby array for now).
 static void *appendsubmsg_handler(void *closure, const void *hd) {
   VALUE ary = (VALUE)closure;
@@ -281,7 +298,7 @@ static bool endmap_handler(void *closure, const void *hd, upb_status* s) {
       &frame->value_storage);
 
   Map_index_set(frame->map, key, value);
-  free(frame);
+  xfree(frame);
 
   return true;
 }
@@ -360,6 +377,13 @@ static void *oneofbytes_handler(void *closure,
   return (void*)str;
 }
 
+static bool oneofstring_end_handler(void* closure, const void* hd) {
+  MessageHeader* msg = closure;
+  const oneof_handlerdata_t *oneofdata = hd;
+  rb_obj_freeze(DEREF(msg, oneofdata->ofs, VALUE));
+  return true;
+}
+
 // Handler for a submessage field in a oneof.
 static void *oneofsubmsg_handler(void *closure,
                                  const void *hd) {
@@ -426,6 +450,7 @@ static void add_handlers_for_repeated_field(upb_handlers *h,
                                appendbytes_handler : appendstr_handler,
                                NULL);
       upb_handlers_setstring(h, f, stringdata_handler, NULL);
+      upb_handlers_setendstr(h, f, appendstring_end_handler, NULL);
       break;
     }
     case UPB_TYPE_MESSAGE: {
@@ -462,6 +487,7 @@ static void add_handlers_for_singular_field(upb_handlers *h,
                                is_bytes ? bytes_handler : str_handler,
                                &attr);
       upb_handlers_setstring(h, f, stringdata_handler, &attr);
+      upb_handlers_setendstr(h, f, stringdata_end_handler, &attr);
       upb_handlerattr_uninit(&attr);
       break;
     }
@@ -484,7 +510,7 @@ static void add_handlers_for_mapfield(upb_handlers* h,
   map_handlerdata_t* hd = new_map_handlerdata(offset, map_msgdef, desc);
   upb_handlerattr attr = UPB_HANDLERATTR_INITIALIZER;
 
-  upb_handlers_addcleanup(h, hd, free);
+  upb_handlers_addcleanup(h, hd, xfree);
   upb_handlerattr_sethandlerdata(&attr, hd);
   upb_handlers_setstartsubmsg(h, fielddef, startmapentry_handler, &attr);
   upb_handlerattr_uninit(&attr);
@@ -499,7 +525,7 @@ static void add_handlers_for_mapentry(const upb_msgdef* msgdef,
   map_handlerdata_t* hd = new_map_handlerdata(0, msgdef, desc);
   upb_handlerattr attr = UPB_HANDLERATTR_INITIALIZER;
 
-  upb_handlers_addcleanup(h, hd, free);
+  upb_handlers_addcleanup(h, hd, xfree);
   upb_handlerattr_sethandlerdata(&attr, hd);
   upb_handlers_setendmsg(h, endmap_handler, &attr);
 
@@ -546,6 +572,7 @@ static void add_handlers_for_oneof_field(upb_handlers *h,
                                oneofbytes_handler : oneofstr_handler,
                                &attr);
       upb_handlers_setstring(h, f, stringdata_handler, NULL);
+      upb_handlers_setendstr(h, f, oneofstring_end_handler, &attr);
       break;
     }
     case UPB_TYPE_MESSAGE: {
@@ -863,9 +890,13 @@ static void putstr(VALUE str, const upb_fielddef *f, upb_sink *sink) {
 
   assert(BUILTIN_TYPE(str) == RUBY_T_STRING);
 
-  // Ensure that the string has the correct encoding. We also check at field-set
-  // time, but the user may have mutated the string object since then.
-  native_slot_validate_string_encoding(upb_fielddef_type(f), str);
+  // We should be guaranteed that the string has the correct encoding because
+  // we ensured this at assignment time and then froze the string.
+  if (upb_fielddef_type(f) == UPB_TYPE_STRING) {
+    assert(rb_enc_from_index(ENCODING_GET(value)) == kRubyStringUtf8Encoding);
+  } else {
+    assert(rb_enc_from_index(ENCODING_GET(value)) == kRubyString8bitEncoding);
+  }
 
   upb_sink_startstr(sink, getsel(f, UPB_HANDLER_STARTSTR), RSTRING_LEN(str),
                     &subsink);

+ 11 - 9
ruby/ext/google/protobuf_c/map.c

@@ -63,16 +63,16 @@
 // construct a key byte sequence if needed. |out_key| and |out_length| provide
 // the resulting key data/length.
 #define TABLE_KEY_BUF_LENGTH 8  // sizeof(uint64_t)
-static void table_key(Map* self, VALUE key,
-                      char* buf,
-                      const char** out_key,
-                      size_t* out_length) {
+static VALUE table_key(Map* self, VALUE key,
+                       char* buf,
+                       const char** out_key,
+                       size_t* out_length) {
   switch (self->key_type) {
     case UPB_TYPE_BYTES:
     case UPB_TYPE_STRING:
       // Strings: use string content directly.
       Check_Type(key, T_STRING);
-      native_slot_validate_string_encoding(self->key_type, key);
+      key = native_slot_encode_and_freeze_string(self->key_type, key);
       *out_key = RSTRING_PTR(key);
       *out_length = RSTRING_LEN(key);
       break;
@@ -93,6 +93,8 @@ static void table_key(Map* self, VALUE key,
       assert(false);
       break;
   }
+
+  return key;
 }
 
 static VALUE table_key_to_ruby(Map* self, const char* buf, size_t length) {
@@ -357,7 +359,7 @@ VALUE Map_index(VALUE _self, VALUE key) {
   const char* keyval = NULL;
   size_t length = 0;
   upb_value v;
-  table_key(self, key, keybuf, &keyval, &length);
+  key = table_key(self, key, keybuf, &keyval, &length);
 
   if (upb_strtable_lookup2(&self->table, keyval, length, &v)) {
     void* mem = value_memory(&v);
@@ -383,7 +385,7 @@ VALUE Map_index_set(VALUE _self, VALUE key, VALUE value) {
   size_t length = 0;
   upb_value v;
   void* mem;
-  table_key(self, key, keybuf, &keyval, &length);
+  key = table_key(self, key, keybuf, &keyval, &length);
 
   mem = value_memory(&v);
   native_slot_set(self->value_type, self->value_type_class, mem, value);
@@ -411,7 +413,7 @@ VALUE Map_has_key(VALUE _self, VALUE key) {
   char keybuf[TABLE_KEY_BUF_LENGTH];
   const char* keyval = NULL;
   size_t length = 0;
-  table_key(self, key, keybuf, &keyval, &length);
+  key = table_key(self, key, keybuf, &keyval, &length);
 
   if (upb_strtable_lookup2(&self->table, keyval, length, NULL)) {
     return Qtrue;
@@ -434,7 +436,7 @@ VALUE Map_delete(VALUE _self, VALUE key) {
   const char* keyval = NULL;
   size_t length = 0;
   upb_value v;
-  table_key(self, key, keybuf, &keyval, &length);
+  key = table_key(self, key, keybuf, &keyval, &length);
 
   if (upb_strtable_remove2(&self->table, keyval, length, &v)) {
     void* mem = value_memory(&v);

+ 2 - 1
ruby/ext/google/protobuf_c/protobuf.h

@@ -313,7 +313,7 @@ void native_slot_dup(upb_fieldtype_t type, void* to, void* from);
 void native_slot_deep_copy(upb_fieldtype_t type, void* to, void* from);
 bool native_slot_eq(upb_fieldtype_t type, void* mem1, void* mem2);
 
-void native_slot_validate_string_encoding(upb_fieldtype_t type, VALUE value);
+VALUE native_slot_encode_and_freeze_string(upb_fieldtype_t type, VALUE value);
 void native_slot_check_int_range_precision(upb_fieldtype_t type, VALUE value);
 
 extern rb_encoding* kRubyStringUtf8Encoding;
@@ -366,6 +366,7 @@ RepeatedField* ruby_to_RepeatedField(VALUE value);
 VALUE RepeatedField_each(VALUE _self);
 VALUE RepeatedField_index(int argc, VALUE* argv, VALUE _self);
 void* RepeatedField_index_native(VALUE _self, int index);
+int RepeatedField_size(VALUE _self);
 VALUE RepeatedField_index_set(VALUE _self, VALUE _index, VALUE val);
 void RepeatedField_reserve(RepeatedField* self, int new_size);
 VALUE RepeatedField_push(VALUE _self, VALUE val);

+ 5 - 0
ruby/ext/google/protobuf_c/repeated_field.c

@@ -244,6 +244,11 @@ void* RepeatedField_index_native(VALUE _self, int index) {
   return RepeatedField_memoryat(self, index, element_size);
 }
 
+int RepeatedField_size(VALUE _self) {
+  RepeatedField* self = ruby_to_RepeatedField(_self);
+  return self->size;
+}
+
 /*
  * Private ruby method, used by RepeatedField.pop
  */

+ 19 - 20
ruby/ext/google/protobuf_c/storage.c

@@ -117,25 +117,24 @@ void native_slot_check_int_range_precision(upb_fieldtype_t type, VALUE val) {
   }
 }
 
-void native_slot_validate_string_encoding(upb_fieldtype_t type, VALUE value) {
-  bool bad_encoding = false;
-  rb_encoding* string_encoding = rb_enc_from_index(ENCODING_GET(value));
-  if (type == UPB_TYPE_STRING) {
-    bad_encoding =
-        string_encoding != kRubyStringUtf8Encoding &&
-        string_encoding != kRubyStringASCIIEncoding;
-  } else {
-    bad_encoding =
-        string_encoding != kRubyString8bitEncoding;
-  }
-  // Check that encoding is UTF-8 or ASCII (for string fields) or ASCII-8BIT
-  // (for bytes fields).
-  if (bad_encoding) {
-    rb_raise(rb_eTypeError, "Encoding for '%s' fields must be %s (was %s)",
-             (type == UPB_TYPE_STRING) ? "string" : "bytes",
-             (type == UPB_TYPE_STRING) ? "UTF-8 or ASCII" : "ASCII-8BIT",
-             rb_enc_name(string_encoding));
+VALUE native_slot_encode_and_freeze_string(upb_fieldtype_t type, VALUE value) {
+  rb_encoding* desired_encoding = (type == UPB_TYPE_STRING) ?
+      kRubyStringUtf8Encoding : kRubyString8bitEncoding;
+  VALUE desired_encoding_value = rb_enc_from_encoding(desired_encoding);
+
+  // Note: this will not duplicate underlying string data unless necessary.
+  value = rb_str_encode(value, desired_encoding_value, 0, Qnil);
+
+  if (type == UPB_TYPE_STRING &&
+      rb_enc_str_coderange(value) == ENC_CODERANGE_BROKEN) {
+    rb_raise(rb_eEncodingError, "String is invalid UTF-8");
   }
+
+  // Ensure the data remains valid.  Since we called #encode a moment ago,
+  // this does not freeze the string the user assigned.
+  rb_obj_freeze(value);
+
+  return value;
 }
 
 void native_slot_set(upb_fieldtype_t type, VALUE type_class,
@@ -181,8 +180,8 @@ void native_slot_set_value_and_case(upb_fieldtype_t type, VALUE type_class,
       if (CLASS_OF(value) != rb_cString) {
         rb_raise(rb_eTypeError, "Invalid argument for string field.");
       }
-      native_slot_validate_string_encoding(type, value);
-      DEREF(memory, VALUE) = value;
+
+      DEREF(memory, VALUE) = native_slot_encode_and_freeze_string(type, value);
       break;
     }
     case UPB_TYPE_MESSAGE: {

+ 1 - 1
ruby/ext/google/protobuf_c/upb.c

@@ -11076,8 +11076,8 @@ static bool end_stringval(upb_json_parser *p) {
 
     case UPB_TYPE_STRING: {
       upb_selector_t sel = getsel_for_handlertype(p, UPB_HANDLER_ENDSTR);
-      upb_sink_endstr(&p->top->sink, sel);
       p->top--;
+      upb_sink_endstr(&p->top->sink, sel);
       break;
     }
 

+ 2 - 2
ruby/src/main/java/com/google/protobuf/jruby/RubyMap.java

@@ -148,8 +148,8 @@ public class RubyMap extends RubyObject {
      */
     @JRubyMethod(name = "[]=")
     public IRubyObject indexSet(ThreadContext context, IRubyObject key, IRubyObject value) {
-        Utils.checkType(context, keyType, key, (RubyModule) valueTypeClass);
-        Utils.checkType(context, valueType, value, (RubyModule) valueTypeClass);
+        key = Utils.checkType(context, keyType, key, (RubyModule) valueTypeClass);
+        value = Utils.checkType(context, valueType, value, (RubyModule) valueTypeClass);
         IRubyObject symbol;
         if (valueType == Descriptors.FieldDescriptor.Type.ENUM &&
                 Utils.isRubyNum(value) &&

+ 2 - 2
ruby/src/main/java/com/google/protobuf/jruby/RubyMessage.java

@@ -504,7 +504,7 @@ public class RubyMessage extends RubyObject {
                 break;
             case BYTES:
             case STRING:
-                Utils.validateStringEncoding(context.runtime, fieldDescriptor.getType(), value);
+                Utils.validateStringEncoding(context, fieldDescriptor.getType(), value);
                 RubyString str = (RubyString) value;
                 switch (fieldDescriptor.getType()) {
                     case BYTES:
@@ -695,7 +695,7 @@ public class RubyMessage extends RubyObject {
                     }
                 }
                 if (addValue) {
-                    Utils.checkType(context, fieldType, value, (RubyModule) typeClass);
+                    value = Utils.checkType(context, fieldType, value, (RubyModule) typeClass);
                     this.fields.put(fieldDescriptor, value);
                 } else {
                     this.fields.remove(fieldDescriptor);

+ 2 - 2
ruby/src/main/java/com/google/protobuf/jruby/RubyRepeatedField.java

@@ -110,7 +110,7 @@ public class RubyRepeatedField extends RubyObject {
     @JRubyMethod(name = "[]=")
     public IRubyObject indexSet(ThreadContext context, IRubyObject index, IRubyObject value) {
         int arrIndex = normalizeArrayIndex(index);
-        Utils.checkType(context, fieldType, value, (RubyModule) typeClass);
+        value = Utils.checkType(context, fieldType, value, (RubyModule) typeClass);
         IRubyObject defaultValue = defaultValue(context);
         for (int i = this.storage.size(); i < arrIndex; i++) {
             this.storage.set(i, defaultValue);
@@ -166,7 +166,7 @@ public class RubyRepeatedField extends RubyObject {
     public IRubyObject push(ThreadContext context, IRubyObject value) {
         if (!(fieldType == Descriptors.FieldDescriptor.Type.MESSAGE &&
             value == context.runtime.getNil())) {
-            Utils.checkType(context, fieldType, value, (RubyModule) typeClass);
+            value = Utils.checkType(context, fieldType, value, (RubyModule) typeClass);
         }
         this.storage.add(value);
         return this.storage;

+ 20 - 17
ruby/src/main/java/com/google/protobuf/jruby/Utils.java

@@ -64,8 +64,8 @@ public class Utils {
         return context.runtime.newSymbol(typeName.replace("TYPE_", "").toLowerCase());
     }
 
-    public static void checkType(ThreadContext context, Descriptors.FieldDescriptor.Type fieldType,
-                            IRubyObject value, RubyModule typeClass) {
+    public static IRubyObject checkType(ThreadContext context, Descriptors.FieldDescriptor.Type fieldType,
+                                        IRubyObject value, RubyModule typeClass) {
         Ruby runtime = context.runtime;
         Object val;
         switch(fieldType) {
@@ -106,7 +106,7 @@ public class Utils {
                 break;
             case BYTES:
             case STRING:
-                validateStringEncoding(context.runtime, fieldType, value);
+                value = validateStringEncoding(context, fieldType, value);
                 break;
             case MESSAGE:
                 if (value.getMetaClass() != typeClass) {
@@ -127,6 +127,7 @@ public class Utils {
             default:
                 break;
         }
+        return value;
     }
 
     public static IRubyObject wrapPrimaryValue(ThreadContext context, Descriptors.FieldDescriptor.Type fieldType, Object value) {
@@ -148,10 +149,16 @@ public class Utils {
                 return runtime.newFloat((Double) value);
             case BOOL:
                 return (Boolean) value ? runtime.getTrue() : runtime.getFalse();
-            case BYTES:
-                return runtime.newString(((ByteString) value).toStringUtf8());
-            case STRING:
-                return runtime.newString(value.toString());
+            case BYTES: {
+                IRubyObject wrapped = runtime.newString(((ByteString) value).toStringUtf8());
+                wrapped.setFrozen(true);
+                return wrapped;
+            }
+            case STRING: {
+                IRubyObject wrapped = runtime.newString(value.toString());
+                wrapped.setFrozen(true);
+                return wrapped;
+            }
             default:
                 return runtime.getNil();
         }
@@ -180,25 +187,21 @@ public class Utils {
         }
     }
 
-    public static void validateStringEncoding(Ruby runtime, Descriptors.FieldDescriptor.Type type, IRubyObject value) {
+    public static IRubyObject validateStringEncoding(ThreadContext context, Descriptors.FieldDescriptor.Type type, IRubyObject value) {
         if (!(value instanceof RubyString))
-            throw runtime.newTypeError("Invalid argument for string field.");
-        Encoding encoding = ((RubyString) value).getEncoding();
+            throw context.runtime.newTypeError("Invalid argument for string field.");
         switch(type) {
             case BYTES:
-                if (encoding != ASCIIEncoding.INSTANCE)
-                    throw runtime.newTypeError("Encoding for bytes fields" +
-                            " must be \"ASCII-8BIT\", but was " + encoding);
+                value = ((RubyString)value).encode(context, context.runtime.evalScriptlet("Encoding::ASCII_8BIT"));
                 break;
             case STRING:
-                if (encoding != UTF8Encoding.INSTANCE
-                        && encoding != USASCIIEncoding.INSTANCE)
-                    throw runtime.newTypeError("Encoding for string fields" +
-                            " must be \"UTF-8\" or \"ASCII\", but was " + encoding);
+                value = ((RubyString)value).encode(context, context.runtime.evalScriptlet("Encoding::UTF_8"));
                 break;
             default:
                 break;
         }
+        value.setFrozen(true);
+        return value;
     }
 
     public static void checkNameAvailability(ThreadContext context, String name) {

+ 26 - 18
ruby/tests/basic.rb

@@ -255,14 +255,17 @@ module BasicTest
       m = TestMessage.new
 
       # Assigning a normal (ASCII or UTF8) string to a bytes field, or
-      # ASCII-8BIT to a string field, raises an error.
-      assert_raise TypeError do
-        m.optional_bytes = "Test string ASCII".encode!('ASCII')
-      end
-      assert_raise TypeError do
+      # ASCII-8BIT to a string field will convert to the proper encoding.
+      m.optional_bytes = "Test string ASCII".encode!('ASCII')
+      assert m.optional_bytes.frozen?
+      assert_equal Encoding::ASCII_8BIT, m.optional_bytes.encoding
+      assert_equal "Test string ASCII", m.optional_bytes
+
+      assert_raise Encoding::UndefinedConversionError do
         m.optional_bytes = "Test string UTF-8 \u0100".encode!('UTF-8')
       end
-      assert_raise TypeError do
+
+      assert_raise Encoding::UndefinedConversionError do
         m.optional_string = ["FFFF"].pack('H*')
       end
 
@@ -270,11 +273,10 @@ module BasicTest
       m.optional_bytes = ["FFFF"].pack('H*')
       m.optional_string = "\u0100"
 
-      # strings are mutable so we can do this, but serialize should catch it.
+      # strings are immutable so we can't do this, but serialize should catch it.
       m.optional_string = "asdf".encode!('UTF-8')
-      m.optional_string.encode!('ASCII-8BIT')
-      assert_raise TypeError do
-        data = TestMessage.encode(m)
+      assert_raise RuntimeError do
+        m.optional_string.encode!('ASCII-8BIT')
       end
     end
 
@@ -558,7 +560,7 @@ module BasicTest
       assert_raise TypeError do
         m[1] = 1
       end
-      assert_raise TypeError do
+      assert_raise Encoding::UndefinedConversionError do
         bytestring = ["FFFF"].pack("H*")
         m[bytestring] = 1
       end
@@ -566,9 +568,8 @@ module BasicTest
       m = Google::Protobuf::Map.new(:bytes, :int32)
       bytestring = ["FFFF"].pack("H*")
       m[bytestring] = 1
-      assert_raise TypeError do
-        m["asdf"] = 1
-      end
+      # Allowed -- we will automatically convert to ASCII-8BIT.
+      m["asdf"] = 1
       assert_raise TypeError do
         m[1] = 1
       end
@@ -853,15 +854,22 @@ module BasicTest
 
     def test_encode_decode_helpers
       m = TestMessage.new(:optional_string => 'foo', :repeated_string => ['bar1', 'bar2'])
+      assert_equal 'foo', m.optional_string
+      assert_equal ['bar1', 'bar2'], m.repeated_string
+
       json = m.to_json
       m2 = TestMessage.decode_json(json)
-      assert m2.optional_string == 'foo'
-      assert m2.repeated_string == ['bar1', 'bar2']
+      assert_equal 'foo', m2.optional_string
+      assert_equal ['bar1', 'bar2'], m2.repeated_string
+      if RUBY_PLATFORM != "java"
+        assert m2.optional_string.frozen?
+        assert m2.repeated_string[0].frozen?
+      end
 
       proto = m.to_proto
       m2 = TestMessage.decode(proto)
-      assert m2.optional_string == 'foo'
-      assert m2.repeated_string == ['bar1', 'bar2']
+      assert_equal 'foo', m2.optional_string
+      assert_equal ['bar1', 'bar2'], m2.repeated_string
     end
 
     def test_protobuf_encode_decode_helpers