2024-11-12 08:45:45 +00:00
|
|
|
package pgUtils
|
|
|
|
|
|
|
|
import (
|
|
|
|
"context"
|
|
|
|
"errors"
|
|
|
|
"github.com/jackc/pgx/v5"
|
|
|
|
"github.com/jackc/pgx/v5/pgxpool"
|
|
|
|
reflectUtils "nettools/reflect"
|
|
|
|
)
|
|
|
|
|
|
|
|
var (
|
|
|
|
ErrEntityNotFound = errors.New("entity not found")
|
|
|
|
ErrTooManyRows = errors.New("too many rows")
|
|
|
|
ErrEntityAlreadyExists = errors.New("entity already exists")
|
|
|
|
ErrNoDstFields = errors.New("no destination fields")
|
|
|
|
)
|
|
|
|
|
|
|
|
type PgxQuerier interface {
|
|
|
|
Query(ctx context.Context, sql string, args ...any) (pgx.Rows, error)
|
|
|
|
}
|
|
|
|
|
|
|
|
// Select executes query on a provided querier and tries to parse db response into antit
|
|
|
|
// Works only with objects
|
|
|
|
//
|
|
|
|
// Usage:
|
|
|
|
//
|
|
|
|
// type User struct {
|
2024-11-12 08:53:46 +00:00
|
|
|
// id int `db:"id"`
|
|
|
|
// name string `db:"name"`
|
2024-11-12 08:45:45 +00:00
|
|
|
// }
|
|
|
|
//
|
|
|
|
// db := pgx.Connect(context.Background(), "<url>")
|
2024-11-12 08:53:46 +00:00
|
|
|
// users, err := pgUtils.Select[User](context.Background(), db, "SELECT id, name FROM users")
|
2024-11-12 08:45:45 +00:00
|
|
|
func Select[T any](ctx context.Context, db PgxQuerier, query string, args ...any) (out []*T, err error) {
|
|
|
|
rows, err := db.Query(ctx, query, args)
|
|
|
|
if err != nil {
|
|
|
|
switch {
|
|
|
|
case errors.Is(err, pgx.ErrNoRows):
|
|
|
|
err = ErrEntityNotFound
|
|
|
|
} // TODO: extend cases
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
|
|
|
|
// Get column names
|
|
|
|
columns := make([]string, len(rows.FieldDescriptions()))
|
|
|
|
for i, fd := range rows.FieldDescriptions() {
|
|
|
|
columns[i] = fd.Name
|
|
|
|
}
|
2024-11-12 08:53:46 +00:00
|
|
|
itemFieldPtrs := make([]any, len(columns))
|
|
|
|
out = []*T{}
|
2024-11-12 08:45:45 +00:00
|
|
|
|
|
|
|
defer rows.Close()
|
|
|
|
for rows.Next() {
|
|
|
|
item := new(T)
|
|
|
|
dstItemPtrsMap, err := reflectUtils.GetEntityPtrs(item, "db")
|
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
for i, columnName := range columns {
|
|
|
|
itemFieldPtrs[i] = dstItemPtrsMap[columnName]
|
|
|
|
}
|
|
|
|
if len(itemFieldPtrs) == 0 {
|
|
|
|
return nil, ErrNoDstFields
|
|
|
|
}
|
|
|
|
if err = rows.Scan(itemFieldPtrs...); err != nil {
|
|
|
|
return out, err
|
|
|
|
}
|
|
|
|
out = append(out, item)
|
|
|
|
}
|
|
|
|
return out, err
|
|
|
|
}
|
|
|
|
|
|
|
|
// Tx creates new transaction. Cancels it if returned not nil err
|
|
|
|
func Tx(ctx context.Context, db *pgxpool.Pool, exec func(ctx context.Context, tx pgx.Tx) error) error {
|
|
|
|
tx, err := db.Begin(ctx)
|
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
err = exec(ctx, tx)
|
|
|
|
if err != nil {
|
|
|
|
_ = tx.Rollback(ctx)
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
return tx.Commit(ctx)
|
|
|
|
}
|