|  | @@ -0,0 +1,222 @@
 | 
	
		
			
				|  |  | +# Copyright 2016, Google Inc.
 | 
	
		
			
				|  |  | +# All rights reserved.
 | 
	
		
			
				|  |  | +#
 | 
	
		
			
				|  |  | +# Redistribution and use in source and binary forms, with or without
 | 
	
		
			
				|  |  | +# modification, are permitted provided that the following conditions are
 | 
	
		
			
				|  |  | +# met:
 | 
	
		
			
				|  |  | +#
 | 
	
		
			
				|  |  | +#     * Redistributions of source code must retain the above copyright
 | 
	
		
			
				|  |  | +# notice, this list of conditions and the following disclaimer.
 | 
	
		
			
				|  |  | +#     * Redistributions in binary form must reproduce the above
 | 
	
		
			
				|  |  | +# copyright notice, this list of conditions and the following disclaimer
 | 
	
		
			
				|  |  | +# in the documentation and/or other materials provided with the
 | 
	
		
			
				|  |  | +# distribution.
 | 
	
		
			
				|  |  | +#     * Neither the name of Google Inc. nor the names of its
 | 
	
		
			
				|  |  | +# contributors may be used to endorse or promote products derived from
 | 
	
		
			
				|  |  | +# this software without specific prior written permission.
 | 
	
		
			
				|  |  | +#
 | 
	
		
			
				|  |  | +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
 | 
	
		
			
				|  |  | +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
 | 
	
		
			
				|  |  | +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
 | 
	
		
			
				|  |  | +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
 | 
	
		
			
				|  |  | +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
 | 
	
		
			
				|  |  | +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
 | 
	
		
			
				|  |  | +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
 | 
	
		
			
				|  |  | +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
 | 
	
		
			
				|  |  | +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
 | 
	
		
			
				|  |  | +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
 | 
	
		
			
				|  |  | +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +"""Test making many calls and immediately cancelling most of them."""
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +import threading
 | 
	
		
			
				|  |  | +import unittest
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +from grpc._cython import cygrpc
 | 
	
		
			
				|  |  | +from grpc.framework.foundation import logging_pool
 | 
	
		
			
				|  |  | +from tests.unit.framework.common import test_constants
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +_INFINITE_FUTURE = cygrpc.Timespec(float('+inf'))
 | 
	
		
			
				|  |  | +_EMPTY_FLAGS = 0
 | 
	
		
			
				|  |  | +_EMPTY_METADATA = cygrpc.Metadata(())
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +_SERVER_SHUTDOWN_TAG = 'server_shutdown'
 | 
	
		
			
				|  |  | +_REQUEST_CALL_TAG = 'request_call'
 | 
	
		
			
				|  |  | +_RECEIVE_CLOSE_ON_SERVER_TAG = 'receive_close_on_server'
 | 
	
		
			
				|  |  | +_RECEIVE_MESSAGE_TAG = 'receive_message'
 | 
	
		
			
				|  |  | +_SERVER_COMPLETE_CALL_TAG = 'server_complete_call'
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +_SUCCESS_CALL_FRACTION = 1.0 / 8.0
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +class _State(object):
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +  def __init__(self):
 | 
	
		
			
				|  |  | +    self.condition = threading.Condition()
 | 
	
		
			
				|  |  | +    self.handlers_released = False
 | 
	
		
			
				|  |  | +    self.parked_handlers = 0
 | 
	
		
			
				|  |  | +    self.handled_rpcs = 0
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +def _is_cancellation_event(event):
 | 
	
		
			
				|  |  | +  return (
 | 
	
		
			
				|  |  | +      event.tag is _RECEIVE_CLOSE_ON_SERVER_TAG and
 | 
	
		
			
				|  |  | +      event.batch_operations[0].received_cancelled)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +class _Handler(object):
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +  def __init__(self, state, completion_queue, rpc_event):
 | 
	
		
			
				|  |  | +    self._state = state
 | 
	
		
			
				|  |  | +    self._lock = threading.Lock()
 | 
	
		
			
				|  |  | +    self._completion_queue = completion_queue
 | 
	
		
			
				|  |  | +    self._call = rpc_event.operation_call
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +  def __call__(self):
 | 
	
		
			
				|  |  | +    with self._state.condition:
 | 
	
		
			
				|  |  | +      self._state.parked_handlers += 1
 | 
	
		
			
				|  |  | +      if self._state.parked_handlers == test_constants.THREAD_CONCURRENCY:
 | 
	
		
			
				|  |  | +        self._state.condition.notify_all()
 | 
	
		
			
				|  |  | +      while not self._state.handlers_released:
 | 
	
		
			
				|  |  | +        self._state.condition.wait()
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    with self._lock:
 | 
	
		
			
				|  |  | +      self._call.start_batch(
 | 
	
		
			
				|  |  | +          cygrpc.Operations(
 | 
	
		
			
				|  |  | +              (cygrpc.operation_receive_close_on_server(_EMPTY_FLAGS),)),
 | 
	
		
			
				|  |  | +          _RECEIVE_CLOSE_ON_SERVER_TAG)
 | 
	
		
			
				|  |  | +      self._call.start_batch(
 | 
	
		
			
				|  |  | +          cygrpc.Operations((cygrpc.operation_receive_message(_EMPTY_FLAGS),)),
 | 
	
		
			
				|  |  | +          _RECEIVE_MESSAGE_TAG)
 | 
	
		
			
				|  |  | +    first_event = self._completion_queue.poll()
 | 
	
		
			
				|  |  | +    if _is_cancellation_event(first_event):
 | 
	
		
			
				|  |  | +      self._completion_queue.poll()
 | 
	
		
			
				|  |  | +    else:
 | 
	
		
			
				|  |  | +      with self._lock:
 | 
	
		
			
				|  |  | +        operations = (
 | 
	
		
			
				|  |  | +            cygrpc.operation_send_initial_metadata(
 | 
	
		
			
				|  |  | +                _EMPTY_METADATA, _EMPTY_FLAGS),
 | 
	
		
			
				|  |  | +            cygrpc.operation_send_message(b'\x79\x57', _EMPTY_FLAGS),
 | 
	
		
			
				|  |  | +            cygrpc.operation_send_status_from_server(
 | 
	
		
			
				|  |  | +                _EMPTY_METADATA, cygrpc.StatusCode.ok, b'test details!',
 | 
	
		
			
				|  |  | +                _EMPTY_FLAGS),
 | 
	
		
			
				|  |  | +        )
 | 
	
		
			
				|  |  | +        self._call.start_batch(
 | 
	
		
			
				|  |  | +            cygrpc.Operations(operations), _SERVER_COMPLETE_CALL_TAG)
 | 
	
		
			
				|  |  | +      self._completion_queue.poll()
 | 
	
		
			
				|  |  | +      self._completion_queue.poll()
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +def _serve(state, server, server_completion_queue, thread_pool):
 | 
	
		
			
				|  |  | +  for _ in range(test_constants.RPC_CONCURRENCY):
 | 
	
		
			
				|  |  | +    call_completion_queue = cygrpc.CompletionQueue()
 | 
	
		
			
				|  |  | +    server.request_call(
 | 
	
		
			
				|  |  | +        call_completion_queue, server_completion_queue, _REQUEST_CALL_TAG)
 | 
	
		
			
				|  |  | +    rpc_event = server_completion_queue.poll()
 | 
	
		
			
				|  |  | +    thread_pool.submit(_Handler(state, call_completion_queue, rpc_event))
 | 
	
		
			
				|  |  | +    with state.condition:
 | 
	
		
			
				|  |  | +      state.handled_rpcs += 1
 | 
	
		
			
				|  |  | +      if test_constants.RPC_CONCURRENCY <= state.handled_rpcs:
 | 
	
		
			
				|  |  | +        state.condition.notify_all()
 | 
	
		
			
				|  |  | +  server_completion_queue.poll()
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +class _QueueDriver(object):
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +  def __init__(self, condition, completion_queue, due):
 | 
	
		
			
				|  |  | +    self._condition = condition
 | 
	
		
			
				|  |  | +    self._completion_queue = completion_queue
 | 
	
		
			
				|  |  | +    self._due = due
 | 
	
		
			
				|  |  | +    self._events = []
 | 
	
		
			
				|  |  | +    self._returned = False
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +  def start(self):
 | 
	
		
			
				|  |  | +    def in_thread():
 | 
	
		
			
				|  |  | +      while True:
 | 
	
		
			
				|  |  | +        event = self._completion_queue.poll()
 | 
	
		
			
				|  |  | +        with self._condition:
 | 
	
		
			
				|  |  | +          self._events.append(event)
 | 
	
		
			
				|  |  | +          self._due.remove(event.tag)
 | 
	
		
			
				|  |  | +          self._condition.notify_all()
 | 
	
		
			
				|  |  | +          if not self._due:
 | 
	
		
			
				|  |  | +            self._returned = True
 | 
	
		
			
				|  |  | +            return
 | 
	
		
			
				|  |  | +    thread = threading.Thread(target=in_thread)
 | 
	
		
			
				|  |  | +    thread.start()
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +  def events(self, at_least):
 | 
	
		
			
				|  |  | +    with self._condition:
 | 
	
		
			
				|  |  | +      while len(self._events) < at_least:
 | 
	
		
			
				|  |  | +        self._condition.wait()
 | 
	
		
			
				|  |  | +      return tuple(self._events)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +class CancelManyCallsTest(unittest.TestCase):
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +  def testCancelManyCalls(self):
 | 
	
		
			
				|  |  | +    server_thread_pool = logging_pool.pool(test_constants.THREAD_CONCURRENCY)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    server_completion_queue = cygrpc.CompletionQueue()
 | 
	
		
			
				|  |  | +    server = cygrpc.Server()
 | 
	
		
			
				|  |  | +    server.register_completion_queue(server_completion_queue)
 | 
	
		
			
				|  |  | +    port = server.add_http2_port('[::]:0')
 | 
	
		
			
				|  |  | +    server.start()
 | 
	
		
			
				|  |  | +    channel = cygrpc.Channel('localhost:{}'.format(port))
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    state = _State()
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    server_thread_args = (
 | 
	
		
			
				|  |  | +        state, server, server_completion_queue, server_thread_pool,)
 | 
	
		
			
				|  |  | +    server_thread = threading.Thread(target=_serve, args=server_thread_args)
 | 
	
		
			
				|  |  | +    server_thread.start()
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    client_condition = threading.Condition()
 | 
	
		
			
				|  |  | +    client_due = set()
 | 
	
		
			
				|  |  | +    client_completion_queue = cygrpc.CompletionQueue()
 | 
	
		
			
				|  |  | +    client_driver = _QueueDriver(
 | 
	
		
			
				|  |  | +        client_condition, client_completion_queue, client_due)
 | 
	
		
			
				|  |  | +    client_driver.start()
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    with client_condition:
 | 
	
		
			
				|  |  | +      client_calls = []
 | 
	
		
			
				|  |  | +      for index in range(test_constants.RPC_CONCURRENCY):
 | 
	
		
			
				|  |  | +        client_call = channel.create_call(
 | 
	
		
			
				|  |  | +            None, _EMPTY_FLAGS, client_completion_queue, b'/twinkies', None,
 | 
	
		
			
				|  |  | +            _INFINITE_FUTURE)
 | 
	
		
			
				|  |  | +        operations = (
 | 
	
		
			
				|  |  | +            cygrpc.operation_send_initial_metadata(
 | 
	
		
			
				|  |  | +                _EMPTY_METADATA, _EMPTY_FLAGS),
 | 
	
		
			
				|  |  | +            cygrpc.operation_send_message(b'\x45\x56', _EMPTY_FLAGS),
 | 
	
		
			
				|  |  | +            cygrpc.operation_send_close_from_client(_EMPTY_FLAGS),
 | 
	
		
			
				|  |  | +            cygrpc.operation_receive_initial_metadata(_EMPTY_FLAGS),
 | 
	
		
			
				|  |  | +            cygrpc.operation_receive_message(_EMPTY_FLAGS),
 | 
	
		
			
				|  |  | +            cygrpc.operation_receive_status_on_client(_EMPTY_FLAGS),
 | 
	
		
			
				|  |  | +        )
 | 
	
		
			
				|  |  | +        tag = 'client_complete_call_{0:04d}_tag'.format(index)
 | 
	
		
			
				|  |  | +        client_call.start_batch(cygrpc.Operations(operations), tag)
 | 
	
		
			
				|  |  | +        client_due.add(tag)
 | 
	
		
			
				|  |  | +        client_calls.append(client_call)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    with state.condition:
 | 
	
		
			
				|  |  | +      while True:
 | 
	
		
			
				|  |  | +        if state.parked_handlers < test_constants.THREAD_CONCURRENCY:
 | 
	
		
			
				|  |  | +          state.condition.wait()
 | 
	
		
			
				|  |  | +        elif state.handled_rpcs < test_constants.RPC_CONCURRENCY:
 | 
	
		
			
				|  |  | +          state.condition.wait()
 | 
	
		
			
				|  |  | +        else:
 | 
	
		
			
				|  |  | +          state.handlers_released = True
 | 
	
		
			
				|  |  | +          state.condition.notify_all()
 | 
	
		
			
				|  |  | +          break
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    client_driver.events(
 | 
	
		
			
				|  |  | +        test_constants.RPC_CONCURRENCY * _SUCCESS_CALL_FRACTION)
 | 
	
		
			
				|  |  | +    with client_condition:
 | 
	
		
			
				|  |  | +      for client_call in client_calls:
 | 
	
		
			
				|  |  | +        client_call.cancel()
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    with state.condition:
 | 
	
		
			
				|  |  | +      server.shutdown(server_completion_queue, _SERVER_SHUTDOWN_TAG)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +if __name__ == '__main__':
 | 
	
		
			
				|  |  | +  unittest.main(verbosity=2)
 |