Selaa lähdekoodia

Merge pull request #7644 from JamesNK/jamesnk/parsestring-multisegment

Optimize reading strings across segments
Jan Tattermusch 5 vuotta sitten
vanhempi
commit
96a62f7cc7

+ 3 - 0
.gitignore

@@ -209,3 +209,6 @@ cmake/cmake-build-debug/
 # IntelliJ
 .idea
 *.iml
+
+# BenchmarkDotNet
+BenchmarkDotNet.Artifacts/

+ 4 - 0
csharp/src/Google.Protobuf.Benchmarks/Google.Protobuf.Benchmarks.csproj

@@ -11,6 +11,10 @@
     <DebugSymbols>true</DebugSymbols>
   </PropertyGroup>
 
+  <ItemGroup>
+    <Compile Include="..\Google.Protobuf.Test\ReadOnlySequenceFactory.cs" Link="ReadOnlySequenceFactory.cs" />
+  </ItemGroup>
+
   <ItemGroup>
     <PackageReference Include="BenchmarkDotNet" Version="0.12.1" />
     <ProjectReference Include="..\Google.Protobuf\Google.Protobuf.csproj" />

+ 41 - 1
csharp/src/Google.Protobuf.Benchmarks/ParseRawPrimitivesBenchmark.cs

@@ -54,10 +54,12 @@ namespace Google.Protobuf.Benchmarks
 
         // key is the encodedSize of string values
         Dictionary<int, byte[]> stringInputBuffers;
+        Dictionary<int, ReadOnlySequence<byte>> stringInputBuffersSegmented;
 
         Random random = new Random(417384220);  // random but deterministic seed
 
         public IEnumerable<int> StringEncodedSizes => new[] { 1, 4, 10, 105, 10080 };
+        public IEnumerable<int> StringSegmentedEncodedSizes => new[] { 105, 10080 };
 
         [GlobalSetup]
         public void GlobalSetup()
@@ -78,11 +80,18 @@ namespace Google.Protobuf.Benchmarks
             fixedIntInputBuffer = CreateBufferWithRandomData(random, BytesToParse / sizeof(long), sizeof(long), paddingValueCount);
 
             stringInputBuffers = new Dictionary<int, byte[]>();
-            foreach(var encodedSize in StringEncodedSizes)
+            foreach (var encodedSize in StringEncodedSizes)
             {
                 byte[] buffer = CreateBufferWithStrings(BytesToParse / encodedSize, encodedSize, encodedSize < 10 ? 10 : 1 );
                 stringInputBuffers.Add(encodedSize, buffer);
             }
+
+            stringInputBuffersSegmented = new Dictionary<int, ReadOnlySequence<byte>>();
+            foreach (var encodedSize in StringSegmentedEncodedSizes)
+            {
+                byte[] buffer = CreateBufferWithStrings(BytesToParse / encodedSize, encodedSize, encodedSize < 10 ? 10 : 1);
+                stringInputBuffersSegmented.Add(encodedSize, ReadOnlySequenceFactory.CreateWithContent(buffer, segmentSize: 128, addEmptySegmentDelimiters: false));
+            }
         }
 
         // Total number of bytes that each benchmark will parse.
@@ -300,6 +309,19 @@ namespace Google.Protobuf.Benchmarks
             return sum;
         }
 
+        [Benchmark]
+        [ArgumentsSource(nameof(StringSegmentedEncodedSizes))]
+        public int ParseString_ParseContext_MultipleSegments(int encodedSize)
+        {
+            InitializeParseContext(stringInputBuffersSegmented[encodedSize], out ParseContext ctx);
+            int sum = 0;
+            for (int i = 0; i < BytesToParse / encodedSize; i++)
+            {
+                sum += ctx.ReadString().Length;
+            }
+            return sum;
+        }
+
         [Benchmark]
         [ArgumentsSource(nameof(StringEncodedSizes))]
         public int ParseBytes_CodedInputStream(int encodedSize)
@@ -326,11 +348,29 @@ namespace Google.Protobuf.Benchmarks
             return sum;
         }
 
+        [Benchmark]
+        [ArgumentsSource(nameof(StringSegmentedEncodedSizes))]
+        public int ParseBytes_ParseContext_MultipleSegments(int encodedSize)
+        {
+            InitializeParseContext(stringInputBuffersSegmented[encodedSize], out ParseContext ctx);
+            int sum = 0;
+            for (int i = 0; i < BytesToParse / encodedSize; i++)
+            {
+                sum += ctx.ReadBytes().Length;
+            }
+            return sum;
+        }
+
         private static void InitializeParseContext(byte[] buffer, out ParseContext ctx)
         {
             ParseContext.Initialize(new ReadOnlySequence<byte>(buffer), out ctx);
         }
 
+        private static void InitializeParseContext(ReadOnlySequence<byte> buffer, out ParseContext ctx)
+        {
+            ParseContext.Initialize(buffer, out ctx);
+        }
+
         private static byte[] CreateBufferWithRandomVarints(Random random, int valueCount, int encodedSize, int paddingValueCount)
         {
             MemoryStream ms = new MemoryStream();

+ 83 - 1
csharp/src/Google.Protobuf.Test/CodedInputStreamTest.cs

@@ -324,7 +324,25 @@ namespace Google.Protobuf
                 Assert.AreEqual(message, message2);
             }
         }
-                
+
+        [Test]
+        public void ReadWholeMessage_VaryingBlockSizes_FromSequence()
+        {
+            TestAllTypes message = SampleMessages.CreateFullTestAllTypes();
+
+            byte[] rawBytes = message.ToByteArray();
+            Assert.AreEqual(rawBytes.Length, message.CalculateSize());
+            TestAllTypes message2 = TestAllTypes.Parser.ParseFrom(rawBytes);
+            Assert.AreEqual(message, message2);
+
+            // Try different block sizes.
+            for (int blockSize = 1; blockSize < 256; blockSize *= 2)
+            {
+                message2 = TestAllTypes.Parser.ParseFrom(ReadOnlySequenceFactory.CreateWithContent(rawBytes, blockSize));
+                Assert.AreEqual(message, message2);
+            }
+        }
+
         [Test]
         public void ReadHugeBlob()
         {
@@ -365,6 +383,70 @@ namespace Google.Protobuf
             Assert.Throws<InvalidProtocolBufferException>(() => input.ReadBytes());
         }
 
+        [Test]
+        public void ReadBlobGreaterThanCurrentLimit()
+        {
+            MemoryStream ms = new MemoryStream();
+            CodedOutputStream output = new CodedOutputStream(ms);
+            uint tag = WireFormat.MakeTag(1, WireFormat.WireType.LengthDelimited);
+            output.WriteRawVarint32(tag);
+            output.WriteRawVarint32(4);
+            output.WriteRawBytes(new byte[4]); // Pad with a few random bytes.
+            output.Flush();
+            ms.Position = 0;
+
+            CodedInputStream input = new CodedInputStream(ms);
+            Assert.AreEqual(tag, input.ReadTag());
+
+            // Specify limit smaller than data length
+            input.PushLimit(3);
+            Assert.Throws<InvalidProtocolBufferException>(() => input.ReadBytes());
+
+            AssertReadFromParseContext(new ReadOnlySequence<byte>(ms.ToArray()), (ref ParseContext ctx) =>
+            {
+                Assert.AreEqual(tag, ctx.ReadTag());
+                SegmentedBufferHelper.PushLimit(ref ctx.state, 3);
+                try
+                {
+                    ctx.ReadBytes();
+                    Assert.Fail();
+                }
+                catch (InvalidProtocolBufferException) {}
+            }, true);
+        }
+
+        [Test]
+        public void ReadStringGreaterThanCurrentLimit()
+        {
+            MemoryStream ms = new MemoryStream();
+            CodedOutputStream output = new CodedOutputStream(ms);
+            uint tag = WireFormat.MakeTag(1, WireFormat.WireType.LengthDelimited);
+            output.WriteRawVarint32(tag);
+            output.WriteRawVarint32(4);
+            output.WriteRawBytes(new byte[4]); // Pad with a few random bytes.
+            output.Flush();
+            ms.Position = 0;
+
+            CodedInputStream input = new CodedInputStream(ms.ToArray());
+            Assert.AreEqual(tag, input.ReadTag());
+
+            // Specify limit smaller than data length
+            input.PushLimit(3);
+            Assert.Throws<InvalidProtocolBufferException>(() => input.ReadString());
+
+            AssertReadFromParseContext(new ReadOnlySequence<byte>(ms.ToArray()), (ref ParseContext ctx) =>
+            {
+                Assert.AreEqual(tag, ctx.ReadTag());
+                SegmentedBufferHelper.PushLimit(ref ctx.state, 3);
+                try
+                {
+                    ctx.ReadString();
+                    Assert.Fail();
+                }
+                catch (InvalidProtocolBufferException) { }
+            }, true);
+        }
+
         // Representations of a tag for field 0 with various wire types
         [Test]
         [TestCase(0)]

+ 14 - 3
csharp/src/Google.Protobuf.Test/ReadOnlySequenceFactory.cs

@@ -41,11 +41,18 @@ namespace Google.Protobuf
 {
     internal static class ReadOnlySequenceFactory
     {
-        public static ReadOnlySequence<byte> CreateWithContent(byte[] data, int segmentSize = 1)
+        /// <summary>
+        /// Create a sequence from the specified data. The data will be divided up into segments in the sequence.
+        /// </summary>
+        public static ReadOnlySequence<byte> CreateWithContent(byte[] data, int segmentSize = 1, bool addEmptySegmentDelimiters = true)
         {
             var segments = new List<byte[]>();
 
-            segments.Add(new byte[0]);
+            if (addEmptySegmentDelimiters)
+            {
+                segments.Add(new byte[0]);
+            }
+
             var currentIndex = 0;
             while (currentIndex < data.Length)
             {
@@ -55,7 +62,11 @@ namespace Google.Protobuf
                     segment.Add(data[currentIndex++]);
                 }
                 segments.Add(segment.ToArray());
-                segments.Add(new byte[0]);
+
+                if (addEmptySegmentDelimiters)
+                {
+                    segments.Add(new byte[0]);
+                }
             }
 
             return CreateSegments(segments.ToArray());

+ 1 - 19
csharp/src/Google.Protobuf/Collections/RepeatedField.cs

@@ -133,7 +133,7 @@ namespace Google.Protobuf.Collections
                     //
                     // Check that the supplied length doesn't exceed the underlying buffer.
                     // That prevents a malicious length from initializing a very large collection.
-                    if (codec.FixedSize > 0 && length % codec.FixedSize == 0 && IsDataAvailable(ref ctx, length))
+                    if (codec.FixedSize > 0 && length % codec.FixedSize == 0 && ParsingPrimitives.IsDataAvailable(ref ctx.state, length))
                     {
                         EnsureSize(count + (length / codec.FixedSize));
 
@@ -167,24 +167,6 @@ namespace Google.Protobuf.Collections
             }
         }
 
-        private bool IsDataAvailable(ref ParseContext ctx, int size)
-        {
-            // Data fits in remaining buffer
-            if (size <= ctx.state.bufferSize - ctx.state.bufferPos)
-            {
-                return true;
-            }
-
-            // Data fits in remaining source data.
-            // Note that this will never be true when reading from a stream as the total length is unknown.
-            if (size < ctx.state.segmentedBufferHelper.TotalLength - ctx.state.totalBytesRetired - ctx.state.bufferPos)
-            {
-                return true;
-            }
-
-            return false;
-        }
-
         /// <summary>
         /// Calculates the size of this collection based on the given codec.
         /// </summary>

+ 6 - 6
csharp/src/Google.Protobuf/ParserInternalState.cs

@@ -43,7 +43,7 @@ using Google.Protobuf.Collections;
 
 namespace Google.Protobuf
 {
-    
+
     // warning: this is a mutable struct, so it needs to be only passed as a ref!
     internal struct ParserInternalState
     {
@@ -54,12 +54,12 @@ namespace Google.Protobuf
         /// The position within the current buffer (i.e. the next byte to read)
         /// </summary>
         internal int bufferPos;
-        
+
         /// <summary>
         /// Size of the current buffer
         /// </summary>
         internal int bufferSize;
-        
+
         /// <summary>
         /// If we are currently inside a length-delimited block, this is the number of
         /// bytes in the buffer that are still available once we leave the delimited block.
@@ -79,9 +79,9 @@ namespace Google.Protobuf
         internal int totalBytesRetired;
 
         internal int recursionDepth;  // current recursion depth
-        
+
         internal SegmentedBufferHelper segmentedBufferHelper;
-        
+
         /// <summary>
         /// The last tag we read. 0 indicates we've read to the end of the stream
         /// (or haven't read anything yet).
@@ -101,7 +101,7 @@ namespace Google.Protobuf
         // If non-null, the top level parse method was started with given coded input stream as an argument
         // which also means we can potentially fallback to calling MergeFrom(CodedInputStream cis) if needed.
         internal CodedInputStream CodedInputStream => segmentedBufferHelper.CodedInputStream;
-        
+
         /// <summary>
         /// Internal-only property; when set to true, unknown fields will be discarded while parsing.
         /// </summary>

+ 131 - 51
csharp/src/Google.Protobuf/ParsingPrimitives.cs

@@ -34,6 +34,7 @@ using System;
 using System.Buffers;
 using System.Buffers.Binary;
 using System.Collections.Generic;
+using System.Diagnostics;
 using System.IO;
 using System.Runtime.CompilerServices;
 using System.Runtime.InteropServices;
@@ -49,6 +50,7 @@ namespace Google.Protobuf
     [SecuritySafeCritical]
     internal static class ParsingPrimitives
     {
+        private const int StackallocThreshold = 256;
 
         /// <summary>
         /// Reads a length for length-delimited data.
@@ -58,7 +60,6 @@ namespace Google.Protobuf
         /// to make the calling code clearer.
         /// </remarks>
         [MethodImpl(MethodImplOptions.AggressiveInlining)]
-
         public static int ParseLength(ref ReadOnlySpan<byte> buffer, ref ParserInternalState state)
         {
             return (int)ParseRawVarint32(ref buffer, ref state);
@@ -437,14 +438,6 @@ namespace Google.Protobuf
                 throw InvalidProtocolBufferException.NegativeSize();
             }
 
-            if (state.totalBytesRetired + state.bufferPos + size > state.currentLimit)
-            {
-                // Read to the end of the stream (up to the current limit) anyway.
-                SkipRawBytes(ref buffer, ref state, state.currentLimit - state.totalBytesRetired - state.bufferPos);
-                // Then fail.
-                throw InvalidProtocolBufferException.TruncatedMessage();
-            }
-
             if (size <= state.bufferSize - state.bufferPos)
             {
                 // We have all the bytes we need already.
@@ -453,36 +446,22 @@ namespace Google.Protobuf
                 state.bufferPos += size;
                 return bytes;
             }
-            else if (size < buffer.Length || size < state.segmentedBufferHelper.TotalLength)
+
+            return ReadRawBytesSlow(ref buffer, ref state, size);
+        }
+
+        private static byte[] ReadRawBytesSlow(ref ReadOnlySpan<byte> buffer, ref ParserInternalState state, int size)
+        {
+            ValidateCurrentLimit(ref buffer, ref state, size);
+
+            if ((!state.segmentedBufferHelper.TotalLength.HasValue && size < buffer.Length) ||
+                IsDataAvailableInSource(ref state, size))
             {
                 // Reading more bytes than are in the buffer, but not an excessive number
                 // of bytes.  We can safely allocate the resulting array ahead of time.
 
-                // First copy what we have.
                 byte[] bytes = new byte[size];
-                var bytesSpan = new Span<byte>(bytes);
-                int pos = state.bufferSize - state.bufferPos;
-                buffer.Slice(state.bufferPos, pos).CopyTo(bytesSpan.Slice(0, pos));
-                state.bufferPos = state.bufferSize;
-
-                // We want to use RefillBuffer() and then copy from the buffer into our
-                // byte array rather than reading directly into our byte array because
-                // the input may be unbuffered.
-                state.segmentedBufferHelper.RefillBuffer(ref buffer, ref state, true);
-
-                while (size - pos > state.bufferSize)
-                {
-                    buffer.Slice(0, state.bufferSize)
-                        .CopyTo(bytesSpan.Slice(pos, state.bufferSize));
-                    pos += state.bufferSize;
-                    state.bufferPos = state.bufferSize;
-                    state.segmentedBufferHelper.RefillBuffer(ref buffer, ref state, true);
-                }
-
-                buffer.Slice(0, size - pos)
-                        .CopyTo(bytesSpan.Slice(pos, size - pos));
-                state.bufferPos = size - pos;
-
+                ReadRawBytesIntoSpan(ref buffer, ref state, size, bytes);
                 return bytes;
             }
             else
@@ -518,7 +497,7 @@ namespace Google.Protobuf
                 }
 
                 // OK, got everything.  Now concatenate it all into one buffer.
-                byte[] bytes = new byte[size];          
+                byte[] bytes = new byte[size];
                 int newPos = 0;
                 foreach (byte[] chunk in chunks)
                 {
@@ -543,13 +522,7 @@ namespace Google.Protobuf
                 throw InvalidProtocolBufferException.NegativeSize();
             }
 
-            if (state.totalBytesRetired + state.bufferPos + size > state.currentLimit)
-            {
-                // Read to the end of the stream anyway.
-                SkipRawBytes(ref buffer, ref state, state.currentLimit - state.totalBytesRetired - state.bufferPos);
-                // Then fail.
-                throw InvalidProtocolBufferException.TruncatedMessage();
-            }
+            ValidateCurrentLimit(ref buffer, ref state, size);
 
             if (size <= state.bufferSize - state.bufferPos)
             {
@@ -619,7 +592,7 @@ namespace Google.Protobuf
             }
 
 #if GOOGLE_PROTOBUF_SUPPORT_FAST_STRING
-            if (length <= state.bufferSize - state.bufferPos && length > 0)
+            if (length <= state.bufferSize - state.bufferPos)
             {
                 // Fast path: all bytes to decode appear in the same span.
                 ReadOnlySpan<byte> data = buffer.Slice(state.bufferPos, length);
@@ -638,20 +611,76 @@ namespace Google.Protobuf
             }
 #endif
 
-            var decoder = WritingPrimitives.Utf8Encoding.GetDecoder();
+            return ReadStringSlow(ref buffer, ref state, length);
+        }
 
-            // TODO: even if GOOGLE_PROTOBUF_SUPPORT_FAST_STRING is not supported,
-            // we could still create a string efficiently by using Utf8Encoding.GetString(byte[] bytes, int index, int count)
-            // whenever the buffer is backed by a byte array (and avoid creating a new byte array), but the problem is
-            // there is no way to get the underlying byte array from a span.
+        /// <summary>
+        /// Reads a string assuming that it is spread across multiple spans in a <see cref="ReadOnlySequence{T}"/>.
+        /// </summary>
+        private static string ReadStringSlow(ref ReadOnlySpan<byte> buffer, ref ParserInternalState state, int length)
+        {
+            ValidateCurrentLimit(ref buffer, ref state, length);
+
+#if GOOGLE_PROTOBUF_SUPPORT_FAST_STRING
+            if (IsDataAvailable(ref state, length))
+            {
+                // Read string data into a temporary buffer, either stackalloc'ed or from ArrayPool
+                // Once all data is read then call Encoding.GetString on buffer and return to pool if needed.
 
-            // TODO: in case the string spans multiple buffer segments, creating a char[] and decoding into it and then
-            // creating a string from that array might be more efficient than creating a string from the copied bytes.
+                byte[] byteArray = null;
+                Span<byte> byteSpan = length <= StackallocThreshold ?
+                    stackalloc byte[length] :
+                    (byteArray = ArrayPool<byte>.Shared.Rent(length));
+
+                try
+                {
+                    unsafe
+                    {
+                        fixed (byte* pByteSpan = &MemoryMarshal.GetReference(byteSpan))
+                        {
+                            // Compiler doesn't like that a potentially stackalloc'd Span<byte> is being used
+                            // in a method with a "ref Span<byte> buffer" argument. If the stackalloc'd span was assigned
+                            // to the ref argument then bad things would happen. We'll never do that so it is ok.
+                            // Make compiler happy by passing a new span created from pointer.
+                            var tempSpan = new Span<byte>(pByteSpan, byteSpan.Length);
+                            ReadRawBytesIntoSpan(ref buffer, ref state, length, tempSpan);
+
+                            return WritingPrimitives.Utf8Encoding.GetString(pByteSpan, length);
+                        }
+                    }
+                }
+                finally
+                {
+                    if (byteArray != null)
+                    {
+                        ArrayPool<byte>.Shared.Return(byteArray);
+                    }
+                }
+            }
+#endif
 
             // Slow path: Build a byte array first then copy it.
+            // This will be called when reading from a Stream because we don't know the length of the stream,
+            // or there is not enough data in the sequence. If there is not enough data then ReadRawBytes will
+            // throw an exception.
             return WritingPrimitives.Utf8Encoding.GetString(ReadRawBytes(ref buffer, ref state, length), 0, length);
         }
 
+        /// <summary>
+        /// Validates that the specified size doesn't exceed the current limit. If it does then remaining bytes
+        /// are skipped and an error is thrown.
+        /// </summary>
+        private static void ValidateCurrentLimit(ref ReadOnlySpan<byte> buffer, ref ParserInternalState state, int size)
+        {
+            if (state.totalBytesRetired + state.bufferPos + size > state.currentLimit)
+            {
+                // Read to the end of the stream (up to the current limit) anyway.
+                SkipRawBytes(ref buffer, ref state, state.currentLimit - state.totalBytesRetired - state.bufferPos);
+                // Then fail.
+                throw InvalidProtocolBufferException.TruncatedMessage();
+            }
+        }
+
         [SecuritySafeCritical]
         private static byte ReadRawByte(ref ReadOnlySpan<byte> buffer, ref ParserInternalState state)
         {
@@ -731,5 +760,56 @@ namespace Google.Protobuf
         {
             return (long)(n >> 1) ^ -(long)(n & 1);
         }
+
+        /// <summary>
+        /// Checks whether there is known data available of the specified size remaining to parse.
+        /// When parsing from a Stream this can return false because we have no knowledge of the amount
+        /// of data remaining in the stream until it is read.
+        /// </summary>
+        public static bool IsDataAvailable(ref ParserInternalState state, int size)
+        {
+            // Data fits in remaining buffer
+            if (size <= state.bufferSize - state.bufferPos)
+            {
+                return true;
+            }
+
+            return IsDataAvailableInSource(ref state, size);
+        }
+
+        /// <summary>
+        /// Checks whether there is known data available of the specified size remaining to parse
+        /// in the underlying data source.
+        /// When parsing from a Stream this will return false because we have no knowledge of the amount
+        /// of data remaining in the stream until it is read.
+        /// </summary>
+        private static bool IsDataAvailableInSource(ref ParserInternalState state, int size)
+        {
+            // Data fits in remaining source data.
+            // Note that this will never be true when reading from a stream as the total length is unknown.
+            return size <= state.segmentedBufferHelper.TotalLength - state.totalBytesRetired - state.bufferPos;
+        }
+
+        /// <summary>
+        /// Read raw bytes of the specified length into a span. The amount of data available and the current limit should
+        /// be checked before calling this method.
+        /// </summary>
+        private static void ReadRawBytesIntoSpan(ref ReadOnlySpan<byte> buffer, ref ParserInternalState state, int length, Span<byte> byteSpan)
+        {
+            int remainingByteLength = length;
+            while (remainingByteLength > 0)
+            {
+                if (state.bufferSize - state.bufferPos == 0)
+                {
+                    state.segmentedBufferHelper.RefillBuffer(ref buffer, ref state, true);
+                }
+
+                ReadOnlySpan<byte> unreadSpan = buffer.Slice(state.bufferPos, Math.Min(remainingByteLength, state.bufferSize - state.bufferPos));
+                unreadSpan.CopyTo(byteSpan.Slice(length - remainingByteLength));
+
+                remainingByteLength -= unreadSpan.Length;
+                state.bufferPos += unreadSpan.Length;
+            }
+        }
     }
-}
+}