Skip to content

Commit

Permalink
Allow Postgres tests to be run on a different database (#329)
Browse files Browse the repository at this point in the history
* Allow Postgres tests to be run on a different database

Not everyone has a "scratch" PostgreSQL running on localhost:5432 for refinery to scribble all over.
Now you can specify an arbitrary PostgreSQL server to work on with the `DB_URI` environment variable (which appears to be what `refinery-cli` already uses) to test in.

* Improve DB reset process

* Use a more appropriate name for the function that does the work
* Clean just the `public` schema, rather than drop/create the whole DB.  This means that running the tests no longer requires superuser privs, and that we don't have to temporarily hide out in `template1`.
* Drop the `catch_unwind`, because it's needed any more.

---------

Co-authored-by: João Oliveira <[email protected]>
  • Loading branch information
mpalmer and jxs authored May 21, 2024
1 parent a2d6a61 commit 74d066f
Showing 1 changed file with 38 additions and 69 deletions.
107 changes: 38 additions & 69 deletions refinery/tests/postgres.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,11 @@ mod postgres {
use assert_cmd::prelude::*;
use predicates::str::contains;
use refinery::{
config::{Config, ConfigDbType},
embed_migrations,
error::Kind,
Migrate, Migration, Runner, Target,
config::Config, embed_migrations, error::Kind, Migrate, Migration, Runner, Target,
};
use refinery_core::postgres::{Client, NoTls};
use std::process::Command;
use std::str::FromStr;
use time::OffsetDateTime;

const DEFAULT_TABLE_NAME: &str = "refinery_schema_history";
Expand All @@ -31,6 +29,10 @@ mod postgres {
embed_migrations!("./tests/migrations_missing");
}

fn db_uri() -> String {
std::env::var("DB_URI").unwrap_or("postgres://postgres@localhost:5432/postgres".to_string())
}

fn get_migrations() -> Vec<Migration> {
embed_migrations!("./tests/migrations");

Expand Down Expand Up @@ -64,36 +66,32 @@ mod postgres {
vec![migration1, migration2, migration3, migration4, migration5]
}

fn clean_database() {
let mut client =
Client::connect("postgres://postgres@localhost:5432/template1", NoTls).unwrap();
fn prep_database() {
let uri = db_uri();

let mut client = Client::connect(&db_uri(), NoTls).unwrap();

client
.execute(
"SELECT pg_terminate_backend(pid) FROM pg_stat_activity WHERE datname='postgres'",
&[],
)
.execute("DROP SCHEMA IF EXISTS public CASCADE", &[])
.unwrap();
client
.execute("CREATE SCHEMA IF NOT EXISTS public", &[])
.unwrap();
client.execute("DROP DATABASE POSTGRES", &[]).unwrap();
client.execute("CREATE DATABASE POSTGRES", &[]).unwrap();
}

fn run_test<T>(test: T)
where
T: FnOnce() + std::panic::UnwindSafe,
T: FnOnce(),
{
let result = std::panic::catch_unwind(test);

clean_database();
prep_database();

assert!(result.is_ok())
test();
}

#[test]
fn report_contains_applied_migrations() {
run_test(|| {
let mut client =
Client::connect("postgres://postgres@localhost:5432/postgres", NoTls).unwrap();
let mut client = Client::connect(&db_uri(), NoTls).unwrap();

let report = embedded::migrations::runner().run(&mut client).unwrap();

Expand Down Expand Up @@ -122,8 +120,7 @@ mod postgres {
#[test]
fn creates_migration_table() {
run_test(|| {
let mut client =
Client::connect("postgres://postgres@localhost:5432/postgres", NoTls).unwrap();
let mut client = Client::connect(&db_uri(), NoTls).unwrap();
embedded::migrations::runner().run(&mut client).unwrap();
for row in &client
.query(
Expand All @@ -144,8 +141,7 @@ mod postgres {
#[test]
fn creates_migration_table_grouped_transaction() {
run_test(|| {
let mut client =
Client::connect("postgres://postgres@localhost:5432/postgres", NoTls).unwrap();
let mut client = Client::connect(&db_uri(), NoTls).unwrap();

embedded::migrations::runner()
.set_grouped(true)
Expand All @@ -171,8 +167,7 @@ mod postgres {
#[test]
fn applies_migration() {
run_test(|| {
let mut client =
Client::connect("postgres://postgres@localhost:5432/postgres", NoTls).unwrap();
let mut client = Client::connect(&db_uri(), NoTls).unwrap();
embedded::migrations::runner().run(&mut client).unwrap();
client
.execute(
Expand All @@ -192,8 +187,7 @@ mod postgres {
#[test]
fn applies_migration_grouped_transaction() {
run_test(|| {
let mut client =
Client::connect("postgres://postgres@localhost:5432/postgres", NoTls).unwrap();
let mut client = Client::connect(&db_uri(), NoTls).unwrap();

embedded::migrations::runner()
.set_grouped(false)
Expand All @@ -218,8 +212,7 @@ mod postgres {
#[test]
fn updates_schema_history() {
run_test(|| {
let mut client =
Client::connect("postgres://postgres@localhost:5432/postgres", NoTls).unwrap();
let mut client = Client::connect(&db_uri(), NoTls).unwrap();

embedded::migrations::runner().run(&mut client).unwrap();

Expand All @@ -239,8 +232,7 @@ mod postgres {
#[test]
fn updates_schema_history_grouped_transaction() {
run_test(|| {
let mut client =
Client::connect("postgres://postgres@localhost:5432/postgres", NoTls).unwrap();
let mut client = Client::connect(&db_uri(), NoTls).unwrap();

embedded::migrations::runner()
.set_grouped(false)
Expand All @@ -262,8 +254,7 @@ mod postgres {
#[test]
fn updates_to_last_working_if_not_grouped() {
run_test(|| {
let mut client =
Client::connect("postgres://postgres@localhost:5432/postgres", NoTls).unwrap();
let mut client = Client::connect(&db_uri(), NoTls).unwrap();

let result = broken::migrations::runner().run(&mut client);

Expand Down Expand Up @@ -300,8 +291,7 @@ mod postgres {
#[test]
fn doesnt_update_to_last_working_if_grouped() {
run_test(|| {
let mut client =
Client::connect("postgres://postgres@localhost:5432/postgres", NoTls).unwrap();
let mut client = Client::connect(&db_uri(), NoTls).unwrap();

let result = broken::migrations::runner()
.set_grouped(true)
Expand All @@ -320,8 +310,7 @@ mod postgres {
#[test]
fn gets_applied_migrations() {
run_test(|| {
let mut client =
Client::connect("postgres://postgres@localhost:5432/postgres", NoTls).unwrap();
let mut client = Client::connect(&db_uri(), NoTls).unwrap();

embedded::migrations::runner().run(&mut client).unwrap();

Expand Down Expand Up @@ -349,8 +338,7 @@ mod postgres {
#[test]
fn applies_new_migration() {
run_test(|| {
let mut client =
Client::connect("postgres://postgres@localhost:5432/postgres", NoTls).unwrap();
let mut client = Client::connect(&db_uri(), NoTls).unwrap();

embedded::migrations::runner().run(&mut client).unwrap();

Expand Down Expand Up @@ -381,8 +369,7 @@ mod postgres {
#[test]
fn migrates_to_target_migration() {
run_test(|| {
let mut client =
Client::connect("postgres://postgres@localhost:5432/postgres", NoTls).unwrap();
let mut client = Client::connect(&db_uri(), NoTls).unwrap();

let report = embedded::migrations::runner()
.set_target(Target::Version(3))
Expand Down Expand Up @@ -417,8 +404,7 @@ mod postgres {
#[test]
fn migrates_to_target_migration_grouped() {
run_test(|| {
let mut client =
Client::connect("postgres://postgres@localhost:5432/postgres", NoTls).unwrap();
let mut client = Client::connect(&db_uri(), NoTls).unwrap();

let report = embedded::migrations::runner()
.set_target(Target::Version(3))
Expand Down Expand Up @@ -454,8 +440,7 @@ mod postgres {
#[test]
fn aborts_on_missing_migration_on_filesystem() {
run_test(|| {
let mut client =
Client::connect("postgres://postgres@localhost:5432/postgres", NoTls).unwrap();
let mut client = Client::connect(&db_uri(), NoTls).unwrap();

embedded::migrations::runner().run(&mut client).unwrap();

Expand Down Expand Up @@ -488,8 +473,7 @@ mod postgres {
#[test]
fn aborts_on_divergent_migration() {
run_test(|| {
let mut client =
Client::connect("postgres://postgres@localhost:5432/postgres", NoTls).unwrap();
let mut client = Client::connect(&db_uri(), NoTls).unwrap();

embedded::migrations::runner().run(&mut client).unwrap();

Expand Down Expand Up @@ -523,8 +507,7 @@ mod postgres {
#[test]
fn aborts_on_missing_migration_on_database() {
run_test(|| {
let mut client =
Client::connect("postgres://postgres@localhost:5432/postgres", NoTls).unwrap();
let mut client = Client::connect(&db_uri(), NoTls).unwrap();

missing::migrations::runner().run(&mut client).unwrap();

Expand Down Expand Up @@ -568,11 +551,7 @@ mod postgres {
#[test]
fn migrates_from_config() {
run_test(|| {
let mut config = Config::new(ConfigDbType::Postgres)
.set_db_name("postgres")
.set_db_user("postgres")
.set_db_host("localhost")
.set_db_port("5432");
let mut config = Config::from_str(&db_uri()).unwrap();

let migrations = get_migrations();
let runner = Runner::new(&migrations)
Expand Down Expand Up @@ -608,11 +587,7 @@ mod postgres {
#[test]
fn migrate_from_config_report_contains_migrations() {
run_test(|| {
let mut config = Config::new(ConfigDbType::Postgres)
.set_db_name("postgres")
.set_db_user("postgres")
.set_db_host("localhost")
.set_db_port("5432");
let mut config = Config::from_str(&db_uri()).unwrap();

let migrations = get_migrations();
let runner = Runner::new(&migrations)
Expand Down Expand Up @@ -648,11 +623,7 @@ mod postgres {
#[test]
fn migrate_from_config_report_returns_last_applied_migration() {
run_test(|| {
let mut config = Config::new(ConfigDbType::Postgres)
.set_db_name("postgres")
.set_db_user("postgres")
.set_db_host("localhost")
.set_db_port("5432");
let mut config = Config::from_str(&db_uri()).unwrap();

let migrations = get_migrations();
let runner = Runner::new(&migrations)
Expand All @@ -677,8 +648,7 @@ mod postgres {
#[test]
fn doesnt_run_migrations_if_fake() {
run_test(|| {
let mut client =
Client::connect("postgres://postgres@localhost:5432/postgres", NoTls).unwrap();
let mut client = Client::connect(&db_uri(), NoTls).unwrap();

let report = embedded::migrations::runner()
.set_target(Target::Fake)
Expand Down Expand Up @@ -712,8 +682,7 @@ mod postgres {
#[test]
fn doesnt_run_migrations_if_fake_version() {
run_test(|| {
let mut client =
Client::connect("postgres://postgres@localhost:5432/postgres", NoTls).unwrap();
let mut client = Client::connect(&db_uri(), NoTls).unwrap();

let report = embedded::migrations::runner()
.set_target(Target::FakeVersion(2))
Expand Down

0 comments on commit 74d066f

Please sign in to comment.