Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 121 additions & 0 deletions climanet/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import matplotlib.pyplot as plt
from matplotlib.colors import TwoSlopeNorm
import matplotlib.ticker as mticker


def regrid_to_boundary_centered_grid(da: xr.DataArray, roll=False) -> xr.DataArray:
Expand Down Expand Up @@ -373,3 +374,123 @@ def plot_histograms(
)

plt.show()


def plot_nobs_vs_err(
nobs: xr.DataArray, err_baseline: xr.DataArray, err_predictions: xr.DataArray
):
"""Plot number of observations vs error for each month.

The three inputs are expected to be xarray DataArrays with dimensions (time, lat, lon).
They should share the same spatial and temporal coordinates.

Parameters

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are using Args and Return for docstring style in most of the source codes.

----------
nobs : xr.DataArray
Number of observations per grid cell per month. Dimensions: (time, lat, lon)
err_baseline : xr.DataArray
Baseline error per grid cell per month. Dimensions: (time, lat, lon)
err_predictions : xr.DataArray
Prediction error per grid cell per month. Dimensions: (time, lat, lon)
"""

fig, axes = plt.subplots(nobs.sizes["time"], 1, figsize=(5 * nobs.sizes["time"], 8))
if nobs.sizes["time"] == 1:
axes = [axes]

for i, ax in enumerate(axes):
ax.set_title(f"Month = {i}")

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use actual time info for title as: err_baseline.time.dt.strftime('%Y-%m-%d').values[t]


# Get unique number of observations for this month, ignoring NaNs and zeros
n_obs_unique = np.unique(nobs.isel(time=i).values.flatten())

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
n_obs_unique = np.unique(nobs.isel(time=i).values.flatten())
n_obs_unique = np.unique(nobs.isel(time=i).values)

n_obs_unique = n_obs_unique[~np.isnan(n_obs_unique)]

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
n_obs_unique = n_obs_unique[~np.isnan(n_obs_unique)]
n_obs_unique = n_obs_unique[(~np.isnan(n_obs_unique)) & (n_obs_unique > 0)]

n_obs_unique = n_obs_unique.astype(int)
n_obs_unique = n_obs_unique[n_obs_unique > 0]

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
n_obs_unique = n_obs_unique[n_obs_unique > 0]


err_by_n_obs_baseline = []
err_by_n_obs_predictions = []

for n_obs in n_obs_unique:

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

n_obs is the input argument of the function. Please choose another variable name in the loop.

# Baseline error
err_arr = (
err_baseline.isel(time=i)
.where(nobs.isel(time=i) == n_obs)
.values.flatten()
)
err_arr = err_arr[~np.isnan(err_arr)]
if len(err_arr) == 0:
err_arr = np.array([np.nan])
err_by_n_obs_baseline.append(np.abs(err_arr))

# Prediction error
err_arr = (
err_predictions.isel(time=i)
.where(nobs.isel(time=i) == n_obs)
.values.flatten()
)
err_arr = err_arr[~np.isnan(err_arr)]
if len(err_arr) == 0:
err_arr = np.array([np.nan])
err_by_n_obs_predictions.append(np.abs(err_arr))

h1 = ax.violinplot(
err_by_n_obs_baseline,
positions=n_obs_unique,
showmedians=True,
showextrema=True,
points=500,
)
h2 = ax.violinplot(
err_by_n_obs_predictions,
positions=n_obs_unique,
showmedians=True,
showextrema=True,
points=500,
)

# Style: thinner outlines + less prominent extrema
for body in h1["bodies"]:
body.set_facecolor("tab:blue")
body.set_edgecolor("tab:blue")
body.set_alpha(0.45)
body.set_linewidth(0.5)

for body in h2["bodies"]:
body.set_facecolor("tab:orange")
body.set_edgecolor("tab:orange")
body.set_alpha(0.45)
body.set_linewidth(0.5)

for h in (h1, h2):
h["cmedians"].set_linewidth(0.9)
h["cmedians"].set_alpha(0.9)

h["cbars"].set_linewidth(0.35)
h["cbars"].set_alpha(0.2)
h["cmins"].set_linewidth(0.35)
h["cmins"].set_alpha(0.2)
h["cmaxes"].set_linewidth(0.35)
h["cmaxes"].set_alpha(0.2)

ax.set_xlabel("Number of Daily Observations")

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
ax.set_xlabel("Number of Daily Observations")
ax.set_xlabel("Number of Observations")

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this function work for hourly data too?

ax.set_ylabel("Log Absolute Error (K)")

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
ax.set_ylabel("Log Absolute Error (K)")
ax.set_ylabel("Symmetric log-scaled Absolute Error (K)")


# Non-linear y-axis: keeps detail near 0 and compresses larger values.
ax.set_yscale("symlog", linthresh=0.05, linscale=0.8, base=10)

# Show major ticks as plain decimals instead of scientific/log notation.
ax.yaxis.set_major_locator(
mticker.SymmetricalLogLocator(base=10, linthresh=0.05)
)
ax.yaxis.set_major_formatter(
mticker.FuncFormatter(lambda y, _: f"{y:.3f}".rstrip("0").rstrip("."))
)
ax.yaxis.set_minor_formatter(mticker.NullFormatter())

ax.legend(
[h1["bodies"][0], h2["bodies"][0]],
["Baseline", "Prediction"],
loc="upper right",
)

plt.tight_layout()

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
plt.tight_layout()
plt.tight_layout()