Bladeren bron

Add metadata auth plugin API support

Masood Malekghassemi 10 jaren geleden
bovenliggende
commit
0f1bf32387

+ 48 - 0
src/python/grpcio/grpc/_adapter/_implementations.py

@@ -0,0 +1,48 @@
+# Copyright 2015, Google Inc.
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are
+# met:
+#
+#     * Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+#     * Redistributions in binary form must reproduce the above
+# copyright notice, this list of conditions and the following disclaimer
+# in the documentation and/or other materials provided with the
+# distribution.
+#     * Neither the name of Google Inc. nor the names of its
+# contributors may be used to endorse or promote products derived from
+# this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+import collections
+
+from grpc.beta import interfaces
+
+class AuthMetadataContext(collections.namedtuple(
+    'AuthMetadataContext', [
+        'service_url',
+        'method_name'
+    ]), interfaces.GRPCAuthMetadataContext):
+  pass
+
+
+class AuthMetadataPluginCallback(interfaces.GRPCAuthMetadataContext):
+
+  def __init__(self, callback):
+    self._callback = callback
+
+  def __call__(self, metadata, error):
+    self._callback(metadata, error)

+ 4 - 21
src/python/grpcio/grpc/_adapter/_intermediary_low.py

@@ -173,20 +173,17 @@ class Call(object):
     return self._internal.peer()
     return self._internal.peer()
 
 
   def set_credentials(self, creds):
   def set_credentials(self, creds):
-    return self._internal.set_credentials(creds._internal)
+    return self._internal.set_credentials(creds)
 
 
 
 
 class Channel(object):
 class Channel(object):
   """Adapter from old _low.Channel interface to new _low.Channel."""
   """Adapter from old _low.Channel interface to new _low.Channel."""
 
 
-  def __init__(self, hostport, client_credentials, server_host_override=None):
+  def __init__(self, hostport, channel_credentials, server_host_override=None):
     args = []
     args = []
     if server_host_override:
     if server_host_override:
       args.append((_types.GrpcChannelArgumentKeys.SSL_TARGET_NAME_OVERRIDE.value, server_host_override))
       args.append((_types.GrpcChannelArgumentKeys.SSL_TARGET_NAME_OVERRIDE.value, server_host_override))
-    creds = None
-    if client_credentials:
-      creds = client_credentials._internal
-    self._internal = _low.Channel(hostport, args, creds)
+    self._internal = _low.Channel(hostport, args, channel_credentials)
 
 
 
 
 class CompletionQueue(object):
 class CompletionQueue(object):
@@ -245,7 +242,7 @@ class Server(object):
     if server_credentials is None:
     if server_credentials is None:
       return self._internal.add_http2_port(addr, None)
       return self._internal.add_http2_port(addr, None)
     else:
     else:
-      return self._internal.add_http2_port(addr, server_credentials._internal)
+      return self._internal.add_http2_port(addr, server_credentials)
 
 
   def start(self):
   def start(self):
     return self._internal.start()
     return self._internal.start()
@@ -259,17 +256,3 @@ class Server(object):
   def stop(self):
   def stop(self):
     return self._internal.shutdown(_TagAdapter(None, Event.Kind.STOP))
     return self._internal.shutdown(_TagAdapter(None, Event.Kind.STOP))
 
 
-
-class ClientCredentials(object):
-  """Adapter from old _low.ClientCredentials interface to new _low.ChannelCredentials."""
-
-  def __init__(self, root_certificates, private_key, certificate_chain):
-    self._internal = _low.channel_credentials_ssl(root_certificates, private_key, certificate_chain)
-
-
-class ServerCredentials(object):
-  """Adapter from old _low.ServerCredentials interface to new _low.ServerCredentials."""
-
-  def __init__(self, root_credentials, pair_sequence, force_client_auth):
-    self._internal = _low.server_credentials_ssl(
-        root_credentials, pair_sequence, force_client_auth)

+ 80 - 0
src/python/grpcio/grpc/_adapter/_low.py

@@ -27,8 +27,11 @@
 # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
 # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
 # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 
 
+import threading
+
 from grpc import _grpcio_metadata
 from grpc import _grpcio_metadata
 from grpc._cython import cygrpc
 from grpc._cython import cygrpc
+from grpc._adapter import _implementations
 from grpc._adapter import _types
 from grpc._adapter import _types
 
 
 _USER_AGENT = 'Python-gRPC-{}'.format(_grpcio_metadata.__version__)
 _USER_AGENT = 'Python-gRPC-{}'.format(_grpcio_metadata.__version__)
@@ -37,6 +40,9 @@ ChannelCredentials = cygrpc.ChannelCredentials
 CallCredentials = cygrpc.CallCredentials
 CallCredentials = cygrpc.CallCredentials
 ServerCredentials = cygrpc.ServerCredentials
 ServerCredentials = cygrpc.ServerCredentials
 
 
+channel_credentials_composite = cygrpc.channel_credentials_composite
+call_credentials_composite = cygrpc.call_credentials_composite
+
 def server_credentials_ssl(root_credentials, pair_sequence, force_client_auth):
 def server_credentials_ssl(root_credentials, pair_sequence, force_client_auth):
   return cygrpc.server_credentials_ssl(
   return cygrpc.server_credentials_ssl(
       root_credentials,
       root_credentials,
@@ -51,6 +57,80 @@ def channel_credentials_ssl(
   return cygrpc.channel_credentials_ssl(root_certificates, pair)
   return cygrpc.channel_credentials_ssl(root_certificates, pair)
 
 
 
 
+class _WrappedCygrpcCallback(object):
+
+  def __init__(self, cygrpc_callback):
+    self.is_called = False
+    self.error = None
+    self.is_called_lock = threading.Lock()
+    self.cygrpc_callback = cygrpc_callback
+
+  def _invoke_failure(self, error):
+    # TODO(atash) translate different Exception superclasses into different
+    # status codes.
+    self.cygrpc_callback(
+        cygrpc.Metadata([]), cygrpc.StatusCode.internal, error.message)
+
+  def _invoke_success(self, metadata):
+    try:
+      cygrpc_metadata = cygrpc.Metadata(
+          cygrpc.Metadatum(key, value)
+          for key, value in metadata)
+    except Exception as error:
+      self._invoke_failure(error)
+      return
+    self.cygrpc_callback(cygrpc_metadata, cygrpc.StatusCode.ok, '')
+
+  def __call__(self, metadata, error):
+    with self.is_called_lock:
+      if self.is_called:
+        raise RuntimeError('callback should only ever be invoked once')
+      if self.error:
+        self._invoke_failure(self.error)
+        return
+      self.is_called = True
+    if error is None:
+      self._invoke_success(metadata)
+    else:
+      self._invoke_failure(error)
+
+  def notify_failure(self, error):
+    with self.is_called_lock:
+      if not self.is_called:
+        self.error = error
+
+
+class _WrappedPlugin(object):
+
+  def __init__(self, plugin):
+    self.plugin = plugin
+
+  def __call__(self, context, cygrpc_callback):
+    wrapped_cygrpc_callback = _WrappedCygrpcCallback(cygrpc_callback)
+    wrapped_context = _implementations.AuthMetadataContext(context.service_url,
+                                                           context.method_name)
+    try:
+      self.plugin(
+          wrapped_context,
+          _implementations.AuthMetadataPluginCallback(wrapped_cygrpc_callback))
+    except Exception as error:
+      wrapped_cygrpc_callback.notify_failure(error)
+      raise
+
+
+def call_credentials_metadata_plugin(plugin, name):
+  """
+  Args:
+    plugin: A callable accepting a _types.AuthMetadataContext
+      object and a callback (itself accepting a list of metadata key/value
+      2-tuples and a None-able exception value). The callback must be eventually
+      called, but need not be called in plugin's invocation.
+      plugin's invocation must be non-blocking.
+  """
+  return cygrpc.call_credentials_metadata_plugin(
+      cygrpc.CredentialsMetadataPlugin(_WrappedPlugin(plugin), name))
+
+
 class CompletionQueue(_types.CompletionQueue):
 class CompletionQueue(_types.CompletionQueue):
 
 
   def __init__(self):
   def __init__(self):

+ 23 - 0
src/python/grpcio/grpc/_cython/_cygrpc/credentials.pxd

@@ -27,7 +27,10 @@
 # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
 # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
 # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 
 
+cimport cpython
+
 from grpc._cython._cygrpc cimport grpc
 from grpc._cython._cygrpc cimport grpc
+from grpc._cython._cygrpc cimport records
 
 
 
 
 cdef class ChannelCredentials:
 cdef class ChannelCredentials:
@@ -49,3 +52,23 @@ cdef class ServerCredentials:
   cdef grpc.grpc_ssl_pem_key_cert_pair *c_ssl_pem_key_cert_pairs
   cdef grpc.grpc_ssl_pem_key_cert_pair *c_ssl_pem_key_cert_pairs
   cdef size_t c_ssl_pem_key_cert_pairs_count
   cdef size_t c_ssl_pem_key_cert_pairs_count
   cdef list references
   cdef list references
+
+
+cdef class CredentialsMetadataPlugin:
+
+  cdef object plugin_callback
+  cdef str plugin_name
+
+  cdef grpc.grpc_metadata_credentials_plugin make_c_plugin(self)
+
+
+cdef class AuthMetadataContext:
+
+  cdef grpc.grpc_auth_metadata_context context
+
+
+cdef void plugin_get_metadata(
+    void *state, grpc.grpc_auth_metadata_context context,
+    grpc.grpc_credentials_plugin_metadata_cb cb, void *user_data) with gil
+
+cdef void plugin_destroy_c_plugin_state(void *state)

+ 71 - 0
src/python/grpcio/grpc/_cython/_cygrpc/credentials.pyx

@@ -27,6 +27,8 @@
 # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
 # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
 # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 
 
+cimport cpython
+
 from grpc._cython._cygrpc cimport grpc
 from grpc._cython._cygrpc cimport grpc
 from grpc._cython._cygrpc cimport records
 from grpc._cython._cygrpc cimport records
 
 
@@ -78,6 +80,66 @@ cdef class ServerCredentials:
       grpc.grpc_server_credentials_release(self.c_credentials)
       grpc.grpc_server_credentials_release(self.c_credentials)
 
 
 
 
+cdef class CredentialsMetadataPlugin:
+
+  def __cinit__(self, object plugin_callback, str name):
+    """
+    Args:
+      plugin_callback (callable): Callback accepting a service URL (str/bytes)
+        and callback object (accepting a records.Metadata,
+        grpc.grpc_status_code, and a str/bytes error message). This argument
+        when called should be non-blocking and eventually call the callback
+        object with the appropriate status code/details and metadata (if
+        successful).
+      name (str): Plugin name.
+    """
+    if not callable(plugin_callback):
+      raise ValueError('expected callable plugin_callback')
+    self.plugin_callback = plugin_callback
+    self.plugin_name = name
+
+  @staticmethod
+  cdef grpc.grpc_metadata_credentials_plugin make_c_plugin(self):
+    cdef grpc.grpc_metadata_credentials_plugin result
+    result.get_metadata = plugin_get_metadata
+    result.destroy = plugin_destroy_c_plugin_state
+    result.state = <void *>self
+    result.type = self.plugin_name
+    cpython.Py_INCREF(self)
+    return result
+
+
+cdef class AuthMetadataContext:
+
+  def __cinit__(self):
+    self.context.service_url = NULL
+    self.context.method_name = NULL
+
+  @property
+  def service_url(self):
+    return self.context.service_url
+
+  @property
+  def method_name(self):
+    return self.context.method_name
+
+
+cdef void plugin_get_metadata(
+    void *state, grpc.grpc_auth_metadata_context context,
+    grpc.grpc_credentials_plugin_metadata_cb cb, void *user_data) with gil:
+  def python_callback(
+      records.Metadata metadata, grpc.grpc_status_code status,
+      const char *error_details):
+    cb(user_data, metadata.c_metadata_array.metadata,
+       metadata.c_metadata_array.count, status, error_details)
+  cdef CredentialsMetadataPlugin self = <CredentialsMetadataPlugin>state
+  cdef AuthMetadataContext cy_context = AuthMetadataContext()
+  cy_context.context = context
+  self.plugin_callback(cy_context, python_callback)
+
+cdef void plugin_destroy_c_plugin_state(void *state):
+  cpython.Py_DECREF(<CredentialsMetadataPlugin>state)
+
 def channel_credentials_google_default():
 def channel_credentials_google_default():
   cdef ChannelCredentials credentials = ChannelCredentials();
   cdef ChannelCredentials credentials = ChannelCredentials();
   credentials.c_credentials = grpc.grpc_google_default_credentials_create()
   credentials.c_credentials = grpc.grpc_google_default_credentials_create()
@@ -185,6 +247,15 @@ def call_credentials_google_iam(authorization_token, authority_selector):
   credentials.references.append(authority_selector)
   credentials.references.append(authority_selector)
   return credentials
   return credentials
 
 
+def call_credentials_metadata_plugin(CredentialsMetadataPlugin plugin):
+  cdef CallCredentials credentials = CallCredentials()
+  credentials.c_credentials = (
+      grpc.grpc_metadata_credentials_create_from_plugin(plugin.make_c_plugin(),
+                                                        NULL))
+  # TODO(atash): the following held reference is *probably* never necessary
+  credentials.references.append(plugin)
+  return credentials
+
 def server_credentials_ssl(pem_root_certs, pem_key_cert_pairs,
 def server_credentials_ssl(pem_root_certs, pem_key_cert_pairs,
                            bint force_client_auth):
                            bint force_client_auth):
   cdef char *c_pem_root_certs = NULL
   cdef char *c_pem_root_certs = NULL

+ 24 - 2
src/python/grpcio/grpc/_cython/_cygrpc/grpc.pxd

@@ -137,8 +137,6 @@ cdef extern from "grpc/grpc.h":
   const char *GRPC_ARG_MAX_CONCURRENT_STREAMS
   const char *GRPC_ARG_MAX_CONCURRENT_STREAMS
   const char *GRPC_ARG_MAX_MESSAGE_LENGTH
   const char *GRPC_ARG_MAX_MESSAGE_LENGTH
   const char *GRPC_ARG_HTTP2_INITIAL_SEQUENCE_NUMBER
   const char *GRPC_ARG_HTTP2_INITIAL_SEQUENCE_NUMBER
-  const char *GRPC_ARG_HTTP2_HPACK_TABLE_SIZE_DECODER
-  const char *GRPC_ARG_HTTP2_HPACK_TABLE_SIZE_ENCODER
   const char *GRPC_ARG_DEFAULT_AUTHORITY
   const char *GRPC_ARG_DEFAULT_AUTHORITY
   const char *GRPC_ARG_PRIMARY_USER_AGENT_STRING
   const char *GRPC_ARG_PRIMARY_USER_AGENT_STRING
   const char *GRPC_ARG_SECONDARY_USER_AGENT_STRING
   const char *GRPC_ARG_SECONDARY_USER_AGENT_STRING
@@ -396,3 +394,27 @@ cdef extern from "grpc/grpc_security.h":
 
 
   grpc_call_error grpc_call_set_credentials(grpc_call *call,
   grpc_call_error grpc_call_set_credentials(grpc_call *call,
                                             grpc_call_credentials *creds)
                                             grpc_call_credentials *creds)
+
+  ctypedef struct grpc_auth_context:
+    # We don't care about the internals (and in fact don't know them)
+    pass
+
+  ctypedef struct grpc_auth_metadata_context:
+    const char *service_url
+    const char *method_name
+    const grpc_auth_context *channel_auth_context
+
+  ctypedef void (*grpc_credentials_plugin_metadata_cb)(
+      void *user_data, const grpc_metadata *creds_md, size_t num_creds_md,
+      grpc_status_code status, const char *error_details)
+
+  ctypedef struct grpc_metadata_credentials_plugin:
+    void (*get_metadata)(
+        void *state, grpc_auth_metadata_context context,
+        grpc_credentials_plugin_metadata_cb cb, void *user_data)
+    void (*destroy)(void *state)
+    void *state
+    const char *type
+
+  grpc_call_credentials *grpc_metadata_credentials_create_from_plugin(
+      grpc_metadata_credentials_plugin plugin, void *reserved)

+ 0 - 2
src/python/grpcio/grpc/_cython/_cygrpc/records.pyx

@@ -45,8 +45,6 @@ class ChannelArgKey:
   max_concurrent_streams = grpc.GRPC_ARG_MAX_CONCURRENT_STREAMS
   max_concurrent_streams = grpc.GRPC_ARG_MAX_CONCURRENT_STREAMS
   max_message_length = grpc.GRPC_ARG_MAX_MESSAGE_LENGTH
   max_message_length = grpc.GRPC_ARG_MAX_MESSAGE_LENGTH
   http2_initial_sequence_number = grpc.GRPC_ARG_HTTP2_INITIAL_SEQUENCE_NUMBER
   http2_initial_sequence_number = grpc.GRPC_ARG_HTTP2_INITIAL_SEQUENCE_NUMBER
-  http2_hpack_table_size_decoder = grpc.GRPC_ARG_HTTP2_HPACK_TABLE_SIZE_DECODER
-  http2_hpack_table_size_encoder = grpc.GRPC_ARG_HTTP2_HPACK_TABLE_SIZE_ENCODER
   default_authority = grpc.GRPC_ARG_DEFAULT_AUTHORITY
   default_authority = grpc.GRPC_ARG_DEFAULT_AUTHORITY
   primary_user_agent_string = grpc.GRPC_ARG_PRIMARY_USER_AGENT_STRING
   primary_user_agent_string = grpc.GRPC_ARG_PRIMARY_USER_AGENT_STRING
   secondary_user_agent_string = grpc.GRPC_ARG_SECONDARY_USER_AGENT_STRING
   secondary_user_agent_string = grpc.GRPC_ARG_SECONDARY_USER_AGENT_STRING

+ 3 - 0
src/python/grpcio/grpc/_cython/cygrpc.pyx

@@ -76,6 +76,8 @@ Operations = records.Operations
 CallCredentials = credentials.CallCredentials
 CallCredentials = credentials.CallCredentials
 ChannelCredentials = credentials.ChannelCredentials
 ChannelCredentials = credentials.ChannelCredentials
 ServerCredentials = credentials.ServerCredentials
 ServerCredentials = credentials.ServerCredentials
+CredentialsMetadataPlugin = credentials.CredentialsMetadataPlugin
+AuthMetadataContext = credentials.AuthMetadataContext
 
 
 channel_credentials_google_default = (
 channel_credentials_google_default = (
     credentials.channel_credentials_google_default)
     credentials.channel_credentials_google_default)
@@ -91,6 +93,7 @@ call_credentials_jwt_access = (
 call_credentials_refresh_token = (
 call_credentials_refresh_token = (
     credentials.call_credentials_google_refresh_token)
     credentials.call_credentials_google_refresh_token)
 call_credentials_google_iam = credentials.call_credentials_google_iam
 call_credentials_google_iam = credentials.call_credentials_google_iam
+call_credentials_metadata_plugin = credentials.call_credentials_metadata_plugin
 server_credentials_ssl = credentials.server_credentials_ssl
 server_credentials_ssl = credentials.server_credentials_ssl
 
 
 CompletionQueue = completion_queue.CompletionQueue
 CompletionQueue = completion_queue.CompletionQueue

+ 1 - 1
src/python/grpcio/grpc/_links/invocation.py

@@ -262,7 +262,7 @@ class _Kernel(object):
         self._channel, self._completion_queue, '/%s/%s' % (group, method),
         self._channel, self._completion_queue, '/%s/%s' % (group, method),
         self._host, time.time() + timeout)
         self._host, time.time() + timeout)
     if options is not None and options.credentials is not None:
     if options is not None and options.credentials is not None:
-      call.set_credentials(options.credentials._intermediary_low_credentials)
+      call.set_credentials(options.credentials._low_credentials)
     if transformed_initial_metadata is not None:
     if transformed_initial_metadata is not None:
       for metadata_key, metadata_value in transformed_initial_metadata:
       for metadata_key, metadata_value in transformed_initial_metadata:
         call.add_metadata(metadata_key, metadata_value)
         call.add_metadata(metadata_key, metadata_value)

+ 1 - 1
src/python/grpcio/grpc/beta/_server.py

@@ -170,7 +170,7 @@ class _Server(interfaces.Server):
     with self._lock:
     with self._lock:
       if self._end_link is None:
       if self._end_link is None:
         return self._grpc_link.add_port(
         return self._grpc_link.add_port(
-            address, server_credentials._intermediary_low_credentials)  # pylint: disable=protected-access
+            address, server_credentials._low_credentials)  # pylint: disable=protected-access
       else:
       else:
         raise ValueError('Can\'t add port to serving server!')
         raise ValueError('Can\'t add port to serving server!')
 
 

+ 77 - 19
src/python/grpcio/grpc/beta/implementations.py

@@ -36,6 +36,7 @@ import threading  # pylint: disable=unused-import
 
 
 # cardinality and face are referenced from specification in this module.
 # cardinality and face are referenced from specification in this module.
 from grpc._adapter import _intermediary_low
 from grpc._adapter import _intermediary_low
+from grpc._adapter import _low
 from grpc._adapter import _types
 from grpc._adapter import _types
 from grpc.beta import _connectivity_channel
 from grpc.beta import _connectivity_channel
 from grpc.beta import _server
 from grpc.beta import _server
@@ -48,7 +49,7 @@ _CHANNEL_SUBSCRIPTION_CALLBACK_ERROR_LOG_MESSAGE = (
     'Exception calling channel subscription callback!')
     'Exception calling channel subscription callback!')
 
 
 
 
-class ClientCredentials(object):
+class ChannelCredentials(object):
   """A value encapsulating the data required to create a secure Channel.
   """A value encapsulating the data required to create a secure Channel.
 
 
   This class and its instances have no supported interface - it exists to define
   This class and its instances have no supported interface - it exists to define
@@ -56,13 +57,12 @@ class ClientCredentials(object):
   functions.
   functions.
   """
   """
 
 
-  def __init__(self, low_credentials, intermediary_low_credentials):
+  def __init__(self, low_credentials):
     self._low_credentials = low_credentials
     self._low_credentials = low_credentials
-    self._intermediary_low_credentials = intermediary_low_credentials
 
 
 
 
-def ssl_client_credentials(root_certificates, private_key, certificate_chain):
-  """Creates a ClientCredentials for use with an SSL-enabled Channel.
+def ssl_channel_credentials(root_certificates, private_key, certificate_chain):
+  """Creates a ChannelCredentials for use with an SSL-enabled Channel.
 
 
   Args:
   Args:
     root_certificates: The PEM-encoded root certificates or None to ask for
     root_certificates: The PEM-encoded root certificates or None to ask for
@@ -73,12 +73,73 @@ def ssl_client_credentials(root_certificates, private_key, certificate_chain):
       certificate chain should be used.
       certificate chain should be used.
 
 
   Returns:
   Returns:
-    A ClientCredentials for use with an SSL-enabled Channel.
+    A ChannelCredentials for use with an SSL-enabled Channel.
   """
   """
-  intermediary_low_credentials = _intermediary_low.ClientCredentials(
-      root_certificates, private_key, certificate_chain)
-  return ClientCredentials(
-      intermediary_low_credentials._internal, intermediary_low_credentials)  # pylint: disable=protected-access
+  return ChannelCredentials(_low.channel_credentials_ssl(
+      root_certificates, private_key, certificate_chain))
+
+
+class CallCredentials(object):
+  """A value encapsulating data asserting an identity over an *established*
+  channel. May be composed with ChannelCredentials to always assert identity for
+  every call over that channel.
+
+  This class and its instances have no supported interface - it exists to define
+  the type of its instances and its instances exist to be passed to other
+  functions.
+  """
+
+  def __init__(self, low_credentials):
+    self._low_credentials = low_credentials
+
+
+def metadata_call_credentials(metadata_plugin, name=None):
+  """Construct CallCredentials from an interfaces.GRPCAuthMetadataPlugin.
+
+  Args:
+    metadata_plugin: An interfaces.GRPCAuthMetadataPlugin to use in constructing
+      the CallCredentials object.
+
+  Returns:
+    A CallCredentials object for use in a GRPCCallOptions object.
+  """
+  if name is None:
+    name = metadata_plugin.__name__
+  return CallCredentials(
+      _low.call_credentials_metadata_plugin(metadata_plugin, name))
+
+def composite_call_credentials(call_credentials, additional_call_credentials):
+  """Compose two CallCredentials to make a new one.
+
+  Args:
+    call_credentials: A CallCredentials object.
+    additional_call_credentials: Another CallCredentials object to compose on
+      top of call_credentials.
+
+  Returns:
+    A CallCredentials object for use in a GRPCCallOptions object.
+  """
+  return CallCredentials(
+      _low.call_credentials_composite(
+          call_credentials._low_credentials,
+          additional_call_credentials._low_credentials))
+
+def composite_channel_credentials(channel_credentials,
+                                 additional_call_credentials):
+  """Compose ChannelCredentials on top of client credentials to make a new one.
+
+  Args:
+    channel_credentials: A ChannelCredentials object.
+    additional_call_credentials: A CallCredentials object to compose on
+      top of channel_credentials.
+
+  Returns:
+    A ChannelCredentials object for use in a GRPCCallOptions object.
+  """
+  return ChannelCredentials(
+      _low.channel_credentials_composite(
+          channel_credentials._low_credentials,
+          additional_call_credentials._low_credentials))
 
 
 
 
 class Channel(object):
 class Channel(object):
@@ -135,19 +196,19 @@ def insecure_channel(host, port):
   return Channel(intermediary_low_channel._internal, intermediary_low_channel)  # pylint: disable=protected-access
   return Channel(intermediary_low_channel._internal, intermediary_low_channel)  # pylint: disable=protected-access
 
 
 
 
-def secure_channel(host, port, client_credentials):
+def secure_channel(host, port, channel_credentials):
   """Creates a secure Channel to a remote host.
   """Creates a secure Channel to a remote host.
 
 
   Args:
   Args:
     host: The name of the remote host to which to connect.
     host: The name of the remote host to which to connect.
     port: The port of the remote host to which to connect.
     port: The port of the remote host to which to connect.
-    client_credentials: A ClientCredentials.
+    channel_credentials: A ChannelCredentials.
 
 
   Returns:
   Returns:
     A secure Channel to the remote host through which RPCs may be conducted.
     A secure Channel to the remote host through which RPCs may be conducted.
   """
   """
   intermediary_low_channel = _intermediary_low.Channel(
   intermediary_low_channel = _intermediary_low.Channel(
-      '%s:%d' % (host, port), client_credentials._intermediary_low_credentials)
+      '%s:%d' % (host, port), channel_credentials._low_credentials)
   return Channel(intermediary_low_channel._internal, intermediary_low_channel)  # pylint: disable=protected-access
   return Channel(intermediary_low_channel._internal, intermediary_low_channel)  # pylint: disable=protected-access
 
 
 
 
@@ -251,9 +312,8 @@ class ServerCredentials(object):
   functions.
   functions.
   """
   """
 
 
-  def __init__(self, low_credentials, intermediary_low_credentials):
+  def __init__(self, low_credentials):
     self._low_credentials = low_credentials
     self._low_credentials = low_credentials
-    self._intermediary_low_credentials = intermediary_low_credentials
 
 
 
 
 def ssl_server_credentials(
 def ssl_server_credentials(
@@ -282,11 +342,9 @@ def ssl_server_credentials(
     raise ValueError(
     raise ValueError(
         'Illegal to require client auth without providing root certificates!')
         'Illegal to require client auth without providing root certificates!')
   else:
   else:
-    intermediary_low_credentials = _intermediary_low.ServerCredentials(
+    return ServerCredentials(_low.server_credentials_ssl(
         root_certificates, private_key_certificate_chain_pairs,
         root_certificates, private_key_certificate_chain_pairs,
-        require_client_auth)
-    return ServerCredentials(
-        intermediary_low_credentials._internal, intermediary_low_credentials)  # pylint: disable=protected-access
+        require_client_auth))
 
 
 
 
 class ServerOptions(object):
 class ServerOptions(object):

+ 45 - 4
src/python/grpcio/grpc/beta/interfaces.py

@@ -100,14 +100,55 @@ def grpc_call_options(disable_compression=False, credentials=None):
     disable_compression: A boolean indicating whether or not compression should
     disable_compression: A boolean indicating whether or not compression should
       be disabled for the request object of the RPC. Only valid for
       be disabled for the request object of the RPC. Only valid for
       request-unary RPCs.
       request-unary RPCs.
-    credentials: Reserved for gRPC per-call credentials. The type for this does
-      not exist yet at the Python level.
+    credentials: A CallCredentials object to use for the invoked RPC.
   """
   """
-  if credentials is not None:
-    raise ValueError('`credentials` is a reserved argument')
   return GRPCCallOptions(disable_compression, None, credentials)
   return GRPCCallOptions(disable_compression, None, credentials)
 
 
 
 
+class GRPCAuthMetadataContext(object):
+  """Provides information to call credentials metadata plugins.
+
+  Attributes:
+    service_url: A string URL of the service being called into.
+    method_name: A string of the fully qualified method name being called.
+  """
+  __metaclass__ = abc.ABCMeta
+
+
+class GRPCAuthMetadataPluginCallback(object):
+  """Callback object received by a metadata plugin."""
+  __metaclass__ = abc.ABCMeta
+
+  def __call__(self, metadata, error):
+    """Inform the gRPC runtime of the metadata to construct a CallCredentials.
+
+    Args:
+      metadata: An iterable of 2-sequences (e.g. tuples) of metadata key/value
+        pairs.
+      error: An Exception to indicate error or None to indicate success.
+    """
+    raise NotImplementedError()
+
+
+class GRPCAuthMetadataPlugin(object):
+  """
+  """
+  __metaclass__ = abc.ABCMeta
+
+  def __call__(self, context, callback):
+    """Invoke the plugin.
+
+    Must not block. Need only be called by the gRPC runtime.
+
+    Args:
+      context: A GRPCAuthMetadataContext providing information on what the
+        plugin is being used for.
+      callback: A GRPCAuthMetadataPluginCallback to be invoked either
+        synchronously or asynchronously.
+    """
+    raise NotImplementedError()
+
+
 class GRPCServicerContext(object):
 class GRPCServicerContext(object):
   """Exposes gRPC-specific options and behaviors to code servicing RPCs."""
   """Exposes gRPC-specific options and behaviors to code servicing RPCs."""
   __metaclass__ = abc.ABCMeta
   __metaclass__ = abc.ABCMeta

+ 1 - 1
src/python/grpcio/tests/interop/_secure_interop_test.py

@@ -55,7 +55,7 @@ class SecureInteropTest(
     self.server.start()
     self.server.start()
     self.stub = test_pb2.beta_create_TestService_stub(
     self.stub = test_pb2.beta_create_TestService_stub(
         test_utilities.not_really_secure_channel(
         test_utilities.not_really_secure_channel(
-            '[::]', port, implementations.ssl_client_credentials(
+            '[::]', port, implementations.ssl_channel_credentials(
                 resources.test_root_certificates(), None, None),
                 resources.test_root_certificates(), None, None),
                 _SERVER_HOST_OVERRIDE))
                 _SERVER_HOST_OVERRIDE))
 
 

+ 1 - 1
src/python/grpcio/tests/interop/client.py

@@ -94,7 +94,7 @@ def _stub(args):
 
 
     channel = test_utilities.not_really_secure_channel(
     channel = test_utilities.not_really_secure_channel(
         args.server_host, args.server_port,
         args.server_host, args.server_port,
-        implementations.ssl_client_credentials(root_certificates, None, None),
+        implementations.ssl_channel_credentials(root_certificates, None, None),
         args.server_host_override)
         args.server_host_override)
     stub = test_pb2.beta_create_TestService_stub(
     stub = test_pb2.beta_create_TestService_stub(
         channel, metadata_transformer=metadata_transformer)
         channel, metadata_transformer=metadata_transformer)

+ 188 - 0
src/python/grpcio/tests/unit/_cython/cygrpc_test.py

@@ -28,11 +28,24 @@
 # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 
 
 import time
 import time
+import threading
 import unittest
 import unittest
 
 
 from grpc._cython import cygrpc
 from grpc._cython import cygrpc
 from tests.unit._cython import test_utilities
 from tests.unit._cython import test_utilities
 from tests.unit import test_common
 from tests.unit import test_common
+from tests.unit import resources
+
+
+_SSL_HOST_OVERRIDE = 'foo.test.google.fr'
+_CALL_CREDENTIALS_METADATA_KEY = 'call-creds-key'
+_CALL_CREDENTIALS_METADATA_VALUE = 'call-creds-value'
+
+def _metadata_plugin_callback(context, callback):
+  callback(cygrpc.Metadata(
+      [cygrpc.Metadatum(_CALL_CREDENTIALS_METADATA_KEY,
+                        _CALL_CREDENTIALS_METADATA_VALUE)]),
+      cygrpc.StatusCode.ok, '')
 
 
 
 
 class TypeSmokeTest(unittest.TestCase):
 class TypeSmokeTest(unittest.TestCase):
@@ -89,6 +102,17 @@ class TypeSmokeTest(unittest.TestCase):
     channel = cygrpc.Channel('[::]:0', cygrpc.ChannelArgs([]))
     channel = cygrpc.Channel('[::]:0', cygrpc.ChannelArgs([]))
     del channel
     del channel
 
 
+  def testCredentialsMetadataPluginUpDown(self):
+    plugin = cygrpc.CredentialsMetadataPlugin(
+        lambda ignored_a, ignored_b: None, '')
+    del plugin
+
+  def testCallCredentialsFromPluginUpDown(self):
+    plugin = cygrpc.CredentialsMetadataPlugin(_metadata_plugin_callback, '')
+    call_credentials = cygrpc.call_credentials_metadata_plugin(plugin)
+    del plugin
+    del call_credentials
+
   def testServerStartNoExplicitShutdown(self):
   def testServerStartNoExplicitShutdown(self):
     server = cygrpc.Server()
     server = cygrpc.Server()
     completion_queue = cygrpc.CompletionQueue()
     completion_queue = cygrpc.CompletionQueue()
@@ -260,5 +284,169 @@ class InsecureServerInsecureClient(unittest.TestCase):
     del server_call
     del server_call
 
 
 
 
+class SecureServerSecureClient(unittest.TestCase):
+
+  def setUp(self):
+    server_credentials = cygrpc.server_credentials_ssl(
+        None, [cygrpc.SslPemKeyCertPair(resources.private_key(),
+                                        resources.certificate_chain())], False)
+    channel_credentials = cygrpc.channel_credentials_ssl(
+        resources.test_root_certificates(), None)
+    self.server_completion_queue = cygrpc.CompletionQueue()
+    self.server = cygrpc.Server()
+    self.server.register_completion_queue(self.server_completion_queue)
+    self.port = self.server.add_http2_port('[::]:0', server_credentials)
+    self.server.start()
+    self.client_completion_queue = cygrpc.CompletionQueue()
+    client_channel_arguments = cygrpc.ChannelArgs([
+        cygrpc.ChannelArg(cygrpc.ChannelArgKey.ssl_target_name_override,
+                          _SSL_HOST_OVERRIDE)])
+    self.client_channel = cygrpc.Channel(
+        'localhost:{}'.format(self.port), client_channel_arguments,
+        channel_credentials)
+
+  def tearDown(self):
+    del self.server
+    del self.client_completion_queue
+    del self.server_completion_queue
+
+  def testEcho(self):
+    DEADLINE = time.time()+5
+    DEADLINE_TOLERANCE = 0.25
+    CLIENT_METADATA_ASCII_KEY = b'key'
+    CLIENT_METADATA_ASCII_VALUE = b'val'
+    CLIENT_METADATA_BIN_KEY = b'key-bin'
+    CLIENT_METADATA_BIN_VALUE = b'\0'*1000
+    SERVER_INITIAL_METADATA_KEY = b'init_me_me_me'
+    SERVER_INITIAL_METADATA_VALUE = b'whodawha?'
+    SERVER_TRAILING_METADATA_KEY = b'california_is_in_a_drought'
+    SERVER_TRAILING_METADATA_VALUE = b'zomg it is'
+    SERVER_STATUS_CODE = cygrpc.StatusCode.ok
+    SERVER_STATUS_DETAILS = b'our work is never over'
+    REQUEST = b'in death a member of project mayhem has a name'
+    RESPONSE = b'his name is robert paulson'
+    METHOD = b'/twinkies'
+    HOST = None  # Default host
+
+    cygrpc_deadline = cygrpc.Timespec(DEADLINE)
+
+    server_request_tag = object()
+    request_call_result = self.server.request_call(
+        self.server_completion_queue, self.server_completion_queue,
+        server_request_tag)
+
+    self.assertEqual(cygrpc.CallError.ok, request_call_result)
+
+    plugin = cygrpc.CredentialsMetadataPlugin(_metadata_plugin_callback, '')
+    call_credentials = cygrpc.call_credentials_metadata_plugin(plugin)
+
+    client_call_tag = object()
+    client_call = self.client_channel.create_call(
+        None, 0, self.client_completion_queue, METHOD, HOST, cygrpc_deadline)
+    client_call.set_credentials(call_credentials)
+    client_initial_metadata = cygrpc.Metadata([
+        cygrpc.Metadatum(CLIENT_METADATA_ASCII_KEY,
+                         CLIENT_METADATA_ASCII_VALUE),
+        cygrpc.Metadatum(CLIENT_METADATA_BIN_KEY, CLIENT_METADATA_BIN_VALUE)])
+    client_start_batch_result = client_call.start_batch(cygrpc.Operations([
+        cygrpc.operation_send_initial_metadata(client_initial_metadata),
+        cygrpc.operation_send_message(REQUEST),
+        cygrpc.operation_send_close_from_client(),
+        cygrpc.operation_receive_initial_metadata(),
+        cygrpc.operation_receive_message(),
+        cygrpc.operation_receive_status_on_client()
+    ]), client_call_tag)
+    self.assertEqual(cygrpc.CallError.ok, client_start_batch_result)
+    client_event_future = test_utilities.CompletionQueuePollFuture(
+        self.client_completion_queue, cygrpc_deadline)
+
+    request_event = self.server_completion_queue.poll(cygrpc_deadline)
+    self.assertEqual(cygrpc.CompletionType.operation_complete,
+                      request_event.type)
+    self.assertIsInstance(request_event.operation_call, cygrpc.Call)
+    self.assertIs(server_request_tag, request_event.tag)
+    self.assertEqual(0, len(request_event.batch_operations))
+    client_metadata_with_credentials = list(client_initial_metadata) + [
+        (_CALL_CREDENTIALS_METADATA_KEY, _CALL_CREDENTIALS_METADATA_VALUE)]
+    self.assertTrue(
+        test_common.metadata_transmitted(client_metadata_with_credentials,
+                                         request_event.request_metadata))
+    self.assertEqual(METHOD, request_event.request_call_details.method)
+    self.assertEqual(_SSL_HOST_OVERRIDE,
+                     request_event.request_call_details.host)
+    self.assertLess(
+        abs(DEADLINE - float(request_event.request_call_details.deadline)),
+        DEADLINE_TOLERANCE)
+
+    server_call_tag = object()
+    server_call = request_event.operation_call
+    server_initial_metadata = cygrpc.Metadata([
+        cygrpc.Metadatum(SERVER_INITIAL_METADATA_KEY,
+                         SERVER_INITIAL_METADATA_VALUE)])
+    server_trailing_metadata = cygrpc.Metadata([
+        cygrpc.Metadatum(SERVER_TRAILING_METADATA_KEY,
+                         SERVER_TRAILING_METADATA_VALUE)])
+    server_start_batch_result = server_call.start_batch([
+        cygrpc.operation_send_initial_metadata(server_initial_metadata),
+        cygrpc.operation_receive_message(),
+        cygrpc.operation_send_message(RESPONSE),
+        cygrpc.operation_receive_close_on_server(),
+        cygrpc.operation_send_status_from_server(
+            server_trailing_metadata, SERVER_STATUS_CODE, SERVER_STATUS_DETAILS)
+    ], server_call_tag)
+    self.assertEqual(cygrpc.CallError.ok, server_start_batch_result)
+
+    client_event = client_event_future.result()
+    server_event = self.server_completion_queue.poll(cygrpc_deadline)
+
+    self.assertEqual(6, len(client_event.batch_operations))
+    found_client_op_types = set()
+    for client_result in client_event.batch_operations:
+      # we expect each op type to be unique
+      self.assertNotIn(client_result.type, found_client_op_types)
+      found_client_op_types.add(client_result.type)
+      if client_result.type == cygrpc.OperationType.receive_initial_metadata:
+        self.assertTrue(
+            test_common.metadata_transmitted(server_initial_metadata,
+                                             client_result.received_metadata))
+      elif client_result.type == cygrpc.OperationType.receive_message:
+        self.assertEqual(RESPONSE, client_result.received_message.bytes())
+      elif client_result.type == cygrpc.OperationType.receive_status_on_client:
+        self.assertTrue(
+            test_common.metadata_transmitted(server_trailing_metadata,
+                                             client_result.received_metadata))
+        self.assertEqual(SERVER_STATUS_DETAILS,
+                         client_result.received_status_details)
+        self.assertEqual(SERVER_STATUS_CODE, client_result.received_status_code)
+    self.assertEqual(set([
+          cygrpc.OperationType.send_initial_metadata,
+          cygrpc.OperationType.send_message,
+          cygrpc.OperationType.send_close_from_client,
+          cygrpc.OperationType.receive_initial_metadata,
+          cygrpc.OperationType.receive_message,
+          cygrpc.OperationType.receive_status_on_client
+      ]), found_client_op_types)
+
+    self.assertEqual(5, len(server_event.batch_operations))
+    found_server_op_types = set()
+    for server_result in server_event.batch_operations:
+      self.assertNotIn(client_result.type, found_server_op_types)
+      found_server_op_types.add(server_result.type)
+      if server_result.type == cygrpc.OperationType.receive_message:
+        self.assertEqual(REQUEST, server_result.received_message.bytes())
+      elif server_result.type == cygrpc.OperationType.receive_close_on_server:
+        self.assertFalse(server_result.received_cancelled)
+    self.assertEqual(set([
+          cygrpc.OperationType.send_initial_metadata,
+          cygrpc.OperationType.receive_message,
+          cygrpc.OperationType.send_message,
+          cygrpc.OperationType.receive_close_on_server,
+          cygrpc.OperationType.send_status_from_server
+      ]), found_server_op_types)
+
+    del client_call
+    del server_call
+
+
 if __name__ == '__main__':
 if __name__ == '__main__':
   unittest.main(verbosity=2)
   unittest.main(verbosity=2)

+ 50 - 8
src/python/grpcio/tests/unit/beta/_beta_features_test.py

@@ -42,6 +42,9 @@ from tests.unit.framework.common import test_constants
 
 
 _SERVER_HOST_OVERRIDE = 'foo.test.google.fr'
 _SERVER_HOST_OVERRIDE = 'foo.test.google.fr'
 
 
+_PER_RPC_CREDENTIALS_METADATA_KEY = 'my-call-credentials-metadata-key'
+_PER_RPC_CREDENTIALS_METADATA_VALUE = 'my-call-credentials-metadata-value'
+
 _GROUP = 'group'
 _GROUP = 'group'
 _UNARY_UNARY = 'unary-unary'
 _UNARY_UNARY = 'unary-unary'
 _UNARY_STREAM = 'unary-stream'
 _UNARY_STREAM = 'unary-stream'
@@ -63,6 +66,7 @@ class _Servicer(object):
     with self._condition:
     with self._condition:
       self._request = request
       self._request = request
       self._peer = context.protocol_context().peer()
       self._peer = context.protocol_context().peer()
+      self._invocation_metadata = context.invocation_metadata()
       context.protocol_context().disable_next_response_compression()
       context.protocol_context().disable_next_response_compression()
       self._serviced = True
       self._serviced = True
       self._condition.notify_all()
       self._condition.notify_all()
@@ -72,6 +76,7 @@ class _Servicer(object):
     with self._condition:
     with self._condition:
       self._request = request
       self._request = request
       self._peer = context.protocol_context().peer()
       self._peer = context.protocol_context().peer()
+      self._invocation_metadata = context.invocation_metadata()
       context.protocol_context().disable_next_response_compression()
       context.protocol_context().disable_next_response_compression()
       self._serviced = True
       self._serviced = True
       self._condition.notify_all()
       self._condition.notify_all()
@@ -83,6 +88,7 @@ class _Servicer(object):
       self._request = request
       self._request = request
     with self._condition:
     with self._condition:
       self._peer = context.protocol_context().peer()
       self._peer = context.protocol_context().peer()
+      self._invocation_metadata = context.invocation_metadata()
       context.protocol_context().disable_next_response_compression()
       context.protocol_context().disable_next_response_compression()
       self._serviced = True
       self._serviced = True
       self._condition.notify_all()
       self._condition.notify_all()
@@ -95,6 +101,7 @@ class _Servicer(object):
         context.protocol_context().disable_next_response_compression()
         context.protocol_context().disable_next_response_compression()
         yield _RESPONSE
         yield _RESPONSE
     with self._condition:
     with self._condition:
+      self._invocation_metadata = context.invocation_metadata()
       self._serviced = True
       self._serviced = True
       self._condition.notify_all()
       self._condition.notify_all()
 
 
@@ -137,6 +144,11 @@ class _BlockingIterator(object):
       self._condition.notify_all()
       self._condition.notify_all()
 
 
 
 
+def _metadata_plugin(context, callback):
+  callback([(_PER_RPC_CREDENTIALS_METADATA_KEY,
+             _PER_RPC_CREDENTIALS_METADATA_VALUE)], None)
+
+
 class BetaFeaturesTest(unittest.TestCase):
 class BetaFeaturesTest(unittest.TestCase):
 
 
   def setUp(self):
   def setUp(self):
@@ -167,10 +179,12 @@ class BetaFeaturesTest(unittest.TestCase):
         [(resources.private_key(), resources.certificate_chain(),),])
         [(resources.private_key(), resources.certificate_chain(),),])
     port = self._server.add_secure_port('[::]:0', server_credentials)
     port = self._server.add_secure_port('[::]:0', server_credentials)
     self._server.start()
     self._server.start()
-    self._client_credentials = implementations.ssl_client_credentials(
+    self._channel_credentials = implementations.ssl_channel_credentials(
         resources.test_root_certificates(), None, None)
         resources.test_root_certificates(), None, None)
+    self._call_credentials = implementations.metadata_call_credentials(
+        _metadata_plugin)
     channel = test_utilities.not_really_secure_channel(
     channel = test_utilities.not_really_secure_channel(
-        'localhost', port, self._client_credentials, _SERVER_HOST_OVERRIDE)
+        'localhost', port, self._channel_credentials, _SERVER_HOST_OVERRIDE)
     stub_options = implementations.stub_options(
     stub_options = implementations.stub_options(
         thread_pool_size=test_constants.POOL_SIZE)
         thread_pool_size=test_constants.POOL_SIZE)
     self._dynamic_stub = implementations.dynamic_stub(
     self._dynamic_stub = implementations.dynamic_stub(
@@ -181,21 +195,36 @@ class BetaFeaturesTest(unittest.TestCase):
     self._server.stop(test_constants.SHORT_TIMEOUT).wait()
     self._server.stop(test_constants.SHORT_TIMEOUT).wait()
 
 
   def test_unary_unary(self):
   def test_unary_unary(self):
-    call_options = interfaces.grpc_call_options(disable_compression=True)
+    call_options = interfaces.grpc_call_options(
+        disable_compression=True, credentials=self._call_credentials)
     response = getattr(self._dynamic_stub, _UNARY_UNARY)(
     response = getattr(self._dynamic_stub, _UNARY_UNARY)(
         _REQUEST, test_constants.LONG_TIMEOUT, protocol_options=call_options)
         _REQUEST, test_constants.LONG_TIMEOUT, protocol_options=call_options)
     self.assertEqual(_RESPONSE, response)
     self.assertEqual(_RESPONSE, response)
     self.assertIsNotNone(self._servicer.peer())
     self.assertIsNotNone(self._servicer.peer())
+    invocation_metadata = [(metadatum.key, metadatum.value) for metadatum in
+                           self._servicer._invocation_metadata]
+    self.assertIn(
+        (_PER_RPC_CREDENTIALS_METADATA_KEY,
+         _PER_RPC_CREDENTIALS_METADATA_VALUE),
+        invocation_metadata)
 
 
   def test_unary_stream(self):
   def test_unary_stream(self):
-    call_options = interfaces.grpc_call_options(disable_compression=True)
+    call_options = interfaces.grpc_call_options(
+        disable_compression=True, credentials=self._call_credentials)
     response_iterator = getattr(self._dynamic_stub, _UNARY_STREAM)(
     response_iterator = getattr(self._dynamic_stub, _UNARY_STREAM)(
         _REQUEST, test_constants.LONG_TIMEOUT, protocol_options=call_options)
         _REQUEST, test_constants.LONG_TIMEOUT, protocol_options=call_options)
     self._servicer.block_until_serviced()
     self._servicer.block_until_serviced()
     self.assertIsNotNone(self._servicer.peer())
     self.assertIsNotNone(self._servicer.peer())
+    invocation_metadata = [(metadatum.key, metadatum.value) for metadatum in
+                           self._servicer._invocation_metadata]
+    self.assertIn(
+        (_PER_RPC_CREDENTIALS_METADATA_KEY,
+         _PER_RPC_CREDENTIALS_METADATA_VALUE),
+        invocation_metadata)
 
 
   def test_stream_unary(self):
   def test_stream_unary(self):
-    call_options = interfaces.grpc_call_options()
+    call_options = interfaces.grpc_call_options(
+        credentials=self._call_credentials)
     request_iterator = _BlockingIterator(iter((_REQUEST,)))
     request_iterator = _BlockingIterator(iter((_REQUEST,)))
     response_future = getattr(self._dynamic_stub, _STREAM_UNARY).future(
     response_future = getattr(self._dynamic_stub, _STREAM_UNARY).future(
         request_iterator, test_constants.LONG_TIMEOUT,
         request_iterator, test_constants.LONG_TIMEOUT,
@@ -207,9 +236,16 @@ class BetaFeaturesTest(unittest.TestCase):
     self._servicer.block_until_serviced()
     self._servicer.block_until_serviced()
     self.assertIsNotNone(self._servicer.peer())
     self.assertIsNotNone(self._servicer.peer())
     self.assertEqual(_RESPONSE, response_future.result())
     self.assertEqual(_RESPONSE, response_future.result())
+    invocation_metadata = [(metadatum.key, metadatum.value) for metadatum in
+                           self._servicer._invocation_metadata]
+    self.assertIn(
+        (_PER_RPC_CREDENTIALS_METADATA_KEY,
+         _PER_RPC_CREDENTIALS_METADATA_VALUE),
+        invocation_metadata)
 
 
   def test_stream_stream(self):
   def test_stream_stream(self):
-    call_options = interfaces.grpc_call_options()
+    call_options = interfaces.grpc_call_options(
+        credentials=self._call_credentials)
     request_iterator = _BlockingIterator(iter((_REQUEST,)))
     request_iterator = _BlockingIterator(iter((_REQUEST,)))
     response_iterator = getattr(self._dynamic_stub, _STREAM_STREAM)(
     response_iterator = getattr(self._dynamic_stub, _STREAM_STREAM)(
         request_iterator, test_constants.SHORT_TIMEOUT,
         request_iterator, test_constants.SHORT_TIMEOUT,
@@ -222,6 +258,12 @@ class BetaFeaturesTest(unittest.TestCase):
     self._servicer.block_until_serviced()
     self._servicer.block_until_serviced()
     self.assertIsNotNone(self._servicer.peer())
     self.assertIsNotNone(self._servicer.peer())
     self.assertEqual(_RESPONSE, response)
     self.assertEqual(_RESPONSE, response)
+    invocation_metadata = [(metadatum.key, metadatum.value) for metadatum in
+                           self._servicer._invocation_metadata]
+    self.assertIn(
+        (_PER_RPC_CREDENTIALS_METADATA_KEY,
+         _PER_RPC_CREDENTIALS_METADATA_VALUE),
+        invocation_metadata)
 
 
 
 
 class ContextManagementAndLifecycleTest(unittest.TestCase):
 class ContextManagementAndLifecycleTest(unittest.TestCase):
@@ -250,7 +292,7 @@ class ContextManagementAndLifecycleTest(unittest.TestCase):
         thread_pool_size=test_constants.POOL_SIZE)
         thread_pool_size=test_constants.POOL_SIZE)
     self._server_credentials = implementations.ssl_server_credentials(
     self._server_credentials = implementations.ssl_server_credentials(
         [(resources.private_key(), resources.certificate_chain(),),])
         [(resources.private_key(), resources.certificate_chain(),),])
-    self._client_credentials = implementations.ssl_client_credentials(
+    self._channel_credentials = implementations.ssl_channel_credentials(
         resources.test_root_certificates(), None, None)
         resources.test_root_certificates(), None, None)
     self._stub_options = implementations.stub_options(
     self._stub_options = implementations.stub_options(
         thread_pool_size=test_constants.POOL_SIZE)
         thread_pool_size=test_constants.POOL_SIZE)
@@ -262,7 +304,7 @@ class ContextManagementAndLifecycleTest(unittest.TestCase):
     server.start()
     server.start()
 
 
     channel = test_utilities.not_really_secure_channel(
     channel = test_utilities.not_really_secure_channel(
-        'localhost', port, self._client_credentials, _SERVER_HOST_OVERRIDE)
+        'localhost', port, self._channel_credentials, _SERVER_HOST_OVERRIDE)
     dynamic_stub = implementations.dynamic_stub(
     dynamic_stub = implementations.dynamic_stub(
         channel, _GROUP, self._cardinalities, options=self._stub_options)
         channel, _GROUP, self._cardinalities, options=self._stub_options)
     for _ in range(100):
     for _ in range(100):

+ 2 - 2
src/python/grpcio/tests/unit/beta/_face_interface_test.py

@@ -91,10 +91,10 @@ class _Implementation(test_interfaces.Implementation):
         [(resources.private_key(), resources.certificate_chain(),),])
         [(resources.private_key(), resources.certificate_chain(),),])
     port = server.add_secure_port('[::]:0', server_credentials)
     port = server.add_secure_port('[::]:0', server_credentials)
     server.start()
     server.start()
-    client_credentials = implementations.ssl_client_credentials(
+    channel_credentials = implementations.ssl_channel_credentials(
         resources.test_root_certificates(), None, None)
         resources.test_root_certificates(), None, None)
     channel = test_utilities.not_really_secure_channel(
     channel = test_utilities.not_really_secure_channel(
-        'localhost', port, client_credentials, _SERVER_HOST_OVERRIDE)
+        'localhost', port, channel_credentials, _SERVER_HOST_OVERRIDE)
     stub_options = implementations.stub_options(
     stub_options = implementations.stub_options(
         request_serializers=serialization_behaviors.request_serializers,
         request_serializers=serialization_behaviors.request_serializers,
         response_deserializers=serialization_behaviors.response_deserializers,
         response_deserializers=serialization_behaviors.response_deserializers,

+ 3 - 3
src/python/grpcio/tests/unit/beta/test_utilities.py

@@ -34,13 +34,13 @@ from grpc.beta import implementations
 
 
 
 
 def not_really_secure_channel(
 def not_really_secure_channel(
-    host, port, client_credentials, server_host_override):
+    host, port, channel_credentials, server_host_override):
   """Creates an insecure Channel to a remote host.
   """Creates an insecure Channel to a remote host.
 
 
   Args:
   Args:
     host: The name of the remote host to which to connect.
     host: The name of the remote host to which to connect.
     port: The port of the remote host to which to connect.
     port: The port of the remote host to which to connect.
-    client_credentials: The implementations.ClientCredentials with which to
+    channel_credentials: The implementations.ChannelCredentials with which to
       connect.
       connect.
     server_host_override: The target name used for SSL host name checking.
     server_host_override: The target name used for SSL host name checking.
 
 
@@ -50,7 +50,7 @@ def not_really_secure_channel(
   """
   """
   hostport = '%s:%d' % (host, port)
   hostport = '%s:%d' % (host, port)
   intermediary_low_channel = _intermediary_low.Channel(
   intermediary_low_channel = _intermediary_low.Channel(
-      hostport, client_credentials._intermediary_low_credentials,
+      hostport, channel_credentials._low_credentials,
       server_host_override=server_host_override)
       server_host_override=server_host_override)
   return implementations.Channel(
   return implementations.Channel(
       intermediary_low_channel._internal, intermediary_low_channel)
       intermediary_low_channel._internal, intermediary_low_channel)