#!/usr/bin/env python3
"""
================================================================================
SECURITY REGRESSION TEST: CWE-78 - Command Injection
================================================================================

CVE/CWE: Command Injection
Location: scripts/release_notes/commitlist.py#L386
Reference: https://cwe.mitre.org/data/definitions/78.html

PURPOSE:
This test verifies that the security fix for command injection in git command
execution is effective. The vulnerability occurs when git log/merge-base
parameters are user-controlled and could contain shell metacharacters.

EXPECTED BEHAVIOR (if patched):
- Git commands use list form or properly escaped strings
- Version parameters are validated as valid git refs
- Shell metacharacters in refs are rejected or escaped
- No shell=True in subprocess calls

VULNERABILITY INDICATOR (if NOT patched):
- Git commands built with string formatting
- Shell metacharacters not escaped in version parameters
- subprocess with shell=True

================================================================================
"""

import sys
import os
import traceback

# Track test results
test_results = []


def test_1_git_command_construction():
    """
    Test: Verify git commands are safely constructed
    """
    print("\n" + "="*70)
    print("TEST 1: Git Command Construction")
    print("="*70)

    try:
        with open('/pytorch/scripts/release_notes/commitlist.py', 'r') as f:
            lines = f.readlines()

        # Check lines around 379 and 386
        target_lines = []
        for i in range(375, 390):
            if i < len(lines):
                target_lines.append((i+1, lines[i]))

        issues = []
        safe_patterns = []

        for line_num, line in target_lines:
            if 'cmd =' in line:
                print(f"Line {line_num}: {line.strip()}")

                # Check if command is built with f-string or .format()
                if 'f"' in line or 'f\'' in line or '.format(' in line:
                    # This could be vulnerable if used with shell=True
                    issues.append(f"Line {line_num}: String formatting")
                    print(f"  [WARN] Uses string formatting")

                # Check if shlex.quote is used
                if 'shlex.quote' in line or 'shlex.split' in line:
                    safe_patterns.append(f"Line {line_num}: shlex")
                    print(f"  [PASS] Uses shlex for escaping")

        # Now check how cmd is used in run()
        func_content = ''.join([l for _, l in target_lines])
        if 'run(cmd)' in func_content:
            print("\nChecking run() function for shell usage...")
            # Need to find the run() function
            with open('/pytorch/scripts/release_notes/commitlist.py', 'r') as f:
                full_content = f.read()

            if 'def run(' in full_content:
                # Extract run function
                run_start = full_content.find('def run(')
                run_content = full_content[run_start:run_start+500]

                if 'shell=True' in run_content:
                    print("[FAIL] run() uses shell=True")
                    test_results.append(("Git Command Safety", "FAIL", "shell=True detected"))
                    return False
                elif 'subprocess' in run_content and 'shell=False' in run_content:
                    print("[PASS] run() uses shell=False")
                    test_results.append(("Git Command Safety", "PASS", "shell=False"))
                    return True
                elif 'shlex.split' in run_content:
                    print("[PASS] run() uses shlex.split")
                    test_results.append(("Git Command Safety", "PASS", "shlex.split"))
                    return True
                else:
                    print("[WARN] Could not determine shell usage")
                    test_results.append(("Git Command Safety", "WARN", "Unclear"))
                    return False

        if safe_patterns:
            print(f"[PASS] Found safe patterns: {len(safe_patterns)}")
            test_results.append(("Git Command Safety", "PASS", "Safe patterns found"))
            return True
        else:
            print("[WARN] Could not verify safety")
            test_results.append(("Git Command Safety", "WARN", "Cannot verify"))
            return False

    except Exception as e:
        print(f"[ERROR] Test failed: {e}")
        traceback.print_exc()
        test_results.append(("Git Command Safety", "ERROR", str(e)))
        return False


def test_2_version_parameter_validation():
    """
    Test: Verify version parameters are validated
    """
    print("\n" + "="*70)
    print("TEST 2: Version Parameter Validation")
    print("="*70)

    try:
        with open('/pytorch/scripts/release_notes/commitlist.py', 'r') as f:
            content = f.read()

        # Find get_commits_between function
        func_start = content.find('def get_commits_between(')
        if func_start == -1:
            print("[WARN] get_commits_between function not found")
            test_results.append(("Version Validation", "WARN", "Function not found"))
            return False

        func_content = content[func_start:func_start+1000]

        # Check for any validation
        validation_found = False

        # Check for assertions or validations
        if 'assert' in func_content:
            print("[PASS] Function has assertions")
            validation_found = True

        # Check for exception handling
        if 'raise' in func_content or 'except' in func_content:
            print("[PASS] Function has exception handling")
            validation_found = True

        # Check return code validation
        if 'rc == 0' in func_content or 'returncode' in func_content:
            print("[PASS] Function validates return codes")
            validation_found = True

        if validation_found:
            test_results.append(("Version Validation", "PASS", "Has validation"))
            return True
        else:
            print("[WARN] Limited validation found")
            test_results.append(("Version Validation", "WARN", "Limited validation"))
            return False

    except Exception as e:
        print(f"[ERROR] Test failed: {e}")
        traceback.print_exc()
        test_results.append(("Version Validation", "ERROR", str(e)))
        return False


def test_3_run_function_implementation():
    """
    Test: Verify run() function is implemented safely
    """
    print("\n" + "="*70)
    print("TEST 3: run() Function Implementation")
    print("="*70)

    try:
        with open('/pytorch/scripts/release_notes/commitlist.py', 'r') as f:
            content = f.read()

        # Find run function
        if 'def run(' not in content:
            print("[WARN] run() function not found - may be imported")
            test_results.append(("run() Implementation", "WARN", "Not found"))
            return False

        run_start = content.find('def run(')
        run_content = content[run_start:run_start+800]

        print("Analyzing run() function...")

        # Check for shell=True
        if 'shell=True' in run_content:
            print("[FAIL] run() uses shell=True")
            test_results.append(("run() Implementation", "FAIL", "shell=True"))
            return False

        # Check for shlex.split usage
        if 'shlex.split' in run_content:
            print("[PASS] run() uses shlex.split")
            test_results.append(("run() Implementation", "PASS", "shlex.split"))
            return True

        # Check for list-based subprocess calls
        if 'subprocess.' in run_content and '[' in run_content:
            print("[PASS] run() appears to use list-based subprocess")
            test_results.append(("run() Implementation", "PASS", "List-based"))
            return True

        print("[WARN] Could not determine safety")
        test_results.append(("run() Implementation", "WARN", "Unclear"))
        return False

    except Exception as e:
        print(f"[ERROR] Test failed: {e}")
        traceback.print_exc()
        test_results.append(("run() Implementation", "ERROR", str(e)))
        return False


def run_all_tests():
    """Run all regression tests for this security fix"""
    print("="*70)
    print("SECURITY REGRESSION TEST: CWE-78 Command Injection")
    print("Verifying fix for: git command execution (commitlist.py:386)")
    print("="*70)

    tests = [
        test_1_git_command_construction,
        test_2_version_parameter_validation,
        test_3_run_function_implementation,
    ]

    passed = 0
    failed = 0

    for test in tests:
        try:
            if test():
                passed += 1
            else:
                failed += 1
        except Exception as e:
            print(f"[ERROR] Test crashed: {e}")
            traceback.print_exc()
            failed += 1

    # Summary
    print("\n" + "="*70)
    print("TEST SUMMARY")
    print("="*70)

    for name, status, details in test_results:
        print(f"{status}: {name} - {details}")

    print(f"\nPassed: {passed}, Failed: {failed}")

    # Final verdict
    if failed == 0:
        print("\nResult: PATCHED (all security tests passed)")
        return "PATCHED"
    else:
        print("\nResult: UNPATCHED (security fix may be incomplete)")
        return "UNPATCHED"


if __name__ == "__main__":
    result = run_all_tests()
    print(f"\nFinal Status: {result}")
    sys.exit(0 if result == "PATCHED" else 1)
