cache result of resolution at completion of resolution

Signed-off-by: Jason Volk <jason@zemos.net>
This commit is contained in:
Jason Volk 2025-01-22 07:56:24 +00:00
parent f75d9fa79e
commit 607e338ac2
3 changed files with 10 additions and 31 deletions

View file

@ -18,7 +18,6 @@ use super::{
pub(crate) struct ActualDest {
pub(crate) dest: FedDest,
pub(crate) host: String,
pub(crate) cached: bool,
}
impl ActualDest {
@ -29,10 +28,10 @@ impl ActualDest {
impl super::Service {
#[tracing::instrument(skip_all, level = "debug", name = "resolve")]
pub(crate) async fn get_actual_dest(&self, server_name: &ServerName) -> Result<ActualDest> {
let (CachedDest { dest, host, .. }, cached) =
let (CachedDest { dest, host, .. }, _cached) =
self.lookup_actual_dest(server_name).await?;
Ok(ActualDest { dest, host, cached })
Ok(ActualDest { dest, host })
}
pub(crate) async fn lookup_actual_dest(
@ -49,6 +48,7 @@ impl super::Service {
}
self.resolve_actual_dest(server_name, true)
.inspect_ok(|result| self.cache.set_destination(server_name, result))
.map_ok(|result| (result, false))
.boxed()
.await
@ -334,7 +334,7 @@ impl super::Service {
debug_info!("{overname:?} overriden by {hostname:?}");
}
self.cache.set_override(overname, CachedOverride {
self.cache.set_override(overname, &CachedOverride {
ips: override_ip.into_iter().take(MAX_IPS).collect(),
port,
expire: CachedOverride::default_expire(),

View file

@ -45,12 +45,12 @@ impl Cache {
}
#[implement(Cache)]
pub fn set_destination(&self, name: &ServerName, dest: CachedDest) {
pub fn set_destination(&self, name: &ServerName, dest: &CachedDest) {
self.destinations.raw_put(name, Cbor(dest));
}
#[implement(Cache)]
pub fn set_override(&self, name: &str, over: CachedOverride) {
pub fn set_override(&self, name: &str, over: &CachedOverride) {
self.overrides.raw_put(name, Cbor(over));
}

View file

@ -18,10 +18,7 @@ use ruma::{
CanonicalJsonObject, CanonicalJsonValue, ServerName, ServerSigningKeyId,
};
use crate::{
resolver,
resolver::{actual::ActualDest, cache::CachedDest},
};
use crate::resolver::actual::ActualDest;
impl super::Service {
#[tracing::instrument(
@ -73,16 +70,7 @@ impl super::Service {
debug!(?method, ?url, "Sending request");
match client.execute(request).await {
| Ok(response) =>
handle_response::<T>(
&self.services.resolver,
dest,
actual,
&method,
&url,
response,
)
.await,
| Ok(response) => handle_response::<T>(dest, actual, &method, &url, response).await,
| Err(error) =>
Err(handle_error(actual, &method, &url, error).expect_err("always returns error")),
}
@ -111,7 +99,6 @@ impl super::Service {
}
async fn handle_response<T>(
resolver: &resolver::Service,
dest: &ServerName,
actual: &ActualDest,
method: &Method,
@ -122,17 +109,9 @@ where
T: OutgoingRequest + Send,
{
let response = into_http_response(dest, actual, method, url, response).await?;
let result = T::IncomingResponse::try_from_http_response(response);
if result.is_ok() && !actual.cached {
resolver.cache.set_destination(dest, CachedDest {
dest: actual.dest.clone(),
host: actual.host.clone(),
expire: CachedDest::default_expire(),
});
}
result.map_err(|e| err!(BadServerResponse("Server returned bad 200 response: {e:?}")))
T::IncomingResponse::try_from_http_response(response)
.map_err(|e| err!(BadServerResponse("Server returned bad 200 response: {e:?}")))
}
async fn into_http_response(