compression_test.py 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196
  1. # Copyright 2020 The gRPC Authors
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """Tests behavior around the compression mechanism."""
  15. import asyncio
  16. import logging
  17. import platform
  18. import random
  19. import unittest
  20. import grpc
  21. from grpc.experimental import aio
  22. from tests_aio.unit._test_base import AioTestBase
  23. from tests_aio.unit import _common
  24. _GZIP_CHANNEL_ARGUMENT = ('grpc.default_compression_algorithm', 2)
  25. _GZIP_DISABLED_CHANNEL_ARGUMENT = ('grpc.compression_enabled_algorithms_bitset',
  26. 3)
  27. _DEFLATE_DISABLED_CHANNEL_ARGUMENT = (
  28. 'grpc.compression_enabled_algorithms_bitset', 5)
  29. _TEST_UNARY_UNARY = '/test/TestUnaryUnary'
  30. _TEST_SET_COMPRESSION = '/test/TestSetCompression'
  31. _TEST_DISABLE_COMPRESSION_UNARY = '/test/TestDisableCompressionUnary'
  32. _TEST_DISABLE_COMPRESSION_STREAM = '/test/TestDisableCompressionStream'
  33. _REQUEST = b'\x01' * 100
  34. _RESPONSE = b'\x02' * 100
  35. async def _test_unary_unary(unused_request, unused_context):
  36. return _RESPONSE
  37. async def _test_set_compression(unused_request_iterator, context):
  38. assert _REQUEST == await context.read()
  39. context.set_compression(grpc.Compression.Deflate)
  40. await context.write(_RESPONSE)
  41. try:
  42. context.set_compression(grpc.Compression.Deflate)
  43. except RuntimeError:
  44. # NOTE(lidiz) Testing if the servicer context raises exception when
  45. # the set_compression method is called after initial_metadata sent.
  46. # After the initial_metadata sent, the server-side has no control over
  47. # which compression algorithm it should use.
  48. pass
  49. else:
  50. raise ValueError(
  51. 'Expecting exceptions if set_compression is not effective')
  52. async def _test_disable_compression_unary(request, context):
  53. assert _REQUEST == request
  54. context.set_compression(grpc.Compression.Deflate)
  55. context.disable_next_message_compression()
  56. return _RESPONSE
  57. async def _test_disable_compression_stream(unused_request_iterator, context):
  58. assert _REQUEST == await context.read()
  59. context.set_compression(grpc.Compression.Deflate)
  60. await context.write(_RESPONSE)
  61. context.disable_next_message_compression()
  62. await context.write(_RESPONSE)
  63. await context.write(_RESPONSE)
  64. _ROUTING_TABLE = {
  65. _TEST_UNARY_UNARY:
  66. grpc.unary_unary_rpc_method_handler(_test_unary_unary),
  67. _TEST_SET_COMPRESSION:
  68. grpc.stream_stream_rpc_method_handler(_test_set_compression),
  69. _TEST_DISABLE_COMPRESSION_UNARY:
  70. grpc.unary_unary_rpc_method_handler(_test_disable_compression_unary),
  71. _TEST_DISABLE_COMPRESSION_STREAM:
  72. grpc.stream_stream_rpc_method_handler(_test_disable_compression_stream),
  73. }
  74. class _GenericHandler(grpc.GenericRpcHandler):
  75. def service(self, handler_call_details):
  76. return _ROUTING_TABLE.get(handler_call_details.method)
  77. async def _start_test_server(options=None):
  78. server = aio.server(options=options)
  79. port = server.add_insecure_port('[::]:0')
  80. server.add_generic_rpc_handlers((_GenericHandler(),))
  81. await server.start()
  82. return f'localhost:{port}', server
  83. class TestCompression(AioTestBase):
  84. async def setUp(self):
  85. server_options = (_GZIP_DISABLED_CHANNEL_ARGUMENT,)
  86. self._address, self._server = await _start_test_server(server_options)
  87. self._channel = aio.insecure_channel(self._address)
  88. async def tearDown(self):
  89. await self._channel.close()
  90. await self._server.stop(None)
  91. async def test_channel_level_compression_baned_compression(self):
  92. # GZIP is disabled, this call should fail
  93. async with aio.insecure_channel(
  94. self._address, compression=grpc.Compression.Gzip) as channel:
  95. multicallable = channel.unary_unary(_TEST_UNARY_UNARY)
  96. call = multicallable(_REQUEST)
  97. with self.assertRaises(aio.AioRpcError) as exception_context:
  98. await call
  99. rpc_error = exception_context.exception
  100. self.assertEqual(grpc.StatusCode.UNIMPLEMENTED, rpc_error.code())
  101. async def test_channel_level_compression_allowed_compression(self):
  102. # Deflate is allowed, this call should succeed
  103. async with aio.insecure_channel(
  104. self._address, compression=grpc.Compression.Deflate) as channel:
  105. multicallable = channel.unary_unary(_TEST_UNARY_UNARY)
  106. call = multicallable(_REQUEST)
  107. self.assertEqual(grpc.StatusCode.OK, await call.code())
  108. async def test_client_call_level_compression_baned_compression(self):
  109. multicallable = self._channel.unary_unary(_TEST_UNARY_UNARY)
  110. # GZIP is disabled, this call should fail
  111. call = multicallable(_REQUEST, compression=grpc.Compression.Gzip)
  112. with self.assertRaises(aio.AioRpcError) as exception_context:
  113. await call
  114. rpc_error = exception_context.exception
  115. self.assertEqual(grpc.StatusCode.UNIMPLEMENTED, rpc_error.code())
  116. async def test_client_call_level_compression_allowed_compression(self):
  117. multicallable = self._channel.unary_unary(_TEST_UNARY_UNARY)
  118. # Deflate is allowed, this call should succeed
  119. call = multicallable(_REQUEST, compression=grpc.Compression.Deflate)
  120. self.assertEqual(grpc.StatusCode.OK, await call.code())
  121. async def test_server_call_level_compression(self):
  122. multicallable = self._channel.stream_stream(_TEST_SET_COMPRESSION)
  123. call = multicallable()
  124. await call.write(_REQUEST)
  125. await call.done_writing()
  126. self.assertEqual(_RESPONSE, await call.read())
  127. self.assertEqual(grpc.StatusCode.OK, await call.code())
  128. async def test_server_disable_compression_unary(self):
  129. multicallable = self._channel.unary_unary(
  130. _TEST_DISABLE_COMPRESSION_UNARY)
  131. call = multicallable(_REQUEST)
  132. self.assertEqual(_RESPONSE, await call)
  133. self.assertEqual(grpc.StatusCode.OK, await call.code())
  134. async def test_server_disable_compression_stream(self):
  135. multicallable = self._channel.stream_stream(
  136. _TEST_DISABLE_COMPRESSION_STREAM)
  137. call = multicallable()
  138. await call.write(_REQUEST)
  139. await call.done_writing()
  140. self.assertEqual(_RESPONSE, await call.read())
  141. self.assertEqual(_RESPONSE, await call.read())
  142. self.assertEqual(_RESPONSE, await call.read())
  143. self.assertEqual(grpc.StatusCode.OK, await call.code())
  144. async def test_server_default_compression_algorithm(self):
  145. server = aio.server(compression=grpc.Compression.Deflate)
  146. port = server.add_insecure_port('[::]:0')
  147. server.add_generic_rpc_handlers((_GenericHandler(),))
  148. await server.start()
  149. async with aio.insecure_channel(f'localhost:{port}') as channel:
  150. multicallable = channel.unary_unary(_TEST_UNARY_UNARY)
  151. call = multicallable(_REQUEST)
  152. self.assertEqual(_RESPONSE, await call)
  153. self.assertEqual(grpc.StatusCode.OK, await call.code())
  154. await server.stop(None)
  155. if __name__ == '__main__':
  156. logging.basicConfig(level=logging.DEBUG)
  157. unittest.main(verbosity=2)