Commit a41c7060 authored by Geofrey Ernest's avatar Geofrey Ernest Committed by GitHub

[WIP] Support for go1.8 (#172)

Add NamedArg support
parent fa0b3688
// +build go1.8
package ql
import (
"context"
"database/sql/driver"
"fmt"
"strings"
"regexp"
)
const prefix = "$"
func (c *driverConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
return c.Exec(replaceNamed(query, args))
}
func replaceNamed(query string, args []driver.NamedValue) (string, []driver.Value) {
a := make([]driver.Value, len(args))
for k, v := range args {
if v.Name != "" {
query = strings.Replace(query, prefix+v.Name, fmt.Sprintf("%s%d", prefix, v.Ordinal), -1)
}
a[k] = v.Value
}
return query, a
}
func (c *driverConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
return c.Query(replaceNamed(query, args))
}
func (c *driverConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
return c.Prepare(filterNamedArgs(query))
}
var re = regexp.MustCompile(`^\w+`)
func filterNamedArgs(q string) string {
c := strings.Count(q, prefix)
if c == 0 || c == len(q) {
return q
}
pc := strings.Split(q, prefix)
for k, v := range pc {
if k == 0 {
continue
}
if v != "" {
pc[k] = re.ReplaceAllString(v, fmt.Sprint(k))
}
}
return strings.Join(pc, prefix)
}
func (s *driverStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
a := make([]driver.Value, len(args))
for k, v := range args {
a[k] = v.Value
}
return s.Exec(a)
}
func (s *driverStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
a := make([]driver.Value, len(args))
for k, v := range args {
a[k] = v.Value
}
return s.Query(a)
}
//+build go1.8
// +build go1.8
package ql
import (
"context"
"database/sql"
"testing"
)
......@@ -72,3 +73,80 @@ func TestMultiResultSet(t *testing.T) {
t.Fatal("unexpected result set")
}
}
func TestNamedArgs(t *testing.T) {
RegisterMemDriver()
db, err := sql.Open("ql-mem", "")
if err != nil {
t.Fatal(err)
}
defer db.Close()
rows, err := db.QueryContext(
context.Background(),
`select $one;select $two;select $three;`,
sql.Named("one", 1),
sql.Named("two", 2),
sql.Named("three", 3),
)
if err != nil {
t.Fatal(err)
}
defer rows.Close()
var i int
for rows.Next() {
if err := rows.Scan(&i); err != nil {
t.Fatal(err)
}
if i != 1 {
t.Fatalf("expected 1, got %d", i)
}
}
if !rows.NextResultSet() {
t.Fatal("expected more result sets", rows.Err())
}
for rows.Next() {
if err := rows.Scan(&i); err != nil {
t.Fatal(err)
}
if i != 2 {
t.Fatalf("expected 2, got %d", i)
}
}
samples := []struct {
src, exp string
}{
{
`select $one;select $two;select $three;`,
`select $1;select $2;select $3;`,
},
{
`select * from foo where t=$1`,
`select * from foo where t=$1`,
},
{
`select * from foo where t=$1&&name=$name`,
`select * from foo where t=$1&&name=$2`,
},
}
for _, s := range samples {
e := filterNamedArgs(s.src)
if e != s.exp {
t.Errorf("expected %s got %s", s.exp, e)
}
}
stmt, err := db.PrepareContext(context.Background(), `select $number`)
if err != nil {
t.Fatal(err)
}
var n int
err = stmt.QueryRowContext(context.Background(), sql.Named("number", 1)).Scan(&n)
if err != nil {
t.Fatal(err)
}
if n != 1 {
t.Errorf("expected 1 got %d", n)
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment