mautrix-signal/mausignald/signald.py
2020-09-28 02:03:47 +03:00

136 lines
6 KiB
Python

# Copyright (c) 2020 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from typing import Union, Optional, List, Dict, Any, Callable, Awaitable, TypeVar, Type
from uuid import uuid4
import asyncio
from mautrix.util.logging import TraceLogger
from .rpc import SignaldRPCClient
from .errors import UnexpectedError, UnexpectedResponse, make_linking_error
from .types import (Address, Quote, Attachment, Reaction, Account, Message, Contact, FullGroup,
Profile)
T = TypeVar('T')
EventHandler = Callable[[T], Awaitable[None]]
class SignaldClient(SignaldRPCClient):
_event_handlers: Dict[Type[T], List[EventHandler]]
def __init__(self, socket_path: str = "/var/run/signald/signald.sock",
log: Optional[TraceLogger] = None,
loop: Optional[asyncio.AbstractEventLoop] = None) -> None:
super().__init__(socket_path, log, loop)
self._event_handlers = {}
self.add_rpc_handler("message", self._parse_message)
def add_event_handler(self, event_class: Type[T], handler: EventHandler) -> None:
self._event_handlers.setdefault(event_class, []).append(handler)
def remove_event_handler(self, event_class: Type[T], handler: EventHandler) -> None:
self._event_handlers.setdefault(event_class, []).remove(handler)
async def _run_event_handler(self, event: T) -> None:
try:
handlers = self._event_handlers[type(event)]
except KeyError:
self.log.warning(f"No handlers for {type(event)}")
else:
for handler in handlers:
try:
await handler(event)
except Exception:
self.log.exception("Exception in event handler")
async def _parse_message(self, data: Dict[str, Any]) -> None:
event_type = data["type"]
event_data = data["data"]
event_class = {
"message": Message,
}[event_type]
event = event_class.deserialize(event_data)
await self._run_event_handler(event)
async def subscribe(self, username: str) -> bool:
try:
await self.request("subscribe", "subscribed", username=username)
return True
except UnexpectedError as e:
self.log.debug("Failed to subscribe to %s: %s", username, e)
return False
async def link(self, url_callback: Callable[[str], Awaitable[None]],
device_name: str = "mausignald") -> Account:
req_id = uuid4()
resp_type, resp = await self._raw_request("link", req_id, deviceName=device_name)
if resp_type == "linking_error":
raise make_linking_error(resp)
elif resp_type != "linking_uri":
raise UnexpectedResponse(resp_type, resp)
self.loop.create_task(url_callback(resp["uri"]))
resp_type, resp = await self._wait_response(req_id)
if resp_type == "linking_error":
raise make_linking_error(resp)
elif resp_type != "linking_successful":
raise UnexpectedResponse(resp_type, resp)
return Account.deserialize(resp)
async def list_accounts(self) -> List[Account]:
data = await self.request("list_accounts", "account_list")
return [Account.deserialize(acc) for acc in data["accounts"]]
@staticmethod
def _recipient_to_args(recipient: Union[Address, str]) -> Dict[str, Any]:
if isinstance(recipient, Address):
return {"recipientAddress": recipient.serialize()}
else:
return {"recipientGroupId": recipient}
async def react(self, username: str, recipient: Union[Address, str],
reaction: Reaction) -> None:
await self.request("react", "send_results", username=username,
reaction=reaction.serialize(),
**self._recipient_to_args(recipient))
async def send(self, username: str, recipient: Union[Address, str], body: str,
quote: Optional[Quote] = None, attachments: Optional[List[Attachment]] = None,
timestamp: Optional[int] = None) -> None:
serialized_quote = quote.serialize() if quote else None
serialized_attachments = [attachment.serialize() for attachment in (attachments or [])]
await self.request("send", "send_results", username=username, messageBody=body,
attachments=serialized_attachments, quote=serialized_quote,
timestamp=timestamp, **self._recipient_to_args(recipient))
# TODO return something?
async def mark_read(self, username: str, sender: Address, timestamps: List[int],
when: Optional[int] = None) -> None:
await self.request_nowait("mark_read", username=username, timestamps=timestamps, when=when,
recipientAddress=sender.serialize())
async def list_contacts(self, username: str) -> List[Contact]:
contacts = await self.request("list_contacts", "contact_list", username=username)
return [Contact.deserialize(contact) for contact in contacts]
async def list_groups(self, username: str) -> List[FullGroup]:
resp = await self.request("list_groups", "group_list", username=username)
return [FullGroup.deserialize(group) for group in resp["groups"]]
async def get_profile(self, username: str, address: Address) -> Optional[Profile]:
try:
resp = await self.request("get_profile", "profile", username=username,
recipientAddress=address.serialize())
except UnexpectedResponse as e:
if e.resp_type == "profile_not_available":
return None
raise
return Profile.deserialize(resp)
async def set_profile(self, username: str, new_name: str) -> None:
await self.request("set_profile", "profile_set", username=username, name=new_name)