Skip to content

Commit

Permalink
Change core_profile_setters_test to have test prefix and simplify tests
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 716503857
  • Loading branch information
tamaranorman authored and Torax team committed Jan 21, 2025
1 parent 198d7b5 commit 0ba99f6
Show file tree
Hide file tree
Showing 6 changed files with 139 additions and 320 deletions.
122 changes: 53 additions & 69 deletions torax/core_profile_setters.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
from torax import math_utils
from torax import physics
from torax import state
from torax.config import numerics
from torax.config import profile_conditions
from torax.config import runtime_params_slice
from torax.fvm import cell_variable
from torax.geometry import geometry
Expand All @@ -38,83 +40,76 @@

_trapz = jax.scipy.integrate.trapezoid

# Using capitalized variables for physics notational conventions rather than
# Python style.
# pylint: disable=invalid-name

def updated_ion_temperature(
dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice,

def _updated_ion_temperature(
dynamic_profile_conditions: profile_conditions.DynamicProfileConditions,
geo: geometry.Geometry,
) -> cell_variable.CellVariable:
"""Updated ion temp. Used upon initialization and if temp_ion=False."""
# pylint: disable=invalid-name
Ti_bound_right = (
dynamic_runtime_params_slice.profile_conditions.Ti_bound_right
)

Ti_bound_right = jax_utils.error_if_not_positive(
Ti_bound_right,
dynamic_profile_conditions.Ti_bound_right,
'Ti_bound_right',
)
temp_ion = cell_variable.CellVariable(
value=dynamic_runtime_params_slice.profile_conditions.Ti,
value=dynamic_profile_conditions.Ti,
left_face_grad_constraint=jnp.zeros(()),
right_face_grad_constraint=None,
right_face_constraint=Ti_bound_right,
dr=geo.drho_norm,
)
# pylint: enable=invalid-name

return temp_ion


def updated_electron_temperature(
dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice,
def _updated_electron_temperature(
dynamic_profile_conditions: profile_conditions.DynamicProfileConditions,
geo: geometry.Geometry,
) -> cell_variable.CellVariable:
"""Updated electron temp. Used upon initialization and if temp_el=False."""
# pylint: disable=invalid-name
Te_bound_right = (
dynamic_runtime_params_slice.profile_conditions.Te_bound_right
)

Te_bound_right = jax_utils.error_if_not_positive(
Te_bound_right,
dynamic_profile_conditions.Te_bound_right,
'Te_bound_right',
)
temp_el = cell_variable.CellVariable(
value=dynamic_runtime_params_slice.profile_conditions.Te,
value=dynamic_profile_conditions.Te,
left_face_grad_constraint=jnp.zeros(()),
right_face_grad_constraint=None,
right_face_constraint=Te_bound_right,
dr=geo.drho_norm,
)
# pylint: enable=invalid-name
return temp_el


# pylint: disable=invalid-name
def _get_ne(
dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice,
dynamic_numerics: numerics.DynamicNumerics,
dynamic_profile_conditions: profile_conditions.DynamicProfileConditions,
geo: geometry.Geometry,
) -> cell_variable.CellVariable:
"""Gets initial or prescribed electron density profile at current timestep."""
# pylint: disable=invalid-name

nGW = (
dynamic_runtime_params_slice.profile_conditions.Ip_tot
dynamic_profile_conditions.Ip_tot
/ (jnp.pi * geo.Rmin**2)
* 1e20
/ dynamic_runtime_params_slice.numerics.nref
/ dynamic_numerics.nref
)
ne_value = jnp.where(
dynamic_runtime_params_slice.profile_conditions.ne_is_fGW,
dynamic_runtime_params_slice.profile_conditions.ne * nGW,
dynamic_runtime_params_slice.profile_conditions.ne,
dynamic_profile_conditions.ne_is_fGW,
dynamic_profile_conditions.ne * nGW,
dynamic_profile_conditions.ne,
)
# Calculate ne_bound_right.
ne_bound_right = jnp.where(
dynamic_runtime_params_slice.profile_conditions.ne_bound_right_is_fGW,
dynamic_runtime_params_slice.profile_conditions.ne_bound_right * nGW,
dynamic_runtime_params_slice.profile_conditions.ne_bound_right,
dynamic_profile_conditions.ne_bound_right_is_fGW,
dynamic_profile_conditions.ne_bound_right * nGW,
dynamic_profile_conditions.ne_bound_right,
)

if dynamic_runtime_params_slice.profile_conditions.normalize_to_nbar:
if dynamic_profile_conditions.normalize_to_nbar:
face_left = ne_value[0] # Zero gradient boundary condition at left face.
face_right = ne_bound_right
face_inner = (ne_value[..., :-1] + ne_value[..., 1:]) / 2.0
Expand All @@ -132,16 +127,15 @@ def _get_ne(
Rmin_out = geo.Rout_face[-1] - geo.Rout_face[0]
# find target nbar in absolute units
target_nbar = jnp.where(
dynamic_runtime_params_slice.profile_conditions.ne_is_fGW,
dynamic_runtime_params_slice.profile_conditions.nbar * nGW,
dynamic_runtime_params_slice.profile_conditions.nbar,
dynamic_profile_conditions.ne_is_fGW,
dynamic_profile_conditions.nbar * nGW,
dynamic_profile_conditions.nbar,
)
if (
not dynamic_runtime_params_slice.profile_conditions.ne_bound_right_is_absolute
not dynamic_profile_conditions.ne_bound_right_is_absolute
):
# In this case, ne_bound_right is taken from ne and we also normalize it.
C = target_nbar / (_trapz(ne_face, geo.Rout_face) / Rmin_out)
# pylint: enable=invalid-name
ne_bound_right = C * ne_bound_right
else:
# If ne_bound_right is absolute, subtract off contribution from outer
Expand Down Expand Up @@ -180,7 +174,6 @@ def _get_charge_states(
array_typing.ArrayFloat,
]:
"""Updated charge states based on IonMixtures and electron temperature."""
# pylint: disable=invalid-name
Zi = charge_states.get_average_charge_state(
ion_symbols=static_runtime_params_slice.main_ion_names,
ion_mixture=dynamic_runtime_params_slice.plasma_composition.main_ion,
Expand Down Expand Up @@ -308,7 +301,6 @@ def _prescribe_currents_no_bootstrap(
"""
# Many variables throughout this function are capitalized based on physics
# notational conventions rather than on Google Python style
# pylint: disable=invalid-name

# Calculate splitting of currents depending on input runtime params.
Ip = dynamic_runtime_params_slice.profile_conditions.Ip_tot
Expand Down Expand Up @@ -397,7 +389,6 @@ def _prescribe_currents_with_bootstrap(

# Many variables throughout this function are capitalized based on physics
# notational conventions rather than on Google Python style
# pylint: disable=invalid-name
Ip = dynamic_runtime_params_slice.profile_conditions.Ip_tot

bootstrap_profile = source_models.j_bootstrap.get_value(
Expand Down Expand Up @@ -483,7 +474,6 @@ def _calculate_currents_from_psi(

# Many variables throughout this function are capitalized based on physics
# notational conventions rather than on Google Python style
# pylint: disable=invalid-name
jtot, jtot_face, Ip_profile_face = physics.calc_jtot_from_psi(
geo,
core_profiles.psi,
Expand Down Expand Up @@ -574,7 +564,6 @@ def _update_psi_from_j(
return psi


# pylint: enable=invalid-name
def _calculate_psi_grad_constraint(
dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice,
geo: geometry.Geometry,
Expand Down Expand Up @@ -689,12 +678,10 @@ def _init_psi_and_current(
geo,
currents.jtot_hires,
)
# pylint: disable=invalid-name
_, _, Ip_profile_face = physics.calc_jtot_from_psi(
geo,
psi,
)
# pylint: enable=invalid-name
currents = dataclasses.replace(currents, Ip_profile_face=Ip_profile_face)
else:
raise ValueError('Cannot compute psi for given config.')
Expand All @@ -721,14 +708,19 @@ def initial_core_profiles(
Returns:
Initial core profiles.
"""
# pylint: disable=invalid-name

# To set initial values and compute the boundary conditions, we need to handle
# potentially time-varying inputs from the users.
# The default time in build_dynamic_runtime_params_slice is t_initial
temp_ion = updated_ion_temperature(dynamic_runtime_params_slice, geo)
temp_el = updated_electron_temperature(dynamic_runtime_params_slice, geo)
ne = _get_ne(dynamic_runtime_params_slice, geo)
temp_ion = _updated_ion_temperature(
dynamic_runtime_params_slice.profile_conditions, geo
)
temp_el = _updated_electron_temperature(
dynamic_runtime_params_slice.profile_conditions, geo
)
ne = _get_ne(dynamic_runtime_params_slice.numerics,
dynamic_runtime_params_slice.profile_conditions,
geo)

ni, nimp, Zi, Zi_face, Zimp, Zimp_face = get_ion_density_and_charge_states(
static_runtime_params_slice,
Expand Down Expand Up @@ -795,15 +787,12 @@ def initial_core_profiles(
core_profiles = dataclasses.replace(core_profiles, psidot=psidot)

# Set psi as source of truth and recalculate jtot, q, s
core_profiles = physics.update_jtot_q_face_s_face(
return physics.update_jtot_q_face_s_face(
geo=geo,
core_profiles=core_profiles,
q_correction_factor=dynamic_runtime_params_slice.numerics.q_correction_factor,
)

# pylint: enable=invalid-name
return core_profiles


def get_prescribed_core_profile_values(
static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice,
Expand All @@ -824,23 +813,22 @@ def get_prescribed_core_profile_values(
Returns:
Updated core profiles values on the cell grid.
"""
# pylint: disable=invalid-name

# If profiles are not evolved, they can still potential be time-evolving,
# depending on the runtime params. If so, they are updated below.
if (
not static_runtime_params_slice.ion_heat_eq
and dynamic_runtime_params_slice.numerics.enable_prescribed_profile_evolution
):
temp_ion = updated_ion_temperature(dynamic_runtime_params_slice, geo).value
temp_ion = _updated_ion_temperature(
dynamic_runtime_params_slice.profile_conditions, geo).value
else:
temp_ion = core_profiles.temp_ion.value
if (
not static_runtime_params_slice.el_heat_eq
and dynamic_runtime_params_slice.numerics.enable_prescribed_profile_evolution
):
temp_el_cell_variable = updated_electron_temperature(
dynamic_runtime_params_slice, geo
temp_el_cell_variable = _updated_electron_temperature(
dynamic_runtime_params_slice.profile_conditions, geo
)
temp_el = temp_el_cell_variable.value
else:
Expand All @@ -850,7 +838,10 @@ def get_prescribed_core_profile_values(
not static_runtime_params_slice.dens_eq
and dynamic_runtime_params_slice.numerics.enable_prescribed_profile_evolution
):
ne_cell_variable = _get_ne(dynamic_runtime_params_slice, geo)
ne_cell_variable = _get_ne(
dynamic_runtime_params_slice.numerics,
dynamic_runtime_params_slice.profile_conditions,
geo)
else:
ne_cell_variable = core_profiles.ne
ni, nimp, Zi, Zi_face, Zimp, Zimp_face = get_ion_density_and_charge_states(
Expand Down Expand Up @@ -908,15 +899,13 @@ def get_update(x_new, var):
psi = get_update(x_new, 'psi')
ne = get_update(x_new, 'ne')

# pylint: disable=invalid-name
ni, nimp, Zi, Zi_face, Zimp, Zimp_face = get_ion_density_and_charge_states(
static_runtime_params_slice,
dynamic_runtime_params_slice,
geo,
ne,
temp_el,
)
# pylint: enable=invalid-name

return dataclasses.replace(
core_profiles,
Expand Down Expand Up @@ -950,24 +939,24 @@ def compute_boundary_conditions(
each CellVariable in the state. This dict can in theory recursively replace
values in a State object.
"""
Ti_bound_right = jax_utils.error_if_not_positive( # pylint: disable=invalid-name
Ti_bound_right = jax_utils.error_if_not_positive(
dynamic_runtime_params_slice.profile_conditions.Ti_bound_right,
'Ti_bound_right',
)

Te_bound_right = jax_utils.error_if_not_positive( # pylint: disable=invalid-name
Te_bound_right = jax_utils.error_if_not_positive(
dynamic_runtime_params_slice.profile_conditions.Te_bound_right,
'Te_bound_right',
)
# TODO(b/390143606): Separate out the boundary condition calculation from the
# core profile calculation.
ne = _get_ne(
dynamic_runtime_params_slice,
dynamic_runtime_params_slice.numerics,
dynamic_runtime_params_slice.profile_conditions,
geo,
)
ne_bound_right = ne.right_face_constraint

# pylint: disable=invalid-name
Zi_edge = charge_states.get_average_charge_state(
static_runtime_params_slice.main_ion_names,
ion_mixture=dynamic_runtime_params_slice.plasma_composition.main_ion,
Expand All @@ -978,7 +967,6 @@ def compute_boundary_conditions(
ion_mixture=dynamic_runtime_params_slice.plasma_composition.impurity,
Te=Te_bound_right,
)
# pylint: disable=invalid-name

dilution_factor_edge = physics.get_main_ion_dilution_factor(
Zi_edge,
Expand Down Expand Up @@ -1026,7 +1014,6 @@ def compute_boundary_conditions(
}


# pylint: disable=invalid-name
def _get_jtot_hires(
dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice,
geo: geometry.Geometry,
Expand Down Expand Up @@ -1064,6 +1051,3 @@ def _get_jtot_hires(
johm_hires = jformula_hires * Cohm_hires
jtot_hires = johm_hires + external_current_hires + j_bootstrap_hires
return jtot_hires


# pylint: enable=invalid-name
34 changes: 0 additions & 34 deletions torax/fvm/cell_variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,28 +76,6 @@ class CellVariable:
# Can't make the above default values be jax zeros because that would be a
# call to jax before absl.app.run

def project(self, weights):
assert self.history is not None

def project(x):
return jnp.dot(weights, x)

def opt_project(x):
if x is None:
return None
return project(x)

return dataclasses.replace(
self,
value=project(self.value),
dr=self.dr[0],
left_face_constraint=opt_project(self.left_face_constraint),
left_face_grad_constraint=opt_project(self.left_face_grad_constraint),
right_face_constraint=opt_project(self.right_face_constraint),
right_face_grad_constraint=opt_project(self.right_face_grad_constraint),
history=None,
)

def __post_init__(self):
self.sanity_check()

Expand Down Expand Up @@ -266,18 +244,6 @@ def assert_not_history(self):
'by `jax.lax.scan`. Most methods of a CellVariable '
'do not work in history mode.'
)
if hasattr(self.history, 'ndim'):
if self.history.ndim == 0 or (
self.history.ndim == 1 and self.history.shape[0] == 1
):
msg += (
f' self.history={self.history} which probably indicates'
' (due to its scalar shape)'
' that an indexing or projection operation failed to'
' turn off history mode. self.history should be None for'
' non-history or a a vector of shape (history_length) for'
' history.'
)
raise AssertionError(msg)

def __hash__(self):
Expand Down
Loading

0 comments on commit 0ba99f6

Please sign in to comment.