Commit 1d9f4ec9 authored by Hans-Peter Deifel's avatar Hans-Peter Deifel
Browse files

Merge branch 'poly-speedup'

parents cbe5bfdc 8a0a8b38
......@@ -18,32 +18,28 @@ module MA.Functors.Polynomial
, Factor(..)
, ConstSet(..)
, Exponent(..)
, SumValue(..)
, ProductValue(..)
, FactorValue(..)
, PolyH1(..)
) where
import Data.Foldable
import Control.Monad
import Data.Bifunctor
import Data.List.NonEmpty (NonEmpty(..))
import qualified Data.List.NonEmpty as NonEmpty
import Data.Maybe (fromMaybe)
import Data.Traversable
import Data.Word (Word8)
import GHC.Generics (Generic)
import Data.List (intersperse)
import Data.Semigroup ((<>))
import qualified Data.Vector.Utils as V
import Data.Text (Text)
import Data.Vector (Vector)
import qualified Data.Vector as V
import qualified Data.Vector.Unboxed as VU
import Text.Megaparsec
import Data.Eq.Deriving (deriveEq1)
import Text.Show.Deriving (deriveShow1)
import Lens.Micro
import Control.DeepSeq (NFData)
import Control.DeepSeq
import MA.PrettyShow
import MA.Coalgebra.Parser
import MA.Coalgebra.RefinementTypes
import qualified MA.Parser.Lexer as L
......@@ -144,54 +140,61 @@ parseExplicitExp = L.braces $ do
parseFiniteNatExp :: MonadParser m => m Exponent
parseFiniteNatExp = FiniteNatExp <$> L.decimal
-- Index into coproduct and corresponding product value
data SumValue a = SumValue Int (ProductValue a)
deriving (Eq, Ord, Show, Functor, Generic, NFData)
instance PrettyShow a => PrettyShow (SumValue a) where
prettyShow (SumValue i p) = "inj" <> prettyShow i <> " " <> prettyShow p
-- | Morally the type '3'.
--
-- We use Word8 for efficiency, so this should only ever have the values 0, 1
-- and 2 with the following meanings:
--
-- 0: Edge goes to rest
-- 1: Edge goes to compound block
-- 2: Edge goes to subblock
type Three = Word8
data ProductValue a =
ProductValue (Vector (FactorValue a))
deriving (Eq, Ord, Show, Functor, Generic, NFData)
toCompound :: Three
toCompound = 1
instance PrettyShow a => PrettyShow (ProductValue a) where
prettyShow (ProductValue v) =
"(" <> mconcat (intersperse ", " (map prettyShow (V.toList v))) <> ")"
toSub :: Three
toSub = 2
data FactorValue a
= -- | Index of constant in vector of possible values
ConstValue Int
-- | One placeholder
| IdValue a
-- | Essentially the same as a product of "IdValue"s
| ExponentialValue (Vector a)
deriving (Eq, Ord, Show, Functor, Generic, NFData)
instance PrettyShow a => PrettyShow (FactorValue a) where
prettyShow (ConstValue i) = prettyShow i
prettyShow (IdValue a) = prettyShow a
prettyShow (ExponentialValue v) =
"[" <> mconcat (intersperse ", " (map prettyShow (V.toList v))) <> "]"
-- | H1 for Polynomial
--
-- TODO: Use unboxed vector for constants
data PolyH1 = PolyH1
{ polyH1Summand :: {-# UNPACK #-} Int
-- ^ Index into sum
, polyH1Variables :: {-# UNPACK #-} Int
-- ^ Number of variable factors
, polyH1Constants :: {-# UNPACK #-} (Vector Int)
-- ^ Values of constant factors
}
deriving (Eq,Show,Ord,NFData,Generic)
data Three = ToRest | ToCompound | ToSub
deriving (Show, Eq, Ord, Enum)
type instance H1 Polynomial = PolyH1
type instance H1 Polynomial = SumValue ()
-- | Tuple @(a, b)@ of
-- | Index of this edge into the product
--
-- [a]: Index of this edge in the product on the second level
-- [b]: Index of this edge in the exponential on the third level
--
-- For "Identity"s, @b@ is 0. Also note that the top-level co-product doesn't
-- appear in the label at all. It already appears in "H1".
type instance Label Polynomial = (Int, Int)
-- Note that this ignores constant factors, so an edge to the X in (2, 5, X)
-- would have 0 as Label, not 2. Also, the top-level co-product doesn't appear
-- in the label at all. It already appears in "H1".
type instance Label Polynomial = Int
-- | Defined as H2
type instance Weight Polynomial = SumValue Bool
--
-- For each variable in the product, this is True if the corresponding successor
-- state belongs to the subblock and False otherwise.
type instance Weight Polynomial = VU.Vector Bool
-- | Defined as H3
--
-- Same as for Weight, but now with the three states described in the
-- documentation for Three.
type instance H3 Polynomial = VU.Vector Three
type instance H3 Polynomial = SumValue Three
instance ParseMorphism Polynomial where
parseMorphismPoint (Polynomial expr) = parseSum1 expr
......@@ -201,7 +204,7 @@ instance ParseMorphism Polynomial where
-- | Parse either a single product or an injection into the coproduct, depending
-- on the number of co-factors.
parseSum1 ::
MonadParser m => Sum (m a) -> m (SumValue (), Vector (a, Label Polynomial))
MonadParser m => Sum (m a) -> m (H1 Polynomial, Vector (a, Label Polynomial))
parseSum1 sum@(Sum (product :| [])) = do
-- only a single summand => parse product directly
......@@ -210,7 +213,7 @@ parseSum1 sum@(Sum (product :| [])) = do
-- This avoids strange situations where a constant calle 'inj' exists and the
-- input starts with inj.
(try parseSumPrefix >>= parseSum sum) <|>
(first (SumValue 0) <$> parseProduct1 product)
(first (uncurry (PolyH1 0)) <$> parseProduct1 product)
parseSum1 other = parseSumPrefix >>= parseSum other -- otherwise, require 'inj'
<?> "coproduct injection"
......@@ -220,63 +223,69 @@ parseSumPrefix = L.symbol "inj" *> L.decimal
-- | Parse an injection into the coproduct with the syntax 'inj i _'
parseSum ::
MonadParser m => Sum (m a) -> Int -> m (SumValue (), Vector (a, Label Polynomial))
MonadParser m => Sum (m a) -> Int -> m (H1 Polynomial, Vector (a, Label Polynomial))
parseSum (Sum summands) i = do
when (i < 0 || i >= length summands) $
fail ("polynomial: injection " ++ show i ++ " is out of bounds")
(h1, successors) <- parseProduct1 (summands NonEmpty.!! i)
return (SumValue i h1, successors)
return (uncurry (PolyH1 i) h1, successors)
----------- Products parser
-- | Parse either a single factor without parens or a tuple.
parseProduct1 ::
MonadParser m => Product (m a) -> m (ProductValue (), Vector ((a, (Int, Int))))
MonadParser m => Product (m a) -> m ((Int, Vector Int), Vector (a, Label Polynomial))
parseProduct1 product@(Product (factor :| [])) =
let mkProduct f = f & _1 %~ (ProductValue . V.singleton)
& _2 %~ (_Just . each . _2 %~ (0,) <&> fromMaybe V.empty)
let mkProduct = either (\i -> ((0, V.singleton i), V.empty))
(\v -> ((length v, V.empty), v))
in (mkProduct <$> parseFactor factor) <|> parseProduct product
parseProduct1 other = parseProduct other
parseProduct ::
MonadParser m => Product (m a) -> m (ProductValue (), Vector (a, (Int, Int)))
parseProduct
:: MonadParser m
=> Product (m a)
-> m ((Int, Vector Int), Vector (a, Label Polynomial))
parseProduct (Product l@(f :| fs)) =
label ("a product of " ++ show (length l) ++ " element(s)") $ L.parens $ do
factors <- V.cons
factors <-
V.cons
<$> parseFactor f
<*> (V.fromList <$> traverse (\x -> L.comma *> parseFactor x) fs)
let (h1, successors) = V.unzip factors
labeledSuccessors =
V.imap
(\i x -> x & _Just . traversed . _2 %~ (i,) & fromMaybe V.empty)
successors
let
constants = V.fromList (factors ^.. each . _Left)
labels = factors ^.. each . _Right
(numFactors, successors) = mapAccumL
(\cur f -> (cur + length f, f & each . _2 %~ (+ cur)))
0
labels
return ( ProductValue h1 , fold labeledSuccessors)
return ((numFactors, constants), V.concat successors)
----------- Factor parser
parseFactor :: MonadParser m => Factor (m a) -> m (FactorValue (), Maybe (Vector (a, Int)))
parseFactor :: MonadParser m => Factor (m a) -> m (Either Int (Vector (a, Int)))
parseFactor (Const (ExplicitSet names)) = do
!h1 <- ConstValue <$> someName names
return (h1, Nothing) -- const has no successors
!h1 <- Left <$> someName names
return h1 -- const has no successors
parseFactor (Const IntSet) = do
x <- L.signed L.decimal <?> "integer"
return (ConstValue x, Nothing)
return (Left x)
parseFactor (Const NatSet) = do
x <- L.decimal <?> "natural number"
return (ConstValue x, Nothing)
return (Left x)
parseFactor (Const (FiniteNatSet n)) = do
x <- L.decimal <?> ("natural number small than " ++ show n)
unless (x < n) $
fail ("out of range constant: " ++ show x ++
"(must be between 0 and " ++ show n ++ ")")
return (ConstValue x, Nothing)
return (Left x)
parseFactor (Identity inner) = do
successor <- inner
return (IdValue (), Just (V.singleton (successor, 0)))
return (Right (V.singleton (successor, 0)))
parseFactor (Exponential inner exp) = L.braces $ do
successors <- V.sortOn snd . V.fromList <$>
(flip (,) <$> parseExpValue exp
......@@ -286,9 +295,7 @@ parseFactor (Exponential inner exp) = L.braces $ do
unless (allExpValues exp == (V.map snd successors)) $
fail ("exponential: map must be well-defined on " ++ showExp exp)
return ( ExponentialValue (V.replicate (length successors) ())
, Just successors
)
return (Right successors)
parseExpValue :: MonadParser m => Exponent -> m Int
parseExpValue (ExplicitExp names) = someName names
......@@ -315,7 +322,7 @@ someName v = do
instance RefinementInterface Polynomial where
init :: H1 Polynomial -> [Label Polynomial] -> Weight Polynomial
init h1 _ = fmap (const True) h1
init h1 _ = VU.replicate (polyH1Variables h1) True
update ::
[Label Polynomial]
......@@ -325,23 +332,12 @@ instance RefinementInterface Polynomial where
where
val :: H3 Polynomial -> (Weight Polynomial, H3 Polynomial, Weight Polynomial)
val h3 =
let toS = fmap (== ToSub) h3
toC = fmap (== ToCompound) h3
let toS = VU.map (== toSub) h3
toC = VU.map (== toCompound) h3
in
(toS, h3, toC)
up :: ([Label Polynomial], Weight Polynomial) -> H3 Polynomial
up (labels, weight) = fmapIndex (\i j bi -> bi +? ((i,j) `elem` labels)) weight
(+?) :: Bool -> Bool -> Three
(+?) a b = toEnum (fromEnum a + fromEnum b)
fmapIndex :: forall a b. (Int -> Int -> a -> b) -> SumValue a -> SumValue b
fmapIndex f (SumValue !s (ProductValue !factors)) =
let !res = V.imap fmapFactor factors
in SumValue s (ProductValue res)
where
fmapFactor :: Int -> FactorValue a -> FactorValue b
fmapFactor i (ExponentialValue as) = ExponentialValue (V.imap (f i) as)
fmapFactor i other = (fmap (f i 0)) other
up (labels, weight) =
let labels'' = map (\i -> (i, toEnum (fromEnum (VU.unsafeIndex weight i)) + 1)) labels
in VU.unsafeUpd (VU.map (toEnum.fromEnum) weight :: VU.Vector Word8) labels''
......@@ -195,15 +195,15 @@ parseMorphismPointSpec = describe "parseMorphismPoint" $ do
it "parses a constant" $ do
morphp (mkPoly [[c ["a", "b", "c"]]]) "x: inj 0 (a)" `shouldParse`
encoding [(Sorted 1 (mkVal 0 [ConstValue 0]))] []
encoding [(Sorted 1 (h1 0 0 [0]))] []
morphp (mkPoly [[c ["a", "b", "c"]]]) "x: inj 0 (b)" `shouldParse`
encoding [(Sorted 1 (mkVal 0 [ConstValue 1]))] []
encoding [(Sorted 1 (h1 0 0 [1]))] []
it "parses the identity" $
morphp (mkPoly [[Identity Variable]]) "x: inj 0 (y)\ny: inj 0 (x)" `shouldParse`
encoding
[(Sorted 1 (mkVal 0 [IdValue ()])), (Sorted 1 (mkVal 0 [IdValue ()]))]
[(1, (Sorted 1 (0, 0)), 0), (0, (Sorted 1 (0, 0)), 1)]
[(Sorted 1 (h1 0 1 [])), (Sorted 1 (h1 0 1 []))]
[(1, (Sorted 1 0), 0), (0, (Sorted 1 0), 1)]
it "gives a useful error if the injection index is out of bounds" $ do
morphp (mkPoly [[c ["a"]]]) `shouldFailOn `"x: inj 5 (a)"
......@@ -211,61 +211,52 @@ parseMorphismPointSpec = describe "parseMorphismPoint" $ do
it "parses a product of a constant and an X" $
morphp (mkPoly [[c ["a"], Identity Variable]]) "x: inj 0 (a, x)" `shouldParse`
encoding
[(Sorted 1 (mkVal 0 [ConstValue 0, IdValue ()]))]
[(0, (Sorted 1 (1, 0)), 0)]
encoding [(Sorted 1 (h1 0 1 [0]))] [(0, (Sorted 1 0), 0)]
it "parses a product of two elements" $
morphp (mkPoly [[Identity Variable, Identity Variable]]) "x: inj 0 (x, x)" `shouldParse`
encoding
[(Sorted 1 (mkVal 0 [IdValue (), IdValue ()]))]
[(0, (Sorted 1 (0, 0)), 0), (0, (Sorted 1 (1, 0)), 0)]
encoding [(Sorted 1 (h1 0 2 []))] [(0, (Sorted 1 0), 0), (0, (Sorted 1 1), 0)]
it "parses a sum of two constants" $
morphp (mkPoly [[c ["a"]], [c ["b"]]]) "x: inj 0 (a)\ny: inj 1 (b)" `shouldParse`
encoding
[(Sorted 1 (mkVal 0 [ConstValue 0])), (Sorted 1 (mkVal 1 [ConstValue 0]))]
[]
encoding [(Sorted 1 (h1 0 0 [0])), (Sorted 1 (h1 1 0 [0]))] []
it "parses X+(AxX)" $
morphp
(mkPoly [[Identity Variable], [c ["a"], Identity Variable]])
"x: inj 0 (y)\ny: inj 1 (a, x)" `shouldParse`
encoding
[ (Sorted 1 (mkVal 0 [IdValue ()]))
, (Sorted 1 (mkVal 1 [ConstValue 0, IdValue ()]))
[ (Sorted 1 (h1 0 1 []))
, (Sorted 1 (h1 1 1 [0]))
]
[(1, (Sorted 1 (1, 0)), 0), (0, (Sorted 1 (0, 0)), 1)]
[(1, (Sorted 1 0), 0), (0, (Sorted 1 0), 1)]
it "allows to omit 'inj' for co-products with only one factor" $ do
morphp (mkPoly [[c ["a"]]]) "x: (a)" `shouldParse`
encoding [(Sorted 1 (mkVal 0 [ConstValue 0]))] []
encoding [(Sorted 1 (h1 0 0 [0]))] []
morphp (mkPoly [[Identity Variable, Identity Variable]]) "x: (x, x)" `shouldParse`
encoding
[(Sorted 1 (mkVal 0 [IdValue (), IdValue ()]))]
[(0, (Sorted 1 (0, 0)), 0), (0, (Sorted 1 (1, 0)), 0)]
encoding [(Sorted 1 (h1 0 2 []))] [(0, (Sorted 1 0), 0), (0, (Sorted 1 1), 0)]
it "doesn't confuse a constant called inj and an injection" $ do
morphp (mkPoly [[c ["injection"]]]) "x: injection" `shouldParse`
encoding [(Sorted 1 (mkVal 0 [ConstValue 0]))] []
encoding [(Sorted 1 (h1 0 0 [0]))] []
morphp (mkPoly [[c ["inj"]]]) "x: inj" `shouldParse`
encoding [(Sorted 1 (mkVal 0 [ConstValue 0]))] []
encoding [(Sorted 1 (h1 0 0 [0]))] []
morphp (mkPoly [[c ["inj"]]]) "x: inj 0 inj" `shouldParse`
encoding [(Sorted 1 (mkVal 0 [ConstValue 0]))] []
encoding [(Sorted 1 (h1 0 0 [0]))] []
it "allows to omit parens for products with only one factor" $ do
morphp (mkPoly [[c ["a"]], [c ["b"]]]) "x: inj 0 a" `shouldParse`
encoding [(Sorted 1 (mkVal 0 [ConstValue 0]))] []
encoding [(Sorted 1 (h1 0 0 [0]))] []
it "allows to omit both 'inj' and parens" $ do
morphp (mkPoly [[c ["a"]]]) "x: a" `shouldParse`
encoding [(Sorted 1 (mkVal 0 [ConstValue 0]))] []
encoding [(Sorted 1 (h1 0 0 [0]))] []
it "parses an exponential" $ do
morphp (mkPoly [[e Variable ["a"]]]) "x: {a: x}" `shouldParse`
encoding
[(Sorted 1 (mkVal 0 [ExponentialValue (v [()])]))]
[(0, (Sorted 1 (0, 0)), 0)]
encoding [(Sorted 1 (h1 0 1 []))] [(0, (Sorted 1 0), 0)]
it "fails to parse an exponential that isn't totally defined" $ do
morphp (mkPoly [[e Variable ["a", "b"]]])
......@@ -280,35 +271,35 @@ parseMorphismPointSpec = describe "parseMorphismPoint" $ do
"x: inj0 {a: x, b: y}\ny: inj1 y"
`shouldParse`
(encoding
[ (Sorted 1 (mkVal 0 [ExponentialValue (v [(), ()])]))
, (Sorted 1 (mkVal 1 [IdValue ()]))
[ (Sorted 1 (h1 0 2 []))
, (Sorted 1 (h1 1 1 []))
]
[ (1, (Sorted 1 (0, 0)), 1)
, (0, (Sorted 1 (0, 0)), 0)
, (0, (Sorted 1 (0, 1)), 1)
[ (1, (Sorted 1 0), 1)
, (0, (Sorted 1 0), 0)
, (0, (Sorted 1 1), 1)
])
it "allows positive numbers as constants for integer set" $ do
morphp (mkPoly [[Const IntSet]]) "x: 5\ny:30" `shouldParse`
(encoding
[ (Sorted 1 (mkVal 0 [ConstValue 5]))
, (Sorted 1 (mkVal 0 [ConstValue 30]))
[ (Sorted 1 (h1 0 0 [5]))
, (Sorted 1 (h1 0 0 [30]))
]
[])
it "allows positive numbers as constants for integer set" $ do
morphp (mkPoly [[Const IntSet]]) "x: -3\ny:-77" `shouldParse`
(encoding
[ (Sorted 1 (mkVal 0 [ConstValue (-3)]))
, (Sorted 1 (mkVal 0 [ConstValue (-77)]))
[ (Sorted 1 (h1 0 0 [(-3)]))
, (Sorted 1 (h1 0 0 [(-77)]))
]
[])
it "allows positive numbers as constants for naturals set" $ do
morphp (mkPoly [[Const NatSet]]) "x: 5\ny:30" `shouldParse`
(encoding
[ (Sorted 1 (mkVal 0 [ConstValue 5]))
, (Sorted 1 (mkVal 0 [ConstValue 30]))
[ (Sorted 1 (h1 0 0 [5]))
, (Sorted 1 (h1 0 0 [30]))
]
[])
......@@ -317,8 +308,7 @@ parseMorphismPointSpec = describe "parseMorphismPoint" $ do
it "allows numbers in the correct range as constants for FiniteNatSet" $ do
let f = mkPoly [[Const (FiniteNatSet 4)]]
morphp f "x: 0" `shouldParse`
(encoding [(Sorted 1 (mkVal 0 [ConstValue 0]))] [])
morphp f "x: 0" `shouldParse` (encoding [(Sorted 1 (h1 0 0 [0]))] [])
morphp f `shouldSucceedOn` "x: 1"
morphp f `shouldSucceedOn` "x: 2"
morphp f `shouldSucceedOn` "x: 3"
......@@ -333,8 +323,8 @@ parseMorphismPointSpec = describe "parseMorphismPoint" $ do
morphp (mkPoly [[Exponential Variable (FiniteNatExp 2)]]) "x: {0: x, 1: x}"
`shouldParse`
(encoding
[(Sorted 1 (mkVal 0 [ExponentialValue (v [(), ()])]))]
[(0, (Sorted 1 (0, 0)), 0), (0, (Sorted 1 (0, 1)), 0)])
[(Sorted 1 (h1 0 2 []))]
[(0, (Sorted 1 0), 0), (0, (Sorted 1 1), 0)])
it "requires all values of the exponent set for FiniteNatExp" $ do
morphp (mkPoly [[Exponential Variable (FiniteNatExp 2)]])
......@@ -343,7 +333,7 @@ parseMorphismPointSpec = describe "parseMorphismPoint" $ do
it "can correctly parse constants from a set with shared prefixes" $ do
morphp (mkPoly [[Const (ExplicitSet (v ["a1", "a10"]))]]) "x: a10"
`shouldParse`
(encoding [(Sorted 1 (mkVal 0 [ConstValue 1]))] [])
(encoding [(Sorted 1 (h1 0 0 [1]))] [])
refineSpec :: Spec
refineSpec = describe "refining" $ do
......@@ -392,8 +382,10 @@ mkPoly :: [[Factor a]] -> Polynomial a
mkPoly =
Polynomial . Sum . NonEmpty.fromList . map (Product . NonEmpty.fromList)
mkVal :: Int -> [FactorValue a] -> SumValue a
mkVal i = SumValue i . ProductValue . v
h1 :: Int -> Int -> [Int] -> PolyH1
h1 s v = PolyH1 s v . V.fromList
v :: [a] -> Vector a
v = V.fromList
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment