Explorar o código

Validate that end-group tags match their corresponding start-group tags

This detects:
- An end-group tag with the wrong field number (doesn't match the start-group field)
- An end-group tag with no preceding start-group tag

Fixes issue #688.
Jon Skeet %!s(int64=9) %!d(string=hai) anos
pai
achega
9bdc848832

+ 48 - 2
csharp/src/Google.Protobuf.Test/CodedInputStreamTest.cs

@@ -469,6 +469,52 @@ namespace Google.Protobuf
             Assert.AreEqual("field 3", input.ReadString());
         }
 
+        [Test]
+        public void SkipGroup_WrongEndGroupTag()
+        {
+            // Create an output stream with:
+            // Field 1: string "field 1"
+            // Start group 2
+            //   Field 3: fixed int32
+            // End group 4 (should give an error)
+            var stream = new MemoryStream();
+            var output = new CodedOutputStream(stream);
+            output.WriteTag(1, WireFormat.WireType.LengthDelimited);
+            output.WriteString("field 1");
+
+            // The outer group...
+            output.WriteTag(2, WireFormat.WireType.StartGroup);
+            output.WriteTag(3, WireFormat.WireType.Fixed32);
+            output.WriteFixed32(100);
+            output.WriteTag(4, WireFormat.WireType.EndGroup);
+            output.Flush();
+            stream.Position = 0;
+
+            // Now act like a generated client
+            var input = new CodedInputStream(stream);
+            Assert.AreEqual(WireFormat.MakeTag(1, WireFormat.WireType.LengthDelimited), input.ReadTag());
+            Assert.AreEqual("field 1", input.ReadString());
+            Assert.AreEqual(WireFormat.MakeTag(2, WireFormat.WireType.StartGroup), input.ReadTag());
+            Assert.Throws<InvalidProtocolBufferException>(input.SkipLastField);
+        }
+
+        [Test]
+        public void RogueEndGroupTag()
+        {
+            // If we have an end-group tag without a leading start-group tag, generated
+            // code will just call SkipLastField... so that should fail.
+
+            var stream = new MemoryStream();
+            var output = new CodedOutputStream(stream);
+            output.WriteTag(1, WireFormat.WireType.EndGroup);
+            output.Flush();
+            stream.Position = 0;
+
+            var input = new CodedInputStream(stream);
+            Assert.AreEqual(WireFormat.MakeTag(1, WireFormat.WireType.EndGroup), input.ReadTag());
+            Assert.Throws<InvalidProtocolBufferException>(input.SkipLastField);
+        }
+
         [Test]
         public void EndOfStreamReachedWhileSkippingGroup()
         {
@@ -484,7 +530,7 @@ namespace Google.Protobuf
             // Now act like a generated client
             var input = new CodedInputStream(stream);
             input.ReadTag();
-            Assert.Throws<InvalidProtocolBufferException>(() => input.SkipLastField());
+            Assert.Throws<InvalidProtocolBufferException>(input.SkipLastField);
         }
 
         [Test]
@@ -506,7 +552,7 @@ namespace Google.Protobuf
             // Now act like a generated client
             var input = new CodedInputStream(stream);
             Assert.AreEqual(WireFormat.MakeTag(1, WireFormat.WireType.StartGroup), input.ReadTag());
-            Assert.Throws<InvalidProtocolBufferException>(() => input.SkipLastField());
+            Assert.Throws<InvalidProtocolBufferException>(input.SkipLastField);
         }
 
         [Test]

+ 3 - 4
csharp/src/Google.Protobuf.Test/GeneratedMessageTest.cs

@@ -679,21 +679,20 @@ namespace Google.Protobuf
         /// for details; we may want to change this.
         /// </summary>
         [Test]
-        public void ExtraEndGroupSkipped()
+        public void ExtraEndGroupThrows()
         {
             var message = SampleMessages.CreateFullTestAllTypes();
             var stream = new MemoryStream();
             var output = new CodedOutputStream(stream);
 
-            output.WriteTag(100, WireFormat.WireType.EndGroup);
             output.WriteTag(TestAllTypes.SingleFixed32FieldNumber, WireFormat.WireType.Fixed32);
             output.WriteFixed32(123);
+            output.WriteTag(100, WireFormat.WireType.EndGroup);
 
             output.Flush();
 
             stream.Position = 0;
-            var parsed = TestAllTypes.Parser.ParseFrom(stream);
-            Assert.AreEqual(new TestAllTypes { SingleFixed32 = 123 }, parsed);
+            Assert.Throws<InvalidProtocolBufferException>(() => TestAllTypes.Parser.ParseFrom(stream));
         }
 
         [Test]

+ 26 - 6
csharp/src/Google.Protobuf/CodedInputStream.cs

@@ -349,6 +349,14 @@ namespace Google.Protobuf
         /// This should be called directly after <see cref="ReadTag"/>, when
         /// the caller wishes to skip an unknown field.
         /// </summary>
+        /// <remarks>
+        /// This method throws <see cref="InvalidProtocolBufferException"/> if the last-read tag was an end-group tag.
+        /// If a caller wishes to skip a group, they should skip the whole group, by calling this method after reading the
+        /// start-group tag. This behavior allows callers to call this method on any field they don't understand, correctly
+        /// resulting in an error if an end-group tag has not been paired with an earlier start-group tag.
+        /// </remarks>
+        /// <exception cref="InvalidProtocolBufferException">The last tag was an end-group tag</exception>
+        /// <exception cref="InvalidOperationException">The last read operation read to the end of the logical stream</exception>
         public void SkipLastField()
         {
             if (lastTag == 0)
@@ -358,11 +366,11 @@ namespace Google.Protobuf
             switch (WireFormat.GetTagWireType(lastTag))
             {
                 case WireFormat.WireType.StartGroup:
-                    SkipGroup();
+                    SkipGroup(lastTag);
                     break;
                 case WireFormat.WireType.EndGroup:
-                    // Just ignore; there's no data following the tag.
-                    break;
+                    throw new InvalidProtocolBufferException(
+                        "SkipLastField called on an end-group tag, indicating that the corresponding start-group was missing");
                 case WireFormat.WireType.Fixed32:
                     ReadFixed32();
                     break;
@@ -379,7 +387,7 @@ namespace Google.Protobuf
             }
         }
 
-        private void SkipGroup()
+        private void SkipGroup(uint startGroupTag)
         {
             // Note: Currently we expect this to be the way that groups are read. We could put the recursion
             // depth changes into the ReadTag method instead, potentially...
@@ -389,16 +397,28 @@ namespace Google.Protobuf
                 throw InvalidProtocolBufferException.RecursionLimitExceeded();
             }
             uint tag;
-            do
+            while (true)
             {
                 tag = ReadTag();
                 if (tag == 0)
                 {
                     throw InvalidProtocolBufferException.TruncatedMessage();
                 }
+                // Can't call SkipLastField for this case- that would throw.
+                if (WireFormat.GetTagWireType(tag) == WireFormat.WireType.EndGroup)
+                {
+                    break;
+                }
                 // This recursion will allow us to handle nested groups.
                 SkipLastField();
-            } while (WireFormat.GetTagWireType(tag) != WireFormat.WireType.EndGroup);
+            }
+            int startField = WireFormat.GetTagFieldNumber(startGroupTag);
+            int endField = WireFormat.GetTagFieldNumber(tag);
+            if (startField != endField)
+            {
+                throw new InvalidProtocolBufferException(
+                    $"Mismatched end-group tag. Started with field {startField}; ended with field {endField}");
+            }
             recursionDepth--;
         }