#! /usr/bin/python3

from __future__ import annotations

import argparse
import contextlib
import functools
import io
import json
import logging
import os
import os.path
import socket
import struct
import sys
import typing

from systemd import journal

Send = typing.Callable[[bytes], None]
GetSend = typing.Callable[[], typing.ContextManager[Send]]
LineIter = typing.Iterable[bytes]

LINE_SEPS: typing.Final = (b"\n", b"\r\n")
DOT_ENDS: typing.Final = (b".\n", b".\r\n")

logger: typing.Final = logging.getLogger("stne-sendmail")


class ClientOSError(OSError):
    def __init__(self, e: OSError) -> None:
        super().__init__(e.errno, e.strerror)


class ServerOSError(OSError):
    def __init__(self, e: OSError) -> None:
        super().__init__(e.errno, e.strerror)


class ServerConnectError(ServerOSError):
    pass


class EnqueueError(Exception):
    pass


class NoRecipientsError(Exception):
    pass


def sendmail(
    lines: LineIter,
    send: Send,
    *,
    sender_name: str = "",
    sender_email: str = "",
    recipients: typing.Sequence[str] = [],
    add_recipients_from_headers: bool = False,
    dot_end: bool = True,
) -> None:
    opts = json.dumps(
        {
            "sender-name": sender_name,
            "sender-email": sender_email,
            "rcpts": recipients,
            "rcpts-from-headers": add_recipients_from_headers,
        }
    )

    send(opts.encode())

    it = iter(lines)
    while True:
        try:
            line = next(it)
        except OSError as e:
            raise ClientOSError(e)
        except StopIteration:
            break
        if dot_end and line in DOT_ENDS:
            break
        send(line)

    send(b"")


def send_frame(f: io.BufferedIOBase, buf: bytes) -> None:
    nb = struct.pack("H", len(buf))

    try:
        n = f.write(nb)
        assert n == len(nb)

        n = f.write(buf)
        assert n == len(buf)

        f.flush()

        resp = f.read(2)
        assert len(resp) == 2

        n = struct.unpack("H", resp)[0]

        # Unix-style: empty reply means OK
        if not n:
            return

        resp = f.read(n)
        assert len(resp) == n
    except OSError as e:
        raise ServerOSError(e)

    raise EnqueueError(resp.decode())


@contextlib.contextmanager
def sock_as_send(sock: socket.socket) -> typing.Iterator[Send]:
    f = sock.makefile("rwb")
    try:
        yield functools.partial(send_frame, f)
    finally:
        # BufferdIOs try to flush on close, but that can cause an error if the
        # connection is already closed. Just ignore it.
        with contextlib.suppress(OSError):
            f.close()


def get_mta_sock_path() -> str:
    base = "/run"
    if base.startswith("@"):
        base = "/run"

    return os.path.join(base, "stne-mta.sock")


@contextlib.contextmanager
def connect(
    path: str = get_mta_sock_path(),
) -> typing.Iterator[Send]:
    with socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) as sock:
        try:
            sock.connect(path)
        except OSError as e:
            raise ServerConnectError(e)

        with sock_as_send(sock) as send:
            yield send


def get_extra() -> dict[str, object]:
    try:
        ppid = os.getppid()
        with open(f"/proc/{ppid}/cmdline", "rb") as f:
            parent = f.read().replace(b"\x00", b" ").decode("utf-8", "replace").strip()
    except OSError:
        parent = ""

    return {
        "PARENT_CMDLINE": parent,
        # Race condition: this might exit before journald captures _CMDLINE,
        # so be sure it's preserved
        "CMDLINE": " ".join(sys.argv),
    }


def log_error(
    msg: str,
    exc_info: typing.Optional[BaseException] = None,
    and_print: bool = True,
) -> None:
    logger.error(msg, exc_info=exc_info, extra=get_extra())

    if and_print:
        print("sendmail [ERROR]", msg, file=sys.stderr)


class ArgumentParser(argparse.ArgumentParser):
    def error(self, message: str) -> typing.NoReturn:
        # If someone runs into trouble using this sendmail, log it so that it
        # will get noticed
        log_error(f"ArgumentParser: {message}", and_print=False)

        super().error(message)


def _main(
    args: typing.Sequence[str],
    get_send: typing.Callable[[], typing.ContextManager[Send]],
    lines: LineIter,
) -> None:
    parser = ArgumentParser(
        "sendmail",
        description="stne-mta Sendmail Interface",
        add_help=False,
        argument_default=argparse.SUPPRESS,
    )
    parser.add_argument(
        "--help",
        action="help",
        default=argparse.SUPPRESS,
        help="Show this help message and exit",
    )
    parser.add_argument(
        "--version",
        action="version",
        version="0.1.0~dev62",
        help="Print version information and exit",
    )
    parser.add_argument(
        "-f",
        metavar="EMAIL",
        dest="sender_email",
        help="Set envelope sender",
    )
    parser.add_argument(
        "-r",
        metavar="EMAIL",
        dest="sender_email",
        help="Obsolete form of -f flag",
    )
    parser.add_argument(
        "-F",
        metavar="NAME",
        dest="sender_name",
        help="Set full name of sender",
    )
    parser.add_argument(
        "-t",
        action="store_true",
        dest="add_recipients_from_headers",
        help="Add recipients from message headers to those given on the command line",
    )
    parser.add_argument(
        "-i",
        "-oi",
        action="store_false",
        dest="dot_end",
        help="""Don't interpret lines with only a "." character as end-of-input""",
    )
    parser.add_argument(
        "recipients",
        metavar="recipient",
        nargs="*",
        help="Email address(es) to send the message to",
    )

    args_ignores = [
        "-B",  # Body type (7BIT, 8BITMIME)
        "-C",  # Config file
        "-d",  # Debug flags
        "-h",  # Hop count
        "-L",  # Syslog ident
        "-N",  # DSN
        "-o",  # Options
        "-O",  # Options
        "-p",  # Set protocol used to receive message
        "-R",  # Amount of message to return on bounce
        "-V",  # Envelope ID
        "-X",  # Log file
    ]
    for ignore in args_ignores:
        parser.add_argument(
            ignore,
            metavar="",
            dest="_ignored",
            default=argparse.SUPPRESS,
            help=argparse.SUPPRESS,
        )

    flag_ignores = [
        "-bm",  # Read mail from stdin and deliver to recipients
        "-em",  # Mail errors back to the sender
        "-m",  # From postfix: "(ignored) Backwards compatibility"
        "-n",  # Don't do aliasing
        "-U",  # Initial user submission
        "-v",  # Verbose mode
    ]
    for ignore in flag_ignores:
        parser.add_argument(
            ignore,
            action="store_const",
            const=None,
            dest="_ignored",
            default=argparse.SUPPRESS,
            help=argparse.SUPPRESS,
        )

    kwargs = vars(parser.parse_intermixed_args(args))
    kwargs.pop("_ignored", None)

    with get_send() as send:
        sendmail(lines, send, **kwargs)


# Based on mail1() in bsd-mailx's send.c
def _mailx_compat(
    prog_name: str,
    argv: typing.Sequence[str],
    lines: LineIter,
) -> tuple[typing.Sequence[str], LineIter]:
    parser = ArgumentParser(
        prog_name,
        description=f"stne-mta {prog_name} interface",
        add_help=False,
        argument_default=argparse.SUPPRESS,
    )
    parser.add_argument(
        "--help",
        action="help",
        default=argparse.SUPPRESS,
        help="Show this help message and exit",
    )
    parser.add_argument(
        "--version",
        action="version",
        version="0.1.0~dev62",
        help="Print version information and exit",
    )
    parser.add_argument(
        "-b",
        action="append",
        dest="bcc",
        help="Blind Carbon Copy Recipient list, comma-separated",
        default=None,
    )
    parser.add_argument(
        "-c",
        action="append",
        dest="cc",
        help="Carbon Copy Recipient list, comma-separated",
        default=None,
    )
    parser.add_argument(
        "-E",
        action="store_true",
        dest="skip_empty",
        help="Don't send messages with an empty body",
        default=False,
    )
    parser.add_argument(
        "-s",
        metavar="SUBJECT",
        dest="subject",
        help="Set the message's subject field",
        default="",
    )
    parser.add_argument(
        "-r",
        metavar="FROM",
        dest="sender",
        help="Set From: address",
        default="",
    )
    parser.add_argument(
        "recipients",
        metavar="recipient",
        nargs="*",
        help="Email address(es) to send the message to",
        default=None,
    )

    args_ignores = [
        "-u",  # "Pretend to be", only relevant for local mail
    ]
    for ignore in args_ignores:
        parser.add_argument(
            ignore,
            metavar="",
            dest="_ignored",
            default=argparse.SUPPRESS,
            help=argparse.SUPPRESS,
        )

    flag_ignores = [
        "-d",  # Debug
        "-f",  # User is specifying file to "edit" with Mail
        "-I",  # Interactive
        "-i",  # User wants to ignore interrupts
        "-n",  # Avoid initial header printing
        "-N",  # User doesn't want to source /usr/lib/Mail.rc
        "-v",  # Verbose
    ]
    for ignore in flag_ignores:
        parser.add_argument(
            ignore,
            metavar="",
            dest="_ignored",
            default=argparse.SUPPRESS,
            help=argparse.SUPPRESS,
        )

    args = parser.parse_intermixed_args(argv)

    if not args.recipients:
        raise NoRecipientsError()

    msg = io.BytesIO()
    argv = ["-i", "-t"]
    if args.sender:
        msg.write(f"From: {args.sender}\n".encode())
        argv.extend(["-f", args.sender])
    msg.write(f"To: {", ".join(args.recipients)}\n".encode())
    argv.extend(args.recipients)
    if args.subject:
        msg.write(f"Subject: {args.subject}\n".encode())
    if args.cc:
        msg.write(f"Cc: {", ".join(args.cc)}\n".encode())
    if args.bcc:
        msg.write(f"Bcc: {", ".join(args.bcc)}\n".encode())
    msg.write("\n".encode())

    n = len(msg.getbuffer())
    msg.writelines(lines)
    if args.skip_empty and len(msg.getbuffer()) == n:
        sys.exit(0)

    msg.seek(0)
    return argv, msg


def main(
    argv: typing.Sequence[str] = sys.argv,
    lines: LineIter = sys.stdin.buffer,
    get_send: GetSend = connect,
) -> None:
    try:
        prog_name = os.path.basename(argv[0])
        argv = argv[1:]
        if prog_name in ("mail", "mailx"):
            argv, lines = _mailx_compat(prog_name, argv, lines)
        return _main(argv, get_send, lines)
    except SystemExit:
        raise
    except NoRecipientsError as e:
        log_error("You must specify at least 1 recipient", exc_info=e)
    except ClientOSError as e:
        log_error(f"Error reading from stdin: {e}", exc_info=e)
    except ServerConnectError as e:
        log_error(f"Failed to connect to server: {e}", exc_info=e)
    except ServerOSError as e:
        log_error(f"Failed to communicate with server: {e}", exc_info=e)
    except EnqueueError as e:
        log_error(f"Failed to enqueue message: {e}", exc_info=e)
    except BaseException as e:
        log_error("Unexpected exception", exc_info=e)
        raise

    sys.exit(1)


if __name__ == "__main__":
    # Only log to journal when main, not when eg. testing
    logger.addHandler(journal.JournalHandler(SYSLOG_IDENTIFIER="stne-sendmail"))

    main()
