From df8735f5440646cb886b87191feb320416b226b4 Mon Sep 17 00:00:00 2001 From: Yoann La Cancellera Date: Thu, 23 Feb 2023 16:37:38 +0100 Subject: [PATCH] Add context counterparts from sql funcs --- sqlutils/sqlutils.go | 102 ++++++++++++++++++++++++++++++++++--------- 1 file changed, 82 insertions(+), 20 deletions(-) diff --git a/sqlutils/sqlutils.go b/sqlutils/sqlutils.go index 0a2eda2..920d0f2 100644 --- a/sqlutils/sqlutils.go +++ b/sqlutils/sqlutils.go @@ -17,6 +17,7 @@ package sqlutils import ( + "context" "database/sql" "encoding/json" "errors" @@ -252,17 +253,17 @@ func ScanRowsToMaps(rows *sql.Rows, on_row func(RowMap) error) error { return err } -// QueryRowsMap is a convenience function allowing querying a result set while poviding a callback +// QueryRowsMapContext is a convenience function allowing querying a result set while poviding a callback // function activated per read row. -func QueryRowsMap(db *sql.DB, query string, on_row func(RowMap) error, args ...interface{}) (err error) { +func QueryRowsMapContext(ctx context.Context, db *sql.DB, query string, on_row func(RowMap) error, args ...interface{}) (err error) { defer func() { if derr := recover(); derr != nil { - err = fmt.Errorf("QueryRowsMap unexpected error: %+v", derr) + err = fmt.Errorf("QueryRowsMapContext unexpected error: %+v", derr) } }() var rows *sql.Rows - rows, err = db.Query(query, args...) + rows, err = db.QueryContext(ctx, query, args...) if rows != nil { defer rows.Close() } @@ -273,16 +274,22 @@ func QueryRowsMap(db *sql.DB, query string, on_row func(RowMap) error, args ...i return } -// queryResultData returns a raw array of rows for a given query, optionally reading and returning column names -func queryResultData(db *sql.DB, query string, retrieveColumns bool, args ...interface{}) (resultData ResultData, columns []string, err error) { +// QueryRowsMap is a convenience function allowing querying a result set while poviding a callback +// function activated per read row. +func QueryRowsMap(db *sql.DB, query string, on_row func(RowMap) error, args ...interface{}) (err error) { + return QueryRowsMapContext(context.Background(), db, query, on_row, args...) +} + +// queryResultDataContext returns a raw array of rows for a given query, optionally reading and returning column names +func queryResultDataContext(ctx context.Context, db *sql.DB, query string, retrieveColumns bool, args ...interface{}) (resultData ResultData, columns []string, err error) { defer func() { if derr := recover(); derr != nil { - err = errors.New(fmt.Sprintf("QueryRowsMap unexpected error: %+v", derr)) + err = errors.New(fmt.Sprintf("queryResultDataContext unexpected error: %+v", derr)) } }() var rows *sql.Rows - rows, err = db.Query(query, args...) + rows, err = db.QueryContext(ctx, query, args...) defer rows.Close() if err != nil && err != sql.ErrNoRows { return EmptyResultData, columns, err @@ -299,23 +306,40 @@ func queryResultData(db *sql.DB, query string, retrieveColumns bool, args ...int return resultData, columns, err } +// queryResultData returns a raw array of rows for a given query, optionally reading and returning column names +func queryResultData(db *sql.DB, query string, retrieveColumns bool, args ...interface{}) (resultData ResultData, columns []string, err error) { + return queryResultDataContext(context.Background(), db, query, retrieveColumns, args...) +} + +// QueryResultDataContext returns a raw array of rows +func QueryResultDataContext(ctx context.Context, db *sql.DB, query string, args ...interface{}) (ResultData, error) { + resultData, _, err := queryResultDataContext(ctx, db, query, false, args...) + return resultData, err +} + // QueryResultData returns a raw array of rows func QueryResultData(db *sql.DB, query string, args ...interface{}) (ResultData, error) { resultData, _, err := queryResultData(db, query, false, args...) return resultData, err } +// QueryResultDataNamedContext returns a raw array of rows, with column names +func QueryNamedResultDataContext(ctx context.Context, db *sql.DB, query string, args ...interface{}) (NamedResultData, error) { + resultData, columns, err := queryResultDataContext(ctx, db, query, true, args...) + return NamedResultData{Columns: columns, Data: resultData}, err +} + // QueryResultDataNamed returns a raw array of rows, with column names func QueryNamedResultData(db *sql.DB, query string, args ...interface{}) (NamedResultData, error) { resultData, columns, err := queryResultData(db, query, true, args...) return NamedResultData{Columns: columns, Data: resultData}, err } -// QueryRowsMapBuffered reads data from the database into a buffer, and only then applies the given function per row. +// QueryRowsMapBufferedContext reads data from the database into a buffer, and only then applies the given function per row. // This allows the application to take its time with processing the data, albeit consuming as much memory as required by // the result set. -func QueryRowsMapBuffered(db *sql.DB, query string, on_row func(RowMap) error, args ...interface{}) error { - resultData, columns, err := queryResultData(db, query, true, args...) +func QueryRowsMapBufferedContext(ctx context.Context, db *sql.DB, query string, on_row func(RowMap) error, args ...interface{}) error { + resultData, columns, err := queryResultDataContext(ctx, db, query, true, args...) if err != nil { // Already logged return err @@ -329,48 +353,77 @@ func QueryRowsMapBuffered(db *sql.DB, query string, on_row func(RowMap) error, a return nil } -// ExecNoPrepare executes given query using given args on given DB, without using prepared statements. -func ExecNoPrepare(db *sql.DB, query string, args ...interface{}) (res sql.Result, err error) { +// QueryRowsMapBuffered reads data from the database into a buffer, and only then applies the given function per row. +// This allows the application to take its time with processing the data, albeit consuming as much memory as required by +// the result set. +func QueryRowsMapBuffered(db *sql.DB, query string, on_row func(RowMap) error, args ...interface{}) error { + return QueryRowsMapBufferedContext(context.Background(), db, query, on_row, args...) +} + +// ExecNoPrepareContext executes given query using given args on given DB, without using prepared statements. +func ExecNoPrepareContext(ctx context.Context, db *sql.DB, query string, args ...interface{}) (res sql.Result, err error) { defer func() { if derr := recover(); derr != nil { err = errors.New(fmt.Sprintf("ExecNoPrepare unexpected error: %+v", derr)) } }() - res, err = db.Exec(query, args...) + res, err = db.ExecContext(ctx, query, args...) if err != nil { log.Errore(err) } return res, err } -// ExecQuery executes given query using given args on given DB. It will safele prepare, execute and close +// ExecNoPrepare executes given query using given args on given DB, without using prepared statements. +func ExecNoPrepare(db *sql.DB, query string, args ...interface{}) (res sql.Result, err error) { + return ExecNoPrepareContext(context.Background(), db, query, args...) +} + +// ExecQueryContext executes given query using given args on given DB. It will safele prepare, execute and close // the statement. -func execInternal(silent bool, db *sql.DB, query string, args ...interface{}) (res sql.Result, err error) { +func execInternalContext(ctx context.Context, silent bool, db *sql.DB, query string, args ...interface{}) (res sql.Result, err error) { defer func() { if derr := recover(); derr != nil { err = errors.New(fmt.Sprintf("execInternal unexpected error: %+v", derr)) } }() var stmt *sql.Stmt - stmt, err = db.Prepare(query) + stmt, err = db.PrepareContext(ctx, query) if err != nil { return nil, err } defer stmt.Close() - res, err = stmt.Exec(args...) + res, err = stmt.ExecContext(ctx, args...) if err != nil && !silent { log.Errore(err) } return res, err } +// ExecQuery executes given query using given args on given DB. It will safele prepare, execute and close +// the statement. +func execInternal(silent bool, db *sql.DB, query string, args ...interface{}) (res sql.Result, err error) { + return execInternalContext(context.Background(), silent, db, query, args...) +} + +// ExecContext executes given query using given args on given DB. It will safele prepare, execute and close +// the statement. +func ExecContext(ctx context.Context, db *sql.DB, query string, args ...interface{}) (sql.Result, error) { + return execInternalContext(ctx, false, db, query, args...) +} + // Exec executes given query using given args on given DB. It will safele prepare, execute and close // the statement. func Exec(db *sql.DB, query string, args ...interface{}) (sql.Result, error) { return execInternal(false, db, query, args...) } +// ExecSilentlyContext acts like Exec but does not report any error +func ExecSilentlyContext(ctx context.Context, db *sql.DB, query string, args ...interface{}) (sql.Result, error) { + return execInternalContext(ctx, true, db, query, args...) +} + // ExecSilently acts like Exec but does not report any error func ExecSilently(db *sql.DB, query string, args ...interface{}) (sql.Result, error) { return execInternal(true, db, query, args...) @@ -396,12 +449,17 @@ func NilIfZero(i int64) interface{} { return i } +func ScanTableContext(ctx context.Context, db *sql.DB, tableName string) (NamedResultData, error) { + query := fmt.Sprintf("select * from %s", tableName) + return QueryNamedResultDataContext(ctx, db, query) +} + func ScanTable(db *sql.DB, tableName string) (NamedResultData, error) { query := fmt.Sprintf("select * from %s", tableName) return QueryNamedResultData(db, query) } -func WriteTable(db *sql.DB, tableName string, data NamedResultData) (err error) { +func WriteTableContext(ctx context.Context, db *sql.DB, tableName string, data NamedResultData) (err error) { if len(data.Data) == 0 { return nil } @@ -419,9 +477,13 @@ func WriteTable(db *sql.DB, tableName string, data NamedResultData) (err error) strings.Join(placeholders, ","), ) for _, rowData := range data.Data { - if _, execErr := db.Exec(query, rowData.Args()...); execErr != nil { + if _, execErr := db.ExecContext(ctx, query, rowData.Args()...); execErr != nil { err = execErr } } return err } + +func WriteTable(db *sql.DB, tableName string, data NamedResultData) (err error) { + return WriteTableContext(context.Background(), db, tableName, data) +}