Commit fac1cab2 authored by cznic's avatar cznic
Browse files

Merge branch 'feat/volatile-args-opt-in' into 'master'

add FunctionImpl.VolatileArgs opt-in for zero-copy TEXT/BLOB args (#226)

See merge request !120
parents 0b223921 569614c5
Loading
Loading
Loading
Loading
+183 −0
Original line number Diff line number Diff line
@@ -17,6 +17,7 @@ import (
	"os"
	"path"
	"path/filepath"
	"reflect"
	"regexp"
	"strings"
	"sync"
@@ -227,6 +228,18 @@ func init() {
			return nil, nil
		},
	)

	// Volatile counterpart to issue226_noop, registered with VolatileArgs=true.
	// Paired with BenchmarkUDFArgsAllocationVolatile to quantify the additional
	// saving from skipping the per-call TEXT / BLOB copy.
	MustRegisterFunction("issue226_noop_volatile", &FunctionImpl{
		NArgs:         3,
		Deterministic: true,
		VolatileArgs:  true,
		Scalar: func(ctx *FunctionContext, args []driver.Value) (driver.Value, error) {
			return nil, nil
		},
	})
}

// BenchmarkUDFArgsAllocation measures allocations in functionArgs when a
@@ -1214,3 +1227,173 @@ func TestRegisteredFunctions(t *testing.T) {
		})
	})
}

// BenchmarkUDFArgsAllocationVolatile mirrors BenchmarkUDFArgsAllocation but
// invokes a UDF registered with VolatileArgs=true. The difference between the
// two benchmarks isolates the per-call cost of copying TEXT / BLOB argument
// bodies into Go-owned memory, which is what VolatileArgs eliminates.
func BenchmarkUDFArgsAllocationVolatile(b *testing.B) {
	db, err := sql.Open(driverName, "file::memory:")
	if err != nil {
		b.Fatal(err)
	}
	defer db.Close()

	if _, err := db.Exec(`CREATE TABLE t (a INTEGER, b TEXT, c BLOB)`); err != nil {
		b.Fatal(err)
	}

	const rows = 1000
	tx, err := db.Begin()
	if err != nil {
		b.Fatal(err)
	}
	stmt, err := tx.Prepare(`INSERT INTO t (a, b, c) VALUES (?, ?, ?)`)
	if err != nil {
		b.Fatal(err)
	}
	for i := 0; i < rows; i++ {
		if _, err := stmt.Exec(int64(i), "hello", []byte{1, 2, 3}); err != nil {
			b.Fatal(err)
		}
	}
	stmt.Close()
	if err := tx.Commit(); err != nil {
		b.Fatal(err)
	}

	b.ReportAllocs()
	b.ResetTimer()
	for i := 0; i < b.N; i++ {
		r, err := db.Query(`SELECT issue226_noop_volatile(a, b, c) FROM t`)
		if err != nil {
			b.Fatal(err)
		}
		for r.Next() {
		}
		if err := r.Err(); err != nil {
			b.Fatal(err)
		}
		r.Close()
	}
}

// TestVolatileArgsScalar verifies that a scalar UDF registered with
// VolatileArgs=true still receives correct TEXT and BLOB argument values
// (including the empty cases that take a short-circuit path in functionArgs).
// The UDF copies each value before the call returns, which is the required
// usage pattern for VolatileArgs callbacks.
func TestVolatileArgsScalar(t *testing.T) {
	var (
		gotStrings []string
		gotBlobs   [][]byte
		mu         sync.Mutex
	)

	MustRegisterFunction("vol_recorder_scalar", &FunctionImpl{
		NArgs:         2,
		Deterministic: true,
		VolatileArgs:  true,
		Scalar: func(ctx *FunctionContext, args []driver.Value) (driver.Value, error) {
			s := args[0].(string)
			b := args[1].([]byte)
			mu.Lock()
			gotStrings = append(gotStrings, strings.Clone(s))
			gotBlobs = append(gotBlobs, append([]byte(nil), b...))
			mu.Unlock()
			return nil, nil
		},
	})

	db, err := sql.Open(driverName, "file::memory:")
	if err != nil {
		t.Fatal(err)
	}
	defer db.Close()

	if _, err := db.Exec(`CREATE TABLE t (s TEXT, b BLOB)`); err != nil {
		t.Fatal(err)
	}
	if _, err := db.Exec(`INSERT INTO t (s, b) VALUES ('alpha', X'01020304'), ('beta', X'AABB'), ('', NULL)`); err != nil {
		t.Fatal(err)
	}

	rows, err := db.Query(`SELECT vol_recorder_scalar(s, COALESCE(b, X'')) FROM t ORDER BY rowid`)
	if err != nil {
		t.Fatal(err)
	}
	for rows.Next() {
	}
	if err := rows.Err(); err != nil {
		t.Fatal(err)
	}
	rows.Close()

	wantStrings := []string{"alpha", "beta", ""}
	if !reflect.DeepEqual(gotStrings, wantStrings) {
		t.Errorf("volatile scalar TEXT: got %q, want %q", gotStrings, wantStrings)
	}
	wantBlobs := [][]byte{{1, 2, 3, 4}, {0xAA, 0xBB}, nil}
	if !reflect.DeepEqual(gotBlobs, wantBlobs) {
		t.Errorf("volatile scalar BLOB: got %v, want %v", gotBlobs, wantBlobs)
	}
}

// volatileAggregate is an aggregate function used by TestVolatileArgsAggregate.
// Each Step copies its argument into the aggregate's own buffer; the volatile
// contract requires this.
type volatileAggregate struct {
	strs  []string
	blobs [][]byte
}

func (a *volatileAggregate) Step(ctx *FunctionContext, args []driver.Value) error {
	a.strs = append(a.strs, strings.Clone(args[0].(string)))
	a.blobs = append(a.blobs, append([]byte(nil), args[1].([]byte)...))
	return nil
}

func (a *volatileAggregate) WindowInverse(ctx *FunctionContext, args []driver.Value) error {
	return nil
}

func (a *volatileAggregate) WindowValue(ctx *FunctionContext) (driver.Value, error) {
	return fmt.Sprintf("strs=%q blobs=%v", a.strs, a.blobs), nil
}

func (a *volatileAggregate) Final(ctx *FunctionContext) {}

// TestVolatileArgsAggregate verifies that the volatile-args path is honored
// by the Step trampoline for aggregate functions. The aggregate sees every
// row's TEXT and BLOB and the assembled result must match the inserted data.
func TestVolatileArgsAggregate(t *testing.T) {
	MustRegisterFunction("vol_recorder_agg", &FunctionImpl{
		NArgs:        2,
		VolatileArgs: true,
		MakeAggregate: func(ctx FunctionContext) (AggregateFunction, error) {
			return &volatileAggregate{}, nil
		},
	})

	db, err := sql.Open(driverName, "file::memory:")
	if err != nil {
		t.Fatal(err)
	}
	defer db.Close()

	if _, err := db.Exec(`CREATE TABLE t (s TEXT, b BLOB)`); err != nil {
		t.Fatal(err)
	}
	if _, err := db.Exec(`INSERT INTO t (s, b) VALUES ('one', X'01'), ('two', X'02'), ('three', X'03')`); err != nil {
		t.Fatal(err)
	}

	var got string
	if err := db.QueryRow(`SELECT vol_recorder_agg(s, b) FROM (SELECT s, b FROM t ORDER BY rowid)`).Scan(&got); err != nil {
		t.Fatal(err)
	}
	want := `strs=["one" "two" "three"] blobs=[[1] [2] [3]]`
	if got != want {
		t.Errorf("volatile aggregate result:\n got %s\nwant %s", got, want)
	}
}
+127 −35
Original line number Diff line number Diff line
@@ -255,6 +255,45 @@ type FunctionImpl struct {
	// MakeAggregate is called at the beginning of each evaluation of an
	// aggregate function.
	MakeAggregate func(ctx FunctionContext) (AggregateFunction, error)

	// VolatileArgs is an opt-in performance flag that eliminates the per-call
	// allocation of string and []byte argument bodies. When true, the driver
	// passes argument strings and byte slices as zero-copy views into
	// SQLite-owned memory instead of Go-allocated copies.
	//
	// Setting this is unsafe unless the user-provided Scalar / Step /
	// WindowInverse callbacks treat string and []byte arguments as strictly
	// transient: they must not be retained past the return of the call,
	// neither directly (storing the slice or string in a struct field, map,
	// channel, or outer-scope variable) nor indirectly (passing them to
	// something that captures them, including most concurrency primitives).
	//
	// Retaining a volatile argument produces silent data corruption: SQLite
	// reuses the underlying buffer for the next row, so on a later read every
	// retained value will appear to hold the contents of the most recent row.
	// The Go race detector cannot catch this because UDF execution is
	// sequential on a single goroutine; the corruption is deterministic and
	// invisible to -race.
	//
	// As a guard against accidental capture, callbacks that must retain
	// values across rows should copy:
	//
	//	saved := append([]byte(nil), args[0].([]byte)...) // []byte
	//	saved := string(append([]byte(nil), args[0].(string)...)) // string, no aliasing
	//
	// When in doubt, leave VolatileArgs at its default (false) — the driver
	// already pools the argument-slice header (issue #226), so the per-row
	// overhead with VolatileArgs=false is one make([]byte) per BLOB column
	// and one libc.GoString per TEXT column, not a fresh slice header.
	//
	// Similarly, do not re-enter SQLite on the same connection while a
	// volatile argument is in scope. A nested Query/Exec on the same conn
	// can cause SQLite to reuse the underlying value buffers, so a volatile
	// string or []byte read before the nested call may alias different
	// bytes after it returns.
	//
	// VolatileArgs has no effect on integer, float, time, or NULL arguments.
	VolatileArgs bool
}

// An AggregateFunction is an invocation of an aggregate or window function. See
@@ -266,13 +305,18 @@ type FunctionImpl struct {
type AggregateFunction interface {
	// Step is called for each row of an aggregate function's SQL
	// invocation. The argument Values are not valid past the return of the
	// function.
	// function. When the aggregate was registered with
	// [FunctionImpl.VolatileArgs] set to true, string and []byte arguments
	// in rowArgs are zero-copy views into SQLite-owned memory and retaining
	// them produces silent data corruption — see [FunctionImpl.VolatileArgs]
	// for the full safety contract.
	Step(ctx *FunctionContext, rowArgs []driver.Value) error

	// WindowInverse is called to remove the oldest presently aggregated
	// result of Step from the current window. The arguments are those
	// passed to Step for the row being removed. The argument Values are not
	// valid past the return of the function.
	// valid past the return of the function. The same
	// [FunctionImpl.VolatileArgs] caveat applies as for Step.
	WindowInverse(ctx *FunctionContext, rowArgs []driver.Value) error

	// WindowValue is called to get the current value of an aggregate
@@ -506,7 +550,7 @@ func registerFunction(
	if impl.Scalar != nil {
		xFuncs.mu.Lock()
		id := xFuncs.ids.next()
		xFuncs.m[id] = impl.Scalar
		xFuncs.m[id] = xFuncEntry{fn: impl.Scalar, volatile: impl.VolatileArgs}
		xFuncs.mu.Unlock()

		udf.scalar = true
@@ -514,7 +558,7 @@ func registerFunction(
	} else {
		xAggregateFactories.mu.Lock()
		id := xAggregateFactories.ids.next()
		xAggregateFactories.m[id] = impl.MakeAggregate
		xAggregateFactories.m[id] = xAggregateFactoryEntry{factory: impl.MakeAggregate, volatile: impl.VolatileArgs}
		xAggregateFactories.mu.Unlock()

		udf.pApp = id
@@ -594,7 +638,13 @@ func releaseUDFArgs(sp *[]driver.Value) {
// functionArgs prepares a []driver.Value for one user-function invocation.
// The returned slice is owned by the driver and must be released via
// releaseUDFArgs once the user function returns.
func functionArgs(tls *libc.TLS, argc int32, argv uintptr) *[]driver.Value {
//
// When volatile is true, SQLITE_TEXT and SQLITE_BLOB arguments are returned as
// zero-copy views into SQLite-owned memory (see [FunctionImpl.VolatileArgs]
// for the user-facing safety contract). When false (the default for all
// existing call sites), text and blob payloads are copied into Go-owned
// memory and stay valid for the lifetime of the slice.
func functionArgs(tls *libc.TLS, argc int32, argv uintptr, volatile bool) *[]driver.Value {
	sp := acquireUDFArgs(int(argc))
	args := *sp
	for i := int32(0); i < argc; i++ {
@@ -602,7 +652,17 @@ func functionArgs(tls *libc.TLS, argc int32, argv uintptr) *[]driver.Value {

		switch valType := sqlite3.Xsqlite3_value_type(tls, valPtr); valType {
		case sqlite3.SQLITE_TEXT:
			if volatile {
				p := sqlite3.Xsqlite3_value_text(tls, valPtr)
				n := sqlite3.Xsqlite3_value_bytes(tls, valPtr)
				if p == 0 || n == 0 {
					args[i] = ""
				} else {
					args[i] = unsafe.String((*byte)(unsafe.Pointer(p)), int(n))
				}
			} else {
				args[i] = libc.GoString(sqlite3.Xsqlite3_value_text(tls, valPtr))
			}
		case sqlite3.SQLITE_INTEGER:
			args[i] = sqlite3.Xsqlite3_value_int64(tls, valPtr)
		case sqlite3.SQLITE_FLOAT:
@@ -612,11 +672,19 @@ func functionArgs(tls *libc.TLS, argc int32, argv uintptr) *[]driver.Value {
		case sqlite3.SQLITE_BLOB:
			size := sqlite3.Xsqlite3_value_bytes(tls, valPtr)
			blobPtr := sqlite3.Xsqlite3_value_blob(tls, valPtr)
			if volatile {
				if blobPtr == 0 || size == 0 {
					args[i] = make([]byte, 0)
				} else {
					args[i] = unsafe.Slice((*byte)(unsafe.Pointer(blobPtr)), int(size))
				}
			} else {
				v := make([]byte, size)
				if size != 0 {
					copy(v, (*libc.RawMem)(unsafe.Pointer(blobPtr))[:size:size])
				}
				args[i] = v
			}
		default:
			panic(fmt.Sprintf("unexpected argument type %q passed by sqlite", valType))
		}
@@ -682,26 +750,26 @@ func functionReturnValue(tls *libc.TLS, ctx uintptr, res driver.Value) error {
var (
	xFuncs = struct {
		mu  sync.RWMutex
		m   map[uintptr]func(*FunctionContext, []driver.Value) (driver.Value, error)
		m   map[uintptr]xFuncEntry
		ids idGen
	}{
		m: make(map[uintptr]func(*FunctionContext, []driver.Value) (driver.Value, error)),
		m: make(map[uintptr]xFuncEntry),
	}

	xAggregateFactories = struct {
		mu  sync.RWMutex
		m   map[uintptr]func(FunctionContext) (AggregateFunction, error)
		m   map[uintptr]xAggregateFactoryEntry
		ids idGen
	}{
		m: make(map[uintptr]func(FunctionContext) (AggregateFunction, error)),
		m: make(map[uintptr]xAggregateFactoryEntry),
	}

	xAggregateContext = struct {
		mu  sync.RWMutex
		m   map[uintptr]AggregateFunction
		m   map[uintptr]xAggregateContextEntry
		ids idGen
	}{
		m: make(map[uintptr]AggregateFunction),
		m: make(map[uintptr]xAggregateContextEntry),
	}

	xCollations = struct {
@@ -713,6 +781,30 @@ var (
	}
)

// xFuncEntry pairs a registered scalar function with the VolatileArgs flag
// captured at registration time. Bundled so trampolines can decide whether to
// pass zero-copy or copied argument values without a second map lookup.
type xFuncEntry struct {
	fn       func(*FunctionContext, []driver.Value) (driver.Value, error)
	volatile bool
}

// xAggregateFactoryEntry pairs a registered aggregate factory with its
// VolatileArgs flag, for the same reason as xFuncEntry.
type xAggregateFactoryEntry struct {
	factory  func(FunctionContext) (AggregateFunction, error)
	volatile bool
}

// xAggregateContextEntry holds the AggregateFunction instance for one
// in-flight aggregate evaluation together with the VolatileArgs flag inherited
// from its factory registration. Caching the flag here avoids a second
// xAggregateFactories lookup in the Step / WindowInverse trampolines.
type xAggregateContextEntry struct {
	fn       AggregateFunction
	volatile bool
}

type idGen struct {
	bitset []uint64
}
@@ -736,42 +828,42 @@ func (gen *idGen) reclaim(id uintptr) {
	gen.bitset[bit/64] &^= 1 << (bit % 64)
}

func makeAggregate(tls *libc.TLS, ctx uintptr) (AggregateFunction, uintptr) {
func makeAggregate(tls *libc.TLS, ctx uintptr) (AggregateFunction, bool, uintptr) {
	goCtx := FunctionContext{tls: tls, ctx: ctx}
	aggCtx := (*uintptr)(unsafe.Pointer(sqlite3.Xsqlite3_aggregate_context(tls, ctx, int32(ptrSize))))
	setErrorResult := errorResultFunction(tls, ctx)
	if aggCtx == nil {
		setErrorResult(errors.New("insufficient memory for aggregate"))
		return nil, 0
		return nil, false, 0
	}
	if *aggCtx != 0 {
		// Already created.
		xAggregateContext.mu.RLock()
		f := xAggregateContext.m[*aggCtx]
		entry := xAggregateContext.m[*aggCtx]
		xAggregateContext.mu.RUnlock()
		return f, *aggCtx
		return entry.fn, entry.volatile, *aggCtx
	}

	factoryID := sqlite3.Xsqlite3_user_data(tls, ctx)
	xAggregateFactories.mu.RLock()
	factory := xAggregateFactories.m[factoryID]
	factoryEntry := xAggregateFactories.m[factoryID]
	xAggregateFactories.mu.RUnlock()

	f, err := factory(goCtx)
	f, err := factoryEntry.factory(goCtx)
	if err != nil {
		setErrorResult(err)
		return nil, 0
		return nil, false, 0
	}
	if f == nil {
		setErrorResult(errors.New("MakeAggregate function returned nil"))
		return nil, 0
		return nil, false, 0
	}

	xAggregateContext.mu.Lock()
	*aggCtx = xAggregateContext.ids.next()
	xAggregateContext.m[*aggCtx] = f
	xAggregateContext.m[*aggCtx] = xAggregateContextEntry{fn: f, volatile: factoryEntry.volatile}
	xAggregateContext.mu.Unlock()
	return f, *aggCtx
	return f, factoryEntry.volatile, *aggCtx
}

// cFuncPointer converts a function defined by a function declaration to a C pointer.
@@ -793,13 +885,13 @@ func cFuncPointer[T any](f T) uintptr {
func funcTrampoline(tls *libc.TLS, ctx uintptr, argc int32, argv uintptr) {
	id := sqlite3.Xsqlite3_user_data(tls, ctx)
	xFuncs.mu.RLock()
	xFunc := xFuncs.m[id]
	entry := xFuncs.m[id]
	xFuncs.mu.RUnlock()

	setErrorResult := errorResultFunction(tls, ctx)
	sp := functionArgs(tls, argc, argv)
	sp := functionArgs(tls, argc, argv, entry.volatile)
	defer releaseUDFArgs(sp)
	res, err := xFunc(&FunctionContext{}, *sp)
	res, err := entry.fn(&FunctionContext{}, *sp)

	if err != nil {
		setErrorResult(err)
@@ -828,13 +920,13 @@ func sqlite3AllocCString(tls *libc.TLS, s string) uintptr {
}

func stepTrampoline(tls *libc.TLS, ctx uintptr, argc int32, argv uintptr) {
	impl, _ := makeAggregate(tls, ctx)
	impl, volatile, _ := makeAggregate(tls, ctx)
	if impl == nil {
		return
	}

	setErrorResult := errorResultFunction(tls, ctx)
	sp := functionArgs(tls, argc, argv)
	sp := functionArgs(tls, argc, argv, volatile)
	defer releaseUDFArgs(sp)
	err := impl.Step(&FunctionContext{}, *sp)
	if err != nil {
@@ -843,13 +935,13 @@ func stepTrampoline(tls *libc.TLS, ctx uintptr, argc int32, argv uintptr) {
}

func inverseTrampoline(tls *libc.TLS, ctx uintptr, argc int32, argv uintptr) {
	impl, _ := makeAggregate(tls, ctx)
	impl, volatile, _ := makeAggregate(tls, ctx)
	if impl == nil {
		return
	}

	setErrorResult := errorResultFunction(tls, ctx)
	sp := functionArgs(tls, argc, argv)
	sp := functionArgs(tls, argc, argv, volatile)
	defer releaseUDFArgs(sp)
	err := impl.WindowInverse(&FunctionContext{}, *sp)
	if err != nil {
@@ -858,7 +950,7 @@ func inverseTrampoline(tls *libc.TLS, ctx uintptr, argc int32, argv uintptr) {
}

func valueTrampoline(tls *libc.TLS, ctx uintptr) {
	impl, _ := makeAggregate(tls, ctx)
	impl, _, _ := makeAggregate(tls, ctx)
	if impl == nil {
		return
	}
@@ -876,7 +968,7 @@ func valueTrampoline(tls *libc.TLS, ctx uintptr) {
}

func finalTrampoline(tls *libc.TLS, ctx uintptr) {
	impl, id := makeAggregate(tls, ctx)
	impl, _, id := makeAggregate(tls, ctx)
	if impl == nil {
		return
	}
+2 −2
Original line number Diff line number Diff line
@@ -542,7 +542,7 @@ func vtabFilterTrampoline(tls *libc.TLS, pCursor uintptr, idxNum int32, idxStr u
	if idxStr != 0 {
		idxStrGo = libc.GoString(idxStr)
	}
	sp := functionArgs(tls, argc, argv)
	sp := functionArgs(tls, argc, argv, false)
	defer releaseUDFArgs(sp)
	err := gc.impl.Filter(int(idxNum), idxStrGo, *sp)
	if err != nil {
@@ -762,7 +762,7 @@ func vtabUpdateTrampoline(tls *libc.TLS, pVtab uintptr, argc int32, argv uintptr
	nCols := argc - 2
	// Extract column values starting from argv[2]
	colsPtr := argv + uintptr(2)*sqliteValPtrSize
	sp := functionArgs(tls, nCols, colsPtr)
	sp := functionArgs(tls, nCols, colsPtr, false)
	defer releaseUDFArgs(sp)
	cols := *sp