diff --git a/go.mod b/go.mod index 033dfb2..04a7c58 100644 --- a/go.mod +++ b/go.mod @@ -9,6 +9,7 @@ require ( github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect github.com/jackc/pgx/v5 v5.7.1 // indirect + github.com/matchsystems/werr v0.1.3 // indirect go.uber.org/atomic v1.7.0 // indirect golang.org/x/crypto v0.27.0 // indirect golang.org/x/text v0.18.0 // indirect diff --git a/go.sum b/go.sum index 3750adf..126ec1a 100644 --- a/go.sum +++ b/go.sum @@ -13,6 +13,8 @@ github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7Ulw github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= github.com/jackc/pgx/v5 v5.7.1 h1:x7SYsPBYDkHDksogeSmZZ5xzThcTgRz++I5E+ePFUcs= github.com/jackc/pgx/v5 v5.7.1/go.mod h1:e7O26IywZZ+naJtWWos6i6fvWK+29etgITqrqHLfoZA= +github.com/matchsystems/werr v0.1.3 h1:h932fzdGLE67w5O8F3O2vO49KkjmSeqsFQqDFkIOMYM= +github.com/matchsystems/werr v0.1.3/go.mod h1:MpZemBWOQ0IuQogwr5aCjNnIfWe+iEfnSh7nTGQ3M7I= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= diff --git a/pg/pg.go b/pg/pg.go index 6ad743e..d375644 100644 --- a/pg/pg.go +++ b/pg/pg.go @@ -2,25 +2,18 @@ package pgUtils import ( "context" - "errors" - reflectUtils "git.mic.pp.ua/anderson/nettools/reflect" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgxpool" -) - -var ( - ErrEntityNotFound = errors.New("entity not found") - ErrTooManyRows = errors.New("too many rows") - ErrEntityAlreadyExists = errors.New("entity already exists") - ErrNoDstFields = errors.New("no destination fields") + "github.com/matchsystems/werr" ) type PgxQuerier interface { Query(ctx context.Context, sql string, args ...any) (pgx.Rows, error) + QueryRow(ctx context.Context, sql string, args ...any) (pgx.Row, error) } -// Select executes query on a provided querier and tries to parse db response -// Works only with objects +// Query executes query on a provided querier and tries to parse db response +// Works only with structs // // Usage: // @@ -30,44 +23,17 @@ type PgxQuerier interface { // } // // db := pgx.Connect(context.Background(), "") -// users, err := pgUtils.Select[User](context.Background(), db, "SELECT id, name FROM users") -func Select[T any](ctx context.Context, db PgxQuerier, query string, args ...any) (out []*T, err error) { +// users, err := pgUtils.Query[User](context.Background(), db, "SELECT id, name FROM users") +func Query[T any](ctx context.Context, db PgxQuerier, query string, args ...any) (out []T, err error) { rows, err := db.Query(ctx, query, args) if err != nil { - switch { - case errors.Is(err, pgx.ErrNoRows): - err = ErrEntityNotFound - } // TODO: extend cases return nil, err } - - // Get column names - columns := make([]string, len(rows.FieldDescriptions())) - for i, fd := range rows.FieldDescriptions() { - columns[i] = fd.Name + entities, err := pgx.CollectRows(rows, pgx.RowToStructByNameLax[T]) + if err != nil { + return nil, werr.Wrapf(err, "failed to parse query results") } - itemFieldPtrs := make([]any, len(columns)) - out = []*T{} - - defer rows.Close() - for rows.Next() { - item := new(T) - dstItemPtrsMap, err := reflectUtils.GetEntityPtrs(item, "db") - if err != nil { - return nil, err - } - for i, columnName := range columns { - itemFieldPtrs[i] = dstItemPtrsMap[columnName] - } - if len(itemFieldPtrs) == 0 { - return nil, ErrNoDstFields - } - if err = rows.Scan(itemFieldPtrs...); err != nil { - return out, err - } - out = append(out, item) - } - return out, err + return entities, nil } // Tx creates new transaction. Cancels it if returned not nil err