Uploaded image for project: 'AI Platform Core Components'
  1. AI Platform Core Components
  2. AIPCC-7401

[docs] Incorrect attention example in torch/onnx/_internal/exporter/_torchlib/ops/nn.py (syntax + scaling + mask issue)

    • Icon: Feature Feature
    • Resolution: Done
    • Icon: Undefined Undefined
    • None
    • None
    • None

          1. Description

      There’s an incorrect code example in the docstring of:

      *File:* `torch/onnx/_internal/exporter/_torchlib/ops/nn.py`
      *Location:* Around line 123

          1. Current (Broken) Example

      ```python
      attn_weight = torch.softmax(
      (Q @ K.transpose(-2, -1) * attn_mask, dim=-1
      )
      ```

      This example has a syntax error and also uses the attention mask incorrectly.

          1. Suggested Fix

      The example should follow the correct scaled dot-product attention formula used in PyTorch’s `MultiheadAttention`.

      ```python

      1. Corrected version
        scale_factor = 1.0 / math.sqrt(Q.size(-1))
        attn_weight = torch.softmax(
        (Q @ K.transpose(-2, -1) * scale_factor) + attn_mask, dim=-1
        )
        ```

          1. Explanation

      1. *Adds missing scale factor*
      The dot product between Q and K should be scaled by `1 / sqrt(d_k)` to prevent large values before softmax.

      2. *Fixes syntax error*
      Removes the misplaced comma before `dim=-1`, ensuring `torch.softmax()` receives proper arguments.

      3. *Corrects mask usage*
      The mask should be *added*, not multiplied.
      Adding large negative values (e.g., `-inf`) ensures masked positions have zero probability after softmax.

      cc @justinchuby @titaiwangms

              rh-ee-ggoswami Gaurav Goswami
              rh-ee-ggoswami Gaurav Goswami
              Votes:
              0 Vote for this issue
              Watchers:
              1 Start watching this issue

                Created:
                Updated:
                Resolved:

                  Estimated:
                  Original Estimate - 2 minutes
                  2m
                  Remaining:
                  Remaining Estimate - 2 minutes
                  2m
                  Logged:
                  Time Spent - Not Specified
                  Not Specified