Uploaded image for project: 'AI Platform Core Components'
  1. AI Platform Core Components
  2. AIPCC-5814

`input` argument of `nan_to_num()` works with `complex` type but `nan`, `posinf` and `neginf` argument don't work with `complex` type

    • Icon: Story Story
    • Resolution: Unresolved
    • Icon: Undefined Undefined
    • None
    • None
    • PyTorch
    • None
    • PyTorch Sprint 17

          1. 🐛 Describe the bug

      `input` argument of [nan_to_num()](https://pytorch.org/docs/stable/generated/torch.nan_to_num.html) works with `complex` type but `nan`, `posinf` and `neginf` argument don't work with `complex` type as shown below:

      ```python
      import torch

      my_tensor = torch.tensor([complex('nan'), complex('inf-infj'), complex('-inf+infj'), 4.+0.j])

      1. ↓ # ↓ # ↓
        torch.nan_to_num(input=my_tensor, nan=1, posinf=2, neginf=3)
      2. tensor([1.+0.j, 2.+3.j, 3.+2.j, 4.+0.j])
      1. ↓↓ # ↓↓ # ↓↓
        torch.nan_to_num(input=my_tensor, nan=1., posinf=2., neginf=3.)
      2. tensor([1.+0.j, 2.+3.j, 3.+2.j, 4.+0.j])
      1. ↓↓↓↓ # ↓↓↓↓↓ # ↓↓↓↓
        torch.nan_to_num(input=my_tensor, nan=True, posinf=False, neginf=True)
      2. tensor([1.+0.j, 0.+1.j, 1.+0.j, 4.+0.j])
      1. ↓↓↓↓↓↓ # ↓↓↓↓↓↓ # ↓↓↓↓↓↓
        torch.nan_to_num(input=my_tensor, nan=1.+0.j, posinf=2.+0.j, neginf=3.+0.j) # Error
        ```

      > TypeError: nan_to_num(): argument 'nan' must be float, not complex

      > TypeError: nan_to_num(): argument 'posinf' must be float, not complex

      > TypeError: nan_to_num(): argument 'neginf' must be float, not complex

      In addition, `input` argument of `nan_to_num()` works with `int`, `float` and `bool` type as shown below:

      ```python
      import torch

      my_tensor = torch.tensor([1, 2, 3, 4])

      torch.nan_to_num(input=my_tensor)

      1. tensor([1, 2, 3, 4])

      my_tensor = torch.tensor([1., 2., 3., 4.])

      torch.nan_to_num(input=my_tensor)

      1. tensor([1., 2., 3., 4.])

      my_tensor = torch.tensor([True, False, True, False])

      torch.nan_to_num(input=my_tensor)

      1. tensor([True, False, True, False])
        ```
          1. Versions

      ```python
      import torch

      torch._version_ # 2.3.1+cu121
      ```

      cc @ezyang @anjali411 @dylanbespalko @mruberry @Lezcano @nikitaved @amjames

              rh-ee-chleonar Christopher Leonard
              rh-ee-chleonar Christopher Leonard
              PyTorch Core
              Votes:
              0 Vote for this issue
              Watchers:
              3 Start watching this issue

                Created:
                Updated: