Loading func_test.go +183 −0 Original line number Diff line number Diff line Loading @@ -17,6 +17,7 @@ import ( "os" "path" "path/filepath" "reflect" "regexp" "strings" "sync" Loading Loading @@ -227,6 +228,18 @@ func init() { return nil, nil }, ) // Volatile counterpart to issue226_noop, registered with VolatileArgs=true. // Paired with BenchmarkUDFArgsAllocationVolatile to quantify the additional // saving from skipping the per-call TEXT / BLOB copy. MustRegisterFunction("issue226_noop_volatile", &FunctionImpl{ NArgs: 3, Deterministic: true, VolatileArgs: true, Scalar: func(ctx *FunctionContext, args []driver.Value) (driver.Value, error) { return nil, nil }, }) } // BenchmarkUDFArgsAllocation measures allocations in functionArgs when a Loading Loading @@ -1214,3 +1227,173 @@ func TestRegisteredFunctions(t *testing.T) { }) }) } // BenchmarkUDFArgsAllocationVolatile mirrors BenchmarkUDFArgsAllocation but // invokes a UDF registered with VolatileArgs=true. The difference between the // two benchmarks isolates the per-call cost of copying TEXT / BLOB argument // bodies into Go-owned memory, which is what VolatileArgs eliminates. func BenchmarkUDFArgsAllocationVolatile(b *testing.B) { db, err := sql.Open(driverName, "file::memory:") if err != nil { b.Fatal(err) } defer db.Close() if _, err := db.Exec(`CREATE TABLE t (a INTEGER, b TEXT, c BLOB)`); err != nil { b.Fatal(err) } const rows = 1000 tx, err := db.Begin() if err != nil { b.Fatal(err) } stmt, err := tx.Prepare(`INSERT INTO t (a, b, c) VALUES (?, ?, ?)`) if err != nil { b.Fatal(err) } for i := 0; i < rows; i++ { if _, err := stmt.Exec(int64(i), "hello", []byte{1, 2, 3}); err != nil { b.Fatal(err) } } stmt.Close() if err := tx.Commit(); err != nil { b.Fatal(err) } b.ReportAllocs() b.ResetTimer() for i := 0; i < b.N; i++ { r, err := db.Query(`SELECT issue226_noop_volatile(a, b, c) FROM t`) if err != nil { b.Fatal(err) } for r.Next() { } if err := r.Err(); err != nil { b.Fatal(err) } r.Close() } } // TestVolatileArgsScalar verifies that a scalar UDF registered with // VolatileArgs=true still receives correct TEXT and BLOB argument values // (including the empty cases that take a short-circuit path in functionArgs). // The UDF copies each value before the call returns, which is the required // usage pattern for VolatileArgs callbacks. func TestVolatileArgsScalar(t *testing.T) { var ( gotStrings []string gotBlobs [][]byte mu sync.Mutex ) MustRegisterFunction("vol_recorder_scalar", &FunctionImpl{ NArgs: 2, Deterministic: true, VolatileArgs: true, Scalar: func(ctx *FunctionContext, args []driver.Value) (driver.Value, error) { s := args[0].(string) b := args[1].([]byte) mu.Lock() gotStrings = append(gotStrings, strings.Clone(s)) gotBlobs = append(gotBlobs, append([]byte(nil), b...)) mu.Unlock() return nil, nil }, }) db, err := sql.Open(driverName, "file::memory:") if err != nil { t.Fatal(err) } defer db.Close() if _, err := db.Exec(`CREATE TABLE t (s TEXT, b BLOB)`); err != nil { t.Fatal(err) } if _, err := db.Exec(`INSERT INTO t (s, b) VALUES ('alpha', X'01020304'), ('beta', X'AABB'), ('', NULL)`); err != nil { t.Fatal(err) } rows, err := db.Query(`SELECT vol_recorder_scalar(s, COALESCE(b, X'')) FROM t ORDER BY rowid`) if err != nil { t.Fatal(err) } for rows.Next() { } if err := rows.Err(); err != nil { t.Fatal(err) } rows.Close() wantStrings := []string{"alpha", "beta", ""} if !reflect.DeepEqual(gotStrings, wantStrings) { t.Errorf("volatile scalar TEXT: got %q, want %q", gotStrings, wantStrings) } wantBlobs := [][]byte{{1, 2, 3, 4}, {0xAA, 0xBB}, nil} if !reflect.DeepEqual(gotBlobs, wantBlobs) { t.Errorf("volatile scalar BLOB: got %v, want %v", gotBlobs, wantBlobs) } } // volatileAggregate is an aggregate function used by TestVolatileArgsAggregate. // Each Step copies its argument into the aggregate's own buffer; the volatile // contract requires this. type volatileAggregate struct { strs []string blobs [][]byte } func (a *volatileAggregate) Step(ctx *FunctionContext, args []driver.Value) error { a.strs = append(a.strs, strings.Clone(args[0].(string))) a.blobs = append(a.blobs, append([]byte(nil), args[1].([]byte)...)) return nil } func (a *volatileAggregate) WindowInverse(ctx *FunctionContext, args []driver.Value) error { return nil } func (a *volatileAggregate) WindowValue(ctx *FunctionContext) (driver.Value, error) { return fmt.Sprintf("strs=%q blobs=%v", a.strs, a.blobs), nil } func (a *volatileAggregate) Final(ctx *FunctionContext) {} // TestVolatileArgsAggregate verifies that the volatile-args path is honored // by the Step trampoline for aggregate functions. The aggregate sees every // row's TEXT and BLOB and the assembled result must match the inserted data. func TestVolatileArgsAggregate(t *testing.T) { MustRegisterFunction("vol_recorder_agg", &FunctionImpl{ NArgs: 2, VolatileArgs: true, MakeAggregate: func(ctx FunctionContext) (AggregateFunction, error) { return &volatileAggregate{}, nil }, }) db, err := sql.Open(driverName, "file::memory:") if err != nil { t.Fatal(err) } defer db.Close() if _, err := db.Exec(`CREATE TABLE t (s TEXT, b BLOB)`); err != nil { t.Fatal(err) } if _, err := db.Exec(`INSERT INTO t (s, b) VALUES ('one', X'01'), ('two', X'02'), ('three', X'03')`); err != nil { t.Fatal(err) } var got string if err := db.QueryRow(`SELECT vol_recorder_agg(s, b) FROM (SELECT s, b FROM t ORDER BY rowid)`).Scan(&got); err != nil { t.Fatal(err) } want := `strs=["one" "two" "three"] blobs=[[1] [2] [3]]` if got != want { t.Errorf("volatile aggregate result:\n got %s\nwant %s", got, want) } } sqlite.go +127 −35 Original line number Diff line number Diff line Loading @@ -255,6 +255,45 @@ type FunctionImpl struct { // MakeAggregate is called at the beginning of each evaluation of an // aggregate function. MakeAggregate func(ctx FunctionContext) (AggregateFunction, error) // VolatileArgs is an opt-in performance flag that eliminates the per-call // allocation of string and []byte argument bodies. When true, the driver // passes argument strings and byte slices as zero-copy views into // SQLite-owned memory instead of Go-allocated copies. // // Setting this is unsafe unless the user-provided Scalar / Step / // WindowInverse callbacks treat string and []byte arguments as strictly // transient: they must not be retained past the return of the call, // neither directly (storing the slice or string in a struct field, map, // channel, or outer-scope variable) nor indirectly (passing them to // something that captures them, including most concurrency primitives). // // Retaining a volatile argument produces silent data corruption: SQLite // reuses the underlying buffer for the next row, so on a later read every // retained value will appear to hold the contents of the most recent row. // The Go race detector cannot catch this because UDF execution is // sequential on a single goroutine; the corruption is deterministic and // invisible to -race. // // As a guard against accidental capture, callbacks that must retain // values across rows should copy: // // saved := append([]byte(nil), args[0].([]byte)...) // []byte // saved := string(append([]byte(nil), args[0].(string)...)) // string, no aliasing // // When in doubt, leave VolatileArgs at its default (false) — the driver // already pools the argument-slice header (issue #226), so the per-row // overhead with VolatileArgs=false is one make([]byte) per BLOB column // and one libc.GoString per TEXT column, not a fresh slice header. // // Similarly, do not re-enter SQLite on the same connection while a // volatile argument is in scope. A nested Query/Exec on the same conn // can cause SQLite to reuse the underlying value buffers, so a volatile // string or []byte read before the nested call may alias different // bytes after it returns. // // VolatileArgs has no effect on integer, float, time, or NULL arguments. VolatileArgs bool } // An AggregateFunction is an invocation of an aggregate or window function. See Loading @@ -266,13 +305,18 @@ type FunctionImpl struct { type AggregateFunction interface { // Step is called for each row of an aggregate function's SQL // invocation. The argument Values are not valid past the return of the // function. // function. When the aggregate was registered with // [FunctionImpl.VolatileArgs] set to true, string and []byte arguments // in rowArgs are zero-copy views into SQLite-owned memory and retaining // them produces silent data corruption — see [FunctionImpl.VolatileArgs] // for the full safety contract. Step(ctx *FunctionContext, rowArgs []driver.Value) error // WindowInverse is called to remove the oldest presently aggregated // result of Step from the current window. The arguments are those // passed to Step for the row being removed. The argument Values are not // valid past the return of the function. // valid past the return of the function. The same // [FunctionImpl.VolatileArgs] caveat applies as for Step. WindowInverse(ctx *FunctionContext, rowArgs []driver.Value) error // WindowValue is called to get the current value of an aggregate Loading Loading @@ -506,7 +550,7 @@ func registerFunction( if impl.Scalar != nil { xFuncs.mu.Lock() id := xFuncs.ids.next() xFuncs.m[id] = impl.Scalar xFuncs.m[id] = xFuncEntry{fn: impl.Scalar, volatile: impl.VolatileArgs} xFuncs.mu.Unlock() udf.scalar = true Loading @@ -514,7 +558,7 @@ func registerFunction( } else { xAggregateFactories.mu.Lock() id := xAggregateFactories.ids.next() xAggregateFactories.m[id] = impl.MakeAggregate xAggregateFactories.m[id] = xAggregateFactoryEntry{factory: impl.MakeAggregate, volatile: impl.VolatileArgs} xAggregateFactories.mu.Unlock() udf.pApp = id Loading Loading @@ -594,7 +638,13 @@ func releaseUDFArgs(sp *[]driver.Value) { // functionArgs prepares a []driver.Value for one user-function invocation. // The returned slice is owned by the driver and must be released via // releaseUDFArgs once the user function returns. func functionArgs(tls *libc.TLS, argc int32, argv uintptr) *[]driver.Value { // // When volatile is true, SQLITE_TEXT and SQLITE_BLOB arguments are returned as // zero-copy views into SQLite-owned memory (see [FunctionImpl.VolatileArgs] // for the user-facing safety contract). When false (the default for all // existing call sites), text and blob payloads are copied into Go-owned // memory and stay valid for the lifetime of the slice. func functionArgs(tls *libc.TLS, argc int32, argv uintptr, volatile bool) *[]driver.Value { sp := acquireUDFArgs(int(argc)) args := *sp for i := int32(0); i < argc; i++ { Loading @@ -602,7 +652,17 @@ func functionArgs(tls *libc.TLS, argc int32, argv uintptr) *[]driver.Value { switch valType := sqlite3.Xsqlite3_value_type(tls, valPtr); valType { case sqlite3.SQLITE_TEXT: if volatile { p := sqlite3.Xsqlite3_value_text(tls, valPtr) n := sqlite3.Xsqlite3_value_bytes(tls, valPtr) if p == 0 || n == 0 { args[i] = "" } else { args[i] = unsafe.String((*byte)(unsafe.Pointer(p)), int(n)) } } else { args[i] = libc.GoString(sqlite3.Xsqlite3_value_text(tls, valPtr)) } case sqlite3.SQLITE_INTEGER: args[i] = sqlite3.Xsqlite3_value_int64(tls, valPtr) case sqlite3.SQLITE_FLOAT: Loading @@ -612,11 +672,19 @@ func functionArgs(tls *libc.TLS, argc int32, argv uintptr) *[]driver.Value { case sqlite3.SQLITE_BLOB: size := sqlite3.Xsqlite3_value_bytes(tls, valPtr) blobPtr := sqlite3.Xsqlite3_value_blob(tls, valPtr) if volatile { if blobPtr == 0 || size == 0 { args[i] = make([]byte, 0) } else { args[i] = unsafe.Slice((*byte)(unsafe.Pointer(blobPtr)), int(size)) } } else { v := make([]byte, size) if size != 0 { copy(v, (*libc.RawMem)(unsafe.Pointer(blobPtr))[:size:size]) } args[i] = v } default: panic(fmt.Sprintf("unexpected argument type %q passed by sqlite", valType)) } Loading Loading @@ -682,26 +750,26 @@ func functionReturnValue(tls *libc.TLS, ctx uintptr, res driver.Value) error { var ( xFuncs = struct { mu sync.RWMutex m map[uintptr]func(*FunctionContext, []driver.Value) (driver.Value, error) m map[uintptr]xFuncEntry ids idGen }{ m: make(map[uintptr]func(*FunctionContext, []driver.Value) (driver.Value, error)), m: make(map[uintptr]xFuncEntry), } xAggregateFactories = struct { mu sync.RWMutex m map[uintptr]func(FunctionContext) (AggregateFunction, error) m map[uintptr]xAggregateFactoryEntry ids idGen }{ m: make(map[uintptr]func(FunctionContext) (AggregateFunction, error)), m: make(map[uintptr]xAggregateFactoryEntry), } xAggregateContext = struct { mu sync.RWMutex m map[uintptr]AggregateFunction m map[uintptr]xAggregateContextEntry ids idGen }{ m: make(map[uintptr]AggregateFunction), m: make(map[uintptr]xAggregateContextEntry), } xCollations = struct { Loading @@ -713,6 +781,30 @@ var ( } ) // xFuncEntry pairs a registered scalar function with the VolatileArgs flag // captured at registration time. Bundled so trampolines can decide whether to // pass zero-copy or copied argument values without a second map lookup. type xFuncEntry struct { fn func(*FunctionContext, []driver.Value) (driver.Value, error) volatile bool } // xAggregateFactoryEntry pairs a registered aggregate factory with its // VolatileArgs flag, for the same reason as xFuncEntry. type xAggregateFactoryEntry struct { factory func(FunctionContext) (AggregateFunction, error) volatile bool } // xAggregateContextEntry holds the AggregateFunction instance for one // in-flight aggregate evaluation together with the VolatileArgs flag inherited // from its factory registration. Caching the flag here avoids a second // xAggregateFactories lookup in the Step / WindowInverse trampolines. type xAggregateContextEntry struct { fn AggregateFunction volatile bool } type idGen struct { bitset []uint64 } Loading @@ -736,42 +828,42 @@ func (gen *idGen) reclaim(id uintptr) { gen.bitset[bit/64] &^= 1 << (bit % 64) } func makeAggregate(tls *libc.TLS, ctx uintptr) (AggregateFunction, uintptr) { func makeAggregate(tls *libc.TLS, ctx uintptr) (AggregateFunction, bool, uintptr) { goCtx := FunctionContext{tls: tls, ctx: ctx} aggCtx := (*uintptr)(unsafe.Pointer(sqlite3.Xsqlite3_aggregate_context(tls, ctx, int32(ptrSize)))) setErrorResult := errorResultFunction(tls, ctx) if aggCtx == nil { setErrorResult(errors.New("insufficient memory for aggregate")) return nil, 0 return nil, false, 0 } if *aggCtx != 0 { // Already created. xAggregateContext.mu.RLock() f := xAggregateContext.m[*aggCtx] entry := xAggregateContext.m[*aggCtx] xAggregateContext.mu.RUnlock() return f, *aggCtx return entry.fn, entry.volatile, *aggCtx } factoryID := sqlite3.Xsqlite3_user_data(tls, ctx) xAggregateFactories.mu.RLock() factory := xAggregateFactories.m[factoryID] factoryEntry := xAggregateFactories.m[factoryID] xAggregateFactories.mu.RUnlock() f, err := factory(goCtx) f, err := factoryEntry.factory(goCtx) if err != nil { setErrorResult(err) return nil, 0 return nil, false, 0 } if f == nil { setErrorResult(errors.New("MakeAggregate function returned nil")) return nil, 0 return nil, false, 0 } xAggregateContext.mu.Lock() *aggCtx = xAggregateContext.ids.next() xAggregateContext.m[*aggCtx] = f xAggregateContext.m[*aggCtx] = xAggregateContextEntry{fn: f, volatile: factoryEntry.volatile} xAggregateContext.mu.Unlock() return f, *aggCtx return f, factoryEntry.volatile, *aggCtx } // cFuncPointer converts a function defined by a function declaration to a C pointer. Loading @@ -793,13 +885,13 @@ func cFuncPointer[T any](f T) uintptr { func funcTrampoline(tls *libc.TLS, ctx uintptr, argc int32, argv uintptr) { id := sqlite3.Xsqlite3_user_data(tls, ctx) xFuncs.mu.RLock() xFunc := xFuncs.m[id] entry := xFuncs.m[id] xFuncs.mu.RUnlock() setErrorResult := errorResultFunction(tls, ctx) sp := functionArgs(tls, argc, argv) sp := functionArgs(tls, argc, argv, entry.volatile) defer releaseUDFArgs(sp) res, err := xFunc(&FunctionContext{}, *sp) res, err := entry.fn(&FunctionContext{}, *sp) if err != nil { setErrorResult(err) Loading Loading @@ -828,13 +920,13 @@ func sqlite3AllocCString(tls *libc.TLS, s string) uintptr { } func stepTrampoline(tls *libc.TLS, ctx uintptr, argc int32, argv uintptr) { impl, _ := makeAggregate(tls, ctx) impl, volatile, _ := makeAggregate(tls, ctx) if impl == nil { return } setErrorResult := errorResultFunction(tls, ctx) sp := functionArgs(tls, argc, argv) sp := functionArgs(tls, argc, argv, volatile) defer releaseUDFArgs(sp) err := impl.Step(&FunctionContext{}, *sp) if err != nil { Loading @@ -843,13 +935,13 @@ func stepTrampoline(tls *libc.TLS, ctx uintptr, argc int32, argv uintptr) { } func inverseTrampoline(tls *libc.TLS, ctx uintptr, argc int32, argv uintptr) { impl, _ := makeAggregate(tls, ctx) impl, volatile, _ := makeAggregate(tls, ctx) if impl == nil { return } setErrorResult := errorResultFunction(tls, ctx) sp := functionArgs(tls, argc, argv) sp := functionArgs(tls, argc, argv, volatile) defer releaseUDFArgs(sp) err := impl.WindowInverse(&FunctionContext{}, *sp) if err != nil { Loading @@ -858,7 +950,7 @@ func inverseTrampoline(tls *libc.TLS, ctx uintptr, argc int32, argv uintptr) { } func valueTrampoline(tls *libc.TLS, ctx uintptr) { impl, _ := makeAggregate(tls, ctx) impl, _, _ := makeAggregate(tls, ctx) if impl == nil { return } Loading @@ -876,7 +968,7 @@ func valueTrampoline(tls *libc.TLS, ctx uintptr) { } func finalTrampoline(tls *libc.TLS, ctx uintptr) { impl, id := makeAggregate(tls, ctx) impl, _, id := makeAggregate(tls, ctx) if impl == nil { return } Loading vtab.go +2 −2 Original line number Diff line number Diff line Loading @@ -542,7 +542,7 @@ func vtabFilterTrampoline(tls *libc.TLS, pCursor uintptr, idxNum int32, idxStr u if idxStr != 0 { idxStrGo = libc.GoString(idxStr) } sp := functionArgs(tls, argc, argv) sp := functionArgs(tls, argc, argv, false) defer releaseUDFArgs(sp) err := gc.impl.Filter(int(idxNum), idxStrGo, *sp) if err != nil { Loading Loading @@ -762,7 +762,7 @@ func vtabUpdateTrampoline(tls *libc.TLS, pVtab uintptr, argc int32, argv uintptr nCols := argc - 2 // Extract column values starting from argv[2] colsPtr := argv + uintptr(2)*sqliteValPtrSize sp := functionArgs(tls, nCols, colsPtr) sp := functionArgs(tls, nCols, colsPtr, false) defer releaseUDFArgs(sp) cols := *sp Loading Loading
func_test.go +183 −0 Original line number Diff line number Diff line Loading @@ -17,6 +17,7 @@ import ( "os" "path" "path/filepath" "reflect" "regexp" "strings" "sync" Loading Loading @@ -227,6 +228,18 @@ func init() { return nil, nil }, ) // Volatile counterpart to issue226_noop, registered with VolatileArgs=true. // Paired with BenchmarkUDFArgsAllocationVolatile to quantify the additional // saving from skipping the per-call TEXT / BLOB copy. MustRegisterFunction("issue226_noop_volatile", &FunctionImpl{ NArgs: 3, Deterministic: true, VolatileArgs: true, Scalar: func(ctx *FunctionContext, args []driver.Value) (driver.Value, error) { return nil, nil }, }) } // BenchmarkUDFArgsAllocation measures allocations in functionArgs when a Loading Loading @@ -1214,3 +1227,173 @@ func TestRegisteredFunctions(t *testing.T) { }) }) } // BenchmarkUDFArgsAllocationVolatile mirrors BenchmarkUDFArgsAllocation but // invokes a UDF registered with VolatileArgs=true. The difference between the // two benchmarks isolates the per-call cost of copying TEXT / BLOB argument // bodies into Go-owned memory, which is what VolatileArgs eliminates. func BenchmarkUDFArgsAllocationVolatile(b *testing.B) { db, err := sql.Open(driverName, "file::memory:") if err != nil { b.Fatal(err) } defer db.Close() if _, err := db.Exec(`CREATE TABLE t (a INTEGER, b TEXT, c BLOB)`); err != nil { b.Fatal(err) } const rows = 1000 tx, err := db.Begin() if err != nil { b.Fatal(err) } stmt, err := tx.Prepare(`INSERT INTO t (a, b, c) VALUES (?, ?, ?)`) if err != nil { b.Fatal(err) } for i := 0; i < rows; i++ { if _, err := stmt.Exec(int64(i), "hello", []byte{1, 2, 3}); err != nil { b.Fatal(err) } } stmt.Close() if err := tx.Commit(); err != nil { b.Fatal(err) } b.ReportAllocs() b.ResetTimer() for i := 0; i < b.N; i++ { r, err := db.Query(`SELECT issue226_noop_volatile(a, b, c) FROM t`) if err != nil { b.Fatal(err) } for r.Next() { } if err := r.Err(); err != nil { b.Fatal(err) } r.Close() } } // TestVolatileArgsScalar verifies that a scalar UDF registered with // VolatileArgs=true still receives correct TEXT and BLOB argument values // (including the empty cases that take a short-circuit path in functionArgs). // The UDF copies each value before the call returns, which is the required // usage pattern for VolatileArgs callbacks. func TestVolatileArgsScalar(t *testing.T) { var ( gotStrings []string gotBlobs [][]byte mu sync.Mutex ) MustRegisterFunction("vol_recorder_scalar", &FunctionImpl{ NArgs: 2, Deterministic: true, VolatileArgs: true, Scalar: func(ctx *FunctionContext, args []driver.Value) (driver.Value, error) { s := args[0].(string) b := args[1].([]byte) mu.Lock() gotStrings = append(gotStrings, strings.Clone(s)) gotBlobs = append(gotBlobs, append([]byte(nil), b...)) mu.Unlock() return nil, nil }, }) db, err := sql.Open(driverName, "file::memory:") if err != nil { t.Fatal(err) } defer db.Close() if _, err := db.Exec(`CREATE TABLE t (s TEXT, b BLOB)`); err != nil { t.Fatal(err) } if _, err := db.Exec(`INSERT INTO t (s, b) VALUES ('alpha', X'01020304'), ('beta', X'AABB'), ('', NULL)`); err != nil { t.Fatal(err) } rows, err := db.Query(`SELECT vol_recorder_scalar(s, COALESCE(b, X'')) FROM t ORDER BY rowid`) if err != nil { t.Fatal(err) } for rows.Next() { } if err := rows.Err(); err != nil { t.Fatal(err) } rows.Close() wantStrings := []string{"alpha", "beta", ""} if !reflect.DeepEqual(gotStrings, wantStrings) { t.Errorf("volatile scalar TEXT: got %q, want %q", gotStrings, wantStrings) } wantBlobs := [][]byte{{1, 2, 3, 4}, {0xAA, 0xBB}, nil} if !reflect.DeepEqual(gotBlobs, wantBlobs) { t.Errorf("volatile scalar BLOB: got %v, want %v", gotBlobs, wantBlobs) } } // volatileAggregate is an aggregate function used by TestVolatileArgsAggregate. // Each Step copies its argument into the aggregate's own buffer; the volatile // contract requires this. type volatileAggregate struct { strs []string blobs [][]byte } func (a *volatileAggregate) Step(ctx *FunctionContext, args []driver.Value) error { a.strs = append(a.strs, strings.Clone(args[0].(string))) a.blobs = append(a.blobs, append([]byte(nil), args[1].([]byte)...)) return nil } func (a *volatileAggregate) WindowInverse(ctx *FunctionContext, args []driver.Value) error { return nil } func (a *volatileAggregate) WindowValue(ctx *FunctionContext) (driver.Value, error) { return fmt.Sprintf("strs=%q blobs=%v", a.strs, a.blobs), nil } func (a *volatileAggregate) Final(ctx *FunctionContext) {} // TestVolatileArgsAggregate verifies that the volatile-args path is honored // by the Step trampoline for aggregate functions. The aggregate sees every // row's TEXT and BLOB and the assembled result must match the inserted data. func TestVolatileArgsAggregate(t *testing.T) { MustRegisterFunction("vol_recorder_agg", &FunctionImpl{ NArgs: 2, VolatileArgs: true, MakeAggregate: func(ctx FunctionContext) (AggregateFunction, error) { return &volatileAggregate{}, nil }, }) db, err := sql.Open(driverName, "file::memory:") if err != nil { t.Fatal(err) } defer db.Close() if _, err := db.Exec(`CREATE TABLE t (s TEXT, b BLOB)`); err != nil { t.Fatal(err) } if _, err := db.Exec(`INSERT INTO t (s, b) VALUES ('one', X'01'), ('two', X'02'), ('three', X'03')`); err != nil { t.Fatal(err) } var got string if err := db.QueryRow(`SELECT vol_recorder_agg(s, b) FROM (SELECT s, b FROM t ORDER BY rowid)`).Scan(&got); err != nil { t.Fatal(err) } want := `strs=["one" "two" "three"] blobs=[[1] [2] [3]]` if got != want { t.Errorf("volatile aggregate result:\n got %s\nwant %s", got, want) } }
sqlite.go +127 −35 Original line number Diff line number Diff line Loading @@ -255,6 +255,45 @@ type FunctionImpl struct { // MakeAggregate is called at the beginning of each evaluation of an // aggregate function. MakeAggregate func(ctx FunctionContext) (AggregateFunction, error) // VolatileArgs is an opt-in performance flag that eliminates the per-call // allocation of string and []byte argument bodies. When true, the driver // passes argument strings and byte slices as zero-copy views into // SQLite-owned memory instead of Go-allocated copies. // // Setting this is unsafe unless the user-provided Scalar / Step / // WindowInverse callbacks treat string and []byte arguments as strictly // transient: they must not be retained past the return of the call, // neither directly (storing the slice or string in a struct field, map, // channel, or outer-scope variable) nor indirectly (passing them to // something that captures them, including most concurrency primitives). // // Retaining a volatile argument produces silent data corruption: SQLite // reuses the underlying buffer for the next row, so on a later read every // retained value will appear to hold the contents of the most recent row. // The Go race detector cannot catch this because UDF execution is // sequential on a single goroutine; the corruption is deterministic and // invisible to -race. // // As a guard against accidental capture, callbacks that must retain // values across rows should copy: // // saved := append([]byte(nil), args[0].([]byte)...) // []byte // saved := string(append([]byte(nil), args[0].(string)...)) // string, no aliasing // // When in doubt, leave VolatileArgs at its default (false) — the driver // already pools the argument-slice header (issue #226), so the per-row // overhead with VolatileArgs=false is one make([]byte) per BLOB column // and one libc.GoString per TEXT column, not a fresh slice header. // // Similarly, do not re-enter SQLite on the same connection while a // volatile argument is in scope. A nested Query/Exec on the same conn // can cause SQLite to reuse the underlying value buffers, so a volatile // string or []byte read before the nested call may alias different // bytes after it returns. // // VolatileArgs has no effect on integer, float, time, or NULL arguments. VolatileArgs bool } // An AggregateFunction is an invocation of an aggregate or window function. See Loading @@ -266,13 +305,18 @@ type FunctionImpl struct { type AggregateFunction interface { // Step is called for each row of an aggregate function's SQL // invocation. The argument Values are not valid past the return of the // function. // function. When the aggregate was registered with // [FunctionImpl.VolatileArgs] set to true, string and []byte arguments // in rowArgs are zero-copy views into SQLite-owned memory and retaining // them produces silent data corruption — see [FunctionImpl.VolatileArgs] // for the full safety contract. Step(ctx *FunctionContext, rowArgs []driver.Value) error // WindowInverse is called to remove the oldest presently aggregated // result of Step from the current window. The arguments are those // passed to Step for the row being removed. The argument Values are not // valid past the return of the function. // valid past the return of the function. The same // [FunctionImpl.VolatileArgs] caveat applies as for Step. WindowInverse(ctx *FunctionContext, rowArgs []driver.Value) error // WindowValue is called to get the current value of an aggregate Loading Loading @@ -506,7 +550,7 @@ func registerFunction( if impl.Scalar != nil { xFuncs.mu.Lock() id := xFuncs.ids.next() xFuncs.m[id] = impl.Scalar xFuncs.m[id] = xFuncEntry{fn: impl.Scalar, volatile: impl.VolatileArgs} xFuncs.mu.Unlock() udf.scalar = true Loading @@ -514,7 +558,7 @@ func registerFunction( } else { xAggregateFactories.mu.Lock() id := xAggregateFactories.ids.next() xAggregateFactories.m[id] = impl.MakeAggregate xAggregateFactories.m[id] = xAggregateFactoryEntry{factory: impl.MakeAggregate, volatile: impl.VolatileArgs} xAggregateFactories.mu.Unlock() udf.pApp = id Loading Loading @@ -594,7 +638,13 @@ func releaseUDFArgs(sp *[]driver.Value) { // functionArgs prepares a []driver.Value for one user-function invocation. // The returned slice is owned by the driver and must be released via // releaseUDFArgs once the user function returns. func functionArgs(tls *libc.TLS, argc int32, argv uintptr) *[]driver.Value { // // When volatile is true, SQLITE_TEXT and SQLITE_BLOB arguments are returned as // zero-copy views into SQLite-owned memory (see [FunctionImpl.VolatileArgs] // for the user-facing safety contract). When false (the default for all // existing call sites), text and blob payloads are copied into Go-owned // memory and stay valid for the lifetime of the slice. func functionArgs(tls *libc.TLS, argc int32, argv uintptr, volatile bool) *[]driver.Value { sp := acquireUDFArgs(int(argc)) args := *sp for i := int32(0); i < argc; i++ { Loading @@ -602,7 +652,17 @@ func functionArgs(tls *libc.TLS, argc int32, argv uintptr) *[]driver.Value { switch valType := sqlite3.Xsqlite3_value_type(tls, valPtr); valType { case sqlite3.SQLITE_TEXT: if volatile { p := sqlite3.Xsqlite3_value_text(tls, valPtr) n := sqlite3.Xsqlite3_value_bytes(tls, valPtr) if p == 0 || n == 0 { args[i] = "" } else { args[i] = unsafe.String((*byte)(unsafe.Pointer(p)), int(n)) } } else { args[i] = libc.GoString(sqlite3.Xsqlite3_value_text(tls, valPtr)) } case sqlite3.SQLITE_INTEGER: args[i] = sqlite3.Xsqlite3_value_int64(tls, valPtr) case sqlite3.SQLITE_FLOAT: Loading @@ -612,11 +672,19 @@ func functionArgs(tls *libc.TLS, argc int32, argv uintptr) *[]driver.Value { case sqlite3.SQLITE_BLOB: size := sqlite3.Xsqlite3_value_bytes(tls, valPtr) blobPtr := sqlite3.Xsqlite3_value_blob(tls, valPtr) if volatile { if blobPtr == 0 || size == 0 { args[i] = make([]byte, 0) } else { args[i] = unsafe.Slice((*byte)(unsafe.Pointer(blobPtr)), int(size)) } } else { v := make([]byte, size) if size != 0 { copy(v, (*libc.RawMem)(unsafe.Pointer(blobPtr))[:size:size]) } args[i] = v } default: panic(fmt.Sprintf("unexpected argument type %q passed by sqlite", valType)) } Loading Loading @@ -682,26 +750,26 @@ func functionReturnValue(tls *libc.TLS, ctx uintptr, res driver.Value) error { var ( xFuncs = struct { mu sync.RWMutex m map[uintptr]func(*FunctionContext, []driver.Value) (driver.Value, error) m map[uintptr]xFuncEntry ids idGen }{ m: make(map[uintptr]func(*FunctionContext, []driver.Value) (driver.Value, error)), m: make(map[uintptr]xFuncEntry), } xAggregateFactories = struct { mu sync.RWMutex m map[uintptr]func(FunctionContext) (AggregateFunction, error) m map[uintptr]xAggregateFactoryEntry ids idGen }{ m: make(map[uintptr]func(FunctionContext) (AggregateFunction, error)), m: make(map[uintptr]xAggregateFactoryEntry), } xAggregateContext = struct { mu sync.RWMutex m map[uintptr]AggregateFunction m map[uintptr]xAggregateContextEntry ids idGen }{ m: make(map[uintptr]AggregateFunction), m: make(map[uintptr]xAggregateContextEntry), } xCollations = struct { Loading @@ -713,6 +781,30 @@ var ( } ) // xFuncEntry pairs a registered scalar function with the VolatileArgs flag // captured at registration time. Bundled so trampolines can decide whether to // pass zero-copy or copied argument values without a second map lookup. type xFuncEntry struct { fn func(*FunctionContext, []driver.Value) (driver.Value, error) volatile bool } // xAggregateFactoryEntry pairs a registered aggregate factory with its // VolatileArgs flag, for the same reason as xFuncEntry. type xAggregateFactoryEntry struct { factory func(FunctionContext) (AggregateFunction, error) volatile bool } // xAggregateContextEntry holds the AggregateFunction instance for one // in-flight aggregate evaluation together with the VolatileArgs flag inherited // from its factory registration. Caching the flag here avoids a second // xAggregateFactories lookup in the Step / WindowInverse trampolines. type xAggregateContextEntry struct { fn AggregateFunction volatile bool } type idGen struct { bitset []uint64 } Loading @@ -736,42 +828,42 @@ func (gen *idGen) reclaim(id uintptr) { gen.bitset[bit/64] &^= 1 << (bit % 64) } func makeAggregate(tls *libc.TLS, ctx uintptr) (AggregateFunction, uintptr) { func makeAggregate(tls *libc.TLS, ctx uintptr) (AggregateFunction, bool, uintptr) { goCtx := FunctionContext{tls: tls, ctx: ctx} aggCtx := (*uintptr)(unsafe.Pointer(sqlite3.Xsqlite3_aggregate_context(tls, ctx, int32(ptrSize)))) setErrorResult := errorResultFunction(tls, ctx) if aggCtx == nil { setErrorResult(errors.New("insufficient memory for aggregate")) return nil, 0 return nil, false, 0 } if *aggCtx != 0 { // Already created. xAggregateContext.mu.RLock() f := xAggregateContext.m[*aggCtx] entry := xAggregateContext.m[*aggCtx] xAggregateContext.mu.RUnlock() return f, *aggCtx return entry.fn, entry.volatile, *aggCtx } factoryID := sqlite3.Xsqlite3_user_data(tls, ctx) xAggregateFactories.mu.RLock() factory := xAggregateFactories.m[factoryID] factoryEntry := xAggregateFactories.m[factoryID] xAggregateFactories.mu.RUnlock() f, err := factory(goCtx) f, err := factoryEntry.factory(goCtx) if err != nil { setErrorResult(err) return nil, 0 return nil, false, 0 } if f == nil { setErrorResult(errors.New("MakeAggregate function returned nil")) return nil, 0 return nil, false, 0 } xAggregateContext.mu.Lock() *aggCtx = xAggregateContext.ids.next() xAggregateContext.m[*aggCtx] = f xAggregateContext.m[*aggCtx] = xAggregateContextEntry{fn: f, volatile: factoryEntry.volatile} xAggregateContext.mu.Unlock() return f, *aggCtx return f, factoryEntry.volatile, *aggCtx } // cFuncPointer converts a function defined by a function declaration to a C pointer. Loading @@ -793,13 +885,13 @@ func cFuncPointer[T any](f T) uintptr { func funcTrampoline(tls *libc.TLS, ctx uintptr, argc int32, argv uintptr) { id := sqlite3.Xsqlite3_user_data(tls, ctx) xFuncs.mu.RLock() xFunc := xFuncs.m[id] entry := xFuncs.m[id] xFuncs.mu.RUnlock() setErrorResult := errorResultFunction(tls, ctx) sp := functionArgs(tls, argc, argv) sp := functionArgs(tls, argc, argv, entry.volatile) defer releaseUDFArgs(sp) res, err := xFunc(&FunctionContext{}, *sp) res, err := entry.fn(&FunctionContext{}, *sp) if err != nil { setErrorResult(err) Loading Loading @@ -828,13 +920,13 @@ func sqlite3AllocCString(tls *libc.TLS, s string) uintptr { } func stepTrampoline(tls *libc.TLS, ctx uintptr, argc int32, argv uintptr) { impl, _ := makeAggregate(tls, ctx) impl, volatile, _ := makeAggregate(tls, ctx) if impl == nil { return } setErrorResult := errorResultFunction(tls, ctx) sp := functionArgs(tls, argc, argv) sp := functionArgs(tls, argc, argv, volatile) defer releaseUDFArgs(sp) err := impl.Step(&FunctionContext{}, *sp) if err != nil { Loading @@ -843,13 +935,13 @@ func stepTrampoline(tls *libc.TLS, ctx uintptr, argc int32, argv uintptr) { } func inverseTrampoline(tls *libc.TLS, ctx uintptr, argc int32, argv uintptr) { impl, _ := makeAggregate(tls, ctx) impl, volatile, _ := makeAggregate(tls, ctx) if impl == nil { return } setErrorResult := errorResultFunction(tls, ctx) sp := functionArgs(tls, argc, argv) sp := functionArgs(tls, argc, argv, volatile) defer releaseUDFArgs(sp) err := impl.WindowInverse(&FunctionContext{}, *sp) if err != nil { Loading @@ -858,7 +950,7 @@ func inverseTrampoline(tls *libc.TLS, ctx uintptr, argc int32, argv uintptr) { } func valueTrampoline(tls *libc.TLS, ctx uintptr) { impl, _ := makeAggregate(tls, ctx) impl, _, _ := makeAggregate(tls, ctx) if impl == nil { return } Loading @@ -876,7 +968,7 @@ func valueTrampoline(tls *libc.TLS, ctx uintptr) { } func finalTrampoline(tls *libc.TLS, ctx uintptr) { impl, id := makeAggregate(tls, ctx) impl, _, id := makeAggregate(tls, ctx) if impl == nil { return } Loading
vtab.go +2 −2 Original line number Diff line number Diff line Loading @@ -542,7 +542,7 @@ func vtabFilterTrampoline(tls *libc.TLS, pCursor uintptr, idxNum int32, idxStr u if idxStr != 0 { idxStrGo = libc.GoString(idxStr) } sp := functionArgs(tls, argc, argv) sp := functionArgs(tls, argc, argv, false) defer releaseUDFArgs(sp) err := gc.impl.Filter(int(idxNum), idxStrGo, *sp) if err != nil { Loading Loading @@ -762,7 +762,7 @@ func vtabUpdateTrampoline(tls *libc.TLS, pVtab uintptr, argc int32, argv uintptr nCols := argc - 2 // Extract column values starting from argv[2] colsPtr := argv + uintptr(2)*sqliteValPtrSize sp := functionArgs(tls, nCols, colsPtr) sp := functionArgs(tls, nCols, colsPtr, false) defer releaseUDFArgs(sp) cols := *sp Loading