Implement MSC4133 to support custom profile fields. (#17488)

Implementation of
[MSC4133](https://github.com/matrix-org/matrix-spec-proposals/pull/4133)
to support custom profile fields. It is behind an experimental flag and
includes tests.


### Pull Request Checklist

<!-- Please read
https://element-hq.github.io/synapse/latest/development/contributing_guide.html
before submitting your pull request -->

* [x] Pull request is based on the develop branch
* [x] Pull request includes a [changelog
file](https://element-hq.github.io/synapse/latest/development/contributing_guide.html#changelog).
The entry should:
- Be a short description of your change which makes sense to users.
"Fixed a bug that prevented receiving messages from other servers."
instead of "Moved X method from `EventStore` to `EventWorkerStore`.".
  - Use markdown where necessary, mostly for `code blocks`.
  - End with either a period (.) or an exclamation mark (!).
  - Start with a capital letter.
- Feel free to credit yourself, by adding a sentence "Contributed by
@github_username." or "Contributed by [Your Name]." to the end of the
entry.
* [x] [Code
style](https://element-hq.github.io/synapse/latest/code_style.html) is
correct
(run the
[linters](https://element-hq.github.io/synapse/latest/development/contributing_guide.html#run-the-linters))

---------

Co-authored-by: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com>
This commit is contained in:
Patrick Cloke 2025-01-21 06:11:04 -05:00 committed by GitHub
parent 0a31cf18cd
commit ca290d325c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
13 changed files with 1039 additions and 26 deletions

View file

@ -0,0 +1 @@
Implement [MSC4133](https://github.com/matrix-org/matrix-spec-proposals/pull/4133) for custom profile fields.

View file

@ -132,6 +132,10 @@ class Codes(str, Enum):
# connection.
UNKNOWN_POS = "M_UNKNOWN_POS"
# Part of MSC4133
PROFILE_TOO_LARGE = "M_PROFILE_TOO_LARGE"
KEY_TOO_LARGE = "M_KEY_TOO_LARGE"
class CodeMessageException(RuntimeError):
"""An exception with integer code, a message string attributes and optional headers.

View file

@ -436,6 +436,9 @@ class ExperimentalConfig(Config):
("experimental", "msc4108_delegation_endpoint"),
)
# MSC4133: Custom profile fields
self.msc4133_enabled: bool = experimental.get("msc4133_enabled", False)
# MSC4210: Remove legacy mentions
self.msc4210_enabled: bool = experimental.get("msc4210_enabled", False)

View file

@ -32,7 +32,7 @@ from synapse.api.errors import (
SynapseError,
)
from synapse.storage.databases.main.media_repository import LocalMedia, RemoteMedia
from synapse.types import JsonDict, Requester, UserID, create_requester
from synapse.types import JsonDict, JsonValue, Requester, UserID, create_requester
from synapse.util.caches.descriptors import cached
from synapse.util.stringutils import parse_and_validate_mxc_uri
@ -43,6 +43,8 @@ logger = logging.getLogger(__name__)
MAX_DISPLAYNAME_LEN = 256
MAX_AVATAR_URL_LEN = 1000
# Field name length is specced at 255 bytes.
MAX_CUSTOM_FIELD_LEN = 255
class ProfileHandler:
@ -90,7 +92,15 @@ class ProfileHandler:
if self.hs.is_mine(target_user):
profileinfo = await self.store.get_profileinfo(target_user)
if profileinfo.display_name is None and profileinfo.avatar_url is None:
extra_fields = {}
if self.hs.config.experimental.msc4133_enabled:
extra_fields = await self.store.get_profile_fields(target_user)
if (
profileinfo.display_name is None
and profileinfo.avatar_url is None
and not extra_fields
):
raise SynapseError(404, "Profile was not found", Codes.NOT_FOUND)
# Do not include display name or avatar if unset.
@ -99,6 +109,9 @@ class ProfileHandler:
ret[ProfileFields.DISPLAYNAME] = profileinfo.display_name
if profileinfo.avatar_url is not None:
ret[ProfileFields.AVATAR_URL] = profileinfo.avatar_url
if extra_fields:
ret.update(extra_fields)
return ret
else:
try:
@ -403,6 +416,110 @@ class ProfileHandler:
return True
async def get_profile_field(
self, target_user: UserID, field_name: str
) -> JsonValue:
"""
Fetch a user's profile from the database for local users and over federation
for remote users.
Args:
target_user: The user ID to fetch the profile for.
field_name: The field to fetch the profile for.
Returns:
The value for the profile field or None if the field does not exist.
"""
if self.hs.is_mine(target_user):
try:
field_value = await self.store.get_profile_field(
target_user, field_name
)
except StoreError as e:
if e.code == 404:
raise SynapseError(404, "Profile was not found", Codes.NOT_FOUND)
raise
return field_value
else:
try:
result = await self.federation.make_query(
destination=target_user.domain,
query_type="profile",
args={"user_id": target_user.to_string(), "field": field_name},
ignore_backoff=True,
)
except RequestSendFailed as e:
raise SynapseError(502, "Failed to fetch profile") from e
except HttpResponseException as e:
raise e.to_synapse_error()
return result.get(field_name)
async def set_profile_field(
self,
target_user: UserID,
requester: Requester,
field_name: str,
new_value: JsonValue,
by_admin: bool = False,
deactivation: bool = False,
) -> None:
"""Set a new profile field for a user.
Args:
target_user: the user whose profile is to be changed.
requester: The user attempting to make this change.
field_name: The name of the profile field to update.
new_value: The new field value for this user.
by_admin: Whether this change was made by an administrator.
deactivation: Whether this change was made while deactivating the user.
"""
if not self.hs.is_mine(target_user):
raise SynapseError(400, "User is not hosted on this homeserver")
if not by_admin and target_user != requester.user:
raise AuthError(403, "Cannot set another user's profile")
await self.store.set_profile_field(target_user, field_name, new_value)
# Custom fields do not propagate into the user directory *or* rooms.
profile = await self.store.get_profileinfo(target_user)
await self._third_party_rules.on_profile_update(
target_user.to_string(), profile, by_admin, deactivation
)
async def delete_profile_field(
self,
target_user: UserID,
requester: Requester,
field_name: str,
by_admin: bool = False,
deactivation: bool = False,
) -> None:
"""Delete a field from a user's profile.
Args:
target_user: the user whose profile is to be changed.
requester: The user attempting to make this change.
field_name: The name of the profile field to remove.
by_admin: Whether this change was made by an administrator.
deactivation: Whether this change was made while deactivating the user.
"""
if not self.hs.is_mine(target_user):
raise SynapseError(400, "User is not hosted on this homeserver")
if not by_admin and target_user != requester.user:
raise AuthError(400, "Cannot set another user's profile")
await self.store.delete_profile_field(target_user, field_name)
# Custom fields do not propagate into the user directory *or* rooms.
profile = await self.store.get_profileinfo(target_user)
await self._third_party_rules.on_profile_update(
target_user.to_string(), profile, by_admin, deactivation
)
async def on_profile_query(self, args: JsonDict) -> JsonDict:
"""Handles federation profile query requests."""
@ -419,13 +536,24 @@ class ProfileHandler:
just_field = args.get("field", None)
response = {}
response: JsonDict = {}
try:
if just_field is None or just_field == "displayname":
if just_field is None or just_field == ProfileFields.DISPLAYNAME:
response["displayname"] = await self.store.get_profile_displayname(user)
if just_field is None or just_field == "avatar_url":
if just_field is None or just_field == ProfileFields.AVATAR_URL:
response["avatar_url"] = await self.store.get_profile_avatar_url(user)
if self.hs.config.experimental.msc4133_enabled:
if just_field is None:
response.update(await self.store.get_profile_fields(user))
elif just_field not in (
ProfileFields.DISPLAYNAME,
ProfileFields.AVATAR_URL,
):
response[just_field] = await self.store.get_profile_field(
user, just_field
)
except StoreError as e:
if e.code == 404:
raise SynapseError(404, "Profile was not found", Codes.NOT_FOUND)

View file

@ -92,6 +92,23 @@ class CapabilitiesRestServlet(RestServlet):
"enabled": self.config.experimental.msc3664_enabled,
}
if self.config.experimental.msc4133_enabled:
response["capabilities"]["uk.tcpip.msc4133.profile_fields"] = {
"enabled": True,
}
# Ensure this is consistent with the legacy m.set_displayname and
# m.set_avatar_url.
disallowed = []
if not self.config.registration.enable_set_displayname:
disallowed.append("displayname")
if not self.config.registration.enable_set_avatar_url:
disallowed.append("avatar_url")
if disallowed:
response["capabilities"]["uk.tcpip.msc4133.profile_fields"][
"disallowed"
] = disallowed
return HTTPStatus.OK, response

View file

@ -21,10 +21,13 @@
"""This module contains REST servlets to do with profile: /profile/<paths>"""
import re
from http import HTTPStatus
from typing import TYPE_CHECKING, Tuple
from synapse.api.constants import ProfileFields
from synapse.api.errors import Codes, SynapseError
from synapse.handlers.profile import MAX_CUSTOM_FIELD_LEN
from synapse.http.server import HttpServer
from synapse.http.servlet import (
RestServlet,
@ -33,7 +36,8 @@ from synapse.http.servlet import (
)
from synapse.http.site import SynapseRequest
from synapse.rest.client._base import client_patterns
from synapse.types import JsonDict, UserID
from synapse.types import JsonDict, JsonValue, UserID
from synapse.util.stringutils import is_namedspaced_grammar
if TYPE_CHECKING:
from synapse.server import HomeServer
@ -91,6 +95,11 @@ class ProfileDisplaynameRestServlet(RestServlet):
async def on_PUT(
self, request: SynapseRequest, user_id: str
) -> Tuple[int, JsonDict]:
if not UserID.is_valid(user_id):
raise SynapseError(
HTTPStatus.BAD_REQUEST, "Invalid user id", Codes.INVALID_PARAM
)
requester = await self.auth.get_user_by_req(request, allow_guest=True)
user = UserID.from_string(user_id)
is_admin = await self.auth.is_server_admin(requester)
@ -101,9 +110,7 @@ class ProfileDisplaynameRestServlet(RestServlet):
new_name = content["displayname"]
except Exception:
raise SynapseError(
code=400,
msg="Unable to parse name",
errcode=Codes.BAD_JSON,
400, "Missing key 'displayname'", errcode=Codes.MISSING_PARAM
)
propagate = _read_propagate(self.hs, request)
@ -166,6 +173,11 @@ class ProfileAvatarURLRestServlet(RestServlet):
async def on_PUT(
self, request: SynapseRequest, user_id: str
) -> Tuple[int, JsonDict]:
if not UserID.is_valid(user_id):
raise SynapseError(
HTTPStatus.BAD_REQUEST, "Invalid user id", Codes.INVALID_PARAM
)
requester = await self.auth.get_user_by_req(request)
user = UserID.from_string(user_id)
is_admin = await self.auth.is_server_admin(requester)
@ -232,7 +244,180 @@ class ProfileRestServlet(RestServlet):
return 200, ret
class UnstableProfileFieldRestServlet(RestServlet):
PATTERNS = [
re.compile(
r"^/_matrix/client/unstable/uk\.tcpip\.msc4133/profile/(?P<user_id>[^/]*)/(?P<field_name>[^/]*)"
)
]
CATEGORY = "Event sending requests"
def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.profile_handler = hs.get_profile_handler()
self.auth = hs.get_auth()
async def on_GET(
self, request: SynapseRequest, user_id: str, field_name: str
) -> Tuple[int, JsonDict]:
requester_user = None
if self.hs.config.server.require_auth_for_profile_requests:
requester = await self.auth.get_user_by_req(request)
requester_user = requester.user
if not UserID.is_valid(user_id):
raise SynapseError(
HTTPStatus.BAD_REQUEST, "Invalid user id", Codes.INVALID_PARAM
)
if not field_name:
raise SynapseError(400, "Field name too short", errcode=Codes.INVALID_PARAM)
if len(field_name.encode("utf-8")) > MAX_CUSTOM_FIELD_LEN:
raise SynapseError(400, "Field name too long", errcode=Codes.KEY_TOO_LARGE)
if not is_namedspaced_grammar(field_name):
raise SynapseError(
400,
"Field name does not follow Common Namespaced Identifier Grammar",
errcode=Codes.INVALID_PARAM,
)
user = UserID.from_string(user_id)
await self.profile_handler.check_profile_query_allowed(user, requester_user)
if field_name == ProfileFields.DISPLAYNAME:
field_value: JsonValue = await self.profile_handler.get_displayname(user)
elif field_name == ProfileFields.AVATAR_URL:
field_value = await self.profile_handler.get_avatar_url(user)
else:
field_value = await self.profile_handler.get_profile_field(user, field_name)
return 200, {field_name: field_value}
async def on_PUT(
self, request: SynapseRequest, user_id: str, field_name: str
) -> Tuple[int, JsonDict]:
if not UserID.is_valid(user_id):
raise SynapseError(
HTTPStatus.BAD_REQUEST, "Invalid user id", Codes.INVALID_PARAM
)
requester = await self.auth.get_user_by_req(request)
user = UserID.from_string(user_id)
is_admin = await self.auth.is_server_admin(requester)
if not field_name:
raise SynapseError(400, "Field name too short", errcode=Codes.INVALID_PARAM)
if len(field_name.encode("utf-8")) > MAX_CUSTOM_FIELD_LEN:
raise SynapseError(400, "Field name too long", errcode=Codes.KEY_TOO_LARGE)
if not is_namedspaced_grammar(field_name):
raise SynapseError(
400,
"Field name does not follow Common Namespaced Identifier Grammar",
errcode=Codes.INVALID_PARAM,
)
content = parse_json_object_from_request(request)
try:
new_value = content[field_name]
except KeyError:
raise SynapseError(
400, f"Missing key '{field_name}'", errcode=Codes.MISSING_PARAM
)
propagate = _read_propagate(self.hs, request)
requester_suspended = (
await self.hs.get_datastores().main.get_user_suspended_status(
requester.user.to_string()
)
)
if requester_suspended:
raise SynapseError(
403,
"Updating profile while account is suspended is not allowed.",
Codes.USER_ACCOUNT_SUSPENDED,
)
if field_name == ProfileFields.DISPLAYNAME:
await self.profile_handler.set_displayname(
user, requester, new_value, is_admin, propagate=propagate
)
elif field_name == ProfileFields.AVATAR_URL:
await self.profile_handler.set_avatar_url(
user, requester, new_value, is_admin, propagate=propagate
)
else:
await self.profile_handler.set_profile_field(
user, requester, field_name, new_value, is_admin
)
return 200, {}
async def on_DELETE(
self, request: SynapseRequest, user_id: str, field_name: str
) -> Tuple[int, JsonDict]:
if not UserID.is_valid(user_id):
raise SynapseError(
HTTPStatus.BAD_REQUEST, "Invalid user id", Codes.INVALID_PARAM
)
requester = await self.auth.get_user_by_req(request)
user = UserID.from_string(user_id)
is_admin = await self.auth.is_server_admin(requester)
if not field_name:
raise SynapseError(400, "Field name too short", errcode=Codes.INVALID_PARAM)
if len(field_name.encode("utf-8")) > MAX_CUSTOM_FIELD_LEN:
raise SynapseError(400, "Field name too long", errcode=Codes.KEY_TOO_LARGE)
if not is_namedspaced_grammar(field_name):
raise SynapseError(
400,
"Field name does not follow Common Namespaced Identifier Grammar",
errcode=Codes.INVALID_PARAM,
)
propagate = _read_propagate(self.hs, request)
requester_suspended = (
await self.hs.get_datastores().main.get_user_suspended_status(
requester.user.to_string()
)
)
if requester_suspended:
raise SynapseError(
403,
"Updating profile while account is suspended is not allowed.",
Codes.USER_ACCOUNT_SUSPENDED,
)
if field_name == ProfileFields.DISPLAYNAME:
await self.profile_handler.set_displayname(
user, requester, "", is_admin, propagate=propagate
)
elif field_name == ProfileFields.AVATAR_URL:
await self.profile_handler.set_avatar_url(
user, requester, "", is_admin, propagate=propagate
)
else:
await self.profile_handler.delete_profile_field(
user, requester, field_name, is_admin
)
return 200, {}
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
# The specific displayname / avatar URL / custom field endpoints *must* appear
# before their corresponding generic profile endpoint.
ProfileDisplaynameRestServlet(hs).register(http_server)
ProfileAvatarURLRestServlet(hs).register(http_server)
ProfileRestServlet(hs).register(http_server)
if hs.config.experimental.msc4133_enabled:
UnstableProfileFieldRestServlet(hs).register(http_server)

View file

@ -172,6 +172,8 @@ class VersionsRestServlet(RestServlet):
"org.matrix.msc4140": bool(self.config.server.max_event_delay_ms),
# Simplified sliding sync
"org.matrix.simplified_msc3575": msc3575_enabled,
# Arbitrary key-value profile fields.
"uk.tcpip.msc4133": self.config.experimental.msc4133_enabled,
},
},
)

View file

@ -18,8 +18,13 @@
# [This file includes modifications made by New Vector Limited]
#
#
from typing import TYPE_CHECKING, Optional
import json
from typing import TYPE_CHECKING, Dict, Optional, Tuple, cast
from canonicaljson import encode_canonical_json
from synapse.api.constants import ProfileFields
from synapse.api.errors import Codes, StoreError
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import (
DatabasePool,
@ -27,13 +32,17 @@ from synapse.storage.database import (
LoggingTransaction,
)
from synapse.storage.databases.main.roommember import ProfileInfo
from synapse.storage.engines import PostgresEngine
from synapse.types import JsonDict, UserID
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
from synapse.types import JsonDict, JsonValue, UserID
if TYPE_CHECKING:
from synapse.server import HomeServer
# The number of bytes that the serialized profile can have.
MAX_PROFILE_SIZE = 65536
class ProfileWorkerStore(SQLBaseStore):
def __init__(
self,
@ -201,6 +210,89 @@ class ProfileWorkerStore(SQLBaseStore):
desc="get_profile_avatar_url",
)
async def get_profile_field(self, user_id: UserID, field_name: str) -> JsonValue:
"""
Get a custom profile field for a user.
Args:
user_id: The user's ID.
field_name: The custom profile field name.
Returns:
The string value if the field exists, otherwise raises 404.
"""
def get_profile_field(txn: LoggingTransaction) -> JsonValue:
# This will error if field_name has double quotes in it, but that's not
# possible due to the grammar.
field_path = f'$."{field_name}"'
if isinstance(self.database_engine, PostgresEngine):
sql = """
SELECT JSONB_PATH_EXISTS(fields, ?), JSONB_EXTRACT_PATH(fields, ?)
FROM profiles
WHERE user_id = ?
"""
txn.execute(
sql,
(field_path, field_name, user_id.localpart),
)
# Test exists first since value being None is used for both
# missing and a null JSON value.
exists, value = cast(Tuple[bool, JsonValue], txn.fetchone())
if not exists:
raise StoreError(404, "No row found")
return value
else:
sql = """
SELECT JSON_TYPE(fields, ?), JSON_EXTRACT(fields, ?)
FROM profiles
WHERE user_id = ?
"""
txn.execute(
sql,
(field_path, field_path, user_id.localpart),
)
# If value_type is None, then the value did not exist.
value_type, value = cast(
Tuple[Optional[str], JsonValue], txn.fetchone()
)
if not value_type:
raise StoreError(404, "No row found")
# If value_type is object or array, then need to deserialize the JSON.
# Scalar values are properly returned directly.
if value_type in ("object", "array"):
assert isinstance(value, str)
return json.loads(value)
return value
return await self.db_pool.runInteraction("get_profile_field", get_profile_field)
async def get_profile_fields(self, user_id: UserID) -> Dict[str, str]:
"""
Get all custom profile fields for a user.
Args:
user_id: The user's ID.
Returns:
A dictionary of custom profile fields.
"""
result = await self.db_pool.simple_select_one_onecol(
table="profiles",
keyvalues={"full_user_id": user_id.to_string()},
retcol="fields",
desc="get_profile_fields",
)
# The SQLite driver doesn't automatically convert JSON to
# Python objects
if isinstance(self.database_engine, Sqlite3Engine) and result:
result = json.loads(result)
return result or {}
async def create_profile(self, user_id: UserID) -> None:
"""
Create a blank profile for a user.
@ -215,6 +307,71 @@ class ProfileWorkerStore(SQLBaseStore):
desc="create_profile",
)
def _check_profile_size(
self,
txn: LoggingTransaction,
user_id: UserID,
new_field_name: str,
new_value: JsonValue,
) -> None:
# For each entry there are 4 quotes (2 each for key and value), 1 colon,
# and 1 comma.
PER_VALUE_EXTRA = 6
# Add the size of the current custom profile fields, ignoring the entry
# which will be overwritten.
if isinstance(txn.database_engine, PostgresEngine):
size_sql = """
SELECT
OCTET_LENGTH((fields - ?)::text), OCTET_LENGTH(displayname), OCTET_LENGTH(avatar_url)
FROM profiles
WHERE
user_id = ?
"""
txn.execute(
size_sql,
(new_field_name, user_id.localpart),
)
else:
size_sql = """
SELECT
LENGTH(json_remove(fields, ?)), LENGTH(displayname), LENGTH(avatar_url)
FROM profiles
WHERE
user_id = ?
"""
txn.execute(
size_sql,
# This will error if field_name has double quotes in it, but that's not
# possible due to the grammar.
(f'$."{new_field_name}"', user_id.localpart),
)
row = cast(Tuple[Optional[int], Optional[int], Optional[int]], txn.fetchone())
# The values return null if the column is null.
total_bytes = (
# Discount the opening and closing braces to avoid double counting,
# but add one for a comma.
# -2 + 1 = -1
(row[0] - 1 if row[0] else 0)
+ (
row[1] + len("displayname") + PER_VALUE_EXTRA
if new_field_name != ProfileFields.DISPLAYNAME and row[1]
else 0
)
+ (
row[2] + len("avatar_url") + PER_VALUE_EXTRA
if new_field_name != ProfileFields.AVATAR_URL and row[2]
else 0
)
)
# Add the length of the field being added + the braces.
total_bytes += len(encode_canonical_json({new_field_name: new_value}))
if total_bytes > MAX_PROFILE_SIZE:
raise StoreError(400, "Profile too large", Codes.PROFILE_TOO_LARGE)
async def set_profile_displayname(
self, user_id: UserID, new_displayname: Optional[str]
) -> None:
@ -227,14 +384,25 @@ class ProfileWorkerStore(SQLBaseStore):
name is removed.
"""
user_localpart = user_id.localpart
await self.db_pool.simple_upsert(
table="profiles",
keyvalues={"user_id": user_localpart},
values={
"displayname": new_displayname,
"full_user_id": user_id.to_string(),
},
desc="set_profile_displayname",
def set_profile_displayname(txn: LoggingTransaction) -> None:
if new_displayname is not None:
self._check_profile_size(
txn, user_id, ProfileFields.DISPLAYNAME, new_displayname
)
self.db_pool.simple_upsert_txn(
txn,
table="profiles",
keyvalues={"user_id": user_localpart},
values={
"displayname": new_displayname,
"full_user_id": user_id.to_string(),
},
)
await self.db_pool.runInteraction(
"set_profile_displayname", set_profile_displayname
)
async def set_profile_avatar_url(
@ -249,13 +417,125 @@ class ProfileWorkerStore(SQLBaseStore):
removed.
"""
user_localpart = user_id.localpart
await self.db_pool.simple_upsert(
table="profiles",
keyvalues={"user_id": user_localpart},
values={"avatar_url": new_avatar_url, "full_user_id": user_id.to_string()},
desc="set_profile_avatar_url",
def set_profile_avatar_url(txn: LoggingTransaction) -> None:
if new_avatar_url is not None:
self._check_profile_size(
txn, user_id, ProfileFields.AVATAR_URL, new_avatar_url
)
self.db_pool.simple_upsert_txn(
txn,
table="profiles",
keyvalues={"user_id": user_localpart},
values={
"avatar_url": new_avatar_url,
"full_user_id": user_id.to_string(),
},
)
await self.db_pool.runInteraction(
"set_profile_avatar_url", set_profile_avatar_url
)
async def set_profile_field(
self, user_id: UserID, field_name: str, new_value: JsonValue
) -> None:
"""
Set a custom profile field for a user.
Args:
user_id: The user's ID.
field_name: The name of the custom profile field.
new_value: The value of the custom profile field.
"""
# Encode to canonical JSON.
canonical_value = encode_canonical_json(new_value)
def set_profile_field(txn: LoggingTransaction) -> None:
self._check_profile_size(txn, user_id, field_name, new_value)
if isinstance(self.database_engine, PostgresEngine):
from psycopg2.extras import Json
# Note that the || jsonb operator is not recursive, any duplicate
# keys will be taken from the second value.
sql = """
INSERT INTO profiles (user_id, full_user_id, fields) VALUES (?, ?, JSON_BUILD_OBJECT(?, ?::jsonb))
ON CONFLICT (user_id)
DO UPDATE SET full_user_id = EXCLUDED.full_user_id, fields = COALESCE(profiles.fields, '{}'::jsonb) || EXCLUDED.fields
"""
txn.execute(
sql,
(
user_id.localpart,
user_id.to_string(),
field_name,
# Pass as a JSON object since we have passing bytes disabled
# at the database driver.
Json(json.loads(canonical_value)),
),
)
else:
# You may be tempted to use json_patch instead of providing the parameters
# twice, but that recursively merges objects instead of replacing.
sql = """
INSERT INTO profiles (user_id, full_user_id, fields) VALUES (?, ?, JSON_OBJECT(?, JSON(?)))
ON CONFLICT (user_id)
DO UPDATE SET full_user_id = EXCLUDED.full_user_id, fields = JSON_SET(COALESCE(profiles.fields, '{}'), ?, JSON(?))
"""
# This will error if field_name has double quotes in it, but that's not
# possible due to the grammar.
json_field_name = f'$."{field_name}"'
txn.execute(
sql,
(
user_id.localpart,
user_id.to_string(),
json_field_name,
canonical_value,
json_field_name,
canonical_value,
),
)
await self.db_pool.runInteraction("set_profile_field", set_profile_field)
async def delete_profile_field(self, user_id: UserID, field_name: str) -> None:
"""
Remove a custom profile field for a user.
Args:
user_id: The user's ID.
field_name: The name of the custom profile field.
"""
def delete_profile_field(txn: LoggingTransaction) -> None:
if isinstance(self.database_engine, PostgresEngine):
sql = """
UPDATE profiles SET fields = fields - ?
WHERE user_id = ?
"""
txn.execute(
sql,
(field_name, user_id.localpart),
)
else:
sql = """
UPDATE profiles SET fields = json_remove(fields, ?)
WHERE user_id = ?
"""
txn.execute(
sql,
# This will error if field_name has double quotes in it.
(f'$."{field_name}"', user_id.localpart),
)
await self.db_pool.runInteraction("delete_profile_field", delete_profile_field)
class ProfileStore(ProfileWorkerStore):
pass

View file

@ -0,0 +1,15 @@
--
-- This file is licensed under the Affero General Public License (AGPL) version 3.
--
-- Copyright (C) 2024 Patrick Cloke
--
-- This program is free software: you can redistribute it and/or modify
-- it under the terms of the GNU Affero General Public License as
-- published by the Free Software Foundation, either version 3 of the
-- License, or (at your option) any later version.
--
-- See the GNU Affero General Public License for more details:
-- <https://www.gnu.org/licenses/agpl-3.0.html>.
-- Custom profile fields.
ALTER TABLE profiles ADD COLUMN fields JSONB;

View file

@ -43,6 +43,14 @@ CLIENT_SECRET_REGEX = re.compile(r"^[0-9a-zA-Z\.=_\-]+$")
#
MXC_REGEX = re.compile("^mxc://([^/]+)/([^/#?]+)$")
# https://spec.matrix.org/v1.13/appendices/#common-namespaced-identifier-grammar
#
# At least one character, less than or equal to 255 characters. Must start with
# a-z, the rest is a-z, 0-9, -, _, or ..
#
# This doesn't check anything about validity of namespaces.
NAMESPACED_GRAMMAR = re.compile(r"^[a-z][a-z0-9_.-]{0,254}$")
def random_string(length: int) -> str:
"""Generate a cryptographically secure string of random letters.
@ -68,6 +76,10 @@ def is_ascii(s: bytes) -> bool:
return True
def is_namedspaced_grammar(s: str) -> bool:
return bool(NAMESPACED_GRAMMAR.match(s))
def assert_valid_client_secret(client_secret: str) -> None:
"""Validate that a given string matches the client_secret defined by the spec"""
if (

View file

@ -142,6 +142,50 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.code, HTTPStatus.OK)
self.assertFalse(capabilities["m.set_avatar_url"]["enabled"])
@override_config(
{
"enable_set_displayname": False,
"experimental_features": {"msc4133_enabled": True},
}
)
def test_get_set_displayname_capabilities_displayname_disabled_msc4133(
self,
) -> None:
"""Test if set displayname is disabled that the server responds it."""
access_token = self.login(self.localpart, self.password)
channel = self.make_request("GET", self.url, access_token=access_token)
capabilities = channel.json_body["capabilities"]
self.assertEqual(channel.code, HTTPStatus.OK)
self.assertFalse(capabilities["m.set_displayname"]["enabled"])
self.assertTrue(capabilities["uk.tcpip.msc4133.profile_fields"]["enabled"])
self.assertEqual(
capabilities["uk.tcpip.msc4133.profile_fields"]["disallowed"],
["displayname"],
)
@override_config(
{
"enable_set_avatar_url": False,
"experimental_features": {"msc4133_enabled": True},
}
)
def test_get_set_avatar_url_capabilities_avatar_url_disabled_msc4133(self) -> None:
"""Test if set avatar_url is disabled that the server responds it."""
access_token = self.login(self.localpart, self.password)
channel = self.make_request("GET", self.url, access_token=access_token)
capabilities = channel.json_body["capabilities"]
self.assertEqual(channel.code, HTTPStatus.OK)
self.assertFalse(capabilities["m.set_avatar_url"]["enabled"])
self.assertTrue(capabilities["uk.tcpip.msc4133.profile_fields"]["enabled"])
self.assertEqual(
capabilities["uk.tcpip.msc4133.profile_fields"]["disallowed"],
["avatar_url"],
)
@override_config({"enable_3pid_changes": False})
def test_get_change_3pid_capabilities_3pid_disabled(self) -> None:
"""Test if change 3pid is disabled that the server responds it."""

View file

@ -25,16 +25,20 @@ import urllib.parse
from http import HTTPStatus
from typing import Any, Dict, Optional
from canonicaljson import encode_canonical_json
from twisted.test.proto_helpers import MemoryReactor
from synapse.api.errors import Codes
from synapse.rest import admin
from synapse.rest.client import login, profile, room
from synapse.server import HomeServer
from synapse.storage.databases.main.profile import MAX_PROFILE_SIZE
from synapse.types import UserID
from synapse.util import Clock
from tests import unittest
from tests.utils import USE_POSTGRES_FOR_TESTS
class ProfileTestCase(unittest.HomeserverTestCase):
@ -480,6 +484,298 @@ class ProfileTestCase(unittest.HomeserverTestCase):
# The client requested ?propagate=true, so it should have happened.
self.assertEqual(channel.json_body.get(prop), "http://my.server/pic.gif")
@unittest.override_config({"experimental_features": {"msc4133_enabled": True}})
def test_get_missing_custom_field(self) -> None:
channel = self.make_request(
"GET",
f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/custom_field",
)
self.assertEqual(channel.code, HTTPStatus.NOT_FOUND, channel.result)
self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
@unittest.override_config({"experimental_features": {"msc4133_enabled": True}})
def test_get_missing_custom_field_invalid_field_name(self) -> None:
channel = self.make_request(
"GET",
f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/[custom_field]",
)
self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
@unittest.override_config({"experimental_features": {"msc4133_enabled": True}})
def test_get_custom_field_rejects_bad_username(self) -> None:
channel = self.make_request(
"GET",
f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{urllib.parse.quote('@alice:')}/custom_field",
)
self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
@unittest.override_config({"experimental_features": {"msc4133_enabled": True}})
def test_set_custom_field(self) -> None:
channel = self.make_request(
"PUT",
f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/custom_field",
content={"custom_field": "test"},
access_token=self.owner_tok,
)
self.assertEqual(channel.code, 200, channel.result)
channel = self.make_request(
"GET",
f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/custom_field",
)
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
self.assertEqual(channel.json_body, {"custom_field": "test"})
# Overwriting the field should work.
channel = self.make_request(
"PUT",
f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/custom_field",
content={"custom_field": "new_Value"},
access_token=self.owner_tok,
)
self.assertEqual(channel.code, 200, channel.result)
channel = self.make_request(
"GET",
f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/custom_field",
)
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
self.assertEqual(channel.json_body, {"custom_field": "new_Value"})
# Deleting the field should work.
channel = self.make_request(
"DELETE",
f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/custom_field",
content={},
access_token=self.owner_tok,
)
self.assertEqual(channel.code, 200, channel.result)
channel = self.make_request(
"GET",
f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/custom_field",
)
self.assertEqual(channel.code, HTTPStatus.NOT_FOUND, channel.result)
self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
@unittest.override_config({"experimental_features": {"msc4133_enabled": True}})
def test_non_string(self) -> None:
"""Non-string fields are supported for custom fields."""
fields = {
"bool_field": True,
"array_field": ["test"],
"object_field": {"test": "test"},
"numeric_field": 1,
"null_field": None,
}
for key, value in fields.items():
channel = self.make_request(
"PUT",
f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/{key}",
content={key: value},
access_token=self.owner_tok,
)
self.assertEqual(channel.code, 200, channel.result)
channel = self.make_request(
"GET",
f"/_matrix/client/v3/profile/{self.owner}",
)
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
self.assertEqual(channel.json_body, {"displayname": "owner", **fields})
# Check getting individual fields works.
for key, value in fields.items():
channel = self.make_request(
"GET",
f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/{key}",
)
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
self.assertEqual(channel.json_body, {key: value})
@unittest.override_config({"experimental_features": {"msc4133_enabled": True}})
def test_set_custom_field_noauth(self) -> None:
channel = self.make_request(
"PUT",
f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/custom_field",
content={"custom_field": "test"},
)
self.assertEqual(channel.code, 401, channel.result)
self.assertEqual(channel.json_body["errcode"], Codes.MISSING_TOKEN)
@unittest.override_config({"experimental_features": {"msc4133_enabled": True}})
def test_set_custom_field_size(self) -> None:
"""
Attempts to set a custom field name that is too long should get a 400 error.
"""
# Key is missing.
channel = self.make_request(
"PUT",
f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/",
content={"": "test"},
access_token=self.owner_tok,
)
self.assertEqual(channel.code, 400, channel.result)
self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
# Single key is too large.
key = "c" * 500
channel = self.make_request(
"PUT",
f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/{key}",
content={key: "test"},
access_token=self.owner_tok,
)
self.assertEqual(channel.code, 400, channel.result)
self.assertEqual(channel.json_body["errcode"], Codes.KEY_TOO_LARGE)
channel = self.make_request(
"DELETE",
f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/{key}",
content={key: "test"},
access_token=self.owner_tok,
)
self.assertEqual(channel.code, 400, channel.result)
self.assertEqual(channel.json_body["errcode"], Codes.KEY_TOO_LARGE)
# Key doesn't match body.
channel = self.make_request(
"PUT",
f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/custom_field",
content={"diff_key": "test"},
access_token=self.owner_tok,
)
self.assertEqual(channel.code, 400, channel.result)
self.assertEqual(channel.json_body["errcode"], Codes.MISSING_PARAM)
@unittest.override_config({"experimental_features": {"msc4133_enabled": True}})
def test_set_custom_field_profile_too_long(self) -> None:
"""
Attempts to set a custom field that would push the overall profile too large.
"""
# Get right to the boundary:
# len("displayname") + len("owner") + 5 = 21 for the displayname
# 1 + 65498 + 5 for key "a" = 65504
# 2 braces, 1 comma
# 3 + 21 + 65498 = 65522 < 65536.
key = "a"
channel = self.make_request(
"PUT",
f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/{key}",
content={key: "a" * 65498},
access_token=self.owner_tok,
)
self.assertEqual(channel.code, 200, channel.result)
# Get the entire profile.
channel = self.make_request(
"GET",
f"/_matrix/client/v3/profile/{self.owner}",
access_token=self.owner_tok,
)
self.assertEqual(channel.code, 200, channel.result)
canonical_json = encode_canonical_json(channel.json_body)
# 6 is the minimum bytes to store a value: 4 quotes, 1 colon, 1 comma, an empty key.
# Be one below that so we can prove we're at the boundary.
self.assertEqual(len(canonical_json), MAX_PROFILE_SIZE - 8)
# Postgres stores JSONB with whitespace, while SQLite doesn't.
if USE_POSTGRES_FOR_TESTS:
ADDITIONAL_CHARS = 0
else:
ADDITIONAL_CHARS = 1
# The next one should fail, note the value has a (JSON) length of 2.
key = "b"
channel = self.make_request(
"PUT",
f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/{key}",
content={key: "1" + "a" * ADDITIONAL_CHARS},
access_token=self.owner_tok,
)
self.assertEqual(channel.code, 400, channel.result)
self.assertEqual(channel.json_body["errcode"], Codes.PROFILE_TOO_LARGE)
# Setting an avatar or (longer) display name should not work.
channel = self.make_request(
"PUT",
f"/profile/{self.owner}/displayname",
content={"displayname": "owner12345678" + "a" * ADDITIONAL_CHARS},
access_token=self.owner_tok,
)
self.assertEqual(channel.code, 400, channel.result)
self.assertEqual(channel.json_body["errcode"], Codes.PROFILE_TOO_LARGE)
channel = self.make_request(
"PUT",
f"/profile/{self.owner}/avatar_url",
content={"avatar_url": "mxc://foo/bar"},
access_token=self.owner_tok,
)
self.assertEqual(channel.code, 400, channel.result)
self.assertEqual(channel.json_body["errcode"], Codes.PROFILE_TOO_LARGE)
# Removing a single byte should work.
key = "b"
channel = self.make_request(
"PUT",
f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/{key}",
content={key: "" + "a" * ADDITIONAL_CHARS},
access_token=self.owner_tok,
)
self.assertEqual(channel.code, 200, channel.result)
# Finally, setting a field that already exists to a value that is <= in length should work.
key = "a"
channel = self.make_request(
"PUT",
f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/{key}",
content={key: ""},
access_token=self.owner_tok,
)
self.assertEqual(channel.code, 200, channel.result)
@unittest.override_config({"experimental_features": {"msc4133_enabled": True}})
def test_set_custom_field_displayname(self) -> None:
channel = self.make_request(
"PUT",
f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/displayname",
content={"displayname": "test"},
access_token=self.owner_tok,
)
self.assertEqual(channel.code, 200, channel.result)
displayname = self._get_displayname()
self.assertEqual(displayname, "test")
@unittest.override_config({"experimental_features": {"msc4133_enabled": True}})
def test_set_custom_field_avatar_url(self) -> None:
channel = self.make_request(
"PUT",
f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/avatar_url",
content={"avatar_url": "mxc://test/good"},
access_token=self.owner_tok,
)
self.assertEqual(channel.code, 200, channel.result)
avatar_url = self._get_avatar_url()
self.assertEqual(avatar_url, "mxc://test/good")
@unittest.override_config({"experimental_features": {"msc4133_enabled": True}})
def test_set_custom_field_other(self) -> None:
"""Setting someone else's profile field should fail"""
channel = self.make_request(
"PUT",
f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.other}/custom_field",
content={"custom_field": "test"},
access_token=self.owner_tok,
)
self.assertEqual(channel.code, 403, channel.result)
self.assertEqual(channel.json_body["errcode"], Codes.FORBIDDEN)
def _setup_local_files(self, names_and_props: Dict[str, Dict[str, Any]]) -> None:
"""Stores metadata about files in the database.

View file

@ -20,7 +20,11 @@
#
from synapse.api.errors import SynapseError
from synapse.util.stringutils import assert_valid_client_secret, base62_encode
from synapse.util.stringutils import (
assert_valid_client_secret,
base62_encode,
is_namedspaced_grammar,
)
from .. import unittest
@ -58,3 +62,25 @@ class StringUtilsTestCase(unittest.TestCase):
self.assertEqual("10", base62_encode(62))
self.assertEqual("1c", base62_encode(100))
self.assertEqual("001c", base62_encode(100, minwidth=4))
def test_namespaced_identifier(self) -> None:
self.assertTrue(is_namedspaced_grammar("test"))
self.assertTrue(is_namedspaced_grammar("m.test"))
self.assertTrue(is_namedspaced_grammar("org.matrix.test"))
self.assertTrue(is_namedspaced_grammar("org.matrix.msc1234"))
self.assertTrue(is_namedspaced_grammar("test"))
self.assertTrue(is_namedspaced_grammar("t-e_s.t"))
# Must start with letter.
self.assertFalse(is_namedspaced_grammar("1test"))
self.assertFalse(is_namedspaced_grammar("-test"))
self.assertFalse(is_namedspaced_grammar("_test"))
self.assertFalse(is_namedspaced_grammar(".test"))
# Must contain only a-z, 0-9, -, _, ..
self.assertFalse(is_namedspaced_grammar("test/"))
self.assertFalse(is_namedspaced_grammar('test"'))
self.assertFalse(is_namedspaced_grammar("testö"))
# Must be < 255 characters.
self.assertFalse(is_namedspaced_grammar("t" * 256))