Skip to content
GitLab
Menu
Why GitLab
Pricing
Contact Sales
Explore
Why GitLab
Pricing
Contact Sales
Explore
Sign in
Get free trial
Primary navigation
Search or go to…
Project
sqlite
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Requirements
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Locked files
Deploy
Releases
Package registry
Container registry
Model registry
Operate
Environments
Terraform modules
Monitor
Incidents
Service Desk
Analyze
Value stream analytics
Contributor analytics
Repository analytics
Code review analytics
Issue analytics
Insights
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Privacy statement
Keyboard shortcuts
?
What's new
6
Snippets
Groups
Projects
Show more breadcrumbs
cznic
sqlite
Commits
19ea8e6c
Commit
19ea8e6c
authored
2 years ago
by
Michael Hoffmann
Browse files
Options
Downloads
Plain Diff
Merge branch 'mhoffm-user-defined-functions-part-2' into 'master'
driver: add a way to register scalar functions Closes
#95
See merge request
!38
parents
4ccbc55b
a9227519
No related branches found
Branches containing commit
No related tags found
Tags containing commit
1 merge request
!38
driver: add a way to register scalar functions
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
functest/func_test.go
+292
-0
292 additions, 0 deletions
functest/func_test.go
lib/defs.go
+10
-0
10 additions, 0 deletions
lib/defs.go
sqlite.go
+166
-30
166 additions, 30 deletions
sqlite.go
with
468 additions
and
30 deletions
functest/func_test.go
0 → 100644
+
292
−
0
View file @
19ea8e6c
// Copyright 2022 The Sqlite Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package
functest
// modernc.org/sqlite/functest
import
(
"bytes"
"crypto/md5"
"database/sql"
"database/sql/driver"
"encoding/hex"
"errors"
"fmt"
"strings"
"testing"
"time"
sqlite3
"modernc.org/sqlite"
)
func
init
()
{
sqlite3
.
MustRegisterDeterministicScalarFunction
(
"test_int64"
,
0
,
func
(
ctx
*
sqlite3
.
FunctionContext
,
args
[]
driver
.
Value
)
(
driver
.
Value
,
error
)
{
return
int64
(
42
),
nil
},
)
sqlite3
.
MustRegisterDeterministicScalarFunction
(
"test_float64"
,
0
,
func
(
ctx
*
sqlite3
.
FunctionContext
,
args
[]
driver
.
Value
)
(
driver
.
Value
,
error
)
{
return
float64
(
1e-2
),
nil
},
)
sqlite3
.
MustRegisterDeterministicScalarFunction
(
"test_null"
,
0
,
func
(
ctx
*
sqlite3
.
FunctionContext
,
args
[]
driver
.
Value
)
(
driver
.
Value
,
error
)
{
return
nil
,
nil
},
)
sqlite3
.
MustRegisterDeterministicScalarFunction
(
"test_error"
,
0
,
func
(
ctx
*
sqlite3
.
FunctionContext
,
args
[]
driver
.
Value
)
(
driver
.
Value
,
error
)
{
return
nil
,
errors
.
New
(
"boom"
)
},
)
sqlite3
.
MustRegisterDeterministicScalarFunction
(
"test_empty_byte_slice"
,
0
,
func
(
ctx
*
sqlite3
.
FunctionContext
,
args
[]
driver
.
Value
)
(
driver
.
Value
,
error
)
{
return
[]
byte
{},
nil
},
)
sqlite3
.
MustRegisterDeterministicScalarFunction
(
"test_nonempty_byte_slice"
,
0
,
func
(
ctx
*
sqlite3
.
FunctionContext
,
args
[]
driver
.
Value
)
(
driver
.
Value
,
error
)
{
return
[]
byte
(
"abcdefg"
),
nil
},
)
sqlite3
.
MustRegisterDeterministicScalarFunction
(
"test_empty_string"
,
0
,
func
(
ctx
*
sqlite3
.
FunctionContext
,
args
[]
driver
.
Value
)
(
driver
.
Value
,
error
)
{
return
""
,
nil
},
)
sqlite3
.
MustRegisterDeterministicScalarFunction
(
"test_nonempty_string"
,
0
,
func
(
ctx
*
sqlite3
.
FunctionContext
,
args
[]
driver
.
Value
)
(
driver
.
Value
,
error
)
{
return
"abcdefg"
,
nil
},
)
sqlite3
.
MustRegisterDeterministicScalarFunction
(
"yesterday"
,
1
,
func
(
ctx
*
sqlite3
.
FunctionContext
,
args
[]
driver
.
Value
)
(
driver
.
Value
,
error
)
{
var
arg
time
.
Time
switch
argTyped
:=
args
[
0
]
.
(
type
)
{
case
int64
:
arg
=
time
.
Unix
(
argTyped
,
0
)
default
:
fmt
.
Println
(
argTyped
)
return
nil
,
fmt
.
Errorf
(
"expected argument to be int64, got: %T"
,
argTyped
)
}
return
arg
.
Add
(
-
24
*
time
.
Hour
),
nil
},
)
sqlite3
.
MustRegisterDeterministicScalarFunction
(
"md5"
,
1
,
func
(
ctx
*
sqlite3
.
FunctionContext
,
args
[]
driver
.
Value
)
(
driver
.
Value
,
error
)
{
var
arg
*
bytes
.
Buffer
switch
argTyped
:=
args
[
0
]
.
(
type
)
{
case
string
:
arg
=
bytes
.
NewBuffer
([]
byte
(
argTyped
))
case
[]
byte
:
arg
=
bytes
.
NewBuffer
(
argTyped
)
default
:
return
nil
,
fmt
.
Errorf
(
"expected argument to be a string, got: %T"
,
argTyped
)
}
w
:=
md5
.
New
()
if
_
,
err
:=
arg
.
WriteTo
(
w
);
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"unable to compute md5 checksum: %s"
,
err
)
}
return
hex
.
EncodeToString
(
w
.
Sum
(
nil
)),
nil
},
)
}
func
TestRegisteredFunctions
(
t
*
testing
.
T
)
{
withDB
:=
func
(
test
func
(
db
*
sql
.
DB
))
{
db
,
err
:=
sql
.
Open
(
"sqlite"
,
"file::memory:"
)
if
err
!=
nil
{
t
.
Fatalf
(
"failed to open database: %v"
,
err
)
}
defer
db
.
Close
()
test
(
db
)
}
t
.
Run
(
"int64"
,
func
(
tt
*
testing
.
T
)
{
withDB
(
func
(
db
*
sql
.
DB
)
{
row
:=
db
.
QueryRow
(
"select test_int64()"
)
var
a
int
if
err
:=
row
.
Scan
(
&
a
);
err
!=
nil
{
tt
.
Fatal
(
err
)
}
if
g
,
e
:=
a
,
42
;
g
!=
e
{
tt
.
Fatal
(
g
,
e
)
}
})
})
t
.
Run
(
"float64"
,
func
(
tt
*
testing
.
T
)
{
withDB
(
func
(
db
*
sql
.
DB
)
{
row
:=
db
.
QueryRow
(
"select test_float64()"
)
var
a
float64
if
err
:=
row
.
Scan
(
&
a
);
err
!=
nil
{
tt
.
Fatal
(
err
)
}
if
g
,
e
:=
a
,
1e-2
;
g
!=
e
{
tt
.
Fatal
(
g
,
e
)
}
})
})
t
.
Run
(
"error"
,
func
(
tt
*
testing
.
T
)
{
withDB
(
func
(
db
*
sql
.
DB
)
{
_
,
err
:=
db
.
Query
(
"select test_error()"
)
if
err
==
nil
{
tt
.
Fatal
(
"expected error, got none"
)
}
if
!
strings
.
Contains
(
err
.
Error
(),
"boom"
)
{
tt
.
Fatal
(
err
)
}
})
})
t
.
Run
(
"empty_byte_slice"
,
func
(
tt
*
testing
.
T
)
{
withDB
(
func
(
db
*
sql
.
DB
)
{
row
:=
db
.
QueryRow
(
"select test_empty_byte_slice()"
)
var
a
[]
byte
if
err
:=
row
.
Scan
(
&
a
);
err
!=
nil
{
tt
.
Fatal
(
err
)
}
if
len
(
a
)
>
0
{
tt
.
Fatal
(
"expected empty byte slice"
)
}
})
})
t
.
Run
(
"nonempty_byte_slice"
,
func
(
tt
*
testing
.
T
)
{
withDB
(
func
(
db
*
sql
.
DB
)
{
row
:=
db
.
QueryRow
(
"select test_nonempty_byte_slice()"
)
var
a
[]
byte
if
err
:=
row
.
Scan
(
&
a
);
err
!=
nil
{
tt
.
Fatal
(
err
)
}
if
g
,
e
:=
a
,
[]
byte
(
"abcdefg"
);
!
bytes
.
Equal
(
g
,
e
)
{
tt
.
Fatal
(
string
(
g
),
string
(
e
))
}
})
})
t
.
Run
(
"empty_string"
,
func
(
tt
*
testing
.
T
)
{
withDB
(
func
(
db
*
sql
.
DB
)
{
row
:=
db
.
QueryRow
(
"select test_empty_string()"
)
var
a
string
if
err
:=
row
.
Scan
(
&
a
);
err
!=
nil
{
tt
.
Fatal
(
err
)
}
if
len
(
a
)
>
0
{
tt
.
Fatal
(
"expected empty string"
)
}
})
})
t
.
Run
(
"nonempty_string"
,
func
(
tt
*
testing
.
T
)
{
withDB
(
func
(
db
*
sql
.
DB
)
{
row
:=
db
.
QueryRow
(
"select test_nonempty_string()"
)
var
a
string
if
err
:=
row
.
Scan
(
&
a
);
err
!=
nil
{
tt
.
Fatal
(
err
)
}
if
g
,
e
:=
a
,
"abcdefg"
;
g
!=
e
{
tt
.
Fatal
(
g
,
e
)
}
})
})
t
.
Run
(
"null"
,
func
(
tt
*
testing
.
T
)
{
withDB
(
func
(
db
*
sql
.
DB
)
{
row
:=
db
.
QueryRow
(
"select test_null()"
)
var
a
interface
{}
if
err
:=
row
.
Scan
(
&
a
);
err
!=
nil
{
tt
.
Fatal
(
err
)
}
if
a
!=
nil
{
tt
.
Fatal
(
"expected nil"
)
}
})
})
t
.
Run
(
"dates"
,
func
(
tt
*
testing
.
T
)
{
withDB
(
func
(
db
*
sql
.
DB
)
{
row
:=
db
.
QueryRow
(
"select yesterday(unixepoch('2018-11-01'))"
)
var
a
int64
if
err
:=
row
.
Scan
(
&
a
);
err
!=
nil
{
tt
.
Fatal
(
err
)
}
if
g
,
e
:=
time
.
Unix
(
a
,
0
),
time
.
Date
(
2018
,
time
.
October
,
31
,
0
,
0
,
0
,
0
,
time
.
UTC
);
!
g
.
Equal
(
e
)
{
tt
.
Fatal
(
g
,
e
)
}
})
})
t
.
Run
(
"md5"
,
func
(
tt
*
testing
.
T
)
{
withDB
(
func
(
db
*
sql
.
DB
)
{
row
:=
db
.
QueryRow
(
"select md5('abcdefg')"
)
var
a
string
if
err
:=
row
.
Scan
(
&
a
);
err
!=
nil
{
tt
.
Fatal
(
err
)
}
if
g
,
e
:=
a
,
"7ac66c0f148de9519b8bd264312c4d64"
;
g
!=
e
{
tt
.
Fatal
(
g
,
e
)
}
})
})
t
.
Run
(
"md5 with blob input"
,
func
(
tt
*
testing
.
T
)
{
withDB
(
func
(
db
*
sql
.
DB
)
{
if
_
,
err
:=
db
.
Exec
(
"create table t(b blob); insert into t values (?)"
,
[]
byte
(
"abcdefg"
));
err
!=
nil
{
tt
.
Fatal
(
err
)
}
row
:=
db
.
QueryRow
(
"select md5(b) from t"
)
var
a
[]
byte
if
err
:=
row
.
Scan
(
&
a
);
err
!=
nil
{
tt
.
Fatal
(
err
)
}
if
g
,
e
:=
a
,
[]
byte
(
"7ac66c0f148de9519b8bd264312c4d64"
);
!
bytes
.
Equal
(
g
,
e
)
{
tt
.
Fatal
(
string
(
g
),
string
(
e
))
}
})
})
}
This diff is collapsed.
Click to expand it.
lib/defs.go
0 → 100644
+
10
−
0
View file @
19ea8e6c
// Copyright 2022 The Sqlite Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package
sqlite3
const
(
SQLITE_STATIC
=
uintptr
(
0
)
// ((sqlite3_destructor_type)0)
SQLITE_TRANSIENT
=
^
uintptr
(
0
)
// ((sqlite3_destructor_type)-1)
)
This diff is collapsed.
Click to expand it.
sqlite.go
+
166
−
30
View file @
19ea8e6c
...
...
@@ -749,7 +749,6 @@ type conn struct {
writeTimeFormat
string
beginMode
string
udfs
map
[
string
]
*
userDefinedFunction
}
func
newConn
(
dsn
string
)
(
*
conn
,
error
)
{
...
...
@@ -766,7 +765,7 @@ func newConn(dsn string) (*conn, error) {
}
}
c
:=
&
conn
{
tls
:
libc
.
NewTLS
()
,
udfs
:
make
(
map
[
string
]
*
userDefinedFunction
)
}
c
:=
&
conn
{
tls
:
libc
.
NewTLS
()}
db
,
err
:=
c
.
openV2
(
dsn
,
sqlite3
.
SQLITE_OPEN_READWRITE
|
sqlite3
.
SQLITE_OPEN_CREATE
|
...
...
@@ -1313,13 +1312,6 @@ func (c *conn) Close() error {
c
.
db
=
0
}
if
c
.
udfs
!=
nil
{
for
_
,
v
:=
range
c
.
udfs
{
v
.
close
(
c
.
tls
)
}
c
.
udfs
=
nil
}
if
c
.
tls
!=
nil
{
c
.
tls
.
Close
()
c
.
tls
=
nil
...
...
@@ -1345,25 +1337,7 @@ type userDefinedFunction struct {
freeOnce
sync
.
Once
}
func
(
udf
*
userDefinedFunction
)
close
(
tls
*
libc
.
TLS
)
{
if
udf
==
nil
{
return
}
udf
.
freeOnce
.
Do
(
func
()
{
libc
.
Xfree
(
tls
,
udf
.
zFuncName
)
})
}
func
(
c
*
conn
)
createFunctionInternal
(
fun
*
userDefinedFunction
)
error
{
c
.
Mutex
.
Lock
()
defer
c
.
Mutex
.
Unlock
()
goZFuncName
:=
libc
.
GoString
(
fun
.
zFuncName
)
if
prev
,
ok
:=
c
.
udfs
[
goZFuncName
];
ok
{
prev
.
close
(
c
.
tls
)
delete
(
c
.
udfs
,
goZFuncName
)
}
c
.
udfs
[
goZFuncName
]
=
fun
if
rc
:=
sqlite3
.
Xsqlite3_create_function
(
c
.
tls
,
c
.
db
,
...
...
@@ -1445,9 +1419,14 @@ func (c *conn) query(ctx context.Context, query string, args []driver.NamedValue
}
// Driver implements database/sql/driver.Driver.
type
Driver
struct
{}
type
Driver
struct
{
// user defined functions that are added to every new connection on Open
udfs
map
[
string
]
*
userDefinedFunction
}
var
d
=
&
Driver
{
udfs
:
make
(
map
[
string
]
*
userDefinedFunction
)}
func
newDriver
()
*
Driver
{
return
&
Driver
{}
}
func
newDriver
()
*
Driver
{
return
d
}
// Open returns a new connection to the database. The name is a string in a
// driver-specific format.
...
...
@@ -1479,5 +1458,162 @@ func newDriver() *Driver { return &Driver{} }
// available at
// https://www.sqlite.org/lang_transaction.html#deferred_immediate_and_exclusive_transactions
func
(
d
*
Driver
)
Open
(
name
string
)
(
driver
.
Conn
,
error
)
{
return
newConn
(
name
)
c
,
err
:=
newConn
(
name
)
if
err
!=
nil
{
return
nil
,
err
}
for
_
,
udf
:=
range
d
.
udfs
{
if
err
=
c
.
createFunctionInternal
(
udf
);
err
!=
nil
{
c
.
Close
()
return
nil
,
err
}
}
return
c
,
nil
}
type
FunctionContext
struct
{}
const
sqliteValPtrSize
=
unsafe
.
Sizeof
(
&
sqlite3
.
Sqlite3_value
{})
func
RegisterScalarFunction
(
zFuncName
string
,
nArg
int32
,
xFunc
func
(
ctx
*
FunctionContext
,
args
[]
driver
.
Value
)
(
driver
.
Value
,
error
),
)
error
{
return
registerScalarFunction
(
zFuncName
,
nArg
,
sqlite3
.
SQLITE_UTF8
,
xFunc
)
}
func
MustRegisterScalarFunction
(
zFuncName
string
,
nArg
int32
,
xFunc
func
(
ctx
*
FunctionContext
,
args
[]
driver
.
Value
)
(
driver
.
Value
,
error
),
)
{
if
err
:=
RegisterScalarFunction
(
zFuncName
,
nArg
,
xFunc
);
err
!=
nil
{
panic
(
err
)
}
}
func
MustRegisterDeterministicScalarFunction
(
zFuncName
string
,
nArg
int32
,
xFunc
func
(
ctx
*
FunctionContext
,
args
[]
driver
.
Value
)
(
driver
.
Value
,
error
),
)
{
if
err
:=
RegisterDeterministicScalarFunction
(
zFuncName
,
nArg
,
xFunc
);
err
!=
nil
{
panic
(
err
)
}
}
func
RegisterDeterministicScalarFunction
(
zFuncName
string
,
nArg
int32
,
xFunc
func
(
ctx
*
FunctionContext
,
args
[]
driver
.
Value
)
(
driver
.
Value
,
error
),
)
error
{
return
registerScalarFunction
(
zFuncName
,
nArg
,
sqlite3
.
SQLITE_UTF8
|
sqlite3
.
SQLITE_DETERMINISTIC
,
xFunc
)
}
func
registerScalarFunction
(
zFuncName
string
,
nArg
int32
,
eTextRep
int32
,
xFunc
func
(
ctx
*
FunctionContext
,
args
[]
driver
.
Value
)
(
driver
.
Value
,
error
),
)
error
{
if
_
,
ok
:=
d
.
udfs
[
zFuncName
];
ok
{
return
fmt
.
Errorf
(
"a function named %q is already registered"
,
zFuncName
)
}
// dont free, functions registered on the driver live as long as the program
name
,
err
:=
libc
.
CString
(
zFuncName
)
if
err
!=
nil
{
return
err
}
udf
:=
&
userDefinedFunction
{
zFuncName
:
name
,
nArg
:
nArg
,
eTextRep
:
eTextRep
,
xFunc
:
func
(
tls
*
libc
.
TLS
,
ctx
uintptr
,
argc
int32
,
argv
uintptr
)
{
setErrorResult
:=
func
(
res
error
)
{
errmsg
,
cerr
:=
libc
.
CString
(
res
.
Error
())
if
cerr
!=
nil
{
panic
(
cerr
)
}
defer
libc
.
Xfree
(
tls
,
errmsg
)
sqlite3
.
Xsqlite3_result_error
(
tls
,
ctx
,
errmsg
,
-
1
)
sqlite3
.
Xsqlite3_result_error_code
(
tls
,
ctx
,
sqlite3
.
SQLITE_ERROR
)
}
args
:=
make
([]
driver
.
Value
,
argc
)
for
i
:=
int32
(
0
);
i
<
argc
;
i
++
{
valPtr
:=
*
(
*
uintptr
)(
unsafe
.
Pointer
(
argv
+
uintptr
(
i
)
*
sqliteValPtrSize
))
switch
valType
:=
sqlite3
.
Xsqlite3_value_type
(
tls
,
valPtr
);
valType
{
case
sqlite3
.
SQLITE_TEXT
:
args
[
i
]
=
libc
.
GoString
(
sqlite3
.
Xsqlite3_value_text
(
tls
,
valPtr
))
case
sqlite3
.
SQLITE_INTEGER
:
args
[
i
]
=
sqlite3
.
Xsqlite3_value_int64
(
tls
,
valPtr
)
case
sqlite3
.
SQLITE_FLOAT
:
args
[
i
]
=
sqlite3
.
Xsqlite3_value_double
(
tls
,
valPtr
)
case
sqlite3
.
SQLITE_NULL
:
args
[
i
]
=
nil
case
sqlite3
.
SQLITE_BLOB
:
size
:=
sqlite3
.
Xsqlite3_value_bytes
(
tls
,
valPtr
)
blobPtr
:=
sqlite3
.
Xsqlite3_value_blob
(
tls
,
valPtr
)
v
:=
make
([]
byte
,
size
)
copy
(
v
,
(
*
libc
.
RawMem
)(
unsafe
.
Pointer
(
blobPtr
))[
:
size
:
size
])
args
[
i
]
=
v
default
:
panic
(
fmt
.
Sprintf
(
"unexpected argument type %q passed by sqlite"
,
valType
))
}
}
res
,
err
:=
xFunc
(
&
FunctionContext
{},
args
)
if
err
!=
nil
{
setErrorResult
(
err
)
return
}
switch
resTyped
:=
res
.
(
type
)
{
case
nil
:
sqlite3
.
Xsqlite3_result_null
(
tls
,
ctx
)
case
int64
:
sqlite3
.
Xsqlite3_result_int64
(
tls
,
ctx
,
resTyped
)
case
float64
:
sqlite3
.
Xsqlite3_result_double
(
tls
,
ctx
,
resTyped
)
case
bool
:
sqlite3
.
Xsqlite3_result_int
(
tls
,
ctx
,
libc
.
Bool32
(
resTyped
))
case
time
.
Time
:
sqlite3
.
Xsqlite3_result_int64
(
tls
,
ctx
,
resTyped
.
Unix
())
case
string
:
size
:=
int32
(
len
(
resTyped
))
cstr
,
err
:=
libc
.
CString
(
resTyped
)
if
err
!=
nil
{
panic
(
err
)
}
defer
libc
.
Xfree
(
tls
,
cstr
)
sqlite3
.
Xsqlite3_result_text
(
tls
,
ctx
,
cstr
,
size
,
sqlite3
.
SQLITE_TRANSIENT
)
case
[]
byte
:
size
:=
int32
(
len
(
resTyped
))
if
size
==
0
{
sqlite3
.
Xsqlite3_result_zeroblob
(
tls
,
ctx
,
0
)
return
}
p
:=
libc
.
Xmalloc
(
tls
,
types
.
Size_t
(
size
))
if
p
==
0
{
panic
(
fmt
.
Sprintf
(
"unable to allocate space for blob: %d"
,
size
))
}
defer
libc
.
Xfree
(
tls
,
p
)
copy
((
*
libc
.
RawMem
)(
unsafe
.
Pointer
(
p
))[
:
size
:
size
],
resTyped
)
sqlite3
.
Xsqlite3_result_blob
(
tls
,
ctx
,
p
,
size
,
sqlite3
.
SQLITE_TRANSIENT
)
default
:
setErrorResult
(
fmt
.
Errorf
(
"function did not return a valid driver.Value: %T"
,
resTyped
))
return
}
},
}
d
.
udfs
[
zFuncName
]
=
udf
return
nil
}
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment