{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DerivingVia #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}

module Cardano.Crypto.Libsodium.MLockedBytes.Internal (
  -- * The MLockesSizedBytes type
  MLockedSizedBytes (..),
  SizedVoid,

  -- * Safe Functions
  mlsbNew,
  mlsbNewZero,
  mlsbZero,
  mlsbUseAsCPtr,
  mlsbUseAsSizedPtr,
  mlsbCopy,
  mlsbFinalize,
  mlsbCompare,
  mlsbEq,
  withMLSB,
  withMLSBChunk,
  mlsbNewWith,
  mlsbNewZeroWith,
  mlsbCopyWith,

  -- * Dangerous Functions
  traceMLSB,
  mlsbFromByteString,
  mlsbFromByteStringCheck,
  mlsbAsByteString,
  mlsbToByteString,
  mlsbFromByteStringWith,
  mlsbFromByteStringCheckWith,
) where

import Control.DeepSeq (NFData (..))
import Control.Monad.Class.MonadST
import Control.Monad.ST.Unsafe (unsafeIOToST)
import Data.Proxy (Proxy (..))
import Data.Word (Word8)
import Foreign.C.Types (CSize (..))
import Foreign.ForeignPtr (castForeignPtr, newForeignPtr_)
import Foreign.Ptr (Ptr, castPtr, plusPtr)
import GHC.TypeLits (KnownNat, Nat, natVal)
import NoThunks.Class (NoThunks)

import Cardano.Crypto.Libsodium.C
import Cardano.Crypto.Libsodium.Memory
import Cardano.Crypto.Libsodium.Memory.Internal (MLockedForeignPtr (..))
import Cardano.Foreign

import Data.Bits (Bits, shiftL)
import qualified Data.ByteString as BS
import qualified Data.ByteString.Internal as BSI

-- | A void type with a type-level size attached to it. We need this in order
-- to express \"pointer to a block of memory of a particular size that can be
-- manipulated through the pointer, but not as a plain Haskell value\" as
-- @Ptr (SizedVoid n)@, or @ForeignPtr (SizedVoid n)@, or
-- @MLockedForeignPtr (SizedVoid n)@.
data SizedVoid (n :: Nat)

-- | A block of raw memory of a known size, protected with @mlock()@.
newtype MLockedSizedBytes (n :: Nat) = MLSB (MLockedForeignPtr (SizedVoid n))
  deriving newtype (Context -> MLockedSizedBytes n -> IO (Maybe ThunkInfo)
Proxy (MLockedSizedBytes n) -> String
(Context -> MLockedSizedBytes n -> IO (Maybe ThunkInfo))
-> (Context -> MLockedSizedBytes n -> IO (Maybe ThunkInfo))
-> (Proxy (MLockedSizedBytes n) -> String)
-> NoThunks (MLockedSizedBytes n)
forall (n :: Nat).
Context -> MLockedSizedBytes n -> IO (Maybe ThunkInfo)
forall (n :: Nat). Proxy (MLockedSizedBytes n) -> String
forall a.
(Context -> a -> IO (Maybe ThunkInfo))
-> (Context -> a -> IO (Maybe ThunkInfo))
-> (Proxy a -> String)
-> NoThunks a
$cnoThunks :: forall (n :: Nat).
Context -> MLockedSizedBytes n -> IO (Maybe ThunkInfo)
noThunks :: Context -> MLockedSizedBytes n -> IO (Maybe ThunkInfo)
$cwNoThunks :: forall (n :: Nat).
Context -> MLockedSizedBytes n -> IO (Maybe ThunkInfo)
wNoThunks :: Context -> MLockedSizedBytes n -> IO (Maybe ThunkInfo)
$cshowTypeOf :: forall (n :: Nat). Proxy (MLockedSizedBytes n) -> String
showTypeOf :: Proxy (MLockedSizedBytes n) -> String
NoThunks)
  deriving newtype (MLockedSizedBytes n -> ()
(MLockedSizedBytes n -> ()) -> NFData (MLockedSizedBytes n)
forall (n :: Nat). MLockedSizedBytes n -> ()
forall a. (a -> ()) -> NFData a
$crnf :: forall (n :: Nat). MLockedSizedBytes n -> ()
rnf :: MLockedSizedBytes n -> ()
NFData)

-- | This instance is /unsafe/, it will leak secrets from mlocked memory to the
-- Haskell heap. Do not use outside of testing.
instance KnownNat n => Show (MLockedSizedBytes n) where
  show :: MLockedSizedBytes n -> String
show MLockedSizedBytes n
mlsb = String
"MLockedSizedBytes[" String -> ShowS
forall a. [a] -> [a] -> [a]
++ Integer -> String
forall a. Show a => a -> String
show (MLockedSizedBytes n -> Integer
forall (n :: Nat) (proxy :: Nat -> *).
KnownNat n =>
proxy n -> Integer
natVal MLockedSizedBytes n
mlsb) String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"]"

-- TODO: move this to test suite, with a newtype wrapper
-- show mlsb =
--   let bytes = BS.unpack $ mlsbAsByteString mlsb
--       hexstr = concatMap (printf "%02x") bytes
--   in "MLSB " ++ hexstr

nextPowerOf2 :: forall n. (Num n, Ord n, Bits n) => n -> n
nextPowerOf2 :: forall n. (Num n, Ord n, Bits n) => n -> n
nextPowerOf2 n
i =
  n -> n
go n
1
  where
    go :: n -> n
    go :: n -> n
go n
c =
      let c' :: n
c' = n
c n -> Int -> n
forall a. Bits a => a -> Int -> a
`shiftL` Int
1
       in if n
c n -> n -> Bool
forall a. Ord a => a -> a -> Bool
>= n
i then n
c else n -> n
go n
c'

traceMLSB :: KnownNat n => MLockedSizedBytes n -> IO ()
traceMLSB :: forall (n :: Nat). KnownNat n => MLockedSizedBytes n -> IO ()
traceMLSB = MLockedSizedBytes n -> IO ()
forall a. Show a => a -> IO ()
print
{-# DEPRECATED traceMLSB "Don't leave traceMLockedForeignPtr in production" #-}

withMLSB :: forall b n m. MonadST m => MLockedSizedBytes n -> (Ptr (SizedVoid n) -> m b) -> m b
withMLSB :: forall b (n :: Nat) (m :: * -> *).
MonadST m =>
MLockedSizedBytes n -> (Ptr (SizedVoid n) -> m b) -> m b
withMLSB (MLSB MLockedForeignPtr (SizedVoid n)
fptr) Ptr (SizedVoid n) -> m b
action = MLockedForeignPtr (SizedVoid n)
-> (Ptr (SizedVoid n) -> m b) -> m b
forall (m :: * -> *) a b.
MonadST m =>
MLockedForeignPtr a -> (Ptr a -> m b) -> m b
withMLockedForeignPtr MLockedForeignPtr (SizedVoid n)
fptr Ptr (SizedVoid n) -> m b
action

withMLSBChunk ::
  forall b n n' m.
  (MonadST m, KnownNat n, KnownNat n') =>
  MLockedSizedBytes n ->
  Int ->
  (MLockedSizedBytes n' -> m b) ->
  m b
withMLSBChunk :: forall b (n :: Nat) (n' :: Nat) (m :: * -> *).
(MonadST m, KnownNat n, KnownNat n') =>
MLockedSizedBytes n -> Int -> (MLockedSizedBytes n' -> m b) -> m b
withMLSBChunk MLockedSizedBytes n
mlsb Int
offset MLockedSizedBytes n' -> m b
action
  | Int
offset Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
0 =
      String -> m b
forall a. HasCallStack => String -> a
error String
"Negative offset not allowed"
  | Int
offset Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
parentSize Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
chunkSize =
      String -> m b
forall a. HasCallStack => String -> a
error (String -> m b) -> String -> m b
forall a b. (a -> b) -> a -> b
$ String
"Overrun (" String -> ShowS
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
offset String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" + " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
chunkSize String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" > " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
parentSize String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
")"
  | Bool
otherwise =
      MLockedSizedBytes n -> (Ptr (SizedVoid n) -> m b) -> m b
forall b (n :: Nat) (m :: * -> *).
MonadST m =>
MLockedSizedBytes n -> (Ptr (SizedVoid n) -> m b) -> m b
withMLSB MLockedSizedBytes n
mlsb ((Ptr (SizedVoid n) -> m b) -> m b)
-> (Ptr (SizedVoid n) -> m b) -> m b
forall a b. (a -> b) -> a -> b
$ \Ptr (SizedVoid n)
ptr -> do
        ForeignPtr (SizedVoid n')
fptr <-
          ST (PrimState m) (ForeignPtr (SizedVoid n'))
-> m (ForeignPtr (SizedVoid n'))
forall a. ST (PrimState m) a -> m a
forall (m :: * -> *) a. MonadST m => ST (PrimState m) a -> m a
stToIO (ST (PrimState m) (ForeignPtr (SizedVoid n'))
 -> m (ForeignPtr (SizedVoid n')))
-> ST (PrimState m) (ForeignPtr (SizedVoid n'))
-> m (ForeignPtr (SizedVoid n'))
forall a b. (a -> b) -> a -> b
$ IO (ForeignPtr (SizedVoid n'))
-> ST (PrimState m) (ForeignPtr (SizedVoid n'))
forall a s. IO a -> ST s a
unsafeIOToST (Ptr (SizedVoid n') -> IO (ForeignPtr (SizedVoid n'))
forall a. Ptr a -> IO (ForeignPtr a)
newForeignPtr_ (Ptr (SizedVoid n') -> IO (ForeignPtr (SizedVoid n')))
-> (Ptr Any -> Ptr (SizedVoid n'))
-> Ptr Any
-> IO (ForeignPtr (SizedVoid n'))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Ptr Any -> Ptr (SizedVoid n')
forall a b. Ptr a -> Ptr b
castPtr (Ptr Any -> IO (ForeignPtr (SizedVoid n')))
-> Ptr Any -> IO (ForeignPtr (SizedVoid n'))
forall a b. (a -> b) -> a -> b
$ Ptr (SizedVoid n) -> Int -> Ptr Any
forall a b. Ptr a -> Int -> Ptr b
plusPtr Ptr (SizedVoid n)
ptr Int
offset)
        MLockedSizedBytes n' -> m b
action (MLockedForeignPtr (SizedVoid n') -> MLockedSizedBytes n'
forall (n :: Nat).
MLockedForeignPtr (SizedVoid n) -> MLockedSizedBytes n
MLSB (MLockedForeignPtr (SizedVoid n') -> MLockedSizedBytes n')
-> MLockedForeignPtr (SizedVoid n') -> MLockedSizedBytes n'
forall a b. (a -> b) -> a -> b
$! ForeignPtr (SizedVoid n') -> MLockedForeignPtr (SizedVoid n')
forall a. ForeignPtr a -> MLockedForeignPtr a
SFP (ForeignPtr (SizedVoid n') -> MLockedForeignPtr (SizedVoid n'))
-> ForeignPtr (SizedVoid n') -> MLockedForeignPtr (SizedVoid n')
forall a b. (a -> b) -> a -> b
$! ForeignPtr (SizedVoid n')
fptr)
  where
    chunkSize :: Int
chunkSize = Integer -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Proxy n' -> Integer
forall (n :: Nat) (proxy :: Nat -> *).
KnownNat n =>
proxy n -> Integer
natVal (forall (t :: Nat). Proxy t
forall {k} (t :: k). Proxy t
Proxy @n'))
    parentSize :: Int
parentSize = Integer -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (MLockedSizedBytes n -> Integer
forall (n :: Nat) (proxy :: Nat -> *).
KnownNat n =>
proxy n -> Integer
natVal MLockedSizedBytes n
mlsb)

mlsbSize :: KnownNat n => MLockedSizedBytes n -> CSize
mlsbSize :: forall (n :: Nat). KnownNat n => MLockedSizedBytes n -> CSize
mlsbSize MLockedSizedBytes n
mlsb = Integer -> CSize
forall a. Num a => Integer -> a
fromInteger (MLockedSizedBytes n -> Integer
forall (n :: Nat) (proxy :: Nat -> *).
KnownNat n =>
proxy n -> Integer
natVal MLockedSizedBytes n
mlsb)

-- | Allocate a new 'MLockedSizedBytes'. The caller is responsible for
-- deallocating it ('mlsbFinalize') when done with it. The contents of the
-- memory block is undefined.
mlsbNew :: forall n m. (KnownNat n, MonadST m) => m (MLockedSizedBytes n)
mlsbNew :: forall (n :: Nat) (m :: * -> *).
(KnownNat n, MonadST m) =>
m (MLockedSizedBytes n)
mlsbNew = MLockedAllocator m
-> (KnownNat n, MonadST m) => m (MLockedSizedBytes n)
forall (n :: Nat) (m :: * -> *).
MLockedAllocator m
-> (KnownNat n, MonadST m) => m (MLockedSizedBytes n)
mlsbNewWith MLockedAllocator m
forall (m :: * -> *). MonadST m => MLockedAllocator m
mlockedMalloc

mlsbNewWith :: forall n m. MLockedAllocator m -> (KnownNat n, MonadST m) => m (MLockedSizedBytes n)
mlsbNewWith :: forall (n :: Nat) (m :: * -> *).
MLockedAllocator m
-> (KnownNat n, MonadST m) => m (MLockedSizedBytes n)
mlsbNewWith MLockedAllocator m
allocator =
  MLockedForeignPtr (SizedVoid n) -> MLockedSizedBytes n
forall (n :: Nat).
MLockedForeignPtr (SizedVoid n) -> MLockedSizedBytes n
MLSB (MLockedForeignPtr (SizedVoid n) -> MLockedSizedBytes n)
-> m (MLockedForeignPtr (SizedVoid n)) -> m (MLockedSizedBytes n)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> MLockedAllocator m
-> CSize -> CSize -> m (MLockedForeignPtr (SizedVoid n))
forall (m :: * -> *) a.
MLockedAllocator m -> CSize -> CSize -> m (MLockedForeignPtr a)
mlockedAllocForeignPtrBytesWith MLockedAllocator m
allocator CSize
size CSize
align
  where
    size :: CSize
size = Integer -> CSize
forall a. Num a => Integer -> a
fromInteger (Proxy n -> Integer
forall (n :: Nat) (proxy :: Nat -> *).
KnownNat n =>
proxy n -> Integer
natVal (forall (t :: Nat). Proxy t
forall {k} (t :: k). Proxy t
Proxy @n))
    align :: CSize
align = CSize -> CSize
forall n. (Num n, Ord n, Bits n) => n -> n
nextPowerOf2 CSize
size

-- | Allocate a new 'MLockedSizedBytes', and pre-fill it with zeroes.
-- The caller is responsible for deallocating it ('mlsbFinalize') when done
-- with it. (See also 'mlsbNew').
mlsbNewZero :: forall n m. (KnownNat n, MonadST m) => m (MLockedSizedBytes n)
mlsbNewZero :: forall (n :: Nat) (m :: * -> *).
(KnownNat n, MonadST m) =>
m (MLockedSizedBytes n)
mlsbNewZero = MLockedAllocator m -> m (MLockedSizedBytes n)
forall (n :: Nat) (m :: * -> *).
(KnownNat n, MonadST m) =>
MLockedAllocator m -> m (MLockedSizedBytes n)
mlsbNewZeroWith MLockedAllocator m
forall (m :: * -> *). MonadST m => MLockedAllocator m
mlockedMalloc

mlsbNewZeroWith ::
  forall n m. (KnownNat n, MonadST m) => MLockedAllocator m -> m (MLockedSizedBytes n)
mlsbNewZeroWith :: forall (n :: Nat) (m :: * -> *).
(KnownNat n, MonadST m) =>
MLockedAllocator m -> m (MLockedSizedBytes n)
mlsbNewZeroWith MLockedAllocator m
allocator = do
  MLockedSizedBytes n
mlsb <- MLockedAllocator m
-> (KnownNat n, MonadST m) => m (MLockedSizedBytes n)
forall (n :: Nat) (m :: * -> *).
MLockedAllocator m
-> (KnownNat n, MonadST m) => m (MLockedSizedBytes n)
mlsbNewWith MLockedAllocator m
allocator
  MLockedSizedBytes n -> m ()
forall (n :: Nat) (m :: * -> *).
(KnownNat n, MonadST m) =>
MLockedSizedBytes n -> m ()
mlsbZero MLockedSizedBytes n
mlsb
  MLockedSizedBytes n -> m (MLockedSizedBytes n)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return MLockedSizedBytes n
mlsb

-- | Overwrite an existing 'MLockedSizedBytes' with zeroes.
mlsbZero :: forall n m. (KnownNat n, MonadST m) => MLockedSizedBytes n -> m ()
mlsbZero :: forall (n :: Nat) (m :: * -> *).
(KnownNat n, MonadST m) =>
MLockedSizedBytes n -> m ()
mlsbZero MLockedSizedBytes n
mlsb = do
  MLockedSizedBytes n -> (Ptr (SizedVoid n) -> m ()) -> m ()
forall b (n :: Nat) (m :: * -> *).
MonadST m =>
MLockedSizedBytes n -> (Ptr (SizedVoid n) -> m b) -> m b
withMLSB MLockedSizedBytes n
mlsb ((Ptr (SizedVoid n) -> m ()) -> m ())
-> (Ptr (SizedVoid n) -> m ()) -> m ()
forall a b. (a -> b) -> a -> b
$ \Ptr (SizedVoid n)
ptr -> Ptr (SizedVoid n) -> CSize -> m ()
forall (m :: * -> *) a. MonadST m => Ptr a -> CSize -> m ()
zeroMem Ptr (SizedVoid n)
ptr (MLockedSizedBytes n -> CSize
forall (n :: Nat). KnownNat n => MLockedSizedBytes n -> CSize
mlsbSize MLockedSizedBytes n
mlsb)

-- | Create a deep mlocked copy of an 'MLockedSizedBytes'.
mlsbCopy ::
  forall n m.
  (KnownNat n, MonadST m) =>
  MLockedSizedBytes n ->
  m (MLockedSizedBytes n)
mlsbCopy :: forall (n :: Nat) (m :: * -> *).
(KnownNat n, MonadST m) =>
MLockedSizedBytes n -> m (MLockedSizedBytes n)
mlsbCopy = MLockedAllocator m
-> MLockedSizedBytes n -> m (MLockedSizedBytes n)
forall (n :: Nat) (m :: * -> *).
(KnownNat n, MonadST m) =>
MLockedAllocator m
-> MLockedSizedBytes n -> m (MLockedSizedBytes n)
mlsbCopyWith MLockedAllocator m
forall (m :: * -> *). MonadST m => MLockedAllocator m
mlockedMalloc

mlsbCopyWith ::
  forall n m.
  (KnownNat n, MonadST m) =>
  MLockedAllocator m ->
  MLockedSizedBytes n ->
  m (MLockedSizedBytes n)
mlsbCopyWith :: forall (n :: Nat) (m :: * -> *).
(KnownNat n, MonadST m) =>
MLockedAllocator m
-> MLockedSizedBytes n -> m (MLockedSizedBytes n)
mlsbCopyWith MLockedAllocator m
allocator MLockedSizedBytes n
src = MLockedSizedBytes n
-> (Ptr Word8 -> m (MLockedSizedBytes n))
-> m (MLockedSizedBytes n)
forall (m :: * -> *) (n :: Nat) r.
MonadST m =>
MLockedSizedBytes n -> (Ptr Word8 -> m r) -> m r
mlsbUseAsCPtr MLockedSizedBytes n
src ((Ptr Word8 -> m (MLockedSizedBytes n)) -> m (MLockedSizedBytes n))
-> (Ptr Word8 -> m (MLockedSizedBytes n))
-> m (MLockedSizedBytes n)
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
ptrSrc -> do
  MLockedSizedBytes n
dst <- MLockedAllocator m
-> (KnownNat n, MonadST m) => m (MLockedSizedBytes n)
forall (n :: Nat) (m :: * -> *).
MLockedAllocator m
-> (KnownNat n, MonadST m) => m (MLockedSizedBytes n)
mlsbNewWith MLockedAllocator m
allocator
  MLockedSizedBytes n -> (Ptr (SizedVoid n) -> m ()) -> m ()
forall b (n :: Nat) (m :: * -> *).
MonadST m =>
MLockedSizedBytes n -> (Ptr (SizedVoid n) -> m b) -> m b
withMLSB MLockedSizedBytes n
dst ((Ptr (SizedVoid n) -> m ()) -> m ())
-> (Ptr (SizedVoid n) -> m ()) -> m ()
forall a b. (a -> b) -> a -> b
$ \Ptr (SizedVoid n)
ptrDst -> do
    Ptr Any -> Ptr Any -> CSize -> m ()
forall (m :: * -> *) a.
MonadST m =>
Ptr a -> Ptr a -> CSize -> m ()
copyMem (Ptr (SizedVoid n) -> Ptr Any
forall a b. Ptr a -> Ptr b
castPtr Ptr (SizedVoid n)
ptrDst) (Ptr Word8 -> Ptr Any
forall a b. Ptr a -> Ptr b
castPtr Ptr Word8
ptrSrc) (MLockedSizedBytes n -> CSize
forall (n :: Nat). KnownNat n => MLockedSizedBytes n -> CSize
mlsbSize MLockedSizedBytes n
src)
  MLockedSizedBytes n -> m (MLockedSizedBytes n)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return MLockedSizedBytes n
dst

-- | Allocate a new 'MLockedSizedBytes', and fill it with the contents of a
-- 'ByteString'. The size of the input is not checked.
-- /Note:/ since the input 'BS.ByteString' is a plain old Haskell value, it has
-- already violated the secure-forgetting properties afforded by
-- 'MLockedSizedBytes', so this function is useless outside of testing. Use
-- 'mlsbNew' or 'mlsbNewZero' to create 'MLockedSizedBytes' values, and
-- manipulate them through 'withMLSB', 'mlsbUseAsCPtr', or 'mlsbUseAsSizedPtr'.
-- (See also 'mlsbFromByteStringCheck')
mlsbFromByteString ::
  forall n m.
  (KnownNat n, MonadST m) =>
  BS.ByteString -> m (MLockedSizedBytes n)
mlsbFromByteString :: forall (n :: Nat) (m :: * -> *).
(KnownNat n, MonadST m) =>
ByteString -> m (MLockedSizedBytes n)
mlsbFromByteString = MLockedAllocator m -> ByteString -> m (MLockedSizedBytes n)
forall (n :: Nat) (m :: * -> *).
(KnownNat n, MonadST m) =>
MLockedAllocator m -> ByteString -> m (MLockedSizedBytes n)
mlsbFromByteStringWith MLockedAllocator m
forall (m :: * -> *). MonadST m => MLockedAllocator m
mlockedMalloc

mlsbFromByteStringWith ::
  forall n m.
  (KnownNat n, MonadST m) =>
  MLockedAllocator m -> BS.ByteString -> m (MLockedSizedBytes n)
mlsbFromByteStringWith :: forall (n :: Nat) (m :: * -> *).
(KnownNat n, MonadST m) =>
MLockedAllocator m -> ByteString -> m (MLockedSizedBytes n)
mlsbFromByteStringWith MLockedAllocator m
allocator ByteString
bs = do
  MLockedSizedBytes n
dst <- MLockedAllocator m
-> (KnownNat n, MonadST m) => m (MLockedSizedBytes n)
forall (n :: Nat) (m :: * -> *).
MLockedAllocator m
-> (KnownNat n, MonadST m) => m (MLockedSizedBytes n)
mlsbNewWith MLockedAllocator m
allocator
  MLockedSizedBytes n -> (Ptr (SizedVoid n) -> m ()) -> m ()
forall b (n :: Nat) (m :: * -> *).
MonadST m =>
MLockedSizedBytes n -> (Ptr (SizedVoid n) -> m b) -> m b
withMLSB MLockedSizedBytes n
dst ((Ptr (SizedVoid n) -> m ()) -> m ())
-> (Ptr (SizedVoid n) -> m ()) -> m ()
forall a b. (a -> b) -> a -> b
$ \Ptr (SizedVoid n)
ptr -> ST (PrimState m) () -> m ()
forall a. ST (PrimState m) a -> m a
forall (m :: * -> *) a. MonadST m => ST (PrimState m) a -> m a
stToIO (ST (PrimState m) () -> m ())
-> (IO () -> ST (PrimState m) ()) -> IO () -> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IO () -> ST (PrimState m) ()
forall a s. IO a -> ST s a
unsafeIOToST (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$ do
    ByteString -> (CStringLen -> IO ()) -> IO ()
forall a. ByteString -> (CStringLen -> IO a) -> IO a
BS.useAsCStringLen ByteString
bs ((CStringLen -> IO ()) -> IO ()) -> (CStringLen -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \(Ptr CChar
ptrBS, Int
len) -> do
      Ptr CChar -> Ptr CChar -> CSize -> IO ()
forall (m :: * -> *) a.
MonadST m =>
Ptr a -> Ptr a -> CSize -> m ()
copyMem (Ptr (SizedVoid n) -> Ptr CChar
forall a b. Ptr a -> Ptr b
castPtr Ptr (SizedVoid n)
ptr) Ptr CChar
ptrBS (CSize -> CSize -> CSize
forall a. Ord a => a -> a -> a
min (Int -> CSize
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
len) (MLockedSizedBytes n -> CSize
forall (n :: Nat). KnownNat n => MLockedSizedBytes n -> CSize
mlsbSize MLockedSizedBytes n
dst))
  MLockedSizedBytes n -> m (MLockedSizedBytes n)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return MLockedSizedBytes n
dst

-- | Allocate a new 'MLockedSizedBytes', and fill it with the contents of a
-- 'ByteString'. The size of the input is checked.
-- /Note:/ since the input 'BS.ByteString' is a plain old Haskell value, it has
-- already violated the secure-forgetting properties afforded by
-- 'MLockedSizedBytes', so this function is useless outside of testing. Use
-- 'mlsbNew' or 'mlsbNewZero' to create 'MLockedSizedBytes' values, and
-- manipulate them through 'withMLSB', 'mlsbUseAsCPtr', or 'mlsbUseAsSizedPtr'.
-- (See also 'mlsbFromByteString')
mlsbFromByteStringCheck ::
  forall n m.
  (KnownNat n, MonadST m) =>
  BS.ByteString ->
  m (Maybe (MLockedSizedBytes n))
mlsbFromByteStringCheck :: forall (n :: Nat) (m :: * -> *).
(KnownNat n, MonadST m) =>
ByteString -> m (Maybe (MLockedSizedBytes n))
mlsbFromByteStringCheck = MLockedAllocator m -> ByteString -> m (Maybe (MLockedSizedBytes n))
forall (n :: Nat) (m :: * -> *).
(KnownNat n, MonadST m) =>
MLockedAllocator m -> ByteString -> m (Maybe (MLockedSizedBytes n))
mlsbFromByteStringCheckWith MLockedAllocator m
forall (m :: * -> *). MonadST m => MLockedAllocator m
mlockedMalloc

mlsbFromByteStringCheckWith ::
  forall n m.
  (KnownNat n, MonadST m) =>
  MLockedAllocator m ->
  BS.ByteString ->
  m (Maybe (MLockedSizedBytes n))
mlsbFromByteStringCheckWith :: forall (n :: Nat) (m :: * -> *).
(KnownNat n, MonadST m) =>
MLockedAllocator m -> ByteString -> m (Maybe (MLockedSizedBytes n))
mlsbFromByteStringCheckWith MLockedAllocator m
allocator ByteString
bs
  | ByteString -> Int
BS.length ByteString
bs Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
size = Maybe (MLockedSizedBytes n) -> m (Maybe (MLockedSizedBytes n))
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe (MLockedSizedBytes n)
forall a. Maybe a
Nothing
  | Bool
otherwise = MLockedSizedBytes n -> Maybe (MLockedSizedBytes n)
forall a. a -> Maybe a
Just (MLockedSizedBytes n -> Maybe (MLockedSizedBytes n))
-> m (MLockedSizedBytes n) -> m (Maybe (MLockedSizedBytes n))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> MLockedAllocator m -> ByteString -> m (MLockedSizedBytes n)
forall (n :: Nat) (m :: * -> *).
(KnownNat n, MonadST m) =>
MLockedAllocator m -> ByteString -> m (MLockedSizedBytes n)
mlsbFromByteStringWith MLockedAllocator m
allocator ByteString
bs
  where
    size :: Int
    size :: Int
size = Integer -> Int
forall a. Num a => Integer -> a
fromInteger (Proxy n -> Integer
forall (n :: Nat) (proxy :: Nat -> *).
KnownNat n =>
proxy n -> Integer
natVal (forall (t :: Nat). Proxy t
forall {k} (t :: k). Proxy t
Proxy @n))

-- | /Note:/ the resulting 'BS.ByteString' will still refer to secure memory,
-- but the types don't prevent it from be exposed. Note further that any
-- subsequent operations (splicing & dicing, copying, conversion,
-- packing/unpacking, etc.) on the resulting 'BS.ByteString' may create copies
-- of the mlocked memory on the unprotected GHC heap, and thus leak secrets,
-- so use this function with extreme care.
mlsbAsByteString :: forall n. KnownNat n => MLockedSizedBytes n -> BS.ByteString
mlsbAsByteString :: forall (n :: Nat). KnownNat n => MLockedSizedBytes n -> ByteString
mlsbAsByteString mlsb :: MLockedSizedBytes n
mlsb@(MLSB (SFP ForeignPtr (SizedVoid n)
fptr)) = ForeignPtr Word8 -> Int -> Int -> ByteString
BSI.PS (ForeignPtr (SizedVoid n) -> ForeignPtr Word8
forall a b. ForeignPtr a -> ForeignPtr b
castForeignPtr ForeignPtr (SizedVoid n)
fptr) Int
0 Int
size
  where
    size :: Int
    size :: Int
size = CSize -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (MLockedSizedBytes n -> CSize
forall (n :: Nat). KnownNat n => MLockedSizedBytes n -> CSize
mlsbSize MLockedSizedBytes n
mlsb)

-- | /Note:/ this function will leak mlocked memory to the Haskell heap
-- and should not be used in production code.
mlsbToByteString :: forall n m. (KnownNat n, MonadST m) => MLockedSizedBytes n -> m BS.ByteString
mlsbToByteString :: forall (n :: Nat) (m :: * -> *).
(KnownNat n, MonadST m) =>
MLockedSizedBytes n -> m ByteString
mlsbToByteString MLockedSizedBytes n
mlsb =
  MLockedSizedBytes n
-> (Ptr (SizedVoid n) -> m ByteString) -> m ByteString
forall b (n :: Nat) (m :: * -> *).
MonadST m =>
MLockedSizedBytes n -> (Ptr (SizedVoid n) -> m b) -> m b
withMLSB MLockedSizedBytes n
mlsb ((Ptr (SizedVoid n) -> m ByteString) -> m ByteString)
-> (Ptr (SizedVoid n) -> m ByteString) -> m ByteString
forall a b. (a -> b) -> a -> b
$ \Ptr (SizedVoid n)
ptr ->
    ST (PrimState m) ByteString -> m ByteString
forall a. ST (PrimState m) a -> m a
forall (m :: * -> *) a. MonadST m => ST (PrimState m) a -> m a
stToIO (ST (PrimState m) ByteString -> m ByteString)
-> (IO ByteString -> ST (PrimState m) ByteString)
-> IO ByteString
-> m ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IO ByteString -> ST (PrimState m) ByteString
forall a s. IO a -> ST s a
unsafeIOToST (IO ByteString -> m ByteString) -> IO ByteString -> m ByteString
forall a b. (a -> b) -> a -> b
$ CStringLen -> IO ByteString
BS.packCStringLen (Ptr (SizedVoid n) -> Ptr CChar
forall a b. Ptr a -> Ptr b
castPtr Ptr (SizedVoid n)
ptr, Int
size)
  where
    size :: Int
    size :: Int
size = CSize -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (MLockedSizedBytes n -> CSize
forall (n :: Nat). KnownNat n => MLockedSizedBytes n -> CSize
mlsbSize MLockedSizedBytes n
mlsb)

-- | Use an 'MLockedSizedBytes' value as a raw C pointer. Care should be taken
-- to never copy the contents of the 'MLockedSizedBytes' value into managed
-- memory through the raw pointer, because that would violate the
-- secure-forgetting property of mlocked memory.
mlsbUseAsCPtr :: MonadST m => MLockedSizedBytes n -> (Ptr Word8 -> m r) -> m r
mlsbUseAsCPtr :: forall (m :: * -> *) (n :: Nat) r.
MonadST m =>
MLockedSizedBytes n -> (Ptr Word8 -> m r) -> m r
mlsbUseAsCPtr (MLSB MLockedForeignPtr (SizedVoid n)
x) Ptr Word8 -> m r
k =
  MLockedForeignPtr (SizedVoid n)
-> (Ptr (SizedVoid n) -> m r) -> m r
forall (m :: * -> *) a b.
MonadST m =>
MLockedForeignPtr a -> (Ptr a -> m b) -> m b
withMLockedForeignPtr MLockedForeignPtr (SizedVoid n)
x (Ptr Word8 -> m r
k (Ptr Word8 -> m r)
-> (Ptr (SizedVoid n) -> Ptr Word8) -> Ptr (SizedVoid n) -> m r
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Ptr (SizedVoid n) -> Ptr Word8
forall a b. Ptr a -> Ptr b
castPtr)

-- | Use an 'MLockedSizedBytes' value as a 'SizedPtr' of the same size. Care
-- should be taken to never copy the contents of the 'MLockedSizedBytes' value
-- into managed memory through the sized pointer, because that would violate
-- the secure-forgetting property of mlocked memory.
mlsbUseAsSizedPtr :: forall n r m. MonadST m => MLockedSizedBytes n -> (SizedPtr n -> m r) -> m r
mlsbUseAsSizedPtr :: forall (n :: Nat) r (m :: * -> *).
MonadST m =>
MLockedSizedBytes n -> (SizedPtr n -> m r) -> m r
mlsbUseAsSizedPtr (MLSB MLockedForeignPtr (SizedVoid n)
x) SizedPtr n -> m r
k =
  MLockedForeignPtr (SizedVoid n)
-> (Ptr (SizedVoid n) -> m r) -> m r
forall (m :: * -> *) a b.
MonadST m =>
MLockedForeignPtr a -> (Ptr a -> m b) -> m b
withMLockedForeignPtr MLockedForeignPtr (SizedVoid n)
x (SizedPtr n -> m r
k (SizedPtr n -> m r)
-> (Ptr (SizedVoid n) -> SizedPtr n) -> Ptr (SizedVoid n) -> m r
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Ptr Void -> SizedPtr n
forall (n :: Nat). Ptr Void -> SizedPtr n
SizedPtr (Ptr Void -> SizedPtr n)
-> (Ptr (SizedVoid n) -> Ptr Void)
-> Ptr (SizedVoid n)
-> SizedPtr n
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Ptr (SizedVoid n) -> Ptr Void
forall a b. Ptr a -> Ptr b
castPtr)

-- | Calls 'finalizeMLockedForeignPtr' on underlying pointer.
-- This function invalidates argument.
mlsbFinalize :: MonadST m => MLockedSizedBytes n -> m ()
mlsbFinalize :: forall (m :: * -> *) (n :: Nat).
MonadST m =>
MLockedSizedBytes n -> m ()
mlsbFinalize (MLSB MLockedForeignPtr (SizedVoid n)
ptr) = MLockedForeignPtr (SizedVoid n) -> m ()
forall (m :: * -> *) a. MonadST m => MLockedForeignPtr a -> m ()
finalizeMLockedForeignPtr MLockedForeignPtr (SizedVoid n)
ptr

-- | 'compareM' on 'MLockedSizedBytes'
mlsbCompare ::
  forall n m. (MonadST m, KnownNat n) => MLockedSizedBytes n -> MLockedSizedBytes n -> m Ordering
mlsbCompare :: forall (n :: Nat) (m :: * -> *).
(MonadST m, KnownNat n) =>
MLockedSizedBytes n -> MLockedSizedBytes n -> m Ordering
mlsbCompare (MLSB MLockedForeignPtr (SizedVoid n)
x) (MLSB MLockedForeignPtr (SizedVoid n)
y) =
  MLockedForeignPtr (SizedVoid n)
-> (Ptr (SizedVoid n) -> m Ordering) -> m Ordering
forall (m :: * -> *) a b.
MonadST m =>
MLockedForeignPtr a -> (Ptr a -> m b) -> m b
withMLockedForeignPtr MLockedForeignPtr (SizedVoid n)
x ((Ptr (SizedVoid n) -> m Ordering) -> m Ordering)
-> (Ptr (SizedVoid n) -> m Ordering) -> m Ordering
forall a b. (a -> b) -> a -> b
$ \Ptr (SizedVoid n)
x' ->
    MLockedForeignPtr (SizedVoid n)
-> (Ptr (SizedVoid n) -> m Ordering) -> m Ordering
forall (m :: * -> *) a b.
MonadST m =>
MLockedForeignPtr a -> (Ptr a -> m b) -> m b
withMLockedForeignPtr MLockedForeignPtr (SizedVoid n)
y ((Ptr (SizedVoid n) -> m Ordering) -> m Ordering)
-> (Ptr (SizedVoid n) -> m Ordering) -> m Ordering
forall a b. (a -> b) -> a -> b
$ \Ptr (SizedVoid n)
y' -> do
      Int
res <- ST (PrimState m) Int -> m Int
forall a. ST (PrimState m) a -> m a
forall (m :: * -> *) a. MonadST m => ST (PrimState m) a -> m a
stToIO (ST (PrimState m) Int -> m Int)
-> (IO Int -> ST (PrimState m) Int) -> IO Int -> m Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IO Int -> ST (PrimState m) Int
forall a s. IO a -> ST s a
unsafeIOToST (IO Int -> m Int) -> IO Int -> m Int
forall a b. (a -> b) -> a -> b
$ Ptr (SizedVoid n) -> Ptr (SizedVoid n) -> CSize -> IO Int
forall a. Ptr a -> Ptr a -> CSize -> IO Int
c_sodium_compare Ptr (SizedVoid n)
x' Ptr (SizedVoid n)
y' CSize
size
      Ordering -> m Ordering
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (Ordering -> m Ordering) -> Ordering -> m Ordering
forall a b. (a -> b) -> a -> b
$ Int -> Int -> Ordering
forall a. Ord a => a -> a -> Ordering
compare Int
res Int
0
  where
    size :: CSize
size = Integer -> CSize
forall a. Num a => Integer -> a
fromInteger (Integer -> CSize) -> Integer -> CSize
forall a b. (a -> b) -> a -> b
$ Proxy n -> Integer
forall (n :: Nat) (proxy :: Nat -> *).
KnownNat n =>
proxy n -> Integer
natVal (forall (t :: Nat). Proxy t
forall {k} (t :: k). Proxy t
Proxy @n)

-- | 'equalsM' on 'MLockedSizedBytes'
mlsbEq ::
  forall n m. (MonadST m, KnownNat n) => MLockedSizedBytes n -> MLockedSizedBytes n -> m Bool
mlsbEq :: forall (n :: Nat) (m :: * -> *).
(MonadST m, KnownNat n) =>
MLockedSizedBytes n -> MLockedSizedBytes n -> m Bool
mlsbEq MLockedSizedBytes n
a MLockedSizedBytes n
b = (Ordering -> Ordering -> Bool
forall a. Eq a => a -> a -> Bool
== Ordering
EQ) (Ordering -> Bool) -> m Ordering -> m Bool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> MLockedSizedBytes n -> MLockedSizedBytes n -> m Ordering
forall (n :: Nat) (m :: * -> *).
(MonadST m, KnownNat n) =>
MLockedSizedBytes n -> MLockedSizedBytes n -> m Ordering
mlsbCompare MLockedSizedBytes n
a MLockedSizedBytes n
b