Uploaded image for project: 'AI Platform Core Components'
  1. AI Platform Core Components
  2. AIPCC-7415

Finish pytree flattening path in create_native_op_schema

    • Icon: Bug Bug
    • Resolution: Unresolved
    • Icon: Undefined Undefined
    • None
    • None
    • PyTorch
    • None
    • PyTorch Sprint 20

          1. 🐛 Describe the bug

      ```
      static std::optional<std::pair<NativeOpSchema, /ComputeMesh/ py::object>>
      create_native_op_schema(
      const c10::OperatorHandle& op,
      py::handle py_op,
      torch::jit::Stack* stack) {
      // fused schema part of unwrap_to_op_info + recompute_comparison_key,
      // operating on IValues instead of Python stuff.

      py::object runtime_schema_info = get_runtime_schema_info_for_op(py_op);
      if (runtime_schema_info &&
      checked_istrue(py::handle(runtime_schema_info)
      .attr(dtensor_interned_strings.needs_pytree)
      .ptr()))

      { // Punting on pytree flattening in the fast path on IValues for // now since only a minority of ops need it. return std::nullopt; }

      ```

      Here, pytree isn't handled, which means we have to go back to the legacy opinfo computation using pytree. We should implement this directly in C++. In particular, a small but very useful increment of progress would be to update the handling below to support list of Tensor, which can hit hot ops like foreach and stack.

      If you're doing the full thing, when running the LLM make sure to feed it the reference Python pytree implementation.

          1. Versions

      main

      cc @wanchaol @tianyu-l @wz337 @XilunWu @d4l3k @pragupta @SherlockNoMad

              rh-ee-amaitra Arkadip Maitra
              rh-ee-amaitra Arkadip Maitra
              PyTorch Distributed
              Votes:
              0 Vote for this issue
              Watchers:
              2 Start watching this issue

                Created:
                Updated: