From 010be751ccc1f795cb9975782df17430ef71def9 Mon Sep 17 00:00:00 2001 From: Brandur Date: Wed, 22 Apr 2026 21:40:57 -0500 Subject: [PATCH] Implement resumable jobs Here, implement "resumable" jobs, which are jobs that can checkpoint their progress so that in case they have to stop early, they're picked up from a point that lets them skip work that's already been done. This is especially useful for long running jobs that are at risk of being interrupted from something like a deploy. Here's roughly the shape of the API, with the same normal `Work` function that all jobs implement, and with a series of `ResumableStep` calls within, each of which take a name for the step and function representing it: func (w *ResumableWorker) Work(ctx context.Context, job *river.Job[ResumableArgs]) error { river.ResumableStep(ctx, "step1", func(ctx context.Context) error { fmt.Println("Step 1") return nil }) river.ResumableStep(ctx, "step2", func(ctx context.Context) error { fmt.Println("Step 2") return nil }) river.ResumableStep(ctx, "step3", func(ctx context.Context) error { fmt.Println("Step 3") return nil }) return nil } We also provide a cursor API for more granularity. This lets a step set an arbitrary cursor value periodically as it's doing something like looping over records in a set: river.ResumableStepCursor(ctx, "process_ids", func(ctx context.Context, cursor ResumableCursor) error { for _, id := range job.Args.IDs { if id <= cursor.LastProcessedID { continue } fmt.Printf("Processed %d\n", id) if err := river.ResumableSetCursor(ctx, ResumableCursor{LastProcessedID: id}); err != nil { return err } } return nil }) The function is `ResumableStepCursor[TCursor any]` where `TCursor` can be defined arbitrarily by the user. This could be a simple scalar value representing an ID, or a more complex `struct` value containing multiple IDs, enabling nested loops that set inner and outer IDs at the same time. `ResumableStep` and `ResumableStepCursor` steps can be freely intermingled, and multiple `ResumableStepCursor` steps with different cursor types are supported. Cursors must be JSON marshable because they're stored to a job's metadata. Lastly, we provide `ResumableSetStepTx` and `ResumableSetStepCursorTx` for cases where a transaction guarantee is necessary. Normally, resumable step and cursor are set as a job's being completed, but there's a chance this is never called in case of sudden failure. `ResumableSetStepTx` (and its cursor version) is available to durably persist a step at the cost of an extra database operation similar to how `JobCompleteTx` does the same for job completion. One neat aspect the implementation here is that I was able to make it entirely middleware-only. So all the resumable job logic goes in an internal `resumableMiddleware` that's included in all clients by default. This is kind of nice because it keeps its code highly modular and will hopefully act as a template for future features. --- CHANGELOG.md | 4 + client.go | 7 +- client_test.go | 114 +++++++++- example_resumable_cursor_job_test.go | 105 +++++++++ example_resumable_job_test.go | 96 ++++++++ example_resumable_set_step_tx_test.go | 126 +++++++++++ internal/maintenance/job_scheduler_test.go | 3 +- internal/rivercommon/river_common.go | 8 + job_list_params.go | 5 +- resumable.go | 205 +++++++++++++++++ resumable_step_tx.go | 112 +++++++++ resumable_step_tx_test.go | 153 +++++++++++++ resumable_test.go | 252 +++++++++++++++++++++ 13 files changed, 1182 insertions(+), 8 deletions(-) create mode 100644 example_resumable_cursor_job_test.go create mode 100644 example_resumable_job_test.go create mode 100644 example_resumable_set_step_tx_test.go create mode 100644 resumable.go create mode 100644 resumable_step_tx.go create mode 100644 resumable_step_tx_test.go create mode 100644 resumable_test.go diff --git a/CHANGELOG.md b/CHANGELOG.md index 159020b0..095a0811 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +### Added + +- Added "resumable jobs" that can be broken down into multiple steps and with a step persisted after it finishes that lets them skip work that's already been done. This is particularly useful for long running jobs that may experience a cancellation (like in the event of a deploy) during the span of their run. [PR #1226](https://github.com/riverqueue/river/pull/1226). + ## [0.35.0] - 2026-04-18 ### Changed diff --git a/client.go b/client.go index b60addcd..923b98a1 100644 --- a/client.go +++ b/client.go @@ -780,7 +780,8 @@ func NewClient[TTx any](driver riverdriver.Driver[TTx], config *Config) (*Client // the more abstract config.Middleware for middleware are set, but not both, // so in practice we never append all three of these to each other. { - middleware := config.Middleware + middleware := defaultMiddleware() + middleware = append(middleware, config.Middleware...) for _, jobInsertMiddleware := range config.JobInsertMiddleware { middleware = append(middleware, jobInsertMiddleware) } @@ -1002,6 +1003,10 @@ func NewClient[TTx any](driver riverdriver.Driver[TTx], config *Config) (*Client return client, nil } +func defaultMiddleware() []rivertype.Middleware { + return []rivertype.Middleware{&resumableMiddleware{}} +} + // Start starts the client's job fetching and working loops. Once this is called, // the client will run in a background goroutine until stopped. All jobs are // run with a context inheriting from the provided context, but with a timeout diff --git a/client_test.go b/client_test.go index f8d035c4..4a6e98cf 100644 --- a/client_test.go +++ b/client_test.go @@ -82,6 +82,62 @@ func (w *periodicJobWorker) Work(ctx context.Context, job *Job[periodicJobArgs]) return nil } +type resumableClientTestArgs struct{} + +func (resumableClientTestArgs) Kind() string { return "resumable_client_test" } + +type resumableClientTestWorker struct { + WorkerDefaults[resumableClientTestArgs] + + calls []string + callsMu sync.Mutex + failedOnce atomic.Bool +} + +func (w *resumableClientTestWorker) Calls() []string { + w.callsMu.Lock() + defer w.callsMu.Unlock() + + return append([]string(nil), w.calls...) +} + +func (w *resumableClientTestWorker) Work(ctx context.Context, job *Job[resumableClientTestArgs]) error { + appendCall := func(call string) { + w.callsMu.Lock() + defer w.callsMu.Unlock() + + w.calls = append(w.calls, call) + } + + ResumableStep(ctx, "step1", func(ctx context.Context) error { + appendCall("step1") + return nil + }) + + ResumableStepCursor(ctx, "step2", func(ctx context.Context, cursor int) error { + appendCall("step2:" + strconv.Itoa(cursor)) + + for itemID := cursor + 1; itemID <= 2; itemID++ { + appendCall("item:" + strconv.Itoa(itemID)) + if err := ResumableSetCursor(ctx, itemID); err != nil { + return err + } + if !w.failedOnce.Swap(true) { + return errors.New("retry me") + } + } + + return nil + }) + + ResumableStep(ctx, "step3", func(ctx context.Context) error { + appendCall("step3") + return nil + }) + + return nil +} + func makeAwaitWorker[T JobArgs](startedCh chan<- int64, doneCh chan struct{}) Worker[T] { return WorkFunc(func(ctx context.Context, job *Job[T]) error { client := ClientFromContext[pgx.Tx](ctx) @@ -6936,6 +6992,58 @@ func Test_Client_JobCompletion(t *testing.T) { require.Nil(t, reloadedJob.FinalizedAt) }) + t.Run("ResumableJobRetriesAndResumes", func(t *testing.T) { + t.Parallel() + + config := newTestConfig(t, "") + config.RetryPolicy = &retrypolicytest.RetryPolicyNoJitter{} + + worker := &resumableClientTestWorker{} + AddWorker(config.Workers, worker) + + client, bundle := setup(t, config) + + insertRes, err := client.Insert(ctx, resumableClientTestArgs{}, nil) + require.NoError(t, err) + + // Wait for the first attempt to fail after step2 checkpoints cursor + // progress and intentionally returns "retry me", leaving the job queued + // for retry. + eventFailed := riversharedtest.WaitOrTimeout(t, bundle.subscribeChan) + require.Equal(t, EventKindJobFailed, eventFailed.Kind) + require.Equal(t, insertRes.Job.ID, eventFailed.Job.ID) + + var retryableMetadata map[string]any + require.Contains(t, []rivertype.JobState{rivertype.JobStateAvailable, rivertype.JobStateRetryable}, eventFailed.Job.State) + require.NoError(t, json.Unmarshal(eventFailed.Job.Metadata, &retryableMetadata)) + require.Equal(t, "step1", retryableMetadata["river:resumable_step"]) + require.Equal(t, map[string]any{"step2": float64(1)}, retryableMetadata["river:resumable_cursor"]) + + // Wait for the retried attempt to resume and then complete successfully. + eventCompleted := riversharedtest.WaitOrTimeout(t, bundle.subscribeChan) + require.Equal(t, EventKindJobCompleted, eventCompleted.Kind) + require.Equal(t, insertRes.Job.ID, eventCompleted.Job.ID) + + reloadedJob, err := client.JobGet(ctx, insertRes.Job.ID) + require.NoError(t, err) + require.Equal(t, rivertype.JobStateCompleted, reloadedJob.State) + require.Len(t, reloadedJob.Errors, 1) + + var metadata map[string]any + require.NoError(t, json.Unmarshal(reloadedJob.Metadata, &metadata)) + require.Equal(t, "step1", metadata["river:resumable_step"]) + require.Equal(t, map[string]any{"step2": float64(1)}, metadata["river:resumable_cursor"]) + + require.Equal(t, []string{ + "step1", + "step2:0", + "item:1", + "step2:1", + "item:2", + "step3", + }, worker.Calls()) + }) + t.Run("JobThatReturnsJobCancelErrorIsImmediatelyCancelled", func(t *testing.T) { t.Parallel() @@ -7602,7 +7710,7 @@ func Test_NewClient_Validations(t *testing.T) { }, validateResult: func(t *testing.T, client *Client[pgx.Tx]) { //nolint:thelper require.Len(t, client.middlewareLookupGlobal.ByMiddlewareKind(middlewarelookup.MiddlewareKindJobInsert), 1) - require.Len(t, client.middlewareLookupGlobal.ByMiddlewareKind(middlewarelookup.MiddlewareKindWorker), 1) + require.Len(t, client.middlewareLookupGlobal.ByMiddlewareKind(middlewarelookup.MiddlewareKindWorker), 2) }, }, { @@ -7613,7 +7721,7 @@ func Test_NewClient_Validations(t *testing.T) { }, validateResult: func(t *testing.T, client *Client[pgx.Tx]) { //nolint:thelper require.Len(t, client.middlewareLookupGlobal.ByMiddlewareKind(middlewarelookup.MiddlewareKindJobInsert), 2) - require.Len(t, client.middlewareLookupGlobal.ByMiddlewareKind(middlewarelookup.MiddlewareKindWorker), 2) + require.Len(t, client.middlewareLookupGlobal.ByMiddlewareKind(middlewarelookup.MiddlewareKindWorker), 3) }, }, { @@ -7625,7 +7733,7 @@ func Test_NewClient_Validations(t *testing.T) { }, validateResult: func(t *testing.T, client *Client[pgx.Tx]) { //nolint:thelper require.Len(t, client.middlewareLookupGlobal.ByMiddlewareKind(middlewarelookup.MiddlewareKindJobInsert), 1) - require.Len(t, client.middlewareLookupGlobal.ByMiddlewareKind(middlewarelookup.MiddlewareKindWorker), 1) + require.Len(t, client.middlewareLookupGlobal.ByMiddlewareKind(middlewarelookup.MiddlewareKindWorker), 2) }, }, { diff --git a/example_resumable_cursor_job_test.go b/example_resumable_cursor_job_test.go new file mode 100644 index 00000000..e49e9d49 --- /dev/null +++ b/example_resumable_cursor_job_test.go @@ -0,0 +1,105 @@ +package river_test + +import ( + "context" + "fmt" + "log/slog" + "os" + + "github.com/jackc/pgx/v5/pgxpool" + + "github.com/riverqueue/river" + "github.com/riverqueue/river/riverdbtest" + "github.com/riverqueue/river/riverdriver/riverpgxv5" + "github.com/riverqueue/river/rivershared/riversharedtest" + "github.com/riverqueue/river/rivershared/util/slogutil" + "github.com/riverqueue/river/rivershared/util/testutil" +) + +type ResumableCursorArgs struct { + IDs []int `json:"ids"` +} + +func (ResumableCursorArgs) Kind() string { return "resumable_cursor" } + +type ResumableCursor struct { + LastProcessedID int `json:"last_processed_id"` +} + +type ResumableCursorWorker struct { + river.WorkerDefaults[ResumableCursorArgs] +} + +func (w *ResumableCursorWorker) Work(ctx context.Context, job *river.Job[ResumableCursorArgs]) error { + river.ResumableStepCursor(ctx, "process_ids", func(ctx context.Context, cursor ResumableCursor) error { + for _, id := range job.Args.IDs { + if id <= cursor.LastProcessedID { + continue + } + + fmt.Printf("Processed %d\n", id) + + if err := river.ResumableSetCursor(ctx, ResumableCursor{LastProcessedID: id}); err != nil { + return err + } + } + + return nil + }) + + return nil +} + +// Example_resumableCursor demonstrates the use of a resumable cursor step, a +// step that can store arbitrary JSON state to resume a partially completed loop. +func Example_resumableCursor() { //nolint:dupl + ctx := context.Background() + + dbPool, err := pgxpool.New(ctx, riversharedtest.TestDatabaseURL()) + if err != nil { + panic(err) + } + defer dbPool.Close() + + workers := river.NewWorkers() + river.AddWorker(workers, &ResumableCursorWorker{}) + + riverClient, err := river.NewClient(riverpgxv5.New(dbPool), &river.Config{ + Logger: slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelWarn, ReplaceAttr: slogutil.NoLevelTime})), + Queues: map[string]river.QueueConfig{ + river.QueueDefault: {MaxWorkers: 100}, + }, + Schema: riverdbtest.TestSchema(ctx, testutil.PanicTB(), riverpgxv5.New(dbPool), nil), // only necessary for the example test + TestOnly: true, // suitable only for use in tests; remove for live environments + Workers: workers, + }) + if err != nil { + panic(err) + } + + // Out of example scope, but used to wait until a job is worked. + subscribeChan, subscribeCancel := riverClient.Subscribe(river.EventKindJobCompleted) + defer subscribeCancel() + + if err := riverClient.Start(ctx); err != nil { + panic(err) + } + + if _, err = riverClient.Insert(ctx, ResumableCursorArgs{ + IDs: []int{1, 2, 3}, + }, nil); err != nil { + panic(err) + } + + // Wait for jobs to complete. Only needed for purposes of the example test. + riversharedtest.WaitOrTimeoutN(testutil.PanicTB(), subscribeChan, 1) + + if err := riverClient.Stop(ctx); err != nil { + panic(err) + } + + // Output: + // Processed 1 + // Processed 2 + // Processed 3 +} diff --git a/example_resumable_job_test.go b/example_resumable_job_test.go new file mode 100644 index 00000000..698cd0d3 --- /dev/null +++ b/example_resumable_job_test.go @@ -0,0 +1,96 @@ +package river_test + +import ( + "context" + "fmt" + "log/slog" + "os" + + "github.com/jackc/pgx/v5/pgxpool" + + "github.com/riverqueue/river" + "github.com/riverqueue/river/riverdbtest" + "github.com/riverqueue/river/riverdriver/riverpgxv5" + "github.com/riverqueue/river/rivershared/riversharedtest" + "github.com/riverqueue/river/rivershared/util/slogutil" + "github.com/riverqueue/river/rivershared/util/testutil" +) + +type ResumableArgs struct{} + +func (ResumableArgs) Kind() string { return "resumable" } + +type ResumableWorker struct { + river.WorkerDefaults[ResumableArgs] +} + +func (w *ResumableWorker) Work(ctx context.Context, job *river.Job[ResumableArgs]) error { + river.ResumableStep(ctx, "step1", func(ctx context.Context) error { + fmt.Println("Step 1") + return nil + }) + + river.ResumableStep(ctx, "step2", func(ctx context.Context) error { + fmt.Println("Step 2") + return nil + }) + + river.ResumableStep(ctx, "step3", func(ctx context.Context) error { + fmt.Println("Step 3") + return nil + }) + + return nil +} + +// Example_resumable demonstrates the use of a "resumable job", a job that has +// multiple steps, and which can be resumed after each one. +func Example_resumable() { //nolint:dupl + ctx := context.Background() + + dbPool, err := pgxpool.New(ctx, riversharedtest.TestDatabaseURL()) + if err != nil { + panic(err) + } + defer dbPool.Close() + + workers := river.NewWorkers() + river.AddWorker(workers, &ResumableWorker{}) + + riverClient, err := river.NewClient(riverpgxv5.New(dbPool), &river.Config{ + Logger: slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelWarn, ReplaceAttr: slogutil.NoLevelTime})), + Queues: map[string]river.QueueConfig{ + river.QueueDefault: {MaxWorkers: 100}, + }, + Schema: riverdbtest.TestSchema(ctx, testutil.PanicTB(), riverpgxv5.New(dbPool), nil), // only necessary for the example test + TestOnly: true, // suitable only for use in tests; remove for live environments + Workers: workers, + }) + if err != nil { + panic(err) + } + + // Out of example scope, but used to wait until a job is worked. + subscribeChan, subscribeCancel := riverClient.Subscribe(river.EventKindJobCompleted) + defer subscribeCancel() + + if err := riverClient.Start(ctx); err != nil { + panic(err) + } + + if _, err = riverClient.Insert(ctx, ResumableArgs{}, nil); err != nil { + panic(err) + } + + // Wait for jobs to complete. Only needed for purposes of the example test. + riversharedtest.WaitOrTimeoutN(testutil.PanicTB(), subscribeChan, 1) + + if err := riverClient.Stop(ctx); err != nil { + panic(err) + } + + // Output: + // Step 1 + // Step 2 + // Step 3 +} diff --git a/example_resumable_set_step_tx_test.go b/example_resumable_set_step_tx_test.go new file mode 100644 index 00000000..3d5e9566 --- /dev/null +++ b/example_resumable_set_step_tx_test.go @@ -0,0 +1,126 @@ +package river_test + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "log/slog" + "os" + + "github.com/jackc/pgx/v5/pgxpool" + + "github.com/riverqueue/river" + "github.com/riverqueue/river/internal/rivercommon" + "github.com/riverqueue/river/riverdbtest" + "github.com/riverqueue/river/riverdriver/riverpgxv5" + "github.com/riverqueue/river/rivershared/riversharedtest" + "github.com/riverqueue/river/rivershared/util/slogutil" + "github.com/riverqueue/river/rivershared/util/testutil" +) + +type ResumableStepTxArgs struct{} + +func (ResumableStepTxArgs) Kind() string { return "resumable_step_tx" } + +// ResumableStepTxWorker persists resumable step progress transactionally before +// failing the job. +type ResumableStepTxWorker struct { + river.WorkerDefaults[ResumableStepTxArgs] + + dbPool *pgxpool.Pool +} + +func (w *ResumableStepTxWorker) Work(ctx context.Context, job *river.Job[ResumableStepTxArgs]) error { + const durableStep = "durable_step" + + river.ResumableStep(ctx, durableStep, func(ctx context.Context) error { + tx, err := w.dbPool.Begin(ctx) + if err != nil { + return err + } + defer tx.Rollback(ctx) + + // Perform some kind database work in a transaction. + var result int + if err := tx.QueryRow(ctx, "SELECT 1").Scan(&result); err != nil { + return err + } + + // Then, record the step as completed in the same transaction. + if _, err := river.ResumableSetStepTx[*riverpgxv5.Driver](ctx, tx, job, durableStep); err != nil { + return err + } + + if err := tx.Commit(ctx); err != nil { + return err + } + + return errors.New("simulated failure after persisting step") + }) + + return nil +} + +// Example_resumableSetStepTx demonstrates how to transactionally persist a +// resumable step so it survives a failed attempt. +func Example_resumableSetStepTx() { + ctx := context.Background() + + dbPool, err := pgxpool.New(ctx, riversharedtest.TestDatabaseURL()) + if err != nil { + panic(err) + } + defer dbPool.Close() + + workers := river.NewWorkers() + river.AddWorker(workers, &ResumableStepTxWorker{dbPool: dbPool}) + + riverClient, err := river.NewClient(riverpgxv5.New(dbPool), &river.Config{ + Logger: slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelWarn, ReplaceAttr: slogutil.NoLevelTime})), + Queues: map[string]river.QueueConfig{ + river.QueueDefault: {MaxWorkers: 100}, + }, + Schema: riverdbtest.TestSchema(ctx, testutil.PanicTB(), riverpgxv5.New(dbPool), nil), // only necessary for the example test + TestOnly: true, // suitable only for use in tests; remove for live environments + Workers: workers, + }) + if err != nil { + panic(err) + } + + // Used only to help the example test wait for the failed attempt. + subscribeChan, subscribeCancel := riverClient.Subscribe(river.EventKindJobFailed) + defer subscribeCancel() + + if err := riverClient.Start(ctx); err != nil { + panic(err) + } + + insertRes, err := riverClient.Insert(ctx, ResumableStepTxArgs{}, nil) + if err != nil { + panic(err) + } + + // Wait for the failed attempt so the persisted step can be inspected. + riversharedtest.WaitOrTimeoutN(testutil.PanicTB(), subscribeChan, 1) + + jobAfter, err := riverClient.JobGet(ctx, insertRes.Job.ID) + if err != nil { + panic(err) + } + + var metadata map[string]any + if err := json.Unmarshal(jobAfter.Metadata, &metadata); err != nil { + panic(err) + } + + fmt.Printf("Persisted resumable step: %s\n", metadata[rivercommon.MetadataKeyResumableStep]) + + if err := riverClient.Stop(ctx); err != nil { + panic(err) + } + + // Output: + // Persisted resumable step: durable_step +} diff --git a/internal/maintenance/job_scheduler_test.go b/internal/maintenance/job_scheduler_test.go index 25388042..ecd9009e 100644 --- a/internal/maintenance/job_scheduler_test.go +++ b/internal/maintenance/job_scheduler_test.go @@ -330,7 +330,8 @@ func TestJobScheduler(t *testing.T) { addJob := func(queue string, fromNow time.Duration, state rivertype.JobState) { t.Helper() var finalizedAt *time.Time - switch state { //nolint:exhaustive + switch state { + case rivertype.JobStateAvailable, rivertype.JobStatePending, rivertype.JobStateRetryable, rivertype.JobStateRunning, rivertype.JobStateScheduled: case rivertype.JobStateCompleted, rivertype.JobStateCancelled, rivertype.JobStateDiscarded: finalizedAt = ptrutil.Ptr(now.Add(fromNow)) } diff --git a/internal/rivercommon/river_common.go b/internal/rivercommon/river_common.go index d409ccf1..986583ed 100644 --- a/internal/rivercommon/river_common.go +++ b/internal/rivercommon/river_common.go @@ -24,6 +24,14 @@ const ( // them. MetadataKeyPeriodicJobID = "river:periodic_job_id" + // MetadataKeyResumableStep records the last successfully completed step for + // a resumable job so later attempts can skip ahead. + MetadataKeyResumableStep = "river:resumable_step" + + // MetadataKeyResumableCursor records a resumable step cursor so a later + // attempt can resume a partially completed step. + MetadataKeyResumableCursor = "river:resumable_cursor" + // MetadataKeyRescueCount records how many times the job has been rescued. MetadataKeyRescueCount = "river:rescue_count" diff --git a/job_list_params.go b/job_list_params.go index 116a5c76..cebdf57b 100644 --- a/job_list_params.go +++ b/job_list_params.go @@ -234,11 +234,10 @@ func (p *JobListParams) toDBParams() (*dblist.JobListParams, error) { if p.sortField == JobListOrderByFinalizedAt { currentNonFinalizedStates := make([]rivertype.JobState, 0, len(p.states)) for _, state := range p.states { - //nolint:exhaustive switch state { - case rivertype.JobStateCancelled, rivertype.JobStateCompleted, rivertype.JobStateDiscarded: - default: + case rivertype.JobStateAvailable, rivertype.JobStatePending, rivertype.JobStateRetryable, rivertype.JobStateRunning, rivertype.JobStateScheduled: currentNonFinalizedStates = append(currentNonFinalizedStates, state) + case rivertype.JobStateCancelled, rivertype.JobStateCompleted, rivertype.JobStateDiscarded: } } // This indicates the user overrode the States list with only non-finalized diff --git a/resumable.go b/resumable.go new file mode 100644 index 00000000..6221e5cc --- /dev/null +++ b/resumable.go @@ -0,0 +1,205 @@ +package river + +import ( + "context" + "encoding/json" + "errors" + "fmt" + + "github.com/tidwall/gjson" + + "github.com/riverqueue/river/internal/jobexecutor" + "github.com/riverqueue/river/internal/rivercommon" + "github.com/riverqueue/river/rivertype" +) + +var ( + errResumableStepNotInWorker = errors.New("river: resumable step can only be used within a Worker") + errResumableCursorNotInStep = errors.New("river: resumable cursor can only be used within ResumableStepCursor") +) + +type resumableContextKey struct{} + +type resumableState struct { + completedStep string + cursors map[string]json.RawMessage + err error + resumeMatched bool + resumeStep string + stepName string +} + +// ResumableSetCursor records a cursor for the current resumable cursor step. +// The cursor is stored only if the job attempt ends in an error, allowing a +// later retry to resume the same step from the recorded position. +// +// Alternatively, ResumableSetStepCursorTx is available to persist a step and +// cursor immediately as part of a transaction, guaranteeing that it's stored +// durably. +func ResumableSetCursor[TCursor any](ctx context.Context, cursor TCursor) error { + state := mustResumableState(ctx) + if state.stepName == "" { + return errResumableCursorNotInStep + } + + cursorBytes, err := json.Marshal(cursor) + if err != nil { + return err + } + + if state.cursors == nil { + state.cursors = make(map[string]json.RawMessage) + } + state.cursors[state.stepName] = json.RawMessage(cursorBytes) + return nil +} + +// ResumableStep runs a resumable step, skipping the step on a later retry if +// an earlier attempt already completed it successfully. +// +// After a step returns an error, no subsequent steps will be run and the +// overall job will be marked as failed with that error. Be careful to put all +// executable code in steps, because any code outside of them will be run, even +// if a step returned an error. +func ResumableStep(ctx context.Context, name string, stepFunc func(ctx context.Context) error) { + state := mustResumableState(ctx) + if state.err != nil { + return + } + + if !state.resumeMatched { + if name == state.resumeStep { + state.completedStep = name + state.resumeMatched = true + } + return + } + + if err := stepFunc(ctx); err != nil { + state.err = err + return + } + + state.completedStep = name +} + +// ResumableStepCursor runs a resumable step that also receives a persisted +// cursor value from an earlier failed attempt, if one was recorded with +// ResumableSetCursor. +// +// The cursor type T is user-specified. It may be a primitive value like an +// integer ID, or a more complex type like a struct with multiple fields. It's +// stored in a job's metadata, so it needs to be marshable and unmarshable to +// and from JSON. +// +// Notably, it's the responsibility of the step function to call +// ResumableSetCursor with an updated cursor value as progress is made, and to +// check the cursor value before running to determine where to resume from. +// +// After a step returns an error, no subsequent steps will be run and the +// overall job will be marked as failed with that error. Be careful to put all +// executable code in steps, because any code outside of them will be run, even +// if a step returned an error. +func ResumableStepCursor[TCursor any](ctx context.Context, name string, stepFunc func(ctx context.Context, cursor TCursor) error) { + state := mustResumableState(ctx) + if state.err != nil { + return + } + + if !state.resumeMatched { + if name == state.resumeStep { + state.completedStep = name + state.resumeMatched = true + } + return + } + + var cursor TCursor + if cursorBytes, ok := state.cursors[name]; ok && len(cursorBytes) > 0 { + if err := json.Unmarshal(cursorBytes, &cursor); err != nil { + state.err = fmt.Errorf("river: unmarshal resumable cursor for step %q: %w", name, err) + return + } + } + + previousStepName := state.stepName + state.stepName = name + defer func() { state.stepName = previousStepName }() + + if err := stepFunc(ctx, cursor); err != nil { + state.err = err + return + } + + state.completedStep = name + delete(state.cursors, name) +} + +type resumableMiddleware struct { + MiddlewareDefaults +} + +func (*resumableMiddleware) Work(ctx context.Context, job *rivertype.JobRow, doInner func(ctx context.Context) error) error { + metadataUpdates, hasMetadataUpdates := jobexecutor.MetadataUpdatesFromWorkContext(ctx) + if !hasMetadataUpdates { + return errors.New("expected to find metadata updates in context, but didn't") + } + + state := &resumableState{ + cursors: make(map[string]json.RawMessage), + resumeMatched: true, + resumeStep: gjson.GetBytes(job.Metadata, rivercommon.MetadataKeyResumableStep).Str, + } + if state.resumeStep != "" { + state.resumeMatched = false + } + if cursorJSON := gjson.GetBytes(job.Metadata, rivercommon.MetadataKeyResumableCursor); cursorJSON.Exists() && cursorJSON.Type == gjson.JSON { + if err := json.Unmarshal([]byte(cursorJSON.Raw), &state.cursors); err != nil { + return fmt.Errorf("river: unmarshal resumable cursors: %w", err) + } + } + + ctx = context.WithValue(ctx, resumableContextKey{}, state) + + err := doInner(ctx) + if err == nil { + switch { + case state.err != nil: + err = state.err + case state.resumeStep != "" && !state.resumeMatched: + err = fmt.Errorf("river: resumable step %q not found in Worker", state.resumeStep) + } + } + + if err != nil && state.completedStep != "" { + if len(state.cursors) > 0 { + metadataUpdates[rivercommon.MetadataKeyResumableCursor] = state.cursors + } + metadataUpdates[rivercommon.MetadataKeyResumableStep] = state.completedStep + } + + return err +} + +func mustResumableState(ctx context.Context) *resumableState { + typedState, ok := resumableStateFromContext(ctx) + if !ok { + panic(errResumableStepNotInWorker) + } + + return typedState +} + +func resumableStateFromContext(ctx context.Context) (*resumableState, bool) { + state := ctx.Value(resumableContextKey{}) + if state == nil { + return nil, false + } + + typedState, ok := state.(*resumableState) + if !ok || typedState == nil { + return nil, false + } + + return typedState, true +} diff --git a/resumable_step_tx.go b/resumable_step_tx.go new file mode 100644 index 00000000..38ce9413 --- /dev/null +++ b/resumable_step_tx.go @@ -0,0 +1,112 @@ +package river + +import ( + "context" + "encoding/json" + "errors" + + "github.com/riverqueue/river/internal/execution" + "github.com/riverqueue/river/internal/jobexecutor" + "github.com/riverqueue/river/internal/rivercommon" + "github.com/riverqueue/river/riverdriver" + "github.com/riverqueue/river/rivertype" +) + +// ResumableSetStepTx immediately persists a resumable job's completed step as +// part of transaction tx. If tx is rolled back, the step update will be as +// well. +// +// Normally, a resumable job's step progress is recorded after it runs along +// with its result status. This is normally sufficient, but because it happens +// out-of-transaction, there's a chance that it doesn't happen in case of panic +// or other abrupt termination. This function useful in cases where a resumable +// worker needs a guarantee of a checkpoint being recorded durably, at the cost +// of an extra database operation. +func ResumableSetStepTx[TDriver riverdriver.Driver[TTx], TTx any, TArgs JobArgs](ctx context.Context, tx TTx, job *Job[TArgs], step string) (*Job[TArgs], error) { + return resumableSetStepTx(ctx, tx, job, step, nil) +} + +// ResumableSetStepCursorTx immediately persists a resumable job's completed +// step and cursor as part of transaction tx. If tx is rolled back, the step +// and cursor update will be as well. +// +// Normally, a resumable job's step progress is recorded after it runs along +// with its result status. This is normally sufficient, but because it happens +// out-of-transaction, there's a chance that it doesn't happen in case of panic +// or other abrupt termination. This function useful in cases where a resumable +// worker needs a guarantee of a checkpoint being recorded durably, at the cost +// of an extra database operation. +func ResumableSetStepCursorTx[TDriver riverdriver.Driver[TTx], TTx any, TArgs JobArgs, TCursor any](ctx context.Context, tx TTx, job *Job[TArgs], step string, cursor TCursor) (*Job[TArgs], error) { + cursorBytes, err := json.Marshal(cursor) + if err != nil { + return nil, err + } + + return resumableSetStepTx(ctx, tx, job, step, json.RawMessage(cursorBytes)) +} + +func resumableSetStepTx[TTx any, TArgs JobArgs](ctx context.Context, tx TTx, job *Job[TArgs], step string, cursor json.RawMessage) (*Job[TArgs], error) { + if job.State != rivertype.JobStateRunning { + return nil, errors.New("job must be running") + } + + client := ClientFromContext[TTx](ctx) + if client == nil { + return nil, errors.New("client not found in context, can only work within a River worker") + } + + metadataUpdates := map[string]any{ + rivercommon.MetadataKeyResumableStep: step, + } + + if state, ok := resumableStateFromContext(ctx); ok { + state.completedStep = step + if cursor != nil { + if state.cursors == nil { + state.cursors = make(map[string]json.RawMessage) + } + state.cursors[step] = cursor + } + if len(state.cursors) > 0 { + metadataUpdates[rivercommon.MetadataKeyResumableCursor] = state.cursors + } + } else if cursor != nil { + metadataUpdates[rivercommon.MetadataKeyResumableCursor] = map[string]json.RawMessage{step: cursor} + } + + workMetadataUpdates, hasWorkMetadataUpdates := jobexecutor.MetadataUpdatesFromWorkContext(ctx) + if hasWorkMetadataUpdates { + workMetadataUpdates[rivercommon.MetadataKeyResumableStep] = step + if resumableCursorMetadata, ok := metadataUpdates[rivercommon.MetadataKeyResumableCursor]; ok { + workMetadataUpdates[rivercommon.MetadataKeyResumableCursor] = resumableCursorMetadata + } + } + + metadataUpdatesBytes, err := json.Marshal(metadataUpdates) + if err != nil { + return nil, err + } + + updatedJob, err := client.Driver().UnwrapExecutor(tx).JobUpdate(ctx, &riverdriver.JobUpdateParams{ + ID: job.ID, + MetadataDoMerge: true, + Metadata: metadataUpdatesBytes, + Schema: client.config.Schema, + }) + if err != nil { + if errors.Is(err, rivertype.ErrNotFound) { + if _, isInsideTestWorker := ctx.Value(execution.ContextKeyInsideTestWorker{}).(bool); isInsideTestWorker { + panic("to use ResumableSetStepTx or ResumableSetStepCursorTx in a rivertest.Worker, the job must be inserted into the database first") + } + } + + return nil, err + } + + result := &Job[TArgs]{JobRow: updatedJob} + if err := json.Unmarshal(result.EncodedArgs, &result.Args); err != nil { + return nil, err + } + + return result, nil +} diff --git a/resumable_step_tx_test.go b/resumable_step_tx_test.go new file mode 100644 index 00000000..5372d8c8 --- /dev/null +++ b/resumable_step_tx_test.go @@ -0,0 +1,153 @@ +package river + +import ( + "context" + "encoding/json" + "testing" + + "github.com/jackc/pgx/v5" + "github.com/stretchr/testify/require" + + "github.com/riverqueue/river/internal/execution" + "github.com/riverqueue/river/internal/jobexecutor" + "github.com/riverqueue/river/internal/rivercommon" + "github.com/riverqueue/river/riverdbtest" + "github.com/riverqueue/river/riverdriver" + "github.com/riverqueue/river/riverdriver/riverpgxv5" + "github.com/riverqueue/river/rivershared/riversharedtest" + "github.com/riverqueue/river/rivershared/testfactory" + "github.com/riverqueue/river/rivershared/util/ptrutil" + "github.com/riverqueue/river/rivershared/util/testutil" + "github.com/riverqueue/river/rivertype" +) + +func TestResumableSetStepTx(t *testing.T) { + t.Parallel() + + ctx := context.Background() + + type JobArgs struct { + testutil.JobArgsReflectKind[JobArgs] + } + + type testBundle struct { + client *Client[pgx.Tx] + exec riverdriver.Executor + tx pgx.Tx + } + + setup := func(ctx context.Context, t *testing.T) (context.Context, *testBundle) { + t.Helper() + + tx := riverdbtest.TestTxPgx(ctx, t) + client, err := NewClient(riverpgxv5.New(nil), &Config{ + Logger: riversharedtest.Logger(t), + }) + require.NoError(t, err) + ctx = context.WithValue(ctx, rivercommon.ContextKeyClient{}, client) + + return ctx, &testBundle{ + client: client, + exec: riverpgxv5.New(nil).UnwrapExecutor(tx), + tx: tx, + } + } + + t.Run("SetsStep", func(t *testing.T) { + t.Parallel() + + ctx, bundle := setup(ctx, t) + ctx = context.WithValue(ctx, jobexecutor.ContextKeyMetadataUpdates, make(map[string]any)) + + job := testfactory.Job(ctx, t, bundle.exec, &testfactory.JobOpts{ + State: ptrutil.Ptr(rivertype.JobStateRunning), + }) + + updatedJob, err := ResumableSetStepTx[*riverpgxv5.Driver](ctx, bundle.tx, &Job[JobArgs]{JobRow: job}, "step1") + require.NoError(t, err) + require.Equal(t, rivertype.JobStateRunning, updatedJob.State) + + reloadedJob, err := bundle.exec.JobGetByID(ctx, &riverdriver.JobGetByIDParams{ID: job.ID}) + require.NoError(t, err) + + var metadata map[string]any + require.NoError(t, json.Unmarshal(reloadedJob.Metadata, &metadata)) + require.Equal(t, "step1", metadata[rivercommon.MetadataKeyResumableStep]) + }) + + t.Run("SetsStepAndCursor", func(t *testing.T) { + t.Parallel() + + ctx, bundle := setup(ctx, t) + ctx = context.WithValue(ctx, jobexecutor.ContextKeyMetadataUpdates, make(map[string]any)) + ctx = context.WithValue(ctx, resumableContextKey{}, &resumableState{cursors: make(map[string]json.RawMessage)}) + + job := testfactory.Job(ctx, t, bundle.exec, &testfactory.JobOpts{ + State: ptrutil.Ptr(rivertype.JobStateRunning), + }) + + type Cursor struct { + ID int `json:"id"` + } + + updatedJob, err := ResumableSetStepCursorTx[*riverpgxv5.Driver](ctx, bundle.tx, &Job[JobArgs]{JobRow: job}, "step2", Cursor{ID: 123}) + require.NoError(t, err) + require.Equal(t, rivertype.JobStateRunning, updatedJob.State) + + reloadedJob, err := bundle.exec.JobGetByID(ctx, &riverdriver.JobGetByIDParams{ID: job.ID}) + require.NoError(t, err) + + var metadata map[string]any + require.NoError(t, json.Unmarshal(reloadedJob.Metadata, &metadata)) + require.Equal(t, "step2", metadata[rivercommon.MetadataKeyResumableStep]) + require.Equal(t, map[string]any{"step2": map[string]any{"id": float64(123)}}, metadata[rivercommon.MetadataKeyResumableCursor]) + + metadataUpdates, ok := jobexecutor.MetadataUpdatesFromWorkContext(ctx) + require.True(t, ok) + require.Equal(t, "step2", metadataUpdates[rivercommon.MetadataKeyResumableStep]) + }) + + t.Run("ErrorIfNotRunning", func(t *testing.T) { + t.Parallel() + + ctx, bundle := setup(ctx, t) + + job := testfactory.Job(ctx, t, bundle.exec, &testfactory.JobOpts{}) + + _, err := ResumableSetStepTx[*riverpgxv5.Driver](ctx, bundle.tx, &Job[JobArgs]{JobRow: job}, "step1") + require.EqualError(t, err, "job must be running") + }) + + t.Run("ErrorIfJobDoesntExist", func(t *testing.T) { + t.Parallel() + + ctx, bundle := setup(ctx, t) + + job := testfactory.Job(ctx, t, bundle.exec, &testfactory.JobOpts{ + State: ptrutil.Ptr(rivertype.JobStateAvailable), + }) + _, err := bundle.exec.JobDelete(ctx, &riverdriver.JobDeleteParams{ID: job.ID}) + require.NoError(t, err) + + job.State = rivertype.JobStateRunning + _, err = ResumableSetStepTx[*riverpgxv5.Driver](ctx, bundle.tx, &Job[JobArgs]{JobRow: job}, "step1") + require.ErrorIs(t, err, rivertype.ErrNotFound) + }) + + t.Run("PanicsIfCalledInTestWorkerWithoutInsertingJob", func(t *testing.T) { + t.Parallel() + + ctx, bundle := setup(ctx, t) + ctx = context.WithValue(ctx, execution.ContextKeyInsideTestWorker{}, true) + + job := testfactory.Job(ctx, t, bundle.exec, &testfactory.JobOpts{State: ptrutil.Ptr(rivertype.JobStateAvailable)}) + _, err := bundle.client.JobDeleteTx(ctx, bundle.tx, job.ID) + require.NoError(t, err) + job.State = rivertype.JobStateRunning + + require.PanicsWithValue(t, "to use ResumableSetStepTx or ResumableSetStepCursorTx in a rivertest.Worker, the job must be inserted into the database first", func() { + _, err := ResumableSetStepTx[*riverpgxv5.Driver](ctx, bundle.tx, &Job[JobArgs]{JobRow: job}, "step1") + require.NoError(t, err) + }) + }) +} diff --git a/resumable_test.go b/resumable_test.go new file mode 100644 index 00000000..00169616 --- /dev/null +++ b/resumable_test.go @@ -0,0 +1,252 @@ +package river + +import ( + "context" + "encoding/json" + "errors" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/riverqueue/river/internal/jobexecutor" + "github.com/riverqueue/river/internal/rivercommon" + "github.com/riverqueue/river/rivertype" +) + +func TestResumableStep(t *testing.T) { + t.Parallel() + + setup := func(t *testing.T, metadata string) (context.Context, map[string]any, *rivertype.JobRow) { + t.Helper() + + metadataUpdates := make(map[string]any) + ctx := context.WithValue(context.Background(), jobexecutor.ContextKeyMetadataUpdates, metadataUpdates) + + return ctx, metadataUpdates, &rivertype.JobRow{Metadata: []byte(metadata)} + } + + t.Run("PanicsOutsideWorker", func(t *testing.T) { + t.Parallel() + + require.PanicsWithError(t, errResumableStepNotInWorker.Error(), func() { + ResumableStep(context.Background(), "step1", func(ctx context.Context) error { return nil }) + }) + }) + + t.Run("ResumesFromLastCompletedStep", func(t *testing.T) { + t.Parallel() + + ctx, metadataUpdates, job := setup(t, `{}`) + + var ran []string + err := (&resumableMiddleware{}).Work(ctx, job, func(ctx context.Context) error { + ResumableStep(ctx, "step1", func(ctx context.Context) error { + ran = append(ran, "step1") + return nil + }) + ResumableStep(ctx, "step2", func(ctx context.Context) error { + ran = append(ran, "step2") + return errors.New("step2 failed") + }) + ResumableStep(ctx, "step3", func(ctx context.Context) error { + ran = append(ran, "step3") + return nil + }) + + return nil + }) + require.EqualError(t, err, "step2 failed") + require.Equal(t, []string{"step1", "step2"}, ran) + require.Equal(t, "step1", metadataUpdates[rivercommon.MetadataKeyResumableStep]) + + ctx, metadataUpdates, job = setup(t, `{"river:resumable_step":"step1"}`) + ran = nil + + err = (&resumableMiddleware{}).Work(ctx, job, func(ctx context.Context) error { + ResumableStep(ctx, "step1", func(ctx context.Context) error { + ran = append(ran, "step1") + return nil + }) + ResumableStep(ctx, "step2", func(ctx context.Context) error { + ran = append(ran, "step2") + return nil + }) + ResumableStep(ctx, "step3", func(ctx context.Context) error { + ran = append(ran, "step3") + return nil + }) + + return nil + }) + require.NoError(t, err) + require.Equal(t, []string{"step2", "step3"}, ran) + require.Empty(t, metadataUpdates) + }) + + t.Run("SavesLastCompletedStepOnContextCancellation", func(t *testing.T) { + t.Parallel() + + baseCtx, metadataUpdates, job := setup(t, `{}`) + ctx, cancel := context.WithCancel(baseCtx) + defer cancel() + + var ran []string + err := (&resumableMiddleware{}).Work(ctx, job, func(ctx context.Context) error { + ResumableStep(ctx, "step1", func(ctx context.Context) error { + ran = append(ran, "step1") + cancel() + return nil + }) + ResumableStep(ctx, "step2", func(ctx context.Context) error { + ran = append(ran, "step2") + return ctx.Err() + }) + ResumableStep(ctx, "step3", func(ctx context.Context) error { + ran = append(ran, "step3") + return nil + }) + + return nil + }) + require.ErrorIs(t, err, context.Canceled) + require.Equal(t, []string{"step1", "step2"}, ran) + require.Equal(t, "step1", metadataUpdates[rivercommon.MetadataKeyResumableStep]) + }) +} + +func TestResumableStepCursor(t *testing.T) { + t.Parallel() + + type resumableCursor struct { + ID int `json:"id"` + } + + setup := func(t *testing.T, metadata string) (context.Context, map[string]any, *rivertype.JobRow) { + t.Helper() + + metadataUpdates := make(map[string]any) + ctx := context.WithValue(context.Background(), jobexecutor.ContextKeyMetadataUpdates, metadataUpdates) + + return ctx, metadataUpdates, &rivertype.JobRow{Metadata: []byte(metadata)} + } + + t.Run("ResumesCursor", func(t *testing.T) { + t.Parallel() + + ctx, metadataUpdates, job := setup(t, `{}`) + + var ( + cursorResult resumableCursor + ran []int + ) + err := (&resumableMiddleware{}).Work(ctx, job, func(ctx context.Context) error { + ResumableStep(ctx, "step1", func(ctx context.Context) error { + ran = append(ran, 1) + return nil + }) + ResumableStepCursor(ctx, "step2", func(ctx context.Context, cursor resumableCursor) error { + cursorResult = cursor + ran = append(ran, 2) + require.NoError(t, ResumableSetCursor(ctx, resumableCursor{ID: 42})) + return errors.New("step2 failed") + }) + ResumableStep(ctx, "step3", func(ctx context.Context) error { + ran = append(ran, 3) + return nil + }) + + return nil + }) + require.EqualError(t, err, "step2 failed") + require.Equal(t, resumableCursor{}, cursorResult) + require.Equal(t, []int{1, 2}, ran) + require.Equal(t, "step1", metadataUpdates[rivercommon.MetadataKeyResumableStep]) + cursorMetadata, err := json.Marshal(metadataUpdates[rivercommon.MetadataKeyResumableCursor]) + require.NoError(t, err) + require.JSONEq(t, `{"step2":{"id":42}}`, string(cursorMetadata)) + + ctx, metadataUpdates, job = setup(t, `{"river:resumable_cursor":{"step2":{"id":42}},"river:resumable_step":"step1"}`) + cursorResult = resumableCursor{} + ran = nil + + err = (&resumableMiddleware{}).Work(ctx, job, func(ctx context.Context) error { + ResumableStep(ctx, "step1", func(ctx context.Context) error { + ran = append(ran, 1) + return nil + }) + ResumableStepCursor(ctx, "step2", func(ctx context.Context, cursor resumableCursor) error { + cursorResult = cursor + ran = append(ran, 2) + return nil + }) + ResumableStep(ctx, "step3", func(ctx context.Context) error { + ran = append(ran, 3) + return nil + }) + + return nil + }) + require.NoError(t, err) + require.Equal(t, resumableCursor{ID: 42}, cursorResult) + require.Equal(t, []int{2, 3}, ran) + require.Empty(t, metadataUpdates) + }) + + t.Run("SetCursorOutsideStep", func(t *testing.T) { + t.Parallel() + + ctx, _, _ := setup(t, `{}`) + + err := (&resumableMiddleware{}).Work(ctx, &rivertype.JobRow{Metadata: []byte(`{}`)}, func(ctx context.Context) error { + return ResumableSetCursor(ctx, 1) + }) + require.ErrorIs(t, err, errResumableCursorNotInStep) + }) + + t.Run("SupportsMultipleCursorStepsWithDifferentTypes", func(t *testing.T) { + t.Parallel() + + type secondCursor struct { + ID string `json:"id"` + } + + ctx, metadataUpdates, job := setup(t, `{}`) + + err := (&resumableMiddleware{}).Work(ctx, job, func(ctx context.Context) error { + ResumableStepCursor(ctx, "step1", func(ctx context.Context, cursor int) error { + require.Zero(t, cursor) + require.NoError(t, ResumableSetCursor(ctx, 123)) + return nil + }) + ResumableStepCursor(ctx, "step2", func(ctx context.Context, cursor secondCursor) error { + require.Equal(t, secondCursor{}, cursor) + require.NoError(t, ResumableSetCursor(ctx, secondCursor{ID: "abc"})) + return errors.New("step2 failed") + }) + + return nil + }) + require.EqualError(t, err, "step2 failed") + require.Equal(t, "step1", metadataUpdates[rivercommon.MetadataKeyResumableStep]) + cursorMetadata, err := json.Marshal(metadataUpdates[rivercommon.MetadataKeyResumableCursor]) + require.NoError(t, err) + require.JSONEq(t, `{"step2":{"id":"abc"}}`, string(cursorMetadata)) + + ctx, metadataUpdates, job = setup(t, `{"river:resumable_cursor":{"step1":123,"step2":{"id":"abc"}},"river:resumable_step":"step1"}`) + + err = (&resumableMiddleware{}).Work(ctx, job, func(ctx context.Context) error { + ResumableStepCursor(ctx, "step1", func(ctx context.Context, cursor int) error { + require.Equal(t, 123, cursor) + return nil + }) + ResumableStepCursor(ctx, "step2", func(ctx context.Context, cursor secondCursor) error { + require.Equal(t, secondCursor{ID: "abc"}, cursor) + return nil + }) + + return nil + }) + require.NoError(t, err) + require.Empty(t, metadataUpdates) + }) +}