-
Bug
-
Resolution: Unresolved
-
Undefined
-
None
-
None
-
None
-
-
- 🐛 Describe the bug
-
```
static std::optional<std::pair<NativeOpSchema, /ComputeMesh/ py::object>>
create_native_op_schema(
const c10::OperatorHandle& op,
py::handle py_op,
torch::jit::Stack* stack) {
// fused schema part of unwrap_to_op_info + recompute_comparison_key,
// operating on IValues instead of Python stuff.
py::object runtime_schema_info = get_runtime_schema_info_for_op(py_op);
if (runtime_schema_info &&
checked_istrue(py::handle(runtime_schema_info)
.attr(dtensor_interned_strings.needs_pytree)
.ptr()))
```
Here, pytree isn't handled, which means we have to go back to the legacy opinfo computation using pytree. We should implement this directly in C++. In particular, a small but very useful increment of progress would be to update the handling below to support list of Tensor, which can hit hot ops like foreach and stack.
If you're doing the full thing, when running the LLM make sure to feed it the reference Python pytree implementation.
-
-
- Versions
-
main
cc @wanchaol @tianyu-l @wz337 @XilunWu @d4l3k @pragupta @SherlockNoMad