Эх сурвалжийг харах

Add wrapper type helpers for Ruby (#5739)

* add wrapper type helpers

* add check for _as_value suffix
Joe Bolinger 6 жил өмнө
parent
commit
e4bbca1fc5

+ 1 - 0
ruby/.gitignore

@@ -6,3 +6,4 @@ protobuf-jruby.iml
 target/
 target/
 pkg/
 pkg/
 tmp/
 tmp/
+tests/google/

+ 5 - 0
ruby/Rakefile

@@ -93,6 +93,7 @@ genproto_output << "tests/test_ruby_package.rb"
 genproto_output << "tests/test_ruby_package_proto2.rb"
 genproto_output << "tests/test_ruby_package_proto2.rb"
 genproto_output << "tests/basic_test.rb"
 genproto_output << "tests/basic_test.rb"
 genproto_output << "tests/basic_test_proto2.rb"
 genproto_output << "tests/basic_test_proto2.rb"
+genproto_output << "tests/wrappers.rb"
 file "tests/generated_code.rb" => "tests/generated_code.proto" do |file_task|
 file "tests/generated_code.rb" => "tests/generated_code.proto" do |file_task|
   sh "../src/protoc --ruby_out=. tests/generated_code.proto"
   sh "../src/protoc --ruby_out=. tests/generated_code.proto"
 end
 end
@@ -125,6 +126,10 @@ file "tests/basic_test_proto2.rb" => "tests/basic_test_proto2.proto" do |file_ta
   sh "../src/protoc -I../src -I. --ruby_out=. tests/basic_test_proto2.proto"
   sh "../src/protoc -I../src -I. --ruby_out=. tests/basic_test_proto2.proto"
 end
 end
 
 
+file "tests/wrappers.rb" => "../src/google/protobuf/wrappers.proto" do |file_task|
+  sh "../src/protoc -I../src -I. --ruby_out=tests ../src/google/protobuf/wrappers.proto"
+end
+
 task :genproto => genproto_output
 task :genproto => genproto_output
 
 
 task :clean do
 task :clean do

+ 68 - 2
ruby/ext/google/protobuf_c/message.c

@@ -119,9 +119,37 @@ enum {
   METHOD_SETTER = 2,
   METHOD_SETTER = 2,
   METHOD_CLEAR = 3,
   METHOD_CLEAR = 3,
   METHOD_PRESENCE = 4,
   METHOD_PRESENCE = 4,
-  METHOD_ENUM_GETTER = 5
+  METHOD_ENUM_GETTER = 5,
+  METHOD_WRAPPER_GETTER = 6,
+  METHOD_WRAPPER_SETTER = 7
 };
 };
 
 
+// Check if the field is a well known wrapper type
+static bool is_wrapper_type_field(const upb_fielddef* field) {
+  char* field_type_name = rb_class2name(field_type_class(field));
+
+  return strcmp(field_type_name, "Google::Protobuf::DoubleValue") == 0 ||
+         strcmp(field_type_name, "Google::Protobuf::FloatValue") == 0 ||
+         strcmp(field_type_name, "Google::Protobuf::Int32Value") == 0 ||
+         strcmp(field_type_name, "Google::Protobuf::Int64Value") == 0 ||
+         strcmp(field_type_name, "Google::Protobuf::UInt32Value") == 0 ||
+         strcmp(field_type_name, "Google::Protobuf::UInt64Value") == 0 ||
+         strcmp(field_type_name, "Google::Protobuf::BoolValue") == 0 ||
+         strcmp(field_type_name, "Google::Protobuf::StringValue") == 0 ||
+         strcmp(field_type_name, "Google::Protobuf::BytesValue") == 0;
+}
+
+// Get a new Ruby wrapper type and set the initial value
+static VALUE ruby_wrapper_type(const upb_fielddef* field, const VALUE* value) {
+  if (is_wrapper_type_field(field) && value != Qnil) {
+    VALUE hash = rb_hash_new();
+    rb_hash_aset(hash, rb_str_new2("value"), value);
+    VALUE args[1] = { hash };
+    return rb_class_new_instance(1, args, field_type_class(field));
+  }
+  return Qnil;
+}
+
 static int extract_method_call(VALUE method_name, MessageHeader* self,
 static int extract_method_call(VALUE method_name, MessageHeader* self,
 			       const upb_fielddef **f, const upb_oneofdef **o) {
 			       const upb_fielddef **f, const upb_oneofdef **o) {
   Check_Type(method_name, T_SYMBOL);
   Check_Type(method_name, T_SYMBOL);
@@ -157,6 +185,34 @@ static int extract_method_call(VALUE method_name, MessageHeader* self,
   bool has_field = upb_msgdef_lookupname(self->descriptor->msgdef, name, name_len,
   bool has_field = upb_msgdef_lookupname(self->descriptor->msgdef, name, name_len,
 			                                   &test_f, &test_o);
 			                                   &test_f, &test_o);
 
 
+  // Look for wrapper type accessor of the form <field_name>_as_value
+  if (!has_field &&
+      (accessor_type == METHOD_GETTER || accessor_type == METHOD_SETTER) &&
+      name_len > 9 && strncmp(name + name_len - 9, "_as_value", 9) == 0) {
+    // Find the field name
+    char wrapper_field_name[name_len - 8];
+    strncpy(wrapper_field_name, name, name_len - 9);
+    wrapper_field_name[name_len - 7] = '\0';
+
+    // Check if field exists and is a wrapper type
+    const upb_oneofdef* test_o_wrapper;
+    const upb_fielddef* test_f_wrapper;
+    if (upb_msgdef_lookupname(self->descriptor->msgdef, wrapper_field_name, name_len - 9,
+			                        &test_f_wrapper, &test_o_wrapper) &&
+        upb_fielddef_type(test_f_wrapper) == UPB_TYPE_MESSAGE &&
+        is_wrapper_type_field(test_f_wrapper)) {
+      // It does exist!
+      has_field = true;
+      if (accessor_type == METHOD_SETTER) {
+        accessor_type = METHOD_WRAPPER_SETTER;
+      } else {
+        accessor_type = METHOD_WRAPPER_GETTER;
+      }
+      test_o = test_o_wrapper;
+      test_f = test_f_wrapper;
+    }
+  }
+
   // Look for enum accessor of the form <enum_name>_const
   // Look for enum accessor of the form <enum_name>_const
   if (!has_field && accessor_type == METHOD_GETTER &&
   if (!has_field && accessor_type == METHOD_GETTER &&
       name_len > 6 && strncmp(name + name_len - 6, "_const", 6) == 0) {
       name_len > 6 && strncmp(name + name_len - 6, "_const", 6) == 0) {
@@ -238,7 +294,7 @@ VALUE Message_method_missing(int argc, VALUE* argv, VALUE _self) {
   int accessor_type = extract_method_call(argv[0], self, &f, &o);
   int accessor_type = extract_method_call(argv[0], self, &f, &o);
   if (accessor_type == METHOD_UNKNOWN || (o == NULL && f == NULL) ) {
   if (accessor_type == METHOD_UNKNOWN || (o == NULL && f == NULL) ) {
     return rb_call_super(argc, argv);
     return rb_call_super(argc, argv);
-  } else if (accessor_type == METHOD_SETTER) {
+  } else if (accessor_type == METHOD_SETTER || accessor_type == METHOD_WRAPPER_SETTER) {
     if (argc != 2) {
     if (argc != 2) {
       rb_raise(rb_eArgError, "Expected 2 arguments, received %d", argc);
       rb_raise(rb_eArgError, "Expected 2 arguments, received %d", argc);
     }
     }
@@ -275,6 +331,16 @@ VALUE Message_method_missing(int argc, VALUE* argv, VALUE _self) {
     return Qnil;
     return Qnil;
   } else if (accessor_type == METHOD_PRESENCE) {
   } else if (accessor_type == METHOD_PRESENCE) {
     return layout_has(self->descriptor->layout, Message_data(self), f);
     return layout_has(self->descriptor->layout, Message_data(self), f);
+  } else if (accessor_type == METHOD_WRAPPER_GETTER) {
+    VALUE value = layout_get(self->descriptor->layout, Message_data(self), f);
+    if (value != Qnil) {
+      value = rb_funcall(value, rb_intern("value"), 0);
+    }
+    return value;
+  } else if (accessor_type == METHOD_WRAPPER_SETTER) {
+    VALUE wrapper = ruby_wrapper_type(f, argv[1]);
+    layout_set(self->descriptor->layout, Message_data(self), f, wrapper);
+    return Qnil;
   } else if (accessor_type == METHOD_ENUM_GETTER) {
   } else if (accessor_type == METHOD_ENUM_GETTER) {
     VALUE enum_type = field_type_class(f);
     VALUE enum_type = field_type_class(f);
     VALUE method = rb_intern("const_get");
     VALUE method = rb_intern("const_get");

+ 17 - 0
ruby/tests/basic_test.proto

@@ -2,6 +2,7 @@ syntax = "proto3";
 
 
 package basic_test;
 package basic_test;
 
 
+import "google/protobuf/wrappers.proto";
 import "google/protobuf/timestamp.proto";
 import "google/protobuf/timestamp.proto";
 import "google/protobuf/duration.proto";
 import "google/protobuf/duration.proto";
 import "google/protobuf/struct.proto";
 import "google/protobuf/struct.proto";
@@ -112,6 +113,22 @@ message Outer {
 message Inner {
 message Inner {
 }
 }
 
 
+message Wrapper {
+  google.protobuf.DoubleValue double = 1;
+  google.protobuf.FloatValue float = 2;
+  google.protobuf.Int32Value int32 = 3;
+  google.protobuf.Int64Value int64 = 4;
+  google.protobuf.UInt32Value uint32 = 5;
+  google.protobuf.UInt64Value uint64 = 6;
+  google.protobuf.BoolValue bool = 7;
+  google.protobuf.StringValue string = 8;
+  google.protobuf.BytesValue bytes = 9;
+  string real_string = 100;
+  oneof a_oneof {
+    string oneof_string = 10;
+  }
+}
+
 message TimeMessage {
 message TimeMessage {
   google.protobuf.Timestamp timestamp = 1;
   google.protobuf.Timestamp timestamp = 1;
   google.protobuf.Duration duration = 2;
   google.protobuf.Duration duration = 2;

+ 17 - 0
ruby/tests/basic_test_proto2.proto

@@ -2,6 +2,7 @@ syntax = "proto2";
 
 
 package basic_test_proto2;
 package basic_test_proto2;
 
 
+import "google/protobuf/wrappers.proto";
 import "google/protobuf/timestamp.proto";
 import "google/protobuf/timestamp.proto";
 import "google/protobuf/duration.proto";
 import "google/protobuf/duration.proto";
 import "google/protobuf/struct.proto";
 import "google/protobuf/struct.proto";
@@ -120,6 +121,22 @@ message OneofMessage {
   }
   }
 }
 }
 
 
+message Wrapper {
+  optional google.protobuf.DoubleValue double = 1;
+  optional google.protobuf.FloatValue float = 2;
+  optional google.protobuf.Int32Value int32 = 3;
+  optional google.protobuf.Int64Value int64 = 4;
+  optional google.protobuf.UInt32Value uint32 = 5;
+  optional google.protobuf.UInt64Value uint64 = 6;
+  optional google.protobuf.BoolValue bool = 7;
+  optional google.protobuf.StringValue string = 8;
+  optional google.protobuf.BytesValue bytes = 9;
+  optional string real_string = 100;
+  oneof a_oneof {
+    string oneof_string = 10;
+  }
+}
+
 message TimeMessage {
 message TimeMessage {
   optional google.protobuf.Timestamp timestamp = 1;
   optional google.protobuf.Timestamp timestamp = 1;
   optional google.protobuf.Duration duration = 2;
   optional google.protobuf.Duration duration = 2;

+ 181 - 0
ruby/tests/common_tests.rb

@@ -1,3 +1,5 @@
+require 'google/protobuf/wrappers_pb.rb'
+
 # Defines tests which are common between proto2 and proto3 syntax.
 # Defines tests which are common between proto2 and proto3 syntax.
 #
 #
 # Requires that the proto messages are exactly the same in proto2 and proto3 syntax
 # Requires that the proto messages are exactly the same in proto2 and proto3 syntax
@@ -1267,6 +1269,185 @@ module CommonTests
     assert proto_module::TestMessage.new != nil
     assert proto_module::TestMessage.new != nil
   end
   end
 
 
+  def test_wrapper_getters
+    m = proto_module::Wrapper.new(
+      double: Google::Protobuf::DoubleValue.new(value: 2.0),
+      float: Google::Protobuf::FloatValue.new(value: 4.0),
+      int32: Google::Protobuf::Int32Value.new(value: 3),
+      int64: Google::Protobuf::Int64Value.new(value: 4),
+      uint32: Google::Protobuf::UInt32Value.new(value: 5),
+      uint64: Google::Protobuf::UInt64Value.new(value: 6),
+      bool: Google::Protobuf::BoolValue.new(value: true),
+      string: Google::Protobuf::StringValue.new(value: 'str'),
+      bytes: Google::Protobuf::BytesValue.new(value: 'fun'),
+      real_string: '100'
+    )
+
+    assert_equal 2.0, m.double_as_value
+    assert_equal 2.0, m.double.value
+    assert_equal 4.0, m.float_as_value
+    assert_equal 4.0, m.float.value
+    assert_equal 3, m.int32_as_value
+    assert_equal 3, m.int32.value
+    assert_equal 4, m.int64_as_value
+    assert_equal 4, m.int64.value
+    assert_equal 5, m.uint32_as_value
+    assert_equal 5, m.uint32.value
+    assert_equal 6, m.uint64_as_value
+    assert_equal 6, m.uint64.value
+    assert_equal true, m.bool_as_value
+    assert_equal true, m.bool.value
+    assert_equal 'str', m.string_as_value
+    assert_equal 'str', m.string.value
+    assert_equal 'fun', m.bytes_as_value
+    assert_equal 'fun', m.bytes.value
+  end
+
+  def test_wrapper_setters_as_value
+    m = proto_module::Wrapper.new
+
+    m.double_as_value = 4.8
+    assert_equal 4.8, m.double_as_value
+    assert_equal Google::Protobuf::DoubleValue.new(value: 4.8), m.double
+    m.float_as_value = 2.4
+    assert_in_delta 2.4, m.float_as_value
+    assert_in_delta Google::Protobuf::FloatValue.new(value: 2.4).value, m.float.value
+    m.int32_as_value = 5
+    assert_equal 5, m.int32_as_value
+    assert_equal Google::Protobuf::Int32Value.new(value: 5), m.int32
+    m.int64_as_value = 15
+    assert_equal 15, m.int64_as_value
+    assert_equal Google::Protobuf::Int64Value.new(value: 15), m.int64
+    m.uint32_as_value = 50
+    assert_equal 50, m.uint32_as_value
+    assert_equal Google::Protobuf::UInt32Value.new(value: 50), m.uint32
+    m.uint64_as_value = 500
+    assert_equal 500, m.uint64_as_value
+    assert_equal Google::Protobuf::UInt64Value.new(value: 500), m.uint64
+    m.bool_as_value = false
+    assert_equal false, m.bool_as_value
+    assert_equal Google::Protobuf::BoolValue.new(value: false), m.bool
+    m.string_as_value = 'xy'
+    assert_equal 'xy', m.string_as_value
+    assert_equal Google::Protobuf::StringValue.new(value: 'xy'), m.string
+    m.bytes_as_value = '123'
+    assert_equal '123', m.bytes_as_value
+    assert_equal Google::Protobuf::BytesValue.new(value: '123'), m.bytes
+
+    m.double_as_value = nil
+    assert_nil m.double
+    assert_nil m.double_as_value
+    m.float_as_value = nil
+    assert_nil m.float
+    assert_nil m.float_as_value
+    m.int32_as_value = nil
+    assert_nil m.int32
+    assert_nil m.int32_as_value
+    m.int64_as_value = nil
+    assert_nil m.int64
+    assert_nil m.int64_as_value
+    m.uint32_as_value = nil
+    assert_nil m.uint32
+    assert_nil m.uint32_as_value
+    m.uint64_as_value = nil
+    assert_nil m.uint64
+    assert_nil m.uint64_as_value
+    m.bool_as_value = nil
+    assert_nil m.bool
+    assert_nil m.bool_as_value
+    m.string_as_value = nil
+    assert_nil m.string
+    assert_nil m.string_as_value
+    m.bytes_as_value = nil
+    assert_nil m.bytes
+    assert_nil m.bytes_as_value
+  end
+
+  def test_wrapper_setters
+    m = proto_module::Wrapper.new
+
+    m.double = Google::Protobuf::DoubleValue.new(value: 4.8)
+    assert_equal 4.8, m.double_as_value
+    assert_equal Google::Protobuf::DoubleValue.new(value: 4.8), m.double
+    m.float = Google::Protobuf::FloatValue.new(value: 2.4)
+    assert_in_delta 2.4, m.float_as_value
+    assert_in_delta Google::Protobuf::FloatValue.new(value: 2.4).value, m.float.value
+    m.int32 = Google::Protobuf::Int32Value.new(value: 5)
+    assert_equal 5, m.int32_as_value
+    assert_equal Google::Protobuf::Int32Value.new(value: 5), m.int32
+    m.int64 = Google::Protobuf::Int64Value.new(value: 15)
+    assert_equal 15, m.int64_as_value
+    assert_equal Google::Protobuf::Int64Value.new(value: 15), m.int64
+    m.uint32 = Google::Protobuf::UInt32Value.new(value: 50)
+    assert_equal 50, m.uint32_as_value
+    assert_equal Google::Protobuf::UInt32Value.new(value: 50), m.uint32
+    m.uint64 = Google::Protobuf::UInt64Value.new(value: 500)
+    assert_equal 500, m.uint64_as_value
+    assert_equal Google::Protobuf::UInt64Value.new(value: 500), m.uint64
+    m.bool = Google::Protobuf::BoolValue.new(value: false)
+    assert_equal false, m.bool_as_value
+    assert_equal Google::Protobuf::BoolValue.new(value: false), m.bool
+    m.string = Google::Protobuf::StringValue.new(value: 'xy')
+    assert_equal 'xy', m.string_as_value
+    assert_equal Google::Protobuf::StringValue.new(value: 'xy'), m.string
+    m.bytes = Google::Protobuf::BytesValue.new(value: '123')
+    assert_equal '123', m.bytes_as_value
+    assert_equal Google::Protobuf::BytesValue.new(value: '123'), m.bytes
+
+    m.double = nil
+    assert_nil m.double
+    assert_nil m.double_as_value
+    m.float = nil
+    assert_nil m.float
+    assert_nil m.float_as_value
+    m.int32 = nil
+    assert_nil m.int32
+    assert_nil m.int32_as_value
+    m.int64 = nil
+    assert_nil m.int64
+    assert_nil m.int64_as_value
+    m.uint32 = nil
+    assert_nil m.uint32
+    assert_nil m.uint32_as_value
+    m.uint64 = nil
+    assert_nil m.uint64
+    assert_nil m.uint64_as_value
+    m.bool = nil
+    assert_nil m.bool
+    assert_nil m.bool_as_value
+    m.string = nil
+    assert_nil m.string
+    assert_nil m.string_as_value
+    m.bytes = nil
+    assert_nil m.bytes
+    assert_nil m.bytes_as_value
+  end
+
+  def test_wrappers_only
+    m = proto_module::Wrapper.new(real_string: 'hi', oneof_string: 'there')
+
+    assert_raise(NoMethodError) { m.real_string_as_value }
+    assert_raise(NoMethodError) { m.as_value }
+    assert_raise(NoMethodError) { m._as_value }
+    assert_raise(NoMethodError) { m.oneof_string_as_value }
+
+    m = proto_module::Wrapper.new
+    m.string_as_value = 'you'
+    assert_equal 'you', m.string.value
+    assert_equal 'you', m.string_as_value
+    assert_raise(NoMethodError) { m.string_ }
+    assert_raise(NoMethodError) { m.string_X }
+    assert_raise(NoMethodError) { m.string_XX }
+    assert_raise(NoMethodError) { m.string_XXX }
+    assert_raise(NoMethodError) { m.string_XXXX }
+    assert_raise(NoMethodError) { m.string_XXXXX }
+    assert_raise(NoMethodError) { m.string_XXXXXX }
+    assert_raise(NoMethodError) { m.string_XXXXXXX }
+    assert_raise(NoMethodError) { m.string_XXXXXXXX }
+    assert_raise(NoMethodError) { m.string_XXXXXXXXX }
+    assert_raise(NoMethodError) { m.string_XXXXXXXXXX }
+  end
+  
   def test_converts_time
   def test_converts_time
     m = proto_module::TimeMessage.new
     m = proto_module::TimeMessage.new