From 19f40d87a669e41081a629a469359d341c9223d6 Mon Sep 17 00:00:00 2001 From: Andrei Nesterov Date: Thu, 3 Oct 2024 01:32:30 +0300 Subject: [PATCH] Obey no-transaction flag in down migrations (#3528) --- sqlx-postgres/src/migrate.rs | 40 +++++++++++++++++++++++++----------- 1 file changed, 28 insertions(+), 12 deletions(-) diff --git a/sqlx-postgres/src/migrate.rs b/sqlx-postgres/src/migrate.rs index da3080581e..c37e92f4d6 100644 --- a/sqlx-postgres/src/migrate.rs +++ b/sqlx-postgres/src/migrate.rs @@ -252,20 +252,18 @@ CREATE TABLE IF NOT EXISTS _sqlx_migrations ( migration: &'m Migration, ) -> BoxFuture<'m, Result> { Box::pin(async move { - // Use a single transaction for the actual migration script and the essential bookeeping so we never - // execute migrations twice. See https://github.com/launchbadge/sqlx/issues/1966. - let mut tx = self.begin().await?; let start = Instant::now(); - let _ = tx.execute(&*migration.sql).await?; - - // language=SQL - let _ = query(r#"DELETE FROM _sqlx_migrations WHERE version = $1"#) - .bind(migration.version) - .execute(&mut *tx) - .await?; - - tx.commit().await?; + // execute migration queries + if migration.no_tx { + revert_migration(self, migration).await?; + } else { + // Use a single transaction for the actual migration script and the essential bookeeping so we never + // execute migrations twice. See https://github.com/launchbadge/sqlx/issues/1966. + let mut tx = self.begin().await?; + revert_migration(&mut tx, migration).await?; + tx.commit().await?; + } let elapsed = start.elapsed(); @@ -299,6 +297,24 @@ async fn execute_migration( Ok(()) } +async fn revert_migration( + conn: &mut PgConnection, + migration: &Migration, +) -> Result<(), MigrateError> { + let _ = conn + .execute(&*migration.sql) + .await + .map_err(|e| MigrateError::ExecuteMigration(e, migration.version))?; + + // language=SQL + let _ = query(r#"DELETE FROM _sqlx_migrations WHERE version = $1"#) + .bind(migration.version) + .execute(conn) + .await?; + + Ok(()) +} + async fn current_database(conn: &mut PgConnection) -> Result { // language=SQL Ok(query_scalar("SELECT current_database()")