-
Feature
-
Resolution: Done
-
Undefined
-
None
-
None
-
None
-
-
- Description
-
There’s an incorrect code example in the docstring of:
*File:* `torch/onnx/_internal/exporter/_torchlib/ops/nn.py`
*Location:* Around line 123
—
-
-
- Current (Broken) Example
-
```python
attn_weight = torch.softmax(
(Q @ K.transpose(-2, -1) * attn_mask, dim=-1
)
```
This example has a syntax error and also uses the attention mask incorrectly.
—
-
-
- Suggested Fix
-
The example should follow the correct scaled dot-product attention formula used in PyTorch’s `MultiheadAttention`.
```python
- Corrected version
scale_factor = 1.0 / math.sqrt(Q.size(-1))
attn_weight = torch.softmax(
(Q @ K.transpose(-2, -1) * scale_factor) + attn_mask, dim=-1
)
```
—
-
-
- Explanation
-
1. *Adds missing scale factor*
The dot product between Q and K should be scaled by `1 / sqrt(d_k)` to prevent large values before softmax.
2. *Fixes syntax error*
Removes the misplaced comma before `dim=-1`, ensuring `torch.softmax()` receives proper arguments.
3. *Corrects mask usage*
The mask should be *added*, not multiplied.
Adding large negative values (e.g., `-inf`) ensures masked positions have zero probability after softmax.
—
cc @justinchuby @titaiwangms