{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE QuasiQuotes #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE UnboxedTuples #-}

module Cardano.Crypto.Util (
  Empty,
  SignableRepresentation (..),
  getRandomWord64,

  -- * Simple serialisation used in mock instances
  readBinaryWord64,
  writeBinaryWord64,
  readBinaryNatural,
  writeBinaryNatural,

  -- * Low level conversions
  bytesToInteger,
  bytesToNatural,
  naturalToBytes,
  byteArrayToInteger,
  byteArrayToNatural,
  naturalToByteArray,

  -- * Base16 conversion
  decodeHexByteString,
  decodeHexString,
  decodeHexStringQ,
)
where

import Cardano.Base.Bytes (byteStringToByteArray)
import Control.Monad (unless)
import Data.Array.Byte (ByteArray (..))
import Data.Bifunctor (first)
import Data.Bits
import Data.ByteString (ByteString)
import qualified Data.ByteString as BS
import Data.ByteString.Base16 as BS16
import qualified Data.ByteString.Char8 as BSC8
import qualified Data.ByteString.Internal as BS
import Data.Char (isAscii)
import Data.Word
import Language.Haskell.TH
import Numeric.Natural

import Foreign.ForeignPtr (withForeignPtr)
import GHC.Exts (Addr#, Int#, Word#, sizeofByteArray#)
import qualified GHC.Exts as GHC
import qualified GHC.Natural as GHC

import Crypto.Random (MonadRandom (..))

import GHC.IO (unsafeDupablePerformIO)
import GHC.Num.Integer (integerFromAddr, integerFromByteArray)

class Empty a
instance Empty a

--
-- Signable
--

-- | A class of types that have a representation in bytes that can be used
-- for signing and verifying.
class SignableRepresentation a where
  getSignableRepresentation :: a -> ByteString

instance SignableRepresentation ByteString where
  getSignableRepresentation :: ByteString -> ByteString
getSignableRepresentation = ByteString -> ByteString
forall a. a -> a
id

--
-- Random source used in some mock instances
--

getRandomWord64 :: MonadRandom m => m Word64
getRandomWord64 :: forall (m :: * -> *). MonadRandom m => m Word64
getRandomWord64 = ByteString -> Word64
readBinaryWord64 (ByteString -> Word64) -> m ByteString -> m Word64
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> m ByteString
forall byteArray. ByteArray byteArray => Int -> m byteArray
forall (m :: * -> *) byteArray.
(MonadRandom m, ByteArray byteArray) =>
Int -> m byteArray
getRandomBytes Int
8

--
-- Really simple serialisation used in some mock instances
--

readBinaryWord64 :: ByteString -> Word64
readBinaryWord64 :: ByteString -> Word64
readBinaryWord64 =
  (Word64 -> Word8 -> Word64) -> Word64 -> ByteString -> Word64
forall a. (a -> Word8 -> a) -> a -> ByteString -> a
BS.foldl' (\Word64
acc Word8
w8 -> Word64 -> Int -> Word64
forall a. Bits a => a -> Int -> a
unsafeShiftL Word64
acc Int
8 Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
+ forall a b. (Integral a, Num b) => a -> b
fromIntegral @Word8 @Word64 Word8
w8) Word64
0

readBinaryNatural :: ByteString -> Natural
readBinaryNatural :: ByteString -> Natural
readBinaryNatural =
  (Natural -> Word8 -> Natural) -> Natural -> ByteString -> Natural
forall a. (a -> Word8 -> a) -> a -> ByteString -> a
BS.foldl' (\Natural
acc Word8
w8 -> Natural -> Int -> Natural
forall a. Bits a => a -> Int -> a
unsafeShiftL Natural
acc Int
8 Natural -> Natural -> Natural
forall a. Num a => a -> a -> a
+ forall a b. (Integral a, Num b) => a -> b
fromIntegral @Word8 @Natural Word8
w8) Natural
0

writeBinaryWord64 :: Word64 -> ByteString
writeBinaryWord64 :: Word64 -> ByteString
writeBinaryWord64 =
  ByteString -> ByteString
BS.reverse
    (ByteString -> ByteString)
-> (Word64 -> ByteString) -> Word64 -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (ByteString, Maybe Word64) -> ByteString
forall a b. (a, b) -> a
fst
    ((ByteString, Maybe Word64) -> ByteString)
-> (Word64 -> (ByteString, Maybe Word64)) -> Word64 -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int
-> (Word64 -> Maybe (Word8, Word64))
-> Word64
-> (ByteString, Maybe Word64)
forall a.
Int -> (a -> Maybe (Word8, a)) -> a -> (ByteString, Maybe a)
BS.unfoldrN Int
8 (\Word64
w -> (Word8, Word64) -> Maybe (Word8, Word64)
forall a. a -> Maybe a
Just (forall a b. (Integral a, Num b) => a -> b
fromIntegral @Word64 @Word8 Word64
w, Word64 -> Int -> Word64
forall a. Bits a => a -> Int -> a
unsafeShiftR Word64
w Int
8))

writeBinaryNatural :: Int -> Natural -> ByteString
writeBinaryNatural :: Int -> Natural -> ByteString
writeBinaryNatural Int
bytes =
  ByteString -> ByteString
BS.reverse
    (ByteString -> ByteString)
-> (Natural -> ByteString) -> Natural -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (ByteString, Maybe Natural) -> ByteString
forall a b. (a, b) -> a
fst
    ((ByteString, Maybe Natural) -> ByteString)
-> (Natural -> (ByteString, Maybe Natural))
-> Natural
-> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int
-> (Natural -> Maybe (Word8, Natural))
-> Natural
-> (ByteString, Maybe Natural)
forall a.
Int -> (a -> Maybe (Word8, a)) -> a -> (ByteString, Maybe a)
BS.unfoldrN Int
bytes (\Natural
w -> (Word8, Natural) -> Maybe (Word8, Natural)
forall a. a -> Maybe a
Just (forall a b. (Integral a, Num b) => a -> b
fromIntegral @Natural @Word8 Natural
w, Natural -> Int -> Natural
forall a. Bits a => a -> Int -> a
unsafeShiftR Natural
w Int
8))

-- | Create a 'Natural' out of a 'ByteString', in big endian.
bytesToNatural :: ByteString -> Natural
bytesToNatural :: ByteString -> Natural
bytesToNatural = Integer -> Natural
GHC.naturalFromInteger (Integer -> Natural)
-> (ByteString -> Integer) -> ByteString -> Natural
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> Integer
bytesToInteger

-- | Create a 'Natural' out of a 'ByteArray', in big endian.
byteArrayToNatural :: ByteArray -> Natural
byteArrayToNatural :: ByteArray -> Natural
byteArrayToNatural = Integer -> Natural
GHC.naturalFromInteger (Integer -> Natural)
-> (ByteArray -> Integer) -> ByteArray -> Natural
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteArray -> Integer
byteArrayToInteger

-- | The inverse of 'bytesToNatural'. Note that this is a naive implementation
-- and only suitable for tests.
naturalToBytes :: Int -> Natural -> ByteString
naturalToBytes :: Int -> Natural -> ByteString
naturalToBytes = Int -> Natural -> ByteString
writeBinaryNatural

-- | The inverse of 'bytesToNatural'. Note that this is a naive implementation
-- and only suitable for tests.
naturalToByteArray :: Int -> Natural -> ByteArray
naturalToByteArray :: Int -> Natural -> ByteArray
naturalToByteArray Int
numBytes = ByteString -> ByteArray
byteStringToByteArray (ByteString -> ByteArray)
-> (Natural -> ByteString) -> Natural -> ByteArray
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> Natural -> ByteString
writeBinaryNatural Int
numBytes

-- | Create a 'Integer' out of a 'ByteString', in big endian.
bytesToInteger :: ByteString -> Integer
bytesToInteger :: ByteString -> Integer
bytesToInteger (BS.PS ForeignPtr Word8
fp (GHC.I# Int#
off#) (GHC.I# Int#
len#)) =
  -- This should be safe since we're simply reading from ByteString (which is
  -- immutable) and GMP allocates a new memory for the Integer, i.e., there is
  -- no mutation involved.
  IO Integer -> Integer
forall a. IO a -> a
unsafeDupablePerformIO (IO Integer -> Integer) -> IO Integer -> Integer
forall a b. (a -> b) -> a -> b
$
    ForeignPtr Word8 -> (Ptr Word8 -> IO Integer) -> IO Integer
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr Word8
fp ((Ptr Word8 -> IO Integer) -> IO Integer)
-> (Ptr Word8 -> IO Integer) -> IO Integer
forall a b. (a -> b) -> a -> b
$ \(GHC.Ptr Addr#
addr#) ->
      let addrOff# :: Addr#
addrOff# = Addr#
addr# Addr# -> Int# -> Addr#
`GHC.plusAddr#` Int#
off#
       in -- The last parmaeter (`1#`) tells the import function to use big
          -- endian encoding.
          Addr# -> Word# -> Int# -> IO Integer
importIntegerFromAddr Addr#
addrOff# (Int# -> Word#
GHC.int2Word# Int#
len#) Int#
1#
  where
    importIntegerFromAddr :: Addr# -> Word# -> Int# -> IO Integer
    importIntegerFromAddr :: Addr# -> Word# -> Int# -> IO Integer
importIntegerFromAddr Addr#
addr Word#
sz = Word# -> Addr# -> Int# -> IO Integer
integerFromAddr Word#
sz Addr#
addr

-- | Create a 'Integer' out of a 'ByteArray', in big endian.
byteArrayToInteger :: ByteArray -> Integer
byteArrayToInteger :: ByteArray -> Integer
byteArrayToInteger (ByteArray ByteArray#
ba#) =
  -- The last parmaeter (`1#`) tells the import function to use big
  -- endian encoding. The one before last (`0#`) is the offset
  Word# -> ByteArray# -> Word# -> Int# -> Integer
integerFromByteArray (Int# -> Word#
GHC.int2Word# (ByteArray# -> Int#
sizeofByteArray# ByteArray#
ba#)) ByteArray#
ba# (Int# -> Word#
GHC.int2Word# Int#
0#) Int#
1#

-- | Decode base16 ByteString, while ensuring expected length.
decodeHexByteString :: ByteString -> Int -> Either String ByteString
decodeHexByteString :: ByteString -> Int -> Either String ByteString
decodeHexByteString ByteString
bsHex Int
lenExpected = do
  ByteString
bs <- (String -> String)
-> Either String ByteString -> Either String ByteString
forall a b c. (a -> b) -> Either a c -> Either b c
forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first (String
"Malformed hex: " String -> String -> String
forall a. [a] -> [a] -> [a]
++) (Either String ByteString -> Either String ByteString)
-> Either String ByteString -> Either String ByteString
forall a b. (a -> b) -> a -> b
$ ByteString -> Either String ByteString
BS16.decode ByteString
bsHex
  let lenActual :: Int
lenActual = ByteString -> Int
BS.length ByteString
bs
  Bool -> Either String () -> Either String ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Int
lenExpected Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
lenActual) (Either String () -> Either String ())
-> Either String () -> Either String ()
forall a b. (a -> b) -> a -> b
$
    String -> Either String ()
forall a b. a -> Either a b
Left (String -> Either String ()) -> String -> Either String ()
forall a b. (a -> b) -> a -> b
$
      String
"Expected in decoded form to be: "
        String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
lenExpected
        String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" bytes, but got: "
        String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
lenActual
  ByteString -> Either String ByteString
forall a. a -> Either String a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ByteString
bs

-- | Decode base16 String, while ensuring expected length. Unlike
-- `decodeHexByteString` this function expects a '0x' prefix.
decodeHexString :: String -> Int -> Either String ByteString
decodeHexString :: String -> Int -> Either String ByteString
decodeHexString String
hexStr' Int
lenExpected = do
  let hexStr :: String
hexStr =
        case String
hexStr' of
          Char
'0' : Char
'x' : String
str -> String
str
          String
str -> String
str
  Bool -> Either String () -> Either String ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless ((Char -> Bool) -> String -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all Char -> Bool
isAscii String
hexStr) (Either String () -> Either String ())
-> Either String () -> Either String ()
forall a b. (a -> b) -> a -> b
$ String -> Either String ()
forall a b. a -> Either a b
Left (String -> Either String ()) -> String -> Either String ()
forall a b. (a -> b) -> a -> b
$ String
"Input string contains invalid characters: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
hexStr
  ByteString -> Int -> Either String ByteString
decodeHexByteString (String -> ByteString
BSC8.pack String
hexStr) Int
lenExpected

-- | Decode a `String` with Hex characters, while ensuring expected length.
decodeHexStringQ :: String -> Int -> Q Exp
decodeHexStringQ :: String -> Int -> Q Exp
decodeHexStringQ String
hexStr Int
n = do
  case String -> Int -> Either String ByteString
decodeHexString String
hexStr Int
n of
    Left String
err -> String -> Q Exp
forall a. String -> Q a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail (String -> Q Exp) -> String -> Q Exp
forall a b. (a -> b) -> a -> b
$ String
"<decodeHexByteString>: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
err
    Right ByteString
_ -> [|either error id (decodeHexString hexStr n)|]