Răsfoiți Sursa

cherrypick descriptor_pool.FindFileContainingSymbol by extensions (#2962)

* Use PyUnicode_AsEncodedString() instead of PyUnicode_AsEncodedObject()

* Cherrypick the fix descriptor_pool.FindFileContainingSymbol by extensions.
Jie Luo 8 ani în urmă
părinte
comite
899460c9cb

+ 31 - 5
python/google/protobuf/descriptor_pool.py

@@ -127,6 +127,9 @@ class DescriptorPool(object):
     self._service_descriptors = {}
     self._service_descriptors = {}
     self._file_descriptors = {}
     self._file_descriptors = {}
     self._toplevel_extensions = {}
     self._toplevel_extensions = {}
+    # TODO(jieluo): Remove _file_desc_by_toplevel_extension when
+    # FieldDescriptor.file is added in code gen.
+    self._file_desc_by_toplevel_extension = {}
     # We store extensions in two two-level mappings: The first key is the
     # We store extensions in two two-level mappings: The first key is the
     # descriptor of the message being extended, the second key is the extension
     # descriptor of the message being extended, the second key is the extension
     # full name or its tag number.
     # full name or its tag number.
@@ -170,7 +173,7 @@ class DescriptorPool(object):
       raise TypeError('Expected instance of descriptor.Descriptor.')
       raise TypeError('Expected instance of descriptor.Descriptor.')
 
 
     self._descriptors[desc.full_name] = desc
     self._descriptors[desc.full_name] = desc
-    self.AddFileDescriptor(desc.file)
+    self._AddFileDescriptor(desc.file)
 
 
   def AddEnumDescriptor(self, enum_desc):
   def AddEnumDescriptor(self, enum_desc):
     """Adds an EnumDescriptor to the pool.
     """Adds an EnumDescriptor to the pool.
@@ -185,7 +188,7 @@ class DescriptorPool(object):
       raise TypeError('Expected instance of descriptor.EnumDescriptor.')
       raise TypeError('Expected instance of descriptor.EnumDescriptor.')
 
 
     self._enum_descriptors[enum_desc.full_name] = enum_desc
     self._enum_descriptors[enum_desc.full_name] = enum_desc
-    self.AddFileDescriptor(enum_desc.file)
+    self._AddFileDescriptor(enum_desc.file)
 
 
   def AddServiceDescriptor(self, service_desc):
   def AddServiceDescriptor(self, service_desc):
     """Adds a ServiceDescriptor to the pool.
     """Adds a ServiceDescriptor to the pool.
@@ -251,6 +254,23 @@ class DescriptorPool(object):
       file_desc: A FileDescriptor.
       file_desc: A FileDescriptor.
     """
     """
 
 
+    self._AddFileDescriptor(file_desc)
+    # TODO(jieluo): This is a temporary solution for FieldDescriptor.file.
+    # Remove it when FieldDescriptor.file is added in code gen.
+    for extension in file_desc.extensions_by_name.itervalues():
+      self._file_desc_by_toplevel_extension[
+          extension.full_name] = file_desc
+
+  def _AddFileDescriptor(self, file_desc):
+    """Adds a FileDescriptor to the pool, non-recursively.
+
+    If the FileDescriptor contains messages or enums, the caller must explicitly
+    register them.
+
+    Args:
+      file_desc: A FileDescriptor.
+    """
+
     if not isinstance(file_desc, descriptor.FileDescriptor):
     if not isinstance(file_desc, descriptor.FileDescriptor):
       raise TypeError('Expected instance of descriptor.FileDescriptor.')
       raise TypeError('Expected instance of descriptor.FileDescriptor.')
     self._file_descriptors[file_desc.name] = file_desc
     self._file_descriptors[file_desc.name] = file_desc
@@ -313,12 +333,18 @@ class DescriptorPool(object):
     except KeyError:
     except KeyError:
       pass
       pass
 
 
+    try:
+      return self._file_desc_by_toplevel_extension[symbol]
+    except KeyError:
+      pass
+
     # Try nested extensions inside a message.
     # Try nested extensions inside a message.
     message_name, _, extension_name = symbol.rpartition('.')
     message_name, _, extension_name = symbol.rpartition('.')
     try:
     try:
-      scope = self.FindMessageTypeByName(message_name)
-      assert scope.extensions_by_name[extension_name]
-      return scope.file
+      message = self.FindMessageTypeByName(message_name)
+      assert message.extensions_by_name[extension_name]
+      return message.file
+
     except KeyError:
     except KeyError:
       raise KeyError('Cannot find a file containing %s' % symbol)
       raise KeyError('Cannot find a file containing %s' % symbol)
 
 

+ 9 - 0
python/google/protobuf/internal/descriptor_pool_test.py

@@ -63,6 +63,9 @@ from google.protobuf import symbol_database
 class DescriptorPoolTest(unittest.TestCase):
 class DescriptorPoolTest(unittest.TestCase):
 
 
   def setUp(self):
   def setUp(self):
+    # TODO(jieluo): Should make the pool which is created by
+    # serialized_pb same with generated pool.
+    # TODO(jieluo): More test coverage for the generated pool.
     self.pool = descriptor_pool.DescriptorPool()
     self.pool = descriptor_pool.DescriptorPool()
     self.factory_test1_fd = descriptor_pb2.FileDescriptorProto.FromString(
     self.factory_test1_fd = descriptor_pb2.FileDescriptorProto.FromString(
         factory_test1_pb2.DESCRIPTOR.serialized_pb)
         factory_test1_pb2.DESCRIPTOR.serialized_pb)
@@ -128,6 +131,12 @@ class DescriptorPoolTest(unittest.TestCase):
     self.assertEqual('google/protobuf/internal/factory_test2.proto',
     self.assertEqual('google/protobuf/internal/factory_test2.proto',
                      file_desc4.name)
                      file_desc4.name)
 
 
+    # Tests the generated pool.
+    assert descriptor_pool.Default().FindFileContainingSymbol(
+        'google.protobuf.python.internal.Factory2Message.one_more_field')
+    assert descriptor_pool.Default().FindFileContainingSymbol(
+        'google.protobuf.python.internal.another_field')
+
   def testFindFileContainingSymbolFailure(self):
   def testFindFileContainingSymbolFailure(self):
     with self.assertRaises(KeyError):
     with self.assertRaises(KeyError):
       self.pool.FindFileContainingSymbol('Does not exist')
       self.pool.FindFileContainingSymbol('Does not exist')

+ 1 - 1
python/google/protobuf/pyext/message.cc

@@ -779,7 +779,7 @@ PyObject* CheckString(PyObject* arg, const FieldDescriptor* descriptor) {
       encoded_string = arg;  // Already encoded.
       encoded_string = arg;  // Already encoded.
       Py_INCREF(encoded_string);
       Py_INCREF(encoded_string);
     } else {
     } else {
-      encoded_string = PyUnicode_AsEncodedObject(arg, "utf-8", NULL);
+      encoded_string = PyUnicode_AsEncodedString(arg, "utf-8", NULL);
     }
     }
   } else {
   } else {
     // In this case field type is "bytes".
     // In this case field type is "bytes".

+ 4 - 2
src/google/protobuf/compiler/python/python_generator.cc

@@ -445,8 +445,6 @@ void Generator::PrintFileDescriptor() const {
 
 
   printer_->Outdent();
   printer_->Outdent();
   printer_->Print(")\n");
   printer_->Print(")\n");
-  printer_->Print("_sym_db.RegisterFileDescriptor($name$)\n", "name",
-                  kDescriptorKey);
   printer_->Print("\n");
   printer_->Print("\n");
 }
 }
 
 
@@ -999,6 +997,10 @@ void Generator::FixForeignFieldsInDescriptors() const {
   for (int i = 0; i < file_->extension_count(); ++i) {
   for (int i = 0; i < file_->extension_count(); ++i) {
     AddExtensionToFileDescriptor(*file_->extension(i));
     AddExtensionToFileDescriptor(*file_->extension(i));
   }
   }
+  // TODO(jieluo): Move this register to PrintFileDescriptor() when
+  // FieldDescriptor.file is added in generated file.
+  printer_->Print("_sym_db.RegisterFileDescriptor($name$)\n", "name",
+                  kDescriptorKey);
   printer_->Print("\n");
   printer_->Print("\n");
 }
 }