Polynomial.hs 10.7 KB
Newer Older
1
2
{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE DeriveGeneric #-}
3
{-# LANGUAGE PartialTypeSignatures #-}
4
{-# LANGUAGE FlexibleContexts #-}
5
{-# LANGUAGE TemplateHaskell #-}
6
{-# LANGUAGE FlexibleInstances #-}
7
{-# LANGUAGE StrictData #-}
8
{-# LANGUAGE BangPatterns #-}
9
10

-- | Polynomial functor with co-products, products, exponentials and constants
11
module MA.Functors.Polynomial
12
13
14
15
  ( -- * Functor expression parser
    polynomial
    -- * Types exported for easier testing
  , Polynomial(..)
16
17
18
  , Sum(..)
  , Product(..)
  , Factor(..)
19
  , ConstSet(..)
20
  , Exponent(..)
21
22
23
24
25
  , SumValue(..)
  , ProductValue(..)
  , FactorValue(..)
  ) where

26
import           Data.Foldable
27
import           Control.Monad
28
import           Data.Bifunctor
29
import           Data.List (sort)
30
31
import           Data.List.NonEmpty (NonEmpty(..))
import qualified Data.List.NonEmpty as NonEmpty
32
33
import           Data.Maybe (fromMaybe)
import           Data.Tuple (swap)
34
import           GHC.Generics (Generic)
35

36
import qualified Data.Vector.Utils as V
37
import           Data.Text (Text)
38
import qualified Data.Text as T
39
40
41
import           Data.Vector (Vector)
import qualified Data.Vector as V
import           Text.Megaparsec
42
43
import           Data.Eq.Deriving (deriveEq1)
import           Text.Show.Deriving (deriveShow1)
44
import           Lens.Micro
45
import           Control.DeepSeq (NFData)
46
47
48
49
50

import           MA.Coalgebra.Parser
import           MA.Coalgebra.RefinementTypes
import qualified MA.Parser.Lexer as L
import           MA.Parser.Types
51
import           MA.RefinementInterface
52
import           MA.FunctorExpression.Parser
53
54

newtype Polynomial a = Polynomial (Sum a)
55
  deriving (Functor, Foldable, Traversable)
56

57
newtype Sum a = Sum (NonEmpty (Product a))
58
  deriving (Functor, Foldable, Traversable)
59

60
newtype Product a = Product (NonEmpty (Factor a))
61
  deriving (Functor, Foldable, Traversable)
62

63
64
data ConstSet
  = IntSet
65
  | NatSet
66
  | FiniteNatSet Int
67
68
69
  | ExplicitSet (Vector Text)
  deriving (Show, Eq)

70
71
72
73
74
data Exponent
  = FiniteNatExp Int
  | ExplicitExp (Vector Text)
  deriving (Show, Eq)

75
data Factor a
76
  = Const ConstSet
77
  | Identity a
78
  | Exponential a Exponent
79
  deriving (Functor, Foldable, Traversable)
80

81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
$(deriveEq1 ''Factor)
$(deriveEq1 ''Product)
$(deriveEq1 ''Sum)
$(deriveEq1 ''Polynomial)

$(deriveShow1 ''Factor)
$(deriveShow1 ''Product)
$(deriveShow1 ''Sum)
$(deriveShow1 ''Polynomial)

polynomial :: FunctorParser Polynomial
polynomial = FunctorParser $ \inner -> do
  parseSumExpr inner >>= \case
    Sum (Product (Identity a :| []) :| []) -> return (Left a)
    other -> return (Right (Polynomial other))

parseSumExpr :: MonadParser m => m a -> m (Sum a)
parseSumExpr inner = do
  left <- parseProductExpr inner
  rest <- many (L.symbol "+" *> parseProductExpr inner)
  return $ Sum (left :| rest)

parseProductExpr :: MonadParser m => m a -> m (Product a)
parseProductExpr inner = do
  left <- parseFactorExpr inner
106
107
  let multSign = L.symbol "x" <|> L.symbol "×"
  rest <- many (multSign *> parseFactorExpr inner)
108
109
  return $ Product (left :| rest)

110
111
parseFactorExpr :: MonadParser m => m a -> m (Factor a)
parseFactorExpr inner = (Const <$> parseConstExpr) <|> (parseIdOrExp inner)
112

113
114
115
parseConstExpr :: MonadParser m => m ConstSet
parseConstExpr =
  ((L.symbol "Z" <|> L.symbol "ℤ") >> return IntSet) <|>
116
  ((L.symbol "N" <|> L.symbol "ℕ") >> return NatSet) <|>
117
  (FiniteNatSet <$> L.decimal) <|>
118
  ((ExplicitSet . V.fromList) <$> (L.braces (L.name `sepBy` L.comma)))
119

120
121
122
parseIdOrExp :: MonadParser m => m a -> m (Factor a)
parseIdOrExp inner = do
  x <- inner
123
124
  (Exponential x <$> (try (L.symbol "^") *>
                      (parseExplicitExp <|> parseFiniteNatExp)))
125
126
    <|> (return (Identity x))

127
128
parseExplicitExp :: MonadParser m => m Exponent
parseExplicitExp = L.braces $ do
129
  names <- (V.sort . V.fromList) <$> L.name `sepBy` L.comma
130

131
  when (V.hasDuplicates names) $
132
133
    fail "exponential: domain must be uniquely defined"

134
135
136
137
  return (ExplicitExp names)

parseFiniteNatExp :: MonadParser m => m Exponent
parseFiniteNatExp = FiniteNatExp <$> L.decimal
138

139
-- Index into coproduct and corresponding product value
140
data SumValue a = SumValue Int (ProductValue a)
141
  deriving (Eq, Ord, Show, Functor, Generic, NFData)
142
143
144

data ProductValue a =
  ProductValue (Vector (FactorValue a))
145
  deriving (Eq, Ord, Show, Functor, Generic, NFData)
146
147

data FactorValue a
148
149
150
  = -- | Index of constant in vector of possible values
    ConstValue Int
    -- | One placeholder
151
  | IdValue a
152
    -- | Essentially the same as a product of "IdValue"s
153
  | ExponentialValue (Vector a)
154
  deriving (Eq, Ord, Show, Functor, Generic, NFData)
155

156
157
158
159

data Three = ToRest | ToCompound | ToSub
  deriving (Show, Eq, Ord, Enum)

160
type instance H1 Polynomial = SumValue ()
161
162
163
164
165
166
167
168

-- | Tuple @(a, b)@ of
--
-- [a]: Index of this edge in the product on the second level
-- [b]: Index of this edge in the exponential on the third level
--
-- For "Identity"s, @b@ is 0. Also note that the top-level co-product doesn't
-- appear in the label at all. It already appears in "H1".
169
type instance Label Polynomial = (Int, Int)
170
171
172
173

-- | Defined as H2
type instance Weight Polynomial = SumValue Bool

174
175
176
type instance H3 Polynomial = SumValue Three

instance ParseMorphism Polynomial where
177
178
179
180
181
182
183
  parseMorphismPoint (Polynomial expr) = parseSum1 expr

----------- Coproducts parser

-- | Parse either a single product or an injection into the coproduct, depending
-- on the number of co-factors.
parseSum1 ::
184
  MonadParser m => Sum (m a) -> m (SumValue (), Vector (a, Label Polynomial))
185
186
187
188
189
190
191
192
193
194
parseSum1 sum@(Sum (product :| [])) = do
  -- only a single summand => parse product directly

  -- We first try to parse 'inj i' and delegate to the product parser only if
  -- that fails.
  -- This avoids strange situations where a constant calle 'inj' exists and the
  -- input starts with inj.
  (try parseSumPrefix >>= parseSum sum) <|>
    (first (SumValue 0) <$> parseProduct1 product)
parseSum1 other = parseSumPrefix >>= parseSum other -- otherwise, require 'inj'
195
  <?> "coproduct injection"
196
197
198
199
200
201
202

-- | parses @inj i@ where @i@ is a decimal integer
parseSumPrefix :: MonadParser m => m Int
parseSumPrefix = L.symbol "inj" *> L.decimal

-- | Parse an injection into the coproduct with the syntax 'inj i _'
parseSum ::
203
     MonadParser m => Sum (m a) -> Int -> m (SumValue (), Vector (a, Label Polynomial))
204
205
206
207
208
209
210
211
212
213
214
215
parseSum (Sum summands) i = do
  when (i < 0 || i >= length summands) $
    fail ("polynomial: injection " ++ show i ++ " is out of bounds")

  (h1, successors) <- parseProduct1 (summands NonEmpty.!! i)

  return (SumValue i h1, successors)

----------- Products parser

-- | Parse either a single factor without parens or a tuple.
parseProduct1 ::
216
     MonadParser m => Product (m a) -> m (ProductValue (), Vector ((a, (Int, Int))))
217
parseProduct1 product@(Product (factor :| [])) =
218
  let mkProduct f = f & _1 %~ (ProductValue . V.singleton)
219
                      & _2 %~ (_Just . each . _2 %~ (0,) <&> fromMaybe V.empty)
220
221
222
223
  in (mkProduct <$> parseFactor factor) <|> parseProduct product
parseProduct1 other = parseProduct other

parseProduct ::
224
     MonadParser m => Product (m a) -> m (ProductValue (), Vector (a, (Int, Int)))
225
226
parseProduct (Product l@(f :| fs)) =
  label ("a product of " ++ show (length l) ++ " element(s)") $ L.parens $ do
227
    factors <- V.cons
228
      <$> parseFactor f
229
      <*> (V.fromList <$> traverse (\x -> L.comma *> parseFactor x) fs)
230

231
    let (h1, successors) = V.unzip factors
232
        labeledSuccessors =
233
234
          V.imap
            (\i x -> x & _Just . traversed . _2 %~ (i,) & fromMaybe V.empty)
235
            successors
236

237
    return ( ProductValue h1 , fold labeledSuccessors)
238
239
240

----------- Factor parser

241
parseFactor :: MonadParser m => Factor (m a) -> m (FactorValue (), Maybe (Vector (a, Int)))
242
parseFactor (Const (ExplicitSet names)) = do
243
  !h1 <- ConstValue <$> someName names
244
  return (h1, Nothing) -- const has no successors
245
parseFactor (Const IntSet) = do
246
  x <- L.signed L.decimal <?> "integer"
247
  return (ConstValue x, Nothing)
248
parseFactor (Const NatSet) = do
249
  x <- L.decimal <?> "natural number"
250
251
  return (ConstValue x, Nothing)
parseFactor (Const (FiniteNatSet n)) = do
252
  x <- L.decimal <?> ("natural number small than " ++ show n)
253
254
255
256
  unless (x < n) $
    fail ("out of range constant: " ++ show x ++
          "(must be between 0 and " ++ show n ++ ")")
  return (ConstValue x, Nothing)
257
258
parseFactor (Identity inner) = do
  successor <- inner
259
  return (IdValue (), Just (V.singleton (successor, 0)))
260
parseFactor (Exponential inner exp) = L.braces $ do
261
262
263
264
  successors <- V.sortOn snd . V.fromList <$>
    (flip (,) <$> parseExpValue exp
      <*> (L.colon *> inner))
    `sepBy` L.comma
265

266
  unless (allExpValues exp == (V.map snd successors)) $
267
    fail ("exponential: map must be well-defined on " ++ showExp exp)
268
269

  return ( ExponentialValue (V.replicate (length successors) ())
270
         , Just successors
271
         )
272

273
274
275
276
277
278
279
280
281
parseExpValue :: MonadParser m => Exponent -> m Int
parseExpValue (ExplicitExp names) = someName names
parseExpValue (FiniteNatExp n) = do
  x <- L.decimal
  unless (x < n) $ -- L.decimal returns only positive ints
    fail ("Value " ++ show x ++
          "is out of bounds. (must be between 0 and " ++ show n ++ ")")
  return x

282
283
284
allExpValues :: Exponent -> Vector Int
allExpValues (ExplicitExp names) = V.enumFromN 0 (length names)
allExpValues (FiniteNatExp n) = V.enumFromN 0 n
285
286
287
288
289

showExp :: Exponent -> String
showExp (ExplicitExp names) = show names
showExp (FiniteNatExp n) = "{0.." ++ show n ++ "}"

290
someName :: MonadParser m => Vector Text -> m Int
291
292
someName v = do
  name <- try L.name
293
  -- TODO Better error
294
  maybe empty return (V.elemIndex name v)
295
296
297
298
299

instance RefinementInterface Polynomial where
  init :: H1 Polynomial -> [Label Polynomial] -> Weight Polynomial
  init h1 _ = fmap (const True) h1

300
301
302
303
  update ::
       [Label Polynomial]
    -> Weight Polynomial
    -> (Weight Polynomial, H3 Polynomial, Weight Polynomial)
304
  update !labs !w = {-# SCC polynoial #-} val $! (up $! (labs, w))
305
306
    where
      val :: H3 Polynomial -> (Weight Polynomial, H3 Polynomial, Weight Polynomial)
307
308
309
310
311
      val !h3 =
        let !toS = {-# SCC a #-} fmap (== ToSub) h3
            !toC = {-# SCC a #-} fmap (== ToCompound) h3
        in
          (toS, h3, toC)
312
313

      up :: ([Label Polynomial], Weight Polynomial) -> H3 Polynomial
314
      up (!labels, !weight) = {-# SCC a #-} (fmapIndex $! (\i j bi -> bi +? ((i,j) `elem` labels))) $! weight
315
316

      (+?) :: Bool -> Bool -> Three
317
      (+?) !a !b = {-# SCC a #-}  toEnum (fromEnum a + fromEnum b)
318

319
fmapIndex :: forall a b. (Int -> Int -> a -> b) -> SumValue a -> SumValue b
320
321
322
fmapIndex f (SumValue !s (ProductValue !factors)) =
  let !res = V.imap' fmapFactor factors
  in (SumValue $! s) $! (ProductValue $! res)
323
324
325

  where
    fmapFactor :: Int -> FactorValue a -> FactorValue b
326
327
    fmapFactor !i (ExponentialValue !as) = ExponentialValue (V.imap' (f i) as)
    fmapFactor !i !other = (fmap $! (f i 0)) $! other