grpc.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899
  1. # Copyright 2020 gRPC authors.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import logging
  15. import re
  16. from typing import Optional, ClassVar, Dict
  17. # Workaround: `grpc` must be imported before `google.protobuf.json_format`,
  18. # to prevent "Segmentation fault". Ref https://github.com/grpc/grpc/issues/24897
  19. import grpc
  20. from google.protobuf import json_format
  21. import google.protobuf.message
  22. logger = logging.getLogger(__name__)
  23. # Type aliases
  24. Message = google.protobuf.message.Message
  25. class GrpcClientHelper:
  26. channel: grpc.Channel
  27. DEFAULT_CONNECTION_TIMEOUT_SEC = 60
  28. DEFAULT_WAIT_FOR_READY_SEC = 60
  29. def __init__(self, channel: grpc.Channel, stub_class: ClassVar):
  30. self.channel = channel
  31. self.stub = stub_class(channel)
  32. # This is purely cosmetic to make RPC logs look like method calls.
  33. self.log_service_name = re.sub('Stub$', '',
  34. self.stub.__class__.__name__)
  35. def call_unary_with_deadline(
  36. self,
  37. *,
  38. rpc: str,
  39. req: Message,
  40. wait_for_ready_sec: Optional[int] = DEFAULT_WAIT_FOR_READY_SEC,
  41. connection_timeout_sec: Optional[
  42. int] = DEFAULT_CONNECTION_TIMEOUT_SEC) -> Message:
  43. if wait_for_ready_sec is None:
  44. wait_for_ready_sec = self.DEFAULT_WAIT_FOR_READY_SEC
  45. if connection_timeout_sec is None:
  46. connection_timeout_sec = self.DEFAULT_CONNECTION_TIMEOUT_SEC
  47. timeout_sec = wait_for_ready_sec + connection_timeout_sec
  48. rpc_callable: grpc.UnaryUnaryMultiCallable = getattr(self.stub, rpc)
  49. call_kwargs = dict(wait_for_ready=True, timeout=timeout_sec)
  50. self._log_debug(rpc, req, call_kwargs)
  51. return rpc_callable(req, **call_kwargs)
  52. def _log_debug(self, rpc, req, call_kwargs):
  53. logger.debug('RPC %s.%s(request=%s(%r), %s)',
  54. self.log_service_name, rpc, req.__class__.__name__,
  55. json_format.MessageToDict(req),
  56. ', '.join({f'{k}={v}' for k, v in call_kwargs.items()}))
  57. class GrpcApp:
  58. channels: Dict[int, grpc.Channel]
  59. class NotFound(Exception):
  60. """Requested resource not found"""
  61. def __init__(self, rpc_host):
  62. self.rpc_host = rpc_host
  63. # Cache gRPC channels per port
  64. self.channels = dict()
  65. def _make_channel(self, port) -> grpc.Channel:
  66. if port not in self.channels:
  67. target = f'{self.rpc_host}:{port}'
  68. self.channels[port] = grpc.insecure_channel(target)
  69. return self.channels[port]
  70. def close(self):
  71. # Close all channels
  72. for channel in self.channels.values():
  73. channel.close()
  74. def __enter__(self):
  75. return self
  76. def __exit__(self, exc_type, exc_val, exc_tb):
  77. self.close()
  78. return False
  79. def __del__(self):
  80. self.close()