# -*- encoding: utf-8 -*- import asyncio from async_generator import asynccontextmanager from bonsai import LDAPEntry import mock import pytest from phi.async_ldap.new_model import get_dn, User BASE_DN = "dc=test,dc=domain,dc=tld" class MockClient(object): def __init__(self, *args, **kwargs): self.return_value = kwargs.get("return_value") self.connect_called = False self.conn = mock.MagicMock() def connect_called_with_search(self): self.conn.search.assert_called() def connect_called_with_add(self): self.conn.add.assert_called() def connect_called_with_modify(self): self.conn.modify.assert_called() def connect_called_with_delete(self): self.conn.delete.assert_called() @property def base_dn(self): return BASE_DN @asynccontextmanager async def connect(self, *args, **kwargs): self.connect_called = True async def _search(*a, **kw): return self.return_value async def _add(*a, **kw): return self.return_value async def _modify(*a, **kw): return self.return_value async def _delete(*a, **kw): return self.return_value self.conn.search = mock.MagicMock(side_effect=_search) self.conn.add = mock.MagicMock(side_effect=_add) self.conn.modify = mock.MagicMock(side_effect=_modify) self.conn.delete = mock.MagicMock(side_effect=_delete) yield self.conn cl = mock.MagicMock() cl.base_dn = BASE_DN @pytest.mark.parametrize( "input_obj, expected_result", [ (User(cl, "test_user"), f"uid=test_user,ou=Hackers,{BASE_DN}"), ( LDAPEntry(f"uid=test_user,ou=Hackers,{BASE_DN}"), f"uid=test_user,ou=Hackers,{BASE_DN}", ), (f"uid=test_user,ou=Hackers,{BASE_DN}", f"uid=test_user,ou=Hackers,{BASE_DN}"), ], ) def test_get_dn(input_obj, expected_result): assert get_dn(input_obj) == expected_result def test_get_dn_raises(): with pytest.raises(ValueError) as e: _ = get_dn(object) assert "Unacceptable input:" in str(e.value) @pytest.mark.asyncio async def test_User_add(): _cl = MockClient(return_value=None) u = User(_cl, "test_user") assert u.dn == f"uid=test_user,ou=Hackers,{BASE_DN}" _ = await u.save() assert _cl.connect_called _cl.connect_called_with_add()