Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions climanet/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,7 @@ def _run_one_batch(model: torch.nn.Module, batch: dict):
) # (B, M, H, W)

# Compute masked loss
return compute_masked_loss(
pred, batch["monthly_patch"], batch["land_mask_patch"]
)
return compute_masked_loss(pred, batch["monthly_patch"], batch["land_mask_patch"])


def _compute_stats(dataset: Dataset):
Expand All @@ -35,7 +33,7 @@ def _compute_stats(dataset: Dataset):

def _initialize_decoder(model: torch.nn.Module, dataset: Dataset):
mean, std = _compute_stats(dataset)
decoder = model.module.decoder if hasattr(model, 'module') else model.decoder
decoder = model.module.decoder if hasattr(model, "module") else model.decoder
with torch.no_grad():
decoder.bias.copy_(torch.from_numpy(mean))
decoder.scale.copy_(torch.from_numpy(std) + 1e-6)
Expand All @@ -58,6 +56,7 @@ def train_monthly_model(
device: str = "cpu",
verbose: bool = True,
dataloader_num_workers: int = 2,
verbose_epoch_interval: int = 20,
):
"""Train the model to predict monthly data from daily data.
Args:
Expand All @@ -75,6 +74,7 @@ def train_monthly_model(
verbose: whether to print training progress
dataloader_num_workers: how many subprocesses to use for data loading.
See torch DataLoader docs for details.
verbose_epoch_interval: how often to print training progress (in epochs)
"""
# Initialize the model
model = model.to(device)
Expand All @@ -87,7 +87,7 @@ def train_monthly_model(
batch_size=batch_size,
shuffle=shuffle,
pin_memory=use_cuda,
num_workers=dataloader_num_workers, # for data loading
num_workers=dataloader_num_workers, # for data loading
persistent_workers=True, # keep workers alive between epochs
)

Expand Down Expand Up @@ -160,7 +160,7 @@ def train_monthly_model(
)
writer.add_scalar("Loss/validation", avg_epoch_loss, epoch)

if verbose and epoch % 20 == 0:
if verbose and epoch % verbose_epoch_interval == 0:
gap = avg_epoch_loss - avg_train_loss
print(f"Epoch {epoch}: gap between train and val loss: {gap:.6f}")

Expand Down