From: Brad Fitzpatrick Date: Thu, 29 Sep 2011 23:12:21 +0000 (-0700) Subject: exp/sql{,/driver}: new database packages X-Git-Tag: weekly.2011-10-06~59 X-Git-Url: http://www.git.cypherpunks.su/?a=commitdiff_plain;h=357f2cb1a385f4d1418e48856f9abe0cce;p=gostls13.git exp/sql{,/driver}: new database packages R=gustavo, rsc, borman, dave, kevlar, nigeltao, dvyukov, kardianos, fw, r, r, david.crawshaw CC=golang-dev https://golang.org/cl/4973055 --- diff --git a/src/pkg/Makefile b/src/pkg/Makefile index 063148b5e5..6b9a2d8784 100644 --- a/src/pkg/Makefile +++ b/src/pkg/Makefile @@ -82,6 +82,8 @@ DIRS=\ exp/gui\ exp/gui/x11\ exp/norm\ + exp/sql\ + exp/sql/driver\ exp/template/html\ expvar\ flag\ @@ -202,6 +204,7 @@ NOTEST+=\ crypto/x509/pkix\ exp/gui\ exp/gui/x11\ + exp/sql/driver\ go/ast\ go/doc\ go/token\ diff --git a/src/pkg/exp/sql/Makefile b/src/pkg/exp/sql/Makefile new file mode 100644 index 0000000000..1e4f74c821 --- /dev/null +++ b/src/pkg/exp/sql/Makefile @@ -0,0 +1,12 @@ +# Copyright 2011 The Go Authors. All rights reserved. +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file. + +include ../../../Make.inc + +TARG=exp/sql +GOFILES=\ + convert.go\ + sql.go\ + +include ../../../Make.pkg diff --git a/src/pkg/exp/sql/convert.go b/src/pkg/exp/sql/convert.go new file mode 100644 index 0000000000..a35e0be9cb --- /dev/null +++ b/src/pkg/exp/sql/convert.go @@ -0,0 +1,106 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Type conversions for Scan. + +package sql + +import ( + "fmt" + "os" + "reflect" + "strconv" +) + +// convertAssign copies to dest the value in src, converting it if possible. +// An error is returned if the copy would result in loss of information. +// dest should be a pointer type. +func convertAssign(dest, src interface{}) os.Error { + // Common cases, without reflect. Fall through. + switch s := src.(type) { + case string: + switch d := dest.(type) { + case *string: + *d = s + return nil + } + case []byte: + switch d := dest.(type) { + case *string: + *d = string(s) + return nil + case *[]byte: + *d = s + return nil + } + } + + sv := reflect.ValueOf(src) + + switch d := dest.(type) { + case *string: + switch sv.Kind() { + case reflect.Bool, + reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, + reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, + reflect.Float32, reflect.Float64: + *d = fmt.Sprintf("%v", src) + return nil + } + } + + if scanner, ok := dest.(ScannerInto); ok { + return scanner.ScanInto(src) + } + + dpv := reflect.ValueOf(dest) + if dpv.Kind() != reflect.Ptr { + return os.NewError("destination not a pointer") + } + + dv := reflect.Indirect(dpv) + if dv.Kind() == sv.Kind() { + dv.Set(sv) + return nil + } + + switch dv.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + if s, ok := asString(src); ok { + i64, err := strconv.Atoi64(s) + if err != nil { + return fmt.Errorf("converting string %q to a %s: %v", s, dv.Kind(), err) + } + if dv.OverflowInt(i64) { + return fmt.Errorf("string %q overflows %s", s, dv.Kind()) + } + dv.SetInt(i64) + return nil + } + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + if s, ok := asString(src); ok { + u64, err := strconv.Atoui64(s) + if err != nil { + return fmt.Errorf("converting string %q to a %s: %v", s, dv.Kind(), err) + } + if dv.OverflowUint(u64) { + return fmt.Errorf("string %q overflows %s", s, dv.Kind()) + } + dv.SetUint(u64) + return nil + } + } + + return fmt.Errorf("unsupported driver -> Scan pair: %T -> %T", src, dest) +} + +func asString(src interface{}) (s string, ok bool) { + switch v := src.(type) { + case string: + return v, true + case []byte: + return string(v), true + } + return "", false +} diff --git a/src/pkg/exp/sql/convert_test.go b/src/pkg/exp/sql/convert_test.go new file mode 100644 index 0000000000..7070709b09 --- /dev/null +++ b/src/pkg/exp/sql/convert_test.go @@ -0,0 +1,108 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package sql + +import ( + "fmt" + "reflect" + "testing" +) + +type conversionTest struct { + s, d interface{} // source and destination + + // following are used if they're non-zero + wantint int64 + wantuint uint64 + wantstr string + wanterr string +} + +// Target variables for scanning into. +var ( + scanstr string + scanint int + scanint8 int8 + scanint16 int16 + scanint32 int32 + scanuint8 uint8 + scanuint16 uint16 +) + +var conversionTests = []conversionTest{ + // Exact conversions (destination pointer type matches source type) + {s: "foo", d: &scanstr, wantstr: "foo"}, + {s: 123, d: &scanint, wantint: 123}, + + // To strings + {s: []byte("byteslice"), d: &scanstr, wantstr: "byteslice"}, + {s: 123, d: &scanstr, wantstr: "123"}, + {s: int8(123), d: &scanstr, wantstr: "123"}, + {s: int64(123), d: &scanstr, wantstr: "123"}, + {s: uint8(123), d: &scanstr, wantstr: "123"}, + {s: uint16(123), d: &scanstr, wantstr: "123"}, + {s: uint32(123), d: &scanstr, wantstr: "123"}, + {s: uint64(123), d: &scanstr, wantstr: "123"}, + {s: 1.5, d: &scanstr, wantstr: "1.5"}, + + // Strings to integers + {s: "255", d: &scanuint8, wantuint: 255}, + {s: "256", d: &scanuint8, wanterr: `string "256" overflows uint8`}, + {s: "256", d: &scanuint16, wantuint: 256}, + {s: "-1", d: &scanint, wantint: -1}, + {s: "foo", d: &scanint, wanterr: `converting string "foo" to a int: parsing "foo": invalid argument`}, +} + +func intValue(intptr interface{}) int64 { + return reflect.Indirect(reflect.ValueOf(intptr)).Int() +} + +func uintValue(intptr interface{}) uint64 { + return reflect.Indirect(reflect.ValueOf(intptr)).Uint() +} + +func TestConversions(t *testing.T) { + for n, ct := range conversionTests { + err := convertAssign(ct.d, ct.s) + errstr := "" + if err != nil { + errstr = err.String() + } + errf := func(format string, args ...interface{}) { + base := fmt.Sprintf("convertAssign #%d: for %v (%T) -> %T, ", n, ct.s, ct.s, ct.d) + t.Errorf(base+format, args...) + } + if errstr != ct.wanterr { + errf("got error %q, want error %q", errstr, ct.wanterr) + } + if ct.wantstr != "" && ct.wantstr != scanstr { + errf("want string %q, got %q", ct.wantstr, scanstr) + } + if ct.wantint != 0 && ct.wantint != intValue(ct.d) { + errf("want int %d, got %d", ct.wantint, intValue(ct.d)) + } + if ct.wantuint != 0 && ct.wantuint != uintValue(ct.d) { + errf("want uint %d, got %d", ct.wantuint, uintValue(ct.d)) + } + } +} + +func TestNullableString(t *testing.T) { + var ns NullableString + convertAssign(&ns, []byte("foo")) + if !ns.Valid { + t.Errorf("expecting not null") + } + if ns.String != "foo" { + t.Errorf("expecting foo; got %q", ns.String) + } + convertAssign(&ns, nil) + if ns.Valid { + t.Errorf("expecting null on nil") + } + if ns.String != "" { + t.Errorf("expecting blank on nil; got %q", ns.String) + } +} diff --git a/src/pkg/exp/sql/doc.txt b/src/pkg/exp/sql/doc.txt new file mode 100644 index 0000000000..fb16595487 --- /dev/null +++ b/src/pkg/exp/sql/doc.txt @@ -0,0 +1,46 @@ +Goals of the sql and sql/driver packages: + +* Provide a generic database API for a variety of SQL or SQL-like + databases. There currently exist Go libraries for SQLite, MySQL, + and Postgres, but all with a very different feel, and often + a non-Go-like feel. + +* Feel like Go. + +* Care mostly about the common cases. Common SQL should be portable. + SQL edge cases or db-specific extensions can be detected and + conditionally used by the application. It is a non-goal to care + about every particular db's extension or quirk. + +* Separate out the basic implementation of a database driver + (implementing the sql/driver interfaces) vs the implementation + of all the user-level types and convenience methods. + In a nutshell: + + User Code ---> sql package (concrete types) ---> sql/driver (interfaces) + Database Driver -> sql (to register) + sql/driver (implement interfaces) + +* Make type casting/conversions consistent between all drivers. To + achieve this, most of the conversions are done in the db package, + not in each driver. The drivers then only have to deal with a + smaller set of types. + +* Be flexible with type conversions, but be paranoid about silent + truncation or other loss of precision. + +* Handle concurrency well. Users shouldn't need to care about the + database's per-connection thread safety issues (or lack thereof), + and shouldn't have to maintain their own free pools of connections. + The 'db' package should deal with that bookkeeping as needed. Given + an *sql.DB, it should be possible to share that instance between + multiple goroutines, without any extra synchronization. + +* Push complexity, where necessary, down into the sql+driver packages, + rather than exposing it to users. Said otherwise, the sql package + should expose an ideal database that's not finnicky about how it's + accessed, even if that's not true. + +* Provide optional interfaces in sql/driver for drivers to implement + for special cases or fastpaths. But the only party that knows about + those is the sql package. To user code, some stuff just might start + working or start working slightly faster. diff --git a/src/pkg/exp/sql/driver/Makefile b/src/pkg/exp/sql/driver/Makefile new file mode 100644 index 0000000000..fce3f2c27c --- /dev/null +++ b/src/pkg/exp/sql/driver/Makefile @@ -0,0 +1,12 @@ +# Copyright 2011 The Go Authors. All rights reserved. +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file. + +include ../../../../Make.inc + +TARG=exp/sql/driver +GOFILES=\ + driver.go\ + types.go\ + +include ../../../../Make.pkg diff --git a/src/pkg/exp/sql/driver/driver.go b/src/pkg/exp/sql/driver/driver.go new file mode 100644 index 0000000000..7508b19fa1 --- /dev/null +++ b/src/pkg/exp/sql/driver/driver.go @@ -0,0 +1,169 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package driver defines interfaces to be implemented by database +// drivers as used by package sql. +// +// Code simply using databases should use package sql. +// +// Drivers only need to be aware of a subset of Go's types. The db package +// will convert all types into one of the following: +// +// int64 +// float64 +// bool +// nil +// []byte +// string [*] everywhere except from Rows.Next. +// +package driver + +import ( + "os" +) + +// Driver is the interface that must be implemented by a database +// driver. +type Driver interface { + // Open returns a new or cached connection to the database. + // The name is a string in a driver-specific format. + // + // The returned connection is only used by one goroutine at a + // time. + Open(name string) (Conn, os.Error) +} + +// Execer is an optional interface that may be implemented by a Driver +// or a Conn. +// +// If a Driver does not implement Execer, the sql package's DB.Exec +// method first obtains a free connection from its free pool or from +// the driver's Open method. Execer should only be implemented by +// drivers that can provide a more efficient implementation. +// +// If a Conn does not implement Execer, the db package's DB.Exec will +// first prepare a query, execute the statement, and then close the +// statement. +// +// All arguments are of a subset type as defined in the package docs. +type Execer interface { + Exec(query string, args []interface{}) (Result, os.Error) +} + +// Conn is a connection to a database. It is not used concurrently +// by multiple goroutines. +// +// Conn is assumed to be stateful. +type Conn interface { + // Prepare returns a prepared statement, bound to this connection. + Prepare(query string) (Stmt, os.Error) + + // Close invalidates and potentially stops any current + // prepared statements and transactions, marking this + // connection as no longer in use. The driver may cache or + // close its underlying connection to its database. + Close() os.Error + + // Begin starts and returns a new transaction. + Begin() (Tx, os.Error) +} + +// Result is the result of a query execution. +type Result interface { + // LastInsertId returns the database's auto-generated ID + // after, for example, an INSERT into a table with primary + // key. + LastInsertId() (int64, os.Error) + + // RowsAffected returns the number of rows affected by the + // query. + RowsAffected() (int64, os.Error) +} + +// Stmt is a prepared statement. It is bound to a Conn and not +// used by multiple goroutines concurrently. +type Stmt interface { + // Close closes the statement. + Close() os.Error + + // NumInput returns the number of placeholder parameters. + NumInput() int + + // Exec executes a query that doesn't return rows, such + // as an INSERT or UPDATE. The args are all of a subset + // type as defined above. + Exec(args []interface{}) (Result, os.Error) + + // Exec executes a query that may return rows, such as a + // SELECT. The args of all of a subset type as defined above. + Query(args []interface{}) (Rows, os.Error) +} + +// ColumnConverter may be optionally implemented by Stmt if the +// the statement is aware of its own columns' types and can +// convert from any type to a driver subset type. +type ColumnConverter interface { + // ColumnConverter returns a ValueConverter for the provided + // column index. If the type of a specific column isn't known + // or shouldn't be handled specially, DefaultValueConverter + // can be returned. + ColumnConverter(idx int) ValueConverter +} + +// Rows is an iterator over an executed query's results. +type Rows interface { + // Columns returns the names of the columns. The number of + // columns of the result is inferred from the length of the + // slice. If a particular column name isn't known, an empty + // string should be returned for that entry. + Columns() []string + + // Close closes the rows iterator. + Close() os.Error + + // Next is called to populate the next row of data into + // the provided slice. The provided slice will be the same + // size as the Columns() are wide. + // + // The dest slice may be populated with only with values + // of subset types defined above, but excluding string. + // All string values must be converted to []byte. + Next(dest []interface{}) os.Error +} + +// Tx is a transaction. +type Tx interface { + Commit() os.Error + Rollback() os.Error +} + +// RowsAffected implements Result for an INSERT or UPDATE operation +// which mutates a number of rows. +type RowsAffected int64 + +var _ Result = RowsAffected(0) + +func (RowsAffected) LastInsertId() (int64, os.Error) { + return 0, os.NewError("no LastInsertId available") +} + +func (v RowsAffected) RowsAffected() (int64, os.Error) { + return int64(v), nil +} + +// DDLSuccess is a pre-defined Result for drivers to return when a DDL +// command succeeds. +var DDLSuccess ddlSuccess + +type ddlSuccess struct{} + +var _ Result = ddlSuccess{} + +func (ddlSuccess) LastInsertId() (int64, os.Error) { + return 0, os.NewError("no LastInsertId available after DDL statement") +} + +func (ddlSuccess) RowsAffected() (int64, os.Error) { + return 0, os.NewError("no RowsAffected available after DDL statement") +} diff --git a/src/pkg/exp/sql/driver/types.go b/src/pkg/exp/sql/driver/types.go new file mode 100644 index 0000000000..5521d5389c --- /dev/null +++ b/src/pkg/exp/sql/driver/types.go @@ -0,0 +1,161 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package driver + +import ( + "fmt" + "os" + "reflect" + "strconv" +) + +// ValueConverter is the interface providing the ConvertValue method. +type ValueConverter interface { + // ConvertValue converts a value to a restricted subset type. + ConvertValue(v interface{}) (interface{}, os.Error) +} + +// Bool is a ValueConverter that converts input values to bools. +// +// The conversion rules are: +// - .... TODO(bradfitz): TBD +var Bool boolType + +type boolType struct{} + +var _ ValueConverter = boolType{} + +func (boolType) ConvertValue(v interface{}) (interface{}, os.Error) { + return nil, fmt.Errorf("TODO(bradfitz): bool conversions") +} + +// Int32 is a ValueConverter that converts input values to int64, +// respecting the limits of an int32 value. +var Int32 int32Type + +type int32Type struct{} + +var _ ValueConverter = int32Type{} + +func (int32Type) ConvertValue(v interface{}) (interface{}, os.Error) { + rv := reflect.ValueOf(v) + switch rv.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + i64 := rv.Int() + if i64 > (1<<31)-1 || i64 < -(1<<31) { + return nil, fmt.Errorf("sql/driver: value %d overflows int32", v) + } + return i64, nil + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + u64 := rv.Uint() + if u64 > (1<<31)-1 { + return nil, fmt.Errorf("sql/driver: value %d overflows int32", v) + } + return int64(u64), nil + case reflect.String: + i, err := strconv.Atoi(rv.String()) + if err != nil { + return nil, fmt.Errorf("sql/driver: value %q can't be converted to int32", v) + } + return int64(i), nil + } + return nil, fmt.Errorf("sql/driver: unsupported value %v (type %T) converting to int32", v, v) +} + +// String is a ValueConverter that converts its input to a string. +// If the value is already a string or []byte, it's unchanged. +// If the value is of another type, conversion to string is done +// with fmt.Sprintf("%v", v). +var String stringType + +type stringType struct{} + +func (stringType) ConvertValue(v interface{}) (interface{}, os.Error) { + switch v.(type) { + case string, []byte: + return v, nil + } + return fmt.Sprintf("%v", v), nil +} + +// IsParameterSubsetType reports whether v is of a valid type for a +// parameter. These types are: +// +// int64 +// float64 +// bool +// nil +// []byte +// string +// +// This is the ame list as IsScanSubsetType, with the addition of +// string. +func IsParameterSubsetType(v interface{}) bool { + if IsScanSubsetType(v) { + return true + } + if _, ok := v.(string); ok { + return true + } + return false +} + +// IsScanSubsetType reports whether v is of a valid type for a +// value populated by Rows.Next. These types are: +// +// int64 +// float64 +// bool +// nil +// []byte +// +// This is the same list as IsParameterSubsetType, without string. +func IsScanSubsetType(v interface{}) bool { + if v == nil { + return true + } + switch v.(type) { + case int64, float64, []byte, bool: + return true + } + return false +} + +// DefaultParameterConverter is the default implementation of +// ValueConverter that's used when a Stmt doesn't implement +// ColumnConverter. +// +// DefaultParameterConverter returns the given value directly if +// IsSubsetType(value). Otherwise integer type are converted to +// int64, floats to float64, and strings to []byte. Other types are +// an error. +var DefaultParameterConverter defaultConverter + +type defaultConverter struct{} + +var _ ValueConverter = defaultConverter{} + +func (defaultConverter) ConvertValue(v interface{}) (interface{}, os.Error) { + if IsParameterSubsetType(v) { + return v, nil + } + + rv := reflect.ValueOf(v) + switch rv.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return rv.Int(), nil + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32: + return int64(rv.Uint()), nil + case reflect.Uint64: + u64 := rv.Uint() + if u64 >= 1<<63 { + return nil, fmt.Errorf("uint64 values with high bit set are not supported") + } + return int64(u64), nil + case reflect.Float32, reflect.Float64: + return rv.Float(), nil + } + return nil, fmt.Errorf("unsupported type %s", rv.Kind()) +} diff --git a/src/pkg/exp/sql/fakedb_test.go b/src/pkg/exp/sql/fakedb_test.go new file mode 100644 index 0000000000..40f60f782b --- /dev/null +++ b/src/pkg/exp/sql/fakedb_test.go @@ -0,0 +1,497 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package sql + +import ( + "fmt" + "log" + "os" + "strconv" + "strings" + "sync" + + "exp/sql/driver" +) + +var _ = log.Printf + +// fakeDriver is a fake database that implements Go's driver.Driver +// interface, just for testing. +// +// It speaks a query language that's semantically similar to but +// syntantically different and simpler than SQL. The syntax is as +// follows: +// +// WIPE +// CREATE||=,=,... +// where types are: "string", [u]int{8,16,32,64}, "bool" +// INSERT||col=val,col2=val2,col3=? +// SELECT||projectcol1,projectcol2|filtercol=?,filtercol2=? +// +// When opening a a fakeDriver's database, it starts empty with no +// tables. All tables and data are stored in memory only. +type fakeDriver struct { + mu sync.Mutex + openCount int + dbs map[string]*fakeDB +} + +type fakeDB struct { + name string + + mu sync.Mutex + free []*fakeConn + tables map[string]*table +} + +type table struct { + mu sync.Mutex + colname []string + coltype []string + rows []*row +} + +func (t *table) columnIndex(name string) int { + for n, nname := range t.colname { + if name == nname { + return n + } + } + return -1 +} + +type row struct { + cols []interface{} // must be same size as its table colname + coltype +} + +func (r *row) clone() *row { + nrow := &row{cols: make([]interface{}, len(r.cols))} + copy(nrow.cols, r.cols) + return nrow +} + +type fakeConn struct { + db *fakeDB // where to return ourselves to + + currTx *fakeTx +} + +type fakeTx struct { + c *fakeConn +} + +type fakeStmt struct { + c *fakeConn + q string // just for debugging + + cmd string + table string + + colName []string // used by CREATE, INSERT, SELECT (selected columns) + colType []string // used by CREATE + colValue []interface{} // used by INSERT (mix of strings and "?" for bound params) + placeholders int // used by INSERT/SELECT: number of ? params + + whereCol []string // used by SELECT (all placeholders) + + placeholderConverter []driver.ValueConverter // used by INSERT +} + +var fdriver driver.Driver = &fakeDriver{} + +func init() { + Register("test", fdriver) +} + +// Supports dsn forms: +// +// ;wipe +func (d *fakeDriver) Open(dsn string) (driver.Conn, os.Error) { + d.mu.Lock() + defer d.mu.Unlock() + d.openCount++ + if d.dbs == nil { + d.dbs = make(map[string]*fakeDB) + } + parts := strings.Split(dsn, ";") + if len(parts) < 1 { + return nil, os.NewError("fakedb: no database name") + } + name := parts[0] + db, ok := d.dbs[name] + if !ok { + db = &fakeDB{name: name} + d.dbs[name] = db + } + return &fakeConn{db: db}, nil +} + +func (db *fakeDB) wipe() { + db.mu.Lock() + defer db.mu.Unlock() + db.tables = nil +} + +func (db *fakeDB) createTable(name string, columnNames, columnTypes []string) os.Error { + db.mu.Lock() + defer db.mu.Unlock() + if db.tables == nil { + db.tables = make(map[string]*table) + } + if _, exist := db.tables[name]; exist { + return fmt.Errorf("table %q already exists", name) + } + if len(columnNames) != len(columnTypes) { + return fmt.Errorf("create table of %q len(names) != len(types): %d vs %d", + len(columnNames), len(columnTypes)) + } + db.tables[name] = &table{colname: columnNames, coltype: columnTypes} + return nil +} + +// must be called with db.mu lock held +func (db *fakeDB) table(table string) (*table, bool) { + if db.tables == nil { + return nil, false + } + t, ok := db.tables[table] + return t, ok +} + +func (db *fakeDB) columnType(table, column string) (typ string, ok bool) { + db.mu.Lock() + defer db.mu.Unlock() + t, ok := db.table(table) + if !ok { + return + } + for n, cname := range t.colname { + if cname == column { + return t.coltype[n], true + } + } + return "", false +} + +func (c *fakeConn) Begin() (driver.Tx, os.Error) { + if c.currTx != nil { + return nil, os.NewError("already in a transaction") + } + c.currTx = &fakeTx{c: c} + return c.currTx, nil +} + +func (c *fakeConn) Close() os.Error { + if c.currTx != nil { + return os.NewError("can't close; in a Transaction") + } + if c.db == nil { + return os.NewError("can't close; already closed") + } + c.db = nil + return nil +} + +func errf(msg string, args ...interface{}) os.Error { + return os.NewError("fakedb: " + fmt.Sprintf(msg, args...)) +} + +// parts are table|selectCol1,selectCol2|whereCol=?,whereCol2=? +// (note that where where columns must always contain ? marks, +// just a limitation for fakedb) +func (c *fakeConn) prepareSelect(stmt *fakeStmt, parts []string) (driver.Stmt, os.Error) { + if len(parts) != 3 { + return nil, errf("invalid SELECT syntax with %d parts; want 3", len(parts)) + } + stmt.table = parts[0] + stmt.colName = strings.Split(parts[1], ",") + for n, colspec := range strings.Split(parts[2], ",") { + nameVal := strings.Split(colspec, "=") + if len(nameVal) != 2 { + return nil, errf("SELECT on table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n) + } + column, value := nameVal[0], nameVal[1] + _, ok := c.db.columnType(stmt.table, column) + if !ok { + return nil, errf("SELECT on table %q references non-existent column %q", stmt.table, column) + } + if value != "?" { + return nil, errf("SELECT on table %q has pre-bound value for where column %q; need a question mark", + stmt.table, column) + } + stmt.whereCol = append(stmt.whereCol, column) + stmt.placeholders++ + } + return stmt, nil +} + +// parts are table|col=type,col2=type2 +func (c *fakeConn) prepareCreate(stmt *fakeStmt, parts []string) (driver.Stmt, os.Error) { + if len(parts) != 2 { + return nil, errf("invalid CREATE syntax with %d parts; want 2", len(parts)) + } + stmt.table = parts[0] + for n, colspec := range strings.Split(parts[1], ",") { + nameType := strings.Split(colspec, "=") + if len(nameType) != 2 { + return nil, errf("CREATE table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n) + } + stmt.colName = append(stmt.colName, nameType[0]) + stmt.colType = append(stmt.colType, nameType[1]) + } + return stmt, nil +} + +// parts are table|col=?,col2=val +func (c *fakeConn) prepareInsert(stmt *fakeStmt, parts []string) (driver.Stmt, os.Error) { + if len(parts) != 2 { + return nil, errf("invalid INSERT syntax with %d parts; want 2", len(parts)) + } + stmt.table = parts[0] + for n, colspec := range strings.Split(parts[1], ",") { + nameVal := strings.Split(colspec, "=") + if len(nameVal) != 2 { + return nil, errf("INSERT table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n) + } + column, value := nameVal[0], nameVal[1] + ctype, ok := c.db.columnType(stmt.table, column) + if !ok { + return nil, errf("INSERT table %q references non-existent column %q", stmt.table, column) + } + stmt.colName = append(stmt.colName, column) + + if value != "?" { + var subsetVal interface{} + // Convert to driver subset type + switch ctype { + case "string": + subsetVal = []byte(value) + case "int32": + i, err := strconv.Atoi(value) + if err != nil { + return nil, errf("invalid conversion to int32 from %q", value) + } + subsetVal = int64(i) // int64 is a subset type, but not int32 + default: + return nil, errf("unsupported conversion for pre-bound parameter %q to type %q", value, ctype) + } + stmt.colValue = append(stmt.colValue, subsetVal) + } else { + stmt.placeholders++ + stmt.placeholderConverter = append(stmt.placeholderConverter, converterForType(ctype)) + stmt.colValue = append(stmt.colValue, "?") + } + } + return stmt, nil +} + +func (c *fakeConn) Prepare(query string) (driver.Stmt, os.Error) { + if c.db == nil { + panic("nil c.db; conn = " + fmt.Sprintf("%#v", c)) + } + parts := strings.Split(query, "|") + if len(parts) < 1 { + return nil, errf("empty query") + } + cmd := parts[0] + parts = parts[1:] + stmt := &fakeStmt{q: query, c: c, cmd: cmd} + switch cmd { + case "WIPE": + // Nothing + case "SELECT": + return c.prepareSelect(stmt, parts) + case "CREATE": + return c.prepareCreate(stmt, parts) + case "INSERT": + return c.prepareInsert(stmt, parts) + default: + return nil, errf("unsupported command type %q", cmd) + } + return stmt, nil +} + +func (s *fakeStmt) ColumnConverter(idx int) driver.ValueConverter { + return s.placeholderConverter[idx] +} + +func (s *fakeStmt) Close() os.Error { + return nil +} + +func (s *fakeStmt) Exec(args []interface{}) (driver.Result, os.Error) { + db := s.c.db + switch s.cmd { + case "WIPE": + db.wipe() + return driver.DDLSuccess, nil + case "CREATE": + if err := db.createTable(s.table, s.colName, s.colType); err != nil { + return nil, err + } + return driver.DDLSuccess, nil + case "INSERT": + return s.execInsert(args) + } + fmt.Printf("EXEC statement, cmd=%q: %#v\n", s.cmd, s) + return nil, fmt.Errorf("unimplemented statement Exec command type of %q", s.cmd) +} + +func (s *fakeStmt) execInsert(args []interface{}) (driver.Result, os.Error) { + db := s.c.db + if len(args) != s.placeholders { + panic("error in pkg db; should only get here if size is correct") + } + db.mu.Lock() + t, ok := db.table(s.table) + db.mu.Unlock() + if !ok { + return nil, fmt.Errorf("fakedb: table %q doesn't exist", s.table) + } + + t.mu.Lock() + defer t.mu.Unlock() + + cols := make([]interface{}, len(t.colname)) + argPos := 0 + for n, colname := range s.colName { + colidx := t.columnIndex(colname) + if colidx == -1 { + return nil, fmt.Errorf("fakedb: column %q doesn't exist or dropped since prepared statement was created", colname) + } + var val interface{} + if strvalue, ok := s.colValue[n].(string); ok && strvalue == "?" { + val = args[argPos] + argPos++ + } else { + val = s.colValue[n] + } + cols[colidx] = val + } + + t.rows = append(t.rows, &row{cols: cols}) + return driver.RowsAffected(1), nil +} + +func (s *fakeStmt) Query(args []interface{}) (driver.Rows, os.Error) { + db := s.c.db + if len(args) != s.placeholders { + panic("error in pkg db; should only get here if size is correct") + } + + db.mu.Lock() + t, ok := db.table(s.table) + db.mu.Unlock() + if !ok { + return nil, fmt.Errorf("fakedb: table %q doesn't exist", s.table) + } + t.mu.Lock() + defer t.mu.Unlock() + + colIdx := make(map[string]int) // select column name -> column index in table + for _, name := range s.colName { + idx := t.columnIndex(name) + if idx == -1 { + return nil, fmt.Errorf("fakedb: unknown column name %q", name) + } + colIdx[name] = idx + } + + mrows := []*row{} +rows: + for _, trow := range t.rows { + // Process the where clause, skipping non-match rows. This is lazy + // and just uses fmt.Sprintf("%v") to test equality. Good enough + // for test code. + for widx, wcol := range s.whereCol { + idx := t.columnIndex(wcol) + if idx == -1 { + return nil, fmt.Errorf("db: invalid where clause column %q", wcol) + } + tcol := trow.cols[idx] + if bs, ok := tcol.([]byte); ok { + // lazy hack to avoid sprintf %v on a []byte + tcol = string(bs) + } + if fmt.Sprintf("%v", tcol) != fmt.Sprintf("%v", args[widx]) { + continue rows + } + } + mrow := &row{cols: make([]interface{}, len(s.colName))} + for seli, name := range s.colName { + mrow.cols[seli] = trow.cols[colIdx[name]] + } + mrows = append(mrows, mrow) + } + + cursor := &rowsCursor{ + pos: -1, + rows: mrows, + cols: s.colName, + } + return cursor, nil +} + +func (s *fakeStmt) NumInput() int { + return s.placeholders +} + +func (tx *fakeTx) Commit() os.Error { + tx.c.currTx = nil + return nil +} + +func (tx *fakeTx) Rollback() os.Error { + tx.c.currTx = nil + return nil +} + +type rowsCursor struct { + cols []string + pos int + rows []*row + closed bool +} + +func (rc *rowsCursor) Close() os.Error { + rc.closed = true + return nil +} + +func (rc *rowsCursor) Columns() []string { + return rc.cols +} + +func (rc *rowsCursor) Next(dest []interface{}) os.Error { + if rc.closed { + return os.NewError("fakedb: cursor is closed") + } + rc.pos++ + if rc.pos >= len(rc.rows) { + return os.EOF // per interface spec + } + for i, v := range rc.rows[rc.pos].cols { + // TODO(bradfitz): convert to subset types? naah, I + // think the subset types should only be input to + // driver, but the db package should be able to handle + // a wider range of types coming out of drivers. all + // for ease of drivers, and to prevent drivers from + // messing up conversions or doing them differently. + dest[i] = v + } + return nil +} + +func converterForType(typ string) driver.ValueConverter { + switch typ { + case "bool": + return driver.Bool + case "int32": + return driver.Int32 + case "string": + return driver.String + } + panic("invalid fakedb column type of " + typ) +} diff --git a/src/pkg/exp/sql/sql.go b/src/pkg/exp/sql/sql.go new file mode 100644 index 0000000000..7f0e0b2842 --- /dev/null +++ b/src/pkg/exp/sql/sql.go @@ -0,0 +1,578 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package sql provides a generic interface around SQL (or SQL-like) +// databases. +package sql + +import ( + "fmt" + "os" + "runtime" + "sync" + + "exp/sql/driver" +) + +var drivers = make(map[string]driver.Driver) + +// Register makes a database driver available by the provided name. +// If Register is called twice with the same name or if driver is nil, +// it panics. +func Register(name string, driver driver.Driver) { + if driver == nil { + panic("db: Register driver is nil") + } + if _, dup := drivers[name]; dup { + panic("db: Register called twice for driver " + name) + } + drivers[name] = driver +} + +// NullableString represents a string that may be null. +// NullableString implements the ScannerInto interface so +// it can be used as a scan destination: +// +// var s NullableString +// err := db.QueryRow("SELECT name FROM foo WHERE id=?", id).Scan(&s) +// ... +// if s.Valid { +// // use s.String +// } else { +// // NULL value +// } +// +// TODO(bradfitz): add other types. +type NullableString struct { + String string + Valid bool // Valid is true if String is not NULL +} + +// ScanInto implements the ScannerInto interface. +func (ms *NullableString) ScanInto(value interface{}) os.Error { + if value == nil { + ms.String, ms.Valid = "", false + return nil + } + ms.Valid = true + return convertAssign(&ms.String, value) +} + +// ScannerInto is an interface used by Scan. +type ScannerInto interface { + // ScanInto assigns a value from a database driver. + // + // The value will be of one of the following restricted + // set of types: + // + // int64 + // float64 + // bool + // []byte + // nil - for NULL values + // + // An error should be returned if the value can not be stored + // without loss of information. + ScanInto(value interface{}) os.Error +} + +// ErrNoRows is returned by Scan when QueryRow doesn't return a +// row. In such a case, QueryRow returns a placeholder *Row value that +// defers this error until a Scan. +var ErrNoRows = os.NewError("db: no rows in result set") + +// DB is a database handle. It's safe for concurrent use by multiple +// goroutines. +type DB struct { + driver driver.Driver + dsn string + + mu sync.Mutex + freeConn []driver.Conn +} + +// Open opens a database specified by its database driver name and a +// driver-specific data source name, usually consisting of at least a +// database name and connection information. +// +// Most users will open a database via a driver-specific connection +// helper function that returns a *DB. +func Open(driverName, dataSourceName string) (*DB, os.Error) { + driver, ok := drivers[driverName] + if !ok { + return nil, fmt.Errorf("db: unknown driver %q (forgotten import?)", driverName) + } + return &DB{driver: driver, dsn: dataSourceName}, nil +} + +func (db *DB) maxIdleConns() int { + const defaultMaxIdleConns = 2 + // TODO(bradfitz): ask driver, if supported, for its default preference + // TODO(bradfitz): let users override? + return defaultMaxIdleConns +} + +// conn returns a newly-opened or cached driver.Conn +func (db *DB) conn() (driver.Conn, os.Error) { + db.mu.Lock() + if n := len(db.freeConn); n > 0 { + conn := db.freeConn[n-1] + db.freeConn = db.freeConn[:n-1] + db.mu.Unlock() + return conn, nil + } + db.mu.Unlock() + return db.driver.Open(db.dsn) +} + +func (db *DB) connIfFree(wanted driver.Conn) (conn driver.Conn, ok bool) { + db.mu.Lock() + defer db.mu.Unlock() + for n, conn := range db.freeConn { + if conn == wanted { + db.freeConn[n] = db.freeConn[len(db.freeConn)-1] + db.freeConn = db.freeConn[:len(db.freeConn)-1] + return wanted, true + } + } + return nil, false +} + +func (db *DB) putConn(c driver.Conn) { + if n := len(db.freeConn); n < db.maxIdleConns() { + db.freeConn = append(db.freeConn, c) + return + } + db.closeConn(c) +} + +func (db *DB) closeConn(c driver.Conn) { + // TODO: check to see if we need this Conn for any prepared statements + // that are active. + c.Close() +} + +// Prepare creates a prepared statement for later execution. +func (db *DB) Prepare(query string) (*Stmt, os.Error) { + // TODO: check if db.driver supports an optional + // driver.Preparer interface and call that instead, if so, + // otherwise we make a prepared statement that's bound + // to a connection, and to execute this prepared statement + // we either need to use this connection (if it's free), else + // get a new connection + re-prepare + execute on that one. + ci, err := db.conn() + if err != nil { + return nil, err + } + defer db.putConn(ci) + si, err := ci.Prepare(query) + if err != nil { + return nil, err + } + stmt := &Stmt{ + db: db, + query: query, + css: []connStmt{{ci, si}}, + } + return stmt, nil +} + +// Exec executes a query without returning any rows. +func (db *DB) Exec(query string, args ...interface{}) (Result, os.Error) { + // Optional fast path, if the driver implements driver.Execer. + if execer, ok := db.driver.(driver.Execer); ok { + resi, err := execer.Exec(query, args) + if err != nil { + return nil, err + } + return result{resi}, nil + } + + // If the driver does not implement driver.Execer, we need + // a connection. + conn, err := db.conn() + if err != nil { + return nil, err + } + defer db.putConn(conn) + + if execer, ok := conn.(driver.Execer); ok { + resi, err := execer.Exec(query, args) + if err != nil { + return nil, err + } + return result{resi}, nil + } + + sti, err := conn.Prepare(query) + if err != nil { + return nil, err + } + defer sti.Close() + resi, err := sti.Exec(args) + if err != nil { + return nil, err + } + return result{resi}, nil +} + +// Query executes a query that returns rows, typically a SELECT. +func (db *DB) Query(query string, args ...interface{}) (*Rows, os.Error) { + stmt, err := db.Prepare(query) + if err != nil { + return nil, err + } + defer stmt.Close() + return stmt.Query(args...) +} + +// QueryRow executes a query that is expected to return at most one row. +// QueryRow always return a non-nil value. Errors are deferred until +// Row's Scan method is called. +func (db *DB) QueryRow(query string, args ...interface{}) *Row { + rows, err := db.Query(query, args...) + if err != nil { + return &Row{err: err} + } + return &Row{rows: rows} +} + +// Begin starts a transaction. The isolation level is dependent on +// the driver. +func (db *DB) Begin() (*Tx, os.Error) { + // TODO(bradfitz): add another method for beginning a transaction + // at a specific isolation level. + panic(todo()) +} + +// DriverDatabase returns the database's underlying driver. +func (db *DB) Driver() driver.Driver { + return db.driver +} + +// Tx is an in-progress database transaction. +type Tx struct { + +} + +// Commit commits the transaction. +func (tx *Tx) Commit() os.Error { + panic(todo()) +} + +// Rollback aborts the transaction. +func (tx *Tx) Rollback() os.Error { + panic(todo()) +} + +// Prepare creates a prepared statement. +func (tx *Tx) Prepare(query string) (*Stmt, os.Error) { + panic(todo()) +} + +// Exec executes a query that doesn't return rows. +// For example: an INSERT and UPDATE. +func (tx *Tx) Exec(query string, args ...interface{}) { + panic(todo()) +} + +// Query executes a query that returns rows, typically a SELECT. +func (tx *Tx) Query(query string, args ...interface{}) (*Rows, os.Error) { + panic(todo()) +} + +// QueryRow executes a query that is expected to return at most one row. +// QueryRow always return a non-nil value. Errors are deferred until +// Row's Scan method is called. +func (tx *Tx) QueryRow(query string, args ...interface{}) *Row { + panic(todo()) +} + +// connStmt is a prepared statement on a particular connection. +type connStmt struct { + ci driver.Conn + si driver.Stmt +} + +// Stmt is a prepared statement. Stmt is safe for concurrent use by multiple goroutines. +type Stmt struct { + // Immutable: + db *DB // where we came from + query string // that created the Sttm + + mu sync.Mutex + closed bool + css []connStmt // can use any that have idle connections +} + +func todo() string { + _, file, line, _ := runtime.Caller(1) + return fmt.Sprintf("%s:%d: TODO: implement", file, line) +} + +// Exec executes a prepared statement with the given arguments and +// returns a Result summarizing the effect of the statement. +func (s *Stmt) Exec(args ...interface{}) (Result, os.Error) { + ci, si, err := s.connStmt() + if err != nil { + return nil, err + } + defer s.db.putConn(ci) + + if want := si.NumInput(); len(args) != want { + return nil, fmt.Errorf("db: expected %d arguments, got %d", want, len(args)) + } + + // Convert args to subset types. + if cc, ok := si.(driver.ColumnConverter); ok { + for n, arg := range args { + args[n], err = cc.ColumnConverter(n).ConvertValue(arg) + if err != nil { + return nil, fmt.Errorf("db: converting Exec argument #%d's type: %v", n, err) + } + if !driver.IsParameterSubsetType(args[n]) { + return nil, fmt.Errorf("db: driver ColumnConverter error converted %T to unsupported type %T", + arg, args[n]) + } + } + } else { + for n, arg := range args { + args[n], err = driver.DefaultParameterConverter.ConvertValue(arg) + if err != nil { + return nil, fmt.Errorf("db: converting Exec argument #%d's type: %v", n, err) + } + } + } + + resi, err := si.Exec(args) + if err != nil { + return nil, err + } + return result{resi}, nil +} + +func (s *Stmt) connStmt(args ...interface{}) (driver.Conn, driver.Stmt, os.Error) { + s.mu.Lock() + if s.closed { + return nil, nil, os.NewError("db: statement is closed") + } + var cs connStmt + match := false + for _, v := range s.css { + // TODO(bradfitz): lazily clean up entries in this + // list with dead conns while enumerating + if _, match = s.db.connIfFree(cs.ci); match { + cs = v + break + } + } + s.mu.Unlock() + + // Make a new conn if all are busy. + // TODO(bradfitz): or wait for one? make configurable later? + if !match { + ci, err := s.db.conn() + if err != nil { + return nil, nil, err + } + si, err := ci.Prepare(s.query) + if err != nil { + return nil, nil, err + } + s.mu.Lock() + cs = connStmt{ci, si} + s.css = append(s.css, cs) + s.mu.Unlock() + } + + return cs.ci, cs.si, nil +} + +// Query executes a prepared query statement with the given arguments +// and returns the query results as a *Rows. +func (s *Stmt) Query(args ...interface{}) (*Rows, os.Error) { + ci, si, err := s.connStmt(args...) + if err != nil { + return nil, err + } + if len(args) != si.NumInput() { + return nil, fmt.Errorf("db: statement expects %d inputs; got %d", si.NumInput(), len(args)) + } + rowsi, err := si.Query(args) + if err != nil { + s.db.putConn(ci) + return nil, err + } + // Note: ownership of ci passes to the *Rows + rows := &Rows{ + db: s.db, + ci: ci, + rowsi: rowsi, + } + return rows, nil +} + +// QueryRow executes a prepared query statement with the given arguments. +// If an error occurs during the execution of the statement, that error will +// be returned by a call to Scan on the returned *Row, which is always non-nil. +// If the query selects no rows, the *Row's Scan will return ErrNoRows. +// Otherwise, the *Row's Scan scans the first selected row and discards +// the rest. +// +// Example usage: +// +// var name string +// err := nameByUseridStmt.QueryRow(id).Scan(&s) +func (s *Stmt) QueryRow(args ...interface{}) *Row { + rows, err := s.Query(args...) + if err != nil { + return &Row{err: err} + } + return &Row{rows: rows} +} + +// Close closes the statement. +func (s *Stmt) Close() os.Error { + s.mu.Lock() + defer s.mu.Unlock() // TODO(bradfitz): move this unlock after 'closed = true'? + if s.closed { + return nil + } + s.closed = true + for _, v := range s.css { + if ci, match := s.db.connIfFree(v.ci); match { + v.si.Close() + s.db.putConn(ci) + } else { + // TODO(bradfitz): care that we can't close + // this statement because the statement's + // connection is in use? + } + } + return nil +} + +// Rows is the result of a query. Its cursor starts before the first row +// of the result set. Use Next to advance through the rows: +// +// rows, err := db.Query("SELECT ...") +// ... +// for rows.Next() { +// var id int +// var name string +// err = rows.Scan(&id, &name) +// ... +// } +// err = rows.Error() // get any Error encountered during iteration +// ... +type Rows struct { + db *DB + ci driver.Conn // owned; must be returned when Rows is closed + rowsi driver.Rows + + closed bool + lastcols []interface{} + lasterr os.Error +} + +// Next prepares the next result row for reading with the Scan method. +// It returns true on success, false if there is no next result row. +// Every call to Scan, even the first one, must be preceded by a call +// to Next. +func (rs *Rows) Next() bool { + if rs.closed { + return false + } + if rs.lasterr != nil { + return false + } + if rs.lastcols == nil { + rs.lastcols = make([]interface{}, len(rs.rowsi.Columns())) + } + rs.lasterr = rs.rowsi.Next(rs.lastcols) + return rs.lasterr == nil +} + +// Error returns the error, if any, that was encountered during iteration. +func (rs *Rows) Error() os.Error { + if rs.lasterr == os.EOF { + return nil + } + return rs.lasterr +} + +// Scan copies the columns in the current row into the values pointed +// at by dest. If dest contains pointers to []byte, the slices should +// not be modified and should only be considered valid until the next +// call to Next or Scan. +func (rs *Rows) Scan(dest ...interface{}) os.Error { + if rs.closed { + return os.NewError("db: Rows closed") + } + if rs.lasterr != nil { + return rs.lasterr + } + if rs.lastcols == nil { + return os.NewError("db: Scan called without calling Next") + } + if len(dest) != len(rs.lastcols) { + return fmt.Errorf("db: expected %d destination arguments in Scan, not %d", len(rs.lastcols), len(dest)) + } + for i, sv := range rs.lastcols { + err := convertAssign(dest[i], sv) + if err != nil { + return fmt.Errorf("db: Scan error on column index %d: %v", i, err) + } + } + return nil +} + +// Close closes the Rows, preventing further enumeration. If the +// end is encountered, the Rows are closed automatically. Close +// is idempotent. +func (rs *Rows) Close() os.Error { + if rs.closed { + return nil + } + rs.closed = true + err := rs.rowsi.Close() + rs.db.putConn(rs.ci) + return err +} + +// Row is the result of calling QueryRow to select a single row. +type Row struct { + // One of these two will be non-nil: + err os.Error // deferred error for easy chaining + rows *Rows +} + +// Scan copies the columns from the matched row into the values +// pointed at by dest. If more than one row matches the query, +// Scan uses the first row and discards the rest. If no row matches +// the query, Scan returns ErrNoRows. +// +// If dest contains pointers to []byte, the slices should not be +// modified and should only be considered valid until the next call to +// Next or Scan. +func (r *Row) Scan(dest ...interface{}) os.Error { + if r.err != nil { + return r.err + } + defer r.rows.Close() + if !r.rows.Next() { + return ErrNoRows + } + return r.rows.Scan(dest...) +} + +// A Result summarizes an executed SQL command. +type Result interface { + LastInsertId() (int64, os.Error) + RowsAffected() (int64, os.Error) +} + +type result struct { + driver.Result +} diff --git a/src/pkg/exp/sql/sql_test.go b/src/pkg/exp/sql/sql_test.go new file mode 100644 index 0000000000..9dca710e61 --- /dev/null +++ b/src/pkg/exp/sql/sql_test.go @@ -0,0 +1,145 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package sql + +import ( + "strings" + "testing" +) + +func newTestDB(t *testing.T, name string) *DB { + db, err := Open("test", "foo") + if err != nil { + t.Fatalf("Open: %v", err) + } + if _, err := db.Exec("WIPE"); err != nil { + t.Fatalf("exec wipe: %v", err) + } + if name == "people" { + exec(t, db, "CREATE|people|name=string,age=int32,dead=bool") + exec(t, db, "INSERT|people|name=Alice,age=?", 1) + exec(t, db, "INSERT|people|name=Bob,age=?", 2) + exec(t, db, "INSERT|people|name=Chris,age=?", 3) + + } + return db +} + +func exec(t *testing.T, db *DB, query string, args ...interface{}) { + _, err := db.Exec(query, args...) + if err != nil { + t.Fatalf("Exec of %q: %v", query, err) + } +} + +func TestQuery(t *testing.T) { + db := newTestDB(t, "people") + var name string + var age int + + err := db.QueryRow("SELECT|people|age,name|age=?", 3).Scan(&age) + if err == nil || !strings.Contains(err.String(), "expected 2 destination arguments") { + t.Errorf("expected error from wrong number of arguments; actually got: %v", err) + } + + err = db.QueryRow("SELECT|people|age,name|age=?", 2).Scan(&age, &name) + if err != nil { + t.Fatalf("age QueryRow+Scan: %v", err) + } + if name != "Bob" { + t.Errorf("expected name Bob, got %q", name) + } + if age != 2 { + t.Errorf("expected age 2, got %d", age) + } + + err = db.QueryRow("SELECT|people|age,name|name=?", "Alice").Scan(&age, &name) + if err != nil { + t.Fatalf("name QueryRow+Scan: %v", err) + } + if name != "Alice" { + t.Errorf("expected name Alice, got %q", name) + } + if age != 1 { + t.Errorf("expected age 1, got %d", age) + } +} + +func TestStatementQueryRow(t *testing.T) { + db := newTestDB(t, "people") + stmt, err := db.Prepare("SELECT|people|age|name=?") + if err != nil { + t.Fatalf("Prepare: %v", err) + } + var age int + for n, tt := range []struct { + name string + want int + }{ + {"Alice", 1}, + {"Bob", 2}, + {"Chris", 3}, + } { + if err := stmt.QueryRow(tt.name).Scan(&age); err != nil { + t.Errorf("%d: on %q, QueryRow/Scan: %v", n, tt.name, err) + } else if age != tt.want { + t.Errorf("%d: age=%d, want %d", age, tt.want) + } + } + +} + +// just a test of fakedb itself +func TestBogusPreboundParameters(t *testing.T) { + db := newTestDB(t, "foo") + exec(t, db, "CREATE|t1|name=string,age=int32,dead=bool") + _, err := db.Prepare("INSERT|t1|name=?,age=bogusconversion") + if err == nil { + t.Fatalf("expected error") + } + if err.String() != `fakedb: invalid conversion to int32 from "bogusconversion"` { + t.Errorf("unexpected error: %v", err) + } +} + +func TestDb(t *testing.T) { + db := newTestDB(t, "foo") + exec(t, db, "CREATE|t1|name=string,age=int32,dead=bool") + stmt, err := db.Prepare("INSERT|t1|name=?,age=?") + if err != nil { + t.Errorf("Stmt, err = %v, %v", stmt, err) + } + + type execTest struct { + args []interface{} + wantErr string + } + execTests := []execTest{ + // Okay: + {[]interface{}{"Brad", 31}, ""}, + {[]interface{}{"Brad", int64(31)}, ""}, + {[]interface{}{"Bob", "32"}, ""}, + {[]interface{}{7, 9}, ""}, + + // Invalid conversions: + {[]interface{}{"Brad", int64(0xFFFFFFFF)}, "db: converting Exec argument #1's type: sql/driver: value 4294967295 overflows int32"}, + {[]interface{}{"Brad", "strconv fail"}, "db: converting Exec argument #1's type: sql/driver: value \"strconv fail\" can't be converted to int32"}, + + // Wrong number of args: + {[]interface{}{}, "db: expected 2 arguments, got 0"}, + {[]interface{}{1, 2, 3}, "db: expected 2 arguments, got 3"}, + } + for n, et := range execTests { + _, err := stmt.Exec(et.args...) + errStr := "" + if err != nil { + errStr = err.String() + } + if errStr != et.wantErr { + t.Errorf("stmt.Execute #%d: for %v, got error %q, want error %q", + n, et.args, errStr, et.wantErr) + } + } +}