knocking implementation

Signed-off-by: strawberry <strawberry@puppygock.gay>

add sync bit of knocking

Signed-off-by: strawberry <strawberry@puppygock.gay>
This commit is contained in:
strawberry 2025-01-11 18:43:54 -05:00
parent fabd3cf567
commit 5a1c41e66b
14 changed files with 978 additions and 117 deletions

View file

@ -1,4 +1,5 @@
use std::{
borrow::Borrow,
collections::{BTreeMap, HashMap, HashSet},
net::IpAddr,
sync::Arc,
@ -8,7 +9,7 @@ use axum::extract::State;
use axum_client_ip::InsecureClientIp;
use conduwuit::{
debug, debug_info, debug_warn, err, info,
pdu::{self, gen_event_id_canonical_json, PduBuilder},
pdu::{gen_event_id_canonical_json, PduBuilder},
result::FlatOk,
trace,
utils::{self, shuffle, IterStream, ReadyExt},
@ -19,6 +20,7 @@ use ruma::{
api::{
client::{
error::ErrorKind,
knock::knock_room,
membership::{
ban_user, forget_room, get_member_events, invite_user, join_room_by_id,
join_room_by_id_or_alias,
@ -37,11 +39,12 @@ use ruma::{
},
StateEventType,
},
state_res, CanonicalJsonObject, CanonicalJsonValue, OwnedRoomId, OwnedServerName,
OwnedUserId, RoomId, RoomVersionId, ServerName, UserId,
state_res, CanonicalJsonObject, CanonicalJsonValue, OwnedEventId, OwnedRoomId,
OwnedServerName, OwnedUserId, RoomId, RoomVersionId, ServerName, UserId,
};
use service::{
appservice::RegistrationInfo,
pdu::gen_event_id,
rooms::{state::RoomMutexGuard, state_compressor::HashSetCompressStateEvent},
Services,
};
@ -348,6 +351,116 @@ pub(crate) async fn join_room_by_id_or_alias_route(
Ok(join_room_by_id_or_alias::v3::Response { room_id: join_room_response.room_id })
}
/// # `POST /_matrix/client/*/knock/{roomIdOrAlias}`
///
/// Tries to knock the room to ask permission to join for the sender user.
#[tracing::instrument(skip_all, fields(%client), name = "knock")]
pub(crate) async fn knock_room_route(
State(services): State<crate::State>,
InsecureClientIp(client): InsecureClientIp,
body: Ruma<knock_room::v3::Request>,
) -> Result<knock_room::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let body = body.body;
let (servers, room_id) = match OwnedRoomId::try_from(body.room_id_or_alias) {
| Ok(room_id) => {
banned_room_check(
&services,
sender_user,
Some(&room_id),
room_id.server_name(),
client,
)
.await?;
let mut servers = body.via.clone();
servers.extend(
services
.rooms
.state_cache
.servers_invite_via(&room_id)
.map(ToOwned::to_owned)
.collect::<Vec<_>>()
.await,
);
servers.extend(
services
.rooms
.state_cache
.invite_state(sender_user, &room_id)
.await
.unwrap_or_default()
.iter()
.filter_map(|event| event.get_field("sender").ok().flatten())
.filter_map(|sender: &str| UserId::parse(sender).ok())
.map(|user| user.server_name().to_owned()),
);
if let Some(server) = room_id.server_name() {
servers.push(server.to_owned());
}
servers.sort_unstable();
servers.dedup();
shuffle(&mut servers);
(servers, room_id)
},
| Err(room_alias) => {
let (room_id, mut servers) = services
.rooms
.alias
.resolve_alias(&room_alias, Some(body.via.clone()))
.await?;
banned_room_check(
&services,
sender_user,
Some(&room_id),
Some(room_alias.server_name()),
client,
)
.await?;
let addl_via_servers = services
.rooms
.state_cache
.servers_invite_via(&room_id)
.map(ToOwned::to_owned);
let addl_state_servers = services
.rooms
.state_cache
.invite_state(sender_user, &room_id)
.await
.unwrap_or_default();
let mut addl_servers: Vec<_> = addl_state_servers
.iter()
.map(|event| event.get_field("sender"))
.filter_map(FlatOk::flat_ok)
.map(|user: &UserId| user.server_name().to_owned())
.stream()
.chain(addl_via_servers)
.collect()
.await;
addl_servers.sort_unstable();
addl_servers.dedup();
shuffle(&mut addl_servers);
servers.append(&mut addl_servers);
(servers, room_id)
},
};
knock_room_by_id_helper(&services, sender_user, &room_id, body.reason.clone(), &servers)
.boxed()
.await
}
/// # `POST /_matrix/client/v3/rooms/{roomId}/leave`
///
/// Tries to leave the sender user from a room.
@ -403,6 +516,17 @@ pub(crate) async fn invite_user_route(
)));
}
if let Ok(target_user_membership) = services
.rooms
.state_accessor
.get_member(&body.room_id, user_id)
.await
{
if target_user_membership.membership == MembershipState::Ban {
return Err!(Request(Forbidden("User is banned from this room.")));
}
}
if recipient_ignored_by_sender {
// silently drop the invite to the recipient if they've been ignored by the
// sender, pretend it worked
@ -862,7 +986,7 @@ async fn join_room_by_id_helper_remote(
.hash_and_sign_event(&mut join_event_stub, &room_version_id)?;
// Generate event id
let event_id = pdu::gen_event_id(&join_event_stub, &room_version_id)?;
let event_id = gen_event_id(&join_event_stub, &room_version_id)?;
// Add event_id back
join_event_stub
@ -1030,7 +1154,7 @@ async fn join_room_by_id_helper_remote(
};
let auth_check = state_res::event_auth::auth_check(
&state_res::RoomVersion::new(&room_version_id).expect("room version is supported"),
&state_res::RoomVersion::new(&room_version_id)?,
&parsed_join_pdu,
None, // TODO: third party invite
|k, s| state_fetch(k, s.to_owned()),
@ -1043,10 +1167,10 @@ async fn join_room_by_id_helper_remote(
}
info!("Compressing state from send_join");
let compressed = state
.iter()
.stream()
.then(|(&k, id)| services.rooms.state_compressor.compress_state_event(k, id))
let compressed: HashSet<_> = services
.rooms
.state_compressor
.compress_state_events(state.iter().map(|(ssk, eid)| (ssk, eid.borrow())))
.collect()
.await;
@ -1282,7 +1406,7 @@ async fn join_room_by_id_helper_local(
.hash_and_sign_event(&mut join_event_stub, &room_version_id)?;
// Generate event id
let event_id = pdu::gen_event_id(&join_event_stub, &room_version_id)?;
let event_id = gen_event_id(&join_event_stub, &room_version_id)?;
// Add event_id back
join_event_stub
@ -1392,6 +1516,7 @@ async fn make_join_request(
);
make_join_response_and_server =
Err!(BadServerResponse("No server available to assist in joining."));
return make_join_response_and_server;
}
}
@ -1569,7 +1694,7 @@ pub async fn leave_all_rooms(services: &Services, user_id: &UserId) {
for room_id in all_rooms {
// ignore errors
if let Err(e) = leave_room(services, user_id, &room_id, None).await {
warn!(%room_id, %user_id, %e, "Failed to leave room");
warn!(%user_id, "Failed to leave {room_id} remotely: {e}");
}
services.rooms.state_cache.forget(&room_id, user_id);
@ -1585,11 +1710,15 @@ pub async fn leave_room(
//use conduwuit::utils::stream::OptionStream;
use futures::TryFutureExt;
// Ask a remote server if we don't have this room
// Ask a remote server if we don't have this room and are not knocking on it
if !services
.rooms
.state_cache
.server_in_room(services.globals.server_name(), room_id)
.await && !services
.rooms
.state_cache
.is_knocked(user_id, room_id)
.await
{
if let Err(e) = remote_leave_room(services, user_id, room_id).await {
@ -1601,7 +1730,8 @@ pub async fn leave_room(
.rooms
.state_cache
.invite_state(user_id, room_id)
.map_err(|_| services.rooms.state_cache.left_state(user_id, room_id))
.or_else(|_| services.rooms.state_cache.knock_state(user_id, room_id))
.or_else(|_| services.rooms.state_cache.left_state(user_id, room_id))
.await
.ok();
@ -1683,13 +1813,6 @@ async fn remote_leave_room(
let mut make_leave_response_and_server =
Err!(BadServerResponse("No server available to assist in leaving."));
let invite_state = services
.rooms
.state_cache
.invite_state(user_id, room_id)
.await
.map_err(|_| err!(Request(BadState("User is not invited."))))?;
let mut servers: HashSet<OwnedServerName> = services
.rooms
.state_cache
@ -1698,13 +1821,39 @@ async fn remote_leave_room(
.collect()
.await;
servers.extend(
invite_state
.iter()
.filter_map(|event| event.get_field("sender").ok().flatten())
.filter_map(|sender: &str| UserId::parse(sender).ok())
.map(|user| user.server_name().to_owned()),
);
if let Ok(invite_state) = services
.rooms
.state_cache
.invite_state(user_id, room_id)
.await
{
servers.extend(
invite_state
.iter()
.filter_map(|event| event.get_field("sender").ok().flatten())
.filter_map(|sender: &str| UserId::parse(sender).ok())
.map(|user| user.server_name().to_owned()),
);
} else if let Ok(knock_state) = services
.rooms
.state_cache
.knock_state(user_id, room_id)
.await
{
servers.extend(
knock_state
.iter()
.filter_map(|event| event.get_field("sender").ok().flatten())
.filter_map(|sender: &str| UserId::parse(sender).ok())
.filter_map(|sender| {
if !services.globals.user_is_local(sender) {
Some(sender.server_name().to_owned())
} else {
None
}
}),
);
}
if let Some(room_id_server_name) = room_id.server_name() {
servers.insert(room_id_server_name.to_owned());
@ -1779,7 +1928,7 @@ async fn remote_leave_room(
.hash_and_sign_event(&mut leave_event_stub, &room_version_id)?;
// Generate event id
let event_id = pdu::gen_event_id(&leave_event_stub, &room_version_id)?;
let event_id = gen_event_id(&leave_event_stub, &room_version_id)?;
// Add event_id back
leave_event_stub
@ -1805,3 +1954,514 @@ async fn remote_leave_room(
Ok(())
}
async fn knock_room_by_id_helper(
services: &Services,
sender_user: &UserId,
room_id: &RoomId,
reason: Option<String>,
servers: &[OwnedServerName],
) -> Result<knock_room::v3::Response> {
let state_lock = services.rooms.state.mutex.lock(room_id).await;
if services
.rooms
.state_cache
.is_invited(sender_user, room_id)
.await
{
debug_warn!("{sender_user} is already invited in {room_id} but attempted to knock");
return Err!(Request(Forbidden(
"You cannot knock on a room you are already invited/accepted to."
)));
}
if services
.rooms
.state_cache
.is_joined(sender_user, room_id)
.await
{
debug_warn!("{sender_user} is already joined in {room_id} but attempted to knock");
return Err!(Request(Forbidden("You cannot knock on a room you are already joined in.")));
}
if services
.rooms
.state_cache
.is_knocked(sender_user, room_id)
.await
{
debug_warn!("{sender_user} is already knocked in {room_id}");
return Ok(knock_room::v3::Response { room_id: room_id.into() });
}
if let Ok(membership) = services
.rooms
.state_accessor
.get_member(room_id, sender_user)
.await
{
if membership.membership == MembershipState::Ban {
debug_warn!("{sender_user} is banned from {room_id} but attempted to knock");
return Err!(Request(Forbidden("You cannot knock on a room you are banned from.")));
}
}
let server_in_room = services
.rooms
.state_cache
.server_in_room(services.globals.server_name(), room_id)
.await;
let local_knock = server_in_room
|| servers.is_empty()
|| (servers.len() == 1 && services.globals.server_is_ours(&servers[0]));
if local_knock {
knock_room_helper_local(services, sender_user, room_id, reason, servers, state_lock)
.boxed()
.await?;
} else {
knock_room_helper_remote(services, sender_user, room_id, reason, servers, state_lock)
.boxed()
.await?;
}
Ok(knock_room::v3::Response::new(room_id.to_owned()))
}
async fn knock_room_helper_local(
services: &Services,
sender_user: &UserId,
room_id: &RoomId,
reason: Option<String>,
servers: &[OwnedServerName],
state_lock: RoomMutexGuard,
) -> Result {
debug_info!("We can knock locally");
let room_version_id = services.rooms.state.get_room_version(room_id).await?;
if matches!(
room_version_id,
RoomVersionId::V1
| RoomVersionId::V2
| RoomVersionId::V3
| RoomVersionId::V4
| RoomVersionId::V5
| RoomVersionId::V6
) {
return Err!(Request(Forbidden("This room does not support knocking.")));
}
let content = RoomMemberEventContent {
displayname: services.users.displayname(sender_user).await.ok(),
avatar_url: services.users.avatar_url(sender_user).await.ok(),
blurhash: services.users.blurhash(sender_user).await.ok(),
reason: reason.clone(),
..RoomMemberEventContent::new(MembershipState::Knock)
};
// Try normal knock first
let Err(error) = services
.rooms
.timeline
.build_and_append_pdu(
PduBuilder::state(sender_user.to_string(), &content),
sender_user,
room_id,
&state_lock,
)
.await
else {
return Ok(());
};
if servers.is_empty() || (servers.len() == 1 && services.globals.server_is_ours(&servers[0]))
{
return Err(error);
}
warn!("We couldn't do the knock locally, maybe federation can help to satisfy the knock");
let (make_knock_response, remote_server) =
make_knock_request(services, sender_user, room_id, servers).await?;
info!("make_knock finished");
let room_version_id = make_knock_response.room_version;
if !services.server.supported_room_version(&room_version_id) {
return Err!(BadServerResponse(
"Remote room version {room_version_id} is not supported by conduwuit"
));
}
let mut knock_event_stub = serde_json::from_str::<CanonicalJsonObject>(
make_knock_response.event.get(),
)
.map_err(|e| {
err!(BadServerResponse("Invalid make_knock event json received from server: {e:?}"))
})?;
knock_event_stub.insert(
"origin".to_owned(),
CanonicalJsonValue::String(services.globals.server_name().as_str().to_owned()),
);
knock_event_stub.insert(
"origin_server_ts".to_owned(),
CanonicalJsonValue::Integer(
utils::millis_since_unix_epoch()
.try_into()
.expect("Timestamp is valid js_int value"),
),
);
knock_event_stub.insert(
"content".to_owned(),
to_canonical_value(RoomMemberEventContent {
displayname: services.users.displayname(sender_user).await.ok(),
avatar_url: services.users.avatar_url(sender_user).await.ok(),
blurhash: services.users.blurhash(sender_user).await.ok(),
reason,
..RoomMemberEventContent::new(MembershipState::Knock)
})
.expect("event is valid, we just created it"),
);
// In order to create a compatible ref hash (EventID) the `hashes` field needs
// to be present
services
.server_keys
.hash_and_sign_event(&mut knock_event_stub, &room_version_id)?;
// Generate event id
let event_id = gen_event_id(&knock_event_stub, &room_version_id)?;
// Add event_id
knock_event_stub
.insert("event_id".to_owned(), CanonicalJsonValue::String(event_id.clone().into()));
// It has enough fields to be called a proper event now
let knock_event = knock_event_stub;
info!("Asking {remote_server} for send_knock in room {room_id}");
let send_knock_request = federation::knock::send_knock::v1::Request {
room_id: room_id.to_owned(),
event_id: event_id.clone(),
pdu: services
.sending
.convert_to_outgoing_federation_event(knock_event.clone())
.await,
};
let send_knock_response = services
.sending
.send_federation_request(&remote_server, send_knock_request)
.await?;
info!("send_knock finished");
services
.rooms
.short
.get_or_create_shortroomid(room_id)
.await;
info!("Parsing knock event");
let parsed_knock_pdu = PduEvent::from_id_val(&event_id, knock_event.clone())
.map_err(|e| err!(BadServerResponse("Invalid knock event PDU: {e:?}")))?;
info!("Updating membership locally to knock state with provided stripped state events");
services
.rooms
.state_cache
.update_membership(
room_id,
sender_user,
parsed_knock_pdu
.get_content::<RoomMemberEventContent>()
.expect("we just created this"),
sender_user,
Some(send_knock_response.knock_room_state),
None,
false,
)
.await?;
info!("Appending room knock event locally");
services
.rooms
.timeline
.append_pdu(
&parsed_knock_pdu,
knock_event,
vec![(*parsed_knock_pdu.event_id).to_owned()],
&state_lock,
)
.await?;
Ok(())
}
async fn knock_room_helper_remote(
services: &Services,
sender_user: &UserId,
room_id: &RoomId,
reason: Option<String>,
servers: &[OwnedServerName],
state_lock: RoomMutexGuard,
) -> Result {
info!("Knocking {room_id} over federation.");
let (make_knock_response, remote_server) =
make_knock_request(services, sender_user, room_id, servers).await?;
info!("make_knock finished");
let room_version_id = make_knock_response.room_version;
if !services.server.supported_room_version(&room_version_id) {
return Err!(BadServerResponse(
"Remote room version {room_version_id} is not supported by conduwuit"
));
}
let mut knock_event_stub: CanonicalJsonObject =
serde_json::from_str(make_knock_response.event.get()).map_err(|e| {
err!(BadServerResponse("Invalid make_knock event json received from server: {e:?}"))
})?;
knock_event_stub.insert(
"origin".to_owned(),
CanonicalJsonValue::String(services.globals.server_name().as_str().to_owned()),
);
knock_event_stub.insert(
"origin_server_ts".to_owned(),
CanonicalJsonValue::Integer(
utils::millis_since_unix_epoch()
.try_into()
.expect("Timestamp is valid js_int value"),
),
);
knock_event_stub.insert(
"content".to_owned(),
to_canonical_value(RoomMemberEventContent {
displayname: services.users.displayname(sender_user).await.ok(),
avatar_url: services.users.avatar_url(sender_user).await.ok(),
blurhash: services.users.blurhash(sender_user).await.ok(),
reason,
..RoomMemberEventContent::new(MembershipState::Knock)
})
.expect("event is valid, we just created it"),
);
// In order to create a compatible ref hash (EventID) the `hashes` field needs
// to be present
services
.server_keys
.hash_and_sign_event(&mut knock_event_stub, &room_version_id)?;
// Generate event id
let event_id = gen_event_id(&knock_event_stub, &room_version_id)?;
// Add event_id
knock_event_stub
.insert("event_id".to_owned(), CanonicalJsonValue::String(event_id.clone().into()));
// It has enough fields to be called a proper event now
let knock_event = knock_event_stub;
info!("Asking {remote_server} for send_knock in room {room_id}");
let send_knock_request = federation::knock::send_knock::v1::Request {
room_id: room_id.to_owned(),
event_id: event_id.clone(),
pdu: services
.sending
.convert_to_outgoing_federation_event(knock_event.clone())
.await,
};
let send_knock_response = services
.sending
.send_federation_request(&remote_server, send_knock_request)
.await?;
info!("send_knock finished");
services
.rooms
.short
.get_or_create_shortroomid(room_id)
.await;
info!("Parsing knock event");
let parsed_knock_pdu = PduEvent::from_id_val(&event_id, knock_event.clone())
.map_err(|e| err!(BadServerResponse("Invalid knock event PDU: {e:?}")))?;
info!("Going through send_knock response knock state events");
let state = send_knock_response
.knock_room_state
.iter()
.map(|event| serde_json::from_str::<CanonicalJsonObject>(event.clone().into_json().get()))
.filter_map(Result::ok);
let mut state_map: HashMap<u64, OwnedEventId> = HashMap::new();
for event in state {
let Some(state_key) = event.get("state_key") else {
debug_warn!("send_knock stripped state event missing state_key: {event:?}");
continue;
};
let Some(event_type) = event.get("type") else {
debug_warn!("send_knock stripped state event missing event type: {event:?}");
continue;
};
let Ok(state_key) = serde_json::from_value::<String>(state_key.clone().into()) else {
debug_warn!("send_knock stripped state event has invalid state_key: {event:?}");
continue;
};
let Ok(event_type) = serde_json::from_value::<StateEventType>(event_type.clone().into())
else {
debug_warn!("send_knock stripped state event has invalid event type: {event:?}");
continue;
};
let event_id = gen_event_id(&event, &room_version_id)?;
let shortstatekey = services
.rooms
.short
.get_or_create_shortstatekey(&event_type, &state_key)
.await;
services.rooms.outlier.add_pdu_outlier(&event_id, &event);
state_map.insert(shortstatekey, event_id.clone());
}
info!("Compressing state from send_knock");
let compressed: HashSet<_> = services
.rooms
.state_compressor
.compress_state_events(state_map.iter().map(|(ssk, eid)| (ssk, eid.borrow())))
.collect()
.await;
debug!("Saving compressed state");
let HashSetCompressStateEvent {
shortstatehash: statehash_before_knock,
added,
removed,
} = services
.rooms
.state_compressor
.save_state(room_id, Arc::new(compressed))
.await?;
debug!("Forcing state for new room");
services
.rooms
.state
.force_state(room_id, statehash_before_knock, added, removed, &state_lock)
.await?;
let statehash_after_knock = services
.rooms
.state
.append_to_state(&parsed_knock_pdu)
.await?;
info!("Updating membership locally to knock state with provided stripped state events");
services
.rooms
.state_cache
.update_membership(
room_id,
sender_user,
parsed_knock_pdu
.get_content::<RoomMemberEventContent>()
.expect("we just created this"),
sender_user,
Some(send_knock_response.knock_room_state),
None,
false,
)
.await?;
info!("Appending room knock event locally");
services
.rooms
.timeline
.append_pdu(
&parsed_knock_pdu,
knock_event,
vec![(*parsed_knock_pdu.event_id).to_owned()],
&state_lock,
)
.await?;
info!("Setting final room state for new room");
// We set the room state after inserting the pdu, so that we never have a moment
// in time where events in the current room state do not exist
services
.rooms
.state
.set_room_state(room_id, statehash_after_knock, &state_lock);
Ok(())
}
async fn make_knock_request(
services: &Services,
sender_user: &UserId,
room_id: &RoomId,
servers: &[OwnedServerName],
) -> Result<(federation::knock::create_knock_event_template::v1::Response, OwnedServerName)> {
let mut make_knock_response_and_server =
Err!(BadServerResponse("No server available to assist in knocking."));
let mut make_knock_counter: usize = 0;
for remote_server in servers {
if services.globals.server_is_ours(remote_server) {
continue;
}
info!("Asking {remote_server} for make_knock ({make_knock_counter})");
let make_knock_response = services
.sending
.send_federation_request(
remote_server,
federation::knock::create_knock_event_template::v1::Request {
room_id: room_id.to_owned(),
user_id: sender_user.to_owned(),
ver: services.server.supported_room_versions().collect(),
},
)
.await;
trace!("make_knock response: {make_knock_response:?}");
make_knock_counter = make_knock_counter.saturating_add(1);
make_knock_response_and_server = make_knock_response.map(|r| (r, remote_server.clone()));
if make_knock_response_and_server.is_ok() {
break;
}
if make_knock_counter > 40 {
warn!(
"50 servers failed to provide valid make_knock response, assuming no server can \
assist in knocking."
);
make_knock_response_and_server =
Err!(BadServerResponse("No server available to assist in knocking."));
return make_knock_response_and_server;
}
}
make_knock_response_and_server
}

View file

@ -33,8 +33,8 @@ use ruma::{
self,
v3::{
Ephemeral, Filter, GlobalAccountData, InviteState, InvitedRoom, JoinedRoom,
LeftRoom, Presence, RoomAccountData, RoomSummary, Rooms, State as RoomState,
Timeline, ToDevice,
KnockState, KnockedRoom, LeftRoom, Presence, RoomAccountData, RoomSummary, Rooms,
State as RoomState, Timeline, ToDevice,
},
DeviceLists, UnreadNotificationsCount,
},
@ -266,6 +266,35 @@ pub(crate) async fn build_sync_events(
invited_rooms
});
let knocked_rooms = services
.rooms
.state_cache
.rooms_knocked(sender_user)
.fold_default(|mut knocked_rooms: BTreeMap<_, _>, (room_id, knock_state)| async move {
// Get and drop the lock to wait for remaining operations to finish
let insert_lock = services.rooms.timeline.mutex_insert.lock(&room_id).await;
drop(insert_lock);
let knock_count = services
.rooms
.state_cache
.get_knock_count(&room_id, sender_user)
.await
.ok();
// Knocked before last sync
if Some(since) >= knock_count {
return knocked_rooms;
}
let knocked_room = KnockedRoom {
knock_state: KnockState { events: knock_state },
};
knocked_rooms.insert(room_id, knocked_room);
knocked_rooms
});
let presence_updates: OptionFuture<_> = services
.globals
.allow_local_presence()
@ -300,7 +329,7 @@ pub(crate) async fn build_sync_events(
.users
.remove_to_device_events(sender_user, sender_device, since);
let rooms = join3(joined_rooms, left_rooms, invited_rooms);
let rooms = join4(joined_rooms, left_rooms, invited_rooms, knocked_rooms);
let ephemeral = join3(remove_to_device_events, to_device_events, presence_updates);
let top = join5(account_data, ephemeral, device_one_time_keys_count, keys_changed, rooms)
.boxed()
@ -308,7 +337,7 @@ pub(crate) async fn build_sync_events(
let (account_data, ephemeral, device_one_time_keys_count, keys_changed, rooms) = top;
let ((), to_device_events, presence_updates) = ephemeral;
let (joined_rooms, left_rooms, invited_rooms) = rooms;
let (joined_rooms, left_rooms, invited_rooms, knocked_rooms) = rooms;
let (joined_rooms, mut device_list_updates, left_encrypted_users) = joined_rooms;
device_list_updates.extend(keys_changed);
@ -349,7 +378,7 @@ pub(crate) async fn build_sync_events(
leave: left_rooms,
join: joined_rooms,
invite: invited_rooms,
knock: BTreeMap::new(), // TODO
knock: knocked_rooms,
},
to_device: ToDevice { events: to_device_events },
};

View file

@ -113,9 +113,18 @@ pub(crate) async fn sync_events_v4_route(
.collect()
.await;
let all_knocked_rooms: Vec<_> = services
.rooms
.state_cache
.rooms_knocked(sender_user)
.map(|r| r.0)
.collect()
.await;
let all_rooms = all_joined_rooms
.iter()
.chain(all_invited_rooms.iter())
.chain(all_knocked_rooms.iter())
.map(Clone::clone)
.collect();

View file

@ -99,6 +99,7 @@ pub fn build(router: Router<State>, server: &Server) -> Router<State> {
.ruma_route(&client::join_room_by_id_route)
.ruma_route(&client::join_room_by_id_or_alias_route)
.ruma_route(&client::joined_members_route)
.ruma_route(&client::knock_room_route)
.ruma_route(&client::leave_room_route)
.ruma_route(&client::forget_room_route)
.ruma_route(&client::joined_rooms_route)
@ -204,8 +205,10 @@ pub fn build(router: Router<State>, server: &Server) -> Router<State> {
.ruma_route(&server::get_room_state_route)
.ruma_route(&server::get_room_state_ids_route)
.ruma_route(&server::create_leave_event_template_route)
.ruma_route(&server::create_knock_event_template_route)
.ruma_route(&server::create_leave_event_v1_route)
.ruma_route(&server::create_leave_event_v2_route)
.ruma_route(&server::create_knock_event_v1_route)
.ruma_route(&server::create_join_event_template_route)
.ruma_route(&server::create_join_event_v1_route)
.ruma_route(&server::create_join_event_v2_route)

View file

@ -6,8 +6,9 @@ use ruma::{
api::{client::error::ErrorKind, federation::membership::create_invite},
events::room::member::{MembershipState, RoomMemberEventContent},
serde::JsonObject,
CanonicalJsonValue, OwnedEventId, OwnedUserId, UserId,
CanonicalJsonValue, OwnedUserId, UserId,
};
use service::pdu::gen_event_id;
use crate::Ruma;
@ -86,12 +87,7 @@ pub(crate) async fn create_invite_route(
.map_err(|e| err!(Request(InvalidParam("Failed to sign event: {e}"))))?;
// Generate event id
let event_id = OwnedEventId::parse(format!(
"${}",
ruma::signatures::reference_hash(&signed_event, &body.room_version)
.expect("ruma can calculate reference hashes")
))
.expect("ruma's reference hashes are valid event ids");
let event_id = gen_event_id(&signed_event, &body.room_version)?;
// Add event_id back
signed_event.insert("event_id".to_owned(), CanonicalJsonValue::String(event_id.to_string()));
@ -115,12 +111,12 @@ pub(crate) async fn create_invite_route(
let mut invite_state = body.invite_room_state.clone();
let mut event: JsonObject = serde_json::from_str(body.event.get())
.map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid invite event bytes."))?;
.map_err(|e| err!(Request(BadJson("Invalid invite event PDU: {e}"))))?;
event.insert("event_id".to_owned(), "$placeholder".into());
let pdu: PduEvent = serde_json::from_value(event.into())
.map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid invite event."))?;
.map_err(|e| err!(Request(BadJson("Invalid invite event PDU: {e}"))))?;
invite_state.push(pdu.to_stripped_state_event());

View file

@ -1,5 +1,5 @@
use axum::extract::State;
use conduwuit::Err;
use conduwuit::{debug_warn, Err};
use ruma::{
api::{client::error::ErrorKind, federation::knock::create_knock_event_template},
events::room::member::{MembershipState, RoomMemberEventContent},
@ -15,7 +15,8 @@ use crate::{service::pdu::PduBuilder, Error, Result, Ruma};
///
/// Creates a knock template.
pub(crate) async fn create_knock_event_template_route(
State(services): State<crate::State>, body: Ruma<create_knock_event_template::v1::Request>,
State(services): State<crate::State>,
body: Ruma<create_knock_event_template::v1::Request>,
) -> Result<create_knock_event_template::v1::Response> {
if !services.rooms.metadata.exists(&body.room_id).await {
return Err!(Request(NotFound("Room is unknown to this server.")));
@ -39,8 +40,8 @@ pub(crate) async fn create_knock_event_template_route(
.contains(body.origin())
{
warn!(
"Server {} for remote user {} tried knocking room ID {} which has a server name that is globally \
forbidden. Rejecting.",
"Server {} for remote user {} tried knocking room ID {} which has a server name \
that is globally forbidden. Rejecting.",
body.origin(),
&body.user_id,
&body.room_id,
@ -63,29 +64,44 @@ pub(crate) async fn create_knock_event_template_route(
if matches!(room_version_id, V1 | V2 | V3 | V4 | V5 | V6) {
return Err(Error::BadRequest(
ErrorKind::IncompatibleRoomVersion {
room_version: room_version_id,
},
ErrorKind::IncompatibleRoomVersion { room_version: room_version_id },
"Room version does not support knocking.",
));
}
if !body.ver.contains(&room_version_id) {
return Err(Error::BadRequest(
ErrorKind::IncompatibleRoomVersion {
room_version: room_version_id,
},
ErrorKind::IncompatibleRoomVersion { room_version: room_version_id },
"Your homeserver does not support the features required to knock on this room.",
));
}
let state_lock = services.rooms.state.mutex.lock(&body.room_id).await;
if let Ok(membership) = services
.rooms
.state_accessor
.get_member(&body.room_id, &body.user_id)
.await
{
if membership.membership == MembershipState::Ban {
debug_warn!(
"Remote user {} is banned from {} but attempted to knock",
&body.user_id,
&body.room_id
);
return Err!(Request(Forbidden("You cannot knock on a room you are banned from.")));
}
}
let (_pdu, mut pdu_json) = services
.rooms
.timeline
.create_hash_and_sign_event(
PduBuilder::state(body.user_id.to_string(), &RoomMemberEventContent::new(MembershipState::Knock)),
PduBuilder::state(
body.user_id.to_string(),
&RoomMemberEventContent::new(MembershipState::Knock),
),
&body.user_id,
&body.room_id,
&state_lock,

View file

@ -9,7 +9,7 @@ use serde_json::value::to_raw_value;
use super::make_join::maybe_strip_event_id;
use crate::{service::pdu::PduBuilder, Ruma};
/// # `PUT /_matrix/federation/v1/make_leave/{roomId}/{eventId}`
/// # `GET /_matrix/federation/v1/make_leave/{roomId}/{eventId}`
///
/// Creates a leave template.
pub(crate) async fn create_leave_event_template_route(
@ -21,7 +21,9 @@ pub(crate) async fn create_leave_event_template_route(
}
if body.user_id.server_name() != body.origin() {
return Err!(Request(BadJson("Not allowed to leave on behalf of another server/user.")));
return Err!(Request(Forbidden(
"Not allowed to leave on behalf of another server/user."
)));
}
// ACL check origin

View file

@ -6,6 +6,7 @@ pub(super) mod hierarchy;
pub(super) mod invite;
pub(super) mod key;
pub(super) mod make_join;
pub(super) mod make_knock;
pub(super) mod make_leave;
pub(super) mod media;
pub(super) mod openid;
@ -13,6 +14,7 @@ pub(super) mod publicrooms;
pub(super) mod query;
pub(super) mod send;
pub(super) mod send_join;
pub(super) mod send_knock;
pub(super) mod send_leave;
pub(super) mod state;
pub(super) mod state_ids;
@ -28,6 +30,7 @@ pub(super) use hierarchy::*;
pub(super) use invite::*;
pub(super) use key::*;
pub(super) use make_join::*;
pub(super) use make_knock::*;
pub(super) use make_leave::*;
pub(super) use media::*;
pub(super) use openid::*;
@ -35,6 +38,7 @@ pub(super) use publicrooms::*;
pub(super) use query::*;
pub(super) use send::*;
pub(super) use send_join::*;
pub(super) use send_knock::*;
pub(super) use send_leave::*;
pub(super) use state::*;
pub(super) use state_ids::*;

View file

@ -186,14 +186,13 @@ async fn create_join_event(
.map_err(|e| err!(Request(InvalidParam(warn!("Failed to sign send_join event: {e}")))))?;
let origin: OwnedServerName = serde_json::from_value(
serde_json::to_value(
value
.get("origin")
.ok_or_else(|| err!(Request(BadJson("Event missing origin property."))))?,
)
.expect("CanonicalJson is valid json value"),
value
.get("origin")
.ok_or_else(|| err!(Request(BadJson("Event does not have an origin server name."))))?
.clone()
.into(),
)
.map_err(|e| err!(Request(BadJson(warn!("origin field is not a valid server name: {e}")))))?;
.map_err(|e| err!(Request(BadJson("Event has an invalid origin server name: {e}"))))?;
let mutex_lock = services
.rooms

View file

@ -1,7 +1,8 @@
use axum::extract::State;
use conduwuit::{err, pdu::gen_event_id_canonical_json, warn, Err, Error, PduEvent, Result};
use conduwuit::{err, pdu::gen_event_id_canonical_json, warn, Err, PduEvent, Result};
use futures::FutureExt;
use ruma::{
api::{client::error::ErrorKind, federation::knock::send_knock},
api::federation::knock::send_knock,
events::{
room::member::{MembershipState, RoomMemberEventContent},
StateEventType,
@ -17,7 +18,8 @@ use crate::Ruma;
///
/// Submits a signed knock event.
pub(crate) async fn create_knock_event_v1_route(
State(services): State<crate::State>, body: Ruma<send_knock::v1::Request>,
State(services): State<crate::State>,
body: Ruma<send_knock::v1::Request>,
) -> Result<send_knock::v1::Response> {
if services
.globals
@ -26,7 +28,8 @@ pub(crate) async fn create_knock_event_v1_route(
.contains(body.origin())
{
warn!(
"Server {} tried knocking room ID {} who has a server name that is globally forbidden. Rejecting.",
"Server {} tried knocking room ID {} who has a server name that is globally \
forbidden. Rejecting.",
body.origin(),
&body.room_id,
);
@ -41,7 +44,8 @@ pub(crate) async fn create_knock_event_v1_route(
.contains(&server.to_owned())
{
warn!(
"Server {} tried knocking room ID {} which has a server name that is globally forbidden. Rejecting.",
"Server {} tried knocking room ID {} which has a server name that is globally \
forbidden. Rejecting.",
body.origin(),
&body.room_id,
);
@ -50,7 +54,7 @@ pub(crate) async fn create_knock_event_v1_route(
}
if !services.rooms.metadata.exists(&body.room_id).await {
return Err(Error::BadRequest(ErrorKind::NotFound, "Room is unknown to this server."));
return Err!(Request(NotFound("Room is unknown to this server.")));
}
// ACL check origin server
@ -74,44 +78,42 @@ pub(crate) async fn create_knock_event_v1_route(
let event_type: StateEventType = serde_json::from_value(
value
.get("type")
.ok_or_else(|| Error::BadRequest(ErrorKind::InvalidParam, "Event missing type property."))?
.ok_or_else(|| err!(Request(InvalidParam("Event has no event type."))))?
.clone()
.into(),
)
.map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Event has invalid event type."))?;
.map_err(|e| err!(Request(InvalidParam("Event has invalid event type: {e}"))))?;
if event_type != StateEventType::RoomMember {
return Err(Error::BadRequest(
ErrorKind::InvalidParam,
return Err!(Request(InvalidParam(
"Not allowed to send non-membership state event to knock endpoint.",
));
)));
}
let content: RoomMemberEventContent = serde_json::from_value(
value
.get("content")
.ok_or_else(|| Error::BadRequest(ErrorKind::InvalidParam, "Event missing content property"))?
.ok_or_else(|| err!(Request(InvalidParam("Membership event has no content"))))?
.clone()
.into(),
)
.map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Event content is empty or invalid"))?;
.map_err(|e| err!(Request(InvalidParam("Event has invalid membership content: {e}"))))?;
if content.membership != MembershipState::Knock {
return Err(Error::BadRequest(
ErrorKind::InvalidParam,
"Not allowed to send a non-knock membership event to knock endpoint.",
));
return Err!(Request(InvalidParam(
"Not allowed to send a non-knock membership event to knock endpoint."
)));
}
// ACL check sender server name
let sender: OwnedUserId = serde_json::from_value(
value
.get("sender")
.ok_or_else(|| Error::BadRequest(ErrorKind::InvalidParam, "Event missing sender property."))?
.ok_or_else(|| err!(Request(InvalidParam("Event has no sender user ID."))))?
.clone()
.into(),
)
.map_err(|_| Error::BadRequest(ErrorKind::BadJson, "sender is not a valid user ID."))?;
.map_err(|e| err!(Request(BadJson("Event sender is not a valid user ID: {e}"))))?;
services
.rooms
@ -127,36 +129,32 @@ pub(crate) async fn create_knock_event_v1_route(
let state_key: OwnedUserId = serde_json::from_value(
value
.get("state_key")
.ok_or_else(|| Error::BadRequest(ErrorKind::InvalidParam, "Event missing state_key property."))?
.ok_or_else(|| err!(Request(InvalidParam("Event does not have a state_key"))))?
.clone()
.into(),
)
.map_err(|_| Error::BadRequest(ErrorKind::BadJson, "state_key is invalid or not a user ID."))?;
.map_err(|e| err!(Request(BadJson("Event does not have a valid state_key: {e}"))))?;
if state_key != sender {
return Err(Error::BadRequest(
ErrorKind::InvalidParam,
"State key does not match sender user",
));
return Err!(Request(InvalidParam("state_key does not match sender user of event.")));
};
let origin: OwnedServerName = serde_json::from_value(
serde_json::to_value(
value
.get("origin")
.ok_or_else(|| Error::BadRequest(ErrorKind::InvalidParam, "Event missing origin property."))?,
)
.expect("CanonicalJson is valid json value"),
value
.get("origin")
.ok_or_else(|| err!(Request(BadJson("Event does not have an origin server name."))))?
.clone()
.into(),
)
.map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "origin is not a server name."))?;
.map_err(|e| err!(Request(BadJson("Event has an invalid origin server name: {e}"))))?;
let mut event: JsonObject = serde_json::from_str(body.pdu.get())
.map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid knock event PDU."))?;
.map_err(|e| err!(Request(InvalidParam("Invalid knock event PDU: {e}"))))?;
event.insert("event_id".to_owned(), "$placeholder".into());
let pdu: PduEvent = serde_json::from_value(event.into())
.map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid knock event PDU."))?;
.map_err(|e| err!(Request(InvalidParam("Invalid knock event PDU: {e}"))))?;
let mutex_lock = services
.rooms
@ -169,19 +167,18 @@ pub(crate) async fn create_knock_event_v1_route(
.rooms
.event_handler
.handle_incoming_pdu(&origin, &body.room_id, &event_id, value.clone(), true)
.boxed()
.await?
.ok_or_else(|| err!(Request(InvalidParam("Could not accept as timeline event."))))?;
drop(mutex_lock);
let knock_room_state = services.rooms.state.summary_stripped(&pdu).await;
services
.sending
.send_pdu_room(&body.room_id, &pdu_id)
.await?;
Ok(send_knock::v1::Response {
knock_room_state,
})
let knock_room_state = services.rooms.state.summary_stripped(&pdu).await;
Ok(send_knock::v1::Response { knock_room_state })
}

View file

@ -1,6 +1,6 @@
use conduwuit::{implement, is_false, Err, Result};
use conduwuit_service::Services;
use futures::{future::OptionFuture, join, FutureExt};
use futures::{future::OptionFuture, join, FutureExt, StreamExt};
use ruma::{EventId, RoomId, ServerName};
pub(super) struct AccessCheck<'a> {
@ -31,6 +31,15 @@ pub(super) async fn check(&self) -> Result {
.state_cache
.server_in_room(self.origin, self.room_id);
// if any user on our homeserver is trying to knock this room, we'll need to
// acknowledge bans or leaves
let user_is_knocking = self
.services
.rooms
.state_cache
.room_members_knocked(self.room_id)
.count();
let server_can_see: OptionFuture<_> = self
.event_id
.map(|event_id| {
@ -42,14 +51,14 @@ pub(super) async fn check(&self) -> Result {
})
.into();
let (world_readable, server_in_room, server_can_see, acl_check) =
join!(world_readable, server_in_room, server_can_see, acl_check);
let (world_readable, server_in_room, server_can_see, acl_check, user_is_knocking) =
join!(world_readable, server_in_room, server_can_see, acl_check, user_is_knocking);
if !acl_check {
return Err!(Request(Forbidden("Server access denied.")));
}
if !world_readable && !server_in_room {
if !world_readable && !server_in_room && user_is_knocking == 0 {
return Err!(Request(Forbidden("Server is not in room.")));
}

View file

@ -184,6 +184,10 @@ pub(super) static MAPS: &[Descriptor] = &[
name: "roomuserid_leftcount",
..descriptor::RANDOM
},
Descriptor {
name: "roomuserid_knockedcount",
..descriptor::RANDOM_SMALL
},
Descriptor {
name: "roomuserid_privateread",
..descriptor::RANDOM_SMALL
@ -377,6 +381,10 @@ pub(super) static MAPS: &[Descriptor] = &[
name: "userroomid_leftstate",
..descriptor::RANDOM
},
Descriptor {
name: "userroomid_knockedstate",
..descriptor::RANDOM_SMALL
},
Descriptor {
name: "userroomid_notificationcount",
..descriptor::RANDOM

View file

@ -10,7 +10,7 @@ use conduwuit::{
warn, Result,
};
use database::{serialize_key, Deserialized, Ignore, Interfix, Json, Map};
use futures::{future::join4, pin_mut, stream::iter, Stream, StreamExt};
use futures::{future::join5, pin_mut, stream::iter, Stream, StreamExt};
use itertools::Itertools;
use ruma::{
events::{
@ -51,11 +51,13 @@ struct Data {
roomuserid_invitecount: Arc<Map>,
roomuserid_joined: Arc<Map>,
roomuserid_leftcount: Arc<Map>,
roomuserid_knockedcount: Arc<Map>,
roomuseroncejoinedids: Arc<Map>,
serverroomids: Arc<Map>,
userroomid_invitestate: Arc<Map>,
userroomid_joined: Arc<Map>,
userroomid_leftstate: Arc<Map>,
userroomid_knockedstate: Arc<Map>,
}
type AppServiceInRoomCache = RwLock<HashMap<OwnedRoomId, HashMap<String, bool>>>;
@ -81,11 +83,13 @@ impl crate::Service for Service {
roomuserid_invitecount: args.db["roomuserid_invitecount"].clone(),
roomuserid_joined: args.db["roomuserid_joined"].clone(),
roomuserid_leftcount: args.db["roomuserid_leftcount"].clone(),
roomuserid_knockedcount: args.db["roomuserid_knockedcount"].clone(),
roomuseroncejoinedids: args.db["roomuseroncejoinedids"].clone(),
serverroomids: args.db["serverroomids"].clone(),
userroomid_invitestate: args.db["userroomid_invitestate"].clone(),
userroomid_joined: args.db["userroomid_joined"].clone(),
userroomid_leftstate: args.db["userroomid_leftstate"].clone(),
userroomid_knockedstate: args.db["userroomid_knockedstate"].clone(),
},
}))
}
@ -336,6 +340,9 @@ impl Service {
self.db.userroomid_leftstate.remove(&userroom_id);
self.db.roomuserid_leftcount.remove(&roomuser_id);
self.db.userroomid_knockedstate.remove(&userroom_id);
self.db.roomuserid_knockedcount.remove(&roomuser_id);
self.db.roomid_inviteviaservers.remove(room_id);
}
@ -352,12 +359,13 @@ impl Service {
// (timo) TODO
let leftstate = Vec::<Raw<AnySyncStateEvent>>::new();
let count = self.services.globals.next_count().unwrap();
self.db
.userroomid_leftstate
.raw_put(&userroom_id, Json(leftstate));
self.db.roomuserid_leftcount.raw_put(&roomuser_id, count);
self.db
.roomuserid_leftcount
.raw_aput::<8, _, _>(&roomuser_id, self.services.globals.next_count().unwrap());
self.db.userroomid_joined.remove(&userroom_id);
self.db.roomuserid_joined.remove(&roomuser_id);
@ -365,6 +373,44 @@ impl Service {
self.db.userroomid_invitestate.remove(&userroom_id);
self.db.roomuserid_invitecount.remove(&roomuser_id);
self.db.userroomid_knockedstate.remove(&userroom_id);
self.db.roomuserid_knockedcount.remove(&roomuser_id);
self.db.roomid_inviteviaservers.remove(room_id);
}
/// Direct DB function to directly mark a user as knocked. It is not
/// recommended to use this directly. You most likely should use
/// `update_membership` instead
#[tracing::instrument(skip(self), level = "debug")]
pub fn mark_as_knocked(
&self,
user_id: &UserId,
room_id: &RoomId,
knocked_state: Option<Vec<Raw<AnyStrippedStateEvent>>>,
) {
let userroom_id = (user_id, room_id);
let userroom_id = serialize_key(userroom_id).expect("failed to serialize userroom_id");
let roomuser_id = (room_id, user_id);
let roomuser_id = serialize_key(roomuser_id).expect("failed to serialize roomuser_id");
self.db
.userroomid_knockedstate
.raw_put(&userroom_id, Json(knocked_state.unwrap_or_default()));
self.db
.roomuserid_knockedcount
.raw_aput::<8, _, _>(&roomuser_id, self.services.globals.next_count().unwrap());
self.db.userroomid_joined.remove(&userroom_id);
self.db.roomuserid_joined.remove(&roomuser_id);
self.db.userroomid_invitestate.remove(&userroom_id);
self.db.roomuserid_invitecount.remove(&roomuser_id);
self.db.userroomid_leftstate.remove(&userroom_id);
self.db.roomuserid_leftcount.remove(&roomuser_id);
self.db.roomid_inviteviaservers.remove(room_id);
}
@ -528,6 +574,20 @@ impl Service {
.map(|(_, user_id): (Ignore, &UserId)| user_id)
}
/// Returns an iterator over all knocked members of a room.
#[tracing::instrument(skip(self), level = "debug")]
pub fn room_members_knocked<'a>(
&'a self,
room_id: &'a RoomId,
) -> impl Stream<Item = &UserId> + Send + 'a {
let prefix = (room_id, Interfix);
self.db
.roomuserid_knockedcount
.keys_prefix(&prefix)
.ignore_err()
.map(|(_, user_id): (Ignore, &UserId)| user_id)
}
#[tracing::instrument(skip(self), level = "trace")]
pub async fn get_invite_count(&self, room_id: &RoomId, user_id: &UserId) -> Result<u64> {
let key = (room_id, user_id);
@ -538,6 +598,16 @@ impl Service {
.deserialized()
}
#[tracing::instrument(skip(self), level = "trace")]
pub async fn get_knock_count(&self, room_id: &RoomId, user_id: &UserId) -> Result<u64> {
let key = (room_id, user_id);
self.db
.roomuserid_knockedcount
.qry(&key)
.await
.deserialized()
}
#[tracing::instrument(skip(self), level = "trace")]
pub async fn get_left_count(&self, room_id: &RoomId, user_id: &UserId) -> Result<u64> {
let key = (room_id, user_id);
@ -576,6 +646,25 @@ impl Service {
.ignore_err()
}
/// Returns an iterator over all rooms a user is currently knocking.
#[tracing::instrument(skip(self), level = "trace")]
pub fn rooms_knocked<'a>(
&'a self,
user_id: &'a UserId,
) -> impl Stream<Item = StrippedStateEventItem> + Send + 'a {
type KeyVal<'a> = (Key<'a>, Raw<Vec<AnyStrippedStateEvent>>);
type Key<'a> = (&'a UserId, &'a RoomId);
let prefix = (user_id, Interfix);
self.db
.userroomid_knockedstate
.stream_prefix(&prefix)
.ignore_err()
.map(|((_, room_id), state): KeyVal<'_>| (room_id.to_owned(), state))
.map(|(room_id, state)| Ok((room_id, state.deserialize_as()?)))
.ignore_err()
}
#[tracing::instrument(skip(self), level = "trace")]
pub async fn invite_state(
&self,
@ -593,6 +682,23 @@ impl Service {
})
}
#[tracing::instrument(skip(self), level = "trace")]
pub async fn knock_state(
&self,
user_id: &UserId,
room_id: &RoomId,
) -> Result<Vec<Raw<AnyStrippedStateEvent>>> {
let key = (user_id, room_id);
self.db
.userroomid_knockedstate
.qry(&key)
.await
.deserialized()
.and_then(|val: Raw<Vec<AnyStrippedStateEvent>>| {
val.deserialize_as().map_err(Into::into)
})
}
#[tracing::instrument(skip(self), level = "trace")]
pub async fn left_state(
&self,
@ -641,6 +747,12 @@ impl Service {
self.db.userroomid_joined.qry(&key).await.is_ok()
}
#[tracing::instrument(skip(self), level = "trace")]
pub async fn is_knocked<'a>(&'a self, user_id: &'a UserId, room_id: &'a RoomId) -> bool {
let key = (user_id, room_id);
self.db.userroomid_knockedstate.qry(&key).await.is_ok()
}
#[tracing::instrument(skip(self), level = "trace")]
pub async fn is_invited(&self, user_id: &UserId, room_id: &RoomId) -> bool {
let key = (user_id, room_id);
@ -659,9 +771,10 @@ impl Service {
user_id: &UserId,
room_id: &RoomId,
) -> Option<MembershipState> {
let states = join4(
let states = join5(
self.is_joined(user_id, room_id),
self.is_left(user_id, room_id),
self.is_knocked(user_id, room_id),
self.is_invited(user_id, room_id),
self.once_joined(user_id, room_id),
)
@ -670,8 +783,9 @@ impl Service {
match states {
| (true, ..) => Some(MembershipState::Join),
| (_, true, ..) => Some(MembershipState::Leave),
| (_, _, true, ..) => Some(MembershipState::Invite),
| (false, false, false, true) => Some(MembershipState::Ban),
| (_, _, true, ..) => Some(MembershipState::Knock),
| (_, _, _, true, ..) => Some(MembershipState::Invite),
| (false, false, false, false, true) => Some(MembershipState::Ban),
| _ => None,
}
}
@ -747,6 +861,7 @@ impl Service {
pub async fn update_joined_count(&self, room_id: &RoomId) {
let mut joinedcount = 0_u64;
let mut invitedcount = 0_u64;
let mut knockedcount = 0_u64;
let mut joined_servers = HashSet::new();
self.room_members(room_id)
@ -764,8 +879,19 @@ impl Service {
.unwrap_or(0),
);
knockedcount = knockedcount.saturating_add(
self.room_members_knocked(room_id)
.count()
.await
.try_into()
.unwrap_or(0),
);
self.db.roomid_joinedcount.raw_put(room_id, joinedcount);
self.db.roomid_invitedcount.raw_put(room_id, invitedcount);
self.db
.roomuserid_knockedcount
.raw_put(room_id, knockedcount);
self.room_servers(room_id)
.ready_for_each(|old_joined_server| {
@ -820,7 +946,6 @@ impl Service {
self.db
.userroomid_invitestate
.raw_put(&userroom_id, Json(last_state.unwrap_or_default()));
self.db
.roomuserid_invitecount
.raw_aput::<8, _, _>(&roomuser_id, self.services.globals.next_count().unwrap());
@ -831,6 +956,9 @@ impl Service {
self.db.userroomid_leftstate.remove(&userroom_id);
self.db.roomuserid_leftcount.remove(&roomuser_id);
self.db.userroomid_knockedstate.remove(&userroom_id);
self.db.roomuserid_knockedcount.remove(&roomuser_id);
if let Some(servers) = invite_via.filter(is_not_empty!()) {
self.add_servers_invite_via(room_id, servers).await;
}

View file

@ -498,14 +498,15 @@ impl Service {
.expect("This state_key was previously validated");
let content: RoomMemberEventContent = pdu.get_content()?;
let invite_state = match content.membership {
| MembershipState::Invite =>
let stripped_state = match content.membership {
| MembershipState::Invite | MembershipState::Knock =>
self.services.state.summary_stripped(pdu).await.into(),
| _ => None,
};
// Update our membership info, we do this here incase a user is invited
// and immediately leaves we need the DB to record the invite event for auth
// Update our membership info, we do this here incase a user is invited or
// knocked and immediately leaves we need the DB to record the invite or
// knock event for auth
self.services
.state_cache
.update_membership(
@ -513,7 +514,7 @@ impl Service {
target_user_id,
content,
&pdu.sender,
invite_state,
stripped_state,
None,
true,
)