Skip to content
Snippets Groups Projects
Commit 27784c3f authored by Hans-Peter Deifel's avatar Hans-Peter Deifel
Browse files

Begin wta benchmark script

parent 041cd53f
No related branches found
No related tags found
1 merge request!1WIP: WTA Benchmarks
#!/usr/bin/env python3
import argparse
import os
import subprocess
import json
import numpy as np
import scipy.stats as st
samples = 5
def coalg_file(states, monoid, symbols, zero_frequency, i):
return "bench/wta_%s_%s_%s_%s_%d" % (monoid, symbols, zero_frequency,
states, i)
def generate(args):
generator = args.generator
states = args.states
monoid = args.monoid
symbols = args.symbols
zero_frequency = args.zero_frequency
os.makedirs("bench", exist_ok=True)
for i in range(0, samples):
f = coalg_file(states, monoid, symbols, zero_frequency,
i) + ".coalgebra"
if os.path.exists(f):
continue
cmd = [
generator, "--states", states, "--monoid", monoid, "--symbols",
symbols, "--zero-frequency", zero_frequency
]
subprocess.run(cmd, stdout=open(f, "w+"))
def run_one(args, i):
copar = args.copar
states = args.states
monoid = args.monoid
symbols = args.symbols
zero_frequency = args.zero_frequency
f = coalg_file(states, monoid, symbols, zero_frequency, i) + ".coalgebra"
copar_args = [copar, 'refine', '--stats-json', f]
out = subprocess.run(
copar_args,
stdout=subprocess.DEVNULL,
stderr=subprocess.PIPE,
check=True)
stats = json.loads(out.stderr.decode('utf-8'))
stats['monoid'] = monoid
stats['symbols'] = symbols
stats['zero-freq'] = zero_frequency
stats['i'] = i
return stats
def run_one_simple(args, i):
copar = args.copar
states = args.states
monoid = args.monoid
symbols = args.symbols
zero_frequency = args.zero_frequency
f = coalg_file(states, monoid, symbols, zero_frequency, i) + ".coalgebra"
copar_args = [copar, 'refine', f]
subprocess.run(
copar_args,
stdout=subprocess.DEVNULL,
check=True)
def confidence(vals):
"""Compute the 95% confidence intervall (CI) for the mean with the student
distribution.
Returns a tuple of (mean, lower, upper), where lower and upper are the bounds
of the CI"""
# For a larger sample size (> 30), we could also use the normal
# distribution.
#
# This code is taken from
# https://stackoverflow.com/questions/15033511/compute-a-confidence-interval-from-sample-data/34474255#34474255
mean = np.mean(vals)
ci = st.t.interval(
0.95, len(vals) - 1, loc=np.mean(vals), scale=st.sem(vals))
return (mean, ci[0], ci[1])
def stddev(vals):
"""Compute the mean and standard deviation intervall on a sample.
This uses the corrected sample standard deviation."""
# see also:
# https://en.wikipedia.org/wiki/Standard_deviation#Corrected_sample_standard_deviation
mean = np.mean(vals)
std = np.std(vals, ddof=1)
return (mean, std)
def print_row(d, header, stddev):
keys = [
'i', 'states', 'edges', 'initial-partition-size',
'final-partition-size', 'explicit-final-partition-size',
'size1-skipped'
]
for k in [
'overall-duration', 'parse-duration', 'algorithm-duration',
'initialize-duration', 'refine-duration'
]:
keys.append(k)
if stddev:
keys.append(k + '-stddev')
values = [d[k] for k in keys]
if header:
print('\t'.join(keys))
else:
print('\t'.join(str(x) for x in values))
def run(args):
results = [run_one(args, i) for i in range(0, samples)]
def confidencekey(vals, k):
return confidence(list(float(x[k]) for x in vals))
def stddevkey(vals, k):
return stddev(list(float(x[k]) for x in vals))
combined = results[0].copy()
combined['i'] = samples
for k in [
'overall-duration', 'parse-duration', 'initialize-duration',
'refine-duration', 'algorithm-duration'
]:
ci = stddevkey(results, k)
combined[k] = str(ci[0])
combined[k + '-stddev'] = str(ci[1])
if args.indiv:
if args.header:
print_row(combined, True, stddev=False)
for res in results:
print_row(res, False, stddev=False)
else:
if args.header:
print_row(combined, True, stddev=args.stddev)
print_row(combined, False, stddev=args.stddev)
def test(args, states):
print("Trying %d..." % states)
args.states = str(states)
generate(args)
for i in range(0, samples):
try:
run_one_simple(args, i)
except subprocess.CalledProcessError:
return False
return True
def find_bad(args, good):
states = good*2
if test(args, states):
return find_bad(args, states)
else:
return (good, states)
def bisect_states(args):
states = args.start_states
good = args.good or 0
bad = args.bad
if bad is None:
if good and states < good:
states = good+1
if test(args, states):
(good, bad) = find_bad(args, states)
else:
bad = states
while good+1 < bad:
states = good + (bad-good)//2
if test(args, states):
good = states
else:
bad = states
print("First bad state count: %d" % bad)
def main():
parser = argparse.ArgumentParser()
subparsers = parser.add_subparsers(required=True)
gen_parser = subparsers.add_parser('generate')
gen_parser.add_argument('generator')
gen_parser.add_argument('--states', required=True)
gen_parser.add_argument('--monoid', required=True)
gen_parser.add_argument('--symbols', required=True)
gen_parser.add_argument('--zero-frequency', required=True)
gen_parser.set_defaults(func=generate)
run_parser = subparsers.add_parser('run')
run_parser.add_argument('copar')
run_parser.add_argument('--states', required=True)
run_parser.add_argument('--monoid', required=True)
run_parser.add_argument('--symbols', required=True)
run_parser.add_argument('--zero-frequency', required=True)
run_parser.add_argument(
'--stddev', action='store_true', help="report stddev for timings")
run_parser.add_argument(
'--indiv', action='store_true', help="report individual samples")
run_parser.add_argument(
'--header', action='store_true', help="Print header row for table")
run_parser.set_defaults(func=run)
bisect_parser = subparsers.add_parser('bisect')
bisect_parser.add_argument('generator')
bisect_parser.add_argument('copar')
bisect_parser.add_argument('--monoid', required=True)
bisect_parser.add_argument('--symbols', required=True)
bisect_parser.add_argument('--zero-frequency', required=True)
bisect_parser.add_argument('--start-states', type=int, default=50)
bisect_parser.add_argument('--good', type=int)
bisect_parser.add_argument('--bad', type=int)
bisect_parser.set_defaults(func=bisect_states)
args = parser.parse_args()
args.func(args)
if __name__ == "__main__":
main()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment