Commit 3497a9e4 authored by Geofrey Ernest's avatar Geofrey Ernest Committed by GitHub

cache selectStmt inside where exists (select...) (#165)

Closes #159
parent 7fbcf48b
......@@ -3441,3 +3441,63 @@ func TestIssue142(t *testing.T) {
}
}
}
func TestWhereExists(t *testing.T) {
RegisterMemDriver()
db, err := sql.Open("ql-mem", "")
if err != nil {
t.Fatal(err)
}
defer db.Close()
tx, err := db.Begin()
if err != nil {
t.Fatal(err)
}
tx.Exec(`
BEGIN TRANSACTION;
CREATE TABLE t (i int);
CREATE TABLE s (i int);
INSERT INTO t VALUES (0);
INSERT INTO t VALUES (1);
INSERT INTO t VALUES (2);
INSERT INTO t VALUES (3);
INSERT INTO t VALUES (4);
INSERT INTO t VALUES (5);
INSERT INTO s VALUES (2);
COMMIT;
`)
err = tx.Commit()
if err != nil {
t.Fatal(err)
}
s, err := db.Prepare(`
select * from t where exists (select * from s where i==$1);
`)
if err != nil {
t.Fatal(err)
}
defer s.Close()
var wg sync.WaitGroup
for i := 0; i < 5; i++ {
go func(id int, wait *sync.WaitGroup) {
var c int
err := s.QueryRow(id).Scan(&c)
if id == 2 {
if err != nil {
t.Error(err)
}
if id == 2 && c != 5 {
t.Errorf("expected %d got %d", id, c)
}
} else {
if err != sql.ErrNoRows {
t.Errorf("expected %v got %v", sql.ErrNoRows, err)
}
}
wait.Done()
}(i, &wg)
wg.Add(1)
}
wg.Wait()
}
......@@ -530,18 +530,28 @@ func (r *whereRset) plan(ctx *execCtx) (plan, error) {
o := r.src
if r.sel != nil {
var exists bool
p, err := r.sel.plan(ctx)
if err != nil {
return nil, err
}
err = p.do(ctx, func(i interface{}, data []interface{}) (bool, error) {
if len(data) > 0 {
exists = true
ctx.mu.RLock()
m, ok := ctx.cache[r.sel]
ctx.mu.RUnlock()
if ok {
exists = m.(bool)
} else {
p, err := r.sel.plan(ctx)
if err != nil {
return nil, err
}
return false, nil
})
if err != nil {
return nil, err
err = p.do(ctx, func(i interface{}, data []interface{}) (bool, error) {
if len(data) > 0 {
exists = true
}
return false, nil
})
if err != nil {
return nil, err
}
ctx.mu.Lock()
ctx.cache[r.sel] = true
ctx.mu.Unlock()
}
if r.exists == exists {
return o, nil
......@@ -840,7 +850,7 @@ func newDB(store storage) (db *DB, err error) {
return
}
ctx := &execCtx{db: db0}
ctx := newExecCtx(db0, nil)
for _, t := range db0.root.tables {
if err := t.constraintsAndDefaults(ctx); err != nil {
return nil, err
......@@ -938,7 +948,7 @@ func newDB(store storage) (db *DB, err error) {
func (db *DB) deleteIndex2ByIndexName(nm string) error {
for _, s := range deleteIndex2ByIndexName.l {
if _, err := s.exec(&execCtx{db: db, arg: []interface{}{nm}}); err != nil {
if _, err := s.exec(newExecCtx(db, []interface{}{nm})); err != nil {
return err
}
}
......@@ -947,7 +957,7 @@ func (db *DB) deleteIndex2ByIndexName(nm string) error {
func (db *DB) deleteIndex2ByTableName(nm string) error {
for _, s := range deleteIndex2ByTableName.l {
if _, err := s.exec(&execCtx{db: db, arg: []interface{}{nm}}); err != nil {
if _, err := s.exec(newExecCtx(db, []interface{}{nm})); err != nil {
return err
}
}
......@@ -1308,7 +1318,7 @@ func (db *DB) run1(pc *TCtx, s stmt, arg ...interface{}) (rs Recordset, tnla, tn
db.rwmu.RLock() // can safely grab before Unlock
db.muUnlock()
defer db.rwmu.RUnlock()
rs, err = s.exec(&execCtx{db, arg}) // R/O tctx
rs, err = s.exec(newExecCtx(db, arg)) // R/O tctx
return rs, tnla, tnlb, err
}
default: // case true:
......@@ -1390,7 +1400,7 @@ func (db *DB) run1(pc *TCtx, s stmt, arg ...interface{}) (rs Recordset, tnla, tn
db.muUnlock() // must Unlock before RLock
db.rwmu.RLock()
defer db.rwmu.RUnlock()
rs, err = s.exec(&execCtx{db, arg})
rs, err = s.exec(newExecCtx(db, arg))
return rs, tnla, tnlb, err
}
......@@ -1400,7 +1410,7 @@ func (db *DB) run1(pc *TCtx, s stmt, arg ...interface{}) (rs Recordset, tnla, tn
return nil, tnla, tnlb, fmt.Errorf("invalid passed transaction context")
}
rs, err = s.exec(&execCtx{db, arg})
rs, err = s.exec(newExecCtx(db, arg))
return rs, tnla, tnlb, err
}
}
......@@ -1568,13 +1578,13 @@ func (db *DB) info() (r *DbInfo, err error) {
ti := TableInfo{Name: nm}
m := map[string]*ColumnInfo{}
if hasColumn2 {
rs, err := selectColumn2.l[0].exec(&execCtx{db: db, arg: []interface{}{nm}})
rs, err := selectColumn2.l[0].exec(newExecCtx(db, []interface{}{nm}))
if err != nil {
return nil, err
}
if err := rs.(recordset).do(
&execCtx{db: db, arg: []interface{}{nm}},
newExecCtx(db, []interface{}{nm}),
func(id interface{}, data []interface{}) (bool, error) {
ci := &ColumnInfo{NotNull: data[1].(bool), Constraint: data[2].(string), Default: data[3].(string)}
m[data[0].(string)] = ci
......
......@@ -9,6 +9,8 @@ import (
"fmt"
"strings"
"sync"
"github.com/cznic/strutil"
)
......@@ -121,8 +123,18 @@ type stmt interface {
}
type execCtx struct { //LATER +shared temp
db *DB
arg []interface{}
db *DB
arg []interface{}
cache map[interface{}]interface{}
mu sync.RWMutex
}
func newExecCtx(db *DB, arg []interface{}) *execCtx {
return &execCtx{
db: db,
arg: arg,
cache: make(map[interface{}]interface{}),
}
}
type explainStmt struct {
......@@ -571,7 +583,7 @@ func (s *alterTableDropColumnStmt) exec(ctx *execCtx) (Recordset, error) {
}
if _, ok := ctx.db.root.tables["__Column2"]; ok {
if _, err := deleteColumn2.l[0].exec(&execCtx{db: ctx.db, arg: []interface{}{s.tableName, c.name}}); err != nil {
if _, err := deleteColumn2.l[0].exec(newExecCtx(ctx.db, []interface{}{s.tableName, c.name})); err != nil {
return nil, err
}
}
......@@ -680,7 +692,7 @@ func (s *alterTableAddStmt) exec(ctx *execCtx) (Recordset, error) {
if c.constraint != nil || c.dflt != nil {
for _, s := range createColumn2.l {
_, err := s.exec(&execCtx{db: ctx.db})
_, err := s.exec(newExecCtx(ctx.db, nil))
if err != nil {
return nil, err
}
......@@ -693,7 +705,7 @@ func (s *alterTableAddStmt) exec(ctx *execCtx) (Recordset, error) {
if e := c.dflt; e != nil {
d = e.String()
}
if _, err := insertColumn2.l[0].exec(&execCtx{db: ctx.db, arg: []interface{}{s.tableName, c.name, notNull, co, d}}); err != nil {
if _, err := insertColumn2.l[0].exec(newExecCtx(ctx.db, []interface{}{s.tableName, c.name, notNull, co, d})); err != nil {
return nil, err
}
}
......@@ -1251,7 +1263,7 @@ func (s *createTableStmt) exec(ctx *execCtx) (_ Recordset, err error) {
if c.constraint != nil || c.dflt != nil {
if mustCreateColumn2 {
for _, stmt := range createColumn2.l {
_, err := stmt.exec(&execCtx{db: ctx.db})
_, err := stmt.exec(newExecCtx(ctx.db, nil))
if err != nil {
return nil, err
}
......@@ -1267,7 +1279,7 @@ func (s *createTableStmt) exec(ctx *execCtx) (_ Recordset, err error) {
if e := c.dflt; e != nil {
d = e.String()
}
if _, err := insertColumn2.l[0].exec(&execCtx{db: ctx.db, arg: []interface{}{s.tableName, c.name, notNull, co, d}}); err != nil {
if _, err := insertColumn2.l[0].exec(newExecCtx(ctx.db, []interface{}{s.tableName, c.name, notNull, co, d})); err != nil {
return nil, err
}
}
......
......@@ -153,7 +153,7 @@ func (t *table) constraintsAndDefaults(ctx *execCtx) error {
constraints := make([]*constraint, len(cols))
defaults := make([]expression, len(cols))
arg := []interface{}{t.name}
rs, err := selectColumn2.l[0].exec(&execCtx{db: ctx.db, arg: arg})
rs, err := selectColumn2.l[0].exec(newExecCtx(ctx.db, arg))
if err != nil {
return err
}
......@@ -161,7 +161,7 @@ func (t *table) constraintsAndDefaults(ctx *execCtx) error {
var rows [][]interface{}
ok = false
if err := rs.(recordset).do(
&execCtx{db: ctx.db, arg: arg},
newExecCtx(ctx.db, arg),
func(id interface{}, data []interface{}) (more bool, err error) {
rows = append(rows, data)
return true, nil
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment