From 186c81594bb93ab826c1752f085aedff16a5f146 Mon Sep 17 00:00:00 2001 From: Yongbo Jiang Date: Mon, 13 Nov 2023 12:15:43 +0800 Subject: [PATCH] api: add rule middleware (#7357) ref tikv/pd#5839 Signed-off-by: Cabinfever_B Co-authored-by: ti-chi-bot[bot] <108142056+ti-chi-bot[bot]@users.noreply.github.com> --- server/api/router.go | 44 ++++----- server/api/rule.go | 220 +++++++++++-------------------------------- 2 files changed, 80 insertions(+), 184 deletions(-) diff --git a/server/api/router.go b/server/api/router.go index 5ec74908c0d..150f0eb47ba 100644 --- a/server/api/router.go +++ b/server/api/router.go @@ -171,29 +171,31 @@ func createRouter(prefix string, svr *server.Server) *mux.Router { registerFunc(apiRouter, "/config/replication-mode", confHandler.SetReplicationModeConfig, setMethods(http.MethodPost), setAuditBackend(localLog, prometheus)) rulesHandler := newRulesHandler(svr, rd) - registerFunc(clusterRouter, "/config/rules", rulesHandler.GetAllRules, setMethods(http.MethodGet), setAuditBackend(prometheus)) - registerFunc(clusterRouter, "/config/rules", rulesHandler.SetAllRules, setMethods(http.MethodPost), setAuditBackend(localLog, prometheus)) - registerFunc(clusterRouter, "/config/rules/batch", rulesHandler.BatchRules, setMethods(http.MethodPost), setAuditBackend(localLog, prometheus)) - registerFunc(clusterRouter, "/config/rules/group/{group}", rulesHandler.GetRuleByGroup, setMethods(http.MethodGet), setAuditBackend(prometheus)) - registerFunc(clusterRouter, "/config/rules/region/{region}", rulesHandler.GetRulesByRegion, setMethods(http.MethodGet), setAuditBackend(prometheus)) - registerFunc(clusterRouter, "/config/rules/region/{region}/detail", rulesHandler.CheckRegionPlacementRule, setMethods(http.MethodGet), setAuditBackend(prometheus)) - registerFunc(clusterRouter, "/config/rules/key/{key}", rulesHandler.GetRulesByKey, setMethods(http.MethodGet), setAuditBackend(prometheus)) - registerFunc(clusterRouter, "/config/rule/{group}/{id}", rulesHandler.GetRuleByGroupAndID, setMethods(http.MethodGet), setAuditBackend(prometheus)) - registerFunc(clusterRouter, "/config/rule", rulesHandler.SetRule, setMethods(http.MethodPost), setAuditBackend(localLog, prometheus)) - registerFunc(clusterRouter, "/config/rule/{group}/{id}", rulesHandler.DeleteRuleByGroup, setMethods(http.MethodDelete), setAuditBackend(localLog, prometheus)) - - registerFunc(clusterRouter, "/config/rule_group/{id}", rulesHandler.GetGroupConfig, setMethods(http.MethodGet), setAuditBackend(prometheus)) - registerFunc(clusterRouter, "/config/rule_group", rulesHandler.SetGroupConfig, setMethods(http.MethodPost), setAuditBackend(localLog, prometheus)) - registerFunc(clusterRouter, "/config/rule_group/{id}", rulesHandler.DeleteGroupConfig, setMethods(http.MethodDelete), setAuditBackend(localLog, prometheus)) - registerFunc(clusterRouter, "/config/rule_groups", rulesHandler.GetAllGroupConfigs, setMethods(http.MethodGet), setAuditBackend(prometheus)) - - registerFunc(clusterRouter, "/config/placement-rule", rulesHandler.GetPlacementRules, setMethods(http.MethodGet), setAuditBackend(prometheus)) - registerFunc(clusterRouter, "/config/placement-rule", rulesHandler.SetPlacementRules, setMethods(http.MethodPost), setAuditBackend(localLog, prometheus)) + ruleRouter := clusterRouter.NewRoute().Subrouter() + ruleRouter.Use(newRuleMiddleware(svr, rd).Middleware) + registerFunc(ruleRouter, "/config/rules", rulesHandler.GetAllRules, setMethods(http.MethodGet), setAuditBackend(prometheus)) + registerFunc(ruleRouter, "/config/rules", rulesHandler.SetAllRules, setMethods(http.MethodPost), setAuditBackend(localLog, prometheus)) + registerFunc(ruleRouter, "/config/rules/batch", rulesHandler.BatchRules, setMethods(http.MethodPost), setAuditBackend(localLog, prometheus)) + registerFunc(ruleRouter, "/config/rules/group/{group}", rulesHandler.GetRuleByGroup, setMethods(http.MethodGet), setAuditBackend(prometheus)) + registerFunc(ruleRouter, "/config/rules/region/{region}", rulesHandler.GetRulesByRegion, setMethods(http.MethodGet), setAuditBackend(prometheus)) + registerFunc(ruleRouter, "/config/rules/region/{region}/detail", rulesHandler.CheckRegionPlacementRule, setMethods(http.MethodGet), setAuditBackend(prometheus)) + registerFunc(ruleRouter, "/config/rules/key/{key}", rulesHandler.GetRulesByKey, setMethods(http.MethodGet), setAuditBackend(prometheus)) + registerFunc(ruleRouter, "/config/rule/{group}/{id}", rulesHandler.GetRuleByGroupAndID, setMethods(http.MethodGet), setAuditBackend(prometheus)) + registerFunc(ruleRouter, "/config/rule", rulesHandler.SetRule, setMethods(http.MethodPost), setAuditBackend(localLog, prometheus)) + registerFunc(ruleRouter, "/config/rule/{group}/{id}", rulesHandler.DeleteRuleByGroup, setMethods(http.MethodDelete), setAuditBackend(localLog, prometheus)) + + registerFunc(ruleRouter, "/config/rule_group/{id}", rulesHandler.GetGroupConfig, setMethods(http.MethodGet), setAuditBackend(prometheus)) + registerFunc(ruleRouter, "/config/rule_group", rulesHandler.SetGroupConfig, setMethods(http.MethodPost), setAuditBackend(localLog, prometheus)) + registerFunc(ruleRouter, "/config/rule_group/{id}", rulesHandler.DeleteGroupConfig, setMethods(http.MethodDelete), setAuditBackend(localLog, prometheus)) + registerFunc(ruleRouter, "/config/rule_groups", rulesHandler.GetAllGroupConfigs, setMethods(http.MethodGet), setAuditBackend(prometheus)) + + registerFunc(ruleRouter, "/config/placement-rule", rulesHandler.GetPlacementRules, setMethods(http.MethodGet), setAuditBackend(prometheus)) + registerFunc(ruleRouter, "/config/placement-rule", rulesHandler.SetPlacementRules, setMethods(http.MethodPost), setAuditBackend(localLog, prometheus)) // {group} can be a regular expression, we should enable path encode to // support special characters. - registerFunc(clusterRouter, "/config/placement-rule/{group}", rulesHandler.GetPlacementRuleByGroup, setMethods(http.MethodGet), setAuditBackend(prometheus)) - registerFunc(clusterRouter, "/config/placement-rule/{group}", rulesHandler.SetPlacementRuleByGroup, setMethods(http.MethodPost), setAuditBackend(localLog, prometheus)) - registerFunc(escapeRouter, "/config/placement-rule/{group}", rulesHandler.DeletePlacementRuleByGroup, setMethods(http.MethodDelete), setAuditBackend(localLog, prometheus)) + registerFunc(ruleRouter, "/config/placement-rule/{group}", rulesHandler.GetPlacementRuleByGroup, setMethods(http.MethodGet), setAuditBackend(prometheus)) + registerFunc(ruleRouter, "/config/placement-rule/{group}", rulesHandler.SetPlacementRuleByGroup, setMethods(http.MethodPost), setAuditBackend(localLog, prometheus)) + registerFunc(ruleRouter, "/config/placement-rule/{group}", rulesHandler.DeletePlacementRuleByGroup, setMethods(http.MethodDelete), setAuditBackend(localLog, prometheus)) regionLabelHandler := newRegionLabelHandler(svr, rd) registerFunc(clusterRouter, "/config/region-label/rules", regionLabelHandler.GetAllRegionLabelRules, setMethods(http.MethodGet), setAuditBackend(prometheus)) diff --git a/server/api/rule.go b/server/api/rule.go index 77aad42eb42..47964d594be 100644 --- a/server/api/rule.go +++ b/server/api/rule.go @@ -15,6 +15,7 @@ package api import ( + "context" "encoding/hex" "fmt" "net/http" @@ -42,6 +43,42 @@ func newRulesHandler(svr *server.Server, rd *render.Render) *ruleHandler { } } +type ruleMiddleware struct { + s *server.Server + rd *render.Render + *server.Handler +} + +func newRuleMiddleware(s *server.Server, rd *render.Render) ruleMiddleware { + return ruleMiddleware{ + s: s, + rd: rd, + Handler: s.GetHandler(), + } +} + +func (m ruleMiddleware) Middleware(h http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + manager, err := m.GetRuleManager() + if err == errs.ErrPlacementDisabled { + m.rd.JSON(w, http.StatusPreconditionFailed, err.Error()) + return + } + if err != nil { + m.rd.JSON(w, http.StatusInternalServerError, err.Error()) + return + } + ctx := context.WithValue(r.Context(), ruleCtxKey{}, manager) + h.ServeHTTP(w, r.WithContext(ctx)) + }) +} + +type ruleCtxKey struct{} + +func getRuleManager(r *http.Request) *placement.RuleManager { + return r.Context().Value(ruleCtxKey{}).(*placement.RuleManager) +} + // @Tags rule // @Summary List all rules of cluster. // @Produce json @@ -50,15 +87,7 @@ func newRulesHandler(svr *server.Server, rd *render.Render) *ruleHandler { // @Failure 500 {string} string "PD server failed to proceed the request." // @Router /config/rules [get] func (h *ruleHandler) GetAllRules(w http.ResponseWriter, r *http.Request) { - manager, err := h.Handler.GetRuleManager() - if err == errs.ErrPlacementDisabled { - h.rd.JSON(w, http.StatusPreconditionFailed, err.Error()) - return - } - if err != nil { - h.rd.JSON(w, http.StatusInternalServerError, err.Error()) - return - } + manager := getRuleManager(r) rules := manager.GetAllRules() h.rd.JSON(w, http.StatusOK, rules) } @@ -73,15 +102,7 @@ func (h *ruleHandler) GetAllRules(w http.ResponseWriter, r *http.Request) { // @Failure 500 {string} string "PD server failed to proceed the request." // @Router /config/rules [post] func (h *ruleHandler) SetAllRules(w http.ResponseWriter, r *http.Request) { - manager, err := h.Handler.GetRuleManager() - if err == errs.ErrPlacementDisabled { - h.rd.JSON(w, http.StatusPreconditionFailed, err.Error()) - return - } - if err != nil { - h.rd.JSON(w, http.StatusInternalServerError, err.Error()) - return - } + manager := getRuleManager(r) var rules []*placement.Rule if err := apiutil.ReadJSONRespondError(h.rd, w, r.Body, &rules); err != nil { return @@ -113,15 +134,7 @@ func (h *ruleHandler) SetAllRules(w http.ResponseWriter, r *http.Request) { // @Failure 500 {string} string "PD server failed to proceed the request." // @Router /config/rules/group/{group} [get] func (h *ruleHandler) GetRuleByGroup(w http.ResponseWriter, r *http.Request) { - manager, err := h.Handler.GetRuleManager() - if err == errs.ErrPlacementDisabled { - h.rd.JSON(w, http.StatusPreconditionFailed, err.Error()) - return - } - if err != nil { - h.rd.JSON(w, http.StatusInternalServerError, err.Error()) - return - } + manager := getRuleManager(r) group := mux.Vars(r)["group"] rules := manager.GetRulesByGroup(group) h.rd.JSON(w, http.StatusOK, rules) @@ -138,15 +151,7 @@ func (h *ruleHandler) GetRuleByGroup(w http.ResponseWriter, r *http.Request) { // @Failure 500 {string} string "PD server failed to proceed the request." // @Router /config/rules/region/{region} [get] func (h *ruleHandler) GetRulesByRegion(w http.ResponseWriter, r *http.Request) { - manager, err := h.Handler.GetRuleManager() - if err == errs.ErrPlacementDisabled { - h.rd.JSON(w, http.StatusPreconditionFailed, err.Error()) - return - } - if err != nil { - h.rd.JSON(w, http.StatusInternalServerError, err.Error()) - return - } + manager := getRuleManager(r) regionStr := mux.Vars(r)["region"] region, code, err := h.PreCheckForRegion(regionStr) if err != nil { @@ -196,15 +201,7 @@ func (h *ruleHandler) CheckRegionPlacementRule(w http.ResponseWriter, r *http.Re // @Failure 500 {string} string "PD server failed to proceed the request." // @Router /config/rules/key/{key} [get] func (h *ruleHandler) GetRulesByKey(w http.ResponseWriter, r *http.Request) { - manager, err := h.Handler.GetRuleManager() - if err == errs.ErrPlacementDisabled { - h.rd.JSON(w, http.StatusPreconditionFailed, err.Error()) - return - } - if err != nil { - h.rd.JSON(w, http.StatusInternalServerError, err.Error()) - return - } + manager := getRuleManager(r) keyHex := mux.Vars(r)["key"] key, err := hex.DecodeString(keyHex) if err != nil { @@ -225,15 +222,7 @@ func (h *ruleHandler) GetRulesByKey(w http.ResponseWriter, r *http.Request) { // @Failure 412 {string} string "Placement rules feature is disabled." // @Router /config/rule/{group}/{id} [get] func (h *ruleHandler) GetRuleByGroupAndID(w http.ResponseWriter, r *http.Request) { - manager, err := h.Handler.GetRuleManager() - if err == errs.ErrPlacementDisabled { - h.rd.JSON(w, http.StatusPreconditionFailed, err.Error()) - return - } - if err != nil { - h.rd.JSON(w, http.StatusInternalServerError, err.Error()) - return - } + manager := getRuleManager(r) group, id := mux.Vars(r)["group"], mux.Vars(r)["id"] rule := manager.GetRule(group, id) if rule == nil { @@ -254,15 +243,7 @@ func (h *ruleHandler) GetRuleByGroupAndID(w http.ResponseWriter, r *http.Request // @Failure 500 {string} string "PD server failed to proceed the request." // @Router /config/rule [post] func (h *ruleHandler) SetRule(w http.ResponseWriter, r *http.Request) { - manager, err := h.Handler.GetRuleManager() - if err == errs.ErrPlacementDisabled { - h.rd.JSON(w, http.StatusPreconditionFailed, err.Error()) - return - } - if err != nil { - h.rd.JSON(w, http.StatusInternalServerError, err.Error()) - return - } + manager := getRuleManager(r) var rule placement.Rule if err := apiutil.ReadJSONRespondError(h.rd, w, r.Body, &rule); err != nil { return @@ -312,15 +293,7 @@ func (h *ruleHandler) syncReplicateConfigWithDefaultRule(rule *placement.Rule) e // @Failure 500 {string} string "PD server failed to proceed the request." // @Router /config/rule/{group}/{id} [delete] func (h *ruleHandler) DeleteRuleByGroup(w http.ResponseWriter, r *http.Request) { - manager, err := h.Handler.GetRuleManager() - if err == errs.ErrPlacementDisabled { - h.rd.JSON(w, http.StatusPreconditionFailed, err.Error()) - return - } - if err != nil { - h.rd.JSON(w, http.StatusInternalServerError, err.Error()) - return - } + manager := getRuleManager(r) group, id := mux.Vars(r)["group"], mux.Vars(r)["id"] rule := manager.GetRule(group, id) if err := manager.DeleteRule(group, id); err != nil { @@ -345,15 +318,7 @@ func (h *ruleHandler) DeleteRuleByGroup(w http.ResponseWriter, r *http.Request) // @Failure 500 {string} string "PD server failed to proceed the request." // @Router /config/rules/batch [post] func (h *ruleHandler) BatchRules(w http.ResponseWriter, r *http.Request) { - manager, err := h.Handler.GetRuleManager() - if err == errs.ErrPlacementDisabled { - h.rd.JSON(w, http.StatusPreconditionFailed, err.Error()) - return - } - if err != nil { - h.rd.JSON(w, http.StatusInternalServerError, err.Error()) - return - } + manager := getRuleManager(r) var opts []placement.RuleOp if err := apiutil.ReadJSONRespondError(h.rd, w, r.Body, &opts); err != nil { return @@ -380,15 +345,7 @@ func (h *ruleHandler) BatchRules(w http.ResponseWriter, r *http.Request) { // @Failure 500 {string} string "PD server failed to proceed the request." // @Router /config/rule_group/{id} [get] func (h *ruleHandler) GetGroupConfig(w http.ResponseWriter, r *http.Request) { - manager, err := h.Handler.GetRuleManager() - if err == errs.ErrPlacementDisabled { - h.rd.JSON(w, http.StatusPreconditionFailed, err.Error()) - return - } - if err != nil { - h.rd.JSON(w, http.StatusInternalServerError, err.Error()) - return - } + manager := getRuleManager(r) id := mux.Vars(r)["id"] group := manager.GetRuleGroup(id) if group == nil { @@ -409,15 +366,7 @@ func (h *ruleHandler) GetGroupConfig(w http.ResponseWriter, r *http.Request) { // @Failure 500 {string} string "PD server failed to proceed the request." // @Router /config/rule_group [post] func (h *ruleHandler) SetGroupConfig(w http.ResponseWriter, r *http.Request) { - manager, err := h.Handler.GetRuleManager() - if err == errs.ErrPlacementDisabled { - h.rd.JSON(w, http.StatusPreconditionFailed, err.Error()) - return - } - if err != nil { - h.rd.JSON(w, http.StatusInternalServerError, err.Error()) - return - } + manager := getRuleManager(r) var ruleGroup placement.RuleGroup if err := apiutil.ReadJSONRespondError(h.rd, w, r.Body, &ruleGroup); err != nil { return @@ -442,17 +391,9 @@ func (h *ruleHandler) SetGroupConfig(w http.ResponseWriter, r *http.Request) { // @Failure 500 {string} string "PD server failed to proceed the request." // @Router /config/rule_group/{id} [delete] func (h *ruleHandler) DeleteGroupConfig(w http.ResponseWriter, r *http.Request) { - manager, err := h.Handler.GetRuleManager() - if err == errs.ErrPlacementDisabled { - h.rd.JSON(w, http.StatusPreconditionFailed, err.Error()) - return - } - if err != nil { - h.rd.JSON(w, http.StatusInternalServerError, err.Error()) - return - } + manager := getRuleManager(r) id := mux.Vars(r)["id"] - err = manager.DeleteRuleGroup(id) + err := manager.DeleteRuleGroup(id) if err != nil { h.rd.JSON(w, http.StatusInternalServerError, err.Error()) return @@ -472,15 +413,7 @@ func (h *ruleHandler) DeleteGroupConfig(w http.ResponseWriter, r *http.Request) // @Failure 500 {string} string "PD server failed to proceed the request." // @Router /config/rule_groups [get] func (h *ruleHandler) GetAllGroupConfigs(w http.ResponseWriter, r *http.Request) { - manager, err := h.Handler.GetRuleManager() - if err == errs.ErrPlacementDisabled { - h.rd.JSON(w, http.StatusPreconditionFailed, err.Error()) - return - } - if err != nil { - h.rd.JSON(w, http.StatusInternalServerError, err.Error()) - return - } + manager := getRuleManager(r) ruleGroups := manager.GetRuleGroups() h.rd.JSON(w, http.StatusOK, ruleGroups) } @@ -493,15 +426,7 @@ func (h *ruleHandler) GetAllGroupConfigs(w http.ResponseWriter, r *http.Request) // @Failure 500 {string} string "PD server failed to proceed the request." // @Router /config/placement-rule [get] func (h *ruleHandler) GetPlacementRules(w http.ResponseWriter, r *http.Request) { - manager, err := h.Handler.GetRuleManager() - if err == errs.ErrPlacementDisabled { - h.rd.JSON(w, http.StatusPreconditionFailed, err.Error()) - return - } - if err != nil { - h.rd.JSON(w, http.StatusInternalServerError, err.Error()) - return - } + manager := getRuleManager(r) bundles := manager.GetAllGroupBundles() h.rd.JSON(w, http.StatusOK, bundles) } @@ -516,15 +441,7 @@ func (h *ruleHandler) GetPlacementRules(w http.ResponseWriter, r *http.Request) // @Failure 500 {string} string "PD server failed to proceed the request." // @Router /config/placement-rule [post] func (h *ruleHandler) SetPlacementRules(w http.ResponseWriter, r *http.Request) { - manager, err := h.Handler.GetRuleManager() - if err == errs.ErrPlacementDisabled { - h.rd.JSON(w, http.StatusPreconditionFailed, err.Error()) - return - } - if err != nil { - h.rd.JSON(w, http.StatusInternalServerError, err.Error()) - return - } + manager := getRuleManager(r) var groups []placement.GroupBundle if err := apiutil.ReadJSONRespondError(h.rd, w, r.Body, &groups); err != nil { return @@ -551,15 +468,7 @@ func (h *ruleHandler) SetPlacementRules(w http.ResponseWriter, r *http.Request) // @Failure 500 {string} string "PD server failed to proceed the request." // @Router /config/placement-rule/{group} [get] func (h *ruleHandler) GetPlacementRuleByGroup(w http.ResponseWriter, r *http.Request) { - manager, err := h.Handler.GetRuleManager() - if err == errs.ErrPlacementDisabled { - h.rd.JSON(w, http.StatusPreconditionFailed, err.Error()) - return - } - if err != nil { - h.rd.JSON(w, http.StatusInternalServerError, err.Error()) - return - } + manager := getRuleManager(r) g := mux.Vars(r)["group"] group := manager.GetGroupBundle(g) h.rd.JSON(w, http.StatusOK, group) @@ -576,16 +485,9 @@ func (h *ruleHandler) GetPlacementRuleByGroup(w http.ResponseWriter, r *http.Req // @Failure 500 {string} string "PD server failed to proceed the request." // @Router /config/placement-rule [delete] func (h *ruleHandler) DeletePlacementRuleByGroup(w http.ResponseWriter, r *http.Request) { - manager, err := h.Handler.GetRuleManager() - if err == errs.ErrPlacementDisabled { - h.rd.JSON(w, http.StatusPreconditionFailed, err.Error()) - return - } - if err != nil { - h.rd.JSON(w, http.StatusInternalServerError, err.Error()) - return - } + manager := getRuleManager(r) group := mux.Vars(r)["group"] + var err error group, err = url.PathUnescape(group) if err != nil { h.rd.JSON(w, http.StatusBadRequest, err.Error()) @@ -608,15 +510,7 @@ func (h *ruleHandler) DeletePlacementRuleByGroup(w http.ResponseWriter, r *http. // @Failure 500 {string} string "PD server failed to proceed the request." // @Router /config/placement-rule/{group} [post] func (h *ruleHandler) SetPlacementRuleByGroup(w http.ResponseWriter, r *http.Request) { - manager, err := h.Handler.GetRuleManager() - if err == errs.ErrPlacementDisabled { - h.rd.JSON(w, http.StatusPreconditionFailed, err.Error()) - return - } - if err != nil { - h.rd.JSON(w, http.StatusInternalServerError, err.Error()) - return - } + manager := getRuleManager(r) groupID := mux.Vars(r)["group"] var group placement.GroupBundle if err := apiutil.ReadJSONRespondError(h.rd, w, r.Body, &group); err != nil {