diff --git a/ma.cabal b/ma.cabal index 24b315f79c57b85af939a6e3a161e2fdd9c3f6b1..f6c2ec2a7165a3386e44570c72236c5a4e5092a1 100644 --- a/ma.cabal +++ b/ma.cabal @@ -105,7 +105,7 @@ library executable ma main-is: Main.hs - build-depends: base >=4.11 && <4.12 + build-depends: base >=4.11 && <4.13 hs-source-dirs: src/main default-language: Haskell2010 build-depends: ma diff --git a/src/Data/MorphismEncoding.hs b/src/Data/MorphismEncoding.hs index ea720ac9b46c9f941a7d13725c15ed7cdaada31a..682eaa23d92cdaf63e27b8e312ba64a650aaa067 100644 --- a/src/Data/MorphismEncoding.hs +++ b/src/Data/MorphismEncoding.hs @@ -23,6 +23,7 @@ module Data.MorphismEncoding , states , structure , edges + , edgeRefs , numEdges , typeOf , EdgeRef(..) @@ -36,6 +37,7 @@ import Control.DeepSeq import Data.Vector (Vector) import qualified Data.Vector as V +import qualified Data.Vector.Unboxed as VU -- | A state is represented by an integer. -- @@ -49,9 +51,9 @@ newtype EdgeRef = EdgeRef { fromEdgeRef :: Int } deriving (Eq, Ord, Show) data Edge a = Edge - { from :: {-# UNPACK #-} State - , label :: a - , to :: {-# UNPACK #-} State + { from :: ~State + , label :: ~a + , to :: ~State } deriving (Show, Eq, Functor, Generic, NFData) @@ -60,23 +62,31 @@ data Edge a = Edge -- Nodes (states) are always the set @[0..size-1]@. data Encoding a h1 = Encoding { eStructure :: Vector h1 - , eEdges :: Vector (Edge a) + , eEdgesFrom :: VU.Vector State + , eEdgesLabel :: Vector a + , eEdgesTo :: VU.Vector State } deriving (Show, Eq, Generic, NFData) instance Bifunctor Encoding where - first f e = e{ eEdges = fmap (fmap f) (eEdges e) } + first f e = e{ eEdgesLabel = fmap f (eEdgesLabel e) } second f e = e{ eStructure = fmap f (eStructure e) } -- | Construct a new encoding from a vector of state labels and a vector of -- edges. -- -- Runtime: O(1) -new :: Vector h1 -- ^ State labels. - -> Vector (Edge a) -- ^ Edges. An edge must only reference states that are +new + :: Vector h1 -- ^ State labels. + -> Vector (Edge a) -- ^ Edges. An edge must only reference states that are -- actually defined in the first vector. - -> Encoding a h1 -new !eStructure !eEdges = Encoding {..} + -> Encoding a h1 +new !eStructure !eEdges = + let eEdgesFrom = V.convert $ V.map from eEdges + eEdgesTo = V.convert $ V.map to eEdges + eEdgesLabel = V.map label eEdges + in Encoding { .. } + -- | The number of states in the encoding -- @@ -96,14 +106,19 @@ states !self = [0..length (eStructure self)-1] -- -- Runtime: O(1) edges :: Encoding a h1 -> Vector (Edge a) -edges = eEdges +edges !self = V.imap mkEdge (eEdgesLabel self) + where mkEdge i l = Edge ((eEdgesFrom self) VU.! i) l ((eEdgesTo self) VU.! i) {-# INLINE edges #-} +edgeRefs :: Encoding a h1 -> [EdgeRef] +edgeRefs !self = map EdgeRef [0..numEdges self-1] +{-# INLINE edgeRefs #-} + -- | The number of edges in the encoding -- -- Runtime: O(1) numEdges :: Encoding a h1 -> Int -numEdges = V.length . edges +numEdges = V.length . eEdgesLabel {-# INLINE numEdges #-} -- | All state labels. @@ -126,5 +141,7 @@ typeOf !self !x = eStructure self V.! x -- -- Runtime: O(1) graph :: Encoding a h1 -> EdgeRef -> Edge a -graph !self (EdgeRef !e) = eEdges self V.! e +graph !self (EdgeRef !e) = Edge ((eEdgesFrom self) VU.! e) + ((eEdgesLabel self) V.! e) + ((eEdgesTo self) VU.! e) {-# INLINE graph #-} diff --git a/src/Data/RefinablePartition.hs b/src/Data/RefinablePartition.hs index 917f42016d0c762a619eb40a3f03af9b72029c8f..2b19c0a47e8c8e375645908e4ef06b89d2ef849c 100644 --- a/src/Data/RefinablePartition.hs +++ b/src/Data/RefinablePartition.hs @@ -54,29 +54,20 @@ import Data.Partition.Common (State,Block(..)) import Data.Partition (Partition) import qualified Data.Partition as Partition -data StateRepr = StateRepr - { _block :: Block - , _location :: Int - } deriving (Show) - -makeLenses ''StateRepr - -derivingUnbox "StateRepr" - [t| StateRepr -> (Block,Int) |] - [| \(StateRepr b l) -> (b, l) |] - [| \(b, l) -> StateRepr b l |] - -data BlockRepr = BlockRepr - { _startOffset :: Int - , _endOffset :: Int -- exclusive - , _unmarkedOffset :: Int - } deriving (Show) - -makeLenses ''BlockRepr -derivingUnbox "BlockRepr" - [t| BlockRepr -> (Int, Int, Int) |] - [| \(BlockRepr s e u) -> (s, e, u) |] - [| \(s, e, u) -> (BlockRepr s e u) |] + +data States s = States + { _block :: !(VU.MVector s Block) + , _location :: !(VU.MVector s Int) + } +makeLenses ''States + + +data Blocks s = Blocks + { _startOffset :: !(VU.MVector s Int) + , _endOffset :: !(VU.MVector s Int) -- exclusive + , _unmarkedOffset :: !(VU.MVector s Int) + } +makeLenses ''Blocks -- | Refinable partition type. -- @@ -86,10 +77,10 @@ derivingUnbox "BlockRepr" -- This type is by nature mutable and can be mutated with the operations in this -- module. The `s` type variable is there to support the ST monad. data RefinablePartition s = Partition - { _blockCount :: !(MutVar s Int) - , _statesByBlock :: !(VU.MVector s State) - , _states :: !(VU.MVector s StateRepr) - , _blocks :: !(VU.MVector s BlockRepr) + { _blockCount :: {-# UNPACK #-} !(MutVar s Int) + , _statesByBlock :: {-# UNPACK #-} !(VU.MVector s State) + , _states :: {-# UNPACK #-} !(States s) + , _blocks :: {-# UNPACK #-} !(Blocks s) } makeLenses ''RefinablePartition @@ -108,10 +99,15 @@ make numStates numBlocks initPart | otherwise = do statesByBlock <- VU.new numStates - states <- VU.new numStates + statesBlock <- VU.new numStates + statesLocation <- VU.new numStates + -- we need to reserve space for more blocks, to allow splitting to create them. -- There can be at most as many blocks as there are states - blocks <- VU.new numStates + blocksStart <- VU.new numStates + blocksEnd <- VU.new numStates + blocksUnmarked <- VU.new numStates + blockCount <- newMutVar numBlocks -- contains a list of states for each block @@ -129,24 +125,22 @@ make numStates numBlocks initPart forM_ stateList $ \s -> do stateLocation <- readMutVar currentLocation - modifyMutVar currentLocation (+1) + writeMutVar currentLocation (stateLocation+1) VU.write statesByBlock stateLocation s - VU.write states s StateRepr { _block = Block i - , _location = stateLocation - } + VU.write statesBlock s (Block i) + VU.write statesLocation s stateLocation endOfBlock <- readMutVar currentLocation - VU.write blocks i BlockRepr { _startOffset = beginningOfBlock - , _endOffset = endOfBlock - , _unmarkedOffset = beginningOfBlock - } + VU.write blocksStart i beginningOfBlock + VU.write blocksEnd i endOfBlock + VU.write blocksUnmarked i beginningOfBlock return Partition { _blockCount = blockCount , _statesByBlock = statesByBlock - , _states = states - , _blocks = blocks + , _states = States statesBlock statesLocation + , _blocks = Blocks blocksStart blocksEnd blocksUnmarked } @@ -164,20 +158,22 @@ make1 numStates statesByBlock <- VUU.thaw (VUU.generate numStates id) blockCount <- newMutVar 1 - states <- VUU.thaw (VUU.generate numStates (StateRepr 0)) + statesBlock <- VU.replicate numStates 0 + statesLocation <- VUU.thaw (VUU.generate numStates id) + -- states <- VUU.thaw (VUU.generate numStates (StateRepr 0)) - blocks <- VU.new numStates - VU.write blocks 0 - $ BlockRepr - { _startOffset = 0 - , _endOffset = numStates - , _unmarkedOffset = 0 - } + blocksStart <- VU.new numStates + blocksEnd <- VU.new numStates + blocksUnmarked <- VU.new numStates + + VU.write blocksStart 0 0 + VU.write blocksEnd 0 numStates + VU.write blocksUnmarked 0 0 return Partition { _blockCount = blockCount , _statesByBlock = statesByBlock - , _states = states - , _blocks = blocks + , _states = States statesBlock statesLocation + , _blocks = Blocks blocksStart blocksEnd blocksUnmarked } @@ -186,20 +182,24 @@ make1 numStates -- Runtime: O(1) numBlocks :: RefinablePartition s -> ST s Int numBlocks p = readMutVar (p ^. blockCount) +{-# INLINE numBlocks #-} -- | Return number of states in a given block. -- -- Runtime: O(1) blockSize :: RefinablePartition s -> Block -> ST s Int -blockSize p b = getBlock p b >>= \block -> - return (block^.endOffset - block^.startOffset) +blockSize !p !b = do + start <- getBlock p b startOffset + end <- getBlock p b endOffset + return (end - start) +{-# INLINE blockSize #-} -- | Return the block that a given state belongs to. -- -- Runtime: O(1) blockOfState :: RefinablePartition s -> State -> ST s Block -blockOfState p s = getState p s >>= \state -> - return (state^.block) +blockOfState p s = getState p s block +{-# INLINE blockOfState #-} -- | Record this state as "marked". -- @@ -210,24 +210,26 @@ blockOfState p s = getState p s >>= \state -> -- -- Runtime: O(1) mark :: RefinablePartition s -> State -> ST s () -mark partition s = do - StateRepr {..} <- getState partition s - BlockRepr {..} <- getBlock partition _block +mark !partition !s = do + loc <- getState partition s location + b <- getState partition s block + + unmarked <- getBlock partition b unmarkedOffset - when (_location > _unmarkedOffset) $ - swap partition _location _unmarkedOffset + when (loc > unmarked) $ + swap partition loc unmarked - setBlock partition _block $ - unmarkedOffset %~ (+1) + setBlock partition b unmarkedOffset (unmarked+1) where -- swap two indices in statesByBlock array swap :: RefinablePartition s -> Int -> Int -> ST s () swap partition a b = do - setStateAt partition a $ location .~ b - setStateAt partition b $ location .~ a + setStateAt partition a location b + setStateAt partition b location a - VU.swap (partition^.statesByBlock) a b + VU.unsafeSwap (partition^.statesByBlock) a b +{-# INLINE mark #-} -- | Revert the marking of a state. -- @@ -236,73 +238,81 @@ mark partition s = do -- Runtime: O(1) unmark :: RefinablePartition s -> State -> ST s () unmark partition s = do - StateRepr {..} <- getState partition s - BlockRepr {..} <- getBlock partition _block + loc <- VU.read (partition^.states.location) s + b <- VU.read (partition^.states.block) s - when (_location <= _unmarkedOffset) $ - swap partition _location (_unmarkedOffset-1) + unmarked <- getBlock partition b unmarkedOffset - setBlock partition _block $ - unmarkedOffset %~ subtract 1 + when (loc <= unmarked) $ + swap partition loc (unmarked-1) + + setBlock partition b unmarkedOffset (unmarked - 1) where -- swap two indices in statesByBlock array swap :: RefinablePartition s -> Int -> Int -> ST s () swap partition a b = do - setStateAt partition a $ location .~ b - setStateAt partition b $ location .~ a + setStateAt partition a location b + setStateAt partition b location a - VU.swap (partition^.statesByBlock) a b + VU.unsafeSwap (partition^.statesByBlock) a b -- | Decide if a block has marked states. -- -- Runtime: O(1) hasMarked :: RefinablePartition s -> Block -> ST s Bool -hasMarked p b = getBlock p b >>= \block -> - return (block^.startOffset /= block^.unmarkedOffset) +hasMarked !p !b = + (/=) <$> (getBlock p b startOffset) <*> (getBlock p b unmarkedOffset) +{-# INLINE hasMarked #-} -- | Decide wether a state is marked. -- -- Runtime: O(1) isMarked :: RefinablePartition s -> State -> ST s Bool isMarked partition s = do - state <- getState partition s - blk <- getBlock partition (state^.block) + loc <- getState partition s location + b <- getState partition s block + unmarked <- getBlock partition b unmarkedOffset - return $ state^.location < blk^.unmarkedOffset + return $ loc < unmarked +{-# INLINE isMarked #-} -- | Return the marked states of a block. -- -- Runtime O(n) for n == number of states in this block markedStates :: RefinablePartition s -> Block -> ST s (VU.Vector State) markedStates partition b = do - block <- getBlock partition b + start <- getBlock partition b startOffset + unmarked <- getBlock partition b unmarkedOffset - let len = block^.unmarkedOffset - block^.startOffset + let len = unmarked - start - VU.freeze (VU.slice (block^.startOffset) len (partition^.statesByBlock)) + VU.freeze (VU.slice start len (partition^.statesByBlock)) +{-# INLINE markedStates #-} -- | Return a vector of all states in a given block. -- -- Runtime: O(n) for n == number of states in this block statesOfBlock :: RefinablePartition s -> Block -> ST s (VU.Vector State) statesOfBlock partition b = do - block <- getBlock partition b + start <- getBlock partition b startOffset len <- blockSize partition b - let slice = VU.slice (block^.startOffset) len (partition^.statesByBlock) + let slice = VU.slice start len (partition^.statesByBlock) VU.freeze slice +{-# INLINE statesOfBlock #-} -- | Return a vector of all states in a given block. -- -- Runtime: O(n) for n == number of states in this block unsafeStatesOfBlock :: RefinablePartition s -> Block -> ST s (VU.Vector State) unsafeStatesOfBlock partition b = do - block <- getBlock partition b + start <- getBlock partition b startOffset len <- blockSize partition b - let slice = VU.slice (block^.startOffset) len (partition^.statesByBlock) + let slice = VU.slice start len (partition^.statesByBlock) VU.unsafeFreeze slice +{-# INLINE unsafeStatesOfBlock #-} -- | Split a block into two new blocks for its marked and unmarked states. @@ -316,29 +326,31 @@ unsafeStatesOfBlock partition b = do -- Runtime: O(number of marked states in the old block) splitMarked :: RefinablePartition s -> Block -> ST s (Maybe Block, Maybe Block) splitMarked partition b = do - block <- getBlock partition b + start <- getBlock partition b startOffset + end <- getBlock partition b endOffset + unmarked <- getBlock partition b unmarkedOffset - if block^.startOffset == block^.unmarkedOffset then -- nothing marked + if start == unmarked then -- nothing marked return (Nothing, Just b) - else if block^.unmarkedOffset == block^.endOffset then do -- nothing unmarked - setBlock partition b $ unmarkedOffset .~ (block^.startOffset) + else if unmarked == end then do -- nothing unmarked + setBlock partition b unmarkedOffset start return (Just b, Nothing) else do - let numMarked = (block^.unmarkedOffset) - (block^.startOffset) - numUnmarked = (block^.endOffset) - (block^.unmarkedOffset) + let numMarked = unmarked - start + numUnmarked = end - unmarked if numMarked <= numUnmarked then do -- new block for marked states - new <- newBlock partition (block^.startOffset) (block^.unmarkedOffset) + new <- newBlock partition start unmarked -- let old block begin at at the first unmarked state - setBlock partition b $ startOffset .~ (block^.unmarkedOffset) + setBlock partition b startOffset unmarked return (Just new, Just b) else do -- new block for unmarked states - new <- newBlock partition (block^.unmarkedOffset) (block^.endOffset) + new <- newBlock partition unmarked end -- let old block end at at the last marked state - setBlock partition b $ endOffset .~ (block^.unmarkedOffset) + setBlock partition b endOffset unmarked -- and reset marked status for all states - setBlock partition b $ unmarkedOffset .~ (block^.startOffset) + setBlock partition b unmarkedOffset start return (Just b, Just new) @@ -353,29 +365,30 @@ splitMarked partition b = do -- Runtime: O(number of states in the old block) splitBy :: RefinablePartition s -> Block -> (State -> Bool) -> ST s (Maybe Block, Maybe Block) splitBy !partition !b !predicate = do - block@BlockRepr{..} <- getBlock partition b + start <- getBlock partition b startOffset + end <- getBlock partition b endOffset - splitPoint <- VU.partition (partition^.statesByBlock) predicate _startOffset _endOffset + splitPoint <- VU.partition (partition^.statesByBlock) predicate start end - if splitPoint == _startOffset then -- no matching states + if splitPoint == start then -- no matching states return (Nothing, Just b) - else if splitPoint == _endOffset then -- all states match + else if splitPoint == end then -- all states match return (Just b, Nothing) else do -- update location for all states, because split moves them - updateLocations partition block + updateLocations partition start end - let beforeSplitNum = splitPoint - _startOffset - afterSplitNum = _endOffset- splitPoint + let beforeSplitNum = splitPoint - start + afterSplitNum = end - splitPoint if beforeSplitNum <= afterSplitNum then do - setBlock partition b $ startOffset .~ splitPoint - setBlock partition b $ unmarkedOffset .~ splitPoint - new <- newBlock partition _startOffset splitPoint + setBlock partition b startOffset splitPoint + setBlock partition b unmarkedOffset splitPoint + new <- newBlock partition start splitPoint return (Just new, Just b) else do - setBlock partition b $ endOffset .~ splitPoint - new <- newBlock partition splitPoint _endOffset + setBlock partition b endOffset splitPoint + new <- newBlock partition splitPoint end return (Just b, Just new) -- | Monadic version of 'splitBy' @@ -389,29 +402,30 @@ splitByM :: RefinablePartition s -> Block -> (State -> ST s Bool) -> ST s (Maybe Block, Maybe Block) splitByM !partition !b !predicate = do -- FIXME Remove duplication with splitBy - block@BlockRepr{..} <- getBlock partition b + start <- getBlock partition b startOffset + end <- getBlock partition b endOffset - splitPoint <- VU.partitionM (partition^.statesByBlock) predicate _startOffset _endOffset + splitPoint <- VU.partitionM (partition^.statesByBlock) predicate start end - if splitPoint == _startOffset then -- no matching states + if splitPoint == start then -- no matching states return (Nothing, Just b) - else if splitPoint == _endOffset then -- all states match + else if splitPoint == end then -- all states match return (Just b, Nothing) else do -- update location for all states, because split moves them - updateLocations partition block + updateLocations partition start end - let beforeSplitNum = splitPoint - _startOffset - afterSplitNum = _endOffset- splitPoint + let beforeSplitNum = splitPoint - start + afterSplitNum = end - splitPoint if beforeSplitNum <= afterSplitNum then do - setBlock partition b $ startOffset .~ splitPoint - setBlock partition b $ unmarkedOffset .~ splitPoint - new <- newBlock partition _startOffset splitPoint + setBlock partition b startOffset splitPoint + setBlock partition b unmarkedOffset splitPoint + new <- newBlock partition start splitPoint return (Just new, Just b) else do - setBlock partition b $ endOffset .~ splitPoint - new <- newBlock partition splitPoint _endOffset + setBlock partition b endOffset splitPoint + new <- newBlock partition splitPoint end return (Just b, Just new) -- | Split a block into new blocks according to some atttribute of its states. @@ -424,20 +438,18 @@ splitByM !partition !b !predicate = do -- Runtime: O(n*log(n)) where n is the number of states in the old block. groupBy :: Ord a => RefinablePartition s -> Block -> (State -> a) -> ST s [Block] groupBy partition b predicate = do - block <- getBlock partition b - - let start = block^.startOffset - end = block^.endOffset + start <- getBlock partition b startOffset + end <- getBlock partition b endOffset VM.sortBy (comparing predicate) $ VU.slice start (end-start) (partition^.statesByBlock) - updateLocations partition block + updateLocations partition start end indices <- VU.groupBy (partition^.statesByBlock) predicate start end let splitAt (currentBlock,newBlocks) index = do - setBlock partition currentBlock $ unmarkedOffset .~ index + setBlock partition currentBlock unmarkedOffset index (Just previousBlock, Just nextBlock) <- splitMarked partition currentBlock return (nextBlock, previousBlock:newBlocks) @@ -455,48 +467,50 @@ freeze partition = do let nrStates = VU.length (partition^.statesByBlock) nrBlocks <- numBlocks partition Partition.fromFunctionM nrBlocks nrStates (blockOfState partition) +{-# INLINE freeze #-} + + +getState :: VUU.Unbox a => RefinablePartition s -> State -> Lens' (States s) (VU.MVector s a) -> ST s a +getState !partition !s l = VU.unsafeRead (partition^.states.l) s + +setState :: VUU.Unbox a => RefinablePartition s -> State -> Lens' (States s) (VU.MVector s a) -> a -> ST s () +setState !partition !s l = VU.unsafeWrite (partition^.states.l) s --- helpers -getBlock :: RefinablePartition s -> Block -> ST s BlockRepr -getBlock !partition (Block b) = VU.unsafeRead (partition ^. blocks) b +setStateAt :: VUU.Unbox a => RefinablePartition s -> Int -> Lens' (States s) (VU.MVector s a) -> a -> ST s () +setStateAt !partition !loc l x = VU.unsafeRead (partition^.statesByBlock) loc >>= \state -> + setState partition state l x -setBlock :: RefinablePartition s -> Block -> (BlockRepr -> BlockRepr) -> ST s () -setBlock partition (Block b) setter = VU.unsafeModify (_blocks partition) setter b -getState :: RefinablePartition s -> State -> ST s StateRepr -getState partition s = VU.unsafeRead (partition^.states) s +getBlock :: VUU.Unbox a => RefinablePartition s -> Block -> Lens' (Blocks s) (VU.MVector s a) -> ST s a +getBlock !partition (Block b) l = VU.unsafeRead (partition^.blocks.l) b -setState :: RefinablePartition s -> State -> (StateRepr -> StateRepr) -> ST s () -setState partition s setter = VU.modify (partition^.states) setter s -setStateAt :: RefinablePartition s -> Int -> (StateRepr -> StateRepr) -> ST s () -setStateAt partition loc setter = VU.read (partition^.statesByBlock) loc >>= \state -> - setState partition state setter +setBlock :: VUU.Unbox a => RefinablePartition s -> Block -> Lens' (Blocks s) (VU.MVector s a) -> a -> ST s () +setBlock !partition (Block b) l = VU.unsafeWrite (partition^.blocks.l) b + newBlock :: RefinablePartition s -> Int -> Int -> ST s Block newBlock partition beginning end = do - let repr = BlockRepr beginning end beginning -- no marked blocks by default - blk <- allocateBlock partition - setBlock partition blk (const repr) + setBlock partition blk startOffset beginning + setBlock partition blk endOffset end + setBlock partition blk unmarkedOffset beginning -- no marked blocks by default -- update block of all contained states - forM_ (blockIndices repr) $ \pos -> - setStateAt partition pos (block .~ blk) + forM_ [beginning..end-1] $ \pos -> + setStateAt partition pos block blk return blk allocateBlock :: RefinablePartition s -> ST s Block allocateBlock partition = do current <- readMutVar (partition^.blockCount) - modifyMutVar (partition^.blockCount) (+1) + writeMutVar (partition^.blockCount) (current+1) return $ Block current --- Update location fields of all states in this block -updateLocations :: RefinablePartition s -> BlockRepr -> ST s () -updateLocations partition block = - forM_ (blockIndices block) $ \i -> - setStateAt partition i $ location .~ i -blockIndices :: BlockRepr -> [Int] -blockIndices block = [block^.startOffset..block^.endOffset-1] +-- Update location fields of all states in this block +updateLocations :: RefinablePartition s -> Int -> Int -> ST s () +updateLocations partition start end = + forM_ [start..end-1] $ \i -> + setStateAt partition i location i diff --git a/src/Data/Vector/Unboxed/Mutable/Utils.hs b/src/Data/Vector/Unboxed/Mutable/Utils.hs index 6ea2cdd62af94195c4e13372d000cb444edd876a..9d48d8cce45fefcebaf8c086cd9d66673930a985 100644 --- a/src/Data/Vector/Unboxed/Mutable/Utils.hs +++ b/src/Data/Vector/Unboxed/Mutable/Utils.hs @@ -1,3 +1,5 @@ +{-# LANGUAGE BangPatterns #-} + -- | Additional functions for unboxed mutable vectors module Data.Vector.Unboxed.Mutable.Utils ( -- * Partitioning @@ -111,7 +113,7 @@ groupByM vec predicate lower upper where groupByImpl :: (a,[Int]) -> Int -> m (a,[Int]) - groupByImpl (current, accu) i + groupByImpl (current, !accu) i | i >= upper = return (current,accu) | otherwise = do x <- VU.read vec i diff --git a/src/MA/Algorithm/Initialize.hs b/src/MA/Algorithm/Initialize.hs index 36e77e82c3e649062bdcd44c3e0ec0775659218c..108ec235a7ee2400deae00d0e60380aa3306c644 100644 --- a/src/MA/Algorithm/Initialize.hs +++ b/src/MA/Algorithm/Initialize.hs @@ -1,4 +1,6 @@ {-# LANGUAGE RecordWildCards #-} +{-# LANGUAGE BangPatterns #-} + module MA.Algorithm.Initialize ( initialize ) where @@ -9,7 +11,6 @@ import Data.STRef import qualified Data.Vector as V import qualified Data.Vector.Mutable as VM -import Data.Vector.Utils import Data.MorphismEncoding import Data.RefinablePartition (Block) @@ -25,17 +26,19 @@ initialize :: forall h s. RefinementInterface h -> ST s ([Block], AlgoState s h) initialize encoding = do toSub <- VM.replicate (size encoding) [] - lastW <- VM.new (length (edges encoding)) + lastW <- VM.new (numEdges encoding) predMutable <- VM.replicate (size encoding) [] h3Cache <- VM.new (size encoding) - iforM_ (edges encoding) $ \i (Edge x _ y) -> do - VM.modify toSub (EdgeRef i:) x - VM.modify predMutable (EdgeRef i:) y + forM_ (edgeRefs encoding) $ \e -> case graph encoding e of + Edge x _ y -> do + VM.modify toSub (e:) x + VM.modify predMutable (e:) y forM_ (states encoding) $ \x -> do - outgoingLabels <- map (label . graph encoding) <$> VM.read toSub x - px <- newSTRef $ RI.init @h (typeOf encoding x) outgoingLabels + outgoingLabels <- VM.read toSub x + >>= mapM (\(!x) -> let !l = label (graph encoding x) in return l) + px <- newSTRef $! RI.init @h (typeOf encoding x) outgoingLabels VM.read toSub x >>= mapM_ (\(EdgeRef e) -> VM.write lastW e px) VM.write toSub x [] diff --git a/src/MA/Algorithm/Split.hs b/src/MA/Algorithm/Split.hs index 3fbdd9fa24b9728759eec9675f2a600d7193fd27..495d4e72f4b49bc7014dec719378c252209b8111 100644 --- a/src/MA/Algorithm/Split.hs +++ b/src/MA/Algorithm/Split.hs @@ -158,7 +158,7 @@ collectTouchedBlocks blockS = do markedBlocks <- lift $ newSTRef [] lift $ VU.forM_ statesOfS $ \y -> forM_ (pred as V.! y) $ \e -> do - let Edge x _ _ = graph (encoding as) e + let Edge !x _ _ = graph (encoding as) e b <- Partition.blockOfState (partition as) x unlessM ((==1) <$> Partition.blockSize (partition as) b) $ do diff --git a/src/MA/Algorithm/Types.hs b/src/MA/Algorithm/Types.hs index 19132720b506a52d5173b4f4f99cb513e13a6bb8..f7c49bb496ec7d52a5e793f00511bb10e450db0f 100644 --- a/src/MA/Algorithm/Types.hs +++ b/src/MA/Algorithm/Types.hs @@ -24,12 +24,12 @@ import Data.RefinablePartition ( RefinablePartition ) import MA.Coalgebra.RefinementTypes data AlgoState s h = AlgoState - { toSub :: MVector s [EdgeRef] - , lastW :: MVector s (STRef s (Weight h)) - , encoding :: Encoding (Label h) (H1 h) - , pred :: Vector [EdgeRef] - , partition :: RefinablePartition s - , h3Cache :: MVector s (H3 h) + { toSub :: {-# UNPACK #-} (MVector s [EdgeRef]) + , lastW :: {-# UNPACK #-} (MVector s (STRef s (Weight h))) + , encoding :: {-# UNPACK #-} (Encoding (Label h) (H1 h)) + , pred :: {-# UNPACK #-} (Vector [EdgeRef]) + , partition :: {-# UNPACK #-} (RefinablePartition s) + , h3Cache :: {-# UNPACK #-} (MVector s (H3 h)) } makeLensesFor diff --git a/src/MA/Functors/Distribution.hs b/src/MA/Functors/Distribution.hs index 2f8800f2dc2700c47ef2243260394ecc5587d050..7dd44ddcd100c281dbf9933a632fb3d0c1ddda96 100644 --- a/src/MA/Functors/Distribution.hs +++ b/src/MA/Functors/Distribution.hs @@ -1,6 +1,8 @@ {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE UndecidableInstances #-} {-# LANGUAGE TemplateHaskell #-} +{-# LANGUAGE BangPatterns #-} + module MA.Functors.Distribution ( distribution , Distribution(..) @@ -54,9 +56,11 @@ instance ParseMorphism Distribution where return (h1, succs) instance RefinementInterface Distribution where - init _ _ = GroupWeight 0 1 - update weightsToS (GroupWeight toRest toC) = + init _ _ = mkGroupWeight 0 1 + update weightsToS w = let + !toRest = gwToCompound w + !toC = gwToSub w toS = sum weightsToS toCwithoutS = toC - toS isOk x = x >= 0 && x <= 1 @@ -66,4 +70,4 @@ instance RefinementInterface Distribution where else mkRes (0, 0, 1) where mkRes (a, b, c) = - ( GroupWeight (a + b) c, GroupH3 a b c, GroupWeight (a + c) b ) + ( mkGroupWeight (a + b) c, mkGroupH3 a b c, mkGroupWeight (a + c) b ) diff --git a/src/MA/Functors/GroupValued.hs b/src/MA/Functors/GroupValued.hs index 794ca202e2f1664529a515f18ccbef3a4809477c..e878dae7f2e5bdf1f9465313437d4eda0d714361 100644 --- a/src/MA/Functors/GroupValued.hs +++ b/src/MA/Functors/GroupValued.hs @@ -14,8 +14,8 @@ module MA.Functors.GroupValued , realValued , complexValued , GroupValued(..) - , GroupWeight(..) - , GroupH3(..) + , IsGroupWeight(..) + , IsGroupH3(..) , OrderedComplex(..) ) where @@ -78,11 +78,83 @@ complexValued = FunctorDescription ((L.symbol "C" <|> L.symbol "ℂ") >> L.symbol "^" >> pure GroupValued) } -data GroupWeight m = GroupWeight !m !m - deriving (Eq, Ord, Show) +class IsGroupWeight m where + data GroupWeight m + gwToCompound :: GroupWeight m -> m + gwToSub :: GroupWeight m -> m + mkGroupWeight :: m -> m -> GroupWeight m + +instance IsGroupWeight Int where + data GroupWeight Int = IntGroupWeight {-# UNPACK #-} !Int {-# UNPACK #-} !Int + gwToCompound (IntGroupWeight x _) = x + gwToSub (IntGroupWeight _ x) = x + mkGroupWeight = IntGroupWeight + +instance IsGroupWeight EqDouble where + data GroupWeight EqDouble = EqDoubleGroupWeight {-# UNPACK #-} !EqDouble {-# UNPACK #-} !EqDouble + gwToCompound (EqDoubleGroupWeight x _) = x + gwToSub (EqDoubleGroupWeight _ x) = x + mkGroupWeight = EqDoubleGroupWeight + +instance IsGroupWeight OrderedComplex where + data GroupWeight OrderedComplex = + OrderedComplexGroupWeight {-# UNPACK #-} !OrderedComplex {-# UNPACK #-} !OrderedComplex + gwToCompound (OrderedComplexGroupWeight x _) = x + gwToSub (OrderedComplexGroupWeight _ x) = x + mkGroupWeight = OrderedComplexGroupWeight + + +class IsGroupH3 m where + data GroupH3 m + h3ToRest :: GroupH3 m -> m + h3ToCompound :: GroupH3 m -> m + h3ToSub :: GroupH3 m -> m + mkGroupH3 :: m -> m -> m -> GroupH3 m + + +instance IsGroupH3 Int where + data GroupH3 Int = IntGroupH3 {-# UNPACK #-} !Int {-# UNPACK #-} !Int {-# UNPACK #-} !Int + h3ToRest (IntGroupH3 x _ _) = x + h3ToCompound (IntGroupH3 _ x _) = x + h3ToSub (IntGroupH3 _ _ x) = x + mkGroupH3 = IntGroupH3 + +instance IsGroupH3 EqDouble where + data GroupH3 EqDouble = EqDoubleGroupH3 {-# UNPACK #-} !EqDouble {-# UNPACK #-} !EqDouble {-# UNPACK #-} !EqDouble + h3ToRest (EqDoubleGroupH3 x _ _) = x + h3ToCompound (EqDoubleGroupH3 _ x _) = x + h3ToSub (EqDoubleGroupH3 _ _ x) = x + mkGroupH3 = EqDoubleGroupH3 + +instance IsGroupH3 OrderedComplex where + data GroupH3 OrderedComplex = OrderedComplexGroupH3 {-# UNPACK #-} !OrderedComplex {-# UNPACK #-} !OrderedComplex {-# UNPACK #-} !OrderedComplex + h3ToRest (OrderedComplexGroupH3 x _ _) = x + h3ToCompound (OrderedComplexGroupH3 _ x _) = x + h3ToSub (OrderedComplexGroupH3 _ _ x) = x + mkGroupH3 = OrderedComplexGroupH3 + +instance (Eq a, IsGroupH3 a) => Eq (GroupH3 a) where + x == y = + (h3ToRest x == h3ToRest y) + && (h3ToCompound x == h3ToCompound y) + && (h3ToSub x == h3ToSub y) + +instance (Ord a, IsGroupH3 a) => Ord (GroupH3 a) where + compare x y = + compare (h3ToRest x) (h3ToRest y) + <> compare (h3ToCompound x) (h3ToCompound y) + <> compare (h3ToSub x) (h3ToSub y) + +instance (Show a, IsGroupH3 a) => Show (GroupH3 a) where + showsPrec p x = + showParen (p > 10) + $ showString "GroupH3 " + . showsPrec 11 (h3ToRest x) + . showChar ' ' + . showsPrec 11 (h3ToCompound x) + . showChar ' ' + . showsPrec 11 (h3ToSub x) -data GroupH3 m = GroupH3 !m !m !m - deriving (Eq, Ord, Show) type instance Label (GroupValued m) = m type instance Weight (GroupValued m) = GroupWeight m @@ -119,20 +191,25 @@ instance ParseMorphism (GroupValued OrderedComplex) where parseMorphismPoint (GroupValued inner) = parseMorphismPointHelper inner (OrderedComplex <$> L.complex L.adouble) -instance (Num m, Ord m) => RefinementInterface (GroupValued m) where +instance (IsGroupWeight m, IsGroupH3 m, Ord m, Num m) => RefinementInterface (GroupValued m) where + {-# SPECIALIZE instance RefinementInterface (GroupValued Int) #-} + {-# SPECIALIZE instance RefinementInterface (GroupValued EqDouble) #-} + {-# SPECIALIZE instance RefinementInterface (GroupValued OrderedComplex) #-} init :: H1 (GroupValued m) -> [Label (GroupValued m)] -> Weight (GroupValued m) - init _ weights = GroupWeight 0 (sum weights) + init _ weights = mkGroupWeight 0 (sum weights) update :: [Label (GroupValued m)] -> Weight (GroupValued m) -> (Weight (GroupValued m), H3 (GroupValued m), Weight (GroupValued m)) - update weightsToS (GroupWeight toRest toC) = + update weightsToS !w = let + !toRest = gwToCompound w + !toC = gwToSub w !toS = sum weightsToS !toCwithoutS = toC - toS !toNotS = toRest + toCwithoutS !toNotC = toRest + toS in - ( GroupWeight toNotS toS - , GroupH3 toRest toCwithoutS toS - , GroupWeight toNotC toCwithoutS + ( mkGroupWeight toNotS toS + , mkGroupH3 toRest toCwithoutS toS + , mkGroupWeight toNotC toCwithoutS ) diff --git a/src/MA/Functors/Powerset.hs b/src/MA/Functors/Powerset.hs index 63f738485c940198c8ef70805bd9ec4169c30d49..fdef686010147af4f5d215d81cb454ede98f164b 100644 --- a/src/MA/Functors/Powerset.hs +++ b/src/MA/Functors/Powerset.hs @@ -1,11 +1,18 @@ {-# LANGUAGE TemplateHaskell #-} +{-# LANGUAGE BangPatterns #-} module MA.Functors.Powerset ( Powerset(..) , powerset + -- * For testing + , PowerWeight(..) + , PowerH3(..) + , mkPowerH3 ) where import Control.Monad (when) +import Data.Word (Word8) +import Data.Bits import Text.Megaparsec import qualified Data.Vector as V @@ -32,17 +39,29 @@ powerset = FunctorDescription prefix ((L.symbol "P" <|> L.symbol "Ƥ") >> pure Powerset) } +data PowerWeight = PowerWeight {-# UNPACK #-} !Int {-# UNPACK #-} !Int + deriving (Show,Eq) + +newtype PowerH3 = PowerH3 Word8 + deriving (Show, Ord, Eq) + +mkPowerH3 :: Bool -> Bool -> Bool -> PowerH3 +mkPowerH3 !a !b !c = PowerH3 $ set 0 a .|. set 1 b .|. set 2 c + where set !i True = bit i + set !_ False = 0 +{-# INLINE mkPowerH3 #-} + -- | No edge labels type instance Label Powerset = () -- | Tuple of (|edgesToC\S|, |edgesToS|) -type instance Weight Powerset = (Int, Int) +type instance Weight Powerset = PowerWeight -- | Does this state have at least one successor? type instance H1 Powerset = Bool -- | Tuple of: -- - do we have edges to the rest? -- - do we have edges to C\S? -- - do we have edges to S? -type instance H3 Powerset = (Bool, Bool, Bool) +type instance H3 Powerset = PowerH3 instance ParseMorphism Powerset where parseMorphismPoint (Powerset inner) = do @@ -58,16 +77,16 @@ instance ParseMorphism Powerset where instance RefinementInterface Powerset where init :: H1 Powerset -> [Label Powerset] -> Weight Powerset - init _ = (0, ) . length + init _ = PowerWeight 0 . length update :: [Label Powerset] -> Weight Powerset -> (Weight Powerset, H3 Powerset, Weight Powerset) - update labels (toRest, toC) = + update labels (PowerWeight toRest toC) = let toS = length labels toCwithoutS = toC - toS - weightToS = (toRest + toCwithoutS, toS) - h3 = (toRest > 0, toCwithoutS > 0, toS > 0) - weightToCwithoutS = (toRest + toS, toCwithoutS) + !weightToS = PowerWeight (toRest + toCwithoutS) toS + !h3 = mkPowerH3 (toRest > 0) (toCwithoutS > 0) (toS > 0) + !weightToCwithoutS = PowerWeight (toRest + toS) toCwithoutS in (weightToS, h3, weightToCwithoutS) diff --git a/tests/MA/Algorithm/InitializeSpec.hs b/tests/MA/Algorithm/InitializeSpec.hs index 7806022602999bb5091db72b48f9d7ef1502c21f..c6001954b872aaab6b982bed38c145e535725ced 100644 --- a/tests/MA/Algorithm/InitializeSpec.hs +++ b/tests/MA/Algorithm/InitializeSpec.hs @@ -68,7 +68,8 @@ lastWSpec = describe "returned lastW vector" $ do getLastW [False, False] [] `shouldBe` [] it "works with some edges" $ - getLastW [True, True] [(0, 1), (0, 0), (1, 0)] `shouldBe` [(0, 2), (0, 2), (0, 1)] + getLastW [True, True] [(0, 1), (0, 0), (1, 0)] + `shouldBe` [PowerWeight 0 2, PowerWeight 0 2, PowerWeight 0 1] enc :: [h1] -> [(State, label, State)] -> Encoding label h1 diff --git a/tests/MA/Algorithm/SplitSpec.hs b/tests/MA/Algorithm/SplitSpec.hs index 4aa2e5d9a353abdc3fcd5ad002234e5b0146046e..1c0a7b58752300c679c40d1a4739be8f8c9a6837 100644 --- a/tests/MA/Algorithm/SplitSpec.hs +++ b/tests/MA/Algorithm/SplitSpec.hs @@ -61,14 +61,14 @@ collectTouchedBlocksSpec = describe "collectTouchedBlocks" $ do -- initially, C contains all blocks, so calling update with empty list of -- edges to S, the result is: No edges to "outside of C", all edges to C -- and no edges to S. - `shouldBe` [(False, True, False)] + `shouldBe` [mkPowerH3 False True False] it "marks the correct states" $ do let res = withState @Powerset (enc [True, True, False] [(0, (), 2), (1, (), 2)]) $ do [(b, _)] <- collectTouchedBlocks (Block 1) - p <- view (_1 . partitionL) + p <- view (_1 . partitionL) VU.toList <$> lift (Partition.markedStates p b) res `shouldMatchList` [0, 1] @@ -126,13 +126,18 @@ updateBlockSpec = describe "updateBlock" $ do updateBlock b v0 lw <- view (_1 . lastWL) >>= lift . V.freeze lift (lw & V.toList & mapM readSTRef) - in res `shouldBe` [(2, 1), (1, 2), (1, 2), (0, 1)] + in res + `shouldBe` [ PowerWeight 2 1 + , PowerWeight 1 2 + , PowerWeight 1 2 + , PowerWeight 0 1 + ] -- The idea here is that the edges from state 0 to block 1 cancel each other -- out and thus the state has a total weight of 0 and must be unmarked. it "unmarks states where H3 is v0" $ let res = - withState @(GroupValued Integer) + withState @(GroupValued Int) (enc [1, 1, 0, 0] [(0, 1, 2), (0, (-1), 3), (1, 1, 3), (0, 1, 1)] ) @@ -145,7 +150,7 @@ updateBlockSpec = describe "updateBlock" $ do it "caches H3 values for all non-v0 states" $ let res = - withState @(GroupValued Integer) + withState @(GroupValued Int) (enc [1, 1, 0, 0] [(0, 1, 2), (0, (-1), 3), (1, 1, 3), (0, 1, 1)] ) @@ -154,7 +159,7 @@ updateBlockSpec = describe "updateBlock" $ do updateBlock b v0 h3 <- view (_1 . h3CacheL) >>= lift . V.freeze return (h3 V.! 1) - in res `shouldBe` (GroupH3 0 0 1) + in res `shouldBe` (mkGroupH3 0 0 1) splitBlockSpec :: Spec @@ -170,29 +175,27 @@ splitBlockSpec = describe "splitBlock" $ do in res `shouldMatchList` [Block 0, Block 2] it "splits different H3s into different blocks" - $ let - res = - withState @(GroupValued Int) - (enc [3, 3, 3, 0] - [(0, 1, 3), (1, 2, 3), (2, 3, 3), (0, 2, 0), (1, 1, 1)] - ) - $ do - [(b, v0)] <- collectTouchedBlocks (Block 1) - updateBlock b v0 - splitBlock b + $ let res = + withState @(GroupValued Int) + (enc [3, 3, 3, 0] + [(0, 1, 3), (1, 2, 3), (2, 3, 3), (0, 2, 0), (1, 1, 1)] + ) + $ do + [(b, v0)] <- collectTouchedBlocks (Block 1) + updateBlock b v0 + splitBlock b in res `shouldMatchList` [Block 0, Block 2, Block 3] it "combines equal H3s into the same block" - $ let - res = - withState @(GroupValued Int) - (enc [3, 3, 3, 0] - [(0, 1, 3), (1, 1, 3), (2, 3, 3), (0, 2, 0), (1, 2, 1)] - ) - $ do - [(b, v0)] <- collectTouchedBlocks (Block 1) - updateBlock b v0 - splitBlock b + $ let res = + withState @(GroupValued Int) + (enc [3, 3, 3, 0] + [(0, 1, 3), (1, 1, 3), (2, 3, 3), (0, 2, 0), (1, 2, 1)] + ) + $ do + [(b, v0)] <- collectTouchedBlocks (Block 1) + updateBlock b v0 + splitBlock b in res `shouldMatchList` [Block 0, Block 2]