diff --git a/crates/bindings/c/include/libsql.h b/crates/bindings/c/include/libsql.h index 4820e72ffd..c8c6d2599b 100644 --- a/crates/bindings/c/include/libsql.h +++ b/crates/bindings/c/include/libsql.h @@ -34,7 +34,11 @@ extern "C" { int libsql_sync(libsql_database_t db, const char **out_err_msg); -int libsql_open_sync(const char *db_path, const char *primary_url, libsql_database_t *out_db, const char **out_err_msg); +int libsql_open_sync(const char *db_path, + const char *primary_url, + const char *auth_token, + libsql_database_t *out_db, + const char **out_err_msg); int libsql_open_ext(const char *url, libsql_database_t *out_db, const char **out_err_msg); diff --git a/crates/bindings/c/src/lib.rs b/crates/bindings/c/src/lib.rs index bb84dda608..883da12ccc 100644 --- a/crates/bindings/c/src/lib.rs +++ b/crates/bindings/c/src/lib.rs @@ -50,6 +50,7 @@ pub unsafe extern "C" fn libsql_sync( pub unsafe extern "C" fn libsql_open_sync( db_path: *const std::ffi::c_char, primary_url: *const std::ffi::c_char, + auth_token: *const std::ffi::c_char, out_db: *mut libsql_database_t, out_err_msg: *mut *const std::ffi::c_char, ) -> std::ffi::c_int { @@ -66,13 +67,21 @@ pub unsafe extern "C" fn libsql_open_sync( Ok(url) => url, Err(e) => { set_err_msg(format!("Wrong URL: {}", e.to_string()), out_err_msg); - return 1; + return 2; + } + }; + let auth_token = unsafe { std::ffi::CStr::from_ptr(auth_token) }; + let auth_token = match auth_token.to_str() { + Ok(token) => token, + Err(e) => { + set_err_msg(format!("Wrong Auth Token: {}", e.to_string()), out_err_msg); + return 3; } }; match RT.block_on(libsql::v2::Database::open_with_sync( db_path.to_string(), primary_url, - "", + auth_token, )) { Ok(db) => { let db = Box::leak(Box::new(libsql_database { db })); diff --git a/crates/bindings/go/libsql.go b/crates/bindings/go/libsql.go index 881f00477c..c308dfc1c0 100644 --- a/crates/bindings/go/libsql.go +++ b/crates/bindings/go/libsql.go @@ -29,12 +29,12 @@ func init() { sql.Register("libsql", driver{}) } -func NewEmbeddedReplicaConnector(dbPath, primaryUrl string) (*Connector, error) { - return openConnector(dbPath, primaryUrl, 0) +func NewEmbeddedReplicaConnector(dbPath, primaryUrl, authToken string) (*Connector, error) { + return openConnector(dbPath, primaryUrl, authToken, 0) } -func NewEmbeddedReplicaConnectorWithAutoSync(dbPath, primaryUrl string, syncInterval time.Duration) (*Connector, error) { - return openConnector(dbPath, primaryUrl, syncInterval) +func NewEmbeddedReplicaConnectorWithAutoSync(dbPath, primaryUrl, authToken string, syncInterval time.Duration) (*Connector, error) { + return openConnector(dbPath, primaryUrl, authToken, syncInterval) } type driver struct{} @@ -48,7 +48,7 @@ func (d driver) Open(dbPath string) (sqldriver.Conn, error) { } func (d driver) OpenConnector(dbPath string) (sqldriver.Connector, error) { - return openConnector(dbPath, "", 0) + return openConnector(dbPath, "", "", 0) } func libsqlSync(nativeDbPtr C.libsql_database_t) error { @@ -60,13 +60,13 @@ func libsqlSync(nativeDbPtr C.libsql_database_t) error { return nil } -func openConnector(dbPath, primaryUrl string, syncInterval time.Duration) (*Connector, error) { +func openConnector(dbPath, primaryUrl, authToken string, syncInterval time.Duration) (*Connector, error) { var nativeDbPtr C.libsql_database_t var err error var closeCh chan struct{} var closeAckCh chan struct{} if primaryUrl != "" { - nativeDbPtr, err = libsqlOpenWithSync(dbPath, primaryUrl) + nativeDbPtr, err = libsqlOpenWithSync(dbPath, primaryUrl, authToken) if err != nil { return nil, err } @@ -160,15 +160,17 @@ func libsqlOpen(dataSourceName string) (C.libsql_database_t, error) { return db, nil } -func libsqlOpenWithSync(dbPath, primaryUrl string) (C.libsql_database_t, error) { +func libsqlOpenWithSync(dbPath, primaryUrl, authToken string) (C.libsql_database_t, error) { dbPathNativeString := C.CString(dbPath) defer C.free(unsafe.Pointer(dbPathNativeString)) primaryUrlNativeString := C.CString(primaryUrl) defer C.free(unsafe.Pointer(primaryUrlNativeString)) + authTokenNativeString := C.CString(authToken) + defer C.free(unsafe.Pointer(authTokenNativeString)) var db C.libsql_database_t var errMsg *C.char - statusCode := C.libsql_open_sync(dbPathNativeString, primaryUrlNativeString, &db, &errMsg) + statusCode := C.libsql_open_sync(dbPathNativeString, primaryUrlNativeString, authTokenNativeString, &db, &errMsg) if statusCode != 0 { return nil, libsqlError(fmt.Sprintf("failed to open database %s %s", dbPath, primaryUrl), statusCode, errMsg) } diff --git a/crates/bindings/go/libsql_test.go b/crates/bindings/go/libsql_test.go index 0f48ec8e70..fe1775b955 100644 --- a/crates/bindings/go/libsql_test.go +++ b/crates/bindings/go/libsql_test.go @@ -15,7 +15,7 @@ import ( "time" ) -func executeSql(t *testing.T, primaryUrl, sql string) { +func executeSql(t *testing.T, primaryUrl, authToken, sql string) { type statement struct { Query string `json:"q"` } @@ -56,6 +56,11 @@ func executeSql(t *testing.T, primaryUrl, sql string) { t.Fatal(err) } req.Header.Set("Content-Type", "application/json") + + if authToken != "" { + req.Header.Set("Authorization", "Bearer "+authToken) + } + resp, err := http.DefaultClient.Do(req) if err != nil { t.Fatal(err) @@ -90,44 +95,45 @@ func executeSql(t *testing.T, primaryUrl, sql string) { } } -func insertRow(t *testing.T, dbUrl, tableName string, id int) { - executeSql(t, dbUrl, fmt.Sprintf("INSERT INTO %s (id, name, gpa, cv) VALUES (%d, '%d', %d.5, randomblob(10));", tableName, id, id, id)) +func insertRow(t *testing.T, dbUrl, authToken, tableName string, id int) { + executeSql(t, dbUrl, authToken, fmt.Sprintf("INSERT INTO %s (id, name, gpa, cv) VALUES (%d, '%d', %d.5, randomblob(10));", tableName, id, id, id)) } -func insertRows(t *testing.T, dbUrl, tableName string, start, count int) { +func insertRows(t *testing.T, dbUrl, authToken, tableName string, start, count int) { for i := 0; i < count; i++ { - insertRow(t, dbUrl, tableName, start+i) + insertRow(t, dbUrl, authToken, tableName, start+i) } } -func createTable(t *testing.T, dbPath string) string { +func createTable(t *testing.T, dbPath, authToken string) string { tableName := fmt.Sprintf("test_%d", time.Now().UnixNano()) - executeSql(t, dbPath, fmt.Sprintf("CREATE TABLE %s (id INTEGER, name TEXT, gpa REAL, cv BLOB);", tableName)) + executeSql(t, dbPath, authToken, fmt.Sprintf("CREATE TABLE %s (id INTEGER, name TEXT, gpa REAL, cv BLOB);", tableName)) return tableName } -func removeTable(t *testing.T, dbPath, tableName string) { - executeSql(t, dbPath, fmt.Sprintf("DROP TABLE %s;", tableName)) +func removeTable(t *testing.T, dbPath, authToken, tableName string) { + executeSql(t, dbPath, authToken, fmt.Sprintf("DROP TABLE %s;", tableName)) } -func testSync(t *testing.T, connect func(dbPath, primaryUrl string) *Connector, sync func(connector *Connector)) { +func testSync(t *testing.T, connect func(dbPath, primaryUrl, authToken string) *Connector, sync func(connector *Connector)) { primaryUrl := os.Getenv("LIBSQL_PRIMARY_URL") if primaryUrl == "" { t.Skip("LIBSQL_PRIMARY_URL is not set") return } - tableName := createTable(t, primaryUrl) - defer removeTable(t, primaryUrl, tableName) + authToken := os.Getenv("LIBSQL_AUTH_TOKEN") + tableName := createTable(t, primaryUrl, authToken) + defer removeTable(t, primaryUrl, authToken, tableName) initialRowsCount := 5 - insertRows(t, primaryUrl, tableName, 0, initialRowsCount) + insertRows(t, primaryUrl, authToken, tableName, 0, initialRowsCount) dir, err := os.MkdirTemp("", "libsql-*") if err != nil { t.Fatal(err) } defer os.RemoveAll(dir) - connector := connect(dir+"/test.db", primaryUrl) + connector := connect(dir+"/test.db", primaryUrl, authToken) db := sql.OpenDB(connector) defer db.Close() @@ -186,7 +192,7 @@ func testSync(t *testing.T, connect func(dbPath, primaryUrl string) *Connector, } }() if iter+1 != iterCount { - insertRow(t, primaryUrl, tableName, initialRowsCount+iter) + insertRow(t, primaryUrl, authToken, tableName, initialRowsCount+iter) sync(connector) } } @@ -194,8 +200,8 @@ func testSync(t *testing.T, connect func(dbPath, primaryUrl string) *Connector, func TestAutoSync(t *testing.T) { syncInterval := 1 * time.Second - testSync(t, func(dbPath, primaryUrl string) *Connector { - connector, err := NewEmbeddedReplicaConnectorWithAutoSync(dbPath, primaryUrl, syncInterval) + testSync(t, func(dbPath, primaryUrl, authToken string) *Connector { + connector, err := NewEmbeddedReplicaConnectorWithAutoSync(dbPath, primaryUrl, authToken, syncInterval) if err != nil { t.Fatal(err) } @@ -206,8 +212,8 @@ func TestAutoSync(t *testing.T) { } func TestSync(t *testing.T) { - testSync(t, func(dbPath, primaryUrl string) *Connector { - connector, err := NewEmbeddedReplicaConnector(dbPath, primaryUrl) + testSync(t, func(dbPath, primaryUrl, authToken string) *Connector { + connector, err := NewEmbeddedReplicaConnector(dbPath, primaryUrl, authToken) if err != nil { t.Fatal(err) }