Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Go/C bindings add authToken support #349

Merged
merged 1 commit into from
Sep 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion crates/bindings/c/include/libsql.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,11 @@ extern "C" {

int libsql_sync(libsql_database_t db, const char **out_err_msg);

int libsql_open_sync(const char *db_path, const char *primary_url, libsql_database_t *out_db, const char **out_err_msg);
int libsql_open_sync(const char *db_path,
const char *primary_url,
const char *auth_token,
libsql_database_t *out_db,
const char **out_err_msg);

int libsql_open_ext(const char *url, libsql_database_t *out_db, const char **out_err_msg);

Expand Down
13 changes: 11 additions & 2 deletions crates/bindings/c/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ pub unsafe extern "C" fn libsql_sync(
pub unsafe extern "C" fn libsql_open_sync(
db_path: *const std::ffi::c_char,
primary_url: *const std::ffi::c_char,
auth_token: *const std::ffi::c_char,
out_db: *mut libsql_database_t,
out_err_msg: *mut *const std::ffi::c_char,
) -> std::ffi::c_int {
Expand All @@ -66,13 +67,21 @@ pub unsafe extern "C" fn libsql_open_sync(
Ok(url) => url,
Err(e) => {
set_err_msg(format!("Wrong URL: {}", e.to_string()), out_err_msg);
return 1;
return 2;
}
};
let auth_token = unsafe { std::ffi::CStr::from_ptr(auth_token) };
let auth_token = match auth_token.to_str() {
Ok(token) => token,
Err(e) => {
set_err_msg(format!("Wrong Auth Token: {}", e.to_string()), out_err_msg);
return 3;
}
};
match RT.block_on(libsql::v2::Database::open_with_sync(
db_path.to_string(),
primary_url,
"",
auth_token,
)) {
Ok(db) => {
let db = Box::leak(Box::new(libsql_database { db }));
Expand Down
20 changes: 11 additions & 9 deletions crates/bindings/go/libsql.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,12 @@ func init() {
sql.Register("libsql", driver{})
}

func NewEmbeddedReplicaConnector(dbPath, primaryUrl string) (*Connector, error) {
return openConnector(dbPath, primaryUrl, 0)
func NewEmbeddedReplicaConnector(dbPath, primaryUrl, authToken string) (*Connector, error) {
return openConnector(dbPath, primaryUrl, authToken, 0)
}

func NewEmbeddedReplicaConnectorWithAutoSync(dbPath, primaryUrl string, syncInterval time.Duration) (*Connector, error) {
return openConnector(dbPath, primaryUrl, syncInterval)
func NewEmbeddedReplicaConnectorWithAutoSync(dbPath, primaryUrl, authToken string, syncInterval time.Duration) (*Connector, error) {
return openConnector(dbPath, primaryUrl, authToken, syncInterval)
}

type driver struct{}
Expand All @@ -48,7 +48,7 @@ func (d driver) Open(dbPath string) (sqldriver.Conn, error) {
}

func (d driver) OpenConnector(dbPath string) (sqldriver.Connector, error) {
return openConnector(dbPath, "", 0)
return openConnector(dbPath, "", "", 0)
}

func libsqlSync(nativeDbPtr C.libsql_database_t) error {
Expand All @@ -60,13 +60,13 @@ func libsqlSync(nativeDbPtr C.libsql_database_t) error {
return nil
}

func openConnector(dbPath, primaryUrl string, syncInterval time.Duration) (*Connector, error) {
func openConnector(dbPath, primaryUrl, authToken string, syncInterval time.Duration) (*Connector, error) {
var nativeDbPtr C.libsql_database_t
var err error
var closeCh chan struct{}
var closeAckCh chan struct{}
if primaryUrl != "" {
nativeDbPtr, err = libsqlOpenWithSync(dbPath, primaryUrl)
nativeDbPtr, err = libsqlOpenWithSync(dbPath, primaryUrl, authToken)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -160,15 +160,17 @@ func libsqlOpen(dataSourceName string) (C.libsql_database_t, error) {
return db, nil
}

func libsqlOpenWithSync(dbPath, primaryUrl string) (C.libsql_database_t, error) {
func libsqlOpenWithSync(dbPath, primaryUrl, authToken string) (C.libsql_database_t, error) {
dbPathNativeString := C.CString(dbPath)
defer C.free(unsafe.Pointer(dbPathNativeString))
primaryUrlNativeString := C.CString(primaryUrl)
defer C.free(unsafe.Pointer(primaryUrlNativeString))
authTokenNativeString := C.CString(authToken)
defer C.free(unsafe.Pointer(authTokenNativeString))

var db C.libsql_database_t
var errMsg *C.char
statusCode := C.libsql_open_sync(dbPathNativeString, primaryUrlNativeString, &db, &errMsg)
statusCode := C.libsql_open_sync(dbPathNativeString, primaryUrlNativeString, authTokenNativeString, &db, &errMsg)
if statusCode != 0 {
return nil, libsqlError(fmt.Sprintf("failed to open database %s %s", dbPath, primaryUrl), statusCode, errMsg)
}
Expand Down
44 changes: 25 additions & 19 deletions crates/bindings/go/libsql_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ import (
"time"
)

func executeSql(t *testing.T, primaryUrl, sql string) {
func executeSql(t *testing.T, primaryUrl, authToken, sql string) {
type statement struct {
Query string `json:"q"`
}
Expand Down Expand Up @@ -56,6 +56,11 @@ func executeSql(t *testing.T, primaryUrl, sql string) {
t.Fatal(err)
}
req.Header.Set("Content-Type", "application/json")

if authToken != "" {
req.Header.Set("Authorization", "Bearer "+authToken)
}

resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatal(err)
Expand Down Expand Up @@ -90,44 +95,45 @@ func executeSql(t *testing.T, primaryUrl, sql string) {
}
}

func insertRow(t *testing.T, dbUrl, tableName string, id int) {
executeSql(t, dbUrl, fmt.Sprintf("INSERT INTO %s (id, name, gpa, cv) VALUES (%d, '%d', %d.5, randomblob(10));", tableName, id, id, id))
func insertRow(t *testing.T, dbUrl, authToken, tableName string, id int) {
executeSql(t, dbUrl, authToken, fmt.Sprintf("INSERT INTO %s (id, name, gpa, cv) VALUES (%d, '%d', %d.5, randomblob(10));", tableName, id, id, id))
}

func insertRows(t *testing.T, dbUrl, tableName string, start, count int) {
func insertRows(t *testing.T, dbUrl, authToken, tableName string, start, count int) {
for i := 0; i < count; i++ {
insertRow(t, dbUrl, tableName, start+i)
insertRow(t, dbUrl, authToken, tableName, start+i)
}
}

func createTable(t *testing.T, dbPath string) string {
func createTable(t *testing.T, dbPath, authToken string) string {
tableName := fmt.Sprintf("test_%d", time.Now().UnixNano())
executeSql(t, dbPath, fmt.Sprintf("CREATE TABLE %s (id INTEGER, name TEXT, gpa REAL, cv BLOB);", tableName))
executeSql(t, dbPath, authToken, fmt.Sprintf("CREATE TABLE %s (id INTEGER, name TEXT, gpa REAL, cv BLOB);", tableName))
return tableName
}

func removeTable(t *testing.T, dbPath, tableName string) {
executeSql(t, dbPath, fmt.Sprintf("DROP TABLE %s;", tableName))
func removeTable(t *testing.T, dbPath, authToken, tableName string) {
executeSql(t, dbPath, authToken, fmt.Sprintf("DROP TABLE %s;", tableName))
}

func testSync(t *testing.T, connect func(dbPath, primaryUrl string) *Connector, sync func(connector *Connector)) {
func testSync(t *testing.T, connect func(dbPath, primaryUrl, authToken string) *Connector, sync func(connector *Connector)) {
primaryUrl := os.Getenv("LIBSQL_PRIMARY_URL")
if primaryUrl == "" {
t.Skip("LIBSQL_PRIMARY_URL is not set")
return
}
tableName := createTable(t, primaryUrl)
defer removeTable(t, primaryUrl, tableName)
authToken := os.Getenv("LIBSQL_AUTH_TOKEN")
tableName := createTable(t, primaryUrl, authToken)
defer removeTable(t, primaryUrl, authToken, tableName)

initialRowsCount := 5
insertRows(t, primaryUrl, tableName, 0, initialRowsCount)
insertRows(t, primaryUrl, authToken, tableName, 0, initialRowsCount)
dir, err := os.MkdirTemp("", "libsql-*")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(dir)

connector := connect(dir+"/test.db", primaryUrl)
connector := connect(dir+"/test.db", primaryUrl, authToken)
db := sql.OpenDB(connector)
defer db.Close()

Expand Down Expand Up @@ -186,16 +192,16 @@ func testSync(t *testing.T, connect func(dbPath, primaryUrl string) *Connector,
}
}()
if iter+1 != iterCount {
insertRow(t, primaryUrl, tableName, initialRowsCount+iter)
insertRow(t, primaryUrl, authToken, tableName, initialRowsCount+iter)
sync(connector)
}
}
}

func TestAutoSync(t *testing.T) {
syncInterval := 1 * time.Second
testSync(t, func(dbPath, primaryUrl string) *Connector {
connector, err := NewEmbeddedReplicaConnectorWithAutoSync(dbPath, primaryUrl, syncInterval)
testSync(t, func(dbPath, primaryUrl, authToken string) *Connector {
connector, err := NewEmbeddedReplicaConnectorWithAutoSync(dbPath, primaryUrl, authToken, syncInterval)
if err != nil {
t.Fatal(err)
}
Expand All @@ -206,8 +212,8 @@ func TestAutoSync(t *testing.T) {
}

func TestSync(t *testing.T) {
testSync(t, func(dbPath, primaryUrl string) *Connector {
connector, err := NewEmbeddedReplicaConnector(dbPath, primaryUrl)
testSync(t, func(dbPath, primaryUrl, authToken string) *Connector {
connector, err := NewEmbeddedReplicaConnector(dbPath, primaryUrl, authToken)
if err != nil {
t.Fatal(err)
}
Expand Down
Loading