add state accessories for iterating state_keys of a type

Signed-off-by: Jason Volk <jason@zemos.net>
This commit is contained in:
Jason Volk 2025-02-07 01:16:46 +00:00 committed by strawberry
parent 59c073d0d8
commit e123a5b660

View file

@ -9,7 +9,7 @@ use conduwuit::{
PduEvent, Result,
};
use database::Deserialized;
use futures::{future::try_join, FutureExt, Stream, StreamExt, TryFutureExt};
use futures::{future::try_join, pin_mut, FutureExt, Stream, StreamExt, TryFutureExt};
use ruma::{
events::{
room::member::{MembershipState, RoomMemberEventContent},
@ -69,7 +69,6 @@ where
}
#[implement(super::Service)]
#[tracing::instrument(skip(self), level = "debug")]
pub async fn state_contains(
&self,
shortstatehash: ShortStateHash,
@ -90,7 +89,18 @@ pub async fn state_contains(
}
#[implement(super::Service)]
#[tracing::instrument(skip(self), level = "debug")]
pub async fn state_contains_type(
&self,
shortstatehash: ShortStateHash,
event_type: &StateEventType,
) -> bool {
let state_keys = self.state_keys(shortstatehash, event_type);
pin_mut!(state_keys);
state_keys.next().await.is_some()
}
#[implement(super::Service)]
pub async fn state_contains_shortstatekey(
&self,
shortstatehash: ShortStateHash,
@ -125,7 +135,6 @@ pub async fn state_get(
/// Returns a single EventId from `room_id` with key (`event_type`,
/// `state_key`).
#[implement(super::Service)]
#[tracing::instrument(skip(self), level = "debug")]
pub async fn state_get_id<Id>(
&self,
shortstatehash: ShortStateHash,
@ -149,7 +158,6 @@ where
/// Returns a single EventId from `room_id` with key (`event_type`,
/// `state_key`).
#[implement(super::Service)]
#[tracing::instrument(skip(self), level = "debug")]
pub async fn state_get_shortid(
&self,
shortstatehash: ShortStateHash,
@ -177,6 +185,103 @@ pub async fn state_get_shortid(
.await?
}
/// Iterates the state_keys for an event_type in the state; current state
/// event_id included.
#[implement(super::Service)]
pub fn state_keys_with_ids<'a, Id>(
&'a self,
shortstatehash: ShortStateHash,
event_type: &'a StateEventType,
) -> impl Stream<Item = (String, Id)> + Send + 'a
where
Id: for<'de> Deserialize<'de> + Send + Sized + ToOwned + 'a,
<Id as ToOwned>::Owned: Borrow<EventId>,
{
let state_keys_with_short_ids = self
.state_keys_with_shortids(shortstatehash, event_type)
.unzip()
.map(|(ssks, sids): (Vec<String>, Vec<u64>)| (ssks, sids))
.shared();
let state_keys = state_keys_with_short_ids
.clone()
.map(at!(0))
.map(Vec::into_iter)
.map(IterStream::stream)
.flatten_stream();
let shorteventids = state_keys_with_short_ids
.map(at!(1))
.map(Vec::into_iter)
.map(IterStream::stream)
.flatten_stream();
self.services
.short
.multi_get_eventid_from_short(shorteventids)
.zip(state_keys)
.ready_filter_map(|(eid, sk)| eid.map(move |eid| (sk, eid)).ok())
}
/// Iterates the state_keys for an event_type in the state; current state
/// event_id included.
#[implement(super::Service)]
pub fn state_keys_with_shortids<'a>(
&'a self,
shortstatehash: ShortStateHash,
event_type: &'a StateEventType,
) -> impl Stream<Item = (String, ShortEventId)> + Send + 'a {
let short_ids = self
.state_full_shortids(shortstatehash)
.expect_ok()
.unzip()
.map(|(ssks, sids): (Vec<u64>, Vec<u64>)| (ssks, sids))
.shared();
let shortstatekeys = short_ids
.clone()
.map(at!(0))
.map(Vec::into_iter)
.map(IterStream::stream)
.flatten_stream();
let shorteventids = short_ids
.map(at!(1))
.map(Vec::into_iter)
.map(IterStream::stream)
.flatten_stream();
self.services
.short
.multi_get_statekey_from_short(shortstatekeys)
.zip(shorteventids)
.ready_filter_map(|(res, id)| res.map(|res| (res, id)).ok())
.ready_filter_map(move |((event_type_, state_key), event_id)| {
event_type_.eq(event_type).then_some((state_key, event_id))
})
}
/// Iterates the state_keys for an event_type in the state
#[implement(super::Service)]
pub fn state_keys<'a>(
&'a self,
shortstatehash: ShortStateHash,
event_type: &'a StateEventType,
) -> impl Stream<Item = String> + Send + 'a {
let short_ids = self
.state_full_shortids(shortstatehash)
.expect_ok()
.map(at!(0));
self.services
.short
.multi_get_statekey_from_short(short_ids)
.ready_filter_map(Result::ok)
.ready_filter_map(move |(event_type_, state_key)| {
event_type_.eq(event_type).then_some(state_key)
})
}
/// Returns the state events removed between the interval (present in .0 but
/// not in .1)
#[implement(super::Service)]
@ -191,11 +296,10 @@ pub fn state_removed(
/// Returns the state events added between the interval (present in .1 but
/// not in .0)
#[implement(super::Service)]
#[tracing::instrument(skip(self), level = "debug")]
pub fn state_added<'a>(
&'a self,
pub fn state_added(
&self,
shortstatehash: pair_of!(ShortStateHash),
) -> impl Stream<Item = (ShortStateKey, ShortEventId)> + Send + 'a {
) -> impl Stream<Item = (ShortStateKey, ShortEventId)> + Send + '_ {
let a = self.load_full_state(shortstatehash.0);
let b = self.load_full_state(shortstatehash.1);
try_join(a, b)
@ -239,7 +343,6 @@ pub fn state_full_pdus(
/// Builds a StateMap by iterating over all keys that start
/// with state_hash, this gives the full state for the given state_hash.
#[implement(super::Service)]
#[tracing::instrument(skip(self), level = "debug")]
pub fn state_full_ids<'a, Id>(
&'a self,
shortstatehash: ShortStateHash,
@ -293,6 +396,7 @@ pub fn state_full_shortids(
}
#[implement(super::Service)]
#[tracing::instrument(name = "load", level = "debug", skip(self))]
async fn load_full_state(&self, shortstatehash: ShortStateHash) -> Result<Arc<CompressedState>> {
self.services
.state_compressor