|  | @@ -270,6 +270,129 @@ class HijackingInterceptorMakesAnotherCallFactory
 | 
	
		
			
				|  |  |    }
 | 
	
		
			
				|  |  |  };
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | +class BidiStreamingRpcHijackingInterceptor : public experimental::Interceptor {
 | 
	
		
			
				|  |  | + public:
 | 
	
		
			
				|  |  | +  BidiStreamingRpcHijackingInterceptor(experimental::ClientRpcInfo* info) {
 | 
	
		
			
				|  |  | +    info_ = info;
 | 
	
		
			
				|  |  | +  }
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +  virtual void Intercept(experimental::InterceptorBatchMethods* methods) {
 | 
	
		
			
				|  |  | +    bool hijack = false;
 | 
	
		
			
				|  |  | +    if (methods->QueryInterceptionHookPoint(
 | 
	
		
			
				|  |  | +            experimental::InterceptionHookPoints::PRE_SEND_INITIAL_METADATA)) {
 | 
	
		
			
				|  |  | +      CheckMetadata(*methods->GetSendInitialMetadata(), "testkey", "testvalue");
 | 
	
		
			
				|  |  | +      hijack = true;
 | 
	
		
			
				|  |  | +    }
 | 
	
		
			
				|  |  | +    if (methods->QueryInterceptionHookPoint(
 | 
	
		
			
				|  |  | +            experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)) {
 | 
	
		
			
				|  |  | +      EchoRequest req;
 | 
	
		
			
				|  |  | +      auto* buffer = methods->GetSerializedSendMessage();
 | 
	
		
			
				|  |  | +      auto copied_buffer = *buffer;
 | 
	
		
			
				|  |  | +      EXPECT_TRUE(
 | 
	
		
			
				|  |  | +          SerializationTraits<EchoRequest>::Deserialize(&copied_buffer, &req)
 | 
	
		
			
				|  |  | +              .ok());
 | 
	
		
			
				|  |  | +      EXPECT_EQ(req.message().find("Hello"), 0u);
 | 
	
		
			
				|  |  | +      msg = req.message();
 | 
	
		
			
				|  |  | +    }
 | 
	
		
			
				|  |  | +    if (methods->QueryInterceptionHookPoint(
 | 
	
		
			
				|  |  | +            experimental::InterceptionHookPoints::PRE_SEND_CLOSE)) {
 | 
	
		
			
				|  |  | +      // Got nothing to do here for now
 | 
	
		
			
				|  |  | +    }
 | 
	
		
			
				|  |  | +    if (methods->QueryInterceptionHookPoint(
 | 
	
		
			
				|  |  | +            experimental::InterceptionHookPoints::POST_RECV_STATUS)) {
 | 
	
		
			
				|  |  | +      CheckMetadata(*methods->GetRecvTrailingMetadata(), "testkey",
 | 
	
		
			
				|  |  | +                    "testvalue");
 | 
	
		
			
				|  |  | +      auto* status = methods->GetRecvStatus();
 | 
	
		
			
				|  |  | +      EXPECT_EQ(status->ok(), true);
 | 
	
		
			
				|  |  | +    }
 | 
	
		
			
				|  |  | +    if (methods->QueryInterceptionHookPoint(
 | 
	
		
			
				|  |  | +            experimental::InterceptionHookPoints::PRE_RECV_MESSAGE)) {
 | 
	
		
			
				|  |  | +      EchoResponse* resp =
 | 
	
		
			
				|  |  | +          static_cast<EchoResponse*>(methods->GetRecvMessage());
 | 
	
		
			
				|  |  | +      resp->set_message(msg);
 | 
	
		
			
				|  |  | +    }
 | 
	
		
			
				|  |  | +    if (methods->QueryInterceptionHookPoint(
 | 
	
		
			
				|  |  | +            experimental::InterceptionHookPoints::POST_RECV_MESSAGE)) {
 | 
	
		
			
				|  |  | +      EXPECT_EQ(static_cast<EchoResponse*>(methods->GetRecvMessage())
 | 
	
		
			
				|  |  | +                    ->message()
 | 
	
		
			
				|  |  | +                    .find("Hello"),
 | 
	
		
			
				|  |  | +                0u);
 | 
	
		
			
				|  |  | +    }
 | 
	
		
			
				|  |  | +    if (methods->QueryInterceptionHookPoint(
 | 
	
		
			
				|  |  | +            experimental::InterceptionHookPoints::PRE_RECV_STATUS)) {
 | 
	
		
			
				|  |  | +      auto* map = methods->GetRecvTrailingMetadata();
 | 
	
		
			
				|  |  | +      // insert the metadata that we want
 | 
	
		
			
				|  |  | +      EXPECT_EQ(map->size(), static_cast<unsigned>(0));
 | 
	
		
			
				|  |  | +      map->insert(std::make_pair("testkey", "testvalue"));
 | 
	
		
			
				|  |  | +      auto* status = methods->GetRecvStatus();
 | 
	
		
			
				|  |  | +      *status = Status(StatusCode::OK, "");
 | 
	
		
			
				|  |  | +    }
 | 
	
		
			
				|  |  | +    if (hijack) {
 | 
	
		
			
				|  |  | +      methods->Hijack();
 | 
	
		
			
				|  |  | +    } else {
 | 
	
		
			
				|  |  | +      methods->Proceed();
 | 
	
		
			
				|  |  | +    }
 | 
	
		
			
				|  |  | +  }
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | + private:
 | 
	
		
			
				|  |  | +  experimental::ClientRpcInfo* info_;
 | 
	
		
			
				|  |  | +  grpc::string msg;
 | 
	
		
			
				|  |  | +};
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +class ClientStreamingRpcHijackingInterceptor
 | 
	
		
			
				|  |  | +    : public experimental::Interceptor {
 | 
	
		
			
				|  |  | + public:
 | 
	
		
			
				|  |  | +  ClientStreamingRpcHijackingInterceptor(experimental::ClientRpcInfo* info) {
 | 
	
		
			
				|  |  | +    info_ = info;
 | 
	
		
			
				|  |  | +  }
 | 
	
		
			
				|  |  | +  virtual void Intercept(experimental::InterceptorBatchMethods* methods) {
 | 
	
		
			
				|  |  | +    bool hijack = false;
 | 
	
		
			
				|  |  | +    if (methods->QueryInterceptionHookPoint(
 | 
	
		
			
				|  |  | +            experimental::InterceptionHookPoints::PRE_SEND_INITIAL_METADATA)) {
 | 
	
		
			
				|  |  | +      hijack = true;
 | 
	
		
			
				|  |  | +    }
 | 
	
		
			
				|  |  | +    if (methods->QueryInterceptionHookPoint(
 | 
	
		
			
				|  |  | +            experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)) {
 | 
	
		
			
				|  |  | +      if (++count_ > 10) {
 | 
	
		
			
				|  |  | +        methods->FailHijackedSendMessage();
 | 
	
		
			
				|  |  | +      }
 | 
	
		
			
				|  |  | +    }
 | 
	
		
			
				|  |  | +    if (methods->QueryInterceptionHookPoint(
 | 
	
		
			
				|  |  | +            experimental::InterceptionHookPoints::POST_SEND_MESSAGE)) {
 | 
	
		
			
				|  |  | +      EXPECT_FALSE(got_failed_send_);
 | 
	
		
			
				|  |  | +      got_failed_send_ = !methods->GetSendMessageStatus();
 | 
	
		
			
				|  |  | +    }
 | 
	
		
			
				|  |  | +    if (methods->QueryInterceptionHookPoint(
 | 
	
		
			
				|  |  | +            experimental::InterceptionHookPoints::PRE_RECV_STATUS)) {
 | 
	
		
			
				|  |  | +      auto* status = methods->GetRecvStatus();
 | 
	
		
			
				|  |  | +      *status = Status(StatusCode::UNAVAILABLE, "Done sending 10 messages");
 | 
	
		
			
				|  |  | +    }
 | 
	
		
			
				|  |  | +    if (hijack) {
 | 
	
		
			
				|  |  | +      methods->Hijack();
 | 
	
		
			
				|  |  | +    } else {
 | 
	
		
			
				|  |  | +      methods->Proceed();
 | 
	
		
			
				|  |  | +    }
 | 
	
		
			
				|  |  | +  }
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +  static bool GotFailedSend() { return got_failed_send_; }
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | + private:
 | 
	
		
			
				|  |  | +  experimental::ClientRpcInfo* info_;
 | 
	
		
			
				|  |  | +  int count_ = 0;
 | 
	
		
			
				|  |  | +  static bool got_failed_send_;
 | 
	
		
			
				|  |  | +};
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +bool ClientStreamingRpcHijackingInterceptor::got_failed_send_ = false;
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +class ClientStreamingRpcHijackingInterceptorFactory
 | 
	
		
			
				|  |  | +    : public experimental::ClientInterceptorFactoryInterface {
 | 
	
		
			
				|  |  | + public:
 | 
	
		
			
				|  |  | +  virtual experimental::Interceptor* CreateClientInterceptor(
 | 
	
		
			
				|  |  | +      experimental::ClientRpcInfo* info) override {
 | 
	
		
			
				|  |  | +    return new ClientStreamingRpcHijackingInterceptor(info);
 | 
	
		
			
				|  |  | +  }
 | 
	
		
			
				|  |  | +};
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |  class ServerStreamingRpcHijackingInterceptor
 | 
	
		
			
				|  |  |      : public experimental::Interceptor {
 | 
	
		
			
				|  |  |   public:
 | 
	
	
		
			
				|  | @@ -292,7 +415,7 @@ class ServerStreamingRpcHijackingInterceptor
 | 
	
		
			
				|  |  |      if (methods->QueryInterceptionHookPoint(
 | 
	
		
			
				|  |  |              experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)) {
 | 
	
		
			
				|  |  |        EchoRequest req;
 | 
	
		
			
				|  |  | -      auto* buffer = methods->GetSendMessage();
 | 
	
		
			
				|  |  | +      auto* buffer = methods->GetSerializedSendMessage();
 | 
	
		
			
				|  |  |        auto copied_buffer = *buffer;
 | 
	
		
			
				|  |  |        EXPECT_TRUE(
 | 
	
		
			
				|  |  |            SerializationTraits<EchoRequest>::Deserialize(&copied_buffer, &req)
 | 
	
	
		
			
				|  | @@ -367,6 +490,15 @@ class ServerStreamingRpcHijackingInterceptorFactory
 | 
	
		
			
				|  |  |    }
 | 
	
		
			
				|  |  |  };
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | +class BidiStreamingRpcHijackingInterceptorFactory
 | 
	
		
			
				|  |  | +    : public experimental::ClientInterceptorFactoryInterface {
 | 
	
		
			
				|  |  | + public:
 | 
	
		
			
				|  |  | +  virtual experimental::Interceptor* CreateClientInterceptor(
 | 
	
		
			
				|  |  | +      experimental::ClientRpcInfo* info) override {
 | 
	
		
			
				|  |  | +    return new BidiStreamingRpcHijackingInterceptor(info);
 | 
	
		
			
				|  |  | +  }
 | 
	
		
			
				|  |  | +};
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |  class LoggingInterceptor : public experimental::Interceptor {
 | 
	
		
			
				|  |  |   public:
 | 
	
		
			
				|  |  |    LoggingInterceptor(experimental::ClientRpcInfo* info) { info_ = info; }
 | 
	
	
		
			
				|  | @@ -647,6 +779,35 @@ TEST_F(ClientInterceptorsStreamingEnd2endTest, ServerStreamingTest) {
 | 
	
		
			
				|  |  |    EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 20);
 | 
	
		
			
				|  |  |  }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | +TEST_F(ClientInterceptorsStreamingEnd2endTest, ClientStreamingHijackingTest) {
 | 
	
		
			
				|  |  | +  ChannelArguments args;
 | 
	
		
			
				|  |  | +  std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
 | 
	
		
			
				|  |  | +      creators;
 | 
	
		
			
				|  |  | +  creators.push_back(
 | 
	
		
			
				|  |  | +      std::unique_ptr<ClientStreamingRpcHijackingInterceptorFactory>(
 | 
	
		
			
				|  |  | +          new ClientStreamingRpcHijackingInterceptorFactory()));
 | 
	
		
			
				|  |  | +  auto channel = experimental::CreateCustomChannelWithInterceptors(
 | 
	
		
			
				|  |  | +      server_address_, InsecureChannelCredentials(), args, std::move(creators));
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +  auto stub = grpc::testing::EchoTestService::NewStub(channel);
 | 
	
		
			
				|  |  | +  ClientContext ctx;
 | 
	
		
			
				|  |  | +  EchoRequest req;
 | 
	
		
			
				|  |  | +  EchoResponse resp;
 | 
	
		
			
				|  |  | +  req.mutable_param()->set_echo_metadata(true);
 | 
	
		
			
				|  |  | +  req.set_message("Hello");
 | 
	
		
			
				|  |  | +  string expected_resp = "";
 | 
	
		
			
				|  |  | +  auto writer = stub->RequestStream(&ctx, &resp);
 | 
	
		
			
				|  |  | +  for (int i = 0; i < 10; i++) {
 | 
	
		
			
				|  |  | +    EXPECT_TRUE(writer->Write(req));
 | 
	
		
			
				|  |  | +    expected_resp += "Hello";
 | 
	
		
			
				|  |  | +  }
 | 
	
		
			
				|  |  | +  // The interceptor will reject the 11th message
 | 
	
		
			
				|  |  | +  writer->Write(req);
 | 
	
		
			
				|  |  | +  Status s = writer->Finish();
 | 
	
		
			
				|  |  | +  EXPECT_EQ(s.ok(), false);
 | 
	
		
			
				|  |  | +  EXPECT_TRUE(ClientStreamingRpcHijackingInterceptor::GotFailedSend());
 | 
	
		
			
				|  |  | +}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |  TEST_F(ClientInterceptorsStreamingEnd2endTest, ServerStreamingHijackingTest) {
 | 
	
		
			
				|  |  |    ChannelArguments args;
 | 
	
		
			
				|  |  |    DummyInterceptor::Reset();
 | 
	
	
		
			
				|  | @@ -661,6 +822,19 @@ TEST_F(ClientInterceptorsStreamingEnd2endTest, ServerStreamingHijackingTest) {
 | 
	
		
			
				|  |  |    EXPECT_TRUE(ServerStreamingRpcHijackingInterceptor::GotFailedMessage());
 | 
	
		
			
				|  |  |  }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | +TEST_F(ClientInterceptorsStreamingEnd2endTest, BidiStreamingHijackingTest) {
 | 
	
		
			
				|  |  | +  ChannelArguments args;
 | 
	
		
			
				|  |  | +  DummyInterceptor::Reset();
 | 
	
		
			
				|  |  | +  std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
 | 
	
		
			
				|  |  | +      creators;
 | 
	
		
			
				|  |  | +  creators.push_back(
 | 
	
		
			
				|  |  | +      std::unique_ptr<BidiStreamingRpcHijackingInterceptorFactory>(
 | 
	
		
			
				|  |  | +          new BidiStreamingRpcHijackingInterceptorFactory()));
 | 
	
		
			
				|  |  | +  auto channel = experimental::CreateCustomChannelWithInterceptors(
 | 
	
		
			
				|  |  | +      server_address_, InsecureChannelCredentials(), args, std::move(creators));
 | 
	
		
			
				|  |  | +  MakeBidiStreamingCall(channel);
 | 
	
		
			
				|  |  | +}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |  TEST_F(ClientInterceptorsStreamingEnd2endTest, BidiStreamingTest) {
 | 
	
		
			
				|  |  |    ChannelArguments args;
 | 
	
		
			
				|  |  |    DummyInterceptor::Reset();
 |