diff --git a/lib/postgrex/replication_connection.ex b/lib/postgrex/replication_connection.ex index 8fcde20c..523da8e4 100644 --- a/lib/postgrex/replication_connection.ex +++ b/lib/postgrex/replication_connection.ex @@ -527,8 +527,8 @@ defmodule Postgrex.ReplicationConnection do defp handle_data([], s), do: {:keep_state, s} defp handle_data([:copy_done | copies], %{state: {mod, mod_state}} = s) do - with {:keep_state, s} <- handle(mod, :handle_data, [:done, mod_state], nil, s) do - handle_data(copies, %{s | streaming: nil}) + with {:keep_state, s} <- handle(mod, :handle_data, [:done, mod_state], nil, %{s | streaming: nil}) do + handle_data(copies, s) end end diff --git a/test/replication_connection_test.exs b/test/replication_connection_test.exs index a013e32a..f862fbe4 100644 --- a/test/replication_connection_test.exs +++ b/test/replication_connection_test.exs @@ -50,6 +50,22 @@ defmodule ReplicationTest do {:noreply, [reply], pid} end + # This is part of the "stream_continuation" test and handles the COPY :done + # state. It will start another stream right away by starting the replication + # slot. + def handle_data(:done, %{pid: pid, test: "stream_continuation"}) do + send(pid, {:done, System.unique_integer([:monotonic])}) + query = "START_REPLICATION SLOT postgrex_test LOGICAL 0/0 (proto_version '1', publication_names 'postgrex_example')" + + {:stream, query, [], pid} + end + + # This is part of the "stream_continuation" test and handles the COPY results. + def handle_data(msg, %{pid: pid, test: "stream_continuation"} = s) do + send(pid, {msg, System.unique_integer([:monotonic])}) + {:noreply, [], s} + end + def handle_data(msg, pid) do send(pid, {msg, System.unique_integer([:monotonic])}) {:noreply, [], pid} @@ -80,6 +96,12 @@ defmodule ReplicationTest do {:query, query, {from, pid}} end + # This is part of the "stream_continuation" test and handles call that + # triggers that chain of events. + def handle_call({:query, query, %{test: "stream_continuation", next_query: _} = opts}, from, pid) do + {:query, query, Map.merge(opts, %{from: from, pid: pid})} + end + @impl true def handle_call({:disconnect, reason}, _, _) do {:disconnect, reason} @@ -97,6 +119,12 @@ defmodule ReplicationTest do {:noreply, pid} end + # Handles the result of the "stream_continuation" query call. It is the results of the slot creation. + def handle_result(results, %{from: from, test: "stream_continuation", next_query: next_query} = s) do + Postgrex.ReplicationConnection.reply(from, {:ok, results}) + {:stream, next_query, [], Map.delete(s, :next_query)} + end + @epoch DateTime.to_unix(~U[2000-01-01 00:00:00Z], :microsecond) defp current_time(), do: System.os_time(:microsecond) - @epoch end @@ -288,6 +316,24 @@ defmodule ReplicationTest do # Can query after copy is done {:ok, [%Postgrex.Result{}]} = PR.call(context.repl, {:query, "SELECT 1"}) end + + test "allow replication stream right after a COPY stream", context do + P.query!(context.pid, "INSERT INTO repl_test VALUES ($1, $2), ($3, $4)", [42, "42", 1, "1"]) + + query = "CREATE_REPLICATION_SLOT postgrex_test TEMPORARY LOGICAL pgoutput NOEXPORT_SNAPSHOT" + next_query = "COPY repl_test TO STDOUT" + + PR.call( + context.repl, + {:query, query, %{test: "stream_continuation", next_query: next_query}} + ) + + assert_receive {"42\t42\n", i1}, @timeout + assert_receive {"1\t1\n", i2} when i1 < i2, @timeout + assert_receive {:done, i3} when i2 < i3, @timeout + # Prior to allowing one stream to start after another, this would fail + assert_receive <>, @timeout + end end defp start_replication(repl) do