{-# LANGUAGE CPP        #-}
{-# LANGUAGE LambdaCase, ViewPatterns #-}
module PureSAT.PartialAssignment where

#define ASSERTING(x)

import PureSAT.Base
import PureSAT.LBool
import PureSAT.LitVar
import PureSAT.Prim

-------------------------------------------------------------------------------
-- Partial Assignment
-------------------------------------------------------------------------------

newtype PartialAssignment s = PA (MutableByteArray s)

newPartialAssignment :: Int -> ST s (PartialAssignment s)
newPartialAssignment :: forall s. Int -> ST s (PartialAssignment s)
newPartialAssignment (Int -> Int -> Int
forall a. Ord a => a -> a -> a
max Int
4096 -> Int
size) = do
    arr <- Int -> ST s (MutableByteArray (PrimState (ST s)))
forall (m :: * -> *).
PrimMonad m =>
Int -> m (MutableByteArray (PrimState m))
newByteArray Int
size
    shrinkMutableByteArray arr size
    fillByteArray arr 0 size 0xff
    return (PA arr)

clonePartialAssignment :: PartialAssignment s -> ST s (PartialAssignment s)
clonePartialAssignment :: forall s. PartialAssignment s -> ST s (PartialAssignment s)
clonePartialAssignment (PA MutableByteArray s
old) = do
    n <- MutableByteArray (PrimState (ST s)) -> ST s Int
forall (m :: * -> *).
PrimMonad m =>
MutableByteArray (PrimState m) -> m Int
getSizeofMutableByteArray MutableByteArray s
MutableByteArray (PrimState (ST s))
old
    new <- newByteArray n
    copyMutableByteArray new 0 old 0 n
    return (PA new)    

copyPartialAssignment :: PartialAssignment s -> PartialAssignment s -> ST s ()
copyPartialAssignment :: forall s. PartialAssignment s -> PartialAssignment s -> ST s ()
copyPartialAssignment (PA MutableByteArray s
src) (PA MutableByteArray s
tgt) = do
    n <- MutableByteArray (PrimState (ST s)) -> ST s Int
forall (m :: * -> *).
PrimMonad m =>
MutableByteArray (PrimState m) -> m Int
getSizeofMutableByteArray MutableByteArray s
MutableByteArray (PrimState (ST s))
src
    m <- getSizeofMutableByteArray tgt
    let size = Int -> Int -> Int
forall a. Ord a => a -> a -> a
min Int
n Int
m
    copyMutableByteArray tgt 0 src 0 size

extendPartialAssignment :: PartialAssignment s -> ST s (PartialAssignment s)
extendPartialAssignment :: forall s. PartialAssignment s -> ST s (PartialAssignment s)
extendPartialAssignment (PA MutableByteArray s
arr) = do
    size <- MutableByteArray (PrimState (ST s)) -> ST s Int
forall (m :: * -> *).
PrimMonad m =>
MutableByteArray (PrimState m) -> m Int
getSizeofMutableByteArray MutableByteArray s
MutableByteArray (PrimState (ST s))
arr
    arr' <- resizeMutableByteArray arr (size + 1)
    writeByteArray arr' size (0xff :: Word8)
    return (PA arr')

lookupPartialAssignment :: Lit -> PartialAssignment s -> ST s LBool
lookupPartialAssignment :: forall s. Lit -> PartialAssignment s -> ST s LBool
lookupPartialAssignment (MkLit Int
l) (PA MutableByteArray s
arr) = do
    MutableByteArray s -> Int -> ST s Word8
forall s. HasCallStack => MutableByteArray s -> Int -> ST s Word8
readByteArray MutableByteArray s
arr (Int -> Int
lit_to_var Int
l) ST s Word8 -> (Word8 -> ST s LBool) -> ST s LBool
forall a b. ST s a -> (a -> ST s b) -> ST s b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
        Word8
0x0 -> LBool -> ST s LBool
forall a. a -> ST s a
forall (m :: * -> *) a. Monad m => a -> m a
return (if Bool
y then LBool
LFalse else LBool
LTrue)
        Word8
0x1 -> LBool -> ST s LBool
forall a. a -> ST s a
forall (m :: * -> *) a. Monad m => a -> m a
return (if Bool
y then LBool
LTrue else LBool
LFalse)
        Word8
_   -> LBool -> ST s LBool
forall a. a -> ST s a
forall (m :: * -> *) a. Monad m => a -> m a
return LBool
LUndef
  where
    y :: Bool
y = Int -> Int -> Bool
forall a. Bits a => a -> Int -> Bool
testBit Int
l Int
0
    {-# INLINE y #-}

insertPartialAssignment :: Lit -> PartialAssignment s -> ST s ()
insertPartialAssignment :: forall s. Lit -> PartialAssignment s -> ST s ()
insertPartialAssignment (MkLit Int
l) (PA MutableByteArray s
arr) = do
    ASSERTING(readByteArray arr (lit_to_var l) >>= \x -> assertST "insert" (x == (0xff :: Word8)))
    MutableByteArray s -> Int -> Word8 -> ST s ()
forall s.
HasCallStack =>
MutableByteArray s -> Int -> Word8 -> ST s ()
writeByteArray MutableByteArray s
arr (Int -> Int
lit_to_var Int
l) (if Int -> Int -> Bool
forall a. Bits a => a -> Int -> Bool
testBit Int
l Int
0 then Word8
0x1 else Word8
0x0 :: Word8)

deletePartialAssignment :: Lit -> PartialAssignment s -> ST s ()
deletePartialAssignment :: forall s. Lit -> PartialAssignment s -> ST s ()
deletePartialAssignment (MkLit Int
l) (PA MutableByteArray s
arr) = do
    MutableByteArray s -> Int -> Word8 -> ST s ()
forall s.
HasCallStack =>
MutableByteArray s -> Int -> Word8 -> ST s ()
writeByteArray MutableByteArray s
arr (Int -> Int
lit_to_var Int
l) (Word8
0xff :: Word8)

tracePartialAssignment :: PartialAssignment s -> ST s ()
tracePartialAssignment :: forall s. PartialAssignment s -> ST s ()
tracePartialAssignment (PA MutableByteArray s
arr) = do
    n <- MutableByteArray (PrimState (ST s)) -> ST s Int
forall (m :: * -> *).
PrimMonad m =>
MutableByteArray (PrimState m) -> m Int
getSizeofMutableByteArray MutableByteArray s
MutableByteArray (PrimState (ST s))
arr
    lits <- go n [] 0
    traceM $ "PartialAssignment " ++ show lits
  where
    go :: Int -> [Lit] -> Int -> ST s [Lit]
go Int
n [Lit]
acc Int
i
        | Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
n
        , let l :: Lit
l = Int -> Lit
MkLit (Int -> Int
var_to_lit Int
i)
        = MutableByteArray s -> Int -> ST s Word8
forall s. HasCallStack => MutableByteArray s -> Int -> ST s Word8
readByteArray MutableByteArray s
arr Int
i ST s Word8 -> (Word8 -> ST s [Lit]) -> ST s [Lit]
forall a b. ST s a -> (a -> ST s b) -> ST s b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
          Word8
0x0 -> Int -> [Lit] -> Int -> ST s [Lit]
go Int
n (    Lit
l Lit -> [Lit] -> [Lit]
forall a. a -> [a] -> [a]
: [Lit]
acc) (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
          Word8
0x1 -> Int -> [Lit] -> Int -> ST s [Lit]
go Int
n (Lit -> Lit
neg Lit
l Lit -> [Lit] -> [Lit]
forall a. a -> [a] -> [a]
: [Lit]
acc) (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
          Word8
_   -> Int -> [Lit] -> Int -> ST s [Lit]
go Int
n          [Lit]
acc  (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)

        | Bool
otherwise
        = [Lit] -> ST s [Lit]
forall a. a -> ST s a
forall (m :: * -> *) a. Monad m => a -> m a
return ([Lit] -> [Lit]
forall a. [a] -> [a]
reverse [Lit]
acc)

assertLiteralInPartialAssignment :: Lit -> PartialAssignment s -> ST s ()
assertLiteralInPartialAssignment :: forall s. Lit -> PartialAssignment s -> ST s ()
assertLiteralInPartialAssignment Lit
l PartialAssignment s
pa =
    Lit -> PartialAssignment s -> ST s LBool
forall s. Lit -> PartialAssignment s -> ST s LBool
lookupPartialAssignment Lit
l PartialAssignment s
pa ST s LBool -> (LBool -> ST s ()) -> ST s ()
forall a b. ST s a -> (a -> ST s b) -> ST s b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
        LBool
LTrue -> () -> ST s ()
forall a. a -> ST s a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
        LBool
x     -> String -> Bool -> ST s ()
forall s. HasCallStack => String -> Bool -> ST s ()
assertST (String
"lit in partial: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ LBool -> String
forall a. Show a => a -> String
show LBool
x) Bool
False

assertLiteralUndef :: Lit -> PartialAssignment s -> ST s ()
assertLiteralUndef :: forall s. Lit -> PartialAssignment s -> ST s ()
assertLiteralUndef Lit
l PartialAssignment s
pa =
    Lit -> PartialAssignment s -> ST s LBool
forall s. Lit -> PartialAssignment s -> ST s LBool
lookupPartialAssignment Lit
l PartialAssignment s
pa ST s LBool -> (LBool -> ST s ()) -> ST s ()
forall a b. ST s a -> (a -> ST s b) -> ST s b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \LBool
x ->
    String -> Bool -> ST s ()
forall s. HasCallStack => String -> Bool -> ST s ()
assertST (String
"assertLiteralUndef: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ LBool -> String
forall a. Show a => a -> String
show LBool
x) (LBool
x LBool -> LBool -> Bool
forall a. Eq a => a -> a -> Bool
== LBool
LUndef)