diff --git a/src/phi/ldap/async_client.py b/src/phi/ldap/async_client.py new file mode 100644 index 0000000..039f6b9 --- /dev/null +++ b/src/phi/ldap/async_client.py @@ -0,0 +1,132 @@ +from urllib.parse import urlparse + +from bonsai import LDAPClient + +from phi.logging import get_logger + + +log = get_logger(__name__) + + +def parse_host(host): + """ + Helper function to decompose the host in the address + and the (optional) protocol and the (optional) port. + If missing, protocol defaults to "ldap" and port to 389, + in case protocol is missing or is "ldap", or 636, in case + protocol is "ldaps". + """ + if "://" not in host: + host = f"//{host}" + + p = urlparse(host) + if p.scheme is not None and p.scheme != "": + proto = p.scheme + else: + proto = "ldap" + + if p.port is not None: + port = p.port + else: + port = None + + if port is not None: + addr = p.netloc.split(":")[0] + else: + addr = p.netloc + if proto == "ldap": + port = 389 + elif proto == "ldaps": + port = 636 + + return proto, addr, port + + +def checked_port(provided, auto): + """ + Check consistency of ports given via the connection string + and the explicit parameter. + """ + _provided = provided is not None + + if _provided and provided != auto: + log.warning( + "Explicitly provided port ({}) does not match " + "the automatically provided one ({}). The former prevails.".format( + provided, auto + ) + ) + return provided + + if _provided: + return provided + + return auto + + +def compose_dn_username(username, base_dn, ou=None, attribute_id=None): + """ + Output the distinguished name of the user to use as login. + """ + if base_dn in username: + return username + + if ou is None: + return f"{attribute_id}={username},{base_dn}" + + return f"{attribute_id}={username},ou={ou},{base_dn}" + + +class AsyncClient(LDAPClient): + """ + Wrapper around LDAPClient. + """ + + def __init__( + self, + host=None, + port=None, + encryption=None, + ciphers=None, + validate=False, + ca_cert=None, + username=None, + password=None, + base_dn=None, + attribute_id="uid", + ou=None, + method="SIMPLE", + ): + self.proto, self.host, _port = parse_host(host) + self.port = checked_port(port, _port) + self.full_uri = "{}://{}:{}".format(self.proto, self.host, self.port) + self.base_dn = base_dn + + if encryption: + self._tls = True + else: + if self.proto == "ldaps": + raise ValueError( + 'Incompatible provided protocol ("%s") and encryption configuration: TLS=%s', + self.proto, + encryption, + ) + self._tls = False + + super().__init__(self.full_uri, self._tls) + log.info( + "Connected at %s (TLS -> %s)", self.full_uri, "ON" if self.tls else "OFF" + ) + + if not validate: + self.set_cert_policy("never") + + if ca_cert is not None: + self.set_ca_cert(ca_cert) + + self.username = compose_dn_username(username, self.base_dn, ou, attribute_id) + self.password = password + self.method = method + + self.set_auto_page_acquire(True) + self.set_credentials(self.method, user=self.username, password=self.password) diff --git a/test/test_ldap_async_client.py b/test/test_ldap_async_client.py new file mode 100644 index 0000000..e903daf --- /dev/null +++ b/test/test_ldap_async_client.py @@ -0,0 +1,118 @@ +# -*- encoding: utf-8 -*- + +from contextlib import contextmanager +import logging + +import mock +import pytest + +from phi.ldap.async_client import ( + parse_host, + checked_port, + compose_dn_username, + AsyncClient, +) + + +BASE_DN = "dc=unit,dc=macaomilano,dc=org" + + +@contextmanager +def does_not_raise(): + yield + + +@pytest.mark.parametrize( + "test_url, exp_proto, exp_addr, exp_port", + [ + ("1.3.1.2", "ldap", "1.3.1.2", 389), + ("ldap://localhost:1312", "ldap", "localhost", 1312), + ("localhost:1312", "ldap", "localhost", 1312), + ("localhost", "ldap", "localhost", 389), + ("ldap://localhost", "ldap", "localhost", 389), + ("ldaps://localhost", "ldaps", "localhost", 636), + ("ldaps://localhost:1312", "ldaps", "localhost", 1312), + ], +) +def test_parse_host(test_url, exp_proto, exp_addr, exp_port): + proto, addr, port = parse_host(test_url) + + assert proto == exp_proto + assert addr == exp_addr + assert port == exp_port + + +@pytest.mark.parametrize( + "manual, auto, exp_port", [(None, 389, 389), (1312, 389, 1312), (1312, 1312, 1312)] +) +def test_checked_port(manual, auto, exp_port, caplog): + port = checked_port(manual, auto) + if manual and manual != auto: + with caplog.at_level(logging.WARNING): + "The former prevails" in caplog.text + + assert port == exp_port + + +@pytest.mark.parametrize( + "username, base_dn, ou, attribute_id, exp_dn", + [ + ( + f"uid=conte_mascetti,{BASE_DN}", + BASE_DN, + None, + "uid", + f"uid=conte_mascetti,{BASE_DN}", + ), + ("root", BASE_DN, None, "cn", f"cn=root,{BASE_DN}"), + ("necchi", BASE_DN, "Hackers", "uid", f"uid=necchi,ou=Hackers,{BASE_DN}"), + ("perozzi", BASE_DN, "Phrackers", "cn", f"cn=perozzi,ou=Phrackers,{BASE_DN}"), + ], +) +def test_compose_dn_username(username, base_dn, ou, attribute_id, exp_dn): + dn = compose_dn_username(username, base_dn, ou, attribute_id) + + assert dn == exp_dn + + +@pytest.mark.parametrize( + "url, encryption, validate, ca_cert, expectation", + [ + ("localhost", None, False, None, does_not_raise()), + ("localhost", True, False, None, does_not_raise()), + ("localhost", False, True, None, does_not_raise()), + ("localhost", True, True, "path/to/cert.pem", does_not_raise()), + ("ldaps://localhost", False, False, None, pytest.raises(ValueError)), + ], +) +def test_AsyncClient_init(url, encryption, validate, ca_cert, expectation): + with expectation as exp: + cl = AsyncClient( + host=url, + port=389, + encryption=encryption, + ciphers=None, + validate=validate, + ca_cert=ca_cert, + username="conte_mascetti", + password="pass", + base_dn=BASE_DN, + ou="Hackers", + ) + + if exp is not None: + assert "Incompatible provided protocol" in str(exp.value) + return + + assert cl.base_dn == BASE_DN + assert url in cl.full_uri + assert "389" in cl.full_uri + assert cl._tls if encryption else not cl._tls + if validate: + assert cl.cert_policy == -1 + else: + assert cl.cert_policy == 0 + if ca_cert: + assert cl.ca_cert == ca_cert + else: + assert cl.ca_cert == ""