From 10ebe7c72cba7f714755d7f08fb8729f1c0911f2 Mon Sep 17 00:00:00 2001 From: aeneasr <3372410+aeneasr@users.noreply.github.com> Date: Mon, 6 Nov 2023 14:03:26 +0100 Subject: [PATCH 1/2] feat: add batch insert --- sqlxx/batch/create.go | 275 +++++++++++++++++++++++++++++++++++++ sqlxx/batch/create_test.go | 122 ++++++++++++++++ 2 files changed, 397 insertions(+) create mode 100644 sqlxx/batch/create.go create mode 100644 sqlxx/batch/create_test.go diff --git a/sqlxx/batch/create.go b/sqlxx/batch/create.go new file mode 100644 index 00000000..801dcbdd --- /dev/null +++ b/sqlxx/batch/create.go @@ -0,0 +1,275 @@ +// Copyright © 2023 Ory Corp +// SPDX-License-Identifier: Apache-2.0 + +package batch + +import ( + "context" + "database/sql" + "fmt" + "reflect" + "sort" + "strings" + "time" + + "github.com/jmoiron/sqlx/reflectx" + + "github.com/ory/x/dbal" + + "github.com/gobuffalo/pop/v6" + "github.com/gofrs/uuid" + "github.com/pkg/errors" + + "github.com/ory/x/otelx" + "github.com/ory/x/sqlcon" + + "github.com/ory/x/sqlxx" +) + +type ( + insertQueryArgs struct { + TableName string + ColumnsDecl string + Columns []string + Placeholders string + } + quoter interface { + Quote(key string) string + } + TracerConnection struct { + Tracer *otelx.Tracer + Connection *pop.Connection + } +) + +func buildInsertQueryArgs[T any](ctx context.Context, dialect string, mapper *reflectx.Mapper, quoter quoter, models []*T) insertQueryArgs { + var ( + v T + model = pop.NewModel(v, ctx) + + columns []string + quotedColumns []string + placeholders []string + placeholderRow []string + ) + + for _, col := range model.Columns().Cols { + columns = append(columns, col.Name) + placeholderRow = append(placeholderRow, "?") + } + + // We sort for the sole reason that the test snapshots are deterministic. + sort.Strings(columns) + + for _, col := range columns { + quotedColumns = append(quotedColumns, quoter.Quote(col)) + } + + // We generate a list (for every row one) of VALUE statements here that + // will be substituted by their column values later: + // + // (?, ?, ?, ?), + // (?, ?, ?, ?), + // (?, ?, ?, ?) + for _, m := range models { + m := reflect.ValueOf(m) + + pl := make([]string, len(placeholderRow)) + copy(pl, placeholderRow) + + // There is a special case - when using CockroachDB we want to generate + // UUIDs using "gen_random_uuid()" which ends up in a VALUE statement of: + // + // (gen_random_uuid(), ?, ?, ?), + for k := range placeholderRow { + if columns[k] != "id" { + continue + } + + field := mapper.FieldByName(m, columns[k]) + val, ok := field.Interface().(uuid.UUID) + if !ok { + continue + } + + if val == uuid.Nil && dialect == dbal.DriverCockroachDB { + pl[k] = "gen_random_uuid()" + break + } + } + + placeholders = append(placeholders, fmt.Sprintf("(%s)", strings.Join(pl, ", "))) + } + + return insertQueryArgs{ + TableName: quoter.Quote(model.TableName()), + ColumnsDecl: strings.Join(quotedColumns, ", "), + Columns: columns, + Placeholders: strings.Join(placeholders, ",\n"), + } +} + +func buildInsertQueryValues[T any](dialect string, mapper *reflectx.Mapper, columns []string, models []*T, nowFunc func() time.Time) (values []any, err error) { + for _, m := range models { + m := reflect.ValueOf(m) + + now := nowFunc() + // Append model fields to args + for _, c := range columns { + field := mapper.FieldByName(m, c) + + switch c { + case "created_at": + if pop.IsZeroOfUnderlyingType(field.Interface()) { + field.Set(reflect.ValueOf(now)) + } + case "updated_at": + field.Set(reflect.ValueOf(now)) + case "id": + if field.Interface().(uuid.UUID) != uuid.Nil { + break // breaks switch, not for + } else if dialect == dbal.DriverCockroachDB { + // This is a special case: + // 1. We're using cockroach + // 2. It's the primary key field ("ID") + // 3. A UUID was not yet set. + // + // If all these conditions meet, the VALUE statement will look as such: + // + // (gen_random_uuid(), ?, ?, ?, ...) + // + // For that reason, we do not add the ID value to the list of arguments, + // because one of the arguments is using a built-in and thus doesn't need a value. + continue // break switch, not for + } + + id, err := uuid.NewV4() + if err != nil { + return nil, err + } + field.Set(reflect.ValueOf(id)) + } + + values = append(values, field.Interface()) + + // Special-handling for *sqlxx.NullTime: mapper.FieldByName sets this to a zero time.Time, + // but we want a nil pointer instead. + if i, ok := field.Interface().(*sqlxx.NullTime); ok { + if time.Time(*i).IsZero() { + field.Set(reflect.Zero(field.Type())) + } + } + } + } + + return values, nil +} + +// Create batch-inserts the given models into the database using a single INSERT statement. +// The models are either all created or none. +func Create[T any](ctx context.Context, p *TracerConnection, models []*T) (err error) { + ctx, span := p.Tracer.Tracer().Start(ctx, "persistence.sql.batch.Create") + defer otelx.End(span, &err) + + if len(models) == 0 { + return nil + } + + var v T + model := pop.NewModel(v, ctx) + + conn := p.Connection + quoter, ok := conn.Dialect.(quoter) + if !ok { + return errors.Errorf("store is not a quoter: %T", conn.Store) + } + + queryArgs := buildInsertQueryArgs(ctx, conn.Dialect.Name(), conn.TX.Mapper, quoter, models) + values, err := buildInsertQueryValues(conn.Dialect.Name(), conn.TX.Mapper, queryArgs.Columns, models, func() time.Time { return time.Now().UTC().Truncate(time.Microsecond) }) + if err != nil { + return err + } + + var returningClause string + if conn.Dialect.Name() != dbal.DriverMySQL { + // PostgreSQL, CockroachDB, SQLite support RETURNING. + returningClause = fmt.Sprintf("RETURNING %s", model.IDField()) + } + + query := conn.Dialect.TranslateSQL(fmt.Sprintf( + "INSERT INTO %s (%s) VALUES\n%s\n%s", + queryArgs.TableName, + queryArgs.ColumnsDecl, + queryArgs.Placeholders, + returningClause, + )) + + rows, err := conn.TX.QueryContext(ctx, query, values...) + if err != nil { + return sqlcon.HandleError(err) + } + defer rows.Close() + + // Hydrate the models from the RETURNING clause. + // + // Databases not supporting RETURNING will just return 0 rows. + count := 0 + for rows.Next() { + if err := rows.Err(); err != nil { + return sqlcon.HandleError(err) + } + + if err := setModelID(rows, pop.NewModel(models[count], ctx)); err != nil { + return err + } + count++ + } + + if err := rows.Err(); err != nil { + return sqlcon.HandleError(err) + } + + if err := rows.Close(); err != nil { + return sqlcon.HandleError(err) + } + + return sqlcon.HandleError(err) +} + +// setModelID was copy & pasted from pop. It basically sets +// the primary key to the given value read from the SQL row. +func setModelID(row *sql.Rows, model *pop.Model) error { + el := reflect.ValueOf(model.Value).Elem() + fbn := el.FieldByName("ID") + if !fbn.IsValid() { + return errors.New("model does not have a field named id") + } + + pkt, err := model.PrimaryKeyType() + if err != nil { + return errors.WithStack(err) + } + + switch pkt { + case "UUID": + var id uuid.UUID + if err := row.Scan(&id); err != nil { + return errors.WithStack(err) + } + fbn.Set(reflect.ValueOf(id)) + default: + var id interface{} + if err := row.Scan(&id); err != nil { + return errors.WithStack(err) + } + v := reflect.ValueOf(id) + switch fbn.Kind() { + case reflect.Int, reflect.Int64: + fbn.SetInt(v.Int()) + default: + fbn.Set(reflect.ValueOf(id)) + } + } + + return nil +} diff --git a/sqlxx/batch/create_test.go b/sqlxx/batch/create_test.go new file mode 100644 index 00000000..49c0ac46 --- /dev/null +++ b/sqlxx/batch/create_test.go @@ -0,0 +1,122 @@ +// Copyright © 2023 Ory Corp +// SPDX-License-Identifier: Apache-2.0 + +package batch + +import ( + "context" + "fmt" + "testing" + "time" + + "github.com/ory/x/dbal" + + "github.com/gofrs/uuid" + "github.com/jmoiron/sqlx/reflectx" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/ory/x/snapshotx" + "github.com/ory/x/sqlxx" +) + +type ( + testModel struct { + ID uuid.UUID `db:"id"` + NID uuid.UUID `db:"nid"` + String string `db:"string"` + Int int `db:"int"` + NullTimePtr *sqlxx.NullTime `db:"null_time_ptr"` + CreatedAt time.Time `json:"created_at" db:"created_at"` + UpdatedAt time.Time `json:"updated_at" db:"updated_at"` + } + testQuoter struct{} +) + +func (i testModel) TableName(ctx context.Context) string { + return "test_models" +} + +func (tq testQuoter) Quote(s string) string { return fmt.Sprintf("%q", s) } + +func makeModels[T any]() []*T { + models := make([]*T, 10) + for k := range models { + models[k] = new(T) + } + return models +} + +func Test_buildInsertQueryArgs(t *testing.T) { + ctx := context.Background() + t.Run("case=testModel", func(t *testing.T) { + models := makeModels[testModel]() + mapper := reflectx.NewMapper("db") + args := buildInsertQueryArgs(ctx, "other", mapper, testQuoter{}, models) + snapshotx.SnapshotT(t, args) + + query := fmt.Sprintf("INSERT INTO %s (%s) VALUES\n%s", args.TableName, args.ColumnsDecl, args.Placeholders) + assert.Equal(t, `INSERT INTO "test_models" ("created_at", "id", "int", "nid", "null_time_ptr", "string", "updated_at") VALUES +(?, ?, ?, ?, ?, ?, ?), +(?, ?, ?, ?, ?, ?, ?), +(?, ?, ?, ?, ?, ?, ?), +(?, ?, ?, ?, ?, ?, ?), +(?, ?, ?, ?, ?, ?, ?), +(?, ?, ?, ?, ?, ?, ?), +(?, ?, ?, ?, ?, ?, ?), +(?, ?, ?, ?, ?, ?, ?), +(?, ?, ?, ?, ?, ?, ?), +(?, ?, ?, ?, ?, ?, ?)`, query) + }) + + t.Run("case=cockroach", func(t *testing.T) { + models := makeModels[testModel]() + for k := range models { + if k%3 == 0 { + models[k].ID = uuid.FromStringOrNil(fmt.Sprintf("ae0125a9-2786-4ada-82d2-d169cf75047%d", k)) + } + } + mapper := reflectx.NewMapper("db") + args := buildInsertQueryArgs(ctx, "cockroach", mapper, testQuoter{}, models) + snapshotx.SnapshotT(t, args) + }) +} + +func Test_buildInsertQueryValues(t *testing.T) { + t.Run("case=testModel", func(t *testing.T) { + model := &testModel{ + String: "string", + Int: 42, + } + mapper := reflectx.NewMapper("db") + + nowFunc := func() time.Time { + return time.Time{} + } + t.Run("case=cockroach", func(t *testing.T) { + values, err := buildInsertQueryValues(dbal.DriverCockroachDB, mapper, []string{"created_at", "updated_at", "id", "string", "int", "null_time_ptr", "traits"}, []*testModel{model}, nowFunc) + require.NoError(t, err) + snapshotx.SnapshotT(t, values) + }) + + t.Run("case=others", func(t *testing.T) { + values, err := buildInsertQueryValues("other", mapper, []string{"created_at", "updated_at", "id", "string", "int", "null_time_ptr", "traits"}, []*testModel{model}, nowFunc) + require.NoError(t, err) + + assert.NotNil(t, model.CreatedAt) + assert.Equal(t, model.CreatedAt, values[0]) + + assert.NotNil(t, model.UpdatedAt) + assert.Equal(t, model.UpdatedAt, values[1]) + + assert.NotZero(t, model.ID) + assert.Equal(t, model.ID, values[2]) + + assert.Equal(t, model.String, values[3]) + assert.Equal(t, model.Int, values[4]) + + assert.Nil(t, model.NullTimePtr) + + }) + }) +} From 4010ffa85b4e555fab8375cf5f3667d2babb51df Mon Sep 17 00:00:00 2001 From: aeneasr <3372410+aeneasr@users.noreply.github.com> Date: Mon, 6 Nov 2023 14:16:50 +0100 Subject: [PATCH 2/2] chore: synchronize workspaces --- ...Test_buildInsertQueryArgs-case=cockroach.json | 14 ++++++++++++++ ...Test_buildInsertQueryArgs-case=testModel.json | 14 ++++++++++++++ ...ueryValues-case=testModel-case=cockroach.json | 16 ++++++++++++++++ 3 files changed, 44 insertions(+) create mode 100644 sqlxx/batch/.snapshots/Test_buildInsertQueryArgs-case=cockroach.json create mode 100644 sqlxx/batch/.snapshots/Test_buildInsertQueryArgs-case=testModel.json create mode 100644 sqlxx/batch/.snapshots/Test_buildInsertQueryValues-case=testModel-case=cockroach.json diff --git a/sqlxx/batch/.snapshots/Test_buildInsertQueryArgs-case=cockroach.json b/sqlxx/batch/.snapshots/Test_buildInsertQueryArgs-case=cockroach.json new file mode 100644 index 00000000..51b3ae70 --- /dev/null +++ b/sqlxx/batch/.snapshots/Test_buildInsertQueryArgs-case=cockroach.json @@ -0,0 +1,14 @@ +{ + "TableName": "\"test_models\"", + "ColumnsDecl": "\"created_at\", \"id\", \"int\", \"nid\", \"null_time_ptr\", \"string\", \"updated_at\"", + "Columns": [ + "created_at", + "id", + "int", + "nid", + "null_time_ptr", + "string", + "updated_at" + ], + "Placeholders": "(?, ?, ?, ?, ?, ?, ?),\n(?, gen_random_uuid(), ?, ?, ?, ?, ?),\n(?, gen_random_uuid(), ?, ?, ?, ?, ?),\n(?, ?, ?, ?, ?, ?, ?),\n(?, gen_random_uuid(), ?, ?, ?, ?, ?),\n(?, gen_random_uuid(), ?, ?, ?, ?, ?),\n(?, ?, ?, ?, ?, ?, ?),\n(?, gen_random_uuid(), ?, ?, ?, ?, ?),\n(?, gen_random_uuid(), ?, ?, ?, ?, ?),\n(?, ?, ?, ?, ?, ?, ?)" +} diff --git a/sqlxx/batch/.snapshots/Test_buildInsertQueryArgs-case=testModel.json b/sqlxx/batch/.snapshots/Test_buildInsertQueryArgs-case=testModel.json new file mode 100644 index 00000000..db458b94 --- /dev/null +++ b/sqlxx/batch/.snapshots/Test_buildInsertQueryArgs-case=testModel.json @@ -0,0 +1,14 @@ +{ + "TableName": "\"test_models\"", + "ColumnsDecl": "\"created_at\", \"id\", \"int\", \"nid\", \"null_time_ptr\", \"string\", \"updated_at\"", + "Columns": [ + "created_at", + "id", + "int", + "nid", + "null_time_ptr", + "string", + "updated_at" + ], + "Placeholders": "(?, ?, ?, ?, ?, ?, ?),\n(?, ?, ?, ?, ?, ?, ?),\n(?, ?, ?, ?, ?, ?, ?),\n(?, ?, ?, ?, ?, ?, ?),\n(?, ?, ?, ?, ?, ?, ?),\n(?, ?, ?, ?, ?, ?, ?),\n(?, ?, ?, ?, ?, ?, ?),\n(?, ?, ?, ?, ?, ?, ?),\n(?, ?, ?, ?, ?, ?, ?),\n(?, ?, ?, ?, ?, ?, ?)" +} diff --git a/sqlxx/batch/.snapshots/Test_buildInsertQueryValues-case=testModel-case=cockroach.json b/sqlxx/batch/.snapshots/Test_buildInsertQueryValues-case=testModel-case=cockroach.json new file mode 100644 index 00000000..c5bdc385 --- /dev/null +++ b/sqlxx/batch/.snapshots/Test_buildInsertQueryValues-case=testModel-case=cockroach.json @@ -0,0 +1,16 @@ +[ + "0001-01-01T00:00:00Z", + "0001-01-01T00:00:00Z", + "string", + 42, + null, + { + "ID": "00000000-0000-0000-0000-000000000000", + "NID": "00000000-0000-0000-0000-000000000000", + "String": "string", + "Int": 42, + "NullTimePtr": null, + "created_at": "0001-01-01T00:00:00Z", + "updated_at": "0001-01-01T00:00:00Z" + } +]