godog/examples/db/txdb.go
2015-07-02 16:18:04 +03:00

146 строки
2,5 КиБ
Go

package main
import (
"database/sql"
"database/sql/driver"
"io"
"sync"
)
// Register a txdb sql driver which can be used to open
// a single transaction based database connection pool
func Register(drv, dsn string) {
sql.Register("txdb", &txDriver{dsn: dsn, drv: drv})
}
// txDriver is an sql driver which runs on single transaction
// when the Close is called, transaction is rolled back
type txDriver struct {
sync.Mutex
tx *sql.Tx
drv string
dsn string
db *sql.DB
}
func (d *txDriver) Open(dsn string) (driver.Conn, error) {
// first open a real database connection
var err error
if d.db == nil {
db, err := sql.Open(d.drv, d.dsn)
if err != nil {
return d, err
}
d.db = db
}
if d.tx == nil {
d.tx, err = d.db.Begin()
}
return d, err
}
func (d *txDriver) Close() error {
err := d.tx.Rollback()
d.tx = nil
return err
}
func (d *txDriver) Begin() (driver.Tx, error) {
return d, nil
}
func (d *txDriver) Commit() error {
return nil
}
func (d *txDriver) Rollback() error {
return nil
}
func (d *txDriver) Prepare(query string) (driver.Stmt, error) {
return &stmt{drv: d, query: query}, nil
}
type stmt struct {
query string
drv *txDriver
}
func (s *stmt) Exec(args []driver.Value) (driver.Result, error) {
s.drv.Lock()
defer s.drv.Unlock()
st, err := s.drv.tx.Prepare(s.query)
if err != nil {
return nil, err
}
defer st.Close()
var iargs []interface{}
for _, arg := range args {
iargs = append(iargs, arg)
}
return st.Exec(iargs...)
}
func (s *stmt) NumInput() int {
return -1
}
func (s *stmt) Close() error {
return nil
}
func (s *stmt) Query(args []driver.Value) (driver.Rows, error) {
s.drv.Lock()
defer s.drv.Unlock()
st, err := s.drv.tx.Prepare(s.query)
if err != nil {
return nil, err
}
// do not close the statement here, Rows need it
var iargs []interface{}
for _, arg := range args {
iargs = append(iargs, arg)
}
rs, err := st.Query(iargs...)
return &rows{rs: rs}, err
}
type rows struct {
err error
rs *sql.Rows
}
func (r *rows) Columns() (cols []string) {
cols, r.err = r.rs.Columns()
return
}
func (r *rows) Next(dest []driver.Value) error {
if r.err != nil {
return r.err
}
if r.rs.Err() != nil {
return r.rs.Err()
}
if !r.rs.Next() {
return io.EOF
}
values := make([]interface{}, len(dest))
for i := range values {
values[i] = new(interface{})
}
if err := r.rs.Scan(values...); err != nil {
return err
}
for i, val := range values {
dest[i] = *(val.(*interface{}))
}
return r.rs.Err()
}
func (r *rows) Close() error {
return r.rs.Close()
}