reflection_test.py 134 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124212521262127212821292130213121322133213421352136213721382139214021412142214321442145214621472148214921502151215221532154215521562157215821592160216121622163216421652166216721682169217021712172217321742175217621772178217921802181218221832184218521862187218821892190219121922193219421952196219721982199220022012202220322042205220622072208220922102211221222132214221522162217221822192220222122222223222422252226222722282229223022312232223322342235223622372238223922402241224222432244224522462247224822492250225122522253225422552256225722582259226022612262226322642265226622672268226922702271227222732274227522762277227822792280228122822283228422852286228722882289229022912292229322942295229622972298229923002301230223032304230523062307230823092310231123122313231423152316231723182319232023212322232323242325232623272328232923302331233223332334233523362337233823392340234123422343234423452346234723482349235023512352235323542355235623572358235923602361236223632364236523662367236823692370237123722373237423752376237723782379238023812382238323842385238623872388238923902391239223932394239523962397239823992400240124022403240424052406240724082409241024112412241324142415241624172418241924202421242224232424242524262427242824292430243124322433243424352436243724382439244024412442244324442445244624472448244924502451245224532454245524562457245824592460246124622463246424652466246724682469247024712472247324742475247624772478247924802481248224832484248524862487248824892490249124922493249424952496249724982499250025012502250325042505250625072508250925102511251225132514251525162517251825192520252125222523252425252526252725282529253025312532253325342535253625372538253925402541254225432544254525462547254825492550255125522553255425552556255725582559256025612562256325642565256625672568256925702571257225732574257525762577257825792580258125822583258425852586258725882589259025912592259325942595259625972598259926002601260226032604260526062607260826092610261126122613261426152616261726182619262026212622262326242625262626272628262926302631263226332634263526362637263826392640264126422643264426452646264726482649265026512652265326542655265626572658265926602661266226632664266526662667266826692670267126722673267426752676267726782679268026812682268326842685268626872688268926902691269226932694269526962697269826992700270127022703270427052706270727082709271027112712271327142715271627172718271927202721272227232724272527262727272827292730273127322733273427352736273727382739274027412742274327442745274627472748274927502751275227532754275527562757275827592760276127622763276427652766276727682769277027712772277327742775277627772778277927802781278227832784278527862787278827892790279127922793279427952796279727982799280028012802280328042805280628072808280928102811281228132814281528162817281828192820282128222823282428252826282728282829283028312832283328342835283628372838283928402841284228432844284528462847284828492850285128522853285428552856285728582859286028612862286328642865286628672868286928702871287228732874287528762877287828792880288128822883288428852886288728882889289028912892289328942895289628972898289929002901290229032904290529062907290829092910291129122913291429152916291729182919292029212922292329242925292629272928292929302931293229332934293529362937293829392940294129422943294429452946294729482949295029512952295329542955295629572958295929602961296229632964296529662967296829692970297129722973297429752976297729782979298029812982298329842985298629872988298929902991299229932994299529962997299829993000300130023003300430053006300730083009301030113012301330143015301630173018301930203021302230233024302530263027302830293030303130323033303430353036303730383039304030413042304330443045304630473048304930503051305230533054305530563057305830593060306130623063306430653066306730683069307030713072307330743075307630773078307930803081308230833084308530863087308830893090309130923093309430953096309730983099310031013102310331043105310631073108310931103111311231133114311531163117311831193120312131223123312431253126312731283129313031313132313331343135313631373138313931403141314231433144314531463147314831493150315131523153315431553156315731583159316031613162316331643165316631673168316931703171317231733174317531763177317831793180318131823183318431853186318731883189319031913192319331943195319631973198319932003201320232033204320532063207320832093210321132123213321432153216321732183219322032213222322332243225322632273228322932303231323232333234323532363237323832393240324132423243324432453246324732483249325032513252325332543255325632573258325932603261326232633264326532663267326832693270327132723273327432753276327732783279328032813282328332843285328632873288328932903291329232933294329532963297329832993300330133023303330433053306330733083309331033113312331333143315331633173318331933203321332233233324332533263327
  1. #! /usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. #
  4. # Protocol Buffers - Google's data interchange format
  5. # Copyright 2008 Google Inc. All rights reserved.
  6. # https://developers.google.com/protocol-buffers/
  7. #
  8. # Redistribution and use in source and binary forms, with or without
  9. # modification, are permitted provided that the following conditions are
  10. # met:
  11. #
  12. # * Redistributions of source code must retain the above copyright
  13. # notice, this list of conditions and the following disclaimer.
  14. # * Redistributions in binary form must reproduce the above
  15. # copyright notice, this list of conditions and the following disclaimer
  16. # in the documentation and/or other materials provided with the
  17. # distribution.
  18. # * Neither the name of Google Inc. nor the names of its
  19. # contributors may be used to endorse or promote products derived from
  20. # this software without specific prior written permission.
  21. #
  22. # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
  23. # "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
  24. # LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
  25. # A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
  26. # OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
  27. # SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
  28. # LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
  29. # DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
  30. # THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
  31. # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
  32. # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
  33. """Unittest for reflection.py, which also indirectly tests the output of the
  34. pure-Python protocol compiler.
  35. """
  36. import copy
  37. import gc
  38. import operator
  39. import six
  40. import struct
  41. import warnings
  42. try:
  43. import unittest2 as unittest #PY26
  44. except ImportError:
  45. import unittest
  46. from google.protobuf import unittest_import_pb2
  47. from google.protobuf import unittest_mset_pb2
  48. from google.protobuf import unittest_pb2
  49. from google.protobuf import unittest_proto3_arena_pb2
  50. from google.protobuf import descriptor_pb2
  51. from google.protobuf import descriptor
  52. from google.protobuf import message
  53. from google.protobuf import reflection
  54. from google.protobuf import text_format
  55. from google.protobuf.internal import api_implementation
  56. from google.protobuf.internal import more_extensions_pb2
  57. from google.protobuf.internal import more_messages_pb2
  58. from google.protobuf.internal import message_set_extensions_pb2
  59. from google.protobuf.internal import wire_format
  60. from google.protobuf.internal import test_util
  61. from google.protobuf.internal import testing_refleaks
  62. from google.protobuf.internal import decoder
  63. from google.protobuf.internal import _parameterized
  64. if six.PY3:
  65. long = int # pylint: disable=redefined-builtin,invalid-name
  66. warnings.simplefilter('error', DeprecationWarning)
  67. class _MiniDecoder(object):
  68. """Decodes a stream of values from a string.
  69. Once upon a time we actually had a class called decoder.Decoder. Then we
  70. got rid of it during a redesign that made decoding much, much faster overall.
  71. But a couple tests in this file used it to check that the serialized form of
  72. a message was correct. So, this class implements just the methods that were
  73. used by said tests, so that we don't have to rewrite the tests.
  74. """
  75. def __init__(self, bytes):
  76. self._bytes = bytes
  77. self._pos = 0
  78. def ReadVarint(self):
  79. result, self._pos = decoder._DecodeVarint(self._bytes, self._pos)
  80. return result
  81. ReadInt32 = ReadVarint
  82. ReadInt64 = ReadVarint
  83. ReadUInt32 = ReadVarint
  84. ReadUInt64 = ReadVarint
  85. def ReadSInt64(self):
  86. return wire_format.ZigZagDecode(self.ReadVarint())
  87. ReadSInt32 = ReadSInt64
  88. def ReadFieldNumberAndWireType(self):
  89. return wire_format.UnpackTag(self.ReadVarint())
  90. def ReadFloat(self):
  91. result = struct.unpack('<f', self._bytes[self._pos:self._pos+4])[0]
  92. self._pos += 4
  93. return result
  94. def ReadDouble(self):
  95. result = struct.unpack('<d', self._bytes[self._pos:self._pos+8])[0]
  96. self._pos += 8
  97. return result
  98. def EndOfStream(self):
  99. return self._pos == len(self._bytes)
  100. @_parameterized.named_parameters(
  101. ('_proto2', unittest_pb2),
  102. ('_proto3', unittest_proto3_arena_pb2))
  103. @testing_refleaks.TestCase
  104. class ReflectionTest(unittest.TestCase):
  105. def assertListsEqual(self, values, others):
  106. self.assertEqual(len(values), len(others))
  107. for i in range(len(values)):
  108. self.assertEqual(values[i], others[i])
  109. def testScalarConstructor(self, message_module):
  110. # Constructor with only scalar types should succeed.
  111. proto = message_module.TestAllTypes(
  112. optional_int32=24,
  113. optional_double=54.321,
  114. optional_string='optional_string',
  115. optional_float=None)
  116. self.assertEqual(24, proto.optional_int32)
  117. self.assertEqual(54.321, proto.optional_double)
  118. self.assertEqual('optional_string', proto.optional_string)
  119. if message_module is unittest_pb2:
  120. self.assertFalse(proto.HasField("optional_float"))
  121. def testRepeatedScalarConstructor(self, message_module):
  122. # Constructor with only repeated scalar types should succeed.
  123. proto = message_module.TestAllTypes(
  124. repeated_int32=[1, 2, 3, 4],
  125. repeated_double=[1.23, 54.321],
  126. repeated_bool=[True, False, False],
  127. repeated_string=["optional_string"],
  128. repeated_float=None)
  129. self.assertEqual([1, 2, 3, 4], list(proto.repeated_int32))
  130. self.assertEqual([1.23, 54.321], list(proto.repeated_double))
  131. self.assertEqual([True, False, False], list(proto.repeated_bool))
  132. self.assertEqual(["optional_string"], list(proto.repeated_string))
  133. self.assertEqual([], list(proto.repeated_float))
  134. def testMixedConstructor(self, message_module):
  135. # Constructor with only mixed types should succeed.
  136. proto = message_module.TestAllTypes(
  137. optional_int32=24,
  138. optional_string='optional_string',
  139. repeated_double=[1.23, 54.321],
  140. repeated_bool=[True, False, False],
  141. repeated_nested_message=[
  142. message_module.TestAllTypes.NestedMessage(
  143. bb=message_module.TestAllTypes.FOO),
  144. message_module.TestAllTypes.NestedMessage(
  145. bb=message_module.TestAllTypes.BAR)],
  146. repeated_foreign_message=[
  147. message_module.ForeignMessage(c=-43),
  148. message_module.ForeignMessage(c=45324),
  149. message_module.ForeignMessage(c=12)],
  150. optional_nested_message=None)
  151. self.assertEqual(24, proto.optional_int32)
  152. self.assertEqual('optional_string', proto.optional_string)
  153. self.assertEqual([1.23, 54.321], list(proto.repeated_double))
  154. self.assertEqual([True, False, False], list(proto.repeated_bool))
  155. self.assertEqual(
  156. [message_module.TestAllTypes.NestedMessage(
  157. bb=message_module.TestAllTypes.FOO),
  158. message_module.TestAllTypes.NestedMessage(
  159. bb=message_module.TestAllTypes.BAR)],
  160. list(proto.repeated_nested_message))
  161. self.assertEqual(
  162. [message_module.ForeignMessage(c=-43),
  163. message_module.ForeignMessage(c=45324),
  164. message_module.ForeignMessage(c=12)],
  165. list(proto.repeated_foreign_message))
  166. self.assertFalse(proto.HasField("optional_nested_message"))
  167. def testConstructorTypeError(self, message_module):
  168. self.assertRaises(
  169. TypeError, message_module.TestAllTypes, optional_int32='foo')
  170. self.assertRaises(
  171. TypeError, message_module.TestAllTypes, optional_string=1234)
  172. self.assertRaises(
  173. TypeError, message_module.TestAllTypes, optional_nested_message=1234)
  174. self.assertRaises(
  175. TypeError, message_module.TestAllTypes, repeated_int32=1234)
  176. self.assertRaises(
  177. TypeError, message_module.TestAllTypes, repeated_int32=['foo'])
  178. self.assertRaises(
  179. TypeError, message_module.TestAllTypes, repeated_string=1234)
  180. self.assertRaises(
  181. TypeError, message_module.TestAllTypes, repeated_string=[1234])
  182. self.assertRaises(
  183. TypeError, message_module.TestAllTypes, repeated_nested_message=1234)
  184. self.assertRaises(
  185. TypeError, message_module.TestAllTypes, repeated_nested_message=[1234])
  186. def testConstructorInvalidatesCachedByteSize(self, message_module):
  187. message = message_module.TestAllTypes(optional_int32=12)
  188. self.assertEqual(2, message.ByteSize())
  189. message = message_module.TestAllTypes(
  190. optional_nested_message=message_module.TestAllTypes.NestedMessage())
  191. self.assertEqual(3, message.ByteSize())
  192. message = message_module.TestAllTypes(repeated_int32=[12])
  193. # TODO(jieluo): Add this test back for proto3
  194. if message_module is unittest_pb2:
  195. self.assertEqual(3, message.ByteSize())
  196. message = message_module.TestAllTypes(
  197. repeated_nested_message=[message_module.TestAllTypes.NestedMessage()])
  198. self.assertEqual(3, message.ByteSize())
  199. def testReferencesToNestedMessage(self, message_module):
  200. proto = message_module.TestAllTypes()
  201. nested = proto.optional_nested_message
  202. del proto
  203. # A previous version had a bug where this would raise an exception when
  204. # hitting a now-dead weak reference.
  205. nested.bb = 23
  206. def testOneOf(self, message_module):
  207. proto = message_module.TestAllTypes()
  208. proto.oneof_uint32 = 10
  209. proto.oneof_nested_message.bb = 11
  210. self.assertEqual(11, proto.oneof_nested_message.bb)
  211. self.assertFalse(proto.HasField('oneof_uint32'))
  212. nested = proto.oneof_nested_message
  213. proto.oneof_string = 'abc'
  214. self.assertEqual('abc', proto.oneof_string)
  215. self.assertEqual(11, nested.bb)
  216. self.assertFalse(proto.HasField('oneof_nested_message'))
  217. def testGetDefaultMessageAfterDisconnectingDefaultMessage(
  218. self, message_module):
  219. proto = message_module.TestAllTypes()
  220. nested = proto.optional_nested_message
  221. proto.ClearField('optional_nested_message')
  222. del proto
  223. del nested
  224. # Force a garbage collect so that the underlying CMessages are freed along
  225. # with the Messages they point to. This is to make sure we're not deleting
  226. # default message instances.
  227. gc.collect()
  228. proto = message_module.TestAllTypes()
  229. nested = proto.optional_nested_message
  230. def testDisconnectingNestedMessageAfterSettingField(self, message_module):
  231. proto = message_module.TestAllTypes()
  232. nested = proto.optional_nested_message
  233. nested.bb = 5
  234. self.assertTrue(proto.HasField('optional_nested_message'))
  235. proto.ClearField('optional_nested_message') # Should disconnect from parent
  236. self.assertEqual(5, nested.bb)
  237. self.assertEqual(0, proto.optional_nested_message.bb)
  238. self.assertIsNot(nested, proto.optional_nested_message)
  239. nested.bb = 23
  240. self.assertFalse(proto.HasField('optional_nested_message'))
  241. self.assertEqual(0, proto.optional_nested_message.bb)
  242. def testDisconnectingNestedMessageBeforeGettingField(self, message_module):
  243. proto = message_module.TestAllTypes()
  244. self.assertFalse(proto.HasField('optional_nested_message'))
  245. proto.ClearField('optional_nested_message')
  246. self.assertFalse(proto.HasField('optional_nested_message'))
  247. def testDisconnectingNestedMessageAfterMerge(self, message_module):
  248. # This test exercises the code path that does not use ReleaseMessage().
  249. # The underlying fear is that if we use ReleaseMessage() incorrectly,
  250. # we will have memory leaks. It's hard to check that that doesn't happen,
  251. # but at least we can exercise that code path to make sure it works.
  252. proto1 = message_module.TestAllTypes()
  253. proto2 = message_module.TestAllTypes()
  254. proto2.optional_nested_message.bb = 5
  255. proto1.MergeFrom(proto2)
  256. self.assertTrue(proto1.HasField('optional_nested_message'))
  257. proto1.ClearField('optional_nested_message')
  258. self.assertFalse(proto1.HasField('optional_nested_message'))
  259. def testDisconnectingLazyNestedMessage(self, message_module):
  260. # This test exercises releasing a nested message that is lazy. This test
  261. # only exercises real code in the C++ implementation as Python does not
  262. # support lazy parsing, but the current C++ implementation results in
  263. # memory corruption and a crash.
  264. if api_implementation.Type() != 'python':
  265. return
  266. proto = message_module.TestAllTypes()
  267. proto.optional_lazy_message.bb = 5
  268. proto.ClearField('optional_lazy_message')
  269. del proto
  270. gc.collect()
  271. def testSingularListFields(self, message_module):
  272. proto = message_module.TestAllTypes()
  273. proto.optional_fixed32 = 1
  274. proto.optional_int32 = 5
  275. proto.optional_string = 'foo'
  276. # Access sub-message but don't set it yet.
  277. nested_message = proto.optional_nested_message
  278. self.assertEqual(
  279. [ (proto.DESCRIPTOR.fields_by_name['optional_int32' ], 5),
  280. (proto.DESCRIPTOR.fields_by_name['optional_fixed32'], 1),
  281. (proto.DESCRIPTOR.fields_by_name['optional_string' ], 'foo') ],
  282. proto.ListFields())
  283. proto.optional_nested_message.bb = 123
  284. self.assertEqual(
  285. [ (proto.DESCRIPTOR.fields_by_name['optional_int32' ], 5),
  286. (proto.DESCRIPTOR.fields_by_name['optional_fixed32'], 1),
  287. (proto.DESCRIPTOR.fields_by_name['optional_string' ], 'foo'),
  288. (proto.DESCRIPTOR.fields_by_name['optional_nested_message' ],
  289. nested_message) ],
  290. proto.ListFields())
  291. def testRepeatedListFields(self, message_module):
  292. proto = message_module.TestAllTypes()
  293. proto.repeated_fixed32.append(1)
  294. proto.repeated_int32.append(5)
  295. proto.repeated_int32.append(11)
  296. proto.repeated_string.extend(['foo', 'bar'])
  297. proto.repeated_string.extend([])
  298. proto.repeated_string.append('baz')
  299. proto.repeated_string.extend(str(x) for x in range(2))
  300. proto.optional_int32 = 21
  301. proto.repeated_bool # Access but don't set anything; should not be listed.
  302. self.assertEqual(
  303. [ (proto.DESCRIPTOR.fields_by_name['optional_int32' ], 21),
  304. (proto.DESCRIPTOR.fields_by_name['repeated_int32' ], [5, 11]),
  305. (proto.DESCRIPTOR.fields_by_name['repeated_fixed32'], [1]),
  306. (proto.DESCRIPTOR.fields_by_name['repeated_string' ],
  307. ['foo', 'bar', 'baz', '0', '1']) ],
  308. proto.ListFields())
  309. def testClearFieldWithUnknownFieldName(self, message_module):
  310. proto = message_module.TestAllTypes()
  311. self.assertRaises(ValueError, proto.ClearField, 'nonexistent_field')
  312. self.assertRaises(ValueError, proto.ClearField, b'nonexistent_field')
  313. def testDisallowedAssignments(self, message_module):
  314. # It's illegal to assign values directly to repeated fields
  315. # or to nonrepeated composite fields. Ensure that this fails.
  316. proto = message_module.TestAllTypes()
  317. # Repeated fields.
  318. self.assertRaises(AttributeError, setattr, proto, 'repeated_int32', 10)
  319. # Lists shouldn't work, either.
  320. self.assertRaises(AttributeError, setattr, proto, 'repeated_int32', [10])
  321. # Composite fields.
  322. self.assertRaises(AttributeError, setattr, proto,
  323. 'optional_nested_message', 23)
  324. # Assignment to a repeated nested message field without specifying
  325. # the index in the array of nested messages.
  326. self.assertRaises(AttributeError, setattr, proto.repeated_nested_message,
  327. 'bb', 34)
  328. # Assignment to an attribute of a repeated field.
  329. self.assertRaises(AttributeError, setattr, proto.repeated_float,
  330. 'some_attribute', 34)
  331. # proto.nonexistent_field = 23 should fail as well.
  332. self.assertRaises(AttributeError, setattr, proto, 'nonexistent_field', 23)
  333. def testSingleScalarTypeSafety(self, message_module):
  334. proto = message_module.TestAllTypes()
  335. self.assertRaises(TypeError, setattr, proto, 'optional_int32', 1.1)
  336. self.assertRaises(TypeError, setattr, proto, 'optional_int32', 'foo')
  337. self.assertRaises(TypeError, setattr, proto, 'optional_string', 10)
  338. self.assertRaises(TypeError, setattr, proto, 'optional_bytes', 10)
  339. self.assertRaises(TypeError, setattr, proto, 'optional_bool', 'foo')
  340. self.assertRaises(TypeError, setattr, proto, 'optional_float', 'foo')
  341. self.assertRaises(TypeError, setattr, proto, 'optional_double', 'foo')
  342. # TODO(jieluo): Fix type checking difference for python and c extension
  343. if api_implementation.Type() == 'python':
  344. self.assertRaises(TypeError, setattr, proto, 'optional_bool', 1.1)
  345. else:
  346. proto.optional_bool = 1.1
  347. def assertIntegerTypes(self, integer_fn, message_module):
  348. """Verifies setting of scalar integers.
  349. Args:
  350. integer_fn: A function to wrap the integers that will be assigned.
  351. message_module: unittest_pb2 or unittest_proto3_arena_pb2
  352. """
  353. def TestGetAndDeserialize(field_name, value, expected_type):
  354. proto = message_module.TestAllTypes()
  355. value = integer_fn(value)
  356. setattr(proto, field_name, value)
  357. self.assertIsInstance(getattr(proto, field_name), expected_type)
  358. proto2 = message_module.TestAllTypes()
  359. proto2.ParseFromString(proto.SerializeToString())
  360. self.assertIsInstance(getattr(proto2, field_name), expected_type)
  361. TestGetAndDeserialize('optional_int32', 1, int)
  362. TestGetAndDeserialize('optional_int32', 1 << 30, int)
  363. TestGetAndDeserialize('optional_uint32', 1 << 30, int)
  364. integer_64 = long
  365. if struct.calcsize('L') == 4:
  366. # Python only has signed ints, so 32-bit python can't fit an uint32
  367. # in an int.
  368. TestGetAndDeserialize('optional_uint32', 1 << 31, integer_64)
  369. else:
  370. # 64-bit python can fit uint32 inside an int
  371. TestGetAndDeserialize('optional_uint32', 1 << 31, int)
  372. TestGetAndDeserialize('optional_int64', 1 << 30, integer_64)
  373. TestGetAndDeserialize('optional_int64', 1 << 60, integer_64)
  374. TestGetAndDeserialize('optional_uint64', 1 << 30, integer_64)
  375. TestGetAndDeserialize('optional_uint64', 1 << 60, integer_64)
  376. def testIntegerTypes(self, message_module):
  377. self.assertIntegerTypes(lambda x: x, message_module)
  378. def testNonStandardIntegerTypes(self, message_module):
  379. self.assertIntegerTypes(test_util.NonStandardInteger, message_module)
  380. def testIllegalValuesForIntegers(self, message_module):
  381. pb = message_module.TestAllTypes()
  382. # Strings are illegal, even when the represent an integer.
  383. with self.assertRaises(TypeError):
  384. pb.optional_uint64 = '2'
  385. # The exact error should propagate with a poorly written custom integer.
  386. with self.assertRaisesRegexp(RuntimeError, 'my_error'):
  387. pb.optional_uint64 = test_util.NonStandardInteger(5, 'my_error')
  388. def assetIntegerBoundsChecking(self, integer_fn, message_module):
  389. """Verifies bounds checking for scalar integer fields.
  390. Args:
  391. integer_fn: A function to wrap the integers that will be assigned.
  392. message_module: unittest_pb2 or unittest_proto3_arena_pb2
  393. """
  394. def TestMinAndMaxIntegers(field_name, expected_min, expected_max):
  395. pb = message_module.TestAllTypes()
  396. expected_min = integer_fn(expected_min)
  397. expected_max = integer_fn(expected_max)
  398. setattr(pb, field_name, expected_min)
  399. self.assertEqual(expected_min, getattr(pb, field_name))
  400. setattr(pb, field_name, expected_max)
  401. self.assertEqual(expected_max, getattr(pb, field_name))
  402. self.assertRaises((ValueError, TypeError), setattr, pb, field_name,
  403. expected_min - 1)
  404. self.assertRaises((ValueError, TypeError), setattr, pb, field_name,
  405. expected_max + 1)
  406. TestMinAndMaxIntegers('optional_int32', -(1 << 31), (1 << 31) - 1)
  407. TestMinAndMaxIntegers('optional_uint32', 0, 0xffffffff)
  408. TestMinAndMaxIntegers('optional_int64', -(1 << 63), (1 << 63) - 1)
  409. TestMinAndMaxIntegers('optional_uint64', 0, 0xffffffffffffffff)
  410. # A bit of white-box testing since -1 is an int and not a long in C++ and
  411. # so goes down a different path.
  412. pb = message_module.TestAllTypes()
  413. with self.assertRaises((ValueError, TypeError)):
  414. pb.optional_uint64 = integer_fn(-(1 << 63))
  415. pb = message_module.TestAllTypes()
  416. pb.optional_nested_enum = integer_fn(1)
  417. self.assertEqual(1, pb.optional_nested_enum)
  418. def testSingleScalarBoundsChecking(self, message_module):
  419. self.assetIntegerBoundsChecking(lambda x: x, message_module)
  420. def testNonStandardSingleScalarBoundsChecking(self, message_module):
  421. self.assetIntegerBoundsChecking(
  422. test_util.NonStandardInteger, message_module)
  423. def testRepeatedScalarTypeSafety(self, message_module):
  424. proto = message_module.TestAllTypes()
  425. self.assertRaises(TypeError, proto.repeated_int32.append, 1.1)
  426. self.assertRaises(TypeError, proto.repeated_int32.append, 'foo')
  427. self.assertRaises(TypeError, proto.repeated_string, 10)
  428. self.assertRaises(TypeError, proto.repeated_bytes, 10)
  429. proto.repeated_int32.append(10)
  430. proto.repeated_int32[0] = 23
  431. self.assertRaises(IndexError, proto.repeated_int32.__setitem__, 500, 23)
  432. self.assertRaises(TypeError, proto.repeated_int32.__setitem__, 0, 'abc')
  433. self.assertRaises(TypeError, proto.repeated_int32.__setitem__, 0, [])
  434. self.assertRaises(TypeError, proto.repeated_int32.__setitem__,
  435. 'index', 23)
  436. proto.repeated_string.append('2')
  437. self.assertRaises(TypeError, proto.repeated_string.__setitem__, 0, 10)
  438. # Repeated enums tests.
  439. #proto.repeated_nested_enum.append(0)
  440. def testSingleScalarGettersAndSetters(self, message_module):
  441. proto = message_module.TestAllTypes()
  442. self.assertEqual(0, proto.optional_int32)
  443. proto.optional_int32 = 1
  444. self.assertEqual(1, proto.optional_int32)
  445. proto.optional_uint64 = 0xffffffffffff
  446. self.assertEqual(0xffffffffffff, proto.optional_uint64)
  447. proto.optional_uint64 = 0xffffffffffffffff
  448. self.assertEqual(0xffffffffffffffff, proto.optional_uint64)
  449. # TODO(robinson): Test all other scalar field types.
  450. def testEnums(self, message_module):
  451. proto = message_module.TestAllTypes()
  452. self.assertEqual(1, proto.FOO)
  453. self.assertEqual(1, message_module.TestAllTypes.FOO)
  454. self.assertEqual(2, proto.BAR)
  455. self.assertEqual(2, message_module.TestAllTypes.BAR)
  456. self.assertEqual(3, proto.BAZ)
  457. self.assertEqual(3, message_module.TestAllTypes.BAZ)
  458. def testEnum_Name(self, message_module):
  459. self.assertEqual(
  460. 'FOREIGN_FOO',
  461. message_module.ForeignEnum.Name(message_module.FOREIGN_FOO))
  462. self.assertEqual(
  463. 'FOREIGN_BAR',
  464. message_module.ForeignEnum.Name(message_module.FOREIGN_BAR))
  465. self.assertEqual(
  466. 'FOREIGN_BAZ',
  467. message_module.ForeignEnum.Name(message_module.FOREIGN_BAZ))
  468. self.assertRaises(ValueError,
  469. message_module.ForeignEnum.Name, 11312)
  470. proto = message_module.TestAllTypes()
  471. self.assertEqual('FOO',
  472. proto.NestedEnum.Name(proto.FOO))
  473. self.assertEqual('FOO',
  474. message_module.TestAllTypes.NestedEnum.Name(proto.FOO))
  475. self.assertEqual('BAR',
  476. proto.NestedEnum.Name(proto.BAR))
  477. self.assertEqual('BAR',
  478. message_module.TestAllTypes.NestedEnum.Name(proto.BAR))
  479. self.assertEqual('BAZ',
  480. proto.NestedEnum.Name(proto.BAZ))
  481. self.assertEqual('BAZ',
  482. message_module.TestAllTypes.NestedEnum.Name(proto.BAZ))
  483. self.assertRaises(ValueError,
  484. proto.NestedEnum.Name, 11312)
  485. self.assertRaises(ValueError,
  486. message_module.TestAllTypes.NestedEnum.Name, 11312)
  487. # Check some coercion cases.
  488. self.assertRaises(TypeError, message_module.TestAllTypes.NestedEnum.Name,
  489. 11312.0)
  490. self.assertRaises(TypeError, message_module.TestAllTypes.NestedEnum.Name,
  491. None)
  492. self.assertEqual('FOO', message_module.TestAllTypes.NestedEnum.Name(True))
  493. def testEnum_Value(self, message_module):
  494. self.assertEqual(message_module.FOREIGN_FOO,
  495. message_module.ForeignEnum.Value('FOREIGN_FOO'))
  496. self.assertEqual(message_module.FOREIGN_FOO,
  497. message_module.ForeignEnum.FOREIGN_FOO)
  498. self.assertEqual(message_module.FOREIGN_BAR,
  499. message_module.ForeignEnum.Value('FOREIGN_BAR'))
  500. self.assertEqual(message_module.FOREIGN_BAR,
  501. message_module.ForeignEnum.FOREIGN_BAR)
  502. self.assertEqual(message_module.FOREIGN_BAZ,
  503. message_module.ForeignEnum.Value('FOREIGN_BAZ'))
  504. self.assertEqual(message_module.FOREIGN_BAZ,
  505. message_module.ForeignEnum.FOREIGN_BAZ)
  506. self.assertRaises(ValueError,
  507. message_module.ForeignEnum.Value, 'FO')
  508. with self.assertRaises(AttributeError):
  509. message_module.ForeignEnum.FO
  510. proto = message_module.TestAllTypes()
  511. self.assertEqual(proto.FOO,
  512. proto.NestedEnum.Value('FOO'))
  513. self.assertEqual(proto.FOO,
  514. proto.NestedEnum.FOO)
  515. self.assertEqual(proto.FOO,
  516. message_module.TestAllTypes.NestedEnum.Value('FOO'))
  517. self.assertEqual(proto.FOO,
  518. message_module.TestAllTypes.NestedEnum.FOO)
  519. self.assertEqual(proto.BAR,
  520. proto.NestedEnum.Value('BAR'))
  521. self.assertEqual(proto.BAR,
  522. proto.NestedEnum.BAR)
  523. self.assertEqual(proto.BAR,
  524. message_module.TestAllTypes.NestedEnum.Value('BAR'))
  525. self.assertEqual(proto.BAR,
  526. message_module.TestAllTypes.NestedEnum.BAR)
  527. self.assertEqual(proto.BAZ,
  528. proto.NestedEnum.Value('BAZ'))
  529. self.assertEqual(proto.BAZ,
  530. proto.NestedEnum.BAZ)
  531. self.assertEqual(proto.BAZ,
  532. message_module.TestAllTypes.NestedEnum.Value('BAZ'))
  533. self.assertEqual(proto.BAZ,
  534. message_module.TestAllTypes.NestedEnum.BAZ)
  535. self.assertRaises(ValueError,
  536. proto.NestedEnum.Value, 'Foo')
  537. with self.assertRaises(AttributeError):
  538. proto.NestedEnum.Value.Foo
  539. self.assertRaises(ValueError,
  540. message_module.TestAllTypes.NestedEnum.Value, 'Foo')
  541. with self.assertRaises(AttributeError):
  542. message_module.TestAllTypes.NestedEnum.Value.Foo
  543. def testEnum_KeysAndValues(self, message_module):
  544. if message_module == unittest_pb2:
  545. keys = ['FOREIGN_FOO', 'FOREIGN_BAR', 'FOREIGN_BAZ']
  546. values = [4, 5, 6]
  547. items = [('FOREIGN_FOO', 4), ('FOREIGN_BAR', 5), ('FOREIGN_BAZ', 6)]
  548. else:
  549. keys = ['FOREIGN_ZERO', 'FOREIGN_FOO', 'FOREIGN_BAR', 'FOREIGN_BAZ']
  550. values = [0, 4, 5, 6]
  551. items = [('FOREIGN_ZERO', 0), ('FOREIGN_FOO', 4),
  552. ('FOREIGN_BAR', 5), ('FOREIGN_BAZ', 6)]
  553. self.assertEqual(keys,
  554. list(message_module.ForeignEnum.keys()))
  555. self.assertEqual(values,
  556. list(message_module.ForeignEnum.values()))
  557. self.assertEqual(items,
  558. list(message_module.ForeignEnum.items()))
  559. proto = message_module.TestAllTypes()
  560. if message_module == unittest_pb2:
  561. keys = ['FOO', 'BAR', 'BAZ', 'NEG']
  562. values = [1, 2, 3, -1]
  563. items = [('FOO', 1), ('BAR', 2), ('BAZ', 3), ('NEG', -1)]
  564. else:
  565. keys = ['ZERO', 'FOO', 'BAR', 'BAZ', 'NEG']
  566. values = [0, 1, 2, 3, -1]
  567. items = [('ZERO', 0), ('FOO', 1), ('BAR', 2), ('BAZ', 3), ('NEG', -1)]
  568. self.assertEqual(keys, list(proto.NestedEnum.keys()))
  569. self.assertEqual(values, list(proto.NestedEnum.values()))
  570. self.assertEqual(items,
  571. list(proto.NestedEnum.items()))
  572. def testStaticParseFrom(self, message_module):
  573. proto1 = message_module.TestAllTypes()
  574. test_util.SetAllFields(proto1)
  575. string1 = proto1.SerializeToString()
  576. proto2 = message_module.TestAllTypes.FromString(string1)
  577. # Messages should be equal.
  578. self.assertEqual(proto2, proto1)
  579. def testMergeFromSingularField(self, message_module):
  580. # Test merge with just a singular field.
  581. proto1 = message_module.TestAllTypes()
  582. proto1.optional_int32 = 1
  583. proto2 = message_module.TestAllTypes()
  584. # This shouldn't get overwritten.
  585. proto2.optional_string = 'value'
  586. proto2.MergeFrom(proto1)
  587. self.assertEqual(1, proto2.optional_int32)
  588. self.assertEqual('value', proto2.optional_string)
  589. def testMergeFromRepeatedField(self, message_module):
  590. # Test merge with just a repeated field.
  591. proto1 = message_module.TestAllTypes()
  592. proto1.repeated_int32.append(1)
  593. proto1.repeated_int32.append(2)
  594. proto2 = message_module.TestAllTypes()
  595. proto2.repeated_int32.append(0)
  596. proto2.MergeFrom(proto1)
  597. self.assertEqual(0, proto2.repeated_int32[0])
  598. self.assertEqual(1, proto2.repeated_int32[1])
  599. self.assertEqual(2, proto2.repeated_int32[2])
  600. def testMergeFromRepeatedNestedMessage(self, message_module):
  601. # Test merge with a repeated nested message.
  602. proto1 = message_module.TestAllTypes()
  603. m = proto1.repeated_nested_message.add()
  604. m.bb = 123
  605. m = proto1.repeated_nested_message.add()
  606. m.bb = 321
  607. proto2 = message_module.TestAllTypes()
  608. m = proto2.repeated_nested_message.add()
  609. m.bb = 999
  610. proto2.MergeFrom(proto1)
  611. self.assertEqual(999, proto2.repeated_nested_message[0].bb)
  612. self.assertEqual(123, proto2.repeated_nested_message[1].bb)
  613. self.assertEqual(321, proto2.repeated_nested_message[2].bb)
  614. proto3 = message_module.TestAllTypes()
  615. proto3.repeated_nested_message.MergeFrom(proto2.repeated_nested_message)
  616. self.assertEqual(999, proto3.repeated_nested_message[0].bb)
  617. self.assertEqual(123, proto3.repeated_nested_message[1].bb)
  618. self.assertEqual(321, proto3.repeated_nested_message[2].bb)
  619. def testMergeFromAllFields(self, message_module):
  620. # With all fields set.
  621. proto1 = message_module.TestAllTypes()
  622. test_util.SetAllFields(proto1)
  623. proto2 = message_module.TestAllTypes()
  624. proto2.MergeFrom(proto1)
  625. # Messages should be equal.
  626. self.assertEqual(proto2, proto1)
  627. # Serialized string should be equal too.
  628. string1 = proto1.SerializeToString()
  629. string2 = proto2.SerializeToString()
  630. self.assertEqual(string1, string2)
  631. def testMergeFromBug(self, message_module):
  632. message1 = message_module.TestAllTypes()
  633. message2 = message_module.TestAllTypes()
  634. # Cause optional_nested_message to be instantiated within message1, even
  635. # though it is not considered to be "present".
  636. message1.optional_nested_message
  637. self.assertFalse(message1.HasField('optional_nested_message'))
  638. # Merge into message2. This should not instantiate the field is message2.
  639. message2.MergeFrom(message1)
  640. self.assertFalse(message2.HasField('optional_nested_message'))
  641. def testCopyFromSingularField(self, message_module):
  642. # Test copy with just a singular field.
  643. proto1 = message_module.TestAllTypes()
  644. proto1.optional_int32 = 1
  645. proto1.optional_string = 'important-text'
  646. proto2 = message_module.TestAllTypes()
  647. proto2.optional_string = 'value'
  648. proto2.CopyFrom(proto1)
  649. self.assertEqual(1, proto2.optional_int32)
  650. self.assertEqual('important-text', proto2.optional_string)
  651. def testCopyFromRepeatedField(self, message_module):
  652. # Test copy with a repeated field.
  653. proto1 = message_module.TestAllTypes()
  654. proto1.repeated_int32.append(1)
  655. proto1.repeated_int32.append(2)
  656. proto2 = message_module.TestAllTypes()
  657. proto2.repeated_int32.append(0)
  658. proto2.CopyFrom(proto1)
  659. self.assertEqual(1, proto2.repeated_int32[0])
  660. self.assertEqual(2, proto2.repeated_int32[1])
  661. def testCopyFromAllFields(self, message_module):
  662. # With all fields set.
  663. proto1 = message_module.TestAllTypes()
  664. test_util.SetAllFields(proto1)
  665. proto2 = message_module.TestAllTypes()
  666. proto2.CopyFrom(proto1)
  667. # Messages should be equal.
  668. self.assertEqual(proto2, proto1)
  669. # Serialized string should be equal too.
  670. string1 = proto1.SerializeToString()
  671. string2 = proto2.SerializeToString()
  672. self.assertEqual(string1, string2)
  673. def testCopyFromSelf(self, message_module):
  674. proto1 = message_module.TestAllTypes()
  675. proto1.repeated_int32.append(1)
  676. proto1.optional_int32 = 2
  677. proto1.optional_string = 'important-text'
  678. proto1.CopyFrom(proto1)
  679. self.assertEqual(1, proto1.repeated_int32[0])
  680. self.assertEqual(2, proto1.optional_int32)
  681. self.assertEqual('important-text', proto1.optional_string)
  682. def testDeepCopy(self, message_module):
  683. proto1 = message_module.TestAllTypes()
  684. proto1.optional_int32 = 1
  685. proto2 = copy.deepcopy(proto1)
  686. self.assertEqual(1, proto2.optional_int32)
  687. proto1.repeated_int32.append(2)
  688. proto1.repeated_int32.append(3)
  689. container = copy.deepcopy(proto1.repeated_int32)
  690. self.assertEqual([2, 3], container)
  691. container.remove(container[0])
  692. self.assertEqual([3], container)
  693. message1 = proto1.repeated_nested_message.add()
  694. message1.bb = 1
  695. messages = copy.deepcopy(proto1.repeated_nested_message)
  696. self.assertEqual(proto1.repeated_nested_message, messages)
  697. message1.bb = 2
  698. self.assertNotEqual(proto1.repeated_nested_message, messages)
  699. messages.remove(messages[0])
  700. self.assertEqual(len(messages), 0)
  701. # TODO(anuraag): Implement deepcopy for extension dict
  702. def testDisconnectingBeforeClear(self, message_module):
  703. proto = message_module.TestAllTypes()
  704. nested = proto.optional_nested_message
  705. proto.Clear()
  706. self.assertIsNot(nested, proto.optional_nested_message)
  707. nested.bb = 23
  708. self.assertFalse(proto.HasField('optional_nested_message'))
  709. self.assertEqual(0, proto.optional_nested_message.bb)
  710. proto = message_module.TestAllTypes()
  711. nested = proto.optional_nested_message
  712. nested.bb = 5
  713. foreign = proto.optional_foreign_message
  714. foreign.c = 6
  715. proto.Clear()
  716. self.assertIsNot(nested, proto.optional_nested_message)
  717. self.assertIsNot(foreign, proto.optional_foreign_message)
  718. self.assertEqual(5, nested.bb)
  719. self.assertEqual(6, foreign.c)
  720. nested.bb = 15
  721. foreign.c = 16
  722. self.assertFalse(proto.HasField('optional_nested_message'))
  723. self.assertEqual(0, proto.optional_nested_message.bb)
  724. self.assertFalse(proto.HasField('optional_foreign_message'))
  725. self.assertEqual(0, proto.optional_foreign_message.c)
  726. def testStringUTF8Encoding(self, message_module):
  727. proto = message_module.TestAllTypes()
  728. # Assignment of a unicode object to a field of type 'bytes' is not allowed.
  729. self.assertRaises(TypeError,
  730. setattr, proto, 'optional_bytes', u'unicode object')
  731. # Check that the default value is of python's 'unicode' type.
  732. self.assertEqual(type(proto.optional_string), six.text_type)
  733. proto.optional_string = six.text_type('Testing')
  734. self.assertEqual(proto.optional_string, str('Testing'))
  735. # Assign a value of type 'str' which can be encoded in UTF-8.
  736. proto.optional_string = str('Testing')
  737. self.assertEqual(proto.optional_string, six.text_type('Testing'))
  738. # Try to assign a 'bytes' object which contains non-UTF-8.
  739. self.assertRaises(ValueError,
  740. setattr, proto, 'optional_string', b'a\x80a')
  741. # No exception: Assign already encoded UTF-8 bytes to a string field.
  742. utf8_bytes = u'Тест'.encode('utf-8')
  743. proto.optional_string = utf8_bytes
  744. # No exception: Assign the a non-ascii unicode object.
  745. proto.optional_string = u'Тест'
  746. # No exception thrown (normal str assignment containing ASCII).
  747. proto.optional_string = 'abc'
  748. def testBytesInTextFormat(self, message_module):
  749. proto = message_module.TestAllTypes(optional_bytes=b'\x00\x7f\x80\xff')
  750. self.assertEqual(u'optional_bytes: "\\000\\177\\200\\377"\n',
  751. six.text_type(proto))
  752. def testEmptyNestedMessage(self, message_module):
  753. proto = message_module.TestAllTypes()
  754. proto.optional_nested_message.MergeFrom(
  755. message_module.TestAllTypes.NestedMessage())
  756. self.assertTrue(proto.HasField('optional_nested_message'))
  757. proto = message_module.TestAllTypes()
  758. proto.optional_nested_message.CopyFrom(
  759. message_module.TestAllTypes.NestedMessage())
  760. self.assertTrue(proto.HasField('optional_nested_message'))
  761. proto = message_module.TestAllTypes()
  762. bytes_read = proto.optional_nested_message.MergeFromString(b'')
  763. self.assertEqual(0, bytes_read)
  764. self.assertTrue(proto.HasField('optional_nested_message'))
  765. proto = message_module.TestAllTypes()
  766. proto.optional_nested_message.ParseFromString(b'')
  767. self.assertTrue(proto.HasField('optional_nested_message'))
  768. serialized = proto.SerializeToString()
  769. proto2 = message_module.TestAllTypes()
  770. self.assertEqual(
  771. len(serialized),
  772. proto2.MergeFromString(serialized))
  773. self.assertTrue(proto2.HasField('optional_nested_message'))
  774. # Class to test proto2-only features (required, extensions, etc.)
  775. @testing_refleaks.TestCase
  776. class Proto2ReflectionTest(unittest.TestCase):
  777. def testRepeatedCompositeConstructor(self):
  778. # Constructor with only repeated composite types should succeed.
  779. proto = unittest_pb2.TestAllTypes(
  780. repeated_nested_message=[
  781. unittest_pb2.TestAllTypes.NestedMessage(
  782. bb=unittest_pb2.TestAllTypes.FOO),
  783. unittest_pb2.TestAllTypes.NestedMessage(
  784. bb=unittest_pb2.TestAllTypes.BAR)],
  785. repeated_foreign_message=[
  786. unittest_pb2.ForeignMessage(c=-43),
  787. unittest_pb2.ForeignMessage(c=45324),
  788. unittest_pb2.ForeignMessage(c=12)],
  789. repeatedgroup=[
  790. unittest_pb2.TestAllTypes.RepeatedGroup(),
  791. unittest_pb2.TestAllTypes.RepeatedGroup(a=1),
  792. unittest_pb2.TestAllTypes.RepeatedGroup(a=2)])
  793. self.assertEqual(
  794. [unittest_pb2.TestAllTypes.NestedMessage(
  795. bb=unittest_pb2.TestAllTypes.FOO),
  796. unittest_pb2.TestAllTypes.NestedMessage(
  797. bb=unittest_pb2.TestAllTypes.BAR)],
  798. list(proto.repeated_nested_message))
  799. self.assertEqual(
  800. [unittest_pb2.ForeignMessage(c=-43),
  801. unittest_pb2.ForeignMessage(c=45324),
  802. unittest_pb2.ForeignMessage(c=12)],
  803. list(proto.repeated_foreign_message))
  804. self.assertEqual(
  805. [unittest_pb2.TestAllTypes.RepeatedGroup(),
  806. unittest_pb2.TestAllTypes.RepeatedGroup(a=1),
  807. unittest_pb2.TestAllTypes.RepeatedGroup(a=2)],
  808. list(proto.repeatedgroup))
  809. def assertListsEqual(self, values, others):
  810. self.assertEqual(len(values), len(others))
  811. for i in range(len(values)):
  812. self.assertEqual(values[i], others[i])
  813. def testSimpleHasBits(self):
  814. # Test a scalar.
  815. proto = unittest_pb2.TestAllTypes()
  816. self.assertFalse(proto.HasField('optional_int32'))
  817. self.assertEqual(0, proto.optional_int32)
  818. # HasField() shouldn't be true if all we've done is
  819. # read the default value.
  820. self.assertFalse(proto.HasField('optional_int32'))
  821. proto.optional_int32 = 1
  822. # Setting a value however *should* set the "has" bit.
  823. self.assertTrue(proto.HasField('optional_int32'))
  824. proto.ClearField('optional_int32')
  825. # And clearing that value should unset the "has" bit.
  826. self.assertFalse(proto.HasField('optional_int32'))
  827. def testHasBitsWithSinglyNestedScalar(self):
  828. # Helper used to test foreign messages and groups.
  829. #
  830. # composite_field_name should be the name of a non-repeated
  831. # composite (i.e., foreign or group) field in TestAllTypes,
  832. # and scalar_field_name should be the name of an integer-valued
  833. # scalar field within that composite.
  834. #
  835. # I never thought I'd miss C++ macros and templates so much. :(
  836. # This helper is semantically just:
  837. #
  838. # assert proto.composite_field.scalar_field == 0
  839. # assert not proto.composite_field.HasField('scalar_field')
  840. # assert not proto.HasField('composite_field')
  841. #
  842. # proto.composite_field.scalar_field = 10
  843. # old_composite_field = proto.composite_field
  844. #
  845. # assert proto.composite_field.scalar_field == 10
  846. # assert proto.composite_field.HasField('scalar_field')
  847. # assert proto.HasField('composite_field')
  848. #
  849. # proto.ClearField('composite_field')
  850. #
  851. # assert not proto.composite_field.HasField('scalar_field')
  852. # assert not proto.HasField('composite_field')
  853. # assert proto.composite_field.scalar_field == 0
  854. #
  855. # # Now ensure that ClearField('composite_field') disconnected
  856. # # the old field object from the object tree...
  857. # assert old_composite_field is not proto.composite_field
  858. # old_composite_field.scalar_field = 20
  859. # assert not proto.composite_field.HasField('scalar_field')
  860. # assert not proto.HasField('composite_field')
  861. def TestCompositeHasBits(composite_field_name, scalar_field_name):
  862. proto = unittest_pb2.TestAllTypes()
  863. # First, check that we can get the scalar value, and see that it's the
  864. # default (0), but that proto.HasField('omposite') and
  865. # proto.composite.HasField('scalar') will still return False.
  866. composite_field = getattr(proto, composite_field_name)
  867. original_scalar_value = getattr(composite_field, scalar_field_name)
  868. self.assertEqual(0, original_scalar_value)
  869. # Assert that the composite object does not "have" the scalar.
  870. self.assertFalse(composite_field.HasField(scalar_field_name))
  871. # Assert that proto does not "have" the composite field.
  872. self.assertFalse(proto.HasField(composite_field_name))
  873. # Now set the scalar within the composite field. Ensure that the setting
  874. # is reflected, and that proto.HasField('composite') and
  875. # proto.composite.HasField('scalar') now both return True.
  876. new_val = 20
  877. setattr(composite_field, scalar_field_name, new_val)
  878. self.assertEqual(new_val, getattr(composite_field, scalar_field_name))
  879. # Hold on to a reference to the current composite_field object.
  880. old_composite_field = composite_field
  881. # Assert that the has methods now return true.
  882. self.assertTrue(composite_field.HasField(scalar_field_name))
  883. self.assertTrue(proto.HasField(composite_field_name))
  884. # Now call the clear method...
  885. proto.ClearField(composite_field_name)
  886. # ...and ensure that the "has" bits are all back to False...
  887. composite_field = getattr(proto, composite_field_name)
  888. self.assertFalse(composite_field.HasField(scalar_field_name))
  889. self.assertFalse(proto.HasField(composite_field_name))
  890. # ...and ensure that the scalar field has returned to its default.
  891. self.assertEqual(0, getattr(composite_field, scalar_field_name))
  892. self.assertIsNot(old_composite_field, composite_field)
  893. setattr(old_composite_field, scalar_field_name, new_val)
  894. self.assertFalse(composite_field.HasField(scalar_field_name))
  895. self.assertFalse(proto.HasField(composite_field_name))
  896. self.assertEqual(0, getattr(composite_field, scalar_field_name))
  897. # Test simple, single-level nesting when we set a scalar.
  898. TestCompositeHasBits('optionalgroup', 'a')
  899. TestCompositeHasBits('optional_nested_message', 'bb')
  900. TestCompositeHasBits('optional_foreign_message', 'c')
  901. TestCompositeHasBits('optional_import_message', 'd')
  902. def testHasBitsWhenModifyingRepeatedFields(self):
  903. # Test nesting when we add an element to a repeated field in a submessage.
  904. proto = unittest_pb2.TestNestedMessageHasBits()
  905. proto.optional_nested_message.nestedmessage_repeated_int32.append(5)
  906. self.assertEqual(
  907. [5], proto.optional_nested_message.nestedmessage_repeated_int32)
  908. self.assertTrue(proto.HasField('optional_nested_message'))
  909. # Do the same test, but with a repeated composite field within the
  910. # submessage.
  911. proto.ClearField('optional_nested_message')
  912. self.assertFalse(proto.HasField('optional_nested_message'))
  913. proto.optional_nested_message.nestedmessage_repeated_foreignmessage.add()
  914. self.assertTrue(proto.HasField('optional_nested_message'))
  915. def testHasBitsForManyLevelsOfNesting(self):
  916. # Test nesting many levels deep.
  917. recursive_proto = unittest_pb2.TestMutualRecursionA()
  918. self.assertFalse(recursive_proto.HasField('bb'))
  919. self.assertEqual(0, recursive_proto.bb.a.bb.a.bb.optional_int32)
  920. self.assertFalse(recursive_proto.HasField('bb'))
  921. recursive_proto.bb.a.bb.a.bb.optional_int32 = 5
  922. self.assertEqual(5, recursive_proto.bb.a.bb.a.bb.optional_int32)
  923. self.assertTrue(recursive_proto.HasField('bb'))
  924. self.assertTrue(recursive_proto.bb.HasField('a'))
  925. self.assertTrue(recursive_proto.bb.a.HasField('bb'))
  926. self.assertTrue(recursive_proto.bb.a.bb.HasField('a'))
  927. self.assertTrue(recursive_proto.bb.a.bb.a.HasField('bb'))
  928. self.assertFalse(recursive_proto.bb.a.bb.a.bb.HasField('a'))
  929. self.assertTrue(recursive_proto.bb.a.bb.a.bb.HasField('optional_int32'))
  930. def testSingularListExtensions(self):
  931. proto = unittest_pb2.TestAllExtensions()
  932. proto.Extensions[unittest_pb2.optional_fixed32_extension] = 1
  933. proto.Extensions[unittest_pb2.optional_int32_extension ] = 5
  934. proto.Extensions[unittest_pb2.optional_string_extension ] = 'foo'
  935. self.assertEqual(
  936. [ (unittest_pb2.optional_int32_extension , 5),
  937. (unittest_pb2.optional_fixed32_extension, 1),
  938. (unittest_pb2.optional_string_extension , 'foo') ],
  939. proto.ListFields())
  940. del proto.Extensions[unittest_pb2.optional_fixed32_extension]
  941. self.assertEqual(
  942. [(unittest_pb2.optional_int32_extension, 5),
  943. (unittest_pb2.optional_string_extension, 'foo')],
  944. proto.ListFields())
  945. def testRepeatedListExtensions(self):
  946. proto = unittest_pb2.TestAllExtensions()
  947. proto.Extensions[unittest_pb2.repeated_fixed32_extension].append(1)
  948. proto.Extensions[unittest_pb2.repeated_int32_extension ].append(5)
  949. proto.Extensions[unittest_pb2.repeated_int32_extension ].append(11)
  950. proto.Extensions[unittest_pb2.repeated_string_extension ].append('foo')
  951. proto.Extensions[unittest_pb2.repeated_string_extension ].append('bar')
  952. proto.Extensions[unittest_pb2.repeated_string_extension ].append('baz')
  953. proto.Extensions[unittest_pb2.optional_int32_extension ] = 21
  954. self.assertEqual(
  955. [ (unittest_pb2.optional_int32_extension , 21),
  956. (unittest_pb2.repeated_int32_extension , [5, 11]),
  957. (unittest_pb2.repeated_fixed32_extension, [1]),
  958. (unittest_pb2.repeated_string_extension , ['foo', 'bar', 'baz']) ],
  959. proto.ListFields())
  960. del proto.Extensions[unittest_pb2.repeated_int32_extension]
  961. del proto.Extensions[unittest_pb2.repeated_string_extension]
  962. self.assertEqual(
  963. [(unittest_pb2.optional_int32_extension, 21),
  964. (unittest_pb2.repeated_fixed32_extension, [1])],
  965. proto.ListFields())
  966. def testListFieldsAndExtensions(self):
  967. proto = unittest_pb2.TestFieldOrderings()
  968. test_util.SetAllFieldsAndExtensions(proto)
  969. unittest_pb2.my_extension_int
  970. self.assertEqual(
  971. [ (proto.DESCRIPTOR.fields_by_name['my_int' ], 1),
  972. (unittest_pb2.my_extension_int , 23),
  973. (proto.DESCRIPTOR.fields_by_name['my_string'], 'foo'),
  974. (unittest_pb2.my_extension_string , 'bar'),
  975. (proto.DESCRIPTOR.fields_by_name['my_float' ], 1.0) ],
  976. proto.ListFields())
  977. def testDefaultValues(self):
  978. proto = unittest_pb2.TestAllTypes()
  979. self.assertEqual(0, proto.optional_int32)
  980. self.assertEqual(0, proto.optional_int64)
  981. self.assertEqual(0, proto.optional_uint32)
  982. self.assertEqual(0, proto.optional_uint64)
  983. self.assertEqual(0, proto.optional_sint32)
  984. self.assertEqual(0, proto.optional_sint64)
  985. self.assertEqual(0, proto.optional_fixed32)
  986. self.assertEqual(0, proto.optional_fixed64)
  987. self.assertEqual(0, proto.optional_sfixed32)
  988. self.assertEqual(0, proto.optional_sfixed64)
  989. self.assertEqual(0.0, proto.optional_float)
  990. self.assertEqual(0.0, proto.optional_double)
  991. self.assertEqual(False, proto.optional_bool)
  992. self.assertEqual('', proto.optional_string)
  993. self.assertEqual(b'', proto.optional_bytes)
  994. self.assertEqual(41, proto.default_int32)
  995. self.assertEqual(42, proto.default_int64)
  996. self.assertEqual(43, proto.default_uint32)
  997. self.assertEqual(44, proto.default_uint64)
  998. self.assertEqual(-45, proto.default_sint32)
  999. self.assertEqual(46, proto.default_sint64)
  1000. self.assertEqual(47, proto.default_fixed32)
  1001. self.assertEqual(48, proto.default_fixed64)
  1002. self.assertEqual(49, proto.default_sfixed32)
  1003. self.assertEqual(-50, proto.default_sfixed64)
  1004. self.assertEqual(51.5, proto.default_float)
  1005. self.assertEqual(52e3, proto.default_double)
  1006. self.assertEqual(True, proto.default_bool)
  1007. self.assertEqual('hello', proto.default_string)
  1008. self.assertEqual(b'world', proto.default_bytes)
  1009. self.assertEqual(unittest_pb2.TestAllTypes.BAR, proto.default_nested_enum)
  1010. self.assertEqual(unittest_pb2.FOREIGN_BAR, proto.default_foreign_enum)
  1011. self.assertEqual(unittest_import_pb2.IMPORT_BAR,
  1012. proto.default_import_enum)
  1013. proto = unittest_pb2.TestExtremeDefaultValues()
  1014. self.assertEqual(u'\u1234', proto.utf8_string)
  1015. def testHasFieldWithUnknownFieldName(self):
  1016. proto = unittest_pb2.TestAllTypes()
  1017. self.assertRaises(ValueError, proto.HasField, 'nonexistent_field')
  1018. def testClearRemovesChildren(self):
  1019. # Make sure there aren't any implementation bugs that are only partially
  1020. # clearing the message (which can happen in the more complex C++
  1021. # implementation which has parallel message lists).
  1022. proto = unittest_pb2.TestRequiredForeign()
  1023. for i in range(10):
  1024. proto.repeated_message.add()
  1025. proto2 = unittest_pb2.TestRequiredForeign()
  1026. proto.CopyFrom(proto2)
  1027. self.assertRaises(IndexError, lambda: proto.repeated_message[5])
  1028. def testSingleScalarClearField(self):
  1029. proto = unittest_pb2.TestAllTypes()
  1030. # Should be allowed to clear something that's not there (a no-op).
  1031. proto.ClearField('optional_int32')
  1032. proto.optional_int32 = 1
  1033. self.assertTrue(proto.HasField('optional_int32'))
  1034. proto.ClearField('optional_int32')
  1035. self.assertEqual(0, proto.optional_int32)
  1036. self.assertFalse(proto.HasField('optional_int32'))
  1037. # TODO(robinson): Test all other scalar field types.
  1038. def testRepeatedScalars(self):
  1039. proto = unittest_pb2.TestAllTypes()
  1040. self.assertFalse(proto.repeated_int32)
  1041. self.assertEqual(0, len(proto.repeated_int32))
  1042. proto.repeated_int32.append(5)
  1043. proto.repeated_int32.append(10)
  1044. proto.repeated_int32.append(15)
  1045. self.assertTrue(proto.repeated_int32)
  1046. self.assertEqual(3, len(proto.repeated_int32))
  1047. self.assertEqual([5, 10, 15], proto.repeated_int32)
  1048. # Test single retrieval.
  1049. self.assertEqual(5, proto.repeated_int32[0])
  1050. self.assertEqual(15, proto.repeated_int32[-1])
  1051. # Test out-of-bounds indices.
  1052. self.assertRaises(IndexError, proto.repeated_int32.__getitem__, 1234)
  1053. self.assertRaises(IndexError, proto.repeated_int32.__getitem__, -1234)
  1054. # Test incorrect types passed to __getitem__.
  1055. self.assertRaises(TypeError, proto.repeated_int32.__getitem__, 'foo')
  1056. self.assertRaises(TypeError, proto.repeated_int32.__getitem__, None)
  1057. # Test single assignment.
  1058. proto.repeated_int32[1] = 20
  1059. self.assertEqual([5, 20, 15], proto.repeated_int32)
  1060. # Test insertion.
  1061. proto.repeated_int32.insert(1, 25)
  1062. self.assertEqual([5, 25, 20, 15], proto.repeated_int32)
  1063. # Test slice retrieval.
  1064. proto.repeated_int32.append(30)
  1065. self.assertEqual([25, 20, 15], proto.repeated_int32[1:4])
  1066. self.assertEqual([5, 25, 20, 15, 30], proto.repeated_int32[:])
  1067. # Test slice assignment with an iterator
  1068. proto.repeated_int32[1:4] = (i for i in range(3))
  1069. self.assertEqual([5, 0, 1, 2, 30], proto.repeated_int32)
  1070. # Test slice assignment.
  1071. proto.repeated_int32[1:4] = [35, 40, 45]
  1072. self.assertEqual([5, 35, 40, 45, 30], proto.repeated_int32)
  1073. # Test that we can use the field as an iterator.
  1074. result = []
  1075. for i in proto.repeated_int32:
  1076. result.append(i)
  1077. self.assertEqual([5, 35, 40, 45, 30], result)
  1078. # Test single deletion.
  1079. del proto.repeated_int32[2]
  1080. self.assertEqual([5, 35, 45, 30], proto.repeated_int32)
  1081. # Test slice deletion.
  1082. del proto.repeated_int32[2:]
  1083. self.assertEqual([5, 35], proto.repeated_int32)
  1084. # Test extending.
  1085. proto.repeated_int32.extend([3, 13])
  1086. self.assertEqual([5, 35, 3, 13], proto.repeated_int32)
  1087. # Test clearing.
  1088. proto.ClearField('repeated_int32')
  1089. self.assertFalse(proto.repeated_int32)
  1090. self.assertEqual(0, len(proto.repeated_int32))
  1091. proto.repeated_int32.append(1)
  1092. self.assertEqual(1, proto.repeated_int32[-1])
  1093. # Test assignment to a negative index.
  1094. proto.repeated_int32[-1] = 2
  1095. self.assertEqual(2, proto.repeated_int32[-1])
  1096. # Test deletion at negative indices.
  1097. proto.repeated_int32[:] = [0, 1, 2, 3]
  1098. del proto.repeated_int32[-1]
  1099. self.assertEqual([0, 1, 2], proto.repeated_int32)
  1100. del proto.repeated_int32[-2]
  1101. self.assertEqual([0, 2], proto.repeated_int32)
  1102. self.assertRaises(IndexError, proto.repeated_int32.__delitem__, -3)
  1103. self.assertRaises(IndexError, proto.repeated_int32.__delitem__, 300)
  1104. del proto.repeated_int32[-2:-1]
  1105. self.assertEqual([2], proto.repeated_int32)
  1106. del proto.repeated_int32[100:10000]
  1107. self.assertEqual([2], proto.repeated_int32)
  1108. def testRepeatedScalarsRemove(self):
  1109. proto = unittest_pb2.TestAllTypes()
  1110. self.assertFalse(proto.repeated_int32)
  1111. self.assertEqual(0, len(proto.repeated_int32))
  1112. proto.repeated_int32.append(5)
  1113. proto.repeated_int32.append(10)
  1114. proto.repeated_int32.append(5)
  1115. proto.repeated_int32.append(5)
  1116. self.assertEqual(4, len(proto.repeated_int32))
  1117. proto.repeated_int32.remove(5)
  1118. self.assertEqual(3, len(proto.repeated_int32))
  1119. self.assertEqual(10, proto.repeated_int32[0])
  1120. self.assertEqual(5, proto.repeated_int32[1])
  1121. self.assertEqual(5, proto.repeated_int32[2])
  1122. proto.repeated_int32.remove(5)
  1123. self.assertEqual(2, len(proto.repeated_int32))
  1124. self.assertEqual(10, proto.repeated_int32[0])
  1125. self.assertEqual(5, proto.repeated_int32[1])
  1126. proto.repeated_int32.remove(10)
  1127. self.assertEqual(1, len(proto.repeated_int32))
  1128. self.assertEqual(5, proto.repeated_int32[0])
  1129. # Remove a non-existent element.
  1130. self.assertRaises(ValueError, proto.repeated_int32.remove, 123)
  1131. def testRepeatedComposites(self):
  1132. proto = unittest_pb2.TestAllTypes()
  1133. self.assertFalse(proto.repeated_nested_message)
  1134. self.assertEqual(0, len(proto.repeated_nested_message))
  1135. m0 = proto.repeated_nested_message.add()
  1136. m1 = proto.repeated_nested_message.add()
  1137. self.assertTrue(proto.repeated_nested_message)
  1138. self.assertEqual(2, len(proto.repeated_nested_message))
  1139. self.assertListsEqual([m0, m1], proto.repeated_nested_message)
  1140. self.assertIsInstance(m0, unittest_pb2.TestAllTypes.NestedMessage)
  1141. # Test out-of-bounds indices.
  1142. self.assertRaises(IndexError, proto.repeated_nested_message.__getitem__,
  1143. 1234)
  1144. self.assertRaises(IndexError, proto.repeated_nested_message.__getitem__,
  1145. -1234)
  1146. # Test incorrect types passed to __getitem__.
  1147. self.assertRaises(TypeError, proto.repeated_nested_message.__getitem__,
  1148. 'foo')
  1149. self.assertRaises(TypeError, proto.repeated_nested_message.__getitem__,
  1150. None)
  1151. # Test slice retrieval.
  1152. m2 = proto.repeated_nested_message.add()
  1153. m3 = proto.repeated_nested_message.add()
  1154. m4 = proto.repeated_nested_message.add()
  1155. self.assertListsEqual(
  1156. [m1, m2, m3], proto.repeated_nested_message[1:4])
  1157. self.assertListsEqual(
  1158. [m0, m1, m2, m3, m4], proto.repeated_nested_message[:])
  1159. self.assertListsEqual(
  1160. [m0, m1], proto.repeated_nested_message[:2])
  1161. self.assertListsEqual(
  1162. [m2, m3, m4], proto.repeated_nested_message[2:])
  1163. self.assertEqual(
  1164. m0, proto.repeated_nested_message[0])
  1165. self.assertListsEqual(
  1166. [m0], proto.repeated_nested_message[:1])
  1167. # Test that we can use the field as an iterator.
  1168. result = []
  1169. for i in proto.repeated_nested_message:
  1170. result.append(i)
  1171. self.assertListsEqual([m0, m1, m2, m3, m4], result)
  1172. # Test single deletion.
  1173. del proto.repeated_nested_message[2]
  1174. self.assertListsEqual([m0, m1, m3, m4], proto.repeated_nested_message)
  1175. # Test slice deletion.
  1176. del proto.repeated_nested_message[2:]
  1177. self.assertListsEqual([m0, m1], proto.repeated_nested_message)
  1178. # Test extending.
  1179. n1 = unittest_pb2.TestAllTypes.NestedMessage(bb=1)
  1180. n2 = unittest_pb2.TestAllTypes.NestedMessage(bb=2)
  1181. proto.repeated_nested_message.extend([n1,n2])
  1182. self.assertEqual(4, len(proto.repeated_nested_message))
  1183. self.assertEqual(n1, proto.repeated_nested_message[2])
  1184. self.assertEqual(n2, proto.repeated_nested_message[3])
  1185. self.assertRaises(TypeError,
  1186. proto.repeated_nested_message.extend, n1)
  1187. self.assertRaises(TypeError,
  1188. proto.repeated_nested_message.extend, [0])
  1189. wrong_message_type = unittest_pb2.TestAllTypes()
  1190. self.assertRaises(TypeError,
  1191. proto.repeated_nested_message.extend,
  1192. [wrong_message_type])
  1193. # Test clearing.
  1194. proto.ClearField('repeated_nested_message')
  1195. self.assertFalse(proto.repeated_nested_message)
  1196. self.assertEqual(0, len(proto.repeated_nested_message))
  1197. # Test constructing an element while adding it.
  1198. proto.repeated_nested_message.add(bb=23)
  1199. self.assertEqual(1, len(proto.repeated_nested_message))
  1200. self.assertEqual(23, proto.repeated_nested_message[0].bb)
  1201. self.assertRaises(TypeError, proto.repeated_nested_message.add, 23)
  1202. with self.assertRaises(Exception):
  1203. proto.repeated_nested_message[0] = 23
  1204. def testRepeatedCompositeRemove(self):
  1205. proto = unittest_pb2.TestAllTypes()
  1206. self.assertEqual(0, len(proto.repeated_nested_message))
  1207. m0 = proto.repeated_nested_message.add()
  1208. # Need to set some differentiating variable so m0 != m1 != m2:
  1209. m0.bb = len(proto.repeated_nested_message)
  1210. m1 = proto.repeated_nested_message.add()
  1211. m1.bb = len(proto.repeated_nested_message)
  1212. self.assertTrue(m0 != m1)
  1213. m2 = proto.repeated_nested_message.add()
  1214. m2.bb = len(proto.repeated_nested_message)
  1215. self.assertListsEqual([m0, m1, m2], proto.repeated_nested_message)
  1216. self.assertEqual(3, len(proto.repeated_nested_message))
  1217. proto.repeated_nested_message.remove(m0)
  1218. self.assertEqual(2, len(proto.repeated_nested_message))
  1219. self.assertEqual(m1, proto.repeated_nested_message[0])
  1220. self.assertEqual(m2, proto.repeated_nested_message[1])
  1221. # Removing m0 again or removing None should raise error
  1222. self.assertRaises(ValueError, proto.repeated_nested_message.remove, m0)
  1223. self.assertRaises(ValueError, proto.repeated_nested_message.remove, None)
  1224. self.assertEqual(2, len(proto.repeated_nested_message))
  1225. proto.repeated_nested_message.remove(m2)
  1226. self.assertEqual(1, len(proto.repeated_nested_message))
  1227. self.assertEqual(m1, proto.repeated_nested_message[0])
  1228. def testHandWrittenReflection(self):
  1229. # Hand written extensions are only supported by the pure-Python
  1230. # implementation of the API.
  1231. if api_implementation.Type() != 'python':
  1232. return
  1233. FieldDescriptor = descriptor.FieldDescriptor
  1234. foo_field_descriptor = FieldDescriptor(
  1235. name='foo_field', full_name='MyProto.foo_field',
  1236. index=0, number=1, type=FieldDescriptor.TYPE_INT64,
  1237. cpp_type=FieldDescriptor.CPPTYPE_INT64,
  1238. label=FieldDescriptor.LABEL_OPTIONAL, default_value=0,
  1239. containing_type=None, message_type=None, enum_type=None,
  1240. is_extension=False, extension_scope=None,
  1241. options=descriptor_pb2.FieldOptions(),
  1242. # pylint: disable=protected-access
  1243. create_key=descriptor._internal_create_key)
  1244. mydescriptor = descriptor.Descriptor(
  1245. name='MyProto', full_name='MyProto', filename='ignored',
  1246. containing_type=None, nested_types=[], enum_types=[],
  1247. fields=[foo_field_descriptor], extensions=[],
  1248. options=descriptor_pb2.MessageOptions(),
  1249. # pylint: disable=protected-access
  1250. create_key=descriptor._internal_create_key)
  1251. class MyProtoClass(six.with_metaclass(reflection.GeneratedProtocolMessageType, message.Message)):
  1252. DESCRIPTOR = mydescriptor
  1253. myproto_instance = MyProtoClass()
  1254. self.assertEqual(0, myproto_instance.foo_field)
  1255. self.assertFalse(myproto_instance.HasField('foo_field'))
  1256. myproto_instance.foo_field = 23
  1257. self.assertEqual(23, myproto_instance.foo_field)
  1258. self.assertTrue(myproto_instance.HasField('foo_field'))
  1259. @testing_refleaks.SkipReferenceLeakChecker('MakeDescriptor is not repeatable')
  1260. def testDescriptorProtoSupport(self):
  1261. # Hand written descriptors/reflection are only supported by the pure-Python
  1262. # implementation of the API.
  1263. if api_implementation.Type() != 'python':
  1264. return
  1265. def AddDescriptorField(proto, field_name, field_type):
  1266. AddDescriptorField.field_index += 1
  1267. new_field = proto.field.add()
  1268. new_field.name = field_name
  1269. new_field.type = field_type
  1270. new_field.number = AddDescriptorField.field_index
  1271. new_field.label = descriptor_pb2.FieldDescriptorProto.LABEL_OPTIONAL
  1272. AddDescriptorField.field_index = 0
  1273. desc_proto = descriptor_pb2.DescriptorProto()
  1274. desc_proto.name = 'Car'
  1275. fdp = descriptor_pb2.FieldDescriptorProto
  1276. AddDescriptorField(desc_proto, 'name', fdp.TYPE_STRING)
  1277. AddDescriptorField(desc_proto, 'year', fdp.TYPE_INT64)
  1278. AddDescriptorField(desc_proto, 'automatic', fdp.TYPE_BOOL)
  1279. AddDescriptorField(desc_proto, 'price', fdp.TYPE_DOUBLE)
  1280. # Add a repeated field
  1281. AddDescriptorField.field_index += 1
  1282. new_field = desc_proto.field.add()
  1283. new_field.name = 'owners'
  1284. new_field.type = fdp.TYPE_STRING
  1285. new_field.number = AddDescriptorField.field_index
  1286. new_field.label = descriptor_pb2.FieldDescriptorProto.LABEL_REPEATED
  1287. desc = descriptor.MakeDescriptor(desc_proto)
  1288. self.assertTrue('name' in desc.fields_by_name)
  1289. self.assertTrue('year' in desc.fields_by_name)
  1290. self.assertTrue('automatic' in desc.fields_by_name)
  1291. self.assertTrue('price' in desc.fields_by_name)
  1292. self.assertTrue('owners' in desc.fields_by_name)
  1293. class CarMessage(six.with_metaclass(reflection.GeneratedProtocolMessageType,
  1294. message.Message)):
  1295. DESCRIPTOR = desc
  1296. prius = CarMessage()
  1297. prius.name = 'prius'
  1298. prius.year = 2010
  1299. prius.automatic = True
  1300. prius.price = 25134.75
  1301. prius.owners.extend(['bob', 'susan'])
  1302. serialized_prius = prius.SerializeToString()
  1303. new_prius = reflection.ParseMessage(desc, serialized_prius)
  1304. self.assertIsNot(new_prius, prius)
  1305. self.assertEqual(prius, new_prius)
  1306. # these are unnecessary assuming message equality works as advertised but
  1307. # explicitly check to be safe since we're mucking about in metaclass foo
  1308. self.assertEqual(prius.name, new_prius.name)
  1309. self.assertEqual(prius.year, new_prius.year)
  1310. self.assertEqual(prius.automatic, new_prius.automatic)
  1311. self.assertEqual(prius.price, new_prius.price)
  1312. self.assertEqual(prius.owners, new_prius.owners)
  1313. def testExtensionDelete(self):
  1314. extendee_proto = more_extensions_pb2.ExtendedMessage()
  1315. extension_int32 = more_extensions_pb2.optional_int_extension
  1316. extendee_proto.Extensions[extension_int32] = 23
  1317. extension_repeated = more_extensions_pb2.repeated_int_extension
  1318. extendee_proto.Extensions[extension_repeated].append(11)
  1319. extension_msg = more_extensions_pb2.optional_message_extension
  1320. extendee_proto.Extensions[extension_msg].foreign_message_int = 56
  1321. self.assertEqual(len(extendee_proto.Extensions), 3)
  1322. del extendee_proto.Extensions[extension_msg]
  1323. self.assertEqual(len(extendee_proto.Extensions), 2)
  1324. del extendee_proto.Extensions[extension_repeated]
  1325. self.assertEqual(len(extendee_proto.Extensions), 1)
  1326. # Delete a none exist extension. It is OK to "del m.Extensions[ext]"
  1327. # even if the extension is not present in the message; we don't
  1328. # raise KeyError. This is consistent with "m.Extensions[ext]"
  1329. # returning a default value even if we did not set anything.
  1330. del extendee_proto.Extensions[extension_repeated]
  1331. self.assertEqual(len(extendee_proto.Extensions), 1)
  1332. del extendee_proto.Extensions[extension_int32]
  1333. self.assertEqual(len(extendee_proto.Extensions), 0)
  1334. def testExtensionIter(self):
  1335. extendee_proto = more_extensions_pb2.ExtendedMessage()
  1336. extension_int32 = more_extensions_pb2.optional_int_extension
  1337. extendee_proto.Extensions[extension_int32] = 23
  1338. extension_repeated = more_extensions_pb2.repeated_int_extension
  1339. extendee_proto.Extensions[extension_repeated].append(11)
  1340. extension_msg = more_extensions_pb2.optional_message_extension
  1341. extendee_proto.Extensions[extension_msg].foreign_message_int = 56
  1342. # Set some normal fields.
  1343. extendee_proto.optional_int32 = 1
  1344. extendee_proto.repeated_string.append('hi')
  1345. expected = (extension_int32, extension_msg, extension_repeated)
  1346. count = 0
  1347. for item in extendee_proto.Extensions:
  1348. self.assertEqual(item.name, expected[count].name)
  1349. self.assertIn(item, extendee_proto.Extensions)
  1350. count += 1
  1351. self.assertEqual(count, 3)
  1352. def testExtensionContainsError(self):
  1353. extendee_proto = more_extensions_pb2.ExtendedMessage()
  1354. self.assertRaises(KeyError, extendee_proto.Extensions.__contains__, 0)
  1355. field = more_extensions_pb2.ExtendedMessage.DESCRIPTOR.fields_by_name[
  1356. 'optional_int32']
  1357. self.assertRaises(KeyError, extendee_proto.Extensions.__contains__, field)
  1358. def testTopLevelExtensionsForOptionalScalar(self):
  1359. extendee_proto = unittest_pb2.TestAllExtensions()
  1360. extension = unittest_pb2.optional_int32_extension
  1361. self.assertFalse(extendee_proto.HasExtension(extension))
  1362. self.assertNotIn(extension, extendee_proto.Extensions)
  1363. self.assertEqual(0, extendee_proto.Extensions[extension])
  1364. # As with normal scalar fields, just doing a read doesn't actually set the
  1365. # "has" bit.
  1366. self.assertFalse(extendee_proto.HasExtension(extension))
  1367. self.assertNotIn(extension, extendee_proto.Extensions)
  1368. # Actually set the thing.
  1369. extendee_proto.Extensions[extension] = 23
  1370. self.assertEqual(23, extendee_proto.Extensions[extension])
  1371. self.assertTrue(extendee_proto.HasExtension(extension))
  1372. self.assertIn(extension, extendee_proto.Extensions)
  1373. # Ensure that clearing works as well.
  1374. extendee_proto.ClearExtension(extension)
  1375. self.assertEqual(0, extendee_proto.Extensions[extension])
  1376. self.assertFalse(extendee_proto.HasExtension(extension))
  1377. self.assertNotIn(extension, extendee_proto.Extensions)
  1378. def testTopLevelExtensionsForRepeatedScalar(self):
  1379. extendee_proto = unittest_pb2.TestAllExtensions()
  1380. extension = unittest_pb2.repeated_string_extension
  1381. self.assertEqual(0, len(extendee_proto.Extensions[extension]))
  1382. self.assertNotIn(extension, extendee_proto.Extensions)
  1383. extendee_proto.Extensions[extension].append('foo')
  1384. self.assertEqual(['foo'], extendee_proto.Extensions[extension])
  1385. self.assertIn(extension, extendee_proto.Extensions)
  1386. string_list = extendee_proto.Extensions[extension]
  1387. extendee_proto.ClearExtension(extension)
  1388. self.assertEqual(0, len(extendee_proto.Extensions[extension]))
  1389. self.assertNotIn(extension, extendee_proto.Extensions)
  1390. self.assertIsNot(string_list, extendee_proto.Extensions[extension])
  1391. # Shouldn't be allowed to do Extensions[extension] = 'a'
  1392. self.assertRaises(TypeError, operator.setitem, extendee_proto.Extensions,
  1393. extension, 'a')
  1394. def testTopLevelExtensionsForOptionalMessage(self):
  1395. extendee_proto = unittest_pb2.TestAllExtensions()
  1396. extension = unittest_pb2.optional_foreign_message_extension
  1397. self.assertFalse(extendee_proto.HasExtension(extension))
  1398. self.assertNotIn(extension, extendee_proto.Extensions)
  1399. self.assertEqual(0, extendee_proto.Extensions[extension].c)
  1400. # As with normal (non-extension) fields, merely reading from the
  1401. # thing shouldn't set the "has" bit.
  1402. self.assertFalse(extendee_proto.HasExtension(extension))
  1403. self.assertNotIn(extension, extendee_proto.Extensions)
  1404. extendee_proto.Extensions[extension].c = 23
  1405. self.assertEqual(23, extendee_proto.Extensions[extension].c)
  1406. self.assertTrue(extendee_proto.HasExtension(extension))
  1407. self.assertIn(extension, extendee_proto.Extensions)
  1408. # Save a reference here.
  1409. foreign_message = extendee_proto.Extensions[extension]
  1410. extendee_proto.ClearExtension(extension)
  1411. self.assertIsNot(foreign_message, extendee_proto.Extensions[extension])
  1412. # Setting a field on foreign_message now shouldn't set
  1413. # any "has" bits on extendee_proto.
  1414. foreign_message.c = 42
  1415. self.assertEqual(42, foreign_message.c)
  1416. self.assertTrue(foreign_message.HasField('c'))
  1417. self.assertFalse(extendee_proto.HasExtension(extension))
  1418. self.assertNotIn(extension, extendee_proto.Extensions)
  1419. # Shouldn't be allowed to do Extensions[extension] = 'a'
  1420. self.assertRaises(TypeError, operator.setitem, extendee_proto.Extensions,
  1421. extension, 'a')
  1422. def testTopLevelExtensionsForRepeatedMessage(self):
  1423. extendee_proto = unittest_pb2.TestAllExtensions()
  1424. extension = unittest_pb2.repeatedgroup_extension
  1425. self.assertEqual(0, len(extendee_proto.Extensions[extension]))
  1426. group = extendee_proto.Extensions[extension].add()
  1427. group.a = 23
  1428. self.assertEqual(23, extendee_proto.Extensions[extension][0].a)
  1429. group.a = 42
  1430. self.assertEqual(42, extendee_proto.Extensions[extension][0].a)
  1431. group_list = extendee_proto.Extensions[extension]
  1432. extendee_proto.ClearExtension(extension)
  1433. self.assertEqual(0, len(extendee_proto.Extensions[extension]))
  1434. self.assertIsNot(group_list, extendee_proto.Extensions[extension])
  1435. # Shouldn't be allowed to do Extensions[extension] = 'a'
  1436. self.assertRaises(TypeError, operator.setitem, extendee_proto.Extensions,
  1437. extension, 'a')
  1438. def testNestedExtensions(self):
  1439. extendee_proto = unittest_pb2.TestAllExtensions()
  1440. extension = unittest_pb2.TestRequired.single
  1441. # We just test the non-repeated case.
  1442. self.assertFalse(extendee_proto.HasExtension(extension))
  1443. self.assertNotIn(extension, extendee_proto.Extensions)
  1444. required = extendee_proto.Extensions[extension]
  1445. self.assertEqual(0, required.a)
  1446. self.assertFalse(extendee_proto.HasExtension(extension))
  1447. self.assertNotIn(extension, extendee_proto.Extensions)
  1448. required.a = 23
  1449. self.assertEqual(23, extendee_proto.Extensions[extension].a)
  1450. self.assertTrue(extendee_proto.HasExtension(extension))
  1451. self.assertIn(extension, extendee_proto.Extensions)
  1452. extendee_proto.ClearExtension(extension)
  1453. self.assertIsNot(required, extendee_proto.Extensions[extension])
  1454. self.assertFalse(extendee_proto.HasExtension(extension))
  1455. self.assertNotIn(extension, extendee_proto.Extensions)
  1456. def testRegisteredExtensions(self):
  1457. pool = unittest_pb2.DESCRIPTOR.pool
  1458. self.assertTrue(
  1459. pool.FindExtensionByNumber(
  1460. unittest_pb2.TestAllExtensions.DESCRIPTOR, 1))
  1461. self.assertIs(
  1462. pool.FindExtensionByName(
  1463. 'protobuf_unittest.optional_int32_extension').containing_type,
  1464. unittest_pb2.TestAllExtensions.DESCRIPTOR)
  1465. # Make sure extensions haven't been registered into types that shouldn't
  1466. # have any.
  1467. self.assertEqual(0, len(
  1468. pool.FindAllExtensions(unittest_pb2.TestAllTypes.DESCRIPTOR)))
  1469. # If message A directly contains message B, and
  1470. # a.HasField('b') is currently False, then mutating any
  1471. # extension in B should change a.HasField('b') to True
  1472. # (and so on up the object tree).
  1473. def testHasBitsForAncestorsOfExtendedMessage(self):
  1474. # Optional scalar extension.
  1475. toplevel = more_extensions_pb2.TopLevelMessage()
  1476. self.assertFalse(toplevel.HasField('submessage'))
  1477. self.assertEqual(0, toplevel.submessage.Extensions[
  1478. more_extensions_pb2.optional_int_extension])
  1479. self.assertFalse(toplevel.HasField('submessage'))
  1480. toplevel.submessage.Extensions[
  1481. more_extensions_pb2.optional_int_extension] = 23
  1482. self.assertEqual(23, toplevel.submessage.Extensions[
  1483. more_extensions_pb2.optional_int_extension])
  1484. self.assertTrue(toplevel.HasField('submessage'))
  1485. # Repeated scalar extension.
  1486. toplevel = more_extensions_pb2.TopLevelMessage()
  1487. self.assertFalse(toplevel.HasField('submessage'))
  1488. self.assertEqual([], toplevel.submessage.Extensions[
  1489. more_extensions_pb2.repeated_int_extension])
  1490. self.assertFalse(toplevel.HasField('submessage'))
  1491. toplevel.submessage.Extensions[
  1492. more_extensions_pb2.repeated_int_extension].append(23)
  1493. self.assertEqual([23], toplevel.submessage.Extensions[
  1494. more_extensions_pb2.repeated_int_extension])
  1495. self.assertTrue(toplevel.HasField('submessage'))
  1496. # Optional message extension.
  1497. toplevel = more_extensions_pb2.TopLevelMessage()
  1498. self.assertFalse(toplevel.HasField('submessage'))
  1499. self.assertEqual(0, toplevel.submessage.Extensions[
  1500. more_extensions_pb2.optional_message_extension].foreign_message_int)
  1501. self.assertFalse(toplevel.HasField('submessage'))
  1502. toplevel.submessage.Extensions[
  1503. more_extensions_pb2.optional_message_extension].foreign_message_int = 23
  1504. self.assertEqual(23, toplevel.submessage.Extensions[
  1505. more_extensions_pb2.optional_message_extension].foreign_message_int)
  1506. self.assertTrue(toplevel.HasField('submessage'))
  1507. # Repeated message extension.
  1508. toplevel = more_extensions_pb2.TopLevelMessage()
  1509. self.assertFalse(toplevel.HasField('submessage'))
  1510. self.assertEqual(0, len(toplevel.submessage.Extensions[
  1511. more_extensions_pb2.repeated_message_extension]))
  1512. self.assertFalse(toplevel.HasField('submessage'))
  1513. foreign = toplevel.submessage.Extensions[
  1514. more_extensions_pb2.repeated_message_extension].add()
  1515. self.assertEqual(foreign, toplevel.submessage.Extensions[
  1516. more_extensions_pb2.repeated_message_extension][0])
  1517. self.assertTrue(toplevel.HasField('submessage'))
  1518. def testDisconnectionAfterClearingEmptyMessage(self):
  1519. toplevel = more_extensions_pb2.TopLevelMessage()
  1520. extendee_proto = toplevel.submessage
  1521. extension = more_extensions_pb2.optional_message_extension
  1522. extension_proto = extendee_proto.Extensions[extension]
  1523. extendee_proto.ClearExtension(extension)
  1524. extension_proto.foreign_message_int = 23
  1525. self.assertIsNot(extension_proto, extendee_proto.Extensions[extension])
  1526. def testExtensionFailureModes(self):
  1527. extendee_proto = unittest_pb2.TestAllExtensions()
  1528. # Try non-extension-handle arguments to HasExtension,
  1529. # ClearExtension(), and Extensions[]...
  1530. self.assertRaises(KeyError, extendee_proto.HasExtension, 1234)
  1531. self.assertRaises(KeyError, extendee_proto.ClearExtension, 1234)
  1532. self.assertRaises(KeyError, extendee_proto.Extensions.__getitem__, 1234)
  1533. self.assertRaises(KeyError, extendee_proto.Extensions.__setitem__, 1234, 5)
  1534. # Try something that *is* an extension handle, just not for
  1535. # this message...
  1536. for unknown_handle in (more_extensions_pb2.optional_int_extension,
  1537. more_extensions_pb2.optional_message_extension,
  1538. more_extensions_pb2.repeated_int_extension,
  1539. more_extensions_pb2.repeated_message_extension):
  1540. self.assertRaises(KeyError, extendee_proto.HasExtension,
  1541. unknown_handle)
  1542. self.assertRaises(KeyError, extendee_proto.ClearExtension,
  1543. unknown_handle)
  1544. self.assertRaises(KeyError, extendee_proto.Extensions.__getitem__,
  1545. unknown_handle)
  1546. self.assertRaises(KeyError, extendee_proto.Extensions.__setitem__,
  1547. unknown_handle, 5)
  1548. # Try call HasExtension() with a valid handle, but for a
  1549. # *repeated* field. (Just as with non-extension repeated
  1550. # fields, Has*() isn't supported for extension repeated fields).
  1551. self.assertRaises(KeyError, extendee_proto.HasExtension,
  1552. unittest_pb2.repeated_string_extension)
  1553. def testMergeFromOptionalGroup(self):
  1554. # Test merge with an optional group.
  1555. proto1 = unittest_pb2.TestAllTypes()
  1556. proto1.optionalgroup.a = 12
  1557. proto2 = unittest_pb2.TestAllTypes()
  1558. proto2.MergeFrom(proto1)
  1559. self.assertEqual(12, proto2.optionalgroup.a)
  1560. def testMergeFromExtensionsSingular(self):
  1561. proto1 = unittest_pb2.TestAllExtensions()
  1562. proto1.Extensions[unittest_pb2.optional_int32_extension] = 1
  1563. proto2 = unittest_pb2.TestAllExtensions()
  1564. proto2.MergeFrom(proto1)
  1565. self.assertEqual(
  1566. 1, proto2.Extensions[unittest_pb2.optional_int32_extension])
  1567. def testMergeFromExtensionsRepeated(self):
  1568. proto1 = unittest_pb2.TestAllExtensions()
  1569. proto1.Extensions[unittest_pb2.repeated_int32_extension].append(1)
  1570. proto1.Extensions[unittest_pb2.repeated_int32_extension].append(2)
  1571. proto2 = unittest_pb2.TestAllExtensions()
  1572. proto2.Extensions[unittest_pb2.repeated_int32_extension].append(0)
  1573. proto2.MergeFrom(proto1)
  1574. self.assertEqual(
  1575. 3, len(proto2.Extensions[unittest_pb2.repeated_int32_extension]))
  1576. self.assertEqual(
  1577. 0, proto2.Extensions[unittest_pb2.repeated_int32_extension][0])
  1578. self.assertEqual(
  1579. 1, proto2.Extensions[unittest_pb2.repeated_int32_extension][1])
  1580. self.assertEqual(
  1581. 2, proto2.Extensions[unittest_pb2.repeated_int32_extension][2])
  1582. def testMergeFromExtensionsNestedMessage(self):
  1583. proto1 = unittest_pb2.TestAllExtensions()
  1584. ext1 = proto1.Extensions[
  1585. unittest_pb2.repeated_nested_message_extension]
  1586. m = ext1.add()
  1587. m.bb = 222
  1588. m = ext1.add()
  1589. m.bb = 333
  1590. proto2 = unittest_pb2.TestAllExtensions()
  1591. ext2 = proto2.Extensions[
  1592. unittest_pb2.repeated_nested_message_extension]
  1593. m = ext2.add()
  1594. m.bb = 111
  1595. proto2.MergeFrom(proto1)
  1596. ext2 = proto2.Extensions[
  1597. unittest_pb2.repeated_nested_message_extension]
  1598. self.assertEqual(3, len(ext2))
  1599. self.assertEqual(111, ext2[0].bb)
  1600. self.assertEqual(222, ext2[1].bb)
  1601. self.assertEqual(333, ext2[2].bb)
  1602. def testCopyFromBadType(self):
  1603. # The python implementation doesn't raise an exception in this
  1604. # case. In theory it should.
  1605. if api_implementation.Type() == 'python':
  1606. return
  1607. proto1 = unittest_pb2.TestAllTypes()
  1608. proto2 = unittest_pb2.TestAllExtensions()
  1609. self.assertRaises(TypeError, proto1.CopyFrom, proto2)
  1610. def testClear(self):
  1611. proto = unittest_pb2.TestAllTypes()
  1612. # C++ implementation does not support lazy fields right now so leave it
  1613. # out for now.
  1614. if api_implementation.Type() == 'python':
  1615. test_util.SetAllFields(proto)
  1616. else:
  1617. test_util.SetAllNonLazyFields(proto)
  1618. # Clear the message.
  1619. proto.Clear()
  1620. self.assertEqual(proto.ByteSize(), 0)
  1621. empty_proto = unittest_pb2.TestAllTypes()
  1622. self.assertEqual(proto, empty_proto)
  1623. # Test if extensions which were set are cleared.
  1624. proto = unittest_pb2.TestAllExtensions()
  1625. test_util.SetAllExtensions(proto)
  1626. # Clear the message.
  1627. proto.Clear()
  1628. self.assertEqual(proto.ByteSize(), 0)
  1629. empty_proto = unittest_pb2.TestAllExtensions()
  1630. self.assertEqual(proto, empty_proto)
  1631. def testDisconnectingInOneof(self):
  1632. m = unittest_pb2.TestOneof2() # This message has two messages in a oneof.
  1633. m.foo_message.qux_int = 5
  1634. sub_message = m.foo_message
  1635. # Accessing another message's field does not clear the first one
  1636. self.assertEqual(m.foo_lazy_message.qux_int, 0)
  1637. self.assertEqual(m.foo_message.qux_int, 5)
  1638. # But mutating another message in the oneof detaches the first one.
  1639. m.foo_lazy_message.qux_int = 6
  1640. self.assertEqual(m.foo_message.qux_int, 0)
  1641. # The reference we got above was detached and is still valid.
  1642. self.assertEqual(sub_message.qux_int, 5)
  1643. sub_message.qux_int = 7
  1644. def assertInitialized(self, proto):
  1645. self.assertTrue(proto.IsInitialized())
  1646. # Neither method should raise an exception.
  1647. proto.SerializeToString()
  1648. proto.SerializePartialToString()
  1649. def assertNotInitialized(self, proto, error_size=None):
  1650. errors = []
  1651. self.assertFalse(proto.IsInitialized())
  1652. self.assertFalse(proto.IsInitialized(errors))
  1653. self.assertEqual(error_size, len(errors))
  1654. self.assertRaises(message.EncodeError, proto.SerializeToString)
  1655. # "Partial" serialization doesn't care if message is uninitialized.
  1656. proto.SerializePartialToString()
  1657. def testIsInitialized(self):
  1658. # Trivial cases - all optional fields and extensions.
  1659. proto = unittest_pb2.TestAllTypes()
  1660. self.assertInitialized(proto)
  1661. proto = unittest_pb2.TestAllExtensions()
  1662. self.assertInitialized(proto)
  1663. # The case of uninitialized required fields.
  1664. proto = unittest_pb2.TestRequired()
  1665. self.assertNotInitialized(proto, 3)
  1666. proto.a = proto.b = proto.c = 2
  1667. self.assertInitialized(proto)
  1668. # The case of uninitialized submessage.
  1669. proto = unittest_pb2.TestRequiredForeign()
  1670. self.assertInitialized(proto)
  1671. proto.optional_message.a = 1
  1672. self.assertNotInitialized(proto, 2)
  1673. proto.optional_message.b = 0
  1674. proto.optional_message.c = 0
  1675. self.assertInitialized(proto)
  1676. # Uninitialized repeated submessage.
  1677. message1 = proto.repeated_message.add()
  1678. self.assertNotInitialized(proto, 3)
  1679. message1.a = message1.b = message1.c = 0
  1680. self.assertInitialized(proto)
  1681. # Uninitialized repeated group in an extension.
  1682. proto = unittest_pb2.TestAllExtensions()
  1683. extension = unittest_pb2.TestRequired.multi
  1684. message1 = proto.Extensions[extension].add()
  1685. message2 = proto.Extensions[extension].add()
  1686. self.assertNotInitialized(proto, 6)
  1687. message1.a = 1
  1688. message1.b = 1
  1689. message1.c = 1
  1690. self.assertNotInitialized(proto, 3)
  1691. message2.a = 2
  1692. message2.b = 2
  1693. message2.c = 2
  1694. self.assertInitialized(proto)
  1695. # Uninitialized nonrepeated message in an extension.
  1696. proto = unittest_pb2.TestAllExtensions()
  1697. extension = unittest_pb2.TestRequired.single
  1698. proto.Extensions[extension].a = 1
  1699. self.assertNotInitialized(proto, 2)
  1700. proto.Extensions[extension].b = 2
  1701. proto.Extensions[extension].c = 3
  1702. self.assertInitialized(proto)
  1703. # Try passing an errors list.
  1704. errors = []
  1705. proto = unittest_pb2.TestRequired()
  1706. self.assertFalse(proto.IsInitialized(errors))
  1707. self.assertEqual(errors, ['a', 'b', 'c'])
  1708. self.assertRaises(TypeError, proto.IsInitialized, 1, 2, 3)
  1709. @unittest.skipIf(
  1710. api_implementation.Type() != 'cpp' or api_implementation.Version() != 2,
  1711. 'Errors are only available from the most recent C++ implementation.')
  1712. def testFileDescriptorErrors(self):
  1713. file_name = 'test_file_descriptor_errors.proto'
  1714. package_name = 'test_file_descriptor_errors.proto'
  1715. file_descriptor_proto = descriptor_pb2.FileDescriptorProto()
  1716. file_descriptor_proto.name = file_name
  1717. file_descriptor_proto.package = package_name
  1718. m1 = file_descriptor_proto.message_type.add()
  1719. m1.name = 'msg1'
  1720. # Compiles the proto into the C++ descriptor pool
  1721. descriptor.FileDescriptor(
  1722. file_name,
  1723. package_name,
  1724. serialized_pb=file_descriptor_proto.SerializeToString())
  1725. # Add a FileDescriptorProto that has duplicate symbols
  1726. another_file_name = 'another_test_file_descriptor_errors.proto'
  1727. file_descriptor_proto.name = another_file_name
  1728. m2 = file_descriptor_proto.message_type.add()
  1729. m2.name = 'msg2'
  1730. with self.assertRaises(TypeError) as cm:
  1731. descriptor.FileDescriptor(
  1732. another_file_name,
  1733. package_name,
  1734. serialized_pb=file_descriptor_proto.SerializeToString())
  1735. self.assertTrue(hasattr(cm, 'exception'), '%s not raised' %
  1736. getattr(cm.expected, '__name__', cm.expected))
  1737. self.assertIn('test_file_descriptor_errors.proto', str(cm.exception))
  1738. # Error message will say something about this definition being a
  1739. # duplicate, though we don't check the message exactly to avoid a
  1740. # dependency on the C++ logging code.
  1741. self.assertIn('test_file_descriptor_errors.msg1', str(cm.exception))
  1742. def testStringUTF8Serialization(self):
  1743. proto = message_set_extensions_pb2.TestMessageSet()
  1744. extension_message = message_set_extensions_pb2.TestMessageSetExtension2
  1745. extension = extension_message.message_set_extension
  1746. test_utf8 = u'Тест'
  1747. test_utf8_bytes = test_utf8.encode('utf-8')
  1748. # 'Test' in another language, using UTF-8 charset.
  1749. proto.Extensions[extension].str = test_utf8
  1750. # Serialize using the MessageSet wire format (this is specified in the
  1751. # .proto file).
  1752. serialized = proto.SerializeToString()
  1753. # Check byte size.
  1754. self.assertEqual(proto.ByteSize(), len(serialized))
  1755. raw = unittest_mset_pb2.RawMessageSet()
  1756. bytes_read = raw.MergeFromString(serialized)
  1757. self.assertEqual(len(serialized), bytes_read)
  1758. message2 = message_set_extensions_pb2.TestMessageSetExtension2()
  1759. self.assertEqual(1, len(raw.item))
  1760. # Check that the type_id is the same as the tag ID in the .proto file.
  1761. self.assertEqual(raw.item[0].type_id, 98418634)
  1762. # Check the actual bytes on the wire.
  1763. self.assertTrue(raw.item[0].message.endswith(test_utf8_bytes))
  1764. bytes_read = message2.MergeFromString(raw.item[0].message)
  1765. self.assertEqual(len(raw.item[0].message), bytes_read)
  1766. self.assertEqual(type(message2.str), six.text_type)
  1767. self.assertEqual(message2.str, test_utf8)
  1768. # The pure Python API throws an exception on MergeFromString(),
  1769. # if any of the string fields of the message can't be UTF-8 decoded.
  1770. # The C++ implementation of the API has no way to check that on
  1771. # MergeFromString and thus has no way to throw the exception.
  1772. #
  1773. # The pure Python API always returns objects of type 'unicode' (UTF-8
  1774. # encoded), or 'bytes' (in 7 bit ASCII).
  1775. badbytes = raw.item[0].message.replace(
  1776. test_utf8_bytes, len(test_utf8_bytes) * b'\xff')
  1777. unicode_decode_failed = False
  1778. try:
  1779. message2.MergeFromString(badbytes)
  1780. except UnicodeDecodeError:
  1781. unicode_decode_failed = True
  1782. string_field = message2.str
  1783. self.assertTrue(unicode_decode_failed or type(string_field) is bytes)
  1784. def testSetInParent(self):
  1785. proto = unittest_pb2.TestAllTypes()
  1786. self.assertFalse(proto.HasField('optionalgroup'))
  1787. proto.optionalgroup.SetInParent()
  1788. self.assertTrue(proto.HasField('optionalgroup'))
  1789. def testPackageInitializationImport(self):
  1790. """Test that we can import nested messages from their __init__.py.
  1791. Such setup is not trivial since at the time of processing of __init__.py one
  1792. can't refer to its submodules by name in code, so expressions like
  1793. google.protobuf.internal.import_test_package.inner_pb2
  1794. don't work. They do work in imports, so we have assign an alias at import
  1795. and then use that alias in generated code.
  1796. """
  1797. # We import here since it's the import that used to fail, and we want
  1798. # the failure to have the right context.
  1799. # pylint: disable=g-import-not-at-top
  1800. from google.protobuf.internal import import_test_package
  1801. # pylint: enable=g-import-not-at-top
  1802. msg = import_test_package.myproto.Outer()
  1803. # Just check the default value.
  1804. self.assertEqual(57, msg.inner.value)
  1805. # Since we had so many tests for protocol buffer equality, we broke these out
  1806. # into separate TestCase classes.
  1807. @testing_refleaks.TestCase
  1808. class TestAllTypesEqualityTest(unittest.TestCase):
  1809. def setUp(self):
  1810. self.first_proto = unittest_pb2.TestAllTypes()
  1811. self.second_proto = unittest_pb2.TestAllTypes()
  1812. def testNotHashable(self):
  1813. self.assertRaises(TypeError, hash, self.first_proto)
  1814. def testSelfEquality(self):
  1815. self.assertEqual(self.first_proto, self.first_proto)
  1816. def testEmptyProtosEqual(self):
  1817. self.assertEqual(self.first_proto, self.second_proto)
  1818. @testing_refleaks.TestCase
  1819. class FullProtosEqualityTest(unittest.TestCase):
  1820. """Equality tests using completely-full protos as a starting point."""
  1821. def setUp(self):
  1822. self.first_proto = unittest_pb2.TestAllTypes()
  1823. self.second_proto = unittest_pb2.TestAllTypes()
  1824. test_util.SetAllFields(self.first_proto)
  1825. test_util.SetAllFields(self.second_proto)
  1826. def testNotHashable(self):
  1827. self.assertRaises(TypeError, hash, self.first_proto)
  1828. def testNoneNotEqual(self):
  1829. self.assertNotEqual(self.first_proto, None)
  1830. self.assertNotEqual(None, self.second_proto)
  1831. def testNotEqualToOtherMessage(self):
  1832. third_proto = unittest_pb2.TestRequired()
  1833. self.assertNotEqual(self.first_proto, third_proto)
  1834. self.assertNotEqual(third_proto, self.second_proto)
  1835. def testAllFieldsFilledEquality(self):
  1836. self.assertEqual(self.first_proto, self.second_proto)
  1837. def testNonRepeatedScalar(self):
  1838. # Nonrepeated scalar field change should cause inequality.
  1839. self.first_proto.optional_int32 += 1
  1840. self.assertNotEqual(self.first_proto, self.second_proto)
  1841. # ...as should clearing a field.
  1842. self.first_proto.ClearField('optional_int32')
  1843. self.assertNotEqual(self.first_proto, self.second_proto)
  1844. def testNonRepeatedComposite(self):
  1845. # Change a nonrepeated composite field.
  1846. self.first_proto.optional_nested_message.bb += 1
  1847. self.assertNotEqual(self.first_proto, self.second_proto)
  1848. self.first_proto.optional_nested_message.bb -= 1
  1849. self.assertEqual(self.first_proto, self.second_proto)
  1850. # Clear a field in the nested message.
  1851. self.first_proto.optional_nested_message.ClearField('bb')
  1852. self.assertNotEqual(self.first_proto, self.second_proto)
  1853. self.first_proto.optional_nested_message.bb = (
  1854. self.second_proto.optional_nested_message.bb)
  1855. self.assertEqual(self.first_proto, self.second_proto)
  1856. # Remove the nested message entirely.
  1857. self.first_proto.ClearField('optional_nested_message')
  1858. self.assertNotEqual(self.first_proto, self.second_proto)
  1859. def testRepeatedScalar(self):
  1860. # Change a repeated scalar field.
  1861. self.first_proto.repeated_int32.append(5)
  1862. self.assertNotEqual(self.first_proto, self.second_proto)
  1863. self.first_proto.ClearField('repeated_int32')
  1864. self.assertNotEqual(self.first_proto, self.second_proto)
  1865. def testRepeatedComposite(self):
  1866. # Change value within a repeated composite field.
  1867. self.first_proto.repeated_nested_message[0].bb += 1
  1868. self.assertNotEqual(self.first_proto, self.second_proto)
  1869. self.first_proto.repeated_nested_message[0].bb -= 1
  1870. self.assertEqual(self.first_proto, self.second_proto)
  1871. # Add a value to a repeated composite field.
  1872. self.first_proto.repeated_nested_message.add()
  1873. self.assertNotEqual(self.first_proto, self.second_proto)
  1874. self.second_proto.repeated_nested_message.add()
  1875. self.assertEqual(self.first_proto, self.second_proto)
  1876. def testNonRepeatedScalarHasBits(self):
  1877. # Ensure that we test "has" bits as well as value for
  1878. # nonrepeated scalar field.
  1879. self.first_proto.ClearField('optional_int32')
  1880. self.second_proto.optional_int32 = 0
  1881. self.assertNotEqual(self.first_proto, self.second_proto)
  1882. def testNonRepeatedCompositeHasBits(self):
  1883. # Ensure that we test "has" bits as well as value for
  1884. # nonrepeated composite field.
  1885. self.first_proto.ClearField('optional_nested_message')
  1886. self.second_proto.optional_nested_message.ClearField('bb')
  1887. self.assertNotEqual(self.first_proto, self.second_proto)
  1888. self.first_proto.optional_nested_message.bb = 0
  1889. self.first_proto.optional_nested_message.ClearField('bb')
  1890. self.assertEqual(self.first_proto, self.second_proto)
  1891. @testing_refleaks.TestCase
  1892. class ExtensionEqualityTest(unittest.TestCase):
  1893. def testExtensionEquality(self):
  1894. first_proto = unittest_pb2.TestAllExtensions()
  1895. second_proto = unittest_pb2.TestAllExtensions()
  1896. self.assertEqual(first_proto, second_proto)
  1897. test_util.SetAllExtensions(first_proto)
  1898. self.assertNotEqual(first_proto, second_proto)
  1899. test_util.SetAllExtensions(second_proto)
  1900. self.assertEqual(first_proto, second_proto)
  1901. # Ensure that we check value equality.
  1902. first_proto.Extensions[unittest_pb2.optional_int32_extension] += 1
  1903. self.assertNotEqual(first_proto, second_proto)
  1904. first_proto.Extensions[unittest_pb2.optional_int32_extension] -= 1
  1905. self.assertEqual(first_proto, second_proto)
  1906. # Ensure that we also look at "has" bits.
  1907. first_proto.ClearExtension(unittest_pb2.optional_int32_extension)
  1908. second_proto.Extensions[unittest_pb2.optional_int32_extension] = 0
  1909. self.assertNotEqual(first_proto, second_proto)
  1910. first_proto.Extensions[unittest_pb2.optional_int32_extension] = 0
  1911. self.assertEqual(first_proto, second_proto)
  1912. # Ensure that differences in cached values
  1913. # don't matter if "has" bits are both false.
  1914. first_proto = unittest_pb2.TestAllExtensions()
  1915. second_proto = unittest_pb2.TestAllExtensions()
  1916. self.assertEqual(
  1917. 0, first_proto.Extensions[unittest_pb2.optional_int32_extension])
  1918. self.assertEqual(first_proto, second_proto)
  1919. @testing_refleaks.TestCase
  1920. class MutualRecursionEqualityTest(unittest.TestCase):
  1921. def testEqualityWithMutualRecursion(self):
  1922. first_proto = unittest_pb2.TestMutualRecursionA()
  1923. second_proto = unittest_pb2.TestMutualRecursionA()
  1924. self.assertEqual(first_proto, second_proto)
  1925. first_proto.bb.a.bb.optional_int32 = 23
  1926. self.assertNotEqual(first_proto, second_proto)
  1927. second_proto.bb.a.bb.optional_int32 = 23
  1928. self.assertEqual(first_proto, second_proto)
  1929. @testing_refleaks.TestCase
  1930. class ByteSizeTest(unittest.TestCase):
  1931. def setUp(self):
  1932. self.proto = unittest_pb2.TestAllTypes()
  1933. self.extended_proto = more_extensions_pb2.ExtendedMessage()
  1934. self.packed_proto = unittest_pb2.TestPackedTypes()
  1935. self.packed_extended_proto = unittest_pb2.TestPackedExtensions()
  1936. def Size(self):
  1937. return self.proto.ByteSize()
  1938. def testEmptyMessage(self):
  1939. self.assertEqual(0, self.proto.ByteSize())
  1940. def testSizedOnKwargs(self):
  1941. # Use a separate message to ensure testing right after creation.
  1942. proto = unittest_pb2.TestAllTypes()
  1943. self.assertEqual(0, proto.ByteSize())
  1944. proto_kwargs = unittest_pb2.TestAllTypes(optional_int64 = 1)
  1945. # One byte for the tag, one to encode varint 1.
  1946. self.assertEqual(2, proto_kwargs.ByteSize())
  1947. def testVarints(self):
  1948. def Test(i, expected_varint_size):
  1949. self.proto.Clear()
  1950. self.proto.optional_int64 = i
  1951. # Add one to the varint size for the tag info
  1952. # for tag 1.
  1953. self.assertEqual(expected_varint_size + 1, self.Size())
  1954. Test(0, 1)
  1955. Test(1, 1)
  1956. for i, num_bytes in zip(range(7, 63, 7), range(1, 10000)):
  1957. Test((1 << i) - 1, num_bytes)
  1958. Test(-1, 10)
  1959. Test(-2, 10)
  1960. Test(-(1 << 63), 10)
  1961. def testStrings(self):
  1962. self.proto.optional_string = ''
  1963. # Need one byte for tag info (tag #14), and one byte for length.
  1964. self.assertEqual(2, self.Size())
  1965. self.proto.optional_string = 'abc'
  1966. # Need one byte for tag info (tag #14), and one byte for length.
  1967. self.assertEqual(2 + len(self.proto.optional_string), self.Size())
  1968. self.proto.optional_string = 'x' * 128
  1969. # Need one byte for tag info (tag #14), and TWO bytes for length.
  1970. self.assertEqual(3 + len(self.proto.optional_string), self.Size())
  1971. def testOtherNumerics(self):
  1972. self.proto.optional_fixed32 = 1234
  1973. # One byte for tag and 4 bytes for fixed32.
  1974. self.assertEqual(5, self.Size())
  1975. self.proto = unittest_pb2.TestAllTypes()
  1976. self.proto.optional_fixed64 = 1234
  1977. # One byte for tag and 8 bytes for fixed64.
  1978. self.assertEqual(9, self.Size())
  1979. self.proto = unittest_pb2.TestAllTypes()
  1980. self.proto.optional_float = 1.234
  1981. # One byte for tag and 4 bytes for float.
  1982. self.assertEqual(5, self.Size())
  1983. self.proto = unittest_pb2.TestAllTypes()
  1984. self.proto.optional_double = 1.234
  1985. # One byte for tag and 8 bytes for float.
  1986. self.assertEqual(9, self.Size())
  1987. self.proto = unittest_pb2.TestAllTypes()
  1988. self.proto.optional_sint32 = 64
  1989. # One byte for tag and 2 bytes for zig-zag-encoded 64.
  1990. self.assertEqual(3, self.Size())
  1991. self.proto = unittest_pb2.TestAllTypes()
  1992. def testComposites(self):
  1993. # 3 bytes.
  1994. self.proto.optional_nested_message.bb = (1 << 14)
  1995. # Plus one byte for bb tag.
  1996. # Plus 1 byte for optional_nested_message serialized size.
  1997. # Plus two bytes for optional_nested_message tag.
  1998. self.assertEqual(3 + 1 + 1 + 2, self.Size())
  1999. def testGroups(self):
  2000. # 4 bytes.
  2001. self.proto.optionalgroup.a = (1 << 21)
  2002. # Plus two bytes for |a| tag.
  2003. # Plus 2 * two bytes for START_GROUP and END_GROUP tags.
  2004. self.assertEqual(4 + 2 + 2*2, self.Size())
  2005. def testRepeatedScalars(self):
  2006. self.proto.repeated_int32.append(10) # 1 byte.
  2007. self.proto.repeated_int32.append(128) # 2 bytes.
  2008. # Also need 2 bytes for each entry for tag.
  2009. self.assertEqual(1 + 2 + 2*2, self.Size())
  2010. def testRepeatedScalarsExtend(self):
  2011. self.proto.repeated_int32.extend([10, 128]) # 3 bytes.
  2012. # Also need 2 bytes for each entry for tag.
  2013. self.assertEqual(1 + 2 + 2*2, self.Size())
  2014. def testRepeatedScalarsRemove(self):
  2015. self.proto.repeated_int32.append(10) # 1 byte.
  2016. self.proto.repeated_int32.append(128) # 2 bytes.
  2017. # Also need 2 bytes for each entry for tag.
  2018. self.assertEqual(1 + 2 + 2*2, self.Size())
  2019. self.proto.repeated_int32.remove(128)
  2020. self.assertEqual(1 + 2, self.Size())
  2021. def testRepeatedComposites(self):
  2022. # Empty message. 2 bytes tag plus 1 byte length.
  2023. foreign_message_0 = self.proto.repeated_nested_message.add()
  2024. # 2 bytes tag plus 1 byte length plus 1 byte bb tag 1 byte int.
  2025. foreign_message_1 = self.proto.repeated_nested_message.add()
  2026. foreign_message_1.bb = 7
  2027. self.assertEqual(2 + 1 + 2 + 1 + 1 + 1, self.Size())
  2028. def testRepeatedCompositesDelete(self):
  2029. # Empty message. 2 bytes tag plus 1 byte length.
  2030. foreign_message_0 = self.proto.repeated_nested_message.add()
  2031. # 2 bytes tag plus 1 byte length plus 1 byte bb tag 1 byte int.
  2032. foreign_message_1 = self.proto.repeated_nested_message.add()
  2033. foreign_message_1.bb = 9
  2034. self.assertEqual(2 + 1 + 2 + 1 + 1 + 1, self.Size())
  2035. repeated_nested_message = copy.deepcopy(
  2036. self.proto.repeated_nested_message)
  2037. # 2 bytes tag plus 1 byte length plus 1 byte bb tag 1 byte int.
  2038. del self.proto.repeated_nested_message[0]
  2039. self.assertEqual(2 + 1 + 1 + 1, self.Size())
  2040. # Now add a new message.
  2041. foreign_message_2 = self.proto.repeated_nested_message.add()
  2042. foreign_message_2.bb = 12
  2043. # 2 bytes tag plus 1 byte length plus 1 byte bb tag 1 byte int.
  2044. # 2 bytes tag plus 1 byte length plus 1 byte bb tag 1 byte int.
  2045. self.assertEqual(2 + 1 + 1 + 1 + 2 + 1 + 1 + 1, self.Size())
  2046. # 2 bytes tag plus 1 byte length plus 1 byte bb tag 1 byte int.
  2047. del self.proto.repeated_nested_message[1]
  2048. self.assertEqual(2 + 1 + 1 + 1, self.Size())
  2049. del self.proto.repeated_nested_message[0]
  2050. self.assertEqual(0, self.Size())
  2051. self.assertEqual(2, len(repeated_nested_message))
  2052. del repeated_nested_message[0:1]
  2053. # TODO(jieluo): Fix cpp extension bug when delete repeated message.
  2054. if api_implementation.Type() == 'python':
  2055. self.assertEqual(1, len(repeated_nested_message))
  2056. del repeated_nested_message[-1]
  2057. # TODO(jieluo): Fix cpp extension bug when delete repeated message.
  2058. if api_implementation.Type() == 'python':
  2059. self.assertEqual(0, len(repeated_nested_message))
  2060. def testRepeatedGroups(self):
  2061. # 2-byte START_GROUP plus 2-byte END_GROUP.
  2062. group_0 = self.proto.repeatedgroup.add()
  2063. # 2-byte START_GROUP plus 2-byte |a| tag + 1-byte |a|
  2064. # plus 2-byte END_GROUP.
  2065. group_1 = self.proto.repeatedgroup.add()
  2066. group_1.a = 7
  2067. self.assertEqual(2 + 2 + 2 + 2 + 1 + 2, self.Size())
  2068. def testExtensions(self):
  2069. proto = unittest_pb2.TestAllExtensions()
  2070. self.assertEqual(0, proto.ByteSize())
  2071. extension = unittest_pb2.optional_int32_extension # Field #1, 1 byte.
  2072. proto.Extensions[extension] = 23
  2073. # 1 byte for tag, 1 byte for value.
  2074. self.assertEqual(2, proto.ByteSize())
  2075. field = unittest_pb2.TestAllTypes.DESCRIPTOR.fields_by_name[
  2076. 'optional_int32']
  2077. with self.assertRaises(KeyError):
  2078. proto.Extensions[field] = 23
  2079. def testCacheInvalidationForNonrepeatedScalar(self):
  2080. # Test non-extension.
  2081. self.proto.optional_int32 = 1
  2082. self.assertEqual(2, self.proto.ByteSize())
  2083. self.proto.optional_int32 = 128
  2084. self.assertEqual(3, self.proto.ByteSize())
  2085. self.proto.ClearField('optional_int32')
  2086. self.assertEqual(0, self.proto.ByteSize())
  2087. # Test within extension.
  2088. extension = more_extensions_pb2.optional_int_extension
  2089. self.extended_proto.Extensions[extension] = 1
  2090. self.assertEqual(2, self.extended_proto.ByteSize())
  2091. self.extended_proto.Extensions[extension] = 128
  2092. self.assertEqual(3, self.extended_proto.ByteSize())
  2093. self.extended_proto.ClearExtension(extension)
  2094. self.assertEqual(0, self.extended_proto.ByteSize())
  2095. def testCacheInvalidationForRepeatedScalar(self):
  2096. # Test non-extension.
  2097. self.proto.repeated_int32.append(1)
  2098. self.assertEqual(3, self.proto.ByteSize())
  2099. self.proto.repeated_int32.append(1)
  2100. self.assertEqual(6, self.proto.ByteSize())
  2101. self.proto.repeated_int32[1] = 128
  2102. self.assertEqual(7, self.proto.ByteSize())
  2103. self.proto.ClearField('repeated_int32')
  2104. self.assertEqual(0, self.proto.ByteSize())
  2105. # Test within extension.
  2106. extension = more_extensions_pb2.repeated_int_extension
  2107. repeated = self.extended_proto.Extensions[extension]
  2108. repeated.append(1)
  2109. self.assertEqual(2, self.extended_proto.ByteSize())
  2110. repeated.append(1)
  2111. self.assertEqual(4, self.extended_proto.ByteSize())
  2112. repeated[1] = 128
  2113. self.assertEqual(5, self.extended_proto.ByteSize())
  2114. self.extended_proto.ClearExtension(extension)
  2115. self.assertEqual(0, self.extended_proto.ByteSize())
  2116. def testCacheInvalidationForNonrepeatedMessage(self):
  2117. # Test non-extension.
  2118. self.proto.optional_foreign_message.c = 1
  2119. self.assertEqual(5, self.proto.ByteSize())
  2120. self.proto.optional_foreign_message.c = 128
  2121. self.assertEqual(6, self.proto.ByteSize())
  2122. self.proto.optional_foreign_message.ClearField('c')
  2123. self.assertEqual(3, self.proto.ByteSize())
  2124. self.proto.ClearField('optional_foreign_message')
  2125. self.assertEqual(0, self.proto.ByteSize())
  2126. if api_implementation.Type() == 'python':
  2127. # This is only possible in pure-Python implementation of the API.
  2128. child = self.proto.optional_foreign_message
  2129. self.proto.ClearField('optional_foreign_message')
  2130. child.c = 128
  2131. self.assertEqual(0, self.proto.ByteSize())
  2132. # Test within extension.
  2133. extension = more_extensions_pb2.optional_message_extension
  2134. child = self.extended_proto.Extensions[extension]
  2135. self.assertEqual(0, self.extended_proto.ByteSize())
  2136. child.foreign_message_int = 1
  2137. self.assertEqual(4, self.extended_proto.ByteSize())
  2138. child.foreign_message_int = 128
  2139. self.assertEqual(5, self.extended_proto.ByteSize())
  2140. self.extended_proto.ClearExtension(extension)
  2141. self.assertEqual(0, self.extended_proto.ByteSize())
  2142. def testCacheInvalidationForRepeatedMessage(self):
  2143. # Test non-extension.
  2144. child0 = self.proto.repeated_foreign_message.add()
  2145. self.assertEqual(3, self.proto.ByteSize())
  2146. self.proto.repeated_foreign_message.add()
  2147. self.assertEqual(6, self.proto.ByteSize())
  2148. child0.c = 1
  2149. self.assertEqual(8, self.proto.ByteSize())
  2150. self.proto.ClearField('repeated_foreign_message')
  2151. self.assertEqual(0, self.proto.ByteSize())
  2152. # Test within extension.
  2153. extension = more_extensions_pb2.repeated_message_extension
  2154. child_list = self.extended_proto.Extensions[extension]
  2155. child0 = child_list.add()
  2156. self.assertEqual(2, self.extended_proto.ByteSize())
  2157. child_list.add()
  2158. self.assertEqual(4, self.extended_proto.ByteSize())
  2159. child0.foreign_message_int = 1
  2160. self.assertEqual(6, self.extended_proto.ByteSize())
  2161. child0.ClearField('foreign_message_int')
  2162. self.assertEqual(4, self.extended_proto.ByteSize())
  2163. self.extended_proto.ClearExtension(extension)
  2164. self.assertEqual(0, self.extended_proto.ByteSize())
  2165. def testPackedRepeatedScalars(self):
  2166. self.assertEqual(0, self.packed_proto.ByteSize())
  2167. self.packed_proto.packed_int32.append(10) # 1 byte.
  2168. self.packed_proto.packed_int32.append(128) # 2 bytes.
  2169. # The tag is 2 bytes (the field number is 90), and the varint
  2170. # storing the length is 1 byte.
  2171. int_size = 1 + 2 + 3
  2172. self.assertEqual(int_size, self.packed_proto.ByteSize())
  2173. self.packed_proto.packed_double.append(4.2) # 8 bytes
  2174. self.packed_proto.packed_double.append(3.25) # 8 bytes
  2175. # 2 more tag bytes, 1 more length byte.
  2176. double_size = 8 + 8 + 3
  2177. self.assertEqual(int_size+double_size, self.packed_proto.ByteSize())
  2178. self.packed_proto.ClearField('packed_int32')
  2179. self.assertEqual(double_size, self.packed_proto.ByteSize())
  2180. def testPackedExtensions(self):
  2181. self.assertEqual(0, self.packed_extended_proto.ByteSize())
  2182. extension = self.packed_extended_proto.Extensions[
  2183. unittest_pb2.packed_fixed32_extension]
  2184. extension.extend([1, 2, 3, 4]) # 16 bytes
  2185. # Tag is 3 bytes.
  2186. self.assertEqual(19, self.packed_extended_proto.ByteSize())
  2187. # Issues to be sure to cover include:
  2188. # * Handling of unrecognized tags ("uninterpreted_bytes").
  2189. # * Handling of MessageSets.
  2190. # * Consistent ordering of tags in the wire format,
  2191. # including ordering between extensions and non-extension
  2192. # fields.
  2193. # * Consistent serialization of negative numbers, especially
  2194. # negative int32s.
  2195. # * Handling of empty submessages (with and without "has"
  2196. # bits set).
  2197. @testing_refleaks.TestCase
  2198. class SerializationTest(unittest.TestCase):
  2199. def testSerializeEmtpyMessage(self):
  2200. first_proto = unittest_pb2.TestAllTypes()
  2201. second_proto = unittest_pb2.TestAllTypes()
  2202. serialized = first_proto.SerializeToString()
  2203. self.assertEqual(first_proto.ByteSize(), len(serialized))
  2204. self.assertEqual(
  2205. len(serialized),
  2206. second_proto.MergeFromString(serialized))
  2207. self.assertEqual(first_proto, second_proto)
  2208. def testSerializeAllFields(self):
  2209. first_proto = unittest_pb2.TestAllTypes()
  2210. second_proto = unittest_pb2.TestAllTypes()
  2211. test_util.SetAllFields(first_proto)
  2212. serialized = first_proto.SerializeToString()
  2213. self.assertEqual(first_proto.ByteSize(), len(serialized))
  2214. self.assertEqual(
  2215. len(serialized),
  2216. second_proto.MergeFromString(serialized))
  2217. self.assertEqual(first_proto, second_proto)
  2218. def testSerializeAllExtensions(self):
  2219. first_proto = unittest_pb2.TestAllExtensions()
  2220. second_proto = unittest_pb2.TestAllExtensions()
  2221. test_util.SetAllExtensions(first_proto)
  2222. serialized = first_proto.SerializeToString()
  2223. self.assertEqual(
  2224. len(serialized),
  2225. second_proto.MergeFromString(serialized))
  2226. self.assertEqual(first_proto, second_proto)
  2227. def testSerializeWithOptionalGroup(self):
  2228. first_proto = unittest_pb2.TestAllTypes()
  2229. second_proto = unittest_pb2.TestAllTypes()
  2230. first_proto.optionalgroup.a = 242
  2231. serialized = first_proto.SerializeToString()
  2232. self.assertEqual(
  2233. len(serialized),
  2234. second_proto.MergeFromString(serialized))
  2235. self.assertEqual(first_proto, second_proto)
  2236. def testSerializeNegativeValues(self):
  2237. first_proto = unittest_pb2.TestAllTypes()
  2238. first_proto.optional_int32 = -1
  2239. first_proto.optional_int64 = -(2 << 40)
  2240. first_proto.optional_sint32 = -3
  2241. first_proto.optional_sint64 = -(4 << 40)
  2242. first_proto.optional_sfixed32 = -5
  2243. first_proto.optional_sfixed64 = -(6 << 40)
  2244. second_proto = unittest_pb2.TestAllTypes.FromString(
  2245. first_proto.SerializeToString())
  2246. self.assertEqual(first_proto, second_proto)
  2247. def testParseTruncated(self):
  2248. # This test is only applicable for the Python implementation of the API.
  2249. if api_implementation.Type() != 'python':
  2250. return
  2251. first_proto = unittest_pb2.TestAllTypes()
  2252. test_util.SetAllFields(first_proto)
  2253. serialized = memoryview(first_proto.SerializeToString())
  2254. for truncation_point in range(len(serialized) + 1):
  2255. try:
  2256. second_proto = unittest_pb2.TestAllTypes()
  2257. unknown_fields = unittest_pb2.TestEmptyMessage()
  2258. pos = second_proto._InternalParse(serialized, 0, truncation_point)
  2259. # If we didn't raise an error then we read exactly the amount expected.
  2260. self.assertEqual(truncation_point, pos)
  2261. # Parsing to unknown fields should not throw if parsing to known fields
  2262. # did not.
  2263. try:
  2264. pos2 = unknown_fields._InternalParse(serialized, 0, truncation_point)
  2265. self.assertEqual(truncation_point, pos2)
  2266. except message.DecodeError:
  2267. self.fail('Parsing unknown fields failed when parsing known fields '
  2268. 'did not.')
  2269. except message.DecodeError:
  2270. # Parsing unknown fields should also fail.
  2271. self.assertRaises(message.DecodeError, unknown_fields._InternalParse,
  2272. serialized, 0, truncation_point)
  2273. def testCanonicalSerializationOrder(self):
  2274. proto = more_messages_pb2.OutOfOrderFields()
  2275. # These are also their tag numbers. Even though we're setting these in
  2276. # reverse-tag order AND they're listed in reverse tag-order in the .proto
  2277. # file, they should nonetheless be serialized in tag order.
  2278. proto.optional_sint32 = 5
  2279. proto.Extensions[more_messages_pb2.optional_uint64] = 4
  2280. proto.optional_uint32 = 3
  2281. proto.Extensions[more_messages_pb2.optional_int64] = 2
  2282. proto.optional_int32 = 1
  2283. serialized = proto.SerializeToString()
  2284. self.assertEqual(proto.ByteSize(), len(serialized))
  2285. d = _MiniDecoder(serialized)
  2286. ReadTag = d.ReadFieldNumberAndWireType
  2287. self.assertEqual((1, wire_format.WIRETYPE_VARINT), ReadTag())
  2288. self.assertEqual(1, d.ReadInt32())
  2289. self.assertEqual((2, wire_format.WIRETYPE_VARINT), ReadTag())
  2290. self.assertEqual(2, d.ReadInt64())
  2291. self.assertEqual((3, wire_format.WIRETYPE_VARINT), ReadTag())
  2292. self.assertEqual(3, d.ReadUInt32())
  2293. self.assertEqual((4, wire_format.WIRETYPE_VARINT), ReadTag())
  2294. self.assertEqual(4, d.ReadUInt64())
  2295. self.assertEqual((5, wire_format.WIRETYPE_VARINT), ReadTag())
  2296. self.assertEqual(5, d.ReadSInt32())
  2297. def testCanonicalSerializationOrderSameAsCpp(self):
  2298. # Copy of the same test we use for C++.
  2299. proto = unittest_pb2.TestFieldOrderings()
  2300. test_util.SetAllFieldsAndExtensions(proto)
  2301. serialized = proto.SerializeToString()
  2302. test_util.ExpectAllFieldsAndExtensionsInOrder(serialized)
  2303. def testMergeFromStringWhenFieldsAlreadySet(self):
  2304. first_proto = unittest_pb2.TestAllTypes()
  2305. first_proto.repeated_string.append('foobar')
  2306. first_proto.optional_int32 = 23
  2307. first_proto.optional_nested_message.bb = 42
  2308. serialized = first_proto.SerializeToString()
  2309. second_proto = unittest_pb2.TestAllTypes()
  2310. second_proto.repeated_string.append('baz')
  2311. second_proto.optional_int32 = 100
  2312. second_proto.optional_nested_message.bb = 999
  2313. bytes_parsed = second_proto.MergeFromString(serialized)
  2314. self.assertEqual(len(serialized), bytes_parsed)
  2315. # Ensure that we append to repeated fields.
  2316. self.assertEqual(['baz', 'foobar'], list(second_proto.repeated_string))
  2317. # Ensure that we overwrite nonrepeatd scalars.
  2318. self.assertEqual(23, second_proto.optional_int32)
  2319. # Ensure that we recursively call MergeFromString() on
  2320. # submessages.
  2321. self.assertEqual(42, second_proto.optional_nested_message.bb)
  2322. def testMessageSetWireFormat(self):
  2323. proto = message_set_extensions_pb2.TestMessageSet()
  2324. extension_message1 = message_set_extensions_pb2.TestMessageSetExtension1
  2325. extension_message2 = message_set_extensions_pb2.TestMessageSetExtension2
  2326. extension1 = extension_message1.message_set_extension
  2327. extension2 = extension_message2.message_set_extension
  2328. extension3 = message_set_extensions_pb2.message_set_extension3
  2329. proto.Extensions[extension1].i = 123
  2330. proto.Extensions[extension2].str = 'foo'
  2331. proto.Extensions[extension3].text = 'bar'
  2332. # Serialize using the MessageSet wire format (this is specified in the
  2333. # .proto file).
  2334. serialized = proto.SerializeToString()
  2335. raw = unittest_mset_pb2.RawMessageSet()
  2336. self.assertEqual(False,
  2337. raw.DESCRIPTOR.GetOptions().message_set_wire_format)
  2338. self.assertEqual(
  2339. len(serialized),
  2340. raw.MergeFromString(serialized))
  2341. self.assertEqual(3, len(raw.item))
  2342. message1 = message_set_extensions_pb2.TestMessageSetExtension1()
  2343. self.assertEqual(
  2344. len(raw.item[0].message),
  2345. message1.MergeFromString(raw.item[0].message))
  2346. self.assertEqual(123, message1.i)
  2347. message2 = message_set_extensions_pb2.TestMessageSetExtension2()
  2348. self.assertEqual(
  2349. len(raw.item[1].message),
  2350. message2.MergeFromString(raw.item[1].message))
  2351. self.assertEqual('foo', message2.str)
  2352. message3 = message_set_extensions_pb2.TestMessageSetExtension3()
  2353. self.assertEqual(
  2354. len(raw.item[2].message),
  2355. message3.MergeFromString(raw.item[2].message))
  2356. self.assertEqual('bar', message3.text)
  2357. # Deserialize using the MessageSet wire format.
  2358. proto2 = message_set_extensions_pb2.TestMessageSet()
  2359. self.assertEqual(
  2360. len(serialized),
  2361. proto2.MergeFromString(serialized))
  2362. self.assertEqual(123, proto2.Extensions[extension1].i)
  2363. self.assertEqual('foo', proto2.Extensions[extension2].str)
  2364. self.assertEqual('bar', proto2.Extensions[extension3].text)
  2365. # Check byte size.
  2366. self.assertEqual(proto2.ByteSize(), len(serialized))
  2367. self.assertEqual(proto.ByteSize(), len(serialized))
  2368. def testMessageSetWireFormatUnknownExtension(self):
  2369. # Create a message using the message set wire format with an unknown
  2370. # message.
  2371. raw = unittest_mset_pb2.RawMessageSet()
  2372. # Add an item.
  2373. item = raw.item.add()
  2374. item.type_id = 98418603
  2375. extension_message1 = message_set_extensions_pb2.TestMessageSetExtension1
  2376. message1 = message_set_extensions_pb2.TestMessageSetExtension1()
  2377. message1.i = 12345
  2378. item.message = message1.SerializeToString()
  2379. # Add a second, unknown extension.
  2380. item = raw.item.add()
  2381. item.type_id = 98418604
  2382. extension_message1 = message_set_extensions_pb2.TestMessageSetExtension1
  2383. message1 = message_set_extensions_pb2.TestMessageSetExtension1()
  2384. message1.i = 12346
  2385. item.message = message1.SerializeToString()
  2386. # Add another unknown extension.
  2387. item = raw.item.add()
  2388. item.type_id = 98418605
  2389. message1 = message_set_extensions_pb2.TestMessageSetExtension2()
  2390. message1.str = 'foo'
  2391. item.message = message1.SerializeToString()
  2392. serialized = raw.SerializeToString()
  2393. # Parse message using the message set wire format.
  2394. proto = message_set_extensions_pb2.TestMessageSet()
  2395. self.assertEqual(
  2396. len(serialized),
  2397. proto.MergeFromString(serialized))
  2398. # Check that the message parsed well.
  2399. extension_message1 = message_set_extensions_pb2.TestMessageSetExtension1
  2400. extension1 = extension_message1.message_set_extension
  2401. self.assertEqual(12345, proto.Extensions[extension1].i)
  2402. def testUnknownFields(self):
  2403. proto = unittest_pb2.TestAllTypes()
  2404. test_util.SetAllFields(proto)
  2405. serialized = proto.SerializeToString()
  2406. # The empty message should be parsable with all of the fields
  2407. # unknown.
  2408. proto2 = unittest_pb2.TestEmptyMessage()
  2409. # Parsing this message should succeed.
  2410. self.assertEqual(
  2411. len(serialized),
  2412. proto2.MergeFromString(serialized))
  2413. # Now test with a int64 field set.
  2414. proto = unittest_pb2.TestAllTypes()
  2415. proto.optional_int64 = 0x0fffffffffffffff
  2416. serialized = proto.SerializeToString()
  2417. # The empty message should be parsable with all of the fields
  2418. # unknown.
  2419. proto2 = unittest_pb2.TestEmptyMessage()
  2420. # Parsing this message should succeed.
  2421. self.assertEqual(
  2422. len(serialized),
  2423. proto2.MergeFromString(serialized))
  2424. def _CheckRaises(self, exc_class, callable_obj, exception):
  2425. """This method checks if the excpetion type and message are as expected."""
  2426. try:
  2427. callable_obj()
  2428. except exc_class as ex:
  2429. # Check if the exception message is the right one.
  2430. self.assertEqual(exception, str(ex))
  2431. return
  2432. else:
  2433. raise self.failureException('%s not raised' % str(exc_class))
  2434. def testSerializeUninitialized(self):
  2435. proto = unittest_pb2.TestRequired()
  2436. self._CheckRaises(
  2437. message.EncodeError,
  2438. proto.SerializeToString,
  2439. 'Message protobuf_unittest.TestRequired is missing required fields: '
  2440. 'a,b,c')
  2441. # Shouldn't raise exceptions.
  2442. partial = proto.SerializePartialToString()
  2443. proto2 = unittest_pb2.TestRequired()
  2444. self.assertFalse(proto2.HasField('a'))
  2445. # proto2 ParseFromString does not check that required fields are set.
  2446. proto2.ParseFromString(partial)
  2447. self.assertFalse(proto2.HasField('a'))
  2448. proto.a = 1
  2449. self._CheckRaises(
  2450. message.EncodeError,
  2451. proto.SerializeToString,
  2452. 'Message protobuf_unittest.TestRequired is missing required fields: b,c')
  2453. # Shouldn't raise exceptions.
  2454. partial = proto.SerializePartialToString()
  2455. proto.b = 2
  2456. self._CheckRaises(
  2457. message.EncodeError,
  2458. proto.SerializeToString,
  2459. 'Message protobuf_unittest.TestRequired is missing required fields: c')
  2460. # Shouldn't raise exceptions.
  2461. partial = proto.SerializePartialToString()
  2462. proto.c = 3
  2463. serialized = proto.SerializeToString()
  2464. # Shouldn't raise exceptions.
  2465. partial = proto.SerializePartialToString()
  2466. proto2 = unittest_pb2.TestRequired()
  2467. self.assertEqual(
  2468. len(serialized),
  2469. proto2.MergeFromString(serialized))
  2470. self.assertEqual(1, proto2.a)
  2471. self.assertEqual(2, proto2.b)
  2472. self.assertEqual(3, proto2.c)
  2473. self.assertEqual(
  2474. len(partial),
  2475. proto2.MergeFromString(partial))
  2476. self.assertEqual(1, proto2.a)
  2477. self.assertEqual(2, proto2.b)
  2478. self.assertEqual(3, proto2.c)
  2479. def testSerializeUninitializedSubMessage(self):
  2480. proto = unittest_pb2.TestRequiredForeign()
  2481. # Sub-message doesn't exist yet, so this succeeds.
  2482. proto.SerializeToString()
  2483. proto.optional_message.a = 1
  2484. self._CheckRaises(
  2485. message.EncodeError,
  2486. proto.SerializeToString,
  2487. 'Message protobuf_unittest.TestRequiredForeign '
  2488. 'is missing required fields: '
  2489. 'optional_message.b,optional_message.c')
  2490. proto.optional_message.b = 2
  2491. proto.optional_message.c = 3
  2492. proto.SerializeToString()
  2493. proto.repeated_message.add().a = 1
  2494. proto.repeated_message.add().b = 2
  2495. self._CheckRaises(
  2496. message.EncodeError,
  2497. proto.SerializeToString,
  2498. 'Message protobuf_unittest.TestRequiredForeign is missing required fields: '
  2499. 'repeated_message[0].b,repeated_message[0].c,'
  2500. 'repeated_message[1].a,repeated_message[1].c')
  2501. proto.repeated_message[0].b = 2
  2502. proto.repeated_message[0].c = 3
  2503. proto.repeated_message[1].a = 1
  2504. proto.repeated_message[1].c = 3
  2505. proto.SerializeToString()
  2506. def testSerializeAllPackedFields(self):
  2507. first_proto = unittest_pb2.TestPackedTypes()
  2508. second_proto = unittest_pb2.TestPackedTypes()
  2509. test_util.SetAllPackedFields(first_proto)
  2510. serialized = first_proto.SerializeToString()
  2511. self.assertEqual(first_proto.ByteSize(), len(serialized))
  2512. bytes_read = second_proto.MergeFromString(serialized)
  2513. self.assertEqual(second_proto.ByteSize(), bytes_read)
  2514. self.assertEqual(first_proto, second_proto)
  2515. def testSerializeAllPackedExtensions(self):
  2516. first_proto = unittest_pb2.TestPackedExtensions()
  2517. second_proto = unittest_pb2.TestPackedExtensions()
  2518. test_util.SetAllPackedExtensions(first_proto)
  2519. serialized = first_proto.SerializeToString()
  2520. bytes_read = second_proto.MergeFromString(serialized)
  2521. self.assertEqual(second_proto.ByteSize(), bytes_read)
  2522. self.assertEqual(first_proto, second_proto)
  2523. def testMergePackedFromStringWhenSomeFieldsAlreadySet(self):
  2524. first_proto = unittest_pb2.TestPackedTypes()
  2525. first_proto.packed_int32.extend([1, 2])
  2526. first_proto.packed_double.append(3.0)
  2527. serialized = first_proto.SerializeToString()
  2528. second_proto = unittest_pb2.TestPackedTypes()
  2529. second_proto.packed_int32.append(3)
  2530. second_proto.packed_double.extend([1.0, 2.0])
  2531. second_proto.packed_sint32.append(4)
  2532. self.assertEqual(
  2533. len(serialized),
  2534. second_proto.MergeFromString(serialized))
  2535. self.assertEqual([3, 1, 2], second_proto.packed_int32)
  2536. self.assertEqual([1.0, 2.0, 3.0], second_proto.packed_double)
  2537. self.assertEqual([4], second_proto.packed_sint32)
  2538. def testPackedFieldsWireFormat(self):
  2539. proto = unittest_pb2.TestPackedTypes()
  2540. proto.packed_int32.extend([1, 2, 150, 3]) # 1 + 1 + 2 + 1 bytes
  2541. proto.packed_double.extend([1.0, 1000.0]) # 8 + 8 bytes
  2542. proto.packed_float.append(2.0) # 4 bytes, will be before double
  2543. serialized = proto.SerializeToString()
  2544. self.assertEqual(proto.ByteSize(), len(serialized))
  2545. d = _MiniDecoder(serialized)
  2546. ReadTag = d.ReadFieldNumberAndWireType
  2547. self.assertEqual((90, wire_format.WIRETYPE_LENGTH_DELIMITED), ReadTag())
  2548. self.assertEqual(1+1+1+2, d.ReadInt32())
  2549. self.assertEqual(1, d.ReadInt32())
  2550. self.assertEqual(2, d.ReadInt32())
  2551. self.assertEqual(150, d.ReadInt32())
  2552. self.assertEqual(3, d.ReadInt32())
  2553. self.assertEqual((100, wire_format.WIRETYPE_LENGTH_DELIMITED), ReadTag())
  2554. self.assertEqual(4, d.ReadInt32())
  2555. self.assertEqual(2.0, d.ReadFloat())
  2556. self.assertEqual((101, wire_format.WIRETYPE_LENGTH_DELIMITED), ReadTag())
  2557. self.assertEqual(8+8, d.ReadInt32())
  2558. self.assertEqual(1.0, d.ReadDouble())
  2559. self.assertEqual(1000.0, d.ReadDouble())
  2560. self.assertTrue(d.EndOfStream())
  2561. def testParsePackedFromUnpacked(self):
  2562. unpacked = unittest_pb2.TestUnpackedTypes()
  2563. test_util.SetAllUnpackedFields(unpacked)
  2564. packed = unittest_pb2.TestPackedTypes()
  2565. serialized = unpacked.SerializeToString()
  2566. self.assertEqual(
  2567. len(serialized),
  2568. packed.MergeFromString(serialized))
  2569. expected = unittest_pb2.TestPackedTypes()
  2570. test_util.SetAllPackedFields(expected)
  2571. self.assertEqual(expected, packed)
  2572. def testParseUnpackedFromPacked(self):
  2573. packed = unittest_pb2.TestPackedTypes()
  2574. test_util.SetAllPackedFields(packed)
  2575. unpacked = unittest_pb2.TestUnpackedTypes()
  2576. serialized = packed.SerializeToString()
  2577. self.assertEqual(
  2578. len(serialized),
  2579. unpacked.MergeFromString(serialized))
  2580. expected = unittest_pb2.TestUnpackedTypes()
  2581. test_util.SetAllUnpackedFields(expected)
  2582. self.assertEqual(expected, unpacked)
  2583. def testFieldNumbers(self):
  2584. proto = unittest_pb2.TestAllTypes()
  2585. self.assertEqual(unittest_pb2.TestAllTypes.NestedMessage.BB_FIELD_NUMBER, 1)
  2586. self.assertEqual(unittest_pb2.TestAllTypes.OPTIONAL_INT32_FIELD_NUMBER, 1)
  2587. self.assertEqual(unittest_pb2.TestAllTypes.OPTIONALGROUP_FIELD_NUMBER, 16)
  2588. self.assertEqual(
  2589. unittest_pb2.TestAllTypes.OPTIONAL_NESTED_MESSAGE_FIELD_NUMBER, 18)
  2590. self.assertEqual(
  2591. unittest_pb2.TestAllTypes.OPTIONAL_NESTED_ENUM_FIELD_NUMBER, 21)
  2592. self.assertEqual(unittest_pb2.TestAllTypes.REPEATED_INT32_FIELD_NUMBER, 31)
  2593. self.assertEqual(unittest_pb2.TestAllTypes.REPEATEDGROUP_FIELD_NUMBER, 46)
  2594. self.assertEqual(
  2595. unittest_pb2.TestAllTypes.REPEATED_NESTED_MESSAGE_FIELD_NUMBER, 48)
  2596. self.assertEqual(
  2597. unittest_pb2.TestAllTypes.REPEATED_NESTED_ENUM_FIELD_NUMBER, 51)
  2598. def testExtensionFieldNumbers(self):
  2599. self.assertEqual(unittest_pb2.TestRequired.single.number, 1000)
  2600. self.assertEqual(unittest_pb2.TestRequired.SINGLE_FIELD_NUMBER, 1000)
  2601. self.assertEqual(unittest_pb2.TestRequired.multi.number, 1001)
  2602. self.assertEqual(unittest_pb2.TestRequired.MULTI_FIELD_NUMBER, 1001)
  2603. self.assertEqual(unittest_pb2.optional_int32_extension.number, 1)
  2604. self.assertEqual(unittest_pb2.OPTIONAL_INT32_EXTENSION_FIELD_NUMBER, 1)
  2605. self.assertEqual(unittest_pb2.optionalgroup_extension.number, 16)
  2606. self.assertEqual(unittest_pb2.OPTIONALGROUP_EXTENSION_FIELD_NUMBER, 16)
  2607. self.assertEqual(unittest_pb2.optional_nested_message_extension.number, 18)
  2608. self.assertEqual(
  2609. unittest_pb2.OPTIONAL_NESTED_MESSAGE_EXTENSION_FIELD_NUMBER, 18)
  2610. self.assertEqual(unittest_pb2.optional_nested_enum_extension.number, 21)
  2611. self.assertEqual(unittest_pb2.OPTIONAL_NESTED_ENUM_EXTENSION_FIELD_NUMBER,
  2612. 21)
  2613. self.assertEqual(unittest_pb2.repeated_int32_extension.number, 31)
  2614. self.assertEqual(unittest_pb2.REPEATED_INT32_EXTENSION_FIELD_NUMBER, 31)
  2615. self.assertEqual(unittest_pb2.repeatedgroup_extension.number, 46)
  2616. self.assertEqual(unittest_pb2.REPEATEDGROUP_EXTENSION_FIELD_NUMBER, 46)
  2617. self.assertEqual(unittest_pb2.repeated_nested_message_extension.number, 48)
  2618. self.assertEqual(
  2619. unittest_pb2.REPEATED_NESTED_MESSAGE_EXTENSION_FIELD_NUMBER, 48)
  2620. self.assertEqual(unittest_pb2.repeated_nested_enum_extension.number, 51)
  2621. self.assertEqual(unittest_pb2.REPEATED_NESTED_ENUM_EXTENSION_FIELD_NUMBER,
  2622. 51)
  2623. def testFieldProperties(self):
  2624. cls = unittest_pb2.TestAllTypes
  2625. self.assertIs(cls.optional_int32.DESCRIPTOR,
  2626. cls.DESCRIPTOR.fields_by_name['optional_int32'])
  2627. self.assertEqual(cls.OPTIONAL_INT32_FIELD_NUMBER,
  2628. cls.optional_int32.DESCRIPTOR.number)
  2629. self.assertIs(cls.optional_nested_message.DESCRIPTOR,
  2630. cls.DESCRIPTOR.fields_by_name['optional_nested_message'])
  2631. self.assertEqual(cls.OPTIONAL_NESTED_MESSAGE_FIELD_NUMBER,
  2632. cls.optional_nested_message.DESCRIPTOR.number)
  2633. self.assertIs(cls.repeated_int32.DESCRIPTOR,
  2634. cls.DESCRIPTOR.fields_by_name['repeated_int32'])
  2635. self.assertEqual(cls.REPEATED_INT32_FIELD_NUMBER,
  2636. cls.repeated_int32.DESCRIPTOR.number)
  2637. def testFieldDataDescriptor(self):
  2638. msg = unittest_pb2.TestAllTypes()
  2639. msg.optional_int32 = 42
  2640. self.assertEqual(unittest_pb2.TestAllTypes.optional_int32.__get__(msg), 42)
  2641. unittest_pb2.TestAllTypes.optional_int32.__set__(msg, 25)
  2642. self.assertEqual(msg.optional_int32, 25)
  2643. with self.assertRaises(AttributeError):
  2644. del msg.optional_int32
  2645. try:
  2646. unittest_pb2.ForeignMessage.c.__get__(msg)
  2647. except TypeError:
  2648. pass # The cpp implementation cannot mix fields from other messages.
  2649. # This test exercises a specific check that avoids a crash.
  2650. else:
  2651. pass # The python implementation allows fields from other messages.
  2652. # This is useless, but works.
  2653. def testInitKwargs(self):
  2654. proto = unittest_pb2.TestAllTypes(
  2655. optional_int32=1,
  2656. optional_string='foo',
  2657. optional_bool=True,
  2658. optional_bytes=b'bar',
  2659. optional_nested_message=unittest_pb2.TestAllTypes.NestedMessage(bb=1),
  2660. optional_foreign_message=unittest_pb2.ForeignMessage(c=1),
  2661. optional_nested_enum=unittest_pb2.TestAllTypes.FOO,
  2662. optional_foreign_enum=unittest_pb2.FOREIGN_FOO,
  2663. repeated_int32=[1, 2, 3])
  2664. self.assertTrue(proto.IsInitialized())
  2665. self.assertTrue(proto.HasField('optional_int32'))
  2666. self.assertTrue(proto.HasField('optional_string'))
  2667. self.assertTrue(proto.HasField('optional_bool'))
  2668. self.assertTrue(proto.HasField('optional_bytes'))
  2669. self.assertTrue(proto.HasField('optional_nested_message'))
  2670. self.assertTrue(proto.HasField('optional_foreign_message'))
  2671. self.assertTrue(proto.HasField('optional_nested_enum'))
  2672. self.assertTrue(proto.HasField('optional_foreign_enum'))
  2673. self.assertEqual(1, proto.optional_int32)
  2674. self.assertEqual('foo', proto.optional_string)
  2675. self.assertEqual(True, proto.optional_bool)
  2676. self.assertEqual(b'bar', proto.optional_bytes)
  2677. self.assertEqual(1, proto.optional_nested_message.bb)
  2678. self.assertEqual(1, proto.optional_foreign_message.c)
  2679. self.assertEqual(unittest_pb2.TestAllTypes.FOO,
  2680. proto.optional_nested_enum)
  2681. self.assertEqual(unittest_pb2.FOREIGN_FOO, proto.optional_foreign_enum)
  2682. self.assertEqual([1, 2, 3], proto.repeated_int32)
  2683. def testInitArgsUnknownFieldName(self):
  2684. def InitalizeEmptyMessageWithExtraKeywordArg():
  2685. unused_proto = unittest_pb2.TestEmptyMessage(unknown='unknown')
  2686. self._CheckRaises(
  2687. ValueError,
  2688. InitalizeEmptyMessageWithExtraKeywordArg,
  2689. 'Protocol message TestEmptyMessage has no "unknown" field.')
  2690. def testInitRequiredKwargs(self):
  2691. proto = unittest_pb2.TestRequired(a=1, b=1, c=1)
  2692. self.assertTrue(proto.IsInitialized())
  2693. self.assertTrue(proto.HasField('a'))
  2694. self.assertTrue(proto.HasField('b'))
  2695. self.assertTrue(proto.HasField('c'))
  2696. self.assertFalse(proto.HasField('dummy2'))
  2697. self.assertEqual(1, proto.a)
  2698. self.assertEqual(1, proto.b)
  2699. self.assertEqual(1, proto.c)
  2700. def testInitRequiredForeignKwargs(self):
  2701. proto = unittest_pb2.TestRequiredForeign(
  2702. optional_message=unittest_pb2.TestRequired(a=1, b=1, c=1))
  2703. self.assertTrue(proto.IsInitialized())
  2704. self.assertTrue(proto.HasField('optional_message'))
  2705. self.assertTrue(proto.optional_message.IsInitialized())
  2706. self.assertTrue(proto.optional_message.HasField('a'))
  2707. self.assertTrue(proto.optional_message.HasField('b'))
  2708. self.assertTrue(proto.optional_message.HasField('c'))
  2709. self.assertFalse(proto.optional_message.HasField('dummy2'))
  2710. self.assertEqual(unittest_pb2.TestRequired(a=1, b=1, c=1),
  2711. proto.optional_message)
  2712. self.assertEqual(1, proto.optional_message.a)
  2713. self.assertEqual(1, proto.optional_message.b)
  2714. self.assertEqual(1, proto.optional_message.c)
  2715. def testInitRepeatedKwargs(self):
  2716. proto = unittest_pb2.TestAllTypes(repeated_int32=[1, 2, 3])
  2717. self.assertTrue(proto.IsInitialized())
  2718. self.assertEqual(1, proto.repeated_int32[0])
  2719. self.assertEqual(2, proto.repeated_int32[1])
  2720. self.assertEqual(3, proto.repeated_int32[2])
  2721. @testing_refleaks.TestCase
  2722. class OptionsTest(unittest.TestCase):
  2723. def testMessageOptions(self):
  2724. proto = message_set_extensions_pb2.TestMessageSet()
  2725. self.assertEqual(True,
  2726. proto.DESCRIPTOR.GetOptions().message_set_wire_format)
  2727. proto = unittest_pb2.TestAllTypes()
  2728. self.assertEqual(False,
  2729. proto.DESCRIPTOR.GetOptions().message_set_wire_format)
  2730. def testPackedOptions(self):
  2731. proto = unittest_pb2.TestAllTypes()
  2732. proto.optional_int32 = 1
  2733. proto.optional_double = 3.0
  2734. for field_descriptor, _ in proto.ListFields():
  2735. self.assertEqual(False, field_descriptor.GetOptions().packed)
  2736. proto = unittest_pb2.TestPackedTypes()
  2737. proto.packed_int32.append(1)
  2738. proto.packed_double.append(3.0)
  2739. for field_descriptor, _ in proto.ListFields():
  2740. self.assertEqual(True, field_descriptor.GetOptions().packed)
  2741. self.assertEqual(descriptor.FieldDescriptor.LABEL_REPEATED,
  2742. field_descriptor.label)
  2743. @testing_refleaks.TestCase
  2744. class ClassAPITest(unittest.TestCase):
  2745. @unittest.skipIf(
  2746. api_implementation.Type() == 'cpp' and api_implementation.Version() == 2,
  2747. 'C++ implementation requires a call to MakeDescriptor()')
  2748. @testing_refleaks.SkipReferenceLeakChecker('MakeClass is not repeatable')
  2749. def testMakeClassWithNestedDescriptor(self):
  2750. leaf_desc = descriptor.Descriptor(
  2751. 'leaf', 'package.parent.child.leaf', '',
  2752. containing_type=None, fields=[],
  2753. nested_types=[], enum_types=[],
  2754. extensions=[],
  2755. # pylint: disable=protected-access
  2756. create_key=descriptor._internal_create_key)
  2757. child_desc = descriptor.Descriptor(
  2758. 'child', 'package.parent.child', '',
  2759. containing_type=None, fields=[],
  2760. nested_types=[leaf_desc], enum_types=[],
  2761. extensions=[],
  2762. # pylint: disable=protected-access
  2763. create_key=descriptor._internal_create_key)
  2764. sibling_desc = descriptor.Descriptor(
  2765. 'sibling', 'package.parent.sibling',
  2766. '', containing_type=None, fields=[],
  2767. nested_types=[], enum_types=[],
  2768. extensions=[],
  2769. # pylint: disable=protected-access
  2770. create_key=descriptor._internal_create_key)
  2771. parent_desc = descriptor.Descriptor(
  2772. 'parent', 'package.parent', '',
  2773. containing_type=None, fields=[],
  2774. nested_types=[child_desc, sibling_desc],
  2775. enum_types=[], extensions=[],
  2776. # pylint: disable=protected-access
  2777. create_key=descriptor._internal_create_key)
  2778. reflection.MakeClass(parent_desc)
  2779. def _GetSerializedFileDescriptor(self, name):
  2780. """Get a serialized representation of a test FileDescriptorProto.
  2781. Args:
  2782. name: All calls to this must use a unique message name, to avoid
  2783. collisions in the cpp descriptor pool.
  2784. Returns:
  2785. A string containing the serialized form of a test FileDescriptorProto.
  2786. """
  2787. file_descriptor_str = (
  2788. 'message_type {'
  2789. ' name: "' + name + '"'
  2790. ' field {'
  2791. ' name: "flat"'
  2792. ' number: 1'
  2793. ' label: LABEL_REPEATED'
  2794. ' type: TYPE_UINT32'
  2795. ' }'
  2796. ' field {'
  2797. ' name: "bar"'
  2798. ' number: 2'
  2799. ' label: LABEL_OPTIONAL'
  2800. ' type: TYPE_MESSAGE'
  2801. ' type_name: "Bar"'
  2802. ' }'
  2803. ' nested_type {'
  2804. ' name: "Bar"'
  2805. ' field {'
  2806. ' name: "baz"'
  2807. ' number: 3'
  2808. ' label: LABEL_OPTIONAL'
  2809. ' type: TYPE_MESSAGE'
  2810. ' type_name: "Baz"'
  2811. ' }'
  2812. ' nested_type {'
  2813. ' name: "Baz"'
  2814. ' enum_type {'
  2815. ' name: "deep_enum"'
  2816. ' value {'
  2817. ' name: "VALUE_A"'
  2818. ' number: 0'
  2819. ' }'
  2820. ' }'
  2821. ' field {'
  2822. ' name: "deep"'
  2823. ' number: 4'
  2824. ' label: LABEL_OPTIONAL'
  2825. ' type: TYPE_UINT32'
  2826. ' }'
  2827. ' }'
  2828. ' }'
  2829. '}')
  2830. file_descriptor = descriptor_pb2.FileDescriptorProto()
  2831. text_format.Merge(file_descriptor_str, file_descriptor)
  2832. return file_descriptor.SerializeToString()
  2833. @testing_refleaks.SkipReferenceLeakChecker('MakeDescriptor is not repeatable')
  2834. # This test can only run once; the second time, it raises errors about
  2835. # conflicting message descriptors.
  2836. def testParsingFlatClassWithExplicitClassDeclaration(self):
  2837. """Test that the generated class can parse a flat message."""
  2838. # TODO(xiaofeng): This test fails with cpp implemetnation in the call
  2839. # of six.with_metaclass(). The other two callsites of with_metaclass
  2840. # in this file are both excluded from cpp test, so it might be expected
  2841. # to fail. Need someone more familiar with the python code to take a
  2842. # look at this.
  2843. if api_implementation.Type() != 'python':
  2844. return
  2845. file_descriptor = descriptor_pb2.FileDescriptorProto()
  2846. file_descriptor.ParseFromString(self._GetSerializedFileDescriptor('A'))
  2847. msg_descriptor = descriptor.MakeDescriptor(
  2848. file_descriptor.message_type[0])
  2849. class MessageClass(six.with_metaclass(reflection.GeneratedProtocolMessageType, message.Message)):
  2850. DESCRIPTOR = msg_descriptor
  2851. msg = MessageClass()
  2852. msg_str = (
  2853. 'flat: 0 '
  2854. 'flat: 1 '
  2855. 'flat: 2 ')
  2856. text_format.Merge(msg_str, msg)
  2857. self.assertEqual(msg.flat, [0, 1, 2])
  2858. @testing_refleaks.SkipReferenceLeakChecker('MakeDescriptor is not repeatable')
  2859. def testParsingFlatClass(self):
  2860. """Test that the generated class can parse a flat message."""
  2861. file_descriptor = descriptor_pb2.FileDescriptorProto()
  2862. file_descriptor.ParseFromString(self._GetSerializedFileDescriptor('B'))
  2863. msg_descriptor = descriptor.MakeDescriptor(
  2864. file_descriptor.message_type[0])
  2865. msg_class = reflection.MakeClass(msg_descriptor)
  2866. msg = msg_class()
  2867. msg_str = (
  2868. 'flat: 0 '
  2869. 'flat: 1 '
  2870. 'flat: 2 ')
  2871. text_format.Merge(msg_str, msg)
  2872. self.assertEqual(msg.flat, [0, 1, 2])
  2873. @testing_refleaks.SkipReferenceLeakChecker('MakeDescriptor is not repeatable')
  2874. def testParsingNestedClass(self):
  2875. """Test that the generated class can parse a nested message."""
  2876. file_descriptor = descriptor_pb2.FileDescriptorProto()
  2877. file_descriptor.ParseFromString(self._GetSerializedFileDescriptor('C'))
  2878. msg_descriptor = descriptor.MakeDescriptor(
  2879. file_descriptor.message_type[0])
  2880. msg_class = reflection.MakeClass(msg_descriptor)
  2881. msg = msg_class()
  2882. msg_str = (
  2883. 'bar {'
  2884. ' baz {'
  2885. ' deep: 4'
  2886. ' }'
  2887. '}')
  2888. text_format.Merge(msg_str, msg)
  2889. self.assertEqual(msg.bar.baz.deep, 4)
  2890. if __name__ == '__main__':
  2891. unittest.main()