Skip to content

Commit

Permalink
Merge pull request #257 from BillHallahan/reducer_conc
Browse files Browse the repository at this point in the history
Reducer conc
  • Loading branch information
QHWU1228 authored Dec 29, 2024
2 parents 84dac79 + 717c9ee commit e42324d
Show file tree
Hide file tree
Showing 4 changed files with 125 additions and 86 deletions.
51 changes: 42 additions & 9 deletions src/G2/Execution/Rules.hs
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ import qualified Data.HashMap.Lazy as HM
import qualified Data.List as L
import qualified Data.Sequence as S
import G2.Data.Utils
import qualified G2.Data.UFMap as UF

import Control.Exception

stdReduce :: (Solver solver, Simplifier simplifier) => Sharing -> SymbolicFuncEval t -> solver -> simplifier -> State t -> Bindings -> IO (Rule, [(State t, ())], Bindings)
Expand Down Expand Up @@ -269,7 +271,8 @@ evalCase s@(State { expr_env = eenv
dbind = [(bind, mexpr)]
expr' = liftCaseBinds dbind expr
pbinds = zip params ar'
(eenv', expr'', ng', news) = liftBinds pbinds eenv expr' ng
(eenv', expr'', ng', news) = liftBinds kv pbinds eenv expr' ng

in
assert (length params == length ar')
( RuleEvalCaseData news
Expand Down Expand Up @@ -397,6 +400,7 @@ concretizeVarExpr' s@(State {expr_env = eenv, type_env = tenv, known_values = kv
, concretized = [mexpr_id]
}, ngen'')
where

-- Make sure that the parameters do not conflict in their symbolic reps.
olds = map idName params
clean_olds = map cleanName olds
Expand Down Expand Up @@ -1106,22 +1110,51 @@ addExtConds s ng e1 ais e2 stck =
in
([strue, sfalse], ng)

-- | Inject binds into the eenv. The LHS of the [(Id, Expr)] are treated as
-- seed values for the names.
liftBinds :: [(Id, Expr)] -> E.ExprEnv -> Expr -> NameGen ->
-- This function aims to extract pairs of types being coerced between. Given a coercion t1 :~ t2, the tuple (t1, t2) is returned.
extractTypes :: KnownValues -> Id -> (Type, Type)
extractTypes kv (Id _ (TyApp (TyApp (TyApp (TyApp (TyCon n _) _) _) n1) n2)) =
(if KV.tyCoercion kv == n
then
(n1, n2)
else
error "ExtractTypes: the center of the pattern is not a coercion")
extractTypes _ _ = error "ExtractTypes: The type of the pattern doesn't have four nested TyApp while its corresponding scrutinee is a coercion"

liftBinds :: KnownValues -> [(Id, Expr)] -> E.ExprEnv -> Expr -> NameGen ->
(E.ExprEnv, Expr, NameGen, [Name])
liftBinds binds eenv expr ngen = (eenv', expr', ngen', news)
where
(bindsLHS, bindsRHS) = unzip binds
liftBinds kv binds eenv expr ngen = (eenv', expr'', ngen', news)

olds = map (idName) bindsLHS
where
-- Converts type variables into corresponding types as determined by coercions
-- For example, in 'E a b c' where
-- 'a ~# Int', 'b ~# Float', 'c ~# String'
-- The code simply does the following:
-- 'E a b c' -> 'E Int Float String'
(coercion, value_args) = L.partition (\(_, e) -> case e of
Coercion _ -> True
_ -> False) binds

extract_tys = map (extractTypes kv . fst) coercion

uf_map = foldM (\uf_map' (t1, t2) -> T.unify' uf_map' t1 t2) UF.empty extract_tys

expr' = case uf_map of
Nothing -> expr
Just uf_map' -> L.foldl' (\e (n,t) -> retype (Id n (typeOf t)) t e) expr (HM.toList $ UF.toSimpleMap uf_map')

-- bindsLHS is the pattern
-- bindsRHS is the scrutinee
(bindsLHS, bindsRHS) = unzip value_args

olds = map idName bindsLHS
(news, ngen') = freshSeededNames olds ngen

olds_news = HM.fromList $ zip olds news
expr' = renamesExprs olds_news expr

eenv' = E.insertExprs (zip news bindsRHS) eenv

expr'' = renamesExprs olds_news expr'

liftBind :: Id -> Expr -> E.ExprEnv -> Expr -> NameGen ->
(E.ExprEnv, Expr, NameGen, Name)
liftBind bindsLHS bindsRHS eenv expr ngen = (eenv', expr', ngen', new)
Expand Down
2 changes: 1 addition & 1 deletion src/G2/Translation/HaskellCheck.hs
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ createDecls pg s = mapM_ runDecls . createDeclsStr pg s
adjustDynFlags :: Ghc ()
adjustDynFlags = do
dyn <- getSessionDynFlags
let dyn' = xopt_set (xopt_set dyn MagicHash) UnboxedTuples
let dyn' = foldl' xopt_set dyn [MagicHash, UnboxedTuples, DataKinds]
dyn'' = wopt_unset dyn' Opt_WarnOverlappingPatterns
_ <- setSessionDynFlags dyn''
return ()
Expand Down
12 changes: 12 additions & 0 deletions tests/Test.hs
Original file line number Diff line number Diff line change
Expand Up @@ -458,6 +458,18 @@ extensionTests = testGroup "Extensions"
, ("callF3", 400, [AtLeast 2])
, ("callG", 400, [AtLeast 1])
, ("callG2", 400, [AtLeast 1]) ]

, checkInputOutputs "tests/TestFiles/Extensions/GADTs1.hs" [ ("vecZipConc", 400, [Exactly 1])
, ("vecZipConc2", 400, [Exactly 1])
, ("vecHeadEx", 400, [Exactly 1])
, ("doubleVec", 400, [Exactly 1])
, ("tailVec", 400, [Exactly 1])
, ("tailPairVec", 400, [Exactly 1])
, ("exampleExpr1", 400, [Exactly 1])
, ("exampleExpr2", 400, [Exactly 1])
, ("exampleExpr3", 400, [Exactly 1])
, ("exampleExpr4", 400, [Exactly 1])
, ("exampleExpr5", 400, [Exactly 1]) ]
]

baseTests :: TestTree
Expand Down
146 changes: 70 additions & 76 deletions tests/TestFiles/Extensions/GADTs1.hs
Original file line number Diff line number Diff line change
@@ -1,86 +1,84 @@
{-# LANGUAGE GADTs, DataKinds, KindSignatures, RankNTypes, TypeFamilies #-}
{-# LANGUAGE GADTs, DataKinds, KindSignatures, RankNTypes, TypeFamilies, FlexibleInstances, FlexibleContexts #-}

module GADTS1 where

import GHC.TypeLits
import Data.Kind

data ShapeType = Circle | Rectangle
-- example of recursive GADT
data Expr a where
Lit :: Int -> Expr Int
Add :: Expr Int -> Expr Int -> Expr Int
IsZero :: Expr Int -> Expr Bool
If :: Expr Bool -> Expr a -> Expr a -> Expr a

data Shape where
CircleShape :: Double -> Shape
RectangleShape :: Double -> Double -> Shape
instance Eq (Expr a) where
(Lit x) == (Lit y) = x == y

(Add e1 e2) == (Add e3 e4) = e1 == e3 && e2 == e4

(IsZero e1) == (IsZero e2) = e1 == e2

(If c1 t1 f1) == (If c2 t2 f2) = c1 == c2 && t1 == t2 && f1 == f2

-- If the constructors are different, the expressions are not equal
_ == _ = False

area :: Shape -> Double
area (CircleShape radius) = pi * radius * radius
area (RectangleShape width height) = width * height
eval :: Expr a -> a
eval (Lit n) = n
eval (Add x y) = eval x + eval y
eval (IsZero x) = eval x == 0
eval (If cond t e) = if eval cond then eval t else eval e

-- infixr :>
exampleConditional :: Expr Int
exampleConditional = If (IsZero (Lit 0)) (Lit 42) (Lit 0)

-- data HList where
-- Nil :: HList
-- (:>) :: forall a . (Num a, Show a) => a -> HList -> HList
evalEC :: Int
evalEC = eval exampleConditional

-- hlistHeadStr :: HList -> String
-- hlistHeadStr (x :> xs) = show (x + 1)
exampleExpr1 :: Expr Int
exampleExpr1 = Add (Lit 5) (Lit 3) -- 5 + 3

data MyList2 a = Nis | Conss a (MyList2 a)
evalExpr1 :: Int
evalExpr1 = eval exampleExpr1

lengthList2 :: MyList2 a -> Int
lengthList2 Nis = 0
lengthList2 (Conss _ xs) = 1 + lengthList2 xs
exampleExpr2 :: Expr Bool
exampleExpr2 = IsZero (Add (Lit 2) (Lit (-2))) -- 2 + (-2) == 0

-- this above is having an error that says
-- G2: No type found in typeWithStrName "MutVar#"
-- CallStack (from HasCallStack):
-- error, called at src/G2/Initialization/KnownValues.hs:127:10 in g2-0.2.0.0-inplace:G2.Initialization.KnownValues
evalExpr2 :: Bool
evalExpr2 = eval exampleExpr2

exampleExpr3 :: Expr Int
exampleExpr3 = If (IsZero (Lit 0)) (Lit 10) (Lit 20) -- if 0 == 0 then 10 else 20

data MyList a where
Ni :: MyList a
Cons :: a -> MyList a -> MyList a
evalExpr3 :: Int
evalExpr3 = eval exampleExpr3

-- recursion on recursive GADT
lengthList :: MyList a -> Int
lengthList Ni = 0
lengthList (Cons _ xs) = 1 + lengthList xs
exampleExpr4 :: Expr Int
exampleExpr4 = If (IsZero (Lit 1)) (Lit 10) (Lit 20) -- if 1 == 0 then 10 else 20

add2 :: a -> a -> MyList a -> MyList a
add2 a1 a2 li = Cons a2 $ Cons a1 li
evalExpr4 :: Int
evalExpr4 = eval exampleExpr4

addn :: [a] -> MyList a -> MyList a
addn [] a = a
addn (x:xs) a = addn xs (Cons x a)
exampleExpr5 :: Expr Bool
exampleExpr5 = IsZero (If (IsZero (Lit 0)) (Lit 0) (Lit 1)) -- isZero (if 0 == 0 then 0 else 1)

data MyExpr a where
Lt :: Int -> MyExpr Int
Mul :: MyExpr Int -> MyExpr Int -> MyExpr Int
Add :: MyExpr Int -> MyExpr Int -> MyExpr Int

evalMyExpr :: MyExpr a -> a
evalMyExpr (Lt a) = a
evalMyExpr (Mul a1 a2) = evalMyExpr a1 * evalMyExpr a2
evalMyExpr (Add a1 a2) = evalMyExpr a1 + evalMyExpr a2

testeval :: Int -> MyExpr Int
testeval a1 = testeval $ evalMyExpr $ Lt (2*a1)

checkeq :: Eq a => a -> a -> Bool
checkeq a a1 = a == a1

id2 :: a -> a
id2 x = x

idlr :: Either l r -> Either l r
idlr x = x
evalExpr5 :: Bool
evalExpr5 = eval exampleExpr5

data Peano = Succ Peano | Zero

data Vec :: Peano -> Type -> Type where
VNil :: Vec Zero a
VCons :: forall n a. a -> Vec n a -> Vec (Succ n) a

vecLength :: Vec n a -> Integer
instance Eq a => Eq (Vec Zero a) where
VNil == VNil = True

instance (Eq a, Eq (Vec n a)) => Eq (Vec (Succ n) a) where
(VCons x xs) == (VCons y ys) = x == y && xs == ys

vecLength :: Vec n a -> Int
vecLength VNil = 0
vecLength (VCons _ xs) = 1 + vecLength xs

Expand All @@ -94,34 +92,30 @@ vecZip (VCons x xs) (VCons y ys) = VCons (x, y) (vecZip xs ys)
vecZipConc :: Vec (Succ Zero) (Int, Char)
vecZipConc = vecZip (VCons 1 VNil) (VCons 'a' VNil)

vecZipConc2 :: Vec (Succ (Succ Zero)) (Int, Char)
vecZipConc2 = vecZip (VCons 1 (VCons 2 VNil)) (VCons 'a' (VCons 'b' VNil))

vecMap :: (a -> b) -> Vec n a -> Vec n b
vecMap _ VNil = VNil
vecMap f (VCons x xs) = VCons (f x) (vecMap f xs)

vecHeadEx :: Int
vecHeadEx = vecHead (VCons 1 (VCons 2 VNil))

-- have to run 400 steps for the result to show up instead of 200
doubleVec :: Vec (Succ (Succ Zero)) Int
doubleVec = vecMap (*2) (VCons 1 (VCons 2 VNil))

vecTail :: Vec (Succ n) a -> Vec n a
vecTail (VCons _ xs) = xs

tailVec :: Vec (Succ Zero) Char
tailVec = vecTail (VCons 'a' (VCons 'b' VNil))

tailPairVec :: Vec (Succ Zero) (Int, Char)
tailPairVec = vecTail $ vecZip (VCons 1 (VCons 2 VNil)) (VCons 'a' (VCons 'b' VNil))

-- Return all the elements of a list except the last one
vecInit :: Vec (Succ n) a -> Vec n a
vecInit (VCons x VNil) = VNil
vecInit (VCons x xs@(VCons y ys)) = VCons x (vecInit xs)


data Term a where
Lit :: Int -> Term Int
Pair :: Term a -> Term b -> Term (a,b)

eval2 :: Term a -> a
eval2 (Lit i) = i
eval2 (Pair a b) = (eval2 a, eval2 b)

data X (b :: Bool) where
XTrue :: X b -> X True
XFalse :: X False

getX :: X True
getX = walkX (XTrue XFalse)

walkX :: X b -> X b
walkX (XTrue x) = XTrue (walkX x)
walkX XFalse = XFalse

0 comments on commit e42324d

Please sign in to comment.