diff --git a/src/Algorithm.hs b/src/Algorithm.hs index 45381ce3355bc6d02df6fc0699fcd594e2f90f50..c7b412747a8370334dbc382010491e2cf70ee05f 100644 --- a/src/Algorithm.hs +++ b/src/Algorithm.hs @@ -15,14 +15,16 @@ module Algorithm import Prelude hiding (pred) -import Control.Monad.ST import Control.Monad -import Data.STRef -import Data.Maybe (maybeToList) +import Control.Monad.ST import Control.Monad.ST.Unsafe (unsafeSTToIO) +import Data.Function (on) +import Data.List (delete,maximumBy) +import Data.Maybe (maybeToList) +import Data.STRef import System.IO.Unsafe (unsafeDupablePerformIO) -import Control.Monad.Extra (unlessM, whenM, whileM) +import Control.Monad.Extra (unlessM, whenM, whileM, ifM) import Control.Monad.Reader import Data.Tuple.Extra (snd3) import Data.Vector (Vector) @@ -47,6 +49,7 @@ data AlgoState s h = AlgoState , pred :: Vector [EdgeRef] , partition :: RefinablePartition s , h3Cache :: MVector s (RI.H3 h) + , sort :: Sort } -- nextSize is the node count of the target set. All target indices in the given @@ -54,10 +57,11 @@ data AlgoState s h = AlgoState -- -- returns (initial queue content, algo state) initialize :: forall h s. RefinementInterface h - => Encoding (RI.Label h) (RI.H1 h) + => Sort + -> Encoding (RI.Label h) (RI.H1 h) -> Int -> ST s ([Block], AlgoState s h) -initialize encoding nextSize = do +initialize sort encoding nextSize = do toSub <- VM.replicate (size encoding) [] lastW <- VM.new (length (edges encoding)) predMutable <- VM.replicate nextSize [] @@ -83,11 +87,11 @@ initialize encoding nextSize = do return (blocks, AlgoState {..}) -type SplitM s h = ReaderT (AlgoState s h) (ST s) +type SplitM s h = ReaderT (AlgoState s h, BlockQueue s) (ST s) split :: RefinementInterface h => Vector State -> SplitM s h () split statesOfS = do - as <- ask + (as, _) <- ask touchedBlocks <- collectTouchedBlocks statesOfS forM_ touchedBlocks $ \(b, v0) -> do @@ -96,7 +100,7 @@ split statesOfS = do splitBlock b updateBlock :: forall s h. RefinementInterface h => Block -> RI.H3 h -> SplitM s h () -updateBlock b v0 = ask >>= \as -> lift $ do +updateBlock b v0 = ask >>= \(as, _) -> lift $ do markB <- Partition.markedStates (partition as) b forM_ markB $ \x -> do pc <- (fromEdgeRef . head <$> VM.read (toSub as) x) @@ -115,7 +119,7 @@ updateBlock b v0 = ask >>= \as -> lift $ do -- b must have at least one marked state splitBlock :: RefinementInterface h => Block -> SplitM s h () -splitBlock b = ask >>= \as -> lift $ do +splitBlock b = ask >>= \(as, queue) -> lift $ do -- b has marked states, so b1 is guaranteed to be non-empty (Just b1, bunmarked) <- Partition.splitMarked (partition as) b @@ -127,7 +131,7 @@ splitBlock b = ask >>= \as -> lift $ do (Just b1', b2) <- Partition.splitByM (partition as) b1 (\i -> (==pmc) <$> VM.read (h3Cache as) i) - _blocks <- ((b1':maybeToList bunmarked) ++) <$> case b2 of + blocks <- ((b1':maybeToList bunmarked) ++) <$> case b2 of Nothing -> return [] -- NOTE: We need to use unsafePerformIO here, because all vector sortBy -- functions expect a pure predicate. Since our predicate is only monadic @@ -136,12 +140,30 @@ splitBlock b = ask >>= \as -> lift $ do Just b2' -> Partition.groupBy (partition as) b2' (\i -> unsafeDupablePerformIO (unsafeSTToIO (VM.read (h3Cache as) i))) - -- TODO Insert blocks except a largest into queue - undefined + let s = sort as + enqueueSorted = Queue.enqueue queue . (s,) + + ifM ((s,b) `Queue.elem` queue) (mapM_ enqueueSorted blocks) $ do + deleteLargest (\x -> Partition.blockSize (partition as) x) (maybeAdd b blocks) + >>= mapM_ enqueueSorted + +-- | Remove one largest element from the list +-- +-- TODO This could probably be more efficient +deleteLargest :: Eq e => (e -> ST s Int) -> [e] -> ST s [e] +deleteLargest sizeFunction lst = do + zipWithSize <- traverse (\x -> (,x) <$> sizeFunction x) lst + return (delete (snd (maximumBy (compare `on` fst) zipWithSize)) lst) + +-- | Add element to list if it isn't already there +maybeAdd :: Eq e => e -> [e] -> [e] +maybeAdd e lst + | e `elem` lst = lst + | otherwise = e : lst collectTouchedBlocks :: forall s h. RefinementInterface h => Vector State -> SplitM s h [(Block, RI.H3 h)] collectTouchedBlocks statesOfS = do - as <- ask + (as, _) <- ask markedBlocks <- lift $ newSTRef [] @@ -172,7 +194,7 @@ initializeAll encodings = do sorts <- iforM (V.zip encodings (rotateVectorLeft sizes)) $ \sort (Morphism (h :: h) encoding, nextSize) -> do - (blocks, state) <- initialize @h encoding nextSize + (blocks, state) <- initialize @h sort encoding nextSize mapM_ (Queue.enqueue queue . (sort,)) blocks return (SomeAlgoState state)