Commit ef93ba85 authored by Josh Bleecher Snyder's avatar Josh Bleecher Snyder
Browse files

improve memory safety of allocs in stmt.query

The multi-statement path was subject to both double-frees and leaks.

Having allocs declared once and re-used across loop iterations
made it hard to manage correctly. Reduce its scope to within the
inner closure and defer freeAllocs there.

Even within a loop, it was easy to mismanage; force callers to deal
with this consistently by passing a pointer to newRows and zeroing there.

Take care to Close rows objects from previous iterations when
a subsequent statement in a multi-statement query fails.

The tests here demonstrate both double-frees and
leaked memory without these changes.

I noticed while in this code that in the optimized
single-statement path, there are some unnecessary calls to
reset/clearBindings after newRows. When newRows fails, its internal
defer calls r.Close, which finalizes the prepared statement.
The subsequent reset and clearBindings calls would then operate
on an already-finalized handle. This would be bad, except that
columnCount/columnName cannot fail on a valid pstmt.
Therefore, the error path is dead code.
I didn't touch it, even though pedantically it looks wrong.
There might also be some dangling SQLITE_STATIC references.
Again, this appears to be safe in practice, so I left it alone.
A follow-up could make it conform to the docs
and ensure extra safety.

I don't plan to do any of this follow-up work;
I am focused for now only on demonstrable, directly testable issues.
parent 2a97c686
Loading
Loading
Loading
Loading
+63 −0
Original line number Diff line number Diff line
@@ -4085,3 +4085,66 @@ func TestTxCommitBusyFix(t *testing.T) {
	// If we got here, the fix is working. Clean up.
	tx3.Rollback()
}

// TestMultiStmtQueryStringRoundtrip alternates between running parameterized and parameter-free statements.
// After each round-trip we immediately bind fresh strings on the same connection to stress the C allocator.
// This is a regression test for a double-free bug of stale bind-parameter memory.
// Run with -race to enable checkptr, which catches the problem quickly and reliably.
func TestMultiStmtQueryStringRoundtrip(t *testing.T) {
	db, err := sql.Open("sqlite", "file::memory:")
	if err != nil {
		t.Fatal(err)
	}
	defer db.Close()
	db.SetMaxOpenConns(1) // force connection reuse so that corruption accumulates

	ctx := context.Background()
	conn, err := db.Conn(ctx)
	if err != nil {
		t.Fatal(err)
	}
	defer conn.Close()

	if _, err := conn.ExecContext(ctx, "CREATE TABLE t(v TEXT)"); err != nil {
		t.Fatal(err)
	}

	// To reproduce without -race/checkptr, increase the number of iterations here a lot and wait for a crash, a hang, or data corruption
	const iters = 1
	for i := range iters {
		val := fmt.Sprintf("iter-%04d-%s", i, strings.Repeat("x", 1024))

		// Multi-statement query: first statement binds a string param, second statement has no params.
		rows, err := conn.QueryContext(ctx, "INSERT INTO t VALUES(?); SELECT v FROM t WHERE rowid = last_insert_rowid()", val)
		if err != nil {
			t.Fatalf("(%d) query err: %v", i, err)
		}
		var got string
		if rows.Next() {
			if err := rows.Scan(&got); err != nil {
				t.Fatalf("(%d) scan err: %v", i, err)
			}
		}
		if err := rows.Err(); err != nil {
			t.Fatalf("(%d) rows err: %v", i, err)
		}
		// Without the fix, this Close double-frees the string
		if err := rows.Close(); err != nil {
			t.Fatalf("(%d) close err: %v", i, err)
		}

		if got != val {
			t.Fatalf("(%d) got len=%d\nwant len=%d", i, len(got), len(val))
		}

		// Immediately bind a fresh string to pressure the allocator into reusing the double-freed block.
		check := fmt.Sprintf("check-%04d-%s", i, strings.Repeat("y", 1024))
		var got2 string
		if err := conn.QueryRowContext(ctx, "SELECT ?", check).Scan(&got2); err != nil {
			t.Fatalf("(%d) check query err: %v", i, err)
		}
		if got2 != check {
			t.Fatalf("(%d) data corruption: got len=%d\nwant len=%d", i, len(got2), len(check))
		}
	}
}

leak_test.go

0 → 100644
+184 −0
Original line number Diff line number Diff line
package sqlite

import (
	"context"
	"database/sql"
	"fmt"
	"strings"
	"testing"

	"modernc.org/libc"
)

// TestMultiStmtNopAllocsLeak exercises a leak in the multi-statement query
// path: when a middle statement binds parameters and returns SQLITE_DONE while
// a previous statement already set the rows result, the bind-parameter
// allocations are silently dropped. If a later statement also binds, it
// overwrites the allocs slice, leaking the earlier allocations.
//
// Run with -tags memory.counters to enable allocation tracking.
//
// Pattern: SELECT 1; UPDATE t SET v = ? WHERE 0; SELECT ?
//   - SELECT 1         → SQLITE_ROW, r set, no bind allocs
//   - UPDATE … WHERE 0 → bind allocs=[ptr_A], SQLITE_DONE, r!=nil → nop (ptr_A retained in allocs)
//   - SELECT ?         → bind allocs=[ptr_B] (ptr_A leaked), SQLITE_ROW
func TestMultiStmtNopAllocsLeak(t *testing.T) {
	db, err := sql.Open("sqlite", "file::memory:")
	if err != nil {
		t.Fatal(err)
	}
	defer db.Close()
	db.SetMaxOpenConns(1)

	ctx := context.Background()
	conn, err := db.Conn(ctx)
	if err != nil {
		t.Fatal(err)
	}
	defer conn.Close()

	if _, err := conn.ExecContext(ctx, "CREATE TABLE t(v TEXT)"); err != nil {
		t.Fatal(err)
	}

	run := func(iters int) {
		for i := range iters {
			val := fmt.Sprintf("iter-%04d-%s", i, strings.Repeat("x", 1024))
			rows, err := conn.QueryContext(ctx, "SELECT 1; UPDATE t SET v = ? WHERE 0; SELECT ?", val)
			if err != nil {
				t.Fatalf("(%d) query: %v", i, err)
			}
			if err := rows.Close(); err != nil {
				t.Fatalf("(%d) close: %v", i, err)
			}
		}
	}

	// Warm up to reach steady state.
	run(100)

	before := libc.MemStat()
	run(1000)
	after := libc.MemStat()

	leaked := after.Allocs - before.Allocs
	t.Logf("allocs before=%d after=%d delta=%d", before.Allocs, after.Allocs, leaked)
	if leaked > 100 {
		t.Fatalf("memory leak: net alloc count grew by %d over 1000 iterations", leaked)
	}
}

// TestMultiStmtErrorAllocsLeak exercises a leak on the step-error path in the
// multi-statement query closure: when step returns an error (e.g. a UNIQUE
// constraint violation), allocs from the preceding bind are not freed.
//
// The failing statement is placed first so that no prior rows object exists —
// this isolates the allocs leak from the separate "orphaned rows" bug.
//
// Run with -tags memory.counters to enable allocation tracking.
//
// Pattern: INSERT INTO t VALUES(?); SELECT 1 — with a duplicate value
//   - INSERT INTO t VALUES(?) → bind allocs=[ptr_A], step error → allocs leaked
func TestMultiStmtErrorAllocsLeak(t *testing.T) {
	db, err := sql.Open("sqlite", "file::memory:")
	if err != nil {
		t.Fatal(err)
	}
	defer db.Close()
	db.SetMaxOpenConns(1)

	ctx := context.Background()
	conn, err := db.Conn(ctx)
	if err != nil {
		t.Fatal(err)
	}
	defer conn.Close()

	if _, err := conn.ExecContext(ctx, "CREATE TABLE t(v TEXT UNIQUE)"); err != nil {
		t.Fatal(err)
	}
	// Seed the table with a value that we'll collide with.
	collideVal := fmt.Sprintf("collide-%s", strings.Repeat("x", 1024))
	if _, err := conn.ExecContext(ctx, "INSERT INTO t VALUES(?)", collideVal); err != nil {
		t.Fatal(err)
	}

	run := func(iters int) {
		for i := range iters {
			_, err := conn.QueryContext(ctx, "INSERT INTO t VALUES(?); SELECT 1", collideVal)
			if err == nil {
				t.Fatalf("(%d) expected UNIQUE constraint error", i)
			}
		}
	}

	// Warm up to reach steady state.
	run(100)

	before := libc.MemStat()
	run(1000)
	after := libc.MemStat()

	leaked := after.Allocs - before.Allocs
	t.Logf("allocs before=%d after=%d delta=%d", before.Allocs, after.Allocs, leaked)
	if leaked > 100 {
		t.Fatalf("memory leak: net alloc count grew by %d over 1000 iterations", leaked)
	}
}

// TestMultiStmtOrphanedRowsOnError exercises a resource leak when a later
// statement in a multi-statement query errors after an earlier statement
// already produced a rows result: the error path discards the rows object
// without closing it, leaking its prepared statement handle and bind-parameter
// allocations.
//
// Run with -tags memory.counters to enable allocation tracking.
//
// Pattern: SELECT ?; INSERT INTO t VALUES(?) — with a duplicate value
//   - SELECT ? → SQLITE_ROW, r set (holds pstmt + allocs)
//   - INSERT   → step error (UNIQUE violation) → return nil, err → r orphaned
func TestMultiStmtOrphanedRowsOnError(t *testing.T) {
	db, err := sql.Open("sqlite", "file::memory:")
	if err != nil {
		t.Fatal(err)
	}
	defer db.Close()
	db.SetMaxOpenConns(1)

	ctx := context.Background()
	conn, err := db.Conn(ctx)
	if err != nil {
		t.Fatal(err)
	}
	defer conn.Close()

	if _, err := conn.ExecContext(ctx, "CREATE TABLE t(v TEXT UNIQUE)"); err != nil {
		t.Fatal(err)
	}
	collideVal := fmt.Sprintf("collide-%s", strings.Repeat("x", 1024))
	if _, err := conn.ExecContext(ctx, "INSERT INTO t VALUES(?)", collideVal); err != nil {
		t.Fatal(err)
	}

	run := func(iters int) {
		for i := range iters {
			_, err := conn.QueryContext(ctx, "SELECT ?; INSERT INTO t VALUES(?)", collideVal)
			if err == nil {
				t.Fatalf("(%d) expected UNIQUE constraint error", i)
			}
		}
	}

	// Warm up to reach steady state.
	run(100)

	before := libc.MemStat()
	run(1000)
	after := libc.MemStat()

	leaked := after.Allocs - before.Allocs
	t.Logf("allocs before=%d after=%d delta=%d", before.Allocs, after.Allocs, leaked)
	if leaked > 100 {
		t.Fatalf("memory leak: net alloc count grew by %d over 1000 iterations", leaked)
	}
}
+7 −2
Original line number Diff line number Diff line
@@ -27,8 +27,13 @@ type rows struct {
	reuseStmt bool // If true, Close() resets instead of finalizing
}

func newRows(c *conn, pstmt uintptr, allocs []uintptr, empty bool) (r *rows, err error) {
	r = &rows{c: c, pstmt: pstmt, allocs: allocs, empty: empty}
func newRows(c *conn, pstmt uintptr, allocs *[]uintptr, empty bool) (r *rows, err error) {
	var a []uintptr
	if allocs != nil {
		a = *allocs
		*allocs = nil
	}
	r = &rows{c: c, pstmt: pstmt, allocs: a, empty: empty}

	defer func() {
		if err != nil {
+16 −10
Original line number Diff line number Diff line
@@ -262,8 +262,6 @@ func (s *stmt) query(ctx context.Context, args []driver.NamedValue) (r driver.Ro
		}
	}

	var allocs []uintptr

	defer func() {
		if ctx != nil && atomic.LoadInt32(&done) != 0 {
			if r != nil {
@@ -271,7 +269,7 @@ func (s *stmt) query(ctx context.Context, args []driver.NamedValue) (r driver.Ro
			}
			r, err = nil, ctx.Err()
		} else if r == nil && err == nil {
			r, err = newRows(s.c, pstmt, allocs, true)
			r, err = newRows(s.c, pstmt, nil, true)
		}

		if pstmt != 0 {
@@ -289,6 +287,7 @@ func (s *stmt) query(ctx context.Context, args []driver.NamedValue) (r driver.Ro

	// OPTIMIZED PATH: Single Cached Statement
	if s.pstmt != 0 {
		var allocs []uintptr
		// Bind
		n, err := s.c.bindParameterCount(s.pstmt)
		if err != nil {
@@ -314,7 +313,7 @@ func (s *stmt) query(ctx context.Context, args []driver.NamedValue) (r driver.Ro
		switch rc & 0xff {
		case sqlite3.SQLITE_ROW:
			// Pass reuseStmt=true
			if r, err = newRows(s.c, s.pstmt, allocs, false); err != nil {
			if r, err = newRows(s.c, s.pstmt, &allocs, false); err != nil {
				s.c.reset(s.pstmt)
				s.c.clearBindings(s.pstmt)
				return nil, err
@@ -331,7 +330,7 @@ func (s *stmt) query(ctx context.Context, args []driver.NamedValue) (r driver.Ro

			// Actually, if we pass reuseStmt=true to an empty set,
			// rows.Close() will eventually reset it.
			if r, err = newRows(s.c, s.pstmt, allocs, true); err != nil {
			if r, err = newRows(s.c, s.pstmt, &allocs, true); err != nil {
				s.c.reset(s.pstmt)
				s.c.clearBindings(s.pstmt)
				return nil, err
@@ -351,6 +350,9 @@ func (s *stmt) query(ctx context.Context, args []driver.NamedValue) (r driver.Ro
	// FALLBACK PATH: Multi-statement script
	for psql := s.psql; *(*byte)(unsafe.Pointer(psql)) != 0 && atomic.LoadInt32(&done) == 0; {
		if pstmt, err = s.c.prepareV2(&psql); err != nil {
			if r != nil {
				r.Close()
			}
			return nil, err
		}

@@ -359,6 +361,9 @@ func (s *stmt) query(ctx context.Context, args []driver.NamedValue) (r driver.Ro
		}

		err = func() (err error) {
			var allocs []uintptr
			defer func() { s.c.freeAllocs(allocs) }()

			n, err := s.c.bindParameterCount(pstmt)
			if err != nil {
				return err
@@ -380,15 +385,14 @@ func (s *stmt) query(ctx context.Context, args []driver.NamedValue) (r driver.Ro
				if r != nil {
					r.Close()
				}
				if r, err = newRows(s.c, pstmt, allocs, false); err != nil {
				if r, err = newRows(s.c, pstmt, &allocs, false); err != nil {
					return err
				}

				pstmt = 0
				return nil
			case sqlite3.SQLITE_DONE:
				if r == nil {
					if r, err = newRows(s.c, pstmt, allocs, true); err != nil {
					if r, err = newRows(s.c, pstmt, &allocs, true); err != nil {
						return err
					}
					pstmt = 0
@@ -404,10 +408,9 @@ func (s *stmt) query(ctx context.Context, args []driver.NamedValue) (r driver.Ro
				if r != nil {
					r.Close()
				}
				if r, err = newRows(s.c, pstmt, allocs, true); err != nil {
				if r, err = newRows(s.c, pstmt, &allocs, true); err != nil {
					return err
				}

				pstmt = 0
			}
			return nil
@@ -423,6 +426,9 @@ func (s *stmt) query(ctx context.Context, args []driver.NamedValue) (r driver.Ro
		}

		if err != nil {
			if r != nil {
				r.Close() // r is from a previous iteration; clean up since we won't return it
			}
			return nil, err
		}
	}