#! /usr/bin/python3

from __future__ import annotations

import argparse
import collections
import configparser
import contextlib
import email.headerregistry
import email.message
import email.parser
import email.policy
import email.utils
import enum
import fnmatch
import functools
import json
import logging
import math
import os
import pathlib
import pwd
import selectors
import signal
import smtplib
import socket
import ssl
import struct
import sys
import tempfile
import threading
import time
import types
import typing

from systemd import daemon, journal

if typing.TYPE_CHECKING:
    from _typeshed import HasFileno


HOSTNAME: typing.Final = socket.gethostname()

logger: typing.Final = logging.getLogger("stne-mta")

policy: typing.Final = email.policy.EmailPolicy(
    cte_type="8bit",
    utf8=True,
    refold_source="none",
)


def addrs_parse(v: typing.Any) -> typing.Sequence[email.headerregistry.Address]:
    # Error in typing stubs, just don't bother with it
    header_: typing.Any = policy.header_fetch_parse("To", str(v))
    header: email.headerregistry.AddressHeader = header_

    return header.addresses


def addrs_uniq(
    addrs: typing.Iterable[email.headerregistry.Address],
) -> typing.Sequence[email.headerregistry.Address]:
    uniqs = {addr.addr_spec.casefold(): addr for addr in addrs}
    return tuple(uniqs.values())


class Idle:
    def __init__(self) -> None:
        self.lock = threading.Lock()
        self.cnt = 0

    # TODO(as): use ExitStack?: https://github.com/python/mypy/issues/12875
    def __enter__(self) -> Idle:
        self._ir, self._iw = socket.socketpair()
        self._ir.setblocking(False)
        self._iw.setblocking(False)
        return self

    def __exit__(
        self,
        exc_type: typing.Optional[typing.Type[BaseException]],
        exc_value: typing.Optional[BaseException],
        traceback: typing.Optional[types.TracebackType],
    ) -> None:
        self._ir.close()
        self._iw.close()

    def pollable(self) -> HasFileno:
        return self._ir

    def drain(self) -> int:
        n = 0

        while True:
            try:
                n += len(self._ir.recv(128))
            except BlockingIOError:
                break

        return n

    def inc(self) -> None:
        with self.lock:
            self.cnt += 1

    def dec(self) -> None:
        with self.lock:
            self.cnt -= 1
            is_idle = self.cnt == 0

        if is_idle:
            try:
                self._iw.send(b"0")
            except BlockingIOError:
                # Buffer too big? Then it's definitely already POLLIN, so
                # nothing else to do.
                pass

    def is_idle(self) -> bool:
        with self.lock:
            return self.cnt == 0


class ConfigError(Exception):
    """Non-recoverable errors with config"""


class ConfigValueError(ConfigError):
    def __init__(self, group: str, key: str, err_msg: str) -> None:
        self.group = group
        self.key = key
        self.err_msg = err_msg

    def __str__(self) -> str:
        if self.key:
            return f"[{self.group}].{self.key}: {self.err_msg}"
        else:
            return f"[{self.group}]: {self.err_msg}"


class TLS(enum.Enum):
    STARTTLS = enum.auto()
    Optional = enum.auto()
    SMTPS = enum.auto()
    Off = enum.auto()


class Config:
    aliases: dict[str, typing.Sequence[email.headerregistry.Address]]

    def __init__(self) -> None:
        self.mta_from = "sendmail+{host}+{user}"
        self.mta_domain = ""
        self.aliases = {}
        self.relay_host = ""
        self.relay_port = 587
        self.relay_user = ""
        self.relay_pass = ""
        self.relay_tls = TLS.STARTTLS

    def validate(self) -> None:
        if not self.mta_from:
            raise ConfigValueError("mta", "from", "A value is required")
        if not self.mta_domain:
            raise ConfigValueError("mta", "domain", "A value is required")
        if not self.relay_host:
            raise ConfigValueError("relay", "host", "A value is required")

        for addrs in self.aliases.values():
            self.expand_aliases(addrs)

    def format_from(
        self,
        *,
        user: str = "nobody",
        host: str = HOSTNAME,
        fmt_str: typing.Optional[str] = None,
    ) -> str:
        return (fmt_str or self.mta_from).format(
            user=user,
            host=host,
        )

    def _expand_local(
        self,
        local: str,
        resolving: list[str],
    ) -> typing.Iterator[email.headerregistry.Address]:
        if local in resolving:
            first = resolving.index(local)
            path = " -> ".join(resolving[first:] + [local])
            raise ConfigError("Circular aliases: " + path)

        resolving.append(local)

        matched = False
        for alocal, adsts in self.aliases.items():
            if not fnmatch.fnmatch(local, alocal):
                continue

            matched = True
            yield from self._expand_aliases(adsts, resolving)

        resolving.pop()

        if not matched:
            yield email.headerregistry.Address(
                username=local,
                domain=self.mta_domain,
            )

    def _expand_aliases(
        self,
        addrs: typing.Sequence[email.headerregistry.Address],
        resolving: list[str],
    ) -> typing.Iterator[email.headerregistry.Address]:
        for addr in addrs:
            if addr.domain:
                yield addr
            else:
                yield from self._expand_local(addr.username, resolving)

    def expand_aliases(
        self,
        addrs: typing.Sequence[email.headerregistry.Address],
    ) -> typing.Sequence[email.headerregistry.Address]:
        return addrs_uniq(self._expand_aliases(addrs, []))


class EtcConfig(Config):
    def __init__(self) -> None:
        super().__init__()

        creds_dir = os.environ.get("CREDENTIALS_DIRECTORY")
        if not creds_dir:
            raise ConfigError("CREDENTIALS_DIRECTORY env var not set")

        creds = pathlib.Path(creds_dir)

        cp = configparser.ConfigParser()
        try:
            cp.read(
                [
                    creds / "etc_mta.ini",
                    creds / "etc_relay.ini",
                ]
            )
        except configparser.Error as e:
            raise ConfigError(str(e))

        self._parse("mta_from", cp, "mta", "from", self._from)
        self._parse("mta_domain", cp, "mta", "domain", str)
        self._parse_aliases(cp)
        self._parse("relay_host", cp, "relay", "host", str)
        self._parse("relay_user", cp, "relay", "username", str)
        self._parse("relay_pass", cp, "relay", "password", str)
        self._parse("relay_tls", cp, "relay", "tls", self._tls)

        if self.relay_tls is TLS.SMTPS:
            self.relay_port = 465

        self._parse("relay_port", cp, "relay", "port", self._port)

        for section in cp.sections():
            for key in cp[section].keys():
                raise ConfigValueError(section, key, "Unrecognized key")
            raise ConfigValueError(section, "", "Unrecognized group")

        self.validate()

    def _remove_option(
        self,
        cp: configparser.ConfigParser,
        section: str,
        key: str,
    ) -> None:
        cp.remove_option(section, key)

        if not cp[section]:
            cp.remove_section(section)

    def _parse(
        self,
        attr_name: str,
        cp: configparser.ConfigParser,
        section: str,
        key: str,
        parse: typing.Callable[[str], typing.Any],
    ) -> None:
        try:
            raw_val = cp.get(section, key)
        except (configparser.NoSectionError, configparser.NoOptionError):
            return

        try:
            val = parse(raw_val)
        except ValueError as e:
            raise ConfigValueError(section, key, str(e))

        setattr(self, attr_name, val)
        self._remove_option(cp, section, key)

    def _parse_aliases(self, cp: configparser.ConfigParser) -> None:
        section = "aliases"

        if section not in cp:
            return

        for key, val in cp[section].items():
            self.aliases[key] = addrs_parse(val)
            self._remove_option(cp, section, key)

        # In case there are no aliases
        cp.remove_section(section)

    def _from(self, val: str) -> str:
        try:
            self.format_from(fmt_str=val)
        except KeyError as e:
            raise ValueError("Unrecognized key: " + e.args[0])
        else:
            return val

    def _port(self, val: str) -> int:
        try:
            port = int(val)
        except ValueError:
            raise ValueError("Value must be a number")

        if port <= 0 or port > 65_535:
            raise ValueError("Port out of range")

        return port

    def _tls(self, val: str) -> TLS:
        lval = val.strip().lower()
        for name, member in TLS.__members__.items():
            if name.lower() == lval:
                return member

        raise ValueError("Unrecognized TLS value: " + val)


class EnvelopeError(Exception):
    """Non-recoverable errors with an Envelope"""

    def __init__(self, envl_id: str, msg: str) -> None:
        self.envl_id = envl_id
        self.msg = msg

    def __str__(self) -> str:
        return f"{self.envl_id}: {self.msg}"


class Envelope:

    # Number of times the Queue had to retry relaying
    retries: int = 0

    # If the Queue gave up on delivery
    bounced: bool = False

    def __init__(
        self,
        id: str,
        sender: str,
        rcpts: typing.Sequence[str],
        msg_size: int,
        received: float,
    ):
        self.id = id
        self.sender = sender
        self.rcpts = tuple(rcpts)
        self.msg_size = msg_size
        self.received = received

    def get_log_extra(self, *, detailed: bool = False) -> dict[str, object]:
        extra = {
            "ENVELOPE_ID": self.id,
            "ENVELOPE_RETRIES": self.retries,
        }

        if detailed:
            extra |= {
                "ENVELOPE_BOUNCED": str(self.bounced),
                "ENVELOPE_SENDER": self.sender,
                "ENVELOPE_RCPTS": ", ".join(self.rcpts),
                "ENVELOPE_MSG_BYTES": self.msg_size,
                "ENVELOPE_RECEIVED": self.received,
            }

        return extra

    def _get_msg(self) -> bytes:
        raise NotImplementedError

    def get_msg(self) -> bytes:
        msg = self._get_msg()
        if self.msg_size != len(msg):
            raise EnvelopeError(
                self.id,
                f"expected message to be {self.msg_size} bytes, got {len(msg)}",
            )

        return msg

    def rm(self) -> None:
        raise NotImplementedError


class DiskEnvelope(Envelope):
    def __init__(self, path: pathlib.Path) -> None:
        self._path = path
        super().__init__(self._path.name, *self._read())

    @classmethod
    def _open(
        cls,
        path: pathlib.Path,
        mode: typing.Literal["rb", "wb"] = "rb",
    ) -> typing.IO[bytes]:
        try:
            return path.open(mode)
        except FileNotFoundError:
            raise EnvelopeError(path.name, "envelope file vanished")

    def _read(self) -> tuple[str, typing.Sequence[str], int, float]:
        with self._open(self._path) as f:
            header = f.readline()

        try:
            d = json.loads(header)
        except json.JSONDecodeError as e:
            raise EnvelopeError(self._path.name, f"broken json header: {e}")

        try:
            return (
                d["sender"],
                d["rcpts"],
                d["size"],
                d["received"],
            )
        except KeyError as e:
            raise EnvelopeError(self._path.name, f"broken json header: {e}")

    def _get_msg(self) -> bytes:
        with self._open(self._path) as f:
            f.readline()
            return f.read()

    @classmethod
    def write(
        cls,
        path: typing.Union[str, pathlib.Path],
        sender: str,
        rcpts: typing.Sequence[str],
        msg: bytes,
    ) -> DiskEnvelope:
        path = pathlib.Path(path)
        header = json.dumps(
            {
                "sender": sender,
                "rcpts": rcpts,
                "size": len(msg),
                "received": time.time(),
            }
        )

        with cls._open(path, "wb") as f:
            f.write(header.encode())
            f.write(b"\n")
            f.write(msg)

        return cls(path)

    def rm(self) -> None:
        try:
            self._path.unlink(missing_ok=True)
        except OSError as e:
            logger.error(
                f"{self.id}: Failed to remove envelope from disk: {e}",
                exc_info=e,
                extra=self.get_log_extra(),
            )


class Store:
    def iter(self) -> typing.Iterator[Envelope]:
        raise NotImplementedError

    def put(self, sender: str, rcpts: list[str], msg: bytes) -> Envelope:
        raise NotImplementedError


class DiskStore(Store):
    def __init__(self) -> None:
        state_dir = os.environ.get("STATE_DIRECTORY")
        if not state_dir:
            raise ConfigError("STATE_DIRECTORY env var not set")

        self.path = pathlib.Path(state_dir)

    def iter(self) -> typing.Iterator[Envelope]:
        for path in self.path.iterdir():
            try:
                yield DiskEnvelope(path)
            except EnvelopeError as e:
                logger.warning(f"{path.name}: Invalid envelope in queue: {e}")
                path.unlink(missing_ok=True)

    def put(self, sender: str, rcpts: list[str], msg: bytes) -> DiskEnvelope:
        with tempfile.NamedTemporaryFile(
            mode="wb",
            dir=self.path,
            prefix="ENVL-",
            delete=False,
        ) as f:
            return DiskEnvelope.write(f.name, sender, rcpts, msg)


class Relay:
    def close(self) -> None:
        raise NotImplementedError

    def send(
        self,
        sender: str,
        rcpts: typing.Sequence[str],
        data: bytes,
    ) -> smtplib._SendErrs:
        raise NotImplementedError


class SMTPRelay(Relay):
    _conn: typing.Optional[smtplib.SMTP] = None

    def __init__(self, cfg: Config, timeout: float = 10.0) -> None:
        self.cfg = cfg
        self.timeout = timeout

    def _get_conn(self) -> smtplib.SMTP:
        if self._conn:
            return self._conn

        if self.cfg.relay_tls is TLS.SMTPS:
            self._conn = smtplib.SMTP_SSL(
                self.cfg.relay_host,
                self.cfg.relay_port,
                timeout=self.timeout,
                context=ssl.create_default_context(),
            )
        else:
            self._conn = smtplib.SMTP(
                self.cfg.relay_host,
                self.cfg.relay_port,
                timeout=self.timeout,
            )

        if self.cfg.relay_tls in (TLS.STARTTLS, TLS.Optional):
            try:
                self._conn.starttls(context=ssl.create_default_context())
            except smtplib.SMTPNotSupportedError:
                if self.cfg.relay_tls is TLS.STARTTLS:
                    raise

        if self.cfg.relay_user or self.cfg.relay_pass:
            self._conn.login(self.cfg.relay_user, self.cfg.relay_pass)

        return self._conn

    def close(self) -> None:
        if self._conn:
            self._conn.close()
            self._conn = None

    def send(
        self,
        sender: str,
        rcpts: typing.Sequence[str],
        data: bytes,
    ) -> smtplib._SendErrs:
        try:
            return self._get_conn().sendmail(sender, rcpts, data)
        except:
            self.close()
            raise


class QueueQuit(Exception):
    pass


class Queue:

    store: Store

    def _put(self, envl: Envelope, msg: email.message.Message) -> None:
        raise NotImplementedError

    def put(
        self,
        sender: email.headerregistry.Address,
        rcpts: typing.Sequence[email.headerregistry.Address],
        msg: email.message.Message,
    ) -> Envelope:
        assert rcpts

        envl = self.store.put(
            sender.addr_spec,
            [rcpt.addr_spec for rcpt in rcpts],
            msg.as_bytes(),
        )
        self._put(envl, msg)
        return envl


class RelayQueue(Queue):
    _q: collections.deque[Envelope]
    _th: typing.Optional[threading.Thread] = None

    def __init__(
        self,
        cfg: Config,
        idle: Idle,
        store: Store,
        relay: Relay,
        backoff_factor: float = 5.0,
        backoff_max: float = 120.0,
    ) -> None:
        self.cfg = cfg
        self.idle = idle
        self.store = store
        self._relay = relay
        self._backoff_factor = backoff_factor
        self._backoff_max = backoff_max

        self._cond = threading.Condition()
        self._done = False
        self._q = collections.deque()
        self._n = 0

        for p in self.store.iter():
            self._put(p)

        logger.debug(f"{len(self._q)} existing messages found in store")

    def __enter__(self) -> RelayQueue:
        with self._cond:
            assert not self._done
            assert not self._th

            self._th = threading.Thread(name="Queue", target=self._run)
            self._th.start()

        return self

    def __exit__(
        self,
        exc_type: typing.Optional[typing.Type[BaseException]],
        exc_value: typing.Optional[BaseException],
        traceback: typing.Optional[types.TracebackType],
    ) -> None:
        self.quit()

    def quit(self) -> None:
        with self._cond:
            while self._q:
                self._q.pop()
                self.idle.dec()

            self._done = True
            self._cond.notify()

            th = self._th

        if th:
            th.join()

    def _smtp_err(self, code: int, err: typing.Union[str, bytes]) -> str:
        if isinstance(err, bytes):
            err = err.decode("utf-8", "replace")
        return f"{code}-{err}"

    def _log(
        self,
        level: int,
        envl: Envelope,
        msg: str,
        *,
        exc_info: typing.Optional[BaseException] = None,
        detailed_extra: bool = False,
        **extra: object,
    ) -> None:
        if isinstance(exc_info, smtplib.SMTPResponseException):
            smtp_err = self._smtp_err(exc_info.smtp_code, exc_info.smtp_error)
            msg = f"{smtp_err}: {msg}"
            extra |= {
                "SMTP_CODE": exc_info.smtp_code,
                "SMTP_ERROR": exc_info.smtp_error,
            }
        elif exc_info is not None:
            msg = f"{msg}: {exc_info}"

        logger.log(
            level,
            f"{envl.id}: {msg}",
            exc_info=exc_info,
            extra=extra | envl.get_log_extra(detailed=detailed_extra),
        )

    _debug = functools.partialmethod(_log, logging.DEBUG)
    _info = functools.partialmethod(_log, logging.INFO)
    _warn = functools.partialmethod(_log, logging.WARN)

    def _bounce(
        self,
        e: BaseException,
        envl: Envelope,
        msg: str,
    ) -> None:
        envl.bounced = True
        self._warn(
            envl,
            f"Bouncing message: {msg}",
            exc_info=e,
            detailed_extra=True,
        )
        envl.rm()

    def _status(self, msg: str) -> None:
        with self._cond:
            n = self._n

        s = "" if n <= 1 else "s"
        daemon.notify(f"STATUS={n} message{s} in queue. {msg}\n")

    def _put(
        self,
        envl: Envelope,
        msg: typing.Optional[email.message.Message] = None,
    ) -> None:
        self._info(
            envl,
            (
                "Message enqueued: "
                f"from=<{envl.sender}> "
                f"size={envl.msg_size} "
                f"nrcpts={len(envl.rcpts)}"
            ),
            detailed_extra=True,
        )

        with self._cond:
            if self._done:
                return

            self._q.append(envl)
            self._n += 1
            self.idle.inc()
            self._cond.notify()

    def _next(self, *, block: bool = True) -> typing.Optional[Envelope]:
        with self._cond:
            while not self._done:
                if self._q:
                    return self._q.popleft()
                if not block:
                    return None

                self._status("Waiting for new messages.")
                self._cond.wait()

        raise QueueQuit()

    def _run(self) -> None:
        try:
            while True:
                try:
                    self.relay_one(block=True)
                except QueueQuit:
                    return
        finally:
            self._relay.close()

    def relay_one(self, *, block: bool = False) -> bool:
        """Attempt to relay a single message from the queue.

        Returns True if a message was removed from the queue.
        """

        envl = self._next(block=block)
        if envl is None:
            return False

        self._relay_one(envl)

        self.idle.dec()
        with self._cond:
            self._n -= 1

        return True

    def _relay_one(self, envl: Envelope) -> None:
        backoff = 0

        self._debug(envl, "Relaying")

        while True:
            self._status("Relaying now.")

            try:
                return self._send(envl)
            except socket.gaierror as e:
                if e.errno != socket.EAI_AGAIN:
                    raise
                self._info(envl, "Temporary DNS resolve error", exc_info=e)
            except smtplib.SMTPServerDisconnected as e:
                self._info(envl, "Lost connection to relay", exc_info=e)
            except smtplib.SMTPConnectError as e:
                self._info(envl, "Failed to connect to relay", exc_info=e)
            except smtplib.SMTPHeloError as e:
                self._warn(envl, "Relay rejected HELO", exc_info=e)
            except smtplib.SMTPAuthenticationError as e:
                self._warn(envl, "Relay rejected login", exc_info=e)
            except smtplib.SMTPSenderRefused as e:
                return self._bounce(e, envl, f"Relay rejected sender=<{e.sender}>")
            except smtplib.SMTPDataError as e:
                return self._bounce(e, envl, "Relay rejected message data")
            except smtplib.SMTPException as e:
                return self._bounce(e, envl, "Unexpected SMTP exception")
            except EnvelopeError as e:
                return self._bounce(e, envl, "Unexpected envelope error")
            except OSError as e:
                self._warn(envl, "Unexpected OS exception", exc_info=e)
            except BaseException as e:
                return self._bounce(e, envl, "Unexpected relay exception")

            envl.retries += 1

            # Probably pointless, but keep backoff small to avoid having to
            # calculate large exponents (eg. 2**1_000_000) which can take a
            # bunch of CPU time.
            max_exp = math.ceil(math.log(self._backoff_max, 2))
            backoff = min(max_exp, backoff + 1)

            secs = min(self._backoff_max, self._backoff_factor * (2 ** (backoff - 1)))
            self._status(f"Relay failure; backing off for {secs}s.")

            with self._cond:
                if self._cond.wait_for(lambda: self._done, secs):
                    return

    def _send(self, envl: Envelope) -> None:
        data = envl.get_msg()

        try:
            refused = self._relay.send(envl.sender, envl.rcpts, data)
        except smtplib.SMTPRecipientsRefused as e:
            refused = e.recipients

        envl.rm()

        for rcpt, reply in refused.items():
            smtp_err = self._smtp_err(reply[0], reply[1])
            self._warn(
                envl,
                f"Relay rejected rcpt=<{rcpt}>: {smtp_err}",
                detailed_extra=True,
                SMTP_CODE=reply[0],
                SMTP_ERROR=reply[1],
            )

        accepted = set(envl.rcpts) - refused.keys()
        self._info(
            envl,
            f"Message relayed rcpts={len(accepted)}/{len(envl.rcpts)}",
            detailed_extra=True,
            SMTP_ACCEPTED=", ".join(accepted),
            SMTP_REFUSED=", ".join(refused.keys()),
        )


class FrameProto:
    def get_remote_uid(self) -> int:
        raise NotImplementedError

    def read(self) -> bytes:
        raise NotImplementedError

    def write(self, buf: bytes) -> None:
        raise NotImplementedError


class SocketFrameProto(FrameProto):
    def __init__(self, sock: socket.socket) -> None:
        self._sock = sock

    def __enter__(self) -> SocketFrameProto:
        self._f = self._sock.makefile("rwb")
        return self

    def __exit__(
        self,
        exc_type: typing.Optional[type[BaseException]],
        exc_value: typing.Optional[BaseException],
        traceback: typing.Optional[types.TracebackType],
    ) -> None:
        # BufferedIOs try to flush on close; this can cause an exception if
        # there's data in the buffer and the client is already closed. Since the
        # protocol is fully synced, there's no risk of data loss, so just ignore
        # any errors.
        with contextlib.suppress(OSError):
            self._f.close()

    def get_remote_uid(self) -> int:
        uid = -1

        try:
            pid, uid, gid = struct.unpack(
                "3i",
                self._sock.getsockopt(
                    socket.SOL_SOCKET,
                    socket.SO_PEERCRED,
                    struct.calcsize("3i"),
                ),
            )
        except (OSError, struct.error) as e:
            logger.warning(f"Failed to get PEERCRED: {e}", exc_info=e)

        return uid

    def read(self) -> bytes:
        buf = self._f.read(2)
        assert len(buf) == 2

        n = struct.unpack("H", buf)[0]
        if not n:
            return b""

        buf = self._f.read(n)
        assert len(buf) == n

        return buf

    def write(self, buf: bytes) -> None:
        nb = struct.pack("H", len(buf))

        n = self._f.write(nb)
        assert n == len(nb)

        if buf:
            n = self._f.write(buf)
            assert n == len(buf)

        self._f.flush()


class ClientError(Exception):
    """Non-recoverable error that the client made"""


class ClientOSError(OSError):
    """Any OSError originating from the client's socket"""

    def __init__(self, e: OSError) -> None:
        super().__init__(e.errno, e.strerror)


class ClientOptions:
    sender_name: str = ""
    sender_email: str = ""
    rcpts_from_headers: bool = False
    rcpts: list[email.headerregistry.Address]

    def __init__(self) -> None:
        self.rcpts = []


class Client:
    def __init__(
        self,
        proto: FrameProto,
        cfg: Config,
        q: Queue,
        max_msg_size: int = 10 * 1000 * 1000,
    ) -> None:
        self.proto = proto
        self.cfg = cfg
        self.q = q
        self.max_msg_size = max_msg_size

    @contextlib.contextmanager
    def _proto_read(self) -> typing.Iterator[bytes]:
        try:
            buf = self.proto.read()
        except OSError as e:
            raise ClientOSError(e)

        try:
            yield buf
        except ClientError as e:
            self._proto_reply(str(e))
            raise
        except BaseException:
            self._proto_reply("Internal server error")
            raise
        else:
            self._proto_reply("", may_raise=True)

    def _proto_reply(self, err_msg: str, may_raise: bool = False) -> None:
        try:
            self.proto.write(err_msg.encode())
        except OSError as e:
            if may_raise:
                raise ClientOSError(e)
            else:
                logger.exception("Failed to reply to client message")

    def _push_orig(self, msg: email.message.Message, key: str) -> None:
        if key not in msg:
            return

        msg.add_header("X-Orig-" + key, msg[key])

        # Make sure multiple instances are removed
        del msg[key]

    def _get_remote_user(self) -> str:
        uid = self.proto.get_remote_uid()
        if uid < 0:
            return "nobody"

        try:
            return pwd.getpwuid(uid).pw_name
        except KeyError:
            return f"uid-{uid}"

    def _get_set_originator(
        self,
        msg: email.message.Message,
        opts: ClientOptions,
    ) -> email.headerregistry.Address:
        user = self._get_remote_user()
        user_addr = str(email.headerregistry.Address(username=user, domain=HOSTNAME))

        try:
            from_name = (msg["From"] or [])[0].display_name
        except (IndexError, AttributeError):
            from_name = ""

        self._push_orig(msg, "Sender")
        self._push_orig(msg, "From")

        # Preserve local user information
        msg["X-Sendmail-User"] = user_addr

        if opts.sender_name:
            msg["X-Sendmail-Name"] = opts.sender_name
        if opts.sender_email:
            msg["X-Sendmail-Sender"] = opts.sender_email

        originator = msg["From"] = email.headerregistry.Address(
            display_name=opts.sender_name or from_name or user_addr,
            username=self.cfg.format_from(user=user),
            domain=self.cfg.mta_domain,
        )

        return originator

    def _read_options(self) -> ClientOptions:
        opts = ClientOptions()

        with self._proto_read() as buf:
            try:
                hdr = json.loads(buf)

                sender_name = hdr["sender-name"]
                assert isinstance(sender_name, str), type(sender_name)
                opts.sender_name = sender_name.strip()

                sender_email = hdr["sender-email"]
                assert isinstance(sender_email, str), type(sender_email)
                opts.sender_email = sender_email.strip()

                rcpts_from_headers = hdr["rcpts-from-headers"]
                assert isinstance(rcpts_from_headers, bool), type(rcpts_from_headers)
                opts.rcpts_from_headers = rcpts_from_headers

                rcpts = hdr["rcpts"]
                assert isinstance(rcpts, list), type(rcpts)
                for rcpt in rcpts:
                    assert isinstance(rcpt, str), type(str)
                    opts.rcpts.extend(addrs_parse(rcpt))
            except (
                json.JSONDecodeError,
                KeyError,
                TypeError,
                ValueError,
                AssertionError,
            ):
                raise ClientError("client: Invalid protocol options")

        return opts

    @contextlib.contextmanager
    def _read_message(self) -> typing.Iterator[email.message.Message]:
        n = 0
        parser = email.parser.BytesFeedParser(policy=policy)

        while True:
            with self._proto_read() as buf:
                if not buf:
                    # This is odd: even if the client sent nothing, the other
                    # sendmails will send an empty message, as long as there are
                    # recipients.
                    yield parser.close()
                    return

                n += len(buf)
                if n > self.max_msg_size:
                    raise ClientError("Message too large")

                parser.feed(buf)

    def _run(self) -> None:
        opts = self._read_options()

        with self._read_message() as msg:
            # Original recipients from sendmail args, for debugging
            msg["X-Sendmail-Rcpts"] = ", ".join(str(addr) for addr in opts.rcpts)

            if opts.rcpts_from_headers:
                opts.rcpts.extend(
                    addr
                    for f in ("To", "Cc", "Bcc")
                    for header in msg.get_all(f, [])
                    for addr in header.addresses
                )

            if not opts.rcpts:
                raise ClientError("At least 1 recipient is required")

            opts.rcpts = list(self.cfg.expand_aliases(opts.rcpts))
            if not opts.rcpts:
                return

            # Do some basic MSA (RFC 6409) stuff
            del msg["Bcc"]
            del msg["Resend-Bcc"]
            del msg["Return-Path"]

            sender = self._get_set_originator(msg, opts)

            if not msg["To"]:
                del msg["To"]
                msg["To"] = opts.rcpts

            if "Date" not in msg:
                msg["Date"] = email.utils.localtime()

            if "Message-Id" not in msg:
                msg["Message-Id"] = email.utils.make_msgid()

            # Practically all mail going through sendmail is text/plain, but not
            # all is labeled as such. This is necessary to avoid some
            # MTAs/clients (Outlook, looking at you) from doing their own
            # conversions when not set.
            if "Content-Type" not in msg:
                msg.set_charset("utf-8")

            self.q.put(sender, opts.rcpts, msg)

    def run(self) -> None:
        try:
            self._run()
        except ClientError:
            logger.exception("Bad client conversation")
        except ClientOSError:
            logger.exception("Failed to communicate with client")
        except BaseException:
            logger.exception("Unexpected client exception")


class Server(contextlib.ExitStack):
    def __init__(
        self,
        listeners: typing.Iterable[socket.socket],
        *,
        cfg: Config,
        store: Store,
        relay: Relay,
    ) -> None:
        super().__init__()

        try:
            self.cfg = cfg
            self.idle = self.enter_context(Idle())
            self.q = self.enter_context(RelayQueue(self.cfg, self.idle, store, relay))
            self._sel = self.enter_context(selectors.DefaultSelector())

            self._sel.register(
                self.idle.pollable(),
                selectors.EVENT_READ,
                self.idle.drain,
            )

            self._listeners = [self._listen(sock) for sock in listeners]
        except:
            self.close()
            raise

    def _listen(self, sock: socket.socket) -> socket.socket:
        self.enter_context(sock)
        sock.setblocking(False)
        self._accept(sock)
        self._sel.register(
            sock,
            selectors.EVENT_READ,
            functools.partial(self._accept, sock),
        )
        return sock

    def _client(self, conn: socket.socket) -> None:
        # Don't set a timeout on client sockets. Consider cron(8): it pipes
        # process output to sendmail, and those processes might run for a while,
        # so a timeout might actually break things.
        with conn:
            try:
                with SocketFrameProto(conn) as proto:
                    Client(proto, self.cfg, self.q).run()
            except BaseException as e:
                # Basically a log.wtf()
                logger.exception(f"Unexpected client exception")
            finally:
                self.idle.dec()

    def _accept(self, sock: socket.socket) -> None:
        while True:
            try:
                conn, _ = sock.accept()
            except BlockingIOError:
                return
            except OSError:
                logger.exception("accept failed")
            else:
                self.add_client(conn)

    def _on_signal(
        self,
        signum: int,
        frame: typing.Optional[types.FrameType],
    ) -> None:
        daemon.notify("STOPPING=1\n")

        while self._listeners:
            l = self._listeners.pop()
            self._sel.unregister(l)
            l.close()

        self.q.quit()

    def handle_signals(self) -> None:
        signal.signal(signal.SIGINT, self._on_signal)
        signal.signal(signal.SIGQUIT, self._on_signal)

    def add_client(self, conn: socket.socket) -> None:
        self.idle.inc()
        threading.Thread(
            name="Client",
            target=self._client,
            args=(conn,),
        ).start()

    def run(self) -> None:
        daemon.notify("READY=1\n")

        while not self.idle.is_idle():
            for key, mask in self._sel.select():
                key.data()

        logger.debug("server is idle, exiting")


def stderr_is_journal() -> bool:
    stat = os.fstat(sys.stderr.fileno())
    actual = f"{stat.st_dev}:{stat.st_ino}"
    want = os.getenv("JOURNAL_STREAM", "")

    return actual == want


def get_listeners() -> typing.Sequence[socket.socket]:
    socks = []

    for fileno in daemon.listen_fds():
        if not daemon.is_socket(fileno, type=socket.SOCK_STREAM, listening=1):
            raise ConfigError("server: fds must be SOCK_STREAM listeners")

        socks.append(socket.socket(fileno=fileno))

    return socks


def main() -> None:
    parser = argparse.ArgumentParser(
        "stne-mta",
        description="Mail Transfer Agent",
    )
    parser.add_argument(
        "--loglevel",
        default="info",
        choices=["debug", "info", "warning", "error", "critical"],
        help="set the log level",
    )
    parser.add_argument(
        "--version",
        action="version",
        version="0.1.0~dev62",
    )
    args = parser.parse_args()

    if stderr_is_journal():
        logging.basicConfig(
            level=args.loglevel.upper(),
            format="%(message)s",
            handlers=[
                journal.JournalHandler(SYSLOG_IDENTIFIER="stne-mta"),
            ],
        )
    else:
        logging.basicConfig(
            level=args.loglevel.upper(),
            format="%(name)s [%(levelname)-5.5s] %(message)s",
        )

    try:
        cfg = EtcConfig()
        store = DiskStore()
        relay = SMTPRelay(cfg)
        listeners = get_listeners()
    except ConfigError as e:
        logger.error(f"Config Error: {e}")
        sys.exit(2)

    with Server(listeners, cfg=cfg, store=store, relay=relay) as srv:
        srv.handle_signals()
        srv.run()


if __name__ == "__main__":
    main()
