mirror of
https://github.com/mautrix/signal.git
synced 2025-03-14 14:15:36 +00:00
245 lines
9.2 KiB
Python
245 lines
9.2 KiB
Python
# Copyright (c) 2022 Tulir Asokan
|
|
#
|
|
# This Source Code Form is subject to the terms of the Mozilla Public
|
|
# License, v. 2.0. If a copy of the MPL was not distributed with this
|
|
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
|
from __future__ import annotations
|
|
|
|
from typing import Any, Awaitable, Callable, Dict
|
|
from uuid import UUID, uuid4
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
|
|
from mautrix.util import background_task
|
|
from mautrix.util.logging import TraceLogger
|
|
|
|
from .errors import NotConnected, UnexpectedError, UnexpectedResponse, make_response_error
|
|
|
|
EventHandler = Callable[[Dict[str, Any]], Awaitable[None]]
|
|
|
|
# These are synthetic RPC events for registering callbacks on socket
|
|
# connect and disconnect.
|
|
CONNECT_EVENT = "_socket_connected"
|
|
DISCONNECT_EVENT = "_socket_disconnected"
|
|
_SOCKET_LIMIT = 1024 * 1024 # 1 MiB
|
|
|
|
|
|
class SignaldRPCClient:
|
|
loop: asyncio.AbstractEventLoop
|
|
log: TraceLogger
|
|
|
|
socket_path: str
|
|
_reader: asyncio.StreamReader | None
|
|
_writer: asyncio.StreamWriter | None
|
|
is_connected: bool
|
|
_connect_future: asyncio.Future
|
|
_communicate_task: asyncio.Task | None
|
|
|
|
_response_waiters: dict[UUID, asyncio.Future]
|
|
_rpc_event_handlers: dict[str, list[EventHandler]]
|
|
|
|
def __init__(
|
|
self,
|
|
socket_path: str,
|
|
log: TraceLogger | None = None,
|
|
loop: asyncio.AbstractEventLoop | None = None,
|
|
) -> None:
|
|
self.socket_path = socket_path
|
|
self.log = log or logging.getLogger("mausignald")
|
|
self.loop = loop or asyncio.get_event_loop()
|
|
self._reader = None
|
|
self._writer = None
|
|
self._communicate_task = None
|
|
self.is_connected = False
|
|
self._connect_future = self.loop.create_future()
|
|
self._response_waiters = {}
|
|
self._rpc_event_handlers = {CONNECT_EVENT: [], DISCONNECT_EVENT: []}
|
|
self.add_rpc_handler(DISCONNECT_EVENT, self._abandon_responses)
|
|
|
|
async def wait_for_connected(self, timeout: int | None = None) -> bool:
|
|
if self.is_connected:
|
|
return True
|
|
await asyncio.wait_for(asyncio.shield(self._connect_future), timeout)
|
|
return self.is_connected
|
|
|
|
async def connect(self) -> None:
|
|
if self._writer is not None:
|
|
return
|
|
|
|
self._communicate_task = asyncio.create_task(self._communicate_forever())
|
|
await self._connect_future
|
|
|
|
async def _communicate_forever(self) -> None:
|
|
while True:
|
|
try:
|
|
await self._communicate()
|
|
except Exception:
|
|
self.log.exception("Unknown error in signald socket")
|
|
await asyncio.sleep(30)
|
|
|
|
async def _communicate(self) -> None:
|
|
try:
|
|
self.log.debug(f"Connecting to {self.socket_path}...")
|
|
self._reader, self._writer = await asyncio.open_unix_connection(
|
|
self.socket_path, limit=_SOCKET_LIMIT
|
|
)
|
|
except OSError as e:
|
|
self.log.error(f"Connection to {self.socket_path} failed: {e}")
|
|
await asyncio.sleep(5)
|
|
return
|
|
|
|
read_loop = asyncio.create_task(self._try_read_loop())
|
|
self.is_connected = True
|
|
background_task.create(self._run_rpc_handler(CONNECT_EVENT, {}))
|
|
self._connect_future.set_result(True)
|
|
|
|
await read_loop
|
|
self.is_connected = False
|
|
self._connect_future = self.loop.create_future()
|
|
await self._run_rpc_handler(DISCONNECT_EVENT, {})
|
|
|
|
async def disconnect(self) -> None:
|
|
if self._writer is not None:
|
|
self._writer.write_eof()
|
|
await self._writer.drain()
|
|
if self._communicate_task:
|
|
self._communicate_task.cancel()
|
|
self._communicate_task = None
|
|
self._writer = None
|
|
self._reader = None
|
|
self.is_connected = False
|
|
self._connect_future = self.loop.create_future()
|
|
|
|
def add_rpc_handler(self, method: str, handler: EventHandler) -> None:
|
|
self._rpc_event_handlers.setdefault(method, []).append(handler)
|
|
|
|
def remove_rpc_handler(self, method: str, handler: EventHandler) -> None:
|
|
self._rpc_event_handlers.setdefault(method, []).remove(handler)
|
|
|
|
async def _run_rpc_handler(self, command: str, req: dict[str, Any]) -> None:
|
|
try:
|
|
handlers = self._rpc_event_handlers[command]
|
|
except KeyError:
|
|
self.log.warning("No handlers for RPC request %s", command)
|
|
self.log.trace("Data unhandled request: %s", req)
|
|
else:
|
|
for handler in handlers:
|
|
try:
|
|
await handler(req)
|
|
except Exception:
|
|
self.log.exception("Exception in RPC event handler")
|
|
|
|
def _run_response_handlers(self, req_id: UUID, command: str, req: Any) -> None:
|
|
try:
|
|
waiter = self._response_waiters.pop(req_id)
|
|
except KeyError:
|
|
self.log.debug(f"Nobody waiting for response to {req_id}")
|
|
return
|
|
data = req.get("data")
|
|
if command == "unexpected_error":
|
|
try:
|
|
waiter.set_exception(UnexpectedError(data["message"]))
|
|
except KeyError:
|
|
waiter.set_exception(UnexpectedError("Unexpected error with no message"))
|
|
# elif data and "error" in data and isinstance(data["error"], (str, dict)):
|
|
# waiter.set_exception(make_response_error(data))
|
|
elif "error" in req and isinstance(req["error"], (str, dict)):
|
|
waiter.set_exception(make_response_error(req))
|
|
else:
|
|
waiter.set_result((command, data))
|
|
|
|
async def _handle_incoming_line(self, line: str) -> None:
|
|
try:
|
|
req = json.loads(line)
|
|
except json.JSONDecodeError:
|
|
self.log.debug(f"Got non-JSON data from server: {line}")
|
|
return
|
|
try:
|
|
req_type = req["type"]
|
|
except KeyError:
|
|
self.log.debug(f"Got invalid request from server: {line}")
|
|
return
|
|
|
|
self.log.trace("Got data from server: %s", req)
|
|
|
|
req_id = req.get("id")
|
|
if req_id is None:
|
|
background_task.create(self._run_rpc_handler(req_type, req))
|
|
else:
|
|
self._run_response_handlers(UUID(req_id), req_type, req)
|
|
|
|
async def _try_read_loop(self) -> None:
|
|
try:
|
|
await self._read_loop()
|
|
except Exception:
|
|
self.log.exception("Fatal error in read loop")
|
|
else:
|
|
self.log.debug("Reader disconnected")
|
|
finally:
|
|
self._reader = None
|
|
self._writer = None
|
|
|
|
async def _read_loop(self) -> None:
|
|
while self._reader is not None and not self._reader.at_eof():
|
|
line = await self._reader.readline()
|
|
if not line:
|
|
continue
|
|
try:
|
|
line_str = line.decode("utf-8")
|
|
except UnicodeDecodeError:
|
|
self.log.exception("Got non-unicode request from server: %s", line)
|
|
continue
|
|
try:
|
|
await self._handle_incoming_line(line_str)
|
|
except Exception:
|
|
self.log.exception("Failed to handle incoming request %s", line_str)
|
|
|
|
def _create_request(
|
|
self, command: str, req_id: UUID | None = None, **data: Any
|
|
) -> tuple[asyncio.Future, dict[str, Any]]:
|
|
req_id = req_id or uuid4()
|
|
req = {"id": str(req_id), "type": command, **data}
|
|
self.log.debug("Request %s: %s", req_id, command)
|
|
self.log.trace("Request %s: %s with data: %s", req_id, command, data)
|
|
return self._wait_response(req_id), req
|
|
|
|
def _wait_response(self, req_id: UUID) -> asyncio.Future:
|
|
try:
|
|
future = self._response_waiters[req_id]
|
|
except KeyError:
|
|
future = self._response_waiters[req_id] = self.loop.create_future()
|
|
return future
|
|
|
|
async def _abandon_responses(self, unused_data: dict[str, Any]) -> None:
|
|
for req_id, waiter in self._response_waiters.items():
|
|
if not waiter.done():
|
|
self.log.trace(f"Abandoning response for {req_id}")
|
|
waiter.set_exception(
|
|
NotConnected("Disconnected from signald before RPC completed")
|
|
)
|
|
|
|
async def _send_request(self, data: dict[str, Any]) -> None:
|
|
if self._writer is None:
|
|
raise NotConnected("Not connected to signald")
|
|
|
|
self._writer.write(json.dumps(data).encode("utf-8"))
|
|
self._writer.write(b"\n")
|
|
await self._writer.drain()
|
|
self.log.trace("Sent data to server server: %s", data)
|
|
|
|
async def _raw_request(
|
|
self, command: str, req_id: UUID | None = None, **data: Any
|
|
) -> tuple[str, dict[str, Any]]:
|
|
future, data = self._create_request(command, req_id, **data)
|
|
await self._send_request(data)
|
|
return await asyncio.shield(future)
|
|
|
|
async def _request(self, command: str, expected_response: str, **data: Any) -> Any:
|
|
resp_type, resp_data = await self._raw_request(command, **data)
|
|
if resp_type != expected_response:
|
|
raise UnexpectedResponse(resp_type, resp_data)
|
|
return resp_data
|
|
|
|
async def request_v1(self, command: str, **data: Any) -> Any:
|
|
return await self._request(command, expected_response=command, version="v1", **data)
|