#!/usr/bin/python3.6

from prometheus_client import start_http_server, Summary, Gauge, Counter
from dateutil.parser import parse

import sys
import time
import os
import re
import getopt
from datetime import datetime
from urllib.parse import urlparse
import logging

import xml.etree.ElementTree as ET

import subprocess
from collections import Counter
from prometheus_client.core import GaugeMetricFamily, CounterMetricFamily, REGISTRY
from arc.utils import config


class LogReader:
    "Log reader helper, starts with current logfile and it follows logrotate"

    def __init__(self, file):
        self.file = file
        self.f = None
        # Alt. self.f = open(self.file, 'r') to get failure on first open

    def __iter__(self):
        """Iterator that yields all lines in currently in self.file
           starting from where it last left off. If the file has been
           overwritten it will reopen the file and start from the
           beginning."""

        # Maybe "if self.needs_reopen():"? That way one could make it
        # handle complex corner cases.
        # Do flush() on filehandle so we can to tell()
        # XXX - What does flush do on a readonly filehandle? What are
        #       the side effects?
        if self.f and self.f.flush() and self.f.tell() > os.path.getsize(self.file):
            self.f.close()
            self.f = None

        if not self.f:
            try:
                self.f = open(self.file, 'r')
            except OSError:	# Limit to FileNotFoundError, PermissionError?
                # Leave the iterator and assume that things will improve
                # next time.
                return

        try:
            # Alt. yield from self.f
            for line in self.f:
                yield line
        except UnicodeError:
            # XXX - Cheating a bit here, should probably try to seek
            #       to next newline and continue from there. As it
            #       is now one bad byte will lead to the rest of the
            #       file being skipped!
            pass


class DTRCollector(object):
    # Custom collector for DTR where metrics need to be reset at each cycle
    def collect(self):
        d = Counter()
        with open(dtrState, 'r') as fh:
            for line in fh:
                # DTR-ID STATE PRIO SHARE [URL [HOST]]
                fields = line.split()
                if len(fields) < 4:
                    continue
                state = fields[1]
                share = fields[3]
                if len(fields) == 5:
                    host = 'local'
                elif len(fields) == 6:
                    host = fields[5]
                else:
                    host = ""
                d[(state, share, host)] += 1
        c = GaugeMetricFamily('arc_transfers', 'ARC Transfers',
                              labels=['state', 'share', 'host'])
        for i in d.items():
            c.add_metric(list(i[0]), i[1])
        yield c


def getARCCoreCounts():

    """ Extract corecounts from info.xml """
    try:
        # Run the cat command using sudo
        result = subprocess.run(['sudo', 'cat', infoxml], capture_output=True, text=True)
        
        if result.returncode != 0:
            logging.error(f"Failed to read the file: {result.stderr.strip()}")
            return
        
        # Parse the XML content
        root = ET.fromstring(result.stdout)
            
        # Define the namespace
        namespace = {'glue': 'http://schemas.ogf.org/glue/2009/03/spec_2.0_r1'}
        
        # Find all ComputingShares  - there can be several ComputingShares (queues)
        # Each can have multivalued OtherInfo xml tags of type:
        #      <OtherInfo>CoreCount=INLRMS:R=431</OtherInfo>
        #      <OtherInfo>CoreCount=PREPARING=2543</OtherInfo>
        #      <OtherInfo>CoreCount=INLRMS:E=1</OtherInfo>
        computing_shares = root.findall('.//glue:ComputingShare', namespace)
        if computing_shares is not None:
            for computing_share in computing_shares:
                share = computing_share.find('glue:Name', namespace).text
                other_info_tags = computing_share.findall('.//glue:OtherInfo',namespace)
                for other_info in other_info_tags:
                    try:
                        parts = other_info.text.split('=')
                        state = parts[1]
                        count = parts[2]
                        arc_corecount.labels(share,state).set(int(count))
                    except ValueError as e:
                        logging.error(f"Error parsing count value: {count} - {str(e)}")
                    except IndexError as e:
                        logging.error(f"Error fetching information about corecount per state: {str(e)}")
                    except Exception as e:
                        logging.error(f"Unexpected error fetching information about corecount per state {str(e)}")
        else:
            logging.warning("No ComputingShare tag not found in XML.")

    except FileNotFoundError as e:
        logging.error(f"File not found: {str(e)}")
    except ET.ParseError as e:
        logging.error(f"Error parsing XML file: {str(e)}")
    except Exception as e:
        logging.error(f"Unexpected error: {str(e)}")




def getCoreCounts():

    """ Useful metrics for pledge monitoring """
    cmd = "sacct -S now-1hour -a --format=JobID,account,user,partition,State,alloccpus,ReqCPUS,node -n -P | grep -v batch"
    sacct_result = subprocess.run(cmd, shell=True, capture_output=True, text=True)
    sacct_result = sacct_result.stdout.strip().split('\n')

    # Reset the gauge each time I scrape the metrics to get fresh numbers
    # of RUNNING, PENDING etc
    corecount.clear()
    # Loop over all the lines in the sacct_result and extract account,
    # state, and number of cores for this account and state. Fill gauge
    # with the metric. Must be increased per account and state.

    for item in sacct_result:
        try:
            parts = item.split('|')
            account = str(parts[1])
            user = str(parts[2])
            partition = str(parts[3])
            state = str(parts[4])
            alloc_cores = str(parts[5])
            req_cores = str(parts[6])
            cores = alloc_cores
        except:
            continue
        if int(cores) == 0:
            # If job is pending, no cores yet allocated. Set corecount as
            # requested cores
            cores = req_cores
        if 'CANCELLED' in state:
            # Typically CANCELLED by <uid>
            state = state.split('by')[0].strip()
        corecount.labels(account, user, partition, state).inc(int(cores))


def arcJobStats():
    # jobs
    cmd = ['arcctl', '-d', 'CRITICAL', 'job', 'stats']
    out, err = subprocess.Popen(cmd, stdout=subprocess.PIPE).communicate()
    for line in out.decode().split('\n'):
        j = line.split()
        if len(j) > 1:
            m = j[0][:-1]
            n = j[1]
            jobs.labels(m).set(n)


def arexLogStats():
    # arex log
    for line in arexlog:
        j = line.split()
        if len(j) < 8:
            continue
        if "Job failure detected" in line:
            jstats.labels("FAILED").inc()
        if j[6] != 'State:':
            continue
        s = j[7]
        s = s.strip(':')
        jstats.labels(s).inc()


def heartBeat():
    # gm heartbeat
    mtime = os.stat(ctrldir + '/gm-heartbeat').st_mtime
    gmtime.set(time.time() - mtime)


def dtrStats():
    # delivery log
    for line in deliveryLog:
        j = line.split()
        if len(j) < 11:
            continue
        if j[11] == 'TRANSFERRED':
            dstats.inc()


def cdtrStats():
    # central staging:
    # Depends on debug level being at least INFO
    for line in cdeliveryLog:
        # [<date> <time>] [<loglevel>] [<pid>/<num>] DTR <dtr-id>: <msg>
        # <date: YEAR-MM-DD
        # <time>: hh:mm:ss
        # <loglevel>: ERROR | WARNING | INFO | VERBOSE | DEBUG
        # <dtr-id>: <hex4>...<hex4>:
        # <hex4>: 4 lower case hex letters ([0-9a-f])
        j = line.split(maxsplit=6)
        if len(j) <= 6:
            continue

        dtr = j[5]
        msg = j[6]

        # Rest will do matches on msg with these definitions:
        # <rurl>: remote url (aka any URL except file:)
        # <surl>: source url, either <rurl> or file:/<path>
        # <durl>: destination url, either <rurl> or file:/<path>
        # One of <surl> and <durl> is file:
        # <dtr>: long dtr id, <hex8>-<hex4>+<hex4><hex8>
        # <hex8>: <hex4><hex4>
        # <share>: (<vo>:<...> | "_default") - (<direction>)
        # <direction>: "upload" | "download"
        # <prio>: number

        # Scheduler received new DTR <dtr> with source: <surl> destination: <durl> assigned to transfer share <share> with priority <prio>
        if msg.startswith("Scheduler received new DTR "):
            # XXX - what happens if dtr is already in dtrs?
            msg = msg.split(maxsplit=10)
            d = dtrs[dtr] = {}
            if len(msg) > 9:
                d['start']  = j[0][1:] + ' ' + j[1][:-1]
                d['src']    = msg[7]
                d['dst']    = msg[9]
                d['cached'] = False
            continue

        # Only process already started DTRs
        if dtr not in dtrs:
            continue

        d = dtrs[dtr]

        # Delivery received new DTR <dtrid> with source: <surl> destination: <durl>
        if msg.startswith("Delivery received new DTR "):
            msg = msg.split(maxsplit=8)
            if len(msg) > 7:
                d['tsrc'] = msg[7]

        # DataDelivery:   <num> s:  6738944.0 kB   17289.4 kB/s   17549.3 kB/s    . . .
        elif msg.startswith('DataDelivery:  ') and ' s: ' in msg:
            msg = msg.split(maxsplit=6)
            if len(msg) > 5 and msg[1].isdigit():
                d['time'] = msg[1]
                d['rate'] = msg[5]

        # Transfer finished: <num> bytes transferred
        # Transfer finished: <num> bytes transferred : checksum adler32:<hex>
        elif msg.startswith("Transfer finished: "):
            msg = msg.split(maxsplit=3)
            try:
                d['size'] = int(msg[2])
            except (ValueError, IndexError):
                d['size'] = 0

        # File <rurl> is cached (<path>) - checking permissions
        elif msg.startswith("File ") and " is cached (/" in msg:
            d['cached'] = True

        # Finished successfully
        elif msg.startswith("Finished successfully"):
            try:
                d['end'] =  j[0][1:] + ' ' + j[1][:-1]
                # s = datetime.fromisoformat(d['start'])
                # e = datetime.fromisoformat(d['end'])
                s = parse(d['start'])
                e = parse(d['end'])
                d['duration'] = (e - s).total_seconds()
                if d['src'].startswith("file:/"):
                    d['type'] = 'upload'
                    t = urlparse(d['dst'])
                else:
                    d['type'] = 'download'
                    if 'tsrc' in d:
                        t = urlparse(d['tsrc'])
                    else:
                        t = urlparse(d['src'])
                        # print(d)
                if t is not None:
                    # print (t)
                    d['domain'] = '.'.join(t.netloc.split('.')[1:]).split(":")[0]

                dm = ''
                if 'domain' in d:
                    dm = d['domain']
                cdstats['files'].labels(d['type'], d['cached'], dm).inc()
                if 'size' in d:
                    cdstats['size'].labels(d['type'], d['cached'], dm).inc(d['size'])
                del dtrs[dtr]
            except:
                pass

        # Returning to generator
        elif msg.startswith("Returning to generator"):
            # Should check this! Otherwise we WILL leak memory
            pass


def process_request():
    """A dummy function that takes some time."""

    if jobs:
        arcJobStats()

    if jstats:
        arexLogStats()

    if gmtime:
        heartBeat()

    if dstats:
        dtrStats()

    if cdstats:
        cdtrStats()

    getARCCoreCounts()

    if corecount:
        getCoreCounts()


# def usage():
#     print("""
#     usage: arc-exporter.py [--port <port>] [--refresh <seconds>]
#             [-f space_command] [ -c <arex_config_file> | <dir1> [<dir2> [...]] ]
#     --help - This help
#
# """)


logging.basicConfig(level=logging.DEBUG)

config.parse_arc_conf()
config_dict = config.get_config_dict()

try:
    # XXX - can we actually accept not having controldir defined?
    ctrldir = config_dict['arex']['controldir']
except KeyError:
    ctrldir = ''

# arc.conf might set a *specific* path to dtr.state
# https://www.nordugrid.org/arc/arc7/admins/reference.html#statefile
try:
    dtrState = config_dict['arex/data-staging']['statefile']
except KeyError:
    dtrState = ctrldir + '/dtr.state'

infoxml = ctrldir + '/info.xml'

jobs = jstats = gmtime = dstats = cdstats = corecount = None

# Number of cores in the different ARC states and per share
arc_corecount = Gauge('arc_corecount','Corecount per computing share and job state',['share','state'])

# Are on a CE?
if os.path.exists(dtrState):
    REGISTRY.register(DTRCollector())
    # Job state snapshot from gm-jobs
    jobs = Gauge('arc_jobs', 'ARC Jobs', ['state'])
    # Job state counter from arex.log
    try:
        arexLogFile = config_dict['arex']['logfile']
        arexlog = LogReader(arexLogFile)
        jstats = Gauge('arc_job_count', 'ARC Jobs', ['state'])
    except KeyError:
        pass
    # gm=heartbeat time stamp
    gmtime = Gauge('arc_arex_heartbeat', 'AREX Heartbeart')

try:
    deliveryLogFile = config_dict['datadelivery-service']['logfile']
    deliveryLog = LogReader(deliveryLogFile)
    dstats = Gauge('arc_delivery_count', 'ARC Data Delivery count')
except KeyError:
    pass

try:
    cdeliveryLogFile = config_dict['arex/data-staging']['logfile']
    cdeliveryLog = LogReader(cdeliveryLogFile)
    dtrs = {}
    cdstats = {}
    cdstats['files'] = Gauge('arc_central_staging_files', 'ARC Central Data Staging Files', ['type', 'cached', 'domain'])
    cdstats['size'] = Gauge('arc_central_staging_size', 'ARC Central Data Staging Size', ['type', 'cached', 'domain'])
except KeyError:
    pass

# Are we using slurm?
if config_dict['lrms']['lrms'].startswith('slurm'):
    corecount = Gauge('arc_slurm_corecount', 'Corecount per slurm user and job state', ['account', 'user', 'partition', 'state'])
    if not hasattr(corecount, 'clear'):
        # https://stackoverflow.com/questions/1015307/how-to-bind-an-unbound-method-without-calling-it
        # XXX - (almost) copy of Gauge.clear from version 0.10
        def _tf(self):
            with self._lock: self._metrics = {}
        setattr(corecount, 'clear', _tf.__get__(corecount))
        del _tf



if __name__ == '__main__':

    if len(sys.argv) == 3 and sys.argv[1] == "--port":
        port = int(sys.argv[2])
    else:
        port = 9101

    # Start up the server to expose the metrics.
    start_http_server(port)
    # Generate some requests.
    while True:
        process_request()
        time.sleep(10)
