Skip to content

Commit

Permalink
refactor(postgres): make better use of traits to improve protocol han…
Browse files Browse the repository at this point in the history
…dling
  • Loading branch information
abonander committed Aug 17, 2024
1 parent ec5326e commit 5e8a50f
Show file tree
Hide file tree
Showing 40 changed files with 1,206 additions and 657 deletions.
3 changes: 2 additions & 1 deletion sqlx-postgres/src/advisory_lock.rs
Original file line number Diff line number Diff line change
Expand Up @@ -414,7 +414,8 @@ impl<'lock, C: AsMut<PgConnection>> Drop for PgAdvisoryLockGuard<'lock, C> {
// The `async fn` versions can safely use the prepared statement protocol,
// but this is the safest way to queue a query to execute on the next opportunity.
conn.as_mut()
.queue_simple_query(self.lock.get_release_query());
.queue_simple_query(self.lock.get_release_query())
.expect("BUG: PgAdvisoryLock::get_release_query() somehow too long for protocol");
}
}
}
1 change: 1 addition & 0 deletions sqlx-postgres/src/arguments.rs
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ impl<'q> Arguments<'q> for PgArguments {
write!(writer, "${}", self.buffer.count)
}

#[inline(always)]
fn len(&self) -> usize {
self.buffer.count
}
Expand Down
10 changes: 4 additions & 6 deletions sqlx-postgres/src/connection/describe.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use crate::error::Error;
use crate::ext::ustr::UStr;
use crate::io::StatementId;
use crate::message::{ParameterDescription, RowDescription};
use crate::query_as::query_as;
use crate::query_scalar::query_scalar;
Expand Down Expand Up @@ -416,7 +417,7 @@ WHERE rngtypid = $1

pub(crate) async fn get_nullable_for_columns(
&mut self,
stmt_id: Oid,
stmt_id: StatementId,
meta: &PgStatementMetadata,
) -> Result<Vec<Option<bool>>, Error> {
if meta.columns.is_empty() {
Expand Down Expand Up @@ -486,13 +487,10 @@ WHERE rngtypid = $1
/// and returns `None` for all others.
async fn nullables_from_explain(
&mut self,
stmt_id: Oid,
stmt_id: StatementId,
params_len: usize,
) -> Result<Vec<Option<bool>>, Error> {
let mut explain = format!(
"EXPLAIN (VERBOSE, FORMAT JSON) EXECUTE sqlx_s_{}",
stmt_id.0
);
let mut explain = format!("EXPLAIN (VERBOSE, FORMAT JSON) EXECUTE {stmt_id}");
let mut comma = false;

if params_len > 0 {
Expand Down
30 changes: 14 additions & 16 deletions sqlx-postgres/src/connection/establish.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,10 @@ use crate::HashMap;
use crate::common::StatementCache;
use crate::connection::{sasl, stream::PgStream};
use crate::error::Error;
use crate::io::Decode;
use crate::io::StatementId;
use crate::message::{
Authentication, BackendKeyData, MessageFormat, Password, ReadyForQuery, Startup,
Authentication, BackendKeyData, BackendMessageFormat, Password, ReadyForQuery, Startup,
};
use crate::types::Oid;
use crate::{PgConnectOptions, PgConnection};

// https://www.postgresql.org/docs/current/protocol-flow.html#id-1.10.5.7.3
Expand Down Expand Up @@ -44,13 +43,13 @@ impl PgConnection {
params.push(("options", options));
}

stream
.send(Startup {
username: Some(&options.username),
database: options.database.as_deref(),
params: &params,
})
.await?;
stream.write(Startup {
username: Some(&options.username),
database: options.database.as_deref(),
params: &params,
})?;

stream.flush().await?;

// The server then uses this information and the contents of
// its configuration files (such as pg_hba.conf) to determine whether the connection is
Expand All @@ -64,7 +63,7 @@ impl PgConnection {
loop {
let message = stream.recv().await?;
match message.format {
MessageFormat::Authentication => match message.decode()? {
BackendMessageFormat::Authentication => match message.decode()? {
Authentication::Ok => {
// the authentication exchange is successfully completed
// do nothing; no more information is required to continue
Expand Down Expand Up @@ -108,7 +107,7 @@ impl PgConnection {
}
},

MessageFormat::BackendKeyData => {
BackendMessageFormat::BackendKeyData => {
// provides secret-key data that the frontend must save if it wants to be
// able to issue cancel requests later

Expand All @@ -118,10 +117,9 @@ impl PgConnection {
secret_key = data.secret_key;
}

MessageFormat::ReadyForQuery => {
BackendMessageFormat::ReadyForQuery => {
// start-up is completed. The frontend can now issue commands
transaction_status =
ReadyForQuery::decode(message.contents)?.transaction_status;
transaction_status = message.decode::<ReadyForQuery>()?.transaction_status;

break;
}
Expand All @@ -142,7 +140,7 @@ impl PgConnection {
transaction_status,
transaction_depth: 0,
pending_ready_for_query_count: 0,
next_statement_id: Oid(1),
next_statement_id: StatementId::NAMED_START,
cache_statement: StatementCache::new(options.statement_cache_capacity),
cache_type_oid: HashMap::new(),
cache_type_info: HashMap::new(),
Expand Down
92 changes: 51 additions & 41 deletions sqlx-postgres/src/connection/executor.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
use crate::describe::Describe;
use crate::error::Error;
use crate::executor::{Execute, Executor};
use crate::io::{PortalId, StatementId};
use crate::logger::QueryLogger;
use crate::message::{
self, Bind, Close, CommandComplete, DataRow, MessageFormat, ParameterDescription, Parse, Query,
RowDescription,
self, BackendMessageFormat, Bind, Close, CommandComplete, DataRow, ParameterDescription, Parse,
ParseComplete, Query, RowDescription,
};
use crate::statement::PgStatementMetadata;
use crate::types::Oid;
use crate::{
statement::PgStatement, PgArguments, PgConnection, PgQueryResult, PgRow, PgTypeInfo,
PgValueFormat, Postgres,
Expand All @@ -16,6 +16,7 @@ use futures_core::future::BoxFuture;
use futures_core::stream::BoxStream;
use futures_core::Stream;
use futures_util::{pin_mut, TryStreamExt};
use sqlx_core::arguments::Arguments;
use sqlx_core::Either;
use std::{borrow::Cow, sync::Arc};

Expand All @@ -24,9 +25,9 @@ async fn prepare(
sql: &str,
parameters: &[PgTypeInfo],
metadata: Option<Arc<PgStatementMetadata>>,
) -> Result<(Oid, Arc<PgStatementMetadata>), Error> {
) -> Result<(StatementId, Arc<PgStatementMetadata>), Error> {
let id = conn.next_statement_id;
conn.next_statement_id.incr_one();
conn.next_statement_id = id.next();

// build a list of type OIDs to send to the database in the PARSE command
// we have not yet started the query sequence, so we are *safe* to cleanly make
Expand All @@ -42,25 +43,23 @@ async fn prepare(
conn.wait_until_ready().await?;

// next we send the PARSE command to the server
conn.stream.write(Parse {
conn.stream.write_msg(Parse {
param_types: &param_types,
query: sql,
statement: id,
});
})?;

if metadata.is_none() {
// get the statement columns and parameters
conn.stream.write(message::Describe::Statement(id));
conn.stream.write_msg(message::Describe::Statement(id))?;
}

// we ask for the server to immediately send us the result of the PARSE command
conn.write_sync();
conn.stream.flush().await?;

// indicates that the SQL query string is now successfully parsed and has semantic validity
conn.stream
.recv_expect(MessageFormat::ParseComplete)
.await?;
conn.stream.recv_expect::<ParseComplete>().await?;

let metadata = if let Some(metadata) = metadata {
// each SYNC produces one READY FOR QUERY
Expand Down Expand Up @@ -95,18 +94,18 @@ async fn prepare(
}

async fn recv_desc_params(conn: &mut PgConnection) -> Result<ParameterDescription, Error> {
conn.stream
.recv_expect(MessageFormat::ParameterDescription)
.await
conn.stream.recv_expect().await
}

async fn recv_desc_rows(conn: &mut PgConnection) -> Result<Option<RowDescription>, Error> {
let rows: Option<RowDescription> = match conn.stream.recv().await? {
// describes the rows that will be returned when the statement is eventually executed
message if message.format == MessageFormat::RowDescription => Some(message.decode()?),
message if message.format == BackendMessageFormat::RowDescription => {
Some(message.decode()?)
}

// no data would be returned if this statement was executed
message if message.format == MessageFormat::NoData => None,
message if message.format == BackendMessageFormat::NoData => None,

message => {
return Err(err_protocol!(
Expand All @@ -125,12 +124,12 @@ impl PgConnection {
// we need to wait for the [CloseComplete] to be returned from the server
while count > 0 {
match self.stream.recv().await? {
message if message.format == MessageFormat::PortalSuspended => {
message if message.format == BackendMessageFormat::PortalSuspended => {
// there was an open portal
// this can happen if the last time a statement was used it was not fully executed
}

message if message.format == MessageFormat::CloseComplete => {
message if message.format == BackendMessageFormat::CloseComplete => {
// successfully closed the statement (and freed up the server resources)
count -= 1;
}
Expand All @@ -147,8 +146,11 @@ impl PgConnection {
Ok(())
}

#[inline(always)]
pub(crate) fn write_sync(&mut self) {
self.stream.write(message::Sync);
self.stream
.write_msg(message::Sync)
.expect("BUG: Sync should not be too big for protocol");

// all SYNC messages will return a ReadyForQuery
self.pending_ready_for_query_count += 1;
Expand All @@ -163,7 +165,7 @@ impl PgConnection {
// optional metadata that was provided by the user, this means they are reusing
// a statement object
metadata: Option<Arc<PgStatementMetadata>>,
) -> Result<(Oid, Arc<PgStatementMetadata>), Error> {
) -> Result<(StatementId, Arc<PgStatementMetadata>), Error> {
if let Some(statement) = self.cache_statement.get_mut(sql) {
return Ok((*statement).clone());
}
Expand All @@ -172,7 +174,7 @@ impl PgConnection {

if store_to_cache && self.cache_statement.is_enabled() {
if let Some((id, _)) = self.cache_statement.insert(sql, statement.clone()) {
self.stream.write(Close::Statement(id));
self.stream.write_msg(Close::Statement(id))?;
self.write_sync();

self.stream.flush().await?;
Expand Down Expand Up @@ -201,6 +203,14 @@ impl PgConnection {
let mut metadata: Arc<PgStatementMetadata>;

let format = if let Some(mut arguments) = arguments {
// Check this before we write anything to the stream.
let num_params = i16::try_from(arguments.len()).map_err(|_| {
err_protocol!(
"PgConnection::run(): too many arguments for query: {}",
arguments.len()
)
})?;

// prepare the statement if this our first time executing it
// always return the statement ID here
let (statement, metadata_) = self
Expand All @@ -216,21 +226,21 @@ impl PgConnection {
self.wait_until_ready().await?;

// bind to attach the arguments to the statement and create a portal
self.stream.write(Bind {
portal: None,
self.stream.write_msg(Bind {
portal: PortalId::UNNAMED,
statement,
formats: &[PgValueFormat::Binary],
num_params: arguments.types.len() as i16,
num_params,
params: &arguments.buffer,
result_formats: &[PgValueFormat::Binary],
});
})?;

// executes the portal up to the passed limit
// the protocol-level limit acts nearly identically to the `LIMIT` in SQL
self.stream.write(message::Execute {
portal: None,
self.stream.write_msg(message::Execute {
portal: PortalId::UNNAMED,
limit: limit.into(),
});
})?;
// From https://www.postgresql.org/docs/current/protocol-flow.html:
//
// "An unnamed portal is destroyed at the end of the transaction, or as
Expand All @@ -240,7 +250,7 @@ impl PgConnection {

// we ask the database server to close the unnamed portal and free the associated resources
// earlier - after the execution of the current query.
self.stream.write(message::Close::Portal(None));
self.stream.write_msg(Close::Portal(PortalId::UNNAMED))?;

// finally, [Sync] asks postgres to process the messages that we sent and respond with
// a [ReadyForQuery] message when it's completely done. Theoretically, we could send
Expand All @@ -253,7 +263,7 @@ impl PgConnection {
PgValueFormat::Binary
} else {
// Query will trigger a ReadyForQuery
self.stream.write(Query(query));
self.stream.write_msg(Query(query))?;
self.pending_ready_for_query_count += 1;

// metadata starts out as "nothing"
Expand All @@ -270,12 +280,12 @@ impl PgConnection {
let message = self.stream.recv().await?;

match message.format {
MessageFormat::BindComplete
| MessageFormat::ParseComplete
| MessageFormat::ParameterDescription
| MessageFormat::NoData
BackendMessageFormat::BindComplete
| BackendMessageFormat::ParseComplete
| BackendMessageFormat::ParameterDescription
| BackendMessageFormat::NoData
// unnamed portal has been closed
| MessageFormat::CloseComplete
| BackendMessageFormat::CloseComplete
=> {
// harmless messages to ignore
}
Expand All @@ -284,7 +294,7 @@ impl PgConnection {
// exactly one of these messages: CommandComplete,
// EmptyQueryResponse (if the portal was created from an
// empty query string), ErrorResponse, or PortalSuspended"
MessageFormat::CommandComplete => {
BackendMessageFormat::CommandComplete => {
// a SQL command completed normally
let cc: CommandComplete = message.decode()?;

Expand All @@ -295,16 +305,16 @@ impl PgConnection {
}));
}

MessageFormat::EmptyQueryResponse => {
BackendMessageFormat::EmptyQueryResponse => {
// empty query string passed to an unprepared execute
}

// Message::ErrorResponse is handled in self.stream.recv()

// incomplete query execution has finished
MessageFormat::PortalSuspended => {}
BackendMessageFormat::PortalSuspended => {}

MessageFormat::RowDescription => {
BackendMessageFormat::RowDescription => {
// indicates that a *new* set of rows are about to be returned
let (columns, column_names) = self
.handle_row_description(Some(message.decode()?), false)
Expand All @@ -317,7 +327,7 @@ impl PgConnection {
});
}

MessageFormat::DataRow => {
BackendMessageFormat::DataRow => {
logger.increment_rows_returned();

// one of the set of rows returned by a SELECT, FETCH, etc query
Expand All @@ -331,7 +341,7 @@ impl PgConnection {
r#yield!(Either::Right(row));
}

MessageFormat::ReadyForQuery => {
BackendMessageFormat::ReadyForQuery => {
// processing of the query string is complete
self.handle_ready_for_query(message)?;
break;
Expand Down
Loading

0 comments on commit 5e8a50f

Please sign in to comment.