Adding AsyncClient.

This commit is contained in:
sfigato 2019-04-28 15:33:53 +02:00
parent 27fc927254
commit b1837f80e4
Signed by: blallo
GPG Key ID: 0CBE577C9B72DC3F
2 changed files with 250 additions and 0 deletions

View File

@ -0,0 +1,132 @@
from urllib.parse import urlparse
from bonsai import LDAPClient
from phi.logging import get_logger
log = get_logger(__name__)
def parse_host(host):
"""
Helper function to decompose the host in the address
and the (optional) protocol and the (optional) port.
If missing, protocol defaults to "ldap" and port to 389,
in case protocol is missing or is "ldap", or 636, in case
protocol is "ldaps".
"""
if "://" not in host:
host = f"//{host}"
p = urlparse(host)
if p.scheme is not None and p.scheme != "":
proto = p.scheme
else:
proto = "ldap"
if p.port is not None:
port = p.port
else:
port = None
if port is not None:
addr = p.netloc.split(":")[0]
else:
addr = p.netloc
if proto == "ldap":
port = 389
elif proto == "ldaps":
port = 636
return proto, addr, port
def checked_port(provided, auto):
"""
Check consistency of ports given via the connection string
and the explicit parameter.
"""
_provided = provided is not None
if _provided and provided != auto:
log.warning(
"Explicitly provided port ({}) does not match "
"the automatically provided one ({}). The former prevails.".format(
provided, auto
)
)
return provided
if _provided:
return provided
return auto
def compose_dn_username(username, base_dn, ou=None, attribute_id=None):
"""
Output the distinguished name of the user to use as login.
"""
if base_dn in username:
return username
if ou is None:
return f"{attribute_id}={username},{base_dn}"
return f"{attribute_id}={username},ou={ou},{base_dn}"
class AsyncClient(LDAPClient):
"""
Wrapper around LDAPClient.
"""
def __init__(
self,
host=None,
port=None,
encryption=None,
ciphers=None,
validate=False,
ca_cert=None,
username=None,
password=None,
base_dn=None,
attribute_id="uid",
ou=None,
method="SIMPLE",
):
self.proto, self.host, _port = parse_host(host)
self.port = checked_port(port, _port)
self.full_uri = "{}://{}:{}".format(self.proto, self.host, self.port)
self.base_dn = base_dn
if encryption:
self._tls = True
else:
if self.proto == "ldaps":
raise ValueError(
'Incompatible provided protocol ("%s") and encryption configuration: TLS=%s',
self.proto,
encryption,
)
self._tls = False
super().__init__(self.full_uri, self._tls)
log.info(
"Connected at %s (TLS -> %s)", self.full_uri, "ON" if self.tls else "OFF"
)
if not validate:
self.set_cert_policy("never")
if ca_cert is not None:
self.set_ca_cert(ca_cert)
self.username = compose_dn_username(username, self.base_dn, ou, attribute_id)
self.password = password
self.method = method
self.set_auto_page_acquire(True)
self.set_credentials(self.method, user=self.username, password=self.password)

View File

@ -0,0 +1,118 @@
# -*- encoding: utf-8 -*-
from contextlib import contextmanager
import logging
import mock
import pytest
from phi.ldap.async_client import (
parse_host,
checked_port,
compose_dn_username,
AsyncClient,
)
BASE_DN = "dc=unit,dc=macaomilano,dc=org"
@contextmanager
def does_not_raise():
yield
@pytest.mark.parametrize(
"test_url, exp_proto, exp_addr, exp_port",
[
("1.3.1.2", "ldap", "1.3.1.2", 389),
("ldap://localhost:1312", "ldap", "localhost", 1312),
("localhost:1312", "ldap", "localhost", 1312),
("localhost", "ldap", "localhost", 389),
("ldap://localhost", "ldap", "localhost", 389),
("ldaps://localhost", "ldaps", "localhost", 636),
("ldaps://localhost:1312", "ldaps", "localhost", 1312),
],
)
def test_parse_host(test_url, exp_proto, exp_addr, exp_port):
proto, addr, port = parse_host(test_url)
assert proto == exp_proto
assert addr == exp_addr
assert port == exp_port
@pytest.mark.parametrize(
"manual, auto, exp_port", [(None, 389, 389), (1312, 389, 1312), (1312, 1312, 1312)]
)
def test_checked_port(manual, auto, exp_port, caplog):
port = checked_port(manual, auto)
if manual and manual != auto:
with caplog.at_level(logging.WARNING):
"The former prevails" in caplog.text
assert port == exp_port
@pytest.mark.parametrize(
"username, base_dn, ou, attribute_id, exp_dn",
[
(
f"uid=conte_mascetti,{BASE_DN}",
BASE_DN,
None,
"uid",
f"uid=conte_mascetti,{BASE_DN}",
),
("root", BASE_DN, None, "cn", f"cn=root,{BASE_DN}"),
("necchi", BASE_DN, "Hackers", "uid", f"uid=necchi,ou=Hackers,{BASE_DN}"),
("perozzi", BASE_DN, "Phrackers", "cn", f"cn=perozzi,ou=Phrackers,{BASE_DN}"),
],
)
def test_compose_dn_username(username, base_dn, ou, attribute_id, exp_dn):
dn = compose_dn_username(username, base_dn, ou, attribute_id)
assert dn == exp_dn
@pytest.mark.parametrize(
"url, encryption, validate, ca_cert, expectation",
[
("localhost", None, False, None, does_not_raise()),
("localhost", True, False, None, does_not_raise()),
("localhost", False, True, None, does_not_raise()),
("localhost", True, True, "path/to/cert.pem", does_not_raise()),
("ldaps://localhost", False, False, None, pytest.raises(ValueError)),
],
)
def test_AsyncClient_init(url, encryption, validate, ca_cert, expectation):
with expectation as exp:
cl = AsyncClient(
host=url,
port=389,
encryption=encryption,
ciphers=None,
validate=validate,
ca_cert=ca_cert,
username="conte_mascetti",
password="pass",
base_dn=BASE_DN,
ou="Hackers",
)
if exp is not None:
assert "Incompatible provided protocol" in str(exp.value)
return
assert cl.base_dn == BASE_DN
assert url in cl.full_uri
assert "389" in cl.full_uri
assert cl._tls if encryption else not cl._tls
if validate:
assert cl.cert_policy == -1
else:
assert cl.cert_policy == 0
if ca_cert:
assert cl.ca_cert == ca_cert
else:
assert cl.ca_cert == ""