Utils.hs 4.43 KB
Newer Older
1
-- | Additional functions for unboxed mutable vectors
2
module Data.Vector.Unboxed.Mutable.Utils
3
4
  ( -- * Partitioning
    partition
5
  , partitionM
6
  -- * Grouping
7
  , groupBy
8
  , groupByM
9
10
11
  ) where

import           Control.Monad (foldM)
12
import qualified Control.Monad.ST.Strict
13

14
import           Control.Monad.Primitive
15
16
import qualified Data.Vector.Unboxed.Mutable as VU

17
18
19
20
21
22
23
24
25
26
27
28
29
30
-- | Partition a region inside a vector according to some predicate.
--
-- Moves all elements for which the predicate is 'True' to the start and all
-- other elements to the end of the region. Return the index of the first
-- element for which the predicate is 'False'.
--
-- ==== __Example__
--
-- >>> import qualified Data.Vector.Unboxed as V
-- >>> vec' <- V.thaw (V.fromList [3, 2, 1] :: V.Vector Int)
-- >>> s <- partition vec' (<=2) 0 3
-- >>> vec <- V.freeze vec'
-- >>> V.splitAt s vec
-- ([1,2],[3])
31
partition :: (VU.Unbox a, PrimMonad m)
32
33
34
35
36
          => VU.MVector (PrimState m) a -- ^ Vector to modify
          -> (a -> Bool) -- ^ The predicate
          -> Int -- ^ The beginning of the region to partition (inclusive)
          -> Int -- ^ The end of the region to partition (exclusive)
          -> m Int
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
partition vec predicate  = go
  where
    go lower upper
      | lower >= upper = return lower
      | otherwise = do
          l <- VU.read vec lower
          r <- VU.read vec (upper-1)

          let predL = predicate l
          let predR = predicate r

          if predL then go (lower+1) upper
            else if not predR then go lower (upper-1)
            else VU.swap vec lower (upper-1) >> go (lower+1) upper
{-# SPECIALIZE INLINE partition :: VU.MVector s Int -> (Int -> Bool) -> Int -> Int -> Control.Monad.ST.Strict.ST s Int #-}
52

53
-- | 'partition' with monadic predicate
54
partitionM :: (VU.Unbox a, PrimMonad m)
55
56
57
58
59
           => VU.MVector (PrimState m) a
           -> (a -> m Bool) -- ^ Vector to modify
           -> Int -- ^ The predicate
           -> Int -- ^ The beginning of the region to partition (inclusive)
           -> m Int -- ^ The end of the region to partition (exclusive)
60
partitionM vec predicate  = go
61
62
63
64
65
66
67
  where
    go lower upper
      | lower >= upper = return lower
      | otherwise = do
          l <- VU.read vec lower
          r <- VU.read vec (upper-1)

68
69
          predL <- predicate l
          predR <- predicate r
70

71
72
73
74
          if predL then partitionM vec predicate (lower+1) upper
            else if not predR then partitionM vec predicate lower (upper-1)
            else VU.swap vec lower (upper-1) >> partitionM vec predicate (lower+1) upper
{-# INLINE partitionM #-}
75
76


77
78
79
80
81
82
83
84
85
86
87
88
89
90
-- | Group a region inside a sorted vector according to a predicate.
--
-- Given a vector that is already sorted with respect to the predicate, this
-- returns a list of indices i where a block new of consecutive equal elements
-- (according to the predicate) begins. In other words, returns a list of
-- indices @i@, such that @p v[i-1]@ is different than @p v[i]@. The vector is
-- not modified (and must be already sorted).
--
-- ==== __Example__
--
-- >>> import Data.Vector.Unboxed (freeze,thaw,fromList,Vector)
-- >>> v <- thaw (fromList [1,1,2,2,2,3,3] :: Vector Int)
-- >>> groupBy v id 0 6
-- [2,5]
91
groupBy :: forall a b m. (Eq b, VU.Unbox a, PrimMonad m)
92
93
94
95
96
        => VU.MVector (PrimState m) a -- ^ The vector to group
        -> (a -> b) -- ^ The predicate
        -> Int -- ^ The beginning of the region (inclusive)
        -> Int -- ^ The end of the region (exclusive)
        -> m [Int]
Hans-Peter Deifel's avatar
Hans-Peter Deifel committed
97
groupBy vec predicate = groupByM vec (return . predicate)
98
99
100
101
102
103
104
105
106
107
{-# INLINE groupBy #-}

-- | 'groupBy' with monadic predicate.
groupByM :: forall a b m. (Eq b, VU.Unbox a, PrimMonad m)
         => VU.MVector (PrimState m) a -- ^ The vector to group
         -> (a -> m b) -- ^ The predicate
         -> Int -- ^ The beginning of the region (inclusive)
         -> Int -- ^ The end of the region (exclusive)
         -> m [Int]
groupByM vec predicate lower upper
108
109
110
111
112
113
  | lower >= upper = return [] -- special case empty range
  | otherwise = do
      first <- VU.read vec lower
      reverse . snd <$> foldM groupByImpl (first, []) [lower+1..upper-1]

  where
114
    groupByImpl :: (a,[Int]) -> Int -> m (a,[Int])
115
116
117
118
    groupByImpl (current, accu) i
      | i >= upper = return (current,accu)
      | otherwise = do
          x <- VU.read vec i
119
120
121
          predX <- predicate x
          predCurrent <- predicate current
          if predX == predCurrent then
122
123
124
            return (x,accu)
          else
            return (x,i:accu)
125
{-# INLINE groupByM #-}