From 0820b5076103043551c23d5ab9e15f808d99d7ea Mon Sep 17 00:00:00 2001 From: Ou Ku Date: Thu, 18 Jun 2026 16:06:21 +0200 Subject: [PATCH] make verbose print epochs a parameter instead of hard-coded 20 --- climanet/train.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/climanet/train.py b/climanet/train.py index 8cc565f..5bb9982 100644 --- a/climanet/train.py +++ b/climanet/train.py @@ -20,9 +20,7 @@ def _run_one_batch(model: torch.nn.Module, batch: dict): ) # (B, M, H, W) # Compute masked loss - return compute_masked_loss( - pred, batch["monthly_patch"], batch["land_mask_patch"] - ) + return compute_masked_loss(pred, batch["monthly_patch"], batch["land_mask_patch"]) def _compute_stats(dataset: Dataset): @@ -35,7 +33,7 @@ def _compute_stats(dataset: Dataset): def _initialize_decoder(model: torch.nn.Module, dataset: Dataset): mean, std = _compute_stats(dataset) - decoder = model.module.decoder if hasattr(model, 'module') else model.decoder + decoder = model.module.decoder if hasattr(model, "module") else model.decoder with torch.no_grad(): decoder.bias.copy_(torch.from_numpy(mean)) decoder.scale.copy_(torch.from_numpy(std) + 1e-6) @@ -58,6 +56,7 @@ def train_monthly_model( device: str = "cpu", verbose: bool = True, dataloader_num_workers: int = 2, + verbose_epoch_interval: int = 20, ): """Train the model to predict monthly data from daily data. Args: @@ -75,6 +74,7 @@ def train_monthly_model( verbose: whether to print training progress dataloader_num_workers: how many subprocesses to use for data loading. See torch DataLoader docs for details. + verbose_epoch_interval: how often to print training progress (in epochs) """ # Initialize the model model = model.to(device) @@ -87,7 +87,7 @@ def train_monthly_model( batch_size=batch_size, shuffle=shuffle, pin_memory=use_cuda, - num_workers=dataloader_num_workers, # for data loading + num_workers=dataloader_num_workers, # for data loading persistent_workers=True, # keep workers alive between epochs ) @@ -160,7 +160,7 @@ def train_monthly_model( ) writer.add_scalar("Loss/validation", avg_epoch_loss, epoch) - if verbose and epoch % 20 == 0: + if verbose and epoch % verbose_epoch_interval == 0: gap = avg_epoch_loss - avg_train_loss print(f"Epoch {epoch}: gap between train and val loss: {gap:.6f}")