Selaa lähdekoodia

Integrate recent changes from Google-internal code tree. See CHANGES.txt
for details.

kenton@google.com 16 vuotta sitten
vanhempi
commit
d37d46dfbc
97 muutettua tiedostoa jossa 7878 lisäystä ja 3349 poistoa
  1. 76 5
      CHANGES.txt
  2. 2 1
      CONTRIBUTORS.txt
  3. 3 0
      Makefile.am
  4. 57 1
      java/src/main/java/com/google/protobuf/AbstractMessage.java
  5. 51 0
      java/src/main/java/com/google/protobuf/BlockingRpcChannel.java
  6. 64 0
      java/src/main/java/com/google/protobuf/BlockingService.java
  7. 60 2
      java/src/main/java/com/google/protobuf/CodedInputStream.java
  8. 1 1
      java/src/main/java/com/google/protobuf/Descriptors.java
  9. 6 1
      java/src/main/java/com/google/protobuf/DynamicMessage.java
  10. 41 2
      java/src/main/java/com/google/protobuf/Message.java
  11. 3 3
      java/src/main/java/com/google/protobuf/RpcUtil.java
  12. 42 0
      java/src/main/java/com/google/protobuf/ServiceException.java
  13. 45 0
      java/src/main/java/com/google/protobuf/UnknownFieldSet.java
  14. 30 8
      java/src/test/java/com/google/protobuf/AbstractMessageTest.java
  15. 61 0
      java/src/test/java/com/google/protobuf/CodedInputStreamTest.java
  16. 12 0
      java/src/test/java/com/google/protobuf/DynamicMessageTest.java
  17. 11 0
      java/src/test/java/com/google/protobuf/GeneratedMessageTest.java
  18. 96 3
      java/src/test/java/com/google/protobuf/ServiceTest.java
  19. 96 3
      java/src/test/java/com/google/protobuf/UnknownFieldSetTest.java
  20. 20 0
      java/src/test/java/com/google/protobuf/WireFormatTest.java
  21. 114 1
      python/google/protobuf/internal/input_stream.py
  22. 20 1
      python/google/protobuf/internal/input_stream_test.py
  23. 9 0
      python/google/protobuf/internal/reflection_test.py
  24. 4 0
      python/google/protobuf/reflection.py
  25. 6 0
      src/Makefile.am
  26. 21 3
      src/google/protobuf/compiler/cpp/cpp_enum.cc
  27. 52 20
      src/google/protobuf/compiler/cpp/cpp_enum_field.cc
  28. 4 2
      src/google/protobuf/compiler/cpp/cpp_enum_field.h
  29. 72 9
      src/google/protobuf/compiler/cpp/cpp_extension.cc
  30. 3 0
      src/google/protobuf/compiler/cpp/cpp_extension.h
  31. 13 10
      src/google/protobuf/compiler/cpp/cpp_field.h
  32. 115 37
      src/google/protobuf/compiler/cpp/cpp_file.cc
  33. 44 3
      src/google/protobuf/compiler/cpp/cpp_helpers.cc
  34. 8 2
      src/google/protobuf/compiler/cpp/cpp_helpers.h
  35. 225 230
      src/google/protobuf/compiler/cpp/cpp_message.cc
  36. 26 3
      src/google/protobuf/compiler/cpp/cpp_message.h
  37. 26 8
      src/google/protobuf/compiler/cpp/cpp_message_field.cc
  38. 4 2
      src/google/protobuf/compiler/cpp/cpp_message_field.h
  39. 52 47
      src/google/protobuf/compiler/cpp/cpp_primitive_field.cc
  40. 4 2
      src/google/protobuf/compiler/cpp/cpp_primitive_field.h
  41. 4 2
      src/google/protobuf/compiler/cpp/cpp_service.cc
  42. 52 48
      src/google/protobuf/compiler/cpp/cpp_string_field.cc
  43. 4 2
      src/google/protobuf/compiler/cpp/cpp_string_field.h
  44. 136 4
      src/google/protobuf/compiler/cpp/cpp_unittest.cc
  45. 26 8
      src/google/protobuf/compiler/importer.cc
  46. 10 0
      src/google/protobuf/compiler/importer.h
  47. 28 0
      src/google/protobuf/compiler/importer_unittest.cc
  48. 18 2
      src/google/protobuf/compiler/java/java_message.cc
  49. 243 38
      src/google/protobuf/compiler/java/java_service.cc
  50. 33 0
      src/google/protobuf/compiler/java/java_service.h
  51. 6 3
      src/google/protobuf/compiler/parser.cc
  52. 14 1
      src/google/protobuf/compiler/parser.h
  53. 38 6
      src/google/protobuf/compiler/parser_unittest.cc
  54. 346 193
      src/google/protobuf/descriptor.cc
  55. 86 11
      src/google/protobuf/descriptor.h
  56. 402 238
      src/google/protobuf/descriptor.pb.cc
  57. 189 381
      src/google/protobuf/descriptor.pb.h
  58. 8 1
      src/google/protobuf/descriptor.proto
  59. 300 97
      src/google/protobuf/descriptor_database.cc
  60. 186 21
      src/google/protobuf/descriptor_database.h
  61. 255 158
      src/google/protobuf/descriptor_database_unittest.cc
  62. 310 3
      src/google/protobuf/descriptor_unittest.cc
  63. 2 2
      src/google/protobuf/dynamic_message.cc
  64. 471 257
      src/google/protobuf/extension_set.cc
  65. 314 127
      src/google/protobuf/extension_set.h
  66. 82 4
      src/google/protobuf/extension_set_unittest.cc
  67. 51 16
      src/google/protobuf/generated_message_reflection.cc
  68. 26 0
      src/google/protobuf/generated_message_reflection.h
  69. 155 103
      src/google/protobuf/io/coded_stream.cc
  70. 145 34
      src/google/protobuf/io/coded_stream.h
  71. 12 8
      src/google/protobuf/io/coded_stream_unittest.cc
  72. 3 2
      src/google/protobuf/io/zero_copy_stream_impl.cc
  73. 90 22
      src/google/protobuf/message.cc
  74. 37 21
      src/google/protobuf/message.h
  75. 16 20
      src/google/protobuf/reflection_ops_unittest.cc
  76. 124 0
      src/google/protobuf/repeated_field.h
  77. 149 0
      src/google/protobuf/repeated_field_unittest.cc
  78. 6 0
      src/google/protobuf/stubs/common.h
  79. 27 0
      src/google/protobuf/stubs/hash.h
  80. 15 0
      src/google/protobuf/stubs/map-util.h
  81. 82 0
      src/google/protobuf/stubs/once.cc
  82. 122 0
      src/google/protobuf/stubs/once.h
  83. 253 0
      src/google/protobuf/stubs/once_unittest.cc
  84. 22 0
      src/google/protobuf/test_util.cc
  85. 6 0
      src/google/protobuf/test_util.h
  86. 178 93
      src/google/protobuf/text_format.cc
  87. 70 29
      src/google/protobuf/text_format.h
  88. 73 10
      src/google/protobuf/text_format_unittest.cc
  89. 8 0
      src/google/protobuf/unittest.proto
  90. 37 0
      src/google/protobuf/unittest_empty.proto
  91. 108 136
      src/google/protobuf/unknown_field_set.cc
  92. 94 321
      src/google/protobuf/unknown_field_set.h
  93. 133 228
      src/google/protobuf/unknown_field_set_unittest.cc
  94. 169 128
      src/google/protobuf/wire_format.cc
  95. 136 47
      src/google/protobuf/wire_format.h
  96. 310 98
      src/google/protobuf/wire_format_inl.h
  97. 59 16
      src/google/protobuf/wire_format_unittest.cc

+ 76 - 5
CHANGES.txt

@@ -1,4 +1,4 @@
-????-??-?? version 2.0.4:
+????-??-?? version 2.1.0:
 
   General
   * Repeated fields of primitive types (types other that string, group, and
@@ -20,22 +20,70 @@
   * Updated bundled Google Test to version 1.3.0.  Google Test is now bundled
     in its verbatim form as a nested autoconf package, so you can drop in any
     other version of Google Test if needed.
+  * optimize_for = SPEED is now the default, by popular demand.  Use
+    optimize_for = CODE_SIZE if code size is more important in your app.
+  * It is now an error to define a default value for a repeated field.
+    Previously, this was silently ignored (it had no effect on the generated
+    code).
+  * Fields can now be marked deprecated like:
+      optional int32 foo = 1 [deprecated = true];
+    Currently this does not have any actual effect, but in the future the code
+    generators may generate deprecation annotations in each language.
 
   protoc
   * --error_format=msvs option causes errors to be printed in Visual Studio
     format, which should allow them to be clicked on in the build log to go
-    directly to the error location. 
+    directly to the error location.
+  * The type name resolver will no longer resolve type names to fields.  For
+    example, this now works:
+      message Foo {}
+      message Bar {
+        optional int32 Foo = 1;
+        optional Foo baz = 2;
+      }
+    Previously, the type of "baz" would resolve to "Bar.Foo", and you'd get
+    an error because Bar.Foo is a field, not a type.  Now the type of "baz"
+    resolves to the message type Foo.  This change is unlikely to make a
+    difference to anyone who follows the Protocol Buffers style guide.
 
   C++
-  * UnknownFieldSet now supports STL-like iteration.
+  * Several optimizations, including but not limited to:
+    - Serialization, especially to flat arrays, is 10%-50% faster, possibly
+      more for small objects.
+    - Several descriptor operations which previously required locking no longer
+      do.
+    - Descriptors are now constructed lazily on first use, rather than at
+      process startup time.  This should save memory in programs which do not
+      use descriptors or reflection.
+    - UnknownFieldSet completely redesigned to be more efficient (especially in
+      terms of memory usage).
+    - Various optimizations to reduce code size (though the serialization speed
+      optimizations increased code size).
   * Message interface has method ParseFromBoundedZeroCopyStream() which parses
     a limited number of bytes from an input stream rather than parsing until
     EOF.
   * GzipInputStream and GzipOutputStream support reading/writing gzip- or
     zlib-compressed streams if zlib is available.
     (google/protobuf/io/gzip_stream.h)
-  * Generated constructors explicitly initialize all fields (to avoid warnings
-    with certain compiler settings).
+  * DescriptorPool::FindAllExtensions() and corresponding
+    DescriptorDatabase::FindAllExtensions() can be used to enumerate all
+    extensions of a given type.
+  * For each enum type Foo, protoc will generate functions:
+      const string& Foo_Name(Foo value);
+      bool Foo_Parse(const string& name, Foo* result);
+    The former returns the name of the enum constant corresponding to the given
+    value while the latter finds the value corresponding to a name.
+  * RepeatedField and RepeatedPtrField now have back-insertion iterators.
+  * String fields now have setters that take a char* and a size, in addition
+    to the existing ones that took char* or const string&.
+  * DescriptorPool::AllowUnknownDependencies() may be used to tell
+    DescriptorPool to create placeholder descriptors for unknown entities
+    referenced in a FileDescriptorProto.  This can allow you to parse a .proto
+    file without having access to other .proto files that it imports, for
+    example.
+  * Updated gtest to latest version.  The gtest package is now included as a
+    nested autoconf package, so it should be able to drop new versions into the
+    "gtest" subdirectory without modification.
 
   Java
   * Fixed bug where Message.mergeFrom(Message) failed to merge extensions.
@@ -48,6 +96,28 @@
     regex implementation (which unfortunately uses recursive backtracking
     rather than building an NFA).  Worked around by making use of possesive
     quantifiers.
+  * Generated service classes now also generate pure interfaces.  For a service
+    Foo, Foo.Interface is a pure interface containing all of the service's
+    defined methods.  Foo.newReflectiveService() can be called to wrap an
+    instance of this interface in a class that implements the generic
+    RpcService interface, which provides reflection support that is usually
+    needed by RPC server implementations.
+  * RPC interfaces now support blocking operation in addition to non-blocking.
+    The protocol compiler generates separate blocking and non-blocking stubs
+    which operate against separate blocking and non-blocking RPC interfaces.
+    RPC implementations will have to implement the new interfaces in order to
+    support blocking mode.
+  * New I/O methods parseDelimitedFrom(), mergeDelimitedFrom(), and
+    writeDelimitedTo() read and write "delemited" messages from/to a stream,
+    meaning that the message size precedes the data.  This way, you can write
+    multiple messages to a stream without having to worry about delimiting
+    them yourself.
+  * Throw a more descriptive exception when build() is double-called.
+  * Add a method to query whether CodedInputStream is at the end of the input
+    stream.
+  * Add a method to reset a CodedInputStream's size counter; useful when
+    reading many messages with the same stream.
+  * equals() and hashCode() now account for unknown fields.
 
   Python
   * Added slicing support for repeated scalar fields. Added slice retrieval and
@@ -58,6 +128,7 @@
     object will be returned directly to the caller.  This interface change
     cannot be used in practice until RPC implementations are updated to
     implement it.
+  * Changes to input_stream.py should make protobuf compatible with appengine.
 
 2008-11-25 version 2.0.3:
 

+ 2 - 1
CONTRIBUTORS.txt

@@ -18,6 +18,7 @@ Proto2 Python primary authors:
   Petar Petrov <petar@google.com>
 
 Large code contributions:
+  Jason Hsueh <jasonh@google.com>
   Joseph Schorr <jschorr@google.com>
   Wenbo Zhu <wenboz@google.com>
 
@@ -38,7 +39,7 @@ Patch contributors:
   Kevin Ko <kevin.s.ko@gmail.com>
     * Small patch to handle trailing slashes in --proto_path flag.
   Johan Euphrosine <proppy@aminche.com>
-    * Small patch to fix Pyhton CallMethod().
+    * Small patch to fix Python CallMethod().
   Ulrich Kunitz <kune@deine-taler.de>
     * Small optimizations to Python serialization.
   Leandro Lucarella <llucax@gmail.com>

+ 3 - 0
Makefile.am

@@ -61,6 +61,8 @@ EXTRA_DIST =                                                                 \
   examples/add_person.py                                                     \
   examples/list_people.py                                                    \
   java/src/main/java/com/google/protobuf/AbstractMessage.java                \
+  java/src/main/java/com/google/protobuf/BlockingRpcChannel.java             \
+  java/src/main/java/com/google/protobuf/BlockingService.java                \
   java/src/main/java/com/google/protobuf/ByteString.java                     \
   java/src/main/java/com/google/protobuf/CodedInputStream.java               \
   java/src/main/java/com/google/protobuf/CodedOutputStream.java              \
@@ -77,6 +79,7 @@ EXTRA_DIST =                                                                 \
   java/src/main/java/com/google/protobuf/RpcController.java                  \
   java/src/main/java/com/google/protobuf/RpcUtil.java                        \
   java/src/main/java/com/google/protobuf/Service.java                        \
+  java/src/main/java/com/google/protobuf/ServiceException.java               \
   java/src/main/java/com/google/protobuf/TextFormat.java                     \
   java/src/main/java/com/google/protobuf/UninitializedMessageException.java  \
   java/src/main/java/com/google/protobuf/UnknownFieldSet.java                \

+ 57 - 1
java/src/main/java/com/google/protobuf/AbstractMessage.java

@@ -32,6 +32,7 @@ package com.google.protobuf;
 
 import com.google.protobuf.Descriptors.FieldDescriptor;
 
+import java.io.FilterInputStream;
 import java.io.InputStream;
 import java.io.IOException;
 import java.io.OutputStream;
@@ -152,6 +153,13 @@ public abstract class AbstractMessage implements Message {
     codedOutput.flush();
   }
 
+  public void writeDelimitedTo(OutputStream output) throws IOException {
+    CodedOutputStream codedOutput = CodedOutputStream.newInstance(output);
+    codedOutput.writeRawVarint32(getSerializedSize());
+    writeTo(codedOutput);
+    codedOutput.flush();
+  }
+
   private int memoizedSize = -1;
 
   public int getSerializedSize() {
@@ -207,7 +215,8 @@ public abstract class AbstractMessage implements Message {
     if (getDescriptorForType() != otherMessage.getDescriptorForType()) {
       return false;
     }
-    return getAllFields().equals(otherMessage.getAllFields());
+    return getAllFields().equals(otherMessage.getAllFields()) &&
+        getUnknownFields().equals(otherMessage.getUnknownFields());
   }
 
   @Override
@@ -215,6 +224,7 @@ public abstract class AbstractMessage implements Message {
     int hash = 41;
     hash = (19 * hash) + getDescriptorForType().hashCode();
     hash = (53 * hash) + getAllFields().hashCode();
+    hash = (29 * hash) + getUnknownFields().hashCode();
     return hash;
   }
 
@@ -397,5 +407,51 @@ public abstract class AbstractMessage implements Message {
       codedInput.checkLastTagWas(0);
       return (BuilderType) this;
     }
+
+    public BuilderType mergeDelimitedFrom(InputStream input,
+                                          ExtensionRegistry extensionRegistry)
+                                          throws IOException {
+      final int size = CodedInputStream.readRawVarint32(input);
+
+      // A stream which will not read more than |size| bytes.
+      InputStream limitedInput = new FilterInputStream(input) {
+        int limit = size;
+
+        @Override
+        public int available() throws IOException {
+          return Math.min(super.available(), limit);
+        }
+
+        @Override
+        public int read() throws IOException {
+          if (limit <= 0) return -1;
+          int result = super.read();
+          if (result >= 0) --limit;
+          return result;
+        }
+
+        @Override
+        public int read(byte[] b, int off, int len) throws IOException {
+          if (limit <= 0) return -1;
+          len = Math.min(len, limit);
+          int result = super.read(b, off, len);
+          if (result >= 0) limit -= result;
+          return result;
+        }
+
+        @Override
+        public long skip(long n) throws IOException {
+          long result = super.skip(Math.min(n, limit));
+          if (result >= 0) limit -= result;
+          return result;
+        }
+      };
+      return mergeFrom(limitedInput, extensionRegistry);
+    }
+
+    public BuilderType mergeDelimitedFrom(InputStream input)
+        throws IOException {
+      return mergeDelimitedFrom(input, ExtensionRegistry.getEmptyRegistry());
+    }
   }
 }

+ 51 - 0
java/src/main/java/com/google/protobuf/BlockingRpcChannel.java

@@ -0,0 +1,51 @@
+// Protocol Buffers - Google's data interchange format
+// Copyright 2008 Google Inc.  All rights reserved.
+// http://code.google.com/p/protobuf/
+//
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted provided that the following conditions are
+// met:
+//
+//     * Redistributions of source code must retain the above copyright
+// notice, this list of conditions and the following disclaimer.
+//     * Redistributions in binary form must reproduce the above
+// copyright notice, this list of conditions and the following disclaimer
+// in the documentation and/or other materials provided with the
+// distribution.
+//     * Neither the name of Google Inc. nor the names of its
+// contributors may be used to endorse or promote products derived from
+// this software without specific prior written permission.
+//
+// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+package com.google.protobuf;
+
+/**
+ * <p>Abstract interface for a blocking RPC channel.  {@code BlockingRpcChannel}
+ * is the blocking equivalent to {@link RpcChannel}.
+ *
+ * @author kenton@google.com Kenton Varda
+ * @author cpovirk@google.com Chris Povirk
+ */
+public interface BlockingRpcChannel {
+  /**
+   * Call the given method of the remote service and blocks until it returns.
+   * {@code callBlockingMethod()} is the blocking equivalent to
+   * {@link RpcChannel#callMethod}.
+   */
+  Message callBlockingMethod(
+      Descriptors.MethodDescriptor method,
+      RpcController controller,
+      Message request,
+      Message responsePrototype) throws ServiceException;
+}

+ 64 - 0
java/src/main/java/com/google/protobuf/BlockingService.java

@@ -0,0 +1,64 @@
+// Protocol Buffers - Google's data interchange format
+// Copyright 2008 Google Inc.  All rights reserved.
+// http://code.google.com/p/protobuf/
+//
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted provided that the following conditions are
+// met:
+//
+//     * Redistributions of source code must retain the above copyright
+// notice, this list of conditions and the following disclaimer.
+//     * Redistributions in binary form must reproduce the above
+// copyright notice, this list of conditions and the following disclaimer
+// in the documentation and/or other materials provided with the
+// distribution.
+//     * Neither the name of Google Inc. nor the names of its
+// contributors may be used to endorse or promote products derived from
+// this software without specific prior written permission.
+//
+// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+package com.google.protobuf;
+
+/**
+ * Blocking equivalent to {@link Service}.
+ *
+ * @author kenton@google.com Kenton Varda
+ * @author cpovirk@google.com Chris Povirk
+ */
+public interface BlockingService {
+  /**
+   * Equivalent to {@link Service#getDescriptorForType}.
+   */
+  Descriptors.ServiceDescriptor getDescriptorForType();
+
+  /**
+   * Equivalent to {@link Service#callMethod}, except that
+   * {@code callBlockingMethod()} returns the result of the RPC or throws a
+   * {@link ServiceException} if there is a failure, rather than passing the
+   * information to a callback.
+   */
+  Message callBlockingMethod(Descriptors.MethodDescriptor method,
+                             RpcController controller,
+                             Message request) throws ServiceException;
+
+  /**
+   * Equivalent to {@link Service#getRequestPrototype}.
+   */
+  Message getRequestPrototype(Descriptors.MethodDescriptor method);
+
+  /**
+   * Equivalent to {@link Service#getResponsePrototype}.
+   */
+  Message getResponsePrototype(Descriptors.MethodDescriptor method);
+}

+ 60 - 2
java/src/main/java/com/google/protobuf/CodedInputStream.java

@@ -77,7 +77,7 @@ public final class CodedInputStream {
    * may legally end wherever a tag occurs, and zero is not a valid tag number.
    */
   public int readTag() throws IOException {
-    if (bufferPos == bufferSize && !refillBuffer(false)) {
+    if (isAtEnd()) {
       lastTag = 0;
       return 0;
     }
@@ -383,6 +383,39 @@ public final class CodedInputStream {
     return result;
   }
 
+  /**
+   * Reads a varint from the input one byte at a time, so that it does not
+   * read any bytes after the end of the varint.  If you simply wrapped the
+   * stream in a CodedInputStream and used {@link #readRawVarint32(InputStream)}
+   * then you would probably end up reading past the end of the varint since
+   * CodedInputStream buffers its input.
+   */
+  static int readRawVarint32(InputStream input) throws IOException {
+    int result = 0;
+    int offset = 0;
+    for (; offset < 32; offset += 7) {
+      int b = input.read();
+      if (b == -1) {
+        throw InvalidProtocolBufferException.truncatedMessage();
+      }
+      result |= (b & 0x7f) << offset;
+      if ((b & 0x80) == 0) {
+        return result;
+      }
+    }
+    // Keep reading up to 64 bits.
+    for (; offset < 64; offset += 7) {
+      int b = input.read();
+      if (b == -1) {
+        throw InvalidProtocolBufferException.truncatedMessage();
+      }
+      if ((b & 0x80) == 0) {
+        return result;
+      }
+    }
+    throw InvalidProtocolBufferException.malformedVarint();
+  }
+
   /** Read a raw Varint from the stream. */
   public long readRawVarint64() throws IOException {
     int shift = 0;
@@ -526,6 +559,10 @@ public final class CodedInputStream {
    * size limits only apply when reading from an {@code InputStream}, not
    * when constructed around a raw byte array (nor with
    * {@link ByteString#newCodedInput}).
+   * <p>
+   * If you want to read several messages from a single CodedInputStream, you
+   * could call {@link #resetSizeCounter()} after each one to avoid hitting the
+   * size limit.
    *
    * @return the old limit.
    */
@@ -539,6 +576,13 @@ public final class CodedInputStream {
     return oldLimit;
   }
 
+  /**
+   * Resets the current size counter to zero (see {@link #setSizeLimit(int)}).
+   */
+  public void resetSizeCounter() {
+    totalBytesRetired = 0;
+  }
+
   /**
    * Sets {@code currentLimit} to (current position) + {@code byteLimit}.  This
    * is called when descending into a length-delimited embedded message.
@@ -596,6 +640,15 @@ public final class CodedInputStream {
     return currentLimit - currentAbsolutePosition;
   }
 
+  /**
+   * Returns true if the stream has reached the end of the input.  This is the
+   * case if either the end of the underlying input source has been reached or
+   * if the stream has reached a limit created using {@link #pushLimit(int)}.
+   */
+  public boolean isAtEnd() throws IOException {
+    return bufferPos == bufferSize && !refillBuffer(false);
+  }
+
   /**
    * Called with {@code this.buffer} is empty to read more bytes from the
    * input.  If {@code mustSucceed} is true, refillBuffer() gurantees that
@@ -622,6 +675,11 @@ public final class CodedInputStream {
 
     bufferPos = 0;
     bufferSize = (input == null) ? -1 : input.read(buffer);
+    if (bufferSize == 0 || bufferSize < -1) {
+      throw new IllegalStateException(
+          "InputStream#read(byte[]) returned invalid result: " + bufferSize +
+          "\nThe InputStream implementation is buggy.");
+    }
     if (bufferSize == -1) {
       bufferSize = 0;
       if (mustSucceed) {
@@ -778,7 +836,7 @@ public final class CodedInputStream {
       throw InvalidProtocolBufferException.truncatedMessage();
     }
 
-    if (size < bufferSize - bufferPos) {
+    if (size <= bufferSize - bufferPos) {
       // We have all the bytes we need already.
       bufferPos += size;
     } else {

+ 1 - 1
java/src/main/java/com/google/protobuf/Descriptors.java

@@ -1689,7 +1689,7 @@ public final class Descriptors {
 
       GenericDescriptor old =
         descriptorsByName.put(fullName,
-          new PackageDescriptor(fullName, name, file));
+          new PackageDescriptor(name, fullName, file));
       if (old != null) {
         descriptorsByName.put(fullName, old);
         if (!(old instanceof PackageDescriptor)) {

+ 6 - 1
java/src/main/java/com/google/protobuf/DynamicMessage.java

@@ -260,7 +260,8 @@ public final class DynamicMessage extends AbstractMessage {
     }
 
     public DynamicMessage build() {
-      if (!isInitialized()) {
+      // If fields == null, we'll throw an appropriate exception later.
+      if (fields != null && !isInitialized()) {
         throw new UninitializedMessageException(
           new DynamicMessage(type, fields, unknownFields));
       }
@@ -282,6 +283,10 @@ public final class DynamicMessage extends AbstractMessage {
     }
 
     public DynamicMessage buildPartial() {
+      if (fields == null) {
+        throw new IllegalStateException(
+            "build() has already been called on this Builder.");
+      }
       fields.makeImmutable();
       DynamicMessage result =
         new DynamicMessage(type, fields, unknownFields);

+ 41 - 2
java/src/main/java/com/google/protobuf/Message.java

@@ -183,9 +183,28 @@ public interface Message {
    * Serializes the message and writes it to {@code output}.  This is just a
    * trivial wrapper around {@link #writeTo(CodedOutputStream)}.  This does
    * not flush or close the stream.
+   * <p>
+   * NOTE:  Protocol Buffers are not self-delimiting.  Therefore, if you write
+   * any more data to the stream after the message, you must somehow ensure
+   * that the parser on the receiving end does not interpret this as being
+   * part of the protocol message.  This can be done e.g. by writing the size
+   * of the message before the data, then making sure to limit the input to
+   * that size on the receiving end (e.g. by wrapping the InputStream in one
+   * which limits the input).  Alternatively, just use
+   * {@link #writeDelimitedTo(OutputStream)}.
    */
   void writeTo(OutputStream output) throws IOException;
 
+  /**
+   * Like {@link #writeTo(OutputStream)}, but writes the size of the message
+   * as a varint before writing the data.  This allows more data to be written
+   * to the stream after the message without the need to delimit the message
+   * data yourself.  Use {@link Builder#mergeDelimitedFrom(InputStream)} (or
+   * the static method {@code YourMessageType.parseDelimitedFrom(InputStream)})
+   * to parse messages written by this method.
+   */
+  void writeDelimitedTo(OutputStream output) throws IOException;
+
   // =================================================================
   // Builders
 
@@ -434,8 +453,11 @@ public interface Message {
      * {@link #mergeFrom(CodedInputStream)}.  Note that this method always
      * reads the <i>entire</i> input (unless it throws an exception).  If you
      * want it to stop earlier, you will need to wrap your input in some
-     * wrapper stream that limits reading.  Despite usually reading the entire
-     * input, this does not close the stream.
+     * wrapper stream that limits reading.  Or, use
+     * {@link Message#writeDelimitedTo(OutputStream)} to write your message and
+     * {@link #mergeDelimitedFrom(InputStream)} to read it.
+     * <p>
+     * Despite usually reading the entire input, this does not close the stream.
      */
     Builder mergeFrom(InputStream input) throws IOException;
 
@@ -447,5 +469,22 @@ public interface Message {
     Builder mergeFrom(InputStream input,
                       ExtensionRegistry extensionRegistry)
                       throws IOException;
+
+    /**
+     * Like {@link #mergeFrom(InputStream)}, but does not read until EOF.
+     * Instead, the size of the message (encoded as a varint) is read first,
+     * then the message data.  Use
+     * {@link Message#writeDelimitedTo(OutputStream)} to write messages in this
+     * format.
+     */
+    Builder mergeDelimitedFrom(InputStream input)
+                               throws IOException;
+
+    /**
+     * Like {@link #mergeDelimitedFrom(InputStream)} but supporting extensions.
+     */
+    Builder mergeDelimitedFrom(InputStream input,
+                               ExtensionRegistry extensionRegistry)
+                               throws IOException;
   }
 }

+ 3 - 3
java/src/main/java/com/google/protobuf/RpcUtil.java

@@ -39,7 +39,7 @@ public final class RpcUtil {
   private RpcUtil() {}
 
   /**
-   * Take an {@code RcpCallabck<Message>} and convert it to an
+   * Take an {@code RpcCallback<Message>} and convert it to an
    * {@code RpcCallback} accepting a specific message type.  This is always
    * type-safe (parameter type contravariance).
    */
@@ -58,8 +58,8 @@ public final class RpcUtil {
   }
 
   /**
-   * Take an {@code RcpCallabck} accepting a specific message type and convert
-   * it to an {@code RcpCallabck<Message>}.  The generalized callback will
+   * Take an {@code RpcCallback} accepting a specific message type and convert
+   * it to an {@code RpcCallback<Message>}.  The generalized callback will
    * accept any message object which has the same descriptor, and will convert
    * it to the correct class before calling the original callback.  However,
    * if the generalized callback is given a message with a different descriptor,

+ 42 - 0
java/src/main/java/com/google/protobuf/ServiceException.java

@@ -0,0 +1,42 @@
+// Protocol Buffers - Google's data interchange format
+// Copyright 2008 Google Inc.  All rights reserved.
+// http://code.google.com/p/protobuf/
+//
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted provided that the following conditions are
+// met:
+//
+//     * Redistributions of source code must retain the above copyright
+// notice, this list of conditions and the following disclaimer.
+//     * Redistributions in binary form must reproduce the above
+// copyright notice, this list of conditions and the following disclaimer
+// in the documentation and/or other materials provided with the
+// distribution.
+//     * Neither the name of Google Inc. nor the names of its
+// contributors may be used to endorse or promote products derived from
+// this software without specific prior written permission.
+//
+// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+package com.google.protobuf;
+
+/**
+ * Thrown by blocking RPC methods when a failure occurs.
+ * 
+ * @author cpovirk@google.com (Chris Povirk)
+ */
+public final class ServiceException extends Exception {
+  public ServiceException(String message) {
+    super(message);
+  }
+}

+ 45 - 0
java/src/main/java/com/google/protobuf/UnknownFieldSet.java

@@ -34,6 +34,7 @@ import java.io.InputStream;
 import java.io.IOException;
 import java.io.OutputStream;
 import java.util.ArrayList;
+import java.util.Arrays;
 import java.util.Collections;
 import java.util.TreeMap;
 import java.util.List;
@@ -85,6 +86,20 @@ public final class UnknownFieldSet {
   }
   private Map<Integer, Field> fields;
 
+  @Override
+  public boolean equals(Object other) {
+    if (this == other) {
+      return true;
+    }
+    return (other instanceof UnknownFieldSet) &&
+        this.fields.equals(((UnknownFieldSet) other).fields);
+  }
+
+  @Override
+  public int hashCode() {
+    return this.fields.hashCode();
+  }
+
   /** Get a map of fields in the set by number. */
   public Map<Integer, Field> asMap() {
     return fields;
@@ -540,6 +555,36 @@ public final class UnknownFieldSet {
      */
     public List<UnknownFieldSet> getGroupList()      { return group;           }
 
+    @Override
+    public boolean equals(Object other) {
+      if (this == other) {
+        return true;
+      }
+      if (!(other instanceof Field)) {
+        return false;
+      }
+      return Arrays.equals(this.getIdentityArray(),
+          ((Field) other).getIdentityArray());
+    }
+
+    @Override
+    public int hashCode() {
+      return Arrays.hashCode(getIdentityArray());
+    }
+
+    /**
+     * Returns the array of objects to be used to uniquely identify this
+     * {@link UnknownFieldSet.Field} instance.
+     */
+    private Object[] getIdentityArray() {
+      return new Object[] {
+          this.varint,
+          this.fixed32,
+          this.fixed64,
+          this.lengthDelimited,
+          this.group};
+    }
+
     /**
      * Serializes the field, including field number, and writes it to
      * {@code output}.

+ 30 - 8
java/src/test/java/com/google/protobuf/AbstractMessageTest.java

@@ -30,14 +30,14 @@
 
 package com.google.protobuf;
 
+import protobuf_unittest.UnittestOptimizeFor.TestOptimizedForSize;
 import protobuf_unittest.UnittestProto;
 import protobuf_unittest.UnittestProto.ForeignMessage;
-import protobuf_unittest.UnittestProto.TestAllTypes;
 import protobuf_unittest.UnittestProto.TestAllExtensions;
+import protobuf_unittest.UnittestProto.TestAllTypes;
 import protobuf_unittest.UnittestProto.TestPackedTypes;
 import protobuf_unittest.UnittestProto.TestRequired;
 import protobuf_unittest.UnittestProto.TestRequiredForeign;
-import protobuf_unittest.UnittestOptimizeFor.TestOptimizedForSize;
 
 import junit.framework.TestCase;
 
@@ -329,7 +329,7 @@ public class AbstractMessageTest extends TestCase {
   // -----------------------------------------------------------------
   // Tests for equals and hashCode
   
-  public void testEqualsAndHashCode() {
+  public void testEqualsAndHashCode() throws Exception {
     TestAllTypes a = TestUtil.getAllSet();
     TestAllTypes b = TestAllTypes.newBuilder().build();
     TestAllTypes c = TestAllTypes.newBuilder(b).addRepeatedString("x").build();
@@ -337,7 +337,7 @@ public class AbstractMessageTest extends TestCase {
     TestAllExtensions e = TestUtil.getAllExtensionsSet();
     TestAllExtensions f = TestAllExtensions.newBuilder(e)
         .addExtension(UnittestProto.repeatedInt32Extension, 999).build();
-      
+
     checkEqualsIsConsistent(a);
     checkEqualsIsConsistent(b);
     checkEqualsIsConsistent(c);
@@ -364,10 +364,25 @@ public class AbstractMessageTest extends TestCase {
     checkNotEqual(d, f);
 
     checkNotEqual(e, f);
+
+    // Deserializing into the TestEmptyMessage such that every field
+    // is an {@link UnknownFieldSet.Field}.
+    UnittestProto.TestEmptyMessage eUnknownFields =
+        UnittestProto.TestEmptyMessage.parseFrom(e.toByteArray());
+    UnittestProto.TestEmptyMessage fUnknownFields =
+        UnittestProto.TestEmptyMessage.parseFrom(f.toByteArray());
+    checkNotEqual(eUnknownFields, fUnknownFields);
+    checkEqualsIsConsistent(eUnknownFields);
+    checkEqualsIsConsistent(fUnknownFields);
+
+    // Subseqent reconstitutions should be identical
+    UnittestProto.TestEmptyMessage eUnknownFields2 =
+        UnittestProto.TestEmptyMessage.parseFrom(e.toByteArray());
+    checkEqualsIsConsistent(eUnknownFields, eUnknownFields2);
   }
   
   /**
-   * Asserts that the given protos are equal and have the same hash code.
+   * Asserts that the given proto has symetric equals and hashCode methods.
    */
   private void checkEqualsIsConsistent(Message message) {
     // Object should be equal to itself.
@@ -375,9 +390,16 @@ public class AbstractMessageTest extends TestCase {
     
     // Object should be equal to a dynamic copy of itself.
     DynamicMessage dynamic = DynamicMessage.newBuilder(message).build();
-    assertEquals(message, dynamic);
-    assertEquals(dynamic, message);
-    assertEquals(dynamic.hashCode(), message.hashCode());
+    checkEqualsIsConsistent(message, dynamic);
+  }
+
+  /**
+   * Asserts that the given protos are equal and have the same hash code.
+   */
+  private void checkEqualsIsConsistent(Message message1, Message message2) {
+    assertEquals(message1, message2);
+    assertEquals(message2, message1);
+    assertEquals(message2.hashCode(), message1.hashCode());
   }
 
   /**

+ 61 - 0
java/src/test/java/com/google/protobuf/CodedInputStreamTest.java

@@ -95,6 +95,7 @@ public class CodedInputStreamTest extends TestCase {
 
     input = CodedInputStream.newInstance(data);
     assertEquals(value, input.readRawVarint64());
+    assertTrue(input.isAtEnd());
 
     // Try different block sizes.
     for (int blockSize = 1; blockSize <= 16; blockSize *= 2) {
@@ -105,7 +106,17 @@ public class CodedInputStreamTest extends TestCase {
       input = CodedInputStream.newInstance(
         new SmallBlockInputStream(data, blockSize));
       assertEquals(value, input.readRawVarint64());
+      assertTrue(input.isAtEnd());
     }
+
+    // Try reading direct from an InputStream.  We want to verify that it
+    // doesn't read past the end of the input, so we copy to a new, bigger
+    // array first.
+    byte[] longerData = new byte[data.length + 1];
+    System.arraycopy(data, 0, longerData, 0, data.length);
+    InputStream rawInput = new ByteArrayInputStream(longerData);
+    assertEquals((int)value, CodedInputStream.readRawVarint32(rawInput));
+    assertEquals(1, rawInput.available());
   }
 
   /**
@@ -131,6 +142,14 @@ public class CodedInputStreamTest extends TestCase {
     } catch (InvalidProtocolBufferException e) {
       assertEquals(expected.getMessage(), e.getMessage());
     }
+
+    // Make sure we get the same error when reading direct from an InputStream.
+    try {
+      CodedInputStream.readRawVarint32(new ByteArrayInputStream(data));
+      fail("Should have thrown an exception.");
+    } catch (InvalidProtocolBufferException e) {
+      assertEquals(expected.getMessage(), e.getMessage());
+    }
   }
 
   /** Tests readRawVarint32() and readRawVarint64(). */
@@ -180,12 +199,14 @@ public class CodedInputStreamTest extends TestCase {
                                         throws Exception {
     CodedInputStream input = CodedInputStream.newInstance(data);
     assertEquals(value, input.readRawLittleEndian32());
+    assertTrue(input.isAtEnd());
 
     // Try different block sizes.
     for (int blockSize = 1; blockSize <= 16; blockSize *= 2) {
       input = CodedInputStream.newInstance(
         new SmallBlockInputStream(data, blockSize));
       assertEquals(value, input.readRawLittleEndian32());
+      assertTrue(input.isAtEnd());
     }
   }
 
@@ -197,12 +218,14 @@ public class CodedInputStreamTest extends TestCase {
                                         throws Exception {
     CodedInputStream input = CodedInputStream.newInstance(data);
     assertEquals(value, input.readRawLittleEndian64());
+    assertTrue(input.isAtEnd());
 
     // Try different block sizes.
     for (int blockSize = 1; blockSize <= 16; blockSize *= 2) {
       input = CodedInputStream.newInstance(
         new SmallBlockInputStream(data, blockSize));
       assertEquals(value, input.readRawLittleEndian64());
+      assertTrue(input.isAtEnd());
     }
   }
 
@@ -288,6 +311,20 @@ public class CodedInputStreamTest extends TestCase {
     }
   }
 
+  /**
+   * Test that a bug in skipRawBytes() has been fixed:  if the skip skips
+   * exactly up to a limit, this should not break things.
+   */
+  public void testSkipRawBytesBug() throws Exception {
+    byte[] rawBytes = new byte[] { 1, 2 };
+    CodedInputStream input = CodedInputStream.newInstance(rawBytes);
+
+    int limit = input.pushLimit(1);
+    input.skipRawBytes(1);
+    input.popLimit(limit);
+    assertEquals(2, input.readRawByte());
+  }
+
   public void testReadHugeBlob() throws Exception {
     // Allocate and initialize a 1MB blob.
     byte[] blob = new byte[1 << 20];
@@ -392,6 +429,30 @@ public class CodedInputStreamTest extends TestCase {
     }
   }
 
+  public void testResetSizeCounter() throws Exception {
+    CodedInputStream input = CodedInputStream.newInstance(
+        new SmallBlockInputStream(new byte[256], 8));
+    input.setSizeLimit(16);
+    input.readRawBytes(16);
+
+    try {
+      input.readRawByte();
+      fail("Should have thrown an exception!");
+    } catch (InvalidProtocolBufferException e) {
+      // success.
+    }
+
+    input.resetSizeCounter();
+    input.readRawByte();  // No exception thrown.
+
+    try {
+      input.readRawBytes(16);  // Hits limit again.
+      fail("Should have thrown an exception!");
+    } catch (InvalidProtocolBufferException e) {
+      // success.
+    }
+  }
+
   /**
    * Tests that if we read an string that contains invalid UTF-8, no exception
    * is thrown.  Instead, the invalid bytes are replaced with the Unicode

+ 12 - 0
java/src/test/java/com/google/protobuf/DynamicMessageTest.java

@@ -61,6 +61,18 @@ public class DynamicMessageTest extends TestCase {
     reflectionTester.assertAllFieldsSetViaReflection(message);
   }
 
+  public void testDoubleBuildError() throws Exception {
+    Message.Builder builder =
+      DynamicMessage.newBuilder(TestAllTypes.getDescriptor());
+    builder.build();
+    try {
+      builder.build();
+      fail("Should have thrown exception.");
+    } catch (IllegalStateException e) {
+      // Success.
+    }
+  }
+
   public void testDynamicMessageSettersRejectNull() throws Exception {
     Message.Builder builder =
       DynamicMessage.newBuilder(TestAllTypes.getDescriptor());

+ 11 - 0
java/src/test/java/com/google/protobuf/GeneratedMessageTest.java

@@ -71,6 +71,17 @@ public class GeneratedMessageTest extends TestCase {
     TestUtil.assertAllFieldsSet(message);
   }
 
+  public void testDoubleBuildError() throws Exception {
+    TestAllTypes.Builder builder = TestAllTypes.newBuilder();
+    builder.build();
+    try {
+      builder.build();
+      fail("Should have thrown exception.");
+    } catch (IllegalStateException e) {
+      // Success.
+    }
+  }
+
   public void testSettersRejectNull() throws Exception {
     TestAllTypes.Builder builder = TestAllTypes.newBuilder();
     try {

+ 96 - 3
java/src/test/java/com/google/protobuf/ServiceTest.java

@@ -30,6 +30,10 @@
 
 package com.google.protobuf;
 
+import com.google.protobuf.Descriptors.MethodDescriptor;
+import protobuf_unittest.MessageWithNoOuter;
+import protobuf_unittest.ServiceWithNoOuter;
+import protobuf_unittest.UnittestProto.TestAllTypes;
 import protobuf_unittest.UnittestProto.TestService;
 import protobuf_unittest.UnittestProto.FooRequest;
 import protobuf_unittest.UnittestProto.FooResponse;
@@ -56,6 +60,7 @@ public class ServiceTest extends TestCase {
   private final Descriptors.MethodDescriptor barDescriptor =
     TestService.getDescriptor().getMethods().get(1);
 
+  @Override
   protected void setUp() throws Exception {
     super.setUp();
     control = EasyMock.createStrictControl();
@@ -127,6 +132,94 @@ public class ServiceTest extends TestCase {
     control.verify();
   }
 
+  /** Tests generated blocking stubs. */
+  public void testBlockingStub() throws Exception {
+    FooRequest fooRequest = FooRequest.newBuilder().build();
+    BarRequest barRequest = BarRequest.newBuilder().build();
+    BlockingRpcChannel mockChannel =
+        control.createMock(BlockingRpcChannel.class);
+    TestService.BlockingInterface stub =
+        TestService.newBlockingStub(mockChannel);
+
+    FooResponse fooResponse = FooResponse.newBuilder().build();
+    BarResponse barResponse = BarResponse.newBuilder().build();
+
+    EasyMock.expect(mockChannel.callBlockingMethod(
+      EasyMock.same(fooDescriptor),
+      EasyMock.same(mockController),
+      EasyMock.same(fooRequest),
+      EasyMock.same(FooResponse.getDefaultInstance()))).andReturn(fooResponse);
+    EasyMock.expect(mockChannel.callBlockingMethod(
+      EasyMock.same(barDescriptor),
+      EasyMock.same(mockController),
+      EasyMock.same(barRequest),
+      EasyMock.same(BarResponse.getDefaultInstance()))).andReturn(barResponse);
+    control.replay();
+
+    assertSame(fooResponse, stub.foo(mockController, fooRequest));
+    assertSame(barResponse, stub.bar(mockController, barRequest));
+    control.verify();
+  }
+
+  public void testNewReflectiveService() {
+    ServiceWithNoOuter.Interface impl =
+        control.createMock(ServiceWithNoOuter.Interface.class);
+    RpcController controller = control.createMock(RpcController.class);
+    Service service = ServiceWithNoOuter.newReflectiveService(impl);
+
+    MethodDescriptor fooMethod =
+        ServiceWithNoOuter.getDescriptor().findMethodByName("Foo");
+    MessageWithNoOuter request = MessageWithNoOuter.getDefaultInstance();
+    RpcCallback<Message> callback = new RpcCallback<Message>() {
+      public void run(Message parameter) {
+        // No reason this should be run.
+        fail();
+      }
+    };
+    RpcCallback<TestAllTypes> specializedCallback =
+        RpcUtil.specializeCallback(callback);
+
+    impl.foo(EasyMock.same(controller), EasyMock.same(request),
+        EasyMock.same(specializedCallback));
+    EasyMock.expectLastCall();
+
+    control.replay();
+
+    service.callMethod(fooMethod, controller, request, callback);
+
+    control.verify();
+  }
+
+  public void testNewReflectiveBlockingService() throws ServiceException {
+    ServiceWithNoOuter.BlockingInterface impl =
+        control.createMock(ServiceWithNoOuter.BlockingInterface.class);
+    RpcController controller = control.createMock(RpcController.class);
+    BlockingService service =
+        ServiceWithNoOuter.newReflectiveBlockingService(impl);
+
+    MethodDescriptor fooMethod =
+        ServiceWithNoOuter.getDescriptor().findMethodByName("Foo");
+    MessageWithNoOuter request = MessageWithNoOuter.getDefaultInstance();
+    RpcCallback<Message> callback = new RpcCallback<Message>() {
+      public void run(Message parameter) {
+        // No reason this should be run.
+        fail();
+      }
+    };
+
+    TestAllTypes expectedResponse = TestAllTypes.getDefaultInstance();
+    EasyMock.expect(impl.foo(EasyMock.same(controller), EasyMock.same(request)))
+        .andReturn(expectedResponse);
+
+    control.replay();
+
+    Message response =
+        service.callBlockingMethod(fooMethod, controller, request);
+    assertEquals(expectedResponse, response);
+
+    control.verify();
+  }
+
   // =================================================================
 
   /**
@@ -135,7 +228,7 @@ public class ServiceTest extends TestCase {
    * In other words, c wraps the given callback.
    */
   private <Type extends Message> RpcCallback<Type> wrapsCallback(
-      MockCallback callback) {
+      MockCallback<?> callback) {
     EasyMock.reportMatcher(new WrapsCallback(callback));
     return null;
   }
@@ -153,9 +246,9 @@ public class ServiceTest extends TestCase {
 
   /** Implementation of the wrapsCallback() argument matcher. */
   private static class WrapsCallback implements IArgumentMatcher {
-    private MockCallback callback;
+    private MockCallback<?> callback;
 
-    public WrapsCallback(MockCallback callback) {
+    public WrapsCallback(MockCallback<?> callback) {
       this.callback = callback;
     }
 

+ 96 - 3
java/src/test/java/com/google/protobuf/UnknownFieldSetTest.java

@@ -31,13 +31,13 @@
 package com.google.protobuf;
 
 import protobuf_unittest.UnittestProto;
-import protobuf_unittest.UnittestProto.TestAllTypes;
 import protobuf_unittest.UnittestProto.TestAllExtensions;
+import protobuf_unittest.UnittestProto.TestAllTypes;
 import protobuf_unittest.UnittestProto.TestEmptyMessage;
-import protobuf_unittest.UnittestProto.
-  TestEmptyMessageWithExtensions;
+import protobuf_unittest.UnittestProto.TestEmptyMessageWithExtensions;
 
 import junit.framework.TestCase;
+
 import java.util.Arrays;
 import java.util.Map;
 
@@ -341,4 +341,97 @@ public class UnknownFieldSetTest extends TestCase {
     assertEquals(1, field.getVarintList().size());
     assertEquals(0x7FFFFFFFFFFFFFFFL, (long)field.getVarintList().get(0));
   }
+
+  public void testEqualsAndHashCode() {
+    UnknownFieldSet.Field fixed32Field =
+        UnknownFieldSet.Field.newBuilder()
+            .addFixed32(1)
+            .build();
+    UnknownFieldSet.Field fixed64Field =
+        UnknownFieldSet.Field.newBuilder()
+            .addFixed64(1)
+            .build();
+    UnknownFieldSet.Field varIntField =
+        UnknownFieldSet.Field.newBuilder()
+            .addVarint(1)
+            .build();
+    UnknownFieldSet.Field lengthDelimitedField =
+        UnknownFieldSet.Field.newBuilder()
+            .addLengthDelimited(ByteString.EMPTY)
+            .build();
+    UnknownFieldSet.Field groupField =
+        UnknownFieldSet.Field.newBuilder()
+            .addGroup(unknownFields)
+            .build();
+
+    UnknownFieldSet a =
+        UnknownFieldSet.newBuilder()
+            .addField(1, fixed32Field)
+            .build();
+    UnknownFieldSet b =
+        UnknownFieldSet.newBuilder()
+            .addField(1, fixed64Field)
+            .build();
+    UnknownFieldSet c =
+        UnknownFieldSet.newBuilder()
+            .addField(1, varIntField)
+            .build();
+    UnknownFieldSet d =
+        UnknownFieldSet.newBuilder()
+            .addField(1, lengthDelimitedField)
+            .build();
+    UnknownFieldSet e =
+        UnknownFieldSet.newBuilder()
+            .addField(1, groupField)
+            .build();
+
+    checkEqualsIsConsistent(a);
+    checkEqualsIsConsistent(b);
+    checkEqualsIsConsistent(c);
+    checkEqualsIsConsistent(d);
+    checkEqualsIsConsistent(e);
+
+    checkNotEqual(a, b);
+    checkNotEqual(a, c);
+    checkNotEqual(a, d);
+    checkNotEqual(a, e);
+    checkNotEqual(b, c);
+    checkNotEqual(b, d);
+    checkNotEqual(b, e);
+    checkNotEqual(c, d);
+    checkNotEqual(c, e);
+    checkNotEqual(d, e);
+  }
+
+  /**
+   * Asserts that the given field sets are not equal and have different
+   * hash codes.
+   *
+   * @warning It's valid for non-equal objects to have the same hash code, so
+   *   this test is stricter than it needs to be. However, this should happen
+   *   relatively rarely.
+   */
+  private void checkNotEqual(UnknownFieldSet s1, UnknownFieldSet s2) {
+    String equalsError = String.format("%s should not be equal to %s", s1, s2);
+    assertFalse(equalsError, s1.equals(s2));
+    assertFalse(equalsError, s2.equals(s1));
+
+    assertFalse(
+        String.format("%s should have a different hash code from %s", s1, s2),
+        s1.hashCode() == s2.hashCode());
+  }
+
+  /**
+   * Asserts that the given field sets are equal and have identical hash codes.
+   */
+  private void checkEqualsIsConsistent(UnknownFieldSet set) {
+    // Object should be equal to itself.
+    assertEquals(set, set);
+
+    // Object should be equal to a copy of itself.
+    UnknownFieldSet copy = UnknownFieldSet.newBuilder(set).build();
+    assertEquals(set, copy);
+    assertEquals(copy, set);
+    assertEquals(set.hashCode(), copy.hashCode());
+  }
 }

+ 20 - 0
java/src/test/java/com/google/protobuf/WireFormatTest.java

@@ -31,6 +31,10 @@
 package com.google.protobuf;
 
 import junit.framework.TestCase;
+
+import java.io.ByteArrayInputStream;
+import java.io.ByteArrayOutputStream;
+
 import protobuf_unittest.UnittestProto;
 import protobuf_unittest.UnittestProto.TestAllExtensions;
 import protobuf_unittest.UnittestProto.TestAllTypes;
@@ -130,6 +134,22 @@ public class WireFormatTest extends TestCase {
                  TestUtil.getAllExtensionsSet().getSerializedSize());
   }
 
+  public void testSerializeDelimited() throws Exception {
+    ByteArrayOutputStream output = new ByteArrayOutputStream();
+    TestUtil.getAllSet().writeDelimitedTo(output);
+    output.write(12);
+    TestUtil.getPackedSet().writeDelimitedTo(output);
+    output.write(34);
+
+    ByteArrayInputStream input = new ByteArrayInputStream(output.toByteArray());
+
+    TestUtil.assertAllFieldsSet(TestAllTypes.parseDelimitedFrom(input));
+    assertEquals(12, input.read());
+    TestUtil.assertPackedFieldsSet(TestPackedTypes.parseDelimitedFrom(input));
+    assertEquals(34, input.read());
+    assertEquals(-1, input.read());
+  }
+
   private void assertFieldsInOrder(ByteString data) throws Exception {
     CodedInputStream input = data.newCodedInput();
     int previousTag = 0;

+ 114 - 1
python/google/protobuf/internal/input_stream.py

@@ -36,6 +36,7 @@ the InputStream primitives provided here.
 
 __author__ = 'robinson@google.com (Will Robinson)'
 
+import array
 import struct
 from google.protobuf import message
 from google.protobuf.internal import wire_format
@@ -46,7 +47,7 @@ from google.protobuf.internal import wire_format
 # proto2 implementation.
 
 
-class InputStream(object):
+class InputStreamBuffer(object):
 
   """Contains all logic for reading bits, and dealing with stream position.
 
@@ -223,3 +224,115 @@ class InputStream(object):
       shift += 7
       if not (b & 0x80):
         return result
+
+
+class InputStreamArray(object):
+
+  """Contains all logic for reading bits, and dealing with stream position.
+
+  If an InputStream method ever raises an exception, the stream is left
+  in an indeterminate state and is not safe for further use.
+
+  This alternative to InputStreamBuffer is used in environments where buffer()
+  is unavailble, such as Google App Engine.
+  """
+
+  def __init__(self, s):
+    self._buffer = array.array('B', s)
+    self._pos = 0
+
+  def EndOfStream(self):
+    return self._pos >= len(self._buffer)
+
+  def Position(self):
+    return self._pos
+
+  def GetSubBuffer(self, size=None):
+    if size is None:
+      return self._buffer[self._pos : ].tostring()
+    else:
+      if size < 0:
+        raise message.DecodeError('Negative size %d' % size)
+      return self._buffer[self._pos : self._pos + size].tostring()
+
+  def SkipBytes(self, num_bytes):
+    if num_bytes < 0:
+      raise message.DecodeError('Negative num_bytes %d' % num_bytes)
+    self._pos += num_bytes
+    self._pos = min(self._pos, len(self._buffer))
+
+  def ReadBytes(self, size):
+    if size < 0:
+      raise message.DecodeError('Negative size %d' % size)
+    s = self._buffer[self._pos : self._pos + size].tostring()
+    self._pos += len(s)  # Only advance by the number of bytes actually read.
+    return s
+
+  def ReadLittleEndian32(self):
+    try:
+      i = struct.unpack(wire_format.FORMAT_UINT32_LITTLE_ENDIAN,
+                        self._buffer[self._pos : self._pos + 4])
+      self._pos += 4
+      return i[0]  # unpack() result is a 1-element tuple.
+    except struct.error, e:
+      raise message.DecodeError(e)
+
+  def ReadLittleEndian64(self):
+    try:
+      i = struct.unpack(wire_format.FORMAT_UINT64_LITTLE_ENDIAN,
+                        self._buffer[self._pos : self._pos + 8])
+      self._pos += 8
+      return i[0]  # unpack() result is a 1-element tuple.
+    except struct.error, e:
+      raise message.DecodeError(e)
+
+  def ReadVarint32(self):
+    i = self.ReadVarint64()
+    if not wire_format.INT32_MIN <= i <= wire_format.INT32_MAX:
+      raise message.DecodeError('Value out of range for int32: %d' % i)
+    return int(i)
+
+  def ReadVarUInt32(self):
+    i = self.ReadVarUInt64()
+    if i > wire_format.UINT32_MAX:
+      raise message.DecodeError('Value out of range for uint32: %d' % i)
+    return i
+
+  def ReadVarint64(self):
+    i = self.ReadVarUInt64()
+    if i > wire_format.INT64_MAX:
+      i -= (1 << 64)
+    return i
+
+  def ReadVarUInt64(self):
+    i = self._ReadVarintHelper()
+    if not 0 <= i <= wire_format.UINT64_MAX:
+      raise message.DecodeError('Value out of range for uint64: %d' % i)
+    return i
+
+  def _ReadVarintHelper(self):
+    result = 0
+    shift = 0
+    while 1:
+      if shift >= 64:
+        raise message.DecodeError('Too many bytes when decoding varint.')
+      try:
+        b = self._buffer[self._pos]
+      except IndexError:
+        raise message.DecodeError('Truncated varint.')
+      self._pos += 1
+      result |= ((b & 0x7f) << shift)
+      shift += 7
+      if not (b & 0x80):
+        return result
+
+
+try:
+  buffer('')
+  InputStream = InputStreamBuffer
+except NotImplementedError:
+  # Google App Engine: dev_appserver.py
+  InputStream = InputStreamArray
+except RuntimeError:
+  # Google App Engine: production
+  InputStream = InputStreamArray

+ 20 - 1
python/google/protobuf/internal/input_stream_test.py

@@ -40,7 +40,14 @@ from google.protobuf.internal import wire_format
 from google.protobuf.internal import input_stream
 
 
-class InputStreamTest(unittest.TestCase):
+class InputStreamBufferTest(unittest.TestCase):
+
+  def setUp(self):
+    self.__original_input_stream = input_stream.InputStream
+    input_stream.InputStream = input_stream.InputStreamBuffer
+
+  def tearDown(self):
+    input_stream.InputStream = self.__original_input_stream
 
   def testEndOfStream(self):
     stream = input_stream.InputStream('abcd')
@@ -291,5 +298,17 @@ class InputStreamTest(unittest.TestCase):
     stream = input_stream.InputStream(s)
     self.assertRaises(message.DecodeError, stream.ReadVarUInt64)
 
+
+class InputStreamArrayTest(InputStreamBufferTest):
+
+  def setUp(self):
+    # Test InputStreamArray against the same tests in InputStreamBuffer
+    self.__original_input_stream = input_stream.InputStream
+    input_stream.InputStream = input_stream.InputStreamArray
+
+  def tearDown(self):
+    input_stream.InputStream = self.__original_input_stream
+
+
 if __name__ == '__main__':
   unittest.main()

+ 9 - 0
python/google/protobuf/internal/reflection_test.py

@@ -1102,6 +1102,15 @@ class FullProtosEqualityTest(unittest.TestCase):
     test_util.SetAllFields(self.first_proto)
     test_util.SetAllFields(self.second_proto)
 
+  def testNoneNotEqual(self):
+    self.assertNotEqual(self.first_proto, None)
+    self.assertNotEqual(None, self.second_proto)
+
+  def testNotEqualToOtherMessage(self):
+    third_proto = unittest_pb2.TestRequired()
+    self.assertNotEqual(self.first_proto, third_proto)
+    self.assertNotEqual(third_proto, self.second_proto)
+
   def testAllFieldsFilledEquality(self):
     self.assertEqual(self.first_proto, self.second_proto)
 

+ 4 - 0
python/google/protobuf/reflection.py

@@ -599,6 +599,10 @@ def _AddHasExtensionMethod(cls):
 def _AddEqualsMethod(message_descriptor, cls):
   """Helper for _AddMessageMethods()."""
   def __eq__(self, other):
+    if (not isinstance(other, message_mod.Message) or
+        other.DESCRIPTOR != self.DESCRIPTOR):
+      return False
+
     if self is other:
       return True
 

+ 6 - 0
src/Makefile.am

@@ -35,6 +35,7 @@ MAINTAINERCLEANFILES =   \
 
 nobase_include_HEADERS =                                       \
   google/protobuf/stubs/common.h                               \
+  google/protobuf/stubs/once.h                                 \
   google/protobuf/descriptor.h                                 \
   google/protobuf/descriptor.pb.h                              \
   google/protobuf/descriptor_database.h                        \
@@ -69,6 +70,7 @@ libprotobuf_la_LIBADD = $(PTHREAD_LIBS)
 libprotobuf_la_LDFLAGS = -version-info 3:0:0
 libprotobuf_la_SOURCES =                                       \
   google/protobuf/stubs/common.cc                              \
+  google/protobuf/stubs/once.cc                                \
   google/protobuf/stubs/hash.cc                                \
   google/protobuf/stubs/hash.h                                 \
   google/protobuf/stubs/map-util.cc                            \
@@ -161,6 +163,7 @@ protoc_SOURCES = google/protobuf/compiler/main.cc
 
 protoc_inputs =                                                \
   google/protobuf/unittest.proto                               \
+  google/protobuf/unittest_empty.proto                         \
   google/protobuf/unittest_import.proto                        \
   google/protobuf/unittest_mset.proto                          \
   google/protobuf/unittest_optimize_for.proto                  \
@@ -184,6 +187,8 @@ EXTRA_DIST =                                                   \
 protoc_outputs =                                               \
   google/protobuf/unittest.pb.cc                               \
   google/protobuf/unittest.pb.h                                \
+  google/protobuf/unittest_empty.pb.cc                         \
+  google/protobuf/unittest_empty.pb.h                          \
   google/protobuf/unittest_import.pb.cc                        \
   google/protobuf/unittest_import.pb.h                         \
   google/protobuf/unittest_mset.pb.cc                          \
@@ -223,6 +228,7 @@ protobuf_test_CPPFLAGS = -I$(top_srcdir)/gtest/include         \
                          -I$(top_builddir)/gtest/include
 protobuf_test_SOURCES =                                        \
   google/protobuf/stubs/common_unittest.cc                     \
+  google/protobuf/stubs/once_unittest.cc                       \
   google/protobuf/stubs/strutil_unittest.cc                    \
   google/protobuf/stubs/structurally_valid_unittest.cc         \
   google/protobuf/descriptor_database_unittest.cc              \

+ 21 - 3
src/google/protobuf/compiler/cpp/cpp_enum.cc

@@ -100,6 +100,19 @@ void EnumGenerator::GenerateDefinition(io::Printer* printer) {
     "const $classname$ $prefix$$short_name$_MIN = $prefix$$min_name$;\n"
     "const $classname$ $prefix$$short_name$_MAX = $prefix$$max_name$;\n"
     "\n");
+
+  // The _Name and _Parse methods
+  printer->Print(vars,
+    "inline const ::std::string& $classname$_Name($classname$ value) {\n"
+    "  return ::google::protobuf::internal::NameOfEnum(\n"
+    "    $classname$_descriptor(), value);\n"
+    "}\n");
+  printer->Print(vars,
+    "inline bool $classname$_Parse(\n"
+    "    const ::std::string& name, $classname$* value) {\n"
+    "  return ::google::protobuf::internal::ParseNamedEnum<$classname$>(\n"
+    "    $classname$_descriptor(), name, value);\n"
+    "}\n");
 }
 
 void EnumGenerator::GenerateSymbolImports(io::Printer* printer) {
@@ -122,6 +135,13 @@ void EnumGenerator::GenerateSymbolImports(io::Printer* printer) {
     "static inline bool $nested_name$_IsValid(int value) {\n"
     "  return $classname$_IsValid(value);\n"
     "}\n"
+    "static inline const ::std::string& $nested_name$_Name($nested_name$ value) {\n"
+    "  return $classname$_Name(value);\n"
+    "}\n"
+    "static inline bool $nested_name$_Parse(const ::std::string& name,\n"
+    "    $nested_name$* value) {\n"
+    "  return $classname$_Parse(name, value);\n"
+    "}\n"
     "static const $nested_name$ $nested_name$_MIN =\n"
     "  $classname$_$nested_name$_MIN;\n"
     "static const $nested_name$ $nested_name$_MAX =\n"
@@ -147,12 +167,10 @@ void EnumGenerator::GenerateDescriptorInitializer(
 void EnumGenerator::GenerateMethods(io::Printer* printer) {
   map<string, string> vars;
   vars["classname"] = classname_;
-  vars["builddescriptorsname"] =
-      GlobalBuildDescriptorsName(descriptor_->file()->name());
 
   printer->Print(vars,
     "const ::google::protobuf::EnumDescriptor* $classname$_descriptor() {\n"
-    "  if ($classname$_descriptor_ == NULL) $builddescriptorsname$();\n"
+    "  protobuf_AssignDescriptorsOnce();\n"
     "  return $classname$_descriptor_;\n"
     "}\n"
     "bool $classname$_IsValid(int value) {\n"

+ 52 - 20
src/google/protobuf/compiler/cpp/cpp_enum_field.cc

@@ -116,8 +116,8 @@ GenerateSwappingCode(io::Printer* printer) const {
 }
 
 void EnumFieldGenerator::
-GenerateInitializer(io::Printer* printer) const {
-  printer->Print(variables_, ",\n$name$_($default$)");
+GenerateConstructorCode(io::Printer* printer) const {
+  printer->Print(variables_, "$name$_ = $default$;\n");
 }
 
 void EnumFieldGenerator::
@@ -128,15 +128,22 @@ GenerateMergeFromCodedStream(io::Printer* printer) const {
     "if ($type$_IsValid(value)) {\n"
     "  set_$name$(static_cast< $type$ >(value));\n"
     "} else {\n"
-    "  mutable_unknown_fields()->AddField($number$)->add_varint(value);\n"
+    "  mutable_unknown_fields()->AddVarint($number$, value);\n"
     "}\n");
 }
 
 void EnumFieldGenerator::
 GenerateSerializeWithCachedSizes(io::Printer* printer) const {
   printer->Print(variables_,
-    "DO_(::google::protobuf::internal::WireFormat::WriteEnum("
-      "$number$, this->$name$(), output));\n");
+    "::google::protobuf::internal::WireFormat::WriteEnum("
+      "$number$, this->$name$(), output);\n");
+}
+
+void EnumFieldGenerator::
+GenerateSerializeWithCachedSizesToArray(io::Printer* printer) const {
+  printer->Print(variables_,
+    "target = ::google::protobuf::internal::WireFormat::WriteEnumToArray("
+      "$number$, this->$name$(), target);\n");
 }
 
 void EnumFieldGenerator::
@@ -217,12 +224,8 @@ GenerateSwappingCode(io::Printer* printer) const {
 }
 
 void RepeatedEnumFieldGenerator::
-GenerateInitializer(io::Printer* printer) const {
-  printer->Print(variables_, ",\n$name$_()");
-  if (descriptor_->options().packed() &&
-      descriptor_->file()->options().optimize_for() == FileOptions::SPEED) {
-    printer->Print(variables_, ",\n_$name$_cached_byte_size_()");
-  }
+GenerateConstructorCode(io::Printer* printer) const {
+  // Not needed for repeated fields.
 }
 
 void RepeatedEnumFieldGenerator::
@@ -248,7 +251,7 @@ GenerateMergeFromCodedStream(io::Printer* printer) const {
       "if ($type$_IsValid(value)) {\n"
       "  add_$name$(static_cast< $type$ >(value));\n"
       "} else {\n"
-      "  mutable_unknown_fields()->AddField($number$)->add_varint(value);\n"
+      "  mutable_unknown_fields()->AddVarint($number$, value);\n"
       "}\n");
   }
 }
@@ -259,22 +262,51 @@ GenerateSerializeWithCachedSizes(io::Printer* printer) const {
     // Write the tag and the size.
     printer->Print(variables_,
       "if (this->$name$_size() > 0) {\n"
-      "  DO_(::google::protobuf::internal::WireFormat::WriteTag("
-          "$number$, ::google::protobuf::internal::WireFormat::WIRETYPE_LENGTH_DELIMITED,"
-          "output));\n"
-      "  DO_(output->WriteVarint32(_$name$_cached_byte_size_));\n"
+      "  ::google::protobuf::internal::WireFormat::WriteTag("
+          "$number$, "
+          "::google::protobuf::internal::WireFormat::WIRETYPE_LENGTH_DELIMITED, "
+          "output);\n"
+      "  output->WriteVarint32(_$name$_cached_byte_size_);\n"
+      "}\n");
+  }
+  printer->Print(variables_,
+      "for (int i = 0; i < this->$name$_size(); i++) {\n");
+  if (descriptor_->options().packed()) {
+    printer->Print(variables_,
+      "  ::google::protobuf::internal::WireFormat::WriteEnumNoTag("
+          "this->$name$(i), output);\n");
+  } else {
+    printer->Print(variables_,
+      "  ::google::protobuf::internal::WireFormat::WriteEnum("
+          "$number$, this->$name$(i), output);\n");
+  }
+  printer->Print("}\n");
+}
+
+void RepeatedEnumFieldGenerator::
+GenerateSerializeWithCachedSizesToArray(io::Printer* printer) const {
+  if (descriptor_->options().packed()) {
+    // Write the tag and the size.
+    printer->Print(variables_,
+      "if (this->$name$_size() > 0) {\n"
+      "  target = ::google::protobuf::internal::WireFormat::WriteTagToArray("
+          "$number$, "
+          "::google::protobuf::internal::WireFormat::WIRETYPE_LENGTH_DELIMITED, "
+          "target);\n"
+      "  target = ::google::protobuf::io::CodedOutputStream::WriteVarint32ToArray("
+          "_$name$_cached_byte_size_, target);\n"
       "}\n");
   }
   printer->Print(variables_,
       "for (int i = 0; i < this->$name$_size(); i++) {\n");
   if (descriptor_->options().packed()) {
     printer->Print(variables_,
-      "  DO_(::google::protobuf::internal::WireFormat::WriteEnumNoTag("
-          "this->$name$(i), output));\n");
+      "  target = ::google::protobuf::internal::WireFormat::WriteEnumNoTagToArray("
+          "this->$name$(i), target);\n");
   } else {
     printer->Print(variables_,
-      "  DO_(::google::protobuf::internal::WireFormat::WriteEnum("
-          "$number$, this->$name$(i), output));\n");
+      "  target = ::google::protobuf::internal::WireFormat::WriteEnumToArray("
+          "$number$, this->$name$(i), target);\n");
   }
   printer->Print("}\n");
 }

+ 4 - 2
src/google/protobuf/compiler/cpp/cpp_enum_field.h

@@ -56,9 +56,10 @@ class EnumFieldGenerator : public FieldGenerator {
   void GenerateClearingCode(io::Printer* printer) const;
   void GenerateMergingCode(io::Printer* printer) const;
   void GenerateSwappingCode(io::Printer* printer) const;
-  void GenerateInitializer(io::Printer* printer) const;
+  void GenerateConstructorCode(io::Printer* printer) const;
   void GenerateMergeFromCodedStream(io::Printer* printer) const;
   void GenerateSerializeWithCachedSizes(io::Printer* printer) const;
+  void GenerateSerializeWithCachedSizesToArray(io::Printer* printer) const;
   void GenerateByteSize(io::Printer* printer) const;
 
  private:
@@ -80,9 +81,10 @@ class RepeatedEnumFieldGenerator : public FieldGenerator {
   void GenerateClearingCode(io::Printer* printer) const;
   void GenerateMergingCode(io::Printer* printer) const;
   void GenerateSwappingCode(io::Printer* printer) const;
-  void GenerateInitializer(io::Printer* printer) const;
+  void GenerateConstructorCode(io::Printer* printer) const;
   void GenerateMergeFromCodedStream(io::Printer* printer) const;
   void GenerateSerializeWithCachedSizes(io::Printer* printer) const;
+  void GenerateSerializeWithCachedSizesToArray(io::Printer* printer) const;
   void GenerateByteSize(io::Printer* printer) const;
 
  private:

+ 72 - 9
src/google/protobuf/compiler/cpp/cpp_extension.cc

@@ -36,6 +36,7 @@
 #include <google/protobuf/compiler/cpp/cpp_helpers.h>
 #include <google/protobuf/stubs/strutil.h>
 #include <google/protobuf/io/printer.h>
+#include <google/protobuf/descriptor.pb.h>
 
 namespace google {
 namespace protobuf {
@@ -55,7 +56,9 @@ ExtensionGenerator::ExtensionGenerator(const FieldDescriptor* descriptor,
     case FieldDescriptor::CPPTYPE_ENUM:
       type_traits_.append("EnumTypeTraits< ");
       type_traits_.append(ClassName(descriptor_->enum_type(), true));
-      type_traits_.append(" >");
+      type_traits_.append(", ");
+      type_traits_.append(ClassName(descriptor_->enum_type(), true));
+      type_traits_.append("_IsValid>");
       break;
     case FieldDescriptor::CPPTYPE_STRING:
       type_traits_.append("StringTypeTraits");
@@ -81,6 +84,8 @@ void ExtensionGenerator::GenerateDeclaration(io::Printer* printer) {
   vars["number"       ] = SimpleItoa(descriptor_->number());
   vars["type_traits"  ] = type_traits_;
   vars["name"         ] = descriptor_->name();
+  vars["field_type"   ] = SimpleItoa(static_cast<int>(descriptor_->type()));
+  vars["packed"       ] = descriptor_->options().packed() ? "true" : "false";
   vars["constant_name"] = FieldConstantName(descriptor_);
 
   // If this is a class member, it needs to be declared "static".  Otherwise,
@@ -95,19 +100,39 @@ void ExtensionGenerator::GenerateDeclaration(io::Printer* printer) {
   printer->Print(vars,
     "static const int $constant_name$ = $number$;\n"
     "$qualifier$ ::google::protobuf::internal::ExtensionIdentifier< $extendee$,\n"
-    "  ::google::protobuf::internal::$type_traits$ > $name$;\n");
+    "    ::google::protobuf::internal::$type_traits$, $field_type$, $packed$ >\n"
+    "  $name$;\n"
+    );
 }
 
 void ExtensionGenerator::GenerateDefinition(io::Printer* printer) {
+  // If this is a class member, it needs to be declared in its class scope.
+  string scope = (descriptor_->extension_scope() == NULL) ? "" :
+    ClassName(descriptor_->extension_scope(), false) + "::";
+  string name = scope + descriptor_->name();
+
   map<string, string> vars;
   vars["extendee"     ] = ClassName(descriptor_->containing_type(), true);
   vars["type_traits"  ] = type_traits_;
-  vars["name"         ] = descriptor_->name();
+  vars["name"         ] = name;
   vars["constant_name"] = FieldConstantName(descriptor_);
-
-  // If this is a class member, it needs to be declared in its class scope.
-  vars["scope"] = (descriptor_->extension_scope() == NULL) ? "" :
-    ClassName(descriptor_->extension_scope(), false) + "::";
+  vars["default"      ] = DefaultValue(descriptor_);
+  vars["field_type"   ] = SimpleItoa(static_cast<int>(descriptor_->type()));
+  vars["packed"       ] = descriptor_->options().packed() ? "true" : "false";
+  vars["scope"        ] = scope;
+
+  if (descriptor_->cpp_type() == FieldDescriptor::CPPTYPE_STRING) {
+    // We need to declare a global string which will contain the default value.
+    // We cannot declare it at class scope because that would require exposing
+    // it in the header which would be annoying for other reasons.  So we
+    // replace :: with _ in the name and declare it as a global.
+    string global_name = StringReplace(name, "::", "_", true);
+    vars["global_name"] = global_name;
+    printer->Print(vars,
+      "const ::std::string $global_name$_default($default$);\n");
+    // Update the default to refer to the string global.
+    vars["default"] = global_name + "_default";
+  }
 
   // Likewise, class members need to declare the field constant variable.
   if (descriptor_->extension_scope() != NULL) {
@@ -119,8 +144,46 @@ void ExtensionGenerator::GenerateDefinition(io::Printer* printer) {
 
   printer->Print(vars,
     "::google::protobuf::internal::ExtensionIdentifier< $extendee$,\n"
-    "  ::google::protobuf::internal::$type_traits$ > $scope$$name$("
-      "$constant_name$);\n");
+    "    ::google::protobuf::internal::$type_traits$, $field_type$, $packed$ >\n"
+    "  $name$($constant_name$, $default$);\n");
+}
+
+void ExtensionGenerator::GenerateRegistration(io::Printer* printer) {
+  map<string, string> vars;
+  vars["extendee"   ] = ClassName(descriptor_->containing_type(), true);
+  vars["number"     ] = SimpleItoa(descriptor_->number());
+  vars["field_type" ] = SimpleItoa(static_cast<int>(descriptor_->type()));
+  vars["is_repeated"] = descriptor_->is_repeated() ? "true" : "false";
+  vars["is_packed"  ] = (descriptor_->is_repeated() &&
+                         descriptor_->options().packed())
+                        ? "true" : "false";
+
+  switch (descriptor_->cpp_type()) {
+    case FieldDescriptor::CPPTYPE_ENUM:
+      printer->Print(vars,
+        "::google::protobuf::internal::ExtensionSet::RegisterEnumExtension(\n"
+        "  &$extendee$::default_instance(),\n"
+        "  $number$, $field_type$, $is_repeated$, $is_packed$,\n");
+      printer->Print(
+        "  &$type$_IsValid);\n",
+        "type", ClassName(descriptor_->enum_type(), true));
+      break;
+    case FieldDescriptor::CPPTYPE_MESSAGE:
+      printer->Print(vars,
+        "::google::protobuf::internal::ExtensionSet::RegisterMessageExtension(\n"
+        "  &$extendee$::default_instance(),\n"
+        "  $number$, $field_type$, $is_repeated$, $is_packed$,\n");
+      printer->Print(
+        "  &$type$::default_instance());\n",
+        "type", ClassName(descriptor_->message_type(), true));
+      break;
+    default:
+      printer->Print(vars,
+        "::google::protobuf::internal::ExtensionSet::RegisterExtension(\n"
+        "  &$extendee$::default_instance(),\n"
+        "  $number$, $field_type$, $is_repeated$, $is_packed$);\n");
+      break;
+  }
 }
 
 }  // namespace cpp

+ 3 - 0
src/google/protobuf/compiler/cpp/cpp_extension.h

@@ -66,6 +66,9 @@ class ExtensionGenerator {
   // Source file stuff.
   void GenerateDefinition(io::Printer* printer);
 
+  // Generate code to register the extension.
+  void GenerateRegistration(io::Printer* printer);
+
  private:
   const FieldDescriptor* descriptor_;
   string type_traits_;

+ 13 - 10
src/google/protobuf/compiler/cpp/cpp_field.h

@@ -94,16 +94,13 @@ class FieldGenerator {
   // message.cc under the GenerateSwap method.
   virtual void GenerateSwappingCode(io::Printer* printer) const = 0;
 
-  // Generate any initializers needed for the private members declared by
-  // GeneratePrivateMembers().  These go into the message class's
-  // constructor's initializer list.  For each initializer, this method
-  // must print the comma and newline separating it from the *previous*
-  // initializer, not the *next* initailizer.  That is, print a ",\n" first,
-  // e.g.:
-  //   printer->Print(",\n$name$_($default$)");
-  virtual void GenerateInitializer(io::Printer* printer) const = 0;
-
-  // Generate any code that needs to go in the class's destructor.
+  // Generate initialization code for private members declared by
+  // GeneratePrivateMembers(). These go into the message class's SharedCtor()
+  // method, invoked by each of the generated constructors.
+  virtual void GenerateConstructorCode(io::Printer* printer) const = 0;
+
+  // Generate any code that needs to go in the class's SharedDtor() method,
+  // invoked by the destructor.
   // Most field types don't need this, so the default implementation is empty.
   virtual void GenerateDestructorCode(io::Printer* printer) const {}
 
@@ -115,6 +112,12 @@ class FieldGenerator {
   // message's SerializeWithCachedSizes() method.
   virtual void GenerateSerializeWithCachedSizes(io::Printer* printer) const = 0;
 
+  // Generate lines to serialize this field directly to the array "target",
+  // which are placed within the message's SerializeWithCachedSizesToArray()
+  // method. This must also advance "target" past the written bytes.
+  virtual void GenerateSerializeWithCachedSizesToArray(
+      io::Printer* printer) const = 0;
+
   // Generate lines to compute the serialized size of this field, which
   // are placed in the message's ByteSize() method.
   virtual void GenerateByteSize(io::Printer* printer) const = 0;

+ 115 - 37
src/google/protobuf/compiler/cpp/cpp_file.cc

@@ -143,17 +143,20 @@ void FileGenerator::GenerateHeader(io::Printer* printer) {
   // Open namespace.
   GenerateNamespaceOpeners(printer);
 
-  // Forward-declare the AssignGlobalDescriptors function, so that we can
-  // declare it to be a friend of each class.
+  // Forward-declare the AddDescriptors and AssignDescriptors functions, so
+  // that we can declare them to be friends of each class.
   printer->Print(
     "\n"
     "// Internal implementation detail -- do not call these.\n"
-    "void $dllexport_decl$ $builddescriptorsname$();\n"
-    "void $builddescriptorsname$_AssignGlobalDescriptors(\n"
-    "    ::google::protobuf::FileDescriptor* file);\n"
-    "\n",
-    "builddescriptorsname", GlobalBuildDescriptorsName(file_->name()),
+    "void $dllexport_decl$ $adddescriptorsname$();\n",
+    "adddescriptorsname", GlobalAddDescriptorsName(file_->name()),
     "dllexport_decl", dllexport_decl_);
+  printer->Print(
+    // Note that we don't put dllexport_decl on this because it is only called
+    // by the .pb.cc file in which it is defined.
+    "void $assigndescriptorsname$();\n"
+    "\n",
+    "assigndescriptorsname", GlobalAssignDescriptorsName(file_->name()));
 
   // Generate forward declarations of classes.
   for (int i = 0; i < file_->message_type_count(); i++) {
@@ -232,6 +235,7 @@ void FileGenerator::GenerateSource(io::Printer* printer) {
     "// Generated by the protocol buffer compiler.  DO NOT EDIT!\n"
     "\n"
     "#include \"$basename$.pb.h\"\n"
+    "#include <google/protobuf/stubs/once.h>\n"
     "#include <google/protobuf/descriptor.h>\n"
     "#include <google/protobuf/io/coded_stream.h>\n"
     "#include <google/protobuf/reflection_ops.h>\n"
@@ -296,23 +300,46 @@ void FileGenerator::GenerateSource(io::Printer* printer) {
 }
 
 void FileGenerator::GenerateBuildDescriptors(io::Printer* printer) {
-  // BuildDescriptors() is a file-level procedure which initializes all of
-  // the Descriptor objects for this file.  It runs the first time one of the
-  // descriptors is accessed.  This will always be at static initialization
-  // time, because every message has a statically-initialized default instance,
-  // and the constructor for a message class accesses its descriptor.  See the
-  // constructor and the descriptor() method of message classes.
+  // AddDescriptors() is a file-level procedure which adds the encoded
+  // FileDescriptorProto for this .proto file to the global DescriptorPool
+  // for generated files (DescriptorPool::generated_pool()).  It always runs
+  // at static initialization time, so all files will be registered before
+  // main() starts.  This procedure also constructs default instances and
+  // registers extensions.
   //
-  // We also construct the reflection object for each class inside
-  // BuildDescriptors().
+  // Its sibling, AssignDescriptors(), actually pulls the compiled
+  // FileDescriptor from the DescriptorPool and uses it to populate all of
+  // the global variables which store pointers to the descriptor objects.
+  // It also constructs the reflection objects.  It is called the first time
+  // anyone calls descriptor() or GetReflection() on one of the types defined
+  // in the file.
 
-  // First we generate a method to assign the global descriptors.
   printer->Print(
     "\n"
-    "void $builddescriptorsname$_AssignGlobalDescriptors("
-    "const ::google::protobuf::FileDescriptor* file) {\n",
-    "builddescriptorsname", GlobalBuildDescriptorsName(file_->name()));
+    "void $assigndescriptorsname$() {\n",
+    "assigndescriptorsname", GlobalAssignDescriptorsName(file_->name()));
   printer->Indent();
+
+  // Make sure the file has found its way into the pool.  If a descriptor
+  // is requested *during* static init then AddDescriptors() may not have
+  // been called yet, so we call it manually.  Note that it's fine if
+  // AddDescriptors() is called multiple times.
+  printer->Print(
+    "$adddescriptorsname$();\n",
+    "adddescriptorsname", GlobalAddDescriptorsName(file_->name()));
+
+  // Get the file's descriptor from the pool.
+  printer->Print(
+    "const ::google::protobuf::FileDescriptor* file =\n"
+    "  ::google::protobuf::DescriptorPool::generated_pool()->FindFileByName(\n"
+    "    \"$filename$\");\n"
+    // Note that this GOOGLE_CHECK is necessary to prevent a warning about "file"
+    // being unused when compiling an empty .proto file.
+    "GOOGLE_CHECK(file != NULL);\n",
+    "filename", file_->name());
+
+  // Go through all the stuff defined in this file and generated code to
+  // assign the global descriptor pointers based on the file descriptor.
   for (int i = 0; i < file_->message_type_count(); i++) {
     message_generators_[i]->GenerateDescriptorInitializer(printer, i);
   }
@@ -322,29 +349,63 @@ void FileGenerator::GenerateBuildDescriptors(io::Printer* printer) {
   for (int i = 0; i < file_->service_count(); i++) {
     service_generators_[i]->GenerateDescriptorInitializer(printer, i);
   }
+
+  printer->Outdent();
+  printer->Print(
+    "}\n"
+    "\n");
+
+  // -----------------------------------------------------------------
+
+  // protobuf_AssignDescriptorsOnce():  The first time it is called, calls
+  // AssignDescriptors().  All later times, waits for the first call to
+  // complete and then returns.
+  printer->Print(
+    "namespace {\n"
+    "\n"
+    "GOOGLE_PROTOBUF_DECLARE_ONCE(protobuf_AssignDescriptors_once_);\n"
+    "inline void protobuf_AssignDescriptorsOnce() {\n"
+    "  ::google::protobuf::GoogleOnceInit(&protobuf_AssignDescriptors_once_,\n"
+    "                 &$assigndescriptorsname$);\n"
+    "}\n"
+    "\n",
+    "assigndescriptorsname", GlobalAssignDescriptorsName(file_->name()));
+
+  // protobuf_RegisterTypes():  Calls
+  // MessageFactory::InternalRegisterGeneratedType() for each message type.
+  printer->Print(
+    "void protobuf_RegisterTypes() {\n"
+    "  protobuf_AssignDescriptorsOnce();\n");
+  printer->Indent();
+
   for (int i = 0; i < file_->message_type_count(); i++) {
-    message_generators_[i]->GenerateDefaultInstanceInitializer(printer);
+    message_generators_[i]->GenerateTypeRegistrations(printer);
   }
 
   printer->Outdent();
   printer->Print(
-    "}\n");
+    "}\n"
+    "\n"
+    "}  // namespace\n");
 
+  // -----------------------------------------------------------------
+
+  // Now generate the AddDescriptors() function.
   printer->Print(
     "\n"
-    "void $builddescriptorsname$() {\n"
+    "void $adddescriptorsname$() {\n"
+    // We don't need any special synchronization here because this code is
+    // called at static init time before any threads exist.
     "  static bool already_here = false;\n"
     "  if (already_here) return;\n"
     "  already_here = true;\n"
     "  GOOGLE_PROTOBUF_VERIFY_VERSION;\n"
-    "  ::google::protobuf::DescriptorPool* pool =\n"
-    "    ::google::protobuf::DescriptorPool::internal_generated_pool();\n"
     "\n",
-    "builddescriptorsname", GlobalBuildDescriptorsName(file_->name()));
+    "adddescriptorsname", GlobalAddDescriptorsName(file_->name()));
   printer->Indent();
 
-  // Call the BuildDescriptors() methods for all of our dependencies, to make
-  // sure they get initialized first.
+  // Call the AddDescriptors() methods for all of our dependencies, to make
+  // sure they get added first.
   for (int i = 0; i < file_->dependency_count(); i++) {
     const FileDescriptor* dependency = file_->dependency(i);
     // Print the namespace prefix for the dependency.
@@ -355,10 +416,10 @@ void FileGenerator::GenerateBuildDescriptors(io::Printer* printer) {
       printer->Print("$name$::",
                      "name", dependency_package_parts[i]);
     }
-    // Call its BuildDescriptors function.
+    // Call its AddDescriptors function.
     printer->Print(
       "$name$();\n",
-      "name", GlobalBuildDescriptorsName(dependency->name()));
+      "name", GlobalAddDescriptorsName(dependency->name()));
   }
 
   // Embed the descriptor.  We simply serialize the entire FileDescriptorProto
@@ -370,7 +431,7 @@ void FileGenerator::GenerateBuildDescriptors(io::Printer* printer) {
   file_proto.SerializeToString(&file_data);
 
   printer->Print(
-    "pool->InternalBuildGeneratedFile(");
+    "::google::protobuf::DescriptorPool::InternalAddGeneratedFile(");
 
   // Only write 40 bytes per line.
   static const int kBytesPerLine = 40;
@@ -379,24 +440,41 @@ void FileGenerator::GenerateBuildDescriptors(io::Printer* printer) {
       "data", CEscape(file_data.substr(i, kBytesPerLine)));
   }
   printer->Print(
-    ", $size$,\n"
-    "&$builddescriptorsname$_AssignGlobalDescriptors);\n",
-    "size", SimpleItoa(file_data.size()),
-    "builddescriptorsname", GlobalBuildDescriptorsName(file_->name()));
+    ", $size$);\n",
+    "size", SimpleItoa(file_data.size()));
+
+  // Call MessageFactory::InternalRegisterGeneratedFile().
+  printer->Print(
+    "::google::protobuf::MessageFactory::InternalRegisterGeneratedFile(\n"
+    "  \"$filename$\", &protobuf_RegisterTypes);\n",
+    "filename", file_->name());
+
+  // Allocate and initialize default instances.  This can't be done lazily
+  // since default instances are returned by simple accessors and are used with
+  // extensions.  Speaking of which, we also register extensions at this time.
+  for (int i = 0; i < file_->message_type_count(); i++) {
+    message_generators_[i]->GenerateDefaultInstanceAllocator(printer);
+  }
+  for (int i = 0; i < file_->extension_count(); i++) {
+    extension_generators_[i]->GenerateRegistration(printer);
+  }
+  for (int i = 0; i < file_->message_type_count(); i++) {
+    message_generators_[i]->GenerateDefaultInstanceInitializer(printer);
+  }
 
   printer->Outdent();
 
   printer->Print(
     "}\n"
     "\n"
-    "// Force BuildDescriptors() to be called at static initialization time.\n"
+    "// Force AddDescriptors() to be called at static initialization time.\n"
     "struct StaticDescriptorInitializer_$filename$ {\n"
     "  StaticDescriptorInitializer_$filename$() {\n"
-    "    $builddescriptorsname$();\n"
+    "    $adddescriptorsname$();\n"
     "  }\n"
     "} static_descriptor_initializer_$filename$_;\n"
     "\n",
-    "builddescriptorsname", GlobalBuildDescriptorsName(file_->name()),
+    "adddescriptorsname", GlobalAddDescriptorsName(file_->name()),
     "filename", FilenameIdentifier(file_->name()));
 }
 

+ 44 - 3
src/google/protobuf/compiler/cpp/cpp_helpers.cc

@@ -38,6 +38,7 @@
 #include <google/protobuf/compiler/cpp/cpp_helpers.h>
 #include <google/protobuf/stubs/common.h>
 #include <google/protobuf/stubs/strutil.h>
+#include <google/protobuf/stubs/substitute.h>
 
 namespace google {
 namespace protobuf {
@@ -213,6 +214,41 @@ const char* DeclaredTypeMethodName(FieldDescriptor::Type type) {
   return "";
 }
 
+string DefaultValue(const FieldDescriptor* field) {
+  switch (field->cpp_type()) {
+    case FieldDescriptor::CPPTYPE_INT32:
+      return SimpleItoa(field->default_value_int32());
+    case FieldDescriptor::CPPTYPE_UINT32:
+      return SimpleItoa(field->default_value_uint32()) + "u";
+    case FieldDescriptor::CPPTYPE_INT64:
+      return "GOOGLE_LONGLONG(" + SimpleItoa(field->default_value_int64()) + ")";
+    case FieldDescriptor::CPPTYPE_UINT64:
+      return "GOOGLE_ULONGLONG(" + SimpleItoa(field->default_value_uint64())+ ")";
+    case FieldDescriptor::CPPTYPE_DOUBLE:
+      return SimpleDtoa(field->default_value_double());
+    case FieldDescriptor::CPPTYPE_FLOAT:
+      return SimpleFtoa(field->default_value_float());
+    case FieldDescriptor::CPPTYPE_BOOL:
+      return field->default_value_bool() ? "true" : "false";
+    case FieldDescriptor::CPPTYPE_ENUM:
+      // Lazy:  Generate a static_cast because we don't have a helper function
+      //   that constructs the full name of an enum value.
+      return strings::Substitute(
+          "static_cast< $0 >($1)",
+          ClassName(field->enum_type(), true),
+          field->default_value_enum()->number());
+    case FieldDescriptor::CPPTYPE_STRING:
+      return "\"" + CEscape(field->default_value_string()) + "\"";
+    case FieldDescriptor::CPPTYPE_MESSAGE:
+      return ClassName(field->message_type(), true) + "::default_instance()";
+  }
+  // Can't actually get here; make compiler happy.  (We could add a default
+  // case above but then we wouldn't get the nice compiler warning when a
+  // new type is added.)
+  GOOGLE_LOG(FATAL) << "Can't get here.";
+  return "";
+}
+
 // Convert a file name into a valid identifier.
 string FilenameIdentifier(const string& filename) {
   string result;
@@ -230,9 +266,14 @@ string FilenameIdentifier(const string& filename) {
   return result;
 }
 
-// Return the name of the BuildDescriptors() function for a given file.
-string GlobalBuildDescriptorsName(const string& filename) {
-  return "protobuf_BuildDesc_" + FilenameIdentifier(filename);
+// Return the name of the AddDescriptors() function for a given file.
+string GlobalAddDescriptorsName(const string& filename) {
+  return "protobuf_AddDesc_" + FilenameIdentifier(filename);
+}
+
+// Return the name of the AssignDescriptors() function for a given file.
+string GlobalAssignDescriptorsName(const string& filename) {
+  return "protobuf_AssignDesc_" + FilenameIdentifier(filename);
 }
 
 }  // namespace cpp

+ 8 - 2
src/google/protobuf/compiler/cpp/cpp_helpers.h

@@ -90,11 +90,17 @@ const char* PrimitiveTypeName(FieldDescriptor::CppType type);
 // methods of WireFormat.  For example, TYPE_INT32 becomes "Int32".
 const char* DeclaredTypeMethodName(FieldDescriptor::Type type);
 
+// Get code that evaluates to the field's default value.
+string DefaultValue(const FieldDescriptor* field);
+
 // Convert a file name into a valid identifier.
 string FilenameIdentifier(const string& filename);
 
-// Return the name of the BuildDescriptors() function for a given file.
-string GlobalBuildDescriptorsName(const string& filename);
+// Return the name of the AddDescriptors() function for a given file.
+string GlobalAddDescriptorsName(const string& filename);
+
+// Return the name of the AssignDescriptors() function for a given file.
+string GlobalAssignDescriptorsName(const string& filename);
 
 }  // namespace cpp
 }  // namespace compiler

+ 225 - 230
src/google/protobuf/compiler/cpp/cpp_message.cc

@@ -223,104 +223,10 @@ GenerateFieldAccessorDeclarations(io::Printer* printer) {
   }
 
   if (descriptor_->extension_range_count() > 0) {
-    // Generate accessors for extensions.
-
-    // Normally I'd generate prototypes here and generate the actual
-    // definitions of these methods in GenerateFieldAccessorDefinitions, but
-    // the prototypes for these silly methods are so absurdly complicated that
-    // it meant way too much repitition.
-    //
-    // We use "_proto_TypeTraits" as a type name below because "TypeTraits"
-    // causes problems if the class has a nested message or enum type with that
-    // name and "_TypeTraits" is technically reserved for the C++ library since
-    // it starts with an underscore followed by a capital letter.
+    // Generate accessors for extensions.  We just call a macro located in
+    // extension_set.h since the accessors about 80 lines of static code.
     printer->Print(
-      // Has, Size, Clear
-      "template <typename _proto_TypeTraits>\n"
-      "inline bool HasExtension(\n"
-      "    const ::google::protobuf::internal::ExtensionIdentifier<\n"
-      "      $classname$, _proto_TypeTraits>& id) const {\n"
-      "  return _extensions_.Has(id.number());\n"
-      "}\n"
-      "\n"
-      "template <typename _proto_TypeTraits>\n"
-      "inline void ClearExtension(\n"
-      "    const ::google::protobuf::internal::ExtensionIdentifier<\n"
-      "      $classname$, _proto_TypeTraits>& id) {\n"
-      "  _extensions_.ClearExtension(id.number());\n"
-      "}\n"
-      "\n"
-      "template <typename _proto_TypeTraits>\n"
-      "inline int ExtensionSize(\n"
-      "    const ::google::protobuf::internal::ExtensionIdentifier<\n"
-      "      $classname$, _proto_TypeTraits>& id) const {\n"
-      "  return _extensions_.ExtensionSize(id.number());\n"
-      "}\n"
-      "\n"
-
-      // Singular accessors
-      "template <typename _proto_TypeTraits>\n"
-      "inline typename _proto_TypeTraits::ConstType GetExtension(\n"
-      "    const ::google::protobuf::internal::ExtensionIdentifier<\n"
-      "      $classname$, _proto_TypeTraits>& id) const {\n"
-      "  return _proto_TypeTraits::Get(id.number(), _extensions_);\n"
-      "}\n"
-      "\n"
-      "template <typename _proto_TypeTraits>\n"
-      "inline typename _proto_TypeTraits::MutableType MutableExtension(\n"
-      "    const ::google::protobuf::internal::ExtensionIdentifier<\n"
-      "      $classname$, _proto_TypeTraits>& id) {\n"
-      "  return _proto_TypeTraits::Mutable(id.number(), &_extensions_);\n"
-      "}\n"
-      "\n"
-      "template <typename _proto_TypeTraits>\n"
-      "inline void SetExtension(\n"
-      "    const ::google::protobuf::internal::ExtensionIdentifier<\n"
-      "      $classname$, _proto_TypeTraits>& id,\n"
-      "    typename _proto_TypeTraits::ConstType value) {\n"
-      "  _proto_TypeTraits::Set(id.number(), value, &_extensions_);\n"
-      "}\n"
-      "\n"
-
-      // Repeated accessors
-      "template <typename _proto_TypeTraits>\n"
-      "inline typename _proto_TypeTraits::ConstType GetExtension(\n"
-      "    const ::google::protobuf::internal::ExtensionIdentifier<\n"
-      "      $classname$, _proto_TypeTraits>& id,\n"
-      "    int index) const {\n"
-      "  return _proto_TypeTraits::Get(id.number(), _extensions_, index);\n"
-      "}\n"
-      "\n"
-      "template <typename _proto_TypeTraits>\n"
-      "inline typename _proto_TypeTraits::MutableType MutableExtension(\n"
-      "    const ::google::protobuf::internal::ExtensionIdentifier<\n"
-      "      $classname$, _proto_TypeTraits>& id,\n"
-      "    int index) {\n"
-      "  return _proto_TypeTraits::Mutable(id.number(),index,&_extensions_);\n"
-      "}\n"
-      "\n"
-      "template <typename _proto_TypeTraits>\n"
-      "inline void SetExtension(\n"
-      "    const ::google::protobuf::internal::ExtensionIdentifier<\n"
-      "      $classname$, _proto_TypeTraits>& id,\n"
-      "    int index, typename _proto_TypeTraits::ConstType value) {\n"
-      "  _proto_TypeTraits::Set(id.number(), index, value, &_extensions_);\n"
-      "}\n"
-      "\n"
-      "template <typename _proto_TypeTraits>\n"
-      "inline typename _proto_TypeTraits::MutableType AddExtension(\n"
-      "    const ::google::protobuf::internal::ExtensionIdentifier<\n"
-      "      $classname$, _proto_TypeTraits>& id) {\n"
-      "  return _proto_TypeTraits::Add(id.number(), &_extensions_);\n"
-      "}\n"
-      "\n"
-      "template <typename _proto_TypeTraits>\n"
-      "inline void AddExtension(\n"
-      "    const ::google::protobuf::internal::ExtensionIdentifier<\n"
-      "      $classname$, _proto_TypeTraits>& id,\n"
-      "    typename _proto_TypeTraits::ConstType value) {\n"
-      "  _proto_TypeTraits::Add(id.number(), value, &_extensions_);\n"
-      "}\n",
+      "GOOGLE_PROTOBUF_EXTENSION_ACCESSORS($classname$)\n",
       "classname", classname_);
   }
 }
@@ -391,8 +297,6 @@ GenerateClassDefinition(io::Printer* printer) {
   } else {
     vars["dllexport"] = dllexport_decl_ + " ";
   }
-  vars["builddescriptorsname"] =
-    GlobalBuildDescriptorsName(descriptor_->file()->name());
 
   printer->Print(vars,
     "class $dllexport$$classname$ : public ::google::protobuf::Message {\n"
@@ -433,18 +337,30 @@ GenerateClassDefinition(io::Printer* printer) {
       "void CopyFrom(const $classname$& from);\n"
       "void MergeFrom(const $classname$& from);\n"
       "void Clear();\n"
-      "bool IsInitialized() const;\n"
-      "int ByteSize() const;\n"
-      "\n"
-      "bool MergePartialFromCodedStream(\n"
-      "    ::google::protobuf::io::CodedInputStream* input);\n"
-      "bool SerializeWithCachedSizes(\n"
-      "    ::google::protobuf::io::CodedOutputStream* output) const;\n");
+      "bool IsInitialized() const;\n");
+
+    if (!descriptor_->options().message_set_wire_format()) {
+      // For message_set_wire_format, we don't generate parsing or
+      // serialization code even if optimize_for = SPEED, since MessageSet
+      // encoding is somewhat more complicated than normal extension encoding
+      // and we'd like to avoid having to implement it in multiple places.
+      // WireFormat's implementation is probably good enough.
+      printer->Print(vars,
+        "\n"
+        "int ByteSize() const;\n"
+        "bool MergePartialFromCodedStream(\n"
+        "    ::google::protobuf::io::CodedInputStream* input);\n"
+        "void SerializeWithCachedSizes(\n"
+        "    ::google::protobuf::io::CodedOutputStream* output) const;\n"
+        "::google::protobuf::uint8* SerializeWithCachedSizesToArray(::google::protobuf::uint8* output) const;\n");
+    }
   }
 
   printer->Print(vars,
     "int GetCachedSize() const { return _cached_size_; }\n"
     "private:\n"
+    "void SharedCtor();\n"
+    "void SharedDtor();\n"
     "void SetCachedSize(int size) const { _cached_size_ = size; }\n"
     "public:\n"
     "\n"
@@ -505,11 +421,17 @@ GenerateClassDefinition(io::Printer* printer) {
                      .GeneratePrivateMembers(printer);
   }
 
-  // Generate offsets and _has_bits_ boilerplate.
-  printer->Print(vars,
-    "friend void $builddescriptorsname$_AssignGlobalDescriptors(\n"
-    "    const ::google::protobuf::FileDescriptor* file);\n");
+  // Declare AddDescriptors() and BuildDescriptors() as friends so that they
+  // can assign private static variables like default_instance_ and reflection_.
+  printer->Print(
+    "friend void $adddescriptorsname$();\n"
+    "friend void $assigndescriptorsname$();\n",
+    "adddescriptorsname",
+      GlobalAddDescriptorsName(descriptor_->file()->name()),
+    "assigndescriptorsname",
+      GlobalAssignDescriptorsName(descriptor_->file()->name()));
 
+  // Generate offsets and _has_bits_ boilerplate.
   if (descriptor_->field_count() > 0) {
     printer->Print(vars,
       "::google::protobuf::uint32 _has_bits_[($field_count$ + 31) / 32];\n");
@@ -592,12 +514,6 @@ GenerateDescriptorInitializer(io::Printer* printer, int index) {
         "$parent$_descriptor_->nested_type($index$);\n");
   }
 
-  // Construct the default instance.  We can't call InitAsDefaultInstance() yet
-  // because we need to make sure all default instances that this one might
-  // depend on are constructed first.
-  printer->Print(vars,
-    "$classname$::default_instance_ = new $classname$();\n");
-
   // Generate the offsets.
   GenerateOffsets(printer);
 
@@ -622,6 +538,7 @@ GenerateDescriptorInitializer(io::Printer* printer, int index) {
   }
   printer->Print(vars,
     "    ::google::protobuf::DescriptorPool::generated_pool(),\n"
+    "    ::google::protobuf::MessageFactory::generated_factory(),\n"
     "    sizeof($classname$));\n");
 
   // Handle nested types.
@@ -632,11 +549,35 @@ GenerateDescriptorInitializer(io::Printer* printer, int index) {
   for (int i = 0; i < descriptor_->enum_type_count(); i++) {
     enum_generators_[i]->GenerateDescriptorInitializer(printer, i);
   }
+}
 
+void MessageGenerator::
+GenerateTypeRegistrations(io::Printer* printer) {
   // Register this message type with the message factory.
-  printer->Print(vars,
+  printer->Print(
     "::google::protobuf::MessageFactory::InternalRegisterGeneratedMessage(\n"
-    "  $classname$_descriptor_, $classname$::default_instance_);\n");
+    "  $classname$_descriptor_, &$classname$::default_instance());\n",
+    "classname", classname_);
+
+  // Handle nested types.
+  for (int i = 0; i < descriptor_->nested_type_count(); i++) {
+    nested_generators_[i]->GenerateTypeRegistrations(printer);
+  }
+}
+
+void MessageGenerator::
+GenerateDefaultInstanceAllocator(io::Printer* printer) {
+  // Construct the default instance.  We can't call InitAsDefaultInstance() yet
+  // because we need to make sure all default instances that this one might
+  // depend on are constructed first.
+  printer->Print(
+    "$classname$::default_instance_ = new $classname$();\n",
+    "classname", classname_);
+
+  // Handle nested types.
+  for (int i = 0; i < descriptor_->nested_type_count(); i++) {
+    nested_generators_[i]->GenerateDefaultInstanceAllocator(printer);
+  }
 }
 
 void MessageGenerator::
@@ -645,6 +586,11 @@ GenerateDefaultInstanceInitializer(io::Printer* printer) {
     "$classname$::default_instance_->InitAsDefaultInstance();\n",
     "classname", classname_);
 
+  // Register extensions.
+  for (int i = 0; i < descriptor_->extension_count(); i++) {
+    extension_generators_[i]->GenerateRegistration(printer);
+  }
+
   // Handle nested types.
   for (int i = 0; i < descriptor_->nested_type_count(); i++) {
     nested_generators_[i]->GenerateDefaultInstanceInitializer(printer);
@@ -695,14 +641,24 @@ GenerateClassMethods(io::Printer* printer) {
     GenerateClear(printer);
     printer->Print("\n");
 
-    GenerateMergeFromCodedStream(printer);
-    printer->Print("\n");
+    if (!descriptor_->options().message_set_wire_format()) {
+      // For message_set_wire_format, we don't generate parsing or
+      // serialization code even if optimize_for = SPEED, since MessageSet
+      // encoding is somewhat more complicated than normal extension encoding
+      // and we'd like to avoid having to implement it in multiple places.
+      // WireFormat's implementation is probably good enough.
+      GenerateMergeFromCodedStream(printer);
+      printer->Print("\n");
 
-    GenerateSerializeWithCachedSizes(printer);
-    printer->Print("\n");
+      GenerateSerializeWithCachedSizes(printer);
+      printer->Print("\n");
 
-    GenerateByteSize(printer);
-    printer->Print("\n");
+      GenerateSerializeWithCachedSizesToArray(printer);
+      printer->Print("\n");
+
+      GenerateByteSize(printer);
+      printer->Print("\n");
+    }
 
     GenerateMergeFrom(printer);
     printer->Print("\n");
@@ -723,12 +679,10 @@ GenerateClassMethods(io::Printer* printer) {
     "}\n"
     "\n"
     "const ::google::protobuf::Reflection* $classname$::GetReflection() const {\n"
-    "  if ($classname$_reflection_ == NULL) $builddescriptorsname$();\n"
+    "  protobuf_AssignDescriptorsOnce();\n"
     "  return $classname$_reflection_;\n"
     "}\n",
-    "classname", classname_,
-    "builddescriptorsname",
-      GlobalBuildDescriptorsName(descriptor_->file()->name()));
+    "classname", classname_);
 }
 
 void MessageGenerator::
@@ -757,28 +711,68 @@ GenerateInitializerList(io::Printer* printer) {
   printer->Indent();
 
   printer->Print(
-    "::google::protobuf::Message(),\n");
+    "::google::protobuf::Message()");
 
-  if (descriptor_->extension_range_count() > 0) {
-    printer->Print(
-      "_extensions_(&$classname$_descriptor_,\n"
-      "             ::google::protobuf::DescriptorPool::generated_pool(),\n"
-      "             ::google::protobuf::MessageFactory::generated_factory()),\n",
-      "classname", classname_);
-  }
+  printer->Outdent();
+  printer->Outdent();
+}
+
+void MessageGenerator::
+GenerateSharedConstructorCode(io::Printer* printer) {
+  printer->Print(
+    "void $classname$::SharedCtor() {\n",
+    "classname", classname_);
+  printer->Indent();
 
   printer->Print(
-    "_unknown_fields_(),\n"
-    "_cached_size_(0)");
+    "_cached_size_ = 0;\n");
 
-  // Write the initializers for each field.
   for (int i = 0; i < descriptor_->field_count(); i++) {
     field_generators_.get(descriptor_->field(i))
-                     .GenerateInitializer(printer);
+                     .GenerateConstructorCode(printer);
   }
 
+  printer->Print(
+    "::memset(_has_bits_, 0, sizeof(_has_bits_));\n");
+
   printer->Outdent();
+  printer->Print("}\n\n");
+}
+
+void MessageGenerator::
+GenerateSharedDestructorCode(io::Printer* printer) {
+  printer->Print(
+    "void $classname$::SharedDtor() {\n",
+    "classname", classname_);
+  printer->Indent();
+  // Write the destructors for each field.
+  for (int i = 0; i < descriptor_->field_count(); i++) {
+    field_generators_.get(descriptor_->field(i))
+                     .GenerateDestructorCode(printer);
+  }
+
+  printer->Print(
+    "if (this != default_instance_) {\n");
+
+  // We need to delete all embedded messages.
+  // TODO(kenton):  If we make unset messages point at default instances
+  //   instead of NULL, then it would make sense to move this code into
+  //   MessageFieldGenerator::GenerateDestructorCode().
+  for (int i = 0; i < descriptor_->field_count(); i++) {
+    const FieldDescriptor* field = descriptor_->field(i);
+
+    if (!field->is_repeated() &&
+        field->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
+      printer->Print("  delete $name$_;\n",
+                     "name", FieldName(field));
+    }
+  }
+
   printer->Outdent();
+  printer->Print(
+    "  }\n"
+    "}\n"
+    "\n");
 }
 
 void MessageGenerator::
@@ -790,7 +784,7 @@ GenerateStructors(io::Printer* printer) {
       "classname", classname_);
   GenerateInitializerList(printer);
   printer->Print(" {\n"
-    "  ::memset(_has_bits_, 0, sizeof(_has_bits_));\n"
+    "  SharedCtor();\n"
     "}\n");
 
   printer->Print(
@@ -826,54 +820,33 @@ GenerateStructors(io::Printer* printer) {
       "classname", classname_);
   GenerateInitializerList(printer);
   printer->Print(" {\n"
-    "  ::memset(_has_bits_, 0, sizeof(_has_bits_));\n"
+    "  SharedCtor();\n"
     "  MergeFrom(from);\n"
     "}\n"
     "\n");
 
+  // Generate the shared constructor code.
+  GenerateSharedConstructorCode(printer);
+
   // Generate the destructor.
   printer->Print(
-    "$classname$::~$classname$() {\n",
+    "$classname$::~$classname$() {\n"
+    "  SharedDtor();\n"
+    "}\n"
+    "\n",
     "classname", classname_);
 
-  printer->Indent();
-
-  // Write the destructors for each field.
-  for (int i = 0; i < descriptor_->field_count(); i++) {
-    field_generators_.get(descriptor_->field(i))
-                     .GenerateDestructorCode(printer);
-  }
+  // Generate the shared destructor code.
+  GenerateSharedDestructorCode(printer);
 
   printer->Print(
-    "if (this != default_instance_) {\n");
-
-  // We need to delete all embedded messages.
-  // TODO(kenton):  If we make unset messages point at default instances
-  //   instead of NULL, then it would make sense to move this code into
-  //   MessageFieldGenerator::GenerateDestructorCode().
-  for (int i = 0; i < descriptor_->field_count(); i++) {
-    const FieldDescriptor* field = descriptor_->field(i);
-
-    if (!field->is_repeated() &&
-        field->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
-      printer->Print("  delete $name$_;\n",
-                     "name", FieldName(field));
-    }
-  }
-
-  printer->Outdent();
-
-  printer->Print(
-    "  }\n"
-    "}\n"
-    "\n"
     "const ::google::protobuf::Descriptor* $classname$::descriptor() {\n"
-    "  if ($classname$_descriptor_ == NULL) $builddescriptorsname$();\n"
+    "  protobuf_AssignDescriptorsOnce();\n"
     "  return $classname$_descriptor_;\n"
     "}\n"
     "\n"
     "const $classname$& $classname$::default_instance() {\n"
-    "  if (default_instance_ == NULL) $builddescriptorsname$();\n"
+    "  if (default_instance_ == NULL) $adddescriptorsname$();"
     "  return *default_instance_;\n"
     "}\n"
     "\n"
@@ -883,8 +856,8 @@ GenerateStructors(io::Printer* printer) {
     "  return new $classname$;\n"
     "}\n",
     "classname", classname_,
-    "builddescriptorsname",
-    GlobalBuildDescriptorsName(descriptor_->file()->name()));
+    "adddescriptorsname",
+    GlobalAddDescriptorsName(descriptor_->file()->name()));
 }
 
 void MessageGenerator::
@@ -1127,24 +1100,6 @@ GenerateCopyFrom(io::Printer* printer) {
 
 void MessageGenerator::
 GenerateMergeFromCodedStream(io::Printer* printer) {
-  if (descriptor_->options().message_set_wire_format()) {
-    // For message_set_wire_format, we don't generate a parser, for two
-    // reasons:
-    // - WireFormat already needs to special-case this, and we'd like to
-    //   avoid having multiple implementations of MessageSet wire format
-    //   lying around the code base.
-    // - All fields are extensions, and extension parsing falls back to
-    //   reflection anyway, so it wouldn't be any faster.
-    printer->Print(
-      "bool $classname$::MergePartialFromCodedStream(\n"
-      "    ::google::protobuf::io::CodedInputStream* input) {\n"
-      "  return ::google::protobuf::internal::WireFormat::ParseAndMergePartial(\n"
-      "    input, this);\n"
-      "}\n",
-      "classname", classname_);
-    return;
-  }
-
   printer->Print(
     "bool $classname$::MergePartialFromCodedStream(\n"
     "    ::google::protobuf::io::CodedInputStream* input) {\n"
@@ -1267,7 +1222,8 @@ GenerateMergeFromCodedStream(io::Printer* printer) {
       }
     }
     printer->Print(") {\n"
-      "  DO_(_extensions_.ParseField(tag, input, this));\n"
+      "  DO_(_extensions_.ParseField(tag, input, default_instance_,\n"
+      "                              mutable_unknown_fields()));\n"
       "  continue;\n"
       "}\n");
   }
@@ -1295,7 +1251,7 @@ GenerateMergeFromCodedStream(io::Printer* printer) {
 }
 
 void MessageGenerator::GenerateSerializeOneField(
-    io::Printer* printer, const FieldDescriptor* field) {
+    io::Printer* printer, const FieldDescriptor* field, bool to_array) {
   PrintFieldComment(printer, field);
 
   if (!field->is_repeated()) {
@@ -1305,7 +1261,12 @@ void MessageGenerator::GenerateSerializeOneField(
     printer->Indent();
   }
 
-  field_generators_.get(field).GenerateSerializeWithCachedSizes(printer);
+  if (to_array) {
+    field_generators_.get(field).GenerateSerializeWithCachedSizesToArray(
+        printer);
+  } else {
+    field_generators_.get(field).GenerateSerializeWithCachedSizes(printer);
+  }
 
   if (!field->is_repeated()) {
     printer->Outdent();
@@ -1315,25 +1276,66 @@ void MessageGenerator::GenerateSerializeOneField(
 }
 
 void MessageGenerator::GenerateSerializeOneExtensionRange(
-    io::Printer* printer, const Descriptor::ExtensionRange* range) {
+    io::Printer* printer, const Descriptor::ExtensionRange* range,
+    bool to_array) {
   map<string, string> vars;
   vars["start"] = SimpleItoa(range->start);
   vars["end"] = SimpleItoa(range->end);
   printer->Print(vars,
-    "// Extension range [$start$, $end$)\n"
-    "DO_(_extensions_.SerializeWithCachedSizes(\n"
-    "    $start$, $end$, *this, output));\n\n");
+    "// Extension range [$start$, $end$)\n");
+  if (to_array) {
+    printer->Print(vars,
+      "target = _extensions_.SerializeWithCachedSizesToArray(\n"
+      "    $start$, $end$, target);\n\n");
+  } else {
+    printer->Print(vars,
+      "_extensions_.SerializeWithCachedSizes(\n"
+      "    $start$, $end$, output);\n\n");
+  }
 }
 
 void MessageGenerator::
 GenerateSerializeWithCachedSizes(io::Printer* printer) {
   printer->Print(
-    "bool $classname$::SerializeWithCachedSizes(\n"
-    "    ::google::protobuf::io::CodedOutputStream* output) const {\n"
-    "#define DO_(EXPRESSION) if (!(EXPRESSION)) return false\n",
+    "void $classname$::SerializeWithCachedSizes(\n"
+    "    ::google::protobuf::io::CodedOutputStream* output) const {\n",
+    "classname", classname_);
+  printer->Indent();
+
+  printer->Print(
+    "::google::protobuf::uint8* raw_buffer = "
+      "output->GetDirectBufferForNBytesAndAdvance(_cached_size_);\n"
+    "if (raw_buffer != NULL) {\n"
+    "  $classname$::SerializeWithCachedSizesToArray(raw_buffer);\n"
+    "  return;\n"
+    "}\n"
+    "\n",
+    "classname", classname_);
+  GenerateSerializeWithCachedSizesBody(printer, false);
+
+  printer->Outdent();
+  printer->Print(
+    "}\n");
+}
+
+void MessageGenerator::
+GenerateSerializeWithCachedSizesToArray(io::Printer* printer) {
+  printer->Print(
+    "::google::protobuf::uint8* $classname$::SerializeWithCachedSizesToArray(\n"
+    "    ::google::protobuf::uint8* target) const {\n",
     "classname", classname_);
   printer->Indent();
 
+  GenerateSerializeWithCachedSizesBody(printer, true);
+
+  printer->Outdent();
+  printer->Print(
+    "  return target;\n"
+    "}\n");
+}
+
+void MessageGenerator::
+GenerateSerializeWithCachedSizesBody(io::Printer* printer, bool to_array) {
   scoped_array<const FieldDescriptor*> ordered_fields(
     SortFieldsByNumber(descriptor_));
 
@@ -1350,35 +1352,35 @@ GenerateSerializeWithCachedSizes(io::Printer* printer) {
        i < descriptor_->field_count() || j < sorted_extensions.size();
        ) {
     if (i == descriptor_->field_count()) {
-      GenerateSerializeOneExtensionRange(printer, sorted_extensions[j++]);
+      GenerateSerializeOneExtensionRange(printer,
+                                         sorted_extensions[j++],
+                                         to_array);
     } else if (j == sorted_extensions.size()) {
-      GenerateSerializeOneField(printer, ordered_fields[i++]);
+      GenerateSerializeOneField(printer, ordered_fields[i++], to_array);
     } else if (ordered_fields[i]->number() < sorted_extensions[j]->start) {
-      GenerateSerializeOneField(printer, ordered_fields[i++]);
+      GenerateSerializeOneField(printer, ordered_fields[i++], to_array);
     } else {
-      GenerateSerializeOneExtensionRange(printer, sorted_extensions[j++]);
+      GenerateSerializeOneExtensionRange(printer,
+                                         sorted_extensions[j++],
+                                         to_array);
     }
   }
 
   printer->Print("if (!unknown_fields().empty()) {\n");
   printer->Indent();
-  if (descriptor_->options().message_set_wire_format()) {
+  if (to_array) {
     printer->Print(
-      "DO_(::google::protobuf::internal::WireFormat::SerializeUnknownMessageSetItems(\n"
-      "    unknown_fields(), output));\n");
+      "target = "
+          "::google::protobuf::internal::WireFormat::SerializeUnknownFieldsToArray(\n"
+      "    unknown_fields(), target);\n");
   } else {
     printer->Print(
-      "DO_(::google::protobuf::internal::WireFormat::SerializeUnknownFields(\n"
-      "    unknown_fields(), output));\n");
+      "::google::protobuf::internal::WireFormat::SerializeUnknownFields(\n"
+      "    unknown_fields(), output);\n");
   }
   printer->Outdent();
-  printer->Print(
-    "}\n"
-    "return true;\n");
 
-  printer->Outdent();
   printer->Print(
-    "#undef DO_\n"
     "}\n");
 }
 
@@ -1449,23 +1451,16 @@ GenerateByteSize(io::Printer* printer) {
 
   if (descriptor_->extension_range_count() > 0) {
     printer->Print(
-      "total_size += _extensions_.ByteSize(*this);\n"
+      "total_size += _extensions_.ByteSize();\n"
       "\n");
   }
 
   printer->Print("if (!unknown_fields().empty()) {\n");
   printer->Indent();
-  if (descriptor_->options().message_set_wire_format()) {
-    printer->Print(
-      "total_size +=\n"
-      "  ::google::protobuf::internal::WireFormat::ComputeUnknownMessageSetItemsSize(\n"
-      "    unknown_fields());\n");
-  } else {
-    printer->Print(
-      "total_size +=\n"
-      "  ::google::protobuf::internal::WireFormat::ComputeUnknownFieldsSize(\n"
-      "    unknown_fields());\n");
-  }
+  printer->Print(
+    "total_size +=\n"
+    "  ::google::protobuf::internal::WireFormat::ComputeUnknownFieldsSize(\n"
+    "    unknown_fields());\n");
   printer->Outdent();
   printer->Print("}\n");
 

+ 26 - 3
src/google/protobuf/compiler/cpp/cpp_message.h

@@ -86,7 +86,16 @@ class MessageGenerator {
   // descriptor.
   void GenerateDescriptorInitializer(io::Printer* printer, int index);
 
-  // Generates code that initializes the message's default instance.
+  // Generate code that calls MessageFactory::InternalRegisterGeneratedMessage()
+  // for all types.
+  void GenerateTypeRegistrations(io::Printer* printer);
+
+  // Generates code that allocates the message's default instance.
+  void GenerateDefaultInstanceAllocator(io::Printer* printer);
+
+  // Generates code that initializes the message's default instance.  This
+  // is separate from allocating because all default instances must be
+  // allocated before any can be initialized.
   void GenerateDefaultInstanceInitializer(io::Printer* printer);
 
   // Generate all non-inline methods for this class.
@@ -103,6 +112,15 @@ class MessageGenerator {
   // Generate constructors and destructor.
   void GenerateStructors(io::Printer* printer);
 
+  // The compiler typically generates multiple copies of each constructor and
+  // destructor: http://gcc.gnu.org/bugs.html#nonbugs_cxx
+  // Placing common code in a separate method reduces the generated code size.
+  //
+  // Generate the shared constructor code.
+  void GenerateSharedConstructorCode(io::Printer* printer);
+  // Generate the shared destructor code.
+  void GenerateSharedDestructorCode(io::Printer* printer);
+
   // Generate the member initializer list for the constructors. The member
   // initializer list is shared between the default constructor and the copy
   // constructor.
@@ -112,6 +130,9 @@ class MessageGenerator {
   void GenerateClear(io::Printer* printer);
   void GenerateMergeFromCodedStream(io::Printer* printer);
   void GenerateSerializeWithCachedSizes(io::Printer* printer);
+  void GenerateSerializeWithCachedSizesToArray(io::Printer* printer);
+  void GenerateSerializeWithCachedSizesBody(io::Printer* printer,
+                                            bool to_array);
   void GenerateByteSize(io::Printer* printer);
   void GenerateMergeFrom(io::Printer* printer);
   void GenerateCopyFrom(io::Printer* printer);
@@ -120,9 +141,11 @@ class MessageGenerator {
 
   // Helpers for GenerateSerializeWithCachedSizes().
   void GenerateSerializeOneField(io::Printer* printer,
-                                 const FieldDescriptor* field);
+                                 const FieldDescriptor* field,
+                                 bool unbounded);
   void GenerateSerializeOneExtensionRange(
-      io::Printer* printer, const Descriptor::ExtensionRange* range);
+      io::Printer* printer, const Descriptor::ExtensionRange* range,
+      bool unbounded);
 
   const Descriptor* descriptor_;
   string classname_;

+ 26 - 8
src/google/protobuf/compiler/cpp/cpp_message_field.cc

@@ -116,8 +116,8 @@ GenerateSwappingCode(io::Printer* printer) const {
 }
 
 void MessageFieldGenerator::
-GenerateInitializer(io::Printer* printer) const {
-  printer->Print(variables_, ",\n$name$_(NULL)");
+GenerateConstructorCode(io::Printer* printer) const {
+  printer->Print(variables_, "$name$_ = NULL;\n");
 }
 
 void MessageFieldGenerator::
@@ -136,8 +136,16 @@ GenerateMergeFromCodedStream(io::Printer* printer) const {
 void MessageFieldGenerator::
 GenerateSerializeWithCachedSizes(io::Printer* printer) const {
   printer->Print(variables_,
-    "DO_(::google::protobuf::internal::WireFormat::Write$declared_type$NoVirtual("
-      "$number$, this->$name$(), output));\n");
+    "::google::protobuf::internal::WireFormat::Write$declared_type$NoVirtual("
+      "$number$, this->$name$(), output);\n");
+}
+
+void MessageFieldGenerator::
+GenerateSerializeWithCachedSizesToArray(io::Printer* printer) const {
+  printer->Print(variables_,
+    "target = ::google::protobuf::internal::WireFormat::"
+      "Write$declared_type$NoVirtualToArray("
+      "$number$, this->$name$(), target);\n");
 }
 
 void MessageFieldGenerator::
@@ -212,8 +220,8 @@ GenerateSwappingCode(io::Printer* printer) const {
 }
 
 void RepeatedMessageFieldGenerator::
-GenerateInitializer(io::Printer* printer) const {
-  printer->Print(variables_, ",\n$name$_()");
+GenerateConstructorCode(io::Printer* printer) const {
+  // Not needed for repeated fields.
 }
 
 void RepeatedMessageFieldGenerator::
@@ -233,8 +241,18 @@ void RepeatedMessageFieldGenerator::
 GenerateSerializeWithCachedSizes(io::Printer* printer) const {
   printer->Print(variables_,
     "for (int i = 0; i < this->$name$_size(); i++) {\n"
-    "  DO_(::google::protobuf::internal::WireFormat::Write$declared_type$NoVirtual("
-        "$number$, this->$name$(i), output));\n"
+    "  ::google::protobuf::internal::WireFormat::Write$declared_type$NoVirtual("
+        "$number$, this->$name$(i), output);\n"
+    "}\n");
+}
+
+void RepeatedMessageFieldGenerator::
+GenerateSerializeWithCachedSizesToArray(io::Printer* printer) const {
+  printer->Print(variables_,
+    "for (int i = 0; i < this->$name$_size(); i++) {\n"
+    "  target = ::google::protobuf::internal::WireFormat::"
+        "Write$declared_type$NoVirtualToArray("
+        "$number$, this->$name$(i), target);\n"
     "}\n");
 }
 

+ 4 - 2
src/google/protobuf/compiler/cpp/cpp_message_field.h

@@ -56,9 +56,10 @@ class MessageFieldGenerator : public FieldGenerator {
   void GenerateClearingCode(io::Printer* printer) const;
   void GenerateMergingCode(io::Printer* printer) const;
   void GenerateSwappingCode(io::Printer* printer) const;
-  void GenerateInitializer(io::Printer* printer) const;
+  void GenerateConstructorCode(io::Printer* printer) const;
   void GenerateMergeFromCodedStream(io::Printer* printer) const;
   void GenerateSerializeWithCachedSizes(io::Printer* printer) const;
+  void GenerateSerializeWithCachedSizesToArray(io::Printer* printer) const;
   void GenerateByteSize(io::Printer* printer) const;
 
  private:
@@ -80,9 +81,10 @@ class RepeatedMessageFieldGenerator : public FieldGenerator {
   void GenerateClearingCode(io::Printer* printer) const;
   void GenerateMergingCode(io::Printer* printer) const;
   void GenerateSwappingCode(io::Printer* printer) const;
-  void GenerateInitializer(io::Printer* printer) const;
+  void GenerateConstructorCode(io::Printer* printer) const;
   void GenerateMergeFromCodedStream(io::Printer* printer) const;
   void GenerateSerializeWithCachedSizes(io::Printer* printer) const;
+  void GenerateSerializeWithCachedSizesToArray(io::Printer* printer) const;
   void GenerateByteSize(io::Printer* printer) const;
 
  private:

+ 52 - 47
src/google/protobuf/compiler/cpp/cpp_primitive_field.cc

@@ -79,35 +79,6 @@ int FixedSize(FieldDescriptor::Type type) {
   return -1;
 }
 
-string DefaultValue(const FieldDescriptor* field) {
-  switch (field->cpp_type()) {
-    case FieldDescriptor::CPPTYPE_INT32:
-      return SimpleItoa(field->default_value_int32());
-    case FieldDescriptor::CPPTYPE_UINT32:
-      return SimpleItoa(field->default_value_uint32()) + "u";
-    case FieldDescriptor::CPPTYPE_INT64:
-      return "GOOGLE_LONGLONG(" + SimpleItoa(field->default_value_int64()) + ")";
-    case FieldDescriptor::CPPTYPE_UINT64:
-      return "GOOGLE_ULONGLONG(" + SimpleItoa(field->default_value_uint64())+ ")";
-    case FieldDescriptor::CPPTYPE_DOUBLE:
-      return SimpleDtoa(field->default_value_double());
-    case FieldDescriptor::CPPTYPE_FLOAT:
-      return SimpleFtoa(field->default_value_float());
-    case FieldDescriptor::CPPTYPE_BOOL:
-      return field->default_value_bool() ? "true" : "false";
-
-    case FieldDescriptor::CPPTYPE_ENUM:
-    case FieldDescriptor::CPPTYPE_STRING:
-    case FieldDescriptor::CPPTYPE_MESSAGE:
-      GOOGLE_LOG(FATAL) << "Shouldn't get here.";
-      return "";
-  }
-  // Can't actually get here; make compiler happy.  (We could add a default
-  // case above but then we wouldn't get the nice compiler warning when a
-  // new type is added.)
-  return "";
-}
-
 // TODO(kenton):  Factor out a "SetCommonFieldVariables()" to get rid of
 //   repeat code between this and the other field types.
 void SetPrimitiveVariables(const FieldDescriptor* descriptor,
@@ -180,8 +151,8 @@ GenerateSwappingCode(io::Printer* printer) const {
 }
 
 void PrimitiveFieldGenerator::
-GenerateInitializer(io::Printer* printer) const {
-  printer->Print(variables_, ",\n$name$_($default$)");
+GenerateConstructorCode(io::Printer* printer) const {
+  printer->Print(variables_, "$name$_ = $default$;\n");
 }
 
 void PrimitiveFieldGenerator::
@@ -195,8 +166,15 @@ GenerateMergeFromCodedStream(io::Printer* printer) const {
 void PrimitiveFieldGenerator::
 GenerateSerializeWithCachedSizes(io::Printer* printer) const {
   printer->Print(variables_,
-    "DO_(::google::protobuf::internal::WireFormat::Write$declared_type$("
-      "$number$, this->$name$(), output));\n");
+    "::google::protobuf::internal::WireFormat::Write$declared_type$("
+      "$number$, this->$name$(), output);\n");
+}
+
+void PrimitiveFieldGenerator::
+GenerateSerializeWithCachedSizesToArray(io::Printer* printer) const {
+  printer->Print(variables_,
+    "target = ::google::protobuf::internal::WireFormat::Write$declared_type$ToArray("
+      "$number$, this->$name$(), target);\n");
 }
 
 void PrimitiveFieldGenerator::
@@ -282,12 +260,8 @@ GenerateSwappingCode(io::Printer* printer) const {
 }
 
 void RepeatedPrimitiveFieldGenerator::
-GenerateInitializer(io::Printer* printer) const {
-  printer->Print(variables_, ",\n$name$_()");
-  if (descriptor_->options().packed() &&
-      descriptor_->file()->options().optimize_for() == FileOptions::SPEED) {
-    printer->Print(variables_, ",\n_$name$_cached_byte_size_()");
-  }
+GenerateConstructorCode(io::Printer* printer) const {
+  // Not needed for repeated fields.
 }
 
 void RepeatedPrimitiveFieldGenerator::
@@ -324,22 +298,53 @@ GenerateSerializeWithCachedSizes(io::Printer* printer) const {
     // Write the tag and the size.
     printer->Print(variables_,
       "if (this->$name$_size() > 0) {\n"
-      "  DO_(::google::protobuf::internal::WireFormat::WriteTag("
-          "$number$, ::google::protobuf::internal::WireFormat::WIRETYPE_LENGTH_DELIMITED,"
-          "output));\n"
-      "  DO_(output->WriteVarint32(_$name$_cached_byte_size_));\n"
+      "  ::google::protobuf::internal::WireFormat::WriteTag("
+          "$number$, "
+          "::google::protobuf::internal::WireFormat::WIRETYPE_LENGTH_DELIMITED, "
+          "output);\n"
+      "  output->WriteVarint32(_$name$_cached_byte_size_);\n"
+      "}\n");
+  }
+  printer->Print(variables_,
+      "for (int i = 0; i < this->$name$_size(); i++) {\n");
+  if (descriptor_->options().packed()) {
+    printer->Print(variables_,
+      "  ::google::protobuf::internal::WireFormat::Write$declared_type$NoTag("
+          "this->$name$(i), output);\n");
+  } else {
+    printer->Print(variables_,
+      "  ::google::protobuf::internal::WireFormat::Write$declared_type$("
+          "$number$, this->$name$(i), output);\n");
+  }
+  printer->Print("}\n");
+}
+
+void RepeatedPrimitiveFieldGenerator::
+GenerateSerializeWithCachedSizesToArray(io::Printer* printer) const {
+  if (descriptor_->options().packed()) {
+    // Write the tag and the size.
+    printer->Print(variables_,
+      "if (this->$name$_size() > 0) {\n"
+      "  target = ::google::protobuf::internal::WireFormat::WriteTagToArray("
+          "$number$, "
+          "::google::protobuf::internal::WireFormat::WIRETYPE_LENGTH_DELIMITED, "
+          "target);\n"
+      "  target = ::google::protobuf::io::CodedOutputStream::WriteVarint32ToArray("
+          "_$name$_cached_byte_size_, target);\n"
       "}\n");
   }
   printer->Print(variables_,
       "for (int i = 0; i < this->$name$_size(); i++) {\n");
   if (descriptor_->options().packed()) {
     printer->Print(variables_,
-      "  DO_(::google::protobuf::internal::WireFormat::Write$declared_type$NoTag("
-          "this->$name$(i), output));\n");
+      "  target = ::google::protobuf::internal::WireFormat::"
+          "Write$declared_type$NoTagToArray("
+          "this->$name$(i), target);\n");
   } else {
     printer->Print(variables_,
-      "  DO_(::google::protobuf::internal::WireFormat::Write$declared_type$("
-          "$number$, this->$name$(i), output));\n");
+      "  target = ::google::protobuf::internal::WireFormat::"
+          "Write$declared_type$ToArray("
+          "$number$, this->$name$(i), target);\n");
   }
   printer->Print("}\n");
 }

+ 4 - 2
src/google/protobuf/compiler/cpp/cpp_primitive_field.h

@@ -56,9 +56,10 @@ class PrimitiveFieldGenerator : public FieldGenerator {
   void GenerateClearingCode(io::Printer* printer) const;
   void GenerateMergingCode(io::Printer* printer) const;
   void GenerateSwappingCode(io::Printer* printer) const;
-  void GenerateInitializer(io::Printer* printer) const;
+  void GenerateConstructorCode(io::Printer* printer) const;
   void GenerateMergeFromCodedStream(io::Printer* printer) const;
   void GenerateSerializeWithCachedSizes(io::Printer* printer) const;
+  void GenerateSerializeWithCachedSizesToArray(io::Printer* printer) const;
   void GenerateByteSize(io::Printer* printer) const;
 
  private:
@@ -80,9 +81,10 @@ class RepeatedPrimitiveFieldGenerator : public FieldGenerator {
   void GenerateClearingCode(io::Printer* printer) const;
   void GenerateMergingCode(io::Printer* printer) const;
   void GenerateSwappingCode(io::Printer* printer) const;
-  void GenerateInitializer(io::Printer* printer) const;
+  void GenerateConstructorCode(io::Printer* printer) const;
   void GenerateMergeFromCodedStream(io::Printer* printer) const;
   void GenerateSerializeWithCachedSizes(io::Printer* printer) const;
+  void GenerateSerializeWithCachedSizesToArray(io::Printer* printer) const;
   void GenerateByteSize(io::Printer* printer) const;
 
  private:

+ 4 - 2
src/google/protobuf/compiler/cpp/cpp_service.cc

@@ -176,10 +176,12 @@ void ServiceGenerator::GenerateImplementation(io::Printer* printer) {
     "$classname$::~$classname$() {}\n"
     "\n"
     "const ::google::protobuf::ServiceDescriptor* $classname$::descriptor() {\n"
+    "  protobuf_AssignDescriptorsOnce();\n"
     "  return $classname$_descriptor_;\n"
     "}\n"
     "\n"
     "const ::google::protobuf::ServiceDescriptor* $classname$::GetDescriptor() {\n"
+    "  protobuf_AssignDescriptorsOnce();\n"
     "  return $classname$_descriptor_;\n"
     "}\n"
     "\n");
@@ -279,7 +281,7 @@ void ServiceGenerator::GenerateGetPrototype(RequestOrResponse which,
 
   printer->Print(vars_,
     "    const ::google::protobuf::MethodDescriptor* method) const {\n"
-    "  GOOGLE_DCHECK_EQ(method->service(), $classname$_descriptor_);\n"
+    "  GOOGLE_DCHECK_EQ(method->service(), descriptor());\n"
     "  switch(method->index()) {\n");
 
   for (int i = 0; i < descriptor_->method_count(); i++) {
@@ -320,7 +322,7 @@ void ServiceGenerator::GenerateStubMethods(io::Printer* printer) {
       "                              const $input_type$* request,\n"
       "                              $output_type$* response,\n"
       "                              ::google::protobuf::Closure* done) {\n"
-      "  channel_->CallMethod($classname$_descriptor_->method($index$),\n"
+      "  channel_->CallMethod(descriptor()->method($index$),\n"
       "                       controller, request, response, done);\n"
       "}\n");
   }

+ 52 - 48
src/google/protobuf/compiler/cpp/cpp_string_field.cc

@@ -61,6 +61,8 @@ void SetStringVariables(const FieldDescriptor* descriptor,
   (*variables)["declared_type"] = DeclaredTypeMethodName(descriptor->type());
   (*variables)["tag_size"] = SimpleItoa(
     WireFormat::TagSize(descriptor->number(), descriptor->type()));
+  (*variables)["pointer_type"] =
+      descriptor->type() == FieldDescriptor::TYPE_BYTES ? "void" : "char";
 }
 
 }  // namespace
@@ -111,13 +113,8 @@ GenerateAccessorDeclarations(io::Printer* printer) const {
   printer->Print(variables_,
     "inline const ::std::string& $name$() const;\n"
     "inline void set_$name$(const ::std::string& value);\n"
-    "inline void set_$name$(const char* value);\n");
-  if (descriptor_->type() == FieldDescriptor::TYPE_BYTES) {
-    printer->Print(variables_,
-      "inline void set_$name$(const void* value, size_t size);\n");
-  }
-
-  printer->Print(variables_,
+    "inline void set_$name$(const char* value);\n"
+    "inline void set_$name$(const $pointer_type$* value, size_t size);\n"
     "inline ::std::string* mutable_$name$();\n");
 
   if (descriptor_->options().has_ctype()) {
@@ -146,20 +143,15 @@ GenerateInlineAccessorDefinitions(io::Printer* printer) const {
     "    $name$_ = new ::std::string;\n"
     "  }\n"
     "  $name$_->assign(value);\n"
-    "}\n");
-
-  if (descriptor_->type() == FieldDescriptor::TYPE_BYTES) {
-    printer->Print(variables_,
-      "inline void $classname$::set_$name$(const void* value, size_t size) {\n"
-      "  _set_bit($index$);\n"
-      "  if ($name$_ == &_default_$name$_) {\n"
-      "    $name$_ = new ::std::string;\n"
-      "  }\n"
-      "  $name$_->assign(reinterpret_cast<const char*>(value), size);\n"
-      "}\n");
-  }
-
-  printer->Print(variables_,
+    "}\n"
+    "inline "
+    "void $classname$::set_$name$(const $pointer_type$* value, size_t size) {\n"
+    "  _set_bit($index$);\n"
+    "  if ($name$_ == &_default_$name$_) {\n"
+    "    $name$_ = new ::std::string;\n"
+    "  }\n"
+    "  $name$_->assign(reinterpret_cast<const char*>(value), size);\n"
+    "}\n"
     "inline ::std::string* $classname$::mutable_$name$() {\n"
     "  _set_bit($index$);\n"
     "  if ($name$_ == &_default_$name$_) {\n");
@@ -213,9 +205,9 @@ GenerateSwappingCode(io::Printer* printer) const {
 }
 
 void StringFieldGenerator::
-GenerateInitializer(io::Printer* printer) const {
+GenerateConstructorCode(io::Printer* printer) const {
   printer->Print(variables_,
-    ",\n$name$_(const_cast< ::std::string*>(&_default_$name$_))");
+    "$name$_ = const_cast< ::std::string*>(&_default_$name$_);\n");
 }
 
 void StringFieldGenerator::
@@ -236,8 +228,15 @@ GenerateMergeFromCodedStream(io::Printer* printer) const {
 void StringFieldGenerator::
 GenerateSerializeWithCachedSizes(io::Printer* printer) const {
   printer->Print(variables_,
-    "DO_(::google::protobuf::internal::WireFormat::Write$declared_type$("
-      "$number$, this->$name$(), output));\n");
+    "::google::protobuf::internal::WireFormat::Write$declared_type$("
+      "$number$, this->$name$(), output);\n");
+}
+
+void StringFieldGenerator::
+GenerateSerializeWithCachedSizesToArray(io::Printer* printer) const {
+  printer->Print(variables_,
+    "target = ::google::protobuf::internal::WireFormat::Write$declared_type$ToArray("
+      "$number$, this->$name$(), target);\n");
 }
 
 void StringFieldGenerator::
@@ -281,15 +280,12 @@ GenerateAccessorDeclarations(io::Printer* printer) const {
     "inline ::std::string* mutable_$name$(int index);\n"
     "inline void set_$name$(int index, const ::std::string& value);\n"
     "inline void set_$name$(int index, const char* value);\n"
+    "inline "
+    "void set_$name$(int index, const $pointer_type$* value, size_t size);\n"
     "inline ::std::string* add_$name$();\n"
     "inline void add_$name$(const ::std::string& value);\n"
-    "inline void add_$name$(const char* value);\n");
-
-  if (descriptor_->type() == FieldDescriptor::TYPE_BYTES) {
-    printer->Print(variables_,
-      "inline void set_$name$(int index, const void* value, size_t size);\n"
-      "inline void add_$name$(const void* value, size_t size);\n");
-  }
+    "inline void add_$name$(const char* value);\n"
+    "inline void add_$name$(const $pointer_type$* value, size_t size);\n");
 
   if (descriptor_->options().has_ctype()) {
     printer->Outdent();
@@ -321,6 +317,12 @@ GenerateInlineAccessorDefinitions(io::Printer* printer) const {
     "inline void $classname$::set_$name$(int index, const char* value) {\n"
     "  $name$_.Mutable(index)->assign(value);\n"
     "}\n"
+    "inline void "
+    "$classname$::set_$name$"
+    "(int index, const $pointer_type$* value, size_t size) {\n"
+    "  $name$_.Mutable(index)->assign(\n"
+    "    reinterpret_cast<const char*>(value), size);\n"
+    "}\n"
     "inline ::std::string* $classname$::add_$name$() {\n"
     "  return $name$_.Add();\n"
     "}\n"
@@ -329,19 +331,11 @@ GenerateInlineAccessorDefinitions(io::Printer* printer) const {
     "}\n"
     "inline void $classname$::add_$name$(const char* value) {\n"
     "  $name$_.Add()->assign(value);\n"
+    "}\n"
+    "inline void "
+    "$classname$::add_$name$(const $pointer_type$* value, size_t size) {\n"
+    "  $name$_.Add()->assign(reinterpret_cast<const char*>(value), size);\n"
     "}\n");
-
-  if (descriptor_->type() == FieldDescriptor::TYPE_BYTES) {
-    printer->Print(variables_,
-      "inline void "
-      "$classname$::set_$name$(int index, const void* value, size_t size) {\n"
-      "  $name$_.Mutable(index)->assign(\n"
-      "    reinterpret_cast<const char*>(value), size);\n"
-      "}\n"
-      "inline void $classname$::add_$name$(const void* value, size_t size) {\n"
-      "  $name$_.Add()->assign(reinterpret_cast<const char*>(value), size);\n"
-      "}\n");
-  }
 }
 
 void RepeatedStringFieldGenerator::
@@ -360,8 +354,8 @@ GenerateSwappingCode(io::Printer* printer) const {
 }
 
 void RepeatedStringFieldGenerator::
-GenerateInitializer(io::Printer* printer) const {
-  printer->Print(variables_, ",\n$name$_()");
+GenerateConstructorCode(io::Printer* printer) const {
+  // Not needed for repeated fields.
 }
 
 void RepeatedStringFieldGenerator::
@@ -375,8 +369,18 @@ void RepeatedStringFieldGenerator::
 GenerateSerializeWithCachedSizes(io::Printer* printer) const {
   printer->Print(variables_,
     "for (int i = 0; i < this->$name$_size(); i++) {\n"
-    "  DO_(::google::protobuf::internal::WireFormat::Write$declared_type$("
-        "$number$, this->$name$(i), output));\n"
+    "  ::google::protobuf::internal::WireFormat::Write$declared_type$("
+        "$number$, this->$name$(i), output);\n"
+    "}\n");
+}
+
+void RepeatedStringFieldGenerator::
+GenerateSerializeWithCachedSizesToArray(io::Printer* printer) const {
+  printer->Print(variables_,
+    "for (int i = 0; i < this->$name$_size(); i++) {\n"
+    "  target = ::google::protobuf::internal::WireFormat::"
+        "Write$declared_type$ToArray("
+        "$number$, this->$name$(i), target);\n"
     "}\n");
 }
 

+ 4 - 2
src/google/protobuf/compiler/cpp/cpp_string_field.h

@@ -57,10 +57,11 @@ class StringFieldGenerator : public FieldGenerator {
   void GenerateClearingCode(io::Printer* printer) const;
   void GenerateMergingCode(io::Printer* printer) const;
   void GenerateSwappingCode(io::Printer* printer) const;
-  void GenerateInitializer(io::Printer* printer) const;
+  void GenerateConstructorCode(io::Printer* printer) const;
   void GenerateDestructorCode(io::Printer* printer) const;
   void GenerateMergeFromCodedStream(io::Printer* printer) const;
   void GenerateSerializeWithCachedSizes(io::Printer* printer) const;
+  void GenerateSerializeWithCachedSizesToArray(io::Printer* printer) const;
   void GenerateByteSize(io::Printer* printer) const;
 
  private:
@@ -82,9 +83,10 @@ class RepeatedStringFieldGenerator : public FieldGenerator {
   void GenerateClearingCode(io::Printer* printer) const;
   void GenerateMergingCode(io::Printer* printer) const;
   void GenerateSwappingCode(io::Printer* printer) const;
-  void GenerateInitializer(io::Printer* printer) const;
+  void GenerateConstructorCode(io::Printer* printer) const;
   void GenerateMergeFromCodedStream(io::Printer* printer) const;
   void GenerateSerializeWithCachedSizes(io::Printer* printer) const;
+  void GenerateSerializeWithCachedSizesToArray(io::Printer* printer) const;
   void GenerateByteSize(io::Printer* printer) const;
 
  private:

+ 136 - 4
src/google/protobuf/compiler/cpp/cpp_unittest.cc

@@ -52,6 +52,8 @@
 #include <google/protobuf/test_util.h>
 #include <google/protobuf/compiler/cpp/cpp_test_bad_identifiers.pb.h>
 #include <google/protobuf/compiler/importer.h>
+#include <google/protobuf/io/coded_stream.h>
+#include <google/protobuf/io/zero_copy_stream_impl.h>
 #include <google/protobuf/descriptor.h>
 #include <google/protobuf/descriptor.pb.h>
 #include <google/protobuf/dynamic_message.h>
@@ -61,6 +63,7 @@
 #include <google/protobuf/stubs/substitute.h>
 #include <google/protobuf/testing/googletest.h>
 #include <gtest/gtest.h>
+#include <google/protobuf/stubs/stl_util-inl.h>
 
 namespace google {
 namespace protobuf {
@@ -86,6 +89,8 @@ class MockErrorCollector : public MultiFileErrorCollector {
   }
 };
 
+#ifndef PROTOBUF_TEST_NO_DESCRIPTORS
+
 // Test that generated code has proper descriptors:
 // Parse a descriptor directly (using google::protobuf::compiler::Importer) and
 // compare it to the one that was produced by generated code.
@@ -115,6 +120,8 @@ TEST(GeneratedDescriptorTest, IdenticalDescriptors) {
             generated_decsriptor_proto.DebugString());
 }
 
+#endif  // !PROTOBUF_TEST_NO_DESCRIPTORS
+
 // ===================================================================
 
 TEST(GeneratedMessageTest, Defaults) {
@@ -222,6 +229,22 @@ TEST(GeneratedMessageTest, ClearOneField) {
   TestUtil::ExpectAllFieldsSet(message);
 }
 
+TEST(GeneratedMessageTest, StringCharStarLength) {
+  // Verify that we can use a char*,length to set one of the string fields.
+  unittest::TestAllTypes message;
+  message.set_optional_string("abcdef", 3);
+  EXPECT_EQ("abc", message.optional_string());
+
+  // Verify that we can use a char*,length to add to a repeated string field.
+  message.add_repeated_string("abcdef", 3);
+  EXPECT_EQ(1, message.repeated_string_size());
+  EXPECT_EQ("abc", message.repeated_string(0));
+
+  // Verify that we can use a char*,length to set a repeated string field.
+  message.set_repeated_string(0, "wxyz", 2);
+  EXPECT_EQ("wx", message.repeated_string(0));
+}
+
 
 TEST(GeneratedMessageTest, CopyFrom) {
   unittest::TestAllTypes message1, message2;
@@ -346,6 +369,8 @@ TEST(GeneratedMessageTest, UpcastCopyFrom) {
   TestUtil::ExpectAllFieldsSet(message2);
 }
 
+#ifndef PROTOBUF_TEST_NO_DESCRIPTORS
+
 TEST(GeneratedMessageTest, DynamicMessageCopyFrom) {
   // Test copying from a DynamicMessage, which must fall back to using
   // reflection.
@@ -366,6 +391,8 @@ TEST(GeneratedMessageTest, DynamicMessageCopyFrom) {
   TestUtil::ExpectAllFieldsSet(message2);
 }
 
+#endif  // !PROTOBUF_TEST_NO_DESCRIPTORS
+
 TEST(GeneratedMessageTest, NonEmptyMergeFrom) {
   // Test merging with a non-empty message. Code is a modified form
   // of that found in google/protobuf/reflection_ops_unittest.cc.
@@ -403,24 +430,75 @@ TEST(GeneratedMessageTest, MergeFromSelf) {
 
 #endif  // GTEST_HAS_DEATH_TEST
 
-TEST(GeneratedMessageTest, Serialization) {
+// Test the generated SerializeWithCachedSizesToArray(),
+TEST(GeneratedMessageTest, SerializationToArray) {
   unittest::TestAllTypes message1, message2;
   string data;
-
   TestUtil::SetAllFields(&message1);
-  message1.SerializeToString(&data);
+  int size = message1.ByteSize();
+  data.resize(size);
+  uint8* start = reinterpret_cast<uint8*>(string_as_array(&data));
+  uint8* end =
+      message1.TestAllTypes::SerializeWithCachedSizesToArray(start);
+  EXPECT_EQ(size, end - start);
   EXPECT_TRUE(message2.ParseFromString(data));
   TestUtil::ExpectAllFieldsSet(message2);
 
+}
 
+TEST(GeneratedMessageTest, PackedFieldsSerializationToArray) {
   unittest::TestPackedTypes packed_message1, packed_message2;
   string packed_data;
   TestUtil::SetPackedFields(&packed_message1);
-  packed_message1.SerializeToString(&packed_data);
+  int packed_size = packed_message1.ByteSize();
+  packed_data.resize(packed_size);
+  uint8* start = reinterpret_cast<uint8*>(string_as_array(&packed_data));
+  uint8* end =
+      packed_message1.TestPackedTypes::SerializeWithCachedSizesToArray(start);
+  EXPECT_EQ(packed_size, end - start);
   EXPECT_TRUE(packed_message2.ParseFromString(packed_data));
   TestUtil::ExpectPackedFieldsSet(packed_message2);
 }
 
+// Test the generated SerializeWithCachedSizes() by forcing the buffer to write
+// one byte at a time.
+TEST(GeneratedMessageTest, SerializationToStream) {
+  unittest::TestAllTypes message1, message2;
+  TestUtil::SetAllFields(&message1);
+  int size = message1.ByteSize();
+  string data;
+  data.resize(size);
+  {
+    // Allow the output stream to buffer only one byte at a time.
+    io::ArrayOutputStream array_stream(string_as_array(&data), size, 1);
+    io::CodedOutputStream output_stream(&array_stream);
+    message1.TestAllTypes::SerializeWithCachedSizes(&output_stream);
+    EXPECT_FALSE(output_stream.HadError());
+    EXPECT_EQ(size, output_stream.ByteCount());
+  }
+  EXPECT_TRUE(message2.ParseFromString(data));
+  TestUtil::ExpectAllFieldsSet(message2);
+
+}
+
+TEST(GeneratedMessageTest, PackedFieldsSerializationToStream) {
+  unittest::TestPackedTypes message1, message2;
+  TestUtil::SetPackedFields(&message1);
+  int size = message1.ByteSize();
+  string data;
+  data.resize(size);
+  {
+    // Allow the output stream to buffer only one byte at a time.
+    io::ArrayOutputStream array_stream(string_as_array(&data), size, 1);
+    io::CodedOutputStream output_stream(&array_stream);
+    message1.TestPackedTypes::SerializeWithCachedSizes(&output_stream);
+    EXPECT_FALSE(output_stream.HadError());
+    EXPECT_EQ(size, output_stream.ByteCount());
+  }
+  EXPECT_TRUE(message2.ParseFromString(data));
+  TestUtil::ExpectPackedFieldsSet(message2);
+}
+
 
 TEST(GeneratedMessageTest, Required) {
   // Test that IsInitialized() returns false if required fields are missing.
@@ -547,6 +625,8 @@ TEST(GeneratedMessageTest, TestConflictingSymbolNames) {
   EXPECT_EQ(5, message.friend_());
 }
 
+#ifndef PROTOBUF_TEST_NO_DESCRIPTORS
+
 TEST(GeneratedMessageTest, TestOptimizedForSize) {
   // We rely on the tests in reflection_ops_unittest and wire_format_unittest
   // to really test that reflection-based methods work.  Here we are mostly
@@ -614,6 +694,8 @@ TEST(GeneratedMessageTest, TestSpaceUsed) {
             message1.SpaceUsed());
 }
 
+#endif  // !PROTOBUF_TEST_NO_DESCRIPTORS
+
 // ===================================================================
 
 TEST(GeneratedEnumTest, EnumValuesAsSwitchCases) {
@@ -682,8 +764,37 @@ TEST(GeneratedEnumTest, MinAndMax) {
   }
 }
 
+#ifndef PROTOBUF_TEST_NO_DESCRIPTORS
+
+TEST(GeneratedEnumTest, Name) {
+  // "Names" in the presence of dup values are a bit arbitrary.
+  EXPECT_EQ("FOO1", unittest::TestEnumWithDupValue_Name(unittest::FOO1));
+  EXPECT_EQ("FOO1", unittest::TestEnumWithDupValue_Name(unittest::FOO2));
+
+  EXPECT_EQ("SPARSE_A", unittest::TestSparseEnum_Name(unittest::SPARSE_A));
+  EXPECT_EQ("SPARSE_B", unittest::TestSparseEnum_Name(unittest::SPARSE_B));
+  EXPECT_EQ("SPARSE_C", unittest::TestSparseEnum_Name(unittest::SPARSE_C));
+  EXPECT_EQ("SPARSE_D", unittest::TestSparseEnum_Name(unittest::SPARSE_D));
+  EXPECT_EQ("SPARSE_E", unittest::TestSparseEnum_Name(unittest::SPARSE_E));
+  EXPECT_EQ("SPARSE_F", unittest::TestSparseEnum_Name(unittest::SPARSE_F));
+  EXPECT_EQ("SPARSE_G", unittest::TestSparseEnum_Name(unittest::SPARSE_G));
+}
+
+TEST(GeneratedEnumTest, Parse) {
+  unittest::TestEnumWithDupValue dup_value = unittest::FOO1;
+  EXPECT_TRUE(unittest::TestEnumWithDupValue_Parse("FOO1", &dup_value));
+  EXPECT_EQ(unittest::FOO1, dup_value);
+  EXPECT_TRUE(unittest::TestEnumWithDupValue_Parse("FOO2", &dup_value));
+  EXPECT_EQ(unittest::FOO2, dup_value);
+  EXPECT_FALSE(unittest::TestEnumWithDupValue_Parse("FOO", &dup_value));
+}
+
+#endif  // PROTOBUF_TEST_NO_DESCRIPTORS
+
 // ===================================================================
 
+#ifndef PROTOBUF_TEST_NO_DESCRIPTORS
+
 // Support code for testing services.
 class GeneratedServiceTest : public testing::Test {
  protected:
@@ -977,6 +1088,27 @@ TEST_F(GeneratedServiceTest, NotImplemented) {
   EXPECT_TRUE(controller.called_);
 }
 
+#endif  // !PROTOBUF_TEST_NO_DESCRIPTORS
+
+// ===================================================================
+
+// This test must run last.  It verifies that descriptors were or were not
+// initialized depending on whether PROTOBUF_TEST_NO_DESCRIPTORS was defined.
+// When this is defined, we skip all tests which are expected to trigger
+// descriptor initialization.  This verifies that everything else still works
+// if descriptors are not initialized.
+TEST(DescriptorInitializationTest, Initialized) {
+#ifdef PROTOBUF_TEST_NO_DESCRIPTORS
+  bool should_have_descriptors = false;
+#else
+  bool should_have_descriptors = true;
+#endif
+
+  EXPECT_EQ(should_have_descriptors,
+    DescriptorPool::generated_pool()->InternalIsFileLoaded(
+      "google/protobuf/unittest.proto"));
+}
+
 }  // namespace cpp_unittest
 
 }  // namespace cpp

+ 26 - 8
src/google/protobuf/compiler/importer.cc

@@ -387,9 +387,22 @@ DiskSourceTree::DiskFileToVirtualFile(
   return SUCCESS;
 }
 
+bool DiskSourceTree::VirtualFileToDiskFile(const string& virtual_file,
+                                           string* disk_file) {
+  scoped_ptr<io::ZeroCopyInputStream> stream(OpenVirtualFile(virtual_file,
+                                                             disk_file));
+  return stream != NULL;
+}
+
 io::ZeroCopyInputStream* DiskSourceTree::Open(const string& filename) {
-  if (filename != CanonicalizePath(filename) ||
-      ContainsParentReference(filename)) {
+  return OpenVirtualFile(filename, NULL);
+}
+
+io::ZeroCopyInputStream* DiskSourceTree::OpenVirtualFile(
+    const string& virtual_file,
+    string* disk_file) {
+  if (virtual_file != CanonicalizePath(virtual_file) ||
+      ContainsParentReference(virtual_file)) {
     // We do not allow importing of paths containing things like ".." or
     // consecutive slashes since the compiler expects files to be uniquely
     // identified by file name.
@@ -397,16 +410,21 @@ io::ZeroCopyInputStream* DiskSourceTree::Open(const string& filename) {
   }
 
   for (int i = 0; i < mappings_.size(); i++) {
-    string disk_file;
-    if (ApplyMapping(filename, mappings_[i].virtual_path,
-                     mappings_[i].disk_path, &disk_file)) {
-      io::ZeroCopyInputStream* stream = OpenDiskFile(disk_file);
-      if (stream != NULL) return stream;
+    string temp_disk_file;
+    if (ApplyMapping(virtual_file, mappings_[i].virtual_path,
+                     mappings_[i].disk_path, &temp_disk_file)) {
+      io::ZeroCopyInputStream* stream = OpenDiskFile(temp_disk_file);
+      if (stream != NULL) {
+        if (disk_file != NULL) {
+          *disk_file = temp_disk_file;
+        }
+        return stream;
+      }
 
       if (errno == EACCES) {
         // The file exists but is not readable.
         // TODO(kenton):  Find a way to report this more nicely.
-        GOOGLE_LOG(WARNING) << "Read access is denied for file: " << disk_file;
+        GOOGLE_LOG(WARNING) << "Read access is denied for file: " << temp_disk_file;
         return NULL;
       }
     }

+ 10 - 0
src/google/protobuf/compiler/importer.h

@@ -267,6 +267,11 @@ class LIBPROTOBUF_EXPORT DiskSourceTree : public SourceTree {
                           string* virtual_file,
                           string* shadowing_disk_file);
 
+  // Given a virtual path, find the path to the file on disk.
+  // Return true and update disk_file with the on-disk path if the file exists.
+  // Return false and leave disk_file untouched if the file doesn't exist.
+  bool VirtualFileToDiskFile(const string& virtual_file, string* disk_file);
+
   // implements SourceTree -------------------------------------------
   io::ZeroCopyInputStream* Open(const string& filename);
 
@@ -280,6 +285,11 @@ class LIBPROTOBUF_EXPORT DiskSourceTree : public SourceTree {
   };
   vector<Mapping> mappings_;
 
+  // Like Open(), but returns the on-disk path in disk_file if disk_file is
+  // non-NULL and the file could be successfully opened.
+  io::ZeroCopyInputStream* OpenVirtualFile(const string& virtual_file,
+                                           string* disk_file);
+
   // Like Open() but given the actual on-disk path.
   io::ZeroCopyInputStream* OpenDiskFile(const string& filename);
 

+ 28 - 0
src/google/protobuf/compiler/importer_unittest.cc

@@ -565,6 +565,34 @@ TEST_F(DiskSourceTreeTest, DiskFileToVirtualFileCanonicalization) {
   EXPECT_EQ("dir5/bar", virtual_file);
 }
 
+TEST_F(DiskSourceTreeTest, VirtualFileToDiskFile) {
+  // Test VirtualFileToDiskFile.
+
+  AddFile(dirnames_[0] + "/foo", "Hello World!");
+  AddFile(dirnames_[1] + "/foo", "This file should be hidden.");
+  AddFile(dirnames_[1] + "/quux", "This file should not be hidden.");
+  source_tree_.MapPath("bar", dirnames_[0]);
+  source_tree_.MapPath("bar", dirnames_[1]);
+
+  // Existent files, shadowed and non-shadowed case.
+  string disk_file;
+  EXPECT_TRUE(source_tree_.VirtualFileToDiskFile("bar/foo", &disk_file));
+  EXPECT_EQ(dirnames_[0] + "/foo", disk_file);
+  EXPECT_TRUE(source_tree_.VirtualFileToDiskFile("bar/quux", &disk_file));
+  EXPECT_EQ(dirnames_[1] + "/quux", disk_file);
+
+  // Nonexistent file in existent directory and vice versa.
+  string not_touched = "not touched";
+  EXPECT_FALSE(source_tree_.VirtualFileToDiskFile("bar/baz", &not_touched));
+  EXPECT_EQ("not touched", not_touched);
+  EXPECT_FALSE(source_tree_.VirtualFileToDiskFile("baz/foo", &not_touched));
+  EXPECT_EQ("not touched", not_touched);
+
+  // Accept NULL as output parameter.
+  EXPECT_TRUE(source_tree_.VirtualFileToDiskFile("bar/foo", NULL));
+  EXPECT_FALSE(source_tree_.VirtualFileToDiskFile("baz/foo", NULL));
+}
+
 }  // namespace
 
 }  // namespace compiler

+ 18 - 2
src/google/protobuf/compiler/java/java_message.cc

@@ -466,6 +466,17 @@ GenerateParseFromMethods(io::Printer* printer) {
     "  return newBuilder().mergeFrom(input, extensionRegistry)\n"
     "           .buildParsed();\n"
     "}\n"
+    "public static $classname$ parseDelimitedFrom(java.io.InputStream input)\n"
+    "    throws java.io.IOException {\n"
+    "  return newBuilder().mergeDelimitedFrom(input).buildParsed();\n"
+    "}\n"
+    "public static $classname$ parseDelimitedFrom(\n"
+    "    java.io.InputStream input,\n"
+    "    com.google.protobuf.ExtensionRegistry extensionRegistry)\n"
+    "    throws java.io.IOException {\n"
+    "  return newBuilder().mergeDelimitedFrom(input, extensionRegistry)\n"
+    "           .buildParsed();\n"
+    "}\n"
     "public static $classname$ parseFrom(\n"
     "    com.google.protobuf.CodedInputStream input)\n"
     "    throws java.io.IOException {\n"
@@ -579,7 +590,8 @@ void MessageGenerator::GenerateCommonBuilderMethods(io::Printer* printer) {
 
   printer->Print(
     "public $classname$ build() {\n"
-    "  if (!isInitialized()) {\n"
+    // If result == null, we'll throw an appropriate exception later.
+    "  if (result != null && !isInitialized()) {\n"
     "    throw new com.google.protobuf.UninitializedMessageException(\n"
     "      result);\n"
     "  }\n"
@@ -595,7 +607,11 @@ void MessageGenerator::GenerateCommonBuilderMethods(io::Printer* printer) {
     "  return buildPartial();\n"
     "}\n"
     "\n"
-    "public $classname$ buildPartial() {\n",
+    "public $classname$ buildPartial() {\n"
+    "  if (result == null) {\n"
+    "    throw new IllegalStateException(\n"
+    "      \"build() has already been called on this Builder.\");"
+    "  }\n",
     "classname", ClassName(descriptor_));
   printer->Indent();
 

+ 243 - 38
src/google/protobuf/compiler/java/java_service.cc

@@ -57,45 +57,111 @@ void ServiceGenerator::Generate(io::Printer* printer) {
     "classname", descriptor_->name());
   printer->Indent();
 
-  // Generate abstract method declarations.
-  for (int i = 0; i < descriptor_->method_count(); i++) {
-    const MethodDescriptor* method = descriptor_->method(i);
-    map<string, string> vars;
-    vars["name"] = UnderscoresToCamelCase(method);
-    vars["input"] = ClassName(method->input_type());
-    vars["output"] = ClassName(method->output_type());
-    printer->Print(vars,
-      "public abstract void $name$(\n"
-      "    com.google.protobuf.RpcController controller,\n"
-      "    $input$ request,\n"
-      "    com.google.protobuf.RpcCallback<$output$> done);\n");
-  }
+  printer->Print(
+    "protected $classname$() {}\n\n",
+    "classname", descriptor_->name());
+
+  GenerateInterface(printer);
+
+  GenerateNewReflectiveServiceMethod(printer);
+  GenerateNewReflectiveBlockingServiceMethod(printer);
+
+  GenerateAbstractMethods(printer);
 
   // Generate getDescriptor() and getDescriptorForType().
   printer->Print(
-    "\n"
     "public static final\n"
     "    com.google.protobuf.Descriptors.ServiceDescriptor\n"
     "    getDescriptor() {\n"
     "  return $file$.getDescriptor().getServices().get($index$);\n"
-    "}\n"
-    "public final com.google.protobuf.Descriptors.ServiceDescriptor\n"
-    "    getDescriptorForType() {\n"
-    "  return getDescriptor();\n"
     "}\n",
     "file", ClassName(descriptor_->file()),
     "index", SimpleItoa(descriptor_->index()));
+  GenerateGetDescriptorForType(printer);
 
   // Generate more stuff.
   GenerateCallMethod(printer);
   GenerateGetPrototype(REQUEST, printer);
   GenerateGetPrototype(RESPONSE, printer);
   GenerateStub(printer);
+  GenerateBlockingStub(printer);
 
   printer->Outdent();
   printer->Print("}\n\n");
 }
 
+void ServiceGenerator::GenerateGetDescriptorForType(io::Printer* printer) {
+  printer->Print(
+    "public final com.google.protobuf.Descriptors.ServiceDescriptor\n"
+    "    getDescriptorForType() {\n"
+    "  return getDescriptor();\n"
+    "}\n");
+}
+
+void ServiceGenerator::GenerateInterface(io::Printer* printer) {
+  printer->Print("public interface Interface {\n");
+  printer->Indent();
+  GenerateAbstractMethods(printer);
+  printer->Outdent();
+  printer->Print("}\n\n");
+}
+
+void ServiceGenerator::GenerateNewReflectiveServiceMethod(
+    io::Printer* printer) {
+  printer->Print(
+    "public static com.google.protobuf.Service newReflectiveService(\n"
+    "    final Interface impl) {\n"
+    "  return new $classname$() {\n",
+    "classname", descriptor_->name());
+  printer->Indent();
+  printer->Indent();
+
+  for (int i = 0; i < descriptor_->method_count(); i++) {
+    const MethodDescriptor* method = descriptor_->method(i);
+    printer->Print("@Override\n");
+    GenerateMethodSignature(printer, method, IS_CONCRETE);
+    printer->Print(
+      " {\n"
+      "  impl.$method$(controller, request, done);\n"
+      "}\n\n",
+      "method", UnderscoresToCamelCase(method));
+  }
+
+  printer->Outdent();
+  printer->Print("};\n");
+  printer->Outdent();
+  printer->Print("}\n\n");
+}
+
+void ServiceGenerator::GenerateNewReflectiveBlockingServiceMethod(
+    io::Printer* printer) {
+  printer->Print(
+    "public static com.google.protobuf.BlockingService\n"
+    "    newReflectiveBlockingService(final BlockingInterface impl) {\n"
+    "  return new com.google.protobuf.BlockingService() {\n");
+  printer->Indent();
+  printer->Indent();
+
+  GenerateGetDescriptorForType(printer);
+
+  GenerateCallBlockingMethod(printer);
+  GenerateGetPrototype(REQUEST, printer);
+  GenerateGetPrototype(RESPONSE, printer);
+
+  printer->Outdent();
+  printer->Print("};\n");
+  printer->Outdent();
+  printer->Print("}\n\n");
+}
+
+void ServiceGenerator::GenerateAbstractMethods(io::Printer* printer) {
+  for (int i = 0; i < descriptor_->method_count(); i++) {
+    const MethodDescriptor* method = descriptor_->method(i);
+    GenerateMethodSignature(printer, method, IS_ABSTRACT);
+    printer->Print(";\n\n");
+  }
+}
+
 void ServiceGenerator::GenerateCallMethod(io::Printer* printer) {
   printer->Print(
     "\n"
@@ -131,7 +197,49 @@ void ServiceGenerator::GenerateCallMethod(io::Printer* printer) {
 
   printer->Print(
     "default:\n"
-    "  throw new java.lang.RuntimeException(\"Can't get here.\");\n");
+    "  throw new java.lang.AssertionError(\"Can't get here.\");\n");
+
+  printer->Outdent();
+  printer->Outdent();
+
+  printer->Print(
+    "  }\n"
+    "}\n"
+    "\n");
+}
+
+void ServiceGenerator::GenerateCallBlockingMethod(io::Printer* printer) {
+  printer->Print(
+    "\n"
+    "public final com.google.protobuf.Message callBlockingMethod(\n"
+    "    com.google.protobuf.Descriptors.MethodDescriptor method,\n"
+    "    com.google.protobuf.RpcController controller,\n"
+    "    com.google.protobuf.Message request)\n"
+    "    throws com.google.protobuf.ServiceException {\n"
+    "  if (method.getService() != getDescriptor()) {\n"
+    "    throw new java.lang.IllegalArgumentException(\n"
+    "      \"Service.callBlockingMethod() given method descriptor for \" +\n"
+    "      \"wrong service type.\");\n"
+    "  }\n"
+    "  switch(method.getIndex()) {\n");
+  printer->Indent();
+  printer->Indent();
+
+  for (int i = 0; i < descriptor_->method_count(); i++) {
+    const MethodDescriptor* method = descriptor_->method(i);
+    map<string, string> vars;
+    vars["index"] = SimpleItoa(i);
+    vars["method"] = UnderscoresToCamelCase(method);
+    vars["input"] = ClassName(method->input_type());
+    vars["output"] = ClassName(method->output_type());
+    printer->Print(vars,
+      "case $index$:\n"
+      "  return impl.$method$(controller, ($input$)request);\n");
+  }
+
+  printer->Print(
+    "default:\n"
+    "  throw new java.lang.AssertionError(\"Can't get here.\");\n");
 
   printer->Outdent();
   printer->Outdent();
@@ -144,6 +252,10 @@ void ServiceGenerator::GenerateCallMethod(io::Printer* printer) {
 
 void ServiceGenerator::GenerateGetPrototype(RequestOrResponse which,
                                             io::Printer* printer) {
+  /*
+   * TODO(cpovirk): The exception message says "Service.foo" when it may be
+   * "BlockingService.foo."  Consider fixing.
+   */
   printer->Print(
     "public final com.google.protobuf.Message\n"
     "    get$request_or_response$Prototype(\n"
@@ -171,7 +283,7 @@ void ServiceGenerator::GenerateGetPrototype(RequestOrResponse which,
 
   printer->Print(
     "default:\n"
-    "  throw new java.lang.RuntimeException(\"Can't get here.\");\n");
+    "  throw new java.lang.AssertionError(\"Can't get here.\");\n");
 
   printer->Outdent();
   printer->Outdent();
@@ -189,7 +301,8 @@ void ServiceGenerator::GenerateStub(io::Printer* printer) {
     "  return new Stub(channel);\n"
     "}\n"
     "\n"
-    "public static final class Stub extends $classname$ {\n",
+    "public static final class Stub extends $classname$ implements Interface {"
+    "\n",
     "classname", ClassName(descriptor_));
   printer->Indent();
 
@@ -206,33 +319,125 @@ void ServiceGenerator::GenerateStub(io::Printer* printer) {
 
   for (int i = 0; i < descriptor_->method_count(); i++) {
     const MethodDescriptor* method = descriptor_->method(i);
+    printer->Print("\n");
+    GenerateMethodSignature(printer, method, IS_CONCRETE);
+    printer->Print(" {\n");
+    printer->Indent();
+
     map<string, string> vars;
     vars["index"] = SimpleItoa(i);
-    vars["method"] = UnderscoresToCamelCase(method);
-    vars["input"] = ClassName(method->input_type());
     vars["output"] = ClassName(method->output_type());
     printer->Print(vars,
-      "\n"
-      "public void $method$(\n"
-      "    com.google.protobuf.RpcController controller,\n"
-      "    $input$ request,\n"
-      "    com.google.protobuf.RpcCallback<$output$> done) {\n"
-      "  channel.callMethod(\n"
-      "    getDescriptor().getMethods().get($index$),\n"
-      "    controller,\n"
-      "    request,\n"
-      "    $output$.getDefaultInstance(),\n"
-      "    com.google.protobuf.RpcUtil.generalizeCallback(\n"
-      "      done,\n"
-      "      $output$.class,\n"
-      "      $output$.getDefaultInstance()));\n"
-      "}\n");
+      "channel.callMethod(\n"
+      "  getDescriptor().getMethods().get($index$),\n"
+      "  controller,\n"
+      "  request,\n"
+      "  $output$.getDefaultInstance(),\n"
+      "  com.google.protobuf.RpcUtil.generalizeCallback(\n"
+      "    done,\n"
+      "    $output$.class,\n"
+      "    $output$.getDefaultInstance()));\n");
+
+    printer->Outdent();
+    printer->Print("}\n");
+  }
+
+  printer->Outdent();
+  printer->Print(
+    "}\n"
+    "\n");
+}
+
+void ServiceGenerator::GenerateBlockingStub(io::Printer* printer) {
+  printer->Print(
+    "public static BlockingInterface newBlockingStub(\n"
+    "    com.google.protobuf.BlockingRpcChannel channel) {\n"
+    "  return new BlockingStub(channel);\n"
+    "}\n"
+    "\n");
+
+  printer->Print(
+    "public interface BlockingInterface {");
+  printer->Indent();
+
+  for (int i = 0; i < descriptor_->method_count(); i++) {
+    const MethodDescriptor* method = descriptor_->method(i);
+    GenerateBlockingMethodSignature(printer, method);
+    printer->Print(";\n");
+  }
+
+  printer->Outdent();
+  printer->Print(
+    "}\n"
+    "\n");
+
+  printer->Print(
+    "private static final class BlockingStub implements BlockingInterface {\n");
+  printer->Indent();
+
+  printer->Print(
+    "private BlockingStub(com.google.protobuf.BlockingRpcChannel channel) {\n"
+    "  this.channel = channel;\n"
+    "}\n"
+    "\n"
+    "private final com.google.protobuf.BlockingRpcChannel channel;\n");
+
+  for (int i = 0; i < descriptor_->method_count(); i++) {
+    const MethodDescriptor* method = descriptor_->method(i);
+    GenerateBlockingMethodSignature(printer, method);
+    printer->Print(" {\n");
+    printer->Indent();
+
+    map<string, string> vars;
+    vars["index"] = SimpleItoa(i);
+    vars["output"] = ClassName(method->output_type());
+    printer->Print(vars,
+      "return ($output$) channel.callBlockingMethod(\n"
+      "  getDescriptor().getMethods().get($index$),\n"
+      "  controller,\n"
+      "  request,\n"
+      "  $output$.getDefaultInstance());\n");
+
+    printer->Outdent();
+    printer->Print(
+      "}\n"
+      "\n");
   }
 
   printer->Outdent();
   printer->Print("}\n");
 }
 
+void ServiceGenerator::GenerateMethodSignature(io::Printer* printer,
+                                               const MethodDescriptor* method,
+                                               IsAbstract is_abstract) {
+  map<string, string> vars;
+  vars["name"] = UnderscoresToCamelCase(method);
+  vars["input"] = ClassName(method->input_type());
+  vars["output"] = ClassName(method->output_type());
+  vars["abstract"] = (is_abstract == IS_ABSTRACT) ? "abstract" : "";
+  printer->Print(vars,
+    "public $abstract$ void $name$(\n"
+    "    com.google.protobuf.RpcController controller,\n"
+    "    $input$ request,\n"
+    "    com.google.protobuf.RpcCallback<$output$> done)");
+}
+
+void ServiceGenerator::GenerateBlockingMethodSignature(
+    io::Printer* printer,
+    const MethodDescriptor* method) {
+  map<string, string> vars;
+  vars["method"] = UnderscoresToCamelCase(method);
+  vars["input"] = ClassName(method->input_type());
+  vars["output"] = ClassName(method->output_type());
+  printer->Print(vars,
+    "\n"
+    "public $output$ $method$(\n"
+    "    com.google.protobuf.RpcController controller,\n"
+    "    $input$ request)\n"
+    "    throws com.google.protobuf.ServiceException");
+}
+
 }  // namespace java
 }  // namespace compiler
 }  // namespace protobuf

+ 33 - 0
src/google/protobuf/compiler/java/java_service.h

@@ -57,9 +57,28 @@ class ServiceGenerator {
   void Generate(io::Printer* printer);
 
  private:
+
+  // Generate the getDescriptorForType() method.
+  void GenerateGetDescriptorForType(io::Printer* printer);
+
+  // Generate a Java interface for the service.
+  void GenerateInterface(io::Printer* printer);
+
+  // Generate newReflectiveService() method.
+  void GenerateNewReflectiveServiceMethod(io::Printer* printer);
+
+  // Generate newReflectiveBlockingService() method.
+  void GenerateNewReflectiveBlockingServiceMethod(io::Printer* printer);
+
+  // Generate abstract method declarations for all methods.
+  void GenerateAbstractMethods(io::Printer* printer);
+
   // Generate the implementation of Service.callMethod().
   void GenerateCallMethod(io::Printer* printer);
 
+  // Generate the implementation of BlockingService.callBlockingMethod().
+  void GenerateCallBlockingMethod(io::Printer* printer);
+
   // Generate the implementations of Service.get{Request,Response}Prototype().
   enum RequestOrResponse { REQUEST, RESPONSE };
   void GenerateGetPrototype(RequestOrResponse which, io::Printer* printer);
@@ -67,6 +86,20 @@ class ServiceGenerator {
   // Generate a stub implementation of the service.
   void GenerateStub(io::Printer* printer);
 
+  // Generate a method signature, possibly abstract, without body or trailing
+  // semicolon.
+  enum IsAbstract { IS_ABSTRACT, IS_CONCRETE };
+  void GenerateMethodSignature(io::Printer* printer,
+                               const MethodDescriptor* method,
+                               IsAbstract is_abstract);
+
+  // Generate a blocking stub interface and implementation of the service.
+  void GenerateBlockingStub(io::Printer* printer);
+
+  // Generate the method signature for one method of a blocking stub.
+  void GenerateBlockingMethodSignature(io::Printer* printer,
+                                       const MethodDescriptor* method);
+
   const ServiceDescriptor* descriptor_;
 
   GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(ServiceGenerator);

+ 6 - 3
src/google/protobuf/compiler/parser.cc

@@ -97,7 +97,8 @@ Parser::Parser()
     error_collector_(NULL),
     source_location_table_(NULL),
     had_errors_(false),
-    require_syntax_identifier_(false) {
+    require_syntax_identifier_(false),
+    stop_after_syntax_identifier_(false) {
 }
 
 Parser::~Parser() {
@@ -309,10 +310,12 @@ bool Parser::Parse(io::Tokenizer* input, FileDescriptorProto* file) {
       // identifier.
       return false;
     }
-  } else {
+  } else if (!stop_after_syntax_identifier_) {
     syntax_identifier_ = "proto2";
   }
 
+  if (stop_after_syntax_identifier_) return !had_errors_;
+
   // Repeatedly parse statements until we reach the end of the file.
   while (!AtEnd()) {
     if (!ParseTopLevelStatement(file)) {
@@ -341,7 +344,7 @@ bool Parser::ParseSyntaxIdentifier() {
 
   syntax_identifier_ = syntax;
 
-  if (syntax != "proto2") {
+  if (syntax != "proto2" && !stop_after_syntax_identifier_) {
     AddError(syntax_token.line, syntax_token.column,
       "Unrecognized syntax identifier \"" + syntax + "\".  This parser "
       "only recognizes \"proto2\".");

+ 14 - 1
src/google/protobuf/compiler/parser.h

@@ -90,7 +90,7 @@ class LIBPROTOBUF_EXPORT Parser {
 
   // Returns the identifier used in the "syntax = " declaration, if one was
   // seen during the last call to Parse(), or the empty string otherwise.
-  const string& GetSyntaxIndentifier() { return syntax_identifier_; }
+  const string& GetSyntaxIdentifier() { return syntax_identifier_; }
 
   // If set true, input files will be required to begin with a syntax
   // identifier.  Otherwise, files may omit this.  If a syntax identifier
@@ -100,6 +100,18 @@ class LIBPROTOBUF_EXPORT Parser {
     require_syntax_identifier_ = value;
   }
 
+  // Call SetStopAfterSyntaxIdentifier(true) to tell the parser to stop
+  // parsing as soon as it has seen the syntax identifier, or lack thereof.
+  // This is useful for quickly identifying the syntax of the file without
+  // parsing the whole thing.  If this is enabled, no error will be recorded
+  // if the syntax identifier is something other than "proto2" (since
+  // presumably the caller intends to deal with that), but other kinds of
+  // errors (e.g. parse errors) will still be reported.  When this is enabled,
+  // you may pass a NULL FileDescriptorProto to Parse().
+  void SetStopAfterSyntaxIdentifier(bool value) {
+    stop_after_syntax_identifier_ = value;
+  }
+
  private:
   // =================================================================
   // Error recovery helpers
@@ -281,6 +293,7 @@ class LIBPROTOBUF_EXPORT Parser {
   SourceLocationTable* source_location_table_;
   bool had_errors_;
   bool require_syntax_identifier_;
+  bool stop_after_syntax_identifier_;
   string syntax_identifier_;
 
   GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(Parser);

+ 38 - 6
src/google/protobuf/compiler/parser_unittest.cc

@@ -176,6 +176,38 @@ class ParserTest : public testing::Test {
 
 // ===================================================================
 
+TEST_F(ParserTest, StopAfterSyntaxIdentifier) {
+  SetupParser(
+    "// blah\n"
+    "syntax = \"foobar\";\n"
+    "this line will not be parsed\n");
+  parser_->SetStopAfterSyntaxIdentifier(true);
+  EXPECT_TRUE(parser_->Parse(input_.get(), NULL));
+  EXPECT_EQ("", error_collector_.text_);
+  EXPECT_EQ("foobar", parser_->GetSyntaxIdentifier());
+}
+
+TEST_F(ParserTest, StopAfterOmittedSyntaxIdentifier) {
+  SetupParser(
+    "// blah\n"
+    "this line will not be parsed\n");
+  parser_->SetStopAfterSyntaxIdentifier(true);
+  EXPECT_TRUE(parser_->Parse(input_.get(), NULL));
+  EXPECT_EQ("", error_collector_.text_);
+  EXPECT_EQ("", parser_->GetSyntaxIdentifier());
+}
+
+TEST_F(ParserTest, StopAfterSyntaxIdentifierWithErrors) {
+  SetupParser(
+    "// blah\n"
+    "syntax = error;\n");
+  parser_->SetStopAfterSyntaxIdentifier(true);
+  EXPECT_FALSE(parser_->Parse(input_.get(), NULL));
+  EXPECT_EQ("1:9: Expected syntax identifier.\n", error_collector_.text_);
+}
+
+// ===================================================================
+
 typedef ParserTest ParseMessageTest;
 
 TEST_F(ParseMessageTest, SimpleMessage) {
@@ -201,7 +233,7 @@ TEST_F(ParseMessageTest, ImplicitSyntaxIdentifier) {
     "  name: \"TestMessage\""
     "  field { name:\"foo\" label:LABEL_REQUIRED type:TYPE_INT32 number:1 }"
     "}");
-  EXPECT_EQ("proto2", parser_->GetSyntaxIndentifier());
+  EXPECT_EQ("proto2", parser_->GetSyntaxIdentifier());
 }
 
 TEST_F(ParseMessageTest, ExplicitSyntaxIdentifier) {
@@ -215,7 +247,7 @@ TEST_F(ParseMessageTest, ExplicitSyntaxIdentifier) {
     "  name: \"TestMessage\""
     "  field { name:\"foo\" label:LABEL_REQUIRED type:TYPE_INT32 number:1 }"
     "}");
-  EXPECT_EQ("proto2", parser_->GetSyntaxIndentifier());
+  EXPECT_EQ("proto2", parser_->GetSyntaxIdentifier());
 }
 
 TEST_F(ParseMessageTest, ExplicitRequiredSyntaxIdentifier) {
@@ -230,7 +262,7 @@ TEST_F(ParseMessageTest, ExplicitRequiredSyntaxIdentifier) {
     "  name: \"TestMessage\""
     "  field { name:\"foo\" label:LABEL_REQUIRED type:TYPE_INT32 number:1 }"
     "}");
-  EXPECT_EQ("proto2", parser_->GetSyntaxIndentifier());
+  EXPECT_EQ("proto2", parser_->GetSyntaxIdentifier());
 }
 
 TEST_F(ParseMessageTest, SimpleFields) {
@@ -673,7 +705,7 @@ TEST_F(ParseErrorTest, MissingSyntaxIdentifier) {
   ExpectHasEarlyExitErrors(
     "message TestMessage {}",
     "0:0: File must begin with 'syntax = \"proto2\";'.\n");
-  EXPECT_EQ("", parser_->GetSyntaxIndentifier());
+  EXPECT_EQ("", parser_->GetSyntaxIdentifier());
 }
 
 TEST_F(ParseErrorTest, UnknownSyntaxIdentifier) {
@@ -681,14 +713,14 @@ TEST_F(ParseErrorTest, UnknownSyntaxIdentifier) {
     "syntax = \"no_such_syntax\";",
     "0:9: Unrecognized syntax identifier \"no_such_syntax\".  This parser "
       "only recognizes \"proto2\".\n");
-  EXPECT_EQ("no_such_syntax", parser_->GetSyntaxIndentifier());
+  EXPECT_EQ("no_such_syntax", parser_->GetSyntaxIdentifier());
 }
 
 TEST_F(ParseErrorTest, SimpleSyntaxError) {
   ExpectHasErrors(
     "message TestMessage @#$ { blah }",
     "0:20: Expected \"{\".\n");
-  EXPECT_EQ("proto2", parser_->GetSyntaxIndentifier());
+  EXPECT_EQ("proto2", parser_->GetSyntaxIdentifier());
 }
 
 TEST_F(ParseErrorTest, ExpectedTopLevel) {

Tiedoston diff-näkymää rajattu, sillä se on liian suuri
+ 346 - 193
src/google/protobuf/descriptor.cc


+ 86 - 11
src/google/protobuf/descriptor.h

@@ -55,6 +55,7 @@
 #define GOOGLE_PROTOBUF_DESCRIPTOR_H__
 
 #include <string>
+#include <vector>
 #include <google/protobuf/stubs/common.h>
 
 
@@ -94,6 +95,7 @@ class Message;
 
 // Defined in descriptor.cc
 class DescriptorBuilder;
+class FileDescriptorTables;
 
 // Defined in unknown_field_set.h.
 class UnknownField;
@@ -246,6 +248,12 @@ class LIBPROTOBUF_EXPORT Descriptor {
   const FileDescriptor* file_;
   const Descriptor* containing_type_;
   const MessageOptions* options_;
+
+  // True if this is a placeholder for an unknown type.
+  bool is_placeholder_;
+  // True if this is a placeholder and the type name wasn't fully-qualified.
+  bool is_unqualified_placeholder_;
+
   int field_count_;
   FieldDescriptor* fields_;
   int nested_type_count_;
@@ -256,12 +264,16 @@ class LIBPROTOBUF_EXPORT Descriptor {
   ExtensionRange* extension_ranges_;
   int extension_count_;
   FieldDescriptor* extensions_;
+  // IMPORTANT:  If you add a new field, make sure to search for all instances
+  // of Allocate<Descriptor>() and AllocateArray<Descriptor>() in descriptor.cc
+  // and update them to initialize the field.
 
   // Must be constructed using DescriptorPool.
   Descriptor() {}
   friend class DescriptorBuilder;
   friend class EnumDescriptor;
   friend class FieldDescriptor;
+  friend class MethodDescriptor;
   friend class FileDescriptor;
   GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(Descriptor);
 };
@@ -458,6 +470,10 @@ class LIBPROTOBUF_EXPORT FieldDescriptor {
 
   // See Descriptor::DebugString().
   string DebugString() const;
+
+  // Helper method to get the CppType for a particular Type.
+  static CppType TypeToCppType(Type type);
+
  private:
   typedef FieldOptions OptionsType;
 
@@ -484,6 +500,9 @@ class LIBPROTOBUF_EXPORT FieldDescriptor {
   const EnumDescriptor* enum_type_;
   const FieldDescriptor* experimental_map_key_;
   const FieldOptions* options_;
+  // IMPORTANT:  If you add a new field, make sure to search for all instances
+  // of Allocate<FieldDescriptor>() and AllocateArray<FieldDescriptor>() in
+  // descriptor.cc and update them to initialize the field.
 
   bool has_default_value_;
   union {
@@ -568,15 +587,25 @@ class LIBPROTOBUF_EXPORT EnumDescriptor {
   const string* name_;
   const string* full_name_;
   const FileDescriptor* file_;
-  int value_count_;
-  EnumValueDescriptor* values_;
   const Descriptor* containing_type_;
   const EnumOptions* options_;
 
+  // True if this is a placeholder for an unknown type.
+  bool is_placeholder_;
+  // True if this is a placeholder and the type name wasn't fully-qualified.
+  bool is_unqualified_placeholder_;
+
+  int value_count_;
+  EnumValueDescriptor* values_;
+  // IMPORTANT:  If you add a new field, make sure to search for all instances
+  // of Allocate<EnumDescriptor>() and AllocateArray<EnumDescriptor>() in
+  // descriptor.cc and update them to initialize the field.
+
   // Must be constructed using DescriptorPool.
   EnumDescriptor() {}
   friend class DescriptorBuilder;
   friend class Descriptor;
+  friend class FieldDescriptor;
   friend class EnumValueDescriptor;
   friend class FileDescriptor;
   GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(EnumDescriptor);
@@ -627,6 +656,9 @@ class LIBPROTOBUF_EXPORT EnumValueDescriptor {
   int number_;
   const EnumDescriptor* type_;
   const EnumValueOptions* options_;
+  // IMPORTANT:  If you add a new field, make sure to search for all instances
+  // of Allocate<EnumValueDescriptor>() and AllocateArray<EnumValueDescriptor>()
+  // in descriptor.cc and update them to initialize the field.
 
   // Must be constructed using DescriptorPool.
   EnumValueDescriptor() {}
@@ -685,6 +717,9 @@ class LIBPROTOBUF_EXPORT ServiceDescriptor {
   const ServiceOptions* options_;
   int method_count_;
   MethodDescriptor* methods_;
+  // IMPORTANT:  If you add a new field, make sure to search for all instances
+  // of Allocate<ServiceDescriptor>() and AllocateArray<ServiceDescriptor>() in
+  // descriptor.cc and update them to initialize the field.
 
   // Must be constructed using DescriptorPool.
   ServiceDescriptor() {}
@@ -740,6 +775,9 @@ class LIBPROTOBUF_EXPORT MethodDescriptor {
   const Descriptor* input_type_;
   const Descriptor* output_type_;
   const MethodOptions* options_;
+  // IMPORTANT:  If you add a new field, make sure to search for all instances
+  // of Allocate<MethodDescriptor>() and AllocateArray<MethodDescriptor>() in
+  // descriptor.cc and update them to initialize the field.
 
   // Must be constructed using DescriptorPool.
   MethodDescriptor() {}
@@ -846,6 +884,11 @@ class LIBPROTOBUF_EXPORT FileDescriptor {
   FieldDescriptor* extensions_;
   const FileOptions* options_;
 
+  const FileDescriptorTables* tables_;
+  // IMPORTANT:  If you add a new field, make sure to search for all instances
+  // of Allocate<FileDescriptor>() and AllocateArray<FileDescriptor>() in
+  // descriptor.cc and update them to initialize the field.
+
   FileDescriptor() {}
   friend class DescriptorBuilder;
   friend class Descriptor;
@@ -945,6 +988,14 @@ class LIBPROTOBUF_EXPORT DescriptorPool {
   const FieldDescriptor* FindExtensionByNumber(const Descriptor* extendee,
                                                int number) const;
 
+  // Finds extensions of extendee. The extensions will be appended to
+  // out in an undefined order. Only extensions defined directly in
+  // this DescriptorPool or one of its underlays are guaranteed to be
+  // found: extensions defined in the fallback database might not be found
+  // depending on the database implementation.
+  void FindAllExtensions(const Descriptor* extendee,
+                         vector<const FieldDescriptor*>* out) const;
+
   // Building descriptors --------------------------------------------
 
   // When converting a FileDescriptorProto to a FileDescriptor, various
@@ -996,6 +1047,23 @@ class LIBPROTOBUF_EXPORT DescriptorPool {
     const FileDescriptorProto& proto,
     ErrorCollector* error_collector);
 
+  // By default, it is an error if a FileDescriptorProto contains references
+  // to types or other files that are not found in the DescriptorPool (or its
+  // backing DescriptorDatabase, if any).  If you call
+  // AllowUnknownDependencies(), however, then unknown types and files
+  // will be replaced by placeholder descriptors.  This can allow you to
+  // perform some useful operations with a .proto file even if you do not
+  // have access to other .proto files on which it depends.  However, some
+  // heuristics must be used to fill in the gaps in information, and these
+  // can lead to descriptors which are inaccurate.  For example, the
+  // DescriptorPool may be forced to guess whether an unknown type is a message
+  // or an enum, as well as what package it resides in.  Furthermore,
+  // placeholder types will not be discoverable via FindMessageTypeByName()
+  // and similar methods, which could confuse some descriptor-based algorithms.
+  // Generally, the results of this option should only be relied upon for
+  // debugging purposes.
+  void AllowUnknownDependencies() { allow_unknown_ = true; }
+
   // Internal stuff --------------------------------------------------
   // These methods MUST NOT be called from outside the proto2 library.
   // These methods may contain hidden pitfalls and may be removed in a
@@ -1024,12 +1092,12 @@ class LIBPROTOBUF_EXPORT DescriptorPool {
   // underlay for a new DescriptorPool in which you add only the new file.
   explicit DescriptorPool(const DescriptorPool* underlay);
 
-  // Called by generated classes at init time.  Do NOT call this in your own
-  // code! descriptor_assigner, if not NULL, is used to assign global
-  // descriptor pointers at the appropriate point during building.
-  typedef void (*InternalDescriptorAssigner)(const FileDescriptor*);
-  const FileDescriptor* InternalBuildGeneratedFile(
-    const void* data, int size, InternalDescriptorAssigner descriptor_assigner);
+  // Called by generated classes at init time to add their descriptors to
+  // generated_pool.  Do NOT call this in your own code!  filename must be a
+  // permanent string (e.g. a string literal).
+  static void InternalAddGeneratedFile(
+      const void* encoded_file_descriptor, int size);
+
 
   // For internal use only:  Gets a non-const pointer to the generated pool.
   // This is called at static-initialization time only, so thread-safety is
@@ -1047,6 +1115,11 @@ class LIBPROTOBUF_EXPORT DescriptorPool {
     underlay_ = underlay;
   }
 
+  // For internal (unit test) use only:  Returns true if a FileDescriptor has
+  // been constructed for the given file, false otherwise.  Useful for testing
+  // lazy descriptor initialization behavior.
+  bool InternalIsFileLoaded(const string& filename) const;
+
  private:
   friend class Descriptor;
   friend class FieldDescriptor;
@@ -1085,9 +1158,7 @@ class LIBPROTOBUF_EXPORT DescriptorPool {
   scoped_ptr<Tables> tables_;
 
   bool enforce_dependencies_;
-
-  // See InternalBuildGeneratedFile().
-  const void* last_internal_build_generated_file_call_;
+  bool allow_unknown_;
 
   GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(DescriptorPool);
 };
@@ -1267,6 +1338,10 @@ inline FieldDescriptor::CppType FieldDescriptor::cpp_type() const {
   return kTypeToCppTypeMap[type_];
 }
 
+inline FieldDescriptor::CppType FieldDescriptor::TypeToCppType(Type type) {
+  return kTypeToCppTypeMap[type];
+}
+
 inline const FileDescriptor* FileDescriptor::dependency(int index) const {
   return dependencies_[index];
 }

Tiedoston diff-näkymää rajattu, sillä se on liian suuri
+ 402 - 238
src/google/protobuf/descriptor.pb.cc


Tiedoston diff-näkymää rajattu, sillä se on liian suuri
+ 189 - 381
src/google/protobuf/descriptor.pb.h


+ 8 - 1
src/google/protobuf/descriptor.proto

@@ -250,7 +250,7 @@ message FileOptions {
     SPEED = 1;      // Generate complete code for parsing, serialization, etc.
     CODE_SIZE = 2;  // Use ReflectionOps to implement these methods.
   }
-  optional OptimizeMode optimize_for = 9 [default=CODE_SIZE];
+  optional OptimizeMode optimize_for = 9 [default=SPEED];
 
 
 
@@ -306,6 +306,12 @@ message FieldOptions {
   // a single length-delimited blob.
   optional bool packed = 2;
 
+  // Is this field deprecated?
+  // Depending on the target platform, this can emit Deprecated annotations
+  // for accessors, or it will be completely ignored; in the very least, this
+  // is a formalization for deprecating fields.
+  optional bool deprecated = 3 [default=false];
+
   // EXPERIMENTAL.  DO NOT USE.
   // For "map" fields, the name of the field in the enclosed type that
   // is the key for this map.  For example, suppose we have:
@@ -328,6 +334,7 @@ message FieldOptions {
 }
 
 message EnumOptions {
+
   // The parser stores options it doesn't recognize here. See above.
   repeated UninterpretedOption uninterpreted_option = 999;
 

+ 300 - 97
src/google/protobuf/descriptor_database.cc

@@ -33,7 +33,11 @@
 //  Sanjay Ghemawat, Jeff Dean, and others.
 
 #include <google/protobuf/descriptor_database.h>
+
+#include <set>
+
 #include <google/protobuf/descriptor.pb.h>
+#include <google/protobuf/stubs/strutil.h>
 #include <google/protobuf/stubs/stl_util-inl.h>
 #include <google/protobuf/stubs/map-util.h>
 
@@ -44,151 +48,312 @@ DescriptorDatabase::~DescriptorDatabase() {}
 
 // ===================================================================
 
-SimpleDescriptorDatabase::SimpleDescriptorDatabase() {}
-SimpleDescriptorDatabase::~SimpleDescriptorDatabase() {
-  STLDeleteElements(&files_to_delete_);
-}
-
-void SimpleDescriptorDatabase::Add(const FileDescriptorProto& file) {
-  FileDescriptorProto* new_file = new FileDescriptorProto;
-  new_file->CopyFrom(file);
-  AddAndOwn(new_file);
-}
-
-void SimpleDescriptorDatabase::AddAndOwn(const FileDescriptorProto* file) {
-  files_to_delete_.push_back(file);
-  InsertOrUpdate(&files_by_name_, file->name(), file);
+template <typename Value>
+bool SimpleDescriptorDatabase::DescriptorIndex<Value>::AddFile(
+    const FileDescriptorProto& file,
+    Value value) {
+  if (!InsertIfNotPresent(&by_name_, file.name(), value)) {
+    GOOGLE_LOG(ERROR) << "File already exists in database: " << file.name();
+    return false;
+  }
 
-  string path = file->package();
+  string path = file.package();
   if (!path.empty()) path += '.';
 
-  for (int i = 0; i < file->message_type_size(); i++) {
-    AddMessage(path, file->message_type(i), file);
+  for (int i = 0; i < file.message_type_size(); i++) {
+    if (!AddSymbol(path + file.message_type(i).name(), value)) return false;
+    if (!AddNestedExtensions(file.message_type(i), value)) return false;
   }
-  for (int i = 0; i < file->enum_type_size(); i++) {
-    AddEnum(path, file->enum_type(i), file);
+  for (int i = 0; i < file.enum_type_size(); i++) {
+    if (!AddSymbol(path + file.enum_type(i).name(), value)) return false;
   }
-  for (int i = 0; i < file->extension_size(); i++) {
-    AddField(path, file->extension(i), file);
+  for (int i = 0; i < file.extension_size(); i++) {
+    if (!AddSymbol(path + file.extension(i).name(), value)) return false;
+    if (!AddExtension(file.extension(i), value)) return false;
   }
-  for (int i = 0; i < file->service_size(); i++) {
-    AddService(path, file->service(i), file);
+  for (int i = 0; i < file.service_size(); i++) {
+    if (!AddSymbol(path + file.service(i).name(), value)) return false;
   }
+
+  return true;
 }
 
-void SimpleDescriptorDatabase::AddMessage(
-    const string& path,
-    const DescriptorProto& message_type,
-    const FileDescriptorProto* file) {
-  string full_name = path + message_type.name();
-  InsertOrUpdate(&files_by_symbol_, full_name, file);
+template <typename Value>
+bool SimpleDescriptorDatabase::DescriptorIndex<Value>::AddSymbol(
+    const string& name, Value value) {
+  // We need to make sure not to violate our map invariant.
 
-  string sub_path = full_name + '.';
-  for (int i = 0; i < message_type.nested_type_size(); i++) {
-    AddMessage(sub_path, message_type.nested_type(i), file);
+  // If the symbol name is invalid it could break our lookup algorithm (which
+  // relies on the fact that '.' sorts before all other characters that are
+  // valid in symbol names).
+  if (!ValidateSymbolName(name)) {
+    GOOGLE_LOG(ERROR) << "Invalid symbol name: " << name;
+    return false;
+  }
+
+  // Try to look up the symbol to make sure a super-symbol doesn't already
+  // exist.
+  typename map<string, Value>::iterator iter = FindLastLessOrEqual(name);
+
+  if (iter == by_symbol_.end()) {
+    // Apparently the map is currently empty.  Just insert and be done with it.
+    by_symbol_.insert(make_pair(name, value));
+    return true;
   }
-  for (int i = 0; i < message_type.enum_type_size(); i++) {
-    AddEnum(sub_path, message_type.enum_type(i), file);
+
+  if (IsSubSymbol(iter->first, name)) {
+    GOOGLE_LOG(ERROR) << "Symbol name \"" << name << "\" conflicts with the existing "
+                  "symbol \"" << iter->first << "\".";
+    return false;
   }
-  for (int i = 0; i < message_type.field_size(); i++) {
-    AddField(sub_path, message_type.field(i), file);
+
+  // OK, that worked.  Now we have to make sure that no symbol in the map is
+  // a sub-symbol of the one we are inserting.  The only symbol which could
+  // be so is the first symbol that is greater than the new symbol.  Since
+  // |iter| points at the last symbol that is less than or equal, we just have
+  // to increment it.
+  ++iter;
+
+  if (iter != by_symbol_.end() && IsSubSymbol(name, iter->first)) {
+    GOOGLE_LOG(ERROR) << "Symbol name \"" << name << "\" conflicts with the existing "
+                  "symbol \"" << iter->first << "\".";
+    return false;
+  }
+
+  // OK, no conflicts.
+
+  // Insert the new symbol using the iterator as a hint, the new entry will
+  // appear immediately before the one the iterator is pointing at.
+  by_symbol_.insert(iter, make_pair(name, value));
+
+  return true;
+}
+
+template <typename Value>
+bool SimpleDescriptorDatabase::DescriptorIndex<Value>::AddNestedExtensions(
+    const DescriptorProto& message_type,
+    Value value) {
+  for (int i = 0; i < message_type.nested_type_size(); i++) {
+    if (!AddNestedExtensions(message_type.nested_type(i), value)) return false;
   }
   for (int i = 0; i < message_type.extension_size(); i++) {
-    AddField(sub_path, message_type.extension(i), file);
+    if (!AddExtension(message_type.extension(i), value)) return false;
   }
+  return true;
 }
 
-void SimpleDescriptorDatabase::AddField(
-    const string& path,
+template <typename Value>
+bool SimpleDescriptorDatabase::DescriptorIndex<Value>::AddExtension(
     const FieldDescriptorProto& field,
-    const FileDescriptorProto* file) {
-  string full_name = path + field.name();
-  InsertOrUpdate(&files_by_symbol_, full_name, file);
-
-  if (field.has_extendee()) {
-    // This field is an extension.
-    if (!field.extendee().empty() && field.extendee()[0] == '.') {
-      // The extension is fully-qualified.  We can use it as a lookup key in
-      // the files_by_symbol_ table.
-      InsertOrUpdate(&files_by_extension_,
-                     make_pair(field.extendee().substr(1), field.number()),
-                     file);
-    } else {
-      // Not fully-qualified.  We can't really do anything here, unfortunately.
+    Value value) {
+  if (!field.extendee().empty() && field.extendee()[0] == '.') {
+    // The extension is fully-qualified.  We can use it as a lookup key in
+    // the by_symbol_ table.
+    if (!InsertIfNotPresent(&by_extension_,
+                            make_pair(field.extendee().substr(1),
+                                      field.number()),
+                            value)) {
+      GOOGLE_LOG(ERROR) << "Extension conflicts with extension already in database: "
+                    "extend " << field.extendee() << " { "
+                 << field.name() << " = " << field.number() << " }";
+      return false;
     }
+  } else {
+    // Not fully-qualified.  We can't really do anything here, unfortunately.
+    // We don't consider this an error, though, because the descriptor is
+    // valid.
   }
+  return true;
 }
 
-void SimpleDescriptorDatabase::AddEnum(
-    const string& path,
-    const EnumDescriptorProto& enum_type,
-    const FileDescriptorProto* file) {
-  string full_name = path + enum_type.name();
-  InsertOrUpdate(&files_by_symbol_, full_name, file);
+template <typename Value>
+Value SimpleDescriptorDatabase::DescriptorIndex<Value>::FindFile(
+    const string& filename) {
+  return FindWithDefault(by_name_, filename, Value());
+}
+
+template <typename Value>
+Value SimpleDescriptorDatabase::DescriptorIndex<Value>::FindSymbol(
+    const string& name) {
+  typename map<string, Value>::iterator iter = FindLastLessOrEqual(name);
+
+  return (iter != by_symbol_.end() && IsSubSymbol(iter->first, name)) ?
+         iter->second : Value();
+}
 
-  string sub_path = full_name + '.';
-  for (int i = 0; i < enum_type.value_size(); i++) {
-    InsertOrUpdate(&files_by_symbol_,
-                   sub_path + enum_type.value(i).name(),
-                   file);
+template <typename Value>
+Value SimpleDescriptorDatabase::DescriptorIndex<Value>::FindExtension(
+    const string& containing_type,
+    int field_number) {
+  return FindWithDefault(by_extension_,
+                         make_pair(containing_type, field_number),
+                         Value());
+}
+
+template <typename Value>
+bool SimpleDescriptorDatabase::DescriptorIndex<Value>::FindAllExtensionNumbers(
+    const string& containing_type,
+    vector<int>* output) {
+  typename map<pair<string, int>, Value >::const_iterator it =
+      by_extension_.lower_bound(make_pair(containing_type, 0));
+  bool success = false;
+
+  for (; it != by_extension_.end() && it->first.first == containing_type;
+       ++it) {
+    output->push_back(it->first.second);
+    success = true;
   }
+
+  return success;
+}
+
+template <typename Value>
+typename map<string, Value>::iterator
+SimpleDescriptorDatabase::DescriptorIndex<Value>::FindLastLessOrEqual(
+    const string& name) {
+  // Find the last key in the map which sorts less than or equal to the
+  // symbol name.  Since upper_bound() returns the *first* key that sorts
+  // *greater* than the input, we want the element immediately before that.
+  typename map<string, Value>::iterator iter = by_symbol_.upper_bound(name);
+  if (iter != by_symbol_.begin()) --iter;
+  return iter;
 }
 
-void SimpleDescriptorDatabase::AddService(
-    const string& path,
-    const ServiceDescriptorProto& service,
-    const FileDescriptorProto* file) {
-  string full_name = path + service.name();
-  InsertOrUpdate(&files_by_symbol_, full_name, file);
+template <typename Value>
+bool SimpleDescriptorDatabase::DescriptorIndex<Value>::IsSubSymbol(
+    const string& sub_symbol, const string& super_symbol) {
+  return sub_symbol == super_symbol ||
+         (HasPrefixString(super_symbol, sub_symbol) &&
+             super_symbol[sub_symbol.size()] == '.');
+}
 
-  string sub_path = full_name + '.';
-  for (int i = 0; i < service.method_size(); i++) {
-    InsertOrUpdate(&files_by_symbol_,
-                   sub_path + service.method(i).name(),
-                   file);
+template <typename Value>
+bool SimpleDescriptorDatabase::DescriptorIndex<Value>::ValidateSymbolName(
+    const string& name) {
+  for (int i = 0; i < name.size(); i++) {
+    // I don't trust ctype.h due to locales.  :(
+    if (name[i] != '.' && name[i] != '_' &&
+        (name[i] < '0' || name[i] > '9') &&
+        (name[i] < 'A' || name[i] > 'Z') &&
+        (name[i] < 'a' || name[i] > 'z')) {
+      return false;
+    }
   }
+  return true;
+}
+
+// -------------------------------------------------------------------
+
+SimpleDescriptorDatabase::SimpleDescriptorDatabase() {}
+SimpleDescriptorDatabase::~SimpleDescriptorDatabase() {
+  STLDeleteElements(&files_to_delete_);
+}
+
+bool SimpleDescriptorDatabase::Add(const FileDescriptorProto& file) {
+  FileDescriptorProto* new_file = new FileDescriptorProto;
+  new_file->CopyFrom(file);
+  return AddAndOwn(new_file);
+}
+
+bool SimpleDescriptorDatabase::AddAndOwn(const FileDescriptorProto* file) {
+  files_to_delete_.push_back(file);
+  return index_.AddFile(*file, file);
 }
 
 bool SimpleDescriptorDatabase::FindFileByName(
     const string& filename,
     FileDescriptorProto* output) {
-  const FileDescriptorProto* result = FindPtrOrNull(files_by_name_, filename);
-  if (result == NULL) {
-    return false;
-  } else {
-    output->CopyFrom(*result);
-    return true;
-  }
+  return MaybeCopy(index_.FindFile(filename), output);
 }
 
 bool SimpleDescriptorDatabase::FindFileContainingSymbol(
     const string& symbol_name,
     FileDescriptorProto* output) {
-  const FileDescriptorProto* result =
-    FindPtrOrNull(files_by_symbol_, symbol_name);
-  if (result == NULL) {
-    return false;
-  } else {
-    output->CopyFrom(*result);
-    return true;
-  }
+  return MaybeCopy(index_.FindSymbol(symbol_name), output);
 }
 
 bool SimpleDescriptorDatabase::FindFileContainingExtension(
     const string& containing_type,
     int field_number,
     FileDescriptorProto* output) {
-  const FileDescriptorProto* result =
-    FindPtrOrNull(files_by_extension_,
-                  make_pair(containing_type, field_number));
-  if (result == NULL) {
-    return false;
+  return MaybeCopy(index_.FindExtension(containing_type, field_number), output);
+}
+
+bool SimpleDescriptorDatabase::FindAllExtensionNumbers(
+    const string& extendee_type,
+    vector<int>* output) {
+  return index_.FindAllExtensionNumbers(extendee_type, output);
+}
+
+bool SimpleDescriptorDatabase::MaybeCopy(const FileDescriptorProto* file,
+                                         FileDescriptorProto* output) {
+  if (file == NULL) return false;
+  output->CopyFrom(*file);
+  return true;
+}
+
+// -------------------------------------------------------------------
+
+EncodedDescriptorDatabase::EncodedDescriptorDatabase() {}
+EncodedDescriptorDatabase::~EncodedDescriptorDatabase() {
+  for (int i = 0; i < files_to_delete_.size(); i++) {
+    operator delete(files_to_delete_[i]);
+  }
+}
+
+bool EncodedDescriptorDatabase::Add(
+    const void* encoded_file_descriptor, int size) {
+  FileDescriptorProto file;
+  if (file.ParseFromArray(encoded_file_descriptor, size)) {
+    return index_.AddFile(file, make_pair(encoded_file_descriptor, size));
   } else {
-    output->CopyFrom(*result);
-    return true;
+    GOOGLE_LOG(ERROR) << "Invalid file descriptor data passed to "
+                  "EncodedDescriptorDatabase::Add().";
+    return false;
   }
 }
 
+bool EncodedDescriptorDatabase::AddCopy(
+    const void* encoded_file_descriptor, int size) {
+  void* copy = operator new(size);
+  memcpy(copy, encoded_file_descriptor, size);
+  files_to_delete_.push_back(copy);
+  return Add(copy, size);
+}
+
+bool EncodedDescriptorDatabase::FindFileByName(
+    const string& filename,
+    FileDescriptorProto* output) {
+  return MaybeParse(index_.FindFile(filename), output);
+}
+
+bool EncodedDescriptorDatabase::FindFileContainingSymbol(
+    const string& symbol_name,
+    FileDescriptorProto* output) {
+  return MaybeParse(index_.FindSymbol(symbol_name), output);
+}
+
+bool EncodedDescriptorDatabase::FindFileContainingExtension(
+    const string& containing_type,
+    int field_number,
+    FileDescriptorProto* output) {
+  return MaybeParse(index_.FindExtension(containing_type, field_number),
+                    output);
+}
+
+bool EncodedDescriptorDatabase::FindAllExtensionNumbers(
+    const string& extendee_type,
+    vector<int>* output) {
+  return index_.FindAllExtensionNumbers(extendee_type, output);
+}
+
+bool EncodedDescriptorDatabase::MaybeParse(
+    pair<const void*, int> encoded_file,
+    FileDescriptorProto* output) {
+  if (encoded_file.first == NULL) return false;
+  return output->ParseFromArray(encoded_file.first, encoded_file.second);
+}
+
 // ===================================================================
 
 DescriptorPoolDatabase::DescriptorPoolDatabase(const DescriptorPool& pool)
@@ -231,6 +396,22 @@ bool DescriptorPoolDatabase::FindFileContainingExtension(
   return true;
 }
 
+bool DescriptorPoolDatabase::FindAllExtensionNumbers(
+    const string& extendee_type,
+    vector<int>* output) {
+  const Descriptor* extendee = pool_.FindMessageTypeByName(extendee_type);
+  if (extendee == NULL) return false;
+
+  vector<const FieldDescriptor*> extensions;
+  pool_.FindAllExtensions(extendee, &extensions);
+
+  for (int i = 0; i < extensions.size(); ++i) {
+    output->push_back(extensions[i]->number());
+  }
+
+  return true;
+}
+
 // ===================================================================
 
 MergedDescriptorDatabase::MergedDescriptorDatabase(
@@ -301,5 +482,27 @@ bool MergedDescriptorDatabase::FindFileContainingExtension(
   return false;
 }
 
+bool MergedDescriptorDatabase::FindAllExtensionNumbers(
+    const string& extendee_type,
+    vector<int>* output) {
+  set<int> merged_results;
+  vector<int> results;
+  bool success = false;
+
+  for (int i = 0; i < sources_.size(); i++) {
+    if (sources_[i]->FindAllExtensionNumbers(extendee_type, &results)) {
+      copy(results.begin(), results.end(),
+           insert_iterator<set<int> >(merged_results, merged_results.begin()));
+      success = true;
+    }
+    results.clear();
+  }
+
+  copy(merged_results.begin(), merged_results.end(),
+       insert_iterator<vector<int> >(*output, output->end()));
+
+  return success;
+}
+
 }  // namespace protobuf
 }  // namespace google

+ 186 - 21
src/google/protobuf/descriptor_database.h

@@ -46,6 +46,13 @@
 namespace google {
 namespace protobuf {
 
+// Defined in this file.
+class DescriptorDatabase;
+class SimpleDescriptorDatabase;
+class EncodedDescriptorDatabase;
+class DescriptorPoolDatabase;
+class MergedDescriptorDatabase;
+
 // Abstract interface for a database of descriptors.
 //
 // This is useful if you want to create a DescriptorPool which loads
@@ -78,6 +85,21 @@ class LIBPROTOBUF_EXPORT DescriptorDatabase {
                                            int field_number,
                                            FileDescriptorProto* output) = 0;
 
+  // Finds the tag numbers used by all known extensions of
+  // extendee_type, and appends them to output in an undefined
+  // order. This method is best-effort: it's not guaranteed that the
+  // database will find all extensions, and it's not guaranteed that
+  // FindFileContainingExtension will return true on all of the found
+  // numbers. Returns true if the search was successful, otherwise
+  // returns false and leaves output unchanged.
+  //
+  // This method has a default implementation that always returns
+  // false.
+  virtual bool FindAllExtensionNumbers(const string& extendee_type,
+                                       vector<int>* output) {
+    return false;
+  }
+
  private:
   GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(DescriptorDatabase);
 };
@@ -85,7 +107,11 @@ class LIBPROTOBUF_EXPORT DescriptorDatabase {
 // A DescriptorDatabase into which you can insert files manually.
 //
 // FindFileContainingSymbol() is fully-implemented.  When you add a file, its
-// symbols will be indexed for this purpose.
+// symbols will be indexed for this purpose.  Note that the implementation
+// may return false positives, but only if it isn't possible for the symbol
+// to be defined in any other file.  In particular, if a file defines a symbol
+// "Foo", then searching for "Foo.[anything]" will match that file.  This way,
+// the database does not need to aggressively index all children of a symbol.
 //
 // FindFileContainingExtension() is mostly-implemented.  It works if and only
 // if the original FieldDescriptorProto defining the extension has a
@@ -105,11 +131,13 @@ class LIBPROTOBUF_EXPORT SimpleDescriptorDatabase : public DescriptorDatabase {
   ~SimpleDescriptorDatabase();
 
   // Adds the FileDescriptorProto to the database, making a copy.  The object
-  // can be deleted after Add() returns.
-  void Add(const FileDescriptorProto& file);
+  // can be deleted after Add() returns.  Returns false if the file conflicted
+  // with a file already in the database, in which case an error will have
+  // been written to GOOGLE_LOG(ERROR).
+  bool Add(const FileDescriptorProto& file);
 
   // Adds the FileDescriptorProto to the database and takes ownership of it.
-  void AddAndOwn(const FileDescriptorProto* file);
+  bool AddAndOwn(const FileDescriptorProto* file);
 
   // implements DescriptorDatabase -----------------------------------
   bool FindFileByName(const string& filename,
@@ -119,31 +147,162 @@ class LIBPROTOBUF_EXPORT SimpleDescriptorDatabase : public DescriptorDatabase {
   bool FindFileContainingExtension(const string& containing_type,
                                    int field_number,
                                    FileDescriptorProto* output);
+  bool FindAllExtensionNumbers(const string& extendee_type,
+                               vector<int>* output);
 
  private:
-  // Helpers to recursively add particular descriptors and all their contents
-  // to the by-symbol and by-extension tables.
-  void AddMessage(const string& path,
-                  const DescriptorProto& message_type,
-                  const FileDescriptorProto* file);
-  void AddField(const string& path,
-                const FieldDescriptorProto& field,
-                const FileDescriptorProto* file);
-  void AddEnum(const string& path,
-               const EnumDescriptorProto& enum_type,
-               const FileDescriptorProto* file);
-  void AddService(const string& path,
-                  const ServiceDescriptorProto& service,
-                  const FileDescriptorProto* file);
+  // So that it can use DescriptorIndex.
+  friend class EncodedDescriptorDatabase;
+
+  // An index mapping file names, symbol names, and extension numbers to
+  // some sort of values.
+  template <typename Value>
+  class DescriptorIndex {
+   public:
+    // Helpers to recursively add particular descriptors and all their contents
+    // to the index.
+    bool AddFile(const FileDescriptorProto& file,
+                 Value value);
+    bool AddSymbol(const string& name, Value value);
+    bool AddNestedExtensions(const DescriptorProto& message_type,
+                             Value value);
+    bool AddExtension(const FieldDescriptorProto& field,
+                      Value value);
+
+    Value FindFile(const string& filename);
+    Value FindSymbol(const string& name);
+    Value FindExtension(const string& containing_type, int field_number);
+    bool FindAllExtensionNumbers(const string& containing_type,
+                                 vector<int>* output);
+
+   private:
+    map<string, Value> by_name_;
+    map<string, Value> by_symbol_;
+    map<pair<string, int>, Value> by_extension_;
+
+    // Invariant:  The by_symbol_ map does not contain any symbols which are
+    // prefixes of other symbols in the map.  For example, "foo.bar" is a
+    // prefix of "foo.bar.baz" (but is not a prefix of "foo.barbaz").
+    //
+    // This invariant is important because it means that given a symbol name,
+    // we can find a key in the map which is a prefix of the symbol in O(lg n)
+    // time, and we know that there is at most one such key.
+    //
+    // The prefix lookup algorithm works like so:
+    // 1) Find the last key in the map which is less than or equal to the
+    //    search key.
+    // 2) If the found key is a prefix of the search key, then return it.
+    //    Otherwise, there is no match.
+    //
+    // I am sure this algorithm has been described elsewhere, but since I
+    // wasn't able to find it quickly I will instead prove that it works
+    // myself.  The key to the algorithm is that if a match exists, step (1)
+    // will find it.  Proof:
+    // 1) Define the "search key" to be the key we are looking for, the "found
+    //    key" to be the key found in step (1), and the "match key" to be the
+    //    key which actually matches the serach key (i.e. the key we're trying
+    //    to find).
+    // 2) The found key must be less than or equal to the search key by
+    //    definition.
+    // 3) The match key must also be less than or equal to the search key
+    //    (because it is a prefix).
+    // 4) The match key cannot be greater than the found key, because if it
+    //    were, then step (1) of the algorithm would have returned the match
+    //    key instead (since it finds the *greatest* key which is less than or
+    //    equal to the search key).
+    // 5) Therefore, the found key must be between the match key and the search
+    //    key, inclusive.
+    // 6) Since the search key must be a sub-symbol of the match key, if it is
+    //    not equal to the match key, then search_key[match_key.size()] must
+    //    be '.'.
+    // 7) Since '.' sorts before any other character that is valid in a symbol
+    //    name, then if the found key is not equal to the match key, then
+    //    found_key[match_key.size()] must also be '.', because any other value
+    //    would make it sort after the search key.
+    // 8) Therefore, if the found key is not equal to the match key, then the
+    //    found key must be a sub-symbol of the match key.  However, this would
+    //    contradict our map invariant which says that no symbol in the map is
+    //    a sub-symbol of any other.
+    // 9) Therefore, the found key must match the match key.
+    //
+    // The above proof assumes the match key exists.  In the case that the
+    // match key does not exist, then step (1) will return some other symbol.
+    // That symbol cannot be a super-symbol of the search key since if it were,
+    // then it would be a match, and we're assuming the match key doesn't exist.
+    // Therefore, step 2 will correctly return no match.
+
+    // Find the last entry in the by_symbol_ map whose key is less than or
+    // equal to the given name.
+    typename map<string, Value>::iterator FindLastLessOrEqual(
+        const string& name);
+
+    // True if either the arguments are equal or super_symbol identifies a
+    // parent symbol of sub_symbol (e.g. "foo.bar" is a parent of
+    // "foo.bar.baz", but not a parent of "foo.barbaz").
+    bool IsSubSymbol(const string& sub_symbol, const string& super_symbol);
 
+    // Returns true if and only if all characters in the name are alphanumerics,
+    // underscores, or periods.
+    bool ValidateSymbolName(const string& name);
+  };
+
+
+  DescriptorIndex<const FileDescriptorProto*> index_;
   vector<const FileDescriptorProto*> files_to_delete_;
-  map<string, const FileDescriptorProto*> files_by_name_;
-  map<string, const FileDescriptorProto*> files_by_symbol_;
-  map<pair<string, int>, const FileDescriptorProto*> files_by_extension_;
+
+  // If file is non-NULL, copy it into *output and return true, otherwise
+  // return false.
+  bool MaybeCopy(const FileDescriptorProto* file,
+                 FileDescriptorProto* output);
 
   GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(SimpleDescriptorDatabase);
 };
 
+// Very similar to SimpleDescriptorDatabase, but stores all the descriptors
+// as raw bytes and generally tries to use as little memory as possible.
+//
+// The same caveats regarding FindFileContainingExtension() apply as with
+// SimpleDescriptorDatabase.
+class LIBPROTOBUF_EXPORT EncodedDescriptorDatabase : public DescriptorDatabase {
+ public:
+  EncodedDescriptorDatabase();
+  ~EncodedDescriptorDatabase();
+
+  // Adds the FileDescriptorProto to the database.  The descriptor is provided
+  // in encoded form.  The database does not make a copy of the bytes, nor
+  // does it take ownership; it's up to the caller to make sure the bytes
+  // remain valid for the life of the database.  Returns false and logs an error
+  // if the bytes are not a valid FileDescriptorProto or if the file conflicted
+  // with a file already in the database.
+  bool Add(const void* encoded_file_descriptor, int size);
+
+  // Like Add(), but makes a copy of the data, so that the caller does not
+  // need to keep it around.
+  bool AddCopy(const void* encoded_file_descriptor, int size);
+
+  // implements DescriptorDatabase -----------------------------------
+  bool FindFileByName(const string& filename,
+                      FileDescriptorProto* output);
+  bool FindFileContainingSymbol(const string& symbol_name,
+                                FileDescriptorProto* output);
+  bool FindFileContainingExtension(const string& containing_type,
+                                   int field_number,
+                                   FileDescriptorProto* output);
+  bool FindAllExtensionNumbers(const string& extendee_type,
+                               vector<int>* output);
+
+ private:
+  SimpleDescriptorDatabase::DescriptorIndex<pair<const void*, int> > index_;
+  vector<void*> files_to_delete_;
+
+  // If encoded_file.first is non-NULL, parse the data into *output and return
+  // true, otherwise return false.
+  bool MaybeParse(pair<const void*, int> encoded_file,
+                  FileDescriptorProto* output);
+
+  GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(EncodedDescriptorDatabase);
+};
+
 // A DescriptorDatabase that fetches files from a given pool.
 class LIBPROTOBUF_EXPORT DescriptorPoolDatabase : public DescriptorDatabase {
  public:
@@ -158,6 +317,8 @@ class LIBPROTOBUF_EXPORT DescriptorPoolDatabase : public DescriptorDatabase {
   bool FindFileContainingExtension(const string& containing_type,
                                    int field_number,
                                    FileDescriptorProto* output);
+  bool FindAllExtensionNumbers(const string& extendee_type,
+                               vector<int>* output);
 
  private:
   const DescriptorPool& pool_;
@@ -185,6 +346,10 @@ class LIBPROTOBUF_EXPORT MergedDescriptorDatabase : public DescriptorDatabase {
   bool FindFileContainingExtension(const string& containing_type,
                                    int field_number,
                                    FileDescriptorProto* output);
+  // Merges the results of calling all databases. Returns true iff any
+  // of the databases returned true.
+  bool FindAllExtensionNumbers(const string& extendee_type,
+                               vector<int>* output);
 
  private:
   vector<DescriptorDatabase*> sources_;

+ 255 - 158
src/google/protobuf/descriptor_database_unittest.cc

@@ -48,13 +48,6 @@ namespace google {
 namespace protobuf {
 namespace {
 
-static bool AddToPool(DescriptorPool* pool, const char* file_text) {
-  FileDescriptorProto file_proto;
-  if (!TextFormat::ParseFromString(file_text, &file_proto)) return false;
-  if (pool->BuildFile(file_proto) == NULL) return false;
-  return true;
-}
-
 static void AddToDatabase(SimpleDescriptorDatabase* database,
                           const char* file_text) {
   FileDescriptorProto file_proto;
@@ -74,25 +67,134 @@ static void ExpectContainsType(const FileDescriptorProto& proto,
 
 // ===================================================================
 
-TEST(SimpleDescriptorDatabaseTest, FindFileByName) {
-  SimpleDescriptorDatabase database;
-  AddToDatabase(&database,
+#if GTEST_HAS_PARAM_TEST
+
+// SimpleDescriptorDatabase, EncodedDescriptorDatabase, and
+// DescriptorPoolDatabase call for very similar tests.  Instead of writing
+// three nearly-identical sets of tests, we use parameterized tests to apply
+// the same code to all three.
+
+// The parameterized test runs against a DescriptarDatabaseTestCase.  We have
+// implementations for each of the three classes we want to test.
+class DescriptorDatabaseTestCase {
+ public:
+  virtual ~DescriptorDatabaseTestCase() {}
+
+  virtual DescriptorDatabase* GetDatabase() = 0;
+  virtual bool AddToDatabase(const FileDescriptorProto& file) = 0;
+};
+
+// Factory function type.
+typedef DescriptorDatabaseTestCase* DescriptorDatabaseTestCaseFactory();
+
+// Specialization for SimpleDescriptorDatabase.
+class SimpleDescriptorDatabaseTestCase : public DescriptorDatabaseTestCase {
+ public:
+  static DescriptorDatabaseTestCase* New() {
+    return new SimpleDescriptorDatabaseTestCase;
+  }
+
+  virtual ~SimpleDescriptorDatabaseTestCase() {}
+
+  virtual DescriptorDatabase* GetDatabase() {
+    return &database_;
+  }
+  virtual bool AddToDatabase(const FileDescriptorProto& file) {
+    return database_.Add(file);
+  }
+
+ private:
+  SimpleDescriptorDatabase database_;
+};
+
+// Specialization for EncodedDescriptorDatabase.
+class EncodedDescriptorDatabaseTestCase : public DescriptorDatabaseTestCase {
+ public:
+  static DescriptorDatabaseTestCase* New() {
+    return new EncodedDescriptorDatabaseTestCase;
+  }
+
+  virtual ~EncodedDescriptorDatabaseTestCase() {}
+
+  virtual DescriptorDatabase* GetDatabase() {
+    return &database_;
+  }
+  virtual bool AddToDatabase(const FileDescriptorProto& file) {
+    string data;
+    file.SerializeToString(&data);
+    return database_.AddCopy(data.data(), data.size());
+  }
+
+ private:
+  EncodedDescriptorDatabase database_;
+};
+
+// Specialization for DescriptorPoolDatabase.
+class DescriptorPoolDatabaseTestCase : public DescriptorDatabaseTestCase {
+ public:
+  static DescriptorDatabaseTestCase* New() {
+    return new EncodedDescriptorDatabaseTestCase;
+  }
+
+  DescriptorPoolDatabaseTestCase() : database_(pool_) {}
+  virtual ~DescriptorPoolDatabaseTestCase() {}
+
+  virtual DescriptorDatabase* GetDatabase() {
+    return &database_;
+  }
+  virtual bool AddToDatabase(const FileDescriptorProto& file) {
+    return pool_.BuildFile(file);
+  }
+
+ private:
+  DescriptorPool pool_;
+  DescriptorPoolDatabase database_;
+};
+
+// -------------------------------------------------------------------
+
+class DescriptorDatabaseTest
+    : public testing::TestWithParam<DescriptorDatabaseTestCaseFactory*> {
+ protected:
+  virtual void SetUp() {
+    test_case_.reset(GetParam()());
+    database_ = test_case_->GetDatabase();
+  }
+
+  void AddToDatabase(const char* file_descriptor_text) {
+    FileDescriptorProto file_proto;
+    EXPECT_TRUE(TextFormat::ParseFromString(file_descriptor_text, &file_proto));
+    EXPECT_TRUE(test_case_->AddToDatabase(file_proto));
+  }
+
+  void AddToDatabaseWithError(const char* file_descriptor_text) {
+    FileDescriptorProto file_proto;
+    EXPECT_TRUE(TextFormat::ParseFromString(file_descriptor_text, &file_proto));
+    EXPECT_FALSE(test_case_->AddToDatabase(file_proto));
+  }
+
+  scoped_ptr<DescriptorDatabaseTestCase> test_case_;
+  DescriptorDatabase* database_;
+};
+
+TEST_P(DescriptorDatabaseTest, FindFileByName) {
+  AddToDatabase(
     "name: \"foo.proto\" "
     "message_type { name:\"Foo\" }");
-  AddToDatabase(&database,
+  AddToDatabase(
     "name: \"bar.proto\" "
     "message_type { name:\"Bar\" }");
 
   {
     FileDescriptorProto file;
-    EXPECT_TRUE(database.FindFileByName("foo.proto", &file));
+    EXPECT_TRUE(database_->FindFileByName("foo.proto", &file));
     EXPECT_EQ("foo.proto", file.name());
     ExpectContainsType(file, "Foo");
   }
 
   {
     FileDescriptorProto file;
-    EXPECT_TRUE(database.FindFileByName("bar.proto", &file));
+    EXPECT_TRUE(database_->FindFileByName("bar.proto", &file));
     EXPECT_EQ("bar.proto", file.name());
     ExpectContainsType(file, "Bar");
   }
@@ -100,13 +202,12 @@ TEST(SimpleDescriptorDatabaseTest, FindFileByName) {
   {
     // Fails to find undefined files.
     FileDescriptorProto file;
-    EXPECT_FALSE(database.FindFileByName("baz.proto", &file));
+    EXPECT_FALSE(database_->FindFileByName("baz.proto", &file));
   }
 }
 
-TEST(SimpleDescriptorDatabaseTest, FindFileContainingSymbol) {
-  SimpleDescriptorDatabase database;
-  AddToDatabase(&database,
+TEST_P(DescriptorDatabaseTest, FindFileContainingSymbol) {
+  AddToDatabase(
     "name: \"foo.proto\" "
     "message_type { "
     "  name: \"Foo\" "
@@ -124,96 +225,95 @@ TEST(SimpleDescriptorDatabaseTest, FindFileContainingSymbol) {
     "  method { name: \"Thud\" } "
     "}"
     );
-  AddToDatabase(&database,
+  AddToDatabase(
     "name: \"bar.proto\" "
     "package: \"corge\" "
     "message_type { name: \"Bar\" }");
 
   {
     FileDescriptorProto file;
-    EXPECT_TRUE(database.FindFileContainingSymbol("Foo", &file));
+    EXPECT_TRUE(database_->FindFileContainingSymbol("Foo", &file));
     EXPECT_EQ("foo.proto", file.name());
   }
 
   {
     // Can find fields.
     FileDescriptorProto file;
-    EXPECT_TRUE(database.FindFileContainingSymbol("Foo.qux", &file));
+    EXPECT_TRUE(database_->FindFileContainingSymbol("Foo.qux", &file));
     EXPECT_EQ("foo.proto", file.name());
   }
 
   {
     // Can find nested types.
     FileDescriptorProto file;
-    EXPECT_TRUE(database.FindFileContainingSymbol("Foo.Grault", &file));
+    EXPECT_TRUE(database_->FindFileContainingSymbol("Foo.Grault", &file));
     EXPECT_EQ("foo.proto", file.name());
   }
 
   {
     // Can find nested enums.
     FileDescriptorProto file;
-    EXPECT_TRUE(database.FindFileContainingSymbol("Foo.Garply", &file));
+    EXPECT_TRUE(database_->FindFileContainingSymbol("Foo.Garply", &file));
     EXPECT_EQ("foo.proto", file.name());
   }
 
   {
     // Can find enum types.
     FileDescriptorProto file;
-    EXPECT_TRUE(database.FindFileContainingSymbol("Waldo", &file));
+    EXPECT_TRUE(database_->FindFileContainingSymbol("Waldo", &file));
     EXPECT_EQ("foo.proto", file.name());
   }
 
   {
     // Can find enum values.
     FileDescriptorProto file;
-    EXPECT_TRUE(database.FindFileContainingSymbol("Waldo.FRED", &file));
+    EXPECT_TRUE(database_->FindFileContainingSymbol("Waldo.FRED", &file));
     EXPECT_EQ("foo.proto", file.name());
   }
 
   {
     // Can find extensions.
     FileDescriptorProto file;
-    EXPECT_TRUE(database.FindFileContainingSymbol("plugh", &file));
+    EXPECT_TRUE(database_->FindFileContainingSymbol("plugh", &file));
     EXPECT_EQ("foo.proto", file.name());
   }
 
   {
     // Can find services.
     FileDescriptorProto file;
-    EXPECT_TRUE(database.FindFileContainingSymbol("Xyzzy", &file));
+    EXPECT_TRUE(database_->FindFileContainingSymbol("Xyzzy", &file));
     EXPECT_EQ("foo.proto", file.name());
   }
 
   {
     // Can find methods.
     FileDescriptorProto file;
-    EXPECT_TRUE(database.FindFileContainingSymbol("Xyzzy.Thud", &file));
+    EXPECT_TRUE(database_->FindFileContainingSymbol("Xyzzy.Thud", &file));
     EXPECT_EQ("foo.proto", file.name());
   }
 
   {
     // Can find things in packages.
     FileDescriptorProto file;
-    EXPECT_TRUE(database.FindFileContainingSymbol("corge.Bar", &file));
+    EXPECT_TRUE(database_->FindFileContainingSymbol("corge.Bar", &file));
     EXPECT_EQ("bar.proto", file.name());
   }
 
   {
     // Fails to find undefined symbols.
     FileDescriptorProto file;
-    EXPECT_FALSE(database.FindFileContainingSymbol("Baz", &file));
+    EXPECT_FALSE(database_->FindFileContainingSymbol("Baz", &file));
   }
 
   {
     // Names must be fully-qualified.
     FileDescriptorProto file;
-    EXPECT_FALSE(database.FindFileContainingSymbol("Bar", &file));
+    EXPECT_FALSE(database_->FindFileContainingSymbol("Bar", &file));
   }
 }
 
-TEST(SimpleDescriptorDatabaseTest, FindFileContainingExtension) {
-  SimpleDescriptorDatabase database;
-  AddToDatabase(&database,
+TEST_P(DescriptorDatabaseTest, FindFileContainingExtension) {
+  AddToDatabase(
     "name: \"foo.proto\" "
     "message_type { "
     "  name: \"Foo\" "
@@ -221,7 +321,7 @@ TEST(SimpleDescriptorDatabaseTest, FindFileContainingExtension) {
     "  extension { name:\"qux\" label:LABEL_OPTIONAL type:TYPE_INT32 number:5 "
     "              extendee: \".Foo\" }"
     "}");
-  AddToDatabase(&database,
+  AddToDatabase(
     "name: \"bar.proto\" "
     "package: \"corge\" "
     "dependency: \"foo.proto\" "
@@ -235,20 +335,20 @@ TEST(SimpleDescriptorDatabaseTest, FindFileContainingExtension) {
 
   {
     FileDescriptorProto file;
-    EXPECT_TRUE(database.FindFileContainingExtension("Foo", 5, &file));
+    EXPECT_TRUE(database_->FindFileContainingExtension("Foo", 5, &file));
     EXPECT_EQ("foo.proto", file.name());
   }
 
   {
     FileDescriptorProto file;
-    EXPECT_TRUE(database.FindFileContainingExtension("Foo", 32, &file));
+    EXPECT_TRUE(database_->FindFileContainingExtension("Foo", 32, &file));
     EXPECT_EQ("bar.proto", file.name());
   }
 
   {
     // Can find extensions for qualified type names.
     FileDescriptorProto file;
-    EXPECT_TRUE(database.FindFileContainingExtension("corge.Bar", 70, &file));
+    EXPECT_TRUE(database_->FindFileContainingExtension("corge.Bar", 70, &file));
     EXPECT_EQ("bar.proto", file.name());
   }
 
@@ -256,173 +356,127 @@ TEST(SimpleDescriptorDatabaseTest, FindFileContainingExtension) {
     // Can't find extensions whose extendee was not fully-qualified in the
     // FileDescriptorProto.
     FileDescriptorProto file;
-    EXPECT_FALSE(database.FindFileContainingExtension("Bar", 56, &file));
-    EXPECT_FALSE(database.FindFileContainingExtension("corge.Bar", 56, &file));
+    EXPECT_FALSE(database_->FindFileContainingExtension("Bar", 56, &file));
+    EXPECT_FALSE(
+        database_->FindFileContainingExtension("corge.Bar", 56, &file));
   }
 
   {
     // Can't find non-existent extension numbers.
     FileDescriptorProto file;
-    EXPECT_FALSE(database.FindFileContainingExtension("Foo", 12, &file));
+    EXPECT_FALSE(database_->FindFileContainingExtension("Foo", 12, &file));
   }
 
   {
     // Can't find extensions for non-existent types.
     FileDescriptorProto file;
-    EXPECT_FALSE(database.FindFileContainingExtension("NoSuchType", 5, &file));
+    EXPECT_FALSE(
+        database_->FindFileContainingExtension("NoSuchType", 5, &file));
   }
 
   {
     // Can't find extensions for unqualified type names.
     FileDescriptorProto file;
-    EXPECT_FALSE(database.FindFileContainingExtension("Bar", 70, &file));
+    EXPECT_FALSE(database_->FindFileContainingExtension("Bar", 70, &file));
   }
 }
 
-// ===================================================================
-
-TEST(DescriptorPoolDatabaseTest, FindFileByName) {
-  DescriptorPool pool;
-  ASSERT_TRUE(AddToPool(&pool,
-    "name: \"foo.proto\" "
-    "message_type { name:\"Foo\" }"));
-  ASSERT_TRUE(AddToPool(&pool,
-    "name: \"bar.proto\" "
-    "message_type { name:\"Bar\" }"));
-
-  DescriptorPoolDatabase database(pool);
-
-  {
-    FileDescriptorProto file;
-    EXPECT_TRUE(database.FindFileByName("foo.proto", &file));
-    EXPECT_EQ("foo.proto", file.name());
-    ExpectContainsType(file, "Foo");
-  }
-
-  {
-    FileDescriptorProto file;
-    EXPECT_TRUE(database.FindFileByName("bar.proto", &file));
-    EXPECT_EQ("bar.proto", file.name());
-    ExpectContainsType(file, "Bar");
-  }
-
-  {
-    // Fails to find undefined files.
-    FileDescriptorProto file;
-    EXPECT_FALSE(database.FindFileByName("baz.proto", &file));
-  }
-}
-
-TEST(DescriptorPoolDatabaseTest, FindFileContainingSymbol) {
-  DescriptorPool pool;
-  ASSERT_TRUE(AddToPool(&pool,
+TEST_P(DescriptorDatabaseTest, FindAllExtensionNumbers) {
+  AddToDatabase(
     "name: \"foo.proto\" "
     "message_type { "
     "  name: \"Foo\" "
-    "  field { name:\"qux\" label:LABEL_OPTIONAL type:TYPE_INT32 number:1 }"
-    "}"));
-  ASSERT_TRUE(AddToPool(&pool,
+    "  extension_range { start: 1 end: 1000 } "
+    "  extension { name:\"qux\" label:LABEL_OPTIONAL type:TYPE_INT32 number:5 "
+    "              extendee: \".Foo\" }"
+    "}");
+  AddToDatabase(
     "name: \"bar.proto\" "
     "package: \"corge\" "
-    "message_type { name: \"Bar\" }"));
-
-  DescriptorPoolDatabase database(pool);
-
-  {
-    FileDescriptorProto file;
-    EXPECT_TRUE(database.FindFileContainingSymbol("Foo", &file));
-    EXPECT_EQ("foo.proto", file.name());
-  }
+    "dependency: \"foo.proto\" "
+    "message_type { "
+    "  name: \"Bar\" "
+    "  extension_range { start: 1 end: 1000 } "
+    "} "
+    "extension { name:\"grault\" extendee: \".Foo\"       number:32 } "
+    "extension { name:\"garply\" extendee: \".corge.Bar\" number:70 } "
+    "extension { name:\"waldo\"  extendee: \"Bar\"        number:56 } ");
 
   {
-    // Can find fields.
-    FileDescriptorProto file;
-    EXPECT_TRUE(database.FindFileContainingSymbol("Foo.qux", &file));
-    EXPECT_EQ("foo.proto", file.name());
+    vector<int> numbers;
+    EXPECT_TRUE(database_->FindAllExtensionNumbers("Foo", &numbers));
+    ASSERT_EQ(2, numbers.size());
+    sort(numbers.begin(), numbers.end());
+    EXPECT_EQ(5, numbers[0]);
+    EXPECT_EQ(32, numbers[1]);
   }
 
   {
-    // Can find things in packages.
-    FileDescriptorProto file;
-    EXPECT_TRUE(database.FindFileContainingSymbol("corge.Bar", &file));
-    EXPECT_EQ("bar.proto", file.name());
+    vector<int> numbers;
+    EXPECT_TRUE(database_->FindAllExtensionNumbers("corge.Bar", &numbers));
+    // Note: won't find extension 56 due to the name not being fully qualified.
+    ASSERT_EQ(1, numbers.size());
+    EXPECT_EQ(70, numbers[0]);
   }
 
   {
-    // Fails to find undefined symbols.
-    FileDescriptorProto file;
-    EXPECT_FALSE(database.FindFileContainingSymbol("Baz", &file));
+    // Can't find extensions for non-existent types.
+    vector<int> numbers;
+    EXPECT_FALSE(database_->FindAllExtensionNumbers("NoSuchType", &numbers));
   }
 
   {
-    // Names must be fully-qualified.
-    FileDescriptorProto file;
-    EXPECT_FALSE(database.FindFileContainingSymbol("Bar", &file));
+    // Can't find extensions for unqualified types.
+    vector<int> numbers;
+    EXPECT_FALSE(database_->FindAllExtensionNumbers("Bar", &numbers));
   }
 }
 
-TEST(DescriptorPoolDatabaseTest, FindFileContainingExtension) {
-  DescriptorPool pool;
-  ASSERT_TRUE(AddToPool(&pool,
+TEST_P(DescriptorDatabaseTest, ConflictingFileError) {
+  AddToDatabase(
     "name: \"foo.proto\" "
     "message_type { "
     "  name: \"Foo\" "
-    "  extension_range { start: 1 end: 1000 } "
-    "  extension { name:\"qux\" label:LABEL_OPTIONAL type:TYPE_INT32 number:5 "
-    "              extendee: \"Foo\" }"
-    "}"));
-  ASSERT_TRUE(AddToPool(&pool,
-    "name: \"bar.proto\" "
-    "package: \"corge\" "
-    "dependency: \"foo.proto\" "
+    "}");
+  AddToDatabaseWithError(
+    "name: \"foo.proto\" "
     "message_type { "
     "  name: \"Bar\" "
-    "  extension_range { start: 1 end: 1000 } "
-    "} "
-    "extension { name:\"grault\" label:LABEL_OPTIONAL type:TYPE_BOOL number:32 "
-    "            extendee: \"Foo\" } "
-    "extension { name:\"garply\" label:LABEL_OPTIONAL type:TYPE_BOOL number:70 "
-    "            extendee: \"Bar\" } "));
-
-  DescriptorPoolDatabase database(pool);
-
-  {
-    FileDescriptorProto file;
-    EXPECT_TRUE(database.FindFileContainingExtension("Foo", 5, &file));
-    EXPECT_EQ("foo.proto", file.name());
-  }
-
-  {
-    FileDescriptorProto file;
-    EXPECT_TRUE(database.FindFileContainingExtension("Foo", 32, &file));
-    EXPECT_EQ("bar.proto", file.name());
-  }
+    "}");
+}
 
-  {
-    // Can find extensions for qualified type names..
-    FileDescriptorProto file;
-    EXPECT_TRUE(database.FindFileContainingExtension("corge.Bar", 70, &file));
-    EXPECT_EQ("bar.proto", file.name());
-  }
+TEST_P(DescriptorDatabaseTest, ConflictingTypeError) {
+  AddToDatabase(
+    "name: \"foo.proto\" "
+    "message_type { "
+    "  name: \"Foo\" "
+    "}");
+  AddToDatabaseWithError(
+    "name: \"bar.proto\" "
+    "message_type { "
+    "  name: \"Foo\" "
+    "}");
+}
 
-  {
-    // Can't find non-existent extension numbers.
-    FileDescriptorProto file;
-    EXPECT_FALSE(database.FindFileContainingExtension("Foo", 12, &file));
-  }
+TEST_P(DescriptorDatabaseTest, ConflictingExtensionError) {
+  AddToDatabase(
+    "name: \"foo.proto\" "
+    "extension { name:\"foo\" label:LABEL_OPTIONAL type:TYPE_INT32 number:5 "
+    "            extendee: \".Foo\" }");
+  AddToDatabaseWithError(
+    "name: \"bar.proto\" "
+    "extension { name:\"bar\" label:LABEL_OPTIONAL type:TYPE_INT32 number:5 "
+    "            extendee: \".Foo\" }");
+}
 
-  {
-    // Can't find extensions for non-existent types.
-    FileDescriptorProto file;
-    EXPECT_FALSE(database.FindFileContainingExtension("NoSuchType", 5, &file));
-  }
+INSTANTIATE_TEST_CASE_P(Simple, DescriptorDatabaseTest,
+    testing::Values(&SimpleDescriptorDatabaseTestCase::New));
+INSTANTIATE_TEST_CASE_P(MemoryConserving, DescriptorDatabaseTest,
+    testing::Values(&EncodedDescriptorDatabaseTestCase::New));
+INSTANTIATE_TEST_CASE_P(Pool, DescriptorDatabaseTest,
+    testing::Values(&DescriptorPoolDatabaseTestCase::New));
 
-  {
-    // Can't find extensions for unqualified type names.
-    FileDescriptorProto file;
-    EXPECT_FALSE(database.FindFileContainingExtension("Bar", 70, &file));
-  }
-}
+#endif  // GTEST_HAS_PARAM_TEST
 
 // ===================================================================
 
@@ -610,6 +664,49 @@ TEST_F(MergedDescriptorDatabaseTest, FindFileContainingExtension) {
   }
 }
 
+TEST_F(MergedDescriptorDatabaseTest, FindAllExtensionNumbers) {
+  {
+    // Message only has extension in database1_
+    vector<int> numbers;
+    EXPECT_TRUE(forward_merged_.FindAllExtensionNumbers("Foo", &numbers));
+    ASSERT_EQ(1, numbers.size());
+    EXPECT_EQ(3, numbers[0]);
+  }
+
+  {
+    // Message only has extension in database2_
+    vector<int> numbers;
+    EXPECT_TRUE(forward_merged_.FindAllExtensionNumbers("Bar", &numbers));
+    ASSERT_EQ(1, numbers.size());
+    EXPECT_EQ(5, numbers[0]);
+  }
+
+  {
+    // Merge results from the two databases.
+    vector<int> numbers;
+    EXPECT_TRUE(forward_merged_.FindAllExtensionNumbers("Baz", &numbers));
+    ASSERT_EQ(2, numbers.size());
+    sort(numbers.begin(), numbers.end());
+    EXPECT_EQ(12, numbers[0]);
+    EXPECT_EQ(13, numbers[1]);
+  }
+
+  {
+    vector<int> numbers;
+    EXPECT_TRUE(reverse_merged_.FindAllExtensionNumbers("Baz", &numbers));
+    ASSERT_EQ(2, numbers.size());
+    sort(numbers.begin(), numbers.end());
+    EXPECT_EQ(12, numbers[0]);
+    EXPECT_EQ(13, numbers[1]);
+  }
+
+  {
+    // Can't find extensions for a non-existent message.
+    vector<int> numbers;
+    EXPECT_FALSE(reverse_merged_.FindAllExtensionNumbers("Blah", &numbers));
+  }
+}
+
 }  // anonymous namespace
 }  // namespace protobuf
 }  // namespace google

+ 310 - 3
src/google/protobuf/descriptor_unittest.cc

@@ -346,6 +346,18 @@ TEST_F(FileDescriptorTest, FindExtensionByNumber) {
   EXPECT_TRUE(pool_.FindExtensionByNumber(foo_message_, 2) == NULL);
 }
 
+TEST_F(FileDescriptorTest, BuildAgain) {
+  // Test that if te call BuildFile again on the same input we get the same
+  // FileDescriptor back.
+  FileDescriptorProto file;
+  foo_file_->CopyTo(&file);
+  EXPECT_EQ(foo_file_, pool_.BuildFile(file));
+
+  // But if we change the file then it won't work.
+  file.set_package("some.other.package");
+  EXPECT_TRUE(pool_.BuildFile(file) == NULL);
+}
+
 // ===================================================================
 
 // Test simple flat messages and fields.
@@ -1492,6 +1504,16 @@ TEST_F(ExtensionDescriptorTest, FindExtensionByName) {
   EXPECT_TRUE(foo_->FindExtensionByName("foo_message") == NULL);
 }
 
+TEST_F(ExtensionDescriptorTest, FindAllExtensions) {
+  vector<const FieldDescriptor*> extensions;
+  pool_.FindAllExtensions(foo_, &extensions);
+  ASSERT_EQ(4, extensions.size());
+  EXPECT_EQ(10, extensions[0]->number());
+  EXPECT_EQ(19, extensions[1]->number());
+  EXPECT_EQ(30, extensions[2]->number());
+  EXPECT_EQ(39, extensions[3]->number());
+}
+
 // ===================================================================
 
 class MiscTest : public testing::Test {
@@ -1716,6 +1738,219 @@ TEST_F(MiscTest, FieldOptions) {
   EXPECT_EQ(FieldOptions::CORD, bar->options().ctype());
 }
 
+// ===================================================================
+
+class AllowUnknownDependenciesTest : public testing::Test {
+ protected:
+  virtual void SetUp() {
+    FileDescriptorProto foo_proto, bar_proto;
+
+    pool_.AllowUnknownDependencies();
+
+    ASSERT_TRUE(TextFormat::ParseFromString(
+      "name: 'foo.proto'"
+      "dependency: 'bar.proto'"
+      "dependency: 'baz.proto'"
+      "message_type {"
+      "  name: 'Foo'"
+      "  field { name:'bar' number:1 label:LABEL_OPTIONAL type_name:'Bar' }"
+      "  field { name:'baz' number:2 label:LABEL_OPTIONAL type_name:'Baz' }"
+      "  field { name:'qux' number:3 label:LABEL_OPTIONAL"
+      "    type_name: '.corge.Qux'"
+      "    type: TYPE_ENUM"
+      "    options {"
+      "      uninterpreted_option {"
+      "        name {"
+      "          name_part: 'grault'"
+      "          is_extension: true"
+      "        }"
+      "        positive_int_value: 1234"
+      "      }"
+      "    }"
+      "  }"
+      "}",
+      &foo_proto));
+    ASSERT_TRUE(TextFormat::ParseFromString(
+      "name: 'bar.proto'"
+      "message_type { name: 'Bar' }",
+      &bar_proto));
+
+    // Collect pointers to stuff.
+    bar_file_ = pool_.BuildFile(bar_proto);
+    ASSERT_TRUE(bar_file_ != NULL);
+
+    ASSERT_EQ(1, bar_file_->message_type_count());
+    bar_type_ = bar_file_->message_type(0);
+
+    foo_file_ = pool_.BuildFile(foo_proto);
+    ASSERT_TRUE(foo_file_ != NULL);
+
+    ASSERT_EQ(1, foo_file_->message_type_count());
+    foo_type_ = foo_file_->message_type(0);
+
+    ASSERT_EQ(3, foo_type_->field_count());
+    bar_field_ = foo_type_->field(0);
+    baz_field_ = foo_type_->field(1);
+    qux_field_ = foo_type_->field(2);
+  }
+
+  const FileDescriptor* bar_file_;
+  const Descriptor* bar_type_;
+  const FileDescriptor* foo_file_;
+  const Descriptor* foo_type_;
+  const FieldDescriptor* bar_field_;
+  const FieldDescriptor* baz_field_;
+  const FieldDescriptor* qux_field_;
+
+  DescriptorPool pool_;
+};
+
+TEST_F(AllowUnknownDependenciesTest, PlaceholderFile) {
+  ASSERT_EQ(2, foo_file_->dependency_count());
+  EXPECT_EQ(bar_file_, foo_file_->dependency(0));
+
+  const FileDescriptor* baz_file = foo_file_->dependency(1);
+  EXPECT_EQ("baz.proto", baz_file->name());
+  EXPECT_EQ(0, baz_file->message_type_count());
+
+  // Placeholder files should not be findable.
+  EXPECT_EQ(bar_file_, pool_.FindFileByName(bar_file_->name()));
+  EXPECT_TRUE(pool_.FindFileByName(baz_file->name()) == NULL);
+}
+
+TEST_F(AllowUnknownDependenciesTest, PlaceholderTypes) {
+  ASSERT_EQ(FieldDescriptor::TYPE_MESSAGE, bar_field_->type());
+  EXPECT_EQ(bar_type_, bar_field_->message_type());
+
+  ASSERT_EQ(FieldDescriptor::TYPE_MESSAGE, baz_field_->type());
+  const Descriptor* baz_type = baz_field_->message_type();
+  EXPECT_EQ("Baz", baz_type->name());
+  EXPECT_EQ("Baz", baz_type->full_name());
+  EXPECT_EQ("Baz.placeholder.proto", baz_type->file()->name());
+  EXPECT_EQ(0, baz_type->extension_range_count());
+
+  ASSERT_EQ(FieldDescriptor::TYPE_ENUM, qux_field_->type());
+  const EnumDescriptor* qux_type = qux_field_->enum_type();
+  EXPECT_EQ("Qux", qux_type->name());
+  EXPECT_EQ("corge.Qux", qux_type->full_name());
+  EXPECT_EQ("corge.Qux.placeholder.proto", qux_type->file()->name());
+
+  // Placeholder types should not be findable.
+  EXPECT_EQ(bar_type_, pool_.FindMessageTypeByName(bar_type_->full_name()));
+  EXPECT_TRUE(pool_.FindMessageTypeByName(baz_type->full_name()) == NULL);
+  EXPECT_TRUE(pool_.FindEnumTypeByName(qux_type->full_name()) == NULL);
+}
+
+TEST_F(AllowUnknownDependenciesTest, CopyTo) {
+  // FieldDescriptor::CopyTo() should write non-fully-qualified type names
+  // for placeholder types which were not originally fully-qualified.
+  FieldDescriptorProto proto;
+
+  // Bar is not a placeholder, so it is fully-qualified.
+  bar_field_->CopyTo(&proto);
+  EXPECT_EQ(".Bar", proto.type_name());
+  EXPECT_EQ(FieldDescriptorProto::TYPE_MESSAGE, proto.type());
+
+  // Baz is an unqualified placeholder.
+  proto.Clear();
+  baz_field_->CopyTo(&proto);
+  EXPECT_EQ("Baz", proto.type_name());
+  EXPECT_FALSE(proto.has_type());
+
+  // Qux is a fully-qualified placeholder.
+  proto.Clear();
+  qux_field_->CopyTo(&proto);
+  EXPECT_EQ(".corge.Qux", proto.type_name());
+  EXPECT_EQ(FieldDescriptorProto::TYPE_ENUM, proto.type());
+}
+
+TEST_F(AllowUnknownDependenciesTest, CustomOptions) {
+  // Qux should still have the uninterpreted option attached.
+  ASSERT_EQ(1, qux_field_->options().uninterpreted_option_size());
+  const UninterpretedOption& option =
+    qux_field_->options().uninterpreted_option(0);
+  ASSERT_EQ(1, option.name_size());
+  EXPECT_EQ("grault", option.name(0).name_part());
+}
+
+TEST_F(AllowUnknownDependenciesTest, UnknownExtendee) {
+  // Test that we can extend an unknown type.  This is slightly tricky because
+  // it means that the placeholder type must have an extension range.
+
+  FileDescriptorProto extension_proto;
+
+  ASSERT_TRUE(TextFormat::ParseFromString(
+    "name: 'extension.proto'"
+    "extension { extendee: 'UnknownType' name:'some_extension' number:123"
+    "            label:LABEL_OPTIONAL type:TYPE_INT32 }",
+    &extension_proto));
+  const FileDescriptor* file = pool_.BuildFile(extension_proto);
+
+  ASSERT_TRUE(file != NULL);
+
+  ASSERT_EQ(1, file->extension_count());
+  const Descriptor* extendee = file->extension(0)->containing_type();
+  EXPECT_EQ("UnknownType", extendee->name());
+  ASSERT_EQ(1, extendee->extension_range_count());
+  EXPECT_EQ(1, extendee->extension_range(0)->start);
+  EXPECT_EQ(FieldDescriptor::kMaxNumber + 1, extendee->extension_range(0)->end);
+}
+
+TEST_F(AllowUnknownDependenciesTest, CustomOption) {
+  // Test that we can use a custom option without having parsed
+  // descriptor.proto.
+
+  FileDescriptorProto option_proto;
+
+  ASSERT_TRUE(TextFormat::ParseFromString(
+    "name: \"unknown_custom_options.proto\" "
+    "dependency: \"google/protobuf/descriptor.proto\" "
+    "extension { "
+    "  extendee: \"google.protobuf.FileOptions\" "
+    "  name: \"some_option\" "
+    "  number: 123456 "
+    "  label: LABEL_OPTIONAL "
+    "  type: TYPE_INT32 "
+    "} "
+    "options { "
+    "  uninterpreted_option { "
+    "    name { "
+    "      name_part: \"some_option\" "
+    "      is_extension: true "
+    "    } "
+    "    positive_int_value: 1234 "
+    "  } "
+    "  uninterpreted_option { "
+    "    name { "
+    "      name_part: \"unknown_option\" "
+    "      is_extension: true "
+    "    } "
+    "    positive_int_value: 1234 "
+    "  } "
+    "  uninterpreted_option { "
+    "    name { "
+    "      name_part: \"optimize_for\" "
+    "      is_extension: false "
+    "    } "
+    "    identifier_value: \"SPEED\" "
+    "  } "
+    "}",
+    &option_proto));
+
+  const FileDescriptor* file = pool_.BuildFile(option_proto);
+  ASSERT_TRUE(file != NULL);
+
+  // Verify that no extension options were set, but they were left as
+  // uninterpreted_options.
+  vector<const FieldDescriptor*> fields;
+  file->options().GetReflection()->ListFields(file->options(), &fields);
+  ASSERT_EQ(2, fields.size());
+  EXPECT_TRUE(file->options().has_optimize_for());
+  EXPECT_EQ(2, file->options().uninterpreted_option_size());
+}
+
+// ===================================================================
+
 TEST(CustomOptions, OptionLocations) {
   const Descriptor* message =
       protobuf_unittest::TestMessageWithCustomOptions::descriptor();
@@ -2108,7 +2343,10 @@ TEST_F(ValidationErrorTest, DupeFile) {
   //   defined.
   BuildFileWithErrors(
     "name: \"foo.proto\" "
-    "message_type { name: \"Foo\" }",
+    "message_type { name: \"Foo\" } "
+    // Add another type so that the files aren't identical (in which case there
+    // would be no error).
+    "enum_type { name: \"Bar\" }",
 
     "foo.proto: foo.proto: OTHER: A file with this name is already in the "
       "pool.\n");
@@ -2174,6 +2412,10 @@ TEST_F(ValidationErrorTest, InvalidDefaults) {
     // we look up the type name.
     "  field { name: \"quux\" number: 5 label: LABEL_OPTIONAL"
     "          default_value: \"abc\" type_name: \"Foo\" }"
+
+    // Repeateds can't have defaults.
+    "  field { name: \"corge\" number: 6 label: LABEL_REPEATED type: TYPE_INT32"
+    "          default_value: \"1\" }"
     "}",
 
     "foo.proto: Foo.foo: DEFAULT_VALUE: Couldn't parse default value.\n"
@@ -2181,6 +2423,10 @@ TEST_F(ValidationErrorTest, InvalidDefaults) {
     "foo.proto: Foo.baz: DEFAULT_VALUE: Boolean default must be true or "
       "false.\n"
     "foo.proto: Foo.qux: DEFAULT_VALUE: Messages can't have default values.\n"
+    "foo.proto: Foo.corge: DEFAULT_VALUE: Repeated fields can't have default "
+      "values.\n"
+    // This ends up being reported later because the error is detected at
+    // cross-linking time.
     "foo.proto: Foo.quux: DEFAULT_VALUE: Messages can't have default "
       "values.\n");
 }
@@ -2473,6 +2719,24 @@ TEST_F(ValidationErrorTest, SearchMostLocalFirst) {
     "foo.proto: Foo.baz: TYPE: \"Bar.Baz\" is not defined.\n");
 }
 
+TEST_F(ValidationErrorTest, SearchMostLocalFirst2) {
+  // This test would find the most local "Bar" first, and does, but
+  // proceeds to find the outer one because the inner one's not an
+  // aggregate.
+  BuildFile(
+    "name: \"foo.proto\" "
+    "message_type {"
+    "  name: \"Bar\""
+    "  nested_type { name: \"Baz\" }"
+    "}"
+    "message_type {"
+    "  name: \"Foo\""
+    "  field { name: \"Bar\" number:1 type:TYPE_BYTES } "
+    "  field { name:\"baz\" number:2 label:LABEL_OPTIONAL"
+    "          type_name:\"Bar.Baz\" }"
+    "}");
+}
+
 TEST_F(ValidationErrorTest, PackageOriginallyDeclaredInTransitiveDependent) {
   // Imagine we have the following:
   //
@@ -2519,11 +2783,39 @@ TEST_F(ValidationErrorTest, FieldTypeNotAType) {
     "name: \"foo.proto\" "
     "message_type {"
     "  name: \"Foo\""
-    "  field { name:\"foo\" number:1 label:LABEL_OPTIONAL type_name:\"bar\" }"
+    "  field { name:\"foo\" number:1 label:LABEL_OPTIONAL "
+    "          type_name:\".Foo.bar\" }"
     "  field { name:\"bar\" number:2 label:LABEL_OPTIONAL type:TYPE_INT32 }"
     "}",
 
-    "foo.proto: Foo.foo: TYPE: \"bar\" is not a type.\n");
+    "foo.proto: Foo.foo: TYPE: \".Foo.bar\" is not a type.\n");
+}
+
+TEST_F(ValidationErrorTest, RelativeFieldTypeNotAType) {
+  BuildFileWithErrors(
+    "name: \"foo.proto\" "
+    "message_type {"
+    "  nested_type {"
+    "    name: \"Bar\""
+    "    field { name:\"Baz\" number:2 label:LABEL_OPTIONAL type:TYPE_INT32 }"
+    "  }"
+    "  name: \"Foo\""
+    "  field { name:\"foo\" number:1 label:LABEL_OPTIONAL "
+    "          type_name:\"Bar.Baz\" }"
+    "}",
+    "foo.proto: Foo.foo: TYPE: \"Bar.Baz\" is not a type.\n");
+}
+
+TEST_F(ValidationErrorTest, FieldTypeMayBeItsName) {
+  BuildFile(
+    "name: \"foo.proto\" "
+    "message_type {"
+    "  name: \"Bar\""
+    "}"
+    "message_type {"
+    "  name: \"Foo\""
+    "  field { name:\"Bar\" number:1 label:LABEL_OPTIONAL type_name:\"Bar\" }"
+    "}");
 }
 
 TEST_F(ValidationErrorTest, EnumFieldTypeIsMessage) {
@@ -3346,6 +3638,21 @@ TEST_F(DatabaseBackedPoolTest, FindExtensionByNumber) {
   EXPECT_TRUE(pool.FindExtensionByNumber(foo, 12) == NULL);
 }
 
+TEST_F(DatabaseBackedPoolTest, FindAllExtensions) {
+  DescriptorPool pool(&database_);
+
+  const Descriptor* foo = pool.FindMessageTypeByName("Foo");
+
+  for (int i = 0; i < 2; ++i) {
+    // Repeat the lookup twice, to check that we get consistent
+    // results despite the fallback database lookup mutating the pool.
+    vector<const FieldDescriptor*> extensions;
+    pool.FindAllExtensions(foo, &extensions);
+    ASSERT_EQ(1, extensions.size());
+    EXPECT_EQ(5, extensions[0]->number());
+  }
+}
+
 TEST_F(DatabaseBackedPoolTest, ErrorWithoutErrorCollector) {
   ErrorDescriptorDatabase error_database;
   DescriptorPool pool(&error_database);

+ 2 - 2
src/google/protobuf/dynamic_message.cc

@@ -229,8 +229,7 @@ DynamicMessage::DynamicMessage(const TypeInfo* type_info)
   new(OffsetToPointer(type_info_->unknown_fields_offset)) UnknownFieldSet;
 
   if (type_info_->extensions_offset != -1) {
-    new(OffsetToPointer(type_info_->extensions_offset))
-      ExtensionSet(&type_info_->type, type_info_->pool, type_info_->factory);
+    new(OffsetToPointer(type_info_->extensions_offset)) ExtensionSet;
   }
 
   for (int i = 0; i < descriptor->field_count(); i++) {
@@ -508,6 +507,7 @@ const Message* DynamicMessageFactory::GetPrototype(const Descriptor* type) {
       type_info->unknown_fields_offset,
       type_info->extensions_offset,
       type_info->pool,
+      this,
       type_info->size));
 
   // Cross link prototypes.

Tiedoston diff-näkymää rajattu, sillä se on liian suuri
+ 471 - 257
src/google/protobuf/extension_set.cc


+ 314 - 127
src/google/protobuf/extension_set.h

@@ -44,6 +44,7 @@
 #include <utility>
 #include <string>
 
+#include <google/protobuf/stubs/common.h>
 #include <google/protobuf/message.h>
 
 namespace google {
@@ -53,6 +54,7 @@ namespace protobuf {
   class DescriptorPool;                                // descriptor.h
   class Message;                                       // message.h
   class MessageFactory;                                // message.h
+  class UnknownFieldSet;                               // unknown_field_set.h
   namespace io {
     class CodedInputStream;                              // coded_stream.h
     class CodedOutputStream;                             // coded_stream.h
@@ -64,6 +66,12 @@ namespace protobuf {
 namespace protobuf {
 namespace internal {
 
+// Used to store values of type FieldDescriptor::Type without having to
+// #include descriptor.h.  Also, ensures that we use only one byte to store
+// these values, which is important to keep the layout of
+// ExtensionSet::Extension small.
+typedef uint8 FieldType;
+
 // This is an internal helper class intended for use within the protocol buffer
 // library and generated classes.  Clients should not use it directly.  Instead,
 // use the generated accessors such as GetExtension() of the class being
@@ -77,30 +85,42 @@ namespace internal {
 // off to the ExtensionSet for parsing.  Etc.
 class LIBPROTOBUF_EXPORT ExtensionSet {
  public:
-  // Construct an ExtensionSet.
-  //   extendee:  Descriptor for the type being extended. We pass in a pointer
-  //              to a pointer to the extendee to get around an initialization
-  //              problem: when we create the ExtensionSet for a message type,
-  //              its descriptor may not exist yet. But we know where that
-  //              descriptor pointer will be placed, and by the time it's used
-  //              by this ExtensionSet it will be fully initialized, so passing
-  //              a pointer to that location works. Note that this problem
-  //              will only occur for messages defined in descriptor.proto.
-  //   pool:      DescriptorPool to search for extension definitions.
-  //   factory:   MessageFactory used to construct implementations of messages
-  //              for extensions with message type.  This factory must be able
-  //              to construct any message type found in "pool".
-  // All three objects remain property of the caller and must outlive the
-  // ExtensionSet.
-  ExtensionSet(const Descriptor* const* extendee,
-               const DescriptorPool* pool,
-               MessageFactory* factory);
-
+  ExtensionSet();
   ~ExtensionSet();
 
+  // A function which, given an integer value, returns true if the number
+  // matches one of the defined values for the corresponding enum type.  This
+  // is used with RegisterEnumExtension, below.
+  typedef bool EnumValidityFunc(int number);
+
+  // These are called at startup by protocol-compiler-generated code to
+  // register known extensions.  The registrations are used by ParseField()
+  // to look up extensions for parsed field numbers.  Note that dynamic parsing
+  // does not use ParseField(); only protocol-compiler-generated parsing
+  // methods do.
+  static void RegisterExtension(const Message* containing_type,
+                                int number, FieldType type,
+                                bool is_repeated, bool is_packed);
+  static void RegisterEnumExtension(const Message* containing_type,
+                                    int number, FieldType type,
+                                    bool is_repeated, bool is_packed,
+                                    EnumValidityFunc* is_valid);
+  static void RegisterMessageExtension(const Message* containing_type,
+                                       int number, FieldType type,
+                                       bool is_repeated, bool is_packed,
+                                       const Message* prototype);
+
+  // =================================================================
+
   // Add all fields which are currently present to the given vector.  This
-  // is useful to implement Reflection::ListFields().
-  void AppendToList(vector<const FieldDescriptor*>* output) const;
+  // is useful to implement Reflection::ListFields().  The FieldDescriptors
+  // are looked up by number from the given pool.
+  //
+  // TODO(kenton): Looking up each field by number is somewhat unfortunate.
+  //   Is there a better way?
+  void AppendToList(const Descriptor* containing_type,
+                    const DescriptorPool* pool,
+                    vector<const FieldDescriptor*>* output) const;
 
   // =================================================================
   // Accessors
@@ -138,28 +158,34 @@ class LIBPROTOBUF_EXPORT ExtensionSet {
 
   // singular fields -------------------------------------------------
 
-  int32  GetInt32 (int number) const;
-  int64  GetInt64 (int number) const;
-  uint32 GetUInt32(int number) const;
-  uint64 GetUInt64(int number) const;
-  float  GetFloat (int number) const;
-  double GetDouble(int number) const;
-  bool   GetBool  (int number) const;
-  int    GetEnum  (int number) const;
-  const string & GetString (int number) const;
-  const Message& GetMessage(int number) const;
-
-  void SetInt32 (int number, int32  value);
-  void SetInt64 (int number, int64  value);
-  void SetUInt32(int number, uint32 value);
-  void SetUInt64(int number, uint64 value);
-  void SetFloat (int number, float  value);
-  void SetDouble(int number, double value);
-  void SetBool  (int number, bool   value);
-  void SetEnum  (int number, int    value);
-  void SetString(int number, const string& value);
-  string * MutableString (int number);
-  Message* MutableMessage(int number);
+  int32  GetInt32 (int number, int32  default_value) const;
+  int64  GetInt64 (int number, int64  default_value) const;
+  uint32 GetUInt32(int number, uint32 default_value) const;
+  uint64 GetUInt64(int number, uint64 default_value) const;
+  float  GetFloat (int number, float  default_value) const;
+  double GetDouble(int number, double default_value) const;
+  bool   GetBool  (int number, bool   default_value) const;
+  int    GetEnum  (int number, int    default_value) const;
+  const string & GetString (int number, const string&  default_value) const;
+  const Message& GetMessage(int number, const Message& default_value) const;
+  const Message& GetMessage(int number, const Descriptor* message_type,
+                            MessageFactory* factory) const;
+
+  void SetInt32 (int number, FieldType type, int32  value);
+  void SetInt64 (int number, FieldType type, int64  value);
+  void SetUInt32(int number, FieldType type, uint32 value);
+  void SetUInt64(int number, FieldType type, uint64 value);
+  void SetFloat (int number, FieldType type, float  value);
+  void SetDouble(int number, FieldType type, double value);
+  void SetBool  (int number, FieldType type, bool   value);
+  void SetEnum  (int number, FieldType type, int    value);
+  void SetString(int number, FieldType type, const string& value);
+  string * MutableString (int number, FieldType type);
+  Message* MutableMessage(int number, FieldType type,
+                          const Message& prototype);
+  Message* MutableMessage(int number, FieldType type,
+                          const Descriptor* message_type,
+                          MessageFactory* factory);
 
   // repeated fields -------------------------------------------------
 
@@ -186,17 +212,21 @@ class LIBPROTOBUF_EXPORT ExtensionSet {
   string * MutableRepeatedString (int number, int index);
   Message* MutableRepeatedMessage(int number, int index);
 
-  void AddInt32 (int number, int32  value);
-  void AddInt64 (int number, int64  value);
-  void AddUInt32(int number, uint32 value);
-  void AddUInt64(int number, uint64 value);
-  void AddFloat (int number, float  value);
-  void AddDouble(int number, double value);
-  void AddBool  (int number, bool   value);
-  void AddEnum  (int number, int    value);
-  void AddString(int number, const string& value);
-  string * AddString (int number);
-  Message* AddMessage(int number);
+  void AddInt32 (int number, FieldType type, bool packed, int32  value);
+  void AddInt64 (int number, FieldType type, bool packed, int64  value);
+  void AddUInt32(int number, FieldType type, bool packed, uint32 value);
+  void AddUInt64(int number, FieldType type, bool packed, uint64 value);
+  void AddFloat (int number, FieldType type, bool packed, float  value);
+  void AddDouble(int number, FieldType type, bool packed, double value);
+  void AddBool  (int number, FieldType type, bool packed, bool   value);
+  void AddEnum  (int number, FieldType type, bool packed, int    value);
+  void AddString(int number, FieldType type, const string& value);
+  string * AddString (int number, FieldType type);
+  Message* AddMessage(int number, FieldType type,
+                      const Message& prototype);
+  Message* AddMessage(int number, FieldType type,
+                      const Descriptor* message_type,
+                      MessageFactory* factory);
 
   // -----------------------------------------------------------------
   // TODO(kenton):  Hardcore memory management accessors
@@ -212,40 +242,41 @@ class LIBPROTOBUF_EXPORT ExtensionSet {
   void Swap(ExtensionSet* other);
   bool IsInitialized() const;
 
-  // These parsing and serialization functions all want a pointer to the
-  // message object because they hand off the actual work to WireFormat,
-  // which works in terms of a reflection interface.  Yes, this means there
-  // are some redundant virtual function calls that end up being made, but
-  // it probably doesn't matter much in practice, and the alternative would
-  // involve reproducing a lot of WireFormat's functionality.
-
   // Parses a single extension from the input.  The input should start out
-  // positioned immediately after the tag.
-  bool ParseField(uint32 tag, io::CodedInputStream* input, Message* message);
+  // positioned immediately after the tag.  |containing_type| is the default
+  // instance for the containing message; it is used only to look up the
+  // extension by number.  See RegisterExtension(), above.  Unlike the other
+  // methods of ExtensionSet, this only works for generated message types --
+  // it looks up extensions registered using RegisterExtension().
+  bool ParseField(uint32 tag, io::CodedInputStream* input,
+                  const Message* containing_type,
+                  UnknownFieldSet* unknown_fields);
 
   // Write all extension fields with field numbers in the range
   //   [start_field_number, end_field_number)
   // to the output stream, using the cached sizes computed when ByteSize() was
   // last called.  Note that the range bounds are inclusive-exclusive.
-  bool SerializeWithCachedSizes(int start_field_number,
+  void SerializeWithCachedSizes(int start_field_number,
                                 int end_field_number,
-                                const Message& message,
                                 io::CodedOutputStream* output) const;
 
+  // Same as SerializeWithCachedSizes, but without any bounds checking.
+  // The caller must ensure that target has sufficient capacity for the
+  // serialized extensions.
+  //
+  // Returns a pointer past the last written byte.
+  uint8* SerializeWithCachedSizesToArray(int start_field_number,
+                                         int end_field_number,
+                                         uint8* target) const;
+
   // Returns the total serialized size of all the extensions.
-  int ByteSize(const Message& message) const;
+  int ByteSize() const;
 
   // Returns (an estimate of) the total number of bytes used for storing the
   // extensions in memory, excluding sizeof(*this).
   int SpaceUsedExcludingSelf() const;
 
  private:
-  // Like FindKnownExtension(), but GOOGLE_CHECK-fail if not found.
-  const FieldDescriptor* FindKnownExtensionOrDie(int number) const;
-
-  // Get the prototype for the message.
-  const Message* GetPrototype(const Descriptor* message_type) const;
-
   struct Extension {
     union {
       int32    int32_value;
@@ -271,7 +302,8 @@ class LIBPROTOBUF_EXPORT ExtensionSet {
       RepeatedPtrField<Message>* repeated_message_value;
     };
 
-    const FieldDescriptor* descriptor;
+    FieldType type;
+    bool is_repeated;
 
     // For singular types, indicates if the extension is "cleared".  This
     // happens when an extension is set and then later cleared by the caller.
@@ -281,19 +313,29 @@ class LIBPROTOBUF_EXPORT ExtensionSet {
     // simply becomes zero when cleared.
     bool is_cleared;
 
-    Extension(): descriptor(NULL), is_cleared(false) {}
+    // For repeated types, this indicates if the [packed=true] option is set.
+    bool is_packed;
+
+    // For packed fields, the size of the packed data is recorded here when
+    // ByteSize() is called then used during serialization.
+    // TODO(kenton):  Use atomic<int> when C++ supports it.
+    mutable int cached_size;
 
     // Some helper methods for operations on a single Extension.
-    bool SerializeFieldWithCachedSizes(
-        const Message& message,
+    void SerializeFieldWithCachedSizes(
+        int number,
         io::CodedOutputStream* output) const;
-    int64 ByteSize(const Message& message) const;
+    int ByteSize(int number) const;
     void Clear();
     int GetSize() const;
     void Free();
     int SpaceUsedExcludingSelf() const;
   };
 
+  // Gets the extension with the given number, creating it if it does not
+  // already exist.  Returns true if the extension did not already exist.
+  bool MaybeNewExtension(int number, Extension** result);
+
   // The Extension struct is small enough to be passed by value, so we use it
   // directly as the value type in the map rather than use pointers.  We use
   // a map rather than hash_map here because we expect most ExtensionSets will
@@ -301,30 +343,26 @@ class LIBPROTOBUF_EXPORT ExtensionSet {
   // for 100 elements or more.  Also, we want AppendToList() to order fields
   // by field number.
   map<int, Extension> extensions_;
-  const Descriptor* const* extendee_;
-  const DescriptorPool* descriptor_pool_;
-  MessageFactory* message_factory_;
 
   GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(ExtensionSet);
 };
 
 // These are just for convenience...
-inline void ExtensionSet::SetString(int number, const string& value) {
-  MutableString(number)->assign(value);
+inline void ExtensionSet::SetString(int number, FieldType type,
+                                    const string& value) {
+  MutableString(number, type)->assign(value);
 }
 inline void ExtensionSet::SetRepeatedString(int number, int index,
                                             const string& value) {
   MutableRepeatedString(number, index)->assign(value);
 }
-inline void ExtensionSet::AddString(int number, const string& value) {
-  AddString(number)->assign(value);
+inline void ExtensionSet::AddString(int number, FieldType type,
+                                    const string& value) {
+  AddString(number, type)->assign(value);
 }
 
 // ===================================================================
-// Implementation details
-//
-// DO NOT DEPEND ON ANYTHING BELOW THIS POINT.  This is for use from
-// generated code only.
+// Glue for generated extension accessors
 
 // -------------------------------------------------------------------
 // Template magic
@@ -377,8 +415,10 @@ class PrimitiveTypeTraits {
  public:
   typedef Type ConstType;
 
-  static inline ConstType Get(int number, const ExtensionSet& set);
-  static inline void Set(int number, ConstType value, ExtensionSet* set);
+  static inline ConstType Get(int number, const ExtensionSet& set,
+                              ConstType default_value);
+  static inline void Set(int number, FieldType field_type,
+                         ConstType value, ExtensionSet* set);
 };
 
 template <typename Type>
@@ -388,17 +428,18 @@ class RepeatedPrimitiveTypeTraits {
 
   static inline Type Get(int number, const ExtensionSet& set, int index);
   static inline void Set(int number, int index, Type value, ExtensionSet* set);
-  static inline void Add(int number, Type value, ExtensionSet* set);
+  static inline void Add(int number, FieldType field_type,
+                         bool is_packed, Type value, ExtensionSet* set);
 };
 
 #define PROTOBUF_DEFINE_PRIMITIVE_TYPE(TYPE, METHOD)                       \
 template<> inline TYPE PrimitiveTypeTraits<TYPE>::Get(                     \
-    int number, const ExtensionSet& set) {                                 \
-  return set.Get##METHOD(number);                                          \
+    int number, const ExtensionSet& set, TYPE default_value) {             \
+  return set.Get##METHOD(number, default_value);                           \
 }                                                                          \
 template<> inline void PrimitiveTypeTraits<TYPE>::Set(                     \
-    int number, ConstType value, ExtensionSet* set) {                      \
-  set->Set##METHOD(number, value);                                         \
+    int number, FieldType field_type, TYPE value, ExtensionSet* set) {     \
+  set->Set##METHOD(number, field_type, value);                             \
 }                                                                          \
                                                                            \
 template<> inline TYPE RepeatedPrimitiveTypeTraits<TYPE>::Get(             \
@@ -406,12 +447,13 @@ template<> inline TYPE RepeatedPrimitiveTypeTraits<TYPE>::Get(             \
   return set.GetRepeated##METHOD(number, index);                           \
 }                                                                          \
 template<> inline void RepeatedPrimitiveTypeTraits<TYPE>::Set(             \
-    int number, int index, ConstType value, ExtensionSet* set) {           \
+    int number, int index, TYPE value, ExtensionSet* set) {                \
   set->SetRepeated##METHOD(number, index, value);                          \
 }                                                                          \
 template<> inline void RepeatedPrimitiveTypeTraits<TYPE>::Add(             \
-    int number, ConstType value, ExtensionSet* set) {                      \
-  set->Add##METHOD(number, value);                                         \
+    int number, FieldType field_type, bool is_packed,                      \
+    TYPE value, ExtensionSet* set) {                                       \
+  set->Add##METHOD(number, field_type, is_packed, value);                  \
 }
 
 PROTOBUF_DEFINE_PRIMITIVE_TYPE( int32,  Int32)
@@ -433,14 +475,17 @@ class LIBPROTOBUF_EXPORT StringTypeTraits {
   typedef const string& ConstType;
   typedef string* MutableType;
 
-  static inline const string& Get(int number, const ExtensionSet& set) {
-    return set.GetString(number);
+  static inline const string& Get(int number, const ExtensionSet& set,
+                                  ConstType default_value) {
+    return set.GetString(number, default_value);
   }
-  static inline void Set(int number, const string& value, ExtensionSet* set) {
-    set->SetString(number, value);
+  static inline void Set(int number, FieldType field_type,
+                         const string& value, ExtensionSet* set) {
+    set->SetString(number, field_type, value);
   }
-  static inline string* Mutable(int number, ExtensionSet* set) {
-    return set->MutableString(number);
+  static inline string* Mutable(int number, FieldType field_type,
+                                ExtensionSet* set) {
+    return set->MutableString(number, field_type);
   }
 };
 
@@ -460,11 +505,14 @@ class LIBPROTOBUF_EXPORT RepeatedStringTypeTraits {
   static inline string* Mutable(int number, int index, ExtensionSet* set) {
     return set->MutableRepeatedString(number, index);
   }
-  static inline void Add(int number, const string& value, ExtensionSet* set) {
-    set->AddString(number, value);
+  static inline void Add(int number, FieldType field_type,
+                         bool is_packed, const string& value,
+                         ExtensionSet* set) {
+    set->AddString(number, field_type, value);
   }
-  static inline string* Add(int number, ExtensionSet* set) {
-    return set->AddString(number);
+  static inline string* Add(int number, FieldType field_type,
+                            ExtensionSet* set) {
+    return set->AddString(number, field_type);
   }
 };
 
@@ -473,20 +521,23 @@ class LIBPROTOBUF_EXPORT RepeatedStringTypeTraits {
 
 // ExtensionSet represents enums using integers internally, so we have to
 // static_cast around.
-template <typename Type>
+template <typename Type, bool IsValid(int)>
 class EnumTypeTraits {
  public:
   typedef Type ConstType;
 
-  static inline ConstType Get(int number, const ExtensionSet& set) {
-    return static_cast<Type>(set.GetEnum(number));
+  static inline ConstType Get(int number, const ExtensionSet& set,
+                              ConstType default_value) {
+    return static_cast<Type>(set.GetEnum(number, default_value));
   }
-  static inline void Set(int number, ConstType value, ExtensionSet* set) {
-    set->SetEnum(number, value);
+  static inline void Set(int number, FieldType field_type,
+                         ConstType value, ExtensionSet* set) {
+    GOOGLE_DCHECK(IsValid(value));
+    set->SetEnum(number, field_type, value);
   }
 };
 
-template <typename Type>
+template <typename Type, bool IsValid(int)>
 class RepeatedEnumTypeTraits {
  public:
   typedef Type ConstType;
@@ -496,10 +547,13 @@ class RepeatedEnumTypeTraits {
   }
   static inline void Set(int number, int index,
                          ConstType value, ExtensionSet* set) {
+    GOOGLE_DCHECK(IsValid(value));
     set->SetRepeatedEnum(number, index, value);
   }
-  static inline void Add(int number, ConstType value, ExtensionSet* set) {
-    set->AddEnum(number, value);
+  static inline void Add(int number, FieldType field_type,
+                         bool is_packed, ConstType value, ExtensionSet* set) {
+    GOOGLE_DCHECK(IsValid(value));
+    set->AddEnum(number, field_type, is_packed, value);
   }
 };
 
@@ -513,13 +567,17 @@ template <typename Type>
 class MessageTypeTraits {
  public:
   typedef const Type& ConstType;
- typedef Type* MutableType;
+  typedef Type* MutableType;
 
-  static inline ConstType Get(int number, const ExtensionSet& set) {
-    return static_cast<const Type&>(set.GetMessage(number));
+  static inline ConstType Get(int number, const ExtensionSet& set,
+                              ConstType default_value) {
+    return static_cast<const Type&>(
+        set.GetMessage(number, default_value));
   }
-  static inline MutableType Mutable(int number, ExtensionSet* set) {
-    return static_cast<Type*>(set->MutableMessage(number));
+  static inline MutableType Mutable(int number, FieldType field_type,
+                                    ExtensionSet* set) {
+    return static_cast<Type*>(
+        set->MutableMessage(number, field_type, Type::default_instance()));
   }
 };
 
@@ -535,8 +593,10 @@ class RepeatedMessageTypeTraits {
   static inline MutableType Mutable(int number, int index, ExtensionSet* set) {
     return static_cast<Type*>(set->MutableRepeatedMessage(number, index));
   }
-  static inline MutableType Add(int number, ExtensionSet* set) {
-    return static_cast<Type*>(set->AddMessage(number));
+  static inline MutableType Add(int number, FieldType field_type,
+                                ExtensionSet* set) {
+    return static_cast<Type*>(
+        set->AddMessage(number, field_type, Type::default_instance()));
   }
 };
 
@@ -546,7 +606,7 @@ class RepeatedMessageTypeTraits {
 // This is the type of actual extension objects.  E.g. if you have:
 //   extends Foo with optional int32 bar = 1234;
 // then "bar" will be defined in C++ as:
-//   ExtensionIdentifier<Foo, PrimitiveTypeTraits<int32>> bar(1234);
+//   ExtensionIdentifier<Foo, PrimitiveTypeTraits<int32>, 1, false> bar(1234);
 //
 // Note that we could, in theory, supply the field number as a template
 // parameter, and thus make an instance of ExtensionIdentifier have no
@@ -557,18 +617,145 @@ class RepeatedMessageTypeTraits {
 // but that would be bad because it would cause this extension to not be
 // registered at static initialization, and therefore using it would crash.
 
-template <typename ExtendeeType, typename TypeTraitsType>
-class ExtensionIdentifier {
+template <typename ExtendeeType, typename TypeTraitsType,
+          FieldType field_type, bool is_packed>
+class LIBPROTOBUF_EXPORT ExtensionIdentifier {
  public:
   typedef TypeTraitsType TypeTraits;
   typedef ExtendeeType Extendee;
 
-  ExtensionIdentifier(int number): number_(number) {}
+  ExtensionIdentifier(int number, typename TypeTraits::ConstType default_value)
+      : number_(number), default_value_(default_value) {}
   inline int number() const { return number_; }
+  typename TypeTraits::ConstType default_value() const {
+    return default_value_;
+  }
+
  private:
   const int number_;
+  const typename TypeTraits::ConstType default_value_;
 };
 
+// -------------------------------------------------------------------
+// Generated accessors
+
+// This macro should be expanded in the context of a generated type which
+// has extensions.
+//
+// We use "_proto_TypeTraits" as a type name below because "TypeTraits"
+// causes problems if the class has a nested message or enum type with that
+// name and "_TypeTraits" is technically reserved for the C++ library since
+// it starts with an underscore followed by a capital letter.
+#define GOOGLE_PROTOBUF_EXTENSION_ACCESSORS(CLASSNAME)                        \
+  /* Has, Size, Clear */                                                      \
+  template <typename _proto_TypeTraits,                                       \
+            ::google::protobuf::internal::FieldType field_type,                         \
+            bool is_packed>                                                   \
+  inline bool HasExtension(                                                   \
+      const ::google::protobuf::internal::ExtensionIdentifier<                          \
+        CLASSNAME, _proto_TypeTraits, field_type, is_packed>& id) const {     \
+    return _extensions_.Has(id.number());                                     \
+  }                                                                           \
+                                                                              \
+  template <typename _proto_TypeTraits,                                       \
+            ::google::protobuf::internal::FieldType field_type,                         \
+            bool is_packed>                                                   \
+  inline void ClearExtension(                                                 \
+      const ::google::protobuf::internal::ExtensionIdentifier<                          \
+        CLASSNAME, _proto_TypeTraits, field_type, is_packed>& id) {           \
+    _extensions_.ClearExtension(id.number());                                 \
+  }                                                                           \
+                                                                              \
+  template <typename _proto_TypeTraits,                                       \
+            ::google::protobuf::internal::FieldType field_type,                         \
+            bool is_packed>                                                   \
+  inline int ExtensionSize(                                                   \
+      const ::google::protobuf::internal::ExtensionIdentifier<                          \
+        CLASSNAME, _proto_TypeTraits, field_type, is_packed>& id) const {     \
+    return _extensions_.ExtensionSize(id.number());                           \
+  }                                                                           \
+                                                                              \
+  /* Singular accessors */                                                    \
+  template <typename _proto_TypeTraits,                                       \
+            ::google::protobuf::internal::FieldType field_type,                         \
+            bool is_packed>                                                   \
+  inline typename _proto_TypeTraits::ConstType GetExtension(                  \
+      const ::google::protobuf::internal::ExtensionIdentifier<                          \
+        CLASSNAME, _proto_TypeTraits, field_type, is_packed>& id) const {     \
+    return _proto_TypeTraits::Get(id.number(), _extensions_,                  \
+                                  id.default_value());                        \
+  }                                                                           \
+                                                                              \
+  template <typename _proto_TypeTraits,                                       \
+            ::google::protobuf::internal::FieldType field_type,                         \
+            bool is_packed>                                                   \
+  inline typename _proto_TypeTraits::MutableType MutableExtension(            \
+      const ::google::protobuf::internal::ExtensionIdentifier<                          \
+        CLASSNAME, _proto_TypeTraits, field_type, is_packed>& id) {           \
+    return _proto_TypeTraits::Mutable(id.number(), field_type, &_extensions_);\
+  }                                                                           \
+                                                                              \
+  template <typename _proto_TypeTraits,                                       \
+            ::google::protobuf::internal::FieldType field_type,                         \
+            bool is_packed>                                                   \
+  inline void SetExtension(                                                   \
+      const ::google::protobuf::internal::ExtensionIdentifier<                          \
+        CLASSNAME, _proto_TypeTraits, field_type, is_packed>& id,             \
+      typename _proto_TypeTraits::ConstType value) {                          \
+    _proto_TypeTraits::Set(id.number(), field_type, value, &_extensions_);    \
+  }                                                                           \
+                                                                              \
+  /* Repeated accessors */                                                    \
+  template <typename _proto_TypeTraits,                                       \
+            ::google::protobuf::internal::FieldType field_type,                         \
+            bool is_packed>                                                   \
+  inline typename _proto_TypeTraits::ConstType GetExtension(                  \
+      const ::google::protobuf::internal::ExtensionIdentifier<                          \
+        CLASSNAME, _proto_TypeTraits, field_type, is_packed>& id,             \
+      int index) const {                                                      \
+    return _proto_TypeTraits::Get(id.number(), _extensions_, index);          \
+  }                                                                           \
+                                                                              \
+  template <typename _proto_TypeTraits,                                       \
+            ::google::protobuf::internal::FieldType field_type,                         \
+            bool is_packed>                                                   \
+  inline typename _proto_TypeTraits::MutableType MutableExtension(            \
+      const ::google::protobuf::internal::ExtensionIdentifier<                          \
+        CLASSNAME, _proto_TypeTraits, field_type, is_packed>& id,             \
+      int index) {                                                            \
+    return _proto_TypeTraits::Mutable(id.number(), index, &_extensions_);     \
+  }                                                                           \
+                                                                              \
+  template <typename _proto_TypeTraits,                                       \
+            ::google::protobuf::internal::FieldType field_type,                         \
+            bool is_packed>                                                   \
+  inline void SetExtension(                                                   \
+      const ::google::protobuf::internal::ExtensionIdentifier<                          \
+        CLASSNAME, _proto_TypeTraits, field_type, is_packed>& id,             \
+      int index, typename _proto_TypeTraits::ConstType value) {               \
+    _proto_TypeTraits::Set(id.number(), index, value, &_extensions_);         \
+  }                                                                           \
+                                                                              \
+  template <typename _proto_TypeTraits,                                       \
+            ::google::protobuf::internal::FieldType field_type,                         \
+            bool is_packed>                                                   \
+  inline typename _proto_TypeTraits::MutableType AddExtension(                \
+      const ::google::protobuf::internal::ExtensionIdentifier<                          \
+        CLASSNAME, _proto_TypeTraits, field_type, is_packed>& id) {           \
+    return _proto_TypeTraits::Add(id.number(), field_type, &_extensions_);    \
+  }                                                                           \
+                                                                              \
+  template <typename _proto_TypeTraits,                                       \
+            ::google::protobuf::internal::FieldType field_type,                         \
+            bool is_packed>                                                   \
+  inline void AddExtension(                                                   \
+      const ::google::protobuf::internal::ExtensionIdentifier<                          \
+        CLASSNAME, _proto_TypeTraits, field_type, is_packed>& id,             \
+      typename _proto_TypeTraits::ConstType value) {                          \
+    _proto_TypeTraits::Add(id.number(), field_type, is_packed,                \
+                           value, &_extensions_);                             \
+  }
+
 }  // namespace internal
 }  // namespace protobuf
 

+ 82 - 4
src/google/protobuf/extension_set_unittest.cc

@@ -35,10 +35,13 @@
 #include <google/protobuf/extension_set.h>
 #include <google/protobuf/unittest.pb.h>
 #include <google/protobuf/test_util.h>
+#include <google/protobuf/io/coded_stream.h>
+#include <google/protobuf/io/zero_copy_stream_impl.h>
 
 #include <google/protobuf/stubs/common.h>
 #include <google/protobuf/testing/googletest.h>
 #include <gtest/gtest.h>
+#include <google/protobuf/stubs/stl_util-inl.h>
 
 namespace google {
 namespace protobuf {
@@ -102,6 +105,11 @@ TEST(ExtensionSetTest, Clear) {
               unittest::optional_foreign_message_extension));
   EXPECT_NE(&unittest_import::ImportMessage::default_instance(),
             &message.GetExtension(unittest::optional_import_message_extension));
+
+  // Make sure setting stuff again after clearing works.  (This takes slightly
+  // different code paths since the objects are reused.)
+  TestUtil::SetAllExtensions(&message);
+  TestUtil::ExpectAllExtensionsSet(message);
 }
 
 TEST(ExtensionSetTest, ClearOneField) {
@@ -166,28 +174,90 @@ TEST(ExtensionSetTest, SwapWithSelf) {
   TestUtil::ExpectAllExtensionsSet(message);
 }
 
-TEST(ExtensionSetTest, Serialization) {
+TEST(ExtensionSetTest, SerializationToArray) {
   // Serialize as TestAllExtensions and parse as TestAllTypes to insure wire
   // compatibility of extensions.
+  //
+  // This checks serialization to a flat array by explicitly reserving space in
+  // the string and calling the generated message's
+  // SerializeWithCachedSizesToArray.
   unittest::TestAllExtensions source;
   unittest::TestAllTypes destination;
+  TestUtil::SetAllExtensions(&source);
+  int size = source.ByteSize();
   string data;
+  data.resize(size);
+  uint8* target = reinterpret_cast<uint8*>(string_as_array(&data));
+  uint8* end = source.SerializeWithCachedSizesToArray(target);
+  EXPECT_EQ(size, end - target);
+  EXPECT_TRUE(destination.ParseFromString(data));
+  TestUtil::ExpectAllFieldsSet(destination);
+}
 
+TEST(ExtensionSetTest, SerializationToStream) {
+  // Serialize as TestAllExtensions and parse as TestAllTypes to insure wire
+  // compatibility of extensions.
+  //
+  // This checks serialization to an output stream by creating an array output
+  // stream that can only buffer 1 byte at a time - this prevents the message
+  // from ever jumping to the fast path, ensuring that serialization happens via
+  // the CodedOutputStream.
+  unittest::TestAllExtensions source;
+  unittest::TestAllTypes destination;
   TestUtil::SetAllExtensions(&source);
-  source.SerializeToString(&data);
+  int size = source.ByteSize();
+  string data;
+  data.resize(size);
+  {
+    io::ArrayOutputStream array_stream(string_as_array(&data), size, 1);
+    io::CodedOutputStream output_stream(&array_stream);
+    source.SerializeWithCachedSizes(&output_stream);
+    ASSERT_FALSE(output_stream.HadError());
+  }
   EXPECT_TRUE(destination.ParseFromString(data));
   TestUtil::ExpectAllFieldsSet(destination);
 }
 
-TEST(ExtensionSetTest, PackedSerialization) {
+TEST(ExtensionSetTest, PackedSerializationToArray) {
   // Serialize as TestPackedExtensions and parse as TestPackedTypes to insure
   // wire compatibility of extensions.
+  //
+  // This checks serialization to a flat array by explicitly reserving space in
+  // the string and calling the generated message's
+  // SerializeWithCachedSizesToArray.
   unittest::TestPackedExtensions source;
   unittest::TestPackedTypes destination;
+  TestUtil::SetPackedExtensions(&source);
+  int size = source.ByteSize();
   string data;
+  data.resize(size);
+  uint8* target = reinterpret_cast<uint8*>(string_as_array(&data));
+  uint8* end = source.SerializeWithCachedSizesToArray(target);
+  EXPECT_EQ(size, end - target);
+  EXPECT_TRUE(destination.ParseFromString(data));
+  TestUtil::ExpectPackedFieldsSet(destination);
+}
 
+TEST(ExtensionSetTest, PackedSerializationToStream) {
+  // Serialize as TestPackedExtensions and parse as TestPackedTypes to insure
+  // wire compatibility of extensions.
+  //
+  // This checks serialization to an output stream by creating an array output
+  // stream that can only buffer 1 byte at a time - this prevents the message
+  // from ever jumping to the fast path, ensuring that serialization happens via
+  // the CodedOutputStream.
+  unittest::TestPackedExtensions source;
+  unittest::TestPackedTypes destination;
   TestUtil::SetPackedExtensions(&source);
-  source.SerializeToString(&data);
+  int size = source.ByteSize();
+  string data;
+  data.resize(size);
+  {
+    io::ArrayOutputStream array_stream(string_as_array(&data), size, 1);
+    io::CodedOutputStream output_stream(&array_stream);
+    source.SerializeWithCachedSizes(&output_stream);
+    ASSERT_FALSE(output_stream.HadError());
+  }
   EXPECT_TRUE(destination.ParseFromString(data));
   TestUtil::ExpectPackedFieldsSet(destination);
 }
@@ -395,6 +465,14 @@ TEST(ExtensionSetTest, SpaceUsedExcludingSelf) {
   }
 }
 
+TEST(ExtensionSetTest, InvalidEnumDeath) {
+  unittest::TestAllExtensions message;
+  EXPECT_DEBUG_DEATH(
+    message.SetExtension(unittest::optional_foreign_enum_extension,
+                         static_cast<unittest::ForeignEnum>(53)),
+    "IsValid");
+}
+
 }  // namespace
 }  // namespace internal
 }  // namespace protobuf

+ 51 - 16
src/google/protobuf/generated_message_reflection.cc

@@ -58,6 +58,21 @@ int StringSpaceUsedExcludingSelf(const string& str) {
   }
 }
 
+bool ParseNamedEnum(const EnumDescriptor* descriptor,
+                    const string& name,
+                    int* value) {
+  const EnumValueDescriptor* d = descriptor->FindValueByName(name);
+  if (d == NULL) return false;
+  *value = d->number();
+  return true;
+}
+
+const string& NameOfEnum(const EnumDescriptor* descriptor, int value) {
+  static string kEmptyString;
+  const EnumValueDescriptor* d = descriptor->FindValueByNumber(value);
+  return (d == NULL ? kEmptyString : d->name());
+}
+
 // ===================================================================
 // Helpers for reporting usage errors (e.g. trying to use GetInt32() on
 // a string field).
@@ -160,6 +175,7 @@ GeneratedMessageReflection::GeneratedMessageReflection(
     int unknown_fields_offset,
     int extensions_offset,
     const DescriptorPool* descriptor_pool,
+    MessageFactory* factory,
     int object_size)
   : descriptor_       (descriptor),
     default_instance_ (default_instance),
@@ -170,7 +186,8 @@ GeneratedMessageReflection::GeneratedMessageReflection(
     object_size_      (object_size),
     descriptor_pool_  ((descriptor_pool == NULL) ?
                          DescriptorPool::generated_pool() :
-                         descriptor_pool) {
+                         descriptor_pool),
+    message_factory_  (factory) {
 }
 
 GeneratedMessageReflection::~GeneratedMessageReflection() {}
@@ -365,7 +382,8 @@ void GeneratedMessageReflection::ListFields(
   }
 
   if (extensions_offset_ != -1) {
-    GetExtensionSet(message).AppendToList(output);
+    GetExtensionSet(message).AppendToList(descriptor_, descriptor_pool_,
+                                          output);
   }
 
   // ListFields() must sort output by field number.
@@ -380,7 +398,8 @@ void GeneratedMessageReflection::ListFields(
       const Message& message, const FieldDescriptor* field) const {          \
     USAGE_CHECK_ALL(Get##TYPENAME, SINGULAR, CPPTYPE);                       \
     if (field->is_extension()) {                                             \
-      return GetExtensionSet(message).Get##TYPENAME(field->number());        \
+      return GetExtensionSet(message).Get##TYPENAME(                         \
+        field->number(), field->default_value_##PASSTYPE());                 \
     } else {                                                                 \
       return GetField<TYPE>(message, field);                                 \
     }                                                                        \
@@ -392,7 +411,7 @@ void GeneratedMessageReflection::ListFields(
     USAGE_CHECK_ALL(Set##TYPENAME, SINGULAR, CPPTYPE);                       \
     if (field->is_extension()) {                                             \
       return MutableExtensionSet(message)->Set##TYPENAME(                    \
-        field->number(), value);                                             \
+        field->number(), field->type(), value);                              \
     } else {                                                                 \
       SetField<TYPE>(message, field, value);                                 \
     }                                                                        \
@@ -427,7 +446,8 @@ void GeneratedMessageReflection::ListFields(
       PASSTYPE value) const {                                                \
     USAGE_CHECK_ALL(Add##TYPENAME, REPEATED, CPPTYPE);                       \
     if (field->is_extension()) {                                             \
-      MutableExtensionSet(message)->Add##TYPENAME(field->number(), value);   \
+      MutableExtensionSet(message)->Add##TYPENAME(                           \
+        field->number(), field->type(), field->options().packed(), value);   \
     } else {                                                                 \
       AddField<TYPE>(message, field, value);                                 \
     }                                                                        \
@@ -448,7 +468,8 @@ string GeneratedMessageReflection::GetString(
     const Message& message, const FieldDescriptor* field) const {
   USAGE_CHECK_ALL(GetString, SINGULAR, STRING);
   if (field->is_extension()) {
-    return GetExtensionSet(message).GetString(field->number());
+    return GetExtensionSet(message).GetString(field->number(),
+                                              field->default_value_string());
   } else {
     return *GetField<const string*>(message, field);
   }
@@ -459,7 +480,8 @@ const string& GeneratedMessageReflection::GetStringReference(
     const FieldDescriptor* field, string* scratch) const {
   USAGE_CHECK_ALL(GetStringReference, SINGULAR, STRING);
   if (field->is_extension()) {
-    return GetExtensionSet(message).GetString(field->number());
+    return GetExtensionSet(message).GetString(field->number(),
+                                              field->default_value_string());
   } else {
     return *GetField<const string*>(message, field);
   }
@@ -471,7 +493,8 @@ void GeneratedMessageReflection::SetString(
     const string& value) const {
   USAGE_CHECK_ALL(SetString, SINGULAR, STRING);
   if (field->is_extension()) {
-    return MutableExtensionSet(message)->SetString(field->number(), value);
+    return MutableExtensionSet(message)->SetString(field->number(),
+                                                   field->type(), value);
   } else {
     string** ptr = MutableField<string*>(message, field);
     if (*ptr == DefaultRaw<const string*>(field)) {
@@ -523,7 +546,8 @@ void GeneratedMessageReflection::AddString(
     const string& value) const {
   USAGE_CHECK_ALL(AddString, REPEATED, STRING);
   if (field->is_extension()) {
-    MutableExtensionSet(message)->AddString(field->number(), value);
+    MutableExtensionSet(message)->AddString(field->number(),
+                                            field->type(), value);
   } else {
     AddField<string>(message, field, value);
   }
@@ -538,7 +562,8 @@ const EnumValueDescriptor* GeneratedMessageReflection::GetEnum(
 
   int value;
   if (field->is_extension()) {
-    value = GetExtensionSet(message).GetEnum(field->number());
+    value = GetExtensionSet(message).GetEnum(
+      field->number(), field->default_value_enum()->number());
   } else {
     value = GetField<int>(message, field);
   }
@@ -555,7 +580,8 @@ void GeneratedMessageReflection::SetEnum(
   USAGE_CHECK_ENUM_VALUE(SetEnum);
 
   if (field->is_extension()) {
-    MutableExtensionSet(message)->SetEnum(field->number(), value->number());
+    MutableExtensionSet(message)->SetEnum(field->number(), field->type(),
+                                          value->number());
   } else {
     SetField<int>(message, field, value->number());
   }
@@ -599,7 +625,9 @@ void GeneratedMessageReflection::AddEnum(
   USAGE_CHECK_ENUM_VALUE(AddEnum);
 
   if (field->is_extension()) {
-    MutableExtensionSet(message)->AddEnum(field->number(), value->number());
+    MutableExtensionSet(message)->AddEnum(field->number(), field->type(),
+                                          field->options().packed(),
+                                          value->number());
   } else {
     AddField<int>(message, field, value->number());
   }
@@ -612,7 +640,9 @@ const Message& GeneratedMessageReflection::GetMessage(
   USAGE_CHECK_ALL(GetMessage, SINGULAR, MESSAGE);
 
   if (field->is_extension()) {
-    return GetExtensionSet(message).GetMessage(field->number());
+    return GetExtensionSet(message).GetMessage(field->number(),
+                                               field->message_type(),
+                                               message_factory_);
   } else {
     const Message* result = GetRaw<const Message*>(message, field);
     if (result == NULL) {
@@ -627,13 +657,15 @@ Message* GeneratedMessageReflection::MutableMessage(
   USAGE_CHECK_ALL(MutableMessage, SINGULAR, MESSAGE);
 
   if (field->is_extension()) {
-    return MutableExtensionSet(message)->MutableMessage(field->number());
+    return MutableExtensionSet(message)->MutableMessage(field->number(),
+                                                        field->type(),
+                                                        field->message_type(),
+                                                        message_factory_);
   } else {
     Message** result = MutableField<Message*>(message, field);
     if (*result == NULL) {
       const Message* default_message = DefaultRaw<const Message*>(field);
       *result = default_message->New();
-      (*result)->CopyFrom(*default_message);
     }
     return *result;
   }
@@ -667,7 +699,10 @@ Message* GeneratedMessageReflection::AddMessage(
   USAGE_CHECK_ALL(AddMessage, REPEATED, MESSAGE);
 
   if (field->is_extension()) {
-    return MutableExtensionSet(message)->AddMessage(field->number());
+    return MutableExtensionSet(message)->AddMessage(field->number(),
+                                                    field->type(),
+                                                    field->message_type(),
+                                                    message_factory_);
   } else {
     return AddField<Message>(message, field);
   }

+ 26 - 0
src/google/protobuf/generated_message_reflection.h

@@ -116,6 +116,7 @@ class LIBPROTOBUF_EXPORT GeneratedMessageReflection : public Reflection {
   //   pool:          DescriptorPool to search for extension definitions.  Only
   //                  used by FindKnownExtensionByName() and
   //                  FindKnownExtensionByNumber().
+  //   factory:       MessageFactory to use to construct extension messages.
   //   object_size:   The size of a message object of this type, as measured
   //                  by sizeof().
   GeneratedMessageReflection(const Descriptor* descriptor,
@@ -125,6 +126,7 @@ class LIBPROTOBUF_EXPORT GeneratedMessageReflection : public Reflection {
                              int unknown_fields_offset,
                              int extensions_offset,
                              const DescriptorPool* pool,
+                             MessageFactory* factory,
                              int object_size);
   ~GeneratedMessageReflection();
 
@@ -274,6 +276,7 @@ class LIBPROTOBUF_EXPORT GeneratedMessageReflection : public Reflection {
   int object_size_;
 
   const DescriptorPool* descriptor_pool_;
+  MessageFactory* message_factory_;
 
   template <typename Type>
   inline const Type& GetRaw(const Message& message,
@@ -383,8 +386,31 @@ inline To dynamic_cast_if_available(From from) {
 // Compute the space used by a string, not including sizeof(string) itself.
 // This is slightly complicated because small strings store their data within
 // the string object but large strings do not.
+LIBPROTOBUF_EXPORT int StringSpaceUsedExcludingSelf(const string& str);
 int StringSpaceUsedExcludingSelf(const string& str);
 
+// Helper for EnumType_Parse functions: try to parse the string 'name' as an
+// enum name of the given type, returning true and filling in value on success,
+// or returning false and leaving value unchanged on failure.
+bool ParseNamedEnum(const EnumDescriptor* descriptor,
+                    const string& name,
+                    int* value);
+
+template<typename EnumType>
+bool ParseNamedEnum(const EnumDescriptor* descriptor,
+                    const string& name,
+                    EnumType* value) {
+  int tmp;
+  if (!ParseNamedEnum(descriptor, name, &tmp)) return false;
+  *value = static_cast<EnumType>(tmp);
+  return true;
+}
+
+// Just a wrapper around printing the name of a value. The main point of this
+// function is not to be inlined, so that you can do this without including
+// descriptor.h.
+const string& NameOfEnum(const EnumDescriptor* descriptor, int value);
+
 
 }  // namespace internal
 }  // namespace protobuf

+ 155 - 103
src/google/protobuf/io/coded_stream.cc

@@ -71,18 +71,13 @@ CodedInputStream::CodedInputStream(ZeroCopyInputStream* input)
     buffer_size_(0),
     total_bytes_read_(0),
     overflow_bytes_(0),
-
     last_tag_(0),
     legitimate_message_end_(false),
-
     aliasing_enabled_(false),
-
     current_limit_(INT_MAX),
     buffer_size_after_limit_(0),
-
     total_bytes_limit_(kDefaultTotalBytesLimit),
     total_bytes_warning_threshold_(kDefaultTotalBytesWarningThreshold),
-
     recursion_depth_(0),
     recursion_limit_(kDefaultRecursionLimit) {
 }
@@ -514,7 +509,14 @@ CodedOutputStream::CodedOutputStream(ZeroCopyOutputStream* output)
   : output_(output),
     buffer_(NULL),
     buffer_size_(0),
-    total_bytes_(0) {
+    total_bytes_(0),
+    had_error_(false) {
+  // Eagerly Refresh() so buffer space is immediately available.
+  Refresh();
+  // The Refresh() may have failed. If the client doesn't write any data,
+  // though, don't consider this an error. If the client does write data, then
+  // another Refresh() will be attempted and it will set the error once again.
+  had_error_ = false;
 }
 
 CodedOutputStream::~CodedOutputStream() {
@@ -543,21 +545,26 @@ bool CodedOutputStream::GetDirectBufferPointer(void** data, int* size) {
   return true;
 }
 
-bool CodedOutputStream::WriteRaw(const void* data, int size) {
+void CodedOutputStream::WriteRaw(const void* data, int size) {
   while (buffer_size_ < size) {
     memcpy(buffer_, data, buffer_size_);
     size -= buffer_size_;
     data = reinterpret_cast<const uint8*>(data) + buffer_size_;
-    if (!Refresh()) return false;
+    if (!Refresh()) return;
   }
 
   memcpy(buffer_, data, size);
   Advance(size);
-  return true;
+}
+
+uint8* CodedOutputStream::WriteRawToArray(
+    const void* data, int size, uint8* target) {
+  memcpy(target, data, size);
+  return target + size;
 }
 
 
-bool CodedOutputStream::WriteLittleEndian32(uint32 value) {
+void CodedOutputStream::WriteLittleEndian32(uint32 value) {
   uint8 bytes[sizeof(value)];
 
   bool use_fast = buffer_size_ >= sizeof(value);
@@ -570,13 +577,21 @@ bool CodedOutputStream::WriteLittleEndian32(uint32 value) {
 
   if (use_fast) {
     Advance(sizeof(value));
-    return true;
   } else {
-    return WriteRaw(bytes, sizeof(value));
+    WriteRaw(bytes, sizeof(value));
   }
 }
 
-bool CodedOutputStream::WriteLittleEndian64(uint64 value) {
+uint8* CodedOutputStream::WriteLittleEndian32ToArray(
+    uint32 value, uint8* target) {
+  target[0] = static_cast<uint8>(value      );
+  target[1] = static_cast<uint8>(value >>  8);
+  target[2] = static_cast<uint8>(value >> 16);
+  target[3] = static_cast<uint8>(value >> 24);
+  return target + sizeof(value);
+}
+
+void CodedOutputStream::WriteLittleEndian64(uint64 value) {
   uint8 bytes[sizeof(value)];
 
   uint32 part0 = static_cast<uint32>(value);
@@ -596,46 +611,66 @@ bool CodedOutputStream::WriteLittleEndian64(uint64 value) {
 
   if (use_fast) {
     Advance(sizeof(value));
-    return true;
   } else {
-    return WriteRaw(bytes, sizeof(value));
+    WriteRaw(bytes, sizeof(value));
   }
 }
 
-bool CodedOutputStream::WriteVarint32Fallback(uint32 value) {
-  if (buffer_size_ >= kMaxVarint32Bytes) {
-    // Fast path:  We have enough bytes left in the buffer to guarantee that
-    // this write won't cross the end, so we can skip the checks.
-    uint8* target = buffer_;
+uint8* CodedOutputStream::WriteLittleEndian64ToArray(
+    uint64 value, uint8* target) {
+  uint32 part0 = static_cast<uint32>(value);
+  uint32 part1 = static_cast<uint32>(value >> 32);
+
+  target[0] = static_cast<uint8>(part0      );
+  target[1] = static_cast<uint8>(part0 >>  8);
+  target[2] = static_cast<uint8>(part0 >> 16);
+  target[3] = static_cast<uint8>(part0 >> 24);
+  target[4] = static_cast<uint8>(part1      );
+  target[5] = static_cast<uint8>(part1 >>  8);
+  target[6] = static_cast<uint8>(part1 >> 16);
+  target[7] = static_cast<uint8>(part1 >> 24);
+
+  return target + sizeof(value);
+}
 
-    target[0] = static_cast<uint8>(value | 0x80);
-    if (value >= (1 << 7)) {
-      target[1] = static_cast<uint8>((value >>  7) | 0x80);
-      if (value >= (1 << 14)) {
-        target[2] = static_cast<uint8>((value >> 14) | 0x80);
-        if (value >= (1 << 21)) {
-          target[3] = static_cast<uint8>((value >> 21) | 0x80);
-          if (value >= (1 << 28)) {
-            target[4] = static_cast<uint8>(value >> 28);
-            Advance(5);
-          } else {
-            target[3] &= 0x7F;
-            Advance(4);
-          }
+inline uint8* CodedOutputStream::WriteVarint32FallbackToArrayInline(
+    uint32 value, uint8* target) {
+  target[0] = static_cast<uint8>(value | 0x80);
+  if (value >= (1 << 7)) {
+    target[1] = static_cast<uint8>((value >>  7) | 0x80);
+    if (value >= (1 << 14)) {
+      target[2] = static_cast<uint8>((value >> 14) | 0x80);
+      if (value >= (1 << 21)) {
+        target[3] = static_cast<uint8>((value >> 21) | 0x80);
+        if (value >= (1 << 28)) {
+          target[4] = static_cast<uint8>(value >> 28);
+          return target + 5;
         } else {
-          target[2] &= 0x7F;
-          Advance(3);
+          target[3] &= 0x7F;
+          return target + 4;
         }
       } else {
-        target[1] &= 0x7F;
-        Advance(2);
+        target[2] &= 0x7F;
+        return target + 3;
       }
     } else {
-      target[0] &= 0x7F;
-      Advance(1);
+      target[1] &= 0x7F;
+      return target + 2;
     }
+  } else {
+    target[0] &= 0x7F;
+    return target + 1;
+  }
+}
 
-    return true;
+void CodedOutputStream::WriteVarint32(uint32 value) {
+  if (buffer_size_ >= kMaxVarint32Bytes) {
+    // Fast path:  We have enough bytes left in the buffer to guarantee that
+    // this write won't cross the end, so we can skip the checks.
+    uint8* target = buffer_;
+    uint8* end = WriteVarint32FallbackToArrayInline(value, target);
+    int size = end - target;
+    Advance(size);
   } else {
     // Slow path:  This write might cross the end of the buffer, so we
     // compose the bytes first then use WriteRaw().
@@ -646,85 +681,96 @@ bool CodedOutputStream::WriteVarint32Fallback(uint32 value) {
       value >>= 7;
     }
     bytes[size++] = static_cast<uint8>(value) & 0x7F;
-    return WriteRaw(bytes, size);
+    WriteRaw(bytes, size);
   }
 }
 
-bool CodedOutputStream::WriteVarint64(uint64 value) {
-  if (buffer_size_ >= kMaxVarintBytes) {
-    // Fast path:  We have enough bytes left in the buffer to guarantee that
-    // this write won't cross the end, so we can skip the checks.
-    uint8* target = buffer_;
+uint8* CodedOutputStream::WriteVarint32FallbackToArray(
+    uint32 value, uint8* target) {
+  return WriteVarint32FallbackToArrayInline(value, target);
+}
 
-    // Splitting into 32-bit pieces gives better performance on 32-bit
-    // processors.
-    uint32 part0 = static_cast<uint32>(value      );
-    uint32 part1 = static_cast<uint32>(value >> 28);
-    uint32 part2 = static_cast<uint32>(value >> 56);
-
-    int size;
-
-    // Here we can't really optimize for small numbers, since the value is
-    // split into three parts.  Cheking for numbers < 128, for instance,
-    // would require three comparisons, since you'd have to make sure part1
-    // and part2 are zero.  However, if the caller is using 64-bit integers,
-    // it is likely that they expect the numbers to often be very large, so
-    // we probably don't want to optimize for small numbers anyway.  Thus,
-    // we end up with a hardcoded binary search tree...
-    if (part2 == 0) {
-      if (part1 == 0) {
-        if (part0 < (1 << 14)) {
-          if (part0 < (1 << 7)) {
-            size = 1; goto size1;
-          } else {
-            size = 2; goto size2;
-          }
+inline uint8* CodedOutputStream::WriteVarint64ToArrayInline(
+    uint64 value, uint8* target) {
+  // Splitting into 32-bit pieces gives better performance on 32-bit
+  // processors.
+  uint32 part0 = static_cast<uint32>(value      );
+  uint32 part1 = static_cast<uint32>(value >> 28);
+  uint32 part2 = static_cast<uint32>(value >> 56);
+
+  int size;
+
+  // Here we can't really optimize for small numbers, since the value is
+  // split into three parts.  Cheking for numbers < 128, for instance,
+  // would require three comparisons, since you'd have to make sure part1
+  // and part2 are zero.  However, if the caller is using 64-bit integers,
+  // it is likely that they expect the numbers to often be very large, so
+  // we probably don't want to optimize for small numbers anyway.  Thus,
+  // we end up with a hardcoded binary search tree...
+  if (part2 == 0) {
+    if (part1 == 0) {
+      if (part0 < (1 << 14)) {
+        if (part0 < (1 << 7)) {
+          size = 1; goto size1;
         } else {
-          if (part0 < (1 << 21)) {
-            size = 3; goto size3;
-          } else {
-            size = 4; goto size4;
-          }
+          size = 2; goto size2;
         }
       } else {
-        if (part1 < (1 << 14)) {
-          if (part1 < (1 << 7)) {
-            size = 5; goto size5;
-          } else {
-            size = 6; goto size6;
-          }
+        if (part0 < (1 << 21)) {
+          size = 3; goto size3;
         } else {
-          if (part1 < (1 << 21)) {
-            size = 7; goto size7;
-          } else {
-            size = 8; goto size8;
-          }
+          size = 4; goto size4;
         }
       }
     } else {
-      if (part2 < (1 << 7)) {
-        size = 9; goto size9;
+      if (part1 < (1 << 14)) {
+        if (part1 < (1 << 7)) {
+          size = 5; goto size5;
+        } else {
+          size = 6; goto size6;
+        }
       } else {
-        size = 10; goto size10;
+        if (part1 < (1 << 21)) {
+          size = 7; goto size7;
+        } else {
+          size = 8; goto size8;
+        }
       }
     }
+  } else {
+    if (part2 < (1 << 7)) {
+      size = 9; goto size9;
+    } else {
+      size = 10; goto size10;
+    }
+  }
 
-    GOOGLE_LOG(FATAL) << "Can't get here.";
+  GOOGLE_LOG(FATAL) << "Can't get here.";
+
+  size10: target[9] = static_cast<uint8>((part2 >>  7) | 0x80);
+  size9 : target[8] = static_cast<uint8>((part2      ) | 0x80);
+  size8 : target[7] = static_cast<uint8>((part1 >> 21) | 0x80);
+  size7 : target[6] = static_cast<uint8>((part1 >> 14) | 0x80);
+  size6 : target[5] = static_cast<uint8>((part1 >>  7) | 0x80);
+  size5 : target[4] = static_cast<uint8>((part1      ) | 0x80);
+  size4 : target[3] = static_cast<uint8>((part0 >> 21) | 0x80);
+  size3 : target[2] = static_cast<uint8>((part0 >> 14) | 0x80);
+  size2 : target[1] = static_cast<uint8>((part0 >>  7) | 0x80);
+  size1 : target[0] = static_cast<uint8>((part0      ) | 0x80);
+
+  target[size-1] &= 0x7F;
+  return target + size;
+}
 
-    size10: target[9] = static_cast<uint8>((part2 >>  7) | 0x80);
-    size9 : target[8] = static_cast<uint8>((part2      ) | 0x80);
-    size8 : target[7] = static_cast<uint8>((part1 >> 21) | 0x80);
-    size7 : target[6] = static_cast<uint8>((part1 >> 14) | 0x80);
-    size6 : target[5] = static_cast<uint8>((part1 >>  7) | 0x80);
-    size5 : target[4] = static_cast<uint8>((part1      ) | 0x80);
-    size4 : target[3] = static_cast<uint8>((part0 >> 21) | 0x80);
-    size3 : target[2] = static_cast<uint8>((part0 >> 14) | 0x80);
-    size2 : target[1] = static_cast<uint8>((part0 >>  7) | 0x80);
-    size1 : target[0] = static_cast<uint8>((part0      ) | 0x80);
+void CodedOutputStream::WriteVarint64(uint64 value) {
+  if (buffer_size_ >= kMaxVarintBytes) {
+    // Fast path:  We have enough bytes left in the buffer to guarantee that
+    // this write won't cross the end, so we can skip the checks.
+    uint8* target = buffer_;
 
-    target[size-1] &= 0x7F;
+    uint8* end = WriteVarint64ToArrayInline(value, target);
+    int size = end - target;
     Advance(size);
-    return true;
   } else {
     // Slow path:  This write might cross the end of the buffer, so we
     // compose the bytes first then use WriteRaw().
@@ -735,10 +781,15 @@ bool CodedOutputStream::WriteVarint64(uint64 value) {
       value >>= 7;
     }
     bytes[size++] = static_cast<uint8>(value) & 0x7F;
-    return WriteRaw(bytes, size);
+    WriteRaw(bytes, size);
   }
 }
 
+uint8* CodedOutputStream::WriteVarint64ToArray(
+    uint64 value, uint8* target) {
+  return WriteVarint64ToArrayInline(value, target);
+}
+
 bool CodedOutputStream::Refresh() {
   void* void_buffer;
   if (output_->Next(&void_buffer, &buffer_size_)) {
@@ -748,6 +799,7 @@ bool CodedOutputStream::Refresh() {
   } else {
     buffer_ = NULL;
     buffer_size_ = 0;
+    had_error_ = true;
     return false;
   }
 }

+ 145 - 34
src/google/protobuf/io/coded_stream.h

@@ -380,7 +380,46 @@ class LIBPROTOBUF_EXPORT CodedInputStream {
 //
 // Most methods of CodedOutputStream which return a bool return false if an
 // underlying I/O error occurs.  Once such a failure occurs, the
-// CodedOutputStream is broken and is no longer useful.
+// CodedOutputStream is broken and is no longer useful. The Write* methods do
+// not return the stream status, but will invalidate the stream if an error
+// occurs. The client can probe HadError() to determine the status.
+//
+// Note that every method of CodedOutputStream which writes some data has
+// a corresponding static "ToArray" version. These versions write directly
+// to the provided buffer, returning a pointer past the last written byte.
+// They require that the buffer has sufficient capacity for the encoded data.
+// This allows an optimization where we check if an output stream has enough
+// space for an entire message before we start writing and, if there is, we
+// call only the ToArray methods to avoid doing bound checks for each
+// individual value.
+// i.e., in the example above:
+//
+//   CodedOutputStream coded_output = new CodedOutputStream(raw_output);
+//   int magic_number = 1234;
+//   char text[] = "Hello world!";
+//
+//   int coded_size = sizeof(magic_number) +
+//                    CodedOutputStream::Varint32Size(strlen(text)) +
+//                    strlen(text);
+//
+//   uint8* buffer =
+//       coded_output->GetDirectBufferForNBytesAndAdvance(coded_size);
+//   if (buffer != NULL) {
+//     // The output stream has enough space in the buffer: write directly to
+//     // the array.
+//     buffer = CodedOutputStream::WriteLittleEndian32ToArray(magic_number,
+//                                                            buffer);
+//     buffer = CodedOutputStream::WriteVarint32ToArray(strlen(text), buffer);
+//     buffer = CodedOutputStream::WriteRawToArray(text, strlen(text), buffer);
+//   } else {
+//     // Make bound-checked writes, which will ask the underlying stream for
+//     // more space as needed.
+//     coded_output->WriteLittleEndian32(magic_number);
+//     coded_output->WriteVarint32(strlen(text));
+//     coded_output->WriteRaw(text, strlen(text));
+//   }
+//
+//   delete coded_output;
 class LIBPROTOBUF_EXPORT CodedOutputStream {
  public:
   // Create an CodedOutputStream that writes to the given ZeroCopyOutputStream.
@@ -405,35 +444,65 @@ class LIBPROTOBUF_EXPORT CodedOutputStream {
   // CodedOutputStream interface.
   bool GetDirectBufferPointer(void** data, int* size);
 
+  // If there are at least "size" bytes available in the current buffer,
+  // returns a pointer directly into the buffer and advances over these bytes.
+  // The caller may then write directly into this buffer (e.g. using the
+  // *ToArray static methods) rather than go through CodedOutputStream.  If
+  // there are not enough bytes available, returns NULL.  The return pointer is
+  // invalidated as soon as any other non-const method of CodedOutputStream
+  // is called.
+  inline uint8* GetDirectBufferForNBytesAndAdvance(int size);
+
   // Write raw bytes, copying them from the given buffer.
-  bool WriteRaw(const void* buffer, int size);
+  void WriteRaw(const void* buffer, int size);
+  // Like WriteRaw()  but writing directly to the target array.
+  // This is _not_ inlined, as the compiler often optimizes memcpy into inline
+  // copy loops. Since this gets called by every field with string or bytes
+  // type, inlining may lead to a significant amount of code bloat, with only a
+  // minor performance gain.
+  static uint8* WriteRawToArray(const void* buffer, int size, uint8* target);
 
   // Equivalent to WriteRaw(str.data(), str.size()).
-  bool WriteString(const string& str);
+  void WriteString(const string& str);
+  // Like WriteString()  but writing directly to the target array.
+  static uint8* WriteStringToArray(const string& str, uint8* target);
 
 
   // Write a 32-bit little-endian integer.
-  bool WriteLittleEndian32(uint32 value);
+  void WriteLittleEndian32(uint32 value);
+  // Like WriteLittleEndian32()  but writing directly to the target array.
+  static uint8* WriteLittleEndian32ToArray(uint32 value, uint8* target);
   // Write a 64-bit little-endian integer.
-  bool WriteLittleEndian64(uint64 value);
+  void WriteLittleEndian64(uint64 value);
+  // Like WriteLittleEndian64()  but writing directly to the target array.
+  static uint8* WriteLittleEndian64ToArray(uint64 value, uint8* target);
 
   // Write an unsigned integer with Varint encoding.  Writing a 32-bit value
   // is equivalent to casting it to uint64 and writing it as a 64-bit value,
   // but may be more efficient.
-  bool WriteVarint32(uint32 value);
+  void WriteVarint32(uint32 value);
+  // Like WriteVarint32()  but writing directly to the target array.
+  static uint8* WriteVarint32ToArray(uint32 value, uint8* target);
   // Write an unsigned integer with Varint encoding.
-  bool WriteVarint64(uint64 value);
+  void WriteVarint64(uint64 value);
+  // Like WriteVarint64()  but writing directly to the target array.
+  static uint8* WriteVarint64ToArray(uint64 value, uint8* target);
 
   // Equivalent to WriteVarint32() except when the value is negative,
   // in which case it must be sign-extended to a full 10 bytes.
-  bool WriteVarint32SignExtended(int32 value);
+  void WriteVarint32SignExtended(int32 value);
+  // Like WriteVarint32SignExtended()  but writing directly to the target array.
+  static uint8* WriteVarint32SignExtendedToArray(int32 value, uint8* target);
 
   // This is identical to WriteVarint32(), but optimized for writing tags.
   // In particular, if the input is a compile-time constant, this method
   // compiles down to a couple instructions.
   // Always inline because otherwise the aformentioned optimization can't work,
   // but GCC by default doesn't want to inline this.
-  bool WriteTag(uint32 value) GOOGLE_ATTRIBUTE_ALWAYS_INLINE;
+  void WriteTag(uint32 value);
+  // Like WriteTag()  but writing directly to the target array.
+  static uint8* WriteTagToArray(
+      uint32 value, uint8* target) GOOGLE_ATTRIBUTE_ALWAYS_INLINE;
 
   // Returns the number of bytes needed to encode the given value as a varint.
   static int VarintSize32(uint32 value);
@@ -446,6 +515,10 @@ class LIBPROTOBUF_EXPORT CodedOutputStream {
   // Returns the total number of bytes written since this object was created.
   inline int ByteCount() const;
 
+  // Returns true if there was an underlying I/O error since this object was
+  // created.
+  bool HadError() const { return had_error_; }
+
  private:
   GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(CodedOutputStream);
 
@@ -453,6 +526,7 @@ class LIBPROTOBUF_EXPORT CodedOutputStream {
   uint8* buffer_;
   int buffer_size_;
   int total_bytes_;  // Sum of sizes of all buffers seen so far.
+  bool had_error_;   // Whether an error occurred during output.
 
   // Advance the buffer by a given number of bytes.
   void Advance(int amount);
@@ -461,7 +535,20 @@ class LIBPROTOBUF_EXPORT CodedOutputStream {
   // Advance(buffer_size_).
   bool Refresh();
 
-  bool WriteVarint32Fallback(uint32 value);
+  static uint8* WriteVarint32FallbackToArray(uint32 value, uint8* target);
+
+  // Always-inlined versions of WriteVarint* functions so that code can be
+  // reused, while still controlling size. For instance, WriteVarint32ToArray()
+  // should not directly call this: since it is inlined itself, doing so
+  // would greatly increase the size of generated code. Instead, it should call
+  // WriteVarint32FallbackToArray.  Meanwhile, WriteVarint32() is already
+  // out-of-line, so it should just invoke this directly to avoid any extra
+  // function call overhead.
+  static uint8* WriteVarint32FallbackToArrayInline(
+      uint32 value, uint8* target) GOOGLE_ATTRIBUTE_ALWAYS_INLINE;
+  static uint8* WriteVarint64ToArrayInline(
+      uint64 value, uint8* target) GOOGLE_ATTRIBUTE_ALWAYS_INLINE;
+
   static int VarintSize32Fallback(uint32 value);
 };
 
@@ -540,40 +627,59 @@ inline bool CodedInputStream::ExpectAtEnd() {
   }
 }
 
-inline bool CodedOutputStream::WriteVarint32(uint32 value) {
-  if (value < 0x80 && buffer_size_ > 0) {
-    *buffer_ = static_cast<uint8>(value);
-    Advance(1);
-    return true;
+inline uint8* CodedOutputStream::GetDirectBufferForNBytesAndAdvance(int size) {
+  if (buffer_size_ < size) {
+    return NULL;
+  } else {
+    uint8* result = buffer_;
+    Advance(size);
+    return result;
+  }
+}
+
+inline uint8* CodedOutputStream::WriteVarint32ToArray(uint32 value,
+                                                        uint8* target) {
+  if (value < 0x80) {
+    *target = value;
+    return target + 1;
   } else {
-    return WriteVarint32Fallback(value);
+    return WriteVarint32FallbackToArray(value, target);
   }
 }
 
-inline bool CodedOutputStream::WriteVarint32SignExtended(int32 value) {
+inline void CodedOutputStream::WriteVarint32SignExtended(int32 value) {
   if (value < 0) {
-    return WriteVarint64(static_cast<uint64>(value));
+    WriteVarint64(static_cast<uint64>(value));
   } else {
-    return WriteVarint32(static_cast<uint32>(value));
+    WriteVarint32(static_cast<uint32>(value));
   }
 }
 
-inline bool CodedOutputStream::WriteTag(uint32 value) {
+inline uint8* CodedOutputStream::WriteVarint32SignExtendedToArray(
+    int32 value, uint8* target) {
+  if (value < 0) {
+    return WriteVarint64ToArray(static_cast<uint64>(value), target);
+  } else {
+    return WriteVarint32ToArray(static_cast<uint32>(value), target);
+  }
+}
+
+inline void CodedOutputStream::WriteTag(uint32 value) {
+  WriteVarint32(value);
+}
+
+inline uint8* CodedOutputStream::WriteTagToArray(
+    uint32 value, uint8* target) {
   if (value < (1 << 7)) {
-    if (buffer_size_ != 0) {
-      buffer_[0] = static_cast<uint8>(value);
-      Advance(1);
-      return true;
-    }
+    target[0] = value;
+    return target + 1;
   } else if (value < (1 << 14)) {
-    if (buffer_size_ >= 2) {
-      buffer_[0] = static_cast<uint8>(value | 0x80);
-      buffer_[1] = static_cast<uint8>(value >> 7);
-      Advance(2);
-      return true;
-    }
+    target[0] = static_cast<uint8>(value | 0x80);
+    target[1] = static_cast<uint8>(value >> 7);
+    return target + 2;
+  } else {
+    return WriteVarint32FallbackToArray(value, target);
   }
-  return WriteVarint32Fallback(value);
 }
 
 inline int CodedOutputStream::VarintSize32(uint32 value) {
@@ -592,8 +698,13 @@ inline int CodedOutputStream::VarintSize32SignExtended(int32 value) {
   }
 }
 
-inline bool CodedOutputStream::WriteString(const string& str) {
-  return WriteRaw(str.data(), str.size());
+inline void CodedOutputStream::WriteString(const string& str) {
+  WriteRaw(str.data(), str.size());
+}
+
+inline uint8* CodedOutputStream::WriteStringToArray(
+    const string& str, uint8* target) {
+  return WriteRawToArray(str.data(), str.size(), target);
 }
 
 inline int CodedOutputStream::ByteCount() const {

+ 12 - 8
src/google/protobuf/io/coded_stream_unittest.cc

@@ -268,8 +268,8 @@ TEST_2D(CodedStreamTest, WriteVarint32, kVarintCases, kBlockSizes) {
   {
     CodedOutputStream coded_output(&output);
 
-    EXPECT_TRUE(coded_output.WriteVarint32(
-      static_cast<uint32>(kVarintCases_case.value)));
+    coded_output.WriteVarint32(static_cast<uint32>(kVarintCases_case.value));
+    EXPECT_FALSE(coded_output.HadError());
 
     EXPECT_EQ(kVarintCases_case.size, coded_output.ByteCount());
   }
@@ -285,7 +285,8 @@ TEST_2D(CodedStreamTest, WriteVarint64, kVarintCases, kBlockSizes) {
   {
     CodedOutputStream coded_output(&output);
 
-    EXPECT_TRUE(coded_output.WriteVarint64(kVarintCases_case.value));
+    coded_output.WriteVarint64(kVarintCases_case.value);
+    EXPECT_FALSE(coded_output.HadError());
 
     EXPECT_EQ(kVarintCases_case.size, coded_output.ByteCount());
   }
@@ -310,8 +311,8 @@ TEST_2D(CodedStreamTest, WriteVarint32SignExtended,
   {
     CodedOutputStream coded_output(&output);
 
-    EXPECT_TRUE(coded_output.WriteVarint32SignExtended(
-      kSignExtendedVarintCases_case));
+    coded_output.WriteVarint32SignExtended(kSignExtendedVarintCases_case);
+    EXPECT_FALSE(coded_output.HadError());
 
     if (kSignExtendedVarintCases_case < 0) {
       EXPECT_EQ(10, coded_output.ByteCount());
@@ -502,7 +503,8 @@ TEST_2D(CodedStreamTest, WriteLittleEndian32, kFixed32Cases, kBlockSizes) {
   {
     CodedOutputStream coded_output(&output);
 
-    EXPECT_TRUE(coded_output.WriteLittleEndian32(kFixed32Cases_case.value));
+    coded_output.WriteLittleEndian32(kFixed32Cases_case.value);
+    EXPECT_FALSE(coded_output.HadError());
 
     EXPECT_EQ(sizeof(uint32), coded_output.ByteCount());
   }
@@ -517,7 +519,8 @@ TEST_2D(CodedStreamTest, WriteLittleEndian64, kFixed64Cases, kBlockSizes) {
   {
     CodedOutputStream coded_output(&output);
 
-    EXPECT_TRUE(coded_output.WriteLittleEndian64(kFixed64Cases_case.value));
+    coded_output.WriteLittleEndian64(kFixed64Cases_case.value);
+    EXPECT_FALSE(coded_output.HadError());
 
     EXPECT_EQ(sizeof(uint64), coded_output.ByteCount());
   }
@@ -552,7 +555,8 @@ TEST_1D(CodedStreamTest, WriteRaw, kBlockSizes) {
   {
     CodedOutputStream coded_output(&output);
 
-    EXPECT_TRUE(coded_output.WriteRaw(kRawBytes, sizeof(kRawBytes)));
+    coded_output.WriteRaw(kRawBytes, sizeof(kRawBytes));
+    EXPECT_FALSE(coded_output.HadError());
 
     EXPECT_EQ(sizeof(kRawBytes), coded_output.ByteCount());
   }

+ 3 - 2
src/google/protobuf/io/zero_copy_stream_impl.cc

@@ -185,11 +185,12 @@ bool StringOutputStream::Next(void** data, int* size) {
   if (old_size < target_->capacity()) {
     // Resize the string to match its capacity, since we can get away
     // without a memory allocation this way.
-    target_->resize(target_->capacity());
+    STLStringResizeUninitialized(target_, target_->capacity());
   } else {
     // Size has reached capacity, so double the size.  Also make sure
     // that the new size is at least kMinimumSize.
-    target_->resize(
+    STLStringResizeUninitialized(
+      target_,
       max(old_size * 2,
           kMinimumSize + 0));  // "+ 0" works around GCC4 weirdness.
   }

+ 90 - 22
src/google/protobuf/message.cc

@@ -47,6 +47,7 @@
 #include <google/protobuf/stubs/strutil.h>
 #include <google/protobuf/stubs/substitute.h>
 #include <google/protobuf/stubs/map-util.h>
+#include <google/protobuf/stubs/stl_util-inl.h>
 
 namespace google {
 namespace protobuf {
@@ -205,9 +206,19 @@ bool Message::ParsePartialFromIstream(istream* input) {
 
 
 
-bool Message::SerializeWithCachedSizes(
+void Message::SerializeWithCachedSizes(
     io::CodedOutputStream* output) const {
-  return WireFormat::SerializeWithCachedSizes(*this, GetCachedSize(), output);
+  WireFormat::SerializeWithCachedSizes(*this, GetCachedSize(), output);
+}
+
+uint8* Message::SerializeWithCachedSizesToArray(uint8* target) const {
+  // We only optimize this when using optimize_for = SPEED.
+  int size = GetCachedSize();
+  io::ArrayOutputStream out(target, size);
+  io::CodedOutputStream coded_out(&out);
+  SerializeWithCachedSizes(&coded_out);
+  GOOGLE_CHECK(!coded_out.HadError());
+  return target + size;
 }
 
 int Message::ByteSize() const {
@@ -234,8 +245,8 @@ bool Message::SerializeToCodedStream(io::CodedOutputStream* output) const {
 bool Message::SerializePartialToCodedStream(
     io::CodedOutputStream* output) const {
   ByteSize();  // Force size to be cached.
-  if (!SerializeWithCachedSizes(output)) return false;
-  return true;
+  SerializeWithCachedSizes(output);
+  return !output->HadError();
 }
 
 bool Message::SerializeToZeroCopyStream(
@@ -256,19 +267,12 @@ bool Message::AppendToString(string* output) const {
 }
 
 bool Message::AppendPartialToString(string* output) const {
-  // For efficiency, we'd like to reserve the exact amount of space we need
-  // in the string.
-  int total_size = output->size() + ByteSize();
-  output->reserve(total_size);
-
-  io::StringOutputStream output_stream(output);
-
-  {
-    io::CodedOutputStream encoder(&output_stream);
-    if (!SerializeWithCachedSizes(&encoder)) return false;
-  }
-
-  GOOGLE_CHECK_EQ(output_stream.ByteCount(), total_size);
+  int old_size = output->size();
+  int byte_size = ByteSize();
+  STLStringResizeUninitialized(output, old_size + byte_size);
+  uint8* start = reinterpret_cast<uint8*>(string_as_array(output) + old_size);
+  uint8* end = SerializeWithCachedSizesToArray(start);
+  GOOGLE_CHECK_EQ(end, start + byte_size);
   return true;
 }
 
@@ -283,13 +287,17 @@ bool Message::SerializePartialToString(string* output) const {
 }
 
 bool Message::SerializeToArray(void* data, int size) const {
-  io::ArrayOutputStream output_stream(data, size);
-  return SerializeToZeroCopyStream(&output_stream);
+  GOOGLE_DCHECK(IsInitialized()) << InitializationErrorMessage("serialize", *this);
+  return SerializePartialToArray(data, size);
 }
 
 bool Message::SerializePartialToArray(void* data, int size) const {
-  io::ArrayOutputStream output_stream(data, size);
-  return SerializePartialToZeroCopyStream(&output_stream);
+  int byte_size = ByteSize();
+  if (size < byte_size) return false;
+  uint8* end =
+    SerializeWithCachedSizesToArray(reinterpret_cast<uint8*>(data));
+  GOOGLE_CHECK_EQ(end, reinterpret_cast<uint8*>(data) + byte_size);
+  return true;
 }
 
 bool Message::SerializeToFileDescriptor(int file_descriptor) const {
@@ -347,12 +355,20 @@ class GeneratedMessageFactory : public MessageFactory {
 
   static GeneratedMessageFactory* singleton();
 
+  typedef void RegistrationFunc();
+  void RegisterFile(const char* file, RegistrationFunc* registration_func);
   void RegisterType(const Descriptor* descriptor, const Message* prototype);
 
   // implements MessageFactory ---------------------------------------
   const Message* GetPrototype(const Descriptor* type);
 
  private:
+  // Only written at static init time, so does not require locking.
+  hash_map<const char*, RegistrationFunc*,
+           hash<const char*>, streq> file_map_;
+
+  // Initialized lazily, so requires locking.
+  Mutex mutex_;
   hash_map<const Descriptor*, const Message*> type_map_;
 };
 
@@ -366,19 +382,65 @@ GeneratedMessageFactory* GeneratedMessageFactory::singleton() {
   return &singleton;
 }
 
+void GeneratedMessageFactory::RegisterFile(
+    const char* file, RegistrationFunc* registration_func) {
+  if (!InsertIfNotPresent(&file_map_, file, registration_func)) {
+    GOOGLE_LOG(FATAL) << "File is already registered: " << file;
+  }
+}
+
 void GeneratedMessageFactory::RegisterType(const Descriptor* descriptor,
                                            const Message* prototype) {
   GOOGLE_DCHECK_EQ(descriptor->file()->pool(), DescriptorPool::generated_pool())
     << "Tried to register a non-generated type with the generated "
        "type registry.";
 
+  // This should only be called as a result of calling a file registration
+  // function during GetPrototype(), in which case we already have locked
+  // the mutex.
+  mutex_.AssertHeld();
   if (!InsertIfNotPresent(&type_map_, descriptor, prototype)) {
     GOOGLE_LOG(DFATAL) << "Type is already registered: " << descriptor->full_name();
   }
 }
 
 const Message* GeneratedMessageFactory::GetPrototype(const Descriptor* type) {
-  return FindPtrOrNull(type_map_, type);
+  {
+    ReaderMutexLock lock(&mutex_);
+    const Message* result = FindPtrOrNull(type_map_, type);
+    if (result != NULL) return result;
+  }
+
+  // If the type is not in the generated pool, then we can't possibly handle
+  // it.
+  if (type->file()->pool() != DescriptorPool::generated_pool()) return NULL;
+
+  // Apparently the file hasn't been registered yet.  Let's do that now.
+  RegistrationFunc* registration_func =
+      FindPtrOrNull(file_map_, type->file()->name().c_str());
+  if (registration_func == NULL) {
+    GOOGLE_LOG(DFATAL) << "File appears to be in generated pool but wasn't "
+                   "registered: " << type->file()->name();
+    return NULL;
+  }
+
+  WriterMutexLock lock(&mutex_);
+
+  // Check if another thread preempted us.
+  const Message* result = FindPtrOrNull(type_map_, type);
+  if (result == NULL) {
+    // Nope.  OK, register everything.
+    registration_func();
+    // Should be here now.
+    result = FindPtrOrNull(type_map_, type);
+  }
+
+  if (result == NULL) {
+    GOOGLE_LOG(DFATAL) << "Type appears to be in generated pool but wasn't "
+                << "registered: " << type->full_name();
+  }
+
+  return result;
 }
 
 }  // namespace
@@ -387,6 +449,12 @@ MessageFactory* MessageFactory::generated_factory() {
   return GeneratedMessageFactory::singleton();
 }
 
+void MessageFactory::InternalRegisterGeneratedFile(
+    const char* filename, void (*register_messages)()) {
+  GeneratedMessageFactory::singleton()->RegisterFile(filename,
+                                                     register_messages);
+}
+
 void MessageFactory::InternalRegisterGeneratedMessage(
     const Descriptor* descriptor, const Message* prototype) {
   GeneratedMessageFactory::singleton()->RegisterType(descriptor, prototype);

+ 37 - 21
src/google/protobuf/message.h

@@ -117,24 +117,24 @@
 #else
 #include <iosfwd>
 #endif
-
-#if defined(_WIN32) && defined(GetMessage)
-// windows.h defines GetMessage() as a macro.  Let's re-define it as an inline
-// function.  This is necessary because Reflection has a method called
-// GetMessage() which we don't want overridden.  The inline function should be
-// equivalent for C++ users.
-inline BOOL GetMessage_Win32(
-    LPMSG lpMsg, HWND hWnd,
-    UINT wMsgFilterMin, UINT wMsgFilterMax) {
-  return GetMessage(lpMsg, hWnd, wMsgFilterMin, wMsgFilterMax);
-}
-#undef GetMessage
-inline BOOL GetMessage(
-    LPMSG lpMsg, HWND hWnd,
-    UINT wMsgFilterMin, UINT wMsgFilterMax) {
-  return GetMessage_Win32(lpMsg, hWnd, wMsgFilterMin, wMsgFilterMax);
-}
-#endif
+
+#if defined(_WIN32) && defined(GetMessage)
+// windows.h defines GetMessage() as a macro.  Let's re-define it as an inline
+// function.  This is necessary because Reflection has a method called
+// GetMessage() which we don't want overridden.  The inline function should be
+// equivalent for C++ users.
+inline BOOL GetMessage_Win32(
+    LPMSG lpMsg, HWND hWnd,
+    UINT wMsgFilterMin, UINT wMsgFilterMax) {
+  return GetMessage(lpMsg, hWnd, wMsgFilterMin, wMsgFilterMax);
+}
+#undef GetMessage
+inline BOOL GetMessage(
+    LPMSG lpMsg, HWND hWnd,
+    UINT wMsgFilterMin, UINT wMsgFilterMax) {
+  return GetMessage_Win32(lpMsg, hWnd, wMsgFilterMin, wMsgFilterMax);
+}
+#endif
 
 #include <google/protobuf/stubs/common.h>
 
@@ -370,7 +370,12 @@ class LIBPROTOBUF_EXPORT Message {
   // Serializes the message without recomputing the size.  The message must
   // not have changed since the last call to ByteSize(); if it has, the results
   // are undefined.
-  virtual bool SerializeWithCachedSizes(io::CodedOutputStream* output) const;
+  virtual void SerializeWithCachedSizes(io::CodedOutputStream* output) const;
+
+  // Like SerializeWithCachedSizes, but writes directly to *target, returning
+  // a pointer to the byte immediately after the last byte written.  "target"
+  // must point at a byte array of at least ByteSize() bytes.
+  virtual uint8* SerializeWithCachedSizesToArray(uint8* target) const;
 
   // Returns the result of the last call to ByteSize().  An embedded message's
   // size is needed both to serialize it (because embedded messages are
@@ -731,8 +736,19 @@ class LIBPROTOBUF_EXPORT MessageFactory {
   // This factory is a singleton.  The caller must not delete the object.
   static MessageFactory* generated_factory();
 
-  // For internal use only:  Registers a message type at static initialization
-  // time, to be placed in generated_factory().
+  // For internal use only:  Registers a .proto file at static initialization
+  // time, to be placed in generated_factory.  The first time GetPrototype()
+  // is called with a descriptor from this file, |register_messages| will be
+  // called.  It must call InternalRegisterGeneratedMessage() (below) to
+  // register each message type in the file.  This strange mechanism is
+  // necessary because descriptors are built lazily, so we can't register
+  // types by their descriptor until we know that the descriptor exists.
+  static void InternalRegisterGeneratedFile(const char* filename,
+                                            void (*register_messages)());
+
+  // For internal use only:  Registers a message type.  Called only by the
+  // functions which are registered with InternalRegisterGeneratedFile(),
+  // above.
   static void InternalRegisterGeneratedMessage(const Descriptor* descriptor,
                                                const Message* prototype);
 

+ 16 - 20
src/google/protobuf/reflection_ops_unittest.cc

@@ -138,16 +138,18 @@ TEST(ReflectionOpsTest, MergeExtensions) {
 TEST(ReflectionOpsTest, MergeUnknown) {
   // Test that the messages' UnknownFieldSets are correctly merged.
   unittest::TestEmptyMessage message1, message2;
-  message1.mutable_unknown_fields()->AddField(1234)->add_varint(1);
-  message2.mutable_unknown_fields()->AddField(1234)->add_varint(2);
+  message1.mutable_unknown_fields()->AddVarint(1234, 1);
+  message2.mutable_unknown_fields()->AddVarint(1234, 2);
 
   ReflectionOps::Merge(message2, &message1);
 
-  ASSERT_EQ(1, message1.unknown_fields().field_count());
-  const UnknownField& field = message1.unknown_fields().field(0);
-  ASSERT_EQ(2, field.varint_size());
-  EXPECT_EQ(1, field.varint(0));
-  EXPECT_EQ(2, field.varint(1));
+  ASSERT_EQ(2, message1.unknown_fields().field_count());
+  ASSERT_EQ(UnknownField::TYPE_VARINT,
+            message1.unknown_fields().field(0).type());
+  EXPECT_EQ(1, message1.unknown_fields().field(0).varint());
+  ASSERT_EQ(UnknownField::TYPE_VARINT,
+            message1.unknown_fields().field(1).type());
+  EXPECT_EQ(2, message1.unknown_fields().field(1).varint());
 }
 
 #ifdef GTEST_HAS_DEATH_TEST
@@ -211,7 +213,7 @@ TEST(ReflectionOpsTest, ClearExtensions) {
 TEST(ReflectionOpsTest, ClearUnknown) {
   // Test that the message's UnknownFieldSet is correctly cleared.
   unittest::TestEmptyMessage message;
-  message.mutable_unknown_fields()->AddField(1234)->add_varint(1);
+  message.mutable_unknown_fields()->AddVarint(1234, 1);
 
   ReflectionOps::Clear(&message);
 
@@ -224,16 +226,13 @@ TEST(ReflectionOpsTest, DiscardUnknownFields) {
 
   // Set some unknown fields in message.
   message.mutable_unknown_fields()
-        ->AddField(123456)
-        ->add_varint(654321);
+        ->AddVarint(123456, 654321);
   message.mutable_optional_nested_message()
         ->mutable_unknown_fields()
-        ->AddField(123456)
-        ->add_varint(654321);
+        ->AddVarint(123456, 654321);
   message.mutable_repeated_nested_message(0)
         ->mutable_unknown_fields()
-        ->AddField(123456)
-        ->add_varint(654321);
+        ->AddVarint(123456, 654321);
 
   EXPECT_EQ(1, message.unknown_fields().field_count());
   EXPECT_EQ(1, message.optional_nested_message()
@@ -258,16 +257,13 @@ TEST(ReflectionOpsTest, DiscardUnknownExtensions) {
 
   // Set some unknown fields.
   message.mutable_unknown_fields()
-        ->AddField(123456)
-        ->add_varint(654321);
+        ->AddVarint(123456, 654321);
   message.MutableExtension(unittest::optional_nested_message_extension)
         ->mutable_unknown_fields()
-        ->AddField(123456)
-        ->add_varint(654321);
+        ->AddVarint(123456, 654321);
   message.MutableExtension(unittest::repeated_nested_message_extension, 0)
         ->mutable_unknown_fields()
-        ->AddField(123456)
-        ->add_varint(654321);
+        ->AddVarint(123456, 654321);
 
   EXPECT_EQ(1, message.unknown_fields().field_count());
   EXPECT_EQ(1,

+ 124 - 0
src/google/protobuf/repeated_field.h

@@ -93,6 +93,7 @@ class LIBPROTOBUF_EXPORT GenericRepeatedField {
 };
 
 // We need this (from generated_message_reflection.cc).
+LIBPROTOBUF_EXPORT int StringSpaceUsedExcludingSelf(const string& str);
 int StringSpaceUsedExcludingSelf(const string& str);
 
 }  // namespace internal
@@ -865,6 +866,129 @@ RepeatedPtrField<Element>::end() const {
   return iterator(elements_ + current_size_);
 }
 
+// Iterators and helper functions that follow the spirit of the STL
+// std::back_insert_iterator and std::back_inserter but are tailor-made
+// for RepeatedField and RepatedPtrField. Typical usage would be:
+//
+//   std::copy(some_sequence.begin(), some_sequence.end(),
+//             google::protobuf::RepeatedFieldBackInserter(proto.mutable_sequence()));
+//
+// Ported by johannes from util/gtl/proto-array-iterators-inl.h
+
+namespace internal {
+// A back inserter for RepeatedField objects.
+template<typename T> class RepeatedFieldBackInsertIterator
+    : public std::iterator<std::output_iterator_tag, T> {
+ public:
+  explicit RepeatedFieldBackInsertIterator(
+      RepeatedField<T>* const mutable_field)
+      : field_(mutable_field) {
+  }
+  RepeatedFieldBackInsertIterator<T>& operator=(const T& value) {
+    field_->Add(value);
+    return *this;
+  }
+  RepeatedFieldBackInsertIterator<T>& operator*() {
+    return *this;
+  }
+  RepeatedFieldBackInsertIterator<T>& operator++() {
+    return *this;
+  }
+  RepeatedFieldBackInsertIterator<T>& operator++(int ignores_parameter) {
+    return *this;
+  }
+
+ private:
+  RepeatedField<T>* const field_;
+};
+
+// A back inserter for RepeatedPtrField objects.
+template<typename T> class RepeatedPtrFieldBackInsertIterator
+    : public std::iterator<std::output_iterator_tag, T> {
+ public:
+  RepeatedPtrFieldBackInsertIterator(
+      RepeatedPtrField<T>* const mutable_field)
+      : field_(mutable_field) {
+  }
+  RepeatedPtrFieldBackInsertIterator<T>& operator=(const T& value) {
+    *field_->Add() = value;
+    return *this;
+  }
+  RepeatedPtrFieldBackInsertIterator<T>& operator=(
+      const T* const ptr_to_value) {
+    *field_->Add() = *ptr_to_value;
+    return *this;
+  }
+  RepeatedPtrFieldBackInsertIterator<T>& operator*() {
+    return *this;
+  }
+  RepeatedPtrFieldBackInsertIterator<T>& operator++() {
+    return *this;
+  }
+  RepeatedPtrFieldBackInsertIterator<T>& operator++(int ignores_parameter) {
+    return *this;
+  }
+
+ private:
+  RepeatedPtrField<T>* const field_;
+};
+
+// A back inserter for RepeatedPtrFields that inserts by transfering ownership
+// of a pointer.
+template<typename T> class AllocatedRepeatedPtrFieldBackInsertIterator
+    : public std::iterator<std::output_iterator_tag, T> {
+ public:
+  explicit AllocatedRepeatedPtrFieldBackInsertIterator(
+      RepeatedPtrField<T>* const mutable_field)
+      : field_(mutable_field) {
+  }
+  AllocatedRepeatedPtrFieldBackInsertIterator<T>& operator=(
+      T* const ptr_to_value) {
+    field_->AddAllocated(ptr_to_value);
+    return *this;
+  }
+  AllocatedRepeatedPtrFieldBackInsertIterator<T>& operator*() {
+    return *this;
+  }
+  AllocatedRepeatedPtrFieldBackInsertIterator<T>& operator++() {
+    return *this;
+  }
+  AllocatedRepeatedPtrFieldBackInsertIterator<T>& operator++(
+      int ignores_parameter) {
+    return *this;
+  }
+
+ private:
+  RepeatedPtrField<T>* const field_;
+};
+}  // namespace internal
+
+// Provides a back insert iterator for RepeatedField instances,
+// similar to std::back_inserter(). Note the identically named
+// function for RepeatedPtrField instances.
+template<typename T> internal::RepeatedFieldBackInsertIterator<T>
+RepeatedFieldBackInserter(RepeatedField<T>* const mutable_field) {
+  return internal::RepeatedFieldBackInsertIterator<T>(mutable_field);
+}
+
+// Provides a back insert iterator for RepeatedPtrField instances,
+// similar to std::back_inserter(). Note the identically named
+// function for RepeatedField instances.
+template<typename T> internal::RepeatedPtrFieldBackInsertIterator<T>
+RepeatedFieldBackInserter(RepeatedPtrField<T>* const mutable_field) {
+  return internal::RepeatedPtrFieldBackInsertIterator<T>(mutable_field);
+}
+
+// Provides a back insert iterator for RepeatedPtrField instances
+// similar to std::back_inserter() which transfers the ownership while
+// copying elements.
+template<typename T> internal::AllocatedRepeatedPtrFieldBackInsertIterator<T>
+AllocatedRepeatedPtrFieldBackInserter(
+    RepeatedPtrField<T>* const mutable_field) {
+  return internal::AllocatedRepeatedPtrFieldBackInsertIterator<T>(
+      mutable_field);
+}
+
 }  // namespace protobuf
 
 }  // namespace google

+ 149 - 0
src/google/protobuf/repeated_field_unittest.cc

@@ -36,15 +36,22 @@
 //   other proto2 unittests.
 
 #include <algorithm>
+#include <list>
 
 #include <google/protobuf/repeated_field.h>
 
 #include <google/protobuf/stubs/common.h>
+#include <google/protobuf/unittest.pb.h>
+#include <google/protobuf/stubs/strutil.h>
 #include <google/protobuf/testing/googletest.h>
 #include <gtest/gtest.h>
+#include <google/protobuf/stubs/stl_util-inl.h>
 
 namespace google {
+using protobuf_unittest::TestAllTypes;
+
 namespace protobuf {
+namespace {
 
 // Test operations on a RepeatedField which is small enough that it does
 // not allocate a separate array for storage.
@@ -621,5 +628,147 @@ TEST_F(RepeatedPtrFieldIteratorTest, Mutation) {
   EXPECT_EQ("qux", proto_array_.Get(0));
 }
 
+// -----------------------------------------------------------------------------
+// Unit-tests for the insert iterators
+// google::protobuf::RepeatedFieldBackInserter,
+// google::protobuf::AllocatedRepeatedPtrFieldBackInserter
+// Ported from util/gtl/proto-array-iterators_unittest.
+
+class RepeatedFieldInsertionIteratorsTest : public testing::Test {
+ protected:
+  std::list<double> halves;
+  std::list<int> fibonacci;
+  std::vector<string> words;
+  typedef TestAllTypes::NestedMessage Nested;
+  Nested nesteds[2];
+  std::vector<Nested*> nested_ptrs;
+  TestAllTypes protobuffer;
+
+  virtual void SetUp() {
+    fibonacci.push_back(1);
+    fibonacci.push_back(1);
+    fibonacci.push_back(2);
+    fibonacci.push_back(3);
+    fibonacci.push_back(5);
+    fibonacci.push_back(8);
+    std::copy(fibonacci.begin(), fibonacci.end(),
+              RepeatedFieldBackInserter(protobuffer.mutable_repeated_int32()));
+
+    halves.push_back(1.0);
+    halves.push_back(0.5);
+    halves.push_back(0.25);
+    halves.push_back(0.125);
+    halves.push_back(0.0625);
+    std::copy(halves.begin(), halves.end(),
+              RepeatedFieldBackInserter(protobuffer.mutable_repeated_double()));
+
+    words.push_back("Able");
+    words.push_back("was");
+    words.push_back("I");
+    words.push_back("ere");
+    words.push_back("I");
+    words.push_back("saw");
+    words.push_back("Elba");
+    std::copy(words.begin(), words.end(),
+              RepeatedFieldBackInserter(protobuffer.mutable_repeated_string()));
+
+    nesteds[0].set_bb(17);
+    nesteds[1].set_bb(4711);
+    std::copy(&nesteds[0], &nesteds[2],
+              RepeatedFieldBackInserter(
+                  protobuffer.mutable_repeated_nested_message()));
+
+    nested_ptrs.push_back(new Nested);
+    nested_ptrs.back()->set_bb(170);
+    nested_ptrs.push_back(new Nested);
+    nested_ptrs.back()->set_bb(47110);
+    std::copy(nested_ptrs.begin(), nested_ptrs.end(),
+              RepeatedFieldBackInserter(
+                  protobuffer.mutable_repeated_nested_message()));
+
+  }
+
+  virtual void TearDown() {
+    STLDeleteContainerPointers(nested_ptrs.begin(), nested_ptrs.end());
+  }
+};
+
+TEST_F(RepeatedFieldInsertionIteratorsTest, Fibonacci) {
+  EXPECT_TRUE(std::equal(fibonacci.begin(),
+                         fibonacci.end(),
+                         protobuffer.repeated_int32().begin()));
+  EXPECT_TRUE(std::equal(protobuffer.repeated_int32().begin(),
+                         protobuffer.repeated_int32().end(),
+                         fibonacci.begin()));
+}
+
+TEST_F(RepeatedFieldInsertionIteratorsTest, Halves) {
+  EXPECT_TRUE(std::equal(halves.begin(),
+                         halves.end(),
+                         protobuffer.repeated_double().begin()));
+  EXPECT_TRUE(std::equal(protobuffer.repeated_double().begin(),
+                         protobuffer.repeated_double().end(),
+                         halves.begin()));
+}
+
+TEST_F(RepeatedFieldInsertionIteratorsTest, Words) {
+  ASSERT_EQ(words.size(), protobuffer.repeated_string_size());
+  EXPECT_EQ(words.at(0), protobuffer.repeated_string(0));
+  EXPECT_EQ(words.at(1), protobuffer.repeated_string(1));
+  EXPECT_EQ(words.at(2), protobuffer.repeated_string(2));
+  EXPECT_EQ(words.at(3), protobuffer.repeated_string(3));
+  EXPECT_EQ(words.at(4), protobuffer.repeated_string(4));
+  EXPECT_EQ(words.at(5), protobuffer.repeated_string(5));
+  EXPECT_EQ(words.at(6), protobuffer.repeated_string(6));
+}
+
+TEST_F(RepeatedFieldInsertionIteratorsTest, Nesteds) {
+  ASSERT_EQ(protobuffer.repeated_nested_message_size(), 4);
+  EXPECT_EQ(protobuffer.repeated_nested_message(0).bb(), 17);
+  EXPECT_EQ(protobuffer.repeated_nested_message(1).bb(), 4711);
+  EXPECT_EQ(protobuffer.repeated_nested_message(2).bb(), 170);
+  EXPECT_EQ(protobuffer.repeated_nested_message(3).bb(), 47110);
+}
+
+TEST_F(RepeatedFieldInsertionIteratorsTest,
+       AllocatedRepeatedPtrFieldWithStringIntData) {
+  vector<Nested*> data;
+  TestAllTypes goldenproto;
+  for (int i = 0; i < 10; ++i) {
+    Nested* new_data = new Nested;
+    new_data->set_bb(i);
+    data.push_back(new_data);
+
+    new_data = goldenproto.add_repeated_nested_message();
+    new_data->set_bb(i);
+  }
+  TestAllTypes testproto;
+  copy(data.begin(), data.end(),
+       AllocatedRepeatedPtrFieldBackInserter(
+           testproto.mutable_repeated_nested_message()));
+  EXPECT_EQ(testproto.DebugString(), goldenproto.DebugString());
+}
+
+TEST_F(RepeatedFieldInsertionIteratorsTest,
+       AllocatedRepeatedPtrFieldWithString) {
+  vector<string*> data;
+  TestAllTypes goldenproto;
+  for (int i = 0; i < 10; ++i) {
+    string* new_data = new string;
+    *new_data = "name-" + SimpleItoa(i);
+    data.push_back(new_data);
+
+    new_data = goldenproto.add_repeated_string();
+    *new_data = "name-" + SimpleItoa(i);
+  }
+  TestAllTypes testproto;
+  copy(data.begin(), data.end(),
+       AllocatedRepeatedPtrFieldBackInserter(
+           testproto.mutable_repeated_string()));
+  EXPECT_EQ(testproto.DebugString(), goldenproto.DebugString());
+}
+
+}  // namespace
+
 }  // namespace protobuf
 }  // namespace google

+ 6 - 0
src/google/protobuf/stubs/common.h

@@ -1039,6 +1039,10 @@ class LIBPROTOBUF_EXPORT MutexLock {
   GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(MutexLock);
 };
 
+// TODO(kenton):  Implement these?  Hard to implement portably.
+typedef MutexLock ReaderMutexLock;
+typedef MutexLock WriterMutexLock;
+
 // MutexLockMaybe is like MutexLock, but is a no-op when mu is NULL.
 class LIBPROTOBUF_EXPORT MutexLockMaybe {
  public:
@@ -1056,6 +1060,8 @@ class LIBPROTOBUF_EXPORT MutexLockMaybe {
 // but we don't want to stick "internal::" in front of them everywhere.
 using internal::Mutex;
 using internal::MutexLock;
+using internal::ReaderMutexLock;
+using internal::WriterMutexLock;
 using internal::MutexLockMaybe;
 
 // ===================================================================

+ 27 - 0
src/google/protobuf/stubs/hash.h

@@ -178,6 +178,33 @@ struct hash<string> {
   }
 };
 
+template <typename First, typename Second>
+struct hash<pair<First, Second> > {
+  inline size_t operator()(const pair<First, Second>& key) const {
+    size_t first_hash = hash<First>()(key.first);
+    size_t second_hash = hash<Second>()(key.second);
+
+    // FIXME(kenton):  What is the best way to compute this hash?  I have
+    // no idea!  This seems a bit better than an XOR.
+    return first_hash * ((1 << 16) - 1) + second_hash;
+  }
+
+  static const size_t bucket_size = 4;
+  static const size_t min_buckets = 8;
+  inline size_t operator()(const pair<First, Second>& a,
+                           const pair<First, Second>& b) const {
+    return a < b;
+  }
+};
+
+// Used by GCC/SGI STL only.  (Why isn't this provided by the standard
+// library?  :( )
+struct streq {
+  inline bool operator()(const char* a, const char* b) const {
+    return strcmp(a, b) == 0;
+  }
+};
+
 }  // namespace protobuf
 }  // namespace google
 

+ 15 - 0
src/google/protobuf/stubs/map-util.h

@@ -39,6 +39,21 @@
 namespace google {
 namespace protobuf {
 
+// Perform a lookup in a map or hash_map.
+// If the key is present in the map then the value associated with that
+// key is returned, otherwise the value passed as a default is returned.
+template <class Collection>
+const typename Collection::value_type::second_type&
+FindWithDefault(const Collection& collection,
+                const typename Collection::value_type::first_type& key,
+                const typename Collection::value_type::second_type& value) {
+  typename Collection::const_iterator it = collection.find(key);
+  if (it == collection.end()) {
+    return value;
+  }
+  return it->second;
+}
+
 // Perform a lookup in a map or hash_map.
 // If the key is present a const pointer to the associated value is returned,
 // otherwise a NULL pointer is returned.

+ 82 - 0
src/google/protobuf/stubs/once.cc

@@ -0,0 +1,82 @@
+// Protocol Buffers - Google's data interchange format
+// Copyright 2008 Google Inc.  All rights reserved.
+// http://code.google.com/p/protobuf/
+//
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted provided that the following conditions are
+// met:
+//
+//     * Redistributions of source code must retain the above copyright
+// notice, this list of conditions and the following disclaimer.
+//     * Redistributions in binary form must reproduce the above
+// copyright notice, this list of conditions and the following disclaimer
+// in the documentation and/or other materials provided with the
+// distribution.
+//     * Neither the name of Google Inc. nor the names of its
+// contributors may be used to endorse or promote products derived from
+// this software without specific prior written permission.
+//
+// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+// Author: kenton@google.com (Kenton Varda)
+//
+// emulates google3/base/once.h
+//
+// This header is intended to be included only by internal .cc files and
+// generated .pb.cc files.  Users should not use this directly.
+
+#ifdef _WIN32
+#include <windows.h>
+#endif
+
+#include <google/protobuf/stubs/once.h>
+
+namespace google {
+namespace protobuf {
+
+#ifdef _WIN32
+
+struct GoogleOnceInternal {
+  GoogleOnceInternal() {
+    InitializeCriticalSection(&critical_section);
+  }
+  ~GoogleOnceInternal() {
+    DeleteCriticalSection(&critical_section);
+  }
+  CRITICAL_SECTION critical_section;
+};
+
+GoogleOnceType::GoogleOnceType() {
+  // internal_ may be non-NULL if Init() was already called.
+  if (internal_ == NULL) internal_ = new GoogleOnceInternal;
+}
+
+void GoogleOnceType::Init(void (*init_func)()) {
+  // internal_ may be NULL if we're still in dynamic initialization and the
+  // constructor has not been called yet.  As mentioned in once.h, we assume
+  // that the program is still single-threaded at this time, and therefore it
+  // should be safe to initialize internal_ like so.
+  if (internal_ == NULL) internal_ = new GoogleOnceInternal;
+
+  EnterCriticalSection(&internal_->critical_section);
+  if (!initialized_) {
+    init_func();
+    initialized_ = true;
+  }
+  LeaveCriticalSection(&internal_->critical_section);
+}
+
+#endif
+
+}  // namespace protobuf
+}  // namespace google

+ 122 - 0
src/google/protobuf/stubs/once.h

@@ -0,0 +1,122 @@
+// Protocol Buffers - Google's data interchange format
+// Copyright 2008 Google Inc.  All rights reserved.
+// http://code.google.com/p/protobuf/
+//
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted provided that the following conditions are
+// met:
+//
+//     * Redistributions of source code must retain the above copyright
+// notice, this list of conditions and the following disclaimer.
+//     * Redistributions in binary form must reproduce the above
+// copyright notice, this list of conditions and the following disclaimer
+// in the documentation and/or other materials provided with the
+// distribution.
+//     * Neither the name of Google Inc. nor the names of its
+// contributors may be used to endorse or promote products derived from
+// this software without specific prior written permission.
+//
+// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+// Author: kenton@google.com (Kenton Varda)
+//
+// emulates google3/base/once.h
+//
+// This header is intended to be included only by internal .cc files and
+// generated .pb.cc files.  Users should not use this directly.
+//
+// This is basically a portable version of pthread_once().
+//
+// This header declares three things:
+// * A type called GoogleOnceType.
+// * A macro GOOGLE_PROTOBUF_DECLARE_ONCE() which declares a variable of type
+//   GoogleOnceType.  This is the only legal way to declare such a variable.
+//   The macro may only be used at the global scope (you cannot create local
+//   or class member variables of this type).
+// * A function GogoleOnceInit(GoogleOnceType* once, void (*init_func)()).
+//   This function, when invoked multiple times given the same GoogleOnceType
+//   object, will invoke init_func on the first call only, and will make sure
+//   none of the calls return before that first call to init_func has finished.
+//
+// This implements a way to perform lazy initialization.  It's more efficient
+// than using mutexes as no lock is needed if initialization has already
+// happened.
+//
+// Example usage:
+//   void Init();
+//   GOOGLE_PROTOBUF_DECLARE_ONCE(once_init);
+//
+//   // Calls Init() exactly once.
+//   void InitOnce() {
+//     GoogleOnceInit(&once_init, &Init);
+//   }
+//
+// Note that if GoogleOnceInit() is called before main() has begun, it must
+// only be called by the thread that will eventually call main() -- that is,
+// the thread that performs dynamic initialization.  In general this is a safe
+// assumption since people don't usually construct threads before main() starts,
+// but it is technically not guaranteed.  Unfortunately, Win32 provides no way
+// whatsoever to statically-initialize its synchronization primitives, so our
+// only choice is to assume that dynamic initialization is single-threaded.
+
+#ifndef GOOGLE_PROTOBUF_STUBS_ONCE_H__
+#define GOOGLE_PROTOBUF_STUBS_ONCE_H__
+
+#include <google/protobuf/stubs/common.h>
+
+#ifndef _WIN32
+#include <pthread.h>
+#endif
+
+namespace google {
+namespace protobuf {
+
+#ifdef _WIN32
+
+struct GoogleOnceInternal;
+
+struct GoogleOnceType {
+  GoogleOnceType();
+  void Init(void (*init_func)());
+
+  volatile bool initialized_;
+  GoogleOnceInternal* internal_;
+};
+
+#define GOOGLE_PROTOBUF_DECLARE_ONCE(NAME)                    \
+  ::google::protobuf::GoogleOnceType NAME
+
+inline void GoogleOnceInit(GoogleOnceType* once, void (*init_func)()) {
+  // Note:  Double-checked locking is safe on x86.
+  if (!once->initialized_) {
+    once->Init(init_func);
+  }
+}
+
+#else
+
+typedef pthread_once_t GoogleOnceType;
+
+#define GOOGLE_PROTOBUF_DECLARE_ONCE(NAME)                    \
+  pthread_once_t NAME = PTHREAD_ONCE_INIT
+
+inline void GoogleOnceInit(GoogleOnceType* once, void (*init_func)()) {
+  pthread_once(once, init_func);
+}
+
+#endif
+
+}  // namespace protobuf
+}  // namespace google
+
+#endif  // GOOGLE_PROTOBUF_STUBS_ONCE_H__

+ 253 - 0
src/google/protobuf/stubs/once_unittest.cc

@@ -0,0 +1,253 @@
+// Protocol Buffers - Google's data interchange format
+// Copyright 2008 Google Inc.  All rights reserved.
+// http://code.google.com/p/protobuf/
+//
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted provided that the following conditions are
+// met:
+//
+//     * Redistributions of source code must retain the above copyright
+// notice, this list of conditions and the following disclaimer.
+//     * Redistributions in binary form must reproduce the above
+// copyright notice, this list of conditions and the following disclaimer
+// in the documentation and/or other materials provided with the
+// distribution.
+//     * Neither the name of Google Inc. nor the names of its
+// contributors may be used to endorse or promote products derived from
+// this software without specific prior written permission.
+//
+// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+// Author: kenton@google.com (Kenton Varda)
+
+#ifdef _WIN32
+#include <windows.h>
+#else
+#include <unistd.h>
+#include <pthread.h>
+#endif
+
+#include <google/protobuf/stubs/once.h>
+#include <google/protobuf/testing/googletest.h>
+#include <gtest/gtest.h>
+
+namespace google {
+namespace protobuf {
+namespace {
+
+class OnceInitTest : public testing::Test {
+ protected:
+  void SetUp() {
+    state_ = INIT_NOT_STARTED;
+    current_test_ = this;
+  }
+
+  // Since GoogleOnceType is only allowed to be allocated in static storage,
+  // each test must use a different pair of GoogleOnceType objects which it
+  // must declare itself.
+  void SetOnces(GoogleOnceType* once, GoogleOnceType* recursive_once) {
+    once_ = once;
+    recursive_once_ = recursive_once;
+  }
+
+  void InitOnce() {
+    GoogleOnceInit(once_, &InitStatic);
+  }
+  void InitRecursiveOnce() {
+    GoogleOnceInit(recursive_once_, &InitRecursiveStatic);
+  }
+
+  void BlockInit() { init_blocker_.Lock(); }
+  void UnblockInit() { init_blocker_.Unlock(); }
+
+  class TestThread {
+   public:
+    TestThread(Closure* callback)
+        : done_(false), joined_(false), callback_(callback) {
+#ifdef _WIN32
+      thread_ = CreateThread(NULL, 0, &Start, this, 0, NULL);
+#else
+      pthread_create(&thread_, NULL, &Start, this);
+#endif
+    }
+    ~TestThread() {
+      if (!joined_) Join();
+    }
+
+    bool IsDone() {
+      MutexLock lock(&done_mutex_);
+      return done_;
+    }
+    void Join() {
+      joined_ = true;
+#ifdef _WIN32
+      WaitForSingleObject(thread_, INFINITE);
+      CloseHandle(thread_);
+#else
+      pthread_join(thread_, NULL);
+#endif
+    }
+
+   private:
+#ifdef _WIN32
+    HANDLE thread_;
+#else
+    pthread_t thread_;
+#endif
+
+    Mutex done_mutex_;
+    bool done_;
+    bool joined_;
+    Closure* callback_;
+
+#ifdef _WIN32
+    static DWORD WINAPI Start(LPVOID arg) {
+#else
+    static void* Start(void* arg) {
+#endif
+      reinterpret_cast<TestThread*>(arg)->Run();
+      return 0;
+    }
+
+    void Run() {
+      callback_->Run();
+      MutexLock lock(&done_mutex_);
+      done_ = true;
+    }
+  };
+
+  TestThread* RunInitOnceInNewThread() {
+    return new TestThread(NewCallback(this, &OnceInitTest::InitOnce));
+  }
+  TestThread* RunInitRecursiveOnceInNewThread() {
+    return new TestThread(NewCallback(this, &OnceInitTest::InitRecursiveOnce));
+  }
+
+  enum State {
+    INIT_NOT_STARTED,
+    INIT_STARTED,
+    INIT_DONE
+  };
+  State CurrentState() {
+    MutexLock lock(&mutex_);
+    return state_;
+  }
+
+  void WaitABit() {
+#ifdef _WIN32
+    Sleep(1000);
+#else
+    sleep(1);
+#endif
+  }
+
+ private:
+  Mutex mutex_;
+  Mutex init_blocker_;
+  State state_;
+  GoogleOnceType* once_;
+  GoogleOnceType* recursive_once_;
+
+  void Init() {
+    MutexLock lock(&mutex_);
+    EXPECT_EQ(INIT_NOT_STARTED, state_);
+    state_ = INIT_STARTED;
+    mutex_.Unlock();
+    init_blocker_.Lock();
+    init_blocker_.Unlock();
+    mutex_.Lock();
+    state_ = INIT_DONE;
+  }
+
+  static OnceInitTest* current_test_;
+  static void InitStatic() { current_test_->Init(); }
+  static void InitRecursiveStatic() { current_test_->InitOnce(); }
+};
+
+OnceInitTest* OnceInitTest::current_test_ = NULL;
+
+GOOGLE_PROTOBUF_DECLARE_ONCE(simple_once);
+
+TEST_F(OnceInitTest, Simple) {
+  SetOnces(&simple_once, NULL);
+
+  EXPECT_EQ(INIT_NOT_STARTED, CurrentState());
+  InitOnce();
+  EXPECT_EQ(INIT_DONE, CurrentState());
+
+  // Calling again has no effect.
+  InitOnce();
+  EXPECT_EQ(INIT_DONE, CurrentState());
+}
+
+GOOGLE_PROTOBUF_DECLARE_ONCE(recursive_once1);
+GOOGLE_PROTOBUF_DECLARE_ONCE(recursive_once2);
+
+TEST_F(OnceInitTest, Recursive) {
+  SetOnces(&recursive_once1, &recursive_once2);
+
+  EXPECT_EQ(INIT_NOT_STARTED, CurrentState());
+  InitRecursiveOnce();
+  EXPECT_EQ(INIT_DONE, CurrentState());
+}
+
+GOOGLE_PROTOBUF_DECLARE_ONCE(multiple_threads_once);
+
+TEST_F(OnceInitTest, MultipleThreads) {
+  SetOnces(&multiple_threads_once, NULL);
+
+  scoped_ptr<TestThread> threads[4];
+  EXPECT_EQ(INIT_NOT_STARTED, CurrentState());
+  for (int i = 0; i < 4; i++) {
+    threads[i].reset(RunInitOnceInNewThread());
+  }
+  for (int i = 0; i < 4; i++) {
+    threads[i]->Join();
+  }
+  EXPECT_EQ(INIT_DONE, CurrentState());
+}
+
+GOOGLE_PROTOBUF_DECLARE_ONCE(multiple_threads_blocked_once1);
+GOOGLE_PROTOBUF_DECLARE_ONCE(multiple_threads_blocked_once2);
+
+TEST_F(OnceInitTest, MultipleThreadsBlocked) {
+  SetOnces(&multiple_threads_blocked_once1, &multiple_threads_blocked_once2);
+
+  scoped_ptr<TestThread> threads[8];
+  EXPECT_EQ(INIT_NOT_STARTED, CurrentState());
+
+  BlockInit();
+  for (int i = 0; i < 4; i++) {
+    threads[i].reset(RunInitOnceInNewThread());
+  }
+  for (int i = 4; i < 8; i++) {
+    threads[i].reset(RunInitRecursiveOnceInNewThread());
+  }
+
+  WaitABit();
+
+  // We should now have one thread blocked inside Init(), four blocked waiting
+  // for Init() to complete, and three blocked waiting for InitRecursive() to
+  // complete.
+  EXPECT_EQ(INIT_STARTED, CurrentState());
+  UnblockInit();
+
+  for (int i = 0; i < 8; i++) {
+    threads[i]->Join();
+  }
+  EXPECT_EQ(INIT_DONE, CurrentState());
+}
+
+}  // anonymous namespace
+}  // namespace protobuf
+}  // namespace google

+ 22 - 0
src/google/protobuf/test_util.cc

@@ -1871,6 +1871,15 @@ void TestUtil::ReflectionTester::SetPackedFieldsViaReflection(
 
 void TestUtil::ReflectionTester::ExpectAllFieldsSetViaReflection(
     const Message& message) {
+  // We have to split this into three function otherwise it creates a stack
+  // frame so large that it triggers a warning.
+  ExpectAllFieldsSetViaReflection1(message);
+  ExpectAllFieldsSetViaReflection2(message);
+  ExpectAllFieldsSetViaReflection3(message);
+}
+
+void TestUtil::ReflectionTester::ExpectAllFieldsSetViaReflection1(
+    const Message& message) {
   const Reflection* reflection = message.GetReflection();
   string scratch;
   const Message* sub_message;
@@ -1949,6 +1958,13 @@ void TestUtil::ReflectionTester::ExpectAllFieldsSetViaReflection(
 
   EXPECT_EQ("125", reflection->GetString(message, F("optional_cord")));
   EXPECT_EQ("125", reflection->GetStringReference(message, F("optional_cord"), &scratch));
+}
+
+void TestUtil::ReflectionTester::ExpectAllFieldsSetViaReflection2(
+    const Message& message) {
+  const Reflection* reflection = message.GetReflection();
+  string scratch;
+  const Message* sub_message;
 
   // -----------------------------------------------------------------
 
@@ -2060,6 +2076,12 @@ void TestUtil::ReflectionTester::ExpectAllFieldsSetViaReflection(
   EXPECT_EQ("325", reflection->GetRepeatedString(message, F("repeated_cord"), 1));
   EXPECT_EQ("325", reflection->GetRepeatedStringReference(
                         message, F("repeated_cord"), 1, &scratch));
+}
+
+void TestUtil::ReflectionTester::ExpectAllFieldsSetViaReflection3(
+    const Message& message) {
+  const Reflection* reflection = message.GetReflection();
+  string scratch;
 
   // -----------------------------------------------------------------
 

+ 6 - 0
src/google/protobuf/test_util.h

@@ -137,6 +137,12 @@ class TestUtil {
     const EnumValueDescriptor* import_bar_;
     const EnumValueDescriptor* import_baz_;
 
+    // We have to split this into three function otherwise it creates a stack
+    // frame so large that it triggers a warning.
+    void ExpectAllFieldsSetViaReflection1(const Message& message);
+    void ExpectAllFieldsSetViaReflection2(const Message& message);
+    void ExpectAllFieldsSetViaReflection3(const Message& message);
+
     GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(ReflectionTester);
   };
 

+ 178 - 93
src/google/protobuf/text_format.cc

@@ -62,29 +62,20 @@ string Message::DebugString() const {
 }
 
 string Message::ShortDebugString() const {
-  // TODO(kenton):  Make TextFormat support this natively instead of using
-  //   DebugString() and munging the result.
-  string result = DebugString();
-
-  // Replace each contiguous range of whitespace (including newlines, and
-  // starting with a newline) with a single space.
-  int out = 0;
-  for (int i = 0; i < result.size(); ++i) {
-    if (result[i] != '\n') {
-      result[out++] = result[i];
-    } else {
-      while (i < result.size() && isspace(result[i])) ++i;
-      --i;
-      result[out++] = ' ';
-    }
-  }
-  // Remove trailing space, if there is one.
-  if (out > 0 && isspace(result[out - 1])) {
-    --out;
+  string debug_string;
+  io::StringOutputStream output_stream(&debug_string);
+
+  TextFormat::Printer printer;
+  printer.SetSingleLineMode(true);
+
+  printer.Print(*this, &output_stream);
+  // Single line mode currently might have an extra space at the end.
+  if (debug_string.size() > 0 &&
+      debug_string[debug_string.size() - 1] == ' ') {
+    debug_string.resize(debug_string.size() - 1);
   }
-  result.resize(out);
 
-  return result;
+  return debug_string;
 }
 
 void Message::PrintDebugString() const {
@@ -429,9 +420,13 @@ class TextFormat::Parser::ParserImpl {
       return false;
     }
 
-    io::Tokenizer::ParseString(tokenizer_.current().text, text);
+    text->clear();
+    while (LookingAtType(io::Tokenizer::TYPE_STRING)) {
+      io::Tokenizer::ParseStringAppend(tokenizer_.current().text, text);
+
+      tokenizer_.Next();
+    }
 
-    tokenizer_.Next();
     return true;
   }
 
@@ -591,15 +586,18 @@ class TextFormat::Parser::ParserImpl {
 // ===========================================================================
 // Internal class for writing text to the io::ZeroCopyOutputStream. Adapted
 // from the Printer found in //google/protobuf/io/printer.h
-class TextFormat::TextGenerator {
+class TextFormat::Printer::TextGenerator {
  public:
-  explicit TextGenerator(io::ZeroCopyOutputStream* output)
+  explicit TextGenerator(io::ZeroCopyOutputStream* output,
+                         int initial_indent_level)
     : output_(output),
       buffer_(NULL),
       buffer_size_(0),
       at_start_of_line_(true),
       failed_(false),
-      indent_("") {
+      indent_(""),
+      initial_indent_level_(initial_indent_level) {
+    indent_.resize(initial_indent_level_ * 2, ' ');
   }
 
   ~TextGenerator() {
@@ -620,7 +618,8 @@ class TextFormat::TextGenerator {
   // Reduces the current indent level by two spaces, or crashes if the indent
   // level is zero.
   void Outdent() {
-    if (indent_.empty()) {
+    if (indent_.empty() ||
+        indent_.size() < initial_indent_level_ * 2) {
       GOOGLE_LOG(DFATAL) << " Outdent() without matching Indent().";
       return;
     }
@@ -699,6 +698,7 @@ class TextFormat::TextGenerator {
   bool failed_;
 
   string indent_;
+  int initial_indent_level_;
 };
 
 // ===========================================================================
@@ -770,8 +770,16 @@ bool TextFormat::Parser::MergeUsingImpl(io::ZeroCopyInputStream* input,
   return Parser().MergeFromString(input, output);
 }
 
-/* static */ bool TextFormat::PrintToString(const Message& message,
-                                            string* output) {
+// ===========================================================================
+
+TextFormat::Printer::Printer()
+  : initial_indent_level_(0),
+    single_line_mode_(false) {}
+
+TextFormat::Printer::~Printer() {}
+
+bool TextFormat::Printer::PrintToString(const Message& message,
+                                        string* output) {
   GOOGLE_DCHECK(output) << "output specified is NULL";
 
   output->clear();
@@ -782,7 +790,7 @@ bool TextFormat::Parser::MergeUsingImpl(io::ZeroCopyInputStream* input,
   return result;
 }
 
-/* static */ bool TextFormat::PrintUnknownFieldsToString(
+bool TextFormat::Printer::PrintUnknownFieldsToString(
     const UnknownFieldSet& unknown_fields,
     string* output) {
   GOOGLE_DCHECK(output) << "output specified is NULL";
@@ -792,9 +800,9 @@ bool TextFormat::Parser::MergeUsingImpl(io::ZeroCopyInputStream* input,
   return PrintUnknownFields(unknown_fields, &output_stream);
 }
 
-/* static */ bool TextFormat::Print(const Message& message,
-                                    io::ZeroCopyOutputStream* output) {
-  TextGenerator generator(output);
+bool TextFormat::Printer::Print(const Message& message,
+                                io::ZeroCopyOutputStream* output) {
+  TextGenerator generator(output, initial_indent_level_);
 
   Print(message, generator);
 
@@ -802,10 +810,10 @@ bool TextFormat::Parser::MergeUsingImpl(io::ZeroCopyInputStream* input,
   return !generator.failed();
 }
 
-/* static */ bool TextFormat::PrintUnknownFields(
+bool TextFormat::Printer::PrintUnknownFields(
     const UnknownFieldSet& unknown_fields,
     io::ZeroCopyOutputStream* output) {
-  TextGenerator generator(output);
+  TextGenerator generator(output, initial_indent_level_);
 
   PrintUnknownFields(unknown_fields, generator);
 
@@ -813,8 +821,8 @@ bool TextFormat::Parser::MergeUsingImpl(io::ZeroCopyInputStream* input,
   return !generator.failed();
 }
 
-/* static */ void TextFormat::Print(const Message& message,
-                                    TextGenerator& generator) {
+void TextFormat::Printer::Print(const Message& message,
+                                TextGenerator& generator) {
   const Reflection* reflection = message.GetReflection();
   vector<const FieldDescriptor*> fields;
   reflection->ListFields(message, &fields);
@@ -824,7 +832,7 @@ bool TextFormat::Parser::MergeUsingImpl(io::ZeroCopyInputStream* input,
   PrintUnknownFields(reflection->GetUnknownFields(message), generator);
 }
 
-/* static */ void TextFormat::PrintFieldValueToString(
+void TextFormat::Printer::PrintFieldValueToString(
     const Message& message,
     const FieldDescriptor* field,
     int index,
@@ -834,15 +842,15 @@ bool TextFormat::Parser::MergeUsingImpl(io::ZeroCopyInputStream* input,
 
   output->clear();
   io::StringOutputStream output_stream(output);
-  TextGenerator generator(&output_stream);
+  TextGenerator generator(&output_stream, initial_indent_level_);
 
   PrintFieldValue(message, message.GetReflection(), field, index, generator);
 }
 
-/* static */ void TextFormat::PrintField(const Message& message,
-                                         const Reflection* reflection,
-                                         const FieldDescriptor* field,
-                                         TextGenerator& generator) {
+void TextFormat::Printer::PrintField(const Message& message,
+                                     const Reflection* reflection,
+                                     const FieldDescriptor* field,
+                                     TextGenerator& generator) {
   int count = 0;
 
   if (field->is_repeated()) {
@@ -874,8 +882,12 @@ bool TextFormat::Parser::MergeUsingImpl(io::ZeroCopyInputStream* input,
     }
 
     if (field->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
+      if (single_line_mode_) {
+        generator.Print(" { ");
+      } else {
         generator.Print(" {\n");
         generator.Indent();
+      }
     } else {
       generator.Print(": ");
     }
@@ -889,15 +901,21 @@ bool TextFormat::Parser::MergeUsingImpl(io::ZeroCopyInputStream* input,
     PrintFieldValue(message, reflection, field, field_index, generator);
 
     if (field->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
+      if (!single_line_mode_) {
         generator.Outdent();
-        generator.Print("}");
+      }
+      generator.Print("}");
     }
 
-    generator.Print("\n");
+    if (single_line_mode_) {
+      generator.Print(" ");
+    } else {
+      generator.Print("\n");
+    }
   }
 }
 
-/* static */ void TextFormat::PrintFieldValue(
+void TextFormat::Printer::PrintFieldValue(
     const Message& message,
     const Reflection* reflection,
     const FieldDescriptor* field,
@@ -961,6 +979,35 @@ bool TextFormat::Parser::MergeUsingImpl(io::ZeroCopyInputStream* input,
   }
 }
 
+/* static */ bool TextFormat::Print(const Message& message,
+                                    io::ZeroCopyOutputStream* output) {
+  return Printer().Print(message, output);
+}
+
+/* static */ bool TextFormat::PrintUnknownFields(
+    const UnknownFieldSet& unknown_fields,
+    io::ZeroCopyOutputStream* output) {
+  return Printer().PrintUnknownFields(unknown_fields, output);
+}
+
+/* static */ bool TextFormat::PrintToString(
+    const Message& message, string* output) {
+  return Printer().PrintToString(message, output);
+}
+
+/* static */ bool TextFormat::PrintUnknownFieldsToString(
+    const UnknownFieldSet& unknown_fields, string* output) {
+  return Printer().PrintUnknownFieldsToString(unknown_fields, output);
+}
+
+/* static */ void TextFormat::PrintFieldValueToString(
+    const Message& message,
+    const FieldDescriptor* field,
+    int index,
+    string* output) {
+  return Printer().PrintFieldValueToString(message, field, index, output);
+}
+
 // Prints an integer as hex with a fixed number of digits dependent on the
 // integer type.
 template<typename IntType>
@@ -973,59 +1020,97 @@ static string PaddedHex(IntType value) {
   return result;
 }
 
-/* static */ void TextFormat::PrintUnknownFields(
+void TextFormat::Printer::PrintUnknownFields(
     const UnknownFieldSet& unknown_fields, TextGenerator& generator) {
   for (int i = 0; i < unknown_fields.field_count(); i++) {
     const UnknownField& field = unknown_fields.field(i);
     string field_number = SimpleItoa(field.number());
 
-    for (int j = 0; j < field.varint_size(); j++) {
-      generator.Print(field_number);
-      generator.Print(": ");
-      generator.Print(SimpleItoa(field.varint(j)));
-      generator.Print("\n");
-    }
-    for (int j = 0; j < field.fixed32_size(); j++) {
-      generator.Print(field_number);
-      generator.Print(": 0x");
-      char buffer[kFastToBufferSize];
-      generator.Print(FastHex32ToBuffer(field.fixed32(j), buffer));
-      generator.Print("\n");
-    }
-    for (int j = 0; j < field.fixed64_size(); j++) {
-      generator.Print(field_number);
-      generator.Print(": 0x");
-      char buffer[kFastToBufferSize];
-      generator.Print(FastHex64ToBuffer(field.fixed64(j), buffer));
-      generator.Print("\n");
-    }
-    for (int j = 0; j < field.length_delimited_size(); j++) {
-      generator.Print(field_number);
-      const string& value = field.length_delimited(j);
-      UnknownFieldSet embedded_unknown_fields;
-      if (!value.empty() && embedded_unknown_fields.ParseFromString(value)) {
-        // This field is parseable as a Message.
-        // So it is probably an embedded message.
-        generator.Print(" {\n");
-        generator.Indent();
-        PrintUnknownFields(embedded_unknown_fields, generator);
-        generator.Outdent();
-        generator.Print("}\n");
-      } else {
-        // This field is not parseable as a Message.
-        // So it is probably just a plain string.
-        generator.Print(": \"");
-        generator.Print(CEscape(value));
-        generator.Print("\"\n");
+    switch (field.type()) {
+      case UnknownField::TYPE_VARINT:
+        generator.Print(field_number);
+        generator.Print(": ");
+        generator.Print(SimpleItoa(field.varint()));
+        if (single_line_mode_) {
+          generator.Print(" ");
+        } else {
+          generator.Print("\n");
+        }
+        break;
+      case UnknownField::TYPE_FIXED32: {
+        generator.Print(field_number);
+        generator.Print(": 0x");
+        char buffer[kFastToBufferSize];
+        generator.Print(FastHex32ToBuffer(field.fixed32(), buffer));
+        if (single_line_mode_) {
+          generator.Print(" ");
+        } else {
+          generator.Print("\n");
+        }
+        break;
       }
-    }
-    for (int j = 0; j < field.group_size(); j++) {
-      generator.Print(field_number);
-      generator.Print(" {\n");
-      generator.Indent();
-      PrintUnknownFields(field.group(j), generator);
-      generator.Outdent();
-      generator.Print("}\n");
+      case UnknownField::TYPE_FIXED64: {
+        generator.Print(field_number);
+        generator.Print(": 0x");
+        char buffer[kFastToBufferSize];
+        generator.Print(FastHex64ToBuffer(field.fixed64(), buffer));
+        if (single_line_mode_) {
+          generator.Print(" ");
+        } else {
+          generator.Print("\n");
+        }
+        break;
+      }
+      case UnknownField::TYPE_LENGTH_DELIMITED: {
+        generator.Print(field_number);
+        const string& value = field.length_delimited();
+        UnknownFieldSet embedded_unknown_fields;
+        if (!value.empty() && embedded_unknown_fields.ParseFromString(value)) {
+          // This field is parseable as a Message.
+          // So it is probably an embedded message.
+          if (single_line_mode_) {
+            generator.Print(" { ");
+          } else {
+            generator.Print(" {\n");
+            generator.Indent();
+          }
+          PrintUnknownFields(embedded_unknown_fields, generator);
+          if (single_line_mode_) {
+            generator.Print("} ");
+          } else {
+            generator.Outdent();
+            generator.Print("}\n");
+          }
+        } else {
+          // This field is not parseable as a Message.
+          // So it is probably just a plain string.
+          generator.Print(": \"");
+          generator.Print(CEscape(value));
+          generator.Print("\"");
+          if (single_line_mode_) {
+            generator.Print(" ");
+          } else {
+            generator.Print("\n");
+          }
+        }
+        break;
+      }
+      case UnknownField::TYPE_GROUP:
+        generator.Print(field_number);
+        if (single_line_mode_) {
+          generator.Print(" { ");
+        } else {
+          generator.Print(" {\n");
+          generator.Indent();
+        }
+        PrintUnknownFields(field.group(), generator);
+        if (single_line_mode_) {
+          generator.Print("} ");
+        } else {
+          generator.Outdent();
+          generator.Print("}\n");
+        }
+        break;
     }
   }
 }

+ 70 - 29
src/google/protobuf/text_format.h

@@ -82,6 +82,76 @@ class LIBPROTOBUF_EXPORT TextFormat {
                                       int index,
                                       string* output);
 
+  // Class for those users which require more fine-grained control over how
+  // a protobuffer message is printed out.
+  class LIBPROTOBUF_EXPORT Printer {
+   public:
+    Printer();
+    ~Printer();
+
+    // Like TextFormat::Print
+    bool Print(const Message& message, io::ZeroCopyOutputStream* output);
+    // Like TextFormat::PrintUnknownFields
+    bool PrintUnknownFields(const UnknownFieldSet& unknown_fields,
+                            io::ZeroCopyOutputStream* output);
+    // Like TextFormat::PrintToString
+    bool PrintToString(const Message& message, string* output);
+    // Like TextFormat::PrintUnknownFieldsToString
+    bool PrintUnknownFieldsToString(const UnknownFieldSet& unknown_fields,
+                                    string* output);
+    // Like TextFormat::PrintFieldValueToString
+    void PrintFieldValueToString(const Message& message,
+                                 const FieldDescriptor* field,
+                                 int index,
+                                 string* output);
+
+    // Adjust the initial indent level of all output.  Each indent level is
+    // equal to two spaces.
+    void SetInitialIndentLevel(int indent_level) {
+      initial_indent_level_ = indent_level;
+    }
+
+    // If printing in single line mode, then the entire message will be output
+    // on a single line with no line breaks.
+    void SetSingleLineMode(bool single_line_mode) {
+      single_line_mode_ = single_line_mode;
+    }
+
+   private:
+    // Forward declaration of an internal class used to print the text
+    // output to the OutputStream (see text_format.cc for implementation).
+    class TextGenerator;
+
+    // Internal Print method, used for writing to the OutputStream via
+    // the TextGenerator class.
+    void Print(const Message& message,
+               TextGenerator& generator);
+
+    // Print a single field.
+    void PrintField(const Message& message,
+                    const Reflection* reflection,
+                    const FieldDescriptor* field,
+                    TextGenerator& generator);
+
+    // Outputs a textual representation of the value of the field supplied on
+    // the message supplied or the default value if not set.
+    void PrintFieldValue(const Message& message,
+                         const Reflection* reflection,
+                         const FieldDescriptor* field,
+                         int index,
+                         TextGenerator& generator);
+
+    // Print the fields in an UnknownFieldSet.  They are printed by tag number
+    // only.  Embedded messages are heuristically identified by attempting to
+    // parse them.
+    void PrintUnknownFields(const UnknownFieldSet& unknown_fields,
+                            TextGenerator& generator);
+
+    int initial_indent_level_;
+
+    bool single_line_mode_;
+  };
+
   // Parses a text-format protocol message from the given input stream to
   // the given message object.  This function parses the format written
   // by Print().
@@ -138,35 +208,6 @@ class LIBPROTOBUF_EXPORT TextFormat {
   };
 
  private:
-  // Forward declaration of an internal class used to print the text
-  // output to the OutputStream (see text_format.cc for implementation).
-  class TextGenerator;
-
-  // Internal Print method, used for writing to the OutputStream via
-  // the TextGenerator class.
-  static void Print(const Message& message,
-                    TextGenerator& generator);
-
-  // Print a single field.
-  static void PrintField(const Message& message,
-                         const Reflection* reflection,
-                         const FieldDescriptor* field,
-                         TextGenerator& generator);
-
-  // Outputs a textual representation of the value of the field supplied on
-  // the message supplied or the default value if not set.
-  static void PrintFieldValue(const Message& message,
-                              const Reflection* reflection,
-                              const FieldDescriptor* field,
-                              int index,
-                              TextGenerator& generator);
-
-  // Print the fields in an UnknownFieldSet.  They are printed by tag number
-  // only.  Embedded messages are heuristically identified by attempting to
-  // parse them.
-  static void PrintUnknownFields(const UnknownFieldSet& unknown_fields,
-                                 TextGenerator& generator);
-
   GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(TextFormat);
 };
 

+ 73 - 10
src/google/protobuf/text_format_unittest.cc

@@ -163,18 +163,16 @@ TEST_F(TextFormatTest, PrintUnknownFields) {
 
   unittest::TestEmptyMessage message;
   UnknownFieldSet* unknown_fields = message.mutable_unknown_fields();
-  UnknownField* field5 = unknown_fields->AddField(5);
 
-  field5->add_varint(1);
-  field5->add_fixed32(2);
-  field5->add_fixed64(3);
-  field5->add_length_delimited("4");
-  field5->add_group()->AddField(10)->add_varint(5);
+  unknown_fields->AddVarint(5, 1);
+  unknown_fields->AddFixed32(5, 2);
+  unknown_fields->AddFixed64(5, 3);
+  unknown_fields->AddLengthDelimited(5, "4");
+  unknown_fields->AddGroup(5)->AddVarint(10, 5);
 
-  UnknownField* field8 = unknown_fields->AddField(8);
-  field8->add_varint(1);
-  field8->add_varint(2);
-  field8->add_varint(3);
+  unknown_fields->AddVarint(8, 1);
+  unknown_fields->AddVarint(8, 2);
+  unknown_fields->AddVarint(8, 3);
 
   EXPECT_EQ(
     "5: 1\n"
@@ -234,6 +232,48 @@ TEST_F(TextFormatTest, PrintUnknownMessage) {
     text);
 }
 
+TEST_F(TextFormatTest, PrintMessageWithIndent) {
+  // Test adding an initial indent to printing.
+
+  protobuf_unittest::TestAllTypes message;
+
+  message.add_repeated_string("abc");
+  message.add_repeated_string("def");
+  message.add_repeated_nested_message()->set_bb(123);
+
+  string text;
+  TextFormat::Printer printer;
+  printer.SetInitialIndentLevel(1);
+  EXPECT_TRUE(printer.PrintToString(message, &text));
+  EXPECT_EQ(
+    "  repeated_string: \"abc\"\n"
+    "  repeated_string: \"def\"\n"
+    "  repeated_nested_message {\n"
+    "    bb: 123\n"
+    "  }\n",
+    text);
+}
+
+TEST_F(TextFormatTest, PrintMessageSingleLine) {
+  // Test printing a message on a single line.
+
+  protobuf_unittest::TestAllTypes message;
+
+  message.add_repeated_string("abc");
+  message.add_repeated_string("def");
+  message.add_repeated_nested_message()->set_bb(123);
+
+  string text;
+  TextFormat::Printer printer;
+  printer.SetInitialIndentLevel(1);
+  printer.SetSingleLineMode(true);
+  EXPECT_TRUE(printer.PrintToString(message, &text));
+  EXPECT_EQ(
+    "  repeated_string: \"abc\" repeated_string: \"def\" "
+    "repeated_nested_message { bb: 123 } ",
+    text);
+}
+
 TEST_F(TextFormatTest, ParseBasic) {
   io::ArrayInputStream input_stream(proto_debug_string_.data(),
                                     proto_debug_string_.size());
@@ -262,6 +302,29 @@ TEST_F(TextFormatTest, ParseStringEscape) {
   EXPECT_EQ(kEscapeTestString, proto_.optional_string());
 }
 
+TEST_F(TextFormatTest, ParseConcatenatedString) {
+  // Create a parse string with multiple parts on one line.
+  string parse_string = "optional_string: \"foo\" \"bar\"\n";
+
+  io::ArrayInputStream input_stream1(parse_string.data(),
+                                    parse_string.size());
+  TextFormat::Parse(&input_stream1, &proto_);
+
+  // Compare.
+  EXPECT_EQ("foobar", proto_.optional_string());
+
+  // Create a parse string with multiple parts on seperate lines.
+  parse_string = "optional_string: \"foo\"\n"
+                 "\"bar\"\n";
+
+  io::ArrayInputStream input_stream2(parse_string.data(),
+                                    parse_string.size());
+  TextFormat::Parse(&input_stream2, &proto_);
+
+  // Compare.
+  EXPECT_EQ("foobar", proto_.optional_string());
+}
+
 TEST_F(TextFormatTest, ParseFloatWithSuffix) {
   // Test that we can parse a floating-point value with 'f' appended to the
   // end.  This is needed for backwards-compatibility with proto1.

+ 8 - 0
src/google/protobuf/unittest.proto

@@ -269,6 +269,14 @@ extend TestAllExtensions {
   optional string default_cord_extension = 85 [ctype=CORD, default="123"];
 }
 
+message TestNestedExtension {
+  extend TestAllExtensions {
+    // Check for bug where string extensions declared in tested scope did not
+    // compile.
+    optional string test = 1002 [default="test"];
+  }
+}
+
 // We have separate messages for testing required fields because it's
 // annoying to have to fill in required fields in TestProto in order to
 // do anything with it.  Note that we don't need to test every type of

+ 37 - 0
src/google/protobuf/unittest_empty.proto

@@ -0,0 +1,37 @@
+// Protocol Buffers - Google's data interchange format
+// Copyright 2008 Google Inc.  All rights reserved.
+// http://code.google.com/p/protobuf/
+//
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted provided that the following conditions are
+// met:
+//
+//     * Redistributions of source code must retain the above copyright
+// notice, this list of conditions and the following disclaimer.
+//     * Redistributions in binary form must reproduce the above
+// copyright notice, this list of conditions and the following disclaimer
+// in the documentation and/or other materials provided with the
+// distribution.
+//     * Neither the name of Google Inc. nor the names of its
+// contributors may be used to endorse or promote products derived from
+// this software without specific prior written permission.
+//
+// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+// Author: kenton@google.com (Kenton Varda)
+//  Based on original Protocol Buffers design by
+//  Sanjay Ghemawat, Jeff Dean, and others.
+//
+// This file intentionally left blank.  (At one point this wouldn't compile
+// correctly.)
+

+ 108 - 136
src/google/protobuf/unknown_field_set.cc

@@ -43,38 +43,106 @@ namespace google {
 namespace protobuf {
 
 UnknownFieldSet::UnknownFieldSet()
-  : internal_(NULL) {}
+  : fields_(NULL) {}
 
 UnknownFieldSet::~UnknownFieldSet() {
-  if (internal_ != NULL) {
-    STLDeleteValues(&internal_->fields_);
-    delete internal_;
-  }
+  Clear();
+  delete fields_;
 }
 
 void UnknownFieldSet::Clear() {
-  if (internal_ == NULL) return;
-
-  if (internal_->fields_.size() > kMaxInactiveFields) {
-    STLDeleteValues(&internal_->fields_);
-  } else {
-    // Don't delete the UnknownField objects.  Just remove them from the active
-    // set.
-    for (int i = 0; i < internal_->active_fields_.size(); i++) {
-      internal_->active_fields_[i]->Clear();
-      internal_->active_fields_[i]->index_ = -1;
+  if (fields_ != NULL) {
+    for (int i = 0; i < fields_->size(); i++) {
+      (*fields_)[i].Delete();
     }
+    fields_->clear();
   }
-
-  internal_->active_fields_.clear();
 }
 
 void UnknownFieldSet::MergeFrom(const UnknownFieldSet& other) {
   for (int i = 0; i < other.field_count(); i++) {
-    AddField(other.field(i).number())->MergeFrom(other.field(i));
+    AddField(other.field(i));
   }
 }
 
+int UnknownFieldSet::SpaceUsedExcludingSelf() const {
+  if (fields_ == NULL) return 0;
+
+  int total_size = sizeof(*fields_) + sizeof(UnknownField) * fields_->size();
+  for (int i = 0; i < fields_->size(); i++) {
+    const UnknownField& field = (*fields_)[i];
+    switch (field.type()) {
+      case UnknownField::TYPE_LENGTH_DELIMITED:
+        total_size += sizeof(*field.length_delimited_) +
+          internal::StringSpaceUsedExcludingSelf(*field.length_delimited_);
+        break;
+      case UnknownField::TYPE_GROUP:
+        total_size += field.group_->SpaceUsed();
+        break;
+      default:
+        break;
+    }
+  }
+  return total_size;
+}
+
+int UnknownFieldSet::SpaceUsed() const {
+  return sizeof(*this) + SpaceUsedExcludingSelf();
+}
+
+void UnknownFieldSet::AddVarint(int number, uint64 value) {
+  if (fields_ == NULL) fields_ = new vector<UnknownField>;
+  UnknownField field;
+  field.number_ = number;
+  field.type_ = UnknownField::TYPE_VARINT;
+  field.varint_ = value;
+  fields_->push_back(field);
+}
+
+void UnknownFieldSet::AddFixed32(int number, uint32 value) {
+  if (fields_ == NULL) fields_ = new vector<UnknownField>;
+  UnknownField field;
+  field.number_ = number;
+  field.type_ = UnknownField::TYPE_FIXED32;
+  field.fixed32_ = value;
+  fields_->push_back(field);
+}
+
+void UnknownFieldSet::AddFixed64(int number, uint64 value) {
+  if (fields_ == NULL) fields_ = new vector<UnknownField>;
+  UnknownField field;
+  field.number_ = number;
+  field.type_ = UnknownField::TYPE_FIXED64;
+  field.fixed64_ = value;
+  fields_->push_back(field);
+}
+
+string* UnknownFieldSet::AddLengthDelimited(int number) {
+  if (fields_ == NULL) fields_ = new vector<UnknownField>;
+  UnknownField field;
+  field.number_ = number;
+  field.type_ = UnknownField::TYPE_LENGTH_DELIMITED;
+  field.length_delimited_ = new string;
+  fields_->push_back(field);
+  return field.length_delimited_;
+}
+
+UnknownFieldSet* UnknownFieldSet::AddGroup(int number) {
+  if (fields_ == NULL) fields_ = new vector<UnknownField>;
+  UnknownField field;
+  field.number_ = number;
+  field.type_ = UnknownField::TYPE_GROUP;
+  field.group_ = new UnknownFieldSet;
+  fields_->push_back(field);
+  return field.group_;
+}
+
+void UnknownFieldSet::AddField(const UnknownField& field) {
+  if (fields_ == NULL) fields_ = new vector<UnknownField>;
+  fields_->push_back(field);
+  fields_->back().DeepCopy();
+}
+
 bool UnknownFieldSet::MergeFromCodedStream(io::CodedInputStream* input) {
 
   UnknownFieldSet other;
@@ -103,129 +171,33 @@ bool UnknownFieldSet::ParseFromArray(const void* data, int size) {
   return ParseFromZeroCopyStream(&input);
 }
 
-const UnknownField* UnknownFieldSet::FindFieldByNumber(int number) const {
-  if (internal_ == NULL) return NULL;
-
-  map<int, UnknownField*>::iterator iter = internal_->fields_.find(number);
-  if (iter != internal_->fields_.end() && iter->second->index() != -1) {
-    return iter->second;
-  } else {
-    return NULL;
+void UnknownField::Delete() {
+  switch (type()) {
+    case UnknownField::TYPE_LENGTH_DELIMITED:
+      delete length_delimited_;
+      break;
+    case UnknownField::TYPE_GROUP:
+      delete group_;
+      break;
+    default:
+      break;
   }
 }
 
-UnknownField* UnknownFieldSet::AddField(int number) {
-  if (internal_ == NULL) internal_ = new Internal;
-
-  UnknownField** map_slot = &internal_->fields_[number];
-  if (*map_slot == NULL) {
-    *map_slot = new UnknownField(number);
-  }
-
-  UnknownField* field = *map_slot;
-  if (field->index() == -1) {
-    field->index_ = internal_->active_fields_.size();
-    internal_->active_fields_.push_back(field);
-  }
-  return field;
-}
-
-int UnknownFieldSet::SpaceUsedExcludingSelf() const {
-  int total_size = 0;
-  if (internal_ != NULL) {
-    total_size += sizeof(*internal_);
-    total_size += internal_->active_fields_.capacity() *
-                  sizeof(Internal::FieldVector::value_type);
-    total_size += internal_->fields_.size() *
-        sizeof(Internal::FieldMap::value_type);
-
-    // Account for the UnknownField objects themselves.
-    for (Internal::FieldMap::const_iterator it = internal_->fields_.begin(),
-         end = internal_->fields_.end();
-         it != end;
-         ++it) {
-      total_size += it->second->SpaceUsed();
+void UnknownField::DeepCopy() {
+  switch (type()) {
+    case UnknownField::TYPE_LENGTH_DELIMITED:
+      length_delimited_ = new string(*length_delimited_);
+      break;
+    case UnknownField::TYPE_GROUP: {
+      UnknownFieldSet* group = new UnknownFieldSet;
+      group->MergeFrom(*group_);
+      group_ = group;
+      break;
     }
+    default:
+      break;
   }
-  return total_size;
-}
-
-int UnknownFieldSet::SpaceUsed() const {
-  return sizeof(*this) + SpaceUsedExcludingSelf();
-}
-
-UnknownFieldSet::Internal::FieldMap UnknownFieldSet::kEmptyMap;
-const UnknownFieldSet::iterator UnknownFieldSet::kEmptyIterator(
-  kEmptyMap.end(), &kEmptyMap);
-const UnknownFieldSet::const_iterator UnknownFieldSet::kEmptyConstIterator(
-  kEmptyMap.end(), &kEmptyMap);
-
-void UnknownFieldSet::iterator::AdvanceToNonEmpty() {
-  while (inner_iterator_ != inner_map_->end() &&
-         (inner_iterator_->second->index() == -1 ||
-          inner_iterator_->second->empty())) {
-    ++inner_iterator_;
-  }
-}
-
-void UnknownFieldSet::const_iterator::AdvanceToNonEmpty() {
-  while (inner_iterator_ != inner_map_->end() &&
-         (inner_iterator_->second->index() == -1 ||
-          inner_iterator_->second->empty())) {
-    ++inner_iterator_;
-  }
-}
-
-UnknownFieldSet::iterator UnknownFieldSet::begin() {
-  if (internal_ == NULL) return kEmptyIterator;
-
-  UnknownFieldSet::iterator result(internal_->fields_.begin(),
-                                   &internal_->fields_);
-  result.AdvanceToNonEmpty();
-  return result;
-}
-
-UnknownFieldSet::const_iterator UnknownFieldSet::begin() const {
-  if (internal_ == NULL) return kEmptyIterator;
-
-  UnknownFieldSet::const_iterator result(internal_->fields_.begin(),
-                                         &internal_->fields_);
-  result.AdvanceToNonEmpty();
-  return result;
-}
-
-UnknownField::UnknownField(int number)
-  : number_(number),
-    index_(-1) {
-}
-
-UnknownField::~UnknownField() {
-}
-
-void UnknownField::Clear() {
-  clear_varint();
-  clear_fixed32();
-  clear_fixed64();
-  clear_length_delimited();
-  clear_group();
-}
-
-void UnknownField::MergeFrom(const UnknownField& other) {
-  varint_          .MergeFrom(other.varint_          );
-  fixed32_         .MergeFrom(other.fixed32_         );
-  fixed64_         .MergeFrom(other.fixed64_         );
-  length_delimited_.MergeFrom(other.length_delimited_);
-  group_           .MergeFrom(other.group_           );
-}
-
-int UnknownField::SpaceUsed() const {
-  int total_size = sizeof(*this);
-  total_size += varint_.SpaceUsedExcludingSelf();
-  total_size += fixed32_.SpaceUsedExcludingSelf();
-  total_size += fixed64_.SpaceUsedExcludingSelf();
-  total_size += length_delimited_.SpaceUsedExcludingSelf();
-  total_size += group_.SpaceUsedExcludingSelf();
-  return total_size;
 }
 
 }  // namespace protobuf

+ 94 - 321
src/google/protobuf/unknown_field_set.h

@@ -39,7 +39,6 @@
 #define GOOGLE_PROTOBUF_UNKNOWN_FIELD_SET_H__
 
 #include <string>
-#include <map>
 #include <vector>
 #include <google/protobuf/repeated_field.h>
 
@@ -70,13 +69,6 @@ class LIBPROTOBUF_EXPORT UnknownFieldSet {
   void Clear();
 
   // Is this set empty?
-  //
-  // Note that this is equivalent to field_count() == 0 but is NOT necessarily
-  // equivalent to begin() == end().  The iterator class skips fields which are
-  // themselves empty, so if field_count() is non-zero but field(i)->empty() is
-  // true for all i, then begin() will be equal to end() but empty() will return
-  // false.  This inconsistency almost never occurs in practice because typical
-  // code does not add empty fields to an UnknownFieldSet.
   inline bool empty() const;
 
   // Merge the contents of some other UnknownFieldSet with this one.
@@ -85,13 +77,6 @@ class LIBPROTOBUF_EXPORT UnknownFieldSet {
   // Swaps the contents of some other UnknownFieldSet with this one.
   inline void Swap(UnknownFieldSet* x);
 
-  // Find a field by field number.  Returns NULL if not found.
-  const UnknownField* FindFieldByNumber(int number) const;
-
-  // Add a field by field number.  If the field number already exists, returns
-  // the existing UnknownField.
-  UnknownField* AddField(int number);
-
   // Computes (an estimate of) the total number of bytes currently used for
   // storing the unknown fields in memory. Does NOT include
   // sizeof(*this) in the calculation.
@@ -100,111 +85,28 @@ class LIBPROTOBUF_EXPORT UnknownFieldSet {
   // Version of SpaceUsed() including sizeof(*this).
   int SpaceUsed() const;
 
-  // STL-style iteration ---------------------------------------------
-  // These iterate over the non-empty UnknownFields in order by field
-  // number.  All iterators are invalidated whenever the UnknownFieldSet
-  // is modified.
-
-  class const_iterator;
-
-  class LIBPROTOBUF_EXPORT iterator {
-   public:
-    iterator() {}
-
-    bool operator==(const iterator& other) {
-      return inner_iterator_ == other.inner_iterator_;
-    }
-    bool operator!=(const iterator& other) {
-      return inner_iterator_ != other.inner_iterator_;
-    }
-
-    UnknownField& operator*() { return *inner_iterator_->second; }
-    UnknownField* operator->() { return inner_iterator_->second; }
-    iterator& operator++() {
-      ++inner_iterator_;
-      AdvanceToNonEmpty();
-      return *this;
-    }
-    iterator operator++(int) {
-      iterator copy(*this);
-      ++*this;
-      return copy;
-    }
-
-   private:
-    friend class UnknownFieldSet;
-    friend class LIBPROTOBUF_EXPORT UnknownFieldSet::const_iterator;
-    iterator(map<int, UnknownField*>::iterator inner_iterator,
-             map<int, UnknownField*>* inner_map)
-      : inner_iterator_(inner_iterator), inner_map_(inner_map) {}
-
-    void AdvanceToNonEmpty();
-
-    map<int, UnknownField*>::iterator inner_iterator_;
-    map<int, UnknownField*>* inner_map_;
-  };
-
-  class LIBPROTOBUF_EXPORT const_iterator {
-   public:
-    const_iterator() {}
-    const_iterator(const iterator& other)
-      : inner_iterator_(other.inner_iterator_), inner_map_(other.inner_map_) {}
-
-    bool operator==(const const_iterator& other) {
-      return inner_iterator_ == other.inner_iterator_;
-    }
-    bool operator!=(const const_iterator& other) {
-      return inner_iterator_ != other.inner_iterator_;
-    }
-
-    UnknownField& operator*() { return *inner_iterator_->second; }
-    UnknownField* operator->() { return inner_iterator_->second; }
-    const_iterator& operator++() {
-      ++inner_iterator_;
-      AdvanceToNonEmpty();
-      return *this;
-    }
-    const_iterator operator++(int) {
-      const_iterator copy(*this);
-      ++*this;
-      return copy;
-    }
-
-   private:
-    friend class UnknownFieldSet;
-    const_iterator(map<int, UnknownField*>::const_iterator inner_iterator,
-                   const map<int, UnknownField*>* inner_map)
-      : inner_iterator_(inner_iterator), inner_map_(inner_map) {}
-
-    void AdvanceToNonEmpty();
-
-    map<int, UnknownField*>::const_iterator inner_iterator_;
-    const map<int, UnknownField*>* inner_map_;
-  };
-
-  iterator begin();
-  iterator end() {
-    return internal_ == NULL ? kEmptyIterator :
-      iterator(internal_->fields_.end(), &internal_->fields_);
-  }
-  const_iterator begin() const;
-  const_iterator end() const {
-    return internal_ == NULL ? kEmptyConstIterator :
-      const_iterator(internal_->fields_.end(), &internal_->fields_);
-  }
-
-  // Old-style iteration ---------------------------------------------
-  // New code should use begin() and end() rather than these methods.
-
   // Returns the number of fields present in the UnknownFieldSet.
   inline int field_count() const;
   // Get a field in the set, where 0 <= index < field_count().  The fields
-  // appear in arbitrary order.
+  // appear in the order in which they were added.
   inline const UnknownField& field(int index) const;
   // Get a mutable pointer to a field in the set, where
-  // 0 <= index < field_count().  The fields appear in arbitrary order.
+  // 0 <= index < field_count().  The fields appear in the order in which
+  // they were added.
   inline UnknownField* mutable_field(int index);
 
+  // Adding fields ---------------------------------------------------
+
+  void AddVarint(int number, uint64 value);
+  void AddFixed32(int number, uint32 value);
+  void AddFixed64(int number, uint64 value);
+  void AddLengthDelimited(int number, const string& value);
+  string* AddLengthDelimited(int number);
+  UnknownFieldSet* AddGroup(int number);
+
+  // Adds an unknown field from another set.
+  void AddField(const UnknownField& field);
+
   // Parsing helpers -------------------------------------------------
   // These work exactly like the similarly-named methods of Message.
 
@@ -217,268 +119,139 @@ class LIBPROTOBUF_EXPORT UnknownFieldSet {
   }
 
  private:
-  // "Active" fields are ones which have been added since the last time Clear()
-  // was called.  Inactive fields are objects we are keeping around incase
-  // they become active again.
-
-  struct Internal {
-    // Contains all UnknownFields that have been allocated for this
-    // UnknownFieldSet, including ones not currently active.  Keyed by
-    // field number.  We intentionally try to reuse UnknownField objects for
-    // the same field number they were used for originally because this makes
-    // it more likely that the previously-allocated memory will have the right
-    // layout.
-    typedef map<int, UnknownField*> FieldMap;
-    FieldMap fields_;
-
-    // Contains the fields from fields_ that are currently active.
-    typedef vector<UnknownField*> FieldVector;
-    FieldVector active_fields_;
-  };
-
-  // We want an UnknownFieldSet to use no more space than a single pointer
-  // until the first field is added.
-  Internal* internal_;
-
-  // Don't keep more inactive fields than this.
-  static const int kMaxInactiveFields = 100;
-
-  // Used by begin() and end() when internal_ is NULL.
-  static Internal::FieldMap kEmptyMap;
-  static const iterator kEmptyIterator;
-  static const const_iterator kEmptyConstIterator;
+  vector<UnknownField>* fields_;
 
   GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(UnknownFieldSet);
 };
 
 // Represents one field in an UnknownFieldSet.
-//
-// UnknownField's accessors are similar to those that would be produced by the
-// protocol compiler for the fields:
-//   repeated uint64 varint;
-//   repeated fixed32 fixed32;
-//   repeated fixed64 fixed64;
-//   repeated bytes length_delimited;
-//   repeated UnknownFieldSet group;
-// (OK, so the last one isn't actually a valid field type but you get the
-// idea.)
 class LIBPROTOBUF_EXPORT UnknownField {
  public:
-  ~UnknownField();
+  enum Type {
+    TYPE_VARINT,
+    TYPE_FIXED32,
+    TYPE_FIXED64,
+    TYPE_LENGTH_DELIMITED,
+    TYPE_GROUP
+  };
 
-  // Clears all fields.
-  void Clear();
+  // The field's tag number, as seen on the wire.
+  inline int number() const;
 
-  // Is this field empty?  (I.e. all of the *_size() methods return zero.)
-  inline bool empty() const;
+  // The field type.
+  inline Type type() const;
 
-  // Merge the contents of some other UnknownField with this one.  For each
-  // wire type, the values are simply concatenated.
-  void MergeFrom(const UnknownField& other);
+  // Accessors -------------------------------------------------------
+  // Each method works only for UnknownFields of the corresponding type.
 
-  // The field's tag number, as seen on the wire.
-  inline int number() const;
+  inline uint64 varint() const;
+  inline uint32 fixed32() const;
+  inline uint64 fixed64() const;
+  inline const string& length_delimited() const;
+  inline const UnknownFieldSet& group() const;
 
-  // The index of this UnknownField within the UnknownFieldSet (e.g.
-  // set.field(field.index()) == field).
-  inline int index() const;
-
-  inline int varint_size          () const;
-  inline int fixed32_size         () const;
-  inline int fixed64_size         () const;
-  inline int length_delimited_size() const;
-  inline int group_size           () const;
-
-  inline uint64 varint (int index) const;
-  inline uint32 fixed32(int index) const;
-  inline uint64 fixed64(int index) const;
-  inline const string& length_delimited(int index) const;
-  inline const UnknownFieldSet& group(int index) const;
-
-  inline void set_varint (int index, uint64 value);
-  inline void set_fixed32(int index, uint32 value);
-  inline void set_fixed64(int index, uint64 value);
-  inline void set_length_delimited(int index, const string& value);
-  inline string* mutable_length_delimited(int index);
-  inline UnknownFieldSet* mutable_group(int index);
-
-  inline void add_varint (uint64 value);
-  inline void add_fixed32(uint32 value);
-  inline void add_fixed64(uint64 value);
-  inline void add_length_delimited(const string& value);
-  inline string* add_length_delimited();
-  inline UnknownFieldSet* add_group();
-
-  inline void clear_varint ();
-  inline void clear_fixed32();
-  inline void clear_fixed64();
-  inline void clear_length_delimited();
-  inline void clear_group();
-
-  inline const RepeatedField   <uint64         >& varint          () const;
-  inline const RepeatedField   <uint32         >& fixed32         () const;
-  inline const RepeatedField   <uint64         >& fixed64         () const;
-  inline const RepeatedPtrField<string         >& length_delimited() const;
-  inline const RepeatedPtrField<UnknownFieldSet>& group           () const;
-
-  inline RepeatedField   <uint64         >* mutable_varint          ();
-  inline RepeatedField   <uint32         >* mutable_fixed32         ();
-  inline RepeatedField   <uint64         >* mutable_fixed64         ();
-  inline RepeatedPtrField<string         >* mutable_length_delimited();
-  inline RepeatedPtrField<UnknownFieldSet>* mutable_group           ();
-
-  // Returns (an estimate of) the total number of bytes used to represent the
-  // unknown field.
-  int SpaceUsed() const;
+  inline void set_varint(uint64 value);
+  inline void set_fixed32(uint32 value);
+  inline void set_fixed64(uint64 value);
+  inline void set_length_delimited(const string& value);
+  inline string* mutable_length_delimited();
+  inline UnknownFieldSet* mutable_group();
 
  private:
   friend class UnknownFieldSet;
-  UnknownField(int number);
 
-  int number_;
-  int index_;
+  // If this UnknownField contains a pointer, delete it.
+  void Delete();
 
-  RepeatedField   <uint64         > varint_;
-  RepeatedField   <uint32         > fixed32_;
-  RepeatedField   <uint64         > fixed64_;
-  RepeatedPtrField<string         > length_delimited_;
-  RepeatedPtrField<UnknownFieldSet> group_;
+  // Make a deep copy of any pointers in this UnknownField.
+  void DeepCopy();
 
-  GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(UnknownField);
+  unsigned int number_ : 29;
+  unsigned int type_   : 3;
+  union {
+    uint64 varint_;
+    uint32 fixed32_;
+    uint64 fixed64_;
+    string* length_delimited_;
+    UnknownFieldSet* group_;
+  };
 };
 
 // ===================================================================
 // inline implementations
 
 inline bool UnknownFieldSet::empty() const {
-  return internal_ == NULL || internal_->active_fields_.empty();
+  return fields_ == NULL || fields_->empty();
 }
 
 inline void UnknownFieldSet::Swap(UnknownFieldSet* x) {
-  std::swap(internal_, x->internal_);
+  std::swap(fields_, x->fields_);
 }
 
 inline int UnknownFieldSet::field_count() const {
-  return (internal_ == NULL) ? 0 : internal_->active_fields_.size();
+  return (fields_ == NULL) ? 0 : fields_->size();
 }
 inline const UnknownField& UnknownFieldSet::field(int index) const {
-  return *(internal_->active_fields_[index]);
+  return (*fields_)[index];
 }
 inline UnknownField* UnknownFieldSet::mutable_field(int index) {
-  return internal_->active_fields_[index];
+  return &(*fields_)[index];
 }
 
-inline bool UnknownField::empty() const {
-  return varint_.size() == 0 &&
-         fixed32_.size() == 0 &&
-         fixed64_.size() == 0 &&
-         length_delimited_.size() == 0 &&
-         group_.size() == 0;
+inline void UnknownFieldSet::AddLengthDelimited(
+    int number, const string& value) {
+  AddLengthDelimited(number)->assign(value);
 }
 
 inline int UnknownField::number() const { return number_; }
-inline int UnknownField::index () const { return index_; }
-
-inline int UnknownField::varint_size          () const {return varint_.size();}
-inline int UnknownField::fixed32_size         () const {return fixed32_.size();}
-inline int UnknownField::fixed64_size         () const {return fixed64_.size();}
-inline int UnknownField::length_delimited_size() const {
-  return length_delimited_.size();
+inline UnknownField::Type UnknownField::type() const {
+  return static_cast<Type>(type_);
 }
-inline int UnknownField::group_size           () const {return group_.size();}
 
-inline uint64 UnknownField::varint (int index) const {
-  return varint_.Get(index);
-}
-inline uint32 UnknownField::fixed32(int index) const {
-  return fixed32_.Get(index);
-}
-inline uint64 UnknownField::fixed64(int index) const {
-  return fixed64_.Get(index);
-}
-inline const string& UnknownField::length_delimited(int index) const {
-  return length_delimited_.Get(index);
-}
-inline const UnknownFieldSet& UnknownField::group(int index) const {
-  return group_.Get(index);
-}
-
-inline void UnknownField::set_varint (int index, uint64 value) {
-  varint_.Set(index, value);
-}
-inline void UnknownField::set_fixed32(int index, uint32 value) {
-  fixed32_.Set(index, value);
-}
-inline void UnknownField::set_fixed64(int index, uint64 value) {
-  fixed64_.Set(index, value);
-}
-inline void UnknownField::set_length_delimited(int index, const string& value) {
-  length_delimited_.Mutable(index)->assign(value);
-}
-inline string* UnknownField::mutable_length_delimited(int index) {
-  return length_delimited_.Mutable(index);
-}
-inline UnknownFieldSet* UnknownField::mutable_group(int index) {
-  return group_.Mutable(index);
-}
-
-inline void UnknownField::add_varint (uint64 value) {
-  varint_.Add(value);
-}
-inline void UnknownField::add_fixed32(uint32 value) {
-  fixed32_.Add(value);
-}
-inline void UnknownField::add_fixed64(uint64 value) {
-  fixed64_.Add(value);
-}
-inline void UnknownField::add_length_delimited(const string& value) {
-  length_delimited_.Add()->assign(value);
-}
-inline string* UnknownField::add_length_delimited() {
-  return length_delimited_.Add();
-}
-inline UnknownFieldSet* UnknownField::add_group() {
-  return group_.Add();
-}
-
-inline void UnknownField::clear_varint () { varint_.Clear(); }
-inline void UnknownField::clear_fixed32() { fixed32_.Clear(); }
-inline void UnknownField::clear_fixed64() { fixed64_.Clear(); }
-inline void UnknownField::clear_length_delimited() {
-  length_delimited_.Clear();
-}
-inline void UnknownField::clear_group() { group_.Clear(); }
-
-inline const RepeatedField<uint64>& UnknownField::varint () const {
+inline uint64 UnknownField::varint () const {
+  GOOGLE_DCHECK_EQ(type_, TYPE_VARINT);
   return varint_;
 }
-inline const RepeatedField<uint32>& UnknownField::fixed32() const {
+inline uint32 UnknownField::fixed32() const {
+  GOOGLE_DCHECK_EQ(type_, TYPE_FIXED32);
   return fixed32_;
 }
-inline const RepeatedField<uint64>& UnknownField::fixed64() const {
+inline uint64 UnknownField::fixed64() const {
+  GOOGLE_DCHECK_EQ(type_, TYPE_FIXED64);
   return fixed64_;
 }
-inline const RepeatedPtrField<string>& UnknownField::length_delimited() const {
-  return length_delimited_;
+inline const string& UnknownField::length_delimited() const {
+  GOOGLE_DCHECK_EQ(type_, TYPE_LENGTH_DELIMITED);
+  return *length_delimited_;
 }
-inline const RepeatedPtrField<UnknownFieldSet>& UnknownField::group() const {
-  return group_;
+inline const UnknownFieldSet& UnknownField::group() const {
+  GOOGLE_DCHECK_EQ(type_, TYPE_GROUP);
+  return *group_;
 }
 
-inline RepeatedField<uint64>* UnknownField::mutable_varint () {
-  return &varint_;
+inline void UnknownField::set_varint(uint64 value) {
+  GOOGLE_DCHECK_EQ(type_, TYPE_VARINT);
+  varint_ = value;
 }
-inline RepeatedField<uint32>* UnknownField::mutable_fixed32() {
-  return &fixed32_;
+inline void UnknownField::set_fixed32(uint32 value) {
+  GOOGLE_DCHECK_EQ(type_, TYPE_FIXED32);
+  fixed32_ = value;
 }
-inline RepeatedField<uint64>* UnknownField::mutable_fixed64() {
-  return &fixed64_;
+inline void UnknownField::set_fixed64(uint64 value) {
+  GOOGLE_DCHECK_EQ(type_, TYPE_FIXED64);
+  fixed64_ = value;
 }
-inline RepeatedPtrField<string>* UnknownField::mutable_length_delimited() {
-  return &length_delimited_;
+inline void UnknownField::set_length_delimited(const string& value) {
+  GOOGLE_DCHECK_EQ(type_, TYPE_LENGTH_DELIMITED);
+  length_delimited_->assign(value);
 }
-inline RepeatedPtrField<UnknownFieldSet>* UnknownField::mutable_group() {
-  return &group_;
+inline string* UnknownField::mutable_length_delimited() {
+  GOOGLE_DCHECK_EQ(type_, TYPE_LENGTH_DELIMITED);
+  return length_delimited_;
+}
+inline UnknownFieldSet* UnknownField::mutable_group() {
+  GOOGLE_DCHECK_EQ(type_, TYPE_GROUP);
+  return group_;
 }
 
 }  // namespace protobuf

+ 133 - 228
src/google/protobuf/unknown_field_set_unittest.cc

@@ -37,8 +37,8 @@
 
 #include <google/protobuf/unknown_field_set.h>
 #include <google/protobuf/descriptor.h>
-#include <google/protobuf/io/zero_copy_stream_impl.h>
 #include <google/protobuf/io/coded_stream.h>
+#include <google/protobuf/io/zero_copy_stream_impl.h>
 #include <google/protobuf/wire_format.h>
 #include <google/protobuf/unittest.pb.h>
 #include <google/protobuf/test_util.h>
@@ -46,6 +46,7 @@
 #include <google/protobuf/stubs/common.h>
 #include <google/protobuf/testing/googletest.h>
 #include <gtest/gtest.h>
+#include <google/protobuf/stubs/stl_util-inl.h>
 
 namespace google {
 namespace protobuf {
@@ -67,7 +68,12 @@ class UnknownFieldSetTest : public testing::Test {
   const UnknownField* GetField(const string& name) {
     const FieldDescriptor* field = descriptor_->FindFieldByName(name);
     if (field == NULL) return NULL;
-    return unknown_fields_->FindFieldByNumber(field->number());
+    for (int i = 0; i < unknown_fields_->field_count(); i++) {
+      if (unknown_fields_->field(i).number() == field->number()) {
+        return &unknown_fields_->field(i);
+      }
+    }
+    return NULL;
   }
 
   // Constructs a protocol buffer which contains fields with all the same
@@ -79,12 +85,10 @@ class UnknownFieldSetTest : public testing::Test {
       bizarro_message.mutable_unknown_fields();
     for (int i = 0; i < unknown_fields_->field_count(); i++) {
       const UnknownField& unknown_field = unknown_fields_->field(i);
-      UnknownField* bizarro_field =
-        bizarro_unknown_fields->AddField(unknown_field.number());
-      if (unknown_field.varint_size() == 0) {
-        bizarro_field->add_varint(1);
+      if (unknown_field.type() == UnknownField::TYPE_VARINT) {
+        bizarro_unknown_fields->AddFixed32(unknown_field.number(), 1);
       } else {
-        bizarro_field->add_fixed32(1);
+        bizarro_unknown_fields->AddVarint(unknown_field.number(), 1);
       }
     }
 
@@ -103,71 +107,98 @@ class UnknownFieldSetTest : public testing::Test {
   UnknownFieldSet* unknown_fields_;
 };
 
-TEST_F(UnknownFieldSetTest, Index) {
-  for (int i = 0; i < unknown_fields_->field_count(); i++) {
-    EXPECT_EQ(i, unknown_fields_->field(i).index());
-  }
-}
+TEST_F(UnknownFieldSetTest, AllFieldsPresent) {
+  // All fields of TestAllTypes should be present, in numeric order (because
+  // that's the order we parsed them in).  Fields that are not valid field
+  // numbers of TestAllTypes should NOT be present.
 
-TEST_F(UnknownFieldSetTest, FindFieldByNumber) {
-  // All fields of TestAllTypes should be present.  Fields that are not valid
-  // field numbers of TestAllTypes should NOT be present.
+  int pos = 0;
 
   for (int i = 0; i < 1000; i++) {
-    if (descriptor_->FindFieldByNumber(i) == NULL) {
-      EXPECT_TRUE(unknown_fields_->FindFieldByNumber(i) == NULL);
-    } else {
-      EXPECT_TRUE(unknown_fields_->FindFieldByNumber(i) != NULL);
+    const FieldDescriptor* field = descriptor_->FindFieldByNumber(i);
+    if (field != NULL) {
+      ASSERT_LT(pos, unknown_fields_->field_count());
+      EXPECT_EQ(i, unknown_fields_->field(pos++).number());
+      if (field->is_repeated()) {
+        // Should have a second instance.
+        ASSERT_LT(pos, unknown_fields_->field_count());
+        EXPECT_EQ(i, unknown_fields_->field(pos++).number());
+      }
     }
   }
+  EXPECT_EQ(unknown_fields_->field_count(), pos);
 }
 
 TEST_F(UnknownFieldSetTest, Varint) {
   const UnknownField* field = GetField("optional_int32");
   ASSERT_TRUE(field != NULL);
 
-  ASSERT_EQ(1, field->varint_size());
-  EXPECT_EQ(all_fields_.optional_int32(), field->varint(0));
+  ASSERT_EQ(UnknownField::TYPE_VARINT, field->type());
+  EXPECT_EQ(all_fields_.optional_int32(), field->varint());
 }
 
 TEST_F(UnknownFieldSetTest, Fixed32) {
   const UnknownField* field = GetField("optional_fixed32");
   ASSERT_TRUE(field != NULL);
 
-  ASSERT_EQ(1, field->fixed32_size());
-  EXPECT_EQ(all_fields_.optional_fixed32(), field->fixed32(0));
+  ASSERT_EQ(UnknownField::TYPE_FIXED32, field->type());
+  EXPECT_EQ(all_fields_.optional_fixed32(), field->fixed32());
 }
 
 TEST_F(UnknownFieldSetTest, Fixed64) {
   const UnknownField* field = GetField("optional_fixed64");
   ASSERT_TRUE(field != NULL);
 
-  ASSERT_EQ(1, field->fixed64_size());
-  EXPECT_EQ(all_fields_.optional_fixed64(), field->fixed64(0));
+  ASSERT_EQ(UnknownField::TYPE_FIXED64, field->type());
+  EXPECT_EQ(all_fields_.optional_fixed64(), field->fixed64());
 }
 
 TEST_F(UnknownFieldSetTest, LengthDelimited) {
   const UnknownField* field = GetField("optional_string");
   ASSERT_TRUE(field != NULL);
 
-  ASSERT_EQ(1, field->length_delimited_size());
-  EXPECT_EQ(all_fields_.optional_string(), field->length_delimited(0));
+  ASSERT_EQ(UnknownField::TYPE_LENGTH_DELIMITED, field->type());
+  EXPECT_EQ(all_fields_.optional_string(), field->length_delimited());
 }
 
 TEST_F(UnknownFieldSetTest, Group) {
   const UnknownField* field = GetField("optionalgroup");
   ASSERT_TRUE(field != NULL);
 
-  ASSERT_EQ(1, field->group_size());
-  EXPECT_EQ(1, field->group(0).field_count());
+  ASSERT_EQ(UnknownField::TYPE_GROUP, field->type());
+  ASSERT_EQ(1, field->group().field_count());
 
-  const UnknownField& nested_field = field->group(0).field(0);
+  const UnknownField& nested_field = field->group().field(0);
   const FieldDescriptor* nested_field_descriptor =
     unittest::TestAllTypes::OptionalGroup::descriptor()->FindFieldByName("a");
   ASSERT_TRUE(nested_field_descriptor != NULL);
 
   EXPECT_EQ(nested_field_descriptor->number(), nested_field.number());
-  EXPECT_EQ(all_fields_.optionalgroup().a(), nested_field.varint(0));
+  ASSERT_EQ(UnknownField::TYPE_VARINT, nested_field.type());
+  EXPECT_EQ(all_fields_.optionalgroup().a(), nested_field.varint());
+}
+
+TEST_F(UnknownFieldSetTest, SerializeFastAndSlowAreEquivalent) {
+  int size = WireFormat::ComputeUnknownFieldsSize(
+      empty_message_.unknown_fields());
+  string slow_buffer;
+  string fast_buffer;
+  slow_buffer.resize(size);
+  fast_buffer.resize(size);
+
+  uint8* target = reinterpret_cast<uint8*>(string_as_array(&fast_buffer));
+  uint8* result = WireFormat::SerializeUnknownFieldsToArray(
+          empty_message_.unknown_fields(), target);
+  EXPECT_EQ(size, result - target);
+
+  {
+    io::ArrayOutputStream raw_stream(string_as_array(&slow_buffer), size, 1);
+    io::CodedOutputStream output_stream(&raw_stream);
+    WireFormat::SerializeUnknownFields(empty_message_.unknown_fields(),
+                                       &output_stream);
+    ASSERT_FALSE(output_stream.HadError());
+  }
+  EXPECT_TRUE(fast_buffer == slow_buffer);
 }
 
 TEST_F(UnknownFieldSetTest, Serialize) {
@@ -205,8 +236,8 @@ TEST_F(UnknownFieldSetTest, SerializeViaReflection) {
     io::StringOutputStream raw_output(&data);
     io::CodedOutputStream output(&raw_output);
     int size = WireFormat::ByteSize(empty_message_);
-    ASSERT_TRUE(
-      WireFormat::SerializeWithCachedSizes(empty_message_, size, &output));
+    WireFormat::SerializeWithCachedSizes(empty_message_, size, &output);
+    ASSERT_FALSE(output.HadError());
   }
 
   // Don't use EXPECT_EQ because we don't want to dump raw binary data to
@@ -249,10 +280,10 @@ TEST_F(UnknownFieldSetTest, SwapWithSelf) {
 TEST_F(UnknownFieldSetTest, MergeFrom) {
   unittest::TestEmptyMessage source, destination;
 
-  destination.mutable_unknown_fields()->AddField(1)->add_varint(1);
-  destination.mutable_unknown_fields()->AddField(3)->add_varint(2);
-  source.mutable_unknown_fields()->AddField(2)->add_varint(3);
-  source.mutable_unknown_fields()->AddField(3)->add_varint(4);
+  destination.mutable_unknown_fields()->AddVarint(1, 1);
+  destination.mutable_unknown_fields()->AddVarint(3, 2);
+  source.mutable_unknown_fields()->AddVarint(2, 3);
+  source.mutable_unknown_fields()->AddVarint(3, 4);
 
   destination.MergeFrom(source);
 
@@ -261,34 +292,22 @@ TEST_F(UnknownFieldSetTest, MergeFrom) {
     //   and merging, above.
     "1: 1\n"
     "3: 2\n"
-    "3: 4\n"
-    "2: 3\n",
+    "2: 3\n"
+    "3: 4\n",
     destination.DebugString());
 }
 
 TEST_F(UnknownFieldSetTest, Clear) {
-  // Get a pointer to a contained field object.
-  const UnknownField* field = GetField("optional_int32");
-  ASSERT_TRUE(field != NULL);
-  ASSERT_EQ(1, field->varint_size());
-  int number = field->number();
-
   // Clear the set.
   empty_message_.Clear();
   EXPECT_EQ(0, unknown_fields_->field_count());
-
-  // If we add that field again we should get the same object.
-  ASSERT_EQ(field, unknown_fields_->AddField(number));
-
-  // But it should be cleared.
-  EXPECT_EQ(0, field->varint_size());
 }
 
 TEST_F(UnknownFieldSetTest, ParseKnownAndUnknown) {
   // Test mixing known and unknown fields when parsing.
 
   unittest::TestEmptyMessage source;
-  source.mutable_unknown_fields()->AddField(123456)->add_varint(654321);
+  source.mutable_unknown_fields()->AddVarint(123456, 654321);
   string data;
   ASSERT_TRUE(source.SerializeToString(&data));
 
@@ -297,8 +316,9 @@ TEST_F(UnknownFieldSetTest, ParseKnownAndUnknown) {
 
   TestUtil::ExpectAllFieldsSet(destination);
   ASSERT_EQ(1, destination.unknown_fields().field_count());
-  ASSERT_EQ(1, destination.unknown_fields().field(0).varint_size());
-  EXPECT_EQ(654321, destination.unknown_fields().field(0).varint(0));
+  ASSERT_EQ(UnknownField::TYPE_VARINT,
+            destination.unknown_fields().field(0).type());
+  EXPECT_EQ(654321, destination.unknown_fields().field(0).varint());
 }
 
 TEST_F(UnknownFieldSetTest, WrongTypeTreatedAsUnknown) {
@@ -384,16 +404,12 @@ TEST_F(UnknownFieldSetTest, UnknownEnumValue) {
   {
     TestEmptyMessage empty_message;
     UnknownFieldSet* unknown_fields = empty_message.mutable_unknown_fields();
-    UnknownField* singular_unknown_field =
-      unknown_fields->AddField(singular_field->number());
-    singular_unknown_field->add_varint(TestAllTypes::BAR);
-    singular_unknown_field->add_varint(5);  // not valid
-    UnknownField* repeated_unknown_field =
-      unknown_fields->AddField(repeated_field->number());
-    repeated_unknown_field->add_varint(TestAllTypes::FOO);
-    repeated_unknown_field->add_varint(4);  // not valid
-    repeated_unknown_field->add_varint(TestAllTypes::BAZ);
-    repeated_unknown_field->add_varint(6);  // not valid
+    unknown_fields->AddVarint(singular_field->number(), TestAllTypes::BAR);
+    unknown_fields->AddVarint(singular_field->number(), 5);  // not valid
+    unknown_fields->AddVarint(repeated_field->number(), TestAllTypes::FOO);
+    unknown_fields->AddVarint(repeated_field->number(), 4);  // not valid
+    unknown_fields->AddVarint(repeated_field->number(), TestAllTypes::BAZ);
+    unknown_fields->AddVarint(repeated_field->number(), 6);  // not valid
     empty_message.SerializeToString(&data);
   }
 
@@ -406,18 +422,19 @@ TEST_F(UnknownFieldSetTest, UnknownEnumValue) {
     EXPECT_EQ(TestAllTypes::BAZ, message.repeated_nested_enum(1));
 
     const UnknownFieldSet& unknown_fields = message.unknown_fields();
-    ASSERT_EQ(2, unknown_fields.field_count());
-
-    const UnknownField& singular_unknown_field = unknown_fields.field(0);
-    ASSERT_EQ(singular_field->number(), singular_unknown_field.number());
-    ASSERT_EQ(1, singular_unknown_field.varint_size());
-    EXPECT_EQ(5, singular_unknown_field.varint(0));
-
-    const UnknownField& repeated_unknown_field = unknown_fields.field(1);
-    ASSERT_EQ(repeated_field->number(), repeated_unknown_field.number());
-    ASSERT_EQ(2, repeated_unknown_field.varint_size());
-    EXPECT_EQ(4, repeated_unknown_field.varint(0));
-    EXPECT_EQ(6, repeated_unknown_field.varint(1));
+    ASSERT_EQ(3, unknown_fields.field_count());
+
+    EXPECT_EQ(singular_field->number(), unknown_fields.field(0).number());
+    ASSERT_EQ(UnknownField::TYPE_VARINT, unknown_fields.field(0).type());
+    EXPECT_EQ(5, unknown_fields.field(0).varint());
+
+    EXPECT_EQ(repeated_field->number(), unknown_fields.field(1).number());
+    ASSERT_EQ(UnknownField::TYPE_VARINT, unknown_fields.field(1).type());
+    EXPECT_EQ(4, unknown_fields.field(1).varint());
+
+    EXPECT_EQ(repeated_field->number(), unknown_fields.field(2).number());
+    ASSERT_EQ(UnknownField::TYPE_VARINT, unknown_fields.field(2).type());
+    EXPECT_EQ(6, unknown_fields.field(2).varint());
   }
 
   {
@@ -435,173 +452,61 @@ TEST_F(UnknownFieldSetTest, UnknownEnumValue) {
               message.GetExtension(repeated_nested_enum_extension, 1));
 
     const UnknownFieldSet& unknown_fields = message.unknown_fields();
-    ASSERT_EQ(2, unknown_fields.field_count());
-
-    const UnknownField& singular_unknown_field = unknown_fields.field(0);
-    ASSERT_EQ(singular_field->number(), singular_unknown_field.number());
-    ASSERT_EQ(1, singular_unknown_field.varint_size());
-    EXPECT_EQ(5, singular_unknown_field.varint(0));
-
-    const UnknownField& repeated_unknown_field = unknown_fields.field(1);
-    ASSERT_EQ(repeated_field->number(), repeated_unknown_field.number());
-    ASSERT_EQ(2, repeated_unknown_field.varint_size());
-    EXPECT_EQ(4, repeated_unknown_field.varint(0));
-    EXPECT_EQ(6, repeated_unknown_field.varint(1));
-  }
-}
+    ASSERT_EQ(3, unknown_fields.field_count());
 
-TEST_F(UnknownFieldSetTest, SpaceUsedExcludingSelf) {
-  {
-    // Make sure an unknown field set has zero space used until a field is
-    // actually added.
-    unittest::TestEmptyMessage empty_message;
-    const int empty_message_size = empty_message.SpaceUsed();
-    UnknownFieldSet* unknown_fields = empty_message.mutable_unknown_fields();
-    EXPECT_EQ(empty_message_size, empty_message.SpaceUsed());
-    unknown_fields->AddField(1)->add_varint(0);
-    EXPECT_LT(empty_message_size, empty_message.SpaceUsed());
-  }
-  {
-    // Test varints.
-    UnknownFieldSet unknown_fields;
-    UnknownField* field = unknown_fields.AddField(1);
-    const int base_size = unknown_fields.SpaceUsedExcludingSelf();
-    for (int i = 0; i < 16; ++i) {
-      field->add_varint(i);
-    }
-    // Should just defer computation to the RepeatedField.
-    int expected_size = base_size + field->varint().SpaceUsedExcludingSelf();
-    EXPECT_EQ(expected_size, unknown_fields.SpaceUsedExcludingSelf());
-  }
-  {
-    // Test fixed32s.
-    UnknownFieldSet unknown_fields;
-    UnknownField* field = unknown_fields.AddField(1);
-    const int base_size = unknown_fields.SpaceUsedExcludingSelf();
-    for (int i = 0; i < 16; ++i) {
-      field->add_fixed32(i);
-    }
-    int expected_size = base_size + field->fixed32().SpaceUsedExcludingSelf();
-    EXPECT_EQ(expected_size, unknown_fields.SpaceUsedExcludingSelf());
-  }
-  {
-    // Test fixed64s.
-    UnknownFieldSet unknown_fields;
-    UnknownField* field = unknown_fields.AddField(1);
-    const int base_size = unknown_fields.SpaceUsedExcludingSelf();
-    for (int i = 0; i < 16; ++i) {
-      field->add_fixed64(i);
-    }
-    int expected_size = base_size + field->fixed64().SpaceUsedExcludingSelf();
-    EXPECT_EQ(expected_size, unknown_fields.SpaceUsedExcludingSelf());
-  }
-  {
-    // Test length-delimited types.
-    UnknownFieldSet unknown_fields;
-    UnknownField* field = unknown_fields.AddField(1);
-    const int base_size = unknown_fields.SpaceUsedExcludingSelf();
-    for (int i = 0; i < 16; ++i) {
-      field->add_length_delimited()->assign("my length delimited string");
-    }
-    int expected_size = base_size +
-        field->length_delimited().SpaceUsedExcludingSelf();
-    EXPECT_EQ(expected_size, unknown_fields.SpaceUsedExcludingSelf());
+    EXPECT_EQ(singular_field->number(), unknown_fields.field(0).number());
+    ASSERT_EQ(UnknownField::TYPE_VARINT, unknown_fields.field(0).type());
+    EXPECT_EQ(5, unknown_fields.field(0).varint());
+
+    EXPECT_EQ(repeated_field->number(), unknown_fields.field(1).number());
+    ASSERT_EQ(UnknownField::TYPE_VARINT, unknown_fields.field(1).type());
+    EXPECT_EQ(4, unknown_fields.field(1).varint());
+
+    EXPECT_EQ(repeated_field->number(), unknown_fields.field(2).number());
+    ASSERT_EQ(UnknownField::TYPE_VARINT, unknown_fields.field(2).type());
+    EXPECT_EQ(6, unknown_fields.field(2).varint());
   }
 }
 
 TEST_F(UnknownFieldSetTest, SpaceUsed) {
-  UnknownFieldSet unknown_fields;
-  const int expected_size = sizeof(unknown_fields) +
-      unknown_fields.SpaceUsedExcludingSelf();
-  EXPECT_EQ(expected_size, unknown_fields.SpaceUsed());
+  unittest::TestEmptyMessage empty_message;
+
+  // Make sure an unknown field set has zero space used until a field is
+  // actually added.
+  int base_size = empty_message.SpaceUsed();
+  UnknownFieldSet* unknown_fields = empty_message.mutable_unknown_fields();
+  EXPECT_EQ(base_size, empty_message.SpaceUsed());
+
+  // Make sure each thing we add to the set increases the SpaceUsed().
+  unknown_fields->AddVarint(1, 0);
+  EXPECT_LT(base_size, empty_message.SpaceUsed());
+  base_size = empty_message.SpaceUsed();
+
+  string* str = unknown_fields->AddLengthDelimited(1);
+  EXPECT_LT(base_size, empty_message.SpaceUsed());
+  base_size = empty_message.SpaceUsed();
+
+  str->assign(sizeof(string) + 1, 'x');
+  EXPECT_LT(base_size, empty_message.SpaceUsed());
+  base_size = empty_message.SpaceUsed();
+
+  UnknownFieldSet* group = unknown_fields->AddGroup(1);
+  EXPECT_LT(base_size, empty_message.SpaceUsed());
+  base_size = empty_message.SpaceUsed();
+
+  group->AddVarint(1, 0);
+  EXPECT_LT(base_size, empty_message.SpaceUsed());
 }
 
 TEST_F(UnknownFieldSetTest, Empty) {
   UnknownFieldSet unknown_fields;
   EXPECT_TRUE(unknown_fields.empty());
-  unknown_fields.AddField(6)->add_varint(123);
+  unknown_fields.AddVarint(6, 123);
   EXPECT_FALSE(unknown_fields.empty());
   unknown_fields.Clear();
   EXPECT_TRUE(unknown_fields.empty());
 }
 
-TEST_F(UnknownFieldSetTest, FieldEmpty) {
-  UnknownFieldSet unknown_fields;
-  UnknownField* field = unknown_fields.AddField(1);
-
-  EXPECT_TRUE(field->empty());
-
-  field->add_varint(1);
-  EXPECT_FALSE(field->empty());
-  field->Clear();
-  EXPECT_TRUE(field->empty());
-
-  field->add_fixed32(1);
-  EXPECT_FALSE(field->empty());
-  field->Clear();
-  EXPECT_TRUE(field->empty());
-
-  field->add_fixed64(1);
-  EXPECT_FALSE(field->empty());
-  field->Clear();
-  EXPECT_TRUE(field->empty());
-
-  field->add_length_delimited("foo");
-  EXPECT_FALSE(field->empty());
-  field->Clear();
-  EXPECT_TRUE(field->empty());
-
-  field->add_group();
-  EXPECT_FALSE(field->empty());
-  field->Clear();
-  EXPECT_TRUE(field->empty());
-}
-
-TEST_F(UnknownFieldSetTest, Iterator) {
-  UnknownFieldSet unknown_fields;
-  EXPECT_TRUE(unknown_fields.begin() == unknown_fields.end());
-
-  // Populate the UnknownFieldSet with some inactive fields by adding some
-  // fields and then clearing.
-  unknown_fields.AddField(6);
-  unknown_fields.AddField(4);
-  unknown_fields.Clear();
-
-  // Add a bunch of "active" fields.
-  UnknownField* a = unknown_fields.AddField(5);
-  unknown_fields.AddField(3);
-  unknown_fields.AddField(9);
-  unknown_fields.AddField(1);
-  UnknownField* b = unknown_fields.AddField(2);
-
-  // Only make some of them non-empty.
-  a->add_varint(1);
-  b->add_length_delimited("foo");
-
-  // Iterate!
-  {
-    UnknownFieldSet::iterator iter = unknown_fields.begin();
-    ASSERT_TRUE(iter != unknown_fields.end());
-    EXPECT_EQ(b, &*iter);
-    ++iter;
-    ASSERT_TRUE(iter != unknown_fields.end());
-    EXPECT_EQ(a, &*iter);
-    ++iter;
-    EXPECT_TRUE(iter == unknown_fields.end());
-  }
-
-  {
-    UnknownFieldSet::const_iterator iter = unknown_fields.begin();
-    ASSERT_TRUE(iter != unknown_fields.end());
-    EXPECT_EQ(b, &*iter);
-    ++iter;
-    ASSERT_TRUE(iter != unknown_fields.end());
-    EXPECT_EQ(a, &*iter);
-    ++iter;
-    EXPECT_TRUE(iter == unknown_fields.end());
-  }
-}
-
 }  // namespace
 }  // namespace protobuf
 }  // namespace google

+ 169 - 128
src/google/protobuf/wire_format.cc

@@ -109,29 +109,29 @@ WireFormat::kWireTypeForFieldType[FieldDescriptor::MAX_TYPE + 1] = {
 
 bool WireFormat::SkipField(io::CodedInputStream* input, uint32 tag,
                            UnknownFieldSet* unknown_fields) {
-  UnknownField* field = (unknown_fields == NULL) ? NULL :
-    unknown_fields->AddField(GetTagFieldNumber(tag));
+  int number = GetTagFieldNumber(tag);
 
   switch (GetTagWireType(tag)) {
     case WIRETYPE_VARINT: {
       uint64 value;
       if (!input->ReadVarint64(&value)) return false;
-      if (field != NULL) field->add_varint(value);
+      if (unknown_fields != NULL) unknown_fields->AddVarint(number, value);
       return true;
     }
     case WIRETYPE_FIXED64: {
       uint64 value;
       if (!input->ReadLittleEndian64(&value)) return false;
-      if (field != NULL) field->add_fixed64(value);
+      if (unknown_fields != NULL) unknown_fields->AddFixed64(number, value);
       return true;
     }
     case WIRETYPE_LENGTH_DELIMITED: {
       uint32 length;
       if (!input->ReadVarint32(&length)) return false;
-      if (field == NULL) {
+      if (unknown_fields == NULL) {
         if (!input->Skip(length)) return false;
       } else {
-        if (!input->ReadString(field->add_length_delimited(), length)) {
+        if (!input->ReadString(unknown_fields->AddLengthDelimited(number),
+                               length)) {
           return false;
         }
       }
@@ -139,7 +139,8 @@ bool WireFormat::SkipField(io::CodedInputStream* input, uint32 tag,
     }
     case WIRETYPE_START_GROUP: {
       if (!input->IncrementRecursionDepth()) return false;
-      if (!SkipMessage(input, (field == NULL) ? NULL : field->add_group())) {
+      if (!SkipMessage(input, (unknown_fields == NULL) ?
+                              NULL : unknown_fields->AddGroup(number))) {
         return false;
       }
       input->DecrementRecursionDepth();
@@ -156,7 +157,7 @@ bool WireFormat::SkipField(io::CodedInputStream* input, uint32 tag,
     case WIRETYPE_FIXED32: {
       uint32 value;
       if (!input->ReadLittleEndian32(&value)) return false;
-      if (field != NULL) field->add_fixed32(value);
+      if (unknown_fields != NULL) unknown_fields->AddFixed32(number, value);
       return true;
     }
     default: {
@@ -185,72 +186,130 @@ bool WireFormat::SkipMessage(io::CodedInputStream* input,
   }
 }
 
-bool WireFormat::SerializeUnknownFields(const UnknownFieldSet& unknown_fields,
+void WireFormat::SerializeUnknownFields(const UnknownFieldSet& unknown_fields,
                                         io::CodedOutputStream* output) {
   for (int i = 0; i < unknown_fields.field_count(); i++) {
     const UnknownField& field = unknown_fields.field(i);
-
-#define DO(EXPRESSION) if (!(EXPRESSION)) return false
-    for (int j = 0; j < field.varint_size(); j++) {
-      DO(output->WriteVarint32(MakeTag(field.number(), WIRETYPE_VARINT)));
-      DO(output->WriteVarint64(field.varint(j)));
-    }
-    for (int j = 0; j < field.fixed32_size(); j++) {
-      DO(output->WriteVarint32(MakeTag(field.number(), WIRETYPE_FIXED32)));
-      DO(output->WriteLittleEndian32(field.fixed32(j)));
-    }
-    for (int j = 0; j < field.fixed64_size(); j++) {
-      DO(output->WriteVarint32(MakeTag(field.number(), WIRETYPE_FIXED64)));
-      DO(output->WriteLittleEndian64(field.fixed64(j)));
-    }
-    for (int j = 0; j < field.length_delimited_size(); j++) {
-      DO(output->WriteVarint32(
-        MakeTag(field.number(), WIRETYPE_LENGTH_DELIMITED)));
-      DO(output->WriteVarint32(field.length_delimited(j).size()));
-      DO(output->WriteString(field.length_delimited(j)));
-    }
-    for (int j = 0; j < field.group_size(); j++) {
-      DO(output->WriteVarint32(MakeTag(field.number(), WIRETYPE_START_GROUP)));
-      DO(SerializeUnknownFields(field.group(j), output));
-      DO(output->WriteVarint32(MakeTag(field.number(), WIRETYPE_END_GROUP)));
+    switch (field.type()) {
+      case UnknownField::TYPE_VARINT:
+        output->WriteVarint32(MakeTag(field.number(), WIRETYPE_VARINT));
+        output->WriteVarint64(field.varint());
+        break;
+      case UnknownField::TYPE_FIXED32:
+        output->WriteVarint32(MakeTag(field.number(), WIRETYPE_FIXED32));
+        output->WriteLittleEndian32(field.fixed32());
+        break;
+      case UnknownField::TYPE_FIXED64:
+        output->WriteVarint32(MakeTag(field.number(), WIRETYPE_FIXED64));
+        output->WriteLittleEndian64(field.fixed64());
+        break;
+      case UnknownField::TYPE_LENGTH_DELIMITED:
+        output->WriteVarint32(
+            MakeTag(field.number(), WIRETYPE_LENGTH_DELIMITED));
+        output->WriteVarint32(field.length_delimited().size());
+        output->WriteString(field.length_delimited());
+        break;
+      case UnknownField::TYPE_GROUP:
+        output->WriteVarint32(MakeTag(field.number(),WIRETYPE_START_GROUP));
+        SerializeUnknownFields(field.group(), output);
+        output->WriteVarint32(MakeTag(field.number(), WIRETYPE_END_GROUP));
+        break;
     }
-#undef DO
   }
+}
 
-  return true;
+uint8* WireFormat::SerializeUnknownFieldsToArray(
+    const UnknownFieldSet& unknown_fields,
+    uint8* target) {
+  for (int i = 0; i < unknown_fields.field_count(); i++) {
+    const UnknownField& field = unknown_fields.field(i);
+
+    switch (field.type()) {
+      case UnknownField::TYPE_VARINT:
+        target = WriteInt64ToArray(field.number(), field.varint(), target);
+        break;
+      case UnknownField::TYPE_FIXED32:
+        target = WriteFixed32ToArray(field.number(), field.fixed32(), target);
+        break;
+      case UnknownField::TYPE_FIXED64:
+        target = WriteFixed64ToArray(field.number(), field.fixed64(), target);
+        break;
+      case UnknownField::TYPE_LENGTH_DELIMITED:
+        target =
+          WriteBytesToArray(field.number(), field.length_delimited(), target);
+        break;
+      case UnknownField::TYPE_GROUP:
+        target = WriteTagToArray(field.number(), WIRETYPE_START_GROUP, target);
+        target = SerializeUnknownFieldsToArray(field.group(), target);
+        target = WriteTagToArray(field.number(), WIRETYPE_END_GROUP, target);
+        break;
+    }
+  }
+  return target;
 }
 
-bool WireFormat::SerializeUnknownMessageSetItems(
+void WireFormat::SerializeUnknownMessageSetItems(
     const UnknownFieldSet& unknown_fields,
     io::CodedOutputStream* output) {
   for (int i = 0; i < unknown_fields.field_count(); i++) {
     const UnknownField& field = unknown_fields.field(i);
+    // The only unknown fields that are allowed to exist in a MessageSet are
+    // messages, which are length-delimited.
+    if (field.type() == UnknownField::TYPE_LENGTH_DELIMITED) {
+      const string& data = field.length_delimited();
+
+      // Start group.
+      output->WriteVarint32(kMessageSetItemStartTag);
+
+      // Write type ID.
+      output->WriteVarint32(kMessageSetTypeIdTag);
+      output->WriteVarint32(field.number());
+
+      // Write message.
+      output->WriteVarint32(kMessageSetMessageTag);
+      output->WriteVarint32(data.size());
+      output->WriteString(data);
+
+      // End group.
+      output->WriteVarint32(kMessageSetItemEndTag);
+    }
+  }
+}
+
+uint8* WireFormat::SerializeUnknownMessageSetItemsToArray(
+    const UnknownFieldSet& unknown_fields,
+    uint8* target) {
+  for (int i = 0; i < unknown_fields.field_count(); i++) {
+    const UnknownField& field = unknown_fields.field(i);
 
-#define DO(EXPRESSION) if (!(EXPRESSION)) return false
     // The only unknown fields that are allowed to exist in a MessageSet are
     // messages, which are length-delimited.
-    for (int j = 0; j < field.length_delimited_size(); j++) {
-      const string& data = field.length_delimited(j);
+    if (field.type() == UnknownField::TYPE_LENGTH_DELIMITED) {
+      const string& data = field.length_delimited();
 
       // Start group.
-      DO(output->WriteVarint32(kMessageSetItemStartTag));
+      target =
+        io::CodedOutputStream::WriteTagToArray(kMessageSetItemStartTag, target);
 
       // Write type ID.
-      DO(output->WriteVarint32(kMessageSetTypeIdTag));
-      DO(output->WriteVarint32(field.number()));
+      target =
+        io::CodedOutputStream::WriteTagToArray(kMessageSetTypeIdTag, target);
+      target =
+        io::CodedOutputStream::WriteVarint32ToArray(field.number(), target);
 
       // Write message.
-      DO(output->WriteVarint32(kMessageSetMessageTag));
-      DO(output->WriteVarint32(data.size()));
-      DO(output->WriteString(data));
+      target =
+        io::CodedOutputStream::WriteTagToArray(kMessageSetMessageTag, target);
+      target = io::CodedOutputStream::WriteVarint32ToArray(data.size(), target);
+      target = io::CodedOutputStream::WriteStringToArray(data, target);
 
       // End group.
-      DO(output->WriteVarint32(kMessageSetItemEndTag));
+      target =
+        io::CodedOutputStream::WriteTagToArray(kMessageSetItemEndTag, target);
     }
-#undef DO
   }
 
-  return true;
+  return target;
 }
 
 int WireFormat::ComputeUnknownFieldsSize(
@@ -259,34 +318,36 @@ int WireFormat::ComputeUnknownFieldsSize(
   for (int i = 0; i < unknown_fields.field_count(); i++) {
     const UnknownField& field = unknown_fields.field(i);
 
-    for (int j = 0; j < field.varint_size(); j++) {
-      size += io::CodedOutputStream::VarintSize32(
-        MakeTag(field.number(), WIRETYPE_VARINT));
-      size += io::CodedOutputStream::VarintSize64(field.varint(j));
-    }
-    for (int j = 0; j < field.fixed32_size(); j++) {
-      size += io::CodedOutputStream::VarintSize32(
-        MakeTag(field.number(), WIRETYPE_FIXED32));
-      size += sizeof(int32);
-    }
-    for (int j = 0; j < field.fixed64_size(); j++) {
-      size += io::CodedOutputStream::VarintSize32(
-        MakeTag(field.number(), WIRETYPE_FIXED64));
-      size += sizeof(int64);
-    }
-    for (int j = 0; j < field.length_delimited_size(); j++) {
-      size += io::CodedOutputStream::VarintSize32(
-        MakeTag(field.number(), WIRETYPE_LENGTH_DELIMITED));
-      size += io::CodedOutputStream::VarintSize32(
-        field.length_delimited(j).size());
-      size += field.length_delimited(j).size();
-    }
-    for (int j = 0; j < field.group_size(); j++) {
-      size += io::CodedOutputStream::VarintSize32(
-        MakeTag(field.number(), WIRETYPE_START_GROUP));
-      size += ComputeUnknownFieldsSize(field.group(j));
-      size += io::CodedOutputStream::VarintSize32(
-        MakeTag(field.number(), WIRETYPE_END_GROUP));
+    switch (field.type()) {
+      case UnknownField::TYPE_VARINT:
+        size += io::CodedOutputStream::VarintSize32(
+          MakeTag(field.number(), WIRETYPE_VARINT));
+        size += io::CodedOutputStream::VarintSize64(field.varint());
+        break;
+      case UnknownField::TYPE_FIXED32:
+        size += io::CodedOutputStream::VarintSize32(
+          MakeTag(field.number(), WIRETYPE_FIXED32));
+        size += sizeof(int32);
+        break;
+      case UnknownField::TYPE_FIXED64:
+        size += io::CodedOutputStream::VarintSize32(
+          MakeTag(field.number(), WIRETYPE_FIXED64));
+        size += sizeof(int64);
+        break;
+      case UnknownField::TYPE_LENGTH_DELIMITED:
+        size += io::CodedOutputStream::VarintSize32(
+          MakeTag(field.number(), WIRETYPE_LENGTH_DELIMITED));
+        size += io::CodedOutputStream::VarintSize32(
+          field.length_delimited().size());
+        size += field.length_delimited().size();
+        break;
+      case UnknownField::TYPE_GROUP:
+        size += io::CodedOutputStream::VarintSize32(
+          MakeTag(field.number(), WIRETYPE_START_GROUP));
+        size += ComputeUnknownFieldsSize(field.group());
+        size += io::CodedOutputStream::VarintSize32(
+          MakeTag(field.number(), WIRETYPE_END_GROUP));
+        break;
     }
   }
 
@@ -301,12 +362,12 @@ int WireFormat::ComputeUnknownMessageSetItemsSize(
 
     // The only unknown fields that are allowed to exist in a MessageSet are
     // messages, which are length-delimited.
-    for (int j = 0; j < field.length_delimited_size(); j++) {
+    if (field.type() == UnknownField::TYPE_LENGTH_DELIMITED) {
       size += kMessageSetItemTagsSize;
       size += io::CodedOutputStream::VarintSize32(field.number());
       size += io::CodedOutputStream::VarintSize32(
-        field.length_delimited(j).size());
-      size += field.length_delimited(j).size();
+        field.length_delimited().size());
+      size += field.length_delimited().size();
     }
   }
 
@@ -487,8 +548,8 @@ bool WireFormat::ParseAndMergeField(
           // UnknownFieldSet.
           int64 sign_extended_value = static_cast<int64>(value);
           message_reflection->MutableUnknownFields(message)
-                            ->AddField(GetTagFieldNumber(tag))
-                            ->add_varint(sign_extended_value);
+                            ->AddVarint(GetTagFieldNumber(tag),
+                                        sign_extended_value);
         }
         break;
       }
@@ -607,7 +668,7 @@ bool WireFormat::ParseAndMergeMessageSetItem(
 
 // ===================================================================
 
-bool WireFormat::SerializeWithCachedSizes(
+void WireFormat::SerializeWithCachedSizes(
     const Message& message,
     int size, io::CodedOutputStream* output) {
   const Descriptor* descriptor = message.GetDescriptor();
@@ -617,32 +678,24 @@ bool WireFormat::SerializeWithCachedSizes(
   vector<const FieldDescriptor*> fields;
   message_reflection->ListFields(message, &fields);
   for (int i = 0; i < fields.size(); i++) {
-    if (!SerializeFieldWithCachedSizes(fields[i], message, output)) {
-      return false;
-    }
+    SerializeFieldWithCachedSizes(fields[i], message, output);
   }
 
   if (descriptor->options().message_set_wire_format()) {
-    if (!SerializeUnknownMessageSetItems(
-           message_reflection->GetUnknownFields(message), output)) {
-      return false;
-    }
+    SerializeUnknownMessageSetItems(
+        message_reflection->GetUnknownFields(message), output);
   } else {
-    if (!SerializeUnknownFields(
-           message_reflection->GetUnknownFields(message), output)) {
-      return false;
-    }
+    SerializeUnknownFields(
+        message_reflection->GetUnknownFields(message), output);
   }
 
   GOOGLE_CHECK_EQ(output->ByteCount(), expected_endpoint)
     << ": Protocol message serialized to a size different from what was "
        "originally expected.  Perhaps it was modified by another thread "
        "during serialization?";
-
-  return true;
 }
 
-bool WireFormat::SerializeFieldWithCachedSizes(
+void WireFormat::SerializeFieldWithCachedSizes(
     const FieldDescriptor* field,
     const Message& message,
     io::CodedOutputStream* output) {
@@ -652,8 +705,8 @@ bool WireFormat::SerializeFieldWithCachedSizes(
       field->containing_type()->options().message_set_wire_format() &&
       field->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE &&
       !field->is_repeated()) {
-    return SerializeMessageSetItemWithCachedSizes(
-      field, message, output);
+    SerializeMessageSetItemWithCachedSizes(field, message, output);
+    return;
   }
 
   int count = 0;
@@ -666,10 +719,9 @@ bool WireFormat::SerializeFieldWithCachedSizes(
 
   const bool is_packed = field->options().packed();
   if (is_packed && count > 0) {
-    if (!WriteTag(field->number(), WIRETYPE_LENGTH_DELIMITED, output))
-      return false;
+    WriteTag(field->number(), WIRETYPE_LENGTH_DELIMITED, output);
     const int data_size = FieldDataOnlyByteSize(field, message);
-    if (!output->WriteVarint32(data_size)) return false;
+    output->WriteVarint32(data_size);
   }
 
   for (int j = 0; j < count; j++) {
@@ -682,13 +734,9 @@ bool WireFormat::SerializeFieldWithCachedSizes(
                               message_reflection->Get##CPPTYPE_METHOD(         \
                                 message, field);                               \
         if (is_packed) {                                                       \
-          if (!Write##TYPE_METHOD##NoTag(value, output)) {                     \
-            return false;                                                      \
-          }                                                                    \
+          Write##TYPE_METHOD##NoTag(value, output);                            \
         } else {                                                               \
-          if (!Write##TYPE_METHOD(field->number(), value, output)) {           \
-            return false;                                                      \
-          }                                                                    \
+          Write##TYPE_METHOD(field->number(), value, output);                  \
         }                                                                      \
         break;                                                                 \
       }
@@ -713,15 +761,13 @@ bool WireFormat::SerializeFieldWithCachedSizes(
 
 #define HANDLE_TYPE(TYPE, TYPE_METHOD, CPPTYPE_METHOD)                       \
       case FieldDescriptor::TYPE_##TYPE:                                     \
-        if (!Write##TYPE_METHOD(                                             \
+        Write##TYPE_METHOD(                                                  \
               field->number(),                                               \
               field->is_repeated() ?                                         \
                 message_reflection->GetRepeated##CPPTYPE_METHOD(             \
                   message, field, j) :                                       \
                 message_reflection->Get##CPPTYPE_METHOD(message, field),     \
-              output)) {                                                     \
-          return false;                                                      \
-        }                                                                    \
+              output);                                                       \
         break;
 
       HANDLE_TYPE(GROUP  , Group  , Message)
@@ -733,10 +779,9 @@ bool WireFormat::SerializeFieldWithCachedSizes(
           message_reflection->GetRepeatedEnum(message, field, j) :
           message_reflection->GetEnum(message, field);
         if (is_packed) {
-          if (!WriteEnumNoTag(value->number(), output)) return false;
+          WriteEnumNoTag(value->number(), output);
         } else {
-          if (!WriteEnum(field->number(), value->number(), output))
-            return false;
+          WriteEnum(field->number(), value->number(), output);
         }
         break;
       }
@@ -749,7 +794,7 @@ bool WireFormat::SerializeFieldWithCachedSizes(
             message_reflection->GetRepeatedStringReference(
               message, field, j, &scratch) :
             message_reflection->GetStringReference(message, field, &scratch);
-          if (!WriteString(field->number(), value, output)) return false;
+          WriteString(field->number(), value, output);
         break;
       }
 
@@ -759,39 +804,35 @@ bool WireFormat::SerializeFieldWithCachedSizes(
             message_reflection->GetRepeatedStringReference(
               message, field, j, &scratch) :
             message_reflection->GetStringReference(message, field, &scratch);
-          if (!WriteBytes(field->number(), value, output)) return false;
+          WriteBytes(field->number(), value, output);
         break;
       }
     }
   }
-
-  return true;
 }
 
-bool WireFormat::SerializeMessageSetItemWithCachedSizes(
+void WireFormat::SerializeMessageSetItemWithCachedSizes(
     const FieldDescriptor* field,
     const Message& message,
     io::CodedOutputStream* output) {
   const Reflection* message_reflection = message.GetReflection();
 
   // Start group.
-  if (!output->WriteVarint32(kMessageSetItemStartTag)) return false;
+  output->WriteVarint32(kMessageSetItemStartTag);
 
   // Write type ID.
-  if (!output->WriteVarint32(kMessageSetTypeIdTag)) return false;
-  if (!output->WriteVarint32(field->number())) return false;
+  output->WriteVarint32(kMessageSetTypeIdTag);
+  output->WriteVarint32(field->number());
 
   // Write message.
-  if (!output->WriteVarint32(kMessageSetMessageTag)) return false;
+  output->WriteVarint32(kMessageSetMessageTag);
 
   const Message& sub_message = message_reflection->GetMessage(message, field);
-  if (!output->WriteVarint32(sub_message.GetCachedSize())) return false;
-  if (!sub_message.SerializeWithCachedSizes(output)) return false;
+  output->WriteVarint32(sub_message.GetCachedSize());
+  sub_message.SerializeWithCachedSizes(output);
 
   // End group.
-  if (!output->WriteVarint32(kMessageSetItemEndTag)) return false;
-
-  return true;
+  output->WriteVarint32(kMessageSetItemEndTag);
 }
 
 // ===================================================================

+ 136 - 47
src/google/protobuf/wire_format.h

@@ -93,7 +93,7 @@ class LIBPROTOBUF_EXPORT WireFormat {
   // a parameter to this procedure.
   //
   // These return false iff the underlying stream returns a write error.
-  static bool SerializeWithCachedSizes(
+  static void SerializeWithCachedSizes(
       const Message& message,
       int size, io::CodedOutputStream* output);
 
@@ -119,14 +119,30 @@ class LIBPROTOBUF_EXPORT WireFormat {
                           UnknownFieldSet* unknown_fields);
 
   // Write the contents of an UnknownFieldSet to the output.
-  static bool SerializeUnknownFields(const UnknownFieldSet& unknown_fields,
+  static void SerializeUnknownFields(const UnknownFieldSet& unknown_fields,
                                      io::CodedOutputStream* output);
+  // Same as above, except writing directly to the provided buffer.
+  // Requires that the buffer have sufficient capacity for
+  // ComputeUnknownFieldsSize(unknown_fields).
+  //
+  // Returns a pointer past the last written byte.
+  static uint8* SerializeUnknownFieldsToArray(
+      const UnknownFieldSet& unknown_fields,
+      uint8* target);
 
   // Same thing except for messages that have the message_set_wire_format
   // option.
-  static bool SerializeUnknownMessageSetItems(
+  static void SerializeUnknownMessageSetItems(
       const UnknownFieldSet& unknown_fields,
       io::CodedOutputStream* output);
+  // Same as above, except writing directly to the provided buffer.
+  // Requires that the buffer have sufficient capacity for
+  // ComputeUnknownMessageSetItemsSize(unknown_fields).
+  //
+  // Returns a pointer past the last written byte.
+  static uint8* SerializeUnknownMessageSetItemsToArray(
+      const UnknownFieldSet& unknown_fields,
+      uint8* target);
 
   // Compute the size of the UnknownFieldSet on the wire.
   static int ComputeUnknownFieldsSize(const UnknownFieldSet& unknown_fields);
@@ -210,7 +226,7 @@ class LIBPROTOBUF_EXPORT WireFormat {
       io::CodedInputStream* input);
 
   // Serialize a single field.
-  static bool SerializeFieldWithCachedSizes(
+  static void SerializeFieldWithCachedSizes(
       const FieldDescriptor* field,        // Cannot be NULL
       const Message& message,
       io::CodedOutputStream* output);
@@ -268,61 +284,139 @@ class LIBPROTOBUF_EXPORT WireFormat {
   // Write a tag.  The Write*() functions typically include the tag, so
   // normally there's no need to call this unless using the Write*NoTag()
   // variants.
-  static inline bool WriteTag(field_number, WireType type, output) INL;
+  static inline void WriteTag(field_number, WireType type, output) INL;
 
   // Write fields, without tags.
-  static inline bool WriteInt32NoTag   (int32 value, output) INL;
-  static inline bool WriteInt64NoTag   (int64 value, output) INL;
-  static inline bool WriteUInt32NoTag  (uint32 value, output) INL;
-  static inline bool WriteUInt64NoTag  (uint64 value, output) INL;
-  static inline bool WriteSInt32NoTag  (int32 value, output) INL;
-  static inline bool WriteSInt64NoTag  (int64 value, output) INL;
-  static inline bool WriteFixed32NoTag (uint32 value, output) INL;
-  static inline bool WriteFixed64NoTag (uint64 value, output) INL;
-  static inline bool WriteSFixed32NoTag(int32 value, output) INL;
-  static inline bool WriteSFixed64NoTag(int64 value, output) INL;
-  static inline bool WriteFloatNoTag   (float value, output) INL;
-  static inline bool WriteDoubleNoTag  (double value, output) INL;
-  static inline bool WriteBoolNoTag    (bool value, output) INL;
-  static inline bool WriteEnumNoTag    (int value, output) INL;
+  static inline void WriteInt32NoTag   (int32 value, output) INL;
+  static inline void WriteInt64NoTag   (int64 value, output) INL;
+  static inline void WriteUInt32NoTag  (uint32 value, output) INL;
+  static inline void WriteUInt64NoTag  (uint64 value, output) INL;
+  static inline void WriteSInt32NoTag  (int32 value, output) INL;
+  static inline void WriteSInt64NoTag  (int64 value, output) INL;
+  static inline void WriteFixed32NoTag (uint32 value, output) INL;
+  static inline void WriteFixed64NoTag (uint64 value, output) INL;
+  static inline void WriteSFixed32NoTag(int32 value, output) INL;
+  static inline void WriteSFixed64NoTag(int64 value, output) INL;
+  static inline void WriteFloatNoTag   (float value, output) INL;
+  static inline void WriteDoubleNoTag  (double value, output) INL;
+  static inline void WriteBoolNoTag    (bool value, output) INL;
+  static inline void WriteEnumNoTag    (int value, output) INL;
 
   // Write fields, including tags.
-  static inline bool WriteInt32   (field_number,  int32 value, output) INL;
-  static inline bool WriteInt64   (field_number,  int64 value, output) INL;
-  static inline bool WriteUInt32  (field_number, uint32 value, output) INL;
-  static inline bool WriteUInt64  (field_number, uint64 value, output) INL;
-  static inline bool WriteSInt32  (field_number,  int32 value, output) INL;
-  static inline bool WriteSInt64  (field_number,  int64 value, output) INL;
-  static inline bool WriteFixed32 (field_number, uint32 value, output) INL;
-  static inline bool WriteFixed64 (field_number, uint64 value, output) INL;
-  static inline bool WriteSFixed32(field_number,  int32 value, output) INL;
-  static inline bool WriteSFixed64(field_number,  int64 value, output) INL;
-  static inline bool WriteFloat   (field_number,  float value, output) INL;
-  static inline bool WriteDouble  (field_number, double value, output) INL;
-  static inline bool WriteBool    (field_number,   bool value, output) INL;
-  static inline bool WriteEnum    (field_number,    int value, output) INL;
-
-  static inline bool WriteString(field_number, const string& value, output) INL;
-  static inline bool WriteBytes (field_number, const string& value, output) INL;
-
-  static inline bool WriteGroup(field_number, const Message& value, output) INL;
-  static inline bool WriteMessage(
+  static inline void WriteInt32   (field_number,  int32 value, output) INL;
+  static inline void WriteInt64   (field_number,  int64 value, output) INL;
+  static inline void WriteUInt32  (field_number, uint32 value, output) INL;
+  static inline void WriteUInt64  (field_number, uint64 value, output) INL;
+  static inline void WriteSInt32  (field_number,  int32 value, output) INL;
+  static inline void WriteSInt64  (field_number,  int64 value, output) INL;
+  static inline void WriteFixed32 (field_number, uint32 value, output) INL;
+  static inline void WriteFixed64 (field_number, uint64 value, output) INL;
+  static inline void WriteSFixed32(field_number,  int32 value, output) INL;
+  static inline void WriteSFixed64(field_number,  int64 value, output) INL;
+  static inline void WriteFloat   (field_number,  float value, output) INL;
+  static inline void WriteDouble  (field_number, double value, output) INL;
+  static inline void WriteBool    (field_number,   bool value, output) INL;
+  static inline void WriteEnum    (field_number,    int value, output) INL;
+
+  static inline void WriteString(field_number, const string& value, output) INL;
+  static inline void WriteBytes (field_number, const string& value, output) INL;
+
+  static inline void WriteGroup(field_number, const Message& value, output) INL;
+  static inline void WriteMessage(
     field_number, const Message& value, output) INL;
 
   // Like above, but de-virtualize the call to SerializeWithCachedSizes().  The
   // pointer must point at an instance of MessageType, *not* a subclass (or
   // the subclass must not override SerializeWithCachedSizes()).
   template<typename MessageType>
-  static inline bool WriteGroupNoVirtual(
+  static inline void WriteGroupNoVirtual(
+    field_number, const MessageType& value, output) INL;
+  template<typename MessageType>
+  static inline void WriteMessageNoVirtual(
+    field_number, const MessageType& value, output) INL;
+
+#undef output
+#define output uint8* target
+
+  // Like above, but use only *ToArray methods of CodedOutputStream.
+  static inline uint8* WriteTagToArray(field_number, WireType type, output) INL;
+
+  // Write fields, without tags.
+  static inline uint8* WriteInt32NoTagToArray   (int32 value, output) INL;
+  static inline uint8* WriteInt64NoTagToArray   (int64 value, output) INL;
+  static inline uint8* WriteUInt32NoTagToArray  (uint32 value, output) INL;
+  static inline uint8* WriteUInt64NoTagToArray  (uint64 value, output) INL;
+  static inline uint8* WriteSInt32NoTagToArray  (int32 value, output) INL;
+  static inline uint8* WriteSInt64NoTagToArray  (int64 value, output) INL;
+  static inline uint8* WriteFixed32NoTagToArray (uint32 value, output) INL;
+  static inline uint8* WriteFixed64NoTagToArray (uint64 value, output) INL;
+  static inline uint8* WriteSFixed32NoTagToArray(int32 value, output) INL;
+  static inline uint8* WriteSFixed64NoTagToArray(int64 value, output) INL;
+  static inline uint8* WriteFloatNoTagToArray   (float value, output) INL;
+  static inline uint8* WriteDoubleNoTagToArray  (double value, output) INL;
+  static inline uint8* WriteBoolNoTagToArray    (bool value, output) INL;
+  static inline uint8* WriteEnumNoTagToArray    (int value, output) INL;
+
+  // Write fields, including tags.
+  static inline uint8* WriteInt32ToArray(
+    field_number, int32 value, output) INL;
+  static inline uint8* WriteInt64ToArray(
+    field_number, int64 value, output) INL;
+  static inline uint8* WriteUInt32ToArray(
+    field_number, uint32 value, output) INL;
+  static inline uint8* WriteUInt64ToArray(
+    field_number, uint64 value, output) INL;
+  static inline uint8* WriteSInt32ToArray(
+    field_number, int32 value, output) INL;
+  static inline uint8* WriteSInt64ToArray(
+    field_number, int64 value, output) INL;
+  static inline uint8* WriteFixed32ToArray(
+    field_number, uint32 value, output) INL;
+  static inline uint8* WriteFixed64ToArray(
+    field_number, uint64 value, output) INL;
+  static inline uint8* WriteSFixed32ToArray(
+    field_number, int32 value, output) INL;
+  static inline uint8* WriteSFixed64ToArray(
+    field_number, int64 value, output) INL;
+  static inline uint8* WriteFloatToArray(
+    field_number, float value, output) INL;
+  static inline uint8* WriteDoubleToArray(
+    field_number, double value, output) INL;
+  static inline uint8* WriteBoolToArray(
+    field_number, bool value, output) INL;
+  static inline uint8* WriteEnumToArray(
+    field_number, int value, output) INL;
+
+  static inline uint8* WriteStringToArray(
+    field_number, const string& value, output) INL;
+  static inline uint8* WriteBytesToArray(
+    field_number, const string& value, output) INL;
+
+  static inline uint8* WriteGroupToArray(
+      field_number, const Message& value, output) INL;
+  static inline uint8* WriteMessageToArray(
+      field_number, const Message& value, output) INL;
+
+  // Like above, but de-virtualize the call to SerializeWithCachedSizes().  The
+  // pointer must point at an instance of MessageType, *not* a subclass (or
+  // the subclass must not override SerializeWithCachedSizes()).
+  template<typename MessageType>
+  static inline uint8* WriteGroupNoVirtualToArray(
     field_number, const MessageType& value, output) INL;
   template<typename MessageType>
-  static inline bool WriteMessageNoVirtual(
+  static inline uint8* WriteMessageNoVirtualToArray(
     field_number, const MessageType& value, output) INL;
 
+#undef output
+#undef input
+#undef INL
+
   // Compute the byte size of a tag.  For groups, this includes both the start
   // and end tags.
   static inline int TagSize(field_number, FieldDescriptor::Type type);
 
+#undef field_number
+
   // Compute the byte size of a field.  The XxSize() functions do NOT include
   // the tag, so you must also call TagSize().  (This is because, for repeated
   // fields, you should only call TagSize() once and multiply it by the element
@@ -358,11 +452,6 @@ class LIBPROTOBUF_EXPORT WireFormat {
   template<typename MessageType>
   static inline int MessageSizeNoVirtual(const MessageType& value);
 
-#undef input
-#undef output
-#undef field_number
-#undef INL
-
  private:
   static const WireType kWireTypeForFieldType[];
 
@@ -371,7 +460,7 @@ class LIBPROTOBUF_EXPORT WireFormat {
   static bool ParseAndMergeMessageSetItem(
       io::CodedInputStream* input,
       Message* message);
-  static bool SerializeMessageSetItemWithCachedSizes(
+  static void SerializeMessageSetItemWithCachedSizes(
       const FieldDescriptor* field,
       const Message& message,
       io::CodedOutputStream* output);

+ 310 - 98
src/google/protobuf/wire_format_inl.h

@@ -164,7 +164,8 @@ inline bool WireFormat::ReadBytes(io::CodedInputStream* input, string* value) {
 }
 
 
-inline bool WireFormat::ReadGroup(int field_number, io::CodedInputStream* input,
+inline bool WireFormat::ReadGroup(int field_number,
+                                  io::CodedInputStream* input,
                                   Message* value) {
   if (!input->IncrementRecursionDepth()) return false;
   if (!value->MergePartialFromCodedStream(input)) return false;
@@ -175,7 +176,8 @@ inline bool WireFormat::ReadGroup(int field_number, io::CodedInputStream* input,
   }
   return true;
 }
-inline bool WireFormat::ReadMessage(io::CodedInputStream* input, Message* value) {
+inline bool WireFormat::ReadMessage(io::CodedInputStream* input,
+                                    Message* value) {
   uint32 length;
   if (!input->ReadVarint32(&length)) return false;
   if (!input->IncrementRecursionDepth()) return false;
@@ -220,140 +222,140 @@ inline bool WireFormat::ReadMessageNoVirtual(io::CodedInputStream* input,
 
 // ===================================================================
 
-inline bool WireFormat::WriteTag(int field_number, WireType type,
+inline void WireFormat::WriteTag(int field_number, WireType type,
                                  io::CodedOutputStream* output) {
-  return output->WriteTag(MakeTag(field_number, type));
+  output->WriteTag(MakeTag(field_number, type));
 }
 
-inline bool WireFormat::WriteInt32NoTag(int32 value,
+inline void WireFormat::WriteInt32NoTag(int32 value,
                                         io::CodedOutputStream* output) {
-  return output->WriteVarint32SignExtended(value);
+  output->WriteVarint32SignExtended(value);
 }
-inline bool WireFormat::WriteInt64NoTag(int64 value,
+inline void WireFormat::WriteInt64NoTag(int64 value,
                                         io::CodedOutputStream* output) {
-  return output->WriteVarint64(static_cast<uint64>(value));
+  output->WriteVarint64(static_cast<uint64>(value));
 }
-inline bool WireFormat::WriteUInt32NoTag(uint32 value,
+inline void WireFormat::WriteUInt32NoTag(uint32 value,
                                          io::CodedOutputStream* output) {
-  return output->WriteVarint32(value);
+  output->WriteVarint32(value);
 }
-inline bool WireFormat::WriteUInt64NoTag(uint64 value,
+inline void WireFormat::WriteUInt64NoTag(uint64 value,
                                          io::CodedOutputStream* output) {
-  return output->WriteVarint64(value);
+  output->WriteVarint64(value);
 }
-inline bool WireFormat::WriteSInt32NoTag(int32 value,
+inline void WireFormat::WriteSInt32NoTag(int32 value,
                                          io::CodedOutputStream* output) {
-  return output->WriteVarint32(ZigZagEncode32(value));
+  output->WriteVarint32(ZigZagEncode32(value));
 }
-inline bool WireFormat::WriteSInt64NoTag(int64 value,
+inline void WireFormat::WriteSInt64NoTag(int64 value,
                                          io::CodedOutputStream* output) {
-  return output->WriteVarint64(ZigZagEncode64(value));
+  output->WriteVarint64(ZigZagEncode64(value));
 }
-inline bool WireFormat::WriteFixed32NoTag(uint32 value,
+inline void WireFormat::WriteFixed32NoTag(uint32 value,
                                           io::CodedOutputStream* output) {
-  return output->WriteLittleEndian32(value);
+  output->WriteLittleEndian32(value);
 }
-inline bool WireFormat::WriteFixed64NoTag(uint64 value,
+inline void WireFormat::WriteFixed64NoTag(uint64 value,
                                           io::CodedOutputStream* output) {
-  return output->WriteLittleEndian64(value);
+  output->WriteLittleEndian64(value);
 }
-inline bool WireFormat::WriteSFixed32NoTag(int32 value,
+inline void WireFormat::WriteSFixed32NoTag(int32 value,
                                            io::CodedOutputStream* output) {
-  return output->WriteLittleEndian32(static_cast<uint32>(value));
+  output->WriteLittleEndian32(static_cast<uint32>(value));
 }
-inline bool WireFormat::WriteSFixed64NoTag(int64 value,
+inline void WireFormat::WriteSFixed64NoTag(int64 value,
                                            io::CodedOutputStream* output) {
-  return output->WriteLittleEndian64(static_cast<uint64>(value));
+  output->WriteLittleEndian64(static_cast<uint64>(value));
 }
-inline bool WireFormat::WriteFloatNoTag(float value,
+inline void WireFormat::WriteFloatNoTag(float value,
                                         io::CodedOutputStream* output) {
-  return output->WriteLittleEndian32(EncodeFloat(value));
+  output->WriteLittleEndian32(EncodeFloat(value));
 }
-inline bool WireFormat::WriteDoubleNoTag(double value,
+inline void WireFormat::WriteDoubleNoTag(double value,
                                          io::CodedOutputStream* output) {
-  return output->WriteLittleEndian64(EncodeDouble(value));
+  output->WriteLittleEndian64(EncodeDouble(value));
 }
-inline bool WireFormat::WriteBoolNoTag(bool value,
+inline void WireFormat::WriteBoolNoTag(bool value,
                                        io::CodedOutputStream* output) {
-  return output->WriteVarint32(value ? 1 : 0);
+  output->WriteVarint32(value ? 1 : 0);
 }
-inline bool WireFormat::WriteEnumNoTag(int value,
+inline void WireFormat::WriteEnumNoTag(int value,
                                        io::CodedOutputStream* output) {
-  return output->WriteVarint32SignExtended(value);
+  output->WriteVarint32SignExtended(value);
 }
 
-inline bool WireFormat::WriteInt32(int field_number, int32 value,
+inline void WireFormat::WriteInt32(int field_number, int32 value,
                                    io::CodedOutputStream* output) {
-  return WriteTag(field_number, WIRETYPE_VARINT, output) &&
-         WriteInt32NoTag(value, output);
+  WriteTag(field_number, WIRETYPE_VARINT, output);
+  WriteInt32NoTag(value, output);
 }
-inline bool WireFormat::WriteInt64(int field_number, int64 value,
+inline void WireFormat::WriteInt64(int field_number, int64 value,
                                    io::CodedOutputStream* output) {
-  return WriteTag(field_number, WIRETYPE_VARINT, output) &&
-         WriteInt64NoTag(value, output);
+  WriteTag(field_number, WIRETYPE_VARINT, output);
+  WriteInt64NoTag(value, output);
 }
-inline bool WireFormat::WriteUInt32(int field_number, uint32 value,
+inline void WireFormat::WriteUInt32(int field_number, uint32 value,
                                     io::CodedOutputStream* output) {
-  return WriteTag(field_number, WIRETYPE_VARINT, output) &&
-         WriteUInt32NoTag(value, output);
+  WriteTag(field_number, WIRETYPE_VARINT, output);
+  WriteUInt32NoTag(value, output);
 }
-inline bool WireFormat::WriteUInt64(int field_number, uint64 value,
+inline void WireFormat::WriteUInt64(int field_number, uint64 value,
                                     io::CodedOutputStream* output) {
-  return WriteTag(field_number, WIRETYPE_VARINT, output) &&
-         WriteUInt64NoTag(value, output);
+  WriteTag(field_number, WIRETYPE_VARINT, output);
+  WriteUInt64NoTag(value, output);
 }
-inline bool WireFormat::WriteSInt32(int field_number, int32 value,
+inline void WireFormat::WriteSInt32(int field_number, int32 value,
                                     io::CodedOutputStream* output) {
-  return WriteTag(field_number, WIRETYPE_VARINT, output) &&
-         WriteSInt32NoTag(value, output);
+  WriteTag(field_number, WIRETYPE_VARINT, output);
+  WriteSInt32NoTag(value, output);
 }
-inline bool WireFormat::WriteSInt64(int field_number, int64 value,
+inline void WireFormat::WriteSInt64(int field_number, int64 value,
                                     io::CodedOutputStream* output) {
-  return WriteTag(field_number, WIRETYPE_VARINT, output) &&
-         WriteSInt64NoTag(value, output);
+  WriteTag(field_number, WIRETYPE_VARINT, output);
+  WriteSInt64NoTag(value, output);
 }
-inline bool WireFormat::WriteFixed32(int field_number, uint32 value,
+inline void WireFormat::WriteFixed32(int field_number, uint32 value,
                                      io::CodedOutputStream* output) {
-  return WriteTag(field_number, WIRETYPE_FIXED32, output) &&
-         WriteFixed32NoTag(value, output);
+  WriteTag(field_number, WIRETYPE_FIXED32, output);
+  WriteFixed32NoTag(value, output);
 }
-inline bool WireFormat::WriteFixed64(int field_number, uint64 value,
+inline void WireFormat::WriteFixed64(int field_number, uint64 value,
                                      io::CodedOutputStream* output) {
-  return WriteTag(field_number, WIRETYPE_FIXED64, output) &&
-         WriteFixed64NoTag(value, output);
+  WriteTag(field_number, WIRETYPE_FIXED64, output);
+  WriteFixed64NoTag(value, output);
 }
-inline bool WireFormat::WriteSFixed32(int field_number, int32 value,
+inline void WireFormat::WriteSFixed32(int field_number, int32 value,
                                       io::CodedOutputStream* output) {
-  return WriteTag(field_number, WIRETYPE_FIXED32, output) &&
-         WriteSFixed32NoTag(value, output);
+  WriteTag(field_number, WIRETYPE_FIXED32, output);
+  WriteSFixed32NoTag(value, output);
 }
-inline bool WireFormat::WriteSFixed64(int field_number, int64 value,
+inline void WireFormat::WriteSFixed64(int field_number, int64 value,
                                       io::CodedOutputStream* output) {
-  return WriteTag(field_number, WIRETYPE_FIXED64, output) &&
-         WriteSFixed64NoTag(value, output);
+  WriteTag(field_number, WIRETYPE_FIXED64, output);
+  WriteSFixed64NoTag(value, output);
 }
-inline bool WireFormat::WriteFloat(int field_number, float value,
+inline void WireFormat::WriteFloat(int field_number, float value,
                                    io::CodedOutputStream* output) {
-  return WriteTag(field_number, WIRETYPE_FIXED32, output) &&
-         WriteFloatNoTag(value, output);
+  WriteTag(field_number, WIRETYPE_FIXED32, output);
+  WriteFloatNoTag(value, output);
 }
-inline bool WireFormat::WriteDouble(int field_number, double value,
+inline void WireFormat::WriteDouble(int field_number, double value,
                                     io::CodedOutputStream* output) {
-  return WriteTag(field_number, WIRETYPE_FIXED64, output) &&
-         WriteDoubleNoTag(value, output);
+  WriteTag(field_number, WIRETYPE_FIXED64, output);
+  WriteDoubleNoTag(value, output);
 }
-inline bool WireFormat::WriteBool(int field_number, bool value,
+inline void WireFormat::WriteBool(int field_number, bool value,
                                   io::CodedOutputStream* output) {
-  return WriteTag(field_number, WIRETYPE_VARINT, output) &&
-         WriteBoolNoTag(value, output);
+  WriteTag(field_number, WIRETYPE_VARINT, output);
+  WriteBoolNoTag(value, output);
 }
-inline bool WireFormat::WriteEnum(int field_number, int value,
+inline void WireFormat::WriteEnum(int field_number, int value,
                                   io::CodedOutputStream* output) {
-  return WriteTag(field_number, WIRETYPE_VARINT, output) &&
-         WriteEnumNoTag(value, output);
+  WriteTag(field_number, WIRETYPE_VARINT, output);
+  WriteEnumNoTag(value, output);
 }
 
-inline bool WireFormat::WriteString(int field_number, const string& value,
+inline void WireFormat::WriteString(int field_number, const string& value,
                                     io::CodedOutputStream* output) {
   // String is for UTF-8 text only
 #ifdef GOOGLE_PROTOBUF_UTF8_VALIDATION_ENABLED
@@ -363,46 +365,256 @@ inline bool WireFormat::WriteString(int field_number, const string& value,
                "use the 'bytes' type for raw bytes.";
   }
 #endif  // GOOGLE_PROTOBUF_UTF8_VALIDATION_ENABLED
-  return WriteTag(field_number, WIRETYPE_LENGTH_DELIMITED, output) &&
-         output->WriteVarint32(value.size()) &&
-         output->WriteString(value);
+  WriteTag(field_number, WIRETYPE_LENGTH_DELIMITED, output);
+  output->WriteVarint32(value.size());
+  output->WriteString(value);
 }
-inline bool WireFormat::WriteBytes(int field_number, const string& value,
+inline void WireFormat::WriteBytes(int field_number, const string& value,
                                    io::CodedOutputStream* output) {
-  return WriteTag(field_number, WIRETYPE_LENGTH_DELIMITED, output) &&
-         output->WriteVarint32(value.size()) &&
-         output->WriteString(value);
+  WriteTag(field_number, WIRETYPE_LENGTH_DELIMITED, output);
+  output->WriteVarint32(value.size());
+  output->WriteString(value);
 }
 
 
-inline bool WireFormat::WriteGroup(int field_number, const Message& value,
+inline void WireFormat::WriteGroup(int field_number, const Message& value,
                                    io::CodedOutputStream* output) {
-  return WriteTag(field_number, WIRETYPE_START_GROUP, output) &&
-         value.SerializeWithCachedSizes(output) &&
-         WriteTag(field_number, WIRETYPE_END_GROUP, output);
+  WriteTag(field_number, WIRETYPE_START_GROUP, output);
+  value.SerializeWithCachedSizes(output);
+  WriteTag(field_number, WIRETYPE_END_GROUP, output);
 }
-inline bool WireFormat::WriteMessage(int field_number, const Message& value,
+inline void WireFormat::WriteMessage(int field_number, const Message& value,
                                      io::CodedOutputStream* output) {
-  return WriteTag(field_number, WIRETYPE_LENGTH_DELIMITED, output) &&
-         output->WriteVarint32(value.GetCachedSize()) &&
-         value.SerializeWithCachedSizes(output);
+  WriteTag(field_number, WIRETYPE_LENGTH_DELIMITED, output);
+  output->WriteVarint32(value.GetCachedSize());
+  value.SerializeWithCachedSizes(output);
 }
 
 template<typename MessageType>
-inline bool WireFormat::WriteGroupNoVirtual(
+inline void WireFormat::WriteGroupNoVirtual(
     int field_number, const MessageType& value,
     io::CodedOutputStream* output) {
-  return WriteTag(field_number, WIRETYPE_START_GROUP, output) &&
-         value.MessageType::SerializeWithCachedSizes(output) &&
-         WriteTag(field_number, WIRETYPE_END_GROUP, output);
+  WriteTag(field_number, WIRETYPE_START_GROUP, output);
+  value.MessageType::SerializeWithCachedSizes(output);
+  WriteTag(field_number, WIRETYPE_END_GROUP, output);
 }
 template<typename MessageType>
-inline bool WireFormat::WriteMessageNoVirtual(
+inline void WireFormat::WriteMessageNoVirtual(
     int field_number, const MessageType& value,
     io::CodedOutputStream* output) {
-  return WriteTag(field_number, WIRETYPE_LENGTH_DELIMITED, output) &&
-         output->WriteVarint32(value.MessageType::GetCachedSize()) &&
-         value.MessageType::SerializeWithCachedSizes(output);
+  WriteTag(field_number, WIRETYPE_LENGTH_DELIMITED, output);
+  output->WriteVarint32(value.MessageType::GetCachedSize());
+  value.MessageType::SerializeWithCachedSizes(output);
+}
+
+// ===================================================================
+
+inline uint8* WireFormat::WriteTagToArray(int field_number,
+                                          WireType type,
+                                          uint8* target) {
+  return io::CodedOutputStream::WriteTagToArray(MakeTag(field_number, type),
+                                                target);
+}
+
+inline uint8* WireFormat::WriteInt32NoTagToArray(int32 value, uint8* target) {
+  return io::CodedOutputStream::WriteVarint32SignExtendedToArray(value, target);
+}
+inline uint8* WireFormat::WriteInt64NoTagToArray(int64 value, uint8* target) {
+  return io::CodedOutputStream::WriteVarint64ToArray(
+      static_cast<uint64>(value), target);
+}
+inline uint8* WireFormat::WriteUInt32NoTagToArray(uint32 value, uint8* target) {
+  return io::CodedOutputStream::WriteVarint32ToArray(value, target);
+}
+inline uint8* WireFormat::WriteUInt64NoTagToArray(uint64 value, uint8* target) {
+  return io::CodedOutputStream::WriteVarint64ToArray(value, target);
+}
+inline uint8* WireFormat::WriteSInt32NoTagToArray(int32 value, uint8* target) {
+  return io::CodedOutputStream::WriteVarint32ToArray(ZigZagEncode32(value),
+                                                     target);
+}
+inline uint8* WireFormat::WriteSInt64NoTagToArray(int64 value, uint8* target) {
+  return io::CodedOutputStream::WriteVarint64ToArray(ZigZagEncode64(value),
+                                                     target);
+}
+inline uint8* WireFormat::WriteFixed32NoTagToArray(uint32 value,
+                                                   uint8* target) {
+  return io::CodedOutputStream::WriteLittleEndian32ToArray(value, target);
+}
+inline uint8* WireFormat::WriteFixed64NoTagToArray(uint64 value,
+                                                   uint8* target) {
+  return io::CodedOutputStream::WriteLittleEndian64ToArray(value, target);
+}
+inline uint8* WireFormat::WriteSFixed32NoTagToArray(int32 value,
+                                                    uint8* target) {
+  return io::CodedOutputStream::WriteLittleEndian32ToArray(
+      static_cast<uint32>(value), target);
+}
+inline uint8* WireFormat::WriteSFixed64NoTagToArray(int64 value,
+                                                    uint8* target) {
+  return io::CodedOutputStream::WriteLittleEndian64ToArray(
+      static_cast<uint64>(value), target);
+}
+inline uint8* WireFormat::WriteFloatNoTagToArray(float value, uint8* target) {
+  return io::CodedOutputStream::WriteLittleEndian32ToArray(EncodeFloat(value),
+                                                           target);
+}
+inline uint8* WireFormat::WriteDoubleNoTagToArray(double value,
+                                                  uint8* target) {
+  return io::CodedOutputStream::WriteLittleEndian64ToArray(EncodeDouble(value),
+                                                           target);
+}
+inline uint8* WireFormat::WriteBoolNoTagToArray(bool value,
+                                                uint8* target) {
+  return io::CodedOutputStream::WriteVarint32ToArray(value ? 1 : 0, target);
+}
+inline uint8* WireFormat::WriteEnumNoTagToArray(int value,
+                                                uint8* target) {
+  return io::CodedOutputStream::WriteVarint32SignExtendedToArray(value, target);
+}
+
+inline uint8* WireFormat::WriteInt32ToArray(int field_number,
+                                            int32 value,
+                                            uint8* target) {
+  target = WriteTagToArray(field_number, WIRETYPE_VARINT, target);
+  return WriteInt32NoTagToArray(value, target);
+}
+inline uint8* WireFormat::WriteInt64ToArray(int field_number,
+                                            int64 value,
+                                            uint8* target) {
+  target = WriteTagToArray(field_number, WIRETYPE_VARINT, target);
+  return WriteInt64NoTagToArray(value, target);
+}
+inline uint8* WireFormat::WriteUInt32ToArray(int field_number,
+                                             uint32 value,
+                                             uint8* target) {
+  target = WriteTagToArray(field_number, WIRETYPE_VARINT, target);
+  return WriteUInt32NoTagToArray(value, target);
+}
+inline uint8* WireFormat::WriteUInt64ToArray(int field_number,
+                                             uint64 value,
+                                             uint8* target) {
+  target = WriteTagToArray(field_number, WIRETYPE_VARINT, target);
+  return WriteUInt64NoTagToArray(value, target);
+}
+inline uint8* WireFormat::WriteSInt32ToArray(int field_number,
+                                             int32 value,
+                                             uint8* target) {
+  target = WriteTagToArray(field_number, WIRETYPE_VARINT, target);
+  return WriteSInt32NoTagToArray(value, target);
+}
+inline uint8* WireFormat::WriteSInt64ToArray(int field_number,
+                                             int64 value,
+                                             uint8* target) {
+  target = WriteTagToArray(field_number, WIRETYPE_VARINT, target);
+  return WriteSInt64NoTagToArray(value, target);
+}
+inline uint8* WireFormat::WriteFixed32ToArray(int field_number,
+                                              uint32 value,
+                                              uint8* target) {
+  target = WriteTagToArray(field_number, WIRETYPE_FIXED32, target);
+  return WriteFixed32NoTagToArray(value, target);
+}
+inline uint8* WireFormat::WriteFixed64ToArray(int field_number,
+                                              uint64 value,
+                                              uint8* target) {
+  target = WriteTagToArray(field_number, WIRETYPE_FIXED64, target);
+  return WriteFixed64NoTagToArray(value, target);
+}
+inline uint8* WireFormat::WriteSFixed32ToArray(int field_number,
+                                               int32 value,
+                                               uint8* target) {
+  target = WriteTagToArray(field_number, WIRETYPE_FIXED32, target);
+  return WriteSFixed32NoTagToArray(value, target);
+}
+inline uint8* WireFormat::WriteSFixed64ToArray(int field_number,
+                                               int64 value,
+                                               uint8* target) {
+  target = WriteTagToArray(field_number, WIRETYPE_FIXED64, target);
+  return WriteSFixed64NoTagToArray(value, target);
+}
+inline uint8* WireFormat::WriteFloatToArray(int field_number,
+                                            float value,
+                                            uint8* target) {
+  target = WriteTagToArray(field_number, WIRETYPE_FIXED32, target);
+  return WriteFloatNoTagToArray(value, target);
+}
+inline uint8* WireFormat::WriteDoubleToArray(int field_number,
+                                             double value,
+                                             uint8* target) {
+  target = WriteTagToArray(field_number, WIRETYPE_FIXED64, target);
+  return WriteDoubleNoTagToArray(value, target);
+}
+inline uint8* WireFormat::WriteBoolToArray(int field_number,
+                                           bool value,
+                                           uint8* target) {
+  target = WriteTagToArray(field_number, WIRETYPE_VARINT, target);
+  return WriteBoolNoTagToArray(value, target);
+}
+inline uint8* WireFormat::WriteEnumToArray(int field_number,
+                                           int value,
+                                           uint8* target) {
+  target = WriteTagToArray(field_number, WIRETYPE_VARINT, target);
+  return WriteEnumNoTagToArray(value, target);
+}
+
+inline uint8* WireFormat::WriteStringToArray(int field_number,
+                                             const string& value,
+                                             uint8* target) {
+  // String is for UTF-8 text only
+#ifdef GOOGLE_PROTOBUF_UTF8_VALIDATION_ENABLED
+  if (!IsStructurallyValidUTF8(value.data(), value.size())) {
+    GOOGLE_LOG(ERROR) << "Encountered string containing invalid UTF-8 data while "
+               "serializing protocol buffer. Strings must contain only UTF-8; "
+               "use the 'bytes' type for raw bytes.";
+  }
+#endif
+  // WARNING:  In wire_format.cc, both strings and bytes are handled by
+  //   WriteString() to avoid code duplication.  If the implementations become
+  //   different, you will need to update that usage.
+  target = WriteTagToArray(field_number, WIRETYPE_LENGTH_DELIMITED, target);
+  target = io::CodedOutputStream::WriteVarint32ToArray(value.size(), target);
+  return io::CodedOutputStream::WriteStringToArray(value, target);
+}
+inline uint8* WireFormat::WriteBytesToArray(int field_number,
+                                            const string& value,
+                                            uint8* target) {
+  target = WriteTagToArray(field_number, WIRETYPE_LENGTH_DELIMITED, target);
+  target = io::CodedOutputStream::WriteVarint32ToArray(value.size(), target);
+  return io::CodedOutputStream::WriteStringToArray(value, target);
+}
+
+
+inline uint8* WireFormat::WriteGroupToArray(int field_number,
+                                            const Message& value,
+                                            uint8* target) {
+  target = WriteTagToArray(field_number, WIRETYPE_START_GROUP, target);
+  target = value.SerializeWithCachedSizesToArray(target);
+  return WriteTagToArray(field_number, WIRETYPE_END_GROUP, target);
+}
+inline uint8* WireFormat::WriteMessageToArray(int field_number,
+                                              const Message& value,
+                                              uint8* target) {
+  target = WriteTagToArray(field_number, WIRETYPE_LENGTH_DELIMITED, target);
+  target = io::CodedOutputStream::WriteVarint32ToArray(
+    value.GetCachedSize(), target);
+  return value.SerializeWithCachedSizesToArray(target);
+}
+
+template<typename MessageType>
+inline uint8* WireFormat::WriteGroupNoVirtualToArray(
+    int field_number, const MessageType& value, uint8* target) {
+  target = WriteTagToArray(field_number, WIRETYPE_START_GROUP, target);
+  target = value.MessageType::SerializeWithCachedSizesToArray(target);
+  return WriteTagToArray(field_number, WIRETYPE_END_GROUP, target);
+}
+template<typename MessageType>
+inline uint8* WireFormat::WriteMessageNoVirtualToArray(
+    int field_number, const MessageType& value, uint8* target) {
+  target = WriteTagToArray(field_number, WIRETYPE_LENGTH_DELIMITED, target);
+  target = io::CodedOutputStream::WriteVarint32ToArray(
+    value.MessageType::GetCachedSize(), target);
+  return value.MessageType::SerializeWithCachedSizesToArray(target);
 }
 
 // ===================================================================

+ 59 - 16
src/google/protobuf/wire_format_unittest.cc

@@ -44,6 +44,7 @@
 #include <google/protobuf/stubs/common.h>
 #include <google/protobuf/testing/googletest.h>
 #include <gtest/gtest.h>
+#include <google/protobuf/stubs/stl_util-inl.h>
 
 namespace google {
 namespace protobuf {
@@ -179,6 +180,7 @@ TEST(WireFormatTest, Serialize) {
     io::StringOutputStream raw_output(&generated_data);
     io::CodedOutputStream output(&raw_output);
     message.SerializeWithCachedSizes(&output);
+    ASSERT_FALSE(output.HadError());
   }
 
   // Serialize using WireFormat.
@@ -186,6 +188,7 @@ TEST(WireFormatTest, Serialize) {
     io::StringOutputStream raw_output(&dynamic_data);
     io::CodedOutputStream output(&raw_output);
     WireFormat::SerializeWithCachedSizes(message, size, &output);
+    ASSERT_FALSE(output.HadError());
   }
 
   // Should be the same.
@@ -207,6 +210,7 @@ TEST(WireFormatTest, SerializeExtensions) {
     io::StringOutputStream raw_output(&generated_data);
     io::CodedOutputStream output(&raw_output);
     message.SerializeWithCachedSizes(&output);
+    ASSERT_FALSE(output.HadError());
   }
 
   // Serialize using WireFormat.
@@ -214,6 +218,7 @@ TEST(WireFormatTest, SerializeExtensions) {
     io::StringOutputStream raw_output(&dynamic_data);
     io::CodedOutputStream output(&raw_output);
     WireFormat::SerializeWithCachedSizes(message, size, &output);
+    ASSERT_FALSE(output.HadError());
   }
 
   // Should be the same.
@@ -235,6 +240,7 @@ TEST(WireFormatTest, SerializeFieldsAndExtensions) {
     io::StringOutputStream raw_output(&generated_data);
     io::CodedOutputStream output(&raw_output);
     message.SerializeWithCachedSizes(&output);
+    ASSERT_FALSE(output.HadError());
   }
 
   // Serialize using WireFormat.
@@ -242,6 +248,7 @@ TEST(WireFormatTest, SerializeFieldsAndExtensions) {
     io::StringOutputStream raw_output(&dynamic_data);
     io::CodedOutputStream output(&raw_output);
     WireFormat::SerializeWithCachedSizes(message, size, &output);
+    ASSERT_FALSE(output.HadError());
   }
 
   // Should be the same.
@@ -287,8 +294,8 @@ TEST(WireFormatTest, SerializeMessageSet) {
     unittest::TestMessageSetExtension1::message_set_extension)->set_i(123);
   message_set.MutableExtension(
     unittest::TestMessageSetExtension2::message_set_extension)->set_str("foo");
-  message_set.mutable_unknown_fields()->AddField(kUnknownTypeId)
-                                      ->add_length_delimited("bar");
+  message_set.mutable_unknown_fields()->AddLengthDelimited(
+    kUnknownTypeId, "bar");
 
   string data;
   ASSERT_TRUE(message_set.SerializeToString(&data));
@@ -319,6 +326,43 @@ TEST(WireFormatTest, SerializeMessageSet) {
   EXPECT_EQ("bar", raw.item(2).message());
 }
 
+TEST(WireFormatTest, SerializeMessageSetToStreamAndArrayAreEqual) {
+  // Serialize a MessageSet to a stream and to a flat array and check that the
+  // results are equal.
+  // Set up a TestMessageSet with two known messages and an unknown one, as
+  // above.
+
+  unittest::TestMessageSet message_set;
+  message_set.MutableExtension(
+    unittest::TestMessageSetExtension1::message_set_extension)->set_i(123);
+  message_set.MutableExtension(
+    unittest::TestMessageSetExtension2::message_set_extension)->set_str("foo");
+  message_set.mutable_unknown_fields()->AddLengthDelimited(
+    kUnknownTypeId, "bar");
+
+  int size = message_set.ByteSize();
+  string flat_data;
+  string stream_data;
+  flat_data.resize(size);
+  stream_data.resize(size);
+  // Serialize to flat array
+  {
+    uint8* target = reinterpret_cast<uint8*>(string_as_array(&flat_data));
+    uint8* end = message_set.SerializeWithCachedSizesToArray(target);
+    EXPECT_EQ(size, end - target);
+  }
+
+  // Serialize to buffer
+  {
+    io::ArrayOutputStream array_stream(string_as_array(&stream_data), size, 1);
+    io::CodedOutputStream output_stream(&array_stream);
+    message_set.SerializeWithCachedSizes(&output_stream);
+    ASSERT_FALSE(output_stream.HadError());
+  }
+
+  EXPECT_TRUE(flat_data == stream_data);
+}
+
 TEST(WireFormatTest, ParseMessageSet) {
   // Set up a RawMessageSet with two known messages and an unknown one.
   unittest::RawMessageSet raw;
@@ -360,8 +404,9 @@ TEST(WireFormatTest, ParseMessageSet) {
     unittest::TestMessageSetExtension2::message_set_extension).str());
 
   ASSERT_EQ(1, message_set.unknown_fields().field_count());
-  ASSERT_EQ(1, message_set.unknown_fields().field(0).length_delimited_size());
-  EXPECT_EQ("bar", message_set.unknown_fields().field(0).length_delimited(0));
+  ASSERT_EQ(UnknownField::TYPE_LENGTH_DELIMITED,
+            message_set.unknown_fields().field(0).type());
+  EXPECT_EQ("bar", message_set.unknown_fields().field(0).length_delimited());
 }
 
 TEST(WireFormatTest, RecursionLimit) {
@@ -390,11 +435,11 @@ TEST(WireFormatTest, RecursionLimit) {
 TEST(WireFormatTest, UnknownFieldRecursionLimit) {
   unittest::TestEmptyMessage message;
   message.mutable_unknown_fields()
-        ->AddField(1234)->add_group()
-        ->AddField(1234)->add_group()
-        ->AddField(1234)->add_group()
-        ->AddField(1234)->add_group()
-        ->AddField(1234)->add_varint(123);
+        ->AddGroup(1234)
+        ->AddGroup(1234)
+        ->AddGroup(1234)
+        ->AddGroup(1234)
+        ->AddVarint(1234, 123);
   string data;
   message.SerializeToString(&data);
 
@@ -500,8 +545,7 @@ class WireFormatInvalidInputTest : public testing::Test {
       io::StringOutputStream raw_output(&result);
       io::CodedOutputStream output(&raw_output);
 
-      EXPECT_TRUE(WireFormat::WriteBytes(
-        field->number(), string(bytes, size), &output));
+      WireFormat::WriteBytes(field->number(), string(bytes, size), &output);
     }
 
     return result;
@@ -522,12 +566,11 @@ class WireFormatInvalidInputTest : public testing::Test {
       io::StringOutputStream raw_output(&result);
       io::CodedOutputStream output(&raw_output);
 
-      EXPECT_TRUE(output.WriteVarint32(WireFormat::MakeTag(field)));
-      EXPECT_TRUE(output.WriteString(string(bytes, size)));
+      output.WriteVarint32(WireFormat::MakeTag(field));
+      output.WriteString(string(bytes, size));
       if (include_end_tag) {
-        EXPECT_TRUE(
-          output.WriteVarint32(WireFormat::MakeTag(
-            field->number(), WireFormat::WIRETYPE_END_GROUP)));
+        output.WriteVarint32(WireFormat::MakeTag(
+          field->number(), WireFormat::WIRETYPE_END_GROUP));
       }
     }
 

Kaikkia tiedostoja ei voida näyttää, sillä liian monta tiedostoa muuttui tässä diffissä