From f71de230a6cf9f15c54a1efd237d062c039a0bde Mon Sep 17 00:00:00 2001 From: JmPotato Date: Wed, 13 Dec 2023 13:15:18 +0800 Subject: [PATCH] audit, client/http: use X-Caller-ID to replace the component signature key (#7536) ref tikv/pd#7300 Use `X-Caller-ID` to replace the component signature key. Signed-off-by: JmPotato --- client/http/client.go | 4 +- client/http/client_test.go | 2 +- pkg/audit/audit.go | 2 +- pkg/audit/audit_test.go | 12 +++--- pkg/utils/apiutil/apiutil.go | 55 ++++++++++++++++----------- pkg/utils/apiutil/apiutil_test.go | 8 ++-- pkg/utils/requestutil/context_test.go | 4 +- pkg/utils/requestutil/request_info.go | 14 ++++--- server/api/middleware.go | 2 +- server/metrics.go | 2 +- tests/pdctl/global_test.go | 12 +++--- tests/server/api/api_test.go | 10 ++--- tools/pd-ctl/pdctl/command/global.go | 17 +++++---- 13 files changed, 80 insertions(+), 64 deletions(-) diff --git a/client/http/client.go b/client/http/client.go index b79aa9ca002..613ebf33294 100644 --- a/client/http/client.go +++ b/client/http/client.go @@ -221,7 +221,7 @@ func (c *client) execDuration(name string, duration time.Duration) { // Header key definition constants. const ( pdAllowFollowerHandleKey = "PD-Allow-Follower-Handle" - componentSignatureKey = "component" + xCallerIDKey = "X-Caller-ID" ) // HeaderOption configures the HTTP header. @@ -279,7 +279,7 @@ func (c *client) request( for _, opt := range headerOpts { opt(req.Header) } - req.Header.Set(componentSignatureKey, c.callerID) + req.Header.Set(xCallerIDKey, c.callerID) start := time.Now() resp, err := c.inner.cli.Do(req) diff --git a/client/http/client_test.go b/client/http/client_test.go index 621910e29ea..70c2ddee08b 100644 --- a/client/http/client_test.go +++ b/client/http/client_test.go @@ -59,7 +59,7 @@ func TestCallerID(t *testing.T) { re := require.New(t) expectedVal := defaultCallerID httpClient := newHTTPClientWithRequestChecker(func(req *http.Request) error { - val := req.Header.Get(componentSignatureKey) + val := req.Header.Get(xCallerIDKey) if val != expectedVal { re.Failf("Caller ID header check failed", "should be %s, but got %s", expectedVal, val) diff --git a/pkg/audit/audit.go b/pkg/audit/audit.go index 3553e5f8377..b971b09ed7e 100644 --- a/pkg/audit/audit.go +++ b/pkg/audit/audit.go @@ -98,7 +98,7 @@ func (b *PrometheusHistogramBackend) ProcessHTTPRequest(req *http.Request) bool if !ok { return false } - b.histogramVec.WithLabelValues(requestInfo.ServiceLabel, "HTTP", requestInfo.Component, requestInfo.IP).Observe(float64(endTime - requestInfo.StartTimeStamp)) + b.histogramVec.WithLabelValues(requestInfo.ServiceLabel, "HTTP", requestInfo.CallerID, requestInfo.IP).Observe(float64(endTime - requestInfo.StartTimeStamp)) return true } diff --git a/pkg/audit/audit_test.go b/pkg/audit/audit_test.go index d59c9627115..8098b36975e 100644 --- a/pkg/audit/audit_test.go +++ b/pkg/audit/audit_test.go @@ -51,7 +51,7 @@ func TestPrometheusHistogramBackend(t *testing.T) { Name: "audit_handling_seconds_test", Help: "PD server service handling audit", Buckets: prometheus.DefBuckets, - }, []string{"service", "method", "component", "ip"}) + }, []string{"service", "method", "caller_id", "ip"}) prometheus.MustRegister(serviceAuditHistogramTest) @@ -62,7 +62,7 @@ func TestPrometheusHistogramBackend(t *testing.T) { req, _ := http.NewRequest(http.MethodGet, "http://127.0.0.1:2379/test?test=test", http.NoBody) info := requestutil.GetRequestInfo(req) info.ServiceLabel = "test" - info.Component = "user1" + info.CallerID = "user1" info.IP = "localhost" req = req.WithContext(requestutil.WithRequestInfo(req.Context(), info)) re.False(backend.ProcessHTTPRequest(req)) @@ -73,7 +73,7 @@ func TestPrometheusHistogramBackend(t *testing.T) { re.True(backend.ProcessHTTPRequest(req)) re.True(backend.ProcessHTTPRequest(req)) - info.Component = "user2" + info.CallerID = "user2" req = req.WithContext(requestutil.WithRequestInfo(req.Context(), info)) re.True(backend.ProcessHTTPRequest(req)) @@ -85,8 +85,8 @@ func TestPrometheusHistogramBackend(t *testing.T) { defer resp.Body.Close() content, _ := io.ReadAll(resp.Body) output := string(content) - re.Contains(output, "pd_service_audit_handling_seconds_test_count{component=\"user1\",ip=\"localhost\",method=\"HTTP\",service=\"test\"} 2") - re.Contains(output, "pd_service_audit_handling_seconds_test_count{component=\"user2\",ip=\"localhost\",method=\"HTTP\",service=\"test\"} 1") + re.Contains(output, "pd_service_audit_handling_seconds_test_count{caller_id=\"user1\",ip=\"localhost\",method=\"HTTP\",service=\"test\"} 2") + re.Contains(output, "pd_service_audit_handling_seconds_test_count{caller_id=\"user2\",ip=\"localhost\",method=\"HTTP\",service=\"test\"} 1") } func TestLocalLogBackendUsingFile(t *testing.T) { @@ -103,7 +103,7 @@ func TestLocalLogBackendUsingFile(t *testing.T) { b, _ := os.ReadFile(fname) output := strings.SplitN(string(b), "]", 4) re.Equal( - fmt.Sprintf(" [\"audit log\"] [service-info=\"{ServiceLabel:, Method:HTTP/1.1/GET:/test, Component:anonymous, IP:, Port:, "+ + fmt.Sprintf(" [\"audit log\"] [service-info=\"{ServiceLabel:, Method:HTTP/1.1/GET:/test, CallerID:anonymous, IP:, Port:, "+ "StartTime:%s, URLParam:{\\\"test\\\":[\\\"test\\\"]}, BodyParam:testBody}\"]\n", time.Unix(info.StartTimeStamp, 0).String()), output[3], diff --git a/pkg/utils/apiutil/apiutil.go b/pkg/utils/apiutil/apiutil.go index 164b3c0783d..53fab682fcb 100644 --- a/pkg/utils/apiutil/apiutil.go +++ b/pkg/utils/apiutil/apiutil.go @@ -39,15 +39,14 @@ import ( "go.uber.org/zap" ) -var ( - // componentSignatureKey is used for http request header key - // to identify component signature +const ( + // componentSignatureKey is used for http request header key to identify component signature. + // Deprecated: please use `XCallerIDHeader` below to obtain a more granular source identification. + // This is kept for backward compatibility. componentSignatureKey = "component" - // componentAnonymousValue identifies anonymous request source - componentAnonymousValue = "anonymous" -) + // anonymousValue identifies anonymous request source + anonymousValue = "anonymous" -const ( // PDRedirectorHeader is used to mark which PD redirected this request. PDRedirectorHeader = "PD-Redirector" // PDAllowFollowerHandleHeader is used to mark whether this request is allowed to be handled by the follower PD. @@ -58,6 +57,8 @@ const ( XForwardedPortHeader = "X-Forwarded-Port" // XRealIPHeader is used to mark the real client IP. XRealIPHeader = "X-Real-Ip" + // XCallerIDHeader is used to mark the caller ID. + XCallerIDHeader = "X-Caller-ID" // ForwardToMicroServiceHeader is used to mark the request is forwarded to micro service. ForwardToMicroServiceHeader = "Forward-To-Micro-Service" @@ -112,7 +113,7 @@ func ErrorResp(rd *render.Render, w http.ResponseWriter, err error) { // GetIPPortFromHTTPRequest returns http client host IP and port from context. // Because `X-Forwarded-For ` header has been written into RFC 7239(Forwarded HTTP Extension), -// so `X-Forwarded-For` has the higher priority than `X-Real-IP`. +// so `X-Forwarded-For` has the higher priority than `X-Real-Ip`. // And both of them have the higher priority than `RemoteAddr` func GetIPPortFromHTTPRequest(r *http.Request) (ip, port string) { forwardedIPs := strings.Split(r.Header.Get(XForwardedForHeader), ",") @@ -136,32 +137,42 @@ func GetIPPortFromHTTPRequest(r *http.Request) (ip, port string) { return splitIP, splitPort } -// GetComponentNameOnHTTP returns component name from Request Header -func GetComponentNameOnHTTP(r *http.Request) string { +// getComponentNameOnHTTP returns component name from the request header. +func getComponentNameOnHTTP(r *http.Request) string { componentName := r.Header.Get(componentSignatureKey) if len(componentName) == 0 { - componentName = componentAnonymousValue + componentName = anonymousValue } return componentName } -// ComponentSignatureRoundTripper is used to add component signature in HTTP header -type ComponentSignatureRoundTripper struct { - proxied http.RoundTripper - component string +// GetCallerIDOnHTTP returns caller ID from the request header. +func GetCallerIDOnHTTP(r *http.Request) string { + callerID := r.Header.Get(XCallerIDHeader) + if len(callerID) == 0 { + // Fall back to get the component name to keep backward compatibility. + callerID = getComponentNameOnHTTP(r) + } + return callerID +} + +// CallerIDRoundTripper is used to add caller ID in the HTTP header. +type CallerIDRoundTripper struct { + proxied http.RoundTripper + callerID string } -// NewComponentSignatureRoundTripper returns a new ComponentSignatureRoundTripper. -func NewComponentSignatureRoundTripper(roundTripper http.RoundTripper, componentName string) *ComponentSignatureRoundTripper { - return &ComponentSignatureRoundTripper{ - proxied: roundTripper, - component: componentName, +// NewCallerIDRoundTripper returns a new `CallerIDRoundTripper`. +func NewCallerIDRoundTripper(roundTripper http.RoundTripper, callerID string) *CallerIDRoundTripper { + return &CallerIDRoundTripper{ + proxied: roundTripper, + callerID: callerID, } } // RoundTrip is used to implement RoundTripper -func (rt *ComponentSignatureRoundTripper) RoundTrip(req *http.Request) (resp *http.Response, err error) { - req.Header.Add(componentSignatureKey, rt.component) +func (rt *CallerIDRoundTripper) RoundTrip(req *http.Request) (resp *http.Response, err error) { + req.Header.Add(XCallerIDHeader, rt.callerID) // Send the request, get the response and the error resp, err = rt.proxied.RoundTrip(req) return diff --git a/pkg/utils/apiutil/apiutil_test.go b/pkg/utils/apiutil/apiutil_test.go index a4e7b97aa4d..106d3fb21cb 100644 --- a/pkg/utils/apiutil/apiutil_test.go +++ b/pkg/utils/apiutil/apiutil_test.go @@ -101,7 +101,7 @@ func TestGetIPPortFromHTTPRequest(t *testing.T) { ip: "127.0.0.1", port: "5299", }, - // IPv4 "X-Real-IP" with port + // IPv4 "X-Real-Ip" with port { r: &http.Request{ Header: map[string][]string{ @@ -111,7 +111,7 @@ func TestGetIPPortFromHTTPRequest(t *testing.T) { ip: "127.0.0.1", port: "5299", }, - // IPv4 "X-Real-IP" without port + // IPv4 "X-Real-Ip" without port { r: &http.Request{ Header: map[string][]string{ @@ -158,7 +158,7 @@ func TestGetIPPortFromHTTPRequest(t *testing.T) { ip: "::1", port: "", }, - // IPv6 "X-Real-IP" with port + // IPv6 "X-Real-Ip" with port { r: &http.Request{ Header: map[string][]string{ @@ -168,7 +168,7 @@ func TestGetIPPortFromHTTPRequest(t *testing.T) { ip: "::1", port: "5299", }, - // IPv6 "X-Real-IP" without port + // IPv6 "X-Real-Ip" without port { r: &http.Request{ Header: map[string][]string{ diff --git a/pkg/utils/requestutil/context_test.go b/pkg/utils/requestutil/context_test.go index 475b109e410..298fc1ff8a3 100644 --- a/pkg/utils/requestutil/context_test.go +++ b/pkg/utils/requestutil/context_test.go @@ -34,7 +34,7 @@ func TestRequestInfo(t *testing.T) { RequestInfo{ ServiceLabel: "test label", Method: http.MethodPost, - Component: "pdctl", + CallerID: "pdctl", IP: "localhost", URLParam: "{\"id\"=1}", BodyParam: "{\"state\"=\"Up\"}", @@ -45,7 +45,7 @@ func TestRequestInfo(t *testing.T) { re.True(ok) re.Equal("test label", result.ServiceLabel) re.Equal(http.MethodPost, result.Method) - re.Equal("pdctl", result.Component) + re.Equal("pdctl", result.CallerID) re.Equal("localhost", result.IP) re.Equal("{\"id\"=1}", result.URLParam) re.Equal("{\"state\"=\"Up\"}", result.BodyParam) diff --git a/pkg/utils/requestutil/request_info.go b/pkg/utils/requestutil/request_info.go index 40724bb790f..cc5403f7232 100644 --- a/pkg/utils/requestutil/request_info.go +++ b/pkg/utils/requestutil/request_info.go @@ -27,9 +27,11 @@ import ( // RequestInfo holds service information from http.Request type RequestInfo struct { - ServiceLabel string - Method string - Component string + ServiceLabel string + Method string + // CallerID is used to identify the specific source of a HTTP request, it will be marked in + // the PD HTTP client, with granularity that can be refined to a specific functionality within a component. + CallerID string IP string Port string URLParam string @@ -38,8 +40,8 @@ type RequestInfo struct { } func (info *RequestInfo) String() string { - s := fmt.Sprintf("{ServiceLabel:%s, Method:%s, Component:%s, IP:%s, Port:%s, StartTime:%s, URLParam:%s, BodyParam:%s}", - info.ServiceLabel, info.Method, info.Component, info.IP, info.Port, time.Unix(info.StartTimeStamp, 0), info.URLParam, info.BodyParam) + s := fmt.Sprintf("{ServiceLabel:%s, Method:%s, CallerID:%s, IP:%s, Port:%s, StartTime:%s, URLParam:%s, BodyParam:%s}", + info.ServiceLabel, info.Method, info.CallerID, info.IP, info.Port, time.Unix(info.StartTimeStamp, 0), info.URLParam, info.BodyParam) return s } @@ -49,7 +51,7 @@ func GetRequestInfo(r *http.Request) RequestInfo { return RequestInfo{ ServiceLabel: apiutil.GetRouteName(r), Method: fmt.Sprintf("%s/%s:%s", r.Proto, r.Method, r.URL.Path), - Component: apiutil.GetComponentNameOnHTTP(r), + CallerID: apiutil.GetCallerIDOnHTTP(r), IP: ip, Port: port, URLParam: getURLParam(r), diff --git a/server/api/middleware.go b/server/api/middleware.go index 627d7fecc92..4173c37b396 100644 --- a/server/api/middleware.go +++ b/server/api/middleware.go @@ -69,7 +69,7 @@ func (rm *requestInfoMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Reques w.Header().Add("body-param", requestInfo.BodyParam) w.Header().Add("url-param", requestInfo.URLParam) w.Header().Add("method", requestInfo.Method) - w.Header().Add("component", requestInfo.Component) + w.Header().Add("caller-id", requestInfo.CallerID) w.Header().Add("ip", requestInfo.IP) }) diff --git a/server/metrics.go b/server/metrics.go index 2d13d02d564..54c5830dc52 100644 --- a/server/metrics.go +++ b/server/metrics.go @@ -151,7 +151,7 @@ var ( Name: "audit_handling_seconds", Help: "PD server service handling audit", Buckets: prometheus.DefBuckets, - }, []string{"service", "method", "component", "ip"}) + }, []string{"service", "method", "caller_id", "ip"}) serverMaxProcs = prometheus.NewGauge( prometheus.GaugeOpts{ Namespace: "pd", diff --git a/tests/pdctl/global_test.go b/tests/pdctl/global_test.go index 7e57f589249..00d31a384d5 100644 --- a/tests/pdctl/global_test.go +++ b/tests/pdctl/global_test.go @@ -30,18 +30,20 @@ import ( "go.uber.org/zap" ) +const pdControlCallerID = "pd-ctl" + func TestSendAndGetComponent(t *testing.T) { re := require.New(t) handler := func(ctx context.Context, s *server.Server) (http.Handler, apiutil.APIServiceGroup, error) { mux := http.NewServeMux() mux.HandleFunc("/pd/api/v1/health", func(w http.ResponseWriter, r *http.Request) { - component := apiutil.GetComponentNameOnHTTP(r) + callerID := apiutil.GetCallerIDOnHTTP(r) for k := range r.Header { log.Info("header", zap.String("key", k)) } - log.Info("component", zap.String("component", component)) - re.Equal("pdctl", component) - fmt.Fprint(w, component) + log.Info("caller id", zap.String("caller-id", callerID)) + re.Equal(pdControlCallerID, callerID) + fmt.Fprint(w, callerID) }) info := apiutil.APIServiceGroup{ IsCore: true, @@ -65,5 +67,5 @@ func TestSendAndGetComponent(t *testing.T) { args := []string{"-u", pdAddr, "health"} output, err := ExecuteCommand(cmd, args...) re.NoError(err) - re.Equal("pdctl\n", string(output)) + re.Equal(fmt.Sprintf("%s\n", pdControlCallerID), string(output)) } diff --git a/tests/server/api/api_test.go b/tests/server/api/api_test.go index 905fb8ec096..f5db6bb2513 100644 --- a/tests/server/api/api_test.go +++ b/tests/server/api/api_test.go @@ -164,7 +164,7 @@ func (suite *middlewareTestSuite) TestRequestInfoMiddleware() { suite.Equal("{\"force\":[\"true\"]}", resp.Header.Get("url-param")) suite.Equal("{\"testkey\":\"testvalue\"}", resp.Header.Get("body-param")) suite.Equal("HTTP/1.1/POST:/pd/api/v1/debug/pprof/profile", resp.Header.Get("method")) - suite.Equal("anonymous", resp.Header.Get("component")) + suite.Equal("anonymous", resp.Header.Get("caller-id")) suite.Equal("127.0.0.1", resp.Header.Get("ip")) input = map[string]interface{}{ @@ -408,7 +408,7 @@ func (suite *middlewareTestSuite) TestAuditPrometheusBackend() { defer resp.Body.Close() content, _ := io.ReadAll(resp.Body) output := string(content) - suite.Contains(output, "pd_service_audit_handling_seconds_count{component=\"anonymous\",ip=\"127.0.0.1\",method=\"HTTP\",service=\"GetTrend\"} 1") + suite.Contains(output, "pd_service_audit_handling_seconds_count{caller_id=\"anonymous\",ip=\"127.0.0.1\",method=\"HTTP\",service=\"GetTrend\"} 1") // resign to test persist config oldLeaderName := leader.GetServer().Name() @@ -434,7 +434,7 @@ func (suite *middlewareTestSuite) TestAuditPrometheusBackend() { defer resp.Body.Close() content, _ = io.ReadAll(resp.Body) output = string(content) - suite.Contains(output, "pd_service_audit_handling_seconds_count{component=\"anonymous\",ip=\"127.0.0.1\",method=\"HTTP\",service=\"GetTrend\"} 2") + suite.Contains(output, "pd_service_audit_handling_seconds_count{caller_id=\"anonymous\",ip=\"127.0.0.1\",method=\"HTTP\",service=\"GetTrend\"} 2") input = map[string]interface{}{ "enable-audit": "false", @@ -543,7 +543,7 @@ func BenchmarkDoRequestWithoutServiceMiddleware(b *testing.B) { func doTestRequestWithLogAudit(srv *tests.TestServer) { req, _ := http.NewRequest(http.MethodDelete, fmt.Sprintf("%s/pd/api/v1/admin/cache/regions", srv.GetAddr()), http.NoBody) - req.Header.Set("component", "test") + req.Header.Set(apiutil.XCallerIDHeader, "test") resp, _ := dialClient.Do(req) resp.Body.Close() } @@ -551,7 +551,7 @@ func doTestRequestWithLogAudit(srv *tests.TestServer) { func doTestRequestWithPrometheus(srv *tests.TestServer) { timeUnix := time.Now().Unix() - 20 req, _ := http.NewRequest(http.MethodGet, fmt.Sprintf("%s/pd/api/v1/trend?from=%d", srv.GetAddr(), timeUnix), http.NoBody) - req.Header.Set("component", "test") + req.Header.Set(apiutil.XCallerIDHeader, "test") resp, _ := dialClient.Do(req) resp.Body.Close() } diff --git a/tools/pd-ctl/pdctl/command/global.go b/tools/pd-ctl/pdctl/command/global.go index 5d8552da51a..0b1f4b4409a 100644 --- a/tools/pd-ctl/pdctl/command/global.go +++ b/tools/pd-ctl/pdctl/command/global.go @@ -29,14 +29,15 @@ import ( "go.etcd.io/etcd/pkg/transport" ) -var ( - pdControllerComponentName = "pdctl" - dialClient = &http.Client{ - Transport: apiutil.NewComponentSignatureRoundTripper(http.DefaultTransport, pdControllerComponentName), - } - pingPrefix = "pd/api/v1/ping" +const ( + pdControlCallerID = "pd-ctl" + pingPrefix = "pd/api/v1/ping" ) +var dialClient = &http.Client{ + Transport: apiutil.NewCallerIDRoundTripper(http.DefaultTransport, pdControlCallerID), +} + // InitHTTPSClient creates https client with ca file func InitHTTPSClient(caPath, certPath, keyPath string) error { tlsInfo := transport.TLSInfo{ @@ -50,8 +51,8 @@ func InitHTTPSClient(caPath, certPath, keyPath string) error { } dialClient = &http.Client{ - Transport: apiutil.NewComponentSignatureRoundTripper( - &http.Transport{TLSClientConfig: tlsConfig}, pdControllerComponentName), + Transport: apiutil.NewCallerIDRoundTripper( + &http.Transport{TLSClientConfig: tlsConfig}, pdControlCallerID), } return nil