RefinablePartition.hs 15.4 KB
Newer Older
1
2
3
4
5
6
7
8
9
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE StrictData #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE MagicHash #-}

module Data.RefinablePartition
  ( -- * Types
    RefinablePartition
10
  , State
11
  , Block(..)
12
13
14
15
16
17
18
19
20
  -- * Construction
  , make
  -- * Accessors
  , numBlocks
  , blockSize
  , blockOfState
  , statesOfBlock
  -- * Marking
  , mark
21
  , unmark
22
23
  , isMarked
  , hasMarked
24
  , markedStates
25
26
27
  -- * Splitting
  , splitMarked
  , splitBy
28
  , splitByM
29
  , groupBy
30
31
  -- * Conversion
  , freeze
32
  , unsafeStatesOfBlock
33
34
  ) where

35
import           Control.Monad (forM_, when, foldM)
36
37
38
39
40
41
42
43
44
import           Control.Monad.ST
import           Data.Ord (comparing)

import           Data.Primitive.MutVar
import qualified Data.Vector as V
import qualified Data.Vector.Algorithms.Heap as VM
import           Data.Vector.Mutable (MVector)
import qualified Data.Vector.Mutable as VM
import qualified Data.Vector.Unboxed.Mutable as VU
45
import qualified Data.Vector.Unboxed as VU (convert, freeze, unsafeFreeze, Vector)
46
47
48
import           Lens.Micro
import           Lens.Micro.TH

Hans-Peter Deifel's avatar
Hans-Peter Deifel committed
49
import           Data.Vector.Utils (iforM_)
50
import qualified Data.Vector.Unboxed.Mutable.Utils as VU
51
import           Data.Partition.Common (State,Block(..))
52
53
import           Data.Partition (Partition)
import qualified Data.Partition as Partition
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167

data StateRepr = StateRepr
  { _block :: {-# UNPACK #-} Block
  , _location :: {-# UNPACK #-} Int
  } deriving (Show)

makeLenses ''StateRepr

data BlockRepr = BlockRepr
  { _startOffset :: {-# UNPACK #-} Int
  , _endOffset :: {-# UNPACK #-} Int -- exclusive
  , _unmarkedOffset :: {-# UNPACK #-} Int
  } deriving (Show)

makeLenses ''BlockRepr

-- | Refinable partition type.
--
-- Should be thought of as a partition {A₁,..} of a set of 'State' values from
-- zero to n. The `Ai`s are called Blocks.
--
-- This type is by nature mutable and can be mutated with the operations in this
-- module. The `s` type variable is there to support the ST monad.
data RefinablePartition s = Partition
  { _blockCount :: MutVar s Int
  , _statesByBlock :: VU.MVector s State
  , _states :: MVector s StateRepr
  , _blocks :: MVector s BlockRepr
  }

makeLenses ''RefinablePartition

-- | Create a mutable refinable partition.
make :: Int -- ^ Number of states n
     -> Int -- ^ Number of initial blocks m
     -> (State -> Block)
     -- ^ Initial partition assigning a block to each state. Must be defined on
     -- {0..n} and only have values in {0..m}.
     -> ST s (RefinablePartition s)
make numStates numBlocks initPart
  | numStates < numBlocks = error "RefinablePartition.make: More blocks than states"
  | numBlocks == 0 = error "RefinablePartition.make: zero blocks"
  | otherwise = do

  statesByBlock <- VU.new numStates
  states <- VM.new numStates
  -- we need to reserve space for more blocks, to allow splitting to create them.
  -- There can be at most as many blocks as there are states
  blocks <- VM.new numStates
  blockCount <- newMutVar numBlocks

  -- contains a list of states for each block
  blockStates <- VM.replicate numBlocks []

  forM_ [0..numStates-1] $ \i -> do
    let Block block = initPart i
    VM.modify blockStates (i:) block

  blockStates' <- V.unsafeFreeze blockStates

  currentLocation <- newMutVar 0
  iforM_ blockStates' $ \i stateList -> do
    beginningOfBlock <- readMutVar currentLocation

    forM_ stateList $ \s -> do
      stateLocation <- readMutVar currentLocation
      modifyMutVar currentLocation (+1)
      VU.write statesByBlock stateLocation s
      VM.write states s StateRepr { _block = Block i
                                  , _location = stateLocation
                                  }

    endOfBlock <- readMutVar currentLocation

    VM.write blocks i BlockRepr { _startOffset = beginningOfBlock
                                , _endOffset = endOfBlock
                                , _unmarkedOffset = beginningOfBlock
                                }


  return Partition { _blockCount = blockCount
                   , _statesByBlock = statesByBlock
                   , _states = states
                   , _blocks = blocks
                   }

-- | Return number of blocks in this partition.
--
-- Runtime: O(1)
numBlocks :: RefinablePartition s -> ST s Int
numBlocks p = readMutVar (p ^. blockCount)

-- | Return number of states in a given block.
--
-- Runtime: O(1)
blockSize :: RefinablePartition s -> Block -> ST s Int
blockSize p b = getBlock p b >>= \block ->
  return (block^.endOffset - block^.startOffset)

-- | Return the block that a given state belongs to.
--
-- Runtime: O(1)
blockOfState :: RefinablePartition s -> State -> ST s Block
blockOfState p s = getState p s >>= \state ->
  return (state^.block)

-- | Record this state as "marked".
--
-- This is implemented in a way to efficiently split the block in marked and
-- unmarked states.
--
-- Runtime: O(1)
mark :: RefinablePartition s -> State -> ST s ()
mark partition s = do
Hans-Peter Deifel's avatar
Hans-Peter Deifel committed
168
169
  StateRepr {..} <- getState partition s
  BlockRepr {..} <- getBlock partition _block
170
171

  when (_location > _unmarkedOffset) $
Hans-Peter Deifel's avatar
Hans-Peter Deifel committed
172
    swap partition _location _unmarkedOffset
173
174
175
176
177
178

  setBlock partition _block $
    unmarkedOffset %~ (+1)

  where
    -- swap two indices in statesByBlock array
179
180
181
182
183
184
185
186
187
188
189
190
191
192
    swap :: RefinablePartition s -> Int -> Int -> ST s ()
    swap partition a b = do
      setStateAt partition a $ location .~ b
      setStateAt partition b $ location .~ a

      VU.swap (partition^.statesByBlock) a b

-- | Revert the marking of a state.
--
-- See 'mark'
--
-- Runtime: O(1)
unmark :: RefinablePartition s -> State -> ST s ()
unmark partition s = do
Hans-Peter Deifel's avatar
Hans-Peter Deifel committed
193
194
  StateRepr {..} <- getState partition s
  BlockRepr {..} <- getBlock partition _block
195
196

  when (_location <= _unmarkedOffset) $
Hans-Peter Deifel's avatar
Hans-Peter Deifel committed
197
    swap partition _location (_unmarkedOffset-1)
198
199

  setBlock partition _block $
Hans-Peter Deifel's avatar
Hans-Peter Deifel committed
200
    unmarkedOffset %~ subtract 1
201
202
203

  where
    -- swap two indices in statesByBlock array
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
    swap :: RefinablePartition s -> Int -> Int -> ST s ()
    swap partition a b = do
      setStateAt partition a $ location .~ b
      setStateAt partition b $ location .~ a

      VU.swap (partition^.statesByBlock) a b

-- | Decide if a block has marked states.
--
-- Runtime: O(1)
hasMarked :: RefinablePartition s -> Block -> ST s Bool
hasMarked p b = getBlock p b >>= \block ->
  return (block^.startOffset /= block^.unmarkedOffset)

-- | Decide wether a state is marked.
--
-- Runtime: O(1)
isMarked :: RefinablePartition s -> State -> ST s Bool
isMarked partition s = do
  state <- getState partition s
  blk <- getBlock partition (state^.block)

  return $ state^.location < blk^.unmarkedOffset

228
229
230
-- | Return the marked states of a block.
--
-- Runtime O(n) for n == number of states in this block
231
markedStates :: RefinablePartition s -> Block -> ST s (VU.Vector State)
232
233
234
235
236
markedStates partition b = do
  block <- getBlock partition b

  let len = block^.unmarkedOffset - block^.startOffset

237
  VU.freeze (VU.slice (block^.startOffset) len (partition^.statesByBlock))
238

239
-- | Return a vector of all states in a given block.
240
241
--
-- Runtime: O(n) for n == number of states in this block
242
statesOfBlock :: RefinablePartition s -> Block -> ST s (VU.Vector State)
243
244
245
statesOfBlock partition b = do
  block <- getBlock partition b

246
  len <- blockSize partition b
247
  let slice = VU.slice (block^.startOffset) len (partition^.statesByBlock)
248
249
250
251
252
253
254
255
256
257
258
259
260
  VU.freeze slice

-- | Return a vector of all states in a given block.
--
-- Runtime: O(n) for n == number of states in this block
unsafeStatesOfBlock :: RefinablePartition s -> Block -> ST s (VU.Vector State)
unsafeStatesOfBlock partition b = do
  block <- getBlock partition b

  len <- blockSize partition b
  let slice = VU.slice (block^.startOffset) len (partition^.statesByBlock)
  VU.unsafeFreeze slice

261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309

-- | Split a block into two new blocks for its marked and unmarked states.
--
-- Returns a tuple of (marked block, unmarked block). If there are no marked or
-- unmarked states, the respective block returned will be Nothing.
--
-- The largest new block will inherit the identity (i.e. block number) of the
-- old block (with bias towards the one with unmarked states).
--
-- Runtime: O(number of marked states in the old block)
splitMarked :: RefinablePartition s -> Block -> ST s (Maybe Block, Maybe Block)
splitMarked partition b = do
  block <- getBlock partition b

  if block^.startOffset == block^.unmarkedOffset then -- nothing marked
    return (Nothing, Just b)
  else if block^.unmarkedOffset == block^.endOffset then do -- nothing unmarked
    setBlock partition b $ unmarkedOffset .~ (block^.startOffset)
    return (Just b, Nothing)
  else do
    let numMarked = (block^.unmarkedOffset) - (block^.startOffset)
        numUnmarked = (block^.endOffset) - (block^.unmarkedOffset)
    if numMarked <= numUnmarked then do
      -- new block for marked states
      new <- newBlock partition (block^.startOffset) (block^.unmarkedOffset)
      -- let old block begin at at the first unmarked state
      setBlock partition b $ startOffset .~ (block^.unmarkedOffset)
      return (Just new, Just b)
    else do
      -- new block for unmarked states
      new <- newBlock partition (block^.unmarkedOffset) (block^.endOffset)
      -- let old block end at at the last marked state
      setBlock partition b $ endOffset .~ (block^.unmarkedOffset)
      -- and reset marked status for all states
      setBlock partition b $ unmarkedOffset .~ (block^.startOffset)
      return (Just b, Just new)


-- | Split a block into two new blocks according to a given predicate.
--
-- Returns a tuple @(a, b)@ where @a@ is the block with all states for which the
-- predicate returns True and 'b' contains the rest of the states.
--
-- The largest block in @[a,b]@ shares the identity (i.e. the block number) with
-- the original block.
--
-- Runtime: O(number of states in the old block)
splitBy :: RefinablePartition s -> Block -> (State -> Bool) -> ST s (Maybe Block, Maybe Block)
splitBy !partition !b !predicate = do
Hans-Peter Deifel's avatar
Hans-Peter Deifel committed
310
  block@BlockRepr{..} <- getBlock partition b
311
312
313

  splitPoint <- VU.partition (partition^.statesByBlock) predicate _startOffset _endOffset

314
315
316
317
318
319
320
321
322
323
324
  if splitPoint == _startOffset then -- no matching states
    return (Nothing, Just b)
  else if splitPoint == _endOffset then -- all states match
    return (Just b, Nothing)
  else do
    -- update location for all states, because split moves them
    updateLocations partition block

    let beforeSplitNum = splitPoint - _startOffset
        afterSplitNum = _endOffset- splitPoint

Hans-Peter Deifel's avatar
Hans-Peter Deifel committed
325
    if beforeSplitNum <= afterSplitNum then do
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
      setBlock partition b $ startOffset .~ splitPoint
      setBlock partition b $ unmarkedOffset .~ splitPoint
      new <- newBlock partition _startOffset splitPoint
      return (Just new, Just b)
    else do
      setBlock partition b $ endOffset .~ splitPoint
      new <- newBlock partition splitPoint _endOffset
      return (Just b, Just new)

-- | Monadic version of 'splitBy'
--
-- This is the same as 'splitBy', but accepts a monadic predicate. This
-- predicate can be called any number of times and should not have side effects,
-- but can be used to read from mutable vectors and the like.
--
-- Runtime: O(number of states in the old block)
splitByM :: RefinablePartition s -> Block -> (State -> ST s Bool)
         -> ST s (Maybe Block, Maybe Block)
splitByM !partition !b !predicate = do
  -- FIXME Remove duplication with splitBy
Hans-Peter Deifel's avatar
Hans-Peter Deifel committed
346
  block@BlockRepr{..} <- getBlock partition b
347
348
349

  splitPoint <- VU.partitionM (partition^.statesByBlock) predicate _startOffset _endOffset

350
351
352
353
354
355
356
357
358
359
360
  if splitPoint == _startOffset then -- no matching states
    return (Nothing, Just b)
  else if splitPoint == _endOffset then -- all states match
    return (Just b, Nothing)
  else do
    -- update location for all states, because split moves them
    updateLocations partition block

    let beforeSplitNum = splitPoint - _startOffset
        afterSplitNum = _endOffset- splitPoint

Hans-Peter Deifel's avatar
Hans-Peter Deifel committed
361
    if beforeSplitNum <= afterSplitNum then do
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
      setBlock partition b $ startOffset .~ splitPoint
      setBlock partition b $ unmarkedOffset .~ splitPoint
      new <- newBlock partition _startOffset splitPoint
      return (Just new, Just b)
    else do
      setBlock partition b $ endOffset .~ splitPoint
      new <- newBlock partition splitPoint _endOffset
      return (Just b, Just new)

-- | Split a block into new blocks according to some atttribute of its states.
--
-- The result is maximally coarse list of blocks, such all states in a new block
-- have the same value for the given attribute.
--
-- One of the blocks inherits the identity of the old block.
--
-- Runtime: O(n*log(n)) where n is the number of states in the old block.
groupBy :: Ord a => RefinablePartition s -> Block -> (State -> a) -> ST s [Block]
groupBy partition b predicate = do
  block <- getBlock partition b

  let start = block^.startOffset
      end   = block^.endOffset

  VM.sortBy (comparing predicate) $
    VU.slice start (end-start) (partition^.statesByBlock)

  updateLocations partition block

  indices <- VU.groupBy (partition^.statesByBlock) predicate start end

  let splitAt (currentBlock,newBlocks) index = do
        setBlock partition currentBlock $ unmarkedOffset .~ index
        (Just previousBlock, Just nextBlock) <- splitMarked partition currentBlock
        return (nextBlock, newBlocks++[previousBlock])

  (last,blocks) <- foldM splitAt (b, []) indices

  -- unless (null indices) $
  --   setBlock partition b $ endOffset .~ head indices

  return (blocks ++ [last])

405
406
407
408
409
410
411
412
413
414
415
-- | Freeze the current refinable partition into an immutable one.
--
-- This copies the current state of the partition and doesn't modify it.
--
-- Runtime: O(number of states)
freeze :: RefinablePartition s -> ST s Partition
freeze partition = do
  let nrStates = VU.length (partition^.statesByBlock)
  nrBlocks <- numBlocks partition
  Partition.fromFunctionM nrBlocks nrStates (blockOfState partition)

416
417
-- helpers
getBlock :: RefinablePartition s -> Block -> ST s BlockRepr
Hans-Peter Deifel's avatar
Hans-Peter Deifel committed
418
getBlock !partition (Block b) = VM.unsafeRead (_blocks partition) b
419
{-# INLINE getBlock #-}
420
421
422

setBlock :: RefinablePartition s -> Block -> (BlockRepr -> BlockRepr) -> ST s ()
setBlock partition (Block b) setter = VM.unsafeModify (_blocks partition) setter b
423
{-# INLINE setBlock #-}
424
425
426

getState :: RefinablePartition s -> State -> ST s StateRepr
getState partition s = VM.unsafeRead (partition^.states) s
427
{-# INLINE getState #-}
428
429
430

setState :: RefinablePartition s -> State -> (StateRepr -> StateRepr) -> ST s ()
setState partition s setter = VM.modify (partition^.states) setter s
431
{-# INLINE setState #-}
432
433
434
435

setStateAt :: RefinablePartition s -> Int -> (StateRepr -> StateRepr) -> ST s ()
setStateAt partition loc setter = VU.read (partition^.statesByBlock) loc >>= \state ->
  setState partition state setter
436
{-# INLINE setStateAt #-}
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464

newBlock :: RefinablePartition s -> Int -> Int -> ST s Block
newBlock partition beginning end = do
  let repr = BlockRepr beginning end beginning -- no marked blocks by default

  blk <- allocateBlock partition
  setBlock partition blk (const repr)

  -- update block of all contained states
  forM_ (blockIndices repr) $ \pos ->
    setStateAt partition pos (block .~ blk)

  return blk

allocateBlock :: RefinablePartition s -> ST s Block
allocateBlock partition = do
  current <- readMutVar (partition^.blockCount)
  modifyMutVar (partition^.blockCount) (+1)
  return $ Block current

-- Update location fields of all states in this block
updateLocations :: RefinablePartition s -> BlockRepr -> ST s ()
updateLocations partition block =
  forM_ (blockIndices block) $ \i ->
    setStateAt partition i $ location .~ i

blockIndices :: BlockRepr -> [Int]
blockIndices block = [block^.startOffset..block^.endOffset-1]