Imagine this: you're training a language model, everything is set up, the system is running – and suddenly, strange numbers start appearing. No errors, no crashes, just a silent divergence in probabilities that can't be explained by anything obvious. This is exactly how one of the most educational episodes began for the team at AI21 – a company developing its own language models.
«Same But Different»
When training language models using the GRPO method – an approach to reinforcement learning – there's a vital check. The system generates text, records the «confidence» with which it chose each word, and then the exact same model (with the same weights, without any updates) recalculates those values from scratch. The results should match almost perfectly: same weights, same input, same output.
But in the case of Jamba 3B – AI21's hybrid model that combines standard attention mechanisms with the Mamba architecture – there was no match. The numbers diverged. Moreover, they didn't diverge randomly, but with a certain periodicity: the glitch appeared roughly every 12 training steps, then «disappeared», and then returned again.
The nastiest part was that, from the outside, it looked like typical training instability. Such things happen, and it would have been very easy to write it all off as «noise.»
Finding the Lever
Distributed training systems are complex beasts. They simultaneously run an engine for text generation, a system for updating weights, coordination between multiple machines, and data transfer between components. When something breaks in such a system, the temptation is great to start debugging everything at once and hit a dead end.
The team chose a different path: find a parameter whose change alters the nature of the failure, not just its intensity. Simply put – find a lever.
That lever turned out to be the number of generated texts per request. When researchers started increasing this number – 8, 16, 32, 64, 128 – they noticed something crucial: the periodicity of the glitches changed along with this parameter. At 128 texts per request, the glitch appeared on the very first step.
This observation changed the whole picture. If the failure is synchronized with the generation process, it means the problem is likely right there – rather than in the weight update system, machine synchronization, or the training algorithm itself.
From «Training Oddity» to «Reproducible Defect»
The next goal was to reproduce the bug in the simplest conditions possible. Ideally – at step zero, before training even began. This is important: when an error appears from the very first iteration, all accumulated effects drop out of the equation – gradient history, weight drift, and long-term training dynamics. All that's left is code execution.
With 128 texts per request, they succeeded. The bug was reproduced on the first step, reliably and consistently.
Further narrowing the circle of suspects was another parameter – the amount of GPU memory allocated for the cache. At a value of 50%, the bug vanished. At a higher value, it reappeared. This meant the issue lay in how the inference engine (the text generation part) allocates and uses the cache inside the video card.
And since Jamba is a hybrid architecture, it has two types of cache: one for the attention mechanism and another for the Mamba blocks. They tested the model with only the attention mechanism – the bug didn't reproduce. That meant the culprit was in the Mamba part.
Two Characters and Several Weeks
The source of the problem was found in low-level code that executes directly on the GPU. There's an operation there: calculate exactly where in memory to write the state for each cache element. To do this, you need to multiply two numbers: the element index and the size of one «stride» in memory.
Both numbers were declared as 32-bit unsigned integers. This is a type that can store values up to roughly 4.29 billion. It sounds impressive, but the product of these two numbers could easily exceed that limit.
Simply put: imagine an odometer that only goes up to 999,999 km and then resets to zero. The car drives a million kilometers – the odometer shows 000,000. No error, no warning. Just the wrong number.
That's exactly what was happening. With a stride size of 89,600 elements, the overflow occurred when the element index exceeded approximately 47,935. In the actual configuration, the cache contained 69,776 slots – meaning about 31% of them were being written to the wrong addresses in GPU memory. Data was going «nowhere», while the «correct» spots remained zeros.
No system crash. No warning. Just silently incorrect numbers at the output, which were then interpreted as training instability.
The fix took two seconds: the data type was changed from 32-bit to 64-bit. Now the numbers don't overflow. That's it.
Weeks of investigation. Two changed characters in the code.
Why It's Hard to Catch
Such bugs are particularly insidious in machine learning systems for several reasons.
First, they don't break the system explicitly. There's no exception, no invalid data format, no obvious signal. There are just slightly different numbers, which could easily be the result of a hundred other causes.
Second, they only manifest at a certain scale. With a small number of requests, the cache didn't fill up to the critical point, and no overflow occurred. You had to specifically «pressurize» the system for the bug to show itself at all.
Third, in distributed training systems, the symptom and the cause can be very far apart. An error in the GPU kernel during the text generation stage looked like instability during the training stage – in a completely different component of the system.
A Lesson Not About the Bug, But the Approach
The most valuable thing in this story isn't the defect itself, but the method of its discovery.
When a system behaves strangely, the first instinct is to check everything at once. This rarely helps. It's much more effective to look for a parameter that changes the structure of the failure. Not «at what value does the error get bigger», but «at what value does it start behaving differently.» This is what provides a clue about the nature of the problem.
The second principle is to narrow it down to the minimum. If a bug can be reproduced on the first step instead of the five-hundredth, you must achieve exactly that. The smaller the context, the cleaner the signal.
The third is to isolate subsystems. A large distributed pipeline is a bad place for debugging. A good place is a minimal script that reproduces the problem in isolation.
These principles apply not just to GPU kernels and not just to machine learning. It's simply good engineering practice: when you don't understand what broke, first understand where it broke. And for that, you need to know how to ask the system questions and correctly interpret the answers.