reflection_test.py 127 KB


  1. #! /usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. #
  4. # Protocol Buffers - Google's data interchange format
  5. # Copyright 2008 Google Inc. All rights reserved.
  6. # https://developers.google.com/protocol-buffers/
  7. #
  8. # Redistribution and use in source and binary forms, with or without
  9. # modification, are permitted provided that the following conditions are
  10. # met:
  11. #
  12. # * Redistributions of source code must retain the above copyright
  13. # notice, this list of conditions and the following disclaimer.
  14. # * Redistributions in binary form must reproduce the above
  15. # copyright notice, this list of conditions and the following disclaimer
  16. # in the documentation and/or other materials provided with the
  17. # distribution.
  18. # * Neither the name of Google Inc. nor the names of its
  19. # contributors may be used to endorse or promote products derived from
  20. # this software without specific prior written permission.
  21. #
  22. # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
  23. # "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
  24. # LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
  25. # A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
  26. # OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
  27. # SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
  28. # LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
  29. # DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
  30. # THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
  31. # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
  32. # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
  33. """Unittest for reflection.py, which also indirectly tests the output of the
  34. pure-Python protocol compiler.
  35. """
  36. import copy
  37. import gc
  38. import operator
  39. import six
  40. import struct
  41. try:
  42. import unittest2 as unittest #PY26
  43. except ImportError:
  44. import unittest
  45. from google.protobuf import unittest_import_pb2
  46. from google.protobuf import unittest_mset_pb2
  47. from google.protobuf import unittest_pb2
  48. from google.protobuf import descriptor_pb2
  49. from google.protobuf import descriptor
  50. from google.protobuf import message
  51. from google.protobuf import reflection
  52. from google.protobuf import text_format
  53. from google.protobuf.internal import api_implementation
  54. from google.protobuf.internal import more_extensions_pb2
  55. from google.protobuf.internal import more_messages_pb2
  56. from google.protobuf.internal import message_set_extensions_pb2
  57. from google.protobuf.internal import wire_format
  58. from google.protobuf.internal import test_util
  59. from google.protobuf.internal import testing_refleaks
  60. from google.protobuf.internal import decoder
  61. if six.PY3:
  62. long = int # pylint: disable=redefined-builtin,invalid-name
  63. BaseTestCase = testing_refleaks.BaseTestCase
  64. class _MiniDecoder(object):
  65. """Decodes a stream of values from a string.
  66. Once upon a time we actually had a class called decoder.Decoder. Then we
  67. got rid of it during a redesign that made decoding much, much faster overall.
  68. But a couple tests in this file used it to check that the serialized form of
  69. a message was correct. So, this class implements just the methods that were
  70. used by said tests, so that we don't have to rewrite the tests.
  71. """
  72. def __init__(self, bytes):
  73. self._bytes = bytes
  74. self._pos = 0
  75. def ReadVarint(self):
  76. result, self._pos = decoder._DecodeVarint(self._bytes, self._pos)
  77. return result
  78. ReadInt32 = ReadVarint
  79. ReadInt64 = ReadVarint
  80. ReadUInt32 = ReadVarint
  81. ReadUInt64 = ReadVarint
  82. def ReadSInt64(self):
  83. return wire_format.ZigZagDecode(self.ReadVarint())
  84. ReadSInt32 = ReadSInt64
  85. def ReadFieldNumberAndWireType(self):
  86. return wire_format.UnpackTag(self.ReadVarint())
  87. def ReadFloat(self):
  88. result = struct.unpack('<f', self._bytes[self._pos:self._pos+4])[0]
  89. self._pos += 4
  90. return result
  91. def ReadDouble(self):
  92. result = struct.unpack('<d', self._bytes[self._pos:self._pos+8])[0]
  93. self._pos += 8
  94. return result
  95. def EndOfStream(self):
  96. return self._pos == len(self._bytes)
  97. class ReflectionTest(BaseTestCase):
  98. def assertListsEqual(self, values, others):
  99. self.assertEqual(len(values), len(others))
  100. for i in range(len(values)):
  101. self.assertEqual(values[i], others[i])
  102. def testScalarConstructor(self):
  103. # Constructor with only scalar types should succeed.
  104. proto = unittest_pb2.TestAllTypes(
  105. optional_int32=24,
  106. optional_double=54.321,
  107. optional_string='optional_string',
  108. optional_float=None)
  109. self.assertEqual(24, proto.optional_int32)
  110. self.assertEqual(54.321, proto.optional_double)
  111. self.assertEqual('optional_string', proto.optional_string)
  112. self.assertFalse(proto.HasField("optional_float"))
  113. def testRepeatedScalarConstructor(self):
  114. # Constructor with only repeated scalar types should succeed.
  115. proto = unittest_pb2.TestAllTypes(
  116. repeated_int32=[1, 2, 3, 4],
  117. repeated_double=[1.23, 54.321],
  118. repeated_bool=[True, False, False],
  119. repeated_string=["optional_string"],
  120. repeated_float=None)
  121. self.assertEqual([1, 2, 3, 4], list(proto.repeated_int32))
  122. self.assertEqual([1.23, 54.321], list(proto.repeated_double))
  123. self.assertEqual([True, False, False], list(proto.repeated_bool))
  124. self.assertEqual(["optional_string"], list(proto.repeated_string))
  125. self.assertEqual([], list(proto.repeated_float))
  126. def testRepeatedCompositeConstructor(self):
  127. # Constructor with only repeated composite types should succeed.
  128. proto = unittest_pb2.TestAllTypes(
  129. repeated_nested_message=[
  130. unittest_pb2.TestAllTypes.NestedMessage(
  131. bb=unittest_pb2.TestAllTypes.FOO),
  132. unittest_pb2.TestAllTypes.NestedMessage(
  133. bb=unittest_pb2.TestAllTypes.BAR)],
  134. repeated_foreign_message=[
  135. unittest_pb2.ForeignMessage(c=-43),
  136. unittest_pb2.ForeignMessage(c=45324),
  137. unittest_pb2.ForeignMessage(c=12)],
  138. repeatedgroup=[
  139. unittest_pb2.TestAllTypes.RepeatedGroup(),
  140. unittest_pb2.TestAllTypes.RepeatedGroup(a=1),
  141. unittest_pb2.TestAllTypes.RepeatedGroup(a=2)])
  142. self.assertEqual(
  143. [unittest_pb2.TestAllTypes.NestedMessage(
  144. bb=unittest_pb2.TestAllTypes.FOO),
  145. unittest_pb2.TestAllTypes.NestedMessage(
  146. bb=unittest_pb2.TestAllTypes.BAR)],
  147. list(proto.repeated_nested_message))
  148. self.assertEqual(
  149. [unittest_pb2.ForeignMessage(c=-43),
  150. unittest_pb2.ForeignMessage(c=45324),
  151. unittest_pb2.ForeignMessage(c=12)],
  152. list(proto.repeated_foreign_message))
  153. self.assertEqual(
  154. [unittest_pb2.TestAllTypes.RepeatedGroup(),
  155. unittest_pb2.TestAllTypes.RepeatedGroup(a=1),
  156. unittest_pb2.TestAllTypes.RepeatedGroup(a=2)],
  157. list(proto.repeatedgroup))
  158. def testMixedConstructor(self):
  159. # Constructor with only mixed types should succeed.
  160. proto = unittest_pb2.TestAllTypes(
  161. optional_int32=24,
  162. optional_string='optional_string',
  163. repeated_double=[1.23, 54.321],
  164. repeated_bool=[True, False, False],
  165. repeated_nested_message=[
  166. unittest_pb2.TestAllTypes.NestedMessage(
  167. bb=unittest_pb2.TestAllTypes.FOO),
  168. unittest_pb2.TestAllTypes.NestedMessage(
  169. bb=unittest_pb2.TestAllTypes.BAR)],
  170. repeated_foreign_message=[
  171. unittest_pb2.ForeignMessage(c=-43),
  172. unittest_pb2.ForeignMessage(c=45324),
  173. unittest_pb2.ForeignMessage(c=12)],
  174. optional_nested_message=None)
  175. self.assertEqual(24, proto.optional_int32)
  176. self.assertEqual('optional_string', proto.optional_string)
  177. self.assertEqual([1.23, 54.321], list(proto.repeated_double))
  178. self.assertEqual([True, False, False], list(proto.repeated_bool))
  179. self.assertEqual(
  180. [unittest_pb2.TestAllTypes.NestedMessage(
  181. bb=unittest_pb2.TestAllTypes.FOO),
  182. unittest_pb2.TestAllTypes.NestedMessage(
  183. bb=unittest_pb2.TestAllTypes.BAR)],
  184. list(proto.repeated_nested_message))
  185. self.assertEqual(
  186. [unittest_pb2.ForeignMessage(c=-43),
  187. unittest_pb2.ForeignMessage(c=45324),
  188. unittest_pb2.ForeignMessage(c=12)],
  189. list(proto.repeated_foreign_message))
  190. self.assertFalse(proto.HasField("optional_nested_message"))
  191. def testConstructorTypeError(self):
  192. self.assertRaises(
  193. TypeError, unittest_pb2.TestAllTypes, optional_int32="foo")
  194. self.assertRaises(
  195. TypeError, unittest_pb2.TestAllTypes, optional_string=1234)
  196. self.assertRaises(
  197. TypeError, unittest_pb2.TestAllTypes, optional_nested_message=1234)
  198. self.assertRaises(
  199. TypeError, unittest_pb2.TestAllTypes, repeated_int32=1234)
  200. self.assertRaises(
  201. TypeError, unittest_pb2.TestAllTypes, repeated_int32=["foo"])
  202. self.assertRaises(
  203. TypeError, unittest_pb2.TestAllTypes, repeated_string=1234)
  204. self.assertRaises(
  205. TypeError, unittest_pb2.TestAllTypes, repeated_string=[1234])
  206. self.assertRaises(
  207. TypeError, unittest_pb2.TestAllTypes, repeated_nested_message=1234)
  208. self.assertRaises(
  209. TypeError, unittest_pb2.TestAllTypes, repeated_nested_message=[1234])
  210. def testConstructorInvalidatesCachedByteSize(self):
  211. message = unittest_pb2.TestAllTypes(optional_int32 = 12)
  212. self.assertEqual(2, message.ByteSize())
  213. message = unittest_pb2.TestAllTypes(
  214. optional_nested_message = unittest_pb2.TestAllTypes.NestedMessage())
  215. self.assertEqual(3, message.ByteSize())
  216. message = unittest_pb2.TestAllTypes(repeated_int32 = [12])
  217. self.assertEqual(3, message.ByteSize())
  218. message = unittest_pb2.TestAllTypes(
  219. repeated_nested_message = [unittest_pb2.TestAllTypes.NestedMessage()])
  220. self.assertEqual(3, message.ByteSize())
  221. def testSimpleHasBits(self):
  222. # Test a scalar.
  223. proto = unittest_pb2.TestAllTypes()
  224. self.assertTrue(not proto.HasField('optional_int32'))
  225. self.assertEqual(0, proto.optional_int32)
  226. # HasField() shouldn't be true if all we've done is
  227. # read the default value.
  228. self.assertTrue(not proto.HasField('optional_int32'))
  229. proto.optional_int32 = 1
  230. # Setting a value however *should* set the "has" bit.
  231. self.assertTrue(proto.HasField('optional_int32'))
  232. proto.ClearField('optional_int32')
  233. # And clearing that value should unset the "has" bit.
  234. self.assertTrue(not proto.HasField('optional_int32'))
  235. def testHasBitsWithSinglyNestedScalar(self):
  236. # Helper used to test foreign messages and groups.
  237. #
  238. # composite_field_name should be the name of a non-repeated
  239. # composite (i.e., foreign or group) field in TestAllTypes,
  240. # and scalar_field_name should be the name of an integer-valued
  241. # scalar field within that composite.
  242. #
  243. # I never thought I'd miss C++ macros and templates so much. :(
  244. # This helper is semantically just:
  245. #
  246. # assert proto.composite_field.scalar_field == 0
  247. # assert not proto.composite_field.HasField('scalar_field')
  248. # assert not proto.HasField('composite_field')
  249. #
  250. # proto.composite_field.scalar_field = 10
  251. # old_composite_field = proto.composite_field
  252. #
  253. # assert proto.composite_field.scalar_field == 10
  254. # assert proto.composite_field.HasField('scalar_field')
  255. # assert proto.HasField('composite_field')
  256. #
  257. # proto.ClearField('composite_field')
  258. #
  259. # assert not proto.composite_field.HasField('scalar_field')
  260. # assert not proto.HasField('composite_field')
  261. # assert proto.composite_field.scalar_field == 0
  262. #
  263. # # Now ensure that ClearField('composite_field') disconnected
  264. # # the old field object from the object tree...
  265. # assert old_composite_field is not proto.composite_field
  266. # old_composite_field.scalar_field = 20
  267. # assert not proto.composite_field.HasField('scalar_field')
  268. # assert not proto.HasField('composite_field')
  269. def TestCompositeHasBits(composite_field_name, scalar_field_name):
  270. proto = unittest_pb2.TestAllTypes()
  271. # First, check that we can get the scalar value, and see that it's the
  272. # default (0), but that proto.HasField('omposite') and
  273. # proto.composite.HasField('scalar') will still return False.
  274. composite_field = getattr(proto, composite_field_name)
  275. original_scalar_value = getattr(composite_field, scalar_field_name)
  276. self.assertEqual(0, original_scalar_value)
  277. # Assert that the composite object does not "have" the scalar.
  278. self.assertTrue(not composite_field.HasField(scalar_field_name))
  279. # Assert that proto does not "have" the composite field.
  280. self.assertTrue(not proto.HasField(composite_field_name))
  281. # Now set the scalar within the composite field. Ensure that the setting
  282. # is reflected, and that proto.HasField('composite') and
  283. # proto.composite.HasField('scalar') now both return True.
  284. new_val = 20
  285. setattr(composite_field, scalar_field_name, new_val)
  286. self.assertEqual(new_val, getattr(composite_field, scalar_field_name))
  287. # Hold on to a reference to the current composite_field object.
  288. old_composite_field = composite_field
  289. # Assert that the has methods now return true.
  290. self.assertTrue(composite_field.HasField(scalar_field_name))
  291. self.assertTrue(proto.HasField(composite_field_name))
  292. # Now call the clear method...
  293. proto.ClearField(composite_field_name)
  294. # ...and ensure that the "has" bits are all back to False...
  295. composite_field = getattr(proto, composite_field_name)
  296. self.assertTrue(not composite_field.HasField(scalar_field_name))
  297. self.assertTrue(not proto.HasField(composite_field_name))
  298. # ...and ensure that the scalar field has returned to its default.
  299. self.assertEqual(0, getattr(composite_field, scalar_field_name))
  300. self.assertTrue(old_composite_field is not composite_field)
  301. setattr(old_composite_field, scalar_field_name, new_val)
  302. self.assertTrue(not composite_field.HasField(scalar_field_name))
  303. self.assertTrue(not proto.HasField(composite_field_name))
  304. self.assertEqual(0, getattr(composite_field, scalar_field_name))
  305. # Test simple, single-level nesting when we set a scalar.
  306. TestCompositeHasBits('optionalgroup', 'a')
  307. TestCompositeHasBits('optional_nested_message', 'bb')
  308. TestCompositeHasBits('optional_foreign_message', 'c')
  309. TestCompositeHasBits('optional_import_message', 'd')
  310. def testReferencesToNestedMessage(self):
  311. proto = unittest_pb2.TestAllTypes()
  312. nested = proto.optional_nested_message
  313. del proto
  314. # A previous version had a bug where this would raise an exception when
  315. # hitting a now-dead weak reference.
  316. nested.bb = 23
  317. def testDisconnectingNestedMessageBeforeSettingField(self):
  318. proto = unittest_pb2.TestAllTypes()
  319. nested = proto.optional_nested_message
  320. proto.ClearField('optional_nested_message') # Should disconnect from parent
  321. self.assertTrue(nested is not proto.optional_nested_message)
  322. nested.bb = 23
  323. self.assertTrue(not proto.HasField('optional_nested_message'))
  324. self.assertEqual(0, proto.optional_nested_message.bb)
  325. def testGetDefaultMessageAfterDisconnectingDefaultMessage(self):
  326. proto = unittest_pb2.TestAllTypes()
  327. nested = proto.optional_nested_message
  328. proto.ClearField('optional_nested_message')
  329. del proto
  330. del nested
  331. # Force a garbage collect so that the underlying CMessages are freed along
  332. # with the Messages they point to. This is to make sure we're not deleting
  333. # default message instances.
  334. gc.collect()
  335. proto = unittest_pb2.TestAllTypes()
  336. nested = proto.optional_nested_message
  337. def testDisconnectingNestedMessageAfterSettingField(self):
  338. proto = unittest_pb2.TestAllTypes()
  339. nested = proto.optional_nested_message
  340. nested.bb = 5
  341. self.assertTrue(proto.HasField('optional_nested_message'))
  342. proto.ClearField('optional_nested_message') # Should disconnect from parent
  343. self.assertEqual(5, nested.bb)
  344. self.assertEqual(0, proto.optional_nested_message.bb)
  345. self.assertTrue(nested is not proto.optional_nested_message)
  346. nested.bb = 23
  347. self.assertTrue(not proto.HasField('optional_nested_message'))
  348. self.assertEqual(0, proto.optional_nested_message.bb)
  349. def testDisconnectingNestedMessageBeforeGettingField(self):
  350. proto = unittest_pb2.TestAllTypes()
  351. self.assertTrue(not proto.HasField('optional_nested_message'))
  352. proto.ClearField('optional_nested_message')
  353. self.assertTrue(not proto.HasField('optional_nested_message'))
  354. def testDisconnectingNestedMessageAfterMerge(self):
  355. # This test exercises the code path that does not use ReleaseMessage().
  356. # The underlying fear is that if we use ReleaseMessage() incorrectly,
  357. # we will have memory leaks. It's hard to check that that doesn't happen,
  358. # but at least we can exercise that code path to make sure it works.
  359. proto1 = unittest_pb2.TestAllTypes()
  360. proto2 = unittest_pb2.TestAllTypes()
  361. proto2.optional_nested_message.bb = 5
  362. proto1.MergeFrom(proto2)
  363. self.assertTrue(proto1.HasField('optional_nested_message'))
  364. proto1.ClearField('optional_nested_message')
  365. self.assertTrue(not proto1.HasField('optional_nested_message'))
  366. def testDisconnectingLazyNestedMessage(self):
  367. # This test exercises releasing a nested message that is lazy. This test
  368. # only exercises real code in the C++ implementation as Python does not
  369. # support lazy parsing, but the current C++ implementation results in
  370. # memory corruption and a crash.
  371. if api_implementation.Type() != 'python':
  372. return
  373. proto = unittest_pb2.TestAllTypes()
  374. proto.optional_lazy_message.bb = 5
  375. proto.ClearField('optional_lazy_message')
  376. del proto
  377. gc.collect()
  378. def testHasBitsWhenModifyingRepeatedFields(self):
  379. # Test nesting when we add an element to a repeated field in a submessage.
  380. proto = unittest_pb2.TestNestedMessageHasBits()
  381. proto.optional_nested_message.nestedmessage_repeated_int32.append(5)
  382. self.assertEqual(
  383. [5], proto.optional_nested_message.nestedmessage_repeated_int32)
  384. self.assertTrue(proto.HasField('optional_nested_message'))
  385. # Do the same test, but with a repeated composite field within the
  386. # submessage.
  387. proto.ClearField('optional_nested_message')
  388. self.assertTrue(not proto.HasField('optional_nested_message'))
  389. proto.optional_nested_message.nestedmessage_repeated_foreignmessage.add()
  390. self.assertTrue(proto.HasField('optional_nested_message'))
  391. def testHasBitsForManyLevelsOfNesting(self):
  392. # Test nesting many levels deep.
  393. recursive_proto = unittest_pb2.TestMutualRecursionA()
  394. self.assertTrue(not recursive_proto.HasField('bb'))
  395. self.assertEqual(0, recursive_proto.bb.a.bb.a.bb.optional_int32)
  396. self.assertTrue(not recursive_proto.HasField('bb'))
  397. recursive_proto.bb.a.bb.a.bb.optional_int32 = 5
  398. self.assertEqual(5, recursive_proto.bb.a.bb.a.bb.optional_int32)
  399. self.assertTrue(recursive_proto.HasField('bb'))
  400. self.assertTrue(recursive_proto.bb.HasField('a'))
  401. self.assertTrue(recursive_proto.bb.a.HasField('bb'))
  402. self.assertTrue(recursive_proto.bb.a.bb.HasField('a'))
  403. self.assertTrue(recursive_proto.bb.a.bb.a.HasField('bb'))
  404. self.assertTrue(not recursive_proto.bb.a.bb.a.bb.HasField('a'))
  405. self.assertTrue(recursive_proto.bb.a.bb.a.bb.HasField('optional_int32'))
  406. def testSingularListFields(self):
  407. proto = unittest_pb2.TestAllTypes()
  408. proto.optional_fixed32 = 1
  409. proto.optional_int32 = 5
  410. proto.optional_string = 'foo'
  411. # Access sub-message but don't set it yet.
  412. nested_message = proto.optional_nested_message
  413. self.assertEqual(
  414. [ (proto.DESCRIPTOR.fields_by_name['optional_int32' ], 5),
  415. (proto.DESCRIPTOR.fields_by_name['optional_fixed32'], 1),
  416. (proto.DESCRIPTOR.fields_by_name['optional_string' ], 'foo') ],
  417. proto.ListFields())
  418. proto.optional_nested_message.bb = 123
  419. self.assertEqual(
  420. [ (proto.DESCRIPTOR.fields_by_name['optional_int32' ], 5),
  421. (proto.DESCRIPTOR.fields_by_name['optional_fixed32'], 1),
  422. (proto.DESCRIPTOR.fields_by_name['optional_string' ], 'foo'),
  423. (proto.DESCRIPTOR.fields_by_name['optional_nested_message' ],
  424. nested_message) ],
  425. proto.ListFields())
  426. def testRepeatedListFields(self):
  427. proto = unittest_pb2.TestAllTypes()
  428. proto.repeated_fixed32.append(1)
  429. proto.repeated_int32.append(5)
  430. proto.repeated_int32.append(11)
  431. proto.repeated_string.extend(['foo', 'bar'])
  432. proto.repeated_string.extend([])
  433. proto.repeated_string.append('baz')
  434. proto.repeated_string.extend(str(x) for x in range(2))
  435. proto.optional_int32 = 21
  436. proto.repeated_bool # Access but don't set anything; should not be listed.
  437. self.assertEqual(
  438. [ (proto.DESCRIPTOR.fields_by_name['optional_int32' ], 21),
  439. (proto.DESCRIPTOR.fields_by_name['repeated_int32' ], [5, 11]),
  440. (proto.DESCRIPTOR.fields_by_name['repeated_fixed32'], [1]),
  441. (proto.DESCRIPTOR.fields_by_name['repeated_string' ],
  442. ['foo', 'bar', 'baz', '0', '1']) ],
  443. proto.ListFields())
  444. def testSingularListExtensions(self):
  445. proto = unittest_pb2.TestAllExtensions()
  446. proto.Extensions[unittest_pb2.optional_fixed32_extension] = 1
  447. proto.Extensions[unittest_pb2.optional_int32_extension ] = 5
  448. proto.Extensions[unittest_pb2.optional_string_extension ] = 'foo'
  449. self.assertEqual(
  450. [ (unittest_pb2.optional_int32_extension , 5),
  451. (unittest_pb2.optional_fixed32_extension, 1),
  452. (unittest_pb2.optional_string_extension , 'foo') ],
  453. proto.ListFields())
  454. def testRepeatedListExtensions(self):
  455. proto = unittest_pb2.TestAllExtensions()
  456. proto.Extensions[unittest_pb2.repeated_fixed32_extension].append(1)
  457. proto.Extensions[unittest_pb2.repeated_int32_extension ].append(5)
  458. proto.Extensions[unittest_pb2.repeated_int32_extension ].append(11)
  459. proto.Extensions[unittest_pb2.repeated_string_extension ].append('foo')
  460. proto.Extensions[unittest_pb2.repeated_string_extension ].append('bar')
  461. proto.Extensions[unittest_pb2.repeated_string_extension ].append('baz')
  462. proto.Extensions[unittest_pb2.optional_int32_extension ] = 21
  463. self.assertEqual(
  464. [ (unittest_pb2.optional_int32_extension , 21),
  465. (unittest_pb2.repeated_int32_extension , [5, 11]),
  466. (unittest_pb2.repeated_fixed32_extension, [1]),
  467. (unittest_pb2.repeated_string_extension , ['foo', 'bar', 'baz']) ],
  468. proto.ListFields())
  469. def testListFieldsAndExtensions(self):
  470. proto = unittest_pb2.TestFieldOrderings()
  471. test_util.SetAllFieldsAndExtensions(proto)
  472. unittest_pb2.my_extension_int
  473. self.assertEqual(
  474. [ (proto.DESCRIPTOR.fields_by_name['my_int' ], 1),
  475. (unittest_pb2.my_extension_int , 23),
  476. (proto.DESCRIPTOR.fields_by_name['my_string'], 'foo'),
  477. (unittest_pb2.my_extension_string , 'bar'),
  478. (proto.DESCRIPTOR.fields_by_name['my_float' ], 1.0) ],
  479. proto.ListFields())
  480. def testDefaultValues(self):
  481. proto = unittest_pb2.TestAllTypes()
  482. self.assertEqual(0, proto.optional_int32)
  483. self.assertEqual(0, proto.optional_int64)
  484. self.assertEqual(0, proto.optional_uint32)
  485. self.assertEqual(0, proto.optional_uint64)
  486. self.assertEqual(0, proto.optional_sint32)
  487. self.assertEqual(0, proto.optional_sint64)
  488. self.assertEqual(0, proto.optional_fixed32)
  489. self.assertEqual(0, proto.optional_fixed64)
  490. self.assertEqual(0, proto.optional_sfixed32)
  491. self.assertEqual(0, proto.optional_sfixed64)
  492. self.assertEqual(0.0, proto.optional_float)
  493. self.assertEqual(0.0, proto.optional_double)
  494. self.assertEqual(False, proto.optional_bool)
  495. self.assertEqual('', proto.optional_string)
  496. self.assertEqual(b'', proto.optional_bytes)
  497. self.assertEqual(41, proto.default_int32)
  498. self.assertEqual(42, proto.default_int64)
  499. self.assertEqual(43, proto.default_uint32)
  500. self.assertEqual(44, proto.default_uint64)
  501. self.assertEqual(-45, proto.default_sint32)
  502. self.assertEqual(46, proto.default_sint64)
  503. self.assertEqual(47, proto.default_fixed32)
  504. self.assertEqual(48, proto.default_fixed64)
  505. self.assertEqual(49, proto.default_sfixed32)
  506. self.assertEqual(-50, proto.default_sfixed64)
  507. self.assertEqual(51.5, proto.default_float)
  508. self.assertEqual(52e3, proto.default_double)
  509. self.assertEqual(True, proto.default_bool)
  510. self.assertEqual('hello', proto.default_string)
  511. self.assertEqual(b'world', proto.default_bytes)
  512. self.assertEqual(unittest_pb2.TestAllTypes.BAR, proto.default_nested_enum)
  513. self.assertEqual(unittest_pb2.FOREIGN_BAR, proto.default_foreign_enum)
  514. self.assertEqual(unittest_import_pb2.IMPORT_BAR,
  515. proto.default_import_enum)
  516. proto = unittest_pb2.TestExtremeDefaultValues()
  517. self.assertEqual(u'\u1234', proto.utf8_string)
  518. def testHasFieldWithUnknownFieldName(self):
  519. proto = unittest_pb2.TestAllTypes()
  520. self.assertRaises(ValueError, proto.HasField, 'nonexistent_field')
  521. def testClearFieldWithUnknownFieldName(self):
  522. proto = unittest_pb2.TestAllTypes()
  523. self.assertRaises(ValueError, proto.ClearField, 'nonexistent_field')
  524. def testClearRemovesChildren(self):
  525. # Make sure there aren't any implementation bugs that are only partially
  526. # clearing the message (which can happen in the more complex C++
  527. # implementation which has parallel message lists).
  528. proto = unittest_pb2.TestRequiredForeign()
  529. for i in range(10):
  530. proto.repeated_message.add()
  531. proto2 = unittest_pb2.TestRequiredForeign()
  532. proto.CopyFrom(proto2)
  533. self.assertRaises(IndexError, lambda: proto.repeated_message[5])
  534. def testDisallowedAssignments(self):
  535. # It's illegal to assign values directly to repeated fields
  536. # or to nonrepeated composite fields. Ensure that this fails.
  537. proto = unittest_pb2.TestAllTypes()
  538. # Repeated fields.
  539. self.assertRaises(AttributeError, setattr, proto, 'repeated_int32', 10)
  540. # Lists shouldn't work, either.
  541. self.assertRaises(AttributeError, setattr, proto, 'repeated_int32', [10])
  542. # Composite fields.
  543. self.assertRaises(AttributeError, setattr, proto,
  544. 'optional_nested_message', 23)
  545. # Assignment to a repeated nested message field without specifying
  546. # the index in the array of nested messages.
  547. self.assertRaises(AttributeError, setattr, proto.repeated_nested_message,
  548. 'bb', 34)
  549. # Assignment to an attribute of a repeated field.
  550. self.assertRaises(AttributeError, setattr, proto.repeated_float,
  551. 'some_attribute', 34)
  552. # proto.nonexistent_field = 23 should fail as well.
  553. self.assertRaises(AttributeError, setattr, proto, 'nonexistent_field', 23)
  554. def testSingleScalarTypeSafety(self):
  555. proto = unittest_pb2.TestAllTypes()
  556. self.assertRaises(TypeError, setattr, proto, 'optional_int32', 1.1)
  557. self.assertRaises(TypeError, setattr, proto, 'optional_int32', 'foo')
  558. self.assertRaises(TypeError, setattr, proto, 'optional_string', 10)
  559. self.assertRaises(TypeError, setattr, proto, 'optional_bytes', 10)
  560. self.assertRaises(TypeError, setattr, proto, 'optional_bool', 'foo')
  561. self.assertRaises(TypeError, setattr, proto, 'optional_float', 'foo')
  562. self.assertRaises(TypeError, setattr, proto, 'optional_double', 'foo')
  563. # TODO(jieluo): Fix type checking difference for python and c extension
  564. if api_implementation.Type() == 'python':
  565. self.assertRaises(TypeError, setattr, proto, 'optional_bool', 1.1)
  566. else:
  567. proto.optional_bool = 1.1
  568. def assertIntegerTypes(self, integer_fn):
  569. """Verifies setting of scalar integers.
  570. Args:
  571. integer_fn: A function to wrap the integers that will be assigned.
  572. """
  573. def TestGetAndDeserialize(field_name, value, expected_type):
  574. proto = unittest_pb2.TestAllTypes()
  575. value = integer_fn(value)
  576. setattr(proto, field_name, value)
  577. self.assertIsInstance(getattr(proto, field_name), expected_type)
  578. proto2 = unittest_pb2.TestAllTypes()
  579. proto2.ParseFromString(proto.SerializeToString())
  580. self.assertIsInstance(getattr(proto2, field_name), expected_type)
  581. TestGetAndDeserialize('optional_int32', 1, int)
  582. TestGetAndDeserialize('optional_int32', 1 << 30, int)
  583. TestGetAndDeserialize('optional_uint32', 1 << 30, int)
  584. integer_64 = long
  585. if struct.calcsize('L') == 4:
  586. # Python only has signed ints, so 32-bit python can't fit an uint32
  587. # in an int.
  588. TestGetAndDeserialize('optional_uint32', 1 << 31, integer_64)
  589. else:
  590. # 64-bit python can fit uint32 inside an int
  591. TestGetAndDeserialize('optional_uint32', 1 << 31, int)
  592. TestGetAndDeserialize('optional_int64', 1 << 30, integer_64)
  593. TestGetAndDeserialize('optional_int64', 1 << 60, integer_64)
  594. TestGetAndDeserialize('optional_uint64', 1 << 30, integer_64)
  595. TestGetAndDeserialize('optional_uint64', 1 << 60, integer_64)
  596. def testIntegerTypes(self):
  597. self.assertIntegerTypes(lambda x: x)
  598. def testNonStandardIntegerTypes(self):
  599. self.assertIntegerTypes(test_util.NonStandardInteger)
  600. def testIllegalValuesForIntegers(self):
  601. pb = unittest_pb2.TestAllTypes()
  602. # Strings are illegal, even when the represent an integer.
  603. with self.assertRaises(TypeError):
  604. pb.optional_uint64 = '2'
  605. # The exact error should propagate with a poorly written custom integer.
  606. with self.assertRaisesRegexp(RuntimeError, 'my_error'):
  607. pb.optional_uint64 = test_util.NonStandardInteger(5, 'my_error')
  608. def assetIntegerBoundsChecking(self, integer_fn):
  609. """Verifies bounds checking for scalar integer fields.
  610. Args:
  611. integer_fn: A function to wrap the integers that will be assigned.
  612. """
  613. def TestMinAndMaxIntegers(field_name, expected_min, expected_max):
  614. pb = unittest_pb2.TestAllTypes()
  615. expected_min = integer_fn(expected_min)
  616. expected_max = integer_fn(expected_max)
  617. setattr(pb, field_name, expected_min)
  618. self.assertEqual(expected_min, getattr(pb, field_name))
  619. setattr(pb, field_name, expected_max)
  620. self.assertEqual(expected_max, getattr(pb, field_name))
  621. self.assertRaises((ValueError, TypeError), setattr, pb, field_name,
  622. expected_min - 1)
  623. self.assertRaises((ValueError, TypeError), setattr, pb, field_name,
  624. expected_max + 1)
  625. TestMinAndMaxIntegers('optional_int32', -(1 << 31), (1 << 31) - 1)
  626. TestMinAndMaxIntegers('optional_uint32', 0, 0xffffffff)
  627. TestMinAndMaxIntegers('optional_int64', -(1 << 63), (1 << 63) - 1)
  628. TestMinAndMaxIntegers('optional_uint64', 0, 0xffffffffffffffff)
  629. # A bit of white-box testing since -1 is an int and not a long in C++ and
  630. # so goes down a different path.
  631. pb = unittest_pb2.TestAllTypes()
  632. with self.assertRaises((ValueError, TypeError)):
  633. pb.optional_uint64 = integer_fn(-(1 << 63))
  634. pb = unittest_pb2.TestAllTypes()
  635. pb.optional_nested_enum = integer_fn(1)
  636. self.assertEqual(1, pb.optional_nested_enum)
  637. def testSingleScalarBoundsChecking(self):
  638. self.assetIntegerBoundsChecking(lambda x: x)
  639. def testNonStandardSingleScalarBoundsChecking(self):
  640. self.assetIntegerBoundsChecking(test_util.NonStandardInteger)
  641. def testRepeatedScalarTypeSafety(self):
  642. proto = unittest_pb2.TestAllTypes()
  643. self.assertRaises(TypeError, proto.repeated_int32.append, 1.1)
  644. self.assertRaises(TypeError, proto.repeated_int32.append, 'foo')
  645. self.assertRaises(TypeError, proto.repeated_string, 10)
  646. self.assertRaises(TypeError, proto.repeated_bytes, 10)
  647. proto.repeated_int32.append(10)
  648. proto.repeated_int32[0] = 23
  649. self.assertRaises(IndexError, proto.repeated_int32.__setitem__, 500, 23)
  650. self.assertRaises(TypeError, proto.repeated_int32.__setitem__, 0, 'abc')
  651. self.assertRaises(TypeError, proto.repeated_int32.__setitem__, 0, [])
  652. self.assertRaises(TypeError, proto.repeated_int32.__setitem__,
  653. 'index', 23)
  654. proto.repeated_string.append('2')
  655. self.assertRaises(TypeError, proto.repeated_string.__setitem__, 0, 10)
  656. # Repeated enums tests.
  657. #proto.repeated_nested_enum.append(0)
  658. def testSingleScalarGettersAndSetters(self):
  659. proto = unittest_pb2.TestAllTypes()
  660. self.assertEqual(0, proto.optional_int32)
  661. proto.optional_int32 = 1
  662. self.assertEqual(1, proto.optional_int32)
  663. proto.optional_uint64 = 0xffffffffffff
  664. self.assertEqual(0xffffffffffff, proto.optional_uint64)
  665. proto.optional_uint64 = 0xffffffffffffffff
  666. self.assertEqual(0xffffffffffffffff, proto.optional_uint64)
  667. # TODO(robinson): Test all other scalar field types.
  668. def testSingleScalarClearField(self):
  669. proto = unittest_pb2.TestAllTypes()
  670. # Should be allowed to clear something that's not there (a no-op).
  671. proto.ClearField('optional_int32')
  672. proto.optional_int32 = 1
  673. self.assertTrue(proto.HasField('optional_int32'))
  674. proto.ClearField('optional_int32')
  675. self.assertEqual(0, proto.optional_int32)
  676. self.assertTrue(not proto.HasField('optional_int32'))
  677. # TODO(robinson): Test all other scalar field types.
  678. def testEnums(self):
  679. proto = unittest_pb2.TestAllTypes()
  680. self.assertEqual(1, proto.FOO)
  681. self.assertEqual(1, unittest_pb2.TestAllTypes.FOO)
  682. self.assertEqual(2, proto.BAR)
  683. self.assertEqual(2, unittest_pb2.TestAllTypes.BAR)
  684. self.assertEqual(3, proto.BAZ)
  685. self.assertEqual(3, unittest_pb2.TestAllTypes.BAZ)
  686. def testEnum_Name(self):
  687. self.assertEqual('FOREIGN_FOO',
  688. unittest_pb2.ForeignEnum.Name(unittest_pb2.FOREIGN_FOO))
  689. self.assertEqual('FOREIGN_BAR',
  690. unittest_pb2.ForeignEnum.Name(unittest_pb2.FOREIGN_BAR))
  691. self.assertEqual('FOREIGN_BAZ',
  692. unittest_pb2.ForeignEnum.Name(unittest_pb2.FOREIGN_BAZ))
  693. self.assertRaises(ValueError,
  694. unittest_pb2.ForeignEnum.Name, 11312)
  695. proto = unittest_pb2.TestAllTypes()
  696. self.assertEqual('FOO',
  697. proto.NestedEnum.Name(proto.FOO))
  698. self.assertEqual('FOO',
  699. unittest_pb2.TestAllTypes.NestedEnum.Name(proto.FOO))
  700. self.assertEqual('BAR',
  701. proto.NestedEnum.Name(proto.BAR))
  702. self.assertEqual('BAR',
  703. unittest_pb2.TestAllTypes.NestedEnum.Name(proto.BAR))
  704. self.assertEqual('BAZ',
  705. proto.NestedEnum.Name(proto.BAZ))
  706. self.assertEqual('BAZ',
  707. unittest_pb2.TestAllTypes.NestedEnum.Name(proto.BAZ))
  708. self.assertRaises(ValueError,
  709. proto.NestedEnum.Name, 11312)
  710. self.assertRaises(ValueError,
  711. unittest_pb2.TestAllTypes.NestedEnum.Name, 11312)
  712. def testEnum_Value(self):
  713. self.assertEqual(unittest_pb2.FOREIGN_FOO,
  714. unittest_pb2.ForeignEnum.Value('FOREIGN_FOO'))
  715. self.assertEqual(unittest_pb2.FOREIGN_BAR,
  716. unittest_pb2.ForeignEnum.Value('FOREIGN_BAR'))
  717. self.assertEqual(unittest_pb2.FOREIGN_BAZ,
  718. unittest_pb2.ForeignEnum.Value('FOREIGN_BAZ'))
  719. self.assertRaises(ValueError,
  720. unittest_pb2.ForeignEnum.Value, 'FO')
  721. proto = unittest_pb2.TestAllTypes()
  722. self.assertEqual(proto.FOO,
  723. proto.NestedEnum.Value('FOO'))
  724. self.assertEqual(proto.FOO,
  725. unittest_pb2.TestAllTypes.NestedEnum.Value('FOO'))
  726. self.assertEqual(proto.BAR,
  727. proto.NestedEnum.Value('BAR'))
  728. self.assertEqual(proto.BAR,
  729. unittest_pb2.TestAllTypes.NestedEnum.Value('BAR'))
  730. self.assertEqual(proto.BAZ,
  731. proto.NestedEnum.Value('BAZ'))
  732. self.assertEqual(proto.BAZ,
  733. unittest_pb2.TestAllTypes.NestedEnum.Value('BAZ'))
  734. self.assertRaises(ValueError,
  735. proto.NestedEnum.Value, 'Foo')
  736. self.assertRaises(ValueError,
  737. unittest_pb2.TestAllTypes.NestedEnum.Value, 'Foo')
  738. def testEnum_KeysAndValues(self):
  739. self.assertEqual(['FOREIGN_FOO', 'FOREIGN_BAR', 'FOREIGN_BAZ'],
  740. list(unittest_pb2.ForeignEnum.keys()))
  741. self.assertEqual([4, 5, 6],
  742. list(unittest_pb2.ForeignEnum.values()))
  743. self.assertEqual([('FOREIGN_FOO', 4), ('FOREIGN_BAR', 5),
  744. ('FOREIGN_BAZ', 6)],
  745. list(unittest_pb2.ForeignEnum.items()))
  746. proto = unittest_pb2.TestAllTypes()
  747. self.assertEqual(['FOO', 'BAR', 'BAZ', 'NEG'], list(proto.NestedEnum.keys()))
  748. self.assertEqual([1, 2, 3, -1], list(proto.NestedEnum.values()))
  749. self.assertEqual([('FOO', 1), ('BAR', 2), ('BAZ', 3), ('NEG', -1)],
  750. list(proto.NestedEnum.items()))
  751. def testRepeatedScalars(self):
  752. proto = unittest_pb2.TestAllTypes()
  753. self.assertTrue(not proto.repeated_int32)
  754. self.assertEqual(0, len(proto.repeated_int32))
  755. proto.repeated_int32.append(5)
  756. proto.repeated_int32.append(10)
  757. proto.repeated_int32.append(15)
  758. self.assertTrue(proto.repeated_int32)
  759. self.assertEqual(3, len(proto.repeated_int32))
  760. self.assertEqual([5, 10, 15], proto.repeated_int32)
  761. # Test single retrieval.
  762. self.assertEqual(5, proto.repeated_int32[0])
  763. self.assertEqual(15, proto.repeated_int32[-1])
  764. # Test out-of-bounds indices.
  765. self.assertRaises(IndexError, proto.repeated_int32.__getitem__, 1234)
  766. self.assertRaises(IndexError, proto.repeated_int32.__getitem__, -1234)
  767. # Test incorrect types passed to __getitem__.
  768. self.assertRaises(TypeError, proto.repeated_int32.__getitem__, 'foo')
  769. self.assertRaises(TypeError, proto.repeated_int32.__getitem__, None)
  770. # Test single assignment.
  771. proto.repeated_int32[1] = 20
  772. self.assertEqual([5, 20, 15], proto.repeated_int32)
  773. # Test insertion.
  774. proto.repeated_int32.insert(1, 25)
  775. self.assertEqual([5, 25, 20, 15], proto.repeated_int32)
  776. # Test slice retrieval.
  777. proto.repeated_int32.append(30)
  778. self.assertEqual([25, 20, 15], proto.repeated_int32[1:4])
  779. self.assertEqual([5, 25, 20, 15, 30], proto.repeated_int32[:])
  780. # Test slice assignment with an iterator
  781. proto.repeated_int32[1:4] = (i for i in range(3))
  782. self.assertEqual([5, 0, 1, 2, 30], proto.repeated_int32)
  783. # Test slice assignment.
  784. proto.repeated_int32[1:4] = [35, 40, 45]
  785. self.assertEqual([5, 35, 40, 45, 30], proto.repeated_int32)
  786. # Test that we can use the field as an iterator.
  787. result = []
  788. for i in proto.repeated_int32:
  789. result.append(i)
  790. self.assertEqual([5, 35, 40, 45, 30], result)
  791. # Test single deletion.
  792. del proto.repeated_int32[2]
  793. self.assertEqual([5, 35, 45, 30], proto.repeated_int32)
  794. # Test slice deletion.
  795. del proto.repeated_int32[2:]
  796. self.assertEqual([5, 35], proto.repeated_int32)
  797. # Test extending.
  798. proto.repeated_int32.extend([3, 13])
  799. self.assertEqual([5, 35, 3, 13], proto.repeated_int32)
  800. # Test clearing.
  801. proto.ClearField('repeated_int32')
  802. self.assertTrue(not proto.repeated_int32)
  803. self.assertEqual(0, len(proto.repeated_int32))
  804. proto.repeated_int32.append(1)
  805. self.assertEqual(1, proto.repeated_int32[-1])
  806. # Test assignment to a negative index.
  807. proto.repeated_int32[-1] = 2
  808. self.assertEqual(2, proto.repeated_int32[-1])
  809. # Test deletion at negative indices.
  810. proto.repeated_int32[:] = [0, 1, 2, 3]
  811. del proto.repeated_int32[-1]
  812. self.assertEqual([0, 1, 2], proto.repeated_int32)
  813. del proto.repeated_int32[-2]
  814. self.assertEqual([0, 2], proto.repeated_int32)
  815. self.assertRaises(IndexError, proto.repeated_int32.__delitem__, -3)
  816. self.assertRaises(IndexError, proto.repeated_int32.__delitem__, 300)
  817. del proto.repeated_int32[-2:-1]
  818. self.assertEqual([2], proto.repeated_int32)
  819. del proto.repeated_int32[100:10000]
  820. self.assertEqual([2], proto.repeated_int32)
  821. def testRepeatedScalarsRemove(self):
  822. proto = unittest_pb2.TestAllTypes()
  823. self.assertTrue(not proto.repeated_int32)
  824. self.assertEqual(0, len(proto.repeated_int32))
  825. proto.repeated_int32.append(5)
  826. proto.repeated_int32.append(10)
  827. proto.repeated_int32.append(5)
  828. proto.repeated_int32.append(5)
  829. self.assertEqual(4, len(proto.repeated_int32))
  830. proto.repeated_int32.remove(5)
  831. self.assertEqual(3, len(proto.repeated_int32))
  832. self.assertEqual(10, proto.repeated_int32[0])
  833. self.assertEqual(5, proto.repeated_int32[1])
  834. self.assertEqual(5, proto.repeated_int32[2])
  835. proto.repeated_int32.remove(5)
  836. self.assertEqual(2, len(proto.repeated_int32))
  837. self.assertEqual(10, proto.repeated_int32[0])
  838. self.assertEqual(5, proto.repeated_int32[1])
  839. proto.repeated_int32.remove(10)
  840. self.assertEqual(1, len(proto.repeated_int32))
  841. self.assertEqual(5, proto.repeated_int32[0])
  842. # Remove a non-existent element.
  843. self.assertRaises(ValueError, proto.repeated_int32.remove, 123)
  844. def testRepeatedComposites(self):
  845. proto = unittest_pb2.TestAllTypes()
  846. self.assertTrue(not proto.repeated_nested_message)
  847. self.assertEqual(0, len(proto.repeated_nested_message))
  848. m0 = proto.repeated_nested_message.add()
  849. m1 = proto.repeated_nested_message.add()
  850. self.assertTrue(proto.repeated_nested_message)
  851. self.assertEqual(2, len(proto.repeated_nested_message))
  852. self.assertListsEqual([m0, m1], proto.repeated_nested_message)
  853. self.assertIsInstance(m0, unittest_pb2.TestAllTypes.NestedMessage)
  854. # Test out-of-bounds indices.
  855. self.assertRaises(IndexError, proto.repeated_nested_message.__getitem__,
  856. 1234)
  857. self.assertRaises(IndexError, proto.repeated_nested_message.__getitem__,
  858. -1234)
  859. # Test incorrect types passed to __getitem__.
  860. self.assertRaises(TypeError, proto.repeated_nested_message.__getitem__,
  861. 'foo')
  862. self.assertRaises(TypeError, proto.repeated_nested_message.__getitem__,
  863. None)
  864. # Test slice retrieval.
  865. m2 = proto.repeated_nested_message.add()
  866. m3 = proto.repeated_nested_message.add()
  867. m4 = proto.repeated_nested_message.add()
  868. self.assertListsEqual(
  869. [m1, m2, m3], proto.repeated_nested_message[1:4])
  870. self.assertListsEqual(
  871. [m0, m1, m2, m3, m4], proto.repeated_nested_message[:])
  872. self.assertListsEqual(
  873. [m0, m1], proto.repeated_nested_message[:2])
  874. self.assertListsEqual(
  875. [m2, m3, m4], proto.repeated_nested_message[2:])
  876. self.assertEqual(
  877. m0, proto.repeated_nested_message[0])
  878. self.assertListsEqual(
  879. [m0], proto.repeated_nested_message[:1])
  880. # Test that we can use the field as an iterator.
  881. result = []
  882. for i in proto.repeated_nested_message:
  883. result.append(i)
  884. self.assertListsEqual([m0, m1, m2, m3, m4], result)
  885. # Test single deletion.
  886. del proto.repeated_nested_message[2]
  887. self.assertListsEqual([m0, m1, m3, m4], proto.repeated_nested_message)
  888. # Test slice deletion.
  889. del proto.repeated_nested_message[2:]
  890. self.assertListsEqual([m0, m1], proto.repeated_nested_message)
  891. # Test extending.
  892. n1 = unittest_pb2.TestAllTypes.NestedMessage(bb=1)
  893. n2 = unittest_pb2.TestAllTypes.NestedMessage(bb=2)
  894. proto.repeated_nested_message.extend([n1,n2])
  895. self.assertEqual(4, len(proto.repeated_nested_message))
  896. self.assertEqual(n1, proto.repeated_nested_message[2])
  897. self.assertEqual(n2, proto.repeated_nested_message[3])
  898. self.assertRaises(TypeError,
  899. proto.repeated_nested_message.extend, n1)
  900. self.assertRaises(TypeError,
  901. proto.repeated_nested_message.extend, [0])
  902. wrong_message_type = unittest_pb2.TestAllTypes()
  903. self.assertRaises(TypeError,
  904. proto.repeated_nested_message.extend,
  905. [wrong_message_type])
  906. # Test clearing.
  907. proto.ClearField('repeated_nested_message')
  908. self.assertTrue(not proto.repeated_nested_message)
  909. self.assertEqual(0, len(proto.repeated_nested_message))
  910. # Test constructing an element while adding it.
  911. proto.repeated_nested_message.add(bb=23)
  912. self.assertEqual(1, len(proto.repeated_nested_message))
  913. self.assertEqual(23, proto.repeated_nested_message[0].bb)
  914. self.assertRaises(TypeError, proto.repeated_nested_message.add, 23)
  915. with self.assertRaises(Exception):
  916. proto.repeated_nested_message[0] = 23
  917. def testRepeatedCompositeRemove(self):
  918. proto = unittest_pb2.TestAllTypes()
  919. self.assertEqual(0, len(proto.repeated_nested_message))
  920. m0 = proto.repeated_nested_message.add()
  921. # Need to set some differentiating variable so m0 != m1 != m2:
  922. m0.bb = len(proto.repeated_nested_message)
  923. m1 = proto.repeated_nested_message.add()
  924. m1.bb = len(proto.repeated_nested_message)
  925. self.assertTrue(m0 != m1)
  926. m2 = proto.repeated_nested_message.add()
  927. m2.bb = len(proto.repeated_nested_message)
  928. self.assertListsEqual([m0, m1, m2], proto.repeated_nested_message)
  929. self.assertEqual(3, len(proto.repeated_nested_message))
  930. proto.repeated_nested_message.remove(m0)
  931. self.assertEqual(2, len(proto.repeated_nested_message))
  932. self.assertEqual(m1, proto.repeated_nested_message[0])
  933. self.assertEqual(m2, proto.repeated_nested_message[1])
  934. # Removing m0 again or removing None should raise error
  935. self.assertRaises(ValueError, proto.repeated_nested_message.remove, m0)
  936. self.assertRaises(ValueError, proto.repeated_nested_message.remove, None)
  937. self.assertEqual(2, len(proto.repeated_nested_message))
  938. proto.repeated_nested_message.remove(m2)
  939. self.assertEqual(1, len(proto.repeated_nested_message))
  940. self.assertEqual(m1, proto.repeated_nested_message[0])
  941. def testHandWrittenReflection(self):
  942. # Hand written extensions are only supported by the pure-Python
  943. # implementation of the API.
  944. if api_implementation.Type() != 'python':
  945. return
  946. FieldDescriptor = descriptor.FieldDescriptor
  947. foo_field_descriptor = FieldDescriptor(
  948. name='foo_field', full_name='MyProto.foo_field',
  949. index=0, number=1, type=FieldDescriptor.TYPE_INT64,
  950. cpp_type=FieldDescriptor.CPPTYPE_INT64,
  951. label=FieldDescriptor.LABEL_OPTIONAL, default_value=0,
  952. containing_type=None, message_type=None, enum_type=None,
  953. is_extension=False, extension_scope=None,
  954. options=descriptor_pb2.FieldOptions())
  955. mydescriptor = descriptor.Descriptor(
  956. name='MyProto', full_name='MyProto', filename='ignored',
  957. containing_type=None, nested_types=[], enum_types=[],
  958. fields=[foo_field_descriptor], extensions=[],
  959. options=descriptor_pb2.MessageOptions())
  960. class MyProtoClass(six.with_metaclass(reflection.GeneratedProtocolMessageType, message.Message)):
  961. DESCRIPTOR = mydescriptor
  962. myproto_instance = MyProtoClass()
  963. self.assertEqual(0, myproto_instance.foo_field)
  964. self.assertTrue(not myproto_instance.HasField('foo_field'))
  965. myproto_instance.foo_field = 23
  966. self.assertEqual(23, myproto_instance.foo_field)
  967. self.assertTrue(myproto_instance.HasField('foo_field'))
  968. @testing_refleaks.SkipReferenceLeakChecker('MakeDescriptor is not repeatable')
  969. def testDescriptorProtoSupport(self):
  970. # Hand written descriptors/reflection are only supported by the pure-Python
  971. # implementation of the API.
  972. if api_implementation.Type() != 'python':
  973. return
  974. def AddDescriptorField(proto, field_name, field_type):
  975. AddDescriptorField.field_index += 1
  976. new_field = proto.field.add()
  977. new_field.name = field_name
  978. new_field.type = field_type
  979. new_field.number = AddDescriptorField.field_index
  980. new_field.label = descriptor_pb2.FieldDescriptorProto.LABEL_OPTIONAL
  981. AddDescriptorField.field_index = 0
  982. desc_proto = descriptor_pb2.DescriptorProto()
  983. desc_proto.name = 'Car'
  984. fdp = descriptor_pb2.FieldDescriptorProto
  985. AddDescriptorField(desc_proto, 'name', fdp.TYPE_STRING)
  986. AddDescriptorField(desc_proto, 'year', fdp.TYPE_INT64)
  987. AddDescriptorField(desc_proto, 'automatic', fdp.TYPE_BOOL)
  988. AddDescriptorField(desc_proto, 'price', fdp.TYPE_DOUBLE)
  989. # Add a repeated field
  990. AddDescriptorField.field_index += 1
  991. new_field = desc_proto.field.add()
  992. new_field.name = 'owners'
  993. new_field.type = fdp.TYPE_STRING
  994. new_field.number = AddDescriptorField.field_index
  995. new_field.label = descriptor_pb2.FieldDescriptorProto.LABEL_REPEATED
  996. desc = descriptor.MakeDescriptor(desc_proto)
  997. self.assertTrue('name' in desc.fields_by_name)
  998. self.assertTrue('year' in desc.fields_by_name)
  999. self.assertTrue('automatic' in desc.fields_by_name)
  1000. self.assertTrue('price' in desc.fields_by_name)
  1001. self.assertTrue('owners' in desc.fields_by_name)
  1002. class CarMessage(six.with_metaclass(reflection.GeneratedProtocolMessageType,
  1003. message.Message)):
  1004. DESCRIPTOR = desc
  1005. prius = CarMessage()
  1006. prius.name = 'prius'
  1007. prius.year = 2010
  1008. prius.automatic = True
  1009. prius.price = 25134.75
  1010. prius.owners.extend(['bob', 'susan'])
  1011. serialized_prius = prius.SerializeToString()
  1012. new_prius = reflection.ParseMessage(desc, serialized_prius)
  1013. self.assertTrue(new_prius is not prius)
  1014. self.assertEqual(prius, new_prius)
  1015. # these are unnecessary assuming message equality works as advertised but
  1016. # explicitly check to be safe since we're mucking about in metaclass foo
  1017. self.assertEqual(prius.name, new_prius.name)
  1018. self.assertEqual(prius.year, new_prius.year)
  1019. self.assertEqual(prius.automatic, new_prius.automatic)
  1020. self.assertEqual(prius.price, new_prius.price)
  1021. self.assertEqual(prius.owners, new_prius.owners)
  1022. def testTopLevelExtensionsForOptionalScalar(self):
  1023. extendee_proto = unittest_pb2.TestAllExtensions()
  1024. extension = unittest_pb2.optional_int32_extension
  1025. self.assertTrue(not extendee_proto.HasExtension(extension))
  1026. self.assertEqual(0, extendee_proto.Extensions[extension])
  1027. # As with normal scalar fields, just doing a read doesn't actually set the
  1028. # "has" bit.
  1029. self.assertTrue(not extendee_proto.HasExtension(extension))
  1030. # Actually set the thing.
  1031. extendee_proto.Extensions[extension] = 23
  1032. self.assertEqual(23, extendee_proto.Extensions[extension])
  1033. self.assertTrue(extendee_proto.HasExtension(extension))
  1034. # Ensure that clearing works as well.
  1035. extendee_proto.ClearExtension(extension)
  1036. self.assertEqual(0, extendee_proto.Extensions[extension])
  1037. self.assertTrue(not extendee_proto.HasExtension(extension))
  1038. def testTopLevelExtensionsForRepeatedScalar(self):
  1039. extendee_proto = unittest_pb2.TestAllExtensions()
  1040. extension = unittest_pb2.repeated_string_extension
  1041. self.assertEqual(0, len(extendee_proto.Extensions[extension]))
  1042. extendee_proto.Extensions[extension].append('foo')
  1043. self.assertEqual(['foo'], extendee_proto.Extensions[extension])
  1044. string_list = extendee_proto.Extensions[extension]
  1045. extendee_proto.ClearExtension(extension)
  1046. self.assertEqual(0, len(extendee_proto.Extensions[extension]))
  1047. self.assertTrue(string_list is not extendee_proto.Extensions[extension])
  1048. # Shouldn't be allowed to do Extensions[extension] = 'a'
  1049. self.assertRaises(TypeError, operator.setitem, extendee_proto.Extensions,
  1050. extension, 'a')
  1051. def testTopLevelExtensionsForOptionalMessage(self):
  1052. extendee_proto = unittest_pb2.TestAllExtensions()
  1053. extension = unittest_pb2.optional_foreign_message_extension
  1054. self.assertTrue(not extendee_proto.HasExtension(extension))
  1055. self.assertEqual(0, extendee_proto.Extensions[extension].c)
  1056. # As with normal (non-extension) fields, merely reading from the
  1057. # thing shouldn't set the "has" bit.
  1058. self.assertTrue(not extendee_proto.HasExtension(extension))
  1059. extendee_proto.Extensions[extension].c = 23
  1060. self.assertEqual(23, extendee_proto.Extensions[extension].c)
  1061. self.assertTrue(extendee_proto.HasExtension(extension))
  1062. # Save a reference here.
  1063. foreign_message = extendee_proto.Extensions[extension]
  1064. extendee_proto.ClearExtension(extension)
  1065. self.assertTrue(foreign_message is not extendee_proto.Extensions[extension])
  1066. # Setting a field on foreign_message now shouldn't set
  1067. # any "has" bits on extendee_proto.
  1068. foreign_message.c = 42
  1069. self.assertEqual(42, foreign_message.c)
  1070. self.assertTrue(foreign_message.HasField('c'))
  1071. self.assertTrue(not extendee_proto.HasExtension(extension))
  1072. # Shouldn't be allowed to do Extensions[extension] = 'a'
  1073. self.assertRaises(TypeError, operator.setitem, extendee_proto.Extensions,
  1074. extension, 'a')
  1075. def testTopLevelExtensionsForRepeatedMessage(self):
  1076. extendee_proto = unittest_pb2.TestAllExtensions()
  1077. extension = unittest_pb2.repeatedgroup_extension
  1078. self.assertEqual(0, len(extendee_proto.Extensions[extension]))
  1079. group = extendee_proto.Extensions[extension].add()
  1080. group.a = 23
  1081. self.assertEqual(23, extendee_proto.Extensions[extension][0].a)
  1082. group.a = 42
  1083. self.assertEqual(42, extendee_proto.Extensions[extension][0].a)
  1084. group_list = extendee_proto.Extensions[extension]
  1085. extendee_proto.ClearExtension(extension)
  1086. self.assertEqual(0, len(extendee_proto.Extensions[extension]))
  1087. self.assertTrue(group_list is not extendee_proto.Extensions[extension])
  1088. # Shouldn't be allowed to do Extensions[extension] = 'a'
  1089. self.assertRaises(TypeError, operator.setitem, extendee_proto.Extensions,
  1090. extension, 'a')
  1091. def testNestedExtensions(self):
  1092. extendee_proto = unittest_pb2.TestAllExtensions()
  1093. extension = unittest_pb2.TestRequired.single
  1094. # We just test the non-repeated case.
  1095. self.assertTrue(not extendee_proto.HasExtension(extension))
  1096. required = extendee_proto.Extensions[extension]
  1097. self.assertEqual(0, required.a)
  1098. self.assertTrue(not extendee_proto.HasExtension(extension))
  1099. required.a = 23
  1100. self.assertEqual(23, extendee_proto.Extensions[extension].a)
  1101. self.assertTrue(extendee_proto.HasExtension(extension))
  1102. extendee_proto.ClearExtension(extension)
  1103. self.assertTrue(required is not extendee_proto.Extensions[extension])
  1104. self.assertTrue(not extendee_proto.HasExtension(extension))
  1105. def testRegisteredExtensions(self):
  1106. pool = unittest_pb2.DESCRIPTOR.pool
  1107. self.assertTrue(
  1108. pool.FindExtensionByNumber(
  1109. unittest_pb2.TestAllExtensions.DESCRIPTOR, 1))
  1110. self.assertIs(
  1111. pool.FindExtensionByName(
  1112. 'protobuf_unittest.optional_int32_extension').containing_type,
  1113. unittest_pb2.TestAllExtensions.DESCRIPTOR)
  1114. # Make sure extensions haven't been registered into types that shouldn't
  1115. # have any.
  1116. self.assertEqual(0, len(
  1117. pool.FindAllExtensions(unittest_pb2.TestAllTypes.DESCRIPTOR)))
  1118. # If message A directly contains message B, and
  1119. # a.HasField('b') is currently False, then mutating any
  1120. # extension in B should change a.HasField('b') to True
  1121. # (and so on up the object tree).
  1122. def testHasBitsForAncestorsOfExtendedMessage(self):
  1123. # Optional scalar extension.
  1124. toplevel = more_extensions_pb2.TopLevelMessage()
  1125. self.assertTrue(not toplevel.HasField('submessage'))
  1126. self.assertEqual(0, toplevel.submessage.Extensions[
  1127. more_extensions_pb2.optional_int_extension])
  1128. self.assertTrue(not toplevel.HasField('submessage'))
  1129. toplevel.submessage.Extensions[
  1130. more_extensions_pb2.optional_int_extension] = 23
  1131. self.assertEqual(23, toplevel.submessage.Extensions[
  1132. more_extensions_pb2.optional_int_extension])
  1133. self.assertTrue(toplevel.HasField('submessage'))
  1134. # Repeated scalar extension.
  1135. toplevel = more_extensions_pb2.TopLevelMessage()
  1136. self.assertTrue(not toplevel.HasField('submessage'))
  1137. self.assertEqual([], toplevel.submessage.Extensions[
  1138. more_extensions_pb2.repeated_int_extension])
  1139. self.assertTrue(not toplevel.HasField('submessage'))
  1140. toplevel.submessage.Extensions[
  1141. more_extensions_pb2.repeated_int_extension].append(23)
  1142. self.assertEqual([23], toplevel.submessage.Extensions[
  1143. more_extensions_pb2.repeated_int_extension])
  1144. self.assertTrue(toplevel.HasField('submessage'))
  1145. # Optional message extension.
  1146. toplevel = more_extensions_pb2.TopLevelMessage()
  1147. self.assertTrue(not toplevel.HasField('submessage'))
  1148. self.assertEqual(0, toplevel.submessage.Extensions[
  1149. more_extensions_pb2.optional_message_extension].foreign_message_int)
  1150. self.assertTrue(not toplevel.HasField('submessage'))
  1151. toplevel.submessage.Extensions[
  1152. more_extensions_pb2.optional_message_extension].foreign_message_int = 23
  1153. self.assertEqual(23, toplevel.submessage.Extensions[
  1154. more_extensions_pb2.optional_message_extension].foreign_message_int)
  1155. self.assertTrue(toplevel.HasField('submessage'))
  1156. # Repeated message extension.
  1157. toplevel = more_extensions_pb2.TopLevelMessage()
  1158. self.assertTrue(not toplevel.HasField('submessage'))
  1159. self.assertEqual(0, len(toplevel.submessage.Extensions[
  1160. more_extensions_pb2.repeated_message_extension]))
  1161. self.assertTrue(not toplevel.HasField('submessage'))
  1162. foreign = toplevel.submessage.Extensions[
  1163. more_extensions_pb2.repeated_message_extension].add()
  1164. self.assertEqual(foreign, toplevel.submessage.Extensions[
  1165. more_extensions_pb2.repeated_message_extension][0])
  1166. self.assertTrue(toplevel.HasField('submessage'))
  1167. def testDisconnectionAfterClearingEmptyMessage(self):
  1168. toplevel = more_extensions_pb2.TopLevelMessage()
  1169. extendee_proto = toplevel.submessage
  1170. extension = more_extensions_pb2.optional_message_extension
  1171. extension_proto = extendee_proto.Extensions[extension]
  1172. extendee_proto.ClearExtension(extension)
  1173. extension_proto.foreign_message_int = 23
  1174. self.assertTrue(extension_proto is not extendee_proto.Extensions[extension])
  1175. def testExtensionFailureModes(self):
  1176. extendee_proto = unittest_pb2.TestAllExtensions()
  1177. # Try non-extension-handle arguments to HasExtension,
  1178. # ClearExtension(), and Extensions[]...
  1179. self.assertRaises(KeyError, extendee_proto.HasExtension, 1234)
  1180. self.assertRaises(KeyError, extendee_proto.ClearExtension, 1234)
  1181. self.assertRaises(KeyError, extendee_proto.Extensions.__getitem__, 1234)
  1182. self.assertRaises(KeyError, extendee_proto.Extensions.__setitem__, 1234, 5)
  1183. # Try something that *is* an extension handle, just not for
  1184. # this message...
  1185. for unknown_handle in (more_extensions_pb2.optional_int_extension,
  1186. more_extensions_pb2.optional_message_extension,
  1187. more_extensions_pb2.repeated_int_extension,
  1188. more_extensions_pb2.repeated_message_extension):
  1189. self.assertRaises(KeyError, extendee_proto.HasExtension,
  1190. unknown_handle)
  1191. self.assertRaises(KeyError, extendee_proto.ClearExtension,
  1192. unknown_handle)
  1193. self.assertRaises(KeyError, extendee_proto.Extensions.__getitem__,
  1194. unknown_handle)
  1195. self.assertRaises(KeyError, extendee_proto.Extensions.__setitem__,
  1196. unknown_handle, 5)
  1197. # Try call HasExtension() with a valid handle, but for a
  1198. # *repeated* field. (Just as with non-extension repeated
  1199. # fields, Has*() isn't supported for extension repeated fields).
  1200. self.assertRaises(KeyError, extendee_proto.HasExtension,
  1201. unittest_pb2.repeated_string_extension)
  1202. def testStaticParseFrom(self):
  1203. proto1 = unittest_pb2.TestAllTypes()
  1204. test_util.SetAllFields(proto1)
  1205. string1 = proto1.SerializeToString()
  1206. proto2 = unittest_pb2.TestAllTypes.FromString(string1)
  1207. # Messages should be equal.
  1208. self.assertEqual(proto2, proto1)
  1209. def testMergeFromSingularField(self):
  1210. # Test merge with just a singular field.
  1211. proto1 = unittest_pb2.TestAllTypes()
  1212. proto1.optional_int32 = 1
  1213. proto2 = unittest_pb2.TestAllTypes()
  1214. # This shouldn't get overwritten.
  1215. proto2.optional_string = 'value'
  1216. proto2.MergeFrom(proto1)
  1217. self.assertEqual(1, proto2.optional_int32)
  1218. self.assertEqual('value', proto2.optional_string)
  1219. def testMergeFromRepeatedField(self):
  1220. # Test merge with just a repeated field.
  1221. proto1 = unittest_pb2.TestAllTypes()
  1222. proto1.repeated_int32.append(1)
  1223. proto1.repeated_int32.append(2)
  1224. proto2 = unittest_pb2.TestAllTypes()
  1225. proto2.repeated_int32.append(0)
  1226. proto2.MergeFrom(proto1)
  1227. self.assertEqual(0, proto2.repeated_int32[0])
  1228. self.assertEqual(1, proto2.repeated_int32[1])
  1229. self.assertEqual(2, proto2.repeated_int32[2])
  1230. def testMergeFromOptionalGroup(self):
  1231. # Test merge with an optional group.
  1232. proto1 = unittest_pb2.TestAllTypes()
  1233. proto1.optionalgroup.a = 12
  1234. proto2 = unittest_pb2.TestAllTypes()
  1235. proto2.MergeFrom(proto1)
  1236. self.assertEqual(12, proto2.optionalgroup.a)
  1237. def testMergeFromRepeatedNestedMessage(self):
  1238. # Test merge with a repeated nested message.
  1239. proto1 = unittest_pb2.TestAllTypes()
  1240. m = proto1.repeated_nested_message.add()
  1241. m.bb = 123
  1242. m = proto1.repeated_nested_message.add()
  1243. m.bb = 321
  1244. proto2 = unittest_pb2.TestAllTypes()
  1245. m = proto2.repeated_nested_message.add()
  1246. m.bb = 999
  1247. proto2.MergeFrom(proto1)
  1248. self.assertEqual(999, proto2.repeated_nested_message[0].bb)
  1249. self.assertEqual(123, proto2.repeated_nested_message[1].bb)
  1250. self.assertEqual(321, proto2.repeated_nested_message[2].bb)
  1251. proto3 = unittest_pb2.TestAllTypes()
  1252. proto3.repeated_nested_message.MergeFrom(proto2.repeated_nested_message)
  1253. self.assertEqual(999, proto3.repeated_nested_message[0].bb)
  1254. self.assertEqual(123, proto3.repeated_nested_message[1].bb)
  1255. self.assertEqual(321, proto3.repeated_nested_message[2].bb)
  1256. def testMergeFromAllFields(self):
  1257. # With all fields set.
  1258. proto1 = unittest_pb2.TestAllTypes()
  1259. test_util.SetAllFields(proto1)
  1260. proto2 = unittest_pb2.TestAllTypes()
  1261. proto2.MergeFrom(proto1)
  1262. # Messages should be equal.
  1263. self.assertEqual(proto2, proto1)
  1264. # Serialized string should be equal too.
  1265. string1 = proto1.SerializeToString()
  1266. string2 = proto2.SerializeToString()
  1267. self.assertEqual(string1, string2)
  1268. def testMergeFromExtensionsSingular(self):
  1269. proto1 = unittest_pb2.TestAllExtensions()
  1270. proto1.Extensions[unittest_pb2.optional_int32_extension] = 1
  1271. proto2 = unittest_pb2.TestAllExtensions()
  1272. proto2.MergeFrom(proto1)
  1273. self.assertEqual(
  1274. 1, proto2.Extensions[unittest_pb2.optional_int32_extension])
  1275. def testMergeFromExtensionsRepeated(self):
  1276. proto1 = unittest_pb2.TestAllExtensions()
  1277. proto1.Extensions[unittest_pb2.repeated_int32_extension].append(1)
  1278. proto1.Extensions[unittest_pb2.repeated_int32_extension].append(2)
  1279. proto2 = unittest_pb2.TestAllExtensions()
  1280. proto2.Extensions[unittest_pb2.repeated_int32_extension].append(0)
  1281. proto2.MergeFrom(proto1)
  1282. self.assertEqual(
  1283. 3, len(proto2.Extensions[unittest_pb2.repeated_int32_extension]))
  1284. self.assertEqual(
  1285. 0, proto2.Extensions[unittest_pb2.repeated_int32_extension][0])
  1286. self.assertEqual(
  1287. 1, proto2.Extensions[unittest_pb2.repeated_int32_extension][1])
  1288. self.assertEqual(
  1289. 2, proto2.Extensions[unittest_pb2.repeated_int32_extension][2])
  1290. def testMergeFromExtensionsNestedMessage(self):
  1291. proto1 = unittest_pb2.TestAllExtensions()
  1292. ext1 = proto1.Extensions[
  1293. unittest_pb2.repeated_nested_message_extension]
  1294. m = ext1.add()
  1295. m.bb = 222
  1296. m = ext1.add()
  1297. m.bb = 333
  1298. proto2 = unittest_pb2.TestAllExtensions()
  1299. ext2 = proto2.Extensions[
  1300. unittest_pb2.repeated_nested_message_extension]
  1301. m = ext2.add()
  1302. m.bb = 111
  1303. proto2.MergeFrom(proto1)
  1304. ext2 = proto2.Extensions[
  1305. unittest_pb2.repeated_nested_message_extension]
  1306. self.assertEqual(3, len(ext2))
  1307. self.assertEqual(111, ext2[0].bb)
  1308. self.assertEqual(222, ext2[1].bb)
  1309. self.assertEqual(333, ext2[2].bb)
  1310. def testMergeFromBug(self):
  1311. message1 = unittest_pb2.TestAllTypes()
  1312. message2 = unittest_pb2.TestAllTypes()
  1313. # Cause optional_nested_message to be instantiated within message1, even
  1314. # though it is not considered to be "present".
  1315. message1.optional_nested_message
  1316. self.assertFalse(message1.HasField('optional_nested_message'))
  1317. # Merge into message2. This should not instantiate the field is message2.
  1318. message2.MergeFrom(message1)
  1319. self.assertFalse(message2.HasField('optional_nested_message'))
  1320. def testCopyFromSingularField(self):
  1321. # Test copy with just a singular field.
  1322. proto1 = unittest_pb2.TestAllTypes()
  1323. proto1.optional_int32 = 1
  1324. proto1.optional_string = 'important-text'
  1325. proto2 = unittest_pb2.TestAllTypes()
  1326. proto2.optional_string = 'value'
  1327. proto2.CopyFrom(proto1)
  1328. self.assertEqual(1, proto2.optional_int32)
  1329. self.assertEqual('important-text', proto2.optional_string)
  1330. def testCopyFromRepeatedField(self):
  1331. # Test copy with a repeated field.
  1332. proto1 = unittest_pb2.TestAllTypes()
  1333. proto1.repeated_int32.append(1)
  1334. proto1.repeated_int32.append(2)
  1335. proto2 = unittest_pb2.TestAllTypes()
  1336. proto2.repeated_int32.append(0)
  1337. proto2.CopyFrom(proto1)
  1338. self.assertEqual(1, proto2.repeated_int32[0])
  1339. self.assertEqual(2, proto2.repeated_int32[1])
  1340. def testCopyFromAllFields(self):
  1341. # With all fields set.
  1342. proto1 = unittest_pb2.TestAllTypes()
  1343. test_util.SetAllFields(proto1)
  1344. proto2 = unittest_pb2.TestAllTypes()
  1345. proto2.CopyFrom(proto1)
  1346. # Messages should be equal.
  1347. self.assertEqual(proto2, proto1)
  1348. # Serialized string should be equal too.
  1349. string1 = proto1.SerializeToString()
  1350. string2 = proto2.SerializeToString()
  1351. self.assertEqual(string1, string2)
  1352. def testCopyFromSelf(self):
  1353. proto1 = unittest_pb2.TestAllTypes()
  1354. proto1.repeated_int32.append(1)
  1355. proto1.optional_int32 = 2
  1356. proto1.optional_string = 'important-text'
  1357. proto1.CopyFrom(proto1)
  1358. self.assertEqual(1, proto1.repeated_int32[0])
  1359. self.assertEqual(2, proto1.optional_int32)
  1360. self.assertEqual('important-text', proto1.optional_string)
  1361. def testCopyFromBadType(self):
  1362. # The python implementation doesn't raise an exception in this
  1363. # case. In theory it should.
  1364. if api_implementation.Type() == 'python':
  1365. return
  1366. proto1 = unittest_pb2.TestAllTypes()
  1367. proto2 = unittest_pb2.TestAllExtensions()
  1368. self.assertRaises(TypeError, proto1.CopyFrom, proto2)
  1369. def testDeepCopy(self):
  1370. proto1 = unittest_pb2.TestAllTypes()
  1371. proto1.optional_int32 = 1
  1372. proto2 = copy.deepcopy(proto1)
  1373. self.assertEqual(1, proto2.optional_int32)
  1374. proto1.repeated_int32.append(2)
  1375. proto1.repeated_int32.append(3)
  1376. container = copy.deepcopy(proto1.repeated_int32)
  1377. self.assertEqual([2, 3], container)
  1378. message1 = proto1.repeated_nested_message.add()
  1379. message1.bb = 1
  1380. messages = copy.deepcopy(proto1.repeated_nested_message)
  1381. self.assertEqual(proto1.repeated_nested_message, messages)
  1382. message1.bb = 2
  1383. self.assertNotEqual(proto1.repeated_nested_message, messages)
  1384. # TODO(anuraag): Implement deepcopy for extension dict
  1385. def testClear(self):
  1386. proto = unittest_pb2.TestAllTypes()
  1387. # C++ implementation does not support lazy fields right now so leave it
  1388. # out for now.
  1389. if api_implementation.Type() == 'python':
  1390. test_util.SetAllFields(proto)
  1391. else:
  1392. test_util.SetAllNonLazyFields(proto)
  1393. # Clear the message.
  1394. proto.Clear()
  1395. self.assertEqual(proto.ByteSize(), 0)
  1396. empty_proto = unittest_pb2.TestAllTypes()
  1397. self.assertEqual(proto, empty_proto)
  1398. # Test if extensions which were set are cleared.
  1399. proto = unittest_pb2.TestAllExtensions()
  1400. test_util.SetAllExtensions(proto)
  1401. # Clear the message.
  1402. proto.Clear()
  1403. self.assertEqual(proto.ByteSize(), 0)
  1404. empty_proto = unittest_pb2.TestAllExtensions()
  1405. self.assertEqual(proto, empty_proto)
  1406. def testDisconnectingBeforeClear(self):
  1407. proto = unittest_pb2.TestAllTypes()
  1408. nested = proto.optional_nested_message
  1409. proto.Clear()
  1410. self.assertTrue(nested is not proto.optional_nested_message)
  1411. nested.bb = 23
  1412. self.assertTrue(not proto.HasField('optional_nested_message'))
  1413. self.assertEqual(0, proto.optional_nested_message.bb)
  1414. proto = unittest_pb2.TestAllTypes()
  1415. nested = proto.optional_nested_message
  1416. nested.bb = 5
  1417. foreign = proto.optional_foreign_message
  1418. foreign.c = 6
  1419. proto.Clear()
  1420. self.assertTrue(nested is not proto.optional_nested_message)
  1421. self.assertTrue(foreign is not proto.optional_foreign_message)
  1422. self.assertEqual(5, nested.bb)
  1423. self.assertEqual(6, foreign.c)
  1424. nested.bb = 15
  1425. foreign.c = 16
  1426. self.assertFalse(proto.HasField('optional_nested_message'))
  1427. self.assertEqual(0, proto.optional_nested_message.bb)
  1428. self.assertFalse(proto.HasField('optional_foreign_message'))
  1429. self.assertEqual(0, proto.optional_foreign_message.c)
  1430. def testDisconnectingInOneof(self):
  1431. m = unittest_pb2.TestOneof2() # This message has two messages in a oneof.
  1432. m.foo_message.qux_int = 5
  1433. sub_message = m.foo_message
  1434. # Accessing another message's field does not clear the first one
  1435. self.assertEqual(m.foo_lazy_message.qux_int, 0)
  1436. self.assertEqual(m.foo_message.qux_int, 5)
  1437. # But mutating another message in the oneof detaches the first one.
  1438. m.foo_lazy_message.qux_int = 6
  1439. self.assertEqual(m.foo_message.qux_int, 0)
  1440. # The reference we got above was detached and is still valid.
  1441. self.assertEqual(sub_message.qux_int, 5)
  1442. sub_message.qux_int = 7
  1443. def testOneOf(self):
  1444. proto = unittest_pb2.TestAllTypes()
  1445. proto.oneof_uint32 = 10
  1446. proto.oneof_nested_message.bb = 11
  1447. self.assertEqual(11, proto.oneof_nested_message.bb)
  1448. self.assertFalse(proto.HasField('oneof_uint32'))
  1449. nested = proto.oneof_nested_message
  1450. proto.oneof_string = 'abc'
  1451. self.assertEqual('abc', proto.oneof_string)
  1452. self.assertEqual(11, nested.bb)
  1453. self.assertFalse(proto.HasField('oneof_nested_message'))
  1454. def assertInitialized(self, proto):
  1455. self.assertTrue(proto.IsInitialized())
  1456. # Neither method should raise an exception.
  1457. proto.SerializeToString()
  1458. proto.SerializePartialToString()
  1459. def assertNotInitialized(self, proto, error_size=None):
  1460. errors = []
  1461. self.assertFalse(proto.IsInitialized())
  1462. self.assertFalse(proto.IsInitialized(errors))
  1463. self.assertEqual(error_size, len(errors))
  1464. self.assertRaises(message.EncodeError, proto.SerializeToString)
  1465. # "Partial" serialization doesn't care if message is uninitialized.
  1466. proto.SerializePartialToString()
  1467. def testIsInitialized(self):
  1468. # Trivial cases - all optional fields and extensions.
  1469. proto = unittest_pb2.TestAllTypes()
  1470. self.assertInitialized(proto)
  1471. proto = unittest_pb2.TestAllExtensions()
  1472. self.assertInitialized(proto)
  1473. # The case of uninitialized required fields.
  1474. proto = unittest_pb2.TestRequired()
  1475. self.assertNotInitialized(proto, 3)
  1476. proto.a = proto.b = proto.c = 2
  1477. self.assertInitialized(proto)
  1478. # The case of uninitialized submessage.
  1479. proto = unittest_pb2.TestRequiredForeign()
  1480. self.assertInitialized(proto)
  1481. proto.optional_message.a = 1
  1482. self.assertNotInitialized(proto, 2)
  1483. proto.optional_message.b = 0
  1484. proto.optional_message.c = 0
  1485. self.assertInitialized(proto)
  1486. # Uninitialized repeated submessage.
  1487. message1 = proto.repeated_message.add()
  1488. self.assertNotInitialized(proto, 3)
  1489. message1.a = message1.b = message1.c = 0
  1490. self.assertInitialized(proto)
  1491. # Uninitialized repeated group in an extension.
  1492. proto = unittest_pb2.TestAllExtensions()
  1493. extension = unittest_pb2.TestRequired.multi
  1494. message1 = proto.Extensions[extension].add()
  1495. message2 = proto.Extensions[extension].add()
  1496. self.assertNotInitialized(proto, 6)
  1497. message1.a = 1
  1498. message1.b = 1
  1499. message1.c = 1
  1500. self.assertNotInitialized(proto, 3)
  1501. message2.a = 2
  1502. message2.b = 2
  1503. message2.c = 2
  1504. self.assertInitialized(proto)
  1505. # Uninitialized nonrepeated message in an extension.
  1506. proto = unittest_pb2.TestAllExtensions()
  1507. extension = unittest_pb2.TestRequired.single
  1508. proto.Extensions[extension].a = 1
  1509. self.assertNotInitialized(proto, 2)
  1510. proto.Extensions[extension].b = 2
  1511. proto.Extensions[extension].c = 3
  1512. self.assertInitialized(proto)
  1513. # Try passing an errors list.
  1514. errors = []
  1515. proto = unittest_pb2.TestRequired()
  1516. self.assertFalse(proto.IsInitialized(errors))
  1517. self.assertEqual(errors, ['a', 'b', 'c'])
  1518. @unittest.skipIf(
  1519. api_implementation.Type() != 'cpp' or api_implementation.Version() != 2,
  1520. 'Errors are only available from the most recent C++ implementation.')
  1521. def testFileDescriptorErrors(self):
  1522. file_name = 'test_file_descriptor_errors.proto'
  1523. package_name = 'test_file_descriptor_errors.proto'
  1524. file_descriptor_proto = descriptor_pb2.FileDescriptorProto()
  1525. file_descriptor_proto.name = file_name
  1526. file_descriptor_proto.package = package_name
  1527. m1 = file_descriptor_proto.message_type.add()
  1528. m1.name = 'msg1'
  1529. # Compiles the proto into the C++ descriptor pool
  1530. descriptor.FileDescriptor(
  1531. file_name,
  1532. package_name,
  1533. serialized_pb=file_descriptor_proto.SerializeToString())
  1534. # Add a FileDescriptorProto that has duplicate symbols
  1535. another_file_name = 'another_test_file_descriptor_errors.proto'
  1536. file_descriptor_proto.name = another_file_name
  1537. m2 = file_descriptor_proto.message_type.add()
  1538. m2.name = 'msg2'
  1539. with self.assertRaises(TypeError) as cm:
  1540. descriptor.FileDescriptor(
  1541. another_file_name,
  1542. package_name,
  1543. serialized_pb=file_descriptor_proto.SerializeToString())
  1544. self.assertTrue(hasattr(cm, 'exception'), '%s not raised' %
  1545. getattr(cm.expected, '__name__', cm.expected))
  1546. self.assertIn('test_file_descriptor_errors.proto', str(cm.exception))
  1547. # Error message will say something about this definition being a
  1548. # duplicate, though we don't check the message exactly to avoid a
  1549. # dependency on the C++ logging code.
  1550. self.assertIn('test_file_descriptor_errors.msg1', str(cm.exception))
  1551. def testStringUTF8Encoding(self):
  1552. proto = unittest_pb2.TestAllTypes()
  1553. # Assignment of a unicode object to a field of type 'bytes' is not allowed.
  1554. self.assertRaises(TypeError,
  1555. setattr, proto, 'optional_bytes', u'unicode object')
  1556. # Check that the default value is of python's 'unicode' type.
  1557. self.assertEqual(type(proto.optional_string), six.text_type)
  1558. proto.optional_string = six.text_type('Testing')
  1559. self.assertEqual(proto.optional_string, str('Testing'))
  1560. # Assign a value of type 'str' which can be encoded in UTF-8.
  1561. proto.optional_string = str('Testing')
  1562. self.assertEqual(proto.optional_string, six.text_type('Testing'))
  1563. # Try to assign a 'bytes' object which contains non-UTF-8.
  1564. self.assertRaises(ValueError,
  1565. setattr, proto, 'optional_string', b'a\x80a')
  1566. # No exception: Assign already encoded UTF-8 bytes to a string field.
  1567. utf8_bytes = u'Тест'.encode('utf-8')
  1568. proto.optional_string = utf8_bytes
  1569. # No exception: Assign the a non-ascii unicode object.
  1570. proto.optional_string = u'Тест'
  1571. # No exception thrown (normal str assignment containing ASCII).
  1572. proto.optional_string = 'abc'
  1573. def testStringUTF8Serialization(self):
  1574. proto = message_set_extensions_pb2.TestMessageSet()
  1575. extension_message = message_set_extensions_pb2.TestMessageSetExtension2
  1576. extension = extension_message.message_set_extension
  1577. test_utf8 = u'Тест'
  1578. test_utf8_bytes = test_utf8.encode('utf-8')
  1579. # 'Test' in another language, using UTF-8 charset.
  1580. proto.Extensions[extension].str = test_utf8
  1581. # Serialize using the MessageSet wire format (this is specified in the
  1582. # .proto file).
  1583. serialized = proto.SerializeToString()
  1584. # Check byte size.
  1585. self.assertEqual(proto.ByteSize(), len(serialized))
  1586. raw = unittest_mset_pb2.RawMessageSet()
  1587. bytes_read = raw.MergeFromString(serialized)
  1588. self.assertEqual(len(serialized), bytes_read)
  1589. message2 = message_set_extensions_pb2.TestMessageSetExtension2()
  1590. self.assertEqual(1, len(raw.item))
  1591. # Check that the type_id is the same as the tag ID in the .proto file.
  1592. self.assertEqual(raw.item[0].type_id, 98418634)
  1593. # Check the actual bytes on the wire.
  1594. self.assertTrue(raw.item[0].message.endswith(test_utf8_bytes))
  1595. bytes_read = message2.MergeFromString(raw.item[0].message)
  1596. self.assertEqual(len(raw.item[0].message), bytes_read)
  1597. self.assertEqual(type(message2.str), six.text_type)
  1598. self.assertEqual(message2.str, test_utf8)
  1599. # The pure Python API throws an exception on MergeFromString(),
  1600. # if any of the string fields of the message can't be UTF-8 decoded.
  1601. # The C++ implementation of the API has no way to check that on
  1602. # MergeFromString and thus has no way to throw the exception.
  1603. #
  1604. # The pure Python API always returns objects of type 'unicode' (UTF-8
  1605. # encoded), or 'bytes' (in 7 bit ASCII).
  1606. badbytes = raw.item[0].message.replace(
  1607. test_utf8_bytes, len(test_utf8_bytes) * b'\xff')
  1608. unicode_decode_failed = False
  1609. try:
  1610. message2.MergeFromString(badbytes)
  1611. except UnicodeDecodeError:
  1612. unicode_decode_failed = True
  1613. string_field = message2.str
  1614. self.assertTrue(unicode_decode_failed or type(string_field) is bytes)
  1615. def testBytesInTextFormat(self):
  1616. proto = unittest_pb2.TestAllTypes(optional_bytes=b'\x00\x7f\x80\xff')
  1617. self.assertEqual(u'optional_bytes: "\\000\\177\\200\\377"\n',
  1618. six.text_type(proto))
  1619. def testEmptyNestedMessage(self):
  1620. proto = unittest_pb2.TestAllTypes()
  1621. proto.optional_nested_message.MergeFrom(
  1622. unittest_pb2.TestAllTypes.NestedMessage())
  1623. self.assertTrue(proto.HasField('optional_nested_message'))
  1624. proto = unittest_pb2.TestAllTypes()
  1625. proto.optional_nested_message.CopyFrom(
  1626. unittest_pb2.TestAllTypes.NestedMessage())
  1627. self.assertTrue(proto.HasField('optional_nested_message'))
  1628. proto = unittest_pb2.TestAllTypes()
  1629. bytes_read = proto.optional_nested_message.MergeFromString(b'')
  1630. self.assertEqual(0, bytes_read)
  1631. self.assertTrue(proto.HasField('optional_nested_message'))
  1632. proto = unittest_pb2.TestAllTypes()
  1633. proto.optional_nested_message.ParseFromString(b'')
  1634. self.assertTrue(proto.HasField('optional_nested_message'))
  1635. serialized = proto.SerializeToString()
  1636. proto2 = unittest_pb2.TestAllTypes()
  1637. self.assertEqual(
  1638. len(serialized),
  1639. proto2.MergeFromString(serialized))
  1640. self.assertTrue(proto2.HasField('optional_nested_message'))
  1641. def testSetInParent(self):
  1642. proto = unittest_pb2.TestAllTypes()
  1643. self.assertFalse(proto.HasField('optionalgroup'))
  1644. proto.optionalgroup.SetInParent()
  1645. self.assertTrue(proto.HasField('optionalgroup'))
  1646. def testPackageInitializationImport(self):
  1647. """Test that we can import nested messages from their __init__.py.
  1648. Such setup is not trivial since at the time of processing of __init__.py one
  1649. can't refer to its submodules by name in code, so expressions like
  1650. google.protobuf.internal.import_test_package.inner_pb2
  1651. don't work. They do work in imports, so we have assign an alias at import
  1652. and then use that alias in generated code.
  1653. """
  1654. # We import here since it's the import that used to fail, and we want
  1655. # the failure to have the right context.
  1656. # pylint: disable=g-import-not-at-top
  1657. from google.protobuf.internal import import_test_package
  1658. # pylint: enable=g-import-not-at-top
  1659. msg = import_test_package.myproto.Outer()
  1660. # Just check the default value.
  1661. self.assertEqual(57, msg.inner.value)
  1662. # Since we had so many tests for protocol buffer equality, we broke these out
  1663. # into separate TestCase classes.
  1664. class TestAllTypesEqualityTest(BaseTestCase):
  1665. def setUp(self):
  1666. self.first_proto = unittest_pb2.TestAllTypes()
  1667. self.second_proto = unittest_pb2.TestAllTypes()
  1668. def testNotHashable(self):
  1669. self.assertRaises(TypeError, hash, self.first_proto)
  1670. def testSelfEquality(self):
  1671. self.assertEqual(self.first_proto, self.first_proto)
  1672. def testEmptyProtosEqual(self):
  1673. self.assertEqual(self.first_proto, self.second_proto)
  1674. class FullProtosEqualityTest(BaseTestCase):
  1675. """Equality tests using completely-full protos as a starting point."""
  1676. def setUp(self):
  1677. self.first_proto = unittest_pb2.TestAllTypes()
  1678. self.second_proto = unittest_pb2.TestAllTypes()
  1679. test_util.SetAllFields(self.first_proto)
  1680. test_util.SetAllFields(self.second_proto)
  1681. def testNotHashable(self):
  1682. self.assertRaises(TypeError, hash, self.first_proto)
  1683. def testNoneNotEqual(self):
  1684. self.assertNotEqual(self.first_proto, None)
  1685. self.assertNotEqual(None, self.second_proto)
  1686. def testNotEqualToOtherMessage(self):
  1687. third_proto = unittest_pb2.TestRequired()
  1688. self.assertNotEqual(self.first_proto, third_proto)
  1689. self.assertNotEqual(third_proto, self.second_proto)
  1690. def testAllFieldsFilledEquality(self):
  1691. self.assertEqual(self.first_proto, self.second_proto)
  1692. def testNonRepeatedScalar(self):
  1693. # Nonrepeated scalar field change should cause inequality.
  1694. self.first_proto.optional_int32 += 1
  1695. self.assertNotEqual(self.first_proto, self.second_proto)
  1696. # ...as should clearing a field.
  1697. self.first_proto.ClearField('optional_int32')
  1698. self.assertNotEqual(self.first_proto, self.second_proto)
  1699. def testNonRepeatedComposite(self):
  1700. # Change a nonrepeated composite field.
  1701. self.first_proto.optional_nested_message.bb += 1
  1702. self.assertNotEqual(self.first_proto, self.second_proto)
  1703. self.first_proto.optional_nested_message.bb -= 1
  1704. self.assertEqual(self.first_proto, self.second_proto)
  1705. # Clear a field in the nested message.
  1706. self.first_proto.optional_nested_message.ClearField('bb')
  1707. self.assertNotEqual(self.first_proto, self.second_proto)
  1708. self.first_proto.optional_nested_message.bb = (
  1709. self.second_proto.optional_nested_message.bb)
  1710. self.assertEqual(self.first_proto, self.second_proto)
  1711. # Remove the nested message entirely.
  1712. self.first_proto.ClearField('optional_nested_message')
  1713. self.assertNotEqual(self.first_proto, self.second_proto)
  1714. def testRepeatedScalar(self):
  1715. # Change a repeated scalar field.
  1716. self.first_proto.repeated_int32.append(5)
  1717. self.assertNotEqual(self.first_proto, self.second_proto)
  1718. self.first_proto.ClearField('repeated_int32')
  1719. self.assertNotEqual(self.first_proto, self.second_proto)
  1720. def testRepeatedComposite(self):
  1721. # Change value within a repeated composite field.
  1722. self.first_proto.repeated_nested_message[0].bb += 1
  1723. self.assertNotEqual(self.first_proto, self.second_proto)
  1724. self.first_proto.repeated_nested_message[0].bb -= 1
  1725. self.assertEqual(self.first_proto, self.second_proto)
  1726. # Add a value to a repeated composite field.
  1727. self.first_proto.repeated_nested_message.add()
  1728. self.assertNotEqual(self.first_proto, self.second_proto)
  1729. self.second_proto.repeated_nested_message.add()
  1730. self.assertEqual(self.first_proto, self.second_proto)
  1731. def testNonRepeatedScalarHasBits(self):
  1732. # Ensure that we test "has" bits as well as value for
  1733. # nonrepeated scalar field.
  1734. self.first_proto.ClearField('optional_int32')
  1735. self.second_proto.optional_int32 = 0
  1736. self.assertNotEqual(self.first_proto, self.second_proto)
  1737. def testNonRepeatedCompositeHasBits(self):
  1738. # Ensure that we test "has" bits as well as value for
  1739. # nonrepeated composite field.
  1740. self.first_proto.ClearField('optional_nested_message')
  1741. self.second_proto.optional_nested_message.ClearField('bb')
  1742. self.assertNotEqual(self.first_proto, self.second_proto)
  1743. self.first_proto.optional_nested_message.bb = 0
  1744. self.first_proto.optional_nested_message.ClearField('bb')
  1745. self.assertEqual(self.first_proto, self.second_proto)
  1746. class ExtensionEqualityTest(BaseTestCase):
  1747. def testExtensionEquality(self):
  1748. first_proto = unittest_pb2.TestAllExtensions()
  1749. second_proto = unittest_pb2.TestAllExtensions()
  1750. self.assertEqual(first_proto, second_proto)
  1751. test_util.SetAllExtensions(first_proto)
  1752. self.assertNotEqual(first_proto, second_proto)
  1753. test_util.SetAllExtensions(second_proto)
  1754. self.assertEqual(first_proto, second_proto)
  1755. # Ensure that we check value equality.
  1756. first_proto.Extensions[unittest_pb2.optional_int32_extension] += 1
  1757. self.assertNotEqual(first_proto, second_proto)
  1758. first_proto.Extensions[unittest_pb2.optional_int32_extension] -= 1
  1759. self.assertEqual(first_proto, second_proto)
  1760. # Ensure that we also look at "has" bits.
  1761. first_proto.ClearExtension(unittest_pb2.optional_int32_extension)
  1762. second_proto.Extensions[unittest_pb2.optional_int32_extension] = 0
  1763. self.assertNotEqual(first_proto, second_proto)
  1764. first_proto.Extensions[unittest_pb2.optional_int32_extension] = 0
  1765. self.assertEqual(first_proto, second_proto)
  1766. # Ensure that differences in cached values
  1767. # don't matter if "has" bits are both false.
  1768. first_proto = unittest_pb2.TestAllExtensions()
  1769. second_proto = unittest_pb2.TestAllExtensions()
  1770. self.assertEqual(
  1771. 0, first_proto.Extensions[unittest_pb2.optional_int32_extension])
  1772. self.assertEqual(first_proto, second_proto)
  1773. class MutualRecursionEqualityTest(BaseTestCase):
  1774. def testEqualityWithMutualRecursion(self):
  1775. first_proto = unittest_pb2.TestMutualRecursionA()
  1776. second_proto = unittest_pb2.TestMutualRecursionA()
  1777. self.assertEqual(first_proto, second_proto)
  1778. first_proto.bb.a.bb.optional_int32 = 23
  1779. self.assertNotEqual(first_proto, second_proto)
  1780. second_proto.bb.a.bb.optional_int32 = 23
  1781. self.assertEqual(first_proto, second_proto)
  1782. class ByteSizeTest(BaseTestCase):
  1783. def setUp(self):
  1784. self.proto = unittest_pb2.TestAllTypes()
  1785. self.extended_proto = more_extensions_pb2.ExtendedMessage()
  1786. self.packed_proto = unittest_pb2.TestPackedTypes()
  1787. self.packed_extended_proto = unittest_pb2.TestPackedExtensions()
  1788. def Size(self):
  1789. return self.proto.ByteSize()
  1790. def testEmptyMessage(self):
  1791. self.assertEqual(0, self.proto.ByteSize())
  1792. def testSizedOnKwargs(self):
  1793. # Use a separate message to ensure testing right after creation.
  1794. proto = unittest_pb2.TestAllTypes()
  1795. self.assertEqual(0, proto.ByteSize())
  1796. proto_kwargs = unittest_pb2.TestAllTypes(optional_int64 = 1)
  1797. # One byte for the tag, one to encode varint 1.
  1798. self.assertEqual(2, proto_kwargs.ByteSize())
  1799. def testVarints(self):
  1800. def Test(i, expected_varint_size):
  1801. self.proto.Clear()
  1802. self.proto.optional_int64 = i
  1803. # Add one to the varint size for the tag info
  1804. # for tag 1.
  1805. self.assertEqual(expected_varint_size + 1, self.Size())
  1806. Test(0, 1)
  1807. Test(1, 1)
  1808. for i, num_bytes in zip(range(7, 63, 7), range(1, 10000)):
  1809. Test((1 << i) - 1, num_bytes)
  1810. Test(-1, 10)
  1811. Test(-2, 10)
  1812. Test(-(1 << 63), 10)
  1813. def testStrings(self):
  1814. self.proto.optional_string = ''
  1815. # Need one byte for tag info (tag #14), and one byte for length.
  1816. self.assertEqual(2, self.Size())
  1817. self.proto.optional_string = 'abc'
  1818. # Need one byte for tag info (tag #14), and one byte for length.
  1819. self.assertEqual(2 + len(self.proto.optional_string), self.Size())
  1820. self.proto.optional_string = 'x' * 128
  1821. # Need one byte for tag info (tag #14), and TWO bytes for length.
  1822. self.assertEqual(3 + len(self.proto.optional_string), self.Size())
  1823. def testOtherNumerics(self):
  1824. self.proto.optional_fixed32 = 1234
  1825. # One byte for tag and 4 bytes for fixed32.
  1826. self.assertEqual(5, self.Size())
  1827. self.proto = unittest_pb2.TestAllTypes()
  1828. self.proto.optional_fixed64 = 1234
  1829. # One byte for tag and 8 bytes for fixed64.
  1830. self.assertEqual(9, self.Size())
  1831. self.proto = unittest_pb2.TestAllTypes()
  1832. self.proto.optional_float = 1.234
  1833. # One byte for tag and 4 bytes for float.
  1834. self.assertEqual(5, self.Size())
  1835. self.proto = unittest_pb2.TestAllTypes()
  1836. self.proto.optional_double = 1.234
  1837. # One byte for tag and 8 bytes for float.
  1838. self.assertEqual(9, self.Size())
  1839. self.proto = unittest_pb2.TestAllTypes()
  1840. self.proto.optional_sint32 = 64
  1841. # One byte for tag and 2 bytes for zig-zag-encoded 64.
  1842. self.assertEqual(3, self.Size())
  1843. self.proto = unittest_pb2.TestAllTypes()
  1844. def testComposites(self):
  1845. # 3 bytes.
  1846. self.proto.optional_nested_message.bb = (1 << 14)
  1847. # Plus one byte for bb tag.
  1848. # Plus 1 byte for optional_nested_message serialized size.
  1849. # Plus two bytes for optional_nested_message tag.
  1850. self.assertEqual(3 + 1 + 1 + 2, self.Size())
  1851. def testGroups(self):
  1852. # 4 bytes.
  1853. self.proto.optionalgroup.a = (1 << 21)
  1854. # Plus two bytes for |a| tag.
  1855. # Plus 2 * two bytes for START_GROUP and END_GROUP tags.
  1856. self.assertEqual(4 + 2 + 2*2, self.Size())
  1857. def testRepeatedScalars(self):
  1858. self.proto.repeated_int32.append(10) # 1 byte.
  1859. self.proto.repeated_int32.append(128) # 2 bytes.
  1860. # Also need 2 bytes for each entry for tag.
  1861. self.assertEqual(1 + 2 + 2*2, self.Size())
  1862. def testRepeatedScalarsExtend(self):
  1863. self.proto.repeated_int32.extend([10, 128]) # 3 bytes.
  1864. # Also need 2 bytes for each entry for tag.
  1865. self.assertEqual(1 + 2 + 2*2, self.Size())
  1866. def testRepeatedScalarsRemove(self):
  1867. self.proto.repeated_int32.append(10) # 1 byte.
  1868. self.proto.repeated_int32.append(128) # 2 bytes.
  1869. # Also need 2 bytes for each entry for tag.
  1870. self.assertEqual(1 + 2 + 2*2, self.Size())
  1871. self.proto.repeated_int32.remove(128)
  1872. self.assertEqual(1 + 2, self.Size())
  1873. def testRepeatedComposites(self):
  1874. # Empty message. 2 bytes tag plus 1 byte length.
  1875. foreign_message_0 = self.proto.repeated_nested_message.add()
  1876. # 2 bytes tag plus 1 byte length plus 1 byte bb tag 1 byte int.
  1877. foreign_message_1 = self.proto.repeated_nested_message.add()
  1878. foreign_message_1.bb = 7
  1879. self.assertEqual(2 + 1 + 2 + 1 + 1 + 1, self.Size())
  1880. def testRepeatedCompositesDelete(self):
  1881. # Empty message. 2 bytes tag plus 1 byte length.
  1882. foreign_message_0 = self.proto.repeated_nested_message.add()
  1883. # 2 bytes tag plus 1 byte length plus 1 byte bb tag 1 byte int.
  1884. foreign_message_1 = self.proto.repeated_nested_message.add()
  1885. foreign_message_1.bb = 9
  1886. self.assertEqual(2 + 1 + 2 + 1 + 1 + 1, self.Size())
  1887. repeated_nested_message = copy.deepcopy(
  1888. self.proto.repeated_nested_message)
  1889. # 2 bytes tag plus 1 byte length plus 1 byte bb tag 1 byte int.
  1890. del self.proto.repeated_nested_message[0]
  1891. self.assertEqual(2 + 1 + 1 + 1, self.Size())
  1892. # Now add a new message.
  1893. foreign_message_2 = self.proto.repeated_nested_message.add()
  1894. foreign_message_2.bb = 12
  1895. # 2 bytes tag plus 1 byte length plus 1 byte bb tag 1 byte int.
  1896. # 2 bytes tag plus 1 byte length plus 1 byte bb tag 1 byte int.
  1897. self.assertEqual(2 + 1 + 1 + 1 + 2 + 1 + 1 + 1, self.Size())
  1898. # 2 bytes tag plus 1 byte length plus 1 byte bb tag 1 byte int.
  1899. del self.proto.repeated_nested_message[1]
  1900. self.assertEqual(2 + 1 + 1 + 1, self.Size())
  1901. del self.proto.repeated_nested_message[0]
  1902. self.assertEqual(0, self.Size())
  1903. self.assertEqual(2, len(repeated_nested_message))
  1904. del repeated_nested_message[0:1]
  1905. # TODO(jieluo): Fix cpp extension bug when delete repeated message.
  1906. if api_implementation.Type() == 'python':
  1907. self.assertEqual(1, len(repeated_nested_message))
  1908. del repeated_nested_message[-1]
  1909. # TODO(jieluo): Fix cpp extension bug when delete repeated message.
  1910. if api_implementation.Type() == 'python':
  1911. self.assertEqual(0, len(repeated_nested_message))
  1912. def testRepeatedGroups(self):
  1913. # 2-byte START_GROUP plus 2-byte END_GROUP.
  1914. group_0 = self.proto.repeatedgroup.add()
  1915. # 2-byte START_GROUP plus 2-byte |a| tag + 1-byte |a|
  1916. # plus 2-byte END_GROUP.
  1917. group_1 = self.proto.repeatedgroup.add()
  1918. group_1.a = 7
  1919. self.assertEqual(2 + 2 + 2 + 2 + 1 + 2, self.Size())
  1920. def testExtensions(self):
  1921. proto = unittest_pb2.TestAllExtensions()
  1922. self.assertEqual(0, proto.ByteSize())
  1923. extension = unittest_pb2.optional_int32_extension # Field #1, 1 byte.
  1924. proto.Extensions[extension] = 23
  1925. # 1 byte for tag, 1 byte for value.
  1926. self.assertEqual(2, proto.ByteSize())
  1927. field = unittest_pb2.TestAllTypes.DESCRIPTOR.fields_by_name[
  1928. 'optional_int32']
  1929. with self.assertRaises(KeyError):
  1930. proto.Extensions[field] = 23
  1931. def testCacheInvalidationForNonrepeatedScalar(self):
  1932. # Test non-extension.
  1933. self.proto.optional_int32 = 1
  1934. self.assertEqual(2, self.proto.ByteSize())
  1935. self.proto.optional_int32 = 128
  1936. self.assertEqual(3, self.proto.ByteSize())
  1937. self.proto.ClearField('optional_int32')
  1938. self.assertEqual(0, self.proto.ByteSize())
  1939. # Test within extension.
  1940. extension = more_extensions_pb2.optional_int_extension
  1941. self.extended_proto.Extensions[extension] = 1
  1942. self.assertEqual(2, self.extended_proto.ByteSize())
  1943. self.extended_proto.Extensions[extension] = 128
  1944. self.assertEqual(3, self.extended_proto.ByteSize())
  1945. self.extended_proto.ClearExtension(extension)
  1946. self.assertEqual(0, self.extended_proto.ByteSize())
  1947. def testCacheInvalidationForRepeatedScalar(self):
  1948. # Test non-extension.
  1949. self.proto.repeated_int32.append(1)
  1950. self.assertEqual(3, self.proto.ByteSize())
  1951. self.proto.repeated_int32.append(1)
  1952. self.assertEqual(6, self.proto.ByteSize())
  1953. self.proto.repeated_int32[1] = 128
  1954. self.assertEqual(7, self.proto.ByteSize())
  1955. self.proto.ClearField('repeated_int32')
  1956. self.assertEqual(0, self.proto.ByteSize())
  1957. # Test within extension.
  1958. extension = more_extensions_pb2.repeated_int_extension
  1959. repeated = self.extended_proto.Extensions[extension]
  1960. repeated.append(1)
  1961. self.assertEqual(2, self.extended_proto.ByteSize())
  1962. repeated.append(1)
  1963. self.assertEqual(4, self.extended_proto.ByteSize())
  1964. repeated[1] = 128
  1965. self.assertEqual(5, self.extended_proto.ByteSize())
  1966. self.extended_proto.ClearExtension(extension)
  1967. self.assertEqual(0, self.extended_proto.ByteSize())
  1968. def testCacheInvalidationForNonrepeatedMessage(self):
  1969. # Test non-extension.
  1970. self.proto.optional_foreign_message.c = 1
  1971. self.assertEqual(5, self.proto.ByteSize())
  1972. self.proto.optional_foreign_message.c = 128
  1973. self.assertEqual(6, self.proto.ByteSize())
  1974. self.proto.optional_foreign_message.ClearField('c')
  1975. self.assertEqual(3, self.proto.ByteSize())
  1976. self.proto.ClearField('optional_foreign_message')
  1977. self.assertEqual(0, self.proto.ByteSize())
  1978. if api_implementation.Type() == 'python':
  1979. # This is only possible in pure-Python implementation of the API.
  1980. child = self.proto.optional_foreign_message
  1981. self.proto.ClearField('optional_foreign_message')
  1982. child.c = 128
  1983. self.assertEqual(0, self.proto.ByteSize())
  1984. # Test within extension.
  1985. extension = more_extensions_pb2.optional_message_extension
  1986. child = self.extended_proto.Extensions[extension]
  1987. self.assertEqual(0, self.extended_proto.ByteSize())
  1988. child.foreign_message_int = 1
  1989. self.assertEqual(4, self.extended_proto.ByteSize())
  1990. child.foreign_message_int = 128
  1991. self.assertEqual(5, self.extended_proto.ByteSize())
  1992. self.extended_proto.ClearExtension(extension)
  1993. self.assertEqual(0, self.extended_proto.ByteSize())
  1994. def testCacheInvalidationForRepeatedMessage(self):
  1995. # Test non-extension.
  1996. child0 = self.proto.repeated_foreign_message.add()
  1997. self.assertEqual(3, self.proto.ByteSize())
  1998. self.proto.repeated_foreign_message.add()
  1999. self.assertEqual(6, self.proto.ByteSize())
  2000. child0.c = 1
  2001. self.assertEqual(8, self.proto.ByteSize())
  2002. self.proto.ClearField('repeated_foreign_message')
  2003. self.assertEqual(0, self.proto.ByteSize())
  2004. # Test within extension.
  2005. extension = more_extensions_pb2.repeated_message_extension
  2006. child_list = self.extended_proto.Extensions[extension]
  2007. child0 = child_list.add()
  2008. self.assertEqual(2, self.extended_proto.ByteSize())
  2009. child_list.add()
  2010. self.assertEqual(4, self.extended_proto.ByteSize())
  2011. child0.foreign_message_int = 1
  2012. self.assertEqual(6, self.extended_proto.ByteSize())
  2013. child0.ClearField('foreign_message_int')
  2014. self.assertEqual(4, self.extended_proto.ByteSize())
  2015. self.extended_proto.ClearExtension(extension)
  2016. self.assertEqual(0, self.extended_proto.ByteSize())
  2017. def testPackedRepeatedScalars(self):
  2018. self.assertEqual(0, self.packed_proto.ByteSize())
  2019. self.packed_proto.packed_int32.append(10) # 1 byte.
  2020. self.packed_proto.packed_int32.append(128) # 2 bytes.
  2021. # The tag is 2 bytes (the field number is 90), and the varint
  2022. # storing the length is 1 byte.
  2023. int_size = 1 + 2 + 3
  2024. self.assertEqual(int_size, self.packed_proto.ByteSize())
  2025. self.packed_proto.packed_double.append(4.2) # 8 bytes
  2026. self.packed_proto.packed_double.append(3.25) # 8 bytes
  2027. # 2 more tag bytes, 1 more length byte.
  2028. double_size = 8 + 8 + 3
  2029. self.assertEqual(int_size+double_size, self.packed_proto.ByteSize())
  2030. self.packed_proto.ClearField('packed_int32')
  2031. self.assertEqual(double_size, self.packed_proto.ByteSize())
  2032. def testPackedExtensions(self):
  2033. self.assertEqual(0, self.packed_extended_proto.ByteSize())
  2034. extension = self.packed_extended_proto.Extensions[
  2035. unittest_pb2.packed_fixed32_extension]
  2036. extension.extend([1, 2, 3, 4]) # 16 bytes
  2037. # Tag is 3 bytes.
  2038. self.assertEqual(19, self.packed_extended_proto.ByteSize())
  2039. # Issues to be sure to cover include:
  2040. # * Handling of unrecognized tags ("uninterpreted_bytes").
  2041. # * Handling of MessageSets.
  2042. # * Consistent ordering of tags in the wire format,
  2043. # including ordering between extensions and non-extension
  2044. # fields.
  2045. # * Consistent serialization of negative numbers, especially
  2046. # negative int32s.
  2047. # * Handling of empty submessages (with and without "has"
  2048. # bits set).
  2049. class SerializationTest(BaseTestCase):
  2050. def testSerializeEmtpyMessage(self):
  2051. first_proto = unittest_pb2.TestAllTypes()
  2052. second_proto = unittest_pb2.TestAllTypes()
  2053. serialized = first_proto.SerializeToString()
  2054. self.assertEqual(first_proto.ByteSize(), len(serialized))
  2055. self.assertEqual(
  2056. len(serialized),
  2057. second_proto.MergeFromString(serialized))
  2058. self.assertEqual(first_proto, second_proto)
  2059. def testSerializeAllFields(self):
  2060. first_proto = unittest_pb2.TestAllTypes()
  2061. second_proto = unittest_pb2.TestAllTypes()
  2062. test_util.SetAllFields(first_proto)
  2063. serialized = first_proto.SerializeToString()
  2064. self.assertEqual(first_proto.ByteSize(), len(serialized))
  2065. self.assertEqual(
  2066. len(serialized),
  2067. second_proto.MergeFromString(serialized))
  2068. self.assertEqual(first_proto, second_proto)
  2069. def testSerializeAllExtensions(self):
  2070. first_proto = unittest_pb2.TestAllExtensions()
  2071. second_proto = unittest_pb2.TestAllExtensions()
  2072. test_util.SetAllExtensions(first_proto)
  2073. serialized = first_proto.SerializeToString()
  2074. self.assertEqual(
  2075. len(serialized),
  2076. second_proto.MergeFromString(serialized))
  2077. self.assertEqual(first_proto, second_proto)
  2078. def testSerializeWithOptionalGroup(self):
  2079. first_proto = unittest_pb2.TestAllTypes()
  2080. second_proto = unittest_pb2.TestAllTypes()
  2081. first_proto.optionalgroup.a = 242
  2082. serialized = first_proto.SerializeToString()
  2083. self.assertEqual(
  2084. len(serialized),
  2085. second_proto.MergeFromString(serialized))
  2086. self.assertEqual(first_proto, second_proto)
  2087. def testSerializeNegativeValues(self):
  2088. first_proto = unittest_pb2.TestAllTypes()
  2089. first_proto.optional_int32 = -1
  2090. first_proto.optional_int64 = -(2 << 40)
  2091. first_proto.optional_sint32 = -3
  2092. first_proto.optional_sint64 = -(4 << 40)
  2093. first_proto.optional_sfixed32 = -5
  2094. first_proto.optional_sfixed64 = -(6 << 40)
  2095. second_proto = unittest_pb2.TestAllTypes.FromString(
  2096. first_proto.SerializeToString())
  2097. self.assertEqual(first_proto, second_proto)
  2098. def testParseTruncated(self):
  2099. # This test is only applicable for the Python implementation of the API.
  2100. if api_implementation.Type() != 'python':
  2101. return
  2102. first_proto = unittest_pb2.TestAllTypes()
  2103. test_util.SetAllFields(first_proto)
  2104. serialized = memoryview(first_proto.SerializeToString())
  2105. for truncation_point in range(len(serialized) + 1):
  2106. try:
  2107. second_proto = unittest_pb2.TestAllTypes()
  2108. unknown_fields = unittest_pb2.TestEmptyMessage()
  2109. pos = second_proto._InternalParse(serialized, 0, truncation_point)
  2110. # If we didn't raise an error then we read exactly the amount expected.
  2111. self.assertEqual(truncation_point, pos)
  2112. # Parsing to unknown fields should not throw if parsing to known fields
  2113. # did not.
  2114. try:
  2115. pos2 = unknown_fields._InternalParse(serialized, 0, truncation_point)
  2116. self.assertEqual(truncation_point, pos2)
  2117. except message.DecodeError:
  2118. self.fail('Parsing unknown fields failed when parsing known fields '
  2119. 'did not.')
  2120. except message.DecodeError:
  2121. # Parsing unknown fields should also fail.
  2122. self.assertRaises(message.DecodeError, unknown_fields._InternalParse,
  2123. serialized, 0, truncation_point)
  2124. def testCanonicalSerializationOrder(self):
  2125. proto = more_messages_pb2.OutOfOrderFields()
  2126. # These are also their tag numbers. Even though we're setting these in
  2127. # reverse-tag order AND they're listed in reverse tag-order in the .proto
  2128. # file, they should nonetheless be serialized in tag order.
  2129. proto.optional_sint32 = 5
  2130. proto.Extensions[more_messages_pb2.optional_uint64] = 4
  2131. proto.optional_uint32 = 3
  2132. proto.Extensions[more_messages_pb2.optional_int64] = 2
  2133. proto.optional_int32 = 1
  2134. serialized = proto.SerializeToString()
  2135. self.assertEqual(proto.ByteSize(), len(serialized))
  2136. d = _MiniDecoder(serialized)
  2137. ReadTag = d.ReadFieldNumberAndWireType
  2138. self.assertEqual((1, wire_format.WIRETYPE_VARINT), ReadTag())
  2139. self.assertEqual(1, d.ReadInt32())
  2140. self.assertEqual((2, wire_format.WIRETYPE_VARINT), ReadTag())
  2141. self.assertEqual(2, d.ReadInt64())
  2142. self.assertEqual((3, wire_format.WIRETYPE_VARINT), ReadTag())
  2143. self.assertEqual(3, d.ReadUInt32())
  2144. self.assertEqual((4, wire_format.WIRETYPE_VARINT), ReadTag())
  2145. self.assertEqual(4, d.ReadUInt64())
  2146. self.assertEqual((5, wire_format.WIRETYPE_VARINT), ReadTag())
  2147. self.assertEqual(5, d.ReadSInt32())
  2148. def testCanonicalSerializationOrderSameAsCpp(self):
  2149. # Copy of the same test we use for C++.
  2150. proto = unittest_pb2.TestFieldOrderings()
  2151. test_util.SetAllFieldsAndExtensions(proto)
  2152. serialized = proto.SerializeToString()
  2153. test_util.ExpectAllFieldsAndExtensionsInOrder(serialized)
  2154. def testMergeFromStringWhenFieldsAlreadySet(self):
  2155. first_proto = unittest_pb2.TestAllTypes()
  2156. first_proto.repeated_string.append('foobar')
  2157. first_proto.optional_int32 = 23
  2158. first_proto.optional_nested_message.bb = 42
  2159. serialized = first_proto.SerializeToString()
  2160. second_proto = unittest_pb2.TestAllTypes()
  2161. second_proto.repeated_string.append('baz')
  2162. second_proto.optional_int32 = 100
  2163. second_proto.optional_nested_message.bb = 999
  2164. bytes_parsed = second_proto.MergeFromString(serialized)
  2165. self.assertEqual(len(serialized), bytes_parsed)
  2166. # Ensure that we append to repeated fields.
  2167. self.assertEqual(['baz', 'foobar'], list(second_proto.repeated_string))
  2168. # Ensure that we overwrite nonrepeatd scalars.
  2169. self.assertEqual(23, second_proto.optional_int32)
  2170. # Ensure that we recursively call MergeFromString() on
  2171. # submessages.
  2172. self.assertEqual(42, second_proto.optional_nested_message.bb)
  2173. def testMessageSetWireFormat(self):
  2174. proto = message_set_extensions_pb2.TestMessageSet()
  2175. extension_message1 = message_set_extensions_pb2.TestMessageSetExtension1
  2176. extension_message2 = message_set_extensions_pb2.TestMessageSetExtension2
  2177. extension1 = extension_message1.message_set_extension
  2178. extension2 = extension_message2.message_set_extension
  2179. extension3 = message_set_extensions_pb2.message_set_extension3
  2180. proto.Extensions[extension1].i = 123
  2181. proto.Extensions[extension2].str = 'foo'
  2182. proto.Extensions[extension3].text = 'bar'
  2183. # Serialize using the MessageSet wire format (this is specified in the
  2184. # .proto file).
  2185. serialized = proto.SerializeToString()
  2186. raw = unittest_mset_pb2.RawMessageSet()
  2187. self.assertEqual(False,
  2188. raw.DESCRIPTOR.GetOptions().message_set_wire_format)
  2189. self.assertEqual(
  2190. len(serialized),
  2191. raw.MergeFromString(serialized))
  2192. self.assertEqual(3, len(raw.item))
  2193. message1 = message_set_extensions_pb2.TestMessageSetExtension1()
  2194. self.assertEqual(
  2195. len(raw.item[0].message),
  2196. message1.MergeFromString(raw.item[0].message))
  2197. self.assertEqual(123, message1.i)
  2198. message2 = message_set_extensions_pb2.TestMessageSetExtension2()
  2199. self.assertEqual(
  2200. len(raw.item[1].message),
  2201. message2.MergeFromString(raw.item[1].message))
  2202. self.assertEqual('foo', message2.str)
  2203. message3 = message_set_extensions_pb2.TestMessageSetExtension3()
  2204. self.assertEqual(
  2205. len(raw.item[2].message),
  2206. message3.MergeFromString(raw.item[2].message))
  2207. self.assertEqual('bar', message3.text)
  2208. # Deserialize using the MessageSet wire format.
  2209. proto2 = message_set_extensions_pb2.TestMessageSet()
  2210. self.assertEqual(
  2211. len(serialized),
  2212. proto2.MergeFromString(serialized))
  2213. self.assertEqual(123, proto2.Extensions[extension1].i)
  2214. self.assertEqual('foo', proto2.Extensions[extension2].str)
  2215. self.assertEqual('bar', proto2.Extensions[extension3].text)
  2216. # Check byte size.
  2217. self.assertEqual(proto2.ByteSize(), len(serialized))
  2218. self.assertEqual(proto.ByteSize(), len(serialized))
  2219. def testMessageSetWireFormatUnknownExtension(self):
  2220. # Create a message using the message set wire format with an unknown
  2221. # message.
  2222. raw = unittest_mset_pb2.RawMessageSet()
  2223. # Add an item.
  2224. item = raw.item.add()
  2225. item.type_id = 98418603
  2226. extension_message1 = message_set_extensions_pb2.TestMessageSetExtension1
  2227. message1 = message_set_extensions_pb2.TestMessageSetExtension1()
  2228. message1.i = 12345
  2229. item.message = message1.SerializeToString()
  2230. # Add a second, unknown extension.
  2231. item = raw.item.add()
  2232. item.type_id = 98418604
  2233. extension_message1 = message_set_extensions_pb2.TestMessageSetExtension1
  2234. message1 = message_set_extensions_pb2.TestMessageSetExtension1()
  2235. message1.i = 12346
  2236. item.message = message1.SerializeToString()
  2237. # Add another unknown extension.
  2238. item = raw.item.add()
  2239. item.type_id = 98418605
  2240. message1 = message_set_extensions_pb2.TestMessageSetExtension2()
  2241. message1.str = 'foo'
  2242. item.message = message1.SerializeToString()
  2243. serialized = raw.SerializeToString()
  2244. # Parse message using the message set wire format.
  2245. proto = message_set_extensions_pb2.TestMessageSet()
  2246. self.assertEqual(
  2247. len(serialized),
  2248. proto.MergeFromString(serialized))
  2249. # Check that the message parsed well.
  2250. extension_message1 = message_set_extensions_pb2.TestMessageSetExtension1
  2251. extension1 = extension_message1.message_set_extension
  2252. self.assertEqual(12345, proto.Extensions[extension1].i)
  2253. def testUnknownFields(self):
  2254. proto = unittest_pb2.TestAllTypes()
  2255. test_util.SetAllFields(proto)
  2256. serialized = proto.SerializeToString()
  2257. # The empty message should be parsable with all of the fields
  2258. # unknown.
  2259. proto2 = unittest_pb2.TestEmptyMessage()
  2260. # Parsing this message should succeed.
  2261. self.assertEqual(
  2262. len(serialized),
  2263. proto2.MergeFromString(serialized))
  2264. # Now test with a int64 field set.
  2265. proto = unittest_pb2.TestAllTypes()
  2266. proto.optional_int64 = 0x0fffffffffffffff
  2267. serialized = proto.SerializeToString()
  2268. # The empty message should be parsable with all of the fields
  2269. # unknown.
  2270. proto2 = unittest_pb2.TestEmptyMessage()
  2271. # Parsing this message should succeed.
  2272. self.assertEqual(
  2273. len(serialized),
  2274. proto2.MergeFromString(serialized))
  2275. def _CheckRaises(self, exc_class, callable_obj, exception):
  2276. """This method checks if the excpetion type and message are as expected."""
  2277. try:
  2278. callable_obj()
  2279. except exc_class as ex:
  2280. # Check if the exception message is the right one.
  2281. self.assertEqual(exception, str(ex))
  2282. return
  2283. else:
  2284. raise self.failureException('%s not raised' % str(exc_class))
  2285. def testSerializeUninitialized(self):
  2286. proto = unittest_pb2.TestRequired()
  2287. self._CheckRaises(
  2288. message.EncodeError,
  2289. proto.SerializeToString,
  2290. 'Message protobuf_unittest.TestRequired is missing required fields: '
  2291. 'a,b,c')
  2292. # Shouldn't raise exceptions.
  2293. partial = proto.SerializePartialToString()
  2294. proto2 = unittest_pb2.TestRequired()
  2295. self.assertFalse(proto2.HasField('a'))
  2296. # proto2 ParseFromString does not check that required fields are set.
  2297. proto2.ParseFromString(partial)
  2298. self.assertFalse(proto2.HasField('a'))
  2299. proto.a = 1
  2300. self._CheckRaises(
  2301. message.EncodeError,
  2302. proto.SerializeToString,
  2303. 'Message protobuf_unittest.TestRequired is missing required fields: b,c')
  2304. # Shouldn't raise exceptions.
  2305. partial = proto.SerializePartialToString()
  2306. proto.b = 2
  2307. self._CheckRaises(
  2308. message.EncodeError,
  2309. proto.SerializeToString,
  2310. 'Message protobuf_unittest.TestRequired is missing required fields: c')
  2311. # Shouldn't raise exceptions.
  2312. partial = proto.SerializePartialToString()
  2313. proto.c = 3
  2314. serialized = proto.SerializeToString()
  2315. # Shouldn't raise exceptions.
  2316. partial = proto.SerializePartialToString()
  2317. proto2 = unittest_pb2.TestRequired()
  2318. self.assertEqual(
  2319. len(serialized),
  2320. proto2.MergeFromString(serialized))
  2321. self.assertEqual(1, proto2.a)
  2322. self.assertEqual(2, proto2.b)
  2323. self.assertEqual(3, proto2.c)
  2324. self.assertEqual(
  2325. len(partial),
  2326. proto2.MergeFromString(partial))
  2327. self.assertEqual(1, proto2.a)
  2328. self.assertEqual(2, proto2.b)
  2329. self.assertEqual(3, proto2.c)
  2330. def testSerializeUninitializedSubMessage(self):
  2331. proto = unittest_pb2.TestRequiredForeign()
  2332. # Sub-message doesn't exist yet, so this succeeds.
  2333. proto.SerializeToString()
  2334. proto.optional_message.a = 1
  2335. self._CheckRaises(
  2336. message.EncodeError,
  2337. proto.SerializeToString,
  2338. 'Message protobuf_unittest.TestRequiredForeign '
  2339. 'is missing required fields: '
  2340. 'optional_message.b,optional_message.c')
  2341. proto.optional_message.b = 2
  2342. proto.optional_message.c = 3
  2343. proto.SerializeToString()
  2344. proto.repeated_message.add().a = 1
  2345. proto.repeated_message.add().b = 2
  2346. self._CheckRaises(
  2347. message.EncodeError,
  2348. proto.SerializeToString,
  2349. 'Message protobuf_unittest.TestRequiredForeign is missing required fields: '
  2350. 'repeated_message[0].b,repeated_message[0].c,'
  2351. 'repeated_message[1].a,repeated_message[1].c')
  2352. proto.repeated_message[0].b = 2
  2353. proto.repeated_message[0].c = 3
  2354. proto.repeated_message[1].a = 1
  2355. proto.repeated_message[1].c = 3
  2356. proto.SerializeToString()
  2357. def testSerializeAllPackedFields(self):
  2358. first_proto = unittest_pb2.TestPackedTypes()
  2359. second_proto = unittest_pb2.TestPackedTypes()
  2360. test_util.SetAllPackedFields(first_proto)
  2361. serialized = first_proto.SerializeToString()
  2362. self.assertEqual(first_proto.ByteSize(), len(serialized))
  2363. bytes_read = second_proto.MergeFromString(serialized)
  2364. self.assertEqual(second_proto.ByteSize(), bytes_read)
  2365. self.assertEqual(first_proto, second_proto)
  2366. def testSerializeAllPackedExtensions(self):
  2367. first_proto = unittest_pb2.TestPackedExtensions()
  2368. second_proto = unittest_pb2.TestPackedExtensions()
  2369. test_util.SetAllPackedExtensions(first_proto)
  2370. serialized = first_proto.SerializeToString()
  2371. bytes_read = second_proto.MergeFromString(serialized)
  2372. self.assertEqual(second_proto.ByteSize(), bytes_read)
  2373. self.assertEqual(first_proto, second_proto)
  2374. def testMergePackedFromStringWhenSomeFieldsAlreadySet(self):
  2375. first_proto = unittest_pb2.TestPackedTypes()
  2376. first_proto.packed_int32.extend([1, 2])
  2377. first_proto.packed_double.append(3.0)
  2378. serialized = first_proto.SerializeToString()
  2379. second_proto = unittest_pb2.TestPackedTypes()
  2380. second_proto.packed_int32.append(3)
  2381. second_proto.packed_double.extend([1.0, 2.0])
  2382. second_proto.packed_sint32.append(4)
  2383. self.assertEqual(
  2384. len(serialized),
  2385. second_proto.MergeFromString(serialized))
  2386. self.assertEqual([3, 1, 2], second_proto.packed_int32)
  2387. self.assertEqual([1.0, 2.0, 3.0], second_proto.packed_double)
  2388. self.assertEqual([4], second_proto.packed_sint32)
  2389. def testPackedFieldsWireFormat(self):
  2390. proto = unittest_pb2.TestPackedTypes()
  2391. proto.packed_int32.extend([1, 2, 150, 3]) # 1 + 1 + 2 + 1 bytes
  2392. proto.packed_double.extend([1.0, 1000.0]) # 8 + 8 bytes
  2393. proto.packed_float.append(2.0) # 4 bytes, will be before double
  2394. serialized = proto.SerializeToString()
  2395. self.assertEqual(proto.ByteSize(), len(serialized))
  2396. d = _MiniDecoder(serialized)
  2397. ReadTag = d.ReadFieldNumberAndWireType
  2398. self.assertEqual((90, wire_format.WIRETYPE_LENGTH_DELIMITED), ReadTag())
  2399. self.assertEqual(1+1+1+2, d.ReadInt32())
  2400. self.assertEqual(1, d.ReadInt32())
  2401. self.assertEqual(2, d.ReadInt32())
  2402. self.assertEqual(150, d.ReadInt32())
  2403. self.assertEqual(3, d.ReadInt32())
  2404. self.assertEqual((100, wire_format.WIRETYPE_LENGTH_DELIMITED), ReadTag())
  2405. self.assertEqual(4, d.ReadInt32())
  2406. self.assertEqual(2.0, d.ReadFloat())
  2407. self.assertEqual((101, wire_format.WIRETYPE_LENGTH_DELIMITED), ReadTag())
  2408. self.assertEqual(8+8, d.ReadInt32())
  2409. self.assertEqual(1.0, d.ReadDouble())
  2410. self.assertEqual(1000.0, d.ReadDouble())
  2411. self.assertTrue(d.EndOfStream())
  2412. def testParsePackedFromUnpacked(self):
  2413. unpacked = unittest_pb2.TestUnpackedTypes()
  2414. test_util.SetAllUnpackedFields(unpacked)
  2415. packed = unittest_pb2.TestPackedTypes()
  2416. serialized = unpacked.SerializeToString()
  2417. self.assertEqual(
  2418. len(serialized),
  2419. packed.MergeFromString(serialized))
  2420. expected = unittest_pb2.TestPackedTypes()
  2421. test_util.SetAllPackedFields(expected)
  2422. self.assertEqual(expected, packed)
  2423. def testParseUnpackedFromPacked(self):
  2424. packed = unittest_pb2.TestPackedTypes()
  2425. test_util.SetAllPackedFields(packed)
  2426. unpacked = unittest_pb2.TestUnpackedTypes()
  2427. serialized = packed.SerializeToString()
  2428. self.assertEqual(
  2429. len(serialized),
  2430. unpacked.MergeFromString(serialized))
  2431. expected = unittest_pb2.TestUnpackedTypes()
  2432. test_util.SetAllUnpackedFields(expected)
  2433. self.assertEqual(expected, unpacked)
  2434. def testFieldNumbers(self):
  2435. proto = unittest_pb2.TestAllTypes()
  2436. self.assertEqual(unittest_pb2.TestAllTypes.NestedMessage.BB_FIELD_NUMBER, 1)
  2437. self.assertEqual(unittest_pb2.TestAllTypes.OPTIONAL_INT32_FIELD_NUMBER, 1)
  2438. self.assertEqual(unittest_pb2.TestAllTypes.OPTIONALGROUP_FIELD_NUMBER, 16)
  2439. self.assertEqual(
  2440. unittest_pb2.TestAllTypes.OPTIONAL_NESTED_MESSAGE_FIELD_NUMBER, 18)
  2441. self.assertEqual(
  2442. unittest_pb2.TestAllTypes.OPTIONAL_NESTED_ENUM_FIELD_NUMBER, 21)
  2443. self.assertEqual(unittest_pb2.TestAllTypes.REPEATED_INT32_FIELD_NUMBER, 31)
  2444. self.assertEqual(unittest_pb2.TestAllTypes.REPEATEDGROUP_FIELD_NUMBER, 46)
  2445. self.assertEqual(
  2446. unittest_pb2.TestAllTypes.REPEATED_NESTED_MESSAGE_FIELD_NUMBER, 48)
  2447. self.assertEqual(
  2448. unittest_pb2.TestAllTypes.REPEATED_NESTED_ENUM_FIELD_NUMBER, 51)
  2449. def testExtensionFieldNumbers(self):
  2450. self.assertEqual(unittest_pb2.TestRequired.single.number, 1000)
  2451. self.assertEqual(unittest_pb2.TestRequired.SINGLE_FIELD_NUMBER, 1000)
  2452. self.assertEqual(unittest_pb2.TestRequired.multi.number, 1001)
  2453. self.assertEqual(unittest_pb2.TestRequired.MULTI_FIELD_NUMBER, 1001)
  2454. self.assertEqual(unittest_pb2.optional_int32_extension.number, 1)
  2455. self.assertEqual(unittest_pb2.OPTIONAL_INT32_EXTENSION_FIELD_NUMBER, 1)
  2456. self.assertEqual(unittest_pb2.optionalgroup_extension.number, 16)
  2457. self.assertEqual(unittest_pb2.OPTIONALGROUP_EXTENSION_FIELD_NUMBER, 16)
  2458. self.assertEqual(unittest_pb2.optional_nested_message_extension.number, 18)
  2459. self.assertEqual(
  2460. unittest_pb2.OPTIONAL_NESTED_MESSAGE_EXTENSION_FIELD_NUMBER, 18)
  2461. self.assertEqual(unittest_pb2.optional_nested_enum_extension.number, 21)
  2462. self.assertEqual(unittest_pb2.OPTIONAL_NESTED_ENUM_EXTENSION_FIELD_NUMBER,
  2463. 21)
  2464. self.assertEqual(unittest_pb2.repeated_int32_extension.number, 31)
  2465. self.assertEqual(unittest_pb2.REPEATED_INT32_EXTENSION_FIELD_NUMBER, 31)
  2466. self.assertEqual(unittest_pb2.repeatedgroup_extension.number, 46)
  2467. self.assertEqual(unittest_pb2.REPEATEDGROUP_EXTENSION_FIELD_NUMBER, 46)
  2468. self.assertEqual(unittest_pb2.repeated_nested_message_extension.number, 48)
  2469. self.assertEqual(
  2470. unittest_pb2.REPEATED_NESTED_MESSAGE_EXTENSION_FIELD_NUMBER, 48)
  2471. self.assertEqual(unittest_pb2.repeated_nested_enum_extension.number, 51)
  2472. self.assertEqual(unittest_pb2.REPEATED_NESTED_ENUM_EXTENSION_FIELD_NUMBER,
  2473. 51)
  2474. def testFieldProperties(self):
  2475. cls = unittest_pb2.TestAllTypes
  2476. self.assertIs(cls.optional_int32.DESCRIPTOR,
  2477. cls.DESCRIPTOR.fields_by_name['optional_int32'])
  2478. self.assertEqual(cls.OPTIONAL_INT32_FIELD_NUMBER,
  2479. cls.optional_int32.DESCRIPTOR.number)
  2480. self.assertIs(cls.optional_nested_message.DESCRIPTOR,
  2481. cls.DESCRIPTOR.fields_by_name['optional_nested_message'])
  2482. self.assertEqual(cls.OPTIONAL_NESTED_MESSAGE_FIELD_NUMBER,
  2483. cls.optional_nested_message.DESCRIPTOR.number)
  2484. self.assertIs(cls.repeated_int32.DESCRIPTOR,
  2485. cls.DESCRIPTOR.fields_by_name['repeated_int32'])
  2486. self.assertEqual(cls.REPEATED_INT32_FIELD_NUMBER,
  2487. cls.repeated_int32.DESCRIPTOR.number)
  2488. def testFieldDataDescriptor(self):
  2489. msg = unittest_pb2.TestAllTypes()
  2490. msg.optional_int32 = 42
  2491. self.assertEqual(unittest_pb2.TestAllTypes.optional_int32.__get__(msg), 42)
  2492. unittest_pb2.TestAllTypes.optional_int32.__set__(msg, 25)
  2493. self.assertEqual(msg.optional_int32, 25)
  2494. with self.assertRaises(AttributeError):
  2495. del msg.optional_int32
  2496. try:
  2497. unittest_pb2.ForeignMessage.c.__get__(msg)
  2498. except TypeError:
  2499. pass # The cpp implementation cannot mix fields from other messages.
  2500. # This test exercises a specific check that avoids a crash.
  2501. else:
  2502. pass # The python implementation allows fields from other messages.
  2503. # This is useless, but works.
  2504. def testInitKwargs(self):
  2505. proto = unittest_pb2.TestAllTypes(
  2506. optional_int32=1,
  2507. optional_string='foo',
  2508. optional_bool=True,
  2509. optional_bytes=b'bar',
  2510. optional_nested_message=unittest_pb2.TestAllTypes.NestedMessage(bb=1),
  2511. optional_foreign_message=unittest_pb2.ForeignMessage(c=1),
  2512. optional_nested_enum=unittest_pb2.TestAllTypes.FOO,
  2513. optional_foreign_enum=unittest_pb2.FOREIGN_FOO,
  2514. repeated_int32=[1, 2, 3])
  2515. self.assertTrue(proto.IsInitialized())
  2516. self.assertTrue(proto.HasField('optional_int32'))
  2517. self.assertTrue(proto.HasField('optional_string'))
  2518. self.assertTrue(proto.HasField('optional_bool'))
  2519. self.assertTrue(proto.HasField('optional_bytes'))
  2520. self.assertTrue(proto.HasField('optional_nested_message'))
  2521. self.assertTrue(proto.HasField('optional_foreign_message'))
  2522. self.assertTrue(proto.HasField('optional_nested_enum'))
  2523. self.assertTrue(proto.HasField('optional_foreign_enum'))
  2524. self.assertEqual(1, proto.optional_int32)
  2525. self.assertEqual('foo', proto.optional_string)
  2526. self.assertEqual(True, proto.optional_bool)
  2527. self.assertEqual(b'bar', proto.optional_bytes)
  2528. self.assertEqual(1, proto.optional_nested_message.bb)
  2529. self.assertEqual(1, proto.optional_foreign_message.c)
  2530. self.assertEqual(unittest_pb2.TestAllTypes.FOO,
  2531. proto.optional_nested_enum)
  2532. self.assertEqual(unittest_pb2.FOREIGN_FOO, proto.optional_foreign_enum)
  2533. self.assertEqual([1, 2, 3], proto.repeated_int32)
  2534. def testInitArgsUnknownFieldName(self):
  2535. def InitalizeEmptyMessageWithExtraKeywordArg():
  2536. unused_proto = unittest_pb2.TestEmptyMessage(unknown='unknown')
  2537. self._CheckRaises(
  2538. ValueError,
  2539. InitalizeEmptyMessageWithExtraKeywordArg,
  2540. 'Protocol message TestEmptyMessage has no "unknown" field.')
  2541. def testInitRequiredKwargs(self):
  2542. proto = unittest_pb2.TestRequired(a=1, b=1, c=1)
  2543. self.assertTrue(proto.IsInitialized())
  2544. self.assertTrue(proto.HasField('a'))
  2545. self.assertTrue(proto.HasField('b'))
  2546. self.assertTrue(proto.HasField('c'))
  2547. self.assertTrue(not proto.HasField('dummy2'))
  2548. self.assertEqual(1, proto.a)
  2549. self.assertEqual(1, proto.b)
  2550. self.assertEqual(1, proto.c)
  2551. def testInitRequiredForeignKwargs(self):
  2552. proto = unittest_pb2.TestRequiredForeign(
  2553. optional_message=unittest_pb2.TestRequired(a=1, b=1, c=1))
  2554. self.assertTrue(proto.IsInitialized())
  2555. self.assertTrue(proto.HasField('optional_message'))
  2556. self.assertTrue(proto.optional_message.IsInitialized())
  2557. self.assertTrue(proto.optional_message.HasField('a'))
  2558. self.assertTrue(proto.optional_message.HasField('b'))
  2559. self.assertTrue(proto.optional_message.HasField('c'))
  2560. self.assertTrue(not proto.optional_message.HasField('dummy2'))
  2561. self.assertEqual(unittest_pb2.TestRequired(a=1, b=1, c=1),
  2562. proto.optional_message)
  2563. self.assertEqual(1, proto.optional_message.a)
  2564. self.assertEqual(1, proto.optional_message.b)
  2565. self.assertEqual(1, proto.optional_message.c)
  2566. def testInitRepeatedKwargs(self):
  2567. proto = unittest_pb2.TestAllTypes(repeated_int32=[1, 2, 3])
  2568. self.assertTrue(proto.IsInitialized())
  2569. self.assertEqual(1, proto.repeated_int32[0])
  2570. self.assertEqual(2, proto.repeated_int32[1])
  2571. self.assertEqual(3, proto.repeated_int32[2])
  2572. class OptionsTest(BaseTestCase):
  2573. def testMessageOptions(self):
  2574. proto = message_set_extensions_pb2.TestMessageSet()
  2575. self.assertEqual(True,
  2576. proto.DESCRIPTOR.GetOptions().message_set_wire_format)
  2577. proto = unittest_pb2.TestAllTypes()
  2578. self.assertEqual(False,
  2579. proto.DESCRIPTOR.GetOptions().message_set_wire_format)
  2580. def testPackedOptions(self):
  2581. proto = unittest_pb2.TestAllTypes()
  2582. proto.optional_int32 = 1
  2583. proto.optional_double = 3.0
  2584. for field_descriptor, _ in proto.ListFields():
  2585. self.assertEqual(False, field_descriptor.GetOptions().packed)
  2586. proto = unittest_pb2.TestPackedTypes()
  2587. proto.packed_int32.append(1)
  2588. proto.packed_double.append(3.0)
  2589. for field_descriptor, _ in proto.ListFields():
  2590. self.assertEqual(True, field_descriptor.GetOptions().packed)
  2591. self.assertEqual(descriptor.FieldDescriptor.LABEL_REPEATED,
  2592. field_descriptor.label)
  2593. class ClassAPITest(BaseTestCase):
  2594. @unittest.skipIf(
  2595. api_implementation.Type() == 'cpp' and api_implementation.Version() == 2,
  2596. 'C++ implementation requires a call to MakeDescriptor()')
  2597. @testing_refleaks.SkipReferenceLeakChecker('MakeClass is not repeatable')
  2598. def testMakeClassWithNestedDescriptor(self):
  2599. leaf_desc = descriptor.Descriptor('leaf', 'package.parent.child.leaf', '',
  2600. containing_type=None, fields=[],
  2601. nested_types=[], enum_types=[],
  2602. extensions=[])
  2603. child_desc = descriptor.Descriptor('child', 'package.parent.child', '',
  2604. containing_type=None, fields=[],
  2605. nested_types=[leaf_desc], enum_types=[],
  2606. extensions=[])
  2607. sibling_desc = descriptor.Descriptor('sibling', 'package.parent.sibling',
  2608. '', containing_type=None, fields=[],
  2609. nested_types=[], enum_types=[],
  2610. extensions=[])
  2611. parent_desc = descriptor.Descriptor('parent', 'package.parent', '',
  2612. containing_type=None, fields=[],
  2613. nested_types=[child_desc, sibling_desc],
  2614. enum_types=[], extensions=[])
  2615. reflection.MakeClass(parent_desc)
  2616. def _GetSerializedFileDescriptor(self, name):
  2617. """Get a serialized representation of a test FileDescriptorProto.
  2618. Args:
  2619. name: All calls to this must use a unique message name, to avoid
  2620. collisions in the cpp descriptor pool.
  2621. Returns:
  2622. A string containing the serialized form of a test FileDescriptorProto.
  2623. """
  2624. file_descriptor_str = (
  2625. 'message_type {'
  2626. ' name: "' + name + '"'
  2627. ' field {'
  2628. ' name: "flat"'
  2629. ' number: 1'
  2630. ' label: LABEL_REPEATED'
  2631. ' type: TYPE_UINT32'
  2632. ' }'
  2633. ' field {'
  2634. ' name: "bar"'
  2635. ' number: 2'
  2636. ' label: LABEL_OPTIONAL'
  2637. ' type: TYPE_MESSAGE'
  2638. ' type_name: "Bar"'
  2639. ' }'
  2640. ' nested_type {'
  2641. ' name: "Bar"'
  2642. ' field {'
  2643. ' name: "baz"'
  2644. ' number: 3'
  2645. ' label: LABEL_OPTIONAL'
  2646. ' type: TYPE_MESSAGE'
  2647. ' type_name: "Baz"'
  2648. ' }'
  2649. ' nested_type {'
  2650. ' name: "Baz"'
  2651. ' enum_type {'
  2652. ' name: "deep_enum"'
  2653. ' value {'
  2654. ' name: "VALUE_A"'
  2655. ' number: 0'
  2656. ' }'
  2657. ' }'
  2658. ' field {'
  2659. ' name: "deep"'
  2660. ' number: 4'
  2661. ' label: LABEL_OPTIONAL'
  2662. ' type: TYPE_UINT32'
  2663. ' }'
  2664. ' }'
  2665. ' }'
  2666. '}')
  2667. file_descriptor = descriptor_pb2.FileDescriptorProto()
  2668. text_format.Merge(file_descriptor_str, file_descriptor)
  2669. return file_descriptor.SerializeToString()
  2670. @testing_refleaks.SkipReferenceLeakChecker('MakeDescriptor is not repeatable')
  2671. # This test can only run once; the second time, it raises errors about
  2672. # conflicting message descriptors.
  2673. def testParsingFlatClassWithExplicitClassDeclaration(self):
  2674. """Test that the generated class can parse a flat message."""
  2675. # TODO(xiaofeng): This test fails with cpp implemetnation in the call
  2676. # of six.with_metaclass(). The other two callsites of with_metaclass
  2677. # in this file are both excluded from cpp test, so it might be expected
  2678. # to fail. Need someone more familiar with the python code to take a
  2679. # look at this.
  2680. if api_implementation.Type() != 'python':
  2681. return
  2682. file_descriptor = descriptor_pb2.FileDescriptorProto()
  2683. file_descriptor.ParseFromString(self._GetSerializedFileDescriptor('A'))
  2684. msg_descriptor = descriptor.MakeDescriptor(
  2685. file_descriptor.message_type[0])
  2686. class MessageClass(six.with_metaclass(reflection.GeneratedProtocolMessageType, message.Message)):
  2687. DESCRIPTOR = msg_descriptor
  2688. msg = MessageClass()
  2689. msg_str = (
  2690. 'flat: 0 '
  2691. 'flat: 1 '
  2692. 'flat: 2 ')
  2693. text_format.Merge(msg_str, msg)
  2694. self.assertEqual(msg.flat, [0, 1, 2])
  2695. @testing_refleaks.SkipReferenceLeakChecker('MakeDescriptor is not repeatable')
  2696. def testParsingFlatClass(self):
  2697. """Test that the generated class can parse a flat message."""
  2698. file_descriptor = descriptor_pb2.FileDescriptorProto()
  2699. file_descriptor.ParseFromString(self._GetSerializedFileDescriptor('B'))
  2700. msg_descriptor = descriptor.MakeDescriptor(
  2701. file_descriptor.message_type[0])
  2702. msg_class = reflection.MakeClass(msg_descriptor)
  2703. msg = msg_class()
  2704. msg_str = (
  2705. 'flat: 0 '
  2706. 'flat: 1 '
  2707. 'flat: 2 ')
  2708. text_format.Merge(msg_str, msg)
  2709. self.assertEqual(msg.flat, [0, 1, 2])
  2710. @testing_refleaks.SkipReferenceLeakChecker('MakeDescriptor is not repeatable')
  2711. def testParsingNestedClass(self):
  2712. """Test that the generated class can parse a nested message."""
  2713. file_descriptor = descriptor_pb2.FileDescriptorProto()
  2714. file_descriptor.ParseFromString(self._GetSerializedFileDescriptor('C'))
  2715. msg_descriptor = descriptor.MakeDescriptor(
  2716. file_descriptor.message_type[0])
  2717. msg_class = reflection.MakeClass(msg_descriptor)
  2718. msg = msg_class()
  2719. msg_str = (
  2720. 'bar {'
  2721. ' baz {'
  2722. ' deep: 4'
  2723. ' }'
  2724. '}')
  2725. text_format.Merge(msg_str, msg)
  2726. self.assertEqual(msg.bar.baz.deep, 4)
  2727. if __name__ == '__main__':
  2728. unittest.main()