diff --git a/async_tests/test_async_ldap_new_model.py b/async_tests/test_async_ldap_new_model.py index 778b0f7..d4da324 100644 --- a/async_tests/test_async_ldap_new_model.py +++ b/async_tests/test_async_ldap_new_model.py @@ -8,6 +8,7 @@ import pytest from phi.async_ldap.new_model import get_dn, User from phi.async_ldap.mixins import Member +from phi.exceptions import PhiCannotExecute BASE_DN = "dc=test,dc=domain,dc=tld" @@ -87,6 +88,34 @@ def test_get_dn_raises(): assert "Unacceptable input:" in str(e.value) +def test_repr(): + _cl = MockClient(return_value=None) + u = User(_cl, "test_user") + + assert repr(u) == f"" + + +def test_str(): + _cl = MockClient(return_value=None) + u = User(_cl, "test_user") + + assert str(u) == f"" + + +@pytest.mark.parametrize( + "input_obj", + [ + User(cl, "test_user"), + LDAPEntry(f"uid=test_user,ou=Hackers,{BASE_DN}"), + f"uid=test_user,ou=Hackers,{BASE_DN}", + ], +) +def test_eq(input_obj): + u = User(cl, "test_user") + + assert u == input_obj + + @pytest.mark.asyncio async def test_User_add(): _cl = MockClient(return_value=None) @@ -141,3 +170,27 @@ async def test_User_delete(): assert _cl.connect_called assert await _cl.connect_called_with_delete() + + +@pytest.mark.asyncio +async def test_User_get_invalid_attr_raises(): + _cl = MockClient(return_value=None) + u = User(_cl, "test_user") + + with pytest.raises(PhiCannotExecute) as ex: + _ = u["iDoNotExist"] + + assert "iDoNotExist" in str(ex.value) + assert "is not an allowed ldap attribute" in str(ex.value) + + +@pytest.mark.asyncio +async def test_User_set_invalid_attr_raises(): + _cl = MockClient(return_value=None) + u = User(_cl, "test_user") + + with pytest.raises(PhiCannotExecute) as ex: + u["iDoNotExist"] = "hello" + + assert "iDoNotExist" in str(ex.value) + assert "is not an allowed ldap attribute" in str(ex.value) diff --git a/integration_tests/test_model.py b/integration_tests/test_model.py index 5868430..20ae496 100644 --- a/integration_tests/test_model.py +++ b/integration_tests/test_model.py @@ -61,6 +61,27 @@ async def init_achilles(): return u +async def init_patroclus(): + u = User(cl, "patroclus") + u["cn"] = "Patroclus" + u["sn"] = "patroclus" + u["mail"] = "patroclus@phthia.gr" + u["userPassword"] = "WannabeAnHero" + + await u.save() + + return u + + +async def init_athena(): + s = Service(cl, "athena") + s["userPassword"] = "ἁ θεονόα" + + await s.save() + + return s + + async def init_group(group_name, members): g = Group(cl, group_name, member=members) @@ -82,6 +103,49 @@ async def test_User_init(): assert u == res +@pytest.mark.asyncio +@pytest.mark.integration_test +async def test_User_exists(): + async with clean_db(): + u1 = await init_achilles() + u2 = User(cl, "enea") + + assert await u1.exists() + assert not await u2.exists() + + +@pytest.mark.asyncio +@pytest.mark.integration_test +async def test_User_double_save_raises(): + async with clean_db(): + u = await init_achilles() + # Read all the data from the db + await u.sync() + + with pytest.raises(e.PhiEntryExists) as ex: + await u.save() + + assert u.dn in str(ex.value) + + +@pytest.mark.asyncio +@pytest.mark.integration_test +async def test_User_describe(): + async with clean_db(): + u = await init_achilles() + + res = await u.describe() + + assert res == { + "ou": "Hackers", + "uid": "achilles", + "cn": "Achilles", + "sn": "achilles", + "dn": f"uid=achilles,ou=Hackers,{BASE_DN}", + "mail": "achilles@phthia.gr", + } + + @pytest.mark.asyncio @pytest.mark.integration_test async def test_User_modify(): @@ -101,6 +165,19 @@ async def test_User_modify(): assert u[attr] == res[attr] +@pytest.mark.asyncio +@pytest.mark.integration_test +async def test_User_modify_raises(): + """Modifying a not-yet-existing user raises.""" + async with clean_db(): + u = User(cl, "enea") + + with pytest.raises(e.PhiEntryDoesNotExist) as ex: + await u.modify() + + assert u.dn in str(ex.value) + + @pytest.mark.asyncio @pytest.mark.integration_test async def test_User_delete(): @@ -116,6 +193,34 @@ async def test_User_delete(): assert u.dn in str(ex.value) +@pytest.mark.asyncio +@pytest.mark.integration_test +async def test_Service_init(): + async with clean_db(): + s = await init_athena() + + r = Robots(cl) + + res = await r.search("athena") + + assert s == res + + +@pytest.mark.asyncio +@pytest.mark.integration_test +async def test_Service_describe(): + async with clean_db(): + s = await init_athena() + + res = await s.describe() + + assert res == { + "ou": "Robots", + "uid": "athena", + "dn": f"uid=athena,ou=Robots,{BASE_DN}", + } + + @pytest.mark.asyncio @pytest.mark.integration_test async def test_Group_init(): @@ -128,4 +233,130 @@ async def test_Group_init(): res = await c.search("achaeans") assert g == res - assert [u] == [a for a in g.get_members()] + assert [u] == [a async for a in g.get_members()] + + +@pytest.mark.asyncio +@pytest.mark.integration_test +async def test_Group_describe(): + async with clean_db(): + u1 = await init_achilles() + u2 = await init_patroclus() + g = await init_group("achaeans", [u1, u2]) + + res = await g.describe() + + assert res == { + "ou": "Congregations", + "cn": "achaeans", + "dn": f"cn=achaeans,ou=Congregations,{BASE_DN}", + "member": [u1.dn, u2.dn], + } + + +@pytest.mark.asyncio +@pytest.mark.integration_test +async def test_Group_add_member(): + async with clean_db(): + u = await init_achilles() + a = await init_athena() + g1 = await init_group("achaeans", [u]) + g2 = await init_group("gods", [u]) + + await g2.add_member(a) + + m1 = [m async for m in g1.get_members()] + m2 = [m async for m in g2.get_members()] + + assert u in m1 + assert u in m2 + assert a not in m1 + assert a in m2 + + +@pytest.mark.asyncio +@pytest.mark.integration_test +async def test_Group_remove_member(): + async with clean_db(): + u = await init_achilles() + a = await init_athena() + g = await init_group("achaeans", [u, a]) + + m = [a async for a in g.get_members()] + + assert u in m + assert a in m + + await g.remove_member(a) + + assert [u] == [el async for el in g.get_members()] + + +@pytest.mark.asyncio +@pytest.mark.integration_test +async def test_User_groups(): + async with clean_db(): + u = await init_achilles() + a = await init_athena() + g1 = await init_group("achaeans", [u]) + g2 = await init_group("gods", [u, a]) + + res1 = await u.groups() + res2 = await a.groups() + + assert g1 in res1 + assert g2 in res1 + assert g1 not in res2 + assert g2 in res2 + + +@pytest.mark.asyncio +@pytest.mark.integration_test +async def test_OU_delete_raises(): + async with clean_db(): + u1 = await init_achilles() + u2 = await init_patroclus() + a = await init_athena() + g1 = await init_group("achaeans", [u1, u2]) + g2 = await init_group("gods", [a, u1]) + h = Hackers(cl) + _saved_val = h.delete_cascade + h.delete_cascade = False + + assert not h.delete_cascade + + with pytest.raises(e.PhiCannotExecute) as ex: + await h.delete() + + assert "delete_cascade is not set" in str(ex.value) + + h.delete_cascade = _saved_val + + +@pytest.mark.asyncio +@pytest.mark.integration_test +async def test_OU_delete_cascade(): + async with clean_db(): + u1 = await init_achilles() + u2 = await init_patroclus() + a = await init_athena() + g1 = await init_group("achaeans", [u1, u2]) + g2 = await init_group("gods", [a, u1]) + h = Hackers(cl) + _saved_val = h.delete_cascade + h.delete_cascade = True + + assert h.delete_cascade + + await h.delete() + g2_members = [e async for e in g2.get_members()] + h_members = [e async for e in h] + + assert not await u1.exists() + assert not await u2.exists() + assert h_members == [] + assert not await g1.exists() + assert u1 not in g2_members + assert a in g2_members + + h.delete_cascade = _saved_val diff --git a/src/phi/async_ldap/mixins.py b/src/phi/async_ldap/mixins.py index 3194ee5..0ae52aa 100644 --- a/src/phi/async_ldap/mixins.py +++ b/src/phi/async_ldap/mixins.py @@ -15,6 +15,15 @@ from phi.exceptions import ( from phi.security import hash_pass, handle_password +def de_listify(elems): + if not isinstance(elems, list): + return elems + if len(elems) == 1: + return elems[0] + else: + return elems + + async def build_heritage(obj, child_class, attribute_id="uid"): """ Given the object and the child class, yields the @@ -44,40 +53,17 @@ class Singleton(object): return cls._instances[name] -def get_value(obj, attr): - """ - Return the tuple (attribute_name, attribute_value) from obj. Extract the value, - either it being a constant or the result of a function call. - """ - if not isinstance(getattr(type(obj), attr), property) and callable( - getattr(type(obj), attr) - ): - return attr, getattr(obj, attr)() - else: - return attr, getattr(obj, attr) - - class Entry(object): """ Mixin to interact with LDAP. """ - def get_all_ldap_attributes(self): - return [get_value(self, attr) for attr in self.ldap_attributes] - def __repr__(self): return f"<{self.__class__.__name__} {self.dn}>" def __str__(self): return f"<{self.__class__.__name__} {self.dn}>" - def __iter__(self): - for k, v in self.get_all_ldap_attributes(): - yield k, v - - def __dict__(self): - return dict(self) - async def _create_new(self): self._entry["objectClass"] = self.object_class async with self.client.connect(is_async=True) as conn: @@ -87,10 +73,10 @@ class Entry(object): async def _get(self): async with self.client.connect(is_async=True) as conn: # This returns a list of dicts. It should always contain only one item: - # the one we are interested. + # the one we are interested in. _res = await conn.search(self.dn, 0) if len(_res) == 0: - raise PhiEntryDoesNotExist() + raise PhiEntryDoesNotExist(self.dn) elif len(_res) > 1: raise PhiUnexpectedRuntimeValue( "return value should be no more than one", res @@ -105,10 +91,15 @@ class Entry(object): async def _delete(self): async with self.client.connect(is_async=True) as conn: - await conn.delete(self.dn) + await conn.delete(self.dn, recursive=self.delete_cascade) async def describe(self): - return dict(await self._get()) + _internal = await self._get() + values = dict((k, de_listify(_internal[k])) for k in self.ldap_attributes) + if "userPassword" in self.ldap_attributes: + values.pop("userPassword") + values["dn"] = self.dn + return values @property def delete_cascade(self): @@ -140,7 +131,10 @@ class OrganizationalUnit(object): self.children = build_heritage(self, self.child_class, self.child_class.id_tag) self._entry = LDAPEntry(self.dn) for k, v in kwargs.items(): - self._entry[k] = v + if k in self.ldap_attributes: + self._entry[k] = v + if "delete_cascade" in kwargs: + self.delete_cascade = delete_cascade def __aiter__(self): return self @@ -164,12 +158,21 @@ class OrganizationalUnit(object): return f"ou={self.ou},{self.base_dn}" async def save(self): - try: - await self._create_new() - except bonsai.errors.AlreadyExists: - raise PhiEntryExists(self.dn) + """ + This function iterates over the OU's children and invokes its `save` method, + ignoring errors from yet existing ones. + """ + async for child in self: + try: + await child.save() + except PhiEntryExists: + pass async def search(self, member_name): + """ + This function allows one to search through the OU's children. The search + function is the one from the underlying library (bonsai) and is strict as such. + """ result = None async with self.client.connect(is_async=True) as conn: result = await conn.search( @@ -226,9 +229,31 @@ class Member(object): It provides the methods to interact with the LDAP db. To properly use, `ou`, `object_class` and `ldap_attributes` class attributes must be specified when inheriting. + + ## Usage + + The initialization needs an `phi.async_ldap.AsyncClient` and a `name`, that is used + as value in the identification attribute (i.e. `uid`). + This inits an object in memory that may or may not exist in the ldap database yet. + To test it, one can invoke the async method `exists` or may try to `sync`, handling + the corresponding exception (`PhiEntryDoesNotExist`). + To save a new instance, one can `save`. The instance accepts dict-like get and set + on the aforementioned `ldap_attributes`. Once an attribute value has been modified, + one can invoke `modify` to persist the changes. + To remove an instance from the database, one can invoke `delete`. + + ## Comparisons + + The comparison operation with a `Member` is quite loose: it returns `True` with + either: + - an instance of the same `type` (i.e. the same class whose this mixin is used + into) whose `dn` matches + - an `LDAPEntry` whose `dn` matches + - a string matching the `dn` """ def __init__(self, client, name, **kwargs): + super().__init__() self.client = client self.base_dn = client.base_dn self.name = name @@ -244,7 +269,7 @@ class Member(object): elif isinstance(other, str): return other == self.dn elif isinstance(other, LDAPEntry): - return str(other) == self.dn + return other["dn"] == self.dn else: return False @@ -255,30 +280,61 @@ class Member(object): def __setitem__(self, attr, val): if attr not in self.ldap_attributes: raise PhiCannotExecute( - f"{attr} is not an allowed ldap_attribute: {self.ldap_attributes}" + f"{attr} is not an allowed ldap attribute: {self.ldap_attributes}" ) self._entry[attr] = val def __getitem__(self, attr): if attr not in self.ldap_attributes: raise PhiCannotExecute( - f"{attr} is not an allowed ldap_attribute: {self.ldap_attributes}" + f"{attr} is not an allowed ldap attribute: {self.ldap_attributes}" ) return self._entry[attr][0] async def save(self): + """ + This method persists on the ldap database an inited instance. Raises + `PhiEntryExists` in case of a yet existing instance. Raises a specific error if + the instance misses any of the needed attributes (accoding to + `ldap_attributes`). + """ try: await self._create_new() except bonsai.errors.AlreadyExists: raise PhiEntryExists(self.dn) async def modify(self): + """ + This method saves the changes made to the instance on the ldap database. Raises + `PhiEntryDoesNotExist` in case of an instance not yet persisted. + """ await self._modify() async def delete(self): + """ + This method removes the instance from the database. Raises + `PhiEntryDoesNotExist` in case the entry does not exist. + """ await self._delete() async def sync(self): + """ + This method reads the `ldap_attributes` of an existing instance from the ldap + database and assigns the values to `self`. It is needed at first instantiation + of the object, in case an instance exists on the database. + """ res = await self._get() _hydrate(self, res) return self + + async def exists(self): + """ + This method returns `True` if the instance exists on the ldap database, `False` + if it does not. It might raise `PhiUnexpectedRuntimeValue` if the ldap state is + inconsistent. + """ + try: + _ = await self.sync() + return True + except PhiEntryDoesNotExist: + return False diff --git a/src/phi/async_ldap/new_model.py b/src/phi/async_ldap/new_model.py index 2459728..5e08e68 100644 --- a/src/phi/async_ldap/new_model.py +++ b/src/phi/async_ldap/new_model.py @@ -6,7 +6,22 @@ from multidict import MultiDict from phi.async_ldap import mixins -class User(mixins.Singleton, mixins.Entry, mixins.Member): +def parse_dn(dn): + return MultiDict(e.split("=") for e in dn.split(",")) + + +def get_dn(obj): + if isinstance(obj, mixins.Entry): + return obj.dn + elif isinstance(obj, LDAPEntry): + return obj["dn"] + elif isinstance(obj, str): + return obj + else: + raise ValueError(f"Unacceptable input: {obj}") + + +class User(mixins.Member, mixins.Entry, mixins.Singleton): object_class = [ "inetOrgPerson", "simpleSecurityObject", @@ -19,61 +34,104 @@ class User(mixins.Singleton, mixins.Entry, mixins.Member): ou = "Hackers" ldap_attributes = ["uid", "ou", "cn", "sn", "mail", "userPassword"] + async def iter_groups(self): # To be monkeypatched later + pass # pragma: no cover -class Hackers(mixins.Singleton, mixins.Entry, mixins.OrganizationalUnit): + async def groups(self): + return [g async for g in self.iter_groups()] + + async def delete(self): + async for group in self.iter_groups(): + await group.remove_member(self) + await super().delete() + + +class Hackers(mixins.OrganizationalUnit, mixins.Entry, mixins.Singleton): _instances = dict() # type: ignore ou = "Hackers" child_class = User -class Service(mixins.Singleton, mixins.Entry, mixins.Member): +class Service(mixins.Member, mixins.Entry, mixins.Singleton): object_class = ["simpleSecurityObject", "account", "top"] _instances = dict() # type: ignore id_tag = "uid" ou = "Robots" ldap_attributes = ["uid", "ou", "userPassword"] + async def iter_groups(self): # To be monkeypatched later + pass # pragma: no cover -class Robots(mixins.Singleton, mixins.Entry, mixins.OrganizationalUnit): + async def groups(self): + return [g async for g in self.iter_groups()] + + async def delete(self): + async for group in self.iter_groups(): + await group.remove_member(self) + await super().delete() + + +class Robots(mixins.OrganizationalUnit, mixins.Entry, mixins.Singleton): _instances = dict() # type: ignore ou = "Robots" child_class = Service -class Group(mixins.Singleton, mixins.Entry, mixins.Member): +class Group(mixins.Member, mixins.Entry, mixins.Singleton): object_class = ["groupOfNames", "top"] _instances = dict() # type: ignore id_tag = "cn" ou = "Congregations" ldap_attributes = ["cn", "ou", "member"] memeber_classes = {"Hackers": User, "Robots": Service} + empty = False async def add_member(self, member): - self._entry["member"].append(get_dn(member)) + member_dn = get_dn(member) + self._entry["member"].append(member_dn) await self.modify() async def remove_member(self, member): - self._entry["member"] = [get_dn(m) for m in self.get_members() if member != m] - await self.modify() + new_group_members = [get_dn(m) async for m in self.get_members() if member != m] + if len(new_group_members) == 0: + await self.delete() + self.empty = True + else: + self._entry["member"] = new_group_members + await self.modify() - def get_members(self): + async def get_members(self): + await self.sync() for member in self._entry.get("member", []): - dn = MultiDict(e.split("=") for e in member.split(",")) + dn = parse_dn(member) yield self.memeber_classes.get(dn["ou"])(self.client, dn["uid"]) -class Congregations(mixins.Singleton, mixins.Entry, mixins.OrganizationalUnit): +class Congregations(mixins.OrganizationalUnit, mixins.Entry, mixins.Singleton): _instances = dict() # type: ignore ou = "Congregations" child_class = Group -def get_dn(obj): - if isinstance(obj, mixins.Entry): - return obj.dn - elif isinstance(obj, LDAPEntry): - return obj["dn"] - elif isinstance(obj, str): - return obj - else: - raise ValueError(f"Unacceptable input: {obj}") +# We define this async method here **after** `User`, `Service` and `Group` have been +# defined, in order to avoid definition loops that would prevent the code from running. +# Indeed, this function explicitely uses `Group` but is needed as a `User` and `Service` +# method. In turn, `Group` definition relies on both `User` and `Service` being yet +# defined. +async def iter_groups(self): + async with self.client.connect(is_async=True) as conn: + res = await conn.search(f"{self.dn}", 2, attrlist=["memberOf"]) + if not res or len(res) == 0: + return + elif len(res) == 1: + for group in res[0].get("memberOf", []): + yield Group(self.client, parse_dn(get_dn(group))["cn"]) + else: + raise PhiUnexpectedRuntimeValue( + "return value should be no more than one", res + ) + + +# Monkeypatch +User.iter_groups = iter_groups # type: ignore +Service.iter_groups = iter_groups # type: ignore