Skip to content

Commit

Permalink
Fixes K->C conversion in initial_condition (#176)
Browse files Browse the repository at this point in the history
* numpy.min -> numpy.nanmin to avoid issues with masked input

* add tests for K->C conversion

* don't hardcode mask

* format black

* add clarifying comments

* add another test for temperatures

* format black

* convert K->C before removing mask

* more robust K -> C

* Update inaccurate comment in initial condition generation

* black

* update temp units when they're converted

---------

Co-authored-by: ashjbarnes <ashjbarnes97@gmail.com>
  • Loading branch information
navidcy and ashjbarnes authored Jun 11, 2024
1 parent a0c1fb4 commit dcca341
Show file tree
Hide file tree
Showing 2 changed files with 156 additions and 71 deletions.
34 changes: 23 additions & 11 deletions regional_mom6/regional_mom6.py
Original file line number Diff line number Diff line change
Expand Up @@ -616,6 +616,12 @@ def initial_condition(
"Error in reading in initial condition tracers. Terminating!"
)

## if min(temperature) > 100 then assume that units must be degrees K
## (otherwise we can't be on Earth) and convert to degrees C
if np.nanmin(ic_raw[varnames["tracers"]["temp"]]) > 100:
ic_raw[varnames["tracers"]["temp"]] -= 273.15
ic_raw[varnames["tracers"]["temp"]].attrs["units"] = "degrees Celsius"

# Rename all coordinates to have 'lon' and 'lat' to work with the xesmf regridder
if arakawa_grid == "A":
if (
Expand All @@ -641,6 +647,7 @@ def initial_condition(
+ "in the varnames dictionary. For example, {'x': 'lon', 'y': 'lat'}.\n\n"
+ "Terminating!"
)

if arakawa_grid == "B":
if (
"xq" in varnames.keys()
Expand Down Expand Up @@ -695,6 +702,7 @@ def initial_condition(
+ "in the varnames dictionary. For example, {'xh': 'lonh', 'yh': 'lath', ...}.\n\n"
+ "Terminating!"
)

## Construct the xq, yh and xh, yq grids
ugrid = (
self.hgrid[["x", "y"]]
Expand Down Expand Up @@ -723,9 +731,10 @@ def initial_condition(
}
)

### Drop NaNs to be re-added later
# NaNs might be here from the land mask of the model that the IC has come from.
# If they're not removed then the coastlines from this other grid will be retained!
# The land mask comes from the bathymetry file, so we don't need NaNs
# to tell MOM6 where the land is.
ic_raw_tracers = (
ic_raw_tracers.interpolate_na("lon", method="linear")
.ffill("lon")
Expand Down Expand Up @@ -780,8 +789,11 @@ def initial_condition(
)

print("INITIAL CONDITIONS")

## Regrid all fields horizontally.
print("Regridding Velocities...", end="")

print("Regridding Velocities... ", end="")

vel_out = xr.merge(
[
regridder_u(ic_raw_u)
Expand All @@ -792,18 +804,22 @@ def initial_condition(
.rename("v"),
]
)
print("Done.\nRegridding Tracers...")

print("Done.\nRegridding Tracers... ", end="")

tracers_out = xr.merge(
[
regridder_t(ic_raw_tracers[varnames["tracers"][i]]).rename(i)
for i in varnames["tracers"]
]
).rename({"lon": "xh", "lat": "yh", varnames["zl"]: "zl"})
print("Done.\nRegridding Free surface...")

print("Done.\nRegridding Free surface... ", end="")

eta_out = (
regridder_t(ic_raw_eta).rename({"lon": "xh", "lat": "yh"}).rename("eta_t")
) ## eta_t is the name set in MOM_input by default
print("Done.")

## Return attributes to arrays

Expand All @@ -825,11 +841,6 @@ def initial_condition(
eta_out.yh.attrs = ic_raw_tracers.lat.attrs
eta_out.attrs = ic_raw_eta.attrs

## if min(temp) > 100 then assume that units must be degrees K
## (otherwise we can't be on Earth) and convert to degrees C
if np.min(tracers_out["temp"].isel({"zl": 0})) > 100:
tracers_out["temp"] -= 273.15

## Regrid the fields vertically

if (
Expand Down Expand Up @@ -874,7 +885,7 @@ def initial_condition(
"eta_t": {"_FillValue": None},
},
)
print("done setting up initial condition.")
print("Done.\nFinished setting up initial condition.")

self.ic_eta = eta_out
self.ic_tracers = tracers_out
Expand Down Expand Up @@ -1910,10 +1921,11 @@ def rectangular_brushcut(self):
del segment_out["lat"]
## Convert temperatures to celsius # use pint
if (
np.min(segment_out[self.tracers["temp"]].isel({self.time: 0, self.z: 0}))
np.nanmin(segment_out[self.tracers["temp"]].isel({self.time: 0, self.z: 0}))
> 100
):
segment_out[self.tracers["temp"]] -= 273.15
segment_out[self.tracers["temp"]].attrs["units"] = "degrees Celsius"

# fill in NaNs
segment_out = (
Expand Down
193 changes: 133 additions & 60 deletions tests/test_expt_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,97 @@ def test_setup_bathymetry(
bathymetry_file.unlink()


def number_of_gridpoints(longitude_extent, latitude_extent, resolution):
nx = int((longitude_extent[-1] - longitude_extent[0]) / resolution)
ny = int((latitude_extent[-1] - latitude_extent[0]) / resolution)

return nx, ny


def generate_temperature_arrays(nx, ny, number_vertical_layers):

# temperatures close to 0 ᵒC
temp_in_C = np.random.randn(ny, nx, number_vertical_layers)

temp_in_C_masked = np.copy(temp_in_C)
if int(ny / 4 + 4) < ny - 1 and int(nx / 3 + 4) < nx + 1:
temp_in_C_masked[
int(ny / 3) : int(ny / 3 + 5), int(nx) : int(nx / 4 + 4), :
] = float("nan")
else:
raise ValueError("use bigger domain")

temp_in_K = np.copy(temp_in_C) + 273.15
temp_in_K_masked = np.copy(temp_in_C_masked) + 273.15

# ensure we didn't mask the minimum temperature
if np.nanmin(temp_in_C_masked) == np.min(temp_in_C):
return temp_in_C, temp_in_C_masked, temp_in_K, temp_in_K_masked
else:
return generate_temperature_arrays(nx, ny, number_vertical_layers)


def generate_silly_coords(
longitude_extent, latitude_extent, resolution, depth, number_vertical_layers
):
nx, ny = number_of_gridpoints(longitude_extent, latitude_extent, resolution)

horizontal_buffer = 5

silly_lat = np.linspace(
latitude_extent[0] - horizontal_buffer,
latitude_extent[1] + horizontal_buffer,
ny,
)
silly_lon = np.linspace(
longitude_extent[0] - horizontal_buffer,
longitude_extent[1] + horizontal_buffer,
nx,
)
silly_depth = np.linspace(0, depth, number_vertical_layers)

return silly_lat, silly_lon, silly_depth


longitude_extent = [-5, 3]
latitude_extent = (0, 10)
date_range = ["2003-01-01 00:00:00", "2003-01-01 00:00:00"]
resolution = 0.1
number_vertical_layers = 5
layer_thickness_ratio = 1
depth = 1000

silly_lat, silly_lon, silly_depth = generate_silly_coords(
longitude_extent, latitude_extent, resolution, depth, number_vertical_layers
)

dims = ["silly_lat", "silly_lon", "silly_depth"]

coords = {"silly_lat": silly_lat, "silly_lon": silly_lon, "silly_depth": silly_depth}

mom_run_dir = "rundir/"
mom_input_dir = "inputdir/"
toolpath_dir = "toolpath"
grid_type = "even_spacing"

nx, ny = number_of_gridpoints(longitude_extent, latitude_extent, resolution)

temp_in_C, temp_in_C_masked, temp_in_K, temp_in_K_masked = generate_temperature_arrays(
nx, ny, number_vertical_layers
)

temp_C = xr.DataArray(temp_in_C, dims=dims, coords=coords)
temp_K = xr.DataArray(temp_in_K, dims=dims, coords=coords)
temp_C_masked = xr.DataArray(temp_in_C_masked, dims=dims, coords=coords)
temp_K_masked = xr.DataArray(temp_in_K_masked, dims=dims, coords=coords)

maximum_temperature_in_C = np.max(temp_in_C)


@pytest.mark.parametrize(
"temp_dataarray_initial_condition",
[temp_C, temp_C_masked, temp_K, temp_K_masked],
)
@pytest.mark.parametrize(
(
"longitude_extent",
Expand All @@ -116,13 +207,13 @@ def test_setup_bathymetry(
),
[
(
[-5, 5],
(0, 10),
["2003-01-01 00:00:00", "2003-01-01 00:00:00"],
0.1,
5,
1,
1000,
longitude_extent,
latitude_extent,
date_range,
resolution,
number_vertical_layers,
layer_thickness_ratio,
depth,
"rundir/",
"inputdir/",
"toolpath",
Expand All @@ -142,8 +233,22 @@ def test_ocean_forcing(
mom_input_dir,
toolpath_dir,
grid_type,
temp_dataarray_initial_condition,
tmp_path,
):

silly_lat, silly_lon, silly_depth = generate_silly_coords(
longitude_extent, latitude_extent, resolution, depth, number_vertical_layers
)

dims = ["silly_lat", "silly_lon", "silly_depth"]

coords = {
"silly_lat": silly_lat,
"silly_lon": silly_lon,
"silly_depth": silly_depth,
}

expt = experiment(
longitude_extent=longitude_extent,
latitude_extent=latitude_extent,
Expand All @@ -160,72 +265,34 @@ def test_ocean_forcing(

## Generate some initial condition to test on

nx, ny = number_of_gridpoints(longitude_extent, latitude_extent, resolution)

# initial condition includes, temp, salt, eta, u, v
initial_cond = xr.Dataset(
{
"temp": xr.DataArray(
np.random.random((100, 100, 10)),
dims=["silly_lat", "silly_lon", "silly_depth"],
coords={
"silly_lat": np.linspace(
latitude_extent[0] - 5, latitude_extent[1] + 5, 100
),
"silly_lon": np.linspace(
longitude_extent[0] - 5, longitude_extent[1] + 5, 100
),
"silly_depth": np.linspace(0, 1000, 10),
},
),
"eta": xr.DataArray(
np.random.random((100, 100)),
np.random.random((ny, nx)),
dims=["silly_lat", "silly_lon"],
coords={
"silly_lat": np.linspace(
latitude_extent[0] - 5, latitude_extent[1] + 5, 100
),
"silly_lon": np.linspace(
longitude_extent[0] - 5, longitude_extent[1] + 5, 100
),
"silly_lat": silly_lat,
"silly_lon": silly_lon,
},
),
"temp": temp_dataarray_initial_condition,
"salt": xr.DataArray(
np.random.random((100, 100, 10)),
dims=["silly_lat", "silly_lon", "silly_depth"],
coords={
"silly_lat": np.linspace(
latitude_extent[0] - 5, latitude_extent[1] + 5, 100
),
"silly_lon": np.linspace(
longitude_extent[0] - 5, longitude_extent[1] + 5, 100
),
"silly_depth": np.linspace(0, 1000, 10),
},
np.random.random((ny, nx, number_vertical_layers)),
dims=dims,
coords=coords,
),
"u": xr.DataArray(
np.random.random((100, 100, 10)),
dims=["silly_lat", "silly_lon", "silly_depth"],
coords={
"silly_lat": np.linspace(
latitude_extent[0] - 5, latitude_extent[1] + 5, 100
),
"silly_lon": np.linspace(
longitude_extent[0] - 5, longitude_extent[1] + 5, 100
),
"silly_depth": np.linspace(0, 1000, 10),
},
np.random.random((ny, nx, number_vertical_layers)),
dims=dims,
coords=coords,
),
"v": xr.DataArray(
np.random.random((100, 100, 10)),
dims=["silly_lat", "silly_lon", "silly_depth"],
coords={
"silly_lat": np.linspace(
latitude_extent[0] - 5, latitude_extent[1] + 5, 100
),
"silly_lon": np.linspace(
longitude_extent[0] - 5, longitude_extent[1] + 5, 100
),
"silly_depth": np.linspace(0, 1000, 10),
},
np.random.random((ny, nx, number_vertical_layers)),
dims=dims,
coords=coords,
),
}
)
Expand All @@ -251,6 +318,12 @@ def test_ocean_forcing(
arakawa_grid="A",
)

# ensure that temperature is in degrees C
assert np.nanmin(expt.ic_tracers["temp"]) < 100.0

# max(temp) can be less maximum_temperature_in_C due to re-gridding
assert np.nanmax(expt.ic_tracers["temp"]) <= maximum_temperature_in_C


@pytest.mark.parametrize(
(
Expand Down

0 comments on commit dcca341

Please sign in to comment.