Ver Fonte

Working on tests

Yash Tibrewal há 7 anos atrás
pai
commit
3a17f5b05e

+ 3 - 3
include/grpcpp/impl/codegen/client_interceptor.h

@@ -47,7 +47,7 @@ class ClientRpcInfo {
  public:
   ClientRpcInfo() {}
   ClientRpcInfo(grpc::ClientContext* ctx, const char* method,
-                const grpc::Channel* channel,
+                grpc::Channel* channel,
                 const std::vector<std::unique_ptr<
                     experimental::ClientInterceptorFactoryInterface>>& creators)
       : ctx_(ctx), method_(method), channel_(channel) {
@@ -64,7 +64,7 @@ class ClientRpcInfo {
 
   // Getter methods
   const char* method() { return method_; }
-  const Channel* channel() { return channel_; }
+  Channel* channel() { return channel_; }
   grpc::ClientContext* client_context() { return ctx_; }
 
  public:
@@ -79,7 +79,7 @@ class ClientRpcInfo {
  private:
   grpc::ClientContext* ctx_ = nullptr;
   const char* method_ = nullptr;
-  const grpc::Channel* channel_ = nullptr;
+  grpc::Channel* channel_ = nullptr;
   std::vector<std::unique_ptr<experimental::Interceptor>> interceptors_;
   bool hijacked_ = false;
   int hijacked_interceptor_ = false;

+ 14 - 15
src/cpp/server/server_cc.cc

@@ -243,13 +243,13 @@ class Server::SyncRequest final : public internal::CompletionQueueTag {
 
       interceptor_methods_.SetCall(&call_);
       interceptor_methods_.SetReverse();
-      /* Set interception point for RECV INITIAL METADATA */
+      // Set interception point for RECV INITIAL METADATA
       interceptor_methods_.AddInterceptionHookPoint(
           experimental::InterceptionHookPoints::POST_RECV_INITIAL_METADATA);
       interceptor_methods_.SetRecvInitialMetadata(&ctx_.client_metadata_);
 
       if (has_request_payload_) {
-        /* Set interception point for RECV MESSAGE */
+        // Set interception point for RECV MESSAGE
         auto* handler = resources_ ? method_->handler()
                                    : server_->resource_exhausted_handler_.get();
         request_ = handler->Deserialize(request_payload_, &request_status_);
@@ -264,8 +264,8 @@ class Server::SyncRequest final : public internal::CompletionQueueTag {
       if (interceptor_methods_.RunInterceptors(f)) {
         ContinueRunAfterInterception();
       } else {
-        /* There were interceptors to be run, so ContinueRunAfterInterception
-        will be run when interceptors are done. */
+        // There were interceptors to be run, so ContinueRunAfterInterception
+        // will be run when interceptors are done.
       }
     }
 
@@ -318,7 +318,6 @@ class Server::SyncRequest final : public internal::CompletionQueueTag {
   grpc_metadata_array request_metadata_;
   grpc_byte_buffer* request_payload_;
   grpc_completion_queue* cq_;
-  bool done_intercepting_ = false;
 };
 
 // Implementation of ThreadManager. Each instance of SyncRequestThreadManager
@@ -763,7 +762,7 @@ bool ServerInterface::BaseAsyncRequest::FinalizeResult(void** tag,
   context_->set_call(call_);
   context_->cq_ = call_cq_;
   if (call_wrapper_.call() == nullptr) {
-    /* Fill it since it is empty. */
+    // Fill it since it is empty.
     call_wrapper_ = internal::Call(
         call_, server_, call_cq_, server_->max_receive_message_size(), nullptr);
   }
@@ -773,7 +772,7 @@ bool ServerInterface::BaseAsyncRequest::FinalizeResult(void** tag,
 
   if (*status && call_ && call_wrapper_.server_rpc_info()) {
     done_intercepting_ = true;
-    /* Set interception point for RECV INITIAL METADATA */
+    // Set interception point for RECV INITIAL METADATA
     interceptor_methods_.AddInterceptionHookPoint(
         experimental::InterceptionHookPoints::POST_RECV_INITIAL_METADATA);
     interceptor_methods_.SetRecvInitialMetadata(&context_->client_metadata_);
@@ -781,11 +780,11 @@ bool ServerInterface::BaseAsyncRequest::FinalizeResult(void** tag,
                            ContinueFinalizeResultAfterInterception,
                        this);
     if (interceptor_methods_.RunInterceptors(f)) {
-      /* There are no interceptors to run. Continue */
+      // There are no interceptors to run. Continue
     } else {
-      /* There were interceptors to be run, so
-      ContinueFinalizeResultAfterInterception will be run when interceptors are
-      done. */
+      // There were interceptors to be run, so
+      // ContinueFinalizeResultAfterInterception will be run when interceptors
+      // are done.
       return false;
     }
   }
@@ -802,7 +801,7 @@ bool ServerInterface::BaseAsyncRequest::FinalizeResult(void** tag,
 void ServerInterface::BaseAsyncRequest::
     ContinueFinalizeResultAfterInterception() {
   context_->BeginCompletionOp(&call_wrapper_);
-  /* Queue a tag which will be returned immediately */
+  // Queue a tag which will be returned immediately
   dummy_alarm_ = new Alarm();
   static_cast<Alarm*>(dummy_alarm_)
       ->Set(notification_cq_,
@@ -844,7 +843,7 @@ ServerInterface::GenericAsyncRequest::GenericAsyncRequest(
 
 bool ServerInterface::GenericAsyncRequest::FinalizeResult(void** tag,
                                                           bool* status) {
-  /* If we are done intercepting, there is nothing more for us to do */
+  // If we are done intercepting, there is nothing more for us to do
   if (done_intercepting_) {
     return BaseAsyncRequest::FinalizeResult(tag, status);
   }
@@ -870,7 +869,7 @@ bool ServerInterface::GenericAsyncRequest::FinalizeResult(void** tag,
 bool Server::UnimplementedAsyncRequest::FinalizeResult(void** tag,
                                                        bool* status) {
   if (GenericAsyncRequest::FinalizeResult(tag, status)) {
-    /* We either had no interceptors run or we are done interceptinh */
+    // We either had no interceptors run or we are done intercepting
     if (*status) {
       new UnimplementedAsyncRequest(server_, cq_);
       new UnimplementedAsyncResponse(this);
@@ -878,7 +877,7 @@ bool Server::UnimplementedAsyncRequest::FinalizeResult(void** tag,
       delete this;
     }
   } else {
-    /* The tag was swallowed due to interception. We will see it again. */
+    // The tag was swallowed due to interception. We will see it again.
   }
   return false;
 }

+ 122 - 10
test/cpp/end2end/client_interceptors_end2end_test.cc

@@ -60,6 +60,8 @@ class ClientInterceptorsEnd2endTest : public ::testing::Test {
   std::unique_ptr<Server> server_;
 };
 
+/* This interceptor does nothing. Just keeps a global count on the number of
+ * times it was invoked. */
 class DummyInterceptor : public experimental::Interceptor {
  public:
   DummyInterceptor(experimental::ClientRpcInfo* info) {}
@@ -91,6 +93,7 @@ class DummyInterceptorFactory
   }
 };
 
+/* Hijacks Echo RPC and fills in the expected values */
 class HijackingInterceptor : public experimental::Interceptor {
  public:
   HijackingInterceptor(experimental::ClientRpcInfo* info) {
@@ -195,6 +198,111 @@ class HijackingInterceptorFactory
   }
 };
 
+class HijackingInterceptorMakesAnotherCall : public experimental::Interceptor {
+ public:
+  HijackingInterceptorMakesAnotherCall(experimental::ClientRpcInfo* info) {
+    info_ = info;
+    // Make sure it is the right method
+    EXPECT_EQ(strcmp("/grpc.testing.EchoTestService/Echo", info->method()), 0);
+  }
+
+  virtual void Intercept(experimental::InterceptorBatchMethods* methods) {
+    gpr_log(GPR_ERROR, "ran this");
+    bool hijack = false;
+    if (methods->QueryInterceptionHookPoint(
+            experimental::InterceptionHookPoints::PRE_SEND_INITIAL_METADATA)) {
+      auto* map = methods->GetSendInitialMetadata();
+      // Check that we can see the test metadata
+      ASSERT_EQ(map->size(), 1);
+      auto iterator = map->begin();
+      EXPECT_EQ("testkey", iterator->first);
+      EXPECT_EQ("testvalue", iterator->second);
+      hijack = true;
+      // Make a copy of the map
+      metadata_map_ = *map;
+    }
+    if (methods->QueryInterceptionHookPoint(
+            experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)) {
+      EchoRequest req;
+      auto* buffer = methods->GetSendMessage();
+      auto copied_buffer = *buffer;
+      SerializationTraits<EchoRequest>::Deserialize(&copied_buffer, &req);
+      EXPECT_EQ(req.message(), "Hello");
+      auto stub = grpc::testing::EchoTestService::NewStub(
+          std::shared_ptr<Channel>(info_->channel()));
+      ClientContext ctx;
+      EchoResponse resp;
+      Status s = stub->Echo(&ctx, req, &resp);
+      EXPECT_EQ(s.ok(), true);
+      EXPECT_EQ(resp.message(), "Hello");
+    }
+    if (methods->QueryInterceptionHookPoint(
+            experimental::InterceptionHookPoints::PRE_SEND_CLOSE)) {
+      // Got nothing to do here for now
+    }
+    if (methods->QueryInterceptionHookPoint(
+            experimental::InterceptionHookPoints::POST_RECV_INITIAL_METADATA)) {
+      auto* map = methods->GetRecvInitialMetadata();
+      // Got nothing better to do here for now
+      EXPECT_EQ(map->size(), 0);
+    }
+    if (methods->QueryInterceptionHookPoint(
+            experimental::InterceptionHookPoints::POST_RECV_MESSAGE)) {
+      EchoResponse* resp =
+          static_cast<EchoResponse*>(methods->GetRecvMessage());
+      // Check that we got the hijacked message, and re-insert the expected
+      // message
+      EXPECT_EQ(resp->message(), "Hello1");
+      resp->set_message("Hello");
+    }
+    if (methods->QueryInterceptionHookPoint(
+            experimental::InterceptionHookPoints::POST_RECV_STATUS)) {
+      auto* map = methods->GetRecvTrailingMetadata();
+      bool found = false;
+      // Check that we received the metadata as an echo
+      for (const auto& pair : *map) {
+        found = pair.first.starts_with("testkey") &&
+                pair.second.starts_with("testvalue");
+        if (found) break;
+      }
+      EXPECT_EQ(found, true);
+      auto* status = methods->GetRecvStatus();
+      EXPECT_EQ(status->ok(), true);
+    }
+    if (methods->QueryInterceptionHookPoint(
+            experimental::InterceptionHookPoints::PRE_RECV_INITIAL_METADATA)) {
+      auto* map = methods->GetRecvInitialMetadata();
+      // Got nothing better to do here at the moment
+      EXPECT_EQ(map->size(), 0);
+    }
+    if (methods->QueryInterceptionHookPoint(
+            experimental::InterceptionHookPoints::PRE_RECV_MESSAGE)) {
+      // Insert a different message than expected
+      EchoResponse* resp =
+          static_cast<EchoResponse*>(methods->GetRecvMessage());
+      resp->set_message("Hello1");
+    }
+    if (methods->QueryInterceptionHookPoint(
+            experimental::InterceptionHookPoints::PRE_RECV_STATUS)) {
+      auto* map = methods->GetRecvTrailingMetadata();
+      // insert the metadata that we want
+      EXPECT_EQ(map->size(), 0);
+      map->insert(std::make_pair("testkey", "testvalue"));
+      auto* status = methods->GetRecvStatus();
+      *status = Status(StatusCode::OK, "");
+    }
+    if (hijack) {
+      methods->Hijack();
+    } else {
+      methods->Proceed();
+    }
+  }
+
+ private:
+  experimental::ClientRpcInfo* info_;
+  std::multimap<grpc::string, grpc::string> metadata_map_;
+};
+
 class LoggingInterceptor : public experimental::Interceptor {
  public:
   LoggingInterceptor(experimental::ClientRpcInfo* info) {
@@ -268,6 +376,19 @@ class LoggingInterceptorFactory
   }
 };
 
+void MakeCall(std::shared_ptr<Channel> channel) {
+  auto stub = grpc::testing::EchoTestService::NewStub(channel);
+  ClientContext ctx;
+  EchoRequest req;
+  req.mutable_param()->set_echo_metadata(true);
+  ctx.AddMetadata("testkey", "testvalue");
+  req.set_message("Hello");
+  EchoResponse resp;
+  Status s = stub->Echo(&ctx, req, &resp);
+  EXPECT_EQ(s.ok(), true);
+  EXPECT_EQ(resp.message(), "Hello");
+}
+
 TEST_F(ClientInterceptorsEnd2endTest, ClientInterceptorLoggingTest) {
   ChannelArguments args;
   DummyInterceptor::Reset();
@@ -284,16 +405,7 @@ TEST_F(ClientInterceptorsEnd2endTest, ClientInterceptorLoggingTest) {
   }
   auto channel = experimental::CreateCustomChannelWithInterceptors(
       server_address_, InsecureChannelCredentials(), args, std::move(creators));
-  auto stub = grpc::testing::EchoTestService::NewStub(channel);
-  ClientContext ctx;
-  EchoRequest req;
-  req.mutable_param()->set_echo_metadata(true);
-  ctx.AddMetadata("testkey", "testvalue");
-  req.set_message("Hello");
-  EchoResponse resp;
-  Status s = stub->Echo(&ctx, req, &resp);
-  EXPECT_EQ(s.ok(), true);
-  EXPECT_EQ(resp.message(), "Hello");
+  MakeCall(channel);
   // Make sure all 20 dummy interceptors were run
   EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 20);
 }