mirror of
https://github.com/element-hq/synapse.git
synced 2025-03-14 09:45:51 +00:00
Merge ae2829d9a7
into 59a15da433
This commit is contained in:
commit
69d1224dae
6 changed files with 166 additions and 2 deletions
7
changelog.d/18156.feature
Normal file
7
changelog.d/18156.feature
Normal file
|
@ -0,0 +1,7 @@
|
|||
Add 'on_user_search' callback to the third_party_rules section of the [module API](https://matrix-org.github.io/synapse/latest/modules/writing_a_module.html).
|
||||
This allows to write modules that can trigger on search requests on the user directory and alter the results.
|
||||
Possible use cases:
|
||||
- Filter or group user search results
|
||||
- Exclude MxIDs from the user search results
|
||||
- Include MxIDs in the user search results even if they do not match the search criteria (e.g. a helper Bot)
|
||||
Contributed by @awesome-michael.
|
|
@ -315,6 +315,28 @@ identifier from an identity server (via a call to [`POST
|
|||
|
||||
If multiple modules implement this callback, Synapse runs them all in order.
|
||||
|
||||
### `on_user_search`
|
||||
|
||||
_First introduced in Synapse v1.xxx.0_
|
||||
|
||||
```python
|
||||
async def on_user_search(results: SearchResult) -> None:
|
||||
```
|
||||
|
||||
Called after a search in the user directory has been performed. The module is given
|
||||
the search results in the SearchResult data format.
|
||||
|
||||
Modules can modify the `results` (e.g. by adding the address of a chatbot by default, filtering
|
||||
the results for some criteria or grouping the results in a special manner. Be aware that altering
|
||||
the results structure can affect the compatibility with the matrix specification and matrix clients),
|
||||
or completely deny the user search by raising a `module_api.errors.SynapseError`.
|
||||
|
||||
If multiple modules implement this callback, they will be considered in order. If a
|
||||
callback returns without raising an exception, Synapse falls through to the next one. The
|
||||
user search will be forbidden as soon as one of the callbacks raises an exception. If
|
||||
this happens, Synapse will not call any of the subsequent implementations of this
|
||||
callback.
|
||||
|
||||
## Example
|
||||
|
||||
The example below is a module that implements the third-party rules callback
|
||||
|
|
|
@ -118,6 +118,7 @@ from synapse.module_api.callbacks.third_party_event_rules_callbacks import (
|
|||
ON_REMOVE_USER_THIRD_PARTY_IDENTIFIER_CALLBACK,
|
||||
ON_THREEPID_BIND_CALLBACK,
|
||||
ON_USER_DEACTIVATION_STATUS_CHANGED_CALLBACK,
|
||||
ON_USER_SEARCH_CALLBACK,
|
||||
)
|
||||
from synapse.push.httppusher import HttpPusher
|
||||
from synapse.rest.client.login import LoginResponse
|
||||
|
@ -385,6 +386,7 @@ class ModuleApi:
|
|||
on_remove_user_third_party_identifier: Optional[
|
||||
ON_REMOVE_USER_THIRD_PARTY_IDENTIFIER_CALLBACK
|
||||
] = None,
|
||||
on_user_search: Optional[ON_USER_SEARCH_CALLBACK] = None,
|
||||
) -> None:
|
||||
"""Registers callbacks for third party event rules capabilities.
|
||||
|
||||
|
@ -403,6 +405,7 @@ class ModuleApi:
|
|||
on_threepid_bind=on_threepid_bind,
|
||||
on_add_user_third_party_identifier=on_add_user_third_party_identifier,
|
||||
on_remove_user_third_party_identifier=on_remove_user_third_party_identifier,
|
||||
on_user_search=on_user_search,
|
||||
)
|
||||
|
||||
def register_presence_router_callbacks(
|
||||
|
|
|
@ -26,6 +26,7 @@ from twisted.internet.defer import CancelledError
|
|||
from synapse.api.errors import ModuleFailedException, SynapseError
|
||||
from synapse.events import EventBase
|
||||
from synapse.events.snapshot import UnpersistedEventContextBase
|
||||
from synapse.storage.databases.main.user_directory import SearchResult
|
||||
from synapse.storage.roommember import ProfileInfo
|
||||
from synapse.types import Requester, StateMap
|
||||
from synapse.util.async_helpers import delay_cancellation, maybe_awaitable
|
||||
|
@ -54,6 +55,7 @@ ON_USER_DEACTIVATION_STATUS_CHANGED_CALLBACK = Callable[[str, bool, bool], Await
|
|||
ON_THREEPID_BIND_CALLBACK = Callable[[str, str, str], Awaitable]
|
||||
ON_ADD_USER_THIRD_PARTY_IDENTIFIER_CALLBACK = Callable[[str, str, str], Awaitable]
|
||||
ON_REMOVE_USER_THIRD_PARTY_IDENTIFIER_CALLBACK = Callable[[str, str, str], Awaitable]
|
||||
ON_USER_SEARCH_CALLBACK = Callable[[Requester, SearchResult], Awaitable]
|
||||
|
||||
|
||||
def load_legacy_third_party_event_rules(hs: "HomeServer") -> None:
|
||||
|
@ -75,6 +77,7 @@ def load_legacy_third_party_event_rules(hs: "HomeServer") -> None:
|
|||
"on_create_room",
|
||||
"check_threepid_can_be_invited",
|
||||
"check_visibility_can_be_modified",
|
||||
"on_user_search",
|
||||
}
|
||||
|
||||
def async_wrapper(f: Optional[Callable]) -> Optional[Callable[..., Awaitable]]:
|
||||
|
@ -185,6 +188,7 @@ class ThirdPartyEventRulesModuleApiCallbacks:
|
|||
self._on_remove_user_third_party_identifier_callbacks: List[
|
||||
ON_REMOVE_USER_THIRD_PARTY_IDENTIFIER_CALLBACK
|
||||
] = []
|
||||
self._on_user_search_callbacks: List[ON_USER_SEARCH_CALLBACK] = []
|
||||
|
||||
def register_third_party_rules_callbacks(
|
||||
self,
|
||||
|
@ -210,6 +214,7 @@ class ThirdPartyEventRulesModuleApiCallbacks:
|
|||
on_remove_user_third_party_identifier: Optional[
|
||||
ON_REMOVE_USER_THIRD_PARTY_IDENTIFIER_CALLBACK
|
||||
] = None,
|
||||
on_user_search: Optional[ON_USER_SEARCH_CALLBACK] = None,
|
||||
) -> None:
|
||||
"""Register callbacks from modules for each hook."""
|
||||
if check_event_allowed is not None:
|
||||
|
@ -257,6 +262,9 @@ class ThirdPartyEventRulesModuleApiCallbacks:
|
|||
on_remove_user_third_party_identifier
|
||||
)
|
||||
|
||||
if on_user_search is not None:
|
||||
self._on_user_search_callbacks.append(on_user_search)
|
||||
|
||||
async def check_event_allowed(
|
||||
self,
|
||||
event: EventBase,
|
||||
|
@ -597,3 +605,23 @@ class ThirdPartyEventRulesModuleApiCallbacks:
|
|||
logger.exception(
|
||||
"Failed to run module API callback %s: %s", callback, e
|
||||
)
|
||||
|
||||
async def on_user_search(self, requester: Requester, results: SearchResult) -> None:
|
||||
"""Intercept requests to search the user directory to maybe deny it (via an exception) or
|
||||
update the result list, e.g. with a filtered version.
|
||||
|
||||
Args:
|
||||
requester
|
||||
results: The results of the search in the user directory.
|
||||
"""
|
||||
for callback in self._on_user_search_callbacks:
|
||||
try:
|
||||
await callback(requester, results)
|
||||
except Exception as e:
|
||||
# Don't silence the errors raised by this callback since we expect it to
|
||||
# raise an exception to deny the search of users; instead make sure
|
||||
# it's a SynapseError we can send to clients.
|
||||
if not isinstance(e, SynapseError):
|
||||
e = SynapseError(403, "User search is forbidden")
|
||||
|
||||
raise e
|
||||
|
|
|
@ -45,6 +45,9 @@ class UserDirectorySearchRestServlet(RestServlet):
|
|||
self.hs = hs
|
||||
self.auth = hs.get_auth()
|
||||
self.user_directory_handler = hs.get_user_directory_handler()
|
||||
self._third_party_event_rules = (
|
||||
hs.get_module_api_callbacks().third_party_event_rules
|
||||
)
|
||||
|
||||
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonMapping]:
|
||||
"""Searches for users in directory
|
||||
|
@ -83,6 +86,10 @@ class UserDirectorySearchRestServlet(RestServlet):
|
|||
user_id, search_term, limit
|
||||
)
|
||||
|
||||
# Let the third party rules modify the result list if needed, or abort
|
||||
# the search entirely with an exception.
|
||||
await self._third_party_event_rules.on_user_search(requester, results)
|
||||
|
||||
return 200, results
|
||||
|
||||
|
||||
|
|
|
@ -33,13 +33,15 @@ from synapse.module_api.callbacks.third_party_event_rules_callbacks import (
|
|||
load_legacy_third_party_event_rules,
|
||||
)
|
||||
from synapse.rest import admin
|
||||
from synapse.rest.client import account, login, profile, room
|
||||
from synapse.rest.client import account, login, profile, room, user_directory
|
||||
from synapse.server import HomeServer
|
||||
from synapse.types import JsonDict, Requester, StateMap
|
||||
from synapse.storage.databases.main.user_directory import SearchResult
|
||||
from synapse.types import JsonDict, Requester, StateMap, UserProfile
|
||||
from synapse.util import Clock
|
||||
from synapse.util.frozenutils import unfreeze
|
||||
|
||||
from tests import unittest
|
||||
from tests.unittest import override_config
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.module_api import ModuleApi
|
||||
|
@ -100,6 +102,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
|
|||
room.register_servlets,
|
||||
profile.register_servlets,
|
||||
account.register_servlets,
|
||||
user_directory.register_servlets,
|
||||
]
|
||||
|
||||
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
||||
|
@ -1074,3 +1077,97 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
|
|||
on_remove_user_third_party_identifier_callback_mock.assert_called_once()
|
||||
args = on_remove_user_third_party_identifier_callback_mock.call_args[0]
|
||||
self.assertEqual(args, (user_id, "email", "foo@example.com"))
|
||||
|
||||
@override_config({"user_directory": {"enabled": True, "search_all_users": True}})
|
||||
def test_on_user_search(self) -> None:
|
||||
"""Tests that the on_user_search module callback is correctly called on
|
||||
searches in the user directory.
|
||||
"""
|
||||
|
||||
# Register a mock callback.
|
||||
m = AsyncMock(return_value=None)
|
||||
self.hs.get_module_api_callbacks().third_party_event_rules._on_user_search_callbacks.append(
|
||||
m
|
||||
)
|
||||
|
||||
# make a search request
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
"/_matrix/client/v3/user_directory/search",
|
||||
{"search_term": "foo"},
|
||||
access_token=self.tok,
|
||||
)
|
||||
self.assertEqual(channel.code, 200, channel.json_body)
|
||||
|
||||
# Check that the callback has been called once.
|
||||
m.assert_called_once()
|
||||
|
||||
@override_config({"user_directory": {"enabled": True, "search_all_users": True}})
|
||||
def test_on_user_search_modify_results(self) -> None:
|
||||
"""Tests that the on_user_search module callback is correctly returning
|
||||
the modified results list.
|
||||
"""
|
||||
|
||||
result_list = [
|
||||
UserProfile(
|
||||
user_id="@foo:bar.com",
|
||||
display_name="Foo",
|
||||
avatar_url="mxc://bar.com/foo",
|
||||
)
|
||||
]
|
||||
|
||||
# patch the search callback so that it will modify the search result
|
||||
async def search(
|
||||
requester: Requester,
|
||||
results: SearchResult,
|
||||
) -> None:
|
||||
results["results"] = result_list
|
||||
|
||||
self.hs.get_module_api_callbacks().third_party_event_rules._on_user_search_callbacks.append(
|
||||
search
|
||||
)
|
||||
|
||||
# now send a search request
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
"/_matrix/client/v3/user_directory/search",
|
||||
{"limit": 10, "search_term": "foo"},
|
||||
access_token=self.tok,
|
||||
)
|
||||
self.assertEqual(channel.code, 200, channel.json_body)
|
||||
|
||||
# ... and check that it got modified
|
||||
results = channel.json_body["results"]
|
||||
self.assertEqual(results, result_list)
|
||||
|
||||
self.assertIn("results", channel.json_body)
|
||||
self.assertEqual(1, len(channel.json_body["results"]))
|
||||
|
||||
@override_config({"user_directory": {"enabled": True, "search_all_users": True}})
|
||||
def test_on_user_search_deny_request(self) -> None:
|
||||
"""Tests that the on_user_search module returns the SynapseError to the API
|
||||
when it is raised in the module.
|
||||
"""
|
||||
|
||||
# patch the search so that it will raise an exception
|
||||
async def search(requester: Requester, results: SearchResult) -> None:
|
||||
raise SynapseError(401, "Unauthorized user search", "M_UNAUTHORIZED")
|
||||
|
||||
self.hs.get_module_api_callbacks().third_party_event_rules._on_user_search_callbacks.append(
|
||||
search
|
||||
)
|
||||
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
"/_matrix/client/v3/user_directory/search",
|
||||
{"search_term": "something"},
|
||||
access_token=self.tok,
|
||||
)
|
||||
|
||||
# Check the error code
|
||||
self.assertEqual(channel.code, 401, channel.json_body)
|
||||
# Check the JSON body has the correct error message
|
||||
self.assertEqual(
|
||||
channel.json_body,
|
||||
{"errcode": "M_UNAUTHORIZED", "error": "Unauthorized user search"},
|
||||
)
|
||||
|
|
Loading…
Add table
Reference in a new issue