#!/usr/bin/python

import os
from time import time

import ldap

BASEDN = os.environ.get("BASEDN", "dc=example,dc=com")
MAXITER = int(os.environ.get("MAXITER", 1000))
TRACELEVEL = int(os.environ.get("TRACELEVEL", -1))
FILTER = os.environ.get(
    "LDAP_FILTER", "(&(objectClass=nsPerson)(memberOf=cn=*,dc=example,dc=com))"
)


def search_ext(srv):
    try:
        ldapid = srv.search_ext(BASEDN, ldap.SCOPE_SUBTREE, FILTER, sizelimit=1)
        while True:
            res = srv.result(ldapid, 0)
            if res[0] == ldap.RES_SEARCH_ENTRY:
                if res[1] == []:
                    break
                if res[1] != []:
                    print(res[1][0][0])
                else:
                    print(res)
            if res[0] == ldap.RES_SEARCH_RESULT:
                if res[1] == []:
                    break
                if res[1] != []:
                    print(res[1][0][0])
                else:
                    print(res)
    except ldap.TIMEOUT as e:
        print("timed out")
    except ldap.SIZELIMIT_EXCEEDED as e:
        pass
    return res


def search_ext_paged(srv):
    lc = ldap.controls.libldap.SimplePagedResultsControl(
        criticality=True, size=1, cookie=""
    )
    try:
        ldapid = srv.search_ext(BASEDN, ldap.SCOPE_SUBTREE, FILTER, serverctrls=[lc])
        while True:
            res = srv.result3(ldapid, 0)
            if res[0] == ldap.RES_SEARCH_ENTRY:
                # that's our bug
                if bool(int(os.environ.get("PROJQUAY_3810_FIXED", False))):
                    if res[1] == []:
                        break
                if res[1] != []:
                    print(res[1][0][0])
                else:
                    print(res)
            if res[0] == ldap.RES_SEARCH_RESULT:
                # that's our bug
                if bool(int(os.environ.get("PROJQUAY_3810_FIXED", False))):
                    if res[1] == []:
                        break
                if res[1] != []:
                    print(res[1][0][0])
                else:
                    print(res)
    except ldap.TIMEOUT as e:
        print("timed out")
    except ldap.SIZELIMIT_EXCEEDED as e:
        pass
    return res


def search_ext_paged_ext(srv):
    lc = ldap.controls.libldap.SimplePagedResultsControl(
        criticality=True, size=1, cookie=""
    )
    try:
        ldapid = srv.search_ext(
            BASEDN, ldap.SCOPE_SUBTREE, FILTER, serverctrls=[lc], sizelimit=1
        )
        while True:
            res = srv.result3(ldapid, 0)
            if res[0] == ldap.RES_SEARCH_ENTRY:
                if res[1] == []:
                    break
                if res[1] != []:
                    print(res[1][0][0])
                else:
                    print(res)
            if res[0] == ldap.RES_SEARCH_RESULT:
                if res[1] == []:
                    break
                if res[1] != []:
                    print(res[1][0][0])
                else:
                    print(res)
    except ldap.TIMEOUT as e:
        print("timed out")
    except ldap.SIZELIMIT_EXCEEDED as e:
        pass
    return res


# Plain query slowest one don't use at all
# starttime = time()
# for _ in range(MAXITER):
#    srv.search_s(BASEDN, ldap.SCOPE_SUBTREE, FILTER)
# stoptime = time()
# print(f"search_s took %0.5f" % (stoptime-starttime))

srv = ldap.initialize(
    os.environ.get("LDAP_URI", "ldaps://ldap.example.com"), trace_level=TRACELEVEL
)
srv.set_option(ldap.OPT_TIMEOUT, 10.0)
srv.simple_bind_s(os.environ.get("BINDDN"), os.environ.get("BINDPWD"))
starttime = time()
for _ in range(MAXITER):
    sep = search_ext_paged(srv)
stoptime = time()
print(
    f"search_ext paginated took total=%0.5f avg=%0.5f"
    % ((stoptime - starttime), ((stoptime - starttime) / MAXITER))
)
srv.unbind()
del srv

srv = ldap.initialize(
    os.environ.get("LDAP_URI", "ldaps://ldap.example.com"), trace_level=TRACELEVEL
)
srv.set_option(ldap.OPT_TIMEOUT, 10.0)
srv.simple_bind_s(os.environ.get("BINDDN"), os.environ.get("BINDPWD"))
starttime = time()
for _ in range(MAXITER):
    sepe = search_ext_paged_ext(srv)
stoptime = time()
print(
    f"search_ext paginated ext took total=%0.5f avg=%0.5f"
    % ((stoptime - starttime), ((stoptime - starttime) / MAXITER))
)
srv.unbind()
del srv

srv = ldap.initialize(
    os.environ.get("LDAP_URI", "ldaps://ldap.example.com"), trace_level=TRACELEVEL
)
srv.set_option(ldap.OPT_TIMEOUT, 10.0)
srv.simple_bind_s(os.environ.get("BINDDN"), os.environ.get("BINDPWD"))
starttime = time()
for _ in range(MAXITER):
    se = search_ext(srv)
stoptime = time()
print(
    f"search_ext took total=%0.5f avg=%0.5f"
    % ((stoptime - starttime), ((stoptime - starttime) / MAXITER))
)
srv.unbind()
del srv
