|  | @@ -393,6 +393,103 @@ class ClientStreamingRpcHijackingInterceptorFactory
 | 
	
		
			
				|  |  |    }
 | 
	
		
			
				|  |  |  };
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | +class ServerStreamingRpcHijackingInterceptor
 | 
	
		
			
				|  |  | +    : public experimental::Interceptor {
 | 
	
		
			
				|  |  | + public:
 | 
	
		
			
				|  |  | +  ServerStreamingRpcHijackingInterceptor(experimental::ClientRpcInfo* info) {
 | 
	
		
			
				|  |  | +    info_ = info;
 | 
	
		
			
				|  |  | +  }
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +  virtual void Intercept(experimental::InterceptorBatchMethods* methods) {
 | 
	
		
			
				|  |  | +    bool hijack = false;
 | 
	
		
			
				|  |  | +    if (methods->QueryInterceptionHookPoint(
 | 
	
		
			
				|  |  | +            experimental::InterceptionHookPoints::PRE_SEND_INITIAL_METADATA)) {
 | 
	
		
			
				|  |  | +      auto* map = methods->GetSendInitialMetadata();
 | 
	
		
			
				|  |  | +      // Check that we can see the test metadata
 | 
	
		
			
				|  |  | +      ASSERT_EQ(map->size(), static_cast<unsigned>(1));
 | 
	
		
			
				|  |  | +      auto iterator = map->begin();
 | 
	
		
			
				|  |  | +      EXPECT_EQ("testkey", iterator->first);
 | 
	
		
			
				|  |  | +      EXPECT_EQ("testvalue", iterator->second);
 | 
	
		
			
				|  |  | +      hijack = true;
 | 
	
		
			
				|  |  | +    }
 | 
	
		
			
				|  |  | +    if (methods->QueryInterceptionHookPoint(
 | 
	
		
			
				|  |  | +            experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)) {
 | 
	
		
			
				|  |  | +      EchoRequest req;
 | 
	
		
			
				|  |  | +      auto* buffer = methods->GetSerializedSendMessage();
 | 
	
		
			
				|  |  | +      auto copied_buffer = *buffer;
 | 
	
		
			
				|  |  | +      EXPECT_TRUE(
 | 
	
		
			
				|  |  | +          SerializationTraits<EchoRequest>::Deserialize(&copied_buffer, &req)
 | 
	
		
			
				|  |  | +              .ok());
 | 
	
		
			
				|  |  | +      EXPECT_EQ(req.message(), "Hello");
 | 
	
		
			
				|  |  | +    }
 | 
	
		
			
				|  |  | +    if (methods->QueryInterceptionHookPoint(
 | 
	
		
			
				|  |  | +            experimental::InterceptionHookPoints::PRE_SEND_CLOSE)) {
 | 
	
		
			
				|  |  | +      // Got nothing to do here for now
 | 
	
		
			
				|  |  | +    }
 | 
	
		
			
				|  |  | +    if (methods->QueryInterceptionHookPoint(
 | 
	
		
			
				|  |  | +            experimental::InterceptionHookPoints::POST_RECV_STATUS)) {
 | 
	
		
			
				|  |  | +      auto* map = methods->GetRecvTrailingMetadata();
 | 
	
		
			
				|  |  | +      bool found = false;
 | 
	
		
			
				|  |  | +      // Check that we received the metadata as an echo
 | 
	
		
			
				|  |  | +      for (const auto& pair : *map) {
 | 
	
		
			
				|  |  | +        found = pair.first.starts_with("testkey") &&
 | 
	
		
			
				|  |  | +                pair.second.starts_with("testvalue");
 | 
	
		
			
				|  |  | +        if (found) break;
 | 
	
		
			
				|  |  | +      }
 | 
	
		
			
				|  |  | +      EXPECT_EQ(found, true);
 | 
	
		
			
				|  |  | +      auto* status = methods->GetRecvStatus();
 | 
	
		
			
				|  |  | +      EXPECT_EQ(status->ok(), true);
 | 
	
		
			
				|  |  | +    }
 | 
	
		
			
				|  |  | +    if (methods->QueryInterceptionHookPoint(
 | 
	
		
			
				|  |  | +            experimental::InterceptionHookPoints::PRE_RECV_MESSAGE)) {
 | 
	
		
			
				|  |  | +      if (++count_ > 10) {
 | 
	
		
			
				|  |  | +        methods->FailHijackedRecvMessage();
 | 
	
		
			
				|  |  | +      }
 | 
	
		
			
				|  |  | +      EchoResponse* resp =
 | 
	
		
			
				|  |  | +          static_cast<EchoResponse*>(methods->GetRecvMessage());
 | 
	
		
			
				|  |  | +      resp->set_message("Hello");
 | 
	
		
			
				|  |  | +    }
 | 
	
		
			
				|  |  | +    if (methods->QueryInterceptionHookPoint(
 | 
	
		
			
				|  |  | +            experimental::InterceptionHookPoints::POST_RECV_MESSAGE)) {
 | 
	
		
			
				|  |  | +      // Only the last message will be a failure
 | 
	
		
			
				|  |  | +      EXPECT_FALSE(got_failed_message_);
 | 
	
		
			
				|  |  | +      got_failed_message_ = methods->GetRecvMessage() == nullptr;
 | 
	
		
			
				|  |  | +    }
 | 
	
		
			
				|  |  | +    if (methods->QueryInterceptionHookPoint(
 | 
	
		
			
				|  |  | +            experimental::InterceptionHookPoints::PRE_RECV_STATUS)) {
 | 
	
		
			
				|  |  | +      auto* map = methods->GetRecvTrailingMetadata();
 | 
	
		
			
				|  |  | +      // insert the metadata that we want
 | 
	
		
			
				|  |  | +      EXPECT_EQ(map->size(), static_cast<unsigned>(0));
 | 
	
		
			
				|  |  | +      map->insert(std::make_pair("testkey", "testvalue"));
 | 
	
		
			
				|  |  | +      auto* status = methods->GetRecvStatus();
 | 
	
		
			
				|  |  | +      *status = Status(StatusCode::OK, "");
 | 
	
		
			
				|  |  | +    }
 | 
	
		
			
				|  |  | +    if (hijack) {
 | 
	
		
			
				|  |  | +      methods->Hijack();
 | 
	
		
			
				|  |  | +    } else {
 | 
	
		
			
				|  |  | +      methods->Proceed();
 | 
	
		
			
				|  |  | +    }
 | 
	
		
			
				|  |  | +  }
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +  static bool GotFailedMessage() { return got_failed_message_; }
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | + private:
 | 
	
		
			
				|  |  | +  experimental::ClientRpcInfo* info_;
 | 
	
		
			
				|  |  | +  static bool got_failed_message_;
 | 
	
		
			
				|  |  | +  int count_ = 0;
 | 
	
		
			
				|  |  | +};
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +bool ServerStreamingRpcHijackingInterceptor::got_failed_message_ = false;
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +class ServerStreamingRpcHijackingInterceptorFactory
 | 
	
		
			
				|  |  | +    : public experimental::ClientInterceptorFactoryInterface {
 | 
	
		
			
				|  |  | + public:
 | 
	
		
			
				|  |  | +  virtual experimental::Interceptor* CreateClientInterceptor(
 | 
	
		
			
				|  |  | +      experimental::ClientRpcInfo* info) override {
 | 
	
		
			
				|  |  | +    return new ServerStreamingRpcHijackingInterceptor(info);
 | 
	
		
			
				|  |  | +  }
 | 
	
		
			
				|  |  | +};
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |  class BidiStreamingRpcHijackingInterceptorFactory
 | 
	
		
			
				|  |  |      : public experimental::ClientInterceptorFactoryInterface {
 | 
	
		
			
				|  |  |   public:
 | 
	
	
		
			
				|  | @@ -711,6 +808,20 @@ TEST_F(ClientInterceptorsStreamingEnd2endTest, ClientStreamingHijackingTest) {
 | 
	
		
			
				|  |  |    EXPECT_TRUE(ClientStreamingRpcHijackingInterceptor::GotFailedSend());
 | 
	
		
			
				|  |  |  }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | +TEST_F(ClientInterceptorsStreamingEnd2endTest, ServerStreamingHijackingTest) {
 | 
	
		
			
				|  |  | +  ChannelArguments args;
 | 
	
		
			
				|  |  | +  DummyInterceptor::Reset();
 | 
	
		
			
				|  |  | +  std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
 | 
	
		
			
				|  |  | +      creators;
 | 
	
		
			
				|  |  | +  creators.push_back(
 | 
	
		
			
				|  |  | +      std::unique_ptr<ServerStreamingRpcHijackingInterceptorFactory>(
 | 
	
		
			
				|  |  | +          new ServerStreamingRpcHijackingInterceptorFactory()));
 | 
	
		
			
				|  |  | +  auto channel = experimental::CreateCustomChannelWithInterceptors(
 | 
	
		
			
				|  |  | +      server_address_, InsecureChannelCredentials(), args, std::move(creators));
 | 
	
		
			
				|  |  | +  MakeServerStreamingCall(channel);
 | 
	
		
			
				|  |  | +  EXPECT_TRUE(ServerStreamingRpcHijackingInterceptor::GotFailedMessage());
 | 
	
		
			
				|  |  | +}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |  TEST_F(ClientInterceptorsStreamingEnd2endTest, BidiStreamingHijackingTest) {
 | 
	
		
			
				|  |  |    ChannelArguments args;
 | 
	
		
			
				|  |  |    DummyInterceptor::Reset();
 |