reflection_test.py 74 KB


  1. #! /usr/bin/python
  2. # -*- coding: utf-8 -*-
  3. #
  4. # Protocol Buffers - Google's data interchange format
  5. # Copyright 2008 Google Inc. All rights reserved.
  6. # http://code.google.com/p/protobuf/
  7. #
  8. # Redistribution and use in source and binary forms, with or without
  9. # modification, are permitted provided that the following conditions are
  10. # met:
  11. #
  12. # * Redistributions of source code must retain the above copyright
  13. # notice, this list of conditions and the following disclaimer.
  14. # * Redistributions in binary form must reproduce the above
  15. # copyright notice, this list of conditions and the following disclaimer
  16. # in the documentation and/or other materials provided with the
  17. # distribution.
  18. # * Neither the name of Google Inc. nor the names of its
  19. # contributors may be used to endorse or promote products derived from
  20. # this software without specific prior written permission.
  21. #
  22. # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
  23. # "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
  24. # LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
  25. # A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
  26. # OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
  27. # SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
  28. # LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
  29. # DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
  30. # THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
  31. # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
  32. # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
  33. """Unittest for reflection.py, which also indirectly tests the output of the
  34. pure-Python protocol compiler.
  35. """
  36. __author__ = 'robinson@google.com (Will Robinson)'
  37. import operator
  38. import unittest
  39. # TODO(robinson): When we split this test in two, only some of these imports
  40. # will be necessary in each test.
  41. from google.protobuf import unittest_import_pb2
  42. from google.protobuf import unittest_mset_pb2
  43. from google.protobuf import unittest_pb2
  44. from google.protobuf import descriptor_pb2
  45. from google.protobuf import descriptor
  46. from google.protobuf import message
  47. from google.protobuf import reflection
  48. from google.protobuf.internal import more_extensions_pb2
  49. from google.protobuf.internal import more_messages_pb2
  50. from google.protobuf.internal import wire_format
  51. from google.protobuf.internal import test_util
  52. from google.protobuf.internal import decoder
  53. class ReflectionTest(unittest.TestCase):
  54. def assertIs(self, values, others):
  55. self.assertEqual(len(values), len(others))
  56. for i in range(len(values)):
  57. self.assertTrue(values[i] is others[i])
  58. def testSimpleHasBits(self):
  59. # Test a scalar.
  60. proto = unittest_pb2.TestAllTypes()
  61. self.assertTrue(not proto.HasField('optional_int32'))
  62. self.assertEqual(0, proto.optional_int32)
  63. # HasField() shouldn't be true if all we've done is
  64. # read the default value.
  65. self.assertTrue(not proto.HasField('optional_int32'))
  66. proto.optional_int32 = 1
  67. # Setting a value however *should* set the "has" bit.
  68. self.assertTrue(proto.HasField('optional_int32'))
  69. proto.ClearField('optional_int32')
  70. # And clearing that value should unset the "has" bit.
  71. self.assertTrue(not proto.HasField('optional_int32'))
  72. def testHasBitsWithSinglyNestedScalar(self):
  73. # Helper used to test foreign messages and groups.
  74. #
  75. # composite_field_name should be the name of a non-repeated
  76. # composite (i.e., foreign or group) field in TestAllTypes,
  77. # and scalar_field_name should be the name of an integer-valued
  78. # scalar field within that composite.
  79. #
  80. # I never thought I'd miss C++ macros and templates so much. :(
  81. # This helper is semantically just:
  82. #
  83. # assert proto.composite_field.scalar_field == 0
  84. # assert not proto.composite_field.HasField('scalar_field')
  85. # assert not proto.HasField('composite_field')
  86. #
  87. # proto.composite_field.scalar_field = 10
  88. # old_composite_field = proto.composite_field
  89. #
  90. # assert proto.composite_field.scalar_field == 10
  91. # assert proto.composite_field.HasField('scalar_field')
  92. # assert proto.HasField('composite_field')
  93. #
  94. # proto.ClearField('composite_field')
  95. #
  96. # assert not proto.composite_field.HasField('scalar_field')
  97. # assert not proto.HasField('composite_field')
  98. # assert proto.composite_field.scalar_field == 0
  99. #
  100. # # Now ensure that ClearField('composite_field') disconnected
  101. # # the old field object from the object tree...
  102. # assert old_composite_field is not proto.composite_field
  103. # old_composite_field.scalar_field = 20
  104. # assert not proto.composite_field.HasField('scalar_field')
  105. # assert not proto.HasField('composite_field')
  106. def TestCompositeHasBits(composite_field_name, scalar_field_name):
  107. proto = unittest_pb2.TestAllTypes()
  108. # First, check that we can get the scalar value, and see that it's the
  109. # default (0), but that proto.HasField('omposite') and
  110. # proto.composite.HasField('scalar') will still return False.
  111. composite_field = getattr(proto, composite_field_name)
  112. original_scalar_value = getattr(composite_field, scalar_field_name)
  113. self.assertEqual(0, original_scalar_value)
  114. # Assert that the composite object does not "have" the scalar.
  115. self.assertTrue(not composite_field.HasField(scalar_field_name))
  116. # Assert that proto does not "have" the composite field.
  117. self.assertTrue(not proto.HasField(composite_field_name))
  118. # Now set the scalar within the composite field. Ensure that the setting
  119. # is reflected, and that proto.HasField('composite') and
  120. # proto.composite.HasField('scalar') now both return True.
  121. new_val = 20
  122. setattr(composite_field, scalar_field_name, new_val)
  123. self.assertEqual(new_val, getattr(composite_field, scalar_field_name))
  124. # Hold on to a reference to the current composite_field object.
  125. old_composite_field = composite_field
  126. # Assert that the has methods now return true.
  127. self.assertTrue(composite_field.HasField(scalar_field_name))
  128. self.assertTrue(proto.HasField(composite_field_name))
  129. # Now call the clear method...
  130. proto.ClearField(composite_field_name)
  131. # ...and ensure that the "has" bits are all back to False...
  132. composite_field = getattr(proto, composite_field_name)
  133. self.assertTrue(not composite_field.HasField(scalar_field_name))
  134. self.assertTrue(not proto.HasField(composite_field_name))
  135. # ...and ensure that the scalar field has returned to its default.
  136. self.assertEqual(0, getattr(composite_field, scalar_field_name))
  137. # Finally, ensure that modifications to the old composite field object
  138. # don't have any effect on the parent.
  139. #
  140. # (NOTE that when we clear the composite field in the parent, we actually
  141. # don't recursively clear down the tree. Instead, we just disconnect the
  142. # cleared composite from the tree.)
  143. self.assertTrue(old_composite_field is not composite_field)
  144. setattr(old_composite_field, scalar_field_name, new_val)
  145. self.assertTrue(not composite_field.HasField(scalar_field_name))
  146. self.assertTrue(not proto.HasField(composite_field_name))
  147. self.assertEqual(0, getattr(composite_field, scalar_field_name))
  148. # Test simple, single-level nesting when we set a scalar.
  149. TestCompositeHasBits('optionalgroup', 'a')
  150. TestCompositeHasBits('optional_nested_message', 'bb')
  151. TestCompositeHasBits('optional_foreign_message', 'c')
  152. TestCompositeHasBits('optional_import_message', 'd')
  153. def testReferencesToNestedMessage(self):
  154. proto = unittest_pb2.TestAllTypes()
  155. nested = proto.optional_nested_message
  156. del proto
  157. # A previous version had a bug where this would raise an exception when
  158. # hitting a now-dead weak reference.
  159. nested.bb = 23
  160. def testDisconnectingNestedMessageBeforeSettingField(self):
  161. proto = unittest_pb2.TestAllTypes()
  162. nested = proto.optional_nested_message
  163. proto.ClearField('optional_nested_message') # Should disconnect from parent
  164. self.assertTrue(nested is not proto.optional_nested_message)
  165. nested.bb = 23
  166. self.assertTrue(not proto.HasField('optional_nested_message'))
  167. self.assertEqual(0, proto.optional_nested_message.bb)
  168. def testHasBitsWhenModifyingRepeatedFields(self):
  169. # Test nesting when we add an element to a repeated field in a submessage.
  170. proto = unittest_pb2.TestNestedMessageHasBits()
  171. proto.optional_nested_message.nestedmessage_repeated_int32.append(5)
  172. self.assertEqual(
  173. [5], proto.optional_nested_message.nestedmessage_repeated_int32)
  174. self.assertTrue(proto.HasField('optional_nested_message'))
  175. # Do the same test, but with a repeated composite field within the
  176. # submessage.
  177. proto.ClearField('optional_nested_message')
  178. self.assertTrue(not proto.HasField('optional_nested_message'))
  179. proto.optional_nested_message.nestedmessage_repeated_foreignmessage.add()
  180. self.assertTrue(proto.HasField('optional_nested_message'))
  181. def testHasBitsForManyLevelsOfNesting(self):
  182. # Test nesting many levels deep.
  183. recursive_proto = unittest_pb2.TestMutualRecursionA()
  184. self.assertTrue(not recursive_proto.HasField('bb'))
  185. self.assertEqual(0, recursive_proto.bb.a.bb.a.bb.optional_int32)
  186. self.assertTrue(not recursive_proto.HasField('bb'))
  187. recursive_proto.bb.a.bb.a.bb.optional_int32 = 5
  188. self.assertEqual(5, recursive_proto.bb.a.bb.a.bb.optional_int32)
  189. self.assertTrue(recursive_proto.HasField('bb'))
  190. self.assertTrue(recursive_proto.bb.HasField('a'))
  191. self.assertTrue(recursive_proto.bb.a.HasField('bb'))
  192. self.assertTrue(recursive_proto.bb.a.bb.HasField('a'))
  193. self.assertTrue(recursive_proto.bb.a.bb.a.HasField('bb'))
  194. self.assertTrue(not recursive_proto.bb.a.bb.a.bb.HasField('a'))
  195. self.assertTrue(recursive_proto.bb.a.bb.a.bb.HasField('optional_int32'))
  196. def testSingularListFields(self):
  197. proto = unittest_pb2.TestAllTypes()
  198. proto.optional_fixed32 = 1
  199. proto.optional_int32 = 5
  200. proto.optional_string = 'foo'
  201. self.assertEqual(
  202. [ (proto.DESCRIPTOR.fields_by_name['optional_int32' ], 5),
  203. (proto.DESCRIPTOR.fields_by_name['optional_fixed32'], 1),
  204. (proto.DESCRIPTOR.fields_by_name['optional_string' ], 'foo') ],
  205. proto.ListFields())
  206. def testRepeatedListFields(self):
  207. proto = unittest_pb2.TestAllTypes()
  208. proto.repeated_fixed32.append(1)
  209. proto.repeated_int32.append(5)
  210. proto.repeated_int32.append(11)
  211. proto.repeated_string.extend(['foo', 'bar'])
  212. proto.repeated_string.extend([])
  213. proto.repeated_string.append('baz')
  214. proto.optional_int32 = 21
  215. self.assertEqual(
  216. [ (proto.DESCRIPTOR.fields_by_name['optional_int32' ], 21),
  217. (proto.DESCRIPTOR.fields_by_name['repeated_int32' ], [5, 11]),
  218. (proto.DESCRIPTOR.fields_by_name['repeated_fixed32'], [1]),
  219. (proto.DESCRIPTOR.fields_by_name['repeated_string' ],
  220. ['foo', 'bar', 'baz']) ],
  221. proto.ListFields())
  222. def testSingularListExtensions(self):
  223. proto = unittest_pb2.TestAllExtensions()
  224. proto.Extensions[unittest_pb2.optional_fixed32_extension] = 1
  225. proto.Extensions[unittest_pb2.optional_int32_extension ] = 5
  226. proto.Extensions[unittest_pb2.optional_string_extension ] = 'foo'
  227. self.assertEqual(
  228. [ (unittest_pb2.optional_int32_extension , 5),
  229. (unittest_pb2.optional_fixed32_extension, 1),
  230. (unittest_pb2.optional_string_extension , 'foo') ],
  231. proto.ListFields())
  232. def testRepeatedListExtensions(self):
  233. proto = unittest_pb2.TestAllExtensions()
  234. proto.Extensions[unittest_pb2.repeated_fixed32_extension].append(1)
  235. proto.Extensions[unittest_pb2.repeated_int32_extension ].append(5)
  236. proto.Extensions[unittest_pb2.repeated_int32_extension ].append(11)
  237. proto.Extensions[unittest_pb2.repeated_string_extension ].append('foo')
  238. proto.Extensions[unittest_pb2.repeated_string_extension ].append('bar')
  239. proto.Extensions[unittest_pb2.repeated_string_extension ].append('baz')
  240. proto.Extensions[unittest_pb2.optional_int32_extension ] = 21
  241. self.assertEqual(
  242. [ (unittest_pb2.optional_int32_extension , 21),
  243. (unittest_pb2.repeated_int32_extension , [5, 11]),
  244. (unittest_pb2.repeated_fixed32_extension, [1]),
  245. (unittest_pb2.repeated_string_extension , ['foo', 'bar', 'baz']) ],
  246. proto.ListFields())
  247. def testListFieldsAndExtensions(self):
  248. proto = unittest_pb2.TestFieldOrderings()
  249. test_util.SetAllFieldsAndExtensions(proto)
  250. unittest_pb2.my_extension_int
  251. self.assertEqual(
  252. [ (proto.DESCRIPTOR.fields_by_name['my_int' ], 1),
  253. (unittest_pb2.my_extension_int , 23),
  254. (proto.DESCRIPTOR.fields_by_name['my_string'], 'foo'),
  255. (unittest_pb2.my_extension_string , 'bar'),
  256. (proto.DESCRIPTOR.fields_by_name['my_float' ], 1.0) ],
  257. proto.ListFields())
  258. def testDefaultValues(self):
  259. proto = unittest_pb2.TestAllTypes()
  260. self.assertEqual(0, proto.optional_int32)
  261. self.assertEqual(0, proto.optional_int64)
  262. self.assertEqual(0, proto.optional_uint32)
  263. self.assertEqual(0, proto.optional_uint64)
  264. self.assertEqual(0, proto.optional_sint32)
  265. self.assertEqual(0, proto.optional_sint64)
  266. self.assertEqual(0, proto.optional_fixed32)
  267. self.assertEqual(0, proto.optional_fixed64)
  268. self.assertEqual(0, proto.optional_sfixed32)
  269. self.assertEqual(0, proto.optional_sfixed64)
  270. self.assertEqual(0.0, proto.optional_float)
  271. self.assertEqual(0.0, proto.optional_double)
  272. self.assertEqual(False, proto.optional_bool)
  273. self.assertEqual('', proto.optional_string)
  274. self.assertEqual('', proto.optional_bytes)
  275. self.assertEqual(41, proto.default_int32)
  276. self.assertEqual(42, proto.default_int64)
  277. self.assertEqual(43, proto.default_uint32)
  278. self.assertEqual(44, proto.default_uint64)
  279. self.assertEqual(-45, proto.default_sint32)
  280. self.assertEqual(46, proto.default_sint64)
  281. self.assertEqual(47, proto.default_fixed32)
  282. self.assertEqual(48, proto.default_fixed64)
  283. self.assertEqual(49, proto.default_sfixed32)
  284. self.assertEqual(-50, proto.default_sfixed64)
  285. self.assertEqual(51.5, proto.default_float)
  286. self.assertEqual(52e3, proto.default_double)
  287. self.assertEqual(True, proto.default_bool)
  288. self.assertEqual('hello', proto.default_string)
  289. self.assertEqual('world', proto.default_bytes)
  290. self.assertEqual(unittest_pb2.TestAllTypes.BAR, proto.default_nested_enum)
  291. self.assertEqual(unittest_pb2.FOREIGN_BAR, proto.default_foreign_enum)
  292. self.assertEqual(unittest_import_pb2.IMPORT_BAR,
  293. proto.default_import_enum)
  294. proto = unittest_pb2.TestExtremeDefaultValues()
  295. self.assertEqual(u'\u1234', proto.utf8_string)
  296. def testHasFieldWithUnknownFieldName(self):
  297. proto = unittest_pb2.TestAllTypes()
  298. self.assertRaises(ValueError, proto.HasField, 'nonexistent_field')
  299. def testClearFieldWithUnknownFieldName(self):
  300. proto = unittest_pb2.TestAllTypes()
  301. self.assertRaises(ValueError, proto.ClearField, 'nonexistent_field')
  302. def testDisallowedAssignments(self):
  303. # It's illegal to assign values directly to repeated fields
  304. # or to nonrepeated composite fields. Ensure that this fails.
  305. proto = unittest_pb2.TestAllTypes()
  306. # Repeated fields.
  307. self.assertRaises(AttributeError, setattr, proto, 'repeated_int32', 10)
  308. # Lists shouldn't work, either.
  309. self.assertRaises(AttributeError, setattr, proto, 'repeated_int32', [10])
  310. # Composite fields.
  311. self.assertRaises(AttributeError, setattr, proto,
  312. 'optional_nested_message', 23)
  313. # Assignment to a repeated nested message field without specifying
  314. # the index in the array of nested messages.
  315. self.assertRaises(AttributeError, setattr, proto.repeated_nested_message,
  316. 'bb', 34)
  317. # Assignment to an attribute of a repeated field.
  318. self.assertRaises(AttributeError, setattr, proto.repeated_float,
  319. 'some_attribute', 34)
  320. # proto.nonexistent_field = 23 should fail as well.
  321. self.assertRaises(AttributeError, setattr, proto, 'nonexistent_field', 23)
  322. # TODO(robinson): Add type-safety check for enums.
  323. def testSingleScalarTypeSafety(self):
  324. proto = unittest_pb2.TestAllTypes()
  325. self.assertRaises(TypeError, setattr, proto, 'optional_int32', 1.1)
  326. self.assertRaises(TypeError, setattr, proto, 'optional_int32', 'foo')
  327. self.assertRaises(TypeError, setattr, proto, 'optional_string', 10)
  328. self.assertRaises(TypeError, setattr, proto, 'optional_bytes', 10)
  329. def testSingleScalarBoundsChecking(self):
  330. def TestMinAndMaxIntegers(field_name, expected_min, expected_max):
  331. pb = unittest_pb2.TestAllTypes()
  332. setattr(pb, field_name, expected_min)
  333. setattr(pb, field_name, expected_max)
  334. self.assertRaises(ValueError, setattr, pb, field_name, expected_min - 1)
  335. self.assertRaises(ValueError, setattr, pb, field_name, expected_max + 1)
  336. TestMinAndMaxIntegers('optional_int32', -(1 << 31), (1 << 31) - 1)
  337. TestMinAndMaxIntegers('optional_uint32', 0, 0xffffffff)
  338. TestMinAndMaxIntegers('optional_int64', -(1 << 63), (1 << 63) - 1)
  339. TestMinAndMaxIntegers('optional_uint64', 0, 0xffffffffffffffff)
  340. TestMinAndMaxIntegers('optional_nested_enum', -(1 << 31), (1 << 31) - 1)
  341. def testRepeatedScalarTypeSafety(self):
  342. proto = unittest_pb2.TestAllTypes()
  343. self.assertRaises(TypeError, proto.repeated_int32.append, 1.1)
  344. self.assertRaises(TypeError, proto.repeated_int32.append, 'foo')
  345. self.assertRaises(TypeError, proto.repeated_string, 10)
  346. self.assertRaises(TypeError, proto.repeated_bytes, 10)
  347. proto.repeated_int32.append(10)
  348. proto.repeated_int32[0] = 23
  349. self.assertRaises(IndexError, proto.repeated_int32.__setitem__, 500, 23)
  350. self.assertRaises(TypeError, proto.repeated_int32.__setitem__, 0, 'abc')
  351. def testSingleScalarGettersAndSetters(self):
  352. proto = unittest_pb2.TestAllTypes()
  353. self.assertEqual(0, proto.optional_int32)
  354. proto.optional_int32 = 1
  355. self.assertEqual(1, proto.optional_int32)
  356. # TODO(robinson): Test all other scalar field types.
  357. def testSingleScalarClearField(self):
  358. proto = unittest_pb2.TestAllTypes()
  359. # Should be allowed to clear something that's not there (a no-op).
  360. proto.ClearField('optional_int32')
  361. proto.optional_int32 = 1
  362. self.assertTrue(proto.HasField('optional_int32'))
  363. proto.ClearField('optional_int32')
  364. self.assertEqual(0, proto.optional_int32)
  365. self.assertTrue(not proto.HasField('optional_int32'))
  366. # TODO(robinson): Test all other scalar field types.
  367. def testEnums(self):
  368. proto = unittest_pb2.TestAllTypes()
  369. self.assertEqual(1, proto.FOO)
  370. self.assertEqual(1, unittest_pb2.TestAllTypes.FOO)
  371. self.assertEqual(2, proto.BAR)
  372. self.assertEqual(2, unittest_pb2.TestAllTypes.BAR)
  373. self.assertEqual(3, proto.BAZ)
  374. self.assertEqual(3, unittest_pb2.TestAllTypes.BAZ)
  375. def testRepeatedScalars(self):
  376. proto = unittest_pb2.TestAllTypes()
  377. self.assertTrue(not proto.repeated_int32)
  378. self.assertEqual(0, len(proto.repeated_int32))
  379. proto.repeated_int32.append(5)
  380. proto.repeated_int32.append(10)
  381. proto.repeated_int32.append(15)
  382. self.assertTrue(proto.repeated_int32)
  383. self.assertEqual(3, len(proto.repeated_int32))
  384. self.assertEqual([5, 10, 15], proto.repeated_int32)
  385. # Test single retrieval.
  386. self.assertEqual(5, proto.repeated_int32[0])
  387. self.assertEqual(15, proto.repeated_int32[-1])
  388. # Test out-of-bounds indices.
  389. self.assertRaises(IndexError, proto.repeated_int32.__getitem__, 1234)
  390. self.assertRaises(IndexError, proto.repeated_int32.__getitem__, -1234)
  391. # Test incorrect types passed to __getitem__.
  392. self.assertRaises(TypeError, proto.repeated_int32.__getitem__, 'foo')
  393. self.assertRaises(TypeError, proto.repeated_int32.__getitem__, None)
  394. # Test single assignment.
  395. proto.repeated_int32[1] = 20
  396. self.assertEqual([5, 20, 15], proto.repeated_int32)
  397. # Test insertion.
  398. proto.repeated_int32.insert(1, 25)
  399. self.assertEqual([5, 25, 20, 15], proto.repeated_int32)
  400. # Test slice retrieval.
  401. proto.repeated_int32.append(30)
  402. self.assertEqual([25, 20, 15], proto.repeated_int32[1:4])
  403. self.assertEqual([5, 25, 20, 15, 30], proto.repeated_int32[:])
  404. # Test slice assignment.
  405. proto.repeated_int32[1:4] = [35, 40, 45]
  406. self.assertEqual([5, 35, 40, 45, 30], proto.repeated_int32)
  407. # Test that we can use the field as an iterator.
  408. result = []
  409. for i in proto.repeated_int32:
  410. result.append(i)
  411. self.assertEqual([5, 35, 40, 45, 30], result)
  412. # Test single deletion.
  413. del proto.repeated_int32[2]
  414. self.assertEqual([5, 35, 45, 30], proto.repeated_int32)
  415. # Test slice deletion.
  416. del proto.repeated_int32[2:]
  417. self.assertEqual([5, 35], proto.repeated_int32)
  418. # Test clearing.
  419. proto.ClearField('repeated_int32')
  420. self.assertTrue(not proto.repeated_int32)
  421. self.assertEqual(0, len(proto.repeated_int32))
  422. def testRepeatedScalarsRemove(self):
  423. proto = unittest_pb2.TestAllTypes()
  424. self.assertTrue(not proto.repeated_int32)
  425. self.assertEqual(0, len(proto.repeated_int32))
  426. proto.repeated_int32.append(5)
  427. proto.repeated_int32.append(10)
  428. proto.repeated_int32.append(5)
  429. proto.repeated_int32.append(5)
  430. self.assertEqual(4, len(proto.repeated_int32))
  431. proto.repeated_int32.remove(5)
  432. self.assertEqual(3, len(proto.repeated_int32))
  433. self.assertEqual(10, proto.repeated_int32[0])
  434. self.assertEqual(5, proto.repeated_int32[1])
  435. self.assertEqual(5, proto.repeated_int32[2])
  436. proto.repeated_int32.remove(5)
  437. self.assertEqual(2, len(proto.repeated_int32))
  438. self.assertEqual(10, proto.repeated_int32[0])
  439. self.assertEqual(5, proto.repeated_int32[1])
  440. proto.repeated_int32.remove(10)
  441. self.assertEqual(1, len(proto.repeated_int32))
  442. self.assertEqual(5, proto.repeated_int32[0])
  443. # Remove a non-existent element.
  444. self.assertRaises(ValueError, proto.repeated_int32.remove, 123)
  445. def testRepeatedComposites(self):
  446. proto = unittest_pb2.TestAllTypes()
  447. self.assertTrue(not proto.repeated_nested_message)
  448. self.assertEqual(0, len(proto.repeated_nested_message))
  449. m0 = proto.repeated_nested_message.add()
  450. m1 = proto.repeated_nested_message.add()
  451. self.assertTrue(proto.repeated_nested_message)
  452. self.assertEqual(2, len(proto.repeated_nested_message))
  453. self.assertIs([m0, m1], proto.repeated_nested_message)
  454. self.assertTrue(isinstance(m0, unittest_pb2.TestAllTypes.NestedMessage))
  455. # Test out-of-bounds indices.
  456. self.assertRaises(IndexError, proto.repeated_nested_message.__getitem__,
  457. 1234)
  458. self.assertRaises(IndexError, proto.repeated_nested_message.__getitem__,
  459. -1234)
  460. # Test incorrect types passed to __getitem__.
  461. self.assertRaises(TypeError, proto.repeated_nested_message.__getitem__,
  462. 'foo')
  463. self.assertRaises(TypeError, proto.repeated_nested_message.__getitem__,
  464. None)
  465. # Test slice retrieval.
  466. m2 = proto.repeated_nested_message.add()
  467. m3 = proto.repeated_nested_message.add()
  468. m4 = proto.repeated_nested_message.add()
  469. self.assertIs([m1, m2, m3], proto.repeated_nested_message[1:4])
  470. self.assertIs([m0, m1, m2, m3, m4], proto.repeated_nested_message[:])
  471. # Test that we can use the field as an iterator.
  472. result = []
  473. for i in proto.repeated_nested_message:
  474. result.append(i)
  475. self.assertIs([m0, m1, m2, m3, m4], result)
  476. # Test single deletion.
  477. del proto.repeated_nested_message[2]
  478. self.assertIs([m0, m1, m3, m4], proto.repeated_nested_message)
  479. # Test slice deletion.
  480. del proto.repeated_nested_message[2:]
  481. self.assertIs([m0, m1], proto.repeated_nested_message)
  482. # Test clearing.
  483. proto.ClearField('repeated_nested_message')
  484. self.assertTrue(not proto.repeated_nested_message)
  485. self.assertEqual(0, len(proto.repeated_nested_message))
  486. def testHandWrittenReflection(self):
  487. # TODO(robinson): We probably need a better way to specify
  488. # protocol types by hand. But then again, this isn't something
  489. # we expect many people to do. Hmm.
  490. FieldDescriptor = descriptor.FieldDescriptor
  491. foo_field_descriptor = FieldDescriptor(
  492. name='foo_field', full_name='MyProto.foo_field',
  493. index=0, number=1, type=FieldDescriptor.TYPE_INT64,
  494. cpp_type=FieldDescriptor.CPPTYPE_INT64,
  495. label=FieldDescriptor.LABEL_OPTIONAL, default_value=0,
  496. containing_type=None, message_type=None, enum_type=None,
  497. is_extension=False, extension_scope=None,
  498. options=descriptor_pb2.FieldOptions())
  499. mydescriptor = descriptor.Descriptor(
  500. name='MyProto', full_name='MyProto', filename='ignored',
  501. containing_type=None, nested_types=[], enum_types=[],
  502. fields=[foo_field_descriptor], extensions=[],
  503. options=descriptor_pb2.MessageOptions())
  504. class MyProtoClass(message.Message):
  505. DESCRIPTOR = mydescriptor
  506. __metaclass__ = reflection.GeneratedProtocolMessageType
  507. myproto_instance = MyProtoClass()
  508. self.assertEqual(0, myproto_instance.foo_field)
  509. self.assertTrue(not myproto_instance.HasField('foo_field'))
  510. myproto_instance.foo_field = 23
  511. self.assertEqual(23, myproto_instance.foo_field)
  512. self.assertTrue(myproto_instance.HasField('foo_field'))
  513. def testTopLevelExtensionsForOptionalScalar(self):
  514. extendee_proto = unittest_pb2.TestAllExtensions()
  515. extension = unittest_pb2.optional_int32_extension
  516. self.assertTrue(not extendee_proto.HasExtension(extension))
  517. self.assertEqual(0, extendee_proto.Extensions[extension])
  518. # As with normal scalar fields, just doing a read doesn't actually set the
  519. # "has" bit.
  520. self.assertTrue(not extendee_proto.HasExtension(extension))
  521. # Actually set the thing.
  522. extendee_proto.Extensions[extension] = 23
  523. self.assertEqual(23, extendee_proto.Extensions[extension])
  524. self.assertTrue(extendee_proto.HasExtension(extension))
  525. # Ensure that clearing works as well.
  526. extendee_proto.ClearExtension(extension)
  527. self.assertEqual(0, extendee_proto.Extensions[extension])
  528. self.assertTrue(not extendee_proto.HasExtension(extension))
  529. def testTopLevelExtensionsForRepeatedScalar(self):
  530. extendee_proto = unittest_pb2.TestAllExtensions()
  531. extension = unittest_pb2.repeated_string_extension
  532. self.assertEqual(0, len(extendee_proto.Extensions[extension]))
  533. extendee_proto.Extensions[extension].append('foo')
  534. self.assertEqual(['foo'], extendee_proto.Extensions[extension])
  535. string_list = extendee_proto.Extensions[extension]
  536. extendee_proto.ClearExtension(extension)
  537. self.assertEqual(0, len(extendee_proto.Extensions[extension]))
  538. self.assertTrue(string_list is not extendee_proto.Extensions[extension])
  539. # Shouldn't be allowed to do Extensions[extension] = 'a'
  540. self.assertRaises(TypeError, operator.setitem, extendee_proto.Extensions,
  541. extension, 'a')
  542. def testTopLevelExtensionsForOptionalMessage(self):
  543. extendee_proto = unittest_pb2.TestAllExtensions()
  544. extension = unittest_pb2.optional_foreign_message_extension
  545. self.assertTrue(not extendee_proto.HasExtension(extension))
  546. self.assertEqual(0, extendee_proto.Extensions[extension].c)
  547. # As with normal (non-extension) fields, merely reading from the
  548. # thing shouldn't set the "has" bit.
  549. self.assertTrue(not extendee_proto.HasExtension(extension))
  550. extendee_proto.Extensions[extension].c = 23
  551. self.assertEqual(23, extendee_proto.Extensions[extension].c)
  552. self.assertTrue(extendee_proto.HasExtension(extension))
  553. # Save a reference here.
  554. foreign_message = extendee_proto.Extensions[extension]
  555. extendee_proto.ClearExtension(extension)
  556. self.assertTrue(foreign_message is not extendee_proto.Extensions[extension])
  557. # Setting a field on foreign_message now shouldn't set
  558. # any "has" bits on extendee_proto.
  559. foreign_message.c = 42
  560. self.assertEqual(42, foreign_message.c)
  561. self.assertTrue(foreign_message.HasField('c'))
  562. self.assertTrue(not extendee_proto.HasExtension(extension))
  563. # Shouldn't be allowed to do Extensions[extension] = 'a'
  564. self.assertRaises(TypeError, operator.setitem, extendee_proto.Extensions,
  565. extension, 'a')
  566. def testTopLevelExtensionsForRepeatedMessage(self):
  567. extendee_proto = unittest_pb2.TestAllExtensions()
  568. extension = unittest_pb2.repeatedgroup_extension
  569. self.assertEqual(0, len(extendee_proto.Extensions[extension]))
  570. group = extendee_proto.Extensions[extension].add()
  571. group.a = 23
  572. self.assertEqual(23, extendee_proto.Extensions[extension][0].a)
  573. group.a = 42
  574. self.assertEqual(42, extendee_proto.Extensions[extension][0].a)
  575. group_list = extendee_proto.Extensions[extension]
  576. extendee_proto.ClearExtension(extension)
  577. self.assertEqual(0, len(extendee_proto.Extensions[extension]))
  578. self.assertTrue(group_list is not extendee_proto.Extensions[extension])
  579. # Shouldn't be allowed to do Extensions[extension] = 'a'
  580. self.assertRaises(TypeError, operator.setitem, extendee_proto.Extensions,
  581. extension, 'a')
  582. def testNestedExtensions(self):
  583. extendee_proto = unittest_pb2.TestAllExtensions()
  584. extension = unittest_pb2.TestRequired.single
  585. # We just test the non-repeated case.
  586. self.assertTrue(not extendee_proto.HasExtension(extension))
  587. required = extendee_proto.Extensions[extension]
  588. self.assertEqual(0, required.a)
  589. self.assertTrue(not extendee_proto.HasExtension(extension))
  590. required.a = 23
  591. self.assertEqual(23, extendee_proto.Extensions[extension].a)
  592. self.assertTrue(extendee_proto.HasExtension(extension))
  593. extendee_proto.ClearExtension(extension)
  594. self.assertTrue(required is not extendee_proto.Extensions[extension])
  595. self.assertTrue(not extendee_proto.HasExtension(extension))
  596. # If message A directly contains message B, and
  597. # a.HasField('b') is currently False, then mutating any
  598. # extension in B should change a.HasField('b') to True
  599. # (and so on up the object tree).
  600. def testHasBitsForAncestorsOfExtendedMessage(self):
  601. # Optional scalar extension.
  602. toplevel = more_extensions_pb2.TopLevelMessage()
  603. self.assertTrue(not toplevel.HasField('submessage'))
  604. self.assertEqual(0, toplevel.submessage.Extensions[
  605. more_extensions_pb2.optional_int_extension])
  606. self.assertTrue(not toplevel.HasField('submessage'))
  607. toplevel.submessage.Extensions[
  608. more_extensions_pb2.optional_int_extension] = 23
  609. self.assertEqual(23, toplevel.submessage.Extensions[
  610. more_extensions_pb2.optional_int_extension])
  611. self.assertTrue(toplevel.HasField('submessage'))
  612. # Repeated scalar extension.
  613. toplevel = more_extensions_pb2.TopLevelMessage()
  614. self.assertTrue(not toplevel.HasField('submessage'))
  615. self.assertEqual([], toplevel.submessage.Extensions[
  616. more_extensions_pb2.repeated_int_extension])
  617. self.assertTrue(not toplevel.HasField('submessage'))
  618. toplevel.submessage.Extensions[
  619. more_extensions_pb2.repeated_int_extension].append(23)
  620. self.assertEqual([23], toplevel.submessage.Extensions[
  621. more_extensions_pb2.repeated_int_extension])
  622. self.assertTrue(toplevel.HasField('submessage'))
  623. # Optional message extension.
  624. toplevel = more_extensions_pb2.TopLevelMessage()
  625. self.assertTrue(not toplevel.HasField('submessage'))
  626. self.assertEqual(0, toplevel.submessage.Extensions[
  627. more_extensions_pb2.optional_message_extension].foreign_message_int)
  628. self.assertTrue(not toplevel.HasField('submessage'))
  629. toplevel.submessage.Extensions[
  630. more_extensions_pb2.optional_message_extension].foreign_message_int = 23
  631. self.assertEqual(23, toplevel.submessage.Extensions[
  632. more_extensions_pb2.optional_message_extension].foreign_message_int)
  633. self.assertTrue(toplevel.HasField('submessage'))
  634. # Repeated message extension.
  635. toplevel = more_extensions_pb2.TopLevelMessage()
  636. self.assertTrue(not toplevel.HasField('submessage'))
  637. self.assertEqual(0, len(toplevel.submessage.Extensions[
  638. more_extensions_pb2.repeated_message_extension]))
  639. self.assertTrue(not toplevel.HasField('submessage'))
  640. foreign = toplevel.submessage.Extensions[
  641. more_extensions_pb2.repeated_message_extension].add()
  642. self.assertTrue(foreign is toplevel.submessage.Extensions[
  643. more_extensions_pb2.repeated_message_extension][0])
  644. self.assertTrue(toplevel.HasField('submessage'))
  645. def testDisconnectionAfterClearingEmptyMessage(self):
  646. toplevel = more_extensions_pb2.TopLevelMessage()
  647. extendee_proto = toplevel.submessage
  648. extension = more_extensions_pb2.optional_message_extension
  649. extension_proto = extendee_proto.Extensions[extension]
  650. extendee_proto.ClearExtension(extension)
  651. extension_proto.foreign_message_int = 23
  652. self.assertTrue(not toplevel.HasField('submessage'))
  653. self.assertTrue(extension_proto is not extendee_proto.Extensions[extension])
  654. def testExtensionFailureModes(self):
  655. extendee_proto = unittest_pb2.TestAllExtensions()
  656. # Try non-extension-handle arguments to HasExtension,
  657. # ClearExtension(), and Extensions[]...
  658. self.assertRaises(KeyError, extendee_proto.HasExtension, 1234)
  659. self.assertRaises(KeyError, extendee_proto.ClearExtension, 1234)
  660. self.assertRaises(KeyError, extendee_proto.Extensions.__getitem__, 1234)
  661. self.assertRaises(KeyError, extendee_proto.Extensions.__setitem__, 1234, 5)
  662. # Try something that *is* an extension handle, just not for
  663. # this message...
  664. unknown_handle = more_extensions_pb2.optional_int_extension
  665. self.assertRaises(KeyError, extendee_proto.HasExtension,
  666. unknown_handle)
  667. self.assertRaises(KeyError, extendee_proto.ClearExtension,
  668. unknown_handle)
  669. self.assertRaises(KeyError, extendee_proto.Extensions.__getitem__,
  670. unknown_handle)
  671. self.assertRaises(KeyError, extendee_proto.Extensions.__setitem__,
  672. unknown_handle, 5)
  673. # Try call HasExtension() with a valid handle, but for a
  674. # *repeated* field. (Just as with non-extension repeated
  675. # fields, Has*() isn't supported for extension repeated fields).
  676. self.assertRaises(KeyError, extendee_proto.HasExtension,
  677. unittest_pb2.repeated_string_extension)
  678. def testStaticParseFrom(self):
  679. proto1 = unittest_pb2.TestAllTypes()
  680. test_util.SetAllFields(proto1)
  681. string1 = proto1.SerializeToString()
  682. proto2 = unittest_pb2.TestAllTypes.FromString(string1)
  683. # Messages should be equal.
  684. self.assertEqual(proto2, proto1)
  685. def testMergeFromSingularField(self):
  686. # Test merge with just a singular field.
  687. proto1 = unittest_pb2.TestAllTypes()
  688. proto1.optional_int32 = 1
  689. proto2 = unittest_pb2.TestAllTypes()
  690. # This shouldn't get overwritten.
  691. proto2.optional_string = 'value'
  692. proto2.MergeFrom(proto1)
  693. self.assertEqual(1, proto2.optional_int32)
  694. self.assertEqual('value', proto2.optional_string)
  695. def testMergeFromRepeatedField(self):
  696. # Test merge with just a repeated field.
  697. proto1 = unittest_pb2.TestAllTypes()
  698. proto1.repeated_int32.append(1)
  699. proto1.repeated_int32.append(2)
  700. proto2 = unittest_pb2.TestAllTypes()
  701. proto2.repeated_int32.append(0)
  702. proto2.MergeFrom(proto1)
  703. self.assertEqual(0, proto2.repeated_int32[0])
  704. self.assertEqual(1, proto2.repeated_int32[1])
  705. self.assertEqual(2, proto2.repeated_int32[2])
  706. def testMergeFromOptionalGroup(self):
  707. # Test merge with an optional group.
  708. proto1 = unittest_pb2.TestAllTypes()
  709. proto1.optionalgroup.a = 12
  710. proto2 = unittest_pb2.TestAllTypes()
  711. proto2.MergeFrom(proto1)
  712. self.assertEqual(12, proto2.optionalgroup.a)
  713. def testMergeFromRepeatedNestedMessage(self):
  714. # Test merge with a repeated nested message.
  715. proto1 = unittest_pb2.TestAllTypes()
  716. m = proto1.repeated_nested_message.add()
  717. m.bb = 123
  718. m = proto1.repeated_nested_message.add()
  719. m.bb = 321
  720. proto2 = unittest_pb2.TestAllTypes()
  721. m = proto2.repeated_nested_message.add()
  722. m.bb = 999
  723. proto2.MergeFrom(proto1)
  724. self.assertEqual(999, proto2.repeated_nested_message[0].bb)
  725. self.assertEqual(123, proto2.repeated_nested_message[1].bb)
  726. self.assertEqual(321, proto2.repeated_nested_message[2].bb)
  727. def testMergeFromAllFields(self):
  728. # With all fields set.
  729. proto1 = unittest_pb2.TestAllTypes()
  730. test_util.SetAllFields(proto1)
  731. proto2 = unittest_pb2.TestAllTypes()
  732. proto2.MergeFrom(proto1)
  733. # Messages should be equal.
  734. self.assertEqual(proto2, proto1)
  735. # Serialized string should be equal too.
  736. string1 = proto1.SerializeToString()
  737. string2 = proto2.SerializeToString()
  738. self.assertEqual(string1, string2)
  739. def testMergeFromExtensionsSingular(self):
  740. proto1 = unittest_pb2.TestAllExtensions()
  741. proto1.Extensions[unittest_pb2.optional_int32_extension] = 1
  742. proto2 = unittest_pb2.TestAllExtensions()
  743. proto2.MergeFrom(proto1)
  744. self.assertEqual(
  745. 1, proto2.Extensions[unittest_pb2.optional_int32_extension])
  746. def testMergeFromExtensionsRepeated(self):
  747. proto1 = unittest_pb2.TestAllExtensions()
  748. proto1.Extensions[unittest_pb2.repeated_int32_extension].append(1)
  749. proto1.Extensions[unittest_pb2.repeated_int32_extension].append(2)
  750. proto2 = unittest_pb2.TestAllExtensions()
  751. proto2.Extensions[unittest_pb2.repeated_int32_extension].append(0)
  752. proto2.MergeFrom(proto1)
  753. self.assertEqual(
  754. 3, len(proto2.Extensions[unittest_pb2.repeated_int32_extension]))
  755. self.assertEqual(
  756. 0, proto2.Extensions[unittest_pb2.repeated_int32_extension][0])
  757. self.assertEqual(
  758. 1, proto2.Extensions[unittest_pb2.repeated_int32_extension][1])
  759. self.assertEqual(
  760. 2, proto2.Extensions[unittest_pb2.repeated_int32_extension][2])
  761. def testMergeFromExtensionsNestedMessage(self):
  762. proto1 = unittest_pb2.TestAllExtensions()
  763. ext1 = proto1.Extensions[
  764. unittest_pb2.repeated_nested_message_extension]
  765. m = ext1.add()
  766. m.bb = 222
  767. m = ext1.add()
  768. m.bb = 333
  769. proto2 = unittest_pb2.TestAllExtensions()
  770. ext2 = proto2.Extensions[
  771. unittest_pb2.repeated_nested_message_extension]
  772. m = ext2.add()
  773. m.bb = 111
  774. proto2.MergeFrom(proto1)
  775. ext2 = proto2.Extensions[
  776. unittest_pb2.repeated_nested_message_extension]
  777. self.assertEqual(3, len(ext2))
  778. self.assertEqual(111, ext2[0].bb)
  779. self.assertEqual(222, ext2[1].bb)
  780. self.assertEqual(333, ext2[2].bb)
  781. def testCopyFromSingularField(self):
  782. # Test copy with just a singular field.
  783. proto1 = unittest_pb2.TestAllTypes()
  784. proto1.optional_int32 = 1
  785. proto1.optional_string = 'important-text'
  786. proto2 = unittest_pb2.TestAllTypes()
  787. proto2.optional_string = 'value'
  788. proto2.CopyFrom(proto1)
  789. self.assertEqual(1, proto2.optional_int32)
  790. self.assertEqual('important-text', proto2.optional_string)
  791. def testCopyFromRepeatedField(self):
  792. # Test copy with a repeated field.
  793. proto1 = unittest_pb2.TestAllTypes()
  794. proto1.repeated_int32.append(1)
  795. proto1.repeated_int32.append(2)
  796. proto2 = unittest_pb2.TestAllTypes()
  797. proto2.repeated_int32.append(0)
  798. proto2.CopyFrom(proto1)
  799. self.assertEqual(1, proto2.repeated_int32[0])
  800. self.assertEqual(2, proto2.repeated_int32[1])
  801. def testCopyFromAllFields(self):
  802. # With all fields set.
  803. proto1 = unittest_pb2.TestAllTypes()
  804. test_util.SetAllFields(proto1)
  805. proto2 = unittest_pb2.TestAllTypes()
  806. proto2.CopyFrom(proto1)
  807. # Messages should be equal.
  808. self.assertEqual(proto2, proto1)
  809. # Serialized string should be equal too.
  810. string1 = proto1.SerializeToString()
  811. string2 = proto2.SerializeToString()
  812. self.assertEqual(string1, string2)
  813. def testCopyFromSelf(self):
  814. proto1 = unittest_pb2.TestAllTypes()
  815. proto1.repeated_int32.append(1)
  816. proto1.optional_int32 = 2
  817. proto1.optional_string = 'important-text'
  818. proto1.CopyFrom(proto1)
  819. self.assertEqual(1, proto1.repeated_int32[0])
  820. self.assertEqual(2, proto1.optional_int32)
  821. self.assertEqual('important-text', proto1.optional_string)
  822. def testClear(self):
  823. proto = unittest_pb2.TestAllTypes()
  824. test_util.SetAllFields(proto)
  825. # Clear the message.
  826. proto.Clear()
  827. self.assertEquals(proto.ByteSize(), 0)
  828. empty_proto = unittest_pb2.TestAllTypes()
  829. self.assertEquals(proto, empty_proto)
  830. # Test if extensions which were set are cleared.
  831. proto = unittest_pb2.TestAllExtensions()
  832. test_util.SetAllExtensions(proto)
  833. # Clear the message.
  834. proto.Clear()
  835. self.assertEquals(proto.ByteSize(), 0)
  836. empty_proto = unittest_pb2.TestAllExtensions()
  837. self.assertEquals(proto, empty_proto)
  838. def testIsInitialized(self):
  839. # Trivial cases - all optional fields and extensions.
  840. proto = unittest_pb2.TestAllTypes()
  841. self.assertTrue(proto.IsInitialized())
  842. proto = unittest_pb2.TestAllExtensions()
  843. self.assertTrue(proto.IsInitialized())
  844. # The case of uninitialized required fields.
  845. proto = unittest_pb2.TestRequired()
  846. self.assertFalse(proto.IsInitialized())
  847. proto.a = proto.b = proto.c = 2
  848. self.assertTrue(proto.IsInitialized())
  849. # The case of uninitialized submessage.
  850. proto = unittest_pb2.TestRequiredForeign()
  851. self.assertTrue(proto.IsInitialized())
  852. proto.optional_message.a = 1
  853. self.assertFalse(proto.IsInitialized())
  854. proto.optional_message.b = 0
  855. proto.optional_message.c = 0
  856. self.assertTrue(proto.IsInitialized())
  857. # Uninitialized repeated submessage.
  858. message1 = proto.repeated_message.add()
  859. self.assertFalse(proto.IsInitialized())
  860. message1.a = message1.b = message1.c = 0
  861. self.assertTrue(proto.IsInitialized())
  862. # Uninitialized repeated group in an extension.
  863. proto = unittest_pb2.TestAllExtensions()
  864. extension = unittest_pb2.TestRequired.multi
  865. message1 = proto.Extensions[extension].add()
  866. message2 = proto.Extensions[extension].add()
  867. self.assertFalse(proto.IsInitialized())
  868. message1.a = 1
  869. message1.b = 1
  870. message1.c = 1
  871. self.assertFalse(proto.IsInitialized())
  872. message2.a = 2
  873. message2.b = 2
  874. message2.c = 2
  875. self.assertTrue(proto.IsInitialized())
  876. # Uninitialized nonrepeated message in an extension.
  877. proto = unittest_pb2.TestAllExtensions()
  878. extension = unittest_pb2.TestRequired.single
  879. proto.Extensions[extension].a = 1
  880. self.assertFalse(proto.IsInitialized())
  881. proto.Extensions[extension].b = 2
  882. proto.Extensions[extension].c = 3
  883. self.assertTrue(proto.IsInitialized())
  884. def testStringUTF8Encoding(self):
  885. proto = unittest_pb2.TestAllTypes()
  886. # Assignment of a unicode object to a field of type 'bytes' is not allowed.
  887. self.assertRaises(TypeError,
  888. setattr, proto, 'optional_bytes', u'unicode object')
  889. # Check that the default value is of python's 'unicode' type.
  890. self.assertEqual(type(proto.optional_string), unicode)
  891. proto.optional_string = unicode('Testing')
  892. self.assertEqual(proto.optional_string, str('Testing'))
  893. # Assign a value of type 'str' which can be encoded in UTF-8.
  894. proto.optional_string = str('Testing')
  895. self.assertEqual(proto.optional_string, unicode('Testing'))
  896. # Values of type 'str' are also accepted as long as they can be encoded in
  897. # UTF-8.
  898. self.assertEqual(type(proto.optional_string), str)
  899. # Try to assign a 'str' value which contains bytes that aren't 7-bit ASCII.
  900. self.assertRaises(ValueError,
  901. setattr, proto, 'optional_string', str('a\x80a'))
  902. # Assign a 'str' object which contains a UTF-8 encoded string.
  903. self.assertRaises(ValueError,
  904. setattr, proto, 'optional_string', 'Тест')
  905. # No exception thrown.
  906. proto.optional_string = 'abc'
  907. def testStringUTF8Serialization(self):
  908. proto = unittest_mset_pb2.TestMessageSet()
  909. extension_message = unittest_mset_pb2.TestMessageSetExtension2
  910. extension = extension_message.message_set_extension
  911. test_utf8 = u'Тест'
  912. test_utf8_bytes = test_utf8.encode('utf-8')
  913. # 'Test' in another language, using UTF-8 charset.
  914. proto.Extensions[extension].str = test_utf8
  915. # Serialize using the MessageSet wire format (this is specified in the
  916. # .proto file).
  917. serialized = proto.SerializeToString()
  918. # Check byte size.
  919. self.assertEqual(proto.ByteSize(), len(serialized))
  920. raw = unittest_mset_pb2.RawMessageSet()
  921. raw.MergeFromString(serialized)
  922. message2 = unittest_mset_pb2.TestMessageSetExtension2()
  923. self.assertEqual(1, len(raw.item))
  924. # Check that the type_id is the same as the tag ID in the .proto file.
  925. self.assertEqual(raw.item[0].type_id, 1547769)
  926. # Check the actually bytes on the wire.
  927. self.assertTrue(
  928. raw.item[0].message.endswith(test_utf8_bytes))
  929. message2.MergeFromString(raw.item[0].message)
  930. self.assertEqual(type(message2.str), unicode)
  931. self.assertEqual(message2.str, test_utf8)
  932. # How about if the bytes on the wire aren't a valid UTF-8 encoded string.
  933. bytes = raw.item[0].message.replace(
  934. test_utf8_bytes, len(test_utf8_bytes) * '\xff')
  935. self.assertRaises(UnicodeDecodeError, message2.MergeFromString, bytes)
  936. # Since we had so many tests for protocol buffer equality, we broke these out
  937. # into separate TestCase classes.
  938. class TestAllTypesEqualityTest(unittest.TestCase):
  939. def setUp(self):
  940. self.first_proto = unittest_pb2.TestAllTypes()
  941. self.second_proto = unittest_pb2.TestAllTypes()
  942. def testSelfEquality(self):
  943. self.assertEqual(self.first_proto, self.first_proto)
  944. def testEmptyProtosEqual(self):
  945. self.assertEqual(self.first_proto, self.second_proto)
  946. class FullProtosEqualityTest(unittest.TestCase):
  947. """Equality tests using completely-full protos as a starting point."""
  948. def setUp(self):
  949. self.first_proto = unittest_pb2.TestAllTypes()
  950. self.second_proto = unittest_pb2.TestAllTypes()
  951. test_util.SetAllFields(self.first_proto)
  952. test_util.SetAllFields(self.second_proto)
  953. def testAllFieldsFilledEquality(self):
  954. self.assertEqual(self.first_proto, self.second_proto)
  955. def testNonRepeatedScalar(self):
  956. # Nonrepeated scalar field change should cause inequality.
  957. self.first_proto.optional_int32 += 1
  958. self.assertNotEqual(self.first_proto, self.second_proto)
  959. # ...as should clearing a field.
  960. self.first_proto.ClearField('optional_int32')
  961. self.assertNotEqual(self.first_proto, self.second_proto)
  962. def testNonRepeatedComposite(self):
  963. # Change a nonrepeated composite field.
  964. self.first_proto.optional_nested_message.bb += 1
  965. self.assertNotEqual(self.first_proto, self.second_proto)
  966. self.first_proto.optional_nested_message.bb -= 1
  967. self.assertEqual(self.first_proto, self.second_proto)
  968. # Clear a field in the nested message.
  969. self.first_proto.optional_nested_message.ClearField('bb')
  970. self.assertNotEqual(self.first_proto, self.second_proto)
  971. self.first_proto.optional_nested_message.bb = (
  972. self.second_proto.optional_nested_message.bb)
  973. self.assertEqual(self.first_proto, self.second_proto)
  974. # Remove the nested message entirely.
  975. self.first_proto.ClearField('optional_nested_message')
  976. self.assertNotEqual(self.first_proto, self.second_proto)
  977. def testRepeatedScalar(self):
  978. # Change a repeated scalar field.
  979. self.first_proto.repeated_int32.append(5)
  980. self.assertNotEqual(self.first_proto, self.second_proto)
  981. self.first_proto.ClearField('repeated_int32')
  982. self.assertNotEqual(self.first_proto, self.second_proto)
  983. def testRepeatedComposite(self):
  984. # Change value within a repeated composite field.
  985. self.first_proto.repeated_nested_message[0].bb += 1
  986. self.assertNotEqual(self.first_proto, self.second_proto)
  987. self.first_proto.repeated_nested_message[0].bb -= 1
  988. self.assertEqual(self.first_proto, self.second_proto)
  989. # Add a value to a repeated composite field.
  990. self.first_proto.repeated_nested_message.add()
  991. self.assertNotEqual(self.first_proto, self.second_proto)
  992. self.second_proto.repeated_nested_message.add()
  993. self.assertEqual(self.first_proto, self.second_proto)
  994. def testNonRepeatedScalarHasBits(self):
  995. # Ensure that we test "has" bits as well as value for
  996. # nonrepeated scalar field.
  997. self.first_proto.ClearField('optional_int32')
  998. self.second_proto.optional_int32 = 0
  999. self.assertNotEqual(self.first_proto, self.second_proto)
  1000. def testNonRepeatedCompositeHasBits(self):
  1001. # Ensure that we test "has" bits as well as value for
  1002. # nonrepeated composite field.
  1003. self.first_proto.ClearField('optional_nested_message')
  1004. self.second_proto.optional_nested_message.ClearField('bb')
  1005. self.assertNotEqual(self.first_proto, self.second_proto)
  1006. # TODO(robinson): Replace next two lines with method
  1007. # to set the "has" bit without changing the value,
  1008. # if/when such a method exists.
  1009. self.first_proto.optional_nested_message.bb = 0
  1010. self.first_proto.optional_nested_message.ClearField('bb')
  1011. self.assertEqual(self.first_proto, self.second_proto)
  1012. class ExtensionEqualityTest(unittest.TestCase):
  1013. def testExtensionEquality(self):
  1014. first_proto = unittest_pb2.TestAllExtensions()
  1015. second_proto = unittest_pb2.TestAllExtensions()
  1016. self.assertEqual(first_proto, second_proto)
  1017. test_util.SetAllExtensions(first_proto)
  1018. self.assertNotEqual(first_proto, second_proto)
  1019. test_util.SetAllExtensions(second_proto)
  1020. self.assertEqual(first_proto, second_proto)
  1021. # Ensure that we check value equality.
  1022. first_proto.Extensions[unittest_pb2.optional_int32_extension] += 1
  1023. self.assertNotEqual(first_proto, second_proto)
  1024. first_proto.Extensions[unittest_pb2.optional_int32_extension] -= 1
  1025. self.assertEqual(first_proto, second_proto)
  1026. # Ensure that we also look at "has" bits.
  1027. first_proto.ClearExtension(unittest_pb2.optional_int32_extension)
  1028. second_proto.Extensions[unittest_pb2.optional_int32_extension] = 0
  1029. self.assertNotEqual(first_proto, second_proto)
  1030. first_proto.Extensions[unittest_pb2.optional_int32_extension] = 0
  1031. self.assertEqual(first_proto, second_proto)
  1032. # Ensure that differences in cached values
  1033. # don't matter if "has" bits are both false.
  1034. first_proto = unittest_pb2.TestAllExtensions()
  1035. second_proto = unittest_pb2.TestAllExtensions()
  1036. self.assertEqual(
  1037. 0, first_proto.Extensions[unittest_pb2.optional_int32_extension])
  1038. self.assertEqual(first_proto, second_proto)
  1039. class MutualRecursionEqualityTest(unittest.TestCase):
  1040. def testEqualityWithMutualRecursion(self):
  1041. first_proto = unittest_pb2.TestMutualRecursionA()
  1042. second_proto = unittest_pb2.TestMutualRecursionA()
  1043. self.assertEqual(first_proto, second_proto)
  1044. first_proto.bb.a.bb.optional_int32 = 23
  1045. self.assertNotEqual(first_proto, second_proto)
  1046. second_proto.bb.a.bb.optional_int32 = 23
  1047. self.assertEqual(first_proto, second_proto)
  1048. class ByteSizeTest(unittest.TestCase):
  1049. def setUp(self):
  1050. self.proto = unittest_pb2.TestAllTypes()
  1051. self.extended_proto = more_extensions_pb2.ExtendedMessage()
  1052. self.packed_proto = unittest_pb2.TestPackedTypes()
  1053. self.packed_extended_proto = unittest_pb2.TestPackedExtensions()
  1054. def Size(self):
  1055. return self.proto.ByteSize()
  1056. def testEmptyMessage(self):
  1057. self.assertEqual(0, self.proto.ByteSize())
  1058. def testVarints(self):
  1059. def Test(i, expected_varint_size):
  1060. self.proto.Clear()
  1061. self.proto.optional_int64 = i
  1062. # Add one to the varint size for the tag info
  1063. # for tag 1.
  1064. self.assertEqual(expected_varint_size + 1, self.Size())
  1065. Test(0, 1)
  1066. Test(1, 1)
  1067. for i, num_bytes in zip(range(7, 63, 7), range(1, 10000)):
  1068. Test((1 << i) - 1, num_bytes)
  1069. Test(-1, 10)
  1070. Test(-2, 10)
  1071. Test(-(1 << 63), 10)
  1072. def testStrings(self):
  1073. self.proto.optional_string = ''
  1074. # Need one byte for tag info (tag #14), and one byte for length.
  1075. self.assertEqual(2, self.Size())
  1076. self.proto.optional_string = 'abc'
  1077. # Need one byte for tag info (tag #14), and one byte for length.
  1078. self.assertEqual(2 + len(self.proto.optional_string), self.Size())
  1079. self.proto.optional_string = 'x' * 128
  1080. # Need one byte for tag info (tag #14), and TWO bytes for length.
  1081. self.assertEqual(3 + len(self.proto.optional_string), self.Size())
  1082. def testOtherNumerics(self):
  1083. self.proto.optional_fixed32 = 1234
  1084. # One byte for tag and 4 bytes for fixed32.
  1085. self.assertEqual(5, self.Size())
  1086. self.proto = unittest_pb2.TestAllTypes()
  1087. self.proto.optional_fixed64 = 1234
  1088. # One byte for tag and 8 bytes for fixed64.
  1089. self.assertEqual(9, self.Size())
  1090. self.proto = unittest_pb2.TestAllTypes()
  1091. self.proto.optional_float = 1.234
  1092. # One byte for tag and 4 bytes for float.
  1093. self.assertEqual(5, self.Size())
  1094. self.proto = unittest_pb2.TestAllTypes()
  1095. self.proto.optional_double = 1.234
  1096. # One byte for tag and 8 bytes for float.
  1097. self.assertEqual(9, self.Size())
  1098. self.proto = unittest_pb2.TestAllTypes()
  1099. self.proto.optional_sint32 = 64
  1100. # One byte for tag and 2 bytes for zig-zag-encoded 64.
  1101. self.assertEqual(3, self.Size())
  1102. self.proto = unittest_pb2.TestAllTypes()
  1103. def testComposites(self):
  1104. # 3 bytes.
  1105. self.proto.optional_nested_message.bb = (1 << 14)
  1106. # Plus one byte for bb tag.
  1107. # Plus 1 byte for optional_nested_message serialized size.
  1108. # Plus two bytes for optional_nested_message tag.
  1109. self.assertEqual(3 + 1 + 1 + 2, self.Size())
  1110. def testGroups(self):
  1111. # 4 bytes.
  1112. self.proto.optionalgroup.a = (1 << 21)
  1113. # Plus two bytes for |a| tag.
  1114. # Plus 2 * two bytes for START_GROUP and END_GROUP tags.
  1115. self.assertEqual(4 + 2 + 2*2, self.Size())
  1116. def testRepeatedScalars(self):
  1117. self.proto.repeated_int32.append(10) # 1 byte.
  1118. self.proto.repeated_int32.append(128) # 2 bytes.
  1119. # Also need 2 bytes for each entry for tag.
  1120. self.assertEqual(1 + 2 + 2*2, self.Size())
  1121. def testRepeatedScalarsExtend(self):
  1122. self.proto.repeated_int32.extend([10, 128]) # 3 bytes.
  1123. # Also need 2 bytes for each entry for tag.
  1124. self.assertEqual(1 + 2 + 2*2, self.Size())
  1125. def testRepeatedScalarsRemove(self):
  1126. self.proto.repeated_int32.append(10) # 1 byte.
  1127. self.proto.repeated_int32.append(128) # 2 bytes.
  1128. # Also need 2 bytes for each entry for tag.
  1129. self.assertEqual(1 + 2 + 2*2, self.Size())
  1130. self.proto.repeated_int32.remove(128)
  1131. self.assertEqual(1 + 2, self.Size())
  1132. def testRepeatedComposites(self):
  1133. # Empty message. 2 bytes tag plus 1 byte length.
  1134. foreign_message_0 = self.proto.repeated_nested_message.add()
  1135. # 2 bytes tag plus 1 byte length plus 1 byte bb tag 1 byte int.
  1136. foreign_message_1 = self.proto.repeated_nested_message.add()
  1137. foreign_message_1.bb = 7
  1138. self.assertEqual(2 + 1 + 2 + 1 + 1 + 1, self.Size())
  1139. def testRepeatedCompositesDelete(self):
  1140. # Empty message. 2 bytes tag plus 1 byte length.
  1141. foreign_message_0 = self.proto.repeated_nested_message.add()
  1142. # 2 bytes tag plus 1 byte length plus 1 byte bb tag 1 byte int.
  1143. foreign_message_1 = self.proto.repeated_nested_message.add()
  1144. foreign_message_1.bb = 9
  1145. self.assertEqual(2 + 1 + 2 + 1 + 1 + 1, self.Size())
  1146. # 2 bytes tag plus 1 byte length plus 1 byte bb tag 1 byte int.
  1147. del self.proto.repeated_nested_message[0]
  1148. self.assertEqual(2 + 1 + 1 + 1, self.Size())
  1149. # Now add a new message.
  1150. foreign_message_2 = self.proto.repeated_nested_message.add()
  1151. foreign_message_2.bb = 12
  1152. # 2 bytes tag plus 1 byte length plus 1 byte bb tag 1 byte int.
  1153. # 2 bytes tag plus 1 byte length plus 1 byte bb tag 1 byte int.
  1154. self.assertEqual(2 + 1 + 1 + 1 + 2 + 1 + 1 + 1, self.Size())
  1155. # 2 bytes tag plus 1 byte length plus 1 byte bb tag 1 byte int.
  1156. del self.proto.repeated_nested_message[1]
  1157. self.assertEqual(2 + 1 + 1 + 1, self.Size())
  1158. del self.proto.repeated_nested_message[0]
  1159. self.assertEqual(0, self.Size())
  1160. def testRepeatedGroups(self):
  1161. # 2-byte START_GROUP plus 2-byte END_GROUP.
  1162. group_0 = self.proto.repeatedgroup.add()
  1163. # 2-byte START_GROUP plus 2-byte |a| tag + 1-byte |a|
  1164. # plus 2-byte END_GROUP.
  1165. group_1 = self.proto.repeatedgroup.add()
  1166. group_1.a = 7
  1167. self.assertEqual(2 + 2 + 2 + 2 + 1 + 2, self.Size())
  1168. def testExtensions(self):
  1169. proto = unittest_pb2.TestAllExtensions()
  1170. self.assertEqual(0, proto.ByteSize())
  1171. extension = unittest_pb2.optional_int32_extension # Field #1, 1 byte.
  1172. proto.Extensions[extension] = 23
  1173. # 1 byte for tag, 1 byte for value.
  1174. self.assertEqual(2, proto.ByteSize())
  1175. def testCacheInvalidationForNonrepeatedScalar(self):
  1176. # Test non-extension.
  1177. self.proto.optional_int32 = 1
  1178. self.assertEqual(2, self.proto.ByteSize())
  1179. self.proto.optional_int32 = 128
  1180. self.assertEqual(3, self.proto.ByteSize())
  1181. self.proto.ClearField('optional_int32')
  1182. self.assertEqual(0, self.proto.ByteSize())
  1183. # Test within extension.
  1184. extension = more_extensions_pb2.optional_int_extension
  1185. self.extended_proto.Extensions[extension] = 1
  1186. self.assertEqual(2, self.extended_proto.ByteSize())
  1187. self.extended_proto.Extensions[extension] = 128
  1188. self.assertEqual(3, self.extended_proto.ByteSize())
  1189. self.extended_proto.ClearExtension(extension)
  1190. self.assertEqual(0, self.extended_proto.ByteSize())
  1191. def testCacheInvalidationForRepeatedScalar(self):
  1192. # Test non-extension.
  1193. self.proto.repeated_int32.append(1)
  1194. self.assertEqual(3, self.proto.ByteSize())
  1195. self.proto.repeated_int32.append(1)
  1196. self.assertEqual(6, self.proto.ByteSize())
  1197. self.proto.repeated_int32[1] = 128
  1198. self.assertEqual(7, self.proto.ByteSize())
  1199. self.proto.ClearField('repeated_int32')
  1200. self.assertEqual(0, self.proto.ByteSize())
  1201. # Test within extension.
  1202. extension = more_extensions_pb2.repeated_int_extension
  1203. repeated = self.extended_proto.Extensions[extension]
  1204. repeated.append(1)
  1205. self.assertEqual(2, self.extended_proto.ByteSize())
  1206. repeated.append(1)
  1207. self.assertEqual(4, self.extended_proto.ByteSize())
  1208. repeated[1] = 128
  1209. self.assertEqual(5, self.extended_proto.ByteSize())
  1210. self.extended_proto.ClearExtension(extension)
  1211. self.assertEqual(0, self.extended_proto.ByteSize())
  1212. def testCacheInvalidationForNonrepeatedMessage(self):
  1213. # Test non-extension.
  1214. self.proto.optional_foreign_message.c = 1
  1215. self.assertEqual(5, self.proto.ByteSize())
  1216. self.proto.optional_foreign_message.c = 128
  1217. self.assertEqual(6, self.proto.ByteSize())
  1218. self.proto.optional_foreign_message.ClearField('c')
  1219. self.assertEqual(3, self.proto.ByteSize())
  1220. self.proto.ClearField('optional_foreign_message')
  1221. self.assertEqual(0, self.proto.ByteSize())
  1222. child = self.proto.optional_foreign_message
  1223. self.proto.ClearField('optional_foreign_message')
  1224. child.c = 128
  1225. self.assertEqual(0, self.proto.ByteSize())
  1226. # Test within extension.
  1227. extension = more_extensions_pb2.optional_message_extension
  1228. child = self.extended_proto.Extensions[extension]
  1229. self.assertEqual(0, self.extended_proto.ByteSize())
  1230. child.foreign_message_int = 1
  1231. self.assertEqual(4, self.extended_proto.ByteSize())
  1232. child.foreign_message_int = 128
  1233. self.assertEqual(5, self.extended_proto.ByteSize())
  1234. self.extended_proto.ClearExtension(extension)
  1235. self.assertEqual(0, self.extended_proto.ByteSize())
  1236. def testCacheInvalidationForRepeatedMessage(self):
  1237. # Test non-extension.
  1238. child0 = self.proto.repeated_foreign_message.add()
  1239. self.assertEqual(3, self.proto.ByteSize())
  1240. self.proto.repeated_foreign_message.add()
  1241. self.assertEqual(6, self.proto.ByteSize())
  1242. child0.c = 1
  1243. self.assertEqual(8, self.proto.ByteSize())
  1244. self.proto.ClearField('repeated_foreign_message')
  1245. self.assertEqual(0, self.proto.ByteSize())
  1246. # Test within extension.
  1247. extension = more_extensions_pb2.repeated_message_extension
  1248. child_list = self.extended_proto.Extensions[extension]
  1249. child0 = child_list.add()
  1250. self.assertEqual(2, self.extended_proto.ByteSize())
  1251. child_list.add()
  1252. self.assertEqual(4, self.extended_proto.ByteSize())
  1253. child0.foreign_message_int = 1
  1254. self.assertEqual(6, self.extended_proto.ByteSize())
  1255. child0.ClearField('foreign_message_int')
  1256. self.assertEqual(4, self.extended_proto.ByteSize())
  1257. self.extended_proto.ClearExtension(extension)
  1258. self.assertEqual(0, self.extended_proto.ByteSize())
  1259. def testPackedRepeatedScalars(self):
  1260. self.assertEqual(0, self.packed_proto.ByteSize())
  1261. self.packed_proto.packed_int32.append(10) # 1 byte.
  1262. self.packed_proto.packed_int32.append(128) # 2 bytes.
  1263. # The tag is 2 bytes (the field number is 90), and the varint
  1264. # storing the length is 1 byte.
  1265. int_size = 1 + 2 + 3
  1266. self.assertEqual(int_size, self.packed_proto.ByteSize())
  1267. self.packed_proto.packed_double.append(4.2) # 8 bytes
  1268. self.packed_proto.packed_double.append(3.25) # 8 bytes
  1269. # 2 more tag bytes, 1 more length byte.
  1270. double_size = 8 + 8 + 3
  1271. self.assertEqual(int_size+double_size, self.packed_proto.ByteSize())
  1272. self.packed_proto.ClearField('packed_int32')
  1273. self.assertEqual(double_size, self.packed_proto.ByteSize())
  1274. def testPackedExtensions(self):
  1275. self.assertEqual(0, self.packed_extended_proto.ByteSize())
  1276. extension = self.packed_extended_proto.Extensions[
  1277. unittest_pb2.packed_fixed32_extension]
  1278. extension.extend([1, 2, 3, 4]) # 16 bytes
  1279. # Tag is 3 bytes.
  1280. self.assertEqual(19, self.packed_extended_proto.ByteSize())
  1281. # TODO(robinson): We need cross-language serialization consistency tests.
  1282. # Issues to be sure to cover include:
  1283. # * Handling of unrecognized tags ("uninterpreted_bytes").
  1284. # * Handling of MessageSets.
  1285. # * Consistent ordering of tags in the wire format,
  1286. # including ordering between extensions and non-extension
  1287. # fields.
  1288. # * Consistent serialization of negative numbers, especially
  1289. # negative int32s.
  1290. # * Handling of empty submessages (with and without "has"
  1291. # bits set).
  1292. class SerializationTest(unittest.TestCase):
  1293. def testSerializeEmtpyMessage(self):
  1294. first_proto = unittest_pb2.TestAllTypes()
  1295. second_proto = unittest_pb2.TestAllTypes()
  1296. serialized = first_proto.SerializeToString()
  1297. self.assertEqual(first_proto.ByteSize(), len(serialized))
  1298. second_proto.MergeFromString(serialized)
  1299. self.assertEqual(first_proto, second_proto)
  1300. def testSerializeAllFields(self):
  1301. first_proto = unittest_pb2.TestAllTypes()
  1302. second_proto = unittest_pb2.TestAllTypes()
  1303. test_util.SetAllFields(first_proto)
  1304. serialized = first_proto.SerializeToString()
  1305. self.assertEqual(first_proto.ByteSize(), len(serialized))
  1306. second_proto.MergeFromString(serialized)
  1307. self.assertEqual(first_proto, second_proto)
  1308. def testSerializeAllExtensions(self):
  1309. first_proto = unittest_pb2.TestAllExtensions()
  1310. second_proto = unittest_pb2.TestAllExtensions()
  1311. test_util.SetAllExtensions(first_proto)
  1312. serialized = first_proto.SerializeToString()
  1313. second_proto.MergeFromString(serialized)
  1314. self.assertEqual(first_proto, second_proto)
  1315. def testCanonicalSerializationOrder(self):
  1316. proto = more_messages_pb2.OutOfOrderFields()
  1317. # These are also their tag numbers. Even though we're setting these in
  1318. # reverse-tag order AND they're listed in reverse tag-order in the .proto
  1319. # file, they should nonetheless be serialized in tag order.
  1320. proto.optional_sint32 = 5
  1321. proto.Extensions[more_messages_pb2.optional_uint64] = 4
  1322. proto.optional_uint32 = 3
  1323. proto.Extensions[more_messages_pb2.optional_int64] = 2
  1324. proto.optional_int32 = 1
  1325. serialized = proto.SerializeToString()
  1326. self.assertEqual(proto.ByteSize(), len(serialized))
  1327. d = decoder.Decoder(serialized)
  1328. ReadTag = d.ReadFieldNumberAndWireType
  1329. self.assertEqual((1, wire_format.WIRETYPE_VARINT), ReadTag())
  1330. self.assertEqual(1, d.ReadInt32())
  1331. self.assertEqual((2, wire_format.WIRETYPE_VARINT), ReadTag())
  1332. self.assertEqual(2, d.ReadInt64())
  1333. self.assertEqual((3, wire_format.WIRETYPE_VARINT), ReadTag())
  1334. self.assertEqual(3, d.ReadUInt32())
  1335. self.assertEqual((4, wire_format.WIRETYPE_VARINT), ReadTag())
  1336. self.assertEqual(4, d.ReadUInt64())
  1337. self.assertEqual((5, wire_format.WIRETYPE_VARINT), ReadTag())
  1338. self.assertEqual(5, d.ReadSInt32())
  1339. def testCanonicalSerializationOrderSameAsCpp(self):
  1340. # Copy of the same test we use for C++.
  1341. proto = unittest_pb2.TestFieldOrderings()
  1342. test_util.SetAllFieldsAndExtensions(proto)
  1343. serialized = proto.SerializeToString()
  1344. test_util.ExpectAllFieldsAndExtensionsInOrder(serialized)
  1345. def testMergeFromStringWhenFieldsAlreadySet(self):
  1346. first_proto = unittest_pb2.TestAllTypes()
  1347. first_proto.repeated_string.append('foobar')
  1348. first_proto.optional_int32 = 23
  1349. first_proto.optional_nested_message.bb = 42
  1350. serialized = first_proto.SerializeToString()
  1351. second_proto = unittest_pb2.TestAllTypes()
  1352. second_proto.repeated_string.append('baz')
  1353. second_proto.optional_int32 = 100
  1354. second_proto.optional_nested_message.bb = 999
  1355. second_proto.MergeFromString(serialized)
  1356. # Ensure that we append to repeated fields.
  1357. self.assertEqual(['baz', 'foobar'], list(second_proto.repeated_string))
  1358. # Ensure that we overwrite nonrepeatd scalars.
  1359. self.assertEqual(23, second_proto.optional_int32)
  1360. # Ensure that we recursively call MergeFromString() on
  1361. # submessages.
  1362. self.assertEqual(42, second_proto.optional_nested_message.bb)
  1363. def testMessageSetWireFormat(self):
  1364. proto = unittest_mset_pb2.TestMessageSet()
  1365. extension_message1 = unittest_mset_pb2.TestMessageSetExtension1
  1366. extension_message2 = unittest_mset_pb2.TestMessageSetExtension2
  1367. extension1 = extension_message1.message_set_extension
  1368. extension2 = extension_message2.message_set_extension
  1369. proto.Extensions[extension1].i = 123
  1370. proto.Extensions[extension2].str = 'foo'
  1371. # Serialize using the MessageSet wire format (this is specified in the
  1372. # .proto file).
  1373. serialized = proto.SerializeToString()
  1374. raw = unittest_mset_pb2.RawMessageSet()
  1375. self.assertEqual(False,
  1376. raw.DESCRIPTOR.GetOptions().message_set_wire_format)
  1377. raw.MergeFromString(serialized)
  1378. self.assertEqual(2, len(raw.item))
  1379. message1 = unittest_mset_pb2.TestMessageSetExtension1()
  1380. message1.MergeFromString(raw.item[0].message)
  1381. self.assertEqual(123, message1.i)
  1382. message2 = unittest_mset_pb2.TestMessageSetExtension2()
  1383. message2.MergeFromString(raw.item[1].message)
  1384. self.assertEqual('foo', message2.str)
  1385. # Deserialize using the MessageSet wire format.
  1386. proto2 = unittest_mset_pb2.TestMessageSet()
  1387. proto2.MergeFromString(serialized)
  1388. self.assertEqual(123, proto2.Extensions[extension1].i)
  1389. self.assertEqual('foo', proto2.Extensions[extension2].str)
  1390. # Check byte size.
  1391. self.assertEqual(proto2.ByteSize(), len(serialized))
  1392. self.assertEqual(proto.ByteSize(), len(serialized))
  1393. def testMessageSetWireFormatUnknownExtension(self):
  1394. # Create a message using the message set wire format with an unknown
  1395. # message.
  1396. raw = unittest_mset_pb2.RawMessageSet()
  1397. # Add an item.
  1398. item = raw.item.add()
  1399. item.type_id = 1545008
  1400. extension_message1 = unittest_mset_pb2.TestMessageSetExtension1
  1401. message1 = unittest_mset_pb2.TestMessageSetExtension1()
  1402. message1.i = 12345
  1403. item.message = message1.SerializeToString()
  1404. # Add a second, unknown extension.
  1405. item = raw.item.add()
  1406. item.type_id = 1545009
  1407. extension_message1 = unittest_mset_pb2.TestMessageSetExtension1
  1408. message1 = unittest_mset_pb2.TestMessageSetExtension1()
  1409. message1.i = 12346
  1410. item.message = message1.SerializeToString()
  1411. # Add another unknown extension.
  1412. item = raw.item.add()
  1413. item.type_id = 1545010
  1414. message1 = unittest_mset_pb2.TestMessageSetExtension2()
  1415. message1.str = 'foo'
  1416. item.message = message1.SerializeToString()
  1417. serialized = raw.SerializeToString()
  1418. # Parse message using the message set wire format.
  1419. proto = unittest_mset_pb2.TestMessageSet()
  1420. proto.MergeFromString(serialized)
  1421. # Check that the message parsed well.
  1422. extension_message1 = unittest_mset_pb2.TestMessageSetExtension1
  1423. extension1 = extension_message1.message_set_extension
  1424. self.assertEquals(12345, proto.Extensions[extension1].i)
  1425. def testUnknownFields(self):
  1426. proto = unittest_pb2.TestAllTypes()
  1427. test_util.SetAllFields(proto)
  1428. serialized = proto.SerializeToString()
  1429. # The empty message should be parsable with all of the fields
  1430. # unknown.
  1431. proto2 = unittest_pb2.TestEmptyMessage()
  1432. # Parsing this message should succeed.
  1433. proto2.MergeFromString(serialized)
  1434. # Now test with a int64 field set.
  1435. proto = unittest_pb2.TestAllTypes()
  1436. proto.optional_int64 = 0x0fffffffffffffff
  1437. serialized = proto.SerializeToString()
  1438. # The empty message should be parsable with all of the fields
  1439. # unknown.
  1440. proto2 = unittest_pb2.TestEmptyMessage()
  1441. # Parsing this message should succeed.
  1442. proto2.MergeFromString(serialized)
  1443. def _CheckRaises(self, exc_class, callable_obj, exception):
  1444. """This method checks if the excpetion type and message are as expected."""
  1445. try:
  1446. callable_obj()
  1447. except exc_class, ex:
  1448. # Check if the exception message is the right one.
  1449. self.assertEqual(exception, str(ex))
  1450. return
  1451. else:
  1452. raise self.failureException('%s not raised' % str(exc_class))
  1453. def testSerializeUninitialized(self):
  1454. proto = unittest_pb2.TestRequired()
  1455. self._CheckRaises(
  1456. message.EncodeError,
  1457. proto.SerializeToString,
  1458. 'Required field protobuf_unittest.TestRequired.a is not set.')
  1459. # Shouldn't raise exceptions.
  1460. partial = proto.SerializePartialToString()
  1461. proto.a = 1
  1462. self._CheckRaises(
  1463. message.EncodeError,
  1464. proto.SerializeToString,
  1465. 'Required field protobuf_unittest.TestRequired.b is not set.')
  1466. # Shouldn't raise exceptions.
  1467. partial = proto.SerializePartialToString()
  1468. proto.b = 2
  1469. self._CheckRaises(
  1470. message.EncodeError,
  1471. proto.SerializeToString,
  1472. 'Required field protobuf_unittest.TestRequired.c is not set.')
  1473. # Shouldn't raise exceptions.
  1474. partial = proto.SerializePartialToString()
  1475. proto.c = 3
  1476. serialized = proto.SerializeToString()
  1477. # Shouldn't raise exceptions.
  1478. partial = proto.SerializePartialToString()
  1479. proto2 = unittest_pb2.TestRequired()
  1480. proto2.MergeFromString(serialized)
  1481. self.assertEqual(1, proto2.a)
  1482. self.assertEqual(2, proto2.b)
  1483. self.assertEqual(3, proto2.c)
  1484. proto2.ParseFromString(partial)
  1485. self.assertEqual(1, proto2.a)
  1486. self.assertEqual(2, proto2.b)
  1487. self.assertEqual(3, proto2.c)
  1488. def testSerializedAllPackedFields(self):
  1489. first_proto = unittest_pb2.TestPackedTypes()
  1490. second_proto = unittest_pb2.TestPackedTypes()
  1491. test_util.SetAllPackedFields(first_proto)
  1492. serialized = first_proto.SerializeToString()
  1493. self.assertEqual(first_proto.ByteSize(), len(serialized))
  1494. second_proto.MergeFromString(serialized)
  1495. self.assertEqual(first_proto, second_proto)
  1496. def testSerializeAllPackedExtensions(self):
  1497. first_proto = unittest_pb2.TestPackedExtensions()
  1498. second_proto = unittest_pb2.TestPackedExtensions()
  1499. test_util.SetAllPackedExtensions(first_proto)
  1500. serialized = first_proto.SerializeToString()
  1501. second_proto.MergeFromString(serialized)
  1502. self.assertEqual(first_proto, second_proto)
  1503. def testMergePackedFromStringWhenSomeFieldsAlreadySet(self):
  1504. first_proto = unittest_pb2.TestPackedTypes()
  1505. first_proto.packed_int32.extend([1, 2])
  1506. first_proto.packed_double.append(3.0)
  1507. serialized = first_proto.SerializeToString()
  1508. second_proto = unittest_pb2.TestPackedTypes()
  1509. second_proto.packed_int32.append(3)
  1510. second_proto.packed_double.extend([1.0, 2.0])
  1511. second_proto.packed_sint32.append(4)
  1512. second_proto.MergeFromString(serialized)
  1513. self.assertEqual([3, 1, 2], second_proto.packed_int32)
  1514. self.assertEqual([1.0, 2.0, 3.0], second_proto.packed_double)
  1515. self.assertEqual([4], second_proto.packed_sint32)
  1516. def testPackedFieldsWireFormat(self):
  1517. proto = unittest_pb2.TestPackedTypes()
  1518. proto.packed_int32.extend([1, 2, 150, 3]) # 1 + 1 + 2 + 1 bytes
  1519. proto.packed_double.extend([1.0, 1000.0]) # 8 + 8 bytes
  1520. proto.packed_float.append(2.0) # 4 bytes, will be before double
  1521. serialized = proto.SerializeToString()
  1522. self.assertEqual(proto.ByteSize(), len(serialized))
  1523. d = decoder.Decoder(serialized)
  1524. ReadTag = d.ReadFieldNumberAndWireType
  1525. self.assertEqual((90, wire_format.WIRETYPE_LENGTH_DELIMITED), ReadTag())
  1526. self.assertEqual(1+1+1+2, d.ReadInt32())
  1527. self.assertEqual(1, d.ReadInt32())
  1528. self.assertEqual(2, d.ReadInt32())
  1529. self.assertEqual(150, d.ReadInt32())
  1530. self.assertEqual(3, d.ReadInt32())
  1531. self.assertEqual((100, wire_format.WIRETYPE_LENGTH_DELIMITED), ReadTag())
  1532. self.assertEqual(4, d.ReadInt32())
  1533. self.assertEqual(2.0, d.ReadFloat())
  1534. self.assertEqual((101, wire_format.WIRETYPE_LENGTH_DELIMITED), ReadTag())
  1535. self.assertEqual(8+8, d.ReadInt32())
  1536. self.assertEqual(1.0, d.ReadDouble())
  1537. self.assertEqual(1000.0, d.ReadDouble())
  1538. self.assertTrue(d.EndOfStream())
  1539. class OptionsTest(unittest.TestCase):
  1540. def testMessageOptions(self):
  1541. proto = unittest_mset_pb2.TestMessageSet()
  1542. self.assertEqual(True,
  1543. proto.DESCRIPTOR.GetOptions().message_set_wire_format)
  1544. proto = unittest_pb2.TestAllTypes()
  1545. self.assertEqual(False,
  1546. proto.DESCRIPTOR.GetOptions().message_set_wire_format)
  1547. def testPackedOptions(self):
  1548. proto = unittest_pb2.TestAllTypes()
  1549. proto.optional_int32 = 1
  1550. proto.optional_double = 3.0
  1551. for field_descriptor, _ in proto.ListFields():
  1552. self.assertEqual(False, field_descriptor.GetOptions().packed)
  1553. proto = unittest_pb2.TestPackedTypes()
  1554. proto.packed_int32.append(1)
  1555. proto.packed_double.append(3.0)
  1556. for field_descriptor, _ in proto.ListFields():
  1557. self.assertEqual(True, field_descriptor.GetOptions().packed)
  1558. self.assertEqual(reflection._FieldDescriptor.LABEL_REPEATED,
  1559. field_descriptor.label)
  1560. class UtilityTest(unittest.TestCase):
  1561. def testImergeSorted(self):
  1562. ImergeSorted = reflection._ImergeSorted
  1563. # Various types of emptiness.
  1564. self.assertEqual([], list(ImergeSorted()))
  1565. self.assertEqual([], list(ImergeSorted([])))
  1566. self.assertEqual([], list(ImergeSorted([], [])))
  1567. # One nonempty list.
  1568. self.assertEqual([1, 2, 3], list(ImergeSorted([1, 2, 3])))
  1569. self.assertEqual([1, 2, 3], list(ImergeSorted([1, 2, 3], [])))
  1570. self.assertEqual([1, 2, 3], list(ImergeSorted([], [1, 2, 3])))
  1571. # Merging some nonempty lists together.
  1572. self.assertEqual([1, 2, 3], list(ImergeSorted([1, 3], [2])))
  1573. self.assertEqual([1, 2, 3], list(ImergeSorted([1], [3], [2])))
  1574. self.assertEqual([1, 2, 3], list(ImergeSorted([1], [3], [2], [])))
  1575. # Elements repeated across component iterators.
  1576. self.assertEqual([1, 2, 2, 3, 3],
  1577. list(ImergeSorted([1, 2], [3], [2, 3])))
  1578. # Elements repeated within an iterator.
  1579. self.assertEqual([1, 2, 2, 3, 3],
  1580. list(ImergeSorted([1, 2, 2], [3], [3])))
  1581. if __name__ == '__main__':
  1582. unittest.main()