Browse Source

Merge pull request #1558 from haberman/rubyoneof

Ruby oneofs: return default instead of nil for unset fields.
Joshua Haberman 9 years ago
parent
commit
cc5296b869

+ 32 - 1
ruby/ext/google/protobuf_c/storage.c

@@ -57,6 +57,37 @@ size_t native_slot_size(upb_fieldtype_t type) {
   }
   }
 }
 }
 
 
+static VALUE value_from_default(const upb_fielddef *field) {
+  switch (upb_fielddef_type(field)) {
+    case UPB_TYPE_FLOAT:   return DBL2NUM(upb_fielddef_defaultfloat(field));
+    case UPB_TYPE_DOUBLE:  return DBL2NUM(upb_fielddef_defaultdouble(field));
+    case UPB_TYPE_BOOL:
+      return upb_fielddef_defaultbool(field) ? Qtrue : Qfalse;
+    case UPB_TYPE_MESSAGE: return Qnil;
+    case UPB_TYPE_ENUM: {
+      const upb_enumdef *enumdef = upb_fielddef_enumsubdef(field);
+      int32_t num = upb_fielddef_defaultint32(field);
+      const char *label = upb_enumdef_iton(enumdef, num);
+      if (label) {
+        return ID2SYM(rb_intern(label));
+      } else {
+        return INT2NUM(num);
+      }
+    }
+    case UPB_TYPE_INT32:   return INT2NUM(upb_fielddef_defaultint32(field));
+    case UPB_TYPE_INT64:   return LL2NUM(upb_fielddef_defaultint64(field));;
+    case UPB_TYPE_UINT32:  return UINT2NUM(upb_fielddef_defaultuint32(field));
+    case UPB_TYPE_UINT64:  return ULL2NUM(upb_fielddef_defaultuint64(field));
+    case UPB_TYPE_STRING:
+    case UPB_TYPE_BYTES: {
+      size_t size;
+      const char *str = upb_fielddef_defaultstr(field, &size);
+      return rb_str_new(str, size);
+    }
+    default: return Qnil;
+  }
+}
+
 static bool is_ruby_num(VALUE value) {
 static bool is_ruby_num(VALUE value) {
   return (TYPE(value) == T_FLOAT ||
   return (TYPE(value) == T_FLOAT ||
           TYPE(value) == T_FIXNUM ||
           TYPE(value) == T_FIXNUM ||
@@ -537,7 +568,7 @@ VALUE layout_get(MessageLayout* layout,
 
 
   if (upb_fielddef_containingoneof(field)) {
   if (upb_fielddef_containingoneof(field)) {
     if (*oneof_case != upb_fielddef_number(field)) {
     if (*oneof_case != upb_fielddef_number(field)) {
-      return Qnil;
+      return value_from_default(field);
     }
     }
     return native_slot_get(upb_fielddef_type(field),
     return native_slot_get(upb_fielddef_type(field),
                            field_type_class(field),
                            field_type_class(field),

+ 8 - 4
ruby/src/main/java/com/google/protobuf/jruby/RubyMessage.java

@@ -592,13 +592,17 @@ public class RubyMessage extends RubyObject {
     protected IRubyObject getField(ThreadContext context, Descriptors.FieldDescriptor fieldDescriptor) {
     protected IRubyObject getField(ThreadContext context, Descriptors.FieldDescriptor fieldDescriptor) {
         Descriptors.OneofDescriptor oneofDescriptor = fieldDescriptor.getContainingOneof();
         Descriptors.OneofDescriptor oneofDescriptor = fieldDescriptor.getContainingOneof();
         if (oneofDescriptor != null) {
         if (oneofDescriptor != null) {
-            if (oneofCases.containsKey(oneofDescriptor)) {
-                if (oneofCases.get(oneofDescriptor) != fieldDescriptor)
-                    return context.runtime.getNil();
+            if (oneofCases.get(oneofDescriptor) == fieldDescriptor) {
                 return fields.get(fieldDescriptor);
                 return fields.get(fieldDescriptor);
             } else {
             } else {
                 Descriptors.FieldDescriptor oneofCase = builder.getOneofFieldDescriptor(oneofDescriptor);
                 Descriptors.FieldDescriptor oneofCase = builder.getOneofFieldDescriptor(oneofDescriptor);
-                if (oneofCase != fieldDescriptor) return context.runtime.getNil();
+                if (oneofCase != fieldDescriptor) {
+                  if (fieldDescriptor.getType() == Descriptors.FieldDescriptor.Type.MESSAGE) {
+                    return context.runtime.getNil();
+                  } else {
+                    return wrapField(context, fieldDescriptor, fieldDescriptor.getDefaultValue());
+                  }
+                }
                 IRubyObject value = wrapField(context, oneofCase, builder.getField(oneofCase));
                 IRubyObject value = wrapField(context, oneofCase, builder.getField(oneofCase));
                 fields.put(fieldDescriptor, value);
                 fields.put(fieldDescriptor, value);
                 return value;
                 return value;

+ 19 - 19
ruby/tests/basic.rb

@@ -703,36 +703,36 @@ module BasicTest
 
 
     def test_oneof
     def test_oneof
       d = OneofMessage.new
       d = OneofMessage.new
-      assert d.a == nil
-      assert d.b == nil
+      assert d.a == ""
+      assert d.b == 0
       assert d.c == nil
       assert d.c == nil
-      assert d.d == nil
+      assert d.d == :Default
       assert d.my_oneof == nil
       assert d.my_oneof == nil
 
 
       d.a = "hi"
       d.a = "hi"
       assert d.a == "hi"
       assert d.a == "hi"
-      assert d.b == nil
+      assert d.b == 0
       assert d.c == nil
       assert d.c == nil
-      assert d.d == nil
+      assert d.d == :Default
       assert d.my_oneof == :a
       assert d.my_oneof == :a
 
 
       d.b = 42
       d.b = 42
-      assert d.a == nil
+      assert d.a == ""
       assert d.b == 42
       assert d.b == 42
       assert d.c == nil
       assert d.c == nil
-      assert d.d == nil
+      assert d.d == :Default
       assert d.my_oneof == :b
       assert d.my_oneof == :b
 
 
       d.c = TestMessage2.new(:foo => 100)
       d.c = TestMessage2.new(:foo => 100)
-      assert d.a == nil
-      assert d.b == nil
+      assert d.a == ""
+      assert d.b == 0
       assert d.c.foo == 100
       assert d.c.foo == 100
-      assert d.d == nil
+      assert d.d == :Default
       assert d.my_oneof == :c
       assert d.my_oneof == :c
 
 
       d.d = :C
       d.d = :C
-      assert d.a == nil
-      assert d.b == nil
+      assert d.a == ""
+      assert d.b == 0
       assert d.c == nil
       assert d.c == nil
       assert d.d == :C
       assert d.d == :C
       assert d.my_oneof == :d
       assert d.my_oneof == :d
@@ -748,23 +748,23 @@ module BasicTest
 
 
       d3 = OneofMessage.decode(
       d3 = OneofMessage.decode(
         encoded_field_c + encoded_field_a + encoded_field_d)
         encoded_field_c + encoded_field_a + encoded_field_d)
-      assert d3.a == nil
-      assert d3.b == nil
+      assert d3.a == ""
+      assert d3.b == 0
       assert d3.c == nil
       assert d3.c == nil
       assert d3.d == :B
       assert d3.d == :B
 
 
       d4 = OneofMessage.decode(
       d4 = OneofMessage.decode(
         encoded_field_c + encoded_field_a + encoded_field_d +
         encoded_field_c + encoded_field_a + encoded_field_d +
         encoded_field_c)
         encoded_field_c)
-      assert d4.a == nil
-      assert d4.b == nil
+      assert d4.a == ""
+      assert d4.b == 0
       assert d4.c.foo == 1
       assert d4.c.foo == 1
-      assert d4.d == nil
+      assert d4.d == :Default
 
 
       d5 = OneofMessage.new(:a => "hello")
       d5 = OneofMessage.new(:a => "hello")
-      assert d5.a != nil
+      assert d5.a == "hello"
       d5.a = nil
       d5.a = nil
-      assert d5.a == nil
+      assert d5.a == ""
       assert OneofMessage.encode(d5) == ''
       assert OneofMessage.encode(d5) == ''
       assert d5.my_oneof == nil
       assert d5.my_oneof == nil
     end
     end