Skip to content

Commit

Permalink
V2: new methods for *sql.DB connections, auto-cleanup database
Browse files Browse the repository at this point in the history
  • Loading branch information
olomix committed Oct 11, 2021
1 parent a6ee6dc commit 0449c0e
Show file tree
Hide file tree
Showing 5 changed files with 247 additions and 33 deletions.
17 changes: 10 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,22 @@ is complete.
As a side effect tool checks that all resources are released when test exits.
If any Rows is not closed or Conn is not released to pool, test fails.

`go-test-pg` requires schema file to initialize database with. It creates
`go-test-pg` uses schema file to initialize database with. It creates
template database with this schema. Then each temporary database for every test
creates from this template database. If the template database for this
schema is exists, it will be reused. The name of the template database
is composed of `baseName` and md5 hashsum of schema file content.
is composed of `baseName` and md5 hashsum of schema file content. If schema file
is empty, then use default PostgreSQL empty database `template1`.

On complete, temporary databases would be dropped, template database will not
be dropped and would remain for future reuse.

Template database would be created only on first use. If you call `NewPool`
and do not call `With<something>` on it, real database would not be touched.

Each method was `Std` version that returns `*sql.DB`. For example,
default method `WithFixtures` returns `*pgxpool.Pool` and `WithStdFixtures`
returns `*sql.DB`.

## Example usage

Expand All @@ -37,11 +41,10 @@ import (
var dbpool = &ptg.Pgpool{SchemaFile: "../schema.sql"}

func TestX(t *testing.T) {
dbPool, dbClear := dbpool.WithEmpty(t)
defer dbClear()
dbPool := dbpool.WithEmpty(t)
var dbName string
err := dbPool.
QueryRow(context.Background(), "select current_database()").
QueryRow(context.Background(), "SELECT current_database()").
Scan(&dbName)
if err != nil {
t.Fatal(err)
Expand All @@ -55,8 +58,8 @@ Connection to database configured using standard PostgreSQL environment
variable https://www.postgresql.org/docs/11/libpq-envars.html. User needs
permissions to create databases.

If you want to skip all tests, you need to set Skip field in Pgpool struct
to false.
If you want to skip all database tests, you need to set `Skip` field in Pgpool
struct to `true`.

```go
var dbpool = &ptg.Pgpool{Skip: true}
Expand Down
151 changes: 126 additions & 25 deletions database.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package go_test_pg
import (
"context"
"crypto/md5"
"database/sql"
"encoding/hex"
"fmt"
"io/ioutil"
Expand All @@ -15,6 +16,7 @@ import (

"github.com/jackc/pgx/v4"
"github.com/jackc/pgx/v4/pgxpool"
"github.com/jackc/pgx/v4/stdlib"
"github.com/pkg/errors"
)

Expand All @@ -29,7 +31,7 @@ type Pgpool struct {
// BaseName is the prefix of template and temporary databases.
// Default is dbtestpg.
BaseName string
// Name of schema file. Required. Tests would fail if not set.
// Name of schema file. If empty, create empty database.
SchemaFile string // schema file name
// If true, skip all database tests.
Skip bool
Expand All @@ -42,41 +44,66 @@ type Pgpool struct {

// WithFixtures creates database from template database, and initializes it
// with fixtures from `fixtures` array
func (p *Pgpool) WithFixtures(
t testing.TB,
fixtures []Fixture,
) (*pgxpool.Pool, func()) {
pool, clean := p.WithEmpty(t)
func (p *Pgpool) WithFixtures(t testing.TB, fixtures []Fixture) *pgxpool.Pool {
pool := p.WithEmpty(t)
ctx, cancel := context.WithTimeout(context.Background(), defaultTimeout)
defer cancel()
for i, f := range fixtures {
if _, err := pool.Exec(ctx, f.Query, f.Params...); err != nil {
clean()
t.Fatalf(
"can't load fixture at idx %v: %+v",
i, errors.WithStack(err),
)
}
}
return pool, clean
return pool
}

// WithStdFixtures creates database from template database, and initializes it
// with fixtures from `fixtures` array
func (p *Pgpool) WithStdFixtures(t testing.TB, fixtures []Fixture) *sql.DB {
db := p.WithStdEmpty(t)
ctx, cancel := context.WithTimeout(context.Background(), defaultTimeout)
defer cancel()
for i, f := range fixtures {
if _, err := db.ExecContext(ctx, f.Query, f.Params...); err != nil {
t.Fatalf("can't load fixture at idx %v: %+v",
i, errors.WithStack(err))
}
}
return db
}

// WithSQLs creates database from template database, and initializes it
// with fixtures from `sqls` array
func (p *Pgpool) WithSQLs(t testing.TB, sqls []string) (*pgxpool.Pool, func()) {
pool, clean := p.WithEmpty(t)
func (p *Pgpool) WithSQLs(t testing.TB, sqls []string) *pgxpool.Pool {
pool := p.WithEmpty(t)
ctx, cancel := context.WithTimeout(context.Background(), defaultTimeout)
defer cancel()
for i, s := range sqls {
if _, err := pool.Exec(ctx, s); err != nil {
clean()
t.Fatalf(
"can't load fixture at idx %v: %+v",
i, errors.WithStack(err),
)
}
}
return pool, clean
return pool
}

// WithStdSQLs creates database from template database, and initializes it
// with fixtures from `sqls` array
func (p *Pgpool) WithStdSQLs(t testing.TB, sqls []string) *sql.DB {
db := p.WithStdEmpty(t)
ctx, cancel := context.WithTimeout(context.Background(), defaultTimeout)
defer cancel()
for i, s := range sqls {
if _, err := db.ExecContext(ctx, s); err != nil {
t.Fatalf("can't load fixture at idx %v: %+v",
i, errors.WithStack(err))
}
}
return db
}

func (p *Pgpool) getTmpl(t testing.TB) string {
Expand Down Expand Up @@ -111,11 +138,27 @@ func (p *Pgpool) getTmpl(t testing.TB) string {
return p.tmpl
}

func (p *Pgpool) createRndDB(t testing.TB) (*pgxpool.Pool, string) {
// Register pgx.ConnConfig with std driver.
// Return connection string for database/sql and error.
func (p *Pgpool) registerStdConfig(t testing.TB, dbName string) (string, error) {
connConfig, err := pgx.ParseConfig("")
if err != nil {
return "", errors.WithStack(err)
}
connConfig.Logger = newLogger(t)
connConfig.Database = dbName
return stdlib.RegisterConnConfig(connConfig), nil
}

func (p *Pgpool) createRndDB(t testing.TB) (string, error) {
tmpl := p.getTmpl(t)
dbName := fmt.Sprintf("%v_%v", tmpl, p.rnd.Int31())

err := p.createDB(dbName, tmpl)
return dbName, p.createDB(dbName, tmpl)
}

func (p *Pgpool) createRndDBPool(t testing.TB) (*pgxpool.Pool, string) {
dbName, err := p.createRndDB(t)
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -191,9 +234,9 @@ func dropDB(dbName string) error {

// WithEmpty creates empty database from template database, that was
// created from `schema` file.
func (p *Pgpool) WithEmpty(t testing.TB) (*pgxpool.Pool, func()) {
pool, dbName := p.createRndDB(t)
return pool, func() {
func (p *Pgpool) WithEmpty(t testing.TB) *pgxpool.Pool {
pool, dbName := p.createRndDBPool(t)
t.Cleanup(func() {
acquiredConns := pool.Stat().AcquiredConns()
if acquiredConns > 0 {
t.Fatalf(
Expand All @@ -206,7 +249,63 @@ func (p *Pgpool) WithEmpty(t testing.TB) (*pgxpool.Pool, func()) {
if err != nil {
t.Errorf("Can't drop DB %v: %v", dbName, err)
}
})
return pool
}

// WithStdEmpty creates empty database from template database, that was
// created from `schema` file.
func (p *Pgpool) WithStdEmpty(t testing.TB) *sql.DB {
db, cleanupFn := p.newStdDBWithCleanup(t)
if cleanupFn != nil {
t.Cleanup(func() {
if err := cleanupFn(); err != nil {
t.Error(err)
}
})
}
return db
}

func (p *Pgpool) newStdDBWithCleanup(t testing.TB) (*sql.DB, func() error) {
dbName, err := p.createRndDB(t)
if err != nil {
t.Fatal(err)
return nil, nil
}

connString, err := p.registerStdConfig(t, dbName)
if err != nil {
_ = dropDB(dbName)
t.Fatal(err)
return nil, nil
}

db, err := sql.Open("pgx", connString)
if err != nil {
_ = dropDB(dbName)
t.Fatal(err)
return nil, nil
}

cleanupFn := func() error {
stats := db.Stats()
if stats.InUse > 0 {
return errors.Errorf(
"unreleased connections exists: %v, can't drop database %v",
stats.InUse, dbName)
}
err := db.Close()
if err != nil {
return errors.Errorf("Can't close DB %v: %v", dbName, err)
}
err = dropDB(dbName)
if err != nil {
return errors.Errorf("Can't drop DB %v: %v", dbName, err)
}
return nil
}
return db, cleanupFn
}

func (p *Pgpool) createDB(name, tmplName string) error {
Expand All @@ -224,9 +323,11 @@ func (p *Pgpool) createDB(name, tmplName string) error {
)
}

// Creates template db, populates with SQLs from schema file and return name
// of the new database. If database is exists, just return its name.
func (p *Pgpool) createTemplateDB() (string, error) {
if p.SchemaFile == "" {
return "", errors.New("SchemaFile is empty")
return "template1", nil
}
schemaSql, err := ioutil.ReadFile(p.SchemaFile)
if err != nil {
Expand All @@ -238,7 +339,7 @@ func (p *Pgpool) createTemplateDB() (string, error) {
if p.BaseName != "" {
baseName = p.BaseName
}
tmpl := fmt.Sprintf("%v_%v", baseName, schemaHex)
tmplDbName := fmt.Sprintf("%v_%v", baseName, schemaHex)

var dbExists bool
err = withNewConnection(
Expand All @@ -247,14 +348,14 @@ func (p *Pgpool) createTemplateDB() (string, error) {
query := `
SELECT EXISTS(SELECT 1 FROM pg_database WHERE datname = $1)
`
err := conn.QueryRow(ctx, query, tmpl).Scan(&dbExists)
err := conn.QueryRow(ctx, query, tmplDbName).Scan(&dbExists)
if err != nil {
return errors.WithStack(err)
}
if dbExists {
return nil
}
_, err = conn.Exec(ctx, `CREATE DATABASE `+quote(tmpl))
_, err = conn.Exec(ctx, `CREATE DATABASE `+quote(tmplDbName))
return errors.WithStack(err)
},
)
Expand All @@ -263,23 +364,23 @@ SELECT EXISTS(SELECT 1 FROM pg_database WHERE datname = $1)
}

if dbExists {
return tmpl, nil
return tmplDbName, nil
}

err = withNewConnection(
tmpl,
tmplDbName,
func(ctx context.Context, conn *pgx.Conn) error {
_, err = conn.Exec(ctx, string(schemaSql))
return errors.WithStack(err)
},
)

if err != nil {
_ = dropDB(tmpl)
_ = dropDB(tmplDbName)
return "", err
}

return tmpl, nil
return tmplDbName, nil
}

func quote(name string) string {
Expand Down
Loading

0 comments on commit 0449c0e

Please sign in to comment.