diff --git a/regional_mom6/regional_mom6.py b/regional_mom6/regional_mom6.py index a15b2e51..54572741 100644 --- a/regional_mom6/regional_mom6.py +++ b/regional_mom6/regional_mom6.py @@ -26,7 +26,7 @@ from pathlib import Path import glob from collections import defaultdict - +import json warnings.filterwarnings("ignore") __all__ = [ @@ -35,6 +35,7 @@ "calculate_rectangular_hgrid", "experiment", "segment", + "load_experiment", ] @@ -104,6 +105,87 @@ def find_MOM6_rectangular_orientation(input): else: raise ValueError("Invalid type of Input, can only be string or int.") +## Load Expirement Function + +def load_experiment(config_file_path): + print("Reading from config file....") + with open(config_file_path, "r") as f: + config_dict = json.load(f) + + print("Creating Empty Experiment Object....") + expt = experiment.create_empty() + + print("Setting Default Variables.....") + expt.name = config_dict["name"] + expt.longitude_extent = tuple(config_dict["longitude_extent"]) + expt.latitude_extent = tuple(config_dict["latitude_extent"]) + expt.date_range = (config_dict["date_range"]) + expt.date_range[0] = dt.datetime.strptime(expt.date_range[0], "%Y-%m-%d") + expt.date_range[1] = dt.datetime.strptime(expt.date_range[1], "%Y-%m-%d") + expt.mom_run_dir = Path(config_dict["run_dir"]) + expt.mom_input_dir = Path(config_dict["input_dir"]) + expt.toolpath_dir = Path(config_dict["toolpath_dir"]) + expt.resolution = config_dict["resolution"] + expt.number_vertical_layers = config_dict["number_vertical_layers"] + expt.layer_thickness_ratio = config_dict["layer_thickness_ratio"] + expt.depth = config_dict["depth"] + expt.grid_type = config_dict["grid_type"] + expt.repeat_year_forcing = config_dict["repeat_year_forcing"] + expt.ocean_mask = None + expt.layout = None + expt.min_depth = config_dict["min_depth"] + expt.tidal_constituents = config_dict["tidal_constituents"] + + print("Checking for hgrid and vgrid....") + if os.path.exists(config_dict["hgrid"]): + print("Found") + expt.hgrid = xr.open_dataset(config_dict["hgrid"]) + else: + print("Hgrid not found, creating hgrid") + expt.hgrid = expt._make_hgrid() + if os.path.exists(config_dict["vgrid"]): + print("Found") + expt.vgrid = xr.open_dataset(config_dict["vgrid"]) + else: + print("Vgrid not found, creating vgrid") + expt.vgrid = expt._make_vgrid() + + print("Checking for bathymetry...") + if config_dict["bathymetry"] is not None and os.path.exists(config_dict["bathymetry"]): + print("Found") + expt.bathymetry = xr.open_dataset(config_dict["bathymetry"]) + else: + print("Bathymetry not found. Please provide bathymetry, or call setup_bathymetry method to set up bathymetry.") + + print("Checking for ocean state files....") + found = True + for path in config_dict["ocean_state"]: + if not os.path.exists(path): + foud = False + print("At least one ocean state file not found. Please provide ocean state files, or call setup_ocean_state_boundaries method to set up ocean state.") + break + if found: + print("Found") + found = True + print("Checking for initial condition files....") + for path in config_dict["initial_conditions"]: + if not os.path.exists(path): + print("At least one initial condition file not found. Please provide initial condition files, or call setup_initial_condition method to set up initial condition.") + break + if found: + print("Found") + found = True + print("Checking for tides files....") + for path in config_dict["tides"]: + if not os.path.exists(path): + print("At least one tides file not found. If you would like tides, call setup_tides_boundaries method to set up tides") + break + if found: + print("Found") + found = True + + return expt + ## Auxiliary functions @@ -525,6 +607,26 @@ class experiment: minimum_depth (Optional[int]): The minimum depth in meters of a grid cell allowed before it is masked out and treated as land. """ + @classmethod + def create_empty(self): + + + expt = self( + longitude_extent=None, + latitude_extent=None, + date_range=None, + resolution=None, + number_vertical_layers=None, + layer_thickness_ratio=None, + depth=None, + minimum_depth=None, + mom_run_dir=None, + mom_input_dir=None, + toolpath_dir=None, + create_empty = True + ) + return expt + def __init__( self, *, @@ -543,8 +645,16 @@ def __init__( read_existing_grids=False, minimum_depth=4, tidal_constituents=["M2"], + create_empty = False, + name = None ): + if create_empty: + return + + # ## Set up the experiment with no config file ## in case list was given, convert to tuples + if name is not None: + self.name = name self.longitude_extent = tuple(longitude_extent) self.latitude_extent = tuple(latitude_extent) self.date_range = tuple(date_range) @@ -572,6 +682,7 @@ def __init__( minimum_depth # Minimum depth. Shallower water will be masked out. ) self.tidal_constituents = tidal_constituents + if read_existing_grids: try: self.hgrid = xr.open_dataset(self.mom_input_dir / "hgrid.nc") @@ -601,6 +712,9 @@ def __init__( if not input_rundir.exists(): input_rundir.symlink_to(self.mom_run_dir.resolve()) + def __str__(self) -> str: + return json.dumps(self.write_config_file(export = False, quiet = True), indent=4) + def __getattr__(self, name): available_methods = [ method for method in dir(self) if not method.startswith("__") @@ -714,10 +828,11 @@ def ocean_state_boundaries(self): ocean_state_path = self.mom_input_dir / "forcing" try: # Use glob to find all tides files - patterns = ["forcing_*", "weights/bi*"] + patterns = ["forcing_*", "weights/bi*",] all_files = [] for pattern in patterns: all_files.extend(glob.glob(os.path.join(ocean_state_path, pattern))) + all_files.extend(glob.glob(os.path.join(self.mom_input_dir, pattern))) if len(all_files) == 0: return "No ocean state files set up yet (or files misplaced from {}). Call `setup_ocean_state_boundaries` method to set up ocean state.".format( @@ -725,8 +840,8 @@ def ocean_state_boundaries(self): ) # Open the files as xarray datasets - datasets = [xr.open_dataset(file) for file in all_files] - return datasets + # datasets = [xr.open_dataset(file) for file in all_files] + return all_files except: return "Error retrieving ocean state files" @@ -742,6 +857,7 @@ def tides_boundaries(self): all_files = [] for pattern in patterns: all_files.extend(glob.glob(os.path.join(tides_path, pattern))) + all_files.extend(glob.glob(os.path.join(self.mom_input_dir, pattern))) if len(all_files) == 0: return "No tides files set up yet (or files misplaced from {}). Call `setup_tides_boundaries` method to set up tides.".format( @@ -749,8 +865,8 @@ def tides_boundaries(self): ) # Open the files as xarray datasets - datasets = [xr.open_dataset(file) for file in all_files] - return datasets + # datasets = [xr.open_dataset(file) for file in all_files] + return all_files except: return "Error retrieving tides files" @@ -769,8 +885,8 @@ def era5(self): ) # Open the files as xarray datasets - datasets = [xr.open_dataset(file) for file in all_files] - return datasets + # datasets = [xr.open_dataset(file) for file in all_files] + return all_files except: return "Error retrieving ERA5 files" @@ -779,12 +895,20 @@ def initial_condition(self): """ Read the ic's from disk, and print 'em """ - + forcing_path = self.mom_input_dir / "forcing" try: - ic_tracers = xr.open_dataset(self.mom_input_dir / "forcing/init_tracers.nc") - ic_vel = xr.open_dataset(self.mom_input_dir / "forcing/init_vel.nc") - ic_eta = xr.open_dataset(self.mom_input_dir / "forcing/init_eta.nc") - return [ic_tracers, ic_vel, ic_eta] + all_files = glob.glob(os.path.join(forcing_path, "init_*.nc")) + all_files = glob.glob(os.path.join(self.mom_input_dir , "init_*.nc")) + if len(all_files) == 0: + return "No initial conditions files set up yet (or files misplaced from {}). Call `setup_initial_condition` method to set up initial conditions.".format( + forcing_path + ) + + # Open the files as xarray datasets + # datasets = [xr.open_dataset(file) for file in all_files] + # return datasets + + return all_files except: return "No initial condition set up yet (or files misplaced from {}). Call `setup_initial_condition` method to set up initial conditions.".format( self.mom_input_dir / "forcing" @@ -798,12 +922,66 @@ def bathymetry_property(self): try: bathy = xr.open_dataset(self.mom_input_dir / "bathymetry.nc") - return [bathy] + #return [bathy] + return str(self.mom_input_dir / "bathymetry.nc") except: return "No bathymetry set up yet (or files misplaced from {}). Call `setup_bathymetry` method to set up bathymetry.".format( self.mom_input_dir ) + def write_config_file(self, export=True, quiet = False): + """ + Write a configuration file for the experiment. This is a simple json file + that contains the expirment object information to allow for reproducibility, to pick up where a user left off, and + to make information about the expirement readable. + """ + if not quiet: + print("Writing Config File.....") + ## check if files exist + vgrid_path = None + hgrid_path = None + if os.path.exists(self.mom_input_dir/"vcoord.nc"): + vgrid_path = self.mom_input_dir/"vcoord.nc" + if os.path.exists(self.mom_input_dir/"hgrid.nc"): + hgrid_path = self.mom_input_dir/"hgrid.nc" + config_dict = { + "name": self.name, + "date_range": [ + self.date_range[0].strftime("%Y-%m-%d"), + self.date_range[1].strftime("%Y-%m-%d"), + ], + "latitude_extent": self.latitude_extent, + "longitude_extent": self.longitude_extent, + "run_dir": str(self.mom_run_dir), + "input_dir": str(self.mom_input_dir), + 'toolpath_dir': str(self.toolpath_dir), + "resolution": self.resolution, + "number_vertical_layers": self.number_vertical_layers, + "layer_thickness_ratio": self.layer_thickness_ratio, + "depth": self.depth, + "grid_type": self.grid_type, + "repeat_year_forcing": self.repeat_year_forcing, + "ocean_mask": self.ocean_mask, + "layout": self.layout, + "min_depth": self.min_depth, + "vgrid": str(vgrid_path), + "hgrid": str(hgrid_path), + "bathymetry": self.bathymetry_property, + "ocean_state": self.ocean_state_boundaries, + "tides": self.tides_boundaries, + "initial_conditions": self.initial_condition, + "tidal_constituents": self.tidal_constituents + } + if export: + with open(self.mom_run_dir/"config.json", "w") as f: + json.dump( + config_dict, + f, + indent=4, + ) + if not quiet: + print("Done.") + return config_dict def setup_initial_condition( self, raw_ic_path,