mirror of
https://github.com/element-hq/synapse.git
synced 2025-03-14 09:45:51 +00:00
Implement MSC4133 to support custom profile fields. (#17488)
Implementation of [MSC4133](https://github.com/matrix-org/matrix-spec-proposals/pull/4133) to support custom profile fields. It is behind an experimental flag and includes tests. ### Pull Request Checklist <!-- Please read https://element-hq.github.io/synapse/latest/development/contributing_guide.html before submitting your pull request --> * [x] Pull request is based on the develop branch * [x] Pull request includes a [changelog file](https://element-hq.github.io/synapse/latest/development/contributing_guide.html#changelog). The entry should: - Be a short description of your change which makes sense to users. "Fixed a bug that prevented receiving messages from other servers." instead of "Moved X method from `EventStore` to `EventWorkerStore`.". - Use markdown where necessary, mostly for `code blocks`. - End with either a period (.) or an exclamation mark (!). - Start with a capital letter. - Feel free to credit yourself, by adding a sentence "Contributed by @github_username." or "Contributed by [Your Name]." to the end of the entry. * [x] [Code style](https://element-hq.github.io/synapse/latest/code_style.html) is correct (run the [linters](https://element-hq.github.io/synapse/latest/development/contributing_guide.html#run-the-linters)) --------- Co-authored-by: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com>
This commit is contained in:
parent
0a31cf18cd
commit
ca290d325c
13 changed files with 1039 additions and 26 deletions
1
changelog.d/17488.feature
Normal file
1
changelog.d/17488.feature
Normal file
|
@ -0,0 +1 @@
|
|||
Implement [MSC4133](https://github.com/matrix-org/matrix-spec-proposals/pull/4133) for custom profile fields.
|
|
@ -132,6 +132,10 @@ class Codes(str, Enum):
|
|||
# connection.
|
||||
UNKNOWN_POS = "M_UNKNOWN_POS"
|
||||
|
||||
# Part of MSC4133
|
||||
PROFILE_TOO_LARGE = "M_PROFILE_TOO_LARGE"
|
||||
KEY_TOO_LARGE = "M_KEY_TOO_LARGE"
|
||||
|
||||
|
||||
class CodeMessageException(RuntimeError):
|
||||
"""An exception with integer code, a message string attributes and optional headers.
|
||||
|
|
|
@ -436,6 +436,9 @@ class ExperimentalConfig(Config):
|
|||
("experimental", "msc4108_delegation_endpoint"),
|
||||
)
|
||||
|
||||
# MSC4133: Custom profile fields
|
||||
self.msc4133_enabled: bool = experimental.get("msc4133_enabled", False)
|
||||
|
||||
# MSC4210: Remove legacy mentions
|
||||
self.msc4210_enabled: bool = experimental.get("msc4210_enabled", False)
|
||||
|
||||
|
|
|
@ -32,7 +32,7 @@ from synapse.api.errors import (
|
|||
SynapseError,
|
||||
)
|
||||
from synapse.storage.databases.main.media_repository import LocalMedia, RemoteMedia
|
||||
from synapse.types import JsonDict, Requester, UserID, create_requester
|
||||
from synapse.types import JsonDict, JsonValue, Requester, UserID, create_requester
|
||||
from synapse.util.caches.descriptors import cached
|
||||
from synapse.util.stringutils import parse_and_validate_mxc_uri
|
||||
|
||||
|
@ -43,6 +43,8 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
MAX_DISPLAYNAME_LEN = 256
|
||||
MAX_AVATAR_URL_LEN = 1000
|
||||
# Field name length is specced at 255 bytes.
|
||||
MAX_CUSTOM_FIELD_LEN = 255
|
||||
|
||||
|
||||
class ProfileHandler:
|
||||
|
@ -90,7 +92,15 @@ class ProfileHandler:
|
|||
|
||||
if self.hs.is_mine(target_user):
|
||||
profileinfo = await self.store.get_profileinfo(target_user)
|
||||
if profileinfo.display_name is None and profileinfo.avatar_url is None:
|
||||
extra_fields = {}
|
||||
if self.hs.config.experimental.msc4133_enabled:
|
||||
extra_fields = await self.store.get_profile_fields(target_user)
|
||||
|
||||
if (
|
||||
profileinfo.display_name is None
|
||||
and profileinfo.avatar_url is None
|
||||
and not extra_fields
|
||||
):
|
||||
raise SynapseError(404, "Profile was not found", Codes.NOT_FOUND)
|
||||
|
||||
# Do not include display name or avatar if unset.
|
||||
|
@ -99,6 +109,9 @@ class ProfileHandler:
|
|||
ret[ProfileFields.DISPLAYNAME] = profileinfo.display_name
|
||||
if profileinfo.avatar_url is not None:
|
||||
ret[ProfileFields.AVATAR_URL] = profileinfo.avatar_url
|
||||
if extra_fields:
|
||||
ret.update(extra_fields)
|
||||
|
||||
return ret
|
||||
else:
|
||||
try:
|
||||
|
@ -403,6 +416,110 @@ class ProfileHandler:
|
|||
|
||||
return True
|
||||
|
||||
async def get_profile_field(
|
||||
self, target_user: UserID, field_name: str
|
||||
) -> JsonValue:
|
||||
"""
|
||||
Fetch a user's profile from the database for local users and over federation
|
||||
for remote users.
|
||||
|
||||
Args:
|
||||
target_user: The user ID to fetch the profile for.
|
||||
field_name: The field to fetch the profile for.
|
||||
|
||||
Returns:
|
||||
The value for the profile field or None if the field does not exist.
|
||||
"""
|
||||
if self.hs.is_mine(target_user):
|
||||
try:
|
||||
field_value = await self.store.get_profile_field(
|
||||
target_user, field_name
|
||||
)
|
||||
except StoreError as e:
|
||||
if e.code == 404:
|
||||
raise SynapseError(404, "Profile was not found", Codes.NOT_FOUND)
|
||||
raise
|
||||
|
||||
return field_value
|
||||
else:
|
||||
try:
|
||||
result = await self.federation.make_query(
|
||||
destination=target_user.domain,
|
||||
query_type="profile",
|
||||
args={"user_id": target_user.to_string(), "field": field_name},
|
||||
ignore_backoff=True,
|
||||
)
|
||||
except RequestSendFailed as e:
|
||||
raise SynapseError(502, "Failed to fetch profile") from e
|
||||
except HttpResponseException as e:
|
||||
raise e.to_synapse_error()
|
||||
|
||||
return result.get(field_name)
|
||||
|
||||
async def set_profile_field(
|
||||
self,
|
||||
target_user: UserID,
|
||||
requester: Requester,
|
||||
field_name: str,
|
||||
new_value: JsonValue,
|
||||
by_admin: bool = False,
|
||||
deactivation: bool = False,
|
||||
) -> None:
|
||||
"""Set a new profile field for a user.
|
||||
|
||||
Args:
|
||||
target_user: the user whose profile is to be changed.
|
||||
requester: The user attempting to make this change.
|
||||
field_name: The name of the profile field to update.
|
||||
new_value: The new field value for this user.
|
||||
by_admin: Whether this change was made by an administrator.
|
||||
deactivation: Whether this change was made while deactivating the user.
|
||||
"""
|
||||
if not self.hs.is_mine(target_user):
|
||||
raise SynapseError(400, "User is not hosted on this homeserver")
|
||||
|
||||
if not by_admin and target_user != requester.user:
|
||||
raise AuthError(403, "Cannot set another user's profile")
|
||||
|
||||
await self.store.set_profile_field(target_user, field_name, new_value)
|
||||
|
||||
# Custom fields do not propagate into the user directory *or* rooms.
|
||||
profile = await self.store.get_profileinfo(target_user)
|
||||
await self._third_party_rules.on_profile_update(
|
||||
target_user.to_string(), profile, by_admin, deactivation
|
||||
)
|
||||
|
||||
async def delete_profile_field(
|
||||
self,
|
||||
target_user: UserID,
|
||||
requester: Requester,
|
||||
field_name: str,
|
||||
by_admin: bool = False,
|
||||
deactivation: bool = False,
|
||||
) -> None:
|
||||
"""Delete a field from a user's profile.
|
||||
|
||||
Args:
|
||||
target_user: the user whose profile is to be changed.
|
||||
requester: The user attempting to make this change.
|
||||
field_name: The name of the profile field to remove.
|
||||
by_admin: Whether this change was made by an administrator.
|
||||
deactivation: Whether this change was made while deactivating the user.
|
||||
"""
|
||||
if not self.hs.is_mine(target_user):
|
||||
raise SynapseError(400, "User is not hosted on this homeserver")
|
||||
|
||||
if not by_admin and target_user != requester.user:
|
||||
raise AuthError(400, "Cannot set another user's profile")
|
||||
|
||||
await self.store.delete_profile_field(target_user, field_name)
|
||||
|
||||
# Custom fields do not propagate into the user directory *or* rooms.
|
||||
profile = await self.store.get_profileinfo(target_user)
|
||||
await self._third_party_rules.on_profile_update(
|
||||
target_user.to_string(), profile, by_admin, deactivation
|
||||
)
|
||||
|
||||
async def on_profile_query(self, args: JsonDict) -> JsonDict:
|
||||
"""Handles federation profile query requests."""
|
||||
|
||||
|
@ -419,13 +536,24 @@ class ProfileHandler:
|
|||
|
||||
just_field = args.get("field", None)
|
||||
|
||||
response = {}
|
||||
response: JsonDict = {}
|
||||
try:
|
||||
if just_field is None or just_field == "displayname":
|
||||
if just_field is None or just_field == ProfileFields.DISPLAYNAME:
|
||||
response["displayname"] = await self.store.get_profile_displayname(user)
|
||||
|
||||
if just_field is None or just_field == "avatar_url":
|
||||
if just_field is None or just_field == ProfileFields.AVATAR_URL:
|
||||
response["avatar_url"] = await self.store.get_profile_avatar_url(user)
|
||||
|
||||
if self.hs.config.experimental.msc4133_enabled:
|
||||
if just_field is None:
|
||||
response.update(await self.store.get_profile_fields(user))
|
||||
elif just_field not in (
|
||||
ProfileFields.DISPLAYNAME,
|
||||
ProfileFields.AVATAR_URL,
|
||||
):
|
||||
response[just_field] = await self.store.get_profile_field(
|
||||
user, just_field
|
||||
)
|
||||
except StoreError as e:
|
||||
if e.code == 404:
|
||||
raise SynapseError(404, "Profile was not found", Codes.NOT_FOUND)
|
||||
|
|
|
@ -92,6 +92,23 @@ class CapabilitiesRestServlet(RestServlet):
|
|||
"enabled": self.config.experimental.msc3664_enabled,
|
||||
}
|
||||
|
||||
if self.config.experimental.msc4133_enabled:
|
||||
response["capabilities"]["uk.tcpip.msc4133.profile_fields"] = {
|
||||
"enabled": True,
|
||||
}
|
||||
|
||||
# Ensure this is consistent with the legacy m.set_displayname and
|
||||
# m.set_avatar_url.
|
||||
disallowed = []
|
||||
if not self.config.registration.enable_set_displayname:
|
||||
disallowed.append("displayname")
|
||||
if not self.config.registration.enable_set_avatar_url:
|
||||
disallowed.append("avatar_url")
|
||||
if disallowed:
|
||||
response["capabilities"]["uk.tcpip.msc4133.profile_fields"][
|
||||
"disallowed"
|
||||
] = disallowed
|
||||
|
||||
return HTTPStatus.OK, response
|
||||
|
||||
|
||||
|
|
|
@ -21,10 +21,13 @@
|
|||
|
||||
"""This module contains REST servlets to do with profile: /profile/<paths>"""
|
||||
|
||||
import re
|
||||
from http import HTTPStatus
|
||||
from typing import TYPE_CHECKING, Tuple
|
||||
|
||||
from synapse.api.constants import ProfileFields
|
||||
from synapse.api.errors import Codes, SynapseError
|
||||
from synapse.handlers.profile import MAX_CUSTOM_FIELD_LEN
|
||||
from synapse.http.server import HttpServer
|
||||
from synapse.http.servlet import (
|
||||
RestServlet,
|
||||
|
@ -33,7 +36,8 @@ from synapse.http.servlet import (
|
|||
)
|
||||
from synapse.http.site import SynapseRequest
|
||||
from synapse.rest.client._base import client_patterns
|
||||
from synapse.types import JsonDict, UserID
|
||||
from synapse.types import JsonDict, JsonValue, UserID
|
||||
from synapse.util.stringutils import is_namedspaced_grammar
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
@ -91,6 +95,11 @@ class ProfileDisplaynameRestServlet(RestServlet):
|
|||
async def on_PUT(
|
||||
self, request: SynapseRequest, user_id: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
if not UserID.is_valid(user_id):
|
||||
raise SynapseError(
|
||||
HTTPStatus.BAD_REQUEST, "Invalid user id", Codes.INVALID_PARAM
|
||||
)
|
||||
|
||||
requester = await self.auth.get_user_by_req(request, allow_guest=True)
|
||||
user = UserID.from_string(user_id)
|
||||
is_admin = await self.auth.is_server_admin(requester)
|
||||
|
@ -101,9 +110,7 @@ class ProfileDisplaynameRestServlet(RestServlet):
|
|||
new_name = content["displayname"]
|
||||
except Exception:
|
||||
raise SynapseError(
|
||||
code=400,
|
||||
msg="Unable to parse name",
|
||||
errcode=Codes.BAD_JSON,
|
||||
400, "Missing key 'displayname'", errcode=Codes.MISSING_PARAM
|
||||
)
|
||||
|
||||
propagate = _read_propagate(self.hs, request)
|
||||
|
@ -166,6 +173,11 @@ class ProfileAvatarURLRestServlet(RestServlet):
|
|||
async def on_PUT(
|
||||
self, request: SynapseRequest, user_id: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
if not UserID.is_valid(user_id):
|
||||
raise SynapseError(
|
||||
HTTPStatus.BAD_REQUEST, "Invalid user id", Codes.INVALID_PARAM
|
||||
)
|
||||
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
user = UserID.from_string(user_id)
|
||||
is_admin = await self.auth.is_server_admin(requester)
|
||||
|
@ -232,7 +244,180 @@ class ProfileRestServlet(RestServlet):
|
|||
return 200, ret
|
||||
|
||||
|
||||
class UnstableProfileFieldRestServlet(RestServlet):
|
||||
PATTERNS = [
|
||||
re.compile(
|
||||
r"^/_matrix/client/unstable/uk\.tcpip\.msc4133/profile/(?P<user_id>[^/]*)/(?P<field_name>[^/]*)"
|
||||
)
|
||||
]
|
||||
CATEGORY = "Event sending requests"
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__()
|
||||
self.hs = hs
|
||||
self.profile_handler = hs.get_profile_handler()
|
||||
self.auth = hs.get_auth()
|
||||
|
||||
async def on_GET(
|
||||
self, request: SynapseRequest, user_id: str, field_name: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
requester_user = None
|
||||
|
||||
if self.hs.config.server.require_auth_for_profile_requests:
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
requester_user = requester.user
|
||||
|
||||
if not UserID.is_valid(user_id):
|
||||
raise SynapseError(
|
||||
HTTPStatus.BAD_REQUEST, "Invalid user id", Codes.INVALID_PARAM
|
||||
)
|
||||
|
||||
if not field_name:
|
||||
raise SynapseError(400, "Field name too short", errcode=Codes.INVALID_PARAM)
|
||||
|
||||
if len(field_name.encode("utf-8")) > MAX_CUSTOM_FIELD_LEN:
|
||||
raise SynapseError(400, "Field name too long", errcode=Codes.KEY_TOO_LARGE)
|
||||
if not is_namedspaced_grammar(field_name):
|
||||
raise SynapseError(
|
||||
400,
|
||||
"Field name does not follow Common Namespaced Identifier Grammar",
|
||||
errcode=Codes.INVALID_PARAM,
|
||||
)
|
||||
|
||||
user = UserID.from_string(user_id)
|
||||
await self.profile_handler.check_profile_query_allowed(user, requester_user)
|
||||
|
||||
if field_name == ProfileFields.DISPLAYNAME:
|
||||
field_value: JsonValue = await self.profile_handler.get_displayname(user)
|
||||
elif field_name == ProfileFields.AVATAR_URL:
|
||||
field_value = await self.profile_handler.get_avatar_url(user)
|
||||
else:
|
||||
field_value = await self.profile_handler.get_profile_field(user, field_name)
|
||||
|
||||
return 200, {field_name: field_value}
|
||||
|
||||
async def on_PUT(
|
||||
self, request: SynapseRequest, user_id: str, field_name: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
if not UserID.is_valid(user_id):
|
||||
raise SynapseError(
|
||||
HTTPStatus.BAD_REQUEST, "Invalid user id", Codes.INVALID_PARAM
|
||||
)
|
||||
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
user = UserID.from_string(user_id)
|
||||
is_admin = await self.auth.is_server_admin(requester)
|
||||
|
||||
if not field_name:
|
||||
raise SynapseError(400, "Field name too short", errcode=Codes.INVALID_PARAM)
|
||||
|
||||
if len(field_name.encode("utf-8")) > MAX_CUSTOM_FIELD_LEN:
|
||||
raise SynapseError(400, "Field name too long", errcode=Codes.KEY_TOO_LARGE)
|
||||
if not is_namedspaced_grammar(field_name):
|
||||
raise SynapseError(
|
||||
400,
|
||||
"Field name does not follow Common Namespaced Identifier Grammar",
|
||||
errcode=Codes.INVALID_PARAM,
|
||||
)
|
||||
|
||||
content = parse_json_object_from_request(request)
|
||||
try:
|
||||
new_value = content[field_name]
|
||||
except KeyError:
|
||||
raise SynapseError(
|
||||
400, f"Missing key '{field_name}'", errcode=Codes.MISSING_PARAM
|
||||
)
|
||||
|
||||
propagate = _read_propagate(self.hs, request)
|
||||
|
||||
requester_suspended = (
|
||||
await self.hs.get_datastores().main.get_user_suspended_status(
|
||||
requester.user.to_string()
|
||||
)
|
||||
)
|
||||
|
||||
if requester_suspended:
|
||||
raise SynapseError(
|
||||
403,
|
||||
"Updating profile while account is suspended is not allowed.",
|
||||
Codes.USER_ACCOUNT_SUSPENDED,
|
||||
)
|
||||
|
||||
if field_name == ProfileFields.DISPLAYNAME:
|
||||
await self.profile_handler.set_displayname(
|
||||
user, requester, new_value, is_admin, propagate=propagate
|
||||
)
|
||||
elif field_name == ProfileFields.AVATAR_URL:
|
||||
await self.profile_handler.set_avatar_url(
|
||||
user, requester, new_value, is_admin, propagate=propagate
|
||||
)
|
||||
else:
|
||||
await self.profile_handler.set_profile_field(
|
||||
user, requester, field_name, new_value, is_admin
|
||||
)
|
||||
|
||||
return 200, {}
|
||||
|
||||
async def on_DELETE(
|
||||
self, request: SynapseRequest, user_id: str, field_name: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
if not UserID.is_valid(user_id):
|
||||
raise SynapseError(
|
||||
HTTPStatus.BAD_REQUEST, "Invalid user id", Codes.INVALID_PARAM
|
||||
)
|
||||
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
user = UserID.from_string(user_id)
|
||||
is_admin = await self.auth.is_server_admin(requester)
|
||||
|
||||
if not field_name:
|
||||
raise SynapseError(400, "Field name too short", errcode=Codes.INVALID_PARAM)
|
||||
|
||||
if len(field_name.encode("utf-8")) > MAX_CUSTOM_FIELD_LEN:
|
||||
raise SynapseError(400, "Field name too long", errcode=Codes.KEY_TOO_LARGE)
|
||||
if not is_namedspaced_grammar(field_name):
|
||||
raise SynapseError(
|
||||
400,
|
||||
"Field name does not follow Common Namespaced Identifier Grammar",
|
||||
errcode=Codes.INVALID_PARAM,
|
||||
)
|
||||
|
||||
propagate = _read_propagate(self.hs, request)
|
||||
|
||||
requester_suspended = (
|
||||
await self.hs.get_datastores().main.get_user_suspended_status(
|
||||
requester.user.to_string()
|
||||
)
|
||||
)
|
||||
|
||||
if requester_suspended:
|
||||
raise SynapseError(
|
||||
403,
|
||||
"Updating profile while account is suspended is not allowed.",
|
||||
Codes.USER_ACCOUNT_SUSPENDED,
|
||||
)
|
||||
|
||||
if field_name == ProfileFields.DISPLAYNAME:
|
||||
await self.profile_handler.set_displayname(
|
||||
user, requester, "", is_admin, propagate=propagate
|
||||
)
|
||||
elif field_name == ProfileFields.AVATAR_URL:
|
||||
await self.profile_handler.set_avatar_url(
|
||||
user, requester, "", is_admin, propagate=propagate
|
||||
)
|
||||
else:
|
||||
await self.profile_handler.delete_profile_field(
|
||||
user, requester, field_name, is_admin
|
||||
)
|
||||
|
||||
return 200, {}
|
||||
|
||||
|
||||
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
|
||||
# The specific displayname / avatar URL / custom field endpoints *must* appear
|
||||
# before their corresponding generic profile endpoint.
|
||||
ProfileDisplaynameRestServlet(hs).register(http_server)
|
||||
ProfileAvatarURLRestServlet(hs).register(http_server)
|
||||
ProfileRestServlet(hs).register(http_server)
|
||||
if hs.config.experimental.msc4133_enabled:
|
||||
UnstableProfileFieldRestServlet(hs).register(http_server)
|
||||
|
|
|
@ -172,6 +172,8 @@ class VersionsRestServlet(RestServlet):
|
|||
"org.matrix.msc4140": bool(self.config.server.max_event_delay_ms),
|
||||
# Simplified sliding sync
|
||||
"org.matrix.simplified_msc3575": msc3575_enabled,
|
||||
# Arbitrary key-value profile fields.
|
||||
"uk.tcpip.msc4133": self.config.experimental.msc4133_enabled,
|
||||
},
|
||||
},
|
||||
)
|
||||
|
|
|
@ -18,8 +18,13 @@
|
|||
# [This file includes modifications made by New Vector Limited]
|
||||
#
|
||||
#
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
import json
|
||||
from typing import TYPE_CHECKING, Dict, Optional, Tuple, cast
|
||||
|
||||
from canonicaljson import encode_canonical_json
|
||||
|
||||
from synapse.api.constants import ProfileFields
|
||||
from synapse.api.errors import Codes, StoreError
|
||||
from synapse.storage._base import SQLBaseStore
|
||||
from synapse.storage.database import (
|
||||
DatabasePool,
|
||||
|
@ -27,13 +32,17 @@ from synapse.storage.database import (
|
|||
LoggingTransaction,
|
||||
)
|
||||
from synapse.storage.databases.main.roommember import ProfileInfo
|
||||
from synapse.storage.engines import PostgresEngine
|
||||
from synapse.types import JsonDict, UserID
|
||||
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
|
||||
from synapse.types import JsonDict, JsonValue, UserID
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
||||
|
||||
# The number of bytes that the serialized profile can have.
|
||||
MAX_PROFILE_SIZE = 65536
|
||||
|
||||
|
||||
class ProfileWorkerStore(SQLBaseStore):
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -201,6 +210,89 @@ class ProfileWorkerStore(SQLBaseStore):
|
|||
desc="get_profile_avatar_url",
|
||||
)
|
||||
|
||||
async def get_profile_field(self, user_id: UserID, field_name: str) -> JsonValue:
|
||||
"""
|
||||
Get a custom profile field for a user.
|
||||
|
||||
Args:
|
||||
user_id: The user's ID.
|
||||
field_name: The custom profile field name.
|
||||
|
||||
Returns:
|
||||
The string value if the field exists, otherwise raises 404.
|
||||
"""
|
||||
|
||||
def get_profile_field(txn: LoggingTransaction) -> JsonValue:
|
||||
# This will error if field_name has double quotes in it, but that's not
|
||||
# possible due to the grammar.
|
||||
field_path = f'$."{field_name}"'
|
||||
|
||||
if isinstance(self.database_engine, PostgresEngine):
|
||||
sql = """
|
||||
SELECT JSONB_PATH_EXISTS(fields, ?), JSONB_EXTRACT_PATH(fields, ?)
|
||||
FROM profiles
|
||||
WHERE user_id = ?
|
||||
"""
|
||||
txn.execute(
|
||||
sql,
|
||||
(field_path, field_name, user_id.localpart),
|
||||
)
|
||||
|
||||
# Test exists first since value being None is used for both
|
||||
# missing and a null JSON value.
|
||||
exists, value = cast(Tuple[bool, JsonValue], txn.fetchone())
|
||||
if not exists:
|
||||
raise StoreError(404, "No row found")
|
||||
return value
|
||||
|
||||
else:
|
||||
sql = """
|
||||
SELECT JSON_TYPE(fields, ?), JSON_EXTRACT(fields, ?)
|
||||
FROM profiles
|
||||
WHERE user_id = ?
|
||||
"""
|
||||
txn.execute(
|
||||
sql,
|
||||
(field_path, field_path, user_id.localpart),
|
||||
)
|
||||
|
||||
# If value_type is None, then the value did not exist.
|
||||
value_type, value = cast(
|
||||
Tuple[Optional[str], JsonValue], txn.fetchone()
|
||||
)
|
||||
if not value_type:
|
||||
raise StoreError(404, "No row found")
|
||||
# If value_type is object or array, then need to deserialize the JSON.
|
||||
# Scalar values are properly returned directly.
|
||||
if value_type in ("object", "array"):
|
||||
assert isinstance(value, str)
|
||||
return json.loads(value)
|
||||
return value
|
||||
|
||||
return await self.db_pool.runInteraction("get_profile_field", get_profile_field)
|
||||
|
||||
async def get_profile_fields(self, user_id: UserID) -> Dict[str, str]:
|
||||
"""
|
||||
Get all custom profile fields for a user.
|
||||
|
||||
Args:
|
||||
user_id: The user's ID.
|
||||
|
||||
Returns:
|
||||
A dictionary of custom profile fields.
|
||||
"""
|
||||
result = await self.db_pool.simple_select_one_onecol(
|
||||
table="profiles",
|
||||
keyvalues={"full_user_id": user_id.to_string()},
|
||||
retcol="fields",
|
||||
desc="get_profile_fields",
|
||||
)
|
||||
# The SQLite driver doesn't automatically convert JSON to
|
||||
# Python objects
|
||||
if isinstance(self.database_engine, Sqlite3Engine) and result:
|
||||
result = json.loads(result)
|
||||
return result or {}
|
||||
|
||||
async def create_profile(self, user_id: UserID) -> None:
|
||||
"""
|
||||
Create a blank profile for a user.
|
||||
|
@ -215,6 +307,71 @@ class ProfileWorkerStore(SQLBaseStore):
|
|||
desc="create_profile",
|
||||
)
|
||||
|
||||
def _check_profile_size(
|
||||
self,
|
||||
txn: LoggingTransaction,
|
||||
user_id: UserID,
|
||||
new_field_name: str,
|
||||
new_value: JsonValue,
|
||||
) -> None:
|
||||
# For each entry there are 4 quotes (2 each for key and value), 1 colon,
|
||||
# and 1 comma.
|
||||
PER_VALUE_EXTRA = 6
|
||||
|
||||
# Add the size of the current custom profile fields, ignoring the entry
|
||||
# which will be overwritten.
|
||||
if isinstance(txn.database_engine, PostgresEngine):
|
||||
size_sql = """
|
||||
SELECT
|
||||
OCTET_LENGTH((fields - ?)::text), OCTET_LENGTH(displayname), OCTET_LENGTH(avatar_url)
|
||||
FROM profiles
|
||||
WHERE
|
||||
user_id = ?
|
||||
"""
|
||||
txn.execute(
|
||||
size_sql,
|
||||
(new_field_name, user_id.localpart),
|
||||
)
|
||||
else:
|
||||
size_sql = """
|
||||
SELECT
|
||||
LENGTH(json_remove(fields, ?)), LENGTH(displayname), LENGTH(avatar_url)
|
||||
FROM profiles
|
||||
WHERE
|
||||
user_id = ?
|
||||
"""
|
||||
txn.execute(
|
||||
size_sql,
|
||||
# This will error if field_name has double quotes in it, but that's not
|
||||
# possible due to the grammar.
|
||||
(f'$."{new_field_name}"', user_id.localpart),
|
||||
)
|
||||
row = cast(Tuple[Optional[int], Optional[int], Optional[int]], txn.fetchone())
|
||||
|
||||
# The values return null if the column is null.
|
||||
total_bytes = (
|
||||
# Discount the opening and closing braces to avoid double counting,
|
||||
# but add one for a comma.
|
||||
# -2 + 1 = -1
|
||||
(row[0] - 1 if row[0] else 0)
|
||||
+ (
|
||||
row[1] + len("displayname") + PER_VALUE_EXTRA
|
||||
if new_field_name != ProfileFields.DISPLAYNAME and row[1]
|
||||
else 0
|
||||
)
|
||||
+ (
|
||||
row[2] + len("avatar_url") + PER_VALUE_EXTRA
|
||||
if new_field_name != ProfileFields.AVATAR_URL and row[2]
|
||||
else 0
|
||||
)
|
||||
)
|
||||
|
||||
# Add the length of the field being added + the braces.
|
||||
total_bytes += len(encode_canonical_json({new_field_name: new_value}))
|
||||
|
||||
if total_bytes > MAX_PROFILE_SIZE:
|
||||
raise StoreError(400, "Profile too large", Codes.PROFILE_TOO_LARGE)
|
||||
|
||||
async def set_profile_displayname(
|
||||
self, user_id: UserID, new_displayname: Optional[str]
|
||||
) -> None:
|
||||
|
@ -227,14 +384,25 @@ class ProfileWorkerStore(SQLBaseStore):
|
|||
name is removed.
|
||||
"""
|
||||
user_localpart = user_id.localpart
|
||||
await self.db_pool.simple_upsert(
|
||||
table="profiles",
|
||||
keyvalues={"user_id": user_localpart},
|
||||
values={
|
||||
"displayname": new_displayname,
|
||||
"full_user_id": user_id.to_string(),
|
||||
},
|
||||
desc="set_profile_displayname",
|
||||
|
||||
def set_profile_displayname(txn: LoggingTransaction) -> None:
|
||||
if new_displayname is not None:
|
||||
self._check_profile_size(
|
||||
txn, user_id, ProfileFields.DISPLAYNAME, new_displayname
|
||||
)
|
||||
|
||||
self.db_pool.simple_upsert_txn(
|
||||
txn,
|
||||
table="profiles",
|
||||
keyvalues={"user_id": user_localpart},
|
||||
values={
|
||||
"displayname": new_displayname,
|
||||
"full_user_id": user_id.to_string(),
|
||||
},
|
||||
)
|
||||
|
||||
await self.db_pool.runInteraction(
|
||||
"set_profile_displayname", set_profile_displayname
|
||||
)
|
||||
|
||||
async def set_profile_avatar_url(
|
||||
|
@ -249,13 +417,125 @@ class ProfileWorkerStore(SQLBaseStore):
|
|||
removed.
|
||||
"""
|
||||
user_localpart = user_id.localpart
|
||||
await self.db_pool.simple_upsert(
|
||||
table="profiles",
|
||||
keyvalues={"user_id": user_localpart},
|
||||
values={"avatar_url": new_avatar_url, "full_user_id": user_id.to_string()},
|
||||
desc="set_profile_avatar_url",
|
||||
|
||||
def set_profile_avatar_url(txn: LoggingTransaction) -> None:
|
||||
if new_avatar_url is not None:
|
||||
self._check_profile_size(
|
||||
txn, user_id, ProfileFields.AVATAR_URL, new_avatar_url
|
||||
)
|
||||
|
||||
self.db_pool.simple_upsert_txn(
|
||||
txn,
|
||||
table="profiles",
|
||||
keyvalues={"user_id": user_localpart},
|
||||
values={
|
||||
"avatar_url": new_avatar_url,
|
||||
"full_user_id": user_id.to_string(),
|
||||
},
|
||||
)
|
||||
|
||||
await self.db_pool.runInteraction(
|
||||
"set_profile_avatar_url", set_profile_avatar_url
|
||||
)
|
||||
|
||||
async def set_profile_field(
|
||||
self, user_id: UserID, field_name: str, new_value: JsonValue
|
||||
) -> None:
|
||||
"""
|
||||
Set a custom profile field for a user.
|
||||
|
||||
Args:
|
||||
user_id: The user's ID.
|
||||
field_name: The name of the custom profile field.
|
||||
new_value: The value of the custom profile field.
|
||||
"""
|
||||
|
||||
# Encode to canonical JSON.
|
||||
canonical_value = encode_canonical_json(new_value)
|
||||
|
||||
def set_profile_field(txn: LoggingTransaction) -> None:
|
||||
self._check_profile_size(txn, user_id, field_name, new_value)
|
||||
|
||||
if isinstance(self.database_engine, PostgresEngine):
|
||||
from psycopg2.extras import Json
|
||||
|
||||
# Note that the || jsonb operator is not recursive, any duplicate
|
||||
# keys will be taken from the second value.
|
||||
sql = """
|
||||
INSERT INTO profiles (user_id, full_user_id, fields) VALUES (?, ?, JSON_BUILD_OBJECT(?, ?::jsonb))
|
||||
ON CONFLICT (user_id)
|
||||
DO UPDATE SET full_user_id = EXCLUDED.full_user_id, fields = COALESCE(profiles.fields, '{}'::jsonb) || EXCLUDED.fields
|
||||
"""
|
||||
|
||||
txn.execute(
|
||||
sql,
|
||||
(
|
||||
user_id.localpart,
|
||||
user_id.to_string(),
|
||||
field_name,
|
||||
# Pass as a JSON object since we have passing bytes disabled
|
||||
# at the database driver.
|
||||
Json(json.loads(canonical_value)),
|
||||
),
|
||||
)
|
||||
else:
|
||||
# You may be tempted to use json_patch instead of providing the parameters
|
||||
# twice, but that recursively merges objects instead of replacing.
|
||||
sql = """
|
||||
INSERT INTO profiles (user_id, full_user_id, fields) VALUES (?, ?, JSON_OBJECT(?, JSON(?)))
|
||||
ON CONFLICT (user_id)
|
||||
DO UPDATE SET full_user_id = EXCLUDED.full_user_id, fields = JSON_SET(COALESCE(profiles.fields, '{}'), ?, JSON(?))
|
||||
"""
|
||||
# This will error if field_name has double quotes in it, but that's not
|
||||
# possible due to the grammar.
|
||||
json_field_name = f'$."{field_name}"'
|
||||
|
||||
txn.execute(
|
||||
sql,
|
||||
(
|
||||
user_id.localpart,
|
||||
user_id.to_string(),
|
||||
json_field_name,
|
||||
canonical_value,
|
||||
json_field_name,
|
||||
canonical_value,
|
||||
),
|
||||
)
|
||||
|
||||
await self.db_pool.runInteraction("set_profile_field", set_profile_field)
|
||||
|
||||
async def delete_profile_field(self, user_id: UserID, field_name: str) -> None:
|
||||
"""
|
||||
Remove a custom profile field for a user.
|
||||
|
||||
Args:
|
||||
user_id: The user's ID.
|
||||
field_name: The name of the custom profile field.
|
||||
"""
|
||||
|
||||
def delete_profile_field(txn: LoggingTransaction) -> None:
|
||||
if isinstance(self.database_engine, PostgresEngine):
|
||||
sql = """
|
||||
UPDATE profiles SET fields = fields - ?
|
||||
WHERE user_id = ?
|
||||
"""
|
||||
txn.execute(
|
||||
sql,
|
||||
(field_name, user_id.localpart),
|
||||
)
|
||||
else:
|
||||
sql = """
|
||||
UPDATE profiles SET fields = json_remove(fields, ?)
|
||||
WHERE user_id = ?
|
||||
"""
|
||||
txn.execute(
|
||||
sql,
|
||||
# This will error if field_name has double quotes in it.
|
||||
(f'$."{field_name}"', user_id.localpart),
|
||||
)
|
||||
|
||||
await self.db_pool.runInteraction("delete_profile_field", delete_profile_field)
|
||||
|
||||
|
||||
class ProfileStore(ProfileWorkerStore):
|
||||
pass
|
||||
|
|
|
@ -0,0 +1,15 @@
|
|||
--
|
||||
-- This file is licensed under the Affero General Public License (AGPL) version 3.
|
||||
--
|
||||
-- Copyright (C) 2024 Patrick Cloke
|
||||
--
|
||||
-- This program is free software: you can redistribute it and/or modify
|
||||
-- it under the terms of the GNU Affero General Public License as
|
||||
-- published by the Free Software Foundation, either version 3 of the
|
||||
-- License, or (at your option) any later version.
|
||||
--
|
||||
-- See the GNU Affero General Public License for more details:
|
||||
-- <https://www.gnu.org/licenses/agpl-3.0.html>.
|
||||
|
||||
-- Custom profile fields.
|
||||
ALTER TABLE profiles ADD COLUMN fields JSONB;
|
|
@ -43,6 +43,14 @@ CLIENT_SECRET_REGEX = re.compile(r"^[0-9a-zA-Z\.=_\-]+$")
|
|||
#
|
||||
MXC_REGEX = re.compile("^mxc://([^/]+)/([^/#?]+)$")
|
||||
|
||||
# https://spec.matrix.org/v1.13/appendices/#common-namespaced-identifier-grammar
|
||||
#
|
||||
# At least one character, less than or equal to 255 characters. Must start with
|
||||
# a-z, the rest is a-z, 0-9, -, _, or ..
|
||||
#
|
||||
# This doesn't check anything about validity of namespaces.
|
||||
NAMESPACED_GRAMMAR = re.compile(r"^[a-z][a-z0-9_.-]{0,254}$")
|
||||
|
||||
|
||||
def random_string(length: int) -> str:
|
||||
"""Generate a cryptographically secure string of random letters.
|
||||
|
@ -68,6 +76,10 @@ def is_ascii(s: bytes) -> bool:
|
|||
return True
|
||||
|
||||
|
||||
def is_namedspaced_grammar(s: str) -> bool:
|
||||
return bool(NAMESPACED_GRAMMAR.match(s))
|
||||
|
||||
|
||||
def assert_valid_client_secret(client_secret: str) -> None:
|
||||
"""Validate that a given string matches the client_secret defined by the spec"""
|
||||
if (
|
||||
|
|
|
@ -142,6 +142,50 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
|
|||
self.assertEqual(channel.code, HTTPStatus.OK)
|
||||
self.assertFalse(capabilities["m.set_avatar_url"]["enabled"])
|
||||
|
||||
@override_config(
|
||||
{
|
||||
"enable_set_displayname": False,
|
||||
"experimental_features": {"msc4133_enabled": True},
|
||||
}
|
||||
)
|
||||
def test_get_set_displayname_capabilities_displayname_disabled_msc4133(
|
||||
self,
|
||||
) -> None:
|
||||
"""Test if set displayname is disabled that the server responds it."""
|
||||
access_token = self.login(self.localpart, self.password)
|
||||
|
||||
channel = self.make_request("GET", self.url, access_token=access_token)
|
||||
capabilities = channel.json_body["capabilities"]
|
||||
|
||||
self.assertEqual(channel.code, HTTPStatus.OK)
|
||||
self.assertFalse(capabilities["m.set_displayname"]["enabled"])
|
||||
self.assertTrue(capabilities["uk.tcpip.msc4133.profile_fields"]["enabled"])
|
||||
self.assertEqual(
|
||||
capabilities["uk.tcpip.msc4133.profile_fields"]["disallowed"],
|
||||
["displayname"],
|
||||
)
|
||||
|
||||
@override_config(
|
||||
{
|
||||
"enable_set_avatar_url": False,
|
||||
"experimental_features": {"msc4133_enabled": True},
|
||||
}
|
||||
)
|
||||
def test_get_set_avatar_url_capabilities_avatar_url_disabled_msc4133(self) -> None:
|
||||
"""Test if set avatar_url is disabled that the server responds it."""
|
||||
access_token = self.login(self.localpart, self.password)
|
||||
|
||||
channel = self.make_request("GET", self.url, access_token=access_token)
|
||||
capabilities = channel.json_body["capabilities"]
|
||||
|
||||
self.assertEqual(channel.code, HTTPStatus.OK)
|
||||
self.assertFalse(capabilities["m.set_avatar_url"]["enabled"])
|
||||
self.assertTrue(capabilities["uk.tcpip.msc4133.profile_fields"]["enabled"])
|
||||
self.assertEqual(
|
||||
capabilities["uk.tcpip.msc4133.profile_fields"]["disallowed"],
|
||||
["avatar_url"],
|
||||
)
|
||||
|
||||
@override_config({"enable_3pid_changes": False})
|
||||
def test_get_change_3pid_capabilities_3pid_disabled(self) -> None:
|
||||
"""Test if change 3pid is disabled that the server responds it."""
|
||||
|
|
|
@ -25,16 +25,20 @@ import urllib.parse
|
|||
from http import HTTPStatus
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from canonicaljson import encode_canonical_json
|
||||
|
||||
from twisted.test.proto_helpers import MemoryReactor
|
||||
|
||||
from synapse.api.errors import Codes
|
||||
from synapse.rest import admin
|
||||
from synapse.rest.client import login, profile, room
|
||||
from synapse.server import HomeServer
|
||||
from synapse.storage.databases.main.profile import MAX_PROFILE_SIZE
|
||||
from synapse.types import UserID
|
||||
from synapse.util import Clock
|
||||
|
||||
from tests import unittest
|
||||
from tests.utils import USE_POSTGRES_FOR_TESTS
|
||||
|
||||
|
||||
class ProfileTestCase(unittest.HomeserverTestCase):
|
||||
|
@ -480,6 +484,298 @@ class ProfileTestCase(unittest.HomeserverTestCase):
|
|||
# The client requested ?propagate=true, so it should have happened.
|
||||
self.assertEqual(channel.json_body.get(prop), "http://my.server/pic.gif")
|
||||
|
||||
@unittest.override_config({"experimental_features": {"msc4133_enabled": True}})
|
||||
def test_get_missing_custom_field(self) -> None:
|
||||
channel = self.make_request(
|
||||
"GET",
|
||||
f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/custom_field",
|
||||
)
|
||||
self.assertEqual(channel.code, HTTPStatus.NOT_FOUND, channel.result)
|
||||
self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
|
||||
|
||||
@unittest.override_config({"experimental_features": {"msc4133_enabled": True}})
|
||||
def test_get_missing_custom_field_invalid_field_name(self) -> None:
|
||||
channel = self.make_request(
|
||||
"GET",
|
||||
f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/[custom_field]",
|
||||
)
|
||||
self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
|
||||
self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
|
||||
|
||||
@unittest.override_config({"experimental_features": {"msc4133_enabled": True}})
|
||||
def test_get_custom_field_rejects_bad_username(self) -> None:
|
||||
channel = self.make_request(
|
||||
"GET",
|
||||
f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{urllib.parse.quote('@alice:')}/custom_field",
|
||||
)
|
||||
self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
|
||||
self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
|
||||
|
||||
@unittest.override_config({"experimental_features": {"msc4133_enabled": True}})
|
||||
def test_set_custom_field(self) -> None:
|
||||
channel = self.make_request(
|
||||
"PUT",
|
||||
f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/custom_field",
|
||||
content={"custom_field": "test"},
|
||||
access_token=self.owner_tok,
|
||||
)
|
||||
self.assertEqual(channel.code, 200, channel.result)
|
||||
|
||||
channel = self.make_request(
|
||||
"GET",
|
||||
f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/custom_field",
|
||||
)
|
||||
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
|
||||
self.assertEqual(channel.json_body, {"custom_field": "test"})
|
||||
|
||||
# Overwriting the field should work.
|
||||
channel = self.make_request(
|
||||
"PUT",
|
||||
f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/custom_field",
|
||||
content={"custom_field": "new_Value"},
|
||||
access_token=self.owner_tok,
|
||||
)
|
||||
self.assertEqual(channel.code, 200, channel.result)
|
||||
|
||||
channel = self.make_request(
|
||||
"GET",
|
||||
f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/custom_field",
|
||||
)
|
||||
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
|
||||
self.assertEqual(channel.json_body, {"custom_field": "new_Value"})
|
||||
|
||||
# Deleting the field should work.
|
||||
channel = self.make_request(
|
||||
"DELETE",
|
||||
f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/custom_field",
|
||||
content={},
|
||||
access_token=self.owner_tok,
|
||||
)
|
||||
self.assertEqual(channel.code, 200, channel.result)
|
||||
|
||||
channel = self.make_request(
|
||||
"GET",
|
||||
f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/custom_field",
|
||||
)
|
||||
self.assertEqual(channel.code, HTTPStatus.NOT_FOUND, channel.result)
|
||||
self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
|
||||
|
||||
@unittest.override_config({"experimental_features": {"msc4133_enabled": True}})
|
||||
def test_non_string(self) -> None:
|
||||
"""Non-string fields are supported for custom fields."""
|
||||
fields = {
|
||||
"bool_field": True,
|
||||
"array_field": ["test"],
|
||||
"object_field": {"test": "test"},
|
||||
"numeric_field": 1,
|
||||
"null_field": None,
|
||||
}
|
||||
|
||||
for key, value in fields.items():
|
||||
channel = self.make_request(
|
||||
"PUT",
|
||||
f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/{key}",
|
||||
content={key: value},
|
||||
access_token=self.owner_tok,
|
||||
)
|
||||
self.assertEqual(channel.code, 200, channel.result)
|
||||
|
||||
channel = self.make_request(
|
||||
"GET",
|
||||
f"/_matrix/client/v3/profile/{self.owner}",
|
||||
)
|
||||
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
|
||||
self.assertEqual(channel.json_body, {"displayname": "owner", **fields})
|
||||
|
||||
# Check getting individual fields works.
|
||||
for key, value in fields.items():
|
||||
channel = self.make_request(
|
||||
"GET",
|
||||
f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/{key}",
|
||||
)
|
||||
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
|
||||
self.assertEqual(channel.json_body, {key: value})
|
||||
|
||||
@unittest.override_config({"experimental_features": {"msc4133_enabled": True}})
|
||||
def test_set_custom_field_noauth(self) -> None:
|
||||
channel = self.make_request(
|
||||
"PUT",
|
||||
f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/custom_field",
|
||||
content={"custom_field": "test"},
|
||||
)
|
||||
self.assertEqual(channel.code, 401, channel.result)
|
||||
self.assertEqual(channel.json_body["errcode"], Codes.MISSING_TOKEN)
|
||||
|
||||
@unittest.override_config({"experimental_features": {"msc4133_enabled": True}})
|
||||
def test_set_custom_field_size(self) -> None:
|
||||
"""
|
||||
Attempts to set a custom field name that is too long should get a 400 error.
|
||||
"""
|
||||
# Key is missing.
|
||||
channel = self.make_request(
|
||||
"PUT",
|
||||
f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/",
|
||||
content={"": "test"},
|
||||
access_token=self.owner_tok,
|
||||
)
|
||||
self.assertEqual(channel.code, 400, channel.result)
|
||||
self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
|
||||
|
||||
# Single key is too large.
|
||||
key = "c" * 500
|
||||
channel = self.make_request(
|
||||
"PUT",
|
||||
f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/{key}",
|
||||
content={key: "test"},
|
||||
access_token=self.owner_tok,
|
||||
)
|
||||
self.assertEqual(channel.code, 400, channel.result)
|
||||
self.assertEqual(channel.json_body["errcode"], Codes.KEY_TOO_LARGE)
|
||||
|
||||
channel = self.make_request(
|
||||
"DELETE",
|
||||
f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/{key}",
|
||||
content={key: "test"},
|
||||
access_token=self.owner_tok,
|
||||
)
|
||||
self.assertEqual(channel.code, 400, channel.result)
|
||||
self.assertEqual(channel.json_body["errcode"], Codes.KEY_TOO_LARGE)
|
||||
|
||||
# Key doesn't match body.
|
||||
channel = self.make_request(
|
||||
"PUT",
|
||||
f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/custom_field",
|
||||
content={"diff_key": "test"},
|
||||
access_token=self.owner_tok,
|
||||
)
|
||||
self.assertEqual(channel.code, 400, channel.result)
|
||||
self.assertEqual(channel.json_body["errcode"], Codes.MISSING_PARAM)
|
||||
|
||||
@unittest.override_config({"experimental_features": {"msc4133_enabled": True}})
|
||||
def test_set_custom_field_profile_too_long(self) -> None:
|
||||
"""
|
||||
Attempts to set a custom field that would push the overall profile too large.
|
||||
"""
|
||||
# Get right to the boundary:
|
||||
# len("displayname") + len("owner") + 5 = 21 for the displayname
|
||||
# 1 + 65498 + 5 for key "a" = 65504
|
||||
# 2 braces, 1 comma
|
||||
# 3 + 21 + 65498 = 65522 < 65536.
|
||||
key = "a"
|
||||
channel = self.make_request(
|
||||
"PUT",
|
||||
f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/{key}",
|
||||
content={key: "a" * 65498},
|
||||
access_token=self.owner_tok,
|
||||
)
|
||||
self.assertEqual(channel.code, 200, channel.result)
|
||||
|
||||
# Get the entire profile.
|
||||
channel = self.make_request(
|
||||
"GET",
|
||||
f"/_matrix/client/v3/profile/{self.owner}",
|
||||
access_token=self.owner_tok,
|
||||
)
|
||||
self.assertEqual(channel.code, 200, channel.result)
|
||||
canonical_json = encode_canonical_json(channel.json_body)
|
||||
# 6 is the minimum bytes to store a value: 4 quotes, 1 colon, 1 comma, an empty key.
|
||||
# Be one below that so we can prove we're at the boundary.
|
||||
self.assertEqual(len(canonical_json), MAX_PROFILE_SIZE - 8)
|
||||
|
||||
# Postgres stores JSONB with whitespace, while SQLite doesn't.
|
||||
if USE_POSTGRES_FOR_TESTS:
|
||||
ADDITIONAL_CHARS = 0
|
||||
else:
|
||||
ADDITIONAL_CHARS = 1
|
||||
|
||||
# The next one should fail, note the value has a (JSON) length of 2.
|
||||
key = "b"
|
||||
channel = self.make_request(
|
||||
"PUT",
|
||||
f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/{key}",
|
||||
content={key: "1" + "a" * ADDITIONAL_CHARS},
|
||||
access_token=self.owner_tok,
|
||||
)
|
||||
self.assertEqual(channel.code, 400, channel.result)
|
||||
self.assertEqual(channel.json_body["errcode"], Codes.PROFILE_TOO_LARGE)
|
||||
|
||||
# Setting an avatar or (longer) display name should not work.
|
||||
channel = self.make_request(
|
||||
"PUT",
|
||||
f"/profile/{self.owner}/displayname",
|
||||
content={"displayname": "owner12345678" + "a" * ADDITIONAL_CHARS},
|
||||
access_token=self.owner_tok,
|
||||
)
|
||||
self.assertEqual(channel.code, 400, channel.result)
|
||||
self.assertEqual(channel.json_body["errcode"], Codes.PROFILE_TOO_LARGE)
|
||||
|
||||
channel = self.make_request(
|
||||
"PUT",
|
||||
f"/profile/{self.owner}/avatar_url",
|
||||
content={"avatar_url": "mxc://foo/bar"},
|
||||
access_token=self.owner_tok,
|
||||
)
|
||||
self.assertEqual(channel.code, 400, channel.result)
|
||||
self.assertEqual(channel.json_body["errcode"], Codes.PROFILE_TOO_LARGE)
|
||||
|
||||
# Removing a single byte should work.
|
||||
key = "b"
|
||||
channel = self.make_request(
|
||||
"PUT",
|
||||
f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/{key}",
|
||||
content={key: "" + "a" * ADDITIONAL_CHARS},
|
||||
access_token=self.owner_tok,
|
||||
)
|
||||
self.assertEqual(channel.code, 200, channel.result)
|
||||
|
||||
# Finally, setting a field that already exists to a value that is <= in length should work.
|
||||
key = "a"
|
||||
channel = self.make_request(
|
||||
"PUT",
|
||||
f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/{key}",
|
||||
content={key: ""},
|
||||
access_token=self.owner_tok,
|
||||
)
|
||||
self.assertEqual(channel.code, 200, channel.result)
|
||||
|
||||
@unittest.override_config({"experimental_features": {"msc4133_enabled": True}})
|
||||
def test_set_custom_field_displayname(self) -> None:
|
||||
channel = self.make_request(
|
||||
"PUT",
|
||||
f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/displayname",
|
||||
content={"displayname": "test"},
|
||||
access_token=self.owner_tok,
|
||||
)
|
||||
self.assertEqual(channel.code, 200, channel.result)
|
||||
|
||||
displayname = self._get_displayname()
|
||||
self.assertEqual(displayname, "test")
|
||||
|
||||
@unittest.override_config({"experimental_features": {"msc4133_enabled": True}})
|
||||
def test_set_custom_field_avatar_url(self) -> None:
|
||||
channel = self.make_request(
|
||||
"PUT",
|
||||
f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/avatar_url",
|
||||
content={"avatar_url": "mxc://test/good"},
|
||||
access_token=self.owner_tok,
|
||||
)
|
||||
self.assertEqual(channel.code, 200, channel.result)
|
||||
|
||||
avatar_url = self._get_avatar_url()
|
||||
self.assertEqual(avatar_url, "mxc://test/good")
|
||||
|
||||
@unittest.override_config({"experimental_features": {"msc4133_enabled": True}})
|
||||
def test_set_custom_field_other(self) -> None:
|
||||
"""Setting someone else's profile field should fail"""
|
||||
channel = self.make_request(
|
||||
"PUT",
|
||||
f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.other}/custom_field",
|
||||
content={"custom_field": "test"},
|
||||
access_token=self.owner_tok,
|
||||
)
|
||||
self.assertEqual(channel.code, 403, channel.result)
|
||||
self.assertEqual(channel.json_body["errcode"], Codes.FORBIDDEN)
|
||||
|
||||
def _setup_local_files(self, names_and_props: Dict[str, Dict[str, Any]]) -> None:
|
||||
"""Stores metadata about files in the database.
|
||||
|
||||
|
|
|
@ -20,7 +20,11 @@
|
|||
#
|
||||
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.util.stringutils import assert_valid_client_secret, base62_encode
|
||||
from synapse.util.stringutils import (
|
||||
assert_valid_client_secret,
|
||||
base62_encode,
|
||||
is_namedspaced_grammar,
|
||||
)
|
||||
|
||||
from .. import unittest
|
||||
|
||||
|
@ -58,3 +62,25 @@ class StringUtilsTestCase(unittest.TestCase):
|
|||
self.assertEqual("10", base62_encode(62))
|
||||
self.assertEqual("1c", base62_encode(100))
|
||||
self.assertEqual("001c", base62_encode(100, minwidth=4))
|
||||
|
||||
def test_namespaced_identifier(self) -> None:
|
||||
self.assertTrue(is_namedspaced_grammar("test"))
|
||||
self.assertTrue(is_namedspaced_grammar("m.test"))
|
||||
self.assertTrue(is_namedspaced_grammar("org.matrix.test"))
|
||||
self.assertTrue(is_namedspaced_grammar("org.matrix.msc1234"))
|
||||
self.assertTrue(is_namedspaced_grammar("test"))
|
||||
self.assertTrue(is_namedspaced_grammar("t-e_s.t"))
|
||||
|
||||
# Must start with letter.
|
||||
self.assertFalse(is_namedspaced_grammar("1test"))
|
||||
self.assertFalse(is_namedspaced_grammar("-test"))
|
||||
self.assertFalse(is_namedspaced_grammar("_test"))
|
||||
self.assertFalse(is_namedspaced_grammar(".test"))
|
||||
|
||||
# Must contain only a-z, 0-9, -, _, ..
|
||||
self.assertFalse(is_namedspaced_grammar("test/"))
|
||||
self.assertFalse(is_namedspaced_grammar('test"'))
|
||||
self.assertFalse(is_namedspaced_grammar("testö"))
|
||||
|
||||
# Must be < 255 characters.
|
||||
self.assertFalse(is_namedspaced_grammar("t" * 256))
|
||||
|
|
Loading…
Add table
Reference in a new issue