From 53d5b4972f788efec06d8b049a0341afc147e0b0 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Mon, 11 May 2026 13:37:48 +0200 Subject: [PATCH 1/4] Sweep DAG-function annotations to ScalarFloat/ScalarInt/Age; merge u_alive This batches several follow-ups from the consumption-floor refactor: * Annotation sweep: every DAG function's scalar fixed_params move from Python `float`/`int`/`bool` to pylcm's post-#345 canonical scalar aliases (`ScalarFloat`, `ScalarInt`, `ScalarBool`, `Age`). Internal test helpers and analytics functions (`aime_to_pia`, `pia_to_aime`, `compute_hcc_insurer_table`) keep Python literal contracts. * Tests: switch every callsite that passed Python literals to a 0-d jax scalar, matching the post-boundary contract. * `u_alive` merges what were `u_canwork` / `u_forcedout`; leisure split three ways (`leisure_canwork_retiree_or_nongroup`, `leisure_canwork_tied`, `leisure_forcedout`). * `u_dead` deleted; `preferences.bequest` registers as the dead-regime utility directly. * `fixed_cost_of_work` extracted as a DAG function used by both `leisure_canwork_*` consumers. * `reference_hours` raised to 1000.0 (lands on the working-hours grid). * `_HCC_RHO` and `_WAGE_RHO` hoisted to module-level constants in `baseline/regimes/_common.py`. * `consumption_dollars_grid`: Python-side guard rejects collapses where the married floor would sit at or beyond the third gridpoint. * `MAX_CONSUMPTION_DOLLARS` carries a TODO pointing to pylcm#348 for routing through `fixed_params` once that lands. * `benchmark_params.pkl` regenerated to reflect the `average_consumption_equiv` key + `reference_hours=1000.0`. Co-Authored-By: Claude Opus 4.7 (1M context) --- .../_benchmark_data/benchmark_params.pkl | Bin 68327 -> 68325 bytes src/aca_model/aca/health_insurance.py | 15 +- src/aca_model/agent/assets_and_income.py | 5 +- src/aca_model/agent/labor_market.py | 5 +- src/aca_model/agent/preferences.py | 149 +++++++----------- src/aca_model/baseline/health_insurance.py | 91 +++++------ src/aca_model/baseline/regimes/_common.py | 32 ++-- src/aca_model/baseline/regimes/_nongroup.py | 8 +- src/aca_model/baseline/regimes/_retiree.py | 8 +- src/aca_model/baseline/regimes/_tied.py | 6 +- src/aca_model/consumption_dollars_grid.py | 24 ++- src/aca_model/environment/pensions.py | 4 +- src/aca_model/environment/social_security.py | 59 ++++--- tests/test_baseline_equivalence.py | 12 +- tests/test_health_insurance.py | 42 ++--- tests/test_model_components.py | 77 ++++----- tests/test_pension_integration.py | 2 +- tests/test_pensions.py | 4 +- tests/test_preferences.py | 112 ++++++------- tests/test_social_security.py | 25 +-- tests/test_ss_benefit_integration.py | 4 +- 21 files changed, 335 insertions(+), 349 deletions(-) diff --git a/src/aca_model/_benchmark_data/benchmark_params.pkl b/src/aca_model/_benchmark_data/benchmark_params.pkl index d0d9c1dac7cf1641c10007edf3ad3865dc0d0436..650c3902fc840003a043e1d0a5d24aa41e09120f 100644 GIT binary patch delta 104 zcmV-u0GI#gl?3IL1O$Kul>)H@(gqqCVRmJ5VP|DuV{dMAb!~8TX>V>{WpQa4WnX1(WN&wEWo~qoM?kKv FloatND: """Compute primary OOP costs with ACA cost-sharing reductions. diff --git a/src/aca_model/agent/assets_and_income.py b/src/aca_model/agent/assets_and_income.py index 92e9abb..e517fb5 100644 --- a/src/aca_model/agent/assets_and_income.py +++ b/src/aca_model/agent/assets_and_income.py @@ -9,12 +9,13 @@ ContinuousAction, ContinuousState, FloatND, + ScalarFloat, ) def capital_income( assets: ContinuousState, - rate_of_return: float, + rate_of_return: ScalarFloat, ) -> FloatND: """Compute capital income from assets.""" return assets * rate_of_return @@ -36,7 +37,7 @@ def cash_on_hand( def consumption_dollars_floor( - consumption_equiv_floor: float, + consumption_equiv_floor: ScalarFloat, equivalence_scale: FloatND, ) -> FloatND: """Per-household $-floor on consumption.""" diff --git a/src/aca_model/agent/labor_market.py b/src/aca_model/agent/labor_market.py index 6e36947..14b9d65 100644 --- a/src/aca_model/agent/labor_market.py +++ b/src/aca_model/agent/labor_market.py @@ -12,6 +12,7 @@ FloatND, IntND, Period, + ScalarFloat, ) @@ -52,8 +53,8 @@ def income( good_health: IntND, log_ft_wage_mean: FloatND, log_ft_wage_std: FloatND, - adj_wage_hours_exp: float, - adj_wage_hours_int: float, + adj_wage_hours_exp: ScalarFloat, + adj_wage_hours_int: ScalarFloat, ) -> FloatND: """Labor income with wage-hours interaction (French & Jones 2011). diff --git a/src/aca_model/agent/preferences.py b/src/aca_model/agent/preferences.py index 5c08541..575ebcd 100644 --- a/src/aca_model/agent/preferences.py +++ b/src/aca_model/agent/preferences.py @@ -6,12 +6,15 @@ import jax.numpy as jnp from lcm import categorical from lcm.typing import ( + Age, BoolND, ContinuousAction, ContinuousState, DiscreteState, FloatND, IntND, + ScalarFloat, + ScalarInt, ) from aca_model.agent.labor_market import LaggedLaborSupply @@ -46,7 +49,7 @@ def positive_leisure(leisure: FloatND) -> BoolND: return leisure > 0 -def equivalence_scale(is_married: IntND, exponent: float) -> FloatND: +def equivalence_scale(is_married: IntND, exponent: ScalarFloat) -> FloatND: """Return the equivalence scale for household size adjustment. Single (is_married=False) → 1.0, married (is_married=True) → 2^exponent. @@ -54,69 +57,70 @@ def equivalence_scale(is_married: IntND, exponent: float) -> FloatND: return jnp.where(is_married, 2.0**exponent, 1.0) -def leisure( +def fixed_cost_of_work( + age: Age, + fixed_cost_of_work_intercept: ScalarFloat, + fixed_cost_of_work_age_trend: ScalarFloat, + reference_age: ScalarInt, +) -> ScalarFloat: + """Age-dependent fixed cost of working (intercept + trend slope on age).""" + return fixed_cost_of_work_intercept + fixed_cost_of_work_age_trend * ( + age - reference_age + ) + + +def leisure_canwork_retiree_or_nongroup( working_hours_value: FloatND, - age: int, good_health: IntND, lagged_labor_supply: DiscreteState, - time_endowment: float, - leisure_cost_of_bad_health: float, - fixed_cost_of_work_intercept: float, - fixed_cost_of_work_age_trend: float, - labor_force_reentry_cost: float, - reference_age: int, + time_endowment: ScalarFloat, + leisure_cost_of_bad_health: ScalarFloat, + fixed_cost_of_work: ScalarFloat, + labor_force_reentry_cost: ScalarFloat, ) -> FloatND: - """Compute leisure given hours worked and state variables. + """Compute leisure for canwork retiree / nongroup regimes. - Fixed cost of work is age-dependent: intercept + trend * (age - reference_age). Reentry cost applies when returning to work after not working last period. - Working status is derived from working_hours_value > 0. """ - is_working = working_hours_value > 0.0 health_loss = jnp.where(good_health, 0.0, leisure_cost_of_bad_health) - - fixed_cost = fixed_cost_of_work_intercept + fixed_cost_of_work_age_trend * ( - age - reference_age - ) reentry_cost = jnp.where( lagged_labor_supply == LaggedLaborSupply.did_not_work, labor_force_reentry_cost, 0.0, ) work_loss = jnp.where( - is_working, working_hours_value + fixed_cost + reentry_cost, 0.0 + working_hours_value > 0.0, + working_hours_value + fixed_cost_of_work + reentry_cost, + 0.0, ) return time_endowment - health_loss - work_loss -def leisure_tied( +def leisure_canwork_tied( working_hours_value: FloatND, - age: int, good_health: IntND, - time_endowment: float, - leisure_cost_of_bad_health: float, - fixed_cost_of_work_intercept: float, - fixed_cost_of_work_age_trend: float, - reference_age: int, + time_endowment: ScalarFloat, + leisure_cost_of_bad_health: ScalarFloat, + fixed_cost_of_work: ScalarFloat, ) -> FloatND: - """Compute leisure for tied regimes (no reentry cost, no lagged_labor_supply).""" + """Compute leisure for canwork tied regimes. + + No need to consider reentry costs. + """ health_loss = jnp.where(good_health, 0.0, leisure_cost_of_bad_health) - fixed_cost = fixed_cost_of_work_intercept + fixed_cost_of_work_age_trend * ( - age - reference_age - ) work_loss = jnp.where( - working_hours_value > 0.0, working_hours_value + fixed_cost, 0.0 + working_hours_value > 0.0, working_hours_value + fixed_cost_of_work, 0.0 ) return time_endowment - health_loss - work_loss -def leisure_retired( +def leisure_forcedout( good_health: IntND, - time_endowment: float, - leisure_cost_of_bad_health: float, + time_endowment: ScalarFloat, + leisure_cost_of_bad_health: ScalarFloat, ) -> FloatND: - """Compute leisure for retired agents (no work).""" + """Compute leisure for forcedout regimes (no work).""" health_loss = jnp.where(good_health, 0.0, leisure_cost_of_bad_health) return time_endowment - health_loss @@ -129,14 +133,18 @@ def consumption_equiv( return consumption_dollars / equivalence_scale -def u_can_work( +def u_alive( consumption_equiv: FloatND, leisure: FloatND, consumption_weight: FloatND, coefficient_rra: FloatND, utility_scale_factor: FloatND, ) -> FloatND: - """Within-period utility for canwork regimes: CES over consumption and leisure.""" + """Within-period utility for every non-dead regime: CES over consumption and leisure. + + `leisure` is a DAG input — supplied per-regime by `leisure_canwork_retiree_or_nongroup`, + `leisure_canwork_tied`, or `leisure_forcedout`. + """ composite = consumption_equiv**consumption_weight * leisure ** ( 1.0 - consumption_weight ) @@ -152,49 +160,6 @@ def u_can_work( return u * utility_scale_factor -def u_cannot_work( - consumption_equiv: FloatND, - good_health: IntND, - consumption_weight: FloatND, - coefficient_rra: FloatND, - utility_scale_factor: FloatND, - time_endowment: float, - leisure_cost_of_bad_health: float, -) -> FloatND: - """Within-period utility for forcedout regimes (no work, retired leisure).""" - leisure = leisure_retired( - good_health=good_health, - time_endowment=time_endowment, - leisure_cost_of_bad_health=leisure_cost_of_bad_health, - ) - return u_can_work( - consumption_equiv=consumption_equiv, - leisure=leisure, - consumption_weight=consumption_weight, - coefficient_rra=coefficient_rra, - utility_scale_factor=utility_scale_factor, - ) - - -def u_dead( - assets: ContinuousState, - bequest_shifter: float, - scaled_bequest_weight: float, - consumption_weight: FloatND, - coefficient_rra: FloatND, - utility_scale_factor: FloatND, -) -> FloatND: - """Terminal bequest utility for the dead regime.""" - return bequest( - assets=assets, - bequest_shifter=bequest_shifter, - scaled_bequest_weight=scaled_bequest_weight, - consumption_weight=consumption_weight, - coefficient_rra=coefficient_rra, - utility_scale_factor=utility_scale_factor, - ) - - def consumption_weight( consumption_weights: FloatND, pref_type: DiscreteState, @@ -233,16 +198,16 @@ def discount_factor( def utility_scale_factor( - average_consumption_dollars: float, + average_consumption_equiv: ScalarFloat, consumption_weight: FloatND, coefficient_rra: FloatND, - time_endowment: float, - fixed_cost_of_work_intercept: float, - reference_hours: float, + time_endowment: ScalarFloat, + fixed_cost_of_work_intercept: ScalarFloat, + reference_hours: ScalarFloat, ) -> FloatND: """Compute the scale factor so utility is approximately 1 at typical values.""" average_leisure = time_endowment - reference_hours - fixed_cost_of_work_intercept - u_cons = average_consumption_dollars**consumption_weight + u_cons = average_consumption_equiv**consumption_weight u_leisure = average_leisure ** (1.0 - consumption_weight) one_minus_rra = jnp.where( @@ -257,12 +222,12 @@ def utility_scale_factor( def scaled_bequest_weight( - bequest_weight: float, - consumption_weight: float, - coefficient_rra: float, - time_endowment: float, - time_discount_factor: float, - rate_of_return: float, + bequest_weight: ScalarFloat, + consumption_weight: ScalarFloat, + coefficient_rra: ScalarFloat, + time_endowment: ScalarFloat, + time_discount_factor: ScalarFloat, + rate_of_return: ScalarFloat, ) -> FloatND: """Transform raw bequest weight into the form used in the bequest function. @@ -283,8 +248,8 @@ def scaled_bequest_weight( def bequest( assets: ContinuousState, - bequest_shifter: float, - scaled_bequest_weight: float, + bequest_shifter: ScalarFloat, + scaled_bequest_weight: ScalarFloat, consumption_weight: FloatND, coefficient_rra: FloatND, utility_scale_factor: FloatND, diff --git a/src/aca_model/baseline/health_insurance.py b/src/aca_model/baseline/health_insurance.py index 3732d6d..8371a29 100644 --- a/src/aca_model/baseline/health_insurance.py +++ b/src/aca_model/baseline/health_insurance.py @@ -15,6 +15,7 @@ import jax.numpy as jnp from lcm import categorical from lcm.typing import ( + Age, BoolND, ContinuousState, DiscreteAction, @@ -22,6 +23,8 @@ FloatND, IntND, Period, + ScalarBool, + ScalarFloat, ) from aca_model.agent.labor_market import LaborSupply @@ -47,8 +50,8 @@ def countable_income( spousal_income_amounts: FloatND, ss_benefit: FloatND, pension_benefit: FloatND, - ssi_ignored_overall: float, - ssi_ignored_earned: float, + ssi_ignored_overall: ScalarFloat, + ssi_ignored_earned: ScalarFloat, ) -> FloatND: """Compute countable income for SSI eligibility test. @@ -69,7 +72,7 @@ def is_ssi_eligible( assets: ContinuousState, countable_income: FloatND, spousal_income: DiscreteState, - gets_medicare: bool, + gets_medicare: ScalarBool, ssi_assets_test: FloatND, ssi_maximum_benefit: FloatND, ) -> BoolND: @@ -99,21 +102,21 @@ def ssi_benefit( def premium( - age: int, + age: Age, good_health: IntND, is_married: IntND, labor_supply: DiscreteAction, buy_private: DiscreteAction, - premium_intercept: float, - premium_age: int, - premium_age_sq: float, - premium_age_cub: float, - premium_predicted_hcc: float, - premium_good_health: float, - premium_married: float, - premium_works: float, - premium_married_works: float, - premium_minimum: float, + premium_intercept: ScalarFloat, + premium_age: ScalarFloat, + premium_age_sq: ScalarFloat, + premium_age_cub: ScalarFloat, + premium_predicted_hcc: ScalarFloat, + premium_good_health: ScalarFloat, + premium_married: ScalarFloat, + premium_works: ScalarFloat, + premium_married_works: ScalarFloat, + premium_minimum: ScalarFloat, predicted_hcc_insurer: FloatND, ) -> FloatND: """Compute health insurance premium for canwork regimes. @@ -141,20 +144,20 @@ def premium( def premium_insured( - age: int, + age: Age, good_health: IntND, is_married: IntND, labor_supply: DiscreteAction, - premium_intercept: float, - premium_age: int, - premium_age_sq: float, - premium_age_cub: float, - premium_predicted_hcc: float, - premium_good_health: float, - premium_married: float, - premium_works: float, - premium_married_works: float, - premium_minimum: float, + premium_intercept: ScalarFloat, + premium_age: ScalarFloat, + premium_age_sq: ScalarFloat, + premium_age_cub: ScalarFloat, + premium_predicted_hcc: ScalarFloat, + premium_good_health: ScalarFloat, + premium_married: ScalarFloat, + premium_works: ScalarFloat, + premium_married_works: ScalarFloat, + premium_minimum: ScalarFloat, predicted_hcc_insurer: FloatND, ) -> FloatND: """Compute health insurance premium for canwork regimes without `buy_private`. @@ -178,17 +181,17 @@ def premium_insured( def premium_retired( - age: int, + age: Age, good_health: IntND, is_married: IntND, - premium_intercept: float, - premium_age: int, - premium_age_sq: float, - premium_age_cub: float, - premium_predicted_hcc: float, - premium_good_health: float, - premium_married: float, - premium_minimum: float, + premium_intercept: ScalarFloat, + premium_age: ScalarFloat, + premium_age_sq: ScalarFloat, + premium_age_cub: ScalarFloat, + premium_predicted_hcc: ScalarFloat, + premium_good_health: ScalarFloat, + premium_married: ScalarFloat, + premium_minimum: ScalarFloat, predicted_hcc_insurer: FloatND, ) -> FloatND: """Compute health insurance premium for forcedout regimes. @@ -209,9 +212,9 @@ def premium_retired( def oop_costs( total_health_costs: FloatND, - deductible: float | FloatND, - coinsurance_rate: float | FloatND, - oop_max: float | FloatND, + deductible: ScalarFloat | FloatND, + coinsurance_rate: ScalarFloat | FloatND, + oop_max: ScalarFloat | FloatND, ) -> FloatND: """Compute out-of-pocket health care costs. @@ -228,9 +231,9 @@ def oop_costs( def primary_oop( total_health_costs: FloatND, buy_private: DiscreteAction, - deductible: float, - coinsurance_rate: float, - oop_max: float, + deductible: ScalarFloat, + coinsurance_rate: ScalarFloat, + oop_max: ScalarFloat, ) -> FloatND: """Compute primary OOP costs. @@ -272,9 +275,9 @@ def target_his( def oop_with_medicaid( primary_oop: FloatND, is_medicaid_eligible: BoolND, - deductible_medicaid: float, - coinsurance_rate_medicaid: float, - oop_max_medicaid: float, + deductible_medicaid: ScalarFloat, + coinsurance_rate_medicaid: ScalarFloat, + oop_max_medicaid: ScalarFloat, ) -> FloatND: """Apply Medicaid cost-sharing on top of primary insurance OOP costs. @@ -352,7 +355,7 @@ def total_costs( log_std: FloatND, hcc_persistent: ContinuousState, hcc_transitory: ContinuousState, - std_xsect_persistent: float, + std_xsect_persistent: ScalarFloat, ) -> FloatND: """Compute total health care costs from log-normal model. diff --git a/src/aca_model/baseline/regimes/_common.py b/src/aca_model/baseline/regimes/_common.py index a2e3a13..0b83321 100644 --- a/src/aca_model/baseline/regimes/_common.py +++ b/src/aca_model/baseline/regimes/_common.py @@ -213,6 +213,13 @@ class Grids: can vary across optimizer iterations without re-importing this module). """ +# AR(1) persistence of the Rouwenhorst shocks. Calibrated once; not +# routed through fixed_params because they shape the grid topology +# rather than feed any DAG function. The Rouwenhorst innovation std is +# `sqrt(1 - rho**2)` so the grid carries unit unconditional variance. +_HCC_RHO = 0.925 +_WAGE_RHO = 0.977 + def build_grids( *, @@ -240,7 +247,6 @@ def build_grids( # grid to have unconditional variance 1, the Rouwenhorst innovation # std must be √(1 − ρ²). Passing the σ_y itself (≈0.577 for hcc, # 0.5627 for wage) would mis-scale the grid. - _WAGE_RHO = 0.977 wage_res = lcm.shocks.ar1.Rouwenhorst( n_points=grid_config.n_wage_res_gridpoints, rho=_WAGE_RHO, @@ -277,9 +283,6 @@ def build_grids( ) -_HCC_RHO = 0.925 - - def get_hcc_persistent_shock(*, grid_config: GridConfig) -> lcm.shocks.ar1.Rouwenhorst: """Return the persistent-HCC AR(1) shock grid for a given `grid_config`. @@ -442,7 +445,7 @@ def build_dead_regime(grids: Grids) -> Regime: return Regime( transition=None, functions={ - "utility": preferences.u_dead, + "utility": preferences.bequest, "consumption_weight": preferences.consumption_weight, "coefficient_rra": preferences.coefficient_rra, "utility_scale_factor": preferences.utility_scale_factor, @@ -468,18 +471,13 @@ def select_ss_benefit(spec: RegimeSpec) -> Callable[..., Any]: return social_security.benefit_inelig_pre65 -def select_utility(spec: RegimeSpec) -> Callable[..., Any]: - """Select the utility function for a regime.""" - if spec["canwork"] != "canwork": - return preferences.u_cannot_work - return preferences.u_can_work - - def _select_leisure(spec: RegimeSpec) -> Callable[..., Any]: - """Select the leisure function for a canwork regime.""" + """Select the leisure function for a non-dead regime.""" + if spec["canwork"] == "forcedout": + return preferences.leisure_forcedout if spec["his"] == "tied": - return preferences.leisure_tied - return preferences.leisure + return preferences.leisure_canwork_tied + return preferences.leisure_canwork_retiree_or_nongroup def build_common_functions(spec: RegimeSpec) -> dict: @@ -503,9 +501,11 @@ def build_common_functions(spec: RegimeSpec) -> dict: if can_work: functions["working_hours_value"] = labor_market.working_hours_value - functions["leisure"] = _select_leisure(spec) functions["labor_income"] = labor_market.income + functions["fixed_cost_of_work"] = preferences.fixed_cost_of_work + functions["leisure"] = _select_leisure(spec) + functions["utility"] = preferences.u_alive functions["capital_income"] = assets_and_income.capital_income # spousal_income_amounts is a lookup table param, not a DAG function functions["is_married"] = labor_market.is_married diff --git a/src/aca_model/baseline/regimes/_nongroup.py b/src/aca_model/baseline/regimes/_nongroup.py index a723b44..730dcc4 100644 --- a/src/aca_model/baseline/regimes/_nongroup.py +++ b/src/aca_model/baseline/regimes/_nongroup.py @@ -7,7 +7,7 @@ from collections.abc import Callable from lcm import MarkovTransition, Regime -from lcm.typing import DiscreteAction, FloatND, Period +from lcm.typing import Age, DiscreteAction, FloatND, Period from aca_model.agent import assets_and_income, preferences from aca_model.agent.labor_market import LaborSupply @@ -25,7 +25,6 @@ make_targets, select_ss_benefit, select_target_for_age, - select_utility, ) from aca_model.environment import pensions @@ -41,7 +40,7 @@ def _make_transition_canwork( """ def transition( - age: int, + age: Age, period: Period, labor_supply: DiscreteAction, survival_probs: FloatND, @@ -65,7 +64,7 @@ def _make_transition_forcedout( """ def transition( - age: int, + age: Age, period: Period, survival_probs: FloatND, ) -> FloatND: @@ -80,7 +79,6 @@ def _build_functions(spec: RegimeSpec) -> dict: can_work = spec["canwork"] == "canwork" functions = build_common_functions(spec) - functions["utility"] = select_utility(spec) functions["ss_benefit"] = select_ss_benefit(spec) # his and gets_medicare are fixed params (constants per regime), diff --git a/src/aca_model/baseline/regimes/_retiree.py b/src/aca_model/baseline/regimes/_retiree.py index 4f16faa..4cb52d9 100644 --- a/src/aca_model/baseline/regimes/_retiree.py +++ b/src/aca_model/baseline/regimes/_retiree.py @@ -8,7 +8,7 @@ import jax.numpy as jnp from lcm import MarkovTransition, Regime -from lcm.typing import BoolND, DiscreteAction, FloatND, Period +from lcm.typing import Age, BoolND, DiscreteAction, FloatND, Period from aca_model.agent import assets_and_income, preferences from aca_model.agent.labor_market import LaborSupply @@ -26,7 +26,6 @@ make_targets, select_ss_benefit, select_target_for_age, - select_utility, ) from aca_model.environment import pensions @@ -43,7 +42,7 @@ def _make_transition_canwork( """ def transition( - age: int, + age: Age, period: Period, labor_supply: DiscreteAction, is_medicaid_eligible: BoolND, @@ -72,7 +71,7 @@ def _make_transition_forcedout( """ def transition( - age: int, + age: Age, period: Period, is_medicaid_eligible: BoolND, survival_probs: FloatND, @@ -92,7 +91,6 @@ def _build_functions(spec: RegimeSpec) -> dict: can_work = spec["canwork"] == "canwork" functions = build_common_functions(spec) - functions["utility"] = select_utility(spec) functions["ss_benefit"] = select_ss_benefit(spec) # his and gets_medicare are fixed params (constants per regime), diff --git a/src/aca_model/baseline/regimes/_tied.py b/src/aca_model/baseline/regimes/_tied.py index df76fa4..c9eeecc 100644 --- a/src/aca_model/baseline/regimes/_tied.py +++ b/src/aca_model/baseline/regimes/_tied.py @@ -9,7 +9,7 @@ import jax.numpy as jnp from lcm import MarkovTransition, Regime -from lcm.typing import BoolND, DiscreteAction, FloatND, Period +from lcm.typing import Age, BoolND, DiscreteAction, FloatND, Period from aca_model.agent import assets_and_income, preferences from aca_model.agent.labor_market import LaborSupply @@ -27,7 +27,6 @@ make_targets, select_ss_benefit, select_target_for_age, - select_utility, ) from aca_model.environment import pensions @@ -44,7 +43,7 @@ def _make_transition_canwork( """ def transition( - age: int, + age: Age, period: Period, labor_supply: DiscreteAction, is_medicaid_eligible: BoolND, @@ -70,7 +69,6 @@ def _build_functions(spec: RegimeSpec) -> dict: """Build functions dict for a tied regime.""" functions = build_common_functions(spec) - functions["utility"] = select_utility(spec) functions["ss_benefit"] = select_ss_benefit(spec) # his and gets_medicare are fixed params (constants per regime), diff --git a/src/aca_model/consumption_dollars_grid.py b/src/aca_model/consumption_dollars_grid.py index 7487fd8..5de175d 100644 --- a/src/aca_model/consumption_dollars_grid.py +++ b/src/aca_model/consumption_dollars_grid.py @@ -39,7 +39,7 @@ def inject_consumption_dollars_points( Walks every regime, reads its `consumption_dollars` action grid, and writes `params[regime_name]["consumption_dollars"] = {"points": }`. - The lower two gridpoints are the single and married unequiv + The lower two gridpoints are the single and married Dollar-valued transfer floors; the rest are geomspaced from the married floor up to `MAX_CONSUMPTION_DOLLARS`. @@ -101,7 +101,7 @@ def _compute_consumption_dollars_points( ) -> Array: """Return log-spaced consumption_dollars gridpoints with both floors pinned. - Single and married households face different unequiv (in-$) floors + Single and married households face different Dollar-valued floors (`consumption_equiv_floor` and the married-scaled twin respectively). Both must land exactly on the action grid so the borrowing constraint's `max(cash_on_hand, floor)` kink boundary is @@ -111,14 +111,26 @@ def _compute_consumption_dollars_points( `MAX_CONSUMPTION_DOLLARS` so the two pinned points stay strictly increasing. """ - married_unequiv_floor = consumption_equiv_floor * jnp.asarray(2.0) ** exponent + married_dollar_floor = consumption_equiv_floor * jnp.asarray(2.0) ** exponent tail = jnp.geomspace( - married_unequiv_floor, MAX_CONSUMPTION_DOLLARS, num=n_points - 1 + married_dollar_floor, MAX_CONSUMPTION_DOLLARS, num=n_points - 1 ) pts = jnp.concatenate([consumption_equiv_floor[None], tail]) # `jnp.geomspace` returns `start * r^0` for the first tail element, - # which mathematically equals `married_unequiv_floor` but drifts by + # which mathematically equals `married_dollar_floor` but drifts by # sub-ULP on some XLA backends. Pin the slot back to the exact # arithmetic value so the borrowing-constraint kink boundary at the # married floor is exactly representable. - return pts.at[1].set(married_unequiv_floor) + pts = pts.at[1].set(married_dollar_floor) + # The runtime params are concrete, not JIT-traced — a Python `if` + # is fine. Guard against a degenerate grid where the geomspace step + # is too small for the next point to clear `married_dollar_floor`. + if not float(married_dollar_floor) < float(pts[2]): + msg = ( + f"consumption_dollars grid is not strictly increasing at the " + f"married-floor kink: pts[1]={float(married_dollar_floor):.6g}, " + f"pts[2]={float(pts[2]):.6g}. Either `MAX_CONSUMPTION_DOLLARS` " + f"is too close to the married floor or `n_points` is too small." + ) + raise ValueError(msg) + return pts diff --git a/src/aca_model/environment/pensions.py b/src/aca_model/environment/pensions.py index eef72d4..cb03a6c 100644 --- a/src/aca_model/environment/pensions.py +++ b/src/aca_model/environment/pensions.py @@ -4,7 +4,7 @@ """ import jax.numpy as jnp -from lcm.typing import ContinuousState, FloatND, IntND, Period +from lcm.typing import ContinuousState, FloatND, IntND, Period, ScalarFloat def benefit( @@ -131,7 +131,7 @@ def wealth_next_before_adjustment( pension_wealth: FloatND, pension_benefit: FloatND, pension_accrual: FloatND, - rate_of_return: float, + rate_of_return: ScalarFloat, unconditional_survival_prob: FloatND, period: Period, ) -> FloatND: diff --git a/src/aca_model/environment/social_security.py b/src/aca_model/environment/social_security.py index e3574cf..5863812 100644 --- a/src/aca_model/environment/social_security.py +++ b/src/aca_model/environment/social_security.py @@ -9,7 +9,16 @@ import jax.numpy as jnp from lcm import categorical -from lcm.typing import ContinuousState, DiscreteAction, DiscreteState, FloatND, Period +from lcm.typing import ( + Age, + ContinuousState, + DiscreteAction, + DiscreteState, + FloatND, + Period, + ScalarFloat, + ScalarInt, +) from aca_model.agent.labor_market import LaborSupply @@ -77,17 +86,17 @@ def benefit_forced( def benefit_choose_post65( pia: FloatND, - age: int, + age: Age, period: Period, claim_ss: DiscreteAction, claimed_ss: DiscreteState, labor_supply: DiscreteAction, labor_income: FloatND, early_ret_adjustment: FloatND, - normal_retirement_age: int, + normal_retirement_age: ScalarInt, earnings_test_threshold: FloatND, earnings_test_fraction: FloatND, - earnings_test_repealed_age: int, + earnings_test_repealed_age: ScalarInt, ) -> FloatND: """SS benefit for post-65, ss=choose: SS if claiming, 0 otherwise.""" ss = jnp.maximum(claim_ss, claimed_ss) @@ -110,7 +119,7 @@ def benefit_choose_post65( def benefit_choose_pre65( pia: FloatND, ssdi_pia: FloatND, - age: int, + age: Age, period: Period, claim_ss: DiscreteAction, claimed_ss: DiscreteState, @@ -118,11 +127,11 @@ def benefit_choose_pre65( labor_supply: DiscreteAction, labor_income: FloatND, early_ret_adjustment: FloatND, - normal_retirement_age: int, + normal_retirement_age: ScalarInt, earnings_test_threshold: FloatND, earnings_test_fraction: FloatND, - earnings_test_repealed_age: int, - ssdi_substantial_gainful_activity: float, + earnings_test_repealed_age: ScalarInt, + ssdi_substantial_gainful_activity: ScalarFloat, ) -> FloatND: """SS benefit for pre-65, ss=choose: SS if claiming, SSDI if disabled, else 0.""" ss = jnp.maximum(claim_ss, claimed_ss) @@ -160,7 +169,7 @@ def benefit_inelig_pre65( ssdi_pia: FloatND, health: DiscreteState, labor_income: FloatND, - ssdi_substantial_gainful_activity: float, + ssdi_substantial_gainful_activity: ScalarFloat, ) -> FloatND: """SS benefit for pre-65, ss=inelig: SSDI if disabled, else 0.""" is_disabled = health == 0 @@ -200,16 +209,16 @@ def benefit_withheld_fraction( def _apply_benefit_rules( *, pia: FloatND, - age: int, + age: Age, period: Period, ss: FloatND, work: FloatND, labor_income: FloatND, early_ret_adjustment: FloatND, - normal_retirement_age: int, + normal_retirement_age: ScalarInt, earnings_test_threshold: FloatND, earnings_test_fraction: FloatND, - earnings_test_repealed_age: int, + earnings_test_repealed_age: ScalarInt, ) -> FloatND: """Apply early retirement adjustment and earnings test to PIA. @@ -246,16 +255,16 @@ def next_aime( aime: ContinuousState, labor_income: FloatND, period: Period, - age: int, + age: Age, benefit_withheld_fraction: FloatND, earnings_test_credited_back: FloatND, - earnings_test_repealed_age: int, + earnings_test_repealed_age: ScalarInt, pia_table: FloatND, pia_aime_grid: FloatND, - aime_accrual_factor: float, - aggregate_wage_growth: float, - aime_last_age_with_indexing: int, - aime_kink_2: float, + aime_accrual_factor: ScalarFloat, + aggregate_wage_growth: ScalarFloat, + aime_last_age_with_indexing: ScalarInt, + aime_kink_2: ScalarFloat, ratio_lowest_earnings: FloatND, ) -> ContinuousState: """Compute next period's AIME given labor earnings. @@ -306,19 +315,19 @@ def next_aime_disabled( aime: ContinuousState, labor_income: FloatND, period: Period, - age: int, + age: Age, health: DiscreteState, benefit_withheld_fraction: FloatND, earnings_test_credited_back: FloatND, - earnings_test_repealed_age: int, + earnings_test_repealed_age: ScalarInt, pia_table: FloatND, pia_aime_grid: FloatND, - aime_accrual_factor: float, - aggregate_wage_growth: float, - aime_last_age_with_indexing: int, - aime_kink_2: float, + aime_accrual_factor: ScalarFloat, + aggregate_wage_growth: ScalarFloat, + aime_last_age_with_indexing: ScalarInt, + aime_kink_2: ScalarFloat, ratio_lowest_earnings: FloatND, - medicare_age: int, + medicare_age: ScalarInt, di_dropout_scale: FloatND, di_dropout_next_period_ratio: FloatND, ) -> ContinuousState: diff --git a/tests/test_baseline_equivalence.py b/tests/test_baseline_equivalence.py index 7d18e5c..5e68e86 100644 --- a/tests/test_baseline_equivalence.py +++ b/tests/test_baseline_equivalence.py @@ -75,9 +75,9 @@ def test_aca_cash_on_hand_matches_baseline_when_neutral() -> None: def test_baseline_primary_oop_no_cost_sharing_scale() -> None: """Baseline primary_oop applies raw deductible/coinsurance/oop_max.""" costs = jnp.array(5000.0) - deductible = 500.0 - coinsurance = 0.2 - oop_max_val = 3000.0 + deductible = jnp.asarray(500.0) + coinsurance = jnp.asarray(0.2) + oop_max_val = jnp.asarray(3000.0) result = health_insurance.primary_oop( total_health_costs=costs, buy_private=jnp.array(BuyPrivate.yes), @@ -97,9 +97,9 @@ def test_baseline_primary_oop_no_cost_sharing_scale() -> None: def test_aca_primary_oop_scaled_reduces_costs() -> None: """ACA primary_oop with scale < 1.0 reduces OOP costs.""" costs = jnp.array(5000.0) - deductible = 500.0 - coinsurance = 0.2 - oop_max_val = 3000.0 + deductible = jnp.asarray(500.0) + coinsurance = jnp.asarray(0.2) + oop_max_val = jnp.asarray(3000.0) oop_full = aca_hi.primary_oop( total_health_costs=costs, cost_sharing_scale=jnp.array(1.0), diff --git a/tests/test_health_insurance.py b/tests/test_health_insurance.py index cd89c9f..06a23b7 100644 --- a/tests/test_health_insurance.py +++ b/tests/test_health_insurance.py @@ -18,7 +18,7 @@ def test_ssi_eligible_assets_too_high() -> None: assets=jnp.array(5000.0), countable_income=jnp.array(1000.0), spousal_income=jnp.array(0), - gets_medicare=True, + gets_medicare=jnp.asarray(True), ssi_assets_test=SSI_ASSETS_TEST, ssi_maximum_benefit=SSI_MAX_BENEFIT, ) @@ -30,7 +30,7 @@ def test_ssi_eligible_income_too_high() -> None: assets=jnp.array(1000.0), countable_income=jnp.array(9000.0), spousal_income=jnp.array(0), - gets_medicare=True, + gets_medicare=jnp.asarray(True), ssi_assets_test=SSI_ASSETS_TEST, ssi_maximum_benefit=SSI_MAX_BENEFIT, ) @@ -42,7 +42,7 @@ def test_ssi_eligible_no_medicare() -> None: assets=jnp.array(1000.0), countable_income=jnp.array(1000.0), spousal_income=jnp.array(0), - gets_medicare=False, + gets_medicare=jnp.asarray(False), ssi_assets_test=SSI_ASSETS_TEST, ssi_maximum_benefit=SSI_MAX_BENEFIT, ) @@ -54,7 +54,7 @@ def test_ssi_eligible_all_pass() -> None: assets=jnp.array(1000.0), countable_income=jnp.array(1000.0), spousal_income=jnp.array(0), - gets_medicare=True, + gets_medicare=jnp.asarray(True), ssi_assets_test=SSI_ASSETS_TEST, ssi_maximum_benefit=SSI_MAX_BENEFIT, ) @@ -148,20 +148,20 @@ def test_compute_table_uniform_transition(table_inputs: dict) -> None: _PREMIUM_KWARGS: dict = { - "age": 60, + "age": jnp.int32(60), "good_health": jnp.array(True), "is_married": jnp.array(False), "labor_supply": jnp.array(LaborSupply.h2000), - "premium_intercept": 1000.0, - "premium_age": 0, - "premium_age_sq": 0.0, - "premium_age_cub": 0.0, - "premium_predicted_hcc": 0.0, - "premium_good_health": 0.0, - "premium_married": 0.0, - "premium_works": 0.0, - "premium_married_works": 0.0, - "premium_minimum": 500.0, + "premium_intercept": jnp.asarray(1000.0), + "premium_age": jnp.asarray(0.0), + "premium_age_sq": jnp.asarray(0.0), + "premium_age_cub": jnp.asarray(0.0), + "premium_predicted_hcc": jnp.asarray(0.0), + "premium_good_health": jnp.asarray(0.0), + "premium_married": jnp.asarray(0.0), + "premium_works": jnp.asarray(0.0), + "premium_married_works": jnp.asarray(0.0), + "premium_minimum": jnp.asarray(500.0), "predicted_hcc_insurer": jnp.array(0.0), } @@ -187,9 +187,9 @@ def test_primary_oop_insured_applies_deductible_coinsurance() -> None: result = health_insurance.primary_oop( total_health_costs=jnp.array(10000.0), buy_private=jnp.array(BuyPrivate.yes), - deductible=500.0, - coinsurance_rate=0.2, - oop_max=5000.0, + deductible=jnp.asarray(500.0), + coinsurance_rate=jnp.asarray(0.2), + oop_max=jnp.asarray(5000.0), ) expected = 500.0 + (10000.0 - 500.0) * 0.2 # 2400 assert jnp.isclose(result, expected, atol=ATOL) @@ -200,8 +200,8 @@ def test_primary_oop_uninsured_equals_total_costs() -> None: result = health_insurance.primary_oop( total_health_costs=total, buy_private=jnp.array(BuyPrivate.no), - deductible=500.0, - coinsurance_rate=0.2, - oop_max=5000.0, + deductible=jnp.asarray(500.0), + coinsurance_rate=jnp.asarray(0.2), + oop_max=jnp.asarray(5000.0), ) assert jnp.isclose(result, total) diff --git a/tests/test_model_components.py b/tests/test_model_components.py index 5b7df6a..b3569c5 100644 --- a/tests/test_model_components.py +++ b/tests/test_model_components.py @@ -7,77 +7,68 @@ def test_equivalence_scale_single() -> None: - result = preferences.equivalence_scale(jnp.array(False), 0.7) + result = preferences.equivalence_scale(jnp.array(False), jnp.asarray(0.7)) assert jnp.isclose(result, 1.0) def test_equivalence_scale_married() -> None: - result = preferences.equivalence_scale(jnp.array(True), 0.7) + result = preferences.equivalence_scale(jnp.array(True), jnp.asarray(0.7)) assert jnp.isclose(result, 2.0**0.7) def test_leisure_not_working() -> None: - result = preferences.leisure( + result = preferences.leisure_canwork_retiree_or_nongroup( working_hours_value=jnp.array(0.0), - age=60, good_health=jnp.array(1.0), lagged_labor_supply=jnp.array(0), - time_endowment=5000.0, - leisure_cost_of_bad_health=500.0, - fixed_cost_of_work_intercept=100.0, - fixed_cost_of_work_age_trend=5, - labor_force_reentry_cost=200.0, - reference_age=50, + time_endowment=jnp.asarray(5000.0), + leisure_cost_of_bad_health=jnp.asarray(500.0), + fixed_cost_of_work=jnp.asarray(150.0), + labor_force_reentry_cost=jnp.asarray(200.0), ) assert jnp.isclose(result, 5000.0) def test_leisure_working_good_health() -> None: - result = preferences.leisure( + result = preferences.leisure_canwork_retiree_or_nongroup( working_hours_value=jnp.array(2000.0), - age=60, good_health=jnp.array(1.0), lagged_labor_supply=jnp.array(1), - time_endowment=5000.0, - leisure_cost_of_bad_health=500.0, - fixed_cost_of_work_intercept=100.0, - fixed_cost_of_work_age_trend=5, - labor_force_reentry_cost=200.0, - reference_age=50, + time_endowment=jnp.asarray(5000.0), + leisure_cost_of_bad_health=jnp.asarray(500.0), + fixed_cost_of_work=jnp.asarray(150.0), + labor_force_reentry_cost=jnp.asarray(200.0), ) - # 5000 - 0 (good health) - (2000 + 100 + 5*(60-50) + 0 (lagged=1)) - expected = 5000.0 - 2000.0 - 100.0 - 50.0 + # 5000 - 0 (good health) - (2000 + 150 + 0 (lagged=1)) + expected = 5000.0 - 2000.0 - 150.0 assert jnp.isclose(result, expected) def test_leisure_reentry_cost() -> None: - result = preferences.leisure( + result = preferences.leisure_canwork_retiree_or_nongroup( working_hours_value=jnp.array(2000.0), - age=60, good_health=jnp.array(1.0), lagged_labor_supply=jnp.array(0), - time_endowment=5000.0, - leisure_cost_of_bad_health=500.0, - fixed_cost_of_work_intercept=100.0, - fixed_cost_of_work_age_trend=5, - labor_force_reentry_cost=200.0, - reference_age=50, + time_endowment=jnp.asarray(5000.0), + leisure_cost_of_bad_health=jnp.asarray(500.0), + fixed_cost_of_work=jnp.asarray(150.0), + labor_force_reentry_cost=jnp.asarray(200.0), ) - expected = 5000.0 - 2000.0 - 100.0 - 50.0 - 200.0 + expected = 5000.0 - 2000.0 - 150.0 - 200.0 assert jnp.isclose(result, expected) def test_leisure_bad_health() -> None: - result = preferences.leisure_retired( + result = preferences.leisure_forcedout( good_health=jnp.array(0.0), - time_endowment=5000.0, - leisure_cost_of_bad_health=500.0, + time_endowment=jnp.asarray(5000.0), + leisure_cost_of_bad_health=jnp.asarray(500.0), ) assert jnp.isclose(result, 4500.0) def test_utility_positive_leisure() -> None: - result = preferences.u_can_work( + result = preferences.u_alive( consumption_equiv=jnp.array(10000.0), leisure=jnp.array(3000.0), consumption_weight=jnp.array(0.4), @@ -88,7 +79,7 @@ def test_utility_positive_leisure() -> None: def test_utility_log_case() -> None: - result = preferences.u_can_work( + result = preferences.u_alive( consumption_equiv=jnp.array(10000.0), leisure=jnp.array(3000.0), consumption_weight=jnp.array(0.4), @@ -103,8 +94,8 @@ def test_utility_log_case() -> None: def test_bequest_positive_assets() -> None: result = preferences.bequest( assets=jnp.array(100000.0), - bequest_shifter=5000.0, - scaled_bequest_weight=0.5, + bequest_shifter=jnp.asarray(5000.0), + scaled_bequest_weight=jnp.asarray(0.5), consumption_weight=jnp.array(0.4), coefficient_rra=jnp.array(2.0), utility_scale_factor=jnp.array(1.0), @@ -115,8 +106,8 @@ def test_bequest_positive_assets() -> None: def test_bequest_zero_assets() -> None: result = preferences.bequest( assets=jnp.array(0.0), - bequest_shifter=5000.0, - scaled_bequest_weight=0.5, + bequest_shifter=jnp.asarray(5000.0), + scaled_bequest_weight=jnp.asarray(0.5), consumption_weight=jnp.array(0.4), coefficient_rra=jnp.array(2.0), utility_scale_factor=jnp.array(1.0), @@ -164,13 +155,13 @@ def test_next_aime_accrual() -> None: age=jnp.int32(55), benefit_withheld_fraction=jnp.array(0.0), earnings_test_credited_back=jnp.zeros(100), - earnings_test_repealed_age=70, + earnings_test_repealed_age=jnp.int32(70), pia_table=jnp.array([0.0, 711.9, 2115.1, 3015.1]), pia_aime_grid=jnp.array([0.0, 791.0, 4768.0, 8000.0]), - aime_accrual_factor=1 / 35, - aggregate_wage_growth=0.02, - aime_last_age_with_indexing=60, - aime_kink_2=8000.0, + aime_accrual_factor=jnp.asarray(1 / 35), + aggregate_wage_growth=jnp.asarray(0.02), + aime_last_age_with_indexing=jnp.int32(60), + aime_kink_2=jnp.asarray(8000.0), ratio_lowest_earnings=ratio, ) assert result > 1000.0 diff --git a/tests/test_pension_integration.py b/tests/test_pension_integration.py index 0f6c07d..ae27d3f 100644 --- a/tests/test_pension_integration.py +++ b/tests/test_pension_integration.py @@ -12,7 +12,7 @@ from aca_model.environment import pensions ATOL = 0.01 -RATE_OF_RETURN = 0.03 +RATE_OF_RETURN = jnp.asarray(0.03) # Pension imputation coefficients — two HIS types with different intercepts. # HIS 0 (retiree): intercept = -50, HIS 1 (nongroup): intercept = -80. diff --git a/tests/test_pensions.py b/tests/test_pensions.py index 514ab8c..286c72a 100644 --- a/tests/test_pensions.py +++ b/tests/test_pensions.py @@ -141,7 +141,7 @@ def test_pension_wealth_next_accrual_only() -> None: lli = math.log(10000) prob = math.exp(0.1) / (1 + math.exp(0.1)) accrual = lli * 0.5 * prob * 10000 - r = 0.03 + r = jnp.asarray(0.03) result = pensions.wealth_next_before_adjustment( pension_wealth=jnp.array(0.0), pension_benefit=jnp.array(0.0), @@ -157,7 +157,7 @@ def test_pension_wealth_next_with_benefit() -> None: lli = math.log(10000) prob = math.exp(0.1) / (1 + math.exp(0.1)) accrual = lli * 0.5 * prob * 10000 - r = 0.03 + r = jnp.asarray(0.03) result = pensions.wealth_next_before_adjustment( pension_wealth=jnp.array(3000.0), pension_benefit=jnp.array(2000.0), diff --git a/tests/test_preferences.py b/tests/test_preferences.py index 1b5107f..8017635 100644 --- a/tests/test_preferences.py +++ b/tests/test_preferences.py @@ -7,16 +7,18 @@ from aca_model.agent import preferences -# Struct-ret preference parameters -CONSUMPTION_WEIGHT = 0.6 -TIME_DISCOUNT_FACTOR = 0.85 -TIME_ENDOWMENT = 5000.0 -FIXED_COST_INTERCEPT = 0.0 -AVERAGE_CONSUMPTION = 10000.0 -RATE_OF_RETURN = 0.01 -BEQUEST_WEIGHT = 0.02 -BEQUEST_SHIFTER = 500_000.0 -REFERENCE_HOURS = 1000.0 +# Struct-ret preference parameters. Tests call DAG functions directly, so +# every scalar fixed_param is supplied as a 0-d jax array (the type pylcm +# casts user-provided Python scalars to before passing them into the DAG). +CONSUMPTION_WEIGHT = jnp.asarray(0.6) +TIME_DISCOUNT_FACTOR = jnp.asarray(0.85) +TIME_ENDOWMENT = jnp.asarray(5000.0) +FIXED_COST_INTERCEPT = jnp.asarray(0.0) +AVERAGE_CONSUMPTION = jnp.asarray(10000.0) +RATE_OF_RETURN = jnp.asarray(0.01) +BEQUEST_WEIGHT = jnp.asarray(0.02) +BEQUEST_SHIFTER = jnp.asarray(500_000.0) +REFERENCE_HOURS = jnp.asarray(1000.0) # --- utility_scale_factor --- @@ -24,9 +26,9 @@ def test_utility_scale_factor_crra() -> None: result = preferences.utility_scale_factor( - average_consumption_dollars=AVERAGE_CONSUMPTION, - consumption_weight=jnp.array(CONSUMPTION_WEIGHT), - coefficient_rra=jnp.array(5.0), + average_consumption_equiv=AVERAGE_CONSUMPTION, + consumption_weight=CONSUMPTION_WEIGHT, + coefficient_rra=jnp.asarray(5.0), time_endowment=TIME_ENDOWMENT, fixed_cost_of_work_intercept=FIXED_COST_INTERCEPT, reference_hours=REFERENCE_HOURS, @@ -36,9 +38,9 @@ def test_utility_scale_factor_crra() -> None: def test_utility_scale_factor_log() -> None: result = preferences.utility_scale_factor( - average_consumption_dollars=AVERAGE_CONSUMPTION, - consumption_weight=jnp.array(CONSUMPTION_WEIGHT), - coefficient_rra=jnp.array(1.0), + average_consumption_equiv=AVERAGE_CONSUMPTION, + consumption_weight=CONSUMPTION_WEIGHT, + coefficient_rra=jnp.asarray(1.0), time_endowment=TIME_ENDOWMENT, fixed_cost_of_work_intercept=FIXED_COST_INTERCEPT, reference_hours=REFERENCE_HOURS, @@ -53,7 +55,7 @@ def test_scaled_bequest_weight_positive() -> None: result = preferences.scaled_bequest_weight( bequest_weight=BEQUEST_WEIGHT, consumption_weight=CONSUMPTION_WEIGHT, - coefficient_rra=5.0, + coefficient_rra=jnp.asarray(5.0), time_endowment=TIME_ENDOWMENT, time_discount_factor=TIME_DISCOUNT_FACTOR, rate_of_return=RATE_OF_RETURN, @@ -65,7 +67,7 @@ def test_scaled_bequest_weight_log() -> None: result = preferences.scaled_bequest_weight( bequest_weight=BEQUEST_WEIGHT, consumption_weight=CONSUMPTION_WEIGHT, - coefficient_rra=1.0, + coefficient_rra=jnp.asarray(1.0), time_endowment=TIME_ENDOWMENT, time_discount_factor=TIME_DISCOUNT_FACTOR, rate_of_return=RATE_OF_RETURN, @@ -75,9 +77,9 @@ def test_scaled_bequest_weight_log() -> None: def test_scaled_bequest_weight_zero() -> None: result = preferences.scaled_bequest_weight( - bequest_weight=0.0, + bequest_weight=jnp.asarray(0.0), consumption_weight=CONSUMPTION_WEIGHT, - coefficient_rra=5.0, + coefficient_rra=jnp.asarray(5.0), time_endowment=TIME_ENDOWMENT, time_discount_factor=TIME_DISCOUNT_FACTOR, rate_of_return=RATE_OF_RETURN, @@ -90,18 +92,18 @@ def test_scaled_bequest_weight_zero() -> None: def test_utility_log_regression() -> None: scale = preferences.utility_scale_factor( - average_consumption_dollars=AVERAGE_CONSUMPTION, - consumption_weight=jnp.array(CONSUMPTION_WEIGHT), - coefficient_rra=jnp.array(1.0), + average_consumption_equiv=AVERAGE_CONSUMPTION, + consumption_weight=CONSUMPTION_WEIGHT, + coefficient_rra=jnp.asarray(1.0), time_endowment=TIME_ENDOWMENT, fixed_cost_of_work_intercept=FIXED_COST_INTERCEPT, reference_hours=REFERENCE_HOURS, ) - result = preferences.u_can_work( + result = preferences.u_alive( consumption_equiv=jnp.array(50000.0), leisure=jnp.array(400.0), - consumption_weight=jnp.array(CONSUMPTION_WEIGHT), - coefficient_rra=jnp.array(1.0), + consumption_weight=CONSUMPTION_WEIGHT, + coefficient_rra=jnp.asarray(1.0), utility_scale_factor=scale, ) assert jnp.isclose(result, 1.005_046_313_660_588_5, rtol=1e-5) @@ -109,18 +111,18 @@ def test_utility_log_regression() -> None: def test_utility_crra_regression() -> None: scale = preferences.utility_scale_factor( - average_consumption_dollars=AVERAGE_CONSUMPTION, - consumption_weight=jnp.array(CONSUMPTION_WEIGHT), - coefficient_rra=jnp.array(5.0), + average_consumption_equiv=AVERAGE_CONSUMPTION, + consumption_weight=CONSUMPTION_WEIGHT, + coefficient_rra=jnp.asarray(5.0), time_endowment=TIME_ENDOWMENT, fixed_cost_of_work_intercept=FIXED_COST_INTERCEPT, reference_hours=REFERENCE_HOURS, ) - result = preferences.u_can_work( + result = preferences.u_alive( consumption_equiv=jnp.array(50000.0), leisure=jnp.array(400.0), - consumption_weight=jnp.array(CONSUMPTION_WEIGHT), - coefficient_rra=jnp.array(5.0), + consumption_weight=CONSUMPTION_WEIGHT, + coefficient_rra=jnp.asarray(5.0), utility_scale_factor=scale, ) assert jnp.isclose(result, -0.836_511_642_073_019_1, rtol=1e-5) @@ -129,25 +131,25 @@ def test_utility_crra_regression() -> None: def test_utility_married_equivalence() -> None: """Married with equiv-scaled consumption_dollars should equal single utility.""" scale = preferences.utility_scale_factor( - average_consumption_dollars=AVERAGE_CONSUMPTION, - consumption_weight=jnp.array(CONSUMPTION_WEIGHT), - coefficient_rra=jnp.array(5.0), + average_consumption_equiv=AVERAGE_CONSUMPTION, + consumption_weight=CONSUMPTION_WEIGHT, + coefficient_rra=jnp.asarray(5.0), time_endowment=TIME_ENDOWMENT, fixed_cost_of_work_intercept=FIXED_COST_INTERCEPT, reference_hours=REFERENCE_HOURS, ) - single = preferences.u_can_work( + single = preferences.u_alive( consumption_equiv=jnp.array(50000.0), leisure=jnp.array(400.0), - consumption_weight=jnp.array(CONSUMPTION_WEIGHT), - coefficient_rra=jnp.array(5.0), + consumption_weight=CONSUMPTION_WEIGHT, + coefficient_rra=jnp.asarray(5.0), utility_scale_factor=scale, ) - married = preferences.u_can_work( + married = preferences.u_alive( consumption_equiv=jnp.array(50000.0), leisure=jnp.array(400.0), - consumption_weight=jnp.array(CONSUMPTION_WEIGHT), - coefficient_rra=jnp.array(5.0), + consumption_weight=CONSUMPTION_WEIGHT, + coefficient_rra=jnp.asarray(5.0), utility_scale_factor=scale, ) assert jnp.isclose(single, married, rtol=1e-5) @@ -158,9 +160,9 @@ def test_utility_married_equivalence() -> None: def test_bequest_log_regression() -> None: scale = preferences.utility_scale_factor( - average_consumption_dollars=AVERAGE_CONSUMPTION, - consumption_weight=jnp.array(CONSUMPTION_WEIGHT), - coefficient_rra=jnp.array(1.0), + average_consumption_equiv=AVERAGE_CONSUMPTION, + consumption_weight=CONSUMPTION_WEIGHT, + coefficient_rra=jnp.asarray(1.0), time_endowment=TIME_ENDOWMENT, fixed_cost_of_work_intercept=FIXED_COST_INTERCEPT, reference_hours=REFERENCE_HOURS, @@ -168,7 +170,7 @@ def test_bequest_log_regression() -> None: bwt = preferences.scaled_bequest_weight( bequest_weight=BEQUEST_WEIGHT, consumption_weight=CONSUMPTION_WEIGHT, - coefficient_rra=1.0, + coefficient_rra=jnp.asarray(1.0), time_endowment=TIME_ENDOWMENT, time_discount_factor=TIME_DISCOUNT_FACTOR, rate_of_return=RATE_OF_RETURN, @@ -176,9 +178,9 @@ def test_bequest_log_regression() -> None: result = preferences.bequest( assets=jnp.array(10000.0), bequest_shifter=BEQUEST_SHIFTER, - scaled_bequest_weight=bwt.item(), - consumption_weight=jnp.array(CONSUMPTION_WEIGHT), - coefficient_rra=jnp.array(1.0), + scaled_bequest_weight=bwt, + consumption_weight=CONSUMPTION_WEIGHT, + coefficient_rra=jnp.asarray(1.0), utility_scale_factor=scale, ) assert jnp.isclose(result, 86.539_249_963_643_88, rtol=1e-5) @@ -186,9 +188,9 @@ def test_bequest_log_regression() -> None: def test_bequest_crra_regression() -> None: scale = preferences.utility_scale_factor( - average_consumption_dollars=AVERAGE_CONSUMPTION, - consumption_weight=jnp.array(CONSUMPTION_WEIGHT), - coefficient_rra=jnp.array(5.0), + average_consumption_equiv=AVERAGE_CONSUMPTION, + consumption_weight=CONSUMPTION_WEIGHT, + coefficient_rra=jnp.asarray(5.0), time_endowment=TIME_ENDOWMENT, fixed_cost_of_work_intercept=FIXED_COST_INTERCEPT, reference_hours=REFERENCE_HOURS, @@ -196,7 +198,7 @@ def test_bequest_crra_regression() -> None: bwt = preferences.scaled_bequest_weight( bequest_weight=BEQUEST_WEIGHT, consumption_weight=CONSUMPTION_WEIGHT, - coefficient_rra=5.0, + coefficient_rra=jnp.asarray(5.0), time_endowment=TIME_ENDOWMENT, time_discount_factor=TIME_DISCOUNT_FACTOR, rate_of_return=RATE_OF_RETURN, @@ -204,9 +206,9 @@ def test_bequest_crra_regression() -> None: result = preferences.bequest( assets=jnp.array(10000.0), bequest_shifter=BEQUEST_SHIFTER, - scaled_bequest_weight=bwt.item(), - consumption_weight=jnp.array(CONSUMPTION_WEIGHT), - coefficient_rra=jnp.array(5.0), + scaled_bequest_weight=bwt, + consumption_weight=CONSUMPTION_WEIGHT, + coefficient_rra=jnp.asarray(5.0), utility_scale_factor=scale, ) assert jnp.isclose(result, -37.932_748_117_035_63, rtol=1e-5) diff --git a/tests/test_social_security.py b/tests/test_social_security.py index b8ac44a..c4c704f 100644 --- a/tests/test_social_security.py +++ b/tests/test_social_security.py @@ -23,10 +23,11 @@ PIA_CONVERSION_RATE_2 = 0.15 PIA_KINK_0 = 5151.6 PIA_KINK_1 = 14359.9 -AIME_ACCRUAL_FACTOR = 0.025 -AGGREGATE_WAGE_GROWTH = 0.03 -AIME_LAST_AGE_WITH_INDEXING = 59 -SSDI_SGA = 12840.0 +AIME_ACCRUAL_FACTOR = jnp.asarray(0.025) +AGGREGATE_WAGE_GROWTH = jnp.asarray(0.03) +AIME_LAST_AGE_WITH_INDEXING = jnp.int32(59) +AIME_KINK_2_SCALAR = jnp.asarray(AIME_KINK_2) +SSDI_SGA = jnp.asarray(12840.0) PIA_PARAMS = { "aime_kink_0": AIME_KINK_0, @@ -59,8 +60,8 @@ DI_SCALE = jnp.array( compute_di_dropout_scale( pd.Series(_RATIO_NP), - AIME_ACCRUAL_FACTOR, - start_age=jnp.int32(0), + AIME_ACCRUAL_FACTOR.item(), + start_age=0, n_periods=100, ) ) @@ -135,7 +136,7 @@ def test_next_aime_indexing_high_income() -> None: aime_accrual_factor=AIME_ACCRUAL_FACTOR, aggregate_wage_growth=AGGREGATE_WAGE_GROWTH, aime_last_age_with_indexing=AIME_LAST_AGE_WITH_INDEXING, - aime_kink_2=AIME_KINK_2, + aime_kink_2=AIME_KINK_2_SCALAR, ratio_lowest_earnings=RATIO, ) expected = 1000 * 1.03 + (20000 - 0.2 * 1000 * 1.03) * 0.025 @@ -156,7 +157,7 @@ def test_next_aime_indexing_low_income() -> None: aime_accrual_factor=AIME_ACCRUAL_FACTOR, aggregate_wage_growth=AGGREGATE_WAGE_GROWTH, aime_last_age_with_indexing=AIME_LAST_AGE_WITH_INDEXING, - aime_kink_2=AIME_KINK_2, + aime_kink_2=AIME_KINK_2_SCALAR, ratio_lowest_earnings=RATIO, ) assert jnp.isclose(result, 10000 * 1.03, atol=ATOL) @@ -176,7 +177,7 @@ def test_next_aime_no_indexing_high_income() -> None: aime_accrual_factor=AIME_ACCRUAL_FACTOR, aggregate_wage_growth=AGGREGATE_WAGE_GROWTH, aime_last_age_with_indexing=AIME_LAST_AGE_WITH_INDEXING, - aime_kink_2=AIME_KINK_2, + aime_kink_2=AIME_KINK_2_SCALAR, ratio_lowest_earnings=RATIO, ) expected = 1000 + (20000 - 0.4 * 1000) * 0.025 @@ -197,7 +198,7 @@ def test_next_aime_no_indexing_low_income() -> None: aime_accrual_factor=AIME_ACCRUAL_FACTOR, aggregate_wage_growth=AGGREGATE_WAGE_GROWTH, aime_last_age_with_indexing=AIME_LAST_AGE_WITH_INDEXING, - aime_kink_2=AIME_KINK_2, + aime_kink_2=AIME_KINK_2_SCALAR, ratio_lowest_earnings=RATIO, ) assert jnp.isclose(result, 1000, atol=ATOL) @@ -217,7 +218,7 @@ def test_next_aime_cap_high_aime_high_income() -> None: aime_accrual_factor=AIME_ACCRUAL_FACTOR, aggregate_wage_growth=AGGREGATE_WAGE_GROWTH, aime_last_age_with_indexing=AIME_LAST_AGE_WITH_INDEXING, - aime_kink_2=AIME_KINK_2, + aime_kink_2=AIME_KINK_2_SCALAR, ratio_lowest_earnings=RATIO, ) assert jnp.isclose(result, 39000, atol=ATOL) @@ -237,7 +238,7 @@ def test_next_aime_cap_high_aime_low_income() -> None: aime_accrual_factor=AIME_ACCRUAL_FACTOR, aggregate_wage_growth=AGGREGATE_WAGE_GROWTH, aime_last_age_with_indexing=AIME_LAST_AGE_WITH_INDEXING, - aime_kink_2=AIME_KINK_2, + aime_kink_2=AIME_KINK_2_SCALAR, ratio_lowest_earnings=RATIO, ) assert jnp.isclose(result, 39000, atol=ATOL) diff --git a/tests/test_ss_benefit_integration.py b/tests/test_ss_benefit_integration.py index 5e74e9a..81a1b61 100644 --- a/tests/test_ss_benefit_integration.py +++ b/tests/test_ss_benefit_integration.py @@ -64,7 +64,7 @@ def test_earnings_test_reduces_benefit_before_fra() -> None: earnings_test_threshold=jnp.array([17640.0]), earnings_test_fraction=jnp.array([0.5]), earnings_test_repealed_age=jnp.int32(66), - ssdi_substantial_gainful_activity=13560.0, + ssdi_substantial_gainful_activity=jnp.asarray(13560.0), ) benefit_not_working = social_security.benefit_choose_pre65( @@ -82,7 +82,7 @@ def test_earnings_test_reduces_benefit_before_fra() -> None: earnings_test_threshold=jnp.array([17640.0]), earnings_test_fraction=jnp.array([0.5]), earnings_test_repealed_age=jnp.int32(66), - ssdi_substantial_gainful_activity=13560.0, + ssdi_substantial_gainful_activity=jnp.asarray(13560.0), ) assert benefit_working < benefit_not_working From b731def3372380c8af306fb89a1ccad5167ea296 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Mon, 11 May 2026 14:39:55 +0200 Subject: [PATCH 2/4] =?UTF-8?q?@categorical=20fields:=20`int`=20=E2=86=92?= =?UTF-8?q?=20`ScalarInt`=20(pylcm=20#349=20cascade)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Mirrors the pylcm-side change: every `@categorical`-decorated class moves its field annotations from `int` to `ScalarInt`, matching the runtime type the decorator now produces. Touched 13 classes across 6 files (`PrefType`, `BenchmarkPrefType`, `LaborSupply`, `LaggedLaborSupply`, `SpousalIncome`, `IsMarried`, `HealthWithDisability`, `Health`, `GoodHealth`, `BuyPrivate`, `HealthInsuranceState`, `ClaimedSS`, `RegimeId`). Hashability fix: `RegimeId.` is now a `jnp.int32` 0-d scalar (unhashable), so the four `id_to_name = {getattr(RegimeId, name): name ...}` sites in `baseline/regimes/_common.py` coerce keys with `int(...)`. `precompute_target_regimes` likewise wraps its `getattr(RegimeId, ...)` returns in `int(...)` so the precomputed mapping's values can serve as dict keys and `in`-set members in the per-target builders downstream. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/aca_model/agent/health.py | 16 +++--- src/aca_model/agent/labor_market.py | 25 +++++---- src/aca_model/agent/preferences.py | 10 ++-- src/aca_model/baseline/health_insurance.py | 11 ++-- src/aca_model/baseline/regimes/_common.py | 59 +++++++++++--------- src/aca_model/environment/social_security.py | 4 +- 6 files changed, 66 insertions(+), 59 deletions(-) diff --git a/src/aca_model/agent/health.py b/src/aca_model/agent/health.py index e0edf5f..66f04fe 100644 --- a/src/aca_model/agent/health.py +++ b/src/aca_model/agent/health.py @@ -6,28 +6,28 @@ import jax.numpy as jnp from lcm import categorical -from lcm.typing import DiscreteState, FloatND, IntND, Period +from lcm.typing import DiscreteState, FloatND, IntND, Period, ScalarInt @categorical(ordered=True) class HealthWithDisability: - disabled: int - bad: int - good: int + disabled: ScalarInt + bad: ScalarInt + good: ScalarInt @categorical(ordered=True) class Health: - bad: int - good: int + bad: ScalarInt + good: ScalarInt @categorical(ordered=True) class GoodHealth: """Derived categorical for good_health DAG output (0=no, 1=yes).""" - no: int - yes: int + no: ScalarInt + yes: ScalarInt def is_good_health_3(health: DiscreteState) -> IntND: diff --git a/src/aca_model/agent/labor_market.py b/src/aca_model/agent/labor_market.py index 14b9d65..b260c43 100644 --- a/src/aca_model/agent/labor_market.py +++ b/src/aca_model/agent/labor_market.py @@ -13,29 +13,30 @@ IntND, Period, ScalarFloat, + ScalarInt, ) @categorical(ordered=True) class LaborSupply: - do_not_work: int - h1000: int - h1500: int - h2000: int - h2500: int + do_not_work: ScalarInt + h1000: ScalarInt + h1500: ScalarInt + h2000: ScalarInt + h2500: ScalarInt @categorical(ordered=False) class LaggedLaborSupply: - did_not_work: int - worked: int + did_not_work: ScalarInt + worked: ScalarInt @categorical(ordered=False) class SpousalIncome: - single: int - married_no_inc: int - married_has_inc: int + single: ScalarInt + married_no_inc: ScalarInt + married_has_inc: ScalarInt HOURS_VALUES = jnp.array([0.0, 1000.0, 1500.0, 2000.0, 2500.0]) @@ -89,8 +90,8 @@ def next_lagged_supply(labor_supply: DiscreteAction) -> DiscreteState: class IsMarried: """Derived categorical for is_married DAG output (0=no, 1=yes).""" - no: int - yes: int + no: ScalarInt + yes: ScalarInt def is_married(spousal_income: DiscreteState) -> IntND: diff --git a/src/aca_model/agent/preferences.py b/src/aca_model/agent/preferences.py index 575ebcd..612896b 100644 --- a/src/aca_model/agent/preferences.py +++ b/src/aca_model/agent/preferences.py @@ -24,9 +24,9 @@ class PrefType: """Unobserved preference type for heterogeneity in estimation.""" - type_0: int - type_1: int - type_2: int + type_0: ScalarInt + type_1: ScalarInt + type_2: ScalarInt @categorical(ordered=False) @@ -40,8 +40,8 @@ class BenchmarkPrefType: measured. """ - type_0: int - type_1: int + type_0: ScalarInt + type_1: ScalarInt def positive_leisure(leisure: FloatND) -> BoolND: diff --git a/src/aca_model/baseline/health_insurance.py b/src/aca_model/baseline/health_insurance.py index 8371a29..d0e9322 100644 --- a/src/aca_model/baseline/health_insurance.py +++ b/src/aca_model/baseline/health_insurance.py @@ -25,6 +25,7 @@ Period, ScalarBool, ScalarFloat, + ScalarInt, ) from aca_model.agent.labor_market import LaborSupply @@ -32,15 +33,15 @@ @categorical(ordered=False) class BuyPrivate: - no: int - yes: int + no: ScalarInt + yes: ScalarInt @categorical(ordered=False) class HealthInsuranceState: - retiree: int - tied: int - nongroup: int + retiree: ScalarInt + tied: ScalarInt + nongroup: ScalarInt def countable_income( diff --git a/src/aca_model/baseline/regimes/_common.py b/src/aca_model/baseline/regimes/_common.py index 0b83321..836ff71 100644 --- a/src/aca_model/baseline/regimes/_common.py +++ b/src/aca_model/baseline/regimes/_common.py @@ -23,7 +23,7 @@ ) from lcm.grids.continuous import ContinuousGrid from lcm.grids.piecewise import Piece, PiecewiseLinSpacedGrid -from lcm.typing import BoolND, FloatND, RegimeName, UserParams +from lcm.typing import BoolND, FloatND, RegimeName, ScalarInt, UserParams from aca_model.agent import ( assets_and_income, @@ -42,25 +42,25 @@ @categorical(ordered=False) class RegimeId: - retiree_nomc_inelig_canwork: int - tied_nomc_inelig_canwork: int - nongroup_nomc_inelig_canwork: int - retiree_dimc_inelig_canwork: int - nongroup_dimc_inelig_canwork: int - retiree_nomc_choose_canwork: int - tied_nomc_choose_canwork: int - nongroup_nomc_choose_canwork: int - retiree_dimc_choose_canwork: int - nongroup_dimc_choose_canwork: int - retiree_oamc_choose_canwork: int - tied_oamc_choose_canwork: int - nongroup_oamc_choose_canwork: int - retiree_oamc_forced_canwork: int - tied_oamc_forced_canwork: int - nongroup_oamc_forced_canwork: int - retiree_oamc_forced_forcedout: int - nongroup_oamc_forced_forcedout: int - dead: int + retiree_nomc_inelig_canwork: ScalarInt + tied_nomc_inelig_canwork: ScalarInt + nongroup_nomc_inelig_canwork: ScalarInt + retiree_dimc_inelig_canwork: ScalarInt + nongroup_dimc_inelig_canwork: ScalarInt + retiree_nomc_choose_canwork: ScalarInt + tied_nomc_choose_canwork: ScalarInt + nongroup_nomc_choose_canwork: ScalarInt + retiree_dimc_choose_canwork: ScalarInt + nongroup_dimc_choose_canwork: ScalarInt + retiree_oamc_choose_canwork: ScalarInt + tied_oamc_choose_canwork: ScalarInt + nongroup_oamc_choose_canwork: ScalarInt + retiree_oamc_forced_canwork: ScalarInt + tied_oamc_forced_canwork: ScalarInt + nongroup_oamc_forced_canwork: ScalarInt + retiree_oamc_forced_forcedout: ScalarInt + nongroup_oamc_forced_forcedout: ScalarInt + dead: ScalarInt class RegimeSpec(TypedDict): @@ -552,7 +552,12 @@ def build_common_functions(spec: RegimeSpec) -> dict: def precompute_target_regimes(spec: RegimeSpec) -> MappingProxyType[str, int]: - """Pre-compute target regime IDs for each next-age bracket.""" + """Pre-compute target regime IDs for each next-age bracket. + + Coerces each `RegimeId.` (`ScalarInt`, post-pylcm#349) to a + Python `int` so the returned mapping's values can serve as dict + keys and `in`-set members downstream. + """ def _resolve(his_val: str, mc_val: str, ss_val: str, canwork_val: str) -> int: for name, s in REGIME_SPECS.items(): @@ -562,8 +567,8 @@ def _resolve(his_val: str, mc_val: str, ss_val: str, canwork_val: str) -> int: and s["ss"] == ss_val and s["canwork"] == canwork_val ): - return getattr(RegimeId, name) - return RegimeId.dead + return int(getattr(RegimeId, name)) + return int(RegimeId.dead) ng_his = "nongroup" if spec["his"] == "tied" else spec["his"] @@ -674,7 +679,7 @@ def _build_per_target_regime_assets( targets use the full `next_assets` with the pension correction. """ target_regimes = precompute_target_regimes(spec) - id_to_name = {getattr(RegimeId, name): name for name in REGIME_SPECS} + id_to_name = {int(getattr(RegimeId, name)): name for name in REGIME_SPECS} result: dict[RegimeName, Callable[..., FloatND]] = {} seen_ids: set[int] = set() @@ -701,7 +706,7 @@ def _build_per_target_regime_health( Cross-grid transitions (3->2) happen at the age-65 boundary. """ target_regimes = precompute_target_regimes(spec) - id_to_name = {getattr(RegimeId, name): name for name in REGIME_SPECS} + id_to_name = {int(getattr(RegimeId, name)): name for name in REGIME_SPECS} result: dict[RegimeName, MarkovTransition] = {} seen_ids: set[int] = set() @@ -737,7 +742,7 @@ def _build_per_target_regime_claimed_ss( return {} target_regimes = precompute_target_regimes(spec) - id_to_name = {getattr(RegimeId, name): name for name in REGIME_SPECS} + id_to_name = {int(getattr(RegimeId, name)): name for name in REGIME_SPECS} result: dict[RegimeName, Callable[..., BoolND]] = {} seen_ids: set[int] = set() @@ -778,7 +783,7 @@ def _build_per_target_regime_lagged_labor_supply( return {} target_regimes = precompute_target_regimes(spec) - id_to_name = {getattr(RegimeId, name): name for name in REGIME_SPECS} + id_to_name = {int(getattr(RegimeId, name)): name for name in REGIME_SPECS} result: dict[RegimeName, Callable[..., BoolND]] = {} seen_ids: set[int] = set() diff --git a/src/aca_model/environment/social_security.py b/src/aca_model/environment/social_security.py index 5863812..8b655d1 100644 --- a/src/aca_model/environment/social_security.py +++ b/src/aca_model/environment/social_security.py @@ -25,8 +25,8 @@ @categorical(ordered=False) class ClaimedSS: - no: int - yes: int + no: ScalarInt + yes: ScalarInt def next_claimed_ss( From b807b286175e00f8a721f039f4afd657bbb10e2f Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Mon, 11 May 2026 14:53:03 +0200 Subject: [PATCH 3/4] ci: pin pylcm to feat/categorical-scalarint pending pylcm#350 Revert to @main once pylcm#350 lands so the cascade can be tested on this branch in the meantime. Co-Authored-By: Claude Opus 4.7 (1M context) --- .github/workflows/main.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 67c82fa..1148345 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -26,10 +26,10 @@ jobs: - uses: actions/setup-python@v6 with: python-version: ${{ matrix.python-version }} - - name: Install pylcm + - name: Install pylcm (feature branch — revert to @main once pylcm#350 merges) run: >- pip install "pylcm @ - git+https://github.com/OpenSourceEconomics/pylcm.git@main" + git+https://github.com/OpenSourceEconomics/pylcm.git@feat/categorical-scalarint" - name: Install aca-model with test deps run: pip install -e . pytest pdbp - name: Run pytest From c8193c439d9f239e434e07b9542aa93b68f75939 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Fri, 22 May 2026 16:26:52 +0200 Subject: [PATCH 4/4] consumption_dollars: route max_consumption_dollars through fixed_params (pylcm #351 cascade) (#11) Co-authored-by: Claude Opus 4.7 (1M context) --- .github/workflows/main.yml | 4 +-- .../_benchmark_data/benchmark_params.pkl | Bin 68325 -> 68562 bytes src/aca_model/baseline/regimes/_common.py | 11 +----- src/aca_model/consumption_dollars_grid.py | 33 ++++++++++-------- tests/test_consumption_dollars_grid.py | 13 ++++--- 5 files changed, 31 insertions(+), 30 deletions(-) diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 1148345..8104668 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -26,10 +26,10 @@ jobs: - uses: actions/setup-python@v6 with: python-version: ${{ matrix.python-version }} - - name: Install pylcm (feature branch — revert to @main once pylcm#350 merges) + - name: Install pylcm (feature branch — revert to @main once pylcm#348/#350 merge) run: >- pip install "pylcm @ - git+https://github.com/OpenSourceEconomics/pylcm.git@feat/categorical-scalarint" + git+https://github.com/OpenSourceEconomics/pylcm.git@feat/runtime-grid-extra-params" - name: Install aca-model with test deps run: pip install -e . pytest pdbp - name: Run pytest diff --git a/src/aca_model/_benchmark_data/benchmark_params.pkl b/src/aca_model/_benchmark_data/benchmark_params.pkl index 650c3902fc840003a043e1d0a5d24aa41e09120f..f7c505ef8fb94731af080e93df12a095b942dc27 100644 GIT binary patch delta 11859 zcmdT~3tW^{7U#~414uB4!)t~&hy$Xi<+_y$h{>-eG>vxMem?}viEpkj(7#3xxW{TO%))1|YUi&k*ZB?vY+g;srzi%GjH%~y^Sp4}t?z#8e zbN=_7bI(2VecK{CZISK$QJxf8Uh;U+BNc`!Q+Z{zzQRy6zoM+H#8_p{%L$vg#1H<+ zRrztWmF$fTzk5f0Zq%u}>mL2Xrk@T`68*EjXPK+bPnnrL^;Om8+2)4eN_|ODiN0J{ zRBF`EH#1vFaX>iQL6QQ}!fky5X`kwnd1X2WePKWv+QV)NI3z(`WYz+8Fw@EC3W2kR zVVG}7-GWH6F(?^bAxDG8qVw$4peGROCWhchTlZ^@?$>E|C3A@}GM5<>2)79SKe4Kc zxk9>vGhuXKh-h?-A{6=g*~jB>QXHWk?r!*1#YEJ_W`yoX(r)eTq%t@hroS;PMln*v zAd{qrxo0pZEDh#=I?QSQ9MSOU9_KGtPJrQClupCvh=vz>9Nrjl$I$a7g=QvCvyMTe zHDoTE9vLPCGDcA$;#KXD*Y2n-U>y0XNLXAP-J!BAZpb`i$s-GN)m25N%1WJ)t}x$8 zr|Lj^NTxc@7txv?VQeq*NRGs#3+%YM69p zwAOFAnaNyf)@YGds+FuTGdoB@{GBjgV|;A%R*!w&p?%ov@pVWNEU`2rq$CNaWcMj0 zetz|ZpmHQU;3#^DhKsJ+JP!~ltgl*Wp2uY47-%!C0KY4NR-kB@`Pd)Y9JV}?e3>{I zWs&%#OcY0ECS^#HC1hON$bp(|x@a~n%%)P|$+Xdkr_h_g|8N#NDY;1^F-XYF1#zKe zqM=VnNJ&IAIC+u;uVn8`l}nI;WNW(6X0mB&5(P!M69#ag?1l%thhFxw1!)_l=puSdAOj_^^m2qfGv>%`btZVuZ$Yp?Gii*X4iJHxppn|tTDSwyOav1}@l)a7 zLjyodxp2@rD6Kq&vUZ}p$y1KlC^~!$9x2;A%~UpG(grG9aaJT-IBFd{a?*}E30(wY z>2N1B%)n>)k(>ZlH6&zd{%}Pj5F88JO&-pPiC#|IjCcjT36Wk&*5v#u(5@r-AZJ9_ zbI?g^M}evc_$ErQ9bmMR%$)iF2Y9N&4tAUrkW3(FrY43vl**O!n}8XgBtf~VX;c@j z+=eh{4IL2u2OUHtw^}CCDaI3h?tR=^IT-m^s$nqB{Zd8n6w26XsWd!gyp1AUs$6;C zEl>;nMyVJ|XSMBT_+HXJP3>jvE=?Pyu&*7!2Zlr%0uxA6K_JO03MGbw0`kT^Fpz0fWLO74S#2l7XcMFN=fmmNVYn!T&8kj4ko^Y8hfk5Hm-(Z-OA!m24nt3oSw*R%R{$f< z@$2CMKSM9i!i7e|6Mzr#GQ*pMZuLrg;rYGOJcee#(y+OhB%k(?1~NWCH=DB7(B919=xoC{;&$#F2bD zcl=<&F0`b7*y#oayFw9mOMPHhP@0Kov`~`Cq5dQxC4k*tiV@mIJ~pd(!rGRg(5)h# zyJU2aS7Fa9vPPeZwzIG6Yg{-T6mdH0lhfV>Ip zB@We)d_y!BlKT)Sdq~l$DE5w`Ay=1-FL3m(pGAz5Fz7N2p4H1TuslI@@w|0sTD;aK$bM7VUJup+2^v`0v zPTvP$bRuACeF5|I3ZKAOgV;>~!xpWY-LFmD>(qS!rBwvVeqW#*XoLvG{?OlB}h z7=1;)TBWBJ5R_i_iiAB%+S|PO?fcs&po2u(mNBqG zd?6Cz8y`YkYWp<^k@k`sA+$GvgrC1KVhr&oZ?}2LDA+(QJ_ZliNLuzKLeh8k*jaYXD+r-yNi*A>_AJn}r#v0Ui|9xby)0wbzxuMg*!4JW ztwqg+P7s0U0H-p>e4awNI8he!lw~#wO$yH*>jKG^ws-ULrT1Dsf3EWMAZ2MkRc#@M zPN?`oT?aXJFd409FCDCw*KLGe4y8REp9fYmzBoebO{d7$$3D0jB|j1=c?OhJM<1tc zMtp+agh%iWw&z5QB${!l>!vYxG)VM!kTLQ&WqmHB)qZ}W-$%%>_tG7PMW(&CoG;4- z+iV3^+qo>4EH1^xgPr08Q7B86k{nHfpprM*=MZ82i|jl7G}I@v-mi!2!B6VbxRRa8 zCX#bCiL0;jB|G-X4~|Irnw`D7kA+VS;7r<0d36*|aZ|Iqlcxw}b}YP!w4YwoKk;dm zDmR`ow;T8G=$oD?)B0&mTC@_%#ikMy{IoTXsl=X(t;heQ4ZaRWl9x>$9IAXlwwo9_ zr$q3zL|82uFwM(e`N-s-Ueu5JtzE1^9BY!iMN5!)tK0|m#> zOC1^d=~$~qn-7Fdytw~c&}fue3Pgttyjss6KK<|kz1E#t<~JXue^#~LOZw>B;+%S` zT5pfN|CXo)hT%o6l+MgSwot3Mg~BQdyI=vVK$F>L(_GeJY)aOFu5p^?k6Q3Az93fd zn>~oXxJxt0@a7%E32#$FRJC3=zsNvZ&qWSSfmewTHZ{Z;Dh#E@3i8HzwTpoa8b`%q zyD&AVjk;>RQAdt`rW#~0yY)i{%%6Sk3iF`%LkHeJ{n>_hL+yo(<;{<|+n}1KDQ7(N zkFD+^`NCods}Q014eGe=KhKql!O-}1@r^skQ_lr!E;5d7Rve}oEQhNUefif(^L zp2Ws?ha)~winnDY;zTJmPMq<0&W}kb(Urzb_KAz$xEjq+ggfctQq8frP)b>w&E-l(dncLH%FX$Ve z7S;4VEqdl+s@2N}nbkRh>1I&O;pKb&YV-1S>nn;ebD=KwU_h0LA?zdADh@3yRXP}_&A;LN3Qu82EMA$%rpjsMsmR?Ly*^Cy&Bu>pkA zBZ?ps+e|7i$0%%|+c4j7dAwD00p`x4TYb{4=(oZt$8n-NvKwmTILvA;yyVV$9J`X| z-{Z8EALFJH?*dF;7s-$=&X4;>sL zd)oJw@E@EEuy6}S!aWJX!RKh@w9$wS^d_w0R1m}U=Ue`WqWxRigXp4>8(&0_u{U;F zev43;bcYtcyGlw%J1tEJMNfYPhI56#*Wdxa4j1RI@%BR-hpnI_ybanMXdMo;%#tGR zBbWf(1tj>6gUD}WNiaW2axW&aWFr(w&fgdy0lP&4Uf~2(MR(FxBR)%SLdd!-NfMN6 zg{0TgTHcbl;CW9kRE61{@eCO%#52M2fE1~{NV`mmG?v9Olx&$RLot@qQlzwKWN4!0 zXDL!yc1cmb1lLh_wKM?-*KV7w;NnEwLTu7dOc delta 11385 zcmd5?30PFu73R(itFq{@&ajS(U=+kGhBV?JF8wlM5>wkWmEa&F!&MfCL9~L{s00iy zr5;#L(=~DB3g_6O)>xiI~``*?R7Kv%ELJ%qZd~e%|-qJ@=h^ z&VQD7-kJA-RdKdivHPMQMEkdqP0;BdLAJw1|D>4Kh6=M~x}`C=yjY*FFD^0U7nzC+ zEyhid9>CU23P_ExuL)#p%=&_P1}A-8Kq_hFUJE!NBOS1@L>+8AZZzasYAh*6qw!Nn zO3suSyWsNyqsb{QYTzP5I-w~z%D(oJbL~&?esHw%TcgQ%-dI^PM~FOOt`TpCQ)HyNe6dt*vmxTsk;@D!Xtb;jLUSR zVE5YS&}0dlG{_yS?opD(gHutBX9i2uhz-+9ZlB~vinZalBZIbZ7Y4tRFqr9v!Jf$5 z1T5g~km+1*)LvfDVMz@(o&*I=&5uPiS&n3$o3 zMfEDtq7uESsJK)=FW;m$8!DWOmZ{rGE9Az;I(R=4%Vws=4OcHm*;4;xnTB>%(MG&L zAEB$6z;=J!9G4VglY?5tgmWT7i1+i;s5rbS*!#F<#?#zx;$KxEMdvbvK+w zII0pdk?EdW6bqmra(z>1{D^I(G=kf(e>*5(g&DNSQ_k8C!?J#Gucn3C_y^Bzk4#Y zP8tavDTQPsWS`MQ%U4+F4kYp5?8Fn=%7$HBRq6&gY3HIw?Ifff4yUCIU>xXPHXP(G zr5zerYsT~3X0pO!yqBdUeV)Y_N!Q>cbCf1LBH+eIV>H82Mc3m5Ld>}E5IZR?0)?@5 zp}Zha4%#UO`Z}I6wo_6jaG7_nCyI8Q;O5=31Eu7Jy@gGPq%zRsI5CZ$5N_o6w`n1R zq#|YLDkM1V+X*$qmml7 zMFuqH+{0VS$;gAD!tlPH4HY3!7-JVh4HqaGc8b_gIevltpjKj0cxWKo)&B6JzlP3< z>Y=@w4gs51S_QSF9Hxm9CP*poLMKbxMgs<-!>Kd?||nEnNWjm&Pj>GoiJ3hYdu83U{&)3+U&p zAL`9sfVzj77#oM~+vM=G) ze>=v#U^fmc-h=-3;9Xwg6{ol^cttQa4!klI9`L&S@Xdr*G?<2e%tu1rmi4 zBx)EE9wRIfNPajRtkZZx#g>@*iprDq-lF1m!>d50MS{vp9{l6u)nhPu4O$~jUSS_e zn0(@k?=9}0*8`dum#iiiF5gMGT=K-_PRK$oRp7!U2voO@pgDmr#)7{(ldejcy zgvBaut<~G)kjixId-<6pfohVUhqB@8r*5XQA&N_G@ith>uMb{E30{q!@SdF?wr??0G-yX*mg>#@;*K+$ipOILdob&tsQF+C$K}&E50w z{MOsaK1kj+^7;z#u|$YZJqV%PekTesb-RlY_x2HqUnD{({P~7Vuf-LYxue}yM#(h( z&^0VREYV+_Q&Rms8F*`l8>v68(;$ zekv7hW0NL&gdN1Mzi)7n?XfbM&IfAV93k}=cL^M=8l*pU^fef;w&Ih)gepbBK|Bz1kHDbiMLMg1J~D^~vMcTOwQu3y+cZdg8y&!}eA@R_vZ*wk4rEu^))UJa8ybvP4KxHt z9igFkpkZvNfR^KX`}D=jjWRhH)$7MIMZOaeoV@Ait?chAA0HC%p=q#kfau3X{N+u_1$`nTLXDDOF9{?yN^-o* z&uzVxCdn!87K)$lF!mg)MGl@@%a2Mhd|4w@CDnEj+*Kv*pIxz4C647Ds*=Ys4Kt;% ztZ*rgtpcTos^l4gB9_}(l~luob4#vW1->$QjBVhkOb$cZSMDm4g0K4GS_S>*TDJUH z4@9)cezw>en&-gb+wZ; zJnLcoe{NC?ZBM~^c;TB=QURZQGt!m>uO*Gt_P?#sb7>b2j?#`+IIm#G@eoKmZ?BoI z#Tl?V?Qdsrufp%U?7rNq1p@4e95POX_hWoHeB+@Dv7s)Y!-@-IY@X69%mG?k%rQA{ z5U`h=-XQf9{6z>~U-o`C#m$EU3t{{B;n+eMe|ZiJy|`wm20whzJ=EKS!567i&7HUy zCs%iGo7lb?@C}5f>NsT(uEocI8V`;iw#HvBA{y9wxpbH}`vmgs55o93GWQ)Y^U63f z1D?Dxl8oYxT=`5!##obri1LQXGgBgu9_7Jy#RF{8M1R8$;&#Oh_`@&jta}K_Y93FX zuujCMzkIt+21~Zbg*Cq$#1|kw(;mf7k!X{(i;$Sfe?}I3lWRSm(7)i|`o&iZ&hdzf zx}T5XoD-z2fz~Vq;lHg&IDi!P>uFSBamDPi%2IRw9Q|x_nJK?uQNDRmIqbO-4(qo> z$!z-%@nV1L8aat-{?MVy5Yi{v+a$7g@Up97_OV$L-OmnUL=RX)`$yabs7Mzn4hrXC0>?YBz$;SCACzIA4;hY78;6&t46sg=JMGhElU-inev Array: """Return log-spaced consumption_dollars gridpoints with both floors pinned. @@ -108,12 +113,12 @@ def _compute_consumption_dollars_points( a feasible action; otherwise sub-ULP drift can flip the `<=` comparison for subjects with very negative cash. The geomspace tail starts at the married floor and runs to - `MAX_CONSUMPTION_DOLLARS` so the two pinned points stay strictly + `max_consumption_dollars` so the two pinned points stay strictly increasing. """ married_dollar_floor = consumption_equiv_floor * jnp.asarray(2.0) ** exponent tail = jnp.geomspace( - married_dollar_floor, MAX_CONSUMPTION_DOLLARS, num=n_points - 1 + married_dollar_floor, max_consumption_dollars, num=n_points - 1 ) pts = jnp.concatenate([consumption_equiv_floor[None], tail]) # `jnp.geomspace` returns `start * r^0` for the first tail element, @@ -129,7 +134,7 @@ def _compute_consumption_dollars_points( msg = ( f"consumption_dollars grid is not strictly increasing at the " f"married-floor kink: pts[1]={float(married_dollar_floor):.6g}, " - f"pts[2]={float(pts[2]):.6g}. Either `MAX_CONSUMPTION_DOLLARS` " + f"pts[2]={float(pts[2]):.6g}. Either `max_consumption_dollars` " f"is too close to the married floor or `n_points` is too small." ) raise ValueError(msg) diff --git a/tests/test_consumption_dollars_grid.py b/tests/test_consumption_dollars_grid.py index 1f42e6f..2d452b0 100644 --- a/tests/test_consumption_dollars_grid.py +++ b/tests/test_consumption_dollars_grid.py @@ -22,20 +22,21 @@ rejects every action for the affected subjects. `_compute_consumption_dollars_points` therefore prepends the singles' -floor as `pts[0]`, runs `geomspace` from the married floor up to -`MAX_CONSUMPTION_DOLLARS` for the rest, and pins the geomspace start -back to the married floor exactly. Test those invariants directly. +floor as `pts[0]`, runs `geomspace` from the married floor up to the +caller-supplied `max_consumption_dollars` for the rest, and pins the +geomspace start back to the married floor exactly. Test those invariants +directly. """ import jax.numpy as jnp import pytest -from aca_model.baseline.regimes._common import MAX_CONSUMPTION_DOLLARS from aca_model.consumption_dollars_grid import _compute_consumption_dollars_points EXPONENT = 0.7 # production value (env_constants["exponent"]) SINGLE_FLOOR = 1597.0921419521899 # production value MARRIED_SCALE = 2.0**EXPONENT +MAX_CONSUMPTION_DOLLARS = 300_000.0 # production value (env_constants) @pytest.mark.parametrize("n_points", [5, 16, 64, 70, 100]) @@ -46,6 +47,7 @@ def test_compute_consumption_dollars_points_first_equals_singles_floor( pts = _compute_consumption_dollars_points( consumption_equiv_floor=jnp.asarray(SINGLE_FLOOR), exponent=jnp.asarray(EXPONENT), + max_consumption_dollars=jnp.asarray(MAX_CONSUMPTION_DOLLARS), n_points=n_points, ) assert float(pts[0]) == SINGLE_FLOOR @@ -59,6 +61,7 @@ def test_compute_consumption_dollars_points_second_equals_married_floor( pts = _compute_consumption_dollars_points( consumption_equiv_floor=jnp.asarray(SINGLE_FLOOR), exponent=jnp.asarray(EXPONENT), + max_consumption_dollars=jnp.asarray(MAX_CONSUMPTION_DOLLARS), n_points=n_points, ) expected = float(jnp.asarray(SINGLE_FLOOR) * jnp.asarray(2.0) ** EXPONENT) @@ -70,6 +73,7 @@ def test_compute_consumption_dollars_points_strictly_increasing() -> None: pts = _compute_consumption_dollars_points( consumption_equiv_floor=jnp.asarray(SINGLE_FLOOR), exponent=jnp.asarray(EXPONENT), + max_consumption_dollars=jnp.asarray(MAX_CONSUMPTION_DOLLARS), n_points=70, ) diffs = jnp.diff(pts) @@ -81,6 +85,7 @@ def test_compute_consumption_dollars_points_last_equals_max() -> None: pts = _compute_consumption_dollars_points( consumption_equiv_floor=jnp.asarray(SINGLE_FLOOR), exponent=jnp.asarray(EXPONENT), + max_consumption_dollars=jnp.asarray(MAX_CONSUMPTION_DOLLARS), n_points=70, ) assert float(pts[-1]) == pytest.approx(MAX_CONSUMPTION_DOLLARS)