diff --git a/identity/pool.go b/identity/pool.go index 30a7308245b4..e07a6b8ee83e 100644 --- a/identity/pool.go +++ b/identity/pool.go @@ -23,9 +23,11 @@ type ( CredentialsIdentifierSimilar string DeclassifyCredentials []CredentialsType KeySetPagination []keysetpagination.Option + ConsistencyLevel crdbx.ConsistencyLevel + StatementTransformer func(string) string + // DEPRECATED - PagePagination *x.Page - ConsistencyLevel crdbx.ConsistencyLevel + PagePagination *x.Page } Pool interface { @@ -114,3 +116,10 @@ type ( FindIdentityByWebauthnUserHandle(ctx context.Context, userHandle []byte) (*Identity, error) } ) + +func (p ListIdentityParameters) TransformStatement(statement string) string { + if p.StatementTransformer != nil { + return p.StatementTransformer(statement) + } + return statement +} diff --git a/persistence/sql/identity/persister_identity.go b/persistence/sql/identity/persister_identity.go index 4523ce7f2146..a9ca50407a19 100644 --- a/persistence/sql/identity/persister_identity.go +++ b/persistence/sql/identity/persister_identity.go @@ -876,6 +876,23 @@ func paginationAttributes(params *identity.ListIdentityParameters, paginator *ke return attrs } +// getCredentialTypeIDs returns a map of credential types to their respective IDs. +// +// If a credential type is not found, an error is returned. +func (p *IdentityPersister) getCredentialTypeIDs(ctx context.Context, credentialTypes []identity.CredentialsType) (map[identity.CredentialsType]uuid.UUID, error) { + result := map[identity.CredentialsType]uuid.UUID{} + + for _, ct := range credentialTypes { + typeID, err := p.findIdentityCredentialsType(ctx, ct) + if err != nil { + return nil, err + } + result[ct] = typeID.ID + } + + return result, nil +} + func (p *IdentityPersister) ListIdentities(ctx context.Context, params identity.ListIdentityParameters) (_ []identity.Identity, nextPage *keysetpagination.Paginator, err error) { paginator := keysetpagination.GetPaginator(append( params.KeySetPagination, @@ -925,22 +942,34 @@ func (p *IdentityPersister) ListIdentities(ctx context.Context, params identity. } if len(identifier) > 0 { + types, err := p.getCredentialTypeIDs(ctx, []identity.CredentialsType{ + identity.CredentialsTypeWebAuthn, + identity.CredentialsTypePassword, + identity.CredentialsTypeCodeAuth, + identity.CredentialsTypeOIDC, + }) + if err != nil { + return err + } + // When filtering by credentials identifier, we most likely are looking for a username or email. It is therefore // important to normalize the identifier before querying the database. - joins = ` + joins = params.TransformStatement(` INNER JOIN identity_credentials ic ON ic.identity_id = identities.id - INNER JOIN identity_credential_types ict ON ict.id = ic.identity_credential_type_id - INNER JOIN identity_credential_identifiers ici ON ici.identity_credential_id = ic.id` + INNER JOIN identity_credential_identifiers ici ON ici.identity_credential_id = ic.id`) + wheres += fmt.Sprintf(` AND ic.nid = ? AND ici.nid = ? - AND ((ict.name IN (?, ?, ?) AND ici.identifier %s ?) - OR (ict.name IN (?) AND ici.identifier %s ?)) + AND ((ic.identity_credential_type_id IN (?, ?, ?) AND ici.identifier %s ?) + OR (ic.identity_credential_type_id IN (?) AND ici.identifier %s ?)) `, identifierOperator, identifierOperator) args = append(args, nid, nid, - identity.CredentialsTypeWebAuthn, identity.CredentialsTypePassword, identity.CredentialsTypeCodeAuth, NormalizeIdentifier(identity.CredentialsTypePassword, identifier), - identity.CredentialsTypeOIDC, identifier) + types[identity.CredentialsTypeWebAuthn], types[identity.CredentialsTypePassword], types[identity.CredentialsTypeCodeAuth], + NormalizeIdentifier(identity.CredentialsTypePassword, identifier), + types[identity.CredentialsTypeOIDC], identifier, + ) } if params.IdsFilter != nil && len(params.IdsFilter) != 0 {