123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899 |
- # Copyright 2020 gRPC authors.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import logging
- import re
- from typing import Optional, ClassVar, Dict
- # Workaround: `grpc` must be imported before `google.protobuf.json_format`,
- # to prevent "Segmentation fault". Ref https://github.com/grpc/grpc/issues/24897
- import grpc
- from google.protobuf import json_format
- import google.protobuf.message
- logger = logging.getLogger(__name__)
- # Type aliases
- Message = google.protobuf.message.Message
- class GrpcClientHelper:
- channel: grpc.Channel
- DEFAULT_CONNECTION_TIMEOUT_SEC = 60
- DEFAULT_WAIT_FOR_READY_SEC = 60
- def __init__(self, channel: grpc.Channel, stub_class: ClassVar):
- self.channel = channel
- self.stub = stub_class(channel)
- # This is purely cosmetic to make RPC logs look like method calls.
- self.log_service_name = re.sub('Stub$', '',
- self.stub.__class__.__name__)
- def call_unary_with_deadline(
- self,
- *,
- rpc: str,
- req: Message,
- wait_for_ready_sec: Optional[int] = DEFAULT_WAIT_FOR_READY_SEC,
- connection_timeout_sec: Optional[
- int] = DEFAULT_CONNECTION_TIMEOUT_SEC) -> Message:
- if wait_for_ready_sec is None:
- wait_for_ready_sec = self.DEFAULT_WAIT_FOR_READY_SEC
- if connection_timeout_sec is None:
- connection_timeout_sec = self.DEFAULT_CONNECTION_TIMEOUT_SEC
- timeout_sec = wait_for_ready_sec + connection_timeout_sec
- rpc_callable: grpc.UnaryUnaryMultiCallable = getattr(self.stub, rpc)
- call_kwargs = dict(wait_for_ready=True, timeout=timeout_sec)
- self._log_debug(rpc, req, call_kwargs)
- return rpc_callable(req, **call_kwargs)
- def _log_debug(self, rpc, req, call_kwargs):
- logger.debug('RPC %s.%s(request=%s(%r), %s)',
- self.log_service_name, rpc, req.__class__.__name__,
- json_format.MessageToDict(req),
- ', '.join({f'{k}={v}' for k, v in call_kwargs.items()}))
- class GrpcApp:
- channels: Dict[int, grpc.Channel]
- class NotFound(Exception):
- """Requested resource not found"""
- def __init__(self, rpc_host):
- self.rpc_host = rpc_host
- # Cache gRPC channels per port
- self.channels = dict()
- def _make_channel(self, port) -> grpc.Channel:
- if port not in self.channels:
- target = f'{self.rpc_host}:{port}'
- self.channels[port] = grpc.insecure_channel(target)
- return self.channels[port]
- def close(self):
- # Close all channels
- for channel in self.channels.values():
- channel.close()
- def __enter__(self):
- return self
- def __exit__(self, exc_type, exc_val, exc_tb):
- self.close()
- return False
- def __del__(self):
- self.close()
|