Skip to content

Commit

Permalink
Merge branch 'main' into 679_ljw_drought
Browse files Browse the repository at this point in the history
  • Loading branch information
lee1043 authored Nov 7, 2023
2 parents 4645d35 + f0eaa5e commit c09be2a
Showing 1 changed file with 72 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,11 @@ def parallel_coordinate_plot(
metric_names,
model_names,
models_to_highlight=list(),
models_to_highlight_by_line=True,
models_to_highlight_colors=None,
models_to_highlight_labels=None,
models_to_highlight_markers=["s", "o", "^", "*"],
models_to_highlight_markers_size=10,
fig=None,
ax=None,
figsize=(15, 5),
Expand All @@ -37,7 +40,10 @@ def parallel_coordinate_plot(
group2_name="group2",
comparing_models=None,
fill_between_lines=False,
fill_between_lines_colors=("green", "red"),
fill_between_lines_colors=("red", "green"),
arrow_between_lines=False,
arrow_between_lines_colors=("red", "green"),
arrow_alpha=1,
vertical_center=None,
vertical_center_line=False,
vertical_center_line_label=None,
Expand All @@ -50,9 +56,12 @@ def parallel_coordinate_plot(
- `data`: 2-d numpy array for metrics
- `metric_names`: list, names of metrics for individual vertical axes (axis=1)
- `model_names`: list, name of models for markers/lines (axis=0)
- `models_to_highlight`: list, default=None, List of models to highlight as lines
- `models_to_highlight`: list, default=None, List of models to highlight as lines or marker
- `models_to_highlight_by_line`: bool, default=True, highlight as lines. If False, as marker
- `models_to_highlight_colors`: list, default=None, List of colors for models to highlight as lines
- `models_to_highlight_labels`: list, default=None, List of string labels for models to highlight as lines
- `models_to_highlight_markers`: list, matplotlib markers for models to highlight if as marker
- `models_to_highlight_markers_size`: float, size of matplotlib markers for models to highlight if as marker
- `fig`: `matplotlib.figure` instance to which the parallel coordinate plot is plotted.
If not provided, use current axes or create a new one. Optional.
- `ax`: `matplotlib.axes.Axes` instance to which the parallel coordinate plot is plotted.
Expand All @@ -76,7 +85,10 @@ def parallel_coordinate_plot(
- `group2_name`: string, needed for violin plot legend if splited to two groups, for the 2nd group. Default is 'group2'.
- `comparing_models`: tuple or list containing two strings for models to compare with colors filled between the two lines.
- `fill_between_lines`: bool, default=False, fill color between lines for models in comparing_models
- `fill_between_lines_colors`: tuple or list containing two strings for colors filled between the two lines. Default=('green', 'red')
- `fill_between_lines_colors`: tuple or list containing two strings of colors for filled between the two lines. Default=('red', 'green')
- `arrow_between_lines`: bool, default=False, place arrows between two lines for models in comparing_models
- `arrow_between_lines_colors`: tuple or list containing two strings of colors for arrow between the two lines. Default=('red', 'green')
- `arrow_alpha`: float, default=1, transparency of arrow (faction between 0 to 1)
- `vertical_center`: string ("median", "mean")/float/integer, default=None, adjust range of vertical axis to set center of vertical axis as median, mean, or given number
- `vertical_center_line`: bool, default=False, show median as line
- `vertical_center_line_label`: str, default=None, label in legend for the horizontal vertical center line. If not given, it will be automatically assigned. It can be turned off by "off"
Expand Down Expand Up @@ -231,7 +243,18 @@ def parallel_coordinate_plot(
else:
label = model

ax.plot(range(N), zs[j, :], "-", c=color, label=label, lw=3)
if models_to_highlight_by_line:
ax.plot(range(N), zs[j, :], "-", c=color, label=label, lw=3)
else:
ax.plot(
range(N),
zs[j, :],
models_to_highlight_markers[mh_index],
c=color,
label=label,
markersize=models_to_highlight_markers_size,
)

mh_index += 1
else:
if identify_all_models:
Expand All @@ -251,8 +274,8 @@ def parallel_coordinate_plot(
vertical_center_line_label = None
ax.plot(range(N), zs_middle, "-", c="k", label=vertical_center_line_label, lw=1)

# Fill between lines
if fill_between_lines and (comparing_models is not None):
# Compare two models
if comparing_models is not None:
if isinstance(comparing_models, tuple) or (
isinstance(comparing_models, list) and len(comparing_models) == 2
):
Expand All @@ -261,24 +284,49 @@ def parallel_coordinate_plot(
m2 = model_names.index(comparing_models[1])
y1 = zs[m1, :]
y2 = zs[m2, :]
ax.fill_between(
x,
y1,
y2,
where=y2 >= y1,
facecolor=fill_between_lines_colors[0],
interpolate=True,
alpha=0.5,
)
ax.fill_between(
x,
y1,
y2,
where=y2 <= y1,
facecolor=fill_between_lines_colors[1],
interpolate=True,
alpha=0.5,
)

# Fill between lines
if fill_between_lines:
ax.fill_between(
x,
y1,
y2,
where=(y2 > y1),
facecolor=fill_between_lines_colors[0],
interpolate=False,
alpha=0.5,
)
ax.fill_between(
x,
y1,
y2,
where=(y2 < y1),
facecolor=fill_between_lines_colors[1],
interpolate=False,
alpha=0.5,
)

if arrow_between_lines:
# Add vertical arrows
for xi, yi1, yi2 in zip(x, y1, y2):
if yi2 > yi1:
arrow_color = arrow_between_lines_colors[0]
elif yi2 < yi1:
arrow_color = arrow_between_lines_colors[1]
else:
arrow_color = None
arrow_length = yi2 - yi1
ax.arrow(
xi,
yi1,
0,
arrow_length,
color=arrow_color,
length_includes_head=True,
alpha=arrow_alpha,
width=0.05,
head_width=0.15,
)

ax.set_xlim(-0.5, N - 0.5)
ax.set_xticks(range(N))
Expand Down

0 comments on commit c09be2a

Please sign in to comment.