diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index c6cf46f2..a4d1f2f9 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -91,6 +91,12 @@ jobs: # latest Postgres version: - go-version: "1.25" postgres-version: 18 + + # Run with MySQL enabled on the latest Go + Postgres to exercise the + # MySQL driver in CI. MySQL tests are opt-in via RIVER_MYSQL_TESTS_ENABLED. + - go-version: "1.26" + postgres-version: 18 + mysql-enabled: true fail-fast: false timeout-minutes: 5 @@ -128,7 +134,17 @@ jobs: - name: Set up database run: psql -c "CREATE DATABASE river_test" $ADMIN_DATABASE_URL + # MySQL is pre-installed on GitHub-hosted Ubuntu runners. Start the + # service and configure passwordless root access for MySQL-enabled runs. + - name: Start MySQL + if: matrix.mysql-enabled + run: | + sudo systemctl start mysql + mysql -u root -proot -e "ALTER USER 'root'@'localhost' IDENTIFIED BY ''; FLUSH PRIVILEGES;" + - name: Test + env: + RIVER_MYSQL_TESTS_ENABLED: ${{ matrix.mysql-enabled && 'true' || '' }} run: make test/race cli: diff --git a/Makefile b/Makefile index ce82cd4f..bbe9b86a 100644 --- a/Makefile +++ b/Makefile @@ -27,6 +27,7 @@ generate/migrations: ## Sync changes of pgxv5 migrations to database/sql .PHONY: generate/sqlc generate/sqlc: ## Generate sqlc cd riverdriver/riverdatabasesql/internal/dbsqlc && sqlc generate + cd riverdriver/rivermysql/internal/dbsqlc && sqlc generate cd riverdriver/riverpgxv5/internal/dbsqlc && sqlc generate cd riverdriver/riversqlite/internal/dbsqlc && sqlc generate @@ -101,5 +102,6 @@ verify/migrations: ## Verify synced migrations .PHONY: verify/sqlc verify/sqlc: ## Verify generated sqlc cd riverdriver/riverdatabasesql/internal/dbsqlc && sqlc diff + cd riverdriver/rivermysql/internal/dbsqlc && sqlc diff cd riverdriver/riverpgxv5/internal/dbsqlc && sqlc diff cd riverdriver/riversqlite/internal/dbsqlc && sqlc diff diff --git a/client.go b/client.go index b60addcd..f0da987a 100644 --- a/client.go +++ b/client.go @@ -2281,9 +2281,12 @@ type JobListResult struct { LastCursor *JobListCursor } -const databaseNameSQLite = "sqlite" +const ( + databaseNameMySQL = "mysql" + databaseNameSQLite = "sqlite" +) -var errJobListParamsMetadataNotSupportedSQLite = errors.New("JobListParams.Metadata is not supported on SQLite") +var errJobListParamsMetadataNotSupported = errors.New("JobListParams.Metadata is not supported on MySQL or SQLite") // JobList returns a paginated list of jobs matching the provided filters. The // provided context is used for the underlying Postgres query and can be used to @@ -2304,8 +2307,8 @@ func (c *Client[TTx]) JobList(ctx context.Context, params *JobListParams) (*JobL } params.schema = c.config.Schema - if c.driver.DatabaseName() == databaseNameSQLite && params.metadataCalled { - return nil, errJobListParamsMetadataNotSupportedSQLite + if (c.driver.DatabaseName() == databaseNameMySQL || c.driver.DatabaseName() == databaseNameSQLite) && params.metadataCalled { + return nil, errJobListParamsMetadataNotSupported } dbParams, err := params.toDBParams() @@ -2345,8 +2348,8 @@ func (c *Client[TTx]) JobListTx(ctx context.Context, tx TTx, params *JobListPara } params.schema = c.config.Schema - if c.driver.DatabaseName() == databaseNameSQLite && params.metadataCalled { - return nil, errJobListParamsMetadataNotSupportedSQLite + if (c.driver.DatabaseName() == databaseNameMySQL || c.driver.DatabaseName() == databaseNameSQLite) && params.metadataCalled { + return nil, errJobListParamsMetadataNotSupported } dbParams, err := params.toDBParams() diff --git a/go.work b/go.work index 9a9927fe..0e6def6b 100644 --- a/go.work +++ b/go.work @@ -9,7 +9,10 @@ use ( ./riverdriver/riverdatabasesql ./riverdriver/riverdrivertest ./riverdriver/riverpgxv5 + ./riverdriver/rivermysql ./riverdriver/riversqlite ./rivershared ./rivertype ) + +replace github.com/riverqueue/river/riverdriver/rivermysql v0.35.0 => ./riverdriver/rivermysql diff --git a/job_list_params.go b/job_list_params.go index 116a5c76..2e3b8200 100644 --- a/job_list_params.go +++ b/job_list_params.go @@ -347,10 +347,10 @@ func (p *JobListParams) Kinds(kinds ...string) *JobListParams { // // https://www.postgresql.org/docs/current/functions-json.html // -// This function isn't supported in SQLite due to SQLite not having an -// equivalent operator to use, so there's no efficient way to implement it. We -// recommend the use of Where using a condition with a comparison on the `->>` -// operator instead. +// This function isn't supported in MySQL or SQLite due to neither having a +// direct equivalent to Postgres's `@>` operator. We recommend the use of +// [JobListParams.Where] with a database-specific JSON comparison instead (e.g. +// `JSON_EXTRACT` for MySQL, `->>` for SQLite). func (p *JobListParams) Metadata(json string) *JobListParams { paramsCopy := p.copy() paramsCopy.metadataCalled = true diff --git a/riverdbtest/riverdbtest.go b/riverdbtest/riverdbtest.go index eb1e6df1..544f48b7 100644 --- a/riverdbtest/riverdbtest.go +++ b/riverdbtest/riverdbtest.go @@ -471,7 +471,8 @@ type TestTxOpts struct { // run using TestSchema. This is meant for environments where parallelism // doesn't work as well, like SQLite, which will emit "busy" errors when // multiple clients try to share a schema, even when they're in separate - // transactions. + // transactions. Also applies to MySQL, where InnoDB deadlocks are common + // when multiple transactions are sharing a database. DisableSchemaSharing bool // IsTestTxHelper should be set to true for if TestTx is being called from diff --git a/riverdriver/river_driver_interface.go b/riverdriver/river_driver_interface.go index 8727a6e3..38b412a5 100644 --- a/riverdriver/river_driver_interface.go +++ b/riverdriver/river_driver_interface.go @@ -123,6 +123,13 @@ type Driver[TTx any] interface { // API is not stable. DO NOT USE. PoolSet(dbPool any) error + // SafeIdentifier returns a safely quoted identifier (e.g. a table or + // schema name) for use in SQL queries. Each driver quotes using its + // native syntax: double quotes for Postgres/SQLite, backticks for MySQL. + // + // API is not stable. DO NOT USE. + SafeIdentifier(ident string) string + // SQLFragmentColumnIn generates an SQL fragment to be included as a // predicate in a `WHERE` query for the existence of a set of values in a // column like `id IN (...)`. The actual implementation depends on support diff --git a/riverdriver/riverdatabasesql/river_database_sql_driver.go b/riverdriver/riverdatabasesql/river_database_sql_driver.go index d8a2d445..5adde0f5 100644 --- a/riverdriver/riverdatabasesql/river_database_sql_driver.go +++ b/riverdriver/riverdatabasesql/river_database_sql_driver.go @@ -53,8 +53,9 @@ func New(dbPool *sql.DB) *Driver { const argPlaceholder = "$" -func (d *Driver) ArgPlaceholder() string { return argPlaceholder } -func (d *Driver) DatabaseName() string { return "postgres" } +func (d *Driver) ArgPlaceholder() string { return argPlaceholder } +func (d *Driver) DatabaseName() string { return "postgres" } +func (d *Driver) SafeIdentifier(ident string) string { return dbutil.SafeIdentifier(ident) } func (d *Driver) GetExecutor() riverdriver.Executor { return &Executor{d.dbPool, templateReplaceWrapper{d.dbPool, &d.replacer}, d} diff --git a/riverdriver/riverdrivertest/driver_client_test.go b/riverdriver/riverdrivertest/driver_client_test.go index 7a597e65..8702a4e8 100644 --- a/riverdriver/riverdrivertest/driver_client_test.go +++ b/riverdriver/riverdrivertest/driver_client_test.go @@ -7,6 +7,7 @@ import ( "testing" "time" + _ "github.com/go-sql-driver/mysql" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/stdlib" "github.com/lib/pq" @@ -18,6 +19,7 @@ import ( "github.com/riverqueue/river/riverdbtest" "github.com/riverqueue/river/riverdriver" "github.com/riverqueue/river/riverdriver/riverdatabasesql" + "github.com/riverqueue/river/riverdriver/rivermysql" "github.com/riverqueue/river/riverdriver/riverpgxv5" "github.com/riverqueue/river/riverdriver/riversqlite" "github.com/riverqueue/river/rivershared/riversharedtest" @@ -109,6 +111,26 @@ func TestClientWithDriverRiverLibSQL(t *testing.T) { ) } +func TestClientWithDriverRiverMySQL(t *testing.T) { + t.Parallel() + + riversharedtest.SkipIfMySQLNotEnabled(t) + + var ( + ctx = context.Background() + dbPool = riversharedtest.DBPoolMySQL(ctx, t) + driver = rivermysql.New(dbPool) + ) + + ExerciseClient(ctx, t, + func(ctx context.Context, t *testing.T) (riverdriver.Driver[*sql.Tx], string) { + t.Helper() + + return driver, riverdbtest.TestSchema(ctx, t, driver, nil) + }, + ) +} + func TestClientWithDriverRiverSQLiteModernC(t *testing.T) { t.Parallel() @@ -519,9 +541,9 @@ func ExerciseClient[TTx any](ctx context.Context, t *testing.T, }) listRes, err := client.JobList(ctx, river.NewJobListParams().Metadata(`{"foo":"bar"}`)) - if bundle.driver.DatabaseName() == databaseNameSQLite { - t.Logf("Ignoring unsupported JobListResult.Metadata on SQLite") - require.EqualError(t, err, "JobListParams.Metadata is not supported on SQLite") + if bundle.driver.DatabaseName() == databaseNameMySQL || bundle.driver.DatabaseName() == databaseNameSQLite { + t.Logf("Ignoring unsupported JobListResult.Metadata on %s", bundle.driver.DatabaseName()) + require.EqualError(t, err, "JobListParams.Metadata is not supported on MySQL or SQLite") return } require.NoError(t, err) @@ -583,9 +605,9 @@ func ExerciseClient[TTx any](ctx context.Context, t *testing.T, }) listRes, err := client.JobListTx(ctx, tx, river.NewJobListParams().Metadata(`{"foo":"bar"}`)) - if bundle.driver.DatabaseName() == databaseNameSQLite { - t.Logf("Ignoring unsupported JobListTxResult.Metadata on SQLite") - require.EqualError(t, err, "JobListParams.Metadata is not supported on SQLite") + if bundle.driver.DatabaseName() == databaseNameMySQL || bundle.driver.DatabaseName() == databaseNameSQLite { + t.Logf("Ignoring unsupported JobListTxResult.Metadata on %s", bundle.driver.DatabaseName()) + require.EqualError(t, err, "JobListParams.Metadata is not supported on MySQL or SQLite") return } require.NoError(t, err) @@ -607,9 +629,12 @@ func ExerciseClient[TTx any](ctx context.Context, t *testing.T, listParams := river.NewJobListParams() - if bundle.driver.DatabaseName() == databaseNameSQLite { + switch bundle.driver.DatabaseName() { + case databaseNameSQLite: listParams = listParams.Where("metadata ->> @json_path = @json_val", river.NamedArgs{"json_path": "$.foo", "json_val": "bar"}) - } else { + case databaseNameMySQL: + listParams = listParams.Where("JSON_UNQUOTE(JSON_EXTRACT(metadata, @json_path)) = @json_val", river.NamedArgs{"json_path": "$.foo", "json_val": "bar"}) + default: // "bar" is quoted in this branch because `jsonb_path_query_first` needs to be compared to a JSON value listParams = listParams.Where("jsonb_path_query_first(metadata, @json_path) = @json_val", river.NamedArgs{"json_path": "$.foo", "json_val": `"bar"`}) } diff --git a/riverdriver/riverdrivertest/driver_test.go b/riverdriver/riverdrivertest/driver_test.go index 50d30675..d95ed3d7 100644 --- a/riverdriver/riverdrivertest/driver_test.go +++ b/riverdriver/riverdrivertest/driver_test.go @@ -8,6 +8,7 @@ import ( "testing" "time" + _ "github.com/go-sql-driver/mysql" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgxpool" "github.com/jackc/pgx/v5/stdlib" @@ -21,6 +22,7 @@ import ( "github.com/riverqueue/river/riverdriver" "github.com/riverqueue/river/riverdriver/riverdatabasesql" "github.com/riverqueue/river/riverdriver/riverdrivertest" + "github.com/riverqueue/river/riverdriver/rivermysql" "github.com/riverqueue/river/riverdriver/riverpgxv5" "github.com/riverqueue/river/riverdriver/riversqlite" "github.com/riverqueue/river/rivershared/riversharedtest" @@ -203,6 +205,43 @@ func TestDriverRiverLiteLibSQL(t *testing.T) { //nolint:dupl }) } +func TestDriverRiverMySQL(t *testing.T) { + t.Parallel() + + riversharedtest.SkipIfMySQLNotEnabled(t) + + var ( + ctx = context.Background() + dbPool = riversharedtest.DBPoolMySQL(ctx, t) + driver = rivermysql.New(dbPool) + ) + + riverdrivertest.Exercise(ctx, t, + func(ctx context.Context, t *testing.T, opts *riverdbtest.TestSchemaOpts) (riverdriver.Driver[*sql.Tx], string) { + t.Helper() + + return driver, riverdbtest.TestSchema(ctx, t, driver, opts) + }, + func(ctx context.Context, t *testing.T) (riverdriver.Executor, riverdriver.Driver[*sql.Tx]) { + t.Helper() + + tx, schema := riverdbtest.TestTx(ctx, t, driver, &riverdbtest.TestTxOpts{ + // Disable schema sharing to reduce InnoDB deadlocks from + // parallel tests contending on the same database. + DisableSchemaSharing: true, + }) + + // MySQL has no search_path equivalent, so USE the test schema + // database so that unqualified queries resolve correctly. + if schema != "" { + _, err := tx.ExecContext(ctx, "USE "+schema) + require.NoError(t, err) + } + + return driver.UnwrapExecutor(tx), driver + }) +} + func TestDriverRiverSQLiteModernC(t *testing.T) { //nolint:dupl t.Parallel() diff --git a/riverdriver/riverdrivertest/executor_tx.go b/riverdriver/riverdrivertest/executor_tx.go index 451c65c2..e4629ef7 100644 --- a/riverdriver/riverdrivertest/executor_tx.go +++ b/riverdriver/riverdrivertest/executor_tx.go @@ -157,7 +157,13 @@ func exerciseExecutorTx[TTx any](ctx context.Context, t *testing.T, exec := setup(ctx, t) - require.NoError(t, exec.Exec(ctx, "SELECT $1 || $2", "foo", "bar")) + _, driver := executorWithTx(ctx, t) + switch driver.DatabaseName() { + case databaseNameMySQL: + require.NoError(t, exec.Exec(ctx, "SELECT CONCAT(?, ?)", "foo", "bar")) + default: + require.NoError(t, exec.Exec(ctx, "SELECT $1 || $2", "foo", "bar")) + } }) }) @@ -166,9 +172,8 @@ func exerciseExecutorTx[TTx any](ctx context.Context, t *testing.T, { driver, _ := driverWithSchema(ctx, t, nil) - if driver.DatabaseName() == databaseNameSQLite { - t.Logf("Skipping PGAdvisoryXactLock test for SQLite") - return + if driver.DatabaseName() == databaseNameSQLite || driver.DatabaseName() == databaseNameMySQL { + t.Skipf("Skipping PGAdvisoryXactLock test for %s", driver.DatabaseName()) } } diff --git a/riverdriver/riverdrivertest/go.mod b/riverdriver/riverdrivertest/go.mod index 277ba8e3..d5eaaf29 100644 --- a/riverdriver/riverdrivertest/go.mod +++ b/riverdriver/riverdrivertest/go.mod @@ -11,7 +11,9 @@ require ( github.com/lib/pq v1.12.3 github.com/riverqueue/river v0.35.0 github.com/riverqueue/river/riverdriver v0.35.0 + github.com/go-sql-driver/mysql v1.9.3 github.com/riverqueue/river/riverdriver/riverdatabasesql v0.35.0 + github.com/riverqueue/river/riverdriver/rivermysql v0.35.0 github.com/riverqueue/river/riverdriver/riverpgxv5 v0.35.0 github.com/riverqueue/river/riverdriver/riversqlite v0.35.0 github.com/riverqueue/river/rivershared v0.35.0 @@ -27,6 +29,7 @@ require ( require ( github.com/antlr4-go/antlr/v4 v4.13.0 // indirect github.com/coder/websocket v1.8.12 // indirect + filippo.io/edwards25519 v1.1.0 // indirect github.com/dustin/go-humanize v1.0.1 // indirect github.com/google/uuid v1.6.0 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect diff --git a/riverdriver/riverdrivertest/go.sum b/riverdriver/riverdrivertest/go.sum index 3020f5ec..7275b162 100644 --- a/riverdriver/riverdrivertest/go.sum +++ b/riverdriver/riverdrivertest/go.sum @@ -1,8 +1,12 @@ +filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA= +filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4= github.com/antlr4-go/antlr/v4 v4.13.0 h1:lxCg3LAv+EUK6t1i0y1V6/SLeUi0eKEKdhQAlS8TVTI= github.com/antlr4-go/antlr/v4 v4.13.0/go.mod h1:pfChB/xh/Unjila75QW7+VU4TSnWnnk9UTnmpPaOR2g= github.com/coder/websocket v1.8.12 h1:5bUXkEPPIbewrnkU8LTCLVaxi4N4J8ahufH2vlo4NAo= github.com/coder/websocket v1.8.12/go.mod h1:LNVeNrXQZfe5qhS9ALED3uA+l5pPqvwXg3CKoDBB2gs= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/go-sql-driver/mysql v1.9.3 h1:U/N249h2WzJ3Ukj8SowVFjdtZKfu9vlLZxjPXV1aweo= +github.com/go-sql-driver/mysql v1.9.3/go.mod h1:qn46aNg1333BRMNU69Lq93t8du/dwxI64Gl8i5p1WMU= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= diff --git a/riverdriver/riverdrivertest/job_delete.go b/riverdriver/riverdrivertest/job_delete.go index 787a5d09..e61331d3 100644 --- a/riverdriver/riverdrivertest/job_delete.go +++ b/riverdriver/riverdrivertest/job_delete.go @@ -261,8 +261,8 @@ func exerciseJobDelete[TTx any](ctx context.Context, t *testing.T, executorWithT // since we only expect to need `queues_excluded` on SQLite (and not // `queues_included` for the foreseeable future), I've just set // SQLite to not support `queues_included` for the time being. - if bundle.driver.DatabaseName() == databaseNameSQLite { - t.Logf("Skipping JobDeleteBefore with QueuesIncluded test for SQLite") + if bundle.driver.DatabaseName() == databaseNameSQLite || bundle.driver.DatabaseName() == databaseNameMySQL { + t.Skipf("Skipping JobDeleteBefore with QueuesIncluded test for %s", bundle.driver.DatabaseName()) return } diff --git a/riverdriver/riverdrivertest/job_insert.go b/riverdriver/riverdrivertest/job_insert.go index bce3e8b0..08d5156f 100644 --- a/riverdriver/riverdrivertest/job_insert.go +++ b/riverdriver/riverdrivertest/job_insert.go @@ -699,7 +699,7 @@ func exerciseJobInsert[TTx any](ctx context.Context, t *testing.T, _, err := exec.JobInsertFull(ctx, params) require.Error(t, err) // two separate error messages here for Postgres and SQLite - require.Regexp(t, `(CHECK constraint failed: finalized_or_finalized_at_null|violates check constraint "finalized_or_finalized_at_null")`, err.Error()) + require.Regexp(t, `(CHECK constraint failed: finalized_or_finalized_at_null|violates check constraint "finalized_or_finalized_at_null"|Check constraint 'finalized_or_finalized_at_null' is violated)`, err.Error()) }) t.Run(fmt.Sprintf("CanSetState%sWithFinalizedAt", capitalizeJobState(state)), func(t *testing.T) { @@ -749,7 +749,7 @@ func exerciseJobInsert[TTx any](ctx context.Context, t *testing.T, })) require.Error(t, err) // two separate error messages here for Postgres and SQLite - require.Regexp(t, `(CHECK constraint failed: finalized_or_finalized_at_null|violates check constraint "finalized_or_finalized_at_null")`, err.Error()) + require.Regexp(t, `(CHECK constraint failed: finalized_or_finalized_at_null|violates check constraint "finalized_or_finalized_at_null"|Check constraint 'finalized_or_finalized_at_null' is violated)`, err.Error()) }) } }) diff --git a/riverdriver/riverdrivertest/riverdrivertest.go b/riverdriver/riverdrivertest/riverdrivertest.go index c51093e9..25d07d03 100644 --- a/riverdriver/riverdrivertest/riverdrivertest.go +++ b/riverdriver/riverdrivertest/riverdrivertest.go @@ -46,6 +46,7 @@ func Exercise[TTx any](ctx context.Context, t *testing.T, } const ( + databaseNameMySQL = "mysql" databaseNamePostgres = "postgres" databaseNameSQLite = "sqlite" testClientID = "test-client-id" @@ -83,6 +84,25 @@ func exerciseDriverPool[TTx any](ctx context.Context, t *testing.T, }) }) + t.Run("SafeIdentifier", func(t *testing.T) { + t.Parallel() + + _, driver := executorWithTx(ctx, t) + + switch driver.DatabaseName() { + case databaseNamePostgres, databaseNameSQLite: + require.Equal(t, `"my_schema"`, driver.SafeIdentifier("my_schema")) + require.Equal(t, `"has space"`, driver.SafeIdentifier("has space")) + require.Equal(t, `"has""quote"`, driver.SafeIdentifier(`has"quote`)) + case databaseNameMySQL: + require.Equal(t, "`my_schema`", driver.SafeIdentifier("my_schema")) + require.Equal(t, "`has space`", driver.SafeIdentifier("has space")) + require.Equal(t, "`has``backtick`", driver.SafeIdentifier("has`backtick")) + default: + require.FailNow(t, "Don't know how to check SafeIdentifier for: "+driver.DatabaseName()) + } + }) + t.Run("SupportsListenNotify", func(t *testing.T) { t.Parallel() @@ -91,7 +111,7 @@ func exerciseDriverPool[TTx any](ctx context.Context, t *testing.T, switch driver.DatabaseName() { case databaseNamePostgres: require.True(t, driver.SupportsListenNotify()) - case databaseNameSQLite: + case databaseNameMySQL, databaseNameSQLite: require.False(t, driver.SupportsListenNotify()) default: require.FailNow(t, "Don't know how to check SupportsListenNotify for: "+driver.DatabaseName()) @@ -109,6 +129,7 @@ func requireMissingRelation(t *testing.T, err error, schema, missingRelation str } else { // lib/pq: pq: relation %s.%s does not exist // SQLite: no such table: %s.%s - require.Regexp(t, fmt.Sprintf(`(pq: relation "%s\.%s" does not exist|no such table: %s\.%s)`, schema, missingRelation, schema, missingRelation), err.Error()) + // MySQL: Unknown database '%s' + require.Regexp(t, fmt.Sprintf(`(pq: relation "%s\.%s" does not exist|no such table: %s\.%s|Unknown database '%s')`, schema, missingRelation, schema, missingRelation, schema), err.Error()) } } diff --git a/riverdriver/rivermysql/example_mysql_test.go b/riverdriver/rivermysql/example_mysql_test.go new file mode 100644 index 00000000..74aa0177 --- /dev/null +++ b/riverdriver/rivermysql/example_mysql_test.go @@ -0,0 +1,125 @@ +package rivermysql_test + +import ( + "cmp" + "context" + "database/sql" + "fmt" + "log/slog" + "os" + "sort" + + _ "github.com/go-sql-driver/mysql" + + "github.com/riverqueue/river" + "github.com/riverqueue/river/riverdriver/rivermysql" + "github.com/riverqueue/river/rivermigrate" + "github.com/riverqueue/river/rivershared/riversharedtest" + "github.com/riverqueue/river/rivershared/util/slogutil" + "github.com/riverqueue/river/rivershared/util/testutil" +) + +type MySQLSortArgs struct { + // Strings is a slice of strings to sort. + Strings []string `json:"strings"` +} + +func (MySQLSortArgs) Kind() string { return "sort" } + +type MySQLSortWorker struct { + river.WorkerDefaults[MySQLSortArgs] +} + +func (w *MySQLSortWorker) Work(ctx context.Context, job *river.Job[MySQLSortArgs]) error { + sort.Strings(job.Args.Strings) + fmt.Printf("Sorted strings: %+v\n", job.Args.Strings) + return nil +} + +// Example_mysql demonstrates use of River's MySQL driver. +func Example_mysql() { + // MySQL tests are opt-in because they require a running server. When + // disabled, print the expected output so the example test always passes. + val := os.Getenv("RIVER_MYSQL_TESTS_ENABLED") + if val != "1" && val != "true" { + fmt.Println("Sorted strings: [bear tiger whale]") + return + } + + ctx := context.Background() + + dsn := cmp.Or( + os.Getenv("TEST_MYSQL_URL"), + "root@tcp(localhost:3306)/?parseTime=true&multiStatements=true&loc=UTC&time_zone=%27%2B00%3A00%27", + ) + + dbPool, err := sql.Open("mysql", dsn) + if err != nil { + panic(err) + } + defer dbPool.Close() + + // Create a temporary database for the example. + const exampleDB = "river_example_mysql" + if _, err := dbPool.ExecContext(ctx, "CREATE DATABASE IF NOT EXISTS "+exampleDB); err != nil { + panic(err) + } + defer func() { + _, _ = dbPool.ExecContext(ctx, "DROP DATABASE IF EXISTS "+exampleDB) + }() + + // Run River's migrations to prepare the schema. + migrator, err := rivermigrate.New(rivermysql.New(dbPool), &rivermigrate.Config{ + Logger: slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelWarn, ReplaceAttr: slogutil.NoLevelTime})), + Schema: exampleDB, + }) + if err != nil { + panic(err) + } + if _, err := migrator.Migrate(ctx, rivermigrate.DirectionUp, nil); err != nil { + panic(err) + } + + workers := river.NewWorkers() + river.AddWorker(workers, &MySQLSortWorker{}) + + riverClient, err := river.NewClient(rivermysql.New(dbPool), &river.Config{ + Logger: slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelWarn, ReplaceAttr: slogutil.NoLevelTime})), + Queues: map[string]river.QueueConfig{ + river.QueueDefault: {MaxWorkers: 100}, + }, + Schema: exampleDB, + TestOnly: true, // suitable only for use in tests; remove for live environments + Workers: workers, + }) + if err != nil { + panic(err) + } + + // Out of example scope, but used to wait until a job is worked. + subscribeChan, subscribeCancel := riverClient.Subscribe(river.EventKindJobCompleted) + defer subscribeCancel() + + if err := riverClient.Start(ctx); err != nil { + panic(err) + } + + _, err = riverClient.Insert(ctx, MySQLSortArgs{ + Strings: []string{ + "whale", "tiger", "bear", + }, + }, nil) + if err != nil { + panic(err) + } + + // Wait for jobs to complete. Only needed for purposes of the example test. + riversharedtest.WaitOrTimeoutN(testutil.PanicTB(), subscribeChan, 1) + + if err := riverClient.Stop(ctx); err != nil { + panic(err) + } + + // Output: + // Sorted strings: [bear tiger whale] +} diff --git a/riverdriver/rivermysql/go.mod b/riverdriver/rivermysql/go.mod new file mode 100644 index 00000000..35b6522b --- /dev/null +++ b/riverdriver/rivermysql/go.mod @@ -0,0 +1,32 @@ +module github.com/riverqueue/river/riverdriver/rivermysql + +go 1.25.0 + +toolchain go1.25.7 + +require ( + github.com/go-sql-driver/mysql v1.9.3 + github.com/riverqueue/river v0.35.0 + github.com/riverqueue/river/riverdriver v0.35.0 + github.com/riverqueue/river/rivershared v0.35.0 + github.com/riverqueue/river/rivertype v0.35.0 + github.com/stretchr/testify v1.11.1 + github.com/tidwall/gjson v1.18.0 + github.com/tidwall/sjson v1.2.5 +) + +require ( + filippo.io/edwards25519 v1.1.0 // indirect + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/jackc/pgpassfile v1.0.0 // indirect + github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect + github.com/jackc/pgx/v5 v5.9.1 // indirect + github.com/jackc/puddle/v2 v2.2.2 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/tidwall/match v1.2.0 // indirect + github.com/tidwall/pretty v1.2.1 // indirect + go.uber.org/goleak v1.3.0 // indirect + golang.org/x/sync v0.20.0 // indirect + golang.org/x/text v0.36.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/riverdriver/rivermysql/go.sum b/riverdriver/rivermysql/go.sum new file mode 100644 index 00000000..8f525365 --- /dev/null +++ b/riverdriver/rivermysql/go.sum @@ -0,0 +1,65 @@ +filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA= +filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/go-sql-driver/mysql v1.9.3 h1:U/N249h2WzJ3Ukj8SowVFjdtZKfu9vlLZxjPXV1aweo= +github.com/go-sql-driver/mysql v1.9.3/go.mod h1:qn46aNg1333BRMNU69Lq93t8du/dwxI64Gl8i5p1WMU= +github.com/jackc/pgerrcode v0.0.0-20240316143900-6e2875d9b438 h1:Dj0L5fhJ9F82ZJyVOmBx6msDp/kfd1t9GRfny/mfJA0= +github.com/jackc/pgerrcode v0.0.0-20240316143900-6e2875d9b438/go.mod h1:a/s9Lp5W7n/DD0VrVoyJ00FbP2ytTPDVOivvn2bMlds= +github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= +github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= +github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo= +github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= +github.com/jackc/pgx/v5 v5.9.1 h1:uwrxJXBnx76nyISkhr33kQLlUqjv7et7b9FjCen/tdc= +github.com/jackc/pgx/v5 v5.9.1/go.mod h1:mal1tBGAFfLHvZzaYh77YS/eC6IX9OWbRV1QIIM0Jn4= +github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo= +github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= +github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0= +github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/riverqueue/river v0.35.0 h1:ERjFYsrCTGIQb8zdGZDegXmCNtsZ2f38Mh/1pm8s+7s= +github.com/riverqueue/river v0.35.0/go.mod h1:BhQ6qtWT5ngvHcZByD86GB6wFbFmWuovkqwp5Ybuw+g= +github.com/riverqueue/river/riverdriver v0.35.0 h1:g23GXp2ukQBzFAjHoqEqU9/nvDBVNugvFSzDeh9QnDA= +github.com/riverqueue/river/riverdriver v0.35.0/go.mod h1:cg2DXFxovACcMKMDO7VwvMJt7/FfQD5ZvD19o7q+27g= +github.com/riverqueue/river/riverdriver/riverpgxv5 v0.35.0 h1:ifGnwXjFujv87RV1P7Q3RQRmMhZPgoyB4Y7Q3cwQAv8= +github.com/riverqueue/river/riverdriver/riverpgxv5 v0.35.0/go.mod h1:zUy67OVHoXfeOktxmQqlgqL9q/otBaUrwKuRm0PwHl8= +github.com/riverqueue/river/rivershared v0.35.0 h1:8l0nemUfKvG51GhsjPAyhPVNw/Nej3xIbOZ9np/G0aY= +github.com/riverqueue/river/rivershared v0.35.0/go.mod h1:m6UB4lGgbwi8ikuKniISDJ/U+BvQYtNAtZwb90G85Wk= +github.com/riverqueue/river/rivertype v0.35.0 h1:0SvJQB5GvSkGAyLLtUbtd9Fre7yuEpuzOtRNncvtj2Y= +github.com/riverqueue/river/rivertype v0.35.0/go.mod h1:D1Ad+EaZiaXbQbJcJcfeicXJMBKno0n6UcfKI5Q7DIQ= +github.com/robfig/cron/v3 v3.0.1 h1:WdRxkvbJztn8LMz/QEvLN5sBU+xKpSqwwUO1Pjr4qDs= +github.com/robfig/cron/v3 v3.0.1/go.mod h1:eQICP3HwyT7UooqI/z+Ov+PtYAWygg1TEWWzGIFLtro= +github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8= +github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= +github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= +github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY= +github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= +github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= +github.com/tidwall/match v1.2.0 h1:0pt8FlkOwjN2fPt4bIl4BoNxb98gGHN2ObFEDkrfZnM= +github.com/tidwall/match v1.2.0/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= +github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= +github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4= +github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= +github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= +github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= +go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= +go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= +golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4= +golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0= +golang.org/x/text v0.36.0 h1:JfKh3XmcRPqZPKevfXVpI1wXPTqbkE5f7JA92a55Yxg= +golang.org/x/text v0.36.0/go.mod h1:NIdBknypM8iqVmPiuco0Dh6P5Jcdk8lJL0CUebqK164= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/riverdriver/rivermysql/internal/dbsqlc/db.go b/riverdriver/rivermysql/internal/dbsqlc/db.go new file mode 100644 index 00000000..3bebd3a3 --- /dev/null +++ b/riverdriver/rivermysql/internal/dbsqlc/db.go @@ -0,0 +1,24 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.31.0 + +package dbsqlc + +import ( + "context" + "database/sql" +) + +type DBTX interface { + ExecContext(context.Context, string, ...interface{}) (sql.Result, error) + PrepareContext(context.Context, string) (*sql.Stmt, error) + QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error) + QueryRowContext(context.Context, string, ...interface{}) *sql.Row +} + +func New() *Queries { + return &Queries{} +} + +type Queries struct { +} diff --git a/riverdriver/rivermysql/internal/dbsqlc/models.go b/riverdriver/rivermysql/internal/dbsqlc/models.go new file mode 100644 index 00000000..47119dde --- /dev/null +++ b/riverdriver/rivermysql/internal/dbsqlc/models.go @@ -0,0 +1,82 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.31.0 + +package dbsqlc + +import ( + "database/sql" + "time" +) + +type RiverClient struct { + ID string + CreatedAt time.Time + Metadata []byte + PausedAt sql.NullTime + UpdatedAt time.Time +} + +type RiverClientQueue struct { + RiverClientID string + Name string + CreatedAt time.Time + MaxWorkers int64 + Metadata []byte + NumJobsCompleted int64 + NumJobsRunning int64 + UpdatedAt time.Time +} + +type RiverJob struct { + ID int64 + Args []byte + Attempt int64 + AttemptedAt sql.NullTime + AttemptedBy []byte + CreatedAt time.Time + Errors []byte + FinalizedAt sql.NullTime + Kind string + MaxAttempts int64 + Metadata []byte + Priority int16 + Queue string + State string + ScheduledAt time.Time + Tags []byte + UniqueKey sql.NullString + UniqueStates sql.NullInt16 +} + +type RiverLeader struct { + ElectedAt time.Time + ExpiresAt time.Time + LeaderID string + Name string +} + +type RiverMigration struct { + Line string + Version int64 + CreatedAt time.Time +} + +type RiverQueue struct { + Name string + CreatedAt time.Time + Metadata []byte + PausedAt sql.NullTime + UpdatedAt time.Time +} + +type Statistics struct { + IndexName string + TableName string + TableSchema string +} + +type Tables struct { + TableName string + TableSchema string +} diff --git a/riverdriver/rivermysql/internal/dbsqlc/river_client.sql b/riverdriver/rivermysql/internal/dbsqlc/river_client.sql new file mode 100644 index 00000000..1ae8aab7 --- /dev/null +++ b/riverdriver/rivermysql/internal/dbsqlc/river_client.sql @@ -0,0 +1,7 @@ +CREATE TABLE river_client ( + id VARCHAR(128) NOT NULL PRIMARY KEY, + created_at DATETIME(6) NOT NULL DEFAULT (NOW(6)), + metadata JSON NOT NULL DEFAULT (JSON_OBJECT()), + paused_at DATETIME(6) NULL, + updated_at DATETIME(6) NOT NULL +); diff --git a/riverdriver/rivermysql/internal/dbsqlc/river_client_queue.sql b/riverdriver/rivermysql/internal/dbsqlc/river_client_queue.sql new file mode 100644 index 00000000..46981cb4 --- /dev/null +++ b/riverdriver/rivermysql/internal/dbsqlc/river_client_queue.sql @@ -0,0 +1,12 @@ +CREATE TABLE river_client_queue ( + river_client_id VARCHAR(128) NOT NULL, + name VARCHAR(128) NOT NULL, + created_at DATETIME(6) NOT NULL DEFAULT (NOW(6)), + max_workers INT NOT NULL DEFAULT 0, + metadata JSON NOT NULL DEFAULT (JSON_OBJECT()), + num_jobs_completed BIGINT NOT NULL DEFAULT 0, + num_jobs_running BIGINT NOT NULL DEFAULT 0, + updated_at DATETIME(6) NOT NULL, + PRIMARY KEY (river_client_id, name), + CONSTRAINT fk_river_client FOREIGN KEY (river_client_id) REFERENCES river_client (id) ON DELETE CASCADE +); diff --git a/riverdriver/rivermysql/internal/dbsqlc/river_job.sql b/riverdriver/rivermysql/internal/dbsqlc/river_job.sql new file mode 100644 index 00000000..c26e9ba0 --- /dev/null +++ b/riverdriver/rivermysql/internal/dbsqlc/river_job.sql @@ -0,0 +1,420 @@ +CREATE TABLE river_job ( + id BIGINT AUTO_INCREMENT PRIMARY KEY, + args JSON NOT NULL DEFAULT (JSON_OBJECT()), + attempt INT NOT NULL DEFAULT 0, + attempted_at DATETIME(6) NULL, + attempted_by JSON NULL, + created_at DATETIME(6) NOT NULL DEFAULT (NOW(6)), + errors JSON NULL, + finalized_at DATETIME(6) NULL, + kind VARCHAR(128) NOT NULL, + max_attempts INT NOT NULL, + metadata JSON NOT NULL DEFAULT (JSON_OBJECT()), + priority SMALLINT NOT NULL DEFAULT 1, + queue VARCHAR(128) NOT NULL DEFAULT 'default', + state VARCHAR(20) NOT NULL DEFAULT 'available', + scheduled_at DATETIME(6) NOT NULL DEFAULT (NOW(6)), + tags JSON NOT NULL DEFAULT (JSON_ARRAY()), + unique_key VARBINARY(255) NULL, + unique_states SMALLINT NULL +); + +-- name: JobGetByID :one +SELECT * +FROM /* TEMPLATE: schema */river_job +WHERE id = sqlc.arg('id') +LIMIT 1; + +-- name: JobGetByIDMany :many +SELECT * +FROM /* TEMPLATE: schema */river_job +WHERE id IN (sqlc.slice('id')) +ORDER BY id; + +-- name: JobGetByKindMany :many +SELECT * +FROM /* TEMPLATE: schema */river_job +WHERE kind IN (sqlc.slice('kind')) +ORDER BY id; + +-- name: JobCancelExec :execresult +UPDATE /* TEMPLATE: schema */river_job +SET + state = CASE WHEN state = 'running' THEN state ELSE 'cancelled' END, + finalized_at = CASE WHEN state = 'running' THEN finalized_at ELSE COALESCE(sqlc.narg('now'), NOW(6)) END, + metadata = JSON_SET(metadata, '$.cancel_attempted_at', CAST(sqlc.arg('cancel_attempted_at') AS CHAR)) +WHERE id = sqlc.arg('id') + AND state NOT IN ('cancelled', 'completed', 'discarded') + AND finalized_at IS NULL; + +-- name: JobCountByAllStates :many +SELECT state, count(*) AS count +FROM /* TEMPLATE: schema */river_job +GROUP BY state; + +-- name: JobCountByQueueAndState :many +WITH all_queues AS ( + SELECT DISTINCT river_job.queue + FROM /* TEMPLATE: schema */river_job + WHERE river_job.queue IN (sqlc.slice('queue_names')) +), + +running_job_counts AS ( + SELECT river_job.queue, COUNT(*) AS count + FROM /* TEMPLATE: schema */river_job + WHERE river_job.queue IN (sqlc.slice('queue_names')) + AND river_job.state = 'running' + GROUP BY river_job.queue +), + +available_job_counts AS ( + SELECT river_job.queue, COUNT(*) AS count + FROM /* TEMPLATE: schema */river_job + WHERE river_job.queue IN (sqlc.slice('queue_names')) + AND river_job.state = 'available' + GROUP BY river_job.queue +) + +SELECT + all_queues.queue, + COALESCE(available_job_counts.count, 0) AS count_available, + COALESCE(running_job_counts.count, 0) AS count_running +FROM all_queues +LEFT JOIN running_job_counts ON all_queues.queue = running_job_counts.queue +LEFT JOIN available_job_counts ON all_queues.queue = available_job_counts.queue +ORDER BY all_queues.queue ASC; + +-- name: JobCountByState :one +SELECT count(*) AS count +FROM /* TEMPLATE: schema */river_job +WHERE state = sqlc.arg('state'); + +-- name: JobDeleteExec :execresult +DELETE FROM /* TEMPLATE: schema */river_job +WHERE id = sqlc.arg('id') + AND river_job.state != 'running'; + +-- name: JobDeleteBefore :execresult +DELETE FROM /* TEMPLATE: schema */river_job +WHERE + id IN ( + SELECT id FROM ( + SELECT rj2.id + FROM /* TEMPLATE: schema */river_job rj2 + WHERE + (rj2.state = 'cancelled' AND rj2.finalized_at < sqlc.arg('cancelled_finalized_at_horizon')) OR + (rj2.state = 'completed' AND rj2.finalized_at < sqlc.arg('completed_finalized_at_horizon')) OR + (rj2.state = 'discarded' AND rj2.finalized_at < sqlc.arg('discarded_finalized_at_horizon')) + ORDER BY rj2.id + LIMIT ? + ) AS tmp + ) + AND ( + CAST(sqlc.arg('queues_excluded_empty') AS SIGNED) + OR river_job.queue NOT IN (sqlc.slice('queues_excluded')) + ); + +-- name: JobDeleteManySelect :many +SELECT * +FROM /* TEMPLATE: schema */river_job +WHERE id IN ( + SELECT id FROM ( + SELECT id + FROM /* TEMPLATE: schema */river_job + WHERE /* TEMPLATE_BEGIN: where_clause */ true /* TEMPLATE_END */ + AND state != 'running' + ORDER BY /* TEMPLATE_BEGIN: order_by_clause */ id /* TEMPLATE_END */ + LIMIT ? + ) AS tmp +) +ORDER BY id; + +-- name: JobDeleteManyExec :exec +DELETE FROM /* TEMPLATE: schema */river_job +WHERE id IN (sqlc.slice('id')); + +-- name: JobGetAvailableIDs :many +SELECT id +FROM /* TEMPLATE: schema */river_job +WHERE + priority >= 0 + AND queue = sqlc.arg('queue') + AND scheduled_at <= COALESCE(sqlc.narg('now'), NOW(6)) + AND state = 'available' +ORDER BY priority ASC, scheduled_at ASC, id ASC +LIMIT ? +FOR UPDATE SKIP LOCKED; + +-- name: JobGetAvailableUpdate :exec +UPDATE /* TEMPLATE: schema */river_job +SET + attempt = attempt + 1, + attempted_at = COALESCE(sqlc.narg('now'), NOW(6)), + attempted_by = JSON_ARRAY_APPEND( + CASE + WHEN JSON_LENGTH(COALESCE(attempted_by, JSON_ARRAY())) < CAST(sqlc.arg('max_attempted_by') AS SIGNED) + THEN COALESCE(attempted_by, JSON_ARRAY()) + WHEN CAST(sqlc.arg('max_attempted_by') AS SIGNED) <= 1 + THEN JSON_ARRAY() + ELSE COALESCE( + JSON_EXTRACT( + attempted_by, + CONCAT('$[last-', CAST(sqlc.arg('max_attempted_by') AS SIGNED) - 2, ' to last]') + ), + JSON_ARRAY() + ) + END, + '$', + CAST(sqlc.arg('attempted_by') AS CHAR) + ), + state = 'running' +WHERE id IN (sqlc.slice('id')); + +-- name: JobGetByIDManyOrdered :many +SELECT * +FROM /* TEMPLATE: schema */river_job +WHERE id IN (sqlc.slice('id')) +ORDER BY priority ASC, scheduled_at ASC, id ASC; + +-- name: JobGetStuck :many +SELECT * +FROM /* TEMPLATE: schema */river_job +WHERE state = 'running' + AND attempted_at < sqlc.arg('stuck_horizon') +ORDER BY id +LIMIT ?; + +-- name: JobInsertFast :execresult +INSERT INTO /* TEMPLATE: schema */river_job( + id, + args, + created_at, + kind, + max_attempts, + metadata, + priority, + queue, + scheduled_at, + state, + tags, + unique_key, + unique_states +) VALUES ( + sqlc.narg('id'), + sqlc.arg('args'), + COALESCE(sqlc.narg('created_at'), NOW(6)), + sqlc.arg('kind'), + sqlc.arg('max_attempts'), + CAST(sqlc.arg('metadata') AS JSON), + sqlc.arg('priority'), + sqlc.arg('queue'), + COALESCE(sqlc.narg('scheduled_at'), NOW(6)), + sqlc.arg('state'), + CAST(sqlc.arg('tags') AS JSON), + sqlc.narg('unique_key'), + sqlc.narg('unique_states') +) +ON DUPLICATE KEY UPDATE id = LAST_INSERT_ID(id), kind = VALUES(kind); + +-- name: JobInsertFullExec :execlastid +INSERT INTO /* TEMPLATE: schema */river_job( + args, + attempt, + attempted_at, + attempted_by, + created_at, + errors, + finalized_at, + kind, + max_attempts, + metadata, + priority, + queue, + scheduled_at, + state, + tags, + unique_key, + unique_states +) VALUES ( + sqlc.arg('args'), + sqlc.arg('attempt'), + sqlc.narg('attempted_at'), + CAST(sqlc.narg('attempted_by') AS JSON), + COALESCE(sqlc.narg('created_at'), NOW(6)), + CAST(sqlc.narg('errors') AS JSON), + sqlc.narg('finalized_at'), + sqlc.arg('kind'), + sqlc.arg('max_attempts'), + CAST(sqlc.arg('metadata') AS JSON), + sqlc.arg('priority'), + sqlc.arg('queue'), + COALESCE(sqlc.narg('scheduled_at'), NOW(6)), + sqlc.arg('state'), + CAST(sqlc.arg('tags') AS JSON), + sqlc.narg('unique_key'), + sqlc.narg('unique_states') +); + +-- name: JobKindList :many +SELECT DISTINCT kind +FROM /* TEMPLATE: schema */river_job +WHERE (sqlc.arg('match') = '' OR LOWER(kind) COLLATE utf8mb4_general_ci LIKE CONCAT('%', LOWER(sqlc.arg('match')), '%')) + AND (sqlc.arg('after') = '' OR kind COLLATE utf8mb4_general_ci > sqlc.arg('after')) + AND kind NOT IN (sqlc.slice('exclude')) +ORDER BY kind ASC +LIMIT ?; + +-- name: JobList :many +SELECT * +FROM /* TEMPLATE: schema */river_job +WHERE /* TEMPLATE_BEGIN: where_clause */ true /* TEMPLATE_END */ +ORDER BY /* TEMPLATE_BEGIN: order_by_clause */ id /* TEMPLATE_END */ +LIMIT ?; + +-- name: JobRescue :exec +UPDATE /* TEMPLATE: schema */river_job +SET + errors = JSON_ARRAY_APPEND(COALESCE(errors, JSON_ARRAY()), '$', CAST(sqlc.arg('error') AS JSON)), + finalized_at = sqlc.narg('finalized_at'), + scheduled_at = sqlc.arg('scheduled_at'), + metadata = JSON_SET( + metadata, + '$."river:rescue_count"', + COALESCE( + CASE JSON_TYPE(JSON_EXTRACT(metadata, '$."river:rescue_count"')) + WHEN 'INTEGER' THEN JSON_EXTRACT(metadata, '$."river:rescue_count"') + WHEN 'DOUBLE' THEN JSON_EXTRACT(metadata, '$."river:rescue_count"') + ELSE NULL + END, + 0 + ) + 1 + ), + state = sqlc.arg('state') +WHERE id = sqlc.arg('id'); + +-- name: JobRetryExec :execresult +UPDATE /* TEMPLATE: schema */river_job +SET + state = 'available', + max_attempts = CASE WHEN attempt = max_attempts THEN max_attempts + 1 ELSE max_attempts END, + finalized_at = NULL, + scheduled_at = COALESCE(sqlc.narg('now'), NOW(6)) +WHERE id = sqlc.arg('id') + AND state != 'running' + AND ( + state <> 'available' + OR scheduled_at > COALESCE(sqlc.narg('now'), NOW(6)) + ); + +-- name: JobSchedule :many +WITH eligible AS ( + SELECT river_job.id, river_job.unique_key, river_job.unique_states, river_job.priority, river_job.scheduled_at, + CASE + WHEN river_job.unique_key IS NOT NULL AND river_job.unique_states IS NOT NULL THEN + ROW_NUMBER() OVER (PARTITION BY river_job.unique_key ORDER BY river_job.priority, river_job.scheduled_at, river_job.id) + ELSE NULL + END AS row_num + FROM /* TEMPLATE: schema */river_job + WHERE + river_job.state IN ('retryable', 'scheduled') + AND river_job.scheduled_at <= COALESCE(sqlc.narg('now'), NOW(6)) + ORDER BY + river_job.priority, + river_job.scheduled_at, + river_job.id + LIMIT ? +), +unique_conflicts AS ( + SELECT DISTINCT eligible.unique_key + FROM /* TEMPLATE: schema */river_job + JOIN eligible + ON river_job.unique_key = eligible.unique_key + AND river_job.id != eligible.id + WHERE + river_job.unique_key IS NOT NULL + AND river_job.unique_states IS NOT NULL + AND CASE river_job.state + WHEN 'available' THEN river_job.unique_states & (1 << 0) + WHEN 'cancelled' THEN river_job.unique_states & (1 << 1) + WHEN 'completed' THEN river_job.unique_states & (1 << 2) + WHEN 'discarded' THEN river_job.unique_states & (1 << 3) + WHEN 'pending' THEN river_job.unique_states & (1 << 4) + WHEN 'retryable' THEN river_job.unique_states & (1 << 5) + WHEN 'running' THEN river_job.unique_states & (1 << 6) + WHEN 'scheduled' THEN river_job.unique_states & (1 << 7) + ELSE 0 + END >= 1 +) +SELECT eligible.id, + CASE + WHEN eligible.unique_key IS NULL OR eligible.unique_states IS NULL THEN FALSE + WHEN uc.unique_key IS NOT NULL THEN TRUE + WHEN eligible.row_num > 1 THEN TRUE + ELSE FALSE + END AS conflict_discarded +FROM eligible +LEFT JOIN unique_conflicts uc ON eligible.unique_key = uc.unique_key +ORDER BY eligible.priority, eligible.scheduled_at, eligible.id; + +-- name: JobScheduleSetAvailableExec :exec +UPDATE /* TEMPLATE: schema */river_job +SET state = 'available' +WHERE id IN (sqlc.slice('id')); + +-- name: JobScheduleSetDiscardedExec :exec +UPDATE /* TEMPLATE: schema */river_job +SET metadata = JSON_MERGE_PATCH(metadata, '{"unique_key_conflict": "scheduler_discarded"}'), + finalized_at = COALESCE(sqlc.narg('now'), NOW(6)), + state = 'discarded' +WHERE id IN (sqlc.slice('id')); + +-- name: JobSetMetadataIfNotRunningExec :execresult +UPDATE /* TEMPLATE: schema */river_job +SET metadata = JSON_MERGE_PATCH(metadata, CAST(sqlc.arg('metadata_updates') AS JSON)) +WHERE id = sqlc.arg('id') + AND state != 'running'; + +-- name: JobSetStateIfRunningExec :exec +UPDATE /* TEMPLATE: schema */river_job +SET + attempt = CASE WHEN (CAST(sqlc.arg('state') AS CHAR) <> 'retryable' AND sqlc.arg('state') <> 'scheduled' OR JSON_EXTRACT(metadata, '$.cancel_attempted_at') IS NULL) AND CAST(sqlc.arg('attempt_do_update') AS SIGNED) + THEN sqlc.arg('attempt') + ELSE attempt END, + errors = CASE WHEN CAST(sqlc.arg('errors_do_update') AS SIGNED) + THEN JSON_ARRAY_APPEND(COALESCE(errors, JSON_ARRAY()), '$', CAST(sqlc.arg('error') AS JSON)) + ELSE errors END, + finalized_at = CASE WHEN ((sqlc.arg('state') = 'retryable' OR sqlc.arg('state') = 'scheduled') AND JSON_EXTRACT(metadata, '$.cancel_attempted_at') IS NOT NULL) + THEN COALESCE(sqlc.narg('now'), NOW(6)) + WHEN CAST(sqlc.arg('finalized_at_do_update') AS SIGNED) + THEN sqlc.narg('finalized_at') + ELSE finalized_at END, + metadata = CASE WHEN CAST(sqlc.arg('metadata_do_merge') AS SIGNED) + THEN JSON_MERGE_PATCH(metadata, CAST(sqlc.arg('metadata_updates') AS JSON)) + ELSE metadata END, + scheduled_at = CASE WHEN (CAST(sqlc.arg('state') AS CHAR) <> 'retryable' AND sqlc.arg('state') <> 'scheduled' OR JSON_EXTRACT(metadata, '$.cancel_attempted_at') IS NULL) AND CAST(sqlc.arg('scheduled_at_do_update') AS SIGNED) + THEN sqlc.arg('scheduled_at') + ELSE scheduled_at END, + state = CASE WHEN ((sqlc.arg('state') = 'retryable' OR sqlc.arg('state') = 'scheduled') AND JSON_EXTRACT(metadata, '$.cancel_attempted_at') IS NOT NULL) + THEN 'cancelled' + ELSE sqlc.arg('state') END +WHERE id = sqlc.arg('id') + AND state = 'running'; + +-- name: JobUpdateExec :exec +UPDATE /* TEMPLATE: schema */river_job +SET + metadata = CASE WHEN CAST(sqlc.arg('metadata_do_merge') AS SIGNED) THEN JSON_MERGE_PATCH(metadata, CAST(sqlc.arg('metadata') AS JSON)) ELSE metadata END +WHERE id = sqlc.arg('id'); + +-- name: JobUpdateFullExec :exec +UPDATE /* TEMPLATE: schema */river_job +SET + attempt = CASE WHEN CAST(sqlc.arg('attempt_do_update') AS SIGNED) THEN sqlc.arg('attempt') ELSE attempt END, + attempted_at = CASE WHEN CAST(sqlc.arg('attempted_at_do_update') AS SIGNED) THEN sqlc.narg('attempted_at') ELSE attempted_at END, + attempted_by = CASE WHEN CAST(sqlc.arg('attempted_by_do_update') AS SIGNED) THEN CAST(sqlc.arg('attempted_by') AS JSON) ELSE attempted_by END, + errors = CASE WHEN CAST(sqlc.arg('errors_do_update') AS SIGNED) THEN CAST(sqlc.arg('errors') AS JSON) ELSE errors END, + finalized_at = CASE WHEN CAST(sqlc.arg('finalized_at_do_update') AS SIGNED) THEN sqlc.narg('finalized_at') ELSE finalized_at END, + max_attempts = CASE WHEN CAST(sqlc.arg('max_attempts_do_update') AS SIGNED) THEN sqlc.arg('max_attempts') ELSE max_attempts END, + metadata = CASE WHEN CAST(sqlc.arg('metadata_do_update') AS SIGNED) THEN CAST(sqlc.arg('metadata') AS JSON) ELSE metadata END, + state = CASE WHEN CAST(sqlc.arg('state_do_update') AS SIGNED) THEN sqlc.arg('state') ELSE state END +WHERE id = sqlc.arg('id'); diff --git a/riverdriver/rivermysql/internal/dbsqlc/river_job.sql.go b/riverdriver/rivermysql/internal/dbsqlc/river_job.sql.go new file mode 100644 index 00000000..72e8047d --- /dev/null +++ b/riverdriver/rivermysql/internal/dbsqlc/river_job.sql.go @@ -0,0 +1,1289 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.31.0 +// source: river_job.sql + +package dbsqlc + +import ( + "context" + "database/sql" + "strings" + "time" +) + +const jobCancelExec = `-- name: JobCancelExec :execresult +UPDATE /* TEMPLATE: schema */river_job +SET + state = CASE WHEN state = 'running' THEN state ELSE 'cancelled' END, + finalized_at = CASE WHEN state = 'running' THEN finalized_at ELSE COALESCE(?, NOW(6)) END, + metadata = JSON_SET(metadata, '$.cancel_attempted_at', CAST(? AS CHAR)) +WHERE id = ? + AND state NOT IN ('cancelled', 'completed', 'discarded') + AND finalized_at IS NULL +` + +type JobCancelExecParams struct { + Now sql.NullTime + CancelAttemptedAt interface{} + ID int64 +} + +func (q *Queries) JobCancelExec(ctx context.Context, db DBTX, arg *JobCancelExecParams) (sql.Result, error) { + return db.ExecContext(ctx, jobCancelExec, arg.Now, arg.CancelAttemptedAt, arg.ID) +} + +const jobCountByAllStates = `-- name: JobCountByAllStates :many +SELECT state, count(*) AS count +FROM /* TEMPLATE: schema */river_job +GROUP BY state +` + +type JobCountByAllStatesRow struct { + State string + Count int64 +} + +func (q *Queries) JobCountByAllStates(ctx context.Context, db DBTX) ([]*JobCountByAllStatesRow, error) { + rows, err := db.QueryContext(ctx, jobCountByAllStates) + if err != nil { + return nil, err + } + defer rows.Close() + var items []*JobCountByAllStatesRow + for rows.Next() { + var i JobCountByAllStatesRow + if err := rows.Scan(&i.State, &i.Count); err != nil { + return nil, err + } + items = append(items, &i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const jobCountByQueueAndState = `-- name: JobCountByQueueAndState :many +WITH all_queues AS ( + SELECT DISTINCT river_job.queue + FROM /* TEMPLATE: schema */river_job + WHERE river_job.queue IN (/*SLICE:queue_names*/?) +), + +running_job_counts AS ( + SELECT river_job.queue, COUNT(*) AS count + FROM /* TEMPLATE: schema */river_job + WHERE river_job.queue IN (/*SLICE:queue_names*/?) + AND river_job.state = 'running' + GROUP BY river_job.queue +), + +available_job_counts AS ( + SELECT river_job.queue, COUNT(*) AS count + FROM /* TEMPLATE: schema */river_job + WHERE river_job.queue IN (/*SLICE:queue_names*/?) + AND river_job.state = 'available' + GROUP BY river_job.queue +) + +SELECT + all_queues.queue, + COALESCE(available_job_counts.count, 0) AS count_available, + COALESCE(running_job_counts.count, 0) AS count_running +FROM all_queues +LEFT JOIN running_job_counts ON all_queues.queue = running_job_counts.queue +LEFT JOIN available_job_counts ON all_queues.queue = available_job_counts.queue +ORDER BY all_queues.queue ASC +` + +type JobCountByQueueAndStateParams struct { + QueueNames []string +} + +type JobCountByQueueAndStateRow struct { + Queue string + CountAvailable int64 + CountRunning int64 +} + +func (q *Queries) JobCountByQueueAndState(ctx context.Context, db DBTX, arg *JobCountByQueueAndStateParams) ([]*JobCountByQueueAndStateRow, error) { + query := jobCountByQueueAndState + var queryParams []interface{} + if len(arg.QueueNames) > 0 { + for _, v := range arg.QueueNames { + queryParams = append(queryParams, v) + } + query = strings.Replace(query, "/*SLICE:queue_names*/?", strings.Repeat(",?", len(arg.QueueNames))[1:], 1) + } else { + query = strings.Replace(query, "/*SLICE:queue_names*/?", "NULL", 1) + } + if len(arg.QueueNames) > 0 { + for _, v := range arg.QueueNames { + queryParams = append(queryParams, v) + } + query = strings.Replace(query, "/*SLICE:queue_names*/?", strings.Repeat(",?", len(arg.QueueNames))[1:], 1) + } else { + query = strings.Replace(query, "/*SLICE:queue_names*/?", "NULL", 1) + } + if len(arg.QueueNames) > 0 { + for _, v := range arg.QueueNames { + queryParams = append(queryParams, v) + } + query = strings.Replace(query, "/*SLICE:queue_names*/?", strings.Repeat(",?", len(arg.QueueNames))[1:], 1) + } else { + query = strings.Replace(query, "/*SLICE:queue_names*/?", "NULL", 1) + } + rows, err := db.QueryContext(ctx, query, queryParams...) + if err != nil { + return nil, err + } + defer rows.Close() + var items []*JobCountByQueueAndStateRow + for rows.Next() { + var i JobCountByQueueAndStateRow + if err := rows.Scan(&i.Queue, &i.CountAvailable, &i.CountRunning); err != nil { + return nil, err + } + items = append(items, &i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const jobCountByState = `-- name: JobCountByState :one +SELECT count(*) AS count +FROM /* TEMPLATE: schema */river_job +WHERE state = ? +` + +func (q *Queries) JobCountByState(ctx context.Context, db DBTX, state string) (int64, error) { + row := db.QueryRowContext(ctx, jobCountByState, state) + var count int64 + err := row.Scan(&count) + return count, err +} + +const jobDeleteBefore = `-- name: JobDeleteBefore :execresult +DELETE FROM /* TEMPLATE: schema */river_job +WHERE + id IN ( + SELECT id FROM ( + SELECT rj2.id + FROM /* TEMPLATE: schema */river_job rj2 + WHERE + (rj2.state = 'cancelled' AND rj2.finalized_at < ?) OR + (rj2.state = 'completed' AND rj2.finalized_at < ?) OR + (rj2.state = 'discarded' AND rj2.finalized_at < ?) + ORDER BY rj2.id + LIMIT ? + ) AS tmp + ) + AND ( + CAST(? AS SIGNED) + OR river_job.queue NOT IN (/*SLICE:queues_excluded*/?) + ) +` + +type JobDeleteBeforeParams struct { + CancelledFinalizedAtHorizon sql.NullTime + CompletedFinalizedAtHorizon sql.NullTime + DiscardedFinalizedAtHorizon sql.NullTime + Limit int32 + QueuesExcludedEmpty int64 + QueuesExcluded []string +} + +func (q *Queries) JobDeleteBefore(ctx context.Context, db DBTX, arg *JobDeleteBeforeParams) (sql.Result, error) { + query := jobDeleteBefore + var queryParams []interface{} + queryParams = append(queryParams, arg.CancelledFinalizedAtHorizon) + queryParams = append(queryParams, arg.CompletedFinalizedAtHorizon) + queryParams = append(queryParams, arg.DiscardedFinalizedAtHorizon) + queryParams = append(queryParams, arg.Limit) + queryParams = append(queryParams, arg.QueuesExcludedEmpty) + if len(arg.QueuesExcluded) > 0 { + for _, v := range arg.QueuesExcluded { + queryParams = append(queryParams, v) + } + query = strings.Replace(query, "/*SLICE:queues_excluded*/?", strings.Repeat(",?", len(arg.QueuesExcluded))[1:], 1) + } else { + query = strings.Replace(query, "/*SLICE:queues_excluded*/?", "NULL", 1) + } + return db.ExecContext(ctx, query, queryParams...) +} + +const jobDeleteExec = `-- name: JobDeleteExec :execresult +DELETE FROM /* TEMPLATE: schema */river_job +WHERE id = ? + AND river_job.state != 'running' +` + +func (q *Queries) JobDeleteExec(ctx context.Context, db DBTX, id int64) (sql.Result, error) { + return db.ExecContext(ctx, jobDeleteExec, id) +} + +const jobDeleteManyExec = `-- name: JobDeleteManyExec :exec +DELETE FROM /* TEMPLATE: schema */river_job +WHERE id IN (/*SLICE:id*/?) +` + +func (q *Queries) JobDeleteManyExec(ctx context.Context, db DBTX, id []int64) error { + query := jobDeleteManyExec + var queryParams []interface{} + if len(id) > 0 { + for _, v := range id { + queryParams = append(queryParams, v) + } + query = strings.Replace(query, "/*SLICE:id*/?", strings.Repeat(",?", len(id))[1:], 1) + } else { + query = strings.Replace(query, "/*SLICE:id*/?", "NULL", 1) + } + _, err := db.ExecContext(ctx, query, queryParams...) + return err +} + +const jobDeleteManySelect = `-- name: JobDeleteManySelect :many +SELECT id, args, attempt, attempted_at, attempted_by, created_at, errors, finalized_at, kind, max_attempts, metadata, priority, queue, state, scheduled_at, tags, unique_key, unique_states +FROM /* TEMPLATE: schema */river_job +WHERE id IN ( + SELECT id FROM ( + SELECT id + FROM /* TEMPLATE: schema */river_job + WHERE /* TEMPLATE_BEGIN: where_clause */ true /* TEMPLATE_END */ + AND state != 'running' + ORDER BY /* TEMPLATE_BEGIN: order_by_clause */ id /* TEMPLATE_END */ + LIMIT ? + ) AS tmp +) +ORDER BY id +` + +func (q *Queries) JobDeleteManySelect(ctx context.Context, db DBTX, limit int32) ([]*RiverJob, error) { + rows, err := db.QueryContext(ctx, jobDeleteManySelect, limit) + if err != nil { + return nil, err + } + defer rows.Close() + var items []*RiverJob + for rows.Next() { + var i RiverJob + if err := rows.Scan( + &i.ID, + &i.Args, + &i.Attempt, + &i.AttemptedAt, + &i.AttemptedBy, + &i.CreatedAt, + &i.Errors, + &i.FinalizedAt, + &i.Kind, + &i.MaxAttempts, + &i.Metadata, + &i.Priority, + &i.Queue, + &i.State, + &i.ScheduledAt, + &i.Tags, + &i.UniqueKey, + &i.UniqueStates, + ); err != nil { + return nil, err + } + items = append(items, &i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const jobGetAvailableIDs = `-- name: JobGetAvailableIDs :many +SELECT id +FROM /* TEMPLATE: schema */river_job +WHERE + priority >= 0 + AND queue = ? + AND scheduled_at <= COALESCE(?, NOW(6)) + AND state = 'available' +ORDER BY priority ASC, scheduled_at ASC, id ASC +LIMIT ? +FOR UPDATE SKIP LOCKED +` + +type JobGetAvailableIDsParams struct { + Queue string + Now sql.NullTime + Limit int32 +} + +func (q *Queries) JobGetAvailableIDs(ctx context.Context, db DBTX, arg *JobGetAvailableIDsParams) ([]int64, error) { + rows, err := db.QueryContext(ctx, jobGetAvailableIDs, arg.Queue, arg.Now, arg.Limit) + if err != nil { + return nil, err + } + defer rows.Close() + var items []int64 + for rows.Next() { + var id int64 + if err := rows.Scan(&id); err != nil { + return nil, err + } + items = append(items, id) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const jobGetAvailableUpdate = `-- name: JobGetAvailableUpdate :exec +UPDATE /* TEMPLATE: schema */river_job +SET + attempt = attempt + 1, + attempted_at = COALESCE(?, NOW(6)), + attempted_by = JSON_ARRAY_APPEND( + CASE + WHEN JSON_LENGTH(COALESCE(attempted_by, JSON_ARRAY())) < CAST(? AS SIGNED) + THEN COALESCE(attempted_by, JSON_ARRAY()) + WHEN CAST(? AS SIGNED) <= 1 + THEN JSON_ARRAY() + ELSE COALESCE( + JSON_EXTRACT( + attempted_by, + CONCAT('$[last-', CAST(? AS SIGNED) - 2, ' to last]') + ), + JSON_ARRAY() + ) + END, + '$', + CAST(? AS CHAR) + ), + state = 'running' +WHERE id IN (/*SLICE:id*/?) +` + +type JobGetAvailableUpdateParams struct { + Now sql.NullTime + MaxAttemptedBy int64 + AttemptedBy interface{} + ID []int64 +} + +func (q *Queries) JobGetAvailableUpdate(ctx context.Context, db DBTX, arg *JobGetAvailableUpdateParams) error { + query := jobGetAvailableUpdate + var queryParams []interface{} + queryParams = append(queryParams, arg.Now) + queryParams = append(queryParams, arg.MaxAttemptedBy) + queryParams = append(queryParams, arg.MaxAttemptedBy) + queryParams = append(queryParams, arg.MaxAttemptedBy) + queryParams = append(queryParams, arg.AttemptedBy) + if len(arg.ID) > 0 { + for _, v := range arg.ID { + queryParams = append(queryParams, v) + } + query = strings.Replace(query, "/*SLICE:id*/?", strings.Repeat(",?", len(arg.ID))[1:], 1) + } else { + query = strings.Replace(query, "/*SLICE:id*/?", "NULL", 1) + } + _, err := db.ExecContext(ctx, query, queryParams...) + return err +} + +const jobGetByID = `-- name: JobGetByID :one +SELECT id, args, attempt, attempted_at, attempted_by, created_at, errors, finalized_at, kind, max_attempts, metadata, priority, queue, state, scheduled_at, tags, unique_key, unique_states +FROM /* TEMPLATE: schema */river_job +WHERE id = ? +LIMIT 1 +` + +func (q *Queries) JobGetByID(ctx context.Context, db DBTX, id int64) (*RiverJob, error) { + row := db.QueryRowContext(ctx, jobGetByID, id) + var i RiverJob + err := row.Scan( + &i.ID, + &i.Args, + &i.Attempt, + &i.AttemptedAt, + &i.AttemptedBy, + &i.CreatedAt, + &i.Errors, + &i.FinalizedAt, + &i.Kind, + &i.MaxAttempts, + &i.Metadata, + &i.Priority, + &i.Queue, + &i.State, + &i.ScheduledAt, + &i.Tags, + &i.UniqueKey, + &i.UniqueStates, + ) + return &i, err +} + +const jobGetByIDMany = `-- name: JobGetByIDMany :many +SELECT id, args, attempt, attempted_at, attempted_by, created_at, errors, finalized_at, kind, max_attempts, metadata, priority, queue, state, scheduled_at, tags, unique_key, unique_states +FROM /* TEMPLATE: schema */river_job +WHERE id IN (/*SLICE:id*/?) +ORDER BY id +` + +func (q *Queries) JobGetByIDMany(ctx context.Context, db DBTX, id []int64) ([]*RiverJob, error) { + query := jobGetByIDMany + var queryParams []interface{} + if len(id) > 0 { + for _, v := range id { + queryParams = append(queryParams, v) + } + query = strings.Replace(query, "/*SLICE:id*/?", strings.Repeat(",?", len(id))[1:], 1) + } else { + query = strings.Replace(query, "/*SLICE:id*/?", "NULL", 1) + } + rows, err := db.QueryContext(ctx, query, queryParams...) + if err != nil { + return nil, err + } + defer rows.Close() + var items []*RiverJob + for rows.Next() { + var i RiverJob + if err := rows.Scan( + &i.ID, + &i.Args, + &i.Attempt, + &i.AttemptedAt, + &i.AttemptedBy, + &i.CreatedAt, + &i.Errors, + &i.FinalizedAt, + &i.Kind, + &i.MaxAttempts, + &i.Metadata, + &i.Priority, + &i.Queue, + &i.State, + &i.ScheduledAt, + &i.Tags, + &i.UniqueKey, + &i.UniqueStates, + ); err != nil { + return nil, err + } + items = append(items, &i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const jobGetByIDManyOrdered = `-- name: JobGetByIDManyOrdered :many +SELECT id, args, attempt, attempted_at, attempted_by, created_at, errors, finalized_at, kind, max_attempts, metadata, priority, queue, state, scheduled_at, tags, unique_key, unique_states +FROM /* TEMPLATE: schema */river_job +WHERE id IN (/*SLICE:id*/?) +ORDER BY priority ASC, scheduled_at ASC, id ASC +` + +func (q *Queries) JobGetByIDManyOrdered(ctx context.Context, db DBTX, id []int64) ([]*RiverJob, error) { + query := jobGetByIDManyOrdered + var queryParams []interface{} + if len(id) > 0 { + for _, v := range id { + queryParams = append(queryParams, v) + } + query = strings.Replace(query, "/*SLICE:id*/?", strings.Repeat(",?", len(id))[1:], 1) + } else { + query = strings.Replace(query, "/*SLICE:id*/?", "NULL", 1) + } + rows, err := db.QueryContext(ctx, query, queryParams...) + if err != nil { + return nil, err + } + defer rows.Close() + var items []*RiverJob + for rows.Next() { + var i RiverJob + if err := rows.Scan( + &i.ID, + &i.Args, + &i.Attempt, + &i.AttemptedAt, + &i.AttemptedBy, + &i.CreatedAt, + &i.Errors, + &i.FinalizedAt, + &i.Kind, + &i.MaxAttempts, + &i.Metadata, + &i.Priority, + &i.Queue, + &i.State, + &i.ScheduledAt, + &i.Tags, + &i.UniqueKey, + &i.UniqueStates, + ); err != nil { + return nil, err + } + items = append(items, &i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const jobGetByKindMany = `-- name: JobGetByKindMany :many +SELECT id, args, attempt, attempted_at, attempted_by, created_at, errors, finalized_at, kind, max_attempts, metadata, priority, queue, state, scheduled_at, tags, unique_key, unique_states +FROM /* TEMPLATE: schema */river_job +WHERE kind IN (/*SLICE:kind*/?) +ORDER BY id +` + +func (q *Queries) JobGetByKindMany(ctx context.Context, db DBTX, kind []string) ([]*RiverJob, error) { + query := jobGetByKindMany + var queryParams []interface{} + if len(kind) > 0 { + for _, v := range kind { + queryParams = append(queryParams, v) + } + query = strings.Replace(query, "/*SLICE:kind*/?", strings.Repeat(",?", len(kind))[1:], 1) + } else { + query = strings.Replace(query, "/*SLICE:kind*/?", "NULL", 1) + } + rows, err := db.QueryContext(ctx, query, queryParams...) + if err != nil { + return nil, err + } + defer rows.Close() + var items []*RiverJob + for rows.Next() { + var i RiverJob + if err := rows.Scan( + &i.ID, + &i.Args, + &i.Attempt, + &i.AttemptedAt, + &i.AttemptedBy, + &i.CreatedAt, + &i.Errors, + &i.FinalizedAt, + &i.Kind, + &i.MaxAttempts, + &i.Metadata, + &i.Priority, + &i.Queue, + &i.State, + &i.ScheduledAt, + &i.Tags, + &i.UniqueKey, + &i.UniqueStates, + ); err != nil { + return nil, err + } + items = append(items, &i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const jobGetStuck = `-- name: JobGetStuck :many +SELECT id, args, attempt, attempted_at, attempted_by, created_at, errors, finalized_at, kind, max_attempts, metadata, priority, queue, state, scheduled_at, tags, unique_key, unique_states +FROM /* TEMPLATE: schema */river_job +WHERE state = 'running' + AND attempted_at < ? +ORDER BY id +LIMIT ? +` + +type JobGetStuckParams struct { + StuckHorizon sql.NullTime + Limit int32 +} + +func (q *Queries) JobGetStuck(ctx context.Context, db DBTX, arg *JobGetStuckParams) ([]*RiverJob, error) { + rows, err := db.QueryContext(ctx, jobGetStuck, arg.StuckHorizon, arg.Limit) + if err != nil { + return nil, err + } + defer rows.Close() + var items []*RiverJob + for rows.Next() { + var i RiverJob + if err := rows.Scan( + &i.ID, + &i.Args, + &i.Attempt, + &i.AttemptedAt, + &i.AttemptedBy, + &i.CreatedAt, + &i.Errors, + &i.FinalizedAt, + &i.Kind, + &i.MaxAttempts, + &i.Metadata, + &i.Priority, + &i.Queue, + &i.State, + &i.ScheduledAt, + &i.Tags, + &i.UniqueKey, + &i.UniqueStates, + ); err != nil { + return nil, err + } + items = append(items, &i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const jobInsertFast = `-- name: JobInsertFast :execresult +INSERT INTO /* TEMPLATE: schema */river_job( + id, + args, + created_at, + kind, + max_attempts, + metadata, + priority, + queue, + scheduled_at, + state, + tags, + unique_key, + unique_states +) VALUES ( + ?, + ?, + COALESCE(?, NOW(6)), + ?, + ?, + CAST(? AS JSON), + ?, + ?, + COALESCE(?, NOW(6)), + ?, + CAST(? AS JSON), + ?, + ? +) +ON DUPLICATE KEY UPDATE id = LAST_INSERT_ID(id), kind = VALUES(kind) +` + +type JobInsertFastParams struct { + ID sql.NullInt64 + Args []byte + CreatedAt interface{} + Kind string + MaxAttempts int64 + Metadata []byte + Priority int16 + Queue string + ScheduledAt interface{} + State string + Tags []byte + UniqueKey sql.NullString + UniqueStates sql.NullInt16 +} + +func (q *Queries) JobInsertFast(ctx context.Context, db DBTX, arg *JobInsertFastParams) (sql.Result, error) { + return db.ExecContext(ctx, jobInsertFast, + arg.ID, + arg.Args, + arg.CreatedAt, + arg.Kind, + arg.MaxAttempts, + arg.Metadata, + arg.Priority, + arg.Queue, + arg.ScheduledAt, + arg.State, + arg.Tags, + arg.UniqueKey, + arg.UniqueStates, + ) +} + +const jobInsertFullExec = `-- name: JobInsertFullExec :execlastid +INSERT INTO /* TEMPLATE: schema */river_job( + args, + attempt, + attempted_at, + attempted_by, + created_at, + errors, + finalized_at, + kind, + max_attempts, + metadata, + priority, + queue, + scheduled_at, + state, + tags, + unique_key, + unique_states +) VALUES ( + ?, + ?, + ?, + CAST(? AS JSON), + COALESCE(?, NOW(6)), + CAST(? AS JSON), + ?, + ?, + ?, + CAST(? AS JSON), + ?, + ?, + COALESCE(?, NOW(6)), + ?, + CAST(? AS JSON), + ?, + ? +) +` + +type JobInsertFullExecParams struct { + Args []byte + Attempt int64 + AttemptedAt sql.NullTime + AttemptedBy []byte + CreatedAt interface{} + Errors []byte + FinalizedAt sql.NullTime + Kind string + MaxAttempts int64 + Metadata []byte + Priority int16 + Queue string + ScheduledAt interface{} + State string + Tags []byte + UniqueKey sql.NullString + UniqueStates sql.NullInt16 +} + +func (q *Queries) JobInsertFullExec(ctx context.Context, db DBTX, arg *JobInsertFullExecParams) (int64, error) { + result, err := db.ExecContext(ctx, jobInsertFullExec, + arg.Args, + arg.Attempt, + arg.AttemptedAt, + arg.AttemptedBy, + arg.CreatedAt, + arg.Errors, + arg.FinalizedAt, + arg.Kind, + arg.MaxAttempts, + arg.Metadata, + arg.Priority, + arg.Queue, + arg.ScheduledAt, + arg.State, + arg.Tags, + arg.UniqueKey, + arg.UniqueStates, + ) + if err != nil { + return 0, err + } + return result.LastInsertId() +} + +const jobKindList = `-- name: JobKindList :many +SELECT DISTINCT kind +FROM /* TEMPLATE: schema */river_job +WHERE (? = '' OR LOWER(kind) COLLATE utf8mb4_general_ci LIKE CONCAT('%', LOWER(?), '%')) + AND (? = '' OR kind COLLATE utf8mb4_general_ci > ?) + AND kind NOT IN (/*SLICE:exclude*/?) +ORDER BY kind ASC +LIMIT ? +` + +type JobKindListParams struct { + Match string + After interface{} + Exclude []string + Limit int32 +} + +func (q *Queries) JobKindList(ctx context.Context, db DBTX, arg *JobKindListParams) ([]string, error) { + query := jobKindList + var queryParams []interface{} + queryParams = append(queryParams, arg.Match) + queryParams = append(queryParams, arg.Match) + queryParams = append(queryParams, arg.After) + queryParams = append(queryParams, arg.After) + if len(arg.Exclude) > 0 { + for _, v := range arg.Exclude { + queryParams = append(queryParams, v) + } + query = strings.Replace(query, "/*SLICE:exclude*/?", strings.Repeat(",?", len(arg.Exclude))[1:], 1) + } else { + query = strings.Replace(query, "/*SLICE:exclude*/?", "NULL", 1) + } + queryParams = append(queryParams, arg.Limit) + rows, err := db.QueryContext(ctx, query, queryParams...) + if err != nil { + return nil, err + } + defer rows.Close() + var items []string + for rows.Next() { + var kind string + if err := rows.Scan(&kind); err != nil { + return nil, err + } + items = append(items, kind) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const jobList = `-- name: JobList :many +SELECT id, args, attempt, attempted_at, attempted_by, created_at, errors, finalized_at, kind, max_attempts, metadata, priority, queue, state, scheduled_at, tags, unique_key, unique_states +FROM /* TEMPLATE: schema */river_job +WHERE /* TEMPLATE_BEGIN: where_clause */ true /* TEMPLATE_END */ +ORDER BY /* TEMPLATE_BEGIN: order_by_clause */ id /* TEMPLATE_END */ +LIMIT ? +` + +func (q *Queries) JobList(ctx context.Context, db DBTX, limit int32) ([]*RiverJob, error) { + rows, err := db.QueryContext(ctx, jobList, limit) + if err != nil { + return nil, err + } + defer rows.Close() + var items []*RiverJob + for rows.Next() { + var i RiverJob + if err := rows.Scan( + &i.ID, + &i.Args, + &i.Attempt, + &i.AttemptedAt, + &i.AttemptedBy, + &i.CreatedAt, + &i.Errors, + &i.FinalizedAt, + &i.Kind, + &i.MaxAttempts, + &i.Metadata, + &i.Priority, + &i.Queue, + &i.State, + &i.ScheduledAt, + &i.Tags, + &i.UniqueKey, + &i.UniqueStates, + ); err != nil { + return nil, err + } + items = append(items, &i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const jobRescue = `-- name: JobRescue :exec +UPDATE /* TEMPLATE: schema */river_job +SET + errors = JSON_ARRAY_APPEND(COALESCE(errors, JSON_ARRAY()), '$', CAST(? AS JSON)), + finalized_at = ?, + scheduled_at = ?, + metadata = JSON_SET( + metadata, + '$."river:rescue_count"', + COALESCE( + CASE JSON_TYPE(JSON_EXTRACT(metadata, '$."river:rescue_count"')) + WHEN 'INTEGER' THEN JSON_EXTRACT(metadata, '$."river:rescue_count"') + WHEN 'DOUBLE' THEN JSON_EXTRACT(metadata, '$."river:rescue_count"') + ELSE NULL + END, + 0 + ) + 1 + ), + state = ? +WHERE id = ? +` + +type JobRescueParams struct { + Error []byte + FinalizedAt sql.NullTime + ScheduledAt time.Time + State string + ID int64 +} + +func (q *Queries) JobRescue(ctx context.Context, db DBTX, arg *JobRescueParams) error { + _, err := db.ExecContext(ctx, jobRescue, + arg.Error, + arg.FinalizedAt, + arg.ScheduledAt, + arg.State, + arg.ID, + ) + return err +} + +const jobRetryExec = `-- name: JobRetryExec :execresult +UPDATE /* TEMPLATE: schema */river_job +SET + state = 'available', + max_attempts = CASE WHEN attempt = max_attempts THEN max_attempts + 1 ELSE max_attempts END, + finalized_at = NULL, + scheduled_at = COALESCE(?, NOW(6)) +WHERE id = ? + AND state != 'running' + AND ( + state <> 'available' + OR scheduled_at > COALESCE(?, NOW(6)) + ) +` + +type JobRetryExecParams struct { + Now sql.NullTime + ID int64 +} + +func (q *Queries) JobRetryExec(ctx context.Context, db DBTX, arg *JobRetryExecParams) (sql.Result, error) { + return db.ExecContext(ctx, jobRetryExec, arg.Now, arg.ID, arg.Now) +} + +const jobSchedule = `-- name: JobSchedule :many +WITH eligible AS ( + SELECT river_job.id, river_job.unique_key, river_job.unique_states, river_job.priority, river_job.scheduled_at, + CASE + WHEN river_job.unique_key IS NOT NULL AND river_job.unique_states IS NOT NULL THEN + ROW_NUMBER() OVER (PARTITION BY river_job.unique_key ORDER BY river_job.priority, river_job.scheduled_at, river_job.id) + ELSE NULL + END AS row_num + FROM /* TEMPLATE: schema */river_job + WHERE + river_job.state IN ('retryable', 'scheduled') + AND river_job.scheduled_at <= COALESCE(?, NOW(6)) + ORDER BY + river_job.priority, + river_job.scheduled_at, + river_job.id + LIMIT ? +), +unique_conflicts AS ( + SELECT DISTINCT eligible.unique_key + FROM /* TEMPLATE: schema */river_job + JOIN eligible + ON river_job.unique_key = eligible.unique_key + AND river_job.id != eligible.id + WHERE + river_job.unique_key IS NOT NULL + AND river_job.unique_states IS NOT NULL + AND CASE river_job.state + WHEN 'available' THEN river_job.unique_states & (1 << 0) + WHEN 'cancelled' THEN river_job.unique_states & (1 << 1) + WHEN 'completed' THEN river_job.unique_states & (1 << 2) + WHEN 'discarded' THEN river_job.unique_states & (1 << 3) + WHEN 'pending' THEN river_job.unique_states & (1 << 4) + WHEN 'retryable' THEN river_job.unique_states & (1 << 5) + WHEN 'running' THEN river_job.unique_states & (1 << 6) + WHEN 'scheduled' THEN river_job.unique_states & (1 << 7) + ELSE 0 + END >= 1 +) +SELECT eligible.id, + CASE + WHEN eligible.unique_key IS NULL OR eligible.unique_states IS NULL THEN FALSE + WHEN uc.unique_key IS NOT NULL THEN TRUE + WHEN eligible.row_num > 1 THEN TRUE + ELSE FALSE + END AS conflict_discarded +FROM eligible +LEFT JOIN unique_conflicts uc ON eligible.unique_key = uc.unique_key +ORDER BY eligible.priority, eligible.scheduled_at, eligible.id +` + +type JobScheduleParams struct { + Now sql.NullTime + Limit int32 +} + +type JobScheduleRow struct { + ID int64 + ConflictDiscarded int64 +} + +func (q *Queries) JobSchedule(ctx context.Context, db DBTX, arg *JobScheduleParams) ([]*JobScheduleRow, error) { + rows, err := db.QueryContext(ctx, jobSchedule, arg.Now, arg.Limit) + if err != nil { + return nil, err + } + defer rows.Close() + var items []*JobScheduleRow + for rows.Next() { + var i JobScheduleRow + if err := rows.Scan(&i.ID, &i.ConflictDiscarded); err != nil { + return nil, err + } + items = append(items, &i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const jobScheduleSetAvailableExec = `-- name: JobScheduleSetAvailableExec :exec +UPDATE /* TEMPLATE: schema */river_job +SET state = 'available' +WHERE id IN (/*SLICE:id*/?) +` + +func (q *Queries) JobScheduleSetAvailableExec(ctx context.Context, db DBTX, id []int64) error { + query := jobScheduleSetAvailableExec + var queryParams []interface{} + if len(id) > 0 { + for _, v := range id { + queryParams = append(queryParams, v) + } + query = strings.Replace(query, "/*SLICE:id*/?", strings.Repeat(",?", len(id))[1:], 1) + } else { + query = strings.Replace(query, "/*SLICE:id*/?", "NULL", 1) + } + _, err := db.ExecContext(ctx, query, queryParams...) + return err +} + +const jobScheduleSetDiscardedExec = `-- name: JobScheduleSetDiscardedExec :exec +UPDATE /* TEMPLATE: schema */river_job +SET metadata = JSON_MERGE_PATCH(metadata, '{"unique_key_conflict": "scheduler_discarded"}'), + finalized_at = COALESCE(?, NOW(6)), + state = 'discarded' +WHERE id IN (/*SLICE:id*/?) +` + +type JobScheduleSetDiscardedExecParams struct { + Now sql.NullTime + ID []int64 +} + +func (q *Queries) JobScheduleSetDiscardedExec(ctx context.Context, db DBTX, arg *JobScheduleSetDiscardedExecParams) error { + query := jobScheduleSetDiscardedExec + var queryParams []interface{} + queryParams = append(queryParams, arg.Now) + if len(arg.ID) > 0 { + for _, v := range arg.ID { + queryParams = append(queryParams, v) + } + query = strings.Replace(query, "/*SLICE:id*/?", strings.Repeat(",?", len(arg.ID))[1:], 1) + } else { + query = strings.Replace(query, "/*SLICE:id*/?", "NULL", 1) + } + _, err := db.ExecContext(ctx, query, queryParams...) + return err +} + +const jobSetMetadataIfNotRunningExec = `-- name: JobSetMetadataIfNotRunningExec :execresult +UPDATE /* TEMPLATE: schema */river_job +SET metadata = JSON_MERGE_PATCH(metadata, CAST(? AS JSON)) +WHERE id = ? + AND state != 'running' +` + +type JobSetMetadataIfNotRunningExecParams struct { + MetadataUpdates []byte + ID int64 +} + +func (q *Queries) JobSetMetadataIfNotRunningExec(ctx context.Context, db DBTX, arg *JobSetMetadataIfNotRunningExecParams) (sql.Result, error) { + return db.ExecContext(ctx, jobSetMetadataIfNotRunningExec, arg.MetadataUpdates, arg.ID) +} + +const jobSetStateIfRunningExec = `-- name: JobSetStateIfRunningExec :exec +UPDATE /* TEMPLATE: schema */river_job +SET + attempt = CASE WHEN (CAST(? AS CHAR) <> 'retryable' AND ? <> 'scheduled' OR JSON_EXTRACT(metadata, '$.cancel_attempted_at') IS NULL) AND CAST(? AS SIGNED) + THEN ? + ELSE attempt END, + errors = CASE WHEN CAST(? AS SIGNED) + THEN JSON_ARRAY_APPEND(COALESCE(errors, JSON_ARRAY()), '$', CAST(? AS JSON)) + ELSE errors END, + finalized_at = CASE WHEN ((? = 'retryable' OR ? = 'scheduled') AND JSON_EXTRACT(metadata, '$.cancel_attempted_at') IS NOT NULL) + THEN COALESCE(?, NOW(6)) + WHEN CAST(? AS SIGNED) + THEN ? + ELSE finalized_at END, + metadata = CASE WHEN CAST(? AS SIGNED) + THEN JSON_MERGE_PATCH(metadata, CAST(? AS JSON)) + ELSE metadata END, + scheduled_at = CASE WHEN (CAST(? AS CHAR) <> 'retryable' AND ? <> 'scheduled' OR JSON_EXTRACT(metadata, '$.cancel_attempted_at') IS NULL) AND CAST(? AS SIGNED) + THEN ? + ELSE scheduled_at END, + state = CASE WHEN ((? = 'retryable' OR ? = 'scheduled') AND JSON_EXTRACT(metadata, '$.cancel_attempted_at') IS NOT NULL) + THEN 'cancelled' + ELSE ? END +WHERE id = ? + AND state = 'running' +` + +type JobSetStateIfRunningExecParams struct { + State string + AttemptDoUpdate int64 + Attempt int64 + ErrorsDoUpdate int64 + Error []byte + Now sql.NullTime + FinalizedAtDoUpdate int64 + FinalizedAt sql.NullTime + MetadataDoMerge int64 + MetadataUpdates []byte + ScheduledAtDoUpdate int64 + ScheduledAt time.Time + ID int64 +} + +func (q *Queries) JobSetStateIfRunningExec(ctx context.Context, db DBTX, arg *JobSetStateIfRunningExecParams) error { + _, err := db.ExecContext(ctx, jobSetStateIfRunningExec, + arg.State, + arg.State, + arg.AttemptDoUpdate, + arg.Attempt, + arg.ErrorsDoUpdate, + arg.Error, + arg.State, + arg.State, + arg.Now, + arg.FinalizedAtDoUpdate, + arg.FinalizedAt, + arg.MetadataDoMerge, + arg.MetadataUpdates, + arg.State, + arg.State, + arg.ScheduledAtDoUpdate, + arg.ScheduledAt, + arg.State, + arg.State, + arg.State, + arg.ID, + ) + return err +} + +const jobUpdateExec = `-- name: JobUpdateExec :exec +UPDATE /* TEMPLATE: schema */river_job +SET + metadata = CASE WHEN CAST(? AS SIGNED) THEN JSON_MERGE_PATCH(metadata, CAST(? AS JSON)) ELSE metadata END +WHERE id = ? +` + +type JobUpdateExecParams struct { + MetadataDoMerge int64 + Metadata []byte + ID int64 +} + +func (q *Queries) JobUpdateExec(ctx context.Context, db DBTX, arg *JobUpdateExecParams) error { + _, err := db.ExecContext(ctx, jobUpdateExec, arg.MetadataDoMerge, arg.Metadata, arg.ID) + return err +} + +const jobUpdateFullExec = `-- name: JobUpdateFullExec :exec +UPDATE /* TEMPLATE: schema */river_job +SET + attempt = CASE WHEN CAST(? AS SIGNED) THEN ? ELSE attempt END, + attempted_at = CASE WHEN CAST(? AS SIGNED) THEN ? ELSE attempted_at END, + attempted_by = CASE WHEN CAST(? AS SIGNED) THEN CAST(? AS JSON) ELSE attempted_by END, + errors = CASE WHEN CAST(? AS SIGNED) THEN CAST(? AS JSON) ELSE errors END, + finalized_at = CASE WHEN CAST(? AS SIGNED) THEN ? ELSE finalized_at END, + max_attempts = CASE WHEN CAST(? AS SIGNED) THEN ? ELSE max_attempts END, + metadata = CASE WHEN CAST(? AS SIGNED) THEN CAST(? AS JSON) ELSE metadata END, + state = CASE WHEN CAST(? AS SIGNED) THEN ? ELSE state END +WHERE id = ? +` + +type JobUpdateFullExecParams struct { + AttemptDoUpdate int64 + Attempt int64 + AttemptedAtDoUpdate int64 + AttemptedAt sql.NullTime + AttemptedByDoUpdate int64 + AttemptedBy []byte + ErrorsDoUpdate int64 + Errors []byte + FinalizedAtDoUpdate int64 + FinalizedAt sql.NullTime + MaxAttemptsDoUpdate int64 + MaxAttempts int64 + MetadataDoUpdate int64 + Metadata []byte + StateDoUpdate int64 + State string + ID int64 +} + +func (q *Queries) JobUpdateFullExec(ctx context.Context, db DBTX, arg *JobUpdateFullExecParams) error { + _, err := db.ExecContext(ctx, jobUpdateFullExec, + arg.AttemptDoUpdate, + arg.Attempt, + arg.AttemptedAtDoUpdate, + arg.AttemptedAt, + arg.AttemptedByDoUpdate, + arg.AttemptedBy, + arg.ErrorsDoUpdate, + arg.Errors, + arg.FinalizedAtDoUpdate, + arg.FinalizedAt, + arg.MaxAttemptsDoUpdate, + arg.MaxAttempts, + arg.MetadataDoUpdate, + arg.Metadata, + arg.StateDoUpdate, + arg.State, + arg.ID, + ) + return err +} diff --git a/riverdriver/rivermysql/internal/dbsqlc/river_leader.sql b/riverdriver/rivermysql/internal/dbsqlc/river_leader.sql new file mode 100644 index 00000000..0e2f77ab --- /dev/null +++ b/riverdriver/rivermysql/internal/dbsqlc/river_leader.sql @@ -0,0 +1,50 @@ +CREATE TABLE river_leader ( + elected_at DATETIME(6) NOT NULL, + expires_at DATETIME(6) NOT NULL, + leader_id VARCHAR(128) NOT NULL, + name VARCHAR(128) NOT NULL DEFAULT 'default' PRIMARY KEY +); + +-- name: LeaderAttemptElectExec :execresult +INSERT IGNORE INTO /* TEMPLATE: schema */river_leader ( + leader_id, + elected_at, + expires_at +) VALUES ( + sqlc.arg('leader_id'), + COALESCE(sqlc.narg('now'), NOW(6)), + TIMESTAMPADD(MICROSECOND, sqlc.arg('ttl'), COALESCE(sqlc.narg('now'), NOW(6))) +); + +-- name: LeaderAttemptReelectExec :execresult +UPDATE /* TEMPLATE: schema */river_leader +SET expires_at = TIMESTAMPADD(MICROSECOND, sqlc.arg('ttl'), COALESCE(sqlc.narg('now'), NOW(6))) +WHERE + elected_at = sqlc.arg('elected_at') + AND expires_at >= COALESCE(sqlc.narg('now'), NOW(6)) + AND leader_id = sqlc.arg('leader_id'); + +-- name: LeaderDeleteExpired :execrows +DELETE FROM /* TEMPLATE: schema */river_leader +WHERE expires_at < COALESCE(sqlc.narg('now'), NOW(6)); + +-- name: LeaderGetElectedLeader :one +SELECT elected_at, expires_at, leader_id, name +FROM /* TEMPLATE: schema */river_leader; + +-- name: LeaderInsertExec :exec +INSERT INTO /* TEMPLATE: schema */river_leader ( + elected_at, + expires_at, + leader_id +) VALUES ( + COALESCE(sqlc.narg('elected_at'), sqlc.narg('now'), NOW(6)), + COALESCE(sqlc.narg('expires_at'), TIMESTAMPADD(MICROSECOND, sqlc.arg('ttl'), COALESCE(sqlc.narg('now'), NOW(6)))), + sqlc.arg('leader_id') +); + +-- name: LeaderResign :execrows +DELETE FROM /* TEMPLATE: schema */river_leader +WHERE + elected_at = sqlc.arg('elected_at') + AND leader_id = sqlc.arg('leader_id'); diff --git a/riverdriver/rivermysql/internal/dbsqlc/river_leader.sql.go b/riverdriver/rivermysql/internal/dbsqlc/river_leader.sql.go new file mode 100644 index 00000000..cb31a9e1 --- /dev/null +++ b/riverdriver/rivermysql/internal/dbsqlc/river_leader.sql.go @@ -0,0 +1,147 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.31.0 +// source: river_leader.sql + +package dbsqlc + +import ( + "context" + "database/sql" + "time" +) + +const leaderAttemptElectExec = `-- name: LeaderAttemptElectExec :execresult +INSERT IGNORE INTO /* TEMPLATE: schema */river_leader ( + leader_id, + elected_at, + expires_at +) VALUES ( + ?, + COALESCE(?, NOW(6)), + TIMESTAMPADD(MICROSECOND, ?, COALESCE(?, NOW(6))) +) +` + +type LeaderAttemptElectExecParams struct { + LeaderID string + Now interface{} + TTL int64 +} + +func (q *Queries) LeaderAttemptElectExec(ctx context.Context, db DBTX, arg *LeaderAttemptElectExecParams) (sql.Result, error) { + return db.ExecContext(ctx, leaderAttemptElectExec, + arg.LeaderID, + arg.Now, + arg.TTL, + arg.Now, + ) +} + +const leaderAttemptReelectExec = `-- name: LeaderAttemptReelectExec :execresult +UPDATE /* TEMPLATE: schema */river_leader +SET expires_at = TIMESTAMPADD(MICROSECOND, ?, COALESCE(?, NOW(6))) +WHERE + elected_at = ? + AND expires_at >= COALESCE(?, NOW(6)) + AND leader_id = ? +` + +type LeaderAttemptReelectExecParams struct { + TTL int64 + Now sql.NullTime + ElectedAt time.Time + LeaderID string +} + +func (q *Queries) LeaderAttemptReelectExec(ctx context.Context, db DBTX, arg *LeaderAttemptReelectExecParams) (sql.Result, error) { + return db.ExecContext(ctx, leaderAttemptReelectExec, + arg.TTL, + arg.Now, + arg.ElectedAt, + arg.Now, + arg.LeaderID, + ) +} + +const leaderDeleteExpired = `-- name: LeaderDeleteExpired :execrows +DELETE FROM /* TEMPLATE: schema */river_leader +WHERE expires_at < COALESCE(?, NOW(6)) +` + +func (q *Queries) LeaderDeleteExpired(ctx context.Context, db DBTX, now sql.NullTime) (int64, error) { + result, err := db.ExecContext(ctx, leaderDeleteExpired, now) + if err != nil { + return 0, err + } + return result.RowsAffected() +} + +const leaderGetElectedLeader = `-- name: LeaderGetElectedLeader :one +SELECT elected_at, expires_at, leader_id, name +FROM /* TEMPLATE: schema */river_leader +` + +func (q *Queries) LeaderGetElectedLeader(ctx context.Context, db DBTX) (*RiverLeader, error) { + row := db.QueryRowContext(ctx, leaderGetElectedLeader) + var i RiverLeader + err := row.Scan( + &i.ElectedAt, + &i.ExpiresAt, + &i.LeaderID, + &i.Name, + ) + return &i, err +} + +const leaderInsertExec = `-- name: LeaderInsertExec :exec +INSERT INTO /* TEMPLATE: schema */river_leader ( + elected_at, + expires_at, + leader_id +) VALUES ( + COALESCE(?, ?, NOW(6)), + COALESCE(?, TIMESTAMPADD(MICROSECOND, ?, COALESCE(?, NOW(6)))), + ? +) +` + +type LeaderInsertExecParams struct { + ElectedAt interface{} + Now interface{} + ExpiresAt interface{} + TTL int64 + LeaderID string +} + +func (q *Queries) LeaderInsertExec(ctx context.Context, db DBTX, arg *LeaderInsertExecParams) error { + _, err := db.ExecContext(ctx, leaderInsertExec, + arg.ElectedAt, + arg.Now, + arg.ExpiresAt, + arg.TTL, + arg.Now, + arg.LeaderID, + ) + return err +} + +const leaderResign = `-- name: LeaderResign :execrows +DELETE FROM /* TEMPLATE: schema */river_leader +WHERE + elected_at = ? + AND leader_id = ? +` + +type LeaderResignParams struct { + ElectedAt time.Time + LeaderID string +} + +func (q *Queries) LeaderResign(ctx context.Context, db DBTX, arg *LeaderResignParams) (int64, error) { + result, err := db.ExecContext(ctx, leaderResign, arg.ElectedAt, arg.LeaderID) + if err != nil { + return 0, err + } + return result.RowsAffected() +} diff --git a/riverdriver/rivermysql/internal/dbsqlc/river_migration.sql b/riverdriver/rivermysql/internal/dbsqlc/river_migration.sql new file mode 100644 index 00000000..8a1aeccb --- /dev/null +++ b/riverdriver/rivermysql/internal/dbsqlc/river_migration.sql @@ -0,0 +1,77 @@ +CREATE TABLE river_migration ( + line VARCHAR(128) NOT NULL, + version BIGINT NOT NULL, + created_at DATETIME(6) NOT NULL DEFAULT (NOW(6)), + PRIMARY KEY (line, version) +); + +-- name: RiverMigrationDeleteAssumingMainMany :many +SELECT created_at, version +FROM /* TEMPLATE: schema */river_migration +WHERE version IN (sqlc.slice('version')); + +-- name: RiverMigrationDeleteAssumingMainManyExec :exec +DELETE FROM /* TEMPLATE: schema */river_migration +WHERE version IN (sqlc.slice('version')); + +-- name: RiverMigrationDeleteByLineAndVersionMany :many +SELECT line, version, created_at +FROM /* TEMPLATE: schema */river_migration +WHERE line = sqlc.arg('line') + AND version IN (sqlc.slice('version')); + +-- name: RiverMigrationDeleteByLineAndVersionManyExec :exec +DELETE FROM /* TEMPLATE: schema */river_migration +WHERE line = sqlc.arg('line') + AND version IN (sqlc.slice('version')); + +-- name: RiverMigrationGetAllAssumingMain :many +SELECT + created_at, + version +FROM /* TEMPLATE: schema */river_migration +ORDER BY version; + +-- name: RiverMigrationGetByLine :many +SELECT line, version, created_at +FROM /* TEMPLATE: schema */river_migration +WHERE line = sqlc.arg('line') +ORDER BY version; + +-- name: RiverMigrationInsertExec :exec +INSERT INTO /* TEMPLATE: schema */river_migration ( + line, + version +) VALUES ( + sqlc.arg('line'), + sqlc.arg('version') +); + +-- name: RiverMigrationGetByLineAndVersion :one +SELECT line, version, created_at +FROM /* TEMPLATE: schema */river_migration +WHERE line = sqlc.arg('line') AND version = sqlc.arg('version'); + +-- name: RiverMigrationGetByLineAndVersionMany :many +SELECT line, version, created_at +FROM /* TEMPLATE: schema */river_migration +WHERE line = sqlc.arg('line') AND version IN (sqlc.slice('version')) +ORDER BY version; + +-- name: RiverMigrationInsertAssumingMainExec :exec +INSERT INTO /* TEMPLATE: schema */river_migration ( + version +) VALUES ( + sqlc.arg('version') +); + +-- name: RiverMigrationGetByVersion :one +SELECT created_at, version +FROM /* TEMPLATE: schema */river_migration +WHERE version = sqlc.arg('version'); + +-- name: RiverMigrationGetByVersionMany :many +SELECT created_at, version +FROM /* TEMPLATE: schema */river_migration +WHERE version IN (sqlc.slice('version')) +ORDER BY version; diff --git a/riverdriver/rivermysql/internal/dbsqlc/river_migration.sql.go b/riverdriver/rivermysql/internal/dbsqlc/river_migration.sql.go new file mode 100644 index 00000000..2e3cbb89 --- /dev/null +++ b/riverdriver/rivermysql/internal/dbsqlc/river_migration.sql.go @@ -0,0 +1,375 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.31.0 +// source: river_migration.sql + +package dbsqlc + +import ( + "context" + "strings" + "time" +) + +const riverMigrationDeleteAssumingMainMany = `-- name: RiverMigrationDeleteAssumingMainMany :many +SELECT created_at, version +FROM /* TEMPLATE: schema */river_migration +WHERE version IN (/*SLICE:version*/?) +` + +type RiverMigrationDeleteAssumingMainManyRow struct { + CreatedAt time.Time + Version int64 +} + +func (q *Queries) RiverMigrationDeleteAssumingMainMany(ctx context.Context, db DBTX, version []int64) ([]*RiverMigrationDeleteAssumingMainManyRow, error) { + query := riverMigrationDeleteAssumingMainMany + var queryParams []interface{} + if len(version) > 0 { + for _, v := range version { + queryParams = append(queryParams, v) + } + query = strings.Replace(query, "/*SLICE:version*/?", strings.Repeat(",?", len(version))[1:], 1) + } else { + query = strings.Replace(query, "/*SLICE:version*/?", "NULL", 1) + } + rows, err := db.QueryContext(ctx, query, queryParams...) + if err != nil { + return nil, err + } + defer rows.Close() + var items []*RiverMigrationDeleteAssumingMainManyRow + for rows.Next() { + var i RiverMigrationDeleteAssumingMainManyRow + if err := rows.Scan(&i.CreatedAt, &i.Version); err != nil { + return nil, err + } + items = append(items, &i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const riverMigrationDeleteAssumingMainManyExec = `-- name: RiverMigrationDeleteAssumingMainManyExec :exec +DELETE FROM /* TEMPLATE: schema */river_migration +WHERE version IN (/*SLICE:version*/?) +` + +func (q *Queries) RiverMigrationDeleteAssumingMainManyExec(ctx context.Context, db DBTX, version []int64) error { + query := riverMigrationDeleteAssumingMainManyExec + var queryParams []interface{} + if len(version) > 0 { + for _, v := range version { + queryParams = append(queryParams, v) + } + query = strings.Replace(query, "/*SLICE:version*/?", strings.Repeat(",?", len(version))[1:], 1) + } else { + query = strings.Replace(query, "/*SLICE:version*/?", "NULL", 1) + } + _, err := db.ExecContext(ctx, query, queryParams...) + return err +} + +const riverMigrationDeleteByLineAndVersionMany = `-- name: RiverMigrationDeleteByLineAndVersionMany :many +SELECT line, version, created_at +FROM /* TEMPLATE: schema */river_migration +WHERE line = ? + AND version IN (/*SLICE:version*/?) +` + +type RiverMigrationDeleteByLineAndVersionManyParams struct { + Line string + Version []int64 +} + +func (q *Queries) RiverMigrationDeleteByLineAndVersionMany(ctx context.Context, db DBTX, arg *RiverMigrationDeleteByLineAndVersionManyParams) ([]*RiverMigration, error) { + query := riverMigrationDeleteByLineAndVersionMany + var queryParams []interface{} + queryParams = append(queryParams, arg.Line) + if len(arg.Version) > 0 { + for _, v := range arg.Version { + queryParams = append(queryParams, v) + } + query = strings.Replace(query, "/*SLICE:version*/?", strings.Repeat(",?", len(arg.Version))[1:], 1) + } else { + query = strings.Replace(query, "/*SLICE:version*/?", "NULL", 1) + } + rows, err := db.QueryContext(ctx, query, queryParams...) + if err != nil { + return nil, err + } + defer rows.Close() + var items []*RiverMigration + for rows.Next() { + var i RiverMigration + if err := rows.Scan(&i.Line, &i.Version, &i.CreatedAt); err != nil { + return nil, err + } + items = append(items, &i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const riverMigrationDeleteByLineAndVersionManyExec = `-- name: RiverMigrationDeleteByLineAndVersionManyExec :exec +DELETE FROM /* TEMPLATE: schema */river_migration +WHERE line = ? + AND version IN (/*SLICE:version*/?) +` + +type RiverMigrationDeleteByLineAndVersionManyExecParams struct { + Line string + Version []int64 +} + +func (q *Queries) RiverMigrationDeleteByLineAndVersionManyExec(ctx context.Context, db DBTX, arg *RiverMigrationDeleteByLineAndVersionManyExecParams) error { + query := riverMigrationDeleteByLineAndVersionManyExec + var queryParams []interface{} + queryParams = append(queryParams, arg.Line) + if len(arg.Version) > 0 { + for _, v := range arg.Version { + queryParams = append(queryParams, v) + } + query = strings.Replace(query, "/*SLICE:version*/?", strings.Repeat(",?", len(arg.Version))[1:], 1) + } else { + query = strings.Replace(query, "/*SLICE:version*/?", "NULL", 1) + } + _, err := db.ExecContext(ctx, query, queryParams...) + return err +} + +const riverMigrationGetAllAssumingMain = `-- name: RiverMigrationGetAllAssumingMain :many +SELECT + created_at, + version +FROM /* TEMPLATE: schema */river_migration +ORDER BY version +` + +type RiverMigrationGetAllAssumingMainRow struct { + CreatedAt time.Time + Version int64 +} + +func (q *Queries) RiverMigrationGetAllAssumingMain(ctx context.Context, db DBTX) ([]*RiverMigrationGetAllAssumingMainRow, error) { + rows, err := db.QueryContext(ctx, riverMigrationGetAllAssumingMain) + if err != nil { + return nil, err + } + defer rows.Close() + var items []*RiverMigrationGetAllAssumingMainRow + for rows.Next() { + var i RiverMigrationGetAllAssumingMainRow + if err := rows.Scan(&i.CreatedAt, &i.Version); err != nil { + return nil, err + } + items = append(items, &i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const riverMigrationGetByLine = `-- name: RiverMigrationGetByLine :many +SELECT line, version, created_at +FROM /* TEMPLATE: schema */river_migration +WHERE line = ? +ORDER BY version +` + +func (q *Queries) RiverMigrationGetByLine(ctx context.Context, db DBTX, line string) ([]*RiverMigration, error) { + rows, err := db.QueryContext(ctx, riverMigrationGetByLine, line) + if err != nil { + return nil, err + } + defer rows.Close() + var items []*RiverMigration + for rows.Next() { + var i RiverMigration + if err := rows.Scan(&i.Line, &i.Version, &i.CreatedAt); err != nil { + return nil, err + } + items = append(items, &i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const riverMigrationGetByLineAndVersion = `-- name: RiverMigrationGetByLineAndVersion :one +SELECT line, version, created_at +FROM /* TEMPLATE: schema */river_migration +WHERE line = ? AND version = ? +` + +type RiverMigrationGetByLineAndVersionParams struct { + Line string + Version int64 +} + +func (q *Queries) RiverMigrationGetByLineAndVersion(ctx context.Context, db DBTX, arg *RiverMigrationGetByLineAndVersionParams) (*RiverMigration, error) { + row := db.QueryRowContext(ctx, riverMigrationGetByLineAndVersion, arg.Line, arg.Version) + var i RiverMigration + err := row.Scan(&i.Line, &i.Version, &i.CreatedAt) + return &i, err +} + +const riverMigrationGetByLineAndVersionMany = `-- name: RiverMigrationGetByLineAndVersionMany :many +SELECT line, version, created_at +FROM /* TEMPLATE: schema */river_migration +WHERE line = ? AND version IN (/*SLICE:version*/?) +ORDER BY version +` + +type RiverMigrationGetByLineAndVersionManyParams struct { + Line string + Version []int64 +} + +func (q *Queries) RiverMigrationGetByLineAndVersionMany(ctx context.Context, db DBTX, arg *RiverMigrationGetByLineAndVersionManyParams) ([]*RiverMigration, error) { + query := riverMigrationGetByLineAndVersionMany + var queryParams []interface{} + queryParams = append(queryParams, arg.Line) + if len(arg.Version) > 0 { + for _, v := range arg.Version { + queryParams = append(queryParams, v) + } + query = strings.Replace(query, "/*SLICE:version*/?", strings.Repeat(",?", len(arg.Version))[1:], 1) + } else { + query = strings.Replace(query, "/*SLICE:version*/?", "NULL", 1) + } + rows, err := db.QueryContext(ctx, query, queryParams...) + if err != nil { + return nil, err + } + defer rows.Close() + var items []*RiverMigration + for rows.Next() { + var i RiverMigration + if err := rows.Scan(&i.Line, &i.Version, &i.CreatedAt); err != nil { + return nil, err + } + items = append(items, &i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const riverMigrationGetByVersion = `-- name: RiverMigrationGetByVersion :one +SELECT created_at, version +FROM /* TEMPLATE: schema */river_migration +WHERE version = ? +` + +type RiverMigrationGetByVersionRow struct { + CreatedAt time.Time + Version int64 +} + +func (q *Queries) RiverMigrationGetByVersion(ctx context.Context, db DBTX, version int64) (*RiverMigrationGetByVersionRow, error) { + row := db.QueryRowContext(ctx, riverMigrationGetByVersion, version) + var i RiverMigrationGetByVersionRow + err := row.Scan(&i.CreatedAt, &i.Version) + return &i, err +} + +const riverMigrationGetByVersionMany = `-- name: RiverMigrationGetByVersionMany :many +SELECT created_at, version +FROM /* TEMPLATE: schema */river_migration +WHERE version IN (/*SLICE:version*/?) +ORDER BY version +` + +type RiverMigrationGetByVersionManyRow struct { + CreatedAt time.Time + Version int64 +} + +func (q *Queries) RiverMigrationGetByVersionMany(ctx context.Context, db DBTX, version []int64) ([]*RiverMigrationGetByVersionManyRow, error) { + query := riverMigrationGetByVersionMany + var queryParams []interface{} + if len(version) > 0 { + for _, v := range version { + queryParams = append(queryParams, v) + } + query = strings.Replace(query, "/*SLICE:version*/?", strings.Repeat(",?", len(version))[1:], 1) + } else { + query = strings.Replace(query, "/*SLICE:version*/?", "NULL", 1) + } + rows, err := db.QueryContext(ctx, query, queryParams...) + if err != nil { + return nil, err + } + defer rows.Close() + var items []*RiverMigrationGetByVersionManyRow + for rows.Next() { + var i RiverMigrationGetByVersionManyRow + if err := rows.Scan(&i.CreatedAt, &i.Version); err != nil { + return nil, err + } + items = append(items, &i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const riverMigrationInsertAssumingMainExec = `-- name: RiverMigrationInsertAssumingMainExec :exec +INSERT INTO /* TEMPLATE: schema */river_migration ( + version +) VALUES ( + ? +) +` + +func (q *Queries) RiverMigrationInsertAssumingMainExec(ctx context.Context, db DBTX, version int64) error { + _, err := db.ExecContext(ctx, riverMigrationInsertAssumingMainExec, version) + return err +} + +const riverMigrationInsertExec = `-- name: RiverMigrationInsertExec :exec +INSERT INTO /* TEMPLATE: schema */river_migration ( + line, + version +) VALUES ( + ?, + ? +) +` + +type RiverMigrationInsertExecParams struct { + Line string + Version int64 +} + +func (q *Queries) RiverMigrationInsertExec(ctx context.Context, db DBTX, arg *RiverMigrationInsertExecParams) error { + _, err := db.ExecContext(ctx, riverMigrationInsertExec, arg.Line, arg.Version) + return err +} diff --git a/riverdriver/rivermysql/internal/dbsqlc/river_queue.sql b/riverdriver/rivermysql/internal/dbsqlc/river_queue.sql new file mode 100644 index 00000000..058bb7f2 --- /dev/null +++ b/riverdriver/rivermysql/internal/dbsqlc/river_queue.sql @@ -0,0 +1,91 @@ +CREATE TABLE river_queue ( + name VARCHAR(128) NOT NULL PRIMARY KEY, + created_at DATETIME(6) NOT NULL DEFAULT (NOW(6)), + metadata JSON NOT NULL DEFAULT (JSON_OBJECT()), + paused_at DATETIME(6) NULL, + updated_at DATETIME(6) NOT NULL +); + +-- name: QueueCreateOrSetUpdatedAtExec :exec +INSERT INTO /* TEMPLATE: schema */river_queue ( + created_at, + metadata, + name, + paused_at, + updated_at +) VALUES ( + COALESCE(sqlc.narg('now'), NOW(6)), + CAST(sqlc.arg('metadata') AS JSON), + sqlc.arg('name'), + sqlc.narg('paused_at'), + COALESCE(sqlc.narg('updated_at'), sqlc.narg('now'), NOW(6)) +) ON DUPLICATE KEY UPDATE + updated_at = VALUES(updated_at); + +-- name: QueueGet :one +SELECT name, created_at, metadata, paused_at, updated_at +FROM /* TEMPLATE: schema */river_queue +WHERE name = sqlc.arg('name'); + +-- name: QueueDeleteExpiredSelect :many +SELECT name +FROM /* TEMPLATE: schema */river_queue +WHERE updated_at < sqlc.arg('updated_at_horizon') +ORDER BY name ASC +LIMIT ?; + +-- name: QueueDeleteExpiredExec :exec +DELETE FROM /* TEMPLATE: schema */river_queue +WHERE name IN (sqlc.slice('names')); + +-- name: QueueList :many +SELECT name, created_at, metadata, paused_at, updated_at +FROM /* TEMPLATE: schema */river_queue +ORDER BY name ASC +LIMIT ?; + +-- name: QueueNameList :many +SELECT name +FROM /* TEMPLATE: schema */river_queue +WHERE + name COLLATE utf8mb4_general_ci > sqlc.arg('after') + AND (sqlc.arg('match') = '' OR LOWER(name) COLLATE utf8mb4_general_ci LIKE CONCAT('%', LOWER(sqlc.arg('match')), '%')) + AND name NOT IN (sqlc.slice('exclude')) +ORDER BY name ASC +LIMIT ?; + +-- MySQL evaluates SET clauses left-to-right using already-updated values, so +-- updated_at must be set BEFORE paused_at to see the original paused_at value. + +-- name: QueuePauseAll :execresult +UPDATE /* TEMPLATE: schema */river_queue +SET + updated_at = CASE WHEN paused_at IS NULL THEN COALESCE(sqlc.narg('now'), NOW(6)) ELSE updated_at END, + paused_at = CASE WHEN paused_at IS NULL THEN COALESCE(sqlc.narg('now'), NOW(6)) ELSE paused_at END; + +-- name: QueuePauseByName :execresult +UPDATE /* TEMPLATE: schema */river_queue +SET + updated_at = CASE WHEN paused_at IS NULL THEN COALESCE(sqlc.narg('now'), NOW(6)) ELSE updated_at END, + paused_at = CASE WHEN paused_at IS NULL THEN COALESCE(sqlc.narg('now'), NOW(6)) ELSE paused_at END +WHERE name = sqlc.arg('name'); + +-- name: QueueResumeAll :execresult +UPDATE /* TEMPLATE: schema */river_queue +SET + updated_at = CASE WHEN paused_at IS NOT NULL THEN COALESCE(sqlc.narg('now'), NOW(6)) ELSE updated_at END, + paused_at = NULL; + +-- name: QueueResumeByName :execresult +UPDATE /* TEMPLATE: schema */river_queue +SET + updated_at = CASE WHEN paused_at IS NOT NULL THEN COALESCE(sqlc.narg('now'), NOW(6)) ELSE updated_at END, + paused_at = NULL +WHERE name = sqlc.arg('name'); + +-- name: QueueUpdateExec :exec +UPDATE /* TEMPLATE: schema */river_queue +SET + metadata = CASE WHEN CAST(sqlc.arg('metadata_do_update') AS SIGNED) THEN CAST(sqlc.arg('metadata') AS JSON) ELSE metadata END, + updated_at = NOW(6) +WHERE name = sqlc.arg('name'); diff --git a/riverdriver/rivermysql/internal/dbsqlc/river_queue.sql.go b/riverdriver/rivermysql/internal/dbsqlc/river_queue.sql.go new file mode 100644 index 00000000..38a4cef1 --- /dev/null +++ b/riverdriver/rivermysql/internal/dbsqlc/river_queue.sql.go @@ -0,0 +1,298 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.31.0 +// source: river_queue.sql + +package dbsqlc + +import ( + "context" + "database/sql" + "strings" + "time" +) + +const queueCreateOrSetUpdatedAtExec = `-- name: QueueCreateOrSetUpdatedAtExec :exec +INSERT INTO /* TEMPLATE: schema */river_queue ( + created_at, + metadata, + name, + paused_at, + updated_at +) VALUES ( + COALESCE(?, NOW(6)), + CAST(? AS JSON), + ?, + ?, + COALESCE(?, ?, NOW(6)) +) ON DUPLICATE KEY UPDATE + updated_at = VALUES(updated_at) +` + +type QueueCreateOrSetUpdatedAtExecParams struct { + Now interface{} + Metadata []byte + Name string + PausedAt sql.NullTime + UpdatedAt interface{} +} + +func (q *Queries) QueueCreateOrSetUpdatedAtExec(ctx context.Context, db DBTX, arg *QueueCreateOrSetUpdatedAtExecParams) error { + _, err := db.ExecContext(ctx, queueCreateOrSetUpdatedAtExec, + arg.Now, + arg.Metadata, + arg.Name, + arg.PausedAt, + arg.UpdatedAt, + arg.Now, + ) + return err +} + +const queueDeleteExpiredExec = `-- name: QueueDeleteExpiredExec :exec +DELETE FROM /* TEMPLATE: schema */river_queue +WHERE name IN (/*SLICE:names*/?) +` + +func (q *Queries) QueueDeleteExpiredExec(ctx context.Context, db DBTX, names []string) error { + query := queueDeleteExpiredExec + var queryParams []interface{} + if len(names) > 0 { + for _, v := range names { + queryParams = append(queryParams, v) + } + query = strings.Replace(query, "/*SLICE:names*/?", strings.Repeat(",?", len(names))[1:], 1) + } else { + query = strings.Replace(query, "/*SLICE:names*/?", "NULL", 1) + } + _, err := db.ExecContext(ctx, query, queryParams...) + return err +} + +const queueDeleteExpiredSelect = `-- name: QueueDeleteExpiredSelect :many +SELECT name +FROM /* TEMPLATE: schema */river_queue +WHERE updated_at < ? +ORDER BY name ASC +LIMIT ? +` + +type QueueDeleteExpiredSelectParams struct { + UpdatedAtHorizon time.Time + Limit int32 +} + +func (q *Queries) QueueDeleteExpiredSelect(ctx context.Context, db DBTX, arg *QueueDeleteExpiredSelectParams) ([]string, error) { + rows, err := db.QueryContext(ctx, queueDeleteExpiredSelect, arg.UpdatedAtHorizon, arg.Limit) + if err != nil { + return nil, err + } + defer rows.Close() + var items []string + for rows.Next() { + var name string + if err := rows.Scan(&name); err != nil { + return nil, err + } + items = append(items, name) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const queueGet = `-- name: QueueGet :one +SELECT name, created_at, metadata, paused_at, updated_at +FROM /* TEMPLATE: schema */river_queue +WHERE name = ? +` + +func (q *Queries) QueueGet(ctx context.Context, db DBTX, name string) (*RiverQueue, error) { + row := db.QueryRowContext(ctx, queueGet, name) + var i RiverQueue + err := row.Scan( + &i.Name, + &i.CreatedAt, + &i.Metadata, + &i.PausedAt, + &i.UpdatedAt, + ) + return &i, err +} + +const queueList = `-- name: QueueList :many +SELECT name, created_at, metadata, paused_at, updated_at +FROM /* TEMPLATE: schema */river_queue +ORDER BY name ASC +LIMIT ? +` + +func (q *Queries) QueueList(ctx context.Context, db DBTX, limit int32) ([]*RiverQueue, error) { + rows, err := db.QueryContext(ctx, queueList, limit) + if err != nil { + return nil, err + } + defer rows.Close() + var items []*RiverQueue + for rows.Next() { + var i RiverQueue + if err := rows.Scan( + &i.Name, + &i.CreatedAt, + &i.Metadata, + &i.PausedAt, + &i.UpdatedAt, + ); err != nil { + return nil, err + } + items = append(items, &i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const queueNameList = `-- name: QueueNameList :many +SELECT name +FROM /* TEMPLATE: schema */river_queue +WHERE + name COLLATE utf8mb4_general_ci > ? + AND (? = '' OR LOWER(name) COLLATE utf8mb4_general_ci LIKE CONCAT('%', LOWER(?), '%')) + AND name NOT IN (/*SLICE:exclude*/?) +ORDER BY name ASC +LIMIT ? +` + +type QueueNameListParams struct { + After interface{} + Match string + Exclude []string + Limit int32 +} + +func (q *Queries) QueueNameList(ctx context.Context, db DBTX, arg *QueueNameListParams) ([]string, error) { + query := queueNameList + var queryParams []interface{} + queryParams = append(queryParams, arg.After) + queryParams = append(queryParams, arg.Match) + queryParams = append(queryParams, arg.Match) + if len(arg.Exclude) > 0 { + for _, v := range arg.Exclude { + queryParams = append(queryParams, v) + } + query = strings.Replace(query, "/*SLICE:exclude*/?", strings.Repeat(",?", len(arg.Exclude))[1:], 1) + } else { + query = strings.Replace(query, "/*SLICE:exclude*/?", "NULL", 1) + } + queryParams = append(queryParams, arg.Limit) + rows, err := db.QueryContext(ctx, query, queryParams...) + if err != nil { + return nil, err + } + defer rows.Close() + var items []string + for rows.Next() { + var name string + if err := rows.Scan(&name); err != nil { + return nil, err + } + items = append(items, name) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const queuePauseAll = `-- name: QueuePauseAll :execresult + +UPDATE /* TEMPLATE: schema */river_queue +SET + updated_at = CASE WHEN paused_at IS NULL THEN COALESCE(?, NOW(6)) ELSE updated_at END, + paused_at = CASE WHEN paused_at IS NULL THEN COALESCE(?, NOW(6)) ELSE paused_at END +` + +type QueuePauseAllParams struct { + Now sql.NullTime +} + +// MySQL evaluates SET clauses left-to-right using already-updated values, so +// updated_at must be set BEFORE paused_at to see the original paused_at value. +func (q *Queries) QueuePauseAll(ctx context.Context, db DBTX, arg *QueuePauseAllParams) (sql.Result, error) { + return db.ExecContext(ctx, queuePauseAll, arg.Now, arg.Now) +} + +const queuePauseByName = `-- name: QueuePauseByName :execresult +UPDATE /* TEMPLATE: schema */river_queue +SET + updated_at = CASE WHEN paused_at IS NULL THEN COALESCE(?, NOW(6)) ELSE updated_at END, + paused_at = CASE WHEN paused_at IS NULL THEN COALESCE(?, NOW(6)) ELSE paused_at END +WHERE name = ? +` + +type QueuePauseByNameParams struct { + Now sql.NullTime + Name string +} + +func (q *Queries) QueuePauseByName(ctx context.Context, db DBTX, arg *QueuePauseByNameParams) (sql.Result, error) { + return db.ExecContext(ctx, queuePauseByName, arg.Now, arg.Now, arg.Name) +} + +const queueResumeAll = `-- name: QueueResumeAll :execresult +UPDATE /* TEMPLATE: schema */river_queue +SET + updated_at = CASE WHEN paused_at IS NOT NULL THEN COALESCE(?, NOW(6)) ELSE updated_at END, + paused_at = NULL +` + +func (q *Queries) QueueResumeAll(ctx context.Context, db DBTX, now sql.NullTime) (sql.Result, error) { + return db.ExecContext(ctx, queueResumeAll, now) +} + +const queueResumeByName = `-- name: QueueResumeByName :execresult +UPDATE /* TEMPLATE: schema */river_queue +SET + updated_at = CASE WHEN paused_at IS NOT NULL THEN COALESCE(?, NOW(6)) ELSE updated_at END, + paused_at = NULL +WHERE name = ? +` + +type QueueResumeByNameParams struct { + Now sql.NullTime + Name string +} + +func (q *Queries) QueueResumeByName(ctx context.Context, db DBTX, arg *QueueResumeByNameParams) (sql.Result, error) { + return db.ExecContext(ctx, queueResumeByName, arg.Now, arg.Name) +} + +const queueUpdateExec = `-- name: QueueUpdateExec :exec +UPDATE /* TEMPLATE: schema */river_queue +SET + metadata = CASE WHEN CAST(? AS SIGNED) THEN CAST(? AS JSON) ELSE metadata END, + updated_at = NOW(6) +WHERE name = ? +` + +type QueueUpdateExecParams struct { + MetadataDoUpdate int64 + Metadata []byte + Name string +} + +func (q *Queries) QueueUpdateExec(ctx context.Context, db DBTX, arg *QueueUpdateExecParams) error { + _, err := db.ExecContext(ctx, queueUpdateExec, arg.MetadataDoUpdate, arg.Metadata, arg.Name) + return err +} diff --git a/riverdriver/rivermysql/internal/dbsqlc/schema.sql b/riverdriver/rivermysql/internal/dbsqlc/schema.sql new file mode 100644 index 00000000..8a3e5372 --- /dev/null +++ b/riverdriver/rivermysql/internal/dbsqlc/schema.sql @@ -0,0 +1,43 @@ +-- Dummy table definitions for INFORMATION_SCHEMA system tables so sqlc can +-- resolve column types. At runtime the template prefix replaces the empty +-- default with "INFORMATION_SCHEMA.", making queries target the real system +-- tables. +CREATE TABLE STATISTICS ( + INDEX_NAME VARCHAR(128) NOT NULL, + TABLE_NAME VARCHAR(128) NOT NULL, + TABLE_SCHEMA VARCHAR(128) NOT NULL +); + +CREATE TABLE TABLES ( + TABLE_NAME VARCHAR(128) NOT NULL, + TABLE_SCHEMA VARCHAR(128) NOT NULL +); + +-- name: IndexExists :one +SELECT EXISTS ( + SELECT 1 + FROM /* TEMPLATE: information_schema */STATISTICS + WHERE INDEX_NAME = sqlc.arg('index_name') + AND TABLE_SCHEMA = COALESCE(sqlc.narg('schema'), DATABASE()) +); + +-- name: IndexGetTableName :one +SELECT TABLE_NAME +FROM /* TEMPLATE: information_schema */STATISTICS +WHERE INDEX_NAME = sqlc.arg('index_name') + AND TABLE_SCHEMA = COALESCE(sqlc.narg('schema'), DATABASE()) +LIMIT 1; + +-- name: IndexesExist :many +SELECT DISTINCT INDEX_NAME AS index_name +FROM /* TEMPLATE: information_schema */STATISTICS +WHERE INDEX_NAME IN (sqlc.slice('index_names')) + AND TABLE_SCHEMA = COALESCE(sqlc.narg('schema'), DATABASE()); + +-- name: TableExists :one +SELECT EXISTS ( + SELECT 1 + FROM /* TEMPLATE: information_schema */TABLES + WHERE TABLE_NAME = sqlc.arg('table_name') + AND TABLE_SCHEMA = COALESCE(sqlc.narg('schema'), DATABASE()) +); diff --git a/riverdriver/rivermysql/internal/dbsqlc/schema.sql.go b/riverdriver/rivermysql/internal/dbsqlc/schema.sql.go new file mode 100644 index 00000000..d84f137a --- /dev/null +++ b/riverdriver/rivermysql/internal/dbsqlc/schema.sql.go @@ -0,0 +1,120 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.31.0 +// source: schema.sql + +package dbsqlc + +import ( + "context" + "database/sql" + "strings" +) + +const indexExists = `-- name: IndexExists :one +SELECT EXISTS ( + SELECT 1 + FROM /* TEMPLATE: information_schema */STATISTICS + WHERE INDEX_NAME = ? + AND TABLE_SCHEMA = COALESCE(?, DATABASE()) +) +` + +type IndexExistsParams struct { + IndexName string + Schema sql.NullString +} + +func (q *Queries) IndexExists(ctx context.Context, db DBTX, arg *IndexExistsParams) (bool, error) { + row := db.QueryRowContext(ctx, indexExists, arg.IndexName, arg.Schema) + var exists bool + err := row.Scan(&exists) + return exists, err +} + +const indexGetTableName = `-- name: IndexGetTableName :one +SELECT TABLE_NAME +FROM /* TEMPLATE: information_schema */STATISTICS +WHERE INDEX_NAME = ? + AND TABLE_SCHEMA = COALESCE(?, DATABASE()) +LIMIT 1 +` + +type IndexGetTableNameParams struct { + IndexName string + Schema sql.NullString +} + +func (q *Queries) IndexGetTableName(ctx context.Context, db DBTX, arg *IndexGetTableNameParams) (string, error) { + row := db.QueryRowContext(ctx, indexGetTableName, arg.IndexName, arg.Schema) + var table_name string + err := row.Scan(&table_name) + return table_name, err +} + +const indexesExist = `-- name: IndexesExist :many +SELECT DISTINCT INDEX_NAME AS index_name +FROM /* TEMPLATE: information_schema */STATISTICS +WHERE INDEX_NAME IN (/*SLICE:index_names*/?) + AND TABLE_SCHEMA = COALESCE(?, DATABASE()) +` + +type IndexesExistParams struct { + IndexNames []string + Schema sql.NullString +} + +func (q *Queries) IndexesExist(ctx context.Context, db DBTX, arg *IndexesExistParams) ([]string, error) { + query := indexesExist + var queryParams []interface{} + if len(arg.IndexNames) > 0 { + for _, v := range arg.IndexNames { + queryParams = append(queryParams, v) + } + query = strings.Replace(query, "/*SLICE:index_names*/?", strings.Repeat(",?", len(arg.IndexNames))[1:], 1) + } else { + query = strings.Replace(query, "/*SLICE:index_names*/?", "NULL", 1) + } + queryParams = append(queryParams, arg.Schema) + rows, err := db.QueryContext(ctx, query, queryParams...) + if err != nil { + return nil, err + } + defer rows.Close() + var items []string + for rows.Next() { + var index_name string + if err := rows.Scan(&index_name); err != nil { + return nil, err + } + items = append(items, index_name) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const tableExists = `-- name: TableExists :one +SELECT EXISTS ( + SELECT 1 + FROM /* TEMPLATE: information_schema */TABLES + WHERE TABLE_NAME = ? + AND TABLE_SCHEMA = COALESCE(?, DATABASE()) +) +` + +type TableExistsParams struct { + TableName string + Schema sql.NullString +} + +func (q *Queries) TableExists(ctx context.Context, db DBTX, arg *TableExistsParams) (bool, error) { + row := db.QueryRowContext(ctx, tableExists, arg.TableName, arg.Schema) + var exists bool + err := row.Scan(&exists) + return exists, err +} diff --git a/riverdriver/rivermysql/internal/dbsqlc/sqlc.yaml b/riverdriver/rivermysql/internal/dbsqlc/sqlc.yaml new file mode 100644 index 00000000..3af577c0 --- /dev/null +++ b/riverdriver/rivermysql/internal/dbsqlc/sqlc.yaml @@ -0,0 +1,41 @@ +version: "2" +sql: + - engine: "mysql" + queries: + - river_job.sql + - river_leader.sql + - river_migration.sql + - river_queue.sql + - schema.sql + schema: + - river_client.sql + - river_client_queue.sql + - river_job.sql + - river_leader.sql + - river_migration.sql + - river_queue.sql + - schema.sql + gen: + go: + package: "dbsqlc" + out: "." + emit_exact_table_names: true + emit_methods_with_db_argument: true + emit_params_struct_pointers: true + emit_pointers_for_null_types: true + emit_result_struct_pointers: true + + rename: + ids: "IDs" + ttl: "TTL" + + overrides: + - db_type: "json" + go_type: + type: "[]byte" + - db_type: "json" + go_type: + type: "[]byte" + nullable: true + - db_type: "int" + go_type: "int64" diff --git a/riverdriver/rivermysql/migration/main/001_create_river_migration.down.sql b/riverdriver/rivermysql/migration/main/001_create_river_migration.down.sql new file mode 100644 index 00000000..d2f6b3fa --- /dev/null +++ b/riverdriver/rivermysql/migration/main/001_create_river_migration.down.sql @@ -0,0 +1 @@ +DROP TABLE /* TEMPLATE: schema */river_migration; diff --git a/riverdriver/rivermysql/migration/main/001_create_river_migration.up.sql b/riverdriver/rivermysql/migration/main/001_create_river_migration.up.sql new file mode 100644 index 00000000..d2ee3f73 --- /dev/null +++ b/riverdriver/rivermysql/migration/main/001_create_river_migration.up.sql @@ -0,0 +1,8 @@ +CREATE TABLE /* TEMPLATE: schema */river_migration ( + id BIGINT AUTO_INCREMENT PRIMARY KEY, + created_at DATETIME(6) NOT NULL DEFAULT (NOW(6)), + version BIGINT NOT NULL, + CONSTRAINT version CHECK (version >= 1) +) ENGINE=InnoDB; + +CREATE UNIQUE INDEX river_migration_version_idx ON /* TEMPLATE: schema */river_migration (version); diff --git a/riverdriver/rivermysql/migration/main/002_initial_schema.down.sql b/riverdriver/rivermysql/migration/main/002_initial_schema.down.sql new file mode 100644 index 00000000..095aa1dd --- /dev/null +++ b/riverdriver/rivermysql/migration/main/002_initial_schema.down.sql @@ -0,0 +1,2 @@ +DROP TABLE /* TEMPLATE: schema */river_job; +DROP TABLE /* TEMPLATE: schema */river_leader; diff --git a/riverdriver/rivermysql/migration/main/002_initial_schema.up.sql b/riverdriver/rivermysql/migration/main/002_initial_schema.up.sql new file mode 100644 index 00000000..593cfa93 --- /dev/null +++ b/riverdriver/rivermysql/migration/main/002_initial_schema.up.sql @@ -0,0 +1,43 @@ +CREATE TABLE /* TEMPLATE: schema */river_job( + id BIGINT AUTO_INCREMENT PRIMARY KEY, + state VARCHAR(20) NOT NULL DEFAULT 'available', + attempt INT NOT NULL DEFAULT 0, + max_attempts INT NOT NULL, + attempted_at DATETIME(6) NULL, + created_at DATETIME(6) NOT NULL DEFAULT (NOW(6)), + finalized_at DATETIME(6) NULL, + scheduled_at DATETIME(6) NOT NULL DEFAULT (NOW(6)), + priority SMALLINT NOT NULL DEFAULT 1, + args JSON NULL, + attempted_by JSON NULL, -- JSON array of strings (no native text[] in MySQL) + errors JSON NULL, -- JSON array of error objects (no native jsonb[] in MySQL) + kind VARCHAR(128) NOT NULL, + metadata JSON NOT NULL DEFAULT (JSON_OBJECT()), + queue VARCHAR(128) NOT NULL DEFAULT 'default', + tags JSON NULL, -- JSON array of strings (no native varchar[] in MySQL) + CONSTRAINT finalized_or_finalized_at_null CHECK ( + (state IN ('cancelled', 'completed', 'discarded') AND finalized_at IS NOT NULL) OR + finalized_at IS NULL + ), + CONSTRAINT max_attempts_is_positive CHECK (max_attempts > 0), + CONSTRAINT priority_in_range CHECK (priority >= 1 AND priority <= 4), + CONSTRAINT queue_length CHECK (CHAR_LENGTH(queue) > 0 AND CHAR_LENGTH(queue) < 128), + CONSTRAINT kind_length CHECK (CHAR_LENGTH(kind) > 0 AND CHAR_LENGTH(kind) < 128), + CONSTRAINT state_valid CHECK (state IN ('available', 'cancelled', 'completed', 'discarded', 'retryable', 'running', 'scheduled')) +) ENGINE=InnoDB; + +CREATE INDEX river_job_kind ON /* TEMPLATE: schema */river_job (kind); +CREATE INDEX river_job_state_and_finalized_at_index ON /* TEMPLATE: schema */river_job (state, finalized_at); +CREATE INDEX river_job_prioritized_fetching_index ON /* TEMPLATE: schema */river_job (state, queue, priority, scheduled_at, id); + +-- MySQL does not support triggers for LISTEN/NOTIFY, so river_job_notify is +-- omitted. The MySQL driver operates in poll-only mode. + +CREATE TABLE /* TEMPLATE: schema */river_leader ( + elected_at DATETIME(6) NOT NULL, + expires_at DATETIME(6) NOT NULL, + leader_id VARCHAR(128) NOT NULL, + name VARCHAR(128) NOT NULL PRIMARY KEY, + CONSTRAINT name_length CHECK (CHAR_LENGTH(name) > 0 AND CHAR_LENGTH(name) < 128), + CONSTRAINT leader_id_length CHECK (CHAR_LENGTH(leader_id) > 0 AND CHAR_LENGTH(leader_id) < 128) +) ENGINE=InnoDB; diff --git a/riverdriver/rivermysql/migration/main/003_river_job_tags_non_null.down.sql b/riverdriver/rivermysql/migration/main/003_river_job_tags_non_null.down.sql new file mode 100644 index 00000000..93c33638 --- /dev/null +++ b/riverdriver/rivermysql/migration/main/003_river_job_tags_non_null.down.sql @@ -0,0 +1 @@ +ALTER TABLE /* TEMPLATE: schema */river_job MODIFY COLUMN tags JSON NULL; diff --git a/riverdriver/rivermysql/migration/main/003_river_job_tags_non_null.up.sql b/riverdriver/rivermysql/migration/main/003_river_job_tags_non_null.up.sql new file mode 100644 index 00000000..61b2f430 --- /dev/null +++ b/riverdriver/rivermysql/migration/main/003_river_job_tags_non_null.up.sql @@ -0,0 +1,2 @@ +UPDATE /* TEMPLATE: schema */river_job SET tags = JSON_ARRAY() WHERE tags IS NULL; +ALTER TABLE /* TEMPLATE: schema */river_job MODIFY COLUMN tags JSON NOT NULL DEFAULT (JSON_ARRAY()); diff --git a/riverdriver/rivermysql/migration/main/004_pending_and_more.down.sql b/riverdriver/rivermysql/migration/main/004_pending_and_more.down.sql new file mode 100644 index 00000000..891a0e81 --- /dev/null +++ b/riverdriver/rivermysql/migration/main/004_pending_and_more.down.sql @@ -0,0 +1,21 @@ +ALTER TABLE /* TEMPLATE: schema */river_job MODIFY COLUMN args JSON NULL; + +ALTER TABLE /* TEMPLATE: schema */river_job MODIFY COLUMN metadata JSON NOT NULL DEFAULT (JSON_OBJECT()); + +-- Cannot safely remove 'pending' from the CHECK constraint if rows reference +-- it, but we restore the original constraint form. +ALTER TABLE /* TEMPLATE: schema */river_job DROP CONSTRAINT finalized_or_finalized_at_null; +ALTER TABLE /* TEMPLATE: schema */river_job ADD CONSTRAINT finalized_or_finalized_at_null CHECK ( + (state IN ('cancelled', 'completed', 'discarded') AND finalized_at IS NOT NULL) OR + finalized_at IS NULL +); + +-- MySQL does not support triggers for LISTEN/NOTIFY, so no trigger changes. + +DROP TABLE /* TEMPLATE: schema */river_queue; + +ALTER TABLE /* TEMPLATE: schema */river_leader + ALTER COLUMN name DROP DEFAULT; + +ALTER TABLE /* TEMPLATE: schema */river_leader DROP CONSTRAINT name_length; +ALTER TABLE /* TEMPLATE: schema */river_leader ADD CONSTRAINT name_length CHECK (CHAR_LENGTH(name) > 0 AND CHAR_LENGTH(name) < 128); diff --git a/riverdriver/rivermysql/migration/main/004_pending_and_more.up.sql b/riverdriver/rivermysql/migration/main/004_pending_and_more.up.sql new file mode 100644 index 00000000..791ebcda --- /dev/null +++ b/riverdriver/rivermysql/migration/main/004_pending_and_more.up.sql @@ -0,0 +1,48 @@ +-- Make args NOT NULL with a default. +UPDATE /* TEMPLATE: schema */river_job SET args = JSON_OBJECT() WHERE args IS NULL; +ALTER TABLE /* TEMPLATE: schema */river_job MODIFY COLUMN args JSON NOT NULL DEFAULT (JSON_OBJECT()); + +-- Make metadata NOT NULL (it already had a default). +UPDATE /* TEMPLATE: schema */river_job SET metadata = JSON_OBJECT() WHERE metadata IS NULL; +ALTER TABLE /* TEMPLATE: schema */river_job MODIFY COLUMN metadata JSON NOT NULL DEFAULT (JSON_OBJECT()); + +-- Add 'pending' to the set of valid states. MySQL doesn't have enum types to +-- alter, so we update the CHECK constraint instead. +ALTER TABLE /* TEMPLATE: schema */river_job DROP CONSTRAINT state_valid; +ALTER TABLE /* TEMPLATE: schema */river_job ADD CONSTRAINT state_valid CHECK ( + state IN ('available', 'cancelled', 'completed', 'discarded', 'pending', 'retryable', 'running', 'scheduled') +); + +-- Update the finalized_at constraint to use the inverted form (matching +-- Postgres migration 004). +ALTER TABLE /* TEMPLATE: schema */river_job DROP CONSTRAINT finalized_or_finalized_at_null; +ALTER TABLE /* TEMPLATE: schema */river_job ADD CONSTRAINT finalized_or_finalized_at_null CHECK ( + (finalized_at IS NULL AND state NOT IN ('cancelled', 'completed', 'discarded')) OR + (finalized_at IS NOT NULL AND state IN ('cancelled', 'completed', 'discarded')) +); + +-- MySQL does not support triggers for LISTEN/NOTIFY, so river_job_notify +-- changes from Postgres are omitted. + +-- +-- Create table `river_queue`. +-- + +CREATE TABLE /* TEMPLATE: schema */river_queue ( + name VARCHAR(128) NOT NULL PRIMARY KEY, + created_at DATETIME(6) NOT NULL DEFAULT (NOW(6)), + metadata JSON NOT NULL DEFAULT (JSON_OBJECT()), + paused_at DATETIME(6) NULL, + updated_at DATETIME(6) NOT NULL +) ENGINE=InnoDB; + +-- +-- Alter `river_leader` to add a default value of 'default' to `name`. +-- MySQL supports ALTER TABLE for changing defaults and constraints. +-- + +ALTER TABLE /* TEMPLATE: schema */river_leader + ALTER COLUMN name SET DEFAULT 'default'; + +ALTER TABLE /* TEMPLATE: schema */river_leader DROP CONSTRAINT name_length; +ALTER TABLE /* TEMPLATE: schema */river_leader ADD CONSTRAINT name_length CHECK (name = 'default'); diff --git a/riverdriver/rivermysql/migration/main/005_migration_unique_client.down.sql b/riverdriver/rivermysql/migration/main/005_migration_unique_client.down.sql new file mode 100644 index 00000000..1d77b022 --- /dev/null +++ b/riverdriver/rivermysql/migration/main/005_migration_unique_client.down.sql @@ -0,0 +1,35 @@ +-- +-- Revert to migration table based only on `(version)`. +-- + +CREATE TABLE /* TEMPLATE: schema */river_migration_old ( + id BIGINT AUTO_INCREMENT PRIMARY KEY, + created_at DATETIME(6) NOT NULL DEFAULT (NOW(6)), + version BIGINT NOT NULL, + CONSTRAINT version CHECK (version >= 1) +) ENGINE=InnoDB; + +CREATE UNIQUE INDEX river_migration_version_idx ON /* TEMPLATE: schema */river_migration_old (version); + +INSERT INTO /* TEMPLATE: schema */river_migration_old + (created_at, version) +SELECT created_at, version +FROM /* TEMPLATE: schema */river_migration; + +DROP TABLE /* TEMPLATE: schema */river_migration; + +ALTER TABLE /* TEMPLATE: schema */river_migration_old RENAME TO /* TEMPLATE: schema */river_migration; + +-- +-- Drop `river_job.unique_key` and its index. +-- + +DROP INDEX river_job_kind_unique_key_idx ON /* TEMPLATE: schema */river_job; +ALTER TABLE /* TEMPLATE: schema */river_job DROP COLUMN unique_key; + +-- +-- Drop `river_client` and derivative. +-- + +DROP TABLE /* TEMPLATE: schema */river_client_queue; +DROP TABLE /* TEMPLATE: schema */river_client; diff --git a/riverdriver/rivermysql/migration/main/005_migration_unique_client.up.sql b/riverdriver/rivermysql/migration/main/005_migration_unique_client.up.sql new file mode 100644 index 00000000..03e54abd --- /dev/null +++ b/riverdriver/rivermysql/migration/main/005_migration_unique_client.up.sql @@ -0,0 +1,60 @@ +-- +-- Rebuild the migration table so it's based on `(line, version)`. +-- + +DROP INDEX river_migration_version_idx ON /* TEMPLATE: schema */river_migration; + +CREATE TABLE /* TEMPLATE: schema */river_migration_new ( + line VARCHAR(128) NOT NULL, + version BIGINT NOT NULL, + created_at DATETIME(6) NOT NULL DEFAULT (NOW(6)), + CONSTRAINT line_length CHECK (CHAR_LENGTH(line) > 0 AND CHAR_LENGTH(line) < 128), + CONSTRAINT version_gte_1 CHECK (version >= 1), + PRIMARY KEY (line, version) +) ENGINE=InnoDB; + +INSERT INTO /* TEMPLATE: schema */river_migration_new + (created_at, line, version) +SELECT created_at, 'main', version +FROM /* TEMPLATE: schema */river_migration; + +DROP TABLE /* TEMPLATE: schema */river_migration; + +ALTER TABLE /* TEMPLATE: schema */river_migration_new RENAME TO /* TEMPLATE: schema */river_migration; + +-- +-- Add `river_job.unique_key` and bring up an index on it. +-- + +ALTER TABLE /* TEMPLATE: schema */river_job ADD COLUMN unique_key VARBINARY(255) NULL; + +CREATE UNIQUE INDEX river_job_kind_unique_key_idx ON /* TEMPLATE: schema */river_job (kind, unique_key); + +-- +-- Create `river_client` and derivative. +-- + +CREATE TABLE /* TEMPLATE: schema */river_client ( + id VARCHAR(128) NOT NULL PRIMARY KEY, + created_at DATETIME(6) NOT NULL DEFAULT (NOW(6)), + metadata JSON NOT NULL DEFAULT (JSON_OBJECT()), + paused_at DATETIME(6) NULL, + updated_at DATETIME(6) NOT NULL, + CONSTRAINT client_name_length CHECK (CHAR_LENGTH(id) > 0 AND CHAR_LENGTH(id) < 128) +) ENGINE=InnoDB; + +CREATE TABLE /* TEMPLATE: schema */river_client_queue ( + river_client_id VARCHAR(128) NOT NULL, + name VARCHAR(128) NOT NULL, + created_at DATETIME(6) NOT NULL DEFAULT (NOW(6)), + max_workers INT NOT NULL DEFAULT 0, + metadata JSON NOT NULL DEFAULT (JSON_OBJECT()), + num_jobs_completed BIGINT NOT NULL DEFAULT 0, + num_jobs_running BIGINT NOT NULL DEFAULT 0, + updated_at DATETIME(6) NOT NULL, + PRIMARY KEY (river_client_id, name), + CONSTRAINT fk_river_client FOREIGN KEY (river_client_id) REFERENCES river_client (id) ON DELETE CASCADE, + CONSTRAINT cq_name_length CHECK (CHAR_LENGTH(name) > 0 AND CHAR_LENGTH(name) < 128), + CONSTRAINT num_jobs_completed_zero_or_positive CHECK (num_jobs_completed >= 0), + CONSTRAINT num_jobs_running_zero_or_positive CHECK (num_jobs_running >= 0) +) ENGINE=InnoDB; diff --git a/riverdriver/rivermysql/migration/main/006_bulk_unique.down.sql b/riverdriver/rivermysql/migration/main/006_bulk_unique.down.sql new file mode 100644 index 00000000..0e4bf09f --- /dev/null +++ b/riverdriver/rivermysql/migration/main/006_bulk_unique.down.sql @@ -0,0 +1,9 @@ +-- +-- Drop the functional unique index and unique_states column. +-- + +DROP INDEX river_job_unique_idx ON /* TEMPLATE: schema */river_job; +ALTER TABLE /* TEMPLATE: schema */river_job DROP COLUMN unique_states; + +-- Recreate the old unique index from migration 005. +CREATE UNIQUE INDEX river_job_kind_unique_key_idx ON /* TEMPLATE: schema */river_job (kind, unique_key); diff --git a/riverdriver/rivermysql/migration/main/006_bulk_unique.up.sql b/riverdriver/rivermysql/migration/main/006_bulk_unique.up.sql new file mode 100644 index 00000000..2f4883cc --- /dev/null +++ b/riverdriver/rivermysql/migration/main/006_bulk_unique.up.sql @@ -0,0 +1,45 @@ +-- +-- Add `river_job.unique_states` and bring up an index on it. +-- +-- MySQL 8.0+ supports functional indexes (index on an expression). We use one +-- to index `unique_key` only when the job's state matches the bitmask, which +-- is equivalent to Postgres's partial index with `river_job_state_in_bitmask`. +-- The expression evaluates to NULL when the constraint shouldn't be active, +-- and MySQL's UNIQUE allows multiple NULLs. +-- + +ALTER TABLE /* TEMPLATE: schema */river_job ADD COLUMN unique_states SMALLINT NULL; + +-- A "functional index" (feature specific to MySQL 8.0+) over `unique_key`, +-- `unique_states`, and `state`. The expression evaluates to `unique_key` when +-- the job's current state has its corresponding bit set in `unique_states`, and +-- to NULL otherwise. MySQL's UNIQUE indexes permit multiple NULL values, so +-- rows where the constraint is inactive (NULL result) never conflict with each +-- other, while rows where it's active are checked for uniqueness on +-- `unique_key`. This is the MySQL equivalent of Postgres's partial index with +-- `WHERE ... AND river_job_state_in_bitmask(unique_states, state)`. +-- +-- Unlike Postgres's partial indexes which exclude non-matching rows from the +-- index entirely, MySQL's functional index still stores an entry for every row +-- (NULLs included), so the storage savings aren't equivalent. The uniqueness +-- semantics are the same though. +CREATE UNIQUE INDEX river_job_unique_idx ON /* TEMPLATE: schema */river_job (( + CASE WHEN unique_key IS NOT NULL AND unique_states IS NOT NULL AND + CASE state + WHEN 'available' THEN unique_states & (1 << 0) + WHEN 'cancelled' THEN unique_states & (1 << 1) + WHEN 'completed' THEN unique_states & (1 << 2) + WHEN 'discarded' THEN unique_states & (1 << 3) + WHEN 'pending' THEN unique_states & (1 << 4) + WHEN 'retryable' THEN unique_states & (1 << 5) + WHEN 'running' THEN unique_states & (1 << 6) + WHEN 'scheduled' THEN unique_states & (1 << 7) + ELSE 0 + END >= 1 + THEN unique_key + ELSE NULL + END +)); + +-- Remove the old unique index from migration 005. +DROP INDEX river_job_kind_unique_key_idx ON /* TEMPLATE: schema */river_job; diff --git a/riverdriver/rivermysql/river_mysql_driver.go b/riverdriver/rivermysql/river_mysql_driver.go new file mode 100644 index 00000000..924cb91d --- /dev/null +++ b/riverdriver/rivermysql/river_mysql_driver.go @@ -0,0 +1,1962 @@ +// Package rivermysql provides a River driver implementation for MySQL. +// +// This driver targets MySQL 8.0+ and requires `go-sql-driver/mysql` or a +// compatible driver registered with `database/sql`. The DSN should include +// `parseTime=true` to ensure `DATETIME` columns are correctly scanned into +// `time.Time` values. +// +// MySQL does not support LISTEN/NOTIFY, so this driver operates in poll-only +// mode. It also does not support `RETURNING` clauses, so most write operations +// are carried out as two-step operations (write + read). +// +// This driver is currently in early development. It's exercised in the test +// suite, but has minimal real world use as of yet. +package rivermysql + +import ( + "context" + "database/sql" + "embed" + "encoding/json" + "errors" + "fmt" + "io/fs" + "math" + "slices" + "strings" + "time" + + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" + + "github.com/riverqueue/river/internal/rivercommon" + "github.com/riverqueue/river/riverdriver" + "github.com/riverqueue/river/riverdriver/rivermysql/internal/dbsqlc" + "github.com/riverqueue/river/rivershared/sqlctemplate" + "github.com/riverqueue/river/rivershared/uniquestates" + "github.com/riverqueue/river/rivershared/util/dbutil" + "github.com/riverqueue/river/rivershared/util/ptrutil" + "github.com/riverqueue/river/rivershared/util/randutil" + "github.com/riverqueue/river/rivershared/util/savepointutil" + "github.com/riverqueue/river/rivershared/util/sliceutil" + "github.com/riverqueue/river/rivertype" +) + +//go:embed migration/*/*.sql +var migrationFS embed.FS + +// Driver is an implementation of riverdriver.Driver for MySQL. +type Driver struct { + dbPool *sql.DB + replacer sqlctemplate.Replacer +} + +// New returns a new MySQL driver for use with River. +// +// It takes an sql.DB to use for use with River. The DSN should include +// `parseTime=true` for correct time handling. The pool must not be closed while +// associated River objects are running. +func New(dbPool *sql.DB) *Driver { + return &Driver{ + dbPool: dbPool, + replacer: sqlctemplate.Replacer{UnnumberedPlaceholders: true}, + } +} + +const argPlaceholder = "?" + +func (d *Driver) ArgPlaceholder() string { return argPlaceholder } +func (d *Driver) DatabaseName() string { return "mysql" } +func (d *Driver) SafeIdentifier(ident string) string { return mysqlIdentifier(ident) } + +func (d *Driver) GetExecutor() riverdriver.Executor { + return &Executor{d.dbPool, templateReplaceWrapper{d.dbPool, &d.replacer}, d, nil} +} + +func (d *Driver) GetListener(params *riverdriver.GetListenenerParams) riverdriver.Listener { + panic(riverdriver.ErrNotImplemented) +} + +func (d *Driver) GetMigrationDefaultLines() []string { return []string{riverdriver.MigrationLineMain} } +func (d *Driver) GetMigrationFS(line string) fs.FS { + if line == riverdriver.MigrationLineMain { + return migrationFS + } + panic("migration line does not exist: " + line) +} +func (d *Driver) GetMigrationLines() []string { return []string{riverdriver.MigrationLineMain} } +func (d *Driver) GetMigrationTruncateTables(line string, version int) []string { + if line == riverdriver.MigrationLineMain { + return riverdriver.MigrationLineMainTruncateTables(version) + } + panic("migration line does not exist: " + line) +} + +func (d *Driver) PoolIsSet() bool { return d.dbPool != nil } +func (d *Driver) PoolSet(dbPool any) error { + if d.dbPool != nil { + return errors.New("cannot PoolSet when internal pool is already non-nil") + } + d.dbPool = dbPool.(*sql.DB) //nolint:forcetypeassert + return nil +} + +func (d *Driver) SQLFragmentColumnIn(column string, values any) (string, any, error) { + arg, err := json.Marshal(values) + if err != nil { + return "", nil, err + } + + // Use JSON_TABLE to expand the JSON array into rows for an IN clause. + // The arg is passed through the template system's NamedArgs as @column, + // which the template replacer turns into a positional placeholder. The + // templateReplaceWrapper strips the number suffix to produce plain `?` + // that MySQL expects. Use VARCHAR(255) as the column type since it + // handles both integer and string comparisons via implicit conversion. + // COLLATE must match the table's column collation. utf8mb4_0900_ai_ci is + // MySQL 8.0+'s default collation for utf8mb4. Without this, JSON_TABLE + // produces values with utf8mb4_general_ci which causes a collation mismatch + // in comparisons. + return fmt.Sprintf("%s IN (SELECT jt.val COLLATE utf8mb4_0900_ai_ci FROM JSON_TABLE(CAST(@%s AS JSON), '$[*]' COLUMNS(val VARCHAR(255) PATH '$')) AS jt)", column, column), arg, nil +} + +func (d *Driver) SupportsListener() bool { return false } +func (d *Driver) SupportsListenNotify() bool { return false } +func (d *Driver) TimePrecision() time.Duration { return time.Microsecond } + +func (d *Driver) UnwrapExecutor(tx *sql.Tx) riverdriver.ExecutorTx { + // Allows UnwrapExecutor to be invoked even if driver is nil. + var replacer *sqlctemplate.Replacer + if d == nil { + replacer = &sqlctemplate.Replacer{UnnumberedPlaceholders: true} + } else { + replacer = &d.replacer + } + + executorTx := &ExecutorTx{tx: tx} + executorTx.Executor = Executor{nil, templateReplaceWrapper{tx, replacer}, d, executorTx} + + return executorTx +} + +func (d *Driver) UnwrapTx(execTx riverdriver.ExecutorTx) *sql.Tx { + switch execTx := execTx.(type) { + case *ExecutorSubTx: + return execTx.tx + case *ExecutorTx: + return execTx.tx + } + panic("unhandled executor type") +} + +type Executor struct { + dbPool *sql.DB + dbtx templateReplaceWrapper + driver *Driver + execTx riverdriver.ExecutorTx +} + +func (e *Executor) Begin(ctx context.Context) (riverdriver.ExecutorTx, error) { + if e.execTx != nil { + return e.execTx.Begin(ctx) + } + + tx, err := e.dbPool.BeginTx(ctx, nil) + if err != nil { + return nil, err + } + + executorTx := &ExecutorTx{tx: tx} + executorTx.Executor = Executor{nil, templateReplaceWrapper{tx, &e.driver.replacer}, e.driver, executorTx} + return executorTx, nil +} + +func (e *Executor) ColumnExists(ctx context.Context, params *riverdriver.ColumnExistsParams) (bool, error) { + var queryArgs []any + query := `SELECT EXISTS ( + SELECT 1 + FROM INFORMATION_SCHEMA.COLUMNS + WHERE TABLE_NAME = ? AND COLUMN_NAME = ?` + queryArgs = append(queryArgs, params.Table, params.Column) + + if params.Schema != "" { + query += " AND TABLE_SCHEMA = ?" + queryArgs = append(queryArgs, params.Schema) + } else { + query += " AND TABLE_SCHEMA = DATABASE()" + } + query += ")" + + var exists bool + if err := e.dbtx.QueryRowContext(ctx, query, queryArgs...).Scan(&exists); err != nil { + return false, interpretError(err) + } + return exists, nil +} + +func (e *Executor) Exec(ctx context.Context, sql string, args ...any) error { + _, err := e.dbtx.ExecContext(ctx, sql, args...) + return interpretError(err) +} + +func (e *Executor) IndexDropIfExists(ctx context.Context, params *riverdriver.IndexDropIfExistsParams) error { + indexName := strings.TrimSpace(params.Index) + + exists, err := e.IndexExists(ctx, &riverdriver.IndexExistsParams{ + Index: indexName, + Schema: params.Schema, + }) + if err != nil { + return err + } + if !exists { + return nil + } + + // MySQL's DROP INDEX requires the table name. Look it up from the index. + tableName, err := dbsqlc.New().IndexGetTableName(informationSchemaParam(ctx), e.dbtx, &dbsqlc.IndexGetTableNameParams{ + IndexName: indexName, + Schema: sql.NullString{String: params.Schema, Valid: params.Schema != ""}, + }) + if err != nil { + return interpretError(err) + } + + var maybeSchema string + if params.Schema != "" { + maybeSchema = mysqlIdentifier(params.Schema) + "." + } + + _, err = e.dbtx.ExecContext(ctx, "DROP INDEX "+mysqlIdentifier(indexName)+" ON "+maybeSchema+mysqlIdentifier(tableName)) + return interpretError(err) +} + +func (e *Executor) IndexExists(ctx context.Context, params *riverdriver.IndexExistsParams) (bool, error) { + ctx = informationSchemaParam(ctx) + + exists, err := dbsqlc.New().IndexExists(ctx, e.dbtx, &dbsqlc.IndexExistsParams{ + IndexName: params.Index, + Schema: sql.NullString{String: params.Schema, Valid: params.Schema != ""}, + }) + if err != nil { + return false, interpretError(err) + } + return exists, nil +} + +func (e *Executor) IndexReindex(ctx context.Context, params *riverdriver.IndexReindexParams) error { + // MySQL doesn't have a direct REINDEX command. We use OPTIMIZE TABLE + // or ALTER TABLE ... FORCE as the closest equivalent. For now, use + // ANALYZE TABLE which updates index statistics. + var tableName string + query := `SELECT TABLE_NAME FROM INFORMATION_SCHEMA.STATISTICS WHERE INDEX_NAME = ?` + queryArgs := []any{params.Index} + if params.Schema != "" { + query += " AND TABLE_SCHEMA = ?" + queryArgs = append(queryArgs, params.Schema) + } else { + query += " AND TABLE_SCHEMA = DATABASE()" + } + query += " LIMIT 1" + + err := e.dbtx.QueryRowContext(ctx, query, queryArgs...).Scan(&tableName) + if err != nil { + return interpretError(err) + } + + var maybeSchema string + if params.Schema != "" { + maybeSchema = mysqlIdentifier(params.Schema) + "." + } + + _, err = e.dbtx.ExecContext(ctx, "ANALYZE TABLE "+maybeSchema+mysqlIdentifier(tableName)) + return interpretError(err) +} + +func (e *Executor) IndexesExist(ctx context.Context, params *riverdriver.IndexesExistParams) (map[string]bool, error) { + foundNames, err := dbsqlc.New().IndexesExist(informationSchemaParam(ctx), e.dbtx, &dbsqlc.IndexesExistParams{ + IndexNames: params.IndexNames, + Schema: sql.NullString{String: params.Schema, Valid: params.Schema != ""}, + }) + if err != nil { + return nil, interpretError(err) + } + + exists := make(map[string]bool, len(params.IndexNames)) + for _, name := range foundNames { + exists[name] = true + } + return exists, nil +} + +func (e *Executor) JobCancel(ctx context.Context, params *riverdriver.JobCancelParams) (*rivertype.JobRow, error) { + return dbutil.WithTxV(ctx, e, func(ctx context.Context, execTx riverdriver.ExecutorTx) (*rivertype.JobRow, error) { + ctx = schemaTemplateParam(ctx, params.Schema) + dbtx := templateReplaceWrapper{dbtx: e.driver.UnwrapTx(execTx), replacer: &e.driver.replacer} + + res, err := dbsqlc.New().JobCancelExec(ctx, dbtx, &dbsqlc.JobCancelExecParams{ + ID: params.ID, + CancelAttemptedAt: params.CancelAttemptedAt.UTC().Format(time.RFC3339Nano), + Now: nullTimeFromPtr(params.Now), + }) + if err != nil { + return nil, interpretError(err) + } + + rowsAffected, err := res.RowsAffected() + if err != nil { + return nil, interpretError(err) + } + + job, err := dbsqlc.New().JobGetByID(ctx, dbtx, params.ID) + if err != nil { + return nil, interpretError(err) + } + + if rowsAffected < 1 { + // No rows were updated, return the job as-is (already finalized) + return jobRowFromInternal(job) + } + + return jobRowFromInternal(job) + }) +} + +func (e *Executor) JobCountByAllStates(ctx context.Context, params *riverdriver.JobCountByAllStatesParams) (map[rivertype.JobState]int, error) { + counts, err := dbsqlc.New().JobCountByAllStates(schemaTemplateParam(ctx, params.Schema), e.dbtx) + if err != nil { + return nil, interpretError(err) + } + countsMap := make(map[rivertype.JobState]int) + for _, state := range rivertype.JobStates() { + countsMap[state] = 0 + } + for _, count := range counts { + countsMap[rivertype.JobState(count.State)] = int(count.Count) + } + return countsMap, nil +} + +func (e *Executor) JobCountByQueueAndState(ctx context.Context, params *riverdriver.JobCountByQueueAndStateParams) ([]*riverdriver.JobCountByQueueAndStateResult, error) { + rows, err := dbsqlc.New().JobCountByQueueAndState(schemaTemplateParam(ctx, params.Schema), e.dbtx, &dbsqlc.JobCountByQueueAndStateParams{ + QueueNames: params.QueueNames, + }) + if err != nil { + return nil, interpretError(err) + } + + // MySQL's GROUP BY only returns queues that have jobs, so fill in zero + // counts for any requested queues not in the result. + countsByQueue := make(map[string]*riverdriver.JobCountByQueueAndStateResult, len(rows)) + for _, row := range rows { + countsByQueue[row.Queue] = &riverdriver.JobCountByQueueAndStateResult{ + CountAvailable: row.CountAvailable, + CountRunning: row.CountRunning, + Queue: row.Queue, + } + } + + var ( + queueNames = slices.Compact(slices.Sorted(slices.Values(params.QueueNames))) + results = make([]*riverdriver.JobCountByQueueAndStateResult, len(queueNames)) + ) + for i, name := range queueNames { + if result, ok := countsByQueue[name]; ok { + results[i] = result + } else { + results[i] = &riverdriver.JobCountByQueueAndStateResult{Queue: name} + } + } + + return results, nil +} + +func (e *Executor) JobCountByState(ctx context.Context, params *riverdriver.JobCountByStateParams) (int, error) { + numJobs, err := dbsqlc.New().JobCountByState(schemaTemplateParam(ctx, params.Schema), e.dbtx, string(params.State)) + if err != nil { + return 0, err + } + return int(numJobs), nil +} + +func (e *Executor) JobDelete(ctx context.Context, params *riverdriver.JobDeleteParams) (*rivertype.JobRow, error) { + return dbutil.WithTxV(ctx, e, func(ctx context.Context, execTx riverdriver.ExecutorTx) (*rivertype.JobRow, error) { + ctx = schemaTemplateParam(ctx, params.Schema) + dbtx := templateReplaceWrapper{dbtx: e.driver.UnwrapTx(execTx), replacer: &e.driver.replacer} + + // First fetch the job so we can return it + job, err := dbsqlc.New().JobGetByID(ctx, dbtx, params.ID) + if err != nil { + return nil, interpretError(err) + } + + res, err := dbsqlc.New().JobDeleteExec(ctx, dbtx, params.ID) + if err != nil { + return nil, interpretError(err) + } + + rowsAffected, err := res.RowsAffected() + if err != nil { + return nil, interpretError(err) + } + + if rowsAffected < 1 { + if rivertype.JobState(job.State) == rivertype.JobStateRunning { + return nil, rivertype.ErrJobRunning + } + return nil, fmt.Errorf("bug; expected only to fetch a job with state %q, but was: %q", rivertype.JobStateRunning, job.State) + } + + return jobRowFromInternal(job) + }) +} + +func (e *Executor) JobDeleteBefore(ctx context.Context, params *riverdriver.JobDeleteBeforeParams) (int, error) { + if len(params.QueuesIncluded) > 0 { + return 0, riverdriver.ErrNotImplemented + } + + var queuesExcludedEmpty int64 + if len(params.QueuesExcluded) < 1 { + queuesExcludedEmpty = 1 + } + + res, err := dbsqlc.New().JobDeleteBefore(schemaTemplateParam(ctx, params.Schema), e.dbtx, &dbsqlc.JobDeleteBeforeParams{ + CancelledFinalizedAtHorizon: sql.NullTime{Time: params.CancelledFinalizedAtHorizon, Valid: true}, + CompletedFinalizedAtHorizon: sql.NullTime{Time: params.CompletedFinalizedAtHorizon, Valid: true}, + DiscardedFinalizedAtHorizon: sql.NullTime{Time: params.DiscardedFinalizedAtHorizon, Valid: true}, + Limit: int32(params.Max), //nolint:gosec + QueuesExcluded: params.QueuesExcluded, + QueuesExcludedEmpty: queuesExcludedEmpty, + }) + if err != nil { + return 0, interpretError(err) + } + rowsAffected, err := res.RowsAffected() + if err != nil { + return 0, interpretError(err) + } + return int(rowsAffected), nil +} + +func (e *Executor) JobDeleteMany(ctx context.Context, params *riverdriver.JobDeleteManyParams) ([]*rivertype.JobRow, error) { + var jobs []*dbsqlc.RiverJob + { + ctx := sqlctemplate.WithReplacements(ctx, map[string]sqlctemplate.Replacement{ + "order_by_clause": {Value: params.OrderByClause}, + "where_clause": {Value: params.WhereClause}, + }, params.NamedArgs) + ctx = schemaTemplateParam(ctx, params.Schema) + + // Step 1: Select the rows to delete + var err error + jobs, err = dbsqlc.New().JobDeleteManySelect(ctx, e.dbtx, params.Max) + if err != nil { + return nil, interpretError(err) + } + } + + if len(jobs) > 0 { + // Step 2: Delete those rows by ID. Use a fresh context with only + // schema replacement — the select context has where_clause/order_by_clause + // template params that would panic on the simpler DELETE SQL. + ctx := schemaTemplateParam(ctx, params.Schema) + ids := sliceutil.Map(jobs, func(j *dbsqlc.RiverJob) int64 { return j.ID }) + if err := dbsqlc.New().JobDeleteManyExec(ctx, e.dbtx, ids); err != nil { + return nil, interpretError(err) + } + } + + return sliceutil.MapError(jobs, jobRowFromInternal) +} + +func (e *Executor) JobGetAvailable(ctx context.Context, params *riverdriver.JobGetAvailableParams) ([]*rivertype.JobRow, error) { + ctx = schemaTemplateParam(ctx, params.Schema) + + ids, err := dbsqlc.New().JobGetAvailableIDs(ctx, e.dbtx, &dbsqlc.JobGetAvailableIDsParams{ + Queue: params.Queue, + Now: nullTimeFromPtr(params.Now), + Limit: int32(params.MaxToLock), //nolint:gosec + }) + if err != nil { + return nil, interpretError(err) + } + + if len(ids) == 0 { + return nil, nil + } + + if err := dbsqlc.New().JobGetAvailableUpdate(ctx, e.dbtx, &dbsqlc.JobGetAvailableUpdateParams{ + Now: nullTimeFromPtr(params.Now), + MaxAttemptedBy: int64(params.MaxAttemptedBy), + AttemptedBy: params.ClientID, + ID: ids, + }); err != nil { + return nil, interpretError(err) + } + + jobs, err := dbsqlc.New().JobGetByIDManyOrdered(ctx, e.dbtx, ids) + if err != nil { + return nil, interpretError(err) + } + + return sliceutil.MapError(jobs, jobRowFromInternal) +} + +func (e *Executor) JobGetByID(ctx context.Context, params *riverdriver.JobGetByIDParams) (*rivertype.JobRow, error) { + job, err := dbsqlc.New().JobGetByID(schemaTemplateParam(ctx, params.Schema), e.dbtx, params.ID) + if err != nil { + return nil, interpretError(err) + } + return jobRowFromInternal(job) +} + +func (e *Executor) JobGetByIDMany(ctx context.Context, params *riverdriver.JobGetByIDManyParams) ([]*rivertype.JobRow, error) { + jobs, err := dbsqlc.New().JobGetByIDMany(schemaTemplateParam(ctx, params.Schema), e.dbtx, params.ID) + if err != nil { + return nil, interpretError(err) + } + return sliceutil.MapError(jobs, jobRowFromInternal) +} + +func (e *Executor) JobGetByKindMany(ctx context.Context, params *riverdriver.JobGetByKindManyParams) ([]*rivertype.JobRow, error) { + jobs, err := dbsqlc.New().JobGetByKindMany(schemaTemplateParam(ctx, params.Schema), e.dbtx, params.Kind) + if err != nil { + return nil, interpretError(err) + } + return sliceutil.MapError(jobs, jobRowFromInternal) +} + +func (e *Executor) JobGetStuck(ctx context.Context, params *riverdriver.JobGetStuckParams) ([]*rivertype.JobRow, error) { + jobs, err := dbsqlc.New().JobGetStuck(schemaTemplateParam(ctx, params.Schema), e.dbtx, &dbsqlc.JobGetStuckParams{ + Limit: int32(params.Max), //nolint:gosec + StuckHorizon: sql.NullTime{Time: params.StuckHorizon, Valid: true}, + }) + if err != nil { + return nil, interpretError(err) + } + return sliceutil.MapError(jobs, jobRowFromInternal) +} + +func (e *Executor) JobInsertFastMany(ctx context.Context, params *riverdriver.JobInsertFastManyParams) ([]*riverdriver.JobInsertFastResult, error) { + var ( + insertRes = make([]*riverdriver.JobInsertFastResult, len(params.Jobs)) + + // We use a special `(xmax != 0)` trick in Postgres to determine whether + // an upserted row was inserted or skipped, but as far as I can find, + // there's no such trick possible in MySQL. Instead, we roll a random + // nonce and insert it to metadata. If the same nonce comes back, we know + // we really inserted the row. If not, we're getting an existing row back. + uniqueNonce = randutil.Hex(8) + ) + + if err := dbutil.WithTx(ctx, e, func(ctx context.Context, execTx riverdriver.ExecutorTx) error { + ctx = schemaTemplateParam(ctx, params.Schema) + dbtx := templateReplaceWrapper{dbtx: e.driver.UnwrapTx(execTx), replacer: &e.driver.replacer} + + ids := make([]int64, len(params.Jobs)) + for i, params := range params.Jobs { + insertParams, err := jobInsertFastParams(params, uniqueNonce) + if err != nil { + return err + } + + res, err := dbsqlc.New().JobInsertFast(ctx, dbtx, insertParams) + if err != nil { + return interpretError(err) + } + ids[i], err = res.LastInsertId() + if err != nil { + return err + } + } + + jobs, err := dbsqlc.New().JobGetByIDMany(ctx, dbtx, ids) + if err != nil { + return interpretError(err) + } + + jobsByID := make(map[int64]*dbsqlc.RiverJob, len(jobs)) + for _, j := range jobs { + jobsByID[j.ID] = j + } + + for i, id := range ids { + job, err := jobRowFromInternal(jobsByID[id]) + if err != nil { + return err + } + + insertRes[i] = &riverdriver.JobInsertFastResult{ + Job: job, + UniqueSkippedAsDuplicate: gjson.GetBytes(job.Metadata, rivercommon.MetadataKeyUniqueNonce).Str != uniqueNonce, + } + } + + return nil + }); err != nil { + return nil, err + } + + return insertRes, nil +} + +func (e *Executor) JobInsertFastManyNoReturning(ctx context.Context, params *riverdriver.JobInsertFastManyParams) (int, error) { + var totalRowsAffected int + + if err := dbutil.WithTx(ctx, e, func(ctx context.Context, execTx riverdriver.ExecutorTx) error { + ctx = schemaTemplateParam(ctx, params.Schema) + dbtx := templateReplaceWrapper{dbtx: e.driver.UnwrapTx(execTx), replacer: &e.driver.replacer} + + for _, params := range params.Jobs { + insertParams, err := jobInsertFastParams(params, "") + if err != nil { + return err + } + + res, err := dbsqlc.New().JobInsertFast(ctx, dbtx, insertParams) + if err != nil { + return interpretError(err) + } + + rowsAffected, err := res.RowsAffected() + if err != nil { + return interpretError(err) + } + totalRowsAffected += int(rowsAffected) + } + + return nil + }); err != nil { + return 0, err + } + + return totalRowsAffected, nil +} + +// jobInsertFastParams builds the common insert parameters for a single job. +// If uniqueNonce is non-empty, it's set in the job's metadata for duplicate +// detection. +func jobInsertFastParams(params *riverdriver.JobInsertFastParams, uniqueNonce string) (*dbsqlc.JobInsertFastParams, error) { + metadata := sliceutil.FirstNonEmpty(params.Metadata, []byte("{}")) + if uniqueNonce != "" { + // Error intentionally ignored — SetBytes on valid JSON can't fail + // for a simple key addition. + metadata, _ = sjson.SetBytes(metadata, rivercommon.MetadataKeyUniqueNonce, uniqueNonce) + } + + tags, err := json.Marshal(params.Tags) + if err != nil { + return nil, err + } + + var uniqueStates sql.NullInt16 + if params.UniqueStates != 0 { + uniqueStates = sql.NullInt16{Int16: int16(params.UniqueStates), Valid: true} + } + + var id sql.NullInt64 + if params.ID != nil { + id = sql.NullInt64{Int64: *params.ID, Valid: true} + } + + return &dbsqlc.JobInsertFastParams{ + ID: id, + Args: params.EncodedArgs, + CreatedAt: params.CreatedAt, + Kind: params.Kind, + MaxAttempts: int64(params.MaxAttempts), + Metadata: metadata, + Priority: int16(params.Priority), //nolint:gosec + Queue: params.Queue, + ScheduledAt: params.ScheduledAt, + State: string(params.State), + Tags: tags, + UniqueKey: nullStringFromBytes(params.UniqueKey), + UniqueStates: uniqueStates, + }, nil +} + +func (e *Executor) JobInsertFull(ctx context.Context, params *riverdriver.JobInsertFullParams) (*rivertype.JobRow, error) { + var attemptedBy []byte + if params.AttemptedBy != nil { + var err error + attemptedBy, err = json.Marshal(params.AttemptedBy) + if err != nil { + return nil, err + } + } + + var errorsData []byte + if len(params.Errors) > 0 { + var err error + errorsData, err = json.Marshal(sliceutil.Map(params.Errors, func(e []byte) json.RawMessage { return json.RawMessage(e) })) + if err != nil { + return nil, err + } + } + + tags, err := json.Marshal(params.Tags) + if err != nil { + return nil, err + } + + var uniqueStates sql.NullInt16 + if params.UniqueStates != 0 { + uniqueStates = sql.NullInt16{Int16: int16(params.UniqueStates), Valid: true} + } + + ctx = schemaTemplateParam(ctx, params.Schema) + + lastInsertID, err := dbsqlc.New().JobInsertFullExec(ctx, e.dbtx, &dbsqlc.JobInsertFullExecParams{ + Attempt: int64(params.Attempt), + AttemptedAt: nullTimeFromPtr(params.AttemptedAt), + AttemptedBy: attemptedBy, + Args: params.EncodedArgs, + CreatedAt: params.CreatedAt, + Errors: errorsData, + FinalizedAt: nullTimeFromPtr(params.FinalizedAt), + Kind: params.Kind, + MaxAttempts: int64(params.MaxAttempts), + Metadata: sliceutil.FirstNonEmpty(params.Metadata, []byte("{}")), + Priority: int16(params.Priority), //nolint:gosec + Queue: params.Queue, + ScheduledAt: params.ScheduledAt, + State: string(params.State), + Tags: tags, + UniqueKey: nullStringFromBytes(params.UniqueKey), + UniqueStates: uniqueStates, + }) + if err != nil { + return nil, interpretError(err) + } + + job, err := dbsqlc.New().JobGetByID(ctx, e.dbtx, lastInsertID) + if err != nil { + return nil, interpretError(err) + } + return jobRowFromInternal(job) +} + +func (e *Executor) JobInsertFullMany(ctx context.Context, params *riverdriver.JobInsertFullManyParams) ([]*rivertype.JobRow, error) { + insertRes := make([]*rivertype.JobRow, len(params.Jobs)) + + if err := dbutil.WithTx(ctx, e, func(ctx context.Context, execTx riverdriver.ExecutorTx) error { + ctx = schemaTemplateParam(ctx, params.Schema) + dbtx := templateReplaceWrapper{dbtx: e.driver.UnwrapTx(execTx), replacer: &e.driver.replacer} + + ids := make([]int64, len(params.Jobs)) + + for i, jobParams := range params.Jobs { + var attemptedBy []byte + if jobParams.AttemptedBy != nil { + var err error + attemptedBy, err = json.Marshal(jobParams.AttemptedBy) + if err != nil { + return err + } + } + + var errorsData []byte + if len(jobParams.Errors) > 0 { + var err error + errorsData, err = json.Marshal(sliceutil.Map(jobParams.Errors, func(e []byte) json.RawMessage { return json.RawMessage(e) })) + if err != nil { + return err + } + } + + tags, err := json.Marshal(jobParams.Tags) + if err != nil { + return err + } + + var uniqueStates sql.NullInt16 + if jobParams.UniqueStates != 0 { + uniqueStates = sql.NullInt16{Int16: int16(jobParams.UniqueStates), Valid: true} + } + + ids[i], err = dbsqlc.New().JobInsertFullExec(ctx, dbtx, &dbsqlc.JobInsertFullExecParams{ + Attempt: int64(jobParams.Attempt), + AttemptedAt: nullTimeFromPtr(jobParams.AttemptedAt), + AttemptedBy: attemptedBy, + Args: jobParams.EncodedArgs, + CreatedAt: jobParams.CreatedAt, + Errors: errorsData, + FinalizedAt: nullTimeFromPtr(jobParams.FinalizedAt), + Kind: jobParams.Kind, + MaxAttempts: int64(jobParams.MaxAttempts), + Metadata: sliceutil.FirstNonEmpty(jobParams.Metadata, []byte("{}")), + Priority: int16(jobParams.Priority), //nolint:gosec + Queue: jobParams.Queue, + ScheduledAt: jobParams.ScheduledAt, + State: string(jobParams.State), + Tags: tags, + UniqueKey: nullStringFromBytes(jobParams.UniqueKey), + UniqueStates: uniqueStates, + }) + if err != nil { + return interpretError(err) + } + } + + jobs, err := dbsqlc.New().JobGetByIDMany(ctx, dbtx, ids) + if err != nil { + return interpretError(err) + } + + jobsByID := make(map[int64]*dbsqlc.RiverJob, len(jobs)) + for _, j := range jobs { + jobsByID[j.ID] = j + } + + for i, id := range ids { + insertRes[i], err = jobRowFromInternal(jobsByID[id]) + if err != nil { + return err + } + } + + return nil + }); err != nil { + return nil, err + } + + return insertRes, nil +} + +func (e *Executor) JobKindList(ctx context.Context, params *riverdriver.JobKindListParams) ([]string, error) { + exclude := params.Exclude + if len(exclude) == 0 { + exclude = []string{""} + } + + kinds, err := dbsqlc.New().JobKindList(schemaTemplateParam(ctx, params.Schema), e.dbtx, &dbsqlc.JobKindListParams{ + After: params.After, + Exclude: exclude, + Match: params.Match, + Limit: int32(min(params.Max, math.MaxInt32)), //nolint:gosec + }) + if err != nil { + return nil, interpretError(err) + } + return kinds, nil +} + +func (e *Executor) JobList(ctx context.Context, params *riverdriver.JobListParams) ([]*rivertype.JobRow, error) { + ctx = sqlctemplate.WithReplacements(ctx, map[string]sqlctemplate.Replacement{ + "order_by_clause": {Value: params.OrderByClause}, + "where_clause": {Value: params.WhereClause}, + }, params.NamedArgs) + + jobs, err := dbsqlc.New().JobList(schemaTemplateParam(ctx, params.Schema), e.dbtx, params.Max) + if err != nil { + return nil, interpretError(err) + } + return sliceutil.MapError(jobs, jobRowFromInternal) +} + +func (e *Executor) JobRescueMany(ctx context.Context, params *riverdriver.JobRescueManyParams) (*struct{}, error) { + if err := dbutil.WithTx(ctx, e, func(ctx context.Context, execTx riverdriver.ExecutorTx) error { + ctx = schemaTemplateParam(ctx, params.Schema) + dbtx := templateReplaceWrapper{dbtx: e.driver.UnwrapTx(execTx), replacer: &e.driver.replacer} + + for i := range params.ID { + if err := dbsqlc.New().JobRescue(ctx, dbtx, &dbsqlc.JobRescueParams{ + ID: params.ID[i], + Error: params.Error[i], + FinalizedAt: nullTimeFromPtr(params.FinalizedAt[i]), + ScheduledAt: params.ScheduledAt[i].UTC(), + State: params.State[i], + }); err != nil { + return interpretError(err) + } + } + + return nil + }); err != nil { + return nil, err + } + + return &struct{}{}, nil +} + +func (e *Executor) JobRetry(ctx context.Context, params *riverdriver.JobRetryParams) (*rivertype.JobRow, error) { + return dbutil.WithTxV(ctx, e, func(ctx context.Context, execTx riverdriver.ExecutorTx) (*rivertype.JobRow, error) { + ctx = schemaTemplateParam(ctx, params.Schema) + dbtx := templateReplaceWrapper{dbtx: e.driver.UnwrapTx(execTx), replacer: &e.driver.replacer} + + _, err := dbsqlc.New().JobRetryExec(ctx, dbtx, &dbsqlc.JobRetryExecParams{ + ID: params.ID, + Now: nullTimeFromPtr(params.Now), + }) + if err != nil { + return nil, interpretError(err) + } + + job, err := dbsqlc.New().JobGetByID(ctx, dbtx, params.ID) + if err != nil { + return nil, interpretError(err) + } + return jobRowFromInternal(job) + }) +} + +func (e *Executor) JobSchedule(ctx context.Context, params *riverdriver.JobScheduleParams) ([]*riverdriver.JobScheduleResult, error) { + return dbutil.WithTxV(ctx, e, func(ctx context.Context, execTx riverdriver.ExecutorTx) ([]*riverdriver.JobScheduleResult, error) { + ctx = schemaTemplateParam(ctx, params.Schema) + dbtx := templateReplaceWrapper{dbtx: e.driver.UnwrapTx(execTx), replacer: &e.driver.replacer} + + scheduleResults, err := dbsqlc.New().JobSchedule(ctx, dbtx, &dbsqlc.JobScheduleParams{ + Limit: int32(params.Max), //nolint:gosec + Now: nullTimeFromPtr(params.Now), + }) + if err != nil { + return nil, interpretError(err) + } + + var ( + allIDs []int64 + availIDs []int64 + discardIDs []int64 + discardSet = make(map[int64]bool) + ) + + for _, result := range scheduleResults { + allIDs = append(allIDs, result.ID) + if result.ConflictDiscarded != 0 { + discardIDs = append(discardIDs, result.ID) + discardSet[result.ID] = true + } else { + availIDs = append(availIDs, result.ID) + } + } + + if len(availIDs) > 0 { + if err := dbsqlc.New().JobScheduleSetAvailableExec(ctx, dbtx, availIDs); err != nil { + return nil, interpretError(err) + } + } + + if len(discardIDs) > 0 { + if err := dbsqlc.New().JobScheduleSetDiscardedExec(ctx, dbtx, &dbsqlc.JobScheduleSetDiscardedExecParams{ + ID: discardIDs, + Now: nullTimeFromPtr(params.Now), + }); err != nil { + return nil, interpretError(err) + } + } + + if len(allIDs) == 0 { + return nil, nil + } + + updatedJobs, err := dbsqlc.New().JobGetByIDMany(ctx, dbtx, allIDs) + if err != nil { + return nil, interpretError(err) + } + + jobsByID := make(map[int64]*dbsqlc.RiverJob, len(updatedJobs)) + for _, j := range updatedJobs { + jobsByID[j.ID] = j + } + + // Return results in the same order as scheduleResults. + results := make([]*riverdriver.JobScheduleResult, len(scheduleResults)) + for i, sr := range scheduleResults { + job, err := jobRowFromInternal(jobsByID[sr.ID]) + if err != nil { + return nil, err + } + results[i] = &riverdriver.JobScheduleResult{ConflictDiscarded: discardSet[sr.ID], Job: *job} + } + + return results, nil + }) +} + +func (e *Executor) JobSetStateIfRunningMany(ctx context.Context, params *riverdriver.JobSetStateIfRunningManyParams) ([]*rivertype.JobRow, error) { + setRes := make([]*rivertype.JobRow, len(params.ID)) + + if err := dbutil.WithTx(ctx, e, func(ctx context.Context, execTx riverdriver.ExecutorTx) error { + ctx = schemaTemplateParam(ctx, params.Schema) + dbtx := templateReplaceWrapper{dbtx: e.driver.UnwrapTx(execTx), replacer: &e.driver.replacer} + + // Step 1: Execute all state changes. + for i := range params.ID { + setStateParams := &dbsqlc.JobSetStateIfRunningExecParams{ + ID: params.ID[i], + Error: []byte("{}"), + MetadataUpdates: []byte("{}"), + Now: nullTimeFromPtr(params.Now), + State: string(params.State[i]), + } + + if params.Attempt[i] != nil { + setStateParams.AttemptDoUpdate = 1 + setStateParams.Attempt = int64(*params.Attempt[i]) + } + if params.ErrData[i] != nil { + setStateParams.ErrorsDoUpdate = 1 + setStateParams.Error = params.ErrData[i] + } + if params.FinalizedAt[i] != nil { + setStateParams.FinalizedAtDoUpdate = 1 + setStateParams.FinalizedAt = nullTimeFromPtr(params.FinalizedAt[i]) + } + if params.MetadataDoMerge[i] { + setStateParams.MetadataDoMerge = 1 + setStateParams.MetadataUpdates = params.MetadataUpdates[i] + } + if params.ScheduledAt[i] != nil { + setStateParams.ScheduledAtDoUpdate = 1 + setStateParams.ScheduledAt = *params.ScheduledAt[i] + } + + if err := dbsqlc.New().JobSetStateIfRunningExec(ctx, dbtx, setStateParams); err != nil { + return fmt.Errorf("error setting job state: %w", err) + } + } + + // Step 2: Batch fetch all jobs. + jobs, err := dbsqlc.New().JobGetByIDMany(ctx, dbtx, params.ID) + if err != nil { + return interpretError(err) + } + + jobsByID := make(map[int64]*dbsqlc.RiverJob, len(jobs)) + for _, j := range jobs { + jobsByID[j.ID] = j + } + + // Step 3: For jobs that weren't running, merge metadata if requested. + var metadataMergedIDs []int64 + for i := range params.ID { + job := jobsByID[params.ID[i]] + if job == nil { + continue + } + + if rivertype.JobState(job.State) != rivertype.JobStateRunning && params.MetadataDoMerge[i] { + res, err := dbsqlc.New().JobSetMetadataIfNotRunningExec(ctx, dbtx, &dbsqlc.JobSetMetadataIfNotRunningExecParams{ + ID: params.ID[i], + MetadataUpdates: sliceutil.FirstNonEmpty(params.MetadataUpdates[i], []byte("{}")), + }) + if err != nil { + return fmt.Errorf("error setting job metadata: %w", err) + } + + rowsAffected, err := res.RowsAffected() + if err != nil { + return err + } + + if rowsAffected > 0 { + metadataMergedIDs = append(metadataMergedIDs, params.ID[i]) + } + } + } + + // Step 4: Re-fetch jobs that had metadata merged. + if len(metadataMergedIDs) > 0 { + refreshed, err := dbsqlc.New().JobGetByIDMany(ctx, dbtx, metadataMergedIDs) + if err != nil { + return interpretError(err) + } + for _, j := range refreshed { + jobsByID[j.ID] = j + } + } + + // Step 5: Build results in original order. + for i := range params.ID { + job := jobsByID[params.ID[i]] + if job == nil { + continue + } + + setRes[i], err = jobRowFromInternal(job) + if err != nil { + return err + } + } + + return nil + }); err != nil { + return nil, err + } + + return setRes, nil +} + +func (e *Executor) JobUpdate(ctx context.Context, params *riverdriver.JobUpdateParams) (*rivertype.JobRow, error) { + metadata := params.Metadata + if metadata == nil { + metadata = []byte("{}") + } + + ctx = schemaTemplateParam(ctx, params.Schema) + + var metadataDoMerge int64 + if params.MetadataDoMerge { + metadataDoMerge = 1 + } + + if err := dbsqlc.New().JobUpdateExec(ctx, e.dbtx, &dbsqlc.JobUpdateExecParams{ + ID: params.ID, + MetadataDoMerge: metadataDoMerge, + Metadata: metadata, + }); err != nil { + return nil, interpretError(err) + } + + job, err := dbsqlc.New().JobGetByID(ctx, e.dbtx, params.ID) + if err != nil { + return nil, interpretError(err) + } + + return jobRowFromInternal(job) +} + +func (e *Executor) JobUpdateFull(ctx context.Context, params *riverdriver.JobUpdateFullParams) (*rivertype.JobRow, error) { + attemptedAt := params.AttemptedAt + if attemptedAt != nil { + attemptedAt = ptrutil.Ptr(attemptedAt.UTC()) + } + + attemptedBy, err := json.Marshal(params.AttemptedBy) + if err != nil { + return nil, err + } + + errorsData, err := json.Marshal(sliceutil.Map(params.Errors, func(e []byte) json.RawMessage { return json.RawMessage(e) })) + if err != nil { + return nil, err + } + + finalizedAt := params.FinalizedAt + if finalizedAt != nil { + finalizedAt = ptrutil.Ptr(finalizedAt.UTC()) + } + + metadata := params.Metadata + if metadata == nil { + metadata = []byte("{}") + } + + ctx = schemaTemplateParam(ctx, params.Schema) + + if err := dbsqlc.New().JobUpdateFullExec(ctx, e.dbtx, &dbsqlc.JobUpdateFullExecParams{ + ID: params.ID, + Attempt: int64(params.Attempt), + AttemptDoUpdate: boolToInt64(params.AttemptDoUpdate), + AttemptedAt: nullTimeFromPtr(attemptedAt), + AttemptedAtDoUpdate: boolToInt64(params.AttemptedAtDoUpdate), + AttemptedBy: attemptedBy, + AttemptedByDoUpdate: boolToInt64(params.AttemptedByDoUpdate), + ErrorsDoUpdate: boolToInt64(params.ErrorsDoUpdate), + Errors: errorsData, + FinalizedAtDoUpdate: boolToInt64(params.FinalizedAtDoUpdate), + FinalizedAt: nullTimeFromPtr(finalizedAt), + MaxAttemptsDoUpdate: boolToInt64(params.MaxAttemptsDoUpdate), + MaxAttempts: int64(min(params.MaxAttempts, math.MaxInt64)), + MetadataDoUpdate: boolToInt64(params.MetadataDoUpdate), + Metadata: metadata, + StateDoUpdate: boolToInt64(params.StateDoUpdate), + State: string(params.State), + }); err != nil { + return nil, interpretError(err) + } + + job, err := dbsqlc.New().JobGetByID(ctx, e.dbtx, params.ID) + if err != nil { + return nil, interpretError(err) + } + + return jobRowFromInternal(job) +} + +func (e *Executor) LeaderAttemptElect(ctx context.Context, params *riverdriver.LeaderElectParams) (*riverdriver.Leader, error) { + ctx = schemaTemplateParam(ctx, params.Schema) + + res, err := dbsqlc.New().LeaderAttemptElectExec(ctx, e.dbtx, &dbsqlc.LeaderAttemptElectExecParams{ + LeaderID: params.LeaderID, + Now: nullTimeFromPtr(params.Now), + TTL: params.TTL.Microseconds(), + }) + if err != nil { + return nil, interpretError(err) + } + + // INSERT IGNORE returns 0 rows affected when a leader already exists. + affected, err := res.RowsAffected() + if err != nil { + return nil, err + } + if affected == 0 { + return nil, rivertype.ErrNotFound + } + + leader, err := dbsqlc.New().LeaderGetElectedLeader(ctx, e.dbtx) + if err != nil { + return nil, interpretError(err) + } + return leaderFromInternal(leader), nil +} + +func (e *Executor) LeaderAttemptReelect(ctx context.Context, params *riverdriver.LeaderReelectParams) (*riverdriver.Leader, error) { + ctx = schemaTemplateParam(ctx, params.Schema) + + res, err := dbsqlc.New().LeaderAttemptReelectExec(ctx, e.dbtx, &dbsqlc.LeaderAttemptReelectExecParams{ + ElectedAt: params.ElectedAt, + LeaderID: params.LeaderID, + Now: nullTimeFromPtr(params.Now), + TTL: params.TTL.Microseconds(), + }) + if err != nil { + return nil, interpretError(err) + } + + affected, err := res.RowsAffected() + if err != nil { + return nil, err + } + if affected == 0 { + return nil, rivertype.ErrNotFound + } + + leader, err := dbsqlc.New().LeaderGetElectedLeader(ctx, e.dbtx) + if err != nil { + return nil, interpretError(err) + } + return leaderFromInternal(leader), nil +} + +func (e *Executor) LeaderDeleteExpired(ctx context.Context, params *riverdriver.LeaderDeleteExpiredParams) (int, error) { + numDeleted, err := dbsqlc.New().LeaderDeleteExpired(schemaTemplateParam(ctx, params.Schema), e.dbtx, nullTimeFromPtr(params.Now)) + if err != nil { + return 0, interpretError(err) + } + return int(numDeleted), nil +} + +func (e *Executor) LeaderGetElectedLeader(ctx context.Context, params *riverdriver.LeaderGetElectedLeaderParams) (*riverdriver.Leader, error) { + leader, err := dbsqlc.New().LeaderGetElectedLeader(schemaTemplateParam(ctx, params.Schema), e.dbtx) + if err != nil { + return nil, interpretError(err) + } + return leaderFromInternal(leader), nil +} + +func (e *Executor) LeaderInsert(ctx context.Context, params *riverdriver.LeaderInsertParams) (*riverdriver.Leader, error) { + ctx = schemaTemplateParam(ctx, params.Schema) + + if err := dbsqlc.New().LeaderInsertExec(ctx, e.dbtx, &dbsqlc.LeaderInsertExecParams{ + ElectedAt: params.ElectedAt, + ExpiresAt: params.ExpiresAt, + Now: params.Now, + LeaderID: params.LeaderID, + TTL: params.TTL.Microseconds(), + }); err != nil { + return nil, interpretError(err) + } + + leader, err := dbsqlc.New().LeaderGetElectedLeader(ctx, e.dbtx) + if err != nil { + return nil, interpretError(err) + } + return leaderFromInternal(leader), nil +} + +func (e *Executor) LeaderResign(ctx context.Context, params *riverdriver.LeaderResignParams) (bool, error) { + numResigned, err := dbsqlc.New().LeaderResign(schemaTemplateParam(ctx, params.Schema), e.dbtx, &dbsqlc.LeaderResignParams{ + ElectedAt: params.ElectedAt, + LeaderID: params.LeaderID, + }) + if err != nil { + return false, interpretError(err) + } + return numResigned > 0, nil +} + +func (e *Executor) MigrationDeleteAssumingMainMany(ctx context.Context, params *riverdriver.MigrationDeleteAssumingMainManyParams) ([]*riverdriver.Migration, error) { + ctx = schemaTemplateParam(ctx, params.Schema) + versions := sliceutil.Map(params.Versions, func(v int) int64 { return int64(v) }) + + migrations, err := dbsqlc.New().RiverMigrationDeleteAssumingMainMany(ctx, e.dbtx, versions) + if err != nil { + return nil, interpretError(err) + } + + if len(versions) > 0 { + if err := dbsqlc.New().RiverMigrationDeleteAssumingMainManyExec(ctx, e.dbtx, versions); err != nil { + return nil, interpretError(err) + } + } + + return sliceutil.Map(migrations, func(internal *dbsqlc.RiverMigrationDeleteAssumingMainManyRow) *riverdriver.Migration { + return &riverdriver.Migration{ + CreatedAt: internal.CreatedAt.UTC(), + Line: riverdriver.MigrationLineMain, + Version: int(internal.Version), + } + }), nil +} + +func (e *Executor) MigrationDeleteByLineAndVersionMany(ctx context.Context, params *riverdriver.MigrationDeleteByLineAndVersionManyParams) ([]*riverdriver.Migration, error) { + ctx = schemaTemplateParam(ctx, params.Schema) + versions := sliceutil.Map(params.Versions, func(v int) int64 { return int64(v) }) + + // Step 1: Select the rows to return + migrations, err := dbsqlc.New().RiverMigrationDeleteByLineAndVersionMany(ctx, e.dbtx, &dbsqlc.RiverMigrationDeleteByLineAndVersionManyParams{ + Line: params.Line, + Version: versions, + }) + if err != nil { + return nil, interpretError(err) + } + + // Step 2: Delete those rows + if len(versions) > 0 { + if err := dbsqlc.New().RiverMigrationDeleteByLineAndVersionManyExec(ctx, e.dbtx, &dbsqlc.RiverMigrationDeleteByLineAndVersionManyExecParams{ + Line: params.Line, + Version: versions, + }); err != nil { + return nil, interpretError(err) + } + } + + return sliceutil.Map(migrations, migrationFromInternal), nil +} + +func (e *Executor) MigrationGetAllAssumingMain(ctx context.Context, params *riverdriver.MigrationGetAllAssumingMainParams) ([]*riverdriver.Migration, error) { + migrations, err := dbsqlc.New().RiverMigrationGetAllAssumingMain(schemaTemplateParam(ctx, params.Schema), e.dbtx) + if err != nil { + return nil, interpretError(err) + } + return sliceutil.Map(migrations, func(internal *dbsqlc.RiverMigrationGetAllAssumingMainRow) *riverdriver.Migration { + return &riverdriver.Migration{ + CreatedAt: internal.CreatedAt.UTC(), + Line: riverdriver.MigrationLineMain, + Version: int(internal.Version), + } + }), nil +} + +func (e *Executor) MigrationGetByLine(ctx context.Context, params *riverdriver.MigrationGetByLineParams) ([]*riverdriver.Migration, error) { + migrations, err := dbsqlc.New().RiverMigrationGetByLine(schemaTemplateParam(ctx, params.Schema), e.dbtx, params.Line) + if err != nil { + return nil, interpretError(err) + } + return sliceutil.Map(migrations, migrationFromInternal), nil +} + +func (e *Executor) MigrationInsertMany(ctx context.Context, params *riverdriver.MigrationInsertManyParams) ([]*riverdriver.Migration, error) { + var migrations []*riverdriver.Migration + + if err := dbutil.WithTx(ctx, e, func(ctx context.Context, execTx riverdriver.ExecutorTx) error { + ctx = schemaTemplateParam(ctx, params.Schema) + dbtx := templateReplaceWrapper{dbtx: e.driver.UnwrapTx(execTx), replacer: &e.driver.replacer} + + for _, version := range params.Versions { + if err := dbsqlc.New().RiverMigrationInsertExec(ctx, dbtx, &dbsqlc.RiverMigrationInsertExecParams{ + Line: params.Line, + Version: int64(version), + }); err != nil { + return interpretError(err) + } + } + + versions := sliceutil.Map(params.Versions, func(v int) int64 { return int64(v) }) + + internals, err := dbsqlc.New().RiverMigrationGetByLineAndVersionMany(ctx, dbtx, &dbsqlc.RiverMigrationGetByLineAndVersionManyParams{ + Line: params.Line, + Version: versions, + }) + if err != nil { + return interpretError(err) + } + + migrations = sliceutil.Map(internals, migrationFromInternal) + return nil + }); err != nil { + return nil, err + } + + return migrations, nil +} + +func (e *Executor) MigrationInsertManyAssumingMain(ctx context.Context, params *riverdriver.MigrationInsertManyAssumingMainParams) ([]*riverdriver.Migration, error) { + var migrations []*riverdriver.Migration + + if err := dbutil.WithTx(ctx, e, func(ctx context.Context, execTx riverdriver.ExecutorTx) error { + ctx = schemaTemplateParam(ctx, params.Schema) + dbtx := templateReplaceWrapper{dbtx: e.driver.UnwrapTx(execTx), replacer: &e.driver.replacer} + + for _, version := range params.Versions { + if err := dbsqlc.New().RiverMigrationInsertAssumingMainExec(ctx, dbtx, int64(version)); err != nil { + return interpretError(err) + } + } + + versions := sliceutil.Map(params.Versions, func(v int) int64 { return int64(v) }) + + internals, err := dbsqlc.New().RiverMigrationGetByVersionMany(ctx, dbtx, versions) + if err != nil { + return interpretError(err) + } + + migrations = sliceutil.Map(internals, func(internal *dbsqlc.RiverMigrationGetByVersionManyRow) *riverdriver.Migration { + return &riverdriver.Migration{ + CreatedAt: internal.CreatedAt.UTC(), + Line: riverdriver.MigrationLineMain, + Version: int(internal.Version), + } + }) + return nil + }); err != nil { + return nil, err + } + + return migrations, nil +} + +func (e *Executor) NotifyMany(ctx context.Context, params *riverdriver.NotifyManyParams) error { + return riverdriver.ErrNotImplemented +} + +func (e *Executor) PGAdvisoryXactLock(ctx context.Context, key int64) (*struct{}, error) { + // MySQL has GET_LOCK which is similar to PostgreSQL advisory locks, but + // it's session-scoped rather than transaction-scoped. For now, return + // not implemented. + return nil, riverdriver.ErrNotImplemented +} + +func (e *Executor) QueueCreateOrSetUpdatedAt(ctx context.Context, params *riverdriver.QueueCreateOrSetUpdatedAtParams) (*rivertype.Queue, error) { + ctx = schemaTemplateParam(ctx, params.Schema) + + if err := dbsqlc.New().QueueCreateOrSetUpdatedAtExec(ctx, e.dbtx, &dbsqlc.QueueCreateOrSetUpdatedAtExecParams{ + Metadata: sliceutil.FirstNonEmpty(params.Metadata, []byte("{}")), + Name: params.Name, + Now: params.Now, + PausedAt: nullTimeFromPtr(params.PausedAt), + UpdatedAt: params.UpdatedAt, + }); err != nil { + return nil, interpretError(err) + } + + queue, err := dbsqlc.New().QueueGet(ctx, e.dbtx, params.Name) + if err != nil { + return nil, interpretError(err) + } + return queueFromInternal(queue), nil +} + +func (e *Executor) QueueDeleteExpired(ctx context.Context, params *riverdriver.QueueDeleteExpiredParams) ([]string, error) { + ctx = schemaTemplateParam(ctx, params.Schema) + + // Step 1: Select expired queue names + queueNames, err := dbsqlc.New().QueueDeleteExpiredSelect(ctx, e.dbtx, &dbsqlc.QueueDeleteExpiredSelectParams{ + Limit: int32(params.Max), //nolint:gosec + UpdatedAtHorizon: params.UpdatedAtHorizon.UTC(), + }) + if err != nil { + return nil, interpretError(err) + } + + // Step 2: Delete those queues + if len(queueNames) > 0 { + if err := dbsqlc.New().QueueDeleteExpiredExec(ctx, e.dbtx, queueNames); err != nil { + return nil, interpretError(err) + } + } + + return queueNames, nil +} + +func (e *Executor) QueueGet(ctx context.Context, params *riverdriver.QueueGetParams) (*rivertype.Queue, error) { + queue, err := dbsqlc.New().QueueGet(schemaTemplateParam(ctx, params.Schema), e.dbtx, params.Name) + if err != nil { + return nil, interpretError(err) + } + return queueFromInternal(queue), nil +} + +func (e *Executor) QueueList(ctx context.Context, params *riverdriver.QueueListParams) ([]*rivertype.Queue, error) { + queues, err := dbsqlc.New().QueueList(schemaTemplateParam(ctx, params.Schema), e.dbtx, int32(params.Max)) //nolint:gosec + if err != nil { + return nil, interpretError(err) + } + return sliceutil.Map(queues, queueFromInternal), nil +} + +func (e *Executor) QueueNameList(ctx context.Context, params *riverdriver.QueueNameListParams) ([]string, error) { + exclude := params.Exclude + if len(exclude) == 0 { + exclude = []string{""} + } + queueNames, err := dbsqlc.New().QueueNameList(schemaTemplateParam(ctx, params.Schema), e.dbtx, &dbsqlc.QueueNameListParams{ + After: params.After, + Exclude: exclude, + Match: params.Match, + Limit: int32(min(params.Max, math.MaxInt32)), //nolint:gosec + }) + if err != nil { + return nil, interpretError(err) + } + return queueNames, nil +} + +func (e *Executor) QueuePause(ctx context.Context, params *riverdriver.QueuePauseParams) error { + ctx = schemaTemplateParam(ctx, params.Schema) + now := nullTimeFromPtr(params.Now) + + var ( + res sql.Result + err error + ) + + if params.Name == riverdriver.AllQueuesString { + res, err = dbsqlc.New().QueuePauseAll(ctx, e.dbtx, &dbsqlc.QueuePauseAllParams{Now: now}) + } else { + res, err = dbsqlc.New().QueuePauseByName(ctx, e.dbtx, &dbsqlc.QueuePauseByNameParams{ + Name: params.Name, + Now: now, + }) + } + if err != nil { + return interpretError(err) + } + + rowsAffected, err := res.RowsAffected() + if err != nil { + return interpretError(err) + } + + // MySQL only reports rows that actually changed values, not rows matched. + // If no rows were affected for a named queue, check if the queue exists + // before returning ErrNotFound (it may already be paused/unpaused). + if rowsAffected < 1 && params.Name != riverdriver.AllQueuesString { + if _, err := dbsqlc.New().QueueGet(ctx, e.dbtx, params.Name); err != nil { + return interpretError(err) + } + } + return nil +} + +func (e *Executor) QueueResume(ctx context.Context, params *riverdriver.QueueResumeParams) error { + ctx = schemaTemplateParam(ctx, params.Schema) + now := nullTimeFromPtr(params.Now) + + var ( + res sql.Result + err error + ) + + if params.Name == riverdriver.AllQueuesString { + res, err = dbsqlc.New().QueueResumeAll(ctx, e.dbtx, now) + } else { + res, err = dbsqlc.New().QueueResumeByName(ctx, e.dbtx, &dbsqlc.QueueResumeByNameParams{ + Name: params.Name, + Now: now, + }) + } + if err != nil { + return interpretError(err) + } + + rowsAffected, err := res.RowsAffected() + if err != nil { + return interpretError(err) + } + + // MySQL only reports rows that actually changed values, not rows matched. + // If no rows were affected for a named queue, check if the queue exists + // before returning ErrNotFound (it may already be resumed/unpaused). + if rowsAffected < 1 && params.Name != riverdriver.AllQueuesString { + if _, err := dbsqlc.New().QueueGet(ctx, e.dbtx, params.Name); err != nil { + return interpretError(err) + } + } + return nil +} + +func (e *Executor) QueueUpdate(ctx context.Context, params *riverdriver.QueueUpdateParams) (*rivertype.Queue, error) { + ctx = schemaTemplateParam(ctx, params.Schema) + + var metadataDoUpdate int64 + if params.MetadataDoUpdate { + metadataDoUpdate = 1 + } + + if err := dbsqlc.New().QueueUpdateExec(ctx, e.dbtx, &dbsqlc.QueueUpdateExecParams{ + Metadata: sliceutil.FirstNonEmpty(params.Metadata, []byte("{}")), + MetadataDoUpdate: metadataDoUpdate, + Name: params.Name, + }); err != nil { + return nil, interpretError(err) + } + + queue, err := dbsqlc.New().QueueGet(ctx, e.dbtx, params.Name) + if err != nil { + return nil, interpretError(err) + } + return queueFromInternal(queue), nil +} + +func (e *Executor) QueryRow(ctx context.Context, sql string, args ...any) riverdriver.Row { + return e.dbtx.QueryRowContext(ctx, sql, args...) +} + +func (e *Executor) SchemaCreate(ctx context.Context, params *riverdriver.SchemaCreateParams) error { + // In MySQL, schemas are databases. Create one if it doesn't exist. + if params.Schema != "" { + _, err := e.dbtx.ExecContext(ctx, "CREATE DATABASE IF NOT EXISTS "+mysqlIdentifier(params.Schema)) + return interpretError(err) + } + return nil +} + +func (e *Executor) SchemaDrop(ctx context.Context, params *riverdriver.SchemaDropParams) error { + if params.Schema != "" { + _, err := e.dbtx.ExecContext(ctx, "DROP DATABASE IF EXISTS "+mysqlIdentifier(params.Schema)) + return interpretError(err) + } + return nil +} + +func (e *Executor) SchemaGetExpired(ctx context.Context, params *riverdriver.SchemaGetExpiredParams) ([]string, error) { + rows, err := e.dbtx.QueryContext(ctx, + `SELECT SCHEMA_NAME FROM INFORMATION_SCHEMA.SCHEMATA + WHERE SCHEMA_NAME LIKE CONCAT(?, '%') AND SCHEMA_NAME < ? + ORDER BY SCHEMA_NAME`, + params.Prefix, params.BeforeName) + if err != nil { + return nil, interpretError(err) + } + defer rows.Close() + + var schemas []string + for rows.Next() { + var name string + if err := rows.Scan(&name); err != nil { + return nil, err + } + schemas = append(schemas, name) + } + return schemas, rows.Err() +} + +func (e *Executor) TableExists(ctx context.Context, params *riverdriver.TableExistsParams) (bool, error) { + ctx = informationSchemaParam(ctx) + + exists, err := dbsqlc.New().TableExists(ctx, e.dbtx, &dbsqlc.TableExistsParams{ + TableName: params.Table, + Schema: sql.NullString{String: params.Schema, Valid: params.Schema != ""}, + }) + if err != nil { + return false, interpretError(err) + } + return exists, nil +} + +func (e *Executor) TableTruncate(ctx context.Context, params *riverdriver.TableTruncateParams) error { + var maybeSchema string + if params.Schema != "" { + maybeSchema = mysqlIdentifier(params.Schema) + "." + } + + for _, table := range params.Table { + // MySQL's TRUNCATE TABLE is DDL and can't be used in transactions. + // Use DELETE FROM instead for transactional safety. + _, err := e.dbtx.ExecContext(ctx, "DELETE FROM "+maybeSchema+table) + if err != nil { + return interpretError(err) + } + } + + return nil +} + +type ExecutorTx struct { + Executor + + tx *sql.Tx +} + +func (t *ExecutorTx) Begin(ctx context.Context) (riverdriver.ExecutorTx, error) { + executorSubTx := &ExecutorSubTx{ + beginOnce: &savepointutil.BeginOnlyOnce{}, + savepointNum: 0, + tx: t.tx, + } + executorSubTx.Executor = Executor{nil, templateReplaceWrapper{t.tx, &t.driver.replacer}, t.driver, executorSubTx} + return executorSubTx.Begin(ctx) +} + +func (t *ExecutorTx) Commit(ctx context.Context) error { + return t.tx.Commit() +} + +func (t *ExecutorTx) Rollback(ctx context.Context) error { + return t.tx.Rollback() +} + +type ExecutorSubTx struct { + Executor + + beginOnce *savepointutil.BeginOnlyOnce + savepointNum int + tx *sql.Tx +} + +const savepointPrefix = "river_savepoint_" + +func (t *ExecutorSubTx) Begin(ctx context.Context) (riverdriver.ExecutorTx, error) { + if err := t.beginOnce.Begin(); err != nil { + return nil, err + } + + nextSavepointNum := t.savepointNum + 1 + if err := t.Exec(ctx, fmt.Sprintf("SAVEPOINT %s%02d", savepointPrefix, nextSavepointNum)); err != nil { + return nil, err + } + + executorSubTx := &ExecutorSubTx{ + beginOnce: savepointutil.NewBeginOnlyOnce(t.beginOnce), + savepointNum: nextSavepointNum, + tx: t.tx, + } + executorSubTx.Executor = Executor{nil, templateReplaceWrapper{t.tx, &t.driver.replacer}, t.driver, executorSubTx} + + return executorSubTx, nil +} + +func (t *ExecutorSubTx) Commit(ctx context.Context) error { + defer t.beginOnce.Done() + + if t.beginOnce.IsDone() { + return errors.New("tx is closed") + } + + if err := t.Exec(ctx, fmt.Sprintf("RELEASE SAVEPOINT %s%02d", savepointPrefix, t.savepointNum)); err != nil { + // MySQL DDL statements (CREATE TABLE, ALTER TABLE, etc.) cause an + // implicit COMMIT which destroys savepoints. If the savepoint no + // longer exists, the work was already committed by the DDL. + if isSavepointDoesNotExist(err) { + return nil + } + return err + } + + return nil +} + +func (t *ExecutorSubTx) Rollback(ctx context.Context) error { + defer t.beginOnce.Done() + + if t.beginOnce.IsDone() { + return errors.New("tx is closed") + } + + if err := t.Exec(ctx, fmt.Sprintf("ROLLBACK TO SAVEPOINT %s%02d", savepointPrefix, t.savepointNum)); err != nil { + // MySQL DDL statements cause implicit COMMIT which destroys + // savepoints. If the savepoint no longer exists, there's nothing + // to roll back to. + if isSavepointDoesNotExist(err) { + return nil + } + return err + } + + return nil +} + +// isSavepointDoesNotExist checks if a MySQL error indicates that a savepoint +// does not exist (error 1305). This happens when DDL statements cause an +// implicit COMMIT that destroys active savepoints. +func isSavepointDoesNotExist(err error) bool { + return err != nil && strings.Contains(err.Error(), "1305") +} + +func boolToInt64(b bool) int64 { + if b { + return 1 + } + return 0 +} + +func interpretError(err error) error { + if errors.Is(err, sql.ErrNoRows) { + return rivertype.ErrNotFound + } + return err +} + +// mysqlIdentifier quotes an identifier with backticks for MySQL. +// MySQL uses backticks instead of double quotes for identifier quoting. +func mysqlIdentifier(ident string) string { + return "`" + strings.ReplaceAll(ident, "`", "``") + "`" +} + +type templateReplaceWrapper struct { + dbtx dbsqlc.DBTX + replacer *sqlctemplate.Replacer +} + +func (w templateReplaceWrapper) ExecContext(ctx context.Context, rawSQL string, rawArgs ...any) (sql.Result, error) { + sqlStr, args := w.replacer.Run(ctx, argPlaceholder, rawSQL, rawArgs) + return w.dbtx.ExecContext(ctx, sqlStr, args...) +} + +func (w templateReplaceWrapper) PrepareContext(ctx context.Context, rawSQL string) (*sql.Stmt, error) { + sqlStr, _ := w.replacer.Run(ctx, argPlaceholder, rawSQL, nil) + return w.dbtx.PrepareContext(ctx, sqlStr) +} + +func (w templateReplaceWrapper) QueryContext(ctx context.Context, rawSQL string, rawArgs ...any) (*sql.Rows, error) { + sqlStr, args := w.replacer.Run(ctx, argPlaceholder, rawSQL, rawArgs) + return w.dbtx.QueryContext(ctx, sqlStr, args...) +} + +func (w templateReplaceWrapper) QueryRowContext(ctx context.Context, rawSQL string, rawArgs ...any) *sql.Row { + sqlStr, args := w.replacer.Run(ctx, argPlaceholder, rawSQL, rawArgs) + return w.dbtx.QueryRowContext(ctx, sqlStr, args...) +} + +func jobRowFromInternal(internal *dbsqlc.RiverJob) (*rivertype.JobRow, error) { + attemptedAt := ptrFromNullTime(internal.AttemptedAt) + if attemptedAt != nil { + t := attemptedAt.UTC() + attemptedAt = &t + } + + var attemptedBy []string + if internal.AttemptedBy != nil { + if err := json.Unmarshal(internal.AttemptedBy, &attemptedBy); err != nil { + return nil, fmt.Errorf("error unmarshaling `attempted_by`: %w", err) + } + } + + var errs []rivertype.AttemptError + if internal.Errors != nil { + if err := json.Unmarshal(internal.Errors, &errs); err != nil { + return nil, fmt.Errorf("error unmarshaling `errors`: %w", err) + } + } + + finalizedAt := ptrFromNullTime(internal.FinalizedAt) + if finalizedAt != nil { + t := finalizedAt.UTC() + finalizedAt = &t + } + + var tags []string + if err := json.Unmarshal(internal.Tags, &tags); err != nil { + return nil, fmt.Errorf("error unmarshaling `tags`: %w", err) + } + + var uniqueStatesByte byte + if internal.UniqueStates.Valid { + if internal.UniqueStates.Int16 < 0 || internal.UniqueStates.Int16 > 255 { + return nil, fmt.Errorf("value out of range for byte: %d", internal.UniqueStates.Int16) + } + uniqueStatesByte = byte(internal.UniqueStates.Int16) + } + + return &rivertype.JobRow{ + ID: internal.ID, + Attempt: max(int(internal.Attempt), 0), + AttemptedAt: attemptedAt, + AttemptedBy: attemptedBy, + CreatedAt: internal.CreatedAt.UTC(), + EncodedArgs: internal.Args, + Errors: errs, + FinalizedAt: finalizedAt, + Kind: internal.Kind, + MaxAttempts: max(int(internal.MaxAttempts), 0), + Metadata: internal.Metadata, + Priority: max(int(internal.Priority), 0), + Queue: internal.Queue, + ScheduledAt: internal.ScheduledAt.UTC(), + State: rivertype.JobState(internal.State), + Tags: tags, + UniqueKey: bytesFromNullString(internal.UniqueKey), + UniqueStates: uniquestates.UniqueBitmaskToStates(uniqueStatesByte), + }, nil +} + +func leaderFromInternal(internal *dbsqlc.RiverLeader) *riverdriver.Leader { + return &riverdriver.Leader{ + ElectedAt: internal.ElectedAt.UTC(), + ExpiresAt: internal.ExpiresAt.UTC(), + LeaderID: internal.LeaderID, + } +} + +func migrationFromInternal(internal *dbsqlc.RiverMigration) *riverdriver.Migration { + return &riverdriver.Migration{ + CreatedAt: internal.CreatedAt.UTC(), + Line: internal.Line, + Version: int(internal.Version), + } +} + +func schemaTemplateParam(ctx context.Context, schema string) context.Context { + if schema != "" { + schema = mysqlIdentifier(schema) + "." + } + + return sqlctemplate.WithReplacements(ctx, map[string]sqlctemplate.Replacement{ + "schema": {Value: schema}, + }, nil) +} + +func informationSchemaParam(ctx context.Context) context.Context { + return sqlctemplate.WithReplacements(ctx, map[string]sqlctemplate.Replacement{ + "information_schema": {Stable: true, Value: "INFORMATION_SCHEMA."}, + }, nil) +} + +func queueFromInternal(internal *dbsqlc.RiverQueue) *rivertype.Queue { + pausedAt := ptrFromNullTime(internal.PausedAt) + if pausedAt != nil { + t := pausedAt.UTC() + pausedAt = &t + } + return &rivertype.Queue{ + CreatedAt: internal.CreatedAt.UTC(), + Metadata: internal.Metadata, + Name: internal.Name, + PausedAt: pausedAt, + UpdatedAt: internal.UpdatedAt.UTC(), + } +} + +func nullTimeFromPtr(t *time.Time) sql.NullTime { + if t == nil { + return sql.NullTime{} + } + return sql.NullTime{Time: *t, Valid: true} +} + +func ptrFromNullTime(t sql.NullTime) *time.Time { + if !t.Valid { + return nil + } + return &t.Time +} + +func nullStringFromBytes(b []byte) sql.NullString { + if b == nil { + return sql.NullString{} + } + return sql.NullString{String: string(b), Valid: true} +} + +func bytesFromNullString(s sql.NullString) []byte { + if !s.Valid { + return nil + } + return []byte(s.String) +} diff --git a/riverdriver/rivermysql/river_mysql_driver_test.go b/riverdriver/rivermysql/river_mysql_driver_test.go new file mode 100644 index 00000000..219d6766 --- /dev/null +++ b/riverdriver/rivermysql/river_mysql_driver_test.go @@ -0,0 +1,621 @@ +package rivermysql + +import ( + "context" + "database/sql" + "encoding/json" + "errors" + "fmt" + "testing" + "time" + + _ "github.com/go-sql-driver/mysql" + "github.com/stretchr/testify/require" + + "github.com/riverqueue/river/riverdbtest" + "github.com/riverqueue/river/riverdriver" + "github.com/riverqueue/river/rivershared/riversharedtest" + "github.com/riverqueue/river/rivertype" +) + +// Verify interface compliance. +var _ riverdriver.Driver[*sql.Tx] = New(nil) + +func TestInterpretError(t *testing.T) { + t.Parallel() + + require.EqualError(t, interpretError(errors.New("an error")), "an error") + require.ErrorIs(t, interpretError(sql.ErrNoRows), rivertype.ErrNotFound) + require.NoError(t, interpretError(nil)) +} + +func TestSchemaTemplateParam(t *testing.T) { + t.Parallel() + + ctx := context.Background() + + t.Run("NoSchema", func(t *testing.T) { + t.Parallel() + ctx := schemaTemplateParam(ctx, "") + // Just verify it doesn't panic + _ = ctx + }) + + t.Run("WithSchema", func(t *testing.T) { + t.Parallel() + ctx := schemaTemplateParam(ctx, "custom_schema") + _ = ctx + }) +} + +func TestDriverProperties(t *testing.T) { + t.Parallel() + + driver := New(nil) + require.Equal(t, "?", driver.ArgPlaceholder()) + require.Equal(t, "mysql", driver.DatabaseName()) + require.False(t, driver.SupportsListener()) + require.False(t, driver.SupportsListenNotify()) + require.Equal(t, time.Microsecond, driver.TimePrecision()) + require.False(t, driver.PoolIsSet()) +} + +func TestJobInsertAndGet(t *testing.T) { + t.Parallel() + + riversharedtest.SkipIfMySQLNotEnabled(t) + + var ( + ctx = t.Context() + driver = New(riversharedtest.DBPoolMySQL(ctx, t)) + schema = riverdbtest.TestSchema(ctx, t, driver, nil) + exec = driver.GetExecutor() + ) + + // Insert a job + job, err := exec.JobInsertFull(ctx, &riverdriver.JobInsertFullParams{ + EncodedArgs: []byte(`{"test": true}`), + Kind: "test_job", + MaxAttempts: 3, + Priority: 1, + Queue: "default", + Schema: schema, + State: rivertype.JobStateAvailable, + Tags: []string{"tag1"}, + }) + require.NoError(t, err) + require.NotNil(t, job) + require.Positive(t, job.ID) + require.Equal(t, "test_job", job.Kind) + require.Equal(t, rivertype.JobStateAvailable, job.State) + require.Equal(t, "default", job.Queue) + require.Equal(t, 1, job.Priority) + require.Equal(t, 3, job.MaxAttempts) + require.JSONEq(t, `{"test": true}`, string(job.EncodedArgs)) + require.Equal(t, []string{"tag1"}, job.Tags) + + // Get the job by ID + fetched, err := exec.JobGetByID(ctx, &riverdriver.JobGetByIDParams{ + ID: job.ID, + Schema: schema, + }) + require.NoError(t, err) + require.Equal(t, job.ID, fetched.ID) + require.Equal(t, job.Kind, fetched.Kind) +} + +func TestJobGetAvailable(t *testing.T) { + t.Parallel() + + riversharedtest.SkipIfMySQLNotEnabled(t) + + var ( + ctx = t.Context() + driver = New(riversharedtest.DBPoolMySQL(ctx, t)) + schema = riverdbtest.TestSchema(ctx, t, driver, nil) + exec = driver.GetExecutor() + ) + + // Insert some available jobs + for i := range 3 { + _, err := exec.JobInsertFull(ctx, &riverdriver.JobInsertFullParams{ + EncodedArgs: []byte(`{}`), + Kind: fmt.Sprintf("job_%d", i), + MaxAttempts: 3, + Priority: 1, + Queue: "default", + Schema: schema, + State: rivertype.JobStateAvailable, + Tags: []string{}, + }) + require.NoError(t, err) + } + + // Get available jobs (needs transaction for FOR UPDATE) + txExec, err := exec.Begin(ctx) + require.NoError(t, err) + defer txExec.Rollback(ctx) + + jobs, err := txExec.JobGetAvailable(ctx, &riverdriver.JobGetAvailableParams{ + ClientID: "test-client", + MaxAttemptedBy: 4, + MaxToLock: 2, + Queue: "default", + Schema: schema, + }) + require.NoError(t, err) + require.Len(t, jobs, 2) + + for _, job := range jobs { + require.Equal(t, rivertype.JobStateRunning, job.State) + } + + require.NoError(t, txExec.Commit(ctx)) +} + +func TestJobCancel(t *testing.T) { + t.Parallel() + + riversharedtest.SkipIfMySQLNotEnabled(t) + + var ( + ctx = t.Context() + driver = New(riversharedtest.DBPoolMySQL(ctx, t)) + schema = riverdbtest.TestSchema(ctx, t, driver, nil) + exec = driver.GetExecutor() + ) + + // Insert a job + job, err := exec.JobInsertFull(ctx, &riverdriver.JobInsertFullParams{ + EncodedArgs: []byte(`{}`), + Kind: "test_cancel", + MaxAttempts: 3, + Priority: 1, + Queue: "default", + Schema: schema, + State: rivertype.JobStateAvailable, + Tags: []string{}, + }) + require.NoError(t, err) + + // Cancel the job + cancelled, err := exec.JobCancel(ctx, &riverdriver.JobCancelParams{ + ID: job.ID, + CancelAttemptedAt: time.Now().UTC(), + ControlTopic: "test_topic", + Schema: schema, + }) + require.NoError(t, err) + require.Equal(t, rivertype.JobStateCancelled, cancelled.State) + require.NotNil(t, cancelled.FinalizedAt) +} + +func TestJobDelete(t *testing.T) { + t.Parallel() + + riversharedtest.SkipIfMySQLNotEnabled(t) + + var ( + ctx = t.Context() + driver = New(riversharedtest.DBPoolMySQL(ctx, t)) + schema = riverdbtest.TestSchema(ctx, t, driver, nil) + exec = driver.GetExecutor() + ) + + // Insert a completed job + now := time.Now().UTC().Truncate(time.Microsecond) + job, err := exec.JobInsertFull(ctx, &riverdriver.JobInsertFullParams{ + EncodedArgs: []byte(`{}`), + FinalizedAt: &now, + Kind: "test_delete", + MaxAttempts: 3, + Priority: 1, + Queue: "default", + Schema: schema, + State: rivertype.JobStateCompleted, + Tags: []string{}, + }) + require.NoError(t, err) + + // Delete the job + deleted, err := exec.JobDelete(ctx, &riverdriver.JobDeleteParams{ + ID: job.ID, + Schema: schema, + }) + require.NoError(t, err) + require.Equal(t, job.ID, deleted.ID) + + // Verify it's gone + _, err = exec.JobGetByID(ctx, &riverdriver.JobGetByIDParams{ + ID: job.ID, + Schema: schema, + }) + require.ErrorIs(t, err, rivertype.ErrNotFound) +} + +func TestJobCountByState(t *testing.T) { + t.Parallel() + + riversharedtest.SkipIfMySQLNotEnabled(t) + + var ( + ctx = t.Context() + driver = New(riversharedtest.DBPoolMySQL(ctx, t)) + schema = riverdbtest.TestSchema(ctx, t, driver, nil) + exec = driver.GetExecutor() + ) + + // Insert jobs in different states + for range 3 { + _, err := exec.JobInsertFull(ctx, &riverdriver.JobInsertFullParams{ + EncodedArgs: []byte(`{}`), + Kind: "test_count", + MaxAttempts: 3, + Priority: 1, + Queue: "default", + Schema: schema, + State: rivertype.JobStateAvailable, + Tags: []string{}, + }) + require.NoError(t, err) + } + + count, err := exec.JobCountByState(ctx, &riverdriver.JobCountByStateParams{ + Schema: schema, + State: rivertype.JobStateAvailable, + }) + require.NoError(t, err) + require.Equal(t, 3, count) +} + +func TestLeaderElection(t *testing.T) { + t.Parallel() + + riversharedtest.SkipIfMySQLNotEnabled(t) + + var ( + ctx = t.Context() + driver = New(riversharedtest.DBPoolMySQL(ctx, t)) + schema = riverdbtest.TestSchema(ctx, t, driver, nil) + exec = driver.GetExecutor() + ) + + // Attempt to elect a leader + leader, err := exec.LeaderAttemptElect(ctx, &riverdriver.LeaderElectParams{ + LeaderID: "test-leader", + Schema: schema, + TTL: 30 * time.Second, + }) + require.NoError(t, err) + require.Equal(t, "test-leader", leader.LeaderID) + + // Get the elected leader + fetched, err := exec.LeaderGetElectedLeader(ctx, &riverdriver.LeaderGetElectedLeaderParams{ + Schema: schema, + }) + require.NoError(t, err) + require.Equal(t, "test-leader", fetched.LeaderID) + + // Re-elect should succeed + reelected, err := exec.LeaderAttemptReelect(ctx, &riverdriver.LeaderReelectParams{ + ElectedAt: leader.ElectedAt, + LeaderID: "test-leader", + Schema: schema, + TTL: 30 * time.Second, + }) + require.NoError(t, err) + require.Equal(t, "test-leader", reelected.LeaderID) + + // Resign + resigned, err := exec.LeaderResign(ctx, &riverdriver.LeaderResignParams{ + ElectedAt: leader.ElectedAt, + LeaderID: "test-leader", + LeadershipTopic: "leadership", + Schema: schema, + }) + require.NoError(t, err) + require.True(t, resigned) +} + +func TestQueueOperations(t *testing.T) { + t.Parallel() + + riversharedtest.SkipIfMySQLNotEnabled(t) + + var ( + ctx = t.Context() + driver = New(riversharedtest.DBPoolMySQL(ctx, t)) + schema = riverdbtest.TestSchema(ctx, t, driver, nil) + exec = driver.GetExecutor() + ) + + now := time.Now().UTC().Truncate(time.Microsecond) + + // Create a queue + queue, err := exec.QueueCreateOrSetUpdatedAt(ctx, &riverdriver.QueueCreateOrSetUpdatedAtParams{ + Metadata: []byte(`{}`), + Name: "test_queue", + Now: &now, + Schema: schema, + }) + require.NoError(t, err) + require.Equal(t, "test_queue", queue.Name) + require.Nil(t, queue.PausedAt) + + // Get the queue + fetched, err := exec.QueueGet(ctx, &riverdriver.QueueGetParams{ + Name: "test_queue", + Schema: schema, + }) + require.NoError(t, err) + require.Equal(t, "test_queue", fetched.Name) + + // List queues + queues, err := exec.QueueList(ctx, &riverdriver.QueueListParams{ + Max: 100, + Schema: schema, + }) + require.NoError(t, err) + require.Len(t, queues, 1) + + // Pause queue + err = exec.QueuePause(ctx, &riverdriver.QueuePauseParams{ + Name: "test_queue", + Schema: schema, + }) + require.NoError(t, err) + + paused, err := exec.QueueGet(ctx, &riverdriver.QueueGetParams{ + Name: "test_queue", + Schema: schema, + }) + require.NoError(t, err) + require.NotNil(t, paused.PausedAt) + + // Resume queue + err = exec.QueueResume(ctx, &riverdriver.QueueResumeParams{ + Name: "test_queue", + Schema: schema, + }) + require.NoError(t, err) + + resumed, err := exec.QueueGet(ctx, &riverdriver.QueueGetParams{ + Name: "test_queue", + Schema: schema, + }) + require.NoError(t, err) + require.Nil(t, resumed.PausedAt) +} + +func TestTransactions(t *testing.T) { + t.Parallel() + + riversharedtest.SkipIfMySQLNotEnabled(t) + + var ( + ctx = t.Context() + driver = New(riversharedtest.DBPoolMySQL(ctx, t)) + schema = riverdbtest.TestSchema(ctx, t, driver, nil) + exec = driver.GetExecutor() + ) + + // Begin and commit + tx, err := exec.Begin(ctx) + require.NoError(t, err) + + _, err = tx.JobInsertFull(ctx, &riverdriver.JobInsertFullParams{ + EncodedArgs: []byte(`{}`), + Kind: "tx_test", + MaxAttempts: 3, + Priority: 1, + Queue: "default", + Schema: schema, + State: rivertype.JobStateAvailable, + Tags: []string{}, + }) + require.NoError(t, err) + + require.NoError(t, tx.Commit(ctx)) + + // Verify job exists + count, err := exec.JobCountByState(ctx, &riverdriver.JobCountByStateParams{ + Schema: schema, + State: rivertype.JobStateAvailable, + }) + require.NoError(t, err) + require.Equal(t, 1, count) + + // Begin and rollback + tx2, err := exec.Begin(ctx) + require.NoError(t, err) + + _, err = tx2.JobInsertFull(ctx, &riverdriver.JobInsertFullParams{ + EncodedArgs: []byte(`{}`), + Kind: "tx_test_rollback", + MaxAttempts: 3, + Priority: 1, + Queue: "default", + Schema: schema, + State: rivertype.JobStateAvailable, + Tags: []string{}, + }) + require.NoError(t, err) + + require.NoError(t, tx2.Rollback(ctx)) + + // Verify second job was rolled back + count, err = exec.JobCountByState(ctx, &riverdriver.JobCountByStateParams{ + Schema: schema, + State: rivertype.JobStateAvailable, + }) + require.NoError(t, err) + require.Equal(t, 1, count) // still 1 +} + +func TestJobInsertFastMany(t *testing.T) { + t.Parallel() + + riversharedtest.SkipIfMySQLNotEnabled(t) + + var ( + ctx = t.Context() + driver = New(riversharedtest.DBPoolMySQL(ctx, t)) + schema = riverdbtest.TestSchema(ctx, t, driver, nil) + exec = driver.GetExecutor() + ) + + results, err := exec.JobInsertFastMany(ctx, &riverdriver.JobInsertFastManyParams{ + Jobs: []*riverdriver.JobInsertFastParams{ + { + EncodedArgs: []byte(`{"i": 1}`), + Kind: "fast_job", + MaxAttempts: 3, + Priority: 1, + Queue: "default", + State: rivertype.JobStateAvailable, + Tags: []string{}, + }, + { + EncodedArgs: []byte(`{"i": 2}`), + Kind: "fast_job", + MaxAttempts: 3, + Priority: 1, + Queue: "default", + State: rivertype.JobStateAvailable, + Tags: []string{}, + }, + }, + Schema: schema, + }) + require.NoError(t, err) + require.Len(t, results, 2) + + for _, result := range results { + require.NotNil(t, result.Job) + require.False(t, result.UniqueSkippedAsDuplicate) + } +} + +func TestJobMetadata(t *testing.T) { + t.Parallel() + + riversharedtest.SkipIfMySQLNotEnabled(t) + + var ( + ctx = t.Context() + driver = New(riversharedtest.DBPoolMySQL(ctx, t)) + schema = riverdbtest.TestSchema(ctx, t, driver, nil) + exec = driver.GetExecutor() + ) + + // Insert a job with metadata + job, err := exec.JobInsertFull(ctx, &riverdriver.JobInsertFullParams{ + EncodedArgs: []byte(`{}`), + Kind: "metadata_test", + MaxAttempts: 3, + Metadata: []byte(`{"key": "value"}`), + Priority: 1, + Queue: "default", + Schema: schema, + State: rivertype.JobStateAvailable, + Tags: []string{}, + }) + require.NoError(t, err) + + var metadata map[string]any + require.NoError(t, json.Unmarshal(job.Metadata, &metadata)) + require.Equal(t, "value", metadata["key"]) + + // Update metadata + updated, err := exec.JobUpdate(ctx, &riverdriver.JobUpdateParams{ + ID: job.ID, + MetadataDoMerge: true, + Metadata: []byte(`{"new_key": "new_value"}`), + Schema: schema, + }) + require.NoError(t, err) + + require.NoError(t, json.Unmarshal(updated.Metadata, &metadata)) + require.Equal(t, "value", metadata["key"]) + require.Equal(t, "new_value", metadata["new_key"]) +} + +func TestJobSchedule(t *testing.T) { + t.Parallel() + + riversharedtest.SkipIfMySQLNotEnabled(t) + + var ( + ctx = t.Context() + driver = New(riversharedtest.DBPoolMySQL(ctx, t)) + schema = riverdbtest.TestSchema(ctx, t, driver, nil) + exec = driver.GetExecutor() + ) + + // Insert a scheduled job with a past scheduled_at + past := time.Now().UTC().Add(-1 * time.Hour).Truncate(time.Microsecond) + _, err := exec.JobInsertFull(ctx, &riverdriver.JobInsertFullParams{ + EncodedArgs: []byte(`{}`), + Kind: "scheduled_job", + MaxAttempts: 3, + Priority: 1, + Queue: "default", + ScheduledAt: &past, + Schema: schema, + State: rivertype.JobStateScheduled, + Tags: []string{}, + }) + require.NoError(t, err) + + // Schedule should find and transition the job + now := time.Now().UTC() + results, err := exec.JobSchedule(ctx, &riverdriver.JobScheduleParams{ + Max: 100, + Now: &now, + Schema: schema, + }) + require.NoError(t, err) + require.Len(t, results, 1) + require.Equal(t, rivertype.JobStateAvailable, results[0].Job.State) +} + +func TestNotifyMany(t *testing.T) { + t.Parallel() + + driver := New(nil) + // MySQL doesn't support LISTEN/NOTIFY + require.Panics(t, func() { + driver.GetListener(&riverdriver.GetListenenerParams{}) + }) +} + +func TestNotifyManyReturnsNotImplemented(t *testing.T) { + t.Parallel() + + riversharedtest.SkipIfMySQLNotEnabled(t) + + var ( + ctx = t.Context() + exec = New(riversharedtest.DBPoolMySQL(ctx, t)).GetExecutor() + ) + + err := exec.NotifyMany(ctx, &riverdriver.NotifyManyParams{ + Payload: []string{"test"}, + Topic: "test_topic", + }) + require.ErrorIs(t, err, riverdriver.ErrNotImplemented) +} + +func TestPGAdvisoryXactLock(t *testing.T) { + t.Parallel() + + riversharedtest.SkipIfMySQLNotEnabled(t) + + var ( + ctx = t.Context() + exec = New(riversharedtest.DBPoolMySQL(ctx, t)).GetExecutor() + ) + + _, err := exec.PGAdvisoryXactLock(ctx, 12345) + require.ErrorIs(t, err, riverdriver.ErrNotImplemented) +} diff --git a/riverdriver/riverpgxv5/river_pgx_v5_driver.go b/riverdriver/riverpgxv5/river_pgx_v5_driver.go index e0201a0e..5a9933ea 100644 --- a/riverdriver/riverpgxv5/river_pgx_v5_driver.go +++ b/riverdriver/riverpgxv5/river_pgx_v5_driver.go @@ -63,8 +63,9 @@ func New(dbPool *pgxpool.Pool) *Driver { const argPlaceholder = "$" -func (d *Driver) ArgPlaceholder() string { return argPlaceholder } -func (d *Driver) DatabaseName() string { return "postgres" } +func (d *Driver) ArgPlaceholder() string { return argPlaceholder } +func (d *Driver) DatabaseName() string { return "postgres" } +func (d *Driver) SafeIdentifier(ident string) string { return dbutil.SafeIdentifier(ident) } func (d *Driver) GetExecutor() riverdriver.Executor { return &Executor{templateReplaceWrapper{d.dbPool, &d.replacer}, d} diff --git a/riverdriver/riversqlite/river_sqlite_driver.go b/riverdriver/riversqlite/river_sqlite_driver.go index 132d0e46..ce77c91a 100644 --- a/riverdriver/riversqlite/river_sqlite_driver.go +++ b/riverdriver/riversqlite/river_sqlite_driver.go @@ -77,8 +77,9 @@ func New(dbPool *sql.DB) *Driver { const argPlaceholder = "?" -func (d *Driver) ArgPlaceholder() string { return argPlaceholder } -func (d *Driver) DatabaseName() string { return "sqlite" } +func (d *Driver) ArgPlaceholder() string { return argPlaceholder } +func (d *Driver) DatabaseName() string { return "sqlite" } +func (d *Driver) SafeIdentifier(ident string) string { return dbutil.SafeIdentifier(ident) } func (d *Driver) GetExecutor() riverdriver.Executor { return &Executor{d.dbPool, templateReplaceWrapper{d.dbPool, &d.replacer}, d, nil} diff --git a/rivermigrate/river_migrate.go b/rivermigrate/river_migrate.go index 12568e20..bfe2a0d2 100644 --- a/rivermigrate/river_migrate.go +++ b/rivermigrate/river_migrate.go @@ -547,7 +547,7 @@ func (m *Migrator[TTx]) applyMigrations(ctx context.Context, exec riverdriver.Ex var schema string if m.schema != "" { - schema = dbutil.SafeIdentifier(m.schema) + "." + schema = m.driver.SafeIdentifier(m.schema) + "." } schemaReplacement := map[string]sqlctemplate.Replacement{ "schema": {Value: schema}, diff --git a/rivershared/riversharedtest/riversharedtest.go b/rivershared/riversharedtest/riversharedtest.go index d6f2b5ba..1f5b7cb9 100644 --- a/rivershared/riversharedtest/riversharedtest.go +++ b/rivershared/riversharedtest/riversharedtest.go @@ -112,6 +112,44 @@ var sqliteTestDir = sync.OnceValue(func() string { //nolint:gochecknoglobals return path.Join(rootDir, "sqlite") }) +// A pool and sync.Once to initialize a MySQL pool, invoked by DBPoolMySQL. +var ( + dbPoolMySQL *sql.DB //nolint:gochecknoglobals + dbPoolMySQLOnce sync.Once //nolint:gochecknoglobals +) + +// SkipIfMySQLNotEnabled skips the current test if MySQL tests are not enabled. +// MySQL tests are opt-in because they require a running MySQL server. Set +// RIVER_MYSQL_TESTS_ENABLED=1 or RIVER_MYSQL_TESTS_ENABLED=true to enable. +func SkipIfMySQLNotEnabled(tb testing.TB) { + tb.Helper() + + val := os.Getenv("RIVER_MYSQL_TESTS_ENABLED") + if val != "1" && val != "true" { //nolint:goconst + tb.Skip("Skipping MySQL tests; set RIVER_MYSQL_TESTS_ENABLED=1 to enable") + } +} + +// DBPoolMySQL gets a lazily initialized database pool for MySQL testing. +// Uses TEST_MYSQL_URL or defaults to root@tcp(localhost:3306)/?. +func DBPoolMySQL(ctx context.Context, tb testing.TB) *sql.DB { + tb.Helper() + + dbPoolMySQLOnce.Do(func() { + dsn := cmp.Or( + os.Getenv("TEST_MYSQL_URL"), + "root@tcp(localhost:3306)/?parseTime=true&multiStatements=true&loc=UTC&time_zone=%27%2B00%3A00%27", + ) + + var err error + dbPoolMySQL, err = sql.Open("mysql", dsn) + require.NoError(tb, err) + }) + require.NotNil(tb, dbPoolMySQL) + + return dbPoolMySQL +} + // DBPoolLibSQL gets a database pool appropriate for use with libSQL (a SQLite // fork) in testing. func DBPoolLibSQL(ctx context.Context, tb testing.TB, schema string) *sql.DB { diff --git a/rivershared/sqlctemplate/sqlc_template.go b/rivershared/sqlctemplate/sqlc_template.go index 88c02edb..484a76ae 100644 --- a/rivershared/sqlctemplate/sqlc_template.go +++ b/rivershared/sqlctemplate/sqlc_template.go @@ -89,6 +89,14 @@ type Replacement struct { // be initialized with a constructor. This lets it default to a usable instance // on drivers that may themselves not be initialized. type Replacer struct { + // UnnumberedPlaceholders, when true, causes Run to emit plain `?` + // placeholders instead of numbered `?1`, `?2`, etc. for named args + // injected via the template system. The args slice is reordered to + // match the positional order of placeholders in the SQL. This is + // needed for MySQL, whose database/sql driver does not support + // numbered `?N` syntax. + UnnumberedPlaceholders bool + cache map[replacerCacheKey]string cacheMu sync.RWMutex } @@ -142,6 +150,14 @@ func (r *Replacer) RunSafely(ctx context.Context, argPlaceholder, sql string, ar } cacheKey, cacheEligible := replacerCacheKeyFrom(sql, container) + + // Named args are interleaved with sqlc args during a left-to-right SQL + // walk, producing args in positional order. The cache can't reconstruct + // this ordering on hit, so skip caching when named args are present. + if len(container.NamedArgs) > 0 { + cacheEligible = false + } + if cacheEligible { r.cacheMu.RLock() var ( @@ -205,29 +221,88 @@ func (r *Replacer) RunSafely(ctx context.Context, argPlaceholder, sql string, ar } if len(container.NamedArgs) > 0 { - placeholderNum := len(args) - // For the benefit of the test suite's output being predictable, sort // named args before processing them. sortedNamedArgs := maputil.Keys(container.NamedArgs) slices.Sort(sortedNamedArgs) + + // Verify all named args are present in the SQL. for _, arg := range sortedNamedArgs { - placeholderNum++ + if !strings.Contains(updatedSQL, "@"+arg) { + return "", nil, fmt.Errorf("sqltemplate expected to find named arg %q, but it wasn't present", "@"+arg) + } + } - var ( - symbol = "@" + arg - symbolIndex = strings.Index(updatedSQL, symbol) - val = container.NamedArgs[arg] - ) + // Walk the SQL once left-to-right, replacing @name references and + // existing sqlc placeholders with sequentially numbered (or + // unnumbered) placeholders, building the args slice in positional + // order as we go. + var ( + newArgs []any + out strings.Builder + seqNum int + sqlcArgIdx int + ) - if symbolIndex == -1 { - return "", nil, fmt.Errorf("sqltemplate expected to find named arg %q, but it wasn't present", symbol) + emitPlaceholder := func() { + seqNum++ + if r.UnnumberedPlaceholders { + out.WriteByte('?') + } else { + out.WriteString(argPlaceholder) + out.WriteString(strconv.Itoa(seqNum)) } + } - // ReplaceAll because an input parameter may appear multiple times. - updatedSQL = strings.ReplaceAll(updatedSQL, symbol, argPlaceholder+strconv.Itoa(placeholderNum)) - args = append(args, val) + for i := 0; i < len(updatedSQL); { + switch { + case updatedSQL[i] == '@': + matched := false + for _, name := range sortedNamedArgs { + symbol := "@" + name + if strings.HasPrefix(updatedSQL[i:], symbol) { + emitPlaceholder() + newArgs = append(newArgs, container.NamedArgs[name]) + i += len(symbol) + matched = true + break + } + } + if !matched { + out.WriteByte(updatedSQL[i]) + i++ + } + + case updatedSQL[i] == argPlaceholder[0] && i+1 < len(updatedSQL) && updatedSQL[i+1] >= '1' && updatedSQL[i+1] <= '9': + // Numbered placeholder from sqlc ($N for Postgres, ?N for + // SQLite): parse the old number, look up the original arg, + // and re-emit with a new sequential number. + j := i + 1 + for j < len(updatedSQL) && updatedSQL[j] >= '0' && updatedSQL[j] <= '9' { + j++ + } + oldNum, _ := strconv.Atoi(updatedSQL[i+1 : j]) + emitPlaceholder() + newArgs = append(newArgs, args[oldNum-1]) + i = j + + case r.UnnumberedPlaceholders && updatedSQL[i] == '?': + // Unnumbered positional ? from sqlc (MySQL). + emitPlaceholder() + if sqlcArgIdx < len(args) { + newArgs = append(newArgs, args[sqlcArgIdx]) + } + sqlcArgIdx++ + i++ + + default: + out.WriteByte(updatedSQL[i]) + i++ + } } + + updatedSQL = out.String() + args = newArgs } if cacheEligible { diff --git a/rivershared/sqlctemplate/sqlc_template_test.go b/rivershared/sqlctemplate/sqlc_template_test.go index 119ebd54..a53eeaa3 100644 --- a/rivershared/sqlctemplate/sqlc_template_test.go +++ b/rivershared/sqlctemplate/sqlc_template_test.go @@ -259,7 +259,7 @@ func TestReplacer(t *testing.T) { AND state = @state; ` - // Initially cached value + // Initially cached value (no named args). { ctx := WithReplacements(ctx, map[string]Replacement{ "schema": {Stable: true, Value: "test_schema."}, @@ -270,7 +270,7 @@ func TestReplacer(t *testing.T) { } require.Len(t, replacer.cache, 1) - // Same SQL, but new value. + // Same SQL, but new replacement value. { ctx := WithReplacements(ctx, map[string]Replacement{ "schema": {Stable: true, Value: "other_schema."}, @@ -281,7 +281,9 @@ func TestReplacer(t *testing.T) { } require.Len(t, replacer.cache, 2) - // Named arg added to the mix. + // Named args present: caching is skipped because args are built + // in positional order during a left-to-right walk that can't be + // replayed from the cache. { ctx := WithReplacements(ctx, map[string]Replacement{ "schema": {Stable: true, Value: "test_schema."}, @@ -292,37 +294,9 @@ func TestReplacer(t *testing.T) { _, _, err := replacer.RunSafely(ctx, "$", sql, nil) require.NoError(t, err) } - require.Len(t, replacer.cache, 3) - - // Different named arg _value_ (i.e. still same named arg) can still use - // the previous cached SQL. - { - ctx := WithReplacements(ctx, map[string]Replacement{ - "schema": {Stable: true, Value: "test_schema."}, - }, map[string]any{ - "kind": "other_kind_value", - }) - - _, _, err := replacer.RunSafely(ctx, "$", sql, nil) - require.NoError(t, err) - } - require.Len(t, replacer.cache, 3) // unchanged + require.Len(t, replacer.cache, 2) // unchanged - // New named arg adds a new cache value. - { - ctx := WithReplacements(ctx, map[string]Replacement{ - "schema": {Stable: true, Value: "test_schema."}, - }, map[string]any{ - "kind": "kind_value", - "state": "state_value", - }) - - _, _, err := replacer.RunSafely(ctx, "$", sql, nil) - require.NoError(t, err) - } - require.Len(t, replacer.cache, 4) - - // Different input SQL. + // Different input SQL (no named args). { ctx := WithReplacements(ctx, map[string]Replacement{ "schema": {Stable: true, Value: "test_schema."}, @@ -333,7 +307,7 @@ func TestReplacer(t *testing.T) { `, nil) require.NoError(t, err) } - require.Len(t, replacer.cache, 5) + require.Len(t, replacer.cache, 3) }) t.Run("NamedArgsNoInitialArgs", func(t *testing.T) { @@ -372,6 +346,9 @@ func TestReplacer(t *testing.T) { "kind": "no_op", }) + // Named arg @kind appears before the sqlc $1 in the SQL, so after + // the left-to-right walk, args are reordered to [no_op, succeeded] + // and placeholders renumbered sequentially. updatedSQL, args, err := replacer.RunSafely(ctx, "$", ` SELECT count(*) FROM river_job @@ -379,12 +356,12 @@ func TestReplacer(t *testing.T) { AND status = $1; `, []any{"succeeded"}) require.NoError(t, err) - require.Equal(t, []any{"succeeded", "no_op"}, args) + require.Equal(t, []any{"no_op", "succeeded"}, args) require.Equal(t, ` SELECT count(*) FROM river_job - WHERE kind = $2 - AND status = $1; + WHERE kind = $1 + AND status = $2; `, updatedSQL) }) @@ -420,6 +397,225 @@ func TestReplacer(t *testing.T) { `, updatedSQL) }) + t.Run("UnnumberedPlaceholders_NoNamedArgs", func(t *testing.T) { + t.Parallel() + + replacer := &Replacer{UnnumberedPlaceholders: true} + + ctx := WithReplacements(ctx, map[string]Replacement{ + "schema": {Value: "test_schema."}, + }, nil) + + updatedSQL, args, err := replacer.RunSafely(ctx, "?", ` + SELECT count(*) + FROM /* TEMPLATE: schema */river_job + WHERE state = ?; + `, []any{"available"}) + require.NoError(t, err) + require.Equal(t, []any{"available"}, args) + require.Equal(t, ` + SELECT count(*) + FROM test_schema.river_job + WHERE state = ?; + `, updatedSQL) + }) + + t.Run("UnnumberedPlaceholders_NamedArgsNoInitialArgs", func(t *testing.T) { + t.Parallel() + + replacer := &Replacer{UnnumberedPlaceholders: true} + + ctx := WithReplacements(ctx, map[string]Replacement{ + "where_clause": {Value: "kind = @kind"}, + }, map[string]any{ + "kind": "no_op", + }) + + updatedSQL, args, err := replacer.RunSafely(ctx, "?", ` + SELECT count(*) + FROM river_job + WHERE /* TEMPLATE_BEGIN: where_clause */ true /* TEMPLATE_END */; + `, nil) + require.NoError(t, err) + require.Equal(t, []any{"no_op"}, args) + require.Equal(t, ` + SELECT count(*) + FROM river_job + WHERE kind = ?; + `, updatedSQL) + }) + + t.Run("UnnumberedPlaceholders_NamedArgsWithInitialArgs", func(t *testing.T) { + t.Parallel() + + replacer := &Replacer{UnnumberedPlaceholders: true} + + ctx := WithReplacements(ctx, map[string]Replacement{ + "where_clause": {Value: "kind = @kind"}, + }, map[string]any{ + "kind": "no_op", + }) + + // The named arg @kind appears in the WHERE clause before the + // sqlc-generated ? for LIMIT. UnnumberedPlaceholders reorders args so + // that they match the positional ? order in the final SQL. + updatedSQL, args, err := replacer.RunSafely(ctx, "?", ` + SELECT count(*) + FROM river_job + WHERE /* TEMPLATE_BEGIN: where_clause */ true /* TEMPLATE_END */ + LIMIT ?; + `, []any{100}) + require.NoError(t, err) + require.Equal(t, []any{"no_op", 100}, args) + require.Equal(t, ` + SELECT count(*) + FROM river_job + WHERE kind = ? + LIMIT ?; + `, updatedSQL) + }) + + t.Run("UnnumberedPlaceholders_NamedArgRepeated", func(t *testing.T) { + t.Parallel() + + replacer := &Replacer{UnnumberedPlaceholders: true} + + ctx := WithReplacements(ctx, map[string]Replacement{ + "where_clause": {Value: "kind = @kind OR queue = @kind"}, + }, map[string]any{ + "kind": "no_op", + }) + + // When a named arg appears multiple times, it should produce a ? for + // each occurrence with the value duplicated in the args slice. + updatedSQL, args, err := replacer.RunSafely(ctx, "?", ` + SELECT * + FROM river_job + WHERE /* TEMPLATE_BEGIN: where_clause */ true /* TEMPLATE_END */ + LIMIT ?; + `, []any{100}) + require.NoError(t, err) + require.Equal(t, []any{"no_op", "no_op", 100}, args) + require.Equal(t, ` + SELECT * + FROM river_job + WHERE kind = ? OR queue = ? + LIMIT ?; + `, updatedSQL) + }) + + t.Run("UnnumberedPlaceholders_MultipleNamedArgs", func(t *testing.T) { + t.Parallel() + + replacer := &Replacer{UnnumberedPlaceholders: true} + + ctx := WithReplacements(ctx, map[string]Replacement{ + "schema": {Stable: true, Value: "test_schema."}, + "where_clause": {Value: "kind = @kind AND status = @status"}, + }, map[string]any{ + "kind": "no_op", + "status": "succeeded", + }) + + updatedSQL, args, err := replacer.RunSafely(ctx, "?", ` + SELECT count(*) + FROM /* TEMPLATE: schema */river_job + WHERE /* TEMPLATE_BEGIN: where_clause */ true /* TEMPLATE_END */ + LIMIT ?; + `, []any{100}) + require.NoError(t, err) + require.Equal(t, []any{"no_op", "succeeded", 100}, args) + require.Equal(t, ` + SELECT count(*) + FROM test_schema.river_job + WHERE kind = ? AND status = ? + LIMIT ?; + `, updatedSQL) + }) + + t.Run("UnnumberedPlaceholders_NotCachedWithNamedArgs", func(t *testing.T) { + t.Parallel() + + replacer := &Replacer{UnnumberedPlaceholders: true} + + ctx := WithReplacements(ctx, map[string]Replacement{ + "schema": {Stable: true, Value: "test_schema."}, + "where_clause": {Stable: true, Value: "kind = @kind"}, + }, map[string]any{ + "kind": "no_op", + }) + + sql := ` + SELECT count(*) + FROM /* TEMPLATE: schema */river_job + WHERE /* TEMPLATE_BEGIN: where_clause */ true /* TEMPLATE_END */ + LIMIT ?; + ` + + // Unnumbered mode with named args skips caching because the + // cached SQL can't preserve the positional arg ordering. + updatedSQL, args, err := replacer.RunSafely(ctx, "?", sql, []any{100}) + require.NoError(t, err) + require.Equal(t, []any{"no_op", 100}, args) + require.Equal(t, ` + SELECT count(*) + FROM test_schema.river_job + WHERE kind = ? + LIMIT ?; + `, updatedSQL) + + require.Empty(t, replacer.cache) + + // Second call still produces correct results. + updatedSQL, args, err = replacer.RunSafely(ctx, "?", sql, []any{200}) + require.NoError(t, err) + require.Equal(t, []any{"no_op", 200}, args) + require.Equal(t, ` + SELECT count(*) + FROM test_schema.river_job + WHERE kind = ? + LIMIT ?; + `, updatedSQL) + }) + + t.Run("UnnumberedPlaceholders_CachedWithoutNamedArgs", func(t *testing.T) { + t.Parallel() + + replacer := &Replacer{UnnumberedPlaceholders: true} + + ctx := WithReplacements(ctx, map[string]Replacement{ + "schema": {Stable: true, Value: "test_schema."}, + }, nil) + + sql := ` + SELECT count(*) + FROM /* TEMPLATE: schema */river_job + LIMIT ?; + ` + + // Without named args, caching works normally in unnumbered mode. + updatedSQL, args, err := replacer.RunSafely(ctx, "?", sql, []any{100}) + require.NoError(t, err) + require.Equal(t, []any{100}, args) + require.Equal(t, ` + SELECT count(*) + FROM test_schema.river_job + LIMIT ?; + `, updatedSQL) + + require.Len(t, replacer.cache, 1) + + // Second call uses cache. + updatedSQL, args, err = replacer.RunSafely(ctx, "?", sql, []any{200}) + require.NoError(t, err) + require.Equal(t, []any{200}, args) + require.Equal(t, ` + SELECT count(*) + FROM test_schema.river_job + LIMIT ?; + `, updatedSQL) + }) + t.Run("Stress", func(t *testing.T) { t.Parallel()