diff --git a/deploy/default.cfg b/deploy/default.cfg index 697f21984..38bc13a07 100644 --- a/deploy/default.cfg +++ b/deploy/default.cfg @@ -24,7 +24,7 @@ geometric_features = 1.2.0 jigsaw = 0.9.14 jigsawpy = 0.3.3 mache = 1.16.0 -mpas_tools = 0.25.0 +mpas_tools = 0.27.0 otps = 2021.10 parallelio = 2.6.0 diff --git a/docs/developers_guide/ocean/api.md b/docs/developers_guide/ocean/api.md index bf31cdcca..23ce0fa8c 100644 --- a/docs/developers_guide/ocean/api.md +++ b/docs/developers_guide/ocean/api.md @@ -314,3 +314,22 @@ vertical.zlevel.compute_z_level_resting_thickness vertical.zstar.init_z_star_vertical_coord ``` + +### Visualization + +```{eval-rst} +.. currentmodule:: polaris.ocean.viz + +.. autosummary:: + :toctree: generated/ + + compute_transect + plot_transect + transect.horiz.find_spherical_transect_cells_and_weights + transect.horiz.find_planar_transect_cells_and_weights + transect.horiz.make_triangle_tree + transect.horiz.mesh_to_triangles + transect.vert.find_transect_levels_and_weights + transect.vert.interp_mpas_to_transect_cells + transect.vert.interp_mpas_to_transect_nodes + diff --git a/docs/developers_guide/ocean/framework.md b/docs/developers_guide/ocean/framework.md index dfcea88ca..736c03c61 100644 --- a/docs/developers_guide/ocean/framework.md +++ b/docs/developers_guide/ocean/framework.md @@ -422,3 +422,15 @@ density, which is horizontally constant and increases with depth. The {py:func}`polaris.ocean.rpe.compute_rpe()` is used to compute the RPE as a function of time in a series of one or more output files. The RPE is stored in `rpe.csv` and also returned as a numpy array for plotting and analysis. + +## Visualization + +The `polaris.ocean.viz` module provides functions for making plots that are +specific to the ocean component. + +The `polaris.ocean.viz.transect` modules includes functions for computing +({py:func}`polaris.ocean.viz.compute_transect()`) and plotting +({py:func}`polaris.ocean.viz.plot_transect()`) transects through the ocean +from a sequence of x-y or latitude-longitude coordinates. Currently, only +transects on xarray data arrays with dimensions `nCells` by `nVertLevels` are +supported. diff --git a/polaris/ocean/tasks/baroclinic_channel/init.py b/polaris/ocean/tasks/baroclinic_channel/init.py index 983efb7ec..5093857ae 100644 --- a/polaris/ocean/tasks/baroclinic_channel/init.py +++ b/polaris/ocean/tasks/baroclinic_channel/init.py @@ -8,6 +8,7 @@ from polaris import Step from polaris.mesh.planar import compute_planar_hex_nx_ny from polaris.ocean.vertical import init_vertical_coord +from polaris.ocean.viz import compute_transect, plot_transect from polaris.viz import plot_horiz_field @@ -162,8 +163,34 @@ def run(self): write_netcdf(ds, 'initial_state.nc') cell_mask = ds.maxLevelCell >= 1 - plot_horiz_field(ds, ds_mesh, 'temperature', - 'initial_temperature.png', cell_mask=cell_mask) + plot_horiz_field(ds, ds_mesh, 'normalVelocity', 'initial_normal_velocity.png', cmap='cmo.balance', show_patch_edges=True, cell_mask=cell_mask) + + y_min = ds_mesh.yVertex.min().values + y_max = ds_mesh.yVertex.max().values + x_mid = ds_mesh.xCell.median().values + + y = xr.DataArray(data=np.linspace(y_min, y_max, 2), dims=('nPoints',)) + x = x_mid * xr.ones_like(y) + + ds_transect = compute_transect( + x=x, y=y, ds_horiz_mesh=ds_mesh, + layer_thickness=ds.layerThickness.isel(Time=0), + bottom_depth=ds.bottomDepth, min_level_cell=ds.minLevelCell - 1, + max_level_cell=ds.maxLevelCell - 1, spherical=False) + + field_name = 'temperature' + vmin = ds[field_name].min().values + vmax = ds[field_name].max().values + plot_transect(ds_transect=ds_transect, + mpas_field=ds[field_name].isel(Time=0), + title=f'{field_name} at x={1e-3 * x_mid:.1f} km', + out_filename=f'initial_{field_name}_section.png', + vmin=vmin, vmax=vmax, cmap='cmo.thermal', + colorbar_label=r'$^\circ$C', color_start_and_end=True) + + plot_horiz_field(ds, ds_mesh, 'temperature', 'initial_temperature.png', + vmin=vmin, vmax=vmax, cmap='cmo.thermal', + cell_mask=cell_mask, transect_x=x, transect_y=y) diff --git a/polaris/ocean/tasks/baroclinic_channel/viz.py b/polaris/ocean/tasks/baroclinic_channel/viz.py index aa2416c92..617253e71 100644 --- a/polaris/ocean/tasks/baroclinic_channel/viz.py +++ b/polaris/ocean/tasks/baroclinic_channel/viz.py @@ -3,6 +3,7 @@ import xarray as xr from polaris import Step +from polaris.ocean.viz import compute_transect, plot_transect from polaris.viz import plot_horiz_field @@ -42,9 +43,6 @@ def run(self): ds = xr.load_dataset('output.nc') t_index = ds.sizes['Time'] - 1 cell_mask = ds_init.maxLevelCell >= 1 - plot_horiz_field(ds, ds_mesh, 'temperature', - 'final_temperature.png', t_index=t_index, - cell_mask=cell_mask) max_velocity = np.max(np.abs(ds.normalVelocity.values)) plot_horiz_field(ds, ds_mesh, 'normalVelocity', 'final_normalVelocity.png', @@ -52,3 +50,33 @@ def run(self): vmin=-max_velocity, vmax=max_velocity, cmap='cmo.balance', show_patch_edges=True, cell_mask=cell_mask) + + y_min = ds_mesh.yVertex.min().values + y_max = ds_mesh.yVertex.max().values + x_mid = ds_mesh.xCell.median().values + + y = xr.DataArray(data=np.linspace(y_min, y_max, 2), dims=('nPoints',)) + x = x_mid * xr.ones_like(y) + + ds_transect = compute_transect( + x=x, y=y, ds_horiz_mesh=ds_mesh, + layer_thickness=ds.layerThickness.isel(Time=t_index), + bottom_depth=ds_init.bottomDepth, + min_level_cell=ds_init.minLevelCell - 1, + max_level_cell=ds_init.maxLevelCell - 1, + spherical=False) + + field_name = 'temperature' + vmin = ds[field_name].min().values + vmax = ds[field_name].max().values + mpas_field = ds[field_name].isel(Time=t_index) + plot_transect(ds_transect=ds_transect, mpas_field=mpas_field, + title=f'{field_name} at x={1e-3 * x_mid:.1f} km', + out_filename=f'final_{field_name}_section.png', + vmin=vmin, vmax=vmax, cmap='cmo.thermal', + colorbar_label=r'$^\circ$C', color_start_and_end=True) + + plot_horiz_field(ds, ds_mesh, 'temperature', 'final_temperature.png', + t_index=t_index, vmin=vmin, vmax=vmax, + cmap='cmo.thermal', cell_mask=cell_mask, transect_x=x, + transect_y=y) diff --git a/polaris/ocean/viz/__init__.py b/polaris/ocean/viz/__init__.py new file mode 100644 index 000000000..0efd37287 --- /dev/null +++ b/polaris/ocean/viz/__init__.py @@ -0,0 +1,2 @@ +from polaris.ocean.viz.transect.plot import plot_transect +from polaris.ocean.viz.transect.vert import compute_transect diff --git a/polaris/ocean/viz/transect/__init__.py b/polaris/ocean/viz/transect/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/polaris/ocean/viz/transect/horiz.py b/polaris/ocean/viz/transect/horiz.py new file mode 100644 index 000000000..6e16e9198 --- /dev/null +++ b/polaris/ocean/viz/transect/horiz.py @@ -0,0 +1,889 @@ +import numpy as np +import xarray as xr +from mpas_tools.transects import ( + cartesian_to_lon_lat, + lon_lat_to_cartesian, + subdivide_great_circle, + subdivide_planar, +) +from mpas_tools.vector import Vector +from scipy.spatial import cKDTree +from shapely.geometry import LineString, Point + + +def mesh_to_triangles(ds_mesh): + """ + Construct a dataset in which each MPAS cell is divided into the triangles + connecting pairs of adjacent vertices to cell centers. + + Parameters + ---------- + ds_mesh : xarray.Dataset + An MPAS mesh + + Returns + ------- + ds_tris : xarray.Dataset + A dataset that defines triangles connecting pairs of adjacent vertices + to cell centers as well as the cell index that each triangle is in and + cell indices and weights for interpolating data defined at cell centers + to triangle nodes. ``ds_tris`` includes variables ``triCellIndices``, + the cell that each triangle is part of; ``nodeCellIndices`` and + ``nodeCellWeights``, the indices and weights used to interpolate from + MPAS cell centers to triangle nodes; Cartesian coordinates ``xNode``, + ``yNode``, and ``zNode``; and ``lonNode``` and ``latNode`` in radians. + ``lonNode`` is guaranteed to be within 180 degrees of the cell center + corresponding to ``triCellIndices``. Nodes always have a + counterclockwise winding. + + """ + n_vertices_on_cell = ds_mesh.nEdgesOnCell.values + vertices_on_cell = ds_mesh.verticesOnCell.values - 1 + cells_on_vertex = ds_mesh.cellsOnVertex.values - 1 + + on_a_sphere = ds_mesh.attrs['on_a_sphere'].strip() == 'YES' + is_periodic = False + x_period = None + y_period = None + if not on_a_sphere: + is_periodic = ds_mesh.attrs['is_periodic'].strip() == 'YES' + if is_periodic: + x_period = ds_mesh.attrs['x_period'] + y_period = ds_mesh.attrs['y_period'] + + kite_areas_on_vertex = ds_mesh.kiteAreasOnVertex.values + + n_triangles = np.sum(n_vertices_on_cell) + + max_edges = ds_mesh.sizes['maxEdges'] + n_cells = ds_mesh.sizes['nCells'] + if ds_mesh.sizes['vertexDegree'] != 3: + raise ValueError('mesh_to_triangles only supports meshes with ' + 'vertexDegree = 3') + + # find the third vertex for each triangle + next_vertex = -1 * np.ones(vertices_on_cell.shape, int) + for i_vertex in range(max_edges): + valid = i_vertex < n_vertices_on_cell + invalid = np.logical_not(valid) + vertices_on_cell[invalid, i_vertex] = -1 + nv = n_vertices_on_cell[valid] + cell_indices = np.arange(0, n_cells)[valid] + i_next = np.where(i_vertex < nv - 1, i_vertex + 1, 0) + next_vertex[:, i_vertex][valid] = ( + vertices_on_cell[cell_indices, i_next]) + + valid = vertices_on_cell >= 0 + vertices_on_cell = vertices_on_cell[valid] + next_vertex = next_vertex[valid] + + # find the cell index for each triangle + tri_cell_indices, _ = np.meshgrid(np.arange(0, n_cells), + np.arange(0, max_edges), + indexing='ij') + tri_cell_indices = tri_cell_indices[valid] + + # find list of cells and weights for each triangle node + node_cell_indices = -1 * np.ones((n_triangles, 3, 3), dtype=int) + node_cell_weights = np.zeros((n_triangles, 3, 3)) + + # the first node is at the cell center, so the value is just the one from + # that cell + node_cell_indices[:, 0, 0] = tri_cell_indices + node_cell_weights[:, 0, 0] = 1. + + # the other 2 nodes are associated with vertices + node_cell_indices[:, 1, :] = cells_on_vertex[vertices_on_cell, :] + node_cell_weights[:, 1, :] = kite_areas_on_vertex[vertices_on_cell, :] + node_cell_indices[:, 2, :] = cells_on_vertex[next_vertex, :] + node_cell_weights[:, 2, :] = kite_areas_on_vertex[next_vertex, :] + + weight_sum = np.sum(node_cell_weights, axis=2) + for i_node in range(3): + node_cell_weights[:, :, i_node] = ( + node_cell_weights[:, :, i_node] / weight_sum) + + ds_tris = xr.Dataset() + ds_tris['triCellIndices'] = ('nTriangles', tri_cell_indices) + ds_tris['nodeCellIndices'] = (('nTriangles', 'nNodes', 'nInterp'), + node_cell_indices) + ds_tris['nodeCellWeights'] = (('nTriangles', 'nNodes', 'nInterp'), + node_cell_weights) + + # get Cartesian and lon/lat coordinates of each node + for prefix in ['x', 'y', 'z', 'lat', 'lon']: + out_var = f'{prefix}Node' + cell_var = f'{prefix}Cell' + vertex_var = f'{prefix}Vertex' + coord = np.zeros((n_triangles, 3)) + coord[:, 0] = ds_mesh[cell_var].values[tri_cell_indices] + coord[:, 1] = ds_mesh[vertex_var].values[vertices_on_cell] + coord[:, 2] = ds_mesh[vertex_var].values[next_vertex] + ds_tris[out_var] = (('nTriangles', 'nNodes'), coord) + + # nothing obvious we can do about triangles containing the poles + + if on_a_sphere: + ds_tris = _fix_periodic_tris(ds_tris, periodic_var='lonNode', + period=2 * np.pi) + elif is_periodic: + ds_tris = _fix_periodic_tris(ds_tris, periodic_var='xNode', + period=x_period) + ds_tris = _fix_periodic_tris(ds_tris, periodic_var='yNode', + period=y_period) + + return ds_tris + + +def make_triangle_tree(ds_tris): + """ + Make a KD-Tree for finding triangle edges that are near enough to transect + segments that they might intersect + + Parameters + ---------- + ds_tris : xarray.Dataset + A dataset that defines triangles, the results of calling + :py:func:`polaris.ocean.viz.transect.horiz.mesh_to_triangles()` + + Returns + ------- + tree : scipy.spatial.cKDTree + A tree of edge centers from triangles making up an MPAS mesh + """ + + n_triangles = ds_tris.sizes['nTriangles'] + n_nodes = ds_tris.sizes['nNodes'] + node_coords = np.zeros((n_triangles * n_nodes, 3)) + node_coords[:, 0] = ds_tris.xNode.values.ravel() + node_coords[:, 1] = ds_tris.yNode.values.ravel() + node_coords[:, 2] = ds_tris.zNode.values.ravel() + + next_tri, next_node = np.meshgrid( + np.arange(n_triangles), np.mod(np.arange(n_nodes) + 1, 3), + indexing='ij') + nextIndices = n_nodes * next_tri.ravel() + next_node.ravel() + + # edge centers are half way between adjacent nodes (ignoring great-circle + # distance) + edgeCoords = 0.5 * (node_coords + node_coords[nextIndices, :]) + + tree = cKDTree(data=edgeCoords, copy_data=True) + return tree + + +def find_spherical_transect_cells_and_weights( + lon_transect, lat_transect, ds_tris, ds_mesh, tree, degrees=True, + earth_radius=None, subdivision_res=10e3): + """ + Find "nodes" where the transect intersects the edges of the triangles + that make up MPAS cells. + + Parameters + ---------- + lon_transect : xarray.DataArray + The longitude of segments making up the transect + + lat_transect : xarray.DataArray + The latitude of segments making up the transect + + ds_tris : xarray.Dataset + A dataset that defines triangles, the results of calling + :py:func:`polaris.ocean.viz.transect.horiz.mesh_to_triangles()` + + ds_mesh : xarray.Dataset + A data set with the full MPAS mesh. + + tree : scipy.spatial.cKDTree + A tree of edge centers from triangles making up an MPAS mesh, the + return value from + :py:func:`polaris.ocean.viz.transect.horiz.make_triangle_tree()` + + degrees : bool, optional + Whether ``lon_transect`` and ``lat_transect`` are in degrees (as + opposed to radians). + + subdivision_res : float, optional + Resolution in m to use to subdivide the transect when looking for + intersection candidates. Should be small enough that curvature is + small. + + earth_radius : float, optional + The radius of the Earth in meters, taken from the `sphere_radius` + global attribute if not provided + + Returns + ------- + ds_out : xarray.Dataset + A dataset that contains "nodes" where the transect intersects the + edges of the triangles in ``ds_tris``. The nodes also includes the two + end points of the transect, which typically lie within triangles. Each + internal node (that is, not including the end points) is purposefully + repeated twice, once for each triangle that node touches. This allows + for discontinuous fields between triangles (e.g. if one wishes to plot + constant values on each MPAS cell). The Cartesian and lon/lat + coordinates of these nodes are ``xCartNode``, ``yCartNode``, + ``zCartNode``, ``lonNode`` and ``latNode``. The distance along the + transect of each intersection is ``dNode``. The index of the triangle + and the first triangle node in ``ds_tris`` associated with each + intersection node are given by ``horizTriangleIndices`` and + ``horizTriangleNodeIndices``, respectively. The second node on the + triangle for the edge associated with the intersection is given by + ``numpy.mod(horizTriangleNodeIndices + 1, 3)``. + + The MPAS cell that a given node belongs to is given by + ``horizCellIndices``. Each node also has an associated set of 6 + ``interpHorizCellIndices`` and ``interpHorizCellWeights`` that can be + used to interpolate from MPAS cell centers to nodes first with + area-weighted averaging to MPAS vertices and then linear interpolation + along triangle edges. Some of the weights may be zero, in which case + the associated ``interpHorizCellIndices`` will be -1. + + Finally, ``lonTransect`` and ``latTransect`` are included in the + dataset, along with Cartesian coordinates ``xCartTransect``, + ``yCartTransect``, `zCartTransect``, and ``dTransect``, the + great-circle distance along the transect of each original transect + point. In order to interpolate values (e.g. observations) from the + original transect points to the intersection nodes, linear + interpolation indices ``transectIndicesOnHorizNode`` and weights + ``transectWeightsOnHorizNode`` are provided. The values at nodes are + found by:: + + nodeValues = ((transectValues[transectIndicesOnHorizNode] * + transectWeightsOnHorizNode) + + (transectValues[transectIndicesOnHorizNode+1] * + (1.0 - transectWeightsOnHorizNode)) + """ + if earth_radius is None: + earth_radius = ds_mesh.attrs['sphere_radius'] + buffer = np.maximum(np.amax(ds_mesh.dvEdge.values), + np.amax(ds_mesh.dcEdge.values)) + + x, y, z = lon_lat_to_cartesian(lon_transect, lat_transect, earth_radius, + degrees) + + n_nodes = ds_tris.sizes['nNodes'] + node_cell_weights = ds_tris.nodeCellWeights.values + node_cell_indices = ds_tris.nodeCellIndices.values + + x_node = ds_tris.xNode.values.ravel() + y_node = ds_tris.yNode.values.ravel() + z_node = ds_tris.zNode.values.ravel() + + d_transect = np.zeros(lon_transect.shape) + + d_node = None + x_out = None + y_out = None + z_out = None + tris = None + nodes = None + interp_cells = None + cell_weights = None + + n_horiz_weights = 6 + + first = True + + d_start = 0. + for seg_index in range(len(x) - 1): + transectv0 = Vector(x[seg_index].values, + y[seg_index].values, + z[seg_index].values) + transectv1 = Vector(x[seg_index + 1].values, + y[seg_index + 1].values, + z[seg_index + 1].values) + + sub_slice = slice(seg_index, seg_index + 2) + x_sub, y_sub, z_sub, _, _ = subdivide_great_circle( + x[sub_slice].values, y[sub_slice].values, z[sub_slice].values, + subdivision_res, earth_radius) + + coords = np.zeros((len(x_sub), 3)) + coords[:, 0] = x_sub + coords[:, 1] = y_sub + coords[:, 2] = z_sub + radius = buffer + subdivision_res + + index_list = tree.query_ball_point(x=coords, r=radius) + + unique_indices = set() + for indices in index_list: + unique_indices.update(indices) + + n0_indices_cand = np.array(list(unique_indices)) + + if len(n0_indices_cand) == 0: + continue + + tris_cand = n0_indices_cand // n_nodes + next_node_index = np.mod(n0_indices_cand + 1, n_nodes) + n1_indices_cand = n_nodes * tris_cand + next_node_index + + n0_cand = Vector(x_node[n0_indices_cand], + y_node[n0_indices_cand], + z_node[n0_indices_cand]) + n1_cand = Vector(x_node[n1_indices_cand], + y_node[n1_indices_cand], + z_node[n1_indices_cand]) + + intersect = Vector.intersects(n0_cand, n1_cand, transectv0, + transectv1) + + n0_inter = Vector(n0_cand.x[intersect], + n0_cand.y[intersect], + n0_cand.z[intersect]) + n1_inter = Vector(n1_cand.x[intersect], + n1_cand.y[intersect], + n1_cand.z[intersect]) + + tris_inter = tris_cand[intersect] + n0_indices_inter = n0_indices_cand[intersect] + n1_indices_inter = n1_indices_cand[intersect] + + intersections = Vector.intersection(n0_inter, n1_inter, transectv0, + transectv1) + intersections = Vector(earth_radius * intersections.x, + earth_radius * intersections.y, + earth_radius * intersections.z) + + angular_distance = transectv0.angular_distance(intersections) + + d_node_local = d_start + earth_radius * angular_distance + + d_start += earth_radius * transectv0.angular_distance(transectv1) + + node0_inter = np.mod(n0_indices_inter, n_nodes) + node1_inter = np.mod(n1_indices_inter, n_nodes) + + node_weights = (intersections.angular_distance(n1_inter) / + n0_inter.angular_distance(n1_inter)) + + weights = np.zeros((len(tris_inter), n_horiz_weights)) + cell_indices = np.zeros((len(tris_inter), n_horiz_weights), int) + for index in range(3): + weights[:, index] = ( + node_weights * + node_cell_weights[tris_inter, node0_inter, index]) + cell_indices[:, index] = ( + node_cell_indices[tris_inter, node0_inter, index]) + weights[:, index + 3] = ( + (1.0 - node_weights) * + node_cell_weights[tris_inter, node1_inter, index]) + cell_indices[:, index + 3] = ( + node_cell_indices[tris_inter, node1_inter, index]) + + if first: + x_out = intersections.x + y_out = intersections.y + z_out = intersections.z + d_node = d_node_local + + tris = tris_inter + nodes = node0_inter + interp_cells = cell_indices + cell_weights = weights + first = False + else: + x_out = np.append(x_out, intersections.x) + y_out = np.append(y_out, intersections.y) + z_out = np.append(z_out, intersections.z) + d_node = np.append(d_node, d_node_local) + + tris = np.concatenate((tris, tris_inter)) + nodes = np.concatenate((nodes, node0_inter)) + interp_cells = np.concatenate((interp_cells, cell_indices), axis=0) + cell_weights = np.concatenate((cell_weights, weights), axis=0) + + d_transect[seg_index + 1] = d_start + + epsilon = 1e-6 * subdivision_res + (d_node, x_out, y_out, z_out, seg_tris, seg_nodes, interp_cells, + cell_weights, valid_nodes) = _sort_intersections( + d_node, tris, nodes, x_out, y_out, z_out, interp_cells, cell_weights, + epsilon) + + lon_out, lat_out = cartesian_to_lon_lat(x_out, y_out, z_out, earth_radius, + degrees) + + valid_segs = seg_tris >= 0 + cell_indices = -1 * np.ones(seg_tris.shape, dtype=int) + cell_indices[valid_segs] = ( + ds_tris.triCellIndices.values[seg_tris[valid_segs]]) + + ds_out = xr.Dataset() + ds_out['xCartNode'] = (('nNodes',), x_out) + ds_out['yCartNode'] = (('nNodes',), y_out) + ds_out['zCartNode'] = (('nNodes',), z_out) + ds_out['dNode'] = (('nNodes',), d_node) + ds_out['lonNode'] = (('nNodes',), lon_out) + ds_out['latNode'] = (('nNodes',), lat_out) + + ds_out['horizTriangleIndices'] = ('nSegments', seg_tris) + ds_out['horizCellIndices'] = ('nSegments', cell_indices) + ds_out['horizTriangleNodeIndices'] = (('nSegments', 'nHorizBounds'), + seg_nodes) + ds_out['interpHorizCellIndices'] = (('nNodes', 'nHorizWeights'), + interp_cells) + ds_out['interpHorizCellWeights'] = (('nNodes', 'nHorizWeights'), + cell_weights) + ds_out['validNodes'] = (('nNodes',), valid_nodes) + + transect_indices_on_horiz_node = np.zeros(d_node.shape, dtype=int) + transect_weights_on_horiz_node = np.zeros(d_node.shape) + for trans_index in range(len(d_transect) - 1): + d0 = d_transect[trans_index] + d1 = d_transect[trans_index + 1] + mask = np.logical_and(d_node >= d0, d_node < d1) + transect_indices_on_horiz_node[mask] = trans_index + transect_weights_on_horiz_node[mask] = (d1 - d_node[mask]) / (d1 - d0) + # last index will get missed by the mask and needs to be handled as a + # special case + transect_indices_on_horiz_node[-1] = len(d_transect) - 2 + transect_weights_on_horiz_node[-1] = 0.0 + + ds_out['lonTransect'] = lon_transect + ds_out['latTransect'] = lat_transect + ds_out['xCartTransect'] = x + ds_out['yCartTransect'] = y + ds_out['zCartTransect'] = z + ds_out['dTransect'] = (lon_transect.dims, d_transect) + ds_out['transectIndicesOnHorizNode'] = (('nNodes',), + transect_indices_on_horiz_node) + ds_out['transectWeightsOnHorizNode'] = (('nNodes',), + transect_weights_on_horiz_node) + + return ds_out + + +def find_planar_transect_cells_and_weights( + x_transect, y_transect, ds_tris, ds_mesh, tree, subdivision_res=10e3): + """ + Find "nodes" where the transect intersects the edges of the triangles + that make up MPAS cells. + + Parameters + ---------- + x_transect : xarray.DataArray + The x points defining segments making up the transect + + y_transect : xarray.DataArray + The y points defining segments making up the transect + + ds_tris : xarray.Dataset + A dataset that defines triangles, the results of calling + :py:func:`polaris.ocean.viz.transect.horiz.mesh_to_triangles()` + + ds_mesh : xarray.Dataset + A data set with the full MPAS mesh. + + tree : scipy.spatial.cKDTree + A tree of edge centers from triangles making up an MPAS mesh, the + return value from + :py:func:`polaris.ocean.viz.transect.horiz.make_triangle_tree()` + + subdivision_res : float, optional + Resolution in m to use to subdivide the transect when looking for + intersection candidates. Should be small enough that curvature is + small. + + Returns + ------- + ds_out : xarray.Dataset + A dataset that contains "nodes" where the transect intersects the + edges of the triangles in ``ds_tris``. The nodes also include the two + end points of the transect, which typically lie within triangles. Each + internal node (that is, not including the end points) is purposefully + repeated twice, once for each triangle that node touches. This allows + for discontinuous fields between triangles (e.g. if one wishes to plot + constant values on each MPAS cell). The planar coordinates of these + nodes are ``xNode`` and ``yNode``. The distance along the transect of + each intersection is ``dNode``. The index of the triangle and the first + triangle node in ``ds_tris`` associated with each intersection node are + given by ``horizTriangleIndices`` and ``horizTriangleNodeIndices``, + respectively. The second node on the triangle for the edge associated + with the intersection is given by + ``numpy.mod(horizTriangleNodeIndices + 1, 3)``. + + The MPAS cell that a given node belongs to is given by + ``horizCellIndices``. Each node also has an associated set of 6 + ``interpHorizCellIndices`` and ``interpHorizCellWeights`` that can be + used to interpolate from MPAS cell centers to nodes first with + area-weighted averaging to MPAS vertices and then linear interpolation + along triangle edges. Some of the weights may be zero, in which case + the associated ``interpHorizCellIndices`` will be -1. + + Finally, ``xTransect`` and ``yTransect`` are included in the + dataset, along with ``dTransect``, the distance along the transect of + each original transect point. In order to interpolate values (e.g. + observations) from the original transect points to the intersection + nodes, linear interpolation indices ``transectIndicesOnHorizNode`` and + weights ``transectWeightsOnHorizNode`` are provided. The values at + nodes are found by:: + + nodeValues = ((transectValues[transectIndicesOnHorizNode] * + transectWeightsOnHorizNode) + + (transectValues[transectIndicesOnHorizNode+1] * + (1.0 - transectWeightsOnHorizNode)) + """ + buffer = np.maximum(np.amax(ds_mesh.dvEdge.values), + np.amax(ds_mesh.dcEdge.values)) + + n_nodes = ds_tris.sizes['nNodes'] + node_cell_weights = ds_tris.nodeCellWeights.values + node_cell_indices = ds_tris.nodeCellIndices.values + + x = x_transect + y = y_transect + + x_node = ds_tris.xNode.values.ravel() + y_node = ds_tris.yNode.values.ravel() + + coordNode = np.zeros((len(x_node), 2)) + coordNode[:, 0] = x_node + coordNode[:, 1] = y_node + + d_transect = np.zeros(x_transect.shape) + + d_node = None + x_out = np.array([]) + y_out = np.array([]) + tris = None + nodes = None + interp_cells = None + cell_weights = None + + n_horiz_weights = 6 + + first = True + + d_start = 0. + for seg_index in range(len(x) - 1): + + sub_slice = slice(seg_index, seg_index + 2) + x_sub, y_sub, _, _ = subdivide_planar( + x[sub_slice].values, y[sub_slice].values, subdivision_res) + + start_point = Point(x_transect[seg_index].values, + y_transect[seg_index].values) + end_point = Point(x_transect[seg_index + 1].values, + y_transect[seg_index + 1].values) + + segment = LineString([start_point, end_point]) + + coords = np.zeros((len(x_sub), 3)) + coords[:, 0] = x_sub + coords[:, 1] = y_sub + radius = buffer + subdivision_res + + index_list = tree.query_ball_point(x=coords, r=radius) + + unique_indices = set() + for indices in index_list: + unique_indices.update(indices) + + start_indices = np.array(list(unique_indices)) + + if len(start_indices) == 0: + continue + + tris_cand = start_indices // n_nodes + next_node_index = np.mod(start_indices + 1, n_nodes) + end_indices = n_nodes * tris_cand + next_node_index + + intersecting_nodes = list() + tris_inter_list = list() + x_intersection_list = list() + y_intersection_list = list() + node_weights_list = list() + node0_inter_list = list() + node1_inter_list = list() + distances_list = list() + + for index in range(len(start_indices)): + start = start_indices[index] + end = end_indices[index] + + node0 = Point(coordNode[start, 0], coordNode[start, 1]) + node1 = Point(coordNode[end, 0], coordNode[end, 1]) + + edge = LineString([node0, node1]) + if segment.intersects(edge): + point = segment.intersection(edge) + intersecting_nodes.append((node0, node1, start, end, edge)) + + if isinstance(point, LineString): + raise ValueError('A triangle edge exactly coincides with ' + 'a transect segment and I can\'t handle ' + 'that case. Try moving the transect a ' + 'tiny bit.') + elif not isinstance(point, Point): + raise ValueError(f'Unexpected intersection type {point}') + + x_intersection_list.append(point.x) + y_intersection_list.append(point.y) + + start_to_intersection = LineString([start_point, point]) + + weight = (LineString([point, node1]).length / + LineString([node0, node1]).length) + + node_weights_list.append(weight) + node0_inter_list.append(np.mod(start, n_nodes)) + node1_inter_list.append(np.mod(end, n_nodes)) + distances_list.append(start_to_intersection.length) + tris_inter_list.append(tris_cand[index]) + + distances = np.array(distances_list) + x_intersection = np.array(x_intersection_list) + y_intersection = np.array(y_intersection_list) + node_weights = np.array(node_weights_list) + node0_inter = np.array(node0_inter_list, dtype=int) + node1_inter = np.array(node1_inter_list, dtype=int) + tris_inter = np.array(tris_inter_list, dtype=int) + + d_node_local = d_start + distances + + d_start += segment.length + + weights = np.zeros((len(tris_inter), n_horiz_weights)) + cell_indices = np.zeros((len(tris_inter), n_horiz_weights), int) + for index in range(3): + weights[:, index] = ( + node_weights * + node_cell_weights[tris_inter, node0_inter, index]) + cell_indices[:, index] = ( + node_cell_indices[tris_inter, node0_inter, index]) + weights[:, index + 3] = ( + (1.0 - node_weights) * + node_cell_weights[tris_inter, node1_inter, index]) + cell_indices[:, index + 3] = ( + node_cell_indices[tris_inter, node1_inter, index]) + + if first: + x_out = x_intersection + y_out = y_intersection + d_node = d_node_local + + tris = tris_inter + nodes = node0_inter + interp_cells = cell_indices + cell_weights = weights + first = False + else: + x_out = np.append(x_out, x_intersection) + y_out = np.append(y_out, y_intersection) + d_node = np.append(d_node, d_node_local) + + tris = np.concatenate((tris, tris_inter)) + nodes = np.concatenate((nodes, node0_inter)) + interp_cells = np.concatenate((interp_cells, cell_indices), axis=0) + cell_weights = np.concatenate((cell_weights, weights), axis=0) + + d_transect[seg_index + 1] = d_start + + z_out = np.zeros(x_out.shape) + + epsilon = 1e-6 * subdivision_res + (d_node, x_out, y_out, z_out, seg_tris, seg_nodes, interp_cells, + cell_weights, valid_nodes) = _sort_intersections( + d_node, tris, nodes, x_out, y_out, z_out, interp_cells, cell_weights, + epsilon) + + valid_segs = seg_tris >= 0 + cell_indices = -1 * np.ones(seg_tris.shape, dtype=int) + cell_indices[valid_segs] = ( + ds_tris.triCellIndices.values[seg_tris[valid_segs]]) + + ds_out = xr.Dataset() + ds_out['xNode'] = (('nNodes',), x_out) + ds_out['yNode'] = (('nNodes',), y_out) + ds_out['dNode'] = (('nNodes',), d_node) + + ds_out['horizTriangleIndices'] = ('nSegments', seg_tris) + ds_out['horizCellIndices'] = ('nSegments', cell_indices) + ds_out['horizTriangleNodeIndices'] = (('nSegments', 'nHorizBounds'), + seg_nodes) + ds_out['interpHorizCellIndices'] = (('nNodes', 'nHorizWeights'), + interp_cells) + ds_out['interpHorizCellWeights'] = (('nNodes', 'nHorizWeights'), + cell_weights) + ds_out['validNodes'] = (('nNodes',), valid_nodes) + + transect_indices_on_horiz_node = np.zeros(d_node.shape, int) + transect_weights_on_horiz_node = np.zeros(d_node.shape) + for trans_index in range(len(d_transect) - 1): + d0 = d_transect[trans_index] + d1 = d_transect[trans_index + 1] + mask = np.logical_and(d_node >= d0, d_node < d1) + transect_indices_on_horiz_node[mask] = trans_index + transect_weights_on_horiz_node[mask] = (d1 - d_node[mask]) / (d1 - d0) + # last index will get missed by the mask and needs to be handled as a + # special case + transect_indices_on_horiz_node[-1] = len(d_transect) - 2 + transect_weights_on_horiz_node[-1] = 0.0 + + ds_out['xTransect'] = x + ds_out['yTransect'] = y + ds_out['dTransect'] = (x_transect.dims, d_transect) + ds_out['transectIndicesOnHorizNode'] = (('nNodes',), + transect_indices_on_horiz_node) + ds_out['transectWeightsOnHorizNode'] = (('nNodes',), + transect_weights_on_horiz_node) + + return ds_out + + +def interp_mpas_horiz_to_transect_nodes(ds_transect, da): + """ + Interpolate a 2D (``nCells``) MPAS DataArray to transect nodes, linearly + interpolating fields between the closest neighboring cells + + Parameters + ---------- + ds_transect : xr.Dataset + A dataset that defines an MPAS transect, the results of calling + ``find_spherical_transect_cells_and_weights()`` or + ``find_planar_transect_cells_and_weights()`` + + da : xr.DataArray + An MPAS 2D field with dimensions `nCells`` (possibly among others) + + Returns + ------- + da_nodes : xr.DataArray + The data array interpolated to transect nodes with dimensions + ``nNodes`` (in addition to whatever dimensions were in ``da`` besides + ``nCells``) + """ + interp_cell_indices = ds_transect.interpHorizCellIndices + interp_cell_weights = ds_transect.interpHorizCellWeights + da = da.isel(nCells=interp_cell_indices) + da_nodes = (da * interp_cell_weights).sum(dim='nHorizWeights') + + da_nodes = da_nodes.where(ds_transect.validNodes) + + return da_nodes + + +def _sort_intersections(d_node, tris, nodes, x_out, y_out, z_out, interp_cells, + cell_weights, epsilon): + """ sort nodes by distance and define segment between them """ + + sort_indices = np.argsort(d_node) + d_sorted = d_node[sort_indices] + + # make a list of indices for each unique value of d + d = d_sorted[0] + unique_d_indices = [sort_indices[0]] + unique_d_all_indices = [[sort_indices[0]]] + for index, next_d, in zip(sort_indices[1:], d_sorted[1:]): + if next_d - d < epsilon: + # this d value is effectively the same as the last, so we'll treat + # it as the same + unique_d_all_indices[-1].append(index) + else: + # this is a new d, so we'll add to a new list + d = next_d + unique_d_indices.append(index) + unique_d_all_indices.append([index]) + + # there is a segment between each unique d, though some are invalid (do + # not correspond to a triangle) + seg_tris_list = list() + seg_nodes_list = list() + + index0 = unique_d_indices[0] + indices0 = unique_d_all_indices[0] + d0 = d_node[index0] + + indices = [index0] + ds = [d0] + for seg_index in range(len(unique_d_all_indices) - 1): + indices1 = unique_d_all_indices[seg_index + 1] + index1 = unique_d_indices[seg_index + 1] + d1 = d_node[index1] + + # are there any triangles in common between this d value and the next? + tris0 = tris[indices0] + tris1 = tris[indices1] + both = set(tris0).intersection(set(tris1)) + + if len(both) > 0: + tri = both.pop() + seg_tris_list.append(tri) + indices.append(index1) + ds.append(d1) + + # the triangle nodes are the 2 corresponding to the same triangle + # in the original list + index0 = indices0[np.where(tris0 == tri)[0][0]] + index1 = indices1[np.where(tris1 == tri)[0][0]] + seg_nodes_list.append([nodes[index0], nodes[index1]]) + else: + # this is an invalid segment so we need to insert and extra invalid + # node to allow for proper masking + seg_tris_list.extend([-1, -1]) + seg_nodes_list.extend([[-1, -1], [-1, -1]]) + indices.extend([index0, index1]) + ds.extend([0.5 * (d0 + d1), d1]) + + index0 = index1 + indices0 = indices1 + d0 = d1 + + indices = np.array(indices, dtype=int) + d_node = np.array(ds, dtype=float) + seg_tris = np.array(seg_tris_list, dtype=int) + seg_nodes = np.array(seg_nodes_list, dtype=int) + + valid_nodes = np.ones(len(indices), dtype=bool) + valid_nodes[1:-1] = np.logical_or(seg_tris[0:-1] >= 0, + seg_tris[1:] > 0) + + x_out = x_out[indices] + y_out = y_out[indices] + z_out = z_out[indices] + + interp_cells = interp_cells[indices, :] + cell_weights = cell_weights[indices, :] + + return (d_node, x_out, y_out, z_out, seg_tris, seg_nodes, interp_cells, + cell_weights, valid_nodes) + + +def _fix_periodic_tris(ds_tris, periodic_var, period): + """ + make sure the given node coordinate on tris is within one period of the + cell center + """ + coord_node = ds_tris[periodic_var].values + coord_cell = coord_node[:, 0] + n_triangles = ds_tris.sizes['nTriangles'] + copy_pos = np.zeros(coord_cell.shape, dtype=bool) + copy_neg = np.zeros(coord_cell.shape, dtype=bool) + for i_node in [1, 2]: + mask = coord_node[:, i_node] - coord_cell > 0.5 * period + copy_pos = np.logical_or(copy_pos, mask) + coord_node[:, i_node][mask] = coord_node[:, i_node][mask] - period + mask = coord_node[:, i_node] - coord_cell < -0.5 * period + copy_neg = np.logical_or(copy_neg, mask) + coord_node[:, i_node][mask] = coord_node[:, i_node][mask] + period + + pos_indices = np.nonzero(copy_pos)[0] + neg_indices = np.nonzero(copy_neg)[0] + tri_indices = np.append(np.append(np.arange(0, n_triangles), + pos_indices), neg_indices) + + ds_new = xr.Dataset(ds_tris) + ds_new[periodic_var] = (('nTriangles', 'nNodes'), coord_node) + ds_new = ds_new.isel(nTriangles=tri_indices) + coord_node = ds_new[periodic_var].values + + pos_slice = slice(n_triangles, n_triangles + len(pos_indices)) + coord_node[pos_slice, :] = coord_node[pos_slice, :] + period + neg_slice = slice(n_triangles + len(pos_indices), + n_triangles + len(pos_indices) + len(neg_indices)) + coord_node[neg_slice, :] = coord_node[neg_slice, :] - period + ds_new[periodic_var] = (('nTriangles', 'nNodes'), coord_node) + return ds_new diff --git a/polaris/ocean/viz/transect/plot.py b/polaris/ocean/viz/transect/plot.py new file mode 100644 index 000000000..3ebac1e2f --- /dev/null +++ b/polaris/ocean/viz/transect/plot.py @@ -0,0 +1,210 @@ +import cmocean # noqa: F401 +import matplotlib.pyplot as plt +import numpy as np + +from polaris.ocean.viz.transect.vert import ( + interp_mpas_to_transect_cells, + interp_mpas_to_transect_nodes, +) +from polaris.viz.style import use_mplstyle + + +def plot_transect(ds_transect, mpas_field=None, out_filename=None, ax=None, + title=None, vmin=None, vmax=None, colorbar_label=None, + cmap=None, figsize=(12, 6), dpi=200, method='flat', + outline_color='black', ssh_color=None, seafloor_color=None, + interface_color=None, cell_boundary_color=None, + linewidth=1.0, color_start_and_end=False, + start_color='red', end_color='green'): + """ + plot a transect showing the field on the MPAS-Ocean mesh and save to a file + + Parameters + ---------- + ds_transect : xarray.Dataset + A transect dataset from + :py:func:`polaris.ocean.viz.compute_transect()` + + mpas_field : xarray.DataArray + The MPAS-Ocean 3D field to plot + + out_filename : str, optional + The png file to write out to + + ax : matplotlib.axes.Axes + Axes to plot to if making a multi-panel figure + + title : str + The title of the plot + + vmin : float, optional + The minimum values for the colorbar + + vmax : float, optional + The maximum values for the colorbar + + colorbar_label : str, optional + The colorbar label, or ``None`` if no colorbar is to be included. + Use an empty string to display a colorbar without a label. + + cmap : str, optional + The name of a colormap to use + + figsize : tuple, optional + The size of the figure in inches + + dpi : int, optional + The dots per inch of the image + + method : {'flat', 'bilinear'}, optional + The type of interpolation to use in plots. ``flat`` means constant + values over each MPAS cell. ``bilinear`` means smooth interpolation + between horizontally between cell centers and vertical between the + middle of layers. + + outline_color : str or None, optional + The color to use to outline the transect or ``None`` for no outline + + ssh_color : str or None, optional + The color to use to plot the SSH (sea surface height) or ``None`` if + not plotting the SSH (except perhaps as part of the outline) + + seafloor_color : str or None, optional + The color to use to plot the seafloor depth or ``None`` if not plotting + the seafloor depth (except perhaps as part of the outline) + + interface_color : str or None, optional + The color to use to plot interfaces between layers or ``None`` if + not plotting the layer interfaces + + cell_boundary_color : str or None, optional + The color to use to plot vertical boundaries between cells or ``None`` + if not plotting cell boundaries. Typically, ``cell_boundary_color`` + will be used along with ``interface_color`` to outline cells both + horizontally and vertically. + + linewidth : float, optional + The width of outlines, interfaces and cell boundaries + + color_start_and_end : bool, optional + Whether to color the left and right axes of the transect, which is + useful if the transect is also being plotted in an inset or on top of + a horizontal field + + start_color : str, optional + The color of left axis marking the start of the transect if + ``plot_start_end == True`` + + end_color : str, optional + The color of right axis marking the end of the transect if + ``plot_start_end == True`` + """ + + if ax is None and out_filename is None: + raise ValueError('One of ax or out_filename must be supplied') + + use_mplstyle() + + create_fig = ax is None + if create_fig: + plt.figure(figsize=figsize) + ax = plt.subplot(111) + + z = ds_transect.zTransectNode + x = 1e-3 * ds_transect.dNode.broadcast_like(z) + + if mpas_field is not None: + if method == 'flat': + transect_field = interp_mpas_to_transect_cells(ds_transect, + mpas_field) + shading = 'flat' + elif method == 'bilinear': + transect_field = interp_mpas_to_transect_nodes(ds_transect, + mpas_field) + shading = 'gouraud' + else: + raise ValueError(f'Unsupported method: {method}') + + pc = ax.pcolormesh(x.values, z.values, transect_field.values, + shading=shading, cmap=cmap, vmin=vmin, vmax=vmax, + zorder=0) + ax.autoscale(tight=True) + if colorbar_label is not None: + plt.colorbar(pc, extend='both', shrink=0.7, ax=ax, + label=colorbar_label) + + _plot_interfaces(ds_transect, ax, interface_color, cell_boundary_color, + ssh_color, seafloor_color, color_start_and_end, + start_color, end_color, linewidth) + + _plot_outline(x, z, ds_transect.validNodes, ax, outline_color, + linewidth) + + ax.set_xlabel('transect distance (km)') + ax.set_ylabel('z (m)') + + if create_fig: + if title is not None: + plt.title(title) + plt.savefig(out_filename, dpi=dpi, bbox_inches='tight', pad_inches=0.2) + plt.close() + + +def _plot_interfaces(ds_transect, ax, interface_color, cell_boundary_color, + ssh_color, seafloor_color, color_start_and_end, + start_color, end_color, linewidth): + if cell_boundary_color is not None: + x_bnd = 1e-3 * ds_transect.dCellBoundary.values.T + z_bnd = ds_transect.zCellBoundary.values.T + ax.plot(x_bnd, z_bnd, color=cell_boundary_color, linewidth=linewidth, + zorder=1) + + if interface_color is not None: + x_int = 1e-3 * ds_transect.dInterfaceSegment.values.T + z_int = ds_transect.zInterfaceSegment.values.T + ax.plot(x_int, z_int, color=interface_color, linewidth=linewidth, + zorder=2) + + if ssh_color is not None: + valid = ds_transect.validNodes.any(dim='nVertNodes') + x_ssh = 1e-3 * ds_transect.dNode.values + z_ssh = ds_transect.ssh.where(valid).values + ax.plot(x_ssh, z_ssh, color=ssh_color, linewidth=linewidth, zorder=4) + + if seafloor_color is not None: + valid = ds_transect.validNodes.any(dim='nVertNodes') + x_floor = 1e-3 * ds_transect.dNode.values + z_floor = ds_transect.zSeafloor.where(valid).values + ax.plot(x_floor, z_floor, color=seafloor_color, linewidth=linewidth, + zorder=5) + + if color_start_and_end: + ax.spines['left'].set_color(start_color) + ax.spines['left'].set_linewidth(4 * linewidth) + ax.spines['right'].set_color(end_color) + ax.spines['right'].set_linewidth(4 * linewidth) + + +def _plot_outline(x, z, valid_nodes, ax, outline_color, linewidth, + epsilon=1e-6): + if outline_color is not None: + # add a buffer of invalid values around the edge of the domain + valid = np.zeros((x.shape[0] + 2, x.shape[1] + 2), dtype=float) + z_buf = np.zeros(valid.shape, dtype=float) + x_buf = np.zeros(valid.shape, dtype=float) + + valid[1:-1, 1:-1] = valid_nodes.astype(float) + z_buf[1:-1, 1:-1] = z.values + z_buf[0, 1:-1] = z_buf[1, 1:-1] + z_buf[-1, 1:-1] = z_buf[-2, 1:-1] + z_buf[:, 0] = z_buf[:, 1] + z_buf[:, -1] = z_buf[:, -2] + + x_buf[1:-1, 1:-1] = x.values + x_buf[0, 1:-1] = x_buf[1, 1:-1] + x_buf[-1, 1:-1] = x_buf[-2, 1:-1] + x_buf[:, 0] = x_buf[:, 1] + x_buf[:, -1] = x_buf[:, -2] + + ax.contour(x_buf, z_buf, valid, levels=[1. - epsilon], + colors=outline_color, linewidths=linewidth, zorder=3) diff --git a/polaris/ocean/viz/transect/vert.py b/polaris/ocean/viz/transect/vert.py new file mode 100644 index 000000000..a146cb46c --- /dev/null +++ b/polaris/ocean/viz/transect/vert.py @@ -0,0 +1,486 @@ +import numpy as np +import xarray as xr + +from polaris.ocean.viz.transect.horiz import ( + find_planar_transect_cells_and_weights, + find_spherical_transect_cells_and_weights, + make_triangle_tree, + mesh_to_triangles, +) + + +def compute_transect(x, y, ds_horiz_mesh, layer_thickness, bottom_depth, + min_level_cell, max_level_cell, spherical=False): + """ + build a sequence of quads showing the transect intersecting mpas cells. + This can be used to plot transects of fields with dimensions ``nCells`` and + ``nVertLevels`` using :py:func:`polaris.ocean.viz.plot_transect()` + + Parameters + ---------- + x : xarray.DataArray + The x or longitude coordinate of the transect + + y : xarray.DataArray + The y or latitude coordinate of the transect + + ds_horiz_mesh : xarray.Dataset + The horizontal MPAS mesh to use for plotting + + layer_thickness : xarray.DataArray + The layer thickness at a particular instant in time. + `layerThickness.isel(Time=tidx)` to select a particular time index + `tidx` if the original data array contains `Time`. + + bottom_depth : xarray.DataArray + the (positive down) depth of the seafloor on the MPAS mesh + + min_level_cell : xarray.DataArray + the vertical zero-based index of the sea surface on the MPAS mesh + + max_level_cell : xarray.DataArray + the vertical zero-based index of the bathymetry on the MPAS mesh + + spherical : bool, optional + Whether the x and y coordinates are latitude and longitude in degrees + + Returns + ------- + ds_transect : xarray.Dataset + The transect dataset, see + :py:func:`polaris.ocean.viz.transect.vert.find_transect_levels_and_weights()` + for details + """ # noqa: E501 + + ds_tris = mesh_to_triangles(ds_horiz_mesh) + + triangle_tree = make_triangle_tree(ds_tris) + + if spherical: + ds_horiz_transect = find_spherical_transect_cells_and_weights( + x, y, ds_tris, ds_horiz_mesh, triangle_tree, degrees=True) + else: + ds_horiz_transect = find_planar_transect_cells_and_weights( + x, y, ds_tris, ds_horiz_mesh, triangle_tree) + + # mask horizontal transect to valid cells (max_level_cell >= 0) + cell_indices = ds_horiz_transect.horizCellIndices + seg_mask = max_level_cell.isel(nCells=cell_indices).values >= 0 + node_mask = np.zeros(ds_horiz_transect.sizes['nNodes'], dtype=bool) + node_mask[0:-1] = seg_mask + node_mask[1:] = np.logical_or(node_mask[1:], seg_mask) + + ds_horiz_transect = ds_horiz_transect.isel(nSegments=seg_mask, + nNodes=node_mask) + + ds_transect = find_transect_levels_and_weights( + ds_horiz_transect=ds_horiz_transect, layer_thickness=layer_thickness, + bottom_depth=bottom_depth, min_level_cell=min_level_cell, + max_level_cell=max_level_cell) + + ds_transect.compute() + + return ds_transect + + +def find_transect_levels_and_weights(ds_horiz_transect, layer_thickness, + bottom_depth, min_level_cell, + max_level_cell): + """ + Construct a vertical coordinate for a transect produced by + :py:func:`polaris.ocean.viz.transect.horiz.find_spherical_transect_cells_and_weights()` + or :py:func:`polaris.ocean.viz.transect.horiz.find_planar_transect_cells_and_weights()`. + Also, compute interpolation weights such that observations at points on the + original transect and with vertical coordinate ``transectZ`` can be + bilinearly interpolated to the nodes of the transect. + + Parameters + ---------- + ds_horiz_transect : xarray.Dataset + A dataset that defines nodes of the transect + + layer_thickness : xarray.DataArray + layer thicknesses on the MPAS mesh + + bottom_depth : xarray.DataArray + the (positive down) depth of the seafloor on the MPAS mesh + + min_level_cell : xarray.DataArray + the vertical zero-based index of the sea surface on the MPAS mesh + + max_level_cell : xarray.DataArray + the vertical zero-based index of the bathymetry on the MPAS mesh + + Returns + ------- + ds_transect : xarray.Dataset + A dataset that contains nodes and cells that make up a 2D transect. + + There are ``nSegments`` horizontal and ``nHalfLevels`` vertical + transect cells (quadrilaterals), bounded by ``nHorizNodes`` horizontal + and ``nVertNodes`` vertical nodes (corners). + + In addition to the variables and coordinates in the input + ``ds_transect``, the output dataset contains: + + - ``validCells``, ``validNodes``: which transect cells and nodes + are valid (above the bathymetry and below the sea surface) + + - zTransectNode: the vertical height of each triangle node + - ssh, zSeaFloor: the sea-surface height and sea-floor height at + each node of each transect segment + + - ``cellIndices``: the MPAS-Ocean cell of a given transect segment + - ``levelIndices``: the MPAS-Ocean vertical level of a given + transect level + + - ``interpCellIndices``, ``interpLevelIndices``: the MPAS-Ocean + cells and levels from which the value at a given transect cell is + interpolated. This can involve up to + ``nHorizWeights * nVertWeights = 12`` different cells and levels. + - interpCellWeights: the weight to multiply each field value by + to perform interpolation to a transect cell. + + - ``dInterfaceSegment``, ``zInterfaceSegment`` - segments that can + be used to plot the interfaces between MPAS-Ocean layers + + - ``dCellBoundary``, ``zCellBoundary`` - segments that can + be used to plot the vertical boundaries between MPAS-Ocean cells + + Interpolation of a DataArray from MPAS cells and levels to transect + cells can be performed with + :py:func:`polaris.ocean.viz.transect.vert.interp_mpas_to_transect_cells()`. + Similarly, interpolation to transect nodes can be performed with + :py:func:`polaris.ocean.viz.transect.vert.interp_mpas_to_transect_nodes()`. + """ # noqa: E501 + if 'Time' in layer_thickness.dims: + raise ValueError('Please select a single time level in layer ' + 'thickness.') + + ds_transect_cells = ds_horiz_transect.rename({'nNodes': 'nHorizNodes'}) + + (z_half_interface, ssh, z_seafloor, interp_cell_indices, + interp_cell_weights, valid_transect_cells, + level_indices) = _get_vertical_coordinate( + ds_transect_cells, layer_thickness, bottom_depth, min_level_cell, + max_level_cell) + + ds_transect_cells['zTransectNode'] = z_half_interface + + ds_transect_cells['ssh'] = ssh + ds_transect_cells['zSeafloor'] = z_seafloor + + ds_transect_cells['cellIndices'] = ds_transect_cells.horizCellIndices + ds_transect_cells['levelIndices'] = level_indices + ds_transect_cells['validCells'] = valid_transect_cells + + d_interface_seg, z_interface_seg = _get_interface_segments( + z_half_interface, ds_transect_cells.dNode, valid_transect_cells) + + ds_transect_cells['dInterfaceSegment'] = d_interface_seg + ds_transect_cells['zInterfaceSegment'] = z_interface_seg + + d_cell_boundary, z_cell_boundary = _get_cell_boundary_segments( + ssh, z_seafloor, ds_transect_cells.dNode, + ds_transect_cells.horizCellIndices) + + ds_transect_cells['dCellBoundary'] = d_cell_boundary + ds_transect_cells['zCellBoundary'] = z_cell_boundary + + interp_level_indices, interp_cell_weights, valid_nodes = \ + _get_interp_indices_and_weights(layer_thickness, interp_cell_indices, + interp_cell_weights, level_indices, + valid_transect_cells) + + ds_transect_cells['interpCellIndices'] = interp_cell_indices + ds_transect_cells['interpLevelIndices'] = interp_level_indices + ds_transect_cells['interpCellWeights'] = interp_cell_weights + ds_transect_cells['validNodes'] = valid_nodes + + dims = ['nSegments', 'nHalfLevels', 'nHorizNodes', 'nVertNodes', + 'nInterfaceSegments', 'nCellBoundaries', 'nHorizBounds', + 'nVertBounds', 'nHorizWeights', 'nVertWeights'] + for dim in ds_transect_cells.dims: + if dim not in dims: + dims.insert(0, dim) + ds_transect_cells = ds_transect_cells.transpose(*dims) + + return ds_transect_cells + + +def interp_mpas_to_transect_cells(ds_transect, da): + """ + Interpolate an MPAS-Ocean DataArray with dimensions ``nCells`` by + ``nVertLevels`` to transect cells + + Parameters + ---------- + ds_transect : xarray.Dataset + A dataset that defines an MPAS-Ocean transect, the results of calling + ``find_transect_levels_and_weights()`` + + da : xarray.DataArray + An MPAS-Ocean field with dimensions `nCells`` and ``nVertLevels`` + (possibly among others) + + Returns + ------- + da_cells : xarray.DataArray + The data array interpolated to transect cells with dimensions + ``nSegments`` and ``nHalfLevels`` (in addition to whatever + dimensions were in ``da`` besides ``nCells`` and ``nVertLevels``) + """ + + cell_indices = ds_transect.cellIndices + level_indices = ds_transect.levelIndices + + da_cells = da.isel(nCells=cell_indices, nVertLevels=level_indices) + da_cells = da_cells.where(ds_transect.validCells) + + return da_cells + + +def interp_mpas_to_transect_nodes(ds_transect, da): + """ + Interpolate an MPAS-Ocean DataArray with dimensions ``nCells`` by + ``nVertLevels`` to transect nodes, linearly interpolating fields between + the closest neighboring cells + + Parameters + ---------- + ds_transect : xarray.Dataset + A dataset that defines an MPAS-Ocean transect, the results of calling + ``find_transect_levels_and_weights()`` + + da : xarray.DataArray + An MPAS-Ocean field with dimensions `nCells`` and ``nVertLevels`` + (possibly among others) + + Returns + ------- + da_nodes : xarray.DataArray + The data array interpolated to transect nodes with dimensions + ``nHorizNodes`` and ``nVertNodes`` (in addition to whatever + dimensions were in ``da`` besides ``nCells`` and ``nVertLevels``) + """ + + interp_cell_indices = ds_transect.interpCellIndices + interp_level_indices = ds_transect.interpLevelIndices + interp_cell_weights = ds_transect.interpCellWeights + + da = da.isel(nCells=interp_cell_indices, nVertLevels=interp_level_indices) + + da_nodes = (da * interp_cell_weights).sum( + dim=('nHorizWeights', 'nVertWeights')) + + da_nodes = da_nodes.where(ds_transect.validNodes) + + return da_nodes + + +def _get_vertical_coordinate(ds_transect, layer_thickness, bottom_depth, + min_level_cell, max_level_cell): + n_horiz_nodes = ds_transect.sizes['nHorizNodes'] + n_segments = ds_transect.sizes['nSegments'] + n_vert_levels = layer_thickness.sizes['nVertLevels'] + + # we assume below that there is a segment (whether valid or invalid) + # connecting each pair of adjacent nodes + assert n_horiz_nodes == n_segments + 1 + + interp_horiz_cell_indices = ds_transect.interpHorizCellIndices + interp_horiz_cell_weights = ds_transect.interpHorizCellWeights + + bottom_depth_interp = bottom_depth.isel(nCells=interp_horiz_cell_indices) + layer_thickness_interp = layer_thickness.isel( + nCells=interp_horiz_cell_indices) + + cell_mask_interp = _get_cell_mask(interp_horiz_cell_indices, + min_level_cell, max_level_cell, + n_vert_levels) + layer_thickness_interp = layer_thickness_interp.where(cell_mask_interp, 0.) + + ssh_interp = (-bottom_depth_interp + + layer_thickness_interp.sum(dim='nVertLevels')) + + interp_mask = np.logical_and(interp_horiz_cell_indices > 0, + cell_mask_interp) + + interp_cell_weights = interp_mask * interp_horiz_cell_weights + weight_sum = interp_cell_weights.sum(dim='nHorizWeights') + + cell_indices = ds_transect.horizCellIndices + + valid_cells = _get_cell_mask(cell_indices, min_level_cell, max_level_cell, + n_vert_levels) + + valid_cells = valid_cells.transpose('nSegments', 'nVertLevels').values + + valid_nodes = np.zeros((n_horiz_nodes, n_vert_levels), dtype=bool) + valid_nodes[0:-1, :] = valid_cells + valid_nodes[1:, :] = np.logical_or(valid_nodes[1:, :], valid_cells) + + valid_nodes = xr.DataArray(dims=('nHorizNodes', 'nVertLevels'), + data=valid_nodes) + + valid_weights = valid_nodes.broadcast_like(interp_cell_weights) + interp_cell_weights = \ + (interp_cell_weights / weight_sum).where(valid_weights) + + layer_thickness_transect = (layer_thickness_interp * + interp_cell_weights).sum(dim='nHorizWeights') + + interp_mask = max_level_cell.isel(nCells=interp_horiz_cell_indices) >= 0 + interp_horiz_cell_weights = interp_mask * interp_horiz_cell_weights + weight_sum = interp_horiz_cell_weights.sum(dim='nHorizWeights') + interp_horiz_cell_weights = \ + (interp_horiz_cell_weights / weight_sum).where(interp_mask) + + ssh_transect = (ssh_interp * + interp_horiz_cell_weights).sum(dim='nHorizWeights') + + z_bot = ssh_transect - layer_thickness_transect.cumsum(dim='nVertLevels') + z_mid = z_bot + 0.5 * layer_thickness_transect + + z_half_interfaces = [ssh_transect] + for z_index in range(n_vert_levels): + z_half_interfaces.extend([z_mid.isel(nVertLevels=z_index), + z_bot.isel(nVertLevels=z_index)]) + + z_half_interface = xr.concat(z_half_interfaces, dim='nVertNodes') + z_half_interface = z_half_interface.transpose('nHorizNodes', 'nVertNodes') + + z_seafloor = ssh_transect - layer_thickness_transect.sum( + dim='nVertLevels') + + valid_transect_cells = np.zeros((n_segments, 2 * n_vert_levels), + dtype=bool) + valid_transect_cells[:, 0::2] = valid_cells + valid_transect_cells[:, 1::2] = valid_cells + valid_transect_cells = xr.DataArray(dims=('nSegments', 'nHalfLevels'), + data=valid_transect_cells) + + level_indices = np.zeros(2 * n_vert_levels, dtype=int) + level_indices[0::2] = np.arange(n_vert_levels) + level_indices[1::2] = np.arange(n_vert_levels) + level_indices = xr.DataArray(dims=('nHalfLevels',), data=level_indices) + + return (z_half_interface, ssh_transect, z_seafloor, + interp_horiz_cell_indices, interp_cell_weights, + valid_transect_cells, level_indices) + + +def _get_cell_mask(cell_indices, min_level_cell, max_level_cell, + n_vert_levels): + level_indices = xr.DataArray(data=np.arange(n_vert_levels), + dims='nVertLevels') + min_level_cell = min_level_cell.isel(nCells=cell_indices) + max_level_cell = max_level_cell.isel(nCells=cell_indices) + + cell_mask = np.logical_and( + level_indices >= min_level_cell, + level_indices <= max_level_cell) + + cell_mask = np.logical_and(cell_mask, cell_indices >= 0) + + return cell_mask + + +def _get_interface_segments(z_half_interface, d_node, valid_transect_cells): + + d = d_node.broadcast_like(z_half_interface) + z_interface = z_half_interface.values[:, 0::2] + d = d.values[:, 0::2] + + n_segments = valid_transect_cells.sizes['nSegments'] + n_half_levels = valid_transect_cells.sizes['nHalfLevels'] + n_vert_levels = n_half_levels // 2 + + valid_segs = np.zeros((n_segments, n_vert_levels + 1), dtype=bool) + valid_segs[:, 0:-1] = valid_transect_cells.values[:, 1::2] + valid_segs[:, 1:] = np.logical_or(valid_segs[:, 1:], + valid_transect_cells.values[:, 0::2]) + + n_interface_segs = np.count_nonzero(valid_segs) + + d_seg = np.zeros((n_interface_segs, 2)) + z_seg = np.zeros((n_interface_segs, 2)) + d_seg[:, 0] = d[0:-1, :][valid_segs] + d_seg[:, 1] = d[1:, :][valid_segs] + z_seg[:, 0] = z_interface[0:-1, :][valid_segs] + z_seg[:, 1] = z_interface[1:, :][valid_segs] + + d_seg = xr.DataArray(dims=('nInterfaceSegments', 'nHorizBounds'), + data=d_seg) + + z_seg = xr.DataArray(dims=('nInterfaceSegments', 'nHorizBounds'), + data=z_seg) + + return d_seg, z_seg + + +def _get_cell_boundary_segments(ssh, z_seafloor, d_node, cell_indices): + + n_horiz_nodes = d_node.sizes['nHorizNodes'] + + cell_boundary = np.ones(n_horiz_nodes, dtype=bool) + cell_boundary[1:-1] = cell_indices.values[0:-1] != cell_indices.values[1:] + + n_cell_boundaries = np.count_nonzero(cell_boundary) + + d_seg = np.zeros((n_cell_boundaries, 2)) + z_seg = np.zeros((n_cell_boundaries, 2)) + d_seg[:, 0] = d_node.values[cell_boundary] + d_seg[:, 1] = d_seg[:, 0] + z_seg[:, 0] = ssh[cell_boundary] + z_seg[:, 1] = z_seafloor[cell_boundary] + + d_seg = xr.DataArray(dims=('nCellBoundaries', 'nVertBounds'), data=d_seg) + + z_seg = xr.DataArray(dims=('nCellBoundaries', 'nVertBounds'), data=z_seg) + + return d_seg, z_seg + + +def _get_interp_indices_and_weights(layer_thickness, interp_cell_indices, + interp_cell_weights, level_indices, + valid_transect_cells): + n_horiz_nodes = interp_cell_indices.sizes['nHorizNodes'] + n_vert_levels = layer_thickness.sizes['nVertLevels'] + n_vert_nodes = 2 * n_vert_levels + 1 + n_vert_weights = 2 + + interp_level_indices = -1 * np.ones((n_vert_nodes, n_vert_weights), + dtype=int) + interp_level_indices[1:, 0] = level_indices.values + interp_level_indices[0:-1, 1] = level_indices.values + + interp_level_indices = xr.DataArray(dims=('nVertNodes', 'nVertWeights'), + data=interp_level_indices) + + half_level_thickness = 0.5 * layer_thickness.isel( + nCells=interp_cell_indices, nVertLevels=interp_level_indices) + half_level_thickness = half_level_thickness.where( + interp_level_indices >= 0, other=0.) + + # vertical weights are proportional to the half-level thickness + interp_cell_weights = half_level_thickness * interp_cell_weights.isel( + nVertLevels=interp_level_indices) + + valid_nodes = np.zeros((n_horiz_nodes, n_vert_nodes), dtype=bool) + valid_nodes[0:-1, 0:-1] = valid_transect_cells + valid_nodes[1:, 0:-1] = np.logical_or(valid_nodes[1:, 0:-1], + valid_transect_cells) + valid_nodes[0:-1, 1:] = np.logical_or(valid_nodes[0:-1, 1:], + valid_transect_cells) + valid_nodes[1:, 1:] = np.logical_or(valid_nodes[1:, 1:], + valid_transect_cells) + + valid_nodes = xr.DataArray(dims=('nHorizNodes', 'nVertNodes'), + data=valid_nodes) + + weight_sum = interp_cell_weights.sum(dim=('nHorizWeights', 'nVertWeights')) + out_mask = (weight_sum > 0.).broadcast_like(interp_cell_weights) + interp_cell_weights = (interp_cell_weights / weight_sum).where(out_mask) + + return interp_level_indices, interp_cell_weights, valid_nodes diff --git a/polaris/viz/planar.py b/polaris/viz/planar.py index fef04eda2..779bbbfa3 100644 --- a/polaris/viz/planar.py +++ b/polaris/viz/planar.py @@ -16,7 +16,10 @@ def plot_horiz_field(ds, ds_mesh, field_name, out_file_name=None, # noqa: C901 cmap=None, cmap_set_under=None, cmap_set_over=None, cmap_scale='linear', cmap_title=None, figsize=None, vert_dim='nVertLevels', cell_mask=None, patches=None, - patch_mask=None): + patch_mask=None, transect_x=None, transect_y=None, + transect_color='black', transect_start='red', + transect_end='green', transect_linewidth=2., + transect_markersize=12.): """ Plot a horizontal field from a planar domain using x,y coordinates at a single time and depth slice. @@ -92,6 +95,27 @@ def plot_horiz_field(ds, ds_mesh, field_name, out_file_name=None, # noqa: C901 A mask of where the field has patches from a previous call to ``plot_horiz_field()`` + transect_x : numpy.ndarray or xarray.DataArray, optional + The x coordinates of a transect to plot on the + + transect_y : numpy.ndarray or xarray.DataArray, optional + The y coordinates of a transect + + transect_color : str, optional + The color of the transect line + + transect_start : str or None, optional + The color of a dot marking the start of the transect + + transect_end : str or None, optional + The color of a dot marking the end of the transect + + transect_linewidth : float, optional + The width of the transect line + + transect_markersize : float, optional + The size of the transect start and end markers + Returns ------- patches : list of numpy.ndarray @@ -101,6 +125,19 @@ def plot_horiz_field(ds, ds_mesh, field_name, out_file_name=None, # noqa: C901 patch_mask : numpy.ndarray A mask used to select entries in the field that have patches """ + if field_name not in ds: + raise ValueError( + f'{field_name} must be present in ds before plotting.') + + if patches is not None: + if patch_mask is None: + raise ValueError('You must supply both patches and patch_mask ' + 'from a previous call to plot_horiz_field()') + + if (transect_x is None) != (transect_y is None): + raise ValueError('You must supply both transect_x and transect_y or ' + 'neither') + use_mplstyle() create_fig = True @@ -118,10 +155,6 @@ def plot_horiz_field(ds, ds_mesh, field_name, out_file_name=None, # noqa: C901 if title is None: title = field_name - if field_name not in ds: - raise ValueError( - f'{field_name} must be present in ds before plotting.') - field = ds[field_name] if 'Time' in field.dims and t_index is None: @@ -133,11 +166,7 @@ def plot_horiz_field(ds, ds_mesh, field_name, out_file_name=None, # noqa: C901 if z_index is not None: field = field.isel({vert_dim: z_index}) - if patches is not None: - if patch_mask is None: - raise ValueError('You must supply both patches and patch_mask ' - 'from a previous call to plot_horiz_field()') - else: + if patches is None: if cell_mask is None: cell_mask = np.ones_like(field, type='bool') if 'nCells' in field.dims: @@ -190,6 +219,18 @@ def plot_horiz_field(ds, ds_mesh, field_name, out_file_name=None, # noqa: C901 cbar = plt.colorbar(local_patches, extend='both', shrink=0.7, ax=ax) if cmap_title is not None: cbar.set_label(cmap_title) + + if transect_x is not None: + transect_x = 1e-3 * transect_x + transect_y = 1e-3 * transect_y + ax.plot(transect_x, transect_y, color=transect_color, + linewidth=transect_linewidth) + if transect_start is not None: + ax.plot(transect_x[0], transect_y[0], '.', color=transect_start, + markersize=transect_markersize) + if transect_end is not None: + ax.plot(transect_x[-1], transect_y[-1], '.', color=transect_end, + markersize=transect_markersize) if create_fig: plt.title(title) plt.savefig(out_file_name, bbox_inches='tight', pad_inches=0.2)