driver1.8.go 2.72 KB
Newer Older
1 2 3 4 5 6 7 8
// +build go1.8

package ql

import (
	"context"
	"database/sql/driver"
	"fmt"
9
	"strconv"
10 11 12 13 14 15
	"strings"
)

const prefix = "$"

func (c *driverConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
16 17 18 19 20 21
	query, vals, err := replaceNamed(query, args)
	if err != nil {
		return nil, err
	}

	return c.Exec(query, vals)
22 23
}

24 25 26 27 28 29
func replaceNamed(query string, args []driver.NamedValue) (string, []driver.Value, error) {
	toks, err := tokenize(query)
	if err != nil {
		return "", nil, err
	}

30
	a := make([]driver.Value, len(args))
31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48
	m := map[string]int{}
	for _, v := range args {
		m[v.Name] = v.Ordinal
		a[v.Ordinal-1] = v.Value
	}
	for i, v := range toks {
		if len(v) > 1 && strings.HasPrefix(v, prefix) {
			if v[1] >= '1' && v[1] <= '9' {
				continue
			}

			nm := v[1:]
			k, ok := m[nm]
			if !ok {
				return query, nil, fmt.Errorf("unknown named parameter %s", nm)
			}

			toks[i] = fmt.Sprintf("$%d", k)
49 50
		}
	}
51
	return strings.Join(toks, " "), a, nil
52 53 54
}

func (c *driverConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
55 56 57 58 59 60
	query, vals, err := replaceNamed(query, args)
	if err != nil {
		return nil, err
	}

	return c.Query(query, vals)
61 62 63
}

func (c *driverConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
64 65 66 67
	query, err := filterNamedArgs(query)
	if err != nil {
		return nil, err
	}
68

69 70
	return c.Prepare(query)
}
71

72 73 74 75
func filterNamedArgs(query string) (string, error) {
	toks, err := tokenize(query)
	if err != nil {
		return "", err
76
	}
77 78 79 80 81 82 83 84 85 86 87 88

	n := 0
	for _, v := range toks {
		if len(v) > 1 && strings.HasPrefix(v, prefix) && v[1] >= '1' && v[1] <= '9' {
			m, err := strconv.ParseUint(v[1:], 10, 31)
			if err != nil {
				return "", err
			}

			if int(m) > n {
				n = int(m)
			}
89
		}
90 91 92 93 94 95 96 97 98
	}
	for i, v := range toks {
		if len(v) > 1 && strings.HasPrefix(v, prefix) {
			if v[1] >= '1' && v[1] <= '9' {
				continue
			}

			n++
			toks[i] = fmt.Sprintf("$%d", n)
99 100
		}
	}
101
	return strings.Join(toks, " "), nil
102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118
}

func (s *driverStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
	a := make([]driver.Value, len(args))
	for k, v := range args {
		a[k] = v.Value
	}
	return s.Exec(a)
}

func (s *driverStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
	a := make([]driver.Value, len(args))
	for k, v := range args {
		a[k] = v.Value
	}
	return s.Query(a)
}
119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140

func tokenize(s string) (r []string, _ error) {
	lx, err := newLexer(s)
	if err != nil {
		return nil, err
	}

	var lval yySymType
	for lx.Lex(&lval) != 0 {
		s := string(lx.TokenBytes(nil))
		if s != "" {
			switch s[len(s)-1] {
			case '"':
				s = "\"" + s
			case '`':
				s = "`" + s
			}
		}
		r = append(r, s)
	}
	return r, nil
}