Skip to content

Commit

Permalink
implement comment in postgrex
Browse files Browse the repository at this point in the history
  • Loading branch information
dkuku committed Oct 12, 2024
1 parent 5924c18 commit be74cce
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 15 deletions.
6 changes: 4 additions & 2 deletions lib/postgrex.ex
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,8 @@ defmodule Postgrex do
{:ok, Postgrex.Query.t()} | {:error, Exception.t()}
def prepare(conn, name, statement, opts \\ []) do
query = %Query{name: name, statement: statement}
opts = Keyword.put(opts, :postgrex_prepare, true)
prepare? = !Keyword.get(opts, :comment)
opts = Keyword.put(opts, :postgrex_prepare, prepare?)
DBConnection.prepare(conn, query, opts)
end

Expand All @@ -373,7 +374,8 @@ defmodule Postgrex do
"""
@spec prepare!(conn, iodata, iodata, [option]) :: Postgrex.Query.t()
def prepare!(conn, name, statement, opts \\ []) do
opts = Keyword.put(opts, :postgrex_prepare, true)
prepare? = !Keyword.get(opts, :comment)
opts = Keyword.put(opts, :postgrex_prepare, prepare?)
DBConnection.prepare!(conn, %Query{name: name, statement: statement}, opts)
end

Expand Down
48 changes: 35 additions & 13 deletions lib/postgrex/protocol.ex
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ defmodule Postgrex.Protocol do
message:
"`:commit_comment` option cannot contain sequence \"*/\""
)
@comment_validation_error Postgrex.Error.exception(
message: "`:comment` option cannot contain sequence \"*/\""
)

defstruct sock: nil,
connection_id: nil,
Expand Down Expand Up @@ -343,12 +346,19 @@ defmodule Postgrex.Protocol do
end

def handle_prepare(%Query{name: ""} = query, opts, s) do
prepare = Keyword.get(opts, :postgrex_prepare, false)
status = new_status(opts, prepare: prepare)
prepare? = Keyword.get(opts, :postgrex_prepare, false)
status = new_status(opts, prepare: prepare?)

case prepare do
true -> parse_describe_close(s, status, query)
false -> parse_describe_flush(s, status, query)
if prepare? do
parse_describe_close(s, status, query)
else
comment = Keyword.get(opts, :comment)

if is_binary(comment) && String.contains?(comment, "*/") do
raise @comment_validation_error
else
parse_describe_flush(s, status, query, comment)
end
end
end

Expand Down Expand Up @@ -1418,9 +1428,9 @@ defmodule Postgrex.Protocol do
parse_describe(s, status, query)
end

defp parse_describe_flush(s, %{mode: :transaction} = status, query) do
defp parse_describe_flush(s, %{mode: :transaction} = status, query, comment) do
%{buffer: buffer} = s
msgs = parse_describe_msgs(query, [msg_flush()])
msgs = parse_describe_comment_msgs(query, comment, [msg_flush()])

with :ok <- msg_send(%{s | buffer: nil}, msgs, buffer),
{:ok, %Query{ref: ref} = query, %{postgres: postgres} = s, buffer} <-
Expand All @@ -1442,11 +1452,12 @@ defmodule Postgrex.Protocol do
defp parse_describe_flush(
%{postgres: :transaction, buffer: buffer} = s,
%{mode: :savepoint} = status,
query
query,
comment
) do
msgs =
[msg_query(statement: "SAVEPOINT postgrex_query")] ++
parse_describe_msgs(query, [msg_flush()])
parse_describe_comment_msgs(query, comment, [msg_flush()])

with :ok <- msg_send(%{s | buffer: nil}, msgs, buffer),
{:ok, _, %{buffer: buffer} = s} <- recv_transaction(s, status, buffer),
Expand All @@ -1466,7 +1477,7 @@ defmodule Postgrex.Protocol do
end
end

defp parse_describe_flush(%{postgres: postgres} = s, %{mode: :savepoint}, _)
defp parse_describe_flush(%{postgres: postgres} = s, %{mode: :savepoint}, _, _)
when postgres in [:idle, :error] do
transaction_error(s, postgres)
end
Expand Down Expand Up @@ -1593,6 +1604,16 @@ defmodule Postgrex.Protocol do
transaction_error(s, postgres)
end

defp parse_describe_comment_msgs(query, comment, tail) when is_binary(comment) do
statement = "/* #{comment} */\n" <> query.statement
query = %{query | statement: statement}
parse_describe_msgs(query, tail)
end

defp parse_describe_comment_msgs(query, _comment, tail) do
parse_describe_msgs(query, tail)
end

defp parse_describe_msgs(query, tail) do
%Query{name: name, statement: statement, param_oids: param_oids} = query
type_oids = param_oids || []
Expand Down Expand Up @@ -2079,7 +2100,7 @@ defmodule Postgrex.Protocol do

_ ->
# flush awaiting execute or declare
parse_describe_flush(s, status, query)
parse_describe_flush(s, status, query, nil)
end
end

Expand All @@ -2105,7 +2126,7 @@ defmodule Postgrex.Protocol do
defp handle_prepare_execute(%Query{name: ""} = query, params, opts, s) do
status = new_status(opts)

case parse_describe_flush(s, status, query) do
case parse_describe_flush(s, status, query, nil) do
{:ok, query, s} ->
bind_execute_close(s, status, query, params)

Expand Down Expand Up @@ -2396,7 +2417,7 @@ defmodule Postgrex.Protocol do
defp handle_prepare_bind(%Query{name: ""} = query, params, res, opts, s) do
status = new_status(opts)

case parse_describe_flush(s, status, query) do
case parse_describe_flush(s, status, query, nil) do
{:ok, query, s} ->
bind(s, status, query, params, res)

Expand Down Expand Up @@ -3371,6 +3392,7 @@ defmodule Postgrex.Protocol do

defp msg_send(s, msgs, buffer) when is_list(msgs) do
binaries = Enum.reduce(msgs, [], &[&2 | maybe_encode_msg(&1)])

do_send(s, binaries, buffer)
end

Expand Down
9 changes: 9 additions & 0 deletions test/query_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -1851,6 +1851,15 @@ defmodule QueryTest do
assert [["1", "2"], ["3", "4"]] = query("COPY (VALUES (1, 2), (3, 4)) TO STDOUT", [], opts)
end

test "comment", context do
assert [[123]] = query("select 123", [], comment: "query comment goes here")

assert_raise Postgrex.Error, fn ->
query("select 123", [], comment: "*/ DROP TABLE 123 --")
end
end

@tag :big_binary
test "receive packet with remainder greater than 64MB", context do
# to ensure remainder is more than 64MB use 64MBx2+1
big_binary = :binary.copy(<<1>>, 128 * 1024 * 1024 + 1)
Expand Down

0 comments on commit be74cce

Please sign in to comment.