Преглед на файлове

Merge pull request #896 from jhump/jh/fix-ioexception-vs-invalidprotobuf-exception

throw IOException instead of InvalidProtocolBufferException when appropriate
Feng Xiao преди 9 години
родител
ревизия
bbe6e430f6

+ 56 - 4
java/core/src/main/java/com/google/protobuf/GeneratedMessage.java

@@ -36,15 +36,13 @@ import com.google.protobuf.Descriptors.EnumValueDescriptor;
 import com.google.protobuf.Descriptors.FieldDescriptor;
 import com.google.protobuf.Descriptors.FieldDescriptor;
 import com.google.protobuf.Descriptors.FileDescriptor;
 import com.google.protobuf.Descriptors.FileDescriptor;
 import com.google.protobuf.Descriptors.OneofDescriptor;
 import com.google.protobuf.Descriptors.OneofDescriptor;
-import com.google.protobuf.GeneratedMessageLite.ExtendableMessage;
-import com.google.protobuf.GeneratedMessageLite.GeneratedExtension;
 
 
 import java.io.IOException;
 import java.io.IOException;
+import java.io.InputStream;
 import java.io.ObjectStreamException;
 import java.io.ObjectStreamException;
 import java.io.Serializable;
 import java.io.Serializable;
 import java.lang.reflect.InvocationTargetException;
 import java.lang.reflect.InvocationTargetException;
 import java.lang.reflect.Method;
 import java.lang.reflect.Method;
-import java.lang.reflect.Type;
 import java.util.ArrayList;
 import java.util.ArrayList;
 import java.util.Collections;
 import java.util.Collections;
 import java.util.Iterator;
 import java.util.Iterator;
@@ -276,6 +274,60 @@ public abstract class GeneratedMessage extends AbstractMessage
     return unknownFields.mergeFieldFrom(tag, input);
     return unknownFields.mergeFieldFrom(tag, input);
   }
   }
 
 
+  protected static <M extends Message> M parseWithIOException(Parser<M> parser, InputStream input)
+      throws IOException {
+    try {
+      return parser.parseFrom(input);
+    } catch (InvalidProtocolBufferException e) {
+      throw e.unwrapIOException();
+    }
+  }
+
+  protected static <M extends Message> M parseWithIOException(Parser<M> parser, InputStream input,
+      ExtensionRegistryLite extensions) throws IOException {
+    try {
+      return parser.parseFrom(input, extensions);
+    } catch (InvalidProtocolBufferException e) {
+      throw e.unwrapIOException();
+    }
+  }
+
+  protected static <M extends Message> M parseWithIOException(Parser<M> parser,
+      CodedInputStream input) throws IOException {
+    try {
+      return parser.parseFrom(input);
+    } catch (InvalidProtocolBufferException e) {
+      throw e.unwrapIOException();
+    }
+  }
+
+  protected static <M extends Message> M parseWithIOException(Parser<M> parser,
+      CodedInputStream input, ExtensionRegistryLite extensions) throws IOException {
+    try {
+      return parser.parseFrom(input, extensions);
+    } catch (InvalidProtocolBufferException e) {
+      throw e.unwrapIOException();
+    }
+  }
+
+  protected static <M extends Message> M parseDelimitedWithIOException(Parser<M> parser,
+      InputStream input) throws IOException {
+    try {
+      return parser.parseDelimitedFrom(input);
+    } catch (InvalidProtocolBufferException e) {
+      throw e.unwrapIOException();
+    }
+  }
+
+  protected static <M extends Message> M parseDelimitedWithIOException(Parser<M> parser,
+      InputStream input, ExtensionRegistryLite extensions) throws IOException {
+    try {
+      return parser.parseDelimitedFrom(input, extensions);
+    } catch (InvalidProtocolBufferException e) {
+      throw e.unwrapIOException();
+    }
+  }
+
   @Override
   @Override
   public void writeTo(final CodedOutputStream output) throws IOException {
   public void writeTo(final CodedOutputStream output) throws IOException {
     MessageReflection.writeMessageTo(this, getAllFieldsRaw(), output, false);
     MessageReflection.writeMessageTo(this, getAllFieldsRaw(), output, false);
@@ -667,7 +719,7 @@ public abstract class GeneratedMessage extends AbstractMessage
           "No map fields found in " + getClass().getName());
           "No map fields found in " + getClass().getName());
     }
     }
 
 
-    /** Like {@link internalGetMapField} but return a mutable version. */
+    /** Like {@link #internalGetMapField} but return a mutable version. */
     @SuppressWarnings({"unused", "rawtypes"})
     @SuppressWarnings({"unused", "rawtypes"})
     protected MapField internalGetMutableMapField(int fieldNumber) {
     protected MapField internalGetMutableMapField(int fieldNumber) {
       // Note that we can't use descriptor names here because this method will
       // Note that we can't use descriptor names here because this method will

+ 12 - 0
java/core/src/main/java/com/google/protobuf/InvalidProtocolBufferException.java

@@ -46,6 +46,10 @@ public class InvalidProtocolBufferException extends IOException {
     super(description);
     super(description);
   }
   }
 
 
+  public InvalidProtocolBufferException(IOException e) {
+    super(e.getMessage(), e);
+  }
+
   /**
   /**
    * Attaches an unfinished message to the exception to support best-effort
    * Attaches an unfinished message to the exception to support best-effort
    * parsing in {@code Parser} interface.
    * parsing in {@code Parser} interface.
@@ -66,6 +70,14 @@ public class InvalidProtocolBufferException extends IOException {
     return unfinishedMessage;
     return unfinishedMessage;
   }
   }
 
 
+  /**
+   * Unwraps the underlying {@link IOException} if this exception was caused by an I/O
+   * problem. Otherwise, returns {@code this}.
+   */
+  public IOException unwrapIOException() {
+    return getCause() instanceof IOException ? (IOException) getCause() : this;
+  }
+
   static InvalidProtocolBufferException truncatedMessage() {
   static InvalidProtocolBufferException truncatedMessage() {
     return new InvalidProtocolBufferException(
     return new InvalidProtocolBufferException(
       "While parsing a protocol message, the input ended unexpectedly " +
       "While parsing a protocol message, the input ended unexpectedly " +

+ 12 - 0
java/core/src/main/java/com/google/protobuf/Parser.java

@@ -30,6 +30,7 @@
 
 
 package com.google.protobuf;
 package com.google.protobuf;
 
 
+import java.io.IOException;
 import java.io.InputStream;
 import java.io.InputStream;
 
 
 /**
 /**
@@ -37,9 +38,20 @@ import java.io.InputStream;
  *
  *
  * The implementation should be stateless and thread-safe.
  * The implementation should be stateless and thread-safe.
  *
  *
+ * <p>All methods may throw {@link InvalidProtocolBufferException}. In the event of invalid data,
+ * like an encoding error, the cause of the thrown exception will be {@code null}. However, if an
+ * I/O problem occurs, an exception is thrown with an {@link IOException} cause.
+ *
  * @author liujisi@google.com (Pherl Liu)
  * @author liujisi@google.com (Pherl Liu)
  */
  */
 public interface Parser<MessageType> {
 public interface Parser<MessageType> {
+
+  // NB(jh): Other parts of the protobuf API that parse messages distinguish between an I/O problem
+  // (like failure reading bytes from a socket) and invalid data (encoding error) via the type of
+  // thrown exception. But it would be source-incompatible to make the methods in this interface do
+  // so since they were originally spec'ed to only throw InvalidProtocolBufferException. So callers
+  // must inspect the cause of the exception to distinguish these two cases.
+
   /**
   /**
    * Parses a message of {@code MessageType} from the input.
    * Parses a message of {@code MessageType} from the input.
    *
    *

+ 211 - 0
java/core/src/test/java/com/google/protobuf/ParseExceptionsTest.java

@@ -0,0 +1,211 @@
+package com.google.protobuf;
+
+import com.google.protobuf.DescriptorProtos.DescriptorProto;
+import java.io.ByteArrayInputStream;
+import java.io.ByteArrayOutputStream;
+import java.io.FilterInputStream;
+import java.io.IOException;
+import java.io.InputStream;
+import org.junit.Test;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.fail;
+
+/**
+ * Tests the exceptions thrown when parsing from a stream. The methods on the {@link Parser}
+ * interface are specified to only throw {@link InvalidProtocolBufferException}. But we really want
+ * to distinguish between invalid protos vs. actual I/O errors (like failures reading from a
+ * socket, etc.). So, when we're not using the parser directly, an {@link IOException} should be
+ * thrown where appropriate, instead of always an {@link InvalidProtocolBufferException}.
+ *
+ * @author jh@squareup.com (Joshua Humphries)
+ */
+public class ParseExceptionsTest {
+
+  private interface ParseTester {
+    DescriptorProto parse(InputStream in) throws IOException;
+  }
+
+  private byte serializedProto[];
+
+  private void setup() {
+    serializedProto = DescriptorProto.getDescriptor().toProto().toByteArray();
+  }
+
+  private void setupDelimited() {
+    ByteArrayOutputStream bos = new ByteArrayOutputStream();
+    try {
+      DescriptorProto.getDescriptor().toProto().writeDelimitedTo(bos);
+    } catch (IOException e) {
+      fail("Exception not expected: " + e);
+    }
+    serializedProto = bos.toByteArray();
+  }
+
+  @Test public void message_parseFrom_InputStream() {
+    setup();
+    verifyExceptions(new ParseTester() {
+      public DescriptorProto parse(InputStream in) throws IOException {
+        return DescriptorProto.parseFrom(in);
+      }
+    });
+  }
+
+  @Test public void message_parseFrom_InputStreamAndExtensionRegistry() {
+    setup();
+    verifyExceptions(new ParseTester() {
+      public DescriptorProto parse(InputStream in) throws IOException {
+        return DescriptorProto.parseFrom(in, ExtensionRegistry.newInstance());
+      }
+    });
+  }
+
+  @Test public void message_parseFrom_CodedInputStream() {
+    setup();
+    verifyExceptions(new ParseTester() {
+      public DescriptorProto parse(InputStream in) throws IOException {
+        return DescriptorProto.parseFrom(CodedInputStream.newInstance(in));
+      }
+    });
+  }
+
+  @Test public void message_parseFrom_CodedInputStreamAndExtensionRegistry() {
+    setup();
+    verifyExceptions(new ParseTester() {
+      public DescriptorProto parse(InputStream in) throws IOException {
+        return DescriptorProto.parseFrom(CodedInputStream.newInstance(in),
+            ExtensionRegistry.newInstance());
+      }
+    });
+  }
+
+  @Test public void message_parseDelimitedFrom_InputStream() {
+    setupDelimited();
+    verifyExceptions(new ParseTester() {
+      public DescriptorProto parse(InputStream in) throws IOException {
+        return DescriptorProto.parseDelimitedFrom(in);
+      }
+    });
+  }
+
+  @Test public void message_parseDelimitedFrom_InputStreamAndExtensionRegistry() {
+    setupDelimited();
+    verifyExceptions(new ParseTester() {
+      public DescriptorProto parse(InputStream in) throws IOException {
+        return DescriptorProto.parseDelimitedFrom(in, ExtensionRegistry.newInstance());
+      }
+    });
+  }
+
+  @Test public void messageBuilder_mergeFrom_InputStream() {
+    setup();
+    verifyExceptions(new ParseTester() {
+      public DescriptorProto parse(InputStream in) throws IOException {
+        return DescriptorProto.newBuilder().mergeFrom(in).build();
+      }
+    });
+  }
+
+  @Test public void messageBuilder_mergeFrom_InputStreamAndExtensionRegistry() {
+    setup();
+    verifyExceptions(new ParseTester() {
+      public DescriptorProto parse(InputStream in) throws IOException {
+        return DescriptorProto.newBuilder().mergeFrom(in, ExtensionRegistry.newInstance()).build();
+      }
+    });
+  }
+
+  @Test public void messageBuilder_mergeFrom_CodedInputStream() {
+    setup();
+    verifyExceptions(new ParseTester() {
+      public DescriptorProto parse(InputStream in) throws IOException {
+        return DescriptorProto.newBuilder().mergeFrom(CodedInputStream.newInstance(in)).build();
+      }
+    });
+  }
+
+  @Test public void messageBuilder_mergeFrom_CodedInputStreamAndExtensionRegistry() {
+    setup();
+    verifyExceptions(new ParseTester() {
+      public DescriptorProto parse(InputStream in) throws IOException {
+        return DescriptorProto.newBuilder()
+            .mergeFrom(CodedInputStream.newInstance(in), ExtensionRegistry.newInstance()).build();
+      }
+    });
+  }
+
+  @Test public void messageBuilder_mergeDelimitedFrom_InputStream() {
+    setupDelimited();
+    verifyExceptions(new ParseTester() {
+      public DescriptorProto parse(InputStream in) throws IOException {
+        DescriptorProto.Builder builder = DescriptorProto.newBuilder();
+        builder.mergeDelimitedFrom(in);
+        return builder.build();
+      }
+    });
+  }
+
+  @Test public void messageBuilder_mergeDelimitedFrom_InputStreamAndExtensionRegistry() {
+    setupDelimited();
+    verifyExceptions(new ParseTester() {
+      public DescriptorProto parse(InputStream in) throws IOException {
+        DescriptorProto.Builder builder = DescriptorProto.newBuilder();
+        builder.mergeDelimitedFrom(in, ExtensionRegistry.newInstance());
+        return builder.build();
+      }
+    });
+  }
+
+  private void verifyExceptions(ParseTester parseTester) {
+    // No exception
+    try {
+      assertEquals(DescriptorProto.getDescriptor().toProto(),
+          parseTester.parse(new ByteArrayInputStream(serializedProto)));
+    } catch (IOException e) {
+      fail("No exception expected: " + e);
+    }
+
+    // IOException
+    try {
+      // using a "broken" stream that will throw part-way through reading the message
+      parseTester.parse(broken(new ByteArrayInputStream(serializedProto)));
+      fail("IOException expected but not thrown");
+    } catch (IOException e) {
+      assertFalse(e instanceof InvalidProtocolBufferException);
+    }
+
+    // InvalidProtocolBufferException
+    try {
+      // make the serialized proto invalid
+      for (int i = 0; i < 50; i++) {
+        serializedProto[i] = -1;
+      }
+      parseTester.parse(new ByteArrayInputStream(serializedProto));
+      fail("InvalidProtocolBufferException expected but not thrown");
+    } catch (IOException e) {
+      assertTrue(e instanceof InvalidProtocolBufferException);
+    }
+  }
+
+  private InputStream broken(InputStream i) {
+    return new FilterInputStream(i) {
+      int count = 0;
+
+      @Override public int read() throws IOException {
+        if (count++ >= 50) {
+          throw new IOException("I'm broken!");
+        }
+        return super.read();
+      }
+
+      @Override public int read(byte b[], int off, int len) throws IOException {
+        if ((count += len) >= 50) {
+          throw new IOException("I'm broken!");
+        }
+        return super.read(b, off, len);
+      }
+    };
+  }
+}

+ 8 - 9
src/google/protobuf/compiler/java/java_message.cc

@@ -664,34 +664,34 @@ GenerateParseFromMethods(io::Printer* printer) {
     "}\n"
     "}\n"
     "public static $classname$ parseFrom(java.io.InputStream input)\n"
     "public static $classname$ parseFrom(java.io.InputStream input)\n"
     "    throws java.io.IOException {\n"
     "    throws java.io.IOException {\n"
-    "  return PARSER.parseFrom(input);\n"
+    "  return com.google.protobuf.GeneratedMessage.parseWithIOException(PARSER, input);"
     "}\n"
     "}\n"
     "public static $classname$ parseFrom(\n"
     "public static $classname$ parseFrom(\n"
     "    java.io.InputStream input,\n"
     "    java.io.InputStream input,\n"
     "    com.google.protobuf.ExtensionRegistryLite extensionRegistry)\n"
     "    com.google.protobuf.ExtensionRegistryLite extensionRegistry)\n"
     "    throws java.io.IOException {\n"
     "    throws java.io.IOException {\n"
-    "  return PARSER.parseFrom(input, extensionRegistry);\n"
+    "  return com.google.protobuf.GeneratedMessage.parseWithIOException(PARSER, input, extensionRegistry);"
     "}\n"
     "}\n"
     "public static $classname$ parseDelimitedFrom(java.io.InputStream input)\n"
     "public static $classname$ parseDelimitedFrom(java.io.InputStream input)\n"
     "    throws java.io.IOException {\n"
     "    throws java.io.IOException {\n"
-    "  return PARSER.parseDelimitedFrom(input);\n"
+    "  return com.google.protobuf.GeneratedMessage.parseDelimitedWithIOException(PARSER, input);"
     "}\n"
     "}\n"
     "public static $classname$ parseDelimitedFrom(\n"
     "public static $classname$ parseDelimitedFrom(\n"
     "    java.io.InputStream input,\n"
     "    java.io.InputStream input,\n"
     "    com.google.protobuf.ExtensionRegistryLite extensionRegistry)\n"
     "    com.google.protobuf.ExtensionRegistryLite extensionRegistry)\n"
     "    throws java.io.IOException {\n"
     "    throws java.io.IOException {\n"
-    "  return PARSER.parseDelimitedFrom(input, extensionRegistry);\n"
+    "  return com.google.protobuf.GeneratedMessage.parseDelimitedWithIOException(PARSER, input, extensionRegistry);"
     "}\n"
     "}\n"
     "public static $classname$ parseFrom(\n"
     "public static $classname$ parseFrom(\n"
     "    com.google.protobuf.CodedInputStream input)\n"
     "    com.google.protobuf.CodedInputStream input)\n"
     "    throws java.io.IOException {\n"
     "    throws java.io.IOException {\n"
-    "  return PARSER.parseFrom(input);\n"
+    "  return com.google.protobuf.GeneratedMessage.parseWithIOException(PARSER, input);"
     "}\n"
     "}\n"
     "public static $classname$ parseFrom(\n"
     "public static $classname$ parseFrom(\n"
     "    com.google.protobuf.CodedInputStream input,\n"
     "    com.google.protobuf.CodedInputStream input,\n"
     "    com.google.protobuf.ExtensionRegistryLite extensionRegistry)\n"
     "    com.google.protobuf.ExtensionRegistryLite extensionRegistry)\n"
     "    throws java.io.IOException {\n"
     "    throws java.io.IOException {\n"
-    "  return PARSER.parseFrom(input, extensionRegistry);\n"
+    "  return com.google.protobuf.GeneratedMessage.parseWithIOException(PARSER, input, extensionRegistry);"
     "}\n"
     "}\n"
     "\n",
     "\n",
     "classname", name_resolver_->GetImmutableClassName(descriptor_));
     "classname", name_resolver_->GetImmutableClassName(descriptor_));
@@ -1217,9 +1217,8 @@ GenerateParsingConstructor(io::Printer* printer) {
       "} catch (com.google.protobuf.InvalidProtocolBufferException e) {\n"
       "} catch (com.google.protobuf.InvalidProtocolBufferException e) {\n"
       "  throw new RuntimeException(e.setUnfinishedMessage(this));\n"
       "  throw new RuntimeException(e.setUnfinishedMessage(this));\n"
       "} catch (java.io.IOException e) {\n"
       "} catch (java.io.IOException e) {\n"
-      "  throw new RuntimeException(\n"
-      "      new com.google.protobuf.InvalidProtocolBufferException(\n"
-      "          e.getMessage()).setUnfinishedMessage(this));\n"
+      "  throw new RuntimeException(new com.google.protobuf.InvalidProtocolBufferException(e)\n"
+      "      .setUnfinishedMessage(this));\n"
       "} finally {\n");
       "} finally {\n");
   printer->Indent();
   printer->Indent();
 
 

+ 1 - 1
src/google/protobuf/compiler/java/java_message_builder.cc

@@ -538,7 +538,7 @@ GenerateBuilderParsingMethods(io::Printer* printer) {
     "    parsedMessage = PARSER.parsePartialFrom(input, extensionRegistry);\n"
     "    parsedMessage = PARSER.parsePartialFrom(input, extensionRegistry);\n"
     "  } catch (com.google.protobuf.InvalidProtocolBufferException e) {\n"
     "  } catch (com.google.protobuf.InvalidProtocolBufferException e) {\n"
     "    parsedMessage = ($classname$) e.getUnfinishedMessage();\n"
     "    parsedMessage = ($classname$) e.getUnfinishedMessage();\n"
-    "    throw e;\n"
+    "    throw e.unwrapIOException();\n"
     "  } finally {\n"
     "  } finally {\n"
     "    if (parsedMessage != null) {\n"
     "    if (parsedMessage != null) {\n"
     "      mergeFrom(parsedMessage);\n"
     "      mergeFrom(parsedMessage);\n"