Diagnosing and fixing exploding gradients in a deep recurrent/Transformer training run
hardsubjective
General
You are training a deep sequence model in PyTorch. After a few hundred steps the loss suddenly becomes `NaN`, and you observe that the global gradient norm spikes to very large values right before the divergence.
Walk through how you would diagnose the root cause and stabilize training. In your answer address: (a) how you would instrument the run to confirm it is an exploding-gradient problem (and distinguish it from, say, a bad learning rate or a numerical issue in the loss), (b) the architectural and initialization factors that make exploding gradients more likely, and (c) at least three concrete mitigation techniques, explaining the trade-offs of each. Be specific about the PyTorch API calls you would use.