#region Copyright notice and license // Copyright 2018 gRPC authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #endregion using System; using System.Threading; using System.Threading.Tasks; using Grpc.Core.Internal; namespace Grpc.Core.Interceptors { /// /// Provides a base class for generic interceptor implementations that raises /// events and hooks to control the RPC lifecycle. /// internal abstract class GenericInterceptor : Interceptor { /// /// Provides hooks through which an invocation should be intercepted. /// public sealed class ClientCallHooks where TRequest : class where TResponse : class { internal ClientCallHooks Freeze() { return (ClientCallHooks)MemberwiseClone(); } /// /// Override the context for the outgoing invocation. /// public ClientInterceptorContext? ContextOverride { get; set; } /// /// Override the request for the outgoing invocation for non-client-streaming invocations. /// public TRequest UnaryRequestOverride { get; set; } /// /// Delegate that intercepts a response from a non-server-streaming invocation and optionally overrides it. /// public Func OnUnaryResponse { get; set; } /// /// Delegate that intercepts each request message for a client-streaming invocation and optionally overrides each message. /// public Func OnRequestMessage { get; set; } /// /// Delegate that intercepts each response message for a server-streaming invocation and optionally overrides each message. /// public Func OnResponseMessage { get; set; } /// /// Callback that gets invoked when response stream is finished. /// public Action OnResponseStreamEnd { get; set; } /// /// Callback that gets invoked when request stream is finished. /// public Action OnRequestStreamEnd { get; set; } } /// /// Intercepts an outgoing call from the client side. /// Derived classes that intend to intercept outgoing invocations from the client side should /// override this and return the appropriate hooks in the form of a ClientCallHooks instance. /// /// The context of the outgoing invocation. /// True if the invocation is client-streaming. /// True if the invocation is server-streaming. /// The request message for client-unary invocations, null otherwise. /// Request message type for the current invocation. /// Response message type for the current invocation. /// /// The derived class should return an instance of ClientCallHooks to control the trajectory /// as they see fit, or null if it does not intend to pursue the invocation any further. /// protected virtual ClientCallHooks InterceptCall(ClientInterceptorContext context, bool clientStreaming, bool serverStreaming, TRequest request) where TRequest : class where TResponse : class { return null; } /// /// Provides hooks through which a server-side handler should be intercepted. /// public sealed class ServerCallHooks where TRequest : class where TResponse : class { internal ServerCallHooks Freeze() { return (ServerCallHooks)MemberwiseClone(); } /// /// Override the request for the outgoing invocation for non-client-streaming invocations. /// public TRequest UnaryRequestOverride { get; set; } /// /// Delegate that intercepts a response from a non-server-streaming invocation and optionally overrides it. /// public Func OnUnaryResponse { get; set; } /// /// Delegate that intercepts each request message for a client-streaming invocation and optionally overrides each message. /// public Func OnRequestMessage { get; set; } /// /// Delegate that intercepts each response message for a server-streaming invocation and optionally overrides each message. /// public Func OnResponseMessage { get; set; } /// /// Callback that gets invoked when handler is finished executing. /// public Action OnHandlerEnd { get; set; } /// /// Callback that gets invoked when request stream is finished. /// public Action OnRequestStreamEnd { get; set; } } /// /// Intercepts an incoming service handler invocation on the server side. /// Derived classes that intend to intercept incoming handlers on the server side should /// override this and return the appropriate hooks in the form of a ServerCallHooks instance. /// /// The context of the incoming invocation. /// True if the invocation is client-streaming. /// True if the invocation is server-streaming. /// The request message for client-unary invocations, null otherwise. /// Request message type for the current invocation. /// Response message type for the current invocation. /// /// The derived class should return an instance of ServerCallHooks to control the trajectory /// as they see fit, or null if it does not intend to pursue the invocation any further. /// protected virtual Task> InterceptHandler(ServerCallContext context, bool clientStreaming, bool serverStreaming, TRequest request) where TRequest : class where TResponse : class { return Task.FromResult>(null); } /// /// Intercepts a blocking invocation of a simple remote call and dispatches the events accordingly. /// public override TResponse BlockingUnaryCall(TRequest request, ClientInterceptorContext context, BlockingUnaryCallContinuation continuation) { var hooks = InterceptCall(context, false, false, request)?.Freeze(); context = hooks?.ContextOverride ?? context; request = hooks?.UnaryRequestOverride ?? request; var response = continuation(request, context); if (hooks?.OnUnaryResponse != null) { response = hooks.OnUnaryResponse(response); } return response; } /// /// Intercepts an asynchronous invocation of a simple remote call and dispatches the events accordingly. /// public override AsyncUnaryCall AsyncUnaryCall(TRequest request, ClientInterceptorContext context, AsyncUnaryCallContinuation continuation) { var hooks = InterceptCall(context, false, false, request)?.Freeze(); context = hooks?.ContextOverride ?? context; request = hooks?.UnaryRequestOverride ?? request; var response = continuation(request, context); if (hooks?.OnUnaryResponse != null) { response = new AsyncUnaryCall(response.ResponseAsync.ContinueWith(unaryResponse => hooks.OnUnaryResponse(unaryResponse.Result)), response.ResponseHeadersAsync, response.GetStatus, response.GetTrailers, response.Dispose); } return response; } /// /// Intercepts an asynchronous invocation of a streaming remote call and dispatches the events accordingly. /// public override AsyncServerStreamingCall AsyncServerStreamingCall(TRequest request, ClientInterceptorContext context, AsyncServerStreamingCallContinuation continuation) { var hooks = InterceptCall(context, false, true, request)?.Freeze(); context = hooks?.ContextOverride ?? context; request = hooks?.UnaryRequestOverride ?? request; var response = continuation(request, context); if (hooks?.OnResponseMessage != null || hooks?.OnResponseStreamEnd != null) { response = new AsyncServerStreamingCall( new WrappedAsyncStreamReader(response.ResponseStream, hooks.OnResponseMessage, hooks.OnResponseStreamEnd), response.ResponseHeadersAsync, response.GetStatus, response.GetTrailers, response.Dispose); } return response; } /// /// Intercepts an asynchronous invocation of a client streaming call and dispatches the events accordingly. /// public override AsyncClientStreamingCall AsyncClientStreamingCall(ClientInterceptorContext context, AsyncClientStreamingCallContinuation continuation) { var hooks = InterceptCall(context, true, false, null)?.Freeze(); context = hooks?.ContextOverride ?? context; var response = continuation(context); if (hooks?.OnRequestMessage != null || hooks?.OnResponseStreamEnd != null || hooks?.OnUnaryResponse != null) { var requestStream = response.RequestStream; if (hooks?.OnRequestMessage != null || hooks?.OnRequestStreamEnd != null) { requestStream = new WrappedClientStreamWriter(response.RequestStream, hooks.OnRequestMessage, hooks.OnRequestStreamEnd); } var responseAsync = response.ResponseAsync; if (hooks?.OnUnaryResponse != null) { responseAsync = response.ResponseAsync.ContinueWith(unaryResponse => hooks.OnUnaryResponse(unaryResponse.Result)); } response = new AsyncClientStreamingCall(requestStream, responseAsync, response.ResponseHeadersAsync, response.GetStatus, response.GetTrailers, response.Dispose); } return response; } /// /// Intercepts an asynchronous invocation of a duplex streaming call and dispatches the events accordingly. /// public override AsyncDuplexStreamingCall AsyncDuplexStreamingCall(ClientInterceptorContext context, AsyncDuplexStreamingCallContinuation continuation) { var hooks = InterceptCall(context, true, true, null)?.Freeze(); context = hooks?.ContextOverride ?? context; var response = continuation(context); if (hooks?.OnRequestMessage != null || hooks?.OnRequestStreamEnd != null || hooks?.OnResponseMessage != null || hooks?.OnResponseStreamEnd != null) { var requestStream = response.RequestStream; if (hooks?.OnRequestMessage != null || hooks?.OnRequestStreamEnd != null) { requestStream = new WrappedClientStreamWriter(response.RequestStream, hooks.OnRequestMessage, hooks.OnRequestStreamEnd); } var responseStream = response.ResponseStream; if (hooks?.OnResponseMessage != null || hooks?.OnResponseStreamEnd != null) { responseStream = new WrappedAsyncStreamReader(response.ResponseStream, hooks.OnResponseMessage, hooks.OnResponseStreamEnd); } response = new AsyncDuplexStreamingCall(requestStream, responseStream, response.ResponseHeadersAsync, response.GetStatus, response.GetTrailers, response.Dispose); } return response; } /// /// Server-side handler for intercepting unary calls. /// /// Request message type for this method. /// Response message type for this method. public override async Task UnaryServerHandler(TRequest request, ServerCallContext context, UnaryServerMethod continuation) { var hooks = (await InterceptHandler(context, false, false, request))?.Freeze(); request = hooks?.UnaryRequestOverride ?? request; var response = await continuation(request, context); if (hooks?.OnUnaryResponse != null) { response = hooks.OnUnaryResponse(response); } hooks?.OnHandlerEnd(); return response; } /// /// Server-side handler for intercepting client streaming call. /// /// Request message type for this method. /// Response message type for this method. public override async Task ClientStreamingServerHandler(IAsyncStreamReader requestStream, ServerCallContext context, ClientStreamingServerMethod continuation) { var hooks = (await InterceptHandler(context, true, false, null))?.Freeze(); if (hooks?.OnRequestMessage != null || hooks?.OnRequestStreamEnd != null) { requestStream = new WrappedAsyncStreamReader(requestStream, hooks.OnRequestMessage, hooks.OnRequestStreamEnd); } var response = await continuation(requestStream, context); if (hooks?.OnUnaryResponse != null) { response = hooks.OnUnaryResponse(response); } hooks?.OnHandlerEnd(); return response; } /// /// Server-side handler for intercepting server streaming calls. /// /// Request message type for this method. /// Response message type for this method. public override async Task ServerStreamingServerHandler(TRequest request, IServerStreamWriter responseStream, ServerCallContext context, ServerStreamingServerMethod continuation) { var hooks = (await InterceptHandler(context, false, true, request))?.Freeze(); request = hooks?.UnaryRequestOverride ?? request; if (hooks?.OnResponseMessage != null) { responseStream = new WrappedAsyncStreamWriter(responseStream, hooks.OnResponseMessage); } await continuation(request, responseStream, context); hooks?.OnHandlerEnd(); } /// /// Server-side handler for intercepting bidi streaming calls. /// /// Request message type for this method. /// Response message type for this method. public override async Task DuplexStreamingServerHandler(IAsyncStreamReader requestStream, IServerStreamWriter responseStream, ServerCallContext context, DuplexStreamingServerMethod continuation) { var hooks = (await InterceptHandler(context, true, true, null))?.Freeze(); if (hooks?.OnRequestMessage != null || hooks?.OnRequestStreamEnd != null) { requestStream = new WrappedAsyncStreamReader(requestStream, hooks.OnRequestMessage, hooks.OnRequestStreamEnd); } if (hooks?.OnResponseMessage != null) { responseStream = new WrappedAsyncStreamWriter(responseStream, hooks.OnResponseMessage); } await continuation(requestStream, responseStream, context); hooks?.OnHandlerEnd(); } private class WrappedAsyncStreamReader : IAsyncStreamReader { readonly IAsyncStreamReader reader; readonly Func onMessage; readonly Action onStreamEnd; public WrappedAsyncStreamReader(IAsyncStreamReader reader, Func onMessage, Action onStreamEnd) { this.reader = reader; this.onMessage = onMessage; this.onStreamEnd = onStreamEnd; } public void Dispose() => ((IDisposable)reader).Dispose(); private T current; public T Current { get { if (current == null) { throw new InvalidOperationException("No current element is available."); } return current; } } public async Task MoveNext(CancellationToken token) { if (await reader.MoveNext(token)) { var current = reader.Current; if (onMessage != null) { var mappedValue = onMessage(current); if (mappedValue != null) { current = mappedValue; } } this.current = current; return true; } onStreamEnd?.Invoke(); return false; } } private class WrappedClientStreamWriter : IClientStreamWriter { readonly IClientStreamWriter writer; readonly Func onMessage; readonly Action onResponseStreamEnd; public WrappedClientStreamWriter(IClientStreamWriter writer, Func onMessage, Action onResponseStreamEnd) { this.writer = writer; this.onMessage = onMessage; this.onResponseStreamEnd = onResponseStreamEnd; } public Task CompleteAsync() { if (onResponseStreamEnd != null) { return writer.CompleteAsync().ContinueWith(x => onResponseStreamEnd()); } return writer.CompleteAsync(); } public Task WriteAsync(T message) { if (onMessage != null) { message = onMessage(message); } return writer.WriteAsync(message); } public WriteOptions WriteOptions { get { return writer.WriteOptions; } set { writer.WriteOptions = value; } } } private class WrappedAsyncStreamWriter : IServerStreamWriter { readonly IAsyncStreamWriter writer; readonly Func onMessage; public WrappedAsyncStreamWriter(IAsyncStreamWriter writer, Func onMessage) { this.writer = writer; this.onMessage = onMessage; } public Task WriteAsync(T message) { if (onMessage != null) { message = onMessage(message); } return writer.WriteAsync(message); } public WriteOptions WriteOptions { get { return writer.WriteOptions; } set { writer.WriteOptions = value; } } } } }