From 28d02a8386f50348e22c0592fb5867a476ccc55a Mon Sep 17 00:00:00 2001 From: Bartosz Sypytkowski Date: Mon, 22 Jan 2024 15:08:22 +0100 Subject: [PATCH] added URL parsing --- libsql/src/hrana/connection.rs | 7 +++---- libsql/src/hrana/stream.rs | 17 ++++++++++++++++- libsql/src/local/database.rs | 19 ++++++++++++++++--- libsql/src/util/mod.rs | 2 +- 4 files changed, 36 insertions(+), 9 deletions(-) diff --git a/libsql/src/hrana/connection.rs b/libsql/src/hrana/connection.rs index ae4dac4d7c..616329f5a4 100644 --- a/libsql/src/hrana/connection.rs +++ b/libsql/src/hrana/connection.rs @@ -1,7 +1,7 @@ use crate::hrana::cursor::Cursor; use crate::hrana::pipeline::{BatchStreamReq, StreamRequest, StreamResponse}; use crate::hrana::proto::{Batch, BatchResult, Stmt}; -use crate::hrana::stream::HranaStream; +use crate::hrana::stream::{parse_hrana_urls, HranaStream}; use crate::hrana::{HranaError, HttpSend, Result, Statement}; use crate::util::coerce_url_scheme; use std::sync::atomic::{AtomicBool, AtomicI64, AtomicU64, Ordering}; @@ -32,9 +32,8 @@ where { pub fn new(url: String, token: String, inner: T) -> Self { // The `libsql://` protocol is an alias for `https://`. - let base_url = coerce_url_scheme(&url); - let pipeline_url = Arc::from(format!("{base_url}/v3/pipeline")); - let cursor_url = Arc::from(format!("{base_url}/v3/cursor")); + let base_url = coerce_url_scheme(url); + let (pipeline_url, cursor_url) = parse_hrana_urls(&base_url); HttpConnection(Arc::new(InnerClient { inner, pipeline_url, diff --git a/libsql/src/hrana/stream.rs b/libsql/src/hrana/stream.rs index fa644191d4..972c6dbb13 100644 --- a/libsql/src/hrana/stream.rs +++ b/libsql/src/hrana/stream.rs @@ -289,7 +289,9 @@ where let body = stream_to_bytes(body).await?; let mut response: ServerMsg = serde_json::from_slice(&body)?; if let Some(base_url) = response.base_url.take() { - self.pipeline_url = Arc::from(base_url); + let (pipeline_url, cursor_url) = parse_hrana_urls(&base_url); + self.pipeline_url = pipeline_url; + self.cursor_url = cursor_url; } match response.baton.take() { None => { @@ -379,6 +381,19 @@ where } } +pub(super) fn parse_hrana_urls(url: &str) -> (Arc, Arc) { + let (mut base_url, query) = match url.rfind('?') { + Some(i) => url.split_at(i), + None => (url, ""), + }; + if base_url.ends_with('/') { + base_url = &base_url[0..(base_url.len() - 1)]; + } + let pipeline_url = Arc::from(format!("{base_url}/v3/pipeline{query}")); + let cursor_url = Arc::from(format!("{base_url}/v3/cursor{query}")); + (pipeline_url, cursor_url) +} + async fn stream_to_bytes(mut stream: S) -> Result where S: Stream> + Unpin, diff --git a/libsql/src/local/database.rs b/libsql/src/local/database.rs index d998c0c784..5eff3167a7 100644 --- a/libsql/src/local/database.rs +++ b/libsql/src/local/database.rs @@ -53,7 +53,16 @@ impl Database { auth_token: String, encryption_key: Option, ) -> Result { - Self::open_http_sync_internal(connector, db_path, endpoint, auth_token, None, false, encryption_key).await + Self::open_http_sync_internal( + connector, + db_path, + endpoint, + auth_token, + None, + false, + encryption_key, + ) + .await } #[cfg(feature = "replication")] @@ -73,7 +82,7 @@ impl Database { let mut db = Database::open(&db_path, OpenFlags::default())?; - let endpoint = coerce_url_scheme(&endpoint); + let endpoint = coerce_url_scheme(endpoint); let remote = crate::replication::client::Client::new( connector, endpoint.as_str().try_into().unwrap(), @@ -98,7 +107,11 @@ impl Database { } #[cfg(feature = "replication")] - pub async fn open_local_sync(db_path: impl Into, flags: OpenFlags, encryption_key: Option) -> Result { + pub async fn open_local_sync( + db_path: impl Into, + flags: OpenFlags, + encryption_key: Option, + ) -> Result { use std::path::PathBuf; let db_path = db_path.into(); diff --git a/libsql/src/util/mod.rs b/libsql/src/util/mod.rs index f2ecc69ba9..d724c924b4 100644 --- a/libsql/src/util/mod.rs +++ b/libsql/src/util/mod.rs @@ -7,7 +7,7 @@ cfg_replication_or_remote! { } cfg_replication_or_remote_or_hrana! { - pub(crate) fn coerce_url_scheme(url: &str) -> String { + pub(crate) fn coerce_url_scheme(url: String) -> String { let mut url = url.replace("libsql://", "https://"); if !url.contains("://") {