Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: move leases DAL+queries into its own package #2566

Merged
merged 2 commits into from
Aug 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 17 additions & 13 deletions backend/controller/controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,13 @@ import (
"github.com/TBD54566975/ftl/backend/controller/dal"
"github.com/TBD54566975/ftl/backend/controller/ingress"
"github.com/TBD54566975/ftl/backend/controller/leases"
leasesdal "github.com/TBD54566975/ftl/backend/controller/leases/dal"
"github.com/TBD54566975/ftl/backend/controller/observability"
"github.com/TBD54566975/ftl/backend/controller/pubsub"
"github.com/TBD54566975/ftl/backend/controller/scaling"
"github.com/TBD54566975/ftl/backend/controller/scaling/localscaling"
"github.com/TBD54566975/ftl/backend/controller/scheduledtask"
dalerrs "github.com/TBD54566975/ftl/backend/dal"
"github.com/TBD54566975/ftl/backend/libdal"
ftlv1 "github.com/TBD54566975/ftl/backend/protos/xyz/block/ftl/v1"
"github.com/TBD54566975/ftl/backend/protos/xyz/block/ftl/v1/console/pbconsoleconnect"
"github.com/TBD54566975/ftl/backend/protos/xyz/block/ftl/v1/ftlv1connect"
Expand Down Expand Up @@ -194,6 +195,7 @@ type ControllerListListener interface {

type Service struct {
conn *sql.DB
leasesdal *leasesdal.DAL
dal *dal.DAL
key model.ControllerKey
deploymentLogsSink *deploymentLogsSink
Expand Down Expand Up @@ -231,14 +233,16 @@ func New(ctx context.Context, conn *sql.DB, config Config, runnerScaling scaling
config.ControllerTimeout = time.Second * 5
}

ldb := leasesdal.New(conn)
db, err := dal.New(ctx, conn, encryption.NewBuilder().WithKMSURI(optional.Ptr(config.KMSURI)))
if err != nil {
return nil, fmt.Errorf("failed to create DAL: %w", err)
}

svc := &Service{
tasks: scheduledtask.New(ctx, key, db),
tasks: scheduledtask.New(ctx, key, ldb),
dal: db,
leasesdal: ldb,
conn: conn,
key: key,
deploymentLogsSink: newDeploymentLogsSink(ctx, db),
Expand Down Expand Up @@ -307,7 +311,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {

routes, err := s.dal.GetIngressRoutes(r.Context(), r.Method)
if err != nil {
if errors.Is(err, dalerrs.ErrNotFound) {
if errors.Is(err, libdal.ErrNotFound) {
http.NotFound(w, r)
observability.Ingress.Request(r.Context(), r.Method, r.URL.Path, optional.None[*schemapb.Ref](), start, optional.Some("route not found in dal"))
return
Expand Down Expand Up @@ -509,7 +513,7 @@ func (s *Service) UpdateDeploy(ctx context.Context, req *connect.Request[ftlv1.U

err = s.dal.SetDeploymentReplicas(ctx, deploymentKey, int(req.Msg.MinReplicas))
if err != nil {
if errors.Is(err, dalerrs.ErrNotFound) {
if errors.Is(err, libdal.ErrNotFound) {
logger.Errorf(err, "Deployment not found: %s", deploymentKey)
return nil, connect.NewError(connect.CodeNotFound, errors.New("deployment not found"))
}
Expand All @@ -531,7 +535,7 @@ func (s *Service) ReplaceDeploy(ctx context.Context, c *connect.Request[ftlv1.Re

err = s.dal.ReplaceDeployment(ctx, newDeploymentKey, int(c.Msg.MinReplicas))
if err != nil {
if errors.Is(err, dalerrs.ErrNotFound) {
if errors.Is(err, libdal.ErrNotFound) {
logger.Errorf(err, "Deployment not found: %s", newDeploymentKey)
return nil, connect.NewError(connect.CodeNotFound, errors.New("deployment not found"))
} else if errors.Is(err, dal.ErrReplaceDeploymentAlreadyActive) {
Expand Down Expand Up @@ -591,7 +595,7 @@ func (s *Service) RegisterRunner(ctx context.Context, stream *connect.ClientStre
Deployment: maybeDeployment,
Labels: msg.Labels.AsMap(),
})
if errors.Is(err, dalerrs.ErrConflict) {
if errors.Is(err, libdal.ErrConflict) {
return nil, connect.NewError(connect.CodeAlreadyExists, err)
} else if err != nil {
return nil, err
Expand All @@ -608,7 +612,7 @@ func (s *Service) RegisterRunner(ctx context.Context, stream *connect.ClientStre
}

routes, err := s.dal.GetRoutingTable(ctx, nil)
if errors.Is(err, dalerrs.ErrNotFound) {
if errors.Is(err, libdal.ErrNotFound) {
routes = map[string][]dal.Route{}
} else if err != nil {
return nil, err
Expand Down Expand Up @@ -815,7 +819,7 @@ func (s *Service) AcquireLease(ctx context.Context, stream *connect.BidiStream[f
return connect.NewError(connect.CodeInternal, fmt.Errorf("could not receive lease request: %w", err))
}
if lease == nil {
lease, _, err = s.dal.AcquireLease(ctx, leases.ModuleKey(msg.Module, msg.Key...), msg.Ttl.AsDuration(), optional.None[any]())
lease, _, err = s.leasesdal.AcquireLease(ctx, leases.ModuleKey(msg.Module, msg.Key...), msg.Ttl.AsDuration(), optional.None[any]())
if err != nil {
if errors.Is(err, leases.ErrConflict) {
return connect.NewError(connect.CodeResourceExhausted, fmt.Errorf("lease is held: %w", err))
Expand Down Expand Up @@ -948,7 +952,7 @@ func (s *Service) SetNextFSMEvent(ctx context.Context, req *connect.Request[ftlv
// Get the current state the instance is transitioning to.
_, currentDestinationState, err := tx.GetFSMStates(ctx, fsmKey, req.Msg.Instance)
if err != nil {
if errors.Is(err, dalerrs.ErrNotFound) {
if errors.Is(err, libdal.ErrNotFound) {
return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("fsm instance not found: %w", err))
}
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("could not get fsm instance: %w", err))
Expand All @@ -963,7 +967,7 @@ func (s *Service) SetNextFSMEvent(ctx context.Context, req *connect.Request[ftlv
// Set the next event.
err = tx.SetNextFSMEvent(ctx, fsmKey, msg.Instance, nextState.ToRefKey(), msg.Body, eventType)
if err != nil {
if errors.Is(err, dalerrs.ErrConflict) {
if errors.Is(err, libdal.ErrConflict) {
return nil, connect.NewError(connect.CodeFailedPrecondition, fmt.Errorf("fsm instance already has its next state set: %w", err))
}
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("could not set next fsm event: %w", err))
Expand Down Expand Up @@ -1403,7 +1407,7 @@ func (s *Service) executeAsyncCalls(ctx context.Context) (interval time.Duration
logger.Tracef("Acquiring async call")

call, leaseCtx, err := s.dal.AcquireAsyncCall(ctx)
if errors.Is(err, dalerrs.ErrNotFound) {
if errors.Is(err, libdal.ErrNotFound) {
logger.Tracef("No async calls to execute")
return time.Second * 2, nil
} else if err != nil {
Expand Down Expand Up @@ -1740,7 +1744,7 @@ func (s *Service) resolveFSMEvent(msg *ftlv1.SendFSMEventRequest) (fsm *schema.F
}

func (s *Service) expireStaleLeases(ctx context.Context) (time.Duration, error) {
err := s.dal.ExpireLeases(ctx)
err := s.leasesdal.ExpireLeases(ctx)
if err != nil {
return 0, fmt.Errorf("failed to expire leases: %w", err)
}
Expand Down Expand Up @@ -1972,7 +1976,7 @@ func (s *Service) getDeploymentLogger(ctx context.Context, deploymentKey model.D
// Periodically sync the routing table from the DB.
func (s *Service) syncRoutes(ctx context.Context) (time.Duration, error) {
routes, err := s.dal.GetRoutingTable(ctx, nil)
if errors.Is(err, dalerrs.ErrNotFound) {
if errors.Is(err, libdal.ErrNotFound) {
routes = map[string][]dal.Route{}
} else if err != nil {
return 0, err
Expand Down
4 changes: 2 additions & 2 deletions backend/controller/cronjobs/cronjobs_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import (
"github.com/TBD54566975/ftl/backend/controller/cronjobs/dal"
parentdal "github.com/TBD54566975/ftl/backend/controller/dal"
"github.com/TBD54566975/ftl/backend/controller/sql/sqltest"
dalerrs "github.com/TBD54566975/ftl/backend/dal"
"github.com/TBD54566975/ftl/backend/libdal"
"github.com/TBD54566975/ftl/backend/schema"
"github.com/TBD54566975/ftl/internal/cron"
"github.com/TBD54566975/ftl/internal/encryption"
Expand Down Expand Up @@ -57,7 +57,7 @@ func TestNewCronJobsForModule(t *testing.T) {

// No async calls yet
_, _, err = parentDAL.AcquireAsyncCall(ctx)
assert.IsError(t, err, dalerrs.ErrNotFound)
assert.IsError(t, err, libdal.ErrNotFound)
assert.EqualError(t, err, "no pending async calls: not found")

err = cjs.scheduleCronJobs(ctx)
Expand Down
18 changes: 9 additions & 9 deletions backend/controller/cronjobs/dal/dal.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,21 +9,21 @@ import (

"github.com/TBD54566975/ftl/backend/controller/cronjobs/sql"
"github.com/TBD54566975/ftl/backend/controller/observability"
"github.com/TBD54566975/ftl/backend/dal"
"github.com/TBD54566975/ftl/backend/libdal"
"github.com/TBD54566975/ftl/backend/schema"
"github.com/TBD54566975/ftl/internal/model"
"github.com/TBD54566975/ftl/internal/slices"
)

type DAL struct {
*dal.Handle[DAL]
*libdal.Handle[DAL]
db sql.Querier
}

func New(conn dal.Connection) *DAL {
func New(conn libdal.Connection) *DAL {
return &DAL{
db: sql.New(conn),
Handle: dal.New(conn, func(h *dal.Handle[DAL]) *DAL {
Handle: libdal.New(conn, func(h *libdal.Handle[DAL]) *DAL {
return &DAL{Handle: h, db: sql.New(h.Connection)}
}),
}
Expand All @@ -45,7 +45,7 @@ func cronJobFromRow(c sql.CronJob, d sql.Deployment) model.CronJob {
func (d *DAL) CreateAsyncCall(ctx context.Context, params sql.CreateAsyncCallParams) (int64, error) {
id, err := d.db.CreateAsyncCall(ctx, params)
if err != nil {
return 0, fmt.Errorf("failed to create async call: %w", dal.TranslatePGError(err))
return 0, fmt.Errorf("failed to create async call: %w", libdal.TranslatePGError(err))
}
observability.AsyncCalls.Created(ctx, params.Verb, optional.None[schema.RefKey](), params.Origin, 0, err)
queueDepth, err := d.db.AsyncCallQueueDepth(ctx)
Expand All @@ -62,7 +62,7 @@ func (d *DAL) CreateAsyncCall(ctx context.Context, params sql.CreateAsyncCallPar
func (d *DAL) GetUnscheduledCronJobs(ctx context.Context, startTime time.Time) ([]model.CronJob, error) {
rows, err := d.db.GetUnscheduledCronJobs(ctx, startTime)
if err != nil {
return nil, fmt.Errorf("failed to get cron jobs: %w", dal.TranslatePGError(err))
return nil, fmt.Errorf("failed to get cron jobs: %w", libdal.TranslatePGError(err))
}
return slices.Map(rows, func(r sql.GetUnscheduledCronJobsRow) model.CronJob {
return cronJobFromRow(r.CronJob, r.Deployment)
Expand All @@ -73,7 +73,7 @@ func (d *DAL) GetUnscheduledCronJobs(ctx context.Context, startTime time.Time) (
func (d *DAL) GetCronJobByKey(ctx context.Context, key model.CronJobKey) (model.CronJob, error) {
row, err := d.db.GetCronJobByKey(ctx, key)
if err != nil {
return model.CronJob{}, fmt.Errorf("failed to get cron job %q: %w", key, dal.TranslatePGError(err))
return model.CronJob{}, fmt.Errorf("failed to get cron job %q: %w", key, libdal.TranslatePGError(err))
}
return cronJobFromRow(row.CronJob, row.Deployment), nil
}
Expand All @@ -82,7 +82,7 @@ func (d *DAL) GetCronJobByKey(ctx context.Context, key model.CronJobKey) (model.
func (d *DAL) IsCronJobPending(ctx context.Context, key model.CronJobKey, startTime time.Time) (bool, error) {
pending, err := d.db.IsCronJobPending(ctx, key, startTime)
if err != nil {
return false, fmt.Errorf("failed to check if cron job %q is pending: %w", key, dal.TranslatePGError(err))
return false, fmt.Errorf("failed to check if cron job %q is pending: %w", key, libdal.TranslatePGError(err))
}
return pending, nil
}
Expand All @@ -92,7 +92,7 @@ func (d *DAL) IsCronJobPending(ctx context.Context, key model.CronJobKey, startT
func (d *DAL) UpdateCronJobExecution(ctx context.Context, params sql.UpdateCronJobExecutionParams) error {
err := d.db.UpdateCronJobExecution(ctx, params)
if err != nil {
return fmt.Errorf("failed to update cron job %q: %w", params.Key, dal.TranslatePGError(err))
return fmt.Errorf("failed to update cron job %q: %w", params.Key, libdal.TranslatePGError(err))
}
return nil
}
2 changes: 1 addition & 1 deletion backend/controller/cronjobs/sql/models.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions backend/controller/cronjobs/sql/types.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
package sql

import csql "github.com/TBD54566975/ftl/backend/controller/sql"
import "github.com/TBD54566975/ftl/backend/controller/sql/sqltypes"

type Type = csql.Type
type Type = sqltypes.Type
27 changes: 14 additions & 13 deletions backend/controller/dal/async_calls.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,10 @@ import (
"github.com/alecthomas/types/either"
"github.com/alecthomas/types/optional"

leasedal "github.com/TBD54566975/ftl/backend/controller/leases/dal"
"github.com/TBD54566975/ftl/backend/controller/sql"
"github.com/TBD54566975/ftl/backend/controller/sql/sqltypes"
dalerrs "github.com/TBD54566975/ftl/backend/dal"
"github.com/TBD54566975/ftl/backend/libdal"
"github.com/TBD54566975/ftl/backend/schema"
"github.com/TBD54566975/ftl/internal/encryption"
"github.com/TBD54566975/ftl/internal/model"
Expand Down Expand Up @@ -95,7 +96,7 @@ func ParseAsyncOrigin(origin string) (AsyncOrigin, error) {
}

type AsyncCall struct {
*Lease // May be nil
*leasedal.Lease // May be nil
ID int64
Origin AsyncOrigin
Verb schema.RefKey
Expand Down Expand Up @@ -127,9 +128,9 @@ func (d *DAL) AcquireAsyncCall(ctx context.Context) (call *AsyncCall, leaseCtx c
ttl := time.Second * 5
row, err := tx.db.AcquireAsyncCall(ctx, sqltypes.Duration(ttl))
if err != nil {
err = dalerrs.TranslatePGError(err)
if errors.Is(err, dalerrs.ErrNotFound) {
return nil, ctx, fmt.Errorf("no pending async calls: %w", dalerrs.ErrNotFound)
err = libdal.TranslatePGError(err)
if errors.Is(err, libdal.ErrNotFound) {
return nil, ctx, fmt.Errorf("no pending async calls: %w", libdal.ErrNotFound)
}
return nil, ctx, fmt.Errorf("failed to acquire async call: %w", err)
}
Expand All @@ -143,7 +144,7 @@ func (d *DAL) AcquireAsyncCall(ctx context.Context) (call *AsyncCall, leaseCtx c
return nil, ctx, fmt.Errorf("failed to decrypt async call request: %w", err)
}

lease, leaseCtx := d.newLease(ctx, row.LeaseKey, row.LeaseIdempotencyKey, ttl)
lease, leaseCtx := d.leasedal.NewLease(ctx, row.LeaseKey, row.LeaseIdempotencyKey, ttl)
return &AsyncCall{
ID: row.AsyncCallID,
Verb: row.Verb,
Expand Down Expand Up @@ -177,7 +178,7 @@ func (d *DAL) CompleteAsyncCall(ctx context.Context,
case *dbsql.DB:
tx, err = d.Begin(ctx)
if err != nil {
return false, dalerrs.TranslatePGError(err) //nolint:wrapcheck
return false, libdal.TranslatePGError(err) //nolint:wrapcheck
}
defer tx.CommitOrRollback(ctx, &err)
case *dbsql.Tx:
Expand All @@ -197,7 +198,7 @@ func (d *DAL) CompleteAsyncCall(ctx context.Context,
}
_, err = tx.db.SucceedAsyncCall(ctx, optional.Some(encryptedResult), call.ID)
if err != nil {
return false, dalerrs.TranslatePGError(err) //nolint:wrapcheck
return false, libdal.TranslatePGError(err) //nolint:wrapcheck
}

case either.Right[[]byte, string]: // Failure message.
Expand All @@ -211,7 +212,7 @@ func (d *DAL) CompleteAsyncCall(ctx context.Context,
ScheduledAt: time.Now().Add(call.Backoff),
})
if err != nil {
return false, dalerrs.TranslatePGError(err) //nolint:wrapcheck
return false, libdal.TranslatePGError(err) //nolint:wrapcheck
}
isFinalResult = false
didScheduleAnotherCall = true
Expand All @@ -234,14 +235,14 @@ func (d *DAL) CompleteAsyncCall(ctx context.Context,
OriginalError: optional.Some(originalError),
})
if err != nil {
return false, dalerrs.TranslatePGError(err) //nolint:wrapcheck
return false, libdal.TranslatePGError(err) //nolint:wrapcheck
}
isFinalResult = false
didScheduleAnotherCall = true
} else {
_, err = tx.db.FailAsyncCall(ctx, result.Get(), call.ID)
if err != nil {
return false, dalerrs.TranslatePGError(err) //nolint:wrapcheck
return false, libdal.TranslatePGError(err) //nolint:wrapcheck
}
}
}
Expand All @@ -254,7 +255,7 @@ func (d *DAL) CompleteAsyncCall(ctx context.Context,
func (d *DAL) LoadAsyncCall(ctx context.Context, id int64) (*AsyncCall, error) {
row, err := d.db.LoadAsyncCall(ctx, id)
if err != nil {
return nil, dalerrs.TranslatePGError(err)
return nil, libdal.TranslatePGError(err)
}
origin, err := ParseAsyncOrigin(row.Origin)
if err != nil {
Expand All @@ -275,7 +276,7 @@ func (d *DAL) LoadAsyncCall(ctx context.Context, id int64) (*AsyncCall, error) {
func (d *DAL) GetZombieAsyncCalls(ctx context.Context, limit int) ([]*AsyncCall, error) {
rows, err := d.db.GetZombieAsyncCalls(ctx, int32(limit))
if err != nil {
return nil, dalerrs.TranslatePGError(err)
return nil, libdal.TranslatePGError(err)
}
var calls []*AsyncCall
for _, row := range rows {
Expand Down
4 changes: 2 additions & 2 deletions backend/controller/dal/async_calls_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import (
"github.com/alecthomas/assert/v2"

"github.com/TBD54566975/ftl/backend/controller/sql/sqltest"
dalerrs "github.com/TBD54566975/ftl/backend/dal"
"github.com/TBD54566975/ftl/backend/libdal"
"github.com/TBD54566975/ftl/backend/schema"
"github.com/TBD54566975/ftl/internal/encryption"
"github.com/TBD54566975/ftl/internal/log"
Expand All @@ -21,7 +21,7 @@ func TestNoCallToAcquire(t *testing.T) {
assert.NoError(t, err)

_, _, err = dal.AcquireAsyncCall(ctx)
assert.IsError(t, err, dalerrs.ErrNotFound)
assert.IsError(t, err, libdal.ErrNotFound)
assert.EqualError(t, err, "no pending async calls: not found")
}

Expand Down
Loading
Loading