#!/usr/bin/env python3
# encoding: utf-8
"""
cache.py

Created by Thomas Mangin
Copyright (c) 2013-2017 Exa Networks. All rights reserved.
License: 3-clause BSD. (See the COPYRIGHT file)
"""

import os
import re
import sys
import json
import glob
import time
import signal
import argparse
import subprocess
from enum import Enum


INTERPRETER = os.environ.get('INTERPRETER', '')
if not INTERPRETER:
    INTERPRETER = os.environ.get('__PYVENV_LAUNCHER__', sys.executable)


class Alarm(Exception):
    pass


def flush(*args, **kwars):
    print(*args, **kwars)
    sys.stdout.flush()


def alarm_handler(number, frame):  # pylint: disable=W0613
    raise Alarm()


def color(prefix, suffix):
    def code(value):
        return f'\033[{value}m'

    return code(prefix) + code(suffix)


class Port:
    base = 1790

    @classmethod
    def get(cls):
        current = cls.base
        cls.base += 1
        return current


class Path:
    PROGRAM = os.path.realpath(__file__)
    ROOT = os.path.abspath(os.path.join(os.path.dirname(PROGRAM), os.path.join('..', '..')))
    SRC = os.path.join(ROOT, 'src')

    ETC = os.path.join(ROOT, 'etc', 'exabgp')
    EXABGP = os.path.join(ROOT, 'sbin', 'exabgp')
    BGP = os.path.join(ROOT, 'qa', 'sbin', 'bgp')
    DECODING = os.path.join(os.path.join(ROOT, 'qa', 'decoding'))
    ENCODING = os.path.join(os.path.join(ROOT, 'qa', 'encoding'))

    ALL_ETC = glob.glob(os.path.join(ETC, '*.conf'))
    ALL_ETC.sort()
    ALL_DECODING = glob.glob(os.path.join(DECODING, '*'))
    ALL_DECODING.sort()
    ALL_ENCODING = glob.glob(os.path.join(ENCODING, '*.ci'))
    ALL_ENCODING.sort()

    @classmethod
    def setup_env(cls):
        os.environ['PYTHONPATH'] = cls.SRC

    @staticmethod
    def etc(fname):
        return os.path.abspath(os.path.join(Path.ETC, fname))

    @staticmethod
    def ci(fname, ext):
        return os.path.abspath(os.path.join(Path.ENCODING, fname) + '.' + ext)

    @classmethod
    def validate(cls):
        if not os.path.isdir(cls.ETC):
            sys.exit('could not find etc folder')

        if not os.path.isdir(cls.ENCODING):
            sys.exit('could not find tests in the qa/encoding folder')

        if not os.path.isdir(cls.DECODING):
            sys.exit('could not find the tests in qa/decoding')

        if not os.path.isfile(cls.EXABGP):
            sys.exit('could not find exabgp')

        if not os.path.isfile(cls.BGP):
            sys.exit('could not find the sequence daemon')


class Exec(object):
    def __init__(self):
        self.code = -1
        self.stdout = b''
        self.stderr = b''
        self.message = ''
        self._process = None
        self.command = []

    def run(self, command):
        self.command = ' '.join([_ if ' ' not in _ else f"'{_}'" for _ in command])
        self._process = subprocess.Popen(
            command,
            stdout=subprocess.PIPE,
            stderr=subprocess.PIPE,
        )
        return self

    def ready(self):
        signal.signal(signal.SIGALRM, alarm_handler)
        try:
            signal.alarm(1)
            polled = self._process.poll()
            signal.alarm(0)
        except Alarm:
            return False
        except (OSError, ValueError):
            return True
        if polled is None:
            return False
        return True

    def report(self, reason='issue with exabgp'):
        flush(reason)
        flush(f'command> {self.command}')
        flush(f'return: {self.code}')
        flush(f'stdout: {self.stdout.decode()}')
        flush(f'stderr: {self.stderr.decode()}')
        flush(f'message: {self.message}')

    def collect(self):
        if self.stdout:
            return
        if self.stderr:
            return
        if self.code != -1:
            return

        signal.signal(signal.SIGALRM, alarm_handler)
        try:
            signal.alarm(15)
            self.stdout, self.stderr = self._process.communicate()
            self.code = self._process.returncode
            signal.alarm(0)
        except ValueError as exc:  # I/O operation on closed file
            self.message = str(exc)
            pass
        except Alarm as exc:
            self.message = str(exc)
            pass

    def terminate(self):
        try:
            self._process.send_signal(signal.SIGTERM)
        except OSError:  # No such process, Errno 3
            pass
        self.collect()

    def __del__(self):
        self.terminate


State = Enum('State', 'NONE STARTING RUNNING FAIL SUCCESS SKIP')


class Record:
    _index = 0
    _listing = '0123456789' + 'ABCDEFGHIJKLMNOPQRSTUVWXYZ' + 'abcdefghijklmnopqrstuvwxyz' + 'αβγδεζηθικλμνξοπρςστυφχψω'

    def __init__(self, nick, name):
        self.nick = nick
        self.name = name
        self.conf = dict()
        self.files = []
        self.state = State.SKIP

    @classmethod
    def new(cls, name):
        return cls(cls.next_nick(), name)

    @classmethod
    def next_nick(cls):
        nick = cls._listing[cls._index]
        cls._index += 1
        return nick

    def skip(self):
        self.state = State.SKIP
        return self

    def fail(self):
        self.state = State.FAIL
        return self

    def activate(self):
        self.state = State.NONE
        return self

    def is_active(self):
        return self.state not in (State.SKIP, State.FAIL, State.SUCCESS)

    def setup(self):
        if self.state == State.NONE:
            self.state = State.STARTING
            return
        if self.state == State.STARTING:
            self.state = State.RUNNING
            return

    def running(self):
        self.state = State.RUNNING

    def colored(self):
        if self.state == State.NONE:
            return color(0, 30) + self.nick  # BLACK
        if self.state == State.STARTING:
            return color(1, 30) + self.nick  # GRAY
        if self.state == State.RUNNING:
            return color(0, 0) + self.nick  # NORMAL
        if self.state == State.FAIL:
            return color(0, 91) + self.nick  # RED
        if self.state == State.SUCCESS:
            return color(1, 92) + '✓'  # GREEN
        if self.state == State.SKIP:
            return color(0, 34) + '✖'  # BLUE

    def result(self, success):
        if success:
            self.state = State.SUCCESS
        else:
            self.state = State.FAIL
        return success


class Tests:
    def __init__(self, klass):
        self.klass = klass
        self._by_nick = {}
        self._ordered = []
        self._nl = 3

    def new(self, name):
        test = self.klass.new(name)
        self._by_nick[test.nick] = test
        self._ordered.append(test.nick)
        self._ordered.sort()
        return test

    def enable_by_nick(self, nick):
        if nick in self._by_nick:
            self._by_nick[nick].activate()
            return True
        return False

    def enable_all(self):
        for nick in self._by_nick:
            self._by_nick[nick].activate()

    def get_by_nick(self, nick):
        return self._by_nick[nick]

    def registered(self):
        return [self._by_nick[nick] for nick in self._ordered]

    def selected(self):
        return [self._by_nick[nick] for nick in self._ordered if self._by_nick[nick].is_active()]

    def _iterate(self):
        number = len(self._ordered)
        lines = number // self._nl

        for line in range(0, lines + 1):
            tests = []
            start = line * self._nl
            for n in range(start, start + self._nl):
                if n >= number:
                    continue
                nick = self._ordered[n]
                tests.append(self._by_nick[nick])
            yield tests

    def listing(self):
        sys.stdout.write('\n')
        sys.stdout.write('The available tests are:\n')
        sys.stdout.write('\n')
        for tests in self._iterate():
            for test in tests:
                sys.stdout.write(f' {test.nick:2} {test.name:25}')
            sys.stdout.write('\n')
        sys.stdout.write('\n')
        sys.stdout.flush()

    def short_list(self):
        sys.stdout.write(' '.join([test.nick for test in self.registered()]))
        sys.stdout.write('\n')
        sys.stdout.flush()

    def display(self):
        for test in self.registered():
            sys.stdout.write(' %s' % test.colored())
        sys.stdout.write('%s\r' % color(0, 0))
        sys.stdout.flush()


def parse_msg_options(msg_file):
    """Parse option directives from .msg file"""
    tcp_connections = None
    try:
        with open(msg_file, 'r') as f:
            for line in f:
                line = line.strip()
                if line.startswith('option:tcp_connections:'):
                    tcp_connections = int(line.split(':')[-1])
                    break
    except Exception:
        pass
    return tcp_connections


class EncodingTests(Tests):
    class Test(Record, Exec):
        def __init__(self, nick, name):
            Record.__init__(self, nick, name)
            Exec.__init__(self)
            self._check = b'successful'

        # def __eq__ (self, other):
        #     return self.nick.__eq__(other.nick)

        # def __lt__ (self, other):
        #     return self.nick.__lt__(other.nick)

        # def __hash__ (self):
        #     return self.nick.__hash__()

        def success(self):
            self.collect()
            if self.code == 0:
                if self._check in self.stdout:
                    if os.getenv('DEBUG', None) is not None:
                        self.report('completed successfully')
                    return True
                if self._check in self.stderr:
                    if os.getenv('DEBUG', None) is not None:
                        self.report('completed successfully')
                    return True

            self.report(f'failed with code {self.code}')
            return False

    API = re.compile(r'^\s*run\s+(.*)\s*?;\s*?$')

    def __init__(self):
        super().__init__(self.Test)

        for filename in Path.ALL_ENCODING:
            name = os.path.basename(filename)[:-3]
            test = self.new(name)
            with open(filename, 'r') as reader:
                content = reader.readline()
            test.conf['confs'] = [Path.etc(_) for _ in content.split()]
            test.conf['ci'] = Path.ci(name, 'ci')
            test.conf['msg'] = Path.ci(name, 'msg')
            test.conf['port'] = Port.get()
            test.files.extend(test.conf['confs'])
            test.files.append(test.conf['ci'])
            test.files.append(test.conf['msg'])

            for f in test.conf['confs']:
                with open(f) as reader:
                    for line in reader:
                        found = self.API.match(line)
                        if not found:
                            continue
                        name = found.group(1)
                        if not name.startswith('/'):
                            name = Path.etc(name)
                        if name not in test.files:
                            test.files.append(name)

    def explain(self, nick):
        template = '\n'
        template += 'exabgp\n'
        template += '-' * 55 + '\n\n'
        template += '%(client)s\n\n\n'
        template += 'bgp deamon\n'
        template += '-' * 55 + '\n\n'
        template += '%(server)s\n\n\n'
        template += 'The following extra configuration options could be used\n'
        template += '-' * 55 + '\n\n'
        template += 'export exabgp_debug_rotate=true\n'
        template += 'export exabgp_debug_defensive=true\n'

        flush(
            template
            % {
                'client': self.client(nick),
                'server': self.server(nick),
            }
        )

    def client(self, nick):
        if not self.enable_by_nick(nick):
            sys.exit('no such test')
        test = self.get_by_nick(nick)

        # Parse tcp_connections option from .msg file, default to 1 if not specified
        tcp_connections = parse_msg_options(test.conf['msg']) or 1

        config = {
            'env': ' \\\n  '.join(
                [
                    'exabgp_version=5.0.0-0+test',
                    f'exabgp_tcp_connections={tcp_connections}',
                    'exabgp_api_cli=false',
                    'exabgp_debug_rotate=true',
                    'exabgp_debug_configuration=true',
                    "exabgp_tcp_bind=''",
                    'exabgp_tcp_port=%d' % test.conf['port'],
                    'INTERPRETER=%s' % INTERPRETER,
                ]
            ),
            'exabgp': Path.EXABGP,
            'confs': ' \\\n    '.join(test.conf['confs']),
        }
        return 'env \\\n  %(env)s \\\n   %(exabgp)s -d -p \\\n    %(confs)s' % config

    def server(self, nick):
        if not self.enable_by_nick(nick):
            sys.exit('no such test')
        test = self.get_by_nick(nick)

        config = {
            'env': ' \\\n  '.join(
                [
                    'exabgp_tcp_port=%d' % test.conf['port'],
                ]
            ),
            'interpreter': INTERPRETER,
            'bgp': Path.BGP,
            'msg': test.conf['msg'],
        }

        return 'env \\\n  %(env)s \\\n  %(interpreter)s %(bgp)s --view \\\n    %(msg)s' % config

    def dry(self):
        result = []
        for test in self.selected():
            result.append(
                ' '.join(
                    [
                        '>',
                        sys.argv[0],
                        'encoding',
                        '--server',
                        test.nick,
                        '--port',
                        f'{test.conf["port"]}',
                    ]
                )
            )
            result.append(
                ' '.join(
                    [
                        '>',
                        sys.argv[0],
                        'encoding',
                        '--client',
                        test.nick,
                        '--port',
                        f'{test.conf["port"]}',
                    ]
                )
            )
        return '\n'.join(result)

    def run_selected(self, timeout):
        success = True
        client = dict()

        for test in self.selected():
            test.setup()
            self.display()
            test.run(
                [
                    sys.argv[0],
                    'encoding',
                    '--server',
                    test.nick,
                    '--port',
                    f'{test.conf["port"]}',
                ]
            )
            time.sleep(0.005)

        time.sleep(0.02)

        for test in self.selected():
            test.setup()
            self.display()
            client[test.nick] = Exec().run(
                [
                    sys.argv[0],
                    'encoding',
                    '--client',
                    test.nick,
                    '--port',
                    f'{test.conf["port"]}',
                ]
            )
            time.sleep(0.005)

        exit_time = time.time() + timeout

        running = set(self.selected())

        while running and time.time() < exit_time:
            self.display()
            for test in list(running):
                if not test.ready():
                    continue
                if not client[test.nick].ready():
                    continue
                running.remove(test)
                client[test.nick].terminate()
                success = test.result(test.success()) and success
                self.display()
            time.sleep(0.1)

        for test in running:
            test.fail()
            test.terminate()
            success = False

        self.display()
        return success


class DecodingTests(Tests):
    class Test(Record, Exec):
        def __init__(self, nick, name):
            Record.__init__(self, nick, name)
            Exec.__init__(self)

        def _cleanup(self, decoded):
            decoded.pop('exabgp', None)
            decoded.pop('host', None)
            decoded.pop('pid', None)
            decoded.pop('ppid', None)
            decoded.pop('time', None)
            decoded.pop('version', None)
            return decoded

        def success(self):
            self.collect()
            if self.stderr:
                self.report('stderr is \n' + self.stderr.decode())
                return False
            if not self.stdout:
                self.report('no stdout received')
                return False
            try:
                decoded = json.loads(self.stdout)
                self._cleanup(decoded)
            except Exception:
                self.report('issue, report to decode the JSON')
                return False
            if decoded != self.conf['json']:
                from pprint import pformat

                failure = 'issue, JSON does not match'
                failure += f'\ndecoded : {pformat(decoded)}\n'
                failure += f'\nexpected: {pformat(self.conf["json"])}'
                self.report(failure)
                return False
            return True

    def __init__(self):
        super().__init__(self.Test)

        for filename in Path.ALL_DECODING:
            name = os.path.basename(filename).split('.')[0]
            test = self.new(name)
            with open(filename, 'r') as reader:
                words = reader.readline().split()
                test.conf['type'] = words[0]
                test.conf['family'] = '' if words[0] == 'open' else f'{words[1]} {words[2]}'
                packet = reader.readline().replace(' ', '').strip()
                test.conf['packet'] = packet
                expected = reader.readline().strip()
                decoded = json.loads(expected)
                test.conf['json'] = test._cleanup(decoded)
            test.files.append(filename)

    def listing(self):
        sys.stdout.write('\n')
        sys.stdout.write('The available tests are:\n')
        sys.stdout.write('\n')
        for tests in self._iterate():
            for test in tests:
                sys.stdout.write(f' {test.nick:2} {test.name:25}')
            sys.stdout.write('\n')
        sys.stdout.write('\n')
        sys.stdout.flush()

    def dry(self):
        result = []
        for test in self.selected():
            result.append(
                ' '.join(
                    [
                        '>',
                        Path.EXABGP,
                        'decode',
                        f"'-f {test.conf['family']}'" if test.conf['family'] else '',
                        '--%s' % test.conf['type'],
                        test.conf['packet'],
                    ]
                )
            )
        return '\n'.join(result)

    def run_selected(self, timeout):
        success = True
        for test in self.selected():
            test.running()
            self.display()
            message = test.conf['type']
            if message == 'open':
                cmd = [
                    Path.EXABGP,
                    'decode',
                    '--%s' % test.conf['type'],
                    test.conf['packet'],
                ]
            elif message in ['update', 'nlri']:
                cmd = [
                    Path.EXABGP,
                    'decode',
                    '-f',
                    test.conf['family'],
                    '--%s' % test.conf['type'],
                    test.conf['packet'],
                ]
            else:
                raise ValueError(f'invalid message type: {message}')
            test.run(cmd)

        for test in self.selected():
            self.display()
            success = test.result(test.success()) and success
            time.sleep(0.05)

        exit_time = time.time() + timeout
        running = set(self.selected())

        while running and time.time() < exit_time:
            self.display()
            for test in list(running):
                if not test.ready():
                    continue
                running.remove(test)
                success = test.result(test.success()) and success
                self.display()
            time.sleep(0.1)

        for test in running:
            test.fail()
            test.terminate()
            success = False

        self.display()
        return success


class ParsingTests(Tests):
    class Test(Record, Exec):
        def __init__(self, nick, name):
            Record.__init__(self, nick, name)
            Exec.__init__(self)

        def success(self):
            self.collect()
            if self.code != 0:
                self.report('return code is not zero')
                return False

            return self.code == 0

    def __init__(self):
        super().__init__(self.Test)

        for filename in Path.ALL_ETC:
            name = os.path.basename(filename).split('.')[0]
            test = self.new(name)
            test.conf['fname'] = filename
            test.files.append(filename)

    def listing(self):
        sys.stdout.write('\n')
        sys.stdout.write('The available tests are:\n')
        sys.stdout.write('\n')
        for tests in self._iterate():
            for test in tests:
                sys.stdout.write(f' {test.nick:2} {test.name:25}')
            sys.stdout.write('\n')
        sys.stdout.write('\n')
        sys.stdout.flush()

    def dry(self):
        result = []
        for test in self.selected():
            result.append(' '.join(['>', Path.EXABGP, 'validate', '-nrv', test.conf['fname']]))
        return '\n'.join(result)

    def run_selected(self, timeout):
        success = True

        for test in self.selected():
            test.running()
            test.run([Path.EXABGP, 'validate', '-nrv', test.conf['fname']])
            time.sleep(0.005)

        time.sleep(0.02)

        for test in self.selected():
            success = test.result(test.success()) and success
            time.sleep(0.005)

        exit_time = time.time() + timeout
        running = set(self.selected())

        while running and time.time() < exit_time:
            for test in list(running):
                if not test.ready():
                    continue
                running.remove(test)
                success = test.result(test.success()) and success
                self.display()
            time.sleep(0.1)

        for test in running:
            test.fail()
            test.terminate()
            success = False

        self.display()
        return success


def add_test(subparser, name, tests, extra):
    sub = subparser.add_parser(name, help=f'run {name} test')
    if 'dry' in extra:
        sub.add_argument('--dry', help='show the action', action='store_true')
    if 'server' in extra:
        sub.add_argument('--server', help='start the server for a test', action='store_true')
    if 'client' in extra:
        sub.add_argument('--client', help='start the client for a test', action='store_true')
    if 'list' in extra:
        sub.add_argument('--list', help='list the files making a test', action='store_true')
    if 'short_list' in extra:
        sub.add_argument('--short-list', help='list test codes only (space separated)', action='store_true')
    if 'edit' in extra:
        sub.add_argument('--edit', help='edit the files making a test', action='store_true')
    if 'timeout' in extra:
        sub.add_argument('--timeout', help='timeout for test failure', type=int, default=60)
    if 'port' in extra:
        sub.add_argument('--port', help='base port to use', type=int, default=1790)
    sub.add_argument('test', help='name of the test to run', nargs='?', default=None)

    def func(parsed):
        if 'edit' in extra and parsed.edit:
            if not tests.enable_by_nick(parsed.test):
                sys.exit('no such test')
            test = tests.get_by_nick(parsed.test)
            if not test.files:
                sys.exit('no file for the test')
            editor = os.environ.get('EDITOR', 'vi')
            command = '%s %s' % (editor, ' '.join(test.files))
            flush(f'> {command}')
            if not parsed.dry:
                sys.exit(os.system(command))
            return

        if 'list' in extra and parsed.list:
            tests.listing()
            return

        if 'short_list' in extra and parsed.short_list:
            tests.short_list()
            return

        if 'client' in extra and parsed.client:
            command = tests.client(parsed.test)
            flush(f'client> {command}')
            if not parsed.dry:
                sys.exit(os.system(command))
            return

        if 'server' in extra and parsed.server:
            command = tests.server(parsed.test)
            flush(f'server> {command}')
            if not parsed.dry:
                sys.exit(os.system(command))
            return

        if 'timeout' not in parsed:
            parsed.timeout = 0

        if parsed.test:
            if not tests.enable_by_nick(parsed.test):
                sys.exit(f'could not find test {parsed.test}')
        else:
            tests.enable_all()

        if parsed.dry:
            command = tests.dry()
            flush(command)
            sys.exit(0)

        exit = tests.run_selected(parsed.timeout)
        sys.stdout.write('\n')
        sys.exit(0 if exit else 1)

    sub.set_defaults(func=func)


if __name__ == '__main__':
    Path.validate()
    Path.setup_env()

    decoding = DecodingTests()
    encoding = EncodingTests()
    parsing = ParsingTests()

    parser = argparse.ArgumentParser(description='The BGP swiss army knife of networking functional testing tool')
    subparser = parser.add_subparsers()

    add_test(subparser, 'decoding', decoding, ['list', 'short_list', 'edit', 'dry', 'timeout', 'port'])
    add_test(
        subparser,
        'encoding',
        encoding,
        ['list', 'short_list', 'edit', 'dry', 'timeout', 'port', 'server', 'client'],
    )
    add_test(subparser, 'parsing', parsing, ['list', 'short_list', 'dry', 'edit'])

    parsed = parser.parse_args()
    if vars(parsed):
        parsed.func(parsed)
    else:
        parser.print_help()
