{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DerivingVia #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE QuasiQuotes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TemplateHaskellQuotes #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE UnboxedTuples #-}

module Cardano.Crypto.PinnedSizedBytes (
  PinnedSizedBytes,

  -- * Initialization
  psbZero,

  -- * Conversions
  psbFromBytes,
  psbToBytes,
  psbFromByteString,
  psbFromByteStringCheck,
  psbToByteString,

  -- * C usage
  psbUseAsCPtr,
  psbUseAsCPtrLen,
  psbUseAsSizedPtr,
  psbCreate,
  psbCreateLen,
  psbCreateSized,
  psbCreateResult,
  psbCreateResultLen,
  psbCreateSizedResult,
  ptrPsbToSizedPtr,
) where

import Control.DeepSeq (NFData)
import Control.Monad.Class.MonadST (MonadST, stToIO)
import Control.Monad.Primitive (primitive_, touch)
import Control.Monad.ST (runST)
import Control.Monad.ST.Unsafe (unsafeIOToST)
import Data.Kind (Type)
import Data.Primitive.ByteArray (
  ByteArray (..),
  MutableByteArray (..),
  byteArrayContents,
  copyByteArrayToAddr,
  foldrByteArray,
  mutableByteArrayContents,
  newPinnedByteArray,
  unsafeFreezeByteArray,
  writeByteArray,
 )
import Data.Proxy (Proxy (..))
import Data.String (IsString (..))
import Data.Word (Word8)
import Foreign.C.Types (CSize)
import Foreign.Ptr (FunPtr, castPtr)
import Foreign.Storable (Storable (..))
import GHC.TypeLits (KnownNat, Nat, natVal)
import Language.Haskell.TH.Syntax (Q, TExp (..))
import Language.Haskell.TH.Syntax.Compat (Code (..), examineSplice)
import NoThunks.Class (NoThunks, OnlyCheckWhnfNamed (..))
import Numeric (showHex)
import System.IO.Unsafe (unsafeDupablePerformIO)

import GHC.Exts (Int (..), copyAddrToByteArray#)
import GHC.Ptr (Ptr (..))

import qualified Data.ByteString as BS
import qualified Data.Primitive as Prim

import Cardano.Crypto.Libsodium.C (c_sodium_compare)
import Cardano.Crypto.Util (decodeHexString)
import Cardano.Foreign

{- HLINT ignore "Reduce duplication" -}

-- $setup
-- >>> :set -XDataKinds -XTypeApplications -XOverloadedStrings
-- >>> import Cardano.Crypto.PinnedSizedBytes

-- | @n@ bytes. 'Storable'.
--
-- We have two @*Bytes@ types:
--
-- * @PinnedSizedBytes@ is backed by pinned ByteArray.
-- * @MLockedSizedBytes@ is backed by ForeignPtr to @mlock@-ed memory region.
--
-- The 'ByteString' is pinned datatype, but it's represented by
-- 'ForeignPtr' + offset (and size).
--
-- I'm sorry for adding more types for bytes. :(
newtype PinnedSizedBytes (n :: Nat) = PSB ByteArray
  deriving (Context -> PinnedSizedBytes n -> IO (Maybe ThunkInfo)
Proxy (PinnedSizedBytes n) -> String
(Context -> PinnedSizedBytes n -> IO (Maybe ThunkInfo))
-> (Context -> PinnedSizedBytes n -> IO (Maybe ThunkInfo))
-> (Proxy (PinnedSizedBytes n) -> String)
-> NoThunks (PinnedSizedBytes n)
forall (n :: Nat).
Context -> PinnedSizedBytes n -> IO (Maybe ThunkInfo)
forall (n :: Nat). Proxy (PinnedSizedBytes n) -> String
forall a.
(Context -> a -> IO (Maybe ThunkInfo))
-> (Context -> a -> IO (Maybe ThunkInfo))
-> (Proxy a -> String)
-> NoThunks a
$cnoThunks :: forall (n :: Nat).
Context -> PinnedSizedBytes n -> IO (Maybe ThunkInfo)
noThunks :: Context -> PinnedSizedBytes n -> IO (Maybe ThunkInfo)
$cwNoThunks :: forall (n :: Nat).
Context -> PinnedSizedBytes n -> IO (Maybe ThunkInfo)
wNoThunks :: Context -> PinnedSizedBytes n -> IO (Maybe ThunkInfo)
$cshowTypeOf :: forall (n :: Nat). Proxy (PinnedSizedBytes n) -> String
showTypeOf :: Proxy (PinnedSizedBytes n) -> String
NoThunks) via OnlyCheckWhnfNamed "PinnedSizedBytes" (PinnedSizedBytes n)
  deriving (PinnedSizedBytes n -> ()
(PinnedSizedBytes n -> ()) -> NFData (PinnedSizedBytes n)
forall (n :: Nat). PinnedSizedBytes n -> ()
forall a. (a -> ()) -> NFData a
$crnf :: forall (n :: Nat). PinnedSizedBytes n -> ()
rnf :: PinnedSizedBytes n -> ()
NFData)

instance Show (PinnedSizedBytes n) where
  showsPrec :: Int -> PinnedSizedBytes n -> ShowS
showsPrec Int
_ (PSB ByteArray
ba) =
    Char -> ShowS
showChar Char
'"'
      ShowS -> ShowS -> ShowS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Word8 -> ShowS -> ShowS) -> ShowS -> ByteArray -> ShowS
forall a b. Prim a => (a -> b -> b) -> b -> ByteArray -> b
foldrByteArray (\Word8
w ShowS
acc -> Word8 -> ShowS
show8 Word8
w ShowS -> ShowS -> ShowS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ShowS
acc) ShowS
forall a. a -> a
id ByteArray
ba
      ShowS -> ShowS -> ShowS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Char -> ShowS
showChar Char
'"'
    where
      show8 :: Word8 -> ShowS
      show8 :: Word8 -> ShowS
show8 Word8
w
        | Word8
w Word8 -> Word8 -> Bool
forall a. Ord a => a -> a -> Bool
< Word8
16 = Char -> ShowS
showChar Char
'0' ShowS -> ShowS -> ShowS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Word8 -> ShowS
forall a. Integral a => a -> ShowS
showHex Word8
w
        | Bool
otherwise = Word8 -> ShowS
forall a. Integral a => a -> ShowS
showHex Word8
w

-- | The comparison is done in constant time for a given size @n@.
instance KnownNat n => Eq (PinnedSizedBytes n) where
  PinnedSizedBytes n
x == :: PinnedSizedBytes n -> PinnedSizedBytes n -> Bool
== PinnedSizedBytes n
y = PinnedSizedBytes n -> PinnedSizedBytes n -> Ordering
forall a. Ord a => a -> a -> Ordering
compare PinnedSizedBytes n
x PinnedSizedBytes n
y Ordering -> Ordering -> Bool
forall a. Eq a => a -> a -> Bool
== Ordering
EQ

instance KnownNat n => Ord (PinnedSizedBytes n) where
  compare :: PinnedSizedBytes n -> PinnedSizedBytes n -> Ordering
compare PinnedSizedBytes n
x PinnedSizedBytes n
y =
    (forall s. ST s Ordering) -> Ordering
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s Ordering) -> Ordering)
-> (forall s. ST s Ordering) -> Ordering
forall a b. (a -> b) -> a -> b
$
      PinnedSizedBytes n -> (Ptr Word8 -> ST s Ordering) -> ST s Ordering
forall (n :: Nat) r (m :: * -> *).
MonadST m =>
PinnedSizedBytes n -> (Ptr Word8 -> m r) -> m r
psbUseAsCPtr PinnedSizedBytes n
x ((Ptr Word8 -> ST s Ordering) -> ST s Ordering)
-> (Ptr Word8 -> ST s Ordering) -> ST s Ordering
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
xPtr ->
        PinnedSizedBytes n -> (Ptr Word8 -> ST s Ordering) -> ST s Ordering
forall (n :: Nat) r (m :: * -> *).
MonadST m =>
PinnedSizedBytes n -> (Ptr Word8 -> m r) -> m r
psbUseAsCPtr PinnedSizedBytes n
y ((Ptr Word8 -> ST s Ordering) -> ST s Ordering)
-> (Ptr Word8 -> ST s Ordering) -> ST s Ordering
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
yPtr -> do
          Int
res <- IO Int -> ST s Int
forall a s. IO a -> ST s a
unsafeIOToST (IO Int -> ST s Int) -> IO Int -> ST s Int
forall a b. (a -> b) -> a -> b
$ Ptr Word8 -> Ptr Word8 -> CSize -> IO Int
forall a. Ptr a -> Ptr a -> CSize -> IO Int
c_sodium_compare Ptr Word8
xPtr Ptr Word8
yPtr CSize
size
          Ordering -> ST s Ordering
forall a. a -> ST s a
forall (m :: * -> *) a. Monad m => a -> m a
return (Int -> Int -> Ordering
forall a. Ord a => a -> a -> Ordering
compare Int
res Int
0)
    where
      size :: CSize
      size :: CSize
size = Integer -> CSize
forall a. Num a => Integer -> a
fromInteger (Proxy n -> Integer
forall (n :: Nat) (proxy :: Nat -> *).
KnownNat n =>
proxy n -> Integer
natVal (Proxy n
forall {k} (t :: k). Proxy t
Proxy :: Proxy n))

-- | This instance is meant to be used with @TemplateHaskell@
--
-- >>> import Cardano.Crypto.PinnedSizedBytes
-- >>> :set -XTemplateHaskell
-- >>> :set -XOverloadedStrings
-- >>> :set -XDataKinds
-- >>> print ($$("0xdeadbeef") :: PinnedSizedBytes 4)
-- "deadbeef"
-- >>> print ($$("deadbeef") :: PinnedSizedBytes 4)
-- "deadbeef"
-- >>> let bsb = $$("0xdeadbeef") :: PinnedSizedBytes 5
-- <interactive>:9:14: error:
--     • <PinnedSizedBytes>: Expected in decoded form to be: 5 bytes, but got: 4
--     • In the Template Haskell splice $$("0xdeadbeef")
--       In the expression: $$("0xdeadbeef") :: PinnedSizedBytes 5
--       In an equation for ‘bsb’:
--           bsb = $$("0xdeadbeef") :: PinnedSizedBytes 5
-- >>> let bsb = $$("nogood") :: PinnedSizedBytes 5
-- <interactive>:11:14: error:
--     • <PinnedSizedBytes>: Malformed hex: invalid character at offset: 0
--     • In the Template Haskell splice $$("nogood")
--       In the expression: $$("nogood") :: PinnedSizedBytes 5
--       In an equation for ‘bsb’: bsb = $$("nogood") :: PinnedSizedBytes 5
instance KnownNat n => IsString (Q (TExp (PinnedSizedBytes n))) where
  fromString :: String -> Q (TExp (PinnedSizedBytes n))
fromString String
hexStr = do
    let n :: Int
n = Integer -> Int
forall a. Num a => Integer -> a
fromInteger (Integer -> Int) -> Integer -> Int
forall a b. (a -> b) -> a -> b
$ Proxy n -> Integer
forall (n :: Nat) (proxy :: Nat -> *).
KnownNat n =>
proxy n -> Integer
natVal (Proxy n
forall {k} (t :: k). Proxy t
Proxy :: Proxy n)
    case String -> Int -> Either String ByteString
decodeHexString String
hexStr Int
n of
      Left String
err -> String -> Q (TExp (PinnedSizedBytes n))
forall a. String -> Q a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail (String -> Q (TExp (PinnedSizedBytes n)))
-> String -> Q (TExp (PinnedSizedBytes n))
forall a b. (a -> b) -> a -> b
$ String
"<PinnedSizedBytes>: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
err
      Right ByteString
_ -> Splice Q (PinnedSizedBytes n) -> Q (TExp (PinnedSizedBytes n))
forall (m :: * -> *) a. Splice m a -> m (TExp a)
examineSplice [||(a -> c) -> (b -> c) -> Either a b -> c
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either String -> a
forall a. HasCallStack => String -> a
error ByteString -> PinnedSizedBytes n
forall (n :: Nat). KnownNat n => ByteString -> PinnedSizedBytes n
psbFromByteString (String -> Int -> Either String ByteString
decodeHexString String
hexStr a
n)||]

instance KnownNat n => IsString (Code Q (PinnedSizedBytes n)) where
  fromString :: String -> Code Q (PinnedSizedBytes n)
fromString = Q (TExp (PinnedSizedBytes n)) -> Code Q (PinnedSizedBytes n)
forall (m :: * -> *) a. m (TExp a) -> Code m a
Code (Q (TExp (PinnedSizedBytes n)) -> Code Q (PinnedSizedBytes n))
-> (String -> Q (TExp (PinnedSizedBytes n)))
-> String
-> Code Q (PinnedSizedBytes n)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> Q (TExp (PinnedSizedBytes n))
forall a. IsString a => String -> a
fromString

-- | See 'psbFromBytes'.
psbToBytes :: PinnedSizedBytes n -> [Word8]
psbToBytes :: forall (n :: Nat). PinnedSizedBytes n -> [Word8]
psbToBytes (PSB ByteArray
ba) = (Word8 -> [Word8] -> [Word8]) -> [Word8] -> ByteArray -> [Word8]
forall a b. Prim a => (a -> b -> b) -> b -> ByteArray -> b
foldrByteArray (:) [] ByteArray
ba

psbToByteString :: PinnedSizedBytes n -> BS.ByteString
psbToByteString :: forall (n :: Nat). PinnedSizedBytes n -> ByteString
psbToByteString = [Word8] -> ByteString
BS.pack ([Word8] -> ByteString)
-> (PinnedSizedBytes n -> [Word8])
-> PinnedSizedBytes n
-> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PinnedSizedBytes n -> [Word8]
forall (n :: Nat). PinnedSizedBytes n -> [Word8]
psbToBytes

-- | See @'IsString' ('PinnedSizedBytes' n)@ instance.
--
-- >>> psbToBytes . (id @(PinnedSizedBytes 4)) . psbFromBytes $ [1,2,3,4]
-- [1,2,3,4]
--
-- >>> psbToBytes . (id @(PinnedSizedBytes 4)) . psbFromBytes $ [1,2]
-- [0,0,1,2]
--
-- >>> psbToBytes . (id @(PinnedSizedBytes 4)) . psbFromBytes $ [1,2,3,4,5,6]
-- [3,4,5,6]
{-# DEPRECATED psbFromBytes "This is not referentially transparent" #-}
psbFromBytes :: forall n. KnownNat n => [Word8] -> PinnedSizedBytes n
psbFromBytes :: forall (n :: Nat). KnownNat n => [Word8] -> PinnedSizedBytes n
psbFromBytes [Word8]
ws0 = ByteArray -> PinnedSizedBytes n
forall (n :: Nat). ByteArray -> PinnedSizedBytes n
PSB (Int -> [Word8] -> ByteArray
forall a. Prim a => Int -> [a] -> ByteArray
pinnedByteArrayFromListN Int
size [Word8]
ws)
  where
    size :: Int
    size :: Int
size = Integer -> Int
forall a. Num a => Integer -> a
fromInteger (Proxy n -> Integer
forall (n :: Nat) (proxy :: Nat -> *).
KnownNat n =>
proxy n -> Integer
natVal (Proxy n
forall {k} (t :: k). Proxy t
Proxy :: Proxy n))

    ws :: [Word8]
    ws :: [Word8]
ws =
      [Word8] -> [Word8]
forall a. [a] -> [a]
reverse ([Word8] -> [Word8]) -> [Word8] -> [Word8]
forall a b. (a -> b) -> a -> b
$
        Int -> [Word8] -> [Word8]
forall a. Int -> [a] -> [a]
take Int
size ([Word8] -> [Word8]) -> [Word8] -> [Word8]
forall a b. (a -> b) -> a -> b
$
          ([Word8] -> [Word8] -> [Word8]
forall a. [a] -> [a] -> [a]
++ Word8 -> [Word8]
forall a. a -> [a]
repeat Word8
0) ([Word8] -> [Word8]) -> [Word8] -> [Word8]
forall a b. (a -> b) -> a -> b
$
            [Word8] -> [Word8]
forall a. [a] -> [a]
reverse [Word8]
ws0

-- | Convert a ByteString into PinnedSizedBytes. Input should contain the exact
-- number of bytes as expected by type level @n@ size, otherwise error.
psbFromByteString :: KnownNat n => BS.ByteString -> PinnedSizedBytes n
psbFromByteString :: forall (n :: Nat). KnownNat n => ByteString -> PinnedSizedBytes n
psbFromByteString ByteString
bs =
  case ByteString -> Maybe (PinnedSizedBytes n)
forall (n :: Nat).
KnownNat n =>
ByteString -> Maybe (PinnedSizedBytes n)
psbFromByteStringCheck ByteString
bs of
    Maybe (PinnedSizedBytes n)
Nothing -> String -> PinnedSizedBytes n
forall a. HasCallStack => String -> a
error (String -> PinnedSizedBytes n) -> String -> PinnedSizedBytes n
forall a b. (a -> b) -> a -> b
$ String
"psbFromByteString: Size mismatch, got: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show (ByteString -> Int
BS.length ByteString
bs)
    Just PinnedSizedBytes n
psb -> PinnedSizedBytes n
psb

psbFromByteStringCheck :: forall n. KnownNat n => BS.ByteString -> Maybe (PinnedSizedBytes n)
psbFromByteStringCheck :: forall (n :: Nat).
KnownNat n =>
ByteString -> Maybe (PinnedSizedBytes n)
psbFromByteStringCheck ByteString
bs
  | ByteString -> Int
BS.length ByteString
bs Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
size = PinnedSizedBytes n -> Maybe (PinnedSizedBytes n)
forall a. a -> Maybe a
Just (PinnedSizedBytes n -> Maybe (PinnedSizedBytes n))
-> PinnedSizedBytes n -> Maybe (PinnedSizedBytes n)
forall a b. (a -> b) -> a -> b
$
      IO (PinnedSizedBytes n) -> PinnedSizedBytes n
forall a. IO a -> a
unsafeDupablePerformIO (IO (PinnedSizedBytes n) -> PinnedSizedBytes n)
-> IO (PinnedSizedBytes n) -> PinnedSizedBytes n
forall a b. (a -> b) -> a -> b
$
        ByteString
-> (CStringLen -> IO (PinnedSizedBytes n))
-> IO (PinnedSizedBytes n)
forall a. ByteString -> (CStringLen -> IO a) -> IO a
BS.useAsCStringLen ByteString
bs ((CStringLen -> IO (PinnedSizedBytes n))
 -> IO (PinnedSizedBytes n))
-> (CStringLen -> IO (PinnedSizedBytes n))
-> IO (PinnedSizedBytes n)
forall a b. (a -> b) -> a -> b
$ \(Ptr Addr#
addr#, Int
_) -> do
          marr :: MutableByteArray RealWorld
marr@(MutableByteArray MutableByteArray# RealWorld
marr#) <- Int -> IO (MutableByteArray (PrimState IO))
forall (m :: * -> *).
PrimMonad m =>
Int -> m (MutableByteArray (PrimState m))
newPinnedByteArray Int
size
          (State# (PrimState IO) -> State# (PrimState IO)) -> IO ()
forall (m :: * -> *).
PrimMonad m =>
(State# (PrimState m) -> State# (PrimState m)) -> m ()
primitive_ ((State# (PrimState IO) -> State# (PrimState IO)) -> IO ())
-> (State# (PrimState IO) -> State# (PrimState IO)) -> IO ()
forall a b. (a -> b) -> a -> b
$ Addr#
-> MutableByteArray# RealWorld
-> Int#
-> Int#
-> State# RealWorld
-> State# RealWorld
forall d.
Addr#
-> MutableByteArray# d -> Int# -> Int# -> State# d -> State# d
copyAddrToByteArray# Addr#
addr# MutableByteArray# RealWorld
marr# Int#
0# (case Int
size of I# Int#
s -> Int#
s)
          ByteArray
arr <- MutableByteArray (PrimState IO) -> IO ByteArray
forall (m :: * -> *).
PrimMonad m =>
MutableByteArray (PrimState m) -> m ByteArray
unsafeFreezeByteArray MutableByteArray RealWorld
MutableByteArray (PrimState IO)
marr
          PinnedSizedBytes n -> IO (PinnedSizedBytes n)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (ByteArray -> PinnedSizedBytes n
forall (n :: Nat). ByteArray -> PinnedSizedBytes n
PSB ByteArray
arr)
  | Bool
otherwise = Maybe (PinnedSizedBytes n)
forall a. Maybe a
Nothing
  where
    size :: Int
    size :: Int
size = Integer -> Int
forall a. Num a => Integer -> a
fromInteger (Proxy n -> Integer
forall (n :: Nat) (proxy :: Nat -> *).
KnownNat n =>
proxy n -> Integer
natVal (Proxy n
forall {k} (t :: k). Proxy t
Proxy :: Proxy n))

{-# DEPRECATED psbZero "This is not referentially transparent" #-}
psbZero :: KnownNat n => PinnedSizedBytes n
psbZero :: forall (n :: Nat). KnownNat n => PinnedSizedBytes n
psbZero = [Word8] -> PinnedSizedBytes n
forall (n :: Nat). KnownNat n => [Word8] -> PinnedSizedBytes n
psbFromBytes []

instance KnownNat n => Storable (PinnedSizedBytes n) where
  sizeOf :: PinnedSizedBytes n -> Int
sizeOf PinnedSizedBytes n
_ = Integer -> Int
forall a. Num a => Integer -> a
fromInteger (Proxy n -> Integer
forall (n :: Nat) (proxy :: Nat -> *).
KnownNat n =>
proxy n -> Integer
natVal (Proxy n
forall {k} (t :: k). Proxy t
Proxy :: Proxy n))
  alignment :: PinnedSizedBytes n -> Int
alignment PinnedSizedBytes n
_ = FunPtr (Int -> Int) -> Int
forall a. Storable a => a -> Int
alignment (FunPtr (Int -> Int)
forall a. HasCallStack => a
undefined :: FunPtr (Int -> Int))

  peek :: Ptr (PinnedSizedBytes n) -> IO (PinnedSizedBytes n)
peek (Ptr Addr#
addr#) = do
    let size :: Int
        size :: Int
size = Integer -> Int
forall a. Num a => Integer -> a
fromInteger (Proxy n -> Integer
forall (n :: Nat) (proxy :: Nat -> *).
KnownNat n =>
proxy n -> Integer
natVal (Proxy n
forall {k} (t :: k). Proxy t
Proxy :: Proxy n))
    marr :: MutableByteArray RealWorld
marr@(MutableByteArray MutableByteArray# RealWorld
marr#) <- Int -> IO (MutableByteArray (PrimState IO))
forall (m :: * -> *).
PrimMonad m =>
Int -> m (MutableByteArray (PrimState m))
newPinnedByteArray Int
size
    (State# (PrimState IO) -> State# (PrimState IO)) -> IO ()
forall (m :: * -> *).
PrimMonad m =>
(State# (PrimState m) -> State# (PrimState m)) -> m ()
primitive_ ((State# (PrimState IO) -> State# (PrimState IO)) -> IO ())
-> (State# (PrimState IO) -> State# (PrimState IO)) -> IO ()
forall a b. (a -> b) -> a -> b
$ Addr#
-> MutableByteArray# RealWorld
-> Int#
-> Int#
-> State# RealWorld
-> State# RealWorld
forall d.
Addr#
-> MutableByteArray# d -> Int# -> Int# -> State# d -> State# d
copyAddrToByteArray# Addr#
addr# MutableByteArray# RealWorld
marr# Int#
0# (case Int
size of I# Int#
s -> Int#
s)
    ByteArray
arr <- MutableByteArray (PrimState IO) -> IO ByteArray
forall (m :: * -> *).
PrimMonad m =>
MutableByteArray (PrimState m) -> m ByteArray
unsafeFreezeByteArray MutableByteArray RealWorld
MutableByteArray (PrimState IO)
marr
    PinnedSizedBytes n -> IO (PinnedSizedBytes n)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (ByteArray -> PinnedSizedBytes n
forall (n :: Nat). ByteArray -> PinnedSizedBytes n
PSB ByteArray
arr)

  poke :: Ptr (PinnedSizedBytes n) -> PinnedSizedBytes n -> IO ()
poke Ptr (PinnedSizedBytes n)
p (PSB ByteArray
arr) = do
    let size :: Int
        size :: Int
size = Integer -> Int
forall a. Num a => Integer -> a
fromInteger (Proxy n -> Integer
forall (n :: Nat) (proxy :: Nat -> *).
KnownNat n =>
proxy n -> Integer
natVal (Proxy n
forall {k} (t :: k). Proxy t
Proxy :: Proxy n))
    Ptr Word8 -> ByteArray -> Int -> Int -> IO ()
forall (m :: * -> *).
PrimMonad m =>
Ptr Word8 -> ByteArray -> Int -> Int -> m ()
copyByteArrayToAddr (Ptr (PinnedSizedBytes n) -> Ptr Word8
forall a b. Ptr a -> Ptr b
castPtr Ptr (PinnedSizedBytes n)
p) ByteArray
arr Int
0 Int
size

-- | Use a 'PinnedSizedBytes' in a setting where its size is \'forgotten\'
-- temporarily.
--
-- = Note
--
-- The 'Ptr' given to the function argument /must not/ be used as the result of
-- type @r@.
{-# INLINE psbUseAsCPtr #-}
psbUseAsCPtr ::
  forall (n :: Nat) (r :: Type) (m :: Type -> Type).
  MonadST m =>
  PinnedSizedBytes n ->
  (Ptr Word8 -> m r) ->
  m r
psbUseAsCPtr :: forall (n :: Nat) r (m :: * -> *).
MonadST m =>
PinnedSizedBytes n -> (Ptr Word8 -> m r) -> m r
psbUseAsCPtr (PSB ByteArray
ba) = ByteArray -> (Ptr Word8 -> m r) -> m r
forall a (m :: * -> *).
MonadST m =>
ByteArray -> (Ptr Word8 -> m a) -> m a
runAndTouch ByteArray
ba

-- | As 'psbUseAsCPtr', but also gives the function argument the size we are
-- allowed to use as a 'CSize'.
--
-- This is mostly boilerplate removal, as it is quite common for C APIs to take
-- a combination of a pointer to some data and its length. A possible use case
-- (and one we run into) is where we know that we can expect a certain data
-- length (using 'PinnedSizedBytes' as its representation), but the C API allows
-- any length we like, provided we give the right argument to indicate this.
-- Therefore, having a helper like this one allows us to avoid having to
-- manually 'natVal' a 'Proxy', as well as ensuring we don't get mismatches
-- accidentally.
--
-- The same caveats apply to the use of this function as to the use of
-- 'psbUseAsCPtr'.
{-# INLINE psbUseAsCPtrLen #-}
psbUseAsCPtrLen ::
  forall (n :: Nat) (r :: Type) (m :: Type -> Type).
  (KnownNat n, MonadST m) =>
  PinnedSizedBytes n ->
  (Ptr Word8 -> CSize -> m r) ->
  m r
psbUseAsCPtrLen :: forall (n :: Nat) r (m :: * -> *).
(KnownNat n, MonadST m) =>
PinnedSizedBytes n -> (Ptr Word8 -> CSize -> m r) -> m r
psbUseAsCPtrLen (PSB ByteArray
ba) Ptr Word8 -> CSize -> m r
f = do
  let CSize
len :: CSize = Integer -> CSize
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Integer -> CSize) -> (Proxy n -> Integer) -> Proxy n -> CSize
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Proxy n -> Integer
forall (n :: Nat) (proxy :: Nat -> *).
KnownNat n =>
proxy n -> Integer
natVal (Proxy n -> CSize) -> Proxy n -> CSize
forall a b. (a -> b) -> a -> b
$ forall (t :: Nat). Proxy t
forall {k} (t :: k). Proxy t
Proxy @n
  ByteArray -> (Ptr Word8 -> m r) -> m r
forall a (m :: * -> *).
MonadST m =>
ByteArray -> (Ptr Word8 -> m a) -> m a
runAndTouch ByteArray
ba (Ptr Word8 -> CSize -> m r
`f` CSize
len)

-- | As 'psbUseAsCPtr', but does not \'forget\' the size.
--
-- The same caveats apply to this use of this function as to the use of
-- 'psbUseAsCPtr'.
{-# INLINE psbUseAsSizedPtr #-}
psbUseAsSizedPtr ::
  forall (n :: Nat) (r :: Type) (m :: Type -> Type).
  MonadST m =>
  PinnedSizedBytes n ->
  (SizedPtr n -> m r) ->
  m r
psbUseAsSizedPtr :: forall (n :: Nat) r (m :: * -> *).
MonadST m =>
PinnedSizedBytes n -> (SizedPtr n -> m r) -> m r
psbUseAsSizedPtr (PSB ByteArray
ba) SizedPtr n -> m r
k = do
  r
r <- SizedPtr n -> m r
k (Ptr Void -> SizedPtr n
forall (n :: Nat). Ptr Void -> SizedPtr n
SizedPtr (Ptr Void -> SizedPtr n) -> Ptr Void -> SizedPtr n
forall a b. (a -> b) -> a -> b
$ Ptr Word8 -> Ptr Void
forall a b. Ptr a -> Ptr b
castPtr (Ptr Word8 -> Ptr Void) -> Ptr Word8 -> Ptr Void
forall a b. (a -> b) -> a -> b
$ ByteArray -> Ptr Word8
byteArrayContents ByteArray
ba)
  r
r r -> m () -> m r
forall a b. a -> m b -> m a
forall (f :: * -> *) a b. Functor f => a -> f b -> f a
<$ ST (PrimState m) () -> m ()
forall a. ST (PrimState m) a -> m a
forall (m :: * -> *) a. MonadST m => ST (PrimState m) a -> m a
stToIO (ByteArray -> ST (PrimState m) ()
forall (m :: * -> *) a. PrimMonad m => a -> m ()
touch ByteArray
ba)

-- | As 'psbCreateResult', but presumes that no useful value is produced: that
-- is, the function argument is run only for its side effects.
{-# INLINE psbCreate #-}
psbCreate ::
  forall (n :: Nat) (m :: Type -> Type).
  (KnownNat n, MonadST m) =>
  (Ptr Word8 -> m ()) ->
  m (PinnedSizedBytes n)
psbCreate :: forall (n :: Nat) (m :: * -> *).
(KnownNat n, MonadST m) =>
(Ptr Word8 -> m ()) -> m (PinnedSizedBytes n)
psbCreate Ptr Word8 -> m ()
f = (PinnedSizedBytes n, ()) -> PinnedSizedBytes n
forall a b. (a, b) -> a
fst ((PinnedSizedBytes n, ()) -> PinnedSizedBytes n)
-> m (PinnedSizedBytes n, ()) -> m (PinnedSizedBytes n)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Ptr Word8 -> m ()) -> m (PinnedSizedBytes n, ())
forall (n :: Nat) r (m :: * -> *).
(KnownNat n, MonadST m) =>
(Ptr Word8 -> m r) -> m (PinnedSizedBytes n, r)
psbCreateResult Ptr Word8 -> m ()
f

-- | As 'psbCreateResultLen', but presumes that no useful value is produced:
-- that is, the function argument is run only for its side effects.
{-# INLINE psbCreateLen #-}
psbCreateLen ::
  forall (n :: Nat) (m :: Type -> Type).
  (KnownNat n, MonadST m) =>
  (Ptr Word8 -> CSize -> m ()) ->
  m (PinnedSizedBytes n)
psbCreateLen :: forall (n :: Nat) (m :: * -> *).
(KnownNat n, MonadST m) =>
(Ptr Word8 -> CSize -> m ()) -> m (PinnedSizedBytes n)
psbCreateLen Ptr Word8 -> CSize -> m ()
f = (PinnedSizedBytes n, ()) -> PinnedSizedBytes n
forall a b. (a, b) -> a
fst ((PinnedSizedBytes n, ()) -> PinnedSizedBytes n)
-> m (PinnedSizedBytes n, ()) -> m (PinnedSizedBytes n)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Ptr Word8 -> CSize -> m ()) -> m (PinnedSizedBytes n, ())
forall (n :: Nat) r (m :: * -> *).
(KnownNat n, MonadST m) =>
(Ptr Word8 -> CSize -> m r) -> m (PinnedSizedBytes n, r)
psbCreateResultLen Ptr Word8 -> CSize -> m ()
f

-- | Given an \'initialization action\', which also produces some result, allocate
-- new pinned memory of the specified size, perform the action, then return the
-- result together with the initialized pinned memory (as a 'PinnedSizedBytes').
--
-- = Note
--
-- It is essential that @r@ is not the 'Ptr' given to the function argument.
-- Returning this 'Ptr' is /extremely/ unsafe:
--
-- * It breaks referential transparency guarantees by aliasing supposedly
-- immutable memory; and
-- * This 'Ptr' could refer to memory which has already been garbage collected,
-- which can lead to segfaults or out-of-bounds reads.
--
-- This poses both correctness /and/ security risks, so please don't do it.
{-# INLINE psbCreateResult #-}
psbCreateResult ::
  forall (n :: Nat) (r :: Type) (m :: Type -> Type).
  (KnownNat n, MonadST m) =>
  (Ptr Word8 -> m r) ->
  m (PinnedSizedBytes n, r)
psbCreateResult :: forall (n :: Nat) r (m :: * -> *).
(KnownNat n, MonadST m) =>
(Ptr Word8 -> m r) -> m (PinnedSizedBytes n, r)
psbCreateResult Ptr Word8 -> m r
f = (Ptr Word8 -> CSize -> m r) -> m (PinnedSizedBytes n, r)
forall (n :: Nat) r (m :: * -> *).
(KnownNat n, MonadST m) =>
(Ptr Word8 -> CSize -> m r) -> m (PinnedSizedBytes n, r)
psbCreateResultLen (\Ptr Word8
p CSize
_ -> Ptr Word8 -> m r
f Ptr Word8
p)

-- | As 'psbCreateResult', but also gives the number of bytes we are allowed to
-- operate on as a 'CSize'.
--
-- This function is provided for two reasons:
--
-- * It is a common practice in C libraries to pass a pointer to data plus a
-- length. While /our/ use case might know the size we expect, the C function we
-- are calling might be more general. This simplifies calling such functions.
-- * We avoid 'natVal'ing a 'Proxy' /twice/, since we have to do it anyway.
--
-- The same caveats apply to this function as to 'psbCreateResult': the 'Ptr'
-- given to the function argument /must not/ be returned as @r@.
{-# INLINE psbCreateResultLen #-}
psbCreateResultLen ::
  forall (n :: Nat) (r :: Type) (m :: Type -> Type).
  (KnownNat n, MonadST m) =>
  (Ptr Word8 -> CSize -> m r) ->
  m (PinnedSizedBytes n, r)
psbCreateResultLen :: forall (n :: Nat) r (m :: * -> *).
(KnownNat n, MonadST m) =>
(Ptr Word8 -> CSize -> m r) -> m (PinnedSizedBytes n, r)
psbCreateResultLen Ptr Word8 -> CSize -> m r
f = do
  let Int
len :: Int = Integer -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Integer -> Int) -> (Proxy n -> Integer) -> Proxy n -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Proxy n -> Integer
forall (n :: Nat) (proxy :: Nat -> *).
KnownNat n =>
proxy n -> Integer
natVal (Proxy n -> Int) -> Proxy n -> Int
forall a b. (a -> b) -> a -> b
$ forall (t :: Nat). Proxy t
forall {k} (t :: k). Proxy t
Proxy @n
  MutableByteArray (PrimState m)
mba <- ST (PrimState m) (MutableByteArray (PrimState m))
-> m (MutableByteArray (PrimState m))
forall a. ST (PrimState m) a -> m a
forall (m :: * -> *) a. MonadST m => ST (PrimState m) a -> m a
stToIO (Int
-> ST
     (PrimState m) (MutableByteArray (PrimState (ST (PrimState m))))
forall (m :: * -> *).
PrimMonad m =>
Int -> m (MutableByteArray (PrimState m))
newPinnedByteArray Int
len)
  r
res <- Ptr Word8 -> CSize -> m r
f (MutableByteArray (PrimState m) -> Ptr Word8
forall s. MutableByteArray s -> Ptr Word8
mutableByteArrayContents MutableByteArray (PrimState m)
mba) (Int -> CSize
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
len)
  ByteArray
arr <- ST (PrimState m) ByteArray -> m ByteArray
forall a. ST (PrimState m) a -> m a
forall (m :: * -> *) a. MonadST m => ST (PrimState m) a -> m a
stToIO (MutableByteArray (PrimState (ST (PrimState m)))
-> ST (PrimState m) ByteArray
forall (m :: * -> *).
PrimMonad m =>
MutableByteArray (PrimState m) -> m ByteArray
unsafeFreezeByteArray MutableByteArray (PrimState m)
MutableByteArray (PrimState (ST (PrimState m)))
mba)
  (PinnedSizedBytes n, r) -> m (PinnedSizedBytes n, r)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ByteArray -> PinnedSizedBytes n
forall (n :: Nat). ByteArray -> PinnedSizedBytes n
PSB ByteArray
arr, r
res)

-- | As 'psbCreateSizedResult', but presumes that no useful value is produced:
-- that is, the function argument is run only for its side effects.
{-# INLINE psbCreateSized #-}
psbCreateSized ::
  forall (n :: Nat) (m :: Type -> Type).
  (KnownNat n, MonadST m) =>
  (SizedPtr n -> m ()) ->
  m (PinnedSizedBytes n)
psbCreateSized :: forall (n :: Nat) (m :: * -> *).
(KnownNat n, MonadST m) =>
(SizedPtr n -> m ()) -> m (PinnedSizedBytes n)
psbCreateSized SizedPtr n -> m ()
k = (Ptr Word8 -> m ()) -> m (PinnedSizedBytes n)
forall (n :: Nat) (m :: * -> *).
(KnownNat n, MonadST m) =>
(Ptr Word8 -> m ()) -> m (PinnedSizedBytes n)
psbCreate (SizedPtr n -> m ()
k (SizedPtr n -> m ())
-> (Ptr Word8 -> SizedPtr n) -> Ptr Word8 -> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Ptr Void -> SizedPtr n
forall (n :: Nat). Ptr Void -> SizedPtr n
SizedPtr (Ptr Void -> SizedPtr n)
-> (Ptr Word8 -> Ptr Void) -> Ptr Word8 -> SizedPtr n
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Ptr Word8 -> Ptr Void
forall a b. Ptr a -> Ptr b
castPtr)

-- | As 'psbCreateResult', but gives a 'SizedPtr' to the function argument. The
-- same caveats apply to this function as to 'psbCreateResult': the 'SizedPtr'
-- given to the function argument /must not/ be resulted as @r@.
{-# INLINE psbCreateSizedResult #-}
psbCreateSizedResult ::
  forall (n :: Nat) (r :: Type) (m :: Type -> Type).
  (KnownNat n, MonadST m) =>
  (SizedPtr n -> m r) ->
  m (PinnedSizedBytes n, r)
psbCreateSizedResult :: forall (n :: Nat) r (m :: * -> *).
(KnownNat n, MonadST m) =>
(SizedPtr n -> m r) -> m (PinnedSizedBytes n, r)
psbCreateSizedResult SizedPtr n -> m r
f = (Ptr Word8 -> m r) -> m (PinnedSizedBytes n, r)
forall (n :: Nat) r (m :: * -> *).
(KnownNat n, MonadST m) =>
(Ptr Word8 -> m r) -> m (PinnedSizedBytes n, r)
psbCreateResult (SizedPtr n -> m r
f (SizedPtr n -> m r)
-> (Ptr Word8 -> SizedPtr n) -> Ptr Word8 -> m r
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Ptr Void -> SizedPtr n
forall (n :: Nat). Ptr Void -> SizedPtr n
SizedPtr (Ptr Void -> SizedPtr n)
-> (Ptr Word8 -> Ptr Void) -> Ptr Word8 -> SizedPtr n
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Ptr Word8 -> Ptr Void
forall a b. Ptr a -> Ptr b
castPtr)

ptrPsbToSizedPtr :: Ptr (PinnedSizedBytes n) -> SizedPtr n
ptrPsbToSizedPtr :: forall (n :: Nat). Ptr (PinnedSizedBytes n) -> SizedPtr n
ptrPsbToSizedPtr = Ptr Void -> SizedPtr n
forall (n :: Nat). Ptr Void -> SizedPtr n
SizedPtr (Ptr Void -> SizedPtr n)
-> (Ptr (PinnedSizedBytes n) -> Ptr Void)
-> Ptr (PinnedSizedBytes n)
-> SizedPtr n
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Ptr (PinnedSizedBytes n) -> Ptr Void
forall a b. Ptr a -> Ptr b
castPtr

-------------------------------------------------------------------------------
-- derivative from primitive
-------------------------------------------------------------------------------

-- | Create a 'ByteArray' from a list of a known length. If the length
--   of the list does not match the given length, or if the length is zero,
--   then this throws an exception.
pinnedByteArrayFromListN :: forall a. Prim.Prim a => Int -> [a] -> ByteArray
pinnedByteArrayFromListN :: forall a. Prim a => Int -> [a] -> ByteArray
pinnedByteArrayFromListN Int
0 [a]
_ =
  String -> String -> ByteArray
forall a. String -> String -> a
die String
"pinnedByteArrayFromListN" String
"list length zero #1"
pinnedByteArrayFromListN Int
n [a]
ys = (forall s. ST s ByteArray) -> ByteArray
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s ByteArray) -> ByteArray)
-> (forall s. ST s ByteArray) -> ByteArray
forall a b. (a -> b) -> a -> b
$ do
  let headYs :: a
headYs = case [a]
ys of
        [] -> String -> String -> a
forall a. String -> String -> a
die String
"pinnedByteArrayFromListN" String
"list length zero #2"
        (a
y : [a]
_) -> a
y
  MutableByteArray (PrimState (ST s))
marr <- Int -> ST s (MutableByteArray (PrimState (ST s)))
forall (m :: * -> *).
PrimMonad m =>
Int -> m (MutableByteArray (PrimState m))
newPinnedByteArray (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
* a -> Int
forall a. Prim a => a -> Int
Prim.sizeOf a
headYs)
  let go :: Int -> [a] -> ST s ()
go !Int
ix [] =
        if Int
ix Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
n
          then () -> ST s ()
forall a. a -> ST s a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
          else String -> String -> ST s ()
forall a. String -> String -> a
die String
"pinnedByteArrayFromListN" String
"list length less than specified size"
      go !Int
ix (a
x : [a]
xs) =
        if Int
ix Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
n
          then do
            MutableByteArray (PrimState (ST s)) -> Int -> a -> ST s ()
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutableByteArray (PrimState m) -> Int -> a -> m ()
writeByteArray MutableByteArray (PrimState (ST s))
marr Int
ix a
x
            Int -> [a] -> ST s ()
go (Int
ix Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) [a]
xs
          else String -> String -> ST s ()
forall a. String -> String -> a
die String
"pinnedByteArrayFromListN" String
"list length greater than specified size"
  Int -> [a] -> ST s ()
forall {a}. Prim a => Int -> [a] -> ST s ()
go Int
0 [a]
ys
  MutableByteArray (PrimState (ST s)) -> ST s ByteArray
forall (m :: * -> *).
PrimMonad m =>
MutableByteArray (PrimState m) -> m ByteArray
unsafeFreezeByteArray MutableByteArray (PrimState (ST s))
marr

die :: String -> String -> a
die :: forall a. String -> String -> a
die String
fun String
problem = String -> a
forall a. HasCallStack => String -> a
error (String -> a) -> String -> a
forall a b. (a -> b) -> a -> b
$ String
"PinnedSizedBytes." String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
fun String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
": " String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
problem

-- Wrapper that combines applying a function, then touching
{-# INLINE runAndTouch #-}
runAndTouch ::
  forall (a :: Type) (m :: Type -> Type).
  MonadST m =>
  ByteArray ->
  (Ptr Word8 -> m a) ->
  m a
runAndTouch :: forall a (m :: * -> *).
MonadST m =>
ByteArray -> (Ptr Word8 -> m a) -> m a
runAndTouch ByteArray
ba Ptr Word8 -> m a
f = do
  a
r <- Ptr Word8 -> m a
f (ByteArray -> Ptr Word8
byteArrayContents ByteArray
ba)
  a
r a -> m () -> m a
forall a b. a -> m b -> m a
forall (f :: * -> *) a b. Functor f => a -> f b -> f a
<$ ST (PrimState m) () -> m ()
forall a. ST (PrimState m) a -> m a
forall (m :: * -> *) a. MonadST m => ST (PrimState m) a -> m a
stToIO (ByteArray -> ST (PrimState m) ()
forall (m :: * -> *) a. PrimMonad m => a -> m ()
touch ByteArray
ba)