From 020dd3d65267aecddfb15cf35b6e056257d705c8 Mon Sep 17 00:00:00 2001 From: Ou Ku Date: Tue, 23 Jun 2026 16:26:00 +0200 Subject: [PATCH] add function for violin plot --- climanet/utils.py | 121 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 121 insertions(+) diff --git a/climanet/utils.py b/climanet/utils.py index 719fd8c..f927d08 100644 --- a/climanet/utils.py +++ b/climanet/utils.py @@ -10,6 +10,7 @@ import matplotlib.pyplot as plt from matplotlib.colors import TwoSlopeNorm +import matplotlib.ticker as mticker def regrid_to_boundary_centered_grid(da: xr.DataArray, roll=False) -> xr.DataArray: @@ -373,3 +374,123 @@ def plot_histograms( ) plt.show() + + +def plot_nobs_vs_err( + nobs: xr.DataArray, err_baseline: xr.DataArray, err_predictions: xr.DataArray +): + """Plot number of observations vs error for each month. + + The three inputs are expected to be xarray DataArrays with dimensions (time, lat, lon). + They should share the same spatial and temporal coordinates. + + Parameters + ---------- + nobs : xr.DataArray + Number of observations per grid cell per month. Dimensions: (time, lat, lon) + err_baseline : xr.DataArray + Baseline error per grid cell per month. Dimensions: (time, lat, lon) + err_predictions : xr.DataArray + Prediction error per grid cell per month. Dimensions: (time, lat, lon) + """ + + fig, axes = plt.subplots(nobs.sizes["time"], 1, figsize=(5 * nobs.sizes["time"], 8)) + if nobs.sizes["time"] == 1: + axes = [axes] + + for i, ax in enumerate(axes): + ax.set_title(f"Month = {i}") + + # Get unique number of observations for this month, ignoring NaNs and zeros + n_obs_unique = np.unique(nobs.isel(time=i).values.flatten()) + n_obs_unique = n_obs_unique[~np.isnan(n_obs_unique)] + n_obs_unique = n_obs_unique.astype(int) + n_obs_unique = n_obs_unique[n_obs_unique > 0] + + err_by_n_obs_baseline = [] + err_by_n_obs_predictions = [] + + for n_obs in n_obs_unique: + # Baseline error + err_arr = ( + err_baseline.isel(time=i) + .where(nobs.isel(time=i) == n_obs) + .values.flatten() + ) + err_arr = err_arr[~np.isnan(err_arr)] + if len(err_arr) == 0: + err_arr = np.array([np.nan]) + err_by_n_obs_baseline.append(np.abs(err_arr)) + + # Prediction error + err_arr = ( + err_predictions.isel(time=i) + .where(nobs.isel(time=i) == n_obs) + .values.flatten() + ) + err_arr = err_arr[~np.isnan(err_arr)] + if len(err_arr) == 0: + err_arr = np.array([np.nan]) + err_by_n_obs_predictions.append(np.abs(err_arr)) + + h1 = ax.violinplot( + err_by_n_obs_baseline, + positions=n_obs_unique, + showmedians=True, + showextrema=True, + points=500, + ) + h2 = ax.violinplot( + err_by_n_obs_predictions, + positions=n_obs_unique, + showmedians=True, + showextrema=True, + points=500, + ) + + # Style: thinner outlines + less prominent extrema + for body in h1["bodies"]: + body.set_facecolor("tab:blue") + body.set_edgecolor("tab:blue") + body.set_alpha(0.45) + body.set_linewidth(0.5) + + for body in h2["bodies"]: + body.set_facecolor("tab:orange") + body.set_edgecolor("tab:orange") + body.set_alpha(0.45) + body.set_linewidth(0.5) + + for h in (h1, h2): + h["cmedians"].set_linewidth(0.9) + h["cmedians"].set_alpha(0.9) + + h["cbars"].set_linewidth(0.35) + h["cbars"].set_alpha(0.2) + h["cmins"].set_linewidth(0.35) + h["cmins"].set_alpha(0.2) + h["cmaxes"].set_linewidth(0.35) + h["cmaxes"].set_alpha(0.2) + + ax.set_xlabel("Number of Daily Observations") + ax.set_ylabel("Log Absolute Error (K)") + + # Non-linear y-axis: keeps detail near 0 and compresses larger values. + ax.set_yscale("symlog", linthresh=0.05, linscale=0.8, base=10) + + # Show major ticks as plain decimals instead of scientific/log notation. + ax.yaxis.set_major_locator( + mticker.SymmetricalLogLocator(base=10, linthresh=0.05) + ) + ax.yaxis.set_major_formatter( + mticker.FuncFormatter(lambda y, _: f"{y:.3f}".rstrip("0").rstrip(".")) + ) + ax.yaxis.set_minor_formatter(mticker.NullFormatter()) + + ax.legend( + [h1["bodies"][0], h2["bodies"][0]], + ["Baseline", "Prediction"], + loc="upper right", + ) + + plt.tight_layout()