This commit is contained in:
Michael Albert 2025-03-13 23:51:47 +01:00 committed by GitHub
commit 69d1224dae
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 166 additions and 2 deletions

View file

@ -0,0 +1,7 @@
Add 'on_user_search' callback to the third_party_rules section of the [module API](https://matrix-org.github.io/synapse/latest/modules/writing_a_module.html).
This allows to write modules that can trigger on search requests on the user directory and alter the results.
Possible use cases:
- Filter or group user search results
- Exclude MxIDs from the user search results
- Include MxIDs in the user search results even if they do not match the search criteria (e.g. a helper Bot)
Contributed by @awesome-michael.

View file

@ -315,6 +315,28 @@ identifier from an identity server (via a call to [`POST
If multiple modules implement this callback, Synapse runs them all in order.
### `on_user_search`
_First introduced in Synapse v1.xxx.0_
```python
async def on_user_search(results: SearchResult) -> None:
```
Called after a search in the user directory has been performed. The module is given
the search results in the SearchResult data format.
Modules can modify the `results` (e.g. by adding the address of a chatbot by default, filtering
the results for some criteria or grouping the results in a special manner. Be aware that altering
the results structure can affect the compatibility with the matrix specification and matrix clients),
or completely deny the user search by raising a `module_api.errors.SynapseError`.
If multiple modules implement this callback, they will be considered in order. If a
callback returns without raising an exception, Synapse falls through to the next one. The
user search will be forbidden as soon as one of the callbacks raises an exception. If
this happens, Synapse will not call any of the subsequent implementations of this
callback.
## Example
The example below is a module that implements the third-party rules callback

View file

@ -118,6 +118,7 @@ from synapse.module_api.callbacks.third_party_event_rules_callbacks import (
ON_REMOVE_USER_THIRD_PARTY_IDENTIFIER_CALLBACK,
ON_THREEPID_BIND_CALLBACK,
ON_USER_DEACTIVATION_STATUS_CHANGED_CALLBACK,
ON_USER_SEARCH_CALLBACK,
)
from synapse.push.httppusher import HttpPusher
from synapse.rest.client.login import LoginResponse
@ -385,6 +386,7 @@ class ModuleApi:
on_remove_user_third_party_identifier: Optional[
ON_REMOVE_USER_THIRD_PARTY_IDENTIFIER_CALLBACK
] = None,
on_user_search: Optional[ON_USER_SEARCH_CALLBACK] = None,
) -> None:
"""Registers callbacks for third party event rules capabilities.
@ -403,6 +405,7 @@ class ModuleApi:
on_threepid_bind=on_threepid_bind,
on_add_user_third_party_identifier=on_add_user_third_party_identifier,
on_remove_user_third_party_identifier=on_remove_user_third_party_identifier,
on_user_search=on_user_search,
)
def register_presence_router_callbacks(

View file

@ -26,6 +26,7 @@ from twisted.internet.defer import CancelledError
from synapse.api.errors import ModuleFailedException, SynapseError
from synapse.events import EventBase
from synapse.events.snapshot import UnpersistedEventContextBase
from synapse.storage.databases.main.user_directory import SearchResult
from synapse.storage.roommember import ProfileInfo
from synapse.types import Requester, StateMap
from synapse.util.async_helpers import delay_cancellation, maybe_awaitable
@ -54,6 +55,7 @@ ON_USER_DEACTIVATION_STATUS_CHANGED_CALLBACK = Callable[[str, bool, bool], Await
ON_THREEPID_BIND_CALLBACK = Callable[[str, str, str], Awaitable]
ON_ADD_USER_THIRD_PARTY_IDENTIFIER_CALLBACK = Callable[[str, str, str], Awaitable]
ON_REMOVE_USER_THIRD_PARTY_IDENTIFIER_CALLBACK = Callable[[str, str, str], Awaitable]
ON_USER_SEARCH_CALLBACK = Callable[[Requester, SearchResult], Awaitable]
def load_legacy_third_party_event_rules(hs: "HomeServer") -> None:
@ -75,6 +77,7 @@ def load_legacy_third_party_event_rules(hs: "HomeServer") -> None:
"on_create_room",
"check_threepid_can_be_invited",
"check_visibility_can_be_modified",
"on_user_search",
}
def async_wrapper(f: Optional[Callable]) -> Optional[Callable[..., Awaitable]]:
@ -185,6 +188,7 @@ class ThirdPartyEventRulesModuleApiCallbacks:
self._on_remove_user_third_party_identifier_callbacks: List[
ON_REMOVE_USER_THIRD_PARTY_IDENTIFIER_CALLBACK
] = []
self._on_user_search_callbacks: List[ON_USER_SEARCH_CALLBACK] = []
def register_third_party_rules_callbacks(
self,
@ -210,6 +214,7 @@ class ThirdPartyEventRulesModuleApiCallbacks:
on_remove_user_third_party_identifier: Optional[
ON_REMOVE_USER_THIRD_PARTY_IDENTIFIER_CALLBACK
] = None,
on_user_search: Optional[ON_USER_SEARCH_CALLBACK] = None,
) -> None:
"""Register callbacks from modules for each hook."""
if check_event_allowed is not None:
@ -257,6 +262,9 @@ class ThirdPartyEventRulesModuleApiCallbacks:
on_remove_user_third_party_identifier
)
if on_user_search is not None:
self._on_user_search_callbacks.append(on_user_search)
async def check_event_allowed(
self,
event: EventBase,
@ -597,3 +605,23 @@ class ThirdPartyEventRulesModuleApiCallbacks:
logger.exception(
"Failed to run module API callback %s: %s", callback, e
)
async def on_user_search(self, requester: Requester, results: SearchResult) -> None:
"""Intercept requests to search the user directory to maybe deny it (via an exception) or
update the result list, e.g. with a filtered version.
Args:
requester
results: The results of the search in the user directory.
"""
for callback in self._on_user_search_callbacks:
try:
await callback(requester, results)
except Exception as e:
# Don't silence the errors raised by this callback since we expect it to
# raise an exception to deny the search of users; instead make sure
# it's a SynapseError we can send to clients.
if not isinstance(e, SynapseError):
e = SynapseError(403, "User search is forbidden")
raise e

View file

@ -45,6 +45,9 @@ class UserDirectorySearchRestServlet(RestServlet):
self.hs = hs
self.auth = hs.get_auth()
self.user_directory_handler = hs.get_user_directory_handler()
self._third_party_event_rules = (
hs.get_module_api_callbacks().third_party_event_rules
)
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonMapping]:
"""Searches for users in directory
@ -83,6 +86,10 @@ class UserDirectorySearchRestServlet(RestServlet):
user_id, search_term, limit
)
# Let the third party rules modify the result list if needed, or abort
# the search entirely with an exception.
await self._third_party_event_rules.on_user_search(requester, results)
return 200, results

View file

@ -33,13 +33,15 @@ from synapse.module_api.callbacks.third_party_event_rules_callbacks import (
load_legacy_third_party_event_rules,
)
from synapse.rest import admin
from synapse.rest.client import account, login, profile, room
from synapse.rest.client import account, login, profile, room, user_directory
from synapse.server import HomeServer
from synapse.types import JsonDict, Requester, StateMap
from synapse.storage.databases.main.user_directory import SearchResult
from synapse.types import JsonDict, Requester, StateMap, UserProfile
from synapse.util import Clock
from synapse.util.frozenutils import unfreeze
from tests import unittest
from tests.unittest import override_config
if TYPE_CHECKING:
from synapse.module_api import ModuleApi
@ -100,6 +102,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
room.register_servlets,
profile.register_servlets,
account.register_servlets,
user_directory.register_servlets,
]
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
@ -1074,3 +1077,97 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
on_remove_user_third_party_identifier_callback_mock.assert_called_once()
args = on_remove_user_third_party_identifier_callback_mock.call_args[0]
self.assertEqual(args, (user_id, "email", "foo@example.com"))
@override_config({"user_directory": {"enabled": True, "search_all_users": True}})
def test_on_user_search(self) -> None:
"""Tests that the on_user_search module callback is correctly called on
searches in the user directory.
"""
# Register a mock callback.
m = AsyncMock(return_value=None)
self.hs.get_module_api_callbacks().third_party_event_rules._on_user_search_callbacks.append(
m
)
# make a search request
channel = self.make_request(
"POST",
"/_matrix/client/v3/user_directory/search",
{"search_term": "foo"},
access_token=self.tok,
)
self.assertEqual(channel.code, 200, channel.json_body)
# Check that the callback has been called once.
m.assert_called_once()
@override_config({"user_directory": {"enabled": True, "search_all_users": True}})
def test_on_user_search_modify_results(self) -> None:
"""Tests that the on_user_search module callback is correctly returning
the modified results list.
"""
result_list = [
UserProfile(
user_id="@foo:bar.com",
display_name="Foo",
avatar_url="mxc://bar.com/foo",
)
]
# patch the search callback so that it will modify the search result
async def search(
requester: Requester,
results: SearchResult,
) -> None:
results["results"] = result_list
self.hs.get_module_api_callbacks().third_party_event_rules._on_user_search_callbacks.append(
search
)
# now send a search request
channel = self.make_request(
"POST",
"/_matrix/client/v3/user_directory/search",
{"limit": 10, "search_term": "foo"},
access_token=self.tok,
)
self.assertEqual(channel.code, 200, channel.json_body)
# ... and check that it got modified
results = channel.json_body["results"]
self.assertEqual(results, result_list)
self.assertIn("results", channel.json_body)
self.assertEqual(1, len(channel.json_body["results"]))
@override_config({"user_directory": {"enabled": True, "search_all_users": True}})
def test_on_user_search_deny_request(self) -> None:
"""Tests that the on_user_search module returns the SynapseError to the API
when it is raised in the module.
"""
# patch the search so that it will raise an exception
async def search(requester: Requester, results: SearchResult) -> None:
raise SynapseError(401, "Unauthorized user search", "M_UNAUTHORIZED")
self.hs.get_module_api_callbacks().third_party_event_rules._on_user_search_callbacks.append(
search
)
channel = self.make_request(
"POST",
"/_matrix/client/v3/user_directory/search",
{"search_term": "something"},
access_token=self.tok,
)
# Check the error code
self.assertEqual(channel.code, 401, channel.json_body)
# Check the JSON body has the correct error message
self.assertEqual(
channel.json_body,
{"errcode": "M_UNAUTHORIZED", "error": "Unauthorized user search"},
)