diff --git a/.gitignore b/.gitignore index aee1772..d36b981 100644 --- a/.gitignore +++ b/.gitignore @@ -7,3 +7,4 @@ result/ cabal.project.local tags /.stack-work/ +/.ghc.environment* diff --git a/arrayfire.cabal b/arrayfire.cabal index d7474af..bda0066 100644 --- a/arrayfire.cabal +++ b/arrayfire.cabal @@ -1,6 +1,6 @@ cabal-version: 3.0 name: arrayfire -version: 0.7.1.0 +version: 0.8.0.0 synopsis: Haskell bindings to the ArrayFire general-purpose GPU library homepage: https://github.com/arrayfire/arrayfire-haskell license: BSD-3-Clause @@ -177,6 +177,7 @@ test-suite test ArrayFire.ImageSpec ArrayFire.IndexSpec ArrayFire.LAPACKSpec + ArrayFire.NumericalSpec ArrayFire.RandomSpec ArrayFire.SignalSpec ArrayFire.SparseSpec diff --git a/flake.lock b/flake.lock index c767330..3851d27 100644 --- a/flake.lock +++ b/flake.lock @@ -5,11 +5,11 @@ "systems": "systems" }, "locked": { - "lastModified": 1692792214, - "narHash": "sha256-voZDQOvqHsaReipVd3zTKSBwN7LZcUwi3/ThMxRZToU=", + "lastModified": 1731533236, + "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=", "owner": "numtide", "repo": "flake-utils", - "rev": "1721b3e7c882f75f2301b00d48a2884af8c448ae", + "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b", "type": "github" }, "original": { @@ -20,11 +20,11 @@ }, "nix-filter": { "locked": { - "lastModified": 1687178632, - "narHash": "sha256-HS7YR5erss0JCaUijPeyg2XrisEb959FIct3n2TMGbE=", + "lastModified": 1757882181, + "narHash": "sha256-+cCxYIh2UNalTz364p+QYmWHs0P+6wDhiWR4jDIKQIU=", "owner": "numtide", "repo": "nix-filter", - "rev": "d90c75e8319d0dd9be67d933d8eb9d0894ec9174", + "rev": "59c44d1909c72441144b93cf0f054be7fe764de5", "type": "github" }, "original": { @@ -35,11 +35,11 @@ }, "nixpkgs": { "locked": { - "lastModified": 1692638711, - "narHash": "sha256-J0LgSFgJVGCC1+j5R2QndadWI1oumusg6hCtYAzLID4=", + "lastModified": 1780243769, + "narHash": "sha256-x5UQuRsH3MqI0U9afaXSNqzTPSeZlRLvFAav2Ux1pNw=", "owner": "nixos", "repo": "nixpkgs", - "rev": "91a22f76cd1716f9d0149e8a5c68424bb691de15", + "rev": "331800de5053fcebacf6813adb5db9c9dca22a0c", "type": "github" }, "original": { diff --git a/include/algorithm.h b/include/algorithm.h index 8894a73..c36f8d3 100644 --- a/include/algorithm.h +++ b/include/algorithm.h @@ -34,3 +34,12 @@ af_err af_sort_by_key(af_array *out_keys, af_array *out_values, const af_array k af_err af_set_unique(af_array *out, const af_array in, const bool is_sorted); af_err af_set_union(af_array *out, const af_array first, const af_array second, const bool is_unique); af_err af_set_intersect(af_array *out, const af_array first, const af_array second, const bool is_unique); +af_err af_sum_by_key(af_array *keys_out, af_array *vals_out, const af_array keys, const af_array vals, const int dim); +af_err af_sum_by_key_nan(af_array *keys_out, af_array *vals_out, const af_array keys, const af_array vals, const int dim, const double nanval); +af_err af_product_by_key(af_array *keys_out, af_array *vals_out, const af_array keys, const af_array vals, const int dim); +af_err af_product_by_key_nan(af_array *keys_out, af_array *vals_out, const af_array keys, const af_array vals, const int dim, const double nanval); +af_err af_min_by_key(af_array *keys_out, af_array *vals_out, const af_array keys, const af_array vals, const int dim); +af_err af_max_by_key(af_array *keys_out, af_array *vals_out, const af_array keys, const af_array vals, const int dim); +af_err af_all_true_by_key(af_array *keys_out, af_array *vals_out, const af_array keys, const af_array vals, const int dim); +af_err af_any_true_by_key(af_array *keys_out, af_array *vals_out, const af_array keys, const af_array vals, const int dim); +af_err af_count_by_key(af_array *keys_out, af_array *vals_out, const af_array keys, const af_array vals, const int dim); diff --git a/include/blas.h b/include/blas.h index d872069..e70bdba 100644 --- a/include/blas.h +++ b/include/blas.h @@ -5,3 +5,4 @@ af_err af_dot(af_array *out, const af_array lhs, const af_array rhs, const af_ma af_err af_dot_all(double *real, double *imag, const af_array lhs, const af_array rhs, const af_mat_prop optLhs, const af_mat_prop optRhs); af_err af_transpose(af_array *out, af_array in, const bool conjugate); af_err af_transpose_inplace(af_array in, const bool conjugate); +af_err af_gemm(af_array *out, const af_mat_prop optLhs, const af_mat_prop optRhs, const void *alpha, const af_array lhs, const af_array rhs, const void *beta); diff --git a/include/statistics.h b/include/statistics.h index c3ddd9b..59f37bc 100644 --- a/include/statistics.h +++ b/include/statistics.h @@ -15,3 +15,4 @@ af_err af_stdev_all(double *real, double *imag, const af_array in); af_err af_median_all(double *realVal, double *imagVal, const af_array in); af_err af_corrcoef(double *realVal, double *imagVal, const af_array X, const af_array Y); af_err af_topk(af_array *values, af_array *indices, const af_array in, const int k, const int dim, const af_topk_function order); +af_err af_meanvar(af_array *mean, af_array *var, const af_array in, const af_array weights, const af_var_bias bias, const dim_t dim); diff --git a/src/ArrayFire/Algorithm.hs b/src/ArrayFire/Algorithm.hs index b7fccba..8fdf369 100644 --- a/src/ArrayFire/Algorithm.hs +++ b/src/ArrayFire/Algorithm.hs @@ -26,6 +26,9 @@ -------------------------------------------------------------------------------- module ArrayFire.Algorithm where +import Data.Word (Word32) +import Foreign.C.Types (CBool) + import ArrayFire.FFI import ArrayFire.Internal.Algorithm import ArrayFire.Internal.Types @@ -152,13 +155,13 @@ max x (fromIntegral -> n) = x `op1` (\p a -> af_max p a n) -- [1 1 1 1] -- 0 allTrue - :: forall a. AFType a + :: AFType a => Array a -- ^ Array input -> Int -- ^ Dimension along which to see if all elements are True - -> Array a - -- ^ Will contain the maximum of all values in the input array along dim + -> Array CBool + -- ^ Will contain 1 where all elements along dim are true, 0 otherwise allTrue x (fromIntegral -> n) = x `op1` (\p a -> af_all_true p a n) @@ -169,13 +172,13 @@ allTrue x (fromIntegral -> n) = -- [1 1 1 1] -- 0 anyTrue - :: forall a . AFType a + :: AFType a => Array a -- ^ Array input -> Int - -- ^ Dimension along which to see if all elements are True - -> Array a - -- ^ Returns if all elements are true + -- ^ Dimension along which to see if any elements are True + -> Array CBool + -- ^ Will contain 1 where any element along dim is true, 0 otherwise anyTrue x (fromIntegral -> n) = (x `op1` (\p a -> af_any_true p a n)) @@ -193,7 +196,7 @@ count -- ^ Dimension along which to count -> Array Int -- ^ Count of all elements along dimension -count x (fromIntegral -> n) = x `op1d` (\p a -> af_count p a n) +count x (fromIntegral -> n) = x `op1` (\p a -> af_count p a n) -- | Sum all elements in an 'Array' along all dimensions -- @@ -323,7 +326,7 @@ imin -- ^ Input array -> Int -- ^ The dimension along which the minimum value is extracted - -> (Array a, Array a) + -> (Array a, Array Word32) -- ^ will contain the minimum of all values along dim, will also contain the location of minimum of all values in in along dim imin a (fromIntegral -> n) = op2p a (\x y z -> af_imin x y z n) @@ -343,7 +346,7 @@ imax -- ^ Input array -> Int -- ^ The dimension along which the minimum value is extracted - -> (Array a, Array a) + -> (Array a, Array Word32) -- ^ will contain the maximum of all values in in along dim, will also contain the location of maximum of all values in in along dim imax a (fromIntegral -> n) = op2p a (\x y z -> af_imax x y z n) @@ -471,8 +474,8 @@ where' :: AFType a => Array a -- ^ Is the input array. - -> Array a - -- ^ will contain indices where input array is non-zero + -> Array Word32 + -- ^ Indices where input array is non-zero where' = (`op1` af_where) -- | First order numerical difference along specified dimension. @@ -565,7 +568,7 @@ sortIndex -- ^ Dimension along `sortIndex` is performed -> Bool -- ^ Return results in ascending order - -> (Array a, Array a) + -> (Array a, Array Word32) -- ^ Contains the sorted, contains indices for original input sortIndex a (fromIntegral -> n) (fromIntegral . fromEnum -> b) = a `op2p` (\p1 p2 p3 -> af_sort_index p1 p2 p3 n b) @@ -657,3 +660,137 @@ setIntersect -- ^ Intersection of first and second array setIntersect a1 a2 (fromIntegral . fromEnum -> b) = op2 a1 a2 (\x y z -> af_set_intersect x y z b) + +-- | Sum values in 'Array' grouped by keys along a dimension. +-- +-- Each contiguous run of equal keys in @keys@ produces one output element. +-- Returns @(keys_out, vals_out)@. +-- +-- >>> sumByKey (vector @Int 5 [1,1,2,2,2]) (vector @Double 5 [10,20,1,2,3]) 0 +-- (ArrayFire Array +-- [2 1 1 1] +-- 1 2, +-- ArrayFire Array +-- [2 1 1 1] +-- 30.0000 6.0000) +sumByKey + :: AFType a + => Array Int + -- ^ Keys array (contiguous equal keys form a group) + -> Array a + -- ^ Values array + -> Int + -- ^ Dimension along which to reduce + -> (Array Int, Array a) + -- ^ (reduced keys, reduced values) +sumByKey keys vals (fromIntegral -> dim) = + op2p2kv keys vals (\ko vo k v -> af_sum_by_key ko vo k v dim) + +-- | 'sumByKey' replacing NaN values with a substitute before summing. +sumByKeyNaN + :: AFType a + => Array Int + -- ^ Keys array + -> Array a + -- ^ Values array + -> Int + -- ^ Dimension + -> Double + -- ^ Substitute for NaN values + -> (Array Int, Array a) + -- ^ (reduced keys, reduced values) +sumByKeyNaN keys vals (fromIntegral -> dim) nanval = + op2p2kv keys vals (\ko vo k v -> af_sum_by_key_nan ko vo k v dim nanval) + +-- | Product of values in 'Array' grouped by keys along a dimension. +productByKey + :: AFType a + => Array Int + -- ^ Keys array + -> Array a + -- ^ Values array + -> Int + -- ^ Dimension + -> (Array Int, Array a) +productByKey keys vals (fromIntegral -> dim) = + op2p2kv keys vals (\ko vo k v -> af_product_by_key ko vo k v dim) + +-- | 'productByKey' replacing NaN values with a substitute before multiplying. +productByKeyNaN + :: AFType a + => Array Int + -- ^ Keys array + -> Array a + -- ^ Values array + -> Int + -- ^ Dimension + -> Double + -- ^ Substitute for NaN values + -> (Array Int, Array a) +productByKeyNaN keys vals (fromIntegral -> dim) nanval = + op2p2kv keys vals (\ko vo k v -> af_product_by_key_nan ko vo k v dim nanval) + +-- | Minimum of values in 'Array' grouped by keys along a dimension. +minByKey + :: AFType a + => Array Int + -- ^ Keys array + -> Array a + -- ^ Values array + -> Int + -- ^ Dimension + -> (Array Int, Array a) +minByKey keys vals (fromIntegral -> dim) = + op2p2kv keys vals (\ko vo k v -> af_min_by_key ko vo k v dim) + +-- | Maximum of values in 'Array' grouped by keys along a dimension. +maxByKey + :: AFType a + => Array Int + -- ^ Keys array + -> Array a + -- ^ Values array + -> Int + -- ^ Dimension + -> (Array Int, Array a) +maxByKey keys vals (fromIntegral -> dim) = + op2p2kv keys vals (\ko vo k v -> af_max_by_key ko vo k v dim) + +-- | True if all values are true within each key group. +allTrueByKey + :: AFType a + => Array Int + -- ^ Keys array + -> Array a + -- ^ Values array (treated as boolean) + -> Int + -- ^ Dimension + -> (Array Int, Array a) +allTrueByKey keys vals (fromIntegral -> dim) = + op2p2kv keys vals (\ko vo k v -> af_all_true_by_key ko vo k v dim) + +-- | True if any value is true within each key group. +anyTrueByKey + :: AFType a + => Array Int + -- ^ Keys array + -> Array a + -- ^ Values array (treated as boolean) + -> Int + -- ^ Dimension + -> (Array Int, Array a) +anyTrueByKey keys vals (fromIntegral -> dim) = + op2p2kv keys vals (\ko vo k v -> af_any_true_by_key ko vo k v dim) + +-- | Count non-zero values within each key group. +countByKey + :: AFType a + => Array Int + -- ^ Keys array + -> Array a + -- ^ Values array + -> Int + -- ^ Dimension + -> (Array Int, Array a) +countByKey keys vals (fromIntegral -> dim) = + op2p2kv keys vals (\ko vo k v -> af_count_by_key ko vo k v dim) diff --git a/src/ArrayFire/Arith.hs b/src/ArrayFire/Arith.hs index ec2cc25..c603849 100644 --- a/src/ArrayFire/Arith.hs +++ b/src/ArrayFire/Arith.hs @@ -28,7 +28,7 @@ -------------------------------------------------------------------------------- module ArrayFire.Arith where -import Prelude (Bool(..), ($), (.), flip, fromEnum, fromIntegral, Real, RealFrac) +import Prelude (Bool(..), ($), (.), flip, fromEnum, fromIntegral, Real, RealFloat) import Data.Coerce import Data.Proxy @@ -512,7 +512,7 @@ not -- ^ Input 'Array' -> Array CBool -- ^ Result of 'not' on an 'Array' -not = flip op1d af_not +not = flip op1 af_not -- | Bitwise and the values in one 'Array' against another 'Array' -- @@ -526,10 +526,10 @@ bitAnd -- ^ First input -> Array a -- ^ Second input - -> Array CBool + -> Array a -- ^ Result of bitwise and bitAnd x y = - x `op2bool` y $ \arr arr1 arr2 -> + x `op2` y $ \arr arr1 arr2 -> af_bitand arr arr1 arr2 1 -- | Bitwise and the values in one 'Array' against another 'Array' @@ -546,10 +546,10 @@ bitAndBatched -- ^ Second input -> Bool -- ^ Use batch - -> Array CBool + -> Array a -- ^ Result of bitwise and bitAndBatched x y (fromIntegral . fromEnum -> batch) = do - x `op2bool` y $ \arr arr1 arr2 -> + x `op2` y $ \arr arr1 arr2 -> af_bitand arr arr1 arr2 batch -- | Bitwise or the values in one 'Array' against another 'Array' @@ -564,10 +564,10 @@ bitOr -- ^ First input -> Array a -- ^ Second input - -> Array CBool - -- ^ Result of bit or + -> Array a + -- ^ Result of bitwise or bitOr x y = do - x `op2bool` y $ \arr arr1 arr2 -> + x `op2` y $ \arr arr1 arr2 -> af_bitor arr arr1 arr2 1 -- | Bitwise or the values in one 'Array' against another 'Array' @@ -584,10 +584,10 @@ bitOrBatched -- ^ Second input -> Bool -- ^ Use batch - -> Array CBool - -- ^ Result of bit or + -> Array a + -- ^ Result of bitwise or bitOrBatched x y (fromIntegral . fromEnum -> batch) = do - x `op2bool` y $ \arr arr1 arr2 -> + x `op2` y $ \arr arr1 arr2 -> af_bitor arr arr1 arr2 batch -- | Bitwise xor the values in one 'Array' against another 'Array' @@ -602,10 +602,10 @@ bitXor -- ^ First input -> Array a -- ^ Second input - -> Array CBool - -- ^ Result of bit xor + -> Array a + -- ^ Result of bitwise xor bitXor x y = do - x `op2bool` y $ \arr arr1 arr2 -> + x `op2` y $ \arr arr1 arr2 -> af_bitxor arr arr1 arr2 1 -- | Bitwise xor the values in one 'Array' against another 'Array' @@ -622,10 +622,10 @@ bitXorBatched -- ^ Second input -> Bool -- ^ Use batch - -> Array CBool - -- ^ Result of bit xor + -> Array a + -- ^ Result of bitwise xor bitXorBatched x y (fromIntegral . fromEnum -> batch) = do - x `op2bool` y $ \arr arr1 arr2 -> + x `op2` y $ \arr arr1 arr2 -> af_bitxor arr arr1 arr2 batch -- | Left bit shift the values in one 'Array' against another 'Array' @@ -640,10 +640,10 @@ bitShiftL -- ^ First input -> Array a -- ^ Second input - -> Array CBool + -> Array a -- ^ Result of bit shift left bitShiftL x y = - x `op2bool` y $ \arr arr1 arr2 -> + x `op2` y $ \arr arr1 arr2 -> af_bitshiftl arr arr1 arr2 1 -- | Left bit shift the values in one 'Array' against another 'Array' @@ -660,10 +660,10 @@ bitShiftLBatched -- ^ Second input -> Bool -- ^ Use batch - -> Array CBool + -> Array a -- ^ Result of bit shift left bitShiftLBatched x y (fromIntegral . fromEnum -> batch) = do - x `op2bool` y $ \arr arr1 arr2 -> + x `op2` y $ \arr arr1 arr2 -> af_bitshiftl arr arr1 arr2 batch -- | Right bit shift the values in one 'Array' against another 'Array' @@ -678,10 +678,10 @@ bitShiftR -- ^ First input -> Array a -- ^ Second input - -> Array CBool + -> Array a -- ^ Result of bit shift right bitShiftR x y = - x `op2bool` y $ \arr arr1 arr2 -> + x `op2` y $ \arr arr1 arr2 -> af_bitshiftr arr arr1 arr2 1 -- | Right bit shift the values in one 'Array' against another 'Array' @@ -698,10 +698,10 @@ bitShiftRBatched -- ^ Second input -> Bool -- ^ Use batch - -> Array CBool - -- ^ Result of bit shift left + -> Array a + -- ^ Result of bit shift right bitShiftRBatched x y (fromIntegral . fromEnum -> batch) = do - x `op2bool` y $ \arr arr1 arr2 -> + x `op2` y $ \arr arr1 arr2 -> af_bitshiftr arr arr1 arr2 batch -- | Cast one 'Array' into another @@ -717,7 +717,7 @@ cast -> Array b -- ^ Result of cast cast afArr = - coerce $ afArr `op1` (\x y -> af_cast x y dtyp) + coerce $ afArr `op1` (\x y -> ArrayFire.Internal.Arith.af_cast x y dtyp) where dtyp = afType (Proxy @b) @@ -1315,12 +1315,12 @@ atan2Batched x y (fromIntegral . fromEnum -> batch) = do -- (9.0000,9.0000) -- (10.0000,10.0000) cplx2 - :: AFType a + :: (RealFloat a, AFType a, AFType (Complex a)) => Array a - -- ^ First input - -> Array a - -- ^ Second input + -- ^ First input (real part) -> Array a + -- ^ Second input (imaginary part) + -> Array (Complex a) -- ^ Result of cplx2 cplx2 x y = x `op2` y $ \arr arr1 arr2 -> @@ -1342,14 +1342,14 @@ cplx2 x y = -- (9.0000,9.0000) -- (10.0000,10.0000) cplx2Batched - :: AFType a + :: (RealFloat a, AFType a, AFType (Complex a)) => Array a - -- ^ First input + -- ^ First input (real part) -> Array a - -- ^ Second input + -- ^ Second input (imaginary part) -> Bool -- ^ Use batch - -> Array a + -> Array (Complex a) -- ^ Result of cplx2 cplx2Batched x y (fromIntegral . fromEnum -> batch) = do x `op2` y $ \arr arr1 arr2 -> @@ -1371,11 +1371,11 @@ cplx2Batched x y (fromIntegral . fromEnum -> batch) = do -- (9.0000,0.0000) -- (10.0000,0.0000) cplx - :: AFType a + :: (RealFloat a, AFType a, AFType (Complex a)) => Array a -- ^ Input array - -> Array a - -- ^ Result of calling 'atan' + -> Array (Complex a) + -- ^ Complex array with input as real part and zero imaginary part cplx = flip op1 af_cplx -- | Execute real @@ -1385,12 +1385,12 @@ cplx = flip op1 af_cplx -- [1 1 1 1] -- 10.0000 real - :: (AFType a, AFType (Complex b), RealFrac a, RealFrac b) - => Array (Complex b) + :: (RealFloat a, AFType a, AFType (Complex a)) + => Array (Complex a) -- ^ Input array -> Array a - -- ^ Result of calling 'real' -real = flip op1d af_real + -- ^ Real part of each element +real = flip op1 af_real -- | Execute imag -- @@ -1399,12 +1399,12 @@ real = flip op1d af_real -- [1 1 1 1] -- 11.0000 imag - :: (AFType a, AFType (Complex b), RealFrac a, RealFrac b) - => Array (Complex b) + :: (RealFloat a, AFType a, AFType (Complex a)) + => Array (Complex a) -- ^ Input array -> Array a - -- ^ Result of calling 'imag' -imag = flip op1d af_imag + -- ^ Imaginary part of each element +imag = flip op1 af_imag -- | Execute conjg -- @@ -2043,7 +2043,7 @@ isZero :: AFType a => Array a -- ^ Input array - -> Array a + -> Array CBool -- ^ Result of calling 'isZero' isZero = (`op1` af_iszero) @@ -2066,7 +2066,7 @@ isInf :: (Real a, AFType a) => Array a -- ^ Input array - -> Array a + -> Array CBool -- ^ will contain 1's where input is Inf or -Inf, and 0 otherwise. isInf = (`op1` af_isinf) @@ -2086,9 +2086,9 @@ isInf = (`op1` af_isinf) -- 1 -- 1 isNaN - :: forall a. (AFType a, Real a) + :: (AFType a, Real a) => Array a -- ^ Input array - -> Array a + -> Array CBool -- ^ Will contain 1's where input is NaN, and 0 otherwise. isNaN = (`op1` af_isnan) diff --git a/src/ArrayFire/Array.hs b/src/ArrayFire/Array.hs index b0abc01..9b14e0c 100644 --- a/src/ArrayFire/Array.hs +++ b/src/ArrayFire/Array.hs @@ -177,21 +177,30 @@ mkArray -- ^ Returned array {-# NOINLINE mkArray #-} mkArray dims xs = - unsafePerformIO $ do - when (Prelude.length (take size xs) < size) $ do - let msg = "Invalid elements provided. " - <> "Expected " - <> show size - <> " elements received " - <> show (Prelude.length xs) - throwIO (AFException SizeError 203 msg) - dataPtr <- castPtr <$> newArray (Prelude.take size xs) + unsafePerformIO . mask_ $ do let ndims = fromIntegral (Prelude.length dims) alloca $ \arrayPtr -> do zeroOutArray arrayPtr dimsPtr <- newArray (DimT . fromIntegral <$> dims) - throwAFError =<< af_create_array arrayPtr dataPtr ndims dimsPtr dType - free dataPtr >> free dimsPtr + if size == 0 + then onException + (do throwAFError =<< af_create_handle arrayPtr ndims dimsPtr dType + free dimsPtr) + (free dimsPtr) + else do + when (Prelude.length (Prelude.take size xs) < size) $ do + free dimsPtr + let msg = "Invalid elements provided. " + <> "Expected " + <> show size + <> " elements received " + <> show (Prelude.length xs) + throwIO (AFException SizeError 203 msg) + dataPtr <- castPtr <$> newArray (Prelude.take size xs) + onException + (do throwAFError =<< af_create_array arrayPtr dataPtr ndims dimsPtr dType + free dataPtr >> free dimsPtr) + (free dataPtr >> free dimsPtr) arr <- peek arrayPtr Array <$> newForeignPtr af_release_array_finalizer arr where @@ -200,6 +209,46 @@ mkArray dims xs = -- af_err af_create_handle(af_array *arr, const unsigned ndims, const dim_t * const dims, const af_dtype type); +-- | Constructs an 'Array' from a 'Storable' 'Vector', avoiding the intermediate list allocation of 'mkArray'. +-- +-- The vector's pinned buffer is passed directly to @af_create_array@. +-- Throws 'AFException' if the vector length does not match the product of the given dimensions. +-- +-- >>> fromVector @Double [3] (Data.Vector.Storable.fromList [1,2,3]) +-- ArrayFire Array +-- [3 1 1 1] +-- 1.0000 +-- 2.0000 +-- 3.0000 +fromVector + :: forall a + . AFType a + => [Int] + -- ^ Dimensions + -> Vector a + -- ^ Source storable vector + -> Array a +{-# NOINLINE fromVector #-} +fromVector dims vec = + unsafePerformIO . mask_ $ do + let size = Prelude.product dims + ndims = fromIntegral (Prelude.length dims) + dType = afType (Proxy @a) + when (V.length vec /= size) $ + throwIO $ AFException SizeError 203 $ + "fromVector: dimension product " <> show size <> + " does not match vector length " <> show (V.length vec) + alloca $ \arrayPtr -> do + zeroOutArray arrayPtr + dimsPtr <- newArray (DimT . fromIntegral <$> dims) + onException + (V.unsafeWith vec $ \ptr -> do + throwAFError =<< af_create_array arrayPtr (castPtr ptr) ndims dimsPtr dType + free dimsPtr) + (free dimsPtr) + arr <- peek arrayPtr + Array <$> newForeignPtr af_release_array_finalizer arr + -- | Copies an 'Array' to a new 'Array' -- -- >>> copyArray (scalar @Double 10) @@ -479,11 +528,12 @@ isSparse a = toEnum . fromIntegral $ (a `infoFromArray` af_is_sparse) -- >>> toVector (vector @Double 10 [1..]) -- [1.0,2.0,3.0,4.0,5.0,6.0,7.0,8.0,9.0,10.0] toVector :: forall a . AFType a => Array a -> Vector a -toVector arr@(Array fptr) = do +{-# NOINLINE toVector #-} +toVector arr@(Array fptr) = unsafePerformIO . mask_ . withForeignPtr fptr $ \arrPtr -> do let len = getElements arr size = len * getSizeOf (Proxy @a) - ptr <- mallocBytes (len * size) + ptr <- mallocBytes size throwAFError =<< af_get_data_ptr (castPtr ptr) arrPtr newFptr <- newForeignPtr finalizerFree ptr pure $ unsafeFromForeignPtr0 newFptr len @@ -500,6 +550,7 @@ toList = V.toList . toVector -- >>> getScalar (scalar @Double 22.0) :: Double -- 22.0 getScalar :: forall a b . (Storable a, AFType b) => Array b -> a +{-# NOINLINE getScalar #-} getScalar (Array fptr) = unsafePerformIO . mask_ . withForeignPtr fptr $ \arrPtr -> do alloca $ \ptr -> do diff --git a/src/ArrayFire/BLAS.hs b/src/ArrayFire/BLAS.hs index 321980a..463edeb 100644 --- a/src/ArrayFire/BLAS.hs +++ b/src/ArrayFire/BLAS.hs @@ -31,8 +31,15 @@ -------------------------------------------------------------------------------- module ArrayFire.BLAS where +import Control.Exception (mask_) import Data.Complex +import Foreign.ForeignPtr (newForeignPtr, withForeignPtr) +import Foreign.Marshal.Alloc (alloca) +import Foreign.Ptr (castPtr) +import Foreign.Storable (peek, poke) +import System.IO.Unsafe (unsafePerformIO) +import ArrayFire.Exception import ArrayFire.FFI import ArrayFire.Internal.BLAS import ArrayFire.Internal.Types @@ -167,3 +174,43 @@ transposeInPlace -> IO () transposeInPlace arr (fromIntegral . fromEnum -> b) = arr `inPlace` (`af_transpose_inplace` b) + +-- | General Matrix Multiply: C = alpha * op(A) * op(B) + beta * C_prev +-- +-- More general than 'matmul': supports scaling and accumulation. +-- When @beta = 0@, equivalent to @alpha * op(A) * op(B)@. +-- +-- >>> gemm None None 1.0 (matrix @Double (2,2) [[1,0],[0,1]]) (matrix @Double (2,2) [[3,4],[5,6]]) 0.0 +-- ArrayFire Array +-- [2 2 1 1] +-- 3.0000 5.0000 +-- 4.0000 6.0000 +gemm + :: AFType a + => MatProp + -- ^ Transformation applied to A ('None', 'Trans', or 'CTrans') + -> MatProp + -- ^ Transformation applied to B ('None', 'Trans', or 'CTrans') + -> a + -- ^ Scalar alpha + -> Array a + -- ^ Matrix A + -> Array a + -- ^ Matrix B + -> a + -- ^ Scalar beta (use 0 for pure multiply) + -> Array a + -- ^ Result C = alpha * op(A) * op(B) + beta * C_prev +gemm opA opB alpha (Array fptrA) (Array fptrB) beta = + unsafePerformIO . mask_ $ + withForeignPtr fptrA $ \ptrA -> + withForeignPtr fptrB $ \ptrB -> + alloca $ \pOut -> + alloca $ \pAlpha -> + alloca $ \pBeta -> do + zeroOutArray pOut + poke pAlpha alpha + poke pBeta beta + throwAFError =<< af_gemm pOut (toMatProp opA) (toMatProp opB) (castPtr pAlpha) ptrA ptrB (castPtr pBeta) + Array <$> (newForeignPtr af_release_array_finalizer =<< peek pOut) +{-# NOINLINE gemm #-} diff --git a/src/ArrayFire/Data.hs b/src/ArrayFire/Data.hs index 8bcfe54..7edab2c 100644 --- a/src/ArrayFire/Data.hs +++ b/src/ArrayFire/Data.hs @@ -42,13 +42,37 @@ import Foreign.Storable import System.IO.Unsafe import Unsafe.Coerce +import Data.Bits + import ArrayFire.Exception import ArrayFire.FFI +import ArrayFire.Internal.Array (af_get_dims) import ArrayFire.Internal.Data import ArrayFire.Internal.Defines import ArrayFire.Internal.Types import ArrayFire.Arith +-- | Bitwise complement of every element in an 'Array' +-- +-- >>> A.bitNot (A.scalar @Int32 0) +-- ArrayFire Array +-- [1 1 1 1] +-- -1 +bitNot + :: (AFType a, Bits a) + => Array a + -> Array a +bitNot arr = arr `bitXor` ones + where + (d0, d1, d2, d3) = arr `infoFromArray4` af_get_dims + ones = constant + [ fromIntegral d0 + , fromIntegral d1 + , fromIntegral d2 + , fromIntegral d3 + ] + (complement zeroBits) + -- | Creates an 'Array' from a scalar value from given dimensions -- -- >>> constant @Double [2,2] 2.0 @@ -63,6 +87,7 @@ constant -> a -- ^ Scalar value -> Array a +{-# NOINLINE constant #-} constant dims val = case dtyp of x | x == c64 -> @@ -191,7 +216,7 @@ constant dims val = -- | Creates a range of values in an Array -- --- >>> range @Double [10] (-1) +-- >>> arange @Double [10] (-1) -- ArrayFire Array -- [10 1 1 1] -- 0.0000 @@ -204,14 +229,15 @@ constant dims val = -- 7.0000 -- 8.0000 -- 9.0000 -range +arange :: forall a . AFType a => [Int] -> Int -> Array a -range dims (fromIntegral -> k) = unsafePerformIO $ do - ptr <- alloca $ \ptrPtr -> mask_ $ do +{-# NOINLINE arange #-} +arange dims (fromIntegral -> k) = unsafePerformIO . mask_ $ do + ptr <- alloca $ \ptrPtr -> do withArray (fromIntegral <$> dims) $ \dimArray -> do throwAFError =<< af_range ptrPtr n dimArray k typ peek ptrPtr @@ -252,10 +278,11 @@ iota -- ^ is array containing the number of repetitions of the unit dimensions -> Array a -- ^ is the generated array -iota dims tdims = unsafePerformIO $ do +{-# NOINLINE iota #-} +iota dims tdims = unsafePerformIO . mask_ $ do let dims' = take 4 (dims ++ repeat 1) tdims' = take 4 (tdims ++ repeat 1) - ptr <- alloca $ \ptrPtr -> mask_ $ do + ptr <- alloca $ \ptrPtr -> do zeroOutArray ptrPtr withArray (fromIntegral <$> dims') $ \dimArray -> withArray (fromIntegral <$> tdims') $ \tdimArray -> do @@ -280,6 +307,7 @@ identity => [Int] -- ^ Dimensions -> Array a +{-# NOINLINE identity #-} identity dims = unsafePerformIO . mask_ $ do let dims' = take 4 (dims ++ repeat 1) ptr <- alloca $ \ptrPtr -> mask_ $ do @@ -303,7 +331,7 @@ identity dims = unsafePerformIO . mask_ $ do -- 1.0000 0.0000 -- 0.0000 2.0000 diagCreate - :: AFType (a :: *) + :: AFType a => Array a -- ^ is the input array which is the diagonal -> Int @@ -320,7 +348,7 @@ diagCreate x (fromIntegral -> n) = -- 1.0000 -- 4.0000 diagExtract - :: AFType (a :: *) + :: AFType a => Array a -> Int -> Array a @@ -339,7 +367,7 @@ diagExtract x (fromIntegral -> n) = -- join :: Int - -> Array (a :: *) + -> Array a -> Array a -> Array a join (fromIntegral -> n) arr1 arr2 = op2 arr1 arr2 (\p a b -> af_join p n a b) @@ -357,6 +385,7 @@ joinMany :: Int -> [Array a] -> Array a +{-# NOINLINE joinMany #-} joinMany (fromIntegral -> n) (fmap (\(Array fp) -> fp) -> arrays) = unsafePerformIO . mask_ $ do newPtr <- alloca $ \aPtr -> do zeroOutArray aPtr @@ -385,7 +414,7 @@ withManyForeignPtr fptrs action = go [] fptrs -- 22.0000 22.0000 22.0000 22.0000 22.0000 -- tile - :: Array (a :: *) + :: Array a -> [Int] -> Array a tile a (take 4 . (++repeat 1) -> [x,y,z,w]) = @@ -406,7 +435,7 @@ tile _ _ = error "impossible" -- 22.0000 22.0000 22.0000 22.0000 22.0000 -- reorder - :: Array (a :: *) + :: Array a -> [Int] -> Array a reorder a (take 4 . (++ repeat 0) -> [x,y,z,w]) = @@ -424,7 +453,7 @@ reorder _ _ = error "impossible" -- 2.0000 -- shift - :: Array (a :: *) + :: Array a -> Int -> Int -> Int @@ -441,10 +470,10 @@ shift a (fromIntegral -> x) (fromIntegral -> y) (fromIntegral -> z) (fromIntegra -- 1.0000 2.0000 3.0000 -- moddims - :: forall a - . Array (a :: *) + :: Array a -> [Int] -> Array a +{-# NOINLINE moddims #-} moddims (Array fptr) dims = unsafePerformIO . mask_ . withForeignPtr fptr $ \ptr -> do newPtr <- alloca $ \aPtr -> do diff --git a/src/ArrayFire/FFI.hs b/src/ArrayFire/FFI.hs index e776ace..f110581 100644 --- a/src/ArrayFire/FFI.hs +++ b/src/ArrayFire/FFI.hs @@ -30,6 +30,12 @@ import Foreign.C import Foreign.Marshal.Alloc import System.IO.Unsafe +foreign import ccall unsafe "af_cast" + af_cast :: Ptr AFArray -> AFArray -> AFDtype -> IO AFErr + +foreign import ccall unsafe "af_release_array" + af_release_array_ffi :: AFArray -> IO AFErr + op3 :: Array b -> Array a @@ -38,7 +44,7 @@ op3 -> Array a {-# NOINLINE op3 #-} op3 (Array fptr1) (Array fptr2) (Array fptr3) op = - unsafePerformIO $ do + unsafePerformIO . mask_ $ do withForeignPtr fptr1 $ \ptr1 -> withForeignPtr fptr2 $ \ptr2 -> do withForeignPtr fptr3 $ \ptr3 -> do @@ -57,7 +63,7 @@ op3Int -> Array a {-# NOINLINE op3Int #-} op3Int (Array fptr1) (Array fptr2) (Array fptr3) op = - unsafePerformIO $ do + unsafePerformIO . mask_ $ do withForeignPtr fptr1 $ \ptr1 -> withForeignPtr fptr2 $ \ptr2 -> do withForeignPtr fptr3 $ \ptr3 -> do @@ -75,7 +81,7 @@ op2 -> Array c {-# NOINLINE op2 #-} op2 (Array fptr1) (Array fptr2) op = - unsafePerformIO $ do + unsafePerformIO . mask_ $ do withForeignPtr fptr1 $ \ptr1 -> withForeignPtr fptr2 $ \ptr2 -> do ptr <- @@ -92,7 +98,7 @@ op2bool -> Array CBool {-# NOINLINE op2bool #-} op2bool (Array fptr1) (Array fptr2) op = - unsafePerformIO $ do + unsafePerformIO . mask_ $ do withForeignPtr fptr1 $ \ptr1 -> withForeignPtr fptr2 $ \ptr2 -> do ptr <- @@ -106,10 +112,10 @@ op2bool (Array fptr1) (Array fptr2) op = op2p :: Array a -> (Ptr AFArray -> Ptr AFArray -> AFArray -> IO AFErr) - -> (Array a, Array a) + -> (Array a, Array b) {-# NOINLINE op2p #-} op2p (Array fptr1) op = - unsafePerformIO $ do + unsafePerformIO . mask_ $ do (x,y) <- withForeignPtr fptr1 $ \ptr1 -> do alloca $ \ptrInput1 -> do alloca $ \ptrInput2 -> do @@ -125,7 +131,7 @@ op3p -> (Array a, Array a, Array a) {-# NOINLINE op3p #-} op3p (Array fptr1) op = - unsafePerformIO $ do + unsafePerformIO . mask_ $ do (x,y,z) <- withForeignPtr fptr1 $ \ptr1 -> do alloca $ \ptrInput1 -> do alloca $ \ptrInput2 -> do @@ -144,7 +150,7 @@ op3p1 -> (Array a, Array a, Array a, b) {-# NOINLINE op3p1 #-} op3p1 (Array fptr1) op = - unsafePerformIO $ do + unsafePerformIO . mask_ $ do (x,y,z,g) <- withForeignPtr fptr1 $ \ptr1 -> do alloca $ \ptrInput1 -> do alloca $ \ptrInput2 -> do @@ -167,7 +173,7 @@ op2p2 -> (Array a, Array a) {-# NOINLINE op2p2 #-} op2p2 (Array fptr1) (Array fptr2) op = - unsafePerformIO $ do + unsafePerformIO . mask_ $ do (x,y) <- withForeignPtr fptr1 $ \ptr1 -> do withForeignPtr fptr2 $ \ptr2 -> do @@ -179,6 +185,39 @@ op2p2 (Array fptr1) (Array fptr2) op = fptrB <- newForeignPtr af_release_array_finalizer y pure (Array fptrA, Array fptrB) +op2p2kv + :: Array Int + -> Array a + -> (Ptr AFArray -> Ptr AFArray -> AFArray -> AFArray -> IO AFErr) + -> (Array Int, Array a) +{-# NOINLINE op2p2kv #-} +op2p2kv (Array fptr1) (Array fptr2) op = + unsafePerformIO . mask_ $ do + (x, y) <- + withForeignPtr fptr1 $ \ptr1 -> + withForeignPtr fptr2 $ \ptr2 -> do + castedKey <- alloca $ \p -> do + throwAFError =<< af_cast p ptr1 s32 + peek p + alloca $ \ptrOutput1 -> + alloca $ \ptrOutput2 -> do + onException + (throwAFError =<< op ptrOutput1 ptrOutput2 castedKey ptr2) + (af_release_array_ffi castedKey) + _ <- af_release_array_ffi castedKey + outKey <- peek ptrOutput1 + outVal <- peek ptrOutput2 + finalKey <- alloca $ \p -> do + onException + (throwAFError =<< af_cast p outKey s64) + (af_release_array_ffi outKey) + peek p + _ <- af_release_array_ffi outKey + pure (finalKey, outVal) + fptrA <- newForeignPtr af_release_array_finalizer x + fptrB <- newForeignPtr af_release_array_finalizer y + pure (Array fptrA, Array fptrB) + createArray' :: (Ptr AFArray -> IO AFErr) -> IO (Array a) @@ -238,29 +277,13 @@ opw1 (Window fptr) op throwAFError =<< op p ptr peek p -op1d - :: Array a - -> (Ptr AFArray -> AFArray -> IO AFErr) - -> Array b -{-# NOINLINE op1d #-} -op1d (Array fptr1) op = - unsafePerformIO $ do - withForeignPtr fptr1 $ \ptr1 -> do - ptr <- - alloca $ \ptrInput -> do - throwAFError =<< op ptrInput ptr1 - peek ptrInput - fptr <- newForeignPtr af_release_array_finalizer ptr - pure (Array fptr) - - op1 :: Array a -> (Ptr AFArray -> AFArray -> IO AFErr) - -> Array a + -> Array b {-# NOINLINE op1 #-} op1 (Array fptr1) op = - unsafePerformIO $ do + unsafePerformIO . mask_ $ do withForeignPtr fptr1 $ \ptr1 -> do ptr <- alloca $ \ptrInput -> do @@ -304,7 +327,7 @@ op1b -> (b, Array a) {-# NOINLINE op1b #-} op1b (Array fptr1) op = - unsafePerformIO $ + unsafePerformIO . mask_ $ withForeignPtr fptr1 $ \ptr1 -> do (y,x) <- alloca $ \ptrInput1 -> do @@ -396,7 +419,7 @@ infoFromFeatures -> a {-# NOINLINE infoFromFeatures #-} infoFromFeatures (Features fptr1) op = - unsafePerformIO $ do + unsafePerformIO . mask_ $ do withForeignPtr fptr1 $ \ptr1 -> do alloca $ \ptrInput -> do throwAFError =<< op ptrInput ptr1 @@ -431,7 +454,7 @@ infoFromArray -> a {-# NOINLINE infoFromArray #-} infoFromArray (Array fptr1) op = - unsafePerformIO $ do + unsafePerformIO . mask_ $ do withForeignPtr fptr1 $ \ptr1 -> do alloca $ \ptrInput -> do throwAFError =<< op ptrInput ptr1 @@ -444,7 +467,7 @@ infoFromArray2 -> (a,b) {-# NOINLINE infoFromArray2 #-} infoFromArray2 (Array fptr1) op = - unsafePerformIO $ do + unsafePerformIO . mask_ $ do withForeignPtr fptr1 $ \ptr1 -> do alloca $ \ptrInput1 -> do alloca $ \ptrInput2 -> do @@ -459,7 +482,7 @@ infoFromArray22 -> (a,b) {-# NOINLINE infoFromArray22 #-} infoFromArray22 (Array fptr1) (Array fptr2) op = - unsafePerformIO $ do + unsafePerformIO . mask_ $ do withForeignPtr fptr1 $ \ptr1 -> do withForeignPtr fptr2 $ \ptr2 -> do alloca $ \ptrInput1 -> do @@ -474,7 +497,7 @@ infoFromArray3 -> (a,b,c) {-# NOINLINE infoFromArray3 #-} infoFromArray3 (Array fptr1) op = - unsafePerformIO $ + unsafePerformIO . mask_ $ withForeignPtr fptr1 $ \ptr1 -> do alloca $ \ptrInput1 -> do alloca $ \ptrInput2 -> do @@ -491,7 +514,7 @@ infoFromArray4 -> (a,b,c,d) {-# NOINLINE infoFromArray4 #-} infoFromArray4 (Array fptr1) op = - unsafePerformIO $ + unsafePerformIO . mask_ $ withForeignPtr fptr1 $ \ptr1 -> alloca $ \ptrInput1 -> alloca $ \ptrInput2 -> diff --git a/src/ArrayFire/Features.hs b/src/ArrayFire/Features.hs index a84f58d..0920bb2 100644 --- a/src/ArrayFire/Features.hs +++ b/src/ArrayFire/Features.hs @@ -17,6 +17,7 @@ -------------------------------------------------------------------------------- module ArrayFire.Features where +import Control.Exception (mask_) import Foreign.Marshal import Foreign.Storable import Foreign.ForeignPtr @@ -34,8 +35,9 @@ import ArrayFire.Exception createFeatures :: Int -> Features +{-# NOINLINE createFeatures #-} createFeatures (fromIntegral -> n) = - unsafePerformIO $ do + unsafePerformIO . mask_ $ do ptr <- alloca $ \ptrInput -> do throwAFError =<< ptrInput `af_create_features` n diff --git a/src/ArrayFire/Graphics.hs b/src/ArrayFire/Graphics.hs index e657625..e996eaa 100644 --- a/src/ArrayFire/Graphics.hs +++ b/src/ArrayFire/Graphics.hs @@ -492,13 +492,13 @@ drawVectorField2d -> Array a -- ^ is an 'Array' with the x-axis points -> Array a - -- ^ is the window handle + -- ^ is an 'Array' with the y-axis points -> Array a - -- ^ is the window handle + -- ^ is an 'Array' with the x-axis directions -> Array a - -- ^ is the window handle + -- ^ is an 'Array' with the y-axis directions -> Cell - -- ^ is the window handle + -- ^ is structure 'Cell' that has the properties that are used for the current rendering. -> IO () drawVectorField2d (Window w) (Array fptr1) (Array fptr2) (Array fptr3) (Array fptr4) cell = mask_ $ do diff --git a/src/ArrayFire/Image.hs b/src/ArrayFire/Image.hs index 9ae11d8..d63ed06 100644 --- a/src/ArrayFire/Image.hs +++ b/src/ArrayFire/Image.hs @@ -25,7 +25,6 @@ import Data.Word import ArrayFire.Internal.Types import ArrayFire.Internal.Image import ArrayFire.FFI -import ArrayFire.Arith -- | Calculates the gradient of an image -- @@ -260,7 +259,7 @@ histogram -> Array Word32 -- ^ (type u32) is the histogram for input array in histogram a (fromIntegral -> b) c d = - cast (a `op1` (\ptr x -> af_histogram ptr x b c d)) + a `op1` (\ptr x -> af_histogram ptr x b c d) -- | Dilation(morphological operator) for images. -- diff --git a/src/ArrayFire/Index.hs b/src/ArrayFire/Index.hs index ae1eaa4..3734c5a 100644 --- a/src/ArrayFire/Index.hs +++ b/src/ArrayFire/Index.hs @@ -10,6 +10,7 @@ -- Functions for indexing into an 'Array' -- -------------------------------------------------------------------------------- +{-# LANGUAGE FlexibleInstances #-} module ArrayFire.Index where import ArrayFire.Internal.Index @@ -29,6 +30,7 @@ index -> [Seq] -- ^ 'Seq' to use for indexing -> Array a +{-# NOINLINE index #-} index (Array fptr) seqs = unsafePerformIO . mask_ . withForeignPtr fptr $ \ptr -> do alloca $ \aptr -> @@ -41,65 +43,156 @@ index (Array fptr) seqs = n = fromIntegral (length seqs) -- | Lookup an Array by keys along a specified dimension -lookup - :: Array a +lookup + :: Array a -- ^ Input Array - -> Array Int + -> Array Int -- ^ Indices - -> Int + -> Int -- ^ Dimension -> Array a lookup a b n = op2 a b $ \p x y -> af_lookup p x y (fromIntegral n) --- | A special value representing the entire axis of an 'Array'. -span :: Seq -span = Seq 1 1 0 -- From include/af/seq.h - -- Hard-coded here because FFI cannot import static const values. - --- af_err af_assign_seq( af_array *out, const af_array lhs, const unsigned ndims, const af_seq* const indices, const af_array rhs); --- | Calculates 'mean' of 'Array' along user-specified dimension. +-- | Assign values into an 'Array' range defined by 'Seq' indices -- -- @ --- >>> print $ mean 0 ( vector @Int 10 [1..] ) +-- >>> let a = vector \@Double 5 [1..] +-- >>> assignSeq a [Seq 1 3 1] (vector \@Double 3 [0,0,0]) -- @ +assignSeq + :: Array a + -- ^ Destination array + -> [Seq] + -- ^ Indices defining the range to assign into + -> Array a + -- ^ Source array + -> Array a + -- ^ Result with values written at the specified indices +{-# NOINLINE assignSeq #-} +assignSeq (Array fptr) seqs (Array rhsFptr) = + unsafePerformIO . mask_ $ + withForeignPtr fptr $ \ptr -> + withForeignPtr rhsFptr $ \rhsPtr -> + withArray (toAFSeq <$> seqs) $ \sptr -> + alloca $ \aptr -> do + throwAFError =<< af_assign_seq aptr ptr n sptr rhsPtr + Array <$> (newForeignPtr af_release_array_finalizer =<< peek aptr) + where + n = fromIntegral (length seqs) + +-- | Index into an 'Array' using generalized 'Index' values (arrays or sequences) +-- -- @ --- ArrayFire Array --- [1 1 1 1] --- 5.5000 +-- >>> let a = matrix \@Double (3,3) [[1..],[1..],[1..]] +-- >>> indexGen a [seqIdx (Seq 0 1 1) False, seqIdx (Seq 0 1 1) False] -- @ --- assignSeq :: Array a -> Int -> [Seq] -> Array a -> Array a --- assignSeq = error "Not implemneted" +indexGen + :: Array a + -- ^ Input array + -> [Index] + -- ^ List of 'Index' values (one per dimension) + -> Array a + -- ^ Indexed result +{-# NOINLINE indexGen #-} +indexGen (Array fptr) indices = + unsafePerformIO . mask_ $ + withForeignPtr fptr $ \ptr -> do + afIndices <- traverse toAFIndex indices + withArray afIndices $ \iptr -> + alloca $ \aptr -> do + throwAFError =<< af_index_gen aptr ptr (fromIntegral n) iptr + mapM_ touchIdxFPtr indices + Array <$> (newForeignPtr af_release_array_finalizer =<< peek aptr) + where + n = length indices + touchIdxFPtr (ArrIndex _ (Array p)) = touchForeignPtr p + touchIdxFPtr _ = pure () --- af_err af_index_gen( af_array *out, const af_array in, const dim_t ndims, const af_index_t* indices); --- | Calculates 'mean' of 'Array' along user-specified dimension. +-- | Assign values into an 'Array' using generalized 'Index' values -- -- @ --- >>> print $ mean 0 ( vector @Int 10 [1..] ) +-- >>> let a = matrix \@Double (3,3) [[1..],[1..],[1..]] +-- >>> let b = matrix \@Double (2,2) [[0,0],[0,0]] +-- >>> assignGen a [seqIdx (Seq 0 1 1) False, seqIdx (Seq 0 1 1) False] b -- @ +assignGen + :: Array a + -- ^ Destination array + -> [Index] + -- ^ List of 'Index' values defining the range to assign into + -> Array a + -- ^ Source array + -> Array a + -- ^ Result with values written at the specified indices +{-# NOINLINE assignGen #-} +assignGen (Array fptr) indices (Array rhsFptr) = + unsafePerformIO . mask_ $ + withForeignPtr fptr $ \ptr -> + withForeignPtr rhsFptr $ \rhsPtr -> do + afIndices <- traverse toAFIndex indices + withArray afIndices $ \iptr -> + alloca $ \aptr -> do + throwAFError =<< af_assign_gen aptr ptr (fromIntegral n) iptr rhsPtr + mapM_ touchIdxFPtr indices + Array <$> (newForeignPtr af_release_array_finalizer =<< peek aptr) + where + n = length indices + touchIdxFPtr (ArrIndex _ (Array p)) = touchForeignPtr p + touchIdxFPtr _ = pure () + +-- | A special 'Seq' value representing the entire axis of an 'Array'. +-- Hard-coded from include\/af\/seq.h because FFI cannot import static const values. +afSpan :: Seq +afSpan = Seq 1 1 0 + +-- | Select the full extent of a dimension. Use in tuple indices where you want all elements along an axis. +-- -- @ --- ArrayFire Array --- [1 1 1 1] --- 5.5000 +-- arr ! (range 0 2, full, at 1) -- @ --- indexGen :: Array a -> Int -> [Index a] -> Array a -> Array a --- indexGen = error "Not implemneted" +full :: Index +full = SeqIndex False afSpan + +-- | Convert index expressions to a list of 'Index'. +-- Supports a single 'Index' or tuples of up to four 'Index' values +-- (matching ArrayFire's maximum of 4 dimensions). +class ToIndexList a where + toIndexList :: a -> [Index] + +instance ToIndexList Index where + toIndexList x = [x] --- af_err af_assingn_gen( af_array *out, const af_array lhs, const dim_t ndims, const af_index_t* indices, const af_array rhs); --- | Calculates 'mean' of 'Array' along user-specified dimension. +instance ToIndexList (Index, Index) where + toIndexList (a, b) = [a, b] + +instance ToIndexList (Index, Index, Index) where + toIndexList (a, b, c) = [a, b, c] + +instance ToIndexList (Index, Index, Index, Index) where + toIndexList (a, b, c, d) = [a, b, c, d] + +-- | Lift a 'Seq' to an 'Index' for use in tuple-based indexing. +idx :: Seq -> Index +idx s = SeqIndex False s + +-- | Index an 'Array'. Accepts a single 'Index' or a tuple of up to four. -- -- @ --- >>> print $ mean 0 ( vector @Int 10 [1..] ) +-- arr ! at 0 -- 1D: element 0 +-- arr ! range 1 3 -- 1D: rows 1-3 +-- arr ! (range 0 2, at 1) -- 2D +-- arr ! (range 0 2, full, at 1) -- 3D, full second axis -- @ +(!) :: ToIndexList ix => Array a -> ix -> Array a +a ! ix = indexGen a (toIndexList ix) +infixl 9 ! + +-- | Assign into a range of an 'Array'. Lens-style: use with '(&)'. +-- -- @ --- ArrayFire Array --- [1 1 1 1] --- 5.5000 +-- arr & range 1 3 .~ src +-- arr & (range 0 1, at 2) .~ src -- @ --- assignGen :: Array a -> Int -> [Index a] -> Array a -> Array a --- assignGen = error "Not implemneted" - --- af_err af_create_indexers(af_index_t** indexers); --- af_err af_set_array_indexer(af_index_t* indexer, const af_array idx, const dim_t dim); --- af_err af_set_seq_indexer(af_index_t* indexer, const af_seq* idx, const dim_t dim, const bool is_batch); --- af_err af_set_seq_param_indexer(af_index_t* indexer, const double begin, const double end, const double step, const dim_t dim, const bool is_batch); --- af_err af_release_indexers(af_index_t* indexers); +(.~) :: ToIndexList ix => ix -> Array a -> Array a -> Array a +(ix .~ rhs) arr = assignGen arr (toIndexList ix) rhs +infixr 4 .~ diff --git a/src/ArrayFire/Internal/Algorithm.hsc b/src/ArrayFire/Internal/Algorithm.hsc index c683a0d..7c20814 100644 --- a/src/ArrayFire/Internal/Algorithm.hsc +++ b/src/ArrayFire/Internal/Algorithm.hsc @@ -75,3 +75,21 @@ foreign import ccall unsafe "af_set_union" af_set_union :: Ptr AFArray -> AFArray -> AFArray -> CBool -> IO AFErr foreign import ccall unsafe "af_set_intersect" af_set_intersect :: Ptr AFArray -> AFArray -> AFArray -> CBool -> IO AFErr +foreign import ccall unsafe "af_sum_by_key" + af_sum_by_key :: Ptr AFArray -> Ptr AFArray -> AFArray -> AFArray -> CInt -> IO AFErr +foreign import ccall unsafe "af_sum_by_key_nan" + af_sum_by_key_nan :: Ptr AFArray -> Ptr AFArray -> AFArray -> AFArray -> CInt -> Double -> IO AFErr +foreign import ccall unsafe "af_product_by_key" + af_product_by_key :: Ptr AFArray -> Ptr AFArray -> AFArray -> AFArray -> CInt -> IO AFErr +foreign import ccall unsafe "af_product_by_key_nan" + af_product_by_key_nan :: Ptr AFArray -> Ptr AFArray -> AFArray -> AFArray -> CInt -> Double -> IO AFErr +foreign import ccall unsafe "af_min_by_key" + af_min_by_key :: Ptr AFArray -> Ptr AFArray -> AFArray -> AFArray -> CInt -> IO AFErr +foreign import ccall unsafe "af_max_by_key" + af_max_by_key :: Ptr AFArray -> Ptr AFArray -> AFArray -> AFArray -> CInt -> IO AFErr +foreign import ccall unsafe "af_all_true_by_key" + af_all_true_by_key :: Ptr AFArray -> Ptr AFArray -> AFArray -> AFArray -> CInt -> IO AFErr +foreign import ccall unsafe "af_any_true_by_key" + af_any_true_by_key :: Ptr AFArray -> Ptr AFArray -> AFArray -> AFArray -> CInt -> IO AFErr +foreign import ccall unsafe "af_count_by_key" + af_count_by_key :: Ptr AFArray -> Ptr AFArray -> AFArray -> AFArray -> CInt -> IO AFErr diff --git a/src/ArrayFire/Internal/BLAS.hsc b/src/ArrayFire/Internal/BLAS.hsc index b3b1788..f75beb2 100644 --- a/src/ArrayFire/Internal/BLAS.hsc +++ b/src/ArrayFire/Internal/BLAS.hsc @@ -17,3 +17,5 @@ foreign import ccall unsafe "af_transpose" af_transpose :: Ptr AFArray -> AFArray -> CBool -> IO AFErr foreign import ccall unsafe "af_transpose_inplace" af_transpose_inplace :: AFArray -> CBool -> IO AFErr +foreign import ccall unsafe "af_gemm" + af_gemm :: Ptr AFArray -> AFMatProp -> AFMatProp -> Ptr () -> AFArray -> AFArray -> Ptr () -> IO AFErr diff --git a/src/ArrayFire/Internal/Defines.hsc b/src/ArrayFire/Internal/Defines.hsc index 9de5f06..2cbdd5e 100644 --- a/src/ArrayFire/Internal/Defines.hsc +++ b/src/ArrayFire/Internal/Defines.hsc @@ -253,7 +253,7 @@ newtype AFBackend = AFBackend CInt #{enum AFBackend, AFBackend , afBackendDefault = AF_BACKEND_DEFAULT - , afBackendCpu = AF_BACKEND_DEFAULT + , afBackendCpu = AF_BACKEND_CPU , afBackendCuda = AF_BACKEND_CUDA , afBackendOpencl = AF_BACKEND_OPENCL } @@ -381,14 +381,14 @@ newtype AFInverseDeconvAlgo = AFInverseDeconvAlgo CInt afInverseDeconvDefault = AF_INVERSE_DECONV_DEFAULT } --- newtype AFVarBias = AFVarBias Int --- deriving (Ord, Show, Eq) +newtype AFVarBias = AFVarBias CInt + deriving (Ord, Show, Eq, Storable) --- #{enum AFVarBias, AFVarBias --- , afVarianceDefault = AF_VARIANCE_DEFAULT --- , afVarianceSample = AF_VARIANCE_SAMPLE --- , afVariancePopulation = AF_VARIANCE_POPULATION --- } +#{enum AFVarBias, AFVarBias + , afVarianceDefault = AF_VARIANCE_DEFAULT + , afVarianceSample = AF_VARIANCE_SAMPLE + , afVariancePopulation = AF_VARIANCE_POPULATION + } newtype DimT = DimT CLLong deriving (Show, Eq, Storable, Num, Integral, Real, Enum, Ord) diff --git a/src/ArrayFire/Internal/Statistics.hsc b/src/ArrayFire/Internal/Statistics.hsc index 744e7b1..1decabc 100644 --- a/src/ArrayFire/Internal/Statistics.hsc +++ b/src/ArrayFire/Internal/Statistics.hsc @@ -36,3 +36,5 @@ foreign import ccall unsafe "af_corrcoef" af_corrcoef :: Ptr Double -> Ptr Double -> AFArray -> AFArray -> IO AFErr foreign import ccall unsafe "af_topk" af_topk :: Ptr AFArray -> Ptr AFArray -> AFArray -> CInt -> CInt -> AFTopkFunction -> IO AFErr +foreign import ccall unsafe "af_meanvar" + af_meanvar :: Ptr AFArray -> Ptr AFArray -> AFArray -> AFArray -> AFVarBias -> DimT -> IO AFErr diff --git a/src/ArrayFire/Internal/Types.hsc b/src/ArrayFire/Internal/Types.hsc index 3198d79..4e77df7 100644 --- a/src/ArrayFire/Internal/Types.hsc +++ b/src/ArrayFire/Internal/Types.hsc @@ -17,6 +17,7 @@ import Data.Word import Foreign.C.String import Foreign.C.Types import Foreign.ForeignPtr +import Foreign.ForeignPtr.Unsafe (unsafeForeignPtrToPtr) import Foreign.Storable import GHC.Int @@ -55,8 +56,8 @@ instance Storable AFIndex where afIsBatch <- #{peek af_index_t, isBatch} ptr afIdx <- if afIsSeq - then Left <$> #{peek af_index_t, idx.arr} ptr - else Right <$> #{peek af_index_t, idx.seq} ptr + then Right <$> #{peek af_index_t, idx.seq} ptr + else Left <$> #{peek af_index_t, idx.arr} ptr pure AFIndex{..} poke ptr AFIndex{..} = do case afIdx of @@ -166,9 +167,13 @@ instance AFType Word where -- | ArrayFire backends data Backend = Default + -- ^ Use the default backend (determined by ArrayFire) | CPU + -- ^ CPU backend (always available) | CUDA + -- ^ NVIDIA CUDA GPU backend | OpenCL + -- ^ OpenCL backend (AMD, Intel, NVIDIA) deriving (Show, Eq, Ord) -- | Low-level to high-level Backend conversion @@ -200,17 +205,29 @@ toBackends _ = [] -- | Matrix properties data MatProp = None + -- ^ No property | Trans + -- ^ Data needs to be transposed | CTrans + -- ^ Data needs to be conjugate transposed | Conj + -- ^ Data needs to be conjugated | Upper + -- ^ Matrix is upper triangular | Lower + -- ^ Matrix is lower triangular | DiagUnit + -- ^ Diagonal contains units; used with triangular solvers | Sym + -- ^ Matrix is symmetric | PosDef + -- ^ Matrix is positive definite | Orthog + -- ^ Matrix is orthogonal | TriDiag + -- ^ Matrix is tri-diagonal | BlockDiag + -- ^ Matrix is block diagonal deriving (Show, Eq, Ord) -- | Low-level to High-level 'MatProp' conversion @@ -248,12 +265,16 @@ toMatProp Orthog = (AFMatProp 2048) toMatProp TriDiag = (AFMatProp 4096) toMatProp BlockDiag = (AFMatProp 8192) --- | Binary operation support +-- | Binary operation support (used with scan-by-key and similar operations) data BinaryOp = Add + -- ^ Addition | Mul + -- ^ Multiplication | Min + -- ^ Minimum | Max + -- ^ Maximum deriving (Show, Eq, Ord) -- | High-level to low-level 'MatProp' conversion @@ -274,9 +295,13 @@ fromBinaryOp x = error ("Invalid Binary Op: " <> show x) -- | Storage type used for Sparse arrays data Storage = Dense + -- ^ Dense storage (not sparse) | CSR + -- ^ Compressed Sparse Row format | CSC + -- ^ Compressed Sparse Column format | COO + -- ^ Coordinate list (COO) format deriving (Show, Eq, Ord, Enum) toStorage :: Storage -> AFStorage @@ -309,15 +334,25 @@ fromRandomEngine Mersenne = (AFRandomEngineType 300) -- | Interpolation type data InterpType = Nearest + -- ^ Nearest-neighbor interpolation | Linear + -- ^ Linear interpolation | Bilinear + -- ^ Bilinear interpolation | Cubic + -- ^ Cubic interpolation | LowerInterp + -- ^ Floor interpolation (rounds down to nearest integer) | LinearCosine + -- ^ Cosine-windowed linear interpolation | BilinearCosine + -- ^ Cosine-windowed bilinear interpolation | Bicubic + -- ^ Bicubic interpolation | CubicSpline + -- ^ Cubic spline interpolation | BicubicSpline + -- ^ Bicubic spline interpolation deriving (Show, Eq, Ord, Enum) toInterpType :: AFInterpType -> InterpType @@ -346,7 +381,7 @@ data Connectivity toConnectivity :: AFConnectivity -> Connectivity toConnectivity (AFConnectivity 4) = Conn4 -toConnectivity (AFConnectivity 8) = Conn4 +toConnectivity (AFConnectivity 8) = Conn8 toConnectivity (AFConnectivity x) = error ("Unknown connectivity option: " <> show x) fromConnectivity :: Connectivity -> AFConnectivity @@ -356,9 +391,13 @@ fromConnectivity Conn8 = AFConnectivity 8 -- | Color Space type data CSpace = Gray + -- ^ Grayscale | RGB + -- ^ Red-Green-Blue | HSV + -- ^ Hue-Saturation-Value | YCBCR + -- ^ Luminance + chroma (blue-difference, red-difference) deriving (Show, Eq, Ord, Enum) toCSpace :: AFCSpace -> CSpace @@ -367,11 +406,14 @@ toCSpace (AFCSpace (fromIntegral -> x)) = toEnum x fromCSpace :: CSpace -> AFCSpace fromCSpace = AFCSpace . fromIntegral . fromEnum --- | YccStd type +-- | YCbCr standard data YccStd = Ycc601 + -- ^ ITU-R BT.601 (standard definition) | Ycc709 + -- ^ ITU-R BT.709 (high definition) | Ycc2020 + -- ^ ITU-R BT.2020 (ultra high definition) deriving (Show, Eq, Ord) toAFYccStd :: AFYccStd -> YccStd @@ -385,13 +427,18 @@ fromAFYccStd Ycc601 = afYcc601 fromAFYccStd Ycc709 = afYcc709 fromAFYccStd Ycc2020 = afYcc2020 --- | Moment types +-- | Image moment types data MomentType = M00 + -- ^ Zeroth-order moment (image area / mass) | M01 + -- ^ First-order moment about x-axis | M10 + -- ^ First-order moment about y-axis | M11 + -- ^ Mixed first-order moment | FirstOrder + -- ^ All first-order moments (M00, M01, M10, M11) deriving (Show, Eq, Ord) toMomentType :: AFMomentType -> MomentType @@ -410,10 +457,12 @@ fromMomentType M10 = afMomentM10 fromMomentType M11 = afMomentM11 fromMomentType FirstOrder = afMomentFirstOrder --- | Canny Theshold type +-- | Threshold mode for Canny edge detection data CannyThreshold = Manual + -- ^ User-supplied low and high threshold values | AutoOtsu + -- ^ Thresholds computed automatically via Otsu's method deriving (Show, Eq, Ord, Enum) toCannyThreshold :: AFCannyThreshold -> CannyThreshold @@ -422,11 +471,14 @@ toCannyThreshold (AFCannyThreshold (fromIntegral -> x)) = toEnum x fromCannyThreshold :: CannyThreshold -> AFCannyThreshold fromCannyThreshold = AFCannyThreshold . fromIntegral . fromEnum --- | Flux function type +-- | Flux function for anisotropic diffusion data FluxFunction = FluxDefault + -- ^ Default flux function (same as 'FluxQuadratic') | FluxQuadratic + -- ^ Quadratic flux function (Perona-Malik) | FluxExponential + -- ^ Exponential flux function (Perona-Malik) deriving (Show, Eq, Ord, Enum) toFluxFunction :: AFFluxFunction -> FluxFunction @@ -435,11 +487,14 @@ toFluxFunction (AFFluxFunction (fromIntegral -> x)) = toEnum x fromFluxFunction :: FluxFunction -> AFFluxFunction fromFluxFunction = AFFluxFunction . fromIntegral . fromEnum --- | Diffusion type +-- | Diffusion equation type for anisotropic smoothing data DiffusionEq = DiffusionDefault + -- ^ Default (same as 'DiffusionGrad') | DiffusionGrad + -- ^ Gradient-based diffusion (Perona-Malik) | DiffusionMCDE + -- ^ Mean curvature diffusion equation deriving (Show, Eq, Ord, Enum) toDiffusionEq :: AFDiffusionEq -> DiffusionEq @@ -448,11 +503,14 @@ toDiffusionEq (AFDiffusionEq (fromIntegral -> x)) = toEnum x fromDiffusionEq :: DiffusionEq -> AFDiffusionEq fromDiffusionEq = AFDiffusionEq . fromIntegral . fromEnum --- | Iterative deconvolution algo type +-- | Iterative deconvolution algorithm data IterativeDeconvAlgo = DeconvDefault + -- ^ Default algorithm (same as 'DeconvLandweber') | DeconvLandweber + -- ^ Landweber iteration (gradient descent on least squares) | DeconvRichardsonLucy + -- ^ Richardson-Lucy algorithm (maximum likelihood for Poisson noise) deriving (Show, Eq, Ord, Enum) toIterativeDeconvAlgo :: AFIterativeDeconvAlgo -> IterativeDeconvAlgo @@ -461,10 +519,12 @@ toIterativeDeconvAlgo (AFIterativeDeconvAlgo (fromIntegral -> x)) = toEnum x fromIterativeDeconvAlgo :: IterativeDeconvAlgo -> AFIterativeDeconvAlgo fromIterativeDeconvAlgo = AFIterativeDeconvAlgo . fromIntegral . fromEnum --- | Inverse deconvolution algo type +-- | Inverse (non-iterative) deconvolution algorithm data InverseDeconvAlgo = InverseDeconvDefault + -- ^ Default algorithm (same as 'InverseDeconvTikhonov') | InverseDeconvTikhonov + -- ^ Tikhonov regularized Wiener filter deriving (Show, Eq, Ord, Enum) toInverseDeconvAlgo :: AFInverseDeconvAlgo -> InverseDeconvAlgo @@ -473,13 +533,17 @@ toInverseDeconvAlgo (AFInverseDeconvAlgo (fromIntegral -> x)) = toEnum x fromInverseDeconvAlgo :: InverseDeconvAlgo -> AFInverseDeconvAlgo fromInverseDeconvAlgo = AFInverseDeconvAlgo . fromIntegral . fromEnum --- | Cell type, used in Graphics module +-- | Cell type, used in Graphics module to describe a subplot position data Cell = Cell { cellRow :: Int + -- ^ Row index of the subplot (0-based) , cellCol :: Int + -- ^ Column index of the subplot (0-based) , cellTitle :: String + -- ^ Title string displayed above the plot , cellColorMap :: ColorMap + -- ^ Color map used for rendering } deriving (Show, Eq) cellToAFCell :: Cell -> IO AFCell @@ -491,19 +555,30 @@ cellToAFCell Cell {..} = , afCellColorMap = fromColorMap cellColorMap } --- | ColorMap type +-- | Color map for rendering data ColorMap = ColorMapDefault + -- ^ Default grayscale color map | ColorMapSpectrum + -- ^ Rainbow spectrum (violet to red) | ColorMapColors + -- ^ Distinct colors | ColorMapRed + -- ^ Red gradient | ColorMapMood + -- ^ Mood color map (cool tones) | ColorMapHeat + -- ^ Heat map (black to red to yellow to white) | ColorMapBlue + -- ^ Blue gradient | ColorMapInferno + -- ^ Perceptually uniform: black-purple-orange-yellow | ColorMapMagma + -- ^ Perceptually uniform: black-purple-pink-white | ColorMapPlasma + -- ^ Perceptually uniform: blue-purple-yellow | ColorMapViridis + -- ^ Perceptually uniform: purple-teal-yellow deriving (Show, Eq, Ord, Enum) fromColorMap :: ColorMap -> AFColorMap @@ -512,16 +587,24 @@ fromColorMap = AFColorMap . fromIntegral . fromEnum toColorMap :: AFColorMap -> ColorMap toColorMap (AFColorMap (fromIntegral -> x)) = toEnum x --- | Marker type +-- | Marker shape for scatter plots data MarkerType = MarkerTypeNone + -- ^ No marker | MarkerTypePoint + -- ^ Single pixel point | MarkerTypeCircle + -- ^ Circle | MarkerTypeSquare + -- ^ Square | MarkerTypeTriangle + -- ^ Triangle | MarkerTypeCross + -- ^ X cross | MarkerTypePlus + -- ^ Plus sign | MarkerTypeStar + -- ^ Star deriving (Show, Eq, Ord, Enum) fromMarkerType :: MarkerType -> AFMarkerType @@ -530,17 +613,26 @@ fromMarkerType = AFMarkerType . fromIntegral . fromEnum toMarkerType :: AFMarkerType -> MarkerType toMarkerType (AFMarkerType (fromIntegral -> x)) = toEnum x --- | Match type +-- | Template matching metric type data MatchType = MatchTypeSAD + -- ^ Sum of Absolute Differences | MatchTypeZSAD + -- ^ Zero-mean Sum of Absolute Differences | MatchTypeLSAD + -- ^ Locally scaled Sum of Absolute Differences | MatchTypeSSD + -- ^ Sum of Squared Differences | MatchTypeZSSD + -- ^ Zero-mean Sum of Squared Differences | MatchTypeLSSD + -- ^ Locally scaled Sum of Squared Differences | MatchTypeNCC + -- ^ Normalized Cross Correlation | MatchTypeZNCC + -- ^ Zero-mean Normalized Cross Correlation | MatchTypeSHD + -- ^ Sum of Hamming Distances deriving (Show, Eq, Ord, Enum) fromMatchType :: MatchType -> AFMatchType @@ -549,11 +641,14 @@ fromMatchType = AFMatchType . fromIntegral . fromEnum toMatchType :: AFMatchType -> MatchType toMatchType (AFMatchType (fromIntegral -> x)) = toEnum x --- | TopK type +-- | Order for @topk@ results data TopK = TopKDefault + -- ^ Default order (same as 'TopKMax') | TopKMin + -- ^ Return the k smallest values | TopKMax + -- ^ Return the k largest values deriving (Show, Eq, Ord, Enum) fromTopK :: TopK -> AFTopkFunction @@ -562,10 +657,25 @@ fromTopK = AFTopkFunction . fromIntegral . fromEnum toTopK :: AFTopkFunction -> TopK toTopK (AFTopkFunction (fromIntegral -> x)) = toEnum x --- | Homography Type +-- | Variance bias correction method +data VarBias + = VarianceDefault + -- ^ Default (same as 'VariancePopulation') + | VarianceSample + -- ^ Sample variance (divides by N-1; Bessel's correction) + | VariancePopulation + -- ^ Population variance (divides by N) + deriving (Show, Eq, Ord, Enum) + +fromVarBias :: VarBias -> AFVarBias +fromVarBias = AFVarBias . fromIntegral . fromEnum + +-- | Homography estimation method data HomographyType = RANSAC + -- ^ Random Sample Consensus — robust to outliers | LMEDS + -- ^ Least Median of Squares — robust to up to 50% outliers deriving (Show, Eq, Ord, Enum) fromHomographyType :: HomographyType -> AFHomographyType @@ -586,26 +696,33 @@ toAFSeq :: Seq -> AFSeq toAFSeq (Seq x y z) = (AFSeq x y z) -- | Index Type -data Index a - = Index - { idx :: Either (Array a) Seq - , isSeq :: !Bool - , isBatch :: !Bool - } +data Index + = SeqIndex Bool Seq + | ArrIndex Bool (Array Int) + +seqIdx :: Seq -> Bool -> Index +seqIdx s batch = SeqIndex batch s + +arrIdx :: Array Int -> Bool -> Index +arrIdx a batch = ArrIndex batch a + +-- | Index a contiguous range [begin..end] with step 1. +range :: Int -> Int -> Index +range b e = SeqIndex False (Seq (fromIntegral b) (fromIntegral e) 1) -seqIdx :: Seq -> Bool -> Index a -seqIdx s = Index (Right s) True +-- | Index a range [begin..end] with an explicit step. +rangeStep :: Int -> Int -> Int -> Index +rangeStep b e s = SeqIndex False (Seq (fromIntegral b) (fromIntegral e) (fromIntegral s)) -arrIdx :: Array a -> Bool -> Index a -arrIdx a = Index (Left a) False +-- | Index a single element. +at :: Int -> Index +at n = let d = fromIntegral n in SeqIndex False (Seq d d 1) -toAFIndex :: Index a -> IO AFIndex -toAFIndex (Index a b c) = do - case a of - Right s -> pure $ AFIndex (Right (toAFSeq s)) b c - Left (Array fptr) -> do - withForeignPtr fptr $ \ptr -> - pure $ AFIndex (Left ptr) b c +toAFIndex :: Index -> IO AFIndex +toAFIndex (SeqIndex batch s) = + pure $ AFIndex (Right (toAFSeq s)) True batch +toAFIndex (ArrIndex batch (Array fptr)) = + pure $ AFIndex (Left (unsafeForeignPtrToPtr fptr)) False batch -- | Type alias for ArrayFire API version @@ -669,20 +786,32 @@ fromConvMode (AFConvMode (fromIntegral -> x)) = toEnum x toConvMode :: ConvMode -> AFConvMode toConvMode = AFConvMode . fromIntegral . fromEnum --- | Array Fire types +-- | ArrayFire element types (mirrors @af_dtype@) data AFDType = F32 + -- ^ 32-bit IEEE 754 float | C32 + -- ^ Complex number of two 32-bit floats | F64 + -- ^ 64-bit IEEE 754 double | C64 + -- ^ Complex number of two 64-bit doubles | B8 + -- ^ 8-bit boolean | S32 + -- ^ 32-bit signed integer | U32 + -- ^ 32-bit unsigned integer | U8 + -- ^ 8-bit unsigned integer | S64 + -- ^ 64-bit signed integer | U64 + -- ^ 64-bit unsigned integer | S16 + -- ^ 16-bit signed integer | U16 + -- ^ 16-bit unsigned integer deriving (Show, Eq, Enum) fromAFType :: AFDtype -> AFDType diff --git a/src/ArrayFire/Orphans.hs b/src/ArrayFire/Orphans.hs index 0d9383a..e9ba80e 100644 --- a/src/ArrayFire/Orphans.hs +++ b/src/ArrayFire/Orphans.hs @@ -15,27 +15,32 @@ -------------------------------------------------------------------------------- module ArrayFire.Orphans where -import Prelude +import Prelude hiding (pi) +import qualified Prelude + +import Control.DeepSeq (NFData(..)) import qualified ArrayFire.Arith as A import qualified ArrayFire.Array as A import qualified ArrayFire.Algorithm as A -import qualified ArrayFire.Data as A import ArrayFire.Types import ArrayFire.Util +instance NFData (Array a) where + rnf x = x `seq` () + instance (AFType a, Eq a) => Eq (Array a) where - x == y = A.allTrueAll (A.eqBatched x y False) == (1.0,0.0) - x /= y = A.allTrueAll (A.neqBatched x y False) == (0.0,0.0) + x == y = A.getDims x == A.getDims y + && A.allTrueAll (A.eqBatched x y False) == (1.0,0.0) + x /= y = A.getDims x /= A.getDims y + || A.anyTrueAll (A.neqBatched x y False) /= (0.0,0.0) instance (Num a, AFType a) => Num (Array a) where x + y = A.add x y x * y = A.mul x y abs = A.abs - signum x = A.sign (-x) - A.sign x - negate arr = do - let (w,x,y,z) = A.getDims arr - A.cast (A.constant @a [w,x,y,z] 0) `A.sub` arr + signum x = A.cast (A.gt x 0) - A.cast (A.lt x 0) + negate arr = A.scalar @a (fromInteger (-1)) `A.mul` arr x - y = A.sub x y fromInteger = A.scalar . fromIntegral @@ -47,7 +52,7 @@ instance forall a . (Fractional a, AFType a) => Fractional (Array a) where fromRational n = A.scalar @a (fromRational n) instance forall a . (Ord a, AFType a, Fractional a) => Floating (Array a) where - pi = A.scalar @a 3.14159 + pi = A.scalar @a (realToFrac (Prelude.pi :: Double)) exp = A.exp @a log = A.log @a sqrt = A.sqrt @a diff --git a/src/ArrayFire/Statistics.hs b/src/ArrayFire/Statistics.hs index 8a3db79..d80a63a 100644 --- a/src/ArrayFire/Statistics.hs +++ b/src/ArrayFire/Statistics.hs @@ -33,6 +33,9 @@ -------------------------------------------------------------------------------- module ArrayFire.Statistics where +import Data.Word (Word32) +import Foreign.Ptr (nullPtr) + import ArrayFire.Array import ArrayFire.FFI import ArrayFire.Internal.Statistics @@ -303,8 +306,58 @@ topk -- ^ The number of elements to be retrieved along the dim dimension -> TopK -- ^ If descending, the highest values are returned. Otherwise, the lowest values are returned - -> (Array a, Array a) + -> (Array a, Array Word32) -- ^ Returns The values of the top k elements along the dim dimension -- along with the indices of the top k elements along the dim dimension topk a (fromIntegral -> x) (fromTopK -> f) = a `op2p` (\b c d -> af_topk b c d x 0 f) + +-- | Simultaneously compute the mean and variance of an 'Array' along a dimension. +-- +-- More efficient than calling 'mean' and 'var' separately. +-- +-- >>> let (m, v) = meanVar (vector @Double 4 [1,2,3,4]) VariancePopulation 0 +-- >>> m +-- ArrayFire Array +-- [1 1 1 1] +-- 2.5000 +-- >>> v +-- ArrayFire Array +-- [1 1 1 1] +-- 1.2500 +meanVar + :: AFType a + => Array a + -- ^ Input 'Array' + -> VarBias + -- ^ Variance bias correction: 'VariancePopulation' (÷N) or 'VarianceSample' (÷N-1) + -> Int + -- ^ Dimension along which to compute + -> (Array a, Array a) + -- ^ (mean, variance) +meanVar arr bias (fromIntegral -> dim) = + arr `op2p` (\pMean pVar aPtr -> + af_meanvar pMean pVar aPtr nullPtr (fromVarBias bias) dim) + +-- | Simultaneously compute the weighted mean and variance of an 'Array' along a dimension. +-- +-- >>> let (m, v) = meanVarWeighted (vector @Double 4 [1,2,3,4]) (vector @Double 4 [1,1,1,1]) VariancePopulation 0 +-- >>> m +-- ArrayFire Array +-- [1 1 1 1] +-- 2.5000 +meanVarWeighted + :: AFType a + => Array a + -- ^ Input 'Array' + -> Array a + -- ^ Weights 'Array' + -> VarBias + -- ^ Variance bias correction + -> Int + -- ^ Dimension along which to compute + -> (Array a, Array a) + -- ^ (mean, variance) +meanVarWeighted arr weights bias (fromIntegral -> dim) = + op2p2 arr weights $ \pMean pVar aPtr wPtr -> + af_meanvar pMean pVar aPtr wPtr (fromVarBias bias) dim diff --git a/src/ArrayFire/Types.hs b/src/ArrayFire/Types.hs index e63f6c9..5daac3c 100644 --- a/src/ArrayFire/Types.hs +++ b/src/ArrayFire/Types.hs @@ -32,6 +32,7 @@ module ArrayFire.Types , Features , AFType (..) , TopK (..) + , VarBias (..) , Backend (..) , MatchType (..) , BinaryOp (..) @@ -52,6 +53,11 @@ module ArrayFire.Types , InverseDeconvAlgo (..) , Seq (..) , Index (..) + , seqIdx + , arrIdx + , range + , rangeStep + , at , NormType (..) , ConvMode (..) , ConvDomain (..) diff --git a/src/ArrayFire/Util.hs b/src/ArrayFire/Util.hs index d8ba69b..26d0b80 100644 --- a/src/ArrayFire/Util.hs +++ b/src/ArrayFire/Util.hs @@ -258,6 +258,7 @@ arrayToString -- ^ If 'True', performs takes the transpose before rendering to 'String' -> String -- ^ 'Array' rendered to 'String' +{-# NOINLINE arrayToString #-} arrayToString expr (Array fptr) (fromIntegral -> prec) (fromIntegral . fromEnum -> trans) = unsafePerformIO . mask_ . withForeignPtr fptr $ \aptr -> withCString expr $ \expCstr -> @@ -279,6 +280,7 @@ getSizeOf -- ^ Witness of Haskell type that mirrors ArrayFire type. -> Int -- ^ Size of ArrayFire type +{-# NOINLINE getSizeOf #-} getSizeOf proxy = unsafePerformIO . mask_ . alloca $ \csize -> do throwAFError =<< af_get_size_of csize (afType proxy) diff --git a/src/ArrayFire/Vision.hs b/src/ArrayFire/Vision.hs index 71f3bd7..898ad5a 100644 --- a/src/ArrayFire/Vision.hs +++ b/src/ArrayFire/Vision.hs @@ -50,6 +50,7 @@ fast -- ^ Is the length of the edges in the image to be discarded by FAST (minimum is 3, as the radius of the circle) -> Features -- ^ Struct containing arrays for x and y coordinates and score, while array orientation is set to 0 as FAST does not compute orientation, and size is set to 1 as FAST does not compute multiple scales +{-# NOINLINE fast #-} fast (Array fptr) thr (fromIntegral -> arc) (fromIntegral . fromEnum -> non) ratio (fromIntegral -> edge) = unsafePerformIO . mask_ . withForeignPtr fptr $ \aptr -> do feat <- alloca $ \ptr -> do @@ -78,6 +79,7 @@ harris -> Float -- ^ struct containing arrays for x and y coordinates and score (Harris response), while arrays orientation and size are set to 0 and 1, respectively, because Harris does not compute that information -> Features +{-# NOINLINE harris #-} harris (Array fptr) (fromIntegral -> maxc) minresp sigma (fromIntegral -> bs) thr = unsafePerformIO . mask_ . withForeignPtr fptr $ \aptr -> do feat <- alloca $ \ptr -> do @@ -107,6 +109,7 @@ orb -- ^ blur image with a Gaussian filter with sigma=2 before computing descriptors to increase robustness against noise if true -> (Features, Array a) -- ^ 'Features' struct composed of arrays for x and y coordinates, score, orientation and size of selected features +{-# NOINLINE orb #-} orb (Array fptr) thr (fromIntegral -> feat) scl (fromIntegral -> levels) (fromIntegral . fromEnum -> blur) = unsafePerformIO . mask_ . withForeignPtr fptr $ \inptr -> do (feature, arr) <- @@ -144,6 +147,7 @@ sift -> (Features, Array a) -- ^ Features object composed of arrays for x and y coordinates, score, orientation and size of selected features -- Nx128 array containing extracted descriptors, where N is the number of features found by SIFT +{-# NOINLINE sift #-} sift (Array fptr) (fromIntegral -> a) b c d (fromIntegral . fromEnum -> e) f g = unsafePerformIO . mask_ . withForeignPtr fptr $ \inptr -> do (feat, arr) <- @@ -181,6 +185,7 @@ gloh -> (Features, Array a) -- ^ 'Features' object composed of arrays for x and y coordinates, score, orientation and size of selected features -- ^ Nx272 array containing extracted GLOH descriptors, where N is the number of features found by SIFT +{-# NOINLINE gloh #-} gloh (Array fptr) (fromIntegral -> a) b c d (fromIntegral . fromEnum -> e) f g = unsafePerformIO . mask_ . withForeignPtr fptr $ \inptr -> do (feat, arr) <- @@ -274,6 +279,7 @@ susan -> Int -- ^ indicates how many pixels width area should be skipped for corner detection -> Features +{-# NOINLINE susan #-} susan (Array fptr) (fromIntegral -> a) b c d (fromIntegral -> e) = unsafePerformIO . mask_ . withForeignPtr fptr $ \inptr -> do feat <- @@ -329,6 +335,7 @@ homography -> (Int, Array a) -- ^ is a 3x3 array containing the estimated homography. -- is the number of inliers that the homography was estimated to comprise, in the case that htype is AF_HOMOGRAPHY_RANSAC, a higher inlier_thr value will increase the estimated inliers. Note that if the number of inliers is too low, it is likely that a bad homography will be returned. +{-# NOINLINE homography #-} homography (Array a) (Array b) diff --git a/test/ArrayFire/AlgorithmSpec.hs b/test/ArrayFire/AlgorithmSpec.hs index 6e5b4d6..b4d3e0e 100644 --- a/test/ArrayFire/AlgorithmSpec.hs +++ b/test/ArrayFire/AlgorithmSpec.hs @@ -2,7 +2,6 @@ module ArrayFire.AlgorithmSpec where import qualified ArrayFire as A - import Test.Hspec spec :: Spec @@ -79,19 +78,29 @@ spec = A.min (A.vector @A.Word32 10 [1..]) 0 `shouldBe` 1 A.min (A.vector @A.Word64 10 [1..]) 0 `shouldBe` 1 A.min (A.vector @Double 10 [1..]) 0 `shouldBe` 1 - A.min (A.vector @(A.Complex Double) 10 (repeat (1 A.:+ 1))) 0 `shouldBe` A.scalar (1 A.:+ 1) - A.min (A.vector @(A.Complex Float) 10 (repeat (1 A.:+ 1))) 0 `shouldBe` A.scalar (1 A.:+ 1) - A.min (A.vector @A.CBool 10 [1..]) 0 `shouldBe` 1 + A.min (A.vector @(A.Complex Double) 3 [3 A.:+ 4, 1 A.:+ 0, 2 A.:+ 2]) 0 `shouldBe` A.scalar (1 A.:+ 0) + A.min (A.vector @(A.Complex Float) 3 [3 A.:+ 4, 1 A.:+ 0, 2 A.:+ 2]) 0 `shouldBe` A.scalar (1 A.:+ 0) A.min (A.vector @A.CBool 10 [1..]) 0 `shouldBe` 1 + it "Should take the maximum element of a vector" $ do + A.max (A.vector @Int 10 [1..]) 0 `shouldBe` 10 + A.max (A.vector @A.Int64 10 [1..]) 0 `shouldBe` 10 + A.max (A.vector @A.Int32 10 [1..]) 0 `shouldBe` 10 + A.max (A.vector @A.Int16 10 [1..]) 0 `shouldBe` 10 + A.max (A.vector @Float 10 [1..]) 0 `shouldBe` 10 + A.max (A.vector @A.Word32 10 [1..]) 0 `shouldBe` 10 + A.max (A.vector @A.Word64 10 [1..]) 0 `shouldBe` 10 + A.max (A.vector @Double 10 [1..]) 0 `shouldBe` 10 + A.max (A.vector @(A.Complex Double) 3 [3 A.:+ 4, 1 A.:+ 0, 2 A.:+ 2]) 0 `shouldBe` A.scalar (3 A.:+ 4) + A.max (A.vector @(A.Complex Float) 3 [3 A.:+ 4, 1 A.:+ 0, 2 A.:+ 2]) 0 `shouldBe` A.scalar (3 A.:+ 4) + A.max (A.vector @A.CBool 5 [0,1,1,0,1]) 0 `shouldBe` 1 it "Should find if all elements are true along dimension" $ do - A.allTrue (A.vector @Double 5 (repeat 12.0)) 0 `shouldBe` 1 - A.allTrue (A.vector @A.CBool 5 (repeat 1)) 0 `shouldBe` 1 - A.allTrue (A.vector @A.CBool 5 (repeat 0)) 0 `shouldBe` 0 - A.allTrue (A.vector @A.CBool 5 (repeat 0)) 0 `shouldBe` 0 + A.allTrue (A.vector @Double 5 (repeat 12.0)) 0 `shouldBe` A.scalar @A.CBool 1 + A.allTrue (A.vector @A.CBool 5 (repeat 1)) 0 `shouldBe` A.scalar @A.CBool 1 + A.allTrue (A.vector @A.CBool 5 (repeat 0)) 0 `shouldBe` A.scalar @A.CBool 0 it "Should find if any elements are true along dimension" $ do - A.anyTrue (A.vector @A.CBool 5 (repeat 1)) 0 `shouldBe` 1 - A.anyTrue (A.vector @Int 5 (repeat 23)) 0 `shouldBe` 1 - A.anyTrue (A.vector @A.CBool 5 (repeat 0)) 0 `shouldBe` 0 + A.anyTrue (A.vector @A.CBool 5 (repeat 1)) 0 `shouldBe` A.scalar @A.CBool 1 + A.anyTrue (A.vector @Int 5 (repeat 23)) 0 `shouldBe` A.scalar @A.CBool 1 + A.anyTrue (A.vector @A.CBool 5 (repeat 0)) 0 `shouldBe` A.scalar @A.CBool 0 it "Should get count of all elements" $ do A.count (A.vector @Int 5 (repeat 1)) 0 `shouldBe` 5 A.count (A.vector @A.CBool 5 (repeat 1)) 0 `shouldBe` 5 @@ -101,12 +110,12 @@ spec = A.sumAll (A.vector @Int 5 (repeat 2)) `shouldBe` (10,0) A.sumAll (A.vector @Double 5 (repeat 2)) `shouldBe` (10.0,0) A.sumAll (A.vector @A.CBool 3800 (repeat 1)) `shouldBe` (3800,0) - A.sumAll (A.vector @(A.Complex Double) 5 (repeat (2 A.:+ 0))) `shouldBe` (10.0,0) - it "Should get sum all elements" $ do + A.sumAll (A.vector @(A.Complex Double) 3 [1 A.:+ 2, 3 A.:+ 4, 5 A.:+ 6]) `shouldBe` (9.0, 12.0) + it "Should sum all elements ignoring NaN" $ do A.sumNaNAll (A.vector @Double 2 [10, acos 2]) 1 `shouldBe` (11.0,0) it "Should product all elements in an Array" $ do A.productAll (A.vector @Int 5 (repeat 2)) `shouldBe` (32,0) - it "Should product all elements in an Array" $ do + it "Should product all elements ignoring NaN" $ do A.productNaNAll (A.vector @Double 2 [10,acos 2]) 10 `shouldBe` (100,0) it "Should find minimum value of an Array" $ do A.minAll (A.vector @Int 5 [0..]) `shouldBe` (0,0) @@ -114,4 +123,161 @@ spec = A.maxAll (A.vector @Int 5 [0..]) `shouldBe` (4,0) -- it "Should find if all elements are true" $ do -- A.allTrue (A.vector @A.CBool 5 (repeat 0)) `shouldBe` False + it "Should sum values grouped by key" $ do + let keys = A.vector @Int 5 [1,1,2,2,2] + vals = A.vector @Double 5 [10,20,1,2,3] + (ko, vo) = A.sumByKey keys vals 0 + ko `shouldBe` A.vector @Int 2 [1,2] + vo `shouldBe` A.vector @Double 2 [30,6] + it "Should take the product of values grouped by key" $ do + let keys = A.vector @Int 4 [1,1,2,2] + vals = A.vector @Double 4 [2,3,4,5] + (ko, vo) = A.productByKey keys vals 0 + ko `shouldBe` A.vector @Int 2 [1,2] + vo `shouldBe` A.vector @Double 2 [6,20] + it "Should find the minimum value per key group" $ do + let keys = A.vector @Int 4 [1,1,2,2] + vals = A.vector @Double 4 [3,1,5,2] + (ko, vo) = A.minByKey keys vals 0 + ko `shouldBe` A.vector @Int 2 [1,2] + vo `shouldBe` A.vector @Double 2 [1,2] + it "Should find the maximum value per key group" $ do + let keys = A.vector @Int 4 [1,1,2,2] + vals = A.vector @Double 4 [3,1,5,2] + (ko, vo) = A.maxByKey keys vals 0 + ko `shouldBe` A.vector @Int 2 [1,2] + vo `shouldBe` A.vector @Double 2 [3,5] + it "Should count non-zero values per key group" $ do + let keys = A.vector @Int 4 [1,1,2,2] + vals = A.vector @Double 4 [1,0,1,1] + (ko, vo) = A.countByKey keys vals 0 + ko `shouldBe` A.vector @Int 2 [1,2] + vo `shouldBe` A.vector @Double 2 [1,2] + it "Should check allTrue per key group" $ do + let keys = A.vector @Int 4 [1,1,2,2] + vals = A.vector @A.CBool 4 [1,1,1,0] + (ko, vo) = A.allTrueByKey keys vals 0 + ko `shouldBe` A.vector @Int 2 [1,2] + vo `shouldBe` A.vector @A.CBool 2 [1,0] + it "Should check anyTrue per key group" $ do + let keys = A.vector @Int 4 [1,1,2,2] + vals = A.vector @A.CBool 4 [0,0,0,1] + (ko, vo) = A.anyTrueByKey keys vals 0 + ko `shouldBe` A.vector @Int 2 [1,2] + vo `shouldBe` A.vector @A.CBool 2 [0,1] + it "Should sum values grouped by key, substituting NaN with 0" $ do + let keys = A.vector @Int 4 [1,1,2,2] + vals = A.vector @Double 4 [10, (acos 2), 3, 4] + (ko, vo) = A.sumByKeyNaN keys vals 0 0 + ko `shouldBe` A.vector @Int 2 [1,2] + vo `shouldBe` A.vector @Double 2 [10, 7] + it "Should take the product of values grouped by key, substituting NaN with 1" $ do + let keys = A.vector @Int 4 [1,1,2,2] + vals = A.vector @Double 4 [2, (acos 2), 4, 5] + (ko, vo) = A.productByKeyNaN keys vals 0 1 + ko `shouldBe` A.vector @Int 2 [1,2] + vo `shouldBe` A.vector @Double 2 [2, 20] + + describe "accum" $ do + it "computes inclusive cumulative sum along dim 0" $ do + A.accum (A.vector @Double 5 [1,2,3,4,5]) 0 + `shouldBe` A.vector @Double 5 [1,3,6,10,15] + it "computes cumulative sum along dim 1 of a matrix" $ do + A.accum (A.mkArray @Double [2,3] [1,2,3,4,5,6]) 1 + `shouldBe` A.mkArray @Double [2,3] [1,2,4,6,9,12] + + describe "diff1" $ do + it "computes first differences along dim 0" $ do + A.diff1 (A.vector @Double 5 [1,2,4,7,11]) 0 + `shouldBe` A.vector @Double 4 [1,2,3,4] + it "first differences of a constant vector are zero" $ do + A.diff1 (A.vector @Double 4 (repeat 5)) 0 + `shouldBe` A.vector @Double 3 [0,0,0] + + describe "diff2" $ do + it "computes second differences of a quadratic sequence" $ do + A.diff2 (A.vector @Double 5 [0,1,4,9,16]) 0 + `shouldBe` A.vector @Double 3 [2,2,2] + it "second differences of a linear sequence are zero" $ do + A.diff2 (A.vector @Double 5 [1,2,3,4,5]) 0 + `shouldBe` A.vector @Double 3 [0,0,0] + + describe "where'" $ do + it "returns indices of nonzero elements" $ do + A.where' (A.vector @Double 5 [0,1,0,2,0]) + `shouldBe` A.vector @A.Word32 2 [1,3] + it "returns empty array when all elements are zero" $ do + A.getDims (A.where' (A.vector @Double 3 [0,0,0])) + `shouldBe` (0,1,1,1) + + describe "scan" $ do + it "inclusive scan with Add equals accum" $ do + A.scan (A.vector @Double 5 [1..5]) 0 A.Add True + `shouldBe` A.vector @Double 5 [1,3,6,10,15] + it "exclusive scan with Add shifts the prefix sums by one" $ do + A.scan (A.vector @Double 5 [1..5]) 0 A.Add False + `shouldBe` A.vector @Double 5 [0,1,3,6,10] + it "inclusive scan with Mul gives running product" $ do + A.scan (A.vector @Double 4 [1..4]) 0 A.Mul True + `shouldBe` A.vector @Double 4 [1,2,6,24] + + describe "scanByKey" $ do + it "resets prefix sum at each key boundary" $ do + let keys = A.vector @Int 4 [1,1,2,2] + vals = A.vector @Double 4 [1,2,3,4] + A.scanByKey keys vals 0 A.Add True + `shouldBe` A.vector @Double 4 [1,3,3,7] + + describe "sort" $ do + it "sorts ascending" $ do + A.sort (A.vector @Double 5 [3,1,4,1,5]) 0 True + `shouldBe` A.vector @Double 5 [1,1,3,4,5] + it "sorts descending" $ do + A.sort (A.vector @Double 5 [3,1,4,1,5]) 0 False + `shouldBe` A.vector @Double 5 [5,4,3,1,1] + + describe "sortIndex" $ do + it "returns sorted values and original indices" $ do + let (vals, idxs) = A.sortIndex (A.vector @Double 4 [3,2,1,4]) 0 True + vals `shouldBe` A.vector @Double 4 [1,2,3,4] + idxs `shouldBe` A.vector @A.Word32 4 [2,1,0,3] + + describe "sortByKey" $ do + it "sorts values by key order" $ do + let (ks, vs) = A.sortByKey + (A.vector @Double 4 [2,1,4,3]) + (A.vector @Double 4 [10,9,8,7]) + 0 True + ks `shouldBe` A.vector @Double 4 [1,2,3,4] + vs `shouldBe` A.vector @Double 4 [9,10,7,8] + + describe "setUnique" $ do + it "removes duplicate elements" $ do + A.setUnique (A.vector @Double 4 [1,1,2,2]) True + `shouldBe` A.vector @Double 2 [1,2] + it "returns a single-element array from an all-same vector" $ do + A.setUnique (A.vector @Double 3 [5,5,5]) True + `shouldBe` A.vector @Double 1 [5] + + describe "setUnion" $ do + it "produces the union of two sorted sets" $ do + A.setUnion (A.vector @Double 3 [3,4,5]) (A.vector @Double 3 [1,2,3]) True + `shouldBe` A.vector @Double 5 [1,2,3,4,5] + + describe "setIntersect" $ do + it "produces the intersection of two sorted sets" $ do + A.setIntersect (A.vector @Double 3 [3,4,5]) (A.vector @Double 3 [1,2,3]) True + `shouldBe` A.vector @Double 1 [3] + it "returns empty array for disjoint sets" $ do + A.getDims (A.setIntersect (A.vector @Double 2 [1,2]) (A.vector @Double 2 [3,4]) True) + `shouldBe` (0,1,1,1) + + -- Regression: infoFromArray3 was missing mask_, risking finalizer interference. + -- iminAll and imaxAll are the primary users. + it "iminAll returns correct value and index" $ do + let arr = A.vector @Double 5 [3, 1, 4, 2, 5] + A.iminAll arr `shouldBe` (1.0, 0.0, 1) + it "imaxAll returns correct value and index" $ do + let arr = A.vector @Double 5 [3, 1, 4, 1, 5] + A.imaxAll arr `shouldBe` (5.0, 0.0, 4) diff --git a/test/ArrayFire/ArithSpec.hs b/test/ArrayFire/ArithSpec.hs index 623726f..3686ec5 100644 --- a/test/ArrayFire/ArithSpec.hs +++ b/test/ArrayFire/ArithSpec.hs @@ -140,15 +140,15 @@ spec = clamp (scalar @Int 2) (scalar @Int 1) (scalar @Int 3) `shouldBe` 2 it "Should check if an array has positive or negative infinities" $ do - isInf (scalar @Double (1 / 0)) `shouldBe` scalar @Double 1 - isInf (scalar @Double 10) `shouldBe` scalar @Double 0 + isInf (scalar @Double (1 / 0)) `shouldBe` scalar @CBool 1 + isInf (scalar @Double 10) `shouldBe` scalar @CBool 0 it "Should check if an array has any NaN values" $ do - ArrayFire.isNaN (scalar @Double (acos 2)) `shouldBe` scalar @Double 1 - ArrayFire.isNaN (scalar @Double 10) `shouldBe` scalar @Double 0 + ArrayFire.isNaN (scalar @Double (acos 2)) `shouldBe` scalar @CBool 1 + ArrayFire.isNaN (scalar @Double 10) `shouldBe` scalar @CBool 0 it "Should check if an array has any Zero values" $ do - isZero (scalar @Double (acos 2)) `shouldBe` scalar @Double 0 - isZero (scalar @Double 0) `shouldBe` scalar @Double 1 - isZero (scalar @Double 1) `shouldBe` scalar @Double 0 + isZero (scalar @Double (acos 2)) `shouldBe` scalar @CBool 0 + isZero (scalar @Double 0) `shouldBe` scalar @CBool 1 + isZero (scalar @Double 1) `shouldBe` scalar @CBool 0 prop "Floating @Float (exp)" $ \(x :: Float) -> exp `shouldMatchBuiltin` exp $ x prop "Floating @Float (log)" $ \(x :: Float) -> log `shouldMatchBuiltin` log $ x @@ -166,3 +166,41 @@ spec = prop "Floating @Float (asinh)" $ \(x :: Float) -> asinh `shouldMatchBuiltin` asinh $ x prop "Floating @Float (acosh)" $ \(x :: Float) -> acosh `shouldMatchBuiltin` acosh $ x prop "Floating @Float (atanh)" $ \(x :: Float) -> atanh `shouldMatchBuiltin` atanh $ x + + describe "erf" $ do + it "erf 0 = 0" $ + evalf (ArrayFire.erf (scalar @Double 0)) `shouldBeApprox` 0 + it "erf 1 ≈ 0.8427" $ + evalf (ArrayFire.erf (scalar @Double 1)) `shouldBeApprox` 0.8427007929497149 + it "erf is odd: erf(-x) = -erf(x)" $ + evalf (ArrayFire.erf (scalar @Double (-1))) `shouldBeApprox` + negate (evalf (ArrayFire.erf (scalar @Double 1))) + + describe "erfc" $ do + it "erfc 0 = 1" $ + evalf (ArrayFire.erfc (scalar @Double 0)) `shouldBeApprox` 1 + it "erf(x) + erfc(x) = 1" $ do + let x = scalar @Double 1.5 + (evalf (ArrayFire.erf x) + evalf (ArrayFire.erfc x)) `shouldBeApprox` 1 + + describe "sigmoid" $ do + it "sigmoid 0 = 0.5" $ + evalf (ArrayFire.sigmoid (scalar @Double 0)) `shouldBeApprox` 0.5 + it "sigmoid(-x) = 1 - sigmoid(x)" $ do + let x = scalar @Double 2.0 + evalf (ArrayFire.sigmoid (negate x)) + `shouldBeApprox` + (1 - evalf (ArrayFire.sigmoid x)) + + describe "expm1" $ do + it "expm1 0 = 0" $ + evalf (ArrayFire.expm1 (scalar @Double 0)) `shouldBeApprox` 0 + it "expm1 1 = e - 1" $ + evalf (ArrayFire.expm1 (scalar @Double 1)) `shouldBeApprox` (exp 1 - 1) + + describe "clamp (vector)" $ do + it "clamps each element to [lo, hi]" $ + clamp (vector @Int 5 [0,1,5,9,10]) + (scalar @Int 2) + (scalar @Int 8) + `shouldBe` vector @Int 5 [2,2,5,8,8] diff --git a/test/ArrayFire/ArraySpec.hs b/test/ArrayFire/ArraySpec.hs index 1452a00..641caa6 100644 --- a/test/ArrayFire/ArraySpec.hs +++ b/test/ArrayFire/ArraySpec.hs @@ -4,6 +4,7 @@ module ArrayFire.ArraySpec where import Control.Exception import Data.Complex +import qualified Data.Vector.Storable as V import Data.Word import Foreign.C.Types import GHC.Int @@ -14,17 +15,14 @@ import ArrayFire spec :: Spec spec = describe "Array tests" $ do - it "Should perform Array tests" $ do - (1 + 1) `shouldBe` 2 - it "Should fail to create 0 dimension arrays" $ do - let arr = mkArray @Int [0,0,0,0] [1..] - evaluate arr `shouldThrow` anyException - it "Should fail to create 0 length arrays" $ do - let arr = mkArray @Int [0,0,0,1] [] - evaluate arr `shouldThrow` anyException - it "Should fail to create 0 length arrays w/ 0 dimensions" $ do - let arr = mkArray @Int [0,0,0,0] [] - evaluate arr `shouldThrow` anyException + it "Should add two scalar arrays" $ do + (scalar @Int 1 + scalar @Int 1) `shouldBe` scalar @Int 2 + it "Should create a 0 dimension array" $ do + getElements (mkArray @Int [3,0,1,1] []) `shouldBe` 0 + it "Should create a 0 length array" $ do + getElements (mkArray @Int [0,0,0,1] []) `shouldBe` 0 + it "Should create a 0 length array w/ 0 dimensions" $ do + getElements (mkArray @Int [0,0,0,0] []) `shouldBe` 0 it "Should create a column vector" $ do let arr = mkArray @Int [9,1,1,1] (repeat 9) isColumn arr `shouldBe` True @@ -47,10 +45,10 @@ spec = it "Should return the number of elements" $ do let arr = mkArray @Int [9,9,1,1] [1..] getElements arr `shouldBe` 81 --- it "Should give an empty array" $ do --- let arr = mkArray @Int [-1,1,1,1] [] --- getElements arr `shouldBe` 0 --- isEmpty arr `shouldBe` True + it "Should give an empty array" $ do + let arr = mkArray @Int [0,1,1,1] [] + getElements arr `shouldBe` 0 + isEmpty arr `shouldBe` True it "Should create a scalar array" $ do let arr = mkArray @Int [1] [1] isScalar arr `shouldBe` True @@ -154,3 +152,41 @@ spec = let arr = mkArray @Word [10] [1..10] toList arr `shouldBe` [1..10] + + -- Regression: toVector previously allocated len*size bytes instead of size, + -- causing quadratic memory use. These round-trips verify correct element count + -- and values at sizes where the bug was most wasteful. + describe "toVector round-trip" $ do + it "preserves all elements for a 1000-element Double array" $ do + let xs = [1..1000] :: [Double] + arr = mkArray @Double [1000] xs + V.toList (toVector arr) `shouldBe` xs + it "preserves all elements for a 500-element Int array" $ do + let xs = [1..500] :: [Int] + arr = mkArray @Int [500] xs + V.toList (toVector arr) `shouldBe` xs + it "length of toVector matches getElements" $ do + let arr = mkArray @Double [7, 13] (repeat 0) + V.length (toVector arr) `shouldBe` getElements arr + + describe "fromVector" $ do + it "round-trips a Double vector" $ do + let xs = V.fromList [1..10 :: Double] + arr = fromVector @Double [10] xs + toVector arr `shouldBe` xs + it "round-trips an Int vector" $ do + let xs = V.fromList [1..100 :: Int] + arr = fromVector @Int [100] xs + toVector arr `shouldBe` xs + it "round-trips a Complex Double vector" $ do + let xs = V.fromList [1 :+ 2, 3 :+ 4 :: Complex Double] + arr = fromVector @(Complex Double) [2] xs + toVector arr `shouldBe` xs + it "produces the same result as mkArray" $ do + let xs = [1..25 :: Double] + arr1 = mkArray @Double [5,5] xs + arr2 = fromVector @Double [5,5] (V.fromList xs) + arr2 `shouldBe` arr1 + it "throws on dimension mismatch" $ do + let xs = V.fromList [1,2,3 :: Double] + evaluate (fromVector @Double [4] xs) `shouldThrow` anyException diff --git a/test/ArrayFire/BLASSpec.hs b/test/ArrayFire/BLASSpec.hs index 40cbbec..43664b3 100644 --- a/test/ArrayFire/BLASSpec.hs +++ b/test/ArrayFire/BLASSpec.hs @@ -14,22 +14,31 @@ spec = `shouldBe` matrix @Double (2,2) [[8,8],[8,8]] it "Should dot product two vectors" $ do dot (vector @Double 2 (repeat 2)) (vector @Double 2 (repeat 2)) None None - `shouldBe` - scalar @Double 8 + `shouldBe` scalar @Double 8 it "Should produce scalar dot product between two vectors as a Complex number" $ do dotAll (vector @Double 2 (repeat 2)) (vector @Double 2 (repeat 2)) None None - `shouldBe` - 8.0 :+ 0.0 + `shouldBe` 8.0 :+ 0.0 it "Should take the transpose of a matrix" $ do transpose (matrix @Double (2,2) [[1,1],[2,2]]) False - `shouldBe` - matrix @Double (2,2) [[1,2],[1,2]] + `shouldBe` matrix @Double (2,2) [[1,2],[1,2]] it "Should take the transpose of a matrix in place" $ do + -- transposeInPlace is an IO () that mutates the underlying C buffer. + -- All Haskell references sharing the same ForeignPtr see the result. + -- Do not use the original binding after calling this. let m = matrix @Double (2,2) [[1,1],[2,2]] transposeInPlace m False m `shouldBe` matrix @Double (2,2) [[1,2],[1,2]] - - - - - + it "Should perform gemm: C = 1*A*B + 0*C (identity scaling)" $ do + let a = matrix @Double (2,2) [[1,2],[3,4]] + b = matrix @Double (2,2) [[1,0],[0,1]] + gemm None None 1.0 a b 0.0 `shouldBe` a + it "Should perform gemm: C = alpha*A*B with alpha=2" $ do + -- b is column-major: col0=[3,4], col1=[5,6] → matrix [[3,5],[4,6]] + -- 2 * I * b = 2b → col0=[6,8], col1=[10,12] + let a = matrix @Double (2,2) [[1,0],[0,1]] + b = matrix @Double (2,2) [[3,4],[5,6]] + gemm None None 2.0 a b 0.0 `shouldBe` matrix @Double (2,2) [[6,8],[10,12]] + it "Should perform gemm with transposed A: C = A^T * B" $ do + let a = matrix @Double (2,2) [[1,3],[2,4]] + b = matrix @Double (2,2) [[1,0],[0,1]] + gemm Trans None 1.0 a b 0.0 `shouldBe` matrix @Double (2,2) [[1,2],[3,4]] diff --git a/test/ArrayFire/DataSpec.hs b/test/ArrayFire/DataSpec.hs index fcbd53f..855e90e 100644 --- a/test/ArrayFire/DataSpec.hs +++ b/test/ArrayFire/DataSpec.hs @@ -2,14 +2,15 @@ {-# LANGUAGE TypeApplications #-} module ArrayFire.DataSpec where -import Control.Exception -import Data.Complex -import Data.Word -import Foreign.C.Types -import GHC.Int -import Test.Hspec +import Control.Exception +import Data.Complex +import Data.Word +import Foreign.C.Types +import GHC.Int +import Prelude hiding (flip) +import Test.Hspec -import ArrayFire +import ArrayFire spec :: Spec spec = @@ -32,6 +33,116 @@ spec = constant @(Complex Float) [1] (1.0 :+ 1.0) `shouldBe` constant @(Complex Float) [1] (1.0 :+ 1.0) + + describe "arange" $ do + it "generates a sequence along dim 0 for a 1D array" $ do + arange @Double [5] (-1) `shouldBe` vector @Double 5 [0,1,2,3,4] + it "generates a sequence along dim 1 for a 2D array" $ do + arange @Double [3,2] 1 `shouldBe` mkArray @Double [3,2] [0,0,0,1,1,1] + + describe "iota" $ do + it "generates a flat sequence without tiling" $ do + iota @Double [5] [] `shouldBe` vector @Double 5 [0,1,2,3,4] + it "tiles the sequence along dim 0" $ do + iota @Double [3] [2] `shouldBe` vector @Double 6 [0,1,2,0,1,2] + + describe "identity" $ do + it "creates a 2x2 identity matrix" $ do + identity @Double [2,2] + `shouldBe` mkArray @Double [2,2] [1,0,0,1] + it "creates a 3x3 identity matrix" $ do + identity @Double [3,3] + `shouldBe` mkArray @Double [3,3] [1,0,0,0,1,0,0,0,1] + + describe "diagCreate" $ do + it "creates a diagonal matrix from a vector (diag 0)" $ do + diagCreate (vector @Double 3 [1,2,3]) 0 + `shouldBe` mkArray @Double [3,3] [1,0,0,0,2,0,0,0,3] + it "creates a superdiagonal matrix (diag 1)" $ do + diagCreate (vector @Double 2 [5,6]) 1 + `shouldBe` mkArray @Double [3,3] [0,0,0,5,0,0,0,6,0] + + describe "diagExtract" $ do + it "extracts the main diagonal of a square matrix" $ do + diagExtract (mkArray @Double [3,3] [1,0,0,0,2,0,0,0,3]) 0 + `shouldBe` vector @Double 3 [1,2,3] + it "is the inverse of diagCreate on the main diagonal" $ do + let v = vector @Double 4 [1,2,3,4] + diagExtract (diagCreate v 0) 0 `shouldBe` v + + describe "lower" $ do + it "extracts the lower triangular part (unit diagonal)" $ do + let m = mkArray @Double [3,3] [1,2,3,4,5,6,7,8,9] + lower m True + `shouldBe` mkArray @Double [3,3] [1,2,3,0,1,6,0,0,1] + it "extracts the lower triangular part (non-unit diagonal)" $ do + let m = mkArray @Double [3,3] [1,2,3,4,5,6,7,8,9] + lower m False + `shouldBe` mkArray @Double [3,3] [1,2,3,0,5,6,0,0,9] + + describe "upper" $ do + it "extracts the upper triangular part (unit diagonal)" $ do + let m = mkArray @Double [3,3] [1,2,3,4,5,6,7,8,9] + upper m True + `shouldBe` mkArray @Double [3,3] [1,0,0,4,1,0,7,8,1] + it "extracts the upper triangular part (non-unit diagonal)" $ do + let m = mkArray @Double [3,3] [1,2,3,4,5,6,7,8,9] + upper m False + `shouldBe` mkArray @Double [3,3] [1,0,0,4,5,0,7,8,9] + + describe "tile" $ do + it "tiles a scalar into a 3x3 array" $ do + tile (scalar @Int 7) [3,3] + `shouldBe` constant @Int [3,3] 7 + it "tiles a row vector along dim 0" $ do + tile (mkArray @Int [1,3] [1,2,3]) [2,1] + `shouldBe` mkArray @Int [2,3] [1,1,2,2,3,3] + + describe "moddims" $ do + it "reshapes a vector into a matrix" $ do + moddims (vector @Int 6 [1..6]) [2,3] + `shouldBe` mkArray @Int [2,3] [1,2,3,4,5,6] + it "reshapes a matrix back to a vector" $ do + let v = vector @Int 6 [1..6] + moddims (moddims v [2,3]) [6] `shouldBe` v + + describe "flat" $ do + it "flattens a 2x3 matrix to a 6-element vector" $ do + flat (mkArray @Int [2,3] [1,2,3,4,5,6]) + `shouldBe` vector @Int 6 [1,2,3,4,5,6] + + describe "flip" $ do + it "reverses a vector (dim 0)" $ do + flip (vector @Int 4 [1,2,3,4]) 0 + `shouldBe` vector @Int 4 [4,3,2,1] + it "reverses columns of a matrix (dim 1)" $ do + flip (mkArray @Int [2,2] [1,2,3,4]) 1 + `shouldBe` mkArray @Int [2,2] [3,4,1,2] + + describe "shift" $ do + it "shifts a vector by 2 elements (wrapping)" $ do + shift (vector @Double 4 [1,2,3,4]) 2 0 0 0 + `shouldBe` vector @Double 4 [3,4,1,2] + + describe "select" $ do + it "selects elements from two arrays based on a boolean mask" $ do + let cond = vector @CBool 4 [1,0,1,0] + a = vector @Double 4 [10,20,30,40] + b = vector @Double 4 [1,2,3,4] + select cond a b `shouldBe` vector @Double 4 [10,2,30,4] + + describe "selectScalarR" $ do + it "uses scalar for false positions" $ do + let cond = vector @CBool 4 [1,0,1,0] + a = vector @Double 4 [10,20,30,40] + selectScalarR cond a 99 `shouldBe` vector @Double 4 [10,99,30,99] + + describe "selectScalarL" $ do + it "uses scalar for true positions" $ do + let cond = vector @CBool 4 [1,0,1,0] + b = vector @Double 4 [1,2,3,4] + selectScalarL cond 99 b `shouldBe` vector @Double 4 [99,2,99,4] + it "Should join Arrays along the specified dimension" $ do join 0 (constant @Int [1, 3] 1) (constant @Int [1, 3] 2) `shouldBe` mkArray @Int [2, 3] [1, 2, 1, 2, 1, 2] join 1 (constant @Int [1, 2] 1) (constant @Int [1, 2] 2) `shouldBe` mkArray @Int [1, 4] [1, 1, 2, 2] diff --git a/test/ArrayFire/IndexSpec.hs b/test/ArrayFire/IndexSpec.hs index d709317..8d31e1e 100644 --- a/test/ArrayFire/IndexSpec.hs +++ b/test/ArrayFire/IndexSpec.hs @@ -1,21 +1,111 @@ -{-# LANGUAGE BangPatterns #-} {-# LANGUAGE TypeApplications #-} module ArrayFire.IndexSpec where -import qualified ArrayFire as A -import Control.Exception -import Data.Complex -import Data.Int -import Data.Proxy -import Data.Word -import Foreign.C.Types +import qualified ArrayFire as A +import Data.Function ((&)) import Test.Hspec spec :: Spec spec = - describe "Index spec" $ do - it "Should index into an array" $ do - let arr = A.vector @Int 10 [1..] - A.index arr [A.Seq 0 4 1] - `shouldBe` - A.vector @Int 5 [1..] + describe "Index" $ do + + describe "index" $ do + it "indexes a sub-range of a vector" $ do + A.index (A.vector @Int 10 [1..]) [A.Seq 0 4 1] + `shouldBe` A.vector @Int 5 [1..] + it "indexes every other element with step=2" $ do + A.index (A.vector @Int 6 [0,1,2,3,4,5]) [A.Seq 0 4 2] + `shouldBe` A.vector @Int 3 [0,2,4] + it "selects the full vector with afSpan" $ do + let arr = A.vector @Int 5 [1..] + A.index arr [A.afSpan] `shouldBe` arr + + describe "afSpan" $ do + it "equals Seq 1 1 0 (the ArrayFire span sentinel)" $ do + A.afSpan `shouldBe` A.Seq 1 1 0 + + describe "lookup" $ do + it "gathers elements by an index array" $ do + let arr = A.vector @Double 5 [10, 20, 30, 40, 50] + ixArr = A.vector @Int 3 [0, 2, 4] + A.lookup arr ixArr 0 + `shouldBe` A.vector @Double 3 [10, 30, 50] + it "allows repeated indices" $ do + let arr = A.vector @Int 5 [10, 20, 30, 40, 50] + ixArr = A.vector @Int 4 [0, 0, 4, 4] + A.lookup arr ixArr 0 + `shouldBe` A.vector @Int 4 [10, 10, 50, 50] + + describe "assignSeq" $ do + it "assigns into a middle slice of a vector" $ do + let arr = A.vector @Double 5 [1..] + src = A.vector @Double 3 [0, 0, 0] + A.assignSeq arr [A.Seq 1 3 1] src + `shouldBe` A.vector @Double 5 [1, 0, 0, 0, 5] + it "assigns a single element" $ do + let arr = A.vector @Double 5 [1..] + src = A.scalar @Double 99 + A.assignSeq arr [A.Seq 2 2 1] src + `shouldBe` A.vector @Double 5 [1, 2, 99, 4, 5] + it "overwrites the full vector via afSpan" $ do + let arr = A.vector @Double 5 [1..] + src = A.vector @Double 5 (repeat 0) + A.assignSeq arr [A.afSpan] src `shouldBe` src + + describe "indexGen" $ do + it "indexes a sub-range of a vector with seqIdx" $ do + let arr = A.vector @Double 5 [10, 20, 30, 40, 50] + A.indexGen arr [A.seqIdx (A.Seq 0 2 1) False] + `shouldBe` A.vector @Double 3 [10, 20, 30] + it "indexes a 2D sub-matrix with two seqIdx" $ do + let arr = A.matrix @Double (3,3) [[1,2,3],[4,5,6],[7,8,9]] + A.indexGen arr [ A.seqIdx (A.Seq 0 1 1) False + , A.seqIdx (A.Seq 0 1 1) False ] + `shouldBe` A.matrix @Double (2,2) [[1,2],[4,5]] + + describe "assignGen" $ do + it "assigns into a vector slice with seqIdx" $ do + let arr = A.vector @Double 5 [1..] + src = A.vector @Double 3 [0, 0, 0] + result = A.assignGen arr [A.seqIdx (A.Seq 1 3 1) False] src + A.indexGen result [A.seqIdx (A.Seq 1 3 1) False] `shouldBe` src + it "assigns into a 2D sub-matrix with two seqIdx" $ do + let arr = A.matrix @Double (3,3) [[1,2,3],[4,5,6],[7,8,9]] + src = A.matrix @Double (2,2) [[0,0],[0,0]] + result = A.assignGen arr [ A.seqIdx (A.Seq 0 1 1) False + , A.seqIdx (A.Seq 0 1 1) False ] src + A.indexGen result [ A.seqIdx (A.Seq 0 1 1) False + , A.seqIdx (A.Seq 0 1 1) False ] + `shouldBe` src + + describe "(!) operator" $ do + it "indexes a 1D sub-range with range" $ do + let arr = A.vector @Double 5 [10, 20, 30, 40, 50] + (arr A.! A.range 0 2) + `shouldBe` A.vector @Double 3 [10, 20, 30] + it "indexes a single element with at" $ do + let arr = A.vector @Double 5 [10, 20, 30, 40, 50] + (arr A.! A.at 2) + `shouldBe` A.scalar @Double 30 + it "indexes a 2D sub-matrix with a tuple" $ do + let arr = A.matrix @Double (3,3) [[1,2,3],[4,5,6],[7,8,9]] + (arr A.! (A.range 0 1, A.range 0 1)) + `shouldBe` A.matrix @Double (2,2) [[1,2],[4,5]] + + describe "(.~) operator" $ do + it "assigns into a 1D slice" $ do + let arr = A.vector @Double 5 [1..] + src = A.vector @Double 3 [0, 0, 0] + result = arr & A.range 1 3 A..~ src + (result A.! A.range 1 3) `shouldBe` src + it "assigns into a 2D sub-matrix" $ do + let arr = A.matrix @Double (3,3) [[1,2,3],[4,5,6],[7,8,9]] + src = A.matrix @Double (2,2) [[0,0],[0,0]] + result = arr & (A.range 0 1, A.range 0 1) A..~ src + (result A.! (A.range 0 1, A.range 0 1)) `shouldBe` src + + describe "rangeStep" $ do + it "selects every other element" $ do + let arr = A.vector @Double 6 [0,1,2,3,4,5] + (arr A.! A.rangeStep 0 4 2) + `shouldBe` A.vector @Double 3 [0,2,4] diff --git a/test/ArrayFire/LAPACKSpec.hs b/test/ArrayFire/LAPACKSpec.hs index 5c225c7..355cda9 100644 --- a/test/ArrayFire/LAPACKSpec.hs +++ b/test/ArrayFire/LAPACKSpec.hs @@ -4,42 +4,93 @@ module ArrayFire.LAPACKSpec where import qualified ArrayFire as A import Prelude import Test.Hspec -import Test.Hspec.ApproxExpect +import Test.Hspec.ApproxExpect spec :: Spec spec = describe "LAPACK spec" $ do it "Should have LAPACK available" $ do A.isLAPACKAvailable `shouldBe` True + it "Should perform svd" $ do let (s,v,d) = A.svd $ A.matrix @Double (4,2) [ [1,2,3,4], [5,6,7,8] ] A.getDims s `shouldBe` (4,4,1,1) A.getDims v `shouldBe` (2,1,1,1) A.getDims d `shouldBe` (2,2,1,1) + it "Should perform svd in place" $ do let (s,v,d) = A.svdInPlace $ A.matrix @Double (4,2) [ [1,2,3,4], [5,6,7,8] ] A.getDims s `shouldBe` (4,4,1,1) A.getDims v `shouldBe` (2,1,1,1) A.getDims d `shouldBe` (2,2,1,1) + it "Should perform lu" $ do - let (s,v,d) = A.lu $ A.matrix @Double (2,2) [[3,1],[4,2]] - A.getDims s `shouldBe` (2,2,1,1) - A.getDims v `shouldBe` (2,2,1,1) - A.getDims d `shouldBe` (2,1,1,1) + let (l,u,piv) = A.lu $ A.matrix @Double (2,2) [[3,1],[4,2]] + A.getDims l `shouldBe` (2,2,1,1) + A.getDims u `shouldBe` (2,2,1,1) + A.getDims piv `shouldBe` (2,1,1,1) + it "Should perform qr" $ do - let (s,v,d) = A.lu $ A.matrix @Double (3,3) [[12,6,4],[-51,167,24],[4,-68,-41]] - A.getDims s `shouldBe` (3,3,1,1) - A.getDims v `shouldBe` (3,3,1,1) - A.getDims d `shouldBe` (3,1,1,1) - it "Should get determinant of Double" $ do - let eles = [[3 A.:+ 1, 8 A.:+ 1], [4 A.:+ 1, 6 A.:+ 1]] - (x,y) = A.det (A.matrix @(A.Complex Double) (2,2) eles) - x `shouldBeApprox` (-14) - let (x,y) = A.det $ A.matrix @Double (2,2) [[3,8],[4,6]] - x `shouldBeApprox` (-14) --- it "Should calculate inverse" $ do --- let x = flip A.inverse A.None $ A.matrix @Double (2,2) [[4.0,7.0],[2.0,6.0]] --- x `shouldBe` A.matrix (2,2) [[0.6,-0.7],[-0.2,0.4]] --- it "Should calculate psuedo inverse" $ do --- let x = A.pinverse (A.matrix @Double (2,2) [[4,7],[2,6]]) 1.0 A.None --- x `shouldBe` A.matrix @Double (2,2) [[0.6,-0.2],[-0.7,0.4]] + let (q,r,tau) = A.qr $ A.matrix @Double (3,3) [[12,6,4],[-51,167,24],[4,-68,-41]] + A.getDims q `shouldBe` (3,3,1,1) + A.getDims r `shouldBe` (3,3,1,1) + A.getDims tau `shouldBe` (3,1,1,1) + + it "Should get determinant of a real matrix" $ do + let (re, _im) = A.det $ A.matrix @Double (2,2) [[3,8],[4,6]] + re `shouldBeApprox` (-14) + + it "Should get determinant of a complex matrix" $ do + -- M = | 3+i 4+i | (column-major: col0=[3+i,8+i], col1=[4+i,6+i]) + -- | 8+i 6+i | + -- det = (3+i)(6+i) - (4+i)(8+i) = -14 - 3i + let (re, im) = A.det $ A.matrix @(A.Complex Double) (2,2) + [[3 A.:+ 1, 8 A.:+ 1], [4 A.:+ 1, 6 A.:+ 1]] + re `shouldBeApprox` (-14) + im `shouldBeApprox` (-3) + + it "Should calculate inverse" $ do + -- M = | 4 2 | (column-major: col0=[4,7], col1=[2,6]) + -- | 7 6 | + -- M^-1 = (1/10) * | 6 -2 | = col0=[0.6,-0.7], col1=[-0.2,0.4] + -- | -7 4 | + let result = A.toList $ A.inverse (A.matrix @Double (2,2) [[4.0,7.0],[2.0,6.0]]) A.None + expected = [0.6, -0.7, -0.2, 0.4] + mapM_ (uncurry shouldBeApprox) (zip result expected) + + it "Should find the rank of a matrix" $ do + A.rank (A.matrix @Double (3,3) [[1,2,3],[4,5,6],[7,8,9]]) 1e-5 `shouldBe` 2 + A.rank (A.identity @Double [3,3]) 1e-5 `shouldBe` 3 + + it "Should compute the norm of a vector" $ do + -- || [3, 4] ||_2 = 5 + A.norm (A.vector @Double 2 [3,4]) A.NormVector2 1 1 `shouldBeApprox` 5 + -- || [3, 4] ||_1 = 7 + A.norm (A.vector @Double 2 [3,4]) A.NormVectorOne 1 1 `shouldBeApprox` 7 + -- || [3, 4] ||_inf = 4 + A.norm (A.vector @Double 2 [3,4]) A.NormVectorInf 1 1 `shouldBeApprox` 4 + + it "Should perform cholesky decomposition" $ do + -- A = | 4 2 | (column-major: [4,2,2,3]) + -- | 2 3 | + -- L = | 2 0 | where L*L^T = A + -- | 1 √2 | + let a = A.mkArray @Double [2,2] [4,2,2,3] + (status, l) = A.cholesky a False + status `shouldBe` 0 + let ls = A.toList @Double l + mapM_ (uncurry shouldBeApprox) (zip ls [2, 1, 0, sqrt 2]) + + it "choleskyInplace returns 0 for a symmetric positive definite matrix" $ do + let a = A.mkArray @Double [2,2] [4,2,2,3] + A.choleskyInplace a False `shouldBe` 0 + + it "Should solve Ax=b using solveLU" $ do + -- A = | 2 1 | b = | 5 | => x = | 2 | + -- | 1 3 | | 10| | 3 | + -- Column-major A: [2,1,1,3], b: [5,10] + let a = A.mkArray @Double [2,2] [2,1,1,3] + b = A.vector @Double 2 [5,10] + piv = A.luInPlace a True + x = A.solveLU a piv b A.None + mapM_ (uncurry shouldBeApprox) (zip (A.toList @Double x) [1,3]) diff --git a/test/ArrayFire/NumericalSpec.hs b/test/ArrayFire/NumericalSpec.hs new file mode 100644 index 0000000..fac01c8 --- /dev/null +++ b/test/ArrayFire/NumericalSpec.hs @@ -0,0 +1,118 @@ +{-# LANGUAGE TypeApplications #-} +-- | Numerical algorithm tests that exercise broad API surface area. +-- Each test has a known exact answer derived from mathematics, so failures +-- indicate either a bug in the library or a precision regression. +module ArrayFire.NumericalSpec where + +import qualified ArrayFire as A +import Data.Function ((&)) +import Test.Hspec + +tol :: Double +tol = 1e-4 + +shouldBeApprox :: Double -> Double -> Expectation +shouldBeApprox x y = abs (x - y) < tol `shouldBe` True + +spec :: Spec +spec = describe "Numerical algorithms" $ do + + -- ∫₀^π sin(x) dx = 2 (midpoint rectangle rule) + -- Exercises: arange, sin, sumAll, scalar, *, + + describe "Rectangle-rule integration" $ do + it "approximates integral of sin over [0,pi] = 2" $ do + let n = 10000 :: Int + h = pi / fromIntegral n + is = A.arange @Double [n] (-1) -- [0,1,...,n-1] + xs = (is + A.scalar 0.5) * A.scalar h -- midpoints + result = h * fst (A.sumAll (sin xs)) + result `shouldBeApprox` 2.0 + + -- Power iteration on A = [[2,1],[1,2]] + -- Exact dominant eigenvalue = 3, eigenvector = [1,1]/√2 + -- Exercises: matrix, matmul, sumAll, *, /, scalar, sqrt, Haskell iterate + describe "Power iteration" $ do + it "converges to dominant eigenvalue 3 of [[2,1],[1,2]]" $ do + let a = A.matrix @Double (2,2) [[2,1],[1,2]] + v0 = A.matrix @Double (2,1) [[1,1]] + norm2 v = sqrt . fst $ A.sumAll (v * v) + norm v = v / A.scalar (norm2 v) + step v = norm (A.matmul a v A.None A.None) + vFinal = iterate step (norm v0) !! 30 + av = A.matmul a vFinal A.None A.None + -- Rayleigh quotient: v^T A v + lambda = fst $ A.sumAll (vFinal * av) + lambda `shouldBeApprox` 3.0 + + -- Geometric series: Σ(k=0..19) 0.5^k = (1 - 0.5^20)/(1 - 0.5) + -- Exercises: arange, (**), sumAll, scalar + describe "Geometric series" $ do + it "sum of 0.5^k for k=0..19 matches closed form" $ do + let n = 20 :: Int + ks = A.arange @Double [n] (-1) + terms = A.scalar 0.5 ** ks + result = fst (A.sumAll terms) + expected = (1.0 - 0.5 ^ n) / (1.0 - 0.5) + result `shouldBeApprox` expected + + -- Centered-difference moving average on u = [1..10]: + -- avg_i = (u[i-1] + u[i+1]) / 2 for i = 1..8 + -- For an arithmetic sequence, this equals u[i] exactly. + -- Exercises: vector, (!), range, +, /, scalar + describe "Slice-based centered differences" $ do + it "moving average of arithmetic sequence equals interior values" $ do + let u = A.vector @Double 10 [1..10] + avg = (u A.! A.range 0 7 + u A.! A.range 2 9) / A.scalar 2.0 + avg `shouldBe` u A.! A.range 1 8 + + -- Slice assignment: overwrite interior of a zero vector. + -- Exercises: vector, &, (.~), !, range, toList + describe "Slice assignment" $ do + it "(.~) writes src into interior slice, leaves boundaries unchanged" $ do + let u = A.vector @Double 6 (repeat 0.0) + src = A.vector @Double 4 [1,2,3,4] + result = u & A.range 1 4 A..~ src + A.toList result `shouldBe` [0,1,2,3,4,0] + + -- Sample statistics of [1..100]. + -- mean([1..100]) = 50.5 (exact by Gauss's formula) + -- sum = n * mean must hold exactly. + -- Exercises: vector, meanAll, sumAll + describe "Statistical identities" $ do + it "mean of [1..100] = 50.5" $ do + let (m, _) = A.meanAll (A.vector @Double 100 [1..100]) + m `shouldBeApprox` 50.5 + it "sumAll = n * meanAll" $ do + let arr = A.vector @Double 100 [1..100] + (m, _) = A.meanAll arr + (s, _) = A.sumAll arr + s `shouldBeApprox` (100 * m) + it "variance of a constant array is 0" $ do + let (v, _) = A.varAll (A.vector @Double 50 (repeat 7.0)) False + v `shouldBeApprox` 0.0 + + -- Sum of first n squares: Σ(k=1..n) k² = n(n+1)(2n+1)/6 + -- Exercises: iota, *, +, scalar, sumAll + describe "Sum of squares" $ do + it "Sigma k^2 for k=1..100 matches closed form n(n+1)(2n+1)/6" $ do + let n = 100 :: Int + ks = A.iota @Double [n] [] + A.scalar 1.0 -- [1,2,...,n] + result = fst $ A.sumAll (ks * ks) + expected = fromIntegral (n * (n+1) * (2*n+1)) / 6.0 + result `shouldBeApprox` expected + + -- Parseval's theorem: ||x||² = (1/N)||X||² where X = FFT(x) + -- Uses a complex Dirac delta: |x|² = 1, FFT is a flat spectrum |X[k]|² = 1 each. + -- Exercises: mkArray, fft, conjg, real, sumAll, * + describe "Parseval's theorem" $ do + it "time-domain and frequency-domain energies agree" $ do + let n = 64 :: Int + -- Dirac delta: all energy in first sample + xs = A.mkArray @(A.Complex Double) [n] (1 : repeat 0) + -- time-domain energy: Σ |x[k]|² = 1 + tEnergy = fst $ A.sumAll (A.real (xs * A.conjg xs) :: A.Array Double) + -- frequency-domain energy: (1/N) Σ |X[k]|² = (1/N)*N = 1 + xf = A.fft xs 1.0 n + fEnergy = (1.0 / fromIntegral n) * fst (A.sumAll (A.real (xf * A.conjg xf) :: A.Array Double)) + tEnergy `shouldBeApprox` 1.0 + tEnergy `shouldBeApprox` fEnergy diff --git a/test/ArrayFire/SignalSpec.hs b/test/ArrayFire/SignalSpec.hs index 06b890e..4a043e6 100644 --- a/test/ArrayFire/SignalSpec.hs +++ b/test/ArrayFire/SignalSpec.hs @@ -2,19 +2,68 @@ module ArrayFire.SignalSpec where import qualified ArrayFire as A -import Data.Int -import Data.Word import Data.Complex -import Data.Proxy -import Foreign.C.Types import Test.Hspec +-- | Check all elements of two Complex Double arrays are within tolerance. +shouldBeApproxC + :: A.Array (Complex Double) + -> A.Array (Complex Double) + -> Expectation +shouldBeApproxC actual expected = + zipWith (\a e -> magnitude (a - e)) + (A.toList @(Complex Double) actual) + (A.toList @(Complex Double) expected) + `shouldSatisfy` all (< 1e-10) + spec :: Spec spec = - describe "Signal spec" $ do - it "Should do FFT in place" $ do - A.fftInPlace (A.matrix @(Complex Double) (1,1) [[1 :+ 1]]) 10.2 - `shouldReturn` () - it "Should do FFT" $ do - A.fft (A.matrix @(Complex Float) (1,1) [[1 :+ 1]]) 1 1 - `shouldBe` A.matrix @(Complex Float) (1,1) [[1 :+ 1]] + describe "Signal" $ do + + describe "fft" $ do + it "fftInPlace runs without error" $ do + A.fftInPlace (A.scalar @(Complex Double) (1 :+ 0)) 1.0 + `shouldReturn` () + + it "transform of a Dirac delta is a flat spectrum" $ do + A.fft (A.mkArray @(Complex Double) [4] [1,0,0,0]) 1.0 4 + `shouldBeApproxC` + A.mkArray @(Complex Double) [4] [1,1,1,1] + + it "transform of all-ones concentrates all energy at DC" $ do + A.fft (A.mkArray @(Complex Double) [4] [1,1,1,1]) 1.0 4 + `shouldBeApproxC` + A.mkArray @(Complex Double) [4] [4,0,0,0] + + it "normalization factor scales the output" $ do + A.fft (A.mkArray @(Complex Double) [4] [1,0,0,0]) 2.0 4 + `shouldBeApproxC` + A.mkArray @(Complex Double) [4] [2,2,2,2] + + it "ifft . fft is the identity" $ do + let n = 8 + input = A.mkArray @(Complex Double) [n] (map (:+ 0) [1..8]) + A.ifft (A.fft input 1.0 n) (1.0 / fromIntegral n) n + `shouldBeApproxC` input + + it "fft output_size pads with zeros when larger than input" $ do + -- 4-point FFT of a 2-point signal padded to 4: input [1,1,0,0] + A.fft (A.mkArray @(Complex Double) [2] [1,1]) 1.0 4 + `shouldBeApproxC` + A.fft (A.mkArray @(Complex Double) [4] [1,1,0,0]) 1.0 4 + + describe "fft2" $ do + it "2D transform of a Dirac delta is a flat spectrum" $ do + A.fft2 (A.mkArray @(Complex Double) [4,4] (1 : replicate 15 0)) 1.0 4 4 + `shouldBeApproxC` + A.mkArray @(Complex Double) [4,4] (replicate 16 1) + + it "ifft2 . fft2 is the identity" $ do + let input = A.mkArray @(Complex Double) [4,4] (map (:+ 0) [1..16]) + A.ifft2 (A.fft2 input 1.0 4 4) (1.0 / 16) 4 4 + `shouldBeApproxC` input + + it "2D transform of all-ones concentrates all energy at DC" $ do + A.fft2 (A.mkArray @(Complex Double) [4,4] (replicate 16 1)) 1.0 4 4 + `shouldBeApproxC` + A.mkArray @(Complex Double) [4,4] (16 : replicate 15 0) diff --git a/test/ArrayFire/SparseSpec.hs b/test/ArrayFire/SparseSpec.hs index b90c931..a16569a 100644 --- a/test/ArrayFire/SparseSpec.hs +++ b/test/ArrayFire/SparseSpec.hs @@ -1,19 +1,70 @@ {-# LANGUAGE TypeApplications #-} module ArrayFire.SparseSpec where -import qualified ArrayFire as A +import qualified ArrayFire as A import Data.Int -import Data.Word -import Data.Complex -import Data.Proxy -import Foreign.C.Types import Test.Hspec +-- 3×3 diagonal matrix diag(1,2,3), stored column-major: +-- col0=[1,0,0], col1=[0,2,0], col2=[0,0,3] +diag3 :: A.Array Double +diag3 = A.mkArray @Double [3,3] [1,0,0, 0,2,0, 0,0,3] + spec :: Spec spec = - describe "Sparse spec" $ do - it "Should create a sparse array" $ do - (1+1) `shouldBe` 2 - -- A.createSparseArrayFromDense (A.matrix @Double (10,10) [1..]) A.CSR - -- `shouldBe` - -- A.vector @Double 10 [0..] + describe "Sparse" $ do + + describe "createSparseArrayFromDense" $ do + it "NNZ equals number of non-zero elements" $ do + A.sparseGetNNZ (A.createSparseArrayFromDense diag3 A.CSR) `shouldBe` 3 + it "all-zero matrix has NNZ 0" $ do + let zeros = A.mkArray @Double [3,3] (repeat 0) + A.sparseGetNNZ (A.createSparseArrayFromDense zeros A.CSR) `shouldBe` 0 + it "fully-dense matrix has NNZ equal to element count" $ do + let full = A.mkArray @Double [2,2] [1,2,3,4] + A.sparseGetNNZ (A.createSparseArrayFromDense full A.CSR) `shouldBe` 4 + it "storage format is preserved" $ do + A.sparseGetStorage (A.createSparseArrayFromDense diag3 A.CSR) `shouldBe` A.CSR + it "COO storage format is preserved" $ do + A.sparseGetStorage (A.createSparseArrayFromDense diag3 A.COO) `shouldBe` A.COO + + describe "sparseToDense" $ do + it "CSR round-trip preserves all values" $ do + A.sparseToDense (A.createSparseArrayFromDense diag3 A.CSR) `shouldBe` diag3 + it "COO round-trip preserves all values" $ do + A.sparseToDense (A.createSparseArrayFromDense diag3 A.COO) `shouldBe` diag3 + + describe "sparseConvertTo" $ do + it "CSR → COO preserves NNZ" $ do + let coo = A.sparseConvertTo (A.createSparseArrayFromDense diag3 A.CSR) A.COO + A.sparseGetNNZ coo `shouldBe` 3 + it "CSR → COO storage tag changes" $ do + let coo = A.sparseConvertTo (A.createSparseArrayFromDense diag3 A.CSR) A.COO + A.sparseGetStorage coo `shouldBe` A.COO + it "CSR → COO → Dense recovers original matrix" $ do + let coo = A.sparseConvertTo (A.createSparseArrayFromDense diag3 A.CSR) A.COO + A.sparseToDense coo `shouldBe` diag3 + + describe "sparseGetValues" $ do + it "diagonal matrix CSR values are the diagonal entries in row order" $ do + let sp = A.createSparseArrayFromDense diag3 A.CSR + A.sparseGetValues sp `shouldBe` A.vector @Double 3 [1,2,3] + + describe "sparseGetRowIdx / sparseGetColIdx" $ do + -- The underlying arrays are s32; we check length, not raw values. + it "CSR row pointer array has nrows+1 elements" $ do + let sp = A.createSparseArrayFromDense diag3 A.CSR + A.getElements (A.sparseGetRowIdx sp) `shouldBe` 4 + it "CSR column index array has NNZ elements" $ do + let sp = A.createSparseArrayFromDense diag3 A.CSR + A.getElements (A.sparseGetColIdx sp) `shouldBe` 3 + + describe "sparseGetInfo" $ do + it "values component matches sparseGetValues" $ do + let sp = A.createSparseArrayFromDense diag3 A.CSR + (vals, _, _, _) = A.sparseGetInfo sp + vals `shouldBe` A.sparseGetValues sp + it "storage tag matches sparseGetStorage" $ do + let sp = A.createSparseArrayFromDense diag3 A.CSR + (_, _, _, storage) = A.sparseGetInfo sp + storage `shouldBe` A.sparseGetStorage sp diff --git a/test/ArrayFire/StatisticsSpec.hs b/test/ArrayFire/StatisticsSpec.hs index c8c6314..50c7bd8 100644 --- a/test/ArrayFire/StatisticsSpec.hs +++ b/test/ArrayFire/StatisticsSpec.hs @@ -1,8 +1,10 @@ {-# LANGUAGE TypeApplications #-} module ArrayFire.StatisticsSpec where +import Data.Word (Word32) import ArrayFire hiding (not) +import Data.Maybe import Data.Complex import Test.Hspec import Test.Hspec.ApproxExpect @@ -15,9 +17,9 @@ spec = `shouldBe` 5.5 it "Should find the weighted-mean" $ do - meanWeighted (vector @Double 10 [1..]) (vector @Double 10 [1..]) 0 - `shouldBeApprox` - 7.0 + listToMaybe (toList (meanWeighted (vector @Double 10 [1..]) (vector @Double 10 [1..]) 0)) + `shouldBe` + (Just 7.0) it "Should find the variance" $ do var (vector @Double 8 [1..8]) False 0 `shouldBe` @@ -69,4 +71,20 @@ spec = it "Should find the top k elements" $ do let (vals,indexes) = topk ( vector @Double 10 [1..] ) 3 TopKDefault vals `shouldBe` vector @Double 3 [10,9,8] - indexes `shouldBe` vector @Double 3 [9,8,7] + indexes `shouldBe` vector @Word32 3 [9,8,7] + it "Should compute mean and variance together (population)" $ do + let (m, v) = meanVar (vector @Double 4 [1,2,3,4]) VariancePopulation 0 + m `shouldBe` scalar @Double 2.5 + v `shouldBe` scalar @Double 1.25 + it "Should compute mean and variance together (sample)" $ do + let (m, v) = meanVar (vector @Double 4 [1,2,3,4]) VarianceSample 0 + m `shouldBe` scalar @Double 2.5 + -- sample variance of [1,2,3,4] = 5/3 ≈ 1.6667 + case listToMaybe (toList v) of + Just k -> k `shouldBeApprox` (5.0/3.0) + _ -> error "failure" + it "Should compute weighted mean and variance together" $ do + let uniform = vector @Double 4 (repeat 1.0) + (m, v) = meanVarWeighted (vector @Double 4 [1,2,3,4]) uniform VariancePopulation 0 + m `shouldBe` scalar @Double 2.5 + v `shouldBe` scalar @Double 1.25 diff --git a/test/Main.hs b/test/Main.hs index c949527..598f042 100644 --- a/test/Main.hs +++ b/test/Main.hs @@ -1,9 +1,8 @@ -{-# LANGUAGE TypeApplications #-} -{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE GeneralisedNewtypeDeriving #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeApplications #-} module Main where -import Control.Monad - import Data.Proxy import Spec (spec) import Test.Hspec (hspec) @@ -13,32 +12,76 @@ import Test.QuickCheck.Classes import qualified ArrayFire as A import ArrayFire (Array) -import System.IO.Unsafe +import Foreign.C.Types (CBool (..)) +-- Multi-dimensional arrays: used for eqLaws, so the Eq instance is exercised +-- on matrices and tensors, not just scalars. instance (A.AFType a, Arbitrary a) => Arbitrary (Array a) where - arbitrary = pure $ unsafePerformIO (A.randu [2,2]) + arbitrary = do + ndim <- choose (1, 4) + dims <- vectorOf ndim (choose (1, 4)) + elems <- vectorOf (product dims) arbitrary + pure (A.mkArray dims elems) + shrink arr = + [ A.mkArray dims' (take (product dims') (A.toList arr)) + | dims' <- shrunkDims + , product dims' > 0 + ] + where + (d0, d1, d2, d3) = A.getDims arr + ndim = A.getNumDims arr + currentDims = take ndim [d0, d1, d2, d3] + shrunkDims = + [ [if i == j then d - 1 else d | (j, d) <- zip [0..] currentDims] + | i <- [0 .. ndim - 1] + , currentDims !! i > 1 + ] + ++ [take (ndim - 1) currentDims | ndim > 1] + +-- Scalar wrapper for numLaws. +-- Num laws require: (a) binary ops succeed for any two generated values, and +-- (b) `fromInteger 0` compares equal to `0 * x`. Both hold only when all +-- arrays are the same shape. Scalars ([1 1 1 1]) are the minimal fixed shape +-- that makes every Num law well-typed and exact for integer element types. +newtype Scalar a = Scalar (Array a) + deriving (Show, Eq, Num) + +instance Arbitrary CBool where + arbitrary = CBool <$> arbitrary + +instance (A.AFType a, Arbitrary a) => Arbitrary (Scalar a) where + arbitrary = Scalar . A.scalar <$> arbitrary + shrink (Scalar arr) = Scalar . A.scalar <$> case A.toList arr of + x : _ -> shrink x + [] -> [] main :: IO () main = do - A.setBackend A.CPU --- checks (Proxy :: Proxy (A.Array (A.Complex Float))) --- checks (Proxy :: Proxy (A.Array (A.Complex Double))) --- checks (Proxy :: Proxy (A.Array Double)) --- checks (Proxy :: Proxy (A.Array Float)) --- checks (Proxy :: Proxy (A.Array Double)) --- checks (Proxy :: Proxy (A.Array A.Int16)) --- checks (Proxy :: Proxy (A.Array A.Int32)) - -- checks (Proxy :: Proxy (A.Array A.CBool)) - -- checks (Proxy :: Proxy (A.Array Word)) - -- checks (Proxy :: Proxy (A.Array A.Word8)) - -- checks (Proxy :: Proxy (A.Array A.Word16)) - -- checks (Proxy :: Proxy (A.Array A.Word32)) --- lawsCheck $ semigroupLaws (Proxy :: Proxy (A.Array Double)) --- lawsCheck $ semigroupLaws (Proxy :: Proxy (A.Array Float)) hspec spec + -- IEEE 754 is not an exact ring; only Eq laws for floating-point arrays. + lawsCheck (eqLaws (Proxy :: Proxy (Array Double))) + lawsCheck (eqLaws (Proxy :: Proxy (Array Float))) + lawsCheck (showLaws (Proxy :: Proxy (Array Float))) + lawsCheck (showLaws (Proxy :: Proxy (Array Double))) + -- Complex: Eq only (IEEE 754 + gt/lt undefined for complex numbers). + lawsCheck (eqLaws (Proxy :: Proxy (Array (A.Complex Double)))) + lawsCheck (eqLaws (Proxy :: Proxy (Array (A.Complex Float)))) + lawsCheck (showLaws (Proxy :: Proxy (Array (A.Complex Double)))) + lawsCheck (showLaws (Proxy :: Proxy (Array (A.Complex Float)))) + -- Integral types: exact ring laws via Scalar, Eq laws via multi-dim Array. + intChecks (Proxy :: Proxy Int) + intChecks (Proxy :: Proxy A.Int16) + intChecks (Proxy :: Proxy A.Int32) + intChecks (Proxy :: Proxy A.Int64) + intChecks (Proxy :: Proxy A.Word8) + intChecks (Proxy :: Proxy A.Word16) + intChecks (Proxy :: Proxy A.Word32) + intChecks (Proxy :: Proxy A.Word64) + intChecks (Proxy :: Proxy Word) + intChecks (Proxy :: Proxy A.CBool) -checks proxy = do - lawsCheck (numLaws proxy) - lawsCheck (eqLaws proxy) - lawsCheck (ordLaws proxy) --- lawsCheck (semigroupLaws proxy) +intChecks :: forall a. (A.AFType a, Arbitrary a, Num a, Eq a) => Proxy a -> IO () +intChecks _ = do + lawsCheck (showLaws (Proxy :: Proxy (Array a))) + lawsCheck (numLaws (Proxy :: Proxy (Scalar a))) + lawsCheck (eqLaws (Proxy :: Proxy (Array a))) diff --git a/test/Test/Hspec/ApproxExpect.hs b/test/Test/Hspec/ApproxExpect.hs index 3e9d66b..e1830a9 100644 --- a/test/Test/Hspec/ApproxExpect.hs +++ b/test/Test/Hspec/ApproxExpect.hs @@ -1,19 +1,22 @@ -{-# LANGUAGE TypeApplications #-} {-# LANGUAGE ScopedTypeVariables #-} module Test.Hspec.ApproxExpect where import Data.CallStack (HasCallStack) - import Test.Hspec (shouldSatisfy, Expectation) infix 1 `shouldBeApprox` -shouldBeApprox :: (HasCallStack, Show a, Fractional a, Eq a) - => a -> a -> Expectation -shouldBeApprox actual tgt - -- This is a hackish way of checking, without requiring a specific - -- type or an 'Ord' instance, whether two floating-point values - -- are only some epsilons apart: when the difference is small enough - -- so scaling it down some more makes it a no-op for addition. - = actual `shouldSatisfy` \x -> (x-tgt) * 1e-4 + tgt == tgt - +-- | Assert two floating-point values are within relative + absolute tolerance. +-- +-- Uses the same formula as numpy.testing.assert_allclose: +-- |a - b| <= atol + rtol * max(|a|, |b|) +-- with rtol = 1e-5 and atol = 1e-8, matching numpy defaults. +shouldBeApprox + :: (HasCallStack, Show a, Ord a, Fractional a) + => a -> a -> Expectation +shouldBeApprox actual expected = + actual `shouldSatisfy` \x -> + abs (x - expected) <= atol + rtol * max (abs x) (abs expected) + where + rtol = 1e-5 + atol = 1e-8