Skip to content

Commit

Permalink
fix(mysql): fallout from ec5326e
Browse files Browse the repository at this point in the history
  • Loading branch information
abonander committed Aug 18, 2024
1 parent 5e8a50f commit 3945e06
Show file tree
Hide file tree
Showing 20 changed files with 91 additions and 66 deletions.
7 changes: 7 additions & 0 deletions sqlx-mysql/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,13 @@ any = ["sqlx-core/any"]
offline = ["sqlx-core/offline", "serde/derive"]
migrate = ["sqlx-core/migrate"]

# Type Integration features
bigdecimal = ["dep:bigdecimal", "sqlx-core/bigdecimal"]
chrono = ["dep:chrono", "sqlx-core/chrono"]
rust_decimal = ["dep:rust_decimal", "rust_decimal/maths", "sqlx-core/rust_decimal"]
time = ["dep:time", "sqlx-core/time"]
uuid = ["dep:uuid", "sqlx-core/uuid"]

[dependencies]
sqlx-core = { workspace = true }

Expand Down
8 changes: 4 additions & 4 deletions sqlx-mysql/src/connection/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use bytes::{Buf, Bytes, BytesMut};
use crate::collation::{CharSet, Collation};
use crate::error::Error;
use crate::io::MySqlBufExt;
use crate::io::{Decode, Encode};
use crate::io::{ProtocolDecode, ProtocolEncode};
use crate::net::{BufferedSocket, Socket};
use crate::protocol::response::{EofPacket, ErrPacket, OkPacket, Status};
use crate::protocol::{Capabilities, Packet};
Expand Down Expand Up @@ -110,7 +110,7 @@ impl<S: Socket> MySqlStream<S> {

pub(crate) async fn send_packet<'en, T>(&mut self, payload: T) -> Result<(), Error>
where
T: Encode<'en, Capabilities>,
T: ProtocolEncode<'en, Capabilities>,
{
self.sequence_id = 0;
self.write_packet(payload);
Expand All @@ -120,7 +120,7 @@ impl<S: Socket> MySqlStream<S> {

pub(crate) fn write_packet<'en, T>(&mut self, payload: T)
where
T: Encode<'en, Capabilities>,
T: ProtocolEncode<'en, Capabilities>,
{
self.socket
.write_with(Packet(payload), (self.capabilities, &mut self.sequence_id));
Expand Down Expand Up @@ -184,7 +184,7 @@ impl<S: Socket> MySqlStream<S> {

pub(crate) async fn recv<'de, T>(&mut self) -> Result<T, Error>
where
T: Decode<'de, Capabilities>,
T: ProtocolDecode<'de, Capabilities>,
{
self.recv_packet().await?.decode_with(self.capabilities)
}
Expand Down
11 changes: 6 additions & 5 deletions sqlx-mysql/src/protocol/connect/auth_switch.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
use bytes::{Buf, Bytes};

use crate::error::Error;
use crate::io::Encode;
use crate::io::{BufExt, Decode};
use crate::io::ProtocolEncode;
use crate::io::{BufExt, ProtocolDecode};
use crate::protocol::auth::AuthPlugin;
use crate::protocol::Capabilities;

Expand All @@ -14,7 +14,7 @@ pub struct AuthSwitchRequest {
pub data: Bytes,
}

impl Decode<'_, bool> for AuthSwitchRequest {
impl ProtocolDecode<'_, bool> for AuthSwitchRequest {
fn decode_with(mut buf: Bytes, enable_cleartext_plugin: bool) -> Result<Self, Error> {
let header = buf.get_u8();
if header != 0xfe {
Expand Down Expand Up @@ -58,9 +58,10 @@ impl Decode<'_, bool> for AuthSwitchRequest {
#[derive(Debug)]
pub struct AuthSwitchResponse(pub Vec<u8>);

impl Encode<'_, Capabilities> for AuthSwitchResponse {
fn encode_with(&self, buf: &mut Vec<u8>, _: Capabilities) {
impl ProtocolEncode<'_, Capabilities> for AuthSwitchResponse {
fn encode_with(&self, buf: &mut Vec<u8>, _: Capabilities) -> Result<(), Error> {
buf.extend_from_slice(&self.0);
Ok(())
}
}

Expand Down
4 changes: 2 additions & 2 deletions sqlx-mysql/src/protocol/connect/handshake.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use bytes::{Buf, Bytes};
use std::cmp;

use crate::error::Error;
use crate::io::{BufExt, Decode};
use crate::io::{BufExt, ProtocolDecode};
use crate::protocol::auth::AuthPlugin;
use crate::protocol::response::Status;
use crate::protocol::Capabilities;
Expand All @@ -27,7 +27,7 @@ pub(crate) struct Handshake {
pub(crate) auth_plugin_data: Chain<Bytes, Bytes>,
}

impl Decode<'_> for Handshake {
impl ProtocolDecode<'_> for Handshake {
fn decode_with(mut buf: Bytes, _: ()) -> Result<Self, Error> {
let protocol_version = buf.get_u8(); // int<1>
let server_version = buf.get_str_nul()?; // string<NUL>
Expand Down
24 changes: 15 additions & 9 deletions sqlx-mysql/src/protocol/connect/handshake_response.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::io::MySqlBufMutExt;
use crate::io::{BufMutExt, Encode};
use crate::io::{BufMutExt, ProtocolEncode};
use crate::protocol::auth::AuthPlugin;
use crate::protocol::connect::ssl_request::SslRequest;
use crate::protocol::Capabilities;
Expand Down Expand Up @@ -27,25 +27,29 @@ pub struct HandshakeResponse<'a> {
pub auth_response: Option<&'a [u8]>,
}

impl Encode<'_, Capabilities> for HandshakeResponse<'_> {
fn encode_with(&self, buf: &mut Vec<u8>, mut capabilities: Capabilities) {
impl ProtocolEncode<'_, Capabilities> for HandshakeResponse<'_> {
fn encode_with(
&self,
buf: &mut Vec<u8>,
mut context: Capabilities,
) -> Result<(), crate::Error> {
if self.auth_plugin.is_none() {
// ensure PLUGIN_AUTH is set *only* if we have a defined plugin
capabilities.remove(Capabilities::PLUGIN_AUTH);
context.remove(Capabilities::PLUGIN_AUTH);
}

// NOTE: Half of this packet is identical to the SSL Request packet
SslRequest {
max_packet_size: self.max_packet_size,
collation: self.collation,
}
.encode_with(buf, capabilities);
.encode_with(buf, context)?;

buf.put_str_nul(self.username);

if capabilities.contains(Capabilities::PLUGIN_AUTH_LENENC_DATA) {
if context.contains(Capabilities::PLUGIN_AUTH_LENENC_DATA) {
buf.put_bytes_lenenc(self.auth_response.unwrap_or_default());
} else if capabilities.contains(Capabilities::SECURE_CONNECTION) {
} else if context.contains(Capabilities::SECURE_CONNECTION) {
let response = self.auth_response.unwrap_or_default();

buf.push(response.len() as u8);
Expand All @@ -54,20 +58,22 @@ impl Encode<'_, Capabilities> for HandshakeResponse<'_> {
buf.push(0);
}

if capabilities.contains(Capabilities::CONNECT_WITH_DB) {
if context.contains(Capabilities::CONNECT_WITH_DB) {
if let Some(database) = &self.database {
buf.put_str_nul(database);
} else {
buf.push(0);
}
}

if capabilities.contains(Capabilities::PLUGIN_AUTH) {
if context.contains(Capabilities::PLUGIN_AUTH) {
if let Some(plugin) = &self.auth_plugin {
buf.put_str_nul(plugin.name());
} else {
buf.push(0);
}
}

Ok(())
}
}
14 changes: 8 additions & 6 deletions sqlx-mysql/src/protocol/connect/ssl_request.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::io::Encode;
use crate::io::ProtocolEncode;
use crate::protocol::Capabilities;

// https://dev.mysql.com/doc/dev/mysql-server/8.0.12/page_protocol_connection_phase_packets_protocol_handshake_response.html
Expand All @@ -10,21 +10,23 @@ pub struct SslRequest {
pub collation: u8,
}

impl Encode<'_, Capabilities> for SslRequest {
fn encode_with(&self, buf: &mut Vec<u8>, capabilities: Capabilities) {
buf.extend(&(capabilities.bits() as u32).to_le_bytes());
impl ProtocolEncode<'_, Capabilities> for SslRequest {
fn encode_with(&self, buf: &mut Vec<u8>, context: Capabilities) -> Result<(), crate::Error> {
buf.extend(&(context.bits() as u32).to_le_bytes());
buf.extend(&self.max_packet_size.to_le_bytes());
buf.push(self.collation);

// reserved: string<19>
buf.extend(&[0_u8; 19]);

if capabilities.contains(Capabilities::MYSQL) {
if context.contains(Capabilities::MYSQL) {
// reserved: string<4>
buf.extend(&[0_u8; 4]);
} else {
// extended client capabilities (MariaDB-specified): int<4>
buf.extend(&((capabilities.bits() >> 32) as u32).to_le_bytes());
buf.extend(&((context.bits() >> 32) as u32).to_le_bytes());
}

Ok(())
}
}
16 changes: 9 additions & 7 deletions sqlx-mysql/src/protocol/packet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,22 @@ use std::ops::{Deref, DerefMut};
use bytes::Bytes;

use crate::error::Error;
use crate::io::{Decode, Encode};
use crate::io::{ProtocolDecode, ProtocolEncode};
use crate::protocol::response::{EofPacket, OkPacket};
use crate::protocol::Capabilities;

#[derive(Debug)]
pub struct Packet<T>(pub(crate) T);

impl<'en, 'stream, T> Encode<'stream, (Capabilities, &'stream mut u8)> for Packet<T>
impl<'en, 'stream, T> ProtocolEncode<'stream, (Capabilities, &'stream mut u8)> for Packet<T>
where
T: Encode<'en, Capabilities>,
T: ProtocolEncode<'en, Capabilities>,
{
fn encode_with(
&self,
buf: &mut Vec<u8>,
(capabilities, sequence_id): (Capabilities, &'stream mut u8),
) {
) -> Result<(), Error> {
let mut next_header = |len: u32| {
let mut buf = len.to_le_bytes();
buf[3] = *sequence_id;
Expand All @@ -33,7 +33,7 @@ where
buf.extend(&[0_u8; 4]);

// encode the payload
self.0.encode_with(buf, capabilities);
self.0.encode_with(buf, capabilities)?;

// determine the length of the encoded payload
// and write to our reserved space
Expand All @@ -59,20 +59,22 @@ where
buf.extend(&next_header(remainder.len() as u32));
buf.extend(remainder);
}

Ok(())
}
}

impl Packet<Bytes> {
pub(crate) fn decode<'de, T>(self) -> Result<T, Error>
where
T: Decode<'de, ()>,
T: ProtocolDecode<'de, ()>,
{
self.decode_with(())
}

pub(crate) fn decode_with<'de, T, C>(self, context: C) -> Result<T, Error>
where
T: Decode<'de, C>,
T: ProtocolDecode<'de, C>,
{
T::decode_with(self.0, context)
}
Expand Down
4 changes: 2 additions & 2 deletions sqlx-mysql/src/protocol/response/eof.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use bytes::{Buf, Bytes};

use crate::error::Error;
use crate::io::Decode;
use crate::io::ProtocolDecode;
use crate::protocol::response::Status;
use crate::protocol::Capabilities;

Expand All @@ -18,7 +18,7 @@ pub struct EofPacket {
pub status: Status,
}

impl Decode<'_, Capabilities> for EofPacket {
impl ProtocolDecode<'_, Capabilities> for EofPacket {
fn decode_with(mut buf: Bytes, _: Capabilities) -> Result<Self, Error> {
let header = buf.get_u8();
if header != 0xfe {
Expand Down
4 changes: 2 additions & 2 deletions sqlx-mysql/src/protocol/response/err.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use bytes::{Buf, Bytes};

use crate::error::Error;
use crate::io::{BufExt, Decode};
use crate::io::{BufExt, ProtocolDecode};
use crate::protocol::Capabilities;

// https://dev.mysql.com/doc/dev/mysql-server/8.0.12/page_protocol_basic_err_packet.html
Expand All @@ -15,7 +15,7 @@ pub struct ErrPacket {
pub error_message: String,
}

impl Decode<'_, Capabilities> for ErrPacket {
impl ProtocolDecode<'_, Capabilities> for ErrPacket {
fn decode_with(mut buf: Bytes, capabilities: Capabilities) -> Result<Self, Error> {
let header = buf.get_u8();
if header != 0xff {
Expand Down
4 changes: 2 additions & 2 deletions sqlx-mysql/src/protocol/response/ok.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
use bytes::{Buf, Bytes};

use crate::error::Error;
use crate::io::Decode;
use crate::io::MySqlBufExt;
use crate::io::ProtocolDecode;
use crate::protocol::response::Status;

/// Indicates successful completion of a previous command sent by the client.
Expand All @@ -14,7 +14,7 @@ pub struct OkPacket {
pub warnings: u16,
}

impl Decode<'_> for OkPacket {
impl ProtocolDecode<'_> for OkPacket {
fn decode_with(mut buf: Bytes, _: ()) -> Result<Self, Error> {
let header = buf.get_u8();
if header != 0 && header != 0xfe {
Expand Down
8 changes: 5 additions & 3 deletions sqlx-mysql/src/protocol/statement/execute.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::io::Encode;
use crate::io::ProtocolEncode;
use crate::protocol::text::ColumnFlags;
use crate::protocol::Capabilities;
use crate::MySqlArguments;
Expand All @@ -11,8 +11,8 @@ pub struct Execute<'q> {
pub arguments: &'q MySqlArguments,
}

impl<'q> Encode<'_, Capabilities> for Execute<'q> {
fn encode_with(&self, buf: &mut Vec<u8>, _: Capabilities) {
impl<'q> ProtocolEncode<'_, Capabilities> for Execute<'q> {
fn encode_with(&self, buf: &mut Vec<u8>, _: Capabilities) -> Result<(), crate::Error> {
buf.push(0x17); // COM_STMT_EXECUTE
buf.extend(&self.statement.to_le_bytes());
buf.push(0); // NO_CURSOR
Expand All @@ -34,5 +34,7 @@ impl<'q> Encode<'_, Capabilities> for Execute<'q> {

buf.extend(&*self.arguments.values);
}

Ok(())
}
}
7 changes: 4 additions & 3 deletions sqlx-mysql/src/protocol/statement/prepare.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::io::Encode;
use crate::io::ProtocolEncode;
use crate::protocol::Capabilities;

// https://dev.mysql.com/doc/internals/en/com-stmt-prepare.html#packet-COM_STMT_PREPARE
Expand All @@ -7,9 +7,10 @@ pub struct Prepare<'a> {
pub query: &'a str,
}

impl Encode<'_, Capabilities> for Prepare<'_> {
fn encode_with(&self, buf: &mut Vec<u8>, _: Capabilities) {
impl ProtocolEncode<'_, Capabilities> for Prepare<'_> {
fn encode_with(&self, buf: &mut Vec<u8>, _: Capabilities) -> Result<(), crate::Error> {
buf.push(0x16); // COM_STMT_PREPARE
buf.extend(self.query.as_bytes());
Ok(())
}
}
4 changes: 2 additions & 2 deletions sqlx-mysql/src/protocol/statement/prepare_ok.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use bytes::{Buf, Bytes};

use crate::error::Error;
use crate::io::Decode;
use crate::io::ProtocolDecode;
use crate::protocol::Capabilities;

// https://dev.mysql.com/doc/internals/en/com-stmt-prepare-response.html#packet-COM_STMT_PREPARE_OK
Expand All @@ -15,7 +15,7 @@ pub(crate) struct PrepareOk {
pub(crate) warnings: u16,
}

impl Decode<'_, Capabilities> for PrepareOk {
impl ProtocolDecode<'_, Capabilities> for PrepareOk {
fn decode_with(buf: Bytes, _: Capabilities) -> Result<Self, Error> {
const SIZE: usize = 12;

Expand Down
4 changes: 2 additions & 2 deletions sqlx-mysql/src/protocol/statement/row.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use bytes::{Buf, Bytes};

use crate::error::Error;
use crate::io::MySqlBufExt;
use crate::io::{BufExt, Decode};
use crate::io::{BufExt, ProtocolDecode};
use crate::protocol::text::ColumnType;
use crate::protocol::Row;
use crate::MySqlColumn;
Expand All @@ -13,7 +13,7 @@ use crate::MySqlColumn;
#[derive(Debug)]
pub(crate) struct BinaryRow(pub(crate) Row);

impl<'de> Decode<'de, &'de [MySqlColumn]> for BinaryRow {
impl<'de> ProtocolDecode<'de, &'de [MySqlColumn]> for BinaryRow {
fn decode_with(mut buf: Bytes, columns: &'de [MySqlColumn]) -> Result<Self, Error> {
let header = buf.get_u8();
if header != 0 {
Expand Down
Loading

0 comments on commit 3945e06

Please sign in to comment.