Skip to content

Commit

Permalink
remove the spectators and temporarily disable through run plotting
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 707244086
  • Loading branch information
tamaranorman authored and Torax team committed Dec 17, 2024
1 parent 3f07b38 commit d700f06
Show file tree
Hide file tree
Showing 9 changed files with 5 additions and 822 deletions.
3 changes: 2 additions & 1 deletion run_simulation_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,8 @@
_PLOT_SIM_PROGRESS = flags.DEFINE_bool(
'plot_progress',
False,
'If true, plots the time of each timestep as the simulation runs.',
'If true, plots the time of each timestep as the simulation runs.'
' Note: this is temporarily disabled.',
)

_LOG_SIM_OUTPUT = flags.DEFINE_bool(
Expand Down
98 changes: 0 additions & 98 deletions torax/sim.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@
from torax.sources import ohmic_heat_source
from torax.sources import source_models as source_models_lib
from torax.sources import source_profiles as source_profiles_lib
from torax.spectators import spectator as spectator_lib
from torax.stepper import stepper as stepper_lib
from torax.time_step_calculator import chi_time_step_calculator
from torax.time_step_calculator import time_step_calculator as ts
Expand Down Expand Up @@ -743,28 +742,18 @@ def source_models(self) -> source_models_lib.SourceModels:
def run(
self,
log_timestep_info: bool = False,
spectator: spectator_lib.Spectator | None = None,
) -> output.ToraxSimOutputs:
"""Runs the transport simulation over a prescribed time interval.
See `run_simulation` for details.
Args:
log_timestep_info: See `run_simulation()`.
spectator: If a SimulationStepFn has not yet been built for this Sim
object (if it was not passed in __init__ or this object has never been
run), then it will be built in this call, and this spectator will be
built into it. If the SimulationStepFn has already been built, then this
argument is ignored and the spectator built into the SimulationStepFn
cannot change. In these cases where you want to use a new spectator, you
must build a new Sim object.
Returns:
Tuple of all ToraxSimStates, one per time step and an additional one at
the beginning for the starting state.
"""
if spectator is not None:
spectator.reset()
return run_simulation(
static_runtime_params_slice=self.static_runtime_params_slice,
dynamic_runtime_params_slice_provider=self.dynamic_runtime_params_slice_provider,
Expand All @@ -773,7 +762,6 @@ def run(
time_step_calculator=self.time_step_calculator,
step_fn=self.step_fn,
log_timestep_info=log_timestep_info,
spectator=spectator,
)


Expand Down Expand Up @@ -1005,7 +993,6 @@ def run_simulation(
time_step_calculator: ts.TimeStepCalculator,
step_fn: SimulationStepFn,
log_timestep_info: bool = False,
spectator: spectator_lib.Spectator | None = None,
) -> output.ToraxSimOutputs:
"""Runs the transport simulation over a prescribed time interval.
Expand Down Expand Up @@ -1046,8 +1033,6 @@ def run_simulation(
ToraxSimState objects.
log_timestep_info: If True, logs basic timestep info, like time, dt, on
every step.
spectator: Object which can "spectate" values as the simulation runs. See
the Spectator class docstring for more details.
Returns:
ToraxSimOutputs, containing information on the sim error state, and the
Expand Down Expand Up @@ -1080,10 +1065,6 @@ def run_simulation(
geometry_provider,
)
)
if spectator is not None:
# Because of the updates we apply to the core sources during the next
# iteration, we need to start the spectator before step here.
spectator.before_step()

sim_state = initial_state

Expand Down Expand Up @@ -1120,14 +1101,6 @@ def run_simulation(
logging.info(
'Solver converged only within coarse tolerance in previous step.'
)
# Make sure to "spectate" the state after the source profiles have been
# merged and updated in the output sim_state.
if spectator is not None:
_update_spectator(spectator, sim_state)
# This is after the previous time step's step_fn() call.
spectator.after_step()
# Now prep the spectator for the following time step.
spectator.before_step()

if first_step:
# Initialize the sim_history with the initial state.
Expand Down Expand Up @@ -1176,10 +1149,6 @@ def run_simulation(
explicit_source_profiles=explicit_source_profiles,
implicit_source_profiles=sim_state.core_sources,
)
if spectator is not None:
# Complete the last time step.
_update_spectator(spectator, sim_state)
spectator.after_step()

# If the first step of the simulation was very long, call it out. It might
# have to do with tracing the jitted step_fn.
Expand Down Expand Up @@ -1213,73 +1182,6 @@ def run_simulation(
)


def _update_spectator(
spectator: spectator_lib.Spectator,
output_state: state.ToraxSimState,
) -> None:
"""Updates the spectator with values from the output state."""
spectator.observe(key='q_face', data=output_state.core_profiles.q_face)
spectator.observe(key='s_face', data=output_state.core_profiles.s_face)
spectator.observe(key='ne', data=output_state.core_profiles.ne.value)
spectator.observe(
key='temp_ion',
data=output_state.core_profiles.temp_ion.value,
)
spectator.observe(
key='temp_el',
data=output_state.core_profiles.temp_el.value,
)
spectator.observe(
key='j_bootstrap_face',
data=output_state.core_profiles.currents.j_bootstrap_face,
)
spectator.observe(
key='johm',
data=output_state.core_profiles.currents.johm,
)
spectator.observe(
key='generic_current_source',
data=output_state.core_profiles.currents.generic_current_source,
)
spectator.observe(
key='jtot_face',
data=output_state.core_profiles.currents.jtot_face,
)
spectator.observe(
key='chi_face_ion', data=output_state.core_transport.chi_face_ion
)
spectator.observe(
key='chi_face_el', data=output_state.core_transport.chi_face_el
)
spectator.observe(
key='Qext_i',
data=output_state.core_sources.get_profile(
'generic_ion_el_heat_source_ion'
),
)
spectator.observe(
key='Qext_e',
data=output_state.core_sources.get_profile(
'generic_ion_el_heat_source_el'
),
)
spectator.observe(
key='Qfus_i',
data=output_state.core_sources.get_profile('fusion_heat_source_ion'),
)
spectator.observe(
key='Qfus_e',
data=output_state.core_sources.get_profile('fusion_heat_source_el'),
)
spectator.observe(
key='Qohm',
data=output_state.core_sources.get_profile('ohmic_heat_source'),
)
spectator.observe(
key='Qei', data=output_state.core_sources.get_profile('qei_source')
)


def _get_geo_and_dynamic_runtime_params_at_t_plus_dt_and_phibdot(
t: jnp.ndarray,
dt: jnp.ndarray,
Expand Down
23 changes: 3 additions & 20 deletions torax/simulation_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ def run(_):

from absl import logging
import jax
from matplotlib import pyplot as plt
from torax import geometry
from torax import geometry_provider
from torax import output
Expand All @@ -69,7 +68,6 @@ def run(_):
from torax.config import runtime_params_slice
from torax.pedestal_model import runtime_params as pedestal_runtime_params_lib
from torax.sources import runtime_params as source_runtime_params_lib
from torax.spectators import plotting
from torax.stepper import runtime_params as stepper_runtime_params_lib
from torax.transport_model import runtime_params as transport_runtime_params_lib
import xarray as xr
Expand Down Expand Up @@ -279,31 +277,16 @@ def main(
sim = get_sim()
geo = sim.geometry_provider(sim.initial_state.t)

spectator = None
if plot_sim_progress:
if can_plot():
plt.ion()
spectator = plotting.PlotSpectator(
plots=plotting.get_default_plot_config(geo=geo),
pyplot_figure_kwargs=dict(
figsize=(12, 6),
),
)
plt.show()
else:
logging.warning(
'plotting requested, but there is no display connected to show the '
'plot.'
)

log_to_stdout('Starting simulation.', color=AnsiColors.GREEN)
sim_outputs = sim.run(
log_timestep_info=log_sim_progress,
spectator=spectator,
)
log_to_stdout('Finished running simulation.', color=AnsiColors.GREEN)
state_history = output.StateHistory(sim_outputs, sim.source_models)

if plot_sim_progress:
raise NotImplementedError('Plotting progress is temporarily disabled.')

data_tree = state_history.simulation_output_to_xr(geo, sim.file_restart)

output_file = write_simulation_output_to_file(output_dir, data_tree)
Expand Down
21 changes: 0 additions & 21 deletions torax/spectators/__init__.py

This file was deleted.

Loading

0 comments on commit d700f06

Please sign in to comment.