{-# LANGUAGE TypeFamilies #-}

-- | Our compilation strategy for 'SegHist' is based around avoiding
-- bin conflicts.  We do this by splitting the input into chunks, and
-- for each chunk computing a single subhistogram.  Then we combine
-- the subhistograms using an ordinary segmented reduction ('SegRed').
--
-- There are some branches around to efficiently handle the case where
-- we use only a single subhistogram (because it's large), so that we
-- respect the asymptotics, and do not copy the destination array.
--
-- We also use a heuristic strategy for computing subhistograms in
-- shared memory when possible.  Given:
--
-- H: total size of histograms in bytes, including any lock arrays.
--
-- G: block size
--
-- T: number of bytes of shared memory each thread can be given without
-- impacting occupancy (determined experimentally, e.g. 32).
--
-- LMAX: maximum amount of shared memory per threadblock (hard limit).
--
-- We wish to compute:
--
-- COOP: cooperation level (number of threads per subhistogram)
--
-- LH: number of shared memory subhistograms
--
-- We do this as:
--
-- COOP = ceil(H / T)
-- LH = ceil((G*T)/H)
-- if COOP <= G && H <= LMAX then
--   use shared memory
-- else
--   use global memory
module Futhark.CodeGen.ImpGen.GPU.SegHist (compileSegHist) where

import Control.Monad
import Data.List qualified as L
import Data.Map qualified as M
import Data.Maybe
import Futhark.CodeGen.ImpCode.GPU qualified as Imp
import Futhark.CodeGen.ImpGen
import Futhark.CodeGen.ImpGen.GPU.Base
import Futhark.CodeGen.ImpGen.GPU.SegRed (compileSegRed')
import Futhark.Construct (fullSliceNum)
import Futhark.IR.GPUMem
import Futhark.IR.Mem.LMAD qualified as LMAD
import Futhark.Pass.ExplicitAllocations ()
import Futhark.Transform.Substitute
import Futhark.Util (chunks, mapAccumLM, maxinum, splitFromEnd, takeLast)
import Futhark.Util.IntegralExp (divUp, quot, rem)
import Prelude hiding (quot, rem)

data SubhistosInfo = SubhistosInfo
  { SubhistosInfo -> VName
subhistosArray :: VName,
    SubhistosInfo -> CallKernelGen ()
subhistosAlloc :: CallKernelGen ()
  }

data SegHistSlug = SegHistSlug
  { SegHistSlug -> HistOp GPUMem
slugOp :: HistOp GPUMem,
    SegHistSlug -> TV Int64
slugNumSubhistos :: TV Int64,
    SegHistSlug -> [SubhistosInfo]
slugSubhistos :: [SubhistosInfo],
    SegHistSlug -> AtomicUpdate GPUMem KernelEnv
slugAtomicUpdate :: AtomicUpdate GPUMem KernelEnv
  }

histSpaceUsage ::
  HistOp GPUMem ->
  Imp.Count Imp.Bytes (Imp.TExp Int64)
histSpaceUsage :: HistOp GPUMem -> Count Bytes (TExp Int64)
histSpaceUsage HistOp GPUMem
op =
  [Count Bytes (TExp Int64)] -> Count Bytes (TExp Int64)
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ([Count Bytes (TExp Int64)] -> Count Bytes (TExp Int64))
-> ([Type] -> [Count Bytes (TExp Int64)])
-> [Type]
-> Count Bytes (TExp Int64)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Type -> Count Bytes (TExp Int64))
-> [Type] -> [Count Bytes (TExp Int64)]
forall a b. (a -> b) -> [a] -> [b]
map (Type -> Count Bytes (TExp Int64)
typeSize (Type -> Count Bytes (TExp Int64))
-> (Type -> Type) -> Type -> Count Bytes (TExp Int64)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Type -> Shape -> Type
`arrayOfShape` (HistOp GPUMem -> Shape
forall rep. HistOp rep -> Shape
histShape HistOp GPUMem
op Shape -> Shape -> Shape
forall a. Semigroup a => a -> a -> a
<> HistOp GPUMem -> Shape
forall rep. HistOp rep -> Shape
histOpShape HistOp GPUMem
op))) ([Type] -> Count Bytes (TExp Int64))
-> [Type] -> Count Bytes (TExp Int64)
forall a b. (a -> b) -> a -> b
$
    Lambda GPUMem -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType (Lambda GPUMem -> [Type]) -> Lambda GPUMem -> [Type]
forall a b. (a -> b) -> a -> b
$
      HistOp GPUMem -> Lambda GPUMem
forall rep. HistOp rep -> Lambda rep
histOp HistOp GPUMem
op

histSize :: HistOp GPUMem -> Imp.TExp Int64
histSize :: HistOp GPUMem -> TExp Int64
histSize = [TExp Int64] -> TExp Int64
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product ([TExp Int64] -> TExp Int64)
-> (HistOp GPUMem -> [TExp Int64]) -> HistOp GPUMem -> TExp Int64
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (SubExp -> TExp Int64) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
pe64 ([SubExp] -> [TExp Int64])
-> (HistOp GPUMem -> [SubExp]) -> HistOp GPUMem -> [TExp Int64]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims (Shape -> [SubExp])
-> (HistOp GPUMem -> Shape) -> HistOp GPUMem -> [SubExp]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HistOp GPUMem -> Shape
forall rep. HistOp rep -> Shape
histShape

histRank :: HistOp GPUMem -> Int
histRank :: HistOp GPUMem -> Int
histRank = Shape -> Int
forall a. ArrayShape a => a -> Int
shapeRank (Shape -> Int) -> (HistOp GPUMem -> Shape) -> HistOp GPUMem -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HistOp GPUMem -> Shape
forall rep. HistOp rep -> Shape
histShape

-- | Figure out how much memory is needed per histogram, both
-- segmented and unsegmented, and compute some other auxiliary
-- information.
computeHistoUsage ::
  SegSpace ->
  HistOp GPUMem ->
  CallKernelGen
    ( Imp.Count Imp.Bytes (Imp.TExp Int64),
      Imp.Count Imp.Bytes (Imp.TExp Int64),
      SegHistSlug
    )
computeHistoUsage :: SegSpace
-> HistOp GPUMem
-> CallKernelGen
     (Count Bytes (TExp Int64), Count Bytes (TExp Int64), SegHistSlug)
computeHistoUsage SegSpace
space HistOp GPUMem
op = do
  let segment_dims :: [(VName, SubExp)]
segment_dims = [(VName, SubExp)] -> [(VName, SubExp)]
forall a. HasCallStack => [a] -> [a]
init ([(VName, SubExp)] -> [(VName, SubExp)])
-> [(VName, SubExp)] -> [(VName, SubExp)]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
      num_segments :: Int
num_segments = [(VName, SubExp)] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [(VName, SubExp)]
segment_dims

  -- Create names for the intermediate array memory blocks,
  -- memory block sizes, arrays, and number of subhistograms.
  num_subhistos <- [Char] -> ImpM GPUMem HostEnv HostOp (TV Int64)
forall {k} (t :: k) rep r op.
MkTV t =>
[Char] -> ImpM rep r op (TV t)
dPrim [Char]
"num_subhistos"
  subhisto_infos <- forM (zip (histDest op) (histNeutral op)) $ \(VName
dest, SubExp
ne) -> do
    dest_t <- VName -> ImpM GPUMem HostEnv HostOp Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
dest
    dest_mem <- entryArrayLoc <$> lookupArray dest

    subhistos_mem <-
      sDeclareMem (baseString dest ++ "_subhistos_mem") (Space "device")

    let subhistos_shape =
          [SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape (((VName, SubExp) -> SubExp) -> [(VName, SubExp)] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (VName, SubExp) -> SubExp
forall a b. (a, b) -> b
snd [(VName, SubExp)]
segment_dims [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ [TV Int64 -> SubExp
forall {k} (t :: k). TV t -> SubExp
tvSize TV Int64
num_subhistos])
            Shape -> Shape -> Shape
forall a. Semigroup a => a -> a -> a
<> Int -> Shape -> Shape
forall d. Int -> ShapeBase d -> ShapeBase d
stripDims Int
num_segments (Type -> Shape
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape Type
dest_t)
    subhistos <-
      sArray
        (baseString dest ++ "_subhistos")
        (elemType dest_t)
        subhistos_shape
        subhistos_mem
        $ LMAD.iota 0
        $ map pe64
        $ shapeDims subhistos_shape

    pure $
      SubhistosInfo subhistos $ do
        let unitHistoCase =
              Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$
                VName -> VName -> Space -> Code HostOp
forall a. VName -> VName -> Space -> Code a
Imp.SetMem VName
subhistos_mem (MemLoc -> VName
memLocName MemLoc
dest_mem) (Space -> Code HostOp) -> Space -> Code HostOp
forall a b. (a -> b) -> a -> b
$
                  [Char] -> Space
Space [Char]
"device"

            multiHistoCase = do
              let num_elems :: TExp Int64
num_elems = [TExp Int64] -> TExp Int64
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product ([TExp Int64] -> TExp Int64) -> [TExp Int64] -> TExp Int64
forall a b. (a -> b) -> a -> b
$ (SubExp -> TExp Int64) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
pe64 ([SubExp] -> [TExp Int64]) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> a -> b
$ Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
subhistos_shape
                  subhistos_mem_size :: Count Bytes (TExp Int64)
subhistos_mem_size =
                    TExp Int64 -> Count Bytes (TExp Int64)
forall a. a -> Count Bytes a
Imp.bytes (TExp Int64 -> Count Bytes (TExp Int64))
-> TExp Int64 -> Count Bytes (TExp Int64)
forall a b. (a -> b) -> a -> b
$
                      Count Bytes (TExp Int64) -> TExp Int64
forall {k} (u :: k) e. Count u e -> e
Imp.unCount (TExp Int64 -> Count Elements (TExp Int64)
forall a. a -> Count Elements a
Imp.elements TExp Int64
num_elems Count Elements (TExp Int64) -> PrimType -> Count Bytes (TExp Int64)
`Imp.withElemType` Type -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType Type
dest_t)

              VName -> Count Bytes (TExp Int64) -> Space -> CallKernelGen ()
forall rep r op.
VName -> Count Bytes (TExp Int64) -> Space -> ImpM rep r op ()
sAlloc_ VName
subhistos_mem Count Bytes (TExp Int64)
subhistos_mem_size (Space -> CallKernelGen ()) -> Space -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ [Char] -> Space
Space [Char]
"device"
              VName -> SubExp -> CallKernelGen ()
sReplicate VName
subhistos SubExp
ne
              subhistos_t <- VName -> ImpM GPUMem HostEnv HostOp Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
subhistos
              let slice =
                    [TExp Int64] -> [DimIndex (TExp Int64)] -> Slice (TExp Int64)
forall d. Num d => [d] -> [DimIndex d] -> Slice d
fullSliceNum ((SubExp -> TExp Int64) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
pe64 ([SubExp] -> [TExp Int64]) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> a -> b
$ Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
subhistos_t) ([DimIndex (TExp Int64)] -> Slice (TExp Int64))
-> [DimIndex (TExp Int64)] -> Slice (TExp Int64)
forall a b. (a -> b) -> a -> b
$
                      ((VName, SubExp) -> DimIndex (TExp Int64))
-> [(VName, SubExp)] -> [DimIndex (TExp Int64)]
forall a b. (a -> b) -> [a] -> [b]
map (TExp Int64 -> TExp Int64 -> DimIndex (TExp Int64)
forall d. Num d => d -> d -> DimIndex d
unitSlice TExp Int64
0 (TExp Int64 -> DimIndex (TExp Int64))
-> ((VName, SubExp) -> TExp Int64)
-> (VName, SubExp)
-> DimIndex (TExp Int64)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> TExp Int64
pe64 (SubExp -> TExp Int64)
-> ((VName, SubExp) -> SubExp) -> (VName, SubExp) -> TExp Int64
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName, SubExp) -> SubExp
forall a b. (a, b) -> b
snd) [(VName, SubExp)]
segment_dims
                        [DimIndex (TExp Int64)]
-> [DimIndex (TExp Int64)] -> [DimIndex (TExp Int64)]
forall a. [a] -> [a] -> [a]
++ [TExp Int64 -> DimIndex (TExp Int64)
forall d. d -> DimIndex d
DimFix TExp Int64
0]
              sUpdate subhistos slice $ Var dest

        sIf (tvExp num_subhistos .==. 1) unitHistoCase multiHistoCase

  let h = HistOp GPUMem -> Count Bytes (TExp Int64)
histSpaceUsage HistOp GPUMem
op
      segmented_h = Count Bytes (TExp Int64)
h Count Bytes (TExp Int64)
-> Count Bytes (TExp Int64) -> Count Bytes (TExp Int64)
forall a. Num a => a -> a -> a
* [Count Bytes (TExp Int64)] -> Count Bytes (TExp Int64)
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product ((SubExp -> Count Bytes (TExp Int64))
-> [SubExp] -> [Count Bytes (TExp Int64)]
forall a b. (a -> b) -> [a] -> [b]
map (TExp Int64 -> Count Bytes (TExp Int64)
forall a. a -> Count Bytes a
Imp.bytes (TExp Int64 -> Count Bytes (TExp Int64))
-> (SubExp -> TExp Int64) -> SubExp -> Count Bytes (TExp Int64)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> TExp Int64
pe64) ([SubExp] -> [Count Bytes (TExp Int64)])
-> [SubExp] -> [Count Bytes (TExp Int64)]
forall a b. (a -> b) -> a -> b
$ [SubExp] -> [SubExp]
forall a. HasCallStack => [a] -> [a]
init ([SubExp] -> [SubExp]) -> [SubExp] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [SubExp]
segSpaceDims SegSpace
space)

  atomics <- hostAtomics <$> askEnv

  pure
    ( h,
      segmented_h,
      SegHistSlug op num_subhistos subhisto_infos $
        atomicUpdateLocking atomics $
          histOp op
    )

prepareAtomicUpdateGlobal ::
  Maybe Locking ->
  Shape ->
  [VName] ->
  SegHistSlug ->
  CallKernelGen
    ( Maybe Locking,
      [Imp.TExp Int64] -> InKernelGen ()
    )
prepareAtomicUpdateGlobal :: Maybe Locking
-> Shape
-> [VName]
-> SegHistSlug
-> CallKernelGen (Maybe Locking, [TExp Int64] -> InKernelGen ())
prepareAtomicUpdateGlobal Maybe Locking
l Shape
segments [VName]
dests SegHistSlug
slug =
  -- We need a separate lock array if the operators are not all of a
  -- particularly simple form that permits pure atomic operations.
  case (Maybe Locking
l, SegHistSlug -> AtomicUpdate GPUMem KernelEnv
slugAtomicUpdate SegHistSlug
slug) of
    (Maybe Locking
_, AtomicPrim DoAtomicUpdate GPUMem KernelEnv
f) -> (Maybe Locking, [TExp Int64] -> InKernelGen ())
-> CallKernelGen (Maybe Locking, [TExp Int64] -> InKernelGen ())
forall a. a -> ImpM GPUMem HostEnv HostOp a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe Locking
l, DoAtomicUpdate GPUMem KernelEnv
f ([Char] -> Space
Space [Char]
"global") [VName]
dests)
    (Maybe Locking
_, AtomicCAS DoAtomicUpdate GPUMem KernelEnv
f) -> (Maybe Locking, [TExp Int64] -> InKernelGen ())
-> CallKernelGen (Maybe Locking, [TExp Int64] -> InKernelGen ())
forall a. a -> ImpM GPUMem HostEnv HostOp a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe Locking
l, DoAtomicUpdate GPUMem KernelEnv
f ([Char] -> Space
Space [Char]
"global") [VName]
dests)
    (Just Locking
l', AtomicLocking Locking -> DoAtomicUpdate GPUMem KernelEnv
f) -> (Maybe Locking, [TExp Int64] -> InKernelGen ())
-> CallKernelGen (Maybe Locking, [TExp Int64] -> InKernelGen ())
forall a. a -> ImpM GPUMem HostEnv HostOp a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe Locking
l, Locking -> DoAtomicUpdate GPUMem KernelEnv
f Locking
l' ([Char] -> Space
Space [Char]
"global") [VName]
dests)
    (Maybe Locking
Nothing, AtomicLocking Locking -> DoAtomicUpdate GPUMem KernelEnv
f) -> do
      -- The number of locks used here is too low, but since we are
      -- currently forced to inline a huge list, I'm keeping it down
      -- for now.  Some quick experiments suggested that it has little
      -- impact anyway (maybe the locking case is just too slow).
      --
      -- A fun solution would also be to use a simple hashing
      -- algorithm to ensure good distribution of locks.
      let num_locks :: Int
num_locks = Int
100151
          dims :: [TExp Int64]
dims =
            (SubExp -> TExp Int64) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
pe64 ([SubExp] -> [TExp Int64]) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> a -> b
$
              Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
segments
                [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims (HistOp GPUMem -> Shape
forall rep. HistOp rep -> Shape
histOpShape (SegHistSlug -> HistOp GPUMem
slugOp SegHistSlug
slug))
                [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ [TV Int64 -> SubExp
forall {k} (t :: k). TV t -> SubExp
tvSize (SegHistSlug -> TV Int64
slugNumSubhistos SegHistSlug
slug)]
                [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims (HistOp GPUMem -> Shape
forall rep. HistOp rep -> Shape
histShape (SegHistSlug -> HistOp GPUMem
slugOp SegHistSlug
slug))

      locks <- [Char] -> Int -> ImpM GPUMem HostEnv HostOp VName
genZeroes [Char]
"hist_locks" Int
num_locks
      let l' = VName
-> TExp Int32
-> TExp Int32
-> TExp Int32
-> ([TExp Int64] -> [TExp Int64])
-> Locking
Locking VName
locks TExp Int32
0 TExp Int32
1 TExp Int32
0 (TExp Int64 -> [TExp Int64]
forall a. a -> [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (TExp Int64 -> [TExp Int64])
-> ([TExp Int64] -> TExp Int64) -> [TExp Int64] -> [TExp Int64]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (TExp Int64 -> TExp Int64 -> TExp Int64
forall e. IntegralExp e => e -> e -> e
`rem` Int -> TExp Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
num_locks) (TExp Int64 -> TExp Int64)
-> ([TExp Int64] -> TExp Int64) -> [TExp Int64] -> TExp Int64
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [TExp Int64] -> [TExp Int64] -> TExp Int64
forall num. IntegralExp num => [num] -> [num] -> num
flattenIndex [TExp Int64]
dims)
      pure (Just l', f l' (Space "global") dests)

-- | Some kernel bodies are not safe (or efficient) to execute
-- multiple times.
data Passage = MustBeSinglePass | MayBeMultiPass deriving (Passage -> Passage -> Bool
(Passage -> Passage -> Bool)
-> (Passage -> Passage -> Bool) -> Eq Passage
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: Passage -> Passage -> Bool
== :: Passage -> Passage -> Bool
$c/= :: Passage -> Passage -> Bool
/= :: Passage -> Passage -> Bool
Eq, Eq Passage
Eq Passage =>
(Passage -> Passage -> Ordering)
-> (Passage -> Passage -> Bool)
-> (Passage -> Passage -> Bool)
-> (Passage -> Passage -> Bool)
-> (Passage -> Passage -> Bool)
-> (Passage -> Passage -> Passage)
-> (Passage -> Passage -> Passage)
-> Ord Passage
Passage -> Passage -> Bool
Passage -> Passage -> Ordering
Passage -> Passage -> Passage
forall a.
Eq a =>
(a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
$ccompare :: Passage -> Passage -> Ordering
compare :: Passage -> Passage -> Ordering
$c< :: Passage -> Passage -> Bool
< :: Passage -> Passage -> Bool
$c<= :: Passage -> Passage -> Bool
<= :: Passage -> Passage -> Bool
$c> :: Passage -> Passage -> Bool
> :: Passage -> Passage -> Bool
$c>= :: Passage -> Passage -> Bool
>= :: Passage -> Passage -> Bool
$cmax :: Passage -> Passage -> Passage
max :: Passage -> Passage -> Passage
$cmin :: Passage -> Passage -> Passage
min :: Passage -> Passage -> Passage
Ord)

bodyPassage :: KernelBody GPUMem -> Passage
bodyPassage :: KernelBody GPUMem -> Passage
bodyPassage KernelBody GPUMem
kbody
  | Names
forall a. Monoid a => a
mempty Names -> Names -> Bool
forall a. Eq a => a -> a -> Bool
== KernelBody (Aliases GPUMem) -> Names
forall rep. Aliased rep => KernelBody rep -> Names
consumedInKernelBody (AliasTable -> KernelBody GPUMem -> KernelBody (Aliases GPUMem)
forall rep.
AliasableRep rep =>
AliasTable -> KernelBody rep -> KernelBody (Aliases rep)
aliasAnalyseKernelBody AliasTable
forall a. Monoid a => a
mempty KernelBody GPUMem
kbody) =
      Passage
MayBeMultiPass
  | Bool
otherwise =
      Passage
MustBeSinglePass

prepareIntermediateArraysGlobal ::
  Passage ->
  Shape ->
  Imp.TExp Int32 ->
  Imp.TExp Int64 ->
  [SegHistSlug] ->
  CallKernelGen
    ( Imp.TExp Int32,
      [[Imp.TExp Int64] -> InKernelGen ()]
    )
prepareIntermediateArraysGlobal :: Passage
-> Shape
-> TExp Int32
-> TExp Int64
-> [SegHistSlug]
-> CallKernelGen (TExp Int32, [[TExp Int64] -> InKernelGen ()])
prepareIntermediateArraysGlobal Passage
passage Shape
segments TExp Int32
hist_T TExp Int64
hist_N [SegHistSlug]
slugs = do
  -- The paper formulae assume there is only one histogram, but in our
  -- implementation there can be multiple that have been horisontally
  -- fused.  We do a bit of trickery with summings and averages to
  -- pretend there is really only one.  For the case of a single
  -- histogram, the actual calculations should be the same as in the
  -- paper.

  -- The sum of all Hs.
  hist_H <- [Char] -> TExp Int64 -> ImpM GPUMem HostEnv HostOp (TExp Int64)
forall {k} (t :: k) rep r op.
[Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"hist_H" (TExp Int64 -> ImpM GPUMem HostEnv HostOp (TExp Int64))
-> TExp Int64 -> ImpM GPUMem HostEnv HostOp (TExp Int64)
forall a b. (a -> b) -> a -> b
$ [TExp Int64] -> TExp Int64
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ([TExp Int64] -> TExp Int64) -> [TExp Int64] -> TExp Int64
forall a b. (a -> b) -> a -> b
$ (SegHistSlug -> TExp Int64) -> [SegHistSlug] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map (HistOp GPUMem -> TExp Int64
histSize (HistOp GPUMem -> TExp Int64)
-> (SegHistSlug -> HistOp GPUMem) -> SegHistSlug -> TExp Int64
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegHistSlug -> HistOp GPUMem
slugOp) [SegHistSlug]
slugs

  hist_RF <-
    dPrimVE "hist_RF" $
      sum (map (r64 . pe64 . histRaceFactor . slugOp) slugs)
        / L.genericLength slugs

  hist_el_size <- dPrimVE "hist_el_size" $ sum $ map slugElAvgSize slugs

  hist_C_max <-
    dPrimVE "hist_C_max" $
      fMin64 (r64 hist_T) $
        r64 hist_H / hist_k_ct_min

  hist_M_min <-
    dPrimVE "hist_M_min" $
      sMax32 1 $
        sExt32 $
          t64 $
            r64 hist_T / hist_C_max

  -- Equivalent to F_L2*L2 in paper.
  hist_L2 <- getSize "hist_L2" Imp.SizeCache

  let hist_L2_ln_sz = TExp Double
16 TExp Double -> TExp Double -> TExp Double
forall a. Num a => a -> a -> a
* TExp Double
4 -- L2 cache line size approximation
  hist_RACE_exp <-
    dPrimVE "hist_RACE_exp" $
      fMax64 1 $
        (hist_k_RF * hist_RF)
          / (hist_L2_ln_sz / r64 hist_el_size)

  hist_S <- dPrim "hist_S"

  -- For sparse histograms (H exceeds N) we only want a single chunk.
  sIf
    (hist_N .<. hist_H)
    (hist_S <-- (1 :: Imp.TExp Int32))
    $ hist_S
      <-- case passage of
        Passage
MayBeMultiPass ->
          TExp Int64 -> TExp Int32
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (TExp Int64 -> TExp Int32) -> TExp Int64 -> TExp Int32
forall a b. (a -> b) -> a -> b
$
            (TExp Int32 -> TExp Int64
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
hist_M_min TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* TExp Int64
hist_H TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* TExp Int32 -> TExp Int64
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
hist_el_size)
              TExp Int64 -> TExp Int64 -> TExp Int64
forall e. IntegralExp e => e -> e -> e
`divUp` TExp Double -> TExp Int64
forall {k} {t :: k} {v}. TPrimExp t v -> TPrimExp Int64 v
t64 (TExp Double
hist_F_L2 TExp Double -> TExp Double -> TExp Double
forall a. Num a => a -> a -> a
* TExp Int64 -> TExp Double
forall {k} {t :: k} {v}. TPrimExp t v -> TPrimExp Double v
r64 (TV Int64 -> TExp Int64
forall {k} (t :: k). TV t -> TExp t
tvExp TV Int64
hist_L2) TExp Double -> TExp Double -> TExp Double
forall a. Num a => a -> a -> a
* TExp Double
hist_RACE_exp)
        Passage
MustBeSinglePass ->
          TExp Int32
1

  emit $ Imp.DebugPrint "Race expansion factor (RACE^exp)" $ Just $ untyped hist_RACE_exp
  emit $ Imp.DebugPrint "Number of chunks (S)" $ Just $ untyped $ tvExp hist_S

  histograms <-
    snd
      <$> mapAccumLM
        (onOp (tvExp hist_L2) hist_M_min (tvExp hist_S) hist_RACE_exp)
        Nothing
        slugs

  pure (tvExp hist_S, histograms)
  where
    hist_k_ct_min :: TExp Double
hist_k_ct_min = TExp Double
2 -- Chosen experimentally
    hist_k_RF :: TExp Double
hist_k_RF = TExp Double
0.75 -- Chosen experimentally
    hist_F_L2 :: TExp Double
hist_F_L2 = TExp Double
0.4 -- Chosen experimentally
    r64 :: TPrimExp t v -> TPrimExp Double v
r64 = PrimExp v -> TPrimExp Double v
forall v. PrimExp v -> TPrimExp Double v
isF64 (PrimExp v -> TPrimExp Double v)
-> (TPrimExp t v -> PrimExp v) -> TPrimExp t v -> TPrimExp Double v
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ConvOp -> PrimExp v -> PrimExp v
forall v. ConvOp -> PrimExp v -> PrimExp v
ConvOpExp (IntType -> FloatType -> ConvOp
SIToFP IntType
Int32 FloatType
Float64) (PrimExp v -> PrimExp v)
-> (TPrimExp t v -> PrimExp v) -> TPrimExp t v -> PrimExp v
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TPrimExp t v -> PrimExp v
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped
    t64 :: TPrimExp t v -> TPrimExp Int64 v
t64 = PrimExp v -> TPrimExp Int64 v
forall v. PrimExp v -> TPrimExp Int64 v
isInt64 (PrimExp v -> TPrimExp Int64 v)
-> (TPrimExp t v -> PrimExp v) -> TPrimExp t v -> TPrimExp Int64 v
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ConvOp -> PrimExp v -> PrimExp v
forall v. ConvOp -> PrimExp v -> PrimExp v
ConvOpExp (FloatType -> IntType -> ConvOp
FPToSI FloatType
Float64 IntType
Int64) (PrimExp v -> PrimExp v)
-> (TPrimExp t v -> PrimExp v) -> TPrimExp t v -> PrimExp v
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TPrimExp t v -> PrimExp v
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped

    -- "Average element size" as computed by a formula that also takes
    -- locking into account.
    slugElAvgSize :: SegHistSlug -> TExp Int32
slugElAvgSize slug :: SegHistSlug
slug@(SegHistSlug HistOp GPUMem
op TV Int64
_ [SubhistosInfo]
_ AtomicUpdate GPUMem KernelEnv
do_op) =
      case AtomicUpdate GPUMem KernelEnv
do_op of
        AtomicLocking {} ->
          SegHistSlug -> TExp Int32
slugElSize SegHistSlug
slug TExp Int32 -> TExp Int32 -> TExp Int32
forall e. IntegralExp e => e -> e -> e
`quot` (TExp Int32
1 TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
+ [Type] -> TExp Int32
forall i a. Num i => [a] -> i
L.genericLength (Lambda GPUMem -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType (HistOp GPUMem -> Lambda GPUMem
forall rep. HistOp rep -> Lambda rep
histOp HistOp GPUMem
op)))
        AtomicUpdate GPUMem KernelEnv
_ ->
          SegHistSlug -> TExp Int32
slugElSize SegHistSlug
slug TExp Int32 -> TExp Int32 -> TExp Int32
forall e. IntegralExp e => e -> e -> e
`quot` [Type] -> TExp Int32
forall i a. Num i => [a] -> i
L.genericLength (Lambda GPUMem -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType (HistOp GPUMem -> Lambda GPUMem
forall rep. HistOp rep -> Lambda rep
histOp HistOp GPUMem
op))

    -- "Average element size" as computed by a formula that also takes
    -- locking into account.
    slugElSize :: SegHistSlug -> TExp Int32
slugElSize (SegHistSlug HistOp GPUMem
op TV Int64
_ [SubhistosInfo]
_ AtomicUpdate GPUMem KernelEnv
do_op) =
      TExp Int64 -> TExp Int32
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (TExp Int64 -> TExp Int32)
-> ([Count Bytes (TExp Int64)] -> TExp Int64)
-> [Count Bytes (TExp Int64)]
-> TExp Int32
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Count Bytes (TExp Int64) -> TExp Int64
forall {k} (u :: k) e. Count u e -> e
unCount (Count Bytes (TExp Int64) -> TExp Int64)
-> ([Count Bytes (TExp Int64)] -> Count Bytes (TExp Int64))
-> [Count Bytes (TExp Int64)]
-> TExp Int64
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Count Bytes (TExp Int64)] -> Count Bytes (TExp Int64)
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ([Count Bytes (TExp Int64)] -> TExp Int32)
-> [Count Bytes (TExp Int64)] -> TExp Int32
forall a b. (a -> b) -> a -> b
$
        case AtomicUpdate GPUMem KernelEnv
do_op of
          AtomicLocking {} ->
            (Type -> Count Bytes (TExp Int64))
-> [Type] -> [Count Bytes (TExp Int64)]
forall a b. (a -> b) -> [a] -> [b]
map (Type -> Count Bytes (TExp Int64)
typeSize (Type -> Count Bytes (TExp Int64))
-> (Type -> Type) -> Type -> Count Bytes (TExp Int64)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Type -> Shape -> Type
`arrayOfShape` HistOp GPUMem -> Shape
forall rep. HistOp rep -> Shape
histOpShape HistOp GPUMem
op)) ([Type] -> [Count Bytes (TExp Int64)])
-> [Type] -> [Count Bytes (TExp Int64)]
forall a b. (a -> b) -> a -> b
$
              PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int32 Type -> [Type] -> [Type]
forall a. a -> [a] -> [a]
: Lambda GPUMem -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType (HistOp GPUMem -> Lambda GPUMem
forall rep. HistOp rep -> Lambda rep
histOp HistOp GPUMem
op)
          AtomicUpdate GPUMem KernelEnv
_ ->
            (Type -> Count Bytes (TExp Int64))
-> [Type] -> [Count Bytes (TExp Int64)]
forall a b. (a -> b) -> [a] -> [b]
map (Type -> Count Bytes (TExp Int64)
typeSize (Type -> Count Bytes (TExp Int64))
-> (Type -> Type) -> Type -> Count Bytes (TExp Int64)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Type -> Shape -> Type
`arrayOfShape` HistOp GPUMem -> Shape
forall rep. HistOp rep -> Shape
histOpShape HistOp GPUMem
op)) ([Type] -> [Count Bytes (TExp Int64)])
-> [Type] -> [Count Bytes (TExp Int64)]
forall a b. (a -> b) -> a -> b
$
              Lambda GPUMem -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType (HistOp GPUMem -> Lambda GPUMem
forall rep. HistOp rep -> Lambda rep
histOp HistOp GPUMem
op)

    onOp :: TExp Int64
-> TExp Int32
-> TExp Int32
-> TExp Double
-> Maybe Locking
-> SegHistSlug
-> CallKernelGen (Maybe Locking, [TExp Int64] -> InKernelGen ())
onOp TExp Int64
hist_L2 TExp Int32
hist_M_min TExp Int32
hist_S TExp Double
hist_RACE_exp Maybe Locking
l SegHistSlug
slug = do
      let SegHistSlug HistOp GPUMem
op TV Int64
num_subhistos [SubhistosInfo]
subhisto_info AtomicUpdate GPUMem KernelEnv
do_op = SegHistSlug
slug
          hist_H :: TExp Int64
hist_H = HistOp GPUMem -> TExp Int64
histSize HistOp GPUMem
op

      hist_H_chk <- [Char] -> TExp Int64 -> ImpM GPUMem HostEnv HostOp (TExp Int64)
forall {k} (t :: k) rep r op.
[Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"hist_H_chk" (TExp Int64 -> ImpM GPUMem HostEnv HostOp (TExp Int64))
-> TExp Int64 -> ImpM GPUMem HostEnv HostOp (TExp Int64)
forall a b. (a -> b) -> a -> b
$ TExp Int64
hist_H TExp Int64 -> TExp Int64 -> TExp Int64
forall e. IntegralExp e => e -> e -> e
`divUp` TExp Int32 -> TExp Int64
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
hist_S

      emit $ Imp.DebugPrint "Chunk size (H_chk)" $ Just $ untyped hist_H_chk

      hist_k_max <-
        dPrimVE "hist_k_max" $
          fMin64
            (hist_F_L2 * (r64 hist_L2 / r64 (slugElSize slug)) * hist_RACE_exp)
            (r64 hist_N)
            / r64 hist_T

      hist_u <- dPrimVE "hist_u" $
        case do_op of
          AtomicPrim {} -> TExp Int64
2
          AtomicUpdate GPUMem KernelEnv
_ -> TExp Int64
1

      hist_C <-
        dPrimVE "hist_C" $
          fMin64 (r64 hist_T) $
            r64 (hist_u * hist_H_chk) / hist_k_max

      -- Number of subhistograms per result histogram.
      hist_M <- dPrimVE "hist_M" $
        case slugAtomicUpdate slug of
          AtomicPrim {} -> TExp Int32
1
          AtomicUpdate GPUMem KernelEnv
_ -> TExp Int32 -> TExp Int32 -> TExp Int32
forall v. TPrimExp Int32 v -> TPrimExp Int32 v -> TPrimExp Int32 v
sMax32 TExp Int32
hist_M_min (TExp Int32 -> TExp Int32) -> TExp Int32 -> TExp Int32
forall a b. (a -> b) -> a -> b
$ TExp Int64 -> TExp Int32
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (TExp Int64 -> TExp Int32) -> TExp Int64 -> TExp Int32
forall a b. (a -> b) -> a -> b
$ TExp Double -> TExp Int64
forall {k} {t :: k} {v}. TPrimExp t v -> TPrimExp Int64 v
t64 (TExp Double -> TExp Int64) -> TExp Double -> TExp Int64
forall a b. (a -> b) -> a -> b
$ TExp Int32 -> TExp Double
forall {k} {t :: k} {v}. TPrimExp t v -> TPrimExp Double v
r64 TExp Int32
hist_T TExp Double -> TExp Double -> TExp Double
forall a. Fractional a => a -> a -> a
/ TExp Double
hist_C

      emit $ Imp.DebugPrint "Elements/thread in L2 cache (k_max)" $ Just $ untyped hist_k_max
      emit $ Imp.DebugPrint "Multiplication degree (M)" $ Just $ untyped hist_M
      emit $ Imp.DebugPrint "Cooperation level (C)" $ Just $ untyped hist_C

      -- num_subhistos is the variable we use to communicate back.
      num_subhistos <-- sExt64 hist_M

      -- Initialise sub-histograms.
      --
      -- If hist_M is 1, then we just reuse the original
      -- destination.  The idea is to avoid a copy if we are writing a
      -- small number of values into a very large prior histogram.
      dests <- forM (zip (histDest op) subhisto_info) $ \(VName
dest, SubhistosInfo
info) -> do
        dest_mem <- ArrayEntry -> MemLoc
entryArrayLoc (ArrayEntry -> MemLoc)
-> ImpM GPUMem HostEnv HostOp ArrayEntry
-> ImpM GPUMem HostEnv HostOp MemLoc
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ImpM GPUMem HostEnv HostOp ArrayEntry
forall rep r op. VName -> ImpM rep r op ArrayEntry
lookupArray VName
dest

        sub_mem <-
          fmap memLocName $
            entryArrayLoc
              <$> lookupArray (subhistosArray info)

        let unitHistoCase =
              Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$
                VName -> VName -> Space -> Code HostOp
forall a. VName -> VName -> Space -> Code a
Imp.SetMem VName
sub_mem (MemLoc -> VName
memLocName MemLoc
dest_mem) (Space -> Code HostOp) -> Space -> Code HostOp
forall a b. (a -> b) -> a -> b
$
                  [Char] -> Space
Space [Char]
"device"

            multiHistoCase = SubhistosInfo -> CallKernelGen ()
subhistosAlloc SubhistosInfo
info

        sIf (hist_M .==. 1) unitHistoCase multiHistoCase

        pure $ subhistosArray info

      (l', do_op') <- prepareAtomicUpdateGlobal l segments dests slug

      pure (l', do_op')

histKernelGlobalPass ::
  [PatElem LetDecMem] ->
  Count NumBlocks SubExp ->
  Count BlockSize SubExp ->
  SegSpace ->
  [SegHistSlug] ->
  KernelBody GPUMem ->
  [[Imp.TExp Int64] -> InKernelGen ()] ->
  Imp.TExp Int32 ->
  Imp.TExp Int32 ->
  CallKernelGen ()
histKernelGlobalPass :: [PatElem LParamMem]
-> Count NumBlocks SubExp
-> Count BlockSize SubExp
-> SegSpace
-> [SegHistSlug]
-> KernelBody GPUMem
-> [[TExp Int64] -> InKernelGen ()]
-> TExp Int32
-> TExp Int32
-> CallKernelGen ()
histKernelGlobalPass [PatElem LParamMem]
map_pes Count NumBlocks SubExp
num_tblocks Count BlockSize SubExp
tblock_size SegSpace
space [SegHistSlug]
slugs KernelBody GPUMem
kbody [[TExp Int64] -> InKernelGen ()]
histograms TExp Int32
hist_S TExp Int32
chk_i = do
  let ([VName]
space_is, [SubExp]
space_sizes) = [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, SubExp)] -> ([VName], [SubExp]))
-> [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
      space_sizes_64 :: [TExp Int64]
space_sizes_64 = (SubExp -> TExp Int64) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map (TExp Int64 -> TExp Int64
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TExp Int64 -> TExp Int64)
-> (SubExp -> TExp Int64) -> SubExp -> TExp Int64
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> TExp Int64
pe64) [SubExp]
space_sizes
      total_w_64 :: TExp Int64
total_w_64 = [TExp Int64] -> TExp Int64
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [TExp Int64]
space_sizes_64

  hist_H_chks <- [TExp Int64]
-> (TExp Int64 -> ImpM GPUMem HostEnv HostOp (TExp Int64))
-> ImpM GPUMem HostEnv HostOp [TExp Int64]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ((SegHistSlug -> TExp Int64) -> [SegHistSlug] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map (HistOp GPUMem -> TExp Int64
histSize (HistOp GPUMem -> TExp Int64)
-> (SegHistSlug -> HistOp GPUMem) -> SegHistSlug -> TExp Int64
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegHistSlug -> HistOp GPUMem
slugOp) [SegHistSlug]
slugs) ((TExp Int64 -> ImpM GPUMem HostEnv HostOp (TExp Int64))
 -> ImpM GPUMem HostEnv HostOp [TExp Int64])
-> (TExp Int64 -> ImpM GPUMem HostEnv HostOp (TExp Int64))
-> ImpM GPUMem HostEnv HostOp [TExp Int64]
forall a b. (a -> b) -> a -> b
$ \TExp Int64
w ->
    [Char] -> TExp Int64 -> ImpM GPUMem HostEnv HostOp (TExp Int64)
forall {k} (t :: k) rep r op.
[Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"hist_H_chk" (TExp Int64 -> ImpM GPUMem HostEnv HostOp (TExp Int64))
-> TExp Int64 -> ImpM GPUMem HostEnv HostOp (TExp Int64)
forall a b. (a -> b) -> a -> b
$ TExp Int64
w TExp Int64 -> TExp Int64 -> TExp Int64
forall e. IntegralExp e => e -> e -> e
`divUp` TExp Int32 -> TExp Int64
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
hist_S

  sKernelThread "seghist_global" (segFlat space) (defKernelAttrs num_tblocks tblock_size) $ do
    constants <- kernelConstants <$> askEnv

    -- Compute subhistogram index for each thread, per histogram.
    subhisto_inds <- forM slugs $ \SegHistSlug
slug ->
      [Char] -> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32)
forall {k} (t :: k) rep r op.
[Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"subhisto_ind" (TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32))
-> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32)
forall a b. (a -> b) -> a -> b
$
        TExp Int32 -> TExp Int32
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (KernelConstants -> TExp Int32
kernelGlobalThreadId KernelConstants
constants)
          TExp Int32 -> TExp Int32 -> TExp Int32
forall e. IntegralExp e => e -> e -> e
`quot` ( KernelConstants -> TExp Int32
kernelNumThreads KernelConstants
constants
                     TExp Int32 -> TExp Int32 -> TExp Int32
forall e. IntegralExp e => e -> e -> e
`divUp` TExp Int64 -> TExp Int32
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (TV Int64 -> TExp Int64
forall {k} (t :: k). TV t -> TExp t
tvExp (SegHistSlug -> TV Int64
slugNumSubhistos SegHistSlug
slug))
                 )

    -- Loop over flat offsets into the input and output.  The
    -- calculation is done with 64-bit integers to avoid overflow,
    -- but the final unflattened segment indexes are 32 bit.
    let gtid = TExp Int32 -> TExp Int64
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TExp Int32 -> TExp Int64) -> TExp Int32 -> TExp Int64
forall a b. (a -> b) -> a -> b
$ KernelConstants -> TExp Int32
kernelGlobalThreadId KernelConstants
constants
        num_threads = TExp Int32 -> TExp Int64
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TExp Int32 -> TExp Int64) -> TExp Int32 -> TExp Int64
forall a b. (a -> b) -> a -> b
$ KernelConstants -> TExp Int32
kernelNumThreads KernelConstants
constants
    kernelLoop gtid num_threads total_w_64 $ \TExp Int64
offset -> do
      -- Construct segment indices.
      [(VName, TExp Int64)] -> TExp Int64 -> InKernelGen ()
forall rep r op.
[(VName, TExp Int64)] -> TExp Int64 -> ImpM rep r op ()
dIndexSpace ([VName] -> [TExp Int64] -> [(VName, TExp Int64)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
space_is [TExp Int64]
space_sizes_64) TExp Int64
offset

      -- We execute the bucket function once and update each histogram serially.
      -- We apply the bucket function if j=offset+ltid is less than
      -- num_elements.  This also involves writing to the mapout
      -- arrays.
      let input_in_bounds :: TExp Bool
input_in_bounds = TExp Int64
offset TExp Int64 -> TExp Int64 -> TExp Bool
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TExp Int64
total_w_64

      TExp Bool -> InKernelGen () -> InKernelGen ()
forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen TExp Bool
input_in_bounds (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
        Names -> Stms GPUMem -> InKernelGen () -> InKernelGen ()
forall rep r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms Names
forall a. Monoid a => a
mempty (KernelBody GPUMem -> Stms GPUMem
forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody GPUMem
kbody) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
          let ([KernelResult]
red_res, [KernelResult]
map_res) = Int -> [KernelResult] -> ([KernelResult], [KernelResult])
forall a. Int -> [a] -> ([a], [a])
splitFromEnd ([PatElem LParamMem] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [PatElem LParamMem]
map_pes) ([KernelResult] -> ([KernelResult], [KernelResult]))
-> [KernelResult] -> ([KernelResult], [KernelResult])
forall a b. (a -> b) -> a -> b
$ KernelBody GPUMem -> [KernelResult]
forall rep. KernelBody rep -> [KernelResult]
kernelBodyResult KernelBody GPUMem
kbody

          Text -> InKernelGen () -> InKernelGen ()
forall rep r op. Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"save map-out results" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
            [(PatElem LParamMem, KernelResult)]
-> ((PatElem LParamMem, KernelResult) -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([PatElem LParamMem]
-> [KernelResult] -> [(PatElem LParamMem, KernelResult)]
forall a b. [a] -> [b] -> [(a, b)]
zip [PatElem LParamMem]
map_pes [KernelResult]
map_res) (((PatElem LParamMem, KernelResult) -> InKernelGen ())
 -> InKernelGen ())
-> ((PatElem LParamMem, KernelResult) -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(PatElem LParamMem
pe, KernelResult
res) ->
              VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> InKernelGen ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix
                (PatElem LParamMem -> VName
forall dec. PatElem dec -> VName
patElemName PatElem LParamMem
pe)
                (((VName, SubExp) -> TExp Int64)
-> [(VName, SubExp)] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> TExp Int64
forall a. a -> TPrimExp Int64 a
Imp.le64 (VName -> TExp Int64)
-> ((VName, SubExp) -> VName) -> (VName, SubExp) -> TExp Int64
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName, SubExp) -> VName
forall a b. (a, b) -> a
fst) ([(VName, SubExp)] -> [TExp Int64])
-> [(VName, SubExp)] -> [TExp Int64]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space)
                (KernelResult -> SubExp
kernelResultSubExp KernelResult
res)
                []

          let red_res_split :: [([SubExp], [SubExp])]
red_res_split =
                [HistOp GPUMem] -> [SubExp] -> [([SubExp], [SubExp])]
forall rep. [HistOp rep] -> [SubExp] -> [([SubExp], [SubExp])]
splitHistResults ((SegHistSlug -> HistOp GPUMem) -> [SegHistSlug] -> [HistOp GPUMem]
forall a b. (a -> b) -> [a] -> [b]
map SegHistSlug -> HistOp GPUMem
slugOp [SegHistSlug]
slugs) ([SubExp] -> [([SubExp], [SubExp])])
-> [SubExp] -> [([SubExp], [SubExp])]
forall a b. (a -> b) -> a -> b
$
                  (KernelResult -> SubExp) -> [KernelResult] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map KernelResult -> SubExp
kernelResultSubExp [KernelResult]
red_res

          Text -> InKernelGen () -> InKernelGen ()
forall rep r op. Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"perform atomic updates" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
            [(HistOp GPUMem, [TExp Int64] -> InKernelGen (),
  ([SubExp], [SubExp]), TExp Int32, TExp Int64)]
-> ((HistOp GPUMem, [TExp Int64] -> InKernelGen (),
     ([SubExp], [SubExp]), TExp Int32, TExp Int64)
    -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([HistOp GPUMem]
-> [[TExp Int64] -> InKernelGen ()]
-> [([SubExp], [SubExp])]
-> [TExp Int32]
-> [TExp Int64]
-> [(HistOp GPUMem, [TExp Int64] -> InKernelGen (),
     ([SubExp], [SubExp]), TExp Int32, TExp Int64)]
forall a b c d e.
[a] -> [b] -> [c] -> [d] -> [e] -> [(a, b, c, d, e)]
L.zip5 ((SegHistSlug -> HistOp GPUMem) -> [SegHistSlug] -> [HistOp GPUMem]
forall a b. (a -> b) -> [a] -> [b]
map SegHistSlug -> HistOp GPUMem
slugOp [SegHistSlug]
slugs) [[TExp Int64] -> InKernelGen ()]
histograms [([SubExp], [SubExp])]
red_res_split [TExp Int32]
subhisto_inds [TExp Int64]
hist_H_chks) (((HistOp GPUMem, [TExp Int64] -> InKernelGen (),
   ([SubExp], [SubExp]), TExp Int32, TExp Int64)
  -> InKernelGen ())
 -> InKernelGen ())
-> ((HistOp GPUMem, [TExp Int64] -> InKernelGen (),
     ([SubExp], [SubExp]), TExp Int32, TExp Int64)
    -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
              \( HistOp Shape
dest_shape SubExp
_ [VName]
_ [SubExp]
_ Shape
shape Lambda GPUMem
lam,
                 [TExp Int64] -> InKernelGen ()
do_op,
                 ([SubExp]
bucket, [SubExp]
vs'),
                 TExp Int32
subhisto_ind,
                 TExp Int64
hist_H_chk
                 ) -> do
                  let chk_beg :: TExp Int64
chk_beg = TExp Int32 -> TExp Int64
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
chk_i TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* TExp Int64
hist_H_chk
                      bucket' :: [TExp Int64]
bucket' = (SubExp -> TExp Int64) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
pe64 [SubExp]
bucket
                      dest_shape' :: [TExp Int64]
dest_shape' = (SubExp -> TExp Int64) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
pe64 ([SubExp] -> [TExp Int64]) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> a -> b
$ Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
dest_shape
                      flat_bucket :: TExp Int64
flat_bucket = [TExp Int64] -> [TExp Int64] -> TExp Int64
forall num. IntegralExp num => [num] -> [num] -> num
flattenIndex [TExp Int64]
dest_shape' [TExp Int64]
bucket'
                      bucket_in_bounds :: TExp Bool
bucket_in_bounds =
                        TExp Int64
chk_beg
                          TExp Int64 -> TExp Int64 -> TExp Bool
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<=. TExp Int64
flat_bucket
                          TExp Bool -> TExp Bool -> TExp Bool
forall v.
Eq v =>
TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. TExp Int64
flat_bucket
                          TExp Int64 -> TExp Int64 -> TExp Bool
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. (TExp Int64
chk_beg TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
+ TExp Int64
hist_H_chk)
                          TExp Bool -> TExp Bool -> TExp Bool
forall v.
Eq v =>
TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. Slice (TExp Int64) -> [TExp Int64] -> TExp Bool
inBounds ([DimIndex (TExp Int64)] -> Slice (TExp Int64)
forall d. [DimIndex d] -> Slice d
Slice ((TExp Int64 -> DimIndex (TExp Int64))
-> [TExp Int64] -> [DimIndex (TExp Int64)]
forall a b. (a -> b) -> [a] -> [b]
map TExp Int64 -> DimIndex (TExp Int64)
forall d. d -> DimIndex d
DimFix [TExp Int64]
bucket')) [TExp Int64]
dest_shape'
                      vs_params :: [Param LParamMem]
vs_params = Int -> [Param LParamMem] -> [Param LParamMem]
forall a. Int -> [a] -> [a]
takeLast ([SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
vs') ([Param LParamMem] -> [Param LParamMem])
-> [Param LParamMem] -> [Param LParamMem]
forall a b. (a -> b) -> a -> b
$ Lambda GPUMem -> [LParam GPUMem]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda GPUMem
lam

                  TExp Bool -> InKernelGen () -> InKernelGen ()
forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen TExp Bool
bucket_in_bounds (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
                    let bucket_is :: [TExp Int64]
bucket_is =
                          (VName -> TExp Int64) -> [VName] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map VName -> TExp Int64
forall a. a -> TPrimExp Int64 a
Imp.le64 ([VName] -> [VName]
forall a. HasCallStack => [a] -> [a]
init [VName]
space_is)
                            [TExp Int64] -> [TExp Int64] -> [TExp Int64]
forall a. [a] -> [a] -> [a]
++ [TExp Int32 -> TExp Int64
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
subhisto_ind]
                            [TExp Int64] -> [TExp Int64] -> [TExp Int64]
forall a. [a] -> [a] -> [a]
++ [TExp Int64] -> TExp Int64 -> [TExp Int64]
forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex [TExp Int64]
dest_shape' TExp Int64
flat_bucket
                    [LParam GPUMem] -> InKernelGen ()
forall rep (inner :: * -> *) r op.
Mem rep inner =>
[LParam rep] -> ImpM rep r op ()
dLParams ([LParam GPUMem] -> InKernelGen ())
-> [LParam GPUMem] -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Lambda GPUMem -> [LParam GPUMem]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda GPUMem
lam
                    Shape -> ([TExp Int64] -> InKernelGen ()) -> InKernelGen ()
forall rep r op.
Shape -> ([TExp Int64] -> ImpM rep r op ()) -> ImpM rep r op ()
sLoopNest Shape
shape (([TExp Int64] -> InKernelGen ()) -> InKernelGen ())
-> ([TExp Int64] -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \[TExp Int64]
is -> do
                      [(Param LParamMem, SubExp)]
-> ((Param LParamMem, SubExp) -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LParamMem] -> [SubExp] -> [(Param LParamMem, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param LParamMem]
vs_params [SubExp]
vs') (((Param LParamMem, SubExp) -> InKernelGen ()) -> InKernelGen ())
-> ((Param LParamMem, SubExp) -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(Param LParamMem
p, SubExp
res) ->
                        VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> InKernelGen ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix (Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
p) [] SubExp
res [TExp Int64]
is
                      [TExp Int64] -> InKernelGen ()
do_op ([TExp Int64]
bucket_is [TExp Int64] -> [TExp Int64] -> [TExp Int64]
forall a. [a] -> [a] -> [a]
++ [TExp Int64]
is)

histKernelGlobal ::
  [PatElem LetDecMem] ->
  Count NumBlocks SubExp ->
  Count BlockSize SubExp ->
  SegSpace ->
  [SegHistSlug] ->
  KernelBody GPUMem ->
  CallKernelGen ()
histKernelGlobal :: [PatElem LParamMem]
-> Count NumBlocks SubExp
-> Count BlockSize SubExp
-> SegSpace
-> [SegHistSlug]
-> KernelBody GPUMem
-> CallKernelGen ()
histKernelGlobal [PatElem LParamMem]
map_pes Count NumBlocks SubExp
num_tblocks Count BlockSize SubExp
tblock_size SegSpace
space [SegHistSlug]
slugs KernelBody GPUMem
kbody = do
  let num_tblocks' :: Count NumBlocks (TExp Int64)
num_tblocks' = (SubExp -> TExp Int64)
-> Count NumBlocks SubExp -> Count NumBlocks (TExp Int64)
forall a b. (a -> b) -> Count NumBlocks a -> Count NumBlocks b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> TExp Int64
pe64 Count NumBlocks SubExp
num_tblocks
      tblock_size' :: Count BlockSize (TExp Int64)
tblock_size' = (SubExp -> TExp Int64)
-> Count BlockSize SubExp -> Count BlockSize (TExp Int64)
forall a b. (a -> b) -> Count BlockSize a -> Count BlockSize b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> TExp Int64
pe64 Count BlockSize SubExp
tblock_size
  let ([VName]
_space_is, [SubExp]
space_sizes) = [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, SubExp)] -> ([VName], [SubExp]))
-> [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
      num_threads :: TExp Int32
num_threads = TExp Int64 -> TExp Int32
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (TExp Int64 -> TExp Int32) -> TExp Int64 -> TExp Int32
forall a b. (a -> b) -> a -> b
$ Count NumBlocks (TExp Int64) -> TExp Int64
forall {k} (u :: k) e. Count u e -> e
unCount Count NumBlocks (TExp Int64)
num_tblocks' TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* Count BlockSize (TExp Int64) -> TExp Int64
forall {k} (u :: k) e. Count u e -> e
unCount Count BlockSize (TExp Int64)
tblock_size'

  Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ [Char] -> Maybe Exp -> Code HostOp
forall a. [Char] -> Maybe Exp -> Code a
Imp.DebugPrint [Char]
"## Using global memory" Maybe Exp
forall a. Maybe a
Nothing

  (hist_S, histograms) <-
    Passage
-> Shape
-> TExp Int32
-> TExp Int64
-> [SegHistSlug]
-> CallKernelGen (TExp Int32, [[TExp Int64] -> InKernelGen ()])
prepareIntermediateArraysGlobal
      (KernelBody GPUMem -> Passage
bodyPassage KernelBody GPUMem
kbody)
      ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape ([SubExp] -> [SubExp]
forall a. HasCallStack => [a] -> [a]
init [SubExp]
space_sizes))
      TExp Int32
num_threads
      (SubExp -> TExp Int64
pe64 (SubExp -> TExp Int64) -> SubExp -> TExp Int64
forall a b. (a -> b) -> a -> b
$ [SubExp] -> SubExp
forall a. HasCallStack => [a] -> a
last [SubExp]
space_sizes)
      [SegHistSlug]
slugs

  sFor "chk_i" hist_S $ \TExp Int32
chk_i ->
    [PatElem LParamMem]
-> Count NumBlocks SubExp
-> Count BlockSize SubExp
-> SegSpace
-> [SegHistSlug]
-> KernelBody GPUMem
-> [[TExp Int64] -> InKernelGen ()]
-> TExp Int32
-> TExp Int32
-> CallKernelGen ()
histKernelGlobalPass
      [PatElem LParamMem]
map_pes
      Count NumBlocks SubExp
num_tblocks
      Count BlockSize SubExp
tblock_size
      SegSpace
space
      [SegHistSlug]
slugs
      KernelBody GPUMem
kbody
      [[TExp Int64] -> InKernelGen ()]
histograms
      TExp Int32
hist_S
      TExp Int32
chk_i

type InitLocalHistograms =
  [ ( [VName],
      SubExp ->
      InKernelGen
        ( [VName],
          [Imp.TExp Int64] -> InKernelGen ()
        )
    )
  ]

prepareIntermediateArraysLocal ::
  TV Int32 ->
  Count NumBlocks (Imp.TExp Int64) ->
  [SegHistSlug] ->
  CallKernelGen InitLocalHistograms
prepareIntermediateArraysLocal :: TV Int32
-> Count NumBlocks (TExp Int64)
-> [SegHistSlug]
-> CallKernelGen InitLocalHistograms
prepareIntermediateArraysLocal TV Int32
num_subhistos_per_block Count NumBlocks (TExp Int64)
blocks_per_segment =
  (SegHistSlug
 -> ImpM
      GPUMem
      HostEnv
      HostOp
      ([VName],
       SubExp -> InKernelGen ([VName], [TExp Int64] -> InKernelGen ())))
-> [SegHistSlug] -> CallKernelGen InitLocalHistograms
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM SegHistSlug
-> ImpM
     GPUMem
     HostEnv
     HostOp
     ([VName],
      SubExp -> InKernelGen ([VName], [TExp Int64] -> InKernelGen ()))
onOp
  where
    onOp :: SegHistSlug
-> ImpM
     GPUMem
     HostEnv
     HostOp
     ([VName],
      SubExp -> InKernelGen ([VName], [TExp Int64] -> InKernelGen ()))
onOp (SegHistSlug HistOp GPUMem
op TV Int64
num_subhistos [SubhistosInfo]
subhisto_info AtomicUpdate GPUMem KernelEnv
do_op) = do
      TV Int64
num_subhistos TV Int64 -> TExp Int64 -> CallKernelGen ()
forall {k} (t :: k) rep r op. TV t -> TExp t -> ImpM rep r op ()
<-- TExp Int64 -> TExp Int64
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (Count NumBlocks (TExp Int64) -> TExp Int64
forall {k} (u :: k) e. Count u e -> e
unCount Count NumBlocks (TExp Int64)
blocks_per_segment)

      Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$
        [Char] -> Maybe Exp -> Code HostOp
forall a. [Char] -> Maybe Exp -> Code a
Imp.DebugPrint [Char]
"Number of subhistograms in global memory per segment" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$
          Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$
            TExp Int64 -> Exp
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped (TExp Int64 -> Exp) -> TExp Int64 -> Exp
forall a b. (a -> b) -> a -> b
$
              TV Int64 -> TExp Int64
forall {k} (t :: k). TV t -> TExp t
tvExp TV Int64
num_subhistos

      mk_op <-
        case AtomicUpdate GPUMem KernelEnv
do_op of
          AtomicPrim DoAtomicUpdate GPUMem KernelEnv
f -> (SubExp
 -> ImpM
      GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv))
-> ImpM
     GPUMem
     HostEnv
     HostOp
     (SubExp
      -> ImpM
           GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv))
forall a. a -> ImpM GPUMem HostEnv HostOp a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ((SubExp
  -> ImpM
       GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv))
 -> ImpM
      GPUMem
      HostEnv
      HostOp
      (SubExp
       -> ImpM
            GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv)))
-> (SubExp
    -> ImpM
         GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv))
-> ImpM
     GPUMem
     HostEnv
     HostOp
     (SubExp
      -> ImpM
           GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv))
forall a b. (a -> b) -> a -> b
$ ImpM GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv)
-> SubExp
-> ImpM GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv)
forall a b. a -> b -> a
const (ImpM GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv)
 -> SubExp
 -> ImpM
      GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv))
-> ImpM GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv)
-> SubExp
-> ImpM GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv)
forall a b. (a -> b) -> a -> b
$ DoAtomicUpdate GPUMem KernelEnv
-> ImpM GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv)
forall a. a -> ImpM GPUMem KernelEnv KernelOp a
forall (f :: * -> *) a. Applicative f => a -> f a
pure DoAtomicUpdate GPUMem KernelEnv
f
          AtomicCAS DoAtomicUpdate GPUMem KernelEnv
f -> (SubExp
 -> ImpM
      GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv))
-> ImpM
     GPUMem
     HostEnv
     HostOp
     (SubExp
      -> ImpM
           GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv))
forall a. a -> ImpM GPUMem HostEnv HostOp a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ((SubExp
  -> ImpM
       GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv))
 -> ImpM
      GPUMem
      HostEnv
      HostOp
      (SubExp
       -> ImpM
            GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv)))
-> (SubExp
    -> ImpM
         GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv))
-> ImpM
     GPUMem
     HostEnv
     HostOp
     (SubExp
      -> ImpM
           GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv))
forall a b. (a -> b) -> a -> b
$ ImpM GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv)
-> SubExp
-> ImpM GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv)
forall a b. a -> b -> a
const (ImpM GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv)
 -> SubExp
 -> ImpM
      GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv))
-> ImpM GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv)
-> SubExp
-> ImpM GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv)
forall a b. (a -> b) -> a -> b
$ DoAtomicUpdate GPUMem KernelEnv
-> ImpM GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv)
forall a. a -> ImpM GPUMem KernelEnv KernelOp a
forall (f :: * -> *) a. Applicative f => a -> f a
pure DoAtomicUpdate GPUMem KernelEnv
f
          AtomicLocking Locking -> DoAtomicUpdate GPUMem KernelEnv
f -> (SubExp
 -> ImpM
      GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv))
-> ImpM
     GPUMem
     HostEnv
     HostOp
     (SubExp
      -> ImpM
           GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv))
forall a. a -> ImpM GPUMem HostEnv HostOp a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ((SubExp
  -> ImpM
       GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv))
 -> ImpM
      GPUMem
      HostEnv
      HostOp
      (SubExp
       -> ImpM
            GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv)))
-> (SubExp
    -> ImpM
         GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv))
-> ImpM
     GPUMem
     HostEnv
     HostOp
     (SubExp
      -> ImpM
           GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv))
forall a b. (a -> b) -> a -> b
$ \SubExp
hist_H_chk -> do
            let lock_shape :: Shape
lock_shape =
                  [SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [TV Int32 -> SubExp
forall {k} (t :: k). TV t -> SubExp
tvSize TV Int32
num_subhistos_per_block, SubExp
hist_H_chk]

            let dims :: [TExp Int64]
dims = [TExp Int32 -> TExp Int64
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TV Int32 -> TExp Int32
forall {k} (t :: k). TV t -> TExp t
tvExp TV Int32
num_subhistos_per_block), SubExp -> TExp Int64
pe64 SubExp
hist_H_chk]

            locks <- [Char]
-> PrimType
-> Shape
-> Space
-> ImpM GPUMem KernelEnv KernelOp VName
forall rep r op.
[Char] -> PrimType -> Shape -> Space -> ImpM rep r op VName
sAllocArray [Char]
"locks" PrimType
int32 Shape
lock_shape (Space -> ImpM GPUMem KernelEnv KernelOp VName)
-> Space -> ImpM GPUMem KernelEnv KernelOp VName
forall a b. (a -> b) -> a -> b
$ [Char] -> Space
Space [Char]
"shared"

            sComment "All locks start out unlocked" $
              blockCoverSpace dims $ \[TExp Int64]
is ->
                VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> InKernelGen ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
locks [TExp Int64]
is (IntType -> Integer -> SubExp
intConst IntType
Int32 Integer
0) []

            pure $ f $ Locking locks 0 1 0 id

      -- Initialise local-memory sub-histograms.  These are
      -- represented as two-dimensional arrays.
      let init_local_subhistos SubExp
hist_H_chk = do
            local_subhistos <- [Type]
-> (Type -> ImpM GPUMem KernelEnv KernelOp VName)
-> ImpM GPUMem KernelEnv KernelOp [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (HistOp GPUMem -> [Type]
forall rep. HistOp rep -> [Type]
histType HistOp GPUMem
op) ((Type -> ImpM GPUMem KernelEnv KernelOp VName)
 -> ImpM GPUMem KernelEnv KernelOp [VName])
-> (Type -> ImpM GPUMem KernelEnv KernelOp VName)
-> ImpM GPUMem KernelEnv KernelOp [VName]
forall a b. (a -> b) -> a -> b
$ \Type
t -> do
              let subhisto_shape :: Shape
subhisto_shape =
                    Shape -> Int -> Shape -> Shape
forall d. ShapeBase d -> Int -> ShapeBase d -> ShapeBase d
setOuterDims
                      (Type -> Shape
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape Type
t)
                      (HistOp GPUMem -> Int
histRank HistOp GPUMem
op)
                      ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp
hist_H_chk])
              [Char]
-> PrimType
-> Shape
-> Space
-> ImpM GPUMem KernelEnv KernelOp VName
forall rep r op.
[Char] -> PrimType -> Shape -> Space -> ImpM rep r op VName
sAllocArray
                [Char]
"subhistogram_local"
                (Type -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType Type
t)
                ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [TV Int32 -> SubExp
forall {k} (t :: k). TV t -> SubExp
tvSize TV Int32
num_subhistos_per_block] Shape -> Shape -> Shape
forall a. Semigroup a => a -> a -> a
<> Shape
subhisto_shape)
                ([Char] -> Space
Space [Char]
"shared")

            do_op' <- mk_op hist_H_chk

            pure (local_subhistos, do_op' (Space "shared") local_subhistos)

      -- Initialise global-memory sub-histograms.
      glob_subhistos <- forM subhisto_info $ \SubhistosInfo
info -> do
        SubhistosInfo -> CallKernelGen ()
subhistosAlloc SubhistosInfo
info
        VName -> ImpM GPUMem HostEnv HostOp VName
forall a. a -> ImpM GPUMem HostEnv HostOp a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (VName -> ImpM GPUMem HostEnv HostOp VName)
-> VName -> ImpM GPUMem HostEnv HostOp VName
forall a b. (a -> b) -> a -> b
$ SubhistosInfo -> VName
subhistosArray SubhistosInfo
info

      pure (glob_subhistos, init_local_subhistos)

histKernelLocalPass ::
  TV Int32 ->
  Count NumBlocks (Imp.TExp Int64) ->
  [PatElem LetDecMem] ->
  Count NumBlocks SubExp ->
  Count BlockSize SubExp ->
  SegSpace ->
  [SegHistSlug] ->
  KernelBody GPUMem ->
  InitLocalHistograms ->
  Imp.TExp Int32 ->
  Imp.TExp Int32 ->
  CallKernelGen ()
histKernelLocalPass :: TV Int32
-> Count NumBlocks (TExp Int64)
-> [PatElem LParamMem]
-> Count NumBlocks SubExp
-> Count BlockSize SubExp
-> SegSpace
-> [SegHistSlug]
-> KernelBody GPUMem
-> InitLocalHistograms
-> TExp Int32
-> TExp Int32
-> CallKernelGen ()
histKernelLocalPass
  TV Int32
num_subhistos_per_block_var
  Count NumBlocks (TExp Int64)
blocks_per_segment
  [PatElem LParamMem]
map_pes
  Count NumBlocks SubExp
num_tblocks
  Count BlockSize SubExp
tblock_size
  SegSpace
space
  [SegHistSlug]
slugs
  KernelBody GPUMem
kbody
  InitLocalHistograms
init_histograms
  TExp Int32
hist_S
  TExp Int32
chk_i = do
    let ([VName]
space_is, [SubExp]
space_sizes) = [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, SubExp)] -> ([VName], [SubExp]))
-> [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
        segment_is :: [VName]
segment_is = [VName] -> [VName]
forall a. HasCallStack => [a] -> [a]
init [VName]
space_is
        segment_dims :: [SubExp]
segment_dims = [SubExp] -> [SubExp]
forall a. HasCallStack => [a] -> [a]
init [SubExp]
space_sizes
        (VName
i_in_segment, SubExp
segment_size) = [(VName, SubExp)] -> (VName, SubExp)
forall a. HasCallStack => [a] -> a
last ([(VName, SubExp)] -> (VName, SubExp))
-> [(VName, SubExp)] -> (VName, SubExp)
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
        num_subhistos_per_block :: TExp Int32
num_subhistos_per_block = TV Int32 -> TExp Int32
forall {k} (t :: k). TV t -> TExp t
tvExp TV Int32
num_subhistos_per_block_var
        segment_size' :: TExp Int64
segment_size' = SubExp -> TExp Int64
pe64 SubExp
segment_size

    num_segments <- [Char] -> TExp Int64 -> ImpM GPUMem HostEnv HostOp (TExp Int64)
forall {k} (t :: k) rep r op.
[Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"num_segments" (TExp Int64 -> ImpM GPUMem HostEnv HostOp (TExp Int64))
-> TExp Int64 -> ImpM GPUMem HostEnv HostOp (TExp Int64)
forall a b. (a -> b) -> a -> b
$ [TExp Int64] -> TExp Int64
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product ([TExp Int64] -> TExp Int64) -> [TExp Int64] -> TExp Int64
forall a b. (a -> b) -> a -> b
$ (SubExp -> TExp Int64) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
pe64 [SubExp]
segment_dims

    hist_H_chks <- forM (map slugOp slugs) $ \HistOp GPUMem
op ->
      [Char] -> TExp Int64 -> ImpM GPUMem HostEnv HostOp (TV Int64)
forall {k} (t :: k) rep r op.
[Char] -> TExp t -> ImpM rep r op (TV t)
dPrimV [Char]
"hist_H_chk" (TExp Int64 -> ImpM GPUMem HostEnv HostOp (TV Int64))
-> TExp Int64 -> ImpM GPUMem HostEnv HostOp (TV Int64)
forall a b. (a -> b) -> a -> b
$ HistOp GPUMem -> TExp Int64
histSize HistOp GPUMem
op TExp Int64 -> TExp Int64 -> TExp Int64
forall e. IntegralExp e => e -> e -> e
`divUp` TExp Int32 -> TExp Int64
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
hist_S

    histo_sizes <- forM (zip slugs hist_H_chks) $ \(SegHistSlug
slug, TV Int64
hist_H_chk) -> do
      let histo_dims :: [TExp Int64]
histo_dims =
            TV Int64 -> TExp Int64
forall {k} (t :: k). TV t -> TExp t
tvExp TV Int64
hist_H_chk TExp Int64 -> [TExp Int64] -> [TExp Int64]
forall a. a -> [a] -> [a]
: (SubExp -> TExp Int64) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
pe64 (Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims (HistOp GPUMem -> Shape
forall rep. HistOp rep -> Shape
histOpShape (SegHistSlug -> HistOp GPUMem
slugOp SegHistSlug
slug)))
      histo_size <-
        [Char] -> TExp Int64 -> ImpM GPUMem HostEnv HostOp (TExp Int64)
forall {k} (t :: k) rep r op.
[Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"histo_size" (TExp Int64 -> ImpM GPUMem HostEnv HostOp (TExp Int64))
-> TExp Int64 -> ImpM GPUMem HostEnv HostOp (TExp Int64)
forall a b. (a -> b) -> a -> b
$ [TExp Int64] -> TExp Int64
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [TExp Int64]
histo_dims
      let block_hists_size = TExp Int32 -> TExp Int64
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
num_subhistos_per_block TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* TExp Int64
histo_size
      init_per_thread <-
        dPrimVE "init_per_thread" $ sExt32 $ block_hists_size `divUp` pe64 (unCount tblock_size)
      pure (histo_dims, histo_size, init_per_thread)

    let attrs = (Count NumBlocks SubExp -> Count BlockSize SubExp -> KernelAttrs
defKernelAttrs Count NumBlocks SubExp
num_tblocks Count BlockSize SubExp
tblock_size) {kAttrCheckSharedMemory = False}
    sKernelThread "seghist_local" (segFlat space) attrs $
      virtualiseBlocks SegVirt (sExt32 $ unCount blocks_per_segment * num_segments) $ \TExp Int32
tblock_id -> do
        constants <- KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> KernelConstants)
-> ImpM GPUMem KernelEnv KernelOp KernelEnv
-> ImpM GPUMem KernelEnv KernelOp KernelConstants
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM GPUMem KernelEnv KernelOp KernelEnv
forall rep r op. ImpM rep r op r
askEnv

        flat_segment_id <- dPrimVE "flat_segment_id" $ tblock_id `quot` sExt32 (unCount blocks_per_segment)
        gid_in_segment <- dPrimVE "gid_in_segment" $ tblock_id `rem` sExt32 (unCount blocks_per_segment)
        -- This pgtid is kind of a "virtualised physical" gtid - not the
        -- same thing as the gtid used for the SegHist itself.
        pgtid_in_segment <-
          dPrimVE "pgtid_in_segment" $
            gid_in_segment * sExt32 (kernelBlockSize constants)
              + kernelLocalThreadId constants
        threads_per_segment <-
          dPrimVE "threads_per_segment" $
            sExt32 $
              unCount blocks_per_segment * kernelBlockSize constants

        -- Set segment indices.
        zipWithM_ dPrimV_ segment_is $
          unflattenIndex (map pe64 segment_dims) $
            sExt64 flat_segment_id

        histograms <- forM (zip init_histograms hist_H_chks) $
          \(([VName]
glob_subhistos, SubExp -> InKernelGen ([VName], [TExp Int64] -> InKernelGen ())
init_local_subhistos), TV Int64
hist_H_chk) -> do
            (local_subhistos, do_op) <- SubExp -> InKernelGen ([VName], [TExp Int64] -> InKernelGen ())
init_local_subhistos (SubExp -> InKernelGen ([VName], [TExp Int64] -> InKernelGen ()))
-> SubExp -> InKernelGen ([VName], [TExp Int64] -> InKernelGen ())
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ TV Int64 -> VName
forall {k} (t :: k). TV t -> VName
tvVar TV Int64
hist_H_chk
            pure (zip glob_subhistos local_subhistos, hist_H_chk, do_op)

        -- Find index of local subhistograms updated by this thread.  We
        -- try to ensure, as much as possible, that threads in the same
        -- warp use different subhistograms, to avoid conflicts.
        thread_local_subhisto_i <-
          dPrimVE "thread_local_subhisto_i" $
            kernelLocalThreadId constants `rem` num_subhistos_per_block

        let onSlugs SegHistSlug
-> [(VName, VName)]
-> TExp Int64
-> [TExp Int64]
-> TExp Int64
-> TExp Int32
-> InKernelGen ()
f =
              [(SegHistSlug,
  ([(VName, VName)], TV Int64, [TExp Int64] -> InKernelGen ()),
  ([TExp Int64], TExp Int64, TExp Int32))]
-> ((SegHistSlug,
     ([(VName, VName)], TV Int64, [TExp Int64] -> InKernelGen ()),
     ([TExp Int64], TExp Int64, TExp Int32))
    -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([SegHistSlug]
-> [([(VName, VName)], TV Int64, [TExp Int64] -> InKernelGen ())]
-> [([TExp Int64], TExp Int64, TExp Int32)]
-> [(SegHistSlug,
     ([(VName, VName)], TV Int64, [TExp Int64] -> InKernelGen ()),
     ([TExp Int64], TExp Int64, TExp Int32))]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [SegHistSlug]
slugs [([(VName, VName)], TV Int64, [TExp Int64] -> InKernelGen ())]
histograms [([TExp Int64], TExp Int64, TExp Int32)]
histo_sizes) (((SegHistSlug,
   ([(VName, VName)], TV Int64, [TExp Int64] -> InKernelGen ()),
   ([TExp Int64], TExp Int64, TExp Int32))
  -> InKernelGen ())
 -> InKernelGen ())
-> ((SegHistSlug,
     ([(VName, VName)], TV Int64, [TExp Int64] -> InKernelGen ()),
     ([TExp Int64], TExp Int64, TExp Int32))
    -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
                \(SegHistSlug
slug, ([(VName, VName)]
dests, TV Int64
hist_H_chk, [TExp Int64] -> InKernelGen ()
_), ([TExp Int64]
histo_dims, TExp Int64
histo_size, TExp Int32
init_per_thread)) ->
                  SegHistSlug
-> [(VName, VName)]
-> TExp Int64
-> [TExp Int64]
-> TExp Int64
-> TExp Int32
-> InKernelGen ()
f SegHistSlug
slug [(VName, VName)]
dests (TV Int64 -> TExp Int64
forall {k} (t :: k). TV t -> TExp t
tvExp TV Int64
hist_H_chk) [TExp Int64]
histo_dims TExp Int64
histo_size TExp Int32
init_per_thread

        let onAllHistograms VName
-> VName
-> HistOp GPUMem
-> SubExp
-> TExp Int32
-> TExp Int32
-> [TExp Int64]
-> [TExp Int64]
-> InKernelGen ()
f =
              (SegHistSlug
 -> [(VName, VName)]
 -> TExp Int64
 -> [TExp Int64]
 -> TExp Int64
 -> TExp Int32
 -> InKernelGen ())
-> InKernelGen ()
onSlugs ((SegHistSlug
  -> [(VName, VName)]
  -> TExp Int64
  -> [TExp Int64]
  -> TExp Int64
  -> TExp Int32
  -> InKernelGen ())
 -> InKernelGen ())
-> (SegHistSlug
    -> [(VName, VName)]
    -> TExp Int64
    -> [TExp Int64]
    -> TExp Int64
    -> TExp Int32
    -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \SegHistSlug
slug [(VName, VName)]
dests TExp Int64
hist_H_chk [TExp Int64]
histo_dims TExp Int64
histo_size TExp Int32
init_per_thread -> do
                let block_hists_size :: TExp Int32
block_hists_size = TExp Int32
num_subhistos_per_block TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
* TExp Int64 -> TExp Int32
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 TExp Int64
histo_size

                [((VName, VName), SubExp)]
-> (((VName, VName), SubExp) -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([(VName, VName)] -> [SubExp] -> [((VName, VName), SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [(VName, VName)]
dests (HistOp GPUMem -> [SubExp]
forall rep. HistOp rep -> [SubExp]
histNeutral (HistOp GPUMem -> [SubExp]) -> HistOp GPUMem -> [SubExp]
forall a b. (a -> b) -> a -> b
$ SegHistSlug -> HistOp GPUMem
slugOp SegHistSlug
slug)) ((((VName, VName), SubExp) -> InKernelGen ()) -> InKernelGen ())
-> (((VName, VName), SubExp) -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
                  \((VName
dest_global, VName
dest_local), SubExp
ne) ->
                    [Char]
-> TExp Int32 -> (TExp Int32 -> InKernelGen ()) -> InKernelGen ()
forall {k} (t :: k) rep r op.
[Char]
-> TExp t -> (TExp t -> ImpM rep r op ()) -> ImpM rep r op ()
sFor [Char]
"local_i" TExp Int32
init_per_thread ((TExp Int32 -> InKernelGen ()) -> InKernelGen ())
-> (TExp Int32 -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \TExp Int32
i -> do
                      j <-
                        [Char] -> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32)
forall {k} (t :: k) rep r op.
[Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"j" (TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32))
-> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32)
forall a b. (a -> b) -> a -> b
$
                          TExp Int32
i TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
* TExp Int64 -> TExp Int32
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (KernelConstants -> TExp Int64
kernelBlockSize KernelConstants
constants)
                            TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
+ KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants
                      j_offset <-
                        dPrimVE "j_offset" $
                          num_subhistos_per_block * sExt32 histo_size * gid_in_segment + j

                      local_subhisto_i <- dPrimVE "local_subhisto_i" $ j `quot` sExt32 histo_size
                      let local_bucket_is = [TExp Int64] -> TExp Int64 -> [TExp Int64]
forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex [TExp Int64]
histo_dims (TExp Int64 -> [TExp Int64]) -> TExp Int64 -> [TExp Int64]
forall a b. (a -> b) -> a -> b
$ TExp Int32 -> TExp Int64
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TExp Int32 -> TExp Int64) -> TExp Int32 -> TExp Int64
forall a b. (a -> b) -> a -> b
$ TExp Int32
j TExp Int32 -> TExp Int32 -> TExp Int32
forall e. IntegralExp e => e -> e -> e
`rem` TExp Int64 -> TExp Int32
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 TExp Int64
histo_size
                          nested_hist_size =
                            (SubExp -> TExp Int64) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
pe64 ([SubExp] -> [TExp Int64]) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> a -> b
$ Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims (Shape -> [SubExp]) -> Shape -> [SubExp]
forall a b. (a -> b) -> a -> b
$ HistOp GPUMem -> Shape
forall rep. HistOp rep -> Shape
histShape (HistOp GPUMem -> Shape) -> HistOp GPUMem -> Shape
forall a b. (a -> b) -> a -> b
$ SegHistSlug -> HistOp GPUMem
slugOp SegHistSlug
slug

                          global_bucket_is =
                            [TExp Int64] -> TExp Int64 -> [TExp Int64]
forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex
                              [TExp Int64]
nested_hist_size
                              ([TExp Int64] -> TExp Int64
forall a. HasCallStack => [a] -> a
head [TExp Int64]
local_bucket_is TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
+ TExp Int32 -> TExp Int64
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
chk_i TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* TExp Int64
hist_H_chk)
                              [TExp Int64] -> [TExp Int64] -> [TExp Int64]
forall a. [a] -> [a] -> [a]
++ [TExp Int64] -> [TExp Int64]
forall a. HasCallStack => [a] -> [a]
tail [TExp Int64]
local_bucket_is
                      global_subhisto_i <- dPrimVE "global_subhisto_i" $ j_offset `quot` sExt32 histo_size

                      sWhen (j .<. block_hists_size) $
                        f
                          dest_local
                          dest_global
                          (slugOp slug)
                          ne
                          local_subhisto_i
                          global_subhisto_i
                          local_bucket_is
                          global_bucket_is

        sComment "initialize histograms in shared memory" $
          onAllHistograms $ \VName
dest_local VName
dest_global HistOp GPUMem
op SubExp
ne TExp Int32
local_subhisto_i TExp Int32
global_subhisto_i [TExp Int64]
local_bucket_is [TExp Int64]
global_bucket_is ->
            Text -> InKernelGen () -> InKernelGen ()
forall rep r op. Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"First subhistogram is initialised from global memory; others with neutral element." (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
              dest_global_shape <- (SubExp -> TExp Int64) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
pe64 ([SubExp] -> [TExp Int64])
-> (Type -> [SubExp]) -> Type -> [TExp Int64]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims (Type -> [TExp Int64])
-> ImpM GPUMem KernelEnv KernelOp Type
-> ImpM GPUMem KernelEnv KernelOp [TExp Int64]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ImpM GPUMem KernelEnv KernelOp Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
dest_global
              let global_is = (VName -> TExp Int64) -> [VName] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map VName -> TExp Int64
forall a. a -> TPrimExp Int64 a
Imp.le64 [VName]
segment_is [TExp Int64] -> [TExp Int64] -> [TExp Int64]
forall a. [a] -> [a] -> [a]
++ [TExp Int64
0] [TExp Int64] -> [TExp Int64] -> [TExp Int64]
forall a. [a] -> [a] -> [a]
++ [TExp Int64]
global_bucket_is
                  local_is = TExp Int32 -> TExp Int64
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
local_subhisto_i TExp Int64 -> [TExp Int64] -> [TExp Int64]
forall a. a -> [a] -> [a]
: [TExp Int64]
local_bucket_is
                  global_in_bounds =
                    Slice (TExp Int64) -> [TExp Int64] -> TExp Bool
inBounds ([DimIndex (TExp Int64)] -> Slice (TExp Int64)
forall d. [DimIndex d] -> Slice d
Slice ((TExp Int64 -> DimIndex (TExp Int64))
-> [TExp Int64] -> [DimIndex (TExp Int64)]
forall a b. (a -> b) -> [a] -> [b]
map TExp Int64 -> DimIndex (TExp Int64)
forall d. d -> DimIndex d
DimFix [TExp Int64]
global_is)) [TExp Int64]
dest_global_shape
              sIf
                (global_subhisto_i .==. 0 .&&. global_in_bounds)
                (copyDWIMFix dest_local local_is (Var dest_global) global_is)
                ( sLoopNest (histOpShape op) $ \[TExp Int64]
is ->
                    VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> InKernelGen ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
dest_local ([TExp Int64]
local_is [TExp Int64] -> [TExp Int64] -> [TExp Int64]
forall a. [a] -> [a] -> [a]
++ [TExp Int64]
is) SubExp
ne []
                )

        sOp $ Imp.Barrier Imp.FenceLocal

        kernelLoop (sExt64 pgtid_in_segment) (sExt64 threads_per_segment) segment_size' $ \TExp Int64
ie -> do
          VName -> TExp Int64 -> InKernelGen ()
forall {k} (t :: k) rep r op. VName -> TExp t -> ImpM rep r op ()
dPrimV_ VName
i_in_segment TExp Int64
ie

          -- We execute the bucket function once and update each histogram
          -- serially.  This also involves writing to the mapout arrays if
          -- this is the first chunk.

          Names -> Stms GPUMem -> InKernelGen () -> InKernelGen ()
forall rep r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms Names
forall a. Monoid a => a
mempty (KernelBody GPUMem -> Stms GPUMem
forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody GPUMem
kbody) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
            let ([SubExp]
red_res, [SubExp]
map_res) =
                  Int -> [SubExp] -> ([SubExp], [SubExp])
forall a. Int -> [a] -> ([a], [a])
splitFromEnd ([PatElem LParamMem] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [PatElem LParamMem]
map_pes) ([SubExp] -> ([SubExp], [SubExp]))
-> [SubExp] -> ([SubExp], [SubExp])
forall a b. (a -> b) -> a -> b
$
                    (KernelResult -> SubExp) -> [KernelResult] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map KernelResult -> SubExp
kernelResultSubExp ([KernelResult] -> [SubExp]) -> [KernelResult] -> [SubExp]
forall a b. (a -> b) -> a -> b
$
                      KernelBody GPUMem -> [KernelResult]
forall rep. KernelBody rep -> [KernelResult]
kernelBodyResult KernelBody GPUMem
kbody

            TExp Bool -> InKernelGen () -> InKernelGen ()
forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen (TExp Int32
chk_i TExp Int32 -> TExp Int32 -> TExp Bool
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int32
0) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
              Text -> InKernelGen () -> InKernelGen ()
forall rep r op. Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"save map-out results" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
                [(PatElem LParamMem, SubExp)]
-> ((PatElem LParamMem, SubExp) -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([PatElem LParamMem] -> [SubExp] -> [(PatElem LParamMem, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [PatElem LParamMem]
map_pes [SubExp]
map_res) (((PatElem LParamMem, SubExp) -> InKernelGen ()) -> InKernelGen ())
-> ((PatElem LParamMem, SubExp) -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(PatElem LParamMem
pe, SubExp
se) ->
                  VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> InKernelGen ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix
                    (PatElem LParamMem -> VName
forall dec. PatElem dec -> VName
patElemName PatElem LParamMem
pe)
                    ((VName -> TExp Int64) -> [VName] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map VName -> TExp Int64
forall a. a -> TPrimExp Int64 a
Imp.le64 [VName]
space_is)
                    SubExp
se
                    []

            let red_res_split :: [([SubExp], [SubExp])]
red_res_split = [HistOp GPUMem] -> [SubExp] -> [([SubExp], [SubExp])]
forall rep. [HistOp rep] -> [SubExp] -> [([SubExp], [SubExp])]
splitHistResults ((SegHistSlug -> HistOp GPUMem) -> [SegHistSlug] -> [HistOp GPUMem]
forall a b. (a -> b) -> [a] -> [b]
map SegHistSlug -> HistOp GPUMem
slugOp [SegHistSlug]
slugs) [SubExp]
red_res
            [(HistOp GPUMem,
  ([(VName, VName)], TV Int64, [TExp Int64] -> InKernelGen ()),
  ([SubExp], [SubExp]))]
-> ((HistOp GPUMem,
     ([(VName, VName)], TV Int64, [TExp Int64] -> InKernelGen ()),
     ([SubExp], [SubExp]))
    -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([HistOp GPUMem]
-> [([(VName, VName)], TV Int64, [TExp Int64] -> InKernelGen ())]
-> [([SubExp], [SubExp])]
-> [(HistOp GPUMem,
     ([(VName, VName)], TV Int64, [TExp Int64] -> InKernelGen ()),
     ([SubExp], [SubExp]))]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 ((SegHistSlug -> HistOp GPUMem) -> [SegHistSlug] -> [HistOp GPUMem]
forall a b. (a -> b) -> [a] -> [b]
map SegHistSlug -> HistOp GPUMem
slugOp [SegHistSlug]
slugs) [([(VName, VName)], TV Int64, [TExp Int64] -> InKernelGen ())]
histograms [([SubExp], [SubExp])]
red_res_split) (((HistOp GPUMem,
   ([(VName, VName)], TV Int64, [TExp Int64] -> InKernelGen ()),
   ([SubExp], [SubExp]))
  -> InKernelGen ())
 -> InKernelGen ())
-> ((HistOp GPUMem,
     ([(VName, VName)], TV Int64, [TExp Int64] -> InKernelGen ()),
     ([SubExp], [SubExp]))
    -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
              \( HistOp Shape
dest_shape SubExp
_ [VName]
_ [SubExp]
_ Shape
shape Lambda GPUMem
lam,
                 ([(VName, VName)]
_, TV Int64
hist_H_chk, [TExp Int64] -> InKernelGen ()
do_op),
                 ([SubExp]
bucket, [SubExp]
vs')
                 ) -> do
                  let chk_beg :: TExp Int64
chk_beg = TExp Int32 -> TExp Int64
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
chk_i TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* TV Int64 -> TExp Int64
forall {k} (t :: k). TV t -> TExp t
tvExp TV Int64
hist_H_chk
                      bucket' :: [TExp Int64]
bucket' = (SubExp -> TExp Int64) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
pe64 [SubExp]
bucket
                      dest_shape' :: [TExp Int64]
dest_shape' = (SubExp -> TExp Int64) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
pe64 ([SubExp] -> [TExp Int64]) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> a -> b
$ Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
dest_shape
                      flat_bucket :: TExp Int64
flat_bucket = [TExp Int64] -> [TExp Int64] -> TExp Int64
forall num. IntegralExp num => [num] -> [num] -> num
flattenIndex [TExp Int64]
dest_shape' [TExp Int64]
bucket'
                      bucket_in_bounds :: TExp Bool
bucket_in_bounds =
                        Slice (TExp Int64) -> [TExp Int64] -> TExp Bool
inBounds ([DimIndex (TExp Int64)] -> Slice (TExp Int64)
forall d. [DimIndex d] -> Slice d
Slice ((TExp Int64 -> DimIndex (TExp Int64))
-> [TExp Int64] -> [DimIndex (TExp Int64)]
forall a b. (a -> b) -> [a] -> [b]
map TExp Int64 -> DimIndex (TExp Int64)
forall d. d -> DimIndex d
DimFix [TExp Int64]
bucket')) [TExp Int64]
dest_shape'
                          TExp Bool -> TExp Bool -> TExp Bool
forall v.
Eq v =>
TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. TExp Int64
chk_beg
                          TExp Int64 -> TExp Int64 -> TExp Bool
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<=. TExp Int64
flat_bucket
                          TExp Bool -> TExp Bool -> TExp Bool
forall v.
Eq v =>
TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. TExp Int64
flat_bucket
                          TExp Int64 -> TExp Int64 -> TExp Bool
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. (TExp Int64
chk_beg TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
+ TV Int64 -> TExp Int64
forall {k} (t :: k). TV t -> TExp t
tvExp TV Int64
hist_H_chk)
                      bucket_is :: [TExp Int64]
bucket_is =
                        [TExp Int32 -> TExp Int64
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
thread_local_subhisto_i, TExp Int64
flat_bucket TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
- TExp Int64
chk_beg]
                      vs_params :: [Param LParamMem]
vs_params = Int -> [Param LParamMem] -> [Param LParamMem]
forall a. Int -> [a] -> [a]
takeLast ([SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
vs') ([Param LParamMem] -> [Param LParamMem])
-> [Param LParamMem] -> [Param LParamMem]
forall a b. (a -> b) -> a -> b
$ Lambda GPUMem -> [LParam GPUMem]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda GPUMem
lam

                  Text -> InKernelGen () -> InKernelGen ()
forall rep r op. Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"perform atomic updates" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
                    TExp Bool -> InKernelGen () -> InKernelGen ()
forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen TExp Bool
bucket_in_bounds (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
                      [LParam GPUMem] -> InKernelGen ()
forall rep (inner :: * -> *) r op.
Mem rep inner =>
[LParam rep] -> ImpM rep r op ()
dLParams ([LParam GPUMem] -> InKernelGen ())
-> [LParam GPUMem] -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Lambda GPUMem -> [LParam GPUMem]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda GPUMem
lam
                      Shape -> ([TExp Int64] -> InKernelGen ()) -> InKernelGen ()
forall rep r op.
Shape -> ([TExp Int64] -> ImpM rep r op ()) -> ImpM rep r op ()
sLoopNest Shape
shape (([TExp Int64] -> InKernelGen ()) -> InKernelGen ())
-> ([TExp Int64] -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \[TExp Int64]
is -> do
                        [(Param LParamMem, SubExp)]
-> ((Param LParamMem, SubExp) -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LParamMem] -> [SubExp] -> [(Param LParamMem, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param LParamMem]
vs_params [SubExp]
vs') (((Param LParamMem, SubExp) -> InKernelGen ()) -> InKernelGen ())
-> ((Param LParamMem, SubExp) -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(Param LParamMem
p, SubExp
v) ->
                          VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> InKernelGen ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix (Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
p) [] SubExp
v [TExp Int64]
is
                        [TExp Int64] -> InKernelGen ()
do_op ([TExp Int64]
bucket_is [TExp Int64] -> [TExp Int64] -> [TExp Int64]
forall a. [a] -> [a] -> [a]
++ [TExp Int64]
is)

        sOp $ Imp.ErrorSync Imp.FenceGlobal

        sComment "Compact the multiple shared memory subhistograms to result in global memory" $
          onSlugs $ \SegHistSlug
slug [(VName, VName)]
dests TExp Int64
hist_H_chk [TExp Int64]
histo_dims TExp Int64
_histo_size TExp Int32
bins_per_thread -> do
            trunc_H <-
              [Char] -> TExp Int64 -> ImpM GPUMem KernelEnv KernelOp (TV Int64)
forall {k} (t :: k) rep r op.
[Char] -> TExp t -> ImpM rep r op (TV t)
dPrimV [Char]
"trunc_H" (TExp Int64 -> ImpM GPUMem KernelEnv KernelOp (TV Int64))
-> (TExp Int64 -> TExp Int64)
-> TExp Int64
-> ImpM GPUMem KernelEnv KernelOp (TV Int64)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TExp Int64 -> TExp Int64 -> TExp Int64
forall v. TPrimExp Int64 v -> TPrimExp Int64 v -> TPrimExp Int64 v
sMin64 TExp Int64
hist_H_chk (TExp Int64 -> ImpM GPUMem KernelEnv KernelOp (TV Int64))
-> TExp Int64 -> ImpM GPUMem KernelEnv KernelOp (TV Int64)
forall a b. (a -> b) -> a -> b
$
                HistOp GPUMem -> TExp Int64
histSize (SegHistSlug -> HistOp GPUMem
slugOp SegHistSlug
slug) TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
- TExp Int32 -> TExp Int64
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
chk_i TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* [TExp Int64] -> TExp Int64
forall a. HasCallStack => [a] -> a
head [TExp Int64]
histo_dims
            let trunc_histo_dims =
                  TV Int64 -> TExp Int64
forall {k} (t :: k). TV t -> TExp t
tvExp TV Int64
trunc_H
                    TExp Int64 -> [TExp Int64] -> [TExp Int64]
forall a. a -> [a] -> [a]
: (SubExp -> TExp Int64) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
pe64 (Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims (HistOp GPUMem -> Shape
forall rep. HistOp rep -> Shape
histOpShape (SegHistSlug -> HistOp GPUMem
slugOp SegHistSlug
slug)))
            trunc_histo_size <- dPrimVE "histo_size" $ sExt32 $ product trunc_histo_dims

            sFor "local_i" bins_per_thread $ \TExp Int32
i -> do
              j <-
                [Char] -> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32)
forall {k} (t :: k) rep r op.
[Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"j" (TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32))
-> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32)
forall a b. (a -> b) -> a -> b
$
                  TExp Int32
i TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
* TExp Int64 -> TExp Int32
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (KernelConstants -> TExp Int64
kernelBlockSize KernelConstants
constants)
                    TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
+ KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants
              sWhen (j .<. trunc_histo_size) $ do
                -- We are responsible for compacting the flat bin 'j', which
                -- we immediately unflatten.
                let local_bucket_is = [TExp Int64] -> TExp Int64 -> [TExp Int64]
forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex [TExp Int64]
histo_dims (TExp Int64 -> [TExp Int64]) -> TExp Int64 -> [TExp Int64]
forall a b. (a -> b) -> a -> b
$ TExp Int32 -> TExp Int64
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
j
                    nested_hist_size =
                      (SubExp -> TExp Int64) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
pe64 ([SubExp] -> [TExp Int64]) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> a -> b
$ Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims (Shape -> [SubExp]) -> Shape -> [SubExp]
forall a b. (a -> b) -> a -> b
$ HistOp GPUMem -> Shape
forall rep. HistOp rep -> Shape
histShape (HistOp GPUMem -> Shape) -> HistOp GPUMem -> Shape
forall a b. (a -> b) -> a -> b
$ SegHistSlug -> HistOp GPUMem
slugOp SegHistSlug
slug
                    global_bucket_is =
                      [TExp Int64] -> TExp Int64 -> [TExp Int64]
forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex
                        [TExp Int64]
nested_hist_size
                        ([TExp Int64] -> TExp Int64
forall a. HasCallStack => [a] -> a
head [TExp Int64]
local_bucket_is TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
+ TExp Int32 -> TExp Int64
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
chk_i TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* TExp Int64
hist_H_chk)
                        [TExp Int64] -> [TExp Int64] -> [TExp Int64]
forall a. [a] -> [a] -> [a]
++ [TExp Int64] -> [TExp Int64]
forall a. HasCallStack => [a] -> [a]
tail [TExp Int64]
local_bucket_is
                dLParams $ lambdaParams $ histOp $ slugOp slug
                let (global_dests, local_dests) = unzip dests
                    (xparams, yparams) =
                      splitAt (length local_dests) $
                        lambdaParams $
                          histOp $
                            slugOp slug

                sComment "Read values from subhistogram 0." $
                  forM_ (zip xparams local_dests) $ \(Param LParamMem
xp, VName
subhisto) ->
                    VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> InKernelGen ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix
                      (Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
xp)
                      []
                      (VName -> SubExp
Var VName
subhisto)
                      (TExp Int64
0 TExp Int64 -> [TExp Int64] -> [TExp Int64]
forall a. a -> [a] -> [a]
: [TExp Int64]
local_bucket_is)

                sComment "Accumulate based on values in other subhistograms." $
                  sFor "subhisto_id" (num_subhistos_per_block - 1) $ \TExp Int32
subhisto_id -> do
                    [(Param LParamMem, VName)]
-> ((Param LParamMem, VName) -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LParamMem] -> [VName] -> [(Param LParamMem, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param LParamMem]
yparams [VName]
local_dests) (((Param LParamMem, VName) -> InKernelGen ()) -> InKernelGen ())
-> ((Param LParamMem, VName) -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(Param LParamMem
yp, VName
subhisto) ->
                      VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> InKernelGen ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix
                        (Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
yp)
                        []
                        (VName -> SubExp
Var VName
subhisto)
                        (TExp Int32 -> TExp Int64
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
subhisto_id TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
+ TExp Int64
1 TExp Int64 -> [TExp Int64] -> [TExp Int64]
forall a. a -> [a] -> [a]
: [TExp Int64]
local_bucket_is)
                    [Param LParamMem] -> Body GPUMem -> InKernelGen ()
forall dec rep r op. [Param dec] -> Body rep -> ImpM rep r op ()
compileBody' [Param LParamMem]
xparams (Body GPUMem -> InKernelGen ()) -> Body GPUMem -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Lambda GPUMem -> Body GPUMem
forall rep. Lambda rep -> Body rep
lambdaBody (Lambda GPUMem -> Body GPUMem) -> Lambda GPUMem -> Body GPUMem
forall a b. (a -> b) -> a -> b
$ HistOp GPUMem -> Lambda GPUMem
forall rep. HistOp rep -> Lambda rep
histOp (HistOp GPUMem -> Lambda GPUMem) -> HistOp GPUMem -> Lambda GPUMem
forall a b. (a -> b) -> a -> b
$ SegHistSlug -> HistOp GPUMem
slugOp SegHistSlug
slug

                sComment "Put final bucket value in global memory." $ do
                  let global_is =
                        (VName -> TExp Int64) -> [VName] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map VName -> TExp Int64
forall a. a -> TPrimExp Int64 a
Imp.le64 [VName]
segment_is
                          [TExp Int64] -> [TExp Int64] -> [TExp Int64]
forall a. [a] -> [a] -> [a]
++ [TExp Int32 -> TExp Int64
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
tblock_id TExp Int64 -> TExp Int64 -> TExp Int64
forall e. IntegralExp e => e -> e -> e
`rem` Count NumBlocks (TExp Int64) -> TExp Int64
forall {k} (u :: k) e. Count u e -> e
unCount Count NumBlocks (TExp Int64)
blocks_per_segment]
                          [TExp Int64] -> [TExp Int64] -> [TExp Int64]
forall a. [a] -> [a] -> [a]
++ [TExp Int64]
global_bucket_is
                  forM_ (zip xparams global_dests) $ \(Param LParamMem
xp, VName
global_dest) ->
                    VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> InKernelGen ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
global_dest [TExp Int64]
global_is (VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
xp) []

histKernelLocal ::
  TV Int32 ->
  Count NumBlocks (Imp.TExp Int64) ->
  [PatElem LetDecMem] ->
  Count NumBlocks SubExp ->
  Count BlockSize SubExp ->
  SegSpace ->
  Imp.TExp Int32 ->
  [SegHistSlug] ->
  KernelBody GPUMem ->
  CallKernelGen ()
histKernelLocal :: TV Int32
-> Count NumBlocks (TExp Int64)
-> [PatElem LParamMem]
-> Count NumBlocks SubExp
-> Count BlockSize SubExp
-> SegSpace
-> TExp Int32
-> [SegHistSlug]
-> KernelBody GPUMem
-> CallKernelGen ()
histKernelLocal TV Int32
num_subhistos_per_block_var Count NumBlocks (TExp Int64)
blocks_per_segment [PatElem LParamMem]
map_pes Count NumBlocks SubExp
num_tblocks Count BlockSize SubExp
tblock_size SegSpace
space TExp Int32
hist_S [SegHistSlug]
slugs KernelBody GPUMem
kbody = do
  let num_subhistos_per_block :: TExp Int32
num_subhistos_per_block = TV Int32 -> TExp Int32
forall {k} (t :: k). TV t -> TExp t
tvExp TV Int32
num_subhistos_per_block_var

  Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$
    [Char] -> Maybe Exp -> Code HostOp
forall a. [Char] -> Maybe Exp -> Code a
Imp.DebugPrint [Char]
"Number of local subhistograms per block" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$
      Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$
        TExp Int32 -> Exp
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TExp Int32
num_subhistos_per_block

  init_histograms <-
    TV Int32
-> Count NumBlocks (TExp Int64)
-> [SegHistSlug]
-> CallKernelGen InitLocalHistograms
prepareIntermediateArraysLocal TV Int32
num_subhistos_per_block_var Count NumBlocks (TExp Int64)
blocks_per_segment [SegHistSlug]
slugs

  sFor "chk_i" hist_S $ \TExp Int32
chk_i ->
    TV Int32
-> Count NumBlocks (TExp Int64)
-> [PatElem LParamMem]
-> Count NumBlocks SubExp
-> Count BlockSize SubExp
-> SegSpace
-> [SegHistSlug]
-> KernelBody GPUMem
-> InitLocalHistograms
-> TExp Int32
-> TExp Int32
-> CallKernelGen ()
histKernelLocalPass
      TV Int32
num_subhistos_per_block_var
      Count NumBlocks (TExp Int64)
blocks_per_segment
      [PatElem LParamMem]
map_pes
      Count NumBlocks SubExp
num_tblocks
      Count BlockSize SubExp
tblock_size
      SegSpace
space
      [SegHistSlug]
slugs
      KernelBody GPUMem
kbody
      InitLocalHistograms
init_histograms
      TExp Int32
hist_S
      TExp Int32
chk_i

-- | The maximum number of passes we are willing to accept for this
-- kind of atomic update.
slugMaxLocalMemPasses :: SegHistSlug -> Int
slugMaxLocalMemPasses :: SegHistSlug -> Int
slugMaxLocalMemPasses SegHistSlug
slug =
  case SegHistSlug -> AtomicUpdate GPUMem KernelEnv
slugAtomicUpdate SegHistSlug
slug of
    AtomicPrim DoAtomicUpdate GPUMem KernelEnv
_ -> Int
3
    AtomicCAS DoAtomicUpdate GPUMem KernelEnv
_ -> Int
4
    AtomicLocking Locking -> DoAtomicUpdate GPUMem KernelEnv
_ -> Int
6

localMemoryCase ::
  [PatElem LetDecMem] ->
  Imp.TExp Int32 ->
  SegSpace ->
  Imp.TExp Int64 ->
  Imp.TExp Int64 ->
  Imp.TExp Int64 ->
  Imp.TExp Int32 ->
  [SegHistSlug] ->
  KernelBody GPUMem ->
  CallKernelGen (Imp.TExp Bool, CallKernelGen ())
localMemoryCase :: [PatElem LParamMem]
-> TExp Int32
-> SegSpace
-> TExp Int64
-> TExp Int64
-> TExp Int64
-> TExp Int32
-> [SegHistSlug]
-> KernelBody GPUMem
-> CallKernelGen (TExp Bool, CallKernelGen ())
localMemoryCase [PatElem LParamMem]
map_pes TExp Int32
hist_T SegSpace
space TExp Int64
hist_H TExp Int64
hist_el_size TExp Int64
hist_N TExp Int32
_ [SegHistSlug]
slugs KernelBody GPUMem
kbody = do
  let space_sizes :: [SubExp]
space_sizes = SegSpace -> [SubExp]
segSpaceDims SegSpace
space
      segment_dims :: [SubExp]
segment_dims = [SubExp] -> [SubExp]
forall a. HasCallStack => [a] -> [a]
init [SubExp]
space_sizes
      segmented :: Bool
segmented = Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [SubExp] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [SubExp]
segment_dims

  hist_L <- [Char] -> SizeClass -> ImpM GPUMem HostEnv HostOp (TV Int64)
getSize [Char]
"hist_L" SizeClass
Imp.SizeSharedMemory

  max_tblock_size :: TV Int64 <- dPrim "max_tblock_size"
  sOp $ Imp.GetSizeMax (tvVar max_tblock_size) Imp.SizeThreadBlock

  -- XXX: we need to record for later use that max_tblock_size is the
  -- result of GetSizeMax.  This is an ugly hack that reflects our
  -- inability to track which variables are actually constants.
  let withSizeMax Map VName (VarEntry GPUMem)
vtable =
        case VName -> Map VName (VarEntry GPUMem) -> Maybe (VarEntry GPUMem)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup (TV Int64 -> VName
forall {k} (t :: k). TV t -> VName
tvVar TV Int64
max_tblock_size) Map VName (VarEntry GPUMem)
vtable of
          Just (ScalarVar Maybe (Exp GPUMem)
_ ScalarEntry
se) ->
            VName
-> VarEntry GPUMem
-> Map VName (VarEntry GPUMem)
-> Map VName (VarEntry GPUMem)
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert
              (TV Int64 -> VName
forall {k} (t :: k). TV t -> VName
tvVar TV Int64
max_tblock_size)
              (Maybe (Exp GPUMem) -> ScalarEntry -> VarEntry GPUMem
forall rep. Maybe (Exp rep) -> ScalarEntry -> VarEntry rep
ScalarVar (Exp GPUMem -> Maybe (Exp GPUMem)
forall a. a -> Maybe a
Just (Op GPUMem -> Exp GPUMem
forall rep. Op rep -> Exp rep
Op (HostOp NoOp GPUMem -> MemOp (HostOp NoOp) GPUMem
forall (inner :: * -> *) rep. inner rep -> MemOp inner rep
Inner (SizeOp -> HostOp NoOp GPUMem
forall (op :: * -> *) rep. SizeOp -> HostOp op rep
SizeOp (SizeClass -> SizeOp
GetSizeMax SizeClass
SizeThreadBlock))))) ScalarEntry
se)
              Map VName (VarEntry GPUMem)
vtable
          Maybe (VarEntry GPUMem)
_ -> Map VName (VarEntry GPUMem)
vtable

  let tblock_size = SubExp -> Count BlockSize SubExp
forall {k} (u :: k) e. e -> Count u e
Imp.Count (SubExp -> Count BlockSize SubExp)
-> SubExp -> Count BlockSize SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ TV Int64 -> VName
forall {k} (t :: k). TV t -> VName
tvVar TV Int64
max_tblock_size
  num_tblocks <-
    fmap (Imp.Count . tvSize) $
      dPrimV "num_tblocks" $
        sExt64 hist_T `divUp` pe64 (unCount tblock_size)
  let num_tblocks' = SubExp -> TExp Int64
pe64 (SubExp -> TExp Int64)
-> Count NumBlocks SubExp -> Count NumBlocks (TExp Int64)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Count NumBlocks SubExp
num_tblocks
      tblock_size' = SubExp -> TExp Int64
pe64 (SubExp -> TExp Int64)
-> Count BlockSize SubExp -> Count BlockSize (TExp Int64)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Count BlockSize SubExp
tblock_size

  let r64 = PrimExp v -> TPrimExp Double v
forall v. PrimExp v -> TPrimExp Double v
isF64 (PrimExp v -> TPrimExp Double v)
-> (TPrimExp t v -> PrimExp v) -> TPrimExp t v -> TPrimExp Double v
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ConvOp -> PrimExp v -> PrimExp v
forall v. ConvOp -> PrimExp v -> PrimExp v
ConvOpExp (IntType -> FloatType -> ConvOp
SIToFP IntType
Int64 FloatType
Float64) (PrimExp v -> PrimExp v)
-> (TPrimExp t v -> PrimExp v) -> TPrimExp t v -> PrimExp v
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TPrimExp t v -> PrimExp v
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped
      t64 = PrimExp v -> TPrimExp Int64 v
forall v. PrimExp v -> TPrimExp Int64 v
isInt64 (PrimExp v -> TPrimExp Int64 v)
-> (TPrimExp t v -> PrimExp v) -> TPrimExp t v -> TPrimExp Int64 v
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ConvOp -> PrimExp v -> PrimExp v
forall v. ConvOp -> PrimExp v -> PrimExp v
ConvOpExp (FloatType -> IntType -> ConvOp
FPToSI FloatType
Float64 IntType
Int64) (PrimExp v -> PrimExp v)
-> (TPrimExp t v -> PrimExp v) -> TPrimExp t v -> PrimExp v
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TPrimExp t v -> PrimExp v
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped

  -- M approximation.
  hist_m' <-
    dPrimVE "hist_m_prime" $
      r64
        ( sMin64
            (sExt64 (tvExp hist_L `quot` hist_el_size))
            (hist_N `divUp` sExt64 (unCount num_tblocks'))
        )
        / r64 hist_H

  let hist_B = Count BlockSize (TExp Int64) -> TExp Int64
forall {k} (u :: k) e. Count u e -> e
unCount Count BlockSize (TExp Int64)
tblock_size'

  -- M in the paper, but not adjusted for asymptotic efficiency.
  hist_M0 <-
    dPrimVE "hist_M0" $
      sMax64 1 $
        sMin64 (t64 hist_m') hist_B

  -- Minimal sequential chunking factor.
  let q_small = TExp Int64
2

  -- The number of segments/histograms produced..
  hist_Nout <- dPrimVE "hist_Nout" $ product $ map pe64 segment_dims

  hist_Nin <- dPrimVE "hist_Nin" $ pe64 $ last space_sizes

  -- Maximum M for work efficiency.
  work_asymp_M_max <-
    if segmented
      then do
        hist_T_hist_min <-
          dPrimVE "hist_T_hist_min" $
            sExt32 $
              sMin64 (sExt64 hist_Nin * sExt64 hist_Nout) (sExt64 hist_T)
                `divUp` sExt64 hist_Nout

        -- Number of blocks, rounded up.
        let r = TExp Int32
hist_T_hist_min TExp Int32 -> TExp Int32 -> TExp Int32
forall e. IntegralExp e => e -> e -> e
`divUp` TExp Int64 -> TExp Int32
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 TExp Int64
hist_B

        dPrimVE "work_asymp_M_max" $ hist_Nin `quot` (sExt64 r * hist_H)
      else
        dPrimVE "work_asymp_M_max" $
          (hist_Nout * hist_N)
            `quot` ( (q_small * unCount num_tblocks' * hist_H)
                       `quot` L.genericLength slugs
                   )

  -- Number of subhistograms per result histogram.
  hist_M <- dPrimV "hist_M" $ sExt32 $ sMin64 hist_M0 work_asymp_M_max

  -- hist_M may be zero (which we'll check for below), but we need it
  -- for some divisions first, so crudely make a nonzero form.
  let hist_M_nonzero = TExp Int32 -> TExp Int32 -> TExp Int32
forall v. TPrimExp Int32 v -> TPrimExp Int32 v -> TPrimExp Int32 v
sMax32 TExp Int32
1 (TExp Int32 -> TExp Int32) -> TExp Int32 -> TExp Int32
forall a b. (a -> b) -> a -> b
$ TV Int32 -> TExp Int32
forall {k} (t :: k). TV t -> TExp t
tvExp TV Int32
hist_M

  -- "Cooperation factor" - the number of threads cooperatively
  -- working on the same (sub)histogram.
  hist_C <-
    dPrimVE "hist_C" $
      hist_B `divUp` sExt64 hist_M_nonzero

  emit $ Imp.DebugPrint "local hist_M0" $ Just $ untyped hist_M0
  emit $ Imp.DebugPrint "local work asymp M max" $ Just $ untyped work_asymp_M_max
  emit $ Imp.DebugPrint "local C" $ Just $ untyped hist_C
  emit $ Imp.DebugPrint "local B" $ Just $ untyped hist_B
  emit $ Imp.DebugPrint "local M" $ Just $ untyped $ tvExp hist_M
  emit $
    Imp.DebugPrint "shared memory needed" $
      Just $
        untyped $
          hist_H * hist_el_size * sExt64 (tvExp hist_M)

  -- local_mem_needed is what we need to keep a single bucket in local
  -- memory - this is an absolute minimum.  We can fit anything else
  -- by doing multiple passes, although more than a few is
  -- (heuristically) not efficient.
  local_mem_needed <-
    dPrimVE "local_mem_needed" $
      hist_el_size * sExt64 (tvExp hist_M)
  -- We add one to the memory requirement because if the chunk
  -- otherwise *exactly* fits, it might actually *not* fit in the case
  -- of a multi-value operator, as we individually round up the sizes
  -- of the component arrays. (Very rare edge case.)
  hist_S <-
    dPrimVE "hist_S" . sExt32 $
      (hist_H * local_mem_needed + 1) `divUp` tvExp hist_L
  let max_S = case KernelBody GPUMem -> Passage
bodyPassage KernelBody GPUMem
kbody of
        Passage
MustBeSinglePass -> TExp Int32
1
        Passage
MayBeMultiPass -> Int -> TExp Int32
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> TExp Int32) -> Int -> TExp Int32
forall a b. (a -> b) -> a -> b
$ [Int] -> Int
forall a (f :: * -> *). (Num a, Ord a, Foldable f) => f a -> a
maxinum ([Int] -> Int) -> [Int] -> Int
forall a b. (a -> b) -> a -> b
$ (SegHistSlug -> Int) -> [SegHistSlug] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map SegHistSlug -> Int
slugMaxLocalMemPasses [SegHistSlug]
slugs

  blocks_per_segment <-
    if segmented
      then
        fmap Count $
          dPrimVE "blocks_per_segment" $
            unCount num_tblocks' `divUp` hist_Nout
      else pure num_tblocks'

  -- We only use shared memory if the number of updates per histogram
  -- at least matches the histogram size, as otherwise it is not
  -- asymptotically efficient.  This mostly matters for the segmented
  -- case.
  let pick_local =
        TExp Int64
hist_Nin
          TExp Int64 -> TExp Int64 -> TExp Bool
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.>=. TExp Int64
hist_H
          TExp Bool -> TExp Bool -> TExp Bool
forall v.
Eq v =>
TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. (TExp Int64
local_mem_needed TExp Int64 -> TExp Int64 -> TExp Bool
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<=. TV Int64 -> TExp Int64
forall {k} (t :: k). TV t -> TExp t
tvExp TV Int64
hist_L)
          TExp Bool -> TExp Bool -> TExp Bool
forall v.
Eq v =>
TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. (TExp Int32
hist_S TExp Int32 -> TExp Int32 -> TExp Bool
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<=. TExp Int32
max_S)
          TExp Bool -> TExp Bool -> TExp Bool
forall v.
Eq v =>
TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. TExp Int64
hist_C
          TExp Int64 -> TExp Int64 -> TExp Bool
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<=. TExp Int64
hist_B
          TExp Bool -> TExp Bool -> TExp Bool
forall v.
Eq v =>
TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. TV Int32 -> TExp Int32
forall {k} (t :: k). TV t -> TExp t
tvExp TV Int32
hist_M
          TExp Int32 -> TExp Int32 -> TExp Bool
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.>. TExp Int32
0

      run = do
        Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ [Char] -> Maybe Exp -> Code HostOp
forall a. [Char] -> Maybe Exp -> Code a
Imp.DebugPrint [Char]
"## Using shared memory" Maybe Exp
forall a. Maybe a
Nothing
        Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ [Char] -> Maybe Exp -> Code HostOp
forall a. [Char] -> Maybe Exp -> Code a
Imp.DebugPrint [Char]
"Histogram size (H)" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TExp Int64 -> Exp
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TExp Int64
hist_H
        Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ [Char] -> Maybe Exp -> Code HostOp
forall a. [Char] -> Maybe Exp -> Code a
Imp.DebugPrint [Char]
"Multiplication degree (M)" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TExp Int32 -> Exp
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped (TExp Int32 -> Exp) -> TExp Int32 -> Exp
forall a b. (a -> b) -> a -> b
$ TV Int32 -> TExp Int32
forall {k} (t :: k). TV t -> TExp t
tvExp TV Int32
hist_M
        Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ [Char] -> Maybe Exp -> Code HostOp
forall a. [Char] -> Maybe Exp -> Code a
Imp.DebugPrint [Char]
"Cooperation level (C)" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TExp Int64 -> Exp
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TExp Int64
hist_C
        Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ [Char] -> Maybe Exp -> Code HostOp
forall a. [Char] -> Maybe Exp -> Code a
Imp.DebugPrint [Char]
"Number of chunks (S)" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TExp Int32 -> Exp
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TExp Int32
hist_S
        Bool -> CallKernelGen () -> CallKernelGen ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
segmented (CallKernelGen () -> CallKernelGen ())
-> CallKernelGen () -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$
          Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$
            [Char] -> Maybe Exp -> Code HostOp
forall a. [Char] -> Maybe Exp -> Code a
Imp.DebugPrint [Char]
"Blocks per segment" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$
              Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$
                TExp Int64 -> Exp
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped (TExp Int64 -> Exp) -> TExp Int64 -> Exp
forall a b. (a -> b) -> a -> b
$
                  Count NumBlocks (TExp Int64) -> TExp Int64
forall {k} (u :: k) e. Count u e -> e
unCount Count NumBlocks (TExp Int64)
blocks_per_segment
        (Map VName (VarEntry GPUMem) -> Map VName (VarEntry GPUMem))
-> CallKernelGen () -> CallKernelGen ()
forall rep r op a.
(VTable rep -> VTable rep) -> ImpM rep r op a -> ImpM rep r op a
localVTable Map VName (VarEntry GPUMem) -> Map VName (VarEntry GPUMem)
withSizeMax (CallKernelGen () -> CallKernelGen ())
-> CallKernelGen () -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$
          TV Int32
-> Count NumBlocks (TExp Int64)
-> [PatElem LParamMem]
-> Count NumBlocks SubExp
-> Count BlockSize SubExp
-> SegSpace
-> TExp Int32
-> [SegHistSlug]
-> KernelBody GPUMem
-> CallKernelGen ()
histKernelLocal
            TV Int32
hist_M
            Count NumBlocks (TExp Int64)
blocks_per_segment
            [PatElem LParamMem]
map_pes
            Count NumBlocks SubExp
num_tblocks
            Count BlockSize SubExp
tblock_size
            SegSpace
space
            TExp Int32
hist_S
            [SegHistSlug]
slugs
            KernelBody GPUMem
kbody

  pure (pick_local, run)

-- | Generate code for a segmented histogram called from the host.
compileSegHist ::
  Pat LetDecMem ->
  SegLevel ->
  SegSpace ->
  [HistOp GPUMem] ->
  KernelBody GPUMem ->
  CallKernelGen ()
compileSegHist :: Pat LParamMem
-> SegLevel
-> SegSpace
-> [HistOp GPUMem]
-> KernelBody GPUMem
-> CallKernelGen ()
compileSegHist (Pat [PatElem LParamMem]
pes) SegLevel
lvl SegSpace
space [HistOp GPUMem]
ops KernelBody GPUMem
kbody = do
  KernelAttrs {kAttrNumBlocks = num_tblocks, kAttrBlockSize = tblock_size} <-
    SegLevel -> CallKernelGen KernelAttrs
lvlKernelAttrs SegLevel
lvl
  -- Most of this function is not the histogram part itself, but
  -- rather figuring out whether to use a local or global memory
  -- strategy, as well as collapsing the subhistograms produced (which
  -- are always in global memory, but their number may vary).
  let num_tblocks' = (SubExp -> TExp Int64)
-> Count NumBlocks SubExp -> Count NumBlocks (TExp Int64)
forall a b. (a -> b) -> Count NumBlocks a -> Count NumBlocks b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> TExp Int64
pe64 Count NumBlocks SubExp
num_tblocks
      tblock_size' = (SubExp -> TExp Int64)
-> Count BlockSize SubExp -> Count BlockSize (TExp Int64)
forall a b. (a -> b) -> Count BlockSize a -> Count BlockSize b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> TExp Int64
pe64 Count BlockSize SubExp
tblock_size
      dims = (SubExp -> TExp Int64) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
pe64 ([SubExp] -> [TExp Int64]) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [SubExp]
segSpaceDims SegSpace
space

      num_red_res = [HistOp GPUMem] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [HistOp GPUMem]
ops Int -> Int -> Int
forall a. Num a => a -> a -> a
+ [Int] -> Int
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ((HistOp GPUMem -> Int) -> [HistOp GPUMem] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([SubExp] -> Int)
-> (HistOp GPUMem -> [SubExp]) -> HistOp GPUMem -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HistOp GPUMem -> [SubExp]
forall rep. HistOp rep -> [SubExp]
histNeutral) [HistOp GPUMem]
ops)
      (all_red_pes, map_pes) = splitAt num_red_res pes
      segment_size = [TExp Int64] -> TExp Int64
forall a. HasCallStack => [a] -> a
last [TExp Int64]
dims

  (op_hs, op_seg_hs, slugs) <- unzip3 <$> mapM (computeHistoUsage space) ops
  h <- dPrimVE "h" $ Imp.unCount $ sum op_hs
  seg_h <- dPrimVE "seg_h" $ Imp.unCount $ sum op_seg_hs

  -- Check for emptyness to avoid division-by-zero.
  sUnless (seg_h .==. 0) $ do
    -- Maximum block size (or actual, in this case).
    let hist_B = Count BlockSize (TExp Int64) -> TExp Int64
forall {k} (u :: k) e. Count u e -> e
unCount Count BlockSize (TExp Int64)
tblock_size'

    -- Size of a histogram.
    hist_H <- dPrimVE "hist_H" $ sum $ map histSize ops

    -- Size of a single histogram element.  Actually the weighted
    -- average of histogram elements in cases where we have more than
    -- one histogram operation, plus any locks.
    let lockSize SegHistSlug
slug = case SegHistSlug -> AtomicUpdate GPUMem KernelEnv
slugAtomicUpdate SegHistSlug
slug of
          AtomicLocking {} -> a -> Maybe a
forall a. a -> Maybe a
Just (a -> Maybe a) -> a -> Maybe a
forall a b. (a -> b) -> a -> b
$ PrimType -> a
forall a. Num a => PrimType -> a
primByteSize PrimType
int32
          AtomicUpdate GPUMem KernelEnv
_ -> Maybe a
forall a. Maybe a
Nothing
    hist_el_size <-
      dPrimVE "hist_el_size" $
        L.foldl' (+) (h `divUp` hist_H) $
          mapMaybe lockSize slugs

    -- Input elements contributing to each histogram.
    hist_N <- dPrimVE "hist_N" segment_size

    -- Compute RF as the average RF over all the histograms.
    hist_RF <-
      dPrimVE "hist_RF" $
        sExt32 $
          sum (map (pe64 . histRaceFactor . slugOp) slugs)
            `quot` L.genericLength slugs

    let hist_T = TExp Int64 -> TExp Int32
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (TExp Int64 -> TExp Int32) -> TExp Int64 -> TExp Int32
forall a b. (a -> b) -> a -> b
$ Count NumBlocks (TExp Int64) -> TExp Int64
forall {k} (u :: k) e. Count u e -> e
unCount Count NumBlocks (TExp Int64)
num_tblocks' TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* Count BlockSize (TExp Int64) -> TExp Int64
forall {k} (u :: k) e. Count u e -> e
unCount Count BlockSize (TExp Int64)
tblock_size'
    emit $ Imp.DebugPrint "\n# SegHist" Nothing
    emit $ Imp.DebugPrint "Number of threads (T)" $ Just $ untyped hist_T
    emit $ Imp.DebugPrint "Desired block size (B)" $ Just $ untyped hist_B
    emit $ Imp.DebugPrint "Histogram size (H)" $ Just $ untyped hist_H
    emit $ Imp.DebugPrint "Input elements per histogram (N)" $ Just $ untyped hist_N
    emit $
      Imp.DebugPrint "Number of segments" $
        Just $
          untyped $
            product $
              map (pe64 . snd) segment_dims
    emit $ Imp.DebugPrint "Histogram element size (el_size)" $ Just $ untyped hist_el_size
    emit $ Imp.DebugPrint "Race factor (RF)" $ Just $ untyped hist_RF
    emit $ Imp.DebugPrint "Memory per set of subhistograms per segment" $ Just $ untyped h
    emit $ Imp.DebugPrint "Memory per set of subhistograms times segments" $ Just $ untyped seg_h

    (use_shared_memory, run_in_shared_memory) <-
      localMemoryCase map_pes hist_T space hist_H hist_el_size hist_N hist_RF slugs kbody

    sIf use_shared_memory run_in_shared_memory $
      histKernelGlobal map_pes num_tblocks tblock_size space slugs kbody

    let pes_per_op = [Int] -> [PatElem LParamMem] -> [[PatElem LParamMem]]
forall a. [Int] -> [a] -> [[a]]
chunks ((HistOp GPUMem -> Int) -> [HistOp GPUMem] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([VName] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([VName] -> Int)
-> (HistOp GPUMem -> [VName]) -> HistOp GPUMem -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HistOp GPUMem -> [VName]
forall rep. HistOp rep -> [VName]
histDest) [HistOp GPUMem]
ops) [PatElem LParamMem]
all_red_pes

    forM_ (zip3 slugs pes_per_op ops) $ \(SegHistSlug
slug, [PatElem LParamMem]
red_pes, HistOp GPUMem
op) -> do
      let num_histos :: TV Int64
num_histos = SegHistSlug -> TV Int64
slugNumSubhistos SegHistSlug
slug
          subhistos :: [VName]
subhistos = (SubhistosInfo -> VName) -> [SubhistosInfo] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map SubhistosInfo -> VName
subhistosArray ([SubhistosInfo] -> [VName]) -> [SubhistosInfo] -> [VName]
forall a b. (a -> b) -> a -> b
$ SegHistSlug -> [SubhistosInfo]
slugSubhistos SegHistSlug
slug

      let unitHistoCase :: CallKernelGen ()
unitHistoCase =
            -- This is OK because the memory blocks are at least as
            -- large as the ones we are supposed to use for the result.
            [(PatElem LParamMem, VName)]
-> ((PatElem LParamMem, VName) -> CallKernelGen ())
-> CallKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([PatElem LParamMem] -> [VName] -> [(PatElem LParamMem, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [PatElem LParamMem]
red_pes [VName]
subhistos) (((PatElem LParamMem, VName) -> CallKernelGen ())
 -> CallKernelGen ())
-> ((PatElem LParamMem, VName) -> CallKernelGen ())
-> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(PatElem LParamMem
pe, VName
subhisto) -> do
              pe_mem <-
                MemLoc -> VName
memLocName (MemLoc -> VName) -> (ArrayEntry -> MemLoc) -> ArrayEntry -> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ArrayEntry -> MemLoc
entryArrayLoc
                  (ArrayEntry -> VName)
-> ImpM GPUMem HostEnv HostOp ArrayEntry
-> ImpM GPUMem HostEnv HostOp VName
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ImpM GPUMem HostEnv HostOp ArrayEntry
forall rep r op. VName -> ImpM rep r op ArrayEntry
lookupArray (PatElem LParamMem -> VName
forall dec. PatElem dec -> VName
patElemName PatElem LParamMem
pe)
              subhisto_mem <-
                memLocName . entryArrayLoc
                  <$> lookupArray subhisto
              emit $ Imp.SetMem pe_mem subhisto_mem $ Space "device"

      TExp Bool
-> CallKernelGen () -> CallKernelGen () -> CallKernelGen ()
forall rep r op.
TExp Bool
-> ImpM rep r op () -> ImpM rep r op () -> ImpM rep r op ()
sIf (TV Int64 -> TExp Int64
forall {k} (t :: k). TV t -> TExp t
tvExp TV Int64
num_histos TExp Int64 -> TExp Int64 -> TExp Bool
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int64
1) CallKernelGen ()
unitHistoCase (CallKernelGen () -> CallKernelGen ())
-> CallKernelGen () -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
        -- For the segmented reduction, we keep the segment dimensions
        -- unchanged.  To this, we add two dimensions: one over the number
        -- of buckets, and one over the number of subhistograms.  This
        -- inner dimension is the one that is collapsed in the reduction.
        bucket_ids <-
          Int
-> ImpM GPUMem HostEnv HostOp VName
-> ImpM GPUMem HostEnv HostOp [VName]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM (Shape -> Int
forall a. ArrayShape a => a -> Int
shapeRank (HistOp GPUMem -> Shape
forall rep. HistOp rep -> Shape
histShape HistOp GPUMem
op)) ([Char] -> ImpM GPUMem HostEnv HostOp VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"bucket_id")
        subhistogram_id <- newVName "subhistogram_id"
        vector_ids <-
          replicateM (shapeRank (histOpShape op)) (newVName "vector_id")

        flat_gtid <- newVName "flat_gtid"

        let grid = Count NumBlocks SubExp -> Count BlockSize SubExp -> KernelGrid
KernelGrid Count NumBlocks SubExp
num_tblocks Count BlockSize SubExp
tblock_size
            segred_space =
              VName -> [(VName, SubExp)] -> SegSpace
SegSpace VName
flat_gtid ([(VName, SubExp)] -> SegSpace) -> [(VName, SubExp)] -> SegSpace
forall a b. (a -> b) -> a -> b
$
                [(VName, SubExp)]
segment_dims
                  [(VName, SubExp)] -> [(VName, SubExp)] -> [(VName, SubExp)]
forall a. [a] -> [a] -> [a]
++ [VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
bucket_ids (Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims (HistOp GPUMem -> Shape
forall rep. HistOp rep -> Shape
histShape HistOp GPUMem
op))
                  [(VName, SubExp)] -> [(VName, SubExp)] -> [(VName, SubExp)]
forall a. [a] -> [a] -> [a]
++ [VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
vector_ids (Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims (Shape -> [SubExp]) -> Shape -> [SubExp]
forall a b. (a -> b) -> a -> b
$ HistOp GPUMem -> Shape
forall rep. HistOp rep -> Shape
histOpShape HistOp GPUMem
op)
                  [(VName, SubExp)] -> [(VName, SubExp)] -> [(VName, SubExp)]
forall a. [a] -> [a] -> [a]
++ [(VName
subhistogram_id, VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ TV Int64 -> VName
forall {k} (t :: k). TV t -> VName
tvVar TV Int64
num_histos)]
            -- The operator may have references to the old flat thread
            -- ID, which we must update to point at the new one.
            subst = VName -> VName -> Map VName VName
forall k a. k -> a -> Map k a
M.singleton (SegSpace -> VName
segFlat SegSpace
space) VName
flat_gtid

        let segred_op = Commutativity
-> Lambda GPUMem -> [SubExp] -> Shape -> SegBinOp GPUMem
forall rep.
Commutativity -> Lambda rep -> [SubExp] -> Shape -> SegBinOp rep
SegBinOp Commutativity
Commutative (Map VName VName -> Lambda GPUMem -> Lambda GPUMem
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst (Lambda GPUMem -> Lambda GPUMem) -> Lambda GPUMem -> Lambda GPUMem
forall a b. (a -> b) -> a -> b
$ HistOp GPUMem -> Lambda GPUMem
forall rep. HistOp rep -> Lambda rep
histOp HistOp GPUMem
op) (HistOp GPUMem -> [SubExp]
forall rep. HistOp rep -> [SubExp]
histNeutral HistOp GPUMem
op) Shape
forall a. Monoid a => a
mempty
        compileSegRed' (Pat red_pes) grid segred_space [segred_op] $ \[(SubExp, [TExp Int64])] -> InKernelGen ()
red_cont ->
          [(SubExp, [TExp Int64])] -> InKernelGen ()
red_cont ([(SubExp, [TExp Int64])] -> InKernelGen ())
-> ((VName -> (SubExp, [TExp Int64])) -> [(SubExp, [TExp Int64])])
-> (VName -> (SubExp, [TExp Int64]))
-> InKernelGen ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((VName -> (SubExp, [TExp Int64]))
 -> [VName] -> [(SubExp, [TExp Int64])])
-> [VName]
-> (VName -> (SubExp, [TExp Int64]))
-> [(SubExp, [TExp Int64])]
forall a b c. (a -> b -> c) -> b -> a -> c
flip (VName -> (SubExp, [TExp Int64]))
-> [VName] -> [(SubExp, [TExp Int64])]
forall a b. (a -> b) -> [a] -> [b]
map [VName]
subhistos ((VName -> (SubExp, [TExp Int64])) -> InKernelGen ())
-> (VName -> (SubExp, [TExp Int64])) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \VName
subhisto ->
            ( VName -> SubExp
Var VName
subhisto,
              (VName -> TExp Int64) -> [VName] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map VName -> TExp Int64
forall a. a -> TPrimExp Int64 a
Imp.le64 ([VName] -> [TExp Int64]) -> [VName] -> [TExp Int64]
forall a b. (a -> b) -> a -> b
$
                ((VName, SubExp) -> VName) -> [(VName, SubExp)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (VName, SubExp) -> VName
forall a b. (a, b) -> a
fst [(VName, SubExp)]
segment_dims
                  [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName
subhistogram_id]
                  [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
bucket_ids
                  [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
vector_ids
            )

  emit $ Imp.DebugPrint "" Nothing
  where
    segment_dims :: [(VName, SubExp)]
segment_dims = [(VName, SubExp)] -> [(VName, SubExp)]
forall a. HasCallStack => [a] -> [a]
init ([(VName, SubExp)] -> [(VName, SubExp)])
-> [(VName, SubExp)] -> [(VName, SubExp)]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space