Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change core_profile_setters_test to have test prefix and simplify tests #665

Merged
merged 1 commit into from
Jan 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 53 additions & 69 deletions torax/core_profile_setters.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
from torax import math_utils
from torax import physics
from torax import state
from torax.config import numerics
from torax.config import profile_conditions
from torax.config import runtime_params_slice
from torax.fvm import cell_variable
from torax.geometry import geometry
Expand All @@ -38,83 +40,76 @@

_trapz = jax.scipy.integrate.trapezoid

# Using capitalized variables for physics notational conventions rather than
# Python style.
# pylint: disable=invalid-name

def updated_ion_temperature(
dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice,

def _updated_ion_temperature(
dynamic_profile_conditions: profile_conditions.DynamicProfileConditions,
geo: geometry.Geometry,
) -> cell_variable.CellVariable:
"""Updated ion temp. Used upon initialization and if temp_ion=False."""
# pylint: disable=invalid-name
Ti_bound_right = (
dynamic_runtime_params_slice.profile_conditions.Ti_bound_right
)

Ti_bound_right = jax_utils.error_if_not_positive(
Ti_bound_right,
dynamic_profile_conditions.Ti_bound_right,
'Ti_bound_right',
)
temp_ion = cell_variable.CellVariable(
value=dynamic_runtime_params_slice.profile_conditions.Ti,
value=dynamic_profile_conditions.Ti,
left_face_grad_constraint=jnp.zeros(()),
right_face_grad_constraint=None,
right_face_constraint=Ti_bound_right,
dr=geo.drho_norm,
)
# pylint: enable=invalid-name

return temp_ion


def updated_electron_temperature(
dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice,
def _updated_electron_temperature(
dynamic_profile_conditions: profile_conditions.DynamicProfileConditions,
geo: geometry.Geometry,
) -> cell_variable.CellVariable:
"""Updated electron temp. Used upon initialization and if temp_el=False."""
# pylint: disable=invalid-name
Te_bound_right = (
dynamic_runtime_params_slice.profile_conditions.Te_bound_right
)

Te_bound_right = jax_utils.error_if_not_positive(
Te_bound_right,
dynamic_profile_conditions.Te_bound_right,
'Te_bound_right',
)
temp_el = cell_variable.CellVariable(
value=dynamic_runtime_params_slice.profile_conditions.Te,
value=dynamic_profile_conditions.Te,
left_face_grad_constraint=jnp.zeros(()),
right_face_grad_constraint=None,
right_face_constraint=Te_bound_right,
dr=geo.drho_norm,
)
# pylint: enable=invalid-name
return temp_el


# pylint: disable=invalid-name
def _get_ne(
dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice,
dynamic_numerics: numerics.DynamicNumerics,
dynamic_profile_conditions: profile_conditions.DynamicProfileConditions,
geo: geometry.Geometry,
) -> cell_variable.CellVariable:
"""Gets initial or prescribed electron density profile at current timestep."""
# pylint: disable=invalid-name

nGW = (
dynamic_runtime_params_slice.profile_conditions.Ip_tot
dynamic_profile_conditions.Ip_tot
/ (jnp.pi * geo.Rmin**2)
* 1e20
/ dynamic_runtime_params_slice.numerics.nref
/ dynamic_numerics.nref
)
ne_value = jnp.where(
dynamic_runtime_params_slice.profile_conditions.ne_is_fGW,
dynamic_runtime_params_slice.profile_conditions.ne * nGW,
dynamic_runtime_params_slice.profile_conditions.ne,
dynamic_profile_conditions.ne_is_fGW,
dynamic_profile_conditions.ne * nGW,
dynamic_profile_conditions.ne,
)
# Calculate ne_bound_right.
ne_bound_right = jnp.where(
dynamic_runtime_params_slice.profile_conditions.ne_bound_right_is_fGW,
dynamic_runtime_params_slice.profile_conditions.ne_bound_right * nGW,
dynamic_runtime_params_slice.profile_conditions.ne_bound_right,
dynamic_profile_conditions.ne_bound_right_is_fGW,
dynamic_profile_conditions.ne_bound_right * nGW,
dynamic_profile_conditions.ne_bound_right,
)

if dynamic_runtime_params_slice.profile_conditions.normalize_to_nbar:
if dynamic_profile_conditions.normalize_to_nbar:
face_left = ne_value[0] # Zero gradient boundary condition at left face.
face_right = ne_bound_right
face_inner = (ne_value[..., :-1] + ne_value[..., 1:]) / 2.0
Expand All @@ -132,16 +127,15 @@ def _get_ne(
Rmin_out = geo.Rout_face[-1] - geo.Rout_face[0]
# find target nbar in absolute units
target_nbar = jnp.where(
dynamic_runtime_params_slice.profile_conditions.ne_is_fGW,
dynamic_runtime_params_slice.profile_conditions.nbar * nGW,
dynamic_runtime_params_slice.profile_conditions.nbar,
dynamic_profile_conditions.ne_is_fGW,
dynamic_profile_conditions.nbar * nGW,
dynamic_profile_conditions.nbar,
)
if (
not dynamic_runtime_params_slice.profile_conditions.ne_bound_right_is_absolute
not dynamic_profile_conditions.ne_bound_right_is_absolute
):
# In this case, ne_bound_right is taken from ne and we also normalize it.
C = target_nbar / (_trapz(ne_face, geo.Rout_face) / Rmin_out)
# pylint: enable=invalid-name
ne_bound_right = C * ne_bound_right
else:
# If ne_bound_right is absolute, subtract off contribution from outer
Expand Down Expand Up @@ -180,7 +174,6 @@ def _get_charge_states(
array_typing.ArrayFloat,
]:
"""Updated charge states based on IonMixtures and electron temperature."""
# pylint: disable=invalid-name
Zi = charge_states.get_average_charge_state(
ion_symbols=static_runtime_params_slice.main_ion_names,
ion_mixture=dynamic_runtime_params_slice.plasma_composition.main_ion,
Expand Down Expand Up @@ -308,7 +301,6 @@ def _prescribe_currents_no_bootstrap(
"""
# Many variables throughout this function are capitalized based on physics
# notational conventions rather than on Google Python style
# pylint: disable=invalid-name

# Calculate splitting of currents depending on input runtime params.
Ip = dynamic_runtime_params_slice.profile_conditions.Ip_tot
Expand Down Expand Up @@ -397,7 +389,6 @@ def _prescribe_currents_with_bootstrap(

# Many variables throughout this function are capitalized based on physics
# notational conventions rather than on Google Python style
# pylint: disable=invalid-name
Ip = dynamic_runtime_params_slice.profile_conditions.Ip_tot

bootstrap_profile = source_models.j_bootstrap.get_value(
Expand Down Expand Up @@ -483,7 +474,6 @@ def _calculate_currents_from_psi(

# Many variables throughout this function are capitalized based on physics
# notational conventions rather than on Google Python style
# pylint: disable=invalid-name
jtot, jtot_face, Ip_profile_face = physics.calc_jtot_from_psi(
geo,
core_profiles.psi,
Expand Down Expand Up @@ -574,7 +564,6 @@ def _update_psi_from_j(
return psi


# pylint: enable=invalid-name
def _calculate_psi_grad_constraint(
dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice,
geo: geometry.Geometry,
Expand Down Expand Up @@ -689,12 +678,10 @@ def _init_psi_and_current(
geo,
currents.jtot_hires,
)
# pylint: disable=invalid-name
_, _, Ip_profile_face = physics.calc_jtot_from_psi(
geo,
psi,
)
# pylint: enable=invalid-name
currents = dataclasses.replace(currents, Ip_profile_face=Ip_profile_face)
else:
raise ValueError('Cannot compute psi for given config.')
Expand All @@ -721,14 +708,19 @@ def initial_core_profiles(
Returns:
Initial core profiles.
"""
# pylint: disable=invalid-name

# To set initial values and compute the boundary conditions, we need to handle
# potentially time-varying inputs from the users.
# The default time in build_dynamic_runtime_params_slice is t_initial
temp_ion = updated_ion_temperature(dynamic_runtime_params_slice, geo)
temp_el = updated_electron_temperature(dynamic_runtime_params_slice, geo)
ne = _get_ne(dynamic_runtime_params_slice, geo)
temp_ion = _updated_ion_temperature(
dynamic_runtime_params_slice.profile_conditions, geo
)
temp_el = _updated_electron_temperature(
dynamic_runtime_params_slice.profile_conditions, geo
)
ne = _get_ne(dynamic_runtime_params_slice.numerics,
dynamic_runtime_params_slice.profile_conditions,
geo)

ni, nimp, Zi, Zi_face, Zimp, Zimp_face = get_ion_density_and_charge_states(
static_runtime_params_slice,
Expand Down Expand Up @@ -795,15 +787,12 @@ def initial_core_profiles(
core_profiles = dataclasses.replace(core_profiles, psidot=psidot)

# Set psi as source of truth and recalculate jtot, q, s
core_profiles = physics.update_jtot_q_face_s_face(
return physics.update_jtot_q_face_s_face(
geo=geo,
core_profiles=core_profiles,
q_correction_factor=dynamic_runtime_params_slice.numerics.q_correction_factor,
)

# pylint: enable=invalid-name
return core_profiles


def get_prescribed_core_profile_values(
static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice,
Expand All @@ -824,23 +813,22 @@ def get_prescribed_core_profile_values(
Returns:
Updated core profiles values on the cell grid.
"""
# pylint: disable=invalid-name

# If profiles are not evolved, they can still potential be time-evolving,
# depending on the runtime params. If so, they are updated below.
if (
not static_runtime_params_slice.ion_heat_eq
and dynamic_runtime_params_slice.numerics.enable_prescribed_profile_evolution
):
temp_ion = updated_ion_temperature(dynamic_runtime_params_slice, geo).value
temp_ion = _updated_ion_temperature(
dynamic_runtime_params_slice.profile_conditions, geo).value
else:
temp_ion = core_profiles.temp_ion.value
if (
not static_runtime_params_slice.el_heat_eq
and dynamic_runtime_params_slice.numerics.enable_prescribed_profile_evolution
):
temp_el_cell_variable = updated_electron_temperature(
dynamic_runtime_params_slice, geo
temp_el_cell_variable = _updated_electron_temperature(
dynamic_runtime_params_slice.profile_conditions, geo
)
temp_el = temp_el_cell_variable.value
else:
Expand All @@ -850,7 +838,10 @@ def get_prescribed_core_profile_values(
not static_runtime_params_slice.dens_eq
and dynamic_runtime_params_slice.numerics.enable_prescribed_profile_evolution
):
ne_cell_variable = _get_ne(dynamic_runtime_params_slice, geo)
ne_cell_variable = _get_ne(
dynamic_runtime_params_slice.numerics,
dynamic_runtime_params_slice.profile_conditions,
geo)
else:
ne_cell_variable = core_profiles.ne
ni, nimp, Zi, Zi_face, Zimp, Zimp_face = get_ion_density_and_charge_states(
Expand Down Expand Up @@ -908,15 +899,13 @@ def get_update(x_new, var):
psi = get_update(x_new, 'psi')
ne = get_update(x_new, 'ne')

# pylint: disable=invalid-name
ni, nimp, Zi, Zi_face, Zimp, Zimp_face = get_ion_density_and_charge_states(
static_runtime_params_slice,
dynamic_runtime_params_slice,
geo,
ne,
temp_el,
)
# pylint: enable=invalid-name

return dataclasses.replace(
core_profiles,
Expand Down Expand Up @@ -950,24 +939,24 @@ def compute_boundary_conditions(
each CellVariable in the state. This dict can in theory recursively replace
values in a State object.
"""
Ti_bound_right = jax_utils.error_if_not_positive( # pylint: disable=invalid-name
Ti_bound_right = jax_utils.error_if_not_positive(
dynamic_runtime_params_slice.profile_conditions.Ti_bound_right,
'Ti_bound_right',
)

Te_bound_right = jax_utils.error_if_not_positive( # pylint: disable=invalid-name
Te_bound_right = jax_utils.error_if_not_positive(
dynamic_runtime_params_slice.profile_conditions.Te_bound_right,
'Te_bound_right',
)
# TODO(b/390143606): Separate out the boundary condition calculation from the
# core profile calculation.
ne = _get_ne(
dynamic_runtime_params_slice,
dynamic_runtime_params_slice.numerics,
dynamic_runtime_params_slice.profile_conditions,
geo,
)
ne_bound_right = ne.right_face_constraint

# pylint: disable=invalid-name
Zi_edge = charge_states.get_average_charge_state(
static_runtime_params_slice.main_ion_names,
ion_mixture=dynamic_runtime_params_slice.plasma_composition.main_ion,
Expand All @@ -978,7 +967,6 @@ def compute_boundary_conditions(
ion_mixture=dynamic_runtime_params_slice.plasma_composition.impurity,
Te=Te_bound_right,
)
# pylint: disable=invalid-name

dilution_factor_edge = physics.get_main_ion_dilution_factor(
Zi_edge,
Expand Down Expand Up @@ -1026,7 +1014,6 @@ def compute_boundary_conditions(
}


# pylint: disable=invalid-name
def _get_jtot_hires(
dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice,
geo: geometry.Geometry,
Expand Down Expand Up @@ -1064,6 +1051,3 @@ def _get_jtot_hires(
johm_hires = jformula_hires * Cohm_hires
jtot_hires = johm_hires + external_current_hires + j_bootstrap_hires
return jtot_hires


# pylint: enable=invalid-name
3 changes: 3 additions & 0 deletions torax/tests/arg_order.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,9 @@ def test_arg_order(self, module):
fields = inspect.getmembers(module)
print(module.__name__)
for name, obj in fields:
if name.startswith("_"):
# Ignore private fields and methods.
continue
if inspect.isfunction(obj):
print("\t", name)
params = inspect.signature(obj).parameters.keys()
Expand Down
Loading
Loading