Skip to content

Commit

Permalink
Alper Review Comments Pt1
Browse files Browse the repository at this point in the history
  • Loading branch information
manishvenu committed Jan 10, 2025
1 parent 29318a0 commit f4dfef7
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 79 deletions.
111 changes: 42 additions & 69 deletions regional_mom6/regional_mom6.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,15 @@
from collections import defaultdict
import json
import copy
from . import regridding as rgd
from . import rotation as rot
from .utils import quadrilateral_areas, ap2ep, ep2ap, is_rectilinear_hgrid
from regional_mom6 import regridding as rgd
from regional_mom6 import rotation as rot
from regional_mom6.utils import (
quadrilateral_areas,
ap2ep,
ep2ap,
is_rectilinear_hgrid,
rotate,
)


warnings.filterwarnings("ignore")
Expand Down Expand Up @@ -934,13 +940,21 @@ def _make_vgrid(self, thicknesses=None):
of largest to smallest layer thickness (``layer_thickness_ratio``) and the
total ``depth`` parameters.
(All these parameters are specified at the class level.)
Args:
thicknesses (Optional[np.ndarray]): An array of layer thicknesses. If not provided,
the layer thicknesses are generated using the :func:`~hyperbolictan_thickness_profile`
function.
"""

if thicknesses is None:
thicknesses = hyperbolictan_thickness_profile(
self.number_vertical_layers, self.layer_thickness_ratio, self.depth
)

if not isinstance(thicknesses, np.ndarray):
raise ValueError("thicknesses must be a numpy array")

zi = np.cumsum(thicknesses)
zi = np.insert(zi, 0, 0.0) # add zi = 0.0 as first interface

Expand Down Expand Up @@ -1311,9 +1325,9 @@ def setup_initial_condition(
.ffill("lat")
.bfill("lat")
)
renamed_hgrid = self.hgrid # This is not a deep copy
renamed_hgrid["lon"] = renamed_hgrid["x"]
renamed_hgrid["lat"] = renamed_hgrid["y"]

self.hgrid["lon"] = self.hgrid["x"]
self.hgrid["lat"] = self.hgrid["y"]
tgrid = (
rgd.get_hgrid_arakawa_c_points(self.hgrid, "t")
.rename({"tlon": "lon", "tlat": "lat", "nxp": "nx", "nyp": "ny"})
Expand All @@ -1323,10 +1337,10 @@ def setup_initial_condition(
## Make our three horizontal regridders

regridder_u = rgd.create_regridder(
ic_raw_u, renamed_hgrid, locstream_out=False, method="bilinear"
ic_raw_u, self.hgrid, locstream_out=False, method="bilinear"
)
regridder_v = rgd.create_regridder(
ic_raw_v, renamed_hgrid, locstream_out=False, method="bilinear"
ic_raw_v, self.hgrid, locstream_out=False, method="bilinear"
)
regridder_t = rgd.create_regridder(
ic_raw_tracers, tgrid, locstream_out=False, method="bilinear"
Expand All @@ -1344,8 +1358,7 @@ def setup_initial_condition(
regridded_u = regridder_u(ic_raw_u)
regridded_v = regridder_v(ic_raw_v)
if rotational_method == rot.RotationMethod.GIVEN_ANGLE:
rotated_u, rotated_v = segment.rotate(
None,
rotated_u, rotated_v = rotate(
regridded_u,
regridded_v,
radian_angle=np.radians(self.hgrid.angle_dx.values),
Expand All @@ -1354,7 +1367,7 @@ def setup_initial_condition(
self.hgrid["angle_dx_rm6"] = (
rot.initialize_grid_rotation_angles_using_expanded_hgrid(self.hgrid)
)
rotated_u, rotated_v = segment.rotate(
rotated_u, rotated_v = rotate(
regridded_u,
regridded_v,
radian_angle=np.radians(self.hgrid.angle_dx_rm6.values),
Expand Down Expand Up @@ -1601,13 +1614,13 @@ def setup_ocean_state_boundaries(
Default is `["south", "north", "west", "east"]`.
arakawa_grid (Optional[str]): Arakawa grid staggering type of the boundary forcing.
Either ``'A'`` (default), ``'B'``, or ``'C'``.
boundary_type (Optional[str]): Type of box around region. Currently, only ``'rectangular'`` is supported.
boundary_type (Optional[str]): Type of box around region. Currently, only ``'rectangular'`` or ``'curvilinear'`` is supported.
bathymetry_path (Optional[str]): Path to the bathymetry file. Default is None, in which case the BC is not masked.
rotational_method (Optional[str]): Method to use for rotating the boundary velocities. Default is 'GIVEN_ANGLE'.
"""
if boundary_type != "rectangular":
if not (boundary_type == "rectangular" or boundary_type == "curvilinear"):
raise ValueError(
"Only rectangular boundaries are supported by this method. To set up more complex boundary shapes you can manually call the 'simple_boundary' method for each boundary."
"Only rectangular or curvilinear boundaries are supported by this method. To set up more complex boundary shapes you can manually call the 'simple_boundary' method for each boundary."
)
for i in self.boundaries:
if i not in ["south", "north", "west", "east"]:
Expand Down Expand Up @@ -1650,7 +1663,6 @@ def setup_single_boundary(
orientation,
segment_number,
arakawa_grid="A",
boundary_type="simple",
bathymetry_path=None,
rotational_method=rot.RotationMethod.GIVEN_ANGLE,
):
Expand All @@ -1671,18 +1683,17 @@ def setup_single_boundary(
the ``MOM_input``.
arakawa_grid (Optional[str]): Arakawa grid staggering type of the boundary forcing.
Either ``'A'`` (default), ``'B'``, or ``'C'``.
boundary_type (Optional[str]): Type of boundary. Currently, only ``'simple'`` is supported. Here 'simple' refers to boundaries that are parallel to lines of constant longitude or latitude.
bathymetry_path (Optional[str]): Path to the bathymetry file. Default is None, in which case the BC is not masked
rotational_method (Optional[str]): Method to use for rotating the boundary velocities. Default is 'GIVEN_ANGLE'.
"""

print("Processing {} boundary...".format(orientation), end="")
print(
"Processing {} boundary velocity & tracers...".format(orientation), end=""
)
if not path_to_bc.exists():
raise FileNotFoundError(
f"Boundary file not found at {path_to_bc}. Please ensure that the files are named in the format `east_unprocessed.nc`."
)
if boundary_type != "simple":
raise ValueError("Only simple boundaries are supported by this method.")
self.segments[orientation] = segment(
hgrid=self.hgrid,
bathymetry_path=bathymetry_path,
Expand All @@ -1708,7 +1719,7 @@ def setup_boundary_tides(
tpxo_elevation_filepath,
tpxo_velocity_filepath,
tidal_constituents="read_from_expt_init",
boundary_type="rectangle",
boundary_type="rectangular",
bathymetry_path=None,
rotational_method=rot.RotationMethod.GIVEN_ANGLE,
):
Expand All @@ -1717,9 +1728,10 @@ def setup_boundary_tides(
Args:
path_to_td (str): Path to boundary tidal file.
tidal_filename: Name of the tpxo product that's used in the tidal_filename. Should be h_tidal_filename, u_tidal_filename
tpxo_elevation_filepath: Filepath to the TPXO elevation product. Generally of the form h_tidalversion.nc
tpxo_velocity_filepath: Filepath to the TPXO velocity product. Generally of the form u_tidalversion.nc
tidal_constituents: List of tidal constituents to include in the regridding. Default is [0] which is the M2 constituent.
boundary_type (str): Type of boundary. Currently, only rectangle is supported. Here rectangle refers to boundaries that are parallel to lines of constant longitude or latitude.
boundary_type (str): Type of boundary. Currently, only rectangle is supported. Here, rectangle refers to boundaries that are parallel to lines of constant longitude or latitude. Curvilinear is also suported.
bathymetry_path (str): Path to the bathymetry file. Default is None, in which case the BC is not masked
rotational_method (str): Method to use for rotating the tidal velocities. Default is 'GIVEN_ANGLE'.
Returns:
Expand All @@ -1740,9 +1752,9 @@ def setup_boundary_tides(
Type: Python Functions, Source Code
Web Address: https://github.com/jsimkins2/nwa25
"""
if boundary_type != "rectangle" and boundary_type != "curvilinear":
if not (boundary_type == "rectangular" or boundary_type == "curvilinear"):
raise ValueError(
"Only rectangular boundaries are supported by this method."
"Only rectangular or curvilinear boundaries are supported by this method."
)
if tidal_constituents != "read_from_expt_init":
self.tidal_constituents = tidal_constituents
Expand Down Expand Up @@ -2999,45 +3011,6 @@ def __init__(
self.segment_name = segment_name
self.repeat_year_forcing = repeat_year_forcing

def rotate_complex(self, u, v, radian_angle):
"""
Rotate velocities to grid orientation using complex number math (Same as rotate)
Args:
u (xarray.DataArray): The u-component of the velocity.
v (xarray.DataArray): The v-component of the velocity.
radian_angle (xarray.DataArray): The angle of the grid in RADIANS
Returns:
Tuple[xarray.DataArray, xarray.DataArray]: The rotated u and v components of the velocity.
"""

# express velocity in the complex plan
vel = u + v * 1j
# rotate velocity using grid angle theta
vel = vel * np.exp(1j * radian_angle)

# From here you can easily get the rotated u, v, or the magnitude/direction of the currents:
u = np.real(vel)
v = np.imag(vel)

return u, v

def rotate(self, u, v, radian_angle):
"""
Rotate the velocities to the grid orientation.
Args:
u (xarray.DataArray): The u-component of the velocity.
v (xarray.DataArray): The v-component of the velocity.
radian_angle (xarray.DataArray): The angle of the grid in RADIANS
Returns:
Tuple[xarray.DataArray, xarray.DataArray]: The rotated u and v components of the velocity.
"""

u_rot = u * np.cos(radian_angle) - v * np.sin(radian_angle)
v_rot = u * np.sin(radian_angle) + v * np.cos(radian_angle)
return u_rot, v_rot

def regrid_velocity_tracers(self, rotational_method=rot.RotationMethod.GIVEN_ANGLE):
"""
Cut out and interpolate the velocities and tracers
Expand Down Expand Up @@ -3071,7 +3044,7 @@ def regrid_velocity_tracers(self, rotational_method=rot.RotationMethod.GIVEN_ANG

## Angle Calculation & Rotation
if rotational_method == rot.RotationMethod.GIVEN_ANGLE:
rotated_u, rotated_v = self.rotate(
rotated_u, rotated_v = rotate(
regridded[self.u],
regridded[self.v],
radian_angle=np.radians(coords.angle.values),
Expand All @@ -3092,7 +3065,7 @@ def regrid_velocity_tracers(self, rotational_method=rot.RotationMethod.GIVEN_ANG
)["angle"]

# Rotate
rotated_u, rotated_v = self.rotate(
rotated_u, rotated_v = rotate(
regridded[self.u],
regridded[self.v],
radian_angle=np.radians(degree_angle.values),
Expand Down Expand Up @@ -3130,7 +3103,7 @@ def regrid_velocity_tracers(self, rotational_method=rot.RotationMethod.GIVEN_ANG

# See explanation of the rotational methods in the A grid section
if rotational_method == rot.RotationMethod.GIVEN_ANGLE:
velocities_out["u"], velocities_out["v"] = self.rotate(
velocities_out["u"], velocities_out["v"] = rotate(
velocities_out["u"],
velocities_out["v"],
radian_angle=np.radians(coords.angle.values),
Expand All @@ -3145,7 +3118,7 @@ def regrid_velocity_tracers(self, rotational_method=rot.RotationMethod.GIVEN_ANG
self.segment_name,
angle_variable_name="angle_dx_rm6",
)["angle"]
velocities_out["u"], velocities_out["v"] = self.rotate(
velocities_out["u"], velocities_out["v"] = rotate(
velocities_out["u"],
velocities_out["v"],
radian_angle=np.radians(degree_angle.values),
Expand Down Expand Up @@ -3195,7 +3168,7 @@ def regrid_velocity_tracers(self, rotational_method=rot.RotationMethod.GIVEN_ANG

# See explanation of the rotational methods in the A grid section
if rotational_method == rot.RotationMethod.GIVEN_ANGLE:
rotated_u, rotated_v = self.rotate(
rotated_u, rotated_v = rotate(
regridded_u,
regridded_v,
radian_angle=np.radians(coords.angle.values),
Expand All @@ -3210,7 +3183,7 @@ def regrid_velocity_tracers(self, rotational_method=rot.RotationMethod.GIVEN_ANG
self.segment_name,
angle_variable_name="angle_dx_rm6",
)["angle"]
rotated_u, rotated_v = self.rotate(
rotated_u, rotated_v = rotate(
regridded_u,
regridded_v,
radian_angle=np.radians(degree_angle.values),
Expand Down
67 changes: 57 additions & 10 deletions regional_mom6/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,15 +321,62 @@ def setup_logger(
return logger


def is_rectilinear_hgrid(hgrid: xr.Dataset) -> bool:
def rotate_complex(u, v, radian_angle):
"""
Check if the hgrid is a rectilinear grid.
Rotate velocities to grid orientation using complex number math (Same as rotate)
Args:
u (xarray.DataArray): The u-component of the velocity.
v (xarray.DataArray): The v-component of the velocity.
radian_angle (xarray.DataArray): The angle of the grid in RADIANS
Returns:
Tuple[xarray.DataArray, xarray.DataArray]: The rotated u and v components of the velocity.
"""

# express velocity in the complex plan
vel = u + v * 1j
# rotate velocity using grid angle theta
vel = vel * np.exp(1j * radian_angle)

# From here you can easily get the rotated u, v, or the magnitude/direction of the currents:
u = np.real(vel)
v = np.imag(vel)

return u, v


def rotate(u, v, radian_angle):
"""
Rotate the velocities to the grid orientation.
Args:
u (xarray.DataArray): The u-component of the velocity.
v (xarray.DataArray): The v-component of the velocity.
radian_angle (xarray.DataArray): The angle of the grid in RADIANS
Returns:
Tuple[xarray.DataArray, xarray.DataArray]: The rotated u and v components of the velocity.
"""

u_rot = u * np.cos(radian_angle) - v * np.sin(radian_angle)
v_rot = u * np.sin(radian_angle) + v * np.cos(radian_angle)
return u_rot, v_rot


def is_rectilinear_hgrid(hgrid: xr.Dataset, rtol: float = 1e-3) -> bool:
"""
if hgrid.x.shape[0] < 2 or hgrid.x.shape[1] < 2:
raise ValueError("hgrid must have at least 2 points in each direction")
if not np.all(hgrid.y == hgrid.y[:, 0].values[:, np.newaxis]):
return False
if not np.all(hgrid.x == hgrid.x[0, :].values[np.newaxis, :]):
return False

return True
Check if the hgrid is a rectilinear grid. From mom6_bathy.grid.is_rectangular by Alper (Altuntas
)
Check if the grid is a rectangular lat-lon grid by comparing the first and last rows and columns of the tlon and tlat arrays.
Args:
hgrid (xarray.Dataset): The horizontal grid dataset.
rtol (float): Relative tolerance. Default is 1e-3.
"""
if (
np.allclose(hgrid.tlon[:, 0], hgrid.tlon[0, 0], rtol=rtol)
and np.allclose(hgrid.tlon[:, -1], hgrid.tlon[0, -1], rtol=rtol)
and np.allclose(hgrid.tlat[0, :], hgrid.tlat[0, 0], rtol=rtol)
and np.allclose(hgrid.tlat[-1, :], hgrid.tlat[-1, 0], rtol=rtol)
):
return True
return False

0 comments on commit f4dfef7

Please sign in to comment.