mirror of
https://github.com/girlbossceo/conduwuit.git
synced 2025-03-14 18:55:37 +00:00
use database for resolver caches
Signed-off-by: Jason Volk <jason@zemos.net>
This commit is contained in:
parent
0c96891008
commit
49023aa295
8 changed files with 114 additions and 162 deletions
|
@ -1,7 +1,6 @@
|
|||
use std::fmt::Write;
|
||||
|
||||
use clap::Subcommand;
|
||||
use conduwuit::{utils::time, Result};
|
||||
use futures::StreamExt;
|
||||
use ruma::{events::room::message::RoomMessageEventContent, OwnedServerName};
|
||||
|
||||
use crate::{admin_command, admin_command_dispatch};
|
||||
|
@ -31,29 +30,19 @@ async fn destinations_cache(
|
|||
writeln!(self, "| Server Name | Destination | Hostname | Expires |").await?;
|
||||
writeln!(self, "| ----------- | ----------- | -------- | ------- |").await?;
|
||||
|
||||
let mut out = String::new();
|
||||
{
|
||||
let map = self
|
||||
.services
|
||||
.resolver
|
||||
.cache
|
||||
.destinations
|
||||
.read()
|
||||
.expect("locked");
|
||||
let mut destinations = self.services.resolver.cache.destinations().boxed();
|
||||
|
||||
for (name, &CachedDest { ref dest, ref host, expire }) in map.iter() {
|
||||
if let Some(server_name) = server_name.as_ref() {
|
||||
if name != server_name {
|
||||
continue;
|
||||
}
|
||||
while let Some((name, CachedDest { dest, host, expire })) = destinations.next().await {
|
||||
if let Some(server_name) = server_name.as_ref() {
|
||||
if name != server_name {
|
||||
continue;
|
||||
}
|
||||
|
||||
let expire = time::format(expire, "%+");
|
||||
writeln!(out, "| {name} | {dest} | {host} | {expire} |")?;
|
||||
}
|
||||
}
|
||||
|
||||
self.write_str(out.as_str()).await?;
|
||||
let expire = time::format(expire, "%+");
|
||||
self.write_str(&format!("| {name} | {dest} | {host} | {expire} |\n"))
|
||||
.await?;
|
||||
}
|
||||
|
||||
Ok(RoomMessageEventContent::notice_plain(""))
|
||||
}
|
||||
|
@ -65,29 +54,19 @@ async fn overrides_cache(&self, server_name: Option<String>) -> Result<RoomMessa
|
|||
writeln!(self, "| Server Name | IP | Port | Expires |").await?;
|
||||
writeln!(self, "| ----------- | --- | ----:| ------- |").await?;
|
||||
|
||||
let mut out = String::new();
|
||||
{
|
||||
let map = self
|
||||
.services
|
||||
.resolver
|
||||
.cache
|
||||
.overrides
|
||||
.read()
|
||||
.expect("locked");
|
||||
let mut overrides = self.services.resolver.cache.overrides().boxed();
|
||||
|
||||
for (name, &CachedOverride { ref ips, port, expire }) in map.iter() {
|
||||
if let Some(server_name) = server_name.as_ref() {
|
||||
if name != server_name {
|
||||
continue;
|
||||
}
|
||||
while let Some((name, CachedOverride { ips, port, expire })) = overrides.next().await {
|
||||
if let Some(server_name) = server_name.as_ref() {
|
||||
if name != server_name {
|
||||
continue;
|
||||
}
|
||||
|
||||
let expire = time::format(expire, "%+");
|
||||
writeln!(out, "| {name} | {ips:?} | {port} | {expire} |")?;
|
||||
}
|
||||
}
|
||||
|
||||
self.write_str(out.as_str()).await?;
|
||||
let expire = time::format(expire, "%+");
|
||||
self.write_str(&format!("| {name} | {ips:?} | {port} | {expire} |\n"))
|
||||
.await?;
|
||||
}
|
||||
|
||||
Ok(RoomMessageEventContent::notice_plain(""))
|
||||
}
|
||||
|
|
|
@ -221,10 +221,18 @@ pub(super) static MAPS: &[Descriptor] = &[
|
|||
name: "servercurrentevent_data",
|
||||
..descriptor::RANDOM_SMALL
|
||||
},
|
||||
Descriptor {
|
||||
name: "servername_destination",
|
||||
..descriptor::RANDOM_SMALL
|
||||
},
|
||||
Descriptor {
|
||||
name: "servername_educount",
|
||||
..descriptor::RANDOM_SMALL
|
||||
},
|
||||
Descriptor {
|
||||
name: "servername_override",
|
||||
..descriptor::RANDOM_SMALL
|
||||
},
|
||||
Descriptor {
|
||||
name: "servernameevent_data",
|
||||
cache_disp: CacheDisp::Unique,
|
||||
|
|
|
@ -29,7 +29,7 @@ impl ActualDest {
|
|||
impl super::Service {
|
||||
#[tracing::instrument(skip_all, level = "debug", name = "resolve")]
|
||||
pub(crate) async fn get_actual_dest(&self, server_name: &ServerName) -> Result<ActualDest> {
|
||||
let (result, cached) = if let Some(result) = self.get_cached_destination(server_name) {
|
||||
let (result, cached) = if let Ok(result) = self.cache.get_destination(server_name).await {
|
||||
(result, true)
|
||||
} else {
|
||||
self.validate_dest(server_name)?;
|
||||
|
@ -232,7 +232,7 @@ impl super::Service {
|
|||
|
||||
#[tracing::instrument(skip_all, name = "well-known")]
|
||||
async fn request_well_known(&self, dest: &str) -> Result<Option<String>> {
|
||||
if !self.has_cached_override(dest) {
|
||||
if !self.cache.has_override(dest).await {
|
||||
self.query_and_cache_override(dest, dest, 8448).await?;
|
||||
}
|
||||
|
||||
|
@ -315,7 +315,7 @@ impl super::Service {
|
|||
debug_info!("{overname:?} overriden by {hostname:?}");
|
||||
}
|
||||
|
||||
self.set_cached_override(overname, CachedOverride {
|
||||
self.cache.set_override(overname, CachedOverride {
|
||||
ips: override_ip.into_iter().take(MAX_IPS).collect(),
|
||||
port,
|
||||
expire: CachedOverride::default_expire(),
|
||||
|
|
|
@ -1,108 +1,103 @@
|
|||
use std::{
|
||||
collections::HashMap,
|
||||
net::IpAddr,
|
||||
sync::{Arc, RwLock},
|
||||
time::SystemTime,
|
||||
};
|
||||
use std::{net::IpAddr, sync::Arc, time::SystemTime};
|
||||
|
||||
use arrayvec::ArrayVec;
|
||||
use conduwuit::{
|
||||
trace,
|
||||
utils::{math::Expected, rand},
|
||||
at, implement,
|
||||
utils::{math::Expected, rand, stream::TryIgnore},
|
||||
Result,
|
||||
};
|
||||
use ruma::{OwnedServerName, ServerName};
|
||||
use database::{Cbor, Deserialized, Map};
|
||||
use futures::{Stream, StreamExt};
|
||||
use ruma::ServerName;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use super::fed::FedDest;
|
||||
|
||||
pub struct Cache {
|
||||
pub destinations: RwLock<WellKnownMap>, // actual_destination, host
|
||||
pub overrides: RwLock<TlsNameMap>,
|
||||
destinations: Arc<Map>,
|
||||
overrides: Arc<Map>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||
pub struct CachedDest {
|
||||
pub dest: FedDest,
|
||||
pub host: String,
|
||||
pub expire: SystemTime,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||
pub struct CachedOverride {
|
||||
pub ips: IpAddrs,
|
||||
pub port: u16,
|
||||
pub expire: SystemTime,
|
||||
}
|
||||
|
||||
pub type WellKnownMap = HashMap<OwnedServerName, CachedDest>;
|
||||
pub type TlsNameMap = HashMap<String, CachedOverride>;
|
||||
|
||||
pub type IpAddrs = ArrayVec<IpAddr, MAX_IPS>;
|
||||
pub(crate) const MAX_IPS: usize = 3;
|
||||
|
||||
impl Cache {
|
||||
pub(super) fn new() -> Arc<Self> {
|
||||
pub(super) fn new(args: &crate::Args<'_>) -> Arc<Self> {
|
||||
Arc::new(Self {
|
||||
destinations: RwLock::new(WellKnownMap::new()),
|
||||
overrides: RwLock::new(TlsNameMap::new()),
|
||||
destinations: args.db["servername_destination"].clone(),
|
||||
overrides: args.db["servername_override"].clone(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl super::Service {
|
||||
pub fn set_cached_destination(
|
||||
&self,
|
||||
name: OwnedServerName,
|
||||
dest: CachedDest,
|
||||
) -> Option<CachedDest> {
|
||||
trace!(?name, ?dest, "set cached destination");
|
||||
self.cache
|
||||
.destinations
|
||||
.write()
|
||||
.expect("locked for writing")
|
||||
.insert(name, dest)
|
||||
}
|
||||
#[implement(Cache)]
|
||||
pub fn set_destination(&self, name: &ServerName, dest: CachedDest) {
|
||||
self.destinations.raw_put(name, Cbor(dest));
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn get_cached_destination(&self, name: &ServerName) -> Option<CachedDest> {
|
||||
self.cache
|
||||
.destinations
|
||||
.read()
|
||||
.expect("locked for reading")
|
||||
.get(name)
|
||||
.cloned()
|
||||
}
|
||||
#[implement(Cache)]
|
||||
pub fn set_override(&self, name: &str, over: CachedOverride) {
|
||||
self.overrides.raw_put(name, Cbor(over));
|
||||
}
|
||||
|
||||
pub fn set_cached_override(
|
||||
&self,
|
||||
name: &str,
|
||||
over: CachedOverride,
|
||||
) -> Option<CachedOverride> {
|
||||
trace!(?name, ?over, "set cached override");
|
||||
self.cache
|
||||
.overrides
|
||||
.write()
|
||||
.expect("locked for writing")
|
||||
.insert(name.into(), over)
|
||||
}
|
||||
#[implement(Cache)]
|
||||
pub async fn get_destination(&self, name: &ServerName) -> Result<CachedDest> {
|
||||
self.destinations
|
||||
.get(name)
|
||||
.await
|
||||
.deserialized::<Cbor<_>>()
|
||||
.map(at!(0))
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn get_cached_override(&self, name: &str) -> Option<CachedOverride> {
|
||||
self.cache
|
||||
.overrides
|
||||
.read()
|
||||
.expect("locked for reading")
|
||||
.get(name)
|
||||
.cloned()
|
||||
}
|
||||
#[implement(Cache)]
|
||||
pub async fn get_override(&self, name: &str) -> Result<CachedOverride> {
|
||||
self.overrides
|
||||
.get(name)
|
||||
.await
|
||||
.deserialized::<Cbor<_>>()
|
||||
.map(at!(0))
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn has_cached_override(&self, name: &str) -> bool {
|
||||
self.cache
|
||||
.overrides
|
||||
.read()
|
||||
.expect("locked for reading")
|
||||
.contains_key(name)
|
||||
}
|
||||
#[implement(Cache)]
|
||||
#[must_use]
|
||||
pub async fn has_destination(&self, destination: &str) -> bool {
|
||||
self.destinations.exists(destination).await.is_ok()
|
||||
}
|
||||
|
||||
#[implement(Cache)]
|
||||
#[must_use]
|
||||
pub async fn has_override(&self, destination: &str) -> bool {
|
||||
self.overrides.exists(destination).await.is_ok()
|
||||
}
|
||||
|
||||
#[implement(Cache)]
|
||||
pub fn destinations(&self) -> impl Stream<Item = (&ServerName, CachedDest)> + Send + '_ {
|
||||
self.destinations
|
||||
.stream()
|
||||
.ignore_err()
|
||||
.map(|item: (&ServerName, Cbor<_>)| (item.0, item.1 .0))
|
||||
}
|
||||
|
||||
#[implement(Cache)]
|
||||
pub fn overrides(&self) -> impl Stream<Item = (&ServerName, CachedOverride)> + Send + '_ {
|
||||
self.overrides
|
||||
.stream()
|
||||
.ignore_err()
|
||||
.map(|item: (&ServerName, Cbor<_>)| (item.0, item.1 .0))
|
||||
}
|
||||
|
||||
impl CachedDest {
|
||||
|
|
|
@ -88,18 +88,20 @@ impl Resolve for Resolver {
|
|||
|
||||
impl Resolve for Hooked {
|
||||
fn resolve(&self, name: Name) -> Resolving {
|
||||
let cached: Option<CachedOverride> = self
|
||||
.cache
|
||||
.overrides
|
||||
.read()
|
||||
.expect("locked for reading")
|
||||
.get(name.as_str())
|
||||
.cloned();
|
||||
hooked_resolve(self.cache.clone(), self.server.clone(), self.resolver.clone(), name)
|
||||
.boxed()
|
||||
}
|
||||
}
|
||||
|
||||
cached.map_or_else(
|
||||
|| resolve_to_reqwest(self.server.clone(), self.resolver.clone(), name).boxed(),
|
||||
|cached| cached_to_reqwest(cached).boxed(),
|
||||
)
|
||||
async fn hooked_resolve(
|
||||
cache: Arc<Cache>,
|
||||
server: Arc<Server>,
|
||||
resolver: Arc<TokioAsyncResolver>,
|
||||
name: Name,
|
||||
) -> Result<Addrs, Box<dyn std::error::Error + Send + Sync>> {
|
||||
match cache.get_override(name.as_str()).await {
|
||||
| Ok(cached) => cached_to_reqwest(cached).await,
|
||||
| Err(_) => resolve_to_reqwest(server, resolver, name).boxed().await,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -6,8 +6,9 @@ use std::{
|
|||
|
||||
use arrayvec::ArrayString;
|
||||
use conduwuit::utils::math::Expected;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Eq)]
|
||||
#[derive(Clone, Debug, Deserialize, PartialEq, Eq, Serialize)]
|
||||
pub enum FedDest {
|
||||
Literal(SocketAddr),
|
||||
Named(String, PortString),
|
||||
|
|
|
@ -4,9 +4,9 @@ mod dns;
|
|||
pub mod fed;
|
||||
mod tests;
|
||||
|
||||
use std::{fmt::Write, sync::Arc};
|
||||
use std::sync::Arc;
|
||||
|
||||
use conduwuit::{utils, utils::math::Expected, Result, Server};
|
||||
use conduwuit::{Result, Server};
|
||||
|
||||
use self::{cache::Cache, dns::Resolver};
|
||||
use crate::{client, Dep};
|
||||
|
@ -25,7 +25,7 @@ struct Services {
|
|||
impl crate::Service for Service {
|
||||
#[allow(clippy::as_conversions, clippy::cast_sign_loss, clippy::cast_possible_truncation)]
|
||||
fn build(args: crate::Args<'_>) -> Result<Arc<Self>> {
|
||||
let cache = Cache::new();
|
||||
let cache = Cache::new(&args);
|
||||
Ok(Arc::new(Self {
|
||||
cache: cache.clone(),
|
||||
resolver: Resolver::build(args.server, cache)?,
|
||||
|
@ -36,38 +36,5 @@ impl crate::Service for Service {
|
|||
}))
|
||||
}
|
||||
|
||||
fn memory_usage(&self, out: &mut dyn Write) -> Result {
|
||||
use utils::bytes::pretty;
|
||||
|
||||
let (oc_count, oc_bytes) = self.cache.overrides.read()?.iter().fold(
|
||||
(0_usize, 0_usize),
|
||||
|(count, bytes), (key, val)| {
|
||||
(count.expected_add(1), bytes.expected_add(key.len()).expected_add(val.size()))
|
||||
},
|
||||
);
|
||||
|
||||
let (dc_count, dc_bytes) = self.cache.destinations.read()?.iter().fold(
|
||||
(0_usize, 0_usize),
|
||||
|(count, bytes), (key, val)| {
|
||||
(count.expected_add(1), bytes.expected_add(key.len()).expected_add(val.size()))
|
||||
},
|
||||
);
|
||||
|
||||
writeln!(out, "resolver_overrides_cache: {oc_count} ({})", pretty(oc_bytes))?;
|
||||
writeln!(out, "resolver_destinations_cache: {dc_count} ({})", pretty(dc_bytes))?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn clear_cache(&self) {
|
||||
self.cache.overrides.write().expect("write locked").clear();
|
||||
self.cache
|
||||
.destinations
|
||||
.write()
|
||||
.expect("write locked")
|
||||
.clear();
|
||||
self.resolver.resolver.clear_cache();
|
||||
}
|
||||
|
||||
fn name(&self) -> &str { crate::service::make_name(std::module_path!()) }
|
||||
}
|
||||
|
|
|
@ -125,7 +125,7 @@ where
|
|||
let result = T::IncomingResponse::try_from_http_response(response);
|
||||
|
||||
if result.is_ok() && !actual.cached {
|
||||
resolver.set_cached_destination(dest.to_owned(), CachedDest {
|
||||
resolver.cache.set_destination(dest, CachedDest {
|
||||
dest: actual.dest.clone(),
|
||||
host: actual.host.clone(),
|
||||
expire: CachedDest::default_expire(),
|
||||
|
|
Loading…
Add table
Reference in a new issue