#!/usr/bin/env python3
"""
================================================================================
SECURITY REGRESSION TEST: CWE-78 - Command Injection
================================================================================

CVE/CWE: Command Injection
Location: torch/_inductor/compiler_bisector.py#L728 (near line 731)
Reference: https://cwe.mitre.org/data/definitions/78.html

PURPOSE:
This test verifies that the security fix for command injection in subprocess.run
is effective. The vulnerability occurs when run_cmd is user-controlled and could
contain shell metacharacters.

EXPECTED BEHAVIOR (if patched):
- run_cmd is a list, not a string
- subprocess.run does not use shell=True
- Command arguments are validated
- Environment variables are validated

VULNERABILITY INDICATOR (if NOT patched):
- run_cmd is a string and shell=True is used
- Shell metacharacters in run_cmd are not escaped
- Arbitrary code execution via run_cmd

================================================================================
"""

import sys
import os
import traceback

# Track test results
test_results = []


def test_1_subprocess_run_safety():
    """
    Test: Verify subprocess.run is called safely
    """
    print("\n" + "="*70)
    print("TEST 1: Subprocess.run Safety")
    print("="*70)

    try:
        with open('/pytorch/torch/_inductor/compiler_bisector.py', 'r') as f:
            lines = f.readlines()

        # Check line 728
        target_line = None
        for i in range(725, 732):
            if i < len(lines) and 'subprocess.run' in lines[i]:
                target_line = lines[i]
                line_num = i + 1
                break

        if target_line:
            print(f"Line {line_num}: {target_line.strip()}")

            # Check if shell=True is used
            if 'shell=True' in target_line:
                print("[FAIL] Uses shell=True - vulnerable")
                test_results.append(("Subprocess Safety", "FAIL", "shell=True"))
                return False
            elif 'shell=False' in target_line or 'shell=' not in target_line:
                print("[PASS] Does not use shell=True")
                test_results.append(("Subprocess Safety", "PASS", "No shell=True"))
                return True
            else:
                print("[WARN] Could not determine shell usage")
                test_results.append(("Subprocess Safety", "WARN", "Unclear"))
                return False
        else:
            print("[WARN] subprocess.run not found at expected line")
            test_results.append(("Subprocess Safety", "WARN", "Not found"))
            return False

    except Exception as e:
        print(f"[ERROR] Test failed: {e}")
        traceback.print_exc()
        test_results.append(("Subprocess Safety", "ERROR", str(e)))
        return False


def test_2_run_cmd_construction():
    """
    Test: Verify run_cmd is constructed as a list
    """
    print("\n" + "="*70)
    print("TEST 2: run_cmd Construction")
    print("="*70)

    try:
        with open('/pytorch/torch/_inductor/compiler_bisector.py', 'r') as f:
            content = f.read()

        # Search for run_cmd construction
        lines = content.split('\n')

        # Look backwards from line 728 for run_cmd definition
        run_cmd_lines = []
        for i in range(728, 650, -1):
            if i < len(lines) and 'run_cmd' in lines[i]:
                run_cmd_lines.append((i+1, lines[i]))

        if run_cmd_lines:
            print(f"Found {len(run_cmd_lines)} run_cmd references")

            # Check if run_cmd is a list
            for line_num, line in run_cmd_lines[:5]:  # Check first 5
                print(f"Line {line_num}: {line.strip()}")

                if 'run_cmd = [' in line:
                    print(f"  [PASS] run_cmd is a list at line {line_num}")
                    test_results.append(("run_cmd Construction", "PASS", "List form"))
                    return True
                elif 'run_cmd.append' in line or 'run_cmd.extend' in line:
                    print(f"  [PASS] run_cmd built as list")
                    test_results.append(("run_cmd Construction", "PASS", "List building"))
                    return True

            # If we get here, check if it's a string
            for line_num, line in run_cmd_lines[:5]:
                if 'run_cmd = "' in line or "run_cmd = '" in line:
                    print(f"  [FAIL] run_cmd is a string at line {line_num}")
                    test_results.append(("run_cmd Construction", "FAIL", "String form"))
                    return False

            print("[WARN] Could not determine run_cmd type")
            test_results.append(("run_cmd Construction", "WARN", "Unclear"))
            return False
        else:
            print("[WARN] run_cmd construction not found")
            test_results.append(("run_cmd Construction", "WARN", "Not found"))
            return False

    except Exception as e:
        print(f"[ERROR] Test failed: {e}")
        traceback.print_exc()
        test_results.append(("run_cmd Construction", "ERROR", str(e)))
        return False


def test_3_environment_variable_safety():
    """
    Test: Verify environment variables are handled safely
    """
    print("\n" + "="*70)
    print("TEST 3: Environment Variable Safety")
    print("="*70)

    try:
        with open('/pytorch/torch/_inductor/compiler_bisector.py', 'r') as f:
            lines = f.readlines()

        # Check context around line 728
        context = ''.join(lines[710:735])

        # Check if env is constructed safely
        if 'env = ' in context:
            print("[PASS] Environment variables are explicitly set")

            # Check for validation of env values
            if '"TORCH_BISECT_' in context:
                print("[PASS] Uses controlled environment variable names")
                test_results.append(("Environment Safety", "PASS", "Controlled vars"))
                return True
            else:
                print("[WARN] Environment variables may need validation")
                test_results.append(("Environment Safety", "WARN", "May need validation"))
                return False
        else:
            print("[INFO] No explicit env parameter")
            test_results.append(("Environment Safety", "INFO", "No env param"))
            return True

    except Exception as e:
        print(f"[ERROR] Test failed: {e}")
        traceback.print_exc()
        test_results.append(("Environment Safety", "ERROR", str(e)))
        return False


def run_all_tests():
    """Run all regression tests for this security fix"""
    print("="*70)
    print("SECURITY REGRESSION TEST: CWE-78 Command Injection")
    print("Verifying fix for: subprocess.run (compiler_bisector.py:728)")
    print("="*70)

    tests = [
        test_1_subprocess_run_safety,
        test_2_run_cmd_construction,
        test_3_environment_variable_safety,
    ]

    passed = 0
    failed = 0

    for test in tests:
        try:
            if test():
                passed += 1
            else:
                failed += 1
        except Exception as e:
            print(f"[ERROR] Test crashed: {e}")
            traceback.print_exc()
            failed += 1

    # Summary
    print("\n" + "="*70)
    print("TEST SUMMARY")
    print("="*70)

    for name, status, details in test_results:
        print(f"{status}: {name} - {details}")

    print(f"\nPassed: {passed}, Failed: {failed}")

    # Final verdict
    if failed == 0:
        print("\nResult: PATCHED (all security tests passed)")
        return "PATCHED"
    else:
        print("\nResult: UNPATCHED (security fix may be incomplete)")
        return "UNPATCHED"


if __name__ == "__main__":
    result = run_all_tests()
    print(f"\nFinal Status: {result}")
    sys.exit(0 if result == "PATCHED" else 1)
