#!/usr/bin/env python3
import argparse
import re
import serial
import os
import shutil
import subprocess
import sys
import time

parser = argparse.ArgumentParser(
    description="Connect to serial console to execute stuff",
    formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
    "--qemu-log",
    dest="qemu_log",
    required=True,
    help="QEMU log to look for serial port info in",
)
parser.add_argument(
    "--hostname", default="trixie", help="hostname of the system for login process"
)
parser.add_argument("--user", default="root", help="user name to use for login")
parser.add_argument("--password", default="grml", help="password for login")
parser.add_argument(
    "--timeout",
    default="180",
    type=int,
    help="Maximum time for finding the login prompt, in seconds",
)
parser.add_argument(
    "--screenshot",
    default="screenshot.jpg",
    help="file name for screenshot captured via VNC on error",
)
parser.add_argument(
    "--tries",
    default="12",
    type=int,
    help="Number of retries for finding the login prompt",
)
parser.add_argument(
    "--poweroff",
    action="store_true",
    default=False,
    help='send "poweroff" command after all other commands',
)
parser.add_argument("command", nargs="+", help="command to execute after logging in")


def print_ser_lines(ser):
    for line in ser.readlines():
        print("<<", line)  # line will be a binary string


def write_ser_line(ser, text):
    print(">>", text)
    ser.write(("%s\n" % text).encode())
    ser.flush()


def wait_ser_text(ser, texts, timeout):
    if isinstance(texts, str):
        texts = [texts]

    ts_max = time.time() + timeout
    ser.timeout = 1
    texts_encoded = [text.encode("utf-8") for text in texts]
    found = False
    print(time.time(), "D: expecting one of", texts_encoded)
    while ts_max > time.time() and found is False:
        for line in ser.readlines():
            print(time.time(), "<<", line)  # line will be a binary string
            for index, text in enumerate(texts_encoded):
                if text in line:
                    found = texts[index]
                    break
    return found


def login(ser, hostname, user, password, timeout):
    login_prompt = "%s login:" % hostname
    shell_prompt = "%s@%s" % (user, hostname)

    write_ser_line(ser, "")  # send newline

    found = wait_ser_text(ser, [shell_prompt, login_prompt], timeout=timeout * 4)
    if found == shell_prompt:
        return
    elif found is False:
        raise ValueError("timeout waiting for login prompt")

    print("Logging in...")
    write_ser_line(ser, user)
    if not wait_ser_text(ser, "Password:", timeout=timeout):
        raise ValueError("timeout waiting for password prompt")

    write_ser_line(ser, password)
    time.sleep(1)

    print(time.time(), "Waiting for shell prompt...")
    if not wait_ser_text(ser, shell_prompt, timeout=timeout * 4):
        raise ValueError("timeout waiting for shell prompt")


def capture_vnc_screenshot(screenshot_file):
    if not shutil.which("vncsnapshot"):
        print("WARN: vncsnapshot not available, skipping vnc snapshot capturing.")
        return

    print("Trying to capture screenshot via vncsnapshot to", screenshot_file)

    proc = subprocess.Popen(["vncsnapshot", "localhost", screenshot_file])
    proc.wait()
    if proc.returncode != 0:
        print("WARN: failed to capture vnc snapshot :(")
    else:
        print("Screenshot file '%s' available" % os.path.abspath(screenshot_file))


def find_serial_port_from_qemu_log(qemu_log, tries):
    port = None
    for i in range(tries):
        print("Waiting for qemu to present serial console [try %s]" % (i, ))
        with open(qemu_log, "r", encoding="utf-8") as fp:
            qemu_log_messages = fp.read().splitlines()
        for line in qemu_log_messages:
            m = re.match("char device redirected to ([^ ]+)", line)
            if m:
                port = m.group(1)
                break
        if port:
            break
        time.sleep(5)

    print("qemu log (up to char device redirect) follows:")
    print("\n".join(qemu_log_messages))
    print()
    return port  # might be None, caller has to deal with it


def main():
    args = parser.parse_args()
    hostname = args.hostname
    password = args.password
    qemu_log = args.qemu_log
    user = args.user
    commands = args.command
    screenshot_file = args.screenshot

    port = find_serial_port_from_qemu_log(qemu_log, args.tries)
    if not port:
        print()
        print("E: no serial port found in qemu log", qemu_log)
        sys.exit(1)

    ser = serial.Serial(port, 115200)
    ser.flushInput()
    ser.flushOutput()

    success = False
    ts_start = time.time()
    for i in range(args.tries):
        try:
            print("Logging into %s via serial console [try %s]" % (port, i))
            login(ser, hostname, user, password, 30)
            success = True
            break
        except Exception as except_inst:
            print("Login failure (try %s):" % (i,), except_inst)
            time.sleep(5)
        if time.time() - ts_start > args.timeout:
            print("E: Timeout reached waiting for login prompt")
            break

    if success:
        write_ser_line(ser, "")
        ser.timeout = 30
        print_ser_lines(ser)
        print("Running commands...")
        for command in commands:
            write_ser_line(ser, command)
            print_ser_lines(ser)
        if args.poweroff:
            print("Sending final poweroff command...")
            write_ser_line(ser, "poweroff")
            ser.flush()
            # after poweroff, the serial device will probably vanish. do not attempt reading from it anymore.

    if not success:
        print("W: Running tests failed, saving screenshot")
        capture_vnc_screenshot(screenshot_file)
        sys.exit(1)


if __name__ == "__main__":
    main()
