Skip to content

Commit

Permalink
Move build_sim_object logic to a class method
Browse files Browse the repository at this point in the history
A follow-up CL will change over occurances of this

PiperOrigin-RevId: 716475568
  • Loading branch information
tamaranorman authored and Torax team committed Jan 17, 2025
1 parent a475745 commit a3f45c6
Showing 1 changed file with 155 additions and 135 deletions.
290 changes: 155 additions & 135 deletions torax/sim.py
Original file line number Diff line number Diff line change
Expand Up @@ -826,6 +826,153 @@ def run(
log_timestep_info=log_timestep_info,
)

@classmethod
def create(
cls,
*,
runtime_params: general_runtime_params.GeneralRuntimeParams,
geometry_provider: geometry_provider_lib.GeometryProvider,
stepper_builder: stepper_lib.StepperBuilder,
transport_model_builder: transport_model_lib.TransportModelBuilder,
source_models_builder: source_models_lib.SourceModelsBuilder,
pedestal_model_builder: pedestal_model_lib.PedestalModelBuilder,
time_step_calculator: Optional[ts.TimeStepCalculator] = None,
file_restart: Optional[general_runtime_params.FileRestart] = None,
) -> Sim:
"""Builds a Sim object from the input runtime params and sim components.
Args:
runtime_params: The input runtime params used throughout the simulation
run.
geometry_provider: The geometry used throughout the simulation run.
stepper_builder: A callable to build the stepper. The stepper has already
been factored out of the config.
transport_model_builder: A callable to build the transport model.
source_models_builder: Builds the SourceModels and holds its
runtime_params.
pedestal_model_builder: A callable to build the pedestal model.
time_step_calculator: The time_step_calculator, if built, otherwise a
ChiTimeStepCalculator will be built by default.
file_restart: If provided we will reconstruct the initial state from the
provided file at the given time step. This state from the file will only
be used for constructing the initial state (as well as the config) and
for all subsequent steps, the evolved state and runtime parameters from
config are used.
Returns:
sim: The built Sim instance.
"""

transport_model = transport_model_builder()
pedestal_model = pedestal_model_builder()

# TODO(b/385788907): Document all changes that lead to recompilations.
static_runtime_params_slice = (
runtime_params_slice.build_static_runtime_params_slice(
runtime_params=runtime_params,
source_runtime_params=source_models_builder.runtime_params,
torax_mesh=geometry_provider.torax_mesh,
stepper=stepper_builder.runtime_params,
)
)
dynamic_runtime_params_slice_provider = (
runtime_params_slice.DynamicRuntimeParamsSliceProvider(
runtime_params=runtime_params,
transport=transport_model_builder.runtime_params,
sources=source_models_builder.runtime_params,
stepper=stepper_builder.runtime_params,
torax_mesh=geometry_provider.torax_mesh,
pedestal=pedestal_model_builder.runtime_params,
)
)
source_models = source_models_builder()
stepper = stepper_builder(transport_model, source_models, pedestal_model)

if time_step_calculator is None:
time_step_calculator = chi_time_step_calculator.ChiTimeStepCalculator()

# Build dynamic_runtime_params_slice at t_initial for initial conditions.
dynamic_runtime_params_slice_for_init, geo_for_init = (
get_consistent_dynamic_runtime_params_slice_and_geometry(
runtime_params.numerics.t_initial,
dynamic_runtime_params_slice_provider,
geometry_provider,
)
)
if file_restart is not None and file_restart.do_restart:
data_tree = output.load_state_file(file_restart.filename)
# Find the closest time in the given dataset.
data_tree = data_tree.sel(time=file_restart.time, method='nearest')
t_restart = data_tree.time.item()
core_profiles_dataset = data_tree.children[output.CORE_PROFILES].dataset
# Remap coordinates in saved file to be consistent with expectations of
# how config_args parses xarrays.
core_profiles_dataset = core_profiles_dataset.rename(
{output.RHO_CELL_NORM: config_args.RHO_NORM}
)
core_profiles_dataset = core_profiles_dataset.squeeze()
if t_restart != runtime_params.numerics.t_initial:
logging.warning(
'Requested restart time %f not exactly available in state file %s.'
' Restarting from closest available time %f instead.',
file_restart.time,
file_restart.filename,
t_restart,
)
# Override some of dynamic runtime params slice from t=t_initial.
dynamic_runtime_params_slice_for_init, geo_for_init = (
_override_initial_runtime_params_from_file(
dynamic_runtime_params_slice_for_init,
geo_for_init,
t_restart,
core_profiles_dataset,
)
)
post_processed_dataset = data_tree.children[
output.POST_PROCESSED_OUTPUTS
].dataset
post_processed_dataset = post_processed_dataset.rename(
{output.RHO_CELL_NORM: config_args.RHO_NORM}
)
post_processed_dataset = post_processed_dataset.squeeze()
post_processed_outputs = (
_override_initial_state_post_processed_outputs_from_file(
geo_for_init,
post_processed_dataset,
)
)

step_fn = SimulationStepFn(
stepper=stepper,
time_step_calculator=time_step_calculator,
transport_model=transport_model,
pedestal_model=pedestal_model,
)

initial_state = get_initial_state(
static_runtime_params_slice=static_runtime_params_slice,
dynamic_runtime_params_slice=dynamic_runtime_params_slice_for_init,
geo=geo_for_init,
step_fn=step_fn,
)

# If we are restarting from a file, we need to override the initial state
# post processed outputs such that cumulative outputs remain correct.
if file_restart is not None and file_restart.do_restart:
initial_state = dataclasses.replace(
initial_state,
post_processed_outputs=post_processed_outputs, # pylint: disable=undefined-variable
)

return cls(
static_runtime_params_slice=static_runtime_params_slice,
dynamic_runtime_params_slice_provider=dynamic_runtime_params_slice_provider,
geometry_provider=geometry_provider,
initial_state=initial_state,
step_fn=step_fn,
file_restart=file_restart,
)


def _override_initial_runtime_params_from_file(
dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice,
Expand Down Expand Up @@ -905,142 +1052,15 @@ def build_sim_object(
time_step_calculator: Optional[ts.TimeStepCalculator] = None,
file_restart: Optional[general_runtime_params.FileRestart] = None,
) -> Sim:
"""Builds a Sim object from the input runtime params and sim components.
The Sim object provides a container for all the components that go into a
single TORAX simulation run. It gives a way to reuse components without having
to rebuild or recompile them if JAX shapes or static arguments do not change.
Read more about the Sim object in its class docstring. The use of it is
optional, and users may call `sim.run_simulation()` directly as well.
Args:
runtime_params: The input runtime params used throughout the simulation run.
geometry_provider: The geometry used throughout the simulation run.
stepper_builder: A callable to build the stepper. The stepper has already
been factored out of the config.
transport_model_builder: A callable to build the transport model.
source_models_builder: Builds the SourceModels and holds its runtime_params.
pedestal_model_builder: A callable to build the pedestal model.
time_step_calculator: The time_step_calculator, if built, otherwise a
ChiTimeStepCalculator will be built by default.
file_restart: If provided we will reconstruct the initial state from the
provided file at the given time step. This state from the file will only
be used for constructing the initial state (as well as the config) and for
all subsequent steps, the evolved state and runtime parameters from config
are used.
Returns:
sim: The built Sim instance.
"""

transport_model = transport_model_builder()
pedestal_model = pedestal_model_builder()

# TODO(b/385788907): Clearly document all changes that lead to recompilations.
static_runtime_params_slice = (
runtime_params_slice.build_static_runtime_params_slice(
runtime_params=runtime_params,
source_runtime_params=source_models_builder.runtime_params,
torax_mesh=geometry_provider.torax_mesh,
stepper=stepper_builder.runtime_params,
)
)
dynamic_runtime_params_slice_provider = (
runtime_params_slice.DynamicRuntimeParamsSliceProvider(
runtime_params=runtime_params,
transport=transport_model_builder.runtime_params,
sources=source_models_builder.runtime_params,
stepper=stepper_builder.runtime_params,
torax_mesh=geometry_provider.torax_mesh,
pedestal=pedestal_model_builder.runtime_params,
)
)
source_models = source_models_builder()
stepper = stepper_builder(transport_model, source_models, pedestal_model)

if time_step_calculator is None:
time_step_calculator = chi_time_step_calculator.ChiTimeStepCalculator()

# Build dynamic_runtime_params_slice at t_initial for initial conditions.
dynamic_runtime_params_slice_for_init, geo_for_init = (
get_consistent_dynamic_runtime_params_slice_and_geometry(
runtime_params.numerics.t_initial,
dynamic_runtime_params_slice_provider,
geometry_provider,
)
)
if file_restart is not None and file_restart.do_restart:
data_tree = output.load_state_file(file_restart.filename)
# Find the closest time in the given dataset.
data_tree = data_tree.sel(time=file_restart.time, method='nearest')
t_restart = data_tree.time.item()
core_profiles_dataset = data_tree.children[output.CORE_PROFILES].dataset
# Remap coordinates in saved file to be consistent with expectations of
# how config_args parses xarrays.
core_profiles_dataset = core_profiles_dataset.rename(
{output.RHO_CELL_NORM: config_args.RHO_NORM}
)
core_profiles_dataset = core_profiles_dataset.squeeze()
if t_restart != runtime_params.numerics.t_initial:
logging.warning(
'Requested restart time %f not exactly available in state file %s.'
' Restarting from closest available time %f instead.',
file_restart.time,
file_restart.filename,
t_restart,
)
# Override some of dynamic runtime params slice from t=t_initial.
dynamic_runtime_params_slice_for_init, geo_for_init = (
_override_initial_runtime_params_from_file(
dynamic_runtime_params_slice_for_init,
geo_for_init,
t_restart,
core_profiles_dataset,
)
)
post_processed_dataset = data_tree.children[
output.POST_PROCESSED_OUTPUTS
].dataset
post_processed_dataset = post_processed_dataset.rename(
{output.RHO_CELL_NORM: config_args.RHO_NORM}
)
post_processed_dataset = post_processed_dataset.squeeze()
post_processed_outputs = (
_override_initial_state_post_processed_outputs_from_file(
geo_for_init,
post_processed_dataset,
)
)

step_fn = SimulationStepFn(
stepper=stepper,
time_step_calculator=time_step_calculator,
transport_model=transport_model,
pedestal_model=pedestal_model,
)

initial_state = get_initial_state(
static_runtime_params_slice=static_runtime_params_slice,
dynamic_runtime_params_slice=dynamic_runtime_params_slice_for_init,
geo=geo_for_init,
step_fn=step_fn,
)

# If we are restarting from a file, we need to override the initial state
# post processed outputs such that cumulative outputs remain correct.
if file_restart is not None and file_restart.do_restart:
initial_state = dataclasses.replace(
initial_state,
post_processed_outputs=post_processed_outputs, # pylint: disable=undefined-variable
)

return Sim(
static_runtime_params_slice=static_runtime_params_slice,
dynamic_runtime_params_slice_provider=dynamic_runtime_params_slice_provider,
"""Builds a Sim object from the input runtime params and sim components."""
return Sim.create(
runtime_params=runtime_params,
geometry_provider=geometry_provider,
initial_state=initial_state,
step_fn=step_fn,
stepper_builder=stepper_builder,
transport_model_builder=transport_model_builder,
source_models_builder=source_models_builder,
pedestal_model_builder=pedestal_model_builder,
time_step_calculator=time_step_calculator,
file_restart=file_restart,
)

Expand Down

0 comments on commit a3f45c6

Please sign in to comment.