Feng Xiao 8 жил өмнө
parent
commit
5777259273

+ 14 - 0
src/google/protobuf/map_field.cc

@@ -67,6 +67,13 @@ size_t MapFieldBase::SpaceUsedExcludingSelfNoLock() const {
   }
 }
 
+bool MapFieldBase::IsMapValid() const {
+  // "Acquire" insures the operation after SyncRepeatedFieldWithMap won't get
+  // executed before state_ is checked.
+  Atomic32 state = google::protobuf::internal::Acquire_Load(&state_);
+  return state != STATE_MODIFIED_REPEATED;
+}
+
 void MapFieldBase::SetMapDirty() { state_ = STATE_MODIFIED_MAP; }
 
 void MapFieldBase::SetRepeatedDirty() { state_ = STATE_MODIFIED_REPEATED; }
@@ -359,6 +366,13 @@ void DynamicMapField::SyncMapWithRepeatedFieldNoLock() const {
         GOOGLE_LOG(FATAL) << "Can't get here.";
         break;
     }
+
+    // Remove existing map value with same key.
+    Map<MapKey, MapValueRef>::iterator iter = map->find(map_key);
+    if (iter != map->end()) {
+      iter->second.DeleteData();
+    }
+
     MapValueRef& map_val = (*map)[map_key];
     map_val.SetType(val_des->cpp_type());
     switch (val_des->cpp_type()) {

+ 2 - 0
src/google/protobuf/map_field.h

@@ -86,6 +86,8 @@ class LIBPROTOBUF_EXPORT MapFieldBase {
   virtual bool ContainsMapKey(const MapKey& map_key) const = 0;
   virtual bool InsertOrLookupMapValue(
       const MapKey& map_key, MapValueRef* val) = 0;
+  // Insures operations after won't get executed before calling this.
+  bool IsMapValid() const;
   virtual bool DeleteMapValue(const MapKey& map_key) = 0;
   virtual bool EqualIterator(const MapIterator& a,
                              const MapIterator& b) const = 0;

+ 127 - 1
src/google/protobuf/map_test.cc

@@ -975,6 +975,11 @@ static int Int(const string& value) {
 class MapFieldReflectionTest : public testing::Test {
  protected:
   typedef FieldDescriptor FD;
+
+  int MapSize(const Reflection* reflection, const FieldDescriptor* field,
+              const Message& message) {
+    return reflection->MapSize(message, field);
+  }
 };
 
 TEST_F(MapFieldReflectionTest, RegularFields) {
@@ -1782,6 +1787,50 @@ TEST_F(MapFieldReflectionTest, RepeatedFieldRefMergeFromAndSwap) {
   // TODO(teboring): add test for duplicated key
 }
 
+TEST_F(MapFieldReflectionTest, MapSizeWithDuplicatedKey) {
+  // Dynamic Message
+  {
+    DynamicMessageFactory factory;
+    google::protobuf::scoped_ptr<Message> message(
+        factory.GetPrototype(unittest::TestMap::descriptor())->New());
+    const Reflection* reflection = message->GetReflection();
+    const FieldDescriptor* field =
+        unittest::TestMap::descriptor()->FindFieldByName("map_int32_int32");
+
+    Message* entry1 = reflection->AddMessage(message.get(), field);
+    Message* entry2 = reflection->AddMessage(message.get(), field);
+
+    const Reflection* entry_reflection = entry1->GetReflection();
+    const FieldDescriptor* key_field =
+        entry1->GetDescriptor()->FindFieldByName("key");
+    entry_reflection->SetInt32(entry1, key_field, 1);
+    entry_reflection->SetInt32(entry2, key_field, 1);
+
+    EXPECT_EQ(2, reflection->FieldSize(*message, field));
+    EXPECT_EQ(1, MapSize(reflection, field, *message));
+  }
+
+  // Generated Message
+  {
+    unittest::TestMap message;
+    const Reflection* reflection = message.GetReflection();
+    const FieldDescriptor* field =
+        message.GetDescriptor()->FindFieldByName("map_int32_int32");
+
+    Message* entry1 = reflection->AddMessage(&message, field);
+    Message* entry2 = reflection->AddMessage(&message, field);
+
+    const Reflection* entry_reflection = entry1->GetReflection();
+    const FieldDescriptor* key_field =
+        entry1->GetDescriptor()->FindFieldByName("key");
+    entry_reflection->SetInt32(entry1, key_field, 1);
+    entry_reflection->SetInt32(entry2, key_field, 1);
+
+    EXPECT_EQ(2, reflection->FieldSize(message, field));
+    EXPECT_EQ(1, MapSize(reflection, field, message));
+  }
+}
+
 // Generated Message Test ===========================================
 
 TEST(GeneratedMapFieldTest, Accessors) {
@@ -2689,6 +2738,69 @@ TEST_F(MapFieldInDynamicMessageTest, RecursiveMap) {
   ASSERT_TRUE(to->ParseFromString(data));
 }
 
+TEST_F(MapFieldInDynamicMessageTest, MapValueReferernceValidAfterSerialize) {
+  google::protobuf::scoped_ptr<Message> message(map_prototype_->New());
+  MapReflectionTester reflection_tester(map_descriptor_);
+  reflection_tester.SetMapFieldsViaMapReflection(message.get());
+
+  // Get value reference before serialization, so that we know the value is from
+  // map.
+  MapKey map_key;
+  MapValueRef map_val;
+  map_key.SetInt32Value(0);
+  reflection_tester.GetMapValueViaMapReflection(
+      message.get(), "map_int32_foreign_message", map_key, &map_val);
+  Message* submsg = map_val.MutableMessageValue();
+
+  // In previous implementation, calling SerializeToString will cause syncing
+  // from map to repeated field, which will invalidate the submsg we previously
+  // got.
+  string data;
+  message->SerializeToString(&data);
+
+  const Reflection* submsg_reflection = submsg->GetReflection();
+  const Descriptor* submsg_desc = submsg->GetDescriptor();
+  const FieldDescriptor* submsg_field = submsg_desc->FindFieldByName("c");
+  submsg_reflection->SetInt32(submsg, submsg_field, 128);
+
+  message->SerializeToString(&data);
+  TestMap to;
+  to.ParseFromString(data);
+  EXPECT_EQ(128, to.map_int32_foreign_message().at(0).c());
+}
+
+TEST_F(MapFieldInDynamicMessageTest, MapEntryReferernceValidAfterSerialize) {
+  google::protobuf::scoped_ptr<Message> message(map_prototype_->New());
+  MapReflectionTester reflection_tester(map_descriptor_);
+  reflection_tester.SetMapFieldsViaReflection(message.get());
+
+  // Get map entry before serialization, so that we know the it is from
+  // repeated field.
+  Message* map_entry = reflection_tester.GetMapEntryViaReflection(
+      message.get(), "map_int32_foreign_message", 0);
+  const Reflection* map_entry_reflection = map_entry->GetReflection();
+  const Descriptor* map_entry_desc = map_entry->GetDescriptor();
+  const FieldDescriptor* value_field = map_entry_desc->FindFieldByName("value");
+  Message* submsg =
+      map_entry_reflection->MutableMessage(map_entry, value_field);
+
+  // In previous implementation, calling SerializeToString will cause syncing
+  // from repeated field to map, which will invalidate the map_entry we
+  // previously got.
+  string data;
+  message->SerializeToString(&data);
+
+  const Reflection* submsg_reflection = submsg->GetReflection();
+  const Descriptor* submsg_desc = submsg->GetDescriptor();
+  const FieldDescriptor* submsg_field = submsg_desc->FindFieldByName("c");
+  submsg_reflection->SetInt32(submsg, submsg_field, 128);
+
+  message->SerializeToString(&data);
+  TestMap to;
+  to.ParseFromString(data);
+  EXPECT_EQ(128, to.map_int32_foreign_message().at(0).c());
+}
+
 // ReflectionOps Test ===============================================
 
 TEST(ReflectionOpsForMapFieldTest, MapSanityCheck) {
@@ -2751,6 +2863,20 @@ TEST(ReflectionOpsForMapFieldTest, MapDiscardUnknownFields) {
       GetUnknownFields(message).field_count());
 }
 
+TEST(ReflectionOpsForMapFieldTest, IsInitialized) {
+  unittest::TestRequiredMessageMap map_message;
+
+  // Add an uninitialized message.
+  (*map_message.mutable_map_field())[0];
+  EXPECT_FALSE(ReflectionOps::IsInitialized(map_message));
+
+  // Initialize uninitialized message
+  (*map_message.mutable_map_field())[0].set_a(0);
+  (*map_message.mutable_map_field())[0].set_b(0);
+  (*map_message.mutable_map_field())[0].set_c(0);
+  EXPECT_TRUE(ReflectionOps::IsInitialized(map_message));
+}
+
 // Wire Format Test =================================================
 
 TEST(WireFormatForMapFieldTest, ParseMap) {
@@ -3089,7 +3215,7 @@ TEST(ArenaTest, ParsingAndSerializingNoHeapAllocation) {
 }
 
 // Use text format parsing and serializing to test reflection api.
-TEST(ArenaTest, RelfectionInTextFormat) {
+TEST(ArenaTest, ReflectionInTextFormat) {
   Arena arena;
   string data;
 

+ 16 - 0
src/google/protobuf/map_test_util.cc

@@ -744,6 +744,22 @@ void MapReflectionTester::SetMapFieldsViaMapReflection(
       sub_foreign_message, foreign_c_, 1);
 }
 
+void MapReflectionTester::GetMapValueViaMapReflection(Message* message,
+                                                      const string& field_name,
+                                                      const MapKey& map_key,
+                                                      MapValueRef* map_val) {
+  const Reflection* reflection = message->GetReflection();
+  EXPECT_FALSE(reflection->InsertOrLookupMapValue(message, F(field_name),
+                                                  map_key, map_val));
+}
+
+Message* MapReflectionTester::GetMapEntryViaReflection(Message* message,
+                                                       const string& field_name,
+                                                       int index) {
+  const Reflection* reflection = message->GetReflection();
+  return reflection->MutableRepeatedMessage(message, F(field_name), index);
+}
+
 void MapReflectionTester::ClearMapFieldsViaReflection(
     Message* message) {
   const Reflection* reflection = message->GetReflection();

+ 5 - 0
src/google/protobuf/map_test_util.h

@@ -106,6 +106,11 @@ class MapReflectionTester {
   void ExpectClearViaReflection(const Message& message);
   void ExpectClearViaReflectionIterator(Message* message);
   void ExpectMapEntryClearViaReflection(Message* message);
+  void GetMapValueViaMapReflection(Message* message,
+                                   const string& field_name,
+                                   const MapKey& map_key, MapValueRef* map_val);
+  Message* GetMapEntryViaReflection(Message* message, const string& field_name,
+                                    int index);
 
  private:
   const FieldDescriptor* F(const string& name);

+ 11 - 0
src/google/protobuf/message.h

@@ -154,6 +154,13 @@ class MapReflectionFriend;     // scalar_map_container.h
 }
 
 
+namespace internal {
+class ReflectionOps;     // reflection_ops.h
+class MapKeySorter;      // wire_format.cc
+class WireFormat;        // wire_format.h
+class MapFieldReflectionTest;  // map_test.cc
+}
+
 template<typename T>
 class RepeatedField;     // repeated_field.h
 
@@ -936,6 +943,10 @@ class LIBPROTOBUF_EXPORT Reflection {
   template<typename T, typename Enable>
   friend class MutableRepeatedFieldRef;
   friend class ::google::protobuf::python::MapReflectionFriend;
+  friend class internal::MapFieldReflectionTest;
+  friend class internal::MapKeySorter;
+  friend class internal::WireFormat;
+  friend class internal::ReflectionOps;
 
   // Special version for specialized implementations of string.  We can't call
   // MutableRawRepeatedField directly here because we don't have access to

+ 22 - 0
src/google/protobuf/reflection_ops.cc

@@ -38,6 +38,7 @@
 #include <google/protobuf/reflection_ops.h>
 #include <google/protobuf/descriptor.h>
 #include <google/protobuf/descriptor.pb.h>
+#include <google/protobuf/map_field.h>
 #include <google/protobuf/unknown_field_set.h>
 #include <google/protobuf/stubs/strutil.h>
 
@@ -158,6 +159,27 @@ bool ReflectionOps::IsInitialized(const Message& message) {
     const FieldDescriptor* field = fields[i];
     if (field->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
 
+      if (field->is_map()) {
+        const FieldDescriptor* value_field = field->message_type()->field(1);
+        if (value_field->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
+          MapFieldBase* map_field =
+              reflection->MapData(const_cast<Message*>(&message), field);
+          if (map_field->IsMapValid()) {
+            MapIterator iter(const_cast<Message*>(&message), field);
+            MapIterator end(const_cast<Message*>(&message), field);
+            for (map_field->MapBegin(&iter), map_field->MapEnd(&end);
+                 iter != end; ++iter) {
+              if (!iter.GetValueRef().GetMessageValue().IsInitialized()) {
+                return false;
+              }
+            }
+            continue;
+          }
+        } else {
+          continue;
+        }
+      }
+
       if (field->is_repeated()) {
         int size = reflection->FieldSize(message, field);
 

+ 275 - 1
src/google/protobuf/wire_format.cc

@@ -54,9 +54,17 @@
 
 
 namespace google {
+const size_t kMapEntryTagByteSize = 2;
+
 namespace protobuf {
 namespace internal {
 
+// Forward declare static functions
+static size_t MapKeyDataOnlyByteSize(const FieldDescriptor* field,
+                                     const MapKey& value);
+static size_t MapValueRefDataOnlyByteSize(const FieldDescriptor* field,
+                                          const MapValueRef& value);
+
 // ===================================================================
 
 bool UnknownFieldSetFieldSkipper::SkipField(
@@ -825,6 +833,129 @@ void WireFormat::SerializeWithCachedSizes(
        "during serialization?";
 }
 
+static void SerializeMapKeyWithCachedSizes(const FieldDescriptor* field,
+                                           const MapKey& value,
+                                           io::CodedOutputStream* output) {
+  switch (field->type()) {
+    case FieldDescriptor::TYPE_DOUBLE:
+    case FieldDescriptor::TYPE_FLOAT:
+    case FieldDescriptor::TYPE_GROUP:
+    case FieldDescriptor::TYPE_MESSAGE:
+    case FieldDescriptor::TYPE_BYTES:
+    case FieldDescriptor::TYPE_ENUM:
+      GOOGLE_LOG(FATAL) << "Unsupported";
+      break;
+#define CASE_TYPE(FieldType, CamelFieldType, CamelCppType)                     \
+  case FieldDescriptor::TYPE_##FieldType:                                      \
+    WireFormatLite::Write##CamelFieldType(1, value.Get##CamelCppType##Value(), \
+                                          output);                             \
+    break;
+      CASE_TYPE(INT64, Int64, Int64)
+      CASE_TYPE(UINT64, UInt64, UInt64)
+      CASE_TYPE(INT32, Int32, Int32)
+      CASE_TYPE(FIXED64, Fixed64, UInt64)
+      CASE_TYPE(FIXED32, Fixed32, UInt32)
+      CASE_TYPE(BOOL, Bool, Bool)
+      CASE_TYPE(UINT32, UInt32, UInt32)
+      CASE_TYPE(SFIXED32, SFixed32, Int32)
+      CASE_TYPE(SFIXED64, SFixed64, Int64)
+      CASE_TYPE(SINT32, SInt32, Int32)
+      CASE_TYPE(SINT64, SInt64, Int64)
+      CASE_TYPE(STRING, String, String)
+#undef CASE_TYPE
+  }
+}
+
+static void SerializeMapValueRefWithCachedSizes(const FieldDescriptor* field,
+                                                const MapValueRef& value,
+                                                io::CodedOutputStream* output) {
+  switch (field->type()) {
+#define CASE_TYPE(FieldType, CamelFieldType, CamelCppType)                     \
+  case FieldDescriptor::TYPE_##FieldType:                                      \
+    WireFormatLite::Write##CamelFieldType(2, value.Get##CamelCppType##Value(), \
+                                          output);                             \
+    break;
+    CASE_TYPE(INT64, Int64, Int64)
+    CASE_TYPE(UINT64, UInt64, UInt64)
+    CASE_TYPE(INT32, Int32, Int32)
+    CASE_TYPE(FIXED64, Fixed64, UInt64)
+    CASE_TYPE(FIXED32, Fixed32, UInt32)
+    CASE_TYPE(BOOL, Bool, Bool)
+    CASE_TYPE(UINT32, UInt32, UInt32)
+    CASE_TYPE(SFIXED32, SFixed32, Int32)
+    CASE_TYPE(SFIXED64, SFixed64, Int64)
+    CASE_TYPE(SINT32, SInt32, Int32)
+    CASE_TYPE(SINT64, SInt64, Int64)
+    CASE_TYPE(ENUM, Enum, Enum)
+    CASE_TYPE(DOUBLE, Double, Double)
+    CASE_TYPE(FLOAT, Float, Float)
+    CASE_TYPE(STRING, String, String)
+    CASE_TYPE(BYTES, Bytes, String)
+    CASE_TYPE(MESSAGE, Message, Message)
+    CASE_TYPE(GROUP, Group, Message)
+#undef CASE_TYPE
+  }
+}
+
+class MapKeySorter {
+ public:
+  static std::vector<MapKey> SortKey(const Message& message,
+                                     const Reflection* reflection,
+                                     const FieldDescriptor* field) {
+    std::vector<MapKey> sorted_key_list;
+    for (MapIterator it =
+             reflection->MapBegin(const_cast<Message*>(&message), field);
+         it != reflection->MapEnd(const_cast<Message*>(&message), field);
+         ++it) {
+      sorted_key_list.push_back(it.GetKey());
+    }
+    MapKeyComparator comparator;
+    std::sort(sorted_key_list.begin(), sorted_key_list.end(), comparator);
+    return sorted_key_list;
+  }
+
+ private:
+  class MapKeyComparator {
+   public:
+    bool operator()(const MapKey& a, const MapKey& b) const {
+      GOOGLE_DCHECK(a.type() == b.type());
+      switch (a.type()) {
+#define CASE_TYPE(CppType, CamelCppType)                                \
+  case FieldDescriptor::CPPTYPE_##CppType: {                            \
+    return a.Get##CamelCppType##Value() < b.Get##CamelCppType##Value(); \
+  }
+        CASE_TYPE(STRING, String)
+        CASE_TYPE(INT64, Int64)
+        CASE_TYPE(INT32, Int32)
+        CASE_TYPE(UINT64, UInt64)
+        CASE_TYPE(UINT32, UInt32)
+        CASE_TYPE(BOOL, Bool)
+#undef CASE_TYPE
+
+        default:
+          GOOGLE_LOG(DFATAL) << "Invalid key for map field.";
+          return true;
+      }
+    }
+  };
+};
+
+static void SerializeMapEntry(const FieldDescriptor* field, const MapKey& key,
+                              const MapValueRef& value,
+                              io::CodedOutputStream* output) {
+  const FieldDescriptor* key_field = field->message_type()->field(0);
+  const FieldDescriptor* value_field = field->message_type()->field(1);
+
+  WireFormatLite::WriteTag(field->number(),
+                           WireFormatLite::WIRETYPE_LENGTH_DELIMITED, output);
+  size_t size = kMapEntryTagByteSize;
+  size += MapKeyDataOnlyByteSize(key_field, key);
+  size += MapValueRefDataOnlyByteSize(value_field, value);
+  output->WriteVarint32(size);
+  SerializeMapKeyWithCachedSizes(key_field, key, output);
+  SerializeMapValueRefWithCachedSizes(value_field, value, output);
+}
+
 void WireFormat::SerializeFieldWithCachedSizes(
     const FieldDescriptor* field,
     const Message& message,
@@ -839,6 +970,48 @@ void WireFormat::SerializeFieldWithCachedSizes(
     return;
   }
 
+  // For map fields, we can use either repeated field reflection or map
+  // reflection.  Our choice has some subtle effects.  If we use repeated field
+  // reflection here, then the repeated field representation becomes
+  // authoritative for this field: any existing references that came from map
+  // reflection remain valid for reading, but mutations to them are lost and
+  // will be overwritten next time we call map reflection!
+  //
+  // So far this mainly affects Python, which keeps long-term references to map
+  // values around, and always uses map reflection.  See: b/35918691
+  //
+  // Here we choose to use map reflection API as long as the internal
+  // map is valid. In this way, the serialization doesn't change map field's
+  // internal state and existing references that came from map reflection remain
+  // valid for both reading and writing.
+  if (field->is_map()) {
+    MapFieldBase* map_field =
+        message_reflection->MapData(const_cast<Message*>(&message), field);
+    if (map_field->IsMapValid()) {
+      if (output->IsSerializationDeterministic()) {
+        std::vector<MapKey> sorted_key_list =
+            MapKeySorter::SortKey(message, message_reflection, field);
+        for (std::vector<MapKey>::iterator it = sorted_key_list.begin();
+             it != sorted_key_list.end(); ++it) {
+          MapValueRef map_value;
+          message_reflection->InsertOrLookupMapValue(
+              const_cast<Message*>(&message), field, *it, &map_value);
+          SerializeMapEntry(field, *it, map_value, output);
+        }
+      } else {
+        for (MapIterator it = message_reflection->MapBegin(
+                 const_cast<Message*>(&message), field);
+             it !=
+             message_reflection->MapEnd(const_cast<Message*>(&message), field);
+             ++it) {
+          SerializeMapEntry(field, it.GetKey(), it.GetValueRef(), output);
+        }
+      }
+
+      return;
+    }
+  }
+
   int count = 0;
 
   if (field->is_repeated()) {
@@ -1059,11 +1232,113 @@ size_t WireFormat::FieldByteSize(
   return our_size;
 }
 
+static size_t MapKeyDataOnlyByteSize(const FieldDescriptor* field,
+                                     const MapKey& value) {
+  GOOGLE_DCHECK_EQ(FieldDescriptor::TypeToCppType(field->type()), value.type());
+  switch (field->type()) {
+    case FieldDescriptor::TYPE_DOUBLE:
+    case FieldDescriptor::TYPE_FLOAT:
+    case FieldDescriptor::TYPE_GROUP:
+    case FieldDescriptor::TYPE_MESSAGE:
+    case FieldDescriptor::TYPE_BYTES:
+    case FieldDescriptor::TYPE_ENUM:
+      GOOGLE_LOG(FATAL) << "Unsupported";
+      return 0;
+#define CASE_TYPE(FieldType, CamelFieldType, CamelCppType) \
+  case FieldDescriptor::TYPE_##FieldType:                  \
+    return WireFormatLite::CamelFieldType##Size(           \
+        value.Get##CamelCppType##Value());
+
+#define FIXED_CASE_TYPE(FieldType, CamelFieldType) \
+  case FieldDescriptor::TYPE_##FieldType:          \
+    return WireFormatLite::k##CamelFieldType##Size;
+
+      CASE_TYPE(INT32, Int32, Int32);
+      CASE_TYPE(INT64, Int64, Int64);
+      CASE_TYPE(UINT32, UInt32, UInt32);
+      CASE_TYPE(UINT64, UInt64, UInt64);
+      CASE_TYPE(SINT32, SInt32, Int32);
+      CASE_TYPE(SINT64, SInt64, Int64);
+      CASE_TYPE(STRING, String, String);
+      FIXED_CASE_TYPE(FIXED32, Fixed32);
+      FIXED_CASE_TYPE(FIXED64, Fixed64);
+      FIXED_CASE_TYPE(SFIXED32, SFixed32);
+      FIXED_CASE_TYPE(SFIXED64, SFixed64);
+      FIXED_CASE_TYPE(BOOL, Bool);
+
+#undef CASE_TYPE
+#undef FIXED_CASE_TYPE
+  }
+  GOOGLE_LOG(FATAL) << "Cannot get here";
+  return 0;
+}
+
+static size_t MapValueRefDataOnlyByteSize(const FieldDescriptor* field,
+                                          const MapValueRef& value) {
+  switch (field->type()) {
+    case FieldDescriptor::TYPE_GROUP:
+      GOOGLE_LOG(FATAL) << "Unsupported";
+      return 0;
+#define CASE_TYPE(FieldType, CamelFieldType, CamelCppType) \
+  case FieldDescriptor::TYPE_##FieldType:                  \
+    return WireFormatLite::CamelFieldType##Size(           \
+        value.Get##CamelCppType##Value());
+
+#define FIXED_CASE_TYPE(FieldType, CamelFieldType) \
+  case FieldDescriptor::TYPE_##FieldType:          \
+    return WireFormatLite::k##CamelFieldType##Size;
+
+      CASE_TYPE(INT32, Int32, Int32);
+      CASE_TYPE(INT64, Int64, Int64);
+      CASE_TYPE(UINT32, UInt32, UInt32);
+      CASE_TYPE(UINT64, UInt64, UInt64);
+      CASE_TYPE(SINT32, SInt32, Int32);
+      CASE_TYPE(SINT64, SInt64, Int64);
+      CASE_TYPE(STRING, String, String);
+      CASE_TYPE(BYTES, Bytes, String);
+      CASE_TYPE(ENUM, Enum, Enum);
+      CASE_TYPE(MESSAGE, Message, Message);
+      FIXED_CASE_TYPE(FIXED32, Fixed32);
+      FIXED_CASE_TYPE(FIXED64, Fixed64);
+      FIXED_CASE_TYPE(SFIXED32, SFixed32);
+      FIXED_CASE_TYPE(SFIXED64, SFixed64);
+      FIXED_CASE_TYPE(DOUBLE, Double);
+      FIXED_CASE_TYPE(FLOAT, Float);
+      FIXED_CASE_TYPE(BOOL, Bool);
+
+#undef CASE_TYPE
+#undef FIXED_CASE_TYPE
+  }
+  GOOGLE_LOG(FATAL) << "Cannot get here";
+  return 0;
+}
+
 size_t WireFormat::FieldDataOnlyByteSize(
     const FieldDescriptor* field,
     const Message& message) {
   const Reflection* message_reflection = message.GetReflection();
 
+  size_t data_size = 0;
+
+  if (field->is_map()) {
+    MapFieldBase* map_field =
+        message_reflection->MapData(const_cast<Message*>(&message), field);
+    if (map_field->IsMapValid()) {
+      MapIterator iter(const_cast<Message*>(&message), field);
+      MapIterator end(const_cast<Message*>(&message), field);
+      const FieldDescriptor* key_field = field->message_type()->field(0);
+      const FieldDescriptor* value_field = field->message_type()->field(1);
+      for (map_field->MapBegin(&iter), map_field->MapEnd(&end); iter != end;
+           ++iter) {
+        size_t size = kMapEntryTagByteSize;
+        size += MapKeyDataOnlyByteSize(key_field, iter.GetKey());
+        size += MapValueRefDataOnlyByteSize(value_field, iter.GetValueRef());
+        data_size += WireFormatLite::LengthDelimitedSize(size);
+      }
+      return data_size;
+    }
+  }
+
   size_t count = 0;
   if (field->is_repeated()) {
     count =
@@ -1075,7 +1350,6 @@ size_t WireFormat::FieldDataOnlyByteSize(
     count = 1;
   }
 
-  size_t data_size = 0;
   switch (field->type()) {
 #define HANDLE_TYPE(TYPE, TYPE_METHOD, CPPTYPE_METHOD)                     \
     case FieldDescriptor::TYPE_##TYPE:                                     \