{-# LANGUAGE CPP #-}
module PureSAT.VarSet where

-- #define INTSET_VARS

import Control.Monad.ST (ST)
import Data.Coerce      (coerce)

import PureSAT.LitVar

#ifdef INTSET_VARS
import qualified Data.IntSet as IS
import           Data.STRef  (modifySTRef)
#else
import PureSAT.SparseMaxHeap
#endif

-------------------------------------------------------------------------------
-- VarSet
-------------------------------------------------------------------------------

#ifdef INTSET_VARS
newtype VarSet s = VS (STRef s IS.IntSet)

newVarSet :: ST s (VarSet s)
newVarSet = VS <$> newSTRef IS.empty

cloneVarSet :: VarSet -> ST s (VarSet s)
cloneVarSet (VS xs) = error "TODO"

sizeofVarSet :: VarSet s -> ST s Int
sizeofVarSet (VS xs) = IS.size <$> readSTRef xs

extendVarSet :: Int -> VarSet s -> ST s (VarSet s)
extendVarSet _ x = return x

weightVarSet :: Var -> (Int -> Int) -> VarSet s -> ST s ()
weightVarSet _ _ _ = return ()

insertVarSet :: Var -> VarSet s -> ST s ()
insertVarSet (MkVar x) (VS xs) = modifySTRef xs (IS.insert x)

deleteVarSet :: Var -> VarSet s -> ST s ()
deleteVarSet (MkVar x) (VS xs) = modifySTRef xs (IS.delete x)

clearVarSet :: VarSet s -> ST s ()
clearVarSet (VS xs) = writeSTRef xs IS.empty

minViewVarSet :: VarSet s -> ST s r -> (Var -> ST s r) -> ST s r
minViewVarSet (VS xs) no yes = do
    is <- readSTRef xs
    case IS.minView is of
        Nothing -> no
        Just (x, is') -> do
            writeSTRef xs is'
            yes (MkVar x)

#else

newtype VarSet s = VS (SparseHeap s)

sizeofVarSet :: VarSet s -> ST s Int
sizeofVarSet :: forall s. VarSet s -> ST s Int
sizeofVarSet (VS SparseHeap s
xs) = SparseHeap s -> ST s Int
forall s. SparseHeap s -> ST s Int
sizeofSparseHeap SparseHeap s
xs

newVarSet :: ST s (VarSet s)
newVarSet :: forall s. ST s (VarSet s)
newVarSet = SparseHeap s -> VarSet s
forall s. SparseHeap s -> VarSet s
VS (SparseHeap s -> VarSet s)
-> ST s (SparseHeap s) -> ST s (VarSet s)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> ST s (SparseHeap s)
forall s. Int -> ST s (SparseHeap s)
newSparseHeap Int
0

cloneVarSet :: VarSet s -> ST s (VarSet s)
cloneVarSet :: forall s. VarSet s -> ST s (VarSet s)
cloneVarSet (VS SparseHeap s
xs) = SparseHeap s -> VarSet s
forall s. SparseHeap s -> VarSet s
VS (SparseHeap s -> VarSet s)
-> ST s (SparseHeap s) -> ST s (VarSet s)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SparseHeap s -> ST s (SparseHeap s)
forall s. SparseHeap s -> ST s (SparseHeap s)
cloneSparseHeap SparseHeap s
xs

extendVarSet :: Int -> VarSet s -> ST s (VarSet s)
extendVarSet :: forall s. Int -> VarSet s -> ST s (VarSet s)
extendVarSet Int
capacity (VS SparseHeap s
xs) = SparseHeap s -> VarSet s
forall s. SparseHeap s -> VarSet s
VS (SparseHeap s -> VarSet s)
-> ST s (SparseHeap s) -> ST s (VarSet s)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> SparseHeap s -> ST s (SparseHeap s)
forall s. Int -> SparseHeap s -> ST s (SparseHeap s)
extendSparseHeap Int
capacity SparseHeap s
xs

weightVarSet :: Var -> (Word -> Word) -> VarSet s -> ST s ()
weightVarSet :: forall s. Var -> (Word -> Word) -> VarSet s -> ST s ()
weightVarSet (MkVar Int
x) Word -> Word
f (VS SparseHeap s
xs) = SparseHeap s -> Int -> (Word -> Word) -> ST s ()
forall s. SparseHeap s -> Int -> (Word -> Word) -> ST s ()
modifyWeightSparseHeap SparseHeap s
xs Int
x Word -> Word
f
{-# INLINE weightVarSet #-}

scaleVarSet :: VarSet s -> (Word -> Word) -> ST s ()
scaleVarSet :: forall s. VarSet s -> (Word -> Word) -> ST s ()
scaleVarSet (VS SparseHeap s
xs) Word -> Word
f = SparseHeap s -> (Word -> Word) -> ST s ()
forall s. SparseHeap s -> (Word -> Word) -> ST s ()
scaleWeightsSparseHeap SparseHeap s
xs Word -> Word
f
{-# INLINE scaleVarSet #-}

insertVarSet :: Var -> VarSet s -> ST s ()
insertVarSet :: forall s. Var -> VarSet s -> ST s ()
insertVarSet (MkVar Int
x) (VS SparseHeap s
xs) = do
    SparseHeap s -> Int -> ST s ()
forall s. SparseHeap s -> Int -> ST s ()
insertSparseHeap SparseHeap s
xs Int
x

deleteVarSet :: Var -> VarSet s -> ST s ()
deleteVarSet :: forall s. Var -> VarSet s -> ST s ()
deleteVarSet (MkVar Int
x) (VS SparseHeap s
xs) = do
    SparseHeap s -> Int -> ST s ()
forall s. SparseHeap s -> Int -> ST s ()
deleteSparseHeap SparseHeap s
xs Int
x

clearVarSet :: VarSet s -> ST s ()
clearVarSet :: forall s. VarSet s -> ST s ()
clearVarSet (VS SparseHeap s
xs) = SparseHeap s -> ST s ()
forall s. SparseHeap s -> ST s ()
clearSparseHeap SparseHeap s
xs

{-# INLINE minViewVarSet #-}
minViewVarSet :: VarSet s -> ST s r -> (Var -> ST s r) -> ST s r
minViewVarSet :: forall s r. VarSet s -> ST s r -> (Var -> ST s r) -> ST s r
minViewVarSet (VS SparseHeap s
xs) ST s r
no Var -> ST s r
yes = SparseHeap s -> ST s r -> (Int -> ST s r) -> ST s r
forall s r. SparseHeap s -> ST s r -> (Int -> ST s r) -> ST s r
popSparseHeap_ SparseHeap s
xs ST s r
no ((Var -> ST s r) -> Int -> ST s r
forall a b. Coercible a b => a -> b
coerce Var -> ST s r
yes)

#endif