Skip to content

Commit

Permalink
Move geometry files into a subfolder
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 707737297
  • Loading branch information
tamaranorman authored and Torax team committed Dec 19, 2024
1 parent 9148754 commit f6af816
Show file tree
Hide file tree
Showing 103 changed files with 355 additions and 293 deletions.
8 changes: 6 additions & 2 deletions torax/config/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,17 @@
from typing import Any, Generic, TypeVar

import chex
from torax import geometry
from torax import interpolated_param
from torax.config import config_args
from torax.geometry import geometry

DynamicT = TypeVar('DynamicT')
ProviderT = TypeVar('ProviderT', bound='RuntimeParametersProvider')


class GridType(enum.Enum):
"""Describes where interpolated values are defined on."""

CELL = enum.auto()
FACE = enum.auto()

Expand Down Expand Up @@ -76,7 +77,8 @@ def get_provider_kwargs(
Args:
torax_mesh: Required if any of the interpolated variables are both
temporally and radially interpolated.
temporally and radially interpolated.
Returns:
A dict of kwargs to be passed to the provider constructor.
"""
Expand Down Expand Up @@ -161,6 +163,7 @@ class RuntimeParametersProvider(Generic[DynamicT], metaclass=abc.ABCMeta):
- any constructed interpolated variables which will already be spatially
interpolated and could vary in time.
"""

runtime_params_config: RuntimeParametersConfig

def get_dynamic_params_kwargs(
Expand All @@ -177,6 +180,7 @@ def get_dynamic_params_kwargs(
Args:
t: The time to interpolate the dynamic parameters at.
Returns:
A dict of kwargs to be passed to the dynamic params constructor.
"""
Expand Down
5 changes: 3 additions & 2 deletions torax/config/build_sim.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@
import copy
from typing import Any

from torax import geometry
from torax import geometry_provider
from torax import sim as sim_lib
from torax.config import config_args
from torax.config import runtime_params as runtime_params_lib
from torax.geometry import geometry
from torax.geometry import geometry_provider
from torax.pedestal_model import pedestal_model as pedestal_model_lib
from torax.pedestal_model import set_tped_nped
from torax.sources import formula_config
Expand All @@ -45,6 +45,7 @@
# pylint: disable=g-import-not-at-top
try:
from torax.transport_model import qualikiz_transport_model

_QUALIKIZ_TRANSPORT_MODEL_AVAILABLE = True
except ImportError:
_QUALIKIZ_TRANSPORT_MODEL_AVAILABLE = False
Expand Down
2 changes: 1 addition & 1 deletion torax/config/numerics.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@

import chex
from torax import array_typing
from torax import geometry
from torax import interpolated_param
from torax.config import base
from torax.geometry import geometry
from typing_extensions import override


Expand Down
2 changes: 1 addition & 1 deletion torax/config/plasma_composition.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@

import chex
from torax import array_typing
from torax import geometry
from torax import interpolated_param
from torax.config import base
from torax.config import config_args
from torax.geometry import geometry


# pylint: disable=invalid-name
Expand Down
4 changes: 2 additions & 2 deletions torax/config/profile_conditions.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@

import chex
from torax import array_typing
from torax import geometry
from torax import interpolated_param
from torax.config import base
from torax.config import config_args
from torax.geometry import geometry
from typing_extensions import override


Expand Down Expand Up @@ -84,7 +84,7 @@ class ProfileConditions(
ne_bound_right: interpolated_param.TimeInterpolatedInput | None = None
ne_bound_right_is_fGW: bool = False
ne_bound_right_is_absolute: bool = False
# Internal boundary condition (pedestal)
# Internal boundary condition (pedestal)
# Do not set internal boundary condition if this is False
set_pedestal: interpolated_param.TimeInterpolatedInput = True
# current profiles (broad "Ohmic" + localized "external" currents)
Expand Down
3 changes: 2 additions & 1 deletion torax/config/runtime_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@
import dataclasses

import chex
from torax import geometry
from torax.config import base
from torax.config import numerics as numerics_lib
from torax.config import plasma_composition as plasma_composition_lib
from torax.config import profile_conditions as profile_conditions_lib
from torax.geometry import geometry
from typing_extensions import override


Expand Down Expand Up @@ -92,6 +92,7 @@ def build_dynamic_params(
@chex.dataclass
class DynamicGeneralRuntimeParams:
"""General runtime input parameters for the `torax` module."""

plasma_composition: plasma_composition_lib.DynamicPlasmaComposition
profile_conditions: profile_conditions_lib.DynamicProfileConditions
numerics: numerics_lib.DynamicNumerics
16 changes: 7 additions & 9 deletions torax/config/runtime_params_slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,11 @@
import dataclasses

import chex
from torax import geometry
from torax.config import numerics
from torax.config import plasma_composition
from torax.config import profile_conditions
from torax.config import runtime_params as general_runtime_params_lib
from torax.geometry import geometry
from torax.pedestal_model import runtime_params as pedestal_model_params
from torax.sources import runtime_params as sources_params
from torax.stepper import runtime_params as stepper_params
Expand Down Expand Up @@ -148,7 +148,9 @@ def _build_dynamic_sources(
) -> dict[str, sources_params.DynamicRuntimeParams]:
"""Builds a dict of DynamicSourceConfigSlice based on the input config."""
return {
source_name: input_source_config.build_dynamic_params(t,)
source_name: input_source_config.build_dynamic_params(
t,
)
for source_name, input_source_config in sources.items()
}

Expand Down Expand Up @@ -262,15 +264,11 @@ def runtime_params_provider(

def _construct_providers(self):
"""Construct the providers that will give us the dynamic params."""
self._runtime_params_provider = (
self._runtime_params.make_provider(
self._torax_mesh
)
self._runtime_params_provider = self._runtime_params.make_provider(
self._torax_mesh
)
self._transport_runtime_params_provider = (
self._transport_runtime_params.make_provider(
self._torax_mesh
)
self._transport_runtime_params.make_provider(self._torax_mesh)
)
self._sources_providers = {
key: source.make_provider(self._torax_mesh)
Expand Down
4 changes: 2 additions & 2 deletions torax/config/tests/build_sim.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@
from absl.testing import absltest
from absl.testing import parameterized
import numpy as np
from torax import geometry
from torax import geometry_provider
from torax.config import build_sim
from torax.config import runtime_params as runtime_params_lib
from torax.config import runtime_params_slice
from torax.geometry import geometry
from torax.geometry import geometry_provider
from torax.pedestal_model import set_tped_nped
from torax.sources import formula_config
from torax.sources import formulas
Expand Down
2 changes: 1 addition & 1 deletion torax/config/tests/numerics.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@

from absl.testing import absltest
from absl.testing import parameterized
from torax import geometry
from torax import interpolated_param
from torax.config import numerics
from torax.geometry import geometry


class NumericsTest(parameterized.TestCase):
Expand Down
2 changes: 1 addition & 1 deletion torax/config/tests/plasma_composition.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@
from absl.testing import absltest
from absl.testing import parameterized
import numpy as np
from torax import geometry
from torax import interpolated_param
from torax.config import plasma_composition
from torax.geometry import geometry


class PlasmaCompositionTest(parameterized.TestCase):
Expand Down
2 changes: 1 addition & 1 deletion torax/config/tests/profile_conditions.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@
from absl.testing import absltest
from absl.testing import parameterized
import numpy as np
from torax import geometry
from torax import interpolated_param
from torax.config import config_args
from torax.config import profile_conditions
from torax.geometry import geometry
import xarray as xr


Expand Down
25 changes: 19 additions & 6 deletions torax/config/tests/runtime_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@

from absl.testing import absltest
from absl.testing import parameterized
from torax import geometry
from torax.config import config_args
from torax.config import profile_conditions as profile_conditions_lib
from torax.config import runtime_params as general_runtime_params
from torax.geometry import geometry


# pylint: disable=invalid-name
Expand Down Expand Up @@ -94,7 +94,9 @@ def test_recursive_replace(self):
self.assertEqual(result.a4.b4.c1, 19)
self.assertEqual(result.a4.b4.c2, 20)

def test_runtime_params_raises_for_invalid_temp_boundary_conditions(self,):
def test_runtime_params_raises_for_invalid_temp_boundary_conditions(
self,
):
"""Tests that runtime params validate boundary conditions."""
with self.assertRaises(ValueError):
general_runtime_params.GeneralRuntimeParams(
Expand All @@ -104,12 +106,23 @@ def test_runtime_params_raises_for_invalid_temp_boundary_conditions(self,):
)

@parameterized.parameters(
({0.0: {0.0: 12.0, 1.0: 2.0}}, None,), # Ti includes 1.0.
({0.0: {0.0: 12.0, 1.0: 2.0}}, 1.0,), # Both provided.
({0.0: {0.0: 12.0, 0.95: 2.0}}, 1.0,) # Ti_bound_right provided.
(
{0.0: {0.0: 12.0, 1.0: 2.0}},
None,
), # Ti includes 1.0.
(
{0.0: {0.0: 12.0, 1.0: 2.0}},
1.0,
), # Both provided.
(
{0.0: {0.0: 12.0, 0.95: 2.0}},
1.0,
), # Ti_bound_right provided.
)
def test_runtime_params_constructs_with_valid_profile_conditions(
self, Ti, Ti_bound_right,
self,
Ti,
Ti_bound_right,
):
"""Tests that runtime params validate boundary conditions."""
general_runtime_params.GeneralRuntimeParams(
Expand Down
26 changes: 20 additions & 6 deletions torax/config/tests/runtime_params_slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@
from absl.testing import parameterized
import jax
import numpy as np
from torax import geometry
from torax.config import profile_conditions as profile_conditions_lib
from torax.config import runtime_params as general_runtime_params
from torax.config import runtime_params_slice as runtime_params_slice_lib
from torax.geometry import geometry
from torax.pedestal_model import set_tped_nped
from torax.sources import electron_density_sources
from torax.sources import formula_config
Expand Down Expand Up @@ -74,11 +74,15 @@ def test_time_dependent_provider_is_time_dependent(self):
stepper=stepper_params_lib.RuntimeParams(),
torax_mesh=self._geo.torax_mesh,
)
dynamic_runtime_params_slice = provider(t=1.0,)
dynamic_runtime_params_slice = provider(
t=1.0,
)
np.testing.assert_allclose(
dynamic_runtime_params_slice.profile_conditions.Ti_bound_right, 2.5
)
dynamic_runtime_params_slice = provider(t=2.0,)
dynamic_runtime_params_slice = provider(
t=2.0,
)
np.testing.assert_allclose(
dynamic_runtime_params_slice.profile_conditions.Ti_bound_right, 3.0
)
Expand Down Expand Up @@ -422,9 +426,18 @@ def test_profile_conditions_set_electron_temperature_and_boundary_condition(
)

@parameterized.product(
ne_bound_right=[None, 1.0,],
ne_bound_right_is_fGW=[True, False,],
ne_is_fGW=[True, False,],
ne_bound_right=[
None,
1.0,
],
ne_bound_right_is_fGW=[
True,
False,
],
ne_is_fGW=[
True,
False,
],
)
def test_profile_conditions_set_electron_density_and_boundary_condition(
self,
Expand Down Expand Up @@ -630,5 +643,6 @@ def test_static_runtime_params_slice_hash_different_for_different_params(
)
self.assertNotEqual(hash(static_slice1), hash(static_slice2))


if __name__ == '__main__':
absltest.main()
Loading

0 comments on commit f6af816

Please sign in to comment.