Przeglądaj źródła

am 4b5874fa: Merge "Correctness: floating point equality using bits instead of ==."

* commit '4b5874fad099faefb469c632e4c7b854cea733ae':
  Correctness: floating point equality using bits instead of ==.
Max Cai 11 lat temu
rodzic
commit
d44a519d8f

+ 121 - 10
java/src/test/java/com/google/protobuf/NanoTest.java

@@ -2886,13 +2886,6 @@ public class NanoTest extends TestCase {
       TestAllTypesNano.BAR,
       TestAllTypesNano.BAZ
     };
-    // We set the _nan fields to something other than nan, because equality
-    // is defined for nan such that Float.NaN != Float.NaN, which makes any
-    // instance of TestAllTypesNano unequal to any other instance unless
-    // these fields are set. This is also the behavior of the regular java
-    // generator when the value of a field is NaN.
-    message.defaultFloatNan = 1.0f;
-    message.defaultDoubleNan = 1.0;
     return message;
   }
 
@@ -2915,7 +2908,6 @@ public class NanoTest extends TestCase {
       TestAllTypesNano.BAR,
       TestAllTypesNano.BAZ
     };
-    message.defaultFloatNan = 1.0f;
     return message;
   }
 
@@ -2924,8 +2916,7 @@ public class NanoTest extends TestCase {
         .setOptionalInt32(5)
         .setOptionalString("Hello")
         .setOptionalBytes(new byte[] {1, 2, 3})
-        .setOptionalNestedEnum(TestNanoAccessors.BAR)
-        .setDefaultFloatNan(1.0f);
+        .setOptionalNestedEnum(TestNanoAccessors.BAR);
     message.optionalNestedMessage = new TestNanoAccessors.NestedMessage().setBb(27);
     message.repeatedInt32 = new int[] { 5, 6, 7, 8 };
     message.repeatedString = new String[] { "One", "Two" };
@@ -2973,6 +2964,126 @@ public class NanoTest extends TestCase {
     return message;
   }
 
+  public void testEqualsWithSpecialFloatingPointValues() throws Exception {
+    // Checks that the nano implementation complies with Object.equals() when treating
+    // floating point numbers, i.e. NaN == NaN and +0.0 != -0.0.
+    // This test assumes that the generated equals() implementations are symmetric, so
+    // there will only be one direction for each equality check.
+
+    TestAllTypesNano m1 = new TestAllTypesNano();
+    m1.optionalFloat = Float.NaN;
+    m1.optionalDouble = Double.NaN;
+    TestAllTypesNano m2 = new TestAllTypesNano();
+    m2.optionalFloat = Float.NaN;
+    m2.optionalDouble = Double.NaN;
+    assertTrue(m1.equals(m2));
+    assertTrue(m1.equals(
+        MessageNano.mergeFrom(new TestAllTypesNano(), MessageNano.toByteArray(m1))));
+
+    m1.optionalFloat = +0f;
+    m2.optionalFloat = -0f;
+    assertFalse(m1.equals(m2));
+
+    m1.optionalFloat = -0f;
+    m1.optionalDouble = +0d;
+    m2.optionalDouble = -0d;
+    assertFalse(m1.equals(m2));
+
+    m1.optionalDouble = -0d;
+    assertTrue(m1.equals(m2));
+    assertFalse(m1.equals(new TestAllTypesNano())); // -0 does not equals() the default +0
+    assertTrue(m1.equals(
+        MessageNano.mergeFrom(new TestAllTypesNano(), MessageNano.toByteArray(m1))));
+
+    // -------
+
+    TestAllTypesNanoHas m3 = new TestAllTypesNanoHas();
+    m3.optionalFloat = Float.NaN;
+    m3.hasOptionalFloat = true;
+    m3.optionalDouble = Double.NaN;
+    m3.hasOptionalDouble = true;
+    TestAllTypesNanoHas m4 = new TestAllTypesNanoHas();
+    m4.optionalFloat = Float.NaN;
+    m4.hasOptionalFloat = true;
+    m4.optionalDouble = Double.NaN;
+    m4.hasOptionalDouble = true;
+    assertTrue(m3.equals(m4));
+    assertTrue(m3.equals(
+        MessageNano.mergeFrom(new TestAllTypesNanoHas(), MessageNano.toByteArray(m3))));
+
+    m3.optionalFloat = +0f;
+    m4.optionalFloat = -0f;
+    assertFalse(m3.equals(m4));
+
+    m3.optionalFloat = -0f;
+    m3.optionalDouble = +0d;
+    m4.optionalDouble = -0d;
+    assertFalse(m3.equals(m4));
+
+    m3.optionalDouble = -0d;
+    m3.hasOptionalFloat = false;  // -0 does not equals() the default +0,
+    m3.hasOptionalDouble = false; // so these incorrect 'has' flags should be disregarded.
+    assertTrue(m3.equals(m4));    // note: m4 has the 'has' flags set.
+    assertFalse(m3.equals(new TestAllTypesNanoHas())); // note: the new message has +0 defaults
+    assertTrue(m3.equals(
+        MessageNano.mergeFrom(new TestAllTypesNanoHas(), MessageNano.toByteArray(m3))));
+                                  // note: the deserialized message has the 'has' flags set.
+
+    // -------
+
+    TestNanoAccessors m5 = new TestNanoAccessors();
+    m5.setOptionalFloat(Float.NaN);
+    m5.setOptionalDouble(Double.NaN);
+    TestNanoAccessors m6 = new TestNanoAccessors();
+    m6.setOptionalFloat(Float.NaN);
+    m6.setOptionalDouble(Double.NaN);
+    assertTrue(m5.equals(m6));
+    assertTrue(m5.equals(
+        MessageNano.mergeFrom(new TestNanoAccessors(), MessageNano.toByteArray(m6))));
+
+    m5.setOptionalFloat(+0f);
+    m6.setOptionalFloat(-0f);
+    assertFalse(m5.equals(m6));
+
+    m5.setOptionalFloat(-0f);
+    m5.setOptionalDouble(+0d);
+    m6.setOptionalDouble(-0d);
+    assertFalse(m5.equals(m6));
+
+    m5.setOptionalDouble(-0d);
+    assertTrue(m5.equals(m6));
+    assertFalse(m5.equals(new TestNanoAccessors()));
+    assertTrue(m5.equals(
+        MessageNano.mergeFrom(new TestNanoAccessors(), MessageNano.toByteArray(m6))));
+
+    // -------
+
+    NanoReferenceTypes.TestAllTypesNano m7 = new NanoReferenceTypes.TestAllTypesNano();
+    m7.optionalFloat = Float.NaN;
+    m7.optionalDouble = Double.NaN;
+    NanoReferenceTypes.TestAllTypesNano m8 = new NanoReferenceTypes.TestAllTypesNano();
+    m8.optionalFloat = Float.NaN;
+    m8.optionalDouble = Double.NaN;
+    assertTrue(m7.equals(m8));
+    assertTrue(m7.equals(MessageNano.mergeFrom(
+        new NanoReferenceTypes.TestAllTypesNano(), MessageNano.toByteArray(m7))));
+
+    m7.optionalFloat = +0f;
+    m8.optionalFloat = -0f;
+    assertFalse(m7.equals(m8));
+
+    m7.optionalFloat = -0f;
+    m7.optionalDouble = +0d;
+    m8.optionalDouble = -0d;
+    assertFalse(m7.equals(m8));
+
+    m7.optionalDouble = -0d;
+    assertTrue(m7.equals(m8));
+    assertFalse(m7.equals(new NanoReferenceTypes.TestAllTypesNano()));
+    assertTrue(m7.equals(MessageNano.mergeFrom(
+        new NanoReferenceTypes.TestAllTypesNano(), MessageNano.toByteArray(m7))));
+  }
+
   public void testNullRepeatedFields() throws Exception {
     // Check that serialization after explicitly setting a repeated field
     // to null doesn't NPE.

+ 58 - 40
src/google/protobuf/compiler/javanano/javanano_primitive_field.cc

@@ -175,38 +175,6 @@ int FixedSize(FieldDescriptor::Type type) {
   return -1;
 }
 
-// Returns true if the field has a default value equal to NaN.
-bool IsDefaultNaN(const FieldDescriptor* field) {
-  switch (field->type()) {
-    case FieldDescriptor::TYPE_INT32   : return false;
-    case FieldDescriptor::TYPE_UINT32  : return false;
-    case FieldDescriptor::TYPE_SINT32  : return false;
-    case FieldDescriptor::TYPE_FIXED32 : return false;
-    case FieldDescriptor::TYPE_SFIXED32: return false;
-    case FieldDescriptor::TYPE_INT64   : return false;
-    case FieldDescriptor::TYPE_UINT64  : return false;
-    case FieldDescriptor::TYPE_SINT64  : return false;
-    case FieldDescriptor::TYPE_FIXED64 : return false;
-    case FieldDescriptor::TYPE_SFIXED64: return false;
-    case FieldDescriptor::TYPE_FLOAT   :
-      return isnan(field->default_value_float());
-    case FieldDescriptor::TYPE_DOUBLE  :
-      return isnan(field->default_value_double());
-    case FieldDescriptor::TYPE_BOOL    : return false;
-    case FieldDescriptor::TYPE_STRING  : return false;
-    case FieldDescriptor::TYPE_BYTES   : return false;
-    case FieldDescriptor::TYPE_ENUM    : return false;
-    case FieldDescriptor::TYPE_GROUP   : return false;
-    case FieldDescriptor::TYPE_MESSAGE : return false;
-
-    // No default because we want the compiler to complain if any new
-    // types are added.
-  }
-
-  GOOGLE_LOG(FATAL) << "Can't get here.";
-  return false;
-}
-
 // Return true if the type is a that has variable length
 // for instance String's.
 bool IsVariableLenType(JavaType type) {
@@ -384,15 +352,21 @@ GenerateSerializationConditional(io::Printer* printer) const {
     printer->Print(variables_,
       "if (");
   }
-  if (IsArrayType(GetJavaType(descriptor_))) {
+  JavaType java_type = GetJavaType(descriptor_);
+  if (IsArrayType(java_type)) {
     printer->Print(variables_,
       "!java.util.Arrays.equals(this.$name$, $default$)) {\n");
-  } else if (IsReferenceType(GetJavaType(descriptor_))) {
+  } else if (IsReferenceType(java_type)) {
     printer->Print(variables_,
       "!this.$name$.equals($default$)) {\n");
-  } else if (IsDefaultNaN(descriptor_)) {
+  } else if (java_type == JAVATYPE_FLOAT) {
     printer->Print(variables_,
-      "!$capitalized_type$.isNaN(this.$name$)) {\n");
+      "java.lang.Float.floatToIntBits(this.$name$)\n"
+      "    != java.lang.Float.floatToIntBits($default$)) {\n");
+  } else if (java_type == JAVATYPE_DOUBLE) {
+    printer->Print(variables_,
+      "java.lang.Double.doubleToLongBits(this.$name$)\n"
+      "    != java.lang.Double.doubleToLongBits($default$)) {\n");
   } else {
     printer->Print(variables_,
       "this.$name$ != $default$) {\n");
@@ -464,6 +438,36 @@ GenerateEqualsCode(io::Printer* printer) const {
     printer->Print(") {\n"
       "  return false;\n"
       "}\n");
+  } else if (java_type == JAVATYPE_FLOAT) {
+    printer->Print(variables_,
+      "{\n"
+      "  int bits = java.lang.Float.floatToIntBits(this.$name$);\n"
+      "  if (bits != java.lang.Float.floatToIntBits(other.$name$)");
+    if (params_.generate_has()) {
+      printer->Print(variables_,
+        "\n"
+        "      || (bits == java.lang.Float.floatToIntBits($default$)\n"
+        "          && this.has$capitalized_name$ != other.has$capitalized_name$)");
+    }
+    printer->Print(") {\n"
+      "    return false;\n"
+      "  }\n"
+      "}\n");
+  } else if (java_type == JAVATYPE_DOUBLE) {
+    printer->Print(variables_,
+      "{\n"
+      "  long bits = java.lang.Double.doubleToLongBits(this.$name$);\n"
+      "  if (bits != java.lang.Double.doubleToLongBits(other.$name$)");
+    if (params_.generate_has()) {
+      printer->Print(variables_,
+        "\n"
+        "      || (bits == java.lang.Double.doubleToLongBits($default$)\n"
+        "          && this.has$capitalized_name$ != other.has$capitalized_name$)");
+    }
+    printer->Print(") {\n"
+      "    return false;\n"
+      "  }\n"
+      "}\n");
   } else {
     printer->Print(variables_,
       "if (this.$name$ != other.$name$");
@@ -623,12 +627,26 @@ GenerateSerializedSizeCode(io::Printer* printer) const {
 void AccessorPrimitiveFieldGenerator::
 GenerateEqualsCode(io::Printer* printer) const {
   switch (GetJavaType(descriptor_)) {
-    // For all Java primitive types below, the hash codes match the
-    // results of BoxedType.valueOf(primitiveValue).hashCode().
-    case JAVATYPE_INT:
-    case JAVATYPE_LONG:
+    // For all Java primitive types below, the equality checks match the
+    // results of BoxedType.valueOf(primitiveValue).equals(otherValue).
     case JAVATYPE_FLOAT:
+      printer->Print(variables_,
+        "if ($different_has$\n"
+        "    || java.lang.Float.floatToIntBits($name$_)\n"
+        "        != java.lang.Float.floatToIntBits(other.$name$_)) {\n"
+        "  return false;\n"
+        "}\n");
+      break;
     case JAVATYPE_DOUBLE:
+      printer->Print(variables_,
+        "if ($different_has$\n"
+        "    || java.lang.Double.doubleToLongBits($name$_)\n"
+        "        != java.lang.Double.doubleToLongBits(other.$name$_)) {\n"
+        "  return false;\n"
+        "}\n");
+      break;
+    case JAVATYPE_INT:
+    case JAVATYPE_LONG:
     case JAVATYPE_BOOLEAN:
       printer->Print(variables_,
         "if ($different_has$\n"

+ 2 - 0
src/google/protobuf/unittest_accessors_nano.proto

@@ -49,6 +49,8 @@ message TestNanoAccessors {
 
   // Singular
   optional int32  optional_int32    =  1;
+  optional float  optional_float    = 11;
+  optional double optional_double   = 12;
   optional string optional_string   = 14;
   optional bytes  optional_bytes    = 15;
 

+ 2 - 0
src/google/protobuf/unittest_has_nano.proto

@@ -49,6 +49,8 @@ message TestAllTypesNanoHas {
 
   // Singular
   optional int32  optional_int32    =  1;
+  optional float  optional_float    = 11;
+  optional double optional_double   = 12;
   optional string optional_string   = 14;
   optional bytes  optional_bytes    = 15;