Skip to content

Commit

Permalink
Merge pull request #55 from metno/54-cf-units-is-a-required-dependency
Browse files Browse the repository at this point in the history
54 cf units is a required dependency
  • Loading branch information
heikoklein authored Oct 30, 2024
2 parents b0e7c48 + da55f08 commit 7762d0f
Showing 1 changed file with 164 additions and 58 deletions.
222 changes: 164 additions & 58 deletions src/pyaro/timeseries/Filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

logger = logging.getLogger(__name__)


class Filter(abc.ABC):
"""Base-class for all filters used from pyaro-Readers"""

Expand Down Expand Up @@ -830,14 +831,17 @@ def filter_data_idx(self, data: Data, stations: dict[str, Station], variables: s
idx |= (minmax[0] <= data_resolution) & (data_resolution <= minmax[1])
return idx


@registered_filter
class AltitudeFilter(StationReductionFilter):
"""
Filter which filters stations based on their altitude. Can be used to filter for a
minimum and/or maximum altitude.
"""

def __init__(self, min_altitude: float | None = None, max_altitude: float | None = None):
def __init__(
self, min_altitude: float | None = None, max_altitude: float | None = None
):
"""
:param min_altitude : float of minimum altitude in meters required to keep the station (inclusive).
:param max_altitude : float of maximum altitude in meters required to keep the station (inclusive).
Expand All @@ -846,7 +850,9 @@ def __init__(self, min_altitude: float | None = None, max_altitude: float | None
"""
if min_altitude is not None and max_altitude is not None:
if min_altitude > max_altitude:
raise ValueError(f"min_altitude ({min_altitude}) > max_altitude ({max_altitude}).")
raise ValueError(
f"min_altitude ({min_altitude}) > max_altitude ({max_altitude})."
)

self._min_altitude = min_altitude
self._max_altitude = max_altitude
Expand All @@ -859,13 +865,28 @@ def name(self):

def filter_stations(self, stations: dict[str, Station]) -> dict[str, Station]:
if self._min_altitude is not None:
stations = {n: s for n, s in stations.items() if (not math.isnan(s["altitude"]) and s["altitude"] >= self._min_altitude) }

stations = {
n: s
for n, s in stations.items()
if (
not math.isnan(s["altitude"])
and s["altitude"] >= self._min_altitude
)
}

if self._max_altitude is not None:
stations = {n: s for n, s in stations.items() if (not math.isnan(s["altitude"]) and s["altitude"] <= self._max_altitude) }

stations = {
n: s
for n, s in stations.items()
if (
not math.isnan(s["altitude"])
and s["altitude"] <= self._max_altitude
)
}

return stations



@registered_filter
class RelativeAltitudeFilter(StationFilter):
"""
Expand All @@ -874,14 +895,30 @@ class RelativeAltitudeFilter(StationFilter):
https://github.com/metno/pyaro/issues/39
"""
UNITS_METER = Unit("m")

# https://cfconventions.org/Data/cf-conventions/cf-conventions-1.11/cf-conventions.html#latitude-coordinate
UNITS_LAT = set(["degrees_north", "degree_north", "degree_N", "degrees_N", "degreeN", "degreesN"])
UNITS_LAT = set(
[
"degrees_north",
"degree_north",
"degree_N",
"degrees_N",
"degreeN",
"degreesN",
]
)

# https://cfconventions.org/Data/cf-conventions/cf-conventions-1.11/cf-conventions.html#longitude-coordinate
UNITS_LON = set(["degrees_east", "degree_east", "degree_E", "degrees_E", "degreeE", "degreesE"])
UNITS_LON = set(
["degrees_east", "degree_east", "degree_E", "degrees_E", "degreeE", "degreesE"]
)

def __init__(self, topo_file: str | None = None, topo_var: str = "topography", rdiff: float = 0):
def __init__(
self,
topo_file: str | None = None,
topo_var: str = "topography",
rdiff: float = 0,
):
"""
:param topo_file : A .nc file from which to read gridded topography data.
:param topo_var : Name of variable that stores altitude.
Expand All @@ -895,104 +932,168 @@ def __init__(self, topo_file: str | None = None, topo_var: str = "topography", r
Note:
-----
This filter requires additional dependencies (xarray, netcdf4, cf-units) to function. These can be installed
with `pip install .[optional]
with `pip install .[optional]
"""
if "cf_units" not in sys.modules:
logger.warning("relaltitude filter is missing required dependency 'cf-units'. Please install to use this filter.")
logger.info(
"relaltitude filter is missing dependency 'cf-units'. Please install to use."
)
if "xarray" not in sys.modules:
logger.warning("relaltitude filter is missing required dependency 'xarray'. Please install to use this filter.")

logger.info(
"relaltitude filter is missing dependency 'xarray'. Please install to use."
)

self._topo_file = topo_file
self._topo_var = topo_var
self._rdiff = rdiff

# topography and unit-m property initialization
self._topography = None
if topo_file is not None:
self._topography = xr.open_dataset(topo_file)
self._convert_altitude_to_meters()
self._find_lat_lon_variables()
self._extract_bounding_box()
else:
logger.warning("No topography data provided (topo_file='%s'). Relative elevation filtering will not be applied.", topo_file)
self._UNITS_METER = None

@property
def UNITS_METER(self):
if self._UNITS_METER is None:
self._UNITS_METER = Unit("m")
return self._UNITS_METER

@property
def topography(self):
if "cf_units" not in sys.modules:
raise ModuleNotFoundError(
"relaltitude filter is missing required dependency 'cf-units'. Please install to use this filter."
)
if "xarray" not in sys.modules:
raise ModuleNotFoundError(
"relaltitude filter is missing required dependency 'xarray'. Please install to use this filter."
)

if self._topography is None:
if self._topo_file is None:
raise FilterException(
f"No topography data provided (topo_file='{self._topo_file}'). Relative elevation filtering will not be applied."
)
else:
try:
with xr.open_dataset(self._topo_file) as topo:
self._topography = self._convert_altitude_to_meters(topo)
lat, lon = self._find_lat_lon_variables(topo)
self._extract_bounding_box(lat, lon)
except Exception as ex:
raise FilterException(
f"Cannot read topography from '{self._topo_file}:{self._topo_var}' : {ex}"
)
return self._topography

def _convert_altitude_to_meters(self):
def _convert_altitude_to_meters(self, topo_xr):
"""
Method which attempts to convert the altitude variable in the gridded topography data
to meters.
:param topo_xr xarray dataset containting topo
:raises TypeError
If conversion isn't possible.
:return xr.DataArray
"""
# Convert altitude to meters
units = Unit(self._topography[self._topo_var].units)
units = Unit(topo_xr[self._topo_var].units)
if units.is_convertible(self.UNITS_METER):
self._topography[self._topo_var].values = self.UNITS_METER.convert(self._topography[self._topo_var].values, self.UNITS_METER)
self._topography[self._topo_var]["units"] = str(self.UNITS_METER)
topography = topo_xr[self._topo_var]
topography.values = self.UNITS_METER.convert(
topography.values, self.UNITS_METER
)
topography["units"] = str(self.UNITS_METER)
else:
raise TypeError(f"Expected altitude units to be convertible to 'm', got '{units}'")

def _find_lat_lon_variables(self):
raise TypeError(
f"Expected altitude units to be convertible to 'm', got '{units}'"
)
return topography

def _find_lat_lon_variables(self, topo_xr):
"""
Determines the names of variables which represent the latitude and longitude
Find and load DataArrays from topo which represent the latitude and longitude
dimensions in the topography data.
These are assigned to self._lat, self._lon, respectively for later use.
These are assigned to self._lat, self._lon, respectively for later use.
:param topo_xr xr.Dataset of topography
:return lat, lon DataArrays
"""
for var_name in self._topography.coords:
unit_str = self._topography[var_name].attrs.get("units", None)
if unit_str in self.UNITS_LAT:
self._lat = var_name
lat = topo_xr[var_name]
continue
if unit_str in self.UNITS_LON:
self._lon = var_name
lon = topo_xr[var_name]
continue

if any(x is None for x in [self._lat, self._lon]):
raise ValueError(f"Required variable names for lat, lon dimensions could not be found in file '{self._topo_file}")

def _extract_bounding_box(self):

if any(x is None for x in [lat, lon]):
raise ValueError(
f"Required variable names for lat, lon dimensions could not be found in file '{self._topo_file}"
)
return lat, lon

def _extract_bounding_box(self, lat, lon):
"""
Extract the bounding box of the grid.
Extract the bounding box of the grid, sets self._boundary_(north|east|south|west)
:param lat latitude (DataArray)
:param lon longitude (DataArray)
"""
self._boundary_west = float(self._topography[self._lon].min())
self._boundary_east = float(self._topography[self._lon].max())
self._boundary_south = float(self._topography[self._lat].min())
self._boundary_north = float(self._topography[self._lat].max())
logger.info("Bounding box (NESW): %.2f, %.2f, %.2f, %.2f", self._boundary_north, self._boundary_east, self._boundary_south, self._boundary_west)
self._boundary_west = float(lon.min())
self._boundary_east = float(lon.max())
self._boundary_south = float(lat.min())
self._boundary_north = float(lat.max())
logger.info(
"Bounding box (NESW) of topography: %.2f, %.2f, %.2f, %.2f",
self._boundary_north,
self._boundary_east,
self._boundary_south,
self._boundary_west,
)

def _gridded_altitude_from_lat_lon(self, lat: np.ndarray, lon: np.ndarray) -> np.ndarray:
altitude = self._topography.sel({"lat": xr.DataArray(lat, dims="latlon"), "lon": xr.DataArray(lon, dims="latlon")}, method="nearest")
def _gridded_altitude_from_lat_lon(
self, lat: np.ndarray, lon: np.ndarray
) -> np.ndarray:
altitude = self.topography.sel(
{
"lat": xr.DataArray(lat, dims="latlon"),
"lon": xr.DataArray(lon, dims="latlon"),
},
method="nearest",
)

return altitude[self._topo_var].values[0]
return altitude.values[0]

def _is_close(self, alt_gridded: np.ndarray, alt_station: np.ndarray) -> np.ndarray[bool]:
def _is_close(
self, alt_gridded: np.ndarray, alt_station: np.ndarray
) -> np.ndarray[bool]:
"""
Function to check if two altitudes are within a relative tolerance of each
other.
:param alt_gridded : Gridded altitude (in meters).
:param alt_station : Observation / station altitude (in meters).
:returns :
True if the absolute difference between alt_gridded and alt_station is
<= self._rdiff
"""
return np.abs(alt_gridded-alt_station) <= self._rdiff
return np.abs(alt_gridded - alt_station) <= self._rdiff

def init_kwargs(self):
return {
"topo_file": self._topo_file,
"topo_var": self._topo_var,
"rdiff": self._rdiff
"rdiff": self._rdiff,
}

def name(self):
return "relaltitude"

def filter_stations(self, stations: dict[str, Station]) -> dict[str, Station]:
if self._topography is None:
if self.topography is None:
return stations

names = np.ndarray(len(stations), dtype=np.dtypes.StrDType)
lats = np.ndarray(len(stations), dtype=np.float64)
lons = np.ndarray(len(stations), dtype=np.float64)
Expand All @@ -1005,9 +1106,14 @@ def filter_stations(self, stations: dict[str, Station]) -> dict[str, Station]:
lons[i] = station["longitude"]
alts[i] = station["altitude"]

out_of_bounds_mask = np.logical_or(np.logical_or(lons < self._boundary_west, lons > self._boundary_east), np.logical_or(lats < self._boundary_south, lats > self._boundary_north))
out_of_bounds_mask = np.logical_or(
np.logical_or(lons < self._boundary_west, lons > self._boundary_east),
np.logical_or(lats < self._boundary_south, lats > self._boundary_north),
)
if np.sum(out_of_bounds_mask) > 0:
logger.warning("Some stations were removed due to being out of bounds of the gridded topography")
logger.warning(
"Some stations were removed due to being out of bounds of the gridded topography"
)

topo = self._gridded_altitude_from_lat_lon(lats, lons)

Expand All @@ -1017,4 +1123,4 @@ def filter_stations(self, stations: dict[str, Station]) -> dict[str, Station]:

selected_names = names[mask]

return {name: stations[name] for name in selected_names}
return {name: stations[name] for name in selected_names}

0 comments on commit 7762d0f

Please sign in to comment.