Skip to content
This repository has been archived by the owner on Jan 10, 2025. It is now read-only.

Commit

Permalink
Support for cutout and constants
Browse files Browse the repository at this point in the history
  • Loading branch information
b8raoult committed Feb 16, 2024
1 parent 75196ec commit 2fb1fc0
Show file tree
Hide file tree
Showing 5 changed files with 163 additions and 9 deletions.
14 changes: 14 additions & 0 deletions ecml_tools/create/functions/actions/empty.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# (C) Copyright 2024 ECMWF.
#
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.
#

import climetlab as cml


def execute(context, dates, **kwargs):
return cml.load_source("empty")
16 changes: 16 additions & 0 deletions ecml_tools/create/functions/steps/empty.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# (C) Copyright 2024 ECMWF.
#
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.
#

import climetlab as cml


def execute(context, input, **kwargs):
# Usefull to create a pipeline that returns an empty result
# So we can reference an earlier step in a function like 'contants'
return cml.load_source("empty")
19 changes: 18 additions & 1 deletion ecml_tools/create/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from copy import deepcopy
from functools import cached_property

import numpy as np
from climetlab.core.order import build_remapping

from .group import build_groups
Expand Down Expand Up @@ -53,6 +54,8 @@ def _datasource_request(data):
params_levels = defaultdict(set)
params_steps = defaultdict(set)

area = grid = None

for field in data:
if not hasattr(field, "as_mars"):
continue
Expand Down Expand Up @@ -133,6 +136,20 @@ def _build_coords(self):

first_field = self.owner.datasource[0]
grid_points = first_field.grid_points()

lats, lons = grid_points
north = np.amax(lats)
south = np.amin(lats)
east = np.amax(lons)
west = np.amin(lons)

assert -90 <= south <= north <= 90, (south, north, first_field)
assert (-180 <= west <= east <= 180) or (0 <= west <= east <= 360), (
west,
east,
first_field,
)

grid_values = list(range(len(grid_points[0])))

self._grid_points = grid_points
Expand Down Expand Up @@ -514,7 +531,7 @@ def __init__(self, context, path, *configs):
def select(self, dates):
print("🚀", self.path, f"PipeAction.select({dates}, {self.content})")
result = self.content.select(dates)
print("🍎", self.path, f"PipeAction.result", result)
print("🍎", self.path, "PipeAction.result", result)
return result

def __repr__(self):
Expand Down
90 changes: 87 additions & 3 deletions ecml_tools/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -784,6 +784,12 @@ def __init__(self, datasets, axis):
# Shape: (dates, variables, ensemble, 1d-values)
assert len(datasets[0].shape) == 4, "Grids must be 1D for now"

def check_same_grid(self, d1, d2):
# We don't check the grid, because we want to be able to combine
pass


class ConcatGrids(Grids):
# TODO: select the statistics of the most global grid?
@property
def latitudes(self):
Expand All @@ -793,10 +799,75 @@ def latitudes(self):
def longitudes(self):
return np.concatenate([d.longitudes for d in self.datasets])

def check_same_grid(self, d1, d2):
# We don't check the grid, because we want to be able to combine

class CutoutGrids(Grids):
def __init__(self, datasets, axis):
from .grids import cutout_mask

super().__init__(datasets, axis)
assert len(datasets) == 2, "CutoutGrids requires two datasets"
assert axis == 3, "CutoutGrids requires axis=3"

# We assume that the LAM is the first dataset, and the global is the second
# Note: the second fields does not really need to be global

self.lam, self.globe = datasets
self.mask = cutout_mask(
self.lam.latitudes,
self.lam.longitudes,
self.globe.latitudes,
self.globe.longitudes,
plot="cutout",
)
assert len(self.mask) == self.globe.shape[3], (
len(self.mask),
self.globe.shape[3],
)

@cached_property
def shape(self):
shape = self.lam.shape
# Number of non-zero masked values in the globe dataset
nb_globe = np.count_nonzero(self.mask)
return shape[:-1] + (shape[-1] + nb_globe,)

def check_same_resolution(self, d1, d2):
# Turned off because we are combining different resolutions
pass

@property
def latitudes(self):
return np.concatenate([self.lam.latitudes, self.globe.latitudes[self.mask]])

@property
def longitudes(self):
return np.concatenate([self.lam.longitudes, self.globe.longitudes[self.mask]])

def __getitem__(self, index):
if isinstance(index, (int, slice)):
index = (index, slice(None), slice(None), slice(None))
return self._get_tuple(index)

@debug_indexing
@expand_list_indexing
def _get_tuple(self, index):
assert index[self.axis] == slice(
None
), "No support for selecting a subset of the 1D values"
index, changes = index_to_slices(index, self.shape)

# In case index_to_slices has changed the last slice
index, _ = update_tuple(index, self.axis, slice(None))

lam_data = self.lam[index]
globe_data = self.globe[index]

globe_data = globe_data[:, :, :, self.mask]

result = np.concatenate([lam_data, globe_data], axis=self.axis)

return apply_index_to_slices_changes(result, changes)


class Join(Combined):
"""
Expand Down Expand Up @@ -1259,10 +1330,23 @@ def _open_dataset(*args, zarr_root, **kwargs):
if "ensemble" in kwargs:
raise NotImplementedError("Cannot use both 'ensemble' and 'grids'")
grids = kwargs.pop("grids")
mode = kwargs.pop("mode", "concatenate")
axis = kwargs.pop("axis", 3)
assert len(args) == 0
assert isinstance(grids, (list, tuple))
return Grids([_open(e, zarr_root) for e in grids], axis=axis)._subset(**kwargs)

KLASSES = {
"concatenate": ConcatGrids,
"cutout": CutoutGrids,
}
if mode not in KLASSES:
raise ValueError(
f"Unknown grids mode: {mode}, values are {list(KLASSES.keys())}"
)

return KLASSES[mode]([_open(e, zarr_root) for e in grids], axis=axis)._subset(
**kwargs
)

for name in ("datasets", "dataset"):
if name in kwargs:
Expand Down
33 changes: 28 additions & 5 deletions ecml_tools/grids.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,23 @@
from scipy.spatial import KDTree


def plot_mask(path, mask, lats, lons, global_lats, global_lons):
import matplotlib.pyplot as plt

fig = plt.figure(figsize=(10, 5))
plt.scatter(global_lons, global_lats, s=0.01, marker="o", c="r")
plt.savefig(path + "-global.png")

fig = plt.figure(figsize=(10, 5))
plt.scatter(global_lons[mask], global_lats[mask], s=0.1, c="k")
plt.savefig(path + "-cutout.png")

fig = plt.figure(figsize=(10, 5))
plt.scatter(lons, lats, s=0.01)
plt.savefig(path + "-lam.png")
# plt.scatter(lons, lats, s=0.01)


def latlon_to_xyz(lat, lon, radius=1.0):
# https://en.wikipedia.org/wiki/Geographic_coordinate_conversion#From_geodetic_to_ECEF_coordinates
# We assume that the Earth is a sphere of radius 1 so N(phi) = 1
Expand Down Expand Up @@ -84,12 +101,13 @@ def cropping_mask(lats, lons, north, west, south, east):


def cutout_mask(
global_lats,
global_lons,
lats,
lons,
global_lats,
global_lons,
cropping_distance=2.0,
min_distance=0.0,
plot=None,
):
"""
Return a mask for the points in [global_lats, global_lons] that are inside of [lats, lons]
Expand Down Expand Up @@ -162,7 +180,12 @@ def cutout_mask(
assert j == len(ok)

# Invert the mask, so we have only the points outside the cutout
return ~mask
mask = ~mask

if plot:
plot_mask(plot, mask, lats, lons, global_lats, global_lons)

return mask


if __name__ == "__main__":
Expand All @@ -180,12 +203,12 @@ def cutout_mask(
lats = lats.flatten()
lons = lons.flatten()

mask = cutout_mask(global_lats, global_lons, lats, lons, cropping_distance=5.0)
mask = cutout_mask(lats, lons, global_lats, global_lons, cropping_distance=5.0)

import matplotlib.pyplot as plt

fig = plt.figure(figsize=(10, 5))
plt.scatter(global_lons, global_lats, s=0.01, marker="o", c="r")
plt.scatter(global_lons[mask], global_lats[mask], s=0.1, c="k")
plt.scatter(lons, lats, s=0.01)
# plt.scatter(lons, lats, s=0.01)
plt.savefig("cutout.png")

0 comments on commit 2fb1fc0

Please sign in to comment.