Skip to content
This repository has been archived by the owner on Jan 10, 2025. It is now read-only.

Commit

Permalink
tidy
Browse files Browse the repository at this point in the history
  • Loading branch information
b8raoult committed Mar 18, 2024
1 parent 6994a10 commit 348f8e7
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 127 deletions.
116 changes: 0 additions & 116 deletions ecml_tools/data/subset.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from .debug import Node
from .debug import Source
from .debug import debug_indexing
from .forewards import Combined
from .forewards import Forwards
from .indexing import apply_index_to_slices_changes
from .indexing import expand_list_indexing
Expand All @@ -24,121 +23,6 @@
LOG = logging.getLogger(__name__)


class Join(Combined):
"""Join the datasets along the variables axis."""

def check_compatibility(self, d1, d2):
super().check_compatibility(d1, d2)
self.check_same_sub_shapes(d1, d2, drop_axis=1)

def check_same_variables(self, d1, d2):
# Turned off because we are joining along the variables axis
pass

def __len__(self):
return len(self.datasets[0])

@debug_indexing
@expand_list_indexing
def _get_tuple(self, index):
index, changes = index_to_slices(index, self.shape)
index, previous = update_tuple(index, 1, slice(None))

# TODO: optimize if index does not access all datasets, so we don't load chunks we don't need
result = [d[index] for d in self.datasets]

result = np.concatenate(result, axis=1)
return apply_index_to_slices_changes(result[:, previous], changes)

@debug_indexing
def _get_slice(self, s):
return np.stack([self[i] for i in range(*s.indices(self._len))])

@debug_indexing
def __getitem__(self, n):
if isinstance(n, tuple):
return self._get_tuple(n)

if isinstance(n, slice):
return self._get_slice(n)

return np.concatenate([d[n] for d in self.datasets])

@cached_property
def shape(self):
cols = sum(d.shape[1] for d in self.datasets)
return (len(self), cols) + self.datasets[0].shape[2:]

def _overlay(self):
indices = {}
i = 0
for d in self.datasets:
for v in d.variables:
indices[v] = i
i += 1

if len(indices) == i:
# No overlay
return self

indices = list(indices.values())

i = 0
for d in self.datasets:
ok = False
for v in d.variables:
if i in indices:
ok = True
i += 1
if not ok:
LOG.warning("Dataset %r completely overridden.", d)

from .select import Select

return Select(self, indices, {"overlay": True})

@cached_property
def variables(self):
seen = set()
result = []
for d in reversed(self.datasets):
for v in reversed(d.variables):
while v in seen:
v = f"({v})"
seen.add(v)
result.insert(0, v)

return result

@cached_property
def name_to_index(self):
return {k: i for i, k in enumerate(self.variables)}

@property
def statistics(self):
return {
k: np.concatenate([d.statistics[k] for d in self.datasets], axis=0) for k in self.datasets[0].statistics
}

def source(self, index):
i = index
for dataset in self.datasets:
if i < dataset.shape[1]:
return Source(self, index, dataset.source(i))
i -= dataset.shape[1]
assert False

@cached_property
def missing(self):
result = set()
for d in self.datasets:
result = result | d.missing
return result

def tree(self):
return Node(self, [d.tree() for d in self.datasets])


class Subset(Forwards):
"""Select a subset of the dates."""

Expand Down
13 changes: 2 additions & 11 deletions ecml_tools/grids.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
#

import numpy as np
from scipy.spatial import KDTree


def plot_mask(path, mask, lats, lons, global_lats, global_lons):
Expand Down Expand Up @@ -128,6 +127,7 @@ def cutout_mask(
"""
Return a mask for the points in [global_lats, global_lons] that are inside of [lats, lons]
"""
from scipy.spatial import KDTree

# TODO: transform min_distance from lat/lon to xyz

Expand Down Expand Up @@ -184,16 +184,6 @@ def cutout_mask(
intersect = t.intersect(zero, global_point) or t.intersect(global_point, zero)
close = np.min(distance) <= min_distance

if not intersect and False:

if 0 <= global_lons_masked[i] <= 30:
if 55 <= global_lats_masked[i] <= 70:
print(global_lats_masked[i], global_lons_masked[i], distance, intersect, close)
print(lats[index[0]], lons[index[0]])
print(lats[index[1]], lons[index[1]])
print(lats[index[2]], lons[index[2]])
assert False

ok.append(intersect and not close)

j = 0
Expand Down Expand Up @@ -226,6 +216,7 @@ def thinning_mask(
"""
Return the list of points in [lats, lons] closest to [global_lats, global_lons]
"""
from scipy.spatial import KDTree

assert global_lats.ndim == 1
assert global_lons.ndim == 1
Expand Down

0 comments on commit 348f8e7

Please sign in to comment.