Эх сурвалжийг харах

Implement Equals for nano.

Jisi Liu 11 жил өмнө
parent
commit
1536e93349

+ 40 - 0
javanano/src/main/java/com/google/protobuf/nano/InternalNano.java

@@ -491,4 +491,44 @@ public final class InternalNano {
     }
     return size;
   }
+
+  /**
+   * Checks whether two {@link Map} are equal. We don't use the default equals
+   * method of {@link Map} because it compares by identity not by content for
+   * byte arrays.
+   */
+  public static <K, V> boolean equals(Map<K, V> a, Map<K, V> b) {
+    if (a == b) {
+      return true;
+    }
+    if (a == null) {
+      return b.size() == 0;
+    }
+    if (b == null) {
+      return a.size() == 0;
+    }
+    if (a.size() != b.size()) {
+      return false;
+    }
+    for (Entry<K, V> entry : a.entrySet()) {
+      if (!b.containsKey(entry.getKey())) {
+        return false;
+      }
+      if (!equalsMapValue(entry.getValue(), b.get(entry.getKey()))) {
+        return false;
+      }
+    }
+    return true;
+  }
+
+  private static boolean equalsMapValue(Object a, Object b) {
+    if (a == null || b == null) {
+      throw new IllegalStateException(
+          "keys and values in maps cannot be null");
+    }
+    if (a instanceof byte[] && b instanceof byte[]) {
+      return Arrays.equals((byte[]) a, (byte[]) b);
+    }
+    return a.equals(b);
+  }
 }

+ 121 - 2
javanano/src/test/java/com/google/protobuf/nano/NanoTest.java

@@ -31,6 +31,7 @@
 package com.google.protobuf.nano;
 
 import com.google.protobuf.nano.MapTestProto.TestMap;
+import com.google.protobuf.nano.MapTestProto.TestMap.MessageValue;
 import com.google.protobuf.nano.NanoAccessorsOuterClass.TestNanoAccessors;
 import com.google.protobuf.nano.NanoHasOuterClass.TestAllTypesNanoHas;
 import com.google.protobuf.nano.NanoOuterClass.TestAllTypesNano;
@@ -47,6 +48,7 @@ import junit.framework.TestCase;
 import java.util.Arrays;
 import java.util.HashMap;
 import java.util.Map;
+import java.util.TreeMap;
 
 /**
  * Test nano runtime.
@@ -3824,15 +3826,107 @@ public class NanoTest extends TestCase {
     assertEquals(0, messageValue.value2);
   }
 
+  public void testMapEquals() throws Exception {
+    TestMap a = new TestMap();
+    TestMap b = new TestMap();
+
+    // empty and null map fields are equal.
+    assertTestMapEqual(a, b);
+    a.int32ToBytesField = new HashMap<Integer, byte[]>();
+    assertTestMapEqual(a, b);
+
+    a.int32ToInt32Field = new HashMap<Integer, Integer>();
+    b.int32ToInt32Field = new HashMap<Integer, Integer>();
+    setMap(a.int32ToInt32Field, deepCopy(int32Values), deepCopy(int32Values));
+    setMap(b.int32ToInt32Field, deepCopy(int32Values), deepCopy(int32Values));
+    assertTestMapEqual(a, b);
+
+    a.int32ToMessageField =
+        new HashMap<Integer, MapTestProto.TestMap.MessageValue>();
+    b.int32ToMessageField =
+        new HashMap<Integer, MapTestProto.TestMap.MessageValue>();
+    setMap(a.int32ToMessageField,
+        deepCopy(int32Values), deepCopy(messageValues));
+    setMap(b.int32ToMessageField,
+        deepCopy(int32Values), deepCopy(messageValues));
+    assertTestMapEqual(a, b);
+
+    a.stringToInt32Field = new HashMap<String, Integer>();
+    b.stringToInt32Field = new HashMap<String, Integer>();
+    setMap(a.stringToInt32Field, deepCopy(stringValues), deepCopy(int32Values));
+    setMap(b.stringToInt32Field, deepCopy(stringValues), deepCopy(int32Values));
+    assertTestMapEqual(a, b);
+
+    a.int32ToBytesField = new HashMap<Integer, byte[]>();
+    b.int32ToBytesField = new HashMap<Integer, byte[]>();
+    setMap(a.int32ToBytesField, deepCopy(int32Values), deepCopy(bytesValues));
+    setMap(b.int32ToBytesField, deepCopy(int32Values), deepCopy(bytesValues));
+    assertTestMapEqual(a, b);
+
+    // Make sure the map implementation does not matter.
+    a.int32ToStringField = new TreeMap<Integer, String>();
+    b.int32ToStringField = new HashMap<Integer, String>();
+    setMap(a.int32ToStringField, deepCopy(int32Values), deepCopy(stringValues));
+    setMap(b.int32ToStringField, deepCopy(int32Values), deepCopy(stringValues));
+    assertTestMapEqual(a, b);
+
+    a.clear();
+    b.clear();
+
+    // unequal cases: different value
+    a.int32ToInt32Field = new HashMap<Integer, Integer>();
+    b.int32ToInt32Field = new HashMap<Integer, Integer>();
+    a.int32ToInt32Field.put(1, 1);
+    b.int32ToInt32Field.put(1, 2);
+    assertTestMapUnequal(a, b);
+    // unequal case: additional entry
+    b.int32ToInt32Field.put(1, 1);
+    b.int32ToInt32Field.put(2, 1);
+    assertTestMapUnequal(a, b);
+    a.int32ToInt32Field.put(2, 1);
+    assertTestMapEqual(a, b);
+
+    // unequal case: different message value.
+    a.int32ToMessageField =
+        new HashMap<Integer, MapTestProto.TestMap.MessageValue>();
+    b.int32ToMessageField =
+        new HashMap<Integer, MapTestProto.TestMap.MessageValue>();
+    MessageValue va = new MessageValue();
+    va.value = 1;
+    MessageValue vb = new MessageValue();
+    vb.value = 1;
+    a.int32ToMessageField.put(1, va);
+    b.int32ToMessageField.put(1, vb);
+    assertTestMapEqual(a, b);
+    vb.value = 2;
+    assertTestMapUnequal(a, b);
+  }
+
+  private static void assertTestMapEqual(TestMap a, TestMap b)
+      throws Exception {
+    assertEquals(a.hashCode(), b.hashCode());
+    assertTrue(a.equals(b));
+    assertTrue(b.equals(a));
+  }
+
+  private static void assertTestMapUnequal(TestMap a, TestMap b)
+      throws Exception {
+    assertFalse(a.equals(b));
+    assertFalse(b.equals(a));
+  }
+
   private static final Integer[] int32Values = new Integer[] {
     0, 1, -1, Integer.MAX_VALUE, Integer.MIN_VALUE,
   };
+
   private static final Long[] int64Values = new Long[] {
     0L, 1L, -1L, Long.MAX_VALUE, Long.MIN_VALUE,
   };
+
   private static final String[] stringValues = new String[] {
     "", "hello", "world", "foo", "bar",
   };
+
   private static final byte[][] bytesValues = new byte[][] {
     new byte[] {},
     new byte[] {0},
@@ -3840,13 +3934,16 @@ public class NanoTest extends TestCase {
     new byte[] {127, -128},
     new byte[] {'a', 'b', '0', '1'},
   };
+
   private static final Boolean[] boolValues = new Boolean[] {
     false, true,
   };
+
   private static final Integer[] enumValues = new Integer[] {
     TestMap.FOO, TestMap.BAR, TestMap.BAZ, TestMap.QUX,
     Integer.MAX_VALUE /* unknown */,
   };
+
   private static final TestMap.MessageValue[] messageValues =
       new TestMap.MessageValue[] {
     newMapValueMessage(0),
@@ -3855,15 +3952,37 @@ public class NanoTest extends TestCase {
     newMapValueMessage(Integer.MAX_VALUE),
     newMapValueMessage(Integer.MIN_VALUE),
   };
+
   private static TestMap.MessageValue newMapValueMessage(int value) {
     TestMap.MessageValue result = new TestMap.MessageValue();
     result.value = value;
     return result;
   }
 
+  @SuppressWarnings("unchecked")
+  private static <T> T[] deepCopy(T[] orig) throws Exception {
+    if (orig instanceof MessageValue[]) {
+      MessageValue[] result = new MessageValue[orig.length];
+      for (int i = 0; i < orig.length; i++) {
+        result[i] = new MessageValue();
+        MessageNano.mergeFrom(
+            result[i], MessageNano.toByteArray((MessageValue) orig[i]));
+      }
+      return (T[]) result;
+    }
+    if (orig instanceof byte[][]) {
+      byte[][] result = new byte[orig.length][];
+      for (int i = 0; i < orig.length; i++) {
+        byte[] origBytes = (byte[]) orig[i];
+        result[i] = Arrays.copyOf(origBytes, origBytes.length);
+      }
+    }
+    return Arrays.copyOf(orig, orig.length);
+  }
+
   private <K, V> void setMap(Map<K, V> map, K[] keys, V[] values) {
     assert(keys.length == values.length);
-    for (int i = 0; i < keys.length; ++i) {
+    for (int i = 0; i < keys.length; i++) {
       map.put(keys[i], values[i]);
     }
   }
@@ -3871,7 +3990,7 @@ public class NanoTest extends TestCase {
   private <K, V> void assertMapSet(
       Map<K, V> map, K[] keys, V[] values) throws Exception {
     assert(keys.length == values.length);
-    for (int i = 0; i < values.length; ++i) {
+    for (int i = 0; i < values.length; i++) {
       assertEquals(values[i], map.get(keys[i]));
     }
     assertEquals(keys.length, map.size());

+ 5 - 0
src/google/protobuf/compiler/javanano/javanano_map_field.cc

@@ -166,6 +166,11 @@ GenerateSerializedSizeCode(io::Printer* printer) const {
 
 void MapFieldGenerator::
 GenerateEqualsCode(io::Printer* printer) const {
+  printer->Print(variables_,
+    "if (!com.google.protobuf.nano.InternalNano.equals(\n"
+    "  this.$name$, other.$name$)) {\n"
+    "  return false;\n"
+    "}\n");
 }
 
 void MapFieldGenerator::