Skip to content

Commit

Permalink
Changes to start simplifying sim.py
Browse files Browse the repository at this point in the history
Also fixes that we were using old geo and state for the final step

PiperOrigin-RevId: 706967398
  • Loading branch information
tamaranorman authored and Torax team committed Dec 19, 2024
1 parent f6af816 commit 79991df
Show file tree
Hide file tree
Showing 6 changed files with 59 additions and 89 deletions.
92 changes: 43 additions & 49 deletions torax/sim.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,19 +57,6 @@
import xarray as xr


def _log_timestep(
t: jax.Array, dt: jax.Array, outer_stepper_iterations: int
) -> None:
"""Logs basic timestep info."""
logging.info(
'\nSimulation time: %.5f, previous dt: %.6f, previous stepper'
' iterations: %d',
t,
dt,
outer_stepper_iterations,
)


def get_consistent_dynamic_runtime_params_slice_and_geometry(
t: chex.Numeric,
dynamic_runtime_params_slice_provider: runtime_params_slice.DynamicRuntimeParamsSliceProvider,
Expand Down Expand Up @@ -147,7 +134,7 @@ def __call__(
dynamic_runtime_params_slice_provider: runtime_params_slice.DynamicRuntimeParamsSliceProvider,
geometry_provider: geometry_provider_lib.GeometryProvider,
input_state: state.ToraxSimState,
) -> state.ToraxSimState:
) -> tuple[state.ToraxSimState, state.SimError]:
"""Advances the simulation state one time step.
Args:
Expand Down Expand Up @@ -179,6 +166,7 @@ def __call__(
1 if solver did not converge for this step (was above coarse tol)
2 if solver converged within coarse tolerance. Allowed to pass with
a warning. Occasional error=2 has low impact on final sim state.
SimError indicating if an error has occurred during simulation.
"""
dynamic_runtime_params_slice_t, geo_t = (
get_consistent_dynamic_runtime_params_slice_and_geometry(
Expand Down Expand Up @@ -259,13 +247,14 @@ def __call__(
explicit_source_profiles,
)

return self.finalize_output(
sim_state = self.finalize_output(
input_state,
output_state,
dynamic_runtime_params_slice_t_plus_dt,
static_runtime_params_slice,
geo_t_plus_dt,
)
return sim_state, sim_state.check_for_errors()

def init_time_step_calculator(
self,
Expand Down Expand Up @@ -811,7 +800,7 @@ def run(
Tuple of all ToraxSimStates, one per time step and an additional one at
the beginning for the starting state.
"""
return run_simulation(
return _run_simulation(
static_runtime_params_slice=self.static_runtime_params_slice,
dynamic_runtime_params_slice_provider=self.dynamic_runtime_params_slice_provider,
geometry_provider=self.geometry_provider,
Expand Down Expand Up @@ -1042,7 +1031,7 @@ def build_sim_object(
)


def run_simulation(
def _run_simulation(
static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice,
dynamic_runtime_params_slice_provider: runtime_params_slice.DynamicRuntimeParamsSliceProvider,
geometry_provider: geometry_provider_lib.GeometryProvider,
Expand Down Expand Up @@ -1124,15 +1113,14 @@ def run_simulation(
)

sim_state = initial_state
sim_history = []
sim_state = post_processing.make_outputs(sim_state=sim_state, geo=geo)
sim_history.append(sim_state)

# Set the sim_error to NO_ERROR. If we encounter an error, we will set it to
# the appropriate error code.
sim_error = state.SimError.NO_ERROR

# Initialize first_step, used to post-process and append the initial state to
# the sim_history.
first_step = True
sim_history = []
# Advance the simulation until the time_step_calculator tells us we are done.
while time_step_calculator.not_done(
sim_state.t,
Expand All @@ -1142,29 +1130,9 @@ def run_simulation(
# Measure how long in wall clock time each simulation step takes.
step_start_time = time.time()
if log_timestep_info:
_log_timestep(
sim_state.t,
sim_state.dt,
sim_state.stepper_numeric_outputs.outer_stepper_iterations,
)
# TODO(b/330172917): once tol and coarse_tol are configurable in the
# runtime_params, also log the value of tol and coarse_tol below
match sim_state.stepper_numeric_outputs.stepper_error_state:
case 0:
pass
case 1:
logging.info('Solver did not converge in previous step.')
case 2:
logging.info(
'Solver converged only within coarse tolerance in previous step.'
)
_log_timestep(sim_state)

if first_step:
# Initialize the sim_history with the initial state.
sim_state = post_processing.make_outputs(sim_state=sim_state, geo=geo)
sim_history.append(sim_state)
first_step = False
sim_state = step_fn(
sim_state, sim_error = step_fn(
static_runtime_params_slice,
dynamic_runtime_params_slice_provider,
geometry_provider,
Expand All @@ -1175,7 +1143,6 @@ def run_simulation(
# Checks if sim_state is valid. If not, exit simulation early.
# We don't raise an Exception because we want to return the truncated
# simulation history to the user for inspection.
sim_error = sim_state.check_for_errors()
if sim_error != state.SimError.NO_ERROR:
sim_error.log_error()
break
Expand All @@ -1185,15 +1152,18 @@ def run_simulation(
# Log final timestep
if log_timestep_info and sim_error == state.SimError.NO_ERROR:
# The "sim_state" here has been updated by the loop above.
_log_timestep(
sim_state.t,
sim_state.dt,
sim_state.stepper_numeric_outputs.outer_stepper_iterations,
)
_log_timestep(sim_state)

# Update the final time step's source profiles based on the explicit source
# profiles computed based on the final state.
logging.info("Updating last step's source profiles.")
dynamic_runtime_params_slice, geo = (
get_consistent_dynamic_runtime_params_slice_and_geometry(
sim_state.t,
dynamic_runtime_params_slice_provider,
geometry_provider,
)
)
explicit_source_profiles = source_models_lib.build_source_profiles(
dynamic_runtime_params_slice=dynamic_runtime_params_slice,
static_runtime_params_slice=static_runtime_params_slice,
Expand Down Expand Up @@ -1598,3 +1568,27 @@ def merge_source_profiles(
j_bootstrap=summed_bootstrap_profile,
qei=summed_qei_info,
)


def _log_timestep(
sim_state: state.ToraxSimState,
) -> None:
"""Logs basic timestep info."""
logging.info(
'\nSimulation time: %.5f, previous dt: %.6f, previous stepper'
' iterations: %d',
sim_state.t,
sim_state.dt,
sim_state.stepper_numeric_outputs.outer_stepper_iterations,
)
# TODO(b/330172917): once tol and coarse_tol are configurable in the
# runtime_params, also log the value of tol and coarse_tol below
match sim_state.stepper_numeric_outputs.stepper_error_state:
case 0:
pass
case 1:
logging.info('Solver did not converge in previous step.')
case 2:
logging.info(
'Solver converged only within coarse tolerance in previous step.'
)
9 changes: 1 addition & 8 deletions torax/sources/tests/formulas.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,14 +219,7 @@ def _run_sim_and_check(
ref_time: chex.Array,
):
"""Runs sim with new runtime params and checks the profiles vs. expected."""
sim_outputs = sim_lib.run_simulation(
static_runtime_params_slice=sim.static_runtime_params_slice,
dynamic_runtime_params_slice_provider=sim.dynamic_runtime_params_slice_provider,
geometry_provider=sim.geometry_provider,
initial_state=sim.initial_state,
time_step_calculator=sim.time_step_calculator,
step_fn=sim.step_fn,
)
sim_outputs = sim.run()
history = output.StateHistory(sim_outputs, sim.source_models)
self._check_profiles_vs_expected(
core_profiles=history.core_profiles,
Expand Down
17 changes: 2 additions & 15 deletions torax/tests/sim.py
Original file line number Diff line number Diff line change
Expand Up @@ -635,21 +635,8 @@ def test_core_profiles_are_recomputable(self, test_config, halfway):

if halfway:
# Run sim till the end and check that final core profiles match reference.
initial_state.t = ref_time[index]
step_fn = sim_lib.SimulationStepFn(
stepper=sim.stepper,
time_step_calculator=sim.time_step_calculator,
transport_model=sim.transport_model,
pedestal_model=sim.pedestal_model,
)
sim_outputs = sim_lib.run_simulation(
static_runtime_params_slice=sim.static_runtime_params_slice,
dynamic_runtime_params_slice_provider=sim.dynamic_runtime_params_slice_provider,
geometry_provider=sim.geometry_provider,
initial_state=initial_state,
time_step_calculator=sim.time_step_calculator,
step_fn=step_fn,
)
sim.initial_state.t = ref_time[index]
sim_outputs = sim.run()
final_core_profiles = sim_outputs.sim_history[-1].core_profiles
verify_core_profiles(ref_profiles, -1, final_core_profiles)
# pylint: enable=invalid-name
Expand Down
10 changes: 2 additions & 8 deletions torax/tests/sim_custom_sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,14 +268,8 @@ def _run_sim_and_check(
source_runtime_params=self.source_models_builder.runtime_params,
)
)
sim_outputs = sim_lib.run_simulation(
initial_state=sim.initial_state,
step_fn=sim.step_fn,
geometry_provider=sim.geometry_provider,
dynamic_runtime_params_slice_provider=sim.dynamic_runtime_params_slice_provider,
static_runtime_params_slice=static_runtime_params_slice,
time_step_calculator=sim.time_step_calculator,
)
sim._static_runtime_params_slice = static_runtime_params_slice # pylint: disable=protected-access
sim_outputs = sim.run()
history = output.StateHistory(sim_outputs, sim.source_models)
self._check_profiles_vs_expected(
core_profiles=history.core_profiles,
Expand Down
18 changes: 10 additions & 8 deletions torax/tests/sim_output_source_profiles.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,8 +150,10 @@ def custom_source_formula(
source_runtime_params=source_models_builder.runtime_params,
)
)

sim_outputs = sim_lib.run_simulation(
sim = sim_lib.Sim(
static_runtime_params_slice=static_runtime_params_slice,
dynamic_runtime_params_slice_provider=dynamic_runtime_params_slice_provider,
geometry_provider=geometry_provider_lib.ConstantGeometryProvider(geo),
initial_state=sim_lib.get_initial_state(
static_runtime_params_slice=static_runtime_params_slice,
dynamic_runtime_params_slice=initial_dcs,
Expand All @@ -160,13 +162,13 @@ def custom_source_formula(
source_models=source_models,
step_fn=step_fn,
),
step_fn=step_fn,
geometry_provider=geometry_provider_lib.ConstantGeometryProvider(geo),
dynamic_runtime_params_slice_provider=dynamic_runtime_params_slice_provider,
static_runtime_params_slice=static_runtime_params_slice,
time_step_calculator=time_stepper,
step_fn=step_fn,
source_models_builder=source_models_builder,
)

sim_outputs = sim.run()

# The implicit and explicit profiles get merged together before being
# outputted, and they are aligned as well as possible to be computed based
# on the state and config at time t. So both the implicit and explicit
Expand Down Expand Up @@ -304,7 +306,7 @@ def __call__(
dynamic_runtime_params_slice_provider: runtime_params_slice.DynamicRuntimeParamsSliceProvider,
geometry_provider: geometry_provider_lib.GeometryProvider,
input_state: state_module.ToraxSimState,
) -> state_module.ToraxSimState:
) -> tuple[state_module.ToraxSimState, state_module.SimError]:
dt, ts_state = self._time_step_calculator.next_dt(
dynamic_runtime_params_slice=dynamic_runtime_params_slice_provider(
t=input_state.t,
Expand All @@ -331,7 +333,7 @@ def __call__(
source_models=self.stepper.source_models,
explicit=False,
),
)
), state_module.SimError.NO_ERROR


if __name__ == '__main__':
Expand Down
2 changes: 1 addition & 1 deletion torax/tests/sim_time_dependence.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def test_time_dependent_params_update_in_adaptive_dt(
time_step_calculator=time_calculator,
)
sim_step_fn = sim.step_fn
output_state = sim_step_fn(
output_state, _ = sim_step_fn(
static_runtime_params_slice=sim.static_runtime_params_slice,
dynamic_runtime_params_slice_provider=sim.dynamic_runtime_params_slice_provider,
geometry_provider=sim.geometry_provider,
Expand Down

0 comments on commit 79991df

Please sign in to comment.