146 строки
2,5 КиБ
Go
146 строки
2,5 КиБ
Go
package main
|
|
|
|
import (
|
|
"database/sql"
|
|
"database/sql/driver"
|
|
"io"
|
|
"sync"
|
|
)
|
|
|
|
// Register a txdb sql driver which can be used to open
|
|
// a single transaction based database connection pool
|
|
func Register(drv, dsn string) {
|
|
sql.Register("txdb", &txDriver{dsn: dsn, drv: drv})
|
|
}
|
|
|
|
// txDriver is an sql driver which runs on single transaction
|
|
// when the Close is called, transaction is rolled back
|
|
type txDriver struct {
|
|
sync.Mutex
|
|
tx *sql.Tx
|
|
|
|
drv string
|
|
dsn string
|
|
db *sql.DB
|
|
}
|
|
|
|
func (d *txDriver) Open(dsn string) (driver.Conn, error) {
|
|
// first open a real database connection
|
|
var err error
|
|
if d.db == nil {
|
|
db, err := sql.Open(d.drv, d.dsn)
|
|
if err != nil {
|
|
return d, err
|
|
}
|
|
d.db = db
|
|
}
|
|
if d.tx == nil {
|
|
d.tx, err = d.db.Begin()
|
|
}
|
|
return d, err
|
|
}
|
|
|
|
func (d *txDriver) Close() error {
|
|
err := d.tx.Rollback()
|
|
d.tx = nil
|
|
return err
|
|
}
|
|
|
|
func (d *txDriver) Begin() (driver.Tx, error) {
|
|
return d, nil
|
|
}
|
|
|
|
func (d *txDriver) Commit() error {
|
|
return nil
|
|
}
|
|
|
|
func (d *txDriver) Rollback() error {
|
|
return nil
|
|
}
|
|
|
|
func (d *txDriver) Prepare(query string) (driver.Stmt, error) {
|
|
return &stmt{drv: d, query: query}, nil
|
|
}
|
|
|
|
type stmt struct {
|
|
query string
|
|
drv *txDriver
|
|
}
|
|
|
|
func (s *stmt) Exec(args []driver.Value) (driver.Result, error) {
|
|
s.drv.Lock()
|
|
defer s.drv.Unlock()
|
|
|
|
st, err := s.drv.tx.Prepare(s.query)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer st.Close()
|
|
var iargs []interface{}
|
|
for _, arg := range args {
|
|
iargs = append(iargs, arg)
|
|
}
|
|
return st.Exec(iargs...)
|
|
}
|
|
|
|
func (s *stmt) NumInput() int {
|
|
return -1
|
|
}
|
|
|
|
func (s *stmt) Close() error {
|
|
return nil
|
|
}
|
|
|
|
func (s *stmt) Query(args []driver.Value) (driver.Rows, error) {
|
|
s.drv.Lock()
|
|
defer s.drv.Unlock()
|
|
|
|
st, err := s.drv.tx.Prepare(s.query)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
// do not close the statement here, Rows need it
|
|
var iargs []interface{}
|
|
for _, arg := range args {
|
|
iargs = append(iargs, arg)
|
|
}
|
|
rs, err := st.Query(iargs...)
|
|
return &rows{rs: rs}, err
|
|
}
|
|
|
|
type rows struct {
|
|
err error
|
|
rs *sql.Rows
|
|
}
|
|
|
|
func (r *rows) Columns() (cols []string) {
|
|
cols, r.err = r.rs.Columns()
|
|
return
|
|
}
|
|
|
|
func (r *rows) Next(dest []driver.Value) error {
|
|
if r.err != nil {
|
|
return r.err
|
|
}
|
|
if r.rs.Err() != nil {
|
|
return r.rs.Err()
|
|
}
|
|
if !r.rs.Next() {
|
|
return io.EOF
|
|
}
|
|
values := make([]interface{}, len(dest))
|
|
for i := range values {
|
|
values[i] = new(interface{})
|
|
}
|
|
if err := r.rs.Scan(values...); err != nil {
|
|
return err
|
|
}
|
|
for i, val := range values {
|
|
dest[i] = *(val.(*interface{}))
|
|
}
|
|
return r.rs.Err()
|
|
}
|
|
|
|
func (r *rows) Close() error {
|
|
return r.rs.Close()
|
|
}
|