(** The "plug-in" functions for each individual logic.
    @author Florian Widmann
 *)


open CoAlgMisc
open CoolUtils
open CoAlgLogicUtils
open Gmlmip

module S = MiscSolver

(** directly return a list of rules **)
let mkRuleList_MultiModalK sort bs sl : rule list =
  (* arguments: 
    sort: type of formulae in bs (needed to disambiguate hashing)
    sl: sorts functor arguments, e.g. type of argument formulae
    bs: tableau sequent  consisting of modal atoms
        (premiss, conjunctive set of formulae represented by hash values)
    result: set R of (propagation function, rule conclusion) pairs
        s.t. [ (premiss / conc) : (_, conc) \in R ] are all rules
        applicable to premiss.
        Propagation functions map unsatisfiable subsets of the
        (union of all?) conclusion(s) to unsatisfiable subsets of
        the premiss -- this is a hook for backjumping.
        NB: propagating the whole premiss is always safe.
  *)
  assert (List.length sl = 1);
  let dep f bsl =                 (* dependencies for formula f (f is a diamond) *)
    assert (List.length bsl = 1); (* -+                 *)
    let bs1 = List.hd bsl in      (* -+-> [bs1] := bsl  *)
    let res = bsetMake () in
    bsetAdd res f;
    let (role : int) = lfGetDest3 sort f in (* ♥R.? := f, ♥ ∈ {∃,∀} *)
    let filterFkt f1 =
      if lfGetType sort f1 = AxF && lfGetDest3 sort f1 = role
        && bsetMem bs1 (lfGetDest1 sort f1)
      then
        (* if f1 = ∀R.C and C ∈ bs1 then res = res ∪ {∀R.C} *)
        bsetAdd res f1
      else ()
    in
    bsetIter filterFkt bs;
    res
  in
  let getRules f acc =
    if lfGetType sort f = ExF then (* f = ∃R.C,i.e. a diamond *)
      let bs1 = bsetMake () in
      bsetAdd bs1 (lfGetDest1 sort f); (* bs1 := { C }          *)
      let (role : int) = lfGetDest3 sort f in (* role := R *)
      let filterFkt f1 =
        if lfGetType sort f1 = AxF && lfGetDest3 sort f1 = role then
          (* if f1 = ∀R.D then bs1 = bs1 ∪ { D } *)
          bsetAdd bs1 (lfGetDest1 sort f1)
        else ()
      in
      bsetIter filterFkt bs; (* bs1 := bs1 ∪ { D | some "∀R.D" ∈ bs } *)
      let s1 = List.hd sl in (* [s1] := sl *)
      let rle = (dep f, lazylistFromList [(s1, bs1)]) in
      rle::acc
    else acc
  in
  (* effectively:
        mkRule_MultiModalK sort bs [s1]
            = { ( λ[bs1]. { ∃R.C } ∪ { ∀R.D
                                     | ∀R.D ∈ bs, D ∈ bs1
                                     }
                , [(s1, {C}∪{D | "∀R.D" ∈ bs)]
                )
              | "∃R.C" ∈ bs
              }
  *)
  bsetFold getRules bs []

let mkRule_MultiModalK sort bs sl : stateExpander =
  let rules = mkRuleList_MultiModalK sort bs sl in
  lazylistFromList rules

(* TODO: test it with:
     make && ./coalg sat <<< $'<R> False \n [R] False \n [R] True'
*)
let mkRule_MultiModalKD sort bs sl : stateExpander =
  assert (List.length sl = 1); (* functor has just one argument *)
  let s1 = List.hd sl in (* [s1] = sl *)
  let roles = S.makeBS () in
  (* step 1: for each ∀R._ add R *)
  let addRoleIfBox formula =
    if lfGetType sort formula = AxF then
      ignore (S.addBSNoChk roles (lfGetDest3 sort formula))
    else ()
  in
  bsetIter (addRoleIfBox) bs;
  (* step 2: for each ∃R._ remove R again from roles (optimization) *)
  let rmRoleIfDiamond formula =
    if lfGetType sort formula = ExF then
      S.remBS roles (lfGetDest3 sort formula)
    else ()
  in
  bsetIter (rmRoleIfDiamond) bs;
  (* step 3: for each R in roles enforce one successor *)
  let getRules r acc = (* add the rule for a concrete R *)
    let dep r bsl =
      assert (List.length bsl = 1); (* -+                 *)
      let bs1 = List.hd bsl in      (* -+-> [bs1] := bsl  *)
      let res = bsetMake () in (* res := { ∀R.D ∈ bs | D ∈ bs1} *)
      let f formula =
        if lfGetType sort formula = AxF
           && lfGetDest3 sort formula = r
           && bsetMem bs1 (lfGetDest1 sort formula)
        then ignore (bsetAdd res formula)
        else ()
      in
      bsetIter f bs; (* fill res *)
      res
    in
    let succs = bsetMake () in (* succs := {D | ∀r.D ∈ bs *)
    let f formula =
      if lfGetType sort formula = AxF
         && lfGetDest3 sort formula = r
      then ignore (bsetAdd succs (lfGetDest1 sort formula))
      else ()
    in
    bsetIter f bs;
    (dep r, lazylistFromList [(s1, succs)])::acc
  in
  (*
    mkRule_MultiModalKD sort bs [s1]
        = { (λ[bs1]. { ∀R.D ∈ bs | D ∈ bs1}
            , [(s1, {D | "∀R.D" ∈ bs)]
            )
          | R ∈ signature(bs) (or R ∈ roles)
          }
          ∪ mkRule_MultiModalK sort bs [s1]
  *)
  let rules = mkRuleList_MultiModalK sort bs sl in
  (* extend rules from K with enforcing of successors *)
  let rules = S.foldBS getRules roles rules in
  lazylistFromList rules


(* CoalitionLogic: helper functions *)
(*val subset : bitset -> bitset -> bool*)
let bsetlen (a: bset) : int =
    let res = ref (0) in
    bsetIter (fun _ -> res := !res + 1) a;
    !res

let subset (a: bset) (b: bset) : bool =
    let res = ref (true) in
    let f formula =
        if bsetMem b formula
        then ()
        else res := false
    in
    bsetIter f a;
    !res && (bsetlen a < bsetlen b)

let bsetForall (a: bset) (f: CoAlgMisc.localFormula -> bool) : bool =
    let res = ref (true) in
    let helper formula =
        if (f formula) then () else res := false
    in
    bsetIter helper a;
    !res

let bsetExists (a: bset) (f: CoAlgMisc.localFormula -> bool) : bool =
    not (bsetForall a (fun x -> not (f x)))

let compatible sort (a: bset) formula1 =
    let res = ref (true) in
    let f formula2 =
        if not (disjointAgents sort formula1 formula2)
        then res := false
        else ()
    in
    bsetIter f a;
    !res

(*
    CoalitionLogic: tableau rules for satisfiability

    Rule 1:

     /\  n
    /  \i=1 [C_i] a_i       n ≥ 0,
  ——————————————————————    C_i pairwise disjoint
     /\  n
    /  \i=1  a_i

    Rule 2:

     /\ n                         /\ m
    /  \i=1 [C_i] a_i  /\ <D> b  /  \j=1 <N> c_j    n,m ≥ 0
  ————————————————————————————————————————————————  C_i pairwise disjoint
     /\ n                         /\ m              all C_i ⊆ D
    /  \i=1 a_i        /\    b   /  \j=1 c_j
*)

(* Not yet implemented: backjumping hooks. 

  E.g. in Rule 1, if a subset I of {a_1, ..., a_n} is unsat, then {[C_i] a_i : i \in I} is
  already unsat.
*)

let mkRule_CL sort bs sl : stateExpander =
  assert (List.length sl = 1); (* TODO: Why? *)
  let s1 = List.hd sl in (* [s1] = List.hd sl *)
  let boxes = bsetFilter bs (fun f -> lfGetType sort f = EnforcesF) in
  let diamonds = bsetFilter bs (fun f -> lfGetType sort f = AllowsF) in
  let disjoints = maxDisjoints sort boxes in
  (*print_endline ("disjoints: "^(string_of_coalition_list sort disjoints)); *)
  let nCandsEmpty = ref (true) in
  let nCands = bsetMakeRealEmpty () in (* all N-diamonds *)
  let dCandsEmpty = ref (true) in
  let hasFullAgentList formula =
    let aglist = lfGetDestAg sort formula in
    let value = TArray.all (fun x -> TArray.elem x aglist) (cl_get_agents ()) in
    if (value) then
    begin
        bsetAdd nCands formula;
        nCandsEmpty := false
    end
    else dCandsEmpty := false;
    value
  in
  (* Candidates for D in Rule 2 *)
  (* implicitly fill nCands *)
  let dCands = bsetFilter diamonds (fun f -> not (hasFullAgentList f)) in
  (*
  print_endline ("For set " ^(CoAlgMisc.bsetToString sort bs));
  print_endline ("N-Cands: " ^(CoAlgMisc.bsetToString sort nCands));
  print_endline ("D-Cands: " ^(CoAlgMisc.bsetToString sort dCands));
  *)
  let c_j : localFormula list =
    bsetFold (fun f a -> (lfGetDest1 sort f)::a) nCands []
  in
  (* rule 2 for diamaonds where D is a proper subset of the agent set N *)
  let getRule2 diamDb acc = (* diamDb = <D> b *)
    (* print_endline "Rule2" ; *)
    let d = lfGetDestAg sort diamDb in (* the agent list *)
    let b = lfGetDest1 sort diamDb in
    let hasAppropriateAglist f =
        let aglist = lfGetDestAg sort f in
        TArray.included aglist d
    in
    let maxdisj = maxDisjoints sort (bsetFilter boxes hasAppropriateAglist) in
    let createSingleRule acc coalitions =
        let a_i : localFormula list =
            bsetFold (fun f a -> (lfGetDest1 sort f)::a) coalitions []
        in
        (* now do rule 2:
            coalitions  /\ <d> b /\ nCands
           ————————————————————————————————
              a_i       /\     b /\ c_j
        *)
        let children = bsetMakeRealEmpty () in
        List.iter (bsetAdd children) (b::c_j) ;
        List.iter (bsetAdd children) (a_i) ;
        ((fun bs1 -> bs), lazylistFromList [(s1, children)])::acc
    in
    List.fold_left createSingleRule acc maxdisj
  in
  let rules = bsetFold getRule2 dCands [] in
  let getRule2ForFullAgents acc =
    (* if there are N-diamonds but no diamonds with a proper subset of the agents,
       then we need an explicit rule 2
    *)
    if not !nCandsEmpty && !dCandsEmpty then begin
        (* create rule 2 once for all the diamonds with a full agent set *)
        let maxdisj = maxDisjoints sort boxes in
        let createSingleRule acc coalitions =
            let a_i : localFormula list =
                bsetFold (fun f a -> (lfGetDest1 sort f)::a) coalitions []
            in
            (* now do rule 2:
                coalitions  /\ nCands
               ———————————————————————
                  a_i       /\ c_j
            *)
            let children = bsetMakeRealEmpty () in
            List.iter (bsetAdd children) (c_j) ;
            List.iter (bsetAdd children) (a_i) ;
            ((fun bs1 -> bs), lazylistFromList [(s1, children)])::acc
        in
        List.fold_left createSingleRule acc maxdisj
    end else acc
  in
  let rules = getRule2ForFullAgents rules in
  let getRule1 acc coalitions =
    (* print_endline "Rule1" ; *)
    (* do rule 1:
        coalitions
       ————————————
           a_i
    *)
    let a_i : bset = bsetMakeRealEmpty () in
    bsetIter (fun f -> bsetAdd a_i (lfGetDest1 sort f)) coalitions ;
    ((fun bs1 -> bs), lazylistFromList [(s1, a_i)])::acc
  in
  let rules = List.fold_left getRule1 rules disjoints in
  (*
    mkRule_CL sort bs [s1]
    = { (λ[bs1]. bs, [(s1, { a_i | i∈I })])
        | {[C_i]a_i | i∈I} ⊆ bs,
          C_i pairwise disjoint and I maximal
        }
  *)
  lazylistFromList rules

let mkRule_GML sort bs sl : stateExpander =
  assert (List.length sl = 1);
  let s1 = List.hd sl in (* [s1] = List.hd sl *)
  let diamonds = bsetFilter bs (fun f -> lfGetType sort f = MoreThanF) in
  let boxes = bsetFilter bs (fun f -> lfGetType sort f = MaxExceptF) in
  let roles = S.makeBS () in
  (* step 1: for each diamond/box add R *)
  let addRole formula =
      ignore (S.addBSNoChk roles (lfGetDest3 sort formula))
  in
  bsetIter addRole boxes;
  bsetIter addRole diamonds;
  let addRule role acc =
      let premise: (bool*int*int) list =
        let modality isDiamond m acc =
            if lfGetDest3 sort m = role
            then (isDiamond,lfGetDestNum sort m,lfToInt (lfGetDest1 sort m))::acc
            else acc
        in
        List.append
            (bsetFold (modality true) diamonds [])
            (bsetFold (modality false) boxes [])
      in
      let conclusion = gml_rules premise in
      (* conclusion is a set of rules, *each* of the form \/ /\ lit *)
      let handleRuleConcl rc acc =
        let handleConjunction conj =
            let res = bsetMake () in
            List.iter (fun (f,positive) ->
                        let f = lfFromInt f in
                        let f = if positive then f else
                                    match lfGetNeg sort f with
                                    | Some nf -> nf
                                    | None -> raise (CoAlgFormula.CoAlgException ("Negation of formula missing"))
                                    in
                        bsetAdd res f)
                      conj;
            (s1,res)
        in
        let rc = List.map handleConjunction rc in
        ((fun bs1 -> bs),lazylistFromList rc)::acc
      in List.fold_right handleRuleConcl conclusion acc
  in
  let rules = S.foldBS addRule roles [] in
  lazylistFromList rules

let mkRule_PML sort bs sl : stateExpander =
  assert (List.length sl = 1);
  let s1 = List.hd sl in (* [s1] = List.hd sl *)
  let diamonds = bsetFilter bs (fun f -> lfGetType sort f = AtLeastProbF) in
  let boxes = bsetFilter bs (fun f -> lfGetType sort f = LessProbFailF) in
  let premise: (bool*int*int*int) list =
    let modality isDiamond m acc =
        let nominator   = lfGetDestNum sort m in
        let denominator = lfGetDestNum2 sort m in
        let nestedFormula = lfToInt (lfGetDest1 sort m) in
        (*print_endline ("putting formula "^(string_of_int nestedFormula)); *)
        (isDiamond,nominator,denominator,nestedFormula)::acc
    in
    List.append
        (bsetFold (modality true) diamonds [])
        (bsetFold (modality false) boxes [])
  in
  let conclusion = pml_rules premise in
  let error message = raise (CoAlgFormula.CoAlgException message) in
  (* conclusion is a set of rules, *each* of the form \/ /\ lit *)
  let handleRuleConcl rc acc =
    let handleConjunction conj =
        let res = bsetMake () in
        let handleLiteral = fun (f_int,positive) -> begin
                    let f = lfFromInt f_int in
                    let f = if positive
                            then f
                            else begin
                                (*print_endline ("getting "^(string_of_int f_int)); *)
                                match lfGetNeg sort f with
                                 | Some nf -> nf
                                 | None -> error ("Negation of formula missing")
                            end
                            in
                    bsetAdd res f
                    end
        in
        List.iter handleLiteral conj;
        (s1,res)
    in
    let rc = List.map handleConjunction rc in
    ((fun bs1 -> bs),lazylistFromList rc)::acc
  in
  let rules = List.fold_right handleRuleConcl conclusion [] in
  lazylistFromList rules

(* constant functor *)
let mkRule_Const colors sort bs sl : stateExpander =
  assert (List.length sl = 1);    (* just one (formal) argument *)
  let helper (f:localFormula)  (pos, neg) =
    let col = lfGetDest3 sort f in
    match (lfGetType sort f) with
      |  ConstnF -> (pos, (col::neg))
      |  ConstF  -> ((col::pos), neg)
      |  _       -> (pos, neg)
  in
  let (pos, neg) = bsetFold helper bs ([], []) in (* pos/neg literals *) 
  let clash = List.exists (fun l -> List.mem l pos) neg in (* =a /\ ~ = a *)
  let allneg = List.length colors = List.length neg in
  let twopos = List.length pos > 1 in
  let rules = if (clash || allneg || twopos) 
              then [((fun x -> bs), lazylistFromList [])]  (* no backjumping *)
              else []
  in 
  lazylistFromList rules

let mkRule_Identity sort bs sl : stateExpander = 
  assert (List.length sl = 1); (* Identity has one argument *)
  let s1 = List.hd sl in
  let dep bsl = (* return arguments prefixed with identity operator *)
    assert (List.length bsl = 1);
    let bs1 = List.hd bsl in
    let res = bsetMake () in
    let filterFkt f =
      if lfGetType sort f = IdF &&
        bsetMem bs1 (lfGetDest1 sort f)
      then bsetAdd res f
      else ()
    in
    bsetIter filterFkt bs;
    res
  in
  let bs1 = bsetMake () in
  let getRule f =
    if lfGetType sort f = IdF 
    then bsetAdd bs1 (lfGetDest1 sort f)
    else ()
  in
  bsetIter getRule bs;
  lazylistFromList [(dep, lazylistFromList [(s1, bs1)])]


let mkRule_DefaultImplication sort bs sl : stateExpander =
  raise (CoAlgFormula.CoAlgException ("Default Implication Not yet implemented."))
  
let mkRule_Choice sort bs sl : stateExpander =
  assert (List.length sl = 2);
  let dep bsl =
    assert (List.length bsl = 2);
    let bs1 = List.nth bsl 0 in
    let bs2 = List.nth bsl 1 in
    let res = bsetMake () in
    let filterFkt f =
      if lfGetType sort f = ChcF &&
        (bsetMem bs1 (lfGetDest1 sort f) || bsetMem bs2 (lfGetDest2 sort f))
      then bsetAdd res f
      else ()
    in
    bsetIter filterFkt bs;
    res
  in
  let bs1 = bsetMake () in
  let bs2 = bsetMake () in
  let getRule f =
    if lfGetType sort f = ChcF then begin
      bsetAdd bs1 (lfGetDest1 sort f);
      bsetAdd bs2 (lfGetDest2 sort f)
    end else ()
  in
  bsetIter getRule bs;
  let s1 = List.nth sl 0 in
  let s2 = List.nth sl 1 in
  lazylistFromList [(dep, lazylistFromList [(s1, bs1); (s2, bs2)])]


let mkRule_Fusion sort bs sl : stateExpander =
  assert (List.length sl = 2);
  let dep proj bsl =
    assert (List.length bsl = 1);
    let bs1 = List.hd bsl in
    let res = bsetMake () in
    let filterFkt f =
      if lfGetType sort f = FusF && lfGetDest3 sort f = proj &&
        bsetMem bs1 (lfGetDest1 sort f)
      then bsetAdd res f
      else ()
    in
    bsetIter filterFkt bs;
    res
  in
  let bs1 = bsetMake () in
  let bs2 = bsetMake () in
  let getRule f =
    if lfGetType sort f = FusF then
      if lfGetDest3 sort f = 0 then bsetAdd bs1 (lfGetDest1 sort f)
      else bsetAdd bs2 (lfGetDest1 sort f)
    else ()
  in
  bsetIter getRule bs;
  let s1 = List.nth sl 0 in
  let s2 = List.nth sl 1 in
  lazylistFromList [(dep 0, lazylistFromList [(s1, bs1)]); (dep 1, lazylistFromList [(s2, bs2)])]


(* Maps a logic represented by the type "functors" to the corresponding
   "plug-in" function.
 *)
let getExpandingFunctionProducer = function
  | MultiModalK -> mkRule_MultiModalK
  | MultiModalKD -> mkRule_MultiModalKD
  | CoalitionLogic -> mkRule_CL
  | GML -> mkRule_GML
  | PML -> mkRule_PML
  | Constant colors -> mkRule_Const colors
  | Identity -> mkRule_Identity
  | DefaultImplication -> mkRule_DefaultImplication
  | Choice -> mkRule_Choice
  | Fusion -> mkRule_Fusion