Эх сурвалжийг харах

Calling Keychecker before checking key in MessageMap

cyyber 7 жил өмнө
parent
commit
0e2089c775

+ 3 - 1
python/google/protobuf/internal/containers.py

@@ -549,10 +549,10 @@ class MessageMap(MutableMapping):
     self._values = {}
     self._values = {}
 
 
   def __getitem__(self, key):
   def __getitem__(self, key):
+    key = self._key_checker.CheckValue(key)
     try:
     try:
       return self._values[key]
       return self._values[key]
     except KeyError:
     except KeyError:
-      key = self._key_checker.CheckValue(key)
       new_element = self._message_descriptor._concrete_class()
       new_element = self._message_descriptor._concrete_class()
       new_element._SetListener(self._message_listener)
       new_element._SetListener(self._message_listener)
       self._values[key] = new_element
       self._values[key] = new_element
@@ -584,12 +584,14 @@ class MessageMap(MutableMapping):
       return default
       return default
 
 
   def __contains__(self, item):
   def __contains__(self, item):
+    item = self._key_checker.CheckValue(item)
     return item in self._values
     return item in self._values
 
 
   def __setitem__(self, key, value):
   def __setitem__(self, key, value):
     raise ValueError('May not set values directly, call my_map[key].foo = 5')
     raise ValueError('May not set values directly, call my_map[key].foo = 5')
 
 
   def __delitem__(self, key):
   def __delitem__(self, key):
+    key = self._key_checker.CheckValue(key)
     del self._values[key]
     del self._values[key]
     self._message_listener.Modified()
     self._message_listener.Modified()
 
 

+ 4 - 12
python/google/protobuf/internal/message_test.py

@@ -1480,12 +1480,8 @@ class Proto3Test(BaseTestCase):
 
 
     submsg = msg.map_int32_foreign_message[5]
     submsg = msg.map_int32_foreign_message[5]
     self.assertIs(submsg, msg.map_int32_foreign_message.get(5))
     self.assertIs(submsg, msg.map_int32_foreign_message.get(5))
-    # TODO(jieluo): Fix python and cpp extension diff.
-    if api_implementation.Type() == 'cpp':
-      with self.assertRaises(TypeError):
-        msg.map_int32_foreign_message.get('')
-    else:
-      self.assertEqual(None, msg.map_int32_foreign_message.get(''))
+    with self.assertRaises(TypeError):
+      msg.map_int32_foreign_message.get('')
 
 
   def testScalarMap(self):
   def testScalarMap(self):
     msg = map_unittest_pb2.TestMap()
     msg = map_unittest_pb2.TestMap()
@@ -1695,12 +1691,8 @@ class Proto3Test(BaseTestCase):
 
 
     del msg2.map_int32_foreign_message[222]
     del msg2.map_int32_foreign_message[222]
     self.assertFalse(222 in msg2.map_int32_foreign_message)
     self.assertFalse(222 in msg2.map_int32_foreign_message)
-    if api_implementation.Type() == 'cpp':
-      with self.assertRaises(TypeError):
-        del msg2.map_int32_foreign_message['']
-    else:
-      with self.assertRaises(KeyError):
-        del msg2.map_int32_foreign_message['']
+    with self.assertRaises(TypeError):
+      del msg2.map_int32_foreign_message['']
 
 
   def testMergeFromBadType(self):
   def testMergeFromBadType(self):
     msg = map_unittest_pb2.TestMap()
     msg = map_unittest_pb2.TestMap()