Skip to content

Commit

Permalink
Simplify check to see whether database exists (#227)
Browse files Browse the repository at this point in the history
* Solves #211 without the need for the --dontcreatedatabase command-line option
* Simplified checking if a database exists to just try to connect, and if success -> exists. If not, doesn't exist

Chores:
* Explicitly kill MariaDB connections against database to make sure it's dropped
* Removed override for SQL server for 'check if database exists', as it was now equal to the base method
* Small cleanups on warnings
  • Loading branch information
erikbra authored Sep 2, 2022
1 parent 68ebaa6 commit 9532190
Show file tree
Hide file tree
Showing 14 changed files with 102 additions and 55 deletions.
57 changes: 44 additions & 13 deletions grate.unittests/Generic/GenericDatabase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -56,33 +56,37 @@ public async Task Does_not_error_if_confed_to_create_but_already_exists()
var db = "DAATAA";

// Create the database manually before running the migration
await CreateDatabase(db);
await CreateDatabaseFromConnectionString(db, Context.UserConnectionString(db));

// Check that the database has been created
IEnumerable<string> databasesBeforeMigration = await GetDatabases();
databasesBeforeMigration.Should().Contain(db);

await using var migrator = GetMigrator(GetConfiguration(db, true));

var config = GetConfiguration(true, Context.UserConnectionString(db), Context.AdminConnectionString);
await using var migrator = GetMigrator(config);

// There should be no errors running the migration
Assert.DoesNotThrowAsync(() => migrator.Migrate());
}

[TestCase("Invalid stuff")]
[TestCase(null)]
[Test]
public async Task Does_not_need_admin_connection_if_database_already_exists()
public async Task Does_not_need_admin_connection_if_database_already_exists(string adminConnectionString)
{
var db = "DATADATADATABASE";

// Create the database manually before running the migration
await CreateDatabase(db);
await CreateDatabaseFromConnectionString(db, Context.UserConnectionString(db));

// Check that the database has been created
IEnumerable<string> databasesBeforeMigration = await GetDatabases();
databasesBeforeMigration.Should().Contain(db);

// Change the admin connection string to rubbish and run the migration
await using var migrator = GetMigrator(GetConfiguration(db, true, "Invalid stuff"));

var config = GetConfiguration(true, Context.UserConnectionString(db), adminConnectionString);
await using var migrator = GetMigrator(config);

// There should be no errors running the migration
Assert.DoesNotThrowAsync(() => migrator.Migrate());
}
Expand All @@ -105,8 +109,13 @@ public async Task Does_not_needlessly_apply_case_sensitive_database_name_checks_
Assert.DoesNotThrowAsync(() => migrator.Migrate());
}

protected virtual async Task CreateDatabase(string db)
protected Task CreateDatabase(string db) => CreateDatabaseFromConnectionString(db, Context.ConnectionString(db));

protected virtual async Task CreateDatabaseFromConnectionString(string db, string connectionString)
{
var uid = TestConfig.Username(connectionString);
var pwd = TestConfig.Password(connectionString);

using (new TransactionScope(TransactionScopeOption.Suppress, TransactionScopeAsyncFlowOption.Enabled))
{
for (var i = 0; i < 5; i++)
Expand All @@ -116,15 +125,29 @@ protected virtual async Task CreateDatabase(string db)
await using var conn = Context.CreateAdminDbConnection();
await conn.OpenAsync();
await using var cmd = conn.CreateCommand();
cmd.CommandText = Context.Syntax.CreateDatabase(db, TestConfig.Password(Context.ConnectionString(db)));

cmd.CommandText = Context.Syntax.CreateDatabase(db, pwd);
await cmd.ExecuteNonQueryAsync();

if (!string.IsNullOrWhiteSpace(Context.Sql.CreateUser))
{
cmd.CommandText = string.Format(Context.Sql.CreateUser, uid, pwd);
await cmd.ExecuteNonQueryAsync();
}

if (!string.IsNullOrWhiteSpace(Context.Sql.GrantAccess))
{
cmd.CommandText = string.Format(Context.Sql.GrantAccess, db, uid);
await cmd.ExecuteNonQueryAsync();
}

break;
}
catch (DbException) { }
}
}
}

protected virtual async Task<IEnumerable<string>> GetDatabases()
{
IEnumerable<string> databases =Enumerable.Empty<string>();
Expand All @@ -150,20 +173,28 @@ protected virtual async Task<IEnumerable<string>> GetDatabases()


private GrateMigrator GetMigrator(GrateConfiguration config) => Context.GetMigrator(config);

private GrateConfiguration GetConfiguration(string databaseName, bool createDatabase)
=> GetConfiguration(databaseName, createDatabase, Context.AdminConnectionString);


private GrateConfiguration GetConfiguration(string databaseName, bool createDatabase, string? adminConnectionString = null)
private GrateConfiguration GetConfiguration(string databaseName, bool createDatabase, string? adminConnectionString)
=> GetConfiguration(createDatabase, Context.ConnectionString(databaseName), adminConnectionString);


private GrateConfiguration GetConfiguration(bool createDatabase, string? connectionString, string? adminConnectionString)
{
var parent = TestConfig.CreateRandomTempDirectory();
return new()
{
CreateDatabase = createDatabase,
ConnectionString = Context.ConnectionString(databaseName),
AdminConnectionString = adminConnectionString ?? Context.AdminConnectionString,
ConnectionString = connectionString,
AdminConnectionString = adminConnectionString,
Folders = FoldersConfiguration.Default(null),
NonInteractive = true,
DatabaseType = Context.DatabaseType,
SqlFilesDirectory = parent
};
}

}
13 changes: 10 additions & 3 deletions grate.unittests/SqLite/Database.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,21 @@ public class Database: Generic.GenericDatabase
{
protected override IGrateTestContext Context => GrateTestContext.Sqlite;

protected override async Task CreateDatabase(string db)
protected override async Task CreateDatabaseFromConnectionString(string db, string connectionString)
{
await using var conn = new SqliteConnection(Context.ConnectionString(db));
await using var conn = new SqliteConnection(connectionString);
conn.Open();
await using var cmd = conn.CreateCommand();

// Create a table to actually create the .sqlite file
var sql = "CREATE TABLE dummy(name VARCHAR(1))";
cmd.CommandText = sql;
await cmd.ExecuteNonQueryAsync();

// Remove the table to avoid polluting the database with dummy tables :)
sql = "DROP TABLE dummy";
cmd.CommandText = sql;
await cmd.ExecuteNonQueryAsync();
}

protected override async Task<IEnumerable<string>> GetDatabases()
Expand All @@ -36,4 +43,4 @@ protected override async Task<IEnumerable<string>> GetDatabases()
}

protected override bool ThrowOnMissingDatabase => false;
}
}
1 change: 1 addition & 0 deletions grate.unittests/TestInfrastructure/IGrateTestContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ public interface IGrateTestContext

string AdminConnectionString { get; }
string ConnectionString(string database);
string UserConnectionString(string database);

DbConnection GetDbConnection(string connectionString);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ public string DockerCommand(string serverName, string adminPassword) =>

public string AdminConnectionString => $"Server=localhost;Port={Port};Database=mysql;Uid=root;Pwd={AdminPassword}";
public string ConnectionString(string database) => $"Server=localhost;Port={Port};Database={database};Uid=root;Pwd={AdminPassword}";
public string UserConnectionString(string database) => $"Server=localhost;Port={Port};Database={database};Uid={database};Pwd=mooo1213";

public DbConnection GetDbConnection(string connectionString) => new MySqlConnection(connectionString);

Expand All @@ -34,7 +35,9 @@ public string DockerCommand(string serverName, string adminPassword) =>
public SqlStatements Sql => new()
{
SelectVersion = "SELECT VERSION()",
SleepTwoSeconds = "SELECT SLEEP(2);"
SleepTwoSeconds = "SELECT SLEEP(2);",
CreateUser = "CREATE USER '{0}'@'%' IDENTIFIED BY '{1}';",
GrantAccess = "GRANT SELECT, INSERT, UPDATE, DELETE, CREATE, INDEX, DROP, ALTER, CREATE TEMPORARY TABLES, LOCK TABLES ON {0}.* TO '{1}'@'%';FLUSH PRIVILEGES;"
};


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ public string DockerCommand(string serverName, string adminPassword) =>

public string AdminConnectionString => $@"Data Source=localhost:{Port}/XE;User ID=SYSTEM;Password=oracle;Pooling=False";
public string ConnectionString(string database) => $@"Data Source=localhost:{Port}/XE;User ID={database.ToUpper()};Password=oracle;Pooling=False";
public string UserConnectionString(string database) => $@"Data Source=localhost:{Port}/XE;User ID={database.ToUpper()};Password=oracle;Pooling=False";

public DbConnection GetDbConnection(string connectionString) => new OracleConnection(connectionString);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ public string DockerCommand(string serverName, string adminPassword) =>

public string AdminConnectionString => $"Host=localhost;Port={Port};Database=postgres;Username=postgres;Password={AdminPassword};Include Error Detail=true;Pooling=false";
public string ConnectionString(string database) => $"Host=localhost;Port={Port};Database={database};Username=postgres;Password={AdminPassword};Include Error Detail=true;Pooling=false";
public string UserConnectionString(string database) => $"Host=localhost;Port={Port};Database={database};Username=postgres;Password={AdminPassword};Include Error Detail=true;Pooling=false";

public DbConnection GetDbConnection(string connectionString) => new NpgsqlConnection(connectionString);

Expand Down
3 changes: 2 additions & 1 deletion grate.unittests/TestInfrastructure/SqLiteGrateTestContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ class SqliteGrateTestContext : TestContextBase, IGrateTestContext

public string AdminConnectionString => $"Data Source=grate-sqlite.db";
public string ConnectionString(string database) => $"Data Source={database}.db";
public string UserConnectionString(string database) => $"Data Source={database}.db";

public DbConnection GetDbConnection(string connectionString) => new SqliteConnection(connectionString);

Expand All @@ -36,4 +37,4 @@ class SqliteGrateTestContext : TestContextBase, IGrateTestContext

public string ExpectedVersionPrefix => "3.32.3";
public bool SupportsCreateDatabase => false;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ public string DockerCommand(string serverName, string adminPassword) =>

public string AdminConnectionString => $"Data Source=localhost,{Port};Initial Catalog=master;User Id=sa;Password={AdminPassword};Encrypt=false;Pooling=false";
public string ConnectionString(string database) => $"Data Source=localhost,{Port};Initial Catalog={database};User Id=sa;Password={AdminPassword};Encrypt=false;Pooling=false";
public string UserConnectionString(string database) => $"Data Source=localhost,{Port};Initial Catalog={database};User Id=sa;Password={AdminPassword};Encrypt=false;Pooling=false";

public DbConnection GetDbConnection(string connectionString) => new SqlConnection(connectionString);

Expand Down
2 changes: 2 additions & 0 deletions grate.unittests/TestInfrastructure/SqlStatements.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,6 @@ public record SqlStatements
{
public string SelectVersion { get; init; } = default!;
public string SleepTwoSeconds { get; init; } = default!;
public string CreateUser { get; init; } = default!;
public string GrantAccess { get; init; } = default!;
}
6 changes: 5 additions & 1 deletion grate.unittests/TestInfrastructure/TestConfig.cs
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,13 @@ public static DirectoryInfo CreateRandomTempDirectory()
var scriptsDir = Directory.CreateDirectory(dummyFile);
return scriptsDir;
}

public static string? Username(string connectionString) => connectionString.Split(";", TrimEntries | RemoveEmptyEntries)
.SingleOrDefault(entry => entry.StartsWith("Uid"))?
.Split("=", TrimEntries | RemoveEmptyEntries).Last();

public static string? Password(string connectionString) => connectionString.Split(";", TrimEntries | RemoveEmptyEntries)
.SingleOrDefault(entry => entry.StartsWith("Password"))?
.SingleOrDefault(entry => entry.StartsWith("Password") || entry.StartsWith("Pwd"))?
.Split("=", TrimEntries | RemoveEmptyEntries).Last();

private static LogLevel GetLogLevel()
Expand Down
2 changes: 1 addition & 1 deletion grate/Infrastructure/MariaDbSyntax.cs
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,4 @@ public string StatementSeparatorRegex
public string Quote(string text) => $"`{text}`";
public string PrimaryKeyConstraint(string tableName, string column) => $",\nCONSTRAINT PK_{tableName}_{column} PRIMARY KEY ({column})";
public string LimitN(string sql, int n) => sql + "\nLIMIT 1";
}
}
13 changes: 2 additions & 11 deletions grate/Migration/AnsiSqlDatabase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -198,19 +198,10 @@ public virtual async Task DropDatabase()
/// <returns></returns>
public virtual async Task<bool> DatabaseExists()
{
var sql = _syntax.ListDatabases;

try
{
var databases = (await Connection.QueryAsync<string>(sql)).ToArray();

Logger.LogTrace("Current databases: ");
foreach (var db in databases)
{
Logger.LogTrace(" * {Database}", db);
}

return databases.Contains(DatabaseName);
await OpenActiveConnection();
return true;
}
catch (DbException e)
{
Expand Down
30 changes: 28 additions & 2 deletions grate/Migration/MariaDbDatabase.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
using System.Data.Common;
using System;
using System.Data.Common;
using System.Diagnostics;
using System.Threading.Tasks;
using grate.Infrastructure;
using Microsoft.Extensions.Logging;
Expand All @@ -20,4 +22,28 @@ public override Task RestoreDatabase(string backupPath)
{
throw new System.NotImplementedException("Restoring a database from file is not currently supported for Maria DB.");
}
}

public override async Task DropDatabase()
{
// Drop the database in normal fashion
await base.DropDatabase();

// We need to kill any active connections to get MariaDB to actually delete the database,
// and stop accepting new connections to it. So we create a list of the
// active sessions against our databse, and create 'KILL X' statements (where X is session id).
// Then we execute the kill statements.
var sql = $@"
SELECT GROUP_CONCAT(CONCAT('KILL ',id,';') SEPARATOR ' ')
FROM information_schema.processlist WHERE DB = '{DatabaseName}'";

var killStatements = await ExecuteScalarAsync<object>(AdminConnection, sql);
if (killStatements != null && !DBNull.Value.Equals(killStatements))
{
string killSql = killStatements.ToString() ?? ""; // Just to keel warnings happy
await ExecuteNonQuery(AdminConnection, killSql, null);
}

var databaseExists = await DatabaseExists();
Debug.Assert(!databaseExists, "Database still exists after it is dropped");
}
}
22 changes: 0 additions & 22 deletions grate/Migration/SqlServerDatabase.cs
Original file line number Diff line number Diff line change
@@ -1,16 +1,12 @@
using System;
using System.Data.Common;
using System.Linq;
using System.Threading.Tasks;
using System.Transactions;
using Dapper;
using grate.Configuration;
using grate.Infrastructure;
using Microsoft.Data.SqlClient;
using Microsoft.Extensions.Logging;

using static System.Data.CommandType;

namespace grate.Migration;

public class SqlServerDatabase : AnsiSqlDatabase
Expand Down Expand Up @@ -73,24 +69,6 @@ WITH NOUNLOAD
Logger.LogInformation("Database {DbName} successfully restored from path {Path}.", DatabaseName, backupPath);
}

public override async Task<bool> DatabaseExists()
{
// For Bug #167. Sql Server is causing issues when the database name passed in differs only in case from one already existing on the server.
// There's currently no point adding to ISyntax for this as all the other DBMS's would just be a NOP.

// This should also mean that a SQL server running a Case Sensitive collation _also_ works as expected
var sql = $"select name from sys.databases where [name] = '{DatabaseName}'";
try
{
var results = await ActiveConnection.QueryAsync<string>(sql, commandType: Text);
return results.Any();
}
catch (DbException ex)
{
Logger.LogDebug(ex, "An unexpected error occurred performing the CheckDatabaseExists check: {ErrorMessage}", ex.Message);
return false; // base method also returns false on any DbException
}
}

protected override string HasRunSql =>
$@"
Expand Down

0 comments on commit 9532190

Please sign in to comment.