flatten auth chain iterations

Signed-off-by: Jason Volk <jason@zemos.net>
This commit is contained in:
Jason Volk 2025-01-29 08:39:44 +00:00
parent eb7d893c86
commit 50acfe7832
8 changed files with 90 additions and 111 deletions

View file

@ -6,8 +6,9 @@ use std::{
};
use conduwuit::{
debug_error, err, info, trace, utils, utils::string::EMPTY, warn, Error, PduEvent, PduId,
RawPduId, Result,
debug_error, err, info, trace, utils,
utils::{stream::ReadyExt, string::EMPTY},
warn, Error, PduEvent, PduId, RawPduId, Result,
};
use futures::{FutureExt, StreamExt, TryStreamExt};
use ruma::{
@ -54,7 +55,7 @@ pub(super) async fn get_auth_chain(
.rooms
.auth_chain
.event_ids_iter(room_id, once(event_id.as_ref()))
.await?
.ready_filter_map(Result::ok)
.count()
.await;

View file

@ -1,7 +1,7 @@
use std::{borrow::Borrow, iter::once};
use axum::extract::State;
use conduwuit::{Error, Result};
use conduwuit::{utils::stream::ReadyExt, Error, Result};
use futures::StreamExt;
use ruma::{
api::{client::error::ErrorKind, federation::authorization::get_event_authorization},
@ -48,7 +48,7 @@ pub(crate) async fn get_event_authorization_route(
.rooms
.auth_chain
.event_ids_iter(room_id, once(body.event_id.borrow()))
.await?
.ready_filter_map(Result::ok)
.filter_map(|id| async move { services.rooms.timeline.get_pdu_json(&id).await.ok() })
.then(|pdu| services.sending.convert_to_outgoing_federation_event(pdu))
.collect()

View file

@ -238,8 +238,6 @@ async fn create_join_event(
.rooms
.auth_chain
.event_ids_iter(room_id, starting_events)
.await?
.map(Ok)
.broad_and_then(|event_id| async move {
services.rooms.timeline.get_pdu_json(&event_id).await
})

View file

@ -56,8 +56,6 @@ pub(crate) async fn get_room_state_route(
.rooms
.auth_chain
.event_ids_iter(&body.room_id, once(body.event_id.borrow()))
.await?
.map(Ok)
.and_then(|id| async move { services.rooms.timeline.get_pdu_json(&id).await })
.and_then(|pdu| {
services

View file

@ -2,7 +2,7 @@ use std::{borrow::Borrow, iter::once};
use axum::extract::State;
use conduwuit::{at, err, Result};
use futures::StreamExt;
use futures::{StreamExt, TryStreamExt};
use ruma::{api::federation::event::get_room_state_ids, OwnedEventId};
use super::AccessCheck;
@ -44,10 +44,8 @@ pub(crate) async fn get_room_state_ids_route(
.rooms
.auth_chain
.event_ids_iter(&body.room_id, once(body.event_id.borrow()))
.await?
.map(|id| (*id).to_owned())
.collect()
.await;
.try_collect()
.await?;
Ok(get_room_state_ids::v1::Response { auth_chain_ids, pdu_ids })
}

View file

@ -4,6 +4,7 @@ use std::{
collections::{BTreeSet, HashSet, VecDeque},
fmt::Debug,
sync::Arc,
time::Instant,
};
use conduwuit::{
@ -14,7 +15,7 @@ use conduwuit::{
},
validated, warn, Err, Result,
};
use futures::{Stream, StreamExt, TryFutureExt, TryStreamExt};
use futures::{FutureExt, Stream, StreamExt, TryFutureExt, TryStreamExt};
use ruma::{EventId, OwnedEventId, RoomId};
use self::data::Data;
@ -30,6 +31,8 @@ struct Services {
timeline: Dep<rooms::timeline::Service>,
}
type Bucket<'a> = BTreeSet<(u64, &'a EventId)>;
impl crate::Service for Service {
fn build(args: crate::Args<'_>) -> Result<Arc<Self>> {
Ok(Arc::new(Self {
@ -45,42 +48,22 @@ impl crate::Service for Service {
}
#[implement(Service)]
pub async fn event_ids_iter<'a, I>(
pub fn event_ids_iter<'a, I>(
&'a self,
room_id: &RoomId,
room_id: &'a RoomId,
starting_events: I,
) -> Result<impl Stream<Item = OwnedEventId> + Send + '_>
) -> impl Stream<Item = Result<OwnedEventId>> + Send + 'a
where
I: Iterator<Item = &'a EventId> + Clone + Debug + ExactSizeIterator + Send + 'a,
{
let stream = self
.get_event_ids(room_id, starting_events)
.await?
.into_iter()
.stream();
Ok(stream)
}
#[implement(Service)]
pub async fn get_event_ids<'a, I>(
&'a self,
room_id: &RoomId,
starting_events: I,
) -> Result<Vec<OwnedEventId>>
where
I: Iterator<Item = &'a EventId> + Clone + Debug + ExactSizeIterator + Send + 'a,
{
let chain = self.get_auth_chain(room_id, starting_events).await?;
let event_ids = self
.services
.short
.multi_get_eventid_from_short(chain.into_iter().stream())
.ready_filter_map(Result::ok)
.collect()
.await;
Ok(event_ids)
self.get_auth_chain(room_id, starting_events)
.map_ok(|chain| {
self.services
.short
.multi_get_eventid_from_short(chain.into_iter().stream())
.ready_filter(Result::is_ok)
})
.try_flatten_stream()
}
#[implement(Service)]
@ -94,9 +77,9 @@ where
I: Iterator<Item = &'a EventId> + Clone + Debug + ExactSizeIterator + Send + 'a,
{
const NUM_BUCKETS: usize = 50; //TODO: change possible w/o disrupting db?
const BUCKET: BTreeSet<(u64, &EventId)> = BTreeSet::new();
const BUCKET: Bucket<'_> = BTreeSet::new();
let started = std::time::Instant::now();
let started = Instant::now();
let mut starting_ids = self
.services
.short
@ -120,53 +103,7 @@ where
let full_auth_chain: Vec<ShortEventId> = buckets
.into_iter()
.try_stream()
.broad_and_then(|chunk| async move {
let chunk_key: Vec<ShortEventId> = chunk.iter().map(at!(0)).collect();
if chunk_key.is_empty() {
return Ok(Vec::new());
}
if let Ok(cached) = self.get_cached_eventid_authchain(&chunk_key).await {
return Ok(cached.to_vec());
}
let chunk_cache: Vec<_> = chunk
.into_iter()
.try_stream()
.broad_and_then(|(shortid, event_id)| async move {
if let Ok(cached) = self.get_cached_eventid_authchain(&[shortid]).await {
return Ok(cached.to_vec());
}
let auth_chain = self.get_auth_chain_inner(room_id, event_id).await?;
self.cache_auth_chain_vec(vec![shortid], auth_chain.as_slice());
debug!(
?event_id,
elapsed = ?started.elapsed(),
"Cache missed event"
);
Ok(auth_chain)
})
.try_collect()
.map_ok(|chunk_cache: Vec<_>| chunk_cache.into_iter().flatten().collect())
.map_ok(|mut chunk_cache: Vec<_>| {
chunk_cache.sort_unstable();
chunk_cache.dedup();
chunk_cache
})
.await?;
self.cache_auth_chain_vec(chunk_key, chunk_cache.as_slice());
debug!(
chunk_cache_length = ?chunk_cache.len(),
elapsed = ?started.elapsed(),
"Cache missed chunk",
);
Ok(chunk_cache)
})
.broad_and_then(|chunk| self.get_auth_chain_outer(room_id, started, chunk))
.try_collect()
.map_ok(|auth_chain: Vec<_>| auth_chain.into_iter().flatten().collect())
.map_ok(|mut full_auth_chain: Vec<_>| {
@ -174,6 +111,7 @@ where
full_auth_chain.dedup();
full_auth_chain
})
.boxed()
.await?;
debug!(
@ -185,6 +123,60 @@ where
Ok(full_auth_chain)
}
#[implement(Service)]
async fn get_auth_chain_outer(
&self,
room_id: &RoomId,
started: Instant,
chunk: Bucket<'_>,
) -> Result<Vec<ShortEventId>> {
let chunk_key: Vec<ShortEventId> = chunk.iter().map(at!(0)).collect();
if chunk_key.is_empty() {
return Ok(Vec::new());
}
if let Ok(cached) = self.get_cached_eventid_authchain(&chunk_key).await {
return Ok(cached.to_vec());
}
let chunk_cache: Vec<_> = chunk
.into_iter()
.try_stream()
.broad_and_then(|(shortid, event_id)| async move {
if let Ok(cached) = self.get_cached_eventid_authchain(&[shortid]).await {
return Ok(cached.to_vec());
}
let auth_chain = self.get_auth_chain_inner(room_id, event_id).await?;
self.cache_auth_chain_vec(vec![shortid], auth_chain.as_slice());
debug!(
?event_id,
elapsed = ?started.elapsed(),
"Cache missed event"
);
Ok(auth_chain)
})
.try_collect()
.map_ok(|chunk_cache: Vec<_>| chunk_cache.into_iter().flatten().collect())
.map_ok(|mut chunk_cache: Vec<_>| {
chunk_cache.sort_unstable();
chunk_cache.dedup();
chunk_cache
})
.await?;
self.cache_auth_chain_vec(chunk_key, chunk_cache.as_slice());
debug!(
chunk_cache_length = ?chunk_cache.len(),
elapsed = ?started.elapsed(),
"Cache missed chunk",
);
Ok(chunk_cache)
}
#[implement(Service)]
#[tracing::instrument(name = "inner", level = "trace", skip(self, room_id))]
async fn get_auth_chain_inner(

View file

@ -44,18 +44,11 @@ pub async fn resolve_state(
let auth_chain_sets: Vec<HashSet<OwnedEventId>> = fork_states
.iter()
.try_stream()
.wide_and_then(|state| async move {
let starting_events = state.values().map(Borrow::borrow);
let auth_chain = self
.services
.wide_and_then(|state| {
self.services
.auth_chain
.get_event_ids(room_id, starting_events)
.await?
.into_iter()
.collect();
Ok(auth_chain)
.event_ids_iter(room_id, state.values().map(Borrow::borrow))
.try_collect()
})
.try_collect()
.await?;

View file

@ -10,7 +10,7 @@ use conduwuit::{
utils::stream::{BroadbandExt, IterStream},
PduEvent, Result,
};
use futures::{FutureExt, StreamExt};
use futures::{FutureExt, StreamExt, TryStreamExt};
use ruma::{state_res::StateMap, OwnedEventId, RoomId, RoomVersionId};
// TODO: if we know the prev_events of the incoming event we can avoid the
@ -140,10 +140,9 @@ pub(super) async fn state_at_incoming_resolved(
let auth_chain: HashSet<OwnedEventId> = self
.services
.auth_chain
.get_event_ids(room_id, starting_events.into_iter())
.await?
.into_iter()
.collect();
.event_ids_iter(room_id, starting_events.into_iter())
.try_collect()
.await?;
auth_chain_sets.push(auth_chain);
fork_states.push(state);