Adding AsyncClient.
This commit is contained in:
parent
27fc927254
commit
b1837f80e4
132
src/phi/ldap/async_client.py
Normal file
132
src/phi/ldap/async_client.py
Normal file
|
@ -0,0 +1,132 @@
|
|||
from urllib.parse import urlparse
|
||||
|
||||
from bonsai import LDAPClient
|
||||
|
||||
from phi.logging import get_logger
|
||||
|
||||
|
||||
log = get_logger(__name__)
|
||||
|
||||
|
||||
def parse_host(host):
|
||||
"""
|
||||
Helper function to decompose the host in the address
|
||||
and the (optional) protocol and the (optional) port.
|
||||
If missing, protocol defaults to "ldap" and port to 389,
|
||||
in case protocol is missing or is "ldap", or 636, in case
|
||||
protocol is "ldaps".
|
||||
"""
|
||||
if "://" not in host:
|
||||
host = f"//{host}"
|
||||
|
||||
p = urlparse(host)
|
||||
if p.scheme is not None and p.scheme != "":
|
||||
proto = p.scheme
|
||||
else:
|
||||
proto = "ldap"
|
||||
|
||||
if p.port is not None:
|
||||
port = p.port
|
||||
else:
|
||||
port = None
|
||||
|
||||
if port is not None:
|
||||
addr = p.netloc.split(":")[0]
|
||||
else:
|
||||
addr = p.netloc
|
||||
if proto == "ldap":
|
||||
port = 389
|
||||
elif proto == "ldaps":
|
||||
port = 636
|
||||
|
||||
return proto, addr, port
|
||||
|
||||
|
||||
def checked_port(provided, auto):
|
||||
"""
|
||||
Check consistency of ports given via the connection string
|
||||
and the explicit parameter.
|
||||
"""
|
||||
_provided = provided is not None
|
||||
|
||||
if _provided and provided != auto:
|
||||
log.warning(
|
||||
"Explicitly provided port ({}) does not match "
|
||||
"the automatically provided one ({}). The former prevails.".format(
|
||||
provided, auto
|
||||
)
|
||||
)
|
||||
return provided
|
||||
|
||||
if _provided:
|
||||
return provided
|
||||
|
||||
return auto
|
||||
|
||||
|
||||
def compose_dn_username(username, base_dn, ou=None, attribute_id=None):
|
||||
"""
|
||||
Output the distinguished name of the user to use as login.
|
||||
"""
|
||||
if base_dn in username:
|
||||
return username
|
||||
|
||||
if ou is None:
|
||||
return f"{attribute_id}={username},{base_dn}"
|
||||
|
||||
return f"{attribute_id}={username},ou={ou},{base_dn}"
|
||||
|
||||
|
||||
class AsyncClient(LDAPClient):
|
||||
"""
|
||||
Wrapper around LDAPClient.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
host=None,
|
||||
port=None,
|
||||
encryption=None,
|
||||
ciphers=None,
|
||||
validate=False,
|
||||
ca_cert=None,
|
||||
username=None,
|
||||
password=None,
|
||||
base_dn=None,
|
||||
attribute_id="uid",
|
||||
ou=None,
|
||||
method="SIMPLE",
|
||||
):
|
||||
self.proto, self.host, _port = parse_host(host)
|
||||
self.port = checked_port(port, _port)
|
||||
self.full_uri = "{}://{}:{}".format(self.proto, self.host, self.port)
|
||||
self.base_dn = base_dn
|
||||
|
||||
if encryption:
|
||||
self._tls = True
|
||||
else:
|
||||
if self.proto == "ldaps":
|
||||
raise ValueError(
|
||||
'Incompatible provided protocol ("%s") and encryption configuration: TLS=%s',
|
||||
self.proto,
|
||||
encryption,
|
||||
)
|
||||
self._tls = False
|
||||
|
||||
super().__init__(self.full_uri, self._tls)
|
||||
log.info(
|
||||
"Connected at %s (TLS -> %s)", self.full_uri, "ON" if self.tls else "OFF"
|
||||
)
|
||||
|
||||
if not validate:
|
||||
self.set_cert_policy("never")
|
||||
|
||||
if ca_cert is not None:
|
||||
self.set_ca_cert(ca_cert)
|
||||
|
||||
self.username = compose_dn_username(username, self.base_dn, ou, attribute_id)
|
||||
self.password = password
|
||||
self.method = method
|
||||
|
||||
self.set_auto_page_acquire(True)
|
||||
self.set_credentials(self.method, user=self.username, password=self.password)
|
118
test/test_ldap_async_client.py
Normal file
118
test/test_ldap_async_client.py
Normal file
|
@ -0,0 +1,118 @@
|
|||
# -*- encoding: utf-8 -*-
|
||||
|
||||
from contextlib import contextmanager
|
||||
import logging
|
||||
|
||||
import mock
|
||||
import pytest
|
||||
|
||||
from phi.ldap.async_client import (
|
||||
parse_host,
|
||||
checked_port,
|
||||
compose_dn_username,
|
||||
AsyncClient,
|
||||
)
|
||||
|
||||
|
||||
BASE_DN = "dc=unit,dc=macaomilano,dc=org"
|
||||
|
||||
|
||||
@contextmanager
|
||||
def does_not_raise():
|
||||
yield
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"test_url, exp_proto, exp_addr, exp_port",
|
||||
[
|
||||
("1.3.1.2", "ldap", "1.3.1.2", 389),
|
||||
("ldap://localhost:1312", "ldap", "localhost", 1312),
|
||||
("localhost:1312", "ldap", "localhost", 1312),
|
||||
("localhost", "ldap", "localhost", 389),
|
||||
("ldap://localhost", "ldap", "localhost", 389),
|
||||
("ldaps://localhost", "ldaps", "localhost", 636),
|
||||
("ldaps://localhost:1312", "ldaps", "localhost", 1312),
|
||||
],
|
||||
)
|
||||
def test_parse_host(test_url, exp_proto, exp_addr, exp_port):
|
||||
proto, addr, port = parse_host(test_url)
|
||||
|
||||
assert proto == exp_proto
|
||||
assert addr == exp_addr
|
||||
assert port == exp_port
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"manual, auto, exp_port", [(None, 389, 389), (1312, 389, 1312), (1312, 1312, 1312)]
|
||||
)
|
||||
def test_checked_port(manual, auto, exp_port, caplog):
|
||||
port = checked_port(manual, auto)
|
||||
if manual and manual != auto:
|
||||
with caplog.at_level(logging.WARNING):
|
||||
"The former prevails" in caplog.text
|
||||
|
||||
assert port == exp_port
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"username, base_dn, ou, attribute_id, exp_dn",
|
||||
[
|
||||
(
|
||||
f"uid=conte_mascetti,{BASE_DN}",
|
||||
BASE_DN,
|
||||
None,
|
||||
"uid",
|
||||
f"uid=conte_mascetti,{BASE_DN}",
|
||||
),
|
||||
("root", BASE_DN, None, "cn", f"cn=root,{BASE_DN}"),
|
||||
("necchi", BASE_DN, "Hackers", "uid", f"uid=necchi,ou=Hackers,{BASE_DN}"),
|
||||
("perozzi", BASE_DN, "Phrackers", "cn", f"cn=perozzi,ou=Phrackers,{BASE_DN}"),
|
||||
],
|
||||
)
|
||||
def test_compose_dn_username(username, base_dn, ou, attribute_id, exp_dn):
|
||||
dn = compose_dn_username(username, base_dn, ou, attribute_id)
|
||||
|
||||
assert dn == exp_dn
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"url, encryption, validate, ca_cert, expectation",
|
||||
[
|
||||
("localhost", None, False, None, does_not_raise()),
|
||||
("localhost", True, False, None, does_not_raise()),
|
||||
("localhost", False, True, None, does_not_raise()),
|
||||
("localhost", True, True, "path/to/cert.pem", does_not_raise()),
|
||||
("ldaps://localhost", False, False, None, pytest.raises(ValueError)),
|
||||
],
|
||||
)
|
||||
def test_AsyncClient_init(url, encryption, validate, ca_cert, expectation):
|
||||
with expectation as exp:
|
||||
cl = AsyncClient(
|
||||
host=url,
|
||||
port=389,
|
||||
encryption=encryption,
|
||||
ciphers=None,
|
||||
validate=validate,
|
||||
ca_cert=ca_cert,
|
||||
username="conte_mascetti",
|
||||
password="pass",
|
||||
base_dn=BASE_DN,
|
||||
ou="Hackers",
|
||||
)
|
||||
|
||||
if exp is not None:
|
||||
assert "Incompatible provided protocol" in str(exp.value)
|
||||
return
|
||||
|
||||
assert cl.base_dn == BASE_DN
|
||||
assert url in cl.full_uri
|
||||
assert "389" in cl.full_uri
|
||||
assert cl._tls if encryption else not cl._tls
|
||||
if validate:
|
||||
assert cl.cert_policy == -1
|
||||
else:
|
||||
assert cl.cert_policy == 0
|
||||
if ca_cert:
|
||||
assert cl.ca_cert == ca_cert
|
||||
else:
|
||||
assert cl.ca_cert == ""
|
Loading…
Reference in New Issue
Block a user