SumBag.hs 5.23 KB
Newer Older
1
2
3
4
5
6
7
8
9
{-# LANGUAGE RoleAnnotations #-}

module Data.SumBag
  ( SumBag
  , empty
  , singleton
  , sum
  , insert
  , delete
10
  , elem
11
12
  , toAscList
  , fromList
13
14
  ) where

15
import Prelude hiding (sum, min, elem)
Hans-Peter Deifel's avatar
Hans-Peter Deifel committed
16
import Data.Foldable hiding (sum,elem)
17
18
19
20
21
import qualified Data.List.NonEmpty as NE

type SumBag a = Tree a

data Tree a = Leaf | Node (MetaData a) (Element a) (Tree a) (Tree a)
22
23
  deriving (Show)

Hans-Peter Deifel's avatar
Hans-Peter Deifel committed
24
25
type role Tree nominal

26
27
instance (Ord a, Eq a) => Eq (Tree a) where
  x == y = toAscList x == toAscList y
28

Hans-Peter Deifel's avatar
Hans-Peter Deifel committed
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
-- TODO There are a few functions from foldable that can be implemented way more
-- efficiently.
--
-- Notably 'minimum' and 'maximum', but also explicit recursion instead of
-- conversions to lists in a lot of cases.
instance Foldable Tree where
  foldMap f = foldMap f . toAscList
  {-# INLINE foldMap #-}

  fold = sum
  {-# INLINE fold #-}

  toList = toAscList
  {-# INLINE toList #-}
  
44
45
46
47
48

data MetaData a = MetaData
  { nodeSize :: Int
  , nodeSum :: a
  }
49
  deriving (Show)
50
51
52
53
54

data Element a = Element
  { value :: a
  , multiplicity :: NE.NonEmpty a
  }
55
  deriving (Show)
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79


empty :: SumBag a
empty = Leaf

singleton :: Monoid a => a -> SumBag a
singleton a =
  node (Element a (NE.fromList [a])) Leaf Leaf

size :: SumBag a -> Int
size Leaf = 0
size (Node node _ _ _) = nodeSize node

sum :: Monoid a => SumBag a -> a
sum Leaf = mempty
sum (Node node _ _ _) = nodeSum node

insert :: (Ord a, Monoid a) => a -> SumBag a -> SumBag a
insert a Leaf = node (Element a (NE.fromList [a])) Leaf Leaf
insert a (Node _ e left right)
  | a < value e = balance1 e (insert a left) right
  | a > value e = balance1 e left (insert a right)
  | otherwise = node (addOnce e) left right

80
81
82
83
84
85
86
elem :: (Ord a) => a -> SumBag a -> Bool
elem _ Leaf = False
elem a (Node _ e left right)
  | a < value e = elem a left
  | a > value e = elem a right
  | otherwise = True

87
88
89
90
91
92
93
94
95
96
97
98
99
100
delete :: (Ord a, Monoid a) => a -> SumBag a -> SumBag a
delete _ Leaf = Leaf
delete a (Node _ e left right)
  | a < value e = balance1 e (delete a left) right
  | a > value e = balance1 e left (delete a right)
  | Just e' <- delOnce e = node e' left right
  | otherwise = helper left right

  where helper Leaf right = right
        helper left Leaf = left
        helper left right =
          let (min, rest) = delmin right
          in balance1 min left rest

Hans-Peter Deifel's avatar
Hans-Peter Deifel committed
101
toAscList :: SumBag a -> [a]
102
103
104
105
106
107
108
109
110
111
toAscList bag = helper bag []
  where helper Leaf accu = accu
        helper (Node _ e left right) accu =
          helper left (mkList e ++ helper right accu)

        mkList (Element val mult) = map (const val) (NE.toList mult)

fromList :: (Ord a, Monoid a) => [a] -> SumBag a
fromList = foldr insert empty

112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
-- Internal functions

-- | "Smart" constructor for Node. Will compute the meta data from its subtrees
node :: Monoid a => Element a -> Tree a -> Tree a -> Tree a
node a left right =
  let nodeData = MetaData
        { nodeSize = size left + 1 + size right
        , nodeSum = NE.head (multiplicity a) <> sum left <> sum right
        }
  in Node nodeData a left right

rotateSingleLeft :: Monoid a => Element a -> Tree a -> Tree a -> Tree a
rotateSingleLeft a x (Node _ b y z) = node b (node a x y) z
rotateSingleLeft _ _ _ = error "rotateSingleLeft called with empty right tree"

rotateSingleRight :: Monoid a => Element a -> Tree a -> Tree a -> Tree a
rotateSingleRight b (Node _ a x y) z = node a x (node b y z)
rotateSingleRight _ _ _ = error "rotateSingleRight called with empty left tree"

rotateDoubleLeft :: Monoid a => Element a -> Tree a -> Tree a -> Tree a
rotateDoubleLeft a x (Node _ c (Node _ b y1 y2) z) = node b (node a x y1) (node c y2 z)
133
rotateDoubleLeft _ _ _ = error "rotateDoubleLeft called with too small right tree"
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154

rotateDoubleRight :: Monoid a => Element a -> Tree a -> Tree a -> Tree a
rotateDoubleRight c (Node _ a x (Node _ b y1 y2)) z = node b (node a x y1) (node c y2 z)
rotateDoubleRight _ _ _ = error "rotateDoubleRight called with too small left tree"


balance1 :: Monoid a => Element a -> Tree a -> Tree a -> Tree a
balance1 a left right
  -- Subtrees have only one element
  | size left + size right < 2 = node a left right
  -- Right subtree is too heavy
  | size right > balanceBound * size left =
    let Node _ _ rleft rright = right
        sizeRL = size rleft
        sizeRR = size rright
    in if sizeRL < sizeRR then rotateSingleLeft a left right else rotateDoubleLeft a left right
  -- Left subtree is too heavy
  | size left > balanceBound * size right =
    let Node _ _ lleft lright = left
        sizeLL = size lleft
        sizeLR = size lright
155
    in if sizeLR < sizeLL then rotateSingleRight a left right else rotateDoubleRight a left right
156
  -- No subtree is too heavy, we can just form a new tree straight away
157
158
159
160
161
162
163
164
165
166
167
168
169
  | otherwise = node a left right

addOnce :: Semigroup a => Element a -> Element a
addOnce e = let total = NE.head (multiplicity e)
            in e { multiplicity = NE.cons (total <> value e) (multiplicity e) }

delOnce :: Element a -> Maybe (Element a)
delOnce e = case snd (NE.uncons (multiplicity e)) of
  Nothing -> Nothing
  Just rest -> Just (e { multiplicity = rest })

delmin :: Monoid a => Tree a -> (Element a, Tree a)
delmin Leaf = error "delmin: Empty tree"
170
delmin (Node _ e Leaf right) = (e, right)
171
172
173
174
delmin (Node _ e left right) = (\left' -> balance1 e left' right) <$> delmin left

balanceBound :: Int
balanceBound = 4