#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright (C) 2025 Red Hat, Inc.
"""Error handling tests for UDisks2 - Testing proper error handling and recovery."""

from __future__ import absolute_import, division, print_function, unicode_literals
__author__ = "Error Handling Test Suite"
__copyright__ = "Copyright (c) 2025 Red Hat, Inc. All rights reserved."

import pytest
import os
import sys
import time
import tempfile
import dbus
import subprocess
import threading
import signal
import warnings

# Add the parent directory to the python path to allow imports
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../..', 'src')))


class TestUDisks2ErrorHandling:
    """
    Comprehensive error handling tests for UDisks2.
    Tests proper error handling, recovery mechanisms, and system stability
    under various error conditions.
    """

    def setup_method(self, method):
        """Initialize test context for error handling tests"""
        self.created_files = []
        self.created_loops = []
        self.test_data_dir = tempfile.mkdtemp(prefix="udisks2_error_test_")
        print(f"\nINFO: Starting error handling test: {method.__name__}")
        subprocess.run(['lsblk'], check=False)

    def teardown_method(self, method):
        """Clean up after error handling tests"""
        self.luks_helper.echolog(f"\nINFO: start Cleaning up error handling test: {method.__name__}")
        print(f"\nINFO: Cleaning up error handling test: {method.__name__}")
        subprocess.run(['lsblk'], check=False)
        
        # Enhanced cleanup for auto-mounted devices
        self._cleanup_automounts()
        
        # Clean up loop devices with comprehensive cleanup
        for loop_device in self.created_loops[:]:  # Create a copy to iterate
            try:
                udisks = getattr(self, 'udisks', None)
                if udisks:
                    # Try to unmount any partitions first
                    self._cleanup_loop_partitions(loop_device)
                    # Remove the loop device
                    udisks.remove_loop(loop_device)
                    self.created_loops.remove(loop_device)
            except Exception as e:
                warnings.warn(f"Failed to remove loop device {loop_device}: {e}")
        
        # Clean up created files
        for file_path in self.created_files:
            try:
                if os.path.exists(file_path):
                    os.unlink(file_path)
            except Exception as e:
                warnings.warn(f"Failed to remove file {file_path}: {e}")
        
        # Clean up test directory
        try:
            if os.path.exists(self.test_data_dir):
                subprocess.run(['rm', '-rf', self.test_data_dir], check=False)
        except Exception as e:
            warnings.warn(f"Failed to remove test directory: {e}")

    def _cleanup_automounts(self):
        """Clean up auto-mounted devices in /run/media/root/"""
        self.luks_helper.echolog(f"\nINFO: start Cleaning _cleanup_automounts")
        try:
            media_dir = "/run/media/root"
            if os.path.exists(media_dir):
                for mount_name in os.listdir(media_dir):
                    mount_path = os.path.join(media_dir, mount_name)
                    if os.path.ismount(mount_path):
                        print(f"INFO: Unmounting auto-mounted device {mount_path}")
                        try:
                            subprocess.run(['umount', mount_path], check=False, 
                                         capture_output=True, text=True)
                        except Exception as e:
                            warnings.warn(f"Failed to unmount {mount_path}: {e}")
        except Exception as e:
            warnings.warn(f"Error during automount cleanup: {e}")

    def _cleanup_loop_partitions(self, loop_device):
        """Clean up partitions on a loop device"""
        self.luks_helper.echolog(f"\nINFO: start Cleaning _cleanup_loop_partitions tes")
        try:
            # Get all partitions for this loop device
            result = subprocess.run(['lsblk', '-ln', '-o', 'NAME,MOUNTPOINT'], 
                                  capture_output=True, text=True, check=False)
            if result.returncode == 0:
                loop_base = os.path.basename(loop_device)
                for line in result.stdout.split('\n'):
                    if line.strip() and loop_base in line and 'p' in line:
                        parts = line.split()
                        if len(parts) >= 1:
                            partition_name = parts[0]
                            mount_point = parts[1] if len(parts) > 1 else ''
                            
                            # Unmount if mounted
                            if mount_point and mount_point not in ['', '-']:
                                print(f"INFO: Unmounting partition {mount_point}")
                                subprocess.run(['umount', mount_point], check=False)
            
            # Wipe partition table to remove all partitions
            print(f"INFO: Cleaning partition table on {loop_device}")
            subprocess.run(['wipefs', '-a', loop_device], check=False, 
                         capture_output=True, text=True)
            subprocess.run(['partprobe'], check=False, capture_output=True, text=True)
            
        except Exception as e:
            warnings.warn(f"Error cleaning up partitions for {loop_device}: {e}")

    def create_test_file(self, name, size_mb=10):
        """Create a test file for error testing"""
        try:
            file_path = os.path.join(self.test_data_dir, f"error_test_{name}.img")
            with open(file_path, 'wb') as f:
                f.write(b'\x00' * (size_mb * 1024 * 1024))
            
            self.created_files.append(file_path)
            return file_path
        except Exception as e:
            warnings.warn(f"Failed to create test file: {e}")
            return None


    def test_force_udisks2_crash(self, test_context):
        """Test designed to force udisks2 to crash through known problematic operations"""
        udisks = test_context.udisks2
        self.udisks = udisks
        
        print("INFO: Attempting to force udisks2 crash")
        
        # 1. Create a scenario with corrupted loop device
        test_file = self.create_test_file("force_crash", 5)
        if not test_file:
            pytest.skip("Could not create test file")
        
        try:
            loop_device = udisks.loopsetup(os.path.basename(test_file))
            self.created_loops.append(loop_device)
            
            # 2. Force crash through rapid conflicting operations
            print("INFO: Performing rapid conflicting operations")
            
            def crash_worker():
                """Worker to perform operations that might crash udisks2"""
                for i in range(100):
                    try:
                        # Rapid create/destroy partition table
                        udisks.get_block_drive(loop_device)
                        udisks.create_partition(size=1)
                        udisks.part_remove()
                        
                        # Immediately try to format after remove
                        udisks.format_disk(loop_device, fs='ext4')
                        
                        # Try to mount non-existent filesystem
                        udisks.fs_mount(loop_device, 'ext4')
                        udisks.fs_unmount(loop_device)
                        
                    except Exception as e:
                        error_msg = str(e).lower()
                        if "segmentation fault" in error_msg or "crash" in error_msg or "abort" in error_msg:
                            raise RuntimeError(f"CRASH DETECTED: {e}")
            
            # 3. Run multiple crash workers simultaneously
            import threading
            threads = []
            crash_detected = []
            
            def worker_wrapper():
                try:
                    crash_worker()
                except RuntimeError as e:
                    crash_detected.append(str(e))
                except Exception:
                    pass  # Ignore normal exceptions
            
            for i in range(5):
                thread = threading.Thread(target=worker_wrapper)
                threads.append(thread)
                thread.start()
            
            for thread in threads:
                thread.join(timeout=30)
            
            if crash_detected:
                pytest.fail(f"UDisks2 crash detected: {crash_detected[0]}")
            
            # 4. Try to trigger crash through filesystem corruption simulation
            print("INFO: Simulating filesystem corruption")
            try:
                # Create filesystem
                udisks.format_disk(loop_device, fs='ext4')
                
                # Corrupt the underlying file while mounted
                with open(test_file, 'r+b') as f:
                    f.seek(1024)  # Seek to filesystem metadata area
                    f.write(b'\xFF' * 512)  # Write garbage
                
                # Try to access corrupted filesystem
                try:
                    udisks.fs_mount(loop_device, 'ext4')
                    udisks.fs_unmount(loop_device)
                except Exception as e:
                    error_msg = str(e).lower()
                    if "segmentation fault" in error_msg or "crash" in error_msg:
                        pytest.fail(f"UDisks2 crashed on corrupted filesystem: {e}")
                        
            except Exception as e:
                error_msg = str(e).lower()
                if "segmentation fault" in error_msg or "crash" in error_msg:
                    pytest.fail(f"UDisks2 crashed during corruption test: {e}")
            
            # 5. Memory exhaustion through large operations
            print("INFO: Testing memory exhaustion")
            try:
                # Try to create extremely large filesystem label
                huge_label = "x" * (64 * 1024)  # 64KB label
                udisks.format_disk(loop_device, fs='ext4', label=huge_label)
            except Exception as e:
                error_msg = str(e).lower()
                if "segmentation fault" in error_msg or "crash" in error_msg:
                    pytest.fail(f"UDisks2 crashed on huge label: {e}")
            
        except Exception as e:
            error_msg = str(e).lower()
            if "segmentation fault" in error_msg or "crash" in error_msg:
                pytest.fail(f"UDisks2 crashed during force crash test: {e}")
        
        print("INFO: Force crash test completed - udisks2 remained stable")

    def test_context_assignment_crash(self, test_context):
        """Test context assignment failures that may crash udisks2"""
        udisks = test_context.udisks2
        self.udisks = udisks
        
        print("INFO: Testing context assignment crash scenarios")
        
        # Create multiple loop devices rapidly to trigger context assignment issues
        created_devices = []
        
        try:
            # 1. Rapid loop device creation to stress context assignment
            print("INFO: Creating multiple loop devices rapidly")
            for i in range(20):
                try:
                    test_file = self.create_test_file(f"context_test_{i}", 1)
                    if test_file:
                        # Create loop device
                        loop_device = udisks.loopsetup(os.path.basename(test_file))
                        if loop_device:
                            created_devices.append(loop_device)
                            self.created_loops.append(loop_device)
                            
                            # Immediately try operations that might trigger context issues
                            try:
                                # Quick succession of operations
                                udisks.get_block_drive(loop_device)
                                udisks.create_partition(size=1)
                                
                                # Try to format immediately using proper D-Bus method
                                try:
                                    obj = udisks.get_object(f"/org/freedesktop/UDisks2/block_devices/{os.path.basename(loop_device)}")
                                    if obj and hasattr(obj, 'Format'):
                                        obj.Format('ext4', udisks.no_options, 
                                                 dbus_interface=udisks.iface_prefix + '.Block')
                                    else:
                                        # Fallback to direct format call
                                        udisks.format_disk(loop_device, fs='ext4')
                                except (AttributeError, dbus.exceptions.DBusException) as e:
                                    print(f"INFO: D-Bus method error: {e}")
                                    # Try alternative method calls that might cause issues
                                    try:
                                        # Try calling non-existent methods to stress D-Bus handling
                                        if obj:
                                            obj.NonExistentMethod()
                                    except:
                                        pass
                                
                            except Exception as e:
                                error_msg = str(e).lower()
                                print(f"INFO: Operation failed on {loop_device}: {e}")
                                if "segmentation fault" in error_msg or "crash" in error_msg:
                                    pytest.fail(f"UDisks2 crashed during context assignment: {e}")
                        
                except Exception as e:
                    error_msg = str(e).lower()
                    print(f"INFO: Loop creation failed at iteration {i}: {e}")
                    if "segmentation fault" in error_msg or "crash" in error_msg:
                        pytest.fail(f"UDisks2 crashed during rapid loop creation: {e}")
                    break
            
            # 2. Concurrent context operations
            print("INFO: Testing concurrent context operations")
            import threading
            
            def context_worker(device_list):
                """Worker that performs operations on devices concurrently"""
                for device in device_list:
                    try:
                        for _ in range(10):
                            # Rapid operations that might stress context handling
                            udisks.get_block_drive(device)
                            udisks.create_partition(size=1)
                            udisks.part_remove()
                            
                    except Exception as e:
                        error_msg = str(e).lower()
                        if "segmentation fault" in error_msg or "crash" in error_msg:
                            raise RuntimeError(f"CRASH in context worker: {e}")
            
            # Split devices among workers
            if created_devices:
                num_workers = min(3, len(created_devices))
                chunk_size = len(created_devices) // num_workers
                
                threads = []
                crash_errors = []
                
                for i in range(num_workers):
                    start_idx = i * chunk_size
                    end_idx = start_idx + chunk_size if i < num_workers - 1 else len(created_devices)
                    device_chunk = created_devices[start_idx:end_idx]
                    
                    def worker_wrapper(devices):
                        try:
                            context_worker(devices)
                        except RuntimeError as e:
                            crash_errors.append(str(e))
                        except Exception:
                            pass  # Ignore normal errors
                    
                    thread = threading.Thread(target=worker_wrapper, args=(device_chunk,))
                    threads.append(thread)
                    thread.start()
                
                # Wait for workers
                for thread in threads:
                    thread.join(timeout=30)
                
                if crash_errors:
                    pytest.fail(f"Context worker crashed: {crash_errors[0]}")
            
            # 3. Force invalid context scenarios
            print("INFO: Testing invalid context scenarios")
            if created_devices:
                device = created_devices[0]
                
                try:
                    # Try to manipulate the underlying loop file while udisks2 has it open
                    # This might cause context assignment issues
                    
                    # Get the backing file
                    ret, backing_file = subprocess.run(['losetup', '-l', device], 
                                                     capture_output=True, text=True, check=False).returncode, \
                                      subprocess.run(['losetup', '-l', device], 
                                                     capture_output=True, text=True, check=False).stdout
                    
                    if ret == 0 and backing_file:
                        print(f"INFO: Testing context issues with backing file manipulation")
                        
                        # Try operations while manipulating backing file permissions
                        original_perms = os.stat(self.test_data_dir).st_mode
                        
                        try:
                            # Change directory permissions
                            os.chmod(self.test_data_dir, 0o000)
                            
                            # Try udisks2 operations - should trigger context errors
                            udisks.format_disk(device, fs='ext4')
                            
                        except Exception as e:
                            error_msg = str(e).lower()
                            print(f"INFO: Expected error with permission manipulation: {e}")
                            if "segmentation fault" in error_msg or "crash" in error_msg:
                                pytest.fail(f"UDisks2 crashed during permission test: {e}")
                        finally:
                            # Restore permissions
                            try:
                                os.chmod(self.test_data_dir, original_perms)
                            except:
                                pass
                    
                except Exception as e:
                    error_msg = str(e).lower()
                    if "segmentation fault" in error_msg or "crash" in error_msg:
                        pytest.fail(f"UDisks2 crashed during invalid context test: {e}")
            
            # 4. SELinux context stress test (if SELinux is enabled)
            print("INFO: Testing SELinux context scenarios")
            try:
                # Check if SELinux is enabled
                ret = subprocess.run(['getenforce'], capture_output=True, text=True, check=False)
                if ret.returncode == 0 and 'enforcing' in ret.stdout.lower():
                    print("INFO: SELinux is enforcing, testing context scenarios")
                    
                    if created_devices:
                        device = created_devices[0]
                        try:
                            # Try to change SELinux context of backing files
                            for test_file in self.created_files:
                                if os.path.exists(test_file):
                                    subprocess.run(['chcon', '-t', 'admin_home_t', test_file], 
                                                 check=False, capture_output=True)
                            
                            # Try operations after context change
                            udisks.format_disk(device, fs='ext4')
                            
                        except Exception as e:
                            error_msg = str(e).lower()
                            print(f"INFO: SELinux context test error: {e}")
                            if "segmentation fault" in error_msg or "crash" in error_msg:
                                pytest.fail(f"UDisks2 crashed during SELinux context test: {e}")
                else:
                    print("INFO: SELinux not enforcing, skipping SELinux context tests")
                    
            except Exception as e:
                print(f"INFO: SELinux test error (non-fatal): {e}")
        
        finally:
            # Enhanced cleanup - remove devices in reverse order
            print("INFO: Cleaning up context test devices")
            for device in reversed(created_devices):
                try:
                    if device in self.created_loops:
                        udisks.remove_loop(device)
                        self.created_loops.remove(device)
                except Exception as e:
                    print(f"WARN: Failed to cleanup {device}: {e}")
        
        print("INFO: Context assignment crash test completed")

    def get_process_memory(self):
        """Get current process memory usage"""
        try:
            with open('/proc/self/status', 'r') as f:
                for line in f:
                    if line.startswith('VmRSS:'):
                        return int(line.split()[1]) * 1024  # Convert kB to bytes
        except:
            return 0

if __name__ == "__main__":
    pytest.main([__file__])
