Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

wfe: use wildcard patterns in HTTP handlers #7791

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions cmd/ocsp-responder/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,10 @@ func (om *ocspMux) Handler(_ *http.Request) (http.Handler, string) {
return om.handler, "/"
}

func (om *ocspMux) ServeHTTP(w http.ResponseWriter, r *http.Request) {
om.handler.ServeHTTP(w, r)
}

func mux(responderPath string, source responder.Source, timeout time.Duration, stats prometheus.Registerer, oTelHTTPOptions []otelhttp.Option, logger blog.Logger, sampleRate int) http.Handler {
stripPrefix := http.StripPrefix(responderPath, responder.NewResponder(source, timeout, stats, logger, sampleRate))
h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
Expand Down
7 changes: 4 additions & 3 deletions metrics/measured_http/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,13 @@ func (r *responseWriterWithStatus) Write(body []byte) (int, error) {
return r.ResponseWriter.Write(body)
}

// serveMux is a partial interface wrapper for the method http.ServeMux
// serveMux is a partial interface wrapper for the methods http.ServeMux
// exposes that we use. This is needed so that we can replace the default
// http.ServeMux in ocsp-responder where we don't want to use its path
// canonicalization.
type serveMux interface {
Handler(*http.Request) (http.Handler, string)
ServeHTTP(w http.ResponseWriter, r *http.Request)
}

// MeasuredHandler wraps an http.Handler and records prometheus stats
Expand Down Expand Up @@ -80,7 +81,7 @@ func (h *MeasuredHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
begin := h.clk.Now()
rwws := &responseWriterWithStatus{w, 0}

subHandler, pattern := h.Handler(r)
_, pattern := h.serveMux.Handler(r)
h.inFlightRequestsGauge.WithLabelValues(pattern).Inc()
defer h.inFlightRequestsGauge.WithLabelValues(pattern).Dec()

Expand All @@ -104,5 +105,5 @@ func (h *MeasuredHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}).Observe(h.clk.Since(begin).Seconds())
}()

subHandler.ServeHTTP(rwws, r)
h.serveMux.ServeHTTP(rwws, r)
}
92 changes: 31 additions & 61 deletions wfe2/wfe.go
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ func (wfe *WebFrontEndImpl) HandleFunc(mux *http.ServeMux, pattern string, h web
methodsMap["HEAD"] = true
}
methodsStr := strings.Join(methods, ", ")
handler := http.StripPrefix(pattern, web.NewTopHandler(wfe.log,
var handler http.Handler = web.NewTopHandler(wfe.log,
web.WFEHandlerFunc(func(ctx context.Context, logEvent *web.RequestEvent, response http.ResponseWriter, request *http.Request) {
span := trace.SpanFromContext(ctx)
span.SetName(pattern)
Expand Down Expand Up @@ -334,7 +334,7 @@ func (wfe *WebFrontEndImpl) HandleFunc(mux *http.ServeMux, pattern string, h web
h(ctx, logEvent, response, request)
cancel()
}),
))
)
mux.Handle(pattern, handler)
}

Expand Down Expand Up @@ -421,33 +421,34 @@ func (wfe *WebFrontEndImpl) Handler(stats prometheus.Registerer, oTelHTTPOptions

// POSTable ACME endpoints
wfe.HandleFunc(m, newAcctPath, wfe.NewAccount, "POST")
wfe.HandleFunc(m, acctPath, wfe.Account, "POST")
wfe.HandleFunc(m, acctPath+"{acctID}", wfe.Account, "POST")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was wondering why you didn't just change acctPath to include the pattern, but then I realized that we also use acctPath (and similar) as an argument to web.relativeEndpoint. Is there a reasonable way to update relativeEndpoint to take a pattern as input?

wfe.HandleFunc(m, revokeCertPath, wfe.RevokeCertificate, "POST")
wfe.HandleFunc(m, rolloverPath, wfe.KeyRollover, "POST")
wfe.HandleFunc(m, newOrderPath, wfe.NewOrder, "POST")
wfe.HandleFunc(m, finalizeOrderPath, wfe.FinalizeOrder, "POST")
wfe.HandleFunc(m, finalizeOrderPath+"{acctID}/{orderID}", wfe.FinalizeOrder, "POST")

// GETable and POST-as-GETable ACME endpoints
wfe.HandleFunc(m, directoryPath, wfe.Directory, "GET", "POST")
wfe.HandleFunc(m, newNoncePath, wfe.Nonce, "GET", "POST")
// POST-as-GETable ACME endpoints
// TODO(@cpu): After November 1st, 2020 support for "GET" to the following
// endpoints will be removed, leaving only POST-as-GET support.
wfe.HandleFunc(m, orderPath, wfe.GetOrder, "GET", "POST")
wfe.HandleFunc(m, authzPath, wfe.AuthorizationHandler, "GET", "POST")
wfe.HandleFunc(m, authzPathWithAcct, wfe.AuthorizationHandlerWithAccount, "GET", "POST")
wfe.HandleFunc(m, challengePath, wfe.ChallengeHandler, "GET", "POST")
wfe.HandleFunc(m, challengePathWithAcct, wfe.ChallengeHandlerWithAccount, "GET", "POST")
wfe.HandleFunc(m, certPath, wfe.Certificate, "GET", "POST")
wfe.HandleFunc(m, orderPath+"{acctID}/{orderID}", wfe.GetOrder, "GET", "POST")
wfe.HandleFunc(m, authzPath+"{authzID}", wfe.AuthorizationHandler, "GET", "POST")
wfe.HandleFunc(m, authzPathWithAcct+"{acctID}/{authzID}", wfe.AuthorizationHandlerWithAccount, "GET", "POST")
wfe.HandleFunc(m, challengePath+"{authzID}/{challID}", wfe.ChallengeHandler, "GET", "POST")
wfe.HandleFunc(m, challengePathWithAcct+"{acctID}/{authzID}/{challID}", wfe.ChallengeHandlerWithAccount, "GET", "POST")
wfe.HandleFunc(m, certPath+"{serial}", wfe.Certificate, "GET", "POST")
wfe.HandleFunc(m, certPath+"{serial}/{chain}", wfe.Certificate, "GET", "POST")
// Boulder-specific GET-able resource endpoints
wfe.HandleFunc(m, getOrderPath, wfe.GetOrder, "GET")
wfe.HandleFunc(m, getAuthzPath, wfe.AuthorizationHandler, "GET")
wfe.HandleFunc(m, getChallengePath, wfe.ChallengeHandler, "GET")
wfe.HandleFunc(m, getCertPath, wfe.Certificate, "GET")
wfe.HandleFunc(m, getOrderPath+"{acctID}/{orderID}", wfe.GetOrder, "GET")
wfe.HandleFunc(m, getAuthzPath+"{authzID}", wfe.AuthorizationHandler, "GET")
wfe.HandleFunc(m, getChallengePath+"{challID}", wfe.ChallengeHandler, "GET")
wfe.HandleFunc(m, getCertPath+"{serial}", wfe.Certificate, "GET")

// Endpoint for draft-ietf-acme-ari
if features.Get().ServeRenewalInfo {
wfe.HandleFunc(m, renewalInfoPath, wfe.RenewalInfo, "GET", "POST")
wfe.HandleFunc(m, renewalInfoPath+"{certID}", wfe.RenewalInfo, "GET", "POST")
}

// We don't use our special HandleFunc for "/" because it matches everything,
Expand Down Expand Up @@ -1099,13 +1100,7 @@ func (wfe *WebFrontEndImpl) ChallengeHandler(
logEvent *web.RequestEvent,
response http.ResponseWriter,
request *http.Request) {
slug := strings.Split(request.URL.Path, "/")
if len(slug) != 2 {
wfe.sendError(response, logEvent, probs.NotFound("No such challenge"), nil)
return
}

wfe.Challenge(ctx, logEvent, challengePath, response, request, slug[0], slug[1])
wfe.Challenge(ctx, logEvent, challengePath, response, request, request.PathValue("authzID"), request.PathValue("challID"))
}

// ChallengeHandlerWithAccount handles POST requests to challenge URLs of the form /acme/chall/{regID}/{authzID}/{challID}.
Expand All @@ -1114,13 +1109,8 @@ func (wfe *WebFrontEndImpl) ChallengeHandlerWithAccount(
logEvent *web.RequestEvent,
response http.ResponseWriter,
request *http.Request) {
slug := strings.Split(request.URL.Path, "/")
if len(slug) != 3 {
wfe.sendError(response, logEvent, probs.NotFound("No such challenge"), nil)
return
}
// TODO(#7683): the regID is currently ignored.
wfe.Challenge(ctx, logEvent, challengePathWithAcct, response, request, slug[1], slug[2])
wfe.Challenge(ctx, logEvent, challengePathWithAcct, response, request, request.PathValue("authzID"), request.PathValue("challID"))
}

// Challenge handles POSTS to both formats of challenge URLs.
Expand Down Expand Up @@ -1403,8 +1393,7 @@ func (wfe *WebFrontEndImpl) Account(

// Requests to this handler should have a path that leads to a known
// account
idStr := request.URL.Path
id, err := strconv.ParseInt(idStr, 10, 64)
id, err := strconv.ParseInt(request.PathValue("acctID"), 10, 64)
if err != nil {
wfe.sendError(response, logEvent, probs.Malformed("Account ID must be an integer"), err)
return
Expand Down Expand Up @@ -1567,7 +1556,7 @@ func (wfe *WebFrontEndImpl) AuthorizationHandler(
logEvent *web.RequestEvent,
response http.ResponseWriter,
request *http.Request) {
wfe.Authorization(ctx, authzPath, logEvent, response, request, request.URL.Path)
wfe.Authorization(ctx, authzPath, logEvent, response, request, request.PathValue("authzID"))
}

// AuthorizationHandlerWithAccount handles requests to authorization URLs of the form /acme/authz/{regID}/{authzID}.
Expand All @@ -1576,13 +1565,8 @@ func (wfe *WebFrontEndImpl) AuthorizationHandlerWithAccount(
logEvent *web.RequestEvent,
response http.ResponseWriter,
request *http.Request) {
slug := strings.Split(request.URL.Path, "/")
if len(slug) != 2 {
wfe.sendError(response, logEvent, probs.NotFound("No such authorization"), nil)
return
}
// TODO(#7683): The regID is currently ignored.
wfe.Authorization(ctx, authzPathWithAcct, logEvent, response, request, slug[1])
wfe.Authorization(ctx, authzPathWithAcct, logEvent, response, request, request.PathValue("authzID"))
}

// Authorization handles both `/acme/authz/{authzID}` and `/acme/authz/{regID}/{authzID}` requests,
Expand Down Expand Up @@ -1706,20 +1690,19 @@ func (wfe *WebFrontEndImpl) Certificate(ctx context.Context, logEvent *web.Reque
}

requestedChain := 0
serial := request.URL.Path
serial := request.PathValue("serial")

// An alternate chain may be requested with the request path {serial}/{chain}, where chain
// is a number - an index into the slice of chains for the issuer. If a specific chain is
// not requested, then it defaults to zero - the default certificate chain for the issuer.
serialAndChain := strings.SplitN(serial, "/", 2)
if len(serialAndChain) == 2 {
idx, err := strconv.Atoi(serialAndChain[1])
chain := request.PathValue("chain")
if chain != "" {
idx, err := strconv.Atoi(chain)
if err != nil || idx < 0 {
wfe.sendError(response, logEvent, probs.Malformed("Chain ID must be a non-negative integer"),
fmt.Errorf("certificate chain id provided was not valid: %s", serialAndChain[1]))
fmt.Errorf("certificate chain id provided was not valid: %q", chain))
return
}
serial = serialAndChain[0]
requestedChain = idx
}

Expand Down Expand Up @@ -2514,18 +2497,12 @@ func (wfe *WebFrontEndImpl) GetOrder(ctx context.Context, logEvent *web.RequestE
requesterAccount = acct
}

// Path prefix is stripped, so this should be like "<account ID>/<order ID>"
fields := strings.SplitN(request.URL.Path, "/", 2)
if len(fields) != 2 {
wfe.sendError(response, logEvent, probs.NotFound("Invalid request path"), nil)
return
}
acctID, err := strconv.ParseInt(fields[0], 10, 64)
acctID, err := strconv.ParseInt(request.PathValue("acctID"), 10, 64)
if err != nil {
wfe.sendError(response, logEvent, probs.Malformed("Invalid account ID"), err)
return
}
orderID, err := strconv.ParseInt(fields[1], 10, 64)
orderID, err := strconv.ParseInt(request.PathValue("orderID"), 10, 64)
if err != nil {
wfe.sendError(response, logEvent, probs.Malformed("Invalid order ID"), err)
return
Expand Down Expand Up @@ -2594,19 +2571,12 @@ func (wfe *WebFrontEndImpl) FinalizeOrder(ctx context.Context, logEvent *web.Req
return
}

// Order URLs are like: /acme/finalize/<account>/<order>/. The prefix is
// stripped by the time we get here.
fields := strings.SplitN(request.URL.Path, "/", 2)
if len(fields) != 2 {
wfe.sendError(response, logEvent, probs.NotFound("Invalid request path"), nil)
return
}
acctID, err := strconv.ParseInt(fields[0], 10, 64)
acctID, err := strconv.ParseInt(request.PathValue("acctID"), 10, 64)
if err != nil {
wfe.sendError(response, logEvent, probs.Malformed("Invalid account ID"), nil)
return
}
orderID, err := strconv.ParseInt(fields[1], 10, 64)
orderID, err := strconv.ParseInt(request.PathValue("orderID"), 10, 64)
if err != nil {
wfe.sendError(response, logEvent, probs.Malformed("Invalid order ID"), nil)
return
Expand Down Expand Up @@ -2756,7 +2726,7 @@ func (wfe *WebFrontEndImpl) RenewalInfo(ctx context.Context, logEvent *web.Reque
return
}

decodedSerial, err := parseARICertID(request.URL.Path, wfe.issuerCertificates)
decodedSerial, err := parseARICertID(request.PathValue("certID"), wfe.issuerCertificates)
if err != nil {
wfe.sendError(response, logEvent, web.ProblemDetailsForError(err, "While parsing ARI CertID an error occurred"), err)
return
Expand Down
Loading