cache compressed state in a sorted structure for logarithmic queries with partial keys

Signed-off-by: Jason Volk <jason@zemos.net>
This commit is contained in:
Jason Volk 2025-01-31 15:50:09 +00:00
parent ea49b60273
commit 4add39d0fe
7 changed files with 118 additions and 73 deletions

View file

@ -46,7 +46,10 @@ use ruma::{
use service::{
appservice::RegistrationInfo,
pdu::gen_event_id,
rooms::{state::RoomMutexGuard, state_compressor::HashSetCompressStateEvent},
rooms::{
state::RoomMutexGuard,
state_compressor::{CompressedState, HashSetCompressStateEvent},
},
Services,
};
@ -1169,7 +1172,7 @@ async fn join_room_by_id_helper_remote(
}
info!("Compressing state from send_join");
let compressed: HashSet<_> = services
let compressed: CompressedState = services
.rooms
.state_compressor
.compress_state_events(state.iter().map(|(ssk, eid)| (ssk, eid.borrow())))
@ -2340,7 +2343,7 @@ async fn knock_room_helper_remote(
}
info!("Compressing state from send_knock");
let compressed: HashSet<_> = services
let compressed: CompressedState = services
.rooms
.state_compressor
.compress_state_events(state_map.iter().map(|(ssk, eid)| (ssk, eid.borrow())))

View file

@ -15,7 +15,7 @@ use ruma::{
OwnedEventId, RoomId, RoomVersionId,
};
use crate::rooms::state_compressor::CompressedStateEvent;
use crate::rooms::state_compressor::CompressedState;
#[implement(super::Service)]
#[tracing::instrument(name = "resolve", level = "debug", skip_all)]
@ -24,7 +24,7 @@ pub async fn resolve_state(
room_id: &RoomId,
room_version_id: &RoomVersionId,
incoming_state: HashMap<u64, OwnedEventId>,
) -> Result<Arc<HashSet<CompressedStateEvent>>> {
) -> Result<Arc<CompressedState>> {
trace!("Loading current room state ids");
let current_sstatehash = self
.services
@ -91,7 +91,7 @@ pub async fn resolve_state(
.await;
trace!("Compressing state...");
let new_room_state: HashSet<_> = self
let new_room_state: CompressedState = self
.services
.state_compressor
.compress_state_events(

View file

@ -1,10 +1,4 @@
use std::{
borrow::Borrow,
collections::{BTreeMap, HashSet},
iter::once,
sync::Arc,
time::Instant,
};
use std::{borrow::Borrow, collections::BTreeMap, iter::once, sync::Arc, time::Instant};
use conduwuit::{
debug, debug_info, err, implement, trace,
@ -19,7 +13,10 @@ use ruma::{
};
use super::{get_room_version_id, to_room_version};
use crate::rooms::{state_compressor::HashSetCompressStateEvent, timeline::RawPduId};
use crate::rooms::{
state_compressor::{CompressedState, HashSetCompressStateEvent},
timeline::RawPduId,
};
#[implement(super::Service)]
pub(super) async fn upgrade_outlier_to_timeline_pdu(
@ -173,7 +170,7 @@ pub(super) async fn upgrade_outlier_to_timeline_pdu(
incoming_pdu.prev_events.len()
);
let state_ids_compressed: Arc<HashSet<_>> = self
let state_ids_compressed: Arc<CompressedState> = self
.services
.state_compressor
.compress_state_events(

View file

@ -1,9 +1,4 @@
use std::{
collections::{HashMap, HashSet},
fmt::Write,
iter::once,
sync::Arc,
};
use std::{collections::HashMap, fmt::Write, iter::once, sync::Arc};
use conduwuit::{
err,
@ -33,7 +28,7 @@ use crate::{
globals, rooms,
rooms::{
short::{ShortEventId, ShortStateHash},
state_compressor::{parse_compressed_state_event, CompressedStateEvent},
state_compressor::{parse_compressed_state_event, CompressedState},
},
Dep,
};
@ -102,10 +97,9 @@ impl Service {
&self,
room_id: &RoomId,
shortstatehash: u64,
statediffnew: Arc<HashSet<CompressedStateEvent>>,
_statediffremoved: Arc<HashSet<CompressedStateEvent>>,
state_lock: &RoomMutexGuard, /* Take mutex guard to make sure users get the room state
* mutex */
statediffnew: Arc<CompressedState>,
_statediffremoved: Arc<CompressedState>,
state_lock: &RoomMutexGuard,
) -> Result {
let event_ids = statediffnew
.iter()
@ -176,7 +170,7 @@ impl Service {
&self,
event_id: &EventId,
room_id: &RoomId,
state_ids_compressed: Arc<HashSet<CompressedStateEvent>>,
state_ids_compressed: Arc<CompressedState>,
) -> Result<ShortStateHash> {
const KEY_LEN: usize = size_of::<ShortEventId>();
const VAL_LEN: usize = size_of::<ShortStateHash>();
@ -209,12 +203,12 @@ impl Service {
let (statediffnew, statediffremoved) =
if let Some(parent_stateinfo) = states_parents.last() {
let statediffnew: HashSet<_> = state_ids_compressed
let statediffnew: CompressedState = state_ids_compressed
.difference(&parent_stateinfo.full_state)
.copied()
.collect();
let statediffremoved: HashSet<_> = parent_stateinfo
let statediffremoved: CompressedState = parent_stateinfo
.full_state
.difference(&state_ids_compressed)
.copied()
@ -222,7 +216,7 @@ impl Service {
(Arc::new(statediffnew), Arc::new(statediffremoved))
} else {
(state_ids_compressed, Arc::new(HashSet::new()))
(state_ids_compressed, Arc::new(CompressedState::new()))
};
self.services.state_compressor.save_state_from_diff(
shortstatehash,
@ -300,10 +294,10 @@ impl Service {
// TODO: statehash with deterministic inputs
let shortstatehash = self.services.globals.next_count()?;
let mut statediffnew = HashSet::new();
let mut statediffnew = CompressedState::new();
statediffnew.insert(new);
let mut statediffremoved = HashSet::new();
let mut statediffremoved = CompressedState::new();
if let Some(replaces) = replaces {
statediffremoved.insert(*replaces);
}

View file

@ -11,6 +11,7 @@ use conduwuit::{
utils,
utils::{
math::{usize_from_f64, Expected},
result::FlatOk,
stream::{BroadbandExt, IterStream, ReadyExt, TryExpect},
},
Err, Error, PduEvent, Result,
@ -47,7 +48,7 @@ use crate::{
rooms::{
short::{ShortEventId, ShortStateHash, ShortStateKey},
state::RoomMutexGuard,
state_compressor::parse_compressed_state_event,
state_compressor::{compress_state_event, parse_compressed_state_event},
},
Dep,
};
@ -220,36 +221,88 @@ impl Service {
Id: for<'de> Deserialize<'de> + Sized + ToOwned,
<Id as ToOwned>::Owned: Borrow<EventId>,
{
let shortstatekey = self
.services
.short
.get_shortstatekey(event_type, state_key)
let shorteventid = self
.state_get_shortid(shortstatehash, event_type, state_key)
.await?;
let full_state = self
.services
.state_compressor
.load_shortstatehash_info(shortstatehash)
.await
.map_err(|e| err!(Database(error!(?event_type, ?state_key, "Missing state: {e:?}"))))?
.pop()
.expect("there is always one layer")
.full_state;
let compressed = full_state
.iter()
.find(|bytes| bytes.starts_with(&shortstatekey.to_be_bytes()))
.ok_or(err!(Database("No shortstatekey in compressed state")))?;
let (_, shorteventid) = parse_compressed_state_event(*compressed);
self.services
.short
.get_eventid_from_short(shorteventid)
.await
}
#[inline]
/// Returns a single EventId from `room_id` with key (`event_type`,
/// `state_key`).
#[tracing::instrument(skip(self), level = "debug")]
pub async fn state_get_shortid(
&self,
shortstatehash: ShortStateHash,
event_type: &StateEventType,
state_key: &str,
) -> Result<ShortEventId> {
let shortstatekey = self
.services
.short
.get_shortstatekey(event_type, state_key)
.await?;
let start = compress_state_event(shortstatekey, 0);
let end = compress_state_event(shortstatekey, u64::MAX);
self.services
.state_compressor
.load_shortstatehash_info(shortstatehash)
.map_ok(|vec| vec.last().expect("at least one layer").full_state.clone())
.map_ok(|full_state| {
full_state
.range(start..end)
.next()
.copied()
.map(parse_compressed_state_event)
.map(at!(1))
.ok_or(err!(Request(NotFound("Not found in room state"))))
})
.await?
}
#[tracing::instrument(skip(self), level = "debug")]
pub async fn state_contains(
&self,
shortstatehash: ShortStateHash,
event_type: &StateEventType,
state_key: &str,
) -> bool {
let Ok(shortstatekey) = self
.services
.short
.get_shortstatekey(event_type, state_key)
.await
else {
return false;
};
self.state_contains_shortstatekey(shortstatehash, shortstatekey)
.await
}
#[tracing::instrument(skip(self), level = "debug")]
pub async fn state_contains_shortstatekey(
&self,
shortstatehash: ShortStateHash,
shortstatekey: ShortStateKey,
) -> bool {
let start = compress_state_event(shortstatekey, 0);
let end = compress_state_event(shortstatekey, u64::MAX);
self.services
.state_compressor
.load_shortstatehash_info(shortstatehash)
.map_ok(|vec| vec.last().expect("at least one layer").full_state.clone())
.map_ok(|full_state| full_state.range(start..end).next().copied())
.await
.flat_ok()
.is_some()
}
pub fn state_full_shortids(
&self,
shortstatehash: ShortStateHash,

View file

@ -1,5 +1,5 @@
use std::{
collections::{HashMap, HashSet},
collections::{BTreeSet, HashMap},
fmt::{Debug, Write},
mem::size_of,
sync::{Arc, Mutex},
@ -63,8 +63,8 @@ type StateInfoLruCache = LruCache<ShortStateHash, ShortStateInfoVec>;
type ShortStateInfoVec = Vec<ShortStateInfo>;
type ParentStatesVec = Vec<ShortStateInfo>;
pub(crate) type CompressedState = HashSet<CompressedStateEvent>;
pub(crate) type CompressedStateEvent = [u8; 2 * size_of::<ShortId>()];
pub type CompressedState = BTreeSet<CompressedStateEvent>;
pub type CompressedStateEvent = [u8; 2 * size_of::<ShortId>()];
impl crate::Service for Service {
fn build(args: crate::Args<'_>) -> Result<Arc<Self>> {
@ -249,8 +249,8 @@ impl Service {
pub fn save_state_from_diff(
&self,
shortstatehash: ShortStateHash,
statediffnew: Arc<HashSet<CompressedStateEvent>>,
statediffremoved: Arc<HashSet<CompressedStateEvent>>,
statediffnew: Arc<CompressedState>,
statediffremoved: Arc<CompressedState>,
diff_to_sibling: usize,
mut parent_states: ParentStatesVec,
) -> Result {
@ -363,7 +363,7 @@ impl Service {
pub async fn save_state(
&self,
room_id: &RoomId,
new_state_ids_compressed: Arc<HashSet<CompressedStateEvent>>,
new_state_ids_compressed: Arc<CompressedState>,
) -> Result<HashSetCompressStateEvent> {
let previous_shortstatehash = self
.services
@ -396,12 +396,12 @@ impl Service {
let (statediffnew, statediffremoved) =
if let Some(parent_stateinfo) = states_parents.last() {
let statediffnew: HashSet<_> = new_state_ids_compressed
let statediffnew: CompressedState = new_state_ids_compressed
.difference(&parent_stateinfo.full_state)
.copied()
.collect();
let statediffremoved: HashSet<_> = parent_stateinfo
let statediffremoved: CompressedState = parent_stateinfo
.full_state
.difference(&new_state_ids_compressed)
.copied()
@ -409,7 +409,7 @@ impl Service {
(Arc::new(statediffnew), Arc::new(statediffremoved))
} else {
(new_state_ids_compressed, Arc::new(HashSet::new()))
(new_state_ids_compressed, Arc::new(CompressedState::new()))
};
if !already_existed {
@ -448,11 +448,11 @@ impl Service {
.take_if(|parent| *parent != 0);
debug_assert!(value.len() % STRIDE == 0, "value not aligned to stride");
let num_values = value.len() / STRIDE;
let _num_values = value.len() / STRIDE;
let mut add_mode = true;
let mut added = HashSet::with_capacity(num_values);
let mut removed = HashSet::with_capacity(num_values);
let mut added = CompressedState::new();
let mut removed = CompressedState::new();
let mut i = STRIDE;
while let Some(v) = value.get(i..expected!(i + 2 * STRIDE)) {
@ -469,8 +469,6 @@ impl Service {
i = expected!(i + 2 * STRIDE);
}
added.shrink_to_fit();
removed.shrink_to_fit();
Ok(StateDiff {
parent,
added: Arc::new(added),
@ -507,7 +505,7 @@ impl Service {
#[inline]
#[must_use]
fn compress_state_event(
pub(crate) fn compress_state_event(
shortstatekey: ShortStateKey,
shorteventid: ShortEventId,
) -> CompressedStateEvent {
@ -523,7 +521,7 @@ fn compress_state_event(
#[inline]
#[must_use]
pub fn parse_compressed_state_event(
pub(crate) fn parse_compressed_state_event(
compressed_event: CompressedStateEvent,
) -> (ShortStateKey, ShortEventId) {
use utils::u64_from_u8;

View file

@ -49,7 +49,7 @@ use crate::{
account_data, admin, appservice,
appservice::NamespaceRegex,
globals, pusher, rooms,
rooms::{short::ShortRoomId, state_compressor::CompressedStateEvent},
rooms::{short::ShortRoomId, state_compressor::CompressedState},
sending, server_keys, users, Dep,
};
@ -950,7 +950,7 @@ impl Service {
pdu: &'a PduEvent,
pdu_json: CanonicalJsonObject,
new_room_leafs: Leafs,
state_ids_compressed: Arc<HashSet<CompressedStateEvent>>,
state_ids_compressed: Arc<CompressedState>,
soft_fail: bool,
state_lock: &'a RoomMutexGuard,
) -> Result<Option<RawPduId>>