GenericInterceptor.cs 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449
  1. #region Copyright notice and license
  2. // Copyright 2018 gRPC authors.
  3. //
  4. // Licensed under the Apache License, Version 2.0 (the "License");
  5. // you may not use this file except in compliance with the License.
  6. // You may obtain a copy of the License at
  7. //
  8. // http://www.apache.org/licenses/LICENSE-2.0
  9. //
  10. // Unless required by applicable law or agreed to in writing, software
  11. // distributed under the License is distributed on an "AS IS" BASIS,
  12. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. // See the License for the specific language governing permissions and
  14. // limitations under the License.
  15. #endregion
  16. using System;
  17. using System.Threading;
  18. using System.Threading.Tasks;
  19. using Grpc.Core.Internal;
  20. namespace Grpc.Core.Interceptors
  21. {
  22. /// <summary>
  23. /// Provides a base class for generic interceptor implementations that raises
  24. /// events and hooks to control the RPC lifecycle.
  25. /// </summary>
  26. internal abstract class GenericInterceptor : Interceptor
  27. {
  28. /// <summary>
  29. /// Provides hooks through which an invocation should be intercepted.
  30. /// </summary>
  31. public sealed class ClientCallHooks<TRequest, TResponse>
  32. where TRequest : class
  33. where TResponse : class
  34. {
  35. internal ClientCallHooks<TRequest, TResponse> Freeze()
  36. {
  37. return (ClientCallHooks<TRequest, TResponse>)MemberwiseClone();
  38. }
  39. /// <summary>
  40. /// Override the context for the outgoing invocation.
  41. /// </summary>
  42. public ClientInterceptorContext<TRequest, TResponse>? ContextOverride { get; set; }
  43. /// <summary>
  44. /// Override the request for the outgoing invocation for non-client-streaming invocations.
  45. /// </summary>
  46. public TRequest UnaryRequestOverride { get; set; }
  47. /// <summary>
  48. /// Delegate that intercepts a response from a non-server-streaming invocation and optionally overrides it.
  49. /// </summary>
  50. public Func<TResponse, TResponse> OnUnaryResponse { get; set; }
  51. /// <summary>
  52. /// Delegate that intercepts each request message for a client-streaming invocation and optionally overrides each message.
  53. /// </summary>
  54. public Func<TRequest, TRequest> OnRequestMessage { get; set; }
  55. /// <summary>
  56. /// Delegate that intercepts each response message for a server-streaming invocation and optionally overrides each message.
  57. /// </summary>
  58. public Func<TResponse, TResponse> OnResponseMessage { get; set; }
  59. /// <summary>
  60. /// Callback that gets invoked when response stream is finished.
  61. /// </summary>
  62. public Action OnResponseStreamEnd { get; set; }
  63. /// <summary>
  64. /// Callback that gets invoked when request stream is finished.
  65. /// </summary>
  66. public Action OnRequestStreamEnd { get; set; }
  67. }
  68. /// <summary>
  69. /// Intercepts an outgoing call from the client side.
  70. /// Derived classes that intend to intercept outgoing invocations from the client side should
  71. /// override this and return the appropriate hooks in the form of a ClientCallHooks instance.
  72. /// </summary>
  73. /// <param name="context">The context of the outgoing invocation.</param>
  74. /// <param name="clientStreaming">True if the invocation is client-streaming.</param>
  75. /// <param name="serverStreaming">True if the invocation is server-streaming.</param>
  76. /// <param name="request">The request message for client-unary invocations, null otherwise.</param>
  77. /// <typeparam name="TRequest">Request message type for the current invocation.</typeparam>
  78. /// <typeparam name="TResponse">Response message type for the current invocation.</typeparam>
  79. /// <returns>
  80. /// The derived class should return an instance of ClientCallHooks to control the trajectory
  81. /// as they see fit, or null if it does not intend to pursue the invocation any further.
  82. /// </returns>
  83. protected virtual ClientCallHooks<TRequest, TResponse> InterceptCall<TRequest, TResponse>(ClientInterceptorContext<TRequest, TResponse> context, bool clientStreaming, bool serverStreaming, TRequest request)
  84. where TRequest : class
  85. where TResponse : class
  86. {
  87. return null;
  88. }
  89. /// <summary>
  90. /// Provides hooks through which a server-side handler should be intercepted.
  91. /// </summary>
  92. public sealed class ServerCallHooks<TRequest, TResponse>
  93. where TRequest : class
  94. where TResponse : class
  95. {
  96. internal ServerCallHooks<TRequest, TResponse> Freeze()
  97. {
  98. return (ServerCallHooks<TRequest, TResponse>)MemberwiseClone();
  99. }
  100. /// <summary>
  101. /// Override the request for the outgoing invocation for non-client-streaming invocations.
  102. /// </summary>
  103. public TRequest UnaryRequestOverride { get; set; }
  104. /// <summary>
  105. /// Delegate that intercepts a response from a non-server-streaming invocation and optionally overrides it.
  106. /// </summary>
  107. public Func<TResponse, TResponse> OnUnaryResponse { get; set; }
  108. /// <summary>
  109. /// Delegate that intercepts each request message for a client-streaming invocation and optionally overrides each message.
  110. /// </summary>
  111. public Func<TRequest, TRequest> OnRequestMessage { get; set; }
  112. /// <summary>
  113. /// Delegate that intercepts each response message for a server-streaming invocation and optionally overrides each message.
  114. /// </summary>
  115. public Func<TResponse, TResponse> OnResponseMessage { get; set; }
  116. /// <summary>
  117. /// Callback that gets invoked when handler is finished executing.
  118. /// </summary>
  119. public Action OnHandlerEnd { get; set; }
  120. /// <summary>
  121. /// Callback that gets invoked when request stream is finished.
  122. /// </summary>
  123. public Action OnRequestStreamEnd { get; set; }
  124. }
  125. /// <summary>
  126. /// Intercepts an incoming service handler invocation on the server side.
  127. /// Derived classes that intend to intercept incoming handlers on the server side should
  128. /// override this and return the appropriate hooks in the form of a ServerCallHooks instance.
  129. /// </summary>
  130. /// <param name="context">The context of the incoming invocation.</param>
  131. /// <param name="clientStreaming">True if the invocation is client-streaming.</param>
  132. /// <param name="serverStreaming">True if the invocation is server-streaming.</param>
  133. /// <param name="request">The request message for client-unary invocations, null otherwise.</param>
  134. /// <typeparam name="TRequest">Request message type for the current invocation.</typeparam>
  135. /// <typeparam name="TResponse">Response message type for the current invocation.</typeparam>
  136. /// <returns>
  137. /// The derived class should return an instance of ServerCallHooks to control the trajectory
  138. /// as they see fit, or null if it does not intend to pursue the invocation any further.
  139. /// </returns>
  140. protected virtual Task<ServerCallHooks<TRequest, TResponse>> InterceptHandler<TRequest, TResponse>(ServerCallContext context, bool clientStreaming, bool serverStreaming, TRequest request)
  141. where TRequest : class
  142. where TResponse : class
  143. {
  144. return Task.FromResult<ServerCallHooks<TRequest, TResponse>>(null);
  145. }
  146. /// <summary>
  147. /// Intercepts a blocking invocation of a simple remote call and dispatches the events accordingly.
  148. /// </summary>
  149. public override TResponse BlockingUnaryCall<TRequest, TResponse>(TRequest request, ClientInterceptorContext<TRequest, TResponse> context, BlockingUnaryCallContinuation<TRequest, TResponse> continuation)
  150. {
  151. var hooks = InterceptCall(context, false, false, request)?.Freeze();
  152. context = hooks?.ContextOverride ?? context;
  153. request = hooks?.UnaryRequestOverride ?? request;
  154. var response = continuation(request, context);
  155. if (hooks?.OnUnaryResponse != null)
  156. {
  157. response = hooks.OnUnaryResponse(response);
  158. }
  159. return response;
  160. }
  161. /// <summary>
  162. /// Intercepts an asynchronous invocation of a simple remote call and dispatches the events accordingly.
  163. /// </summary>
  164. public override AsyncUnaryCall<TResponse> AsyncUnaryCall<TRequest, TResponse>(TRequest request, ClientInterceptorContext<TRequest, TResponse> context, AsyncUnaryCallContinuation<TRequest, TResponse> continuation)
  165. {
  166. var hooks = InterceptCall(context, false, false, request)?.Freeze();
  167. context = hooks?.ContextOverride ?? context;
  168. request = hooks?.UnaryRequestOverride ?? request;
  169. var response = continuation(request, context);
  170. if (hooks?.OnUnaryResponse != null)
  171. {
  172. response = new AsyncUnaryCall<TResponse>(response.ResponseAsync.ContinueWith(unaryResponse => hooks.OnUnaryResponse(unaryResponse.Result)),
  173. response.ResponseHeadersAsync, response.GetStatus, response.GetTrailers, response.Dispose);
  174. }
  175. return response;
  176. }
  177. /// <summary>
  178. /// Intercepts an asynchronous invocation of a streaming remote call and dispatches the events accordingly.
  179. /// </summary>
  180. public override AsyncServerStreamingCall<TResponse> AsyncServerStreamingCall<TRequest, TResponse>(TRequest request, ClientInterceptorContext<TRequest, TResponse> context, AsyncServerStreamingCallContinuation<TRequest, TResponse> continuation)
  181. {
  182. var hooks = InterceptCall(context, false, true, request)?.Freeze();
  183. context = hooks?.ContextOverride ?? context;
  184. request = hooks?.UnaryRequestOverride ?? request;
  185. var response = continuation(request, context);
  186. if (hooks?.OnResponseMessage != null || hooks?.OnResponseStreamEnd != null)
  187. {
  188. response = new AsyncServerStreamingCall<TResponse>(
  189. new WrappedAsyncStreamReader<TResponse>(response.ResponseStream, hooks.OnResponseMessage, hooks.OnResponseStreamEnd),
  190. response.ResponseHeadersAsync, response.GetStatus, response.GetTrailers, response.Dispose);
  191. }
  192. return response;
  193. }
  194. /// <summary>
  195. /// Intercepts an asynchronous invocation of a client streaming call and dispatches the events accordingly.
  196. /// </summary>
  197. public override AsyncClientStreamingCall<TRequest, TResponse> AsyncClientStreamingCall<TRequest, TResponse>(ClientInterceptorContext<TRequest, TResponse> context, AsyncClientStreamingCallContinuation<TRequest, TResponse> continuation)
  198. {
  199. var hooks = InterceptCall(context, true, false, null)?.Freeze();
  200. context = hooks?.ContextOverride ?? context;
  201. var response = continuation(context);
  202. if (hooks?.OnRequestMessage != null || hooks?.OnResponseStreamEnd != null || hooks?.OnUnaryResponse != null)
  203. {
  204. var requestStream = response.RequestStream;
  205. if (hooks?.OnRequestMessage != null || hooks?.OnRequestStreamEnd != null)
  206. {
  207. requestStream = new WrappedClientStreamWriter<TRequest>(response.RequestStream, hooks.OnRequestMessage, hooks.OnRequestStreamEnd);
  208. }
  209. var responseAsync = response.ResponseAsync;
  210. if (hooks?.OnUnaryResponse != null)
  211. {
  212. responseAsync = response.ResponseAsync.ContinueWith(unaryResponse => hooks.OnUnaryResponse(unaryResponse.Result));
  213. }
  214. response = new AsyncClientStreamingCall<TRequest, TResponse>(requestStream, responseAsync, response.ResponseHeadersAsync, response.GetStatus, response.GetTrailers, response.Dispose);
  215. }
  216. return response;
  217. }
  218. /// <summary>
  219. /// Intercepts an asynchronous invocation of a duplex streaming call and dispatches the events accordingly.
  220. /// </summary>
  221. public override AsyncDuplexStreamingCall<TRequest, TResponse> AsyncDuplexStreamingCall<TRequest, TResponse>(ClientInterceptorContext<TRequest, TResponse> context, AsyncDuplexStreamingCallContinuation<TRequest, TResponse> continuation)
  222. {
  223. var hooks = InterceptCall(context, true, true, null)?.Freeze();
  224. context = hooks?.ContextOverride ?? context;
  225. var response = continuation(context);
  226. if (hooks?.OnRequestMessage != null || hooks?.OnRequestStreamEnd != null || hooks?.OnResponseMessage != null || hooks?.OnResponseStreamEnd != null)
  227. {
  228. var requestStream = response.RequestStream;
  229. if (hooks?.OnRequestMessage != null || hooks?.OnRequestStreamEnd != null)
  230. {
  231. requestStream = new WrappedClientStreamWriter<TRequest>(response.RequestStream, hooks.OnRequestMessage, hooks.OnRequestStreamEnd);
  232. }
  233. var responseStream = response.ResponseStream;
  234. if (hooks?.OnResponseMessage != null || hooks?.OnResponseStreamEnd != null)
  235. {
  236. responseStream = new WrappedAsyncStreamReader<TResponse>(response.ResponseStream, hooks.OnResponseMessage, hooks.OnResponseStreamEnd);
  237. }
  238. response = new AsyncDuplexStreamingCall<TRequest, TResponse>(requestStream, responseStream, response.ResponseHeadersAsync, response.GetStatus, response.GetTrailers, response.Dispose);
  239. }
  240. return response;
  241. }
  242. /// <summary>
  243. /// Server-side handler for intercepting unary calls.
  244. /// </summary>
  245. /// <typeparam name="TRequest">Request message type for this method.</typeparam>
  246. /// <typeparam name="TResponse">Response message type for this method.</typeparam>
  247. public override async Task<TResponse> UnaryServerHandler<TRequest, TResponse>(TRequest request, ServerCallContext context, UnaryServerMethod<TRequest, TResponse> continuation)
  248. {
  249. var hooks = (await InterceptHandler<TRequest, TResponse>(context, false, false, request))?.Freeze();
  250. request = hooks?.UnaryRequestOverride ?? request;
  251. var response = await continuation(request, context);
  252. if (hooks?.OnUnaryResponse != null)
  253. {
  254. response = hooks.OnUnaryResponse(response);
  255. }
  256. hooks?.OnHandlerEnd();
  257. return response;
  258. }
  259. /// <summary>
  260. /// Server-side handler for intercepting client streaming call.
  261. /// </summary>
  262. /// <typeparam name="TRequest">Request message type for this method.</typeparam>
  263. /// <typeparam name="TResponse">Response message type for this method.</typeparam>
  264. public override async Task<TResponse> ClientStreamingServerHandler<TRequest, TResponse>(IAsyncStreamReader<TRequest> requestStream, ServerCallContext context, ClientStreamingServerMethod<TRequest, TResponse> continuation)
  265. {
  266. var hooks = (await InterceptHandler<TRequest, TResponse>(context, true, false, null))?.Freeze();
  267. if (hooks?.OnRequestMessage != null || hooks?.OnRequestStreamEnd != null)
  268. {
  269. requestStream = new WrappedAsyncStreamReader<TRequest>(requestStream, hooks.OnRequestMessage, hooks.OnRequestStreamEnd);
  270. }
  271. var response = await continuation(requestStream, context);
  272. if (hooks?.OnUnaryResponse != null)
  273. {
  274. response = hooks.OnUnaryResponse(response);
  275. }
  276. hooks?.OnHandlerEnd();
  277. return response;
  278. }
  279. /// <summary>
  280. /// Server-side handler for intercepting server streaming calls.
  281. /// </summary>
  282. /// <typeparam name="TRequest">Request message type for this method.</typeparam>
  283. /// <typeparam name="TResponse">Response message type for this method.</typeparam>
  284. public override async Task ServerStreamingServerHandler<TRequest, TResponse>(TRequest request, IServerStreamWriter<TResponse> responseStream, ServerCallContext context, ServerStreamingServerMethod<TRequest, TResponse> continuation)
  285. {
  286. var hooks = (await InterceptHandler<TRequest, TResponse>(context, false, true, request))?.Freeze();
  287. request = hooks?.UnaryRequestOverride ?? request;
  288. if (hooks?.OnResponseMessage != null)
  289. {
  290. responseStream = new WrappedAsyncStreamWriter<TResponse>(responseStream, hooks.OnResponseMessage);
  291. }
  292. await continuation(request, responseStream, context);
  293. hooks?.OnHandlerEnd();
  294. }
  295. /// <summary>
  296. /// Server-side handler for intercepting bidi streaming calls.
  297. /// </summary>
  298. /// <typeparam name="TRequest">Request message type for this method.</typeparam>
  299. /// <typeparam name="TResponse">Response message type for this method.</typeparam>
  300. public override async Task DuplexStreamingServerHandler<TRequest, TResponse>(IAsyncStreamReader<TRequest> requestStream, IServerStreamWriter<TResponse> responseStream, ServerCallContext context, DuplexStreamingServerMethod<TRequest, TResponse> continuation)
  301. {
  302. var hooks = (await InterceptHandler<TRequest, TResponse>(context, true, true, null))?.Freeze();
  303. if (hooks?.OnRequestMessage != null || hooks?.OnRequestStreamEnd != null)
  304. {
  305. requestStream = new WrappedAsyncStreamReader<TRequest>(requestStream, hooks.OnRequestMessage, hooks.OnRequestStreamEnd);
  306. }
  307. if (hooks?.OnResponseMessage != null)
  308. {
  309. responseStream = new WrappedAsyncStreamWriter<TResponse>(responseStream, hooks.OnResponseMessage);
  310. }
  311. await continuation(requestStream, responseStream, context);
  312. hooks?.OnHandlerEnd();
  313. }
  314. private class WrappedAsyncStreamReader<T> : IAsyncStreamReader<T>
  315. {
  316. readonly IAsyncStreamReader<T> reader;
  317. readonly Func<T, T> onMessage;
  318. readonly Action onStreamEnd;
  319. public WrappedAsyncStreamReader(IAsyncStreamReader<T> reader, Func<T, T> onMessage, Action onStreamEnd)
  320. {
  321. this.reader = reader;
  322. this.onMessage = onMessage;
  323. this.onStreamEnd = onStreamEnd;
  324. }
  325. public void Dispose() => ((IDisposable)reader).Dispose();
  326. private T current;
  327. public T Current
  328. {
  329. get
  330. {
  331. if (current == null)
  332. {
  333. throw new InvalidOperationException("No current element is available.");
  334. }
  335. return current;
  336. }
  337. }
  338. public async Task<bool> MoveNext(CancellationToken token)
  339. {
  340. if (await reader.MoveNext(token))
  341. {
  342. var current = reader.Current;
  343. if (onMessage != null)
  344. {
  345. var mappedValue = onMessage(current);
  346. if (mappedValue != null)
  347. {
  348. current = mappedValue;
  349. }
  350. }
  351. this.current = current;
  352. return true;
  353. }
  354. onStreamEnd?.Invoke();
  355. return false;
  356. }
  357. }
  358. private class WrappedClientStreamWriter<T> : IClientStreamWriter<T>
  359. {
  360. readonly IClientStreamWriter<T> writer;
  361. readonly Func<T, T> onMessage;
  362. readonly Action onResponseStreamEnd;
  363. public WrappedClientStreamWriter(IClientStreamWriter<T> writer, Func<T, T> onMessage, Action onResponseStreamEnd)
  364. {
  365. this.writer = writer;
  366. this.onMessage = onMessage;
  367. this.onResponseStreamEnd = onResponseStreamEnd;
  368. }
  369. public Task CompleteAsync()
  370. {
  371. if (onResponseStreamEnd != null)
  372. {
  373. return writer.CompleteAsync().ContinueWith(x => onResponseStreamEnd());
  374. }
  375. return writer.CompleteAsync();
  376. }
  377. public Task WriteAsync(T message)
  378. {
  379. if (onMessage != null)
  380. {
  381. message = onMessage(message);
  382. }
  383. return writer.WriteAsync(message);
  384. }
  385. public WriteOptions WriteOptions
  386. {
  387. get
  388. {
  389. return writer.WriteOptions;
  390. }
  391. set
  392. {
  393. writer.WriteOptions = value;
  394. }
  395. }
  396. }
  397. private class WrappedAsyncStreamWriter<T> : IServerStreamWriter<T>
  398. {
  399. readonly IAsyncStreamWriter<T> writer;
  400. readonly Func<T, T> onMessage;
  401. public WrappedAsyncStreamWriter(IAsyncStreamWriter<T> writer, Func<T, T> onMessage)
  402. {
  403. this.writer = writer;
  404. this.onMessage = onMessage;
  405. }
  406. public Task WriteAsync(T message)
  407. {
  408. if (onMessage != null)
  409. {
  410. message = onMessage(message);
  411. }
  412. return writer.WriteAsync(message);
  413. }
  414. public WriteOptions WriteOptions
  415. {
  416. get
  417. {
  418. return writer.WriteOptions;
  419. }
  420. set
  421. {
  422. writer.WriteOptions = value;
  423. }
  424. }
  425. }
  426. }
  427. }