Commit 3f53e147 authored by cznic's avatar cznic

Fix handling of named parameters. Closes #190.

	modified:   Makefile
	modified:   all_test.go
	modified:   driver1.8.go
	modified:   go1.8_test.go
	modified:   parser.go
	modified:   ql.y
	modified:   scanner.go
	modified:   scanner.l
parent 11f89f7e
...@@ -29,7 +29,7 @@ coerce.go: helper/helper.go ...@@ -29,7 +29,7 @@ coerce.go: helper/helper.go
go run helper/helper.go | gofmt > [email protected] go run helper/helper.go | gofmt > [email protected]
cover: cover:
t=$(shell tempfile) ; go test -coverprofile $$t && go tool cover -html $$t && unlink $$t t=$(shell mktemp) ; go test -coverprofile $$t && go tool cover -html $$t && unlink $$t
cpu: clean cpu: clean
go test -run @ -bench . -cpuprofile cpu.out go test -run @ -bench . -cpuprofile cpu.out
...@@ -58,7 +58,7 @@ nuke: clean ...@@ -58,7 +58,7 @@ nuke: clean
go clean -i go clean -i
parser.go: parser.y parser.go: parser.y
a=$(shell tempfile) ; \ a=$(shell mktemp) ; \
goyacc -o /dev/null -xegen $$a $< ; \ goyacc -o /dev/null -xegen $$a $< ; \
goyacc -cr -o [email protected] -xe $$a $< ; \ goyacc -cr -o [email protected] -xe $$a $< ; \
rm -f $$a rm -f $$a
......
...@@ -3444,6 +3444,23 @@ func TestIssue142(t *testing.T) { ...@@ -3444,6 +3444,23 @@ func TestIssue142(t *testing.T) {
} }
} }
func TestTokenize(t *testing.T) {
toks, err := tokenize("\"a$1\" `a$2` $3 $x $x_Yřa 'z' 3+6 -- foo\nbar")
if err != nil {
t.Fatal(err)
}
exp := []string{"\"a$1\"", "`a$2`", "$3", "$x", "$x_Yřa", "'z'", "3", "+", "6", "bar"}
if g, e := len(toks), len(exp); g != e {
t.Fatalf("\ngot %q\nexp %q", toks, exp)
}
for i, g := range toks {
if e := exp[i]; g != e {
t.Fatalf("\not %q\nexp %q", toks, exp)
}
}
}
// Both of the UPDATEs _should_ work but the 2nd one results in a _type missmatch_ error at the time of writing. // Both of the UPDATEs _should_ work but the 2nd one results in a _type missmatch_ error at the time of writing.
// see https://github.com/cznic/ql/issues/190 // see https://github.com/cznic/ql/issues/190
func TestIssue190(t *testing.T) { func TestIssue190(t *testing.T) {
......
...@@ -6,53 +6,99 @@ import ( ...@@ -6,53 +6,99 @@ import (
"context" "context"
"database/sql/driver" "database/sql/driver"
"fmt" "fmt"
"strconv"
"strings" "strings"
"regexp"
) )
const prefix = "$" const prefix = "$"
func (c *driverConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { func (c *driverConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
return c.Exec(replaceNamed(query, args)) query, vals, err := replaceNamed(query, args)
if err != nil {
return nil, err
}
return c.Exec(query, vals)
} }
func replaceNamed(query string, args []driver.NamedValue) (string, []driver.Value) { func replaceNamed(query string, args []driver.NamedValue) (string, []driver.Value, error) {
toks, err := tokenize(query)
if err != nil {
return "", nil, err
}
a := make([]driver.Value, len(args)) a := make([]driver.Value, len(args))
for k, v := range args { m := map[string]int{}
if v.Name != "" { for _, v := range args {
query = strings.Replace(query, prefix+v.Name, fmt.Sprintf("%s%d", prefix, v.Ordinal), -1) m[v.Name] = v.Ordinal
a[v.Ordinal-1] = v.Value
}
for i, v := range toks {
if len(v) > 1 && strings.HasPrefix(v, prefix) {
if v[1] >= '1' && v[1] <= '9' {
continue
}
nm := v[1:]
k, ok := m[nm]
if !ok {
return query, nil, fmt.Errorf("unknown named parameter %s", nm)
}
toks[i] = fmt.Sprintf("$%d", k)
} }
a[k] = v.Value
} }
return query, a return strings.Join(toks, " "), a, nil
} }
func (c *driverConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { func (c *driverConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
return c.Query(replaceNamed(query, args)) query, vals, err := replaceNamed(query, args)
if err != nil {
return nil, err
}
return c.Query(query, vals)
} }
func (c *driverConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) { func (c *driverConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
return c.Prepare(filterNamedArgs(query)) query, err := filterNamedArgs(query)
} if err != nil {
return nil, err
}
var re = regexp.MustCompile(`^\w+`) return c.Prepare(query)
}
func filterNamedArgs(q string) string { func filterNamedArgs(query string) (string, error) {
c := strings.Count(q, prefix) toks, err := tokenize(query)
if c == 0 || c == len(q) { if err != nil {
return q return "", err
} }
pc := strings.Split(q, prefix)
for k, v := range pc { n := 0
if k == 0 { for _, v := range toks {
continue if len(v) > 1 && strings.HasPrefix(v, prefix) && v[1] >= '1' && v[1] <= '9' {
m, err := strconv.ParseUint(v[1:], 10, 31)
if err != nil {
return "", err
}
if int(m) > n {
n = int(m)
}
} }
if v != "" { }
pc[k] = re.ReplaceAllString(v, fmt.Sprint(k)) for i, v := range toks {
if len(v) > 1 && strings.HasPrefix(v, prefix) {
if v[1] >= '1' && v[1] <= '9' {
continue
}
n++
toks[i] = fmt.Sprintf("$%d", n)
} }
} }
return strings.Join(pc, prefix) return strings.Join(toks, " "), nil
} }
func (s *driverStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) { func (s *driverStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
...@@ -70,3 +116,25 @@ func (s *driverStmt) QueryContext(ctx context.Context, args []driver.NamedValue) ...@@ -70,3 +116,25 @@ func (s *driverStmt) QueryContext(ctx context.Context, args []driver.NamedValue)
} }
return s.Query(a) return s.Query(a)
} }
func tokenize(s string) (r []string, _ error) {
lx, err := newLexer(s)
if err != nil {
return nil, err
}
var lval yySymType
for lx.Lex(&lval) != 0 {
s := string(lx.TokenBytes(nil))
if s != "" {
switch s[len(s)-1] {
case '"':
s = "\"" + s
case '`':
s = "`" + s
}
}
r = append(r, s)
}
return r, nil
}
...@@ -84,9 +84,9 @@ func TestNamedArgs(t *testing.T) { ...@@ -84,9 +84,9 @@ func TestNamedArgs(t *testing.T) {
rows, err := db.QueryContext( rows, err := db.QueryContext(
context.Background(), context.Background(),
`select $one;select $two;select $three;`, `select $two;select $one;select $three;`,
sql.Named("one", 1), sql.Named("one", 2),
sql.Named("two", 2), sql.Named("two", 1),
sql.Named("three", 3), sql.Named("three", 3),
) )
if err != nil { if err != nil {
...@@ -119,21 +119,25 @@ func TestNamedArgs(t *testing.T) { ...@@ -119,21 +119,25 @@ func TestNamedArgs(t *testing.T) {
}{ }{
{ {
`select $one;select $two;select $three;`, `select $one;select $two;select $three;`,
`select $1;select $2;select $3;`, `select $1 ; select $2 ; select $3 ;`,
}, },
{ {
`select * from foo where t=$1`, `select * from foo where t=$1`,
`select * from foo where t=$1`, `select * from foo where t = $1`,
}, },
{ {
`select * from foo where t=$1&&name=$name`, `select * from foo where t=$1&&name=$name`,
`select * from foo where t=$1&&name=$2`, `select * from foo where t = $1 && name = $2`,
}, },
} }
for _, s := range samples { for _, s := range samples {
e := filterNamedArgs(s.src) e, err := filterNamedArgs(s.src)
if err != nil {
t.Fatal(err)
}
if e != s.exp { if e != s.exp {
t.Errorf("expected %s got %s", s.exp, e) t.Errorf("\nexpected %q\n got %q", s.exp, e)
} }
} }
......
...@@ -134,6 +134,8 @@ const ( ...@@ -134,6 +134,8 @@ const (
) )
var ( var (
yyPrec = map[int]int{}
yyXLAT = map[int]int{ yyXLAT = map[int]int{
59: 0, // ';' (204x) 59: 0, // ';' (204x)
57344: 1, // $end (203x) 57344: 1, // $end (203x)
...@@ -1700,7 +1702,7 @@ func yySymName(c int) (s string) { ...@@ -1700,7 +1702,7 @@ func yySymName(c int) (s string) {
} }
if c < 0x7f { if c < 0x7f {
return __yyfmt__.Sprintf("'%c'", c) return __yyfmt__.Sprintf("%q", c)
} }
return __yyfmt__.Sprintf("%d", c) return __yyfmt__.Sprintf("%d", c)
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
//TODO Put your favorite license here //TODO Put your favorite license here
// yacc source generated by ebnf2y[1] // yacc source generated by ebnf2y[1]
// at 2017-08-31 17:43:07.227157474 +0200 CEST m=+0.001846399 // at 2017-11-22 13:44:30.7008477 +0100 CET m=+0.004756809
// //
// $ ebnf2y -o ql.y -oe ql.ebnf -start StatementList -pkg ql -p _ // $ ebnf2y -o ql.y -oe ql.ebnf -start StatementList -pkg ql -p _
// //
......
This source diff could not be displayed because it is too large. You can view the blob instead.
...@@ -324,7 +324,11 @@ ident {idchar0}{idchars}* ...@@ -324,7 +324,11 @@ ident {idchar0}{idchars}*
{ident} lval.item = string(l.TokenBytes(nil)) {ident} lval.item = string(l.TokenBytes(nil))
return identifier return identifier
($|\?){D} lval.item, _ = strconv.Atoi(string(l.TokenBytes(nil)[1:])) ($|\?)({D}|{ident}) s := string(l.TokenBytes(nil)[1:])
lval.item, _ = strconv.Atoi(s)
if s != "" && s[0] < '1' || s[0] > '9' {
l.err("parameter number must be non zero")
}
return qlParam return qlParam
%% %%
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment