Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use dedicated registration ID for auth flow #2337

Merged
merged 20 commits into from
Jan 26, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions cmd/headscale/cli/debug.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@ import (
"fmt"

v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/rs/zerolog/log"
"github.com/spf13/cobra"
"google.golang.org/grpc/status"
"tailscale.com/types/key"
)

const (
Expand Down Expand Up @@ -79,7 +79,7 @@ var createNodeCmd = &cobra.Command{
)
}

machineKey, err := cmd.Flags().GetString("key")
registrationID, err := cmd.Flags().GetString("key")
if err != nil {
ErrorOutput(
err,
Expand All @@ -88,8 +88,7 @@ var createNodeCmd = &cobra.Command{
)
}

var mkey key.MachinePublic
err = mkey.UnmarshalText([]byte(machineKey))
_, err = types.RegistrationIDFromString(registrationID)
if err != nil {
ErrorOutput(
err,
Expand All @@ -108,7 +107,7 @@ var createNodeCmd = &cobra.Command{
}

request := &v1.DebugCreateNodeRequest{
Key: machineKey,
Key: registrationID,
Name: name,
User: user,
Routes: routes,
Expand Down
4 changes: 2 additions & 2 deletions cmd/headscale/cli/nodes.go
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ var registerNodeCmd = &cobra.Command{
defer cancel()
defer conn.Close()

machineKey, err := cmd.Flags().GetString("key")
registrationID, err := cmd.Flags().GetString("key")
if err != nil {
ErrorOutput(
err,
Expand All @@ -132,7 +132,7 @@ var registerNodeCmd = &cobra.Command{
}

request := &v1.RegisterNodeRequest{
Key: machineKey,
Key: registrationID,
User: user,
}

Expand Down
6 changes: 3 additions & 3 deletions hscontrol/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ type Headscale struct {
mapper *mapper.Mapper
nodeNotifier *notifier.Notifier

registrationCache *zcache.Cache[string, types.Node]
registrationCache *zcache.Cache[types.RegistrationID, types.RegisterNode]

authProvider AuthProvider

Expand All @@ -123,7 +123,7 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) {
return nil, fmt.Errorf("failed to read or create Noise protocol private key: %w", err)
}

registrationCache := zcache.New[string, types.Node](
registrationCache := zcache.New[types.RegistrationID, types.RegisterNode](
registerCacheExpiration,
registerCacheCleanup,
)
Expand Down Expand Up @@ -462,7 +462,7 @@ func (h *Headscale) createRouter(grpcMux *grpcRuntime.ServeMux) *mux.Router {

router.HandleFunc("/health", h.HealthHandler).Methods(http.MethodGet)
router.HandleFunc("/key", h.KeyHandler).Methods(http.MethodGet)
router.HandleFunc("/register/{mkey}", h.authProvider.RegisterHandler).Methods(http.MethodGet)
router.HandleFunc("/register/{registration_id}", h.authProvider.RegisterHandler).Methods(http.MethodGet)

if provider, ok := h.authProvider.(*AuthProviderOIDC); ok {
router.HandleFunc("/oidc/callback", provider.OIDCCallbackHandler).Methods(http.MethodGet)
Expand Down
135 changes: 89 additions & 46 deletions hscontrol/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ import (
"errors"
"fmt"
"net/http"
"net/url"
"strings"
"time"

"github.com/juanfont/headscale/hscontrol/db"
Expand All @@ -20,16 +22,18 @@ import (

type AuthProvider interface {
RegisterHandler(http.ResponseWriter, *http.Request)
AuthURL(key.MachinePublic) string
AuthURL(types.RegistrationID) string
}

func logAuthFunc(
registerRequest tailcfg.RegisterRequest,
machineKey key.MachinePublic,
registrationId types.RegistrationID,
) (func(string), func(string), func(error, string)) {
return func(msg string) {
log.Info().
Caller().
Str("registration_id", registrationId.String()).
Str("machine_key", machineKey.ShortString()).
Str("node_key", registerRequest.NodeKey.ShortString()).
Str("node_key_old", registerRequest.OldNodeKey.ShortString()).
Expand All @@ -41,6 +45,7 @@ func logAuthFunc(
func(msg string) {
log.Trace().
Caller().
Str("registration_id", registrationId.String()).
Str("machine_key", machineKey.ShortString()).
Str("node_key", registerRequest.NodeKey.ShortString()).
Str("node_key_old", registerRequest.OldNodeKey.ShortString()).
Expand All @@ -52,6 +57,7 @@ func logAuthFunc(
func(err error, msg string) {
log.Error().
Caller().
Str("registration_id", registrationId.String()).
Str("machine_key", machineKey.ShortString()).
Str("node_key", registerRequest.NodeKey.ShortString()).
Str("node_key_old", registerRequest.OldNodeKey.ShortString()).
Expand All @@ -63,16 +69,64 @@ func logAuthFunc(
}
}

func (h *Headscale) waitForFollowup(
req *http.Request,
regReq tailcfg.RegisterRequest,
logTrace func(string),
) {
logTrace("register request is a followup")
fu, err := url.Parse(regReq.Followup)
if err != nil {
logTrace("failed to parse followup URL")
return
}

followupReg, err := types.RegistrationIDFromString(strings.ReplaceAll(fu.Path, "/register/", ""))
if err != nil {
logTrace("followup URL does not contains a valid registration ID")
return
}

logTrace(fmt.Sprintf("followup URL contains a valid registration ID, looking up in cache: %s", followupReg))

if reg, ok := h.registrationCache.Get(followupReg); ok {
logTrace("Node is waiting for interactive login")

select {
case <-req.Context().Done():
logTrace("node went away before it was registered")
return
case <-reg.Registered:
logTrace("node has successfully registered")
return
}
}
}

// handleRegister is the logic for registering a client.
func (h *Headscale) handleRegister(
writer http.ResponseWriter,
req *http.Request,
regReq tailcfg.RegisterRequest,
machineKey key.MachinePublic,
) {
logInfo, logTrace, _ := logAuthFunc(regReq, machineKey)
registrationId, err := types.NewRegistrationID()
if err != nil {
log.Error().
Caller().
Err(err).
Msg("Failed to generate registration ID")
http.Error(writer, "Internal server error", http.StatusInternalServerError)

return
}

logInfo, logTrace, _ := logAuthFunc(regReq, machineKey, registrationId)
now := time.Now().UTC()
logTrace("handleRegister called, looking up machine in DB")

// TODO(kradalby): Use reqs NodeKey and OldNodeKey as indicators for new registrations vs
// key refreshes. This will allow us to remove the machineKey from the registration request.
node, err := h.db.GetNodeByAnyKey(machineKey, regReq.NodeKey, regReq.OldNodeKey)
logTrace("handleRegister database lookup has returned")
if errors.Is(err, gorm.ErrRecordNotFound) {
Expand All @@ -84,27 +138,9 @@ func (h *Headscale) handleRegister(
}

// Check if the node is waiting for interactive login.
//
// TODO(juan): We could use this field to improve our protocol implementation,
// and hold the request until the client closes it, or the interactive
// login is completed (i.e., the user registers the node).
// This is not implemented yet, as it is no strictly required. The only side-effect
// is that the client will hammer headscale with requests until it gets a
// successful RegisterResponse.
if regReq.Followup != "" {
logTrace("register request is a followup")
if _, ok := h.registrationCache.Get(machineKey.String()); ok {
logTrace("Node is waiting for interactive login")

select {
case <-req.Context().Done():
return
case <-time.After(registrationHoldoff):
h.handleNewNode(writer, regReq, machineKey)

return
}
}
h.waitForFollowup(req, regReq, logTrace)
return
}

logInfo("Node not found in database, creating new")
Expand All @@ -113,25 +149,28 @@ func (h *Headscale) handleRegister(
// that we rely on a method that calls back some how (OpenID or CLI)
// We create the node and then keep it around until a callback
// happens
newNode := types.Node{
MachineKey: machineKey,
Hostname: regReq.Hostinfo.Hostname,
NodeKey: regReq.NodeKey,
LastSeen: &now,
Expiry: &time.Time{},
newNode := types.RegisterNode{
Node: types.Node{
MachineKey: machineKey,
Hostname: regReq.Hostinfo.Hostname,
NodeKey: regReq.NodeKey,
LastSeen: &now,
Expiry: &time.Time{},
},
Registered: make(chan struct{}),
}

if !regReq.Expiry.IsZero() {
logTrace("Non-zero expiry time requested")
newNode.Expiry = &regReq.Expiry
newNode.Node.Expiry = &regReq.Expiry
}

h.registrationCache.Set(
machineKey.String(),
registrationId,
newNode,
)

h.handleNewNode(writer, regReq, machineKey)
h.handleNewNode(writer, regReq, registrationId)

return
}
Expand Down Expand Up @@ -206,27 +245,28 @@ func (h *Headscale) handleRegister(
}

if regReq.Followup != "" {
select {
case <-req.Context().Done():
return
case <-time.After(registrationHoldoff):
}
h.waitForFollowup(req, regReq, logTrace)
return
}

// The node has expired or it is logged out
h.handleNodeExpiredOrLoggedOut(writer, regReq, *node, machineKey)
h.handleNodeExpiredOrLoggedOut(writer, regReq, *node, machineKey, registrationId)

// TODO(juan): RegisterRequest includes an Expiry time, that we could optionally use
node.Expiry = &time.Time{}

// TODO(kradalby): do we need to rethink this as part of authflow?
// If we are here it means the client needs to be reauthorized,
// we need to make sure the NodeKey matches the one in the request
// TODO(juan): What happens when using fast user switching between two
// headscale-managed tailnets?
node.NodeKey = regReq.NodeKey
h.registrationCache.Set(
machineKey.String(),
*node,
registrationId,
types.RegisterNode{
Node: *node,
Registered: make(chan struct{}),
},
)

return
Expand Down Expand Up @@ -296,6 +336,8 @@ func (h *Headscale) handleAuthKey(
// The error is not important, because if it does not
// exist, then this is a new node and we will move
// on to registration.
// TODO(kradalby): Use reqs NodeKey and OldNodeKey as indicators for new registrations vs
// key refreshes. This will allow us to remove the machineKey from the registration request.
node, _ := h.db.GetNodeByAnyKey(machineKey, registerRequest.NodeKey, registerRequest.OldNodeKey)
if node != nil {
log.Trace().
Expand Down Expand Up @@ -444,16 +486,16 @@ func (h *Headscale) handleAuthKey(
func (h *Headscale) handleNewNode(
writer http.ResponseWriter,
registerRequest tailcfg.RegisterRequest,
machineKey key.MachinePublic,
registrationId types.RegistrationID,
) {
logInfo, logTrace, logErr := logAuthFunc(registerRequest, machineKey)
logInfo, logTrace, logErr := logAuthFunc(registerRequest, key.MachinePublic{}, registrationId)

resp := tailcfg.RegisterResponse{}

// The node registration is new, redirect the client to the registration URL
logTrace("The node seems to be new, sending auth url")
logTrace("The node is new, sending auth url")

resp.AuthURL = h.authProvider.AuthURL(machineKey)
resp.AuthURL = h.authProvider.AuthURL(registrationId)

respBody, err := json.Marshal(resp)
if err != nil {
Expand Down Expand Up @@ -660,6 +702,7 @@ func (h *Headscale) handleNodeExpiredOrLoggedOut(
regReq tailcfg.RegisterRequest,
node types.Node,
machineKey key.MachinePublic,
registrationId types.RegistrationID,
) {
resp := tailcfg.RegisterResponse{}

Expand All @@ -673,12 +716,12 @@ func (h *Headscale) handleNodeExpiredOrLoggedOut(
log.Trace().
Caller().
Str("node", node.Hostname).
Str("machine_key", machineKey.ShortString()).
Str("registration_id", registrationId.String()).
Str("node_key", regReq.NodeKey.ShortString()).
Str("node_key_old", regReq.OldNodeKey.ShortString()).
Msg("Node registration has expired or logged out. Sending a auth url to register")

resp.AuthURL = h.authProvider.AuthURL(machineKey)
resp.AuthURL = h.authProvider.AuthURL(registrationId)

respBody, err := json.Marshal(resp)
if err != nil {
Expand All @@ -703,7 +746,7 @@ func (h *Headscale) handleNodeExpiredOrLoggedOut(

log.Trace().
Caller().
Str("machine_key", machineKey.ShortString()).
Str("registration_id", registrationId.String()).
Str("node_key", regReq.NodeKey.ShortString()).
Str("node_key_old", regReq.OldNodeKey.ShortString()).
Str("node", node.Hostname).
Expand Down
Loading
Loading