Skip to content

Commit

Permalink
Go/C bindings add authToken support
Browse files Browse the repository at this point in the history
Signed-off-by: Piotr Jastrzebski <[email protected]>
  • Loading branch information
haaawk committed Sep 2, 2023
1 parent 160b965 commit ba31e0e
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 31 deletions.
6 changes: 5 additions & 1 deletion crates/bindings/c/include/libsql.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,11 @@ extern "C" {

int libsql_sync(libsql_database_t db, const char **out_err_msg);

int libsql_open_sync(const char *db_path, const char *primary_url, libsql_database_t *out_db, const char **out_err_msg);
int libsql_open_sync(const char *db_path,
const char *primary_url,
const char *auth_token,
libsql_database_t *out_db,
const char **out_err_msg);

int libsql_open_ext(const char *url, libsql_database_t *out_db, const char **out_err_msg);

Expand Down
13 changes: 11 additions & 2 deletions crates/bindings/c/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ pub unsafe extern "C" fn libsql_sync(
pub unsafe extern "C" fn libsql_open_sync(
db_path: *const std::ffi::c_char,
primary_url: *const std::ffi::c_char,
auth_token: *const std::ffi::c_char,
out_db: *mut libsql_database_t,
out_err_msg: *mut *const std::ffi::c_char,
) -> std::ffi::c_int {
Expand All @@ -66,13 +67,21 @@ pub unsafe extern "C" fn libsql_open_sync(
Ok(url) => url,
Err(e) => {
set_err_msg(format!("Wrong URL: {}", e.to_string()), out_err_msg);
return 1;
return 2;
}
};
let auth_token = unsafe { std::ffi::CStr::from_ptr(auth_token) };
let auth_token = match auth_token.to_str() {
Ok(token) => token,
Err(e) => {
set_err_msg(format!("Wrong Auth Token: {}", e.to_string()), out_err_msg);
return 3;
}
};
match RT.block_on(libsql::v2::Database::open_with_sync(
db_path.to_string(),
primary_url,
"",
auth_token,
)) {
Ok(db) => {
let db = Box::leak(Box::new(libsql_database { db }));
Expand Down
20 changes: 11 additions & 9 deletions crates/bindings/go/libsql.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,12 @@ func init() {
sql.Register("libsql", driver{})
}

func NewEmbeddedReplicaConnector(dbPath, primaryUrl string) (*Connector, error) {
return openConnector(dbPath, primaryUrl, 0)
func NewEmbeddedReplicaConnector(dbPath, primaryUrl, authToken string) (*Connector, error) {
return openConnector(dbPath, primaryUrl, authToken, 0)
}

func NewEmbeddedReplicaConnectorWithAutoSync(dbPath, primaryUrl string, syncInterval time.Duration) (*Connector, error) {
return openConnector(dbPath, primaryUrl, syncInterval)
func NewEmbeddedReplicaConnectorWithAutoSync(dbPath, primaryUrl, authToken string, syncInterval time.Duration) (*Connector, error) {
return openConnector(dbPath, primaryUrl, authToken, syncInterval)
}

type driver struct{}
Expand All @@ -48,7 +48,7 @@ func (d driver) Open(dbPath string) (sqldriver.Conn, error) {
}

func (d driver) OpenConnector(dbPath string) (sqldriver.Connector, error) {
return openConnector(dbPath, "", 0)
return openConnector(dbPath, "", "", 0)
}

func libsqlSync(nativeDbPtr C.libsql_database_t) error {
Expand All @@ -60,13 +60,13 @@ func libsqlSync(nativeDbPtr C.libsql_database_t) error {
return nil
}

func openConnector(dbPath, primaryUrl string, syncInterval time.Duration) (*Connector, error) {
func openConnector(dbPath, primaryUrl, authToken string, syncInterval time.Duration) (*Connector, error) {
var nativeDbPtr C.libsql_database_t
var err error
var closeCh chan struct{}
var closeAckCh chan struct{}
if primaryUrl != "" {
nativeDbPtr, err = libsqlOpenWithSync(dbPath, primaryUrl)
nativeDbPtr, err = libsqlOpenWithSync(dbPath, primaryUrl, authToken)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -160,15 +160,17 @@ func libsqlOpen(dataSourceName string) (C.libsql_database_t, error) {
return db, nil
}

func libsqlOpenWithSync(dbPath, primaryUrl string) (C.libsql_database_t, error) {
func libsqlOpenWithSync(dbPath, primaryUrl, authToken string) (C.libsql_database_t, error) {
dbPathNativeString := C.CString(dbPath)
defer C.free(unsafe.Pointer(dbPathNativeString))
primaryUrlNativeString := C.CString(primaryUrl)
defer C.free(unsafe.Pointer(primaryUrlNativeString))
authTokenNativeString := C.CString(authToken)
defer C.free(unsafe.Pointer(authTokenNativeString))

var db C.libsql_database_t
var errMsg *C.char
statusCode := C.libsql_open_sync(dbPathNativeString, primaryUrlNativeString, &db, &errMsg)
statusCode := C.libsql_open_sync(dbPathNativeString, primaryUrlNativeString, authTokenNativeString, &db, &errMsg)
if statusCode != 0 {
return nil, libsqlError(fmt.Sprintf("failed to open database %s %s", dbPath, primaryUrl), statusCode, errMsg)
}
Expand Down
44 changes: 25 additions & 19 deletions crates/bindings/go/libsql_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ import (
"time"
)

func executeSql(t *testing.T, primaryUrl, sql string) {
func executeSql(t *testing.T, primaryUrl, authToken, sql string) {
type statement struct {
Query string `json:"q"`
}
Expand Down Expand Up @@ -56,6 +56,11 @@ func executeSql(t *testing.T, primaryUrl, sql string) {
t.Fatal(err)
}
req.Header.Set("Content-Type", "application/json")

if authToken != "" {
req.Header.Set("Authorization", "Bearer "+authToken)
}

resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatal(err)
Expand Down Expand Up @@ -90,44 +95,45 @@ func executeSql(t *testing.T, primaryUrl, sql string) {
}
}

func insertRow(t *testing.T, dbUrl, tableName string, id int) {
executeSql(t, dbUrl, fmt.Sprintf("INSERT INTO %s (id, name, gpa, cv) VALUES (%d, '%d', %d.5, randomblob(10));", tableName, id, id, id))
func insertRow(t *testing.T, dbUrl, authToken, tableName string, id int) {
executeSql(t, dbUrl, authToken, fmt.Sprintf("INSERT INTO %s (id, name, gpa, cv) VALUES (%d, '%d', %d.5, randomblob(10));", tableName, id, id, id))
}

func insertRows(t *testing.T, dbUrl, tableName string, start, count int) {
func insertRows(t *testing.T, dbUrl, authToken, tableName string, start, count int) {
for i := 0; i < count; i++ {
insertRow(t, dbUrl, tableName, start+i)
insertRow(t, dbUrl, authToken, tableName, start+i)
}
}

func createTable(t *testing.T, dbPath string) string {
func createTable(t *testing.T, dbPath, authToken string) string {
tableName := fmt.Sprintf("test_%d", time.Now().UnixNano())
executeSql(t, dbPath, fmt.Sprintf("CREATE TABLE %s (id INTEGER, name TEXT, gpa REAL, cv BLOB);", tableName))
executeSql(t, dbPath, authToken, fmt.Sprintf("CREATE TABLE %s (id INTEGER, name TEXT, gpa REAL, cv BLOB);", tableName))
return tableName
}

func removeTable(t *testing.T, dbPath, tableName string) {
executeSql(t, dbPath, fmt.Sprintf("DROP TABLE %s;", tableName))
func removeTable(t *testing.T, dbPath, authToken, tableName string) {
executeSql(t, dbPath, authToken, fmt.Sprintf("DROP TABLE %s;", tableName))
}

func testSync(t *testing.T, connect func(dbPath, primaryUrl string) *Connector, sync func(connector *Connector)) {
func testSync(t *testing.T, connect func(dbPath, primaryUrl, authToken string) *Connector, sync func(connector *Connector)) {
primaryUrl := os.Getenv("LIBSQL_PRIMARY_URL")
if primaryUrl == "" {
t.Skip("LIBSQL_PRIMARY_URL is not set")
return
}
tableName := createTable(t, primaryUrl)
defer removeTable(t, primaryUrl, tableName)
authToken := os.Getenv("LIBSQL_AUTH_TOKEN")
tableName := createTable(t, primaryUrl, authToken)
defer removeTable(t, primaryUrl, authToken, tableName)

initialRowsCount := 5
insertRows(t, primaryUrl, tableName, 0, initialRowsCount)
insertRows(t, primaryUrl, authToken, tableName, 0, initialRowsCount)
dir, err := os.MkdirTemp("", "libsql-*")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(dir)

connector := connect(dir+"/test.db", primaryUrl)
connector := connect(dir+"/test.db", primaryUrl, authToken)
db := sql.OpenDB(connector)
defer db.Close()

Expand Down Expand Up @@ -186,16 +192,16 @@ func testSync(t *testing.T, connect func(dbPath, primaryUrl string) *Connector,
}
}()
if iter+1 != iterCount {
insertRow(t, primaryUrl, tableName, initialRowsCount+iter)
insertRow(t, primaryUrl, authToken, tableName, initialRowsCount+iter)
sync(connector)
}
}
}

func TestAutoSync(t *testing.T) {
syncInterval := 1 * time.Second
testSync(t, func(dbPath, primaryUrl string) *Connector {
connector, err := NewEmbeddedReplicaConnectorWithAutoSync(dbPath, primaryUrl, syncInterval)
testSync(t, func(dbPath, primaryUrl, authToken string) *Connector {
connector, err := NewEmbeddedReplicaConnectorWithAutoSync(dbPath, primaryUrl, authToken, syncInterval)
if err != nil {
t.Fatal(err)
}
Expand All @@ -206,8 +212,8 @@ func TestAutoSync(t *testing.T) {
}

func TestSync(t *testing.T) {
testSync(t, func(dbPath, primaryUrl string) *Connector {
connector, err := NewEmbeddedReplicaConnector(dbPath, primaryUrl)
testSync(t, func(dbPath, primaryUrl, authToken string) *Connector {
connector, err := NewEmbeddedReplicaConnector(dbPath, primaryUrl, authToken)
if err != nil {
t.Fatal(err)
}
Expand Down

0 comments on commit ba31e0e

Please sign in to comment.