Model Training¶
This guide covers training DDR models to learn optimal routing parameters from observed streamflow data.
Overview¶
DDR training optimizes a neural network (KAN) to predict physical routing parameters (Manning's n, channel geometry) from catchment attributes. The training loop:
- Reads lateral inflow (Q') from unit catchment predictions
- Predicts routing parameters using the KAN
- Routes flow through the river network using Muskingum-Cunge
- Computes loss against observed streamflow
- Backpropagates gradients through the entire system
Quick Start¶
Configuration¶
Essential Training Options¶
mode: training
geodataset: lynker_hydrofabric # or merit
experiment:
epochs: 5 # Number of training epochs
batch_size: 64 # Gauges per batch
learning_rate:
1: 0.005 # LR for epoch 1
3: 0.001 # LR for epoch 3+
rho: 365 # Training window (days)
warmup: 3 # Warmup days excluded from loss
shuffle: true # Shuffle training data
checkpoint: null # Resume from checkpoint (optional)
KAN Configuration¶
kan:
hidden_size: 21 # Hidden layer size (recommend 2n+1)
num_hidden_layers: 2 # Number of hidden layers
input_var_names: # Catchment attributes as inputs
- aridity
- meanelevation
- meanP
- log10_uparea
# ... more attributes
learnable_parameters: # Parameters to learn
- n # Manning's roughness
- q_spatial # Leopold & Maddock shape exponent
- p_spatial # Leopold & Maddock width coefficient
grid: 50 # KAN grid size
k: 2 # KAN spline order
Training Process¶
1. Data Loading¶
DDR uses PyTorch DataLoaders with custom collate functions:
dataset = cfg.geodataset.get_dataset_class(cfg=cfg)
dataloader = DataLoader(
dataset=dataset,
batch_size=cfg.experiment.batch_size,
sampler=RandomSampler(dataset),
collate_fn=dataset.collate_fn,
)
2. Forward Pass¶
For each batch:
# Get lateral inflows
streamflow_predictions = flow(routing_dataclass=routing_dataclass)
# Predict parameters from attributes
spatial_params = nn(inputs=routing_dataclass.normalized_spatial_attributes)
# Route flow through network
dmc_output = routing_model(
routing_dataclass=routing_dataclass,
spatial_parameters=spatial_params,
streamflow=streamflow_predictions,
)
3. Loss Computation¶
Loss is computed on daily-averaged discharge after warmup using MAE (L1) loss:
# Downsample to daily
daily_runoff = ddr_functions.downsample(dmc_output["runoff"], rho=num_days)
# Compute MAE loss (excluding warmup period)
loss = torch.nn.functional.l1_loss(
input=daily_runoff[:, warmup:],
target=observations[:, warmup:],
)
4. Checkpointing¶
Models are saved periodically:
Resuming Training¶
To resume from a checkpoint:
The training will resume from the saved epoch and mini-batch.
Monitoring¶
Training progress is logged to the output directory. See the log file for details on loss values, learning rate changes, and parameter statistics.
Tips¶
- Start with smaller batch sizes (8-16) for debugging
- Use warmup (3+ days) to allow routing to stabilize
- Monitor for NaN losses - may indicate unstable parameters
- Save checkpoints frequently - training can take hours/days
Next Steps¶
- Model Testing: Evaluate trained models
- Benchmarks: Compare against other models