diff --git a/internal/postgres/mocks/mock_pg_querier.go b/internal/postgres/mocks/mock_pg_querier.go index ac46d93..c69c060 100644 --- a/internal/postgres/mocks/mock_pg_querier.go +++ b/internal/postgres/mocks/mock_pg_querier.go @@ -12,6 +12,7 @@ type Querier struct { QueryRowFn func(ctx context.Context, query string, args ...any) postgres.Row QueryFn func(ctx context.Context, query string, args ...any) (postgres.Rows, error) ExecFn func(context.Context, string, ...any) (postgres.CommandTag, error) + ExecInTxFn func(context.Context, func(tx postgres.Tx) error) error CloseFn func(context.Context) error } @@ -27,6 +28,10 @@ func (m *Querier) Exec(ctx context.Context, query string, args ...any) (postgres return m.ExecFn(ctx, query, args...) } +func (m *Querier) ExecInTx(ctx context.Context, fn func(tx postgres.Tx) error) error { + return m.ExecInTxFn(ctx, fn) +} + func (m *Querier) Close(ctx context.Context) error { return m.CloseFn(ctx) } diff --git a/internal/postgres/pg_conn.go b/internal/postgres/pg_conn.go index 510cf42..91ec665 100644 --- a/internal/postgres/pg_conn.go +++ b/internal/postgres/pg_conn.go @@ -7,25 +7,12 @@ import ( "fmt" "github.com/jackc/pgx/v5" - "github.com/jackc/pgx/v5/pgconn" ) type Conn struct { conn *pgx.Conn } -type Row interface { - pgx.Row -} - -type Rows interface { - pgx.Rows -} - -type CommandTag struct { - pgconn.CommandTag -} - func NewConn(ctx context.Context, url string) (*Conn, error) { pgCfg, err := pgx.ParseConfig(url) if err != nil { @@ -55,6 +42,20 @@ func (c *Conn) Exec(ctx context.Context, query string, args ...any) (CommandTag, return CommandTag{tag}, mapError(err) } +func (c *Conn) ExecInTx(ctx context.Context, fn func(Tx) error) error { + tx, err := c.conn.BeginTx(ctx, pgx.TxOptions{}) + if err != nil { + return mapError(err) + } + + if err := fn(tx); err != nil { + tx.Rollback(ctx) + return mapError(err) + } + + return tx.Commit(ctx) +} + func (c *Conn) Close(ctx context.Context) error { return mapError(c.conn.Close(ctx)) } diff --git a/internal/postgres/pg_conn_pool.go b/internal/postgres/pg_conn_pool.go index 7574085..45cc9ec 100644 --- a/internal/postgres/pg_conn_pool.go +++ b/internal/postgres/pg_conn_pool.go @@ -6,6 +6,7 @@ import ( "context" "fmt" + "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgxpool" ) @@ -42,6 +43,20 @@ func (c *Pool) Exec(ctx context.Context, query string, args ...any) (CommandTag, return CommandTag{tag}, mapError(err) } +func (c *Pool) ExecInTx(ctx context.Context, fn func(Tx) error) error { + tx, err := c.Pool.BeginTx(ctx, pgx.TxOptions{}) + if err != nil { + return mapError(err) + } + + if err := fn(tx); err != nil { + tx.Rollback(ctx) + return mapError(err) + } + + return tx.Commit(ctx) +} + func (c *Pool) Close(_ context.Context) error { c.Pool.Close() return nil diff --git a/internal/postgres/pg_querier.go b/internal/postgres/pg_querier.go index 7375b60..c75befc 100644 --- a/internal/postgres/pg_querier.go +++ b/internal/postgres/pg_querier.go @@ -2,14 +2,37 @@ package postgres -import "context" +import ( + "context" + + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" +) type Querier interface { Query(ctx context.Context, query string, args ...any) (Rows, error) QueryRow(ctx context.Context, query string, args ...any) Row Exec(ctx context.Context, query string, args ...any) (CommandTag, error) + ExecInTx(ctx context.Context, fn func(tx Tx) error) error Close(ctx context.Context) error } + +type Row interface { + pgx.Row +} + +type Rows interface { + pgx.Rows +} + +type Tx interface { + pgx.Tx +} + +type CommandTag struct { + pgconn.CommandTag +} + type mappedRow struct { inner Row }