#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License
#

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import gc
import os
import subprocess
import sys
import threading
import time
import unittest
import uuid

import proton.handlers
import proton.reactor
import proton.utils


class ReconnectingTestClient:
    def __init__(self, hostport):
        # type: (str) -> None
        self.hostport = hostport

        self.object_counts = []
        self.done = threading.Event()

    def count_objects(self, message):
        # type: (str) -> None
        gc.collect()
        n = len(gc.get_objects())
        if message == "loop":
            self.object_counts.append(n)
        print(message, n)

    def run(self):
        ADDR = "testing123"
        HEARTBEAT = 5
        SLEEP = 5

        recv = None
        conn = None
        for _ in range(5):
            subscribed = False
            while not subscribed:
                try:
                    conn = proton.utils.BlockingConnection(self.hostport, ssl_domain=None, heartbeat=HEARTBEAT)
                    recv = conn.create_receiver(ADDR, name=str(uuid.uuid4()), dynamic=False, options=None)
                    subscribed = True
                except Exception as e:
                    print("received exception %s on connect/subscribe, retry" % e)
                    time.sleep(0.5)

            self.count_objects("loop")
            print("connected")
            while subscribed:
                try:
                    print()
                    recv.receive(SLEEP)
                except proton.Timeout:
                    pass
                except Exception as e:
                    print(e)
                    try:
                        recv.close()
                        recv = None
                    except:
                        self.count_objects("link close() failed")
                        pass
                    try:
                        conn.close()
                        conn = None
                        self.count_objects("conn closed")
                    except:
                        self.count_objects("conn close() failed")
                        pass
                    subscribed = False
        self.done.set()


class BlockingConnectionObjectLeakTests(unittest.TestCase):
    def test_blocking_connection_object_leak(self):
        gc.collect()

        thread = None
        client = None

        hostport = ""
        broker_process = None

        while not client or not client.done.is_set():
            try:
                params = []
                if hostport:
                    params = ['-b', hostport]
                broker_process = subprocess.Popen(
                    args=[sys.executable, os.path.join(os.path.dirname(__file__), 'ENTMQCL-1578_broker.py')] + params,
                    stdout=subprocess.PIPE,
                    universal_newlines=True,
                )
                hostport = broker_process.stdout.readline()

                if not client:
                    client = ReconnectingTestClient(hostport)
                    thread = threading.Thread(target=client.run)
                    thread.start()

                time.sleep(3)
            finally:
                if broker_process:
                    broker_process.kill()
                    broker_process.wait()
                    broker_process.stdout.close()
            time.sleep(0.3)

        thread.join()
        print("client.object_counts:", client.object_counts)
        object_counts = client.object_counts[1:]  # drop first, measurement error

        diffs = [c - object_counts[0] for c in object_counts]
        self.assertEqual([0] * 4, diffs, "Object counts should not be increasing")


if __name__ == '__main__':
    unittest.main()
