Skip to content

Commit

Permalink
WIP put read method in context manager
Browse files Browse the repository at this point in the history
  • Loading branch information
jgriesfeller committed Oct 30, 2024
1 parent b14319a commit 4ab15d1
Show file tree
Hide file tree
Showing 5 changed files with 73 additions and 11 deletions.
20 changes: 16 additions & 4 deletions src/pyaro/csvreader/CSVTimeseriesReader.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,15 +86,24 @@ def __init__(
self._set_filters(filters)
self._extra_metadata = tuple(set(columns.keys()) - set(self.col_keys()))
if country_lookup:
lookupISO2 = _lookup_function()
self._lookupISO2 = _lookup_function()
else:
lookupISO2 = None
self._lookupISO2 = None
self._filename = filename
self._columns = columns
self._variable_units = variable_units
self._csvreader_kwargs = csvreader_kwargs

def read(self):
"""read method"""

for path in self._file_iterator:
logger.debug("%s: %s", filename, path)
logger.debug("%s: %s", self._filename, path)
self._read_single_file(
path, columns, variable_units, lookupISO2, csvreader_kwargs
path, self._columns, self._variable_units, self._lookupISO2, self._csvreader_kwargs
)


def _read_single_file(
self, filename, columns, variable_units, country_lookup, csvreader_kwargs
):
Expand Down Expand Up @@ -208,3 +217,6 @@ def description(self):

def url(self):
return "https://github.com/metno/pyaro"

def read(self):
return self.reader_class().read(*args, **kwargs)
5 changes: 3 additions & 2 deletions src/pyaro/timeseries/Engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,15 @@ def open(self, filename_or_obj_or_url, *, filters=None):
pass

@abc.abstractmethod
@contextmanager
# @contextmanager
def read(self):
"""read-method of the timeseries
:return pyaro.timeseries.Reader
:raises UnknownFilterException
"""
yield self
# yield self
pass

@property
@abc.abstractmethod
Expand Down
20 changes: 18 additions & 2 deletions src/pyaro/timeseries/Reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,16 @@ def __init__(self, filename_or_obj_or_url, filters=None, **kwargs):
pass

@abc.abstractmethod
@contextmanager
# @contextmanager
def read(self):
"""define read method. All needed parameters should be put into self
by the __init__ method
This function is usually called after the Engine's open function.
Should implement context manager
"""
yield self
# yield self
pass

@abc.abstractmethod
def metadata(self) -> dict[str, str]:
Expand Down Expand Up @@ -79,3 +80,18 @@ def close(self) -> None:
Implement as dummy (pass) if no cleanup needed.
"""
pass

def __enter__(self):
"""Context managaer function
:return: context-object
"""
return self

def __exit__(self, type, value, traceback):
"""Context manager function.
The default implementation calls the close function.
"""
self.close()
return
5 changes: 2 additions & 3 deletions src/pyaro/timeseries/Wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,14 +68,13 @@ def variables(self):
def close(self):
self._reader.close()

@contextmanager
# @contextmanager
def read(self,):
"""define read method. All needed parameters should be put into self
by the __init__ method
This method is called after the Engine's open function.
"""
with self._reader.read() as ts:
yield self
return self._reader.read()


34 changes: 34 additions & 0 deletions tests/test_CSVTimeSeriesReader.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def test_init(self):
engine.description()
engine.args()
with engine.open(self.file, filters=[]) as ts:
ts.read()
count = 0
for var in ts.variables():
count += len(ts.data(var))
Expand All @@ -70,6 +71,7 @@ def test_init_multifile(self):
engine.description()
engine.args()
with engine.open(self.multifile, filters=[]) as ts:
ts.read()
count = 0
for var in ts.variables():
count += len(ts.data(var))
Expand All @@ -83,6 +85,7 @@ def test_init_directory(self):
engine.description()
engine.args()
with engine.open(self.multifile_dir, filters=[]) as ts:
ts.read()
count = 0
for var in ts.variables():
count += len(ts.data(var))
Expand All @@ -93,6 +96,7 @@ def test_init2(self):
with pyaro.open_timeseries(
"csv_timeseries", *[self.file], **{"filters": []}
) as ts:
ts.read()
count = 0
for var in ts.variables():
count += len(ts.data(var))
Expand All @@ -118,6 +122,7 @@ def test_init_extra_columns(self):
with pyaro.open_timeseries(
"csv_timeseries", *[self.file], **{"filters": [], "columns": columns}
) as ts:
ts.read()
areas = ["Rural", "Urban"]
stations = ts.stations()
self.assertEqual(stations["station1"]["area_classification"], areas[0])
Expand All @@ -127,6 +132,7 @@ def test_metadata(self):
with pyaro.open_timeseries(
"csv_timeseries", *[self.file], **{"filters": []}
) as ts:
ts.read()
self.assertIsInstance(ts.metadata(), dict)
self.assertIn("path", ts.metadata())

Expand All @@ -136,6 +142,7 @@ def test_data(self):
filename=self.file,
filters=[pyaro.timeseries.filters.get("countries", include=["NO"])],
) as ts:
ts.read()
for var in ts.variables():
# stations
ts.data(var).stations
Expand All @@ -161,6 +168,7 @@ def test_append_data(self):
filename=self.file,
filters={"countries": {"include": ["NO"]}},
) as ts:
ts.read()
var = next(iter(ts.variables()))
data = ts.data(var)
old_size = len(data)
Expand All @@ -185,6 +193,7 @@ def test_stationfilter(self):
engine = pyaro.list_timeseries_engines()["csv_timeseries"]
sfilter = pyaro.timeseries.filters.get("stations", exclude=["station1"])
with engine.open(self.file, filters=[sfilter]) as ts:
ts.read()
count = 0
for var in ts.variables():
count += len(ts.data(var))
Expand All @@ -202,6 +211,7 @@ def test_boundingboxfilter(self):
)
self.assertEqual(sfilter.init_kwargs()["include"][0][3], 0)
with engine.open(self.file, filters=[sfilter]) as ts:
ts.read()
count = 0
for var in ts.variables():
count += len(ts.data(var))
Expand All @@ -212,6 +222,7 @@ def test_boundingboxfilter(self):
)
self.assertEqual(sfilter.init_kwargs()["exclude"][0][3], -180)
with engine.open(self.file, filters=[sfilter]) as ts:
ts.read()
count = 0
for var in ts.variables():
count += len(ts.data(var))
Expand Down Expand Up @@ -239,6 +250,7 @@ def test_timebounds(self):
self.assertIsInstance(dt1, datetime.datetime)
self.assertIsInstance(dt2, datetime.datetime)
with engine.open(self.file, filters=[tfilter]) as ts:
ts.read()
count = 0
for var in ts.variables():
count += len(ts.data(var))
Expand All @@ -258,6 +270,7 @@ def test_flagfilter(self):
ffilter.init_kwargs()["include"][0], pyaro.timeseries.Flag.VALID
)
with engine.open(self.file, filters=[ffilter]) as ts:
ts.read()
count = 0
for var in ts.variables():
count += len(ts.data(var))
Expand All @@ -268,6 +281,7 @@ def test_flagfilter(self):
"flags", include=[pyaro.timeseries.Flag.INVALID]
)
with engine.open(self.file, filters=[ffilter]) as ts:
ts.read()
count = 0
for var in ts.variables():
count += len(ts.data(var))
Expand All @@ -288,6 +302,7 @@ def test_variable_time_station_filter(self):
)
engine = pyaro.list_timeseries_engines()["csv_timeseries"]
with engine.open(self.file, filters=[vtsfilter]) as ts:
ts.read()
count = 0
for var in ts.variables():
count += len(ts.data(var))
Expand All @@ -311,6 +326,7 @@ def test_variable_time_station_filter_csv(self):
)
engine = pyaro.list_timeseries_engines()["csv_timeseries"]
with engine.open(self.file, filters=[vtsfilter]) as ts:
ts.read()
count = 0
for var in ts.variables():
count += len(ts.data(var))
Expand All @@ -323,6 +339,7 @@ def test_wrappers(self):
with VariableNameChangingReader(
engine.open(self.file, filters=[]), {"SOx": newsox}
) as ts:
ts.read()
self.assertEqual(ts.data(newsox).variable, newsox)
pass

Expand All @@ -333,6 +350,7 @@ def test_variables_filter(self):
"variables", reader_to_new={"SOx": newsox}
)
with engine.open(self.file, filters=[vfilter]) as ts:
ts.read()
self.assertEqual(ts.data(newsox).variable, newsox)
pass

Expand All @@ -342,13 +360,15 @@ def test_duplicate_filter(self):
self.multifile_dir + "/csvReader_testdata2.csv",
filters={"duplicates": {"duplicate_keys": None}},
) as ts:
ts.read()
self.assertEqual(len(ts.data("NOx")), 8)
with engine.open(
self.multifile_dir + "/csvReader_testdata2.csv",
filters={
"duplicates": {"duplicate_keys": ["stations", "start_times", "values"]}
},
) as ts:
ts.read()
self.assertEqual(len(ts.data("NOx")), 10)

def test_time_resolution_filter(self):
Expand All @@ -363,6 +383,7 @@ def test_time_resolution_filter(self):
self.file,
filters={"time_resolution": {"resolutions": ["1 day"]}},
) as ts:
ts.read()
count = 0
for var in ts.variables():
count += len(ts.data(var))
Expand All @@ -372,6 +393,7 @@ def test_time_resolution_filter(self):
self.file,
filters={"time_resolution": {"resolutions": ["1 hour"]}},
) as ts:
ts.read()
count = 0
for var in ts.variables():
count += len(ts.data(var))
Expand All @@ -387,6 +409,7 @@ def test_filterCollection(self):
"csv_timeseries",
filename=self.file,
) as ts:
ts.read()
filters = pyaro.timeseries.FilterCollection(
{
"countries": {"include": ["NO"]},
Expand All @@ -402,6 +425,7 @@ def test_timeseries_data_to_pd(self):
with pyaro.open_timeseries(
"csv_timeseries", *[self.file], **{"filters": []}
) as ts:
ts.read()
count = 0
vars = list(ts.variables())
data = ts.data(vars[0])
Expand All @@ -415,6 +439,7 @@ def test_country_lookup(self):
with pyaro.open_timeseries(
"csv_timeseries", *[self.file], **{"filters": [], "country_lookup": True}
) as ts:
ts.read()
count = 0
vars = list(ts.variables())
data = ts.data(vars[0])
Expand All @@ -440,6 +465,7 @@ def test_altitude_filter_1(self):
"flag": "0",
}
) as ts:
ts.read()
self.assertEqual(len(ts.stations()), 1)

def test_altitude_filter_2(self):
Expand All @@ -462,6 +488,7 @@ def test_altitude_filter_2(self):
"flag": "0",
}
) as ts:
ts.read()
self.assertEqual(len(ts.stations()), 1)

def test_altitude_filter_3(self):
Expand All @@ -484,6 +511,7 @@ def test_altitude_filter_3(self):
"flag": "0",
}
) as ts:
ts.read()
self.assertEqual(len(ts.stations()), 1)

def test_relaltitude_filter_emep_1(self):
Expand All @@ -506,6 +534,7 @@ def test_relaltitude_filter_emep_1(self):
"flag": "0",
}
) as ts:
ts.read()
# Altitudes in test dataset:
# Station | Alt_obs | Modeobs | rdiff |
# Station 1 | 100 | 12.2554 | 87.7446 |
Expand Down Expand Up @@ -534,6 +563,7 @@ def test_relaltitude_filter_emep_2(self):
"flag": "0",
}
) as ts:
ts.read()
# At rdiff = 90, only the first station should be included.
self.assertEqual(len(ts.stations()), 1)

Expand All @@ -557,6 +587,7 @@ def test_relaltitude_filter_emep_3(self):
"flag": "0",
}
) as ts:
ts.read()
# Since rdiff=300, all stations should be included.
self.assertEqual(len(ts.stations()), 3)

Expand All @@ -580,6 +611,7 @@ def test_relaltitude_filter_1(self):
"flag": "0",
}
) as ts:
ts.read()
self.assertEqual(len(ts.stations()), 0)

def test_relaltitude_filter_2(self):
Expand All @@ -602,6 +634,7 @@ def test_relaltitude_filter_2(self):
"flag": "0",
}
) as ts:
ts.read()
# At rdiff = 90, only the first station should be included.
self.assertEqual(len(ts.stations()), 1)

Expand All @@ -625,6 +658,7 @@ def test_relaltitude_filter_3(self):
"flag": "0",
}
) as ts:
ts.read()
# Since rdiff=300, all stations should be included.
self.assertEqual(len(ts.stations()), 3)

Expand Down

0 comments on commit 4ab15d1

Please sign in to comment.