#!/usr/bin/python3

import glob
import os
import string
import subprocess
import sys
import unittest
import errno
import time
import re
import socket
import tempfile
import binascii

import importlib.machinery
import importlib.util

import parent
import testlib
import testvm

sys.dont_write_bytecode = True
os.environ['PYTHONUNBUFFERED'] = '1'


class Test:
    process = None
    retries = 0
    output = b""

    def __init__(self, test_id, command, timeout, nondestructive):
        self.test_id = test_id
        self.command = command
        self.timeout = timeout
        self.nondestructive = nondestructive


def test_name(test):
    return "{0} {1} {2}{3}".format(test.test_id, test.command[0], test.command[-1], " [ND]" if test.nondestructive else "")


def flush_stdout():
    while True:
        try:
            sys.stdout.flush()
            break
        except BlockingIOError:
            time.sleep(0.1)


def print_test(test, print_tap=True):
    for line in test.output.splitlines(keepends=True):
        while line:
            try:
                sys.stdout.buffer.write(line)
                break
            except BlockingIOError as e:
                line = line[e.characters_written:]
                time.sleep(0.1)
    flush_stdout()

    if not print_tap:
        return

    if test.process.returncode == 0:
        print("ok " +  test_name(test))
    elif test.process.returncode == 77 or b"# SKIP " in test.output:
        # If the test was skipped, add the last line (which contains the reason
        # for the skip) to the result
        print("ok {0} {1}".format(test_name(test),
                                  test.output.splitlines()[-1].strip().decode() if test.process.returncode == 77 else ""))
    else:
        print("not ok " + test_name(test))
    flush_stdout()

def finish_test(opts, test):
    """Returns if a test should retry or not

    Call test-policy on the test's output, print if needed.

    Return (retry_reason, exit_code). retry_reason can be None or a string.
    """
    if test.process.returncode in [0, 77]:
        print_test(test, not opts.list)
        return None, 0

    if not opts.thorough:
        cmd = ["tests-policy", testvm.DEFAULT_IMAGE]
        try:
            test.output += ("not ok " + test_name(test)).encode()
            proc = subprocess.Popen(cmd, stdin=subprocess.PIPE, stdout=subprocess.PIPE)
            changed = proc.communicate(test.output)[0]
            if proc.returncode == 0:
                if test.output != changed:
                    changed += b"\n"
                test.output = changed
        except OSError as ex:
            if ex.errno != errno.ENOENT:
                sys.stderr.write("Couldn't run tests-policy: {0}\n".format(str(ex)))

        if b"# SKIP" in test.output:
            print_test(test, print_tap=False)
            return None, 0

    # do we get a specific retry reason from tests-policy?
    m = re.search(b"\s*# RETRY (.*)$", test.output, re.MULTILINE)
    if m:
        retry_reason = m.group(1)
        # remove it from test output; we must not print it after the 3rd time, and going to print it separately
        test.output = re.sub(b"\s*# RETRY .*\\n", b"", test.output)
    else:
        # HACK: many tests are unstable, always retry them 3 times
        retry_reason = b"be robust against unstable tests"

    unexpected_message = b"Test completed, but found unexpected" in test.output
    if test.retries < 2 and not unexpected_message:
        test.retries += 1
        test.output += b" # RETRY %i (%s)\n" % (test.retries, retry_reason)
        print_test(test, print_tap=opts.thorough)
        return retry_reason, 0
    else:
        print_test(test, print_tap=opts.thorough)
    return None, 1

def check_valid(filename):
    name = os.path.basename(filename)
    allowed = string.ascii_letters + string.digits + '-_'
    if not all(c in allowed for c in name):
        return None
    return name.replace("-", "_")

def build_command(filename, test, opts):
    cmd = [filename]
    if opts.trace:
        cmd.append("-t")
    if opts.verbosity:
        cmd.append("-v")
    if not opts.fetch:
        cmd.append("--nonet")
    if opts.list:
        cmd.append("-l")
    cmd.append(test)
    return cmd

def run(opts, image):
    # Build the list of tests we'll parallellize and the ones we'll run serially
    test_loader = unittest.TestLoader()
    parallel_tests = []
    serial_tests = []
    test_id = 1
    result = 0
    jobs = 1 if opts.list else opts.jobs
    start_time = time.time()
    serial_tests_duration = 0

    # Make sure tests can make relative imports
    sys.path.append(os.path.realpath(opts.test_dir))

    seen_classes = {}
    for filename in glob.glob(os.path.join(opts.test_dir, "check-*")):
        name = check_valid(filename)
        if not name or not os.path.isfile(filename):
            continue
        loader = importlib.machinery.SourceFileLoader(name, filename)
        module = importlib.util.module_from_spec(importlib.util.spec_from_loader(loader.name, loader))
        loader.exec_module(module)
        for test_suite in test_loader.loadTestsFromModule(module):
            for test in test_suite:
                # ensure that test classes are unique, so that they can be selected properly
                cls = test.__class__.__name__
                if seen_classes.get(cls) not in [None, filename]:
                    raise ValueError("test class %s in %s already defined in %s" % (cls, filename, seen_classes[cls]))
                seen_classes[cls] = filename

                test_method = getattr(test.__class__, test._testMethodName)
                test_str = "{0}.{1}".format(cls, test._testMethodName)
                # most tests should take much less than 10mins, so default to that;
                # longer tests can be annotated with @timeout(seconds)
                # check the test function first, fall back to the class'es timeout
                test_timeout = getattr(test_method, "__timeout", getattr(test, "__timeout", 600))
                if opts.tests and not any([t in test_str for t in opts.tests]):
                    continue
                if test_str in opts.exclude:
                    continue
                nd = getattr(test_method, "_testlib__non_destructive", False)
                test = Test(test_id, build_command(filename, test_str, opts), test_timeout, nd)
                if nd:
                    serial_tests.append(test)
                else:
                    if not opts.nondestructive:
                        parallel_tests.append(test)
                test_id += 1

    # sort serial tests by class/test name, to avoid spurious errors where failures depend on the order of execution
    # but let's make sure we always test them both ways around; hash the image name, which is robust, reproducible, and provides
    # an even distribution of both directions
    serial_tests.sort(key=lambda t: t.command[-1], reverse=bool(binascii.crc32(image.encode()) & 1))

    print("1..{0}".format(len(parallel_tests) + len(serial_tests)))
    flush_stdout()

    if serial_tests and not opts.list:
        if opts.machine:
            ssh_address = opts.machine
            web_address = opts.browser
        else:
            m = testlib.MachineCase.get_global_machine(restrict=not opts.enable_network)
            ssh_address = "{0}:{1}".format(m.ssh_address,
                                           m.ssh_port)
            web_address = "{0}:{1}".format(m.web_address,
                                           m.web_port)
        for test in serial_tests:
            test.command.insert(-2, "--machine")
            test.command.insert(-2, ssh_address)
            test.command.insert(-2, "--browser")
            test.command.insert(-2, web_address)

    running_tests = []
    serial_test = None
    serial_tests_len = len(serial_tests)
    while serial_tests or parallel_tests or running_tests:
        made_progress = False
        if len(running_tests) < jobs:
            test = None
            if serial_tests and serial_test is None:
                test = serial_tests.pop(0)
                serial_test = test
                serial_test_start = time.time()
            elif parallel_tests:
                test = parallel_tests.pop(0)

            if test:
                made_progress = True
                test.outfile = tempfile.TemporaryFile()
                test.process = subprocess.Popen(["timeout", str(test.timeout)] + test.command,
                                                stdout=test.outfile, stderr=subprocess.STDOUT)
                running_tests.append(test)


        for test in running_tests.copy():
            poll_result = test.process.poll()
            if poll_result is not None:
                made_progress = True
                test.outfile.flush()
                test.outfile.seek(0)
                test.output = test.outfile.read()
                test.outfile.close()
                running_tests.remove(test)
                retry_reason, test_result = finish_test(opts, test)
                result += test_result

                if test is serial_test:
                    serial_tests_duration += (time.time() - serial_test_start)

                    # sometimes our global machine gets messed up; also, tests that time out don't run cleanup handlers
                    # restart it to avoid an unbounded number of test retries and follow-up errors
                    if not opts.machine and (poll_result == 124 or (retry_reason and b"test harness" in retry_reason)):
                        # try hard to keep the test output consistent
                        testlib.MachineCase.kill_global_machine()
                        testlib.MachineCase.get_global_machine()

                # run again if needed
                if retry_reason:
                    test.output = None
                    test.process = None
                    if test is serial_test:
                        serial_tests.insert(0, test)
                    else:
                        parallel_tests.insert(0, test)

                if test is serial_test:
                    serial_test = None

        # Sleep if we didn't make progress
        if not made_progress and not opts.list:
            time.sleep(0.5)

    if not opts.list:
        testlib.MachineCase.kill_global_machine()

        duration = int(time.time() - start_time)
        hostname = socket.gethostname().split(".")[0]
        details = "[{0}s on {1}, {2} serial tests took {3}s]".format(duration, hostname, serial_tests_len, int(serial_tests_duration))
        print()
        if result > 0:
            print("# {0} TESTS FAILED {1}".format(result, details))
        else:
            print("# TESTS PASSED {0}".format(details))

    return result


def main():
    parser = testlib.arg_parser(enable_sit=False)
    parser.add_argument('-j', '--jobs', type=int,
                        default=os.environ.get("TEST_JOBS", 1), help="Number of concurrent jobs")
    parser.add_argument('--thorough', action='store_true',
                        help='Thorough mode, no skipping known issues')
    parser.add_argument('-n', '--nondestructive', action='store_true',
                        help='Only consider @nondestructive tests')
    parser.add_argument('--machine', metavar="hostname[:port]",
                        default=None, help="Run tests against an already running machine;  implies --nondestructive")
    parser.add_argument('--browser', metavar="hostname[:port]",
                        default=None, help="When using --machine, use this cockpit web address")
    parser.add_argument('--test-dir', default=testvm.TEST_DIR,
                        help="Directory in which to glob check-* files")
    parser.add_argument('--exclude', action="append", default=[], metavar="TestClass.testName",
                        help="Exclude test (exact match only); can be specified multiple times")
    opts = parser.parse_args()

    if opts.machine:
        if opts.jobs > 1:
            parser.error("--machine cannot be used with concurrent jobs")
        opts.nondestructive = True

    # Tell any subprocesses what we are testing
    if "TEST_REVISION" not in os.environ:
        r = subprocess.run(["git", "rev-parse", "HEAD"],
                           universal_newlines=True, check=False, stdout=subprocess.PIPE)
        if r.returncode == 0:
            os.environ["TEST_REVISION"] = r.stdout.strip()

    os.environ["TEST_BROWSER"] = os.environ.get("TEST_BROWSER", "chromium")

    image = testvm.DEFAULT_IMAGE
    testvm.DEFAULT_IMAGE = image
    os.environ["TEST_OS"] = image

    return run(opts, image)


if __name__ == '__main__':
    sys.exit(main())
