#region Copyright notice and license
// Copyright 2018 gRPC authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#endregion
using System;
using System.Threading;
using System.Threading.Tasks;
using Grpc.Core.Internal;
namespace Grpc.Core.Interceptors
{
///
/// Provides a base class for generic interceptor implementations that raises
/// events and hooks to control the RPC lifecycle.
///
internal abstract class GenericInterceptor : Interceptor
{
///
/// Provides hooks through which an invocation should be intercepted.
///
public sealed class ClientCallHooks
where TRequest : class
where TResponse : class
{
internal ClientCallHooks Freeze()
{
return (ClientCallHooks)MemberwiseClone();
}
///
/// Override the context for the outgoing invocation.
///
public ClientInterceptorContext? ContextOverride { get; set; }
///
/// Override the request for the outgoing invocation for non-client-streaming invocations.
///
public TRequest UnaryRequestOverride { get; set; }
///
/// Delegate that intercepts a response from a non-server-streaming invocation and optionally overrides it.
///
public Func OnUnaryResponse { get; set; }
///
/// Delegate that intercepts each request message for a client-streaming invocation and optionally overrides each message.
///
public Func OnRequestMessage { get; set; }
///
/// Delegate that intercepts each response message for a server-streaming invocation and optionally overrides each message.
///
public Func OnResponseMessage { get; set; }
///
/// Callback that gets invoked when response stream is finished.
///
public Action OnResponseStreamEnd { get; set; }
///
/// Callback that gets invoked when request stream is finished.
///
public Action OnRequestStreamEnd { get; set; }
}
///
/// Intercepts an outgoing call from the client side.
/// Derived classes that intend to intercept outgoing invocations from the client side should
/// override this and return the appropriate hooks in the form of a ClientCallHooks instance.
///
/// The context of the outgoing invocation.
/// True if the invocation is client-streaming.
/// True if the invocation is server-streaming.
/// The request message for client-unary invocations, null otherwise.
/// Request message type for the current invocation.
/// Response message type for the current invocation.
///
/// The derived class should return an instance of ClientCallHooks to control the trajectory
/// as they see fit, or null if it does not intend to pursue the invocation any further.
///
protected virtual ClientCallHooks InterceptCall(ClientInterceptorContext context, bool clientStreaming, bool serverStreaming, TRequest request)
where TRequest : class
where TResponse : class
{
return null;
}
///
/// Provides hooks through which a server-side handler should be intercepted.
///
public sealed class ServerCallHooks
where TRequest : class
where TResponse : class
{
internal ServerCallHooks Freeze()
{
return (ServerCallHooks)MemberwiseClone();
}
///
/// Override the request for the outgoing invocation for non-client-streaming invocations.
///
public TRequest UnaryRequestOverride { get; set; }
///
/// Delegate that intercepts a response from a non-server-streaming invocation and optionally overrides it.
///
public Func OnUnaryResponse { get; set; }
///
/// Delegate that intercepts each request message for a client-streaming invocation and optionally overrides each message.
///
public Func OnRequestMessage { get; set; }
///
/// Delegate that intercepts each response message for a server-streaming invocation and optionally overrides each message.
///
public Func OnResponseMessage { get; set; }
///
/// Callback that gets invoked when handler is finished executing.
///
public Action OnHandlerEnd { get; set; }
///
/// Callback that gets invoked when request stream is finished.
///
public Action OnRequestStreamEnd { get; set; }
}
///
/// Intercepts an incoming service handler invocation on the server side.
/// Derived classes that intend to intercept incoming handlers on the server side should
/// override this and return the appropriate hooks in the form of a ServerCallHooks instance.
///
/// The context of the incoming invocation.
/// True if the invocation is client-streaming.
/// True if the invocation is server-streaming.
/// The request message for client-unary invocations, null otherwise.
/// Request message type for the current invocation.
/// Response message type for the current invocation.
///
/// The derived class should return an instance of ServerCallHooks to control the trajectory
/// as they see fit, or null if it does not intend to pursue the invocation any further.
///
protected virtual Task> InterceptHandler(ServerCallContext context, bool clientStreaming, bool serverStreaming, TRequest request)
where TRequest : class
where TResponse : class
{
return Task.FromResult>(null);
}
///
/// Intercepts a blocking invocation of a simple remote call and dispatches the events accordingly.
///
public override TResponse BlockingUnaryCall(TRequest request, ClientInterceptorContext context, BlockingUnaryCallContinuation continuation)
{
var hooks = InterceptCall(context, false, false, request)?.Freeze();
context = hooks?.ContextOverride ?? context;
request = hooks?.UnaryRequestOverride ?? request;
var response = continuation(request, context);
if (hooks?.OnUnaryResponse != null)
{
response = hooks.OnUnaryResponse(response);
}
return response;
}
///
/// Intercepts an asynchronous invocation of a simple remote call and dispatches the events accordingly.
///
public override AsyncUnaryCall AsyncUnaryCall(TRequest request, ClientInterceptorContext context, AsyncUnaryCallContinuation continuation)
{
var hooks = InterceptCall(context, false, false, request)?.Freeze();
context = hooks?.ContextOverride ?? context;
request = hooks?.UnaryRequestOverride ?? request;
var response = continuation(request, context);
if (hooks?.OnUnaryResponse != null)
{
response = new AsyncUnaryCall(response.ResponseAsync.ContinueWith(unaryResponse => hooks.OnUnaryResponse(unaryResponse.Result)),
response.ResponseHeadersAsync, response.GetStatus, response.GetTrailers, response.Dispose);
}
return response;
}
///
/// Intercepts an asynchronous invocation of a streaming remote call and dispatches the events accordingly.
///
public override AsyncServerStreamingCall AsyncServerStreamingCall(TRequest request, ClientInterceptorContext context, AsyncServerStreamingCallContinuation continuation)
{
var hooks = InterceptCall(context, false, true, request)?.Freeze();
context = hooks?.ContextOverride ?? context;
request = hooks?.UnaryRequestOverride ?? request;
var response = continuation(request, context);
if (hooks?.OnResponseMessage != null || hooks?.OnResponseStreamEnd != null)
{
response = new AsyncServerStreamingCall(
new WrappedAsyncStreamReader(response.ResponseStream, hooks.OnResponseMessage, hooks.OnResponseStreamEnd),
response.ResponseHeadersAsync, response.GetStatus, response.GetTrailers, response.Dispose);
}
return response;
}
///
/// Intercepts an asynchronous invocation of a client streaming call and dispatches the events accordingly.
///
public override AsyncClientStreamingCall AsyncClientStreamingCall(ClientInterceptorContext context, AsyncClientStreamingCallContinuation continuation)
{
var hooks = InterceptCall(context, true, false, null)?.Freeze();
context = hooks?.ContextOverride ?? context;
var response = continuation(context);
if (hooks?.OnRequestMessage != null || hooks?.OnResponseStreamEnd != null || hooks?.OnUnaryResponse != null)
{
var requestStream = response.RequestStream;
if (hooks?.OnRequestMessage != null || hooks?.OnRequestStreamEnd != null)
{
requestStream = new WrappedClientStreamWriter(response.RequestStream, hooks.OnRequestMessage, hooks.OnRequestStreamEnd);
}
var responseAsync = response.ResponseAsync;
if (hooks?.OnUnaryResponse != null)
{
responseAsync = response.ResponseAsync.ContinueWith(unaryResponse => hooks.OnUnaryResponse(unaryResponse.Result));
}
response = new AsyncClientStreamingCall(requestStream, responseAsync, response.ResponseHeadersAsync, response.GetStatus, response.GetTrailers, response.Dispose);
}
return response;
}
///
/// Intercepts an asynchronous invocation of a duplex streaming call and dispatches the events accordingly.
///
public override AsyncDuplexStreamingCall AsyncDuplexStreamingCall(ClientInterceptorContext context, AsyncDuplexStreamingCallContinuation continuation)
{
var hooks = InterceptCall(context, true, true, null)?.Freeze();
context = hooks?.ContextOverride ?? context;
var response = continuation(context);
if (hooks?.OnRequestMessage != null || hooks?.OnRequestStreamEnd != null || hooks?.OnResponseMessage != null || hooks?.OnResponseStreamEnd != null)
{
var requestStream = response.RequestStream;
if (hooks?.OnRequestMessage != null || hooks?.OnRequestStreamEnd != null)
{
requestStream = new WrappedClientStreamWriter(response.RequestStream, hooks.OnRequestMessage, hooks.OnRequestStreamEnd);
}
var responseStream = response.ResponseStream;
if (hooks?.OnResponseMessage != null || hooks?.OnResponseStreamEnd != null)
{
responseStream = new WrappedAsyncStreamReader(response.ResponseStream, hooks.OnResponseMessage, hooks.OnResponseStreamEnd);
}
response = new AsyncDuplexStreamingCall(requestStream, responseStream, response.ResponseHeadersAsync, response.GetStatus, response.GetTrailers, response.Dispose);
}
return response;
}
///
/// Server-side handler for intercepting unary calls.
///
/// Request message type for this method.
/// Response message type for this method.
public override async Task UnaryServerHandler(TRequest request, ServerCallContext context, UnaryServerMethod continuation)
{
var hooks = (await InterceptHandler(context, false, false, request))?.Freeze();
request = hooks?.UnaryRequestOverride ?? request;
var response = await continuation(request, context);
if (hooks?.OnUnaryResponse != null)
{
response = hooks.OnUnaryResponse(response);
}
hooks?.OnHandlerEnd();
return response;
}
///
/// Server-side handler for intercepting client streaming call.
///
/// Request message type for this method.
/// Response message type for this method.
public override async Task ClientStreamingServerHandler(IAsyncStreamReader requestStream, ServerCallContext context, ClientStreamingServerMethod continuation)
{
var hooks = (await InterceptHandler(context, true, false, null))?.Freeze();
if (hooks?.OnRequestMessage != null || hooks?.OnRequestStreamEnd != null)
{
requestStream = new WrappedAsyncStreamReader(requestStream, hooks.OnRequestMessage, hooks.OnRequestStreamEnd);
}
var response = await continuation(requestStream, context);
if (hooks?.OnUnaryResponse != null)
{
response = hooks.OnUnaryResponse(response);
}
hooks?.OnHandlerEnd();
return response;
}
///
/// Server-side handler for intercepting server streaming calls.
///
/// Request message type for this method.
/// Response message type for this method.
public override async Task ServerStreamingServerHandler(TRequest request, IServerStreamWriter responseStream, ServerCallContext context, ServerStreamingServerMethod continuation)
{
var hooks = (await InterceptHandler(context, false, true, request))?.Freeze();
request = hooks?.UnaryRequestOverride ?? request;
if (hooks?.OnResponseMessage != null)
{
responseStream = new WrappedAsyncStreamWriter(responseStream, hooks.OnResponseMessage);
}
await continuation(request, responseStream, context);
hooks?.OnHandlerEnd();
}
///
/// Server-side handler for intercepting bidi streaming calls.
///
/// Request message type for this method.
/// Response message type for this method.
public override async Task DuplexStreamingServerHandler(IAsyncStreamReader requestStream, IServerStreamWriter responseStream, ServerCallContext context, DuplexStreamingServerMethod continuation)
{
var hooks = (await InterceptHandler(context, true, true, null))?.Freeze();
if (hooks?.OnRequestMessage != null || hooks?.OnRequestStreamEnd != null)
{
requestStream = new WrappedAsyncStreamReader(requestStream, hooks.OnRequestMessage, hooks.OnRequestStreamEnd);
}
if (hooks?.OnResponseMessage != null)
{
responseStream = new WrappedAsyncStreamWriter(responseStream, hooks.OnResponseMessage);
}
await continuation(requestStream, responseStream, context);
hooks?.OnHandlerEnd();
}
private class WrappedAsyncStreamReader : IAsyncStreamReader
{
readonly IAsyncStreamReader reader;
readonly Func onMessage;
readonly Action onStreamEnd;
public WrappedAsyncStreamReader(IAsyncStreamReader reader, Func onMessage, Action onStreamEnd)
{
this.reader = reader;
this.onMessage = onMessage;
this.onStreamEnd = onStreamEnd;
}
public void Dispose() => ((IDisposable)reader).Dispose();
private T current;
public T Current
{
get
{
if (current == null)
{
throw new InvalidOperationException("No current element is available.");
}
return current;
}
}
public async Task MoveNext(CancellationToken token)
{
if (await reader.MoveNext(token))
{
var current = reader.Current;
if (onMessage != null)
{
var mappedValue = onMessage(current);
if (mappedValue != null)
{
current = mappedValue;
}
}
this.current = current;
return true;
}
onStreamEnd?.Invoke();
return false;
}
}
private class WrappedClientStreamWriter : IClientStreamWriter
{
readonly IClientStreamWriter writer;
readonly Func onMessage;
readonly Action onResponseStreamEnd;
public WrappedClientStreamWriter(IClientStreamWriter writer, Func onMessage, Action onResponseStreamEnd)
{
this.writer = writer;
this.onMessage = onMessage;
this.onResponseStreamEnd = onResponseStreamEnd;
}
public Task CompleteAsync()
{
if (onResponseStreamEnd != null)
{
return writer.CompleteAsync().ContinueWith(x => onResponseStreamEnd());
}
return writer.CompleteAsync();
}
public Task WriteAsync(T message)
{
if (onMessage != null)
{
message = onMessage(message);
}
return writer.WriteAsync(message);
}
public WriteOptions WriteOptions
{
get
{
return writer.WriteOptions;
}
set
{
writer.WriteOptions = value;
}
}
}
private class WrappedAsyncStreamWriter : IServerStreamWriter
{
readonly IAsyncStreamWriter writer;
readonly Func onMessage;
public WrappedAsyncStreamWriter(IAsyncStreamWriter writer, Func onMessage)
{
this.writer = writer;
this.onMessage = onMessage;
}
public Task WriteAsync(T message)
{
if (onMessage != null)
{
message = onMessage(message);
}
return writer.WriteAsync(message);
}
public WriteOptions WriteOptions
{
get
{
return writer.WriteOptions;
}
set
{
writer.WriteOptions = value;
}
}
}
}
}