|
|
@@ -18,11 +18,18 @@ import unittest
|
|
|
import grpc
|
|
|
|
|
|
from grpc.experimental import aio
|
|
|
-from tests_aio.unit._test_server import start_test_server, UNARY_CALL_WITH_SLEEP_VALUE
|
|
|
+from tests_aio.unit._test_server import start_test_server, _INITIAL_METADATA_KEY, _TRAILING_METADATA_KEY
|
|
|
+from tests_aio.unit import _constants
|
|
|
+from tests_aio.unit import _common
|
|
|
from tests_aio.unit._test_base import AioTestBase
|
|
|
-from src.proto.grpc.testing import messages_pb2
|
|
|
+from src.proto.grpc.testing import messages_pb2, test_pb2_grpc
|
|
|
+
|
|
|
|
|
|
_LOCAL_CANCEL_DETAILS_EXPECTATION = 'Locally cancelled by application!'
|
|
|
+_INITIAL_METADATA_TO_INJECT = (
|
|
|
+ (_INITIAL_METADATA_KEY, 'extra info'),
|
|
|
+ (_TRAILING_METADATA_KEY, b'\x13\x37'),
|
|
|
+)
|
|
|
|
|
|
|
|
|
class TestUnaryUnaryClientInterceptor(AioTestBase):
|
|
|
@@ -124,7 +131,7 @@ class TestUnaryUnaryClientInterceptor(AioTestBase):
|
|
|
client_call_details, request):
|
|
|
new_client_call_details = aio.ClientCallDetails(
|
|
|
method=client_call_details.method,
|
|
|
- timeout=UNARY_CALL_WITH_SLEEP_VALUE / 2,
|
|
|
+ timeout=_constants.UNARY_CALL_WITH_SLEEP_VALUE / 2,
|
|
|
metadata=client_call_details.metadata,
|
|
|
credentials=client_call_details.credentials)
|
|
|
return await continuation(new_client_call_details, request)
|
|
|
@@ -165,7 +172,7 @@ class TestUnaryUnaryClientInterceptor(AioTestBase):
|
|
|
|
|
|
new_client_call_details = aio.ClientCallDetails(
|
|
|
method=client_call_details.method,
|
|
|
- timeout=UNARY_CALL_WITH_SLEEP_VALUE / 2,
|
|
|
+ timeout=_constants.UNARY_CALL_WITH_SLEEP_VALUE / 2,
|
|
|
metadata=client_call_details.metadata,
|
|
|
credentials=client_call_details.credentials)
|
|
|
|
|
|
@@ -342,8 +349,9 @@ class TestInterceptedUnaryUnaryCall(AioTestBase):
|
|
|
request_serializer=messages_pb2.SimpleRequest.SerializeToString,
|
|
|
response_deserializer=messages_pb2.SimpleResponse.FromString)
|
|
|
|
|
|
- call = multicallable(messages_pb2.SimpleRequest(),
|
|
|
- timeout=UNARY_CALL_WITH_SLEEP_VALUE / 2)
|
|
|
+ call = multicallable(
|
|
|
+ messages_pb2.SimpleRequest(),
|
|
|
+ timeout=_constants.UNARY_CALL_WITH_SLEEP_VALUE / 2)
|
|
|
|
|
|
with self.assertRaises(aio.AioRpcError) as exception_context:
|
|
|
await call
|
|
|
@@ -375,8 +383,9 @@ class TestInterceptedUnaryUnaryCall(AioTestBase):
|
|
|
request_serializer=messages_pb2.SimpleRequest.SerializeToString,
|
|
|
response_deserializer=messages_pb2.SimpleResponse.FromString)
|
|
|
|
|
|
- call = multicallable(messages_pb2.SimpleRequest(),
|
|
|
- timeout=UNARY_CALL_WITH_SLEEP_VALUE / 2)
|
|
|
+ call = multicallable(
|
|
|
+ messages_pb2.SimpleRequest(),
|
|
|
+ timeout=_constants.UNARY_CALL_WITH_SLEEP_VALUE / 2)
|
|
|
|
|
|
with self.assertRaises(aio.AioRpcError) as exception_context:
|
|
|
await call
|
|
|
@@ -532,6 +541,42 @@ class TestInterceptedUnaryUnaryCall(AioTestBase):
|
|
|
self.assertEqual(await call.initial_metadata(), tuple())
|
|
|
self.assertEqual(await call.trailing_metadata(), None)
|
|
|
|
|
|
+ async def test_initial_metadata_modification(self):
|
|
|
+
|
|
|
+ class Interceptor(aio.UnaryUnaryClientInterceptor):
|
|
|
+
|
|
|
+ async def intercept_unary_unary(self, continuation,
|
|
|
+ client_call_details, request):
|
|
|
+ if client_call_details.metadata is not None:
|
|
|
+ new_metadata = client_call_details.metadata + _INITIAL_METADATA_TO_INJECT
|
|
|
+ else:
|
|
|
+ new_metadata = _INITIAL_METADATA_TO_INJECT
|
|
|
+ new_details = aio.ClientCallDetails(
|
|
|
+ method=client_call_details.method,
|
|
|
+ timeout=client_call_details.timeout,
|
|
|
+ metadata=new_metadata,
|
|
|
+ credentials=client_call_details.credentials,
|
|
|
+ )
|
|
|
+ return await continuation(new_details, request)
|
|
|
+
|
|
|
+ async with aio.insecure_channel(self._server_target,
|
|
|
+ interceptors=[Interceptor()
|
|
|
+ ]) as channel:
|
|
|
+ stub = test_pb2_grpc.TestServiceStub(channel)
|
|
|
+ call = stub.UnaryCall(messages_pb2.SimpleRequest())
|
|
|
+
|
|
|
+ # Expected to see the echoed initial metadata
|
|
|
+ self.assertTrue(
|
|
|
+ _common.seen_metadata(_INITIAL_METADATA_TO_INJECT[0], await
|
|
|
+ call.initial_metadata()))
|
|
|
+
|
|
|
+ # Expected to see the echoed trailing metadata
|
|
|
+ self.assertTrue(
|
|
|
+ _common.seen_metadata(_INITIAL_METADATA_TO_INJECT[1], await
|
|
|
+ call.trailing_metadata()))
|
|
|
+
|
|
|
+ self.assertEqual(await call.code(), grpc.StatusCode.OK)
|
|
|
+
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
logging.basicConfig()
|