Skip to content

Commit

Permalink
feat: add single session per user with tags support
Browse files Browse the repository at this point in the history
  • Loading branch information
hf committed Nov 14, 2023
1 parent d76f439 commit fcd541f
Show file tree
Hide file tree
Showing 6 changed files with 193 additions and 20 deletions.
81 changes: 61 additions & 20 deletions internal/api/token_refresh.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,31 +64,18 @@ func (a *API) RefreshTokenGrant(ctx context.Context, w http.ResponseWriter, r *h
}

if session != nil {
var notAfter time.Time
result := session.CheckValidity(retryStart, &token.UpdatedAt, config.Sessions.Timebox, config.Sessions.InactivityTimeout)

if session.NotAfter != nil {
notAfter = *session.NotAfter
}

if config.Sessions.Timebox != nil {
sessionEndsAt := session.CreatedAt.Add((*config.Sessions.Timebox).Abs())
switch result {
case models.SessionValid:
// do nothing

if notAfter.IsZero() || notAfter.After(sessionEndsAt) {
notAfter = sessionEndsAt
}
}
case models.SessionTimedOut:
return oauthError("invalid_grant", "Invalid Refresh Token: Session Expired (Inactivity)")

if !notAfter.IsZero() && a.Now().After(notAfter) {
default:
return oauthError("invalid_grant", "Invalid Refresh Token: Session Expired")
}

if config.Sessions.InactivityTimeout != nil {
timesOutAt := session.LastRefreshedAt(&token.UpdatedAt).Add(*config.Sessions.InactivityTimeout)

if timesOutAt.Before(a.Now()) {
return oauthError("invalid_grant", "Invalid Refresh Token: Session Expired (Inactivity)")
}
}
}

// Basic checks above passed, now we need to serialize access
Expand Down Expand Up @@ -120,6 +107,60 @@ func (a *API) RefreshTokenGrant(ctx context.Context, w http.ResponseWriter, r *h
return internalServerError(terr.Error())
}

if a.config.Sessions.SinglePerUser {
sessions, terr := models.FindAllSessionsForUser(tx, user.ID, true /* forUpdate */)
if models.IsNotFoundError(terr) {
// because forUpdate was set, and the
// previous check outside the
// transaction found a user and
// session, but now we're getting a
// IsNotFoundError, this means that the
// user is locked and we need to retry
// in a few milliseconds
retry = true
return terr
} else if terr != nil {
return internalServerError(terr.Error())
}

sessionTag := session.DetermineTag(config.Sessions.Tags)

// go through all sessions of the user and
// check if the current session is the user's
// most recently refreshed valid session
for _, s := range sessions {
if s.ID == session.ID {
// current session, skip it
continue
}

if s.CheckValidity(retryStart, nil, config.Sessions.Timebox, config.Sessions.InactivityTimeout) != models.SessionValid {
// session is not valid so it
// can't be regarded as active
// on the user
continue
}

if s.DetermineTag(config.Sessions.Tags) != sessionTag {
// if tags are specified,
// ignore sessions with a
// mismatching tag
continue
}

// since token is not the refresh token
// of s, we can't use it's UpdatedAt
// time to compare!
if s.LastRefreshedAt(nil).After(session.LastRefreshedAt(&token.UpdatedAt)) {
// session is not the most
// recently active one
return oauthError("invalid_grant", "Invalid Refresh Token: Session Expired (New Session)")
}
}

// this session is the user's active session
}

// refresh token row and session are locked at this
// point, cannot be concurrently refreshed

Expand Down
43 changes: 43 additions & 0 deletions internal/api/token_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ func (ts *TokenTestSuite) TestSessionTimebox() {

defer func() {
ts.API.overrideTime = nil
ts.API.config.Sessions.Timebox = nil
}()

var buffer bytes.Buffer
Expand Down Expand Up @@ -176,6 +177,48 @@ func (ts *TokenTestSuite) TestFailedToSaveRefreshTokenResultCase() {
assert.Equal(ts.T(), firstResult.RefreshToken, secondResult.RefreshToken)
}

func (ts *TokenTestSuite) TestSingleSessionPerUserNoTags() {
ts.API.config.Sessions.SinglePerUser = true
defer func() {
ts.API.config.Sessions.SinglePerUser = false
}()

firstRefreshToken := ts.RefreshToken

// just in case to give some delay between first and second session creation
time.Sleep(10 * time.Millisecond)

secondRefreshToken, err := models.GrantAuthenticatedUser(ts.API.db, ts.User, models.GrantParams{})

require.NoError(ts.T(), err)

require.NotEqual(ts.T(), *firstRefreshToken.SessionId, *secondRefreshToken.SessionId)
require.Equal(ts.T(), firstRefreshToken.UserID, secondRefreshToken.UserID)

var buffer bytes.Buffer
require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{
"refresh_token": firstRefreshToken.Token,
}))

req := httptest.NewRequest(http.MethodPost, "http://localhost/token?grant_type=refresh_token", &buffer)
req.Header.Set("Content-Type", "application/json")

w := httptest.NewRecorder()

ts.API.handler.ServeHTTP(w, req)
assert.Equal(ts.T(), http.StatusBadRequest, w.Code)
assert.True(ts.T(), ts.API.config.Sessions.SinglePerUser)

var firstResult struct {
Error string `json:"error"`
ErrorDescription string `json:"error_description"`
}

assert.NoError(ts.T(), json.NewDecoder(w.Result().Body).Decode(&firstResult))
assert.Equal(ts.T(), "invalid_grant", firstResult.Error)
assert.Equal(ts.T(), "Invalid Refresh Token: Session Expired (New Session)", firstResult.ErrorDescription)
}

func (ts *TokenTestSuite) TestRateLimitTokenRefresh() {
var buffer bytes.Buffer
req := httptest.NewRequest(http.MethodPost, "http://localhost/token", &buffer)
Expand Down
3 changes: 3 additions & 0 deletions internal/conf/configuration.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,9 @@ func (a *APIConfiguration) Validate() error {
type SessionsConfiguration struct {
Timebox *time.Duration `json:"timebox"`
InactivityTimeout *time.Duration `json:"inactivity_timeout,omitempty" split_words:"true"`

SinglePerUser bool `json:"single_per_user" split_words:"true"`
Tags []string `json:"tags,omitempty"`
}

func (c *SessionsConfiguration) Validate() error {
Expand Down
5 changes: 5 additions & 0 deletions internal/models/refresh_token.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ type GrantParams struct {
FactorID *uuid.UUID

SessionNotAfter *time.Time
SessionTag *string

UserAgent string
IP string
Expand Down Expand Up @@ -145,6 +146,10 @@ func createRefreshToken(tx *storage.Connection, user *User, oldToken *RefreshTok
session.IP = &params.IP
}

if params.SessionTag != nil && *params.SessionTag != "" {
session.Tag = params.SessionTag
}

if err := tx.Create(session); err != nil {
return nil, errors.Wrap(err, "error creating new session")
}
Expand Down
79 changes: 79 additions & 0 deletions internal/models/sessions.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ type Session struct {
RefreshedAt *time.Time `json:"refreshed_at,omitempty" db:"refreshed_at"`
UserAgent *string `json:"user_agent,omitempty" db:"user_agent"`
IP *string `json:"ip,omitempty" db:"ip"`

Tag *string `json:"tag" db:"tag"`
}

func (Session) TableName() string {
Expand Down Expand Up @@ -104,6 +106,54 @@ func (s *Session) UpdateOnlyRefreshInfo(tx *storage.Connection) error {
return tx.UpdateOnly(s, "refreshed_at", "user_agent", "ip")
}

type SessionValidityReason = int

const (
SessionValid SessionValidityReason = iota
SessionPastNotAfter = iota
SessionPastTimebox = iota
SessionTimedOut = iota
)

func (s *Session) CheckValidity(now time.Time, refreshTokenTime *time.Time, timebox, inactivityTimeout *time.Duration) SessionValidityReason {
if s.NotAfter != nil && now.After(*s.NotAfter) {
return SessionPastNotAfter
}

if timebox != nil && *timebox != 0 && now.After(s.CreatedAt.Add(*timebox)) {
return SessionPastTimebox
}

if inactivityTimeout != nil && *inactivityTimeout != 0 && now.After(s.LastRefreshedAt(refreshTokenTime).Add(*inactivityTimeout)) {
return SessionTimedOut
}

return SessionValid
}

func (s *Session) DetermineTag(tags []string) string {
if len(tags) == 0 {
return ""
}

if s.Tag == nil {
return tags[0]
}

tag := *s.Tag
if tag == "" {
return tags[0]
}

for _, t := range tags {
if t == tag {
return tag
}
}

return tags[0]
}

func NewSession() (*Session, error) {
id := uuid.Must(uuid.NewV4())

Expand Down Expand Up @@ -168,6 +218,35 @@ func FindSessionsByFactorID(tx *storage.Connection, factorID uuid.UUID) ([]*Sess
return sessions, nil
}

// FindAllSessionsForUser finds all of the sessions for a user. If forUpdate is
// set, it will first lock on the user row which can be used to prevent issues
// with concurrency. If the lock is acquired, it will return a
// UserNotFoundError and the operation should be retried. If there are no
// sessions for the user, a nil result is returned without an error.
func FindAllSessionsForUser(tx *storage.Connection, userId uuid.UUID, forUpdate bool) ([]*Session, error) {
if forUpdate {
user := &User{}
if err := tx.RawQuery(fmt.Sprintf("SELECT id FROM %q WHERE id = ? LIMIT 1 FOR UPDATE SKIP LOCKED;", user.TableName()), userId).First(user); err != nil {
if errors.Cause(err) == sql.ErrNoRows {
return nil, UserNotFoundError{}
}

return nil, err
}
}

var sessions []*Session
if err := tx.Where("user_id = ?", userId).All(&sessions); err != nil {
if errors.Cause(err) == sql.ErrNoRows {
return nil, nil
}

return nil, err
}

return sessions, nil
}

func updateFactorAssociatedSessions(tx *storage.Connection, userID, factorID uuid.UUID, aal string) error {
return tx.RawQuery("UPDATE "+(&pop.Model{Value: Session{}}).TableName()+" set aal = ?, factor_id = ? WHERE user_id = ? AND factor_id = ?", aal, nil, userID, factorID).Exec()
}
Expand Down
2 changes: 2 additions & 0 deletions migrations/20231114161723_add_sessions_tag.up.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
alter table if exists {{ index .Options "Namespace" }}.sessions
add column if not exists tag text;

0 comments on commit fcd541f

Please sign in to comment.