Skip to content
Snippets Groups Projects
Commit 19ea8e6c authored by Michael Hoffmann's avatar Michael Hoffmann
Browse files

Merge branch 'mhoffm-user-defined-functions-part-2' into 'master'

driver: add a way to register scalar functions

Closes #95

See merge request !38
parents 4ccbc55b a9227519
No related branches found
No related tags found
1 merge request!38driver: add a way to register scalar functions
// Copyright 2022 The Sqlite Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package functest // modernc.org/sqlite/functest
import (
"bytes"
"crypto/md5"
"database/sql"
"database/sql/driver"
"encoding/hex"
"errors"
"fmt"
"strings"
"testing"
"time"
sqlite3 "modernc.org/sqlite"
)
func init() {
sqlite3.MustRegisterDeterministicScalarFunction(
"test_int64",
0,
func(ctx *sqlite3.FunctionContext, args []driver.Value) (driver.Value, error) {
return int64(42), nil
},
)
sqlite3.MustRegisterDeterministicScalarFunction(
"test_float64",
0,
func(ctx *sqlite3.FunctionContext, args []driver.Value) (driver.Value, error) {
return float64(1e-2), nil
},
)
sqlite3.MustRegisterDeterministicScalarFunction(
"test_null",
0,
func(ctx *sqlite3.FunctionContext, args []driver.Value) (driver.Value, error) {
return nil, nil
},
)
sqlite3.MustRegisterDeterministicScalarFunction(
"test_error",
0,
func(ctx *sqlite3.FunctionContext, args []driver.Value) (driver.Value, error) {
return nil, errors.New("boom")
},
)
sqlite3.MustRegisterDeterministicScalarFunction(
"test_empty_byte_slice",
0,
func(ctx *sqlite3.FunctionContext, args []driver.Value) (driver.Value, error) {
return []byte{}, nil
},
)
sqlite3.MustRegisterDeterministicScalarFunction(
"test_nonempty_byte_slice",
0,
func(ctx *sqlite3.FunctionContext, args []driver.Value) (driver.Value, error) {
return []byte("abcdefg"), nil
},
)
sqlite3.MustRegisterDeterministicScalarFunction(
"test_empty_string",
0,
func(ctx *sqlite3.FunctionContext, args []driver.Value) (driver.Value, error) {
return "", nil
},
)
sqlite3.MustRegisterDeterministicScalarFunction(
"test_nonempty_string",
0,
func(ctx *sqlite3.FunctionContext, args []driver.Value) (driver.Value, error) {
return "abcdefg", nil
},
)
sqlite3.MustRegisterDeterministicScalarFunction(
"yesterday",
1,
func(ctx *sqlite3.FunctionContext, args []driver.Value) (driver.Value, error) {
var arg time.Time
switch argTyped := args[0].(type) {
case int64:
arg = time.Unix(argTyped, 0)
default:
fmt.Println(argTyped)
return nil, fmt.Errorf("expected argument to be int64, got: %T", argTyped)
}
return arg.Add(-24 * time.Hour), nil
},
)
sqlite3.MustRegisterDeterministicScalarFunction(
"md5",
1,
func(ctx *sqlite3.FunctionContext, args []driver.Value) (driver.Value, error) {
var arg *bytes.Buffer
switch argTyped := args[0].(type) {
case string:
arg = bytes.NewBuffer([]byte(argTyped))
case []byte:
arg = bytes.NewBuffer(argTyped)
default:
return nil, fmt.Errorf("expected argument to be a string, got: %T", argTyped)
}
w := md5.New()
if _, err := arg.WriteTo(w); err != nil {
return nil, fmt.Errorf("unable to compute md5 checksum: %s", err)
}
return hex.EncodeToString(w.Sum(nil)), nil
},
)
}
func TestRegisteredFunctions(t *testing.T) {
withDB := func(test func(db *sql.DB)) {
db, err := sql.Open("sqlite", "file::memory:")
if err != nil {
t.Fatalf("failed to open database: %v", err)
}
defer db.Close()
test(db)
}
t.Run("int64", func(tt *testing.T) {
withDB(func(db *sql.DB) {
row := db.QueryRow("select test_int64()")
var a int
if err := row.Scan(&a); err != nil {
tt.Fatal(err)
}
if g, e := a, 42; g != e {
tt.Fatal(g, e)
}
})
})
t.Run("float64", func(tt *testing.T) {
withDB(func(db *sql.DB) {
row := db.QueryRow("select test_float64()")
var a float64
if err := row.Scan(&a); err != nil {
tt.Fatal(err)
}
if g, e := a, 1e-2; g != e {
tt.Fatal(g, e)
}
})
})
t.Run("error", func(tt *testing.T) {
withDB(func(db *sql.DB) {
_, err := db.Query("select test_error()")
if err == nil {
tt.Fatal("expected error, got none")
}
if !strings.Contains(err.Error(), "boom") {
tt.Fatal(err)
}
})
})
t.Run("empty_byte_slice", func(tt *testing.T) {
withDB(func(db *sql.DB) {
row := db.QueryRow("select test_empty_byte_slice()")
var a []byte
if err := row.Scan(&a); err != nil {
tt.Fatal(err)
}
if len(a) > 0 {
tt.Fatal("expected empty byte slice")
}
})
})
t.Run("nonempty_byte_slice", func(tt *testing.T) {
withDB(func(db *sql.DB) {
row := db.QueryRow("select test_nonempty_byte_slice()")
var a []byte
if err := row.Scan(&a); err != nil {
tt.Fatal(err)
}
if g, e := a, []byte("abcdefg"); !bytes.Equal(g, e) {
tt.Fatal(string(g), string(e))
}
})
})
t.Run("empty_string", func(tt *testing.T) {
withDB(func(db *sql.DB) {
row := db.QueryRow("select test_empty_string()")
var a string
if err := row.Scan(&a); err != nil {
tt.Fatal(err)
}
if len(a) > 0 {
tt.Fatal("expected empty string")
}
})
})
t.Run("nonempty_string", func(tt *testing.T) {
withDB(func(db *sql.DB) {
row := db.QueryRow("select test_nonempty_string()")
var a string
if err := row.Scan(&a); err != nil {
tt.Fatal(err)
}
if g, e := a, "abcdefg"; g != e {
tt.Fatal(g, e)
}
})
})
t.Run("null", func(tt *testing.T) {
withDB(func(db *sql.DB) {
row := db.QueryRow("select test_null()")
var a interface{}
if err := row.Scan(&a); err != nil {
tt.Fatal(err)
}
if a != nil {
tt.Fatal("expected nil")
}
})
})
t.Run("dates", func(tt *testing.T) {
withDB(func(db *sql.DB) {
row := db.QueryRow("select yesterday(unixepoch('2018-11-01'))")
var a int64
if err := row.Scan(&a); err != nil {
tt.Fatal(err)
}
if g, e := time.Unix(a, 0), time.Date(2018, time.October, 31, 0, 0, 0, 0, time.UTC); !g.Equal(e) {
tt.Fatal(g, e)
}
})
})
t.Run("md5", func(tt *testing.T) {
withDB(func(db *sql.DB) {
row := db.QueryRow("select md5('abcdefg')")
var a string
if err := row.Scan(&a); err != nil {
tt.Fatal(err)
}
if g, e := a, "7ac66c0f148de9519b8bd264312c4d64"; g != e {
tt.Fatal(g, e)
}
})
})
t.Run("md5 with blob input", func(tt *testing.T) {
withDB(func(db *sql.DB) {
if _, err := db.Exec("create table t(b blob); insert into t values (?)", []byte("abcdefg")); err != nil {
tt.Fatal(err)
}
row := db.QueryRow("select md5(b) from t")
var a []byte
if err := row.Scan(&a); err != nil {
tt.Fatal(err)
}
if g, e := a, []byte("7ac66c0f148de9519b8bd264312c4d64"); !bytes.Equal(g, e) {
tt.Fatal(string(g), string(e))
}
})
})
}
// Copyright 2022 The Sqlite Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package sqlite3
const (
SQLITE_STATIC = uintptr(0) // ((sqlite3_destructor_type)0)
SQLITE_TRANSIENT = ^uintptr(0) // ((sqlite3_destructor_type)-1)
)
......@@ -749,7 +749,6 @@ type conn struct {
writeTimeFormat string
beginMode string
udfs map[string]*userDefinedFunction
}
func newConn(dsn string) (*conn, error) {
......@@ -766,7 +765,7 @@ func newConn(dsn string) (*conn, error) {
}
}
c := &conn{tls: libc.NewTLS(), udfs: make(map[string]*userDefinedFunction)}
c := &conn{tls: libc.NewTLS()}
db, err := c.openV2(
dsn,
sqlite3.SQLITE_OPEN_READWRITE|sqlite3.SQLITE_OPEN_CREATE|
......@@ -1313,13 +1312,6 @@ func (c *conn) Close() error {
c.db = 0
}
if c.udfs != nil {
for _, v := range c.udfs {
v.close(c.tls)
}
c.udfs = nil
}
if c.tls != nil {
c.tls.Close()
c.tls = nil
......@@ -1345,25 +1337,7 @@ type userDefinedFunction struct {
freeOnce sync.Once
}
func (udf *userDefinedFunction) close(tls *libc.TLS) {
if udf == nil {
return
}
udf.freeOnce.Do(func() { libc.Xfree(tls, udf.zFuncName) })
}
func (c *conn) createFunctionInternal(fun *userDefinedFunction) error {
c.Mutex.Lock()
defer c.Mutex.Unlock()
goZFuncName := libc.GoString(fun.zFuncName)
if prev, ok := c.udfs[goZFuncName]; ok {
prev.close(c.tls)
delete(c.udfs, goZFuncName)
}
c.udfs[goZFuncName] = fun
if rc := sqlite3.Xsqlite3_create_function(
c.tls,
c.db,
......@@ -1445,9 +1419,14 @@ func (c *conn) query(ctx context.Context, query string, args []driver.NamedValue
}
// Driver implements database/sql/driver.Driver.
type Driver struct{}
type Driver struct {
// user defined functions that are added to every new connection on Open
udfs map[string]*userDefinedFunction
}
var d = &Driver{udfs: make(map[string]*userDefinedFunction)}
func newDriver() *Driver { return &Driver{} }
func newDriver() *Driver { return d }
// Open returns a new connection to the database. The name is a string in a
// driver-specific format.
......@@ -1479,5 +1458,162 @@ func newDriver() *Driver { return &Driver{} }
// available at
// https://www.sqlite.org/lang_transaction.html#deferred_immediate_and_exclusive_transactions
func (d *Driver) Open(name string) (driver.Conn, error) {
return newConn(name)
c, err := newConn(name)
if err != nil {
return nil, err
}
for _, udf := range d.udfs {
if err = c.createFunctionInternal(udf); err != nil {
c.Close()
return nil, err
}
}
return c, nil
}
type FunctionContext struct{}
const sqliteValPtrSize = unsafe.Sizeof(&sqlite3.Sqlite3_value{})
func RegisterScalarFunction(
zFuncName string,
nArg int32,
xFunc func(ctx *FunctionContext, args []driver.Value) (driver.Value, error),
) error {
return registerScalarFunction(zFuncName, nArg, sqlite3.SQLITE_UTF8, xFunc)
}
func MustRegisterScalarFunction(
zFuncName string,
nArg int32,
xFunc func(ctx *FunctionContext, args []driver.Value) (driver.Value, error),
) {
if err := RegisterScalarFunction(zFuncName, nArg, xFunc); err != nil {
panic(err)
}
}
func MustRegisterDeterministicScalarFunction(
zFuncName string,
nArg int32,
xFunc func(ctx *FunctionContext, args []driver.Value) (driver.Value, error),
) {
if err := RegisterDeterministicScalarFunction(zFuncName, nArg, xFunc); err != nil {
panic(err)
}
}
func RegisterDeterministicScalarFunction(
zFuncName string,
nArg int32,
xFunc func(ctx *FunctionContext, args []driver.Value) (driver.Value, error),
) error {
return registerScalarFunction(zFuncName, nArg, sqlite3.SQLITE_UTF8|sqlite3.SQLITE_DETERMINISTIC, xFunc)
}
func registerScalarFunction(
zFuncName string,
nArg int32,
eTextRep int32,
xFunc func(ctx *FunctionContext, args []driver.Value) (driver.Value, error),
) error {
if _, ok := d.udfs[zFuncName]; ok {
return fmt.Errorf("a function named %q is already registered", zFuncName)
}
// dont free, functions registered on the driver live as long as the program
name, err := libc.CString(zFuncName)
if err != nil {
return err
}
udf := &userDefinedFunction{
zFuncName: name,
nArg: nArg,
eTextRep: eTextRep,
xFunc: func(tls *libc.TLS, ctx uintptr, argc int32, argv uintptr) {
setErrorResult := func(res error) {
errmsg, cerr := libc.CString(res.Error())
if cerr != nil {
panic(cerr)
}
defer libc.Xfree(tls, errmsg)
sqlite3.Xsqlite3_result_error(tls, ctx, errmsg, -1)
sqlite3.Xsqlite3_result_error_code(tls, ctx, sqlite3.SQLITE_ERROR)
}
args := make([]driver.Value, argc)
for i := int32(0); i < argc; i++ {
valPtr := *(*uintptr)(unsafe.Pointer(argv + uintptr(i)*sqliteValPtrSize))
switch valType := sqlite3.Xsqlite3_value_type(tls, valPtr); valType {
case sqlite3.SQLITE_TEXT:
args[i] = libc.GoString(sqlite3.Xsqlite3_value_text(tls, valPtr))
case sqlite3.SQLITE_INTEGER:
args[i] = sqlite3.Xsqlite3_value_int64(tls, valPtr)
case sqlite3.SQLITE_FLOAT:
args[i] = sqlite3.Xsqlite3_value_double(tls, valPtr)
case sqlite3.SQLITE_NULL:
args[i] = nil
case sqlite3.SQLITE_BLOB:
size := sqlite3.Xsqlite3_value_bytes(tls, valPtr)
blobPtr := sqlite3.Xsqlite3_value_blob(tls, valPtr)
v := make([]byte, size)
copy(v, (*libc.RawMem)(unsafe.Pointer(blobPtr))[:size:size])
args[i] = v
default:
panic(fmt.Sprintf("unexpected argument type %q passed by sqlite", valType))
}
}
res, err := xFunc(&FunctionContext{}, args)
if err != nil {
setErrorResult(err)
return
}
switch resTyped := res.(type) {
case nil:
sqlite3.Xsqlite3_result_null(tls, ctx)
case int64:
sqlite3.Xsqlite3_result_int64(tls, ctx, resTyped)
case float64:
sqlite3.Xsqlite3_result_double(tls, ctx, resTyped)
case bool:
sqlite3.Xsqlite3_result_int(tls, ctx, libc.Bool32(resTyped))
case time.Time:
sqlite3.Xsqlite3_result_int64(tls, ctx, resTyped.Unix())
case string:
size := int32(len(resTyped))
cstr, err := libc.CString(resTyped)
if err != nil {
panic(err)
}
defer libc.Xfree(tls, cstr)
sqlite3.Xsqlite3_result_text(tls, ctx, cstr, size, sqlite3.SQLITE_TRANSIENT)
case []byte:
size := int32(len(resTyped))
if size == 0 {
sqlite3.Xsqlite3_result_zeroblob(tls, ctx, 0)
return
}
p := libc.Xmalloc(tls, types.Size_t(size))
if p == 0 {
panic(fmt.Sprintf("unable to allocate space for blob: %d", size))
}
defer libc.Xfree(tls, p)
copy((*libc.RawMem)(unsafe.Pointer(p))[:size:size], resTyped)
sqlite3.Xsqlite3_result_blob(tls, ctx, p, size, sqlite3.SQLITE_TRANSIENT)
default:
setErrorResult(fmt.Errorf("function did not return a valid driver.Value: %T", resTyped))
return
}
},
}
d.udfs[zFuncName] = udf
return nil
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment