From ec9be0480beccbf76ca9fce29f88f01651b961e6 Mon Sep 17 00:00:00 2001 From: Brandur Date: Mon, 20 Apr 2026 23:42:08 -0500 Subject: [PATCH] Add MySQL support to River Here, add a new driver `rivermysql` that brings MySQL support to River. Similar to SQLite, it's unfortunately not quite as good as the Postgres driver, but it does the job. MySQL has more facilities than SQLite, but it's still missing some major niceties like `RETURNING`, which for many queries requires us to write an implementation using two operations -- one that performs the action, and then another that loads back the result using returned IDs. It also doesn't have listen/notify. Luckily, _like_ SQLite, any of this nastiness stays cordoned to the driver layer and doesn't leak into the mainline River code. We've got good driver testing and basic tests on clients for each driver in place so we get reasonable assurance that everything works. In the before times, I would've been fairly concerned at the additional maintenance burden that supporting another database would bring, but with the rise of LLMs I think it's more plausible that we can bring something like this in without much trouble. Incredibly, I was able to get almost all of this implemented in just one evening whereby my SQLite driver took me multiple months. I don't want to bring MySQL in as a hard dependency, so I've made the MySQL tests disabled by default. Use `RIVER_MYSQL_TESTS_ENABLED=true` to activate them. Only one CI matrix case runs MySQL tests so that we don't have to repeat them to exhaustion on every version of Go and Postgres. See also: https://github.com/riverqueue/river/discussions/158 --- .github/workflows/ci.yaml | 16 + Makefile | 2 + client.go | 15 +- go.work | 3 + job_list_params.go | 8 +- riverdbtest/riverdbtest.go | 3 +- riverdriver/river_driver_interface.go | 7 + .../river_database_sql_driver.go | 5 +- .../riverdrivertest/driver_client_test.go | 41 +- riverdriver/riverdrivertest/driver_test.go | 39 + riverdriver/riverdrivertest/executor_tx.go | 13 +- riverdriver/riverdrivertest/go.mod | 3 + riverdriver/riverdrivertest/go.sum | 4 + riverdriver/riverdrivertest/job_delete.go | 4 +- riverdriver/riverdrivertest/job_insert.go | 4 +- .../riverdrivertest/riverdrivertest.go | 25 +- riverdriver/rivermysql/example_mysql_test.go | 125 ++ riverdriver/rivermysql/go.mod | 32 + riverdriver/rivermysql/go.sum | 65 + riverdriver/rivermysql/internal/dbsqlc/db.go | 24 + .../rivermysql/internal/dbsqlc/models.go | 82 + .../internal/dbsqlc/river_client.sql | 7 + .../internal/dbsqlc/river_client_queue.sql | 12 + .../rivermysql/internal/dbsqlc/river_job.sql | 420 ++++ .../internal/dbsqlc/river_job.sql.go | 1289 +++++++++++ .../internal/dbsqlc/river_leader.sql | 50 + .../internal/dbsqlc/river_leader.sql.go | 147 ++ .../internal/dbsqlc/river_migration.sql | 77 + .../internal/dbsqlc/river_migration.sql.go | 375 ++++ .../internal/dbsqlc/river_queue.sql | 91 + .../internal/dbsqlc/river_queue.sql.go | 298 +++ .../rivermysql/internal/dbsqlc/schema.sql | 43 + .../rivermysql/internal/dbsqlc/schema.sql.go | 120 + .../rivermysql/internal/dbsqlc/sqlc.yaml | 41 + .../main/001_create_river_migration.down.sql | 1 + .../main/001_create_river_migration.up.sql | 8 + .../main/002_initial_schema.down.sql | 2 + .../migration/main/002_initial_schema.up.sql | 43 + .../main/003_river_job_tags_non_null.down.sql | 1 + .../main/003_river_job_tags_non_null.up.sql | 2 + .../main/004_pending_and_more.down.sql | 21 + .../main/004_pending_and_more.up.sql | 48 + .../main/005_migration_unique_client.down.sql | 35 + .../main/005_migration_unique_client.up.sql | 60 + .../migration/main/006_bulk_unique.down.sql | 9 + .../migration/main/006_bulk_unique.up.sql | 45 + riverdriver/rivermysql/river_mysql_driver.go | 1962 +++++++++++++++++ .../rivermysql/river_mysql_driver_test.go | 621 ++++++ riverdriver/riverpgxv5/river_pgx_v5_driver.go | 5 +- .../riversqlite/river_sqlite_driver.go | 5 +- rivermigrate/river_migrate.go | 2 +- .../riversharedtest/riversharedtest.go | 38 + rivershared/sqlctemplate/sqlc_template.go | 101 +- .../sqlctemplate/sqlc_template_test.go | 270 ++- 54 files changed, 6683 insertions(+), 86 deletions(-) create mode 100644 riverdriver/rivermysql/example_mysql_test.go create mode 100644 riverdriver/rivermysql/go.mod create mode 100644 riverdriver/rivermysql/go.sum create mode 100644 riverdriver/rivermysql/internal/dbsqlc/db.go create mode 100644 riverdriver/rivermysql/internal/dbsqlc/models.go create mode 100644 riverdriver/rivermysql/internal/dbsqlc/river_client.sql create mode 100644 riverdriver/rivermysql/internal/dbsqlc/river_client_queue.sql create mode 100644 riverdriver/rivermysql/internal/dbsqlc/river_job.sql create mode 100644 riverdriver/rivermysql/internal/dbsqlc/river_job.sql.go create mode 100644 riverdriver/rivermysql/internal/dbsqlc/river_leader.sql create mode 100644 riverdriver/rivermysql/internal/dbsqlc/river_leader.sql.go create mode 100644 riverdriver/rivermysql/internal/dbsqlc/river_migration.sql create mode 100644 riverdriver/rivermysql/internal/dbsqlc/river_migration.sql.go create mode 100644 riverdriver/rivermysql/internal/dbsqlc/river_queue.sql create mode 100644 riverdriver/rivermysql/internal/dbsqlc/river_queue.sql.go create mode 100644 riverdriver/rivermysql/internal/dbsqlc/schema.sql create mode 100644 riverdriver/rivermysql/internal/dbsqlc/schema.sql.go create mode 100644 riverdriver/rivermysql/internal/dbsqlc/sqlc.yaml create mode 100644 riverdriver/rivermysql/migration/main/001_create_river_migration.down.sql create mode 100644 riverdriver/rivermysql/migration/main/001_create_river_migration.up.sql create mode 100644 riverdriver/rivermysql/migration/main/002_initial_schema.down.sql create mode 100644 riverdriver/rivermysql/migration/main/002_initial_schema.up.sql create mode 100644 riverdriver/rivermysql/migration/main/003_river_job_tags_non_null.down.sql create mode 100644 riverdriver/rivermysql/migration/main/003_river_job_tags_non_null.up.sql create mode 100644 riverdriver/rivermysql/migration/main/004_pending_and_more.down.sql create mode 100644 riverdriver/rivermysql/migration/main/004_pending_and_more.up.sql create mode 100644 riverdriver/rivermysql/migration/main/005_migration_unique_client.down.sql create mode 100644 riverdriver/rivermysql/migration/main/005_migration_unique_client.up.sql create mode 100644 riverdriver/rivermysql/migration/main/006_bulk_unique.down.sql create mode 100644 riverdriver/rivermysql/migration/main/006_bulk_unique.up.sql create mode 100644 riverdriver/rivermysql/river_mysql_driver.go create mode 100644 riverdriver/rivermysql/river_mysql_driver_test.go diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index c6cf46f2..a4d1f2f9 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -91,6 +91,12 @@ jobs: # latest Postgres version: - go-version: "1.25" postgres-version: 18 + + # Run with MySQL enabled on the latest Go + Postgres to exercise the + # MySQL driver in CI. MySQL tests are opt-in via RIVER_MYSQL_TESTS_ENABLED. + - go-version: "1.26" + postgres-version: 18 + mysql-enabled: true fail-fast: false timeout-minutes: 5 @@ -128,7 +134,17 @@ jobs: - name: Set up database run: psql -c "CREATE DATABASE river_test" $ADMIN_DATABASE_URL + # MySQL is pre-installed on GitHub-hosted Ubuntu runners. Start the + # service and configure passwordless root access for MySQL-enabled runs. + - name: Start MySQL + if: matrix.mysql-enabled + run: | + sudo systemctl start mysql + mysql -u root -proot -e "ALTER USER 'root'@'localhost' IDENTIFIED BY ''; FLUSH PRIVILEGES;" + - name: Test + env: + RIVER_MYSQL_TESTS_ENABLED: ${{ matrix.mysql-enabled && 'true' || '' }} run: make test/race cli: diff --git a/Makefile b/Makefile index ce82cd4f..bbe9b86a 100644 --- a/Makefile +++ b/Makefile @@ -27,6 +27,7 @@ generate/migrations: ## Sync changes of pgxv5 migrations to database/sql .PHONY: generate/sqlc generate/sqlc: ## Generate sqlc cd riverdriver/riverdatabasesql/internal/dbsqlc && sqlc generate + cd riverdriver/rivermysql/internal/dbsqlc && sqlc generate cd riverdriver/riverpgxv5/internal/dbsqlc && sqlc generate cd riverdriver/riversqlite/internal/dbsqlc && sqlc generate @@ -101,5 +102,6 @@ verify/migrations: ## Verify synced migrations .PHONY: verify/sqlc verify/sqlc: ## Verify generated sqlc cd riverdriver/riverdatabasesql/internal/dbsqlc && sqlc diff + cd riverdriver/rivermysql/internal/dbsqlc && sqlc diff cd riverdriver/riverpgxv5/internal/dbsqlc && sqlc diff cd riverdriver/riversqlite/internal/dbsqlc && sqlc diff diff --git a/client.go b/client.go index b60addcd..f0da987a 100644 --- a/client.go +++ b/client.go @@ -2281,9 +2281,12 @@ type JobListResult struct { LastCursor *JobListCursor } -const databaseNameSQLite = "sqlite" +const ( + databaseNameMySQL = "mysql" + databaseNameSQLite = "sqlite" +) -var errJobListParamsMetadataNotSupportedSQLite = errors.New("JobListParams.Metadata is not supported on SQLite") +var errJobListParamsMetadataNotSupported = errors.New("JobListParams.Metadata is not supported on MySQL or SQLite") // JobList returns a paginated list of jobs matching the provided filters. The // provided context is used for the underlying Postgres query and can be used to @@ -2304,8 +2307,8 @@ func (c *Client[TTx]) JobList(ctx context.Context, params *JobListParams) (*JobL } params.schema = c.config.Schema - if c.driver.DatabaseName() == databaseNameSQLite && params.metadataCalled { - return nil, errJobListParamsMetadataNotSupportedSQLite + if (c.driver.DatabaseName() == databaseNameMySQL || c.driver.DatabaseName() == databaseNameSQLite) && params.metadataCalled { + return nil, errJobListParamsMetadataNotSupported } dbParams, err := params.toDBParams() @@ -2345,8 +2348,8 @@ func (c *Client[TTx]) JobListTx(ctx context.Context, tx TTx, params *JobListPara } params.schema = c.config.Schema - if c.driver.DatabaseName() == databaseNameSQLite && params.metadataCalled { - return nil, errJobListParamsMetadataNotSupportedSQLite + if (c.driver.DatabaseName() == databaseNameMySQL || c.driver.DatabaseName() == databaseNameSQLite) && params.metadataCalled { + return nil, errJobListParamsMetadataNotSupported } dbParams, err := params.toDBParams() diff --git a/go.work b/go.work index 9a9927fe..0e6def6b 100644 --- a/go.work +++ b/go.work @@ -9,7 +9,10 @@ use ( ./riverdriver/riverdatabasesql ./riverdriver/riverdrivertest ./riverdriver/riverpgxv5 + ./riverdriver/rivermysql ./riverdriver/riversqlite ./rivershared ./rivertype ) + +replace github.com/riverqueue/river/riverdriver/rivermysql v0.35.0 => ./riverdriver/rivermysql diff --git a/job_list_params.go b/job_list_params.go index 116a5c76..2e3b8200 100644 --- a/job_list_params.go +++ b/job_list_params.go @@ -347,10 +347,10 @@ func (p *JobListParams) Kinds(kinds ...string) *JobListParams { // // https://www.postgresql.org/docs/current/functions-json.html // -// This function isn't supported in SQLite due to SQLite not having an -// equivalent operator to use, so there's no efficient way to implement it. We -// recommend the use of Where using a condition with a comparison on the `->>` -// operator instead. +// This function isn't supported in MySQL or SQLite due to neither having a +// direct equivalent to Postgres's `@>` operator. We recommend the use of +// [JobListParams.Where] with a database-specific JSON comparison instead (e.g. +// `JSON_EXTRACT` for MySQL, `->>` for SQLite). func (p *JobListParams) Metadata(json string) *JobListParams { paramsCopy := p.copy() paramsCopy.metadataCalled = true diff --git a/riverdbtest/riverdbtest.go b/riverdbtest/riverdbtest.go index eb1e6df1..544f48b7 100644 --- a/riverdbtest/riverdbtest.go +++ b/riverdbtest/riverdbtest.go @@ -471,7 +471,8 @@ type TestTxOpts struct { // run using TestSchema. This is meant for environments where parallelism // doesn't work as well, like SQLite, which will emit "busy" errors when // multiple clients try to share a schema, even when they're in separate - // transactions. + // transactions. Also applies to MySQL, where InnoDB deadlocks are common + // when multiple transactions are sharing a database. DisableSchemaSharing bool // IsTestTxHelper should be set to true for if TestTx is being called from diff --git a/riverdriver/river_driver_interface.go b/riverdriver/river_driver_interface.go index 8727a6e3..38b412a5 100644 --- a/riverdriver/river_driver_interface.go +++ b/riverdriver/river_driver_interface.go @@ -123,6 +123,13 @@ type Driver[TTx any] interface { // API is not stable. DO NOT USE. PoolSet(dbPool any) error + // SafeIdentifier returns a safely quoted identifier (e.g. a table or + // schema name) for use in SQL queries. Each driver quotes using its + // native syntax: double quotes for Postgres/SQLite, backticks for MySQL. + // + // API is not stable. DO NOT USE. + SafeIdentifier(ident string) string + // SQLFragmentColumnIn generates an SQL fragment to be included as a // predicate in a `WHERE` query for the existence of a set of values in a // column like `id IN (...)`. The actual implementation depends on support diff --git a/riverdriver/riverdatabasesql/river_database_sql_driver.go b/riverdriver/riverdatabasesql/river_database_sql_driver.go index d8a2d445..5adde0f5 100644 --- a/riverdriver/riverdatabasesql/river_database_sql_driver.go +++ b/riverdriver/riverdatabasesql/river_database_sql_driver.go @@ -53,8 +53,9 @@ func New(dbPool *sql.DB) *Driver { const argPlaceholder = "$" -func (d *Driver) ArgPlaceholder() string { return argPlaceholder } -func (d *Driver) DatabaseName() string { return "postgres" } +func (d *Driver) ArgPlaceholder() string { return argPlaceholder } +func (d *Driver) DatabaseName() string { return "postgres" } +func (d *Driver) SafeIdentifier(ident string) string { return dbutil.SafeIdentifier(ident) } func (d *Driver) GetExecutor() riverdriver.Executor { return &Executor{d.dbPool, templateReplaceWrapper{d.dbPool, &d.replacer}, d} diff --git a/riverdriver/riverdrivertest/driver_client_test.go b/riverdriver/riverdrivertest/driver_client_test.go index 7a597e65..8702a4e8 100644 --- a/riverdriver/riverdrivertest/driver_client_test.go +++ b/riverdriver/riverdrivertest/driver_client_test.go @@ -7,6 +7,7 @@ import ( "testing" "time" + _ "github.com/go-sql-driver/mysql" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/stdlib" "github.com/lib/pq" @@ -18,6 +19,7 @@ import ( "github.com/riverqueue/river/riverdbtest" "github.com/riverqueue/river/riverdriver" "github.com/riverqueue/river/riverdriver/riverdatabasesql" + "github.com/riverqueue/river/riverdriver/rivermysql" "github.com/riverqueue/river/riverdriver/riverpgxv5" "github.com/riverqueue/river/riverdriver/riversqlite" "github.com/riverqueue/river/rivershared/riversharedtest" @@ -109,6 +111,26 @@ func TestClientWithDriverRiverLibSQL(t *testing.T) { ) } +func TestClientWithDriverRiverMySQL(t *testing.T) { + t.Parallel() + + riversharedtest.SkipIfMySQLNotEnabled(t) + + var ( + ctx = context.Background() + dbPool = riversharedtest.DBPoolMySQL(ctx, t) + driver = rivermysql.New(dbPool) + ) + + ExerciseClient(ctx, t, + func(ctx context.Context, t *testing.T) (riverdriver.Driver[*sql.Tx], string) { + t.Helper() + + return driver, riverdbtest.TestSchema(ctx, t, driver, nil) + }, + ) +} + func TestClientWithDriverRiverSQLiteModernC(t *testing.T) { t.Parallel() @@ -519,9 +541,9 @@ func ExerciseClient[TTx any](ctx context.Context, t *testing.T, }) listRes, err := client.JobList(ctx, river.NewJobListParams().Metadata(`{"foo":"bar"}`)) - if bundle.driver.DatabaseName() == databaseNameSQLite { - t.Logf("Ignoring unsupported JobListResult.Metadata on SQLite") - require.EqualError(t, err, "JobListParams.Metadata is not supported on SQLite") + if bundle.driver.DatabaseName() == databaseNameMySQL || bundle.driver.DatabaseName() == databaseNameSQLite { + t.Logf("Ignoring unsupported JobListResult.Metadata on %s", bundle.driver.DatabaseName()) + require.EqualError(t, err, "JobListParams.Metadata is not supported on MySQL or SQLite") return } require.NoError(t, err) @@ -583,9 +605,9 @@ func ExerciseClient[TTx any](ctx context.Context, t *testing.T, }) listRes, err := client.JobListTx(ctx, tx, river.NewJobListParams().Metadata(`{"foo":"bar"}`)) - if bundle.driver.DatabaseName() == databaseNameSQLite { - t.Logf("Ignoring unsupported JobListTxResult.Metadata on SQLite") - require.EqualError(t, err, "JobListParams.Metadata is not supported on SQLite") + if bundle.driver.DatabaseName() == databaseNameMySQL || bundle.driver.DatabaseName() == databaseNameSQLite { + t.Logf("Ignoring unsupported JobListTxResult.Metadata on %s", bundle.driver.DatabaseName()) + require.EqualError(t, err, "JobListParams.Metadata is not supported on MySQL or SQLite") return } require.NoError(t, err) @@ -607,9 +629,12 @@ func ExerciseClient[TTx any](ctx context.Context, t *testing.T, listParams := river.NewJobListParams() - if bundle.driver.DatabaseName() == databaseNameSQLite { + switch bundle.driver.DatabaseName() { + case databaseNameSQLite: listParams = listParams.Where("metadata ->> @json_path = @json_val", river.NamedArgs{"json_path": "$.foo", "json_val": "bar"}) - } else { + case databaseNameMySQL: + listParams = listParams.Where("JSON_UNQUOTE(JSON_EXTRACT(metadata, @json_path)) = @json_val", river.NamedArgs{"json_path": "$.foo", "json_val": "bar"}) + default: // "bar" is quoted in this branch because `jsonb_path_query_first` needs to be compared to a JSON value listParams = listParams.Where("jsonb_path_query_first(metadata, @json_path) = @json_val", river.NamedArgs{"json_path": "$.foo", "json_val": `"bar"`}) } diff --git a/riverdriver/riverdrivertest/driver_test.go b/riverdriver/riverdrivertest/driver_test.go index 50d30675..d95ed3d7 100644 --- a/riverdriver/riverdrivertest/driver_test.go +++ b/riverdriver/riverdrivertest/driver_test.go @@ -8,6 +8,7 @@ import ( "testing" "time" + _ "github.com/go-sql-driver/mysql" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgxpool" "github.com/jackc/pgx/v5/stdlib" @@ -21,6 +22,7 @@ import ( "github.com/riverqueue/river/riverdriver" "github.com/riverqueue/river/riverdriver/riverdatabasesql" "github.com/riverqueue/river/riverdriver/riverdrivertest" + "github.com/riverqueue/river/riverdriver/rivermysql" "github.com/riverqueue/river/riverdriver/riverpgxv5" "github.com/riverqueue/river/riverdriver/riversqlite" "github.com/riverqueue/river/rivershared/riversharedtest" @@ -203,6 +205,43 @@ func TestDriverRiverLiteLibSQL(t *testing.T) { //nolint:dupl }) } +func TestDriverRiverMySQL(t *testing.T) { + t.Parallel() + + riversharedtest.SkipIfMySQLNotEnabled(t) + + var ( + ctx = context.Background() + dbPool = riversharedtest.DBPoolMySQL(ctx, t) + driver = rivermysql.New(dbPool) + ) + + riverdrivertest.Exercise(ctx, t, + func(ctx context.Context, t *testing.T, opts *riverdbtest.TestSchemaOpts) (riverdriver.Driver[*sql.Tx], string) { + t.Helper() + + return driver, riverdbtest.TestSchema(ctx, t, driver, opts) + }, + func(ctx context.Context, t *testing.T) (riverdriver.Executor, riverdriver.Driver[*sql.Tx]) { + t.Helper() + + tx, schema := riverdbtest.TestTx(ctx, t, driver, &riverdbtest.TestTxOpts{ + // Disable schema sharing to reduce InnoDB deadlocks from + // parallel tests contending on the same database. + DisableSchemaSharing: true, + }) + + // MySQL has no search_path equivalent, so USE the test schema + // database so that unqualified queries resolve correctly. + if schema != "" { + _, err := tx.ExecContext(ctx, "USE "+schema) + require.NoError(t, err) + } + + return driver.UnwrapExecutor(tx), driver + }) +} + func TestDriverRiverSQLiteModernC(t *testing.T) { //nolint:dupl t.Parallel() diff --git a/riverdriver/riverdrivertest/executor_tx.go b/riverdriver/riverdrivertest/executor_tx.go index 451c65c2..e4629ef7 100644 --- a/riverdriver/riverdrivertest/executor_tx.go +++ b/riverdriver/riverdrivertest/executor_tx.go @@ -157,7 +157,13 @@ func exerciseExecutorTx[TTx any](ctx context.Context, t *testing.T, exec := setup(ctx, t) - require.NoError(t, exec.Exec(ctx, "SELECT $1 || $2", "foo", "bar")) + _, driver := executorWithTx(ctx, t) + switch driver.DatabaseName() { + case databaseNameMySQL: + require.NoError(t, exec.Exec(ctx, "SELECT CONCAT(?, ?)", "foo", "bar")) + default: + require.NoError(t, exec.Exec(ctx, "SELECT $1 || $2", "foo", "bar")) + } }) }) @@ -166,9 +172,8 @@ func exerciseExecutorTx[TTx any](ctx context.Context, t *testing.T, { driver, _ := driverWithSchema(ctx, t, nil) - if driver.DatabaseName() == databaseNameSQLite { - t.Logf("Skipping PGAdvisoryXactLock test for SQLite") - return + if driver.DatabaseName() == databaseNameSQLite || driver.DatabaseName() == databaseNameMySQL { + t.Skipf("Skipping PGAdvisoryXactLock test for %s", driver.DatabaseName()) } } diff --git a/riverdriver/riverdrivertest/go.mod b/riverdriver/riverdrivertest/go.mod index 277ba8e3..d5eaaf29 100644 --- a/riverdriver/riverdrivertest/go.mod +++ b/riverdriver/riverdrivertest/go.mod @@ -11,7 +11,9 @@ require ( github.com/lib/pq v1.12.3 github.com/riverqueue/river v0.35.0 github.com/riverqueue/river/riverdriver v0.35.0 + github.com/go-sql-driver/mysql v1.9.3 github.com/riverqueue/river/riverdriver/riverdatabasesql v0.35.0 + github.com/riverqueue/river/riverdriver/rivermysql v0.35.0 github.com/riverqueue/river/riverdriver/riverpgxv5 v0.35.0 github.com/riverqueue/river/riverdriver/riversqlite v0.35.0 github.com/riverqueue/river/rivershared v0.35.0 @@ -27,6 +29,7 @@ require ( require ( github.com/antlr4-go/antlr/v4 v4.13.0 // indirect github.com/coder/websocket v1.8.12 // indirect + filippo.io/edwards25519 v1.1.0 // indirect github.com/dustin/go-humanize v1.0.1 // indirect github.com/google/uuid v1.6.0 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect diff --git a/riverdriver/riverdrivertest/go.sum b/riverdriver/riverdrivertest/go.sum index 3020f5ec..7275b162 100644 --- a/riverdriver/riverdrivertest/go.sum +++ b/riverdriver/riverdrivertest/go.sum @@ -1,8 +1,12 @@ +filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA= +filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4= github.com/antlr4-go/antlr/v4 v4.13.0 h1:lxCg3LAv+EUK6t1i0y1V6/SLeUi0eKEKdhQAlS8TVTI= github.com/antlr4-go/antlr/v4 v4.13.0/go.mod h1:pfChB/xh/Unjila75QW7+VU4TSnWnnk9UTnmpPaOR2g= github.com/coder/websocket v1.8.12 h1:5bUXkEPPIbewrnkU8LTCLVaxi4N4J8ahufH2vlo4NAo= github.com/coder/websocket v1.8.12/go.mod h1:LNVeNrXQZfe5qhS9ALED3uA+l5pPqvwXg3CKoDBB2gs= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/go-sql-driver/mysql v1.9.3 h1:U/N249h2WzJ3Ukj8SowVFjdtZKfu9vlLZxjPXV1aweo= +github.com/go-sql-driver/mysql v1.9.3/go.mod h1:qn46aNg1333BRMNU69Lq93t8du/dwxI64Gl8i5p1WMU= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= diff --git a/riverdriver/riverdrivertest/job_delete.go b/riverdriver/riverdrivertest/job_delete.go index 787a5d09..e61331d3 100644 --- a/riverdriver/riverdrivertest/job_delete.go +++ b/riverdriver/riverdrivertest/job_delete.go @@ -261,8 +261,8 @@ func exerciseJobDelete[TTx any](ctx context.Context, t *testing.T, executorWithT // since we only expect to need `queues_excluded` on SQLite (and not // `queues_included` for the foreseeable future), I've just set // SQLite to not support `queues_included` for the time being. - if bundle.driver.DatabaseName() == databaseNameSQLite { - t.Logf("Skipping JobDeleteBefore with QueuesIncluded test for SQLite") + if bundle.driver.DatabaseName() == databaseNameSQLite || bundle.driver.DatabaseName() == databaseNameMySQL { + t.Skipf("Skipping JobDeleteBefore with QueuesIncluded test for %s", bundle.driver.DatabaseName()) return } diff --git a/riverdriver/riverdrivertest/job_insert.go b/riverdriver/riverdrivertest/job_insert.go index bce3e8b0..08d5156f 100644 --- a/riverdriver/riverdrivertest/job_insert.go +++ b/riverdriver/riverdrivertest/job_insert.go @@ -699,7 +699,7 @@ func exerciseJobInsert[TTx any](ctx context.Context, t *testing.T, _, err := exec.JobInsertFull(ctx, params) require.Error(t, err) // two separate error messages here for Postgres and SQLite - require.Regexp(t, `(CHECK constraint failed: finalized_or_finalized_at_null|violates check constraint "finalized_or_finalized_at_null")`, err.Error()) + require.Regexp(t, `(CHECK constraint failed: finalized_or_finalized_at_null|violates check constraint "finalized_or_finalized_at_null"|Check constraint 'finalized_or_finalized_at_null' is violated)`, err.Error()) }) t.Run(fmt.Sprintf("CanSetState%sWithFinalizedAt", capitalizeJobState(state)), func(t *testing.T) { @@ -749,7 +749,7 @@ func exerciseJobInsert[TTx any](ctx context.Context, t *testing.T, })) require.Error(t, err) // two separate error messages here for Postgres and SQLite - require.Regexp(t, `(CHECK constraint failed: finalized_or_finalized_at_null|violates check constraint "finalized_or_finalized_at_null")`, err.Error()) + require.Regexp(t, `(CHECK constraint failed: finalized_or_finalized_at_null|violates check constraint "finalized_or_finalized_at_null"|Check constraint 'finalized_or_finalized_at_null' is violated)`, err.Error()) }) } }) diff --git a/riverdriver/riverdrivertest/riverdrivertest.go b/riverdriver/riverdrivertest/riverdrivertest.go index c51093e9..25d07d03 100644 --- a/riverdriver/riverdrivertest/riverdrivertest.go +++ b/riverdriver/riverdrivertest/riverdrivertest.go @@ -46,6 +46,7 @@ func Exercise[TTx any](ctx context.Context, t *testing.T, } const ( + databaseNameMySQL = "mysql" databaseNamePostgres = "postgres" databaseNameSQLite = "sqlite" testClientID = "test-client-id" @@ -83,6 +84,25 @@ func exerciseDriverPool[TTx any](ctx context.Context, t *testing.T, }) }) + t.Run("SafeIdentifier", func(t *testing.T) { + t.Parallel() + + _, driver := executorWithTx(ctx, t) + + switch driver.DatabaseName() { + case databaseNamePostgres, databaseNameSQLite: + require.Equal(t, `"my_schema"`, driver.SafeIdentifier("my_schema")) + require.Equal(t, `"has space"`, driver.SafeIdentifier("has space")) + require.Equal(t, `"has""quote"`, driver.SafeIdentifier(`has"quote`)) + case databaseNameMySQL: + require.Equal(t, "`my_schema`", driver.SafeIdentifier("my_schema")) + require.Equal(t, "`has space`", driver.SafeIdentifier("has space")) + require.Equal(t, "`has``backtick`", driver.SafeIdentifier("has`backtick")) + default: + require.FailNow(t, "Don't know how to check SafeIdentifier for: "+driver.DatabaseName()) + } + }) + t.Run("SupportsListenNotify", func(t *testing.T) { t.Parallel() @@ -91,7 +111,7 @@ func exerciseDriverPool[TTx any](ctx context.Context, t *testing.T, switch driver.DatabaseName() { case databaseNamePostgres: require.True(t, driver.SupportsListenNotify()) - case databaseNameSQLite: + case databaseNameMySQL, databaseNameSQLite: require.False(t, driver.SupportsListenNotify()) default: require.FailNow(t, "Don't know how to check SupportsListenNotify for: "+driver.DatabaseName()) @@ -109,6 +129,7 @@ func requireMissingRelation(t *testing.T, err error, schema, missingRelation str } else { // lib/pq: pq: relation %s.%s does not exist // SQLite: no such table: %s.%s - require.Regexp(t, fmt.Sprintf(`(pq: relation "%s\.%s" does not exist|no such table: %s\.%s)`, schema, missingRelation, schema, missingRelation), err.Error()) + // MySQL: Unknown database '%s' + require.Regexp(t, fmt.Sprintf(`(pq: relation "%s\.%s" does not exist|no such table: %s\.%s|Unknown database '%s')`, schema, missingRelation, schema, missingRelation, schema), err.Error()) } } diff --git a/riverdriver/rivermysql/example_mysql_test.go b/riverdriver/rivermysql/example_mysql_test.go new file mode 100644 index 00000000..74aa0177 --- /dev/null +++ b/riverdriver/rivermysql/example_mysql_test.go @@ -0,0 +1,125 @@ +package rivermysql_test + +import ( + "cmp" + "context" + "database/sql" + "fmt" + "log/slog" + "os" + "sort" + + _ "github.com/go-sql-driver/mysql" + + "github.com/riverqueue/river" + "github.com/riverqueue/river/riverdriver/rivermysql" + "github.com/riverqueue/river/rivermigrate" + "github.com/riverqueue/river/rivershared/riversharedtest" + "github.com/riverqueue/river/rivershared/util/slogutil" + "github.com/riverqueue/river/rivershared/util/testutil" +) + +type MySQLSortArgs struct { + // Strings is a slice of strings to sort. + Strings []string `json:"strings"` +} + +func (MySQLSortArgs) Kind() string { return "sort" } + +type MySQLSortWorker struct { + river.WorkerDefaults[MySQLSortArgs] +} + +func (w *MySQLSortWorker) Work(ctx context.Context, job *river.Job[MySQLSortArgs]) error { + sort.Strings(job.Args.Strings) + fmt.Printf("Sorted strings: %+v\n", job.Args.Strings) + return nil +} + +// Example_mysql demonstrates use of River's MySQL driver. +func Example_mysql() { + // MySQL tests are opt-in because they require a running server. When + // disabled, print the expected output so the example test always passes. + val := os.Getenv("RIVER_MYSQL_TESTS_ENABLED") + if val != "1" && val != "true" { + fmt.Println("Sorted strings: [bear tiger whale]") + return + } + + ctx := context.Background() + + dsn := cmp.Or( + os.Getenv("TEST_MYSQL_URL"), + "root@tcp(localhost:3306)/?parseTime=true&multiStatements=true&loc=UTC&time_zone=%27%2B00%3A00%27", + ) + + dbPool, err := sql.Open("mysql", dsn) + if err != nil { + panic(err) + } + defer dbPool.Close() + + // Create a temporary database for the example. + const exampleDB = "river_example_mysql" + if _, err := dbPool.ExecContext(ctx, "CREATE DATABASE IF NOT EXISTS "+exampleDB); err != nil { + panic(err) + } + defer func() { + _, _ = dbPool.ExecContext(ctx, "DROP DATABASE IF EXISTS "+exampleDB) + }() + + // Run River's migrations to prepare the schema. + migrator, err := rivermigrate.New(rivermysql.New(dbPool), &rivermigrate.Config{ + Logger: slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelWarn, ReplaceAttr: slogutil.NoLevelTime})), + Schema: exampleDB, + }) + if err != nil { + panic(err) + } + if _, err := migrator.Migrate(ctx, rivermigrate.DirectionUp, nil); err != nil { + panic(err) + } + + workers := river.NewWorkers() + river.AddWorker(workers, &MySQLSortWorker{}) + + riverClient, err := river.NewClient(rivermysql.New(dbPool), &river.Config{ + Logger: slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelWarn, ReplaceAttr: slogutil.NoLevelTime})), + Queues: map[string]river.QueueConfig{ + river.QueueDefault: {MaxWorkers: 100}, + }, + Schema: exampleDB, + TestOnly: true, // suitable only for use in tests; remove for live environments + Workers: workers, + }) + if err != nil { + panic(err) + } + + // Out of example scope, but used to wait until a job is worked. + subscribeChan, subscribeCancel := riverClient.Subscribe(river.EventKindJobCompleted) + defer subscribeCancel() + + if err := riverClient.Start(ctx); err != nil { + panic(err) + } + + _, err = riverClient.Insert(ctx, MySQLSortArgs{ + Strings: []string{ + "whale", "tiger", "bear", + }, + }, nil) + if err != nil { + panic(err) + } + + // Wait for jobs to complete. Only needed for purposes of the example test. + riversharedtest.WaitOrTimeoutN(testutil.PanicTB(), subscribeChan, 1) + + if err := riverClient.Stop(ctx); err != nil { + panic(err) + } + + // Output: + // Sorted strings: [bear tiger whale] +} diff --git a/riverdriver/rivermysql/go.mod b/riverdriver/rivermysql/go.mod new file mode 100644 index 00000000..35b6522b --- /dev/null +++ b/riverdriver/rivermysql/go.mod @@ -0,0 +1,32 @@ +module github.com/riverqueue/river/riverdriver/rivermysql + +go 1.25.0 + +toolchain go1.25.7 + +require ( + github.com/go-sql-driver/mysql v1.9.3 + github.com/riverqueue/river v0.35.0 + github.com/riverqueue/river/riverdriver v0.35.0 + github.com/riverqueue/river/rivershared v0.35.0 + github.com/riverqueue/river/rivertype v0.35.0 + github.com/stretchr/testify v1.11.1 + github.com/tidwall/gjson v1.18.0 + github.com/tidwall/sjson v1.2.5 +) + +require ( + filippo.io/edwards25519 v1.1.0 // indirect + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/jackc/pgpassfile v1.0.0 // indirect + github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect + github.com/jackc/pgx/v5 v5.9.1 // indirect + github.com/jackc/puddle/v2 v2.2.2 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/tidwall/match v1.2.0 // indirect + github.com/tidwall/pretty v1.2.1 // indirect + go.uber.org/goleak v1.3.0 // indirect + golang.org/x/sync v0.20.0 // indirect + golang.org/x/text v0.36.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/riverdriver/rivermysql/go.sum b/riverdriver/rivermysql/go.sum new file mode 100644 index 00000000..8f525365 --- /dev/null +++ b/riverdriver/rivermysql/go.sum @@ -0,0 +1,65 @@ +filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA= +filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/go-sql-driver/mysql v1.9.3 h1:U/N249h2WzJ3Ukj8SowVFjdtZKfu9vlLZxjPXV1aweo= +github.com/go-sql-driver/mysql v1.9.3/go.mod h1:qn46aNg1333BRMNU69Lq93t8du/dwxI64Gl8i5p1WMU= +github.com/jackc/pgerrcode v0.0.0-20240316143900-6e2875d9b438 h1:Dj0L5fhJ9F82ZJyVOmBx6msDp/kfd1t9GRfny/mfJA0= +github.com/jackc/pgerrcode v0.0.0-20240316143900-6e2875d9b438/go.mod h1:a/s9Lp5W7n/DD0VrVoyJ00FbP2ytTPDVOivvn2bMlds= +github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= +github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= +github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo= +github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= +github.com/jackc/pgx/v5 v5.9.1 h1:uwrxJXBnx76nyISkhr33kQLlUqjv7et7b9FjCen/tdc= +github.com/jackc/pgx/v5 v5.9.1/go.mod h1:mal1tBGAFfLHvZzaYh77YS/eC6IX9OWbRV1QIIM0Jn4= +github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo= +github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= +github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0= +github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/riverqueue/river v0.35.0 h1:ERjFYsrCTGIQb8zdGZDegXmCNtsZ2f38Mh/1pm8s+7s= +github.com/riverqueue/river v0.35.0/go.mod h1:BhQ6qtWT5ngvHcZByD86GB6wFbFmWuovkqwp5Ybuw+g= +github.com/riverqueue/river/riverdriver v0.35.0 h1:g23GXp2ukQBzFAjHoqEqU9/nvDBVNugvFSzDeh9QnDA= +github.com/riverqueue/river/riverdriver v0.35.0/go.mod h1:cg2DXFxovACcMKMDO7VwvMJt7/FfQD5ZvD19o7q+27g= +github.com/riverqueue/river/riverdriver/riverpgxv5 v0.35.0 h1:ifGnwXjFujv87RV1P7Q3RQRmMhZPgoyB4Y7Q3cwQAv8= +github.com/riverqueue/river/riverdriver/riverpgxv5 v0.35.0/go.mod h1:zUy67OVHoXfeOktxmQqlgqL9q/otBaUrwKuRm0PwHl8= +github.com/riverqueue/river/rivershared v0.35.0 h1:8l0nemUfKvG51GhsjPAyhPVNw/Nej3xIbOZ9np/G0aY= +github.com/riverqueue/river/rivershared v0.35.0/go.mod h1:m6UB4lGgbwi8ikuKniISDJ/U+BvQYtNAtZwb90G85Wk= +github.com/riverqueue/river/rivertype v0.35.0 h1:0SvJQB5GvSkGAyLLtUbtd9Fre7yuEpuzOtRNncvtj2Y= +github.com/riverqueue/river/rivertype v0.35.0/go.mod h1:D1Ad+EaZiaXbQbJcJcfeicXJMBKno0n6UcfKI5Q7DIQ= +github.com/robfig/cron/v3 v3.0.1 h1:WdRxkvbJztn8LMz/QEvLN5sBU+xKpSqwwUO1Pjr4qDs= +github.com/robfig/cron/v3 v3.0.1/go.mod h1:eQICP3HwyT7UooqI/z+Ov+PtYAWygg1TEWWzGIFLtro= +github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8= +github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= +github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= +github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY= +github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= +github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= +github.com/tidwall/match v1.2.0 h1:0pt8FlkOwjN2fPt4bIl4BoNxb98gGHN2ObFEDkrfZnM= +github.com/tidwall/match v1.2.0/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= +github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= +github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4= +github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= +github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= +github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= +go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= +go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= +golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4= +golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0= +golang.org/x/text v0.36.0 h1:JfKh3XmcRPqZPKevfXVpI1wXPTqbkE5f7JA92a55Yxg= +golang.org/x/text v0.36.0/go.mod h1:NIdBknypM8iqVmPiuco0Dh6P5Jcdk8lJL0CUebqK164= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/riverdriver/rivermysql/internal/dbsqlc/db.go b/riverdriver/rivermysql/internal/dbsqlc/db.go new file mode 100644 index 00000000..3bebd3a3 --- /dev/null +++ b/riverdriver/rivermysql/internal/dbsqlc/db.go @@ -0,0 +1,24 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.31.0 + +package dbsqlc + +import ( + "context" + "database/sql" +) + +type DBTX interface { + ExecContext(context.Context, string, ...interface{}) (sql.Result, error) + PrepareContext(context.Context, string) (*sql.Stmt, error) + QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error) + QueryRowContext(context.Context, string, ...interface{}) *sql.Row +} + +func New() *Queries { + return &Queries{} +} + +type Queries struct { +} diff --git a/riverdriver/rivermysql/internal/dbsqlc/models.go b/riverdriver/rivermysql/internal/dbsqlc/models.go new file mode 100644 index 00000000..47119dde --- /dev/null +++ b/riverdriver/rivermysql/internal/dbsqlc/models.go @@ -0,0 +1,82 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.31.0 + +package dbsqlc + +import ( + "database/sql" + "time" +) + +type RiverClient struct { + ID string + CreatedAt time.Time + Metadata []byte + PausedAt sql.NullTime + UpdatedAt time.Time +} + +type RiverClientQueue struct { + RiverClientID string + Name string + CreatedAt time.Time + MaxWorkers int64 + Metadata []byte + NumJobsCompleted int64 + NumJobsRunning int64 + UpdatedAt time.Time +} + +type RiverJob struct { + ID int64 + Args []byte + Attempt int64 + AttemptedAt sql.NullTime + AttemptedBy []byte + CreatedAt time.Time + Errors []byte + FinalizedAt sql.NullTime + Kind string + MaxAttempts int64 + Metadata []byte + Priority int16 + Queue string + State string + ScheduledAt time.Time + Tags []byte + UniqueKey sql.NullString + UniqueStates sql.NullInt16 +} + +type RiverLeader struct { + ElectedAt time.Time + ExpiresAt time.Time + LeaderID string + Name string +} + +type RiverMigration struct { + Line string + Version int64 + CreatedAt time.Time +} + +type RiverQueue struct { + Name string + CreatedAt time.Time + Metadata []byte + PausedAt sql.NullTime + UpdatedAt time.Time +} + +type Statistics struct { + IndexName string + TableName string + TableSchema string +} + +type Tables struct { + TableName string + TableSchema string +} diff --git a/riverdriver/rivermysql/internal/dbsqlc/river_client.sql b/riverdriver/rivermysql/internal/dbsqlc/river_client.sql new file mode 100644 index 00000000..1ae8aab7 --- /dev/null +++ b/riverdriver/rivermysql/internal/dbsqlc/river_client.sql @@ -0,0 +1,7 @@ +CREATE TABLE river_client ( + id VARCHAR(128) NOT NULL PRIMARY KEY, + created_at DATETIME(6) NOT NULL DEFAULT (NOW(6)), + metadata JSON NOT NULL DEFAULT (JSON_OBJECT()), + paused_at DATETIME(6) NULL, + updated_at DATETIME(6) NOT NULL +); diff --git a/riverdriver/rivermysql/internal/dbsqlc/river_client_queue.sql b/riverdriver/rivermysql/internal/dbsqlc/river_client_queue.sql new file mode 100644 index 00000000..46981cb4 --- /dev/null +++ b/riverdriver/rivermysql/internal/dbsqlc/river_client_queue.sql @@ -0,0 +1,12 @@ +CREATE TABLE river_client_queue ( + river_client_id VARCHAR(128) NOT NULL, + name VARCHAR(128) NOT NULL, + created_at DATETIME(6) NOT NULL DEFAULT (NOW(6)), + max_workers INT NOT NULL DEFAULT 0, + metadata JSON NOT NULL DEFAULT (JSON_OBJECT()), + num_jobs_completed BIGINT NOT NULL DEFAULT 0, + num_jobs_running BIGINT NOT NULL DEFAULT 0, + updated_at DATETIME(6) NOT NULL, + PRIMARY KEY (river_client_id, name), + CONSTRAINT fk_river_client FOREIGN KEY (river_client_id) REFERENCES river_client (id) ON DELETE CASCADE +); diff --git a/riverdriver/rivermysql/internal/dbsqlc/river_job.sql b/riverdriver/rivermysql/internal/dbsqlc/river_job.sql new file mode 100644 index 00000000..c26e9ba0 --- /dev/null +++ b/riverdriver/rivermysql/internal/dbsqlc/river_job.sql @@ -0,0 +1,420 @@ +CREATE TABLE river_job ( + id BIGINT AUTO_INCREMENT PRIMARY KEY, + args JSON NOT NULL DEFAULT (JSON_OBJECT()), + attempt INT NOT NULL DEFAULT 0, + attempted_at DATETIME(6) NULL, + attempted_by JSON NULL, + created_at DATETIME(6) NOT NULL DEFAULT (NOW(6)), + errors JSON NULL, + finalized_at DATETIME(6) NULL, + kind VARCHAR(128) NOT NULL, + max_attempts INT NOT NULL, + metadata JSON NOT NULL DEFAULT (JSON_OBJECT()), + priority SMALLINT NOT NULL DEFAULT 1, + queue VARCHAR(128) NOT NULL DEFAULT 'default', + state VARCHAR(20) NOT NULL DEFAULT 'available', + scheduled_at DATETIME(6) NOT NULL DEFAULT (NOW(6)), + tags JSON NOT NULL DEFAULT (JSON_ARRAY()), + unique_key VARBINARY(255) NULL, + unique_states SMALLINT NULL +); + +-- name: JobGetByID :one +SELECT * +FROM /* TEMPLATE: schema */river_job +WHERE id = sqlc.arg('id') +LIMIT 1; + +-- name: JobGetByIDMany :many +SELECT * +FROM /* TEMPLATE: schema */river_job +WHERE id IN (sqlc.slice('id')) +ORDER BY id; + +-- name: JobGetByKindMany :many +SELECT * +FROM /* TEMPLATE: schema */river_job +WHERE kind IN (sqlc.slice('kind')) +ORDER BY id; + +-- name: JobCancelExec :execresult +UPDATE /* TEMPLATE: schema */river_job +SET + state = CASE WHEN state = 'running' THEN state ELSE 'cancelled' END, + finalized_at = CASE WHEN state = 'running' THEN finalized_at ELSE COALESCE(sqlc.narg('now'), NOW(6)) END, + metadata = JSON_SET(metadata, '$.cancel_attempted_at', CAST(sqlc.arg('cancel_attempted_at') AS CHAR)) +WHERE id = sqlc.arg('id') + AND state NOT IN ('cancelled', 'completed', 'discarded') + AND finalized_at IS NULL; + +-- name: JobCountByAllStates :many +SELECT state, count(*) AS count +FROM /* TEMPLATE: schema */river_job +GROUP BY state; + +-- name: JobCountByQueueAndState :many +WITH all_queues AS ( + SELECT DISTINCT river_job.queue + FROM /* TEMPLATE: schema */river_job + WHERE river_job.queue IN (sqlc.slice('queue_names')) +), + +running_job_counts AS ( + SELECT river_job.queue, COUNT(*) AS count + FROM /* TEMPLATE: schema */river_job + WHERE river_job.queue IN (sqlc.slice('queue_names')) + AND river_job.state = 'running' + GROUP BY river_job.queue +), + +available_job_counts AS ( + SELECT river_job.queue, COUNT(*) AS count + FROM /* TEMPLATE: schema */river_job + WHERE river_job.queue IN (sqlc.slice('queue_names')) + AND river_job.state = 'available' + GROUP BY river_job.queue +) + +SELECT + all_queues.queue, + COALESCE(available_job_counts.count, 0) AS count_available, + COALESCE(running_job_counts.count, 0) AS count_running +FROM all_queues +LEFT JOIN running_job_counts ON all_queues.queue = running_job_counts.queue +LEFT JOIN available_job_counts ON all_queues.queue = available_job_counts.queue +ORDER BY all_queues.queue ASC; + +-- name: JobCountByState :one +SELECT count(*) AS count +FROM /* TEMPLATE: schema */river_job +WHERE state = sqlc.arg('state'); + +-- name: JobDeleteExec :execresult +DELETE FROM /* TEMPLATE: schema */river_job +WHERE id = sqlc.arg('id') + AND river_job.state != 'running'; + +-- name: JobDeleteBefore :execresult +DELETE FROM /* TEMPLATE: schema */river_job +WHERE + id IN ( + SELECT id FROM ( + SELECT rj2.id + FROM /* TEMPLATE: schema */river_job rj2 + WHERE + (rj2.state = 'cancelled' AND rj2.finalized_at < sqlc.arg('cancelled_finalized_at_horizon')) OR + (rj2.state = 'completed' AND rj2.finalized_at < sqlc.arg('completed_finalized_at_horizon')) OR + (rj2.state = 'discarded' AND rj2.finalized_at < sqlc.arg('discarded_finalized_at_horizon')) + ORDER BY rj2.id + LIMIT ? + ) AS tmp + ) + AND ( + CAST(sqlc.arg('queues_excluded_empty') AS SIGNED) + OR river_job.queue NOT IN (sqlc.slice('queues_excluded')) + ); + +-- name: JobDeleteManySelect :many +SELECT * +FROM /* TEMPLATE: schema */river_job +WHERE id IN ( + SELECT id FROM ( + SELECT id + FROM /* TEMPLATE: schema */river_job + WHERE /* TEMPLATE_BEGIN: where_clause */ true /* TEMPLATE_END */ + AND state != 'running' + ORDER BY /* TEMPLATE_BEGIN: order_by_clause */ id /* TEMPLATE_END */ + LIMIT ? + ) AS tmp +) +ORDER BY id; + +-- name: JobDeleteManyExec :exec +DELETE FROM /* TEMPLATE: schema */river_job +WHERE id IN (sqlc.slice('id')); + +-- name: JobGetAvailableIDs :many +SELECT id +FROM /* TEMPLATE: schema */river_job +WHERE + priority >= 0 + AND queue = sqlc.arg('queue') + AND scheduled_at <= COALESCE(sqlc.narg('now'), NOW(6)) + AND state = 'available' +ORDER BY priority ASC, scheduled_at ASC, id ASC +LIMIT ? +FOR UPDATE SKIP LOCKED; + +-- name: JobGetAvailableUpdate :exec +UPDATE /* TEMPLATE: schema */river_job +SET + attempt = attempt + 1, + attempted_at = COALESCE(sqlc.narg('now'), NOW(6)), + attempted_by = JSON_ARRAY_APPEND( + CASE + WHEN JSON_LENGTH(COALESCE(attempted_by, JSON_ARRAY())) < CAST(sqlc.arg('max_attempted_by') AS SIGNED) + THEN COALESCE(attempted_by, JSON_ARRAY()) + WHEN CAST(sqlc.arg('max_attempted_by') AS SIGNED) <= 1 + THEN JSON_ARRAY() + ELSE COALESCE( + JSON_EXTRACT( + attempted_by, + CONCAT('$[last-', CAST(sqlc.arg('max_attempted_by') AS SIGNED) - 2, ' to last]') + ), + JSON_ARRAY() + ) + END, + '$', + CAST(sqlc.arg('attempted_by') AS CHAR) + ), + state = 'running' +WHERE id IN (sqlc.slice('id')); + +-- name: JobGetByIDManyOrdered :many +SELECT * +FROM /* TEMPLATE: schema */river_job +WHERE id IN (sqlc.slice('id')) +ORDER BY priority ASC, scheduled_at ASC, id ASC; + +-- name: JobGetStuck :many +SELECT * +FROM /* TEMPLATE: schema */river_job +WHERE state = 'running' + AND attempted_at < sqlc.arg('stuck_horizon') +ORDER BY id +LIMIT ?; + +-- name: JobInsertFast :execresult +INSERT INTO /* TEMPLATE: schema */river_job( + id, + args, + created_at, + kind, + max_attempts, + metadata, + priority, + queue, + scheduled_at, + state, + tags, + unique_key, + unique_states +) VALUES ( + sqlc.narg('id'), + sqlc.arg('args'), + COALESCE(sqlc.narg('created_at'), NOW(6)), + sqlc.arg('kind'), + sqlc.arg('max_attempts'), + CAST(sqlc.arg('metadata') AS JSON), + sqlc.arg('priority'), + sqlc.arg('queue'), + COALESCE(sqlc.narg('scheduled_at'), NOW(6)), + sqlc.arg('state'), + CAST(sqlc.arg('tags') AS JSON), + sqlc.narg('unique_key'), + sqlc.narg('unique_states') +) +ON DUPLICATE KEY UPDATE id = LAST_INSERT_ID(id), kind = VALUES(kind); + +-- name: JobInsertFullExec :execlastid +INSERT INTO /* TEMPLATE: schema */river_job( + args, + attempt, + attempted_at, + attempted_by, + created_at, + errors, + finalized_at, + kind, + max_attempts, + metadata, + priority, + queue, + scheduled_at, + state, + tags, + unique_key, + unique_states +) VALUES ( + sqlc.arg('args'), + sqlc.arg('attempt'), + sqlc.narg('attempted_at'), + CAST(sqlc.narg('attempted_by') AS JSON), + COALESCE(sqlc.narg('created_at'), NOW(6)), + CAST(sqlc.narg('errors') AS JSON), + sqlc.narg('finalized_at'), + sqlc.arg('kind'), + sqlc.arg('max_attempts'), + CAST(sqlc.arg('metadata') AS JSON), + sqlc.arg('priority'), + sqlc.arg('queue'), + COALESCE(sqlc.narg('scheduled_at'), NOW(6)), + sqlc.arg('state'), + CAST(sqlc.arg('tags') AS JSON), + sqlc.narg('unique_key'), + sqlc.narg('unique_states') +); + +-- name: JobKindList :many +SELECT DISTINCT kind +FROM /* TEMPLATE: schema */river_job +WHERE (sqlc.arg('match') = '' OR LOWER(kind) COLLATE utf8mb4_general_ci LIKE CONCAT('%', LOWER(sqlc.arg('match')), '%')) + AND (sqlc.arg('after') = '' OR kind COLLATE utf8mb4_general_ci > sqlc.arg('after')) + AND kind NOT IN (sqlc.slice('exclude')) +ORDER BY kind ASC +LIMIT ?; + +-- name: JobList :many +SELECT * +FROM /* TEMPLATE: schema */river_job +WHERE /* TEMPLATE_BEGIN: where_clause */ true /* TEMPLATE_END */ +ORDER BY /* TEMPLATE_BEGIN: order_by_clause */ id /* TEMPLATE_END */ +LIMIT ?; + +-- name: JobRescue :exec +UPDATE /* TEMPLATE: schema */river_job +SET + errors = JSON_ARRAY_APPEND(COALESCE(errors, JSON_ARRAY()), '$', CAST(sqlc.arg('error') AS JSON)), + finalized_at = sqlc.narg('finalized_at'), + scheduled_at = sqlc.arg('scheduled_at'), + metadata = JSON_SET( + metadata, + '$."river:rescue_count"', + COALESCE( + CASE JSON_TYPE(JSON_EXTRACT(metadata, '$."river:rescue_count"')) + WHEN 'INTEGER' THEN JSON_EXTRACT(metadata, '$."river:rescue_count"') + WHEN 'DOUBLE' THEN JSON_EXTRACT(metadata, '$."river:rescue_count"') + ELSE NULL + END, + 0 + ) + 1 + ), + state = sqlc.arg('state') +WHERE id = sqlc.arg('id'); + +-- name: JobRetryExec :execresult +UPDATE /* TEMPLATE: schema */river_job +SET + state = 'available', + max_attempts = CASE WHEN attempt = max_attempts THEN max_attempts + 1 ELSE max_attempts END, + finalized_at = NULL, + scheduled_at = COALESCE(sqlc.narg('now'), NOW(6)) +WHERE id = sqlc.arg('id') + AND state != 'running' + AND ( + state <> 'available' + OR scheduled_at > COALESCE(sqlc.narg('now'), NOW(6)) + ); + +-- name: JobSchedule :many +WITH eligible AS ( + SELECT river_job.id, river_job.unique_key, river_job.unique_states, river_job.priority, river_job.scheduled_at, + CASE + WHEN river_job.unique_key IS NOT NULL AND river_job.unique_states IS NOT NULL THEN + ROW_NUMBER() OVER (PARTITION BY river_job.unique_key ORDER BY river_job.priority, river_job.scheduled_at, river_job.id) + ELSE NULL + END AS row_num + FROM /* TEMPLATE: schema */river_job + WHERE + river_job.state IN ('retryable', 'scheduled') + AND river_job.scheduled_at <= COALESCE(sqlc.narg('now'), NOW(6)) + ORDER BY + river_job.priority, + river_job.scheduled_at, + river_job.id + LIMIT ? +), +unique_conflicts AS ( + SELECT DISTINCT eligible.unique_key + FROM /* TEMPLATE: schema */river_job + JOIN eligible + ON river_job.unique_key = eligible.unique_key + AND river_job.id != eligible.id + WHERE + river_job.unique_key IS NOT NULL + AND river_job.unique_states IS NOT NULL + AND CASE river_job.state + WHEN 'available' THEN river_job.unique_states & (1 << 0) + WHEN 'cancelled' THEN river_job.unique_states & (1 << 1) + WHEN 'completed' THEN river_job.unique_states & (1 << 2) + WHEN 'discarded' THEN river_job.unique_states & (1 << 3) + WHEN 'pending' THEN river_job.unique_states & (1 << 4) + WHEN 'retryable' THEN river_job.unique_states & (1 << 5) + WHEN 'running' THEN river_job.unique_states & (1 << 6) + WHEN 'scheduled' THEN river_job.unique_states & (1 << 7) + ELSE 0 + END >= 1 +) +SELECT eligible.id, + CASE + WHEN eligible.unique_key IS NULL OR eligible.unique_states IS NULL THEN FALSE + WHEN uc.unique_key IS NOT NULL THEN TRUE + WHEN eligible.row_num > 1 THEN TRUE + ELSE FALSE + END AS conflict_discarded +FROM eligible +LEFT JOIN unique_conflicts uc ON eligible.unique_key = uc.unique_key +ORDER BY eligible.priority, eligible.scheduled_at, eligible.id; + +-- name: JobScheduleSetAvailableExec :exec +UPDATE /* TEMPLATE: schema */river_job +SET state = 'available' +WHERE id IN (sqlc.slice('id')); + +-- name: JobScheduleSetDiscardedExec :exec +UPDATE /* TEMPLATE: schema */river_job +SET metadata = JSON_MERGE_PATCH(metadata, '{"unique_key_conflict": "scheduler_discarded"}'), + finalized_at = COALESCE(sqlc.narg('now'), NOW(6)), + state = 'discarded' +WHERE id IN (sqlc.slice('id')); + +-- name: JobSetMetadataIfNotRunningExec :execresult +UPDATE /* TEMPLATE: schema */river_job +SET metadata = JSON_MERGE_PATCH(metadata, CAST(sqlc.arg('metadata_updates') AS JSON)) +WHERE id = sqlc.arg('id') + AND state != 'running'; + +-- name: JobSetStateIfRunningExec :exec +UPDATE /* TEMPLATE: schema */river_job +SET + attempt = CASE WHEN (CAST(sqlc.arg('state') AS CHAR) <> 'retryable' AND sqlc.arg('state') <> 'scheduled' OR JSON_EXTRACT(metadata, '$.cancel_attempted_at') IS NULL) AND CAST(sqlc.arg('attempt_do_update') AS SIGNED) + THEN sqlc.arg('attempt') + ELSE attempt END, + errors = CASE WHEN CAST(sqlc.arg('errors_do_update') AS SIGNED) + THEN JSON_ARRAY_APPEND(COALESCE(errors, JSON_ARRAY()), '$', CAST(sqlc.arg('error') AS JSON)) + ELSE errors END, + finalized_at = CASE WHEN ((sqlc.arg('state') = 'retryable' OR sqlc.arg('state') = 'scheduled') AND JSON_EXTRACT(metadata, '$.cancel_attempted_at') IS NOT NULL) + THEN COALESCE(sqlc.narg('now'), NOW(6)) + WHEN CAST(sqlc.arg('finalized_at_do_update') AS SIGNED) + THEN sqlc.narg('finalized_at') + ELSE finalized_at END, + metadata = CASE WHEN CAST(sqlc.arg('metadata_do_merge') AS SIGNED) + THEN JSON_MERGE_PATCH(metadata, CAST(sqlc.arg('metadata_updates') AS JSON)) + ELSE metadata END, + scheduled_at = CASE WHEN (CAST(sqlc.arg('state') AS CHAR) <> 'retryable' AND sqlc.arg('state') <> 'scheduled' OR JSON_EXTRACT(metadata, '$.cancel_attempted_at') IS NULL) AND CAST(sqlc.arg('scheduled_at_do_update') AS SIGNED) + THEN sqlc.arg('scheduled_at') + ELSE scheduled_at END, + state = CASE WHEN ((sqlc.arg('state') = 'retryable' OR sqlc.arg('state') = 'scheduled') AND JSON_EXTRACT(metadata, '$.cancel_attempted_at') IS NOT NULL) + THEN 'cancelled' + ELSE sqlc.arg('state') END +WHERE id = sqlc.arg('id') + AND state = 'running'; + +-- name: JobUpdateExec :exec +UPDATE /* TEMPLATE: schema */river_job +SET + metadata = CASE WHEN CAST(sqlc.arg('metadata_do_merge') AS SIGNED) THEN JSON_MERGE_PATCH(metadata, CAST(sqlc.arg('metadata') AS JSON)) ELSE metadata END +WHERE id = sqlc.arg('id'); + +-- name: JobUpdateFullExec :exec +UPDATE /* TEMPLATE: schema */river_job +SET + attempt = CASE WHEN CAST(sqlc.arg('attempt_do_update') AS SIGNED) THEN sqlc.arg('attempt') ELSE attempt END, + attempted_at = CASE WHEN CAST(sqlc.arg('attempted_at_do_update') AS SIGNED) THEN sqlc.narg('attempted_at') ELSE attempted_at END, + attempted_by = CASE WHEN CAST(sqlc.arg('attempted_by_do_update') AS SIGNED) THEN CAST(sqlc.arg('attempted_by') AS JSON) ELSE attempted_by END, + errors = CASE WHEN CAST(sqlc.arg('errors_do_update') AS SIGNED) THEN CAST(sqlc.arg('errors') AS JSON) ELSE errors END, + finalized_at = CASE WHEN CAST(sqlc.arg('finalized_at_do_update') AS SIGNED) THEN sqlc.narg('finalized_at') ELSE finalized_at END, + max_attempts = CASE WHEN CAST(sqlc.arg('max_attempts_do_update') AS SIGNED) THEN sqlc.arg('max_attempts') ELSE max_attempts END, + metadata = CASE WHEN CAST(sqlc.arg('metadata_do_update') AS SIGNED) THEN CAST(sqlc.arg('metadata') AS JSON) ELSE metadata END, + state = CASE WHEN CAST(sqlc.arg('state_do_update') AS SIGNED) THEN sqlc.arg('state') ELSE state END +WHERE id = sqlc.arg('id'); diff --git a/riverdriver/rivermysql/internal/dbsqlc/river_job.sql.go b/riverdriver/rivermysql/internal/dbsqlc/river_job.sql.go new file mode 100644 index 00000000..72e8047d --- /dev/null +++ b/riverdriver/rivermysql/internal/dbsqlc/river_job.sql.go @@ -0,0 +1,1289 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.31.0 +// source: river_job.sql + +package dbsqlc + +import ( + "context" + "database/sql" + "strings" + "time" +) + +const jobCancelExec = `-- name: JobCancelExec :execresult +UPDATE /* TEMPLATE: schema */river_job +SET + state = CASE WHEN state = 'running' THEN state ELSE 'cancelled' END, + finalized_at = CASE WHEN state = 'running' THEN finalized_at ELSE COALESCE(?, NOW(6)) END, + metadata = JSON_SET(metadata, '$.cancel_attempted_at', CAST(? AS CHAR)) +WHERE id = ? + AND state NOT IN ('cancelled', 'completed', 'discarded') + AND finalized_at IS NULL +` + +type JobCancelExecParams struct { + Now sql.NullTime + CancelAttemptedAt interface{} + ID int64 +} + +func (q *Queries) JobCancelExec(ctx context.Context, db DBTX, arg *JobCancelExecParams) (sql.Result, error) { + return db.ExecContext(ctx, jobCancelExec, arg.Now, arg.CancelAttemptedAt, arg.ID) +} + +const jobCountByAllStates = `-- name: JobCountByAllStates :many +SELECT state, count(*) AS count +FROM /* TEMPLATE: schema */river_job +GROUP BY state +` + +type JobCountByAllStatesRow struct { + State string + Count int64 +} + +func (q *Queries) JobCountByAllStates(ctx context.Context, db DBTX) ([]*JobCountByAllStatesRow, error) { + rows, err := db.QueryContext(ctx, jobCountByAllStates) + if err != nil { + return nil, err + } + defer rows.Close() + var items []*JobCountByAllStatesRow + for rows.Next() { + var i JobCountByAllStatesRow + if err := rows.Scan(&i.State, &i.Count); err != nil { + return nil, err + } + items = append(items, &i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const jobCountByQueueAndState = `-- name: JobCountByQueueAndState :many +WITH all_queues AS ( + SELECT DISTINCT river_job.queue + FROM /* TEMPLATE: schema */river_job + WHERE river_job.queue IN (/*SLICE:queue_names*/?) +), + +running_job_counts AS ( + SELECT river_job.queue, COUNT(*) AS count + FROM /* TEMPLATE: schema */river_job + WHERE river_job.queue IN (/*SLICE:queue_names*/?) + AND river_job.state = 'running' + GROUP BY river_job.queue +), + +available_job_counts AS ( + SELECT river_job.queue, COUNT(*) AS count + FROM /* TEMPLATE: schema */river_job + WHERE river_job.queue IN (/*SLICE:queue_names*/?) + AND river_job.state = 'available' + GROUP BY river_job.queue +) + +SELECT + all_queues.queue, + COALESCE(available_job_counts.count, 0) AS count_available, + COALESCE(running_job_counts.count, 0) AS count_running +FROM all_queues +LEFT JOIN running_job_counts ON all_queues.queue = running_job_counts.queue +LEFT JOIN available_job_counts ON all_queues.queue = available_job_counts.queue +ORDER BY all_queues.queue ASC +` + +type JobCountByQueueAndStateParams struct { + QueueNames []string +} + +type JobCountByQueueAndStateRow struct { + Queue string + CountAvailable int64 + CountRunning int64 +} + +func (q *Queries) JobCountByQueueAndState(ctx context.Context, db DBTX, arg *JobCountByQueueAndStateParams) ([]*JobCountByQueueAndStateRow, error) { + query := jobCountByQueueAndState + var queryParams []interface{} + if len(arg.QueueNames) > 0 { + for _, v := range arg.QueueNames { + queryParams = append(queryParams, v) + } + query = strings.Replace(query, "/*SLICE:queue_names*/?", strings.Repeat(",?", len(arg.QueueNames))[1:], 1) + } else { + query = strings.Replace(query, "/*SLICE:queue_names*/?", "NULL", 1) + } + if len(arg.QueueNames) > 0 { + for _, v := range arg.QueueNames { + queryParams = append(queryParams, v) + } + query = strings.Replace(query, "/*SLICE:queue_names*/?", strings.Repeat(",?", len(arg.QueueNames))[1:], 1) + } else { + query = strings.Replace(query, "/*SLICE:queue_names*/?", "NULL", 1) + } + if len(arg.QueueNames) > 0 { + for _, v := range arg.QueueNames { + queryParams = append(queryParams, v) + } + query = strings.Replace(query, "/*SLICE:queue_names*/?", strings.Repeat(",?", len(arg.QueueNames))[1:], 1) + } else { + query = strings.Replace(query, "/*SLICE:queue_names*/?", "NULL", 1) + } + rows, err := db.QueryContext(ctx, query, queryParams...) + if err != nil { + return nil, err + } + defer rows.Close() + var items []*JobCountByQueueAndStateRow + for rows.Next() { + var i JobCountByQueueAndStateRow + if err := rows.Scan(&i.Queue, &i.CountAvailable, &i.CountRunning); err != nil { + return nil, err + } + items = append(items, &i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const jobCountByState = `-- name: JobCountByState :one +SELECT count(*) AS count +FROM /* TEMPLATE: schema */river_job +WHERE state = ? +` + +func (q *Queries) JobCountByState(ctx context.Context, db DBTX, state string) (int64, error) { + row := db.QueryRowContext(ctx, jobCountByState, state) + var count int64 + err := row.Scan(&count) + return count, err +} + +const jobDeleteBefore = `-- name: JobDeleteBefore :execresult +DELETE FROM /* TEMPLATE: schema */river_job +WHERE + id IN ( + SELECT id FROM ( + SELECT rj2.id + FROM /* TEMPLATE: schema */river_job rj2 + WHERE + (rj2.state = 'cancelled' AND rj2.finalized_at < ?) OR + (rj2.state = 'completed' AND rj2.finalized_at < ?) OR + (rj2.state = 'discarded' AND rj2.finalized_at < ?) + ORDER BY rj2.id + LIMIT ? + ) AS tmp + ) + AND ( + CAST(? AS SIGNED) + OR river_job.queue NOT IN (/*SLICE:queues_excluded*/?) + ) +` + +type JobDeleteBeforeParams struct { + CancelledFinalizedAtHorizon sql.NullTime + CompletedFinalizedAtHorizon sql.NullTime + DiscardedFinalizedAtHorizon sql.NullTime + Limit int32 + QueuesExcludedEmpty int64 + QueuesExcluded []string +} + +func (q *Queries) JobDeleteBefore(ctx context.Context, db DBTX, arg *JobDeleteBeforeParams) (sql.Result, error) { + query := jobDeleteBefore + var queryParams []interface{} + queryParams = append(queryParams, arg.CancelledFinalizedAtHorizon) + queryParams = append(queryParams, arg.CompletedFinalizedAtHorizon) + queryParams = append(queryParams, arg.DiscardedFinalizedAtHorizon) + queryParams = append(queryParams, arg.Limit) + queryParams = append(queryParams, arg.QueuesExcludedEmpty) + if len(arg.QueuesExcluded) > 0 { + for _, v := range arg.QueuesExcluded { + queryParams = append(queryParams, v) + } + query = strings.Replace(query, "/*SLICE:queues_excluded*/?", strings.Repeat(",?", len(arg.QueuesExcluded))[1:], 1) + } else { + query = strings.Replace(query, "/*SLICE:queues_excluded*/?", "NULL", 1) + } + return db.ExecContext(ctx, query, queryParams...) +} + +const jobDeleteExec = `-- name: JobDeleteExec :execresult +DELETE FROM /* TEMPLATE: schema */river_job +WHERE id = ? + AND river_job.state != 'running' +` + +func (q *Queries) JobDeleteExec(ctx context.Context, db DBTX, id int64) (sql.Result, error) { + return db.ExecContext(ctx, jobDeleteExec, id) +} + +const jobDeleteManyExec = `-- name: JobDeleteManyExec :exec +DELETE FROM /* TEMPLATE: schema */river_job +WHERE id IN (/*SLICE:id*/?) +` + +func (q *Queries) JobDeleteManyExec(ctx context.Context, db DBTX, id []int64) error { + query := jobDeleteManyExec + var queryParams []interface{} + if len(id) > 0 { + for _, v := range id { + queryParams = append(queryParams, v) + } + query = strings.Replace(query, "/*SLICE:id*/?", strings.Repeat(",?", len(id))[1:], 1) + } else { + query = strings.Replace(query, "/*SLICE:id*/?", "NULL", 1) + } + _, err := db.ExecContext(ctx, query, queryParams...) + return err +} + +const jobDeleteManySelect = `-- name: JobDeleteManySelect :many +SELECT id, args, attempt, attempted_at, attempted_by, created_at, errors, finalized_at, kind, max_attempts, metadata, priority, queue, state, scheduled_at, tags, unique_key, unique_states +FROM /* TEMPLATE: schema */river_job +WHERE id IN ( + SELECT id FROM ( + SELECT id + FROM /* TEMPLATE: schema */river_job + WHERE /* TEMPLATE_BEGIN: where_clause */ true /* TEMPLATE_END */ + AND state != 'running' + ORDER BY /* TEMPLATE_BEGIN: order_by_clause */ id /* TEMPLATE_END */ + LIMIT ? + ) AS tmp +) +ORDER BY id +` + +func (q *Queries) JobDeleteManySelect(ctx context.Context, db DBTX, limit int32) ([]*RiverJob, error) { + rows, err := db.QueryContext(ctx, jobDeleteManySelect, limit) + if err != nil { + return nil, err + } + defer rows.Close() + var items []*RiverJob + for rows.Next() { + var i RiverJob + if err := rows.Scan( + &i.ID, + &i.Args, + &i.Attempt, + &i.AttemptedAt, + &i.AttemptedBy, + &i.CreatedAt, + &i.Errors, + &i.FinalizedAt, + &i.Kind, + &i.MaxAttempts, + &i.Metadata, + &i.Priority, + &i.Queue, + &i.State, + &i.ScheduledAt, + &i.Tags, + &i.UniqueKey, + &i.UniqueStates, + ); err != nil { + return nil, err + } + items = append(items, &i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const jobGetAvailableIDs = `-- name: JobGetAvailableIDs :many +SELECT id +FROM /* TEMPLATE: schema */river_job +WHERE + priority >= 0 + AND queue = ? + AND scheduled_at <= COALESCE(?, NOW(6)) + AND state = 'available' +ORDER BY priority ASC, scheduled_at ASC, id ASC +LIMIT ? +FOR UPDATE SKIP LOCKED +` + +type JobGetAvailableIDsParams struct { + Queue string + Now sql.NullTime + Limit int32 +} + +func (q *Queries) JobGetAvailableIDs(ctx context.Context, db DBTX, arg *JobGetAvailableIDsParams) ([]int64, error) { + rows, err := db.QueryContext(ctx, jobGetAvailableIDs, arg.Queue, arg.Now, arg.Limit) + if err != nil { + return nil, err + } + defer rows.Close() + var items []int64 + for rows.Next() { + var id int64 + if err := rows.Scan(&id); err != nil { + return nil, err + } + items = append(items, id) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const jobGetAvailableUpdate = `-- name: JobGetAvailableUpdate :exec +UPDATE /* TEMPLATE: schema */river_job +SET + attempt = attempt + 1, + attempted_at = COALESCE(?, NOW(6)), + attempted_by = JSON_ARRAY_APPEND( + CASE + WHEN JSON_LENGTH(COALESCE(attempted_by, JSON_ARRAY())) < CAST(? AS SIGNED) + THEN COALESCE(attempted_by, JSON_ARRAY()) + WHEN CAST(? AS SIGNED) <= 1 + THEN JSON_ARRAY() + ELSE COALESCE( + JSON_EXTRACT( + attempted_by, + CONCAT('$[last-', CAST(? AS SIGNED) - 2, ' to last]') + ), + JSON_ARRAY() + ) + END, + '$', + CAST(? AS CHAR) + ), + state = 'running' +WHERE id IN (/*SLICE:id*/?) +` + +type JobGetAvailableUpdateParams struct { + Now sql.NullTime + MaxAttemptedBy int64 + AttemptedBy interface{} + ID []int64 +} + +func (q *Queries) JobGetAvailableUpdate(ctx context.Context, db DBTX, arg *JobGetAvailableUpdateParams) error { + query := jobGetAvailableUpdate + var queryParams []interface{} + queryParams = append(queryParams, arg.Now) + queryParams = append(queryParams, arg.MaxAttemptedBy) + queryParams = append(queryParams, arg.MaxAttemptedBy) + queryParams = append(queryParams, arg.MaxAttemptedBy) + queryParams = append(queryParams, arg.AttemptedBy) + if len(arg.ID) > 0 { + for _, v := range arg.ID { + queryParams = append(queryParams, v) + } + query = strings.Replace(query, "/*SLICE:id*/?", strings.Repeat(",?", len(arg.ID))[1:], 1) + } else { + query = strings.Replace(query, "/*SLICE:id*/?", "NULL", 1) + } + _, err := db.ExecContext(ctx, query, queryParams...) + return err +} + +const jobGetByID = `-- name: JobGetByID :one +SELECT id, args, attempt, attempted_at, attempted_by, created_at, errors, finalized_at, kind, max_attempts, metadata, priority, queue, state, scheduled_at, tags, unique_key, unique_states +FROM /* TEMPLATE: schema */river_job +WHERE id = ? +LIMIT 1 +` + +func (q *Queries) JobGetByID(ctx context.Context, db DBTX, id int64) (*RiverJob, error) { + row := db.QueryRowContext(ctx, jobGetByID, id) + var i RiverJob + err := row.Scan( + &i.ID, + &i.Args, + &i.Attempt, + &i.AttemptedAt, + &i.AttemptedBy, + &i.CreatedAt, + &i.Errors, + &i.FinalizedAt, + &i.Kind, + &i.MaxAttempts, + &i.Metadata, + &i.Priority, + &i.Queue, + &i.State, + &i.ScheduledAt, + &i.Tags, + &i.UniqueKey, + &i.UniqueStates, + ) + return &i, err +} + +const jobGetByIDMany = `-- name: JobGetByIDMany :many +SELECT id, args, attempt, attempted_at, attempted_by, created_at, errors, finalized_at, kind, max_attempts, metadata, priority, queue, state, scheduled_at, tags, unique_key, unique_states +FROM /* TEMPLATE: schema */river_job +WHERE id IN (/*SLICE:id*/?) +ORDER BY id +` + +func (q *Queries) JobGetByIDMany(ctx context.Context, db DBTX, id []int64) ([]*RiverJob, error) { + query := jobGetByIDMany + var queryParams []interface{} + if len(id) > 0 { + for _, v := range id { + queryParams = append(queryParams, v) + } + query = strings.Replace(query, "/*SLICE:id*/?", strings.Repeat(",?", len(id))[1:], 1) + } else { + query = strings.Replace(query, "/*SLICE:id*/?", "NULL", 1) + } + rows, err := db.QueryContext(ctx, query, queryParams...) + if err != nil { + return nil, err + } + defer rows.Close() + var items []*RiverJob + for rows.Next() { + var i RiverJob + if err := rows.Scan( + &i.ID, + &i.Args, + &i.Attempt, + &i.AttemptedAt, + &i.AttemptedBy, + &i.CreatedAt, + &i.Errors, + &i.FinalizedAt, + &i.Kind, + &i.MaxAttempts, + &i.Metadata, + &i.Priority, + &i.Queue, + &i.State, + &i.ScheduledAt, + &i.Tags, + &i.UniqueKey, + &i.UniqueStates, + ); err != nil { + return nil, err + } + items = append(items, &i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const jobGetByIDManyOrdered = `-- name: JobGetByIDManyOrdered :many +SELECT id, args, attempt, attempted_at, attempted_by, created_at, errors, finalized_at, kind, max_attempts, metadata, priority, queue, state, scheduled_at, tags, unique_key, unique_states +FROM /* TEMPLATE: schema */river_job +WHERE id IN (/*SLICE:id*/?) +ORDER BY priority ASC, scheduled_at ASC, id ASC +` + +func (q *Queries) JobGetByIDManyOrdered(ctx context.Context, db DBTX, id []int64) ([]*RiverJob, error) { + query := jobGetByIDManyOrdered + var queryParams []interface{} + if len(id) > 0 { + for _, v := range id { + queryParams = append(queryParams, v) + } + query = strings.Replace(query, "/*SLICE:id*/?", strings.Repeat(",?", len(id))[1:], 1) + } else { + query = strings.Replace(query, "/*SLICE:id*/?", "NULL", 1) + } + rows, err := db.QueryContext(ctx, query, queryParams...) + if err != nil { + return nil, err + } + defer rows.Close() + var items []*RiverJob + for rows.Next() { + var i RiverJob + if err := rows.Scan( + &i.ID, + &i.Args, + &i.Attempt, + &i.AttemptedAt, + &i.AttemptedBy, + &i.CreatedAt, + &i.Errors, + &i.FinalizedAt, + &i.Kind, + &i.MaxAttempts, + &i.Metadata, + &i.Priority, + &i.Queue, + &i.State, + &i.ScheduledAt, + &i.Tags, + &i.UniqueKey, + &i.UniqueStates, + ); err != nil { + return nil, err + } + items = append(items, &i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const jobGetByKindMany = `-- name: JobGetByKindMany :many +SELECT id, args, attempt, attempted_at, attempted_by, created_at, errors, finalized_at, kind, max_attempts, metadata, priority, queue, state, scheduled_at, tags, unique_key, unique_states +FROM /* TEMPLATE: schema */river_job +WHERE kind IN (/*SLICE:kind*/?) +ORDER BY id +` + +func (q *Queries) JobGetByKindMany(ctx context.Context, db DBTX, kind []string) ([]*RiverJob, error) { + query := jobGetByKindMany + var queryParams []interface{} + if len(kind) > 0 { + for _, v := range kind { + queryParams = append(queryParams, v) + } + query = strings.Replace(query, "/*SLICE:kind*/?", strings.Repeat(",?", len(kind))[1:], 1) + } else { + query = strings.Replace(query, "/*SLICE:kind*/?", "NULL", 1) + } + rows, err := db.QueryContext(ctx, query, queryParams...) + if err != nil { + return nil, err + } + defer rows.Close() + var items []*RiverJob + for rows.Next() { + var i RiverJob + if err := rows.Scan( + &i.ID, + &i.Args, + &i.Attempt, + &i.AttemptedAt, + &i.AttemptedBy, + &i.CreatedAt, + &i.Errors, + &i.FinalizedAt, + &i.Kind, + &i.MaxAttempts, + &i.Metadata, + &i.Priority, + &i.Queue, + &i.State, + &i.ScheduledAt, + &i.Tags, + &i.UniqueKey, + &i.UniqueStates, + ); err != nil { + return nil, err + } + items = append(items, &i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const jobGetStuck = `-- name: JobGetStuck :many +SELECT id, args, attempt, attempted_at, attempted_by, created_at, errors, finalized_at, kind, max_attempts, metadata, priority, queue, state, scheduled_at, tags, unique_key, unique_states +FROM /* TEMPLATE: schema */river_job +WHERE state = 'running' + AND attempted_at < ? +ORDER BY id +LIMIT ? +` + +type JobGetStuckParams struct { + StuckHorizon sql.NullTime + Limit int32 +} + +func (q *Queries) JobGetStuck(ctx context.Context, db DBTX, arg *JobGetStuckParams) ([]*RiverJob, error) { + rows, err := db.QueryContext(ctx, jobGetStuck, arg.StuckHorizon, arg.Limit) + if err != nil { + return nil, err + } + defer rows.Close() + var items []*RiverJob + for rows.Next() { + var i RiverJob + if err := rows.Scan( + &i.ID, + &i.Args, + &i.Attempt, + &i.AttemptedAt, + &i.AttemptedBy, + &i.CreatedAt, + &i.Errors, + &i.FinalizedAt, + &i.Kind, + &i.MaxAttempts, + &i.Metadata, + &i.Priority, + &i.Queue, + &i.State, + &i.ScheduledAt, + &i.Tags, + &i.UniqueKey, + &i.UniqueStates, + ); err != nil { + return nil, err + } + items = append(items, &i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const jobInsertFast = `-- name: JobInsertFast :execresult +INSERT INTO /* TEMPLATE: schema */river_job( + id, + args, + created_at, + kind, + max_attempts, + metadata, + priority, + queue, + scheduled_at, + state, + tags, + unique_key, + unique_states +) VALUES ( + ?, + ?, + COALESCE(?, NOW(6)), + ?, + ?, + CAST(? AS JSON), + ?, + ?, + COALESCE(?, NOW(6)), + ?, + CAST(? AS JSON), + ?, + ? +) +ON DUPLICATE KEY UPDATE id = LAST_INSERT_ID(id), kind = VALUES(kind) +` + +type JobInsertFastParams struct { + ID sql.NullInt64 + Args []byte + CreatedAt interface{} + Kind string + MaxAttempts int64 + Metadata []byte + Priority int16 + Queue string + ScheduledAt interface{} + State string + Tags []byte + UniqueKey sql.NullString + UniqueStates sql.NullInt16 +} + +func (q *Queries) JobInsertFast(ctx context.Context, db DBTX, arg *JobInsertFastParams) (sql.Result, error) { + return db.ExecContext(ctx, jobInsertFast, + arg.ID, + arg.Args, + arg.CreatedAt, + arg.Kind, + arg.MaxAttempts, + arg.Metadata, + arg.Priority, + arg.Queue, + arg.ScheduledAt, + arg.State, + arg.Tags, + arg.UniqueKey, + arg.UniqueStates, + ) +} + +const jobInsertFullExec = `-- name: JobInsertFullExec :execlastid +INSERT INTO /* TEMPLATE: schema */river_job( + args, + attempt, + attempted_at, + attempted_by, + created_at, + errors, + finalized_at, + kind, + max_attempts, + metadata, + priority, + queue, + scheduled_at, + state, + tags, + unique_key, + unique_states +) VALUES ( + ?, + ?, + ?, + CAST(? AS JSON), + COALESCE(?, NOW(6)), + CAST(? AS JSON), + ?, + ?, + ?, + CAST(? AS JSON), + ?, + ?, + COALESCE(?, NOW(6)), + ?, + CAST(? AS JSON), + ?, + ? +) +` + +type JobInsertFullExecParams struct { + Args []byte + Attempt int64 + AttemptedAt sql.NullTime + AttemptedBy []byte + CreatedAt interface{} + Errors []byte + FinalizedAt sql.NullTime + Kind string + MaxAttempts int64 + Metadata []byte + Priority int16 + Queue string + ScheduledAt interface{} + State string + Tags []byte + UniqueKey sql.NullString + UniqueStates sql.NullInt16 +} + +func (q *Queries) JobInsertFullExec(ctx context.Context, db DBTX, arg *JobInsertFullExecParams) (int64, error) { + result, err := db.ExecContext(ctx, jobInsertFullExec, + arg.Args, + arg.Attempt, + arg.AttemptedAt, + arg.AttemptedBy, + arg.CreatedAt, + arg.Errors, + arg.FinalizedAt, + arg.Kind, + arg.MaxAttempts, + arg.Metadata, + arg.Priority, + arg.Queue, + arg.ScheduledAt, + arg.State, + arg.Tags, + arg.UniqueKey, + arg.UniqueStates, + ) + if err != nil { + return 0, err + } + return result.LastInsertId() +} + +const jobKindList = `-- name: JobKindList :many +SELECT DISTINCT kind +FROM /* TEMPLATE: schema */river_job +WHERE (? = '' OR LOWER(kind) COLLATE utf8mb4_general_ci LIKE CONCAT('%', LOWER(?), '%')) + AND (? = '' OR kind COLLATE utf8mb4_general_ci > ?) + AND kind NOT IN (/*SLICE:exclude*/?) +ORDER BY kind ASC +LIMIT ? +` + +type JobKindListParams struct { + Match string + After interface{} + Exclude []string + Limit int32 +} + +func (q *Queries) JobKindList(ctx context.Context, db DBTX, arg *JobKindListParams) ([]string, error) { + query := jobKindList + var queryParams []interface{} + queryParams = append(queryParams, arg.Match) + queryParams = append(queryParams, arg.Match) + queryParams = append(queryParams, arg.After) + queryParams = append(queryParams, arg.After) + if len(arg.Exclude) > 0 { + for _, v := range arg.Exclude { + queryParams = append(queryParams, v) + } + query = strings.Replace(query, "/*SLICE:exclude*/?", strings.Repeat(",?", len(arg.Exclude))[1:], 1) + } else { + query = strings.Replace(query, "/*SLICE:exclude*/?", "NULL", 1) + } + queryParams = append(queryParams, arg.Limit) + rows, err := db.QueryContext(ctx, query, queryParams...) + if err != nil { + return nil, err + } + defer rows.Close() + var items []string + for rows.Next() { + var kind string + if err := rows.Scan(&kind); err != nil { + return nil, err + } + items = append(items, kind) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const jobList = `-- name: JobList :many +SELECT id, args, attempt, attempted_at, attempted_by, created_at, errors, finalized_at, kind, max_attempts, metadata, priority, queue, state, scheduled_at, tags, unique_key, unique_states +FROM /* TEMPLATE: schema */river_job +WHERE /* TEMPLATE_BEGIN: where_clause */ true /* TEMPLATE_END */ +ORDER BY /* TEMPLATE_BEGIN: order_by_clause */ id /* TEMPLATE_END */ +LIMIT ? +` + +func (q *Queries) JobList(ctx context.Context, db DBTX, limit int32) ([]*RiverJob, error) { + rows, err := db.QueryContext(ctx, jobList, limit) + if err != nil { + return nil, err + } + defer rows.Close() + var items []*RiverJob + for rows.Next() { + var i RiverJob + if err := rows.Scan( + &i.ID, + &i.Args, + &i.Attempt, + &i.AttemptedAt, + &i.AttemptedBy, + &i.CreatedAt, + &i.Errors, + &i.FinalizedAt, + &i.Kind, + &i.MaxAttempts, + &i.Metadata, + &i.Priority, + &i.Queue, + &i.State, + &i.ScheduledAt, + &i.Tags, + &i.UniqueKey, + &i.UniqueStates, + ); err != nil { + return nil, err + } + items = append(items, &i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const jobRescue = `-- name: JobRescue :exec +UPDATE /* TEMPLATE: schema */river_job +SET + errors = JSON_ARRAY_APPEND(COALESCE(errors, JSON_ARRAY()), '$', CAST(? AS JSON)), + finalized_at = ?, + scheduled_at = ?, + metadata = JSON_SET( + metadata, + '$."river:rescue_count"', + COALESCE( + CASE JSON_TYPE(JSON_EXTRACT(metadata, '$."river:rescue_count"')) + WHEN 'INTEGER' THEN JSON_EXTRACT(metadata, '$."river:rescue_count"') + WHEN 'DOUBLE' THEN JSON_EXTRACT(metadata, '$."river:rescue_count"') + ELSE NULL + END, + 0 + ) + 1 + ), + state = ? +WHERE id = ? +` + +type JobRescueParams struct { + Error []byte + FinalizedAt sql.NullTime + ScheduledAt time.Time + State string + ID int64 +} + +func (q *Queries) JobRescue(ctx context.Context, db DBTX, arg *JobRescueParams) error { + _, err := db.ExecContext(ctx, jobRescue, + arg.Error, + arg.FinalizedAt, + arg.ScheduledAt, + arg.State, + arg.ID, + ) + return err +} + +const jobRetryExec = `-- name: JobRetryExec :execresult +UPDATE /* TEMPLATE: schema */river_job +SET + state = 'available', + max_attempts = CASE WHEN attempt = max_attempts THEN max_attempts + 1 ELSE max_attempts END, + finalized_at = NULL, + scheduled_at = COALESCE(?, NOW(6)) +WHERE id = ? + AND state != 'running' + AND ( + state <> 'available' + OR scheduled_at > COALESCE(?, NOW(6)) + ) +` + +type JobRetryExecParams struct { + Now sql.NullTime + ID int64 +} + +func (q *Queries) JobRetryExec(ctx context.Context, db DBTX, arg *JobRetryExecParams) (sql.Result, error) { + return db.ExecContext(ctx, jobRetryExec, arg.Now, arg.ID, arg.Now) +} + +const jobSchedule = `-- name: JobSchedule :many +WITH eligible AS ( + SELECT river_job.id, river_job.unique_key, river_job.unique_states, river_job.priority, river_job.scheduled_at, + CASE + WHEN river_job.unique_key IS NOT NULL AND river_job.unique_states IS NOT NULL THEN + ROW_NUMBER() OVER (PARTITION BY river_job.unique_key ORDER BY river_job.priority, river_job.scheduled_at, river_job.id) + ELSE NULL + END AS row_num + FROM /* TEMPLATE: schema */river_job + WHERE + river_job.state IN ('retryable', 'scheduled') + AND river_job.scheduled_at <= COALESCE(?, NOW(6)) + ORDER BY + river_job.priority, + river_job.scheduled_at, + river_job.id + LIMIT ? +), +unique_conflicts AS ( + SELECT DISTINCT eligible.unique_key + FROM /* TEMPLATE: schema */river_job + JOIN eligible + ON river_job.unique_key = eligible.unique_key + AND river_job.id != eligible.id + WHERE + river_job.unique_key IS NOT NULL + AND river_job.unique_states IS NOT NULL + AND CASE river_job.state + WHEN 'available' THEN river_job.unique_states & (1 << 0) + WHEN 'cancelled' THEN river_job.unique_states & (1 << 1) + WHEN 'completed' THEN river_job.unique_states & (1 << 2) + WHEN 'discarded' THEN river_job.unique_states & (1 << 3) + WHEN 'pending' THEN river_job.unique_states & (1 << 4) + WHEN 'retryable' THEN river_job.unique_states & (1 << 5) + WHEN 'running' THEN river_job.unique_states & (1 << 6) + WHEN 'scheduled' THEN river_job.unique_states & (1 << 7) + ELSE 0 + END >= 1 +) +SELECT eligible.id, + CASE + WHEN eligible.unique_key IS NULL OR eligible.unique_states IS NULL THEN FALSE + WHEN uc.unique_key IS NOT NULL THEN TRUE + WHEN eligible.row_num > 1 THEN TRUE + ELSE FALSE + END AS conflict_discarded +FROM eligible +LEFT JOIN unique_conflicts uc ON eligible.unique_key = uc.unique_key +ORDER BY eligible.priority, eligible.scheduled_at, eligible.id +` + +type JobScheduleParams struct { + Now sql.NullTime + Limit int32 +} + +type JobScheduleRow struct { + ID int64 + ConflictDiscarded int64 +} + +func (q *Queries) JobSchedule(ctx context.Context, db DBTX, arg *JobScheduleParams) ([]*JobScheduleRow, error) { + rows, err := db.QueryContext(ctx, jobSchedule, arg.Now, arg.Limit) + if err != nil { + return nil, err + } + defer rows.Close() + var items []*JobScheduleRow + for rows.Next() { + var i JobScheduleRow + if err := rows.Scan(&i.ID, &i.ConflictDiscarded); err != nil { + return nil, err + } + items = append(items, &i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const jobScheduleSetAvailableExec = `-- name: JobScheduleSetAvailableExec :exec +UPDATE /* TEMPLATE: schema */river_job +SET state = 'available' +WHERE id IN (/*SLICE:id*/?) +` + +func (q *Queries) JobScheduleSetAvailableExec(ctx context.Context, db DBTX, id []int64) error { + query := jobScheduleSetAvailableExec + var queryParams []interface{} + if len(id) > 0 { + for _, v := range id { + queryParams = append(queryParams, v) + } + query = strings.Replace(query, "/*SLICE:id*/?", strings.Repeat(",?", len(id))[1:], 1) + } else { + query = strings.Replace(query, "/*SLICE:id*/?", "NULL", 1) + } + _, err := db.ExecContext(ctx, query, queryParams...) + return err +} + +const jobScheduleSetDiscardedExec = `-- name: JobScheduleSetDiscardedExec :exec +UPDATE /* TEMPLATE: schema */river_job +SET metadata = JSON_MERGE_PATCH(metadata, '{"unique_key_conflict": "scheduler_discarded"}'), + finalized_at = COALESCE(?, NOW(6)), + state = 'discarded' +WHERE id IN (/*SLICE:id*/?) +` + +type JobScheduleSetDiscardedExecParams struct { + Now sql.NullTime + ID []int64 +} + +func (q *Queries) JobScheduleSetDiscardedExec(ctx context.Context, db DBTX, arg *JobScheduleSetDiscardedExecParams) error { + query := jobScheduleSetDiscardedExec + var queryParams []interface{} + queryParams = append(queryParams, arg.Now) + if len(arg.ID) > 0 { + for _, v := range arg.ID { + queryParams = append(queryParams, v) + } + query = strings.Replace(query, "/*SLICE:id*/?", strings.Repeat(",?", len(arg.ID))[1:], 1) + } else { + query = strings.Replace(query, "/*SLICE:id*/?", "NULL", 1) + } + _, err := db.ExecContext(ctx, query, queryParams...) + return err +} + +const jobSetMetadataIfNotRunningExec = `-- name: JobSetMetadataIfNotRunningExec :execresult +UPDATE /* TEMPLATE: schema */river_job +SET metadata = JSON_MERGE_PATCH(metadata, CAST(? AS JSON)) +WHERE id = ? + AND state != 'running' +` + +type JobSetMetadataIfNotRunningExecParams struct { + MetadataUpdates []byte + ID int64 +} + +func (q *Queries) JobSetMetadataIfNotRunningExec(ctx context.Context, db DBTX, arg *JobSetMetadataIfNotRunningExecParams) (sql.Result, error) { + return db.ExecContext(ctx, jobSetMetadataIfNotRunningExec, arg.MetadataUpdates, arg.ID) +} + +const jobSetStateIfRunningExec = `-- name: JobSetStateIfRunningExec :exec +UPDATE /* TEMPLATE: schema */river_job +SET + attempt = CASE WHEN (CAST(? AS CHAR) <> 'retryable' AND ? <> 'scheduled' OR JSON_EXTRACT(metadata, '$.cancel_attempted_at') IS NULL) AND CAST(? AS SIGNED) + THEN ? + ELSE attempt END, + errors = CASE WHEN CAST(? AS SIGNED) + THEN JSON_ARRAY_APPEND(COALESCE(errors, JSON_ARRAY()), '$', CAST(? AS JSON)) + ELSE errors END, + finalized_at = CASE WHEN ((? = 'retryable' OR ? = 'scheduled') AND JSON_EXTRACT(metadata, '$.cancel_attempted_at') IS NOT NULL) + THEN COALESCE(?, NOW(6)) + WHEN CAST(? AS SIGNED) + THEN ? + ELSE finalized_at END, + metadata = CASE WHEN CAST(? AS SIGNED) + THEN JSON_MERGE_PATCH(metadata, CAST(? AS JSON)) + ELSE metadata END, + scheduled_at = CASE WHEN (CAST(? AS CHAR) <> 'retryable' AND ? <> 'scheduled' OR JSON_EXTRACT(metadata, '$.cancel_attempted_at') IS NULL) AND CAST(? AS SIGNED) + THEN ? + ELSE scheduled_at END, + state = CASE WHEN ((? = 'retryable' OR ? = 'scheduled') AND JSON_EXTRACT(metadata, '$.cancel_attempted_at') IS NOT NULL) + THEN 'cancelled' + ELSE ? END +WHERE id = ? + AND state = 'running' +` + +type JobSetStateIfRunningExecParams struct { + State string + AttemptDoUpdate int64 + Attempt int64 + ErrorsDoUpdate int64 + Error []byte + Now sql.NullTime + FinalizedAtDoUpdate int64 + FinalizedAt sql.NullTime + MetadataDoMerge int64 + MetadataUpdates []byte + ScheduledAtDoUpdate int64 + ScheduledAt time.Time + ID int64 +} + +func (q *Queries) JobSetStateIfRunningExec(ctx context.Context, db DBTX, arg *JobSetStateIfRunningExecParams) error { + _, err := db.ExecContext(ctx, jobSetStateIfRunningExec, + arg.State, + arg.State, + arg.AttemptDoUpdate, + arg.Attempt, + arg.ErrorsDoUpdate, + arg.Error, + arg.State, + arg.State, + arg.Now, + arg.FinalizedAtDoUpdate, + arg.FinalizedAt, + arg.MetadataDoMerge, + arg.MetadataUpdates, + arg.State, + arg.State, + arg.ScheduledAtDoUpdate, + arg.ScheduledAt, + arg.State, + arg.State, + arg.State, + arg.ID, + ) + return err +} + +const jobUpdateExec = `-- name: JobUpdateExec :exec +UPDATE /* TEMPLATE: schema */river_job +SET + metadata = CASE WHEN CAST(? AS SIGNED) THEN JSON_MERGE_PATCH(metadata, CAST(? AS JSON)) ELSE metadata END +WHERE id = ? +` + +type JobUpdateExecParams struct { + MetadataDoMerge int64 + Metadata []byte + ID int64 +} + +func (q *Queries) JobUpdateExec(ctx context.Context, db DBTX, arg *JobUpdateExecParams) error { + _, err := db.ExecContext(ctx, jobUpdateExec, arg.MetadataDoMerge, arg.Metadata, arg.ID) + return err +} + +const jobUpdateFullExec = `-- name: JobUpdateFullExec :exec +UPDATE /* TEMPLATE: schema */river_job +SET + attempt = CASE WHEN CAST(? AS SIGNED) THEN ? ELSE attempt END, + attempted_at = CASE WHEN CAST(? AS SIGNED) THEN ? ELSE attempted_at END, + attempted_by = CASE WHEN CAST(? AS SIGNED) THEN CAST(? AS JSON) ELSE attempted_by END, + errors = CASE WHEN CAST(? AS SIGNED) THEN CAST(? AS JSON) ELSE errors END, + finalized_at = CASE WHEN CAST(? AS SIGNED) THEN ? ELSE finalized_at END, + max_attempts = CASE WHEN CAST(? AS SIGNED) THEN ? ELSE max_attempts END, + metadata = CASE WHEN CAST(? AS SIGNED) THEN CAST(? AS JSON) ELSE metadata END, + state = CASE WHEN CAST(? AS SIGNED) THEN ? ELSE state END +WHERE id = ? +` + +type JobUpdateFullExecParams struct { + AttemptDoUpdate int64 + Attempt int64 + AttemptedAtDoUpdate int64 + AttemptedAt sql.NullTime + AttemptedByDoUpdate int64 + AttemptedBy []byte + ErrorsDoUpdate int64 + Errors []byte + FinalizedAtDoUpdate int64 + FinalizedAt sql.NullTime + MaxAttemptsDoUpdate int64 + MaxAttempts int64 + MetadataDoUpdate int64 + Metadata []byte + StateDoUpdate int64 + State string + ID int64 +} + +func (q *Queries) JobUpdateFullExec(ctx context.Context, db DBTX, arg *JobUpdateFullExecParams) error { + _, err := db.ExecContext(ctx, jobUpdateFullExec, + arg.AttemptDoUpdate, + arg.Attempt, + arg.AttemptedAtDoUpdate, + arg.AttemptedAt, + arg.AttemptedByDoUpdate, + arg.AttemptedBy, + arg.ErrorsDoUpdate, + arg.Errors, + arg.FinalizedAtDoUpdate, + arg.FinalizedAt, + arg.MaxAttemptsDoUpdate, + arg.MaxAttempts, + arg.MetadataDoUpdate, + arg.Metadata, + arg.StateDoUpdate, + arg.State, + arg.ID, + ) + return err +} diff --git a/riverdriver/rivermysql/internal/dbsqlc/river_leader.sql b/riverdriver/rivermysql/internal/dbsqlc/river_leader.sql new file mode 100644 index 00000000..0e2f77ab --- /dev/null +++ b/riverdriver/rivermysql/internal/dbsqlc/river_leader.sql @@ -0,0 +1,50 @@ +CREATE TABLE river_leader ( + elected_at DATETIME(6) NOT NULL, + expires_at DATETIME(6) NOT NULL, + leader_id VARCHAR(128) NOT NULL, + name VARCHAR(128) NOT NULL DEFAULT 'default' PRIMARY KEY +); + +-- name: LeaderAttemptElectExec :execresult +INSERT IGNORE INTO /* TEMPLATE: schema */river_leader ( + leader_id, + elected_at, + expires_at +) VALUES ( + sqlc.arg('leader_id'), + COALESCE(sqlc.narg('now'), NOW(6)), + TIMESTAMPADD(MICROSECOND, sqlc.arg('ttl'), COALESCE(sqlc.narg('now'), NOW(6))) +); + +-- name: LeaderAttemptReelectExec :execresult +UPDATE /* TEMPLATE: schema */river_leader +SET expires_at = TIMESTAMPADD(MICROSECOND, sqlc.arg('ttl'), COALESCE(sqlc.narg('now'), NOW(6))) +WHERE + elected_at = sqlc.arg('elected_at') + AND expires_at >= COALESCE(sqlc.narg('now'), NOW(6)) + AND leader_id = sqlc.arg('leader_id'); + +-- name: LeaderDeleteExpired :execrows +DELETE FROM /* TEMPLATE: schema */river_leader +WHERE expires_at < COALESCE(sqlc.narg('now'), NOW(6)); + +-- name: LeaderGetElectedLeader :one +SELECT elected_at, expires_at, leader_id, name +FROM /* TEMPLATE: schema */river_leader; + +-- name: LeaderInsertExec :exec +INSERT INTO /* TEMPLATE: schema */river_leader ( + elected_at, + expires_at, + leader_id +) VALUES ( + COALESCE(sqlc.narg('elected_at'), sqlc.narg('now'), NOW(6)), + COALESCE(sqlc.narg('expires_at'), TIMESTAMPADD(MICROSECOND, sqlc.arg('ttl'), COALESCE(sqlc.narg('now'), NOW(6)))), + sqlc.arg('leader_id') +); + +-- name: LeaderResign :execrows +DELETE FROM /* TEMPLATE: schema */river_leader +WHERE + elected_at = sqlc.arg('elected_at') + AND leader_id = sqlc.arg('leader_id'); diff --git a/riverdriver/rivermysql/internal/dbsqlc/river_leader.sql.go b/riverdriver/rivermysql/internal/dbsqlc/river_leader.sql.go new file mode 100644 index 00000000..cb31a9e1 --- /dev/null +++ b/riverdriver/rivermysql/internal/dbsqlc/river_leader.sql.go @@ -0,0 +1,147 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.31.0 +// source: river_leader.sql + +package dbsqlc + +import ( + "context" + "database/sql" + "time" +) + +const leaderAttemptElectExec = `-- name: LeaderAttemptElectExec :execresult +INSERT IGNORE INTO /* TEMPLATE: schema */river_leader ( + leader_id, + elected_at, + expires_at +) VALUES ( + ?, + COALESCE(?, NOW(6)), + TIMESTAMPADD(MICROSECOND, ?, COALESCE(?, NOW(6))) +) +` + +type LeaderAttemptElectExecParams struct { + LeaderID string + Now interface{} + TTL int64 +} + +func (q *Queries) LeaderAttemptElectExec(ctx context.Context, db DBTX, arg *LeaderAttemptElectExecParams) (sql.Result, error) { + return db.ExecContext(ctx, leaderAttemptElectExec, + arg.LeaderID, + arg.Now, + arg.TTL, + arg.Now, + ) +} + +const leaderAttemptReelectExec = `-- name: LeaderAttemptReelectExec :execresult +UPDATE /* TEMPLATE: schema */river_leader +SET expires_at = TIMESTAMPADD(MICROSECOND, ?, COALESCE(?, NOW(6))) +WHERE + elected_at = ? + AND expires_at >= COALESCE(?, NOW(6)) + AND leader_id = ? +` + +type LeaderAttemptReelectExecParams struct { + TTL int64 + Now sql.NullTime + ElectedAt time.Time + LeaderID string +} + +func (q *Queries) LeaderAttemptReelectExec(ctx context.Context, db DBTX, arg *LeaderAttemptReelectExecParams) (sql.Result, error) { + return db.ExecContext(ctx, leaderAttemptReelectExec, + arg.TTL, + arg.Now, + arg.ElectedAt, + arg.Now, + arg.LeaderID, + ) +} + +const leaderDeleteExpired = `-- name: LeaderDeleteExpired :execrows +DELETE FROM /* TEMPLATE: schema */river_leader +WHERE expires_at < COALESCE(?, NOW(6)) +` + +func (q *Queries) LeaderDeleteExpired(ctx context.Context, db DBTX, now sql.NullTime) (int64, error) { + result, err := db.ExecContext(ctx, leaderDeleteExpired, now) + if err != nil { + return 0, err + } + return result.RowsAffected() +} + +const leaderGetElectedLeader = `-- name: LeaderGetElectedLeader :one +SELECT elected_at, expires_at, leader_id, name +FROM /* TEMPLATE: schema */river_leader +` + +func (q *Queries) LeaderGetElectedLeader(ctx context.Context, db DBTX) (*RiverLeader, error) { + row := db.QueryRowContext(ctx, leaderGetElectedLeader) + var i RiverLeader + err := row.Scan( + &i.ElectedAt, + &i.ExpiresAt, + &i.LeaderID, + &i.Name, + ) + return &i, err +} + +const leaderInsertExec = `-- name: LeaderInsertExec :exec +INSERT INTO /* TEMPLATE: schema */river_leader ( + elected_at, + expires_at, + leader_id +) VALUES ( + COALESCE(?, ?, NOW(6)), + COALESCE(?, TIMESTAMPADD(MICROSECOND, ?, COALESCE(?, NOW(6)))), + ? +) +` + +type LeaderInsertExecParams struct { + ElectedAt interface{} + Now interface{} + ExpiresAt interface{} + TTL int64 + LeaderID string +} + +func (q *Queries) LeaderInsertExec(ctx context.Context, db DBTX, arg *LeaderInsertExecParams) error { + _, err := db.ExecContext(ctx, leaderInsertExec, + arg.ElectedAt, + arg.Now, + arg.ExpiresAt, + arg.TTL, + arg.Now, + arg.LeaderID, + ) + return err +} + +const leaderResign = `-- name: LeaderResign :execrows +DELETE FROM /* TEMPLATE: schema */river_leader +WHERE + elected_at = ? + AND leader_id = ? +` + +type LeaderResignParams struct { + ElectedAt time.Time + LeaderID string +} + +func (q *Queries) LeaderResign(ctx context.Context, db DBTX, arg *LeaderResignParams) (int64, error) { + result, err := db.ExecContext(ctx, leaderResign, arg.ElectedAt, arg.LeaderID) + if err != nil { + return 0, err + } + return result.RowsAffected() +} diff --git a/riverdriver/rivermysql/internal/dbsqlc/river_migration.sql b/riverdriver/rivermysql/internal/dbsqlc/river_migration.sql new file mode 100644 index 00000000..8a1aeccb --- /dev/null +++ b/riverdriver/rivermysql/internal/dbsqlc/river_migration.sql @@ -0,0 +1,77 @@ +CREATE TABLE river_migration ( + line VARCHAR(128) NOT NULL, + version BIGINT NOT NULL, + created_at DATETIME(6) NOT NULL DEFAULT (NOW(6)), + PRIMARY KEY (line, version) +); + +-- name: RiverMigrationDeleteAssumingMainMany :many +SELECT created_at, version +FROM /* TEMPLATE: schema */river_migration +WHERE version IN (sqlc.slice('version')); + +-- name: RiverMigrationDeleteAssumingMainManyExec :exec +DELETE FROM /* TEMPLATE: schema */river_migration +WHERE version IN (sqlc.slice('version')); + +-- name: RiverMigrationDeleteByLineAndVersionMany :many +SELECT line, version, created_at +FROM /* TEMPLATE: schema */river_migration +WHERE line = sqlc.arg('line') + AND version IN (sqlc.slice('version')); + +-- name: RiverMigrationDeleteByLineAndVersionManyExec :exec +DELETE FROM /* TEMPLATE: schema */river_migration +WHERE line = sqlc.arg('line') + AND version IN (sqlc.slice('version')); + +-- name: RiverMigrationGetAllAssumingMain :many +SELECT + created_at, + version +FROM /* TEMPLATE: schema */river_migration +ORDER BY version; + +-- name: RiverMigrationGetByLine :many +SELECT line, version, created_at +FROM /* TEMPLATE: schema */river_migration +WHERE line = sqlc.arg('line') +ORDER BY version; + +-- name: RiverMigrationInsertExec :exec +INSERT INTO /* TEMPLATE: schema */river_migration ( + line, + version +) VALUES ( + sqlc.arg('line'), + sqlc.arg('version') +); + +-- name: RiverMigrationGetByLineAndVersion :one +SELECT line, version, created_at +FROM /* TEMPLATE: schema */river_migration +WHERE line = sqlc.arg('line') AND version = sqlc.arg('version'); + +-- name: RiverMigrationGetByLineAndVersionMany :many +SELECT line, version, created_at +FROM /* TEMPLATE: schema */river_migration +WHERE line = sqlc.arg('line') AND version IN (sqlc.slice('version')) +ORDER BY version; + +-- name: RiverMigrationInsertAssumingMainExec :exec +INSERT INTO /* TEMPLATE: schema */river_migration ( + version +) VALUES ( + sqlc.arg('version') +); + +-- name: RiverMigrationGetByVersion :one +SELECT created_at, version +FROM /* TEMPLATE: schema */river_migration +WHERE version = sqlc.arg('version'); + +-- name: RiverMigrationGetByVersionMany :many +SELECT created_at, version +FROM /* TEMPLATE: schema */river_migration +WHERE version IN (sqlc.slice('version')) +ORDER BY version; diff --git a/riverdriver/rivermysql/internal/dbsqlc/river_migration.sql.go b/riverdriver/rivermysql/internal/dbsqlc/river_migration.sql.go new file mode 100644 index 00000000..2e3cbb89 --- /dev/null +++ b/riverdriver/rivermysql/internal/dbsqlc/river_migration.sql.go @@ -0,0 +1,375 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.31.0 +// source: river_migration.sql + +package dbsqlc + +import ( + "context" + "strings" + "time" +) + +const riverMigrationDeleteAssumingMainMany = `-- name: RiverMigrationDeleteAssumingMainMany :many +SELECT created_at, version +FROM /* TEMPLATE: schema */river_migration +WHERE version IN (/*SLICE:version*/?) +` + +type RiverMigrationDeleteAssumingMainManyRow struct { + CreatedAt time.Time + Version int64 +} + +func (q *Queries) RiverMigrationDeleteAssumingMainMany(ctx context.Context, db DBTX, version []int64) ([]*RiverMigrationDeleteAssumingMainManyRow, error) { + query := riverMigrationDeleteAssumingMainMany + var queryParams []interface{} + if len(version) > 0 { + for _, v := range version { + queryParams = append(queryParams, v) + } + query = strings.Replace(query, "/*SLICE:version*/?", strings.Repeat(",?", len(version))[1:], 1) + } else { + query = strings.Replace(query, "/*SLICE:version*/?", "NULL", 1) + } + rows, err := db.QueryContext(ctx, query, queryParams...) + if err != nil { + return nil, err + } + defer rows.Close() + var items []*RiverMigrationDeleteAssumingMainManyRow + for rows.Next() { + var i RiverMigrationDeleteAssumingMainManyRow + if err := rows.Scan(&i.CreatedAt, &i.Version); err != nil { + return nil, err + } + items = append(items, &i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const riverMigrationDeleteAssumingMainManyExec = `-- name: RiverMigrationDeleteAssumingMainManyExec :exec +DELETE FROM /* TEMPLATE: schema */river_migration +WHERE version IN (/*SLICE:version*/?) +` + +func (q *Queries) RiverMigrationDeleteAssumingMainManyExec(ctx context.Context, db DBTX, version []int64) error { + query := riverMigrationDeleteAssumingMainManyExec + var queryParams []interface{} + if len(version) > 0 { + for _, v := range version { + queryParams = append(queryParams, v) + } + query = strings.Replace(query, "/*SLICE:version*/?", strings.Repeat(",?", len(version))[1:], 1) + } else { + query = strings.Replace(query, "/*SLICE:version*/?", "NULL", 1) + } + _, err := db.ExecContext(ctx, query, queryParams...) + return err +} + +const riverMigrationDeleteByLineAndVersionMany = `-- name: RiverMigrationDeleteByLineAndVersionMany :many +SELECT line, version, created_at +FROM /* TEMPLATE: schema */river_migration +WHERE line = ? + AND version IN (/*SLICE:version*/?) +` + +type RiverMigrationDeleteByLineAndVersionManyParams struct { + Line string + Version []int64 +} + +func (q *Queries) RiverMigrationDeleteByLineAndVersionMany(ctx context.Context, db DBTX, arg *RiverMigrationDeleteByLineAndVersionManyParams) ([]*RiverMigration, error) { + query := riverMigrationDeleteByLineAndVersionMany + var queryParams []interface{} + queryParams = append(queryParams, arg.Line) + if len(arg.Version) > 0 { + for _, v := range arg.Version { + queryParams = append(queryParams, v) + } + query = strings.Replace(query, "/*SLICE:version*/?", strings.Repeat(",?", len(arg.Version))[1:], 1) + } else { + query = strings.Replace(query, "/*SLICE:version*/?", "NULL", 1) + } + rows, err := db.QueryContext(ctx, query, queryParams...) + if err != nil { + return nil, err + } + defer rows.Close() + var items []*RiverMigration + for rows.Next() { + var i RiverMigration + if err := rows.Scan(&i.Line, &i.Version, &i.CreatedAt); err != nil { + return nil, err + } + items = append(items, &i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const riverMigrationDeleteByLineAndVersionManyExec = `-- name: RiverMigrationDeleteByLineAndVersionManyExec :exec +DELETE FROM /* TEMPLATE: schema */river_migration +WHERE line = ? + AND version IN (/*SLICE:version*/?) +` + +type RiverMigrationDeleteByLineAndVersionManyExecParams struct { + Line string + Version []int64 +} + +func (q *Queries) RiverMigrationDeleteByLineAndVersionManyExec(ctx context.Context, db DBTX, arg *RiverMigrationDeleteByLineAndVersionManyExecParams) error { + query := riverMigrationDeleteByLineAndVersionManyExec + var queryParams []interface{} + queryParams = append(queryParams, arg.Line) + if len(arg.Version) > 0 { + for _, v := range arg.Version { + queryParams = append(queryParams, v) + } + query = strings.Replace(query, "/*SLICE:version*/?", strings.Repeat(",?", len(arg.Version))[1:], 1) + } else { + query = strings.Replace(query, "/*SLICE:version*/?", "NULL", 1) + } + _, err := db.ExecContext(ctx, query, queryParams...) + return err +} + +const riverMigrationGetAllAssumingMain = `-- name: RiverMigrationGetAllAssumingMain :many +SELECT + created_at, + version +FROM /* TEMPLATE: schema */river_migration +ORDER BY version +` + +type RiverMigrationGetAllAssumingMainRow struct { + CreatedAt time.Time + Version int64 +} + +func (q *Queries) RiverMigrationGetAllAssumingMain(ctx context.Context, db DBTX) ([]*RiverMigrationGetAllAssumingMainRow, error) { + rows, err := db.QueryContext(ctx, riverMigrationGetAllAssumingMain) + if err != nil { + return nil, err + } + defer rows.Close() + var items []*RiverMigrationGetAllAssumingMainRow + for rows.Next() { + var i RiverMigrationGetAllAssumingMainRow + if err := rows.Scan(&i.CreatedAt, &i.Version); err != nil { + return nil, err + } + items = append(items, &i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const riverMigrationGetByLine = `-- name: RiverMigrationGetByLine :many +SELECT line, version, created_at +FROM /* TEMPLATE: schema */river_migration +WHERE line = ? +ORDER BY version +` + +func (q *Queries) RiverMigrationGetByLine(ctx context.Context, db DBTX, line string) ([]*RiverMigration, error) { + rows, err := db.QueryContext(ctx, riverMigrationGetByLine, line) + if err != nil { + return nil, err + } + defer rows.Close() + var items []*RiverMigration + for rows.Next() { + var i RiverMigration + if err := rows.Scan(&i.Line, &i.Version, &i.CreatedAt); err != nil { + return nil, err + } + items = append(items, &i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const riverMigrationGetByLineAndVersion = `-- name: RiverMigrationGetByLineAndVersion :one +SELECT line, version, created_at +FROM /* TEMPLATE: schema */river_migration +WHERE line = ? AND version = ? +` + +type RiverMigrationGetByLineAndVersionParams struct { + Line string + Version int64 +} + +func (q *Queries) RiverMigrationGetByLineAndVersion(ctx context.Context, db DBTX, arg *RiverMigrationGetByLineAndVersionParams) (*RiverMigration, error) { + row := db.QueryRowContext(ctx, riverMigrationGetByLineAndVersion, arg.Line, arg.Version) + var i RiverMigration + err := row.Scan(&i.Line, &i.Version, &i.CreatedAt) + return &i, err +} + +const riverMigrationGetByLineAndVersionMany = `-- name: RiverMigrationGetByLineAndVersionMany :many +SELECT line, version, created_at +FROM /* TEMPLATE: schema */river_migration +WHERE line = ? AND version IN (/*SLICE:version*/?) +ORDER BY version +` + +type RiverMigrationGetByLineAndVersionManyParams struct { + Line string + Version []int64 +} + +func (q *Queries) RiverMigrationGetByLineAndVersionMany(ctx context.Context, db DBTX, arg *RiverMigrationGetByLineAndVersionManyParams) ([]*RiverMigration, error) { + query := riverMigrationGetByLineAndVersionMany + var queryParams []interface{} + queryParams = append(queryParams, arg.Line) + if len(arg.Version) > 0 { + for _, v := range arg.Version { + queryParams = append(queryParams, v) + } + query = strings.Replace(query, "/*SLICE:version*/?", strings.Repeat(",?", len(arg.Version))[1:], 1) + } else { + query = strings.Replace(query, "/*SLICE:version*/?", "NULL", 1) + } + rows, err := db.QueryContext(ctx, query, queryParams...) + if err != nil { + return nil, err + } + defer rows.Close() + var items []*RiverMigration + for rows.Next() { + var i RiverMigration + if err := rows.Scan(&i.Line, &i.Version, &i.CreatedAt); err != nil { + return nil, err + } + items = append(items, &i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const riverMigrationGetByVersion = `-- name: RiverMigrationGetByVersion :one +SELECT created_at, version +FROM /* TEMPLATE: schema */river_migration +WHERE version = ? +` + +type RiverMigrationGetByVersionRow struct { + CreatedAt time.Time + Version int64 +} + +func (q *Queries) RiverMigrationGetByVersion(ctx context.Context, db DBTX, version int64) (*RiverMigrationGetByVersionRow, error) { + row := db.QueryRowContext(ctx, riverMigrationGetByVersion, version) + var i RiverMigrationGetByVersionRow + err := row.Scan(&i.CreatedAt, &i.Version) + return &i, err +} + +const riverMigrationGetByVersionMany = `-- name: RiverMigrationGetByVersionMany :many +SELECT created_at, version +FROM /* TEMPLATE: schema */river_migration +WHERE version IN (/*SLICE:version*/?) +ORDER BY version +` + +type RiverMigrationGetByVersionManyRow struct { + CreatedAt time.Time + Version int64 +} + +func (q *Queries) RiverMigrationGetByVersionMany(ctx context.Context, db DBTX, version []int64) ([]*RiverMigrationGetByVersionManyRow, error) { + query := riverMigrationGetByVersionMany + var queryParams []interface{} + if len(version) > 0 { + for _, v := range version { + queryParams = append(queryParams, v) + } + query = strings.Replace(query, "/*SLICE:version*/?", strings.Repeat(",?", len(version))[1:], 1) + } else { + query = strings.Replace(query, "/*SLICE:version*/?", "NULL", 1) + } + rows, err := db.QueryContext(ctx, query, queryParams...) + if err != nil { + return nil, err + } + defer rows.Close() + var items []*RiverMigrationGetByVersionManyRow + for rows.Next() { + var i RiverMigrationGetByVersionManyRow + if err := rows.Scan(&i.CreatedAt, &i.Version); err != nil { + return nil, err + } + items = append(items, &i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const riverMigrationInsertAssumingMainExec = `-- name: RiverMigrationInsertAssumingMainExec :exec +INSERT INTO /* TEMPLATE: schema */river_migration ( + version +) VALUES ( + ? +) +` + +func (q *Queries) RiverMigrationInsertAssumingMainExec(ctx context.Context, db DBTX, version int64) error { + _, err := db.ExecContext(ctx, riverMigrationInsertAssumingMainExec, version) + return err +} + +const riverMigrationInsertExec = `-- name: RiverMigrationInsertExec :exec +INSERT INTO /* TEMPLATE: schema */river_migration ( + line, + version +) VALUES ( + ?, + ? +) +` + +type RiverMigrationInsertExecParams struct { + Line string + Version int64 +} + +func (q *Queries) RiverMigrationInsertExec(ctx context.Context, db DBTX, arg *RiverMigrationInsertExecParams) error { + _, err := db.ExecContext(ctx, riverMigrationInsertExec, arg.Line, arg.Version) + return err +} diff --git a/riverdriver/rivermysql/internal/dbsqlc/river_queue.sql b/riverdriver/rivermysql/internal/dbsqlc/river_queue.sql new file mode 100644 index 00000000..058bb7f2 --- /dev/null +++ b/riverdriver/rivermysql/internal/dbsqlc/river_queue.sql @@ -0,0 +1,91 @@ +CREATE TABLE river_queue ( + name VARCHAR(128) NOT NULL PRIMARY KEY, + created_at DATETIME(6) NOT NULL DEFAULT (NOW(6)), + metadata JSON NOT NULL DEFAULT (JSON_OBJECT()), + paused_at DATETIME(6) NULL, + updated_at DATETIME(6) NOT NULL +); + +-- name: QueueCreateOrSetUpdatedAtExec :exec +INSERT INTO /* TEMPLATE: schema */river_queue ( + created_at, + metadata, + name, + paused_at, + updated_at +) VALUES ( + COALESCE(sqlc.narg('now'), NOW(6)), + CAST(sqlc.arg('metadata') AS JSON), + sqlc.arg('name'), + sqlc.narg('paused_at'), + COALESCE(sqlc.narg('updated_at'), sqlc.narg('now'), NOW(6)) +) ON DUPLICATE KEY UPDATE + updated_at = VALUES(updated_at); + +-- name: QueueGet :one +SELECT name, created_at, metadata, paused_at, updated_at +FROM /* TEMPLATE: schema */river_queue +WHERE name = sqlc.arg('name'); + +-- name: QueueDeleteExpiredSelect :many +SELECT name +FROM /* TEMPLATE: schema */river_queue +WHERE updated_at < sqlc.arg('updated_at_horizon') +ORDER BY name ASC +LIMIT ?; + +-- name: QueueDeleteExpiredExec :exec +DELETE FROM /* TEMPLATE: schema */river_queue +WHERE name IN (sqlc.slice('names')); + +-- name: QueueList :many +SELECT name, created_at, metadata, paused_at, updated_at +FROM /* TEMPLATE: schema */river_queue +ORDER BY name ASC +LIMIT ?; + +-- name: QueueNameList :many +SELECT name +FROM /* TEMPLATE: schema */river_queue +WHERE + name COLLATE utf8mb4_general_ci > sqlc.arg('after') + AND (sqlc.arg('match') = '' OR LOWER(name) COLLATE utf8mb4_general_ci LIKE CONCAT('%', LOWER(sqlc.arg('match')), '%')) + AND name NOT IN (sqlc.slice('exclude')) +ORDER BY name ASC +LIMIT ?; + +-- MySQL evaluates SET clauses left-to-right using already-updated values, so +-- updated_at must be set BEFORE paused_at to see the original paused_at value. + +-- name: QueuePauseAll :execresult +UPDATE /* TEMPLATE: schema */river_queue +SET + updated_at = CASE WHEN paused_at IS NULL THEN COALESCE(sqlc.narg('now'), NOW(6)) ELSE updated_at END, + paused_at = CASE WHEN paused_at IS NULL THEN COALESCE(sqlc.narg('now'), NOW(6)) ELSE paused_at END; + +-- name: QueuePauseByName :execresult +UPDATE /* TEMPLATE: schema */river_queue +SET + updated_at = CASE WHEN paused_at IS NULL THEN COALESCE(sqlc.narg('now'), NOW(6)) ELSE updated_at END, + paused_at = CASE WHEN paused_at IS NULL THEN COALESCE(sqlc.narg('now'), NOW(6)) ELSE paused_at END +WHERE name = sqlc.arg('name'); + +-- name: QueueResumeAll :execresult +UPDATE /* TEMPLATE: schema */river_queue +SET + updated_at = CASE WHEN paused_at IS NOT NULL THEN COALESCE(sqlc.narg('now'), NOW(6)) ELSE updated_at END, + paused_at = NULL; + +-- name: QueueResumeByName :execresult +UPDATE /* TEMPLATE: schema */river_queue +SET + updated_at = CASE WHEN paused_at IS NOT NULL THEN COALESCE(sqlc.narg('now'), NOW(6)) ELSE updated_at END, + paused_at = NULL +WHERE name = sqlc.arg('name'); + +-- name: QueueUpdateExec :exec +UPDATE /* TEMPLATE: schema */river_queue +SET + metadata = CASE WHEN CAST(sqlc.arg('metadata_do_update') AS SIGNED) THEN CAST(sqlc.arg('metadata') AS JSON) ELSE metadata END, + updated_at = NOW(6) +WHERE name = sqlc.arg('name'); diff --git a/riverdriver/rivermysql/internal/dbsqlc/river_queue.sql.go b/riverdriver/rivermysql/internal/dbsqlc/river_queue.sql.go new file mode 100644 index 00000000..38a4cef1 --- /dev/null +++ b/riverdriver/rivermysql/internal/dbsqlc/river_queue.sql.go @@ -0,0 +1,298 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.31.0 +// source: river_queue.sql + +package dbsqlc + +import ( + "context" + "database/sql" + "strings" + "time" +) + +const queueCreateOrSetUpdatedAtExec = `-- name: QueueCreateOrSetUpdatedAtExec :exec +INSERT INTO /* TEMPLATE: schema */river_queue ( + created_at, + metadata, + name, + paused_at, + updated_at +) VALUES ( + COALESCE(?, NOW(6)), + CAST(? AS JSON), + ?, + ?, + COALESCE(?, ?, NOW(6)) +) ON DUPLICATE KEY UPDATE + updated_at = VALUES(updated_at) +` + +type QueueCreateOrSetUpdatedAtExecParams struct { + Now interface{} + Metadata []byte + Name string + PausedAt sql.NullTime + UpdatedAt interface{} +} + +func (q *Queries) QueueCreateOrSetUpdatedAtExec(ctx context.Context, db DBTX, arg *QueueCreateOrSetUpdatedAtExecParams) error { + _, err := db.ExecContext(ctx, queueCreateOrSetUpdatedAtExec, + arg.Now, + arg.Metadata, + arg.Name, + arg.PausedAt, + arg.UpdatedAt, + arg.Now, + ) + return err +} + +const queueDeleteExpiredExec = `-- name: QueueDeleteExpiredExec :exec +DELETE FROM /* TEMPLATE: schema */river_queue +WHERE name IN (/*SLICE:names*/?) +` + +func (q *Queries) QueueDeleteExpiredExec(ctx context.Context, db DBTX, names []string) error { + query := queueDeleteExpiredExec + var queryParams []interface{} + if len(names) > 0 { + for _, v := range names { + queryParams = append(queryParams, v) + } + query = strings.Replace(query, "/*SLICE:names*/?", strings.Repeat(",?", len(names))[1:], 1) + } else { + query = strings.Replace(query, "/*SLICE:names*/?", "NULL", 1) + } + _, err := db.ExecContext(ctx, query, queryParams...) + return err +} + +const queueDeleteExpiredSelect = `-- name: QueueDeleteExpiredSelect :many +SELECT name +FROM /* TEMPLATE: schema */river_queue +WHERE updated_at < ? +ORDER BY name ASC +LIMIT ? +` + +type QueueDeleteExpiredSelectParams struct { + UpdatedAtHorizon time.Time + Limit int32 +} + +func (q *Queries) QueueDeleteExpiredSelect(ctx context.Context, db DBTX, arg *QueueDeleteExpiredSelectParams) ([]string, error) { + rows, err := db.QueryContext(ctx, queueDeleteExpiredSelect, arg.UpdatedAtHorizon, arg.Limit) + if err != nil { + return nil, err + } + defer rows.Close() + var items []string + for rows.Next() { + var name string + if err := rows.Scan(&name); err != nil { + return nil, err + } + items = append(items, name) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const queueGet = `-- name: QueueGet :one +SELECT name, created_at, metadata, paused_at, updated_at +FROM /* TEMPLATE: schema */river_queue +WHERE name = ? +` + +func (q *Queries) QueueGet(ctx context.Context, db DBTX, name string) (*RiverQueue, error) { + row := db.QueryRowContext(ctx, queueGet, name) + var i RiverQueue + err := row.Scan( + &i.Name, + &i.CreatedAt, + &i.Metadata, + &i.PausedAt, + &i.UpdatedAt, + ) + return &i, err +} + +const queueList = `-- name: QueueList :many +SELECT name, created_at, metadata, paused_at, updated_at +FROM /* TEMPLATE: schema */river_queue +ORDER BY name ASC +LIMIT ? +` + +func (q *Queries) QueueList(ctx context.Context, db DBTX, limit int32) ([]*RiverQueue, error) { + rows, err := db.QueryContext(ctx, queueList, limit) + if err != nil { + return nil, err + } + defer rows.Close() + var items []*RiverQueue + for rows.Next() { + var i RiverQueue + if err := rows.Scan( + &i.Name, + &i.CreatedAt, + &i.Metadata, + &i.PausedAt, + &i.UpdatedAt, + ); err != nil { + return nil, err + } + items = append(items, &i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const queueNameList = `-- name: QueueNameList :many +SELECT name +FROM /* TEMPLATE: schema */river_queue +WHERE + name COLLATE utf8mb4_general_ci > ? + AND (? = '' OR LOWER(name) COLLATE utf8mb4_general_ci LIKE CONCAT('%', LOWER(?), '%')) + AND name NOT IN (/*SLICE:exclude*/?) +ORDER BY name ASC +LIMIT ? +` + +type QueueNameListParams struct { + After interface{} + Match string + Exclude []string + Limit int32 +} + +func (q *Queries) QueueNameList(ctx context.Context, db DBTX, arg *QueueNameListParams) ([]string, error) { + query := queueNameList + var queryParams []interface{} + queryParams = append(queryParams, arg.After) + queryParams = append(queryParams, arg.Match) + queryParams = append(queryParams, arg.Match) + if len(arg.Exclude) > 0 { + for _, v := range arg.Exclude { + queryParams = append(queryParams, v) + } + query = strings.Replace(query, "/*SLICE:exclude*/?", strings.Repeat(",?", len(arg.Exclude))[1:], 1) + } else { + query = strings.Replace(query, "/*SLICE:exclude*/?", "NULL", 1) + } + queryParams = append(queryParams, arg.Limit) + rows, err := db.QueryContext(ctx, query, queryParams...) + if err != nil { + return nil, err + } + defer rows.Close() + var items []string + for rows.Next() { + var name string + if err := rows.Scan(&name); err != nil { + return nil, err + } + items = append(items, name) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const queuePauseAll = `-- name: QueuePauseAll :execresult + +UPDATE /* TEMPLATE: schema */river_queue +SET + updated_at = CASE WHEN paused_at IS NULL THEN COALESCE(?, NOW(6)) ELSE updated_at END, + paused_at = CASE WHEN paused_at IS NULL THEN COALESCE(?, NOW(6)) ELSE paused_at END +` + +type QueuePauseAllParams struct { + Now sql.NullTime +} + +// MySQL evaluates SET clauses left-to-right using already-updated values, so +// updated_at must be set BEFORE paused_at to see the original paused_at value. +func (q *Queries) QueuePauseAll(ctx context.Context, db DBTX, arg *QueuePauseAllParams) (sql.Result, error) { + return db.ExecContext(ctx, queuePauseAll, arg.Now, arg.Now) +} + +const queuePauseByName = `-- name: QueuePauseByName :execresult +UPDATE /* TEMPLATE: schema */river_queue +SET + updated_at = CASE WHEN paused_at IS NULL THEN COALESCE(?, NOW(6)) ELSE updated_at END, + paused_at = CASE WHEN paused_at IS NULL THEN COALESCE(?, NOW(6)) ELSE paused_at END +WHERE name = ? +` + +type QueuePauseByNameParams struct { + Now sql.NullTime + Name string +} + +func (q *Queries) QueuePauseByName(ctx context.Context, db DBTX, arg *QueuePauseByNameParams) (sql.Result, error) { + return db.ExecContext(ctx, queuePauseByName, arg.Now, arg.Now, arg.Name) +} + +const queueResumeAll = `-- name: QueueResumeAll :execresult +UPDATE /* TEMPLATE: schema */river_queue +SET + updated_at = CASE WHEN paused_at IS NOT NULL THEN COALESCE(?, NOW(6)) ELSE updated_at END, + paused_at = NULL +` + +func (q *Queries) QueueResumeAll(ctx context.Context, db DBTX, now sql.NullTime) (sql.Result, error) { + return db.ExecContext(ctx, queueResumeAll, now) +} + +const queueResumeByName = `-- name: QueueResumeByName :execresult +UPDATE /* TEMPLATE: schema */river_queue +SET + updated_at = CASE WHEN paused_at IS NOT NULL THEN COALESCE(?, NOW(6)) ELSE updated_at END, + paused_at = NULL +WHERE name = ? +` + +type QueueResumeByNameParams struct { + Now sql.NullTime + Name string +} + +func (q *Queries) QueueResumeByName(ctx context.Context, db DBTX, arg *QueueResumeByNameParams) (sql.Result, error) { + return db.ExecContext(ctx, queueResumeByName, arg.Now, arg.Name) +} + +const queueUpdateExec = `-- name: QueueUpdateExec :exec +UPDATE /* TEMPLATE: schema */river_queue +SET + metadata = CASE WHEN CAST(? AS SIGNED) THEN CAST(? AS JSON) ELSE metadata END, + updated_at = NOW(6) +WHERE name = ? +` + +type QueueUpdateExecParams struct { + MetadataDoUpdate int64 + Metadata []byte + Name string +} + +func (q *Queries) QueueUpdateExec(ctx context.Context, db DBTX, arg *QueueUpdateExecParams) error { + _, err := db.ExecContext(ctx, queueUpdateExec, arg.MetadataDoUpdate, arg.Metadata, arg.Name) + return err +} diff --git a/riverdriver/rivermysql/internal/dbsqlc/schema.sql b/riverdriver/rivermysql/internal/dbsqlc/schema.sql new file mode 100644 index 00000000..8a3e5372 --- /dev/null +++ b/riverdriver/rivermysql/internal/dbsqlc/schema.sql @@ -0,0 +1,43 @@ +-- Dummy table definitions for INFORMATION_SCHEMA system tables so sqlc can +-- resolve column types. At runtime the template prefix replaces the empty +-- default with "INFORMATION_SCHEMA.", making queries target the real system +-- tables. +CREATE TABLE STATISTICS ( + INDEX_NAME VARCHAR(128) NOT NULL, + TABLE_NAME VARCHAR(128) NOT NULL, + TABLE_SCHEMA VARCHAR(128) NOT NULL +); + +CREATE TABLE TABLES ( + TABLE_NAME VARCHAR(128) NOT NULL, + TABLE_SCHEMA VARCHAR(128) NOT NULL +); + +-- name: IndexExists :one +SELECT EXISTS ( + SELECT 1 + FROM /* TEMPLATE: information_schema */STATISTICS + WHERE INDEX_NAME = sqlc.arg('index_name') + AND TABLE_SCHEMA = COALESCE(sqlc.narg('schema'), DATABASE()) +); + +-- name: IndexGetTableName :one +SELECT TABLE_NAME +FROM /* TEMPLATE: information_schema */STATISTICS +WHERE INDEX_NAME = sqlc.arg('index_name') + AND TABLE_SCHEMA = COALESCE(sqlc.narg('schema'), DATABASE()) +LIMIT 1; + +-- name: IndexesExist :many +SELECT DISTINCT INDEX_NAME AS index_name +FROM /* TEMPLATE: information_schema */STATISTICS +WHERE INDEX_NAME IN (sqlc.slice('index_names')) + AND TABLE_SCHEMA = COALESCE(sqlc.narg('schema'), DATABASE()); + +-- name: TableExists :one +SELECT EXISTS ( + SELECT 1 + FROM /* TEMPLATE: information_schema */TABLES + WHERE TABLE_NAME = sqlc.arg('table_name') + AND TABLE_SCHEMA = COALESCE(sqlc.narg('schema'), DATABASE()) +); diff --git a/riverdriver/rivermysql/internal/dbsqlc/schema.sql.go b/riverdriver/rivermysql/internal/dbsqlc/schema.sql.go new file mode 100644 index 00000000..d84f137a --- /dev/null +++ b/riverdriver/rivermysql/internal/dbsqlc/schema.sql.go @@ -0,0 +1,120 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.31.0 +// source: schema.sql + +package dbsqlc + +import ( + "context" + "database/sql" + "strings" +) + +const indexExists = `-- name: IndexExists :one +SELECT EXISTS ( + SELECT 1 + FROM /* TEMPLATE: information_schema */STATISTICS + WHERE INDEX_NAME = ? + AND TABLE_SCHEMA = COALESCE(?, DATABASE()) +) +` + +type IndexExistsParams struct { + IndexName string + Schema sql.NullString +} + +func (q *Queries) IndexExists(ctx context.Context, db DBTX, arg *IndexExistsParams) (bool, error) { + row := db.QueryRowContext(ctx, indexExists, arg.IndexName, arg.Schema) + var exists bool + err := row.Scan(&exists) + return exists, err +} + +const indexGetTableName = `-- name: IndexGetTableName :one +SELECT TABLE_NAME +FROM /* TEMPLATE: information_schema */STATISTICS +WHERE INDEX_NAME = ? + AND TABLE_SCHEMA = COALESCE(?, DATABASE()) +LIMIT 1 +` + +type IndexGetTableNameParams struct { + IndexName string + Schema sql.NullString +} + +func (q *Queries) IndexGetTableName(ctx context.Context, db DBTX, arg *IndexGetTableNameParams) (string, error) { + row := db.QueryRowContext(ctx, indexGetTableName, arg.IndexName, arg.Schema) + var table_name string + err := row.Scan(&table_name) + return table_name, err +} + +const indexesExist = `-- name: IndexesExist :many +SELECT DISTINCT INDEX_NAME AS index_name +FROM /* TEMPLATE: information_schema */STATISTICS +WHERE INDEX_NAME IN (/*SLICE:index_names*/?) + AND TABLE_SCHEMA = COALESCE(?, DATABASE()) +` + +type IndexesExistParams struct { + IndexNames []string + Schema sql.NullString +} + +func (q *Queries) IndexesExist(ctx context.Context, db DBTX, arg *IndexesExistParams) ([]string, error) { + query := indexesExist + var queryParams []interface{} + if len(arg.IndexNames) > 0 { + for _, v := range arg.IndexNames { + queryParams = append(queryParams, v) + } + query = strings.Replace(query, "/*SLICE:index_names*/?", strings.Repeat(",?", len(arg.IndexNames))[1:], 1) + } else { + query = strings.Replace(query, "/*SLICE:index_names*/?", "NULL", 1) + } + queryParams = append(queryParams, arg.Schema) + rows, err := db.QueryContext(ctx, query, queryParams...) + if err != nil { + return nil, err + } + defer rows.Close() + var items []string + for rows.Next() { + var index_name string + if err := rows.Scan(&index_name); err != nil { + return nil, err + } + items = append(items, index_name) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const tableExists = `-- name: TableExists :one +SELECT EXISTS ( + SELECT 1 + FROM /* TEMPLATE: information_schema */TABLES + WHERE TABLE_NAME = ? + AND TABLE_SCHEMA = COALESCE(?, DATABASE()) +) +` + +type TableExistsParams struct { + TableName string + Schema sql.NullString +} + +func (q *Queries) TableExists(ctx context.Context, db DBTX, arg *TableExistsParams) (bool, error) { + row := db.QueryRowContext(ctx, tableExists, arg.TableName, arg.Schema) + var exists bool + err := row.Scan(&exists) + return exists, err +} diff --git a/riverdriver/rivermysql/internal/dbsqlc/sqlc.yaml b/riverdriver/rivermysql/internal/dbsqlc/sqlc.yaml new file mode 100644 index 00000000..3af577c0 --- /dev/null +++ b/riverdriver/rivermysql/internal/dbsqlc/sqlc.yaml @@ -0,0 +1,41 @@ +version: "2" +sql: + - engine: "mysql" + queries: + - river_job.sql + - river_leader.sql + - river_migration.sql + - river_queue.sql + - schema.sql + schema: + - river_client.sql + - river_client_queue.sql + - river_job.sql + - river_leader.sql + - river_migration.sql + - river_queue.sql + - schema.sql + gen: + go: + package: "dbsqlc" + out: "." + emit_exact_table_names: true + emit_methods_with_db_argument: true + emit_params_struct_pointers: true + emit_pointers_for_null_types: true + emit_result_struct_pointers: true + + rename: + ids: "IDs" + ttl: "TTL" + + overrides: + - db_type: "json" + go_type: + type: "[]byte" + - db_type: "json" + go_type: + type: "[]byte" + nullable: true + - db_type: "int" + go_type: "int64" diff --git a/riverdriver/rivermysql/migration/main/001_create_river_migration.down.sql b/riverdriver/rivermysql/migration/main/001_create_river_migration.down.sql new file mode 100644 index 00000000..d2f6b3fa --- /dev/null +++ b/riverdriver/rivermysql/migration/main/001_create_river_migration.down.sql @@ -0,0 +1 @@ +DROP TABLE /* TEMPLATE: schema */river_migration; diff --git a/riverdriver/rivermysql/migration/main/001_create_river_migration.up.sql b/riverdriver/rivermysql/migration/main/001_create_river_migration.up.sql new file mode 100644 index 00000000..d2ee3f73 --- /dev/null +++ b/riverdriver/rivermysql/migration/main/001_create_river_migration.up.sql @@ -0,0 +1,8 @@ +CREATE TABLE /* TEMPLATE: schema */river_migration ( + id BIGINT AUTO_INCREMENT PRIMARY KEY, + created_at DATETIME(6) NOT NULL DEFAULT (NOW(6)), + version BIGINT NOT NULL, + CONSTRAINT version CHECK (version >= 1) +) ENGINE=InnoDB; + +CREATE UNIQUE INDEX river_migration_version_idx ON /* TEMPLATE: schema */river_migration (version); diff --git a/riverdriver/rivermysql/migration/main/002_initial_schema.down.sql b/riverdriver/rivermysql/migration/main/002_initial_schema.down.sql new file mode 100644 index 00000000..095aa1dd --- /dev/null +++ b/riverdriver/rivermysql/migration/main/002_initial_schema.down.sql @@ -0,0 +1,2 @@ +DROP TABLE /* TEMPLATE: schema */river_job; +DROP TABLE /* TEMPLATE: schema */river_leader; diff --git a/riverdriver/rivermysql/migration/main/002_initial_schema.up.sql b/riverdriver/rivermysql/migration/main/002_initial_schema.up.sql new file mode 100644 index 00000000..593cfa93 --- /dev/null +++ b/riverdriver/rivermysql/migration/main/002_initial_schema.up.sql @@ -0,0 +1,43 @@ +CREATE TABLE /* TEMPLATE: schema */river_job( + id BIGINT AUTO_INCREMENT PRIMARY KEY, + state VARCHAR(20) NOT NULL DEFAULT 'available', + attempt INT NOT NULL DEFAULT 0, + max_attempts INT NOT NULL, + attempted_at DATETIME(6) NULL, + created_at DATETIME(6) NOT NULL DEFAULT (NOW(6)), + finalized_at DATETIME(6) NULL, + scheduled_at DATETIME(6) NOT NULL DEFAULT (NOW(6)), + priority SMALLINT NOT NULL DEFAULT 1, + args JSON NULL, + attempted_by JSON NULL, -- JSON array of strings (no native text[] in MySQL) + errors JSON NULL, -- JSON array of error objects (no native jsonb[] in MySQL) + kind VARCHAR(128) NOT NULL, + metadata JSON NOT NULL DEFAULT (JSON_OBJECT()), + queue VARCHAR(128) NOT NULL DEFAULT 'default', + tags JSON NULL, -- JSON array of strings (no native varchar[] in MySQL) + CONSTRAINT finalized_or_finalized_at_null CHECK ( + (state IN ('cancelled', 'completed', 'discarded') AND finalized_at IS NOT NULL) OR + finalized_at IS NULL + ), + CONSTRAINT max_attempts_is_positive CHECK (max_attempts > 0), + CONSTRAINT priority_in_range CHECK (priority >= 1 AND priority <= 4), + CONSTRAINT queue_length CHECK (CHAR_LENGTH(queue) > 0 AND CHAR_LENGTH(queue) < 128), + CONSTRAINT kind_length CHECK (CHAR_LENGTH(kind) > 0 AND CHAR_LENGTH(kind) < 128), + CONSTRAINT state_valid CHECK (state IN ('available', 'cancelled', 'completed', 'discarded', 'retryable', 'running', 'scheduled')) +) ENGINE=InnoDB; + +CREATE INDEX river_job_kind ON /* TEMPLATE: schema */river_job (kind); +CREATE INDEX river_job_state_and_finalized_at_index ON /* TEMPLATE: schema */river_job (state, finalized_at); +CREATE INDEX river_job_prioritized_fetching_index ON /* TEMPLATE: schema */river_job (state, queue, priority, scheduled_at, id); + +-- MySQL does not support triggers for LISTEN/NOTIFY, so river_job_notify is +-- omitted. The MySQL driver operates in poll-only mode. + +CREATE TABLE /* TEMPLATE: schema */river_leader ( + elected_at DATETIME(6) NOT NULL, + expires_at DATETIME(6) NOT NULL, + leader_id VARCHAR(128) NOT NULL, + name VARCHAR(128) NOT NULL PRIMARY KEY, + CONSTRAINT name_length CHECK (CHAR_LENGTH(name) > 0 AND CHAR_LENGTH(name) < 128), + CONSTRAINT leader_id_length CHECK (CHAR_LENGTH(leader_id) > 0 AND CHAR_LENGTH(leader_id) < 128) +) ENGINE=InnoDB; diff --git a/riverdriver/rivermysql/migration/main/003_river_job_tags_non_null.down.sql b/riverdriver/rivermysql/migration/main/003_river_job_tags_non_null.down.sql new file mode 100644 index 00000000..93c33638 --- /dev/null +++ b/riverdriver/rivermysql/migration/main/003_river_job_tags_non_null.down.sql @@ -0,0 +1 @@ +ALTER TABLE /* TEMPLATE: schema */river_job MODIFY COLUMN tags JSON NULL; diff --git a/riverdriver/rivermysql/migration/main/003_river_job_tags_non_null.up.sql b/riverdriver/rivermysql/migration/main/003_river_job_tags_non_null.up.sql new file mode 100644 index 00000000..61b2f430 --- /dev/null +++ b/riverdriver/rivermysql/migration/main/003_river_job_tags_non_null.up.sql @@ -0,0 +1,2 @@ +UPDATE /* TEMPLATE: schema */river_job SET tags = JSON_ARRAY() WHERE tags IS NULL; +ALTER TABLE /* TEMPLATE: schema */river_job MODIFY COLUMN tags JSON NOT NULL DEFAULT (JSON_ARRAY()); diff --git a/riverdriver/rivermysql/migration/main/004_pending_and_more.down.sql b/riverdriver/rivermysql/migration/main/004_pending_and_more.down.sql new file mode 100644 index 00000000..891a0e81 --- /dev/null +++ b/riverdriver/rivermysql/migration/main/004_pending_and_more.down.sql @@ -0,0 +1,21 @@ +ALTER TABLE /* TEMPLATE: schema */river_job MODIFY COLUMN args JSON NULL; + +ALTER TABLE /* TEMPLATE: schema */river_job MODIFY COLUMN metadata JSON NOT NULL DEFAULT (JSON_OBJECT()); + +-- Cannot safely remove 'pending' from the CHECK constraint if rows reference +-- it, but we restore the original constraint form. +ALTER TABLE /* TEMPLATE: schema */river_job DROP CONSTRAINT finalized_or_finalized_at_null; +ALTER TABLE /* TEMPLATE: schema */river_job ADD CONSTRAINT finalized_or_finalized_at_null CHECK ( + (state IN ('cancelled', 'completed', 'discarded') AND finalized_at IS NOT NULL) OR + finalized_at IS NULL +); + +-- MySQL does not support triggers for LISTEN/NOTIFY, so no trigger changes. + +DROP TABLE /* TEMPLATE: schema */river_queue; + +ALTER TABLE /* TEMPLATE: schema */river_leader + ALTER COLUMN name DROP DEFAULT; + +ALTER TABLE /* TEMPLATE: schema */river_leader DROP CONSTRAINT name_length; +ALTER TABLE /* TEMPLATE: schema */river_leader ADD CONSTRAINT name_length CHECK (CHAR_LENGTH(name) > 0 AND CHAR_LENGTH(name) < 128); diff --git a/riverdriver/rivermysql/migration/main/004_pending_and_more.up.sql b/riverdriver/rivermysql/migration/main/004_pending_and_more.up.sql new file mode 100644 index 00000000..791ebcda --- /dev/null +++ b/riverdriver/rivermysql/migration/main/004_pending_and_more.up.sql @@ -0,0 +1,48 @@ +-- Make args NOT NULL with a default. +UPDATE /* TEMPLATE: schema */river_job SET args = JSON_OBJECT() WHERE args IS NULL; +ALTER TABLE /* TEMPLATE: schema */river_job MODIFY COLUMN args JSON NOT NULL DEFAULT (JSON_OBJECT()); + +-- Make metadata NOT NULL (it already had a default). +UPDATE /* TEMPLATE: schema */river_job SET metadata = JSON_OBJECT() WHERE metadata IS NULL; +ALTER TABLE /* TEMPLATE: schema */river_job MODIFY COLUMN metadata JSON NOT NULL DEFAULT (JSON_OBJECT()); + +-- Add 'pending' to the set of valid states. MySQL doesn't have enum types to +-- alter, so we update the CHECK constraint instead. +ALTER TABLE /* TEMPLATE: schema */river_job DROP CONSTRAINT state_valid; +ALTER TABLE /* TEMPLATE: schema */river_job ADD CONSTRAINT state_valid CHECK ( + state IN ('available', 'cancelled', 'completed', 'discarded', 'pending', 'retryable', 'running', 'scheduled') +); + +-- Update the finalized_at constraint to use the inverted form (matching +-- Postgres migration 004). +ALTER TABLE /* TEMPLATE: schema */river_job DROP CONSTRAINT finalized_or_finalized_at_null; +ALTER TABLE /* TEMPLATE: schema */river_job ADD CONSTRAINT finalized_or_finalized_at_null CHECK ( + (finalized_at IS NULL AND state NOT IN ('cancelled', 'completed', 'discarded')) OR + (finalized_at IS NOT NULL AND state IN ('cancelled', 'completed', 'discarded')) +); + +-- MySQL does not support triggers for LISTEN/NOTIFY, so river_job_notify +-- changes from Postgres are omitted. + +-- +-- Create table `river_queue`. +-- + +CREATE TABLE /* TEMPLATE: schema */river_queue ( + name VARCHAR(128) NOT NULL PRIMARY KEY, + created_at DATETIME(6) NOT NULL DEFAULT (NOW(6)), + metadata JSON NOT NULL DEFAULT (JSON_OBJECT()), + paused_at DATETIME(6) NULL, + updated_at DATETIME(6) NOT NULL +) ENGINE=InnoDB; + +-- +-- Alter `river_leader` to add a default value of 'default' to `name`. +-- MySQL supports ALTER TABLE for changing defaults and constraints. +-- + +ALTER TABLE /* TEMPLATE: schema */river_leader + ALTER COLUMN name SET DEFAULT 'default'; + +ALTER TABLE /* TEMPLATE: schema */river_leader DROP CONSTRAINT name_length; +ALTER TABLE /* TEMPLATE: schema */river_leader ADD CONSTRAINT name_length CHECK (name = 'default'); diff --git a/riverdriver/rivermysql/migration/main/005_migration_unique_client.down.sql b/riverdriver/rivermysql/migration/main/005_migration_unique_client.down.sql new file mode 100644 index 00000000..1d77b022 --- /dev/null +++ b/riverdriver/rivermysql/migration/main/005_migration_unique_client.down.sql @@ -0,0 +1,35 @@ +-- +-- Revert to migration table based only on `(version)`. +-- + +CREATE TABLE /* TEMPLATE: schema */river_migration_old ( + id BIGINT AUTO_INCREMENT PRIMARY KEY, + created_at DATETIME(6) NOT NULL DEFAULT (NOW(6)), + version BIGINT NOT NULL, + CONSTRAINT version CHECK (version >= 1) +) ENGINE=InnoDB; + +CREATE UNIQUE INDEX river_migration_version_idx ON /* TEMPLATE: schema */river_migration_old (version); + +INSERT INTO /* TEMPLATE: schema */river_migration_old + (created_at, version) +SELECT created_at, version +FROM /* TEMPLATE: schema */river_migration; + +DROP TABLE /* TEMPLATE: schema */river_migration; + +ALTER TABLE /* TEMPLATE: schema */river_migration_old RENAME TO /* TEMPLATE: schema */river_migration; + +-- +-- Drop `river_job.unique_key` and its index. +-- + +DROP INDEX river_job_kind_unique_key_idx ON /* TEMPLATE: schema */river_job; +ALTER TABLE /* TEMPLATE: schema */river_job DROP COLUMN unique_key; + +-- +-- Drop `river_client` and derivative. +-- + +DROP TABLE /* TEMPLATE: schema */river_client_queue; +DROP TABLE /* TEMPLATE: schema */river_client; diff --git a/riverdriver/rivermysql/migration/main/005_migration_unique_client.up.sql b/riverdriver/rivermysql/migration/main/005_migration_unique_client.up.sql new file mode 100644 index 00000000..03e54abd --- /dev/null +++ b/riverdriver/rivermysql/migration/main/005_migration_unique_client.up.sql @@ -0,0 +1,60 @@ +-- +-- Rebuild the migration table so it's based on `(line, version)`. +-- + +DROP INDEX river_migration_version_idx ON /* TEMPLATE: schema */river_migration; + +CREATE TABLE /* TEMPLATE: schema */river_migration_new ( + line VARCHAR(128) NOT NULL, + version BIGINT NOT NULL, + created_at DATETIME(6) NOT NULL DEFAULT (NOW(6)), + CONSTRAINT line_length CHECK (CHAR_LENGTH(line) > 0 AND CHAR_LENGTH(line) < 128), + CONSTRAINT version_gte_1 CHECK (version >= 1), + PRIMARY KEY (line, version) +) ENGINE=InnoDB; + +INSERT INTO /* TEMPLATE: schema */river_migration_new + (created_at, line, version) +SELECT created_at, 'main', version +FROM /* TEMPLATE: schema */river_migration; + +DROP TABLE /* TEMPLATE: schema */river_migration; + +ALTER TABLE /* TEMPLATE: schema */river_migration_new RENAME TO /* TEMPLATE: schema */river_migration; + +-- +-- Add `river_job.unique_key` and bring up an index on it. +-- + +ALTER TABLE /* TEMPLATE: schema */river_job ADD COLUMN unique_key VARBINARY(255) NULL; + +CREATE UNIQUE INDEX river_job_kind_unique_key_idx ON /* TEMPLATE: schema */river_job (kind, unique_key); + +-- +-- Create `river_client` and derivative. +-- + +CREATE TABLE /* TEMPLATE: schema */river_client ( + id VARCHAR(128) NOT NULL PRIMARY KEY, + created_at DATETIME(6) NOT NULL DEFAULT (NOW(6)), + metadata JSON NOT NULL DEFAULT (JSON_OBJECT()), + paused_at DATETIME(6) NULL, + updated_at DATETIME(6) NOT NULL, + CONSTRAINT client_name_length CHECK (CHAR_LENGTH(id) > 0 AND CHAR_LENGTH(id) < 128) +) ENGINE=InnoDB; + +CREATE TABLE /* TEMPLATE: schema */river_client_queue ( + river_client_id VARCHAR(128) NOT NULL, + name VARCHAR(128) NOT NULL, + created_at DATETIME(6) NOT NULL DEFAULT (NOW(6)), + max_workers INT NOT NULL DEFAULT 0, + metadata JSON NOT NULL DEFAULT (JSON_OBJECT()), + num_jobs_completed BIGINT NOT NULL DEFAULT 0, + num_jobs_running BIGINT NOT NULL DEFAULT 0, + updated_at DATETIME(6) NOT NULL, + PRIMARY KEY (river_client_id, name), + CONSTRAINT fk_river_client FOREIGN KEY (river_client_id) REFERENCES river_client (id) ON DELETE CASCADE, + CONSTRAINT cq_name_length CHECK (CHAR_LENGTH(name) > 0 AND CHAR_LENGTH(name) < 128), + CONSTRAINT num_jobs_completed_zero_or_positive CHECK (num_jobs_completed >= 0), + CONSTRAINT num_jobs_running_zero_or_positive CHECK (num_jobs_running >= 0) +) ENGINE=InnoDB; diff --git a/riverdriver/rivermysql/migration/main/006_bulk_unique.down.sql b/riverdriver/rivermysql/migration/main/006_bulk_unique.down.sql new file mode 100644 index 00000000..0e4bf09f --- /dev/null +++ b/riverdriver/rivermysql/migration/main/006_bulk_unique.down.sql @@ -0,0 +1,9 @@ +-- +-- Drop the functional unique index and unique_states column. +-- + +DROP INDEX river_job_unique_idx ON /* TEMPLATE: schema */river_job; +ALTER TABLE /* TEMPLATE: schema */river_job DROP COLUMN unique_states; + +-- Recreate the old unique index from migration 005. +CREATE UNIQUE INDEX river_job_kind_unique_key_idx ON /* TEMPLATE: schema */river_job (kind, unique_key); diff --git a/riverdriver/rivermysql/migration/main/006_bulk_unique.up.sql b/riverdriver/rivermysql/migration/main/006_bulk_unique.up.sql new file mode 100644 index 00000000..2f4883cc --- /dev/null +++ b/riverdriver/rivermysql/migration/main/006_bulk_unique.up.sql @@ -0,0 +1,45 @@ +-- +-- Add `river_job.unique_states` and bring up an index on it. +-- +-- MySQL 8.0+ supports functional indexes (index on an expression). We use one +-- to index `unique_key` only when the job's state matches the bitmask, which +-- is equivalent to Postgres's partial index with `river_job_state_in_bitmask`. +-- The expression evaluates to NULL when the constraint shouldn't be active, +-- and MySQL's UNIQUE allows multiple NULLs. +-- + +ALTER TABLE /* TEMPLATE: schema */river_job ADD COLUMN unique_states SMALLINT NULL; + +-- A "functional index" (feature specific to MySQL 8.0+) over `unique_key`, +-- `unique_states`, and `state`. The expression evaluates to `unique_key` when +-- the job's current state has its corresponding bit set in `unique_states`, and +-- to NULL otherwise. MySQL's UNIQUE indexes permit multiple NULL values, so +-- rows where the constraint is inactive (NULL result) never conflict with each +-- other, while rows where it's active are checked for uniqueness on +-- `unique_key`. This is the MySQL equivalent of Postgres's partial index with +-- `WHERE ... AND river_job_state_in_bitmask(unique_states, state)`. +-- +-- Unlike Postgres's partial indexes which exclude non-matching rows from the +-- index entirely, MySQL's functional index still stores an entry for every row +-- (NULLs included), so the storage savings aren't equivalent. The uniqueness +-- semantics are the same though. +CREATE UNIQUE INDEX river_job_unique_idx ON /* TEMPLATE: schema */river_job (( + CASE WHEN unique_key IS NOT NULL AND unique_states IS NOT NULL AND + CASE state + WHEN 'available' THEN unique_states & (1 << 0) + WHEN 'cancelled' THEN unique_states & (1 << 1) + WHEN 'completed' THEN unique_states & (1 << 2) + WHEN 'discarded' THEN unique_states & (1 << 3) + WHEN 'pending' THEN unique_states & (1 << 4) + WHEN 'retryable' THEN unique_states & (1 << 5) + WHEN 'running' THEN unique_states & (1 << 6) + WHEN 'scheduled' THEN unique_states & (1 << 7) + ELSE 0 + END >= 1 + THEN unique_key + ELSE NULL + END +)); + +-- Remove the old unique index from migration 005. +DROP INDEX river_job_kind_unique_key_idx ON /* TEMPLATE: schema */river_job; diff --git a/riverdriver/rivermysql/river_mysql_driver.go b/riverdriver/rivermysql/river_mysql_driver.go new file mode 100644 index 00000000..924cb91d --- /dev/null +++ b/riverdriver/rivermysql/river_mysql_driver.go @@ -0,0 +1,1962 @@ +// Package rivermysql provides a River driver implementation for MySQL. +// +// This driver targets MySQL 8.0+ and requires `go-sql-driver/mysql` or a +// compatible driver registered with `database/sql`. The DSN should include +// `parseTime=true` to ensure `DATETIME` columns are correctly scanned into +// `time.Time` values. +// +// MySQL does not support LISTEN/NOTIFY, so this driver operates in poll-only +// mode. It also does not support `RETURNING` clauses, so most write operations +// are carried out as two-step operations (write + read). +// +// This driver is currently in early development. It's exercised in the test +// suite, but has minimal real world use as of yet. +package rivermysql + +import ( + "context" + "database/sql" + "embed" + "encoding/json" + "errors" + "fmt" + "io/fs" + "math" + "slices" + "strings" + "time" + + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" + + "github.com/riverqueue/river/internal/rivercommon" + "github.com/riverqueue/river/riverdriver" + "github.com/riverqueue/river/riverdriver/rivermysql/internal/dbsqlc" + "github.com/riverqueue/river/rivershared/sqlctemplate" + "github.com/riverqueue/river/rivershared/uniquestates" + "github.com/riverqueue/river/rivershared/util/dbutil" + "github.com/riverqueue/river/rivershared/util/ptrutil" + "github.com/riverqueue/river/rivershared/util/randutil" + "github.com/riverqueue/river/rivershared/util/savepointutil" + "github.com/riverqueue/river/rivershared/util/sliceutil" + "github.com/riverqueue/river/rivertype" +) + +//go:embed migration/*/*.sql +var migrationFS embed.FS + +// Driver is an implementation of riverdriver.Driver for MySQL. +type Driver struct { + dbPool *sql.DB + replacer sqlctemplate.Replacer +} + +// New returns a new MySQL driver for use with River. +// +// It takes an sql.DB to use for use with River. The DSN should include +// `parseTime=true` for correct time handling. The pool must not be closed while +// associated River objects are running. +func New(dbPool *sql.DB) *Driver { + return &Driver{ + dbPool: dbPool, + replacer: sqlctemplate.Replacer{UnnumberedPlaceholders: true}, + } +} + +const argPlaceholder = "?" + +func (d *Driver) ArgPlaceholder() string { return argPlaceholder } +func (d *Driver) DatabaseName() string { return "mysql" } +func (d *Driver) SafeIdentifier(ident string) string { return mysqlIdentifier(ident) } + +func (d *Driver) GetExecutor() riverdriver.Executor { + return &Executor{d.dbPool, templateReplaceWrapper{d.dbPool, &d.replacer}, d, nil} +} + +func (d *Driver) GetListener(params *riverdriver.GetListenenerParams) riverdriver.Listener { + panic(riverdriver.ErrNotImplemented) +} + +func (d *Driver) GetMigrationDefaultLines() []string { return []string{riverdriver.MigrationLineMain} } +func (d *Driver) GetMigrationFS(line string) fs.FS { + if line == riverdriver.MigrationLineMain { + return migrationFS + } + panic("migration line does not exist: " + line) +} +func (d *Driver) GetMigrationLines() []string { return []string{riverdriver.MigrationLineMain} } +func (d *Driver) GetMigrationTruncateTables(line string, version int) []string { + if line == riverdriver.MigrationLineMain { + return riverdriver.MigrationLineMainTruncateTables(version) + } + panic("migration line does not exist: " + line) +} + +func (d *Driver) PoolIsSet() bool { return d.dbPool != nil } +func (d *Driver) PoolSet(dbPool any) error { + if d.dbPool != nil { + return errors.New("cannot PoolSet when internal pool is already non-nil") + } + d.dbPool = dbPool.(*sql.DB) //nolint:forcetypeassert + return nil +} + +func (d *Driver) SQLFragmentColumnIn(column string, values any) (string, any, error) { + arg, err := json.Marshal(values) + if err != nil { + return "", nil, err + } + + // Use JSON_TABLE to expand the JSON array into rows for an IN clause. + // The arg is passed through the template system's NamedArgs as @column, + // which the template replacer turns into a positional placeholder. The + // templateReplaceWrapper strips the number suffix to produce plain `?` + // that MySQL expects. Use VARCHAR(255) as the column type since it + // handles both integer and string comparisons via implicit conversion. + // COLLATE must match the table's column collation. utf8mb4_0900_ai_ci is + // MySQL 8.0+'s default collation for utf8mb4. Without this, JSON_TABLE + // produces values with utf8mb4_general_ci which causes a collation mismatch + // in comparisons. + return fmt.Sprintf("%s IN (SELECT jt.val COLLATE utf8mb4_0900_ai_ci FROM JSON_TABLE(CAST(@%s AS JSON), '$[*]' COLUMNS(val VARCHAR(255) PATH '$')) AS jt)", column, column), arg, nil +} + +func (d *Driver) SupportsListener() bool { return false } +func (d *Driver) SupportsListenNotify() bool { return false } +func (d *Driver) TimePrecision() time.Duration { return time.Microsecond } + +func (d *Driver) UnwrapExecutor(tx *sql.Tx) riverdriver.ExecutorTx { + // Allows UnwrapExecutor to be invoked even if driver is nil. + var replacer *sqlctemplate.Replacer + if d == nil { + replacer = &sqlctemplate.Replacer{UnnumberedPlaceholders: true} + } else { + replacer = &d.replacer + } + + executorTx := &ExecutorTx{tx: tx} + executorTx.Executor = Executor{nil, templateReplaceWrapper{tx, replacer}, d, executorTx} + + return executorTx +} + +func (d *Driver) UnwrapTx(execTx riverdriver.ExecutorTx) *sql.Tx { + switch execTx := execTx.(type) { + case *ExecutorSubTx: + return execTx.tx + case *ExecutorTx: + return execTx.tx + } + panic("unhandled executor type") +} + +type Executor struct { + dbPool *sql.DB + dbtx templateReplaceWrapper + driver *Driver + execTx riverdriver.ExecutorTx +} + +func (e *Executor) Begin(ctx context.Context) (riverdriver.ExecutorTx, error) { + if e.execTx != nil { + return e.execTx.Begin(ctx) + } + + tx, err := e.dbPool.BeginTx(ctx, nil) + if err != nil { + return nil, err + } + + executorTx := &ExecutorTx{tx: tx} + executorTx.Executor = Executor{nil, templateReplaceWrapper{tx, &e.driver.replacer}, e.driver, executorTx} + return executorTx, nil +} + +func (e *Executor) ColumnExists(ctx context.Context, params *riverdriver.ColumnExistsParams) (bool, error) { + var queryArgs []any + query := `SELECT EXISTS ( + SELECT 1 + FROM INFORMATION_SCHEMA.COLUMNS + WHERE TABLE_NAME = ? AND COLUMN_NAME = ?` + queryArgs = append(queryArgs, params.Table, params.Column) + + if params.Schema != "" { + query += " AND TABLE_SCHEMA = ?" + queryArgs = append(queryArgs, params.Schema) + } else { + query += " AND TABLE_SCHEMA = DATABASE()" + } + query += ")" + + var exists bool + if err := e.dbtx.QueryRowContext(ctx, query, queryArgs...).Scan(&exists); err != nil { + return false, interpretError(err) + } + return exists, nil +} + +func (e *Executor) Exec(ctx context.Context, sql string, args ...any) error { + _, err := e.dbtx.ExecContext(ctx, sql, args...) + return interpretError(err) +} + +func (e *Executor) IndexDropIfExists(ctx context.Context, params *riverdriver.IndexDropIfExistsParams) error { + indexName := strings.TrimSpace(params.Index) + + exists, err := e.IndexExists(ctx, &riverdriver.IndexExistsParams{ + Index: indexName, + Schema: params.Schema, + }) + if err != nil { + return err + } + if !exists { + return nil + } + + // MySQL's DROP INDEX requires the table name. Look it up from the index. + tableName, err := dbsqlc.New().IndexGetTableName(informationSchemaParam(ctx), e.dbtx, &dbsqlc.IndexGetTableNameParams{ + IndexName: indexName, + Schema: sql.NullString{String: params.Schema, Valid: params.Schema != ""}, + }) + if err != nil { + return interpretError(err) + } + + var maybeSchema string + if params.Schema != "" { + maybeSchema = mysqlIdentifier(params.Schema) + "." + } + + _, err = e.dbtx.ExecContext(ctx, "DROP INDEX "+mysqlIdentifier(indexName)+" ON "+maybeSchema+mysqlIdentifier(tableName)) + return interpretError(err) +} + +func (e *Executor) IndexExists(ctx context.Context, params *riverdriver.IndexExistsParams) (bool, error) { + ctx = informationSchemaParam(ctx) + + exists, err := dbsqlc.New().IndexExists(ctx, e.dbtx, &dbsqlc.IndexExistsParams{ + IndexName: params.Index, + Schema: sql.NullString{String: params.Schema, Valid: params.Schema != ""}, + }) + if err != nil { + return false, interpretError(err) + } + return exists, nil +} + +func (e *Executor) IndexReindex(ctx context.Context, params *riverdriver.IndexReindexParams) error { + // MySQL doesn't have a direct REINDEX command. We use OPTIMIZE TABLE + // or ALTER TABLE ... FORCE as the closest equivalent. For now, use + // ANALYZE TABLE which updates index statistics. + var tableName string + query := `SELECT TABLE_NAME FROM INFORMATION_SCHEMA.STATISTICS WHERE INDEX_NAME = ?` + queryArgs := []any{params.Index} + if params.Schema != "" { + query += " AND TABLE_SCHEMA = ?" + queryArgs = append(queryArgs, params.Schema) + } else { + query += " AND TABLE_SCHEMA = DATABASE()" + } + query += " LIMIT 1" + + err := e.dbtx.QueryRowContext(ctx, query, queryArgs...).Scan(&tableName) + if err != nil { + return interpretError(err) + } + + var maybeSchema string + if params.Schema != "" { + maybeSchema = mysqlIdentifier(params.Schema) + "." + } + + _, err = e.dbtx.ExecContext(ctx, "ANALYZE TABLE "+maybeSchema+mysqlIdentifier(tableName)) + return interpretError(err) +} + +func (e *Executor) IndexesExist(ctx context.Context, params *riverdriver.IndexesExistParams) (map[string]bool, error) { + foundNames, err := dbsqlc.New().IndexesExist(informationSchemaParam(ctx), e.dbtx, &dbsqlc.IndexesExistParams{ + IndexNames: params.IndexNames, + Schema: sql.NullString{String: params.Schema, Valid: params.Schema != ""}, + }) + if err != nil { + return nil, interpretError(err) + } + + exists := make(map[string]bool, len(params.IndexNames)) + for _, name := range foundNames { + exists[name] = true + } + return exists, nil +} + +func (e *Executor) JobCancel(ctx context.Context, params *riverdriver.JobCancelParams) (*rivertype.JobRow, error) { + return dbutil.WithTxV(ctx, e, func(ctx context.Context, execTx riverdriver.ExecutorTx) (*rivertype.JobRow, error) { + ctx = schemaTemplateParam(ctx, params.Schema) + dbtx := templateReplaceWrapper{dbtx: e.driver.UnwrapTx(execTx), replacer: &e.driver.replacer} + + res, err := dbsqlc.New().JobCancelExec(ctx, dbtx, &dbsqlc.JobCancelExecParams{ + ID: params.ID, + CancelAttemptedAt: params.CancelAttemptedAt.UTC().Format(time.RFC3339Nano), + Now: nullTimeFromPtr(params.Now), + }) + if err != nil { + return nil, interpretError(err) + } + + rowsAffected, err := res.RowsAffected() + if err != nil { + return nil, interpretError(err) + } + + job, err := dbsqlc.New().JobGetByID(ctx, dbtx, params.ID) + if err != nil { + return nil, interpretError(err) + } + + if rowsAffected < 1 { + // No rows were updated, return the job as-is (already finalized) + return jobRowFromInternal(job) + } + + return jobRowFromInternal(job) + }) +} + +func (e *Executor) JobCountByAllStates(ctx context.Context, params *riverdriver.JobCountByAllStatesParams) (map[rivertype.JobState]int, error) { + counts, err := dbsqlc.New().JobCountByAllStates(schemaTemplateParam(ctx, params.Schema), e.dbtx) + if err != nil { + return nil, interpretError(err) + } + countsMap := make(map[rivertype.JobState]int) + for _, state := range rivertype.JobStates() { + countsMap[state] = 0 + } + for _, count := range counts { + countsMap[rivertype.JobState(count.State)] = int(count.Count) + } + return countsMap, nil +} + +func (e *Executor) JobCountByQueueAndState(ctx context.Context, params *riverdriver.JobCountByQueueAndStateParams) ([]*riverdriver.JobCountByQueueAndStateResult, error) { + rows, err := dbsqlc.New().JobCountByQueueAndState(schemaTemplateParam(ctx, params.Schema), e.dbtx, &dbsqlc.JobCountByQueueAndStateParams{ + QueueNames: params.QueueNames, + }) + if err != nil { + return nil, interpretError(err) + } + + // MySQL's GROUP BY only returns queues that have jobs, so fill in zero + // counts for any requested queues not in the result. + countsByQueue := make(map[string]*riverdriver.JobCountByQueueAndStateResult, len(rows)) + for _, row := range rows { + countsByQueue[row.Queue] = &riverdriver.JobCountByQueueAndStateResult{ + CountAvailable: row.CountAvailable, + CountRunning: row.CountRunning, + Queue: row.Queue, + } + } + + var ( + queueNames = slices.Compact(slices.Sorted(slices.Values(params.QueueNames))) + results = make([]*riverdriver.JobCountByQueueAndStateResult, len(queueNames)) + ) + for i, name := range queueNames { + if result, ok := countsByQueue[name]; ok { + results[i] = result + } else { + results[i] = &riverdriver.JobCountByQueueAndStateResult{Queue: name} + } + } + + return results, nil +} + +func (e *Executor) JobCountByState(ctx context.Context, params *riverdriver.JobCountByStateParams) (int, error) { + numJobs, err := dbsqlc.New().JobCountByState(schemaTemplateParam(ctx, params.Schema), e.dbtx, string(params.State)) + if err != nil { + return 0, err + } + return int(numJobs), nil +} + +func (e *Executor) JobDelete(ctx context.Context, params *riverdriver.JobDeleteParams) (*rivertype.JobRow, error) { + return dbutil.WithTxV(ctx, e, func(ctx context.Context, execTx riverdriver.ExecutorTx) (*rivertype.JobRow, error) { + ctx = schemaTemplateParam(ctx, params.Schema) + dbtx := templateReplaceWrapper{dbtx: e.driver.UnwrapTx(execTx), replacer: &e.driver.replacer} + + // First fetch the job so we can return it + job, err := dbsqlc.New().JobGetByID(ctx, dbtx, params.ID) + if err != nil { + return nil, interpretError(err) + } + + res, err := dbsqlc.New().JobDeleteExec(ctx, dbtx, params.ID) + if err != nil { + return nil, interpretError(err) + } + + rowsAffected, err := res.RowsAffected() + if err != nil { + return nil, interpretError(err) + } + + if rowsAffected < 1 { + if rivertype.JobState(job.State) == rivertype.JobStateRunning { + return nil, rivertype.ErrJobRunning + } + return nil, fmt.Errorf("bug; expected only to fetch a job with state %q, but was: %q", rivertype.JobStateRunning, job.State) + } + + return jobRowFromInternal(job) + }) +} + +func (e *Executor) JobDeleteBefore(ctx context.Context, params *riverdriver.JobDeleteBeforeParams) (int, error) { + if len(params.QueuesIncluded) > 0 { + return 0, riverdriver.ErrNotImplemented + } + + var queuesExcludedEmpty int64 + if len(params.QueuesExcluded) < 1 { + queuesExcludedEmpty = 1 + } + + res, err := dbsqlc.New().JobDeleteBefore(schemaTemplateParam(ctx, params.Schema), e.dbtx, &dbsqlc.JobDeleteBeforeParams{ + CancelledFinalizedAtHorizon: sql.NullTime{Time: params.CancelledFinalizedAtHorizon, Valid: true}, + CompletedFinalizedAtHorizon: sql.NullTime{Time: params.CompletedFinalizedAtHorizon, Valid: true}, + DiscardedFinalizedAtHorizon: sql.NullTime{Time: params.DiscardedFinalizedAtHorizon, Valid: true}, + Limit: int32(params.Max), //nolint:gosec + QueuesExcluded: params.QueuesExcluded, + QueuesExcludedEmpty: queuesExcludedEmpty, + }) + if err != nil { + return 0, interpretError(err) + } + rowsAffected, err := res.RowsAffected() + if err != nil { + return 0, interpretError(err) + } + return int(rowsAffected), nil +} + +func (e *Executor) JobDeleteMany(ctx context.Context, params *riverdriver.JobDeleteManyParams) ([]*rivertype.JobRow, error) { + var jobs []*dbsqlc.RiverJob + { + ctx := sqlctemplate.WithReplacements(ctx, map[string]sqlctemplate.Replacement{ + "order_by_clause": {Value: params.OrderByClause}, + "where_clause": {Value: params.WhereClause}, + }, params.NamedArgs) + ctx = schemaTemplateParam(ctx, params.Schema) + + // Step 1: Select the rows to delete + var err error + jobs, err = dbsqlc.New().JobDeleteManySelect(ctx, e.dbtx, params.Max) + if err != nil { + return nil, interpretError(err) + } + } + + if len(jobs) > 0 { + // Step 2: Delete those rows by ID. Use a fresh context with only + // schema replacement — the select context has where_clause/order_by_clause + // template params that would panic on the simpler DELETE SQL. + ctx := schemaTemplateParam(ctx, params.Schema) + ids := sliceutil.Map(jobs, func(j *dbsqlc.RiverJob) int64 { return j.ID }) + if err := dbsqlc.New().JobDeleteManyExec(ctx, e.dbtx, ids); err != nil { + return nil, interpretError(err) + } + } + + return sliceutil.MapError(jobs, jobRowFromInternal) +} + +func (e *Executor) JobGetAvailable(ctx context.Context, params *riverdriver.JobGetAvailableParams) ([]*rivertype.JobRow, error) { + ctx = schemaTemplateParam(ctx, params.Schema) + + ids, err := dbsqlc.New().JobGetAvailableIDs(ctx, e.dbtx, &dbsqlc.JobGetAvailableIDsParams{ + Queue: params.Queue, + Now: nullTimeFromPtr(params.Now), + Limit: int32(params.MaxToLock), //nolint:gosec + }) + if err != nil { + return nil, interpretError(err) + } + + if len(ids) == 0 { + return nil, nil + } + + if err := dbsqlc.New().JobGetAvailableUpdate(ctx, e.dbtx, &dbsqlc.JobGetAvailableUpdateParams{ + Now: nullTimeFromPtr(params.Now), + MaxAttemptedBy: int64(params.MaxAttemptedBy), + AttemptedBy: params.ClientID, + ID: ids, + }); err != nil { + return nil, interpretError(err) + } + + jobs, err := dbsqlc.New().JobGetByIDManyOrdered(ctx, e.dbtx, ids) + if err != nil { + return nil, interpretError(err) + } + + return sliceutil.MapError(jobs, jobRowFromInternal) +} + +func (e *Executor) JobGetByID(ctx context.Context, params *riverdriver.JobGetByIDParams) (*rivertype.JobRow, error) { + job, err := dbsqlc.New().JobGetByID(schemaTemplateParam(ctx, params.Schema), e.dbtx, params.ID) + if err != nil { + return nil, interpretError(err) + } + return jobRowFromInternal(job) +} + +func (e *Executor) JobGetByIDMany(ctx context.Context, params *riverdriver.JobGetByIDManyParams) ([]*rivertype.JobRow, error) { + jobs, err := dbsqlc.New().JobGetByIDMany(schemaTemplateParam(ctx, params.Schema), e.dbtx, params.ID) + if err != nil { + return nil, interpretError(err) + } + return sliceutil.MapError(jobs, jobRowFromInternal) +} + +func (e *Executor) JobGetByKindMany(ctx context.Context, params *riverdriver.JobGetByKindManyParams) ([]*rivertype.JobRow, error) { + jobs, err := dbsqlc.New().JobGetByKindMany(schemaTemplateParam(ctx, params.Schema), e.dbtx, params.Kind) + if err != nil { + return nil, interpretError(err) + } + return sliceutil.MapError(jobs, jobRowFromInternal) +} + +func (e *Executor) JobGetStuck(ctx context.Context, params *riverdriver.JobGetStuckParams) ([]*rivertype.JobRow, error) { + jobs, err := dbsqlc.New().JobGetStuck(schemaTemplateParam(ctx, params.Schema), e.dbtx, &dbsqlc.JobGetStuckParams{ + Limit: int32(params.Max), //nolint:gosec + StuckHorizon: sql.NullTime{Time: params.StuckHorizon, Valid: true}, + }) + if err != nil { + return nil, interpretError(err) + } + return sliceutil.MapError(jobs, jobRowFromInternal) +} + +func (e *Executor) JobInsertFastMany(ctx context.Context, params *riverdriver.JobInsertFastManyParams) ([]*riverdriver.JobInsertFastResult, error) { + var ( + insertRes = make([]*riverdriver.JobInsertFastResult, len(params.Jobs)) + + // We use a special `(xmax != 0)` trick in Postgres to determine whether + // an upserted row was inserted or skipped, but as far as I can find, + // there's no such trick possible in MySQL. Instead, we roll a random + // nonce and insert it to metadata. If the same nonce comes back, we know + // we really inserted the row. If not, we're getting an existing row back. + uniqueNonce = randutil.Hex(8) + ) + + if err := dbutil.WithTx(ctx, e, func(ctx context.Context, execTx riverdriver.ExecutorTx) error { + ctx = schemaTemplateParam(ctx, params.Schema) + dbtx := templateReplaceWrapper{dbtx: e.driver.UnwrapTx(execTx), replacer: &e.driver.replacer} + + ids := make([]int64, len(params.Jobs)) + for i, params := range params.Jobs { + insertParams, err := jobInsertFastParams(params, uniqueNonce) + if err != nil { + return err + } + + res, err := dbsqlc.New().JobInsertFast(ctx, dbtx, insertParams) + if err != nil { + return interpretError(err) + } + ids[i], err = res.LastInsertId() + if err != nil { + return err + } + } + + jobs, err := dbsqlc.New().JobGetByIDMany(ctx, dbtx, ids) + if err != nil { + return interpretError(err) + } + + jobsByID := make(map[int64]*dbsqlc.RiverJob, len(jobs)) + for _, j := range jobs { + jobsByID[j.ID] = j + } + + for i, id := range ids { + job, err := jobRowFromInternal(jobsByID[id]) + if err != nil { + return err + } + + insertRes[i] = &riverdriver.JobInsertFastResult{ + Job: job, + UniqueSkippedAsDuplicate: gjson.GetBytes(job.Metadata, rivercommon.MetadataKeyUniqueNonce).Str != uniqueNonce, + } + } + + return nil + }); err != nil { + return nil, err + } + + return insertRes, nil +} + +func (e *Executor) JobInsertFastManyNoReturning(ctx context.Context, params *riverdriver.JobInsertFastManyParams) (int, error) { + var totalRowsAffected int + + if err := dbutil.WithTx(ctx, e, func(ctx context.Context, execTx riverdriver.ExecutorTx) error { + ctx = schemaTemplateParam(ctx, params.Schema) + dbtx := templateReplaceWrapper{dbtx: e.driver.UnwrapTx(execTx), replacer: &e.driver.replacer} + + for _, params := range params.Jobs { + insertParams, err := jobInsertFastParams(params, "") + if err != nil { + return err + } + + res, err := dbsqlc.New().JobInsertFast(ctx, dbtx, insertParams) + if err != nil { + return interpretError(err) + } + + rowsAffected, err := res.RowsAffected() + if err != nil { + return interpretError(err) + } + totalRowsAffected += int(rowsAffected) + } + + return nil + }); err != nil { + return 0, err + } + + return totalRowsAffected, nil +} + +// jobInsertFastParams builds the common insert parameters for a single job. +// If uniqueNonce is non-empty, it's set in the job's metadata for duplicate +// detection. +func jobInsertFastParams(params *riverdriver.JobInsertFastParams, uniqueNonce string) (*dbsqlc.JobInsertFastParams, error) { + metadata := sliceutil.FirstNonEmpty(params.Metadata, []byte("{}")) + if uniqueNonce != "" { + // Error intentionally ignored — SetBytes on valid JSON can't fail + // for a simple key addition. + metadata, _ = sjson.SetBytes(metadata, rivercommon.MetadataKeyUniqueNonce, uniqueNonce) + } + + tags, err := json.Marshal(params.Tags) + if err != nil { + return nil, err + } + + var uniqueStates sql.NullInt16 + if params.UniqueStates != 0 { + uniqueStates = sql.NullInt16{Int16: int16(params.UniqueStates), Valid: true} + } + + var id sql.NullInt64 + if params.ID != nil { + id = sql.NullInt64{Int64: *params.ID, Valid: true} + } + + return &dbsqlc.JobInsertFastParams{ + ID: id, + Args: params.EncodedArgs, + CreatedAt: params.CreatedAt, + Kind: params.Kind, + MaxAttempts: int64(params.MaxAttempts), + Metadata: metadata, + Priority: int16(params.Priority), //nolint:gosec + Queue: params.Queue, + ScheduledAt: params.ScheduledAt, + State: string(params.State), + Tags: tags, + UniqueKey: nullStringFromBytes(params.UniqueKey), + UniqueStates: uniqueStates, + }, nil +} + +func (e *Executor) JobInsertFull(ctx context.Context, params *riverdriver.JobInsertFullParams) (*rivertype.JobRow, error) { + var attemptedBy []byte + if params.AttemptedBy != nil { + var err error + attemptedBy, err = json.Marshal(params.AttemptedBy) + if err != nil { + return nil, err + } + } + + var errorsData []byte + if len(params.Errors) > 0 { + var err error + errorsData, err = json.Marshal(sliceutil.Map(params.Errors, func(e []byte) json.RawMessage { return json.RawMessage(e) })) + if err != nil { + return nil, err + } + } + + tags, err := json.Marshal(params.Tags) + if err != nil { + return nil, err + } + + var uniqueStates sql.NullInt16 + if params.UniqueStates != 0 { + uniqueStates = sql.NullInt16{Int16: int16(params.UniqueStates), Valid: true} + } + + ctx = schemaTemplateParam(ctx, params.Schema) + + lastInsertID, err := dbsqlc.New().JobInsertFullExec(ctx, e.dbtx, &dbsqlc.JobInsertFullExecParams{ + Attempt: int64(params.Attempt), + AttemptedAt: nullTimeFromPtr(params.AttemptedAt), + AttemptedBy: attemptedBy, + Args: params.EncodedArgs, + CreatedAt: params.CreatedAt, + Errors: errorsData, + FinalizedAt: nullTimeFromPtr(params.FinalizedAt), + Kind: params.Kind, + MaxAttempts: int64(params.MaxAttempts), + Metadata: sliceutil.FirstNonEmpty(params.Metadata, []byte("{}")), + Priority: int16(params.Priority), //nolint:gosec + Queue: params.Queue, + ScheduledAt: params.ScheduledAt, + State: string(params.State), + Tags: tags, + UniqueKey: nullStringFromBytes(params.UniqueKey), + UniqueStates: uniqueStates, + }) + if err != nil { + return nil, interpretError(err) + } + + job, err := dbsqlc.New().JobGetByID(ctx, e.dbtx, lastInsertID) + if err != nil { + return nil, interpretError(err) + } + return jobRowFromInternal(job) +} + +func (e *Executor) JobInsertFullMany(ctx context.Context, params *riverdriver.JobInsertFullManyParams) ([]*rivertype.JobRow, error) { + insertRes := make([]*rivertype.JobRow, len(params.Jobs)) + + if err := dbutil.WithTx(ctx, e, func(ctx context.Context, execTx riverdriver.ExecutorTx) error { + ctx = schemaTemplateParam(ctx, params.Schema) + dbtx := templateReplaceWrapper{dbtx: e.driver.UnwrapTx(execTx), replacer: &e.driver.replacer} + + ids := make([]int64, len(params.Jobs)) + + for i, jobParams := range params.Jobs { + var attemptedBy []byte + if jobParams.AttemptedBy != nil { + var err error + attemptedBy, err = json.Marshal(jobParams.AttemptedBy) + if err != nil { + return err + } + } + + var errorsData []byte + if len(jobParams.Errors) > 0 { + var err error + errorsData, err = json.Marshal(sliceutil.Map(jobParams.Errors, func(e []byte) json.RawMessage { return json.RawMessage(e) })) + if err != nil { + return err + } + } + + tags, err := json.Marshal(jobParams.Tags) + if err != nil { + return err + } + + var uniqueStates sql.NullInt16 + if jobParams.UniqueStates != 0 { + uniqueStates = sql.NullInt16{Int16: int16(jobParams.UniqueStates), Valid: true} + } + + ids[i], err = dbsqlc.New().JobInsertFullExec(ctx, dbtx, &dbsqlc.JobInsertFullExecParams{ + Attempt: int64(jobParams.Attempt), + AttemptedAt: nullTimeFromPtr(jobParams.AttemptedAt), + AttemptedBy: attemptedBy, + Args: jobParams.EncodedArgs, + CreatedAt: jobParams.CreatedAt, + Errors: errorsData, + FinalizedAt: nullTimeFromPtr(jobParams.FinalizedAt), + Kind: jobParams.Kind, + MaxAttempts: int64(jobParams.MaxAttempts), + Metadata: sliceutil.FirstNonEmpty(jobParams.Metadata, []byte("{}")), + Priority: int16(jobParams.Priority), //nolint:gosec + Queue: jobParams.Queue, + ScheduledAt: jobParams.ScheduledAt, + State: string(jobParams.State), + Tags: tags, + UniqueKey: nullStringFromBytes(jobParams.UniqueKey), + UniqueStates: uniqueStates, + }) + if err != nil { + return interpretError(err) + } + } + + jobs, err := dbsqlc.New().JobGetByIDMany(ctx, dbtx, ids) + if err != nil { + return interpretError(err) + } + + jobsByID := make(map[int64]*dbsqlc.RiverJob, len(jobs)) + for _, j := range jobs { + jobsByID[j.ID] = j + } + + for i, id := range ids { + insertRes[i], err = jobRowFromInternal(jobsByID[id]) + if err != nil { + return err + } + } + + return nil + }); err != nil { + return nil, err + } + + return insertRes, nil +} + +func (e *Executor) JobKindList(ctx context.Context, params *riverdriver.JobKindListParams) ([]string, error) { + exclude := params.Exclude + if len(exclude) == 0 { + exclude = []string{""} + } + + kinds, err := dbsqlc.New().JobKindList(schemaTemplateParam(ctx, params.Schema), e.dbtx, &dbsqlc.JobKindListParams{ + After: params.After, + Exclude: exclude, + Match: params.Match, + Limit: int32(min(params.Max, math.MaxInt32)), //nolint:gosec + }) + if err != nil { + return nil, interpretError(err) + } + return kinds, nil +} + +func (e *Executor) JobList(ctx context.Context, params *riverdriver.JobListParams) ([]*rivertype.JobRow, error) { + ctx = sqlctemplate.WithReplacements(ctx, map[string]sqlctemplate.Replacement{ + "order_by_clause": {Value: params.OrderByClause}, + "where_clause": {Value: params.WhereClause}, + }, params.NamedArgs) + + jobs, err := dbsqlc.New().JobList(schemaTemplateParam(ctx, params.Schema), e.dbtx, params.Max) + if err != nil { + return nil, interpretError(err) + } + return sliceutil.MapError(jobs, jobRowFromInternal) +} + +func (e *Executor) JobRescueMany(ctx context.Context, params *riverdriver.JobRescueManyParams) (*struct{}, error) { + if err := dbutil.WithTx(ctx, e, func(ctx context.Context, execTx riverdriver.ExecutorTx) error { + ctx = schemaTemplateParam(ctx, params.Schema) + dbtx := templateReplaceWrapper{dbtx: e.driver.UnwrapTx(execTx), replacer: &e.driver.replacer} + + for i := range params.ID { + if err := dbsqlc.New().JobRescue(ctx, dbtx, &dbsqlc.JobRescueParams{ + ID: params.ID[i], + Error: params.Error[i], + FinalizedAt: nullTimeFromPtr(params.FinalizedAt[i]), + ScheduledAt: params.ScheduledAt[i].UTC(), + State: params.State[i], + }); err != nil { + return interpretError(err) + } + } + + return nil + }); err != nil { + return nil, err + } + + return &struct{}{}, nil +} + +func (e *Executor) JobRetry(ctx context.Context, params *riverdriver.JobRetryParams) (*rivertype.JobRow, error) { + return dbutil.WithTxV(ctx, e, func(ctx context.Context, execTx riverdriver.ExecutorTx) (*rivertype.JobRow, error) { + ctx = schemaTemplateParam(ctx, params.Schema) + dbtx := templateReplaceWrapper{dbtx: e.driver.UnwrapTx(execTx), replacer: &e.driver.replacer} + + _, err := dbsqlc.New().JobRetryExec(ctx, dbtx, &dbsqlc.JobRetryExecParams{ + ID: params.ID, + Now: nullTimeFromPtr(params.Now), + }) + if err != nil { + return nil, interpretError(err) + } + + job, err := dbsqlc.New().JobGetByID(ctx, dbtx, params.ID) + if err != nil { + return nil, interpretError(err) + } + return jobRowFromInternal(job) + }) +} + +func (e *Executor) JobSchedule(ctx context.Context, params *riverdriver.JobScheduleParams) ([]*riverdriver.JobScheduleResult, error) { + return dbutil.WithTxV(ctx, e, func(ctx context.Context, execTx riverdriver.ExecutorTx) ([]*riverdriver.JobScheduleResult, error) { + ctx = schemaTemplateParam(ctx, params.Schema) + dbtx := templateReplaceWrapper{dbtx: e.driver.UnwrapTx(execTx), replacer: &e.driver.replacer} + + scheduleResults, err := dbsqlc.New().JobSchedule(ctx, dbtx, &dbsqlc.JobScheduleParams{ + Limit: int32(params.Max), //nolint:gosec + Now: nullTimeFromPtr(params.Now), + }) + if err != nil { + return nil, interpretError(err) + } + + var ( + allIDs []int64 + availIDs []int64 + discardIDs []int64 + discardSet = make(map[int64]bool) + ) + + for _, result := range scheduleResults { + allIDs = append(allIDs, result.ID) + if result.ConflictDiscarded != 0 { + discardIDs = append(discardIDs, result.ID) + discardSet[result.ID] = true + } else { + availIDs = append(availIDs, result.ID) + } + } + + if len(availIDs) > 0 { + if err := dbsqlc.New().JobScheduleSetAvailableExec(ctx, dbtx, availIDs); err != nil { + return nil, interpretError(err) + } + } + + if len(discardIDs) > 0 { + if err := dbsqlc.New().JobScheduleSetDiscardedExec(ctx, dbtx, &dbsqlc.JobScheduleSetDiscardedExecParams{ + ID: discardIDs, + Now: nullTimeFromPtr(params.Now), + }); err != nil { + return nil, interpretError(err) + } + } + + if len(allIDs) == 0 { + return nil, nil + } + + updatedJobs, err := dbsqlc.New().JobGetByIDMany(ctx, dbtx, allIDs) + if err != nil { + return nil, interpretError(err) + } + + jobsByID := make(map[int64]*dbsqlc.RiverJob, len(updatedJobs)) + for _, j := range updatedJobs { + jobsByID[j.ID] = j + } + + // Return results in the same order as scheduleResults. + results := make([]*riverdriver.JobScheduleResult, len(scheduleResults)) + for i, sr := range scheduleResults { + job, err := jobRowFromInternal(jobsByID[sr.ID]) + if err != nil { + return nil, err + } + results[i] = &riverdriver.JobScheduleResult{ConflictDiscarded: discardSet[sr.ID], Job: *job} + } + + return results, nil + }) +} + +func (e *Executor) JobSetStateIfRunningMany(ctx context.Context, params *riverdriver.JobSetStateIfRunningManyParams) ([]*rivertype.JobRow, error) { + setRes := make([]*rivertype.JobRow, len(params.ID)) + + if err := dbutil.WithTx(ctx, e, func(ctx context.Context, execTx riverdriver.ExecutorTx) error { + ctx = schemaTemplateParam(ctx, params.Schema) + dbtx := templateReplaceWrapper{dbtx: e.driver.UnwrapTx(execTx), replacer: &e.driver.replacer} + + // Step 1: Execute all state changes. + for i := range params.ID { + setStateParams := &dbsqlc.JobSetStateIfRunningExecParams{ + ID: params.ID[i], + Error: []byte("{}"), + MetadataUpdates: []byte("{}"), + Now: nullTimeFromPtr(params.Now), + State: string(params.State[i]), + } + + if params.Attempt[i] != nil { + setStateParams.AttemptDoUpdate = 1 + setStateParams.Attempt = int64(*params.Attempt[i]) + } + if params.ErrData[i] != nil { + setStateParams.ErrorsDoUpdate = 1 + setStateParams.Error = params.ErrData[i] + } + if params.FinalizedAt[i] != nil { + setStateParams.FinalizedAtDoUpdate = 1 + setStateParams.FinalizedAt = nullTimeFromPtr(params.FinalizedAt[i]) + } + if params.MetadataDoMerge[i] { + setStateParams.MetadataDoMerge = 1 + setStateParams.MetadataUpdates = params.MetadataUpdates[i] + } + if params.ScheduledAt[i] != nil { + setStateParams.ScheduledAtDoUpdate = 1 + setStateParams.ScheduledAt = *params.ScheduledAt[i] + } + + if err := dbsqlc.New().JobSetStateIfRunningExec(ctx, dbtx, setStateParams); err != nil { + return fmt.Errorf("error setting job state: %w", err) + } + } + + // Step 2: Batch fetch all jobs. + jobs, err := dbsqlc.New().JobGetByIDMany(ctx, dbtx, params.ID) + if err != nil { + return interpretError(err) + } + + jobsByID := make(map[int64]*dbsqlc.RiverJob, len(jobs)) + for _, j := range jobs { + jobsByID[j.ID] = j + } + + // Step 3: For jobs that weren't running, merge metadata if requested. + var metadataMergedIDs []int64 + for i := range params.ID { + job := jobsByID[params.ID[i]] + if job == nil { + continue + } + + if rivertype.JobState(job.State) != rivertype.JobStateRunning && params.MetadataDoMerge[i] { + res, err := dbsqlc.New().JobSetMetadataIfNotRunningExec(ctx, dbtx, &dbsqlc.JobSetMetadataIfNotRunningExecParams{ + ID: params.ID[i], + MetadataUpdates: sliceutil.FirstNonEmpty(params.MetadataUpdates[i], []byte("{}")), + }) + if err != nil { + return fmt.Errorf("error setting job metadata: %w", err) + } + + rowsAffected, err := res.RowsAffected() + if err != nil { + return err + } + + if rowsAffected > 0 { + metadataMergedIDs = append(metadataMergedIDs, params.ID[i]) + } + } + } + + // Step 4: Re-fetch jobs that had metadata merged. + if len(metadataMergedIDs) > 0 { + refreshed, err := dbsqlc.New().JobGetByIDMany(ctx, dbtx, metadataMergedIDs) + if err != nil { + return interpretError(err) + } + for _, j := range refreshed { + jobsByID[j.ID] = j + } + } + + // Step 5: Build results in original order. + for i := range params.ID { + job := jobsByID[params.ID[i]] + if job == nil { + continue + } + + setRes[i], err = jobRowFromInternal(job) + if err != nil { + return err + } + } + + return nil + }); err != nil { + return nil, err + } + + return setRes, nil +} + +func (e *Executor) JobUpdate(ctx context.Context, params *riverdriver.JobUpdateParams) (*rivertype.JobRow, error) { + metadata := params.Metadata + if metadata == nil { + metadata = []byte("{}") + } + + ctx = schemaTemplateParam(ctx, params.Schema) + + var metadataDoMerge int64 + if params.MetadataDoMerge { + metadataDoMerge = 1 + } + + if err := dbsqlc.New().JobUpdateExec(ctx, e.dbtx, &dbsqlc.JobUpdateExecParams{ + ID: params.ID, + MetadataDoMerge: metadataDoMerge, + Metadata: metadata, + }); err != nil { + return nil, interpretError(err) + } + + job, err := dbsqlc.New().JobGetByID(ctx, e.dbtx, params.ID) + if err != nil { + return nil, interpretError(err) + } + + return jobRowFromInternal(job) +} + +func (e *Executor) JobUpdateFull(ctx context.Context, params *riverdriver.JobUpdateFullParams) (*rivertype.JobRow, error) { + attemptedAt := params.AttemptedAt + if attemptedAt != nil { + attemptedAt = ptrutil.Ptr(attemptedAt.UTC()) + } + + attemptedBy, err := json.Marshal(params.AttemptedBy) + if err != nil { + return nil, err + } + + errorsData, err := json.Marshal(sliceutil.Map(params.Errors, func(e []byte) json.RawMessage { return json.RawMessage(e) })) + if err != nil { + return nil, err + } + + finalizedAt := params.FinalizedAt + if finalizedAt != nil { + finalizedAt = ptrutil.Ptr(finalizedAt.UTC()) + } + + metadata := params.Metadata + if metadata == nil { + metadata = []byte("{}") + } + + ctx = schemaTemplateParam(ctx, params.Schema) + + if err := dbsqlc.New().JobUpdateFullExec(ctx, e.dbtx, &dbsqlc.JobUpdateFullExecParams{ + ID: params.ID, + Attempt: int64(params.Attempt), + AttemptDoUpdate: boolToInt64(params.AttemptDoUpdate), + AttemptedAt: nullTimeFromPtr(attemptedAt), + AttemptedAtDoUpdate: boolToInt64(params.AttemptedAtDoUpdate), + AttemptedBy: attemptedBy, + AttemptedByDoUpdate: boolToInt64(params.AttemptedByDoUpdate), + ErrorsDoUpdate: boolToInt64(params.ErrorsDoUpdate), + Errors: errorsData, + FinalizedAtDoUpdate: boolToInt64(params.FinalizedAtDoUpdate), + FinalizedAt: nullTimeFromPtr(finalizedAt), + MaxAttemptsDoUpdate: boolToInt64(params.MaxAttemptsDoUpdate), + MaxAttempts: int64(min(params.MaxAttempts, math.MaxInt64)), + MetadataDoUpdate: boolToInt64(params.MetadataDoUpdate), + Metadata: metadata, + StateDoUpdate: boolToInt64(params.StateDoUpdate), + State: string(params.State), + }); err != nil { + return nil, interpretError(err) + } + + job, err := dbsqlc.New().JobGetByID(ctx, e.dbtx, params.ID) + if err != nil { + return nil, interpretError(err) + } + + return jobRowFromInternal(job) +} + +func (e *Executor) LeaderAttemptElect(ctx context.Context, params *riverdriver.LeaderElectParams) (*riverdriver.Leader, error) { + ctx = schemaTemplateParam(ctx, params.Schema) + + res, err := dbsqlc.New().LeaderAttemptElectExec(ctx, e.dbtx, &dbsqlc.LeaderAttemptElectExecParams{ + LeaderID: params.LeaderID, + Now: nullTimeFromPtr(params.Now), + TTL: params.TTL.Microseconds(), + }) + if err != nil { + return nil, interpretError(err) + } + + // INSERT IGNORE returns 0 rows affected when a leader already exists. + affected, err := res.RowsAffected() + if err != nil { + return nil, err + } + if affected == 0 { + return nil, rivertype.ErrNotFound + } + + leader, err := dbsqlc.New().LeaderGetElectedLeader(ctx, e.dbtx) + if err != nil { + return nil, interpretError(err) + } + return leaderFromInternal(leader), nil +} + +func (e *Executor) LeaderAttemptReelect(ctx context.Context, params *riverdriver.LeaderReelectParams) (*riverdriver.Leader, error) { + ctx = schemaTemplateParam(ctx, params.Schema) + + res, err := dbsqlc.New().LeaderAttemptReelectExec(ctx, e.dbtx, &dbsqlc.LeaderAttemptReelectExecParams{ + ElectedAt: params.ElectedAt, + LeaderID: params.LeaderID, + Now: nullTimeFromPtr(params.Now), + TTL: params.TTL.Microseconds(), + }) + if err != nil { + return nil, interpretError(err) + } + + affected, err := res.RowsAffected() + if err != nil { + return nil, err + } + if affected == 0 { + return nil, rivertype.ErrNotFound + } + + leader, err := dbsqlc.New().LeaderGetElectedLeader(ctx, e.dbtx) + if err != nil { + return nil, interpretError(err) + } + return leaderFromInternal(leader), nil +} + +func (e *Executor) LeaderDeleteExpired(ctx context.Context, params *riverdriver.LeaderDeleteExpiredParams) (int, error) { + numDeleted, err := dbsqlc.New().LeaderDeleteExpired(schemaTemplateParam(ctx, params.Schema), e.dbtx, nullTimeFromPtr(params.Now)) + if err != nil { + return 0, interpretError(err) + } + return int(numDeleted), nil +} + +func (e *Executor) LeaderGetElectedLeader(ctx context.Context, params *riverdriver.LeaderGetElectedLeaderParams) (*riverdriver.Leader, error) { + leader, err := dbsqlc.New().LeaderGetElectedLeader(schemaTemplateParam(ctx, params.Schema), e.dbtx) + if err != nil { + return nil, interpretError(err) + } + return leaderFromInternal(leader), nil +} + +func (e *Executor) LeaderInsert(ctx context.Context, params *riverdriver.LeaderInsertParams) (*riverdriver.Leader, error) { + ctx = schemaTemplateParam(ctx, params.Schema) + + if err := dbsqlc.New().LeaderInsertExec(ctx, e.dbtx, &dbsqlc.LeaderInsertExecParams{ + ElectedAt: params.ElectedAt, + ExpiresAt: params.ExpiresAt, + Now: params.Now, + LeaderID: params.LeaderID, + TTL: params.TTL.Microseconds(), + }); err != nil { + return nil, interpretError(err) + } + + leader, err := dbsqlc.New().LeaderGetElectedLeader(ctx, e.dbtx) + if err != nil { + return nil, interpretError(err) + } + return leaderFromInternal(leader), nil +} + +func (e *Executor) LeaderResign(ctx context.Context, params *riverdriver.LeaderResignParams) (bool, error) { + numResigned, err := dbsqlc.New().LeaderResign(schemaTemplateParam(ctx, params.Schema), e.dbtx, &dbsqlc.LeaderResignParams{ + ElectedAt: params.ElectedAt, + LeaderID: params.LeaderID, + }) + if err != nil { + return false, interpretError(err) + } + return numResigned > 0, nil +} + +func (e *Executor) MigrationDeleteAssumingMainMany(ctx context.Context, params *riverdriver.MigrationDeleteAssumingMainManyParams) ([]*riverdriver.Migration, error) { + ctx = schemaTemplateParam(ctx, params.Schema) + versions := sliceutil.Map(params.Versions, func(v int) int64 { return int64(v) }) + + migrations, err := dbsqlc.New().RiverMigrationDeleteAssumingMainMany(ctx, e.dbtx, versions) + if err != nil { + return nil, interpretError(err) + } + + if len(versions) > 0 { + if err := dbsqlc.New().RiverMigrationDeleteAssumingMainManyExec(ctx, e.dbtx, versions); err != nil { + return nil, interpretError(err) + } + } + + return sliceutil.Map(migrations, func(internal *dbsqlc.RiverMigrationDeleteAssumingMainManyRow) *riverdriver.Migration { + return &riverdriver.Migration{ + CreatedAt: internal.CreatedAt.UTC(), + Line: riverdriver.MigrationLineMain, + Version: int(internal.Version), + } + }), nil +} + +func (e *Executor) MigrationDeleteByLineAndVersionMany(ctx context.Context, params *riverdriver.MigrationDeleteByLineAndVersionManyParams) ([]*riverdriver.Migration, error) { + ctx = schemaTemplateParam(ctx, params.Schema) + versions := sliceutil.Map(params.Versions, func(v int) int64 { return int64(v) }) + + // Step 1: Select the rows to return + migrations, err := dbsqlc.New().RiverMigrationDeleteByLineAndVersionMany(ctx, e.dbtx, &dbsqlc.RiverMigrationDeleteByLineAndVersionManyParams{ + Line: params.Line, + Version: versions, + }) + if err != nil { + return nil, interpretError(err) + } + + // Step 2: Delete those rows + if len(versions) > 0 { + if err := dbsqlc.New().RiverMigrationDeleteByLineAndVersionManyExec(ctx, e.dbtx, &dbsqlc.RiverMigrationDeleteByLineAndVersionManyExecParams{ + Line: params.Line, + Version: versions, + }); err != nil { + return nil, interpretError(err) + } + } + + return sliceutil.Map(migrations, migrationFromInternal), nil +} + +func (e *Executor) MigrationGetAllAssumingMain(ctx context.Context, params *riverdriver.MigrationGetAllAssumingMainParams) ([]*riverdriver.Migration, error) { + migrations, err := dbsqlc.New().RiverMigrationGetAllAssumingMain(schemaTemplateParam(ctx, params.Schema), e.dbtx) + if err != nil { + return nil, interpretError(err) + } + return sliceutil.Map(migrations, func(internal *dbsqlc.RiverMigrationGetAllAssumingMainRow) *riverdriver.Migration { + return &riverdriver.Migration{ + CreatedAt: internal.CreatedAt.UTC(), + Line: riverdriver.MigrationLineMain, + Version: int(internal.Version), + } + }), nil +} + +func (e *Executor) MigrationGetByLine(ctx context.Context, params *riverdriver.MigrationGetByLineParams) ([]*riverdriver.Migration, error) { + migrations, err := dbsqlc.New().RiverMigrationGetByLine(schemaTemplateParam(ctx, params.Schema), e.dbtx, params.Line) + if err != nil { + return nil, interpretError(err) + } + return sliceutil.Map(migrations, migrationFromInternal), nil +} + +func (e *Executor) MigrationInsertMany(ctx context.Context, params *riverdriver.MigrationInsertManyParams) ([]*riverdriver.Migration, error) { + var migrations []*riverdriver.Migration + + if err := dbutil.WithTx(ctx, e, func(ctx context.Context, execTx riverdriver.ExecutorTx) error { + ctx = schemaTemplateParam(ctx, params.Schema) + dbtx := templateReplaceWrapper{dbtx: e.driver.UnwrapTx(execTx), replacer: &e.driver.replacer} + + for _, version := range params.Versions { + if err := dbsqlc.New().RiverMigrationInsertExec(ctx, dbtx, &dbsqlc.RiverMigrationInsertExecParams{ + Line: params.Line, + Version: int64(version), + }); err != nil { + return interpretError(err) + } + } + + versions := sliceutil.Map(params.Versions, func(v int) int64 { return int64(v) }) + + internals, err := dbsqlc.New().RiverMigrationGetByLineAndVersionMany(ctx, dbtx, &dbsqlc.RiverMigrationGetByLineAndVersionManyParams{ + Line: params.Line, + Version: versions, + }) + if err != nil { + return interpretError(err) + } + + migrations = sliceutil.Map(internals, migrationFromInternal) + return nil + }); err != nil { + return nil, err + } + + return migrations, nil +} + +func (e *Executor) MigrationInsertManyAssumingMain(ctx context.Context, params *riverdriver.MigrationInsertManyAssumingMainParams) ([]*riverdriver.Migration, error) { + var migrations []*riverdriver.Migration + + if err := dbutil.WithTx(ctx, e, func(ctx context.Context, execTx riverdriver.ExecutorTx) error { + ctx = schemaTemplateParam(ctx, params.Schema) + dbtx := templateReplaceWrapper{dbtx: e.driver.UnwrapTx(execTx), replacer: &e.driver.replacer} + + for _, version := range params.Versions { + if err := dbsqlc.New().RiverMigrationInsertAssumingMainExec(ctx, dbtx, int64(version)); err != nil { + return interpretError(err) + } + } + + versions := sliceutil.Map(params.Versions, func(v int) int64 { return int64(v) }) + + internals, err := dbsqlc.New().RiverMigrationGetByVersionMany(ctx, dbtx, versions) + if err != nil { + return interpretError(err) + } + + migrations = sliceutil.Map(internals, func(internal *dbsqlc.RiverMigrationGetByVersionManyRow) *riverdriver.Migration { + return &riverdriver.Migration{ + CreatedAt: internal.CreatedAt.UTC(), + Line: riverdriver.MigrationLineMain, + Version: int(internal.Version), + } + }) + return nil + }); err != nil { + return nil, err + } + + return migrations, nil +} + +func (e *Executor) NotifyMany(ctx context.Context, params *riverdriver.NotifyManyParams) error { + return riverdriver.ErrNotImplemented +} + +func (e *Executor) PGAdvisoryXactLock(ctx context.Context, key int64) (*struct{}, error) { + // MySQL has GET_LOCK which is similar to PostgreSQL advisory locks, but + // it's session-scoped rather than transaction-scoped. For now, return + // not implemented. + return nil, riverdriver.ErrNotImplemented +} + +func (e *Executor) QueueCreateOrSetUpdatedAt(ctx context.Context, params *riverdriver.QueueCreateOrSetUpdatedAtParams) (*rivertype.Queue, error) { + ctx = schemaTemplateParam(ctx, params.Schema) + + if err := dbsqlc.New().QueueCreateOrSetUpdatedAtExec(ctx, e.dbtx, &dbsqlc.QueueCreateOrSetUpdatedAtExecParams{ + Metadata: sliceutil.FirstNonEmpty(params.Metadata, []byte("{}")), + Name: params.Name, + Now: params.Now, + PausedAt: nullTimeFromPtr(params.PausedAt), + UpdatedAt: params.UpdatedAt, + }); err != nil { + return nil, interpretError(err) + } + + queue, err := dbsqlc.New().QueueGet(ctx, e.dbtx, params.Name) + if err != nil { + return nil, interpretError(err) + } + return queueFromInternal(queue), nil +} + +func (e *Executor) QueueDeleteExpired(ctx context.Context, params *riverdriver.QueueDeleteExpiredParams) ([]string, error) { + ctx = schemaTemplateParam(ctx, params.Schema) + + // Step 1: Select expired queue names + queueNames, err := dbsqlc.New().QueueDeleteExpiredSelect(ctx, e.dbtx, &dbsqlc.QueueDeleteExpiredSelectParams{ + Limit: int32(params.Max), //nolint:gosec + UpdatedAtHorizon: params.UpdatedAtHorizon.UTC(), + }) + if err != nil { + return nil, interpretError(err) + } + + // Step 2: Delete those queues + if len(queueNames) > 0 { + if err := dbsqlc.New().QueueDeleteExpiredExec(ctx, e.dbtx, queueNames); err != nil { + return nil, interpretError(err) + } + } + + return queueNames, nil +} + +func (e *Executor) QueueGet(ctx context.Context, params *riverdriver.QueueGetParams) (*rivertype.Queue, error) { + queue, err := dbsqlc.New().QueueGet(schemaTemplateParam(ctx, params.Schema), e.dbtx, params.Name) + if err != nil { + return nil, interpretError(err) + } + return queueFromInternal(queue), nil +} + +func (e *Executor) QueueList(ctx context.Context, params *riverdriver.QueueListParams) ([]*rivertype.Queue, error) { + queues, err := dbsqlc.New().QueueList(schemaTemplateParam(ctx, params.Schema), e.dbtx, int32(params.Max)) //nolint:gosec + if err != nil { + return nil, interpretError(err) + } + return sliceutil.Map(queues, queueFromInternal), nil +} + +func (e *Executor) QueueNameList(ctx context.Context, params *riverdriver.QueueNameListParams) ([]string, error) { + exclude := params.Exclude + if len(exclude) == 0 { + exclude = []string{""} + } + queueNames, err := dbsqlc.New().QueueNameList(schemaTemplateParam(ctx, params.Schema), e.dbtx, &dbsqlc.QueueNameListParams{ + After: params.After, + Exclude: exclude, + Match: params.Match, + Limit: int32(min(params.Max, math.MaxInt32)), //nolint:gosec + }) + if err != nil { + return nil, interpretError(err) + } + return queueNames, nil +} + +func (e *Executor) QueuePause(ctx context.Context, params *riverdriver.QueuePauseParams) error { + ctx = schemaTemplateParam(ctx, params.Schema) + now := nullTimeFromPtr(params.Now) + + var ( + res sql.Result + err error + ) + + if params.Name == riverdriver.AllQueuesString { + res, err = dbsqlc.New().QueuePauseAll(ctx, e.dbtx, &dbsqlc.QueuePauseAllParams{Now: now}) + } else { + res, err = dbsqlc.New().QueuePauseByName(ctx, e.dbtx, &dbsqlc.QueuePauseByNameParams{ + Name: params.Name, + Now: now, + }) + } + if err != nil { + return interpretError(err) + } + + rowsAffected, err := res.RowsAffected() + if err != nil { + return interpretError(err) + } + + // MySQL only reports rows that actually changed values, not rows matched. + // If no rows were affected for a named queue, check if the queue exists + // before returning ErrNotFound (it may already be paused/unpaused). + if rowsAffected < 1 && params.Name != riverdriver.AllQueuesString { + if _, err := dbsqlc.New().QueueGet(ctx, e.dbtx, params.Name); err != nil { + return interpretError(err) + } + } + return nil +} + +func (e *Executor) QueueResume(ctx context.Context, params *riverdriver.QueueResumeParams) error { + ctx = schemaTemplateParam(ctx, params.Schema) + now := nullTimeFromPtr(params.Now) + + var ( + res sql.Result + err error + ) + + if params.Name == riverdriver.AllQueuesString { + res, err = dbsqlc.New().QueueResumeAll(ctx, e.dbtx, now) + } else { + res, err = dbsqlc.New().QueueResumeByName(ctx, e.dbtx, &dbsqlc.QueueResumeByNameParams{ + Name: params.Name, + Now: now, + }) + } + if err != nil { + return interpretError(err) + } + + rowsAffected, err := res.RowsAffected() + if err != nil { + return interpretError(err) + } + + // MySQL only reports rows that actually changed values, not rows matched. + // If no rows were affected for a named queue, check if the queue exists + // before returning ErrNotFound (it may already be resumed/unpaused). + if rowsAffected < 1 && params.Name != riverdriver.AllQueuesString { + if _, err := dbsqlc.New().QueueGet(ctx, e.dbtx, params.Name); err != nil { + return interpretError(err) + } + } + return nil +} + +func (e *Executor) QueueUpdate(ctx context.Context, params *riverdriver.QueueUpdateParams) (*rivertype.Queue, error) { + ctx = schemaTemplateParam(ctx, params.Schema) + + var metadataDoUpdate int64 + if params.MetadataDoUpdate { + metadataDoUpdate = 1 + } + + if err := dbsqlc.New().QueueUpdateExec(ctx, e.dbtx, &dbsqlc.QueueUpdateExecParams{ + Metadata: sliceutil.FirstNonEmpty(params.Metadata, []byte("{}")), + MetadataDoUpdate: metadataDoUpdate, + Name: params.Name, + }); err != nil { + return nil, interpretError(err) + } + + queue, err := dbsqlc.New().QueueGet(ctx, e.dbtx, params.Name) + if err != nil { + return nil, interpretError(err) + } + return queueFromInternal(queue), nil +} + +func (e *Executor) QueryRow(ctx context.Context, sql string, args ...any) riverdriver.Row { + return e.dbtx.QueryRowContext(ctx, sql, args...) +} + +func (e *Executor) SchemaCreate(ctx context.Context, params *riverdriver.SchemaCreateParams) error { + // In MySQL, schemas are databases. Create one if it doesn't exist. + if params.Schema != "" { + _, err := e.dbtx.ExecContext(ctx, "CREATE DATABASE IF NOT EXISTS "+mysqlIdentifier(params.Schema)) + return interpretError(err) + } + return nil +} + +func (e *Executor) SchemaDrop(ctx context.Context, params *riverdriver.SchemaDropParams) error { + if params.Schema != "" { + _, err := e.dbtx.ExecContext(ctx, "DROP DATABASE IF EXISTS "+mysqlIdentifier(params.Schema)) + return interpretError(err) + } + return nil +} + +func (e *Executor) SchemaGetExpired(ctx context.Context, params *riverdriver.SchemaGetExpiredParams) ([]string, error) { + rows, err := e.dbtx.QueryContext(ctx, + `SELECT SCHEMA_NAME FROM INFORMATION_SCHEMA.SCHEMATA + WHERE SCHEMA_NAME LIKE CONCAT(?, '%') AND SCHEMA_NAME < ? + ORDER BY SCHEMA_NAME`, + params.Prefix, params.BeforeName) + if err != nil { + return nil, interpretError(err) + } + defer rows.Close() + + var schemas []string + for rows.Next() { + var name string + if err := rows.Scan(&name); err != nil { + return nil, err + } + schemas = append(schemas, name) + } + return schemas, rows.Err() +} + +func (e *Executor) TableExists(ctx context.Context, params *riverdriver.TableExistsParams) (bool, error) { + ctx = informationSchemaParam(ctx) + + exists, err := dbsqlc.New().TableExists(ctx, e.dbtx, &dbsqlc.TableExistsParams{ + TableName: params.Table, + Schema: sql.NullString{String: params.Schema, Valid: params.Schema != ""}, + }) + if err != nil { + return false, interpretError(err) + } + return exists, nil +} + +func (e *Executor) TableTruncate(ctx context.Context, params *riverdriver.TableTruncateParams) error { + var maybeSchema string + if params.Schema != "" { + maybeSchema = mysqlIdentifier(params.Schema) + "." + } + + for _, table := range params.Table { + // MySQL's TRUNCATE TABLE is DDL and can't be used in transactions. + // Use DELETE FROM instead for transactional safety. + _, err := e.dbtx.ExecContext(ctx, "DELETE FROM "+maybeSchema+table) + if err != nil { + return interpretError(err) + } + } + + return nil +} + +type ExecutorTx struct { + Executor + + tx *sql.Tx +} + +func (t *ExecutorTx) Begin(ctx context.Context) (riverdriver.ExecutorTx, error) { + executorSubTx := &ExecutorSubTx{ + beginOnce: &savepointutil.BeginOnlyOnce{}, + savepointNum: 0, + tx: t.tx, + } + executorSubTx.Executor = Executor{nil, templateReplaceWrapper{t.tx, &t.driver.replacer}, t.driver, executorSubTx} + return executorSubTx.Begin(ctx) +} + +func (t *ExecutorTx) Commit(ctx context.Context) error { + return t.tx.Commit() +} + +func (t *ExecutorTx) Rollback(ctx context.Context) error { + return t.tx.Rollback() +} + +type ExecutorSubTx struct { + Executor + + beginOnce *savepointutil.BeginOnlyOnce + savepointNum int + tx *sql.Tx +} + +const savepointPrefix = "river_savepoint_" + +func (t *ExecutorSubTx) Begin(ctx context.Context) (riverdriver.ExecutorTx, error) { + if err := t.beginOnce.Begin(); err != nil { + return nil, err + } + + nextSavepointNum := t.savepointNum + 1 + if err := t.Exec(ctx, fmt.Sprintf("SAVEPOINT %s%02d", savepointPrefix, nextSavepointNum)); err != nil { + return nil, err + } + + executorSubTx := &ExecutorSubTx{ + beginOnce: savepointutil.NewBeginOnlyOnce(t.beginOnce), + savepointNum: nextSavepointNum, + tx: t.tx, + } + executorSubTx.Executor = Executor{nil, templateReplaceWrapper{t.tx, &t.driver.replacer}, t.driver, executorSubTx} + + return executorSubTx, nil +} + +func (t *ExecutorSubTx) Commit(ctx context.Context) error { + defer t.beginOnce.Done() + + if t.beginOnce.IsDone() { + return errors.New("tx is closed") + } + + if err := t.Exec(ctx, fmt.Sprintf("RELEASE SAVEPOINT %s%02d", savepointPrefix, t.savepointNum)); err != nil { + // MySQL DDL statements (CREATE TABLE, ALTER TABLE, etc.) cause an + // implicit COMMIT which destroys savepoints. If the savepoint no + // longer exists, the work was already committed by the DDL. + if isSavepointDoesNotExist(err) { + return nil + } + return err + } + + return nil +} + +func (t *ExecutorSubTx) Rollback(ctx context.Context) error { + defer t.beginOnce.Done() + + if t.beginOnce.IsDone() { + return errors.New("tx is closed") + } + + if err := t.Exec(ctx, fmt.Sprintf("ROLLBACK TO SAVEPOINT %s%02d", savepointPrefix, t.savepointNum)); err != nil { + // MySQL DDL statements cause implicit COMMIT which destroys + // savepoints. If the savepoint no longer exists, there's nothing + // to roll back to. + if isSavepointDoesNotExist(err) { + return nil + } + return err + } + + return nil +} + +// isSavepointDoesNotExist checks if a MySQL error indicates that a savepoint +// does not exist (error 1305). This happens when DDL statements cause an +// implicit COMMIT that destroys active savepoints. +func isSavepointDoesNotExist(err error) bool { + return err != nil && strings.Contains(err.Error(), "1305") +} + +func boolToInt64(b bool) int64 { + if b { + return 1 + } + return 0 +} + +func interpretError(err error) error { + if errors.Is(err, sql.ErrNoRows) { + return rivertype.ErrNotFound + } + return err +} + +// mysqlIdentifier quotes an identifier with backticks for MySQL. +// MySQL uses backticks instead of double quotes for identifier quoting. +func mysqlIdentifier(ident string) string { + return "`" + strings.ReplaceAll(ident, "`", "``") + "`" +} + +type templateReplaceWrapper struct { + dbtx dbsqlc.DBTX + replacer *sqlctemplate.Replacer +} + +func (w templateReplaceWrapper) ExecContext(ctx context.Context, rawSQL string, rawArgs ...any) (sql.Result, error) { + sqlStr, args := w.replacer.Run(ctx, argPlaceholder, rawSQL, rawArgs) + return w.dbtx.ExecContext(ctx, sqlStr, args...) +} + +func (w templateReplaceWrapper) PrepareContext(ctx context.Context, rawSQL string) (*sql.Stmt, error) { + sqlStr, _ := w.replacer.Run(ctx, argPlaceholder, rawSQL, nil) + return w.dbtx.PrepareContext(ctx, sqlStr) +} + +func (w templateReplaceWrapper) QueryContext(ctx context.Context, rawSQL string, rawArgs ...any) (*sql.Rows, error) { + sqlStr, args := w.replacer.Run(ctx, argPlaceholder, rawSQL, rawArgs) + return w.dbtx.QueryContext(ctx, sqlStr, args...) +} + +func (w templateReplaceWrapper) QueryRowContext(ctx context.Context, rawSQL string, rawArgs ...any) *sql.Row { + sqlStr, args := w.replacer.Run(ctx, argPlaceholder, rawSQL, rawArgs) + return w.dbtx.QueryRowContext(ctx, sqlStr, args...) +} + +func jobRowFromInternal(internal *dbsqlc.RiverJob) (*rivertype.JobRow, error) { + attemptedAt := ptrFromNullTime(internal.AttemptedAt) + if attemptedAt != nil { + t := attemptedAt.UTC() + attemptedAt = &t + } + + var attemptedBy []string + if internal.AttemptedBy != nil { + if err := json.Unmarshal(internal.AttemptedBy, &attemptedBy); err != nil { + return nil, fmt.Errorf("error unmarshaling `attempted_by`: %w", err) + } + } + + var errs []rivertype.AttemptError + if internal.Errors != nil { + if err := json.Unmarshal(internal.Errors, &errs); err != nil { + return nil, fmt.Errorf("error unmarshaling `errors`: %w", err) + } + } + + finalizedAt := ptrFromNullTime(internal.FinalizedAt) + if finalizedAt != nil { + t := finalizedAt.UTC() + finalizedAt = &t + } + + var tags []string + if err := json.Unmarshal(internal.Tags, &tags); err != nil { + return nil, fmt.Errorf("error unmarshaling `tags`: %w", err) + } + + var uniqueStatesByte byte + if internal.UniqueStates.Valid { + if internal.UniqueStates.Int16 < 0 || internal.UniqueStates.Int16 > 255 { + return nil, fmt.Errorf("value out of range for byte: %d", internal.UniqueStates.Int16) + } + uniqueStatesByte = byte(internal.UniqueStates.Int16) + } + + return &rivertype.JobRow{ + ID: internal.ID, + Attempt: max(int(internal.Attempt), 0), + AttemptedAt: attemptedAt, + AttemptedBy: attemptedBy, + CreatedAt: internal.CreatedAt.UTC(), + EncodedArgs: internal.Args, + Errors: errs, + FinalizedAt: finalizedAt, + Kind: internal.Kind, + MaxAttempts: max(int(internal.MaxAttempts), 0), + Metadata: internal.Metadata, + Priority: max(int(internal.Priority), 0), + Queue: internal.Queue, + ScheduledAt: internal.ScheduledAt.UTC(), + State: rivertype.JobState(internal.State), + Tags: tags, + UniqueKey: bytesFromNullString(internal.UniqueKey), + UniqueStates: uniquestates.UniqueBitmaskToStates(uniqueStatesByte), + }, nil +} + +func leaderFromInternal(internal *dbsqlc.RiverLeader) *riverdriver.Leader { + return &riverdriver.Leader{ + ElectedAt: internal.ElectedAt.UTC(), + ExpiresAt: internal.ExpiresAt.UTC(), + LeaderID: internal.LeaderID, + } +} + +func migrationFromInternal(internal *dbsqlc.RiverMigration) *riverdriver.Migration { + return &riverdriver.Migration{ + CreatedAt: internal.CreatedAt.UTC(), + Line: internal.Line, + Version: int(internal.Version), + } +} + +func schemaTemplateParam(ctx context.Context, schema string) context.Context { + if schema != "" { + schema = mysqlIdentifier(schema) + "." + } + + return sqlctemplate.WithReplacements(ctx, map[string]sqlctemplate.Replacement{ + "schema": {Value: schema}, + }, nil) +} + +func informationSchemaParam(ctx context.Context) context.Context { + return sqlctemplate.WithReplacements(ctx, map[string]sqlctemplate.Replacement{ + "information_schema": {Stable: true, Value: "INFORMATION_SCHEMA."}, + }, nil) +} + +func queueFromInternal(internal *dbsqlc.RiverQueue) *rivertype.Queue { + pausedAt := ptrFromNullTime(internal.PausedAt) + if pausedAt != nil { + t := pausedAt.UTC() + pausedAt = &t + } + return &rivertype.Queue{ + CreatedAt: internal.CreatedAt.UTC(), + Metadata: internal.Metadata, + Name: internal.Name, + PausedAt: pausedAt, + UpdatedAt: internal.UpdatedAt.UTC(), + } +} + +func nullTimeFromPtr(t *time.Time) sql.NullTime { + if t == nil { + return sql.NullTime{} + } + return sql.NullTime{Time: *t, Valid: true} +} + +func ptrFromNullTime(t sql.NullTime) *time.Time { + if !t.Valid { + return nil + } + return &t.Time +} + +func nullStringFromBytes(b []byte) sql.NullString { + if b == nil { + return sql.NullString{} + } + return sql.NullString{String: string(b), Valid: true} +} + +func bytesFromNullString(s sql.NullString) []byte { + if !s.Valid { + return nil + } + return []byte(s.String) +} diff --git a/riverdriver/rivermysql/river_mysql_driver_test.go b/riverdriver/rivermysql/river_mysql_driver_test.go new file mode 100644 index 00000000..219d6766 --- /dev/null +++ b/riverdriver/rivermysql/river_mysql_driver_test.go @@ -0,0 +1,621 @@ +package rivermysql + +import ( + "context" + "database/sql" + "encoding/json" + "errors" + "fmt" + "testing" + "time" + + _ "github.com/go-sql-driver/mysql" + "github.com/stretchr/testify/require" + + "github.com/riverqueue/river/riverdbtest" + "github.com/riverqueue/river/riverdriver" + "github.com/riverqueue/river/rivershared/riversharedtest" + "github.com/riverqueue/river/rivertype" +) + +// Verify interface compliance. +var _ riverdriver.Driver[*sql.Tx] = New(nil) + +func TestInterpretError(t *testing.T) { + t.Parallel() + + require.EqualError(t, interpretError(errors.New("an error")), "an error") + require.ErrorIs(t, interpretError(sql.ErrNoRows), rivertype.ErrNotFound) + require.NoError(t, interpretError(nil)) +} + +func TestSchemaTemplateParam(t *testing.T) { + t.Parallel() + + ctx := context.Background() + + t.Run("NoSchema", func(t *testing.T) { + t.Parallel() + ctx := schemaTemplateParam(ctx, "") + // Just verify it doesn't panic + _ = ctx + }) + + t.Run("WithSchema", func(t *testing.T) { + t.Parallel() + ctx := schemaTemplateParam(ctx, "custom_schema") + _ = ctx + }) +} + +func TestDriverProperties(t *testing.T) { + t.Parallel() + + driver := New(nil) + require.Equal(t, "?", driver.ArgPlaceholder()) + require.Equal(t, "mysql", driver.DatabaseName()) + require.False(t, driver.SupportsListener()) + require.False(t, driver.SupportsListenNotify()) + require.Equal(t, time.Microsecond, driver.TimePrecision()) + require.False(t, driver.PoolIsSet()) +} + +func TestJobInsertAndGet(t *testing.T) { + t.Parallel() + + riversharedtest.SkipIfMySQLNotEnabled(t) + + var ( + ctx = t.Context() + driver = New(riversharedtest.DBPoolMySQL(ctx, t)) + schema = riverdbtest.TestSchema(ctx, t, driver, nil) + exec = driver.GetExecutor() + ) + + // Insert a job + job, err := exec.JobInsertFull(ctx, &riverdriver.JobInsertFullParams{ + EncodedArgs: []byte(`{"test": true}`), + Kind: "test_job", + MaxAttempts: 3, + Priority: 1, + Queue: "default", + Schema: schema, + State: rivertype.JobStateAvailable, + Tags: []string{"tag1"}, + }) + require.NoError(t, err) + require.NotNil(t, job) + require.Positive(t, job.ID) + require.Equal(t, "test_job", job.Kind) + require.Equal(t, rivertype.JobStateAvailable, job.State) + require.Equal(t, "default", job.Queue) + require.Equal(t, 1, job.Priority) + require.Equal(t, 3, job.MaxAttempts) + require.JSONEq(t, `{"test": true}`, string(job.EncodedArgs)) + require.Equal(t, []string{"tag1"}, job.Tags) + + // Get the job by ID + fetched, err := exec.JobGetByID(ctx, &riverdriver.JobGetByIDParams{ + ID: job.ID, + Schema: schema, + }) + require.NoError(t, err) + require.Equal(t, job.ID, fetched.ID) + require.Equal(t, job.Kind, fetched.Kind) +} + +func TestJobGetAvailable(t *testing.T) { + t.Parallel() + + riversharedtest.SkipIfMySQLNotEnabled(t) + + var ( + ctx = t.Context() + driver = New(riversharedtest.DBPoolMySQL(ctx, t)) + schema = riverdbtest.TestSchema(ctx, t, driver, nil) + exec = driver.GetExecutor() + ) + + // Insert some available jobs + for i := range 3 { + _, err := exec.JobInsertFull(ctx, &riverdriver.JobInsertFullParams{ + EncodedArgs: []byte(`{}`), + Kind: fmt.Sprintf("job_%d", i), + MaxAttempts: 3, + Priority: 1, + Queue: "default", + Schema: schema, + State: rivertype.JobStateAvailable, + Tags: []string{}, + }) + require.NoError(t, err) + } + + // Get available jobs (needs transaction for FOR UPDATE) + txExec, err := exec.Begin(ctx) + require.NoError(t, err) + defer txExec.Rollback(ctx) + + jobs, err := txExec.JobGetAvailable(ctx, &riverdriver.JobGetAvailableParams{ + ClientID: "test-client", + MaxAttemptedBy: 4, + MaxToLock: 2, + Queue: "default", + Schema: schema, + }) + require.NoError(t, err) + require.Len(t, jobs, 2) + + for _, job := range jobs { + require.Equal(t, rivertype.JobStateRunning, job.State) + } + + require.NoError(t, txExec.Commit(ctx)) +} + +func TestJobCancel(t *testing.T) { + t.Parallel() + + riversharedtest.SkipIfMySQLNotEnabled(t) + + var ( + ctx = t.Context() + driver = New(riversharedtest.DBPoolMySQL(ctx, t)) + schema = riverdbtest.TestSchema(ctx, t, driver, nil) + exec = driver.GetExecutor() + ) + + // Insert a job + job, err := exec.JobInsertFull(ctx, &riverdriver.JobInsertFullParams{ + EncodedArgs: []byte(`{}`), + Kind: "test_cancel", + MaxAttempts: 3, + Priority: 1, + Queue: "default", + Schema: schema, + State: rivertype.JobStateAvailable, + Tags: []string{}, + }) + require.NoError(t, err) + + // Cancel the job + cancelled, err := exec.JobCancel(ctx, &riverdriver.JobCancelParams{ + ID: job.ID, + CancelAttemptedAt: time.Now().UTC(), + ControlTopic: "test_topic", + Schema: schema, + }) + require.NoError(t, err) + require.Equal(t, rivertype.JobStateCancelled, cancelled.State) + require.NotNil(t, cancelled.FinalizedAt) +} + +func TestJobDelete(t *testing.T) { + t.Parallel() + + riversharedtest.SkipIfMySQLNotEnabled(t) + + var ( + ctx = t.Context() + driver = New(riversharedtest.DBPoolMySQL(ctx, t)) + schema = riverdbtest.TestSchema(ctx, t, driver, nil) + exec = driver.GetExecutor() + ) + + // Insert a completed job + now := time.Now().UTC().Truncate(time.Microsecond) + job, err := exec.JobInsertFull(ctx, &riverdriver.JobInsertFullParams{ + EncodedArgs: []byte(`{}`), + FinalizedAt: &now, + Kind: "test_delete", + MaxAttempts: 3, + Priority: 1, + Queue: "default", + Schema: schema, + State: rivertype.JobStateCompleted, + Tags: []string{}, + }) + require.NoError(t, err) + + // Delete the job + deleted, err := exec.JobDelete(ctx, &riverdriver.JobDeleteParams{ + ID: job.ID, + Schema: schema, + }) + require.NoError(t, err) + require.Equal(t, job.ID, deleted.ID) + + // Verify it's gone + _, err = exec.JobGetByID(ctx, &riverdriver.JobGetByIDParams{ + ID: job.ID, + Schema: schema, + }) + require.ErrorIs(t, err, rivertype.ErrNotFound) +} + +func TestJobCountByState(t *testing.T) { + t.Parallel() + + riversharedtest.SkipIfMySQLNotEnabled(t) + + var ( + ctx = t.Context() + driver = New(riversharedtest.DBPoolMySQL(ctx, t)) + schema = riverdbtest.TestSchema(ctx, t, driver, nil) + exec = driver.GetExecutor() + ) + + // Insert jobs in different states + for range 3 { + _, err := exec.JobInsertFull(ctx, &riverdriver.JobInsertFullParams{ + EncodedArgs: []byte(`{}`), + Kind: "test_count", + MaxAttempts: 3, + Priority: 1, + Queue: "default", + Schema: schema, + State: rivertype.JobStateAvailable, + Tags: []string{}, + }) + require.NoError(t, err) + } + + count, err := exec.JobCountByState(ctx, &riverdriver.JobCountByStateParams{ + Schema: schema, + State: rivertype.JobStateAvailable, + }) + require.NoError(t, err) + require.Equal(t, 3, count) +} + +func TestLeaderElection(t *testing.T) { + t.Parallel() + + riversharedtest.SkipIfMySQLNotEnabled(t) + + var ( + ctx = t.Context() + driver = New(riversharedtest.DBPoolMySQL(ctx, t)) + schema = riverdbtest.TestSchema(ctx, t, driver, nil) + exec = driver.GetExecutor() + ) + + // Attempt to elect a leader + leader, err := exec.LeaderAttemptElect(ctx, &riverdriver.LeaderElectParams{ + LeaderID: "test-leader", + Schema: schema, + TTL: 30 * time.Second, + }) + require.NoError(t, err) + require.Equal(t, "test-leader", leader.LeaderID) + + // Get the elected leader + fetched, err := exec.LeaderGetElectedLeader(ctx, &riverdriver.LeaderGetElectedLeaderParams{ + Schema: schema, + }) + require.NoError(t, err) + require.Equal(t, "test-leader", fetched.LeaderID) + + // Re-elect should succeed + reelected, err := exec.LeaderAttemptReelect(ctx, &riverdriver.LeaderReelectParams{ + ElectedAt: leader.ElectedAt, + LeaderID: "test-leader", + Schema: schema, + TTL: 30 * time.Second, + }) + require.NoError(t, err) + require.Equal(t, "test-leader", reelected.LeaderID) + + // Resign + resigned, err := exec.LeaderResign(ctx, &riverdriver.LeaderResignParams{ + ElectedAt: leader.ElectedAt, + LeaderID: "test-leader", + LeadershipTopic: "leadership", + Schema: schema, + }) + require.NoError(t, err) + require.True(t, resigned) +} + +func TestQueueOperations(t *testing.T) { + t.Parallel() + + riversharedtest.SkipIfMySQLNotEnabled(t) + + var ( + ctx = t.Context() + driver = New(riversharedtest.DBPoolMySQL(ctx, t)) + schema = riverdbtest.TestSchema(ctx, t, driver, nil) + exec = driver.GetExecutor() + ) + + now := time.Now().UTC().Truncate(time.Microsecond) + + // Create a queue + queue, err := exec.QueueCreateOrSetUpdatedAt(ctx, &riverdriver.QueueCreateOrSetUpdatedAtParams{ + Metadata: []byte(`{}`), + Name: "test_queue", + Now: &now, + Schema: schema, + }) + require.NoError(t, err) + require.Equal(t, "test_queue", queue.Name) + require.Nil(t, queue.PausedAt) + + // Get the queue + fetched, err := exec.QueueGet(ctx, &riverdriver.QueueGetParams{ + Name: "test_queue", + Schema: schema, + }) + require.NoError(t, err) + require.Equal(t, "test_queue", fetched.Name) + + // List queues + queues, err := exec.QueueList(ctx, &riverdriver.QueueListParams{ + Max: 100, + Schema: schema, + }) + require.NoError(t, err) + require.Len(t, queues, 1) + + // Pause queue + err = exec.QueuePause(ctx, &riverdriver.QueuePauseParams{ + Name: "test_queue", + Schema: schema, + }) + require.NoError(t, err) + + paused, err := exec.QueueGet(ctx, &riverdriver.QueueGetParams{ + Name: "test_queue", + Schema: schema, + }) + require.NoError(t, err) + require.NotNil(t, paused.PausedAt) + + // Resume queue + err = exec.QueueResume(ctx, &riverdriver.QueueResumeParams{ + Name: "test_queue", + Schema: schema, + }) + require.NoError(t, err) + + resumed, err := exec.QueueGet(ctx, &riverdriver.QueueGetParams{ + Name: "test_queue", + Schema: schema, + }) + require.NoError(t, err) + require.Nil(t, resumed.PausedAt) +} + +func TestTransactions(t *testing.T) { + t.Parallel() + + riversharedtest.SkipIfMySQLNotEnabled(t) + + var ( + ctx = t.Context() + driver = New(riversharedtest.DBPoolMySQL(ctx, t)) + schema = riverdbtest.TestSchema(ctx, t, driver, nil) + exec = driver.GetExecutor() + ) + + // Begin and commit + tx, err := exec.Begin(ctx) + require.NoError(t, err) + + _, err = tx.JobInsertFull(ctx, &riverdriver.JobInsertFullParams{ + EncodedArgs: []byte(`{}`), + Kind: "tx_test", + MaxAttempts: 3, + Priority: 1, + Queue: "default", + Schema: schema, + State: rivertype.JobStateAvailable, + Tags: []string{}, + }) + require.NoError(t, err) + + require.NoError(t, tx.Commit(ctx)) + + // Verify job exists + count, err := exec.JobCountByState(ctx, &riverdriver.JobCountByStateParams{ + Schema: schema, + State: rivertype.JobStateAvailable, + }) + require.NoError(t, err) + require.Equal(t, 1, count) + + // Begin and rollback + tx2, err := exec.Begin(ctx) + require.NoError(t, err) + + _, err = tx2.JobInsertFull(ctx, &riverdriver.JobInsertFullParams{ + EncodedArgs: []byte(`{}`), + Kind: "tx_test_rollback", + MaxAttempts: 3, + Priority: 1, + Queue: "default", + Schema: schema, + State: rivertype.JobStateAvailable, + Tags: []string{}, + }) + require.NoError(t, err) + + require.NoError(t, tx2.Rollback(ctx)) + + // Verify second job was rolled back + count, err = exec.JobCountByState(ctx, &riverdriver.JobCountByStateParams{ + Schema: schema, + State: rivertype.JobStateAvailable, + }) + require.NoError(t, err) + require.Equal(t, 1, count) // still 1 +} + +func TestJobInsertFastMany(t *testing.T) { + t.Parallel() + + riversharedtest.SkipIfMySQLNotEnabled(t) + + var ( + ctx = t.Context() + driver = New(riversharedtest.DBPoolMySQL(ctx, t)) + schema = riverdbtest.TestSchema(ctx, t, driver, nil) + exec = driver.GetExecutor() + ) + + results, err := exec.JobInsertFastMany(ctx, &riverdriver.JobInsertFastManyParams{ + Jobs: []*riverdriver.JobInsertFastParams{ + { + EncodedArgs: []byte(`{"i": 1}`), + Kind: "fast_job", + MaxAttempts: 3, + Priority: 1, + Queue: "default", + State: rivertype.JobStateAvailable, + Tags: []string{}, + }, + { + EncodedArgs: []byte(`{"i": 2}`), + Kind: "fast_job", + MaxAttempts: 3, + Priority: 1, + Queue: "default", + State: rivertype.JobStateAvailable, + Tags: []string{}, + }, + }, + Schema: schema, + }) + require.NoError(t, err) + require.Len(t, results, 2) + + for _, result := range results { + require.NotNil(t, result.Job) + require.False(t, result.UniqueSkippedAsDuplicate) + } +} + +func TestJobMetadata(t *testing.T) { + t.Parallel() + + riversharedtest.SkipIfMySQLNotEnabled(t) + + var ( + ctx = t.Context() + driver = New(riversharedtest.DBPoolMySQL(ctx, t)) + schema = riverdbtest.TestSchema(ctx, t, driver, nil) + exec = driver.GetExecutor() + ) + + // Insert a job with metadata + job, err := exec.JobInsertFull(ctx, &riverdriver.JobInsertFullParams{ + EncodedArgs: []byte(`{}`), + Kind: "metadata_test", + MaxAttempts: 3, + Metadata: []byte(`{"key": "value"}`), + Priority: 1, + Queue: "default", + Schema: schema, + State: rivertype.JobStateAvailable, + Tags: []string{}, + }) + require.NoError(t, err) + + var metadata map[string]any + require.NoError(t, json.Unmarshal(job.Metadata, &metadata)) + require.Equal(t, "value", metadata["key"]) + + // Update metadata + updated, err := exec.JobUpdate(ctx, &riverdriver.JobUpdateParams{ + ID: job.ID, + MetadataDoMerge: true, + Metadata: []byte(`{"new_key": "new_value"}`), + Schema: schema, + }) + require.NoError(t, err) + + require.NoError(t, json.Unmarshal(updated.Metadata, &metadata)) + require.Equal(t, "value", metadata["key"]) + require.Equal(t, "new_value", metadata["new_key"]) +} + +func TestJobSchedule(t *testing.T) { + t.Parallel() + + riversharedtest.SkipIfMySQLNotEnabled(t) + + var ( + ctx = t.Context() + driver = New(riversharedtest.DBPoolMySQL(ctx, t)) + schema = riverdbtest.TestSchema(ctx, t, driver, nil) + exec = driver.GetExecutor() + ) + + // Insert a scheduled job with a past scheduled_at + past := time.Now().UTC().Add(-1 * time.Hour).Truncate(time.Microsecond) + _, err := exec.JobInsertFull(ctx, &riverdriver.JobInsertFullParams{ + EncodedArgs: []byte(`{}`), + Kind: "scheduled_job", + MaxAttempts: 3, + Priority: 1, + Queue: "default", + ScheduledAt: &past, + Schema: schema, + State: rivertype.JobStateScheduled, + Tags: []string{}, + }) + require.NoError(t, err) + + // Schedule should find and transition the job + now := time.Now().UTC() + results, err := exec.JobSchedule(ctx, &riverdriver.JobScheduleParams{ + Max: 100, + Now: &now, + Schema: schema, + }) + require.NoError(t, err) + require.Len(t, results, 1) + require.Equal(t, rivertype.JobStateAvailable, results[0].Job.State) +} + +func TestNotifyMany(t *testing.T) { + t.Parallel() + + driver := New(nil) + // MySQL doesn't support LISTEN/NOTIFY + require.Panics(t, func() { + driver.GetListener(&riverdriver.GetListenenerParams{}) + }) +} + +func TestNotifyManyReturnsNotImplemented(t *testing.T) { + t.Parallel() + + riversharedtest.SkipIfMySQLNotEnabled(t) + + var ( + ctx = t.Context() + exec = New(riversharedtest.DBPoolMySQL(ctx, t)).GetExecutor() + ) + + err := exec.NotifyMany(ctx, &riverdriver.NotifyManyParams{ + Payload: []string{"test"}, + Topic: "test_topic", + }) + require.ErrorIs(t, err, riverdriver.ErrNotImplemented) +} + +func TestPGAdvisoryXactLock(t *testing.T) { + t.Parallel() + + riversharedtest.SkipIfMySQLNotEnabled(t) + + var ( + ctx = t.Context() + exec = New(riversharedtest.DBPoolMySQL(ctx, t)).GetExecutor() + ) + + _, err := exec.PGAdvisoryXactLock(ctx, 12345) + require.ErrorIs(t, err, riverdriver.ErrNotImplemented) +} diff --git a/riverdriver/riverpgxv5/river_pgx_v5_driver.go b/riverdriver/riverpgxv5/river_pgx_v5_driver.go index e0201a0e..5a9933ea 100644 --- a/riverdriver/riverpgxv5/river_pgx_v5_driver.go +++ b/riverdriver/riverpgxv5/river_pgx_v5_driver.go @@ -63,8 +63,9 @@ func New(dbPool *pgxpool.Pool) *Driver { const argPlaceholder = "$" -func (d *Driver) ArgPlaceholder() string { return argPlaceholder } -func (d *Driver) DatabaseName() string { return "postgres" } +func (d *Driver) ArgPlaceholder() string { return argPlaceholder } +func (d *Driver) DatabaseName() string { return "postgres" } +func (d *Driver) SafeIdentifier(ident string) string { return dbutil.SafeIdentifier(ident) } func (d *Driver) GetExecutor() riverdriver.Executor { return &Executor{templateReplaceWrapper{d.dbPool, &d.replacer}, d} diff --git a/riverdriver/riversqlite/river_sqlite_driver.go b/riverdriver/riversqlite/river_sqlite_driver.go index 132d0e46..ce77c91a 100644 --- a/riverdriver/riversqlite/river_sqlite_driver.go +++ b/riverdriver/riversqlite/river_sqlite_driver.go @@ -77,8 +77,9 @@ func New(dbPool *sql.DB) *Driver { const argPlaceholder = "?" -func (d *Driver) ArgPlaceholder() string { return argPlaceholder } -func (d *Driver) DatabaseName() string { return "sqlite" } +func (d *Driver) ArgPlaceholder() string { return argPlaceholder } +func (d *Driver) DatabaseName() string { return "sqlite" } +func (d *Driver) SafeIdentifier(ident string) string { return dbutil.SafeIdentifier(ident) } func (d *Driver) GetExecutor() riverdriver.Executor { return &Executor{d.dbPool, templateReplaceWrapper{d.dbPool, &d.replacer}, d, nil} diff --git a/rivermigrate/river_migrate.go b/rivermigrate/river_migrate.go index 12568e20..bfe2a0d2 100644 --- a/rivermigrate/river_migrate.go +++ b/rivermigrate/river_migrate.go @@ -547,7 +547,7 @@ func (m *Migrator[TTx]) applyMigrations(ctx context.Context, exec riverdriver.Ex var schema string if m.schema != "" { - schema = dbutil.SafeIdentifier(m.schema) + "." + schema = m.driver.SafeIdentifier(m.schema) + "." } schemaReplacement := map[string]sqlctemplate.Replacement{ "schema": {Value: schema}, diff --git a/rivershared/riversharedtest/riversharedtest.go b/rivershared/riversharedtest/riversharedtest.go index d6f2b5ba..1f5b7cb9 100644 --- a/rivershared/riversharedtest/riversharedtest.go +++ b/rivershared/riversharedtest/riversharedtest.go @@ -112,6 +112,44 @@ var sqliteTestDir = sync.OnceValue(func() string { //nolint:gochecknoglobals return path.Join(rootDir, "sqlite") }) +// A pool and sync.Once to initialize a MySQL pool, invoked by DBPoolMySQL. +var ( + dbPoolMySQL *sql.DB //nolint:gochecknoglobals + dbPoolMySQLOnce sync.Once //nolint:gochecknoglobals +) + +// SkipIfMySQLNotEnabled skips the current test if MySQL tests are not enabled. +// MySQL tests are opt-in because they require a running MySQL server. Set +// RIVER_MYSQL_TESTS_ENABLED=1 or RIVER_MYSQL_TESTS_ENABLED=true to enable. +func SkipIfMySQLNotEnabled(tb testing.TB) { + tb.Helper() + + val := os.Getenv("RIVER_MYSQL_TESTS_ENABLED") + if val != "1" && val != "true" { //nolint:goconst + tb.Skip("Skipping MySQL tests; set RIVER_MYSQL_TESTS_ENABLED=1 to enable") + } +} + +// DBPoolMySQL gets a lazily initialized database pool for MySQL testing. +// Uses TEST_MYSQL_URL or defaults to root@tcp(localhost:3306)/?. +func DBPoolMySQL(ctx context.Context, tb testing.TB) *sql.DB { + tb.Helper() + + dbPoolMySQLOnce.Do(func() { + dsn := cmp.Or( + os.Getenv("TEST_MYSQL_URL"), + "root@tcp(localhost:3306)/?parseTime=true&multiStatements=true&loc=UTC&time_zone=%27%2B00%3A00%27", + ) + + var err error + dbPoolMySQL, err = sql.Open("mysql", dsn) + require.NoError(tb, err) + }) + require.NotNil(tb, dbPoolMySQL) + + return dbPoolMySQL +} + // DBPoolLibSQL gets a database pool appropriate for use with libSQL (a SQLite // fork) in testing. func DBPoolLibSQL(ctx context.Context, tb testing.TB, schema string) *sql.DB { diff --git a/rivershared/sqlctemplate/sqlc_template.go b/rivershared/sqlctemplate/sqlc_template.go index 88c02edb..484a76ae 100644 --- a/rivershared/sqlctemplate/sqlc_template.go +++ b/rivershared/sqlctemplate/sqlc_template.go @@ -89,6 +89,14 @@ type Replacement struct { // be initialized with a constructor. This lets it default to a usable instance // on drivers that may themselves not be initialized. type Replacer struct { + // UnnumberedPlaceholders, when true, causes Run to emit plain `?` + // placeholders instead of numbered `?1`, `?2`, etc. for named args + // injected via the template system. The args slice is reordered to + // match the positional order of placeholders in the SQL. This is + // needed for MySQL, whose database/sql driver does not support + // numbered `?N` syntax. + UnnumberedPlaceholders bool + cache map[replacerCacheKey]string cacheMu sync.RWMutex } @@ -142,6 +150,14 @@ func (r *Replacer) RunSafely(ctx context.Context, argPlaceholder, sql string, ar } cacheKey, cacheEligible := replacerCacheKeyFrom(sql, container) + + // Named args are interleaved with sqlc args during a left-to-right SQL + // walk, producing args in positional order. The cache can't reconstruct + // this ordering on hit, so skip caching when named args are present. + if len(container.NamedArgs) > 0 { + cacheEligible = false + } + if cacheEligible { r.cacheMu.RLock() var ( @@ -205,29 +221,88 @@ func (r *Replacer) RunSafely(ctx context.Context, argPlaceholder, sql string, ar } if len(container.NamedArgs) > 0 { - placeholderNum := len(args) - // For the benefit of the test suite's output being predictable, sort // named args before processing them. sortedNamedArgs := maputil.Keys(container.NamedArgs) slices.Sort(sortedNamedArgs) + + // Verify all named args are present in the SQL. for _, arg := range sortedNamedArgs { - placeholderNum++ + if !strings.Contains(updatedSQL, "@"+arg) { + return "", nil, fmt.Errorf("sqltemplate expected to find named arg %q, but it wasn't present", "@"+arg) + } + } - var ( - symbol = "@" + arg - symbolIndex = strings.Index(updatedSQL, symbol) - val = container.NamedArgs[arg] - ) + // Walk the SQL once left-to-right, replacing @name references and + // existing sqlc placeholders with sequentially numbered (or + // unnumbered) placeholders, building the args slice in positional + // order as we go. + var ( + newArgs []any + out strings.Builder + seqNum int + sqlcArgIdx int + ) - if symbolIndex == -1 { - return "", nil, fmt.Errorf("sqltemplate expected to find named arg %q, but it wasn't present", symbol) + emitPlaceholder := func() { + seqNum++ + if r.UnnumberedPlaceholders { + out.WriteByte('?') + } else { + out.WriteString(argPlaceholder) + out.WriteString(strconv.Itoa(seqNum)) } + } - // ReplaceAll because an input parameter may appear multiple times. - updatedSQL = strings.ReplaceAll(updatedSQL, symbol, argPlaceholder+strconv.Itoa(placeholderNum)) - args = append(args, val) + for i := 0; i < len(updatedSQL); { + switch { + case updatedSQL[i] == '@': + matched := false + for _, name := range sortedNamedArgs { + symbol := "@" + name + if strings.HasPrefix(updatedSQL[i:], symbol) { + emitPlaceholder() + newArgs = append(newArgs, container.NamedArgs[name]) + i += len(symbol) + matched = true + break + } + } + if !matched { + out.WriteByte(updatedSQL[i]) + i++ + } + + case updatedSQL[i] == argPlaceholder[0] && i+1 < len(updatedSQL) && updatedSQL[i+1] >= '1' && updatedSQL[i+1] <= '9': + // Numbered placeholder from sqlc ($N for Postgres, ?N for + // SQLite): parse the old number, look up the original arg, + // and re-emit with a new sequential number. + j := i + 1 + for j < len(updatedSQL) && updatedSQL[j] >= '0' && updatedSQL[j] <= '9' { + j++ + } + oldNum, _ := strconv.Atoi(updatedSQL[i+1 : j]) + emitPlaceholder() + newArgs = append(newArgs, args[oldNum-1]) + i = j + + case r.UnnumberedPlaceholders && updatedSQL[i] == '?': + // Unnumbered positional ? from sqlc (MySQL). + emitPlaceholder() + if sqlcArgIdx < len(args) { + newArgs = append(newArgs, args[sqlcArgIdx]) + } + sqlcArgIdx++ + i++ + + default: + out.WriteByte(updatedSQL[i]) + i++ + } } + + updatedSQL = out.String() + args = newArgs } if cacheEligible { diff --git a/rivershared/sqlctemplate/sqlc_template_test.go b/rivershared/sqlctemplate/sqlc_template_test.go index 119ebd54..a53eeaa3 100644 --- a/rivershared/sqlctemplate/sqlc_template_test.go +++ b/rivershared/sqlctemplate/sqlc_template_test.go @@ -259,7 +259,7 @@ func TestReplacer(t *testing.T) { AND state = @state; ` - // Initially cached value + // Initially cached value (no named args). { ctx := WithReplacements(ctx, map[string]Replacement{ "schema": {Stable: true, Value: "test_schema."}, @@ -270,7 +270,7 @@ func TestReplacer(t *testing.T) { } require.Len(t, replacer.cache, 1) - // Same SQL, but new value. + // Same SQL, but new replacement value. { ctx := WithReplacements(ctx, map[string]Replacement{ "schema": {Stable: true, Value: "other_schema."}, @@ -281,7 +281,9 @@ func TestReplacer(t *testing.T) { } require.Len(t, replacer.cache, 2) - // Named arg added to the mix. + // Named args present: caching is skipped because args are built + // in positional order during a left-to-right walk that can't be + // replayed from the cache. { ctx := WithReplacements(ctx, map[string]Replacement{ "schema": {Stable: true, Value: "test_schema."}, @@ -292,37 +294,9 @@ func TestReplacer(t *testing.T) { _, _, err := replacer.RunSafely(ctx, "$", sql, nil) require.NoError(t, err) } - require.Len(t, replacer.cache, 3) - - // Different named arg _value_ (i.e. still same named arg) can still use - // the previous cached SQL. - { - ctx := WithReplacements(ctx, map[string]Replacement{ - "schema": {Stable: true, Value: "test_schema."}, - }, map[string]any{ - "kind": "other_kind_value", - }) - - _, _, err := replacer.RunSafely(ctx, "$", sql, nil) - require.NoError(t, err) - } - require.Len(t, replacer.cache, 3) // unchanged + require.Len(t, replacer.cache, 2) // unchanged - // New named arg adds a new cache value. - { - ctx := WithReplacements(ctx, map[string]Replacement{ - "schema": {Stable: true, Value: "test_schema."}, - }, map[string]any{ - "kind": "kind_value", - "state": "state_value", - }) - - _, _, err := replacer.RunSafely(ctx, "$", sql, nil) - require.NoError(t, err) - } - require.Len(t, replacer.cache, 4) - - // Different input SQL. + // Different input SQL (no named args). { ctx := WithReplacements(ctx, map[string]Replacement{ "schema": {Stable: true, Value: "test_schema."}, @@ -333,7 +307,7 @@ func TestReplacer(t *testing.T) { `, nil) require.NoError(t, err) } - require.Len(t, replacer.cache, 5) + require.Len(t, replacer.cache, 3) }) t.Run("NamedArgsNoInitialArgs", func(t *testing.T) { @@ -372,6 +346,9 @@ func TestReplacer(t *testing.T) { "kind": "no_op", }) + // Named arg @kind appears before the sqlc $1 in the SQL, so after + // the left-to-right walk, args are reordered to [no_op, succeeded] + // and placeholders renumbered sequentially. updatedSQL, args, err := replacer.RunSafely(ctx, "$", ` SELECT count(*) FROM river_job @@ -379,12 +356,12 @@ func TestReplacer(t *testing.T) { AND status = $1; `, []any{"succeeded"}) require.NoError(t, err) - require.Equal(t, []any{"succeeded", "no_op"}, args) + require.Equal(t, []any{"no_op", "succeeded"}, args) require.Equal(t, ` SELECT count(*) FROM river_job - WHERE kind = $2 - AND status = $1; + WHERE kind = $1 + AND status = $2; `, updatedSQL) }) @@ -420,6 +397,225 @@ func TestReplacer(t *testing.T) { `, updatedSQL) }) + t.Run("UnnumberedPlaceholders_NoNamedArgs", func(t *testing.T) { + t.Parallel() + + replacer := &Replacer{UnnumberedPlaceholders: true} + + ctx := WithReplacements(ctx, map[string]Replacement{ + "schema": {Value: "test_schema."}, + }, nil) + + updatedSQL, args, err := replacer.RunSafely(ctx, "?", ` + SELECT count(*) + FROM /* TEMPLATE: schema */river_job + WHERE state = ?; + `, []any{"available"}) + require.NoError(t, err) + require.Equal(t, []any{"available"}, args) + require.Equal(t, ` + SELECT count(*) + FROM test_schema.river_job + WHERE state = ?; + `, updatedSQL) + }) + + t.Run("UnnumberedPlaceholders_NamedArgsNoInitialArgs", func(t *testing.T) { + t.Parallel() + + replacer := &Replacer{UnnumberedPlaceholders: true} + + ctx := WithReplacements(ctx, map[string]Replacement{ + "where_clause": {Value: "kind = @kind"}, + }, map[string]any{ + "kind": "no_op", + }) + + updatedSQL, args, err := replacer.RunSafely(ctx, "?", ` + SELECT count(*) + FROM river_job + WHERE /* TEMPLATE_BEGIN: where_clause */ true /* TEMPLATE_END */; + `, nil) + require.NoError(t, err) + require.Equal(t, []any{"no_op"}, args) + require.Equal(t, ` + SELECT count(*) + FROM river_job + WHERE kind = ?; + `, updatedSQL) + }) + + t.Run("UnnumberedPlaceholders_NamedArgsWithInitialArgs", func(t *testing.T) { + t.Parallel() + + replacer := &Replacer{UnnumberedPlaceholders: true} + + ctx := WithReplacements(ctx, map[string]Replacement{ + "where_clause": {Value: "kind = @kind"}, + }, map[string]any{ + "kind": "no_op", + }) + + // The named arg @kind appears in the WHERE clause before the + // sqlc-generated ? for LIMIT. UnnumberedPlaceholders reorders args so + // that they match the positional ? order in the final SQL. + updatedSQL, args, err := replacer.RunSafely(ctx, "?", ` + SELECT count(*) + FROM river_job + WHERE /* TEMPLATE_BEGIN: where_clause */ true /* TEMPLATE_END */ + LIMIT ?; + `, []any{100}) + require.NoError(t, err) + require.Equal(t, []any{"no_op", 100}, args) + require.Equal(t, ` + SELECT count(*) + FROM river_job + WHERE kind = ? + LIMIT ?; + `, updatedSQL) + }) + + t.Run("UnnumberedPlaceholders_NamedArgRepeated", func(t *testing.T) { + t.Parallel() + + replacer := &Replacer{UnnumberedPlaceholders: true} + + ctx := WithReplacements(ctx, map[string]Replacement{ + "where_clause": {Value: "kind = @kind OR queue = @kind"}, + }, map[string]any{ + "kind": "no_op", + }) + + // When a named arg appears multiple times, it should produce a ? for + // each occurrence with the value duplicated in the args slice. + updatedSQL, args, err := replacer.RunSafely(ctx, "?", ` + SELECT * + FROM river_job + WHERE /* TEMPLATE_BEGIN: where_clause */ true /* TEMPLATE_END */ + LIMIT ?; + `, []any{100}) + require.NoError(t, err) + require.Equal(t, []any{"no_op", "no_op", 100}, args) + require.Equal(t, ` + SELECT * + FROM river_job + WHERE kind = ? OR queue = ? + LIMIT ?; + `, updatedSQL) + }) + + t.Run("UnnumberedPlaceholders_MultipleNamedArgs", func(t *testing.T) { + t.Parallel() + + replacer := &Replacer{UnnumberedPlaceholders: true} + + ctx := WithReplacements(ctx, map[string]Replacement{ + "schema": {Stable: true, Value: "test_schema."}, + "where_clause": {Value: "kind = @kind AND status = @status"}, + }, map[string]any{ + "kind": "no_op", + "status": "succeeded", + }) + + updatedSQL, args, err := replacer.RunSafely(ctx, "?", ` + SELECT count(*) + FROM /* TEMPLATE: schema */river_job + WHERE /* TEMPLATE_BEGIN: where_clause */ true /* TEMPLATE_END */ + LIMIT ?; + `, []any{100}) + require.NoError(t, err) + require.Equal(t, []any{"no_op", "succeeded", 100}, args) + require.Equal(t, ` + SELECT count(*) + FROM test_schema.river_job + WHERE kind = ? AND status = ? + LIMIT ?; + `, updatedSQL) + }) + + t.Run("UnnumberedPlaceholders_NotCachedWithNamedArgs", func(t *testing.T) { + t.Parallel() + + replacer := &Replacer{UnnumberedPlaceholders: true} + + ctx := WithReplacements(ctx, map[string]Replacement{ + "schema": {Stable: true, Value: "test_schema."}, + "where_clause": {Stable: true, Value: "kind = @kind"}, + }, map[string]any{ + "kind": "no_op", + }) + + sql := ` + SELECT count(*) + FROM /* TEMPLATE: schema */river_job + WHERE /* TEMPLATE_BEGIN: where_clause */ true /* TEMPLATE_END */ + LIMIT ?; + ` + + // Unnumbered mode with named args skips caching because the + // cached SQL can't preserve the positional arg ordering. + updatedSQL, args, err := replacer.RunSafely(ctx, "?", sql, []any{100}) + require.NoError(t, err) + require.Equal(t, []any{"no_op", 100}, args) + require.Equal(t, ` + SELECT count(*) + FROM test_schema.river_job + WHERE kind = ? + LIMIT ?; + `, updatedSQL) + + require.Empty(t, replacer.cache) + + // Second call still produces correct results. + updatedSQL, args, err = replacer.RunSafely(ctx, "?", sql, []any{200}) + require.NoError(t, err) + require.Equal(t, []any{"no_op", 200}, args) + require.Equal(t, ` + SELECT count(*) + FROM test_schema.river_job + WHERE kind = ? + LIMIT ?; + `, updatedSQL) + }) + + t.Run("UnnumberedPlaceholders_CachedWithoutNamedArgs", func(t *testing.T) { + t.Parallel() + + replacer := &Replacer{UnnumberedPlaceholders: true} + + ctx := WithReplacements(ctx, map[string]Replacement{ + "schema": {Stable: true, Value: "test_schema."}, + }, nil) + + sql := ` + SELECT count(*) + FROM /* TEMPLATE: schema */river_job + LIMIT ?; + ` + + // Without named args, caching works normally in unnumbered mode. + updatedSQL, args, err := replacer.RunSafely(ctx, "?", sql, []any{100}) + require.NoError(t, err) + require.Equal(t, []any{100}, args) + require.Equal(t, ` + SELECT count(*) + FROM test_schema.river_job + LIMIT ?; + `, updatedSQL) + + require.Len(t, replacer.cache, 1) + + // Second call uses cache. + updatedSQL, args, err = replacer.RunSafely(ctx, "?", sql, []any{200}) + require.NoError(t, err) + require.Equal(t, []any{200}, args) + require.Equal(t, ` + SELECT count(*) + FROM test_schema.river_job + LIMIT ?; + `, updatedSQL) + }) + t.Run("Stress", func(t *testing.T) { t.Parallel()