#!/usr/bin/python3
#
# Ideal:

#   Start with a single variable. While formula is still below the
#   target size, replace a random variable by a random
#   connective. once the target size has been reached, replace all
#   unguarded variables by propositional atoms. Replace all remaining
#   variables by either an allowed fixpoint variable or a
#   propositional atom

import random
import string
import argparse
import os


variables = []
syntax = 'cool'


def _gensym():
    i = 0
    while True:
        i = i + 1
        yield "G%d" % i


gensym = iter(_gensym())



class Connective:
    pass



class Propositional(Connective):
    pass



class And(Propositional):
    def __init__(self):
        self._left = Variable(self)
        self._right = Variable(self)


    def __str__(self):
        left = str(self._left)
        right = str(self._right)
        if syntax == 'tatl':
            return "((%s) /\ (%s))" % (left, right)
        else:
            return "((%s) & (%s))" % (left, right)


    def replace(self, what, withw):
        if what == self._left:
            self._left = withw
        else:
            assert self._right == what
            self._right = withw


    def finalize(self, atoms, guarded, unguarded, fixpoint):
        self._left.finalize(atoms, guarded, unguarded, fixpoint)
        self._right.finalize(atoms, guarded, unguarded, fixpoint)



class Or(Propositional):
    def __init__(self):
        self._left = Variable(self)
        self._right = Variable(self)


    def __str__(self):
        left = str(self._left)
        right = str(self._right)
        if syntax == 'tatl':
            return "((%s) \/ (%s))" % (left, right)
        else:
            return "((%s) | (%s))" % (left, right)


    def replace(self, what, withw):
        if what == self._left:
            self._left = withw
        else:
            assert self._right == what
            self._right = withw


    def finalize(self, atoms, guarded, unguarded, fixpoint):
        self._left.finalize(atoms, guarded, unguarded, fixpoint)
        self._right.finalize(atoms, guarded, unguarded, fixpoint)



class Modal(Connective):
    pass


class Box(Modal):
    def __init__(self):
        self._subformula = Variable(self)


    def __str__(self):
        subformula = str(self._subformula)
        if syntax == 'ctl':
            return "AX (%s)" % (subformula,)
        else:
            return "[] (%s)" % (subformula,)


    def replace(self, what, withw):
        assert self._subformula == what
        self._subformula = withw


    def finalize(self, atoms, guarded, unguarded, fixpoint):
        self._subformula.finalize(atoms, guarded + unguarded, [], fixpoint)



class Diamond(Modal):
    def __init__(self):
        self._subformula = Variable(self)


    def __str__(self):
        subformula = str(self._subformula)
        if syntax == 'ctl':
            return "EX (%s)" % (subformula,)
        else:
            return "<> (%s)" % (subformula,)


    def replace(self, what, withw):
        assert self._subformula == what
        self._subformula = withw


    def finalize(self, atoms, guarded, unguarded, fixpoint):
        self._subformula.finalize(atoms, guarded + unguarded, [], fixpoint)



class Fixpoint(Connective):
    pass

class Mu(Fixpoint):
    def __init__(self):
        self._subformula = Variable(self)
        self._var = next(gensym)


    def __str__(self):
        subformula = str(self._subformula)
        if syntax == 'cool':
            return "(μ %s . (%s))" % (self._var, subformula)
        else:
            return "(mu %s . (%s))" % (self._var, subformula)


    def replace(self, what, withw):
        assert self._subformula == what
        self._subformula = withw


    def finalize(self, atoms, guarded, unguarded, fixpoint):
        if fixpoint == 'nu':
            guarded = []
            unguarded = []

        self._subformula.finalize(atoms, guarded, unguarded + [self._var], 'mu')



class Nu(Fixpoint):
    def __init__(self):
        self._subformula = Variable(self)
        self._var = next(gensym)


    def __str__(self):
        subformula = str(self._subformula)
        if syntax == 'cool':
            return "(ν %s . (%s))" % (self._var, subformula)
        else:
            return "(nu %s . (%s))" % (self._var, subformula)


    def replace(self, what, withw):
        assert self._subformula == what
        self._subformula = withw


    def finalize(self, atoms, guarded, unguarded, fixpoint):
        if fixpoint == 'mu':
            guarded = []
            unguarded = []

        self._subformula.finalize(atoms, guarded, unguarded + [self._var], 'nu')



class CTL(Connective):
    pass



class AG(CTL):
    def __init__(self):
        self._subformula = Variable(self)


    def __str__(self):
        subformula = str(self._subformula)
        return "AG (%s)" % (subformula,)


    def replace(self, what, withw):
        assert self._subformula == what
        self._subformula = withw


    def finalize(self, atoms, guarded, unguarded, fixpoint):
        self._subformula.finalize(atoms, guarded + unguarded, [], fixpoint)



class AF(CTL):
    def __init__(self):
        self._subformula = Variable(self)


    def __str__(self):
        subformula = str(self._subformula)
        return "AF (%s)" % (subformula,)


    def replace(self, what, withw):
        assert self._subformula == what
        self._subformula = withw


    def finalize(self, atoms, guarded, unguarded, fixpoint):
        self._subformula.finalize(atoms, guarded + unguarded, [], fixpoint)



class EG(CTL):
    def __init__(self):
        self._subformula = Variable(self)


    def __str__(self):
        subformula = str(self._subformula)
        return "EG (%s)" % (subformula,)


    def replace(self, what, withw):
        assert self._subformula == what
        self._subformula = withw


    def finalize(self, atoms, guarded, unguarded, fixpoint):
        self._subformula.finalize(atoms, guarded + unguarded, [], fixpoint)



class EF(CTL):
    def __init__(self):
        self._subformula = Variable(self)


    def __str__(self):
        subformula = str(self._subformula)
        return "EF (%s)" % (subformula,)


    def replace(self, what, withw):
        assert self._subformula == what
        self._subformula = withw


    def finalize(self, atoms, guarded, unguarded, fixpoint):
        self._subformula.finalize(atoms, guarded + unguarded, [], fixpoint)



class AU(CTL):
    def __init__(self):
        self._left = Variable(self)
        self._right = Variable(self)


    def __str__(self):
        left = str(self._left)
        right = str(self._right)
        return "A((%s) U (%s))" % (left, right)


    def replace(self, what, withw):
        if what == self._left:
            self._left = withw
        else:
            assert self._right == what
            self._right = withw


    def finalize(self, atoms, guarded, unguarded, fixpoint):
        self._left.finalize(atoms, guarded, unguarded, fixpoint)
        self._right.finalize(atoms, guarded, unguarded, fixpoint)



class AR(CTL):
    def __init__(self):
        self._left = Variable(self)
        self._right = Variable(self)


    def __str__(self):
        left = str(self._left)
        right = str(self._right)
        return "A((%s) R (%s))" % (left, right)


    def replace(self, what, withw):
        if what == self._left:
            self._left = withw
        else:
            assert self._right == what
            self._right = withw


    def finalize(self, atoms, guarded, unguarded, fixpoint):
        self._left.finalize(atoms, guarded, unguarded, fixpoint)
        self._right.finalize(atoms, guarded, unguarded, fixpoint)



class EU(CTL):
    def __init__(self):
        self._left = Variable(self)
        self._right = Variable(self)


    def __str__(self):
        left = str(self._left)
        right = str(self._right)
        return "E((%s) U (%s))" % (left, right)


    def replace(self, what, withw):
        if what == self._left:
            self._left = withw
        else:
            assert self._right == what
            self._right = withw


    def finalize(self, atoms, guarded, unguarded, fixpoint):
        self._left.finalize(atoms, guarded, unguarded, fixpoint)
        self._right.finalize(atoms, guarded, unguarded, fixpoint)



class ER(CTL):
    def __init__(self):
        self._left = Variable(self)
        self._right = Variable(self)


    def __str__(self):
        left = str(self._left)
        right = str(self._right)
        return "E((%s) R (%s))" % (left, right)


    def replace(self, what, withw):
        if what == self._left:
            self._left = withw
        else:
            assert self._right == what
            self._right = withw


    def finalize(self, atoms, guarded, unguarded, fixpoint):
        self._left.finalize(atoms, guarded, unguarded, fixpoint)
        self._right.finalize(atoms, guarded, unguarded, fixpoint)



class ATL:
    def coalition(self):
        if not hasattr(self, '_coalition'):
            self._coalition = []
            while len(self._coalition) == 0:
                self._coalition = []
                for i in range(1, 6):
                    if random.getrandbits(1) == 1:
                        self._coalition.append(str(i))

        if syntax == 'tatl':
            return ",".join(self._coalition)
        else:
            return " ".join(self._coalition)



class Next(ATL):
    def __init__(self):
        self._subformula = Variable(self)


    def __str__(self):
        subformula = str(self._subformula)
        if syntax == 'tatl':
            return "<<%s>>X(%s)" % (self.coalition(), subformula,)
        else:
            return "[{%s}](%s)" % (self.coalition(), subformula,)


    def replace(self, what, withw):
        assert self._subformula == what
        self._subformula = withw


    def finalize(self, atoms, guarded, unguarded, fixpoint):
        self._subformula.finalize(atoms, guarded + unguarded, [], fixpoint)



class Always(ATL):
    def __init__(self):
        self._subformula = Variable(self)


    def __str__(self):
        subformula = str(self._subformula)
        if syntax == 'tatl':
            return "<<%s>>G(%s)" % (self.coalition(), subformula,)
        else:
            return "(ν X . ((%s) & [{%s}]X))" % (subformula,self.coalition())


    def replace(self, what, withw):
        assert self._subformula == what
        self._subformula = withw


    def finalize(self, atoms, guarded, unguarded, fixpoint):
        self._subformula.finalize(atoms, guarded + unguarded, [], fixpoint)



class Eventually(ATL):
    def __init__(self):
        self._subformula = Variable(self)


    def __str__(self):
        subformula = str(self._subformula)
        if syntax == 'tatl':
            return "<<%s>>F(%s)" % (self.coalition(), subformula,)
        else:
            return "(μ X . ((%s) | [{%s}]X))" % (subformula, self.coalition())


    def replace(self, what, withw):
        assert self._subformula == what
        self._subformula = withw


    def finalize(self, atoms, guarded, unguarded, fixpoint):
        self._subformula.finalize(atoms, guarded + unguarded, [], fixpoint)



class Until(ATL):
    def __init__(self):
        self._left = Variable(self)
        self._right = Variable(self)


    def __str__(self):
        left = str(self._left)
        right = str(self._right)
        if syntax == 'tatl':
            return "<<%s>>((%s) U (%s))" % (self.coalition(),left, right)
        else:
            return "(μ X . ((%s) | ((%s) & [{%s}]X)))" % (right,left,self.coalition())


    def replace(self, what, withw):
        if what == self._left:
            self._left = withw
        else:
            assert self._right == what
            self._right = withw


    def finalize(self, atoms, guarded, unguarded, fixpoint):
        self._left.finalize(atoms, guarded, unguarded, fixpoint)
        self._right.finalize(atoms, guarded, unguarded, fixpoint)



connectives = []


class Variable:
    def __init__(self, parent):
        self._parent = parent
        variables.append(self)


    def __str__(self):
        return "(undecided)"


    def replace(self):
        assert self._parent != None
        replacement = random.choice(connectives)()
        variables.remove(self)
        self._parent.replace(self, replacement)


    def finalize(self, atoms, guarded, unguarded, fixpoint):
        choice = random.choice(guarded + atoms)
        if choice in atoms:
            choice = random.choice([choice, "~%s" % choice])

        variables.remove(self)
        self._parent.replace(self, choice)



def main(args):
    global connectives
    if args.logic == 'afmu':
        connectives = [And, And, Or, Or, Box, Diamond, Mu, Nu]
        os.makedirs(os.path.join(args.destdir, 'cool'))
        os.makedirs(os.path.join(args.destdir, 'mlsolver'))

    elif args.logic == 'CTL':
        connectives = [And, And, Or, Or, Box, Diamond, AF, AG, EF, EG, AU, EU]
        os.makedirs(os.path.join(args.destdir, 'ctl'))

    elif args.logic == 'ATL':
        connectives = [And, And, Or, Or, Next, Always, Eventually, Until]
        os.makedirs(os.path.join(args.destdir, 'cool'))
        os.makedirs(os.path.join(args.destdir, 'tatl'))

    for i in range(0, args.count):
        global variables
        global syntax
        variables = []
        formula = random.choice(connectives)()

        for _ in range(0, args.constructors):
            random.choice(variables).replace()

        formula.finalize(list(string.ascii_lowercase[:args.atoms]), [], [], 'none')

        if args.logic == 'afmu':
            with open(os.path.join(args.destdir, 'cool', 'random.%04d.cool' % i), 'w') as f:
                syntax = 'cool'
                f.write(str(formula))

            with open(os.path.join(args.destdir, 'mlsolver', 'random.%04d.mlsolver' % i), 'w') as f:
                syntax = 'mlsolver'
                f.write(str(formula))

        elif args.logic == 'CTL':
            with open(os.path.join(args.destdir, 'ctl', 'random.%04d.ctl' % i), 'w') as f:
                syntax = 'ctl'
                f.write(str(formula))

        elif args.logic == 'ATL':
            with open(os.path.join(args.destdir, 'cool', 'random.%04d.cool' % i), 'w') as f:
                syntax = 'cool'
                f.write(str(formula))

            with open(os.path.join(args.destdir, 'tatl', 'random.%04d.tatl' % i), 'w') as f:
                syntax = 'tatl'
                f.write(str(formula))

    print(args.destdir)

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="Generator for random μ-Calculus-Formulas")
    parser.add_argument('--atoms', type=int, required=True,
                        help="Number of propositional arguments to use")
    parser.add_argument('--constructors', type=int, required=True,
                        help="Number of connectives to build")
    parser.add_argument('--count', type=int, required=True,
                        help="Number of formulas to generate")
    parser.add_argument('--destdir', type=str, required=True,
                        help="Directory to place created formulas in")
    parser.add_argument('--logic', choices=['CTL', 'ATL', 'afmu'], required=True)

    args = parser.parse_args()

    main(args)