diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 646ea258..e3154b32 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -103,10 +103,10 @@ jobs: map: | { "postgresql": { - "dialector": "postgresql" + "dialector": "postgres" }, "cockroachdb": { - "dialector": "postgresql" + "dialector": "postgres" }, "mysql": { "dialector": "mysql" diff --git a/Makefile b/Makefile index 7ccec49e..7f9a7e4a 100644 --- a/Makefile +++ b/Makefile @@ -29,10 +29,10 @@ sqlserver: docker compose -f "docker/sqlserver/docker-compose.yml" up -d --build test_integration_postgresql: postgresql - DB=postgresql gotestsum --format testname ./testintegration + DB=postgres gotestsum --format testname ./testintegration test_integration_cockroachdb: cockroachdb - DB=postgresql gotestsum --format testname ./testintegration -tags=cockroachdb + DB=postgres gotestsum --format testname ./testintegration -tags=cockroachdb test_integration_mysql: mysql DB=mysql gotestsum --format testname ./testintegration -tags=mysql diff --git a/orm/query/gorm_query.go b/orm/query/gorm_query.go index e82f845b..ef027ead 100644 --- a/orm/query/gorm_query.go +++ b/orm/query/gorm_query.go @@ -28,8 +28,8 @@ func (query *GormQuery) Order(field IFieldIdentifier, descending bool, joinNumbe return err } - switch query.GormDB.Dialector.Name() { - case "postgres": + switch query.Dialector() { + case Postgres: // postgres supports only order by selected fields query.AddSelect(table, field) query.GormDB = query.GormDB.Order( @@ -42,7 +42,7 @@ func (query *GormQuery) Order(field IFieldIdentifier, descending bool, joinNumbe ) return nil - case "sqlserver", "sqlite", "mysql": + case SQLServer, SQLite, MySQL: query.GormDB = query.GormDB.Order( clause.OrderByColumn{ Column: clause.Column{ @@ -173,6 +173,19 @@ func (query GormQuery) ColumnName(table Table, fieldName string) string { return query.GormDB.NamingStrategy.ColumnName(table.Name, fieldName) } +type Dialector string + +const ( + Postgres Dialector = "postgres" + MySQL Dialector = "mysql" + SQLite Dialector = "sqlite" + SQLServer Dialector = "sqlserver" +) + +func (query GormQuery) Dialector() Dialector { + return Dialector(query.GormDB.Dialector.Name()) +} + func NewGormQuery(db *gorm.DB, initialModel model.Model, initialTable Table) *GormQuery { query := &GormQuery{ GormDB: db.Select(initialTable.Name + ".*"), diff --git a/testintegration/operators_test.go b/testintegration/operators_test.go index e527dbd2..267850b1 100644 --- a/testintegration/operators_test.go +++ b/testintegration/operators_test.go @@ -11,6 +11,7 @@ import ( "github.com/ditrit/badaas/orm/mysql" "github.com/ditrit/badaas/orm/operator" "github.com/ditrit/badaas/orm/psql" + "github.com/ditrit/badaas/orm/query" "github.com/ditrit/badaas/orm/sqlite" "github.com/ditrit/badaas/testintegration/conditions" "github.com/ditrit/badaas/testintegration/models" @@ -250,12 +251,12 @@ func (ts *OperatorsIntTestSuite) TestIsTrue() { var entities []*models.Product switch getDBDialector() { - case postgreSQL, mySQL, sqLite: + case query.Postgres, query.MySQL, query.SQLite: entities, err = orm.NewQuery[models.Product]( ts.db, conditions.Product.BoolIs().True(), ).Find() - case sqlServer: + case query.SQLServer: entities, err = orm.NewQuery[models.Product]( ts.db, conditions.Product.BoolIs().Eq(true), @@ -277,12 +278,12 @@ func (ts *OperatorsIntTestSuite) TestIsFalse() { var entities []*models.Product switch getDBDialector() { - case postgreSQL, mySQL, sqLite: + case query.Postgres, query.MySQL, query.SQLite: entities, err = orm.NewQuery[models.Product]( ts.db, conditions.Product.BoolIs().False(), ).Find() - case sqlServer: + case query.SQLServer: entities, err = orm.NewQuery[models.Product]( ts.db, conditions.Product.BoolIs().Eq(false), @@ -310,12 +311,12 @@ func (ts *OperatorsIntTestSuite) TestIsNotTrue() { var entities []*models.Product switch getDBDialector() { - case postgreSQL, mySQL, sqLite: + case query.Postgres, query.MySQL, query.SQLite: entities, err = orm.NewQuery[models.Product]( ts.db, conditions.Product.NullBoolIs().NotTrue(), ).Find() - case sqlServer: + case query.SQLServer: entities, err = orm.NewQuery[models.Product]( ts.db, conditions.Product.NullBoolIs().Distinct(true), @@ -343,12 +344,12 @@ func (ts *OperatorsIntTestSuite) TestIsNotFalse() { var entities []*models.Product switch getDBDialector() { - case postgreSQL, mySQL, sqLite: + case query.Postgres, query.MySQL, query.SQLite: entities, err = orm.NewQuery[models.Product]( ts.db, conditions.Product.NullBoolIs().NotFalse(), ).Find() - case sqlServer: + case query.SQLServer: entities, err = orm.NewQuery[models.Product]( ts.db, conditions.Product.NullBoolIs().Distinct(false), @@ -376,12 +377,12 @@ func (ts *OperatorsIntTestSuite) TestIsUnknown() { var entities []*models.Product switch getDBDialector() { - case postgreSQL, mySQL: + case query.Postgres, query.MySQL: entities, err = orm.NewQuery[models.Product]( ts.db, conditions.Product.NullBoolIs().Unknown(), ).Find() - case sqlServer, sqLite: + case query.SQLServer, query.SQLite: entities, err = orm.NewQuery[models.Product]( ts.db, conditions.Product.NullBoolIs().Null(), @@ -409,12 +410,12 @@ func (ts *OperatorsIntTestSuite) TestIsNotUnknown() { var entities []*models.Product switch getDBDialector() { - case postgreSQL, mySQL: + case query.Postgres, query.MySQL: entities, err = orm.NewQuery[models.Product]( ts.db, conditions.Product.NullBoolIs().NotUnknown(), ).Find() - case sqlServer, sqLite: + case query.SQLServer, query.SQLite: entities, err = orm.NewQuery[models.Product]( ts.db, conditions.Product.NullBoolIs().NotNull(), @@ -428,7 +429,7 @@ func (ts *OperatorsIntTestSuite) TestIsNotUnknown() { func (ts *OperatorsIntTestSuite) TestIsDistinct() { switch getDBDialector() { - case postgreSQL, sqlServer, sqLite: + case query.Postgres, query.SQLServer, query.SQLite: match1 := ts.createProduct("match", 3, 0, false, nil) match2 := ts.createProduct("match", 4, 0, false, nil) ts.createProduct("not_match", 2, 0, false, nil) @@ -440,14 +441,14 @@ func (ts *OperatorsIntTestSuite) TestIsDistinct() { ts.Nil(err) EqualList(&ts.Suite, []*models.Product{match1, match2}, entities) - case mySQL: + case query.MySQL: log.Println("IsDistinct not compatible") } } func (ts *OperatorsIntTestSuite) TestIsNotDistinct() { switch getDBDialector() { - case postgreSQL, sqlServer, sqLite: + case query.Postgres, query.SQLServer, query.SQLite: match := ts.createProduct("match", 3, 0, false, nil) ts.createProduct("not_match", 4, 0, false, nil) ts.createProduct("not_match", 2, 0, false, nil) @@ -459,7 +460,7 @@ func (ts *OperatorsIntTestSuite) TestIsNotDistinct() { ts.Nil(err) EqualList(&ts.Suite, []*models.Product{match}, entities) - case mySQL: + case query.MySQL: log.Println("IsNotDistinct not compatible") } } @@ -532,9 +533,9 @@ func (ts *OperatorsIntTestSuite) TestLikeEscape() { func (ts *OperatorsIntTestSuite) TestLikeOnNumeric() { switch getDBDialector() { - case postgreSQL, sqlServer, sqLite: + case query.Postgres, query.SQLServer, query.SQLite: log.Println("Like with numeric not compatible") - case mySQL: + case query.MySQL: match1 := ts.createProduct("", 10, 0, false, nil) match2 := ts.createProduct("", 100, 0, false, nil) @@ -555,9 +556,9 @@ func (ts *OperatorsIntTestSuite) TestLikeOnNumeric() { func (ts *OperatorsIntTestSuite) TestILike() { switch getDBDialector() { - case mySQL, sqlServer, sqLite: + case query.MySQL, query.SQLServer, query.SQLite: log.Println("ILike not compatible") - case postgreSQL: + case query.Postgres: match1 := ts.createProduct("basd", 0, 0, false, nil) match2 := ts.createProduct("cape", 0, 0, false, nil) match3 := ts.createProduct("bAsd", 0, 0, false, nil) @@ -579,9 +580,9 @@ func (ts *OperatorsIntTestSuite) TestILike() { func (ts *OperatorsIntTestSuite) TestSimilarTo() { switch getDBDialector() { - case mySQL, sqlServer, sqLite: + case query.MySQL, query.SQLServer, query.SQLite: log.Println("SimilarTo not compatible") - case postgreSQL: + case query.Postgres: match1 := ts.createProduct("abc", 0, 0, false, nil) match2 := ts.createProduct("aabcc", 0, 0, false, nil) @@ -611,11 +612,11 @@ func (ts *OperatorsIntTestSuite) TestPosixRegexCaseSensitive() { var posixRegexOperator operator.Operator[string] switch getDBDialector() { - case sqlServer, mySQL: + case query.SQLServer, query.MySQL: log.Println("PosixRegex not compatible") - case postgreSQL: + case query.Postgres: posixRegexOperator = psql.POSIXMatch("^a(b|x)") - case sqLite: + case query.SQLite: posixRegexOperator = sqlite.Glob("a[bx]") } @@ -643,11 +644,11 @@ func (ts *OperatorsIntTestSuite) TestPosixRegexCaseInsensitive() { var posixRegexOperator operator.Operator[string] switch getDBDialector() { - case sqlServer, sqLite: + case query.SQLServer, query.SQLite: log.Println("PosixRegex Case Insensitive not compatible") - case mySQL: + case query.MySQL: posixRegexOperator = mysql.RegexP("^a(b|x)") - case postgreSQL: + case query.Postgres: posixRegexOperator = psql.POSIXIMatch("^a(b|x)") } @@ -744,7 +745,7 @@ func (ts *OperatorsIntTestSuite) TestUnsafeOperatorInCaseTypesNotMatchConvertibl func (ts *OperatorsIntTestSuite) TestUnsafeOperatorInCaseTypesNotMatchNotConvertible() { switch getDBDialector() { - case sqLite: + case query.SQLite: // comparisons between types are allowed and matches nothing if not convertible ts.createProduct("", 0, 0, false, nil) ts.createProduct("", 0, 2, false, nil) @@ -757,7 +758,7 @@ func (ts *OperatorsIntTestSuite) TestUnsafeOperatorInCaseTypesNotMatchNotConvert ts.Nil(err) EqualList(&ts.Suite, []*models.Product{}, entities) - case mySQL: + case query.MySQL: // comparisons between types are allowed but matches 0s if not convertible match := ts.createProduct("", 0, 0, false, nil) ts.createProduct("", 0, 2, false, nil) @@ -770,14 +771,14 @@ func (ts *OperatorsIntTestSuite) TestUnsafeOperatorInCaseTypesNotMatchNotConvert ts.Nil(err) EqualList(&ts.Suite, []*models.Product{match}, entities) - case sqlServer: + case query.SQLServer: // returns an error _, err := orm.NewQuery[models.Product]( ts.db, conditions.Product.FloatIs().Unsafe().Eq("not_convertible_to_float"), ).Find() ts.ErrorContains(err, "mssql: Error converting data type nvarchar to float.") - case postgreSQL: + case query.Postgres: // returns an error _, err := orm.NewQuery[models.Product]( ts.db, @@ -789,7 +790,7 @@ func (ts *OperatorsIntTestSuite) TestUnsafeOperatorInCaseTypesNotMatchNotConvert func (ts *OperatorsIntTestSuite) TestUnsafeOperatorInCaseFieldWithTypesNotMatch() { switch getDBDialector() { - case sqLite: + case query.SQLite: // comparisons between fields with different types are allowed match1 := ts.createProduct("0", 0, 0, false, nil) match2 := ts.createProduct("1", 0, 1, false, nil) @@ -803,7 +804,7 @@ func (ts *OperatorsIntTestSuite) TestUnsafeOperatorInCaseFieldWithTypesNotMatch( ts.Nil(err) EqualList(&ts.Suite, []*models.Product{match1, match2}, entities) - case mySQL: + case query.MySQL: // comparisons between fields with different types are allowed but matches 0s on not convertible match1 := ts.createProduct("0", 1, 0, false, nil) match2 := ts.createProduct("1", 2, 1, false, nil) @@ -817,7 +818,7 @@ func (ts *OperatorsIntTestSuite) TestUnsafeOperatorInCaseFieldWithTypesNotMatch( ts.Nil(err) EqualList(&ts.Suite, []*models.Product{match1, match2, match3}, entities) - case sqlServer: + case query.SQLServer: // comparisons between fields with different types are allowed and returns error only if at least one is not convertible match1 := ts.createProduct("0", 1, 0, false, nil) match2 := ts.createProduct("1", 2, 1, false, nil) @@ -838,7 +839,7 @@ func (ts *OperatorsIntTestSuite) TestUnsafeOperatorInCaseFieldWithTypesNotMatch( conditions.Product.FloatIs().Unsafe().Eq(conditions.Product.String), ).Find() ts.ErrorContains(err, "mssql: Error converting data type nvarchar to float.") - case postgreSQL: + case query.Postgres: // returns an error _, err := orm.NewQuery[models.Product]( ts.db, diff --git a/testintegration/orm_test.go b/testintegration/orm_test.go index 586f9316..eb03d97a 100644 --- a/testintegration/orm_test.go +++ b/testintegration/orm_test.go @@ -16,6 +16,7 @@ import ( "github.com/ditrit/badaas/orm" "github.com/ditrit/badaas/orm/logger" + "github.com/ditrit/badaas/orm/query" "github.com/ditrit/badaas/persistence/database" "github.com/ditrit/badaas/persistence/gormfx" ) @@ -33,15 +34,6 @@ const ( dbName = "badaas_db" ) -type dbDialector string - -const ( - postgreSQL dbDialector = "postgresql" - mySQL dbDialector = "mysql" - sqLite dbDialector = "sqlite" - sqlServer dbDialector = "sqlserver" -) - func TestBaDaaSORM(t *testing.T) { tGlobal = t @@ -84,13 +76,13 @@ func NewDBConnection() (*gorm.DB, error) { var dialector gorm.Dialector switch getDBDialector() { - case postgreSQL: + case query.Postgres: dialector = postgres.Open(orm.CreatePostgreSQLDSN(host, username, password, sslMode, dbName, port)) - case mySQL: - dialector = mysql.Open(orm.CreateMySQLDSN(host, username, password, dbName, port)) - case sqLite: + case query.SQLite: dialector = sqlite.Open(orm.CreateSQLiteDSN(host)) - case sqlServer: + case query.MySQL: + dialector = mysql.Open(orm.CreateMySQLDSN(host, username, password, dbName, port)) + case query.SQLServer: dialector = sqlserver.Open(orm.CreateSQLServerDSN(host, username, password, dbName, port)) default: return nil, fmt.Errorf("unknown db %s", getDBDialector()) @@ -103,6 +95,6 @@ func NewDBConnection() (*gorm.DB, error) { ) } -func getDBDialector() dbDialector { - return dbDialector(os.Getenv(dbTypeEnvKey)) +func getDBDialector() query.Dialector { + return query.Dialector(os.Getenv(dbTypeEnvKey)) } diff --git a/testintegration/query_test.go b/testintegration/query_test.go index 417f9d96..2d52abcd 100644 --- a/testintegration/query_test.go +++ b/testintegration/query_test.go @@ -7,6 +7,7 @@ import ( "github.com/ditrit/badaas/orm" "github.com/ditrit/badaas/orm/errors" + "github.com/ditrit/badaas/orm/query" "github.com/ditrit/badaas/testintegration/conditions" "github.com/ditrit/badaas/testintegration/models" ) @@ -272,7 +273,7 @@ func (ts *QueryIntTestSuite) TestOffsetSkipsTheModelsReturned() { product2 := ts.createProduct("", 1, 2, false, nil) switch getDBDialector() { - case postgreSQL, sqlServer, sqLite: + case query.Postgres, query.SQLServer, query.SQLite: products, err := orm.NewQuery[models.Product]( ts.db, conditions.Product.IntIs().Eq(1), @@ -280,7 +281,7 @@ func (ts *QueryIntTestSuite) TestOffsetSkipsTheModelsReturned() { ts.Nil(err) EqualList(&ts.Suite, []*models.Product{product2}, products) - case mySQL: + case query.MySQL: products, err := orm.NewQuery[models.Product]( ts.db, conditions.Product.IntIs().Eq(1), @@ -296,7 +297,7 @@ func (ts *QueryIntTestSuite) TestOffsetReturnsEmptyIfMoreOffsetThanResults() { ts.createProduct("", 1, 0, false, nil) switch getDBDialector() { - case postgreSQL, sqlServer, sqLite: + case query.Postgres, query.SQLServer, query.SQLite: products, err := orm.NewQuery[models.Product]( ts.db, conditions.Product.IntIs().Eq(1), @@ -304,7 +305,7 @@ func (ts *QueryIntTestSuite) TestOffsetReturnsEmptyIfMoreOffsetThanResults() { ts.Nil(err) EqualList(&ts.Suite, []*models.Product{}, products) - case mySQL: + case query.MySQL: products, err := orm.NewQuery[models.Product]( ts.db, conditions.Product.IntIs().Eq(1), diff --git a/testintegration/where_conditions_test.go b/testintegration/where_conditions_test.go index 9b49c42c..446a769c 100644 --- a/testintegration/where_conditions_test.go +++ b/testintegration/where_conditions_test.go @@ -8,6 +8,7 @@ import ( "github.com/ditrit/badaas/orm" "github.com/ditrit/badaas/orm/errors" "github.com/ditrit/badaas/orm/mysql" + "github.com/ditrit/badaas/orm/query" "github.com/ditrit/badaas/orm/unsafe" "github.com/ditrit/badaas/testintegration/conditions" "github.com/ditrit/badaas/testintegration/models" @@ -500,9 +501,9 @@ func (ts *WhereConditionsIntTestSuite) TestNotOr() { func (ts *WhereConditionsIntTestSuite) TestXor() { switch getDBDialector() { - case postgreSQL, sqLite, sqlServer: + case query.Postgres, query.SQLite, query.SQLServer: log.Println("Xor not compatible") - case mySQL: + case query.MySQL: match1 := ts.createProduct("", 1, 0, false, nil) match2 := ts.createProduct("", 7, 0, false, nil)