{-# LANGUAGE DerivingVia #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}

module Cardano.Crypto.Libsodium.MLockedSeed
where

import Cardano.Crypto.DirectSerialise
import Cardano.Crypto.Libsodium.C (
  c_sodium_randombytes_buf,
 )
import Cardano.Crypto.Libsodium.MLockedBytes (
  MLockedSizedBytes,
  mlsbCopyWith,
  mlsbFinalize,
  mlsbNewWith,
  mlsbNewZeroWith,
  mlsbUseAsCPtr,
  mlsbUseAsSizedPtr,
 )
import Cardano.Crypto.Libsodium.Memory (
  MLockedAllocator,
  mlockedMalloc,
 )
import Cardano.Foreign (SizedPtr)
import Control.DeepSeq (NFData)
import Control.Monad.Class.MonadST (MonadST)
import Data.Proxy (Proxy (..))
import Data.Word (Word8)
import Foreign.C.Types (CSize)
import Foreign.Ptr (Ptr, castPtr)
import GHC.TypeLits (Natural)
import GHC.TypeNats (KnownNat, natVal)
import NoThunks.Class (NoThunks)

-- | A seed of size @n@, stored in mlocked memory. This is required to prevent
-- the seed from leaking to disk via swapping and reclaiming or scanning memory
-- after its content has been moved.
newtype MLockedSeed n = MLockedSeed {forall (n :: Nat). MLockedSeed n -> MLockedSizedBytes n
mlockedSeedMLSB :: MLockedSizedBytes n}
  deriving (MLockedSeed n -> ()
(MLockedSeed n -> ()) -> NFData (MLockedSeed n)
forall (n :: Nat). MLockedSeed n -> ()
forall a. (a -> ()) -> NFData a
$crnf :: forall (n :: Nat). MLockedSeed n -> ()
rnf :: MLockedSeed n -> ()
NFData, Context -> MLockedSeed n -> IO (Maybe ThunkInfo)
Proxy (MLockedSeed n) -> String
(Context -> MLockedSeed n -> IO (Maybe ThunkInfo))
-> (Context -> MLockedSeed n -> IO (Maybe ThunkInfo))
-> (Proxy (MLockedSeed n) -> String)
-> NoThunks (MLockedSeed n)
forall (n :: Nat). Context -> MLockedSeed n -> IO (Maybe ThunkInfo)
forall (n :: Nat). Proxy (MLockedSeed n) -> String
forall a.
(Context -> a -> IO (Maybe ThunkInfo))
-> (Context -> a -> IO (Maybe ThunkInfo))
-> (Proxy a -> String)
-> NoThunks a
$cnoThunks :: forall (n :: Nat). Context -> MLockedSeed n -> IO (Maybe ThunkInfo)
noThunks :: Context -> MLockedSeed n -> IO (Maybe ThunkInfo)
$cwNoThunks :: forall (n :: Nat). Context -> MLockedSeed n -> IO (Maybe ThunkInfo)
wNoThunks :: Context -> MLockedSeed n -> IO (Maybe ThunkInfo)
$cshowTypeOf :: forall (n :: Nat). Proxy (MLockedSeed n) -> String
showTypeOf :: Proxy (MLockedSeed n) -> String
NoThunks)

instance KnownNat n => DirectSerialise (MLockedSeed n) where
  directSerialise :: forall (m :: * -> *).
(MonadST m, MonadThrow m) =>
(Ptr CChar -> CSize -> m ()) -> MLockedSeed n -> m ()
directSerialise Ptr CChar -> CSize -> m ()
push MLockedSeed n
seed =
    MLockedSeed n -> (Ptr Word8 -> m ()) -> m ()
forall (m :: * -> *) (n :: Nat) b.
MonadST m =>
MLockedSeed n -> (Ptr Word8 -> m b) -> m b
mlockedSeedUseAsCPtr MLockedSeed n
seed ((Ptr Word8 -> m ()) -> m ()) -> (Ptr Word8 -> m ()) -> m ()
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
ptr ->
      Ptr CChar -> CSize -> m ()
push (Ptr Word8 -> Ptr CChar
forall a b. Ptr a -> Ptr b
castPtr Ptr Word8
ptr) (forall a b. (Integral a, Num b) => a -> b
fromIntegral @Natural @CSize (Nat -> CSize) -> Nat -> CSize
forall a b. (a -> b) -> a -> b
$ MLockedSeed n -> Nat
forall (n :: Nat) (proxy :: Nat -> *). KnownNat n => proxy n -> Nat
natVal MLockedSeed n
seed)

instance KnownNat n => DirectDeserialise (MLockedSeed n) where
  directDeserialise :: forall (m :: * -> *).
(MonadST m, MonadThrow m) =>
(Ptr CChar -> CSize -> m ()) -> m (MLockedSeed n)
directDeserialise Ptr CChar -> CSize -> m ()
pull = do
    MLockedSeed n
seed <- m (MLockedSeed n)
forall (n :: Nat) (m :: * -> *).
(KnownNat n, MonadST m) =>
m (MLockedSeed n)
mlockedSeedNew
    MLockedSeed n -> (Ptr Word8 -> m ()) -> m ()
forall (m :: * -> *) (n :: Nat) b.
MonadST m =>
MLockedSeed n -> (Ptr Word8 -> m b) -> m b
mlockedSeedUseAsCPtr MLockedSeed n
seed ((Ptr Word8 -> m ()) -> m ()) -> (Ptr Word8 -> m ()) -> m ()
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
ptr ->
      Ptr CChar -> CSize -> m ()
pull (Ptr Word8 -> Ptr CChar
forall a b. Ptr a -> Ptr b
castPtr Ptr Word8
ptr) (forall a b. (Integral a, Num b) => a -> b
fromIntegral @Natural @CSize (Nat -> CSize) -> Nat -> CSize
forall a b. (a -> b) -> a -> b
$ MLockedSeed n -> Nat
forall (n :: Nat) (proxy :: Nat -> *). KnownNat n => proxy n -> Nat
natVal MLockedSeed n
seed)
    MLockedSeed n -> m (MLockedSeed n)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return MLockedSeed n
seed

withMLockedSeedAsMLSB ::
  Functor m =>
  (MLockedSizedBytes n -> m (MLockedSizedBytes n)) ->
  MLockedSeed n ->
  m (MLockedSeed n)
withMLockedSeedAsMLSB :: forall (m :: * -> *) (n :: Nat).
Functor m =>
(MLockedSizedBytes n -> m (MLockedSizedBytes n))
-> MLockedSeed n -> m (MLockedSeed n)
withMLockedSeedAsMLSB MLockedSizedBytes n -> m (MLockedSizedBytes n)
action =
  (MLockedSizedBytes n -> MLockedSeed n)
-> m (MLockedSizedBytes n) -> m (MLockedSeed n)
forall a b. (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap MLockedSizedBytes n -> MLockedSeed n
forall (n :: Nat). MLockedSizedBytes n -> MLockedSeed n
MLockedSeed (m (MLockedSizedBytes n) -> m (MLockedSeed n))
-> (MLockedSeed n -> m (MLockedSizedBytes n))
-> MLockedSeed n
-> m (MLockedSeed n)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. MLockedSizedBytes n -> m (MLockedSizedBytes n)
action (MLockedSizedBytes n -> m (MLockedSizedBytes n))
-> (MLockedSeed n -> MLockedSizedBytes n)
-> MLockedSeed n
-> m (MLockedSizedBytes n)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. MLockedSeed n -> MLockedSizedBytes n
forall (n :: Nat). MLockedSeed n -> MLockedSizedBytes n
mlockedSeedMLSB

mlockedSeedCopy :: (KnownNat n, MonadST m) => MLockedSeed n -> m (MLockedSeed n)
mlockedSeedCopy :: forall (n :: Nat) (m :: * -> *).
(KnownNat n, MonadST m) =>
MLockedSeed n -> m (MLockedSeed n)
mlockedSeedCopy = MLockedAllocator m -> MLockedSeed n -> m (MLockedSeed n)
forall (n :: Nat) (m :: * -> *).
(KnownNat n, MonadST m) =>
MLockedAllocator m -> MLockedSeed n -> m (MLockedSeed n)
mlockedSeedCopyWith MLockedAllocator m
forall (m :: * -> *). MonadST m => MLockedAllocator m
mlockedMalloc

mlockedSeedCopyWith ::
  (KnownNat n, MonadST m) =>
  MLockedAllocator m ->
  MLockedSeed n ->
  m (MLockedSeed n)
mlockedSeedCopyWith :: forall (n :: Nat) (m :: * -> *).
(KnownNat n, MonadST m) =>
MLockedAllocator m -> MLockedSeed n -> m (MLockedSeed n)
mlockedSeedCopyWith MLockedAllocator m
allocator = (MLockedSizedBytes n -> m (MLockedSizedBytes n))
-> MLockedSeed n -> m (MLockedSeed n)
forall (m :: * -> *) (n :: Nat).
Functor m =>
(MLockedSizedBytes n -> m (MLockedSizedBytes n))
-> MLockedSeed n -> m (MLockedSeed n)
withMLockedSeedAsMLSB (MLockedAllocator m
-> MLockedSizedBytes n -> m (MLockedSizedBytes n)
forall (n :: Nat) (m :: * -> *).
(KnownNat n, MonadST m) =>
MLockedAllocator m
-> MLockedSizedBytes n -> m (MLockedSizedBytes n)
mlsbCopyWith MLockedAllocator m
allocator)

mlockedSeedNew :: (KnownNat n, MonadST m) => m (MLockedSeed n)
mlockedSeedNew :: forall (n :: Nat) (m :: * -> *).
(KnownNat n, MonadST m) =>
m (MLockedSeed n)
mlockedSeedNew = MLockedAllocator m -> m (MLockedSeed n)
forall (n :: Nat) (m :: * -> *).
(KnownNat n, MonadST m) =>
MLockedAllocator m -> m (MLockedSeed n)
mlockedSeedNewWith MLockedAllocator m
forall (m :: * -> *). MonadST m => MLockedAllocator m
mlockedMalloc

mlockedSeedNewWith :: (KnownNat n, MonadST m) => MLockedAllocator m -> m (MLockedSeed n)
mlockedSeedNewWith :: forall (n :: Nat) (m :: * -> *).
(KnownNat n, MonadST m) =>
MLockedAllocator m -> m (MLockedSeed n)
mlockedSeedNewWith MLockedAllocator m
allocator =
  MLockedSizedBytes n -> MLockedSeed n
forall (n :: Nat). MLockedSizedBytes n -> MLockedSeed n
MLockedSeed (MLockedSizedBytes n -> MLockedSeed n)
-> m (MLockedSizedBytes n) -> m (MLockedSeed n)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> MLockedAllocator m
-> (KnownNat n, MonadST m) => m (MLockedSizedBytes n)
forall (n :: Nat) (m :: * -> *).
MLockedAllocator m
-> (KnownNat n, MonadST m) => m (MLockedSizedBytes n)
mlsbNewWith MLockedAllocator m
allocator

mlockedSeedNewZero :: (KnownNat n, MonadST m) => m (MLockedSeed n)
mlockedSeedNewZero :: forall (n :: Nat) (m :: * -> *).
(KnownNat n, MonadST m) =>
m (MLockedSeed n)
mlockedSeedNewZero = MLockedAllocator m -> m (MLockedSeed n)
forall (n :: Nat) (m :: * -> *).
(KnownNat n, MonadST m) =>
MLockedAllocator m -> m (MLockedSeed n)
mlockedSeedNewZeroWith MLockedAllocator m
forall (m :: * -> *). MonadST m => MLockedAllocator m
mlockedMalloc

mlockedSeedNewZeroWith :: (KnownNat n, MonadST m) => MLockedAllocator m -> m (MLockedSeed n)
mlockedSeedNewZeroWith :: forall (n :: Nat) (m :: * -> *).
(KnownNat n, MonadST m) =>
MLockedAllocator m -> m (MLockedSeed n)
mlockedSeedNewZeroWith MLockedAllocator m
allocator =
  MLockedSizedBytes n -> MLockedSeed n
forall (n :: Nat). MLockedSizedBytes n -> MLockedSeed n
MLockedSeed (MLockedSizedBytes n -> MLockedSeed n)
-> m (MLockedSizedBytes n) -> m (MLockedSeed n)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> MLockedAllocator m -> m (MLockedSizedBytes n)
forall (n :: Nat) (m :: * -> *).
(KnownNat n, MonadST m) =>
MLockedAllocator m -> m (MLockedSizedBytes n)
mlsbNewZeroWith MLockedAllocator m
allocator

mlockedSeedNewRandom :: forall n. KnownNat n => IO (MLockedSeed n)
mlockedSeedNewRandom :: forall (n :: Nat). KnownNat n => IO (MLockedSeed n)
mlockedSeedNewRandom = MLockedAllocator IO -> IO (MLockedSeed n)
forall (n :: Nat).
KnownNat n =>
MLockedAllocator IO -> IO (MLockedSeed n)
mlockedSeedNewRandomWith MLockedAllocator IO
forall (m :: * -> *). MonadST m => MLockedAllocator m
mlockedMalloc

mlockedSeedNewRandomWith :: forall n. KnownNat n => MLockedAllocator IO -> IO (MLockedSeed n)
mlockedSeedNewRandomWith :: forall (n :: Nat).
KnownNat n =>
MLockedAllocator IO -> IO (MLockedSeed n)
mlockedSeedNewRandomWith MLockedAllocator IO
allocator = do
  MLockedSeed n
mls <- MLockedSizedBytes n -> MLockedSeed n
forall (n :: Nat). MLockedSizedBytes n -> MLockedSeed n
MLockedSeed (MLockedSizedBytes n -> MLockedSeed n)
-> IO (MLockedSizedBytes n) -> IO (MLockedSeed n)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> MLockedAllocator IO -> IO (MLockedSizedBytes n)
forall (n :: Nat) (m :: * -> *).
(KnownNat n, MonadST m) =>
MLockedAllocator m -> m (MLockedSizedBytes n)
mlsbNewZeroWith MLockedAllocator IO
allocator
  MLockedSeed n -> (Ptr Word8 -> IO ()) -> IO ()
forall (m :: * -> *) (n :: Nat) b.
MonadST m =>
MLockedSeed n -> (Ptr Word8 -> m b) -> m b
mlockedSeedUseAsCPtr MLockedSeed n
mls ((Ptr Word8 -> IO ()) -> IO ()) -> (Ptr Word8 -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
dst -> do
    Ptr Word8 -> CSize -> IO ()
forall a. Ptr a -> CSize -> IO ()
c_sodium_randombytes_buf Ptr Word8
dst CSize
size
  MLockedSeed n -> IO (MLockedSeed n)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return MLockedSeed n
mls
  where
    size :: CSize
size = forall a b. (Integral a, Num b) => a -> b
fromIntegral @Natural @CSize (Nat -> CSize) -> Nat -> CSize
forall a b. (a -> b) -> a -> b
$ Proxy n -> Nat
forall (n :: Nat) (proxy :: Nat -> *). KnownNat n => proxy n -> Nat
natVal (forall (t :: Nat). Proxy t
forall {k} (t :: k). Proxy t
Proxy @n)

mlockedSeedFinalize :: MonadST m => MLockedSeed n -> m ()
mlockedSeedFinalize :: forall (m :: * -> *) (n :: Nat). MonadST m => MLockedSeed n -> m ()
mlockedSeedFinalize = MLockedSizedBytes n -> m ()
forall (m :: * -> *) (n :: Nat).
MonadST m =>
MLockedSizedBytes n -> m ()
mlsbFinalize (MLockedSizedBytes n -> m ())
-> (MLockedSeed n -> MLockedSizedBytes n) -> MLockedSeed n -> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. MLockedSeed n -> MLockedSizedBytes n
forall (n :: Nat). MLockedSeed n -> MLockedSizedBytes n
mlockedSeedMLSB

mlockedSeedUseAsCPtr :: MonadST m => MLockedSeed n -> (Ptr Word8 -> m b) -> m b
mlockedSeedUseAsCPtr :: forall (m :: * -> *) (n :: Nat) b.
MonadST m =>
MLockedSeed n -> (Ptr Word8 -> m b) -> m b
mlockedSeedUseAsCPtr MLockedSeed n
seed = MLockedSizedBytes n -> (Ptr Word8 -> m b) -> m b
forall (m :: * -> *) (n :: Nat) r.
MonadST m =>
MLockedSizedBytes n -> (Ptr Word8 -> m r) -> m r
mlsbUseAsCPtr (MLockedSeed n -> MLockedSizedBytes n
forall (n :: Nat). MLockedSeed n -> MLockedSizedBytes n
mlockedSeedMLSB MLockedSeed n
seed)

mlockedSeedUseAsSizedPtr :: MonadST m => MLockedSeed n -> (SizedPtr n -> m b) -> m b
mlockedSeedUseAsSizedPtr :: forall (m :: * -> *) (n :: Nat) b.
MonadST m =>
MLockedSeed n -> (SizedPtr n -> m b) -> m b
mlockedSeedUseAsSizedPtr MLockedSeed n
seed = MLockedSizedBytes n -> (SizedPtr n -> m b) -> m b
forall (n :: Nat) r (m :: * -> *).
MonadST m =>
MLockedSizedBytes n -> (SizedPtr n -> m r) -> m r
mlsbUseAsSizedPtr (MLockedSeed n -> MLockedSizedBytes n
forall (n :: Nat). MLockedSeed n -> MLockedSizedBytes n
mlockedSeedMLSB MLockedSeed n
seed)