From 3666e67a428485b18726dd2cf4e42ce86b1a6288 Mon Sep 17 00:00:00 2001 From: ankita Date: Wed, 27 Jul 2022 11:00:18 -0700 Subject: [PATCH] [Feature] return error with stack --- connection.go | 14 +++++++------- driver.go | 14 +++++++------- go.mod | 5 ++++- go.sum | 2 ++ hive/client.go | 4 ++-- hive/errors.go | 19 +++++++++++++++++++ hive/hive.go | 6 +++--- hive/operation.go | 14 +++++++------- hive/result_set.go | 6 +++--- hive/session.go | 12 ++++++------ statement.go | 18 +++++++++--------- 11 files changed, 69 insertions(+), 45 deletions(-) create mode 100644 hive/errors.go diff --git a/connection.go b/connection.go index 12274c1..9874b95 100644 --- a/connection.go +++ b/connection.go @@ -22,11 +22,11 @@ type Conn struct { func (c *Conn) Ping(ctx context.Context) error { session, err := c.OpenSession(ctx) if err != nil { - return err + return hive.WithStack(err) } if err := session.Ping(ctx); err != nil { - return err + return hive.WithStack(err) } return nil @@ -61,13 +61,13 @@ func (c *Conn) PrepareContext(ctx context.Context, query string) (driver.Stmt, e func (c *Conn) QueryContext(ctx context.Context, q string, args []driver.NamedValue) (driver.Rows, error) { session, err := c.OpenSession(ctx) if err != nil { - return nil, err + return nil, hive.WithStack(err) } tmpl := template(q) stmt, err := statement(tmpl, args) if err != nil { - return nil, err + return nil, hive.WithStack(err) } return query(ctx, session, stmt) } @@ -76,13 +76,13 @@ func (c *Conn) QueryContext(ctx context.Context, q string, args []driver.NamedVa func (c *Conn) ExecContext(ctx context.Context, q string, args []driver.NamedValue) (driver.Result, error) { session, err := c.OpenSession(ctx) if err != nil { - return nil, err + return nil, hive.WithStack(err) } tmpl := template(q) stmt, err := statement(tmpl, args) if err != nil { - return nil, err + return nil, hive.WithStack(err) } return exec(ctx, session, stmt) } @@ -109,7 +109,7 @@ func (c *Conn) OpenSession(ctx context.Context) (*hive.Session, error) { func (c *Conn) ResetSession(ctx context.Context) error { if c.session != nil { if err := c.session.Close(ctx); err != nil { - return err + return hive.WithStack(err) } c.session = nil } diff --git a/driver.go b/driver.go index 5732fe1..e597e90 100644 --- a/driver.go +++ b/driver.go @@ -29,7 +29,7 @@ type Driver struct{} func (d *Driver) Open(uri string) (driver.Conn, error) { opts, err := parseURI(uri) if err != nil { - return nil, err + return nil, hive.WithStack(err) } // (eric) Don't log opts because it contains sensitive information. @@ -37,7 +37,7 @@ func (d *Driver) Open(uri string) (driver.Conn, error) { conn, err := connect(opts) if err != nil { - return nil, err + return nil, hive.WithStack(err) } return conn, nil } @@ -45,7 +45,7 @@ func (d *Driver) Open(uri string) (driver.Conn, error) { func parseURI(uri string) (*Options, error) { u, err := url.Parse(uri) if err != nil { - return nil, err + return nil, hive.WithStack(err) } if u.Scheme != "databricks" { @@ -69,7 +69,7 @@ func parseURI(uri string) (*Options, error) { host, port, err := net.SplitHostPort(u.Host) if err != nil { - return nil, err + return nil, hive.WithStack(err) } opts.Host = host @@ -123,7 +123,7 @@ func (d *Driver) OpenConnector(name string) (driver.Connector, error) { opts, err := parseURI(name) if err != nil { - return nil, err + return nil, hive.WithStack(err) } return &connector{opts: opts}, nil @@ -161,14 +161,14 @@ func connect(opts *Options) (*Conn, error) { } if err != nil { - return nil, err + return nil, hive.WithStack(err) } httpOptions := thrift.THttpClientOptions{Client: httpClient} endpointUrl := fmt.Sprintf("https://%s:%s@%s:%s"+opts.HTTPPath, "token", url.QueryEscape(opts.Token), opts.Host, opts.Port) transport, err = thrift.NewTHttpClientTransportFactoryWithOptions(endpointUrl, httpOptions).GetTransport(socket) if err != nil { - return nil, err + return nil, hive.WithStack(err) } httpTransport, ok := transport.(*thrift.THttpClient) diff --git a/go.mod b/go.mod index 1e12d03..9e463bf 100644 --- a/go.mod +++ b/go.mod @@ -2,4 +2,7 @@ module github.com/databricks/databricks-sql-go go 1.18 -require github.com/apache/thrift v0.12.0 +require ( + github.com/apache/thrift v0.12.0 + github.com/pkg/errors v0.9.1 +) diff --git a/go.sum b/go.sum index 817e349..1aea9f1 100644 --- a/go.sum +++ b/go.sum @@ -1,2 +1,4 @@ github.com/apache/thrift v0.12.0 h1:pODnxUFNcjP9UTLZGTdeh+j16A8lJbRvD3rOtrk/7bs= github.com/apache/thrift v0.12.0/go.mod h1:cp2SuWMxlEZw2r+iP2GNCdIi4C1qmUzdZFSVb+bacwQ= +github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= diff --git a/hive/client.go b/hive/client.go index 08471f4..bcf0f03 100644 --- a/hive/client.go +++ b/hive/client.go @@ -42,10 +42,10 @@ func (c *Client) OpenSession(ctx context.Context) (*Session, error) { resp, err := c.client.OpenSession(ctx, &req) if err != nil { - return nil, err + return nil, WithStack(err) } if err := checkStatus(resp); err != nil { - return nil, err + return nil, WithStack(err) } c.log.Printf("open session: %s", guid(resp.SessionHandle.GetSessionId().GUID)) diff --git a/hive/errors.go b/hive/errors.go new file mode 100644 index 0000000..551a848 --- /dev/null +++ b/hive/errors.go @@ -0,0 +1,19 @@ +package hive + +import ( + "github.com/pkg/errors" +) + +type errorStackTracer interface { + StackTrace() errors.StackTrace +} + +//adds a stack trace if not already present +func WithStack(err error) error { + if _, ok := err.(errorStackTracer); ok { + return err + } + // newError := errors.WithStack(err) + // fmt.Printf("%+v\n", newError) + return errors.WithStack(err) +} diff --git a/hive/hive.go b/hive/hive.go index c18da04..1795df4 100644 --- a/hive/hive.go +++ b/hive/hive.go @@ -24,10 +24,10 @@ func checkStatus(resp interface{}) error { if ok { status := rpcresp.GetStatus() if status.StatusCode == cli_service.TStatusCode_ERROR_STATUS { - return errors.New(status.GetErrorMessage()) + return WithStack(errors.New(status.GetErrorMessage())) } if status.StatusCode == cli_service.TStatusCode_INVALID_HANDLE_STATUS { - return errors.New("thrift: invalid handle") + return WithStack(errors.New("thrift: invalid handle")) } // SUCCESS, SUCCESS_WITH_INFO, STILL_EXECUTING are ok @@ -35,7 +35,7 @@ func checkStatus(resp interface{}) error { } log.Printf("response: %v", resp) - return errors.New("thrift: invalid response") + return WithStack(errors.New("thrift: invalid response")) } func guid(b []byte) string { diff --git a/hive/operation.go b/hive/operation.go index b8c09ef..30488d2 100644 --- a/hive/operation.go +++ b/hive/operation.go @@ -32,10 +32,10 @@ func (op *Operation) GetResultSetMetadata(ctx context.Context) (*TableSchema, er resp, err := op.hive.client.GetResultSetMetadata(ctx, &req) if err != nil { - return nil, err + return nil, WithStack(err) } if err := checkStatus(resp); err != nil { - return nil, err + return nil, WithStack(err) } schema := new(TableSchema) @@ -65,7 +65,7 @@ func (op *Operation) FetchResults(ctx context.Context, schema *TableSchema) (*Re resp, err := fetch(ctx, op, schema) if err != nil { - return nil, err + return nil, WithStack(err) } rs := ResultSet{ @@ -93,10 +93,10 @@ func fetch(ctx context.Context, op *Operation, schema *TableSchema) (*cli_servic resp, err := op.hive.client.FetchResults(ctx, &req) if err != nil { - return nil, err + return nil, WithStack(err) } if err := checkStatus(resp); err != nil { - return nil, err + return nil, WithStack(err) } op.hive.log.Printf("results: %v", resp.Results) @@ -110,10 +110,10 @@ func (op *Operation) Close(ctx context.Context) error { } resp, err := op.hive.client.CloseOperation(ctx, &req) if err != nil { - return err + return WithStack(err) } if err := checkStatus(resp); err != nil { - return err + return WithStack(err) } op.hive.log.Printf("close operation: %v", guid(op.h.OperationId.GUID)) diff --git a/hive/result_set.go b/hive/result_set.go index 0d72b50..7225b64 100644 --- a/hive/result_set.go +++ b/hive/result_set.go @@ -32,7 +32,7 @@ func (rs *ResultSet) Next(dest []driver.Value) error { resp, err := rs.fetchfn() if err != nil { - return err + return WithStack(err) } // Replace previous page of results with new page of results @@ -52,7 +52,7 @@ func (rs *ResultSet) Next(dest []driver.Value) error { for i := range dest { val, err := value(rs.result.Columns[i], rs.schema.Columns[i], rs.idx, rs.loc) if err != nil { - return err + return WithStack(err) } dest[i] = val @@ -123,7 +123,7 @@ func value(col *cli_service.TColumn, cd *ColDesc, i int, loc *time.Location) (in } t, err := time.ParseInLocation(TimestampFormat, col.StringVal.Values[i], loc) if err != nil { - return nil, err + return nil, WithStack(err) } return t, nil case "DATE": diff --git a/hive/session.go b/hive/session.go index e921190..ef248a4 100644 --- a/hive/session.go +++ b/hive/session.go @@ -21,10 +21,10 @@ func (s *Session) Ping(ctx context.Context) error { resp, err := s.hive.client.GetInfo(ctx, &req) if err != nil { - return err + return WithStack(err) } if err := checkStatus(resp); err != nil { - return err + return WithStack(err) } s.hive.log.Printf("ping. server name: %s", resp.InfoValue.GetStringValue()) @@ -40,10 +40,10 @@ func (s *Session) ExecuteStatement(ctx context.Context, stmt string) (*Operation resp, err := s.hive.client.ExecuteStatement(ctx, &req) if err != nil { - return nil, err + return nil, WithStack(err) } if err := checkStatus(resp); err != nil { - return nil, err + return nil, WithStack(err) } s.hive.log.Printf("execute operation: %s", guid(resp.OperationHandle.OperationId.GUID)) s.hive.log.Printf("operation. has resultset: %v", resp.OperationHandle.GetHasResultSet()) @@ -59,10 +59,10 @@ func (s *Session) Close(ctx context.Context) error { } resp, err := s.hive.client.CloseSession(ctx, &req) if err != nil { - return err + return WithStack(err) } if err := checkStatus(resp); err != nil { - return err + return WithStack(err) } return nil } diff --git a/statement.go b/statement.go index 89be891..3645584 100644 --- a/statement.go +++ b/statement.go @@ -62,11 +62,11 @@ func (s *Stmt) Query(args []driver.Value) (driver.Rows, error) { func (s *Stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { session, err := s.conn.OpenSession(ctx) if err != nil { - return nil, err + return nil, hive.WithStack(err) } stmt, err := statement(s.stmt, args) if err != nil { - return nil, err + return nil, hive.WithStack(err) } return query(ctx, session, stmt) } @@ -75,11 +75,11 @@ func (s *Stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driv func (s *Stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) { session, err := s.conn.OpenSession(ctx) if err != nil { - return nil, err + return nil, hive.WithStack(err) } stmt, err := statement(s.stmt, args) if err != nil { - return nil, err + return nil, hive.WithStack(err) } return exec(ctx, session, stmt) } @@ -119,17 +119,17 @@ func statement(tmpl string, args []driver.NamedValue) (string, error) { func query(ctx context.Context, session *hive.Session, stmt string) (driver.Rows, error) { operation, err := session.ExecuteStatement(ctx, stmt) if err != nil { - return nil, err + return nil, hive.WithStack(err) } schema, err := operation.GetResultSetMetadata(ctx) if err != nil { - return nil, err + return nil, hive.WithStack(err) } rs, err := operation.FetchResults(ctx, schema) if err != nil { - return nil, err + return nil, hive.WithStack(err) } return &Rows{ @@ -142,11 +142,11 @@ func query(ctx context.Context, session *hive.Session, stmt string) (driver.Rows func exec(ctx context.Context, session *hive.Session, stmt string) (driver.Result, error) { operation, err := session.ExecuteStatement(ctx, stmt) if err != nil { - return nil, err + return nil, hive.WithStack(err) } if err := operation.Close(ctx); err != nil { - return nil, err + return nil, hive.WithStack(err) } return driver.ResultNoRows, nil