| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536 | # Copyright 2015 gRPC authors.## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at##     http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License."""Implementations of interoperability test methods."""import enumimport jsonimport osimport threadingfrom google import auth as google_authfrom google.auth import environment_vars as google_auth_environment_varsfrom google.auth.transport import grpc as google_auth_transport_grpcfrom google.auth.transport import requests as google_auth_transport_requestsimport grpcfrom grpc.beta import implementationsfrom src.proto.grpc.testing import empty_pb2from src.proto.grpc.testing import messages_pb2from src.proto.grpc.testing import test_pb2_grpc_INITIAL_METADATA_KEY = "x-grpc-test-echo-initial"_TRAILING_METADATA_KEY = "x-grpc-test-echo-trailing-bin"def _maybe_echo_metadata(servicer_context):    """Copies metadata from request to response if it is present."""    invocation_metadata = dict(servicer_context.invocation_metadata())    if _INITIAL_METADATA_KEY in invocation_metadata:        initial_metadatum = (_INITIAL_METADATA_KEY,                             invocation_metadata[_INITIAL_METADATA_KEY])        servicer_context.send_initial_metadata((initial_metadatum,))    if _TRAILING_METADATA_KEY in invocation_metadata:        trailing_metadatum = (_TRAILING_METADATA_KEY,                              invocation_metadata[_TRAILING_METADATA_KEY])        servicer_context.set_trailing_metadata((trailing_metadatum,))def _maybe_echo_status_and_message(request, servicer_context):    """Sets the response context code and details if the request asks for them"""    if request.HasField('response_status'):        servicer_context.set_code(request.response_status.code)        servicer_context.set_details(request.response_status.message)class TestService(test_pb2_grpc.TestServiceServicer):    def EmptyCall(self, request, context):        _maybe_echo_metadata(context)        return empty_pb2.Empty()    def UnaryCall(self, request, context):        _maybe_echo_metadata(context)        _maybe_echo_status_and_message(request, context)        return messages_pb2.SimpleResponse(            payload=messages_pb2.Payload(                type=messages_pb2.COMPRESSABLE,                body=b'\x00' * request.response_size))    def StreamingOutputCall(self, request, context):        _maybe_echo_status_and_message(request, context)        for response_parameters in request.response_parameters:            yield messages_pb2.StreamingOutputCallResponse(                payload=messages_pb2.Payload(                    type=request.response_type,                    body=b'\x00' * response_parameters.size))    def StreamingInputCall(self, request_iterator, context):        aggregate_size = 0        for request in request_iterator:            if request.payload is not None and request.payload.body:                aggregate_size += len(request.payload.body)        return messages_pb2.StreamingInputCallResponse(            aggregated_payload_size=aggregate_size)    def FullDuplexCall(self, request_iterator, context):        _maybe_echo_metadata(context)        for request in request_iterator:            _maybe_echo_status_and_message(request, context)            for response_parameters in request.response_parameters:                yield messages_pb2.StreamingOutputCallResponse(                    payload=messages_pb2.Payload(                        type=request.payload.type,                        body=b'\x00' * response_parameters.size))    # NOTE(nathaniel): Apparently this is the same as the full-duplex call?    # NOTE(atash): It isn't even called in the interop spec (Oct 22 2015)...    def HalfDuplexCall(self, request_iterator, context):        return self.FullDuplexCall(request_iterator, context)def _expect_status_code(call, expected_code):    if call.code() != expected_code:        raise ValueError('expected code %s, got %s' % (expected_code,                                                       call.code()))def _expect_status_details(call, expected_details):    if call.details() != expected_details:        raise ValueError('expected message %s, got %s' % (expected_details,                                                          call.details()))def _validate_status_code_and_details(call, expected_code, expected_details):    _expect_status_code(call, expected_code)    _expect_status_details(call, expected_details)def _validate_payload_type_and_length(response, expected_type, expected_length):    if response.payload.type is not expected_type:        raise ValueError('expected payload type %s, got %s' %                         (expected_type, type(response.payload.type)))    elif len(response.payload.body) != expected_length:        raise ValueError('expected payload body size %d, got %d' %                         (expected_length, len(response.payload.body)))def _large_unary_common_behavior(stub, fill_username, fill_oauth_scope,                                 call_credentials):    size = 314159    request = messages_pb2.SimpleRequest(        response_type=messages_pb2.COMPRESSABLE,        response_size=size,        payload=messages_pb2.Payload(body=b'\x00' * 271828),        fill_username=fill_username,        fill_oauth_scope=fill_oauth_scope)    response_future = stub.UnaryCall.future(        request, credentials=call_credentials)    response = response_future.result()    _validate_payload_type_and_length(response, messages_pb2.COMPRESSABLE, size)    return responsedef _empty_unary(stub):    response = stub.EmptyCall(empty_pb2.Empty())    if not isinstance(response, empty_pb2.Empty):        raise TypeError(            'response is of type "%s", not empty_pb2.Empty!' % type(response))def _large_unary(stub):    _large_unary_common_behavior(stub, False, False, None)def _client_streaming(stub):    payload_body_sizes = (        27182,        8,        1828,        45904,    )    payloads = (messages_pb2.Payload(body=b'\x00' * size)                for size in payload_body_sizes)    requests = (messages_pb2.StreamingInputCallRequest(payload=payload)                for payload in payloads)    response = stub.StreamingInputCall(requests)    if response.aggregated_payload_size != 74922:        raise ValueError(            'incorrect size %d!' % response.aggregated_payload_size)def _server_streaming(stub):    sizes = (        31415,        9,        2653,        58979,    )    request = messages_pb2.StreamingOutputCallRequest(        response_type=messages_pb2.COMPRESSABLE,        response_parameters=(            messages_pb2.ResponseParameters(size=sizes[0]),            messages_pb2.ResponseParameters(size=sizes[1]),            messages_pb2.ResponseParameters(size=sizes[2]),            messages_pb2.ResponseParameters(size=sizes[3]),        ))    response_iterator = stub.StreamingOutputCall(request)    for index, response in enumerate(response_iterator):        _validate_payload_type_and_length(response, messages_pb2.COMPRESSABLE,                                          sizes[index])class _Pipe(object):    def __init__(self):        self._condition = threading.Condition()        self._values = []        self._open = True    def __iter__(self):        return self    def __next__(self):        return self.next()    def next(self):        with self._condition:            while not self._values and self._open:                self._condition.wait()            if self._values:                return self._values.pop(0)            else:                raise StopIteration()    def add(self, value):        with self._condition:            self._values.append(value)            self._condition.notify()    def close(self):        with self._condition:            self._open = False            self._condition.notify()    def __enter__(self):        return self    def __exit__(self, type, value, traceback):        self.close()def _ping_pong(stub):    request_response_sizes = (        31415,        9,        2653,        58979,    )    request_payload_sizes = (        27182,        8,        1828,        45904,    )    with _Pipe() as pipe:        response_iterator = stub.FullDuplexCall(pipe)        for response_size, payload_size in zip(request_response_sizes,                                               request_payload_sizes):            request = messages_pb2.StreamingOutputCallRequest(                response_type=messages_pb2.COMPRESSABLE,                response_parameters=(                    messages_pb2.ResponseParameters(size=response_size),),                payload=messages_pb2.Payload(body=b'\x00' * payload_size))            pipe.add(request)            response = next(response_iterator)            _validate_payload_type_and_length(                response, messages_pb2.COMPRESSABLE, response_size)def _cancel_after_begin(stub):    with _Pipe() as pipe:        response_future = stub.StreamingInputCall.future(pipe)        response_future.cancel()        if not response_future.cancelled():            raise ValueError('expected cancelled method to return True')        if response_future.code() is not grpc.StatusCode.CANCELLED:            raise ValueError('expected status code CANCELLED')def _cancel_after_first_response(stub):    request_response_sizes = (        31415,        9,        2653,        58979,    )    request_payload_sizes = (        27182,        8,        1828,        45904,    )    with _Pipe() as pipe:        response_iterator = stub.FullDuplexCall(pipe)        response_size = request_response_sizes[0]        payload_size = request_payload_sizes[0]        request = messages_pb2.StreamingOutputCallRequest(            response_type=messages_pb2.COMPRESSABLE,            response_parameters=(                messages_pb2.ResponseParameters(size=response_size),),            payload=messages_pb2.Payload(body=b'\x00' * payload_size))        pipe.add(request)        response = next(response_iterator)        # We test the contents of `response` in the Ping Pong test - don't check        # them here.        response_iterator.cancel()        try:            next(response_iterator)        except grpc.RpcError as rpc_error:            if rpc_error.code() is not grpc.StatusCode.CANCELLED:                raise        else:            raise ValueError('expected call to be cancelled')def _timeout_on_sleeping_server(stub):    request_payload_size = 27182    with _Pipe() as pipe:        response_iterator = stub.FullDuplexCall(pipe, timeout=0.001)        request = messages_pb2.StreamingOutputCallRequest(            response_type=messages_pb2.COMPRESSABLE,            payload=messages_pb2.Payload(body=b'\x00' * request_payload_size))        pipe.add(request)        try:            next(response_iterator)        except grpc.RpcError as rpc_error:            if rpc_error.code() is not grpc.StatusCode.DEADLINE_EXCEEDED:                raise        else:            raise ValueError('expected call to exceed deadline')def _empty_stream(stub):    with _Pipe() as pipe:        response_iterator = stub.FullDuplexCall(pipe)        pipe.close()        try:            next(response_iterator)            raise ValueError('expected exactly 0 responses')        except StopIteration:            passdef _status_code_and_message(stub):    details = 'test status message'    code = 2    status = grpc.StatusCode.UNKNOWN  # code = 2    # Test with a UnaryCall    request = messages_pb2.SimpleRequest(        response_type=messages_pb2.COMPRESSABLE,        response_size=1,        payload=messages_pb2.Payload(body=b'\x00'),        response_status=messages_pb2.EchoStatus(code=code, message=details))    response_future = stub.UnaryCall.future(request)    _validate_status_code_and_details(response_future, status, details)    # Test with a FullDuplexCall    with _Pipe() as pipe:        response_iterator = stub.FullDuplexCall(pipe)        request = messages_pb2.StreamingOutputCallRequest(            response_type=messages_pb2.COMPRESSABLE,            response_parameters=(messages_pb2.ResponseParameters(size=1),),            payload=messages_pb2.Payload(body=b'\x00'),            response_status=messages_pb2.EchoStatus(code=code, message=details))        pipe.add(request)  # sends the initial request.    # Dropping out of with block closes the pipe    _validate_status_code_and_details(response_iterator, status, details)def _unimplemented_method(test_service_stub):    response_future = (test_service_stub.UnimplementedCall.future(        empty_pb2.Empty()))    _expect_status_code(response_future, grpc.StatusCode.UNIMPLEMENTED)def _unimplemented_service(unimplemented_service_stub):    response_future = (unimplemented_service_stub.UnimplementedCall.future(        empty_pb2.Empty()))    _expect_status_code(response_future, grpc.StatusCode.UNIMPLEMENTED)def _custom_metadata(stub):    initial_metadata_value = "test_initial_metadata_value"    trailing_metadata_value = "\x0a\x0b\x0a\x0b\x0a\x0b"    metadata = ((_INITIAL_METADATA_KEY, initial_metadata_value),                (_TRAILING_METADATA_KEY, trailing_metadata_value))    def _validate_metadata(response):        initial_metadata = dict(response.initial_metadata())        if initial_metadata[_INITIAL_METADATA_KEY] != initial_metadata_value:            raise ValueError('expected initial metadata %s, got %s' %                             (initial_metadata_value,                              initial_metadata[_INITIAL_METADATA_KEY]))        trailing_metadata = dict(response.trailing_metadata())        if trailing_metadata[_TRAILING_METADATA_KEY] != trailing_metadata_value:            raise ValueError('expected trailing metadata %s, got %s' %                             (trailing_metadata_value,                              initial_metadata[_TRAILING_METADATA_KEY]))    # Testing with UnaryCall    request = messages_pb2.SimpleRequest(        response_type=messages_pb2.COMPRESSABLE,        response_size=1,        payload=messages_pb2.Payload(body=b'\x00'))    response_future = stub.UnaryCall.future(request, metadata=metadata)    _validate_metadata(response_future)    # Testing with FullDuplexCall    with _Pipe() as pipe:        response_iterator = stub.FullDuplexCall(pipe, metadata=metadata)        request = messages_pb2.StreamingOutputCallRequest(            response_type=messages_pb2.COMPRESSABLE,            response_parameters=(messages_pb2.ResponseParameters(size=1),))        pipe.add(request)  # Sends the request        next(response_iterator)  # Causes server to send trailing metadata    # Dropping out of the with block closes the pipe    _validate_metadata(response_iterator)def _compute_engine_creds(stub, args):    response = _large_unary_common_behavior(stub, True, True, None)    if args.default_service_account != response.username:        raise ValueError('expected username %s, got %s' %                         (args.default_service_account, response.username))def _oauth2_auth_token(stub, args):    json_key_filename = os.environ[google_auth_environment_vars.CREDENTIALS]    wanted_email = json.load(open(json_key_filename, 'rb'))['client_email']    response = _large_unary_common_behavior(stub, True, True, None)    if wanted_email != response.username:        raise ValueError('expected username %s, got %s' % (wanted_email,                                                           response.username))    if args.oauth_scope.find(response.oauth_scope) == -1:        raise ValueError(            'expected to find oauth scope "{}" in received "{}"'.format(                response.oauth_scope, args.oauth_scope))def _jwt_token_creds(stub, args):    json_key_filename = os.environ[google_auth_environment_vars.CREDENTIALS]    wanted_email = json.load(open(json_key_filename, 'rb'))['client_email']    response = _large_unary_common_behavior(stub, True, False, None)    if wanted_email != response.username:        raise ValueError('expected username %s, got %s' % (wanted_email,                                                           response.username))def _per_rpc_creds(stub, args):    json_key_filename = os.environ[google_auth_environment_vars.CREDENTIALS]    wanted_email = json.load(open(json_key_filename, 'rb'))['client_email']    google_credentials, unused_project_id = google_auth.default(        scopes=[args.oauth_scope])    call_credentials = grpc.metadata_call_credentials(        google_auth_transport_grpc.AuthMetadataPlugin(            credentials=google_credentials,            request=google_auth_transport_requests.Request()))    response = _large_unary_common_behavior(stub, True, False, call_credentials)    if wanted_email != response.username:        raise ValueError('expected username %s, got %s' % (wanted_email,                                                           response.username))def _special_status_message(stub, args):    details = b'\t\ntest with whitespace\r\nand Unicode BMP \xe2\x98\xba and non-BMP \xf0\x9f\x98\x88\t\n'.decode(        'utf-8')    code = 2    status = grpc.StatusCode.UNKNOWN  # code = 2    # Test with a UnaryCall    request = messages_pb2.SimpleRequest(        response_type=messages_pb2.COMPRESSABLE,        response_size=1,        payload=messages_pb2.Payload(body=b'\x00'),        response_status=messages_pb2.EchoStatus(code=code, message=details))    response_future = stub.UnaryCall.future(request)    _validate_status_code_and_details(response_future, status, details)@enum.uniqueclass TestCase(enum.Enum):    EMPTY_UNARY = 'empty_unary'    LARGE_UNARY = 'large_unary'    SERVER_STREAMING = 'server_streaming'    CLIENT_STREAMING = 'client_streaming'    PING_PONG = 'ping_pong'    CANCEL_AFTER_BEGIN = 'cancel_after_begin'    CANCEL_AFTER_FIRST_RESPONSE = 'cancel_after_first_response'    EMPTY_STREAM = 'empty_stream'    STATUS_CODE_AND_MESSAGE = 'status_code_and_message'    UNIMPLEMENTED_METHOD = 'unimplemented_method'    UNIMPLEMENTED_SERVICE = 'unimplemented_service'    CUSTOM_METADATA = "custom_metadata"    COMPUTE_ENGINE_CREDS = 'compute_engine_creds'    OAUTH2_AUTH_TOKEN = 'oauth2_auth_token'    JWT_TOKEN_CREDS = 'jwt_token_creds'    PER_RPC_CREDS = 'per_rpc_creds'    TIMEOUT_ON_SLEEPING_SERVER = 'timeout_on_sleeping_server'    SPECIAL_STATUS_MESSAGE = 'special_status_message'    def test_interoperability(self, stub, args):        if self is TestCase.EMPTY_UNARY:            _empty_unary(stub)        elif self is TestCase.LARGE_UNARY:            _large_unary(stub)        elif self is TestCase.SERVER_STREAMING:            _server_streaming(stub)        elif self is TestCase.CLIENT_STREAMING:            _client_streaming(stub)        elif self is TestCase.PING_PONG:            _ping_pong(stub)        elif self is TestCase.CANCEL_AFTER_BEGIN:            _cancel_after_begin(stub)        elif self is TestCase.CANCEL_AFTER_FIRST_RESPONSE:            _cancel_after_first_response(stub)        elif self is TestCase.TIMEOUT_ON_SLEEPING_SERVER:            _timeout_on_sleeping_server(stub)        elif self is TestCase.EMPTY_STREAM:            _empty_stream(stub)        elif self is TestCase.STATUS_CODE_AND_MESSAGE:            _status_code_and_message(stub)        elif self is TestCase.UNIMPLEMENTED_METHOD:            _unimplemented_method(stub)        elif self is TestCase.UNIMPLEMENTED_SERVICE:            _unimplemented_service(stub)        elif self is TestCase.CUSTOM_METADATA:            _custom_metadata(stub)        elif self is TestCase.COMPUTE_ENGINE_CREDS:            _compute_engine_creds(stub, args)        elif self is TestCase.OAUTH2_AUTH_TOKEN:            _oauth2_auth_token(stub, args)        elif self is TestCase.JWT_TOKEN_CREDS:            _jwt_token_creds(stub, args)        elif self is TestCase.PER_RPC_CREDS:            _per_rpc_creds(stub, args)        elif self is TestCase.SPECIAL_STATUS_MESSAGE:            _special_status_message(stub, args)        else:            raise NotImplementedError(                'Test case "%s" not implemented!' % self.name)
 |