diff --git a/CHANGELOG.md b/CHANGELOG.md index ccea26c..c6c0702 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,8 @@ ## 0.1.x (Unreleased) +- Fix thread safety issue in connector + ## 0.2.0 (2022-11-18) - Support for DirectResults diff --git a/README.md b/README.md index b345a95..f550b43 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,4 @@ -# Databricks SQL Driver for Go (Beta) +# Databricks SQL Driver for Go ![http://www.apache.org/licenses/LICENSE-2.0.txt](http://img.shields.io/:license-Apache%202-brightgreen.svg) @@ -7,47 +7,71 @@ This repo contains a Databricks SQL Driver for Go's [database/sql](https://golang.org/pkg/database/sql) package. It can be used to connect and query Databricks clusters and SQL Warehouses. -**NOTE: This Driver is Beta.** - ## Documentation -Full documentation is not yet available. See below for usage examples. +See `doc.go` for full documentation or the Databrick's documentation for [SQL Driver for Go](https://docs.databricks.com/dev-tools/go-sql-driver.html). ## Usage ```go import ( - "database/sql" - "time" - - _ "github.com/databricks/databricks-sql-go" + "context" + "database/sql" + _ "github.com/databricks/databricks-sql-go" ) -db, err := sql.Open("databricks", "token:********@********.databricks.com/sql/1.0/endpoints/********") +db, err := sql.Open("databricks", "token:********@********.databricks.com:443/sql/1.0/endpoints/********") if err != nil { - panic(err) + panic(err) } +defer db.Close() -rows, err := db.Query("SELECT 1") +rows, err := db.QueryContext(context.Background(), "SELECT 1") +defer rows.Close() ``` Additional usage examples are available [here](https://github.com/databricks/databricks-sql-go/tree/main/examples). -### DSN (Data Source Name) +### Connecting with DSN (Data Source Name) The DSN format is: ``` -token:[your token]@[Workspace hostname][Endpoint HTTP Path]?param=value +token:[your token]@[Workspace hostname]:[Port number][Endpoint HTTP Path]?param=value ``` You can set query timeout value by appending a `timeout` query parameter (in seconds) and you can set max rows to retrieve per network request by setting the `maxRows` query parameter: ``` -token:[your token]@[Workspace hostname][Endpoint HTTP Path]?timeout=1000&maxRows=1000 +token:[your token]@[Workspace hostname]:[Port number][Endpoint HTTP Path]?timeout=1000&maxRows=1000 +``` + +### Connecting with a new Connector + +You can also connect with a new connector object. For example: + +```go +import ( +"database/sql" + _ "github.com/databricks/databricks-sql-go" +) + +connector, err := dbsql.NewConnector( + dbsql.WithServerHostname(), + dbsql.WithPort(), + dbsql.WithHTTPPath(), + dbsql.WithAccessToken() +) +if err != nil { + log.Fatal(err) +} +db := sql.OpenDB(connector) +defer db.Close() ``` +View `doc.go` or `connector.go` to understand all the functional options available when creating a new connector object. + ## Develop ### Lint diff --git a/connection.go b/connection.go index 29a15a0..6500732 100644 --- a/connection.go +++ b/connection.go @@ -21,27 +21,27 @@ type conn struct { session *cli_service.TOpenSessionResp } -// The driver does not really implement prepared statements. +// Prepare prepares a statement with the query bound to this connection. func (c *conn) Prepare(query string) (driver.Stmt, error) { return &stmt{conn: c, query: query}, nil } -// The driver does not really implement prepared statements. +// PrepareContext prepares a statement with the query bound to this connection. +// Currently, PrepareContext does not use context and is functionally equivalent to Prepare. func (c *conn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) { return &stmt{conn: c, query: query}, nil } +// Close closes the session. +// sql package maintains a free pool of connections and only calls Close when there's a surplus of idle connections. func (c *conn) Close() error { log := logger.WithContext(c.id, "", "") ctx := driverctx.NewContextWithConnId(context.Background(), c.id) - sentinel := sentinel.Sentinel{ - OnDoneFn: func(statusResp any) (any, error) { - return c.client.CloseSession(ctx, &cli_service.TCloseSessionReq{ - SessionHandle: c.session.SessionHandle, - }) - }, - } - _, _, err := sentinel.Watch(ctx, c.cfg.PollInterval, 15*time.Second) + + _, err := c.client.CloseSession(ctx, &cli_service.TCloseSessionReq{ + SessionHandle: c.session.SessionHandle, + }) + if err != nil { log.Err(err).Msg("databricks: failed to close connection") return wrapErr(err, "failed to close connection") @@ -49,20 +49,22 @@ func (c *conn) Close() error { return nil } -// Not supported in Databricks +// Not supported in Databricks. func (c *conn) Begin() (driver.Tx, error) { return nil, errors.New(ErrTransactionsNotSupported) } -// Not supported in Databricks +// Not supported in Databricks. func (c *conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { return nil, errors.New(ErrTransactionsNotSupported) } +// Ping attempts to verify that the server is accessible. +// Returns ErrBadConn if ping fails and consequently DB.Ping will remove the conn from the pool. func (c *conn) Ping(ctx context.Context) error { log := logger.WithContext(c.id, driverctx.CorrelationIdFromContext(ctx), "") ctx = driverctx.NewContextWithConnId(ctx, c.id) - ctx1, cancel := context.WithTimeout(ctx, 15*time.Second) + ctx1, cancel := context.WithTimeout(ctx, 60*time.Second) defer cancel() _, err := c.QueryContext(ctx1, "select 1", nil) if err != nil { @@ -72,12 +74,13 @@ func (c *conn) Ping(ctx context.Context) error { return nil } -// Implementation of SessionResetter +// ResetSession is called prior to executing a query on the connection. +// The session with this driver does not have any important state to reset before re-use. func (c *conn) ResetSession(ctx context.Context) error { - // For now our session does not have any important state to reset before re-use return nil } +// IsValid signals whether a connection is valid or if it should be discarded. func (c *conn) IsValid() bool { return c.session.GetStatus().StatusCode == cli_service.TStatusCode_SUCCESS_STATUS } @@ -88,8 +91,11 @@ func (c *conn) IsValid() bool { // ExecContext honors the context timeout and return when it is canceled. // Statement ExecContext is the same as connection ExecContext func (c *conn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { - log := logger.WithContext(c.id, driverctx.CorrelationIdFromContext(ctx), "") + corrId := driverctx.CorrelationIdFromContext(ctx) + log := logger.WithContext(c.id, corrId, "") msg, start := logger.Track("ExecContext") + defer log.Duration(msg, start) + ctx = driverctx.NewContextWithConnId(ctx, c.id) if len(args) > 0 { return nil, errors.New(ErrParametersNotSupported) @@ -97,14 +103,26 @@ func (c *conn) ExecContext(ctx context.Context, query string, args []driver.Name exStmtResp, opStatusResp, err := c.runQuery(ctx, query, args) if exStmtResp != nil && exStmtResp.OperationHandle != nil { - log = logger.WithContext(c.id, driverctx.CorrelationIdFromContext(ctx), client.SprintGuid(exStmtResp.OperationHandle.OperationId.GUID)) + // we have an operation id so update the logger + log = logger.WithContext(c.id, corrId, client.SprintGuid(exStmtResp.OperationHandle.OperationId.GUID)) + + // since we have an operation handle we can close the operation if necessary + alreadyClosed := exStmtResp.DirectResults != nil && exStmtResp.DirectResults.CloseOperation != nil + newCtx := driverctx.NewContextWithCorrelationId(driverctx.NewContextWithConnId(context.Background(), c.id), corrId) + if !alreadyClosed && (opStatusResp == nil || opStatusResp.GetOperationState() != cli_service.TOperationState_CLOSED_STATE) { + _, err1 := c.client.CloseOperation(newCtx, &cli_service.TCloseOperationReq{ + OperationHandle: exStmtResp.OperationHandle, + }) + if err1 != nil { + log.Err(err1).Msg("databricks: failed to close operation after executing statement") + } + } } - defer log.Duration(msg, start) - if err != nil { log.Err(err).Msgf("databricks: failed to execute query: query %s", query) return nil, wrapErrf(err, "failed to execute query") } + res := result{AffectedRows: opStatusResp.GetNumModifiedRows()} return &res, nil @@ -140,22 +158,9 @@ func (c *conn) QueryContext(ctx context.Context, query string, args []driver.Nam // hold on to the operation handle opHandle := exStmtResp.OperationHandle - rows := rows{ - connId: c.id, - correlationId: corrId, - client: c.client, - opHandle: opHandle, - pageSize: int64(c.cfg.MaxRows), - location: c.cfg.Location, - } - - if exStmtResp.DirectResults != nil { - // return results - rows.fetchResults = exStmtResp.DirectResults.ResultSet - rows.fetchResultsMetadata = exStmtResp.DirectResults.ResultSetMetadata + rows := NewRows(c.id, corrId, c.client, opHandle, int64(c.cfg.MaxRows), c.cfg.Location, exStmtResp.DirectResults) - } - return &rows, nil + return rows, nil } @@ -168,10 +173,12 @@ func (c *conn) runQuery(ctx context.Context, query string, args []driver.NamedVa if err != nil { return exStmtResp, nil, err } - // hold on to the operation handle opHandle := exStmtResp.OperationHandle if opHandle != nil && opHandle.OperationId != nil { - log = logger.WithContext(c.id, driverctx.CorrelationIdFromContext(ctx), client.SprintGuid(opHandle.OperationId.GUID)) + log = logger.WithContext( + c.id, + driverctx.CorrelationIdFromContext(ctx), client.SprintGuid(opHandle.OperationId.GUID), + ) } if exStmtResp.DirectResults != nil { @@ -181,15 +188,18 @@ func (c *conn) runQuery(ctx context.Context, query string, args []driver.NamedVa // terminal states // good case cli_service.TOperationState_FINISHED_STATE: - // return results return exStmtResp, opStatus, nil // bad - case cli_service.TOperationState_CANCELED_STATE, cli_service.TOperationState_CLOSED_STATE, cli_service.TOperationState_ERROR_STATE, cli_service.TOperationState_TIMEDOUT_STATE: - // do we need to close the operation in these cases? + case cli_service.TOperationState_CANCELED_STATE, + cli_service.TOperationState_CLOSED_STATE, + cli_service.TOperationState_ERROR_STATE, + cli_service.TOperationState_TIMEDOUT_STATE: logBadQueryState(log, opStatus) return exStmtResp, opStatus, errors.New(opStatus.GetDisplayMessage()) // live states - case cli_service.TOperationState_INITIALIZED_STATE, cli_service.TOperationState_PENDING_STATE, cli_service.TOperationState_RUNNING_STATE: + case cli_service.TOperationState_INITIALIZED_STATE, + cli_service.TOperationState_PENDING_STATE, + cli_service.TOperationState_RUNNING_STATE: statusResp, err := c.pollOperation(ctx, opHandle) if err != nil { return exStmtResp, statusResp, err @@ -198,16 +208,18 @@ func (c *conn) runQuery(ctx context.Context, query string, args []driver.NamedVa // terminal states // good case cli_service.TOperationState_FINISHED_STATE: - // return handle to fetch results later - return exStmtResp, opStatus, nil + return exStmtResp, statusResp, nil // bad - case cli_service.TOperationState_CANCELED_STATE, cli_service.TOperationState_CLOSED_STATE, cli_service.TOperationState_ERROR_STATE, cli_service.TOperationState_TIMEDOUT_STATE: + case cli_service.TOperationState_CANCELED_STATE, + cli_service.TOperationState_CLOSED_STATE, + cli_service.TOperationState_ERROR_STATE, + cli_service.TOperationState_TIMEDOUT_STATE: logBadQueryState(log, statusResp) - return exStmtResp, opStatus, errors.New(statusResp.GetDisplayMessage()) + return exStmtResp, statusResp, errors.New(statusResp.GetDisplayMessage()) // live states default: logBadQueryState(log, statusResp) - return exStmtResp, opStatus, errors.New("invalid operation state. This should not have happened") + return exStmtResp, statusResp, errors.New("invalid operation state. This should not have happened") } // weird states default: @@ -224,10 +236,12 @@ func (c *conn) runQuery(ctx context.Context, query string, args []driver.NamedVa // terminal states // good case cli_service.TOperationState_FINISHED_STATE: - // return handle to fetch results later return exStmtResp, statusResp, nil // bad - case cli_service.TOperationState_CANCELED_STATE, cli_service.TOperationState_CLOSED_STATE, cli_service.TOperationState_ERROR_STATE, cli_service.TOperationState_TIMEDOUT_STATE: + case cli_service.TOperationState_CANCELED_STATE, + cli_service.TOperationState_CLOSED_STATE, + cli_service.TOperationState_ERROR_STATE, + cli_service.TOperationState_TIMEDOUT_STATE: logBadQueryState(log, statusResp) return exStmtResp, statusResp, errors.New(statusResp.GetDisplayMessage()) // live states @@ -246,39 +260,52 @@ func logBadQueryState(log *logger.DBSQLLogger, opStatus *cli_service.TGetOperati func (c *conn) executeStatement(ctx context.Context, query string, args []driver.NamedValue) (*cli_service.TExecuteStatementResp, error) { corrId := driverctx.CorrelationIdFromContext(ctx) log := logger.WithContext(c.id, corrId, "") - sentinel := sentinel.Sentinel{ - OnDoneFn: func(statusResp any) (any, error) { - req := cli_service.TExecuteStatementReq{ - SessionHandle: c.session.SessionHandle, - Statement: query, - RunAsync: c.cfg.RunAsync, - QueryTimeout: int64(c.cfg.QueryTimeout / time.Second), - // this is specific for databricks. It shortcuts server roundtrips - GetDirectResults: &cli_service.TSparkGetDirectResults{ - MaxRows: int64(c.cfg.MaxRows), - }, - // CanReadArrowResult_: &t, - // CanDecompressLZ4Result_: &f, - // CanDownloadResult_: &t, - } - ctx = driverctx.NewContextWithConnId(ctx, c.id) - resp, err := c.client.ExecuteStatement(ctx, &req) - return resp, wrapErr(err, "failed to execute statement") - }, - OnCancelFn: func() (any, error) { - log.Warn().Msg("databricks: execute statement canceled while creation operation") - return nil, nil + + req := cli_service.TExecuteStatementReq{ + SessionHandle: c.session.SessionHandle, + Statement: query, + RunAsync: c.cfg.RunAsync, + QueryTimeout: int64(c.cfg.QueryTimeout / time.Second), + GetDirectResults: &cli_service.TSparkGetDirectResults{ + MaxRows: int64(c.cfg.MaxRows), }, } - _, res, err := sentinel.Watch(ctx, c.cfg.PollInterval, c.cfg.QueryTimeout) - if err != nil { - return nil, err + + ctx = driverctx.NewContextWithConnId(ctx, c.id) + resp, err := c.client.ExecuteStatement(ctx, &req) + + var shouldCancel = func(resp *cli_service.TExecuteStatementResp) bool { + if resp == nil { + return false + } + hasHandle := resp.OperationHandle != nil + isOpen := resp.DirectResults != nil && resp.DirectResults.CloseOperation == nil + return hasHandle && isOpen } - exStmtResp, ok := res.(*cli_service.TExecuteStatementResp) - if !ok { - return exStmtResp, errors.New("databricks: invalid execute statement response") + + select { + default: + case <-ctx.Done(): + newCtx := driverctx.NewContextWithCorrelationId(driverctx.NewContextWithConnId(context.Background(), c.id), corrId) + // in case context is done, we need to cancel the operation if necessary + if err == nil && shouldCancel(resp) { + log.Debug().Msg("databricks: canceling query") + _, err1 := c.client.CancelOperation(newCtx, &cli_service.TCancelOperationReq{ + OperationHandle: resp.GetOperationHandle(), + }) + + if err1 != nil { + log.Err(err).Msgf("databricks: cancel failed") + } + log.Debug().Msgf("databricks: cancel success") + + } else { + log.Debug().Msg("databricks: query did not need cancellation") + } + return nil, ctx.Err() } - return exStmtResp, err + + return resp, err } func (c *conn) pollOperation(ctx context.Context, opHandle *cli_service.TOperationHandle) (*cli_service.TGetOperationStatusResp, error) { @@ -301,12 +328,13 @@ func (c *conn) pollOperation(ctx context.Context, opHandle *cli_service.TOperati log.Debug().Msgf("databricks: status %s", statusResp.GetOperationState().String()) } return func() bool { - // which other states? if err != nil { return true } switch statusResp.GetOperationState() { - case cli_service.TOperationState_INITIALIZED_STATE, cli_service.TOperationState_PENDING_STATE, cli_service.TOperationState_RUNNING_STATE: + case cli_service.TOperationState_INITIALIZED_STATE, + cli_service.TOperationState_PENDING_STATE, + cli_service.TOperationState_RUNNING_STATE: return false default: log.Debug().Msg("databricks: polling done") diff --git a/connection_test.go b/connection_test.go index 0d7f252..070aa2c 100644 --- a/connection_test.go +++ b/connection_test.go @@ -86,6 +86,241 @@ func TestConn_executeStatement(t *testing.T) { assert.NoError(t, err) assert.Equal(t, 1, executeStatementCount) }) + + t.Run("ExecStatement should close operation on success", func(t *testing.T) { + var executeStatementCount, closeOperationCount int + executeStatementResp := &cli_service.TExecuteStatementResp{ + Status: &cli_service.TStatus{ + StatusCode: cli_service.TStatusCode_SUCCESS_STATUS, + }, + OperationHandle: &cli_service.TOperationHandle{ + OperationId: &cli_service.THandleIdentifier{ + GUID: []byte{1, 2, 3, 4, 2, 23, 4, 2, 3, 1, 2, 3, 4, 223, 34, 54}, + Secret: []byte("b"), + }, + }, + DirectResults: &cli_service.TSparkDirectResults{ + OperationStatus: &cli_service.TGetOperationStatusResp{ + Status: &cli_service.TStatus{ + StatusCode: cli_service.TStatusCode_SUCCESS_STATUS, + }, + OperationState: cli_service.TOperationStatePtr(cli_service.TOperationState_ERROR_STATE), + ErrorMessage: strPtr("error message"), + DisplayMessage: strPtr("display message"), + }, + ResultSetMetadata: &cli_service.TGetResultSetMetadataResp{ + Status: &cli_service.TStatus{ + StatusCode: cli_service.TStatusCode_SUCCESS_STATUS, + }, + }, + ResultSet: &cli_service.TFetchResultsResp{ + Status: &cli_service.TStatus{ + StatusCode: cli_service.TStatusCode_SUCCESS_STATUS, + }, + }, + }, + } + + testClient := &client.TestClient{ + FnExecuteStatement: func(ctx context.Context, req *cli_service.TExecuteStatementReq) (r *cli_service.TExecuteStatementResp, err error) { + executeStatementCount++ + return executeStatementResp, nil + }, + FnCloseOperation: func(ctx context.Context, req *cli_service.TCloseOperationReq) (_r *cli_service.TCloseOperationResp, _err error) { + closeOperationCount++ + return &cli_service.TCloseOperationResp{}, nil + }, + } + testConn := &conn{ + session: getTestSession(), + client: testClient, + cfg: config.WithDefaults(), + } + + type opStateTest struct { + state cli_service.TOperationState + err string + closeOperationCount int + } + + // test behaviour with all terminal operation states + operationStateTests := []opStateTest{ + {state: cli_service.TOperationState_ERROR_STATE, err: "error state", closeOperationCount: 1}, + {state: cli_service.TOperationState_FINISHED_STATE, err: "", closeOperationCount: 1}, + {state: cli_service.TOperationState_CANCELED_STATE, err: "cancelled state", closeOperationCount: 1}, + {state: cli_service.TOperationState_CLOSED_STATE, err: "closed state", closeOperationCount: 0}, + {state: cli_service.TOperationState_TIMEDOUT_STATE, err: "timeout state", closeOperationCount: 1}, + } + + for _, opTest := range operationStateTests { + closeOperationCount = 0 + executeStatementCount = 0 + executeStatementResp.DirectResults.OperationStatus.OperationState = &opTest.state + executeStatementResp.DirectResults.OperationStatus.DisplayMessage = &opTest.err + _, err := testConn.ExecContext(context.Background(), "select 1", []driver.NamedValue{}) + if opTest.err == "" { + assert.NoError(t, err) + } else { + assert.EqualError(t, err, opTest.err) + } + assert.Equal(t, 1, executeStatementCount) + assert.Equal(t, opTest.closeOperationCount, closeOperationCount) + } + + // if the execute statement response contains direct results with a non-nil CloseOperation member + // we shouldn't call close + closeOperationCount = 0 + executeStatementCount = 0 + executeStatementResp.DirectResults.CloseOperation = &cli_service.TCloseOperationResp{} + finished := cli_service.TOperationState_FINISHED_STATE + executeStatementResp.DirectResults.OperationStatus.OperationState = &finished + _, err := testConn.ExecContext(context.Background(), "select 1", []driver.NamedValue{}) + assert.NoError(t, err) + assert.Equal(t, 1, executeStatementCount) + assert.Equal(t, 0, closeOperationCount) + }) + + t.Run("executeStatement should not call cancel if not needed", func(t *testing.T) { + var executeStatementCount int + var cancelOperationCount int + var cancel context.CancelFunc + executeStatement := func(ctx context.Context, req *cli_service.TExecuteStatementReq) (r *cli_service.TExecuteStatementResp, err error) { + executeStatementCount++ + cancel() + executeStatementResp := &cli_service.TExecuteStatementResp{ + Status: &cli_service.TStatus{ + StatusCode: cli_service.TStatusCode_SUCCESS_STATUS, + }, + OperationHandle: &cli_service.TOperationHandle{ + OperationId: &cli_service.THandleIdentifier{ + GUID: []byte{1, 2, 3, 4, 2, 23, 4, 2, 3, 1, 2, 3, 4, 4, 223, 34, 54}, + Secret: []byte("b"), + }, + }, + DirectResults: &cli_service.TSparkDirectResults{ + OperationStatus: &cli_service.TGetOperationStatusResp{ + Status: &cli_service.TStatus{ + StatusCode: cli_service.TStatusCode_SUCCESS_STATUS, + }, + OperationState: cli_service.TOperationStatePtr(cli_service.TOperationState_FINISHED_STATE), + ErrorMessage: strPtr("error message"), + DisplayMessage: strPtr("display message"), + }, + ResultSetMetadata: &cli_service.TGetResultSetMetadataResp{ + Status: &cli_service.TStatus{ + StatusCode: cli_service.TStatusCode_SUCCESS_STATUS, + }, + }, + ResultSet: &cli_service.TFetchResultsResp{ + Status: &cli_service.TStatus{ + StatusCode: cli_service.TStatusCode_SUCCESS_STATUS, + }, + }, + CloseOperation: &cli_service.TCloseOperationResp{ + Status: &cli_service.TStatus{ + StatusCode: cli_service.TStatusCode_SUCCESS_STATUS, + }, + }, + }, + } + return executeStatementResp, nil + } + cancelOperation := func(ctx context.Context, req *cli_service.TCancelOperationReq) (r *cli_service.TCancelOperationResp, err error) { + cancelOperationCount++ + cancelOperationResp := &cli_service.TCancelOperationResp{ + Status: &cli_service.TStatus{ + StatusCode: cli_service.TStatusCode_SUCCESS_STATUS, + }, + } + return cancelOperationResp, nil + } + testClient := &client.TestClient{ + FnExecuteStatement: executeStatement, + FnCancelOperation: cancelOperation, + } + testConn := &conn{ + session: getTestSession(), + client: testClient, + cfg: config.WithDefaults(), + } + + ctx := context.Background() + ctx, cancel = context.WithCancel(ctx) + defer cancel() + _, err := testConn.executeStatement(ctx, "select 1", []driver.NamedValue{}) + + assert.Error(t, err) + assert.Equal(t, 1, executeStatementCount) + assert.Equal(t, 0, cancelOperationCount) + }) + t.Run("executeStatement should call cancel if needed", func(t *testing.T) { + var executeStatementCount int + var cancelOperationCount int + var cancel context.CancelFunc + executeStatement := func(ctx context.Context, req *cli_service.TExecuteStatementReq) (r *cli_service.TExecuteStatementResp, err error) { + executeStatementCount++ + cancel() + executeStatementResp := &cli_service.TExecuteStatementResp{ + Status: &cli_service.TStatus{ + StatusCode: cli_service.TStatusCode_SUCCESS_STATUS, + }, + OperationHandle: &cli_service.TOperationHandle{ + OperationId: &cli_service.THandleIdentifier{ + GUID: []byte{1, 2, 3, 4, 2, 23, 4, 2, 3, 1, 2, 3, 4, 4, 223, 34, 54}, + Secret: []byte("b"), + }, + }, + DirectResults: &cli_service.TSparkDirectResults{ + OperationStatus: &cli_service.TGetOperationStatusResp{ + Status: &cli_service.TStatus{ + StatusCode: cli_service.TStatusCode_SUCCESS_STATUS, + }, + OperationState: cli_service.TOperationStatePtr(cli_service.TOperationState_FINISHED_STATE), + ErrorMessage: strPtr("error message"), + DisplayMessage: strPtr("display message"), + }, + ResultSetMetadata: &cli_service.TGetResultSetMetadataResp{ + Status: &cli_service.TStatus{ + StatusCode: cli_service.TStatusCode_SUCCESS_STATUS, + }, + }, + ResultSet: &cli_service.TFetchResultsResp{ + Status: &cli_service.TStatus{ + StatusCode: cli_service.TStatusCode_SUCCESS_STATUS, + }, + }, + }, + } + return executeStatementResp, nil + } + cancelOperation := func(ctx context.Context, req *cli_service.TCancelOperationReq) (r *cli_service.TCancelOperationResp, err error) { + cancelOperationCount++ + cancelOperationResp := &cli_service.TCancelOperationResp{ + Status: &cli_service.TStatus{ + StatusCode: cli_service.TStatusCode_SUCCESS_STATUS, + }, + } + return cancelOperationResp, nil + } + testClient := &client.TestClient{ + FnExecuteStatement: executeStatement, + FnCancelOperation: cancelOperation, + } + testConn := &conn{ + session: getTestSession(), + client: testClient, + cfg: config.WithDefaults(), + } + ctx := context.Background() + ctx, cancel = context.WithCancel(ctx) + defer cancel() + _, err := testConn.executeStatement(ctx, "select 1", []driver.NamedValue{}) + + assert.Error(t, err) + assert.Equal(t, 1, executeStatementCount) + assert.Equal(t, 1, cancelOperationCount) + }) + } func TestConn_pollOperation(t *testing.T) { @@ -109,7 +344,7 @@ func TestConn_pollOperation(t *testing.T) { } res, err := testConn.pollOperation(context.Background(), &cli_service.TOperationHandle{ OperationId: &cli_service.THandleIdentifier{ - GUID: []byte{1, 2, 3, 4, 2, 23, 4, 2, 3, 2, 3, 4, 4, 223, 34, 54}, + GUID: []byte{1, 2, 3, 4, 2, 23, 4, 2, 3, 2, 4, 7, 8, 223, 34, 54}, Secret: []byte("b"), }, }) @@ -496,11 +731,13 @@ func TestConn_runQuery(t *testing.T) { } return executeStatementResp, nil } + var numModRows int64 = 2 getOperationStatus := func(ctx context.Context, req *cli_service.TGetOperationStatusReq) (r *cli_service.TGetOperationStatusResp, err error) { getOperationStatusCount++ getOperationStatusResp := &cli_service.TGetOperationStatusResp{ - OperationState: cli_service.TOperationStatePtr(cli_service.TOperationState_FINISHED_STATE), + OperationState: cli_service.TOperationStatePtr(cli_service.TOperationState_FINISHED_STATE), + NumModifiedRows: &numModRows, } return getOperationStatusResp, nil } @@ -521,6 +758,7 @@ func TestConn_runQuery(t *testing.T) { assert.Equal(t, 1, getOperationStatusCount) assert.NotNil(t, exStmtResp) assert.NotNil(t, opStatusResp) + assert.Equal(t, &numModRows, opStatusResp.NumModifiedRows) }) t.Run("runQuery should return resp and error when query is canceled", func(t *testing.T) { @@ -540,11 +778,13 @@ func TestConn_runQuery(t *testing.T) { } return executeStatementResp, nil } + var numModRows int64 = 3 getOperationStatus := func(ctx context.Context, req *cli_service.TGetOperationStatusReq) (r *cli_service.TGetOperationStatusResp, err error) { getOperationStatusCount++ getOperationStatusResp := &cli_service.TGetOperationStatusResp{ - OperationState: cli_service.TOperationStatePtr(cli_service.TOperationState_CANCELED_STATE), + OperationState: cli_service.TOperationStatePtr(cli_service.TOperationState_CANCELED_STATE), + NumModifiedRows: &numModRows, } return getOperationStatusResp, nil } @@ -565,6 +805,7 @@ func TestConn_runQuery(t *testing.T) { assert.Equal(t, 1, getOperationStatusCount) assert.NotNil(t, exStmtResp) assert.NotNil(t, opStatusResp) + assert.Equal(t, &numModRows, opStatusResp.NumModifiedRows) }) t.Run("runQuery should return resp when query is finished with DirectResults", func(t *testing.T) { @@ -696,11 +937,12 @@ func TestConn_runQuery(t *testing.T) { } return executeStatementResp, nil } - + var numModRows int64 = 3 getOperationStatus := func(ctx context.Context, req *cli_service.TGetOperationStatusReq) (r *cli_service.TGetOperationStatusResp, err error) { getOperationStatusCount++ getOperationStatusResp := &cli_service.TGetOperationStatusResp{ - OperationState: cli_service.TOperationStatePtr(cli_service.TOperationState_FINISHED_STATE), + OperationState: cli_service.TOperationStatePtr(cli_service.TOperationState_FINISHED_STATE), + NumModifiedRows: &numModRows, } return getOperationStatusResp, nil } @@ -718,6 +960,7 @@ func TestConn_runQuery(t *testing.T) { assert.NoError(t, err) assert.Equal(t, 1, executeStatementCount) + assert.Equal(t, &numModRows, opStatusResp.NumModifiedRows) assert.Equal(t, 1, getOperationStatusCount) assert.NotNil(t, exStmtResp) assert.NotNil(t, opStatusResp) @@ -816,6 +1059,11 @@ func TestConn_ExecContext(t *testing.T) { testClient := &client.TestClient{ FnExecuteStatement: executeStatement, + FnCloseOperation: func(ctx context.Context, req *cli_service.TCloseOperationReq) (_r *cli_service.TCloseOperationResp, _err error) { + ctxErr := ctx.Err() + assert.NoError(t, ctxErr) + return &cli_service.TCloseOperationResp{}, nil + }, } testConn := &conn{ session: getTestSession(), @@ -859,6 +1107,11 @@ func TestConn_ExecContext(t *testing.T) { testClient := &client.TestClient{ FnExecuteStatement: executeStatement, FnGetOperationStatus: getOperationStatus, + FnCloseOperation: func(ctx context.Context, req *cli_service.TCloseOperationReq) (_r *cli_service.TCloseOperationResp, _err error) { + ctxErr := ctx.Err() + assert.NoError(t, ctxErr) + return &cli_service.TCloseOperationResp{}, nil + }, } testConn := &conn{ session: getTestSession(), @@ -873,6 +1126,71 @@ func TestConn_ExecContext(t *testing.T) { assert.Equal(t, int64(10), rowsAffected) assert.Equal(t, 1, executeStatementCount) }) + t.Run("ExecContext uses new context to close operation", func(t *testing.T) { + var executeStatementCount, getOperationStatusCount, closeOperationCount, cancelOperationCount int + var cancel context.CancelFunc + executeStatement := func(ctx context.Context, req *cli_service.TExecuteStatementReq) (r *cli_service.TExecuteStatementResp, err error) { + executeStatementCount++ + executeStatementResp := &cli_service.TExecuteStatementResp{ + Status: &cli_service.TStatus{ + StatusCode: cli_service.TStatusCode_SUCCESS_STATUS, + }, + OperationHandle: &cli_service.TOperationHandle{ + OperationId: &cli_service.THandleIdentifier{ + GUID: []byte{1, 2, 3, 4, 2, 23, 4, 2, 3, 2, 3, 4, 4, 223, 34, 54}, + Secret: []byte("b"), + }, + }, + } + return executeStatementResp, nil + } + + getOperationStatus := func(ctx context.Context, req *cli_service.TGetOperationStatusReq) (r *cli_service.TGetOperationStatusResp, err error) { + getOperationStatusCount++ + cancel() + getOperationStatusResp := &cli_service.TGetOperationStatusResp{ + OperationState: cli_service.TOperationStatePtr(cli_service.TOperationState_FINISHED_STATE), + NumModifiedRows: thrift.Int64Ptr(10), + } + return getOperationStatusResp, nil + } + + testClient := &client.TestClient{ + FnExecuteStatement: executeStatement, + FnGetOperationStatus: getOperationStatus, + FnCloseOperation: func(ctx context.Context, req *cli_service.TCloseOperationReq) (_r *cli_service.TCloseOperationResp, _err error) { + closeOperationCount++ + ctxErr := ctx.Err() + assert.NoError(t, ctxErr) + return &cli_service.TCloseOperationResp{}, nil + }, + FnCancelOperation: func(ctx context.Context, req *cli_service.TCancelOperationReq) (r *cli_service.TCancelOperationResp, err error) { + cancelOperationCount++ + cancelOperationResp := &cli_service.TCancelOperationResp{ + Status: &cli_service.TStatus{ + StatusCode: cli_service.TStatusCode_SUCCESS_STATUS, + }, + } + return cancelOperationResp, nil + }, + } + testConn := &conn{ + session: getTestSession(), + client: testClient, + cfg: config.WithDefaults(), + } + ctx := context.Background() + ctx, cancel = context.WithCancel(ctx) + defer cancel() + res, err := testConn.ExecContext(ctx, "insert 10", []driver.NamedValue{}) + time.Sleep(10 * time.Millisecond) + assert.Error(t, err) + assert.Nil(t, res) + assert.Equal(t, 1, executeStatementCount) + assert.Equal(t, 1, cancelOperationCount) + assert.Equal(t, 1, getOperationStatusCount) + assert.Equal(t, 1, closeOperationCount) + }) } func TestConn_QueryContext(t *testing.T) { diff --git a/connector.go b/connector.go index 74f78d8..e220998 100644 --- a/connector.go +++ b/connector.go @@ -11,21 +11,15 @@ import ( "github.com/databricks/databricks-sql-go/internal/cli_service" "github.com/databricks/databricks-sql-go/internal/client" "github.com/databricks/databricks-sql-go/internal/config" - "github.com/databricks/databricks-sql-go/internal/sentinel" "github.com/databricks/databricks-sql-go/logger" - "github.com/pkg/errors" ) type connector struct { cfg *config.Config } +// Connect returns a connection to the Databricks database from a connection pool. func (c *connector) Connect(ctx context.Context) (driver.Conn, error) { - - tclient, err := client.InitThriftClient(c.cfg) - if err != nil { - return nil, wrapErr(err, "error initializing thrift client") - } var catalogName *cli_service.TIdentifier var schemaName *cli_service.TIdentifier if c.cfg.Catalog != "" { @@ -35,29 +29,24 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) { schemaName = cli_service.TIdentifierPtr(cli_service.TIdentifier(c.cfg.Schema)) } - // we need to ensure that open session will eventually end - sentinel := sentinel.Sentinel{ - OnDoneFn: func(statusResp any) (any, error) { - return tclient.OpenSession(ctx, &cli_service.TOpenSessionReq{ - ClientProtocol: c.cfg.ThriftProtocolVersion, - Configuration: make(map[string]string), - InitialNamespace: &cli_service.TNamespace{ - CatalogName: catalogName, - SchemaName: schemaName, - }, - CanUseMultipleCatalogs: &c.cfg.CanUseMultipleCatalogs, - }) - }, + tclient, err := client.InitThriftClient(c.cfg) + if err != nil { + return nil, wrapErr(err, "error initializing thrift client") } - // default timeout in here in addition to potential context timeout - _, res, err := sentinel.Watch(ctx, c.cfg.PollInterval, c.cfg.ConnectTimeout) + + session, err := tclient.OpenSession(ctx, &cli_service.TOpenSessionReq{ + ClientProtocol: c.cfg.ThriftProtocolVersion, + Configuration: make(map[string]string), + InitialNamespace: &cli_service.TNamespace{ + CatalogName: catalogName, + SchemaName: schemaName, + }, + CanUseMultipleCatalogs: &c.cfg.CanUseMultipleCatalogs, + }) + if err != nil { return nil, wrapErrf(err, "error connecting: host=%s port=%d, httpPath=%s", c.cfg.Host, c.cfg.Port, c.cfg.HTTPPath) } - session, ok := res.(*cli_service.TOpenSessionResp) - if !ok { - return nil, errors.New("databricks: invalid open session response") - } conn := &conn{ id: client.SprintGuid(session.SessionHandle.GetSessionId().GUID), @@ -80,6 +69,7 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) { return conn, nil } +// Driver returns underlying databricksDriver for compatibility with sql.DB Driver method func (c *connector) Driver() driver.Driver { return &databricksDriver{} } @@ -88,7 +78,7 @@ var _ driver.Connector = (*connector)(nil) type connOption func(*config.Config) -// NewConnector creates a connection that can be used with sql.OpenDB(). +// NewConnector creates a connection that can be used with `sql.OpenDB()`. // This is an easier way to set up the DB instead of having to construct a DSN string. func NewConnector(options ...connOption) (driver.Connector, error) { // config with default options @@ -97,9 +87,8 @@ func NewConnector(options ...connOption) (driver.Connector, error) { for _, opt := range options { opt(cfg) } - // validate config? - return &connector{cfg}, nil + return &connector{cfg: cfg}, nil } // WithServerHostname sets up the server hostname. Mandatory. @@ -129,6 +118,9 @@ func WithAccessToken(token string) connOption { // WithHTTPPath sets up the endpoint to the warehouse. Mandatory. func WithHTTPPath(path string) connOption { return func(c *config.Config) { + if !strings.HasPrefix(path, "/") { + path = "/" + path + } c.HTTPPath = path } } diff --git a/connector_test.go b/connector_test.go index 660612c..5e2f2e3 100644 --- a/connector_test.go +++ b/connector_test.go @@ -1,28 +1,14 @@ package dbsql import ( - "context" + "testing" + "time" + "github.com/databricks/databricks-sql-go/internal/config" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "testing" - "time" ) -func TestConnector_Connect(t *testing.T) { - t.Run("Connect returns err when thrift client initialization fails", func(t *testing.T) { - cfg := config.WithDefaults() - cfg.ThriftProtocol = "invalidprotocol" - - testConnector := connector{ - cfg: cfg, - } - conn, err := testConnector.Connect(context.Background()) - assert.Nil(t, conn) - assert.Error(t, err) - }) -} - func TestNewConnector(t *testing.T) { t.Run("Connector initialized with functional options should have all options set", func(t *testing.T) { host := "databricks-host" @@ -51,7 +37,7 @@ func TestNewConnector(t *testing.T) { Port: port, Protocol: "https", AccessToken: accessToken, - HTTPPath: httpPath, + HTTPPath: "/" + httpPath, MaxRows: maxRows, QueryTimeout: timeout, Catalog: catalog, @@ -66,4 +52,32 @@ func TestNewConnector(t *testing.T) { assert.Nil(t, err) assert.Equal(t, expectedCfg, coni.cfg) }) + t.Run("Connector initialized minimal settings", func(t *testing.T) { + host := "databricks-host" + port := 443 + accessToken := "token" + httpPath := "http-path" + maxRows := 100000 + sessionParams := map[string]string{} + con, err := NewConnector( + WithServerHostname(host), + WithAccessToken(accessToken), + WithHTTPPath(httpPath), + ) + expectedUserConfig := config.UserConfig{ + Host: host, + Port: port, + Protocol: "https", + AccessToken: accessToken, + HTTPPath: "/" + httpPath, + MaxRows: maxRows, + SessionParams: sessionParams, + } + expectedCfg := config.WithDefaults() + expectedCfg.UserConfig = expectedUserConfig + coni, ok := con.(*connector) + require.True(t, ok) + assert.Nil(t, err) + assert.Equal(t, expectedCfg, coni.cfg) + }) } diff --git a/doc.go b/doc.go index d052ea4..c048c29 100644 --- a/doc.go +++ b/doc.go @@ -1,4 +1,229 @@ /* Package dbsql implements the go driver to Databricks SQL + +# Usage + +Clients should use the database/sql package in conjunction with the driver: + + import ( + "database/sql" + + _ "github.com/databricks/databricks-sql-go" + ) + + func main() { + db, err := sql.Open("databricks", "token:@:/") + + if err != nil { + log.Fatal(err) + } + defer db.Close() + } + +# Connection via DSN (Data Source Name) + +Use sql.Open() to create a database handle via a data source name string: + + db, err := sql.Open("databricks", "") + +The DSN format is: + + token:[my_token]@[hostname]:[port]/[endpoint http path]?param=value + +Supported optional connection parameters can be specified in param=value and include: + + - catalog: Sets the initial catalog name in the session + - schema: Sets the initial schema name in the session + - maxRows: Sets up the max rows fetched per request. Default is 100000 + - timeout: Adds timeout (in seconds) for the server query execution. Default is no timeout + - userAgentEntry: Used to identify partners. Set as a string with format + +Supported optional session parameters can be specified in param=value and include: + + - ansi_mode: (Boolean string). Session statements will adhere to rules defined by ANSI SQL specification. + - timezone: (e.g. "America/Los_Angeles"). Sets the timezone of the session + +# Connection via new connector object + +Use sql.OpenDB() to create a database handle via a new connector object created with dbsql.NewConnector(): + + import ( + "database/sql" + dbsql "github.com/databricks/databricks-sql-go" + ) + + func main() { + connector, err := dbsql.NewConnector( + dbsql.WithServerHostname(), + dbsql.WithPort(), + dbsql.WithHTTPPath(), + dbsql.WithAccessToken() + ) + if err != nil { + log.Fatal(err) + } + + db := sql.OpenDB(connector) + defer db.Close() + ... + } + +Supported functional options include: + + - WithServerHostname( string): Sets up the server hostname. Mandatory + - WithPort( int): Sets up the server port. Mandatory + - WithAccessToken( string): Sets up the Personal Access Token. Mandatory + - WithHTTPPath( string): Sets up the endpoint to the warehouse. Mandatory + - WithInitialNamespace( string, string): Sets up the catalog and schema name in the session. Optional + - WithMaxRows( int): Sets up the max rows fetched per request. Default is 100000. Optional + - WithSessionParams( map[string]string): Sets up session parameters including "timezone" and "ansi_mode". Optional + - WithTimeout( Duration). Adds timeout (in time.Duration) for the server query execution. Default is no timeout. Optional + - WithUserAgentEntry( string). Used to identify partners. Optional + +# Query cancellation and timeout + +Cancelling a query via context cancellation or timeout is supported. + + // Set up context timeout + ctx, cancel := context.WithTimeout(context.Background(), 30 * time.Second) + defer cancel() + + // Execute query. Query will be cancelled after 30 seconds if still running + res, err := db.ExecContext(ctx, "CREATE TABLE example(id int, message string)") + +# CorrelationId and ConnId + +Use the driverctx package under driverctx/ctx.go to add CorrelationId and ConnId to the context. +CorrelationId and ConnId makes it convenient to parse and create metrics in logging. + +**Connection Id** +Internal id to track what happens under a connection. Connections can be reused so this would track across queries. + +**Query Id** +Internal id to track what happens under a query. Useful because the same query can be used with multiple connections. + +**Correlation Id** +External id, such as request ID, to track what happens under a request. Useful to track multiple connections in the same request. + + ctx := dbsqlctx.NewContextWithCorrelationId(context.Background(), "workflow-example") + +# Logging + +Use the logger package under logger.go to set up logging (from zerolog). +By default, logging level is `warn`. If you want to disable logging, use `disabled`. +The user can also utilize Track() and Duration() to custom log the elapsed time of anything tracked. + + import ( + dbsqllog "github.com/databricks/databricks-sql-go/logger" + dbsqlctx "github.com/databricks/databricks-sql-go/driverctx" + ) + + func main() { + // Optional. Set the logging level with SetLogLevel() + if err := dbsqllog.SetLogLevel("debug"); err != nil { + log.Fatal(err) + } + + // Optional. Set logging output with SetLogOutput() + // Default is os.Stderr. If running in terminal, logger will use ConsoleWriter to prettify logs + dbsqllog.SetLogOutput(os.Stdout) + + // Optional. Set correlation id with NewContextWithCorrelationId + ctx := dbsqlctx.NewContextWithCorrelationId(context.Background(), "workflow-example") + + + // Optional. Track time spent and log elapsed time + msg, start := logger.Track("Run Main") + defer log.Duration(msg, start) + + db, err := sql.Open("databricks", "") + ... + } + +The result log may look like this: + + {"level":"debug","connId":"01ed6545-5669-1ec7-8c7e-6d8a1ea0ab16","corrId":"workflow-example","queryId":"01ed6545-57cc-188a-bfc5-d9c0eaf8e189","time":1668558402,"message":"Run Main elapsed time: 1.298712292s"} + +# Supported Data Types + +================================== + +Databricks Type --> Golang Type + +================================== + +BOOLEAN --> bool + +TINYINT --> int8 + +SMALLINT --> int16 + +INT --> int32 + +BIGINT --> int64 + +FLOAT --> float32 + +DOUBLE --> float64 + +VOID --> nil + +STRING --> string + +DATE --> time.Time + +TIMESTAMP --> time.Time + +DECIMAL(p,s) --> sql.RawBytes + +BINARY --> sql.RawBytes + +ARRAY --> sql.RawBytes + +STRUCT --> sql.RawBytes + +MAP --> sql.RawBytes + +INTERVAL (year-month) --> string + +INTERVAL (day-time) --> string + +For ARRAY, STRUCT, and MAP types, sql.Scan can cast sql.RawBytes to JSON string, which can be unmarshalled to Golang +arrays, maps, and structs. For example: + + type structVal struct { + StringField string `json:"string_field"` + ArrayField []int `json:"array_field"` + } + type row struct { + arrayVal []int + mapVal map[string]int + structVal structVal + } + res := []row{} + + for rows.Next() { + r := row{} + tempArray := []byte{} + tempStruct := []byte{} + tempMap := []byte{} + if err := rows.Scan(&tempArray, &tempMap, &tempStruct); err != nil { + log.Fatal(err) + } + if err := json.Unmarshal(tempArray, &r.arrayVal); err != nil { + log.Fatal(err) + } + if err := json.Unmarshal(tempMap, &r.mapVal); err != nil { + log.Fatal(err) + } + if err := json.Unmarshal(tempStruct, &r.structVal); err != nil { + log.Fatal(err) + } + res = append(res, r) + } + +May generate the following row: + + {arrayVal:[1,2,3] mapVal:{"key1":1} structVal:{"string_field":"string_val","array_field":[4,5,6]}} */ package dbsql diff --git a/driver.go b/driver.go index 80e54ef..945a21d 100644 --- a/driver.go +++ b/driver.go @@ -15,6 +15,8 @@ func init() { type databricksDriver struct{} +// Open returns a new connection to Databricks database with a DSN string. +// Use sql.Open("databricks", ) after importing this driver package. func (d *databricksDriver) Open(dsn string) (driver.Conn, error) { cfg := config.WithDefaults() userCfg, err := config.ParseDSN(dsn) @@ -28,6 +30,8 @@ func (d *databricksDriver) Open(dsn string) (driver.Conn, error) { return c.Connect(context.Background()) } +// OpenConnector returns a new Connector. +// Used by sql.DB to obtain a Connector and invoke its Connect method to obtain each needed connection. func (d *databricksDriver) OpenConnector(dsn string) (driver.Connector, error) { cfg := config.WithDefaults() ucfg, err := config.ParseDSN(dsn) @@ -35,46 +39,9 @@ func (d *databricksDriver) OpenConnector(dsn string) (driver.Connector, error) { return nil, err } cfg.UserConfig = ucfg - return &connector{cfg}, nil + + return &connector{cfg: cfg}, nil } var _ driver.Driver = (*databricksDriver)(nil) var _ driver.DriverContext = (*databricksDriver)(nil) - -// type databricksDB struct { -// *sql.DB -// } - -// func OpenDB(c driver.Connector) *databricksDB { -// db := sql.OpenDB(c) -// return &databricksDB{db} -// } - -// func (db *databricksDB) QueryContextAsync(ctx context.Context, query string, args ...any) (rows *sql.Rows, queryId string, err error) { -// return nil, "", nil -// } - -// func (db *databricksDB) ExecContextAsync(ctx context.Context, query string, args ...any) (result sql.Result, queryId string) { -// //go do something -// return nil, "" -// } - -// func (db *databricksDB) CancelQuery(ctx context.Context, queryId string) error { -// //go do something -// return nil -// } - -// func (db *databricksDB) GetQueryStatus(ctx context.Context, queryId string) error { -// //go do something -// return nil -// } - -// func (db *databricksDB) FetchRows(ctx context.Context, queryId string) (rows *sql.Rows, err error) { -// //go do something -// return nil, nil -// } - -// func (db *databricksDB) FetchResult(ctx context.Context, queryId string) (rows sql.Result, err error) { -// //go do something -// return nil, nil -// } diff --git a/examples/createdrop/main.go b/examples/createdrop/main.go new file mode 100644 index 0000000..c6835ee --- /dev/null +++ b/examples/createdrop/main.go @@ -0,0 +1,78 @@ +package main + +import ( + "context" + "database/sql" + "log" + "os" + "strconv" + "time" + + dbsql "github.com/databricks/databricks-sql-go" + dbsqlctx "github.com/databricks/databricks-sql-go/driverctx" + dbsqllog "github.com/databricks/databricks-sql-go/logger" + "github.com/joho/godotenv" +) + +func main() { + // use this package to set up logging. By default logging level is `warn`. If you want to disable logging, use `disabled` + if err := dbsqllog.SetLogLevel("debug"); err != nil { + log.Fatal(err) + } + // sets the logging output. By default it will use os.Stderr. If running in terminal, it will use ConsoleWriter to make it pretty + // dbsqllog.SetLogOutput(os.Stdout) + + // this is just to make it easy to load all variables + if err := godotenv.Load(); err != nil { + log.Fatal(err) + } + port, err := strconv.Atoi(os.Getenv("DATABRICKS_PORT")) + if err != nil { + log.Fatal(err) + } + + // programmatically initializes the connector + // another way is to use a DNS. In this case the equivalent DNS would be: + // "token:@hostname:port/http_path?catalog=hive_metastore&schema=default&timeout=60&maxRows=10&&timezone=America/Sao_Paulo&ANSI_MODE=true" + connector, err := dbsql.NewConnector( + // minimum configuration + dbsql.WithServerHostname(os.Getenv("DATABRICKS_HOST")), + dbsql.WithPort(port), + dbsql.WithHTTPPath(os.Getenv("DATABRICKS_HTTPPATH")), + dbsql.WithAccessToken(os.Getenv("DATABRICKS_ACCESSTOKEN")), + //optional configuration + dbsql.WithSessionParams(map[string]string{"timezone": "America/Sao_Paulo", "ansi_mode": "true"}), + dbsql.WithUserAgentEntry("workflow-example"), + dbsql.WithInitialNamespace("hive_metastore", "default"), + dbsql.WithTimeout(time.Minute), // defaults to no timeout. Global timeout. Any query will be canceled if taking more than this time. + dbsql.WithMaxRows(10), // defaults to 10000 + ) + if err != nil { + // This will not be a connection error, but a DSN parse error or + // another initialization error. + log.Fatal(err) + + } + // Opening a driver typically will not attempt to connect to the database. + db := sql.OpenDB(connector) + // make sure to close it later + defer db.Close() + + ogCtx := dbsqlctx.NewContextWithCorrelationId(context.Background(), "createdrop-example") + + // sets the timeout to 30 seconds. More than that we ping will fail. The default is 15 seconds + ctx1, cancel := context.WithTimeout(ogCtx, 30*time.Second) + defer cancel() + if err := db.PingContext(ctx1); err != nil { + log.Fatal(err) + } + + // create a table with some data. This has no context timeout, it will follow the timeout of one minute set for the connection. + if _, err := db.ExecContext(ogCtx, `CREATE TABLE IF NOT EXISTS diamonds USING CSV LOCATION '/databricks-datasets/Rdatasets/data-001/csv/ggplot2/diamonds.csv' options (header = true, inferSchema = true)`); err != nil { + log.Fatal(err) + } + + if _, err := db.ExecContext(ogCtx, `DROP TABLE diamonds `); err != nil { + log.Fatal(err) + } +} diff --git a/examples/queryrow/main.go b/examples/queryrow/main.go index cdca764..dfd60d3 100644 --- a/examples/queryrow/main.go +++ b/examples/queryrow/main.go @@ -40,7 +40,7 @@ func main() { // defer cancel() ctx := context.Background() var res float64 - err1 := db.QueryRowContext(ctx, `select max(carat) from default.diamonds`).Scan(res) + err1 := db.QueryRowContext(ctx, `select max(carat) from default.diamonds`).Scan(&res) if err1 != nil { if err1 == sql.ErrNoRows { diff --git a/examples/queryrows/main.go b/examples/queryrows/main.go index b49faf4..e0f28da 100644 --- a/examples/queryrows/main.go +++ b/examples/queryrows/main.go @@ -83,7 +83,7 @@ func main() { rows.Close() return } - // fmt.Printf("%v, %v\n", res1, res2) + fmt.Printf("%v, %v\n", res1, res2) } } diff --git a/examples/workflow/main.go b/examples/workflow/main.go index b43ab19..bd89294 100644 --- a/examples/workflow/main.go +++ b/examples/workflow/main.go @@ -4,6 +4,7 @@ import ( "context" "database/sql" "fmt" + "log" "os" "strconv" "time" @@ -17,18 +18,18 @@ import ( func main() { // use this package to set up logging. By default logging level is `warn`. If you want to disable logging, use `disabled` if err := dbsqllog.SetLogLevel("debug"); err != nil { - panic(err) + log.Fatal(err) } // sets the logging output. By default it will use os.Stderr. If running in terminal, it will use ConsoleWriter to make it pretty // dbsqllog.SetLogOutput(os.Stdout) // this is just to make it easy to load all variables if err := godotenv.Load(); err != nil { - panic(err) + log.Fatal(err) } port, err := strconv.Atoi(os.Getenv("DATABRICKS_PORT")) if err != nil { - panic(err) + log.Fatal(err) } // programmatically initializes the connector @@ -50,7 +51,7 @@ func main() { if err != nil { // This will not be a connection error, but a DSN parse error or // another initialization error. - panic(err) + log.Fatal(err) } // Opening a driver typically will not attempt to connect to the database. @@ -88,18 +89,18 @@ func main() { ctx1, cancel := context.WithTimeout(ogCtx, 30*time.Second) defer cancel() if err := db.PingContext(ctx1); err != nil { - panic(err) + log.Fatal(err) } // create a table with some data. This has no context timeout, it will follow the timeout of one minute set for the connection. if _, err := db.ExecContext(ogCtx, `CREATE TABLE IF NOT EXISTS diamonds USING CSV LOCATION '/databricks-datasets/Rdatasets/data-001/csv/ggplot2/diamonds.csv' options (header = true, inferSchema = true)`); err != nil { - panic(err) + log.Fatal(err) } // QueryRowContext is a shortcut function to get a single value var max float64 if err := db.QueryRowContext(ogCtx, `select max(carat) from diamonds`).Scan(&max); err != nil { - panic(err) + log.Fatal(err) } else { fmt.Printf("max carat in dataset is: %f\n", max) } @@ -109,7 +110,7 @@ func main() { defer cancel() if rows, err := db.QueryContext(ctx2, "select * from diamonds limit 19"); err != nil { - panic(err) + log.Fatal(err) } else { type row struct { _c0 int @@ -127,11 +128,11 @@ func main() { cols, err := rows.Columns() if err != nil { - panic(err) + log.Fatal(err) } types, err := rows.ColumnTypes() if err != nil { - panic(err) + log.Fatal(err) } for i, c := range cols { fmt.Printf("column %d is %s and has type %v\n", i, c, types[i].DatabaseTypeName()) @@ -141,7 +142,7 @@ func main() { // After row 10 this will cause one fetch call, as 10 rows (maxRows config) will come from the first execute statement call. r := row{} if err := rows.Scan(&r._c0, &r.carat, &r.cut, &r.color, &r.clarity, &r.depth, &r.table, &r.price, &r.x, &r.y, &r.z); err != nil { - panic(err) + log.Fatal(err) } res = append(res, r) } @@ -156,7 +157,7 @@ func main() { var curTimezone string if err := db.QueryRowContext(ogCtx, `select current_date(), current_timestamp(), current_timezone()`).Scan(&curDate, &curTimestamp, &curTimezone); err != nil { - panic(err) + log.Fatal(err) } else { // this will print now at timezone America/Sao_Paulo is: 2022-11-16 20:25:15.282 -0300 -03 fmt.Printf("current timestamp at timezone %s is: %s\n", curTimezone, curTimestamp) @@ -170,11 +171,11 @@ func main() { array_col array < int >, map_col map < string, int >, struct_col struct < string_field string, array_field array < int > >)`); err != nil { - panic(err) + log.Fatal(err) } var numRows int if err := db.QueryRowContext(ogCtx, `select count(*) from array_map_struct`).Scan(&numRows); err != nil { - panic(err) + log.Fatal(err) } else { fmt.Printf("table has %d rows\n", numRows) } @@ -186,18 +187,18 @@ func main() { array(1, 2, 3), map('key1', 1), struct('string_val', array(4, 5, 6)))`); err != nil { - panic(err) + log.Fatal(err) } else { i, err1 := res.RowsAffected() if err1 != nil { - panic(err1) + log.Fatal(err1) } fmt.Printf("inserted %d rows", i) } } if rows, err := db.QueryContext(ogCtx, "select * from array_map_struct"); err != nil { - panic(err) + log.Fatal(err) } else { // complex data types are returned as string type row struct { @@ -208,11 +209,11 @@ func main() { res := []row{} cols, err := rows.Columns() if err != nil { - panic(err) + log.Fatal(err) } types, err := rows.ColumnTypes() if err != nil { - panic(err) + log.Fatal(err) } for i, c := range cols { fmt.Printf("column %d is %s and has type %v\n", i, c, types[i].DatabaseTypeName()) @@ -221,7 +222,7 @@ func main() { for rows.Next() { r := row{} if err := rows.Scan(&r.arrayVal, &r.mapVal, &r.structVal); err != nil { - panic(err) + log.Fatal(err) } res = append(res, r) } diff --git a/go.mod b/go.mod index d0193de..95e0e28 100644 --- a/go.mod +++ b/go.mod @@ -16,6 +16,7 @@ require ( github.com/fatih/color v1.13.0 // indirect github.com/fsnotify/fsnotify v1.5.4 // indirect github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 // indirect + github.com/hashicorp/go-cleanhttp v0.5.1 // indirect github.com/kr/pretty v0.2.1 // indirect github.com/mattn/go-colorable v0.1.12 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect @@ -28,6 +29,7 @@ require ( ) require ( + github.com/hashicorp/go-retryablehttp v0.7.1 github.com/pkg/errors v0.9.1 github.com/rs/zerolog v1.28.0 golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab // indirect diff --git a/go.sum b/go.sum index 7266908..e190418 100644 --- a/go.sum +++ b/go.sum @@ -16,6 +16,12 @@ github.com/google/go-cmp v0.5.8 h1:e6P7q2lk1O+qJJb4BtCQXlK8vWEO8V1ZeuEdJNOqZyg= github.com/google/go-cmp v0.5.8/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 h1:El6M4kTTCOh6aBiKaUGG7oYTSPP8MxqL4YI3kZKwcP4= github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510/go.mod h1:pupxD2MaaD3pAXIBCelhxNneeOaAeabZDe5s4K6zSpQ= +github.com/hashicorp/go-cleanhttp v0.5.1 h1:dH3aiDG9Jvb5r5+bYHsikaOUIpcM0xvgMXVoDkXMzJM= +github.com/hashicorp/go-cleanhttp v0.5.1/go.mod h1:JpRdi6/HCYpAwUzNwuwqhbovhLtngrth3wmdIIUrZ80= +github.com/hashicorp/go-hclog v0.9.2 h1:CG6TE5H9/JXsFWJCfoIVpKFIkFe6ysEuHirp4DxCsHI= +github.com/hashicorp/go-hclog v0.9.2/go.mod h1:5CU+agLiy3J7N7QjHK5d05KxGsuXiQLrjA0H7acj2lQ= +github.com/hashicorp/go-retryablehttp v0.7.1 h1:sUiuQAnLlbvmExtFQs72iFW/HXeUn8Z1aJLQ4LJJbTQ= +github.com/hashicorp/go-retryablehttp v0.7.1/go.mod h1:vAew36LZh98gCBJNLH42IQ1ER/9wtLZZ8meHqQvEYWY= github.com/joho/godotenv v1.4.0 h1:3l4+N6zfMWnkbPEXKng2o2/MR5mSwTrBih4ZEkkz1lg= github.com/joho/godotenv v1.4.0/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= github.com/kr/pretty v0.2.1 h1:Fmg33tUaq4/8ym9TJN1x7sLJnHVwhP33CNkpYV/7rwI= @@ -41,6 +47,7 @@ github.com/spf13/pflag v1.0.3/go.mod h1:DYY7MBk1bdzusC3SYhjObp+wFpr4gzcvqqNjLnIn github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= diff --git a/internal/client/client.go b/internal/client/client.go index 6a4bfba..840b3f6 100644 --- a/internal/client/client.go +++ b/internal/client/client.go @@ -5,7 +5,6 @@ import ( "context" "encoding/json" "fmt" - "net/http" "os" "github.com/apache/thrift/lib/go/thrift" @@ -13,18 +12,20 @@ import ( "github.com/databricks/databricks-sql-go/internal/cli_service" "github.com/databricks/databricks-sql-go/internal/config" "github.com/databricks/databricks-sql-go/logger" + "github.com/hashicorp/go-retryablehttp" "github.com/pkg/errors" ) -// this is used to generate test data. Developer should change this manually +// RecordResults is used to generate test data. Developer should change this manually var RecordResults bool var resultIndex int type ThriftServiceClient struct { *cli_service.TCLIServiceClient - transport *Transport } +// OpenSession is a wrapper around the thrift operation OpenSession +// If RecordResults is true, the results will be marshalled to JSON format and written to OpenSession.json func (tsc *ThriftServiceClient) OpenSession(ctx context.Context, req *cli_service.TOpenSessionReq) (*cli_service.TOpenSessionResp, error) { msg, start := logger.Track("OpenSession") resp, err := tsc.TCLIServiceClient.OpenSession(ctx, req) @@ -41,6 +42,8 @@ func (tsc *ThriftServiceClient) OpenSession(ctx context.Context, req *cli_servic return resp, CheckStatus(resp) } +// CloseSession is a wrapper around the thrift operation CloseSession +// If RecordResults is true, the results will be marshalled to JSON format and written to CloseSession.json func (tsc *ThriftServiceClient) CloseSession(ctx context.Context, req *cli_service.TCloseSessionReq) (*cli_service.TCloseSessionResp, error) { log := logger.WithContext(driverctx.ConnIdFromContext(ctx), driverctx.CorrelationIdFromContext(ctx), "") defer log.Duration(logger.Track("CloseSession")) @@ -56,6 +59,8 @@ func (tsc *ThriftServiceClient) CloseSession(ctx context.Context, req *cli_servi return resp, CheckStatus(resp) } +// FetchResults is a wrapper around the thrift operation FetchResults +// If RecordResults is true, the results will be marshalled to JSON format and written to FetchResults.json func (tsc *ThriftServiceClient) FetchResults(ctx context.Context, req *cli_service.TFetchResultsReq) (*cli_service.TFetchResultsResp, error) { log := logger.WithContext(driverctx.ConnIdFromContext(ctx), driverctx.CorrelationIdFromContext(ctx), SprintGuid(req.OperationHandle.OperationId.GUID)) defer log.Duration(logger.Track("FetchResults")) @@ -71,6 +76,8 @@ func (tsc *ThriftServiceClient) FetchResults(ctx context.Context, req *cli_servi return resp, CheckStatus(resp) } +// GetResultSetMetadata is a wrapper around the thrift operation GetResultSetMetadata +// If RecordResults is true, the results will be marshalled to JSON format and written to GetResultSetMetadata.json func (tsc *ThriftServiceClient) GetResultSetMetadata(ctx context.Context, req *cli_service.TGetResultSetMetadataReq) (*cli_service.TGetResultSetMetadataResp, error) { log := logger.WithContext(driverctx.ConnIdFromContext(ctx), driverctx.CorrelationIdFromContext(ctx), SprintGuid(req.OperationHandle.OperationId.GUID)) defer log.Duration(logger.Track("GetResultSetMetadata")) @@ -86,9 +93,11 @@ func (tsc *ThriftServiceClient) GetResultSetMetadata(ctx context.Context, req *c return resp, CheckStatus(resp) } +// ExecuteStatement is a wrapper around the thrift operation ExecuteStatement +// If RecordResults is true, the results will be marshalled to JSON format and written to ExecuteStatement.json func (tsc *ThriftServiceClient) ExecuteStatement(ctx context.Context, req *cli_service.TExecuteStatementReq) (*cli_service.TExecuteStatementResp, error) { msg, start := logger.Track("ExecuteStatement") - resp, err := tsc.TCLIServiceClient.ExecuteStatement(ctx, req) + resp, err := tsc.TCLIServiceClient.ExecuteStatement(context.Background(), req) if err != nil { return resp, errors.Wrap(err, "execute statement request error") } @@ -107,6 +116,8 @@ func (tsc *ThriftServiceClient) ExecuteStatement(ctx context.Context, req *cli_s return resp, CheckStatus(resp) } +// GetOperationStatus is a wrapper around the thrift operation GetOperationStatus +// If RecordResults is true, the results will be marshalled to JSON format and written to GetOperationStatus.json func (tsc *ThriftServiceClient) GetOperationStatus(ctx context.Context, req *cli_service.TGetOperationStatusReq) (*cli_service.TGetOperationStatusResp, error) { log := logger.WithContext(driverctx.ConnIdFromContext(ctx), driverctx.CorrelationIdFromContext(ctx), SprintGuid(req.OperationHandle.OperationId.GUID)) defer log.Duration(logger.Track("GetOperationStatus")) @@ -122,6 +133,8 @@ func (tsc *ThriftServiceClient) GetOperationStatus(ctx context.Context, req *cli return resp, CheckStatus(resp) } +// CloseOperation is a wrapper around the thrift operation CloseOperation +// If RecordResults is true, the results will be marshalled to JSON format and written to CloseOperation.json func (tsc *ThriftServiceClient) CloseOperation(ctx context.Context, req *cli_service.TCloseOperationReq) (*cli_service.TCloseOperationResp, error) { log := logger.WithContext(driverctx.ConnIdFromContext(ctx), driverctx.CorrelationIdFromContext(ctx), SprintGuid(req.OperationHandle.OperationId.GUID)) defer log.Duration(logger.Track("CloseOperation")) @@ -137,6 +150,8 @@ func (tsc *ThriftServiceClient) CloseOperation(ctx context.Context, req *cli_ser return resp, CheckStatus(resp) } +// CancelOperation is a wrapper around the thrift operation CancelOperation +// If RecordResults is true, the results will be marshalled to JSON format and written to CancelOperation.json func (tsc *ThriftServiceClient) CancelOperation(ctx context.Context, req *cli_service.TCancelOperationReq) (*cli_service.TCancelOperationResp, error) { log := logger.WithContext(driverctx.ConnIdFromContext(ctx), driverctx.CorrelationIdFromContext(ctx), SprintGuid(req.OperationHandle.OperationId.GUID)) defer log.Duration(logger.Track("CancelOperation")) @@ -152,25 +167,8 @@ func (tsc *ThriftServiceClient) CancelOperation(ctx context.Context, req *cli_se return resp, CheckStatus(resp) } -// log.Debug().Msg(fmt.Sprint(c.transport.response.StatusCode)) -// log.Debug().Msg(c.transport.response.Header.Get("X-Databricks-Org-Id")) -// log.Debug().Msg(c.transport.response.Header.Get("x-databricks-error-or-redirect-message")) -// log.Debug().Msg(c.transport.response.Header.Get("x-thriftserver-error-message")) -// log.Debug().Msg(c.transport.response.Header.Get("x-databricks-reason-phrase")) - -// This is a wrapper of the http transport so we can have access to response code and headers +// InitThriftClient is a wrapper of the http transport, so we can have access to response code and headers. // It is important to know the code and headers to know if we need to retry or not -type Transport struct { - *http.Transport - response *http.Response -} - -func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) { - resp, err := t.Transport.RoundTrip(req) - t.response = resp - return resp, err -} - func InitThriftClient(cfg *config.Config) (*ThriftServiceClient, error) { endpoint := cfg.ToEndpointURL() tcfg := &thrift.TConfiguration{ @@ -197,31 +195,24 @@ func InitThriftClient(cfg *config.Config) (*ThriftServiceClient, error) { } var tTrans thrift.TTransport - var tr *Transport var err error switch cfg.ThriftTransport { case "http": - tr = &Transport{ - Transport: &http.Transport{ - TLSClientConfig: cfg.TLSConfig, - }, - } - httpclient := &http.Client{ - Transport: tr, - Timeout: cfg.ClientTimeout, - } - tTrans, err = thrift.NewTHttpClientWithOptions(endpoint, thrift.THttpClientOptions{Client: httpclient}) + retryableClient := retryablehttp.NewClient() + retryableClient.HTTPClient.Timeout = cfg.ClientTimeout + // TODO + // add custom retryableClient.CheckRetry to retry based on thrift server headers and response code + tTrans, err = thrift.NewTHttpClientWithOptions(endpoint, thrift.THttpClientOptions{Client: retryableClient.HTTPClient}) if err != nil { return nil, err } - - httpTransport := tTrans.(*thrift.THttpClient) + thriftHttpClient := tTrans.(*thrift.THttpClient) userAgent := fmt.Sprintf("%s/%s", cfg.DriverName, cfg.DriverVersion) if cfg.UserAgentEntry != "" { userAgent = fmt.Sprintf("%s/%s (%s)", cfg.DriverName, cfg.DriverVersion, cfg.UserAgentEntry) } - httpTransport.SetHeader("User-Agent", userAgent) + thriftHttpClient.SetHeader("User-Agent", userAgent) case "framed": tTrans = thrift.NewTFramedTransportConf(tTrans, tcfg) @@ -241,15 +232,17 @@ func InitThriftClient(cfg *config.Config) (*ThriftServiceClient, error) { iprot := protocolFactory.GetProtocol(tTrans) oprot := protocolFactory.GetProtocol(tTrans) tclient := cli_service.NewTCLIServiceClient(thrift.NewTStandardClient(iprot, oprot)) - tsClient := &ThriftServiceClient{tclient, tr} + tsClient := &ThriftServiceClient{tclient} return tsClient, nil } -// ThriftResponse respresents thrift rpc response +// ThriftResponse represents the thrift rpc response type ThriftResponse interface { GetStatus() *cli_service.TStatus } +// CheckStatus checks the status code after a thrift operation. +// Returns nil if the operation is successful or still executing, otherwise returns an error. func CheckStatus(resp interface{}) error { rpcresp, ok := resp.(ThriftResponse) if ok { @@ -268,6 +261,7 @@ func CheckStatus(resp interface{}) error { return errors.New("thrift: invalid response") } +// SprintGuid is a convenience function to format a byte array into GUID. func SprintGuid(bts []byte) string { if len(bts) == 16 { return fmt.Sprintf("%x-%x-%x-%x-%x", bts[0:4], bts[4:6], bts[6:8], bts[8:10], bts[10:16]) diff --git a/internal/config/config.go b/internal/config/config.go index 3bdf018..ff26bf6 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -13,7 +13,7 @@ import ( "github.com/pkg/errors" ) -// Driver Configurations +// Driver Configurations. // Only UserConfig are currently exposed to users type Config struct { UserConfig @@ -34,6 +34,7 @@ type Config struct { ThriftDebugClientProtocol bool } +// ToEndpointURL generates the endpoint URL from Config that a Thrift client will connect to func (c *Config) ToEndpointURL() string { var userInfo string if c.AccessToken != "" { @@ -43,6 +44,7 @@ func (c *Config) ToEndpointURL() string { return endpointUrl } +// DeepCopy returns a true deep copy of Config func (c *Config) DeepCopy() *Config { if c == nil { return nil @@ -84,6 +86,7 @@ type UserConfig struct { SessionParams map[string]string } +// DeepCopy returns a true deep copy of UserConfig func (ucfg UserConfig) DeepCopy() UserConfig { var sessionParams map[string]string if ucfg.SessionParams != nil { @@ -118,17 +121,25 @@ func (ucfg UserConfig) DeepCopy() UserConfig { } } +var defaultMaxRows = 100000 + +// WithDefaults provides default settings for optional fields in UserConfig func (ucfg UserConfig) WithDefaults() UserConfig { if ucfg.MaxRows <= 0 { - ucfg.MaxRows = 10000 + ucfg.MaxRows = defaultMaxRows } if ucfg.Protocol == "" { ucfg.Protocol = "https" + ucfg.Port = 443 + } + if ucfg.Port == 0 { + ucfg.Port = 443 } ucfg.SessionParams = make(map[string]string) return ucfg } +// WithDefaults provides default settings for Config func WithDefaults() *Config { return &Config{ UserConfig: UserConfig{}.WithDefaults(), @@ -140,7 +151,7 @@ func WithDefaults() *Config { ClientTimeout: 900 * time.Second, PingTimeout: 15 * time.Second, CanUseMultipleCatalogs: true, - DriverName: "godatabrickssqlconnector", //important. Do not change + DriverName: "godatabrickssqlconnector", // important. Do not change DriverVersion: "0.9.0", ThriftProtocol: "binary", ThriftTransport: "http", @@ -150,6 +161,7 @@ func WithDefaults() *Config { } +// ParseDSN constructs UserConfig by parsing DSN string supplied to `sql.Open()` func ParseDSN(dsn string) (UserConfig, error) { fullDSN := dsn if !strings.HasPrefix(dsn, "https://") && !strings.HasPrefix(dsn, "http://") { diff --git a/internal/config/config_test.go b/internal/config/config_test.go index 7ef42c8..b9e2f87 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -28,7 +28,7 @@ func TestParseConfig(t *testing.T) { Protocol: "https", Host: "example.cloud.databricks.com", Port: 443, - MaxRows: 10000, + MaxRows: defaultMaxRows, AccessToken: "supersecret", HTTPPath: "/sql/1.0/endpoints/12346a5b5b0e123a", SessionParams: make(map[string]string), @@ -43,7 +43,7 @@ func TestParseConfig(t *testing.T) { Protocol: "https", Host: "example.cloud.databricks.com", Port: 443, - MaxRows: 10000, + MaxRows: defaultMaxRows, AccessToken: "supersecret", HTTPPath: "/sql/1.0/endpoints/12346a5b5b0e123a", SessionParams: make(map[string]string), @@ -58,7 +58,7 @@ func TestParseConfig(t *testing.T) { Protocol: "http", Host: "localhost", Port: 8080, - MaxRows: 10000, + MaxRows: defaultMaxRows, HTTPPath: "/sql/1.0/endpoints/12346a5b5b0e123a", SessionParams: make(map[string]string), }, @@ -72,7 +72,7 @@ func TestParseConfig(t *testing.T) { Protocol: "http", Host: "localhost", Port: 8080, - MaxRows: 10000, + MaxRows: defaultMaxRows, SessionParams: make(map[string]string), }, wantErr: false, @@ -118,7 +118,7 @@ func TestParseConfig(t *testing.T) { Protocol: "https", Host: "example.cloud.databricks.com", Port: 8000, - MaxRows: 10000, + MaxRows: defaultMaxRows, HTTPPath: "/sql/1.0/endpoints/12346a5b5b0e123a", SessionParams: make(map[string]string), }, @@ -132,7 +132,7 @@ func TestParseConfig(t *testing.T) { Protocol: "https", Host: "example.cloud.databricks.com", Port: 8000, - MaxRows: 10000, + MaxRows: defaultMaxRows, AccessToken: "supersecret", HTTPPath: "/sql/1.0/endpoints/12346a5b5b0e123b", Catalog: "default", @@ -148,7 +148,7 @@ func TestParseConfig(t *testing.T) { Protocol: "https", Host: "example.cloud.databricks.com", Port: 8000, - MaxRows: 10000, + MaxRows: defaultMaxRows, AccessToken: "supersecret", HTTPPath: "/sql/1.0/endpoints/12346a5b5b0e123b", UserAgentEntry: "partner-name", @@ -164,7 +164,7 @@ func TestParseConfig(t *testing.T) { Protocol: "https", Host: "example.cloud.databricks.com", Port: 8000, - MaxRows: 10000, + MaxRows: defaultMaxRows, AccessToken: "supersecret2", HTTPPath: "/sql/1.0/endpoints/12346a5b5b0e123a", Schema: "system", @@ -199,7 +199,7 @@ func TestParseConfig(t *testing.T) { Protocol: "https", Host: "example.cloud.databricks.com", Port: 443, - MaxRows: 10000, + MaxRows: defaultMaxRows, AccessToken: "supersecret", SessionParams: make(map[string]string), }, diff --git a/internal/sentinel/sentinel.go b/internal/sentinel/sentinel.go index 726d2fb..45d8e11 100644 --- a/internal/sentinel/sentinel.go +++ b/internal/sentinel/sentinel.go @@ -10,7 +10,7 @@ import ( ) const ( - DEFAULT_TIMEOUT = 0 + DEFAULT_TIMEOUT = 0 //no timeout DEFAULT_INTERVAL = 100 * time.Millisecond ) diff --git a/internal/sentinel/sentinel_test.go b/internal/sentinel/sentinel_test.go index 20b5ce1..e2cd459 100644 --- a/internal/sentinel/sentinel_test.go +++ b/internal/sentinel/sentinel_test.go @@ -144,7 +144,7 @@ func TestWatch(t *testing.T) { assert.Nil(t, res) assert.Error(t, err) }) - t.Run("it should call cancelFn upon cancelation", func(t *testing.T) { + t.Run("it should call cancelFn upon cancellation while polling", func(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) go func() { time.Sleep(100 * time.Millisecond) diff --git a/logger/logger.go b/logger/logger.go index 1ba3402..e048eaa 100644 --- a/logger/logger.go +++ b/logger/logger.go @@ -14,10 +14,24 @@ type DBSQLLogger struct { zerolog.Logger } +// Track is a simple utility function to use with logger to log a message with a timestamp. +// Recommended to use in conjunction with Duration. +// +// For example: +// +// msg, start := log.Track("Run operation") +// defer log.Duration(msg, start) func (l *DBSQLLogger) Track(msg string) (string, time.Time) { return msg, time.Now() } +// Duration logs a debug message with the time elapsed between the provided start and the current time. +// Use in conjunction with Track. +// +// For example: +// +// msg, start := log.Track("Run operation") +// defer log.Duration(msg, start) func (l *DBSQLLogger) Duration(msg string, start time.Time) { l.Debug().Msgf("%v elapsed time: %v", msg, time.Since(start)) } @@ -26,7 +40,7 @@ var Logger = &DBSQLLogger{ zerolog.New(os.Stderr).With().Timestamp().Logger(), } -// enable pretty printing for interactive terminals and json for production. +// Enable pretty printing for interactive terminals and json for production. func init() { // for tty terminal enable pretty logs if isatty.IsTerminal(os.Stdout.Fd()) && runtime.GOOS != "windows" { @@ -114,7 +128,7 @@ func Err(err error) *zerolog.Event { return Logger.Err(err) } -// WithContext sets connectionId, correlationId, and queryID to be used as fields. +// WithContext sets connectionId, correlationId, and queryId to be used as fields. func WithContext(connectionId string, correlationId string, queryId string) *DBSQLLogger { return &DBSQLLogger{Logger.With().Str("connId", connectionId).Str("corrId", correlationId).Str("queryId", queryId).Logger()} } diff --git a/result.go b/result.go index ab12e81..bb8bfa5 100644 --- a/result.go +++ b/result.go @@ -9,10 +9,13 @@ type result struct { var _ driver.Result = (*result)(nil) +// LastInsertId returns the database's auto-generated ID after an insert into a table. +// This is currently not really implemented for this driver and will always return 0. func (res *result) LastInsertId() (int64, error) { return res.InsertId, nil } +// RowsAffected returns the number of rows affected by the query. func (res *result) RowsAffected() (int64, error) { return res.AffectedRows, nil } diff --git a/rows.go b/rows.go index 5c9da4c..a71ab08 100644 --- a/rows.go +++ b/rows.go @@ -28,6 +28,7 @@ type rows struct { fetchResultsMetadata *cli_service.TGetResultSetMetadataResp nextRowIndex int64 nextRowNumber int64 + closed bool } var _ driver.Rows = (*rows)(nil) @@ -36,10 +37,34 @@ var _ driver.RowsColumnTypeDatabaseTypeName = (*rows)(nil) var _ driver.RowsColumnTypeNullable = (*rows)(nil) var _ driver.RowsColumnTypeLength = (*rows)(nil) -var errRowsFetchPriorToStart = "unable to fetch row page prior to start of results" -var errRowsNoSchemaAvailable = "no schema in result set metadata response" -var errRowsNoClient = "instance of Rows missing client" -var errRowsNilRows = "nil Rows instance" +var errRowsFetchPriorToStart = "databricks: unable to fetch row page prior to start of results" +var errRowsNoSchemaAvailable = "databricks: no schema in result set metadata response" +var errRowsNoClient = "databricks: instance of Rows missing client" +var errRowsNilRows = "databricks: nil Rows instance" +var errRowsParseValue = "databricks: unable to parse %s value '%s' from column %s" + +// NewRows generates a new rows object given the rows' fields. +// NewRows will also parse directResults if it is available for some rows' fields. +func NewRows(connID string, corrId string, client cli_service.TCLIService, opHandle *cli_service.TOperationHandle, pageSize int64, location *time.Location, directResults *cli_service.TSparkDirectResults) driver.Rows { + r := &rows{ + connId: connID, + correlationId: corrId, + client: client, + opHandle: opHandle, + pageSize: pageSize, + location: location, + } + + if directResults != nil { + r.fetchResults = directResults.ResultSet + r.fetchResultsMetadata = directResults.ResultSetMetadata + if directResults.CloseOperation != nil { + r.closed = true + } + } + + return r +} // Columns returns the names of the columns. The number of // columns of the result is inferred from the length of the @@ -72,20 +97,23 @@ func (r *rows) Columns() []string { // Close closes the rows iterator. func (r *rows) Close() error { - err := isValidRows(r) - if err != nil { - return err - } + if !r.closed { + err := isValidRows(r) + if err != nil { + return err + } - req := cli_service.TCloseOperationReq{ - OperationHandle: r.opHandle, - } - ctx := driverctx.NewContextWithCorrelationId(driverctx.NewContextWithConnId(context.Background(), r.connId), r.correlationId) + req := cli_service.TCloseOperationReq{ + OperationHandle: r.opHandle, + } + ctx := driverctx.NewContextWithCorrelationId(driverctx.NewContextWithConnId(context.Background(), r.connId), r.correlationId) - _, err1 := r.client.CloseOperation(ctx, &req) - if err1 != nil { - return err1 + _, err1 := r.client.CloseOperation(ctx, &req) + if err1 != nil { + return err1 + } } + return nil } @@ -119,7 +147,7 @@ func (r *rows) Next(dest []driver.Value) error { return err } - // populate the destinatino slice + // populate the destination slice for i := range dest { val, err := value(r.fetchResults.Results.Columns[i], metadata.Schema.Columns[i], r.nextRowIndex, r.location) @@ -174,10 +202,9 @@ func (r *rows) ColumnTypeDatabaseTypeName(index int) string { } // ColumnTypeNullable returns a flag indicating whether the column is nullable -// and an ok value of true if the status of the column is known. Otherwise +// and an ok value of true if the status of the column is known. Otherwise // a value of false is returned for ok. func (r *rows) ColumnTypeNullable(index int) (nullable, ok bool) { - // TODO: Update if we can figure out this information return false, false } @@ -188,8 +215,6 @@ func (r *rows) ColumnTypeLength(index int) (length int64, ok bool) { } typeName := getDBTypeID(columnInfo) - // TODO: figure out how to get better metadata about complex types - // currently map, array, and struct are returned as strings switch typeName { case cli_service.TTypeId_STRING_TYPE, cli_service.TTypeId_VARCHAR_TYPE, @@ -215,12 +240,11 @@ var ( scanTypeString = reflect.TypeOf("") scanTypeDateTime = reflect.TypeOf(time.Time{}) scanTypeRawBytes = reflect.TypeOf(sql.RawBytes{}) - scanTypeUnknown = reflect.TypeOf(new(interface{})) + scanTypeUnknown = reflect.TypeOf(new(any)) ) func getScanType(column *cli_service.TColumnDesc) reflect.Type { - // TODO: handle non-primitive types entry := column.TypeDesc.Types[0].PrimitiveEntry switch entry.Type { @@ -261,7 +285,6 @@ func getScanType(column *cli_service.TColumnDesc) reflect.Type { } func getDBTypeName(column *cli_service.TColumnDesc) string { - // TODO: handle non-primitive types entry := column.TypeDesc.Types[0].PrimitiveEntry dbtype := strings.TrimSuffix(entry.Type.String(), "_TYPE") @@ -269,7 +292,6 @@ func getDBTypeName(column *cli_service.TColumnDesc) string { } func getDBTypeID(column *cli_service.TColumnDesc) cli_service.TTypeId { - // TODO: handle non-primitive types entry := column.TypeDesc.Types[0].PrimitiveEntry return entry.Type } @@ -308,7 +330,6 @@ func (r *rows) getColumnMetadataByIndex(index int) (*cli_service.TColumnDesc, er return nil, errors.Errorf("invalid column index: %d", index) } - // tColumns := resultMetadata.Schema.GetColumns() return columns[index], nil } @@ -366,7 +387,7 @@ func (r *rows) fetchResultPage() error { for !r.isNextRowInPage() { - // determine the direction of page fetching. Currently we only handle + // determine the direction of page fetching. Currently we only handle // TFetchOrientation_FETCH_PRIOR and TFetchOrientation_FETCH_NEXT var direction cli_service.TFetchOrientation = r.getPageFetchDirection() if direction == cli_service.TFetchOrientation_FETCH_PRIOR { @@ -429,13 +450,12 @@ func (r *rows) getPageStartRowNum() int64 { return r.fetchResults.GetResults().GetStartRowOffset() } -const ( - // TimestampFormat is JDBC compliant timestamp format - TimestampFormat = "2006-01-02 15:04:05.999999999" - DateFormat = "2006-01-02" -) +var dateTimeFormats map[string]string = map[string]string{ + "TIMESTAMP": "2006-01-02 15:04:05.999999999", + "DATE": "2006-01-02", +} -func value(tColumn *cli_service.TColumn, tColumnDesc *cli_service.TColumnDesc, rowNum int64, location *time.Location) (val interface{}, err error) { +func value(tColumn *cli_service.TColumn, tColumnDesc *cli_service.TColumnDesc, rowNum int64, location *time.Location) (val any, err error) { if location == nil { location = time.UTC } @@ -444,17 +464,7 @@ func value(tColumn *cli_service.TColumn, tColumnDesc *cli_service.TColumnDesc, r dbtype := strings.TrimSuffix(entry.Type.String(), "_TYPE") if tVal := tColumn.GetStringVal(); tVal != nil && !isNull(tVal.Nulls, rowNum) { val = tVal.Values[rowNum] - if dbtype == "TIMESTAMP" { - t, err := time.ParseInLocation(TimestampFormat, val.(string), location) - if err == nil { - val = t - } - } else if dbtype == "DATE" { - t, err := time.ParseInLocation(DateFormat, val.(string), location) - if err == nil { - val = t - } - } + val, err = handleDateTime(val, dbtype, tColumnDesc.ColumnName, location) } else if tVal := tColumn.GetByteVal(); tVal != nil && !isNull(tVal.Nulls, rowNum) { val = tVal.Values[rowNum] } else if tVal := tColumn.GetI16Val(); tVal != nil && !isNull(tVal.Nulls, rowNum) { @@ -466,7 +476,15 @@ func value(tColumn *cli_service.TColumn, tColumnDesc *cli_service.TColumnDesc, r } else if tVal := tColumn.GetBoolVal(); tVal != nil && !isNull(tVal.Nulls, rowNum) { val = tVal.Values[rowNum] } else if tVal := tColumn.GetDoubleVal(); tVal != nil && !isNull(tVal.Nulls, rowNum) { - val = tVal.Values[rowNum] + if dbtype == "FLOAT" { + // database types FLOAT and DOUBLE are both returned as a float64 + // convert to a float32 is valid because the FLOAT type would have + // only been four bytes on the server + val = float32(tVal.Values[rowNum]) + } else { + val = tVal.Values[rowNum] + } + } else if tVal := tColumn.GetBinaryVal(); tVal != nil && !isNull(tVal.Nulls, rowNum) { val = tVal.Values[rowNum] } @@ -474,6 +492,21 @@ func value(tColumn *cli_service.TColumn, tColumnDesc *cli_service.TColumnDesc, r return val, err } +// handleDateTime will convert the passed val to a time.Time value if necessary +func handleDateTime(val any, dbType, columnName string, location *time.Location) (any, error) { + // if there is a date/time format corresponding to the column type we need to + // convert to time.Time + if format, ok := dateTimeFormats[dbType]; ok { + t, err := parseInLocation(format, val.(string), location) + if err != nil { + err = wrapErrf(err, errRowsParseValue, dbType, val, columnName) + } + return t, err + } + + return val, nil +} + func isNull(nulls []byte, position int64) bool { index := position / 8 if int64(len(nulls)) > index { @@ -515,3 +548,57 @@ func getNRows(rs *cli_service.TRowSet) int64 { } return 0 } + +// parseInLocation parses a date/time string in the given format and using the provided +// location. +// This is, essentially, a wrapper around time.ParseInLocation to handle negative year +// values +func parseInLocation(format, dateTimeString string, loc *time.Location) (time.Time, error) { + // we want to handle dates with negative year values and currently we only + // support formats that start with the year so we can just strip a leading minus + // sign + var isNegative bool + dateTimeString, isNegative = stripLeadingNegative(dateTimeString) + + date, err := time.ParseInLocation(format, dateTimeString, loc) + if err != nil { + return time.Time{}, err + } + + if isNegative { + date = date.AddDate(-2*date.Year(), 0, 0) + } + + return date, nil +} + +// stripLeadingNegative will remove a leading ascii or unicode minus +// if present. The possibly shortened string is returned and a flag indicating if +// the string was altered +func stripLeadingNegative(dateTimeString string) (string, bool) { + if dateStartsWithNegative(dateTimeString) { + // strip leading rune from dateTimeString + // using range because it is supposed to be faster than utf8.DecodeRuneInString + for i := range dateTimeString { + if i > 0 { + return dateTimeString[i:], true + } + } + } + + return dateTimeString, false +} + +// ISO 8601 allows for both the ascii and unicode characters for minus +const ( + // unicode minus sign + uMinus string = "\u2212" + // ascii hyphen/minus + aMinus string = "\x2D" +) + +// dateStartsWithNegative returns true if the string starts with +// a minus sign +func dateStartsWithNegative(val string) bool { + return strings.HasPrefix(val, aMinus) || strings.HasPrefix(val, uMinus) +} diff --git a/rows_test.go b/rows_test.go index b00e4dd..0c1ffd5 100644 --- a/rows_test.go +++ b/rows_test.go @@ -4,9 +4,11 @@ import ( "context" "database/sql/driver" "errors" + "fmt" "io" "math" "reflect" + "strings" "testing" "time" @@ -537,15 +539,15 @@ func TestNextNoDirectResults(t *testing.T) { row := make([]driver.Value, len(colNames)) err = rowSet.Next(row) - timestamp, _ := time.Parse(TimestampFormat, "2021-07-01 05:43:28") - date, _ := time.Parse(DateFormat, "2021-07-01") + timestamp, _ := time.Parse(dateTimeFormats["TIMESTAMP"], "2021-07-01 05:43:28") + date, _ := time.Parse(dateTimeFormats["DATE"], "2021-07-01") row0 := []driver.Value{ true, driver.Value(nil), int16(0), int32(0), int64(0), - float64(0), + float32(0), float64(0), "s0", timestamp, @@ -592,15 +594,15 @@ func TestNextWithDirectResults(t *testing.T) { err := rowSet.Next(row) - timestamp, _ := time.Parse(TimestampFormat, "2021-07-01 05:43:28") - date, _ := time.Parse(DateFormat, "2021-07-01") + timestamp, _ := time.Parse(dateTimeFormats["TIMESTAMP"], "2021-07-01 05:43:28") + date, _ := time.Parse(dateTimeFormats["DATE"], "2021-07-01") row0 := []driver.Value{ true, driver.Value(nil), int16(0), int32(0), int64(0), - float64(0), + float32(0), float64(0), "s0", timestamp, @@ -621,6 +623,63 @@ func TestNextWithDirectResults(t *testing.T) { assert.Equal(t, 1, fetchResultsCount) } +func TestHandlingDateTime(t *testing.T) { + t.Run("should do nothing if data is not a date/time", func(t *testing.T) { + val, err := handleDateTime("this is not a date", "STRING", "string_col", time.UTC) + assert.Nil(t, err, "handleDateTime should do nothing if a column is not a date/time") + assert.Equal(t, "this is not a date", val) + }) + + t.Run("should error on invalid date/time value", func(t *testing.T) { + _, err := handleDateTime("this is not a date", "DATE", "date_col", time.UTC) + assert.NotNil(t, err) + assert.True(t, strings.HasPrefix(err.Error(), fmt.Sprintf(errRowsParseValue, "DATE", "this is not a date", "date_col"))) + }) + + t.Run("should parse valid date", func(t *testing.T) { + dt, err := handleDateTime("2006-12-22", "DATE", "date_col", time.UTC) + assert.Nil(t, err) + assert.Equal(t, time.Date(2006, 12, 22, 0, 0, 0, 0, time.UTC), dt) + }) + + t.Run("should parse valid timestamp", func(t *testing.T) { + dt, err := handleDateTime("2006-12-22 17:13:11.000001000", "TIMESTAMP", "timestamp_col", time.UTC) + assert.Nil(t, err) + assert.Equal(t, time.Date(2006, 12, 22, 17, 13, 11, 1000, time.UTC), dt) + }) + + t.Run("should parse date with negative year", func(t *testing.T) { + expectedTime := time.Date(-2006, 12, 22, 0, 0, 0, 0, time.UTC) + dateStrings := []string{ + "-2006-12-22", + "\u22122006-12-22", + "\x2D2006-12-22", + } + + for _, s := range dateStrings { + dt, err := handleDateTime(s, "DATE", "date_col", time.UTC) + assert.Nil(t, err) + assert.Equal(t, expectedTime, dt) + } + }) + + t.Run("should parse timestamp with negative year", func(t *testing.T) { + expectedTime := time.Date(-2006, 12, 22, 17, 13, 11, 1000, time.UTC) + + timestampStrings := []string{ + "-2006-12-22 17:13:11.000001000", + "\u22122006-12-22 17:13:11.000001000", + "\x2D2006-12-22 17:13:11.000001000", + } + + for _, s := range timestampStrings { + dt, err := handleDateTime(s, "TIMESTAMP", "timestamp_col", time.UTC) + assert.Nil(t, err) + assert.Equal(t, expectedTime, dt) + } + }) +} + func TestGetScanType(t *testing.T) { var getMetadataCount, fetchResultsCount int @@ -799,6 +858,40 @@ func TestColumnTypeDatabaseTypeName(t *testing.T) { assert.Equal(t, expectedScanTypes, scanTypes) } +func TestRowsCloseOptimization(t *testing.T) { + t.Parallel() + + var closeCount int + client := &client.TestClient{ + FnCloseOperation: func(ctx context.Context, req *cli_service.TCloseOperationReq) (_r *cli_service.TCloseOperationResp, _err error) { + closeCount++ + return nil, nil + }, + } + + rowSet := NewRows("", "", client, &cli_service.TOperationHandle{}, 1, nil, nil) + + // rowSet has no direct results calling Close should result in call to client to close operation + err := rowSet.Close() + assert.Nil(t, err, "rows.Close should not throw an error") + assert.Equal(t, 1, closeCount) + + // rowSet has direct results, but operation was not closed so it should call client to close operation + closeCount = 0 + rowSet = NewRows("", "", client, &cli_service.TOperationHandle{}, 1, nil, &cli_service.TSparkDirectResults{}) + err = rowSet.Close() + assert.Nil(t, err, "rows.Close should not throw an error") + assert.Equal(t, 1, closeCount) + + // rowSet has direct results which include a close operation response. rowSet should be marked as closed + // and calling Close should not call into the client. + closeCount = 0 + rowSet = NewRows("", "", client, &cli_service.TOperationHandle{}, 1, nil, &cli_service.TSparkDirectResults{CloseOperation: &cli_service.TCloseOperationResp{}}) + err = rowSet.Close() + assert.Nil(t, err, "rows.Close should not throw an error") + assert.Equal(t, 0, closeCount) +} + type rowTestPagingResult struct { getMetadataCount int fetchResultsCount int diff --git a/statement.go b/statement.go index 940649a..b06f496 100644 --- a/statement.go +++ b/statement.go @@ -17,15 +17,20 @@ func (s *stmt) Close() error { return nil } +// NumInput returns -1 and the sql package will not sanity check Exec or Query argument counts. func (s *stmt) NumInput() int { return -1 } +// Exec is not implemented. +// // Deprecated: Use StmtExecContext instead. func (s *stmt) Exec(args []driver.Value) (driver.Result, error) { return nil, errors.New(ErrNotImplemented) } +// Query is not implemented. +// // Deprecated: Use StmtQueryContext instead. func (s *stmt) Query(args []driver.Value) (driver.Rows, error) { return nil, errors.New(ErrNotImplemented) diff --git a/testserver.go b/testserver.go index e2f0d5f..09b300b 100644 --- a/testserver.go +++ b/testserver.go @@ -20,7 +20,6 @@ func (h *thriftHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { func initThriftTestServer(handler cli_service.TCLIService) *httptest.Server { - // endpoint := fmt.Sprintf("%s:%d", cfg.Host, cfg.Port) tcfg := &thrift.TConfiguration{ TLSConfig: nil, }