فهرست منبع

Merge pull request #7412 from JamesNK/jamesnk/repeated-fixed-parsing

Improve performance of parsing repeated fixed sized types
Jan Tattermusch 5 سال پیش
والد
کامیت
efbadb6c73

+ 39 - 0
csharp/src/Google.Protobuf.Benchmarks/ParseMessagesBenchmark.cs

@@ -37,6 +37,7 @@ using System.IO;
 using System.Linq;
 using System.Buffers;
 using Google.Protobuf.WellKnownTypes;
+using Benchmarks.Proto3;
 
 namespace Google.Protobuf.Benchmarks
 {
@@ -50,6 +51,7 @@ namespace Google.Protobuf.Benchmarks
 
         SubTest manyWrapperFieldsTest = new SubTest(CreateManyWrapperFieldsMessage(), ManyWrapperFieldsMessage.Parser, () => new ManyWrapperFieldsMessage(), MaxMessages);
         SubTest manyPrimitiveFieldsTest = new SubTest(CreateManyPrimitiveFieldsMessage(), ManyPrimitiveFieldsMessage.Parser, () => new ManyPrimitiveFieldsMessage(), MaxMessages);
+        SubTest repeatedFieldTest = new SubTest(CreateRepeatedFieldMessage(), GoogleMessage1.Parser, () => new GoogleMessage1(), MaxMessages);
         SubTest emptyMessageTest = new SubTest(new Empty(), Empty.Parser, () => new Empty(), MaxMessages);
 
         public IEnumerable<int> MessageCountValues => new[] { 10, 100 };
@@ -83,6 +85,18 @@ namespace Google.Protobuf.Benchmarks
             return manyPrimitiveFieldsTest.ParseFromReadOnlySequence();
         }
 
+        [Benchmark]
+        public IMessage RepeatedFieldMessage_ParseFromByteArray()
+        {
+            return repeatedFieldTest.ParseFromByteArray();
+        }
+
+        [Benchmark]
+        public IMessage RepeatedFieldMessage_ParseFromReadOnlySequence()
+        {
+            return repeatedFieldTest.ParseFromReadOnlySequence();
+        }
+
         [Benchmark]
         public IMessage EmptyMessage_ParseFromByteArray()
         {
@@ -123,6 +137,20 @@ namespace Google.Protobuf.Benchmarks
             manyPrimitiveFieldsTest.ParseDelimitedMessagesFromReadOnlySequence(messageCount);
         }
 
+        [Benchmark]
+        [ArgumentsSource(nameof(MessageCountValues))]
+        public void RepeatedFieldMessage_ParseDelimitedMessagesFromByteArray(int messageCount)
+        {
+            repeatedFieldTest.ParseDelimitedMessagesFromByteArray(messageCount);
+        }
+
+        [Benchmark]
+        [ArgumentsSource(nameof(MessageCountValues))]
+        public void RepeatedFieldMessage_ParseDelimitedMessagesFromReadOnlySequence(int messageCount)
+        {
+            repeatedFieldTest.ParseDelimitedMessagesFromReadOnlySequence(messageCount);
+        }
+
         private static ManyWrapperFieldsMessage CreateManyWrapperFieldsMessage()
         {
             // Example data match data of an internal benchmarks
@@ -157,6 +185,17 @@ namespace Google.Protobuf.Benchmarks
             };
         }
 
+        private static GoogleMessage1 CreateRepeatedFieldMessage()
+        {
+            // Message with a repeated fixed length item collection
+            var message = new GoogleMessage1();
+            for (ulong i = 0; i < 1000; i++)
+            {
+                message.Field5.Add(i);
+            }
+            return message;
+        }
+
         private class SubTest
         {
             private readonly IMessage message;

+ 89 - 0
csharp/src/Google.Protobuf.Test/Collections/RepeatedFieldTest.cs

@@ -595,6 +595,95 @@ namespace Google.Protobuf.Collections
             Assert.AreEqual(((SampleEnum)(-5)), values[5]);
         }
 
+        [Test]
+        public void TestPackedRepeatedFieldCollectionNonDivisibleLength()
+        {
+            uint tag = WireFormat.MakeTag(10, WireFormat.WireType.LengthDelimited);
+            var codec = FieldCodec.ForFixed32(tag);
+            var stream = new MemoryStream();
+            var output = new CodedOutputStream(stream);
+            output.WriteTag(tag);
+            output.WriteString("A long string");
+            output.WriteTag(codec.Tag);
+            output.WriteRawVarint32((uint)codec.FixedSize - 1); // Length not divisible by FixedSize
+            output.WriteFixed32(uint.MaxValue);
+            output.Flush();
+            stream.Position = 0;
+
+            var input = new CodedInputStream(stream);
+            input.ReadTag();
+            input.ReadString();
+            input.ReadTag();
+            var field = new RepeatedField<uint>();
+            Assert.Throws<InvalidProtocolBufferException>(() => field.AddEntriesFrom(input, codec));
+
+            // Collection was not pre-initialized
+            Assert.AreEqual(0, field.Count);
+        }
+
+        [Test]
+        public void TestPackedRepeatedFieldCollectionNotAllocatedWhenLengthExceedsBuffer()
+        {
+            uint tag = WireFormat.MakeTag(10, WireFormat.WireType.LengthDelimited);
+            var codec = FieldCodec.ForFixed32(tag);
+            var stream = new MemoryStream();
+            var output = new CodedOutputStream(stream);
+            output.WriteTag(tag);
+            output.WriteString("A long string");
+            output.WriteTag(codec.Tag);
+            output.WriteRawVarint32((uint)codec.FixedSize);
+            // Note that there is no content for the packed field.
+            // The field length exceeds the remaining length of content.
+            output.Flush();
+            stream.Position = 0;
+
+            var input = new CodedInputStream(stream);
+            input.ReadTag();
+            input.ReadString();
+            input.ReadTag();
+            var field = new RepeatedField<uint>();
+            Assert.Throws<InvalidProtocolBufferException>(() => field.AddEntriesFrom(input, codec));
+
+            // Collection was not pre-initialized
+            Assert.AreEqual(0, field.Count);
+        }
+
+        [Test]
+        public void TestPackedRepeatedFieldCollectionNotAllocatedWhenLengthExceedsRemainingData()
+        {
+            uint tag = WireFormat.MakeTag(10, WireFormat.WireType.LengthDelimited);
+            var codec = FieldCodec.ForFixed32(tag);
+            var stream = new MemoryStream();
+            var output = new CodedOutputStream(stream);
+            output.WriteTag(tag);
+            output.WriteString("A long string");
+            output.WriteTag(codec.Tag);
+            output.WriteRawVarint32((uint)codec.FixedSize);
+            // Note that there is no content for the packed field.
+            // The field length exceeds the remaining length of the buffer.
+            output.Flush();
+            stream.Position = 0;
+
+            var sequence = ReadOnlySequenceFactory.CreateWithContent(stream.ToArray());
+            ParseContext.Initialize(sequence, out ParseContext ctx);
+
+            ctx.ReadTag();
+            ctx.ReadString();
+            ctx.ReadTag();
+            var field = new RepeatedField<uint>();
+            try
+            {
+                field.AddEntriesFrom(ref ctx, codec);
+                Assert.Fail();
+            }
+            catch (InvalidProtocolBufferException)
+            {
+            }
+
+            // Collection was not pre-initialized
+            Assert.AreEqual(0, field.Count);
+        }
+
         // Fairly perfunctory tests for the non-generic IList implementation
         [Test]
         public void IList_Indexer()

+ 2 - 2
csharp/src/Google.Protobuf.Test/ReadOnlySequenceFactory.cs

@@ -50,9 +50,9 @@ namespace Google.Protobuf
             while (currentIndex < data.Length)
             {
                 var segment = new List<byte>();
-                for (; currentIndex < Math.Min(currentIndex + segmentSize, data.Length); currentIndex++)
+                while (segment.Count < segmentSize && currentIndex < data.Length)
                 {
-                    segment.Add(data[currentIndex]);
+                    segment.Add(data[currentIndex++]);
                 }
                 segments.Add(segment.ToArray());
                 segments.Add(new byte[0]);

+ 43 - 2
csharp/src/Google.Protobuf/Collections/RepeatedField.cs

@@ -35,6 +35,7 @@ using System.Collections;
 using System.Collections.Generic;
 using System.IO;
 using System.Security;
+using System.Threading;
 
 namespace Google.Protobuf.Collections
 {
@@ -126,9 +127,31 @@ namespace Google.Protobuf.Collections
                 if (length > 0)
                 {
                     int oldLimit = SegmentedBufferHelper.PushLimit(ref ctx.state, length);
-                    while (!SegmentedBufferHelper.IsReachedLimit(ref ctx.state))
+
+                    // If the content is fixed size then we can calculate the length
+                    // of the repeated field and pre-initialize the underlying collection.
+                    //
+                    // Check that the supplied length doesn't exceed the underlying buffer.
+                    // That prevents a malicious length from initializing a very large collection.
+                    if (codec.FixedSize > 0 && length % codec.FixedSize == 0 && IsDataAvailable(ref ctx, length))
+                    {
+                        EnsureSize(count + (length / codec.FixedSize));
+
+                        while (!SegmentedBufferHelper.IsReachedLimit(ref ctx.state))
+                        {
+                            // Only FieldCodecs with a fixed size can reach here, and they are all known
+                            // types that don't allow the user to specify a custom reader action.
+                            // reader action will never return null.
+                            array[count++] = reader(ref ctx);
+                        }
+                    }
+                    else
                     {
-                        Add(reader(ref ctx));
+                        // Content is variable size so add until we reach the limit.
+                        while (!SegmentedBufferHelper.IsReachedLimit(ref ctx.state))
+                        {
+                            Add(reader(ref ctx));
+                        }
                     }
                     SegmentedBufferHelper.PopLimit(ref ctx.state, oldLimit);
                 }
@@ -144,6 +167,24 @@ namespace Google.Protobuf.Collections
             }
         }
 
+        private bool IsDataAvailable(ref ParseContext ctx, int size)
+        {
+            // Data fits in remaining buffer
+            if (size <= ctx.state.bufferSize - ctx.state.bufferPos)
+            {
+                return true;
+            }
+
+            // Data fits in remaining source data.
+            // Note that this will never be true when reading from a stream as the total length is unknown.
+            if (size < ctx.state.segmentedBufferHelper.TotalLength - ctx.state.totalBytesRetired - ctx.state.bufferPos)
+            {
+                return true;
+            }
+
+            return false;
+        }
+
         /// <summary>
         /// Calculates the size of this collection based on the given codec.
         /// </summary>