diff --git a/lib/catnip/visualisation.py b/lib/catnip/visualisation.py index 714940a..9d7bc26 100644 --- a/lib/catnip/visualisation.py +++ b/lib/catnip/visualisation.py @@ -32,6 +32,7 @@ import matplotlib + matplotlib.use("Agg") import matplotlib.pyplot as plt import iris.quickplot as qplt @@ -39,8 +40,12 @@ import iris from catnip.analysis import linear_regress, ci_interval import numpy as np +import cartopy.crs as ccrs + -def vector_plot(u_cube, v_cube, unrotate=False, npts=30, num_plot=111, title=""): +def vector_plot( + u_cube, v_cube, unrotate=False, npts=30, num_plot=111, title="", projection=None +): """ A plotting function to produce a quick wind vector plot. Output is a plot with windspeed @@ -64,14 +69,8 @@ def vector_plot(u_cube, v_cube, unrotate=False, npts=30, num_plot=111, title="") """ # x and y coords - try: - x = u_cube.coord(axis="x") - except iris.exceptions.CoordinateNotFoundError: - print("Error: more than one x coord found") - try: - y = u_cube.coord(axis="y") - except iris.exceptions.CoordinateNotFoundError: - print("Error: more than one y coord found") + x = u_cube.coord(axis="x") + y = u_cube.coord(axis="y") # if the wind vectors need to be unrotated # they are in the statement below @@ -92,19 +91,32 @@ def vector_plot(u_cube, v_cube, unrotate=False, npts=30, num_plot=111, title="") windspeed_cube = (u_cube ** 2 + v_cube ** 2) ** 0.5 # plot - transform = x.coord_system.as_cartopy_projection() - # use coord_system of input data to define plot projection - ax = plt.subplot(num_plot, projection=transform) + # use coord_system of input data to define plot projection if not specified by user + if projection is None: + if x.coord_system is None: + projection = ccrs.PlateCarree() + else: + projection = x.coord_system.as_cartopy_projection() + + # the original crs of the data is needed for quiver plot and setting extents + if x.coord_system is None: + print("No crs found, assuming PlateCarree") + orig_crs = ccrs.PlateCarree() + else: + orig_crs = x.coord_system.as_cartopy_projection() + + ax = plt.subplot(num_plot, projection=projection) qplt.contourf(windspeed_cube, 20) ax.quiver( u_cube.coord(x.standard_name).points[::npts], v_cube.coord(y.standard_name).points[::npts], u_cube.data[::npts, ::npts], v_cube.data[::npts, ::npts], + transform=orig_crs, ) # OTHER OPTIONS FOR QUIVER: scale = 1, headwidth = 3, width = 0.0015) ax.coastlines() - ax.set_extent([x.points[0], x.points[-1], y.points[0], y.points[-1]], transform) + ax.set_extent([x.points[0], x.points[-1], y.points[0], y.points[-1]], crs=orig_crs) plt.title(title) print("plot {} created".format(title))