Browse Source

Protect against null repeated fields.

There's no distinction between a repeated field being null and being
empty. In both cases, nothing is sent on the wire. Clients might for
whatever reason inadvertently set a repeated field to null, so
protect against that and treat it just as if the field was empty.

Change-Id: Ic3846f7f2189d6cfff6f8ef3ca217daecc3c8be7
Brian Duff 12 years ago
parent
commit
34547de99c

+ 1 - 1
java/src/main/java/com/google/protobuf/nano/MessageNanoPrinter.java

@@ -105,7 +105,7 @@ public final class MessageNanoPrinter {
                     if (arrayType == byte.class) {
                         print(fieldName, fieldType, value, indentBuf, buf);
                     } else {
-                        int len = Array.getLength(value);
+                        int len = value == null ? 0 : Array.getLength(value);
                         for (int i = 0; i < len; i++) {
                             Object elem = Array.get(value, i);
                             print(fieldName, arrayType, elem, indentBuf, buf);

+ 52 - 2
java/src/test/java/com/google/protobuf/NanoTest.java

@@ -2677,13 +2677,63 @@ public class NanoTest extends TestCase {
     assertHasWireData(message, false);
   }
 
+  public void testNullRepeatedFields() throws Exception {
+    // Check that serialization after explicitly setting a repeated field
+    // to null doesn't NPE.
+    TestAllTypesNano message = new TestAllTypesNano();
+    message.repeatedInt32 = null;
+    MessageNano.toByteArray(message);  // should not NPE
+    message.toString(); // should not NPE
+
+    message = new TestAllTypesNano();
+    message.repeatedNestedEnum = null;
+    MessageNano.toByteArray(message);  // should not NPE
+    message.toString(); // should not NPE
+
+    message = new TestAllTypesNano();
+    message.repeatedBytes = null;
+    MessageNano.toByteArray(message); // should not NPE
+    message.toString(); // should not NPE
+
+    message = new TestAllTypesNano();
+    message.repeatedNestedMessage = null;
+    MessageNano.toByteArray(message); // should not NPE
+    message.toString(); // should not NPE
+
+    // Create a second message to merge into message.
+    TestAllTypesNano secondMessage = new TestAllTypesNano();
+    TestAllTypesNano.NestedMessage nested =
+        new TestAllTypesNano.NestedMessage();
+    nested.bb = 55;
+    secondMessage.repeatedNestedMessage =
+        new TestAllTypesNano.NestedMessage[] { nested };
+
+    // Should not NPE
+    message.mergeFrom(CodedInputByteBufferNano.newInstance(
+        MessageNano.toByteArray(secondMessage)));
+    assertEquals(55, message.repeatedNestedMessage[0].bb);
+  }
+
   private void assertHasWireData(MessageNano message, boolean expected) {
-    int wireLength = MessageNano.toByteArray(message).length;
+    byte[] bytes = MessageNano.toByteArray(message);
+    int wireLength = bytes.length;
     if (expected) {
       assertFalse(wireLength == 0);
     } else {
-      assertEquals(0, wireLength);
+      if (wireLength != 0) {
+        fail("Expected no wire data for message \n" + message
+            + "\nBut got:\n"
+            + hexDump(bytes));
+      }
+    }
+  }
+
+  private static String hexDump(byte[] bytes) {
+    StringBuilder sb = new StringBuilder();
+    for (byte b : bytes) {
+      sb.append(String.format("%02x ", b));
     }
+    return sb.toString();
   }
 
   private <T> List<T> list(T first, T... remaining) {

+ 2 - 2
src/google/protobuf/compiler/javanano/javanano_enum_field.cc

@@ -293,7 +293,7 @@ GenerateMergingCode(io::Printer* printer) const {
 void RepeatedEnumFieldGenerator::
 GenerateSerializationCode(io::Printer* printer) const {
   printer->Print(variables_,
-    "if (this.$name$.length > 0) {\n");
+    "if (this.$name$ != null && this.$name$.length > 0) {\n");
   printer->Indent();
 
   if (descriptor_->options().packed()) {
@@ -317,7 +317,7 @@ GenerateSerializationCode(io::Printer* printer) const {
 void RepeatedEnumFieldGenerator::
 GenerateSerializedSizeCode(io::Printer* printer) const {
   printer->Print(variables_,
-    "if (this.$name$.length > 0) {\n");
+    "if (this.$name$ != null && this.$name$.length > 0) {\n");
   printer->Indent();
 
   printer->Print(variables_,

+ 13 - 7
src/google/protobuf/compiler/javanano/javanano_message_field.cc

@@ -233,9 +233,11 @@ GenerateMergingCode(io::Printer* printer) const {
   printer->Print(variables_,
     "int arrayLength = com.google.protobuf.nano.WireFormatNano"
     "    .getRepeatedFieldArrayLength(input, $tag$);\n"
-    "int i = this.$name$.length;\n"
+    "int i = this.$name$ == null ? 0 : this.$name$.length;\n"
     "$type$[] newArray = new $type$[i + arrayLength];\n"
-    "System.arraycopy(this.$name$, 0, newArray, 0, i);\n"
+    "if (this.$name$ != null) {\n"
+    "  System.arraycopy(this.$name$, 0, newArray, 0, i);\n"
+    "}\n"
     "this.$name$ = newArray;\n"
     "for (; i < this.$name$.length - 1; i++) {\n"
     "  this.$name$[i] = new $type$();\n");
@@ -266,17 +268,21 @@ GenerateMergingCode(io::Printer* printer) const {
 void RepeatedMessageFieldGenerator::
 GenerateSerializationCode(io::Printer* printer) const {
   printer->Print(variables_,
-    "for ($type$ element : this.$name$) {\n"
-    "  output.write$group_or_message$($number$, element);\n"
+    "if (this.$name$ != null) {\n"
+    "  for ($type$ element : this.$name$) {\n"
+    "    output.write$group_or_message$($number$, element);\n"
+    "  }\n"
     "}\n");
 }
 
 void RepeatedMessageFieldGenerator::
 GenerateSerializedSizeCode(io::Printer* printer) const {
   printer->Print(variables_,
-    "for ($type$ element : this.$name$) {\n"
-    "  size += com.google.protobuf.nano.CodedOutputByteBufferNano\n"
-    "    .compute$group_or_message$Size($number$, element);\n"
+    "if (this.$name$ != null) {\n"
+    "  for ($type$ element : this.$name$) {\n"
+    "    size += com.google.protobuf.nano.CodedOutputByteBufferNano\n"
+    "      .compute$group_or_message$Size($number$, element);\n"
+    "  }\n"
     "}\n");
 }
 

+ 10 - 9
src/google/protobuf/compiler/javanano/javanano_primitive_field.cc

@@ -570,17 +570,15 @@ GenerateRepeatedDataSizeCode(io::Printer* printer) const {
 
 void RepeatedPrimitiveFieldGenerator::
 GenerateSerializationCode(io::Printer* printer) const {
+  printer->Print(variables_,
+    "if (this.$name$ != null && this.$name$.length > 0) {\n");
+  printer->Indent();
+
   if (descriptor_->options().packed()) {
-    printer->Print(variables_,
-      "if (this.$name$.length > 0) {\n");
-    printer->Indent();
     GenerateRepeatedDataSizeCode(printer);
-    printer->Outdent();
-    printer->Print(variables_,
-      "  output.writeRawVarint32($tag$);\n"
-      "  output.writeRawVarint32(dataSize);\n"
-      "}\n");
     printer->Print(variables_,
+      "output.writeRawVarint32($tag$);\n"
+      "output.writeRawVarint32(dataSize);\n"
       "for ($type$ element : this.$name$) {\n"
       "  output.write$capitalized_type$NoTag(element);\n"
       "}\n");
@@ -590,12 +588,15 @@ GenerateSerializationCode(io::Printer* printer) const {
       "  output.write$capitalized_type$($number$, element);\n"
       "}\n");
   }
+
+  printer->Outdent();
+  printer->Print("}\n");
 }
 
 void RepeatedPrimitiveFieldGenerator::
 GenerateSerializedSizeCode(io::Printer* printer) const {
   printer->Print(variables_,
-    "if (this.$name$.length > 0) {\n");
+    "if (this.$name$ != null && this.$name$.length > 0) {\n");
   printer->Indent();
 
   GenerateRepeatedDataSizeCode(printer);