Skip to content

Commit

Permalink
chore: synchronize workspaces
Browse files Browse the repository at this point in the history
  • Loading branch information
aeneasr committed Oct 3, 2024
1 parent 4f1a2b7 commit 35ec07f
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 9 deletions.
13 changes: 11 additions & 2 deletions identity/pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,11 @@ type (
CredentialsIdentifierSimilar string
DeclassifyCredentials []CredentialsType
KeySetPagination []keysetpagination.Option
ConsistencyLevel crdbx.ConsistencyLevel
StatementTransformer func(string) string

// DEPRECATED
PagePagination *x.Page
ConsistencyLevel crdbx.ConsistencyLevel
PagePagination *x.Page
}

Pool interface {
Expand Down Expand Up @@ -114,3 +116,10 @@ type (
FindIdentityByWebauthnUserHandle(ctx context.Context, userHandle []byte) (*Identity, error)
}
)

func (p ListIdentityParameters) TransformStatement(statement string) string {
if p.StatementTransformer != nil {
return p.StatementTransformer(statement)

Check warning on line 122 in identity/pool.go

View check run for this annotation

Codecov / codecov/patch

identity/pool.go#L122

Added line #L122 was not covered by tests
}
return statement
}
43 changes: 36 additions & 7 deletions persistence/sql/identity/persister_identity.go
Original file line number Diff line number Diff line change
Expand Up @@ -876,6 +876,23 @@ func paginationAttributes(params *identity.ListIdentityParameters, paginator *ke
return attrs
}

// getCredentialTypeIDs returns a map of credential types to their respective IDs.
//
// If a credential type is not found, an error is returned.
func (p *IdentityPersister) getCredentialTypeIDs(ctx context.Context, credentialTypes []identity.CredentialsType) (map[identity.CredentialsType]uuid.UUID, error) {
result := map[identity.CredentialsType]uuid.UUID{}

for _, ct := range credentialTypes {
typeID, err := p.findIdentityCredentialsType(ctx, ct)
if err != nil {
return nil, err

Check warning on line 888 in persistence/sql/identity/persister_identity.go

View check run for this annotation

Codecov / codecov/patch

persistence/sql/identity/persister_identity.go#L888

Added line #L888 was not covered by tests
}
result[ct] = typeID.ID
}

return result, nil
}

func (p *IdentityPersister) ListIdentities(ctx context.Context, params identity.ListIdentityParameters) (_ []identity.Identity, nextPage *keysetpagination.Paginator, err error) {
paginator := keysetpagination.GetPaginator(append(
params.KeySetPagination,
Expand Down Expand Up @@ -925,22 +942,34 @@ func (p *IdentityPersister) ListIdentities(ctx context.Context, params identity.
}

if len(identifier) > 0 {
types, err := p.getCredentialTypeIDs(ctx, []identity.CredentialsType{
identity.CredentialsTypeWebAuthn,
identity.CredentialsTypePassword,
identity.CredentialsTypeCodeAuth,
identity.CredentialsTypeOIDC,
})
if err != nil {
return err

Check warning on line 952 in persistence/sql/identity/persister_identity.go

View check run for this annotation

Codecov / codecov/patch

persistence/sql/identity/persister_identity.go#L952

Added line #L952 was not covered by tests
}

// When filtering by credentials identifier, we most likely are looking for a username or email. It is therefore
// important to normalize the identifier before querying the database.

joins = `
joins = params.TransformStatement(`
INNER JOIN identity_credentials ic ON ic.identity_id = identities.id
INNER JOIN identity_credential_types ict ON ict.id = ic.identity_credential_type_id
INNER JOIN identity_credential_identifiers ici ON ici.identity_credential_id = ic.id`
INNER JOIN identity_credential_identifiers ici ON ici.identity_credential_id = ic.id`)

wheres += fmt.Sprintf(`
AND ic.nid = ? AND ici.nid = ?
AND ((ict.name IN (?, ?, ?) AND ici.identifier %s ?)
OR (ict.name IN (?) AND ici.identifier %s ?))
AND ((ic.identity_credential_type_id IN (?, ?, ?) AND ici.identifier %s ?)
OR (ic.identity_credential_type_id IN (?) AND ici.identifier %s ?))
`, identifierOperator, identifierOperator)
args = append(args,
nid, nid,
identity.CredentialsTypeWebAuthn, identity.CredentialsTypePassword, identity.CredentialsTypeCodeAuth, NormalizeIdentifier(identity.CredentialsTypePassword, identifier),
identity.CredentialsTypeOIDC, identifier)
types[identity.CredentialsTypeWebAuthn], types[identity.CredentialsTypePassword], types[identity.CredentialsTypeCodeAuth],
NormalizeIdentifier(identity.CredentialsTypePassword, identifier),
types[identity.CredentialsTypeOIDC], identifier,
)
}

if params.IdsFilter != nil && len(params.IdsFilter) != 0 {
Expand Down

0 comments on commit 35ec07f

Please sign in to comment.