Skip to content

Commit

Permalink
audit, client/http: use X-Caller-ID to replace the component signatur…
Browse files Browse the repository at this point in the history
…e key (tikv#7536)

ref tikv#7300

Use `X-Caller-ID` to replace the component signature key.

Signed-off-by: JmPotato <[email protected]>
  • Loading branch information
JmPotato authored Dec 13, 2023
1 parent ab97b9a commit f71de23
Show file tree
Hide file tree
Showing 13 changed files with 80 additions and 64 deletions.
4 changes: 2 additions & 2 deletions client/http/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ func (c *client) execDuration(name string, duration time.Duration) {
// Header key definition constants.
const (
pdAllowFollowerHandleKey = "PD-Allow-Follower-Handle"
componentSignatureKey = "component"
xCallerIDKey = "X-Caller-ID"
)

// HeaderOption configures the HTTP header.
Expand Down Expand Up @@ -279,7 +279,7 @@ func (c *client) request(
for _, opt := range headerOpts {
opt(req.Header)
}
req.Header.Set(componentSignatureKey, c.callerID)
req.Header.Set(xCallerIDKey, c.callerID)

start := time.Now()
resp, err := c.inner.cli.Do(req)
Expand Down
2 changes: 1 addition & 1 deletion client/http/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ func TestCallerID(t *testing.T) {
re := require.New(t)
expectedVal := defaultCallerID
httpClient := newHTTPClientWithRequestChecker(func(req *http.Request) error {
val := req.Header.Get(componentSignatureKey)
val := req.Header.Get(xCallerIDKey)
if val != expectedVal {
re.Failf("Caller ID header check failed",
"should be %s, but got %s", expectedVal, val)
Expand Down
2 changes: 1 addition & 1 deletion pkg/audit/audit.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ func (b *PrometheusHistogramBackend) ProcessHTTPRequest(req *http.Request) bool
if !ok {
return false
}
b.histogramVec.WithLabelValues(requestInfo.ServiceLabel, "HTTP", requestInfo.Component, requestInfo.IP).Observe(float64(endTime - requestInfo.StartTimeStamp))
b.histogramVec.WithLabelValues(requestInfo.ServiceLabel, "HTTP", requestInfo.CallerID, requestInfo.IP).Observe(float64(endTime - requestInfo.StartTimeStamp))
return true
}

Expand Down
12 changes: 6 additions & 6 deletions pkg/audit/audit_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ func TestPrometheusHistogramBackend(t *testing.T) {
Name: "audit_handling_seconds_test",
Help: "PD server service handling audit",
Buckets: prometheus.DefBuckets,
}, []string{"service", "method", "component", "ip"})
}, []string{"service", "method", "caller_id", "ip"})

prometheus.MustRegister(serviceAuditHistogramTest)

Expand All @@ -62,7 +62,7 @@ func TestPrometheusHistogramBackend(t *testing.T) {
req, _ := http.NewRequest(http.MethodGet, "http://127.0.0.1:2379/test?test=test", http.NoBody)
info := requestutil.GetRequestInfo(req)
info.ServiceLabel = "test"
info.Component = "user1"
info.CallerID = "user1"
info.IP = "localhost"
req = req.WithContext(requestutil.WithRequestInfo(req.Context(), info))
re.False(backend.ProcessHTTPRequest(req))
Expand All @@ -73,7 +73,7 @@ func TestPrometheusHistogramBackend(t *testing.T) {
re.True(backend.ProcessHTTPRequest(req))
re.True(backend.ProcessHTTPRequest(req))

info.Component = "user2"
info.CallerID = "user2"
req = req.WithContext(requestutil.WithRequestInfo(req.Context(), info))
re.True(backend.ProcessHTTPRequest(req))

Expand All @@ -85,8 +85,8 @@ func TestPrometheusHistogramBackend(t *testing.T) {
defer resp.Body.Close()
content, _ := io.ReadAll(resp.Body)
output := string(content)
re.Contains(output, "pd_service_audit_handling_seconds_test_count{component=\"user1\",ip=\"localhost\",method=\"HTTP\",service=\"test\"} 2")
re.Contains(output, "pd_service_audit_handling_seconds_test_count{component=\"user2\",ip=\"localhost\",method=\"HTTP\",service=\"test\"} 1")
re.Contains(output, "pd_service_audit_handling_seconds_test_count{caller_id=\"user1\",ip=\"localhost\",method=\"HTTP\",service=\"test\"} 2")
re.Contains(output, "pd_service_audit_handling_seconds_test_count{caller_id=\"user2\",ip=\"localhost\",method=\"HTTP\",service=\"test\"} 1")
}

func TestLocalLogBackendUsingFile(t *testing.T) {
Expand All @@ -103,7 +103,7 @@ func TestLocalLogBackendUsingFile(t *testing.T) {
b, _ := os.ReadFile(fname)
output := strings.SplitN(string(b), "]", 4)
re.Equal(
fmt.Sprintf(" [\"audit log\"] [service-info=\"{ServiceLabel:, Method:HTTP/1.1/GET:/test, Component:anonymous, IP:, Port:, "+
fmt.Sprintf(" [\"audit log\"] [service-info=\"{ServiceLabel:, Method:HTTP/1.1/GET:/test, CallerID:anonymous, IP:, Port:, "+
"StartTime:%s, URLParam:{\\\"test\\\":[\\\"test\\\"]}, BodyParam:testBody}\"]\n",
time.Unix(info.StartTimeStamp, 0).String()),
output[3],
Expand Down
55 changes: 33 additions & 22 deletions pkg/utils/apiutil/apiutil.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,14 @@ import (
"go.uber.org/zap"
)

var (
// componentSignatureKey is used for http request header key
// to identify component signature
const (
// componentSignatureKey is used for http request header key to identify component signature.
// Deprecated: please use `XCallerIDHeader` below to obtain a more granular source identification.
// This is kept for backward compatibility.
componentSignatureKey = "component"
// componentAnonymousValue identifies anonymous request source
componentAnonymousValue = "anonymous"
)
// anonymousValue identifies anonymous request source
anonymousValue = "anonymous"

const (
// PDRedirectorHeader is used to mark which PD redirected this request.
PDRedirectorHeader = "PD-Redirector"
// PDAllowFollowerHandleHeader is used to mark whether this request is allowed to be handled by the follower PD.
Expand All @@ -58,6 +57,8 @@ const (
XForwardedPortHeader = "X-Forwarded-Port"
// XRealIPHeader is used to mark the real client IP.
XRealIPHeader = "X-Real-Ip"
// XCallerIDHeader is used to mark the caller ID.
XCallerIDHeader = "X-Caller-ID"
// ForwardToMicroServiceHeader is used to mark the request is forwarded to micro service.
ForwardToMicroServiceHeader = "Forward-To-Micro-Service"

Expand Down Expand Up @@ -112,7 +113,7 @@ func ErrorResp(rd *render.Render, w http.ResponseWriter, err error) {

// GetIPPortFromHTTPRequest returns http client host IP and port from context.
// Because `X-Forwarded-For ` header has been written into RFC 7239(Forwarded HTTP Extension),
// so `X-Forwarded-For` has the higher priority than `X-Real-IP`.
// so `X-Forwarded-For` has the higher priority than `X-Real-Ip`.
// And both of them have the higher priority than `RemoteAddr`
func GetIPPortFromHTTPRequest(r *http.Request) (ip, port string) {
forwardedIPs := strings.Split(r.Header.Get(XForwardedForHeader), ",")
Expand All @@ -136,32 +137,42 @@ func GetIPPortFromHTTPRequest(r *http.Request) (ip, port string) {
return splitIP, splitPort
}

// GetComponentNameOnHTTP returns component name from Request Header
func GetComponentNameOnHTTP(r *http.Request) string {
// getComponentNameOnHTTP returns component name from the request header.
func getComponentNameOnHTTP(r *http.Request) string {
componentName := r.Header.Get(componentSignatureKey)
if len(componentName) == 0 {
componentName = componentAnonymousValue
componentName = anonymousValue
}
return componentName
}

// ComponentSignatureRoundTripper is used to add component signature in HTTP header
type ComponentSignatureRoundTripper struct {
proxied http.RoundTripper
component string
// GetCallerIDOnHTTP returns caller ID from the request header.
func GetCallerIDOnHTTP(r *http.Request) string {
callerID := r.Header.Get(XCallerIDHeader)
if len(callerID) == 0 {
// Fall back to get the component name to keep backward compatibility.
callerID = getComponentNameOnHTTP(r)
}
return callerID
}

// CallerIDRoundTripper is used to add caller ID in the HTTP header.
type CallerIDRoundTripper struct {
proxied http.RoundTripper
callerID string
}

// NewComponentSignatureRoundTripper returns a new ComponentSignatureRoundTripper.
func NewComponentSignatureRoundTripper(roundTripper http.RoundTripper, componentName string) *ComponentSignatureRoundTripper {
return &ComponentSignatureRoundTripper{
proxied: roundTripper,
component: componentName,
// NewCallerIDRoundTripper returns a new `CallerIDRoundTripper`.
func NewCallerIDRoundTripper(roundTripper http.RoundTripper, callerID string) *CallerIDRoundTripper {
return &CallerIDRoundTripper{
proxied: roundTripper,
callerID: callerID,
}
}

// RoundTrip is used to implement RoundTripper
func (rt *ComponentSignatureRoundTripper) RoundTrip(req *http.Request) (resp *http.Response, err error) {
req.Header.Add(componentSignatureKey, rt.component)
func (rt *CallerIDRoundTripper) RoundTrip(req *http.Request) (resp *http.Response, err error) {
req.Header.Add(XCallerIDHeader, rt.callerID)
// Send the request, get the response and the error
resp, err = rt.proxied.RoundTrip(req)
return
Expand Down
8 changes: 4 additions & 4 deletions pkg/utils/apiutil/apiutil_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ func TestGetIPPortFromHTTPRequest(t *testing.T) {
ip: "127.0.0.1",
port: "5299",
},
// IPv4 "X-Real-IP" with port
// IPv4 "X-Real-Ip" with port
{
r: &http.Request{
Header: map[string][]string{
Expand All @@ -111,7 +111,7 @@ func TestGetIPPortFromHTTPRequest(t *testing.T) {
ip: "127.0.0.1",
port: "5299",
},
// IPv4 "X-Real-IP" without port
// IPv4 "X-Real-Ip" without port
{
r: &http.Request{
Header: map[string][]string{
Expand Down Expand Up @@ -158,7 +158,7 @@ func TestGetIPPortFromHTTPRequest(t *testing.T) {
ip: "::1",
port: "",
},
// IPv6 "X-Real-IP" with port
// IPv6 "X-Real-Ip" with port
{
r: &http.Request{
Header: map[string][]string{
Expand All @@ -168,7 +168,7 @@ func TestGetIPPortFromHTTPRequest(t *testing.T) {
ip: "::1",
port: "5299",
},
// IPv6 "X-Real-IP" without port
// IPv6 "X-Real-Ip" without port
{
r: &http.Request{
Header: map[string][]string{
Expand Down
4 changes: 2 additions & 2 deletions pkg/utils/requestutil/context_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ func TestRequestInfo(t *testing.T) {
RequestInfo{
ServiceLabel: "test label",
Method: http.MethodPost,
Component: "pdctl",
CallerID: "pdctl",
IP: "localhost",
URLParam: "{\"id\"=1}",
BodyParam: "{\"state\"=\"Up\"}",
Expand All @@ -45,7 +45,7 @@ func TestRequestInfo(t *testing.T) {
re.True(ok)
re.Equal("test label", result.ServiceLabel)
re.Equal(http.MethodPost, result.Method)
re.Equal("pdctl", result.Component)
re.Equal("pdctl", result.CallerID)
re.Equal("localhost", result.IP)
re.Equal("{\"id\"=1}", result.URLParam)
re.Equal("{\"state\"=\"Up\"}", result.BodyParam)
Expand Down
14 changes: 8 additions & 6 deletions pkg/utils/requestutil/request_info.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,11 @@ import (

// RequestInfo holds service information from http.Request
type RequestInfo struct {
ServiceLabel string
Method string
Component string
ServiceLabel string
Method string
// CallerID is used to identify the specific source of a HTTP request, it will be marked in
// the PD HTTP client, with granularity that can be refined to a specific functionality within a component.
CallerID string
IP string
Port string
URLParam string
Expand All @@ -38,8 +40,8 @@ type RequestInfo struct {
}

func (info *RequestInfo) String() string {
s := fmt.Sprintf("{ServiceLabel:%s, Method:%s, Component:%s, IP:%s, Port:%s, StartTime:%s, URLParam:%s, BodyParam:%s}",
info.ServiceLabel, info.Method, info.Component, info.IP, info.Port, time.Unix(info.StartTimeStamp, 0), info.URLParam, info.BodyParam)
s := fmt.Sprintf("{ServiceLabel:%s, Method:%s, CallerID:%s, IP:%s, Port:%s, StartTime:%s, URLParam:%s, BodyParam:%s}",
info.ServiceLabel, info.Method, info.CallerID, info.IP, info.Port, time.Unix(info.StartTimeStamp, 0), info.URLParam, info.BodyParam)
return s
}

Expand All @@ -49,7 +51,7 @@ func GetRequestInfo(r *http.Request) RequestInfo {
return RequestInfo{
ServiceLabel: apiutil.GetRouteName(r),
Method: fmt.Sprintf("%s/%s:%s", r.Proto, r.Method, r.URL.Path),
Component: apiutil.GetComponentNameOnHTTP(r),
CallerID: apiutil.GetCallerIDOnHTTP(r),
IP: ip,
Port: port,
URLParam: getURLParam(r),
Expand Down
2 changes: 1 addition & 1 deletion server/api/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ func (rm *requestInfoMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Reques
w.Header().Add("body-param", requestInfo.BodyParam)
w.Header().Add("url-param", requestInfo.URLParam)
w.Header().Add("method", requestInfo.Method)
w.Header().Add("component", requestInfo.Component)
w.Header().Add("caller-id", requestInfo.CallerID)
w.Header().Add("ip", requestInfo.IP)
})

Expand Down
2 changes: 1 addition & 1 deletion server/metrics.go
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ var (
Name: "audit_handling_seconds",
Help: "PD server service handling audit",
Buckets: prometheus.DefBuckets,
}, []string{"service", "method", "component", "ip"})
}, []string{"service", "method", "caller_id", "ip"})
serverMaxProcs = prometheus.NewGauge(
prometheus.GaugeOpts{
Namespace: "pd",
Expand Down
12 changes: 7 additions & 5 deletions tests/pdctl/global_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,18 +30,20 @@ import (
"go.uber.org/zap"
)

const pdControlCallerID = "pd-ctl"

func TestSendAndGetComponent(t *testing.T) {
re := require.New(t)
handler := func(ctx context.Context, s *server.Server) (http.Handler, apiutil.APIServiceGroup, error) {
mux := http.NewServeMux()
mux.HandleFunc("/pd/api/v1/health", func(w http.ResponseWriter, r *http.Request) {
component := apiutil.GetComponentNameOnHTTP(r)
callerID := apiutil.GetCallerIDOnHTTP(r)
for k := range r.Header {
log.Info("header", zap.String("key", k))
}
log.Info("component", zap.String("component", component))
re.Equal("pdctl", component)
fmt.Fprint(w, component)
log.Info("caller id", zap.String("caller-id", callerID))
re.Equal(pdControlCallerID, callerID)
fmt.Fprint(w, callerID)
})
info := apiutil.APIServiceGroup{
IsCore: true,
Expand All @@ -65,5 +67,5 @@ func TestSendAndGetComponent(t *testing.T) {
args := []string{"-u", pdAddr, "health"}
output, err := ExecuteCommand(cmd, args...)
re.NoError(err)
re.Equal("pdctl\n", string(output))
re.Equal(fmt.Sprintf("%s\n", pdControlCallerID), string(output))
}
10 changes: 5 additions & 5 deletions tests/server/api/api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ func (suite *middlewareTestSuite) TestRequestInfoMiddleware() {
suite.Equal("{\"force\":[\"true\"]}", resp.Header.Get("url-param"))
suite.Equal("{\"testkey\":\"testvalue\"}", resp.Header.Get("body-param"))
suite.Equal("HTTP/1.1/POST:/pd/api/v1/debug/pprof/profile", resp.Header.Get("method"))
suite.Equal("anonymous", resp.Header.Get("component"))
suite.Equal("anonymous", resp.Header.Get("caller-id"))
suite.Equal("127.0.0.1", resp.Header.Get("ip"))

input = map[string]interface{}{
Expand Down Expand Up @@ -408,7 +408,7 @@ func (suite *middlewareTestSuite) TestAuditPrometheusBackend() {
defer resp.Body.Close()
content, _ := io.ReadAll(resp.Body)
output := string(content)
suite.Contains(output, "pd_service_audit_handling_seconds_count{component=\"anonymous\",ip=\"127.0.0.1\",method=\"HTTP\",service=\"GetTrend\"} 1")
suite.Contains(output, "pd_service_audit_handling_seconds_count{caller_id=\"anonymous\",ip=\"127.0.0.1\",method=\"HTTP\",service=\"GetTrend\"} 1")

// resign to test persist config
oldLeaderName := leader.GetServer().Name()
Expand All @@ -434,7 +434,7 @@ func (suite *middlewareTestSuite) TestAuditPrometheusBackend() {
defer resp.Body.Close()
content, _ = io.ReadAll(resp.Body)
output = string(content)
suite.Contains(output, "pd_service_audit_handling_seconds_count{component=\"anonymous\",ip=\"127.0.0.1\",method=\"HTTP\",service=\"GetTrend\"} 2")
suite.Contains(output, "pd_service_audit_handling_seconds_count{caller_id=\"anonymous\",ip=\"127.0.0.1\",method=\"HTTP\",service=\"GetTrend\"} 2")

input = map[string]interface{}{
"enable-audit": "false",
Expand Down Expand Up @@ -543,15 +543,15 @@ func BenchmarkDoRequestWithoutServiceMiddleware(b *testing.B) {

func doTestRequestWithLogAudit(srv *tests.TestServer) {
req, _ := http.NewRequest(http.MethodDelete, fmt.Sprintf("%s/pd/api/v1/admin/cache/regions", srv.GetAddr()), http.NoBody)
req.Header.Set("component", "test")
req.Header.Set(apiutil.XCallerIDHeader, "test")
resp, _ := dialClient.Do(req)
resp.Body.Close()
}

func doTestRequestWithPrometheus(srv *tests.TestServer) {
timeUnix := time.Now().Unix() - 20
req, _ := http.NewRequest(http.MethodGet, fmt.Sprintf("%s/pd/api/v1/trend?from=%d", srv.GetAddr(), timeUnix), http.NoBody)
req.Header.Set("component", "test")
req.Header.Set(apiutil.XCallerIDHeader, "test")
resp, _ := dialClient.Do(req)
resp.Body.Close()
}
Expand Down
17 changes: 9 additions & 8 deletions tools/pd-ctl/pdctl/command/global.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,15 @@ import (
"go.etcd.io/etcd/pkg/transport"
)

var (
pdControllerComponentName = "pdctl"
dialClient = &http.Client{
Transport: apiutil.NewComponentSignatureRoundTripper(http.DefaultTransport, pdControllerComponentName),
}
pingPrefix = "pd/api/v1/ping"
const (
pdControlCallerID = "pd-ctl"
pingPrefix = "pd/api/v1/ping"
)

var dialClient = &http.Client{
Transport: apiutil.NewCallerIDRoundTripper(http.DefaultTransport, pdControlCallerID),
}

// InitHTTPSClient creates https client with ca file
func InitHTTPSClient(caPath, certPath, keyPath string) error {
tlsInfo := transport.TLSInfo{
Expand All @@ -50,8 +51,8 @@ func InitHTTPSClient(caPath, certPath, keyPath string) error {
}

dialClient = &http.Client{
Transport: apiutil.NewComponentSignatureRoundTripper(
&http.Transport{TLSClientConfig: tlsConfig}, pdControllerComponentName),
Transport: apiutil.NewCallerIDRoundTripper(
&http.Transport{TLSClientConfig: tlsConfig}, pdControlCallerID),
}

return nil
Expand Down

0 comments on commit f71de23

Please sign in to comment.