Sfoglia il codice sorgente

Merge pull request #12662 from kpayson64/fix_cpu_100

Explicitly unreference c-core call objects on the server
kpayson64 8 anni fa
parent
commit
2fdf1c00c2

+ 7 - 8
src/python/grpcio/grpc/_cython/_cygrpc/call.pyx.pxi

@@ -24,7 +24,7 @@ cdef class Call:
     self.references = []
 
   def _start_batch(self, operations, tag, retain_self):
-    if not self.is_valid:
+    if self.c_call == NULL:
       raise ValueError("invalid call object cannot be used from Python")
     cdef grpc_call_error result
     cdef Operations cy_operations = Operations(operations)
@@ -53,7 +53,7 @@ cdef class Call:
       self, grpc_status_code error_code=GRPC_STATUS__DO_NOT_USE,
       details=None):
     details = str_to_bytes(details)
-    if not self.is_valid:
+    if self.c_call == NULL:
       raise ValueError("invalid call object cannot be used from Python")
     if (details is None) != (error_code == GRPC_STATUS__DO_NOT_USE):
       raise ValueError("if error_code is specified, so must details "
@@ -89,13 +89,12 @@ cdef class Call:
       gpr_free(peer)
     return result
 
+  def close(self):
+    if self.c_call != NULL:
+      grpc_call_unref(self.c_call)
+      self.c_call = NULL
+
   def __dealloc__(self):
     if self.c_call != NULL:
       grpc_call_unref(self.c_call)
     grpc_shutdown()
-
-  # The object *should* always be valid from Python. Used for debugging.
-  @property
-  def is_valid(self):
-    return self.c_call != NULL
-

+ 55 - 50
src/python/grpcio/grpc/_server.py

@@ -83,7 +83,8 @@ class _HandlerCallDetails(
 
 class _RPCState(object):
 
-    def __init__(self):
+    def __init__(self, call):
+        self.call = call
         self.condition = threading.Condition()
         self.due = set()
         self.request = None
@@ -106,7 +107,9 @@ def _raise_rpc_error(state):
 
 def _possibly_finish_call(state, token):
     state.due.remove(token)
+
     if (state.client is _CANCELLED or state.statused) and not state.due:
+        state.call.close()
         callbacks = state.callbacks
         state.callbacks = None
         return state, callbacks
@@ -207,10 +210,19 @@ def _send_message(state, token):
 
 class _Context(grpc.ServicerContext):
 
-    def __init__(self, rpc_event, state, request_deserializer):
+    def __init__(self, rpc_event, state):
         self._rpc_event = rpc_event
         self._state = state
-        self._request_deserializer = request_deserializer
+        self._peer = _common.decode(self._rpc_event.operation_call.peer())
+        self._peer_identities = cygrpc.peer_identities(
+            self._rpc_event.operation_call)
+        self._peer_identity_key = cygrpc.peer_identity_key(
+            self._rpc_event.operation_call)
+        self._auth_context = {
+            _common.decode(key): value
+            for key, value in six.iteritems(
+                cygrpc.auth_context(self._rpc_event.operation_call))
+        }
 
     def is_active(self):
         with self._state.condition:
@@ -240,21 +252,17 @@ class _Context(grpc.ServicerContext):
         return _common.to_application_metadata(self._rpc_event.request_metadata)
 
     def peer(self):
-        return _common.decode(self._rpc_event.operation_call.peer())
+        return self._peer
 
     def peer_identities(self):
-        return cygrpc.peer_identities(self._rpc_event.operation_call)
+        return self._peer_identities
 
     def peer_identity_key(self):
-        id_key = cygrpc.peer_identity_key(self._rpc_event.operation_call)
-        return id_key if id_key is None else _common.decode(id_key)
+        return self._peer_identity_key if self._peer_identity_key is None else _common.decode(
+            self._peer_identity_key)
 
     def auth_context(self):
-        return {
-            _common.decode(key): value
-            for key, value in six.iteritems(
-                cygrpc.auth_context(self._rpc_event.operation_call))
-        }
+        return self._auth_context
 
     def send_initial_metadata(self, initial_metadata):
         with self._state.condition:
@@ -370,8 +378,7 @@ def _unary_request(rpc_event, state, request_deserializer):
     return unary_request
 
 
-def _call_behavior(rpc_event, state, behavior, argument, request_deserializer):
-    context = _Context(rpc_event, state, request_deserializer)
+def _call_behavior(rpc_event, state, context, behavior, argument):
     try:
         return behavior(argument, context), True
     except Exception as e:  # pylint: disable=broad-except
@@ -461,12 +468,12 @@ def _status(rpc_event, state, serialized_response):
             state.due.add(_SEND_STATUS_FROM_SERVER_TOKEN)
 
 
-def _unary_response_in_pool(rpc_event, state, behavior, argument_thunk,
-                            request_deserializer, response_serializer):
+def _unary_response_in_pool(rpc_event, state, context, behavior, argument_thunk,
+                            response_serializer):
     argument = argument_thunk()
     if argument is not None:
-        response, proceed = _call_behavior(rpc_event, state, behavior, argument,
-                                           request_deserializer)
+        response, proceed = _call_behavior(rpc_event, state, context, behavior,
+                                           argument)
         if proceed:
             serialized_response = _serialize_response(
                 rpc_event, state, response, response_serializer)
@@ -474,12 +481,12 @@ def _unary_response_in_pool(rpc_event, state, behavior, argument_thunk,
                 _status(rpc_event, state, serialized_response)
 
 
-def _stream_response_in_pool(rpc_event, state, behavior, argument_thunk,
-                             request_deserializer, response_serializer):
+def _stream_response_in_pool(rpc_event, state, context, behavior,
+                             argument_thunk, response_serializer):
     argument = argument_thunk()
     if argument is not None:
-        response_iterator, proceed = _call_behavior(
-            rpc_event, state, behavior, argument, request_deserializer)
+        response_iterator, proceed = _call_behavior(rpc_event, state, context,
+                                                    behavior, argument)
         if proceed:
             while True:
                 response, proceed = _take_response_from_response_iterator(
@@ -502,40 +509,41 @@ def _stream_response_in_pool(rpc_event, state, behavior, argument_thunk,
                     break
 
 
-def _handle_unary_unary(rpc_event, state, method_handler, thread_pool):
+def _handle_unary_unary(rpc_event, state, context, method_handler, thread_pool):
     unary_request = _unary_request(rpc_event, state,
                                    method_handler.request_deserializer)
     return thread_pool.submit(_unary_response_in_pool, rpc_event, state,
-                              method_handler.unary_unary, unary_request,
-                              method_handler.request_deserializer,
-                              method_handler.response_serializer)
+                              context, method_handler.unary_unary,
+                              unary_request, method_handler.response_serializer)
 
 
-def _handle_unary_stream(rpc_event, state, method_handler, thread_pool):
+def _handle_unary_stream(rpc_event, state, context, method_handler,
+                         thread_pool):
     unary_request = _unary_request(rpc_event, state,
                                    method_handler.request_deserializer)
     return thread_pool.submit(_stream_response_in_pool, rpc_event, state,
-                              method_handler.unary_stream, unary_request,
-                              method_handler.request_deserializer,
-                              method_handler.response_serializer)
+                              context, method_handler.unary_stream,
+                              unary_request, method_handler.response_serializer)
 
 
-def _handle_stream_unary(rpc_event, state, method_handler, thread_pool):
+def _handle_stream_unary(rpc_event, state, context, method_handler,
+                         thread_pool):
     request_iterator = _RequestIterator(state, rpc_event.operation_call,
                                         method_handler.request_deserializer)
-    return thread_pool.submit(
-        _unary_response_in_pool, rpc_event, state, method_handler.stream_unary,
-        lambda: request_iterator, method_handler.request_deserializer,
-        method_handler.response_serializer)
+    return thread_pool.submit(_unary_response_in_pool, rpc_event, state,
+                              context, method_handler.stream_unary,
+                              lambda: request_iterator,
+                              method_handler.response_serializer)
 
 
-def _handle_stream_stream(rpc_event, state, method_handler, thread_pool):
+def _handle_stream_stream(rpc_event, state, context, method_handler,
+                          thread_pool):
     request_iterator = _RequestIterator(state, rpc_event.operation_call,
                                         method_handler.request_deserializer)
-    return thread_pool.submit(
-        _stream_response_in_pool, rpc_event, state,
-        method_handler.stream_stream, lambda: request_iterator,
-        method_handler.request_deserializer, method_handler.response_serializer)
+    return thread_pool.submit(_stream_response_in_pool, rpc_event, state,
+                              context, method_handler.stream_stream,
+                              lambda: request_iterator,
+                              method_handler.response_serializer)
 
 
 def _find_method_handler(rpc_event, generic_handlers):
@@ -556,14 +564,15 @@ def _reject_rpc(rpc_event, status, details):
                   cygrpc.operation_receive_close_on_server(_EMPTY_FLAGS),
                   cygrpc.operation_send_status_from_server(
                       _common.EMPTY_METADATA, status, details, _EMPTY_FLAGS),)
-    rpc_state = _RPCState()
+    rpc_state = _RPCState(rpc_event.operation_call)
     rpc_event.operation_call.start_server_batch(
         operations, lambda ignored_event: (rpc_state, (),))
     return rpc_state
 
 
 def _handle_with_method_handler(rpc_event, method_handler, thread_pool):
-    state = _RPCState()
+    state = _RPCState(rpc_event.operation_call)
+    context = _Context(rpc_event, state)
     with state.condition:
         rpc_event.operation_call.start_server_batch(
             cygrpc.Operations(
@@ -572,17 +581,17 @@ def _handle_with_method_handler(rpc_event, method_handler, thread_pool):
         state.due.add(_RECEIVE_CLOSE_ON_SERVER_TOKEN)
         if method_handler.request_streaming:
             if method_handler.response_streaming:
-                return state, _handle_stream_stream(rpc_event, state,
+                return state, _handle_stream_stream(rpc_event, state, context,
                                                     method_handler, thread_pool)
             else:
-                return state, _handle_stream_unary(rpc_event, state,
+                return state, _handle_stream_unary(rpc_event, state, context,
                                                    method_handler, thread_pool)
         else:
             if method_handler.response_streaming:
-                return state, _handle_unary_stream(rpc_event, state,
+                return state, _handle_unary_stream(rpc_event, state, context,
                                                    method_handler, thread_pool)
             else:
-                return state, _handle_unary_unary(rpc_event, state,
+                return state, _handle_unary_unary(rpc_event, state, context,
                                                   method_handler, thread_pool)
 
 
@@ -706,10 +715,6 @@ def _serve(state):
                     state.rpc_states.remove(rpc_state)
                     if _stop_serving(state):
                         return
-        # We want to force the deletion of the previous event
-        # ~before~ we poll again; if the event has a reference
-        # to a shutdown Call object, this can induce spinlock.
-        event = None
 
 
 def _stop(state, grace):