Explorar o código

Fix for construction of messages in the C++ Python implementation. (#2299)

* Fix construction of messages using the C++ Python implementation when a map field is passed from one message to another.

* Add a test on message map field construction

* python 3 support

* review comments

* add test

* Collapse code into one
Peter Ebden %!s(int64=6) %!d(string=hai) anos
pai
achega
2adde396ed

+ 18 - 0
python/google/protobuf/internal/message_test.py

@@ -2214,6 +2214,24 @@ class Proto3Test(BaseTestCase):
         map_int32_foreign_message={3: unittest_pb2.ForeignMessage(c=5)})
     self.assertEqual(5, msg.map_int32_foreign_message[3].c)
 
+  def testMapScalarFieldConstruction(self):
+    msg1 = map_unittest_pb2.TestMap()
+    msg1.map_int32_int32[1] = 42
+    msg2 = map_unittest_pb2.TestMap(map_int32_int32=msg1.map_int32_int32)
+    self.assertEqual(42, msg2.map_int32_int32[1])
+
+  def testMapMessageFieldConstruction(self):
+    msg1 = map_unittest_pb2.TestMap()
+    msg1.map_string_foreign_message['test'].c = 42
+    msg2 = map_unittest_pb2.TestMap(
+      map_string_foreign_message=msg1.map_string_foreign_message)
+    self.assertEqual(42, msg2.map_string_foreign_message['test'].c)
+
+  def testMapFieldRaisesCorrectError(self):
+    # Should raise a TypeError when given a non-iterable.
+    with self.assertRaises(TypeError):
+      map_unittest_pb2.TestMap(map_string_foreign_message=1)
+
   def testMapValidAfterFieldCleared(self):
     # Map needs to work even if field is cleared.
     # For the C++ implementation this tests the correctness of

+ 11 - 8
python/google/protobuf/pyext/message.cc

@@ -1236,17 +1236,20 @@ int InitAttributes(CMessage* self, PyObject* args, PyObject* kwargs) {
       const FieldDescriptor* value_descriptor =
           descriptor->message_type()->FindFieldByName("value");
       if (value_descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
-        Py_ssize_t map_pos = 0;
-        PyObject* map_key;
-        PyObject* map_value;
-        while (PyDict_Next(value, &map_pos, &map_key, &map_value)) {
-          ScopedPyObjectPtr function_return;
-          function_return.reset(PyObject_GetItem(map.get(), map_key));
-          if (function_return.get() == NULL) {
+        ScopedPyObjectPtr iter(PyObject_GetIter(value));
+        if (iter == NULL) {
+          PyErr_Format(PyExc_TypeError, "Argument %s is not iterable", PyString_AsString(name));
+          return -1;
+        }
+        ScopedPyObjectPtr next;
+        while ((next.reset(PyIter_Next(iter.get()))) != NULL) {
+          ScopedPyObjectPtr source_value(PyObject_GetItem(value, next.get()));
+          ScopedPyObjectPtr dest_value(PyObject_GetItem(map.get(), next.get()));
+          if (source_value.get() == NULL || dest_value.get() == NULL) {
             return -1;
           }
           ScopedPyObjectPtr ok(PyObject_CallMethod(
-              function_return.get(), "MergeFrom", "O", map_value));
+              dest_value.get(), "MergeFrom", "O", source_value.get()));
           if (ok.get() == NULL) {
             return -1;
           }