Skip to content

Commit

Permalink
Go bindings support remote only dbs
Browse files Browse the repository at this point in the history
Signed-off-by: Piotr Jastrzebski <[email protected]>
  • Loading branch information
haaawk committed Jan 22, 2024
1 parent 49ea8c2 commit c23f4d6
Show file tree
Hide file tree
Showing 3 changed files with 117 additions and 54 deletions.
7 changes: 1 addition & 6 deletions bindings/c/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -155,12 +155,7 @@ pub unsafe extern "C" fn libsql_open_remote(
return 2;
}
};
match RT.block_on(libsql::Database::open_with_remote_sync(
url.to_string(),
url,
auth_token,
None,
)) {
match libsql::Database::open_remote(url, auth_token) {
Ok(db) => {
let db = Box::leak(Box::new(libsql_database { db }));
*out_db = libsql_database_t::from(db);
Expand Down
132 changes: 89 additions & 43 deletions bindings/go/libsql.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ import (
sqldriver "database/sql/driver"
"fmt"
"io"
"net/url"
"strings"
"time"
"unsafe"
)
Expand All @@ -30,25 +32,44 @@ func init() {
}

func NewEmbeddedReplicaConnector(dbPath, primaryUrl, authToken string) (*Connector, error) {
return openConnector(dbPath, primaryUrl, authToken, 0)
return openEmbeddedReplicaConnector(dbPath, primaryUrl, authToken, 0)
}

func NewEmbeddedReplicaConnectorWithAutoSync(dbPath, primaryUrl, authToken string, syncInterval time.Duration) (*Connector, error) {
return openConnector(dbPath, primaryUrl, authToken, syncInterval)
return openEmbeddedReplicaConnector(dbPath, primaryUrl, authToken, syncInterval)
}

type driver struct{}

func (d driver) Open(dbPath string) (sqldriver.Conn, error) {
connector, err := d.OpenConnector(dbPath)
func (d driver) Open(dbAddress string) (sqldriver.Conn, error) {
connector, err := d.OpenConnector(dbAddress)
if err != nil {
return nil, err
}
return connector.Connect(context.Background())
}

func (d driver) OpenConnector(dbPath string) (sqldriver.Connector, error) {
return openConnector(dbPath, "", "", 0)
func (d driver) OpenConnector(dbAddress string) (sqldriver.Connector, error) {
if strings.HasPrefix(dbAddress, ":memory:") {
return openLocalConnector(dbAddress)
}
u, err := url.Parse(dbAddress)
if err != nil {
return nil, err
}
switch u.Scheme {
case "file":
return openLocalConnector(dbAddress)
case "http":
fallthrough
case "https":
fallthrough
case "libsql":
authToken := u.Query().Get("authToken")
u.RawQuery = ""
return openRemoteConnector(u.String(), authToken)
}
return nil, fmt.Errorf("unsupported URL scheme: %s\nThis driver supports only URLs that start with libsql://, file://, https:// or http://", u.Scheme)
}

func libsqlSync(nativeDbPtr C.libsql_database_t) error {
Expand All @@ -60,44 +81,54 @@ func libsqlSync(nativeDbPtr C.libsql_database_t) error {
return nil
}

func openConnector(dbPath, primaryUrl, authToken string, syncInterval time.Duration) (*Connector, error) {
var nativeDbPtr C.libsql_database_t
var err error
func openLocalConnector(dbPath string) (*Connector, error) {
nativeDbPtr, err := libsqlOpenLocal(dbPath)
if err != nil {
return nil, err
}
return &Connector{nativeDbPtr: nativeDbPtr}, nil
}

func openRemoteConnector(primaryUrl, authToken string) (*Connector, error) {
nativeDbPtr, err := libsqlOpenRemote(primaryUrl, authToken)
if err != nil {
return nil, err
}
return &Connector{nativeDbPtr: nativeDbPtr}, nil
}

func openEmbeddedReplicaConnector(dbPath, primaryUrl, authToken string, syncInterval time.Duration) (*Connector, error) {
var closeCh chan struct{}
var closeAckCh chan struct{}
if primaryUrl != "" {
nativeDbPtr, err = libsqlOpenWithSync(dbPath, primaryUrl, authToken)
if err != nil {
return nil, err
}
if err := libsqlSync(nativeDbPtr); err != nil {
C.libsql_close(nativeDbPtr)
return nil, err
}
if syncInterval != 0 {
closeCh = make(chan struct{}, 1)
closeAckCh = make(chan struct{}, 1)
go func() {
for {
timerCh := make(chan struct{}, 1)
go func() {
time.Sleep(syncInterval)
timerCh <- struct{}{}
}()
select {
case <-closeCh:
closeAckCh <- struct{}{}
return
case <-timerCh:
if err := libsqlSync(nativeDbPtr); err != nil {
fmt.Println(err)
}
nativeDbPtr, err := libsqlOpenWithSync(dbPath, primaryUrl, authToken)
if err != nil {
return nil, err
}
if err := libsqlSync(nativeDbPtr); err != nil {
C.libsql_close(nativeDbPtr)
return nil, err
}
if syncInterval != 0 {
closeCh = make(chan struct{}, 1)
closeAckCh = make(chan struct{}, 1)
go func() {
for {
timerCh := make(chan struct{}, 1)
go func() {
time.Sleep(syncInterval)
timerCh <- struct{}{}
}()
select {
case <-closeCh:
closeAckCh <- struct{}{}
return
case <-timerCh:
if err := libsqlSync(nativeDbPtr); err != nil {
fmt.Println(err)
}
}
}()
}
} else {
nativeDbPtr, err = libsqlOpen(dbPath)
}
}()
}
if err != nil {
return nil, err
Expand Down Expand Up @@ -147,15 +178,30 @@ func libsqlError(message string, statusCode C.int, errMsg *C.char) error {
}
}

func libsqlOpen(dataSourceName string) (C.libsql_database_t, error) {
func libsqlOpenLocal(dataSourceName string) (C.libsql_database_t, error) {
connectionString := C.CString(dataSourceName)
defer C.free(unsafe.Pointer(connectionString))

var db C.libsql_database_t
var errMsg *C.char
statusCode := C.libsql_open_ext(connectionString, &db, &errMsg)
statusCode := C.libsql_open_file(connectionString, &db, &errMsg)
if statusCode != 0 {
return nil, libsqlError(fmt.Sprint("failed to open local database ", dataSourceName), statusCode, errMsg)
}
return db, nil
}

func libsqlOpenRemote(url, authToken string) (C.libsql_database_t, error) {
connectionString := C.CString(url)
defer C.free(unsafe.Pointer(connectionString))
authTokenNativeString := C.CString(authToken)
defer C.free(unsafe.Pointer(authTokenNativeString))

var db C.libsql_database_t
var errMsg *C.char
statusCode := C.libsql_open_remote(connectionString, authTokenNativeString, &db, &errMsg)
if statusCode != 0 {
return nil, libsqlError(fmt.Sprint("failed to open database ", dataSourceName), statusCode, errMsg)
return nil, libsqlError(fmt.Sprint("failed to open remote database ", url), statusCode, errMsg)
}
return db, nil
}
Expand Down
32 changes: 27 additions & 5 deletions bindings/go/libsql_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -225,14 +225,36 @@ func TestSync(t *testing.T) {
})
}

func TestRemote(t *testing.T) {
primaryUrl := os.Getenv("LIBSQL_PRIMARY_URL")
if primaryUrl == "" {
t.Skip("LIBSQL_PRIMARY_URL is not set")
return
}
authToken := os.Getenv("LIBSQL_AUTH_TOKEN")
db, err := sql.Open("libsql", primaryUrl+"?authToken="+authToken)
if err != nil {
t.Fatal(err)
}
tableName := fmt.Sprintf("test_%d", time.Now().UnixNano())
_, err = db.Exec(fmt.Sprintf("CREATE TABLE %s (id INTEGER, name TEXT, gpa REAL, cv BLOB);", tableName))
if err != nil {
t.Fatal(err)
}
_, err = db.Exec(fmt.Sprintf("INSERT INTO %s (id, name, gpa, cv) VALUES (%d, '%d', %d.5, randomblob(10));", tableName, 0, 0, 0))
if err != nil {
t.Fatal(err)
}
}

func runFileTest(t *testing.T, test func(*testing.T, *sql.DB)) {
t.Parallel()
dir, err := os.MkdirTemp("", "libsql-*")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(dir)
db, err := sql.Open("libsql", dir+"/test.db")
db, err := sql.Open("libsql", "file:"+dir+"/test.db")
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -266,7 +288,7 @@ func runMemoryAndFileTests(t *testing.T, test func(*testing.T, *sql.DB)) {

func TestErrorNonUtf8URL(t *testing.T) {
t.Parallel()
db, err := sql.Open("libsql", "a\xc5z")
db, err := sql.Open("libsql", "file:a\xc5z")
if err == nil {
defer func() {
if err := db.Close(); err != nil {
Expand All @@ -275,7 +297,7 @@ func TestErrorNonUtf8URL(t *testing.T) {
}()
t.Fatal("expected error")
}
if err.Error() != "failed to open database a\xc5z\nerror code = 1: Wrong URL: invalid utf-8 sequence of 1 bytes from index 1" {
if err.Error() != "failed to open local database file:a\xc5z\nerror code = 1: Wrong URL: invalid utf-8 sequence of 1 bytes from index 6" {
t.Fatal("unexpected error:", err)
}
}
Expand All @@ -299,7 +321,7 @@ func TestErrorWrongURL(t *testing.T) {

func TestErrorCanNotConnect(t *testing.T) {
t.Parallel()
db, err := sql.Open("libsql", "/root/test.db")
db, err := sql.Open("libsql", "file:/root/test.db")
if err != nil {
t.Fatal(err)
}
Expand All @@ -317,7 +339,7 @@ func TestErrorCanNotConnect(t *testing.T) {
}()
t.Fatal("expected error")
}
if err.Error() != "failed to connect to database\nerror code = 1: Unable to connect: Failed to connect to database: `/root/test.db`" {
if err.Error() != "failed to connect to database\nerror code = 1: Unable to connect: Failed to connect to database: `file:/root/test.db`" {
t.Fatal("unexpected error:", err)
}
}
Expand Down

0 comments on commit c23f4d6

Please sign in to comment.