Skip to content

Commit

Permalink
Merge pull request #1932 from levydsa/remote_writes
Browse files Browse the repository at this point in the history
Remote writes for offline databases
  • Loading branch information
penberg authored Jan 23, 2025
2 parents 09715e4 + 3bd9ef9 commit 1e6af39
Show file tree
Hide file tree
Showing 14 changed files with 461 additions and 159 deletions.
2 changes: 2 additions & 0 deletions libsql/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,8 @@ sync = [
"parser",
"serde",
"stream",
"remote",
"replication",
"dep:tower",
"dep:hyper",
"dep:http",
Expand Down
55 changes: 45 additions & 10 deletions libsql/src/database.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ pub use libsql_sys::{Cipher, EncryptionConfig};
use crate::{Connection, Result};
use std::fmt;
use std::sync::atomic::AtomicU64;
use std::sync::Arc;

cfg_core! {
bitflags::bitflags! {
Expand Down Expand Up @@ -82,7 +81,14 @@ enum DbType {
encryption_config: Option<EncryptionConfig>,
},
#[cfg(feature = "sync")]
Offline { db: crate::local::Database },
Offline {
db: crate::local::Database,
remote_writes: bool,
read_your_writes: bool,
url: String,
auth_token: String,
connector: crate::util::ConnectorService,
},
#[cfg(feature = "remote")]
Remote {
url: String,
Expand Down Expand Up @@ -117,7 +123,7 @@ pub struct Database {
db_type: DbType,
/// The maximum replication index returned from a write performed using any connection created using this Database object.
#[allow(dead_code)]
max_write_replication_index: Arc<AtomicU64>,
max_write_replication_index: std::sync::Arc<AtomicU64>,
}

cfg_core! {
Expand Down Expand Up @@ -375,7 +381,7 @@ cfg_replication! {
#[cfg(feature = "replication")]
DbType::Sync { db, encryption_config: _ } => db.sync().await,
#[cfg(feature = "sync")]
DbType::Offline { db } => db.sync_offline().await,
DbType::Offline { db, .. } => db.sync_offline().await,
_ => Err(Error::SyncNotSupported(format!("{:?}", self.db_type))),
}
}
Expand Down Expand Up @@ -642,13 +648,42 @@ impl Database {
}

#[cfg(feature = "sync")]
DbType::Offline { db } => {
use crate::local::impls::LibsqlConnection;

let conn = db.connect()?;

let conn = std::sync::Arc::new(LibsqlConnection { conn });
DbType::Offline {
db,
remote_writes,
read_your_writes,
url,
auth_token,
connector,
} => {
use crate::{
hrana::{connection::HttpConnection, hyper::HttpSender},
local::impls::LibsqlConnection,
replication::connection::State,
sync::connection::SyncedConnection,
};
use tokio::sync::Mutex;

let local = db.connect()?;

if *remote_writes {
let synced = SyncedConnection {
local,
remote: HttpConnection::new(
url.clone(),
auth_token.clone(),
HttpSender::new(connector.clone(), None),
),
read_your_writes: *read_your_writes,
context: db.sync_ctx.clone().unwrap(),
state: std::sync::Arc::new(Mutex::new(State::Init)),
};

let conn = std::sync::Arc::new(synced);
return Ok(Connection { conn });
}

let conn = std::sync::Arc::new(LibsqlConnection { conn: local });
Ok(Connection { conn })
}

Expand Down
33 changes: 28 additions & 5 deletions libsql/src/database/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,9 @@ impl Builder<()> {
connector: None,
version: None,
},
connector:None,
connector: None,
read_your_writes: true,
remote_writes: false,
},
}
}
Expand Down Expand Up @@ -463,6 +465,8 @@ cfg_sync! {
flags: crate::OpenFlags,
remote: Remote,
connector: Option<crate::util::ConnectorService>,
remote_writes: bool,
read_your_writes: bool,
}

impl Builder<SyncedDatabase> {
Expand All @@ -472,6 +476,16 @@ cfg_sync! {
self
}

pub fn read_your_writes(mut self, v: bool) -> Builder<SyncedDatabase> {
self.inner.read_your_writes = v;
self
}

pub fn remote_writes(mut self, v: bool) -> Builder<SyncedDatabase> {
self.inner.remote_writes = v;
self
}

/// Provide a custom http connector that will be used to create http connections.
pub fn connector<C>(mut self, connector: C) -> Builder<SyncedDatabase>
where
Expand All @@ -497,6 +511,8 @@ cfg_sync! {
version: _,
},
connector,
remote_writes,
read_your_writes,
} = self.inner;

let path = path.to_str().ok_or(crate::Error::InvalidUTF8Path)?.to_owned();
Expand All @@ -515,16 +531,23 @@ cfg_sync! {
let connector = crate::util::ConnectorService::new(svc);

let db = crate::local::Database::open_local_with_offline_writes(
connector,
connector.clone(),
path,
flags,
url,
auth_token,
url.clone(),
auth_token.clone(),
)
.await?;

Ok(Database {
db_type: DbType::Offline { db },
db_type: DbType::Offline {
db,
remote_writes,
read_your_writes,
url,
auth_token,
connector,
},
max_write_replication_index: Default::default(),
})
}
Expand Down
11 changes: 7 additions & 4 deletions libsql/src/hrana/hyper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -305,14 +305,17 @@ impl Conn for HranaStream<HttpSender> {
let parse = crate::parser::Statement::parse(sql);
for s in parse {
let s = s?;
if s.kind == crate::parser::StmtKind::TxnBegin
|| s.kind == crate::parser::StmtKind::TxnBeginReadOnly
|| s.kind == crate::parser::StmtKind::TxnEnd
{

use crate::parser::StmtKind;
if matches!(
s.kind,
StmtKind::TxnBegin | StmtKind::TxnBeginReadOnly | StmtKind::TxnEnd
) {
return Err(Error::TransactionalBatchError(
"Transactions forbidden inside transactional batch".to_string(),
));
}

stmts.push(Stmt::new(s.stmt, false));
}
let res = self
Expand Down
2 changes: 1 addition & 1 deletion libsql/src/hrana/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
pub mod connection;

cfg_remote! {
mod hyper;
pub mod hyper;
}

mod cursor;
Expand Down
138 changes: 6 additions & 132 deletions libsql/src/local/database.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,11 @@ cfg_replication!(

cfg_sync! {
use crate::sync::SyncContext;
use tokio::sync::Mutex;
use std::sync::Arc;
}

use crate::{database::OpenFlags, local::connection::Connection};
use crate::{Error::ConnectionFailed, Result};
use crate::{database::OpenFlags, local::connection::Connection, Error::ConnectionFailed, Result};
use libsql_sys::ffi;

// A libSQL database.
Expand All @@ -33,7 +34,7 @@ pub struct Database {
#[cfg(feature = "replication")]
pub replication_ctx: Option<ReplicationContext>,
#[cfg(feature = "sync")]
pub sync_ctx: Option<tokio::sync::Mutex<SyncContext>>,
pub sync_ctx: Option<Arc<Mutex<SyncContext>>>,
}

impl Database {
Expand Down Expand Up @@ -222,7 +223,7 @@ impl Database {

let sync_ctx =
SyncContext::new(connector, db_path.into(), endpoint, Some(auth_token)).await?;
db.sync_ctx = Some(tokio::sync::Mutex::new(sync_ctx));
db.sync_ctx = Some(Arc::new(Mutex::new(sync_ctx)));

Ok(db)
}
Expand Down Expand Up @@ -463,137 +464,10 @@ impl Database {
#[cfg(feature = "sync")]
/// Sync WAL frames to remote.
pub async fn sync_offline(&self) -> Result<crate::database::Replicated> {
use crate::sync::SyncError;
use crate::Error;

let mut sync_ctx = self.sync_ctx.as_ref().unwrap().lock().await;
let conn = self.connect()?;

let durable_frame_no = sync_ctx.durable_frame_num();
let max_frame_no = conn.wal_frame_count();

if max_frame_no > durable_frame_no {
match self.try_push(&mut sync_ctx, &conn).await {
Ok(rep) => Ok(rep),
Err(Error::Sync(err)) => {
// Retry the sync because we are ahead of the server and we need to push some older
// frames.
if let Some(SyncError::InvalidPushFrameNoLow(_, _)) = err.downcast_ref() {
tracing::debug!("got InvalidPushFrameNo, retrying push");
self.try_push(&mut sync_ctx, &conn).await
} else {
Err(Error::Sync(err))
}
}
Err(e) => Err(e),
}
} else {
self.try_pull(&mut sync_ctx, &conn).await
}
.or_else(|err| {
let Error::Sync(err) = err else {
return Err(err);
};

// TODO(levy): upcasting should be done *only* at the API boundary, doing this in
// internal code just sucks.
let Some(SyncError::HttpDispatch(_)) = err.downcast_ref() else {
return Err(Error::Sync(err));
};

Ok(crate::database::Replicated {
frame_no: None,
frames_synced: 0,
})
})
}

#[cfg(feature = "sync")]
async fn try_push(
&self,
sync_ctx: &mut SyncContext,
conn: &Connection,
) -> Result<crate::database::Replicated> {
let page_size = {
let rows = conn
.query("PRAGMA page_size", crate::params::Params::None)?
.unwrap();
let row = rows.next()?.unwrap();
let page_size = row.get::<u32>(0)?;
page_size
};

let max_frame_no = conn.wal_frame_count();
if max_frame_no == 0 {
return Ok(crate::database::Replicated {
frame_no: None,
frames_synced: 0,
});
}

let generation = sync_ctx.generation(); // TODO: Probe from WAL.
let start_frame_no = sync_ctx.durable_frame_num() + 1;
let end_frame_no = max_frame_no;

let mut frame_no = start_frame_no;
while frame_no <= end_frame_no {
let frame = conn.wal_get_frame(frame_no, page_size)?;

// The server returns its maximum frame number. To avoid resending
// frames the server already knows about, we need to update the
// frame number to the one returned by the server.
let max_frame_no = sync_ctx
.push_one_frame(frame.freeze(), generation, frame_no)
.await?;

if max_frame_no > frame_no {
frame_no = max_frame_no;
}
frame_no += 1;
}

sync_ctx.write_metadata().await?;

// TODO(lucio): this can underflow if the server previously returned a higher max_frame_no
// than what we have stored here.
let frame_count = end_frame_no - start_frame_no + 1;
Ok(crate::database::Replicated {
frame_no: None,
frames_synced: frame_count as usize,
})
}

#[cfg(feature = "sync")]
async fn try_pull(
&self,
sync_ctx: &mut SyncContext,
conn: &Connection,
) -> Result<crate::database::Replicated> {
let generation = sync_ctx.generation();
let mut frame_no = sync_ctx.durable_frame_num() + 1;

let insert_handle = conn.wal_insert_handle()?;

loop {
match sync_ctx.pull_one_frame(generation, frame_no).await {
Ok(Some(frame)) => {
insert_handle.insert(&frame)?;
frame_no += 1;
}
Ok(None) => {
sync_ctx.write_metadata().await?;
return Ok(crate::database::Replicated {
frame_no: None,
frames_synced: 1,
});
}
Err(err) => {
tracing::debug!("pull_one_frame error: {:?}", err);
sync_ctx.write_metadata().await?;
return Err(err);
}
}
}
crate::sync::sync_offline(&mut sync_ctx, &conn).await
}

pub(crate) fn path(&self) -> &str {
Expand Down
2 changes: 1 addition & 1 deletion libsql/src/local/impls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ impl Drop for LibsqlConnection {
}
}

pub(crate) struct LibsqlStmt(pub(super) crate::local::Statement);
pub(crate) struct LibsqlStmt(pub crate::local::Statement);

#[async_trait::async_trait]
impl Stmt for LibsqlStmt {
Expand Down
7 changes: 7 additions & 0 deletions libsql/src/params.rs
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,13 @@ impl IntoParams for Params {
}
}

impl Sealed for &Params {}
impl IntoParams for &Params {
fn into_params(self) -> Result<Params> {
Ok(self.clone())
}
}

impl<T: IntoValue> Sealed for Vec<T> {}
impl<T: IntoValue> IntoParams for Vec<T> {
fn into_params(self) -> Result<Params> {
Expand Down
Loading

0 comments on commit 1e6af39

Please sign in to comment.