{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE RoleAnnotations #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}

module Cardano.Crypto.Libsodium.Hash (
  SodiumHashAlgorithm (..),
  digestMLockedStorable,
  digestMLockedBS,
  expandHash,
  expandHashWith,
) where

import Data.Proxy (Proxy (..))
import Data.Word (Word8)
import Foreign.C.Types (CSize)
import Foreign.Ptr (castPtr, plusPtr)
import Foreign.Storable (Storable (poke))
import GHC.TypeLits

import Cardano.Crypto.Hash (HashAlgorithm (SizeHash))
import Cardano.Crypto.Libsodium.Hash.Class
import Cardano.Crypto.Libsodium.MLockedBytes.Internal
import Cardano.Crypto.Libsodium.Memory
import Control.Monad.Class.MonadST (MonadST (..))
import Control.Monad.Class.MonadThrow (MonadThrow)
import Control.Monad.ST.Unsafe (unsafeIOToST)

-------------------------------------------------------------------------------
-- Hash expansion
-------------------------------------------------------------------------------

expandHash ::
  forall h m proxy.
  (SodiumHashAlgorithm h, MonadST m, MonadThrow m) =>
  proxy h ->
  MLockedSizedBytes (SizeHash h) ->
  m (MLockedSizedBytes (SizeHash h), MLockedSizedBytes (SizeHash h))
expandHash :: forall h (m :: * -> *) (proxy :: * -> *).
(SodiumHashAlgorithm h, MonadST m, MonadThrow m) =>
proxy h
-> MLockedSizedBytes (SizeHash h)
-> m (MLockedSizedBytes (SizeHash h),
      MLockedSizedBytes (SizeHash h))
expandHash = MLockedAllocator m
-> proxy h
-> MLockedSizedBytes (SizeHash h)
-> m (MLockedSizedBytes (SizeHash h),
      MLockedSizedBytes (SizeHash h))
forall h (m :: * -> *) (proxy :: * -> *).
(SodiumHashAlgorithm h, MonadST m, MonadThrow m) =>
MLockedAllocator m
-> proxy h
-> MLockedSizedBytes (SizeHash h)
-> m (MLockedSizedBytes (SizeHash h),
      MLockedSizedBytes (SizeHash h))
expandHashWith MLockedAllocator m
forall (m :: * -> *). MonadST m => MLockedAllocator m
mlockedMalloc

expandHashWith ::
  forall h m proxy.
  (SodiumHashAlgorithm h, MonadST m, MonadThrow m) =>
  MLockedAllocator m ->
  proxy h ->
  MLockedSizedBytes (SizeHash h) ->
  m (MLockedSizedBytes (SizeHash h), MLockedSizedBytes (SizeHash h))
expandHashWith :: forall h (m :: * -> *) (proxy :: * -> *).
(SodiumHashAlgorithm h, MonadST m, MonadThrow m) =>
MLockedAllocator m
-> proxy h
-> MLockedSizedBytes (SizeHash h)
-> m (MLockedSizedBytes (SizeHash h),
      MLockedSizedBytes (SizeHash h))
expandHashWith MLockedAllocator m
allocator proxy h
h (MLSB MLockedForeignPtr (SizedVoid (SizeHash h))
sfptr) = do
  MLockedForeignPtr (SizedVoid (SizeHash h))
-> (Ptr (SizedVoid (SizeHash h))
    -> m (MLockedSizedBytes (SizeHash h),
          MLockedSizedBytes (SizeHash h)))
-> m (MLockedSizedBytes (SizeHash h),
      MLockedSizedBytes (SizeHash h))
forall (m :: * -> *) a b.
MonadST m =>
MLockedForeignPtr a -> (Ptr a -> m b) -> m b
withMLockedForeignPtr MLockedForeignPtr (SizedVoid (SizeHash h))
sfptr ((Ptr (SizedVoid (SizeHash h))
  -> m (MLockedSizedBytes (SizeHash h),
        MLockedSizedBytes (SizeHash h)))
 -> m (MLockedSizedBytes (SizeHash h),
       MLockedSizedBytes (SizeHash h)))
-> (Ptr (SizedVoid (SizeHash h))
    -> m (MLockedSizedBytes (SizeHash h),
          MLockedSizedBytes (SizeHash h)))
-> m (MLockedSizedBytes (SizeHash h),
      MLockedSizedBytes (SizeHash h))
forall a b. (a -> b) -> a -> b
$ \Ptr (SizedVoid (SizeHash h))
ptr -> do
    MLockedSizedBytes (SizeHash h)
l <- MLockedAllocator m
-> CSize
-> (Ptr Word8 -> m (MLockedSizedBytes (SizeHash h)))
-> m (MLockedSizedBytes (SizeHash h))
forall a b (m :: * -> *).
(MonadThrow m, MonadST m) =>
MLockedAllocator m -> CSize -> (Ptr a -> m b) -> m b
mlockedAllocaWith MLockedAllocator m
allocator CSize
size1 ((Ptr Word8 -> m (MLockedSizedBytes (SizeHash h)))
 -> m (MLockedSizedBytes (SizeHash h)))
-> (Ptr Word8 -> m (MLockedSizedBytes (SizeHash h)))
-> m (MLockedSizedBytes (SizeHash h))
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
ptr' -> do
      ST (PrimState m) (MLockedSizedBytes (SizeHash h))
-> m (MLockedSizedBytes (SizeHash h))
forall a. ST (PrimState m) a -> m a
forall (m :: * -> *) a. MonadST m => ST (PrimState m) a -> m a
stToIO (ST (PrimState m) (MLockedSizedBytes (SizeHash h))
 -> m (MLockedSizedBytes (SizeHash h)))
-> (IO (MLockedSizedBytes (SizeHash h))
    -> ST (PrimState m) (MLockedSizedBytes (SizeHash h)))
-> IO (MLockedSizedBytes (SizeHash h))
-> m (MLockedSizedBytes (SizeHash h))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IO (MLockedSizedBytes (SizeHash h))
-> ST (PrimState m) (MLockedSizedBytes (SizeHash h))
forall a s. IO a -> ST s a
unsafeIOToST (IO (MLockedSizedBytes (SizeHash h))
 -> m (MLockedSizedBytes (SizeHash h)))
-> IO (MLockedSizedBytes (SizeHash h))
-> m (MLockedSizedBytes (SizeHash h))
forall a b. (a -> b) -> a -> b
$ do
        Ptr Word8 -> Word8 -> IO ()
forall a. Storable a => Ptr a -> a -> IO ()
poke Ptr Word8
ptr' (Word8
1 :: Word8)
        Ptr (SizedVoid (SizeHash h))
-> Ptr (SizedVoid (SizeHash h)) -> CSize -> IO ()
forall (m :: * -> *) a.
MonadST m =>
Ptr a -> Ptr a -> CSize -> m ()
copyMem (Ptr Any -> Ptr (SizedVoid (SizeHash h))
forall a b. Ptr a -> Ptr b
castPtr (Ptr Word8 -> Int -> Ptr Any
forall a b. Ptr a -> Int -> Ptr b
plusPtr Ptr Word8
ptr' Int
1)) Ptr (SizedVoid (SizeHash h))
ptr CSize
size
        proxy h -> Ptr Word8 -> Int -> IO (MLockedSizedBytes (SizeHash h))
forall h (proxy :: * -> *) a.
SodiumHashAlgorithm h =>
proxy h -> Ptr a -> Int -> IO (MLockedSizedBytes (SizeHash h))
forall (proxy :: * -> *) a.
proxy h -> Ptr a -> Int -> IO (MLockedSizedBytes (SizeHash h))
naclDigestPtr proxy h
h Ptr Word8
ptr' (CSize -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral CSize
size1)

    MLockedSizedBytes (SizeHash h)
r <- MLockedAllocator m
-> CSize
-> (Ptr Word8 -> m (MLockedSizedBytes (SizeHash h)))
-> m (MLockedSizedBytes (SizeHash h))
forall a b (m :: * -> *).
(MonadThrow m, MonadST m) =>
MLockedAllocator m -> CSize -> (Ptr a -> m b) -> m b
mlockedAllocaWith MLockedAllocator m
allocator CSize
size1 ((Ptr Word8 -> m (MLockedSizedBytes (SizeHash h)))
 -> m (MLockedSizedBytes (SizeHash h)))
-> (Ptr Word8 -> m (MLockedSizedBytes (SizeHash h)))
-> m (MLockedSizedBytes (SizeHash h))
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
ptr' -> do
      ST (PrimState m) (MLockedSizedBytes (SizeHash h))
-> m (MLockedSizedBytes (SizeHash h))
forall a. ST (PrimState m) a -> m a
forall (m :: * -> *) a. MonadST m => ST (PrimState m) a -> m a
stToIO (ST (PrimState m) (MLockedSizedBytes (SizeHash h))
 -> m (MLockedSizedBytes (SizeHash h)))
-> (IO (MLockedSizedBytes (SizeHash h))
    -> ST (PrimState m) (MLockedSizedBytes (SizeHash h)))
-> IO (MLockedSizedBytes (SizeHash h))
-> m (MLockedSizedBytes (SizeHash h))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IO (MLockedSizedBytes (SizeHash h))
-> ST (PrimState m) (MLockedSizedBytes (SizeHash h))
forall a s. IO a -> ST s a
unsafeIOToST (IO (MLockedSizedBytes (SizeHash h))
 -> m (MLockedSizedBytes (SizeHash h)))
-> IO (MLockedSizedBytes (SizeHash h))
-> m (MLockedSizedBytes (SizeHash h))
forall a b. (a -> b) -> a -> b
$ do
        Ptr Word8 -> Word8 -> IO ()
forall a. Storable a => Ptr a -> a -> IO ()
poke Ptr Word8
ptr' (Word8
2 :: Word8)
        Ptr (SizedVoid (SizeHash h))
-> Ptr (SizedVoid (SizeHash h)) -> CSize -> IO ()
forall (m :: * -> *) a.
MonadST m =>
Ptr a -> Ptr a -> CSize -> m ()
copyMem (Ptr Any -> Ptr (SizedVoid (SizeHash h))
forall a b. Ptr a -> Ptr b
castPtr (Ptr Word8 -> Int -> Ptr Any
forall a b. Ptr a -> Int -> Ptr b
plusPtr Ptr Word8
ptr' Int
1)) Ptr (SizedVoid (SizeHash h))
ptr CSize
size
        proxy h -> Ptr Word8 -> Int -> IO (MLockedSizedBytes (SizeHash h))
forall h (proxy :: * -> *) a.
SodiumHashAlgorithm h =>
proxy h -> Ptr a -> Int -> IO (MLockedSizedBytes (SizeHash h))
forall (proxy :: * -> *) a.
proxy h -> Ptr a -> Int -> IO (MLockedSizedBytes (SizeHash h))
naclDigestPtr proxy h
h Ptr Word8
ptr' (CSize -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral CSize
size1)

    (MLockedSizedBytes (SizeHash h), MLockedSizedBytes (SizeHash h))
-> m (MLockedSizedBytes (SizeHash h),
      MLockedSizedBytes (SizeHash h))
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (MLockedSizedBytes (SizeHash h)
l, MLockedSizedBytes (SizeHash h)
r)
  where
    size1 :: CSize
    size1 :: CSize
size1 = CSize
size CSize -> CSize -> CSize
forall a. Num a => a -> a -> a
+ CSize
1

    size :: CSize
    size :: CSize
size = Integer -> CSize
forall a. Num a => Integer -> a
fromInteger (Integer -> CSize) -> Integer -> CSize
forall a b. (a -> b) -> a -> b
$ Proxy (SizeHash h) -> Integer
forall (n :: Nat) (proxy :: Nat -> *).
KnownNat n =>
proxy n -> Integer
natVal (forall (t :: Nat). Proxy t
forall {k} (t :: k). Proxy t
Proxy @(SizeHash h))