Skip to content

Commit

Permalink
Merge pull request #206 from wokket/fix/#203-command-timeout
Browse files Browse the repository at this point in the history
Fix #203:  Missing support for (Admin)CommandTimeout
  • Loading branch information
wokket authored Jul 6, 2022
2 parents f077344 + fd3443d commit 06db239
Show file tree
Hide file tree
Showing 8 changed files with 119 additions and 46 deletions.
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using System.Data.Common;
using System;
using System.Data.Common;
using System.Linq;
using System.Threading.Tasks;
using System.Transactions;
Expand Down Expand Up @@ -67,6 +68,64 @@ public async Task Are_Inserted_Into_ScriptRunErrors_Table()
scripts.Should().HaveCount(1);
}

[Test]
public void Ensure_Command_Timeout_Fires()
{
var sql = Context.Sql.SleepTwoSeconds;

if (sql == default)
{
Assert.Ignore("DBMS doesn't support sleep() for testing");
}

var db = TestConfig.RandomDatabase();
var knownFolders = KnownFolders.In(CreateRandomTempDirectory());
var path = MakeSurePathExists(knownFolders.Up);
WriteSql(path, "goodnight.sql", sql);

// run it with a timeout shorter than the 1 second sleep, should timeout
var config = Context.GetConfiguration(db, knownFolders) with
{
CommandTimeout = 1, // shorter than the script runs for
};

Assert.CatchAsync(async () =>
{
await using var migrator = Context.GetMigrator(config);
await migrator.Migrate();
Assert.Fail("Should have thrown a timeout exception prior to this!");
});
}

[Test]
public void Ensure_AdminCommand_Timeout_Fires()
{
var sql = Context.Sql.SleepTwoSeconds;

if (sql == default)
{
Assert.Ignore("DBMS doesn't support sleep() for testing");
}

var db = TestConfig.RandomDatabase();
var knownFolders = KnownFolders.In(CreateRandomTempDirectory());
var path = MakeSurePathExists(knownFolders.AlterDatabase); //so it's run on the admin connection
WriteSql(path, "goodnight.sql", sql);

// run it with a timeout shorter than the 1 second sleep, should timeout
var config = Context.GetConfiguration(db, knownFolders) with
{
AdminCommandTimeout = 1, // shorter than the script runs for
};

Assert.CatchAsync(async () =>
{
await using var migrator = Context.GetMigrator(config);
await migrator.Migrate();
Assert.Fail("Should have thrown a timeout exception prior to this!");
});
}

// This does not work for MySql/MariaDB, as it does not support DDL transactions
// [Test]
// public async Task Makes_Whole_Transaction_Rollback()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,10 @@ public string DockerCommand(string serverName, string adminPassword) =>
public SqlStatements Sql => new()
{
SelectVersion = "SELECT VERSION()",
SleepTwoSeconds = "SELECT SLEEP(2);"
};


public string ExpectedVersionPrefix => "10.5.9-MariaDB";
public bool SupportsCreateDatabase => true;
}
}
3 changes: 2 additions & 1 deletion grate.unittests/TestInfrastructure/OracleGrateTestContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,9 @@ public string DockerCommand(string serverName, string adminPassword) =>
public SqlStatements Sql => new()
{
SelectVersion = "SELECT * FROM v$version WHERE banner LIKE 'Oracle%'",
SleepTwoSeconds = "sys.dbms_session.sleep(2);"
};

public string ExpectedVersionPrefix => "Oracle Database 11g Express Edition Release 11.2.0.2.0 - 64bit Production";
public bool SupportsCreateDatabase => true;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,10 @@ public string DockerCommand(string serverName, string adminPassword) =>
public SqlStatements Sql => new()
{
SelectVersion = "SELECT version()",
SleepTwoSeconds = "SELECT pg_sleep(2);"
};


public string ExpectedVersionPrefix => "PostgreSQL 14.";
public bool SupportsCreateDatabase => true;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ public string DockerCommand(string serverName, string adminPassword) =>
public SqlStatements Sql => new()
{
SelectVersion = "SELECT @@VERSION",
SleepTwoSeconds = "WAITFOR DELAY '00:00:02'"
};

public string ExpectedVersionPrefix => "Microsoft SQL Server 2019";
Expand Down
3 changes: 2 additions & 1 deletion grate.unittests/TestInfrastructure/SqlStatements.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,5 @@
public record SqlStatements
{
public string SelectVersion { get; init; } = default!;
}
public string SleepTwoSeconds { get; init; } = default!;
}
67 changes: 38 additions & 29 deletions grate/Migration/AnsiSqlDatabase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ namespace grate.Migration;
public abstract class AnsiSqlDatabase : IDatabase
{
private string SchemaName { get; set; } = "";

protected GrateConfiguration? Config { get; private set; }

protected ILogger Logger { get; }
// ReSharper disable once InconsistentNaming
protected DbConnection? _connection;
Expand Down Expand Up @@ -55,7 +58,7 @@ public virtual Task InitializeConnections(GrateConfiguration configuration)
ConnectionString = configuration.ConnectionString;
AdminConnectionString = configuration.AdminConnectionString;
SchemaName = configuration.SchemaName;

Config = configuration;
return Task.CompletedTask;
}

Expand All @@ -80,11 +83,11 @@ public async Task CreateDatabase()
if (!await DatabaseExists())
{
Logger.LogTrace("Creating database {DatabaseName}", DatabaseName);

using var s = new TransactionScope(TransactionScopeOption.Suppress, TransactionScopeAsyncFlowOption.Enabled);
var sql = _syntax.CreateDatabase(DatabaseName, Password);

await ExecuteNonQuery(AdminConnection, sql);
await ExecuteNonQuery(AdminConnection, sql, Config?.AdminCommandTimeout);
s.Complete();
}

Expand All @@ -99,7 +102,7 @@ public virtual async Task DropDatabase()
using var s = new TransactionScope(TransactionScopeOption.Suppress, TransactionScopeAsyncFlowOption.Enabled);
await CloseConnection(); // try and ensure there's nobody else in there...
await OpenAdminConnection();
await ExecuteNonQuery(AdminConnection, _syntax.DropDatabase(DatabaseName));
await ExecuteNonQuery(AdminConnection, _syntax.DropDatabase(DatabaseName), Config?.AdminCommandTimeout);
s.Complete();
}
}
Expand All @@ -117,13 +120,13 @@ public virtual async Task<bool> DatabaseExists()
{
await OpenConnection();
var databases = (await Connection.QueryAsync<string>(sql)).ToArray();

Logger.LogTrace("Current databases: ");
foreach (var db in databases)
{
Logger.LogTrace(" * {Database}", db);
}

return databases.Contains(DatabaseName);
}
catch (DbException e)
Expand Down Expand Up @@ -175,7 +178,7 @@ private async Task CreateRunSchema()
{
if (SupportsSchemas && !await RunSchemaExists())
{
await ExecuteNonQuery(Connection, _syntax.CreateSchema(SchemaName));
await ExecuteNonQuery(Connection, _syntax.CreateSchema(SchemaName), Config?.CommandTimeout);
}
}

Expand All @@ -201,12 +204,12 @@ protected virtual async Task CreateScriptsRunTable()
entry_date {_syntax.TimestampType} NULL,
modified_date {_syntax.TimestampType} NULL,
entered_by {_syntax.VarcharType}(50) NULL
{_syntax.PrimaryKeyConstraint("ScriptsRun","id")}
{_syntax.PrimaryKeyConstraint("ScriptsRun", "id")}
)";

if (!await ScriptsRunTableExists())
{
await ExecuteNonQuery(Connection, createSql);
await ExecuteNonQuery(Connection, createSql, Config?.CommandTimeout);
}
}

Expand All @@ -224,11 +227,11 @@ protected virtual async Task CreateScriptsRunErrorsTable()
entry_date {_syntax.TimestampType} NULL,
modified_date {_syntax.TimestampType} NULL,
entered_by {_syntax.VarcharType}(50) NULL
{_syntax.PrimaryKeyConstraint("ScriptsRunErrors","id")}
{_syntax.PrimaryKeyConstraint("ScriptsRunErrors", "id")}
)";
if (!await ScriptsRunErrorsTableExists())
{
await ExecuteNonQuery(Connection, createSql);
await ExecuteNonQuery(Connection, createSql, Config?.CommandTimeout);
}
}

Expand All @@ -242,11 +245,11 @@ protected virtual async Task CreateVersionTable()
entry_date {_syntax.TimestampType} NULL,
modified_date {_syntax.TimestampType} NULL,
entered_by {_syntax.VarcharType}(50) NULL
{_syntax.PrimaryKeyConstraint("Version","id")}
{_syntax.PrimaryKeyConstraint("Version", "id")}
)";
if (!await VersionTableExists())
{
await ExecuteNonQuery(Connection, createSql);
await ExecuteNonQuery(Connection, createSql, Config?.CommandTimeout);
}
}

Expand All @@ -264,7 +267,7 @@ private async Task<bool> TableExists(string schemaName, string tableName)
var res = await ExecuteScalarAsync<object>(Connection, existsSql);
return !DBNull.Value.Equals(res) && res is not null;
}

protected virtual string ExistsSql(string tableSchema, string fullTableName)
{
return $@"
Expand All @@ -274,14 +277,14 @@ protected virtual string ExistsSql(string tableSchema, string fullTableName)
table_name = '{fullTableName}'
";
}

protected virtual string CurrentVersionSql => $@"
SELECT
{_syntax.LimitN($@"
version
FROM {VersionTable}
ORDER BY id DESC", 1)}
";
";

public async Task<string> GetCurrentVersion()
{
Expand Down Expand Up @@ -324,14 +327,14 @@ public async Task RunSql(string sql, ConnectionType connectionType)
{
Logger.LogTrace("[SQL] Running (on connection '{ConnType}'): \n{Sql}", connectionType.ToString(), sql);

var conn = connectionType switch
var (conn, timeout) = connectionType switch
{
ConnectionType.Default => Connection,
ConnectionType.Admin => AdminConnection,
ConnectionType.Default => (Connection, Config?.CommandTimeout),
ConnectionType.Admin => (AdminConnection, Config?.AdminCommandTimeout),
_ => throw new ArgumentOutOfRangeException(nameof(connectionType), connectionType, "Unknown connection type: " + connectionType)
};

await ExecuteNonQuery(conn, sql);
await ExecuteNonQuery(conn, sql, timeout);
}

// ReSharper disable once ClassNeverInstantiated.Local
Expand Down Expand Up @@ -376,7 +379,7 @@ private async Task<IDictionary<string, string>> GetAllScriptsRun()
SELECT text_hash FROM {ScriptsRunTable}
WHERE script_name = @scriptName");

var hash = await ExecuteScalarAsync<string?>(Connection, hashSql, new { scriptName });
var hash = await ExecuteScalarAsync<string?>(Connection, hashSql, new { scriptName });
return hash;
}

Expand Down Expand Up @@ -431,7 +434,7 @@ INSERT INTO {ScriptsRunErrorsTable}

var versionSql = Parameterize($"SELECT version FROM {VersionTable} WHERE id = @versionId");
var version = await ExecuteScalarAsync<string>(Connection, versionSql, new { versionId });

var scriptRunErrors = new
{
version,
Expand All @@ -456,7 +459,7 @@ private static async Task Close(DbConnection? conn)
await conn.CloseAsync();
}
}

protected virtual async Task Open(DbConnection? conn)
{
if (conn != null && conn.State != ConnectionState.Open)
Expand All @@ -470,33 +473,39 @@ protected virtual async Task Open(DbConnection? conn)
{
Logger.LogTrace("SQL: {Sql}", sql);
Logger.LogTrace("Parameters: {@Parameters}", parameters);

return await conn.ExecuteScalarAsync<T?>(sql, parameters);
}

protected async Task<int> ExecuteAsync(DbConnection conn, string sql, object? parameters = null)
{
Logger.LogTrace("SQL: {Sql}", sql);
Logger.LogTrace("Parameters: {@Parameters}", parameters);

return await conn.ExecuteAsync(sql, parameters);
}

protected async Task ExecuteNonQuery(DbConnection conn, string sql)
protected async Task ExecuteNonQuery(DbConnection conn, string sql, int? timeout)
{
Logger.LogTrace("SQL: {Sql}", sql);

await using var cmd = conn.CreateCommand();
cmd.CommandText = sql;
cmd.CommandType = CommandType.Text;

if (timeout.HasValue)
{
cmd.CommandTimeout = timeout.Value;
}

await cmd.ExecuteNonQueryAsync();
}

public async ValueTask DisposeAsync()
{
await CloseConnection();
await CloseAdminConnection();

GC.SuppressFinalize(this);
}

Expand Down
Loading

0 comments on commit 06db239

Please sign in to comment.