diff --git a/grate.unittests/Generic/GenericDatabase.cs b/grate.unittests/Generic/GenericDatabase.cs index fcb62fb3..be6654e9 100644 --- a/grate.unittests/Generic/GenericDatabase.cs +++ b/grate.unittests/Generic/GenericDatabase.cs @@ -56,33 +56,37 @@ public async Task Does_not_error_if_confed_to_create_but_already_exists() var db = "DAATAA"; // Create the database manually before running the migration - await CreateDatabase(db); + await CreateDatabaseFromConnectionString(db, Context.UserConnectionString(db)); // Check that the database has been created IEnumerable databasesBeforeMigration = await GetDatabases(); databasesBeforeMigration.Should().Contain(db); - - await using var migrator = GetMigrator(GetConfiguration(db, true)); + + var config = GetConfiguration(true, Context.UserConnectionString(db), Context.AdminConnectionString); + await using var migrator = GetMigrator(config); // There should be no errors running the migration Assert.DoesNotThrowAsync(() => migrator.Migrate()); } + [TestCase("Invalid stuff")] + [TestCase(null)] [Test] - public async Task Does_not_need_admin_connection_if_database_already_exists() + public async Task Does_not_need_admin_connection_if_database_already_exists(string adminConnectionString) { var db = "DATADATADATABASE"; // Create the database manually before running the migration - await CreateDatabase(db); + await CreateDatabaseFromConnectionString(db, Context.UserConnectionString(db)); // Check that the database has been created IEnumerable databasesBeforeMigration = await GetDatabases(); databasesBeforeMigration.Should().Contain(db); // Change the admin connection string to rubbish and run the migration - await using var migrator = GetMigrator(GetConfiguration(db, true, "Invalid stuff")); - + var config = GetConfiguration(true, Context.UserConnectionString(db), adminConnectionString); + await using var migrator = GetMigrator(config); + // There should be no errors running the migration Assert.DoesNotThrowAsync(() => migrator.Migrate()); } @@ -105,8 +109,13 @@ public async Task Does_not_needlessly_apply_case_sensitive_database_name_checks_ Assert.DoesNotThrowAsync(() => migrator.Migrate()); } - protected virtual async Task CreateDatabase(string db) + protected Task CreateDatabase(string db) => CreateDatabaseFromConnectionString(db, Context.ConnectionString(db)); + + protected virtual async Task CreateDatabaseFromConnectionString(string db, string connectionString) { + var uid = TestConfig.Username(connectionString); + var pwd = TestConfig.Password(connectionString); + using (new TransactionScope(TransactionScopeOption.Suppress, TransactionScopeAsyncFlowOption.Enabled)) { for (var i = 0; i < 5; i++) @@ -116,15 +125,29 @@ protected virtual async Task CreateDatabase(string db) await using var conn = Context.CreateAdminDbConnection(); await conn.OpenAsync(); await using var cmd = conn.CreateCommand(); - cmd.CommandText = Context.Syntax.CreateDatabase(db, TestConfig.Password(Context.ConnectionString(db))); + + cmd.CommandText = Context.Syntax.CreateDatabase(db, pwd); await cmd.ExecuteNonQueryAsync(); + + if (!string.IsNullOrWhiteSpace(Context.Sql.CreateUser)) + { + cmd.CommandText = string.Format(Context.Sql.CreateUser, uid, pwd); + await cmd.ExecuteNonQueryAsync(); + } + + if (!string.IsNullOrWhiteSpace(Context.Sql.GrantAccess)) + { + cmd.CommandText = string.Format(Context.Sql.GrantAccess, db, uid); + await cmd.ExecuteNonQueryAsync(); + } + break; } catch (DbException) { } } } } - + protected virtual async Task> GetDatabases() { IEnumerable databases =Enumerable.Empty(); @@ -150,20 +173,28 @@ protected virtual async Task> GetDatabases() private GrateMigrator GetMigrator(GrateConfiguration config) => Context.GetMigrator(config); + + private GrateConfiguration GetConfiguration(string databaseName, bool createDatabase) + => GetConfiguration(databaseName, createDatabase, Context.AdminConnectionString); - private GrateConfiguration GetConfiguration(string databaseName, bool createDatabase, string? adminConnectionString = null) + private GrateConfiguration GetConfiguration(string databaseName, bool createDatabase, string? adminConnectionString) + => GetConfiguration(createDatabase, Context.ConnectionString(databaseName), adminConnectionString); + + + private GrateConfiguration GetConfiguration(bool createDatabase, string? connectionString, string? adminConnectionString) { var parent = TestConfig.CreateRandomTempDirectory(); return new() { CreateDatabase = createDatabase, - ConnectionString = Context.ConnectionString(databaseName), - AdminConnectionString = adminConnectionString ?? Context.AdminConnectionString, + ConnectionString = connectionString, + AdminConnectionString = adminConnectionString, Folders = FoldersConfiguration.Default(null), NonInteractive = true, DatabaseType = Context.DatabaseType, SqlFilesDirectory = parent }; } + } diff --git a/grate.unittests/SqLite/Database.cs b/grate.unittests/SqLite/Database.cs index c3afa494..8c7272c1 100644 --- a/grate.unittests/SqLite/Database.cs +++ b/grate.unittests/SqLite/Database.cs @@ -14,14 +14,21 @@ public class Database: Generic.GenericDatabase { protected override IGrateTestContext Context => GrateTestContext.Sqlite; - protected override async Task CreateDatabase(string db) + protected override async Task CreateDatabaseFromConnectionString(string db, string connectionString) { - await using var conn = new SqliteConnection(Context.ConnectionString(db)); + await using var conn = new SqliteConnection(connectionString); conn.Open(); await using var cmd = conn.CreateCommand(); + + // Create a table to actually create the .sqlite file var sql = "CREATE TABLE dummy(name VARCHAR(1))"; cmd.CommandText = sql; await cmd.ExecuteNonQueryAsync(); + + // Remove the table to avoid polluting the database with dummy tables :) + sql = "DROP TABLE dummy"; + cmd.CommandText = sql; + await cmd.ExecuteNonQueryAsync(); } protected override async Task> GetDatabases() @@ -36,4 +43,4 @@ protected override async Task> GetDatabases() } protected override bool ThrowOnMissingDatabase => false; -} \ No newline at end of file +} diff --git a/grate.unittests/TestInfrastructure/IGrateTestContext.cs b/grate.unittests/TestInfrastructure/IGrateTestContext.cs index 9bbd09a8..1756a8c0 100644 --- a/grate.unittests/TestInfrastructure/IGrateTestContext.cs +++ b/grate.unittests/TestInfrastructure/IGrateTestContext.cs @@ -16,6 +16,7 @@ public interface IGrateTestContext string AdminConnectionString { get; } string ConnectionString(string database); + string UserConnectionString(string database); DbConnection GetDbConnection(string connectionString); diff --git a/grate.unittests/TestInfrastructure/MariaDbGrateTestContext.cs b/grate.unittests/TestInfrastructure/MariaDbGrateTestContext.cs index 5f202580..03457907 100644 --- a/grate.unittests/TestInfrastructure/MariaDbGrateTestContext.cs +++ b/grate.unittests/TestInfrastructure/MariaDbGrateTestContext.cs @@ -18,6 +18,7 @@ public string DockerCommand(string serverName, string adminPassword) => public string AdminConnectionString => $"Server=localhost;Port={Port};Database=mysql;Uid=root;Pwd={AdminPassword}"; public string ConnectionString(string database) => $"Server=localhost;Port={Port};Database={database};Uid=root;Pwd={AdminPassword}"; + public string UserConnectionString(string database) => $"Server=localhost;Port={Port};Database={database};Uid={database};Pwd=mooo1213"; public DbConnection GetDbConnection(string connectionString) => new MySqlConnection(connectionString); @@ -34,7 +35,9 @@ public string DockerCommand(string serverName, string adminPassword) => public SqlStatements Sql => new() { SelectVersion = "SELECT VERSION()", - SleepTwoSeconds = "SELECT SLEEP(2);" + SleepTwoSeconds = "SELECT SLEEP(2);", + CreateUser = "CREATE USER '{0}'@'%' IDENTIFIED BY '{1}';", + GrantAccess = "GRANT SELECT, INSERT, UPDATE, DELETE, CREATE, INDEX, DROP, ALTER, CREATE TEMPORARY TABLES, LOCK TABLES ON {0}.* TO '{1}'@'%';FLUSH PRIVILEGES;" }; diff --git a/grate.unittests/TestInfrastructure/OracleGrateTestContext.cs b/grate.unittests/TestInfrastructure/OracleGrateTestContext.cs index dbade7a4..ec555716 100644 --- a/grate.unittests/TestInfrastructure/OracleGrateTestContext.cs +++ b/grate.unittests/TestInfrastructure/OracleGrateTestContext.cs @@ -20,6 +20,7 @@ public string DockerCommand(string serverName, string adminPassword) => public string AdminConnectionString => $@"Data Source=localhost:{Port}/XE;User ID=SYSTEM;Password=oracle;Pooling=False"; public string ConnectionString(string database) => $@"Data Source=localhost:{Port}/XE;User ID={database.ToUpper()};Password=oracle;Pooling=False"; + public string UserConnectionString(string database) => $@"Data Source=localhost:{Port}/XE;User ID={database.ToUpper()};Password=oracle;Pooling=False"; public DbConnection GetDbConnection(string connectionString) => new OracleConnection(connectionString); diff --git a/grate.unittests/TestInfrastructure/PostgreSqlGrateTestContext.cs b/grate.unittests/TestInfrastructure/PostgreSqlGrateTestContext.cs index fd8e0ed8..561684e4 100644 --- a/grate.unittests/TestInfrastructure/PostgreSqlGrateTestContext.cs +++ b/grate.unittests/TestInfrastructure/PostgreSqlGrateTestContext.cs @@ -18,6 +18,7 @@ public string DockerCommand(string serverName, string adminPassword) => public string AdminConnectionString => $"Host=localhost;Port={Port};Database=postgres;Username=postgres;Password={AdminPassword};Include Error Detail=true;Pooling=false"; public string ConnectionString(string database) => $"Host=localhost;Port={Port};Database={database};Username=postgres;Password={AdminPassword};Include Error Detail=true;Pooling=false"; + public string UserConnectionString(string database) => $"Host=localhost;Port={Port};Database={database};Username=postgres;Password={AdminPassword};Include Error Detail=true;Pooling=false"; public DbConnection GetDbConnection(string connectionString) => new NpgsqlConnection(connectionString); diff --git a/grate.unittests/TestInfrastructure/SqLiteGrateTestContext.cs b/grate.unittests/TestInfrastructure/SqLiteGrateTestContext.cs index 491f8d21..2daa24e8 100644 --- a/grate.unittests/TestInfrastructure/SqLiteGrateTestContext.cs +++ b/grate.unittests/TestInfrastructure/SqLiteGrateTestContext.cs @@ -15,6 +15,7 @@ class SqliteGrateTestContext : TestContextBase, IGrateTestContext public string AdminConnectionString => $"Data Source=grate-sqlite.db"; public string ConnectionString(string database) => $"Data Source={database}.db"; + public string UserConnectionString(string database) => $"Data Source={database}.db"; public DbConnection GetDbConnection(string connectionString) => new SqliteConnection(connectionString); @@ -36,4 +37,4 @@ class SqliteGrateTestContext : TestContextBase, IGrateTestContext public string ExpectedVersionPrefix => "3.32.3"; public bool SupportsCreateDatabase => false; -} \ No newline at end of file +} diff --git a/grate.unittests/TestInfrastructure/SqlServerGrateTestContext.cs b/grate.unittests/TestInfrastructure/SqlServerGrateTestContext.cs index ee04924c..518c102e 100644 --- a/grate.unittests/TestInfrastructure/SqlServerGrateTestContext.cs +++ b/grate.unittests/TestInfrastructure/SqlServerGrateTestContext.cs @@ -29,6 +29,7 @@ public string DockerCommand(string serverName, string adminPassword) => public string AdminConnectionString => $"Data Source=localhost,{Port};Initial Catalog=master;User Id=sa;Password={AdminPassword};Encrypt=false;Pooling=false"; public string ConnectionString(string database) => $"Data Source=localhost,{Port};Initial Catalog={database};User Id=sa;Password={AdminPassword};Encrypt=false;Pooling=false"; + public string UserConnectionString(string database) => $"Data Source=localhost,{Port};Initial Catalog={database};User Id=sa;Password={AdminPassword};Encrypt=false;Pooling=false"; public DbConnection GetDbConnection(string connectionString) => new SqlConnection(connectionString); diff --git a/grate.unittests/TestInfrastructure/SqlStatements.cs b/grate.unittests/TestInfrastructure/SqlStatements.cs index 83718d53..f5716658 100644 --- a/grate.unittests/TestInfrastructure/SqlStatements.cs +++ b/grate.unittests/TestInfrastructure/SqlStatements.cs @@ -4,4 +4,6 @@ public record SqlStatements { public string SelectVersion { get; init; } = default!; public string SleepTwoSeconds { get; init; } = default!; + public string CreateUser { get; init; } = default!; + public string GrantAccess { get; init; } = default!; } diff --git a/grate.unittests/TestInfrastructure/TestConfig.cs b/grate.unittests/TestInfrastructure/TestConfig.cs index 22ff3da1..c7b76cc1 100644 --- a/grate.unittests/TestInfrastructure/TestConfig.cs +++ b/grate.unittests/TestInfrastructure/TestConfig.cs @@ -27,9 +27,13 @@ public static DirectoryInfo CreateRandomTempDirectory() var scriptsDir = Directory.CreateDirectory(dummyFile); return scriptsDir; } + + public static string? Username(string connectionString) => connectionString.Split(";", TrimEntries | RemoveEmptyEntries) + .SingleOrDefault(entry => entry.StartsWith("Uid"))? + .Split("=", TrimEntries | RemoveEmptyEntries).Last(); public static string? Password(string connectionString) => connectionString.Split(";", TrimEntries | RemoveEmptyEntries) - .SingleOrDefault(entry => entry.StartsWith("Password"))? + .SingleOrDefault(entry => entry.StartsWith("Password") || entry.StartsWith("Pwd"))? .Split("=", TrimEntries | RemoveEmptyEntries).Last(); private static LogLevel GetLogLevel() diff --git a/grate/Infrastructure/MariaDbSyntax.cs b/grate/Infrastructure/MariaDbSyntax.cs index 57d519f1..1ba7a0b8 100644 --- a/grate/Infrastructure/MariaDbSyntax.cs +++ b/grate/Infrastructure/MariaDbSyntax.cs @@ -30,4 +30,4 @@ public string StatementSeparatorRegex public string Quote(string text) => $"`{text}`"; public string PrimaryKeyConstraint(string tableName, string column) => $",\nCONSTRAINT PK_{tableName}_{column} PRIMARY KEY ({column})"; public string LimitN(string sql, int n) => sql + "\nLIMIT 1"; -} \ No newline at end of file +} diff --git a/grate/Migration/AnsiSqlDatabase.cs b/grate/Migration/AnsiSqlDatabase.cs index ef552352..8ff60797 100644 --- a/grate/Migration/AnsiSqlDatabase.cs +++ b/grate/Migration/AnsiSqlDatabase.cs @@ -198,19 +198,10 @@ public virtual async Task DropDatabase() /// public virtual async Task DatabaseExists() { - var sql = _syntax.ListDatabases; - try { - var databases = (await Connection.QueryAsync(sql)).ToArray(); - - Logger.LogTrace("Current databases: "); - foreach (var db in databases) - { - Logger.LogTrace(" * {Database}", db); - } - - return databases.Contains(DatabaseName); + await OpenActiveConnection(); + return true; } catch (DbException e) { diff --git a/grate/Migration/MariaDbDatabase.cs b/grate/Migration/MariaDbDatabase.cs index eb54d121..9091573a 100644 --- a/grate/Migration/MariaDbDatabase.cs +++ b/grate/Migration/MariaDbDatabase.cs @@ -1,4 +1,6 @@ -using System.Data.Common; +using System; +using System.Data.Common; +using System.Diagnostics; using System.Threading.Tasks; using grate.Infrastructure; using Microsoft.Extensions.Logging; @@ -20,4 +22,28 @@ public override Task RestoreDatabase(string backupPath) { throw new System.NotImplementedException("Restoring a database from file is not currently supported for Maria DB."); } -} \ No newline at end of file + + public override async Task DropDatabase() + { + // Drop the database in normal fashion + await base.DropDatabase(); + + // We need to kill any active connections to get MariaDB to actually delete the database, + // and stop accepting new connections to it. So we create a list of the + // active sessions against our databse, and create 'KILL X' statements (where X is session id). + // Then we execute the kill statements. + var sql = $@" +SELECT GROUP_CONCAT(CONCAT('KILL ',id,';') SEPARATOR ' ') +FROM information_schema.processlist WHERE DB = '{DatabaseName}'"; + + var killStatements = await ExecuteScalarAsync(AdminConnection, sql); + if (killStatements != null && !DBNull.Value.Equals(killStatements)) + { + string killSql = killStatements.ToString() ?? ""; // Just to keel warnings happy + await ExecuteNonQuery(AdminConnection, killSql, null); + } + + var databaseExists = await DatabaseExists(); + Debug.Assert(!databaseExists, "Database still exists after it is dropped"); + } +} diff --git a/grate/Migration/SqlServerDatabase.cs b/grate/Migration/SqlServerDatabase.cs index 0561dfdb..68e33a44 100644 --- a/grate/Migration/SqlServerDatabase.cs +++ b/grate/Migration/SqlServerDatabase.cs @@ -1,16 +1,12 @@ using System; using System.Data.Common; -using System.Linq; using System.Threading.Tasks; using System.Transactions; -using Dapper; using grate.Configuration; using grate.Infrastructure; using Microsoft.Data.SqlClient; using Microsoft.Extensions.Logging; -using static System.Data.CommandType; - namespace grate.Migration; public class SqlServerDatabase : AnsiSqlDatabase @@ -73,24 +69,6 @@ WITH NOUNLOAD Logger.LogInformation("Database {DbName} successfully restored from path {Path}.", DatabaseName, backupPath); } - public override async Task DatabaseExists() - { - // For Bug #167. Sql Server is causing issues when the database name passed in differs only in case from one already existing on the server. - // There's currently no point adding to ISyntax for this as all the other DBMS's would just be a NOP. - - // This should also mean that a SQL server running a Case Sensitive collation _also_ works as expected - var sql = $"select name from sys.databases where [name] = '{DatabaseName}'"; - try - { - var results = await ActiveConnection.QueryAsync(sql, commandType: Text); - return results.Any(); - } - catch (DbException ex) - { - Logger.LogDebug(ex, "An unexpected error occurred performing the CheckDatabaseExists check: {ErrorMessage}", ex.Message); - return false; // base method also returns false on any DbException - } - } protected override string HasRunSql => $@"