Pārlūkot izejas kodu

fix equal and hash for bytes field for javanano oneof

Jisi Liu 10 gadi atpakaļ
vecāks
revīzija
7b72a24a20

+ 38 - 16
src/google/protobuf/compiler/javanano/javanano_field.cc

@@ -168,26 +168,48 @@ void SetCommonOneofVariables(const FieldDescriptor* descriptor,
       SimpleItoa(descriptor->number());
 }
 
-void GenerateOneofFieldEquals(const map<string, string>& variables,
+void GenerateOneofFieldEquals(const FieldDescriptor* descriptor,
+                              const map<string, string>& variables,
                               io::Printer* printer) {
-  printer->Print(variables,
-    "if (this.has$capitalized_name$()) {\n"
-    "  if (!this.$oneof_name$_.equals(other.$oneof_name$_)) {\n"
-    "    return false;\n"
-    "  }\n"
-    "} else {\n"
-    "  if (other.has$capitalized_name$()) {\n"
-    "    return false;\n"
-    "  }\n"
-    "}\n");
-
+  if (GetJavaType(descriptor) == JAVATYPE_BYTES) {
+    printer->Print(variables,
+      "if (this.has$capitalized_name$()) {\n"
+      "  if (!other.has$capitalized_name$() ||\n"
+      "      !java.util.Arrays.equals((byte[]) this.$oneof_name$_,\n"
+      "                               (byte[]) other.$oneof_name$_)) {\n"
+      "    return false;\n"
+      "  }\n"
+      "} else {\n"
+      "  if (other.has$capitalized_name$()) {\n"
+      "    return false;\n"
+      "  }\n"
+      "}\n");
+  } else {
+    printer->Print(variables,
+      "if (this.has$capitalized_name$()) {\n"
+      "  if (!this.$oneof_name$_.equals(other.$oneof_name$_)) {\n"
+      "    return false;\n"
+      "  }\n"
+      "} else {\n"
+      "  if (other.has$capitalized_name$()) {\n"
+      "    return false;\n"
+      "  }\n"
+      "}\n");
+  }
 }
 
-void GenerateOneofFieldHashCode(const map<string, string>& variables,
+void GenerateOneofFieldHashCode(const FieldDescriptor* descriptor,
+                                const map<string, string>& variables,
                                 io::Printer* printer) {
-  printer->Print(variables,
-    "result = 31 * result +\n"
-    "  ($has_oneof_case$ ? this.$oneof_name$_.hashCode() : 0);\n");
+  if (GetJavaType(descriptor) == JAVATYPE_BYTES) {
+    printer->Print(variables,
+      "result = 31 * result + ($has_oneof_case$\n"
+      "   ? java.util.Arrays.hashCode((byte[]) this.$oneof_name$_) : 0);\n");
+  } else {
+    printer->Print(variables,
+      "result = 31 * result +\n"
+      "  ($has_oneof_case$ ? this.$oneof_name$_.hashCode() : 0);\n");
+  }
 }
 
 }  // namespace javanano

+ 4 - 2
src/google/protobuf/compiler/javanano/javanano_field.h

@@ -114,9 +114,11 @@ class FieldGeneratorMap {
 
 void SetCommonOneofVariables(const FieldDescriptor* descriptor,
                              map<string, string>* variables);
-void GenerateOneofFieldEquals(const map<string, string>& variables,
+void GenerateOneofFieldEquals(const FieldDescriptor* descriptor,
+                              const map<string, string>& variables,
                               io::Printer* printer);
-void GenerateOneofFieldHashCode(const map<string, string>& variables,
+void GenerateOneofFieldHashCode(const FieldDescriptor* descriptor,
+                                const map<string, string>& variables,
                                 io::Printer* printer);
 
 }  // namespace javanano

+ 2 - 2
src/google/protobuf/compiler/javanano/javanano_message_field.cc

@@ -214,12 +214,12 @@ GenerateSerializedSizeCode(io::Printer* printer) const {
 
 void MessageOneofFieldGenerator::
 GenerateEqualsCode(io::Printer* printer) const {
-  GenerateOneofFieldEquals(variables_, printer);
+  GenerateOneofFieldEquals(descriptor_, variables_, printer);
 }
 
 void MessageOneofFieldGenerator::
 GenerateHashCodeCode(io::Printer* printer) const {
-  GenerateOneofFieldHashCode(variables_, printer);
+  GenerateOneofFieldHashCode(descriptor_, variables_, printer);
 }
 
 // ===================================================================

+ 2 - 2
src/google/protobuf/compiler/javanano/javanano_primitive_field.cc

@@ -767,12 +767,12 @@ void PrimitiveOneofFieldGenerator::GenerateSerializedSizeCode(
 
 void PrimitiveOneofFieldGenerator::GenerateEqualsCode(
     io::Printer* printer) const {
-  GenerateOneofFieldEquals(variables_, printer);
+  GenerateOneofFieldEquals(descriptor_, variables_, printer);
 }
 
 void PrimitiveOneofFieldGenerator::GenerateHashCodeCode(
     io::Printer* printer) const {
-  GenerateOneofFieldHashCode(variables_, printer);
+  GenerateOneofFieldHashCode(descriptor_, variables_, printer);
 }
 
 // ===================================================================