Browse Source

enforce recursion depth checking for unknown fields

Jan Tattermusch 5 years ago
parent
commit
f20be83927

+ 62 - 1
csharp/src/Google.Protobuf.Test/CodedInputStreamTest.cs

@@ -33,6 +33,7 @@
 using System;
 using System.IO;
 using Google.Protobuf.TestProtos;
+using Proto2 = Google.Protobuf.TestProtos.Proto2;
 using NUnit.Framework;
 
 namespace Google.Protobuf
@@ -337,6 +338,66 @@ namespace Google.Protobuf
             CodedInputStream input = CodedInputStream.CreateWithLimits(new MemoryStream(atRecursiveLimit.ToByteArray()), 1000000, CodedInputStream.DefaultRecursionLimit - 1);
             Assert.Throws<InvalidProtocolBufferException>(() => TestRecursiveMessage.Parser.ParseFrom(input));
         }
+        
+        private static byte[] MakeMaliciousRecursionUnknownFieldsPayload(int recursionDepth)
+        {
+            // generate recursively nested groups that will be parsed as unknown fields
+            int unknownFieldNumber = 14;  // an unused field number
+            MemoryStream ms = new MemoryStream();
+            CodedOutputStream output = new CodedOutputStream(ms);
+            for (int i = 0; i < recursionDepth; i++)
+            {
+                output.WriteTag(WireFormat.MakeTag(unknownFieldNumber, WireFormat.WireType.StartGroup));
+            }
+            for (int i = 0; i < recursionDepth; i++)
+            {
+                output.WriteTag(WireFormat.MakeTag(unknownFieldNumber, WireFormat.WireType.EndGroup));
+            }
+            output.Flush();
+            return ms.ToArray();
+        }
+
+        [Test]
+        public void MaliciousRecursion_UnknownFields()
+        {
+            byte[] payloadAtRecursiveLimit = MakeMaliciousRecursionUnknownFieldsPayload(CodedInputStream.DefaultRecursionLimit);
+            byte[] payloadBeyondRecursiveLimit = MakeMaliciousRecursionUnknownFieldsPayload(CodedInputStream.DefaultRecursionLimit + 1);
+            
+            Assert.DoesNotThrow(() => TestRecursiveMessage.Parser.ParseFrom(payloadAtRecursiveLimit));
+            Assert.Throws<InvalidProtocolBufferException>(() => TestRecursiveMessage.Parser.ParseFrom(payloadBeyondRecursiveLimit));
+        }
+
+        [Test]
+        public void ReadGroup_WrongEndGroupTag()
+        {
+            int groupFieldNumber = Proto2.TestAllTypes.OptionalGroupFieldNumber;
+
+            // write Proto2.TestAllTypes with "optional_group" set, but use wrong EndGroup closing tag
+            MemoryStream ms = new MemoryStream();
+            CodedOutputStream output = new CodedOutputStream(ms);
+            output.WriteTag(WireFormat.MakeTag(groupFieldNumber, WireFormat.WireType.StartGroup));
+            output.WriteGroup(new Proto2.TestAllTypes.Types.OptionalGroup { A = 12345 });
+            // end group with different field number
+            output.WriteTag(WireFormat.MakeTag(groupFieldNumber + 1, WireFormat.WireType.EndGroup));
+            output.Flush();
+            var payload = ms.ToArray();
+
+            Assert.Throws<InvalidProtocolBufferException>(() => Proto2.TestAllTypes.Parser.ParseFrom(payload));
+        }
+
+        [Test]
+        public void ReadGroup_UnknownFields_WrongEndGroupTag()
+        {
+            MemoryStream ms = new MemoryStream();
+            CodedOutputStream output = new CodedOutputStream(ms);
+            output.WriteTag(WireFormat.MakeTag(14, WireFormat.WireType.StartGroup));
+            // end group with different field number
+            output.WriteTag(WireFormat.MakeTag(15, WireFormat.WireType.EndGroup));
+            output.Flush();
+            var payload = ms.ToArray();
+
+            Assert.Throws<InvalidProtocolBufferException>(() => TestRecursiveMessage.Parser.ParseFrom(payload));
+        }
 
         [Test]
         public void SizeLimit()
@@ -735,4 +796,4 @@ namespace Google.Protobuf
             }
         }
     }
-}
+}

+ 29 - 2
csharp/src/Google.Protobuf/CodedInputStream.cs

@@ -307,10 +307,17 @@ namespace Google.Protobuf
                 throw InvalidProtocolBufferException.MoreDataAvailable();
             }
         }
-        #endregion
 
+        internal void CheckLastTagWas(uint expectedTag)
+        {
+           if (lastTag != expectedTag) {
+                throw InvalidProtocolBufferException.InvalidEndTag();
+           }
+        }
+        #endregion
+
         #region Reading of tags etc
-
+
         /// <summary>
         /// Peeks at the next field tag. This is like calling <see cref="ReadTag"/>, but the
         /// tag is not consumed. (So a subsequent call to <see cref="ReadTag"/> will return the
@@ -636,7 +643,27 @@ namespace Google.Protobuf
                 throw InvalidProtocolBufferException.RecursionLimitExceeded();
             }
             ++recursionDepth;
+
+            uint tag = lastTag;
+            int fieldNumber = WireFormat.GetTagFieldNumber(tag);
+
             builder.MergeFrom(this);
+            CheckLastTagWas(WireFormat.MakeTag(fieldNumber, WireFormat.WireType.EndGroup));
+            --recursionDepth;
+        }
+
+        /// <summary>
+        /// Reads an embedded group unknown field from the stream.
+        /// </summary>
+        internal void ReadGroup(int fieldNumber, UnknownFieldSet set)
+        {
+            if (recursionDepth >= recursionLimit)
+            {
+                throw InvalidProtocolBufferException.RecursionLimitExceeded();
+            }
+            ++recursionDepth;
+            set.MergeGroupFrom(this);
+            CheckLastTagWas(WireFormat.MakeTag(fieldNumber, WireFormat.WireType.EndGroup));
             --recursionDepth;
         }
 

+ 17 - 5
csharp/src/Google.Protobuf/UnknownFieldSet.cs

@@ -215,12 +215,8 @@ namespace Google.Protobuf
                     }
                 case WireFormat.WireType.StartGroup:
                     {
-                        uint endTag = WireFormat.MakeTag(number, WireFormat.WireType.EndGroup);
                         UnknownFieldSet set = new UnknownFieldSet();
-                        while (input.ReadTag() != endTag)
-                        {
-                            set.MergeFieldFrom(input);
-                        }
+                        input.ReadGroup(number, set);
                         GetOrAddField(number).AddGroup(set);
                         return true;
                     }
@@ -233,6 +229,22 @@ namespace Google.Protobuf
             }
         }
 
+        internal void MergeGroupFrom(CodedInputStream input)
+        {
+            while (true)
+            {
+                uint tag = input.ReadTag();
+                if (tag == 0)
+                {
+                    break;
+                }
+                if (!MergeFieldFrom(input))
+                {
+                    break;
+                }
+            }
+        }
+
         /// <summary>
         /// Create a new UnknownFieldSet if unknownFields is null.
         /// Parse a single field from <paramref name="input"/> and merge it