Skip to content

Commit

Permalink
Refactor the users module to bring together all the functions for pro…
Browse files Browse the repository at this point in the history
…pagating user IDs.
  • Loading branch information
tomwilkie committed Mar 2, 2017
1 parent e98b873 commit 4054b33
Show file tree
Hide file tree
Showing 5 changed files with 111 additions and 51 deletions.
6 changes: 6 additions & 0 deletions errors/error.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
package errors

// Error see https://dave.cheney.net/2016/04/07/constant-errors.
type Error string

func (e Error) Error() string { return string(e) }
42 changes: 6 additions & 36 deletions middleware/grpc_auth.go
Original file line number Diff line number Diff line change
@@ -1,58 +1,28 @@
package middleware

import (
"fmt"

"golang.org/x/net/context"
"google.golang.org/grpc"
"google.golang.org/grpc/metadata"

"github.com/weaveworks/common/user"
)

// ClientUserHeaderInterceptor propagates the user ID from the context to gRPC metadata, which eventually ends up as a HTTP2 header.
func ClientUserHeaderInterceptor(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
userID, err := user.GetID(ctx)
ctx, err := user.InjectIntoGRPCRequest(ctx)
if err != nil {
return err
}

md, ok := metadata.FromContext(ctx)
if !ok {
md = metadata.New(map[string]string{})
}

newCtx := ctx
if userIDs, ok := md[user.LowerOrgIDHeaderName]; ok {
switch len(userIDs) {
case 1:
if userIDs[0] != userID {
return fmt.Errorf("wrong user ID found")
}
default:
return fmt.Errorf("multiple user IDs found")
}
} else {
md = md.Copy()
md[user.LowerOrgIDHeaderName] = []string{userID}
newCtx = metadata.NewContext(ctx, md)
}

return invoker(newCtx, method, req, reply, cc, opts...)
return invoker(ctx, method, req, reply, cc, opts...)
}

// ServerUserHeaderInterceptor propagates the user ID from the gRPC metadata back to our context.
func ServerUserHeaderInterceptor(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
md, ok := metadata.FromContext(ctx)
if !ok {
return nil, fmt.Errorf("no metadata")
}

userIDs, ok := md[user.LowerOrgIDHeaderName]
if !ok || len(userIDs) != 1 {
return nil, fmt.Errorf("no user id")
_, ctx, err := user.ExtractFromGRPCRequest(ctx)
if err != nil {
return nil, err
}

newCtx := user.WithID(ctx, userIDs[0])
return handler(newCtx, req)
return handler(ctx, req)
}
51 changes: 51 additions & 0 deletions user/grpc.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
package user

import (
"golang.org/x/net/context"
"google.golang.org/grpc/metadata"
)

// ExtractFromGRPCRequest extracts the user ID from the request metadata and returns
// the user ID and a context with the user ID injected.
func ExtractFromGRPCRequest(ctx context.Context) (string, context.Context, error) {
md, ok := metadata.FromContext(ctx)
if !ok {
return "", ctx, ErrNoUserID
}

userIDs, ok := md[lowerOrgIDHeaderName]
if !ok || len(userIDs) != 1 {
return "", ctx, ErrNoUserID
}

return userIDs[0], Inject(ctx, userIDs[0]), nil
}

// InjectIntoGRPCRequest injects the userID from the context into the request metadata.
func InjectIntoGRPCRequest(ctx context.Context) (context.Context, error) {
userID, err := Extract(ctx)
if err != nil {
return ctx, err
}

md, ok := metadata.FromContext(ctx)
if !ok {
md = metadata.New(map[string]string{})
}
newCtx := ctx
if userIDs, ok := md[lowerOrgIDHeaderName]; ok {
if len(userIDs) == 1 {
if userIDs[0] != userID {
return ctx, ErrDifferentIDPresent
}
} else {
return ctx, ErrTooManyUserIDs
}
} else {
md = md.Copy()
md[lowerOrgIDHeaderName] = []string{userID}
newCtx = metadata.NewContext(ctx, md)
}

return newCtx, nil
}
27 changes: 27 additions & 0 deletions user/http.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
package user

import (
"net/http"

"golang.org/x/net/context"
)

// ExtractFromHTTPRequest extracts the user ID from the request headers and returns
// the user ID and a context with the user ID embbedded.
func ExtractFromHTTPRequest(r *http.Request) (string, context.Context, error) {
userID := r.Header.Get(orgIDHeaderName)
if userID == "" {
return "", r.Context(), ErrNoUserID
}
return userID, Inject(r.Context(), userID), nil
}

// InjectIntoHTTPRequest injects the userID from the context into the request headers.
func InjectIntoHTTPRequest(ctx context.Context, r *http.Request) error {
userID, err := Extract(ctx)
if err != nil {
return err
}
r.Header.Add(orgIDHeaderName, userID)
return nil
}
36 changes: 21 additions & 15 deletions user/id.go
Original file line number Diff line number Diff line change
@@ -1,32 +1,38 @@
package user

import (
"fmt"

"golang.org/x/net/context"

"github.com/weaveworks/common/errors"
)

// UserIDContextKey is the key used in contexts to find the userid
type contextKey int

const userIDContextKey contextKey = 0
const (
// UserIDContextKey is the key used in contexts to find the userid
userIDContextKey contextKey = 0

// OrgIDHeaderName is a legacy from scope as a service.
const OrgIDHeaderName = "X-Scope-OrgID"
// orgIDHeaderName is a legacy from scope as a service.
orgIDHeaderName = "X-Scope-OrgID"

// LowerOrgIDHeaderName as gRPC / HTTP2.0 headers are lowercased.
const LowerOrgIDHeaderName = "x-scope-orgid"
// LowerOrgIDHeaderName as gRPC / HTTP2.0 headers are lowercased.
lowerOrgIDHeaderName = "x-scope-orgid"

ErrNoUserID = errors.Error("no user id")
ErrDifferentIDPresent = errors.Error("different user ID already present")
ErrTooManyUserIDs = errors.Error("multiple user IDs present")
)

// GetID returns the user
func GetID(ctx context.Context) (string, error) {
userid, ok := ctx.Value(userIDContextKey).(string)
// GetID returns the user from the context
func Extract(ctx context.Context) (string, error) {
userID, ok := ctx.Value(userIDContextKey).(string)
if !ok {
return "", fmt.Errorf("no user id")
return "", ErrNoUserID
}
return userid, nil
return userID, nil
}

// WithID returns a derived context containing the user ID.
func WithID(ctx context.Context, userID string) context.Context {
// Inject returns a derived context containing the user ID.
func Inject(ctx context.Context, userID string) context.Context {
return context.WithValue(ctx, interface{}(userIDContextKey), userID)
}

0 comments on commit 4054b33

Please sign in to comment.