Skip to content

Commit

Permalink
Changes to start simplifying sim.py
Browse files Browse the repository at this point in the history
Also fixes that we were using old geo and state for the final step

PiperOrigin-RevId: 706967398
  • Loading branch information
tamaranorman authored and Torax team committed Dec 18, 2024
1 parent d700f06 commit 6459401
Show file tree
Hide file tree
Showing 108 changed files with 516 additions and 590 deletions.
8 changes: 6 additions & 2 deletions torax/config/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,17 @@
from typing import Any, Generic, TypeVar

import chex
from torax import geometry
from torax import interpolated_param
from torax.config import config_args
from torax.geometry import geometry

DynamicT = TypeVar('DynamicT')
ProviderT = TypeVar('ProviderT', bound='RuntimeParametersProvider')


class GridType(enum.Enum):
"""Describes where interpolated values are defined on."""

CELL = enum.auto()
FACE = enum.auto()

Expand Down Expand Up @@ -76,7 +77,8 @@ def get_provider_kwargs(
Args:
torax_mesh: Required if any of the interpolated variables are both
temporally and radially interpolated.
temporally and radially interpolated.
Returns:
A dict of kwargs to be passed to the provider constructor.
"""
Expand Down Expand Up @@ -161,6 +163,7 @@ class RuntimeParametersProvider(Generic[DynamicT], metaclass=abc.ABCMeta):
- any constructed interpolated variables which will already be spatially
interpolated and could vary in time.
"""

runtime_params_config: RuntimeParametersConfig

def get_dynamic_params_kwargs(
Expand All @@ -177,6 +180,7 @@ def get_dynamic_params_kwargs(
Args:
t: The time to interpolate the dynamic parameters at.
Returns:
A dict of kwargs to be passed to the dynamic params constructor.
"""
Expand Down
5 changes: 3 additions & 2 deletions torax/config/build_sim.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@
import copy
from typing import Any

from torax import geometry
from torax import geometry_provider
from torax import sim as sim_lib
from torax.config import config_args
from torax.config import runtime_params as runtime_params_lib
from torax.geometry import geometry
from torax.geometry import geometry_provider
from torax.pedestal_model import pedestal_model as pedestal_model_lib
from torax.pedestal_model import set_tped_nped
from torax.sources import formula_config
Expand All @@ -45,6 +45,7 @@
# pylint: disable=g-import-not-at-top
try:
from torax.transport_model import qualikiz_transport_model

_QUALIKIZ_TRANSPORT_MODEL_AVAILABLE = True
except ImportError:
_QUALIKIZ_TRANSPORT_MODEL_AVAILABLE = False
Expand Down
2 changes: 1 addition & 1 deletion torax/config/numerics.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@

import chex
from torax import array_typing
from torax import geometry
from torax import interpolated_param
from torax.config import base
from torax.geometry import geometry
from typing_extensions import override


Expand Down
2 changes: 1 addition & 1 deletion torax/config/plasma_composition.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@

import chex
from torax import array_typing
from torax import geometry
from torax import interpolated_param
from torax.config import base
from torax.config import config_args
from torax.geometry import geometry


# pylint: disable=invalid-name
Expand Down
4 changes: 2 additions & 2 deletions torax/config/profile_conditions.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@

import chex
from torax import array_typing
from torax import geometry
from torax import interpolated_param
from torax.config import base
from torax.config import config_args
from torax.geometry import geometry
from typing_extensions import override


Expand Down Expand Up @@ -84,7 +84,7 @@ class ProfileConditions(
ne_bound_right: interpolated_param.TimeInterpolatedInput | None = None
ne_bound_right_is_fGW: bool = False
ne_bound_right_is_absolute: bool = False
# Internal boundary condition (pedestal)
# Internal boundary condition (pedestal)
# Do not set internal boundary condition if this is False
set_pedestal: interpolated_param.TimeInterpolatedInput = True
# current profiles (broad "Ohmic" + localized "external" currents)
Expand Down
3 changes: 2 additions & 1 deletion torax/config/runtime_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@
import dataclasses

import chex
from torax import geometry
from torax.config import base
from torax.config import numerics as numerics_lib
from torax.config import plasma_composition as plasma_composition_lib
from torax.config import profile_conditions as profile_conditions_lib
from torax.geometry import geometry
from typing_extensions import override


Expand Down Expand Up @@ -92,6 +92,7 @@ def build_dynamic_params(
@chex.dataclass
class DynamicGeneralRuntimeParams:
"""General runtime input parameters for the `torax` module."""

plasma_composition: plasma_composition_lib.DynamicPlasmaComposition
profile_conditions: profile_conditions_lib.DynamicProfileConditions
numerics: numerics_lib.DynamicNumerics
16 changes: 7 additions & 9 deletions torax/config/runtime_params_slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,11 @@
import dataclasses

import chex
from torax import geometry
from torax.config import numerics
from torax.config import plasma_composition
from torax.config import profile_conditions
from torax.config import runtime_params as general_runtime_params_lib
from torax.geometry import geometry
from torax.pedestal_model import runtime_params as pedestal_model_params
from torax.sources import runtime_params as sources_params
from torax.stepper import runtime_params as stepper_params
Expand Down Expand Up @@ -148,7 +148,9 @@ def _build_dynamic_sources(
) -> dict[str, sources_params.DynamicRuntimeParams]:
"""Builds a dict of DynamicSourceConfigSlice based on the input config."""
return {
source_name: input_source_config.build_dynamic_params(t,)
source_name: input_source_config.build_dynamic_params(
t,
)
for source_name, input_source_config in sources.items()
}

Expand Down Expand Up @@ -262,15 +264,11 @@ def runtime_params_provider(

def _construct_providers(self):
"""Construct the providers that will give us the dynamic params."""
self._runtime_params_provider = (
self._runtime_params.make_provider(
self._torax_mesh
)
self._runtime_params_provider = self._runtime_params.make_provider(
self._torax_mesh
)
self._transport_runtime_params_provider = (
self._transport_runtime_params.make_provider(
self._torax_mesh
)
self._transport_runtime_params.make_provider(self._torax_mesh)
)
self._sources_providers = {
key: source.make_provider(self._torax_mesh)
Expand Down
4 changes: 2 additions & 2 deletions torax/config/tests/build_sim.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@
from absl.testing import absltest
from absl.testing import parameterized
import numpy as np
from torax import geometry
from torax import geometry_provider
from torax.config import build_sim
from torax.config import runtime_params as runtime_params_lib
from torax.config import runtime_params_slice
from torax.geometry import geometry
from torax.geometry import geometry_provider
from torax.pedestal_model import set_tped_nped
from torax.sources import formula_config
from torax.sources import formulas
Expand Down
2 changes: 1 addition & 1 deletion torax/config/tests/numerics.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@

from absl.testing import absltest
from absl.testing import parameterized
from torax import geometry
from torax import interpolated_param
from torax.config import numerics
from torax.geometry import geometry


class NumericsTest(parameterized.TestCase):
Expand Down
2 changes: 1 addition & 1 deletion torax/config/tests/plasma_composition.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@
from absl.testing import absltest
from absl.testing import parameterized
import numpy as np
from torax import geometry
from torax import interpolated_param
from torax.config import plasma_composition
from torax.geometry import geometry


class PlasmaCompositionTest(parameterized.TestCase):
Expand Down
2 changes: 1 addition & 1 deletion torax/config/tests/profile_conditions.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@
from absl.testing import absltest
from absl.testing import parameterized
import numpy as np
from torax import geometry
from torax import interpolated_param
from torax.config import config_args
from torax.config import profile_conditions
from torax.geometry import geometry
import xarray as xr


Expand Down
25 changes: 19 additions & 6 deletions torax/config/tests/runtime_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@

from absl.testing import absltest
from absl.testing import parameterized
from torax import geometry
from torax.config import config_args
from torax.config import profile_conditions as profile_conditions_lib
from torax.config import runtime_params as general_runtime_params
from torax.geometry import geometry


# pylint: disable=invalid-name
Expand Down Expand Up @@ -94,7 +94,9 @@ def test_recursive_replace(self):
self.assertEqual(result.a4.b4.c1, 19)
self.assertEqual(result.a4.b4.c2, 20)

def test_runtime_params_raises_for_invalid_temp_boundary_conditions(self,):
def test_runtime_params_raises_for_invalid_temp_boundary_conditions(
self,
):
"""Tests that runtime params validate boundary conditions."""
with self.assertRaises(ValueError):
general_runtime_params.GeneralRuntimeParams(
Expand All @@ -104,12 +106,23 @@ def test_runtime_params_raises_for_invalid_temp_boundary_conditions(self,):
)

@parameterized.parameters(
({0.0: {0.0: 12.0, 1.0: 2.0}}, None,), # Ti includes 1.0.
({0.0: {0.0: 12.0, 1.0: 2.0}}, 1.0,), # Both provided.
({0.0: {0.0: 12.0, 0.95: 2.0}}, 1.0,) # Ti_bound_right provided.
(
{0.0: {0.0: 12.0, 1.0: 2.0}},
None,
), # Ti includes 1.0.
(
{0.0: {0.0: 12.0, 1.0: 2.0}},
1.0,
), # Both provided.
(
{0.0: {0.0: 12.0, 0.95: 2.0}},
1.0,
), # Ti_bound_right provided.
)
def test_runtime_params_constructs_with_valid_profile_conditions(
self, Ti, Ti_bound_right,
self,
Ti,
Ti_bound_right,
):
"""Tests that runtime params validate boundary conditions."""
general_runtime_params.GeneralRuntimeParams(
Expand Down
26 changes: 20 additions & 6 deletions torax/config/tests/runtime_params_slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@
from absl.testing import parameterized
import jax
import numpy as np
from torax import geometry
from torax.config import profile_conditions as profile_conditions_lib
from torax.config import runtime_params as general_runtime_params
from torax.config import runtime_params_slice as runtime_params_slice_lib
from torax.geometry import geometry
from torax.pedestal_model import set_tped_nped
from torax.sources import electron_density_sources
from torax.sources import formula_config
Expand Down Expand Up @@ -74,11 +74,15 @@ def test_time_dependent_provider_is_time_dependent(self):
stepper=stepper_params_lib.RuntimeParams(),
torax_mesh=self._geo.torax_mesh,
)
dynamic_runtime_params_slice = provider(t=1.0,)
dynamic_runtime_params_slice = provider(
t=1.0,
)
np.testing.assert_allclose(
dynamic_runtime_params_slice.profile_conditions.Ti_bound_right, 2.5
)
dynamic_runtime_params_slice = provider(t=2.0,)
dynamic_runtime_params_slice = provider(
t=2.0,
)
np.testing.assert_allclose(
dynamic_runtime_params_slice.profile_conditions.Ti_bound_right, 3.0
)
Expand Down Expand Up @@ -422,9 +426,18 @@ def test_profile_conditions_set_electron_temperature_and_boundary_condition(
)

@parameterized.product(
ne_bound_right=[None, 1.0,],
ne_bound_right_is_fGW=[True, False,],
ne_is_fGW=[True, False,],
ne_bound_right=[
None,
1.0,
],
ne_bound_right_is_fGW=[
True,
False,
],
ne_is_fGW=[
True,
False,
],
)
def test_profile_conditions_set_electron_density_and_boundary_condition(
self,
Expand Down Expand Up @@ -630,5 +643,6 @@ def test_static_runtime_params_slice_hash_different_for_different_params(
)
self.assertNotEqual(hash(static_slice1), hash(static_slice2))


if __name__ == '__main__':
absltest.main()
Loading

0 comments on commit 6459401

Please sign in to comment.