From 5fb9483abd4e3b1471fd22eb058a0f16de771553 Mon Sep 17 00:00:00 2001 From: Xylar Asay-Davis Date: Sat, 28 Oct 2023 15:18:01 +0200 Subject: [PATCH 01/14] Update to mpas_tools 0.27.0 --- deploy/default.cfg | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deploy/default.cfg b/deploy/default.cfg index 697f21984..38bc13a07 100644 --- a/deploy/default.cfg +++ b/deploy/default.cfg @@ -24,7 +24,7 @@ geometric_features = 1.2.0 jigsaw = 0.9.14 jigsawpy = 0.3.3 mache = 1.16.0 -mpas_tools = 0.25.0 +mpas_tools = 0.27.0 otps = 2021.10 parallelio = 2.6.0 From 9ebaf1e6f79e5ec97d963976b77053aa5f14ec6b Mon Sep 17 00:00:00 2001 From: Xylar Asay-Davis Date: Wed, 25 Oct 2023 14:59:11 +0200 Subject: [PATCH 02/14] Add functions for computing and plotting transects --- polaris/ocean/viz/__init__.py | 1 + polaris/ocean/viz/transect.py | 159 ++++++++++++++++++++++++++++++++++ 2 files changed, 160 insertions(+) create mode 100644 polaris/ocean/viz/__init__.py create mode 100644 polaris/ocean/viz/transect.py diff --git a/polaris/ocean/viz/__init__.py b/polaris/ocean/viz/__init__.py new file mode 100644 index 000000000..365aaec1c --- /dev/null +++ b/polaris/ocean/viz/__init__.py @@ -0,0 +1 @@ +from polaris.ocean.viz.transect import compute_transect, plot_transect diff --git a/polaris/ocean/viz/transect.py b/polaris/ocean/viz/transect.py new file mode 100644 index 000000000..07780b67e --- /dev/null +++ b/polaris/ocean/viz/transect.py @@ -0,0 +1,159 @@ +import cmocean # noqa: blah +import matplotlib.pyplot as plt +import numpy as np +from matplotlib.tri import Triangulation +from mpas_tools.ocean.transects import ( + find_transect_levels_and_weights, + get_outline_segments, + interp_mpas_to_transect_triangles, +) +from mpas_tools.viz import mesh_to_triangles +from mpas_tools.viz.transects import ( + find_planar_transect_cells_and_weights, + find_transect_cells_and_weights, + make_triangle_tree, +) + + +def compute_transect(x, y, ds_3d_mesh, spherical=False): + """ + build a sequence of triangles showing the transect intersecting mpas cells + + Parameters + ---------- + x : xarray.DataArray + The x or longitude coordinate of the transect + + y : xarray.DataArray + The y or latitude coordinate of the transect + + ds_3d_mesh : xarray.Dataset + The MPAS-Ocean mesh to use for plotting + + spherical : bool, optional + Whether the x and y coordinates are latitude and longitude in degrees + + Returns + ------- + ds_transect : xarray.Dataset + The transect dataset + """ + + ds_tris = mesh_to_triangles(ds_3d_mesh) + + triangle_tree = make_triangle_tree(ds_tris) + + if spherical: + ds_transect = find_transect_cells_and_weights( + x, y, ds_tris, ds_3d_mesh, triangle_tree, degrees=True) + else: + ds_transect = find_planar_transect_cells_and_weights( + x, y, ds_tris, ds_3d_mesh, triangle_tree) + + cell_indices = ds_transect.horizCellIndices + mask = ds_3d_mesh.maxLevelCell.isel(nCells=cell_indices) > 0 + ds_transect = ds_transect.isel(nSegments=mask) + + ds_transect = find_transect_levels_and_weights( + ds_transect, ds_3d_mesh.layerThickness, + ds_3d_mesh.bottomDepth, ds_3d_mesh.maxLevelCell - 1) + + if 'landIceFraction' in ds_3d_mesh: + interp_cell_indices = ds_transect.interpHorizCellIndices + interp_cell_weights = ds_transect.interpHorizCellWeights + land_ice_fraction = ds_3d_mesh.landIceFraction.isel( + nCells=interp_cell_indices) + land_ice_fraction = (land_ice_fraction * interp_cell_weights).sum( + dim='nHorizWeights') + ds_transect['landIceFraction'] = land_ice_fraction + + ds_transect['x'] = ds_transect.dNode.isel( + nSegments=ds_transect.segmentIndices, + nHorizBounds=ds_transect.nodeHorizBoundsIndices) + + ds_transect['z'] = ds_transect.zTransectNode + + ds_transect.compute() + + return ds_transect + + +def plot_transect(ds_transect, mpas_field, out_filename, title, + colorbar_label=None, cmap=None, figsize=(12, 6), dpi=200): + """ + plot a transect showing the field on the MPAS-Ocean mesh and save to a file + + Parameters + ---------- + ds_transect : xarray.Dataset + A transect dataset from + :py:func:`polaris.ocean.viz.compute_transect()` + + mpas_field : xarray.DataArray + The MPAS-Ocean 3D field to plot + + out_filename : str + The png file to write out to + + title : str + The title of the plot + + colorbar_label : str, optional + The colorbar label + + cmap : str, optional + The name of a colormap to use + + figsize : tuple, optional + The size of the figure in inches + + dpi : int, optional + The dots per inch of the image + + """ + transect_field = interp_mpas_to_transect_triangles(ds_transect, + mpas_field) + + x_outline, z_outline = get_outline_segments(ds_transect) + x_outline = 1e-3 * x_outline + + tri_mask = np.logical_not(transect_field.notnull().values) + + triangulation_args = _get_ds_triangulation_args(ds_transect) + + triangulation_args['mask'] = tri_mask + + tris = Triangulation(**triangulation_args) + plt.figure(figsize=figsize) + plt.tripcolor(tris, facecolors=transect_field.values, shading='flat', + cmap=cmap) + plt.plot(x_outline, z_outline, 'k') + if colorbar_label is not None: + plt.colorbar(label=colorbar_label) + plt.title(title) + plt.xlabel('x (km)') + plt.ylabel('z (m)') + + plt.savefig(out_filename, dpi=dpi, bbox_inches='tight', pad_inches=0.2) + plt.close() + + +def _get_ds_triangulation_args(ds_transect): + """ + get arguments for matplotlib Triangulation from triangulation dataset + """ + + n_transect_triangles = ds_transect.sizes['nTransectTriangles'] + d_node = ds_transect.dNode.isel( + nSegments=ds_transect.segmentIndices, + nHorizBounds=ds_transect.nodeHorizBoundsIndices) + x = 1e-3 * d_node.values.ravel() + + z_transect_node = ds_transect.zTransectNode + y = z_transect_node.values.ravel() + + tris = np.arange(3 * n_transect_triangles).reshape( + (n_transect_triangles, 3)) + triangulation_args = dict(x=x, y=y, triangles=tris) + + return triangulation_args From 69ef39a28707fb732b0d453a0a0f761a1876f52d Mon Sep 17 00:00:00 2001 From: Xylar Asay-Davis Date: Mon, 6 Nov 2023 14:44:14 +0100 Subject: [PATCH 03/14] Compute transects on quads The code added here is similar to functionality in MPAS-Tools but it works with data on quads, rather than triangles. Visualization on triangles has proven to have a number of major drawbacks. It is not very intuitive, functionality for matplotlib functions and methods on triangles have fewer features and are less robust than their counterparts for quads, and there does not seem to be support for labeled contours on triangles. --- polaris/ocean/viz/transect/__init__.py | 0 polaris/ocean/viz/transect/horiz.py | 726 +++++++++++++++++++++++++ polaris/ocean/viz/transect/vert.py | 477 ++++++++++++++++ 3 files changed, 1203 insertions(+) create mode 100644 polaris/ocean/viz/transect/__init__.py create mode 100644 polaris/ocean/viz/transect/horiz.py create mode 100644 polaris/ocean/viz/transect/vert.py diff --git a/polaris/ocean/viz/transect/__init__.py b/polaris/ocean/viz/transect/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/polaris/ocean/viz/transect/horiz.py b/polaris/ocean/viz/transect/horiz.py new file mode 100644 index 000000000..0a06cddd5 --- /dev/null +++ b/polaris/ocean/viz/transect/horiz.py @@ -0,0 +1,726 @@ +import numpy as np +import xarray as xr +from mpas_tools.transects import ( + cartesian_to_lon_lat, + lon_lat_to_cartesian, + subdivide_great_circle, + subdivide_planar, +) +from mpas_tools.vector import Vector +from scipy.spatial import cKDTree +from shapely.geometry import LineString, Point + + +def make_triangle_tree(ds_tris): + """ + Make a KD-Tree for finding triangle edges that are near enough to transect + segments that they might intersect + + Parameters + ---------- + ds_tris : xarray.Dataset + A dataset that defines triangles, the results of calling + :py:func:`mpas_tools.viz.mesh_to_triangles.mesh_to_triangles()` + + Returns + ------- + tree : scipy.spatial.cKDTree + A tree of edge centers from triangles making up an MPAS mesh + """ + + n_triangles = ds_tris.sizes['nTriangles'] + n_nodes = ds_tris.sizes['nNodes'] + node_coords = np.zeros((n_triangles * n_nodes, 3)) + node_coords[:, 0] = ds_tris.xNode.values.ravel() + node_coords[:, 1] = ds_tris.yNode.values.ravel() + node_coords[:, 2] = ds_tris.zNode.values.ravel() + + next_tri, next_node = np.meshgrid( + np.arange(n_triangles), np.mod(np.arange(n_nodes) + 1, 3), + indexing='ij') + nextIndices = n_nodes * next_tri.ravel() + next_node.ravel() + + # edge centers are half way between adjacent nodes (ignoring great-circle + # distance) + edgeCoords = 0.5 * (node_coords + node_coords[nextIndices, :]) + + tree = cKDTree(data=edgeCoords, copy_data=True) + return tree + + +def find_spherical_transect_cells_and_weights( + lon_transect, lat_transect, ds_tris, ds_mesh, tree, degrees=True, + earth_radius=None, subdivision_res=10e3): + """ + Find "nodes" where the transect intersects the edges of the triangles + that make up MPAS cells. + + Parameters + ---------- + lon_transect : xarray.DataArray + The longitude of segments making up the transect + + lat_transect : xarray.DataArray + The latitude of segments making up the transect + + ds_tris : xarray.Dataset + A dataset that defines triangles, the results of calling + :py:func:`mpas_tools.viz.mesh_to_triangles.mesh_to_triangles()` + + ds_mesh : xarray.Dataset + A data set with the full MPAS mesh. + + tree : scipy.spatial.cKDTree + A tree of edge centers from triangles making up an MPAS mesh, the + return value from ``make_triangle_tree()`` + + degrees : bool, optional + Whether ``lon_transect`` and ``lat_transect`` are in degrees (as + opposed to radians). + + subdivision_res : float, optional + Resolution in m to use to subdivide the transect when looking for + intersection candidates. Should be small enough that curvature is + small. + + earth_radius : float, optional + The radius of the Earth in meters, taken from the `sphere_radius` + global attribute if not provided + + Returns + ------- + ds_out : xarray.Dataset + A dataset that contains "nodes" where the transect intersects the + edges of the triangles in ``ds_tris``. The nodes also includes the two + end points of the transect, which typically lie within triangles. Each + internal node (that is, not including the end points) is purposefully + repeated twice, once for each triangle that node touches. This allows + for discontinuous fields between triangles (e.g. if one wishes to plot + constant values on each MPAS cell). The Cartesian and lon/lat + coordinates of these nodes are ``xCartNode``, ``yCartNode``, + ``zCartNode``, ``lonNode`` and ``latNode``. The distance along the + transect of each intersection is ``dNode``. The index of the triangle + and the first triangle node in ``ds_tris`` associated with each + intersection node are given by ``horizTriangleIndices`` and + ``horizTriangleNodeIndices``, respectively. The second node on the + triangle for the edge associated with the intersection is given by + ``numpy.mod(horizTriangleNodeIndices + 1, 3)``. + + The MPAS cell that a given node belongs to is given by + ``horizCellIndices``. Each node also has an associated set of 6 + ``interpHorizCellIndices`` and ``interpHorizCellWeights`` that can be + used to interpolate from MPAS cell centers to nodes first with + area-weighted averaging to MPAS vertices and then linear interpolation + along triangle edges. Some of the weights may be zero, in which case + the associated ``interpHorizCellIndices`` will be -1. + + Finally, ``lonTransect`` and ``latTransect`` are included in the + dataset, along with Cartesian coordinates ``xCartTransect``, + ``yCartTransect``, `zCartTransect``, and ``dTransect``, the + great-circle distance along the transect of each original transect + point. In order to interpolate values (e.g. observations) from the + original transect points to the intersection nodes, linear + interpolation indices ``transectIndicesOnHorizNode`` and weights + ``transectWeightsOnHorizNode`` are provided. The values at nodes are + found by:: + + nodeValues = ((transectValues[transectIndicesOnHorizNode] * + transectWeightsOnHorizNode) + + (transectValues[transectIndicesOnHorizNode+1] * + (1.0 - transectWeightsOnHorizNode)) + """ + if earth_radius is None: + earth_radius = ds_mesh.attrs['sphere_radius'] + buffer = np.maximum(np.amax(ds_mesh.dvEdge.values), + np.amax(ds_mesh.dcEdge.values)) + + x, y, z = lon_lat_to_cartesian(lon_transect, lat_transect, earth_radius, + degrees) + + n_nodes = ds_tris.sizes['nNodes'] + node_cell_weights = ds_tris.nodeCellWeights.values + node_cell_indices = ds_tris.nodeCellIndices.values + + x_node = ds_tris.xNode.values.ravel() + y_node = ds_tris.yNode.values.ravel() + z_node = ds_tris.zNode.values.ravel() + + d_transect = np.zeros(lon_transect.shape) + + d_node = None + x_out = None + y_out = None + z_out = None + tris = None + nodes = None + interp_cells = None + cell_weights = None + + n_horiz_weights = 6 + + first = True + + d_start = 0. + for seg_index in range(len(x) - 1): + transectv0 = Vector(x[seg_index].values, + y[seg_index].values, + z[seg_index].values) + transectv1 = Vector(x[seg_index + 1].values, + y[seg_index + 1].values, + z[seg_index + 1].values) + + sub_slice = slice(seg_index, seg_index + 2) + x_sub, y_sub, z_sub, _, _ = subdivide_great_circle( + x[sub_slice].values, y[sub_slice].values, z[sub_slice].values, + subdivision_res, earth_radius) + + coords = np.zeros((len(x_sub), 3)) + coords[:, 0] = x_sub + coords[:, 1] = y_sub + coords[:, 2] = z_sub + radius = buffer + subdivision_res + + index_list = tree.query_ball_point(x=coords, r=radius) + + unique_indices = set() + for indices in index_list: + unique_indices.update(indices) + + n0_indices_cand = np.array(list(unique_indices)) + + if len(n0_indices_cand) == 0: + continue + + tris_cand = n0_indices_cand // n_nodes + next_node_index = np.mod(n0_indices_cand + 1, n_nodes) + n1_indices_cand = n_nodes * tris_cand + next_node_index + + n0_cand = Vector(x_node[n0_indices_cand], + y_node[n0_indices_cand], + z_node[n0_indices_cand]) + n1_cand = Vector(x_node[n1_indices_cand], + y_node[n1_indices_cand], + z_node[n1_indices_cand]) + + intersect = Vector.intersects(n0_cand, n1_cand, transectv0, + transectv1) + + n0_inter = Vector(n0_cand.x[intersect], + n0_cand.y[intersect], + n0_cand.z[intersect]) + n1_inter = Vector(n1_cand.x[intersect], + n1_cand.y[intersect], + n1_cand.z[intersect]) + + tris_inter = tris_cand[intersect] + n0_indices_inter = n0_indices_cand[intersect] + n1_indices_inter = n1_indices_cand[intersect] + + intersections = Vector.intersection(n0_inter, n1_inter, transectv0, + transectv1) + intersections = Vector(earth_radius * intersections.x, + earth_radius * intersections.y, + earth_radius * intersections.z) + + angular_distance = transectv0.angular_distance(intersections) + + d_node_local = d_start + earth_radius * angular_distance + + d_start += earth_radius * transectv0.angular_distance(transectv1) + + node0_inter = np.mod(n0_indices_inter, n_nodes) + node1_inter = np.mod(n1_indices_inter, n_nodes) + + node_weights = (intersections.angular_distance(n1_inter) / + n0_inter.angular_distance(n1_inter)) + + weights = np.zeros((len(tris_inter), n_horiz_weights)) + cell_indices = np.zeros((len(tris_inter), n_horiz_weights), int) + for index in range(3): + weights[:, index] = ( + node_weights * + node_cell_weights[tris_inter, node0_inter, index]) + cell_indices[:, index] = ( + node_cell_indices[tris_inter, node0_inter, index]) + weights[:, index + 3] = ( + (1.0 - node_weights) * + node_cell_weights[tris_inter, node1_inter, index]) + cell_indices[:, index + 3] = ( + node_cell_indices[tris_inter, node1_inter, index]) + + if first: + x_out = intersections.x + y_out = intersections.y + z_out = intersections.z + d_node = d_node_local + + tris = tris_inter + nodes = node0_inter + interp_cells = cell_indices + cell_weights = weights + first = False + else: + x_out = np.append(x_out, intersections.x) + y_out = np.append(y_out, intersections.y) + z_out = np.append(z_out, intersections.z) + d_node = np.append(d_node, d_node_local) + + tris = np.concatenate((tris, tris_inter)) + nodes = np.concatenate((nodes, node0_inter)) + interp_cells = np.concatenate((interp_cells, cell_indices), axis=0) + cell_weights = np.concatenate((cell_weights, weights), axis=0) + + d_transect[seg_index + 1] = d_start + + epsilon = 1e-6 * subdivision_res + (d_node, x_out, y_out, z_out, seg_tris, seg_nodes, interp_cells, + cell_weights, valid_nodes) = _sort_intersections( + d_node, tris, nodes, x_out, y_out, z_out, interp_cells, cell_weights, + epsilon) + + lon_out, lat_out = cartesian_to_lon_lat(x_out, y_out, z_out, earth_radius, + degrees) + + valid_segs = seg_tris >= 0 + cell_indices = -1 * np.ones(seg_tris.shape, dtype=int) + cell_indices[valid_segs] = ( + ds_tris.triCellIndices.values[seg_tris[valid_segs]]) + + ds_out = xr.Dataset() + ds_out['xCartNode'] = (('nNodes',), x_out) + ds_out['yCartNode'] = (('nNodes',), y_out) + ds_out['zCartNode'] = (('nNodes',), z_out) + ds_out['dNode'] = (('nNodes',), d_node) + ds_out['lonNode'] = (('nNodes',), lon_out) + ds_out['latNode'] = (('nNodes',), lat_out) + + ds_out['horizTriangleIndices'] = ('nSegments', seg_tris) + ds_out['horizCellIndices'] = ('nSegments', cell_indices) + ds_out['horizTriangleNodeIndices'] = (('nSegments', 'nHorizBounds'), + seg_nodes) + ds_out['interpHorizCellIndices'] = (('nNodes', 'nHorizWeights'), + interp_cells) + ds_out['interpHorizCellWeights'] = (('nNodes', 'nHorizWeights'), + cell_weights) + ds_out['validNodes'] = (('nNodes',), valid_nodes) + + transect_indices_on_horiz_node = np.zeros(d_node.shape, dtype=int) + transect_weights_on_horiz_node = np.zeros(d_node.shape) + for trans_index in range(len(d_transect) - 1): + d0 = d_transect[trans_index] + d1 = d_transect[trans_index + 1] + mask = np.logical_and(d_node >= d0, d_node < d1) + transect_indices_on_horiz_node[mask] = trans_index + transect_weights_on_horiz_node[mask] = (d1 - d_node[mask]) / (d1 - d0) + # last index will get missed by the mask and needs to be handled as a + # special case + transect_indices_on_horiz_node[-1] = len(d_transect) - 2 + transect_weights_on_horiz_node[-1] = 0.0 + + ds_out['lonTransect'] = lon_transect + ds_out['latTransect'] = lat_transect + ds_out['xCartTransect'] = x + ds_out['yCartTransect'] = y + ds_out['zCartTransect'] = z + ds_out['dTransect'] = (lon_transect.dims, d_transect) + ds_out['transectIndicesOnHorizNode'] = (('nNodes',), + transect_indices_on_horiz_node) + ds_out['transectWeightsOnHorizNode'] = (('nNodes',), + transect_weights_on_horiz_node) + + return ds_out + + +def find_planar_transect_cells_and_weights( + x_transect, y_transect, ds_tris, ds_mesh, tree, subdivision_res=10e3): + """ + Find "nodes" where the transect intersects the edges of the triangles + that make up MPAS cells. + + Parameters + ---------- + x_transect : xarray.DataArray + The x points defining segments making up the transect + + y_transect : xarray.DataArray + The y points defining segments making up the transect + + ds_tris : xarray.Dataset + A dataset that defines triangles, the results of calling + `:py:func:`mpas_tools.viz.mesh_to_triangles.mesh_to_triangles()` + + ds_mesh : xarray.Dataset + A data set with the full MPAS mesh. + + tree : scipy.spatial.cKDTree + A tree of edge centers from triangles making up an MPAS mesh, the + return value from ``make_triangle_tree()`` + + subdivision_res : float, optional + Resolution in m to use to subdivide the transect when looking for + intersection candidates. Should be small enough that curvature is + small. + + Returns + ------- + ds_out : xarray.Dataset + A dataset that contains "nodes" where the transect intersects the + edges of the triangles in ``ds_tris``. The nodes also include the two + end points of the transect, which typically lie within triangles. Each + internal node (that is, not including the end points) is purposefully + repeated twice, once for each triangle that node touches. This allows + for discontinuous fields between triangles (e.g. if one wishes to plot + constant values on each MPAS cell). The planar coordinates of these + nodes are ``xNode`` and ``yNode``. The distance along the transect of + each intersection is ``dNode``. The index of the triangle and the first + triangle node in ``ds_tris`` associated with each intersection node are + given by ``horizTriangleIndices`` and ``horizTriangleNodeIndices``, + respectively. The second node on the triangle for the edge associated + with the intersection is given by + ``numpy.mod(horizTriangleNodeIndices + 1, 3)``. + + The MPAS cell that a given node belongs to is given by + ``horizCellIndices``. Each node also has an associated set of 6 + ``interpHorizCellIndices`` and ``interpHorizCellWeights`` that can be + used to interpolate from MPAS cell centers to nodes first with + area-weighted averaging to MPAS vertices and then linear interpolation + along triangle edges. Some of the weights may be zero, in which case + the associated ``interpHorizCellIndices`` will be -1. + + Finally, ``xTransect`` and ``yTransect`` are included in the + dataset, along with ``dTransect``, the distance along the transect of + each original transect point. In order to interpolate values (e.g. + observations) from the original transect points to the intersection + nodes, linear interpolation indices ``transectIndicesOnHorizNode`` and + weights ``transectWeightsOnHorizNode`` are provided. The values at + nodes are found by:: + + nodeValues = ((transectValues[transectIndicesOnHorizNode] * + transectWeightsOnHorizNode) + + (transectValues[transectIndicesOnHorizNode+1] * + (1.0 - transectWeightsOnHorizNode)) + """ + buffer = np.maximum(np.amax(ds_mesh.dvEdge.values), + np.amax(ds_mesh.dcEdge.values)) + + n_nodes = ds_tris.sizes['nNodes'] + node_cell_weights = ds_tris.nodeCellWeights.values + node_cell_indices = ds_tris.nodeCellIndices.values + + x = x_transect + y = y_transect + + x_node = ds_tris.xNode.values.ravel() + y_node = ds_tris.yNode.values.ravel() + + coordNode = np.zeros((len(x_node), 2)) + coordNode[:, 0] = x_node + coordNode[:, 1] = y_node + + d_transect = np.zeros(x_transect.shape) + + d_node = None + x_out = np.array([]) + y_out = np.array([]) + tris = None + nodes = None + interp_cells = None + cell_weights = None + + n_horiz_weights = 6 + + first = True + + d_start = 0. + for seg_index in range(len(x) - 1): + + sub_slice = slice(seg_index, seg_index + 2) + x_sub, y_sub, _, _ = subdivide_planar( + x[sub_slice].values, y[sub_slice].values, subdivision_res) + + start_point = Point(x_transect[seg_index].values, + y_transect[seg_index].values) + end_point = Point(x_transect[seg_index + 1].values, + y_transect[seg_index + 1].values) + + segment = LineString([start_point, end_point]) + + coords = np.zeros((len(x_sub), 3)) + coords[:, 0] = x_sub + coords[:, 1] = y_sub + radius = buffer + subdivision_res + + index_list = tree.query_ball_point(x=coords, r=radius) + + unique_indices = set() + for indices in index_list: + unique_indices.update(indices) + + start_indices = np.array(list(unique_indices)) + + if len(start_indices) == 0: + continue + + tris_cand = start_indices // n_nodes + next_node_index = np.mod(start_indices + 1, n_nodes) + end_indices = n_nodes * tris_cand + next_node_index + + intersecting_nodes = list() + tris_inter_list = list() + x_intersection_list = list() + y_intersection_list = list() + node_weights_list = list() + node0_inter_list = list() + node1_inter_list = list() + distances_list = list() + + for index in range(len(start_indices)): + start = start_indices[index] + end = end_indices[index] + + node0 = Point(coordNode[start, 0], coordNode[start, 1]) + node1 = Point(coordNode[end, 0], coordNode[end, 1]) + + edge = LineString([node0, node1]) + if segment.intersects(edge): + point = segment.intersection(edge) + intersecting_nodes.append((node0, node1, start, end, edge)) + + if isinstance(point, LineString): + raise ValueError('A triangle edge exactly coincides with ' + 'a transect segment and I can\'t handle ' + 'that case. Try moving the transect a ' + 'tiny bit.') + elif not isinstance(point, Point): + raise ValueError(f'Unexpected intersection type {point}') + + x_intersection_list.append(point.x) + y_intersection_list.append(point.y) + + start_to_intersection = LineString([start_point, point]) + + weight = (LineString([point, node1]).length / + LineString([node0, node1]).length) + + node_weights_list.append(weight) + node0_inter_list.append(np.mod(start, n_nodes)) + node1_inter_list.append(np.mod(end, n_nodes)) + distances_list.append(start_to_intersection.length) + tris_inter_list.append(tris_cand[index]) + + distances = np.array(distances_list) + x_intersection = np.array(x_intersection_list) + y_intersection = np.array(y_intersection_list) + node_weights = np.array(node_weights_list) + node0_inter = np.array(node0_inter_list, dtype=int) + node1_inter = np.array(node1_inter_list, dtype=int) + tris_inter = np.array(tris_inter_list, dtype=int) + + d_node_local = d_start + distances + + d_start += segment.length + + weights = np.zeros((len(tris_inter), n_horiz_weights)) + cell_indices = np.zeros((len(tris_inter), n_horiz_weights), int) + for index in range(3): + weights[:, index] = ( + node_weights * + node_cell_weights[tris_inter, node0_inter, index]) + cell_indices[:, index] = ( + node_cell_indices[tris_inter, node0_inter, index]) + weights[:, index + 3] = ( + (1.0 - node_weights) * + node_cell_weights[tris_inter, node1_inter, index]) + cell_indices[:, index + 3] = ( + node_cell_indices[tris_inter, node1_inter, index]) + + if first: + x_out = x_intersection + y_out = y_intersection + d_node = d_node_local + + tris = tris_inter + nodes = node0_inter + interp_cells = cell_indices + cell_weights = weights + first = False + else: + x_out = np.append(x_out, x_intersection) + y_out = np.append(y_out, y_intersection) + d_node = np.append(d_node, d_node_local) + + tris = np.concatenate((tris, tris_inter)) + nodes = np.concatenate((nodes, node0_inter)) + interp_cells = np.concatenate((interp_cells, cell_indices), axis=0) + cell_weights = np.concatenate((cell_weights, weights), axis=0) + + d_transect[seg_index + 1] = d_start + + z_out = np.zeros(x_out.shape) + + epsilon = 1e-6 * subdivision_res + (d_node, x_out, y_out, z_out, seg_tris, seg_nodes, interp_cells, + cell_weights, valid_nodes) = _sort_intersections( + d_node, tris, nodes, x_out, y_out, z_out, interp_cells, cell_weights, + epsilon) + + valid_segs = seg_tris >= 0 + cell_indices = -1 * np.ones(seg_tris.shape, dtype=int) + cell_indices[valid_segs] = ( + ds_tris.triCellIndices.values[seg_tris[valid_segs]]) + + ds_out = xr.Dataset() + ds_out['xNode'] = (('nNodes',), x_out) + ds_out['yNode'] = (('nNodes',), y_out) + ds_out['dNode'] = (('nNodes',), d_node) + + ds_out['horizTriangleIndices'] = ('nSegments', seg_tris) + ds_out['horizCellIndices'] = ('nSegments', cell_indices) + ds_out['horizTriangleNodeIndices'] = (('nSegments', 'nHorizBounds'), + seg_nodes) + ds_out['interpHorizCellIndices'] = (('nNodes', 'nHorizWeights'), + interp_cells) + ds_out['interpHorizCellWeights'] = (('nNodes', 'nHorizWeights'), + cell_weights) + ds_out['validNodes'] = (('nNodes',), valid_nodes) + + transect_indices_on_horiz_node = np.zeros(d_node.shape, int) + transect_weights_on_horiz_node = np.zeros(d_node.shape) + for trans_index in range(len(d_transect) - 1): + d0 = d_transect[trans_index] + d1 = d_transect[trans_index + 1] + mask = np.logical_and(d_node >= d0, d_node < d1) + transect_indices_on_horiz_node[mask] = trans_index + transect_weights_on_horiz_node[mask] = (d1 - d_node[mask]) / (d1 - d0) + # last index will get missed by the mask and needs to be handled as a + # special case + transect_indices_on_horiz_node[-1] = len(d_transect) - 2 + transect_weights_on_horiz_node[-1] = 0.0 + + ds_out['xTransect'] = x + ds_out['yTransect'] = y + ds_out['dTransect'] = (x_transect.dims, d_transect) + ds_out['transectIndicesOnHorizNode'] = (('nNodes',), + transect_indices_on_horiz_node) + ds_out['transectWeightsOnHorizNode'] = (('nNodes',), + transect_weights_on_horiz_node) + + return ds_out + + +def interp_mpas_horiz_to_transect_nodes(ds_transect, da): + """ + Interpolate a 2D (``nCells``) MPAS DataArray to transect nodes, linearly + interpolating fields between the closest neighboring cells + + Parameters + ---------- + ds_transect : xr.Dataset + A dataset that defines an MPAS transect, the results of calling + ``find_spherical_transect_cells_and_weights()`` or + ``find_planar_transect_cells_and_weights()`` + + da : xr.DataArray + An MPAS 2D field with dimensions `nCells`` (possibly among others) + + Returns + ------- + da_nodes : xr.DataArray + The data array interpolated to transect nodes with dimensions + ``nNodes`` (in addition to whatever dimensions were in ``da`` besides + ``nCells``) + """ + interp_cell_indices = ds_transect.interpHorizCellIndices + interp_cell_weights = ds_transect.interpHorizCellWeights + da = da.isel(nCells=interp_cell_indices) + da_nodes = (da * interp_cell_weights).sum(dim='nHorizWeights') + + da_nodes = da_nodes.where(ds_transect.validNodes) + + return da_nodes + + +def _sort_intersections(d_node, tris, nodes, x_out, y_out, z_out, interp_cells, + cell_weights, epsilon): + """ sort nodes by distance and define segment between them """ + + sort_indices = np.argsort(d_node) + d_sorted = d_node[sort_indices] + + # make a list of indices for each unique value of d + d = d_sorted[0] + unique_d_indices = [sort_indices[0]] + unique_d_all_indices = [[sort_indices[0]]] + for index, next_d, in zip(sort_indices[1:], d_sorted[1:]): + if next_d - d < epsilon: + # this d value is effectively the same as the last, so we'll treat + # it as the same + unique_d_all_indices[-1].append(index) + else: + # this is a new d, so we'll add to a new list + d = next_d + unique_d_indices.append(index) + unique_d_all_indices.append([index]) + + # there is a segment between each unique d, though some are invalid (do + # not correspond to a triangle) + seg_tris_list = list() + seg_nodes_list = list() + + index0 = unique_d_indices[0] + indices0 = unique_d_all_indices[0] + d0 = d_node[index0] + + indices = [index0] + ds = [d0] + for seg_index in range(len(unique_d_all_indices) - 1): + indices1 = unique_d_all_indices[seg_index + 1] + index1 = unique_d_indices[seg_index + 1] + d1 = d_node[index1] + + # are there any triangles in common between this d value and the next? + tris0 = tris[indices0] + tris1 = tris[indices1] + both = set(tris0).intersection(set(tris1)) + + if len(both) > 0: + tri = both.pop() + seg_tris_list.append(tri) + indices.append(index1) + ds.append(d1) + + # the triangle nodes are the 2 corresponding to the same triangle + # in the original list + index0 = indices0[np.where(tris0 == tri)[0][0]] + index1 = indices1[np.where(tris1 == tri)[0][0]] + seg_nodes_list.append([nodes[index0], nodes[index1]]) + else: + # this is an invalid segment so we need to insert and extra invalid + # node to allow for proper masking + seg_tris_list.extend([-1, -1]) + seg_nodes_list.extend([[-1, -1], [-1, -1]]) + indices.extend([index0, index1]) + ds.extend([0.5 * (d0 + d1), d1]) + + index0 = index1 + indices0 = indices1 + d0 = d1 + + indices = np.array(indices, dtype=int) + d_node = np.array(ds, dtype=float) + seg_tris = np.array(seg_tris_list, dtype=int) + seg_nodes = np.array(seg_nodes_list, dtype=int) + + valid_nodes = np.ones(len(indices), dtype=bool) + valid_nodes[1:-1] = np.logical_or(seg_tris[0:-1] >= 0, + seg_tris[1:] > 0) + + x_out = x_out[indices] + y_out = y_out[indices] + z_out = z_out[indices] + + interp_cells = interp_cells[indices, :] + cell_weights = cell_weights[indices, :] + + return (d_node, x_out, y_out, z_out, seg_tris, seg_nodes, interp_cells, + cell_weights, valid_nodes) diff --git a/polaris/ocean/viz/transect/vert.py b/polaris/ocean/viz/transect/vert.py new file mode 100644 index 000000000..dd898b2b0 --- /dev/null +++ b/polaris/ocean/viz/transect/vert.py @@ -0,0 +1,477 @@ +import numpy as np +import xarray as xr +from mpas_tools.viz import mesh_to_triangles + +from polaris.ocean.viz.transect.horiz import ( + find_planar_transect_cells_and_weights, + find_spherical_transect_cells_and_weights, + interp_mpas_horiz_to_transect_nodes, + make_triangle_tree, +) + + +def compute_transect(x, y, ds_3d_mesh, spherical=False): + """ + build a sequence of quads showing the transect intersecting mpas cells + + Parameters + ---------- + x : xarray.DataArray + The x or longitude coordinate of the transect + + y : xarray.DataArray + The y or latitude coordinate of the transect + + ds_3d_mesh : xarray.Dataset + The MPAS-Ocean mesh to use for plotting + + spherical : bool, optional + Whether the x and y coordinates are latitude and longitude in degrees + + Returns + ------- + ds_transect : xarray.Dataset + The transect dataset, see + :py:func:`polaris.ocean.viz.transect.vert.find_transect_levels_and_weights()` + for details + """ # noqa: E501 + + ds_tris = mesh_to_triangles(ds_3d_mesh) + + triangle_tree = make_triangle_tree(ds_tris) + + if spherical: + ds_horiz_transect = find_spherical_transect_cells_and_weights( + x, y, ds_tris, ds_3d_mesh, triangle_tree, degrees=True) + else: + ds_horiz_transect = find_planar_transect_cells_and_weights( + x, y, ds_tris, ds_3d_mesh, triangle_tree) + + # mask horizontal transect to valid cells (maxLevelCell > 0) + cell_indices = ds_horiz_transect.horizCellIndices + seg_mask = ds_3d_mesh.maxLevelCell.isel(nCells=cell_indices).values > 0 + node_mask = np.zeros(ds_horiz_transect.sizes['nNodes'], dtype=bool) + node_mask[0:-1] = seg_mask + node_mask[1:] = np.logical_or(node_mask[1:], seg_mask) + + ds_horiz_transect = ds_horiz_transect.isel(nSegments=seg_mask, + nNodes=node_mask) + + ds_transect = find_transect_levels_and_weights( + ds_horiz_transect=ds_horiz_transect, + layer_thickness=ds_3d_mesh.layerThickness, + bottom_depth=ds_3d_mesh.bottomDepth, + min_level_cell=ds_3d_mesh.minLevelCell - 1, + max_level_cell=ds_3d_mesh.maxLevelCell - 1) + + # interpolate the land-ice fraction so we can plot an overlying ice shelf + if 'landIceFraction' in ds_3d_mesh: + ds_transect['landIceFraction'] = interp_mpas_horiz_to_transect_nodes( + ds_transect, ds_3d_mesh.landIceFraction) + + ds_transect.compute() + + return ds_transect + + +def find_transect_levels_and_weights(ds_horiz_transect, layer_thickness, + bottom_depth, min_level_cell, + max_level_cell): + """ + Construct a vertical coordinate for a transect produced by + :py:func:`polaris.ocean.viz.transect.horiz.find_spherical_transect_cells_and_weights()` + or :py:func:`polaris.ocean.viz.transect.horiz.find_planar_transect_cells_and_weights()`. + Also, compute interpolation weights such that observations at points on the + original transect and with vertical coordinate ``transectZ`` can be + bilinearly interpolated to the nodes of the transect. + + Parameters + ---------- + ds_horiz_transect : xarray.Dataset + A dataset that defines nodes of the transect + + layer_thickness : xarray.DataArray + layer thicknesses on the MPAS mesh + + bottom_depth : xarray.DataArray + the (positive down) depth of the seafloor on the MPAS mesh + + min_level_cell : xarray.DataArray + the vertical zero-based index of the sea surface on the MPAS mesh + + max_level_cell : xarray.DataArray + the vertical zero-based index of the bathymetry on the MPAS mesh + + Returns + ------- + ds_transect : xarray.Dataset + A dataset that contains nodes and cells that make up a 2D transect. + + There are ``nSegments`` horizontal and ``nHalfLevels`` vertical + transect cells (quadrilaterals), bounded by ``nHorizNodes`` horizontal + and ``nVertNodes`` vertical nodes (corners). + + In addition to the variables and coordinates in the input + ``ds_transect``, the output dataset contains: + + - ``validCells``, ``validNodes``: which transect cells and nodes + are valid (above the bathymetry and below the sea surface) + + - zTransectNode: the vertical height of each triangle node + - ssh, zSeaFloor: the sea-surface height and sea-floor height at + each node of each transect segment + + - ``cellIndices``: the MPAS-Ocean cell of a given transect segment + - ``levelIndices``: the MPAS-Ocean vertical level of a given + transect level + + - ``interpCellIndices``, ``interpLevelIndices``: the MPAS-Ocean + cells and levels from which the value at a given transect cell is + interpolated. This can involve up to + ``nHorizWeights * nVertWeights = 12`` different cells and levels. + - interpCellWeights: the weight to multiply each field value by + to perform interpolation to a transect cell. + + - ``dInterfaceSegment``, ``zInterfaceSegment`` - segments that can + be used to plot the interfaces between MPAS-Ocean layers + + - ``dCellBoundary``, ``zCellBoundary`` - segments that can + be used to plot the vertical boundaries between MPAS-Ocean cells + + Interpolation of a DataArray from MPAS cells and levels to transect + cells can be performed with + :py:func:`polaris.ocean.viz.transect.vert.interp_mpas_to_transect_cells()`. + Similarly, interpolation to transect nodes can be performed with + :py:func:`polaris.ocean.viz.transect.vert.interp_mpas_to_transect_nodes()`. + """ # noqa: E501 + if 'Time' in layer_thickness.dims: + raise ValueError('Please select a single time level in layer ' + 'thickness.') + + ds_transect_cells = ds_horiz_transect.rename({'nNodes': 'nHorizNodes'}) + + (z_half_interface, ssh, z_seafloor, interp_cell_indices, + interp_cell_weights, valid_transect_cells, + level_indices) = _get_vertical_coordinate( + ds_transect_cells, layer_thickness, bottom_depth, min_level_cell, + max_level_cell) + + ds_transect_cells['zTransectNode'] = z_half_interface + + ds_transect_cells['ssh'] = ssh + ds_transect_cells['zSeafloor'] = z_seafloor + + ds_transect_cells['cellIndices'] = ds_transect_cells.horizCellIndices + ds_transect_cells['levelIndices'] = level_indices + ds_transect_cells['validCells'] = valid_transect_cells + + d_interface_seg, z_interface_seg = _get_interface_segments( + z_half_interface, ds_transect_cells.dNode, valid_transect_cells) + + ds_transect_cells['dInterfaceSegment'] = d_interface_seg + ds_transect_cells['zInterfaceSegment'] = z_interface_seg + + d_cell_boundary, z_cell_boundary = _get_cell_boundary_segments( + ssh, z_seafloor, ds_transect_cells.dNode, + ds_transect_cells.horizCellIndices) + + ds_transect_cells['dCellBoundary'] = d_cell_boundary + ds_transect_cells['zCellBoundary'] = z_cell_boundary + + interp_level_indices, interp_cell_weights, valid_nodes = \ + _get_interp_indices_and_weights(layer_thickness, interp_cell_indices, + interp_cell_weights, level_indices, + valid_transect_cells) + + ds_transect_cells['interpCellIndices'] = interp_cell_indices + ds_transect_cells['interpLevelIndices'] = interp_level_indices + ds_transect_cells['interpCellWeights'] = interp_cell_weights + ds_transect_cells['validNodes'] = valid_nodes + + dims = ['nSegments', 'nHalfLevels', 'nHorizNodes', 'nVertNodes', + 'nInterfaceSegments', 'nCellBoundaries', 'nHorizBounds', + 'nVertBounds', 'nHorizWeights', 'nVertWeights'] + for dim in ds_transect_cells.dims: + if dim not in dims: + dims.insert(0, dim) + ds_transect_cells = ds_transect_cells.transpose(*dims) + + return ds_transect_cells + + +def interp_mpas_to_transect_cells(ds_transect, da): + """ + Interpolate a 3D (``nCells`` by ``nVertLevels``) MPAS-Ocean DataArray + to transect cells + + Parameters + ---------- + ds_transect : xarray.Dataset + A dataset that defines an MPAS-Ocean transect, the results of calling + ``find_transect_levels_and_weights()`` + + da : xarray.DataArray + An MPAS-Ocean 3D field with dimensions `nCells`` and ``nVertLevels`` + (possibly among others) + + Returns + ------- + da_cells : xarray.DataArray + The data array interpolated to transect cells with dimensions + ``nSegments`` and ``nHalfLevels`` (in addition to whatever + dimensions were in ``da`` besides ``nCells`` and ``nVertLevels``) + """ + + cell_indices = ds_transect.cellIndices + level_indices = ds_transect.levelIndices + + da_cells = da.isel(nCells=cell_indices, nVertLevels=level_indices) + da_cells = da_cells.where(ds_transect.validCells) + + return da_cells + + +def interp_mpas_to_transect_nodes(ds_transect, da): + """ + Interpolate a 3D (``nCells`` by ``nVertLevels``) MPAS-Ocean DataArray + to transect nodes, linearly interpolating fields between the closest + neighboring cells + + Parameters + ---------- + ds_transect : xarray.Dataset + A dataset that defines an MPAS-Ocean transect, the results of calling + ``find_transect_levels_and_weights()`` + + da : xarray.DataArray + An MPAS-Ocean 3D field with dimensions `nCells`` and ``nVertLevels`` + (possibly among others) + + Returns + ------- + da_nodes : xarray.DataArray + The data array interpolated to transect nodes with dimensions + ``nHorizNodes`` and ``nVertNodes`` (in addition to whatever + dimensions were in ``da`` besides ``nCells`` and ``nVertLevels``) + """ + + interp_cell_indices = ds_transect.interpCellIndices + interp_level_indices = ds_transect.interpLevelIndices + interp_cell_weights = ds_transect.interpCellWeights + + da = da.isel(nCells=interp_cell_indices, nVertLevels=interp_level_indices) + + da_nodes = (da * interp_cell_weights).sum( + dim=('nHorizWeights', 'nVertWeights')) + + da_nodes = da_nodes.where(ds_transect.validNodes) + + return da_nodes + + +def _get_vertical_coordinate(ds_transect, layer_thickness, bottom_depth, + min_level_cell, max_level_cell): + n_horiz_nodes = ds_transect.sizes['nHorizNodes'] + n_segments = ds_transect.sizes['nSegments'] + n_vert_levels = layer_thickness.sizes['nVertLevels'] + + # we assume below that there is a segment (whether valid or invalid) + # connecting each pair of adjacent nodes + assert n_horiz_nodes == n_segments + 1 + + interp_horiz_cell_indices = ds_transect.interpHorizCellIndices + interp_horiz_cell_weights = ds_transect.interpHorizCellWeights + + bottom_depth_interp = bottom_depth.isel(nCells=interp_horiz_cell_indices) + layer_thickness_interp = layer_thickness.isel( + nCells=interp_horiz_cell_indices) + + cell_mask_interp = _get_cell_mask(interp_horiz_cell_indices, + min_level_cell, max_level_cell, + n_vert_levels) + layer_thickness_interp = layer_thickness_interp.where(cell_mask_interp, 0.) + + ssh_interp = (-bottom_depth_interp + + layer_thickness_interp.sum(dim='nVertLevels')) + + interp_mask = np.logical_and(interp_horiz_cell_indices > 0, + cell_mask_interp) + + interp_cell_weights = interp_mask * interp_horiz_cell_weights + weight_sum = interp_cell_weights.sum(dim='nHorizWeights') + + cell_indices = ds_transect.horizCellIndices + + valid_cells = _get_cell_mask(cell_indices, min_level_cell, max_level_cell, + n_vert_levels) + + valid_cells = valid_cells.transpose('nSegments', 'nVertLevels').values + + valid_nodes = np.zeros((n_horiz_nodes, n_vert_levels), dtype=bool) + valid_nodes[0:-1, :] = valid_cells + valid_nodes[1:, :] = np.logical_or(valid_nodes[1:, :], valid_cells) + + valid_nodes = xr.DataArray(dims=('nHorizNodes', 'nVertLevels'), + data=valid_nodes) + + valid_weights = valid_nodes.broadcast_like(interp_cell_weights) + interp_cell_weights = \ + (interp_cell_weights / weight_sum).where(valid_weights) + + layer_thickness_transect = (layer_thickness_interp * + interp_cell_weights).sum(dim='nHorizWeights') + + interp_mask = max_level_cell.isel(nCells=interp_horiz_cell_indices) >= 0 + interp_horiz_cell_weights = interp_mask * interp_horiz_cell_weights + weight_sum = interp_horiz_cell_weights.sum(dim='nHorizWeights') + interp_horiz_cell_weights = \ + (interp_horiz_cell_weights / weight_sum).where(interp_mask) + + ssh_transect = (ssh_interp * + interp_horiz_cell_weights).sum(dim='nHorizWeights') + + z_bot = ssh_transect - layer_thickness_transect.cumsum(dim='nVertLevels') + z_mid = z_bot + 0.5 * layer_thickness_transect + + z_half_interfaces = [ssh_transect] + for z_index in range(n_vert_levels): + z_half_interfaces.extend([z_mid.isel(nVertLevels=z_index), + z_bot.isel(nVertLevels=z_index)]) + + z_half_interface = xr.concat(z_half_interfaces, dim='nVertNodes') + z_half_interface = z_half_interface.transpose('nHorizNodes', 'nVertNodes') + + z_seafloor = ssh_transect - layer_thickness_transect.sum( + dim='nVertLevels') + + valid_transect_cells = np.zeros((n_segments, 2 * n_vert_levels), + dtype=bool) + valid_transect_cells[:, 0::2] = valid_cells + valid_transect_cells[:, 1::2] = valid_cells + valid_transect_cells = xr.DataArray(dims=('nSegments', 'nHalfLevels'), + data=valid_transect_cells) + + level_indices = np.zeros(2 * n_vert_levels, dtype=int) + level_indices[0::2] = np.arange(n_vert_levels) + level_indices[1::2] = np.arange(n_vert_levels) + level_indices = xr.DataArray(dims=('nHalfLevels',), data=level_indices) + + return (z_half_interface, ssh_transect, z_seafloor, + interp_horiz_cell_indices, interp_cell_weights, + valid_transect_cells, level_indices) + + +def _get_cell_mask(cell_indices, min_level_cell, max_level_cell, + n_vert_levels): + level_indices = xr.DataArray(data=np.arange(n_vert_levels), + dims='nVertLevels') + min_level_cell = min_level_cell.isel(nCells=cell_indices) + max_level_cell = max_level_cell.isel(nCells=cell_indices) + + cell_mask = np.logical_and( + level_indices >= min_level_cell, + level_indices <= max_level_cell) + + cell_mask = np.logical_and(cell_mask, cell_indices >= 0) + + return cell_mask + + +def _get_interface_segments(z_half_interface, d_node, valid_transect_cells): + + d = d_node.broadcast_like(z_half_interface) + z_interface = z_half_interface.values[:, 0::2] + d = d.values[:, 0::2] + + n_segments = valid_transect_cells.sizes['nSegments'] + n_half_levels = valid_transect_cells.sizes['nHalfLevels'] + n_vert_levels = n_half_levels // 2 + + valid_segs = np.zeros((n_segments, n_vert_levels + 1), dtype=bool) + valid_segs[:, 0:-1] = valid_transect_cells.values[:, 1::2] + valid_segs[:, 1:] = np.logical_or(valid_segs[:, 1:], + valid_transect_cells.values[:, 0::2]) + + n_interface_segs = np.count_nonzero(valid_segs) + + d_seg = np.zeros((n_interface_segs, 2)) + z_seg = np.zeros((n_interface_segs, 2)) + d_seg[:, 0] = d[0:-1, :][valid_segs] + d_seg[:, 1] = d[1:, :][valid_segs] + z_seg[:, 0] = z_interface[0:-1, :][valid_segs] + z_seg[:, 1] = z_interface[1:, :][valid_segs] + + d_seg = xr.DataArray(dims=('nInterfaceSegments', 'nHorizBounds'), + data=d_seg) + + z_seg = xr.DataArray(dims=('nInterfaceSegments', 'nHorizBounds'), + data=z_seg) + + return d_seg, z_seg + + +def _get_cell_boundary_segments(ssh, z_seafloor, d_node, cell_indices): + + n_horiz_nodes = d_node.sizes['nHorizNodes'] + + cell_boundary = np.ones(n_horiz_nodes, dtype=bool) + cell_boundary[1:-1] = cell_indices.values[0:-1] != cell_indices.values[1:] + + n_cell_boundaries = np.count_nonzero(cell_boundary) + + d_seg = np.zeros((n_cell_boundaries, 2)) + z_seg = np.zeros((n_cell_boundaries, 2)) + d_seg[:, 0] = d_node.values[cell_boundary] + d_seg[:, 1] = d_seg[:, 0] + z_seg[:, 0] = ssh[cell_boundary] + z_seg[:, 1] = z_seafloor[cell_boundary] + + d_seg = xr.DataArray(dims=('nCellBoundaries', 'nVertBounds'), data=d_seg) + + z_seg = xr.DataArray(dims=('nCellBoundaries', 'nVertBounds'), data=z_seg) + + return d_seg, z_seg + + +def _get_interp_indices_and_weights(layer_thickness, interp_cell_indices, + interp_cell_weights, level_indices, + valid_transect_cells): + n_horiz_nodes = interp_cell_indices.sizes['nHorizNodes'] + n_vert_levels = layer_thickness.sizes['nVertLevels'] + n_vert_nodes = 2 * n_vert_levels + 1 + n_vert_weights = 2 + + interp_level_indices = -1 * np.ones((n_vert_nodes, n_vert_weights), + dtype=int) + interp_level_indices[1:, 0] = level_indices.values + interp_level_indices[0:-1, 1] = level_indices.values + + interp_level_indices = xr.DataArray(dims=('nVertNodes', 'nVertWeights'), + data=interp_level_indices) + + half_level_thickness = 0.5 * layer_thickness.isel( + nCells=interp_cell_indices, nVertLevels=interp_level_indices) + half_level_thickness = half_level_thickness.where( + interp_level_indices >= 0, other=0.) + + # vertical weights are proportional to the half-level thickness + interp_cell_weights = half_level_thickness * interp_cell_weights.isel( + nVertLevels=interp_level_indices) + + valid_nodes = np.zeros((n_horiz_nodes, n_vert_nodes), dtype=bool) + valid_nodes[0:-1, 0:-1] = valid_transect_cells + valid_nodes[1:, 0:-1] = np.logical_or(valid_nodes[1:, 0:-1], + valid_transect_cells) + valid_nodes[0:-1, 1:] = np.logical_or(valid_nodes[0:-1, 1:], + valid_transect_cells) + valid_nodes[1:, 1:] = np.logical_or(valid_nodes[1:, 1:], + valid_transect_cells) + + valid_nodes = xr.DataArray(dims=('nHorizNodes', 'nVertNodes'), + data=valid_nodes) + + weight_sum = interp_cell_weights.sum(dim=('nHorizWeights', 'nVertWeights')) + out_mask = (weight_sum > 0.).broadcast_like(interp_cell_weights) + interp_cell_weights = (interp_cell_weights / weight_sum).where(out_mask) + + return interp_level_indices, interp_cell_weights, valid_nodes From 4ca3362c4fbab07633247e41fa62a3b23cbeb8c8 Mon Sep 17 00:00:00 2001 From: Xylar Asay-Davis Date: Mon, 6 Nov 2023 14:48:50 +0100 Subject: [PATCH 04/14] Switch transect viz to use polaris transects on quads --- polaris/ocean/viz/__init__.py | 3 +- polaris/ocean/viz/transect.py | 159 ------------------------ polaris/ocean/viz/transect/plot.py | 188 +++++++++++++++++++++++++++++ 3 files changed, 190 insertions(+), 160 deletions(-) delete mode 100644 polaris/ocean/viz/transect.py create mode 100644 polaris/ocean/viz/transect/plot.py diff --git a/polaris/ocean/viz/__init__.py b/polaris/ocean/viz/__init__.py index 365aaec1c..0efd37287 100644 --- a/polaris/ocean/viz/__init__.py +++ b/polaris/ocean/viz/__init__.py @@ -1 +1,2 @@ -from polaris.ocean.viz.transect import compute_transect, plot_transect +from polaris.ocean.viz.transect.plot import plot_transect +from polaris.ocean.viz.transect.vert import compute_transect diff --git a/polaris/ocean/viz/transect.py b/polaris/ocean/viz/transect.py deleted file mode 100644 index 07780b67e..000000000 --- a/polaris/ocean/viz/transect.py +++ /dev/null @@ -1,159 +0,0 @@ -import cmocean # noqa: blah -import matplotlib.pyplot as plt -import numpy as np -from matplotlib.tri import Triangulation -from mpas_tools.ocean.transects import ( - find_transect_levels_and_weights, - get_outline_segments, - interp_mpas_to_transect_triangles, -) -from mpas_tools.viz import mesh_to_triangles -from mpas_tools.viz.transects import ( - find_planar_transect_cells_and_weights, - find_transect_cells_and_weights, - make_triangle_tree, -) - - -def compute_transect(x, y, ds_3d_mesh, spherical=False): - """ - build a sequence of triangles showing the transect intersecting mpas cells - - Parameters - ---------- - x : xarray.DataArray - The x or longitude coordinate of the transect - - y : xarray.DataArray - The y or latitude coordinate of the transect - - ds_3d_mesh : xarray.Dataset - The MPAS-Ocean mesh to use for plotting - - spherical : bool, optional - Whether the x and y coordinates are latitude and longitude in degrees - - Returns - ------- - ds_transect : xarray.Dataset - The transect dataset - """ - - ds_tris = mesh_to_triangles(ds_3d_mesh) - - triangle_tree = make_triangle_tree(ds_tris) - - if spherical: - ds_transect = find_transect_cells_and_weights( - x, y, ds_tris, ds_3d_mesh, triangle_tree, degrees=True) - else: - ds_transect = find_planar_transect_cells_and_weights( - x, y, ds_tris, ds_3d_mesh, triangle_tree) - - cell_indices = ds_transect.horizCellIndices - mask = ds_3d_mesh.maxLevelCell.isel(nCells=cell_indices) > 0 - ds_transect = ds_transect.isel(nSegments=mask) - - ds_transect = find_transect_levels_and_weights( - ds_transect, ds_3d_mesh.layerThickness, - ds_3d_mesh.bottomDepth, ds_3d_mesh.maxLevelCell - 1) - - if 'landIceFraction' in ds_3d_mesh: - interp_cell_indices = ds_transect.interpHorizCellIndices - interp_cell_weights = ds_transect.interpHorizCellWeights - land_ice_fraction = ds_3d_mesh.landIceFraction.isel( - nCells=interp_cell_indices) - land_ice_fraction = (land_ice_fraction * interp_cell_weights).sum( - dim='nHorizWeights') - ds_transect['landIceFraction'] = land_ice_fraction - - ds_transect['x'] = ds_transect.dNode.isel( - nSegments=ds_transect.segmentIndices, - nHorizBounds=ds_transect.nodeHorizBoundsIndices) - - ds_transect['z'] = ds_transect.zTransectNode - - ds_transect.compute() - - return ds_transect - - -def plot_transect(ds_transect, mpas_field, out_filename, title, - colorbar_label=None, cmap=None, figsize=(12, 6), dpi=200): - """ - plot a transect showing the field on the MPAS-Ocean mesh and save to a file - - Parameters - ---------- - ds_transect : xarray.Dataset - A transect dataset from - :py:func:`polaris.ocean.viz.compute_transect()` - - mpas_field : xarray.DataArray - The MPAS-Ocean 3D field to plot - - out_filename : str - The png file to write out to - - title : str - The title of the plot - - colorbar_label : str, optional - The colorbar label - - cmap : str, optional - The name of a colormap to use - - figsize : tuple, optional - The size of the figure in inches - - dpi : int, optional - The dots per inch of the image - - """ - transect_field = interp_mpas_to_transect_triangles(ds_transect, - mpas_field) - - x_outline, z_outline = get_outline_segments(ds_transect) - x_outline = 1e-3 * x_outline - - tri_mask = np.logical_not(transect_field.notnull().values) - - triangulation_args = _get_ds_triangulation_args(ds_transect) - - triangulation_args['mask'] = tri_mask - - tris = Triangulation(**triangulation_args) - plt.figure(figsize=figsize) - plt.tripcolor(tris, facecolors=transect_field.values, shading='flat', - cmap=cmap) - plt.plot(x_outline, z_outline, 'k') - if colorbar_label is not None: - plt.colorbar(label=colorbar_label) - plt.title(title) - plt.xlabel('x (km)') - plt.ylabel('z (m)') - - plt.savefig(out_filename, dpi=dpi, bbox_inches='tight', pad_inches=0.2) - plt.close() - - -def _get_ds_triangulation_args(ds_transect): - """ - get arguments for matplotlib Triangulation from triangulation dataset - """ - - n_transect_triangles = ds_transect.sizes['nTransectTriangles'] - d_node = ds_transect.dNode.isel( - nSegments=ds_transect.segmentIndices, - nHorizBounds=ds_transect.nodeHorizBoundsIndices) - x = 1e-3 * d_node.values.ravel() - - z_transect_node = ds_transect.zTransectNode - y = z_transect_node.values.ravel() - - tris = np.arange(3 * n_transect_triangles).reshape( - (n_transect_triangles, 3)) - triangulation_args = dict(x=x, y=y, triangles=tris) - - return triangulation_args diff --git a/polaris/ocean/viz/transect/plot.py b/polaris/ocean/viz/transect/plot.py new file mode 100644 index 000000000..73ec7d1cd --- /dev/null +++ b/polaris/ocean/viz/transect/plot.py @@ -0,0 +1,188 @@ +import cmocean # noqa: F401 +import matplotlib.pyplot as plt +import numpy as np + +from polaris.ocean.viz.transect.vert import ( + interp_mpas_to_transect_cells, + interp_mpas_to_transect_nodes, +) +from polaris.viz.style import use_mplstyle + + +def plot_transect(ds_transect, mpas_field=None, out_filename=None, ax=None, + title=None, vmin=None, vmax=None, colorbar_label=None, + cmap=None, figsize=(12, 6), dpi=200, method='flat', + outline_color='black', ssh_color=None, seafloor_color=None, + interface_color=None, cell_boundary_color=None, + linewidth=1.0): + """ + plot a transect showing the field on the MPAS-Ocean mesh and save to a file + + Parameters + ---------- + ds_transect : xarray.Dataset + A transect dataset from + :py:func:`polaris.ocean.viz.compute_transect()` + + mpas_field : xarray.DataArray + The MPAS-Ocean 3D field to plot + + out_filename : str, optional + The png file to write out to + + ax : matplotlib.axes.Axes + Axes to plot to if making a multi-panel figure + + title : str + The title of the plot + + vmin : float, optional + The minimum values for the colorbar + + vmax : float, optional + The maximum values for the colorbar + + colorbar_label : str, optional + The colorbar label, or ``None`` if no colorbar is to be included. + Use an empty string to display a colorbar without a label. + + cmap : str, optional + The name of a colormap to use + + figsize : tuple, optional + The size of the figure in inches + + dpi : int, optional + The dots per inch of the image + + method : {'flat', 'bilinear'}, optional + The type of interpolation to use in plots. ``flat`` means constant + values over each MPAS cell. ``bilinear`` means smooth interpolation + between horizontally between cell centers and vertical between the + middle of layers. + + outline_color : str or None, optional + The color to use to outline the transect or ``None`` for no outline + + ssh_color : str or None, optional + The color to use to plot the SSH (sea surface height) or ``None`` if + not plotting the SSH (except perhaps as part of the outline) + + seafloor_color : str or None, optional + The color to use to plot the seafloor depth or ``None`` if not plotting + the seafloor depth (except perhaps as part of the outline) + + interface_color : str or None, optional + The color to use to plot interfaces between layers or ``None`` if + not plotting the layer interfaces + + cell_boundary_color : str or None, optional + The color to use to plot vertical boundaries between cells or ``None`` + if not plotting cell boundaries. Typically, ``cell_boundary_color`` + will be used along with ``interface_color`` to outline cells both + horizontally and vertically. + + linewidth : float, optional + The width of outlines, interfaces and cell boundaries + """ + + if ax is None and out_filename is None: + raise ValueError('One of ax or out_filename must be supplied') + + use_mplstyle() + + create_fig = ax is None + if create_fig: + plt.figure(figsize=figsize) + ax = plt.subplot(111) + + z = ds_transect.zTransectNode + x = 1e-3 * ds_transect.dNode.broadcast_like(z) + + if mpas_field is not None: + if method == 'flat': + transect_field = interp_mpas_to_transect_cells(ds_transect, + mpas_field) + shading = 'flat' + elif method == 'bilinear': + transect_field = interp_mpas_to_transect_nodes(ds_transect, + mpas_field) + shading = 'gouraud' + else: + raise ValueError(f'Unsupported method: {method}') + + pc = ax.pcolormesh(x.values, z.values, transect_field.values, + shading=shading, cmap=cmap, vmin=vmin, vmax=vmax, + zorder=0) + ax.autoscale(tight=True) + if colorbar_label is not None: + plt.colorbar(pc, extend='both', shrink=0.7, ax=ax, + label=colorbar_label) + + _plot_interfaces(ds_transect, ax, interface_color, cell_boundary_color, + ssh_color, seafloor_color, linewidth) + + _plot_outline(x, z, ds_transect.validNodes, ax, outline_color, + linewidth) + + ax.set_xlabel('transect distance (km)') + ax.set_ylabel('z (m)') + + if create_fig: + if title is not None: + plt.title(title) + plt.savefig(out_filename, dpi=dpi, bbox_inches='tight', pad_inches=0.2) + plt.close() + + +def _plot_interfaces(ds_transect, ax, interface_color, cell_boundary_color, + ssh_color, seafloor_color, linewidth): + if cell_boundary_color is not None: + x_bnd = 1e-3 * ds_transect.dCellBoundary.values.T + z_bnd = ds_transect.zCellBoundary.values.T + ax.plot(x_bnd, z_bnd, color=cell_boundary_color, linewidth=linewidth, + zorder=1) + + if interface_color is not None: + x_int = 1e-3 * ds_transect.dInterfaceSegment.values.T + z_int = ds_transect.zInterfaceSegment.values.T + ax.plot(x_int, z_int, color=interface_color, linewidth=linewidth, + zorder=2) + + if ssh_color is not None: + valid = ds_transect.validNodes.any(dim='nVertNodes') + x_ssh = 1e-3 * ds_transect.dNode.values + z_ssh = ds_transect.ssh.where(valid).values + ax.plot(x_ssh, z_ssh, color=ssh_color, linewidth=linewidth, zorder=4) + + if seafloor_color is not None: + valid = ds_transect.validNodes.any(dim='nVertNodes') + x_floor = 1e-3 * ds_transect.dNode.values + z_floor = ds_transect.zSeafloor.where(valid).values + ax.plot(x_floor, z_floor, color=seafloor_color, linewidth=linewidth, + zorder=5) + + +def _plot_outline(x, z, valid_nodes, ax, outline_color, linewidth, + epsilon=1e-6): + if outline_color is not None: + # add a buffer of invalid values around the edge of the domain + valid = np.zeros((x.shape[0] + 2, x.shape[1] + 2), dtype=float) + z_buf = np.zeros(valid.shape, dtype=float) + x_buf = np.zeros(valid.shape, dtype=float) + + valid[1:-1, 1:-1] = valid_nodes.astype(float) + z_buf[1:-1, 1:-1] = z.values + z_buf[0, 1:-1] = z_buf[1, 1:-1] + z_buf[-1, 1:-1] = z_buf[-2, 1:-1] + z_buf[:, 0] = z_buf[:, 1] + z_buf[:, -1] = z_buf[:, -2] + + x_buf[1:-1, 1:-1] = x.values + x_buf[0, 1:-1] = x_buf[1, 1:-1] + x_buf[-1, 1:-1] = x_buf[-2, 1:-1] + x_buf[:, 0] = x_buf[:, 1] + x_buf[:, -1] = x_buf[:, -2] + + ax.contour(x_buf, z_buf, valid, levels=[1. - epsilon], + colors=outline_color, linewidths=linewidth, zorder=3) From 65474b547e7586145d97f2b1365f9f93d355d067 Mon Sep 17 00:00:00 2001 From: Xylar Asay-Davis Date: Tue, 7 Nov 2023 13:55:51 +0100 Subject: [PATCH 05/14] Handle periodic meshes in `mesh_to_triangles()` Polaris' version of the funciton handles periodicity in planar as well as spherical meshes. --- polaris/ocean/viz/transect/horiz.py | 173 +++++++++++++++++++++++++++- polaris/ocean/viz/transect/vert.py | 2 +- 2 files changed, 169 insertions(+), 6 deletions(-) diff --git a/polaris/ocean/viz/transect/horiz.py b/polaris/ocean/viz/transect/horiz.py index 0a06cddd5..6e16e9198 100644 --- a/polaris/ocean/viz/transect/horiz.py +++ b/polaris/ocean/viz/transect/horiz.py @@ -11,6 +11,130 @@ from shapely.geometry import LineString, Point +def mesh_to_triangles(ds_mesh): + """ + Construct a dataset in which each MPAS cell is divided into the triangles + connecting pairs of adjacent vertices to cell centers. + + Parameters + ---------- + ds_mesh : xarray.Dataset + An MPAS mesh + + Returns + ------- + ds_tris : xarray.Dataset + A dataset that defines triangles connecting pairs of adjacent vertices + to cell centers as well as the cell index that each triangle is in and + cell indices and weights for interpolating data defined at cell centers + to triangle nodes. ``ds_tris`` includes variables ``triCellIndices``, + the cell that each triangle is part of; ``nodeCellIndices`` and + ``nodeCellWeights``, the indices and weights used to interpolate from + MPAS cell centers to triangle nodes; Cartesian coordinates ``xNode``, + ``yNode``, and ``zNode``; and ``lonNode``` and ``latNode`` in radians. + ``lonNode`` is guaranteed to be within 180 degrees of the cell center + corresponding to ``triCellIndices``. Nodes always have a + counterclockwise winding. + + """ + n_vertices_on_cell = ds_mesh.nEdgesOnCell.values + vertices_on_cell = ds_mesh.verticesOnCell.values - 1 + cells_on_vertex = ds_mesh.cellsOnVertex.values - 1 + + on_a_sphere = ds_mesh.attrs['on_a_sphere'].strip() == 'YES' + is_periodic = False + x_period = None + y_period = None + if not on_a_sphere: + is_periodic = ds_mesh.attrs['is_periodic'].strip() == 'YES' + if is_periodic: + x_period = ds_mesh.attrs['x_period'] + y_period = ds_mesh.attrs['y_period'] + + kite_areas_on_vertex = ds_mesh.kiteAreasOnVertex.values + + n_triangles = np.sum(n_vertices_on_cell) + + max_edges = ds_mesh.sizes['maxEdges'] + n_cells = ds_mesh.sizes['nCells'] + if ds_mesh.sizes['vertexDegree'] != 3: + raise ValueError('mesh_to_triangles only supports meshes with ' + 'vertexDegree = 3') + + # find the third vertex for each triangle + next_vertex = -1 * np.ones(vertices_on_cell.shape, int) + for i_vertex in range(max_edges): + valid = i_vertex < n_vertices_on_cell + invalid = np.logical_not(valid) + vertices_on_cell[invalid, i_vertex] = -1 + nv = n_vertices_on_cell[valid] + cell_indices = np.arange(0, n_cells)[valid] + i_next = np.where(i_vertex < nv - 1, i_vertex + 1, 0) + next_vertex[:, i_vertex][valid] = ( + vertices_on_cell[cell_indices, i_next]) + + valid = vertices_on_cell >= 0 + vertices_on_cell = vertices_on_cell[valid] + next_vertex = next_vertex[valid] + + # find the cell index for each triangle + tri_cell_indices, _ = np.meshgrid(np.arange(0, n_cells), + np.arange(0, max_edges), + indexing='ij') + tri_cell_indices = tri_cell_indices[valid] + + # find list of cells and weights for each triangle node + node_cell_indices = -1 * np.ones((n_triangles, 3, 3), dtype=int) + node_cell_weights = np.zeros((n_triangles, 3, 3)) + + # the first node is at the cell center, so the value is just the one from + # that cell + node_cell_indices[:, 0, 0] = tri_cell_indices + node_cell_weights[:, 0, 0] = 1. + + # the other 2 nodes are associated with vertices + node_cell_indices[:, 1, :] = cells_on_vertex[vertices_on_cell, :] + node_cell_weights[:, 1, :] = kite_areas_on_vertex[vertices_on_cell, :] + node_cell_indices[:, 2, :] = cells_on_vertex[next_vertex, :] + node_cell_weights[:, 2, :] = kite_areas_on_vertex[next_vertex, :] + + weight_sum = np.sum(node_cell_weights, axis=2) + for i_node in range(3): + node_cell_weights[:, :, i_node] = ( + node_cell_weights[:, :, i_node] / weight_sum) + + ds_tris = xr.Dataset() + ds_tris['triCellIndices'] = ('nTriangles', tri_cell_indices) + ds_tris['nodeCellIndices'] = (('nTriangles', 'nNodes', 'nInterp'), + node_cell_indices) + ds_tris['nodeCellWeights'] = (('nTriangles', 'nNodes', 'nInterp'), + node_cell_weights) + + # get Cartesian and lon/lat coordinates of each node + for prefix in ['x', 'y', 'z', 'lat', 'lon']: + out_var = f'{prefix}Node' + cell_var = f'{prefix}Cell' + vertex_var = f'{prefix}Vertex' + coord = np.zeros((n_triangles, 3)) + coord[:, 0] = ds_mesh[cell_var].values[tri_cell_indices] + coord[:, 1] = ds_mesh[vertex_var].values[vertices_on_cell] + coord[:, 2] = ds_mesh[vertex_var].values[next_vertex] + ds_tris[out_var] = (('nTriangles', 'nNodes'), coord) + + # nothing obvious we can do about triangles containing the poles + + if on_a_sphere: + ds_tris = _fix_periodic_tris(ds_tris, periodic_var='lonNode', + period=2 * np.pi) + elif is_periodic: + ds_tris = _fix_periodic_tris(ds_tris, periodic_var='xNode', + period=x_period) + ds_tris = _fix_periodic_tris(ds_tris, periodic_var='yNode', + period=y_period) + + return ds_tris + + def make_triangle_tree(ds_tris): """ Make a KD-Tree for finding triangle edges that are near enough to transect @@ -20,7 +144,7 @@ def make_triangle_tree(ds_tris): ---------- ds_tris : xarray.Dataset A dataset that defines triangles, the results of calling - :py:func:`mpas_tools.viz.mesh_to_triangles.mesh_to_triangles()` + :py:func:`polaris.ocean.viz.transect.horiz.mesh_to_triangles()` Returns ------- @@ -65,14 +189,15 @@ def find_spherical_transect_cells_and_weights( ds_tris : xarray.Dataset A dataset that defines triangles, the results of calling - :py:func:`mpas_tools.viz.mesh_to_triangles.mesh_to_triangles()` + :py:func:`polaris.ocean.viz.transect.horiz.mesh_to_triangles()` ds_mesh : xarray.Dataset A data set with the full MPAS mesh. tree : scipy.spatial.cKDTree A tree of edge centers from triangles making up an MPAS mesh, the - return value from ``make_triangle_tree()`` + return value from + :py:func:`polaris.ocean.viz.transect.horiz.make_triangle_tree()` degrees : bool, optional Whether ``lon_transect`` and ``lat_transect`` are in degrees (as @@ -347,14 +472,15 @@ def find_planar_transect_cells_and_weights( ds_tris : xarray.Dataset A dataset that defines triangles, the results of calling - `:py:func:`mpas_tools.viz.mesh_to_triangles.mesh_to_triangles()` + :py:func:`polaris.ocean.viz.transect.horiz.mesh_to_triangles()` ds_mesh : xarray.Dataset A data set with the full MPAS mesh. tree : scipy.spatial.cKDTree A tree of edge centers from triangles making up an MPAS mesh, the - return value from ``make_triangle_tree()`` + return value from + :py:func:`polaris.ocean.viz.transect.horiz.make_triangle_tree()` subdivision_res : float, optional Resolution in m to use to subdivide the transect when looking for @@ -724,3 +850,40 @@ def _sort_intersections(d_node, tris, nodes, x_out, y_out, z_out, interp_cells, return (d_node, x_out, y_out, z_out, seg_tris, seg_nodes, interp_cells, cell_weights, valid_nodes) + + +def _fix_periodic_tris(ds_tris, periodic_var, period): + """ + make sure the given node coordinate on tris is within one period of the + cell center + """ + coord_node = ds_tris[periodic_var].values + coord_cell = coord_node[:, 0] + n_triangles = ds_tris.sizes['nTriangles'] + copy_pos = np.zeros(coord_cell.shape, dtype=bool) + copy_neg = np.zeros(coord_cell.shape, dtype=bool) + for i_node in [1, 2]: + mask = coord_node[:, i_node] - coord_cell > 0.5 * period + copy_pos = np.logical_or(copy_pos, mask) + coord_node[:, i_node][mask] = coord_node[:, i_node][mask] - period + mask = coord_node[:, i_node] - coord_cell < -0.5 * period + copy_neg = np.logical_or(copy_neg, mask) + coord_node[:, i_node][mask] = coord_node[:, i_node][mask] + period + + pos_indices = np.nonzero(copy_pos)[0] + neg_indices = np.nonzero(copy_neg)[0] + tri_indices = np.append(np.append(np.arange(0, n_triangles), + pos_indices), neg_indices) + + ds_new = xr.Dataset(ds_tris) + ds_new[periodic_var] = (('nTriangles', 'nNodes'), coord_node) + ds_new = ds_new.isel(nTriangles=tri_indices) + coord_node = ds_new[periodic_var].values + + pos_slice = slice(n_triangles, n_triangles + len(pos_indices)) + coord_node[pos_slice, :] = coord_node[pos_slice, :] + period + neg_slice = slice(n_triangles + len(pos_indices), + n_triangles + len(pos_indices) + len(neg_indices)) + coord_node[neg_slice, :] = coord_node[neg_slice, :] - period + ds_new[periodic_var] = (('nTriangles', 'nNodes'), coord_node) + return ds_new diff --git a/polaris/ocean/viz/transect/vert.py b/polaris/ocean/viz/transect/vert.py index dd898b2b0..1823f3873 100644 --- a/polaris/ocean/viz/transect/vert.py +++ b/polaris/ocean/viz/transect/vert.py @@ -1,12 +1,12 @@ import numpy as np import xarray as xr -from mpas_tools.viz import mesh_to_triangles from polaris.ocean.viz.transect.horiz import ( find_planar_transect_cells_and_weights, find_spherical_transect_cells_and_weights, interp_mpas_horiz_to_transect_nodes, make_triangle_tree, + mesh_to_triangles, ) From 29f19d37e1f2d9f7be668b4cbeca503daf2dc7e8 Mon Sep 17 00:00:00 2001 From: Xylar Asay-Davis Date: Tue, 7 Nov 2023 21:36:00 +0100 Subject: [PATCH 06/14] Add transect plotting to `plot_horiz_field()` --- polaris/viz/planar.py | 61 ++++++++++++++++++++++++++++++++++++------- 1 file changed, 51 insertions(+), 10 deletions(-) diff --git a/polaris/viz/planar.py b/polaris/viz/planar.py index fef04eda2..779bbbfa3 100644 --- a/polaris/viz/planar.py +++ b/polaris/viz/planar.py @@ -16,7 +16,10 @@ def plot_horiz_field(ds, ds_mesh, field_name, out_file_name=None, # noqa: C901 cmap=None, cmap_set_under=None, cmap_set_over=None, cmap_scale='linear', cmap_title=None, figsize=None, vert_dim='nVertLevels', cell_mask=None, patches=None, - patch_mask=None): + patch_mask=None, transect_x=None, transect_y=None, + transect_color='black', transect_start='red', + transect_end='green', transect_linewidth=2., + transect_markersize=12.): """ Plot a horizontal field from a planar domain using x,y coordinates at a single time and depth slice. @@ -92,6 +95,27 @@ def plot_horiz_field(ds, ds_mesh, field_name, out_file_name=None, # noqa: C901 A mask of where the field has patches from a previous call to ``plot_horiz_field()`` + transect_x : numpy.ndarray or xarray.DataArray, optional + The x coordinates of a transect to plot on the + + transect_y : numpy.ndarray or xarray.DataArray, optional + The y coordinates of a transect + + transect_color : str, optional + The color of the transect line + + transect_start : str or None, optional + The color of a dot marking the start of the transect + + transect_end : str or None, optional + The color of a dot marking the end of the transect + + transect_linewidth : float, optional + The width of the transect line + + transect_markersize : float, optional + The size of the transect start and end markers + Returns ------- patches : list of numpy.ndarray @@ -101,6 +125,19 @@ def plot_horiz_field(ds, ds_mesh, field_name, out_file_name=None, # noqa: C901 patch_mask : numpy.ndarray A mask used to select entries in the field that have patches """ + if field_name not in ds: + raise ValueError( + f'{field_name} must be present in ds before plotting.') + + if patches is not None: + if patch_mask is None: + raise ValueError('You must supply both patches and patch_mask ' + 'from a previous call to plot_horiz_field()') + + if (transect_x is None) != (transect_y is None): + raise ValueError('You must supply both transect_x and transect_y or ' + 'neither') + use_mplstyle() create_fig = True @@ -118,10 +155,6 @@ def plot_horiz_field(ds, ds_mesh, field_name, out_file_name=None, # noqa: C901 if title is None: title = field_name - if field_name not in ds: - raise ValueError( - f'{field_name} must be present in ds before plotting.') - field = ds[field_name] if 'Time' in field.dims and t_index is None: @@ -133,11 +166,7 @@ def plot_horiz_field(ds, ds_mesh, field_name, out_file_name=None, # noqa: C901 if z_index is not None: field = field.isel({vert_dim: z_index}) - if patches is not None: - if patch_mask is None: - raise ValueError('You must supply both patches and patch_mask ' - 'from a previous call to plot_horiz_field()') - else: + if patches is None: if cell_mask is None: cell_mask = np.ones_like(field, type='bool') if 'nCells' in field.dims: @@ -190,6 +219,18 @@ def plot_horiz_field(ds, ds_mesh, field_name, out_file_name=None, # noqa: C901 cbar = plt.colorbar(local_patches, extend='both', shrink=0.7, ax=ax) if cmap_title is not None: cbar.set_label(cmap_title) + + if transect_x is not None: + transect_x = 1e-3 * transect_x + transect_y = 1e-3 * transect_y + ax.plot(transect_x, transect_y, color=transect_color, + linewidth=transect_linewidth) + if transect_start is not None: + ax.plot(transect_x[0], transect_y[0], '.', color=transect_start, + markersize=transect_markersize) + if transect_end is not None: + ax.plot(transect_x[-1], transect_y[-1], '.', color=transect_end, + markersize=transect_markersize) if create_fig: plt.title(title) plt.savefig(out_file_name, bbox_inches='tight', pad_inches=0.2) From 186b016774fc1f62fb620cfd922d1f007f8bc2ec Mon Sep 17 00:00:00 2001 From: Xylar Asay-Davis Date: Tue, 7 Nov 2023 22:45:53 +0100 Subject: [PATCH 07/14] Add start/end colors on transect axes --- polaris/ocean/viz/transect/plot.py | 22 +++++++++++++++++++--- 1 file changed, 19 insertions(+), 3 deletions(-) diff --git a/polaris/ocean/viz/transect/plot.py b/polaris/ocean/viz/transect/plot.py index 73ec7d1cd..890310c2e 100644 --- a/polaris/ocean/viz/transect/plot.py +++ b/polaris/ocean/viz/transect/plot.py @@ -14,7 +14,7 @@ def plot_transect(ds_transect, mpas_field=None, out_filename=None, ax=None, cmap=None, figsize=(12, 6), dpi=200, method='flat', outline_color='black', ssh_color=None, seafloor_color=None, interface_color=None, cell_boundary_color=None, - linewidth=1.0): + linewidth=1.0, transect_start='red', transect_end='green'): """ plot a transect showing the field on the MPAS-Ocean mesh and save to a file @@ -84,6 +84,12 @@ def plot_transect(ds_transect, mpas_field=None, out_filename=None, ax=None, linewidth : float, optional The width of outlines, interfaces and cell boundaries + + transect_start : str or None, optional + The color of left axis marking the start of the transect + + transect_end : str or None, optional + The color of right axis marking the end of the transect """ if ax is None and out_filename is None: @@ -120,7 +126,8 @@ def plot_transect(ds_transect, mpas_field=None, out_filename=None, ax=None, label=colorbar_label) _plot_interfaces(ds_transect, ax, interface_color, cell_boundary_color, - ssh_color, seafloor_color, linewidth) + ssh_color, seafloor_color, transect_start, transect_end, + linewidth) _plot_outline(x, z, ds_transect.validNodes, ax, outline_color, linewidth) @@ -136,7 +143,8 @@ def plot_transect(ds_transect, mpas_field=None, out_filename=None, ax=None, def _plot_interfaces(ds_transect, ax, interface_color, cell_boundary_color, - ssh_color, seafloor_color, linewidth): + ssh_color, seafloor_color, transect_start, transect_end, + linewidth): if cell_boundary_color is not None: x_bnd = 1e-3 * ds_transect.dCellBoundary.values.T z_bnd = ds_transect.zCellBoundary.values.T @@ -162,6 +170,14 @@ def _plot_interfaces(ds_transect, ax, interface_color, cell_boundary_color, ax.plot(x_floor, z_floor, color=seafloor_color, linewidth=linewidth, zorder=5) + if transect_start is not None: + ax.spines['left'].set_color(transect_start) + ax.spines['left'].set_linewidth(4 * linewidth) + + if transect_end is not None: + ax.spines['right'].set_color(transect_end) + ax.spines['right'].set_linewidth(4 * linewidth) + def _plot_outline(x, z, valid_nodes, ax, outline_color, linewidth, epsilon=1e-6): From de8d5bfb2ce13190e369ca1b7c435c5c53ccb2fb Mon Sep 17 00:00:00 2001 From: Xylar Asay-Davis Date: Tue, 7 Nov 2023 15:31:47 +0100 Subject: [PATCH 08/14] Add transect plots to baroclinic channel --- .../ocean/tasks/baroclinic_channel/init.py | 24 +++++++++++-- polaris/ocean/tasks/baroclinic_channel/viz.py | 34 +++++++++++++++++-- 2 files changed, 53 insertions(+), 5 deletions(-) diff --git a/polaris/ocean/tasks/baroclinic_channel/init.py b/polaris/ocean/tasks/baroclinic_channel/init.py index 983efb7ec..183618133 100644 --- a/polaris/ocean/tasks/baroclinic_channel/init.py +++ b/polaris/ocean/tasks/baroclinic_channel/init.py @@ -8,6 +8,7 @@ from polaris import Step from polaris.mesh.planar import compute_planar_hex_nx_ny from polaris.ocean.vertical import init_vertical_coord +from polaris.ocean.viz import compute_transect, plot_transect from polaris.viz import plot_horiz_field @@ -162,8 +163,27 @@ def run(self): write_netcdf(ds, 'initial_state.nc') cell_mask = ds.maxLevelCell >= 1 - plot_horiz_field(ds, ds_mesh, 'temperature', - 'initial_temperature.png', cell_mask=cell_mask) + plot_horiz_field(ds, ds_mesh, 'normalVelocity', 'initial_normal_velocity.png', cmap='cmo.balance', show_patch_edges=True, cell_mask=cell_mask) + + x_mid = 0.5 * (x_min + x_max) + + y = xr.DataArray(data=np.linspace(y_min, y_max, 2), dims=('nPoints',)) + x = x_mid * xr.ones_like(y) + + ds_transect = compute_transect(x=x, y=y, ds_3d_mesh=ds.isel(Time=0), + spherical=False) + + field_name = 'temperature' + plot_transect(ds_transect=ds_transect, + mpas_field=ds[field_name].isel(Time=0), + title=f'{field_name} at x={1e-3 * x_mid:.1f} km', + out_filename=f'initial_{field_name}_section.png', + vmin=9., vmax=13., cmap='cmo.thermal', + colorbar_label=r'$^\circ$C') + + plot_horiz_field(ds, ds_mesh, 'temperature', 'initial_temperature.png', + vmin=9., vmax=13., cmap='cmo.thermal', + cell_mask=cell_mask, transect_x=x, transect_y=y) diff --git a/polaris/ocean/tasks/baroclinic_channel/viz.py b/polaris/ocean/tasks/baroclinic_channel/viz.py index aa2416c92..a49257643 100644 --- a/polaris/ocean/tasks/baroclinic_channel/viz.py +++ b/polaris/ocean/tasks/baroclinic_channel/viz.py @@ -3,6 +3,7 @@ import xarray as xr from polaris import Step +from polaris.ocean.viz import compute_transect, plot_transect from polaris.viz import plot_horiz_field @@ -42,9 +43,6 @@ def run(self): ds = xr.load_dataset('output.nc') t_index = ds.sizes['Time'] - 1 cell_mask = ds_init.maxLevelCell >= 1 - plot_horiz_field(ds, ds_mesh, 'temperature', - 'final_temperature.png', t_index=t_index, - cell_mask=cell_mask) max_velocity = np.max(np.abs(ds.normalVelocity.values)) plot_horiz_field(ds, ds_mesh, 'normalVelocity', 'final_normalVelocity.png', @@ -52,3 +50,33 @@ def run(self): vmin=-max_velocity, vmax=max_velocity, cmap='cmo.balance', show_patch_edges=True, cell_mask=cell_mask) + + x_cell = ds_mesh.xCell + y_cell = ds_mesh.yCell + + x_min = x_cell.min().values + x_max = x_cell.max().values + y_min = y_cell.min().values + y_max = y_cell.max().values + + x_mid = 0.5 * (x_min + x_max) + + y = xr.DataArray(data=np.linspace(y_min, y_max, 2), dims=('nPoints',)) + x = x_mid * xr.ones_like(y) + + ds_transect = compute_transect(x=x, y=y, + ds_3d_mesh=ds_init.isel(Time=0), + spherical=False) + + field_name = 'temperature' + mpas_field = ds[field_name].isel(Time=t_index) + plot_transect(ds_transect=ds_transect, mpas_field=mpas_field, + title=f'{field_name} at x={1e-3 * x_mid:.1f} km', + out_filename=f'final_{field_name}_section.png', + vmin=9., vmax=13., cmap='cmo.thermal', + colorbar_label=r'$^\circ$C') + + plot_horiz_field(ds, ds_mesh, 'temperature', 'final_temperature.png', + t_index=t_index, vmin=9., vmax=13., + cmap='cmo.thermal', cell_mask=cell_mask, transect_x=x, + transect_y=y) From 0e2fd483e70ffd510f4cb9898cbdb2eabf153ec7 Mon Sep 17 00:00:00 2001 From: Xylar Asay-Davis Date: Sat, 28 Oct 2023 21:22:10 +0200 Subject: [PATCH 09/14] Add docs --- docs/developers_guide/ocean/api.md | 19 +++++++++++++++++++ docs/developers_guide/ocean/framework.md | 10 ++++++++++ 2 files changed, 29 insertions(+) diff --git a/docs/developers_guide/ocean/api.md b/docs/developers_guide/ocean/api.md index bf31cdcca..23ce0fa8c 100644 --- a/docs/developers_guide/ocean/api.md +++ b/docs/developers_guide/ocean/api.md @@ -314,3 +314,22 @@ vertical.zlevel.compute_z_level_resting_thickness vertical.zstar.init_z_star_vertical_coord ``` + +### Visualization + +```{eval-rst} +.. currentmodule:: polaris.ocean.viz + +.. autosummary:: + :toctree: generated/ + + compute_transect + plot_transect + transect.horiz.find_spherical_transect_cells_and_weights + transect.horiz.find_planar_transect_cells_and_weights + transect.horiz.make_triangle_tree + transect.horiz.mesh_to_triangles + transect.vert.find_transect_levels_and_weights + transect.vert.interp_mpas_to_transect_cells + transect.vert.interp_mpas_to_transect_nodes + diff --git a/docs/developers_guide/ocean/framework.md b/docs/developers_guide/ocean/framework.md index dfcea88ca..6ae754bd1 100644 --- a/docs/developers_guide/ocean/framework.md +++ b/docs/developers_guide/ocean/framework.md @@ -422,3 +422,13 @@ density, which is horizontally constant and increases with depth. The {py:func}`polaris.ocean.rpe.compute_rpe()` is used to compute the RPE as a function of time in a series of one or more output files. The RPE is stored in `rpe.csv` and also returned as a numpy array for plotting and analysis. + +## Visualization + +The `polaris.ocean.viz` module provides functions for making plots that are +specific to the ocean component. + +The `polaris.ocean.viz.transect` modules includes functions for computing +({py:func}`polaris.ocean.viz.compute_transect()`) and plotting +({py:func}`polaris.ocean.viz.plot_transect()`) transects through the ocean +from a sequence of x-y or latitude-longitude coordinates. From dba67c88d15ab3a9e6f3503bf06b75c4d35b715f Mon Sep 17 00:00:00 2001 From: Xylar Asay-Davis Date: Thu, 9 Nov 2023 10:52:59 +0100 Subject: [PATCH 10/14] Add more parameters to `compute_transect()` This should allow the horizontal mesh and constituents of the vertical coordinate to come from different data sets as needed. --- .../ocean/tasks/baroclinic_channel/init.py | 7 ++- polaris/ocean/tasks/baroclinic_channel/viz.py | 10 +++-- polaris/ocean/viz/transect/vert.py | 45 +++++++++++-------- 3 files changed, 38 insertions(+), 24 deletions(-) diff --git a/polaris/ocean/tasks/baroclinic_channel/init.py b/polaris/ocean/tasks/baroclinic_channel/init.py index 183618133..372e4f8f7 100644 --- a/polaris/ocean/tasks/baroclinic_channel/init.py +++ b/polaris/ocean/tasks/baroclinic_channel/init.py @@ -173,8 +173,11 @@ def run(self): y = xr.DataArray(data=np.linspace(y_min, y_max, 2), dims=('nPoints',)) x = x_mid * xr.ones_like(y) - ds_transect = compute_transect(x=x, y=y, ds_3d_mesh=ds.isel(Time=0), - spherical=False) + ds_transect = compute_transect( + x=x, y=y, ds_horiz_mesh=ds_mesh, + layer_thickness=ds.layerThickness.isel(Time=0), + bottom_depth=ds.bottomDepth, min_level_cell=ds.minLevelCell - 1, + max_level_cell=ds.maxLevelCell - 1, spherical=False) field_name = 'temperature' plot_transect(ds_transect=ds_transect, diff --git a/polaris/ocean/tasks/baroclinic_channel/viz.py b/polaris/ocean/tasks/baroclinic_channel/viz.py index a49257643..48de48f55 100644 --- a/polaris/ocean/tasks/baroclinic_channel/viz.py +++ b/polaris/ocean/tasks/baroclinic_channel/viz.py @@ -64,9 +64,13 @@ def run(self): y = xr.DataArray(data=np.linspace(y_min, y_max, 2), dims=('nPoints',)) x = x_mid * xr.ones_like(y) - ds_transect = compute_transect(x=x, y=y, - ds_3d_mesh=ds_init.isel(Time=0), - spherical=False) + ds_transect = compute_transect( + x=x, y=y, ds_horiz_mesh=ds_mesh, + layer_thickness=ds.layerThickness.isel(Time=t_index), + bottom_depth=ds_init.bottomDepth, + min_level_cell=ds_init.minLevelCell - 1, + max_level_cell=ds_init.maxLevelCell - 1, + spherical=False) field_name = 'temperature' mpas_field = ds[field_name].isel(Time=t_index) diff --git a/polaris/ocean/viz/transect/vert.py b/polaris/ocean/viz/transect/vert.py index 1823f3873..4b8e5314c 100644 --- a/polaris/ocean/viz/transect/vert.py +++ b/polaris/ocean/viz/transect/vert.py @@ -4,13 +4,13 @@ from polaris.ocean.viz.transect.horiz import ( find_planar_transect_cells_and_weights, find_spherical_transect_cells_and_weights, - interp_mpas_horiz_to_transect_nodes, make_triangle_tree, mesh_to_triangles, ) -def compute_transect(x, y, ds_3d_mesh, spherical=False): +def compute_transect(x, y, ds_horiz_mesh, layer_thickness, bottom_depth, + min_level_cell, max_level_cell, spherical=False): """ build a sequence of quads showing the transect intersecting mpas cells @@ -22,8 +22,22 @@ def compute_transect(x, y, ds_3d_mesh, spherical=False): y : xarray.DataArray The y or latitude coordinate of the transect - ds_3d_mesh : xarray.Dataset - The MPAS-Ocean mesh to use for plotting + ds_horiz_mesh : xarray.Dataset + The horizontal MPAS mesh to use for plotting + + layer_thickness : xarray.DataArray + The layer thickness at a particular instant in time. + `layerThickness.isel(Time=tidx)` to select a particular time index + `tidx` if the original data array contains `Time`. + + bottom_depth : xarray.DataArray + the (positive down) depth of the seafloor on the MPAS mesh + + min_level_cell : xarray.DataArray + the vertical zero-based index of the sea surface on the MPAS mesh + + max_level_cell : xarray.DataArray + the vertical zero-based index of the bathymetry on the MPAS mesh spherical : bool, optional Whether the x and y coordinates are latitude and longitude in degrees @@ -36,20 +50,20 @@ def compute_transect(x, y, ds_3d_mesh, spherical=False): for details """ # noqa: E501 - ds_tris = mesh_to_triangles(ds_3d_mesh) + ds_tris = mesh_to_triangles(ds_horiz_mesh) triangle_tree = make_triangle_tree(ds_tris) if spherical: ds_horiz_transect = find_spherical_transect_cells_and_weights( - x, y, ds_tris, ds_3d_mesh, triangle_tree, degrees=True) + x, y, ds_tris, ds_horiz_mesh, triangle_tree, degrees=True) else: ds_horiz_transect = find_planar_transect_cells_and_weights( - x, y, ds_tris, ds_3d_mesh, triangle_tree) + x, y, ds_tris, ds_horiz_mesh, triangle_tree) - # mask horizontal transect to valid cells (maxLevelCell > 0) + # mask horizontal transect to valid cells (max_level_cell >= 0) cell_indices = ds_horiz_transect.horizCellIndices - seg_mask = ds_3d_mesh.maxLevelCell.isel(nCells=cell_indices).values > 0 + seg_mask = max_level_cell.isel(nCells=cell_indices).values >= 0 node_mask = np.zeros(ds_horiz_transect.sizes['nNodes'], dtype=bool) node_mask[0:-1] = seg_mask node_mask[1:] = np.logical_or(node_mask[1:], seg_mask) @@ -58,16 +72,9 @@ def compute_transect(x, y, ds_3d_mesh, spherical=False): nNodes=node_mask) ds_transect = find_transect_levels_and_weights( - ds_horiz_transect=ds_horiz_transect, - layer_thickness=ds_3d_mesh.layerThickness, - bottom_depth=ds_3d_mesh.bottomDepth, - min_level_cell=ds_3d_mesh.minLevelCell - 1, - max_level_cell=ds_3d_mesh.maxLevelCell - 1) - - # interpolate the land-ice fraction so we can plot an overlying ice shelf - if 'landIceFraction' in ds_3d_mesh: - ds_transect['landIceFraction'] = interp_mpas_horiz_to_transect_nodes( - ds_transect, ds_3d_mesh.landIceFraction) + ds_horiz_transect=ds_horiz_transect, layer_thickness=layer_thickness, + bottom_depth=bottom_depth, min_level_cell=min_level_cell, + max_level_cell=max_level_cell) ds_transect.compute() From 6d58edbb54c903355cac6e020a66e558b223618b Mon Sep 17 00:00:00 2001 From: Xylar Asay-Davis Date: Mon, 13 Nov 2023 08:40:20 -0700 Subject: [PATCH 11/14] Clean up docstrings for computing transects Remove confusing references to "3D" and indicate that only fields with `nCells` and `nVertLevels` as dimensions can be plotted. --- polaris/ocean/viz/transect/vert.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/polaris/ocean/viz/transect/vert.py b/polaris/ocean/viz/transect/vert.py index 4b8e5314c..a146cb46c 100644 --- a/polaris/ocean/viz/transect/vert.py +++ b/polaris/ocean/viz/transect/vert.py @@ -12,7 +12,9 @@ def compute_transect(x, y, ds_horiz_mesh, layer_thickness, bottom_depth, min_level_cell, max_level_cell, spherical=False): """ - build a sequence of quads showing the transect intersecting mpas cells + build a sequence of quads showing the transect intersecting mpas cells. + This can be used to plot transects of fields with dimensions ``nCells`` and + ``nVertLevels`` using :py:func:`polaris.ocean.viz.plot_transect()` Parameters ---------- @@ -208,8 +210,8 @@ def find_transect_levels_and_weights(ds_horiz_transect, layer_thickness, def interp_mpas_to_transect_cells(ds_transect, da): """ - Interpolate a 3D (``nCells`` by ``nVertLevels``) MPAS-Ocean DataArray - to transect cells + Interpolate an MPAS-Ocean DataArray with dimensions ``nCells`` by + ``nVertLevels`` to transect cells Parameters ---------- @@ -218,7 +220,7 @@ def interp_mpas_to_transect_cells(ds_transect, da): ``find_transect_levels_and_weights()`` da : xarray.DataArray - An MPAS-Ocean 3D field with dimensions `nCells`` and ``nVertLevels`` + An MPAS-Ocean field with dimensions `nCells`` and ``nVertLevels`` (possibly among others) Returns @@ -240,9 +242,9 @@ def interp_mpas_to_transect_cells(ds_transect, da): def interp_mpas_to_transect_nodes(ds_transect, da): """ - Interpolate a 3D (``nCells`` by ``nVertLevels``) MPAS-Ocean DataArray - to transect nodes, linearly interpolating fields between the closest - neighboring cells + Interpolate an MPAS-Ocean DataArray with dimensions ``nCells`` by + ``nVertLevels`` to transect nodes, linearly interpolating fields between + the closest neighboring cells Parameters ---------- @@ -251,7 +253,7 @@ def interp_mpas_to_transect_nodes(ds_transect, da): ``find_transect_levels_and_weights()`` da : xarray.DataArray - An MPAS-Ocean 3D field with dimensions `nCells`` and ``nVertLevels`` + An MPAS-Ocean field with dimensions `nCells`` and ``nVertLevels`` (possibly among others) Returns From 501360b18594e994cecfb81dd87267b92e31b778 Mon Sep 17 00:00:00 2001 From: Xylar Asay-Davis Date: Mon, 13 Nov 2023 08:41:31 -0700 Subject: [PATCH 12/14] Make coloring the start and end of a transect off by default --- polaris/ocean/viz/transect/plot.py | 34 ++++++++++++++++++------------ 1 file changed, 20 insertions(+), 14 deletions(-) diff --git a/polaris/ocean/viz/transect/plot.py b/polaris/ocean/viz/transect/plot.py index 890310c2e..3ebac1e2f 100644 --- a/polaris/ocean/viz/transect/plot.py +++ b/polaris/ocean/viz/transect/plot.py @@ -14,7 +14,8 @@ def plot_transect(ds_transect, mpas_field=None, out_filename=None, ax=None, cmap=None, figsize=(12, 6), dpi=200, method='flat', outline_color='black', ssh_color=None, seafloor_color=None, interface_color=None, cell_boundary_color=None, - linewidth=1.0, transect_start='red', transect_end='green'): + linewidth=1.0, color_start_and_end=False, + start_color='red', end_color='green'): """ plot a transect showing the field on the MPAS-Ocean mesh and save to a file @@ -85,11 +86,18 @@ def plot_transect(ds_transect, mpas_field=None, out_filename=None, ax=None, linewidth : float, optional The width of outlines, interfaces and cell boundaries - transect_start : str or None, optional - The color of left axis marking the start of the transect + color_start_and_end : bool, optional + Whether to color the left and right axes of the transect, which is + useful if the transect is also being plotted in an inset or on top of + a horizontal field - transect_end : str or None, optional - The color of right axis marking the end of the transect + start_color : str, optional + The color of left axis marking the start of the transect if + ``plot_start_end == True`` + + end_color : str, optional + The color of right axis marking the end of the transect if + ``plot_start_end == True`` """ if ax is None and out_filename is None: @@ -126,8 +134,8 @@ def plot_transect(ds_transect, mpas_field=None, out_filename=None, ax=None, label=colorbar_label) _plot_interfaces(ds_transect, ax, interface_color, cell_boundary_color, - ssh_color, seafloor_color, transect_start, transect_end, - linewidth) + ssh_color, seafloor_color, color_start_and_end, + start_color, end_color, linewidth) _plot_outline(x, z, ds_transect.validNodes, ax, outline_color, linewidth) @@ -143,8 +151,8 @@ def plot_transect(ds_transect, mpas_field=None, out_filename=None, ax=None, def _plot_interfaces(ds_transect, ax, interface_color, cell_boundary_color, - ssh_color, seafloor_color, transect_start, transect_end, - linewidth): + ssh_color, seafloor_color, color_start_and_end, + start_color, end_color, linewidth): if cell_boundary_color is not None: x_bnd = 1e-3 * ds_transect.dCellBoundary.values.T z_bnd = ds_transect.zCellBoundary.values.T @@ -170,12 +178,10 @@ def _plot_interfaces(ds_transect, ax, interface_color, cell_boundary_color, ax.plot(x_floor, z_floor, color=seafloor_color, linewidth=linewidth, zorder=5) - if transect_start is not None: - ax.spines['left'].set_color(transect_start) + if color_start_and_end: + ax.spines['left'].set_color(start_color) ax.spines['left'].set_linewidth(4 * linewidth) - - if transect_end is not None: - ax.spines['right'].set_color(transect_end) + ax.spines['right'].set_color(end_color) ax.spines['right'].set_linewidth(4 * linewidth) From 0d6d10f9c8ee5aa81c7b59ff6628add20b374e78 Mon Sep 17 00:00:00 2001 From: Xylar Asay-Davis Date: Mon, 13 Nov 2023 08:42:09 -0700 Subject: [PATCH 13/14] Update baroclinic channel transect plots Color the start and end axes of the transects (now off by default). Use the median to compute the x location of the transect. Use vertices to determine the start and end of the transect in y (so the transects doesn't start in the middle of a cell). Use the min and max of temperature to determine `vmin` and `vmax`, rather than hard-coding. --- .../ocean/tasks/baroclinic_channel/init.py | 12 +++++++---- polaris/ocean/tasks/baroclinic_channel/viz.py | 20 ++++++++----------- 2 files changed, 16 insertions(+), 16 deletions(-) diff --git a/polaris/ocean/tasks/baroclinic_channel/init.py b/polaris/ocean/tasks/baroclinic_channel/init.py index 372e4f8f7..5093857ae 100644 --- a/polaris/ocean/tasks/baroclinic_channel/init.py +++ b/polaris/ocean/tasks/baroclinic_channel/init.py @@ -168,7 +168,9 @@ def run(self): 'initial_normal_velocity.png', cmap='cmo.balance', show_patch_edges=True, cell_mask=cell_mask) - x_mid = 0.5 * (x_min + x_max) + y_min = ds_mesh.yVertex.min().values + y_max = ds_mesh.yVertex.max().values + x_mid = ds_mesh.xCell.median().values y = xr.DataArray(data=np.linspace(y_min, y_max, 2), dims=('nPoints',)) x = x_mid * xr.ones_like(y) @@ -180,13 +182,15 @@ def run(self): max_level_cell=ds.maxLevelCell - 1, spherical=False) field_name = 'temperature' + vmin = ds[field_name].min().values + vmax = ds[field_name].max().values plot_transect(ds_transect=ds_transect, mpas_field=ds[field_name].isel(Time=0), title=f'{field_name} at x={1e-3 * x_mid:.1f} km', out_filename=f'initial_{field_name}_section.png', - vmin=9., vmax=13., cmap='cmo.thermal', - colorbar_label=r'$^\circ$C') + vmin=vmin, vmax=vmax, cmap='cmo.thermal', + colorbar_label=r'$^\circ$C', color_start_and_end=True) plot_horiz_field(ds, ds_mesh, 'temperature', 'initial_temperature.png', - vmin=9., vmax=13., cmap='cmo.thermal', + vmin=vmin, vmax=vmax, cmap='cmo.thermal', cell_mask=cell_mask, transect_x=x, transect_y=y) diff --git a/polaris/ocean/tasks/baroclinic_channel/viz.py b/polaris/ocean/tasks/baroclinic_channel/viz.py index 48de48f55..617253e71 100644 --- a/polaris/ocean/tasks/baroclinic_channel/viz.py +++ b/polaris/ocean/tasks/baroclinic_channel/viz.py @@ -51,15 +51,9 @@ def run(self): cmap='cmo.balance', show_patch_edges=True, cell_mask=cell_mask) - x_cell = ds_mesh.xCell - y_cell = ds_mesh.yCell - - x_min = x_cell.min().values - x_max = x_cell.max().values - y_min = y_cell.min().values - y_max = y_cell.max().values - - x_mid = 0.5 * (x_min + x_max) + y_min = ds_mesh.yVertex.min().values + y_max = ds_mesh.yVertex.max().values + x_mid = ds_mesh.xCell.median().values y = xr.DataArray(data=np.linspace(y_min, y_max, 2), dims=('nPoints',)) x = x_mid * xr.ones_like(y) @@ -73,14 +67,16 @@ def run(self): spherical=False) field_name = 'temperature' + vmin = ds[field_name].min().values + vmax = ds[field_name].max().values mpas_field = ds[field_name].isel(Time=t_index) plot_transect(ds_transect=ds_transect, mpas_field=mpas_field, title=f'{field_name} at x={1e-3 * x_mid:.1f} km', out_filename=f'final_{field_name}_section.png', - vmin=9., vmax=13., cmap='cmo.thermal', - colorbar_label=r'$^\circ$C') + vmin=vmin, vmax=vmax, cmap='cmo.thermal', + colorbar_label=r'$^\circ$C', color_start_and_end=True) plot_horiz_field(ds, ds_mesh, 'temperature', 'final_temperature.png', - t_index=t_index, vmin=9., vmax=13., + t_index=t_index, vmin=vmin, vmax=vmax, cmap='cmo.thermal', cell_mask=cell_mask, transect_x=x, transect_y=y) From e269794079aedf6b02329c407ab4c2c7c07c7915 Mon Sep 17 00:00:00 2001 From: Xylar Asay-Davis Date: Mon, 13 Nov 2023 13:02:50 -0700 Subject: [PATCH 14/14] Update docs --- docs/developers_guide/ocean/framework.md | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/docs/developers_guide/ocean/framework.md b/docs/developers_guide/ocean/framework.md index 6ae754bd1..736c03c61 100644 --- a/docs/developers_guide/ocean/framework.md +++ b/docs/developers_guide/ocean/framework.md @@ -431,4 +431,6 @@ specific to the ocean component. The `polaris.ocean.viz.transect` modules includes functions for computing ({py:func}`polaris.ocean.viz.compute_transect()`) and plotting ({py:func}`polaris.ocean.viz.plot_transect()`) transects through the ocean -from a sequence of x-y or latitude-longitude coordinates. +from a sequence of x-y or latitude-longitude coordinates. Currently, only +transects on xarray data arrays with dimensions `nCells` by `nVertLevels` are +supported.