diff --git a/sqlite3.go b/sqlite3.go index ce985ec8..e7a6f876 100644 --- a/sqlite3.go +++ b/sqlite3.go @@ -201,6 +201,53 @@ static int sqlite3_system_errno(sqlite3 *db) { return 0; } #endif + +#define GO_SQLITE3_DECL_DATE (1 << 7) +#define GO_SQLITE3_DECL_BOOL (1 << 6) +#define GO_SQLITE3_DECL_MASK (GO_SQLITE3_DECL_DATE | GO_SQLITE3_DECL_BOOL) +#define GO_SQLITE3_TYPE_MASK (GO_SQLITE3_DECL_BOOL - 1) + +// _sqlite3_column_decltypes stores the declared column type in the decls array. +// This function must always be called before _sqlite3_column_types since it +// overwrites the datatype. +static void _sqlite3_column_decltypes(sqlite3_stmt* stmt, uint8_t *decls, int ndecls) { + for (int i = 0; i < ndecls; i++) { + const char *typ = sqlite3_column_decltype(stmt, i); + if (typ == NULL) { + decls[i] = 0; + continue; + } + switch (typ[0]) { + case 'b': + case 'B': + if (!sqlite3_stricmp(typ, "boolean")) { + decls[i] = GO_SQLITE3_DECL_BOOL; + } + break; + case 'd': + case 'D': + if (!sqlite3_stricmp(typ, "date") || !sqlite3_stricmp(typ, "datetime")) { + decls[i] = GO_SQLITE3_DECL_DATE; + } + break; + case 't': + case 'T': + if (!sqlite3_stricmp(typ, "timestamp")) { + decls[i] = GO_SQLITE3_DECL_DATE; + } + break; + default: + decls[i] = 0; + } + } +} + +static void _sqlite3_column_types(sqlite3_stmt *stmt, uint8_t *typs, int ntyps) { + for (int i = 0; i < ntyps; i++) { + typs[i] &= GO_SQLITE3_DECL_MASK; // clear lower bits + typs[i] |= (uint8_t)sqlite3_column_type(stmt, i); + } +} */ import "C" import ( @@ -239,12 +286,6 @@ var SQLiteTimestampFormats = []string{ "2006-01-02", } -const ( - columnDate string = "date" - columnDatetime string = "datetime" - columnTimestamp string = "timestamp" -) - // This variable can be replaced with -ldflags like below: // go build -ldflags="-X 'github.com/mattn/go-sqlite3.driverName=my-sqlite3'" var driverName = "sqlite3" @@ -390,12 +431,31 @@ type SQLiteResult struct { changes int64 } +// A columnType is a compact representation of sqlite3 columns datatype and +// declared type. The first two bits store the declared type and the remaining +// six bits store the sqlite3 datatype. +type columnType uint8 + +// DeclType returns the declared type, which is currently GO_SQLITE3_DECL_DATE +// or GO_SQLITE3_DECL_BOOL, since those are the only two types that we need for +// converting values. +func (c columnType) DeclType() int { + return int(c) & C.GO_SQLITE3_DECL_MASK +} + +// DataType returns the sqlite3 datatype code of the column, which is the +// result of sqlite3_column_type. +func (c columnType) DataType() int { + return int(c) & C.GO_SQLITE3_TYPE_MASK +} + // SQLiteRows implements driver.Rows. type SQLiteRows struct { s *SQLiteStmt nc int cols []string decltype []string + coltype []columnType cls bool closed bool ctx context.Context // no better alternative to pass context into Next() method @@ -2146,7 +2206,10 @@ func (rc *SQLiteRows) Columns() []string { return rc.cols } -func (rc *SQLiteRows) declTypes() []string { +// DeclTypes return column types. +func (rc *SQLiteRows) DeclTypes() []string { + rc.s.mu.Lock() + defer rc.s.mu.Unlock() if rc.s.s != nil && rc.decltype == nil { rc.decltype = make([]string, rc.nc) for i := 0; i < rc.nc; i++ { @@ -2156,13 +2219,6 @@ func (rc *SQLiteRows) declTypes() []string { return rc.decltype } -// DeclTypes return column types. -func (rc *SQLiteRows) DeclTypes() []string { - rc.s.mu.Lock() - defer rc.s.mu.Unlock() - return rc.declTypes() -} - // Next move cursor to next. Attempts to honor context timeout from QueryContext call. func (rc *SQLiteRows) Next(dest []driver.Value) error { rc.s.mu.Lock() @@ -2195,6 +2251,13 @@ func (rc *SQLiteRows) Next(dest []driver.Value) error { } } +func (rc *SQLiteRows) colTypePtr() *C.uint8_t { + if len(rc.coltype) == 0 { + return nil + } + return (*C.uint8_t)(unsafe.Pointer(&rc.coltype[0])) +} + // nextSyncLocked moves cursor to next; must be called with locked mutex. func (rc *SQLiteRows) nextSyncLocked(dest []driver.Value) error { rv := C._sqlite3_step_internal(rc.s.s) @@ -2208,15 +2271,24 @@ func (rc *SQLiteRows) nextSyncLocked(dest []driver.Value) error { } return nil } + if len(dest) == 0 { + return nil + } - rc.declTypes() + if rc.coltype == nil { + rc.coltype = make([]columnType, rc.nc) + C._sqlite3_column_decltypes(rc.s.s, rc.colTypePtr(), C.int(rc.nc)) + } + // Must call this each time since sqlite3 is loosely + // typed and the column types can vary between rows. + C._sqlite3_column_types(rc.s.s, rc.colTypePtr(), C.int(rc.nc)) for i := range dest { - switch C.sqlite3_column_type(rc.s.s, C.int(i)) { + switch rc.coltype[i].DataType() { case C.SQLITE_INTEGER: val := int64(C.sqlite3_column_int64(rc.s.s, C.int(i))) - switch rc.decltype[i] { - case columnTimestamp, columnDatetime, columnDate: + switch rc.coltype[i].DeclType() { + case C.GO_SQLITE3_DECL_DATE: var t time.Time // Assume a millisecond unix timestamp if it's 13 digits -- too // large to be a reasonable timestamp in seconds. @@ -2231,7 +2303,7 @@ func (rc *SQLiteRows) nextSyncLocked(dest []driver.Value) error { t = t.In(rc.s.c.loc) } dest[i] = t - case "boolean": + case C.GO_SQLITE3_DECL_BOOL: dest[i] = val > 0 default: dest[i] = val @@ -2255,8 +2327,7 @@ func (rc *SQLiteRows) nextSyncLocked(dest []driver.Value) error { n := int(C.sqlite3_column_bytes(rc.s.s, C.int(i))) s := C.GoStringN((*C.char)(unsafe.Pointer(C.sqlite3_column_text(rc.s.s, C.int(i)))), C.int(n)) - switch rc.decltype[i] { - case columnTimestamp, columnDatetime, columnDate: + if rc.coltype[i].DeclType() == C.GO_SQLITE3_DECL_DATE { var t time.Time s = strings.TrimSuffix(s, "Z") for _, format := range SQLiteTimestampFormats { @@ -2273,7 +2344,7 @@ func (rc *SQLiteRows) nextSyncLocked(dest []driver.Value) error { t = t.In(rc.s.c.loc) } dest[i] = t - default: + } else { dest[i] = s } } diff --git a/sqlite3_test.go b/sqlite3_test.go index 63c939d3..43652b23 100644 --- a/sqlite3_test.go +++ b/sqlite3_test.go @@ -2030,7 +2030,7 @@ func BenchmarkCustomFunctions(b *testing.B) { } func TestSuite(t *testing.T) { - initializeTestDB(t) + initializeTestDB(t, false) defer freeTestDB() for _, test := range tests { @@ -2039,7 +2039,7 @@ func TestSuite(t *testing.T) { } func BenchmarkSuite(b *testing.B) { - initializeTestDB(b) + initializeTestDB(b, true) defer freeTestDB() for _, benchmark := range benchmarks { @@ -2068,8 +2068,13 @@ type TestDB struct { var db *TestDB -func initializeTestDB(t testing.TB) { - tempFilename := TempFilename(t) +func initializeTestDB(t testing.TB, memory bool) { + var tempFilename string + if memory { + tempFilename = ":memory:" + } else { + tempFilename = TempFilename(t) + } d, err := sql.Open("sqlite3", tempFilename+"?_busy_timeout=99999") if err != nil { os.Remove(tempFilename) @@ -2084,9 +2089,11 @@ func freeTestDB() { if err != nil { panic(err) } - err = os.Remove(db.tempFilename) - if err != nil { - panic(err) + if db.tempFilename != "" && db.tempFilename != ":memory:" { + err := os.Remove(db.tempFilename) + if err != nil { + panic(err) + } } } @@ -2111,6 +2118,7 @@ var benchmarks = []testing.InternalBenchmark{ {Name: "BenchmarkStmt", F: benchmarkStmt}, {Name: "BenchmarkRows", F: benchmarkRows}, {Name: "BenchmarkStmtRows", F: benchmarkStmtRows}, + {Name: "BenchmarkStmt10Cols", F: benchmarkStmt10Cols}, } func (db *TestDB) mustExec(sql string, args ...any) sql.Result { @@ -2568,3 +2576,60 @@ func benchmarkStmtRows(b *testing.B) { } } } + +func benchmarkStmt10Cols(b *testing.B) { + db.once.Do(makeBench) + + const createTableStmt = ` + DROP TABLE IF EXISTS bench_cols; + VACUUM; + CREATE TABLE bench_cols ( + r0 INTEGER NOT NULL, + r1 INTEGER NOT NULL, + r2 INTEGER NOT NULL, + r3 INTEGER NOT NULL, + r4 INTEGER NOT NULL, + r5 INTEGER NOT NULL, + r6 INTEGER NOT NULL, + r7 INTEGER NOT NULL, + r8 INTEGER NOT NULL, + r9 INTEGER NOT NULL + );` + if _, err := db.Exec(createTableStmt); err != nil { + b.Fatal(err) + } + for i := int64(0); i < 4; i++ { + _, err := db.Exec("INSERT INTO bench_cols VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?);", + i, i, i, i, i, i, i, i, i, i) + if err != nil { + b.Fatal(err) + } + } + + stmt, err := db.Prepare("SELECT * FROM bench_cols;") + if err != nil { + b.Fatal(err) + } + defer stmt.Close() + + b.ResetTimer() + var ( + v0, v1, v2, v3, v4 int64 + v5, v6, v7, v8, v9 int64 + ) + for i := 0; i < b.N; i++ { + rows, err := stmt.Query() + if err != nil { + b.Fatal(err) + } + for rows.Next() { + err := rows.Scan(&v0, &v1, &v2, &v3, &v4, &v5, &v6, &v7, &v8, &v9) + if err != nil { + b.Fatal(err) + } + } + if err := rows.Err(); err != nil { + b.Fatal(err) + } + } +}