Преглед на файлове

Merge pull request #1371 from keveman/oversize_protos

Added an API to allow oversize protos when using C++ extension in Python
Joshua Haberman преди 9 години
родител
ревизия
81eb84c029
променени са 2 файла, в които са добавени 94 реда и са изтрити 11 реда
  1. 59 0
      python/google/protobuf/internal/message_test.py
  2. 35 11
      python/google/protobuf/pyext/message.cc

+ 59 - 0
python/google/protobuf/internal/message_test.py

@@ -57,7 +57,11 @@ try:
 except ImportError:
   import unittest
 from google.protobuf.internal import _parameterized
+from google.protobuf import descriptor_pb2
+from google.protobuf import descriptor_pool
 from google.protobuf import map_unittest_pb2
+from google.protobuf import message_factory
+from google.protobuf import text_format
 from google.protobuf import unittest_pb2
 from google.protobuf import unittest_proto3_arena_pb2
 from google.protobuf.internal import any_test_pb2
@@ -1776,5 +1780,60 @@ class PackedFieldTest(unittest.TestCase):
                    b'\x70\x01')
     self.assertEqual(golden_data, message.SerializeToString())
 
+
+@unittest.skipIf(api_implementation.Type() != 'cpp',
+                 'explicit tests of the C++ implementation')
+class OversizeProtosTest(unittest.TestCase):
+
+  def setUp(self):
+    self.file_desc = """
+      name: "f/f.msg2"
+      package: "f"
+      message_type {
+        name: "msg1"
+        field {
+          name: "payload"
+          number: 1
+          label: LABEL_OPTIONAL
+          type: TYPE_STRING
+        }
+      }
+      message_type {
+        name: "msg2"
+        field {
+          name: "field"
+          number: 1
+          label: LABEL_OPTIONAL
+          type: TYPE_MESSAGE
+          type_name: "msg1"
+        }
+      }
+    """
+    pool = descriptor_pool.DescriptorPool()
+    desc = descriptor_pb2.FileDescriptorProto()
+    text_format.Parse(self.file_desc, desc)
+    pool.Add(desc)
+    self.proto_cls = message_factory.MessageFactory(pool).GetPrototype(
+        pool.FindMessageTypeByName('f.msg2'))
+    self.p = self.proto_cls()
+    self.p.field.payload = 'c' * (1024 * 1024 * 64 + 1)
+    self.p_serialized = self.p.SerializeToString()
+
+  def testAssertOversizeProto(self):
+    from google.protobuf.pyext._message import SetAllowOversizeProtos
+    SetAllowOversizeProtos(False)
+    q = self.proto_cls()
+    try:
+      q.ParseFromString(self.p_serialized)
+    except message.DecodeError as e:
+      self.assertEqual(str(e), 'Error parsing message')
+
+  def testSucceedOversizeProto(self):
+    from google.protobuf.pyext._message import SetAllowOversizeProtos
+    SetAllowOversizeProtos(True)
+    q = self.proto_cls()
+    q.ParseFromString(self.p_serialized)
+    self.assertEqual(self.p.field.payload, q.field.payload)
+
 if __name__ == '__main__':
   unittest.main()

+ 35 - 11
python/google/protobuf/pyext/message.cc

@@ -1911,6 +1911,30 @@ static PyObject* CopyFrom(CMessage* self, PyObject* arg) {
   Py_RETURN_NONE;
 }
 
+// Protobuf has a 64MB limit built in, this variable will override this. Please
+// do not enable this unless you fully understand the implications: protobufs
+// must all be kept in memory at the same time, so if they grow too big you may
+// get OOM errors. The protobuf APIs do not provide any tools for processing
+// protobufs in chunks.  If you have protos this big you should break them up if
+// it is at all convenient to do so.
+static bool allow_oversize_protos = false;
+
+// Provide a method in the module to set allow_oversize_protos to a boolean
+// value. This method returns the newly value of allow_oversize_protos.
+static PyObject* SetAllowOversizeProtos(PyObject* m, PyObject* arg) {
+  if (!arg || !PyBool_Check(arg)) {
+    PyErr_SetString(PyExc_TypeError,
+                    "Argument to SetAllowOversizeProtos must be boolean");
+    return NULL;
+  }
+  allow_oversize_protos = PyObject_IsTrue(arg);
+  if (allow_oversize_protos) {
+    Py_RETURN_TRUE;
+  } else {
+    Py_RETURN_FALSE;
+  }
+}
+
 static PyObject* MergeFromString(CMessage* self, PyObject* arg) {
   const void* data;
   Py_ssize_t data_length;
@@ -1921,15 +1945,9 @@ static PyObject* MergeFromString(CMessage* self, PyObject* arg) {
   AssureWritable(self);
   io::CodedInputStream input(
       reinterpret_cast<const uint8*>(data), data_length);
-#if PROTOBUF_PYTHON_ALLOW_OVERSIZE_PROTOS
-  // Protobuf has a 64MB limit built in, this code will override this. Please do
-  // not enable this unless you fully understand the implications: protobufs
-  // must all be kept in memory at the same time, so if they grow too big you
-  // may get OOM errors. The protobuf APIs do not provide any tools for
-  // processing protobufs in chunks.  If you have protos this big you should
-  // break them up if it is at all convenient to do so.
-  input.SetTotalBytesLimit(INT_MAX, INT_MAX);
-#endif  // PROTOBUF_PYTHON_ALLOW_OVERSIZE_PROTOS
+  if (allow_oversize_protos) {
+    input.SetTotalBytesLimit(INT_MAX, INT_MAX);
+  }
   PyDescriptorPool* pool = GetDescriptorPoolForMessage(self);
   input.SetExtensionRegistry(pool->pool, pool->message_factory);
   bool success = self->message->MergePartialFromCodedStream(&input);
@@ -3046,6 +3064,11 @@ bool InitProto2MessageModule(PyObject *m) {
 }  // namespace python
 }  // namespace protobuf
 
+static PyMethodDef ModuleMethods[] = {
+    {"SetAllowOversizeProtos",
+     (PyCFunction)google::protobuf::python::cmessage::SetAllowOversizeProtos,
+     METH_O, "Enable/disable oversize proto parsing."},
+};
 
 #if PY_MAJOR_VERSION >= 3
 static struct PyModuleDef _module = {
@@ -3053,7 +3076,7 @@ static struct PyModuleDef _module = {
   "_message",
   google::protobuf::python::module_docstring,
   -1,
-  NULL,
+  ModuleMethods,  /* m_methods */
   NULL,
   NULL,
   NULL,
@@ -3072,7 +3095,8 @@ extern "C" {
 #if PY_MAJOR_VERSION >= 3
     m = PyModule_Create(&_module);
 #else
-    m = Py_InitModule3("_message", NULL, google::protobuf::python::module_docstring);
+    m = Py_InitModule3("_message", ModuleMethods,
+                       google::protobuf::python::module_docstring);
 #endif
     if (m == NULL) {
       return INITFUNC_ERRORVAL;