diff --git a/climanet/utils.py b/climanet/utils.py index 95ea29e..1f4848f 100644 --- a/climanet/utils.py +++ b/climanet/utils.py @@ -336,11 +336,14 @@ def plot_results( plt.show() -def plot_histograms(target, predictions, label="SST K", title=("Target", "Prediction")): +def plot_histograms( + target, predictions, label="SST K", legend_labels=("Target", "Prediction"), bins=30 +): + """Plot histograms of target and predictions in the same figure for comparison.""" fig, axs = plt.subplots( nrows=len(target.time), - ncols=2, - figsize=(12, 4 * len(target.time)), + ncols=1, + figsize=(8, 4 * len(target.time)), constrained_layout=True, ) @@ -352,28 +355,21 @@ def plot_histograms(target, predictions, label="SST K", title=("Target", "Predic target_t = target.isel(time=t) pred_t = predictions.isel(time=t) - title_1, title_2 = title - # Target histogram - axs[t, 0].hist( - target_t.values.flatten(), bins=30, alpha=0.7, color="blue", density=True + axs[t].hist( + target_t.values.flatten(), bins=bins, alpha=0.7, color="blue", density=True ) - axs[t, 0].set_title( - f"{title_1} Histogram, month={target.time.dt.strftime('%Y-%m-%d').values[t]}" - ) - axs[t, 0].set_xlabel(label) - axs[t, 0].set_ylabel("Frequency") - axs[t, 0].grid(True, alpha=0.3) + axs[t].set_xlabel(label) + axs[t].set_ylabel("Frequency") + axs[t].grid(True, alpha=0.3) - # Prediction histogram - axs[t, 1].hist( - pred_t.values.flatten(), bins=30, alpha=0.7, color="orange", density=True + # Prediction histogram (overlaid) + axs[t].hist( + pred_t.values.flatten(), bins=bins, alpha=0.7, color="orange", density=True ) - axs[t, 1].set_title( - f"{title_2} Histogram, month={target.time.dt.strftime('%Y-%m-%d').values[t]}" + axs[t].legend(legend_labels) + axs[t].set_title( + f"Histogram {legend_labels[0]} vs {legend_labels[1]}, month={t + 1}" ) - axs[t, 1].set_xlabel(label) - axs[t, 1].set_ylabel("Frequency") - axs[t, 1].grid(True, alpha=0.3) plt.show()