#!/usr/bin/env python3
"""Collect and summarize evaluation results"""

import argparse
import os
import sys
import typing as T
from pathlib import Path

import numpy
import yaml as pyyaml

VALUE_CONVERSIONS = {'duration_time:u': lambda x: x / 1000 / 1000 / 1000}

Measurements = dict[str, list[float]]
Data = dict[str, Measurements]


def collect(result_dir) -> T.Optional[Data]:
    """Collect the data in result_dir and return calculated averages"""
    if not result_dir or not os.path.exists(result_dir) or not os.path.isdir(
            result_dir):
        print(f'{result_dir} is not a directory', file=sys.stderr)
        return None

    result_dir = Path(result_dir)

    data = {}

    for result_file_path in result_dir.iterdir():
        if result_file_path.suffix != '.stats':
            continue
        target = result_file_path.name.split('.')[0]
        results = {}
        with open(result_file_path, 'r', encoding='utf-8') as result_file:
            for line in result_file.readlines()[2:]:
                fields = line.split(';')
                key, _value = fields[2], fields[0]

                if not key or _value == '<not supported>':
                    continue

                try:
                    value = float(_value)
                except ValueError as val_err:
                    print(
                        f'{val_err} occured during value conversion of {key}',
                        file=sys.stderr)
                    results[key] = numpy.nan
                    continue

                if key in VALUE_CONVERSIONS:
                    value = VALUE_CONVERSIONS[key](value)
                results[key] = value

        if not results:
            print(f'Warning: empty result file {result_file_path}',
                  file=sys.stderr)
            continue

        if target not in data:
            data[target] = {k: [v] for k, v in results.items()}
        else:
            target_data = data[target]
            for key, value in results.items():
                if key not in target_data:
                    target_data[key] = [value]
                else:
                    target_data[key].append(value)

    return data


Outliers = list[float]
DescriptiveStats = dict[str, T.Union[float, Outliers]]
TargetStats = dict[str, DescriptiveStats]
Stats = dict[str, TargetStats]


def calc_stats(data: Data) -> Stats:
    """Calculate and return descriptive stats of all measurements in data"""
    stats = {}
    for target, measurements in data.items():
        target_stats: TargetStats = {}
        stats[target] = target_stats
        for measure, values in measurements.items():
            measure_stats: DescriptiveStats = {}
            target_stats[measure] = measure_stats

            measure_stats['mean'] = numpy.mean(values)
            measure_stats['std'] = numpy.std(values)

            values.sort()
            measure_stats['min'] = values[0]
            measure_stats['max'] = values[-1]
            measure_stats['median'] = float(numpy.median(values))
            upper_quartile = float(numpy.percentile(values, 75))
            measure_stats['upper_quartile'] = upper_quartile
            lower_quartile = float(numpy.percentile(values, 25))
            measure_stats['lower_quartile'] = lower_quartile
            iqr = upper_quartile - lower_quartile

            # find whiskers
            i = 0
            while values[i] < lower_quartile - 1.5 * iqr:
                i += 1
            measure_stats['lower_whisker'] = values[i]
            outliers = values[:i]

            i = len(values) - 1
            while values[i] > upper_quartile + 1.5 * iqr:
                i -= 1
            measure_stats['upper_whisker'] = values[i]
            outliers += values[i + 1:]
            measure_stats['outliers'] = outliers

            # convert everything to float to easily dump it using pyyaml
            for key, value in measure_stats.items():
                if isinstance(value, list):
                    continue
                measure_stats[key] = float(value)
    return stats


def summarize(stats: Stats,
              keys=None,
              desc_stats=None,
              format_str=None) -> bool:
    """Print a summary for each selected key of the collected stats"""
    if not stats:
        print('no data to summarize', file=sys.stderr)
        return False

    # find duration_time if no specific keys are given
    if not keys:
        for key in next(iter(stats.values())):
            if 'duration_time' in key:
                keys = [key]
                break
        assert keys

    for key in keys:
        print(f'{key}:')
        if format_str:
            for target in stats:
                print(format_str.format(**stats[target][key], target=target))
            continue

        dstats = desc_stats or next(iter(stats.values()))[key].keys()
        for stat in dstats:
            print(f'{stat}:')
            for target in stats:
                print(f'\t{target}-{stat}: {stats[target][key][stat]}')

    return True


def collect_and_summarize(args: argparse.Namespace) -> int:
    """Collect data and print a summary of the collected data"""
    data = collect(args.result_dir)
    if not data:
        print('No data to collect', file=sys.stderr)
        return 1

    stats = calc_stats(data)
    if not stats:
        print('No stats calculated', file=sys.stderr)
        return 1

    if args.implementations:
        stats = {v: s for v, s in stats.items() if v in args.implementations}

    if args.yaml:
        print(pyyaml.safe_dump(stats))
        return 0

    if not summarize(stats=stats,
                     keys=args.keys,
                     desc_stats=args.desc_stats,
                     format_str=args.format):
        print('Failed to summarize {desc_stats} of {keys} in stats',
              file=sys.stderr)
        return 1

    return 0


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('-k', '--keys', help='keys to summarize', nargs='*')
    parser.add_argument('-s',
                        '--desc-stats',
                        help='print all stats not only means',
                        nargs='*')
    parser.add_argument("-i",
                        "--implementations",
                        help="implementations to plot",
                        nargs='+')
    parser.add_argument("-f",
                        "--format",
                        help="Format populated by the available stats",
                        type=str)
    parser.add_argument('--yaml',
                        help='dump statistics as yaml',
                        action='store_true')
    parser.add_argument('result_dir',
                        help='directory containing the results to summarize')

    _args = parser.parse_args()

    print('### Summary ###')
    sys.exit(collect_and_summarize(_args))