Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add batch insert #744

Merged
merged 2 commits into from
Nov 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
{
"TableName": "\"test_models\"",
"ColumnsDecl": "\"created_at\", \"id\", \"int\", \"nid\", \"null_time_ptr\", \"string\", \"updated_at\"",
"Columns": [
"created_at",
"id",
"int",
"nid",
"null_time_ptr",
"string",
"updated_at"
],
"Placeholders": "(?, ?, ?, ?, ?, ?, ?),\n(?, gen_random_uuid(), ?, ?, ?, ?, ?),\n(?, gen_random_uuid(), ?, ?, ?, ?, ?),\n(?, ?, ?, ?, ?, ?, ?),\n(?, gen_random_uuid(), ?, ?, ?, ?, ?),\n(?, gen_random_uuid(), ?, ?, ?, ?, ?),\n(?, ?, ?, ?, ?, ?, ?),\n(?, gen_random_uuid(), ?, ?, ?, ?, ?),\n(?, gen_random_uuid(), ?, ?, ?, ?, ?),\n(?, ?, ?, ?, ?, ?, ?)"
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
{
"TableName": "\"test_models\"",
"ColumnsDecl": "\"created_at\", \"id\", \"int\", \"nid\", \"null_time_ptr\", \"string\", \"updated_at\"",
"Columns": [
"created_at",
"id",
"int",
"nid",
"null_time_ptr",
"string",
"updated_at"
],
"Placeholders": "(?, ?, ?, ?, ?, ?, ?),\n(?, ?, ?, ?, ?, ?, ?),\n(?, ?, ?, ?, ?, ?, ?),\n(?, ?, ?, ?, ?, ?, ?),\n(?, ?, ?, ?, ?, ?, ?),\n(?, ?, ?, ?, ?, ?, ?),\n(?, ?, ?, ?, ?, ?, ?),\n(?, ?, ?, ?, ?, ?, ?),\n(?, ?, ?, ?, ?, ?, ?),\n(?, ?, ?, ?, ?, ?, ?)"
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
[
"0001-01-01T00:00:00Z",
"0001-01-01T00:00:00Z",
"string",
42,
null,
{
"ID": "00000000-0000-0000-0000-000000000000",
"NID": "00000000-0000-0000-0000-000000000000",
"String": "string",
"Int": 42,
"NullTimePtr": null,
"created_at": "0001-01-01T00:00:00Z",
"updated_at": "0001-01-01T00:00:00Z"
}
]
275 changes: 275 additions & 0 deletions sqlxx/batch/create.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,275 @@
// Copyright © 2023 Ory Corp
// SPDX-License-Identifier: Apache-2.0

package batch

import (
"context"
"database/sql"
"fmt"
"reflect"
"sort"
"strings"
"time"

"github.com/jmoiron/sqlx/reflectx"

"github.com/ory/x/dbal"

"github.com/gobuffalo/pop/v6"
"github.com/gofrs/uuid"
"github.com/pkg/errors"

"github.com/ory/x/otelx"
"github.com/ory/x/sqlcon"

"github.com/ory/x/sqlxx"
)

type (
insertQueryArgs struct {
TableName string
ColumnsDecl string
Columns []string
Placeholders string
}
quoter interface {
Quote(key string) string
}
TracerConnection struct {
Tracer *otelx.Tracer
Connection *pop.Connection
}
)

func buildInsertQueryArgs[T any](ctx context.Context, dialect string, mapper *reflectx.Mapper, quoter quoter, models []*T) insertQueryArgs {
var (
v T
model = pop.NewModel(v, ctx)

columns []string
quotedColumns []string
placeholders []string
placeholderRow []string
)

for _, col := range model.Columns().Cols {
columns = append(columns, col.Name)
placeholderRow = append(placeholderRow, "?")
}

// We sort for the sole reason that the test snapshots are deterministic.
sort.Strings(columns)

for _, col := range columns {
quotedColumns = append(quotedColumns, quoter.Quote(col))
}

// We generate a list (for every row one) of VALUE statements here that
// will be substituted by their column values later:
//
// (?, ?, ?, ?),
// (?, ?, ?, ?),
// (?, ?, ?, ?)
for _, m := range models {
m := reflect.ValueOf(m)

pl := make([]string, len(placeholderRow))
copy(pl, placeholderRow)

// There is a special case - when using CockroachDB we want to generate
// UUIDs using "gen_random_uuid()" which ends up in a VALUE statement of:
//
// (gen_random_uuid(), ?, ?, ?),
for k := range placeholderRow {
if columns[k] != "id" {
continue
}

field := mapper.FieldByName(m, columns[k])
val, ok := field.Interface().(uuid.UUID)
if !ok {
continue
}

if val == uuid.Nil && dialect == dbal.DriverCockroachDB {
pl[k] = "gen_random_uuid()"
break
}
}

placeholders = append(placeholders, fmt.Sprintf("(%s)", strings.Join(pl, ", ")))
}

return insertQueryArgs{
TableName: quoter.Quote(model.TableName()),
ColumnsDecl: strings.Join(quotedColumns, ", "),
Columns: columns,
Placeholders: strings.Join(placeholders, ",\n"),
}
}

func buildInsertQueryValues[T any](dialect string, mapper *reflectx.Mapper, columns []string, models []*T, nowFunc func() time.Time) (values []any, err error) {
for _, m := range models {
m := reflect.ValueOf(m)

now := nowFunc()
// Append model fields to args
for _, c := range columns {
field := mapper.FieldByName(m, c)

switch c {
case "created_at":
if pop.IsZeroOfUnderlyingType(field.Interface()) {
field.Set(reflect.ValueOf(now))
}
case "updated_at":
field.Set(reflect.ValueOf(now))
case "id":
if field.Interface().(uuid.UUID) != uuid.Nil {
break // breaks switch, not for
} else if dialect == dbal.DriverCockroachDB {
// This is a special case:
// 1. We're using cockroach
// 2. It's the primary key field ("ID")
// 3. A UUID was not yet set.
//
// If all these conditions meet, the VALUE statement will look as such:
//
// (gen_random_uuid(), ?, ?, ?, ...)
//
// For that reason, we do not add the ID value to the list of arguments,
// because one of the arguments is using a built-in and thus doesn't need a value.
continue // break switch, not for
}

id, err := uuid.NewV4()
if err != nil {
return nil, err
}
field.Set(reflect.ValueOf(id))
}

values = append(values, field.Interface())

// Special-handling for *sqlxx.NullTime: mapper.FieldByName sets this to a zero time.Time,
// but we want a nil pointer instead.
if i, ok := field.Interface().(*sqlxx.NullTime); ok {
if time.Time(*i).IsZero() {
field.Set(reflect.Zero(field.Type()))
}
}
}
}

return values, nil
}

// Create batch-inserts the given models into the database using a single INSERT statement.
// The models are either all created or none.
func Create[T any](ctx context.Context, p *TracerConnection, models []*T) (err error) {
ctx, span := p.Tracer.Tracer().Start(ctx, "persistence.sql.batch.Create")
defer otelx.End(span, &err)

if len(models) == 0 {
return nil
}

var v T
model := pop.NewModel(v, ctx)

conn := p.Connection
quoter, ok := conn.Dialect.(quoter)
if !ok {
return errors.Errorf("store is not a quoter: %T", conn.Store)
}

queryArgs := buildInsertQueryArgs(ctx, conn.Dialect.Name(), conn.TX.Mapper, quoter, models)
values, err := buildInsertQueryValues(conn.Dialect.Name(), conn.TX.Mapper, queryArgs.Columns, models, func() time.Time { return time.Now().UTC().Truncate(time.Microsecond) })
if err != nil {
return err
}

var returningClause string
if conn.Dialect.Name() != dbal.DriverMySQL {
// PostgreSQL, CockroachDB, SQLite support RETURNING.
returningClause = fmt.Sprintf("RETURNING %s", model.IDField())
}

query := conn.Dialect.TranslateSQL(fmt.Sprintf(
"INSERT INTO %s (%s) VALUES\n%s\n%s",
queryArgs.TableName,
queryArgs.ColumnsDecl,
queryArgs.Placeholders,
returningClause,
))

rows, err := conn.TX.QueryContext(ctx, query, values...)
if err != nil {
return sqlcon.HandleError(err)
}
defer rows.Close()

// Hydrate the models from the RETURNING clause.
//
// Databases not supporting RETURNING will just return 0 rows.
count := 0
for rows.Next() {
if err := rows.Err(); err != nil {
return sqlcon.HandleError(err)
}

if err := setModelID(rows, pop.NewModel(models[count], ctx)); err != nil {
return err
}
count++
}

if err := rows.Err(); err != nil {
return sqlcon.HandleError(err)
}

if err := rows.Close(); err != nil {
return sqlcon.HandleError(err)
}

return sqlcon.HandleError(err)
}

// setModelID was copy & pasted from pop. It basically sets
// the primary key to the given value read from the SQL row.
func setModelID(row *sql.Rows, model *pop.Model) error {
el := reflect.ValueOf(model.Value).Elem()
fbn := el.FieldByName("ID")
if !fbn.IsValid() {
return errors.New("model does not have a field named id")
}

pkt, err := model.PrimaryKeyType()
if err != nil {
return errors.WithStack(err)
}

switch pkt {
case "UUID":
var id uuid.UUID
if err := row.Scan(&id); err != nil {
return errors.WithStack(err)
}
fbn.Set(reflect.ValueOf(id))
default:
var id interface{}
if err := row.Scan(&id); err != nil {
return errors.WithStack(err)
}
v := reflect.ValueOf(id)
switch fbn.Kind() {
case reflect.Int, reflect.Int64:
fbn.SetInt(v.Int())
default:
fbn.Set(reflect.ValueOf(id))
}
}

return nil
}
Loading
Loading