-
Notifications
You must be signed in to change notification settings - Fork 0
add function for violin plot nr of obs vs err #61
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -10,6 +10,7 @@ | |||||
|
|
||||||
| import matplotlib.pyplot as plt | ||||||
| from matplotlib.colors import TwoSlopeNorm | ||||||
| import matplotlib.ticker as mticker | ||||||
|
|
||||||
|
|
||||||
| def regrid_to_boundary_centered_grid(da: xr.DataArray, roll=False) -> xr.DataArray: | ||||||
|
|
@@ -373,3 +374,123 @@ def plot_histograms( | |||||
| ) | ||||||
|
|
||||||
| plt.show() | ||||||
|
|
||||||
|
|
||||||
| def plot_nobs_vs_err( | ||||||
| nobs: xr.DataArray, err_baseline: xr.DataArray, err_predictions: xr.DataArray | ||||||
| ): | ||||||
| """Plot number of observations vs error for each month. | ||||||
|
|
||||||
| The three inputs are expected to be xarray DataArrays with dimensions (time, lat, lon). | ||||||
| They should share the same spatial and temporal coordinates. | ||||||
|
|
||||||
| Parameters | ||||||
| ---------- | ||||||
| nobs : xr.DataArray | ||||||
| Number of observations per grid cell per month. Dimensions: (time, lat, lon) | ||||||
| err_baseline : xr.DataArray | ||||||
| Baseline error per grid cell per month. Dimensions: (time, lat, lon) | ||||||
| err_predictions : xr.DataArray | ||||||
| Prediction error per grid cell per month. Dimensions: (time, lat, lon) | ||||||
| """ | ||||||
|
|
||||||
| fig, axes = plt.subplots(nobs.sizes["time"], 1, figsize=(5 * nobs.sizes["time"], 8)) | ||||||
| if nobs.sizes["time"] == 1: | ||||||
| axes = [axes] | ||||||
|
|
||||||
| for i, ax in enumerate(axes): | ||||||
| ax.set_title(f"Month = {i}") | ||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please use actual time info for title as: |
||||||
|
|
||||||
| # Get unique number of observations for this month, ignoring NaNs and zeros | ||||||
| n_obs_unique = np.unique(nobs.isel(time=i).values.flatten()) | ||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
| n_obs_unique = n_obs_unique[~np.isnan(n_obs_unique)] | ||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
| n_obs_unique = n_obs_unique.astype(int) | ||||||
| n_obs_unique = n_obs_unique[n_obs_unique > 0] | ||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
|
|
||||||
| err_by_n_obs_baseline = [] | ||||||
| err_by_n_obs_predictions = [] | ||||||
|
|
||||||
| for n_obs in n_obs_unique: | ||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||||||
| # Baseline error | ||||||
| err_arr = ( | ||||||
| err_baseline.isel(time=i) | ||||||
| .where(nobs.isel(time=i) == n_obs) | ||||||
| .values.flatten() | ||||||
| ) | ||||||
| err_arr = err_arr[~np.isnan(err_arr)] | ||||||
| if len(err_arr) == 0: | ||||||
| err_arr = np.array([np.nan]) | ||||||
| err_by_n_obs_baseline.append(np.abs(err_arr)) | ||||||
|
|
||||||
| # Prediction error | ||||||
| err_arr = ( | ||||||
| err_predictions.isel(time=i) | ||||||
| .where(nobs.isel(time=i) == n_obs) | ||||||
| .values.flatten() | ||||||
| ) | ||||||
| err_arr = err_arr[~np.isnan(err_arr)] | ||||||
| if len(err_arr) == 0: | ||||||
| err_arr = np.array([np.nan]) | ||||||
| err_by_n_obs_predictions.append(np.abs(err_arr)) | ||||||
|
|
||||||
| h1 = ax.violinplot( | ||||||
| err_by_n_obs_baseline, | ||||||
| positions=n_obs_unique, | ||||||
| showmedians=True, | ||||||
| showextrema=True, | ||||||
| points=500, | ||||||
| ) | ||||||
| h2 = ax.violinplot( | ||||||
| err_by_n_obs_predictions, | ||||||
| positions=n_obs_unique, | ||||||
| showmedians=True, | ||||||
| showextrema=True, | ||||||
| points=500, | ||||||
| ) | ||||||
|
|
||||||
| # Style: thinner outlines + less prominent extrema | ||||||
| for body in h1["bodies"]: | ||||||
| body.set_facecolor("tab:blue") | ||||||
| body.set_edgecolor("tab:blue") | ||||||
| body.set_alpha(0.45) | ||||||
| body.set_linewidth(0.5) | ||||||
|
|
||||||
| for body in h2["bodies"]: | ||||||
| body.set_facecolor("tab:orange") | ||||||
| body.set_edgecolor("tab:orange") | ||||||
| body.set_alpha(0.45) | ||||||
| body.set_linewidth(0.5) | ||||||
|
|
||||||
| for h in (h1, h2): | ||||||
| h["cmedians"].set_linewidth(0.9) | ||||||
| h["cmedians"].set_alpha(0.9) | ||||||
|
|
||||||
| h["cbars"].set_linewidth(0.35) | ||||||
| h["cbars"].set_alpha(0.2) | ||||||
| h["cmins"].set_linewidth(0.35) | ||||||
| h["cmins"].set_alpha(0.2) | ||||||
| h["cmaxes"].set_linewidth(0.35) | ||||||
| h["cmaxes"].set_alpha(0.2) | ||||||
|
|
||||||
| ax.set_xlabel("Number of Daily Observations") | ||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does this function work for hourly data too? |
||||||
| ax.set_ylabel("Log Absolute Error (K)") | ||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
|
|
||||||
| # Non-linear y-axis: keeps detail near 0 and compresses larger values. | ||||||
| ax.set_yscale("symlog", linthresh=0.05, linscale=0.8, base=10) | ||||||
|
|
||||||
| # Show major ticks as plain decimals instead of scientific/log notation. | ||||||
| ax.yaxis.set_major_locator( | ||||||
| mticker.SymmetricalLogLocator(base=10, linthresh=0.05) | ||||||
| ) | ||||||
| ax.yaxis.set_major_formatter( | ||||||
| mticker.FuncFormatter(lambda y, _: f"{y:.3f}".rstrip("0").rstrip(".")) | ||||||
| ) | ||||||
| ax.yaxis.set_minor_formatter(mticker.NullFormatter()) | ||||||
|
|
||||||
| ax.legend( | ||||||
| [h1["bodies"][0], h2["bodies"][0]], | ||||||
| ["Baseline", "Prediction"], | ||||||
| loc="upper right", | ||||||
| ) | ||||||
|
|
||||||
| plt.tight_layout() | ||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We are using
ArgsandReturnfor docstring style in most of the source codes.