Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: don't show oidc subject in login hints #4264

Merged
merged 1 commit into from
Jan 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 29 additions & 33 deletions identity/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
"reflect"
"slices"
"sort"
"strings"

"github.com/ory/kratos/schema"
"github.com/ory/x/sqlcon"
Expand Down Expand Up @@ -102,7 +103,7 @@
return nil
}

func (m *Manager) ConflictingIdentity(ctx context.Context, i *Identity) (found *Identity, foundConflictAddress string, err error) {
func (m *Manager) ConflictingIdentity(ctx context.Context, i *Identity) (found *Identity, foundConflictAddress string, conflictAddressType string, err error) {
for ct, cred := range i.Credentials {
for _, id := range cred.Identifiers {
found, _, err = m.r.PrivilegedIdentityPool().FindByCredentialsIdentifier(ctx, ct, id)
Expand All @@ -112,10 +113,10 @@

// FindByCredentialsIdentifier does not expand identity credentials.
if err = m.r.PrivilegedIdentityPool().HydrateIdentityAssociations(ctx, found, ExpandCredentials); err != nil {
return nil, "", err
return nil, "", "", err

Check warning on line 116 in identity/manager.go

View check run for this annotation

Codecov / codecov/patch

identity/manager.go#L116

Added line #L116 was not covered by tests
}

return found, id, nil
return found, id, ct.String(), nil
}
}

Expand All @@ -125,16 +126,16 @@
if errors.Is(err, sqlcon.ErrNoRows) {
continue
} else if err != nil {
return nil, "", err
return nil, "", "", err

Check warning on line 129 in identity/manager.go

View check run for this annotation

Codecov / codecov/patch

identity/manager.go#L129

Added line #L129 was not covered by tests
}

foundConflictAddress = conflictingAddress.Value
found, err = m.r.PrivilegedIdentityPool().GetIdentity(ctx, conflictingAddress.IdentityID, ExpandCredentials)
if err != nil {
return nil, "", err
return nil, "", "", err

Check warning on line 135 in identity/manager.go

View check run for this annotation

Codecov / codecov/patch

identity/manager.go#L135

Added line #L135 was not covered by tests
}

return found, foundConflictAddress, nil
return found, foundConflictAddress, va.Via, nil
}

// Last option: check the recovery address
Expand All @@ -143,27 +144,27 @@
if errors.Is(err, sqlcon.ErrNoRows) {
continue
} else if err != nil {
return nil, "", err
return nil, "", "", err

Check warning on line 147 in identity/manager.go

View check run for this annotation

Codecov / codecov/patch

identity/manager.go#L147

Added line #L147 was not covered by tests
}

foundConflictAddress = conflictingAddress.Value
found, err = m.r.PrivilegedIdentityPool().GetIdentity(ctx, conflictingAddress.IdentityID, ExpandCredentials)
if err != nil {
return nil, "", err
return nil, "", "", err

Check warning on line 153 in identity/manager.go

View check run for this annotation

Codecov / codecov/patch

identity/manager.go#L153

Added line #L153 was not covered by tests
}

return found, foundConflictAddress, nil
return found, foundConflictAddress, string(va.Via), nil
}

return nil, "", sqlcon.ErrNoRows
return nil, "", "", sqlcon.ErrNoRows
}

func (m *Manager) findExistingAuthMethod(ctx context.Context, e error, i *Identity) (err error) {
if !m.r.Config().SelfServiceFlowRegistrationLoginHints(ctx) {
return &ErrDuplicateCredentials{error: e}
}

found, foundConflictAddress, err := m.ConflictingIdentity(ctx, i)
found, foundConflictAddress, conflictingAddressType, err := m.ConflictingIdentity(ctx, i)
if err != nil {
if errors.Is(err, sqlcon.ErrNoRows) {
return &ErrDuplicateCredentials{error: e}
Expand All @@ -181,6 +182,11 @@
})

duplicateCredErr := &ErrDuplicateCredentials{error: e}
// OIDC credentials are not email addresses but the sub claim from the OIDC provider.
// This is useless for the user, so in that case, we don't set the identifier hint.
if conflictingAddressType != CredentialsTypeOIDC.String() {
duplicateCredErr.SetIdentifierHint(strings.Trim(foundConflictAddress, " "))
}

for _, cred := range creds {
if cred.Config == nil {
Expand All @@ -192,11 +198,9 @@
// in to the first factor (obviously).
switch cred.Type {
case CredentialsTypePassword:
identifierHint := foundConflictAddress
if len(cred.Identifiers) > 0 {
identifierHint = cred.Identifiers[0]
if duplicateCredErr.IdentifierHint() == "" && len(cred.Identifiers) == 1 {
duplicateCredErr.SetIdentifierHint(cred.Identifiers[0])
}
duplicateCredErr.SetIdentifierHint(identifierHint)

var cfg CredentialsPassword
if err := json.Unmarshal(cred.Config, &cfg); err != nil {
Expand All @@ -209,14 +213,7 @@
}

duplicateCredErr.AddCredentialsType(cred.Type)

case CredentialsTypeCodeAuth:
identifierHint := foundConflictAddress
if len(cred.Identifiers) > 0 {
identifierHint = cred.Identifiers[0]
}

duplicateCredErr.SetIdentifierHint(identifierHint)
duplicateCredErr.AddCredentialsType(cred.Type)
case CredentialsTypeOIDC:
var cfg CredentialsOIDC
Expand All @@ -230,23 +227,19 @@
}

duplicateCredErr.AddCredentialsType(cred.Type)
duplicateCredErr.SetIdentifierHint(foundConflictAddress)
duplicateCredErr.availableOIDCProviders = available
case CredentialsTypeWebAuthn:
var cfg CredentialsWebAuthnConfig
if err := json.Unmarshal(cred.Config, &cfg); err != nil {
return errors.WithStack(herodot.ErrInternalServerError.WithReasonf("Unable to JSON decode identity credentials %s for identity %s.", cred.Type, found.ID))
}

identifierHint := foundConflictAddress
if len(cred.Identifiers) > 0 {
identifierHint = cred.Identifiers[0]
if duplicateCredErr.IdentifierHint() == "" && len(cred.Identifiers) == 1 {
duplicateCredErr.SetIdentifierHint(cred.Identifiers[0])

Check warning on line 238 in identity/manager.go

View check run for this annotation

Codecov / codecov/patch

identity/manager.go#L238

Added line #L238 was not covered by tests
}

for _, webauthn := range cfg.Credentials {
if webauthn.IsPasswordless {
duplicateCredErr.AddCredentialsType(cred.Type)
duplicateCredErr.SetIdentifierHint(identifierHint)
break
}
}
Expand All @@ -256,15 +249,12 @@
return errors.WithStack(herodot.ErrInternalServerError.WithReasonf("Unable to JSON decode identity credentials %s for identity %s.", cred.Type, found.ID))
}

identifierHint := foundConflictAddress
if len(cred.Identifiers) > 0 {
identifierHint = cred.Identifiers[0]
if duplicateCredErr.IdentifierHint() == "" && len(cred.Identifiers) == 1 {
duplicateCredErr.SetIdentifierHint(cred.Identifiers[0])

Check warning on line 253 in identity/manager.go

View check run for this annotation

Codecov / codecov/patch

identity/manager.go#L253

Added line #L253 was not covered by tests
}

for _, webauthn := range cfg.Credentials {
if webauthn.IsPasswordless {
duplicateCredErr.AddCredentialsType(cred.Type)
duplicateCredErr.SetIdentifierHint(identifierHint)
break
}
}
Expand Down Expand Up @@ -343,6 +333,7 @@
e.init()
return fmt.Sprintf("create identities error: %d identities failed", len(e.failedIdentities))
}

func (e *CreateIdentitiesError) Unwrap() []error {
e.init()
var errs []error
Expand All @@ -356,17 +347,20 @@
e.init()
e.failedIdentities[ident] = err
}

func (e *CreateIdentitiesError) Merge(other *CreateIdentitiesError) {
e.init()
for k, v := range other.failedIdentities {
e.failedIdentities[k] = v
}
}

func (e *CreateIdentitiesError) Contains(ident *Identity) bool {
e.init()
_, found := e.failedIdentities[ident]
return found
}

func (e *CreateIdentitiesError) Find(ident *Identity) *FailedIdentity {
e.init()
if err, found := e.failedIdentities[ident]; found {
Expand All @@ -375,12 +369,14 @@

return nil
}

func (e *CreateIdentitiesError) ErrOrNil() error {
if e == nil || len(e.failedIdentities) == 0 {
return nil
}
return e
}

func (e *CreateIdentitiesError) init() {
if e.failedIdentities == nil {
e.failedIdentities = map[*Identity]*herodot.DefaultError{}
Expand Down
14 changes: 9 additions & 5 deletions identity/manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,7 @@ func TestManager(t *testing.T) {
assert.ErrorAs(t, err, &verr)
assert.ElementsMatch(t, []string{"oidc"}, verr.AvailableCredentials())
assert.ElementsMatch(t, []string{"google", "github"}, verr.AvailableOIDCProviders())
// The conflicting identifier is the oidc subject, which is not useful for the user
assert.Equal(t, email, verr.IdentifierHint())
})

Expand Down Expand Up @@ -756,29 +757,31 @@ func TestManager(t *testing.T) {
require.NoError(t, reg.IdentityManager().Create(ctx, conflicOnRecoveryAddress))

t.Run("case=returns not found if no conflict", func(t *testing.T) {
found, foundConflictAddress, err := reg.IdentityManager().ConflictingIdentity(ctx, &identity.Identity{
found, foundConflictAddress, addressType, err := reg.IdentityManager().ConflictingIdentity(ctx, &identity.Identity{
Credentials: map[identity.CredentialsType]identity.Credentials{
identity.CredentialsTypePassword: {Identifiers: []string{"[email protected]"}},
},
})
assert.ErrorIs(t, err, sqlcon.ErrNoRows)
assert.Nil(t, found)
assert.Empty(t, foundConflictAddress)
assert.Empty(t, addressType)
})

t.Run("case=conflict on identifier", func(t *testing.T) {
found, foundConflictAddress, err := reg.IdentityManager().ConflictingIdentity(ctx, &identity.Identity{
found, foundConflictAddress, addressType, err := reg.IdentityManager().ConflictingIdentity(ctx, &identity.Identity{
Credentials: map[identity.CredentialsType]identity.Credentials{
identity.CredentialsTypePassword: {Identifiers: []string{"[email protected]"}},
},
})
require.NoError(t, err)
assert.Equal(t, conflicOnIdentifier.ID, found.ID)
assert.Equal(t, "[email protected]", foundConflictAddress)
assert.EqualValues(t, string(identity.CredentialsTypePassword), addressType)
})

t.Run("case=conflict on verifiable address", func(t *testing.T) {
found, foundConflictAddress, err := reg.IdentityManager().ConflictingIdentity(ctx, &identity.Identity{
found, foundConflictAddress, addressType, err := reg.IdentityManager().ConflictingIdentity(ctx, &identity.Identity{
VerifiableAddresses: []identity.VerifiableAddress{{
Value: "[email protected]",
Via: "email",
Expand All @@ -787,10 +790,10 @@ func TestManager(t *testing.T) {
require.NoError(t, err)
assert.Equal(t, conflicOnVerifiableAddress.ID, found.ID)
assert.Equal(t, "[email protected]", foundConflictAddress)
assert.Equal(t, "email", addressType)
})

t.Run("case=conflict on recovery address", func(t *testing.T) {
found, foundConflictAddress, err := reg.IdentityManager().ConflictingIdentity(ctx, &identity.Identity{
found, foundConflictAddress, addressType, err := reg.IdentityManager().ConflictingIdentity(ctx, &identity.Identity{
RecoveryAddresses: []identity.RecoveryAddress{{
Value: "[email protected]",
Via: "email",
Expand All @@ -799,6 +802,7 @@ func TestManager(t *testing.T) {
require.NoError(t, err)
assert.Equal(t, conflicOnRecoveryAddress.ID, found.ID)
assert.Equal(t, "[email protected]", foundConflictAddress)
assert.Equal(t, "email", addressType)
})
})
}
Expand Down
2 changes: 1 addition & 1 deletion selfservice/flow/registration/hook.go
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,7 @@ func (e *HookExecutor) PostRegistrationHook(w http.ResponseWriter, r *http.Reque
}

func (e *HookExecutor) getDuplicateIdentifier(ctx context.Context, i *identity.Identity) (string, error) {
_, id, err := e.d.IdentityManager().ConflictingIdentity(ctx, i)
_, id, _, err := e.d.IdentityManager().ConflictingIdentity(ctx, i)
if err != nil {
return "", err
}
Expand Down
Loading