mautrix-signal/mausignald/rpc.py
2023-02-12 12:40:34 +02:00

245 lines
9.2 KiB
Python

# Copyright (c) 2022 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from __future__ import annotations
from typing import Any, Awaitable, Callable, Dict
from uuid import UUID, uuid4
import asyncio
import json
import logging
from mautrix.util import background_task
from mautrix.util.logging import TraceLogger
from .errors import NotConnected, UnexpectedError, UnexpectedResponse, make_response_error
EventHandler = Callable[[Dict[str, Any]], Awaitable[None]]
# These are synthetic RPC events for registering callbacks on socket
# connect and disconnect.
CONNECT_EVENT = "_socket_connected"
DISCONNECT_EVENT = "_socket_disconnected"
_SOCKET_LIMIT = 1024 * 1024 # 1 MiB
class SignaldRPCClient:
loop: asyncio.AbstractEventLoop
log: TraceLogger
socket_path: str
_reader: asyncio.StreamReader | None
_writer: asyncio.StreamWriter | None
is_connected: bool
_connect_future: asyncio.Future
_communicate_task: asyncio.Task | None
_response_waiters: dict[UUID, asyncio.Future]
_rpc_event_handlers: dict[str, list[EventHandler]]
def __init__(
self,
socket_path: str,
log: TraceLogger | None = None,
loop: asyncio.AbstractEventLoop | None = None,
) -> None:
self.socket_path = socket_path
self.log = log or logging.getLogger("mausignald")
self.loop = loop or asyncio.get_event_loop()
self._reader = None
self._writer = None
self._communicate_task = None
self.is_connected = False
self._connect_future = self.loop.create_future()
self._response_waiters = {}
self._rpc_event_handlers = {CONNECT_EVENT: [], DISCONNECT_EVENT: []}
self.add_rpc_handler(DISCONNECT_EVENT, self._abandon_responses)
async def wait_for_connected(self, timeout: int | None = None) -> bool:
if self.is_connected:
return True
await asyncio.wait_for(asyncio.shield(self._connect_future), timeout)
return self.is_connected
async def connect(self) -> None:
if self._writer is not None:
return
self._communicate_task = asyncio.create_task(self._communicate_forever())
await self._connect_future
async def _communicate_forever(self) -> None:
while True:
try:
await self._communicate()
except Exception:
self.log.exception("Unknown error in signald socket")
await asyncio.sleep(30)
async def _communicate(self) -> None:
try:
self.log.debug(f"Connecting to {self.socket_path}...")
self._reader, self._writer = await asyncio.open_unix_connection(
self.socket_path, limit=_SOCKET_LIMIT
)
except OSError as e:
self.log.error(f"Connection to {self.socket_path} failed: {e}")
await asyncio.sleep(5)
return
read_loop = asyncio.create_task(self._try_read_loop())
self.is_connected = True
background_task.create(self._run_rpc_handler(CONNECT_EVENT, {}))
self._connect_future.set_result(True)
await read_loop
self.is_connected = False
self._connect_future = self.loop.create_future()
await self._run_rpc_handler(DISCONNECT_EVENT, {})
async def disconnect(self) -> None:
if self._writer is not None:
self._writer.write_eof()
await self._writer.drain()
if self._communicate_task:
self._communicate_task.cancel()
self._communicate_task = None
self._writer = None
self._reader = None
self.is_connected = False
self._connect_future = self.loop.create_future()
def add_rpc_handler(self, method: str, handler: EventHandler) -> None:
self._rpc_event_handlers.setdefault(method, []).append(handler)
def remove_rpc_handler(self, method: str, handler: EventHandler) -> None:
self._rpc_event_handlers.setdefault(method, []).remove(handler)
async def _run_rpc_handler(self, command: str, req: dict[str, Any]) -> None:
try:
handlers = self._rpc_event_handlers[command]
except KeyError:
self.log.warning("No handlers for RPC request %s", command)
self.log.trace("Data unhandled request: %s", req)
else:
for handler in handlers:
try:
await handler(req)
except Exception:
self.log.exception("Exception in RPC event handler")
def _run_response_handlers(self, req_id: UUID, command: str, req: Any) -> None:
try:
waiter = self._response_waiters.pop(req_id)
except KeyError:
self.log.debug(f"Nobody waiting for response to {req_id}")
return
data = req.get("data")
if command == "unexpected_error":
try:
waiter.set_exception(UnexpectedError(data["message"]))
except KeyError:
waiter.set_exception(UnexpectedError("Unexpected error with no message"))
# elif data and "error" in data and isinstance(data["error"], (str, dict)):
# waiter.set_exception(make_response_error(data))
elif "error" in req and isinstance(req["error"], (str, dict)):
waiter.set_exception(make_response_error(req))
else:
waiter.set_result((command, data))
async def _handle_incoming_line(self, line: str) -> None:
try:
req = json.loads(line)
except json.JSONDecodeError:
self.log.debug(f"Got non-JSON data from server: {line}")
return
try:
req_type = req["type"]
except KeyError:
self.log.debug(f"Got invalid request from server: {line}")
return
self.log.trace("Got data from server: %s", req)
req_id = req.get("id")
if req_id is None:
background_task.create(self._run_rpc_handler(req_type, req))
else:
self._run_response_handlers(UUID(req_id), req_type, req)
async def _try_read_loop(self) -> None:
try:
await self._read_loop()
except Exception:
self.log.exception("Fatal error in read loop")
else:
self.log.debug("Reader disconnected")
finally:
self._reader = None
self._writer = None
async def _read_loop(self) -> None:
while self._reader is not None and not self._reader.at_eof():
line = await self._reader.readline()
if not line:
continue
try:
line_str = line.decode("utf-8")
except UnicodeDecodeError:
self.log.exception("Got non-unicode request from server: %s", line)
continue
try:
await self._handle_incoming_line(line_str)
except Exception:
self.log.exception("Failed to handle incoming request %s", line_str)
def _create_request(
self, command: str, req_id: UUID | None = None, **data: Any
) -> tuple[asyncio.Future, dict[str, Any]]:
req_id = req_id or uuid4()
req = {"id": str(req_id), "type": command, **data}
self.log.debug("Request %s: %s", req_id, command)
self.log.trace("Request %s: %s with data: %s", req_id, command, data)
return self._wait_response(req_id), req
def _wait_response(self, req_id: UUID) -> asyncio.Future:
try:
future = self._response_waiters[req_id]
except KeyError:
future = self._response_waiters[req_id] = self.loop.create_future()
return future
async def _abandon_responses(self, unused_data: dict[str, Any]) -> None:
for req_id, waiter in self._response_waiters.items():
if not waiter.done():
self.log.trace(f"Abandoning response for {req_id}")
waiter.set_exception(
NotConnected("Disconnected from signald before RPC completed")
)
async def _send_request(self, data: dict[str, Any]) -> None:
if self._writer is None:
raise NotConnected("Not connected to signald")
self._writer.write(json.dumps(data).encode("utf-8"))
self._writer.write(b"\n")
await self._writer.drain()
self.log.trace("Sent data to server server: %s", data)
async def _raw_request(
self, command: str, req_id: UUID | None = None, **data: Any
) -> tuple[str, dict[str, Any]]:
future, data = self._create_request(command, req_id, **data)
await self._send_request(data)
return await asyncio.shield(future)
async def _request(self, command: str, expected_response: str, **data: Any) -> Any:
resp_type, resp_data = await self._raw_request(command, **data)
if resp_type != expected_response:
raise UnexpectedResponse(resp_type, resp_data)
return resp_data
async def request_v1(self, command: str, **data: Any) -> Any:
return await self._request(command, expected_response=command, version="v1", **data)