From baa077f6dda7e4278b74ba30f5afc1a6c9e0f405 Mon Sep 17 00:00:00 2001 From: Jonas Hungershausen Date: Fri, 10 Jan 2025 12:23:37 +0100 Subject: [PATCH] fix: don't show oidc subject in login hints --- identity/manager.go | 62 +++++++++++++-------------- identity/manager_test.go | 14 +++--- selfservice/flow/registration/hook.go | 2 +- 3 files changed, 39 insertions(+), 39 deletions(-) diff --git a/identity/manager.go b/identity/manager.go index a09a08a778cd..74c984eb13af 100644 --- a/identity/manager.go +++ b/identity/manager.go @@ -10,6 +10,7 @@ import ( "reflect" "slices" "sort" + "strings" "github.com/ory/kratos/schema" "github.com/ory/x/sqlcon" @@ -102,7 +103,7 @@ func (m *Manager) Create(ctx context.Context, i *Identity, opts ...ManagerOption return nil } -func (m *Manager) ConflictingIdentity(ctx context.Context, i *Identity) (found *Identity, foundConflictAddress string, err error) { +func (m *Manager) ConflictingIdentity(ctx context.Context, i *Identity) (found *Identity, foundConflictAddress string, conflictAddressType string, err error) { for ct, cred := range i.Credentials { for _, id := range cred.Identifiers { found, _, err = m.r.PrivilegedIdentityPool().FindByCredentialsIdentifier(ctx, ct, id) @@ -112,10 +113,10 @@ func (m *Manager) ConflictingIdentity(ctx context.Context, i *Identity) (found * // FindByCredentialsIdentifier does not expand identity credentials. if err = m.r.PrivilegedIdentityPool().HydrateIdentityAssociations(ctx, found, ExpandCredentials); err != nil { - return nil, "", err + return nil, "", "", err } - return found, id, nil + return found, id, ct.String(), nil } } @@ -125,16 +126,16 @@ func (m *Manager) ConflictingIdentity(ctx context.Context, i *Identity) (found * if errors.Is(err, sqlcon.ErrNoRows) { continue } else if err != nil { - return nil, "", err + return nil, "", "", err } foundConflictAddress = conflictingAddress.Value found, err = m.r.PrivilegedIdentityPool().GetIdentity(ctx, conflictingAddress.IdentityID, ExpandCredentials) if err != nil { - return nil, "", err + return nil, "", "", err } - return found, foundConflictAddress, nil + return found, foundConflictAddress, va.Via, nil } // Last option: check the recovery address @@ -143,19 +144,19 @@ func (m *Manager) ConflictingIdentity(ctx context.Context, i *Identity) (found * if errors.Is(err, sqlcon.ErrNoRows) { continue } else if err != nil { - return nil, "", err + return nil, "", "", err } foundConflictAddress = conflictingAddress.Value found, err = m.r.PrivilegedIdentityPool().GetIdentity(ctx, conflictingAddress.IdentityID, ExpandCredentials) if err != nil { - return nil, "", err + return nil, "", "", err } - return found, foundConflictAddress, nil + return found, foundConflictAddress, string(va.Via), nil } - return nil, "", sqlcon.ErrNoRows + return nil, "", "", sqlcon.ErrNoRows } func (m *Manager) findExistingAuthMethod(ctx context.Context, e error, i *Identity) (err error) { @@ -163,7 +164,7 @@ func (m *Manager) findExistingAuthMethod(ctx context.Context, e error, i *Identi return &ErrDuplicateCredentials{error: e} } - found, foundConflictAddress, err := m.ConflictingIdentity(ctx, i) + found, foundConflictAddress, conflictingAddressType, err := m.ConflictingIdentity(ctx, i) if err != nil { if errors.Is(err, sqlcon.ErrNoRows) { return &ErrDuplicateCredentials{error: e} @@ -181,6 +182,11 @@ func (m *Manager) findExistingAuthMethod(ctx context.Context, e error, i *Identi }) duplicateCredErr := &ErrDuplicateCredentials{error: e} + // OIDC credentials are not email addresses but the sub claim from the OIDC provider. + // This is useless for the user, so in that case, we don't set the identifier hint. + if conflictingAddressType != CredentialsTypeOIDC.String() { + duplicateCredErr.SetIdentifierHint(strings.Trim(foundConflictAddress, " ")) + } for _, cred := range creds { if cred.Config == nil { @@ -192,11 +198,9 @@ func (m *Manager) findExistingAuthMethod(ctx context.Context, e error, i *Identi // in to the first factor (obviously). switch cred.Type { case CredentialsTypePassword: - identifierHint := foundConflictAddress - if len(cred.Identifiers) > 0 { - identifierHint = cred.Identifiers[0] + if duplicateCredErr.IdentifierHint() == "" && len(cred.Identifiers) == 1 { + duplicateCredErr.SetIdentifierHint(cred.Identifiers[0]) } - duplicateCredErr.SetIdentifierHint(identifierHint) var cfg CredentialsPassword if err := json.Unmarshal(cred.Config, &cfg); err != nil { @@ -209,14 +213,7 @@ func (m *Manager) findExistingAuthMethod(ctx context.Context, e error, i *Identi } duplicateCredErr.AddCredentialsType(cred.Type) - case CredentialsTypeCodeAuth: - identifierHint := foundConflictAddress - if len(cred.Identifiers) > 0 { - identifierHint = cred.Identifiers[0] - } - - duplicateCredErr.SetIdentifierHint(identifierHint) duplicateCredErr.AddCredentialsType(cred.Type) case CredentialsTypeOIDC: var cfg CredentialsOIDC @@ -230,7 +227,6 @@ func (m *Manager) findExistingAuthMethod(ctx context.Context, e error, i *Identi } duplicateCredErr.AddCredentialsType(cred.Type) - duplicateCredErr.SetIdentifierHint(foundConflictAddress) duplicateCredErr.availableOIDCProviders = available case CredentialsTypeWebAuthn: var cfg CredentialsWebAuthnConfig @@ -238,15 +234,12 @@ func (m *Manager) findExistingAuthMethod(ctx context.Context, e error, i *Identi return errors.WithStack(herodot.ErrInternalServerError.WithReasonf("Unable to JSON decode identity credentials %s for identity %s.", cred.Type, found.ID)) } - identifierHint := foundConflictAddress - if len(cred.Identifiers) > 0 { - identifierHint = cred.Identifiers[0] + if duplicateCredErr.IdentifierHint() == "" && len(cred.Identifiers) == 1 { + duplicateCredErr.SetIdentifierHint(cred.Identifiers[0]) } - for _, webauthn := range cfg.Credentials { if webauthn.IsPasswordless { duplicateCredErr.AddCredentialsType(cred.Type) - duplicateCredErr.SetIdentifierHint(identifierHint) break } } @@ -256,15 +249,12 @@ func (m *Manager) findExistingAuthMethod(ctx context.Context, e error, i *Identi return errors.WithStack(herodot.ErrInternalServerError.WithReasonf("Unable to JSON decode identity credentials %s for identity %s.", cred.Type, found.ID)) } - identifierHint := foundConflictAddress - if len(cred.Identifiers) > 0 { - identifierHint = cred.Identifiers[0] + if duplicateCredErr.IdentifierHint() == "" && len(cred.Identifiers) == 1 { + duplicateCredErr.SetIdentifierHint(cred.Identifiers[0]) } - for _, webauthn := range cfg.Credentials { if webauthn.IsPasswordless { duplicateCredErr.AddCredentialsType(cred.Type) - duplicateCredErr.SetIdentifierHint(identifierHint) break } } @@ -343,6 +333,7 @@ func (e *CreateIdentitiesError) Error() string { e.init() return fmt.Sprintf("create identities error: %d identities failed", len(e.failedIdentities)) } + func (e *CreateIdentitiesError) Unwrap() []error { e.init() var errs []error @@ -356,17 +347,20 @@ func (e *CreateIdentitiesError) AddFailedIdentity(ident *Identity, err *herodot. e.init() e.failedIdentities[ident] = err } + func (e *CreateIdentitiesError) Merge(other *CreateIdentitiesError) { e.init() for k, v := range other.failedIdentities { e.failedIdentities[k] = v } } + func (e *CreateIdentitiesError) Contains(ident *Identity) bool { e.init() _, found := e.failedIdentities[ident] return found } + func (e *CreateIdentitiesError) Find(ident *Identity) *FailedIdentity { e.init() if err, found := e.failedIdentities[ident]; found { @@ -375,12 +369,14 @@ func (e *CreateIdentitiesError) Find(ident *Identity) *FailedIdentity { return nil } + func (e *CreateIdentitiesError) ErrOrNil() error { if e == nil || len(e.failedIdentities) == 0 { return nil } return e } + func (e *CreateIdentitiesError) init() { if e.failedIdentities == nil { e.failedIdentities = map[*Identity]*herodot.DefaultError{} diff --git a/identity/manager_test.go b/identity/manager_test.go index b7659eba68a1..3e1efaff0673 100644 --- a/identity/manager_test.go +++ b/identity/manager_test.go @@ -282,6 +282,7 @@ func TestManager(t *testing.T) { assert.ErrorAs(t, err, &verr) assert.ElementsMatch(t, []string{"oidc"}, verr.AvailableCredentials()) assert.ElementsMatch(t, []string{"google", "github"}, verr.AvailableOIDCProviders()) + // The conflicting identifier is the oidc subject, which is not useful for the user assert.Equal(t, email, verr.IdentifierHint()) }) @@ -756,7 +757,7 @@ func TestManager(t *testing.T) { require.NoError(t, reg.IdentityManager().Create(ctx, conflicOnRecoveryAddress)) t.Run("case=returns not found if no conflict", func(t *testing.T) { - found, foundConflictAddress, err := reg.IdentityManager().ConflictingIdentity(ctx, &identity.Identity{ + found, foundConflictAddress, addressType, err := reg.IdentityManager().ConflictingIdentity(ctx, &identity.Identity{ Credentials: map[identity.CredentialsType]identity.Credentials{ identity.CredentialsTypePassword: {Identifiers: []string{"no-conflict@example.com"}}, }, @@ -764,10 +765,11 @@ func TestManager(t *testing.T) { assert.ErrorIs(t, err, sqlcon.ErrNoRows) assert.Nil(t, found) assert.Empty(t, foundConflictAddress) + assert.Empty(t, addressType) }) t.Run("case=conflict on identifier", func(t *testing.T) { - found, foundConflictAddress, err := reg.IdentityManager().ConflictingIdentity(ctx, &identity.Identity{ + found, foundConflictAddress, addressType, err := reg.IdentityManager().ConflictingIdentity(ctx, &identity.Identity{ Credentials: map[identity.CredentialsType]identity.Credentials{ identity.CredentialsTypePassword: {Identifiers: []string{"conflict-on-identifier@example.com"}}, }, @@ -775,10 +777,11 @@ func TestManager(t *testing.T) { require.NoError(t, err) assert.Equal(t, conflicOnIdentifier.ID, found.ID) assert.Equal(t, "conflict-on-identifier@example.com", foundConflictAddress) + assert.EqualValues(t, string(identity.CredentialsTypePassword), addressType) }) t.Run("case=conflict on verifiable address", func(t *testing.T) { - found, foundConflictAddress, err := reg.IdentityManager().ConflictingIdentity(ctx, &identity.Identity{ + found, foundConflictAddress, addressType, err := reg.IdentityManager().ConflictingIdentity(ctx, &identity.Identity{ VerifiableAddresses: []identity.VerifiableAddress{{ Value: "conflict-on-va@example.com", Via: "email", @@ -787,10 +790,10 @@ func TestManager(t *testing.T) { require.NoError(t, err) assert.Equal(t, conflicOnVerifiableAddress.ID, found.ID) assert.Equal(t, "conflict-on-va@example.com", foundConflictAddress) + assert.Equal(t, "email", addressType) }) - t.Run("case=conflict on recovery address", func(t *testing.T) { - found, foundConflictAddress, err := reg.IdentityManager().ConflictingIdentity(ctx, &identity.Identity{ + found, foundConflictAddress, addressType, err := reg.IdentityManager().ConflictingIdentity(ctx, &identity.Identity{ RecoveryAddresses: []identity.RecoveryAddress{{ Value: "conflict-on-ra@example.com", Via: "email", @@ -799,6 +802,7 @@ func TestManager(t *testing.T) { require.NoError(t, err) assert.Equal(t, conflicOnRecoveryAddress.ID, found.ID) assert.Equal(t, "conflict-on-ra@example.com", foundConflictAddress) + assert.Equal(t, "email", addressType) }) }) } diff --git a/selfservice/flow/registration/hook.go b/selfservice/flow/registration/hook.go index ab7400b60936..d53e1ffcb047 100644 --- a/selfservice/flow/registration/hook.go +++ b/selfservice/flow/registration/hook.go @@ -330,7 +330,7 @@ func (e *HookExecutor) PostRegistrationHook(w http.ResponseWriter, r *http.Reque } func (e *HookExecutor) getDuplicateIdentifier(ctx context.Context, i *identity.Identity) (string, error) { - _, id, err := e.d.IdentityManager().ConflictingIdentity(ctx, i) + _, id, _, err := e.d.IdentityManager().ConflictingIdentity(ctx, i) if err != nil { return "", err }