Bläddra i källkod

Optimize measurement and serialization of nano protos.

Measuring the serialized size of nano protos is now a zero-alloc operation, and serializing a proto now allocates no memory (other than the output buffer) instead of O(total length of strings).

Change-Id: Id5e2ac3bdc4ac56c0bf13d725472da3a00c9baec
Signed-off-by: Charles Munger <clm@google.com>
Charles Munger 10 år sedan
förälder
incheckning
54511f701f

+ 215 - 23
javanano/src/main/java/com/google/protobuf/nano/CodedOutputByteBufferNano.java

@@ -31,6 +31,9 @@
 package com.google.protobuf.nano;
 
 import java.io.IOException;
+import java.nio.BufferOverflowException;
+import java.nio.ByteBuffer;
+import java.nio.ReadOnlyBufferException;
 
 /**
  * Encodes and writes protocol message fields.
@@ -47,15 +50,17 @@ import java.io.IOException;
  * @author kneton@google.com Kenton Varda
  */
 public final class CodedOutputByteBufferNano {
-  private final byte[] buffer;
-  private final int limit;
-  private int position;
+  /* max bytes per java UTF-16 char in UTF-8 */
+  private static final int MAX_UTF8_EXPANSION = 3;
+  private final ByteBuffer buffer;
 
   private CodedOutputByteBufferNano(final byte[] buffer, final int offset,
                             final int length) {
+    this(ByteBuffer.wrap(buffer, offset, length));
+  }
+
+  private CodedOutputByteBufferNano(final ByteBuffer buffer) {
     this.buffer = buffer;
-    position = offset;
-    limit = offset + length;
   }
 
   /**
@@ -287,14 +292,204 @@ public final class CodedOutputByteBufferNano {
 
   /** Write a {@code string} field to the stream. */
   public void writeStringNoTag(final String value) throws IOException {
-    // Unfortunately there does not appear to be any way to tell Java to encode
-    // UTF-8 directly into our buffer, so we have to let it create its own byte
-    // array and then copy.
-    final byte[] bytes = value.getBytes(InternalNano.UTF_8);
-    writeRawVarint32(bytes.length);
-    writeRawBytes(bytes);
+    // UTF-8 byte length of the string is at least its UTF-16 code unit length (value.length()),
+    // and at most 3 times of it. Optimize for the case where we know this length results in a
+    // constant varint length - saves measuring length of the string.
+    try {
+      final int minLengthVarIntSize = computeRawVarint32Size(value.length());
+      final int maxLengthVarIntSize = computeRawVarint32Size(value.length() * MAX_UTF8_EXPANSION);
+      if (minLengthVarIntSize == maxLengthVarIntSize) {
+        int oldPosition = buffer.position();
+        buffer.position(oldPosition + minLengthVarIntSize);
+        encode(value, buffer);
+        int newPosition = buffer.position();
+        buffer.position(oldPosition);
+        writeRawVarint32(newPosition - oldPosition - minLengthVarIntSize);
+        buffer.position(newPosition);
+      } else {
+        writeRawVarint32(encodedLength(value));
+        encode(value, buffer);
+      }
+    } catch (BufferOverflowException e) {
+      throw new OutOfSpaceException(buffer.position(), buffer.limit());
+    }
+  }
+
+  // These UTF-8 handling methods are copied from Guava's Utf8 class.
+  /**
+   * Returns the number of bytes in the UTF-8-encoded form of {@code sequence}. For a string,
+   * this method is equivalent to {@code string.getBytes(UTF_8).length}, but is more efficient in
+   * both time and space.
+   *
+   * @throws IllegalArgumentException if {@code sequence} contains ill-formed UTF-16 (unpaired
+   *     surrogates)
+   */
+  private static int encodedLength(CharSequence sequence) {
+    // Warning to maintainers: this implementation is highly optimized.
+    int utf16Length = sequence.length();
+    int utf8Length = utf16Length;
+    int i = 0;
+
+    // This loop optimizes for pure ASCII.
+    while (i < utf16Length && sequence.charAt(i) < 0x80) {
+      i++;
+    }
+
+    // This loop optimizes for chars less than 0x800.
+    for (; i < utf16Length; i++) {
+      char c = sequence.charAt(i);
+      if (c < 0x800) {
+        utf8Length += ((0x7f - c) >>> 31);  // branch free!
+      } else {
+        utf8Length += encodedLengthGeneral(sequence, i);
+        break;
+      }
+    }
+
+    if (utf8Length < utf16Length) {
+      // Necessary and sufficient condition for overflow because of maximum 3x expansion
+      throw new IllegalArgumentException("UTF-8 length does not fit in int: "
+              + (utf8Length + (1L << 32)));
+    }
+    return utf8Length;
+  }
+
+  private static int encodedLengthGeneral(CharSequence sequence, int start) {
+    int utf16Length = sequence.length();
+    int utf8Length = 0;
+    for (int i = start; i < utf16Length; i++) {
+      char c = sequence.charAt(i);
+      if (c < 0x800) {
+        utf8Length += (0x7f - c) >>> 31; // branch free!
+      } else {
+        utf8Length += 2;
+        // jdk7+: if (Character.isSurrogate(c)) {
+        if (Character.MIN_SURROGATE <= c && c <= Character.MAX_SURROGATE) {
+          // Check that we have a well-formed surrogate pair.
+          int cp = Character.codePointAt(sequence, i);
+          if (cp < Character.MIN_SUPPLEMENTARY_CODE_POINT) {
+            throw new IllegalArgumentException("Unpaired surrogate at index " + i);
+          }
+          i++;
+        }
+      }
+    }
+    return utf8Length;
   }
 
+  /**
+   * Encodes {@code sequence} into UTF-8, in {@code byteBuffer}. For a string, this method is
+   * equivalent to {@code buffer.put(string.getBytes(UTF_8))}, but is more efficient in both time
+   * and space. Bytes are written starting at the current position. This method requires paired
+   * surrogates, and therefore does not support chunking.
+   *
+   * <p>To ensure sufficient space in the output buffer, either call {@link #encodedLength} to
+   * compute the exact amount needed, or leave room for {@code 3 * sequence.length()}, which is the
+   * largest possible number of bytes that any input can be encoded to.
+   *
+   * @throws IllegalArgumentException if {@code sequence} contains ill-formed UTF-16 (unpaired
+   *     surrogates)
+   * @throws BufferOverflowException if {@code sequence} encoded in UTF-8 does not fit in
+   *     {@code byteBuffer}'s remaining space.
+   * @throws ReadOnlyBufferException if {@code byteBuffer} is a read-only buffer.
+   */
+  private static void encode(CharSequence sequence, ByteBuffer byteBuffer) {
+    if (byteBuffer.isReadOnly()) {
+      throw new ReadOnlyBufferException();
+    } else if (byteBuffer.hasArray()) {
+      try {
+        int encoded = encode(sequence,
+                byteBuffer.array(),
+                byteBuffer.arrayOffset() + byteBuffer.position(),
+                byteBuffer.remaining());
+        byteBuffer.position(encoded - byteBuffer.arrayOffset());
+      } catch (ArrayIndexOutOfBoundsException e) {
+        BufferOverflowException boe = new BufferOverflowException();
+        boe.initCause(e);
+        throw boe;
+      }
+    } else {
+      encodeDirect(sequence, byteBuffer);
+    }
+  }
+
+  private static void encodeDirect(CharSequence sequence, ByteBuffer byteBuffer) {
+    int utf16Length = sequence.length();
+    for (int i = 0; i < utf16Length; i++) {
+      final char c = sequence.charAt(i);
+      if (c < 0x80) { // ASCII
+        byteBuffer.put((byte) c);
+      } else if (c < 0x800) { // 11 bits, two UTF-8 bytes
+        byteBuffer.put((byte) ((0xF << 6) | (c >>> 6)));
+        byteBuffer.put((byte) (0x80 | (0x3F & c)));
+      } else if (c < Character.MIN_SURROGATE || Character.MAX_SURROGATE < c) {
+        // Maximium single-char code point is 0xFFFF, 16 bits, three UTF-8 bytes
+        byteBuffer.put((byte) ((0xF << 5) | (c >>> 12)));
+        byteBuffer.put((byte) (0x80 | (0x3F & (c >>> 6))));
+        byteBuffer.put((byte) (0x80 | (0x3F & c)));
+      } else {
+        final char low;
+        if (i + 1 == sequence.length()
+                || !Character.isSurrogatePair(c, (low = sequence.charAt(++i)))) {
+          throw new IllegalArgumentException("Unpaired surrogate at index " + (i - 1));
+        }
+        int codePoint = Character.toCodePoint(c, low);
+        byteBuffer.put((byte) ((0xF << 4) | (codePoint >>> 18)));
+        byteBuffer.put((byte) (0x80 | (0x3F & (codePoint >>> 12))));
+        byteBuffer.put((byte) (0x80 | (0x3F & (codePoint >>> 6))));
+        byteBuffer.put((byte) (0x80 | (0x3F & codePoint)));
+      }
+    }
+  }
+
+  private static int encode(CharSequence sequence, byte[] bytes, int offset, int length) {
+    int utf16Length = sequence.length();
+    int j = offset;
+    int i = 0;
+    int limit = offset + length;
+    // Designed to take advantage of
+    // https://wikis.oracle.com/display/HotSpotInternals/RangeCheckElimination
+    for (char c; i < utf16Length && i + j < limit && (c = sequence.charAt(i)) < 0x80; i++) {
+      bytes[j + i] = (byte) c;
+    }
+    if (i == utf16Length) {
+      return j + utf16Length;
+    }
+    j += i;
+    for (char c; i < utf16Length; i++) {
+      c = sequence.charAt(i);
+      if (c < 0x80 && j < limit) {
+        bytes[j++] = (byte) c;
+      } else if (c < 0x800 && j <= limit - 2) { // 11 bits, two UTF-8 bytes
+        bytes[j++] = (byte) ((0xF << 6) | (c >>> 6));
+        bytes[j++] = (byte) (0x80 | (0x3F & c));
+      } else if ((c < Character.MIN_SURROGATE || Character.MAX_SURROGATE < c) && j <= limit - 3) {
+        // Maximum single-char code point is 0xFFFF, 16 bits, three UTF-8 bytes
+        bytes[j++] = (byte) ((0xF << 5) | (c >>> 12));
+        bytes[j++] = (byte) (0x80 | (0x3F & (c >>> 6)));
+        bytes[j++] = (byte) (0x80 | (0x3F & c));
+      } else if (j <= limit - 4) {
+        // Minimum code point represented by a surrogate pair is 0x10000, 17 bits, four UTF-8 bytes
+        final char low;
+        if (i + 1 == sequence.length()
+                || !Character.isSurrogatePair(c, (low = sequence.charAt(++i)))) {
+          throw new IllegalArgumentException("Unpaired surrogate at index " + (i - 1));
+        }
+        int codePoint = Character.toCodePoint(c, low);
+        bytes[j++] = (byte) ((0xF << 4) | (codePoint >>> 18));
+        bytes[j++] = (byte) (0x80 | (0x3F & (codePoint >>> 12)));
+        bytes[j++] = (byte) (0x80 | (0x3F & (codePoint >>> 6)));
+        bytes[j++] = (byte) (0x80 | (0x3F & codePoint));
+      } else {
+        throw new ArrayIndexOutOfBoundsException("Failed writing " + c + " at index " + j);
+      }
+    }
+    return j;
+  }
+
+  // End guava UTF-8 methods
+
+
   /** Write a {@code group} field to the stream. */
   public void writeGroupNoTag(final MessageNano value) throws IOException {
     value.writeTo(this);
@@ -602,9 +797,8 @@ public final class CodedOutputByteBufferNano {
    * {@code string} field.
    */
   public static int computeStringSizeNoTag(final String value) {
-    final byte[] bytes = value.getBytes(InternalNano.UTF_8);
-    return computeRawVarint32Size(bytes.length) +
-           bytes.length;
+    final int length = encodedLength(value);
+    return computeRawVarint32Size(length) + length;
   }
 
   /**
@@ -687,7 +881,7 @@ public final class CodedOutputByteBufferNano {
    * Otherwise, throws {@code UnsupportedOperationException}.
    */
   public int spaceLeft() {
-    return limit - position;
+    return buffer.remaining();
   }
 
   /**
@@ -720,12 +914,12 @@ public final class CodedOutputByteBufferNano {
 
   /** Write a single byte. */
   public void writeRawByte(final byte value) throws IOException {
-    if (position == limit) {
+    if (!buffer.hasRemaining()) {
       // We're writing to a single buffer.
-      throw new OutOfSpaceException(position, limit);
+      throw new OutOfSpaceException(buffer.position(), buffer.limit());
     }
 
-    buffer[position++] = value;
+    buffer.put(value);
   }
 
   /** Write a single byte, represented by an integer value. */
@@ -741,13 +935,11 @@ public final class CodedOutputByteBufferNano {
   /** Write part of an array of bytes. */
   public void writeRawBytes(final byte[] value, int offset, int length)
                             throws IOException {
-    if (limit - position >= length) {
-      // We have room in the current buffer.
-      System.arraycopy(value, offset, buffer, position, length);
-      position += length;
+    if (buffer.remaining() >= length) {
+      buffer.put(value, offset, length);
     } else {
       // We're writing to a single buffer.
-      throw new OutOfSpaceException(position, limit);
+      throw new OutOfSpaceException(buffer.position(), buffer.limit());
     }
   }
 

+ 36 - 0
javanano/src/test/java/com/google/protobuf/nano/NanoTest.java

@@ -2300,6 +2300,42 @@ public class NanoTest extends TestCase {
     }
   }
 
+  public void testDifferentStringLengthsNano() throws Exception {
+    // Test string serialization roundtrip using strings of the following lengths,
+    // with ASCII and Unicode characters requiring different UTF-8 byte counts per
+    // char, hence causing the length delimiter varint to sometimes require more
+    // bytes for the Unicode strings than the ASCII string of the same length.
+    int[] lengths = new int[] {
+            0,
+            1,
+            (1 << 4) - 1,  // 1 byte for ASCII and Unicode
+            (1 << 7) - 1,  // 1 byte for ASCII, 2 bytes for Unicode
+            (1 << 11) - 1, // 2 bytes for ASCII and Unicode
+            (1 << 14) - 1, // 2 bytes for ASCII, 3 bytes for Unicode
+            (1 << 17) - 1, // 3 bytes for ASCII and Unicode
+    };
+    for (int i : lengths) {
+      testEncodingOfString('q', i);      // 1 byte per char
+      testEncodingOfString('\u07FF', i); // 2 bytes per char
+      testEncodingOfString('\u0981', i); // 3 bytes per char
+    }
+  }
+
+  private void testEncodingOfString(char c, int length) throws InvalidProtocolBufferNanoException {
+    TestAllTypesNano testAllTypesNano = new TestAllTypesNano();
+    final String fullString = fullString(c, length);
+    testAllTypesNano.optionalString = fullString;
+    final TestAllTypesNano resultNano = new TestAllTypesNano();
+    MessageNano.mergeFrom(resultNano, MessageNano.toByteArray(testAllTypesNano));
+    assertEquals(fullString, resultNano.optionalString);
+  }
+
+  private String fullString(char c, int length) {
+    char[] result = new char[length];
+    Arrays.fill(result, c);
+    return new String(result);
+  }
+
   public void testNanoWithHasParseFrom() throws Exception {
     TestAllTypesNanoHas msg = null;
     // Test false on creation, after clear and upon empty parse.