瀏覽代碼

Massive roll-up of changes. See CHANGES.txt.

kenton@google.com 16 年之前
父節點
當前提交
fccb146e3f
共有 100 個文件被更改,包括 5803 次插入4131 次删除
  1. 46 0
      CHANGES.txt
  2. 0 6
      Makefile.am
  3. 10 0
      autogen.sh
  4. 3 1
      generate_descriptor_proto.sh
  5. 1 0
      java/pom.xml
  6. 27 8
      java/src/main/java/com/google/protobuf/AbstractMessage.java
  7. 13 15
      java/src/main/java/com/google/protobuf/AbstractMessageLite.java
  8. 18 0
      java/src/main/java/com/google/protobuf/ByteString.java
  9. 23 4
      java/src/main/java/com/google/protobuf/CodedInputStream.java
  10. 63 55
      java/src/main/java/com/google/protobuf/Descriptors.java
  11. 5 0
      java/src/main/java/com/google/protobuf/ExtensionRegistry.java
  12. 4 0
      java/src/main/java/com/google/protobuf/GeneratedMessage.java
  13. 23 8
      java/src/main/java/com/google/protobuf/GeneratedMessageLite.java
  14. 2 2
      java/src/main/java/com/google/protobuf/Message.java
  15. 6 2
      java/src/main/java/com/google/protobuf/MessageLite.java
  16. 1 1
      java/src/main/java/com/google/protobuf/TextFormat.java
  17. 13 7
      java/src/main/java/com/google/protobuf/UnknownFieldSet.java
  18. 14 4
      java/src/main/java/com/google/protobuf/WireFormat.java
  19. 38 0
      java/src/test/java/com/google/protobuf/AbstractMessageTest.java
  20. 14 0
      java/src/test/java/com/google/protobuf/CodedInputStreamTest.java
  21. 51 0
      java/src/test/java/com/google/protobuf/DescriptorsTest.java
  22. 33 2
      java/src/test/java/com/google/protobuf/GeneratedMessageTest.java
  23. 47 0
      java/src/test/java/com/google/protobuf/ServiceTest.java
  24. 92 0
      java/src/test/java/com/google/protobuf/TestUtil.java
  25. 19 3
      java/src/test/java/com/google/protobuf/TextFormatTest.java
  26. 3 0
      java/src/test/java/com/google/protobuf/WireFormatTest.java
  27. 184 27
      python/google/protobuf/descriptor.py
  28. 43 26
      python/google/protobuf/internal/containers.py
  29. 601 169
      python/google/protobuf/internal/decoder.py
  30. 0 256
      python/google/protobuf/internal/decoder_test.py
  31. 224 3
      python/google/protobuf/internal/descriptor_test.py
  32. 647 241
      python/google/protobuf/internal/encoder.py
  33. 0 286
      python/google/protobuf/internal/encoder_test.py
  34. 106 1
      python/google/protobuf/internal/generator_test.py
  35. 0 338
      python/google/protobuf/internal/input_stream.py
  36. 0 314
      python/google/protobuf/internal/input_stream_test.py
  37. 25 16
      python/google/protobuf/internal/message_listener.py
  38. 39 3
      python/google/protobuf/internal/message_test.py
  39. 0 125
      python/google/protobuf/internal/output_stream.py
  40. 0 178
      python/google/protobuf/internal/output_stream_test.py
  41. 307 47
      python/google/protobuf/internal/reflection_test.py
  42. 245 222
      python/google/protobuf/internal/test_util.py
  43. 22 3
      python/google/protobuf/internal/text_format_test.py
  44. 63 64
      python/google/protobuf/internal/type_checkers.py
  45. 24 3
      python/google/protobuf/internal/wire_format.py
  46. 10 1
      python/google/protobuf/message.py
  47. 400 733
      python/google/protobuf/reflection.py
  48. 5 0
      python/google/protobuf/text_format.py
  49. 1 10
      python/setup.py
  50. 27 2
      src/Makefile.am
  51. 8 1
      src/google/protobuf/compiler/code_generator.cc
  52. 8 1
      src/google/protobuf/compiler/code_generator.h
  53. 432 36
      src/google/protobuf/compiler/command_line_interface.cc
  54. 62 3
      src/google/protobuf/compiler/command_line_interface.h
  55. 271 287
      src/google/protobuf/compiler/command_line_interface_unittest.cc
  56. 10 1
      src/google/protobuf/compiler/cpp/cpp_bootstrap_unittest.cc
  57. 10 4
      src/google/protobuf/compiler/cpp/cpp_enum.cc
  58. 45 26
      src/google/protobuf/compiler/cpp/cpp_enum_field.cc
  59. 1 0
      src/google/protobuf/compiler/cpp/cpp_enum_field.h
  60. 17 4
      src/google/protobuf/compiler/cpp/cpp_extension.cc
  61. 26 3
      src/google/protobuf/compiler/cpp/cpp_field.cc
  62. 6 0
      src/google/protobuf/compiler/cpp/cpp_field.h
  63. 45 14
      src/google/protobuf/compiler/cpp/cpp_file.cc
  64. 44 9
      src/google/protobuf/compiler/cpp/cpp_helpers.cc
  65. 19 5
      src/google/protobuf/compiler/cpp/cpp_helpers.h
  66. 72 27
      src/google/protobuf/compiler/cpp/cpp_message.cc
  67. 1 0
      src/google/protobuf/compiler/cpp/cpp_message.h
  68. 21 15
      src/google/protobuf/compiler/cpp/cpp_message_field.cc
  69. 41 37
      src/google/protobuf/compiler/cpp/cpp_primitive_field.cc
  70. 1 0
      src/google/protobuf/compiler/cpp/cpp_primitive_field.h
  71. 19 16
      src/google/protobuf/compiler/cpp/cpp_string_field.cc
  72. 69 4
      src/google/protobuf/compiler/cpp/cpp_unittest.cc
  73. 5 0
      src/google/protobuf/compiler/java/java_enum.cc
  74. 29 17
      src/google/protobuf/compiler/java/java_enum_field.cc
  75. 3 0
      src/google/protobuf/compiler/java/java_enum_field.h
  76. 1 2
      src/google/protobuf/compiler/java/java_extension.cc
  77. 10 0
      src/google/protobuf/compiler/java/java_field.cc
  78. 2 0
      src/google/protobuf/compiler/java/java_field.h
  79. 29 11
      src/google/protobuf/compiler/java/java_file.cc
  80. 5 0
      src/google/protobuf/compiler/java/java_file.h
  81. 1 0
      src/google/protobuf/compiler/java/java_generator.cc
  82. 34 9
      src/google/protobuf/compiler/java/java_helpers.cc
  83. 13 5
      src/google/protobuf/compiler/java/java_helpers.h
  84. 62 21
      src/google/protobuf/compiler/java/java_message.cc
  85. 14 4
      src/google/protobuf/compiler/java/java_message_field.cc
  86. 2 0
      src/google/protobuf/compiler/java/java_message_field.h
  87. 29 17
      src/google/protobuf/compiler/java/java_primitive_field.cc
  88. 3 0
      src/google/protobuf/compiler/java/java_primitive_field.h
  89. 1 0
      src/google/protobuf/compiler/main.cc
  90. 10 1
      src/google/protobuf/compiler/parser.cc
  91. 6 0
      src/google/protobuf/compiler/parser_unittest.cc
  92. 197 30
      src/google/protobuf/compiler/python/python_generator.cc
  93. 14 2
      src/google/protobuf/compiler/python/python_generator.h
  94. 8 12
      src/google/protobuf/descriptor.cc
  95. 19 4
      src/google/protobuf/descriptor.h
  96. 270 218
      src/google/protobuf/descriptor.pb.cc
  97. 211 98
      src/google/protobuf/descriptor.pb.h
  98. 21 1
      src/google/protobuf/descriptor.proto
  99. 30 0
      src/google/protobuf/descriptor_database.cc
  100. 4 0
      src/google/protobuf/descriptor_database.h

+ 46 - 0
CHANGES.txt

@@ -1,3 +1,49 @@
+2009-12-17 version 2.3.0:
+
+  General
+  * Parsers for repeated numeric fields now always accept both packed and
+    unpacked input.  The [packed=true] option only affects serializers.
+    Therefore, it is possible to switch a field to packed format without
+    breaking backwards-compatibility -- as long as all parties are using
+    protobuf 2.3.0 or above, at least.
+  * The generic RPC service code generated by the C++, Java, and Python
+    generators can be disabled via file options:
+      option cc_generic_services = false;
+      option java_generic_services = false;
+      option py_generic_services = false;
+    This allows plugins to generate alternative code, possibly specific to some
+    particular RPC implementation.
+
+  protoc
+  * Now supports a plugin system for code generators.  Plugins can generate
+    code for new languages or inject additional code into the output of other
+    code generators.  Plugins are just binaries which accept a protocol buffer
+    on stdin and write a protocol buffer to stdout, so they may be written in
+    any language.  See src/google/protobuf/compiler/plugin.proto.
+  * inf, -inf, and nan can now be used as default values for float and double
+    fields.
+
+  C++
+  * Various speed and code size optimizations.
+  * DynamicMessageFactory is now fully thread-safe.
+  * Message::Utf8DebugString() method is like DebugString() but avoids escaping
+    UTF-8 bytes.
+  * Compiled-in message types can now contain dynamic extensions, through use
+    of CodedInputStream::SetExtensionRegistry().
+
+  Java
+  * parseDelimitedFrom() and mergeDelimitedFrom() now detect EOF and return
+    false/null instead of throwing an exception.
+  * Fixed some initialization ordering bugs.
+  * Fixes for OpenJDK 7.
+
+  Python
+  * 10-25 times faster than 2.2.0, still pure-Python.
+  * Calling a mutating method on a sub-message always instantiates the message
+    in its parent even if the mutating method doesn't actually mutate anything
+    (e.g. parsing from an empty string).
+  * Expanded descriptors a bit.
+
 2009-08-11 version 2.2.0:
 
   C++

+ 0 - 6
Makefile.am

@@ -114,18 +114,12 @@ EXTRA_DIST =                                                                 \
   python/google/protobuf/internal/generator_test.py                          \
   python/google/protobuf/internal/containers.py                              \
   python/google/protobuf/internal/decoder.py                                 \
-  python/google/protobuf/internal/decoder_test.py                            \
   python/google/protobuf/internal/descriptor_test.py                         \
   python/google/protobuf/internal/encoder.py                                 \
-  python/google/protobuf/internal/encoder_test.py                            \
-  python/google/protobuf/internal/input_stream.py                            \
-  python/google/protobuf/internal/input_stream_test.py                       \
   python/google/protobuf/internal/message_listener.py                        \
   python/google/protobuf/internal/message_test.py                            \
   python/google/protobuf/internal/more_extensions.proto                      \
   python/google/protobuf/internal/more_messages.proto                        \
-  python/google/protobuf/internal/output_stream.py                           \
-  python/google/protobuf/internal/output_stream_test.py                      \
   python/google/protobuf/internal/reflection_test.py                         \
   python/google/protobuf/internal/service_reflection_test.py                 \
   python/google/protobuf/internal/test_util.py                               \

+ 10 - 0
autogen.sh

@@ -4,6 +4,8 @@
 # be included in the distribution.  These files are not checked in because they
 # are automatically generated.
 
+set -e
+
 # Check that we're being run from the right directory.
 if test ! -f src/google/protobuf/stubs/common.h; then
   cat >&2 << __EOF__
@@ -13,6 +15,14 @@ __EOF__
   exit 1
 fi
 
+# Check that gtest is present.  Usually it is already there since the
+# directory is set up as an SVN external.
+if test ! -e gtest; then
+  echo "Google Test not present.  Fetching gtest-1.3.0 from the web..."
+  curl http://googletest.googlecode.com/files/gtest-1.3.0.tar.bz2 | tar jx
+  mv gtest-1.3.0 gtest
+fi
+
 set -ex
 
 # Temporary hack:  Must change C runtime library to "multi-threaded DLL",

+ 3 - 1
generate_descriptor_proto.sh

@@ -27,5 +27,7 @@ __EOF__
 fi
 
 cd src
-make $@ protoc && ./protoc --cpp_out=dllexport_decl=LIBPROTOBUF_EXPORT:. google/protobuf/descriptor.proto
+make $@ protoc &&
+  ./protoc --cpp_out=dllexport_decl=LIBPROTOBUF_EXPORT:. google/protobuf/descriptor.proto && \
+  ./protoc --cpp_out=dllexport_decl=LIBPROTOC_EXPORT:. google/protobuf/compiler/plugin.proto
 cd ..

+ 1 - 0
java/pom.xml

@@ -113,6 +113,7 @@
                   <arg value="../src/google/protobuf/unittest_import_lite.proto" />
                   <arg value="../src/google/protobuf/unittest_lite_imports_nonlite.proto" />
                   <arg value="../src/google/protobuf/unittest_enormous_descriptor.proto" />
+                  <arg value="../src/google/protobuf/unittest_no_generic_services.proto" />
                 </exec>
               </tasks>
               <testSourceRoot>target/generated-test-sources</testSourceRoot>

+ 27 - 8
java/src/main/java/com/google/protobuf/AbstractMessage.java

@@ -311,6 +311,12 @@ public abstract class AbstractMessage extends AbstractMessageLite
           } else {
             field = extension.descriptor;
             defaultInstance = extension.defaultInstance;
+            if (defaultInstance == null &&
+                field.getJavaType() == FieldDescriptor.JavaType.MESSAGE) {
+              throw new IllegalStateException(
+                  "Message-typed extension lacked default instance: " +
+                  field.getFullName());
+            }
           }
         } else {
           field = null;
@@ -319,15 +325,28 @@ public abstract class AbstractMessage extends AbstractMessageLite
         field = type.findFieldByNumber(fieldNumber);
       }
 
-      if (field == null || wireType !=
-            FieldSet.getWireFormatForFieldType(
-                field.getLiteType(),
-                field.getOptions().getPacked())) {
-        // Unknown field or wrong wire type.  Skip.
+      boolean unknown = false;
+      boolean packed = false;
+      if (field == null) {
+        unknown = true;  // Unknown field.
+      } else if (wireType == FieldSet.getWireFormatForFieldType(
+                   field.getLiteType(),
+                   false  /* isPacked */)) {
+        packed = false;
+      } else if (field.isPackable() &&
+                 wireType == FieldSet.getWireFormatForFieldType(
+                   field.getLiteType(),
+                   true  /* isPacked */)) {
+        packed = true;
+      } else {
+        unknown = true;  // Unknown wire type.
+      }
+
+      if (unknown) {  // Unknown field or wrong wire type.  Skip.
         return unknownFields.mergeFieldFrom(tag, input);
       }
 
-      if (field.getOptions().getPacked()) {
+      if (packed) {
         final int length = input.readRawVarint32();
         final int limit = input.pushLimit(length);
         if (field.getLiteType() == WireFormat.FieldType.ENUM) {
@@ -673,13 +692,13 @@ public abstract class AbstractMessage extends AbstractMessageLite
     }
 
     @Override
-    public BuilderType mergeDelimitedFrom(final InputStream input)
+    public boolean mergeDelimitedFrom(final InputStream input)
         throws IOException {
       return super.mergeDelimitedFrom(input);
     }
 
     @Override
-    public BuilderType mergeDelimitedFrom(
+    public boolean mergeDelimitedFrom(
         final InputStream input,
         final ExtensionRegistryLite extensionRegistry)
         throws IOException {

+ 13 - 15
java/src/main/java/com/google/protobuf/AbstractMessageLite.java

@@ -86,7 +86,7 @@ public abstract class AbstractMessageLite implements MessageLite {
         CodedOutputStream.computeRawVarint32Size(serialized) + serialized);
     final CodedOutputStream codedOutput =
         CodedOutputStream.newInstance(output, bufferSize);
-    codedOutput.writeRawVarint32(getSerializedSize());
+    codedOutput.writeRawVarint32(serialized);
     writeTo(codedOutput);
     codedOutput.flush();
   }
@@ -105,13 +105,7 @@ public abstract class AbstractMessageLite implements MessageLite {
 
     public BuilderType mergeFrom(final CodedInputStream input)
                                  throws IOException {
-      // TODO(kenton):  Don't use null here.  Currently we have to because
-      //   using ExtensionRegistry.getEmptyRegistry() would imply a dependency
-      //   on ExtensionRegistry.  However, AbstractMessage overrides this with
-      //   a correct implementation, and lite messages don't yet support
-      //   extensions, so it ends up not mattering for now.  It will matter
-      //   once lite messages support extensions.
-      return mergeFrom(input, null);
+      return mergeFrom(input, ExtensionRegistryLite.getEmptyRegistry());
     }
 
     // Re-defined here for return type covariance.
@@ -275,20 +269,24 @@ public abstract class AbstractMessageLite implements MessageLite {
       }
     }
 
-    public BuilderType mergeDelimitedFrom(
+    public boolean mergeDelimitedFrom(
         final InputStream input,
         final ExtensionRegistryLite extensionRegistry)
         throws IOException {
-      final int size = CodedInputStream.readRawVarint32(input);
+      final int firstByte = input.read();
+      if (firstByte == -1) {
+        return false;
+      }
+      final int size = CodedInputStream.readRawVarint32(firstByte, input);
       final InputStream limitedInput = new LimitedInputStream(input, size);
-      return mergeFrom(limitedInput, extensionRegistry);
+      mergeFrom(limitedInput, extensionRegistry);
+      return true;
     }
 
-    public BuilderType mergeDelimitedFrom(final InputStream input)
+    public boolean mergeDelimitedFrom(final InputStream input)
         throws IOException {
-      final int size = CodedInputStream.readRawVarint32(input);
-      final InputStream limitedInput = new LimitedInputStream(input, size);
-      return mergeFrom(limitedInput);
+      return mergeDelimitedFrom(input,
+          ExtensionRegistryLite.getEmptyRegistry());
     }
 
     /**

+ 18 - 0
java/src/main/java/com/google/protobuf/ByteString.java

@@ -98,6 +98,24 @@ public final class ByteString {
     return copyFrom(bytes, 0, bytes.length);
   }
 
+  /**
+   * Copies {@code size} bytes from a {@code java.nio.ByteBuffer} into
+   * a {@code ByteString}.
+   */
+  public static ByteString copyFrom(final ByteBuffer bytes, final int size) {
+    final byte[] copy = new byte[size];
+    bytes.get(copy);
+    return new ByteString(copy);
+  }
+
+  /**
+   * Copies the remaining bytes from a {@code java.nio.ByteBuffer} into
+   * a {@code ByteString}.
+   */
+  public static ByteString copyFrom(final ByteBuffer bytes) {
+    return copyFrom(bytes, bytes.remaining());
+  }
+
   /**
    * Encodes {@code text} into a sequence of bytes using the named charset
    * and returns the result as a {@code ByteString}.

+ 23 - 4
java/src/main/java/com/google/protobuf/CodedInputStream.java

@@ -84,8 +84,9 @@ public final class CodedInputStream {
     }
 
     lastTag = readRawVarint32();
-    if (lastTag == 0) {
-      // If we actually read zero, that's not a valid tag.
+    if (WireFormat.getTagFieldNumber(lastTag) == 0) {
+      // If we actually read zero (or any tag number corresponding to field
+      // number zero), that's not a valid tag.
       throw InvalidProtocolBufferException.invalidTag();
     }
     return lastTag;
@@ -355,8 +356,26 @@ public final class CodedInputStream {
    * CodedInputStream buffers its input.
    */
   static int readRawVarint32(final InputStream input) throws IOException {
-    int result = 0;
-    int offset = 0;
+    final int firstByte = input.read();
+    if (firstByte == -1) {
+      throw InvalidProtocolBufferException.truncatedMessage();
+    }
+    return readRawVarint32(firstByte, input);
+  }
+
+  /**
+   * Like {@link #readRawVarint32(InputStream)}, but expects that the caller
+   * has already read one byte.  This allows the caller to determine if EOF
+   * has been reached before attempting to read.
+   */
+  static int readRawVarint32(final int firstByte,
+                             final InputStream input) throws IOException {
+    if ((firstByte & 0x80) == 0) {
+      return firstByte;
+    }
+
+    int result = firstByte & 0x7f;
+    int offset = 7;
     for (; offset < 32; offset += 7) {
       final int b = input.read();
       if (b == -1) {

+ 63 - 55
java/src/main/java/com/google/protobuf/Descriptors.java

@@ -48,7 +48,7 @@ import java.io.UnsupportedEncodingException;
  * (given a message object of the type) {@code message.getDescriptorForType()}.
  *
  * Descriptors are built from DescriptorProtos, as defined in
- * {@code net/proto2/proto/descriptor.proto}.
+ * {@code google/protobuf/descriptor.proto}.
  *
  * @author kenton@google.com Kenton Varda
  */
@@ -699,6 +699,11 @@ public final class Descriptors {
       return getOptions().getPacked();
     }
 
+    /** Can this field be packed? i.e. is it a repeated primitive field? */
+    public boolean isPackable() {
+      return isRepeated() && getLiteType().isPackable();
+    }
+
     /** Returns true if the field had an explicitly-defined default value. */
     public boolean hasDefaultValue() { return proto.hasDefaultValue(); }
 
@@ -810,39 +815,34 @@ public final class Descriptors {
     private Object defaultValue;
 
     public enum Type {
-      DOUBLE  (FieldDescriptorProto.Type.TYPE_DOUBLE  , JavaType.DOUBLE     ),
-      FLOAT   (FieldDescriptorProto.Type.TYPE_FLOAT   , JavaType.FLOAT      ),
-      INT64   (FieldDescriptorProto.Type.TYPE_INT64   , JavaType.LONG       ),
-      UINT64  (FieldDescriptorProto.Type.TYPE_UINT64  , JavaType.LONG       ),
-      INT32   (FieldDescriptorProto.Type.TYPE_INT32   , JavaType.INT        ),
-      FIXED64 (FieldDescriptorProto.Type.TYPE_FIXED64 , JavaType.LONG       ),
-      FIXED32 (FieldDescriptorProto.Type.TYPE_FIXED32 , JavaType.INT        ),
-      BOOL    (FieldDescriptorProto.Type.TYPE_BOOL    , JavaType.BOOLEAN    ),
-      STRING  (FieldDescriptorProto.Type.TYPE_STRING  , JavaType.STRING     ),
-      GROUP   (FieldDescriptorProto.Type.TYPE_GROUP   , JavaType.MESSAGE    ),
-      MESSAGE (FieldDescriptorProto.Type.TYPE_MESSAGE , JavaType.MESSAGE    ),
-      BYTES   (FieldDescriptorProto.Type.TYPE_BYTES   , JavaType.BYTE_STRING),
-      UINT32  (FieldDescriptorProto.Type.TYPE_UINT32  , JavaType.INT        ),
-      ENUM    (FieldDescriptorProto.Type.TYPE_ENUM    , JavaType.ENUM       ),
-      SFIXED32(FieldDescriptorProto.Type.TYPE_SFIXED32, JavaType.INT        ),
-      SFIXED64(FieldDescriptorProto.Type.TYPE_SFIXED64, JavaType.LONG       ),
-      SINT32  (FieldDescriptorProto.Type.TYPE_SINT32  , JavaType.INT        ),
-      SINT64  (FieldDescriptorProto.Type.TYPE_SINT64  , JavaType.LONG       );
-
-      Type(final FieldDescriptorProto.Type proto, final JavaType javaType) {
-        this.proto = proto;
+      DOUBLE  (JavaType.DOUBLE     ),
+      FLOAT   (JavaType.FLOAT      ),
+      INT64   (JavaType.LONG       ),
+      UINT64  (JavaType.LONG       ),
+      INT32   (JavaType.INT        ),
+      FIXED64 (JavaType.LONG       ),
+      FIXED32 (JavaType.INT        ),
+      BOOL    (JavaType.BOOLEAN    ),
+      STRING  (JavaType.STRING     ),
+      GROUP   (JavaType.MESSAGE    ),
+      MESSAGE (JavaType.MESSAGE    ),
+      BYTES   (JavaType.BYTE_STRING),
+      UINT32  (JavaType.INT        ),
+      ENUM    (JavaType.ENUM       ),
+      SFIXED32(JavaType.INT        ),
+      SFIXED64(JavaType.LONG       ),
+      SINT32  (JavaType.INT        ),
+      SINT64  (JavaType.LONG       );
+
+      Type(final JavaType javaType) {
         this.javaType = javaType;
-
-        if (ordinal() != proto.getNumber() - 1) {
-          throw new RuntimeException(
-            "descriptor.proto changed but Desrciptors.java wasn't updated.");
-        }
       }
 
-      private FieldDescriptorProto.Type proto;
       private JavaType javaType;
 
-      public FieldDescriptorProto.Type toProto() { return proto; }
+      public FieldDescriptorProto.Type toProto() {
+        return FieldDescriptorProto.Type.valueOf(ordinal() + 1);
+      }
       public JavaType getJavaType() { return javaType; }
 
       public static Type valueOf(final FieldDescriptorProto.Type type) {
@@ -902,16 +902,10 @@ public final class Descriptors {
       }
 
       // Only repeated primitive fields may be packed.
-      if (proto.getOptions().getPacked()) {
-        if (proto.getLabel() != FieldDescriptorProto.Label.LABEL_REPEATED ||
-            proto.getType() == FieldDescriptorProto.Type.TYPE_STRING ||
-            proto.getType() == FieldDescriptorProto.Type.TYPE_GROUP ||
-            proto.getType() == FieldDescriptorProto.Type.TYPE_MESSAGE ||
-            proto.getType() == FieldDescriptorProto.Type.TYPE_BYTES) {
-          throw new DescriptorValidationException(this,
-            "[packed = true] can only be specified for repeated primitive " +
-            "fields.");
-        }
+      if (proto.getOptions().getPacked() && !isPackable()) {
+        throw new DescriptorValidationException(this,
+          "[packed = true] can only be specified for repeated primitive " +
+          "fields.");
       }
 
       if (isExtension) {
@@ -1030,10 +1024,26 @@ public final class Descriptors {
               defaultValue = TextFormat.parseUInt64(proto.getDefaultValue());
               break;
             case FLOAT:
-              defaultValue = Float.valueOf(proto.getDefaultValue());
+              if (proto.getDefaultValue().equals("inf")) {
+                defaultValue = Float.POSITIVE_INFINITY;
+              } else if (proto.getDefaultValue().equals("-inf")) {
+                defaultValue = Float.NEGATIVE_INFINITY;
+              } else if (proto.getDefaultValue().equals("nan")) {
+                defaultValue = Float.NaN;
+              } else {
+                defaultValue = Float.valueOf(proto.getDefaultValue());
+              }
               break;
             case DOUBLE:
-              defaultValue = Double.valueOf(proto.getDefaultValue());
+              if (proto.getDefaultValue().equals("inf")) {
+                defaultValue = Double.POSITIVE_INFINITY;
+              } else if (proto.getDefaultValue().equals("-inf")) {
+                defaultValue = Double.NEGATIVE_INFINITY;
+              } else if (proto.getDefaultValue().equals("nan")) {
+                defaultValue = Double.NaN;
+              } else {
+                defaultValue = Double.valueOf(proto.getDefaultValue());
+              }
               break;
             case BOOL:
               defaultValue = Boolean.valueOf(proto.getDefaultValue());
@@ -1064,12 +1074,9 @@ public final class Descriptors {
                 "Message type had default value.");
           }
         } catch (NumberFormatException e) {
-          final DescriptorValidationException validationException =
-            new DescriptorValidationException(this,
-              "Could not parse default value: \"" +
-              proto.getDefaultValue() + '\"');
-          validationException.initCause(e);
-          throw validationException;
+          throw new DescriptorValidationException(this, 
+              "Could not parse default value: \"" + 
+              proto.getDefaultValue() + '\"', e);
         }
       } else {
         // Determine the default default for this field.
@@ -1536,14 +1543,7 @@ public final class Descriptors {
     private DescriptorValidationException(
         final GenericDescriptor problemDescriptor,
         final String description) {
-      this(problemDescriptor, description, null);
-    }
-
-    private DescriptorValidationException(
-        final GenericDescriptor problemDescriptor,
-        final String description,
-        final Throwable cause) {
-      super(problemDescriptor.getFullName() + ": " + description, cause);
+      super(problemDescriptor.getFullName() + ": " + description);
 
       // Note that problemDescriptor may be partially uninitialized, so we
       // don't want to expose it directly to the user.  So, we only provide
@@ -1553,6 +1553,14 @@ public final class Descriptors {
       this.description = description;
     }
 
+    private DescriptorValidationException(
+        final GenericDescriptor problemDescriptor,
+        final String description,
+        final Throwable cause) {
+      this(problemDescriptor, description);
+      initCause(cause);
+    }
+
     private DescriptorValidationException(
         final FileDescriptor problemDescriptor,
         final String description) {

+ 5 - 0
java/src/main/java/com/google/protobuf/ExtensionRegistry.java

@@ -157,6 +157,11 @@ public final class ExtensionRegistry extends ExtensionRegistryLite {
   public void add(final GeneratedMessage.GeneratedExtension<?, ?> extension) {
     if (extension.getDescriptor().getJavaType() ==
         FieldDescriptor.JavaType.MESSAGE) {
+      if (extension.getMessageDefaultInstance() == null) {
+        throw new IllegalStateException(
+            "Registered message-type extension had null default instance: " +
+            extension.getDescriptor().getFullName());
+      }
       add(new ExtensionInfo(extension.getDescriptor(),
                             extension.getMessageDefaultInstance()));
     } else {

+ 4 - 0
java/src/main/java/com/google/protobuf/GeneratedMessage.java

@@ -789,6 +789,10 @@ public abstract class GeneratedMessage extends AbstractMessage {
           messageDefaultInstance =
             (Message) invokeOrDie(getMethodOrDie(type, "getDefaultInstance"),
                                   null);
+          if (messageDefaultInstance == null) {
+            throw new IllegalStateException(
+                type.getName() + ".getDefaultInstance() returned null.");
+          }
           break;
         case ENUM:
           enumValueOf = getMethodOrDie(type, "valueOf",

+ 23 - 8
java/src/main/java/com/google/protobuf/GeneratedMessageLite.java

@@ -303,7 +303,7 @@ public abstract class GeneratedMessageLite extends AbstractMessageLite {
         final ExtensionRegistryLite extensionRegistry,
         final int tag) throws IOException {
       final FieldSet<ExtensionDescriptor> extensions =
-          internalGetResult().extensions;
+          ((ExtendableMessage) internalGetResult()).extensions;
 
       final int wireType = WireFormat.getTagWireType(tag);
       final int fieldNumber = WireFormat.getTagFieldNumber(tag);
@@ -312,15 +312,29 @@ public abstract class GeneratedMessageLite extends AbstractMessageLite {
         extensionRegistry.findLiteExtensionByNumber(
             getDefaultInstanceForType(), fieldNumber);
 
-      if (extension == null || wireType !=
-            FieldSet.getWireFormatForFieldType(
-                extension.descriptor.getLiteType(),
-                extension.descriptor.isPacked())) {
-        // Unknown field or wrong wire type.  Skip.
+      boolean unknown = false;
+      boolean packed = false;
+      if (extension == null) {
+        unknown = true;  // Unknown field.
+      } else if (wireType == FieldSet.getWireFormatForFieldType(
+                   extension.descriptor.getLiteType(),
+                   false  /* isPacked */)) {
+        packed = false;  // Normal, unpacked value.
+      } else if (extension.descriptor.isRepeated &&
+                 extension.descriptor.type.isPackable() &&
+                 wireType == FieldSet.getWireFormatForFieldType(
+                   extension.descriptor.getLiteType(),
+                   true  /* isPacked */)) {
+        packed = true;  // Packed value.
+      } else {
+        unknown = true;  // Wrong wire type.
+      }
+
+      if (unknown) {  // Unknown field or wrong wire type.  Skip.
         return input.skipField(tag);
       }
 
-      if (extension.descriptor.isPacked()) {
+      if (packed) {
         final int length = input.readRawVarint32();
         final int limit = input.pushLimit(length);
         if (extension.descriptor.getLiteType() == WireFormat.FieldType.ENUM) {
@@ -396,7 +410,8 @@ public abstract class GeneratedMessageLite extends AbstractMessageLite {
     }
 
     protected final void mergeExtensionFields(final MessageType other) {
-      internalGetResult().extensions.mergeFrom(other.extensions);
+      ((ExtendableMessage) internalGetResult()).extensions.mergeFrom(
+          ((ExtendableMessage) other).extensions);
     }
   }
 

+ 2 - 2
java/src/main/java/com/google/protobuf/Message.java

@@ -296,9 +296,9 @@ public interface Message extends MessageLite {
     Builder mergeFrom(InputStream input,
                       ExtensionRegistryLite extensionRegistry)
                       throws IOException;
-    Builder mergeDelimitedFrom(InputStream input)
+    boolean mergeDelimitedFrom(InputStream input)
                                throws IOException;
-    Builder mergeDelimitedFrom(InputStream input,
+    boolean mergeDelimitedFrom(InputStream input,
                                ExtensionRegistryLite extensionRegistry)
                                throws IOException;
   }

+ 6 - 2
java/src/main/java/com/google/protobuf/MessageLite.java

@@ -317,14 +317,18 @@ public interface MessageLite {
      * then the message data.  Use
      * {@link MessageLite#writeDelimitedTo(OutputStream)} to write messages in
      * this format.
+     *
+     * @returns True if successful, or false if the stream is at EOF when the
+     *          method starts.  Any other error (including reaching EOF during
+     *          parsing) will cause an exception to be thrown.
      */
-    Builder mergeDelimitedFrom(InputStream input)
+    boolean mergeDelimitedFrom(InputStream input)
                                throws IOException;
 
     /**
      * Like {@link #mergeDelimitedFrom(InputStream)} but supporting extensions.
      */
-    Builder mergeDelimitedFrom(InputStream input,
+    boolean mergeDelimitedFrom(InputStream input,
                                ExtensionRegistryLite extensionRegistry)
                                throws IOException;
   }

+ 1 - 1
java/src/main/java/com/google/protobuf/TextFormat.java

@@ -426,7 +426,7 @@ public final class TextFormat {
       Pattern.compile("(\\s|(#.*$))++", Pattern.MULTILINE);
     private static final Pattern TOKEN = Pattern.compile(
       "[a-zA-Z_][0-9a-zA-Z_+-]*+|" +                // an identifier
-      "[0-9+-][0-9a-zA-Z_.+-]*+|" +                 // a number
+      "[.]?[0-9+-][0-9a-zA-Z_.+-]*+|" +             // a number
       "\"([^\"\n\\\\]|\\\\.)*+(\"|\\\\?$)|" +       // a double-quoted string
       "\'([^\"\n\\\\]|\\\\.)*+(\'|\\\\?$)",         // a single-quoted string
       Pattern.MULTILINE);

+ 13 - 7
java/src/main/java/com/google/protobuf/UnknownFieldSet.java

@@ -30,6 +30,8 @@
 
 package com.google.protobuf;
 
+import com.google.protobuf.AbstractMessageLite.Builder.LimitedInputStream;
+
 import java.io.IOException;
 import java.io.InputStream;
 import java.io.OutputStream;
@@ -551,19 +553,23 @@ public final class UnknownFieldSet implements MessageLite {
       return this;
     }
 
-    public Builder mergeDelimitedFrom(InputStream input)
+    public boolean mergeDelimitedFrom(InputStream input)
         throws IOException {
-      final int size = CodedInputStream.readRawVarint32(input);
-      final InputStream limitedInput =
-        new AbstractMessage.Builder.LimitedInputStream(input, size);
-      return mergeFrom(limitedInput, null);
+      final int firstByte = input.read();
+      if (firstByte == -1) {
+        return false;
+      }
+      final int size = CodedInputStream.readRawVarint32(firstByte, input);
+      final InputStream limitedInput = new LimitedInputStream(input, size);
+      mergeFrom(limitedInput);
+      return true;
     }
 
-    public Builder mergeDelimitedFrom(
+    public boolean mergeDelimitedFrom(
         InputStream input,
         ExtensionRegistryLite extensionRegistry) throws IOException {
       // UnknownFieldSet has no extensions.
-      return mergeFrom(input);
+      return mergeDelimitedFrom(input);
     }
 
     public Builder mergeFrom(

+ 14 - 4
java/src/main/java/com/google/protobuf/WireFormat.java

@@ -113,10 +113,18 @@ public final class WireFormat {
     FIXED64 (JavaType.LONG       , WIRETYPE_FIXED64         ),
     FIXED32 (JavaType.INT        , WIRETYPE_FIXED32         ),
     BOOL    (JavaType.BOOLEAN    , WIRETYPE_VARINT          ),
-    STRING  (JavaType.STRING     , WIRETYPE_LENGTH_DELIMITED),
-    GROUP   (JavaType.MESSAGE    , WIRETYPE_START_GROUP     ),
-    MESSAGE (JavaType.MESSAGE    , WIRETYPE_LENGTH_DELIMITED),
-    BYTES   (JavaType.BYTE_STRING, WIRETYPE_LENGTH_DELIMITED),
+    STRING  (JavaType.STRING     , WIRETYPE_LENGTH_DELIMITED) {
+      public boolean isPackable() { return false; }
+    },
+    GROUP   (JavaType.MESSAGE    , WIRETYPE_START_GROUP     ) {
+      public boolean isPackable() { return false; }
+    },
+    MESSAGE (JavaType.MESSAGE    , WIRETYPE_LENGTH_DELIMITED) {
+      public boolean isPackable() { return false; }
+    },
+    BYTES   (JavaType.BYTE_STRING, WIRETYPE_LENGTH_DELIMITED) {
+      public boolean isPackable() { return false; }
+    },
     UINT32  (JavaType.INT        , WIRETYPE_VARINT          ),
     ENUM    (JavaType.ENUM       , WIRETYPE_VARINT          ),
     SFIXED32(JavaType.INT        , WIRETYPE_FIXED32         ),
@@ -134,6 +142,8 @@ public final class WireFormat {
 
     public JavaType getJavaType() { return javaType; }
     public int getWireType() { return wireType; }
+
+    public boolean isPackable() { return true; }
   }
 
   // Field numbers for feilds in MessageSet wire format.

+ 38 - 0
java/src/test/java/com/google/protobuf/AbstractMessageTest.java

@@ -38,6 +38,7 @@ import protobuf_unittest.UnittestProto.TestAllTypes;
 import protobuf_unittest.UnittestProto.TestPackedTypes;
 import protobuf_unittest.UnittestProto.TestRequired;
 import protobuf_unittest.UnittestProto.TestRequiredForeign;
+import protobuf_unittest.UnittestProto.TestUnpackedTypes;
 
 import junit.framework.TestCase;
 
@@ -238,6 +239,43 @@ public class AbstractMessageTest extends TestCase {
     TestUtil.assertPackedFieldsSet((TestPackedTypes) message.wrappedMessage);
   }
 
+  public void testUnpackedSerialization() throws Exception {
+    Message abstractMessage =
+      new AbstractMessageWrapper(TestUtil.getUnpackedSet());
+
+    TestUtil.assertUnpackedFieldsSet(
+      TestUnpackedTypes.parseFrom(abstractMessage.toByteString()));
+
+    assertEquals(TestUtil.getUnpackedSet().toByteString(),
+                 abstractMessage.toByteString());
+  }
+
+  public void testParsePackedToUnpacked() throws Exception {
+    AbstractMessageWrapper.Builder builder =
+      new AbstractMessageWrapper.Builder(TestUnpackedTypes.newBuilder());
+    AbstractMessageWrapper message =
+      builder.mergeFrom(TestUtil.getPackedSet().toByteString()).build();
+    TestUtil.assertUnpackedFieldsSet(
+      (TestUnpackedTypes) message.wrappedMessage);
+  }
+
+  public void testParseUnpackedToPacked() throws Exception {
+    AbstractMessageWrapper.Builder builder =
+      new AbstractMessageWrapper.Builder(TestPackedTypes.newBuilder());
+    AbstractMessageWrapper message =
+      builder.mergeFrom(TestUtil.getUnpackedSet().toByteString()).build();
+    TestUtil.assertPackedFieldsSet((TestPackedTypes) message.wrappedMessage);
+  }
+
+  public void testUnpackedParsing() throws Exception {
+    AbstractMessageWrapper.Builder builder =
+      new AbstractMessageWrapper.Builder(TestUnpackedTypes.newBuilder());
+    AbstractMessageWrapper message =
+      builder.mergeFrom(TestUtil.getUnpackedSet().toByteString()).build();
+    TestUtil.assertUnpackedFieldsSet(
+      (TestUnpackedTypes) message.wrappedMessage);
+  }
+
   public void testOptimizedForSize() throws Exception {
     // We're mostly only checking that this class was compiled successfully.
     TestOptimizedForSize message =

+ 14 - 0
java/src/test/java/com/google/protobuf/CodedInputStreamTest.java

@@ -490,4 +490,18 @@ public class CodedInputStreamTest extends TestCase {
     assertEquals(0, in.readTag());
     assertEquals(5, in.getTotalBytesRead());
   }
+
+  public void testInvalidTag() throws Exception {
+    // Any tag number which corresponds to field number zero is invalid and
+    // should throw InvalidProtocolBufferException.
+    for (int i = 0; i < 8; i++) {
+      try {
+        CodedInputStream.newInstance(bytes(i)).readTag();
+        fail("Should have thrown an exception.");
+      } catch (InvalidProtocolBufferException e) {
+        assertEquals(InvalidProtocolBufferException.invalidTag().getMessage(),
+                     e.getMessage());
+      }
+    }
+  }
 }

+ 51 - 0
java/src/test/java/com/google/protobuf/DescriptorsTest.java

@@ -30,6 +30,10 @@
 
 package com.google.protobuf;
 
+import com.google.protobuf.DescriptorProtos.DescriptorProto;
+import com.google.protobuf.DescriptorProtos.FieldDescriptorProto;
+import com.google.protobuf.DescriptorProtos.FileDescriptorProto;
+import com.google.protobuf.Descriptors.DescriptorValidationException;
 import com.google.protobuf.Descriptors.FileDescriptor;
 import com.google.protobuf.Descriptors.Descriptor;
 import com.google.protobuf.Descriptors.FieldDescriptor;
@@ -63,6 +67,22 @@ import java.util.Collections;
  * @author kenton@google.com Kenton Varda
  */
 public class DescriptorsTest extends TestCase {
+
+  // Regression test for bug where referencing a FieldDescriptor.Type value
+  // before a FieldDescriptorProto.Type value would yield a
+  // ExceptionInInitializerError.
+  private static final Object STATIC_INIT_TEST = FieldDescriptor.Type.BOOL;
+
+  public void testFieldTypeEnumMapping() throws Exception {
+    assertEquals(FieldDescriptor.Type.values().length,
+        FieldDescriptorProto.Type.values().length);
+    for (FieldDescriptor.Type type : FieldDescriptor.Type.values()) {
+      FieldDescriptorProto.Type protoType = type.toProto();
+      assertEquals("TYPE_" + type.name(), protoType.name());
+      assertEquals(type, FieldDescriptor.Type.valueOf(protoType));
+    }
+  }
+
   public void testFileDescriptor() throws Exception {
     FileDescriptor file = UnittestProto.getDescriptor();
 
@@ -405,4 +425,35 @@ public class DescriptorsTest extends TestCase {
         UnittestEnormousDescriptor.getDescriptor()
           .toProto().getSerializedSize() > 65536);
   }
+  
+  /**
+   * Tests that the DescriptorValidationException works as intended.
+   */
+  public void testDescriptorValidatorException() throws Exception {
+    FileDescriptorProto fileDescriptorProto = FileDescriptorProto.newBuilder()
+      .setName("foo.proto")
+      .addMessageType(DescriptorProto.newBuilder()
+      .setName("Foo")
+        .addField(FieldDescriptorProto.newBuilder()
+          .setLabel(FieldDescriptorProto.Label.LABEL_OPTIONAL)
+          .setType(FieldDescriptorProto.Type.TYPE_INT32)
+          .setName("foo")
+          .setNumber(1)
+          .setDefaultValue("invalid")
+          .build())
+        .build())
+      .build();
+    try {
+      Descriptors.FileDescriptor.buildFrom(fileDescriptorProto, 
+          new FileDescriptor[0]);
+      fail("DescriptorValidationException expected");
+    } catch (DescriptorValidationException e) {
+      // Expected; check that the error message contains some useful hints
+      assertTrue(e.getMessage().indexOf("foo") != -1);
+      assertTrue(e.getMessage().indexOf("Foo") != -1);
+      assertTrue(e.getMessage().indexOf("invalid") != -1);
+      assertTrue(e.getCause() instanceof NumberFormatException);
+      assertTrue(e.getCause().getMessage().indexOf("invalid") != -1);
+    }
+  }
 }

+ 33 - 2
java/src/test/java/com/google/protobuf/GeneratedMessageTest.java

@@ -39,6 +39,8 @@ import protobuf_unittest.UnittestProto.ForeignEnum;
 import protobuf_unittest.UnittestProto.TestAllTypes;
 import protobuf_unittest.UnittestProto.TestAllExtensions;
 import protobuf_unittest.UnittestProto.TestExtremeDefaultValues;
+import protobuf_unittest.UnittestProto.TestPackedTypes;
+import protobuf_unittest.UnittestProto.TestUnpackedTypes;
 import protobuf_unittest.MultipleFilesTestProto;
 import protobuf_unittest.MessageWithNoOuter;
 import protobuf_unittest.EnumWithNoOuter;
@@ -303,8 +305,15 @@ public class GeneratedMessageTest extends TestCase {
     TestUtil.assertClear(TestAllTypes.getDefaultInstance());
     TestUtil.assertClear(TestAllTypes.newBuilder().build());
 
-    assertEquals("\u1234",
-                 TestExtremeDefaultValues.getDefaultInstance().getUtf8String());
+    TestExtremeDefaultValues message =
+        TestExtremeDefaultValues.getDefaultInstance();
+    assertEquals("\u1234", message.getUtf8String());
+    assertEquals(Double.POSITIVE_INFINITY, message.getInfDouble());
+    assertEquals(Double.NEGATIVE_INFINITY, message.getNegInfDouble());
+    assertTrue(Double.isNaN(message.getNanDouble()));
+    assertEquals(Float.POSITIVE_INFINITY, message.getInfFloat());
+    assertEquals(Float.NEGATIVE_INFINITY, message.getNegInfFloat());
+    assertTrue(Float.isNaN(message.getNanFloat()));
   }
 
   public void testReflectionGetters() throws Exception {
@@ -361,6 +370,20 @@ public class GeneratedMessageTest extends TestCase {
     assertTrue(map.findValueByNumber(12345) == null);
   }
 
+  public void testParsePackedToUnpacked() throws Exception {
+    TestUnpackedTypes.Builder builder = TestUnpackedTypes.newBuilder();
+    TestUnpackedTypes message =
+      builder.mergeFrom(TestUtil.getPackedSet().toByteString()).build();
+    TestUtil.assertUnpackedFieldsSet(message);
+  }
+
+  public void testParseUnpackedToPacked() throws Exception {
+    TestPackedTypes.Builder builder = TestPackedTypes.newBuilder();
+    TestPackedTypes message =
+      builder.mergeFrom(TestUtil.getUnpackedSet().toByteString()).build();
+    TestUtil.assertPackedFieldsSet(message);
+  }
+
   // =================================================================
   // Extensions.
 
@@ -615,4 +638,12 @@ public class GeneratedMessageTest extends TestCase {
       UnittestProto.REPEATED_NESTED_MESSAGE_EXTENSION_FIELD_NUMBER, 48);
     assertEquals(UnittestProto.REPEATED_NESTED_ENUM_EXTENSION_FIELD_NUMBER, 51);
   }
+
+  public void testRecursiveMessageDefaultInstance() throws Exception {
+    UnittestProto.TestRecursiveMessage message =
+        UnittestProto.TestRecursiveMessage.getDefaultInstance();
+    assertTrue(message != null);
+    assertTrue(message.getA() != null);
+    assertTrue(message.getA() == message);
+  }
 }

+ 47 - 0
java/src/test/java/com/google/protobuf/ServiceTest.java

@@ -30,7 +30,9 @@
 
 package com.google.protobuf;
 
+import com.google.protobuf.Descriptors.FileDescriptor;
 import com.google.protobuf.Descriptors.MethodDescriptor;
+import google.protobuf.no_generic_services_test.UnittestNoGenericServices;
 import protobuf_unittest.MessageWithNoOuter;
 import protobuf_unittest.ServiceWithNoOuter;
 import protobuf_unittest.UnittestProto.TestAllTypes;
@@ -44,6 +46,9 @@ import org.easymock.classextension.EasyMock;
 import org.easymock.classextension.IMocksControl;
 import org.easymock.IArgumentMatcher;
 
+import java.util.HashSet;
+import java.util.Set;
+
 import junit.framework.TestCase;
 
 /**
@@ -220,6 +225,48 @@ public class ServiceTest extends TestCase {
     control.verify();
   }
 
+  public void testNoGenericServices() throws Exception {
+    // Non-services should be usable.
+    UnittestNoGenericServices.TestMessage message =
+      UnittestNoGenericServices.TestMessage.newBuilder()
+        .setA(123)
+        .setExtension(UnittestNoGenericServices.testExtension, 456)
+        .build();
+    assertEquals(123, message.getA());
+    assertEquals(1, UnittestNoGenericServices.TestEnum.FOO.getNumber());
+
+    // Build a list of the class names nested in UnittestNoGenericServices.
+    String outerName = "google.protobuf.no_generic_services_test." +
+                       "UnittestNoGenericServices";
+    Class<?> outerClass = Class.forName(outerName);
+
+    Set<String> innerClassNames = new HashSet<String>();
+    for (Class<?> innerClass : outerClass.getClasses()) {
+      String fullName = innerClass.getName();
+      // Figure out the unqualified name of the inner class.
+      // Note:  Surprisingly, the full name of an inner class will be separated
+      //   from the outer class name by a '$' rather than a '.'.  This is not
+      //   mentioned in the documentation for java.lang.Class.  I don't want to
+      //   make assumptions, so I'm just going to accept any character as the
+      //   separator.
+      assertTrue(fullName.startsWith(outerName));
+      innerClassNames.add(fullName.substring(outerName.length() + 1));
+    }
+
+    // No service class should have been generated.
+    assertTrue(innerClassNames.contains("TestMessage"));
+    assertTrue(innerClassNames.contains("TestEnum"));
+    assertFalse(innerClassNames.contains("TestService"));
+
+    // But descriptors are there.
+    FileDescriptor file = UnittestNoGenericServices.getDescriptor();
+    assertEquals(1, file.getServices().size());
+    assertEquals("TestService", file.getServices().get(0).getName());
+    assertEquals(1, file.getServices().get(0).getMethods().size());
+    assertEquals("Foo",
+        file.getServices().get(0).getMethods().get(0).getName());
+  }
+
   // =================================================================
 
   /**

+ 92 - 0
java/src/test/java/com/google/protobuf/TestUtil.java

@@ -217,6 +217,7 @@ import protobuf_unittest.UnittestProto.TestAllExtensions;
 import protobuf_unittest.UnittestProto.TestAllTypes;
 import protobuf_unittest.UnittestProto.TestPackedExtensions;
 import protobuf_unittest.UnittestProto.TestPackedTypes;
+import protobuf_unittest.UnittestProto.TestUnpackedTypes;
 import protobuf_unittest.UnittestProto.ForeignMessage;
 import protobuf_unittest.UnittestProto.ForeignEnum;
 import com.google.protobuf.test.UnittestImport.ImportMessage;
@@ -289,6 +290,12 @@ class TestUtil {
     return builder.build();
   }
 
+  public static TestUnpackedTypes getUnpackedSet() {
+    TestUnpackedTypes.Builder builder = TestUnpackedTypes.newBuilder();
+    setUnpackedFields(builder);
+    return builder.build();
+  }
+
   public static TestPackedExtensions getPackedExtensionsSet() {
     TestPackedExtensions.Builder builder = TestPackedExtensions.newBuilder();
     setPackedExtensions(builder);
@@ -955,6 +962,42 @@ class TestUtil {
     message.addPackedEnum    (ForeignEnum.FOREIGN_BAZ);
   }
 
+  /**
+   * Set every field of {@code message} to a unique value. Must correspond with
+   * the values applied by {@code setPackedFields}.
+   */
+  public static void setUnpackedFields(TestUnpackedTypes.Builder message) {
+    message.addUnpackedInt32   (601);
+    message.addUnpackedInt64   (602);
+    message.addUnpackedUint32  (603);
+    message.addUnpackedUint64  (604);
+    message.addUnpackedSint32  (605);
+    message.addUnpackedSint64  (606);
+    message.addUnpackedFixed32 (607);
+    message.addUnpackedFixed64 (608);
+    message.addUnpackedSfixed32(609);
+    message.addUnpackedSfixed64(610);
+    message.addUnpackedFloat   (611);
+    message.addUnpackedDouble  (612);
+    message.addUnpackedBool    (true);
+    message.addUnpackedEnum    (ForeignEnum.FOREIGN_BAR);
+    // Add a second one of each field.
+    message.addUnpackedInt32   (701);
+    message.addUnpackedInt64   (702);
+    message.addUnpackedUint32  (703);
+    message.addUnpackedUint64  (704);
+    message.addUnpackedSint32  (705);
+    message.addUnpackedSint64  (706);
+    message.addUnpackedFixed32 (707);
+    message.addUnpackedFixed64 (708);
+    message.addUnpackedSfixed32(709);
+    message.addUnpackedSfixed64(710);
+    message.addUnpackedFloat   (711);
+    message.addUnpackedDouble  (712);
+    message.addUnpackedBool    (false);
+    message.addUnpackedEnum    (ForeignEnum.FOREIGN_BAZ);
+  }
+
   /**
    * Assert (using {@code junit.framework.Assert}} that all fields of
    * {@code message} are set to the values assigned by {@code setPackedFields}.
@@ -1004,6 +1047,55 @@ class TestUtil {
     Assert.assertEquals(ForeignEnum.FOREIGN_BAZ, message.getPackedEnum(1));
   }
 
+  /**
+   * Assert (using {@code junit.framework.Assert}} that all fields of
+   * {@code message} are set to the values assigned by {@code setUnpackedFields}.
+   */
+  public static void assertUnpackedFieldsSet(TestUnpackedTypes message) {
+    Assert.assertEquals(2, message.getUnpackedInt32Count   ());
+    Assert.assertEquals(2, message.getUnpackedInt64Count   ());
+    Assert.assertEquals(2, message.getUnpackedUint32Count  ());
+    Assert.assertEquals(2, message.getUnpackedUint64Count  ());
+    Assert.assertEquals(2, message.getUnpackedSint32Count  ());
+    Assert.assertEquals(2, message.getUnpackedSint64Count  ());
+    Assert.assertEquals(2, message.getUnpackedFixed32Count ());
+    Assert.assertEquals(2, message.getUnpackedFixed64Count ());
+    Assert.assertEquals(2, message.getUnpackedSfixed32Count());
+    Assert.assertEquals(2, message.getUnpackedSfixed64Count());
+    Assert.assertEquals(2, message.getUnpackedFloatCount   ());
+    Assert.assertEquals(2, message.getUnpackedDoubleCount  ());
+    Assert.assertEquals(2, message.getUnpackedBoolCount    ());
+    Assert.assertEquals(2, message.getUnpackedEnumCount   ());
+    Assert.assertEquals(601  , message.getUnpackedInt32   (0));
+    Assert.assertEquals(602  , message.getUnpackedInt64   (0));
+    Assert.assertEquals(603  , message.getUnpackedUint32  (0));
+    Assert.assertEquals(604  , message.getUnpackedUint64  (0));
+    Assert.assertEquals(605  , message.getUnpackedSint32  (0));
+    Assert.assertEquals(606  , message.getUnpackedSint64  (0));
+    Assert.assertEquals(607  , message.getUnpackedFixed32 (0));
+    Assert.assertEquals(608  , message.getUnpackedFixed64 (0));
+    Assert.assertEquals(609  , message.getUnpackedSfixed32(0));
+    Assert.assertEquals(610  , message.getUnpackedSfixed64(0));
+    Assert.assertEquals(611  , message.getUnpackedFloat   (0), 0.0);
+    Assert.assertEquals(612  , message.getUnpackedDouble  (0), 0.0);
+    Assert.assertEquals(true , message.getUnpackedBool    (0));
+    Assert.assertEquals(ForeignEnum.FOREIGN_BAR, message.getUnpackedEnum(0));
+    Assert.assertEquals(701  , message.getUnpackedInt32   (1));
+    Assert.assertEquals(702  , message.getUnpackedInt64   (1));
+    Assert.assertEquals(703  , message.getUnpackedUint32  (1));
+    Assert.assertEquals(704  , message.getUnpackedUint64  (1));
+    Assert.assertEquals(705  , message.getUnpackedSint32  (1));
+    Assert.assertEquals(706  , message.getUnpackedSint64  (1));
+    Assert.assertEquals(707  , message.getUnpackedFixed32 (1));
+    Assert.assertEquals(708  , message.getUnpackedFixed64 (1));
+    Assert.assertEquals(709  , message.getUnpackedSfixed32(1));
+    Assert.assertEquals(710  , message.getUnpackedSfixed64(1));
+    Assert.assertEquals(711  , message.getUnpackedFloat   (1), 0.0);
+    Assert.assertEquals(712  , message.getUnpackedDouble  (1), 0.0);
+    Assert.assertEquals(false, message.getUnpackedBool    (1));
+    Assert.assertEquals(ForeignEnum.FOREIGN_BAZ, message.getUnpackedEnum(1));
+  }
+
   // ===================================================================
   // Like above, but for extensions
 

+ 19 - 3
java/src/test/java/com/google/protobuf/TextFormatTest.java

@@ -68,7 +68,7 @@ public class TextFormatTest extends TestCase {
   private static String allExtensionsSetText = TestUtil.readTextFromFile(
     "text_format_unittest_extensions_data.txt");
 
-  private String exoticText =
+  private static String exoticText =
     "repeated_int32: -1\n" +
     "repeated_int32: -2147483648\n" +
     "repeated_int64: -1\n" +
@@ -80,7 +80,13 @@ public class TextFormatTest extends TestCase {
     "repeated_double: 123.0\n" +
     "repeated_double: 123.5\n" +
     "repeated_double: 0.125\n" +
+    "repeated_double: .125\n" +
+    "repeated_double: -.125\n" +
     "repeated_double: 1.23E17\n" +
+    "repeated_double: 1.23E+17\n" +
+    "repeated_double: -1.23e-17\n" +
+    "repeated_double: .23e+17\n" +
+    "repeated_double: -.23E17\n" +
     "repeated_double: 1.235E22\n" +
     "repeated_double: 1.235E-18\n" +
     "repeated_double: 123.456789\n" +
@@ -91,6 +97,10 @@ public class TextFormatTest extends TestCase {
       "\\341\\210\\264\"\n" +
     "repeated_bytes: \"\\000\\001\\a\\b\\f\\n\\r\\t\\v\\\\\\'\\\"\\376\"\n";
 
+  private static String canonicalExoticText =
+      exoticText.replace(": .", ": 0.").replace(": -.", ": -0.")   // short-form double
+      .replace("23e", "23E").replace("E+", "E").replace("0.23E17", "2.3E16");
+
   private String messageSetText =
     "[protobuf_unittest.TestMessageSetExtension1] {\n" +
     "  i: 123\n" +
@@ -231,7 +241,13 @@ public class TextFormatTest extends TestCase {
       .addRepeatedDouble(123)
       .addRepeatedDouble(123.5)
       .addRepeatedDouble(0.125)
+      .addRepeatedDouble(.125)
+      .addRepeatedDouble(-.125)
+      .addRepeatedDouble(123e15)
       .addRepeatedDouble(123e15)
+      .addRepeatedDouble(-1.23e-17)
+      .addRepeatedDouble(.23e17)
+      .addRepeatedDouble(-23e15)
       .addRepeatedDouble(123.5e20)
       .addRepeatedDouble(123.5e-20)
       .addRepeatedDouble(123.456789)
@@ -244,7 +260,7 @@ public class TextFormatTest extends TestCase {
       .addRepeatedBytes(bytes("\0\001\007\b\f\n\r\t\013\\\'\"\u00fe"))
       .build();
 
-    assertEquals(exoticText, message.toString());
+    assertEquals(canonicalExoticText, message.toString());
   }
 
   public void testPrintMessageSet() throws Exception {
@@ -319,7 +335,7 @@ public class TextFormatTest extends TestCase {
 
     // Too lazy to check things individually.  Don't try to debug this
     // if testPrintExotic() is failing.
-    assertEquals(exoticText, builder.build().toString());
+    assertEquals(canonicalExoticText, builder.build().toString());
   }
 
   public void testParseMessageSet() throws Exception {

+ 3 - 0
java/src/test/java/com/google/protobuf/WireFormatTest.java

@@ -235,6 +235,9 @@ public class WireFormatTest extends TestCase {
     TestUtil.assertPackedFieldsSet(TestPackedTypes.parseDelimitedFrom(input));
     assertEquals(34, input.read());
     assertEquals(-1, input.read());
+
+    // We're at EOF, so parsing again should return null.
+    assertTrue(TestAllTypes.parseDelimitedFrom(input) == null);
   }
 
   private void assertFieldsInOrder(ByteString data) throws Exception {

+ 184 - 27
python/google/protobuf/descriptor.py

@@ -44,12 +44,24 @@ file, in types that make this information accessible in Python.
 
 __author__ = 'robinson@google.com (Will Robinson)'
 
+
+class Error(Exception):
+  """Base error for this module."""
+
+
 class DescriptorBase(object):
 
   """Descriptors base class.
 
   This class is the base of all descriptor classes. It provides common options
   related functionaility.
+
+  Attributes:
+    has_options:  True if the descriptor has non-default options.  Usually it
+        is not necessary to read this -- just call GetOptions() which will
+        happily return the default instance.  However, it's sometimes useful
+        for efficiency, and also useful inside the protobuf implementation to
+        avoid some bootstrapping issues.
   """
 
   def __init__(self, options, options_class_name):
@@ -60,6 +72,9 @@ class DescriptorBase(object):
     self._options = options
     self._options_class_name = options_class_name
 
+    # Does this descriptor have non-default options?
+    self.has_options = options is not None
+
   def GetOptions(self):
     """Retrieves descriptor options.
 
@@ -78,7 +93,70 @@ class DescriptorBase(object):
     return self._options
 
 
-class Descriptor(DescriptorBase):
+class _NestedDescriptorBase(DescriptorBase):
+  """Common class for descriptors that can be nested."""
+
+  def __init__(self, options, options_class_name, name, full_name,
+               file, containing_type, serialized_start=None,
+               serialized_end=None):
+    """Constructor.
+
+    Args:
+      options: Protocol message options or None
+        to use default message options.
+      options_class_name: (str) The class name of the above options.
+
+      name: (str) Name of this protocol message type.
+      full_name: (str) Fully-qualified name of this protocol message type,
+        which will include protocol "package" name and the name of any
+        enclosing types.
+      file: (FileDescriptor) Reference to file info.
+      containing_type: if provided, this is a nested descriptor, with this
+        descriptor as parent, otherwise None.
+      serialized_start: The start index (inclusive) in block in the
+        file.serialized_pb that describes this descriptor.
+      serialized_end: The end index (exclusive) in block in the
+        file.serialized_pb that describes this descriptor.
+    """
+    super(_NestedDescriptorBase, self).__init__(
+        options, options_class_name)
+
+    self.name = name
+    # TODO(falk): Add function to calculate full_name instead of having it in
+    #             memory?
+    self.full_name = full_name
+    self.file = file
+    self.containing_type = containing_type
+
+    self._serialized_start = serialized_start
+    self._serialized_end = serialized_end
+
+  def GetTopLevelContainingType(self):
+    """Returns the root if this is a nested type, or itself if its the root."""
+    desc = self
+    while desc.containing_type is not None:
+      desc = desc.containing_type
+    return desc
+
+  def CopyToProto(self, proto):
+    """Copies this to the matching proto in descriptor_pb2.
+
+    Args:
+      proto: An empty proto instance from descriptor_pb2.
+
+    Raises:
+      Error: If self couldnt be serialized, due to to few constructor arguments.
+    """
+    if (self.file is not None and
+        self._serialized_start is not None and
+        self._serialized_end is not None):
+      proto.ParseFromString(self.file.serialized_pb[
+          self._serialized_start:self._serialized_end])
+    else:
+      raise Error('Descriptor does not contain serialization.')
+
+
+class Descriptor(_NestedDescriptorBase):
 
   """Descriptor for a protocol message type.
 
@@ -89,10 +167,8 @@ class Descriptor(DescriptorBase):
       which will include protocol "package" name and the name of any
       enclosing types.
 
-    filename: (str) Name of the .proto file containing this message.
-
     containing_type: (Descriptor) Reference to the descriptor of the
-      type containing us, or None if we have no containing type.
+      type containing us, or None if this is top-level.
 
     fields: (list of FieldDescriptors) Field descriptors for all
       fields in this type.
@@ -123,20 +199,28 @@ class Descriptor(DescriptorBase):
       objects as |extensions|, but indexed by "name" attribute of each
       FieldDescriptor.
 
+    is_extendable:  Does this type define any extension ranges?
+
     options: (descriptor_pb2.MessageOptions) Protocol message options or None
       to use default message options.
+
+    file: (FileDescriptor) Reference to file descriptor.
   """
 
-  def __init__(self, name, full_name, filename, containing_type,
-               fields, nested_types, enum_types, extensions, options=None):
+  def __init__(self, name, full_name, filename, containing_type, fields,
+               nested_types, enum_types, extensions, options=None,
+               is_extendable=True, extension_ranges=None, file=None,
+               serialized_start=None, serialized_end=None):
     """Arguments to __init__() are as described in the description
     of Descriptor fields above.
+
+    Note that filename is an obsolete argument, that is not used anymore.
+    Please use file.name to access this as an attribute.
     """
-    super(Descriptor, self).__init__(options, 'MessageOptions')
-    self.name = name
-    self.full_name = full_name
-    self.filename = filename
-    self.containing_type = containing_type
+    super(Descriptor, self).__init__(
+        options, 'MessageOptions', name, full_name, file,
+        containing_type, serialized_start=serialized_start,
+        serialized_end=serialized_start)
 
     # We have fields in addition to fields_by_name and fields_by_number,
     # so that:
@@ -163,6 +247,20 @@ class Descriptor(DescriptorBase):
     for extension in self.extensions:
       extension.extension_scope = self
     self.extensions_by_name = dict((f.name, f) for f in extensions)
+    self.is_extendable = is_extendable
+    self.extension_ranges = extension_ranges
+
+    self._serialized_start = serialized_start
+    self._serialized_end = serialized_end
+
+  def CopyToProto(self, proto):
+    """Copies this to a descriptor_pb2.DescriptorProto.
+
+    Args:
+      proto: An empty descriptor_pb2.DescriptorProto.
+    """
+    # This function is overriden to give a better doc comment.
+    super(Descriptor, self).CopyToProto(proto)
 
 
 # TODO(robinson): We should have aggressive checking here,
@@ -195,6 +293,8 @@ class FieldDescriptor(DescriptorBase):
 
     label: (One of the LABEL_* constants below) Tells whether this
       field is optional, required, or repeated.
+    has_default_value: (bool) True if this field has a default value defined,
+      otherwise false.
     default_value: (Varies) Default value of this field.  Only
       meaningful for non-repeated scalar fields.  Repeated fields
       should always set this to [], and non-repeated composite
@@ -272,7 +372,8 @@ class FieldDescriptor(DescriptorBase):
 
   def __init__(self, name, full_name, index, number, type, cpp_type, label,
                default_value, message_type, enum_type, containing_type,
-               is_extension, extension_scope, options=None):
+               is_extension, extension_scope, options=None,
+               has_default_value=True):
     """The arguments are as described in the description of FieldDescriptor
     attributes above.
 
@@ -288,6 +389,7 @@ class FieldDescriptor(DescriptorBase):
     self.type = type
     self.cpp_type = cpp_type
     self.label = label
+    self.has_default_value = has_default_value
     self.default_value = default_value
     self.containing_type = containing_type
     self.message_type = message_type
@@ -296,7 +398,7 @@ class FieldDescriptor(DescriptorBase):
     self.extension_scope = extension_scope
 
 
-class EnumDescriptor(DescriptorBase):
+class EnumDescriptor(_NestedDescriptorBase):
 
   """Descriptor for an enum defined in a .proto file.
 
@@ -305,7 +407,6 @@ class EnumDescriptor(DescriptorBase):
     name: (str) Name of the enum type.
     full_name: (str) Full name of the type, including package name
       and any enclosing type(s).
-    filename: (str) Name of the .proto file in which this appears.
 
     values: (list of EnumValueDescriptors) List of the values
       in this enum.
@@ -317,23 +418,41 @@ class EnumDescriptor(DescriptorBase):
       type of this enum, or None if this is an enum defined at the
       top level in a .proto file.  Set by Descriptor's constructor
       if we're passed into one.
+    file: (FileDescriptor) Reference to file descriptor.
     options: (descriptor_pb2.EnumOptions) Enum options message or
       None to use default enum options.
   """
 
   def __init__(self, name, full_name, filename, values,
-               containing_type=None, options=None):
-    """Arguments are as described in the attribute description above."""
-    super(EnumDescriptor, self).__init__(options, 'EnumOptions')
-    self.name = name
-    self.full_name = full_name
-    self.filename = filename
+               containing_type=None, options=None, file=None,
+               serialized_start=None, serialized_end=None):
+    """Arguments are as described in the attribute description above.
+
+    Note that filename is an obsolete argument, that is not used anymore.
+    Please use file.name to access this as an attribute.
+    """
+    super(EnumDescriptor, self).__init__(
+        options, 'EnumOptions', name, full_name, file,
+        containing_type, serialized_start=serialized_start,
+        serialized_end=serialized_start)
+
     self.values = values
     for value in self.values:
       value.type = self
     self.values_by_name = dict((v.name, v) for v in values)
     self.values_by_number = dict((v.number, v) for v in values)
-    self.containing_type = containing_type
+
+    self._serialized_start = serialized_start
+    self._serialized_end = serialized_end
+
+  def CopyToProto(self, proto):
+    """Copies this to a descriptor_pb2.EnumDescriptorProto.
+
+    Args:
+      proto: An empty descriptor_pb2.EnumDescriptorProto.
+    """
+    # This function is overriden to give a better doc comment.
+    super(EnumDescriptor, self).CopyToProto(proto)
 
 
 class EnumValueDescriptor(DescriptorBase):
@@ -360,7 +479,7 @@ class EnumValueDescriptor(DescriptorBase):
     self.type = type
 
 
-class ServiceDescriptor(DescriptorBase):
+class ServiceDescriptor(_NestedDescriptorBase):
 
   """Descriptor for a service.
 
@@ -372,12 +491,15 @@ class ServiceDescriptor(DescriptorBase):
       service.
     options: (descriptor_pb2.ServiceOptions) Service options message or
       None to use default service options.
+    file: (FileDescriptor) Reference to file info.
   """
 
-  def __init__(self, name, full_name, index, methods, options=None):
-    super(ServiceDescriptor, self).__init__(options, 'ServiceOptions')
-    self.name = name
-    self.full_name = full_name
+  def __init__(self, name, full_name, index, methods, options=None, file=None,
+               serialized_start=None, serialized_end=None):
+    super(ServiceDescriptor, self).__init__(
+        options, 'ServiceOptions', name, full_name, file,
+        None, serialized_start=serialized_start,
+        serialized_end=serialized_end)
     self.index = index
     self.methods = methods
     # Set the containing service for each method in this service.
@@ -391,6 +513,15 @@ class ServiceDescriptor(DescriptorBase):
         return method
     return None
 
+  def CopyToProto(self, proto):
+    """Copies this to a descriptor_pb2.ServiceDescriptorProto.
+
+    Args:
+      proto: An empty descriptor_pb2.ServiceDescriptorProto.
+    """
+    # This function is overriden to give a better doc comment.
+    super(ServiceDescriptor, self).CopyToProto(proto)
+
 
 class MethodDescriptor(DescriptorBase):
 
@@ -423,6 +554,32 @@ class MethodDescriptor(DescriptorBase):
     self.output_type = output_type
 
 
+class FileDescriptor(DescriptorBase):
+  """Descriptor for a file. Mimics the descriptor_pb2.FileDescriptorProto.
+
+  name: name of file, relative to root of source tree.
+  package: name of the package
+  serialized_pb: (str) Byte string of serialized
+    descriptor_pb2.FileDescriptorProto.
+  """
+
+  def __init__(self, name, package, options=None, serialized_pb=None):
+    """Constructor."""
+    super(FileDescriptor, self).__init__(options, 'FileOptions')
+
+    self.name = name
+    self.package = package
+    self.serialized_pb = serialized_pb
+
+  def CopyToProto(self, proto):
+    """Copies this to a descriptor_pb2.FileDescriptorProto.
+
+    Args:
+      proto: An empty descriptor_pb2.FileDescriptorProto.
+    """
+    proto.ParseFromString(self.serialized_pb)
+
+
 def _ParseOptions(message, string):
   """Parses serialized options.
 
@@ -430,4 +587,4 @@ def _ParseOptions(message, string):
   proto2 files. It must not be used outside proto2.
   """
   message.ParseFromString(string)
-  return message;
+  return message

+ 43 - 26
python/google/protobuf/internal/containers.py

@@ -54,8 +54,7 @@ class BaseContainer(object):
     Args:
       message_listener: A MessageListener implementation.
         The RepeatedScalarFieldContainer will call this object's
-        TransitionToNonempty() method when it transitions from being empty to
-        being nonempty.
+        Modified() method when it is modified.
     """
     self._message_listener = message_listener
     self._values = []
@@ -73,6 +72,9 @@ class BaseContainer(object):
     # The concrete classes should define __eq__.
     return not self == other
 
+  def __repr__(self):
+    return repr(self._values)
+
 
 class RepeatedScalarFieldContainer(BaseContainer):
 
@@ -86,8 +88,7 @@ class RepeatedScalarFieldContainer(BaseContainer):
     Args:
       message_listener: A MessageListener implementation.
         The RepeatedScalarFieldContainer will call this object's
-        TransitionToNonempty() method when it transitions from being empty to
-        being nonempty.
+        Modified() method when it is modified.
       type_checker: A type_checkers.ValueChecker instance to run on elements
         inserted into this container.
     """
@@ -96,44 +97,47 @@ class RepeatedScalarFieldContainer(BaseContainer):
 
   def append(self, value):
     """Appends an item to the list. Similar to list.append()."""
-    self.insert(len(self._values), value)
+    self._type_checker.CheckValue(value)
+    self._values.append(value)
+    if not self._message_listener.dirty:
+      self._message_listener.Modified()
 
   def insert(self, key, value):
     """Inserts the item at the specified position. Similar to list.insert()."""
     self._type_checker.CheckValue(value)
     self._values.insert(key, value)
-    self._message_listener.ByteSizeDirty()
-    if len(self._values) == 1:
-      self._message_listener.TransitionToNonempty()
+    if not self._message_listener.dirty:
+      self._message_listener.Modified()
 
   def extend(self, elem_seq):
     """Extends by appending the given sequence. Similar to list.extend()."""
     if not elem_seq:
       return
 
-    orig_empty = len(self._values) == 0
     new_values = []
     for elem in elem_seq:
       self._type_checker.CheckValue(elem)
       new_values.append(elem)
     self._values.extend(new_values)
-    self._message_listener.ByteSizeDirty()
-    if orig_empty:
-      self._message_listener.TransitionToNonempty()
+    self._message_listener.Modified()
+
+  def MergeFrom(self, other):
+    """Appends the contents of another repeated field of the same type to this
+    one. We do not check the types of the individual fields.
+    """
+    self._values.extend(other._values)
+    self._message_listener.Modified()
 
   def remove(self, elem):
     """Removes an item from the list. Similar to list.remove()."""
     self._values.remove(elem)
-    self._message_listener.ByteSizeDirty()
+    self._message_listener.Modified()
 
   def __setitem__(self, key, value):
     """Sets the item on the specified position."""
-    # No need to call TransitionToNonempty(), since if we're able to
-    # set the element at this index, we were already nonempty before
-    # this method was called.
-    self._message_listener.ByteSizeDirty()
     self._type_checker.CheckValue(value)
     self._values[key] = value
+    self._message_listener.Modified()
 
   def __getslice__(self, start, stop):
     """Retrieves the subset of items from between the specified indices."""
@@ -146,17 +150,17 @@ class RepeatedScalarFieldContainer(BaseContainer):
       self._type_checker.CheckValue(value)
       new_values.append(value)
     self._values[start:stop] = new_values
-    self._message_listener.ByteSizeDirty()
+    self._message_listener.Modified()
 
   def __delitem__(self, key):
     """Deletes the item at the specified position."""
     del self._values[key]
-    self._message_listener.ByteSizeDirty()
+    self._message_listener.Modified()
 
   def __delslice__(self, start, stop):
     """Deletes the subset of items from between the specified indices."""
     del self._values[start:stop]
-    self._message_listener.ByteSizeDirty()
+    self._message_listener.Modified()
 
   def __eq__(self, other):
     """Compares the current instance with another one."""
@@ -186,8 +190,7 @@ class RepeatedCompositeFieldContainer(BaseContainer):
     Args:
       message_listener: A MessageListener implementation.
         The RepeatedCompositeFieldContainer will call this object's
-        TransitionToNonempty() method when it transitions from being empty to
-        being nonempty.
+        Modified() method when it is modified.
       message_descriptor: A Descriptor instance describing the protocol type
         that should be present in this container.  We'll use the
         _concrete_class field of this descriptor when the client calls add().
@@ -199,10 +202,24 @@ class RepeatedCompositeFieldContainer(BaseContainer):
     new_element = self._message_descriptor._concrete_class()
     new_element._SetListener(self._message_listener)
     self._values.append(new_element)
-    self._message_listener.ByteSizeDirty()
-    self._message_listener.TransitionToNonempty()
+    if not self._message_listener.dirty:
+      self._message_listener.Modified()
     return new_element
 
+  def MergeFrom(self, other):
+    """Appends the contents of another repeated field of the same type to this
+    one, copying each individual message.
+    """
+    message_class = self._message_descriptor._concrete_class
+    listener = self._message_listener
+    values = self._values
+    for message in other._values:
+      new_element = message_class()
+      new_element._SetListener(listener)
+      new_element.MergeFrom(message)
+      values.append(new_element)
+    listener.Modified()
+
   def __getslice__(self, start, stop):
     """Retrieves the subset of items from between the specified indices."""
     return self._values[start:stop]
@@ -210,12 +227,12 @@ class RepeatedCompositeFieldContainer(BaseContainer):
   def __delitem__(self, key):
     """Deletes the item at the specified position."""
     del self._values[key]
-    self._message_listener.ByteSizeDirty()
+    self._message_listener.Modified()
 
   def __delslice__(self, start, stop):
     """Deletes the subset of items from between the specified indices."""
     del self._values[start:stop]
-    self._message_listener.ByteSizeDirty()
+    self._message_listener.Modified()
 
   def __eq__(self, other):
     """Compares the current instance with another one."""

+ 601 - 169
python/google/protobuf/internal/decoder.py

@@ -28,182 +28,614 @@
 # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
 # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 
-"""Class for decoding protocol buffer primitives.
-
-Contains the logic for decoding every logical protocol field type
-from one of the 5 physical wire types.
+"""Code for decoding protocol buffer primitives.
+
+This code is very similar to encoder.py -- read the docs for that module first.
+
+A "decoder" is a function with the signature:
+  Decode(buffer, pos, end, message, field_dict)
+The arguments are:
+  buffer:     The string containing the encoded message.
+  pos:        The current position in the string.
+  end:        The position in the string where the current message ends.  May be
+              less than len(buffer) if we're reading a sub-message.
+  message:    The message object into which we're parsing.
+  field_dict: message._fields (avoids a hashtable lookup).
+The decoder reads the field and stores it into field_dict, returning the new
+buffer position.  A decoder for a repeated field may proactively decode all of
+the elements of that field, if they appear consecutively.
+
+Note that decoders may throw any of the following:
+  IndexError:  Indicates a truncated message.
+  struct.error:  Unpacking of a fixed-width field failed.
+  message.DecodeError:  Other errors.
+
+Decoders are expected to raise an exception if they are called with pos > end.
+This allows callers to be lax about bounds checking:  it's fineto read past
+"end" as long as you are sure that someone else will notice and throw an
+exception later on.
+
+Something up the call stack is expected to catch IndexError and struct.error
+and convert them to message.DecodeError.
+
+Decoders are constructed using decoder constructors with the signature:
+  MakeDecoder(field_number, is_repeated, is_packed, key, new_default)
+The arguments are:
+  field_number:  The field number of the field we want to decode.
+  is_repeated:   Is the field a repeated field? (bool)
+  is_packed:     Is the field a packed field? (bool)
+  key:           The key to use when looking up the field within field_dict.
+                 (This is actually the FieldDescriptor but nothing in this
+                 file should depend on that.)
+  new_default:   A function which takes a message object as a parameter and
+                 returns a new instance of the default value for this field.
+                 (This is called for repeated fields and sub-messages, when an
+                 instance does not already exist.)
+
+As with encoders, we define a decoder constructor for every type of field.
+Then, for every field of every message class we construct an actual decoder.
+That decoder goes into a dict indexed by tag, so when we decode a message
+we repeatedly read a tag, look up the corresponding decoder, and invoke it.
 """
 
-__author__ = 'robinson@google.com (Will Robinson)'
+__author__ = 'kenton@google.com (Kenton Varda)'
 
 import struct
-from google.protobuf import message
-from google.protobuf.internal import input_stream
+from google.protobuf.internal import encoder
 from google.protobuf.internal import wire_format
+from google.protobuf import message
 
 
+# This is not for optimization, but rather to avoid conflicts with local
+# variables named "message".
+_DecodeError = message.DecodeError
+
+
+def _VarintDecoder(mask):
+  """Return an encoder for a basic varint value (does not include tag).
+
+  Decoded values will be bitwise-anded with the given mask before being
+  returned, e.g. to limit them to 32 bits.  The returned decoder does not
+  take the usual "end" parameter -- the caller is expected to do bounds checking
+  after the fact (often the caller can defer such checking until later).  The
+  decoder returns a (value, new_pos) pair.
+  """
+
+  local_ord = ord
+  def DecodeVarint(buffer, pos):
+    result = 0
+    shift = 0
+    while 1:
+      b = local_ord(buffer[pos])
+      result |= ((b & 0x7f) << shift)
+      pos += 1
+      if not (b & 0x80):
+        result &= mask
+        return (result, pos)
+      shift += 7
+      if shift >= 64:
+        raise _DecodeError('Too many bytes when decoding varint.')
+  return DecodeVarint
+
+
+def _SignedVarintDecoder(mask):
+  """Like _VarintDecoder() but decodes signed values."""
+
+  local_ord = ord
+  def DecodeVarint(buffer, pos):
+    result = 0
+    shift = 0
+    while 1:
+      b = local_ord(buffer[pos])
+      result |= ((b & 0x7f) << shift)
+      pos += 1
+      if not (b & 0x80):
+        if result > 0x7fffffffffffffff:
+          result -= (1 << 64)
+          result |= ~mask
+        else:
+          result &= mask
+        return (result, pos)
+      shift += 7
+      if shift >= 64:
+        raise _DecodeError('Too many bytes when decoding varint.')
+  return DecodeVarint
+
+
+_DecodeVarint = _VarintDecoder((1 << 64) - 1)
+_DecodeSignedVarint = _SignedVarintDecoder((1 << 64) - 1)
+
+# Use these versions for values which must be limited to 32 bits.
+_DecodeVarint32 = _VarintDecoder((1 << 32) - 1)
+_DecodeSignedVarint32 = _SignedVarintDecoder((1 << 32) - 1)
+
+
+def ReadTag(buffer, pos):
+  """Read a tag from the buffer, and return a (tag_bytes, new_pos) tuple.
+
+  We return the raw bytes of the tag rather than decoding them.  The raw
+  bytes can then be used to look up the proper decoder.  This effectively allows
+  us to trade some work that would be done in pure-python (decoding a varint)
+  for work that is done in C (searching for a byte string in a hash table).
+  In a low-level language it would be much cheaper to decode the varint and
+  use that, but not in Python.
+  """
+
+  start = pos
+  while ord(buffer[pos]) & 0x80:
+    pos += 1
+  pos += 1
+  return (buffer[start:pos], pos)
+
+
+# --------------------------------------------------------------------
+
+
+def _SimpleDecoder(wire_type, decode_value):
+  """Return a constructor for a decoder for fields of a particular type.
+
+  Args:
+      wire_type:  The field's wire type.
+      decode_value:  A function which decodes an individual value, e.g.
+        _DecodeVarint()
+  """
+
+  def SpecificDecoder(field_number, is_repeated, is_packed, key, new_default):
+    if is_packed:
+      local_DecodeVarint = _DecodeVarint
+      def DecodePackedField(buffer, pos, end, message, field_dict):
+        value = field_dict.get(key)
+        if value is None:
+          value = field_dict.setdefault(key, new_default(message))
+        (endpoint, pos) = local_DecodeVarint(buffer, pos)
+        endpoint += pos
+        if endpoint > end:
+          raise _DecodeError('Truncated message.')
+        while pos < endpoint:
+          (element, pos) = decode_value(buffer, pos)
+          value.append(element)
+        if pos > endpoint:
+          del value[-1]   # Discard corrupt value.
+          raise _DecodeError('Packed element was truncated.')
+        return pos
+      return DecodePackedField
+    elif is_repeated:
+      tag_bytes = encoder.TagBytes(field_number, wire_type)
+      tag_len = len(tag_bytes)
+      def DecodeRepeatedField(buffer, pos, end, message, field_dict):
+        value = field_dict.get(key)
+        if value is None:
+          value = field_dict.setdefault(key, new_default(message))
+        while 1:
+          (element, new_pos) = decode_value(buffer, pos)
+          value.append(element)
+          # Predict that the next tag is another copy of the same repeated
+          # field.
+          pos = new_pos + tag_len
+          if buffer[new_pos:pos] != tag_bytes or new_pos >= end:
+            # Prediction failed.  Return.
+            if new_pos > end:
+              raise _DecodeError('Truncated message.')
+            return new_pos
+      return DecodeRepeatedField
+    else:
+      def DecodeField(buffer, pos, end, message, field_dict):
+        (field_dict[key], pos) = decode_value(buffer, pos)
+        if pos > end:
+          del field_dict[key]  # Discard corrupt value.
+          raise _DecodeError('Truncated message.')
+        return pos
+      return DecodeField
+
+  return SpecificDecoder
+
+
+def _ModifiedDecoder(wire_type, decode_value, modify_value):
+  """Like SimpleDecoder but additionally invokes modify_value on every value
+  before storing it.  Usually modify_value is ZigZagDecode.
+  """
+
+  # Reusing _SimpleDecoder is slightly slower than copying a bunch of code, but
+  # not enough to make a significant difference.
+
+  def InnerDecode(buffer, pos):
+    (result, new_pos) = decode_value(buffer, pos)
+    return (modify_value(result), new_pos)
+  return _SimpleDecoder(wire_type, InnerDecode)
+
+
+def _StructPackDecoder(wire_type, format):
+  """Return a constructor for a decoder for a fixed-width field.
+
+  Args:
+      wire_type:  The field's wire type.
+      format:  The format string to pass to struct.unpack().
+  """
+
+  value_size = struct.calcsize(format)
+  local_unpack = struct.unpack
+
+  # Reusing _SimpleDecoder is slightly slower than copying a bunch of code, but
+  # not enough to make a significant difference.
+
+  # Note that we expect someone up-stack to catch struct.error and convert
+  # it to _DecodeError -- this way we don't have to set up exception-
+  # handling blocks every time we parse one value.
+
+  def InnerDecode(buffer, pos):
+    new_pos = pos + value_size
+    result = local_unpack(format, buffer[pos:new_pos])[0]
+    return (result, new_pos)
+  return _SimpleDecoder(wire_type, InnerDecode)
+
+
+# --------------------------------------------------------------------
+
+
+Int32Decoder = EnumDecoder = _SimpleDecoder(
+    wire_format.WIRETYPE_VARINT, _DecodeSignedVarint32)
+
+Int64Decoder = _SimpleDecoder(
+    wire_format.WIRETYPE_VARINT, _DecodeSignedVarint)
+
+UInt32Decoder = _SimpleDecoder(wire_format.WIRETYPE_VARINT, _DecodeVarint32)
+UInt64Decoder = _SimpleDecoder(wire_format.WIRETYPE_VARINT, _DecodeVarint)
+
+SInt32Decoder = _ModifiedDecoder(
+    wire_format.WIRETYPE_VARINT, _DecodeVarint32, wire_format.ZigZagDecode)
+SInt64Decoder = _ModifiedDecoder(
+    wire_format.WIRETYPE_VARINT, _DecodeVarint, wire_format.ZigZagDecode)
+
+# Note that Python conveniently guarantees that when using the '<' prefix on
+# formats, they will also have the same size across all platforms (as opposed
+# to without the prefix, where their sizes depend on the C compiler's basic
+# type sizes).
+Fixed32Decoder  = _StructPackDecoder(wire_format.WIRETYPE_FIXED32, '<I')
+Fixed64Decoder  = _StructPackDecoder(wire_format.WIRETYPE_FIXED64, '<Q')
+SFixed32Decoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED32, '<i')
+SFixed64Decoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED64, '<q')
+FloatDecoder    = _StructPackDecoder(wire_format.WIRETYPE_FIXED32, '<f')
+DoubleDecoder   = _StructPackDecoder(wire_format.WIRETYPE_FIXED64, '<d')
+
+BoolDecoder = _ModifiedDecoder(
+    wire_format.WIRETYPE_VARINT, _DecodeVarint, bool)
+
+
+def StringDecoder(field_number, is_repeated, is_packed, key, new_default):
+  """Returns a decoder for a string field."""
+
+  local_DecodeVarint = _DecodeVarint
+  local_unicode = unicode
+
+  assert not is_packed
+  if is_repeated:
+    tag_bytes = encoder.TagBytes(field_number,
+                                 wire_format.WIRETYPE_LENGTH_DELIMITED)
+    tag_len = len(tag_bytes)
+    def DecodeRepeatedField(buffer, pos, end, message, field_dict):
+      value = field_dict.get(key)
+      if value is None:
+        value = field_dict.setdefault(key, new_default(message))
+      while 1:
+        (size, pos) = local_DecodeVarint(buffer, pos)
+        new_pos = pos + size
+        if new_pos > end:
+          raise _DecodeError('Truncated string.')
+        value.append(local_unicode(buffer[pos:new_pos], 'utf-8'))
+        # Predict that the next tag is another copy of the same repeated field.
+        pos = new_pos + tag_len
+        if buffer[new_pos:pos] != tag_bytes or new_pos == end:
+          # Prediction failed.  Return.
+          return new_pos
+    return DecodeRepeatedField
+  else:
+    def DecodeField(buffer, pos, end, message, field_dict):
+      (size, pos) = local_DecodeVarint(buffer, pos)
+      new_pos = pos + size
+      if new_pos > end:
+        raise _DecodeError('Truncated string.')
+      field_dict[key] = local_unicode(buffer[pos:new_pos], 'utf-8')
+      return new_pos
+    return DecodeField
+
+
+def BytesDecoder(field_number, is_repeated, is_packed, key, new_default):
+  """Returns a decoder for a bytes field."""
+
+  local_DecodeVarint = _DecodeVarint
+
+  assert not is_packed
+  if is_repeated:
+    tag_bytes = encoder.TagBytes(field_number,
+                                 wire_format.WIRETYPE_LENGTH_DELIMITED)
+    tag_len = len(tag_bytes)
+    def DecodeRepeatedField(buffer, pos, end, message, field_dict):
+      value = field_dict.get(key)
+      if value is None:
+        value = field_dict.setdefault(key, new_default(message))
+      while 1:
+        (size, pos) = local_DecodeVarint(buffer, pos)
+        new_pos = pos + size
+        if new_pos > end:
+          raise _DecodeError('Truncated string.')
+        value.append(buffer[pos:new_pos])
+        # Predict that the next tag is another copy of the same repeated field.
+        pos = new_pos + tag_len
+        if buffer[new_pos:pos] != tag_bytes or new_pos == end:
+          # Prediction failed.  Return.
+          return new_pos
+    return DecodeRepeatedField
+  else:
+    def DecodeField(buffer, pos, end, message, field_dict):
+      (size, pos) = local_DecodeVarint(buffer, pos)
+      new_pos = pos + size
+      if new_pos > end:
+        raise _DecodeError('Truncated string.')
+      field_dict[key] = buffer[pos:new_pos]
+      return new_pos
+    return DecodeField
+
+
+def GroupDecoder(field_number, is_repeated, is_packed, key, new_default):
+  """Returns a decoder for a group field."""
+
+  end_tag_bytes = encoder.TagBytes(field_number,
+                                   wire_format.WIRETYPE_END_GROUP)
+  end_tag_len = len(end_tag_bytes)
+
+  assert not is_packed
+  if is_repeated:
+    tag_bytes = encoder.TagBytes(field_number,
+                                 wire_format.WIRETYPE_START_GROUP)
+    tag_len = len(tag_bytes)
+    def DecodeRepeatedField(buffer, pos, end, message, field_dict):
+      value = field_dict.get(key)
+      if value is None:
+        value = field_dict.setdefault(key, new_default(message))
+      while 1:
+        value = field_dict.get(key)
+        if value is None:
+          value = field_dict.setdefault(key, new_default(message))
+        # Read sub-message.
+        pos = value.add()._InternalParse(buffer, pos, end)
+        # Read end tag.
+        new_pos = pos+end_tag_len
+        if buffer[pos:new_pos] != end_tag_bytes or new_pos > end:
+          raise _DecodeError('Missing group end tag.')
+        # Predict that the next tag is another copy of the same repeated field.
+        pos = new_pos + tag_len
+        if buffer[new_pos:pos] != tag_bytes or new_pos == end:
+          # Prediction failed.  Return.
+          return new_pos
+    return DecodeRepeatedField
+  else:
+    def DecodeField(buffer, pos, end, message, field_dict):
+      value = field_dict.get(key)
+      if value is None:
+        value = field_dict.setdefault(key, new_default(message))
+      # Read sub-message.
+      pos = value._InternalParse(buffer, pos, end)
+      # Read end tag.
+      new_pos = pos+end_tag_len
+      if buffer[pos:new_pos] != end_tag_bytes or new_pos > end:
+        raise _DecodeError('Missing group end tag.')
+      return new_pos
+    return DecodeField
+
+
+def MessageDecoder(field_number, is_repeated, is_packed, key, new_default):
+  """Returns a decoder for a message field."""
+
+  local_DecodeVarint = _DecodeVarint
+
+  assert not is_packed
+  if is_repeated:
+    tag_bytes = encoder.TagBytes(field_number,
+                                 wire_format.WIRETYPE_LENGTH_DELIMITED)
+    tag_len = len(tag_bytes)
+    def DecodeRepeatedField(buffer, pos, end, message, field_dict):
+      value = field_dict.get(key)
+      if value is None:
+        value = field_dict.setdefault(key, new_default(message))
+      while 1:
+        value = field_dict.get(key)
+        if value is None:
+          value = field_dict.setdefault(key, new_default(message))
+        # Read length.
+        (size, pos) = local_DecodeVarint(buffer, pos)
+        new_pos = pos + size
+        if new_pos > end:
+          raise _DecodeError('Truncated message.')
+        # Read sub-message.
+        if value.add()._InternalParse(buffer, pos, new_pos) != new_pos:
+          # The only reason _InternalParse would return early is if it
+          # encountered an end-group tag.
+          raise _DecodeError('Unexpected end-group tag.')
+        # Predict that the next tag is another copy of the same repeated field.
+        pos = new_pos + tag_len
+        if buffer[new_pos:pos] != tag_bytes or new_pos == end:
+          # Prediction failed.  Return.
+          return new_pos
+    return DecodeRepeatedField
+  else:
+    def DecodeField(buffer, pos, end, message, field_dict):
+      value = field_dict.get(key)
+      if value is None:
+        value = field_dict.setdefault(key, new_default(message))
+      # Read length.
+      (size, pos) = local_DecodeVarint(buffer, pos)
+      new_pos = pos + size
+      if new_pos > end:
+        raise _DecodeError('Truncated message.')
+      # Read sub-message.
+      if value._InternalParse(buffer, pos, new_pos) != new_pos:
+        # The only reason _InternalParse would return early is if it encountered
+        # an end-group tag.
+        raise _DecodeError('Unexpected end-group tag.')
+      return new_pos
+    return DecodeField
+
+
+# --------------------------------------------------------------------
+
+MESSAGE_SET_ITEM_TAG = encoder.TagBytes(1, wire_format.WIRETYPE_START_GROUP)
+
+def MessageSetItemDecoder(extensions_by_number):
+  """Returns a decoder for a MessageSet item.
+
+  The parameter is the _extensions_by_number map for the message class.
+
+  The message set message looks like this:
+    message MessageSet {
+      repeated group Item = 1 {
+        required int32 type_id = 2;
+        required string message = 3;
+      }
+    }
+  """
+
+  type_id_tag_bytes = encoder.TagBytes(2, wire_format.WIRETYPE_VARINT)
+  message_tag_bytes = encoder.TagBytes(3, wire_format.WIRETYPE_LENGTH_DELIMITED)
+  item_end_tag_bytes = encoder.TagBytes(1, wire_format.WIRETYPE_END_GROUP)
+
+  local_ReadTag = ReadTag
+  local_DecodeVarint = _DecodeVarint
+  local_SkipField = SkipField
+
+  def DecodeItem(buffer, pos, end, message, field_dict):
+    type_id = -1
+    message_start = -1
+    message_end = -1
+
+    # Technically, type_id and message can appear in any order, so we need
+    # a little loop here.
+    while 1:
+      (tag_bytes, pos) = local_ReadTag(buffer, pos)
+      if tag_bytes == type_id_tag_bytes:
+        (type_id, pos) = local_DecodeVarint(buffer, pos)
+      elif tag_bytes == message_tag_bytes:
+        (size, message_start) = local_DecodeVarint(buffer, pos)
+        pos = message_end = message_start + size
+      elif tag_bytes == item_end_tag_bytes:
+        break
+      else:
+        pos = SkipField(buffer, pos, end, tag_bytes)
+        if pos == -1:
+          raise _DecodeError('Missing group end tag.')
+
+    if pos > end:
+      raise _DecodeError('Truncated message.')
+
+    if type_id == -1:
+      raise _DecodeError('MessageSet item missing type_id.')
+    if message_start == -1:
+      raise _DecodeError('MessageSet item missing message.')
+
+    extension = extensions_by_number.get(type_id)
+    if extension is not None:
+      value = field_dict.get(extension)
+      if value is None:
+        value = field_dict.setdefault(
+            extension, extension.message_type._concrete_class())
+      if value._InternalParse(buffer, message_start,message_end) != message_end:
+        # The only reason _InternalParse would return early is if it encountered
+        # an end-group tag.
+        raise _DecodeError('Unexpected end-group tag.')
+
+    return pos
+
+  return DecodeItem
+
+# --------------------------------------------------------------------
+# Optimization is not as heavy here because calls to SkipField() are rare,
+# except for handling end-group tags.
+
+def _SkipVarint(buffer, pos, end):
+  """Skip a varint value.  Returns the new position."""
+
+  while ord(buffer[pos]) & 0x80:
+    pos += 1
+  pos += 1
+  if pos > end:
+    raise _DecodeError('Truncated message.')
+  return pos
+
+def _SkipFixed64(buffer, pos, end):
+  """Skip a fixed64 value.  Returns the new position."""
+
+  pos += 8
+  if pos > end:
+    raise _DecodeError('Truncated message.')
+  return pos
+
+def _SkipLengthDelimited(buffer, pos, end):
+  """Skip a length-delimited value.  Returns the new position."""
+
+  (size, pos) = _DecodeVarint(buffer, pos)
+  pos += size
+  if pos > end:
+    raise _DecodeError('Truncated message.')
+  return pos
+
+def _SkipGroup(buffer, pos, end):
+  """Skip sub-group.  Returns the new position."""
+
+  while 1:
+    (tag_bytes, pos) = ReadTag(buffer, pos)
+    new_pos = SkipField(buffer, pos, end, tag_bytes)
+    if new_pos == -1:
+      return pos
+    pos = new_pos
+
+def _EndGroup(buffer, pos, end):
+  """Skipping an END_GROUP tag returns -1 to tell the parent loop to break."""
+
+  return -1
+
+def _SkipFixed32(buffer, pos, end):
+  """Skip a fixed32 value.  Returns the new position."""
+
+  pos += 4
+  if pos > end:
+    raise _DecodeError('Truncated message.')
+  return pos
+
+def _RaiseInvalidWireType(buffer, pos, end):
+  """Skip function for unknown wire types.  Raises an exception."""
+
+  raise _DecodeError('Tag had invalid wire type.')
+
+def _FieldSkipper():
+  """Constructs the SkipField function."""
+
+  WIRETYPE_TO_SKIPPER = [
+      _SkipVarint,
+      _SkipFixed64,
+      _SkipLengthDelimited,
+      _SkipGroup,
+      _EndGroup,
+      _SkipFixed32,
+      _RaiseInvalidWireType,
+      _RaiseInvalidWireType,
+      ]
+
+  wiretype_mask = wire_format.TAG_TYPE_MASK
+  local_ord = ord
+
+  def SkipField(buffer, pos, end, tag_bytes):
+    """Skips a field with the specified tag.
+
+    |pos| should point to the byte immediately after the tag.
+
+    Returns:
+        The new position (after the tag value), or -1 if the tag is an end-group
+        tag (in which case the calling loop should break).
+    """
 
-# Note that much of this code is ported from //net/proto/ProtocolBuffer, and
-# that the interface is strongly inspired by WireFormat from the C++ proto2
-# implementation.
-
-
-class Decoder(object):
-
-  """Decodes logical protocol buffer fields from the wire."""
+    # The wire type is always in the first byte since varints are little-endian.
+    wire_type = local_ord(tag_bytes[0]) & wiretype_mask
+    return WIRETYPE_TO_SKIPPER[wire_type](buffer, pos, end)
 
-  def __init__(self, s):
-    """Initializes the decoder to read from s.
+  return SkipField
 
-    Args:
-      s: An immutable sequence of bytes, which must be accessible
-        via the Python buffer() primitive (i.e., buffer(s)).
-    """
-    self._stream = input_stream.InputStream(s)
-
-  def EndOfStream(self):
-    """Returns true iff we've reached the end of the bytes we're reading."""
-    return self._stream.EndOfStream()
-
-  def Position(self):
-    """Returns the 0-indexed position in |s|."""
-    return self._stream.Position()
-
-  def ReadFieldNumberAndWireType(self):
-    """Reads a tag from the wire. Returns a (field_number, wire_type) pair."""
-    tag_and_type = self.ReadUInt32()
-    return wire_format.UnpackTag(tag_and_type)
-
-  def SkipBytes(self, bytes):
-    """Skips the specified number of bytes on the wire."""
-    self._stream.SkipBytes(bytes)
-
-  # Note that the Read*() methods below are not exactly symmetrical with the
-  # corresponding Encoder.Append*() methods.  Those Encoder methods first
-  # encode a tag, but the Read*() methods below assume that the tag has already
-  # been read, and that the client wishes to read a field of the specified type
-  # starting at the current position.
-
-  def ReadInt32(self):
-    """Reads and returns a signed, varint-encoded, 32-bit integer."""
-    return self._stream.ReadVarint32()
-
-  def ReadInt64(self):
-    """Reads and returns a signed, varint-encoded, 64-bit integer."""
-    return self._stream.ReadVarint64()
-
-  def ReadUInt32(self):
-    """Reads and returns an signed, varint-encoded, 32-bit integer."""
-    return self._stream.ReadVarUInt32()
-
-  def ReadUInt64(self):
-    """Reads and returns an signed, varint-encoded,64-bit integer."""
-    return self._stream.ReadVarUInt64()
-
-  def ReadSInt32(self):
-    """Reads and returns a signed, zigzag-encoded, varint-encoded,
-    32-bit integer."""
-    return wire_format.ZigZagDecode(self._stream.ReadVarUInt32())
-
-  def ReadSInt64(self):
-    """Reads and returns a signed, zigzag-encoded, varint-encoded,
-    64-bit integer."""
-    return wire_format.ZigZagDecode(self._stream.ReadVarUInt64())
-
-  def ReadFixed32(self):
-    """Reads and returns an unsigned, fixed-width, 32-bit integer."""
-    return self._stream.ReadLittleEndian32()
-
-  def ReadFixed64(self):
-    """Reads and returns an unsigned, fixed-width, 64-bit integer."""
-    return self._stream.ReadLittleEndian64()
-
-  def ReadSFixed32(self):
-    """Reads and returns a signed, fixed-width, 32-bit integer."""
-    value = self._stream.ReadLittleEndian32()
-    if value >= (1 << 31):
-      value -= (1 << 32)
-    return value
-
-  def ReadSFixed64(self):
-    """Reads and returns a signed, fixed-width, 64-bit integer."""
-    value = self._stream.ReadLittleEndian64()
-    if value >= (1 << 63):
-      value -= (1 << 64)
-    return value
-
-  def ReadFloat(self):
-    """Reads and returns a 4-byte floating-point number."""
-    serialized = self._stream.ReadBytes(4)
-    return struct.unpack(wire_format.FORMAT_FLOAT_LITTLE_ENDIAN, serialized)[0]
-
-  def ReadDouble(self):
-    """Reads and returns an 8-byte floating-point number."""
-    serialized = self._stream.ReadBytes(8)
-    return struct.unpack(wire_format.FORMAT_DOUBLE_LITTLE_ENDIAN, serialized)[0]
-
-  def ReadBool(self):
-    """Reads and returns a bool."""
-    i = self._stream.ReadVarUInt32()
-    return bool(i)
-
-  def ReadEnum(self):
-    """Reads and returns an enum value."""
-    return self._stream.ReadVarUInt32()
-
-  def ReadString(self):
-    """Reads and returns a length-delimited string."""
-    bytes = self.ReadBytes()
-    return unicode(bytes, 'utf-8')
-
-  def ReadBytes(self):
-    """Reads and returns a length-delimited byte sequence."""
-    length = self._stream.ReadVarUInt32()
-    return self._stream.ReadBytes(length)
-
-  def ReadMessageInto(self, msg):
-    """Calls msg.MergeFromString() to merge
-    length-delimited serialized message data into |msg|.
-
-    REQUIRES: The decoder must be positioned at the serialized "length"
-      prefix to a length-delmiited serialized message.
-
-    POSTCONDITION: The decoder is positioned just after the
-      serialized message, and we have merged those serialized
-      contents into |msg|.
-    """
-    length = self._stream.ReadVarUInt32()
-    sub_buffer = self._stream.GetSubBuffer(length)
-    num_bytes_used = msg.MergeFromString(sub_buffer)
-    if num_bytes_used != length:
-      raise message.DecodeError(
-          'Submessage told to deserialize from %d-byte encoding, '
-          'but used only %d bytes' % (length, num_bytes_used))
-    self._stream.SkipBytes(num_bytes_used)
-
-  def ReadGroupInto(self, expected_field_number, group):
-    """Calls group.MergeFromString() to merge
-    END_GROUP-delimited serialized message data into |group|.
-    We'll raise an exception if we don't find an END_GROUP
-    tag immediately after the serialized message contents.
-
-    REQUIRES: The decoder is positioned just after the START_GROUP
-      tag for this group.
-
-    POSTCONDITION: The decoder is positioned just after the
-      END_GROUP tag for this group, and we have merged
-      the contents of the group into |group|.
-    """
-    sub_buffer = self._stream.GetSubBuffer()  # No a priori length limit.
-    num_bytes_used = group.MergeFromString(sub_buffer)
-    if num_bytes_used < 0:
-      raise message.DecodeError('Group message reported negative bytes read.')
-    self._stream.SkipBytes(num_bytes_used)
-    field_number, field_type = self.ReadFieldNumberAndWireType()
-    if field_type != wire_format.WIRETYPE_END_GROUP:
-      raise message.DecodeError('Group message did not end with an END_GROUP.')
-    if field_number != expected_field_number:
-      raise message.DecodeError('END_GROUP tag had field '
-                                'number %d, was expecting field number %d' % (
-          field_number, expected_field_number))
-    # We're now positioned just after the END_GROUP tag.  Perfect.
+SkipField = _FieldSkipper()

+ 0 - 256
python/google/protobuf/internal/decoder_test.py

@@ -1,256 +0,0 @@
-#! /usr/bin/python
-#
-# Protocol Buffers - Google's data interchange format
-# Copyright 2008 Google Inc.  All rights reserved.
-# http://code.google.com/p/protobuf/
-#
-# Redistribution and use in source and binary forms, with or without
-# modification, are permitted provided that the following conditions are
-# met:
-#
-#     * Redistributions of source code must retain the above copyright
-# notice, this list of conditions and the following disclaimer.
-#     * Redistributions in binary form must reproduce the above
-# copyright notice, this list of conditions and the following disclaimer
-# in the documentation and/or other materials provided with the
-# distribution.
-#     * Neither the name of Google Inc. nor the names of its
-# contributors may be used to endorse or promote products derived from
-# this software without specific prior written permission.
-#
-# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
-# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
-# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
-# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
-# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
-# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
-# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
-# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
-# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
-# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
-# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
-"""Test for google.protobuf.internal.decoder."""
-
-__author__ = 'robinson@google.com (Will Robinson)'
-
-import struct
-import unittest
-from google.protobuf.internal import decoder
-from google.protobuf.internal import encoder
-from google.protobuf.internal import input_stream
-from google.protobuf.internal import wire_format
-from google.protobuf import message
-import logging
-import mox
-
-
-class DecoderTest(unittest.TestCase):
-
-  def setUp(self):
-    self.mox = mox.Mox()
-    self.mock_stream = self.mox.CreateMock(input_stream.InputStream)
-    self.mock_message = self.mox.CreateMock(message.Message)
-
-  def testReadFieldNumberAndWireType(self):
-    # Test field numbers that will require various varint sizes.
-    for expected_field_number in (1, 15, 16, 2047, 2048):
-      for expected_wire_type in range(6):  # Highest-numbered wiretype is 5.
-        e = encoder.Encoder()
-        e.AppendTag(expected_field_number, expected_wire_type)
-        s = e.ToString()
-        d = decoder.Decoder(s)
-        field_number, wire_type = d.ReadFieldNumberAndWireType()
-        self.assertEqual(expected_field_number, field_number)
-        self.assertEqual(expected_wire_type, wire_type)
-
-  def ReadScalarTestHelper(self, test_name, decoder_method, expected_result,
-                           expected_stream_method_name,
-                           stream_method_return, *args):
-    """Helper for testReadScalars below.
-
-    Calls one of the Decoder.Read*() methods and ensures that the results are
-    as expected.
-
-    Args:
-      test_name: Name of this test, used for logging only.
-      decoder_method: Unbound decoder.Decoder method to call.
-      expected_result: Value we expect returned from decoder_method().
-      expected_stream_method_name: (string) Name of the InputStream
-        method we expect Decoder to call to actually read the value
-        on the wire.
-      stream_method_return: Value our mocked-out stream method should
-        return to the decoder.
-      args: Additional arguments that we expect to be passed to the
-        stream method.
-    """
-    logging.info('Testing %s scalar input.\n'
-                 'Calling %r(), and expecting that to call the '
-                 'stream method %s(%r), which will return %r.  Finally, '
-                 'expecting the Decoder method to return %r'% (
-        test_name, decoder_method,
-        expected_stream_method_name, args, stream_method_return,
-        expected_result))
-
-    d = decoder.Decoder('')
-    d._stream = self.mock_stream
-    if decoder_method in (decoder.Decoder.ReadString,
-                          decoder.Decoder.ReadBytes):
-      self.mock_stream.ReadVarUInt32().AndReturn(len(stream_method_return))
-    # We have to use names instead of methods to work around some
-    # mox weirdness.  (ResetAll() is overzealous).
-    expected_stream_method = getattr(self.mock_stream,
-                                     expected_stream_method_name)
-    expected_stream_method(*args).AndReturn(stream_method_return)
-
-    self.mox.ReplayAll()
-    result = decoder_method(d)
-    self.assertEqual(expected_result, result)
-    self.assert_(isinstance(result, type(expected_result)))
-    self.mox.VerifyAll()
-    self.mox.ResetAll()
-
-  VAL = 1.125  # Perfectly representable as a float (no rounding error).
-  LITTLE_FLOAT_VAL = '\x00\x00\x90?'
-  LITTLE_DOUBLE_VAL = '\x00\x00\x00\x00\x00\x00\xf2?'
-
-  def testReadScalars(self):
-    test_string = 'I can feel myself getting sutpider.'
-    scalar_tests = [
-        ['int32', decoder.Decoder.ReadInt32, 0, 'ReadVarint32', 0],
-        ['int64', decoder.Decoder.ReadInt64, 0, 'ReadVarint64', 0],
-        ['uint32', decoder.Decoder.ReadUInt32, 0, 'ReadVarUInt32', 0],
-        ['uint64', decoder.Decoder.ReadUInt64, 0, 'ReadVarUInt64', 0],
-        ['fixed32', decoder.Decoder.ReadFixed32, 0xffffffff,
-         'ReadLittleEndian32', 0xffffffff],
-        ['fixed64', decoder.Decoder.ReadFixed64, 0xffffffffffffffff,
-        'ReadLittleEndian64', 0xffffffffffffffff],
-        ['sfixed32', decoder.Decoder.ReadSFixed32, long(-1),
-         'ReadLittleEndian32', long(0xffffffff)],
-        ['sfixed64', decoder.Decoder.ReadSFixed64, long(-1),
-         'ReadLittleEndian64', 0xffffffffffffffff],
-        ['float', decoder.Decoder.ReadFloat, self.VAL,
-         'ReadBytes', self.LITTLE_FLOAT_VAL, 4],
-        ['double', decoder.Decoder.ReadDouble, self.VAL,
-         'ReadBytes', self.LITTLE_DOUBLE_VAL, 8],
-        ['bool', decoder.Decoder.ReadBool, True, 'ReadVarUInt32', 1],
-        ['enum', decoder.Decoder.ReadEnum, 23, 'ReadVarUInt32', 23],
-        ['string', decoder.Decoder.ReadString,
-         unicode(test_string, 'utf-8'), 'ReadBytes', test_string,
-         len(test_string)],
-        ['utf8-string', decoder.Decoder.ReadString,
-         unicode(test_string, 'utf-8'), 'ReadBytes', test_string,
-         len(test_string)],
-        ['bytes', decoder.Decoder.ReadBytes,
-         test_string, 'ReadBytes', test_string, len(test_string)],
-        # We test zigzag decoding routines more extensively below.
-        ['sint32', decoder.Decoder.ReadSInt32, -1, 'ReadVarUInt32', 1],
-        ['sint64', decoder.Decoder.ReadSInt64, -1, 'ReadVarUInt64', 1],
-        ]
-    # Ensure that we're testing different Decoder methods and using
-    # different test names in all test cases above.
-    self.assertEqual(len(scalar_tests), len(set(t[0] for t in scalar_tests)))
-    self.assert_(len(scalar_tests) >= len(set(t[1] for t in scalar_tests)))
-    for args in scalar_tests:
-      self.ReadScalarTestHelper(*args)
-
-  def testReadMessageInto(self):
-    length = 23
-    def Test(simulate_error):
-      d = decoder.Decoder('')
-      d._stream = self.mock_stream
-      self.mock_stream.ReadVarUInt32().AndReturn(length)
-      sub_buffer = object()
-      self.mock_stream.GetSubBuffer(length).AndReturn(sub_buffer)
-
-      if simulate_error:
-        self.mock_message.MergeFromString(sub_buffer).AndReturn(length - 1)
-        self.mox.ReplayAll()
-        self.assertRaises(
-            message.DecodeError, d.ReadMessageInto, self.mock_message)
-      else:
-        self.mock_message.MergeFromString(sub_buffer).AndReturn(length)
-        self.mock_stream.SkipBytes(length)
-        self.mox.ReplayAll()
-        d.ReadMessageInto(self.mock_message)
-
-      self.mox.VerifyAll()
-      self.mox.ResetAll()
-
-    Test(simulate_error=False)
-    Test(simulate_error=True)
-
-  def testReadGroupInto_Success(self):
-    # Test both the empty and nonempty cases.
-    for num_bytes in (5, 0):
-      field_number = expected_field_number = 10
-      d = decoder.Decoder('')
-      d._stream = self.mock_stream
-      sub_buffer = object()
-      self.mock_stream.GetSubBuffer().AndReturn(sub_buffer)
-      self.mock_message.MergeFromString(sub_buffer).AndReturn(num_bytes)
-      self.mock_stream.SkipBytes(num_bytes)
-      self.mock_stream.ReadVarUInt32().AndReturn(wire_format.PackTag(
-          field_number, wire_format.WIRETYPE_END_GROUP))
-      self.mox.ReplayAll()
-      d.ReadGroupInto(expected_field_number, self.mock_message)
-      self.mox.VerifyAll()
-      self.mox.ResetAll()
-
-  def ReadGroupInto_FailureTestHelper(self, bytes_read):
-    d = decoder.Decoder('')
-    d._stream = self.mock_stream
-    sub_buffer = object()
-    self.mock_stream.GetSubBuffer().AndReturn(sub_buffer)
-    self.mock_message.MergeFromString(sub_buffer).AndReturn(bytes_read)
-    return d
-
-  def testReadGroupInto_NegativeBytesReported(self):
-    expected_field_number = 10
-    d = self.ReadGroupInto_FailureTestHelper(bytes_read=-1)
-    self.mox.ReplayAll()
-    self.assertRaises(message.DecodeError,
-                      d.ReadGroupInto, expected_field_number,
-                      self.mock_message)
-    self.mox.VerifyAll()
-
-  def testReadGroupInto_NoEndGroupTag(self):
-    field_number = expected_field_number = 10
-    num_bytes = 5
-    d = self.ReadGroupInto_FailureTestHelper(bytes_read=num_bytes)
-    self.mock_stream.SkipBytes(num_bytes)
-    # Right field number, wrong wire type.
-    self.mock_stream.ReadVarUInt32().AndReturn(wire_format.PackTag(
-        field_number, wire_format.WIRETYPE_LENGTH_DELIMITED))
-    self.mox.ReplayAll()
-    self.assertRaises(message.DecodeError,
-                      d.ReadGroupInto, expected_field_number,
-                      self.mock_message)
-    self.mox.VerifyAll()
-
-  def testReadGroupInto_WrongFieldNumberInEndGroupTag(self):
-    expected_field_number = 10
-    field_number = expected_field_number + 1
-    num_bytes = 5
-    d = self.ReadGroupInto_FailureTestHelper(bytes_read=num_bytes)
-    self.mock_stream.SkipBytes(num_bytes)
-    # Wrong field number, right wire type.
-    self.mock_stream.ReadVarUInt32().AndReturn(wire_format.PackTag(
-        field_number, wire_format.WIRETYPE_END_GROUP))
-    self.mox.ReplayAll()
-    self.assertRaises(message.DecodeError,
-                      d.ReadGroupInto, expected_field_number,
-                      self.mock_message)
-    self.mox.VerifyAll()
-
-  def testSkipBytes(self):
-    d = decoder.Decoder('')
-    num_bytes = 1024
-    self.mock_stream.SkipBytes(num_bytes)
-    d._stream = self.mock_stream
-    self.mox.ReplayAll()
-    d.SkipBytes(num_bytes)
-    self.mox.VerifyAll()
-
-if __name__ == '__main__':
-  unittest.main()

+ 224 - 3
python/google/protobuf/internal/descriptor_test.py

@@ -35,16 +35,30 @@
 __author__ = 'robinson@google.com (Will Robinson)'
 
 import unittest
+from google.protobuf import unittest_import_pb2
+from google.protobuf import unittest_pb2
 from google.protobuf import descriptor_pb2
 from google.protobuf import descriptor
+from google.protobuf import text_format
+
+
+TEST_EMPTY_MESSAGE_DESCRIPTOR_ASCII = """
+name: 'TestEmptyMessage'
+"""
+
 
 class DescriptorTest(unittest.TestCase):
 
   def setUp(self):
+    self.my_file = descriptor.FileDescriptor(
+        name='some/filename/some.proto',
+        package='protobuf_unittest'
+        )
     self.my_enum = descriptor.EnumDescriptor(
         name='ForeignEnum',
         full_name='protobuf_unittest.ForeignEnum',
-        filename='ForeignEnum',
+        filename=None,
+        file=self.my_file,
         values=[
           descriptor.EnumValueDescriptor(name='FOREIGN_FOO', index=0, number=4),
           descriptor.EnumValueDescriptor(name='FOREIGN_BAR', index=1, number=5),
@@ -53,7 +67,8 @@ class DescriptorTest(unittest.TestCase):
     self.my_message = descriptor.Descriptor(
         name='NestedMessage',
         full_name='protobuf_unittest.TestAllTypes.NestedMessage',
-        filename='some/filename/some.proto',
+        filename=None,
+        file=self.my_file,
         containing_type=None,
         fields=[
           descriptor.FieldDescriptor(
@@ -61,7 +76,7 @@ class DescriptorTest(unittest.TestCase):
             full_name='protobuf_unittest.TestAllTypes.NestedMessage.bb',
             index=0, number=1,
             type=5, cpp_type=1, label=1,
-            default_value=0,
+            has_default_value=False, default_value=0,
             message_type=None, enum_type=None, containing_type=None,
             is_extension=False, extension_scope=None),
         ],
@@ -80,6 +95,7 @@ class DescriptorTest(unittest.TestCase):
     self.my_service = descriptor.ServiceDescriptor(
         name='TestServiceWithOptions',
         full_name='protobuf_unittest.TestServiceWithOptions',
+        file=self.my_file,
         index=0,
         methods=[
             self.my_method
@@ -109,5 +125,210 @@ class DescriptorTest(unittest.TestCase):
     self.assertEqual(self.my_service.GetOptions(),
                      descriptor_pb2.ServiceOptions())
 
+  def testFileDescriptorReferences(self):
+    self.assertEqual(self.my_enum.file, self.my_file)
+    self.assertEqual(self.my_message.file, self.my_file)
+
+  def testFileDescriptor(self):
+    self.assertEqual(self.my_file.name, 'some/filename/some.proto')
+    self.assertEqual(self.my_file.package, 'protobuf_unittest')
+
+
+class DescriptorCopyToProtoTest(unittest.TestCase):
+  """Tests for CopyTo functions of Descriptor."""
+
+  def _AssertProtoEqual(self, actual_proto, expected_class, expected_ascii):
+    expected_proto = expected_class()
+    text_format.Merge(expected_ascii, expected_proto)
+
+    self.assertEqual(
+        actual_proto, expected_proto,
+        'Not equal,\nActual:\n%s\nExpected:\n%s\n'
+        % (str(actual_proto), str(expected_proto)))
+
+  def _InternalTestCopyToProto(self, desc, expected_proto_class,
+                               expected_proto_ascii):
+    actual = expected_proto_class()
+    desc.CopyToProto(actual)
+    self._AssertProtoEqual(
+        actual, expected_proto_class, expected_proto_ascii)
+
+  def testCopyToProto_EmptyMessage(self):
+    self._InternalTestCopyToProto(
+        unittest_pb2.TestEmptyMessage.DESCRIPTOR,
+        descriptor_pb2.DescriptorProto,
+        TEST_EMPTY_MESSAGE_DESCRIPTOR_ASCII)
+
+  def testCopyToProto_NestedMessage(self):
+    TEST_NESTED_MESSAGE_ASCII = """
+      name: 'NestedMessage'
+      field: <
+        name: 'bb'
+        number: 1
+        label: 1  # Optional
+        type: 5  # TYPE_INT32
+      >
+      """
+
+    self._InternalTestCopyToProto(
+        unittest_pb2.TestAllTypes.NestedMessage.DESCRIPTOR,
+        descriptor_pb2.DescriptorProto,
+        TEST_NESTED_MESSAGE_ASCII)
+
+  def testCopyToProto_ForeignNestedMessage(self):
+    TEST_FOREIGN_NESTED_ASCII = """
+      name: 'TestForeignNested'
+      field: <
+        name: 'foreign_nested'
+        number: 1
+        label: 1  # Optional
+        type: 11  # TYPE_MESSAGE
+        type_name: '.protobuf_unittest.TestAllTypes.NestedMessage'
+      >
+      """
+
+    self._InternalTestCopyToProto(
+        unittest_pb2.TestForeignNested.DESCRIPTOR,
+        descriptor_pb2.DescriptorProto,
+        TEST_FOREIGN_NESTED_ASCII)
+
+  def testCopyToProto_ForeignEnum(self):
+    TEST_FOREIGN_ENUM_ASCII = """
+      name: 'ForeignEnum'
+      value: <
+        name: 'FOREIGN_FOO'
+        number: 4
+      >
+      value: <
+        name: 'FOREIGN_BAR'
+        number: 5
+      >
+      value: <
+        name: 'FOREIGN_BAZ'
+        number: 6
+      >
+      """
+
+    self._InternalTestCopyToProto(
+        unittest_pb2._FOREIGNENUM,
+        descriptor_pb2.EnumDescriptorProto,
+        TEST_FOREIGN_ENUM_ASCII)
+
+  def testCopyToProto_Options(self):
+    TEST_DEPRECATED_FIELDS_ASCII = """
+      name: 'TestDeprecatedFields'
+      field: <
+        name: 'deprecated_int32'
+        number: 1
+        label: 1  # Optional
+        type: 5  # TYPE_INT32
+        options: <
+          deprecated: true
+        >
+      >
+      """
+
+    self._InternalTestCopyToProto(
+        unittest_pb2.TestDeprecatedFields.DESCRIPTOR,
+        descriptor_pb2.DescriptorProto,
+        TEST_DEPRECATED_FIELDS_ASCII)
+
+  def testCopyToProto_AllExtensions(self):
+    TEST_EMPTY_MESSAGE_WITH_EXTENSIONS_ASCII = """
+      name: 'TestEmptyMessageWithExtensions'
+      extension_range: <
+        start: 1
+        end: 536870912
+      >
+      """
+
+    self._InternalTestCopyToProto(
+        unittest_pb2.TestEmptyMessageWithExtensions.DESCRIPTOR,
+        descriptor_pb2.DescriptorProto,
+        TEST_EMPTY_MESSAGE_WITH_EXTENSIONS_ASCII)
+
+  def testCopyToProto_SeveralExtensions(self):
+    TEST_MESSAGE_WITH_SEVERAL_EXTENSIONS_ASCII = """
+      name: 'TestMultipleExtensionRanges'
+      extension_range: <
+        start: 42
+        end: 43
+      >
+      extension_range: <
+        start: 4143
+        end: 4244
+      >
+      extension_range: <
+        start: 65536
+        end: 536870912
+      >
+      """
+
+    self._InternalTestCopyToProto(
+        unittest_pb2.TestMultipleExtensionRanges.DESCRIPTOR,
+        descriptor_pb2.DescriptorProto,
+        TEST_MESSAGE_WITH_SEVERAL_EXTENSIONS_ASCII)
+
+  def testCopyToProto_FileDescriptor(self):
+    UNITTEST_IMPORT_FILE_DESCRIPTOR_ASCII = ("""
+      name: 'google/protobuf/unittest_import.proto'
+      package: 'protobuf_unittest_import'
+      message_type: <
+        name: 'ImportMessage'
+        field: <
+          name: 'd'
+          number: 1
+          label: 1  # Optional
+          type: 5  # TYPE_INT32
+        >
+      >
+      """ +
+      """enum_type: <
+        name: 'ImportEnum'
+        value: <
+          name: 'IMPORT_FOO'
+          number: 7
+        >
+        value: <
+          name: 'IMPORT_BAR'
+          number: 8
+        >
+        value: <
+          name: 'IMPORT_BAZ'
+          number: 9
+        >
+      >
+      options: <
+        java_package: 'com.google.protobuf.test'
+        optimize_for: 1  # SPEED
+      >
+      """)
+
+    self._InternalTestCopyToProto(
+        unittest_import_pb2.DESCRIPTOR,
+        descriptor_pb2.FileDescriptorProto,
+        UNITTEST_IMPORT_FILE_DESCRIPTOR_ASCII)
+
+  def testCopyToProto_ServiceDescriptor(self):
+    TEST_SERVICE_ASCII = """
+      name: 'TestService'
+      method: <
+        name: 'Foo'
+        input_type: '.protobuf_unittest.FooRequest'
+        output_type: '.protobuf_unittest.FooResponse'
+      >
+      method: <
+        name: 'Bar'
+        input_type: '.protobuf_unittest.BarRequest'
+        output_type: '.protobuf_unittest.BarResponse'
+      >
+      """
+
+    self._InternalTestCopyToProto(
+        unittest_pb2.TestService.DESCRIPTOR,
+        descriptor_pb2.ServiceDescriptorProto,
+        TEST_SERVICE_ASCII)
+
+
 if __name__ == '__main__':
   unittest.main()

+ 647 - 241
python/google/protobuf/internal/encoder.py

@@ -28,253 +28,659 @@
 # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
 # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 
-"""Class for encoding protocol message primitives.
+"""Code for encoding protocol message primitives.
 
 Contains the logic for encoding every logical protocol field type
 into one of the 5 physical wire types.
+
+This code is designed to push the Python interpreter's performance to the
+limits.
+
+The basic idea is that at startup time, for every field (i.e. every
+FieldDescriptor) we construct two functions:  a "sizer" and an "encoder".  The
+sizer takes a value of this field's type and computes its byte size.  The
+encoder takes a writer function and a value.  It encodes the value into byte
+strings and invokes the writer function to write those strings.  Typically the
+writer function is the write() method of a cStringIO.
+
+We try to do as much work as possible when constructing the writer and the
+sizer rather than when calling them.  In particular:
+* We copy any needed global functions to local variables, so that we do not need
+  to do costly global table lookups at runtime.
+* Similarly, we try to do any attribute lookups at startup time if possible.
+* Every field's tag is encoded to bytes at startup, since it can't change at
+  runtime.
+* Whatever component of the field size we can compute at startup, we do.
+* We *avoid* sharing code if doing so would make the code slower and not sharing
+  does not burden us too much.  For example, encoders for repeated fields do
+  not just call the encoders for singular fields in a loop because this would
+  add an extra function call overhead for every loop iteration; instead, we
+  manually inline the single-value encoder into the loop.
+* If a Python function lacks a return statement, Python actually generates
+  instructions to pop the result of the last statement off the stack, push
+  None onto the stack, and then return that.  If we really don't care what
+  value is returned, then we can save two instructions by returning the
+  result of the last statement.  It looks funny but it helps.
+* We assume that type and bounds checking has happened at a higher level.
 """
 
-__author__ = 'robinson@google.com (Will Robinson)'
+__author__ = 'kenton@google.com (Kenton Varda)'
 
 import struct
-from google.protobuf import message
 from google.protobuf.internal import wire_format
-from google.protobuf.internal import output_stream
-
-
-# Note that much of this code is ported from //net/proto/ProtocolBuffer, and
-# that the interface is strongly inspired by WireFormat from the C++ proto2
-# implementation.
-
-
-class Encoder(object):
-
-  """Encodes logical protocol buffer fields to the wire format."""
-
-  def __init__(self):
-    self._stream = output_stream.OutputStream()
-
-  def ToString(self):
-    """Returns all values encoded in this object as a string."""
-    return self._stream.ToString()
-
-  # Append*NoTag methods.  These are necessary for serializing packed
-  # repeated fields.  The Append*() methods call these methods to do
-  # the actual serialization.
-  def AppendInt32NoTag(self, value):
-    """Appends a 32-bit integer to our buffer, varint-encoded."""
-    self._stream.AppendVarint32(value)
-
-  def AppendInt64NoTag(self, value):
-    """Appends a 64-bit integer to our buffer, varint-encoded."""
-    self._stream.AppendVarint64(value)
-
-  def AppendUInt32NoTag(self, unsigned_value):
-    """Appends an unsigned 32-bit integer to our buffer, varint-encoded."""
-    self._stream.AppendVarUInt32(unsigned_value)
-
-  def AppendUInt64NoTag(self, unsigned_value):
-    """Appends an unsigned 64-bit integer to our buffer, varint-encoded."""
-    self._stream.AppendVarUInt64(unsigned_value)
-
-  def AppendSInt32NoTag(self, value):
-    """Appends a 32-bit integer to our buffer, zigzag-encoded and then
-    varint-encoded.
-    """
-    zigzag_value = wire_format.ZigZagEncode(value)
-    self._stream.AppendVarUInt32(zigzag_value)
-
-  def AppendSInt64NoTag(self, value):
-    """Appends a 64-bit integer to our buffer, zigzag-encoded and then
-    varint-encoded.
-    """
-    zigzag_value = wire_format.ZigZagEncode(value)
-    self._stream.AppendVarUInt64(zigzag_value)
-
-  def AppendFixed32NoTag(self, unsigned_value):
-    """Appends an unsigned 32-bit integer to our buffer, in little-endian
-    byte-order.
-    """
-    self._stream.AppendLittleEndian32(unsigned_value)
-
-  def AppendFixed64NoTag(self, unsigned_value):
-    """Appends an unsigned 64-bit integer to our buffer, in little-endian
-    byte-order.
-    """
-    self._stream.AppendLittleEndian64(unsigned_value)
-
-  def AppendSFixed32NoTag(self, value):
-    """Appends a signed 32-bit integer to our buffer, in little-endian
-    byte-order.
-    """
-    sign = (value & 0x80000000) and -1 or 0
-    if value >> 32 != sign:
-      raise message.EncodeError('SFixed32 out of range: %d' % value)
-    self._stream.AppendLittleEndian32(value & 0xffffffff)
-
-  def AppendSFixed64NoTag(self, value):
-    """Appends a signed 64-bit integer to our buffer, in little-endian
-    byte-order.
-    """
-    sign = (value & 0x8000000000000000) and -1 or 0
-    if value >> 64 != sign:
-      raise message.EncodeError('SFixed64 out of range: %d' % value)
-    self._stream.AppendLittleEndian64(value & 0xffffffffffffffff)
-
-  def AppendFloatNoTag(self, value):
-    """Appends a floating-point number to our buffer."""
-    self._stream.AppendRawBytes(
-        struct.pack(wire_format.FORMAT_FLOAT_LITTLE_ENDIAN, value))
-
-  def AppendDoubleNoTag(self, value):
-    """Appends a double-precision floating-point number to our buffer."""
-    self._stream.AppendRawBytes(
-        struct.pack(wire_format.FORMAT_DOUBLE_LITTLE_ENDIAN, value))
-
-  def AppendBoolNoTag(self, value):
-    """Appends a boolean to our buffer."""
-    self.AppendInt32NoTag(value)
-
-  def AppendEnumNoTag(self, value):
-    """Appends an enum value to our buffer."""
-    self.AppendInt32NoTag(value)
-
-
-  # All the Append*() methods below first append a tag+type pair to the buffer
-  # before appending the specified value.
-
-  def AppendInt32(self, field_number, value):
-    """Appends a 32-bit integer to our buffer, varint-encoded."""
-    self.AppendTag(field_number, wire_format.WIRETYPE_VARINT)
-    self.AppendInt32NoTag(value)
-
-  def AppendInt64(self, field_number, value):
-    """Appends a 64-bit integer to our buffer, varint-encoded."""
-    self.AppendTag(field_number, wire_format.WIRETYPE_VARINT)
-    self.AppendInt64NoTag(value)
-
-  def AppendUInt32(self, field_number, unsigned_value):
-    """Appends an unsigned 32-bit integer to our buffer, varint-encoded."""
-    self.AppendTag(field_number, wire_format.WIRETYPE_VARINT)
-    self.AppendUInt32NoTag(unsigned_value)
-
-  def AppendUInt64(self, field_number, unsigned_value):
-    """Appends an unsigned 64-bit integer to our buffer, varint-encoded."""
-    self.AppendTag(field_number, wire_format.WIRETYPE_VARINT)
-    self.AppendUInt64NoTag(unsigned_value)
-
-  def AppendSInt32(self, field_number, value):
-    """Appends a 32-bit integer to our buffer, zigzag-encoded and then
-    varint-encoded.
-    """
-    self.AppendTag(field_number, wire_format.WIRETYPE_VARINT)
-    self.AppendSInt32NoTag(value)
-
-  def AppendSInt64(self, field_number, value):
-    """Appends a 64-bit integer to our buffer, zigzag-encoded and then
-    varint-encoded.
-    """
-    self.AppendTag(field_number, wire_format.WIRETYPE_VARINT)
-    self.AppendSInt64NoTag(value)
-
-  def AppendFixed32(self, field_number, unsigned_value):
-    """Appends an unsigned 32-bit integer to our buffer, in little-endian
-    byte-order.
-    """
-    self.AppendTag(field_number, wire_format.WIRETYPE_FIXED32)
-    self.AppendFixed32NoTag(unsigned_value)
-
-  def AppendFixed64(self, field_number, unsigned_value):
-    """Appends an unsigned 64-bit integer to our buffer, in little-endian
-    byte-order.
-    """
-    self.AppendTag(field_number, wire_format.WIRETYPE_FIXED64)
-    self.AppendFixed64NoTag(unsigned_value)
-
-  def AppendSFixed32(self, field_number, value):
-    """Appends a signed 32-bit integer to our buffer, in little-endian
-    byte-order.
-    """
-    self.AppendTag(field_number, wire_format.WIRETYPE_FIXED32)
-    self.AppendSFixed32NoTag(value)
-
-  def AppendSFixed64(self, field_number, value):
-    """Appends a signed 64-bit integer to our buffer, in little-endian
-    byte-order.
-    """
-    self.AppendTag(field_number, wire_format.WIRETYPE_FIXED64)
-    self.AppendSFixed64NoTag(value)
-
-  def AppendFloat(self, field_number, value):
-    """Appends a floating-point number to our buffer."""
-    self.AppendTag(field_number, wire_format.WIRETYPE_FIXED32)
-    self.AppendFloatNoTag(value)
-
-  def AppendDouble(self, field_number, value):
-    """Appends a double-precision floating-point number to our buffer."""
-    self.AppendTag(field_number, wire_format.WIRETYPE_FIXED64)
-    self.AppendDoubleNoTag(value)
-
-  def AppendBool(self, field_number, value):
-    """Appends a boolean to our buffer."""
-    self.AppendInt32(field_number, value)
-
-  def AppendEnum(self, field_number, value):
-    """Appends an enum value to our buffer."""
-    self.AppendInt32(field_number, value)
-
-  def AppendString(self, field_number, value):
-    """Appends a length-prefixed unicode string, encoded as UTF-8 to our buffer,
-    with the length varint-encoded.
-    """
-    self.AppendBytes(field_number, value.encode('utf-8'))
-
-  def AppendBytes(self, field_number, value):
-    """Appends a length-prefixed sequence of bytes to our buffer, with the
-    length varint-encoded.
-    """
-    self.AppendTag(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED)
-    self._stream.AppendVarUInt32(len(value))
-    self._stream.AppendRawBytes(value)
-
-  # TODO(robinson): For AppendGroup() and AppendMessage(), we'd really like to
-  # avoid the extra string copy here.  We can do so if we widen the Message
-  # interface to be able to serialize to a stream in addition to a string.  The
-  # challenge when thinking ahead to the Python/C API implementation of Message
-  # is finding a stream-like Python thing to which we can write raw bytes
-  # from C.  I'm not sure such a thing exists(?).  (array.array is pretty much
-  # what we want, but it's not directly exposed in the Python/C API).
-
-  def AppendGroup(self, field_number, group):
-    """Appends a group to our buffer.
-    """
-    self.AppendTag(field_number, wire_format.WIRETYPE_START_GROUP)
-    self._stream.AppendRawBytes(group.SerializeToString())
-    self.AppendTag(field_number, wire_format.WIRETYPE_END_GROUP)
-
-  def AppendMessage(self, field_number, msg):
-    """Appends a nested message to our buffer.
-    """
-    self.AppendTag(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED)
-    self._stream.AppendVarUInt32(msg.ByteSize())
-    self._stream.AppendRawBytes(msg.SerializeToString())
-
-  def AppendMessageSetItem(self, field_number, msg):
-    """Appends an item using the message set wire format.
-
-    The message set message looks like this:
-      message MessageSet {
-        repeated group Item = 1 {
-          required int32 type_id = 2;
-          required string message = 3;
-        }
+
+
+def _VarintSize(value):
+  """Compute the size of a varint value."""
+  if value <= 0x7f: return 1
+  if value <= 0x3fff: return 2
+  if value <= 0x1fffff: return 3
+  if value <= 0xfffffff: return 4
+  if value <= 0x7ffffffff: return 5
+  if value <= 0x3ffffffffff: return 6
+  if value <= 0x1ffffffffffff: return 7
+  if value <= 0xffffffffffffff: return 8
+  if value <= 0x7fffffffffffffff: return 9
+  return 10
+
+
+def _SignedVarintSize(value):
+  """Compute the size of a signed varint value."""
+  if value < 0: return 10
+  if value <= 0x7f: return 1
+  if value <= 0x3fff: return 2
+  if value <= 0x1fffff: return 3
+  if value <= 0xfffffff: return 4
+  if value <= 0x7ffffffff: return 5
+  if value <= 0x3ffffffffff: return 6
+  if value <= 0x1ffffffffffff: return 7
+  if value <= 0xffffffffffffff: return 8
+  if value <= 0x7fffffffffffffff: return 9
+  return 10
+
+
+def _TagSize(field_number):
+  """Returns the number of bytes required to serialize a tag with this field
+  number."""
+  # Just pass in type 0, since the type won't affect the tag+type size.
+  return _VarintSize(wire_format.PackTag(field_number, 0))
+
+
+# --------------------------------------------------------------------
+# In this section we define some generic sizers.  Each of these functions
+# takes parameters specific to a particular field type, e.g. int32 or fixed64.
+# It returns another function which in turn takes parameters specific to a
+# particular field, e.g. the field number and whether it is repeated or packed.
+# Look at the next section to see how these are used.
+
+
+def _SimpleSizer(compute_value_size):
+  """A sizer which uses the function compute_value_size to compute the size of
+  each value.  Typically compute_value_size is _VarintSize."""
+
+  def SpecificSizer(field_number, is_repeated, is_packed):
+    tag_size = _TagSize(field_number)
+    if is_packed:
+      local_VarintSize = _VarintSize
+      def PackedFieldSize(value):
+        result = 0
+        for element in value:
+          result += compute_value_size(element)
+        return result + local_VarintSize(result) + tag_size
+      return PackedFieldSize
+    elif is_repeated:
+      def RepeatedFieldSize(value):
+        result = tag_size * len(value)
+        for element in value:
+          result += compute_value_size(element)
+        return result
+      return RepeatedFieldSize
+    else:
+      def FieldSize(value):
+        return tag_size + compute_value_size(value)
+      return FieldSize
+
+  return SpecificSizer
+
+
+def _ModifiedSizer(compute_value_size, modify_value):
+  """Like SimpleSizer, but modify_value is invoked on each value before it is
+  passed to compute_value_size.  modify_value is typically ZigZagEncode."""
+
+  def SpecificSizer(field_number, is_repeated, is_packed):
+    tag_size = _TagSize(field_number)
+    if is_packed:
+      local_VarintSize = _VarintSize
+      def PackedFieldSize(value):
+        result = 0
+        for element in value:
+          result += compute_value_size(modify_value(element))
+        return result + local_VarintSize(result) + tag_size
+      return PackedFieldSize
+    elif is_repeated:
+      def RepeatedFieldSize(value):
+        result = tag_size * len(value)
+        for element in value:
+          result += compute_value_size(modify_value(element))
+        return result
+      return RepeatedFieldSize
+    else:
+      def FieldSize(value):
+        return tag_size + compute_value_size(modify_value(value))
+      return FieldSize
+
+  return SpecificSizer
+
+
+def _FixedSizer(value_size):
+  """Like _SimpleSizer except for a fixed-size field.  The input is the size
+  of one value."""
+
+  def SpecificSizer(field_number, is_repeated, is_packed):
+    tag_size = _TagSize(field_number)
+    if is_packed:
+      local_VarintSize = _VarintSize
+      def PackedFieldSize(value):
+        result = len(value) * value_size
+        return result + local_VarintSize(result) + tag_size
+      return PackedFieldSize
+    elif is_repeated:
+      element_size = value_size + tag_size
+      def RepeatedFieldSize(value):
+        return len(value) * element_size
+      return RepeatedFieldSize
+    else:
+      field_size = value_size + tag_size
+      def FieldSize(value):
+        return field_size
+      return FieldSize
+
+  return SpecificSizer
+
+
+# ====================================================================
+# Here we declare a sizer constructor for each field type.  Each "sizer
+# constructor" is a function that takes (field_number, is_repeated, is_packed)
+# as parameters and returns a sizer, which in turn takes a field value as
+# a parameter and returns its encoded size.
+
+
+Int32Sizer = Int64Sizer = EnumSizer = _SimpleSizer(_SignedVarintSize)
+
+UInt32Sizer = UInt64Sizer = _SimpleSizer(_VarintSize)
+
+SInt32Sizer = SInt64Sizer = _ModifiedSizer(
+    _SignedVarintSize, wire_format.ZigZagEncode)
+
+Fixed32Sizer = SFixed32Sizer = FloatSizer  = _FixedSizer(4)
+Fixed64Sizer = SFixed64Sizer = DoubleSizer = _FixedSizer(8)
+
+BoolSizer = _FixedSizer(1)
+
+
+def StringSizer(field_number, is_repeated, is_packed):
+  """Returns a sizer for a string field."""
+
+  tag_size = _TagSize(field_number)
+  local_VarintSize = _VarintSize
+  local_len = len
+  assert not is_packed
+  if is_repeated:
+    def RepeatedFieldSize(value):
+      result = tag_size * len(value)
+      for element in value:
+        l = local_len(element.encode('utf-8'))
+        result += local_VarintSize(l) + l
+      return result
+    return RepeatedFieldSize
+  else:
+    def FieldSize(value):
+      l = local_len(value.encode('utf-8'))
+      return tag_size + local_VarintSize(l) + l
+    return FieldSize
+
+
+def BytesSizer(field_number, is_repeated, is_packed):
+  """Returns a sizer for a bytes field."""
+
+  tag_size = _TagSize(field_number)
+  local_VarintSize = _VarintSize
+  local_len = len
+  assert not is_packed
+  if is_repeated:
+    def RepeatedFieldSize(value):
+      result = tag_size * len(value)
+      for element in value:
+        l = local_len(element)
+        result += local_VarintSize(l) + l
+      return result
+    return RepeatedFieldSize
+  else:
+    def FieldSize(value):
+      l = local_len(value)
+      return tag_size + local_VarintSize(l) + l
+    return FieldSize
+
+
+def GroupSizer(field_number, is_repeated, is_packed):
+  """Returns a sizer for a group field."""
+
+  tag_size = _TagSize(field_number) * 2
+  assert not is_packed
+  if is_repeated:
+    def RepeatedFieldSize(value):
+      result = tag_size * len(value)
+      for element in value:
+        result += element.ByteSize()
+      return result
+    return RepeatedFieldSize
+  else:
+    def FieldSize(value):
+      return tag_size + value.ByteSize()
+    return FieldSize
+
+
+def MessageSizer(field_number, is_repeated, is_packed):
+  """Returns a sizer for a message field."""
+
+  tag_size = _TagSize(field_number)
+  local_VarintSize = _VarintSize
+  assert not is_packed
+  if is_repeated:
+    def RepeatedFieldSize(value):
+      result = tag_size * len(value)
+      for element in value:
+        l = element.ByteSize()
+        result += local_VarintSize(l) + l
+      return result
+    return RepeatedFieldSize
+  else:
+    def FieldSize(value):
+      l = value.ByteSize()
+      return tag_size + local_VarintSize(l) + l
+    return FieldSize
+
+
+# --------------------------------------------------------------------
+# MessageSet is special.
+
+
+def MessageSetItemSizer(field_number):
+  """Returns a sizer for extensions of MessageSet.
+
+  The message set message looks like this:
+    message MessageSet {
+      repeated group Item = 1 {
+        required int32 type_id = 2;
+        required string message = 3;
+      }
+    }
+  """
+  static_size = (_TagSize(1) * 2 + _TagSize(2) + _VarintSize(field_number) +
+                 _TagSize(3))
+  local_VarintSize = _VarintSize
+
+  def FieldSize(value):
+    l = value.ByteSize()
+    return static_size + local_VarintSize(l) + l
+
+  return FieldSize
+
+
+# ====================================================================
+# Encoders!
+
+
+def _VarintEncoder():
+  """Return an encoder for a basic varint value (does not include tag)."""
+
+  local_chr = chr
+  def EncodeVarint(write, value):
+    bits = value & 0x7f
+    value >>= 7
+    while value:
+      write(local_chr(0x80|bits))
+      bits = value & 0x7f
+      value >>= 7
+    return write(local_chr(bits))
+
+  return EncodeVarint
+
+
+def _SignedVarintEncoder():
+  """Return an encoder for a basic signed varint value (does not include
+  tag)."""
+
+  local_chr = chr
+  def EncodeSignedVarint(write, value):
+    if value < 0:
+      value += (1 << 64)
+    bits = value & 0x7f
+    value >>= 7
+    while value:
+      write(local_chr(0x80|bits))
+      bits = value & 0x7f
+      value >>= 7
+    return write(local_chr(bits))
+
+  return EncodeSignedVarint
+
+
+_EncodeVarint = _VarintEncoder()
+_EncodeSignedVarint = _SignedVarintEncoder()
+
+
+def _VarintBytes(value):
+  """Encode the given integer as a varint and return the bytes.  This is only
+  called at startup time so it doesn't need to be fast."""
+
+  pieces = []
+  _EncodeVarint(pieces.append, value)
+  return "".join(pieces)
+
+
+def TagBytes(field_number, wire_type):
+  """Encode the given tag and return the bytes.  Only called at startup."""
+
+  return _VarintBytes(wire_format.PackTag(field_number, wire_type))
+
+# --------------------------------------------------------------------
+# As with sizers (see above), we have a number of common encoder
+# implementations.
+
+
+def _SimpleEncoder(wire_type, encode_value, compute_value_size):
+  """Return a constructor for an encoder for fields of a particular type.
+
+  Args:
+      wire_type:  The field's wire type, for encoding tags.
+      encode_value:  A function which encodes an individual value, e.g.
+        _EncodeVarint().
+      compute_value_size:  A function which computes the size of an individual
+        value, e.g. _VarintSize().
+  """
+
+  def SpecificEncoder(field_number, is_repeated, is_packed):
+    if is_packed:
+      tag_bytes = TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED)
+      local_EncodeVarint = _EncodeVarint
+      def EncodePackedField(write, value):
+        write(tag_bytes)
+        size = 0
+        for element in value:
+          size += compute_value_size(element)
+        local_EncodeVarint(write, size)
+        for element in value:
+          encode_value(write, element)
+      return EncodePackedField
+    elif is_repeated:
+      tag_bytes = TagBytes(field_number, wire_type)
+      def EncodeRepeatedField(write, value):
+        for element in value:
+          write(tag_bytes)
+          encode_value(write, element)
+      return EncodeRepeatedField
+    else:
+      tag_bytes = TagBytes(field_number, wire_type)
+      def EncodeField(write, value):
+        write(tag_bytes)
+        return encode_value(write, value)
+      return EncodeField
+
+  return SpecificEncoder
+
+
+def _ModifiedEncoder(wire_type, encode_value, compute_value_size, modify_value):
+  """Like SimpleEncoder but additionally invokes modify_value on every value
+  before passing it to encode_value.  Usually modify_value is ZigZagEncode."""
+
+  def SpecificEncoder(field_number, is_repeated, is_packed):
+    if is_packed:
+      tag_bytes = TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED)
+      local_EncodeVarint = _EncodeVarint
+      def EncodePackedField(write, value):
+        write(tag_bytes)
+        size = 0
+        for element in value:
+          size += compute_value_size(modify_value(element))
+        local_EncodeVarint(write, size)
+        for element in value:
+          encode_value(write, modify_value(element))
+      return EncodePackedField
+    elif is_repeated:
+      tag_bytes = TagBytes(field_number, wire_type)
+      def EncodeRepeatedField(write, value):
+        for element in value:
+          write(tag_bytes)
+          encode_value(write, modify_value(element))
+      return EncodeRepeatedField
+    else:
+      tag_bytes = TagBytes(field_number, wire_type)
+      def EncodeField(write, value):
+        write(tag_bytes)
+        return encode_value(write, modify_value(value))
+      return EncodeField
+
+  return SpecificEncoder
+
+
+def _StructPackEncoder(wire_type, format):
+  """Return a constructor for an encoder for a fixed-width field.
+
+  Args:
+      wire_type:  The field's wire type, for encoding tags.
+      format:  The format string to pass to struct.pack().
+  """
+
+  value_size = struct.calcsize(format)
+
+  def SpecificEncoder(field_number, is_repeated, is_packed):
+    local_struct_pack = struct.pack
+    if is_packed:
+      tag_bytes = TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED)
+      local_EncodeVarint = _EncodeVarint
+      def EncodePackedField(write, value):
+        write(tag_bytes)
+        local_EncodeVarint(write, len(value) * value_size)
+        for element in value:
+          write(local_struct_pack(format, element))
+      return EncodePackedField
+    elif is_repeated:
+      tag_bytes = TagBytes(field_number, wire_type)
+      def EncodeRepeatedField(write, value):
+        for element in value:
+          write(tag_bytes)
+          write(local_struct_pack(format, element))
+      return EncodeRepeatedField
+    else:
+      tag_bytes = TagBytes(field_number, wire_type)
+      def EncodeField(write, value):
+        write(tag_bytes)
+        return write(local_struct_pack(format, value))
+      return EncodeField
+
+  return SpecificEncoder
+
+
+# ====================================================================
+# Here we declare an encoder constructor for each field type.  These work
+# very similarly to sizer constructors, described earlier.
+
+
+Int32Encoder = Int64Encoder = EnumEncoder = _SimpleEncoder(
+    wire_format.WIRETYPE_VARINT, _EncodeSignedVarint, _SignedVarintSize)
+
+UInt32Encoder = UInt64Encoder = _SimpleEncoder(
+    wire_format.WIRETYPE_VARINT, _EncodeVarint, _VarintSize)
+
+SInt32Encoder = SInt64Encoder = _ModifiedEncoder(
+    wire_format.WIRETYPE_VARINT, _EncodeVarint, _VarintSize,
+    wire_format.ZigZagEncode)
+
+# Note that Python conveniently guarantees that when using the '<' prefix on
+# formats, they will also have the same size across all platforms (as opposed
+# to without the prefix, where their sizes depend on the C compiler's basic
+# type sizes).
+Fixed32Encoder  = _StructPackEncoder(wire_format.WIRETYPE_FIXED32, '<I')
+Fixed64Encoder  = _StructPackEncoder(wire_format.WIRETYPE_FIXED64, '<Q')
+SFixed32Encoder = _StructPackEncoder(wire_format.WIRETYPE_FIXED32, '<i')
+SFixed64Encoder = _StructPackEncoder(wire_format.WIRETYPE_FIXED64, '<q')
+FloatEncoder    = _StructPackEncoder(wire_format.WIRETYPE_FIXED32, '<f')
+DoubleEncoder   = _StructPackEncoder(wire_format.WIRETYPE_FIXED64, '<d')
+
+
+def BoolEncoder(field_number, is_repeated, is_packed):
+  """Returns an encoder for a boolean field."""
+
+  false_byte = chr(0)
+  true_byte = chr(1)
+  if is_packed:
+    tag_bytes = TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED)
+    local_EncodeVarint = _EncodeVarint
+    def EncodePackedField(write, value):
+      write(tag_bytes)
+      local_EncodeVarint(write, len(value))
+      for element in value:
+        if element:
+          write(true_byte)
+        else:
+          write(false_byte)
+    return EncodePackedField
+  elif is_repeated:
+    tag_bytes = TagBytes(field_number, wire_format.WIRETYPE_VARINT)
+    def EncodeRepeatedField(write, value):
+      for element in value:
+        write(tag_bytes)
+        if element:
+          write(true_byte)
+        else:
+          write(false_byte)
+    return EncodeRepeatedField
+  else:
+    tag_bytes = TagBytes(field_number, wire_format.WIRETYPE_VARINT)
+    def EncodeField(write, value):
+      write(tag_bytes)
+      if value:
+        return write(true_byte)
+      return write(false_byte)
+    return EncodeField
+
+
+def StringEncoder(field_number, is_repeated, is_packed):
+  """Returns an encoder for a string field."""
+
+  tag = TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED)
+  local_EncodeVarint = _EncodeVarint
+  local_len = len
+  assert not is_packed
+  if is_repeated:
+    def EncodeRepeatedField(write, value):
+      for element in value:
+        encoded = element.encode('utf-8')
+        write(tag)
+        local_EncodeVarint(write, local_len(encoded))
+        write(encoded)
+    return EncodeRepeatedField
+  else:
+    def EncodeField(write, value):
+      encoded = value.encode('utf-8')
+      write(tag)
+      local_EncodeVarint(write, local_len(encoded))
+      return write(encoded)
+    return EncodeField
+
+
+def BytesEncoder(field_number, is_repeated, is_packed):
+  """Returns an encoder for a bytes field."""
+
+  tag = TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED)
+  local_EncodeVarint = _EncodeVarint
+  local_len = len
+  assert not is_packed
+  if is_repeated:
+    def EncodeRepeatedField(write, value):
+      for element in value:
+        write(tag)
+        local_EncodeVarint(write, local_len(element))
+        write(element)
+    return EncodeRepeatedField
+  else:
+    def EncodeField(write, value):
+      write(tag)
+      local_EncodeVarint(write, local_len(value))
+      return write(value)
+    return EncodeField
+
+
+def GroupEncoder(field_number, is_repeated, is_packed):
+  """Returns an encoder for a group field."""
+
+  start_tag = TagBytes(field_number, wire_format.WIRETYPE_START_GROUP)
+  end_tag = TagBytes(field_number, wire_format.WIRETYPE_END_GROUP)
+  assert not is_packed
+  if is_repeated:
+    def EncodeRepeatedField(write, value):
+      for element in value:
+        write(start_tag)
+        element._InternalSerialize(write)
+        write(end_tag)
+    return EncodeRepeatedField
+  else:
+    def EncodeField(write, value):
+      write(start_tag)
+      value._InternalSerialize(write)
+      return write(end_tag)
+    return EncodeField
+
+
+def MessageEncoder(field_number, is_repeated, is_packed):
+  """Returns an encoder for a message field."""
+
+  tag = TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED)
+  local_EncodeVarint = _EncodeVarint
+  assert not is_packed
+  if is_repeated:
+    def EncodeRepeatedField(write, value):
+      for element in value:
+        write(tag)
+        local_EncodeVarint(write, element.ByteSize())
+        element._InternalSerialize(write)
+    return EncodeRepeatedField
+  else:
+    def EncodeField(write, value):
+      write(tag)
+      local_EncodeVarint(write, value.ByteSize())
+      return value._InternalSerialize(write)
+    return EncodeField
+
+
+# --------------------------------------------------------------------
+# As before, MessageSet is special.
+
+
+def MessageSetItemEncoder(field_number):
+  """Encoder for extensions of MessageSet.
+
+  The message set message looks like this:
+    message MessageSet {
+      repeated group Item = 1 {
+        required int32 type_id = 2;
+        required string message = 3;
       }
-    """
-    self.AppendTag(1, wire_format.WIRETYPE_START_GROUP)
-    self.AppendInt32(2, field_number)
-    self.AppendMessage(3, msg)
-    self.AppendTag(1, wire_format.WIRETYPE_END_GROUP)
-
-  def AppendTag(self, field_number, wire_type):
-    """Appends a tag containing field number and wire type information."""
-    self._stream.AppendVarUInt32(wire_format.PackTag(field_number, wire_type))
+    }
+  """
+  start_bytes = "".join([
+      TagBytes(1, wire_format.WIRETYPE_START_GROUP),
+      TagBytes(2, wire_format.WIRETYPE_VARINT),
+      _VarintBytes(field_number),
+      TagBytes(3, wire_format.WIRETYPE_LENGTH_DELIMITED)])
+  end_bytes = TagBytes(1, wire_format.WIRETYPE_END_GROUP)
+  local_EncodeVarint = _EncodeVarint
+
+  def EncodeField(write, value):
+    write(start_bytes)
+    local_EncodeVarint(write, value.ByteSize())
+    value._InternalSerialize(write)
+    return write(end_bytes)
+
+  return EncodeField

+ 0 - 286
python/google/protobuf/internal/encoder_test.py

@@ -1,286 +0,0 @@
-#! /usr/bin/python
-#
-# Protocol Buffers - Google's data interchange format
-# Copyright 2008 Google Inc.  All rights reserved.
-# http://code.google.com/p/protobuf/
-#
-# Redistribution and use in source and binary forms, with or without
-# modification, are permitted provided that the following conditions are
-# met:
-#
-#     * Redistributions of source code must retain the above copyright
-# notice, this list of conditions and the following disclaimer.
-#     * Redistributions in binary form must reproduce the above
-# copyright notice, this list of conditions and the following disclaimer
-# in the documentation and/or other materials provided with the
-# distribution.
-#     * Neither the name of Google Inc. nor the names of its
-# contributors may be used to endorse or promote products derived from
-# this software without specific prior written permission.
-#
-# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
-# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
-# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
-# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
-# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
-# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
-# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
-# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
-# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
-# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
-# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
-"""Test for google.protobuf.internal.encoder."""
-
-__author__ = 'robinson@google.com (Will Robinson)'
-
-import struct
-import logging
-import unittest
-from google.protobuf.internal import wire_format
-from google.protobuf.internal import encoder
-from google.protobuf.internal import output_stream
-from google.protobuf import message
-import mox
-
-
-class EncoderTest(unittest.TestCase):
-
-  def setUp(self):
-    self.mox = mox.Mox()
-    self.encoder = encoder.Encoder()
-    self.mock_stream = self.mox.CreateMock(output_stream.OutputStream)
-    self.mock_message = self.mox.CreateMock(message.Message)
-    self.encoder._stream = self.mock_stream
-
-  def PackTag(self, field_number, wire_type):
-    return wire_format.PackTag(field_number, wire_type)
-
-  def AppendScalarTestHelper(self, test_name, encoder_method,
-                             expected_stream_method_name,
-                             wire_type, field_value,
-                             expected_value=None, expected_length=None,
-                             is_tag_test=True):
-    """Helper for testAppendScalars.
-
-    Calls one of the Encoder methods, and ensures that the Encoder
-    in turn makes the expected calls into its OutputStream.
-
-    Args:
-      test_name: Name of this test, used only for logging.
-      encoder_method: Callable on self.encoder. This is the Encoder
-        method we're testing.  If is_tag_test=True, the encoder method
-        accepts a field_number and field_value. if is_tag_test=False,
-        the encoder method accepts a field_value.
-      expected_stream_method_name: (string) Name of the OutputStream
-        method we expect Encoder to call to actually put the value
-        on the wire.
-      wire_type: The WIRETYPE_* constant we expect encoder to
-        use in the specified encoder_method.
-      field_value: The value we're trying to encode.  Passed
-        into encoder_method.
-      expected_value: The value we expect Encoder to pass into
-        the OutputStream method.  If None, we expect field_value
-        to pass through unmodified.
-      expected_length: The length we expect Encoder to pass to the
-        AppendVarUInt32 method. If None we expect the length of the
-        field_value.
-      is_tag_test: A Boolean.  If True (the default), we append the
-        the packed field number and wire_type to the stream before
-        the field value.
-    """
-    if expected_value is None:
-      expected_value = field_value
-
-    logging.info('Testing %s scalar output.\n'
-                 'Calling %r(%r), and expecting that to call the '
-                 'stream method %s(%r).' % (
-        test_name, encoder_method, field_value,
-        expected_stream_method_name, expected_value))
-
-    if is_tag_test:
-      field_number = 10
-      # Should first append the field number and type information.
-      self.mock_stream.AppendVarUInt32(self.PackTag(field_number, wire_type))
-      # If we're length-delimited, we should then append the length.
-      if wire_type == wire_format.WIRETYPE_LENGTH_DELIMITED:
-        if expected_length is None:
-          expected_length = len(field_value)
-        self.mock_stream.AppendVarUInt32(expected_length)
-
-    # Should then append the value itself.
-    # We have to use names instead of methods to work around some
-    # mox weirdness.  (ResetAll() is overzealous).
-    expected_stream_method = getattr(self.mock_stream,
-                                     expected_stream_method_name)
-    expected_stream_method(expected_value)
-
-    self.mox.ReplayAll()
-    if is_tag_test:
-      encoder_method(field_number, field_value)
-    else:
-      encoder_method(field_value)
-    self.mox.VerifyAll()
-    self.mox.ResetAll()
-
-  VAL = 1.125  # Perfectly representable as a float (no rounding error).
-  LITTLE_FLOAT_VAL = '\x00\x00\x90?'
-  LITTLE_DOUBLE_VAL = '\x00\x00\x00\x00\x00\x00\xf2?'
-
-  def testAppendScalars(self):
-    utf8_bytes = '\xd0\xa2\xd0\xb5\xd1\x81\xd1\x82'
-    utf8_string = unicode(utf8_bytes, 'utf-8')
-    scalar_tests = [
-        ['int32', self.encoder.AppendInt32, 'AppendVarint32',
-         wire_format.WIRETYPE_VARINT, 0],
-        ['int64', self.encoder.AppendInt64, 'AppendVarint64',
-         wire_format.WIRETYPE_VARINT, 0],
-        ['uint32', self.encoder.AppendUInt32, 'AppendVarUInt32',
-         wire_format.WIRETYPE_VARINT, 0],
-        ['uint64', self.encoder.AppendUInt64, 'AppendVarUInt64',
-         wire_format.WIRETYPE_VARINT, 0],
-        ['fixed32', self.encoder.AppendFixed32, 'AppendLittleEndian32',
-         wire_format.WIRETYPE_FIXED32, 0],
-        ['fixed64', self.encoder.AppendFixed64, 'AppendLittleEndian64',
-         wire_format.WIRETYPE_FIXED64, 0],
-        ['sfixed32', self.encoder.AppendSFixed32, 'AppendLittleEndian32',
-         wire_format.WIRETYPE_FIXED32, -1, 0xffffffff],
-        ['sfixed64', self.encoder.AppendSFixed64, 'AppendLittleEndian64',
-         wire_format.WIRETYPE_FIXED64, -1, 0xffffffffffffffff],
-        ['float', self.encoder.AppendFloat, 'AppendRawBytes',
-         wire_format.WIRETYPE_FIXED32, self.VAL, self.LITTLE_FLOAT_VAL],
-        ['double', self.encoder.AppendDouble, 'AppendRawBytes',
-         wire_format.WIRETYPE_FIXED64, self.VAL, self.LITTLE_DOUBLE_VAL],
-        ['bool', self.encoder.AppendBool, 'AppendVarint32',
-         wire_format.WIRETYPE_VARINT, False],
-        ['enum', self.encoder.AppendEnum, 'AppendVarint32',
-         wire_format.WIRETYPE_VARINT, 0],
-        ['string', self.encoder.AppendString, 'AppendRawBytes',
-         wire_format.WIRETYPE_LENGTH_DELIMITED,
-         "You're in a maze of twisty little passages, all alike."],
-        ['utf8-string', self.encoder.AppendString, 'AppendRawBytes',
-         wire_format.WIRETYPE_LENGTH_DELIMITED, utf8_string,
-         utf8_bytes, len(utf8_bytes)],
-        # We test zigzag encoding routines more extensively below.
-        ['sint32', self.encoder.AppendSInt32, 'AppendVarUInt32',
-         wire_format.WIRETYPE_VARINT, -1, 1],
-        ['sint64', self.encoder.AppendSInt64, 'AppendVarUInt64',
-         wire_format.WIRETYPE_VARINT, -1, 1],
-        ]
-    # Ensure that we're testing different Encoder methods and using
-    # different test names in all test cases above.
-    self.assertEqual(len(scalar_tests), len(set(t[0] for t in scalar_tests)))
-    self.assert_(len(scalar_tests) >= len(set(t[1] for t in scalar_tests)))
-    for args in scalar_tests:
-      self.AppendScalarTestHelper(*args)
-
-  def testAppendScalarsWithoutTags(self):
-    scalar_no_tag_tests = [
-        ['int32', self.encoder.AppendInt32NoTag, 'AppendVarint32', None, 0],
-        ['int64', self.encoder.AppendInt64NoTag, 'AppendVarint64', None, 0],
-        ['uint32', self.encoder.AppendUInt32NoTag, 'AppendVarUInt32', None, 0],
-        ['uint64', self.encoder.AppendUInt64NoTag, 'AppendVarUInt64', None, 0],
-        ['fixed32', self.encoder.AppendFixed32NoTag,
-         'AppendLittleEndian32', None, 0],
-        ['fixed64', self.encoder.AppendFixed64NoTag,
-         'AppendLittleEndian64', None, 0],
-        ['sfixed32', self.encoder.AppendSFixed32NoTag,
-         'AppendLittleEndian32', None, 0],
-        ['sfixed64', self.encoder.AppendSFixed64NoTag,
-         'AppendLittleEndian64', None, 0],
-        ['float', self.encoder.AppendFloatNoTag,
-         'AppendRawBytes', None, self.VAL, self.LITTLE_FLOAT_VAL],
-        ['double', self.encoder.AppendDoubleNoTag,
-         'AppendRawBytes', None, self.VAL, self.LITTLE_DOUBLE_VAL],
-        ['bool', self.encoder.AppendBoolNoTag, 'AppendVarint32', None, 0],
-        ['enum', self.encoder.AppendEnumNoTag, 'AppendVarint32', None, 0],
-        ['sint32', self.encoder.AppendSInt32NoTag,
-         'AppendVarUInt32', None, -1, 1],
-        ['sint64', self.encoder.AppendSInt64NoTag,
-         'AppendVarUInt64', None, -1, 1],
-    ]
-
-    self.assertEqual(len(scalar_no_tag_tests),
-                     len(set(t[0] for t in scalar_no_tag_tests)))
-    self.assert_(len(scalar_no_tag_tests) >=
-                 len(set(t[1] for t in scalar_no_tag_tests)))
-    for args in scalar_no_tag_tests:
-      # For no tag tests, the wire_type is not used, so we put in None.
-      self.AppendScalarTestHelper(is_tag_test=False, *args)
-
-  def testAppendGroup(self):
-    field_number = 23
-    # Should first append the start-group marker.
-    self.mock_stream.AppendVarUInt32(
-        self.PackTag(field_number, wire_format.WIRETYPE_START_GROUP))
-    # Should then serialize itself.
-    self.mock_message.SerializeToString().AndReturn('foo')
-    self.mock_stream.AppendRawBytes('foo')
-    # Should finally append the end-group marker.
-    self.mock_stream.AppendVarUInt32(
-        self.PackTag(field_number, wire_format.WIRETYPE_END_GROUP))
-
-    self.mox.ReplayAll()
-    self.encoder.AppendGroup(field_number, self.mock_message)
-    self.mox.VerifyAll()
-
-  def testAppendMessage(self):
-    field_number = 23
-    byte_size = 42
-    # Should first append the field number and type information.
-    self.mock_stream.AppendVarUInt32(
-        self.PackTag(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED))
-    # Should then append its length.
-    self.mock_message.ByteSize().AndReturn(byte_size)
-    self.mock_stream.AppendVarUInt32(byte_size)
-    # Should then serialize itself to the encoder.
-    self.mock_message.SerializeToString().AndReturn('foo')
-    self.mock_stream.AppendRawBytes('foo')
-
-    self.mox.ReplayAll()
-    self.encoder.AppendMessage(field_number, self.mock_message)
-    self.mox.VerifyAll()
-
-  def testAppendMessageSetItem(self):
-    field_number = 23
-    byte_size = 42
-    # Should first append the field number and type information.
-    self.mock_stream.AppendVarUInt32(
-        self.PackTag(1, wire_format.WIRETYPE_START_GROUP))
-    self.mock_stream.AppendVarUInt32(
-        self.PackTag(2, wire_format.WIRETYPE_VARINT))
-    self.mock_stream.AppendVarint32(field_number)
-    self.mock_stream.AppendVarUInt32(
-        self.PackTag(3, wire_format.WIRETYPE_LENGTH_DELIMITED))
-    # Should then append its length.
-    self.mock_message.ByteSize().AndReturn(byte_size)
-    self.mock_stream.AppendVarUInt32(byte_size)
-    # Should then serialize itself to the encoder.
-    self.mock_message.SerializeToString().AndReturn('foo')
-    self.mock_stream.AppendRawBytes('foo')
-    self.mock_stream.AppendVarUInt32(
-        self.PackTag(1, wire_format.WIRETYPE_END_GROUP))
-
-    self.mox.ReplayAll()
-    self.encoder.AppendMessageSetItem(field_number, self.mock_message)
-    self.mox.VerifyAll()
-
-  def testAppendSFixed(self):
-    # Most of our bounds-checking is done in output_stream.py,
-    # but encoder.py is responsible for transforming signed
-    # fixed-width integers into unsigned ones, so we test here
-    # to ensure that we're not losing any entropy when we do
-    # that conversion.
-    field_number = 10
-    self.assertRaises(message.EncodeError, self.encoder.AppendSFixed32,
-        10, wire_format.UINT32_MAX + 1)
-    self.assertRaises(message.EncodeError, self.encoder.AppendSFixed32,
-        10, -(1 << 32))
-    self.assertRaises(message.EncodeError, self.encoder.AppendSFixed64,
-        10, wire_format.UINT64_MAX + 1)
-    self.assertRaises(message.EncodeError, self.encoder.AppendSFixed64,
-        10, -(1 << 64))
-
-
-if __name__ == '__main__':
-  unittest.main()

+ 106 - 1
python/google/protobuf/internal/generator_test.py

@@ -35,15 +35,20 @@
 # indirect testing of the protocol compiler output.
 
 """Unittest that directly tests the output of the pure-Python protocol
-compiler.  See //net/proto2/internal/reflection_test.py for a test which
+compiler.  See //google/protobuf/reflection_test.py for a test which
 further ensures that we can use Python protocol message objects as we expect.
 """
 
 __author__ = 'robinson@google.com (Will Robinson)'
 
 import unittest
+from google.protobuf import unittest_import_pb2
 from google.protobuf import unittest_mset_pb2
 from google.protobuf import unittest_pb2
+from google.protobuf import unittest_no_generic_services_pb2
+
+
+MAX_EXTENSION = 536870912
 
 
 class GeneratorTest(unittest.TestCase):
@@ -71,6 +76,31 @@ class GeneratorTest(unittest.TestCase):
     self.assertEqual(3, proto.BAZ)
     self.assertEqual(3, unittest_pb2.TestAllTypes.BAZ)
 
+  def testExtremeDefaultValues(self):
+    message = unittest_pb2.TestExtremeDefaultValues()
+    self.assertEquals(float('inf'), message.inf_double)
+    self.assertEquals(float('-inf'), message.neg_inf_double)
+    self.assert_(message.nan_double != message.nan_double)
+    self.assertEquals(float('inf'), message.inf_float)
+    self.assertEquals(float('-inf'), message.neg_inf_float)
+    self.assert_(message.nan_float != message.nan_float)
+
+  def testHasDefaultValues(self):
+    desc = unittest_pb2.TestAllTypes.DESCRIPTOR
+
+    expected_has_default_by_name = {
+        'optional_int32': False,
+        'repeated_int32': False,
+        'optional_nested_message': False,
+        'default_int32': True,
+    }
+
+    has_default_by_name = dict(
+        [(f.name, f.has_default_value)
+         for f in desc.fields
+         if f.name in expected_has_default_by_name])
+    self.assertEqual(expected_has_default_by_name, has_default_by_name)
+
   def testContainingTypeBehaviorForExtensions(self):
     self.assertEqual(unittest_pb2.optional_int32_extension.containing_type,
                      unittest_pb2.TestAllExtensions.DESCRIPTOR)
@@ -95,6 +125,81 @@ class GeneratorTest(unittest.TestCase):
     proto = unittest_mset_pb2.TestMessageSet()
     self.assertTrue(proto.DESCRIPTOR.GetOptions().message_set_wire_format)
 
+  def testNestedTypes(self):
+    self.assertEquals(
+        set(unittest_pb2.TestAllTypes.DESCRIPTOR.nested_types),
+        set([
+            unittest_pb2.TestAllTypes.NestedMessage.DESCRIPTOR,
+            unittest_pb2.TestAllTypes.OptionalGroup.DESCRIPTOR,
+            unittest_pb2.TestAllTypes.RepeatedGroup.DESCRIPTOR,
+        ]))
+    self.assertEqual(unittest_pb2.TestEmptyMessage.DESCRIPTOR.nested_types, [])
+    self.assertEqual(
+        unittest_pb2.TestAllTypes.NestedMessage.DESCRIPTOR.nested_types, [])
+
+  def testContainingType(self):
+    self.assertTrue(
+        unittest_pb2.TestEmptyMessage.DESCRIPTOR.containing_type is None)
+    self.assertTrue(
+        unittest_pb2.TestAllTypes.DESCRIPTOR.containing_type is None)
+    self.assertEqual(
+        unittest_pb2.TestAllTypes.NestedMessage.DESCRIPTOR.containing_type,
+        unittest_pb2.TestAllTypes.DESCRIPTOR)
+    self.assertEqual(
+        unittest_pb2.TestAllTypes.NestedMessage.DESCRIPTOR.containing_type,
+        unittest_pb2.TestAllTypes.DESCRIPTOR)
+    self.assertEqual(
+        unittest_pb2.TestAllTypes.RepeatedGroup.DESCRIPTOR.containing_type,
+        unittest_pb2.TestAllTypes.DESCRIPTOR)
+
+  def testContainingTypeInEnumDescriptor(self):
+    self.assertTrue(unittest_pb2._FOREIGNENUM.containing_type is None)
+    self.assertEqual(unittest_pb2._TESTALLTYPES_NESTEDENUM.containing_type,
+                     unittest_pb2.TestAllTypes.DESCRIPTOR)
+
+  def testPackage(self):
+    self.assertEqual(
+        unittest_pb2.TestAllTypes.DESCRIPTOR.file.package,
+        'protobuf_unittest')
+    desc = unittest_pb2.TestAllTypes.NestedMessage.DESCRIPTOR
+    self.assertEqual(desc.file.package, 'protobuf_unittest')
+    self.assertEqual(
+        unittest_import_pb2.ImportMessage.DESCRIPTOR.file.package,
+        'protobuf_unittest_import')
+
+    self.assertEqual(
+        unittest_pb2._FOREIGNENUM.file.package, 'protobuf_unittest')
+    self.assertEqual(
+        unittest_pb2._TESTALLTYPES_NESTEDENUM.file.package,
+        'protobuf_unittest')
+    self.assertEqual(
+        unittest_import_pb2._IMPORTENUM.file.package,
+        'protobuf_unittest_import')
+
+  def testExtensionRange(self):
+    self.assertEqual(
+        unittest_pb2.TestAllTypes.DESCRIPTOR.extension_ranges, [])
+    self.assertEqual(
+        unittest_pb2.TestAllExtensions.DESCRIPTOR.extension_ranges,
+        [(1, MAX_EXTENSION)])
+    self.assertEqual(
+        unittest_pb2.TestMultipleExtensionRanges.DESCRIPTOR.extension_ranges,
+        [(42, 43), (4143, 4244), (65536, MAX_EXTENSION)])
+
+  def testFileDescriptor(self):
+    self.assertEqual(unittest_pb2.DESCRIPTOR.name,
+                     'google/protobuf/unittest.proto')
+    self.assertEqual(unittest_pb2.DESCRIPTOR.package, 'protobuf_unittest')
+    self.assertFalse(unittest_pb2.DESCRIPTOR.serialized_pb is None)
+
+  def testNoGenericServices(self):
+    # unittest_no_generic_services.proto should contain defs for everything
+    # except services.
+    self.assertTrue(hasattr(unittest_no_generic_services_pb2, "TestMessage"))
+    self.assertTrue(hasattr(unittest_no_generic_services_pb2, "FOO"))
+    self.assertTrue(hasattr(unittest_no_generic_services_pb2, "test_extension"))
+    self.assertFalse(hasattr(unittest_no_generic_services_pb2, "TestService"))
+
 
 if __name__ == '__main__':
   unittest.main()

+ 0 - 338
python/google/protobuf/internal/input_stream.py

@@ -1,338 +0,0 @@
-# Protocol Buffers - Google's data interchange format
-# Copyright 2008 Google Inc.  All rights reserved.
-# http://code.google.com/p/protobuf/
-#
-# Redistribution and use in source and binary forms, with or without
-# modification, are permitted provided that the following conditions are
-# met:
-#
-#     * Redistributions of source code must retain the above copyright
-# notice, this list of conditions and the following disclaimer.
-#     * Redistributions in binary form must reproduce the above
-# copyright notice, this list of conditions and the following disclaimer
-# in the documentation and/or other materials provided with the
-# distribution.
-#     * Neither the name of Google Inc. nor the names of its
-# contributors may be used to endorse or promote products derived from
-# this software without specific prior written permission.
-#
-# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
-# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
-# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
-# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
-# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
-# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
-# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
-# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
-# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
-# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
-# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
-"""InputStream is the primitive interface for reading bits from the wire.
-
-All protocol buffer deserialization can be expressed in terms of
-the InputStream primitives provided here.
-"""
-
-__author__ = 'robinson@google.com (Will Robinson)'
-
-import array
-import struct
-from google.protobuf import message
-from google.protobuf.internal import wire_format
-
-
-# Note that much of this code is ported from //net/proto/ProtocolBuffer, and
-# that the interface is strongly inspired by CodedInputStream from the C++
-# proto2 implementation.
-
-
-class InputStreamBuffer(object):
-
-  """Contains all logic for reading bits, and dealing with stream position.
-
-  If an InputStream method ever raises an exception, the stream is left
-  in an indeterminate state and is not safe for further use.
-  """
-
-  def __init__(self, s):
-    # What we really want is something like array('B', s), where elements we
-    # read from the array are already given to us as one-byte integers.  BUT
-    # using array() instead of buffer() would force full string copies to result
-    # from each GetSubBuffer() call.
-    #
-    # So, if the N serialized bytes of a single protocol buffer object are
-    # split evenly between 2 child messages, and so on recursively, using
-    # array('B', s) instead of buffer() would incur an additional N*logN bytes
-    # copied during deserialization.
-    #
-    # The higher constant overhead of having to ord() for every byte we read
-    # from the buffer in _ReadVarintHelper() could definitely lead to worse
-    # performance in many real-world scenarios, even if the asymptotic
-    # complexity is better.  However, our real answer is that the mythical
-    # Python/C extension module output mode for the protocol compiler will
-    # be blazing-fast and will eliminate most use of this class anyway.
-    self._buffer = buffer(s)
-    self._pos = 0
-
-  def EndOfStream(self):
-    """Returns true iff we're at the end of the stream.
-    If this returns true, then a call to any other InputStream method
-    will raise an exception.
-    """
-    return self._pos >= len(self._buffer)
-
-  def Position(self):
-    """Returns the current position in the stream, or equivalently, the
-    number of bytes read so far.
-    """
-    return self._pos
-
-  def GetSubBuffer(self, size=None):
-    """Returns a sequence-like object that represents a portion of our
-    underlying sequence.
-
-    Position 0 in the returned object corresponds to self.Position()
-    in this stream.
-
-    If size is specified, then the returned object ends after the
-    next "size" bytes in this stream.  If size is not specified,
-    then the returned object ends at the end of this stream.
-
-    We guarantee that the returned object R supports the Python buffer
-    interface (and thus that the call buffer(R) will work).
-
-    Note that the returned buffer is read-only.
-
-    The intended use for this method is for nested-message and nested-group
-    deserialization, where we want to make a recursive MergeFromString()
-    call on the portion of the original sequence that contains the serialized
-    nested message.  (And we'd like to do so without making unnecessary string
-    copies).
-
-    REQUIRES: size is nonnegative.
-    """
-    # Note that buffer() doesn't perform any actual string copy.
-    if size is None:
-      return buffer(self._buffer, self._pos)
-    else:
-      if size < 0:
-        raise message.DecodeError('Negative size %d' % size)
-      return buffer(self._buffer, self._pos, size)
-
-  def SkipBytes(self, num_bytes):
-    """Skip num_bytes bytes ahead, or go to the end of the stream, whichever
-    comes first.
-
-    REQUIRES: num_bytes is nonnegative.
-    """
-    if num_bytes < 0:
-      raise message.DecodeError('Negative num_bytes %d' % num_bytes)
-    self._pos += num_bytes
-    self._pos = min(self._pos, len(self._buffer))
-
-  def ReadBytes(self, size):
-    """Reads up to 'size' bytes from the stream, stopping early
-    only if we reach the end of the stream.  Returns the bytes read
-    as a string.
-    """
-    if size < 0:
-      raise message.DecodeError('Negative size %d' % size)
-    s = (self._buffer[self._pos : self._pos + size])
-    self._pos += len(s)  # Only advance by the number of bytes actually read.
-    return s
-
-  def ReadLittleEndian32(self):
-    """Interprets the next 4 bytes of the stream as a little-endian
-    encoded, unsiged 32-bit integer, and returns that integer.
-    """
-    try:
-      i = struct.unpack(wire_format.FORMAT_UINT32_LITTLE_ENDIAN,
-                        self._buffer[self._pos : self._pos + 4])
-      self._pos += 4
-      return i[0]  # unpack() result is a 1-element tuple.
-    except struct.error, e:
-      raise message.DecodeError(e)
-
-  def ReadLittleEndian64(self):
-    """Interprets the next 8 bytes of the stream as a little-endian
-    encoded, unsiged 64-bit integer, and returns that integer.
-    """
-    try:
-      i = struct.unpack(wire_format.FORMAT_UINT64_LITTLE_ENDIAN,
-                        self._buffer[self._pos : self._pos + 8])
-      self._pos += 8
-      return i[0]  # unpack() result is a 1-element tuple.
-    except struct.error, e:
-      raise message.DecodeError(e)
-
-  def ReadVarint32(self):
-    """Reads a varint from the stream, interprets this varint
-    as a signed, 32-bit integer, and returns the integer.
-    """
-    i = self.ReadVarint64()
-    if not wire_format.INT32_MIN <= i <= wire_format.INT32_MAX:
-      raise message.DecodeError('Value out of range for int32: %d' % i)
-    return int(i)
-
-  def ReadVarUInt32(self):
-    """Reads a varint from the stream, interprets this varint
-    as an unsigned, 32-bit integer, and returns the integer.
-    """
-    i = self.ReadVarUInt64()
-    if i > wire_format.UINT32_MAX:
-      raise message.DecodeError('Value out of range for uint32: %d' % i)
-    return i
-
-  def ReadVarint64(self):
-    """Reads a varint from the stream, interprets this varint
-    as a signed, 64-bit integer, and returns the integer.
-    """
-    i = self.ReadVarUInt64()
-    if i > wire_format.INT64_MAX:
-      i -= (1 << 64)
-    return i
-
-  def ReadVarUInt64(self):
-    """Reads a varint from the stream, interprets this varint
-    as an unsigned, 64-bit integer, and returns the integer.
-    """
-    i = self._ReadVarintHelper()
-    if not 0 <= i <= wire_format.UINT64_MAX:
-      raise message.DecodeError('Value out of range for uint64: %d' % i)
-    return i
-
-  def _ReadVarintHelper(self):
-    """Helper for the various varint-reading methods above.
-    Reads an unsigned, varint-encoded integer from the stream and
-    returns this integer.
-
-    Does no bounds checking except to ensure that we read at most as many bytes
-    as could possibly be present in a varint-encoded 64-bit number.
-    """
-    result = 0
-    shift = 0
-    while 1:
-      if shift >= 64:
-        raise message.DecodeError('Too many bytes when decoding varint.')
-      try:
-        b = ord(self._buffer[self._pos])
-      except IndexError:
-        raise message.DecodeError('Truncated varint.')
-      self._pos += 1
-      result |= ((b & 0x7f) << shift)
-      shift += 7
-      if not (b & 0x80):
-        return result
-
-
-class InputStreamArray(object):
-
-  """Contains all logic for reading bits, and dealing with stream position.
-
-  If an InputStream method ever raises an exception, the stream is left
-  in an indeterminate state and is not safe for further use.
-
-  This alternative to InputStreamBuffer is used in environments where buffer()
-  is unavailble, such as Google App Engine.
-  """
-
-  def __init__(self, s):
-    self._buffer = array.array('B', s)
-    self._pos = 0
-
-  def EndOfStream(self):
-    return self._pos >= len(self._buffer)
-
-  def Position(self):
-    return self._pos
-
-  def GetSubBuffer(self, size=None):
-    if size is None:
-      return self._buffer[self._pos : ].tostring()
-    else:
-      if size < 0:
-        raise message.DecodeError('Negative size %d' % size)
-      return self._buffer[self._pos : self._pos + size].tostring()
-
-  def SkipBytes(self, num_bytes):
-    if num_bytes < 0:
-      raise message.DecodeError('Negative num_bytes %d' % num_bytes)
-    self._pos += num_bytes
-    self._pos = min(self._pos, len(self._buffer))
-
-  def ReadBytes(self, size):
-    if size < 0:
-      raise message.DecodeError('Negative size %d' % size)
-    s = self._buffer[self._pos : self._pos + size].tostring()
-    self._pos += len(s)  # Only advance by the number of bytes actually read.
-    return s
-
-  def ReadLittleEndian32(self):
-    try:
-      i = struct.unpack(wire_format.FORMAT_UINT32_LITTLE_ENDIAN,
-                        self._buffer[self._pos : self._pos + 4])
-      self._pos += 4
-      return i[0]  # unpack() result is a 1-element tuple.
-    except struct.error, e:
-      raise message.DecodeError(e)
-
-  def ReadLittleEndian64(self):
-    try:
-      i = struct.unpack(wire_format.FORMAT_UINT64_LITTLE_ENDIAN,
-                        self._buffer[self._pos : self._pos + 8])
-      self._pos += 8
-      return i[0]  # unpack() result is a 1-element tuple.
-    except struct.error, e:
-      raise message.DecodeError(e)
-
-  def ReadVarint32(self):
-    i = self.ReadVarint64()
-    if not wire_format.INT32_MIN <= i <= wire_format.INT32_MAX:
-      raise message.DecodeError('Value out of range for int32: %d' % i)
-    return int(i)
-
-  def ReadVarUInt32(self):
-    i = self.ReadVarUInt64()
-    if i > wire_format.UINT32_MAX:
-      raise message.DecodeError('Value out of range for uint32: %d' % i)
-    return i
-
-  def ReadVarint64(self):
-    i = self.ReadVarUInt64()
-    if i > wire_format.INT64_MAX:
-      i -= (1 << 64)
-    return i
-
-  def ReadVarUInt64(self):
-    i = self._ReadVarintHelper()
-    if not 0 <= i <= wire_format.UINT64_MAX:
-      raise message.DecodeError('Value out of range for uint64: %d' % i)
-    return i
-
-  def _ReadVarintHelper(self):
-    result = 0
-    shift = 0
-    while 1:
-      if shift >= 64:
-        raise message.DecodeError('Too many bytes when decoding varint.')
-      try:
-        b = self._buffer[self._pos]
-      except IndexError:
-        raise message.DecodeError('Truncated varint.')
-      self._pos += 1
-      result |= ((b & 0x7f) << shift)
-      shift += 7
-      if not (b & 0x80):
-        return result
-
-
-try:
-  buffer('')
-  InputStream = InputStreamBuffer
-except NotImplementedError:
-  # Google App Engine: dev_appserver.py
-  InputStream = InputStreamArray
-except RuntimeError:
-  # Google App Engine: production
-  InputStream = InputStreamArray

+ 0 - 314
python/google/protobuf/internal/input_stream_test.py

@@ -1,314 +0,0 @@
-#! /usr/bin/python
-#
-# Protocol Buffers - Google's data interchange format
-# Copyright 2008 Google Inc.  All rights reserved.
-# http://code.google.com/p/protobuf/
-#
-# Redistribution and use in source and binary forms, with or without
-# modification, are permitted provided that the following conditions are
-# met:
-#
-#     * Redistributions of source code must retain the above copyright
-# notice, this list of conditions and the following disclaimer.
-#     * Redistributions in binary form must reproduce the above
-# copyright notice, this list of conditions and the following disclaimer
-# in the documentation and/or other materials provided with the
-# distribution.
-#     * Neither the name of Google Inc. nor the names of its
-# contributors may be used to endorse or promote products derived from
-# this software without specific prior written permission.
-#
-# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
-# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
-# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
-# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
-# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
-# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
-# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
-# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
-# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
-# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
-# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
-"""Test for google.protobuf.internal.input_stream."""
-
-__author__ = 'robinson@google.com (Will Robinson)'
-
-import unittest
-from google.protobuf import message
-from google.protobuf.internal import wire_format
-from google.protobuf.internal import input_stream
-
-
-class InputStreamBufferTest(unittest.TestCase):
-
-  def setUp(self):
-    self.__original_input_stream = input_stream.InputStream
-    input_stream.InputStream = input_stream.InputStreamBuffer
-
-  def tearDown(self):
-    input_stream.InputStream = self.__original_input_stream
-
-  def testEndOfStream(self):
-    stream = input_stream.InputStream('abcd')
-    self.assertFalse(stream.EndOfStream())
-    self.assertEqual('abcd', stream.ReadBytes(10))
-    self.assertTrue(stream.EndOfStream())
-
-  def testPosition(self):
-    stream = input_stream.InputStream('abcd')
-    self.assertEqual(0, stream.Position())
-    self.assertEqual(0, stream.Position())  # No side-effects.
-    stream.ReadBytes(1)
-    self.assertEqual(1, stream.Position())
-    stream.ReadBytes(1)
-    self.assertEqual(2, stream.Position())
-    stream.ReadBytes(10)
-    self.assertEqual(4, stream.Position())  # Can't go past end of stream.
-
-  def testGetSubBuffer(self):
-    stream = input_stream.InputStream('abcd')
-    # Try leaving out the size.
-    self.assertEqual('abcd', str(stream.GetSubBuffer()))
-    stream.SkipBytes(1)
-    # GetSubBuffer() always starts at current size.
-    self.assertEqual('bcd', str(stream.GetSubBuffer()))
-    # Try 0-size.
-    self.assertEqual('', str(stream.GetSubBuffer(0)))
-    # Negative sizes should raise an error.
-    self.assertRaises(message.DecodeError, stream.GetSubBuffer, -1)
-    # Positive sizes should work as expected.
-    self.assertEqual('b', str(stream.GetSubBuffer(1)))
-    self.assertEqual('bc', str(stream.GetSubBuffer(2)))
-    # Sizes longer than remaining bytes in the buffer should
-    # return the whole remaining buffer.
-    self.assertEqual('bcd', str(stream.GetSubBuffer(1000)))
-
-  def testSkipBytes(self):
-    stream = input_stream.InputStream('')
-    # Skipping bytes when at the end of stream
-    # should have no effect.
-    stream.SkipBytes(0)
-    stream.SkipBytes(1)
-    stream.SkipBytes(2)
-    self.assertTrue(stream.EndOfStream())
-    self.assertEqual(0, stream.Position())
-
-    # Try skipping within a stream.
-    stream = input_stream.InputStream('abcd')
-    self.assertEqual(0, stream.Position())
-    stream.SkipBytes(1)
-    self.assertEqual(1, stream.Position())
-    stream.SkipBytes(10)  # Can't skip past the end.
-    self.assertEqual(4, stream.Position())
-
-    # Ensure that a negative skip raises an exception.
-    stream = input_stream.InputStream('abcd')
-    stream.SkipBytes(1)
-    self.assertRaises(message.DecodeError, stream.SkipBytes, -1)
-
-  def testReadBytes(self):
-    s = 'abcd'
-    # Also test going past the total stream length.
-    for i in range(len(s) + 10):
-      stream = input_stream.InputStream(s)
-      self.assertEqual(s[:i], stream.ReadBytes(i))
-      self.assertEqual(min(i, len(s)), stream.Position())
-    stream = input_stream.InputStream(s)
-    self.assertRaises(message.DecodeError, stream.ReadBytes, -1)
-
-  def EnsureFailureOnEmptyStream(self, input_stream_method):
-    """Helper for integer-parsing tests below.
-    Ensures that the given InputStream method raises a DecodeError
-    if called on a stream with no bytes remaining.
-    """
-    stream = input_stream.InputStream('')
-    self.assertRaises(message.DecodeError, input_stream_method, stream)
-
-  def testReadLittleEndian32(self):
-    self.EnsureFailureOnEmptyStream(input_stream.InputStream.ReadLittleEndian32)
-    s = ''
-    # Read 0.
-    s += '\x00\x00\x00\x00'
-    # Read 1.
-    s += '\x01\x00\x00\x00'
-    # Read a bunch of different bytes.
-    s += '\x01\x02\x03\x04'
-    # Read max unsigned 32-bit int.
-    s += '\xff\xff\xff\xff'
-    # Try a read with fewer than 4 bytes left in the stream.
-    s += '\x00\x00\x00'
-    stream = input_stream.InputStream(s)
-    self.assertEqual(0, stream.ReadLittleEndian32())
-    self.assertEqual(4, stream.Position())
-    self.assertEqual(1, stream.ReadLittleEndian32())
-    self.assertEqual(8, stream.Position())
-    self.assertEqual(0x04030201, stream.ReadLittleEndian32())
-    self.assertEqual(12, stream.Position())
-    self.assertEqual(wire_format.UINT32_MAX, stream.ReadLittleEndian32())
-    self.assertEqual(16, stream.Position())
-    self.assertRaises(message.DecodeError, stream.ReadLittleEndian32)
-
-  def testReadLittleEndian64(self):
-    self.EnsureFailureOnEmptyStream(input_stream.InputStream.ReadLittleEndian64)
-    s = ''
-    # Read 0.
-    s += '\x00\x00\x00\x00\x00\x00\x00\x00'
-    # Read 1.
-    s += '\x01\x00\x00\x00\x00\x00\x00\x00'
-    # Read a bunch of different bytes.
-    s += '\x01\x02\x03\x04\x05\x06\x07\x08'
-    # Read max unsigned 64-bit int.
-    s += '\xff\xff\xff\xff\xff\xff\xff\xff'
-    # Try a read with fewer than 8 bytes left in the stream.
-    s += '\x00\x00\x00'
-    stream = input_stream.InputStream(s)
-    self.assertEqual(0, stream.ReadLittleEndian64())
-    self.assertEqual(8, stream.Position())
-    self.assertEqual(1, stream.ReadLittleEndian64())
-    self.assertEqual(16, stream.Position())
-    self.assertEqual(0x0807060504030201, stream.ReadLittleEndian64())
-    self.assertEqual(24, stream.Position())
-    self.assertEqual(wire_format.UINT64_MAX, stream.ReadLittleEndian64())
-    self.assertEqual(32, stream.Position())
-    self.assertRaises(message.DecodeError, stream.ReadLittleEndian64)
-
-  def ReadVarintSuccessTestHelper(self, varints_and_ints, read_method):
-    """Helper for tests below that test successful reads of various varints.
-
-    Args:
-      varints_and_ints: Iterable of (str, integer) pairs, where the string
-        gives the wire encoding and the integer gives the value we expect
-        to be returned by the read_method upon encountering this string.
-      read_method: Unbound InputStream method that is capable of reading
-        the encoded strings provided in the first elements of varints_and_ints.
-    """
-    s = ''.join(s for s, i in varints_and_ints)
-    stream = input_stream.InputStream(s)
-    expected_pos = 0
-    self.assertEqual(expected_pos, stream.Position())
-    for s, expected_int in varints_and_ints:
-      self.assertEqual(expected_int, read_method(stream))
-      expected_pos += len(s)
-      self.assertEqual(expected_pos, stream.Position())
-
-  def testReadVarint32Success(self):
-    varints_and_ints = [
-        ('\x00', 0),
-        ('\x01', 1),
-        ('\x7f', 127),
-        ('\x80\x01', 128),
-        ('\xff\xff\xff\xff\xff\xff\xff\xff\xff\x01', -1),
-        ('\xff\xff\xff\xff\x07', wire_format.INT32_MAX),
-        ('\x80\x80\x80\x80\xf8\xff\xff\xff\xff\x01', wire_format.INT32_MIN),
-        ]
-    self.ReadVarintSuccessTestHelper(varints_and_ints,
-                                     input_stream.InputStream.ReadVarint32)
-
-  def testReadVarint32Failure(self):
-    self.EnsureFailureOnEmptyStream(input_stream.InputStream.ReadVarint32)
-
-    # Try and fail to read INT32_MAX + 1.
-    s = '\x80\x80\x80\x80\x08'
-    stream = input_stream.InputStream(s)
-    self.assertRaises(message.DecodeError, stream.ReadVarint32)
-
-    # Try and fail to read INT32_MIN - 1.
-    s = '\xfe\xff\xff\xff\xf7\xff\xff\xff\xff\x01'
-    stream = input_stream.InputStream(s)
-    self.assertRaises(message.DecodeError, stream.ReadVarint32)
-
-    # Try and fail to read something that looks like
-    # a varint with more than 10 bytes.
-    s = '\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\x00'
-    stream = input_stream.InputStream(s)
-    self.assertRaises(message.DecodeError, stream.ReadVarint32)
-
-  def testReadVarUInt32Success(self):
-    varints_and_ints = [
-        ('\x00', 0),
-        ('\x01', 1),
-        ('\x7f', 127),
-        ('\x80\x01', 128),
-        ('\xff\xff\xff\xff\x0f', wire_format.UINT32_MAX),
-        ]
-    self.ReadVarintSuccessTestHelper(varints_and_ints,
-                                     input_stream.InputStream.ReadVarUInt32)
-
-  def testReadVarUInt32Failure(self):
-    self.EnsureFailureOnEmptyStream(input_stream.InputStream.ReadVarUInt32)
-    # Try and fail to read UINT32_MAX + 1
-    s = '\x80\x80\x80\x80\x10'
-    stream = input_stream.InputStream(s)
-    self.assertRaises(message.DecodeError, stream.ReadVarUInt32)
-
-    # Try and fail to read something that looks like
-    # a varint with more than 10 bytes.
-    s = '\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\x00'
-    stream = input_stream.InputStream(s)
-    self.assertRaises(message.DecodeError, stream.ReadVarUInt32)
-
-  def testReadVarint64Success(self):
-    varints_and_ints = [
-        ('\x00', 0),
-        ('\x01', 1),
-        ('\xff\xff\xff\xff\xff\xff\xff\xff\xff\x01', -1),
-        ('\x7f', 127),
-        ('\x80\x01', 128),
-        ('\xff\xff\xff\xff\xff\xff\xff\xff\x7f', wire_format.INT64_MAX),
-        ('\x80\x80\x80\x80\x80\x80\x80\x80\x80\x01', wire_format.INT64_MIN),
-        ]
-    self.ReadVarintSuccessTestHelper(varints_and_ints,
-                                     input_stream.InputStream.ReadVarint64)
-
-  def testReadVarint64Failure(self):
-    self.EnsureFailureOnEmptyStream(input_stream.InputStream.ReadVarint64)
-    # Try and fail to read something with the mythical 64th bit set.
-    s = '\x80\x80\x80\x80\x80\x80\x80\x80\x80\x02'
-    stream = input_stream.InputStream(s)
-    self.assertRaises(message.DecodeError, stream.ReadVarint64)
-
-    # Try and fail to read something that looks like
-    # a varint with more than 10 bytes.
-    s = '\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\x00'
-    stream = input_stream.InputStream(s)
-    self.assertRaises(message.DecodeError, stream.ReadVarint64)
-
-  def testReadVarUInt64Success(self):
-    varints_and_ints = [
-        ('\x00', 0),
-        ('\x01', 1),
-        ('\x7f', 127),
-        ('\x80\x01', 128),
-        ('\x80\x80\x80\x80\x80\x80\x80\x80\x80\x01', 1 << 63),
-        ]
-    self.ReadVarintSuccessTestHelper(varints_and_ints,
-                                     input_stream.InputStream.ReadVarUInt64)
-
-  def testReadVarUInt64Failure(self):
-    self.EnsureFailureOnEmptyStream(input_stream.InputStream.ReadVarUInt64)
-    # Try and fail to read something with the mythical 64th bit set.
-    s = '\x80\x80\x80\x80\x80\x80\x80\x80\x80\x02'
-    stream = input_stream.InputStream(s)
-    self.assertRaises(message.DecodeError, stream.ReadVarUInt64)
-
-    # Try and fail to read something that looks like
-    # a varint with more than 10 bytes.
-    s = '\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\x00'
-    stream = input_stream.InputStream(s)
-    self.assertRaises(message.DecodeError, stream.ReadVarUInt64)
-
-
-class InputStreamArrayTest(InputStreamBufferTest):
-
-  def setUp(self):
-    # Test InputStreamArray against the same tests in InputStreamBuffer
-    self.__original_input_stream = input_stream.InputStream
-    input_stream.InputStream = input_stream.InputStreamArray
-
-  def tearDown(self):
-    input_stream.InputStream = self.__original_input_stream
-
-
-if __name__ == '__main__':
-  unittest.main()

+ 25 - 16
python/google/protobuf/internal/message_listener.py

@@ -39,22 +39,34 @@ __author__ = 'robinson@google.com (Will Robinson)'
 
 class MessageListener(object):
 
-  """Listens for transitions to nonempty and for invalidations of cached
-  byte sizes.  Meant to be registered via Message._SetListener().
+  """Listens for modifications made to a message.  Meant to be registered via
+  Message._SetListener().
+
+  Attributes:
+    dirty:  If True, then calling Modified() would be a no-op.  This can be
+            used to avoid these calls entirely in the common case.
   """
 
-  def TransitionToNonempty(self):
-    """Called the *first* time that this message becomes nonempty.
-    Implementations are free (but not required) to call this method multiple
-    times after the message has become nonempty.
-    """
-    raise NotImplementedError
+  def Modified(self):
+    """Called every time the message is modified in such a way that the parent
+    message may need to be updated.  This currently means either:
+    (a) The message was modified for the first time, so the parent message
+        should henceforth mark the message as present.
+    (b) The message's cached byte size became dirty -- i.e. the message was
+        modified for the first time after a previous call to ByteSize().
+        Therefore the parent should also mark its byte size as dirty.
+    Note that (a) implies (b), since new objects start out with a client cached
+    size (zero).  However, we document (a) explicitly because it is important.
+
+    Modified() will *only* be called in response to one of these two events --
+    not every time the sub-message is modified.
 
-  def ByteSizeDirty(self):
-    """Called *every* time the cached byte size value
-    for this object is invalidated (transitions from being
-    "clean" to "dirty").
+    Note that if the listener's |dirty| attribute is true, then calling
+    Modified at the moment would be a no-op, so it can be skipped.  Performance-
+    sensitive callers should check this attribute directly before calling since
+    it will be true most of the time.
     """
+
     raise NotImplementedError
 
 
@@ -62,8 +74,5 @@ class NullMessageListener(object):
 
   """No-op MessageListener implementation."""
 
-  def TransitionToNonempty(self):
-    pass
-
-  def ByteSizeDirty(self):
+  def Modified(self):
     pass

+ 39 - 3
python/google/protobuf/internal/message_test.py

@@ -30,7 +30,16 @@
 # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
 # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 
-"""Tests python protocol buffers against the golden message."""
+"""Tests python protocol buffers against the golden message.
+
+Note that the golden messages exercise every known field type, thus this
+test ends up exercising and verifying nearly all of the parsing and
+serialization code in the whole library.
+
+TODO(kenton):  Merge with wire_format_test?  It doesn't make a whole lot of
+sense to call this a test of the "message" module, which only declares an
+abstract interface.
+"""
 
 __author__ = 'gps@google.com (Gregory P. Smith)'
 
@@ -40,14 +49,41 @@ from google.protobuf import unittest_pb2
 from google.protobuf.internal import test_util
 
 
-class MessageTest(test_util.GoldenMessageTestCase):
+class MessageTest(unittest.TestCase):
 
   def testGoldenMessage(self):
     golden_data = test_util.GoldenFile('golden_message').read()
     golden_message = unittest_pb2.TestAllTypes()
     golden_message.ParseFromString(golden_data)
-    self.ExpectAllFieldsSet(golden_message)
+    test_util.ExpectAllFieldsSet(self, golden_message)
+    self.assertTrue(golden_message.SerializeToString() == golden_data)
+
+  def testGoldenExtensions(self):
+    golden_data = test_util.GoldenFile('golden_message').read()
+    golden_message = unittest_pb2.TestAllExtensions()
+    golden_message.ParseFromString(golden_data)
+    all_set = unittest_pb2.TestAllExtensions()
+    test_util.SetAllExtensions(all_set)
+    self.assertEquals(all_set, golden_message)
+    self.assertTrue(golden_message.SerializeToString() == golden_data)
+
+  def testGoldenPackedMessage(self):
+    golden_data = test_util.GoldenFile('golden_packed_fields_message').read()
+    golden_message = unittest_pb2.TestPackedTypes()
+    golden_message.ParseFromString(golden_data)
+    all_set = unittest_pb2.TestPackedTypes()
+    test_util.SetAllPackedFields(all_set)
+    self.assertEquals(all_set, golden_message)
+    self.assertTrue(all_set.SerializeToString() == golden_data)
 
+  def testGoldenPackedExtensions(self):
+    golden_data = test_util.GoldenFile('golden_packed_fields_message').read()
+    golden_message = unittest_pb2.TestPackedExtensions()
+    golden_message.ParseFromString(golden_data)
+    all_set = unittest_pb2.TestPackedExtensions()
+    test_util.SetAllPackedExtensions(all_set)
+    self.assertEquals(all_set, golden_message)
+    self.assertTrue(all_set.SerializeToString() == golden_data)
 
 if __name__ == '__main__':
   unittest.main()

+ 0 - 125
python/google/protobuf/internal/output_stream.py

@@ -1,125 +0,0 @@
-# Protocol Buffers - Google's data interchange format
-# Copyright 2008 Google Inc.  All rights reserved.
-# http://code.google.com/p/protobuf/
-#
-# Redistribution and use in source and binary forms, with or without
-# modification, are permitted provided that the following conditions are
-# met:
-#
-#     * Redistributions of source code must retain the above copyright
-# notice, this list of conditions and the following disclaimer.
-#     * Redistributions in binary form must reproduce the above
-# copyright notice, this list of conditions and the following disclaimer
-# in the documentation and/or other materials provided with the
-# distribution.
-#     * Neither the name of Google Inc. nor the names of its
-# contributors may be used to endorse or promote products derived from
-# this software without specific prior written permission.
-#
-# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
-# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
-# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
-# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
-# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
-# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
-# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
-# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
-# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
-# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
-# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
-"""OutputStream is the primitive interface for sticking bits on the wire.
-
-All protocol buffer serialization can be expressed in terms of
-the OutputStream primitives provided here.
-"""
-
-__author__ = 'robinson@google.com (Will Robinson)'
-
-import array
-import struct
-from google.protobuf import message
-from google.protobuf.internal import wire_format
-
-
-
-# Note that much of this code is ported from //net/proto/ProtocolBuffer, and
-# that the interface is strongly inspired by CodedOutputStream from the C++
-# proto2 implementation.
-
-
-class OutputStream(object):
-
-  """Contains all logic for writing bits, and ToString() to get the result."""
-
-  def __init__(self):
-    self._buffer = array.array('B')
-
-  def AppendRawBytes(self, raw_bytes):
-    """Appends raw_bytes to our internal buffer."""
-    self._buffer.fromstring(raw_bytes)
-
-  def AppendLittleEndian32(self, unsigned_value):
-    """Appends an unsigned 32-bit integer to the internal buffer,
-    in little-endian byte order.
-    """
-    if not 0 <= unsigned_value <= wire_format.UINT32_MAX:
-      raise message.EncodeError(
-          'Unsigned 32-bit out of range: %d' % unsigned_value)
-    self._buffer.fromstring(struct.pack(
-        wire_format.FORMAT_UINT32_LITTLE_ENDIAN, unsigned_value))
-
-  def AppendLittleEndian64(self, unsigned_value):
-    """Appends an unsigned 64-bit integer to the internal buffer,
-    in little-endian byte order.
-    """
-    if not 0 <= unsigned_value <= wire_format.UINT64_MAX:
-      raise message.EncodeError(
-          'Unsigned 64-bit out of range: %d' % unsigned_value)
-    self._buffer.fromstring(struct.pack(
-        wire_format.FORMAT_UINT64_LITTLE_ENDIAN, unsigned_value))
-
-  def AppendVarint32(self, value):
-    """Appends a signed 32-bit integer to the internal buffer,
-    encoded as a varint.  (Note that a negative varint32 will
-    always require 10 bytes of space.)
-    """
-    if not wire_format.INT32_MIN <= value <= wire_format.INT32_MAX:
-      raise message.EncodeError('Value out of range: %d' % value)
-    self.AppendVarint64(value)
-
-  def AppendVarUInt32(self, value):
-    """Appends an unsigned 32-bit integer to the internal buffer,
-    encoded as a varint.
-    """
-    if not 0 <= value <= wire_format.UINT32_MAX:
-      raise message.EncodeError('Value out of range: %d' % value)
-    self.AppendVarUInt64(value)
-
-  def AppendVarint64(self, value):
-    """Appends a signed 64-bit integer to the internal buffer,
-    encoded as a varint.
-    """
-    if not wire_format.INT64_MIN <= value <= wire_format.INT64_MAX:
-      raise message.EncodeError('Value out of range: %d' % value)
-    if value < 0:
-      value += (1 << 64)
-    self.AppendVarUInt64(value)
-
-  def AppendVarUInt64(self, unsigned_value):
-    """Appends an unsigned 64-bit integer to the internal buffer,
-    encoded as a varint.
-    """
-    if not 0 <= unsigned_value <= wire_format.UINT64_MAX:
-      raise message.EncodeError('Value out of range: %d' % unsigned_value)
-    while True:
-      bits = unsigned_value & 0x7f
-      unsigned_value >>= 7
-      if not unsigned_value:
-        self._buffer.append(bits)
-        break
-      self._buffer.append(0x80|bits)
-
-  def ToString(self):
-    """Returns a string containing the bytes in our internal buffer."""
-    return self._buffer.tostring()

+ 0 - 178
python/google/protobuf/internal/output_stream_test.py

@@ -1,178 +0,0 @@
-#! /usr/bin/python
-#
-# Protocol Buffers - Google's data interchange format
-# Copyright 2008 Google Inc.  All rights reserved.
-# http://code.google.com/p/protobuf/
-#
-# Redistribution and use in source and binary forms, with or without
-# modification, are permitted provided that the following conditions are
-# met:
-#
-#     * Redistributions of source code must retain the above copyright
-# notice, this list of conditions and the following disclaimer.
-#     * Redistributions in binary form must reproduce the above
-# copyright notice, this list of conditions and the following disclaimer
-# in the documentation and/or other materials provided with the
-# distribution.
-#     * Neither the name of Google Inc. nor the names of its
-# contributors may be used to endorse or promote products derived from
-# this software without specific prior written permission.
-#
-# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
-# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
-# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
-# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
-# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
-# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
-# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
-# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
-# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
-# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
-# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
-"""Test for google.protobuf.internal.output_stream."""
-
-__author__ = 'robinson@google.com (Will Robinson)'
-
-import unittest
-from google.protobuf import message
-from google.protobuf.internal import output_stream
-from google.protobuf.internal import wire_format
-
-
-class OutputStreamTest(unittest.TestCase):
-
-  def setUp(self):
-    self.stream = output_stream.OutputStream()
-
-  def testAppendRawBytes(self):
-    # Empty string.
-    self.stream.AppendRawBytes('')
-    self.assertEqual('', self.stream.ToString())
-
-    # Nonempty string.
-    self.stream.AppendRawBytes('abc')
-    self.assertEqual('abc', self.stream.ToString())
-
-    # Ensure that we're actually appending.
-    self.stream.AppendRawBytes('def')
-    self.assertEqual('abcdef', self.stream.ToString())
-
-  def AppendNumericTestHelper(self, append_fn, values_and_strings):
-    """For each (value, expected_string) pair in values_and_strings,
-    calls an OutputStream.Append*(value) method on an OutputStream and ensures
-    that the string written to that stream matches expected_string.
-
-    Args:
-      append_fn: Unbound OutputStream method that takes an integer or
-        long value as input.
-      values_and_strings: Iterable of (value, expected_string) pairs.
-    """
-    for conversion in (int, long):
-      for value, string in values_and_strings:
-        stream = output_stream.OutputStream()
-        expected_string = ''
-        append_fn(stream, conversion(value))
-        expected_string += string
-        self.assertEqual(expected_string, stream.ToString())
-
-  def AppendOverflowTestHelper(self, append_fn, value):
-    """Calls an OutputStream.Append*(value) method and asserts
-    that the method raises message.EncodeError.
-
-    Args:
-      append_fn: Unbound OutputStream method that takes an integer or
-        long value as input.
-      value: Value to pass to append_fn which should cause an
-        message.EncodeError.
-    """
-    stream = output_stream.OutputStream()
-    self.assertRaises(message.EncodeError, append_fn, stream, value)
-
-  def testAppendLittleEndian32(self):
-    append_fn = output_stream.OutputStream.AppendLittleEndian32
-    values_and_expected_strings = [
-        (0, '\x00\x00\x00\x00'),
-        (1, '\x01\x00\x00\x00'),
-        ((1 << 32) - 1, '\xff\xff\xff\xff'),
-        ]
-    self.AppendNumericTestHelper(append_fn, values_and_expected_strings)
-
-    self.AppendOverflowTestHelper(append_fn, 1 << 32)
-    self.AppendOverflowTestHelper(append_fn, -1)
-
-  def testAppendLittleEndian64(self):
-    append_fn = output_stream.OutputStream.AppendLittleEndian64
-    values_and_expected_strings = [
-        (0, '\x00\x00\x00\x00\x00\x00\x00\x00'),
-        (1, '\x01\x00\x00\x00\x00\x00\x00\x00'),
-        ((1 << 64) - 1, '\xff\xff\xff\xff\xff\xff\xff\xff'),
-        ]
-    self.AppendNumericTestHelper(append_fn, values_and_expected_strings)
-
-    self.AppendOverflowTestHelper(append_fn, 1 << 64)
-    self.AppendOverflowTestHelper(append_fn, -1)
-
-  def testAppendVarint32(self):
-    append_fn = output_stream.OutputStream.AppendVarint32
-    values_and_expected_strings = [
-        (0, '\x00'),
-        (1, '\x01'),
-        (127, '\x7f'),
-        (128, '\x80\x01'),
-        (-1, '\xff\xff\xff\xff\xff\xff\xff\xff\xff\x01'),
-        (wire_format.INT32_MAX, '\xff\xff\xff\xff\x07'),
-        (wire_format.INT32_MIN, '\x80\x80\x80\x80\xf8\xff\xff\xff\xff\x01'),
-        ]
-    self.AppendNumericTestHelper(append_fn, values_and_expected_strings)
-
-    self.AppendOverflowTestHelper(append_fn, wire_format.INT32_MAX + 1)
-    self.AppendOverflowTestHelper(append_fn, wire_format.INT32_MIN - 1)
-
-  def testAppendVarUInt32(self):
-    append_fn = output_stream.OutputStream.AppendVarUInt32
-    values_and_expected_strings = [
-        (0, '\x00'),
-        (1, '\x01'),
-        (127, '\x7f'),
-        (128, '\x80\x01'),
-        (wire_format.UINT32_MAX, '\xff\xff\xff\xff\x0f'),
-        ]
-    self.AppendNumericTestHelper(append_fn, values_and_expected_strings)
-
-    self.AppendOverflowTestHelper(append_fn, -1)
-    self.AppendOverflowTestHelper(append_fn, wire_format.UINT32_MAX + 1)
-
-  def testAppendVarint64(self):
-    append_fn = output_stream.OutputStream.AppendVarint64
-    values_and_expected_strings = [
-        (0, '\x00'),
-        (1, '\x01'),
-        (127, '\x7f'),
-        (128, '\x80\x01'),
-        (-1, '\xff\xff\xff\xff\xff\xff\xff\xff\xff\x01'),
-        (wire_format.INT64_MAX, '\xff\xff\xff\xff\xff\xff\xff\xff\x7f'),
-        (wire_format.INT64_MIN, '\x80\x80\x80\x80\x80\x80\x80\x80\x80\x01'),
-        ]
-    self.AppendNumericTestHelper(append_fn, values_and_expected_strings)
-
-    self.AppendOverflowTestHelper(append_fn, wire_format.INT64_MAX + 1)
-    self.AppendOverflowTestHelper(append_fn, wire_format.INT64_MIN - 1)
-
-  def testAppendVarUInt64(self):
-    append_fn = output_stream.OutputStream.AppendVarUInt64
-    values_and_expected_strings = [
-        (0, '\x00'),
-        (1, '\x01'),
-        (127, '\x7f'),
-        (128, '\x80\x01'),
-        (wire_format.UINT64_MAX, '\xff\xff\xff\xff\xff\xff\xff\xff\xff\x01'),
-        ]
-    self.AppendNumericTestHelper(append_fn, values_and_expected_strings)
-
-    self.AppendOverflowTestHelper(append_fn, -1)
-    self.AppendOverflowTestHelper(append_fn, wire_format.UINT64_MAX + 1)
-
-
-if __name__ == '__main__':
-  unittest.main()

+ 307 - 47
python/google/protobuf/internal/reflection_test.py

@@ -38,6 +38,7 @@ pure-Python protocol compiler.
 __author__ = 'robinson@google.com (Will Robinson)'
 
 import operator
+import struct
 
 import unittest
 # TODO(robinson): When we split this test in two, only some of these imports
@@ -56,6 +57,51 @@ from google.protobuf.internal import test_util
 from google.protobuf.internal import decoder
 
 
+class _MiniDecoder(object):
+  """Decodes a stream of values from a string.
+
+  Once upon a time we actually had a class called decoder.Decoder.  Then we
+  got rid of it during a redesign that made decoding much, much faster overall.
+  But a couple tests in this file used it to check that the serialized form of
+  a message was correct.  So, this class implements just the methods that were
+  used by said tests, so that we don't have to rewrite the tests.
+  """
+
+  def __init__(self, bytes):
+    self._bytes = bytes
+    self._pos = 0
+
+  def ReadVarint(self):
+    result, self._pos = decoder._DecodeVarint(self._bytes, self._pos)
+    return result
+
+  ReadInt32 = ReadVarint
+  ReadInt64 = ReadVarint
+  ReadUInt32 = ReadVarint
+  ReadUInt64 = ReadVarint
+
+  def ReadSInt64(self):
+    return wire_format.ZigZagDecode(self.ReadVarint())
+
+  ReadSInt32 = ReadSInt64
+
+  def ReadFieldNumberAndWireType(self):
+    return wire_format.UnpackTag(self.ReadVarint())
+
+  def ReadFloat(self):
+    result = struct.unpack("<f", self._bytes[self._pos:self._pos+4])[0]
+    self._pos += 4
+    return result
+
+  def ReadDouble(self):
+    result = struct.unpack("<d", self._bytes[self._pos:self._pos+8])[0]
+    self._pos += 8
+    return result
+
+  def EndOfStream(self):
+    return self._pos == len(self._bytes)
+
+
 class ReflectionTest(unittest.TestCase):
 
   def assertIs(self, values, others):
@@ -63,6 +109,97 @@ class ReflectionTest(unittest.TestCase):
     for i in range(len(values)):
       self.assertTrue(values[i] is others[i])
 
+  def testScalarConstructor(self):
+    # Constructor with only scalar types should succeed.
+    proto = unittest_pb2.TestAllTypes(
+        optional_int32=24,
+        optional_double=54.321,
+        optional_string='optional_string')
+
+    self.assertEqual(24, proto.optional_int32)
+    self.assertEqual(54.321, proto.optional_double)
+    self.assertEqual('optional_string', proto.optional_string)
+
+  def testRepeatedScalarConstructor(self):
+    # Constructor with only repeated scalar types should succeed.
+    proto = unittest_pb2.TestAllTypes(
+        repeated_int32=[1, 2, 3, 4],
+        repeated_double=[1.23, 54.321],
+        repeated_bool=[True, False, False],
+        repeated_string=["optional_string"])
+
+    self.assertEquals([1, 2, 3, 4], list(proto.repeated_int32))
+    self.assertEquals([1.23, 54.321], list(proto.repeated_double))
+    self.assertEquals([True, False, False], list(proto.repeated_bool))
+    self.assertEquals(["optional_string"], list(proto.repeated_string))
+
+  def testRepeatedCompositeConstructor(self):
+    # Constructor with only repeated composite types should succeed.
+    proto = unittest_pb2.TestAllTypes(
+        repeated_nested_message=[
+            unittest_pb2.TestAllTypes.NestedMessage(
+                bb=unittest_pb2.TestAllTypes.FOO),
+            unittest_pb2.TestAllTypes.NestedMessage(
+                bb=unittest_pb2.TestAllTypes.BAR)],
+        repeated_foreign_message=[
+            unittest_pb2.ForeignMessage(c=-43),
+            unittest_pb2.ForeignMessage(c=45324),
+            unittest_pb2.ForeignMessage(c=12)],
+        repeatedgroup=[
+            unittest_pb2.TestAllTypes.RepeatedGroup(),
+            unittest_pb2.TestAllTypes.RepeatedGroup(a=1),
+            unittest_pb2.TestAllTypes.RepeatedGroup(a=2)])
+
+    self.assertEquals(
+        [unittest_pb2.TestAllTypes.NestedMessage(
+            bb=unittest_pb2.TestAllTypes.FOO),
+         unittest_pb2.TestAllTypes.NestedMessage(
+             bb=unittest_pb2.TestAllTypes.BAR)],
+        list(proto.repeated_nested_message))
+    self.assertEquals(
+        [unittest_pb2.ForeignMessage(c=-43),
+         unittest_pb2.ForeignMessage(c=45324),
+         unittest_pb2.ForeignMessage(c=12)],
+        list(proto.repeated_foreign_message))
+    self.assertEquals(
+        [unittest_pb2.TestAllTypes.RepeatedGroup(),
+         unittest_pb2.TestAllTypes.RepeatedGroup(a=1),
+         unittest_pb2.TestAllTypes.RepeatedGroup(a=2)],
+        list(proto.repeatedgroup))
+
+  def testMixedConstructor(self):
+    # Constructor with only mixed types should succeed.
+    proto = unittest_pb2.TestAllTypes(
+        optional_int32=24,
+        optional_string='optional_string',
+        repeated_double=[1.23, 54.321],
+        repeated_bool=[True, False, False],
+        repeated_nested_message=[
+            unittest_pb2.TestAllTypes.NestedMessage(
+                bb=unittest_pb2.TestAllTypes.FOO),
+            unittest_pb2.TestAllTypes.NestedMessage(
+                bb=unittest_pb2.TestAllTypes.BAR)],
+        repeated_foreign_message=[
+            unittest_pb2.ForeignMessage(c=-43),
+            unittest_pb2.ForeignMessage(c=45324),
+            unittest_pb2.ForeignMessage(c=12)])
+
+    self.assertEqual(24, proto.optional_int32)
+    self.assertEqual('optional_string', proto.optional_string)
+    self.assertEquals([1.23, 54.321], list(proto.repeated_double))
+    self.assertEquals([True, False, False], list(proto.repeated_bool))
+    self.assertEquals(
+        [unittest_pb2.TestAllTypes.NestedMessage(
+            bb=unittest_pb2.TestAllTypes.FOO),
+         unittest_pb2.TestAllTypes.NestedMessage(
+             bb=unittest_pb2.TestAllTypes.BAR)],
+        list(proto.repeated_nested_message))
+    self.assertEquals(
+        [unittest_pb2.ForeignMessage(c=-43),
+         unittest_pb2.ForeignMessage(c=45324),
+         unittest_pb2.ForeignMessage(c=12)],
+        list(proto.repeated_foreign_message))
+
   def testSimpleHasBits(self):
     # Test a scalar.
     proto = unittest_pb2.TestAllTypes()
@@ -218,12 +355,23 @@ class ReflectionTest(unittest.TestCase):
     proto.optional_fixed32 = 1
     proto.optional_int32 = 5
     proto.optional_string = 'foo'
+    # Access sub-message but don't set it yet.
+    nested_message = proto.optional_nested_message
     self.assertEqual(
       [ (proto.DESCRIPTOR.fields_by_name['optional_int32'  ], 5),
         (proto.DESCRIPTOR.fields_by_name['optional_fixed32'], 1),
         (proto.DESCRIPTOR.fields_by_name['optional_string' ], 'foo') ],
       proto.ListFields())
 
+    proto.optional_nested_message.bb = 123
+    self.assertEqual(
+      [ (proto.DESCRIPTOR.fields_by_name['optional_int32'  ], 5),
+        (proto.DESCRIPTOR.fields_by_name['optional_fixed32'], 1),
+        (proto.DESCRIPTOR.fields_by_name['optional_string' ], 'foo'),
+        (proto.DESCRIPTOR.fields_by_name['optional_nested_message' ],
+             nested_message) ],
+      proto.ListFields())
+
   def testRepeatedListFields(self):
     proto = unittest_pb2.TestAllTypes()
     proto.repeated_fixed32.append(1)
@@ -234,6 +382,7 @@ class ReflectionTest(unittest.TestCase):
     proto.repeated_string.append('baz')
     proto.repeated_string.extend(str(x) for x in xrange(2))
     proto.optional_int32 = 21
+    proto.repeated_bool  # Access but don't set anything; should not be listed.
     self.assertEqual(
       [ (proto.DESCRIPTOR.fields_by_name['optional_int32'  ], 21),
         (proto.DESCRIPTOR.fields_by_name['repeated_int32'  ], [5, 11]),
@@ -731,7 +880,6 @@ class ReflectionTest(unittest.TestCase):
     extendee_proto.ClearExtension(extension)
     extension_proto.foreign_message_int = 23
 
-    self.assertTrue(not toplevel.HasField('submessage'))
     self.assertTrue(extension_proto is not extendee_proto.Extensions[extension])
 
   def testExtensionFailureModes(self):
@@ -957,57 +1105,75 @@ class ReflectionTest(unittest.TestCase):
     empty_proto = unittest_pb2.TestAllExtensions()
     self.assertEquals(proto, empty_proto)
 
+  def assertInitialized(self, proto):
+    self.assertTrue(proto.IsInitialized())
+    # Neither method should raise an exception.
+    proto.SerializeToString()
+    proto.SerializePartialToString()
+
+  def assertNotInitialized(self, proto):
+    self.assertFalse(proto.IsInitialized())
+    self.assertRaises(message.EncodeError, proto.SerializeToString)
+    # "Partial" serialization doesn't care if message is uninitialized.
+    proto.SerializePartialToString()
+
   def testIsInitialized(self):
     # Trivial cases - all optional fields and extensions.
     proto = unittest_pb2.TestAllTypes()
-    self.assertTrue(proto.IsInitialized())
+    self.assertInitialized(proto)
     proto = unittest_pb2.TestAllExtensions()
-    self.assertTrue(proto.IsInitialized())
+    self.assertInitialized(proto)
 
     # The case of uninitialized required fields.
     proto = unittest_pb2.TestRequired()
-    self.assertFalse(proto.IsInitialized())
+    self.assertNotInitialized(proto)
     proto.a = proto.b = proto.c = 2
-    self.assertTrue(proto.IsInitialized())
+    self.assertInitialized(proto)
 
     # The case of uninitialized submessage.
     proto = unittest_pb2.TestRequiredForeign()
-    self.assertTrue(proto.IsInitialized())
+    self.assertInitialized(proto)
     proto.optional_message.a = 1
-    self.assertFalse(proto.IsInitialized())
+    self.assertNotInitialized(proto)
     proto.optional_message.b = 0
     proto.optional_message.c = 0
-    self.assertTrue(proto.IsInitialized())
+    self.assertInitialized(proto)
 
     # Uninitialized repeated submessage.
     message1 = proto.repeated_message.add()
-    self.assertFalse(proto.IsInitialized())
+    self.assertNotInitialized(proto)
     message1.a = message1.b = message1.c = 0
-    self.assertTrue(proto.IsInitialized())
+    self.assertInitialized(proto)
 
     # Uninitialized repeated group in an extension.
     proto = unittest_pb2.TestAllExtensions()
     extension = unittest_pb2.TestRequired.multi
     message1 = proto.Extensions[extension].add()
     message2 = proto.Extensions[extension].add()
-    self.assertFalse(proto.IsInitialized())
+    self.assertNotInitialized(proto)
     message1.a = 1
     message1.b = 1
     message1.c = 1
-    self.assertFalse(proto.IsInitialized())
+    self.assertNotInitialized(proto)
     message2.a = 2
     message2.b = 2
     message2.c = 2
-    self.assertTrue(proto.IsInitialized())
+    self.assertInitialized(proto)
 
     # Uninitialized nonrepeated message in an extension.
     proto = unittest_pb2.TestAllExtensions()
     extension = unittest_pb2.TestRequired.single
     proto.Extensions[extension].a = 1
-    self.assertFalse(proto.IsInitialized())
+    self.assertNotInitialized(proto)
     proto.Extensions[extension].b = 2
     proto.Extensions[extension].c = 3
-    self.assertTrue(proto.IsInitialized())
+    self.assertInitialized(proto)
+
+    # Try passing an errors list.
+    errors = []
+    proto = unittest_pb2.TestRequired()
+    self.assertFalse(proto.IsInitialized(errors))
+    self.assertEqual(errors, ['a', 'b', 'c'])
 
   def testStringUTF8Encoding(self):
     proto = unittest_pb2.TestAllTypes()
@@ -1079,6 +1245,36 @@ class ReflectionTest(unittest.TestCase):
         test_utf8_bytes, len(test_utf8_bytes) * '\xff')
     self.assertRaises(UnicodeDecodeError, message2.MergeFromString, bytes)
 
+  def testEmptyNestedMessage(self):
+    proto = unittest_pb2.TestAllTypes()
+    proto.optional_nested_message.MergeFrom(
+        unittest_pb2.TestAllTypes.NestedMessage())
+    self.assertTrue(proto.HasField('optional_nested_message'))
+
+    proto = unittest_pb2.TestAllTypes()
+    proto.optional_nested_message.CopyFrom(
+        unittest_pb2.TestAllTypes.NestedMessage())
+    self.assertTrue(proto.HasField('optional_nested_message'))
+
+    proto = unittest_pb2.TestAllTypes()
+    proto.optional_nested_message.MergeFromString('')
+    self.assertTrue(proto.HasField('optional_nested_message'))
+
+    proto = unittest_pb2.TestAllTypes()
+    proto.optional_nested_message.ParseFromString('')
+    self.assertTrue(proto.HasField('optional_nested_message'))
+
+    serialized = proto.SerializeToString()
+    proto2 = unittest_pb2.TestAllTypes()
+    proto2.MergeFromString(serialized)
+    self.assertTrue(proto2.HasField('optional_nested_message'))
+
+  def testSetInParent(self):
+    proto = unittest_pb2.TestAllTypes()
+    self.assertFalse(proto.HasField('optionalgroup'))
+    proto.optionalgroup.SetInParent()
+    self.assertTrue(proto.HasField('optionalgroup'))
+
 
 #  Since we had so many tests for protocol buffer equality, we broke these out
 #  into separate TestCase classes.
@@ -1541,6 +1737,47 @@ class SerializationTest(unittest.TestCase):
     second_proto.MergeFromString(serialized)
     self.assertEqual(first_proto, second_proto)
 
+  def testSerializeNegativeValues(self):
+    first_proto = unittest_pb2.TestAllTypes()
+
+    first_proto.optional_int32 = -1
+    first_proto.optional_int64 = -(2 << 40)
+    first_proto.optional_sint32 = -3
+    first_proto.optional_sint64 = -(4 << 40)
+    first_proto.optional_sfixed32 = -5
+    first_proto.optional_sfixed64 = -(6 << 40)
+
+    second_proto = unittest_pb2.TestAllTypes.FromString(
+        first_proto.SerializeToString())
+
+    self.assertEqual(first_proto, second_proto)
+
+  def testParseTruncated(self):
+    first_proto = unittest_pb2.TestAllTypes()
+    test_util.SetAllFields(first_proto)
+    serialized = first_proto.SerializeToString()
+
+    for truncation_point in xrange(len(serialized) + 1):
+      try:
+        second_proto = unittest_pb2.TestAllTypes()
+        unknown_fields = unittest_pb2.TestEmptyMessage()
+        pos = second_proto._InternalParse(serialized, 0, truncation_point)
+        # If we didn't raise an error then we read exactly the amount expected.
+        self.assertEqual(truncation_point, pos)
+
+        # Parsing to unknown fields should not throw if parsing to known fields
+        # did not.
+        try:
+          pos2 = unknown_fields._InternalParse(serialized, 0, truncation_point)
+          self.assertEqual(truncation_point, pos2)
+        except message.DecodeError:
+          self.fail('Parsing unknown fields failed when parsing known fields '
+                    'did not.')
+      except message.DecodeError:
+        # Parsing unknown fields should also fail.
+        self.assertRaises(message.DecodeError, unknown_fields._InternalParse,
+                          serialized, 0, truncation_point)
+
   def testCanonicalSerializationOrder(self):
     proto = more_messages_pb2.OutOfOrderFields()
     # These are also their tag numbers.  Even though we're setting these in
@@ -1553,7 +1790,7 @@ class SerializationTest(unittest.TestCase):
     proto.optional_int32 = 1
     serialized = proto.SerializeToString()
     self.assertEqual(proto.ByteSize(), len(serialized))
-    d = decoder.Decoder(serialized)
+    d = _MiniDecoder(serialized)
     ReadTag = d.ReadFieldNumberAndWireType
     self.assertEqual((1, wire_format.WIRETYPE_VARINT), ReadTag())
     self.assertEqual(1, d.ReadInt32())
@@ -1709,7 +1946,7 @@ class SerializationTest(unittest.TestCase):
     self._CheckRaises(
         message.EncodeError,
         proto.SerializeToString,
-        'Required field protobuf_unittest.TestRequired.a is not set.')
+        'Message is missing required fields: a,b,c')
     # Shouldn't raise exceptions.
     partial = proto.SerializePartialToString()
 
@@ -1717,7 +1954,7 @@ class SerializationTest(unittest.TestCase):
     self._CheckRaises(
         message.EncodeError,
         proto.SerializeToString,
-        'Required field protobuf_unittest.TestRequired.b is not set.')
+        'Message is missing required fields: b,c')
     # Shouldn't raise exceptions.
     partial = proto.SerializePartialToString()
 
@@ -1725,7 +1962,7 @@ class SerializationTest(unittest.TestCase):
     self._CheckRaises(
         message.EncodeError,
         proto.SerializeToString,
-        'Required field protobuf_unittest.TestRequired.c is not set.')
+        'Message is missing required fields: c')
     # Shouldn't raise exceptions.
     partial = proto.SerializePartialToString()
 
@@ -1744,6 +1981,38 @@ class SerializationTest(unittest.TestCase):
     self.assertEqual(2, proto2.b)
     self.assertEqual(3, proto2.c)
 
+  def testSerializeUninitializedSubMessage(self):
+    proto = unittest_pb2.TestRequiredForeign()
+
+    # Sub-message doesn't exist yet, so this succeeds.
+    proto.SerializeToString()
+
+    proto.optional_message.a = 1
+    self._CheckRaises(
+        message.EncodeError,
+        proto.SerializeToString,
+        'Message is missing required fields: '
+        'optional_message.b,optional_message.c')
+
+    proto.optional_message.b = 2
+    proto.optional_message.c = 3
+    proto.SerializeToString()
+
+    proto.repeated_message.add().a = 1
+    proto.repeated_message.add().b = 2
+    self._CheckRaises(
+        message.EncodeError,
+        proto.SerializeToString,
+        'Message is missing required fields: '
+        'repeated_message[0].b,repeated_message[0].c,'
+        'repeated_message[1].a,repeated_message[1].c')
+
+    proto.repeated_message[0].b = 2
+    proto.repeated_message[0].c = 3
+    proto.repeated_message[1].a = 1
+    proto.repeated_message[1].c = 3
+    proto.SerializeToString()
+
   def testSerializeAllPackedFields(self):
     first_proto = unittest_pb2.TestPackedTypes()
     second_proto = unittest_pb2.TestPackedTypes()
@@ -1786,7 +2055,7 @@ class SerializationTest(unittest.TestCase):
     proto.packed_float.append(2.0)             # 4 bytes, will be before double
     serialized = proto.SerializeToString()
     self.assertEqual(proto.ByteSize(), len(serialized))
-    d = decoder.Decoder(serialized)
+    d = _MiniDecoder(serialized)
     ReadTag = d.ReadFieldNumberAndWireType
     self.assertEqual((90, wire_format.WIRETYPE_LENGTH_DELIMITED), ReadTag())
     self.assertEqual(1+1+1+2, d.ReadInt32())
@@ -1803,6 +2072,24 @@ class SerializationTest(unittest.TestCase):
     self.assertEqual(1000.0, d.ReadDouble())
     self.assertTrue(d.EndOfStream())
 
+  def testParsePackedFromUnpacked(self):
+    unpacked = unittest_pb2.TestUnpackedTypes()
+    test_util.SetAllUnpackedFields(unpacked)
+    packed = unittest_pb2.TestPackedTypes()
+    packed.MergeFromString(unpacked.SerializeToString())
+    expected = unittest_pb2.TestPackedTypes()
+    test_util.SetAllPackedFields(expected)
+    self.assertEqual(expected, packed)
+
+  def testParseUnpackedFromPacked(self):
+    packed = unittest_pb2.TestPackedTypes()
+    test_util.SetAllPackedFields(packed)
+    unpacked = unittest_pb2.TestUnpackedTypes()
+    unpacked.MergeFromString(packed.SerializeToString())
+    expected = unittest_pb2.TestUnpackedTypes()
+    test_util.SetAllUnpackedFields(expected)
+    self.assertEqual(expected, unpacked)
+
   def testFieldNumbers(self):
     proto = unittest_pb2.TestAllTypes()
     self.assertEqual(unittest_pb2.TestAllTypes.NestedMessage.BB_FIELD_NUMBER, 1)
@@ -1944,33 +2231,6 @@ class OptionsTest(unittest.TestCase):
                        field_descriptor.label)
 
 
-class UtilityTest(unittest.TestCase):
-
-  def testImergeSorted(self):
-    ImergeSorted = reflection._ImergeSorted
-    # Various types of emptiness.
-    self.assertEqual([], list(ImergeSorted()))
-    self.assertEqual([], list(ImergeSorted([])))
-    self.assertEqual([], list(ImergeSorted([], [])))
-
-    # One nonempty list.
-    self.assertEqual([1, 2, 3], list(ImergeSorted([1, 2, 3])))
-    self.assertEqual([1, 2, 3], list(ImergeSorted([1, 2, 3], [])))
-    self.assertEqual([1, 2, 3], list(ImergeSorted([], [1, 2, 3])))
-
-    # Merging some nonempty lists together.
-    self.assertEqual([1, 2, 3], list(ImergeSorted([1, 3], [2])))
-    self.assertEqual([1, 2, 3], list(ImergeSorted([1], [3], [2])))
-    self.assertEqual([1, 2, 3], list(ImergeSorted([1], [3], [2], [])))
-
-    # Elements repeated across component iterators.
-    self.assertEqual([1, 2, 2, 3, 3],
-                     list(ImergeSorted([1, 2], [3], [2, 3])))
-
-    # Elements repeated within an iterator.
-    self.assertEqual([1, 2, 2, 3, 3],
-                     list(ImergeSorted([1, 2, 2], [3], [3])))
-
 
 if __name__ == '__main__':
   unittest.main()

+ 245 - 222
python/google/protobuf/internal/test_util.py

@@ -31,14 +31,13 @@
 """Utilities for Python proto2 tests.
 
 This is intentionally modeled on C++ code in
-//net/proto2/internal/test_util.*.
+//google/protobuf/test_util.*.
 """
 
 __author__ = 'robinson@google.com (Will Robinson)'
 
 import os.path
 
-import unittest
 from google.protobuf import unittest_import_pb2
 from google.protobuf import unittest_pb2
 
@@ -353,198 +352,198 @@ def ExpectAllFieldsAndExtensionsInOrder(serialized):
     raise ValueError('Expected %r, found %r' % (expected, serialized))
 
 
-class GoldenMessageTestCase(unittest.TestCase):
-  """This adds methods to TestCase useful for verifying our Golden Message."""
-
-  def ExpectAllFieldsSet(self, message):
-    """Check all fields for correct values have after Set*Fields() is called."""
-    self.assertTrue(message.HasField('optional_int32'))
-    self.assertTrue(message.HasField('optional_int64'))
-    self.assertTrue(message.HasField('optional_uint32'))
-    self.assertTrue(message.HasField('optional_uint64'))
-    self.assertTrue(message.HasField('optional_sint32'))
-    self.assertTrue(message.HasField('optional_sint64'))
-    self.assertTrue(message.HasField('optional_fixed32'))
-    self.assertTrue(message.HasField('optional_fixed64'))
-    self.assertTrue(message.HasField('optional_sfixed32'))
-    self.assertTrue(message.HasField('optional_sfixed64'))
-    self.assertTrue(message.HasField('optional_float'))
-    self.assertTrue(message.HasField('optional_double'))
-    self.assertTrue(message.HasField('optional_bool'))
-    self.assertTrue(message.HasField('optional_string'))
-    self.assertTrue(message.HasField('optional_bytes'))
-
-    self.assertTrue(message.HasField('optionalgroup'))
-    self.assertTrue(message.HasField('optional_nested_message'))
-    self.assertTrue(message.HasField('optional_foreign_message'))
-    self.assertTrue(message.HasField('optional_import_message'))
-
-    self.assertTrue(message.optionalgroup.HasField('a'))
-    self.assertTrue(message.optional_nested_message.HasField('bb'))
-    self.assertTrue(message.optional_foreign_message.HasField('c'))
-    self.assertTrue(message.optional_import_message.HasField('d'))
-
-    self.assertTrue(message.HasField('optional_nested_enum'))
-    self.assertTrue(message.HasField('optional_foreign_enum'))
-    self.assertTrue(message.HasField('optional_import_enum'))
-
-    self.assertTrue(message.HasField('optional_string_piece'))
-    self.assertTrue(message.HasField('optional_cord'))
-
-    self.assertEqual(101, message.optional_int32)
-    self.assertEqual(102, message.optional_int64)
-    self.assertEqual(103, message.optional_uint32)
-    self.assertEqual(104, message.optional_uint64)
-    self.assertEqual(105, message.optional_sint32)
-    self.assertEqual(106, message.optional_sint64)
-    self.assertEqual(107, message.optional_fixed32)
-    self.assertEqual(108, message.optional_fixed64)
-    self.assertEqual(109, message.optional_sfixed32)
-    self.assertEqual(110, message.optional_sfixed64)
-    self.assertEqual(111, message.optional_float)
-    self.assertEqual(112, message.optional_double)
-    self.assertEqual(True, message.optional_bool)
-    self.assertEqual('115', message.optional_string)
-    self.assertEqual('116', message.optional_bytes)
-
-    self.assertEqual(117, message.optionalgroup.a);
-    self.assertEqual(118, message.optional_nested_message.bb)
-    self.assertEqual(119, message.optional_foreign_message.c)
-    self.assertEqual(120, message.optional_import_message.d)
-
-    self.assertEqual(unittest_pb2.TestAllTypes.BAZ,
-                     message.optional_nested_enum)
-    self.assertEqual(unittest_pb2.FOREIGN_BAZ, message.optional_foreign_enum)
-    self.assertEqual(unittest_import_pb2.IMPORT_BAZ,
-                     message.optional_import_enum)
-
-    # -----------------------------------------------------------------
-
-    self.assertEqual(2, len(message.repeated_int32))
-    self.assertEqual(2, len(message.repeated_int64))
-    self.assertEqual(2, len(message.repeated_uint32))
-    self.assertEqual(2, len(message.repeated_uint64))
-    self.assertEqual(2, len(message.repeated_sint32))
-    self.assertEqual(2, len(message.repeated_sint64))
-    self.assertEqual(2, len(message.repeated_fixed32))
-    self.assertEqual(2, len(message.repeated_fixed64))
-    self.assertEqual(2, len(message.repeated_sfixed32))
-    self.assertEqual(2, len(message.repeated_sfixed64))
-    self.assertEqual(2, len(message.repeated_float))
-    self.assertEqual(2, len(message.repeated_double))
-    self.assertEqual(2, len(message.repeated_bool))
-    self.assertEqual(2, len(message.repeated_string))
-    self.assertEqual(2, len(message.repeated_bytes))
-
-    self.assertEqual(2, len(message.repeatedgroup))
-    self.assertEqual(2, len(message.repeated_nested_message))
-    self.assertEqual(2, len(message.repeated_foreign_message))
-    self.assertEqual(2, len(message.repeated_import_message))
-    self.assertEqual(2, len(message.repeated_nested_enum))
-    self.assertEqual(2, len(message.repeated_foreign_enum))
-    self.assertEqual(2, len(message.repeated_import_enum))
-
-    self.assertEqual(2, len(message.repeated_string_piece))
-    self.assertEqual(2, len(message.repeated_cord))
-
-    self.assertEqual(201, message.repeated_int32[0])
-    self.assertEqual(202, message.repeated_int64[0])
-    self.assertEqual(203, message.repeated_uint32[0])
-    self.assertEqual(204, message.repeated_uint64[0])
-    self.assertEqual(205, message.repeated_sint32[0])
-    self.assertEqual(206, message.repeated_sint64[0])
-    self.assertEqual(207, message.repeated_fixed32[0])
-    self.assertEqual(208, message.repeated_fixed64[0])
-    self.assertEqual(209, message.repeated_sfixed32[0])
-    self.assertEqual(210, message.repeated_sfixed64[0])
-    self.assertEqual(211, message.repeated_float[0])
-    self.assertEqual(212, message.repeated_double[0])
-    self.assertEqual(True, message.repeated_bool[0])
-    self.assertEqual('215', message.repeated_string[0])
-    self.assertEqual('216', message.repeated_bytes[0])
-
-    self.assertEqual(217, message.repeatedgroup[0].a)
-    self.assertEqual(218, message.repeated_nested_message[0].bb)
-    self.assertEqual(219, message.repeated_foreign_message[0].c)
-    self.assertEqual(220, message.repeated_import_message[0].d)
-
-    self.assertEqual(unittest_pb2.TestAllTypes.BAR,
-                     message.repeated_nested_enum[0])
-    self.assertEqual(unittest_pb2.FOREIGN_BAR,
-                     message.repeated_foreign_enum[0])
-    self.assertEqual(unittest_import_pb2.IMPORT_BAR,
-                     message.repeated_import_enum[0])
-
-    self.assertEqual(301, message.repeated_int32[1])
-    self.assertEqual(302, message.repeated_int64[1])
-    self.assertEqual(303, message.repeated_uint32[1])
-    self.assertEqual(304, message.repeated_uint64[1])
-    self.assertEqual(305, message.repeated_sint32[1])
-    self.assertEqual(306, message.repeated_sint64[1])
-    self.assertEqual(307, message.repeated_fixed32[1])
-    self.assertEqual(308, message.repeated_fixed64[1])
-    self.assertEqual(309, message.repeated_sfixed32[1])
-    self.assertEqual(310, message.repeated_sfixed64[1])
-    self.assertEqual(311, message.repeated_float[1])
-    self.assertEqual(312, message.repeated_double[1])
-    self.assertEqual(False, message.repeated_bool[1])
-    self.assertEqual('315', message.repeated_string[1])
-    self.assertEqual('316', message.repeated_bytes[1])
-
-    self.assertEqual(317, message.repeatedgroup[1].a)
-    self.assertEqual(318, message.repeated_nested_message[1].bb)
-    self.assertEqual(319, message.repeated_foreign_message[1].c)
-    self.assertEqual(320, message.repeated_import_message[1].d)
-
-    self.assertEqual(unittest_pb2.TestAllTypes.BAZ,
-                     message.repeated_nested_enum[1])
-    self.assertEqual(unittest_pb2.FOREIGN_BAZ,
-                     message.repeated_foreign_enum[1])
-    self.assertEqual(unittest_import_pb2.IMPORT_BAZ,
-                     message.repeated_import_enum[1])
-
-    # -----------------------------------------------------------------
-
-    self.assertTrue(message.HasField('default_int32'))
-    self.assertTrue(message.HasField('default_int64'))
-    self.assertTrue(message.HasField('default_uint32'))
-    self.assertTrue(message.HasField('default_uint64'))
-    self.assertTrue(message.HasField('default_sint32'))
-    self.assertTrue(message.HasField('default_sint64'))
-    self.assertTrue(message.HasField('default_fixed32'))
-    self.assertTrue(message.HasField('default_fixed64'))
-    self.assertTrue(message.HasField('default_sfixed32'))
-    self.assertTrue(message.HasField('default_sfixed64'))
-    self.assertTrue(message.HasField('default_float'))
-    self.assertTrue(message.HasField('default_double'))
-    self.assertTrue(message.HasField('default_bool'))
-    self.assertTrue(message.HasField('default_string'))
-    self.assertTrue(message.HasField('default_bytes'))
-
-    self.assertTrue(message.HasField('default_nested_enum'))
-    self.assertTrue(message.HasField('default_foreign_enum'))
-    self.assertTrue(message.HasField('default_import_enum'))
-
-    self.assertEqual(401, message.default_int32)
-    self.assertEqual(402, message.default_int64)
-    self.assertEqual(403, message.default_uint32)
-    self.assertEqual(404, message.default_uint64)
-    self.assertEqual(405, message.default_sint32)
-    self.assertEqual(406, message.default_sint64)
-    self.assertEqual(407, message.default_fixed32)
-    self.assertEqual(408, message.default_fixed64)
-    self.assertEqual(409, message.default_sfixed32)
-    self.assertEqual(410, message.default_sfixed64)
-    self.assertEqual(411, message.default_float)
-    self.assertEqual(412, message.default_double)
-    self.assertEqual(False, message.default_bool)
-    self.assertEqual('415', message.default_string)
-    self.assertEqual('416', message.default_bytes)
-
-    self.assertEqual(unittest_pb2.TestAllTypes.FOO, message.default_nested_enum)
-    self.assertEqual(unittest_pb2.FOREIGN_FOO, message.default_foreign_enum)
-    self.assertEqual(unittest_import_pb2.IMPORT_FOO,
-                     message.default_import_enum)
+def ExpectAllFieldsSet(test_case, message):
+  """Check all fields for correct values have after Set*Fields() is called."""
+  test_case.assertTrue(message.HasField('optional_int32'))
+  test_case.assertTrue(message.HasField('optional_int64'))
+  test_case.assertTrue(message.HasField('optional_uint32'))
+  test_case.assertTrue(message.HasField('optional_uint64'))
+  test_case.assertTrue(message.HasField('optional_sint32'))
+  test_case.assertTrue(message.HasField('optional_sint64'))
+  test_case.assertTrue(message.HasField('optional_fixed32'))
+  test_case.assertTrue(message.HasField('optional_fixed64'))
+  test_case.assertTrue(message.HasField('optional_sfixed32'))
+  test_case.assertTrue(message.HasField('optional_sfixed64'))
+  test_case.assertTrue(message.HasField('optional_float'))
+  test_case.assertTrue(message.HasField('optional_double'))
+  test_case.assertTrue(message.HasField('optional_bool'))
+  test_case.assertTrue(message.HasField('optional_string'))
+  test_case.assertTrue(message.HasField('optional_bytes'))
+
+  test_case.assertTrue(message.HasField('optionalgroup'))
+  test_case.assertTrue(message.HasField('optional_nested_message'))
+  test_case.assertTrue(message.HasField('optional_foreign_message'))
+  test_case.assertTrue(message.HasField('optional_import_message'))
+
+  test_case.assertTrue(message.optionalgroup.HasField('a'))
+  test_case.assertTrue(message.optional_nested_message.HasField('bb'))
+  test_case.assertTrue(message.optional_foreign_message.HasField('c'))
+  test_case.assertTrue(message.optional_import_message.HasField('d'))
+
+  test_case.assertTrue(message.HasField('optional_nested_enum'))
+  test_case.assertTrue(message.HasField('optional_foreign_enum'))
+  test_case.assertTrue(message.HasField('optional_import_enum'))
+
+  test_case.assertTrue(message.HasField('optional_string_piece'))
+  test_case.assertTrue(message.HasField('optional_cord'))
+
+  test_case.assertEqual(101, message.optional_int32)
+  test_case.assertEqual(102, message.optional_int64)
+  test_case.assertEqual(103, message.optional_uint32)
+  test_case.assertEqual(104, message.optional_uint64)
+  test_case.assertEqual(105, message.optional_sint32)
+  test_case.assertEqual(106, message.optional_sint64)
+  test_case.assertEqual(107, message.optional_fixed32)
+  test_case.assertEqual(108, message.optional_fixed64)
+  test_case.assertEqual(109, message.optional_sfixed32)
+  test_case.assertEqual(110, message.optional_sfixed64)
+  test_case.assertEqual(111, message.optional_float)
+  test_case.assertEqual(112, message.optional_double)
+  test_case.assertEqual(True, message.optional_bool)
+  test_case.assertEqual('115', message.optional_string)
+  test_case.assertEqual('116', message.optional_bytes)
+
+  test_case.assertEqual(117, message.optionalgroup.a)
+  test_case.assertEqual(118, message.optional_nested_message.bb)
+  test_case.assertEqual(119, message.optional_foreign_message.c)
+  test_case.assertEqual(120, message.optional_import_message.d)
+
+  test_case.assertEqual(unittest_pb2.TestAllTypes.BAZ,
+                        message.optional_nested_enum)
+  test_case.assertEqual(unittest_pb2.FOREIGN_BAZ,
+                        message.optional_foreign_enum)
+  test_case.assertEqual(unittest_import_pb2.IMPORT_BAZ,
+                        message.optional_import_enum)
+
+  # -----------------------------------------------------------------
+
+  test_case.assertEqual(2, len(message.repeated_int32))
+  test_case.assertEqual(2, len(message.repeated_int64))
+  test_case.assertEqual(2, len(message.repeated_uint32))
+  test_case.assertEqual(2, len(message.repeated_uint64))
+  test_case.assertEqual(2, len(message.repeated_sint32))
+  test_case.assertEqual(2, len(message.repeated_sint64))
+  test_case.assertEqual(2, len(message.repeated_fixed32))
+  test_case.assertEqual(2, len(message.repeated_fixed64))
+  test_case.assertEqual(2, len(message.repeated_sfixed32))
+  test_case.assertEqual(2, len(message.repeated_sfixed64))
+  test_case.assertEqual(2, len(message.repeated_float))
+  test_case.assertEqual(2, len(message.repeated_double))
+  test_case.assertEqual(2, len(message.repeated_bool))
+  test_case.assertEqual(2, len(message.repeated_string))
+  test_case.assertEqual(2, len(message.repeated_bytes))
+
+  test_case.assertEqual(2, len(message.repeatedgroup))
+  test_case.assertEqual(2, len(message.repeated_nested_message))
+  test_case.assertEqual(2, len(message.repeated_foreign_message))
+  test_case.assertEqual(2, len(message.repeated_import_message))
+  test_case.assertEqual(2, len(message.repeated_nested_enum))
+  test_case.assertEqual(2, len(message.repeated_foreign_enum))
+  test_case.assertEqual(2, len(message.repeated_import_enum))
+
+  test_case.assertEqual(2, len(message.repeated_string_piece))
+  test_case.assertEqual(2, len(message.repeated_cord))
+
+  test_case.assertEqual(201, message.repeated_int32[0])
+  test_case.assertEqual(202, message.repeated_int64[0])
+  test_case.assertEqual(203, message.repeated_uint32[0])
+  test_case.assertEqual(204, message.repeated_uint64[0])
+  test_case.assertEqual(205, message.repeated_sint32[0])
+  test_case.assertEqual(206, message.repeated_sint64[0])
+  test_case.assertEqual(207, message.repeated_fixed32[0])
+  test_case.assertEqual(208, message.repeated_fixed64[0])
+  test_case.assertEqual(209, message.repeated_sfixed32[0])
+  test_case.assertEqual(210, message.repeated_sfixed64[0])
+  test_case.assertEqual(211, message.repeated_float[0])
+  test_case.assertEqual(212, message.repeated_double[0])
+  test_case.assertEqual(True, message.repeated_bool[0])
+  test_case.assertEqual('215', message.repeated_string[0])
+  test_case.assertEqual('216', message.repeated_bytes[0])
+
+  test_case.assertEqual(217, message.repeatedgroup[0].a)
+  test_case.assertEqual(218, message.repeated_nested_message[0].bb)
+  test_case.assertEqual(219, message.repeated_foreign_message[0].c)
+  test_case.assertEqual(220, message.repeated_import_message[0].d)
+
+  test_case.assertEqual(unittest_pb2.TestAllTypes.BAR,
+                        message.repeated_nested_enum[0])
+  test_case.assertEqual(unittest_pb2.FOREIGN_BAR,
+                        message.repeated_foreign_enum[0])
+  test_case.assertEqual(unittest_import_pb2.IMPORT_BAR,
+                        message.repeated_import_enum[0])
+
+  test_case.assertEqual(301, message.repeated_int32[1])
+  test_case.assertEqual(302, message.repeated_int64[1])
+  test_case.assertEqual(303, message.repeated_uint32[1])
+  test_case.assertEqual(304, message.repeated_uint64[1])
+  test_case.assertEqual(305, message.repeated_sint32[1])
+  test_case.assertEqual(306, message.repeated_sint64[1])
+  test_case.assertEqual(307, message.repeated_fixed32[1])
+  test_case.assertEqual(308, message.repeated_fixed64[1])
+  test_case.assertEqual(309, message.repeated_sfixed32[1])
+  test_case.assertEqual(310, message.repeated_sfixed64[1])
+  test_case.assertEqual(311, message.repeated_float[1])
+  test_case.assertEqual(312, message.repeated_double[1])
+  test_case.assertEqual(False, message.repeated_bool[1])
+  test_case.assertEqual('315', message.repeated_string[1])
+  test_case.assertEqual('316', message.repeated_bytes[1])
+
+  test_case.assertEqual(317, message.repeatedgroup[1].a)
+  test_case.assertEqual(318, message.repeated_nested_message[1].bb)
+  test_case.assertEqual(319, message.repeated_foreign_message[1].c)
+  test_case.assertEqual(320, message.repeated_import_message[1].d)
+
+  test_case.assertEqual(unittest_pb2.TestAllTypes.BAZ,
+                        message.repeated_nested_enum[1])
+  test_case.assertEqual(unittest_pb2.FOREIGN_BAZ,
+                        message.repeated_foreign_enum[1])
+  test_case.assertEqual(unittest_import_pb2.IMPORT_BAZ,
+                        message.repeated_import_enum[1])
+
+  # -----------------------------------------------------------------
+
+  test_case.assertTrue(message.HasField('default_int32'))
+  test_case.assertTrue(message.HasField('default_int64'))
+  test_case.assertTrue(message.HasField('default_uint32'))
+  test_case.assertTrue(message.HasField('default_uint64'))
+  test_case.assertTrue(message.HasField('default_sint32'))
+  test_case.assertTrue(message.HasField('default_sint64'))
+  test_case.assertTrue(message.HasField('default_fixed32'))
+  test_case.assertTrue(message.HasField('default_fixed64'))
+  test_case.assertTrue(message.HasField('default_sfixed32'))
+  test_case.assertTrue(message.HasField('default_sfixed64'))
+  test_case.assertTrue(message.HasField('default_float'))
+  test_case.assertTrue(message.HasField('default_double'))
+  test_case.assertTrue(message.HasField('default_bool'))
+  test_case.assertTrue(message.HasField('default_string'))
+  test_case.assertTrue(message.HasField('default_bytes'))
+
+  test_case.assertTrue(message.HasField('default_nested_enum'))
+  test_case.assertTrue(message.HasField('default_foreign_enum'))
+  test_case.assertTrue(message.HasField('default_import_enum'))
+
+  test_case.assertEqual(401, message.default_int32)
+  test_case.assertEqual(402, message.default_int64)
+  test_case.assertEqual(403, message.default_uint32)
+  test_case.assertEqual(404, message.default_uint64)
+  test_case.assertEqual(405, message.default_sint32)
+  test_case.assertEqual(406, message.default_sint64)
+  test_case.assertEqual(407, message.default_fixed32)
+  test_case.assertEqual(408, message.default_fixed64)
+  test_case.assertEqual(409, message.default_sfixed32)
+  test_case.assertEqual(410, message.default_sfixed64)
+  test_case.assertEqual(411, message.default_float)
+  test_case.assertEqual(412, message.default_double)
+  test_case.assertEqual(False, message.default_bool)
+  test_case.assertEqual('415', message.default_string)
+  test_case.assertEqual('416', message.default_bytes)
+
+  test_case.assertEqual(unittest_pb2.TestAllTypes.FOO,
+                        message.default_nested_enum)
+  test_case.assertEqual(unittest_pb2.FOREIGN_FOO,
+                        message.default_foreign_enum)
+  test_case.assertEqual(unittest_import_pb2.IMPORT_FOO,
+                        message.default_import_enum)
 
 def GoldenFile(filename):
   """Finds the given golden file and returns a file object representing it."""
@@ -570,21 +569,21 @@ def SetAllPackedFields(message):
   Args:
     message: A unittest_pb2.TestPackedTypes instance.
   """
-  message.packed_int32.extend([101, 102])
-  message.packed_int64.extend([103, 104])
-  message.packed_uint32.extend([105, 106])
-  message.packed_uint64.extend([107, 108])
-  message.packed_sint32.extend([109, 110])
-  message.packed_sint64.extend([111, 112])
-  message.packed_fixed32.extend([113, 114])
-  message.packed_fixed64.extend([115, 116])
-  message.packed_sfixed32.extend([117, 118])
-  message.packed_sfixed64.extend([119, 120])
-  message.packed_float.extend([121.0, 122.0])
-  message.packed_double.extend([122.0, 123.0])
+  message.packed_int32.extend([601, 701])
+  message.packed_int64.extend([602, 702])
+  message.packed_uint32.extend([603, 703])
+  message.packed_uint64.extend([604, 704])
+  message.packed_sint32.extend([605, 705])
+  message.packed_sint64.extend([606, 706])
+  message.packed_fixed32.extend([607, 707])
+  message.packed_fixed64.extend([608, 708])
+  message.packed_sfixed32.extend([609, 709])
+  message.packed_sfixed64.extend([610, 710])
+  message.packed_float.extend([611.0, 711.0])
+  message.packed_double.extend([612.0, 712.0])
   message.packed_bool.extend([True, False])
-  message.packed_enum.extend([unittest_pb2.FOREIGN_FOO,
-                              unittest_pb2.FOREIGN_BAR])
+  message.packed_enum.extend([unittest_pb2.FOREIGN_BAR,
+                              unittest_pb2.FOREIGN_BAZ])
 
 
 def SetAllPackedExtensions(message):
@@ -596,17 +595,41 @@ def SetAllPackedExtensions(message):
   extensions = message.Extensions
   pb2 = unittest_pb2
 
-  extensions[pb2.packed_int32_extension].append(101)
-  extensions[pb2.packed_int64_extension].append(102)
-  extensions[pb2.packed_uint32_extension].append(103)
-  extensions[pb2.packed_uint64_extension].append(104)
-  extensions[pb2.packed_sint32_extension].append(105)
-  extensions[pb2.packed_sint64_extension].append(106)
-  extensions[pb2.packed_fixed32_extension].append(107)
-  extensions[pb2.packed_fixed64_extension].append(108)
-  extensions[pb2.packed_sfixed32_extension].append(109)
-  extensions[pb2.packed_sfixed64_extension].append(110)
-  extensions[pb2.packed_float_extension].append(111.0)
-  extensions[pb2.packed_double_extension].append(112.0)
-  extensions[pb2.packed_bool_extension].append(True)
-  extensions[pb2.packed_enum_extension].append(pb2.FOREIGN_BAZ)
+  extensions[pb2.packed_int32_extension].extend([601, 701])
+  extensions[pb2.packed_int64_extension].extend([602, 702])
+  extensions[pb2.packed_uint32_extension].extend([603, 703])
+  extensions[pb2.packed_uint64_extension].extend([604, 704])
+  extensions[pb2.packed_sint32_extension].extend([605, 705])
+  extensions[pb2.packed_sint64_extension].extend([606, 706])
+  extensions[pb2.packed_fixed32_extension].extend([607, 707])
+  extensions[pb2.packed_fixed64_extension].extend([608, 708])
+  extensions[pb2.packed_sfixed32_extension].extend([609, 709])
+  extensions[pb2.packed_sfixed64_extension].extend([610, 710])
+  extensions[pb2.packed_float_extension].extend([611.0, 711.0])
+  extensions[pb2.packed_double_extension].extend([612.0, 712.0])
+  extensions[pb2.packed_bool_extension].extend([True, False])
+  extensions[pb2.packed_enum_extension].extend([unittest_pb2.FOREIGN_BAR,
+                                                unittest_pb2.FOREIGN_BAZ])
+
+
+def SetAllUnpackedFields(message):
+  """Sets every field in the message to a unique value.
+
+  Args:
+    message: A unittest_pb2.TestUnpackedTypes instance.
+  """
+  message.unpacked_int32.extend([601, 701])
+  message.unpacked_int64.extend([602, 702])
+  message.unpacked_uint32.extend([603, 703])
+  message.unpacked_uint64.extend([604, 704])
+  message.unpacked_sint32.extend([605, 705])
+  message.unpacked_sint64.extend([606, 706])
+  message.unpacked_fixed32.extend([607, 707])
+  message.unpacked_fixed64.extend([608, 708])
+  message.unpacked_sfixed32.extend([609, 709])
+  message.unpacked_sfixed64.extend([610, 710])
+  message.unpacked_float.extend([611.0, 711.0])
+  message.unpacked_double.extend([612.0, 712.0])
+  message.unpacked_bool.extend([True, False])
+  message.unpacked_enum.extend([unittest_pb2.FOREIGN_BAR,
+                                unittest_pb2.FOREIGN_BAZ])

+ 22 - 3
python/google/protobuf/internal/text_format_test.py

@@ -43,7 +43,7 @@ from google.protobuf import unittest_pb2
 from google.protobuf import unittest_mset_pb2
 
 
-class TextFormatTest(test_util.GoldenMessageTestCase):
+class TextFormatTest(unittest.TestCase):
   def ReadGolden(self, golden_filename):
     f = test_util.GoldenFile(golden_filename)
     golden_lines = f.readlines()
@@ -149,7 +149,7 @@ class TextFormatTest(test_util.GoldenMessageTestCase):
     parsed_message = unittest_pb2.TestAllTypes()
     text_format.Merge(ascii_text, parsed_message)
     self.assertEqual(message, parsed_message)
-    self.ExpectAllFieldsSet(message)
+    test_util.ExpectAllFieldsSet(self, message)
 
   def testMergeAllExtensions(self):
     message = unittest_pb2.TestAllExtensions()
@@ -212,12 +212,18 @@ class TextFormatTest(test_util.GoldenMessageTestCase):
         text_format.Merge, text, message)
 
   def testMergeBadExtension(self):
-    message = unittest_pb2.TestAllTypes()
+    message = unittest_pb2.TestAllExtensions()
     text = '[unknown_extension]: 8\n'
     self.assertRaisesWithMessage(
         text_format.ParseError,
         '1:2 : Extension "unknown_extension" not registered.',
         text_format.Merge, text, message)
+    message = unittest_pb2.TestAllTypes()
+    self.assertRaisesWithMessage(
+        text_format.ParseError,
+        ('1:2 : Message type "protobuf_unittest.TestAllTypes" does not have '
+         'extensions.'),
+        text_format.Merge, text, message)
 
   def testMergeGroupNotClosed(self):
     message = unittest_pb2.TestAllTypes()
@@ -231,6 +237,19 @@ class TextFormatTest(test_util.GoldenMessageTestCase):
         text_format.ParseError, '1:16 : Expected "}".',
         text_format.Merge, text, message)
 
+  def testMergeEmptyGroup(self):
+    message = unittest_pb2.TestAllTypes()
+    text = 'OptionalGroup: {}'
+    text_format.Merge(text, message)
+    self.assertTrue(message.HasField('optionalgroup'))
+
+    message.Clear()
+
+    message = unittest_pb2.TestAllTypes()
+    text = 'OptionalGroup: <>'
+    text_format.Merge(text, message)
+    self.assertTrue(message.HasField('optionalgroup'))
+
   def testMergeBadEnumValue(self):
     message = unittest_pb2.TestAllTypes()
     text = 'optional_nested_enum: BARR'

+ 63 - 64
python/google/protobuf/internal/type_checkers.py

@@ -192,47 +192,72 @@ TYPE_TO_BYTE_SIZE_FN = {
     }
 
 
-# Maps from field type to an unbound Encoder method F, such that
-# F(encoder, field_number, value) will append the serialization
-# of a value of this type to the encoder.
-_Encoder = encoder.Encoder
-TYPE_TO_SERIALIZE_METHOD = {
-    _FieldDescriptor.TYPE_DOUBLE: _Encoder.AppendDouble,
-    _FieldDescriptor.TYPE_FLOAT: _Encoder.AppendFloat,
-    _FieldDescriptor.TYPE_INT64: _Encoder.AppendInt64,
-    _FieldDescriptor.TYPE_UINT64: _Encoder.AppendUInt64,
-    _FieldDescriptor.TYPE_INT32: _Encoder.AppendInt32,
-    _FieldDescriptor.TYPE_FIXED64: _Encoder.AppendFixed64,
-    _FieldDescriptor.TYPE_FIXED32: _Encoder.AppendFixed32,
-    _FieldDescriptor.TYPE_BOOL: _Encoder.AppendBool,
-    _FieldDescriptor.TYPE_STRING: _Encoder.AppendString,
-    _FieldDescriptor.TYPE_GROUP: _Encoder.AppendGroup,
-    _FieldDescriptor.TYPE_MESSAGE: _Encoder.AppendMessage,
-    _FieldDescriptor.TYPE_BYTES: _Encoder.AppendBytes,
-    _FieldDescriptor.TYPE_UINT32: _Encoder.AppendUInt32,
-    _FieldDescriptor.TYPE_ENUM: _Encoder.AppendEnum,
-    _FieldDescriptor.TYPE_SFIXED32: _Encoder.AppendSFixed32,
-    _FieldDescriptor.TYPE_SFIXED64: _Encoder.AppendSFixed64,
-    _FieldDescriptor.TYPE_SINT32: _Encoder.AppendSInt32,
-    _FieldDescriptor.TYPE_SINT64: _Encoder.AppendSInt64,
+# Maps from field types to encoder constructors.
+TYPE_TO_ENCODER = {
+    _FieldDescriptor.TYPE_DOUBLE: encoder.DoubleEncoder,
+    _FieldDescriptor.TYPE_FLOAT: encoder.FloatEncoder,
+    _FieldDescriptor.TYPE_INT64: encoder.Int64Encoder,
+    _FieldDescriptor.TYPE_UINT64: encoder.UInt64Encoder,
+    _FieldDescriptor.TYPE_INT32: encoder.Int32Encoder,
+    _FieldDescriptor.TYPE_FIXED64: encoder.Fixed64Encoder,
+    _FieldDescriptor.TYPE_FIXED32: encoder.Fixed32Encoder,
+    _FieldDescriptor.TYPE_BOOL: encoder.BoolEncoder,
+    _FieldDescriptor.TYPE_STRING: encoder.StringEncoder,
+    _FieldDescriptor.TYPE_GROUP: encoder.GroupEncoder,
+    _FieldDescriptor.TYPE_MESSAGE: encoder.MessageEncoder,
+    _FieldDescriptor.TYPE_BYTES: encoder.BytesEncoder,
+    _FieldDescriptor.TYPE_UINT32: encoder.UInt32Encoder,
+    _FieldDescriptor.TYPE_ENUM: encoder.EnumEncoder,
+    _FieldDescriptor.TYPE_SFIXED32: encoder.SFixed32Encoder,
+    _FieldDescriptor.TYPE_SFIXED64: encoder.SFixed64Encoder,
+    _FieldDescriptor.TYPE_SINT32: encoder.SInt32Encoder,
+    _FieldDescriptor.TYPE_SINT64: encoder.SInt64Encoder,
     }
 
 
-TYPE_TO_NOTAG_SERIALIZE_METHOD = {
-    _FieldDescriptor.TYPE_DOUBLE: _Encoder.AppendDoubleNoTag,
-    _FieldDescriptor.TYPE_FLOAT: _Encoder.AppendFloatNoTag,
-    _FieldDescriptor.TYPE_INT64: _Encoder.AppendInt64NoTag,
-    _FieldDescriptor.TYPE_UINT64: _Encoder.AppendUInt64NoTag,
-    _FieldDescriptor.TYPE_INT32: _Encoder.AppendInt32NoTag,
-    _FieldDescriptor.TYPE_FIXED64: _Encoder.AppendFixed64NoTag,
-    _FieldDescriptor.TYPE_FIXED32: _Encoder.AppendFixed32NoTag,
-    _FieldDescriptor.TYPE_BOOL: _Encoder.AppendBoolNoTag,
-    _FieldDescriptor.TYPE_UINT32: _Encoder.AppendUInt32NoTag,
-    _FieldDescriptor.TYPE_ENUM: _Encoder.AppendEnumNoTag,
-    _FieldDescriptor.TYPE_SFIXED32: _Encoder.AppendSFixed32NoTag,
-    _FieldDescriptor.TYPE_SFIXED64: _Encoder.AppendSFixed64NoTag,
-    _FieldDescriptor.TYPE_SINT32: _Encoder.AppendSInt32NoTag,
-    _FieldDescriptor.TYPE_SINT64: _Encoder.AppendSInt64NoTag,
+# Maps from field types to sizer constructors.
+TYPE_TO_SIZER = {
+    _FieldDescriptor.TYPE_DOUBLE: encoder.DoubleSizer,
+    _FieldDescriptor.TYPE_FLOAT: encoder.FloatSizer,
+    _FieldDescriptor.TYPE_INT64: encoder.Int64Sizer,
+    _FieldDescriptor.TYPE_UINT64: encoder.UInt64Sizer,
+    _FieldDescriptor.TYPE_INT32: encoder.Int32Sizer,
+    _FieldDescriptor.TYPE_FIXED64: encoder.Fixed64Sizer,
+    _FieldDescriptor.TYPE_FIXED32: encoder.Fixed32Sizer,
+    _FieldDescriptor.TYPE_BOOL: encoder.BoolSizer,
+    _FieldDescriptor.TYPE_STRING: encoder.StringSizer,
+    _FieldDescriptor.TYPE_GROUP: encoder.GroupSizer,
+    _FieldDescriptor.TYPE_MESSAGE: encoder.MessageSizer,
+    _FieldDescriptor.TYPE_BYTES: encoder.BytesSizer,
+    _FieldDescriptor.TYPE_UINT32: encoder.UInt32Sizer,
+    _FieldDescriptor.TYPE_ENUM: encoder.EnumSizer,
+    _FieldDescriptor.TYPE_SFIXED32: encoder.SFixed32Sizer,
+    _FieldDescriptor.TYPE_SFIXED64: encoder.SFixed64Sizer,
+    _FieldDescriptor.TYPE_SINT32: encoder.SInt32Sizer,
+    _FieldDescriptor.TYPE_SINT64: encoder.SInt64Sizer,
+    }
+
+
+# Maps from field type to a decoder constructor.
+TYPE_TO_DECODER = {
+    _FieldDescriptor.TYPE_DOUBLE: decoder.DoubleDecoder,
+    _FieldDescriptor.TYPE_FLOAT: decoder.FloatDecoder,
+    _FieldDescriptor.TYPE_INT64: decoder.Int64Decoder,
+    _FieldDescriptor.TYPE_UINT64: decoder.UInt64Decoder,
+    _FieldDescriptor.TYPE_INT32: decoder.Int32Decoder,
+    _FieldDescriptor.TYPE_FIXED64: decoder.Fixed64Decoder,
+    _FieldDescriptor.TYPE_FIXED32: decoder.Fixed32Decoder,
+    _FieldDescriptor.TYPE_BOOL: decoder.BoolDecoder,
+    _FieldDescriptor.TYPE_STRING: decoder.StringDecoder,
+    _FieldDescriptor.TYPE_GROUP: decoder.GroupDecoder,
+    _FieldDescriptor.TYPE_MESSAGE: decoder.MessageDecoder,
+    _FieldDescriptor.TYPE_BYTES: decoder.BytesDecoder,
+    _FieldDescriptor.TYPE_UINT32: decoder.UInt32Decoder,
+    _FieldDescriptor.TYPE_ENUM: decoder.EnumDecoder,
+    _FieldDescriptor.TYPE_SFIXED32: decoder.SFixed32Decoder,
+    _FieldDescriptor.TYPE_SFIXED64: decoder.SFixed64Decoder,
+    _FieldDescriptor.TYPE_SINT32: decoder.SInt32Decoder,
+    _FieldDescriptor.TYPE_SINT64: decoder.SInt64Decoder,
     }
 
 # Maps from field type to expected wiretype.
@@ -259,29 +284,3 @@ FIELD_TYPE_TO_WIRE_TYPE = {
     _FieldDescriptor.TYPE_SINT32: wire_format.WIRETYPE_VARINT,
     _FieldDescriptor.TYPE_SINT64: wire_format.WIRETYPE_VARINT,
     }
-
-
-# Maps from field type to an unbound Decoder method F,
-# such that F(decoder) will read a field of the requested type.
-#
-# Note that Message and Group are intentionally missing here.
-# They're handled by _RecursivelyMerge().
-_Decoder = decoder.Decoder
-TYPE_TO_DESERIALIZE_METHOD = {
-    _FieldDescriptor.TYPE_DOUBLE: _Decoder.ReadDouble,
-    _FieldDescriptor.TYPE_FLOAT: _Decoder.ReadFloat,
-    _FieldDescriptor.TYPE_INT64: _Decoder.ReadInt64,
-    _FieldDescriptor.TYPE_UINT64: _Decoder.ReadUInt64,
-    _FieldDescriptor.TYPE_INT32: _Decoder.ReadInt32,
-    _FieldDescriptor.TYPE_FIXED64: _Decoder.ReadFixed64,
-    _FieldDescriptor.TYPE_FIXED32: _Decoder.ReadFixed32,
-    _FieldDescriptor.TYPE_BOOL: _Decoder.ReadBool,
-    _FieldDescriptor.TYPE_STRING: _Decoder.ReadString,
-    _FieldDescriptor.TYPE_BYTES: _Decoder.ReadBytes,
-    _FieldDescriptor.TYPE_UINT32: _Decoder.ReadUInt32,
-    _FieldDescriptor.TYPE_ENUM: _Decoder.ReadEnum,
-    _FieldDescriptor.TYPE_SFIXED32: _Decoder.ReadSFixed32,
-    _FieldDescriptor.TYPE_SFIXED64: _Decoder.ReadSFixed64,
-    _FieldDescriptor.TYPE_SINT32: _Decoder.ReadSInt32,
-    _FieldDescriptor.TYPE_SINT64: _Decoder.ReadSInt64,
-    }

+ 24 - 3
python/google/protobuf/internal/wire_format.py

@@ -33,16 +33,17 @@
 __author__ = 'robinson@google.com (Will Robinson)'
 
 import struct
+from google.protobuf import descriptor
 from google.protobuf import message
 
 
 TAG_TYPE_BITS = 3  # Number of bits used to hold type info in a proto tag.
-_TAG_TYPE_MASK = (1 << TAG_TYPE_BITS) - 1  # 0x7
+TAG_TYPE_MASK = (1 << TAG_TYPE_BITS) - 1  # 0x7
 
 # These numbers identify the wire type of a protocol buffer value.
 # We use the least-significant TAG_TYPE_BITS bits of the varint-encoded
 # tag-and-type to store one of these WIRETYPE_* constants.
-# These values must match WireType enum in //net/proto2/public/wire_format.h.
+# These values must match WireType enum in google/protobuf/wire_format.h.
 WIRETYPE_VARINT = 0
 WIRETYPE_FIXED64 = 1
 WIRETYPE_LENGTH_DELIMITED = 2
@@ -93,7 +94,7 @@ def UnpackTag(tag):
   """The inverse of PackTag().  Given an unsigned 32-bit number,
   returns a (field_number, wire_type) tuple.
   """
-  return (tag >> TAG_TYPE_BITS), (tag & _TAG_TYPE_MASK)
+  return (tag >> TAG_TYPE_BITS), (tag & TAG_TYPE_MASK)
 
 
 def ZigZagEncode(value):
@@ -245,3 +246,23 @@ def _VarUInt64ByteSizeNoTag(uint64):
   if uint64 > UINT64_MAX:
     raise message.EncodeError('Value out of range: %d' % uint64)
   return 10
+
+
+NON_PACKABLE_TYPES = (
+  descriptor.FieldDescriptor.TYPE_STRING,
+  descriptor.FieldDescriptor.TYPE_GROUP,
+  descriptor.FieldDescriptor.TYPE_MESSAGE,
+  descriptor.FieldDescriptor.TYPE_BYTES
+)
+
+
+def IsTypePackable(field_type):
+  """Return true iff packable = true is valid for fields of this type.
+
+  Args:
+    field_type: a FieldDescriptor::Type value.
+
+  Returns:
+    True iff fields of this type are packable.
+  """
+  return field_type not in NON_PACKABLE_TYPES

+ 10 - 1
python/google/protobuf/message.py

@@ -99,7 +99,7 @@ class Message(object):
     Args:
       other_msg: Message to copy into the current one.
     """
-    if self == other_msg:
+    if self is other_msg:
       return
     self.Clear()
     self.MergeFrom(other_msg)
@@ -108,6 +108,15 @@ class Message(object):
     """Clears all data that was set in the message."""
     raise NotImplementedError
 
+  def SetInParent(self):
+    """Mark this as present in the parent.
+
+    This normally happens automatically when you assign a field of a
+    sub-message, but sometimes you want to make the sub-message
+    present while keeping it empty.  If you find yourself using this,
+    you may want to reconsider your design."""
+    raise NotImplementedError
+
   def IsInitialized(self):
     """Checks if the message is initialized.
 

文件差異過大導致無法顯示
+ 400 - 733
python/google/protobuf/reflection.py


+ 5 - 0
python/google/protobuf/text_format.py

@@ -149,6 +149,10 @@ def _MergeField(tokenizer, message):
       name.append(tokenizer.ConsumeIdentifier())
     name = '.'.join(name)
 
+    if not message_descriptor.is_extendable:
+      raise tokenizer.ParseErrorPreviousToken(
+          'Message type "%s" does not have extensions.' %
+          message_descriptor.full_name)
     field = message.Extensions._FindExtensionByName(name)
     if not field:
       raise tokenizer.ParseErrorPreviousToken(
@@ -198,6 +202,7 @@ def _MergeField(tokenizer, message):
         sub_message = message.Extensions[field]
       else:
         sub_message = getattr(message, field.name)
+        sub_message.SetInParent()
 
     while not tokenizer.TryConsume(end_token):
       if tokenizer.AtEnd():

+ 1 - 10
python/setup.py

@@ -58,16 +58,13 @@ def MakeTestSuite():
   generate_proto("../src/google/protobuf/unittest.proto")
   generate_proto("../src/google/protobuf/unittest_import.proto")
   generate_proto("../src/google/protobuf/unittest_mset.proto")
+  generate_proto("../src/google/protobuf/unittest_no_generic_services.proto")
   generate_proto("google/protobuf/internal/more_extensions.proto")
   generate_proto("google/protobuf/internal/more_messages.proto")
 
   import unittest
   import google.protobuf.internal.generator_test     as generator_test
-  import google.protobuf.internal.decoder_test       as decoder_test
   import google.protobuf.internal.descriptor_test    as descriptor_test
-  import google.protobuf.internal.encoder_test       as encoder_test
-  import google.protobuf.internal.input_stream_test  as input_stream_test
-  import google.protobuf.internal.output_stream_test as output_stream_test
   import google.protobuf.internal.reflection_test    as reflection_test
   import google.protobuf.internal.service_reflection_test \
     as service_reflection_test
@@ -77,11 +74,7 @@ def MakeTestSuite():
   loader = unittest.defaultTestLoader
   suite = unittest.TestSuite()
   for test in [ generator_test,
-                decoder_test,
                 descriptor_test,
-                encoder_test,
-                input_stream_test,
-                output_stream_test,
                 reflection_test,
                 service_reflection_test,
                 text_format_test,
@@ -114,9 +107,7 @@ if __name__ == '__main__':
           'google.protobuf.internal.containers',
           'google.protobuf.internal.decoder',
           'google.protobuf.internal.encoder',
-          'google.protobuf.internal.input_stream',
           'google.protobuf.internal.message_listener',
-          'google.protobuf.internal.output_stream',
           'google.protobuf.internal.type_checkers',
           'google.protobuf.internal.wire_format',
           'google.protobuf.descriptor',

+ 27 - 2
src/Makefile.am

@@ -24,7 +24,8 @@ AM_LDFLAGS = $(PTHREAD_CFLAGS)
 # If I say "dist_include_DATA", automake complains that $(includedir) is not
 # a "legitimate" directory for DATA.  Screw you, automake.
 protodir = $(includedir)
-nobase_dist_proto_DATA = google/protobuf/descriptor.proto
+nobase_dist_proto_DATA = google/protobuf/descriptor.proto \
+                         google/protobuf/compiler/plugin.proto
 
 # Not sure why these don't get cleaned automatically.
 clean-local:
@@ -66,6 +67,8 @@ nobase_include_HEADERS =                                       \
   google/protobuf/compiler/command_line_interface.h            \
   google/protobuf/compiler/importer.h                          \
   google/protobuf/compiler/parser.h                            \
+  google/protobuf/compiler/plugin.h                            \
+  google/protobuf/compiler/plugin.pb.h                         \
   google/protobuf/compiler/cpp/cpp_generator.h                 \
   google/protobuf/compiler/java/java_generator.h               \
   google/protobuf/compiler/python/python_generator.h
@@ -87,6 +90,7 @@ libprotobuf_lite_la_SOURCES =                                  \
   google/protobuf/repeated_field.cc                            \
   google/protobuf/wire_format_lite.cc                          \
   google/protobuf/io/coded_stream.cc                           \
+  google/protobuf/io/coded_stream_inl.h                        \
   google/protobuf/io/zero_copy_stream.cc                       \
   google/protobuf/io/zero_copy_stream_impl_lite.cc
 
@@ -123,6 +127,10 @@ libprotoc_la_LDFLAGS = -version-info 5:0:0
 libprotoc_la_SOURCES =                                         \
   google/protobuf/compiler/code_generator.cc                   \
   google/protobuf/compiler/command_line_interface.cc           \
+  google/protobuf/compiler/plugin.cc                           \
+  google/protobuf/compiler/plugin.pb.cc                        \
+  google/protobuf/compiler/subprocess.cc                       \
+  google/protobuf/compiler/subprocess.h                        \
   google/protobuf/compiler/cpp/cpp_enum.cc                     \
   google/protobuf/compiler/cpp/cpp_enum.h                      \
   google/protobuf/compiler/cpp/cpp_enum_field.cc               \
@@ -186,6 +194,7 @@ protoc_inputs =                                                \
   google/protobuf/unittest_lite.proto                          \
   google/protobuf/unittest_import_lite.proto                   \
   google/protobuf/unittest_lite_imports_nonlite.proto          \
+  google/protobuf/unittest_no_generic_services.proto           \
   google/protobuf/compiler/cpp/cpp_test_bad_identifiers.proto
 
 EXTRA_DIST =                                                   \
@@ -226,6 +235,8 @@ protoc_outputs =                                               \
   google/protobuf/unittest_custom_options.pb.h                 \
   google/protobuf/unittest_lite_imports_nonlite.pb.cc          \
   google/protobuf/unittest_lite_imports_nonlite.pb.h           \
+  google/protobuf/unittest_no_generic_services.pb.cc           \
+  google/protobuf/unittest_no_generic_services.pb.h            \
   google/protobuf/compiler/cpp/cpp_test_bad_identifiers.pb.cc  \
   google/protobuf/compiler/cpp/cpp_test_bad_identifiers.pb.h
 
@@ -265,7 +276,7 @@ COMMON_TEST_SOURCES =                                          \
   google/protobuf/testing/file.cc                              \
   google/protobuf/testing/file.h
 
-check_PROGRAMS = protobuf-test protobuf-lazy-descriptor-test protobuf-lite-test $(GZCHECKPROGRAMS)
+check_PROGRAMS = protobuf-test protobuf-lazy-descriptor-test protobuf-lite-test test_plugin $(GZCHECKPROGRAMS)
 protobuf_test_LDADD = $(PTHREAD_LIBS) libprotobuf.la libprotoc.la \
                       $(top_builddir)/gtest/lib/libgtest.la       \
                       $(top_builddir)/gtest/lib/libgtest_main.la
@@ -297,9 +308,14 @@ protobuf_test_SOURCES =                                        \
   google/protobuf/io/zero_copy_stream_unittest.cc              \
   google/protobuf/compiler/command_line_interface_unittest.cc  \
   google/protobuf/compiler/importer_unittest.cc                \
+  google/protobuf/compiler/mock_code_generator.cc              \
+  google/protobuf/compiler/mock_code_generator.h               \
   google/protobuf/compiler/parser_unittest.cc                  \
   google/protobuf/compiler/cpp/cpp_bootstrap_unittest.cc       \
   google/protobuf/compiler/cpp/cpp_unittest.cc                 \
+  google/protobuf/compiler/cpp/cpp_plugin_unittest.cc          \
+  google/protobuf/compiler/java/java_plugin_unittest.cc        \
+  google/protobuf/compiler/python/python_plugin_unittest.cc    \
   $(COMMON_TEST_SOURCES)
 nodist_protobuf_test_SOURCES = $(protoc_outputs)
 
@@ -325,6 +341,15 @@ protobuf_lite_test_SOURCES =                                           \
   google/protobuf/test_util_lite.h
 nodist_protobuf_lite_test_SOURCES = $(protoc_lite_outputs)
 
+# Test plugin binary.
+test_plugin_LDADD = $(PTHREAD_LIBS) libprotobuf.la libprotoc.la \
+                    $(top_builddir)/gtest/lib/libgtest.la
+test_plugin_SOURCES =                                          \
+  google/protobuf/compiler/mock_code_generator.cc              \
+  google/protobuf/testing/file.cc                              \
+  google/protobuf/testing/file.h                               \
+  google/protobuf/compiler/test_plugin.cc
+
 if HAVE_ZLIB
 zcgzip_LDADD = $(PTHREAD_LIBS) libprotobuf.la
 zcgzip_SOURCES = google/protobuf/testing/zcgzip.cc

+ 8 - 1
src/google/protobuf/compiler/code_generator.cc

@@ -34,6 +34,7 @@
 
 #include <google/protobuf/compiler/code_generator.h>
 
+#include <google/protobuf/stubs/common.h>
 #include <google/protobuf/stubs/strutil.h>
 
 namespace google {
@@ -43,9 +44,15 @@ namespace compiler {
 CodeGenerator::~CodeGenerator() {}
 OutputDirectory::~OutputDirectory() {}
 
+io::ZeroCopyOutputStream* OutputDirectory::OpenForInsert(
+    const string& filename, const string& insertion_point) {
+  GOOGLE_LOG(FATAL) << "This OutputDirectory does not support insertion.";
+  return NULL;  // make compiler happy
+}
+
 // Parses a set of comma-delimited name/value pairs.
 void ParseGeneratorParameter(const string& text,
-			     vector<pair<string, string> >* output) {
+                             vector<pair<string, string> >* output) {
   vector<string> parts;
   SplitStringUsing(text, ",", &parts);
 

+ 8 - 1
src/google/protobuf/compiler/code_generator.h

@@ -103,6 +103,13 @@ class LIBPROTOC_EXPORT OutputDirectory {
   // contain "." or ".." components.
   virtual io::ZeroCopyOutputStream* Open(const string& filename) = 0;
 
+  // Creates a ZeroCopyOutputStream which will insert code into the given file
+  // at the given insertion point.  See plugin.proto for more information on
+  // insertion points.  The default implementation assert-fails -- it exists
+  // only for backwards-compatibility.
+  virtual io::ZeroCopyOutputStream* OpenForInsert(
+      const string& filename, const string& insertion_point);
+
  private:
   GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(OutputDirectory);
 };
@@ -114,7 +121,7 @@ class LIBPROTOC_EXPORT OutputDirectory {
 // parses to the pairs:
 //   ("foo", "bar"), ("baz", ""), ("qux", "corge")
 extern void ParseGeneratorParameter(const string&,
-				    vector<pair<string, string> >*);
+            vector<pair<string, string> >*);
 
 }  // namespace compiler
 }  // namespace protobuf

+ 432 - 36
src/google/protobuf/compiler/command_line_interface.cc

@@ -32,6 +32,8 @@
 //  Based on original Protocol Buffers design by
 //  Sanjay Ghemawat, Jeff Dean, and others.
 
+#include <google/protobuf/compiler/command_line_interface.h>
+
 #include <stdio.h>
 #include <sys/types.h>
 #include <sys/stat.h>
@@ -46,15 +48,19 @@
 #include <iostream>
 #include <ctype.h>
 
-#include <google/protobuf/compiler/command_line_interface.h>
 #include <google/protobuf/compiler/importer.h>
 #include <google/protobuf/compiler/code_generator.h>
+#include <google/protobuf/compiler/plugin.pb.h>
+#include <google/protobuf/compiler/subprocess.h>
 #include <google/protobuf/descriptor.h>
 #include <google/protobuf/text_format.h>
 #include <google/protobuf/dynamic_message.h>
 #include <google/protobuf/io/zero_copy_stream_impl.h>
+#include <google/protobuf/io/printer.h>
 #include <google/protobuf/stubs/common.h>
 #include <google/protobuf/stubs/strutil.h>
+#include <google/protobuf/stubs/substitute.h>
+#include <google/protobuf/stubs/map-util.h>
 
 
 namespace google {
@@ -182,6 +188,8 @@ class CommandLineInterface::DiskOutputDirectory : public OutputDirectory {
 
   // implements OutputDirectory --------------------------------------
   io::ZeroCopyOutputStream* Open(const string& filename);
+  io::ZeroCopyOutputStream* OpenForInsert(
+      const string& filename, const string& insertion_point);
 
  private:
   string root_;
@@ -209,11 +217,45 @@ class CommandLineInterface::ErrorReportingFileOutput
 
  private:
   scoped_ptr<io::FileOutputStream> file_stream_;
-  int file_descriptor_;
   string filename_;
   DiskOutputDirectory* directory_;
 };
 
+// Kind of like ErrorReportingFileOutput, but used when inserting
+// (OutputDirectory::OpenForInsert()).  In this case, we are writing to a
+// temporary file, since we must copy data from the original.  We copy the
+// data up to the insertion point in the constructor, and the remainder in the
+// destructor.  We then replace the original file with the temporary, also in
+// the destructor.
+class CommandLineInterface::InsertionOutputStream
+    : public io::ZeroCopyOutputStream {
+ public:
+  InsertionOutputStream(
+      const string& filename,
+      const string& temp_filename,
+      const string& insertion_point,
+      int original_file_descriptor,            // Takes ownership.
+      int temp_file_descriptor,                // Takes ownership.
+      DiskOutputDirectory* directory);         // Does not take ownership.
+  ~InsertionOutputStream();
+
+  // implements ZeroCopyOutputStream ---------------------------------
+  bool Next(void** data, int* size) { return temp_file_->Next(data, size); }
+  void BackUp(int count)            {        temp_file_->BackUp(count);    }
+  int64 ByteCount() const           { return temp_file_->ByteCount();      }
+
+ private:
+  scoped_ptr<io::FileInputStream> original_file_;
+  scoped_ptr<io::FileOutputStream> temp_file_;
+
+  string filename_;
+  string temp_filename_;
+  DiskOutputDirectory* directory_;
+
+  // The contents of the line containing the insertion point.
+  string magic_line_;
+};
+
 // -------------------------------------------------------------------
 
 CommandLineInterface::DiskOutputDirectory::DiskOutputDirectory(
@@ -242,6 +284,8 @@ bool CommandLineInterface::DiskOutputDirectory::VerifyExistence() {
   return true;
 }
 
+// -------------------------------------------------------------------
+
 io::ZeroCopyOutputStream* CommandLineInterface::DiskOutputDirectory::Open(
     const string& filename) {
   // Recursively create parent directories to the output file.
@@ -286,7 +330,6 @@ CommandLineInterface::ErrorReportingFileOutput::ErrorReportingFileOutput(
     const string& filename,
     DiskOutputDirectory* directory)
   : file_stream_(new io::FileOutputStream(file_descriptor)),
-    file_descriptor_(file_descriptor),
     filename_(filename),
     directory_(directory) {}
 
@@ -304,6 +347,201 @@ CommandLineInterface::ErrorReportingFileOutput::~ErrorReportingFileOutput() {
   }
 }
 
+// -------------------------------------------------------------------
+
+io::ZeroCopyOutputStream*
+CommandLineInterface::DiskOutputDirectory::OpenForInsert(
+    const string& filename, const string& insertion_point) {
+  string path = root_ + filename;
+
+  // Put the temp file in the same directory so that we can simply rename() it
+  // into place later.
+  string temp_path = path + ".protoc_temp";
+
+  // Open the original file.
+  int original_file;
+  do {
+    original_file = open(path.c_str(), O_RDONLY | O_BINARY);
+  } while (original_file < 0 && errno == EINTR);
+
+  if (original_file < 0) {
+    // Failed to open.
+    cerr << path << ": " << strerror(errno) << endl;
+    had_error_ = true;
+    // Return a dummy stream.
+    return new io::ArrayOutputStream(NULL, 0);
+  }
+
+  // Create the temp file.
+  int temp_file;
+  do {
+    temp_file =
+      open(temp_path.c_str(),
+           O_WRONLY | O_CREAT | O_TRUNC | O_BINARY, 0666);
+  } while (temp_file < 0 && errno == EINTR);
+
+  if (temp_file < 0) {
+    // Failed to open.
+    cerr << temp_path << ": " << strerror(errno) << endl;
+    had_error_ = true;
+    close(original_file);
+    // Return a dummy stream.
+    return new io::ArrayOutputStream(NULL, 0);
+  }
+
+  return new InsertionOutputStream(
+      path, temp_path, insertion_point, original_file, temp_file, this);
+}
+
+namespace {
+
+// Helper for reading lines from a ZeroCopyInputStream.
+// TODO(kenton):  Put somewhere reusable?
+class LineReader {
+ public:
+  LineReader(io::ZeroCopyInputStream* input)
+      : input_(input), buffer_(NULL), size_(0) {}
+
+  ~LineReader() {
+    if (size_ > 0) {
+      input_->BackUp(size_);
+    }
+  }
+
+  bool ReadLine(string* line) {
+    line->clear();
+
+    while (true) {
+      for (int i = 0; i < size_; i++) {
+        if (buffer_[i] == '\n') {
+          line->append(buffer_, i + 1);
+          buffer_ += i + 1;
+          size_ -= i + 1;
+          return true;
+        }
+      }
+
+      line->append(buffer_, size_);
+
+      const void* void_buffer;
+      if (!input_->Next(&void_buffer, &size_)) {
+        buffer_ = NULL;
+        size_ = 0;
+        return false;
+      }
+
+      buffer_ = reinterpret_cast<const char*>(void_buffer);
+    }
+  }
+
+ private:
+  io::ZeroCopyInputStream* input_;
+  const char* buffer_;
+  int size_;
+};
+
+}  // namespace
+
+CommandLineInterface::InsertionOutputStream::InsertionOutputStream(
+    const string& filename,
+    const string& temp_filename,
+    const string& insertion_point,
+    int original_file_descriptor,
+    int temp_file_descriptor,
+    DiskOutputDirectory* directory)
+    : original_file_(new io::FileInputStream(original_file_descriptor)),
+      temp_file_(new io::FileOutputStream(temp_file_descriptor)),
+      filename_(filename),
+      temp_filename_(temp_filename),
+      directory_(directory) {
+  string magic_string = strings::Substitute(
+      "@@protoc_insertion_point($0)", insertion_point);
+
+  LineReader reader(original_file_.get());
+  io::Printer writer(temp_file_.get(), '$');
+  string line;
+
+  while (true) {
+    if (!reader.ReadLine(&line)) {
+      int error = temp_file_->GetErrno();
+      if (error == 0) {
+        cerr << filename << ": Insertion point not found: "
+             << insertion_point << endl;
+      } else {
+        cerr << filename << ": " << strerror(error) << endl;
+      }
+      original_file_->Close();
+      original_file_.reset();
+      // Will finish handling error in the destructor.
+      break;
+    }
+
+    if (line.find(magic_string) != string::npos) {
+      // Found the magic line.  Since we want to insert before it, save it for
+      // later.
+      magic_line_ = line;
+      break;
+    }
+
+    writer.PrintRaw(line);
+  }
+}
+
+CommandLineInterface::InsertionOutputStream::~InsertionOutputStream() {
+  // C-style error handling is teh best.
+  bool had_error = false;
+
+  if (original_file_ == NULL) {
+    // We had an error in the constructor.
+    had_error = true;
+  } else {
+    // Use CodedOutputStream for convenience, so we don't have to deal with
+    // copying buffers ourselves.
+    io::CodedOutputStream out(temp_file_.get());
+    out.WriteRaw(magic_line_.data(), magic_line_.size());
+
+    // Write the rest of the original file.
+    const void* buffer;
+    int size;
+    while (original_file_->Next(&buffer, &size)) {
+      out.WriteRaw(buffer, size);
+    }
+
+    // Close the original file.
+    if (!original_file_->Close()) {
+      cerr << filename_ << ": " << strerror(original_file_->GetErrno()) << endl;
+      had_error = true;
+    }
+  }
+
+  // Check if we had any errors while writing.
+  if (temp_file_->GetErrno() != 0) {
+    cerr << filename_ << ": " << strerror(temp_file_->GetErrno()) << endl;
+    had_error = true;
+  }
+
+  // Close the temp file.
+  if (!temp_file_->Close()) {
+    cerr << filename_ << ": " << strerror(temp_file_->GetErrno()) << endl;
+    had_error = true;
+  }
+
+  // If everything was successful, overwrite the original file with the temp
+  // file.
+  if (!had_error) {
+    if (rename(temp_filename_.c_str(), filename_.c_str()) < 0) {
+      cerr << filename_ << ": rename: " << strerror(errno) << endl;
+      had_error = true;
+    }
+  }
+
+  if (had_error) {
+    // We had some sort of error so let's try to delete the temp file.
+    remove(temp_filename_.c_str());
+    directory_->set_had_error(true);
+  }
+}
+
 // ===================================================================
 
 CommandLineInterface::CommandLineInterface()
@@ -323,6 +561,10 @@ void CommandLineInterface::RegisterGenerator(const string& flag_name,
   generators_[flag_name] = info;
 }
 
+void CommandLineInterface::AllowPlugins(const string& exe_name_prefix) {
+  plugin_prefix_ = exe_name_prefix;
+}
+
 int CommandLineInterface::Run(int argc, const char* const argv[]) {
   Clear();
   if (!ParseArguments(argc, argv)) return 1;
@@ -346,7 +588,7 @@ int CommandLineInterface::Run(int argc, const char* const argv[]) {
 
   vector<const FileDescriptor*> parsed_files;
 
-  // Parse each file and generate output.
+  // Parse each file.
   for (int i = 0; i < input_files_.size(); i++) {
     // Import the file.
     const FileDescriptor* parsed_file = importer.Import(input_files_[i]);
@@ -359,13 +601,13 @@ int CommandLineInterface::Run(int argc, const char* const argv[]) {
               "--disallow_services was used." << endl;
       return 1;
     }
+  }
 
-    if (mode_ == MODE_COMPILE) {
-      // Generate output files.
-      for (int i = 0; i < output_directives_.size(); i++) {
-        if (!GenerateOutput(parsed_file, output_directives_[i])) {
-          return 1;
-        }
+  // Generate output.
+  if (mode_ == MODE_COMPILE) {
+    for (int i = 0; i < output_directives_.size(); i++) {
+      if (!GenerateOutput(parsed_files, output_directives_[i])) {
+        return 1;
       }
     }
   }
@@ -686,10 +928,37 @@ bool CommandLineInterface::InterpretArgument(const string& name,
       return false;
     }
 
+  } else if (name == "--plugin") {
+    if (plugin_prefix_.empty()) {
+      cerr << "This compiler does not support plugins." << endl;
+      return false;
+    }
+
+    string name;
+    string path;
+
+    string::size_type equals_pos = value.find_first_of('=');
+    if (equals_pos == string::npos) {
+      // Use the basename of the file.
+      string::size_type slash_pos = value.find_last_of('/');
+      if (slash_pos == string::npos) {
+        name = value;
+      } else {
+        name = value.substr(slash_pos + 1);
+      }
+      path = value;
+    } else {
+      name = value.substr(0, equals_pos);
+      path = value.substr(equals_pos + 1);
+    }
+
+    plugins_[name] = path;
+
   } else {
     // Some other flag.  Look it up in the generators list.
-    GeneratorMap::const_iterator iter = generators_.find(name);
-    if (iter == generators_.end()) {
+    const GeneratorInfo* generator_info = FindOrNull(generators_, name);
+    if (generator_info == NULL &&
+        (plugin_prefix_.empty() || !HasSuffixString(name, "_out"))) {
       cerr << "Unknown flag: " << name << endl;
       return false;
     }
@@ -703,7 +972,11 @@ bool CommandLineInterface::InterpretArgument(const string& name,
 
     OutputDirective directive;
     directive.name = name;
-    directive.generator = iter->second.generator;
+    if (generator_info == NULL) {
+      directive.generator = NULL;
+    } else {
+      directive.generator = generator_info->generator;
+    }
 
     // Split value at ':' to separate the generator parameter from the
     // filename.  However, avoid doing this if the colon is part of a valid
@@ -755,6 +1028,17 @@ void CommandLineInterface::PrintHelpText() {
 "  --error_format=FORMAT       Set the format in which to print errors.\n"
 "                              FORMAT may be 'gcc' (the default) or 'msvs'\n"
 "                              (Microsoft Visual Studio format)." << endl;
+  if (!plugin_prefix_.empty()) {
+    cerr <<
+"  --plugin=EXECUTABLE         Specifies a plugin executable to use.\n"
+"                              Normally, protoc searches the PATH for\n"
+"                              plugins, but you may specify additional\n"
+"                              executables not in the path using this flag.\n"
+"                              Additionally, EXECUTABLE may be of the form\n"
+"                              NAME=PATH, in which case the given plugin name\n"
+"                              is mapped to the given executable even if\n"
+"                              the executable's own name differs." << endl;
+  }
 
   for (GeneratorMap::iterator iter = generators_.begin();
        iter != generators_.end(); ++iter) {
@@ -768,7 +1052,7 @@ void CommandLineInterface::PrintHelpText() {
 }
 
 bool CommandLineInterface::GenerateOutput(
-    const FileDescriptor* parsed_file,
+    const vector<const FileDescriptor*>& parsed_files,
     const OutputDirective& output_directive) {
   // Create the output directory.
   DiskOutputDirectory output_directory(output_directive.output_location);
@@ -780,12 +1064,34 @@ bool CommandLineInterface::GenerateOutput(
 
   // Call the generator.
   string error;
-  if (!output_directive.generator->Generate(
-      parsed_file, output_directive.parameter, &output_directory, &error)) {
-    // Generator returned an error.
-    cerr << parsed_file->name() << ": " << output_directive.name << ": "
-         << error << endl;
-    return false;
+  if (output_directive.generator == NULL) {
+    // This is a plugin.
+    GOOGLE_CHECK(HasPrefixString(output_directive.name, "--") &&
+          HasSuffixString(output_directive.name, "_out"))
+        << "Bad name for plugin generator: " << output_directive.name;
+
+    // Strip the "--" and "_out" and add the plugin prefix.
+    string plugin_name = plugin_prefix_ + "gen-" +
+        output_directive.name.substr(2, output_directive.name.size() - 6);
+
+    if (!GeneratePluginOutput(parsed_files, plugin_name,
+                              output_directive.parameter,
+                              &output_directory, &error)) {
+      cerr << output_directive.name << ": " << error << endl;
+      return false;
+    }
+  } else {
+    // Regular generator.
+    for (int i = 0; i < parsed_files.size(); i++) {
+      if (!output_directive.generator->Generate(
+          parsed_files[i], output_directive.parameter,
+          &output_directory, &error)) {
+        // Generator returned an error.
+        cerr << output_directive.name << ": " << parsed_files[i]->name() << ": "
+             << error << endl;
+        return false;
+      }
+    }
   }
 
   // Check for write errors.
@@ -796,6 +1102,84 @@ bool CommandLineInterface::GenerateOutput(
   return true;
 }
 
+bool CommandLineInterface::GeneratePluginOutput(
+    const vector<const FileDescriptor*>& parsed_files,
+    const string& plugin_name,
+    const string& parameter,
+    OutputDirectory* output_directory,
+    string* error) {
+  CodeGeneratorRequest request;
+  CodeGeneratorResponse response;
+
+  // Build the request.
+  if (!parameter.empty()) {
+    request.set_parameter(parameter);
+  }
+
+  set<const FileDescriptor*> already_seen;
+  for (int i = 0; i < parsed_files.size(); i++) {
+    request.add_file_to_generate(parsed_files[i]->name());
+    GetTransitiveDependencies(parsed_files[i], &already_seen,
+                              request.mutable_proto_file());
+  }
+
+  // Invoke the plugin.
+  Subprocess subprocess;
+
+  if (plugins_.count(plugin_name) > 0) {
+    subprocess.Start(plugins_[plugin_name], Subprocess::EXACT_NAME);
+  } else {
+    subprocess.Start(plugin_name, Subprocess::SEARCH_PATH);
+  }
+
+  string communicate_error;
+  if (!subprocess.Communicate(request, &response, &communicate_error)) {
+    *error = strings::Substitute("$0: $1", plugin_name, communicate_error);
+    return false;
+  }
+
+  // Write the files.  We do this even if there was a generator error in order
+  // to match the behavior of a compiled-in generator.
+  scoped_ptr<io::ZeroCopyOutputStream> current_output;
+  for (int i = 0; i < response.file_size(); i++) {
+    const CodeGeneratorResponse::File& output_file = response.file(i);
+
+    if (!output_file.insertion_point().empty()) {
+      // Open a file for insert.
+      // We reset current_output to NULL first so that the old file is closed
+      // before the new one is opened.
+      current_output.reset();
+      current_output.reset(output_directory->OpenForInsert(
+          output_file.name(), output_file.insertion_point()));
+    } else if (!output_file.name().empty()) {
+      // Starting a new file.  Open it.
+      // We reset current_output to NULL first so that the old file is closed
+      // before the new one is opened.
+      current_output.reset();
+      current_output.reset(output_directory->Open(output_file.name()));
+    } else if (current_output == NULL) {
+      *error = strings::Substitute(
+        "$0: First file chunk returned by plugin did not specify a file name.",
+        plugin_name);
+      return false;
+    }
+
+    // Use CodedOutputStream for convenience; otherwise we'd need to provide
+    // our own buffer-copying loop.
+    io::CodedOutputStream writer(current_output.get());
+    writer.WriteString(output_file.content());
+  }
+
+  // Check for errors.
+  if (!response.error().empty()) {
+    // Generator returned an error.
+    *error = response.error();
+    return false;
+  }
+
+  return true;
+}
+
 bool CommandLineInterface::EncodeOrDecode(const DescriptorPool* pool) {
   // Look up the type.
   const Descriptor* type = pool->FindMessageTypeByName(codec_type_);
@@ -862,22 +1246,16 @@ bool CommandLineInterface::EncodeOrDecode(const DescriptorPool* pool) {
 bool CommandLineInterface::WriteDescriptorSet(
     const vector<const FileDescriptor*> parsed_files) {
   FileDescriptorSet file_set;
-  set<const FileDescriptor*> already_added;
-  vector<const FileDescriptor*> to_add(parsed_files);
-
-  while (!to_add.empty()) {
-    const FileDescriptor* file = to_add.back();
-    to_add.pop_back();
-    if (already_added.insert(file).second) {
-      // This file was not already in the set.
-      file->CopyTo(file_set.add_file());
-
-      if (imports_in_descriptor_set_) {
-        // Add all of this file's dependencies.
-        for (int i = 0; i < file->dependency_count(); i++) {
-          to_add.push_back(file->dependency(i));
-        }
-      }
+
+  if (imports_in_descriptor_set_) {
+    set<const FileDescriptor*> already_seen;
+    for (int i = 0; i < parsed_files.size(); i++) {
+      GetTransitiveDependencies(
+          parsed_files[i], &already_seen, file_set.mutable_file());
+    }
+  } else {
+    for (int i = 0; i < parsed_files.size(); i++) {
+      parsed_files[i]->CopyTo(file_set.add_file());
     }
   }
 
@@ -906,6 +1284,24 @@ bool CommandLineInterface::WriteDescriptorSet(
   return true;
 }
 
+void CommandLineInterface::GetTransitiveDependencies(
+    const FileDescriptor* file,
+    set<const FileDescriptor*>* already_seen,
+    RepeatedPtrField<FileDescriptorProto>* output) {
+  if (!already_seen->insert(file).second) {
+    // Already saw this file.  Skip.
+    return;
+  }
+
+  // Add all dependencies.
+  for (int i = 0; i < file->dependency_count(); i++) {
+    GetTransitiveDependencies(file->dependency(i), already_seen, output);
+  }
+
+  // Add this file.
+  file->CopyTo(output->Add());
+}
+
 
 }  // namespace compiler
 }  // namespace protobuf

+ 62 - 3
src/google/protobuf/compiler/command_line_interface.h

@@ -50,10 +50,13 @@ namespace protobuf {
 
 class FileDescriptor;        // descriptor.h
 class DescriptorPool;        // descriptor.h
+class FileDescriptorProto;   // descriptor.pb.h
+template<typename T> class RepeatedPtrField;  // repeated_field.h
 
 namespace compiler {
 
 class CodeGenerator;        // code_generator.h
+class OutputDirectory;      // code_generator.h
 class DiskSourceTree;       // importer.h
 
 // This class implements the command-line interface to the protocol compiler.
@@ -109,6 +112,37 @@ class LIBPROTOC_EXPORT CommandLineInterface {
                          CodeGenerator* generator,
                          const string& help_text);
 
+  // Enables "plugins".  In this mode, if a command-line flag ends with "_out"
+  // but does not match any registered generator, the compiler will attempt to
+  // find a "plugin" to implement the generator.  Plugins are just executables.
+  // They should live somewhere in the PATH.
+  //
+  // The compiler determines the executable name to search for by concatenating
+  // exe_name_prefix with the unrecognized flag name, removing "_out".  So, for
+  // example, if exe_name_prefix is "protoc-" and you pass the flag --foo_out,
+  // the compiler will try to run the program "protoc-foo".
+  //
+  // The plugin program should implement the following usage:
+  //   plugin [--out=OUTDIR] [--parameter=PARAMETER] PROTO_FILES < DESCRIPTORS
+  // --out indicates the output directory (as passed to the --foo_out
+  // parameter); if omitted, the current directory should be used.  --parameter
+  // gives the generator parameter, if any was provided.  The PROTO_FILES list
+  // the .proto files which were given on the compiler command-line; these are
+  // the files for which the plugin is expected to generate output code.
+  // Finally, DESCRIPTORS is an encoded FileDescriptorSet (as defined in
+  // descriptor.proto).  This is piped to the plugin's stdin.  The set will
+  // include descriptors for all the files listed in PROTO_FILES as well as
+  // all files that they import.  The plugin MUST NOT attempt to read the
+  // PROTO_FILES directly -- it must use the FileDescriptorSet.
+  //
+  // The plugin should generate whatever files are necessary, as code generators
+  // normally do.  It should write the names of all files it generates to
+  // stdout.  The names should be relative to the output directory, NOT absolute
+  // names or relative to the current directory.  If any errors occur, error
+  // messages should be written to stderr.  If an error is fatal, the plugin
+  // should exit with a non-zero exit code.
+  void AllowPlugins(const string& exe_name_prefix);
+
   // Run the Protocol Compiler with the given command-line parameters.
   // Returns the error code which should be returned by main().
   //
@@ -142,6 +176,7 @@ class LIBPROTOC_EXPORT CommandLineInterface {
   class ErrorPrinter;
   class DiskOutputDirectory;
   class ErrorReportingFileOutput;
+  class InsertionOutputStream;
 
   // Clear state from previous Run().
   void Clear();
@@ -176,8 +211,13 @@ class LIBPROTOC_EXPORT CommandLineInterface {
 
   // Generate the given output file from the given input.
   struct OutputDirective;  // see below
-  bool GenerateOutput(const FileDescriptor* proto_file,
+  bool GenerateOutput(const vector<const FileDescriptor*>& parsed_files,
                       const OutputDirective& output_directive);
+  bool GeneratePluginOutput(const vector<const FileDescriptor*>& parsed_files,
+                            const string& plugin_name,
+                            const string& parameter,
+                            OutputDirectory* output_directory,
+                            string* error);
 
   // Implements --encode and --decode.
   bool EncodeOrDecode(const DescriptorPool* pool);
@@ -185,6 +225,17 @@ class LIBPROTOC_EXPORT CommandLineInterface {
   // Implements the --descriptor_set_out option.
   bool WriteDescriptorSet(const vector<const FileDescriptor*> parsed_files);
 
+  // Get all transitive dependencies of the given file (including the file
+  // itself), adding them to the given list of FileDescriptorProtos.  The
+  // protos will be ordered such that every file is listed before any file that
+  // depends on it, so that you can call DescriptorPool::BuildFile() on them
+  // in order.  Any files in *already_seen will not be added, and each file
+  // added will be inserted into *already_seen.
+  static void GetTransitiveDependencies(
+      const FileDescriptor* file,
+      set<const FileDescriptor*>* already_seen,
+      RepeatedPtrField<FileDescriptorProto>* output);
+
   // -----------------------------------------------------------------
 
   // The name of the executable as invoked (i.e. argv[0]).
@@ -201,6 +252,14 @@ class LIBPROTOC_EXPORT CommandLineInterface {
   typedef map<string, GeneratorInfo> GeneratorMap;
   GeneratorMap generators_;
 
+  // See AllowPlugins().  If this is empty, plugins aren't allowed.
+  string plugin_prefix_;
+
+  // Maps specific plugin names to files.  When executing a plugin, this map
+  // is searched first to find the plugin executable.  If not found here, the
+  // PATH (or other OS-specific search strategy) is searched.
+  map<string, string> plugins_;
+
   // Stuff parsed from command line.
   enum Mode {
     MODE_COMPILE,  // Normal mode:  parse .proto files and compile them.
@@ -223,8 +282,8 @@ class LIBPROTOC_EXPORT CommandLineInterface {
   // output_directives_ lists all the files we are supposed to output and what
   // generator to use for each.
   struct OutputDirective {
-    string name;
-    CodeGenerator* generator;
+    string name;                // E.g. "--foo_out"
+    CodeGenerator* generator;   // NULL for plugins
     string parameter;
     string output_location;
   };

文件差異過大導致無法顯示
+ 271 - 287
src/google/protobuf/compiler/command_line_interface_unittest.cc


+ 10 - 1
src/google/protobuf/compiler/cpp/cpp_bootstrap_unittest.cc

@@ -123,8 +123,11 @@ TEST(BootstrapTest, GeneratedDescriptorMatches) {
   Importer importer(&source_tree, &error_collector);
   const FileDescriptor* proto_file =
     importer.Import("google/protobuf/descriptor.proto");
+  const FileDescriptor* plugin_proto_file =
+    importer.Import("google/protobuf/compiler/plugin.proto");
   EXPECT_EQ("", error_collector.text_);
   ASSERT_TRUE(proto_file != NULL);
+  ASSERT_TRUE(plugin_proto_file != NULL);
 
   CppGenerator generator;
   MockOutputDirectory output_directory;
@@ -133,11 +136,18 @@ TEST(BootstrapTest, GeneratedDescriptorMatches) {
   parameter = "dllexport_decl=LIBPROTOBUF_EXPORT";
   ASSERT_TRUE(generator.Generate(proto_file, parameter,
                                  &output_directory, &error));
+  parameter = "dllexport_decl=LIBPROTOC_EXPORT";
+  ASSERT_TRUE(generator.Generate(plugin_proto_file, parameter,
+                                 &output_directory, &error));
 
   output_directory.ExpectFileMatches("google/protobuf/descriptor.pb.h",
                                      "google/protobuf/descriptor.pb.h");
   output_directory.ExpectFileMatches("google/protobuf/descriptor.pb.cc",
                                      "google/protobuf/descriptor.pb.cc");
+  output_directory.ExpectFileMatches("google/protobuf/compiler/plugin.pb.h",
+                                     "google/protobuf/compiler/plugin.pb.h");
+  output_directory.ExpectFileMatches("google/protobuf/compiler/plugin.pb.cc",
+                                     "google/protobuf/compiler/plugin.pb.cc");
 }
 
 }  // namespace
@@ -145,5 +155,4 @@ TEST(BootstrapTest, GeneratedDescriptorMatches) {
 }  // namespace cpp
 }  // namespace compiler
 }  // namespace protobuf
-
 }  // namespace google

+ 10 - 4
src/google/protobuf/compiler/cpp/cpp_enum.cc

@@ -98,6 +98,7 @@ void EnumGenerator::GenerateDefinition(io::Printer* printer) {
     "$dllexport$bool $classname$_IsValid(int value);\n"
     "const $classname$ $prefix$$short_name$_MIN = $prefix$$min_name$;\n"
     "const $classname$ $prefix$$short_name$_MAX = $prefix$$max_name$;\n"
+    "const int $prefix$$short_name$_ARRAYSIZE = $prefix$$short_name$_MAX + 1;\n"
     "\n");
 
   if (HasDescriptorMethods(descriptor_->file())) {
@@ -149,17 +150,21 @@ void EnumGenerator::GenerateSymbolImports(io::Printer* printer) {
     "static const $nested_name$ $nested_name$_MIN =\n"
     "  $classname$_$nested_name$_MIN;\n"
     "static const $nested_name$ $nested_name$_MAX =\n"
-    "  $classname$_$nested_name$_MAX;\n");
+    "  $classname$_$nested_name$_MAX;\n"
+    "static const int $nested_name$_ARRAYSIZE =\n"
+    "  $classname$_$nested_name$_ARRAYSIZE;\n");
 
   if (HasDescriptorMethods(descriptor_->file())) {
     printer->Print(vars,
       "static inline const ::google::protobuf::EnumDescriptor*\n"
       "$nested_name$_descriptor() {\n"
       "  return $classname$_descriptor();\n"
-      "}\n"
+      "}\n");
+    printer->Print(vars,
       "static inline const ::std::string& $nested_name$_Name($nested_name$ value) {\n"
       "  return $classname$_Name(value);\n"
-      "}\n"
+      "}\n");
+    printer->Print(vars,
       "static inline bool $nested_name$_Parse(const ::std::string& name,\n"
       "    $nested_name$* value) {\n"
       "  return $classname$_Parse(name, value);\n"
@@ -240,7 +245,8 @@ void EnumGenerator::GenerateMethods(io::Printer* printer) {
     }
     printer->Print(vars,
       "const $classname$ $parent$::$nested_name$_MIN;\n"
-      "const $classname$ $parent$::$nested_name$_MAX;\n");
+      "const $classname$ $parent$::$nested_name$_MAX;\n"
+      "const int $parent$::$nested_name$_ARRAYSIZE;\n");
 
     printer->Print("#endif  // _MSC_VER\n");
   }

+ 45 - 26
src/google/protobuf/compiler/cpp/cpp_enum_field.cc

@@ -114,7 +114,9 @@ void EnumFieldGenerator::
 GenerateMergeFromCodedStream(io::Printer* printer) const {
   printer->Print(variables_,
     "int value;\n"
-    "DO_(::google::protobuf::internal::WireFormatLite::ReadEnum(input, &value));\n"
+    "DO_((::google::protobuf::internal::WireFormatLite::ReadPrimitive<\n"
+    "         int, ::google::protobuf::internal::WireFormatLite::TYPE_ENUM>(\n"
+    "       input, &value)));\n"
     "if ($type$_IsValid(value)) {\n"
     "  set_$name$(static_cast< $type$ >(value));\n");
   if (HasUnknownFields(descriptor_->file())) {
@@ -170,24 +172,17 @@ GeneratePrivateMembers(io::Printer* printer) const {
 void RepeatedEnumFieldGenerator::
 GenerateAccessorDeclarations(io::Printer* printer) const {
   printer->Print(variables_,
-    "inline const ::google::protobuf::RepeatedField<int>& $name$() const$deprecation$;\n"
-    "inline ::google::protobuf::RepeatedField<int>* mutable_$name$()$deprecation$;\n"
     "inline $type$ $name$(int index) const$deprecation$;\n"
     "inline void set_$name$(int index, $type$ value)$deprecation$;\n"
     "inline void add_$name$($type$ value)$deprecation$;\n");
+  printer->Print(variables_,
+    "inline const ::google::protobuf::RepeatedField<int>& $name$() const$deprecation$;\n"
+    "inline ::google::protobuf::RepeatedField<int>* mutable_$name$()$deprecation$;\n");
 }
 
 void RepeatedEnumFieldGenerator::
 GenerateInlineAccessorDefinitions(io::Printer* printer) const {
   printer->Print(variables_,
-    "inline const ::google::protobuf::RepeatedField<int>&\n"
-    "$classname$::$name$() const {\n"
-    "  return $name$_;\n"
-    "}\n"
-    "inline ::google::protobuf::RepeatedField<int>*\n"
-    "$classname$::mutable_$name$() {\n"
-    "  return &$name$_;\n"
-    "}\n"
     "inline $type$ $classname$::$name$(int index) const {\n"
     "  return static_cast< $type$ >($name$_.Get(index));\n"
     "}\n"
@@ -199,6 +194,15 @@ GenerateInlineAccessorDefinitions(io::Printer* printer) const {
     "  GOOGLE_DCHECK($type$_IsValid(value));\n"
     "  $name$_.Add(value);\n"
     "}\n");
+  printer->Print(variables_,
+    "inline const ::google::protobuf::RepeatedField<int>&\n"
+    "$classname$::$name$() const {\n"
+    "  return $name$_;\n"
+    "}\n"
+    "inline ::google::protobuf::RepeatedField<int>*\n"
+    "$classname$::mutable_$name$() {\n"
+    "  return &$name$_;\n"
+    "}\n");
 }
 
 void RepeatedEnumFieldGenerator::
@@ -223,7 +227,33 @@ GenerateConstructorCode(io::Printer* printer) const {
 
 void RepeatedEnumFieldGenerator::
 GenerateMergeFromCodedStream(io::Printer* printer) const {
-  if (descriptor_->options().packed()) {
+  // Don't use ReadRepeatedPrimitive here so that the enum can be validated.
+  printer->Print(variables_,
+    "int value;\n"
+    "DO_((::google::protobuf::internal::WireFormatLite::ReadPrimitive<\n"
+    "         int, ::google::protobuf::internal::WireFormatLite::TYPE_ENUM>(\n"
+    "       input, &value)));\n"
+    "if ($type$_IsValid(value)) {\n"
+    "  add_$name$(static_cast< $type$ >(value));\n");
+  if (HasUnknownFields(descriptor_->file())) {
+    printer->Print(variables_,
+      "} else {\n"
+      "  mutable_unknown_fields()->AddVarint($number$, value);\n");
+  }
+  printer->Print("}\n");
+}
+
+void RepeatedEnumFieldGenerator::
+GenerateMergeFromCodedStreamWithPacking(io::Printer* printer) const {
+  if (!descriptor_->options().packed()) {
+    // We use a non-inlined implementation in this case, since this path will
+    // rarely be executed.
+    printer->Print(variables_,
+      "DO_((::google::protobuf::internal::WireFormatLite::ReadPackedEnumNoInline(\n"
+      "       input,\n"
+      "       &$type$_IsValid,\n"
+      "       this->mutable_$name$())));\n");
+  } else {
     printer->Print(variables_,
       "::google::protobuf::uint32 length;\n"
       "DO_(input->ReadVarint32(&length));\n"
@@ -231,25 +261,14 @@ GenerateMergeFromCodedStream(io::Printer* printer) const {
           "input->PushLimit(length);\n"
       "while (input->BytesUntilLimit() > 0) {\n"
       "  int value;\n"
-      "  DO_(::google::protobuf::internal::WireFormatLite::ReadEnum(input, &value));\n"
+      "  DO_((::google::protobuf::internal::WireFormatLite::ReadPrimitive<\n"
+      "         int, ::google::protobuf::internal::WireFormatLite::TYPE_ENUM>(\n"
+      "       input, &value)));\n"
       "  if ($type$_IsValid(value)) {\n"
       "    add_$name$(static_cast< $type$ >(value));\n"
       "  }\n"
       "}\n"
       "input->PopLimit(limit);\n");
-  } else {
-    printer->Print(variables_,
-      "int value;\n"
-      "DO_(::google::protobuf::internal::WireFormatLite::ReadEnum(input, &value));\n"
-      "if ($type$_IsValid(value)) {\n"
-      "  add_$name$(static_cast< $type$ >(value));\n");
-    if (HasUnknownFields(descriptor_->file())) {
-      printer->Print(variables_,
-        "} else {\n"
-        "  mutable_unknown_fields()->AddVarint($number$, value);\n");
-    }
-    printer->Print(variables_,
-      "}\n");
   }
 }
 

+ 1 - 0
src/google/protobuf/compiler/cpp/cpp_enum_field.h

@@ -83,6 +83,7 @@ class RepeatedEnumFieldGenerator : public FieldGenerator {
   void GenerateSwappingCode(io::Printer* printer) const;
   void GenerateConstructorCode(io::Printer* printer) const;
   void GenerateMergeFromCodedStream(io::Printer* printer) const;
+  void GenerateMergeFromCodedStreamWithPacking(io::Printer* printer) const;
   void GenerateSerializeWithCachedSizes(io::Printer* printer) const;
   void GenerateSerializeWithCachedSizesToArray(io::Printer* printer) const;
   void GenerateByteSize(io::Printer* printer) const;

+ 17 - 4
src/google/protobuf/compiler/cpp/cpp_extension.cc

@@ -33,6 +33,7 @@
 //  Sanjay Ghemawat, Jeff Dean, and others.
 
 #include <google/protobuf/compiler/cpp/cpp_extension.h>
+#include <map>
 #include <google/protobuf/compiler/cpp/cpp_helpers.h>
 #include <google/protobuf/stubs/strutil.h>
 #include <google/protobuf/io/printer.h>
@@ -43,6 +44,18 @@ namespace protobuf {
 namespace compiler {
 namespace cpp {
 
+namespace {
+
+// Returns the fully-qualified class name of the message that this field
+// extends. This function is used in the Google-internal code to handle some
+// legacy cases.
+string ExtendeeClassName(const FieldDescriptor* descriptor) {
+  const Descriptor* extendee = descriptor->containing_type();
+  return ClassName(extendee, true);
+}
+
+}  // anonymous namespace
+
 ExtensionGenerator::ExtensionGenerator(const FieldDescriptor* descriptor,
                                        const string& dllexport_decl)
   : descriptor_(descriptor),
@@ -80,7 +93,7 @@ ExtensionGenerator::~ExtensionGenerator() {}
 
 void ExtensionGenerator::GenerateDeclaration(io::Printer* printer) {
   map<string, string> vars;
-  vars["extendee"     ] = ClassName(descriptor_->containing_type(), true);
+  vars["extendee"     ] = ExtendeeClassName(descriptor_);
   vars["number"       ] = SimpleItoa(descriptor_->number());
   vars["type_traits"  ] = type_traits_;
   vars["name"         ] = descriptor_->name();
@@ -106,6 +119,7 @@ void ExtensionGenerator::GenerateDeclaration(io::Printer* printer) {
     "    ::google::protobuf::internal::$type_traits$, $field_type$, $packed$ >\n"
     "  $name$;\n"
     );
+
 }
 
 void ExtensionGenerator::GenerateDefinition(io::Printer* printer) {
@@ -115,7 +129,7 @@ void ExtensionGenerator::GenerateDefinition(io::Printer* printer) {
   string name = scope + descriptor_->name();
 
   map<string, string> vars;
-  vars["extendee"     ] = ClassName(descriptor_->containing_type(), true);
+  vars["extendee"     ] = ExtendeeClassName(descriptor_);
   vars["type_traits"  ] = type_traits_;
   vars["name"         ] = name;
   vars["constant_name"] = FieldConstantName(descriptor_);
@@ -154,7 +168,7 @@ void ExtensionGenerator::GenerateDefinition(io::Printer* printer) {
 
 void ExtensionGenerator::GenerateRegistration(io::Printer* printer) {
   map<string, string> vars;
-  vars["extendee"   ] = ClassName(descriptor_->containing_type(), true);
+  vars["extendee"   ] = ExtendeeClassName(descriptor_);
   vars["number"     ] = SimpleItoa(descriptor_->number());
   vars["field_type" ] = SimpleItoa(static_cast<int>(descriptor_->type()));
   vars["is_repeated"] = descriptor_->is_repeated() ? "true" : "false";
@@ -193,5 +207,4 @@ void ExtensionGenerator::GenerateRegistration(io::Printer* printer) {
 }  // namespace cpp
 }  // namespace compiler
 }  // namespace protobuf
-
 }  // namespace google

+ 26 - 3
src/google/protobuf/compiler/cpp/cpp_field.cc

@@ -40,6 +40,7 @@
 #include <google/protobuf/compiler/cpp/cpp_message_field.h>
 #include <google/protobuf/descriptor.pb.h>
 #include <google/protobuf/wire_format.h>
+#include <google/protobuf/io/printer.h>
 #include <google/protobuf/stubs/common.h>
 #include <google/protobuf/stubs/strutil.h>
 
@@ -61,11 +62,24 @@ void SetCommonFieldVariables(const FieldDescriptor* descriptor,
   (*variables)["tag_size"] = SimpleItoa(
     WireFormat::TagSize(descriptor->number(), descriptor->type()));
   (*variables)["deprecation"] = descriptor->options().deprecated()
-      ? " DEPRECATED_PROTOBUF_FIELD" : "";
+      ? " PROTOBUF_DEPRECATED" : "";
+
 }
 
 FieldGenerator::~FieldGenerator() {}
 
+void FieldGenerator::
+GenerateMergeFromCodedStreamWithPacking(io::Printer* printer) const {
+  // Reaching here indicates a bug. Cases are:
+  //   - This FieldGenerator should support packing, but this method should be
+  //     overridden.
+  //   - This FieldGenerator doesn't support packing, and this method should
+  //     never have been called.
+  GOOGLE_LOG(FATAL) << "GenerateMergeFromCodedStreamWithPacking() "
+             << "called on field generator that does not support packing.";
+
+}
+
 FieldGeneratorMap::FieldGeneratorMap(const Descriptor* descriptor)
   : descriptor_(descriptor),
     field_generators_(
@@ -82,7 +96,11 @@ FieldGenerator* FieldGeneratorMap::MakeGenerator(const FieldDescriptor* field) {
       case FieldDescriptor::CPPTYPE_MESSAGE:
         return new RepeatedMessageFieldGenerator(field);
       case FieldDescriptor::CPPTYPE_STRING:
-          return new RepeatedStringFieldGenerator(field);
+        switch (field->options().ctype()) {
+          default:  // RepeatedStringFieldGenerator handles unknown ctypes.
+          case FieldOptions::STRING:
+            return new RepeatedStringFieldGenerator(field);
+        }
       case FieldDescriptor::CPPTYPE_ENUM:
         return new RepeatedEnumFieldGenerator(field);
       default:
@@ -93,7 +111,11 @@ FieldGenerator* FieldGeneratorMap::MakeGenerator(const FieldDescriptor* field) {
       case FieldDescriptor::CPPTYPE_MESSAGE:
         return new MessageFieldGenerator(field);
       case FieldDescriptor::CPPTYPE_STRING:
-          return new StringFieldGenerator(field);
+        switch (field->options().ctype()) {
+          default:  // StringFieldGenerator handles unknown ctypes.
+          case FieldOptions::STRING:
+            return new StringFieldGenerator(field);
+        }
       case FieldDescriptor::CPPTYPE_ENUM:
         return new EnumFieldGenerator(field);
       default:
@@ -110,6 +132,7 @@ const FieldGenerator& FieldGeneratorMap::get(
   return *field_generators_[field->index()];
 }
 
+
 }  // namespace cpp
 }  // namespace compiler
 }  // namespace protobuf

+ 6 - 0
src/google/protobuf/compiler/cpp/cpp_field.h

@@ -118,6 +118,11 @@ class FieldGenerator {
   // message's MergeFromCodedStream() method.
   virtual void GenerateMergeFromCodedStream(io::Printer* printer) const = 0;
 
+  // Generate lines to decode this field from a packed value, which will be
+  // placed inside the message's MergeFromCodedStream() method.
+  virtual void GenerateMergeFromCodedStreamWithPacking(io::Printer* printer)
+      const;
+
   // Generate lines to serialize this field, which are placed within the
   // message's SerializeWithCachedSizes() method.
   virtual void GenerateSerializeWithCachedSizes(io::Printer* printer) const = 0;
@@ -153,6 +158,7 @@ class FieldGeneratorMap {
   GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(FieldGeneratorMap);
 };
 
+
 }  // namespace cpp
 }  // namespace compiler
 }  // namespace protobuf

+ 45 - 14
src/google/protobuf/compiler/cpp/cpp_file.cc

@@ -38,6 +38,7 @@
 #include <google/protobuf/compiler/cpp/cpp_extension.h>
 #include <google/protobuf/compiler/cpp/cpp_helpers.h>
 #include <google/protobuf/compiler/cpp/cpp_message.h>
+#include <google/protobuf/compiler/cpp/cpp_field.h>
 #include <google/protobuf/io/printer.h>
 #include <google/protobuf/descriptor.pb.h>
 #include <google/protobuf/stubs/strutil.h>
@@ -93,12 +94,14 @@ void FileGenerator::GenerateHeader(io::Printer* printer) {
   // Generate top of header.
   printer->Print(
     "// Generated by the protocol buffer compiler.  DO NOT EDIT!\n"
+    "// source: $filename$\n"
     "\n"
     "#ifndef PROTOBUF_$filename_identifier$__INCLUDED\n"
     "#define PROTOBUF_$filename_identifier$__INCLUDED\n"
     "\n"
     "#include <string>\n"
     "\n",
+    "filename", file_->name(),
     "filename_identifier", filename_identifier);
 
   printer->Print(
@@ -132,19 +135,23 @@ void FileGenerator::GenerateHeader(io::Printer* printer) {
   if (HasDescriptorMethods(file_)) {
     printer->Print(
       "#include <google/protobuf/generated_message_reflection.h>\n");
+  }
 
-    if (file_->service_count() > 0) {
-      printer->Print(
-        "#include <google/protobuf/service.h>\n");
-    }
+  if (HasGenericServices(file_)) {
+    printer->Print(
+      "#include <google/protobuf/service.h>\n");
   }
 
+
   for (int i = 0; i < file_->dependency_count(); i++) {
     printer->Print(
       "#include \"$dependency$.pb.h\"\n",
       "dependency", StripProto(file_->dependency(i)->name()));
   }
 
+  printer->Print(
+    "// @@protoc_insertion_point(includes)\n");
+
   // Open namespace.
   GenerateNamespaceOpeners(printer);
 
@@ -198,7 +205,7 @@ void FileGenerator::GenerateHeader(io::Printer* printer) {
   printer->Print(kThickSeparator);
   printer->Print("\n");
 
-  if (HasDescriptorMethods(file_)) {
+  if (HasGenericServices(file_)) {
     // Generate service definitions.
     for (int i = 0; i < file_->service_count(); i++) {
       if (i > 0) {
@@ -232,6 +239,10 @@ void FileGenerator::GenerateHeader(io::Printer* printer) {
     message_generators_[i]->GenerateInlineMethods(printer);
   }
 
+  printer->Print(
+    "\n"
+    "// @@protoc_insertion_point(namespace_scope)\n");
+
   // Close up namespace.
   GenerateNamespaceClosers(printer);
 
@@ -255,10 +266,14 @@ void FileGenerator::GenerateHeader(io::Printer* printer) {
     printer->Print(
         "\n"
         "}  // namespace google\n}  // namespace protobuf\n"
-        "#endif  // SWIG\n"
-        "\n");
+        "#endif  // SWIG\n");
   }
 
+  printer->Print(
+    "\n"
+    "// @@protoc_insertion_point(global_scope)\n"
+    "\n");
+
   printer->Print(
     "#endif  // PROTOBUF_$filename_identifier$__INCLUDED\n",
     "filename_identifier", filename_identifier);
@@ -285,6 +300,9 @@ void FileGenerator::GenerateSource(io::Printer* printer) {
       "#include <google/protobuf/wire_format.h>\n");
   }
 
+  printer->Print(
+    "// @@protoc_insertion_point(includes)\n");
+
   GenerateNamespaceOpeners(printer);
 
   if (HasDescriptorMethods(file_)) {
@@ -300,10 +318,13 @@ void FileGenerator::GenerateSource(io::Printer* printer) {
         "const ::google::protobuf::EnumDescriptor* $name$_descriptor_ = NULL;\n",
         "name", ClassName(file_->enum_type(i), false));
     }
-    for (int i = 0; i < file_->service_count(); i++) {
-      printer->Print(
-        "const ::google::protobuf::ServiceDescriptor* $name$_descriptor_ = NULL;\n",
-        "name", file_->service(i)->name());
+
+    if (HasGenericServices(file_)) {
+      for (int i = 0; i < file_->service_count(); i++) {
+        printer->Print(
+          "const ::google::protobuf::ServiceDescriptor* $name$_descriptor_ = NULL;\n",
+          "name", file_->service(i)->name());
+      }
     }
 
     printer->Print(
@@ -329,7 +350,7 @@ void FileGenerator::GenerateSource(io::Printer* printer) {
     message_generators_[i]->GenerateClassMethods(printer);
   }
 
-  if (HasDescriptorMethods(file_)) {
+  if (HasGenericServices(file_)) {
     // Generate services.
     for (int i = 0; i < file_->service_count(); i++) {
       if (i == 0) printer->Print("\n");
@@ -344,7 +365,15 @@ void FileGenerator::GenerateSource(io::Printer* printer) {
     extension_generators_[i]->GenerateDefinition(printer);
   }
 
+  printer->Print(
+    "\n"
+    "// @@protoc_insertion_point(namespace_scope)\n");
+
   GenerateNamespaceClosers(printer);
+
+  printer->Print(
+    "\n"
+    "// @@protoc_insertion_point(global_scope)\n");
 }
 
 void FileGenerator::GenerateBuildDescriptors(io::Printer* printer) {
@@ -397,8 +426,10 @@ void FileGenerator::GenerateBuildDescriptors(io::Printer* printer) {
     for (int i = 0; i < file_->enum_type_count(); i++) {
       enum_generators_[i]->GenerateDescriptorInitializer(printer, i);
     }
-    for (int i = 0; i < file_->service_count(); i++) {
-      service_generators_[i]->GenerateDescriptorInitializer(printer, i);
+    if (HasGenericServices(file_)) {
+      for (int i = 0; i < file_->service_count(); i++) {
+        service_generators_[i]->GenerateDescriptorInitializer(printer, i);
+      }
     }
 
     printer->Outdent();

+ 44 - 9
src/google/protobuf/compiler/cpp/cpp_helpers.cc

@@ -32,6 +32,7 @@
 //  Based on original Protocol Buffers design by
 //  Sanjay Ghemawat, Jeff Dean, and others.
 
+#include <limits>
 #include <vector>
 #include <google/protobuf/stubs/hash.h>
 
@@ -40,6 +41,7 @@
 #include <google/protobuf/stubs/strutil.h>
 #include <google/protobuf/stubs/substitute.h>
 
+
 namespace google {
 namespace protobuf {
 namespace compiler {
@@ -111,6 +113,7 @@ const char kThinSeparator[] =
   "// -------------------------------------------------------------------\n";
 
 string ClassName(const Descriptor* descriptor, bool qualified) {
+
   // Find "outer", the descriptor of the top-level message in which
   // "descriptor" is embedded.
   const Descriptor* outer = descriptor;
@@ -141,6 +144,12 @@ string ClassName(const EnumDescriptor* enum_descriptor, bool qualified) {
   }
 }
 
+
+string SuperClassName(const Descriptor* descriptor) {
+  return HasDescriptorMethods(descriptor->file()) ?
+      "::google::protobuf::Message" : "::google::protobuf::MessageLite";
+}
+
 string FieldName(const FieldDescriptor* field) {
   string result = field->name();
   LowerString(&result);
@@ -166,6 +175,12 @@ string FieldConstantName(const FieldDescriptor *field) {
   return result;
 }
 
+string FieldMessageTypeName(const FieldDescriptor* field) {
+  // Note:  The Google-internal version of Protocol Buffers uses this function
+  //   as a hook point for hacks to support legacy code.
+  return ClassName(field->message_type(), true);
+}
+
 string StripProto(const string& filename) {
   if (HasSuffixString(filename, ".protodevel")) {
     return StripSuffixString(filename, ".protodevel");
@@ -235,17 +250,37 @@ string DefaultValue(const FieldDescriptor* field) {
       return "GOOGLE_LONGLONG(" + SimpleItoa(field->default_value_int64()) + ")";
     case FieldDescriptor::CPPTYPE_UINT64:
       return "GOOGLE_ULONGLONG(" + SimpleItoa(field->default_value_uint64())+ ")";
-    case FieldDescriptor::CPPTYPE_DOUBLE:
-      return SimpleDtoa(field->default_value_double());
+    case FieldDescriptor::CPPTYPE_DOUBLE: {
+      double value = field->default_value_double();
+      if (value == numeric_limits<double>::infinity()) {
+        return "::google::protobuf::internal::Infinity()";
+      } else if (value == -numeric_limits<double>::infinity()) {
+        return "-::google::protobuf::internal::Infinity()";
+      } else if (value != value) {
+        return "::google::protobuf::internal::NaN()";
+      } else {
+        return SimpleDtoa(value);
+      }
+    }
     case FieldDescriptor::CPPTYPE_FLOAT:
       {
-        // If floating point value contains a period (.) or an exponent (either
-        // E or e), then append suffix 'f' to make it a floating-point literal.
-        string float_value = SimpleFtoa(field->default_value_float());
-        if (float_value.find_first_of(".eE") != string::npos) {
-          float_value.push_back('f');
+        float value = field->default_value_float();
+        if (value == numeric_limits<float>::infinity()) {
+          return "static_cast<float>(::google::protobuf::internal::Infinity())";
+        } else if (value == -numeric_limits<float>::infinity()) {
+          return "static_cast<float>(-::google::protobuf::internal::Infinity())";
+        } else if (value != value) {
+          return "static_cast<float>(::google::protobuf::internal::NaN())";
+        } else {
+          string float_value = SimpleFtoa(value);
+          // If floating point value contains a period (.) or an exponent
+          // (either E or e), then append suffix 'f' to make it a float
+          // literal.
+          if (float_value.find_first_of(".eE") != string::npos) {
+            float_value.push_back('f');
+          }
+          return float_value;
         }
-        return float_value;
       }
     case FieldDescriptor::CPPTYPE_BOOL:
       return field->default_value_bool() ? "true" : "false";
@@ -259,7 +294,7 @@ string DefaultValue(const FieldDescriptor* field) {
     case FieldDescriptor::CPPTYPE_STRING:
       return "\"" + CEscape(field->default_value_string()) + "\"";
     case FieldDescriptor::CPPTYPE_MESSAGE:
-      return ClassName(field->message_type(), true) + "::default_instance()";
+      return FieldMessageTypeName(field) + "::default_instance()";
   }
   // Can't actually get here; make compiler happy.  (We could add a default
   // case above but then we wouldn't get the nice compiler warning when a

+ 19 - 5
src/google/protobuf/compiler/cpp/cpp_helpers.h

@@ -60,6 +60,8 @@ extern const char kThinSeparator[];
 string ClassName(const Descriptor* descriptor, bool qualified);
 string ClassName(const EnumDescriptor* enum_descriptor, bool qualified);
 
+string SuperClassName(const Descriptor* descriptor);
+
 // Get the (unqualified) name that should be used for this field in C++ code.
 // The name is coerced to lower-case to emulate proto1 behavior.  People
 // should be using lowercase-with-underscores style for proto field names
@@ -77,6 +79,10 @@ inline const Descriptor* FieldScope(const FieldDescriptor* field) {
     field->extension_scope() : field->containing_type();
 }
 
+// Returns the fully-qualified type name field->message_type().  Usually this
+// is just ClassName(field->message_type(), true);
+string FieldMessageTypeName(const FieldDescriptor* field);
+
 // Strips ".proto" or ".protodevel" from the end of a filename.
 string StripProto(const string& filename);
 
@@ -107,33 +113,41 @@ string GlobalAssignDescriptorsName(const string& filename);
 string GlobalShutdownFileName(const string& filename);
 
 // Do message classes in this file keep track of unknown fields?
-inline const bool HasUnknownFields(const FileDescriptor *file) {
+inline bool HasUnknownFields(const FileDescriptor *file) {
   return file->options().optimize_for() != FileOptions::LITE_RUNTIME;
 }
 
 // Does this file have generated parsing, serialization, and other
 // standard methods for which reflection-based fallback implementations exist?
-inline const bool HasGeneratedMethods(const FileDescriptor *file) {
+inline bool HasGeneratedMethods(const FileDescriptor *file) {
   return file->options().optimize_for() != FileOptions::CODE_SIZE;
 }
 
 // Do message classes in this file have descriptor and refelction methods?
-inline const bool HasDescriptorMethods(const FileDescriptor *file) {
+inline bool HasDescriptorMethods(const FileDescriptor *file) {
   return file->options().optimize_for() != FileOptions::LITE_RUNTIME;
 }
 
+// Should we generate generic services for this file?
+inline bool HasGenericServices(const FileDescriptor *file) {
+  return file->service_count() > 0 &&
+         file->options().optimize_for() != FileOptions::LITE_RUNTIME &&
+         file->options().cc_generic_services();
+}
+
 // Should string fields in this file verify that their contents are UTF-8?
-inline const bool HasUtf8Verification(const FileDescriptor* file) {
+inline bool HasUtf8Verification(const FileDescriptor* file) {
   return file->options().optimize_for() != FileOptions::LITE_RUNTIME;
 }
 
 // Should we generate a separate, super-optimized code path for serializing to
 // flat arrays?  We don't do this in Lite mode because we'd rather reduce code
 // size.
-inline const bool HasFastArraySerialization(const FileDescriptor* file) {
+inline bool HasFastArraySerialization(const FileDescriptor* file) {
   return file->options().optimize_for() == FileOptions::SPEED;
 }
 
+
 }  // namespace cpp
 }  // namespace compiler
 }  // namespace protobuf

+ 72 - 27
src/google/protobuf/compiler/cpp/cpp_message.cc

@@ -308,11 +308,10 @@ GenerateClassDefinition(io::Printer* printer) {
   } else {
     vars["dllexport"] = dllexport_decl_ + " ";
   }
-  vars["superclass"] = HasDescriptorMethods(descriptor_->file()) ?
-                       "Message" : "MessageLite";
+  vars["superclass"] = SuperClassName(descriptor_);
 
   printer->Print(vars,
-    "class $dllexport$$classname$ : public ::google::protobuf::$superclass$ {\n"
+    "class $dllexport$$classname$ : public $superclass$ {\n"
     " public:\n");
   printer->Indent();
 
@@ -349,6 +348,10 @@ GenerateClassDefinition(io::Printer* printer) {
 
   printer->Print(vars,
     "static const $classname$& default_instance();\n"
+    "\n");
+
+
+  printer->Print(vars,
     "void Swap($classname$* other);\n"
     "\n"
     "// implements Message ----------------------------------------------\n"
@@ -387,7 +390,7 @@ GenerateClassDefinition(io::Printer* printer) {
     "private:\n"
     "void SharedCtor();\n"
     "void SharedDtor();\n"
-    "void SetCachedSize(int size) const { _cached_size_ = size; }\n"
+    "void SetCachedSize(int size) const;\n"
     "public:\n"
     "\n");
 
@@ -436,6 +439,11 @@ GenerateClassDefinition(io::Printer* printer) {
     extension_generators_[i]->GenerateDeclaration(printer);
   }
 
+
+  printer->Print(
+    "// @@protoc_insertion_point(class_scope:$full_name$)\n",
+    "full_name", descriptor_->full_name());
+
   // Generate private members for fields.
   printer->Outdent();
   printer->Print(" private:\n");
@@ -623,6 +631,7 @@ GenerateDefaultInstanceAllocator(io::Printer* printer) {
   for (int i = 0; i < descriptor_->nested_type_count(); i++) {
     nested_generators_[i]->GenerateDefaultInstanceAllocator(printer);
   }
+
 }
 
 void MessageGenerator::
@@ -751,6 +760,7 @@ GenerateClassMethods(io::Printer* printer) {
       "classname", classname_,
       "type_name", descriptor_->full_name());
   }
+
 }
 
 void MessageGenerator::
@@ -833,9 +843,8 @@ GenerateSharedDestructorCode(io::Printer* printer) {
 
 void MessageGenerator::
 GenerateStructors(io::Printer* printer) {
-  string superclass = HasDescriptorMethods(descriptor_->file()) ?
-      "Message" : "MessageLite";
-  
+  string superclass = SuperClassName(descriptor_);
+
   // Generate the default constructor.
   printer->Print(
     "$classname$::$classname$()\n"
@@ -864,7 +873,7 @@ GenerateStructors(io::Printer* printer) {
       printer->Print(
           "  $name$_ = const_cast< $type$*>(&$type$::default_instance());\n",
           "name", FieldName(field),
-          "type", ClassName(field->message_type(), true));
+          "type", FieldMessageTypeName(field));
     }
   }
   printer->Print(
@@ -896,6 +905,15 @@ GenerateStructors(io::Printer* printer) {
   // Generate the shared destructor code.
   GenerateSharedDestructorCode(printer);
 
+  // Generate SetCachedSize.
+  printer->Print(
+    "void $classname$::SetCachedSize(int size) const {\n"
+    "  GOOGLE_SAFE_CONCURRENT_WRITES_BEGIN();\n"
+    "  _cached_size_ = size;\n"
+    "  GOOGLE_SAFE_CONCURRENT_WRITES_END();\n"
+    "}\n",
+    "classname", classname_);
+
   // Only generate this member if it's not disabled.
   if (HasDescriptorMethods(descriptor_->file()) &&
       !descriptor_->options().no_standard_descriptor_accessor()) {
@@ -924,6 +942,7 @@ GenerateStructors(io::Printer* printer) {
     "classname", classname_,
     "adddescriptorsname",
     GlobalAddDescriptorsName(descriptor_->file()->name()));
+
 }
 
 void MessageGenerator::
@@ -1237,12 +1256,15 @@ GenerateMergeFromCodedStream(io::Printer* printer) {
       PrintFieldComment(printer, field);
 
       printer->Print(
-        "case $number$: {\n"
-        "  if (::google::protobuf::internal::WireFormatLite::GetTagWireType(tag) !=\n"
-        "      ::google::protobuf::internal::WireFormatLite::WIRETYPE_$wiretype$) {\n"
-        "    goto handle_uninterpreted;\n"
-        "  }\n",
-        "number", SimpleItoa(field->number()),
+        "case $number$: {\n",
+        "number", SimpleItoa(field->number()));
+      printer->Indent();
+      const FieldGenerator& field_generator = field_generators_.get(field);
+
+      // Emit code to parse the common, expected case.
+      printer->Print(
+        "if (::google::protobuf::internal::WireFormatLite::GetTagWireType(tag) ==\n"
+        "    ::google::protobuf::internal::WireFormatLite::WIRETYPE_$wiretype$) {\n",
         "wiretype", kWireTypeNames[WireFormat::WireTypeForField(field)]);
 
       if (i > 0 || (field->is_repeated() && !field->options().packed())) {
@@ -1252,8 +1274,38 @@ GenerateMergeFromCodedStream(io::Printer* printer) {
       }
 
       printer->Indent();
+      if (field->options().packed()) {
+        field_generator.GenerateMergeFromCodedStreamWithPacking(printer);
+      } else {
+        field_generator.GenerateMergeFromCodedStream(printer);
+      }
+      printer->Outdent();
 
-      field_generators_.get(field).GenerateMergeFromCodedStream(printer);
+      // Emit code to parse unexpectedly packed or unpacked values.
+      if (field->is_packable() && field->options().packed()) {
+        printer->Print(
+          "} else if (::google::protobuf::internal::WireFormatLite::GetTagWireType(tag)\n"
+          "           == ::google::protobuf::internal::WireFormatLite::\n"
+          "              WIRETYPE_$wiretype$) {\n",
+          "wiretype",
+          kWireTypeNames[WireFormat::WireTypeForFieldType(field->type())]);
+        printer->Indent();
+        field_generator.GenerateMergeFromCodedStream(printer);
+        printer->Outdent();
+      } else if (field->is_packable() && !field->options().packed()) {
+        printer->Print(
+          "} else if (::google::protobuf::internal::WireFormatLite::GetTagWireType(tag)\n"
+          "           == ::google::protobuf::internal::WireFormatLite::\n"
+          "              WIRETYPE_LENGTH_DELIMITED) {\n");
+        printer->Indent();
+        field_generator.GenerateMergeFromCodedStreamWithPacking(printer);
+        printer->Outdent();
+      }
+
+      printer->Print(
+        "} else {\n"
+        "  goto handle_uninterpreted;\n"
+        "}\n");
 
       // switch() is slow since it can't be predicted well.  Insert some if()s
       // here that attempt to predict the next tag.
@@ -1434,18 +1486,6 @@ GenerateSerializeWithCachedSizes(io::Printer* printer) {
     "classname", classname_);
   printer->Indent();
 
-  if (HasFastArraySerialization(descriptor_->file())) {
-    printer->Print(
-      "::google::protobuf::uint8* raw_buffer = "
-        "output->GetDirectBufferForNBytesAndAdvance(_cached_size_);\n"
-      "if (raw_buffer != NULL) {\n"
-      "  $classname$::SerializeWithCachedSizesToArray(raw_buffer);\n"
-      "  return;\n"
-      "}\n"
-      "\n",
-      "classname", classname_);
-  }
-
   GenerateSerializeWithCachedSizesBody(printer, false);
 
   printer->Outdent();
@@ -1555,7 +1595,9 @@ GenerateByteSize(io::Printer* printer) {
         "      ComputeUnknownMessageSetItemsSize(unknown_fields());\n");
     }
     printer->Print(
+      "  GOOGLE_SAFE_CONCURRENT_WRITES_BEGIN();\n"
       "  _cached_size_ = total_size;\n"
+      "  GOOGLE_SAFE_CONCURRENT_WRITES_END();\n"
       "  return total_size;\n"
       "}\n");
     return;
@@ -1647,7 +1689,9 @@ GenerateByteSize(io::Printer* printer) {
   // exact same value, it works on all common processors.  In a future version
   // of C++, _cached_size_ should be made into an atomic<int>.
   printer->Print(
+    "GOOGLE_SAFE_CONCURRENT_WRITES_BEGIN();\n"
     "_cached_size_ = total_size;\n"
+    "GOOGLE_SAFE_CONCURRENT_WRITES_END();\n"
     "return total_size;\n");
 
   printer->Outdent();
@@ -1719,6 +1763,7 @@ GenerateIsInitialized(io::Printer* printer) {
     "}\n");
 }
 
+
 }  // namespace cpp
 }  // namespace compiler
 }  // namespace protobuf

+ 1 - 0
src/google/protobuf/compiler/cpp/cpp_message.h

@@ -150,6 +150,7 @@ class MessageGenerator {
       io::Printer* printer, const Descriptor::ExtensionRange* range,
       bool unbounded);
 
+
   const Descriptor* descriptor_;
   string classname_;
   string dllexport_decl_;

+ 21 - 15
src/google/protobuf/compiler/cpp/cpp_message_field.cc

@@ -47,7 +47,11 @@ namespace {
 void SetMessageVariables(const FieldDescriptor* descriptor,
                          map<string, string>* variables) {
   SetCommonFieldVariables(descriptor, variables);
-  (*variables)["type"] = ClassName(descriptor->message_type(), true);
+  (*variables)["type"] = FieldMessageTypeName(descriptor);
+  (*variables)["stream_writer"] = (*variables)["declared_type"] +
+      (HasFastArraySerialization(descriptor->message_type()->file()) ?
+       "MaybeToArray" :
+       "");
 }
 
 }  // namespace
@@ -125,7 +129,7 @@ GenerateMergeFromCodedStream(io::Printer* printer) const {
 void MessageFieldGenerator::
 GenerateSerializeWithCachedSizes(io::Printer* printer) const {
   printer->Print(variables_,
-    "::google::protobuf::internal::WireFormatLite::Write$declared_type$NoVirtual(\n"
+    "::google::protobuf::internal::WireFormatLite::Write$stream_writer$(\n"
     "  $number$, this->$name$(), output);\n");
 }
 
@@ -164,26 +168,19 @@ GeneratePrivateMembers(io::Printer* printer) const {
 void RepeatedMessageFieldGenerator::
 GenerateAccessorDeclarations(io::Printer* printer) const {
   printer->Print(variables_,
-    "inline const ::google::protobuf::RepeatedPtrField< $type$ >& $name$() const"
-                 "$deprecation$;\n"
-    "inline ::google::protobuf::RepeatedPtrField< $type$ >* mutable_$name$()"
-                 "$deprecation$;\n"
     "inline const $type$& $name$(int index) const$deprecation$;\n"
     "inline $type$* mutable_$name$(int index)$deprecation$;\n"
     "inline $type$* add_$name$()$deprecation$;\n");
+  printer->Print(variables_,
+    "inline const ::google::protobuf::RepeatedPtrField< $type$ >&\n"
+    "    $name$() const$deprecation$;\n"
+    "inline ::google::protobuf::RepeatedPtrField< $type$ >*\n"
+    "    mutable_$name$()$deprecation$;\n");
 }
 
 void RepeatedMessageFieldGenerator::
 GenerateInlineAccessorDefinitions(io::Printer* printer) const {
   printer->Print(variables_,
-    "inline const ::google::protobuf::RepeatedPtrField< $type$ >&\n"
-    "$classname$::$name$() const {\n"
-    "  return $name$_;\n"
-    "}\n"
-    "inline ::google::protobuf::RepeatedPtrField< $type$ >*\n"
-    "$classname$::mutable_$name$() {\n"
-    "  return &$name$_;\n"
-    "}\n"
     "inline const $type$& $classname$::$name$(int index) const {\n"
     "  return $name$_.Get(index);\n"
     "}\n"
@@ -193,6 +190,15 @@ GenerateInlineAccessorDefinitions(io::Printer* printer) const {
     "inline $type$* $classname$::add_$name$() {\n"
     "  return $name$_.Add();\n"
     "}\n");
+  printer->Print(variables_,
+    "inline const ::google::protobuf::RepeatedPtrField< $type$ >&\n"
+    "$classname$::$name$() const {\n"
+    "  return $name$_;\n"
+    "}\n"
+    "inline ::google::protobuf::RepeatedPtrField< $type$ >*\n"
+    "$classname$::mutable_$name$() {\n"
+    "  return &$name$_;\n"
+    "}\n");
 }
 
 void RepeatedMessageFieldGenerator::
@@ -232,7 +238,7 @@ void RepeatedMessageFieldGenerator::
 GenerateSerializeWithCachedSizes(io::Printer* printer) const {
   printer->Print(variables_,
     "for (int i = 0; i < this->$name$_size(); i++) {\n"
-    "  ::google::protobuf::internal::WireFormatLite::Write$declared_type$NoVirtual(\n"
+    "  ::google::protobuf::internal::WireFormatLite::Write$stream_writer$(\n"
     "    $number$, this->$name$(i), output);\n"
     "}\n");
 }

+ 41 - 37
src/google/protobuf/compiler/cpp/cpp_primitive_field.cc

@@ -84,10 +84,14 @@ void SetPrimitiveVariables(const FieldDescriptor* descriptor,
   SetCommonFieldVariables(descriptor, variables);
   (*variables)["type"] = PrimitiveTypeName(descriptor->cpp_type());
   (*variables)["default"] = DefaultValue(descriptor);
+  (*variables)["tag"] = SimpleItoa(internal::WireFormat::MakeTag(descriptor));
   int fixed_size = FixedSize(descriptor->type());
   if (fixed_size != -1) {
     (*variables)["fixed_size"] = SimpleItoa(fixed_size);
   }
+  (*variables)["wire_format_field_type"] =
+      "::google::protobuf::internal::WireFormatLite::" + FieldDescriptorProto_Type_Name(
+          static_cast<FieldDescriptorProto_Type>(descriptor->type()));
 }
 
 }  // namespace
@@ -149,8 +153,9 @@ GenerateConstructorCode(io::Printer* printer) const {
 void PrimitiveFieldGenerator::
 GenerateMergeFromCodedStream(io::Printer* printer) const {
   printer->Print(variables_,
-    "DO_(::google::protobuf::internal::WireFormatLite::Read$declared_type$(\n"
-    "      input, &$name$_));\n"
+    "DO_((::google::protobuf::internal::WireFormatLite::ReadPrimitive<\n"
+    "         $type$, $wire_format_field_type$>(\n"
+    "       input, &$name$_)));\n"
     "_set_bit($index$);\n");
 }
 
@@ -188,6 +193,14 @@ RepeatedPrimitiveFieldGenerator::
 RepeatedPrimitiveFieldGenerator(const FieldDescriptor* descriptor)
   : descriptor_(descriptor) {
   SetPrimitiveVariables(descriptor, &variables_);
+
+  if (descriptor->options().packed()) {
+    variables_["packed_reader"] = "ReadPackedPrimitive";
+    variables_["repeated_reader"] = "ReadRepeatedPrimitiveNoInline";
+  } else {
+    variables_["packed_reader"] = "ReadPackedPrimitiveNoInline";
+    variables_["repeated_reader"] = "ReadRepeatedPrimitive";
+  }
 }
 
 RepeatedPrimitiveFieldGenerator::~RepeatedPrimitiveFieldGenerator() {}
@@ -205,25 +218,19 @@ GeneratePrivateMembers(io::Printer* printer) const {
 void RepeatedPrimitiveFieldGenerator::
 GenerateAccessorDeclarations(io::Printer* printer) const {
   printer->Print(variables_,
-    "inline const ::google::protobuf::RepeatedField< $type$ >& $name$() const\n"
-    "    $deprecation$;\n"
-    "inline ::google::protobuf::RepeatedField< $type$ >* mutable_$name$()$deprecation$;\n"
     "inline $type$ $name$(int index) const$deprecation$;\n"
     "inline void set_$name$(int index, $type$ value)$deprecation$;\n"
     "inline void add_$name$($type$ value)$deprecation$;\n");
+  printer->Print(variables_,
+    "inline const ::google::protobuf::RepeatedField< $type$ >&\n"
+    "    $name$() const$deprecation$;\n"
+    "inline ::google::protobuf::RepeatedField< $type$ >*\n"
+    "    mutable_$name$()$deprecation$;\n");
 }
 
 void RepeatedPrimitiveFieldGenerator::
 GenerateInlineAccessorDefinitions(io::Printer* printer) const {
   printer->Print(variables_,
-    "inline const ::google::protobuf::RepeatedField< $type$ >&\n"
-    "$classname$::$name$() const {\n"
-    "  return $name$_;\n"
-    "}\n"
-    "inline ::google::protobuf::RepeatedField< $type$ >*\n"
-    "$classname$::mutable_$name$() {\n"
-    "  return &$name$_;\n"
-    "}\n"
     "inline $type$ $classname$::$name$(int index) const {\n"
     "  return $name$_.Get(index);\n"
     "}\n"
@@ -233,6 +240,15 @@ GenerateInlineAccessorDefinitions(io::Printer* printer) const {
     "inline void $classname$::add_$name$($type$ value) {\n"
     "  $name$_.Add(value);\n"
     "}\n");
+  printer->Print(variables_,
+    "inline const ::google::protobuf::RepeatedField< $type$ >&\n"
+    "$classname$::$name$() const {\n"
+    "  return $name$_;\n"
+    "}\n"
+    "inline ::google::protobuf::RepeatedField< $type$ >*\n"
+    "$classname$::mutable_$name$() {\n"
+    "  return &$name$_;\n"
+    "}\n");
 }
 
 void RepeatedPrimitiveFieldGenerator::
@@ -257,30 +273,18 @@ GenerateConstructorCode(io::Printer* printer) const {
 
 void RepeatedPrimitiveFieldGenerator::
 GenerateMergeFromCodedStream(io::Printer* printer) const {
-  if (descriptor_->options().packed()) {
-    printer->Print("{\n");
-    printer->Indent();
-    printer->Print(variables_,
-      "::google::protobuf::uint32 length;\n"
-      "DO_(input->ReadVarint32(&length));\n"
-      "::google::protobuf::io::CodedInputStream::Limit limit =\n"
-      "    input->PushLimit(length);\n"
-      "while (input->BytesUntilLimit() > 0) {\n"
-      "  $type$ value;\n"
-      "  DO_(::google::protobuf::internal::WireFormatLite::Read$declared_type$(\n"
-      "        input, &value));\n"
-      "  add_$name$(value);\n"
-      "}\n"
-      "input->PopLimit(limit);\n");
-    printer->Outdent();
-    printer->Print("}\n");
-  } else {
-    printer->Print(variables_,
-      "$type$ value;\n"
-      "DO_(::google::protobuf::internal::WireFormatLite::Read$declared_type$(\n"
-      "      input, &value));\n"
-      "add_$name$(value);\n");
-  }
+  printer->Print(variables_,
+    "DO_((::google::protobuf::internal::WireFormatLite::$repeated_reader$<\n"
+    "         $type$, $wire_format_field_type$>(\n"
+    "       $tag_size$, $tag$, input, this->mutable_$name$())));\n");
+}
+
+void RepeatedPrimitiveFieldGenerator::
+GenerateMergeFromCodedStreamWithPacking(io::Printer* printer) const {
+  printer->Print(variables_,
+    "DO_((::google::protobuf::internal::WireFormatLite::$packed_reader$<\n"
+    "         $type$, $wire_format_field_type$>(\n"
+    "       input, this->mutable_$name$())));\n");
 }
 
 void RepeatedPrimitiveFieldGenerator::

+ 1 - 0
src/google/protobuf/compiler/cpp/cpp_primitive_field.h

@@ -83,6 +83,7 @@ class RepeatedPrimitiveFieldGenerator : public FieldGenerator {
   void GenerateSwappingCode(io::Printer* printer) const;
   void GenerateConstructorCode(io::Printer* printer) const;
   void GenerateMergeFromCodedStream(io::Printer* printer) const;
+  void GenerateMergeFromCodedStreamWithPacking(io::Printer* printer) const;
   void GenerateSerializeWithCachedSizes(io::Printer* printer) const;
   void GenerateSerializeWithCachedSizesToArray(io::Printer* printer) const;
   void GenerateByteSize(io::Printer* printer) const;

+ 19 - 16
src/google/protobuf/compiler/cpp/cpp_string_field.cc

@@ -91,7 +91,7 @@ GenerateAccessorDeclarations(io::Printer* printer) const {
   // files that applied the ctype.  The field can still be accessed via the
   // reflection interface since the reflection interface is independent of
   // the string's underlying representation.
-  if (descriptor_->options().has_ctype()) {
+  if (descriptor_->options().ctype() != FieldOptions::STRING) {
     printer->Outdent();
     printer->Print(
       " private:\n"
@@ -107,7 +107,7 @@ GenerateAccessorDeclarations(io::Printer* printer) const {
                  "$deprecation$;\n"
     "inline ::std::string* mutable_$name$()$deprecation$;\n");
 
-  if (descriptor_->options().has_ctype()) {
+  if (descriptor_->options().ctype() != FieldOptions::STRING) {
     printer->Outdent();
     printer->Print(" public:\n");
     printer->Indent();
@@ -278,7 +278,7 @@ GeneratePrivateMembers(io::Printer* printer) const {
 void RepeatedStringFieldGenerator::
 GenerateAccessorDeclarations(io::Printer* printer) const {
   // See comment above about unknown ctypes.
-  if (descriptor_->options().has_ctype()) {
+  if (descriptor_->options().ctype() != FieldOptions::STRING) {
     printer->Outdent();
     printer->Print(
       " private:\n"
@@ -287,10 +287,6 @@ GenerateAccessorDeclarations(io::Printer* printer) const {
   }
 
   printer->Print(variables_,
-    "inline const ::google::protobuf::RepeatedPtrField< ::std::string>& $name$() const"
-                 "$deprecation$;\n"
-    "inline ::google::protobuf::RepeatedPtrField< ::std::string>* mutable_$name$()"
-                 "$deprecation$;\n"
     "inline const ::std::string& $name$(int index) const$deprecation$;\n"
     "inline ::std::string* mutable_$name$(int index)$deprecation$;\n"
     "inline void set_$name$(int index, const ::std::string& value)$deprecation$;\n"
@@ -304,7 +300,13 @@ GenerateAccessorDeclarations(io::Printer* printer) const {
     "inline void add_$name$(const $pointer_type$* value, size_t size)"
                  "$deprecation$;\n");
 
-  if (descriptor_->options().has_ctype()) {
+  printer->Print(variables_,
+    "inline const ::google::protobuf::RepeatedPtrField< ::std::string>& $name$() const"
+                 "$deprecation$;\n"
+    "inline ::google::protobuf::RepeatedPtrField< ::std::string>* mutable_$name$()"
+                 "$deprecation$;\n");
+
+  if (descriptor_->options().ctype() != FieldOptions::STRING) {
     printer->Outdent();
     printer->Print(" public:\n");
     printer->Indent();
@@ -314,14 +316,6 @@ GenerateAccessorDeclarations(io::Printer* printer) const {
 void RepeatedStringFieldGenerator::
 GenerateInlineAccessorDefinitions(io::Printer* printer) const {
   printer->Print(variables_,
-    "inline const ::google::protobuf::RepeatedPtrField< ::std::string>&\n"
-    "$classname$::$name$() const {\n"
-    "  return $name$_;\n"
-    "}\n"
-    "inline ::google::protobuf::RepeatedPtrField< ::std::string>*\n"
-    "$classname$::mutable_$name$() {\n"
-    "  return &$name$_;\n"
-    "}\n"
     "inline const ::std::string& $classname$::$name$(int index) const {\n"
     "  return $name$_.Get(index);\n"
     "}\n"
@@ -353,6 +347,15 @@ GenerateInlineAccessorDefinitions(io::Printer* printer) const {
     "$classname$::add_$name$(const $pointer_type$* value, size_t size) {\n"
     "  $name$_.Add()->assign(reinterpret_cast<const char*>(value), size);\n"
     "}\n");
+  printer->Print(variables_,
+    "inline const ::google::protobuf::RepeatedPtrField< ::std::string>&\n"
+    "$classname$::$name$() const {\n"
+    "  return $name$_;\n"
+    "}\n"
+    "inline ::google::protobuf::RepeatedPtrField< ::std::string>*\n"
+    "$classname$::mutable_$name$() {\n"
+    "  return &$name$_;\n"
+    "}\n");
 }
 
 void RepeatedStringFieldGenerator::

+ 69 - 4
src/google/protobuf/compiler/cpp/cpp_unittest.cc

@@ -49,6 +49,7 @@
 #include <google/protobuf/unittest.pb.h>
 #include <google/protobuf/unittest_optimize_for.pb.h>
 #include <google/protobuf/unittest_embed_optimize_for.pb.h>
+#include <google/protobuf/unittest_no_generic_services.pb.h>
 #include <google/protobuf/test_util.h>
 #include <google/protobuf/compiler/cpp/cpp_test_bad_identifiers.pb.h>
 #include <google/protobuf/compiler/importer.h>
@@ -154,6 +155,16 @@ TEST(GeneratedMessageTest, FloatingPointDefaults) {
   EXPECT_EQ(-1.5f, extreme_default.negative_float());
   EXPECT_EQ(2.0e8f, extreme_default.large_float());
   EXPECT_EQ(-8e-28f, extreme_default.small_negative_float());
+  EXPECT_EQ(numeric_limits<double>::infinity(),
+            extreme_default.inf_double());
+  EXPECT_EQ(-numeric_limits<double>::infinity(),
+            extreme_default.neg_inf_double());
+  EXPECT_TRUE(extreme_default.nan_double() != extreme_default.nan_double());
+  EXPECT_EQ(numeric_limits<float>::infinity(),
+            extreme_default.inf_float());
+  EXPECT_EQ(-numeric_limits<float>::infinity(),
+            extreme_default.neg_inf_float());
+  EXPECT_TRUE(extreme_default.nan_float() != extreme_default.nan_float());
 }
 
 TEST(GeneratedMessageTest, Accessors) {
@@ -779,22 +790,39 @@ TEST(GeneratedEnumTest, IsValidValue) {
 }
 
 TEST(GeneratedEnumTest, MinAndMax) {
-  EXPECT_EQ(unittest::TestAllTypes::FOO,unittest::TestAllTypes::NestedEnum_MIN);
-  EXPECT_EQ(unittest::TestAllTypes::BAZ,unittest::TestAllTypes::NestedEnum_MAX);
+  EXPECT_EQ(unittest::TestAllTypes::FOO,
+            unittest::TestAllTypes::NestedEnum_MIN);
+  EXPECT_EQ(unittest::TestAllTypes::BAZ,
+            unittest::TestAllTypes::NestedEnum_MAX);
+  EXPECT_EQ(4, unittest::TestAllTypes::NestedEnum_ARRAYSIZE);
 
   EXPECT_EQ(unittest::FOREIGN_FOO, unittest::ForeignEnum_MIN);
   EXPECT_EQ(unittest::FOREIGN_BAZ, unittest::ForeignEnum_MAX);
+  EXPECT_EQ(7, unittest::ForeignEnum_ARRAYSIZE);
 
   EXPECT_EQ(1, unittest::TestEnumWithDupValue_MIN);
   EXPECT_EQ(3, unittest::TestEnumWithDupValue_MAX);
+  EXPECT_EQ(4, unittest::TestEnumWithDupValue_ARRAYSIZE);
 
   EXPECT_EQ(unittest::SPARSE_E, unittest::TestSparseEnum_MIN);
   EXPECT_EQ(unittest::SPARSE_C, unittest::TestSparseEnum_MAX);
+  EXPECT_EQ(12589235, unittest::TestSparseEnum_ARRAYSIZE);
 
-  // Make sure we can use _MIN and _MAX as switch cases.
-  switch(unittest::SPARSE_A) {
+  // Make sure we can take the address of _MIN, _MAX and _ARRAYSIZE.
+  void* nullptr = 0;  // NULL may be integer-type, not pointer-type.
+  EXPECT_NE(nullptr, &unittest::TestAllTypes::NestedEnum_MIN);
+  EXPECT_NE(nullptr, &unittest::TestAllTypes::NestedEnum_MAX);
+  EXPECT_NE(nullptr, &unittest::TestAllTypes::NestedEnum_ARRAYSIZE);
+
+  EXPECT_NE(nullptr, &unittest::ForeignEnum_MIN);
+  EXPECT_NE(nullptr, &unittest::ForeignEnum_MAX);
+  EXPECT_NE(nullptr, &unittest::ForeignEnum_ARRAYSIZE);
+
+  // Make sure we can use _MIN, _MAX and _ARRAYSIZE as switch cases.
+  switch (unittest::SPARSE_A) {
     case unittest::TestSparseEnum_MIN:
     case unittest::TestSparseEnum_MAX:
+    case unittest::TestSparseEnum_ARRAYSIZE:
       break;
     default:
       break;
@@ -1136,6 +1164,43 @@ TEST_F(GeneratedServiceTest, NotImplemented) {
   EXPECT_TRUE(controller.called_);
 }
 
+}  // namespace cpp_unittest
+}  // namespace cpp
+}  // namespace compiler
+
+namespace no_generic_services_test {
+  // Verify that no class called "TestService" was defined in
+  // unittest_no_generic_services.pb.h by defining a different type by the same
+  // name.  If such a service was generated, this will not compile.
+  struct TestService {
+    int i;
+  };
+}
+
+namespace compiler {
+namespace cpp {
+namespace cpp_unittest {
+
+TEST_F(GeneratedServiceTest, NoGenericServices) {
+  // Verify that non-services in unittest_no_generic_services.proto were
+  // generated.
+  no_generic_services_test::TestMessage message;
+  message.set_a(1);
+  message.SetExtension(no_generic_services_test::test_extension, 123);
+  no_generic_services_test::TestEnum e = no_generic_services_test::FOO;
+  EXPECT_EQ(e, 1);
+
+  // Verify that a ServiceDescriptor is generated for the service even if the
+  // class itself is not.
+  const FileDescriptor* file =
+      no_generic_services_test::TestMessage::descriptor()->file();
+
+  ASSERT_EQ(1, file->service_count());
+  EXPECT_EQ("TestService", file->service(0)->name());
+  ASSERT_EQ(1, file->service(0)->method_count());
+  EXPECT_EQ("Foo", file->service(0)->method(0)->name());
+}
+
 #endif  // !PROTOBUF_TEST_NO_DESCRIPTORS
 
 // ===================================================================

+ 5 - 0
src/google/protobuf/compiler/java/java_enum.cc

@@ -223,6 +223,11 @@ void EnumGenerator::Generate(io::Printer* printer) {
       "file", ClassName(descriptor_->file()));
   }
 
+  printer->Print(
+    "\n"
+    "// @@protoc_insertion_point(enum_scope:$full_name$)\n",
+    "full_name", descriptor_->full_name());
+
   printer->Outdent();
   printer->Print("}\n\n");
 }

+ 29 - 17
src/google/protobuf/compiler/java/java_enum_field.cc

@@ -62,7 +62,7 @@ void SetEnumVariables(const FieldDescriptor* descriptor,
   (*variables)["default"] = DefaultValue(descriptor);
   (*variables)["tag"] = SimpleItoa(internal::WireFormat::MakeTag(descriptor));
   (*variables)["tag_size"] = SimpleItoa(
-      internal::WireFormat::TagSize(descriptor->number(), descriptor->type()));
+      internal::WireFormat::TagSize(descriptor->number(), GetType(descriptor)));
 }
 
 }  // namespace
@@ -81,7 +81,7 @@ void EnumFieldGenerator::
 GenerateMembers(io::Printer* printer) const {
   printer->Print(variables_,
     "private boolean has$capitalized_name$;\n"
-    "private $type$ $name$_ = $default$;\n"
+    "private $type$ $name$_;\n"
     "public boolean has$capitalized_name$() { return has$capitalized_name$; }\n"
     "public $type$ get$capitalized_name$() { return $name$_; }\n");
 }
@@ -110,6 +110,11 @@ GenerateBuilderMembers(io::Printer* printer) const {
     "}\n");
 }
 
+void EnumFieldGenerator::
+GenerateInitializationCode(io::Printer* printer) const {
+  printer->Print(variables_, "$name$_ = $default$;\n");
+}
+
 void EnumFieldGenerator::
 GenerateMergingCode(io::Printer* printer) const {
   printer->Print(variables_,
@@ -240,6 +245,11 @@ GenerateBuilderMembers(io::Printer* printer) const {
     "}\n");
 }
 
+void RepeatedEnumFieldGenerator::
+GenerateInitializationCode(io::Printer* printer) const {
+  // Initialized inline.
+}
+
 void RepeatedEnumFieldGenerator::
 GenerateMergingCode(io::Printer* printer) const {
   printer->Print(variables_,
@@ -262,15 +272,6 @@ GenerateBuildingCode(io::Printer* printer) const {
 
 void RepeatedEnumFieldGenerator::
 GenerateParsingCode(io::Printer* printer) const {
-  // If packed, set up the while loop
-  if (descriptor_->options().packed()) {
-    printer->Print(variables_,
-      "int length = input.readRawVarint32();\n"
-      "int oldLimit = input.pushLimit(length);\n"
-      "while(input.getBytesUntilLimit() > 0) {\n");
-    printer->Indent();
-  }
-
   // Read and store the enum
   printer->Print(variables_,
     "int rawValue = input.readEnum();\n"
@@ -287,13 +288,24 @@ GenerateParsingCode(io::Printer* printer) const {
   printer->Print(variables_,
     "  add$capitalized_name$(value);\n"
     "}\n");
+}
 
-  if (descriptor_->options().packed()) {
-    printer->Outdent();
-    printer->Print(variables_,
-      "}\n"
-      "input.popLimit(oldLimit);\n");
-  }
+void RepeatedEnumFieldGenerator::
+GenerateParsingCodeFromPacked(io::Printer* printer) const {
+  // Wrap GenerateParsingCode's contents with a while loop.
+
+  printer->Print(variables_,
+    "int length = input.readRawVarint32();\n"
+    "int oldLimit = input.pushLimit(length);\n"
+    "while(input.getBytesUntilLimit() > 0) {\n");
+  printer->Indent();
+
+  GenerateParsingCode(printer);
+
+  printer->Outdent();
+  printer->Print(variables_,
+    "}\n"
+    "input.popLimit(oldLimit);\n");
 }
 
 void RepeatedEnumFieldGenerator::

+ 3 - 0
src/google/protobuf/compiler/java/java_enum_field.h

@@ -52,6 +52,7 @@ class EnumFieldGenerator : public FieldGenerator {
   // implements FieldGenerator ---------------------------------------
   void GenerateMembers(io::Printer* printer) const;
   void GenerateBuilderMembers(io::Printer* printer) const;
+  void GenerateInitializationCode(io::Printer* printer) const;
   void GenerateMergingCode(io::Printer* printer) const;
   void GenerateBuildingCode(io::Printer* printer) const;
   void GenerateParsingCode(io::Printer* printer) const;
@@ -75,9 +76,11 @@ class RepeatedEnumFieldGenerator : public FieldGenerator {
   // implements FieldGenerator ---------------------------------------
   void GenerateMembers(io::Printer* printer) const;
   void GenerateBuilderMembers(io::Printer* printer) const;
+  void GenerateInitializationCode(io::Printer* printer) const;
   void GenerateMergingCode(io::Printer* printer) const;
   void GenerateBuildingCode(io::Printer* printer) const;
   void GenerateParsingCode(io::Printer* printer) const;
+  void GenerateParsingCodeFromPacked(io::Printer* printer) const;
   void GenerateSerializationCode(io::Printer* printer) const;
   void GenerateSerializedSizeCode(io::Printer* printer) const;
 

+ 1 - 2
src/google/protobuf/compiler/java/java_extension.cc

@@ -133,7 +133,7 @@ void ExtensionGenerator::GenerateInitializationCode(io::Printer* printer) {
   vars["extendee"] = ClassName(descriptor_->containing_type());
   vars["default"] = descriptor_->is_repeated() ? "" : DefaultValue(descriptor_);
   vars["number"] = SimpleItoa(descriptor_->number());
-  vars["type_constant"] = TypeName(descriptor_->type());
+  vars["type_constant"] = TypeName(GetType(descriptor_));
   vars["packed"] = descriptor_->options().packed() ? "true" : "false";
   vars["enum_map"] = "null";
   vars["prototype"] = "null";
@@ -208,5 +208,4 @@ void ExtensionGenerator::GenerateRegistrationCode(io::Printer* printer) {
 }  // namespace java
 }  // namespace compiler
 }  // namespace protobuf
-
 }  // namespace google

+ 10 - 0
src/google/protobuf/compiler/java/java_field.cc

@@ -46,6 +46,16 @@ namespace java {
 
 FieldGenerator::~FieldGenerator() {}
 
+void FieldGenerator::GenerateParsingCodeFromPacked(io::Printer* printer) const {
+  // Reaching here indicates a bug. Cases are:
+  //   - This FieldGenerator should support packing, but this method should be
+  //     overridden.
+  //   - This FieldGenerator doesn't support packing, and this method should
+  //     never have been called.
+  GOOGLE_LOG(FATAL) << "GenerateParsingCodeFromPacked() "
+             << "called on field generator that does not support packing.";
+}
+
 FieldGeneratorMap::FieldGeneratorMap(const Descriptor* descriptor)
   : descriptor_(descriptor),
     field_generators_(

+ 2 - 0
src/google/protobuf/compiler/java/java_field.h

@@ -57,9 +57,11 @@ class FieldGenerator {
 
   virtual void GenerateMembers(io::Printer* printer) const = 0;
   virtual void GenerateBuilderMembers(io::Printer* printer) const = 0;
+  virtual void GenerateInitializationCode(io::Printer* printer) const = 0;
   virtual void GenerateMergingCode(io::Printer* printer) const = 0;
   virtual void GenerateBuildingCode(io::Printer* printer) const = 0;
   virtual void GenerateParsingCode(io::Printer* printer) const = 0;
+  virtual void GenerateParsingCodeFromPacked(io::Printer* printer) const;
   virtual void GenerateSerializationCode(io::Printer* printer) const = 0;
   virtual void GenerateSerializedSizeCode(io::Printer* printer) const = 0;
 

+ 29 - 11
src/google/protobuf/compiler/java/java_file.cc

@@ -64,7 +64,7 @@ bool UsesExtensions(const Message& message) {
   for (int i = 0; i < fields.size(); i++) {
     if (fields[i]->is_extension()) return true;
 
-    if (fields[i]->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
+    if (GetJavaType(fields[i]) == JAVATYPE_MESSAGE) {
       if (fields[i]->is_repeated()) {
         int size = reflection->FieldSize(message, fields[i]);
         for (int j = 0; j < size; j++) {
@@ -82,6 +82,7 @@ bool UsesExtensions(const Message& message) {
   return false;
 }
 
+
 }  // namespace
 
 FileGenerator::FileGenerator(const FileDescriptor* file)
@@ -134,7 +135,9 @@ void FileGenerator::Generate(io::Printer* printer) {
   // fully-qualified names in the generated source.
   printer->Print(
     "// Generated by the protocol buffer compiler.  DO NOT EDIT!\n"
-    "\n");
+    "// source: $filename$\n"
+    "\n",
+    "filename", file_->name());
   if (!java_package_.empty()) {
     printer->Print(
       "package $package$;\n"
@@ -178,8 +181,10 @@ void FileGenerator::Generate(io::Printer* printer) {
     for (int i = 0; i < file_->message_type_count(); i++) {
       MessageGenerator(file_->message_type(i)).Generate(printer);
     }
-    for (int i = 0; i < file_->service_count(); i++) {
-      ServiceGenerator(file_->service(i)).Generate(printer);
+    if (HasGenericServices(file_)) {
+      for (int i = 0; i < file_->service_count(); i++) {
+        ServiceGenerator(file_->service(i)).Generate(printer);
+      }
     }
   }
 
@@ -228,6 +233,10 @@ void FileGenerator::Generate(io::Printer* printer) {
     "\n"
     "public static void internalForceInit() {}\n");
 
+  printer->Print(
+    "\n"
+    "// @@protoc_insertion_point(outer_class_scope)\n");
+
   printer->Outdent();
   printer->Print("}\n");
 }
@@ -245,6 +254,7 @@ void FileGenerator::GenerateEmbeddedDescriptor(io::Printer* printer) {
   // embedded raw, which is what we want.
   FileDescriptorProto file_proto;
   file_->CopyTo(&file_proto);
+
   string file_data;
   file_proto.SerializeToString(&file_data);
 
@@ -343,9 +353,11 @@ void FileGenerator::GenerateEmbeddedDescriptor(io::Printer* printer) {
     "    new com.google.protobuf.Descriptors.FileDescriptor[] {\n");
 
   for (int i = 0; i < file_->dependency_count(); i++) {
-    printer->Print(
-      "      $dependency$.getDescriptor(),\n",
-      "dependency", ClassName(file_->dependency(i)));
+    if (ShouldIncludeDependency(file_->dependency(i))) {
+      printer->Print(
+        "      $dependency$.getDescriptor(),\n",
+        "dependency", ClassName(file_->dependency(i)));
+    }
   }
 
   printer->Print(
@@ -396,14 +408,20 @@ void FileGenerator::GenerateSiblings(const string& package_dir,
                                         file_->message_type(i),
                                         output_directory, file_list);
     }
-    for (int i = 0; i < file_->service_count(); i++) {
-      GenerateSibling<ServiceGenerator>(package_dir, java_package_,
-                                        file_->service(i),
-                                        output_directory, file_list);
+    if (HasGenericServices(file_)) {
+      for (int i = 0; i < file_->service_count(); i++) {
+        GenerateSibling<ServiceGenerator>(package_dir, java_package_,
+                                          file_->service(i),
+                                          output_directory, file_list);
+      }
     }
   }
 }
 
+bool FileGenerator::ShouldIncludeDependency(const FileDescriptor* descriptor) {
+  return true;
+}
+
 }  // namespace java
 }  // namespace compiler
 }  // namespace protobuf

+ 5 - 0
src/google/protobuf/compiler/java/java_file.h

@@ -77,6 +77,11 @@ class FileGenerator {
   const string& classname()    { return classname_;    }
 
  private:
+  // Returns whether the dependency should be included in the output file.
+  // Always returns true for opensource, but used internally at Google to help
+  // improve compatibility with version 1 of protocol buffers.
+  bool ShouldIncludeDependency(const FileDescriptor* descriptor);
+
   const FileDescriptor* file_;
   string java_package_;
   string classname_;

+ 1 - 0
src/google/protobuf/compiler/java/java_generator.cc

@@ -45,6 +45,7 @@ namespace protobuf {
 namespace compiler {
 namespace java {
 
+
 JavaGenerator::JavaGenerator() {}
 JavaGenerator::~JavaGenerator() {}
 

+ 34 - 9
src/google/protobuf/compiler/java/java_helpers.cc

@@ -32,6 +32,7 @@
 //  Based on original Protocol Buffers design by
 //  Sanjay Ghemawat, Jeff Dean, and others.
 
+#include <limits>
 #include <vector>
 
 #include <google/protobuf/compiler/java/java_helpers.h>
@@ -57,7 +58,7 @@ const string& FieldName(const FieldDescriptor* field) {
   // Groups are hacky:  The name of the field is just the lower-cased name
   // of the group type.  In Java, though, we would like to retain the original
   // capitalization of the type name.
-  if (field->type() == FieldDescriptor::TYPE_GROUP) {
+  if (GetType(field) == FieldDescriptor::TYPE_GROUP) {
     return field->message_type()->name();
   } else {
     return field->name();
@@ -178,8 +179,12 @@ string FieldConstantName(const FieldDescriptor *field) {
   return name;
 }
 
-JavaType GetJavaType(FieldDescriptor::Type field_type) {
-  switch (field_type) {
+FieldDescriptor::Type GetType(const FieldDescriptor* field) {
+  return field->type();
+}
+
+JavaType GetJavaType(const FieldDescriptor* field) {
+  switch (GetType(field)) {
     case FieldDescriptor::TYPE_INT32:
     case FieldDescriptor::TYPE_UINT32:
     case FieldDescriptor::TYPE_SINT32:
@@ -254,7 +259,7 @@ bool AllAscii(const string& text) {
 }
 
 string DefaultValue(const FieldDescriptor* field) {
-  // Switch on cpp_type since we need to know which default_value_* method
+  // Switch on CppType since we need to know which default_value_* method
   // of FieldDescriptor to call.
   switch (field->cpp_type()) {
     case FieldDescriptor::CPPTYPE_INT32:
@@ -267,14 +272,34 @@ string DefaultValue(const FieldDescriptor* field) {
     case FieldDescriptor::CPPTYPE_UINT64:
       return SimpleItoa(static_cast<int64>(field->default_value_uint64())) +
              "L";
-    case FieldDescriptor::CPPTYPE_DOUBLE:
-      return SimpleDtoa(field->default_value_double()) + "D";
-    case FieldDescriptor::CPPTYPE_FLOAT:
-      return SimpleFtoa(field->default_value_float()) + "F";
+    case FieldDescriptor::CPPTYPE_DOUBLE: {
+      double value = field->default_value_double();
+      if (value == numeric_limits<double>::infinity()) {
+        return "Double.POSITIVE_INFINITY";
+      } else if (value == -numeric_limits<double>::infinity()) {
+        return "Double.NEGATIVE_INFINITY";
+      } else if (value != value) {
+        return "Double.NaN";
+      } else {
+        return SimpleDtoa(value) + "D";
+      }
+    }
+    case FieldDescriptor::CPPTYPE_FLOAT: {
+      float value = field->default_value_float();
+      if (value == numeric_limits<float>::infinity()) {
+        return "Float.POSITIVE_INFINITY";
+      } else if (value == -numeric_limits<float>::infinity()) {
+        return "Float.NEGATIVE_INFINITY";
+      } else if (value != value) {
+        return "Float.NaN";
+      } else {
+        return SimpleFtoa(value) + "F";
+      }
+    }
     case FieldDescriptor::CPPTYPE_BOOL:
       return field->default_value_bool() ? "true" : "false";
     case FieldDescriptor::CPPTYPE_STRING:
-      if (field->type() == FieldDescriptor::TYPE_BYTES) {
+      if (GetType(field) == FieldDescriptor::TYPE_BYTES) {
         if (field->has_default_value()) {
           // See comments in Internal.java for gory details.
           return strings::Substitute(

+ 13 - 5
src/google/protobuf/compiler/java/java_helpers.h

@@ -93,6 +93,11 @@ string ClassName(const FileDescriptor* descriptor);
 // number constant.
 string FieldConstantName(const FieldDescriptor *field);
 
+// Returns the type of the FieldDescriptor.
+// This does nothing interesting for the open source release, but is used for
+// hacks that improve compatability with version 1 protocol buffers at Google.
+FieldDescriptor::Type GetType(const FieldDescriptor* field);
+
 enum JavaType {
   JAVATYPE_INT,
   JAVATYPE_LONG,
@@ -105,11 +110,7 @@ enum JavaType {
   JAVATYPE_MESSAGE
 };
 
-JavaType GetJavaType(FieldDescriptor::Type field_type);
-
-inline JavaType GetJavaType(const FieldDescriptor* field) {
-  return GetJavaType(field->type());
-}
+JavaType GetJavaType(const FieldDescriptor* field);
 
 // Get the fully-qualified class name for a boxed primitive type, e.g.
 // "java.lang.Integer" for JAVATYPE_INT.  Returns NULL for enum and message
@@ -145,6 +146,13 @@ inline bool HasDescriptorMethods(const FileDescriptor* descriptor) {
            FileOptions::LITE_RUNTIME;
 }
 
+// Should we generate generic services for this file?
+inline bool HasGenericServices(const FileDescriptor *file) {
+  return file->service_count() > 0 &&
+         file->options().optimize_for() != FileOptions::LITE_RUNTIME &&
+         file->options().java_generic_services();
+}
+
 }  // namespace java
 }  // namespace compiler
 }  // namespace protobuf

+ 62 - 21
src/google/protobuf/compiler/java/java_message.cc

@@ -127,7 +127,7 @@ static bool HasRequiredFields(
     if (field->is_required()) {
       return true;
     }
-    if (field->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
+    if (GetJavaType(field) == JAVATYPE_MESSAGE) {
       if (HasRequiredFields(field->message_type(), already_seen)) {
         return true;
       }
@@ -292,9 +292,14 @@ void MessageGenerator::Generate(io::Printer* printer) {
   printer->Indent();
   printer->Print(
     "// Use $classname$.newBuilder() to construct.\n"
-    "private $classname$() {}\n"
+    "private $classname$() {\n"
+    "  initFields();\n"
+    "}\n"
+    // Used when constructing the default instance, which cannot be initialized
+    // immediately because it may cyclically refer to other default instances.
+    "private $classname$(boolean noInit) {}\n"
     "\n"
-    "private static final $classname$ defaultInstance = new $classname$();\n"
+    "private static final $classname$ defaultInstance;\n"
     "public static $classname$ getDefaultInstance() {\n"
     "  return defaultInstance;\n"
     "}\n"
@@ -344,6 +349,17 @@ void MessageGenerator::Generate(io::Printer* printer) {
     printer->Print("\n");
   }
 
+  // Called by the constructor, except in the case of the default instance,
+  // in which case this is called by static init code later on.
+  printer->Print("private void initFields() {\n");
+  printer->Indent();
+  for (int i = 0; i < descriptor_->field_count(); i++) {
+    field_generators_.get(descriptor_->field(i))
+                     .GenerateInitializationCode(printer);
+  }
+  printer->Outdent();
+  printer->Print("}\n");
+
   if (HasGeneratedMethods(descriptor_)) {
     GenerateIsInitialized(printer);
     GenerateMessageSerializationMethods(printer);
@@ -352,25 +368,23 @@ void MessageGenerator::Generate(io::Printer* printer) {
   GenerateParseFromMethods(printer);
   GenerateBuilder(printer);
 
-  if (HasDescriptorMethods(descriptor_)) {
-    // Force the static initialization code for the file to run, since it may
-    // initialize static variables declared in this class.
-    printer->Print(
-      "\n"
-      "static {\n"
-      "  $file$.getDescriptor();\n"
-      "}\n",
-      "file", ClassName(descriptor_->file()));
-  }
-
   // Force initialization of outer class.  Otherwise, nested extensions may
-  // not be initialized.
+  // not be initialized.  Also carefully initialize the default instance in
+  // such a way that it doesn't conflict with other initialization.
   printer->Print(
     "\n"
     "static {\n"
+    "  defaultInstance = new $classname$(true);\n"
     "  $file$.internalForceInit();\n"
+    "  defaultInstance.initFields();\n"
     "}\n",
-    "file", ClassName(descriptor_->file()));
+    "file", ClassName(descriptor_->file()),
+    "classname", descriptor_->name());
+
+  printer->Print(
+    "\n"
+    "// @@protoc_insertion_point(class_scope:$full_name$)\n",
+    "full_name", descriptor_->full_name());
 
   printer->Outdent();
   printer->Print("}\n\n");
@@ -529,14 +543,23 @@ GenerateParseFromMethods(io::Printer* printer) {
     "}\n"
     "public static $classname$ parseDelimitedFrom(java.io.InputStream input)\n"
     "    throws java.io.IOException {\n"
-    "  return newBuilder().mergeDelimitedFrom(input).buildParsed();\n"
+    "  Builder builder = newBuilder();\n"
+    "  if (builder.mergeDelimitedFrom(input)) {\n"
+    "    return builder.buildParsed();\n"
+    "  } else {\n"
+    "    return null;\n"
+    "  }\n"
     "}\n"
     "public static $classname$ parseDelimitedFrom(\n"
     "    java.io.InputStream input,\n"
     "    com.google.protobuf.ExtensionRegistryLite extensionRegistry)\n"
     "    throws java.io.IOException {\n"
-    "  return newBuilder().mergeDelimitedFrom(input, extensionRegistry)\n"
-    "           .buildParsed();\n"
+    "  Builder builder = newBuilder();\n"
+    "  if (builder.mergeDelimitedFrom(input, extensionRegistry)) {\n"
+    "    return builder.buildParsed();\n"
+    "  } else {\n"
+    "    return null;\n"
+    "  }\n"
     "}\n"
     "public static $classname$ parseFrom(\n"
     "    com.google.protobuf.CodedInputStream input)\n"
@@ -827,7 +850,7 @@ void MessageGenerator::GenerateBuilderParsingMethods(io::Printer* printer) {
   for (int i = 0; i < descriptor_->field_count(); i++) {
     const FieldDescriptor* field = sorted_fields[i];
     uint32 tag = WireFormatLite::MakeTag(field->number(),
-      WireFormat::WireTypeForField(field));
+      WireFormat::WireTypeForFieldType(field->type()));
 
     printer->Print(
       "case $tag$: {\n",
@@ -840,6 +863,24 @@ void MessageGenerator::GenerateBuilderParsingMethods(io::Printer* printer) {
     printer->Print(
       "  break;\n"
       "}\n");
+
+    if (field->is_packable()) {
+      // To make packed = true wire compatible, we generate parsing code from a
+      // packed version of this field regardless of field->options().packed().
+      uint32 packed_tag = WireFormatLite::MakeTag(field->number(),
+        WireFormatLite::WIRETYPE_LENGTH_DELIMITED);
+      printer->Print(
+        "case $tag$: {\n",
+        "tag", SimpleItoa(packed_tag));
+      printer->Indent();
+
+      field_generators_.get(field).GenerateParsingCodeFromPacked(printer);
+
+      printer->Outdent();
+      printer->Print(
+        "  break;\n"
+        "}\n");
+    }
   }
 
   printer->Outdent();
@@ -875,7 +916,7 @@ void MessageGenerator::GenerateIsInitialized(io::Printer* printer) {
   // Now check that all embedded messages are initialized.
   for (int i = 0; i < descriptor_->field_count(); i++) {
     const FieldDescriptor* field = descriptor_->field(i);
-    if (field->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE &&
+    if (GetJavaType(field) == JAVATYPE_MESSAGE &&
         HasRequiredFields(field->message_type())) {
       switch (field->label()) {
         case FieldDescriptor::LABEL_REQUIRED:

+ 14 - 4
src/google/protobuf/compiler/java/java_message_field.cc

@@ -59,7 +59,7 @@ void SetMessageVariables(const FieldDescriptor* descriptor,
   (*variables)["number"] = SimpleItoa(descriptor->number());
   (*variables)["type"] = ClassName(descriptor->message_type());
   (*variables)["group_or_message"] =
-    (descriptor->type() == FieldDescriptor::TYPE_GROUP) ?
+    (GetType(descriptor) == FieldDescriptor::TYPE_GROUP) ?
     "Group" : "Message";
 }
 
@@ -79,7 +79,7 @@ void MessageFieldGenerator::
 GenerateMembers(io::Printer* printer) const {
   printer->Print(variables_,
     "private boolean has$capitalized_name$;\n"
-    "private $type$ $name$_ = $type$.getDefaultInstance();\n"
+    "private $type$ $name$_;\n"
     "public boolean has$capitalized_name$() { return has$capitalized_name$; }\n"
     "public $type$ get$capitalized_name$() { return $name$_; }\n");
 }
@@ -124,6 +124,11 @@ GenerateBuilderMembers(io::Printer* printer) const {
     "}\n");
 }
 
+void MessageFieldGenerator::
+GenerateInitializationCode(io::Printer* printer) const {
+  printer->Print(variables_, "$name$_ = $type$.getDefaultInstance();\n");
+}
+
 void MessageFieldGenerator::
 GenerateMergingCode(io::Printer* printer) const {
   printer->Print(variables_,
@@ -145,7 +150,7 @@ GenerateParsingCode(io::Printer* printer) const {
     "  subBuilder.mergeFrom(get$capitalized_name$());\n"
     "}\n");
 
-  if (descriptor_->type() == FieldDescriptor::TYPE_GROUP) {
+  if (GetType(descriptor_) == FieldDescriptor::TYPE_GROUP) {
     printer->Print(variables_,
       "input.readGroup($number$, subBuilder, extensionRegistry);\n");
   } else {
@@ -261,6 +266,11 @@ GenerateBuilderMembers(io::Printer* printer) const {
     "}\n");
 }
 
+void RepeatedMessageFieldGenerator::
+GenerateInitializationCode(io::Printer* printer) const {
+  // Initialized inline.
+}
+
 void RepeatedMessageFieldGenerator::
 GenerateMergingCode(io::Printer* printer) const {
   printer->Print(variables_,
@@ -286,7 +296,7 @@ GenerateParsingCode(io::Printer* printer) const {
   printer->Print(variables_,
     "$type$.Builder subBuilder = $type$.newBuilder();\n");
 
-  if (descriptor_->type() == FieldDescriptor::TYPE_GROUP) {
+  if (GetType(descriptor_) == FieldDescriptor::TYPE_GROUP) {
     printer->Print(variables_,
       "input.readGroup($number$, subBuilder, extensionRegistry);\n");
   } else {

+ 2 - 0
src/google/protobuf/compiler/java/java_message_field.h

@@ -52,6 +52,7 @@ class MessageFieldGenerator : public FieldGenerator {
   // implements FieldGenerator ---------------------------------------
   void GenerateMembers(io::Printer* printer) const;
   void GenerateBuilderMembers(io::Printer* printer) const;
+  void GenerateInitializationCode(io::Printer* printer) const;
   void GenerateMergingCode(io::Printer* printer) const;
   void GenerateBuildingCode(io::Printer* printer) const;
   void GenerateParsingCode(io::Printer* printer) const;
@@ -75,6 +76,7 @@ class RepeatedMessageFieldGenerator : public FieldGenerator {
   // implements FieldGenerator ---------------------------------------
   void GenerateMembers(io::Printer* printer) const;
   void GenerateBuilderMembers(io::Printer* printer) const;
+  void GenerateInitializationCode(io::Printer* printer) const;
   void GenerateMergingCode(io::Printer* printer) const;
   void GenerateBuildingCode(io::Printer* printer) const;
   void GenerateParsingCode(io::Printer* printer) const;

+ 29 - 17
src/google/protobuf/compiler/java/java_primitive_field.cc

@@ -93,7 +93,7 @@ bool IsReferenceType(JavaType type) {
 }
 
 const char* GetCapitalizedType(const FieldDescriptor* field) {
-  switch (field->type()) {
+  switch (GetType(field)) {
     case FieldDescriptor::TYPE_INT32   : return "Int32"   ;
     case FieldDescriptor::TYPE_UINT32  : return "UInt32"  ;
     case FieldDescriptor::TYPE_SINT32  : return "SInt32"  ;
@@ -166,7 +166,7 @@ void SetPrimitiveVariables(const FieldDescriptor* descriptor,
   (*variables)["capitalized_type"] = GetCapitalizedType(descriptor);
   (*variables)["tag"] = SimpleItoa(WireFormat::MakeTag(descriptor));
   (*variables)["tag_size"] = SimpleItoa(
-      WireFormat::TagSize(descriptor->number(), descriptor->type()));
+      WireFormat::TagSize(descriptor->number(), GetType(descriptor)));
   if (IsReferenceType(GetJavaType(descriptor))) {
     (*variables)["null_check"] =
         "  if (value == null) {\n"
@@ -175,7 +175,7 @@ void SetPrimitiveVariables(const FieldDescriptor* descriptor,
   } else {
     (*variables)["null_check"] = "";
   }
-  int fixed_size = FixedSize(descriptor->type());
+  int fixed_size = FixedSize(GetType(descriptor));
   if (fixed_size != -1) {
     (*variables)["fixed_size"] = SimpleItoa(fixed_size);
   }
@@ -218,7 +218,8 @@ GenerateBuilderMembers(io::Printer* printer) const {
     "}\n"
     "public Builder clear$capitalized_name$() {\n"
     "  result.has$capitalized_name$ = false;\n");
-  if (descriptor_->cpp_type() == FieldDescriptor::CPPTYPE_STRING) {
+  JavaType type = GetJavaType(descriptor_);
+  if (type == JAVATYPE_STRING || type == JAVATYPE_BYTES) {
     // The default value is not a simple literal so we want to avoid executing
     // it multiple times.  Instead, get the default out of the default instance.
     printer->Print(variables_,
@@ -232,6 +233,11 @@ GenerateBuilderMembers(io::Printer* printer) const {
     "}\n");
 }
 
+void PrimitiveFieldGenerator::
+GenerateInitializationCode(io::Printer* printer) const {
+  // Initialized inline.
+}
+
 void PrimitiveFieldGenerator::
 GenerateMergingCode(io::Printer* printer) const {
   printer->Print(variables_,
@@ -345,6 +351,11 @@ GenerateBuilderMembers(io::Printer* printer) const {
     "}\n");
 }
 
+void RepeatedPrimitiveFieldGenerator::
+GenerateInitializationCode(io::Printer* printer) const {
+  // Initialized inline.
+}
+
 void RepeatedPrimitiveFieldGenerator::
 GenerateMergingCode(io::Printer* printer) const {
   printer->Print(variables_,
@@ -367,18 +378,19 @@ GenerateBuildingCode(io::Printer* printer) const {
 
 void RepeatedPrimitiveFieldGenerator::
 GenerateParsingCode(io::Printer* printer) const {
-  if (descriptor_->options().packed()) {
-    printer->Print(variables_,
-      "int length = input.readRawVarint32();\n"
-      "int limit = input.pushLimit(length);\n"
-      "while (input.getBytesUntilLimit() > 0) {\n"
-      "  add$capitalized_name$(input.read$capitalized_type$());\n"
-      "}\n"
-      "input.popLimit(limit);\n");
-  } else {
-    printer->Print(variables_,
-      "add$capitalized_name$(input.read$capitalized_type$());\n");
-  }
+  printer->Print(variables_,
+    "add$capitalized_name$(input.read$capitalized_type$());\n");
+}
+
+void RepeatedPrimitiveFieldGenerator::
+GenerateParsingCodeFromPacked(io::Printer* printer) const {
+  printer->Print(variables_,
+    "int length = input.readRawVarint32();\n"
+    "int limit = input.pushLimit(length);\n"
+    "while (input.getBytesUntilLimit() > 0) {\n"
+    "  add$capitalized_name$(input.read$capitalized_type$());\n"
+    "}\n"
+    "input.popLimit(limit);\n");
 }
 
 void RepeatedPrimitiveFieldGenerator::
@@ -407,7 +419,7 @@ GenerateSerializedSizeCode(io::Printer* printer) const {
     "  int dataSize = 0;\n");
   printer->Indent();
 
-  if (FixedSize(descriptor_->type()) == -1) {
+  if (FixedSize(GetType(descriptor_)) == -1) {
     printer->Print(variables_,
       "for ($type$ element : get$capitalized_name$List()) {\n"
       "  dataSize += com.google.protobuf.CodedOutputStream\n"

+ 3 - 0
src/google/protobuf/compiler/java/java_primitive_field.h

@@ -52,6 +52,7 @@ class PrimitiveFieldGenerator : public FieldGenerator {
   // implements FieldGenerator ---------------------------------------
   void GenerateMembers(io::Printer* printer) const;
   void GenerateBuilderMembers(io::Printer* printer) const;
+  void GenerateInitializationCode(io::Printer* printer) const;
   void GenerateMergingCode(io::Printer* printer) const;
   void GenerateBuildingCode(io::Printer* printer) const;
   void GenerateParsingCode(io::Printer* printer) const;
@@ -75,9 +76,11 @@ class RepeatedPrimitiveFieldGenerator : public FieldGenerator {
   // implements FieldGenerator ---------------------------------------
   void GenerateMembers(io::Printer* printer) const;
   void GenerateBuilderMembers(io::Printer* printer) const;
+  void GenerateInitializationCode(io::Printer* printer) const;
   void GenerateMergingCode(io::Printer* printer) const;
   void GenerateBuildingCode(io::Printer* printer) const;
   void GenerateParsingCode(io::Printer* printer) const;
+  void GenerateParsingCodeFromPacked(io::Printer* printer) const;
   void GenerateSerializationCode(io::Printer* printer) const;
   void GenerateSerializedSizeCode(io::Printer* printer) const;
 

+ 1 - 0
src/google/protobuf/compiler/main.cc

@@ -39,6 +39,7 @@
 int main(int argc, char* argv[]) {
 
   google::protobuf::compiler::CommandLineInterface cli;
+  cli.AllowPlugins("protoc-");
 
   // Proto2 C++
   google::protobuf::compiler::cpp::CppGenerator cpp_generator;

+ 10 - 1
src/google/protobuf/compiler/parser.cc

@@ -34,8 +34,9 @@
 //
 // Recursive descent FTW.
 
-#include <google/protobuf/stubs/hash.h>
 #include <float.h>
+#include <google/protobuf/stubs/hash.h>
+#include <limits>
 
 
 #include <google/protobuf/compiler/parser.h>
@@ -206,6 +207,14 @@ bool Parser::ConsumeNumber(double* output, const char* error) {
     *output = value;
     input_->Next();
     return true;
+  } else if (LookingAt("inf")) {
+    *output = numeric_limits<double>::infinity();
+    input_->Next();
+    return true;
+  } else if (LookingAt("nan")) {
+    *output = numeric_limits<double>::quiet_NaN();
+    input_->Next();
+    return true;
   } else {
     AddError(error);
     return false;

+ 6 - 0
src/google/protobuf/compiler/parser_unittest.cc

@@ -336,6 +336,9 @@ TEST_F(ParseMessageTest, FieldDefaults) {
     "  required double foo = 1 [default= 10.5];\n"
     "  required double foo = 1 [default=-11.5];\n"
     "  required double foo = 1 [default= 12  ];\n"
+    "  required double foo = 1 [default= inf ];\n"
+    "  required double foo = 1 [default=-inf ];\n"
+    "  required double foo = 1 [default= nan ];\n"
     "  required string foo = 1 [default='13\\001'];\n"
     "  required string foo = 1 [default='a' \"b\" \n \"c\"];\n"
     "  required bytes  foo = 1 [default='14\\002'];\n"
@@ -367,6 +370,9 @@ TEST_F(ParseMessageTest, FieldDefaults) {
     "  field { type:TYPE_DOUBLE  default_value:\"10.5\"      "ETC" }"
     "  field { type:TYPE_DOUBLE  default_value:\"-11.5\"     "ETC" }"
     "  field { type:TYPE_DOUBLE  default_value:\"12\"        "ETC" }"
+    "  field { type:TYPE_DOUBLE  default_value:\"inf\"       "ETC" }"
+    "  field { type:TYPE_DOUBLE  default_value:\"-inf\"      "ETC" }"
+    "  field { type:TYPE_DOUBLE  default_value:\"nan\"       "ETC" }"
     "  field { type:TYPE_STRING  default_value:\"13\\001\"   "ETC" }"
     "  field { type:TYPE_STRING  default_value:\"abc\"       "ETC" }"
     "  field { type:TYPE_BYTES   default_value:\"14\\\\002\" "ETC" }"

+ 197 - 30
src/google/protobuf/compiler/python/python_generator.cc

@@ -42,8 +42,9 @@
 // performance-minded Python code leverage the fast C++ implementation
 // directly.
 
-#include <utility>
+#include <limits>
 #include <map>
+#include <utility>
 #include <string>
 #include <vector>
 
@@ -105,6 +106,13 @@ string NamePrefixedWithNestedTypes(const DescriptorT& descriptor,
 const char kDescriptorKey[] = "DESCRIPTOR";
 
 
+// Should we generate generic services for this file?
+inline bool HasGenericServices(const FileDescriptor *file) {
+  return file->service_count() > 0 &&
+         file->options().py_generic_services();
+}
+
+
 // Prints the common boilerplate needed at the top of every .py
 // file output by this generator.
 void PrintTopBoilerplate(
@@ -115,14 +123,21 @@ void PrintTopBoilerplate(
       "\n"
       "from google.protobuf import descriptor\n"
       "from google.protobuf import message\n"
-      "from google.protobuf import reflection\n"
-      "from google.protobuf import service\n"
-      "from google.protobuf import service_reflection\n");
+      "from google.protobuf import reflection\n");
+  if (HasGenericServices(file)) {
+    printer->Print(
+        "from google.protobuf import service\n"
+        "from google.protobuf import service_reflection\n");
+  }
+
   // Avoid circular imports if this module is descriptor_pb2.
   if (!descriptor_proto) {
     printer->Print(
         "from google.protobuf import descriptor_pb2\n");
   }
+  printer->Print(
+    "# @@protoc_insertion_point(imports)\n");
+  printer->Print("\n\n");
 }
 
 
@@ -150,10 +165,30 @@ string StringifyDefaultValue(const FieldDescriptor& field) {
       return SimpleItoa(field.default_value_int64());
     case FieldDescriptor::CPPTYPE_UINT64:
       return SimpleItoa(field.default_value_uint64());
-    case FieldDescriptor::CPPTYPE_DOUBLE:
-      return SimpleDtoa(field.default_value_double());
-    case FieldDescriptor::CPPTYPE_FLOAT:
-      return SimpleFtoa(field.default_value_float());
+    case FieldDescriptor::CPPTYPE_DOUBLE: {
+      double value = field.default_value_double();
+      if (value == numeric_limits<double>::infinity()) {
+        return "float('inf')";
+      } else if (value == -numeric_limits<double>::infinity()) {
+        return "float('-inf')";
+      } else if (value != value) {
+        return "float('nan')";
+      } else {
+        return SimpleDtoa(value);
+      }
+    }
+    case FieldDescriptor::CPPTYPE_FLOAT: {
+      float value = field.default_value_float();
+      if (value == numeric_limits<float>::infinity()) {
+        return "float('inf')";
+      } else if (value == -numeric_limits<float>::infinity()) {
+        return "float('-inf')";
+      } else if (value != value) {
+        return "float('nan')";
+      } else {
+        return SimpleFtoa(value);
+      }
+    }
     case FieldDescriptor::CPPTYPE_BOOL:
       return field.default_value_bool() ? "True" : "False";
     case FieldDescriptor::CPPTYPE_ENUM:
@@ -204,6 +239,10 @@ bool Generator::Generate(const FileDescriptor* file,
   StripString(&filename, ".", '/');
   filename += ".py";
 
+  FileDescriptorProto fdp;
+  file_->CopyTo(&fdp);
+  fdp.SerializeToString(&file_descriptor_serialized_);
+
 
   scoped_ptr<io::ZeroCopyOutputStream> output(output_directory->Open(filename));
   GOOGLE_CHECK(output.get());
@@ -211,6 +250,7 @@ bool Generator::Generate(const FileDescriptor* file,
   printer_ = &printer;
 
   PrintTopBoilerplate(printer_, file_, GeneratingDescriptorProto());
+  PrintFileDescriptor();
   PrintTopLevelEnums();
   PrintTopLevelExtensions();
   PrintAllNestedEnumsInFile();
@@ -224,7 +264,13 @@ bool Generator::Generate(const FileDescriptor* file,
   // since they need to call static RegisterExtension() methods on these
   // classes.
   FixForeignFieldsInExtensions();
-  PrintServices();
+  if (HasGenericServices(file)) {
+    PrintServices();
+  }
+
+  printer.Print(
+    "# @@protoc_insertion_point(module_scope)\n");
+
   return !printer.failed();
 }
 
@@ -238,6 +284,30 @@ void Generator::PrintImports() const {
   printer_->Print("\n");
 }
 
+// Prints the single file descriptor for this file.
+void Generator::PrintFileDescriptor() const {
+  map<string, string> m;
+  m["descriptor_name"] = kDescriptorKey;
+  m["name"] = file_->name();
+  m["package"] = file_->package();
+  const char file_descriptor_template[] =
+      "$descriptor_name$ = descriptor.FileDescriptor(\n"
+      "  name='$name$',\n"
+      "  package='$package$',\n";
+  printer_->Print(m, file_descriptor_template);
+  printer_->Indent();
+  printer_->Print(
+      "serialized_pb='$value$'",
+      "value", strings::CHexEscape(file_descriptor_serialized_));
+
+  // TODO(falk): Also print options and fix the message_type, enum_type,
+  //             service and extension later in the generation.
+
+  printer_->Outdent();
+  printer_->Print(")\n");
+  printer_->Print("\n");
+}
+
 // Prints descriptors and module-level constants for all top-level
 // enums defined in |file|.
 void Generator::PrintTopLevelEnums() const {
@@ -277,12 +347,13 @@ void Generator::PrintEnum(const EnumDescriptor& enum_descriptor) const {
   m["descriptor_name"] = ModuleLevelDescriptorName(enum_descriptor);
   m["name"] = enum_descriptor.name();
   m["full_name"] = enum_descriptor.full_name();
-  m["filename"] = enum_descriptor.name();
+  m["file"] = kDescriptorKey;
   const char enum_descriptor_template[] =
       "$descriptor_name$ = descriptor.EnumDescriptor(\n"
       "  name='$name$',\n"
       "  full_name='$full_name$',\n"
-      "  filename='$filename$',\n"
+      "  filename=None,\n"
+      "  file=$file$,\n"
       "  values=[\n";
   string options_string;
   enum_descriptor.options().SerializeToString(&options_string);
@@ -295,9 +366,12 @@ void Generator::PrintEnum(const EnumDescriptor& enum_descriptor) const {
   }
   printer_->Outdent();
   printer_->Print("],\n");
+  printer_->Print("containing_type=None,\n");
   printer_->Print("options=$options_value$,\n",
                   "options_value",
                   OptionsValue("EnumOptions", CEscape(options_string)));
+  EnumDescriptorProto edp;
+  PrintSerializedPbInterval(enum_descriptor, edp);
   printer_->Outdent();
   printer_->Print(")\n");
   printer_->Print("\n");
@@ -362,15 +436,21 @@ void Generator::PrintServiceDescriptor(
   map<string, string> m;
   m["name"] = descriptor.name();
   m["full_name"] = descriptor.full_name();
+  m["file"] = kDescriptorKey;
   m["index"] = SimpleItoa(descriptor.index());
   m["options_value"] = OptionsValue("ServiceOptions", options_string);
   const char required_function_arguments[] =
       "name='$name$',\n"
       "full_name='$full_name$',\n"
+      "file=$file$,\n"
       "index=$index$,\n"
-      "options=$options_value$,\n"
-      "methods=[\n";
+      "options=$options_value$,\n";
   printer_->Print(m, required_function_arguments);
+
+  ServiceDescriptorProto sdp;
+  PrintSerializedPbInterval(descriptor, sdp);
+
+  printer_->Print("methods=[\n");
   for (int i = 0; i < descriptor.method_count(); ++i) {
     const MethodDescriptor* method = descriptor.method(i);
     string options_string;
@@ -444,17 +524,27 @@ void Generator::PrintDescriptor(const Descriptor& message_descriptor) const {
   map<string, string> m;
   m["name"] = message_descriptor.name();
   m["full_name"] = message_descriptor.full_name();
-  m["filename"] = message_descriptor.file()->name();
+  m["file"] = kDescriptorKey;
   const char required_function_arguments[] =
       "name='$name$',\n"
       "full_name='$full_name$',\n"
-      "filename='$filename$',\n"
-      "containing_type=None,\n";  // TODO(robinson): Implement containing_type.
+      "filename=None,\n"
+      "file=$file$,\n"
+      "containing_type=None,\n";
   printer_->Print(m, required_function_arguments);
   PrintFieldsInDescriptor(message_descriptor);
   PrintExtensionsInDescriptor(message_descriptor);
-  // TODO(robinson): implement printing of nested_types.
-  printer_->Print("nested_types=[],  # TODO(robinson): Implement.\n");
+
+  // Nested types
+  printer_->Print("nested_types=[");
+  for (int i = 0; i < message_descriptor.nested_type_count(); ++i) {
+    const string nested_name = ModuleLevelDescriptorName(
+        *message_descriptor.nested_type(i));
+    printer_->Print("$name$, ", "name", nested_name);
+  }
+  printer_->Print("],\n");
+
+  // Enum types
   printer_->Print("enum_types=[\n");
   printer_->Indent();
   for (int i = 0; i < message_descriptor.enum_type_count(); ++i) {
@@ -468,8 +558,28 @@ void Generator::PrintDescriptor(const Descriptor& message_descriptor) const {
   string options_string;
   message_descriptor.options().SerializeToString(&options_string);
   printer_->Print(
-      "options=$options_value$",
-      "options_value", OptionsValue("MessageOptions", options_string));
+      "options=$options_value$,\n"
+      "is_extendable=$extendable$",
+      "options_value", OptionsValue("MessageOptions", options_string),
+      "extendable", message_descriptor.extension_range_count() > 0 ?
+                      "True" : "False");
+  printer_->Print(",\n");
+
+  // Extension ranges
+  printer_->Print("extension_ranges=[");
+  for (int i = 0; i < message_descriptor.extension_range_count(); ++i) {
+    const Descriptor::ExtensionRange* range =
+        message_descriptor.extension_range(i);
+    printer_->Print("($start$, $end$), ",
+                    "start", SimpleItoa(range->start),
+                    "end", SimpleItoa(range->end));
+  }
+  printer_->Print("],\n");
+
+  // Serialization of proto
+  DescriptorProto edp;
+  PrintSerializedPbInterval(message_descriptor, edp);
+
   printer_->Outdent();
   printer_->Print(")\n");
 }
@@ -511,6 +621,12 @@ void Generator::PrintMessage(
   m["descriptor_key"] = kDescriptorKey;
   m["descriptor_name"] = ModuleLevelDescriptorName(message_descriptor);
   printer_->Print(m, "$descriptor_key$ = $descriptor_name$\n");
+
+  printer_->Print(
+    "\n"
+    "# @@protoc_insertion_point(class_scope:$full_name$)\n",
+    "full_name", message_descriptor.full_name());
+
   printer_->Outdent();
 }
 
@@ -527,16 +643,27 @@ void Generator::PrintNestedMessages(
 // Recursively fixes foreign fields in all nested types in |descriptor|, then
 // sets the message_type and enum_type of all message and enum fields to point
 // to their respective descriptors.
+// Args:
+//   descriptor: descriptor to print fields for.
+//   containing_descriptor: if descriptor is a nested type, this is its
+//       containing type, or NULL if this is a root/top-level type.
 void Generator::FixForeignFieldsInDescriptor(
-    const Descriptor& descriptor) const {
+    const Descriptor& descriptor,
+    const Descriptor* containing_descriptor) const {
   for (int i = 0; i < descriptor.nested_type_count(); ++i) {
-    FixForeignFieldsInDescriptor(*descriptor.nested_type(i));
+    FixForeignFieldsInDescriptor(*descriptor.nested_type(i), &descriptor);
   }
 
   for (int i = 0; i < descriptor.field_count(); ++i) {
     const FieldDescriptor& field_descriptor = *descriptor.field(i);
     FixForeignFieldsInField(&descriptor, field_descriptor, "fields_by_name");
   }
+
+  FixContainingTypeInDescriptor(descriptor, containing_descriptor);
+  for (int i = 0; i < descriptor.enum_type_count(); ++i) {
+    const EnumDescriptor& enum_descriptor = *descriptor.enum_type(i);
+    FixContainingTypeInDescriptor(enum_descriptor, &descriptor);
+  }
 }
 
 // Sets any necessary message_type and enum_type attributes
@@ -593,13 +720,29 @@ string Generator::FieldReferencingExpression(
       python_dict_name, field.name());
 }
 
+// Prints containing_type for nested descriptors or enum descriptors.
+template <typename DescriptorT>
+void Generator::FixContainingTypeInDescriptor(
+    const DescriptorT& descriptor,
+    const Descriptor* containing_descriptor) const {
+  if (containing_descriptor != NULL) {
+    const string nested_name = ModuleLevelDescriptorName(descriptor);
+    const string parent_name = ModuleLevelDescriptorName(
+        *containing_descriptor);
+    printer_->Print(
+        "$nested_name$.containing_type = $parent_name$;\n",
+        "nested_name", nested_name,
+        "parent_name", parent_name);
+  }
+}
+
 // Prints statements setting the message_type and enum_type fields in the
 // Python descriptor objects we've already output in ths file.  We must
 // do this in a separate step due to circular references (otherwise, we'd
 // just set everything in the initial assignment statements).
 void Generator::FixForeignFieldsInDescriptors() const {
   for (int i = 0; i < file_->message_type_count(); ++i) {
-    FixForeignFieldsInDescriptor(*file_->message_type(i));
+    FixForeignFieldsInDescriptor(*file_->message_type(i), NULL);
   }
   printer_->Print("\n");
 }
@@ -696,6 +839,7 @@ void Generator::PrintFieldDescriptor(
   m["type"] = SimpleItoa(field.type());
   m["cpp_type"] = SimpleItoa(field.cpp_type());
   m["label"] = SimpleItoa(field.label());
+  m["has_default_value"] = field.has_default_value() ? "True" : "False";
   m["default_value"] = StringifyDefaultValue(field);
   m["is_extension"] = is_extension ? "True" : "False";
   m["options"] = OptionsValue("FieldOptions", options_string);
@@ -703,13 +847,13 @@ void Generator::PrintFieldDescriptor(
   // these fields in correctly after all referenced descriptors have been
   // defined and/or imported (see FixForeignFieldsInDescriptors()).
   const char field_descriptor_decl[] =
-      "descriptor.FieldDescriptor(\n"
-      "  name='$name$', full_name='$full_name$', index=$index$,\n"
-      "  number=$number$, type=$type$, cpp_type=$cpp_type$, label=$label$,\n"
-      "  default_value=$default_value$,\n"
-      "  message_type=None, enum_type=None, containing_type=None,\n"
-      "  is_extension=$is_extension$, extension_scope=None,\n"
-      "  options=$options$)";
+    "descriptor.FieldDescriptor(\n"
+    "  name='$name$', full_name='$full_name$', index=$index$,\n"
+    "  number=$number$, type=$type$, cpp_type=$cpp_type$, label=$label$,\n"
+    "  has_default_value=$has_default_value$, default_value=$default_value$,\n"
+    "  message_type=None, enum_type=None, containing_type=None,\n"
+    "  is_extension=$is_extension$, extension_scope=None,\n"
+    "  options=$options$)";
   printer_->Print(m, field_descriptor_decl);
 }
 
@@ -811,6 +955,29 @@ string Generator::ModuleLevelServiceDescriptorName(
   return name;
 }
 
+// Prints standard constructor arguments serialized_start and serialized_end.
+// Args:
+//   descriptor: The cpp descriptor to have a serialized reference.
+//   proto: A proto
+// Example printer output:
+// serialized_start=41,
+// serialized_end=43,
+//
+template <typename DescriptorT, typename DescriptorProtoT>
+void Generator::PrintSerializedPbInterval(
+    const DescriptorT& descriptor, DescriptorProtoT& proto) const {
+  descriptor.CopyTo(&proto);
+  string sp;
+  proto.SerializeToString(&sp);
+  int offset = file_descriptor_serialized_.find(sp);
+  GOOGLE_CHECK_GE(offset, 0);
+
+  printer_->Print("serialized_start=$serialized_start$,\n"
+                  "serialized_end=$serialized_end$,\n",
+                  "serialized_start", SimpleItoa(offset),
+                  "serialized_end", SimpleItoa(offset + sp.size()));
+}
+
 }  // namespace python
 }  // namespace compiler
 }  // namespace protobuf

+ 14 - 2
src/google/protobuf/compiler/python/python_generator.h

@@ -71,6 +71,7 @@ class LIBPROTOC_EXPORT Generator : public CodeGenerator {
 
  private:
   void PrintImports() const;
+  void PrintFileDescriptor() const;
   void PrintTopLevelEnums() const;
   void PrintAllNestedEnumsInFile() const;
   void PrintNestedEnums(const Descriptor& descriptor) const;
@@ -97,13 +98,19 @@ class LIBPROTOC_EXPORT Generator : public CodeGenerator {
   void PrintNestedMessages(const Descriptor& containing_descriptor) const;
 
   void FixForeignFieldsInDescriptors() const;
-  void FixForeignFieldsInDescriptor(const Descriptor& descriptor) const;
+  void FixForeignFieldsInDescriptor(
+      const Descriptor& descriptor,
+      const Descriptor* containing_descriptor) const;
   void FixForeignFieldsInField(const Descriptor* containing_type,
                                const FieldDescriptor& field,
                                const string& python_dict_name) const;
   string FieldReferencingExpression(const Descriptor* containing_type,
                                     const FieldDescriptor& field,
                                     const string& python_dict_name) const;
+  template <typename DescriptorT>
+  void FixContainingTypeInDescriptor(
+      const DescriptorT& descriptor,
+      const Descriptor* containing_descriptor) const;
 
   void FixForeignFieldsInExtensions() const;
   void FixForeignFieldsInExtension(
@@ -126,10 +133,15 @@ class LIBPROTOC_EXPORT Generator : public CodeGenerator {
   string ModuleLevelServiceDescriptorName(
       const ServiceDescriptor& descriptor) const;
 
+  template <typename DescriptorT, typename DescriptorProtoT>
+  void PrintSerializedPbInterval(
+      const DescriptorT& descriptor, DescriptorProtoT& proto) const;
+
   // Very coarse-grained lock to ensure that Generate() is reentrant.
-  // Guards file_ and printer_.
+  // Guards file_, printer_ and file_descriptor_serialized_.
   mutable Mutex mutex_;
   mutable const FileDescriptor* file_;  // Set in Generate().  Under mutex_.
+  mutable string file_descriptor_serialized_;
   mutable io::Printer* printer_;  // Set in Generate().  Under mutex_.
 
   GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(Generator);

+ 8 - 12
src/google/protobuf/descriptor.cc

@@ -796,9 +796,10 @@ bool DescriptorPool::InternalIsFileLoaded(const string& filename) const {
 
 namespace {
 
+
 EncodedDescriptorDatabase* generated_database_ = NULL;
 DescriptorPool* generated_pool_ = NULL;
-GOOGLE_PROTOBUF_DECLARE_ONCE(generated_pool_init_);
+GoogleOnceType generated_pool_init_;
 
 void DeleteGeneratedPool() {
   delete generated_database_;
@@ -810,6 +811,7 @@ void DeleteGeneratedPool() {
 void InitGeneratedPool() {
   generated_database_ = new EncodedDescriptorDatabase;
   generated_pool_ = new DescriptorPool(generated_database_);
+
   internal::OnShutdown(&DeleteGeneratedPool);
 }
 
@@ -3651,17 +3653,11 @@ void DescriptorBuilder::ValidateFieldOptions(FieldDescriptor* field,
   }
 
   // Only repeated primitive fields may be packed.
-  if (field->options().packed()) {
-    if (!field->is_repeated() ||
-        field->type() == FieldDescriptor::TYPE_STRING ||
-        field->type() == FieldDescriptor::TYPE_GROUP ||
-        field->type() == FieldDescriptor::TYPE_MESSAGE ||
-        field->type() == FieldDescriptor::TYPE_BYTES) {
-      AddError(
-        field->full_name(), proto,
-        DescriptorPool::ErrorCollector::TYPE,
-        "[packed = true] can only be specified for repeated primitive fields.");
-    }
+  if (field->options().packed() && !field->is_packable()) {
+    AddError(
+      field->full_name(), proto,
+      DescriptorPool::ErrorCollector::TYPE,
+      "[packed = true] can only be specified for repeated primitive fields.");
   }
 
   // Note:  Default instance may not yet be initialized here, so we have to

+ 19 - 4
src/google/protobuf/descriptor.h

@@ -395,6 +395,8 @@ class LIBPROTOBUF_EXPORT FieldDescriptor {
   bool is_required() const;      // shorthand for label() == LABEL_REQUIRED
   bool is_optional() const;      // shorthand for label() == LABEL_OPTIONAL
   bool is_repeated() const;      // shorthand for label() == LABEL_REPEATED
+  bool is_packable() const;      // shorthand for is_repeated() &&
+                                 //               IsTypePackable(type())
 
   // Index of this field within the message's field array, or the file or
   // extension scope's extensions array.
@@ -474,6 +476,9 @@ class LIBPROTOBUF_EXPORT FieldDescriptor {
   // Helper method to get the CppType for a particular Type.
   static CppType TypeToCppType(Type type);
 
+  // Return true iff [packed = true] is valid for fields of this type.
+  static inline bool IsTypePackable(Type field_type);
+
  private:
   typedef FieldOptions OptionsType;
 
@@ -1069,10 +1074,6 @@ class LIBPROTOBUF_EXPORT DescriptorPool {
   // These methods may contain hidden pitfalls and may be removed in a
   // future library version.
 
-  // DEPRECATED:  Use of underlays can lead to many subtle gotchas.  Instead,
-  //   try to formulate what you want to do in terms of DescriptorDatabases.
-  //   This constructor will be removed soon.
-  //
   // Create a DescriptorPool which is overlaid on top of some other pool.
   // If you search for a descriptor in the overlay and it is not found, the
   // underlay will be searched as a backup.  If the underlay has its own
@@ -1090,6 +1091,9 @@ class LIBPROTOBUF_EXPORT DescriptorPool {
   // types directly into generated_pool(): this is not allowed, and would be
   // bad design anyway.  So, instead, you could use generated_pool() as an
   // underlay for a new DescriptorPool in which you add only the new file.
+  //
+  // WARNING:  Use of underlays can lead to many subtle gotchas.  Instead,
+  //   try to formulate what you want to do in terms of DescriptorDatabases.
   explicit DescriptorPool(const DescriptorPool* underlay);
 
   // Called by generated classes at init time to add their descriptors to
@@ -1294,6 +1298,10 @@ inline bool FieldDescriptor::is_repeated() const {
   return label() == LABEL_REPEATED;
 }
 
+inline bool FieldDescriptor::is_packable() const {
+  return is_repeated() && IsTypePackable(type());
+}
+
 // To save space, index() is computed by looking at the descriptor's position
 // in the parent's array of children.
 inline int FieldDescriptor::index() const {
@@ -1342,6 +1350,13 @@ inline FieldDescriptor::CppType FieldDescriptor::TypeToCppType(Type type) {
   return kTypeToCppTypeMap[type];
 }
 
+inline bool FieldDescriptor::IsTypePackable(Type field_type) {
+  return (field_type != FieldDescriptor::TYPE_STRING &&
+          field_type != FieldDescriptor::TYPE_GROUP &&
+          field_type != FieldDescriptor::TYPE_MESSAGE &&
+          field_type != FieldDescriptor::TYPE_BYTES);
+}
+
 inline const FileDescriptor* FileDescriptor::dependency(int index) const {
   return dependencies_[index];
 }

文件差異過大導致無法顯示
+ 270 - 218
src/google/protobuf/descriptor.pb.cc


文件差異過大導致無法顯示
+ 211 - 98
src/google/protobuf/descriptor.pb.h


+ 21 - 1
src/google/protobuf/descriptor.proto

@@ -256,6 +256,22 @@ message FileOptions {
 
 
 
+
+  // Should generic services be generated in each language?  "Generic" services
+  // are not specific to any particular RPC system.  They are generated by the
+  // main code generators in each language (without additional plugins).
+  // Generic services were the only kind of service generation supported by
+  // early versions of proto2.
+  //
+  // Generic services are now considered deprecated in favor of using plugins
+  // that generate code specific to your particular RPC system.  If you are
+  // using such a plugin, set these to false.  In the future, we may change
+  // the default to false, so if you explicitly want generic services, you
+  // should explicitly set these to true.
+  optional bool cc_generic_services = 16 [default=true];
+  optional bool java_generic_services = 17 [default=true];
+  optional bool py_generic_services = 18 [default=true];
+
   // The parser stores options it doesn't recognize here. See above.
   repeated UninterpretedOption uninterpreted_option = 999;
 
@@ -301,8 +317,11 @@ message FieldOptions {
   // representation of the field than it normally would.  See the specific
   // options below.  This option is not yet implemented in the open source
   // release -- sorry, we'll try to include it in a future version!
-  optional CType ctype = 1;
+  optional CType ctype = 1 [default = STRING];
   enum CType {
+    // Default mode.
+    STRING = 0;
+
     CORD = 1;
 
     STRING_PIECE = 2;
@@ -313,6 +332,7 @@ message FieldOptions {
   // a single length-delimited blob.
   optional bool packed = 2;
 
+
   // Is this field deprecated?
   // Depending on the target platform, this can emit Deprecated annotations
   // for accessors, or it will be completely ignored; in the very least, this

+ 30 - 0
src/google/protobuf/descriptor_database.cc

@@ -37,6 +37,7 @@
 #include <set>
 
 #include <google/protobuf/descriptor.pb.h>
+#include <google/protobuf/wire_format_lite_inl.h>
 #include <google/protobuf/stubs/strutil.h>
 #include <google/protobuf/stubs/stl_util-inl.h>
 #include <google/protobuf/stubs/map-util.h>
@@ -336,6 +337,35 @@ bool EncodedDescriptorDatabase::FindFileContainingSymbol(
   return MaybeParse(index_.FindSymbol(symbol_name), output);
 }
 
+bool EncodedDescriptorDatabase::FindNameOfFileContainingSymbol(
+    const string& symbol_name,
+    string* output) {
+  pair<const void*, int> encoded_file = index_.FindSymbol(symbol_name);
+  if (encoded_file.first == NULL) return false;
+
+  // Optimization:  The name should be the first field in the encoded message.
+  //   Try to just read it directly.
+  io::CodedInputStream input(reinterpret_cast<const uint8*>(encoded_file.first),
+                             encoded_file.second);
+
+  const uint32 kNameTag = internal::WireFormatLite::MakeTag(
+      FileDescriptorProto::kNameFieldNumber,
+      internal::WireFormatLite::WIRETYPE_LENGTH_DELIMITED);
+
+  if (input.ReadTag() == kNameTag) {
+    // Success!
+    return internal::WireFormatLite::ReadString(&input, output);
+  } else {
+    // Slow path.  Parse whole message.
+    FileDescriptorProto file_proto;
+    if (!file_proto.ParseFromArray(encoded_file.first, encoded_file.second)) {
+      return false;
+    }
+    *output = file_proto.name();
+    return true;
+  }
+}
+
 bool EncodedDescriptorDatabase::FindFileContainingExtension(
     const string& containing_type,
     int field_number,

+ 4 - 0
src/google/protobuf/descriptor_database.h

@@ -280,6 +280,10 @@ class LIBPROTOBUF_EXPORT EncodedDescriptorDatabase : public DescriptorDatabase {
   // need to keep it around.
   bool AddCopy(const void* encoded_file_descriptor, int size);
 
+  // Like FindFileContainingSymbol but returns only the name of the file.
+  bool FindNameOfFileContainingSymbol(const string& symbol_name,
+                                      string* output);
+
   // implements DescriptorDatabase -----------------------------------
   bool FindFileByName(const string& filename,
                       FileDescriptorProto* output);

部分文件因文件數量過多而無法顯示