import socket
import select
import time

UPSTREAM_DNS = "1.1.1.1"

class DelayedDelivery:
    def __init__(self, time, addr, data):
        self.time = time
        self.addr = addr
        self.data = data

try:
    server_socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
    print("Socket created successfully.")

    server_socket.bind(("0.0.0.0", 53))
    print(f"Socket bound")

    read_list = [server_socket]
    
    print("Server is listening... Press Ctrl+C to exit.")

    socks_to_addr = {}
    deliveries = []

    while True:
        start = time.time()
        readable, _, _ = select.select(read_list, [], [], 0.1)
        
        for sock in readable:
            if sock is server_socket:
                data, addr = server_socket.recvfrom(8192)
                print(f"server received query from {addr}")
                new_sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
                new_sock.bind(("0.0.0.0", 0))
                new_sock.sendto(data, (UPSTREAM_DNS, 53))
                socks_to_addr[new_sock] = addr
                read_list.append(new_sock)
            else:
                data, addr = sock.recvfrom(8192)
                print(f"server delaying reply to {socks_to_addr[sock]}")
                deliveries.append(DelayedDelivery(3, socks_to_addr[sock], data))
                sock.close()
                read_list.remove(sock)
                socks_to_addr.pop(sock)

        end = time.time()

        deliveries_iterator = iter(deliveries)
        while True:
            delivery = next(deliveries_iterator, None)
            if delivery is None:
                break
            delivery.time -= end - start
            if delivery.time <= 0:
                print(f"server delivering reply to {delivery.addr}")
                server_socket.sendto(delivery.data, delivery.addr)
                deliveries.remove(delivery)

except PermissionError:
    print(f"Permission denied to bind to port {53}.")
except KeyboardInterrupt:
    print("Server is shutting down.")
except Exception as e:
    print(f"An error occurred: {e}")
