-
Story
-
Resolution: Unresolved
-
Undefined
-
None
-
None
-
None
-
-
- 🐛 Describe the bug
-
Reproduce example:
```python
import torch
def f(*args):
sym_0, sym_1, sym_2 = args
var_485 = torch.ones(sym_0)
return torch.unsafe_split(var_485, split_size=sym_1, dim=sym_2)
res = f((0,), 0, -1,)
print(res) # (tensor([]), )
res = torch.compile(f)((0,), 0, -1,) # ZeroDivisionError
print(res)
```
The behavior is different between Eager Mode and Inductor.
And if `split_size` > input dim, Eager Mode doesn't handle it as an error, but inductor causes `list assignment index out of range` error, add more input parameter range check is needed.
-
-
- Error logs
-
```
Traceback (most recent call last):
File "/home/yvesw/reborn2-expr/250216-bugs/test-5.py", line 12, in <module>
res = torch.compile(f)((0,), 0, -1,)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/yvesw/miniconda3/envs/torch-preview/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 574, in _fn
raise e.remove_dynamo_frames() from None # see TORCHDYNAMO_VERBOSE=1
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/yvesw/miniconda3/envs/torch-preview/lib/python3.11/site-packages/torch/_inductor/compile_fx.py", line 752, in _compile_fx_inner
raise InductorError(e, currentframe()).with_traceback(
File "/home/yvesw/miniconda3/envs/torch-preview/lib/python3.11/site-packages/torch/_inductor/compile_fx.py", line 737, in _compile_fx_inner
mb_compiled_graph = fx_codegen_and_compile(
^^^^^^^^^^^^^^^^^^^^^^^
File "/home/yvesw/miniconda3/envs/torch-preview/lib/python3.11/site-packages/torch/_inductor/compile_fx.py", line 1402, in fx_codegen_and_compile
return scheme.codegen_and_compile(gm, example_inputs, inputs_to_check, graph_kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/yvesw/miniconda3/envs/torch-preview/lib/python3.11/site-packages/torch/_inductor/compile_fx.py", line 1057, in codegen_and_compile
graph.run(*example_inputs)
File "/home/yvesw/miniconda3/envs/torch-preview/lib/python3.11/site-packages/torch/_inductor/graph.py", line 851, in run
return super().run(*args)
^^^^^^^^^^^^^^^^^^
File "/home/yvesw/miniconda3/envs/torch-preview/lib/python3.11/site-packages/torch/fx/interpreter.py", line 171, in run
self.env[node] = self.run_node(node)
^^^^^^^^^^^^^^^^^^^
File "/home/yvesw/miniconda3/envs/torch-preview/lib/python3.11/site-packages/torch/_inductor/graph.py", line 1436, in run_node
result = super().run_node![]()
^^^^^^^^^^^^^^^^^^^
File "/home/yvesw/miniconda3/envs/torch-preview/lib/python3.11/site-packages/torch/fx/interpreter.py", line 236, in run_node
return getattr(self, n.op)(n.target, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/yvesw/miniconda3/envs/torch-preview/lib/python3.11/site-packages/torch/_inductor/graph.py", line 1139, in call_function
raise LoweringException(e, target, args, kwargs).with_traceback(
File "/home/yvesw/miniconda3/envs/torch-preview/lib/python3.11/site-packages/torch/_inductor/graph.py", line 1129, in call_function
out = lowerings[target](*args, **kwargs) # type: ignore[index]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/yvesw/miniconda3/envs/torch-preview/lib/python3.11/site-packages/torch/_inductor/lowering.py", line 462, in wrapped
out = decomp_fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/yvesw/miniconda3/envs/torch-preview/lib/python3.11/site-packages/torch/_inductor/lowering.py", line 1798, in split
FloorDiv(x_size + sizes - 1, sizes)
File "/home/yvesw/miniconda3/envs/torch-preview/lib/python3.11/site-packages/sympy/core/cache.py", line 72, in wrapper
retval = cfunc(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^
File "/home/yvesw/miniconda3/envs/torch-preview/lib/python3.11/site-packages/sympy/core/function.py", line 466, in _new_
result = super()._new_(cls, *args, **options)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/yvesw/miniconda3/envs/torch-preview/lib/python3.11/site-packages/sympy/core/cache.py", line 72, in wrapper
retval = cfunc(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^
File "/home/yvesw/miniconda3/envs/torch-preview/lib/python3.11/site-packages/sympy/core/function.py", line 307, in _new_
evaluated = cls.eval(*args)
^^^^^^^^^^^^^^^
File "/home/yvesw/miniconda3/envs/torch-preview/lib/python3.11/site-packages/torch/utils/_sympy/functions.py", line 223, in eval
raise ZeroDivisionError("division by zero")
torch._inductor.exc.InductorError: LoweringException: ZeroDivisionError: division by zero
target: aten.split.Tensor
args[0]: TensorBox(StorageBox(
Pointwise(
'cpu',
torch.float32,
def inner_fn(index):
i0 = index
tmp0 = ops.constant(1, torch.float32)
return tmp0
,
ranges=[0],
origin_node=full_default,
origins=OrderedSet([full_default])
)
))
args[1]: 0
args[2]: -1
Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
```
-
-
- Versions
-
PyTorch 2.7.0.dev20250209+cu124
cc @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @amjames @aakhundov