Beartype perimeter on user-facing API#90
Merged
Conversation
Stacked on the af-estimator branch (`2206212` series). Mirrors the
pattern in pylcm PR #355: a per-exception `BeartypeConf` plus a
`beartype_init` class decorator routes parameter-type violations at
every documented entry point through a skillmodels-specific
exception class, so callers can write narrowly-scoped `except`
clauses against a stable hierarchy rather than catching beartype's
framework exception.
Layout
------
* `src/skillmodels/exceptions.py` -- six `TypeError` subclasses,
organised by perimeter (`ModelSpecInitializationError`,
`OptionsInitializationError`, `EstimationCallError`,
`InferenceCallError`, `SimulationCallError`,
`DiagnosticsCallError`), all inheriting from a common
`SkillmodelsInputError` for callers that want to catch the whole
hierarchy in one go.
* `src/skillmodels/_beartype_conf.py` --
`_conf(exc)` builds a `BeartypeConf` with
`violation_param_type=exc`, `strategy=BeartypeStrategy.On` (full
O(n) container scan; entry points are called rarely compared to
the JIT-compiled hot path each one kicks off), and
`is_pep484_tower=True`. `beartype_init(conf)` is a class decorator
that wraps only `__init__` so non-public method-level annotation
drift on instance methods does not surface at construction time.
Decoration sites
----------------
* `@beartype_init(MODEL_SPEC_CONF)` on `FactorSpec`, `AnchoringSpec`,
`ModelSpec`, `Normalizations`.
* `@beartype_init(OPTIONS_CONF)` on `CHSEstimationOptions`,
`AFEstimationOptions`, `AMNEstimationOptions`.
* `@beartype(conf=ESTIMATION_CONF)` on `get_maximization_inputs`,
`get_filtered_states`, `estimate_af`, `estimate_amn`,
`get_af_posterior_states`, `get_amn_posterior_states`.
* `@beartype(conf=INFERENCE_CONF)` on
`compute_af_standard_errors`, `compute_amn_standard_errors`.
* `@beartype(conf=SIMULATION_CONF)` on `simulate_dataset`,
`simulate_policy_effect`.
* `@beartype(conf=DIAGNOSTICS_CONF)` on
`decompose_measurement_variance`,
`summarize_measurement_reliability`,
`plot_residual_boxplots`, `plot_likelihood_contributions`,
`create_state_ranges`, `plot_correlation_heatmap`,
`get_measurements_corr`, `get_quasi_scores_corr`,
`get_scores_corr`, `univariate_densities`,
`bivariate_density_contours`, `bivariate_density_surfaces`,
`combine_distribution_plots`, `get_transition_plots`,
`combine_transition_plots`.
Side effects of perimeter-only validation
-----------------------------------------
* `_check_measurements`'s type-shape arm in
`common/check_model.py` is now dead code: the
`tuple[tuple[str, ...], ...]` annotation on
`FactorSpec.measurements` makes beartype reject every malformed
measurement structure at construction time. The function is kept
(the report aggregator might still surface non-type issues a
beartype container scan can't see), but the corresponding two
tests in `tests/test_check_model.py` are rewritten to assert
`ModelSpecInitializationError` at `FactorSpec(...)` time
instead of asserting a soft message in the aggregator output.
* `tests/test_af_jaxopt_backend.py::test_optimizer_backend_rejects_unknown_value`
now asserts `OptionsInitializationError` from beartype's
`Literal` check (which fires before
`AFEstimationOptions.__post_init__`'s manual ValueError).
* `tests/test_amn_plot_harmonization.py::test_get_filtered_states_rejects_both_af_and_amn_results`
now asserts `EstimationCallError`; the prior body-level
`"only one of"` ValueError is still in place but is unreachable
from this fixture, which passes the same AMN result to both
parameters and so trips the type guard on `af_result` first.
* `chs/filtered_states.py` imports `AFEstimationResult` /
`AMNEstimationResult` at runtime rather than under
`TYPE_CHECKING` so beartype can resolve the annotation; ruff's
TC003 autofix had been silently unforwarding the string forward
refs.
Verification
------------
* `pixi run -e tests-cpu pytest tests/ -q -k "not long_running"` --
529 passed, 1 deselected (same count as before this commit; no
regressions).
* `pixi run ty` -- clean.
* `prek run --all-files` -- clean.
Out of scope (follow-up PRs)
----------------------------
* Whole-package activation via `beartype.claw.beartype_package("skillmodels")`
in `tests/conftest.py`. That probe would surface internal-helper
annotation drift the same way pylcm's part-3 PR will, and is left
for a separate review.
* AGENTS-level conventions documentation. The perimeter is in place;
the rule for where to put the next decorator is "wherever the
signature is documented to the user" -- to be expanded once the
pattern has settled.
🤖 Generated with [Claude Code](https://claude.com/claude-code)
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## af-estimator #90 +/- ##
================================================
+ Coverage 95.05% 95.29% +0.24%
================================================
Files 102 105 +3
Lines 10070 10207 +137
================================================
+ Hits 9572 9727 +155
+ Misses 498 480 -18 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
The previous fix (2206212) tried to address this by enabling x64 before `import jaxopt`. The reproducer on sonny (jax 0.10.0, jaxopt 0.8.5) shows that's insufficient: even with x64 on before any jaxopt code runs and float64 inputs throughout, `LBFGSB.update`'s jit-compiled `jnp.argsort` still emits an s32 reduction accumulator while the surrounding scatter operand is built as s64. XLA's `permutation_sort_simplifier` HLO pass rejects the mismatch with `INVALID_ARGUMENT: Reduction function's accumulator shape at index 0 differs from the init_value shape: s32[] vs s64[]`. Disabling just `permutation_sort_simplifier` via `XLA_FLAGS` fixes the crash, keeps every other XLA optimisation intact, and is a no-op on JAX < 0.10 (the pass doesn't exist there). The flag must be set before `import jax` because XLA reads `XLA_FLAGS` once at backend init. Applied in two places: - `skillmodels/__init__.py`: the primary entry point. Appends to any pre-existing `XLA_FLAGS` so user flags aren't clobbered. - `skillmodels/af/jaxopt_backend.py`: belt-and-suspenders for direct module imports that skip the package init. The previous comment block tying the bug to "x64 off at import time" was wrong about the root cause; replaced with the actual XLA pass explanation. The `JAX_ENABLE_X64=1` setting is retained because the AF pipeline assumes float64 throughout. Verified end-to-end on sonny (jax 0.10): minimum jaxopt repro that previously crashed now succeeds. Local jaxopt backend tests (7) and full local suite (485 tests, jax 0.9) still pass. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
The perimeter decorators (PR #90 / commit d7d29ea) covered only public entry points. The claw activation in `tests/conftest.py` extends type enforcement to every annotated callable in the package during the test run, catching annotation drift on internal helpers that would otherwise silently flow through. Configuration: - `is_pep484_tower=True` to mirror the perimeter conf so `int` satisfies `float`-typed parameters. - `claw_skip_package_names=("skillmodels.chs.qr",)` because JAX's `@custom_jvp` decorator stores the secondary `.defjvp` setter on the wrapped object; beartype.claw's wrapping strips it. Annotation drift fixes (sources of truth, not type-system theater): * `FixedConstraintWithValue` moved to its own module `common/fixed_constraint.py`. `transition_functions.py` previously imported it under `if TYPE_CHECKING:` to avoid a circular import with `constraints.py`; beartype.claw can't resolve those forward refs at decoration time. The leaf type now lives where both modules can pull it without a cycle. * Same TYPE_CHECKING removal for `ModelSpec` in `af/types.py` and `amn/types.py`. * JAX-traced helpers (`_at_node`, `_chain_one_component`, `_compute_investment`, kalman / likelihood entry points, transition pipeline plumbing): annotations relaxed to accept `Array | np.ndarray` / `float | Array` / `int | np.integer` where the runtime contract is genuinely mixed. JAX vmap traces ints as `BatchTracer`; numpy and jax arrays interconvert freely through these signatures. * `MixtureComponent`, `ConditionalDistribution`, `ChainLink` dataclass fields: accept `np.ndarray` alongside `jax.Array` since estimators fill them with both. * `TransitionInfo.param_names`: now built with explicit `tuple()` conversion at the boundary in `process_model.py` so the `MappingProxyType[str, tuple[str, ...]]` annotation actually holds. * `get_has_endogenous_factors` now casts the pandas `.any()` result to `bool` instead of relying on a `# ty: ignore`. * `NDArray[np.floating[Any]]` (which beartype doesn't accept as a dtype hint) replaced with `NDArray[np.float64]` in `chs/process_debug_data` and `common/visualize_factor_distributions`. * `NDArray[np.floating]` in `common/simulate_data` widened to `NDArray[np.floating] | Array`. * Internal duck-typed validators (`_check_anchoring`, `_process_factors`) re-typed as `Any` with `# noqa: ANN401` and an inline comment; they exist precisely to take partially-built objects. * `_aug_periods_from_period`: `dict[int, int]` → `Mapping[int, int]` (production passes a `MappingProxyType`). Tests: - `tests/test_check_model.py`: re-add `# ty: ignore` on the two `FactorSpec(measurements=...)` calls that intentionally pass `list` where `tuple` is required; they verify the beartype perimeter catches the shape error. - `tests/test_transition_functions.py::test_constant`: rewritten to pass real JAX arrays now that the claw type-checks `constant`. - Stale `# ty: ignore[invalid-argument-type]` directives stripped from `test_check_model.py`, `test_correlation_heatmap.py`, `test_process_debug_data.py` (and one in `simulate_data.py`) — the annotations they were silencing have been relaxed. Verification: 495 tests pass with claw enabled (`pixi run -e tests-cpu tests`); `pixi run ty` clean; `prek run --all-files` clean. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
CI run 25868871644 surfaced 14 failures in `tests/test_af_inference.py` that local runs missed (the file is unmarked but the local sweep excluded similarly-shaped tests via `-k "not long_running and not end_to_end"`). All failures had the same root cause as the bulk of the prior claw-activation diff: AF transition / chain helpers were annotated with strict `jax.Array` parameters, but the runtime path through `af.inference` constructs `prev_distribution`, chain link inputs, and chol/mean blocks from `np.ndarray`. Beartype rejected them under the now-active claw. Relaxed sites: - `af.likelihood.af_per_obs_loglike_transition` / `af_loglike_transition` / `_integrate_transition_chain`: `prev_distribution` widened from `dict[str, Array]` to `Mapping[str, Array | np.ndarray]`. `Mapping` (covariant) lets callers pass `dict[str, Array]` without an explicit cast. - `af.likelihood._map_over_obs`: `*xs: Array` → `*xs: Array | np.ndarray`. - `af.likelihood._integrate_transition_single_obs`: `obs_cond_weights`, `obs_cond_means`, `cond_chols` widened to `Array | np.ndarray`. - `af.likelihood._rebuild_chain_at_period`: `initial_mean`, `initial_chol` widened to `Array | np.ndarray`. Internal `theta` bound through `jnp.asarray(...)` so downstream `_compute_investment` still sees an `Array`. - `af.likelihood._compute_investment`: `inv_eq_params`, `inv_sds` widened (covered earlier; ty-narrowing followed naturally). Verification: - `pixi run -e tests-cpu pytest tests/test_af_inference.py` — 14 / 14 pass (was 1 failed + 13 errors). - `pixi run ty` clean. - `prek run --all-files` clean. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Previously stopped on `max|projected_grad| < tol` only. That's not how scipy_lbfgsb stops in practice: scipy stops on EITHER `gtol_abs` OR a relative function-value decrease `ftol_rel`. On the skill-formation likelihoods used in the Monte Carlo benchmarks, the loglikelihood goes locally flat before the gradient does, so scipy's ftol channel fires ~100% of the time and gradient-norm < 1e-5 is essentially never the actual stopping criterion in production. Without the ftol channel, jaxopt would grind down the gradient while scipy declared success at the same point — a fake apples-to-oranges asymmetry that made the AF-jaxopt vs AF-optimagic timing comparison meaningless. Implementation: drop jaxopt's built-in `run()` and drive the solver through an explicit `init_state` + `update` loop with the same gtol-OR-ftol stop. Default values now match scipy_lbfgsb's defaults (`gtol_abs=1e-5`, `ftol_rel=2.22e-9`, `maxiter=15000`). The wrapper accepts both canonical scipy keys (`convergence_gtol_abs`, `convergence_ftol_rel`, `stopping_maxiter`) and the historical jaxopt keys (`tol`, `maxiter`) so the same `optimizer_options` dict works for either backend. This makes the two LBFGSB implementations stop on byte-identical rules; the only remaining differences are internal (line search, step acceptance, curvature-pair filtering) — which is the comparison that's actually interesting. Verified: 7 jaxopt_backend tests still pass; ty clean. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Stacked on the af-estimator branch.
Summary
Mirrors the pattern in pylcm PR #355: a per-exception
BeartypeConfplus abeartype_initclass decorator routes parameter-type violations at every documented entry point through a skillmodels-specific exception class, so callers can write narrowly-scopedexceptclauses against a stable hierarchy rather than catching beartype's framework exception.Exceptions (
src/skillmodels/exceptions.py)Six
TypeErrorsubclasses, organised by perimeter:ModelSpecInitializationError—FactorSpec,AnchoringSpec,ModelSpec,NormalizationsOptionsInitializationError—CHSEstimationOptions,AFEstimationOptions,AMNEstimationOptionsEstimationCallError—get_maximization_inputs,get_filtered_states,estimate_af,estimate_amn,get_af_posterior_states,get_amn_posterior_statesInferenceCallError—compute_af_standard_errors,compute_amn_standard_errorsSimulationCallError—simulate_dataset,simulate_policy_effectDiagnosticsCallError—decompose_measurement_variance,summarize_measurement_reliability,plot_residual_boxplots,plot_likelihood_contributions,create_state_ranges,plot_correlation_heatmap,get_measurements_corr,get_quasi_scores_corr,get_scores_corr,univariate_densities,bivariate_density_contours,bivariate_density_surfaces,combine_distribution_plots,get_transition_plots,combine_transition_plotsAll inherit from a common
SkillmodelsInputErrorfor callers that want to catch the whole hierarchy.Decorator + config (
src/skillmodels/_beartype_conf.py)_conf(exc)—BeartypeConfwithviolation_param_type=exc,strategy=BeartypeStrategy.On(full O(n) container scan; entry points are called rarely compared to the JIT-compiled hot path each one kicks off),is_pep484_tower=True.beartype_init(conf)— class decorator that wraps only__init__. Bare@beartypeon a class wraps every method, which surfaces non-public annotation drift on instance methods that has nothing to do with parameter validation at construction time.MODEL_SPEC_CONF,OPTIONS_CONF,ESTIMATION_CONF,INFERENCE_CONF,SIMULATION_CONF,DIAGNOSTICS_CONF.Side effects
_check_measurements's type-shape arm incommon/check_model.pyis now dead code: thetuple[tuple[str, ...], ...]annotation onFactorSpec.measurementsmakes beartype reject every malformed measurement structure at construction time. The function is kept (the report aggregator might still surface non-type issues a beartype container scan can't see), but the corresponding two tests intests/test_check_model.pyare rewritten to assertModelSpecInitializationErroratFactorSpec(...)time.tests/test_af_jaxopt_backend.py::test_optimizer_backend_rejects_unknown_valuenow assertsOptionsInitializationErrorfrom beartype'sLiteralcheck (which fires beforeAFEstimationOptions.__post_init__'s manualValueError).tests/test_amn_plot_harmonization.py::test_get_filtered_states_rejects_both_af_and_amn_resultsnow assertsEstimationCallError; the prior body-level\"only one of\"ValueErroris still in place but unreachable from this fixture, which passes the same AMN result to both parameters and trips the type guard onaf_resultfirst.chs/filtered_states.pyimportsAFEstimationResult/AMNEstimationResultat runtime rather than underTYPE_CHECKINGso beartype can resolve the annotation; ruff's TC003 autofix had been silently unforwarding the string forward refs.Test plan
Out of scope (follow-up PRs)
🤖 Generated with Claude Code