|
@@ -0,0 +1,99 @@
|
|
|
|
+# Copyright 2020 gRPC authors.
|
|
|
|
+#
|
|
|
|
+# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
+# you may not use this file except in compliance with the License.
|
|
|
|
+# You may obtain a copy of the License at
|
|
|
|
+#
|
|
|
|
+# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
+#
|
|
|
|
+# Unless required by applicable law or agreed to in writing, software
|
|
|
|
+# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
+# See the License for the specific language governing permissions and
|
|
|
|
+# limitations under the License.
|
|
|
|
+import logging
|
|
|
|
+import re
|
|
|
|
+from typing import Optional, ClassVar, Dict
|
|
|
|
+
|
|
|
|
+# Workaround: `grpc` must be imported before `google.protobuf.json_format`,
|
|
|
|
+# to prevent "Segmentation fault". Ref https://github.com/grpc/grpc/issues/24897
|
|
|
|
+import grpc
|
|
|
|
+from google.protobuf import json_format
|
|
|
|
+import google.protobuf.message
|
|
|
|
+
|
|
|
|
+logger = logging.getLogger(__name__)
|
|
|
|
+
|
|
|
|
+# Type aliases
|
|
|
|
+Message = google.protobuf.message.Message
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+class GrpcClientHelper:
|
|
|
|
+ channel: grpc.Channel
|
|
|
|
+ DEFAULT_CONNECTION_TIMEOUT_SEC = 60
|
|
|
|
+ DEFAULT_WAIT_FOR_READY_SEC = 60
|
|
|
|
+
|
|
|
|
+ def __init__(self, channel: grpc.Channel, stub_class: ClassVar):
|
|
|
|
+ self.channel = channel
|
|
|
|
+ self.stub = stub_class(channel)
|
|
|
|
+ # This is purely cosmetic to make RPC logs look like method calls.
|
|
|
|
+ self.log_service_name = re.sub('Stub$', '',
|
|
|
|
+ self.stub.__class__.__name__)
|
|
|
|
+
|
|
|
|
+ def call_unary_with_deadline(
|
|
|
|
+ self,
|
|
|
|
+ *,
|
|
|
|
+ rpc: str,
|
|
|
|
+ req: Message,
|
|
|
|
+ wait_for_ready_sec: Optional[int] = DEFAULT_WAIT_FOR_READY_SEC,
|
|
|
|
+ connection_timeout_sec: Optional[
|
|
|
|
+ int] = DEFAULT_CONNECTION_TIMEOUT_SEC) -> Message:
|
|
|
|
+ if wait_for_ready_sec is None:
|
|
|
|
+ wait_for_ready_sec = self.DEFAULT_WAIT_FOR_READY_SEC
|
|
|
|
+ if connection_timeout_sec is None:
|
|
|
|
+ connection_timeout_sec = self.DEFAULT_CONNECTION_TIMEOUT_SEC
|
|
|
|
+
|
|
|
|
+ timeout_sec = wait_for_ready_sec + connection_timeout_sec
|
|
|
|
+ rpc_callable: grpc.UnaryUnaryMultiCallable = getattr(self.stub, rpc)
|
|
|
|
+
|
|
|
|
+ call_kwargs = dict(wait_for_ready=True, timeout=timeout_sec)
|
|
|
|
+ self._log_debug(rpc, req, call_kwargs)
|
|
|
|
+ return rpc_callable(req, **call_kwargs)
|
|
|
|
+
|
|
|
|
+ def _log_debug(self, rpc, req, call_kwargs):
|
|
|
|
+ logger.debug('RPC %s.%s(request=%s(%r), %s)',
|
|
|
|
+ self.log_service_name, rpc, req.__class__.__name__,
|
|
|
|
+ json_format.MessageToDict(req),
|
|
|
|
+ ', '.join({f'{k}={v}' for k, v in call_kwargs.items()}))
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+class GrpcApp:
|
|
|
|
+ channels: Dict[int, grpc.Channel]
|
|
|
|
+
|
|
|
|
+ class NotFound(Exception):
|
|
|
|
+ """Requested resource not found"""
|
|
|
|
+
|
|
|
|
+ def __init__(self, rpc_host):
|
|
|
|
+ self.rpc_host = rpc_host
|
|
|
|
+ # Cache gRPC channels per port
|
|
|
|
+ self.channels = dict()
|
|
|
|
+
|
|
|
|
+ def _make_channel(self, port) -> grpc.Channel:
|
|
|
|
+ if port not in self.channels:
|
|
|
|
+ target = f'{self.rpc_host}:{port}'
|
|
|
|
+ self.channels[port] = grpc.insecure_channel(target)
|
|
|
|
+ return self.channels[port]
|
|
|
|
+
|
|
|
|
+ def close(self):
|
|
|
|
+ # Close all channels
|
|
|
|
+ for channel in self.channels.values():
|
|
|
|
+ channel.close()
|
|
|
|
+
|
|
|
|
+ def __enter__(self):
|
|
|
|
+ return self
|
|
|
|
+
|
|
|
|
+ def __exit__(self, exc_type, exc_val, exc_tb):
|
|
|
|
+ self.close()
|
|
|
|
+ return False
|
|
|
|
+
|
|
|
|
+ def __del__(self):
|
|
|
|
+ self.close()
|