Kernel programming

This section lists the package's public functionality that corresponds to special CUDA functions for use in device code. It is loosely organized according to the C language extensions appendix from the CUDA C programming guide. For more information about certain intrinsics, refer to the aforementioned NVIDIA documentation.

Indexing and dimensions

CUDA.warpsizeFunction
warpsize(dev::CuDevice)

Returns the warp size (in threads) of the device.

source
warpsize()::Int32

Returns the warp size (in threads).

source
CUDA.laneidFunction
laneid()::Int32

Returns the thread's lane within the warp.

source
CUDA.active_maskFunction
active_mask()

Returns a 32-bit mask indicating which threads in a warp are active with the current executing thread.

source

Device arrays

CUDA.jl provides a primitive, lightweight array type to manage GPU data organized in an plain, dense fashion. This is the device-counterpart to the CuArray, and implements (part of) the array interface as well as other functionality for use on the GPU:

CUDA.CuDeviceArrayType
CuDeviceArray{T,N,A}(ptr, dims, [maxsize])

Construct an N-dimensional dense CUDA device array with element type T wrapping a pointer, where N is determined from the length of dims and T is determined from the type of ptr. dims may be a single scalar, or a tuple of integers corresponding to the lengths in each dimension). If the rank N is supplied explicitly as in Array{T,N}(dims), then it must match the length of dims. The same applies to the element type T, which should match the type of the pointer ptr.

source
CUDA.ConstType
Const(A::CuDeviceArray)

Mark a CuDeviceArray as constant/read-only. The invariant guaranteed is that you will not modify an CuDeviceArray for the duration of the current kernel.

This API can only be used on devices with compute capability 3.5 or higher.

Warning

Experimental API. Subject to change without deprecation.

source

Memory types

Shared memory

CUDA.CuStaticSharedArrayFunction
CuStaticSharedArray(T::Type, dims) -> CuDeviceArray{T,N,AS.Shared}

Get an array of type T and dimensions dims (either an integer length or tuple shape) pointing to a statically-allocated piece of shared memory. The type should be statically inferable and the dimensions should be constant, or an error will be thrown and the generator function will be called dynamically.

source
CUDA.CuDynamicSharedArrayFunction
CuDynamicSharedArray(T::Type, dims, offset::Integer=0) -> CuDeviceArray{T,N,AS.Shared}

Get an array of type T and dimensions dims (either an integer length or tuple shape) pointing to a dynamically-allocated piece of shared memory. The type should be statically inferable or an error will be thrown and the generator function will be called dynamically.

Note that the amount of dynamic shared memory needs to specified when launching the kernel.

Optionally, an offset parameter indicating how many bytes to add to the base shared memory pointer can be specified. This is useful when dealing with a heterogeneous buffer of dynamic shared memory; in the case of a homogeneous multi-part buffer it is preferred to use view.

source

Texture memory

CUDA.CuDeviceTextureType
CuDeviceTexture{T,N,M,NC,I}

N-dimensional device texture with elements of type T. This type is the device-side counterpart of CuTexture{T,N,P}, and can be used to access textures using regular indexing notation. If NC is true, indices used by these accesses should be normalized, i.e., fall into the [0,1) domain. The I type parameter indicates the kind of interpolation that happens when indexing into this texture. The source memory of the texture is specified by the M parameter, either linear memory or a texture array.

Device-side texture objects cannot be created directly, but should be created host-side using CuTexture{T,N,P} and passed to the kernel as an argument.

Warning

Experimental API. Subject to change without deprecation.

source

Synchronization

CUDA.sync_threadsFunction
sync_threads()

Waits until all threads in the thread block have reached this point and all global and shared memory accesses made by these threads prior to sync_threads() are visible to all threads in the block.

source
CUDA.sync_threads_countFunction
sync_threads_count(predicate)

Identical to sync_threads() with the additional feature that it evaluates predicate for all threads of the block and returns the number of threads for which predicate evaluates to true.

source
CUDA.sync_threads_andFunction
sync_threads_and(predicate)

Identical to sync_threads() with the additional feature that it evaluates predicate for all threads of the block and returns true if and only if predicate evaluates to true for all of them.

source
CUDA.sync_threads_orFunction
sync_threads_or(predicate)

Identical to sync_threads() with the additional feature that it evaluates predicate for all threads of the block and returns true if and only if predicate evaluates to true for any of them.

source
CUDA.sync_warpFunction
sync_warp(mask::Integer=FULL_MASK)

Waits threads in the warp, selected by means of the bitmask mask, have reached this point and all global and shared memory accesses made by these threads prior to sync_warp() are visible to those threads in the warp. The default value for mask selects all threads in the warp.

Note

Requires CUDA >= 9.0 and sm_6.2

source
CUDA.threadfence_blockFunction
threadfence_block()

A memory fence that ensures that:

  • All writes to all memory made by the calling thread before the call to threadfence_block() are observed by all threads in the block of the calling thread as occurring before all writes to all memory made by the calling thread after the call to threadfence_block()
  • All reads from all memory made by the calling thread before the call to threadfence_block() are ordered before all reads from all memory made by the calling thread after the call to threadfence_block().
source
CUDA.threadfenceFunction
threadfence()

A memory fence that acts as threadfence_block for all threads in the block of the calling thread and also ensures that no writes to all memory made by the calling thread after the call to threadfence() are observed by any thread in the device as occurring before any write to all memory made by the calling thread before the call to threadfence().

Note that for this ordering guarantee to be true, the observing threads must truly observe the memory and not cached versions of it; this is requires the use of volatile loads and stores, which is not available from Julia right now.

source
CUDA.threadfence_systemFunction
threadfence_system()

A memory fence that acts as threadfence_block for all threads in the block of the calling thread and also ensures that all writes to all memory made by the calling thread before the call to threadfence_system() are observed by all threads in the device, host threads, and all threads in peer devices as occurring before all writes to all memory made by the calling thread after the call to threadfence_system().

source

Time functions

CUDA.clockFunction
clock(UInt32)

Returns the value of a per-multiprocessor counter that is incremented every clock cycle.

source
clock(UInt64)

Returns the value of a per-multiprocessor counter that is incremented every clock cycle.

source
CUDA.nanosleepFunction
nanosleep(t)

Puts a thread for a given amount t(in nanoseconds).

Note

Requires CUDA >= 10.0 and sm_6.2

source

Warp-level functions

Voting

The warp vote functions allow the threads of a given warp to perform a reduction-and-broadcast operation. These functions take as input a boolean predicate from each thread in the warp and evaluate it. The results of that evaluation are combined (reduced) across the active threads of the warp in one different ways, broadcasting a single return value to each participating thread.

CUDA.vote_all_syncFunction
vote_all_sync(mask::UInt32, predicate::Bool)

Evaluate predicate for all active threads of the warp and return whether predicate is true for all of them.

source
CUDA.vote_any_syncFunction
vote_any_sync(mask::UInt32, predicate::Bool)

Evaluate predicate for all active threads of the warp and return whether predicate is true for any of them.

source
CUDA.vote_uni_syncFunction
vote_uni_sync(mask::UInt32, predicate::Bool)

Evaluate predicate for all active threads of the warp and return whether predicate is the same for any of them.

source
CUDA.vote_ballot_syncFunction
vote_ballot_sync(mask::UInt32, predicate::Bool)

Evaluate predicate for all active threads of the warp and return an integer whose Nth bit is set if and only if predicate is true for the Nth thread of the warp and the Nth thread is active.

source

Shuffle

CUDA.shfl_syncFunction
shfl_sync(threadmask::UInt32, val, lane::Integer, width::Integer=32)

Shuffle a value from a directly indexed lane lane, and synchronize threads according to threadmask.

source
CUDA.shfl_up_syncFunction
shfl_up_sync(threadmask::UInt32, val, delta::Integer, width::Integer=32)

Shuffle a value from a lane with lower ID relative to caller, and synchronize threads according to threadmask.

source
CUDA.shfl_down_syncFunction
shfl_down_sync(threadmask::UInt32, val, delta::Integer, width::Integer=32)

Shuffle a value from a lane with higher ID relative to caller, and synchronize threads according to threadmask.

source
CUDA.shfl_xor_syncFunction
shfl_xor_sync(threadmask::UInt32, val, mask::Integer, width::Integer=32)

Shuffle a value from a lane based on bitwise XOR of own lane ID with mask, and synchronize threads according to threadmask.

source

Formatted Output

CUDA.@cushowMacro
@cushow(ex)

GPU analog of Base.@show. It comes with the same type restrictions as @cuprintf.

@cushow threadIdx().x
source
CUDA.@cuprintMacro
@cuprint(xs...)
@cuprintln(xs...)

Print a textual representation of values xs to standard output from the GPU. The functionality builds on @cuprintf, and is intended as a more use friendly alternative of that API. However, that also means there's only limited support for argument types, handling 16/32/64 signed and unsigned integers, 32 and 64-bit floating point numbers, Cchars and pointers. For more complex output, use @cuprintf directly.

Limited string interpolation is also possible:

    @cuprint("Hello, World ", 42, "\n")
    @cuprint "Hello, World $(42)\n"
source
CUDA.@cuprintlnMacro
@cuprint(xs...)
@cuprintln(xs...)

Print a textual representation of values xs to standard output from the GPU. The functionality builds on @cuprintf, and is intended as a more use friendly alternative of that API. However, that also means there's only limited support for argument types, handling 16/32/64 signed and unsigned integers, 32 and 64-bit floating point numbers, Cchars and pointers. For more complex output, use @cuprintf directly.

Limited string interpolation is also possible:

    @cuprint("Hello, World ", 42, "\n")
    @cuprint "Hello, World $(42)\n"
source
CUDA.@cuprintfMacro
@cuprintf("%Fmt", args...)

Print a formatted string in device context on the host standard output.

Note that this is not a fully C-compliant printf implementation; see the CUDA documentation for supported options and inputs.

Also beware that it is an untyped, and unforgiving printf implementation. Type widths need to match, eg. printing a 64-bit Julia integer requires the %ld formatting string.

source

Assertions

CUDA.@cuassertMacro
@assert cond [text]

Signal assertion failure to the CUDA driver if cond is false. Preferred syntax for writing assertions, mimicking Base.@assert. Message text is optionally displayed upon assertion failure.

Warning

A failed assertion will crash the GPU, so use sparingly as a debugging tool. Furthermore, the assertion might be disabled at various optimization levels, and thus should not cause any side-effects.

source

Atomics

A high-level macro is available to annotate expressions with:

CUDA.@atomicMacro
@atomic a[I] = op(a[I], val)
@atomic a[I] ...= val

Atomically perform a sequence of operations that loads an array element a[I], performs the operation op on that value and a second value val, and writes the result back to the array. This sequence can be written out as a regular assignment, in which case the same array element should be used in the left and right hand side of the assignment, or as an in-place application of a known operator. In both cases, the array reference should be pure and not induce any side-effects.

Warn

This interface is experimental, and might change without warning. Use the lower-level atomic_...! functions for a stable API, albeit one limited to natively-supported ops.

source

If your expression is not recognized, or you need more control, use the underlying functions:

CUDA.atomic_cas!Function
atomic_cas!(ptr::LLVMPtr{T}, cmp::T, val::T)

Reads the value old located at address ptr and compare with cmp. If old equals to cmp, stores val at the same address. Otherwise, doesn't change the value old. These operations are performed in one atomic transaction. The function returns old.

This operation is supported for values of type Int32, Int64, UInt32 and UInt64. Additionally, on GPU hardware with compute capability 7.0+, values of type UInt16 are supported.

source
CUDA.atomic_xchg!Function
atomic_xchg!(ptr::LLVMPtr{T}, val::T)

Reads the value old located at address ptr and stores val at the same address. These operations are performed in one atomic transaction. The function returns old.

This operation is supported for values of type Int32, Int64, UInt32 and UInt64.

source
CUDA.atomic_add!Function
atomic_add!(ptr::LLVMPtr{T}, val::T)

Reads the value old located at address ptr, computes old + val, and stores the result back to memory at the same address. These operations are performed in one atomic transaction. The function returns old.

This operation is supported for values of type Int32, Int64, UInt32, UInt64, and Float32. Additionally, on GPU hardware with compute capability 6.0+, values of type Float64 are supported.

source
CUDA.atomic_sub!Function
atomic_sub!(ptr::LLVMPtr{T}, val::T)

Reads the value old located at address ptr, computes old - val, and stores the result back to memory at the same address. These operations are performed in one atomic transaction. The function returns old.

This operation is supported for values of type Int32, Int64, UInt32 and UInt64.

source
CUDA.atomic_and!Function
atomic_and!(ptr::LLVMPtr{T}, val::T)

Reads the value old located at address ptr, computes old & val, and stores the result back to memory at the same address. These operations are performed in one atomic transaction. The function returns old.

This operation is supported for values of type Int32, Int64, UInt32 and UInt64.

source
CUDA.atomic_or!Function
atomic_or!(ptr::LLVMPtr{T}, val::T)

Reads the value old located at address ptr, computes old | val, and stores the result back to memory at the same address. These operations are performed in one atomic transaction. The function returns old.

This operation is supported for values of type Int32, Int64, UInt32 and UInt64.

source
CUDA.atomic_xor!Function
atomic_xor!(ptr::LLVMPtr{T}, val::T)

Reads the value old located at address ptr, computes old ⊻ val, and stores the result back to memory at the same address. These operations are performed in one atomic transaction. The function returns old.

This operation is supported for values of type Int32, Int64, UInt32 and UInt64.

source
CUDA.atomic_min!Function
atomic_min!(ptr::LLVMPtr{T}, val::T)

Reads the value old located at address ptr, computes min(old, val), and stores the result back to memory at the same address. These operations are performed in one atomic transaction. The function returns old.

This operation is supported for values of type Int32, Int64, UInt32 and UInt64.

source
CUDA.atomic_max!Function
atomic_max!(ptr::LLVMPtr{T}, val::T)

Reads the value old located at address ptr, computes max(old, val), and stores the result back to memory at the same address. These operations are performed in one atomic transaction. The function returns old.

This operation is supported for values of type Int32, Int64, UInt32 and UInt64.

source
CUDA.atomic_inc!Function
atomic_inc!(ptr::LLVMPtr{T}, val::T)

Reads the value old located at address ptr, computes ((old >= val) ? 0 : (old+1)), and stores the result back to memory at the same address. These three operations are performed in one atomic transaction. The function returns old.

This operation is only supported for values of type Int32.

source
CUDA.atomic_dec!Function
atomic_dec!(ptr::LLVMPtr{T}, val::T)

Reads the value old located at address ptr, computes (((old == 0) | (old > val)) ? val : (old-1) ), and stores the result back to memory at the same address. These three operations are performed in one atomic transaction. The function returns old.

This operation is only supported for values of type Int32.

source

Dynamic parallelism

Similarly to launching kernels from the host, you can use @cuda while passing dynamic=true for launching kernels from the device. A lower-level API is available as well:

CUDA.dynamic_cufunctionFunction
dynamic_cufunction(f, tt=Tuple{})

Low-level interface to compile a function invocation for the currently-active GPU, returning a callable kernel object. Device-side equivalent of CUDA.cufunction.

No keyword arguments are supported.

source
CUDA.DeviceKernelType
(::HostKernel)(args...; kwargs...)
(::DeviceKernel)(args...; kwargs...)

Low-level interface to call a compiled kernel, passing GPU-compatible arguments in args. For a higher-level interface, use @cuda.

A HostKernel is callable on the host, and a DeviceKernel is callable on the device (created by @cuda with dynamic=true).

The following keyword arguments are supported:

  • threads (default: 1): Number of threads per block, or a 1-, 2- or 3-tuple of dimensions (e.g. threads=(32, 32) for a 2D block of 32×32 threads). Use threadIdx() and blockDim() to query from within the kernel.
  • blocks (default: 1): Number of thread blocks to launch, or a 1-, 2- or 3-tuple of dimensions (e.g. blocks=(2, 4, 2) for a 3D grid of blocks). Use blockIdx() and gridDim() to query from within the kernel.
  • shmem(default: 0): Amount of dynamic shared memory in bytes to allocate per thread block; used by CuDynamicSharedArray.
  • stream (default: stream()): CuStream to launch the kernel on.
  • cooperative (default: false): whether to launch a cooperative kernel that supports grid synchronization (see CG.this_grid and CG.sync). Note that this requires care wrt. the number of blocks launched.
source

Cooperative groups

CUDA.CGModule

CUDA.jl's cooperative groups implementation.

Cooperative groups in CUDA offer a structured approach to synchronize and communicate among threads. They allow developers to define specific groups of threads, providing a means to fine-tune inter-thread communication granularity. By offering a more nuanced alternative to traditional CUDA synchronization methods, cooperative groups enable a more controlled and efficient parallel decomposition in kernel design.

The following functionality is available in CUDA.jl:

  • implicit groups: thread blocks, grid groups, and coalesced groups.
  • synchronization: sync, barrier_arrive, barrier_wait
  • warp collectives for coalesced groups: shuffle and voting
  • data transfer: memcpy_async, wait and wait_prior

Noteworthy missing functionality:

  • implicit groups: clusters, and multi-grid groups (which are deprecated)
  • explicit groups: tiling and partitioning
source

Group construction and properties

CUDA.CG.thread_rankFunction
thread_rank(group)

Returns the linearized rank of the calling thread along the interval [1, num_threads()].

source
CUDA.CG.thread_blockType
thread_block <: thread_group

Every GPU kernel is executed by a grid of thread blocks, and threads within each block are guaranteed to reside on the same streaming multiprocessor. A thread_block represents a thread block whose dimensions are not known until runtime.

Constructed via this_thread_block

source
CUDA.CG.group_indexFunction
group_index(tb::thread_block)

3-Dimensional index of the block within the launched grid.

source
CUDA.CG.grid_groupType
grid_group <: thread_group

Threads within this this group are guaranteed to be co-resident on the same device within the same launched kernel. To use this group, the kernel must have been launched with @cuda cooperative=true, and the device must support it (queryable device attribute).

Constructed via this_grid.

source
CUDA.CG.coalesced_groupType
coalesced_group <: thread_group

A group representing the current set of converged threads in a warp. The size of the group is not guaranteed and it may return a group of only one thread (itself).

This group exposes warp-synchronous builtins. Constructed via coalesced_threads.

source
CUDA.CG.meta_group_sizeFunction
meta_group_size(cg::coalesced_group)

Total number of partitions created out of all CTAs when the group was created.

source

Synchronization

Data transfer

CUDA.CG.memcpy_asyncFunction
memcpy_async(group, dst, src, bytes)

Perform a group-wide collective memory copy from src to dst of bytes bytes. This operation may be performed asynchronously, so you should wait or wait_prior before using the data. It is only supported by thread blocks and coalesced groups.

For this operation to be performed asynchronously, the following conditions must be met:

  • the source and destination memory should be aligned to 4, 8 or 16 bytes. this will be deduced from the datatype, but can also be specified explicitly using CUDA.align.
  • the source should be global memory, and the destination should be shared memory.
  • the device should have compute capability 8.0 or higher.
source

Math

Many mathematical functions are provided by the libdevice library, and are wrapped by CUDA.jl. These functions are used to implement well-known functions from the Julia standard library and packages like SpecialFunctions.jl, e.g., calling the cos function will automatically use __nv_cos from libdevice if possible.

Some functions do not have a counterpart in the Julia ecosystem, those have to be called directly. For example, to call __nv_logb or __nv_logbf you use CUDA.logb in a kernel.

For a list of available functions, look at src/device/intrinsics/math.jl.

WMMA

Warp matrix multiply-accumulate (WMMA) is a CUDA API to access Tensor Cores, a new hardware feature in Volta GPUs to perform mixed precision matrix multiply-accumulate operations. The interface is split in two levels, both available in the WMMA submodule: low level wrappers around the LLVM intrinsics, and a higher-level API similar to that of CUDA C.

LLVM Intrinsics

Load matrix

CUDA.WMMA.llvm_wmma_loadFunction
WMMA.llvm_wmma_load_{matrix}_{layout}_{shape}_{addr_space}_stride_{elem_type}(src_addr, stride)

Wrapper around the LLVM intrinsic @llvm.nvvm.wmma.load.{matrix}.sync.{layout}.{shape}.{addr_space}.stride.{elem_type}.

Arguments

  • src_addr: The memory address to load from.
  • stride: The leading dimension of the matrix, in numbers of elements.

Placeholders

  • {matrix}: The matrix to load. Can be a, b or c.
  • {layout}: The storage layout for the matrix. Can be row or col, for row major (C style) or column major (Julia style), respectively.
  • {shape}: The overall shape of the MAC operation. Valid values are m16n16k16, m32n8k16, and m8n32k16.
  • {addr_space}: The address space of src_addr. Can be empty (generic addressing), shared or global.
  • {elem_type}: The type of each element in the matrix. For a and b matrices, valid values are u8 (byte unsigned integer), s8 (byte signed integer), and f16 (half precision floating point). For c and d matrices, valid values are s32 (32-bit signed integer), f16 (half precision floating point), and f32 (full precision floating point).
source

Perform multiply-accumulate

CUDA.WMMA.llvm_wmma_mmaFunction
WMMA.llvm_wmma_mma_{a_layout}_{b_layout}_{shape}_{d_elem_type}_{c_elem_type}(a, b, c) or
WMMA.llvm_wmma_mma_{a_layout}_{b_layout}_{shape}_{a_elem_type}(a, b, c)

For floating point operations: wrapper around the LLVM intrinsic @llvm.nvvm.wmma.mma.sync.{a_layout}.{b_layout}.{shape}.{d_elem_type}.{c_elem_type} For all other operations: wrapper around the LLVM intrinsic @llvm.nvvm.wmma.mma.sync.{a_layout}.{b_layout}.{shape}.{a_elem_type}

Arguments

  • a: The WMMA fragment corresponding to the matrix $A$.
  • b: The WMMA fragment corresponding to the matrix $B$.
  • c: The WMMA fragment corresponding to the matrix $C$.

Placeholders

  • {a_layout}: The storage layout for matrix $A$. Can be row or col, for row major (C style) or column major (Julia style), respectively. Note that this must match the layout used in the load operation.
  • {b_layout}: The storage layout for matrix $B$. Can be row or col, for row major (C style) or column major (Julia style), respectively. Note that this must match the layout used in the load operation.
  • {shape}: The overall shape of the MAC operation. Valid values are m16n16k16, m32n8k16, and m8n32k16.
  • {a_elem_type}: The type of each element in the $A$ matrix. Valid values are u8 (byte unsigned integer), s8 (byte signed integer), and f16 (half precision floating point).
  • {d_elem_type}: The type of each element in the resultant $D$ matrix. Valid values are s32 (32-bit signed integer), f16 (half precision floating point), and f32 (full precision floating point).
  • {c_elem_type}: The type of each element in the $C$ matrix. Valid values are s32 (32-bit signed integer), f16 (half precision floating point), and f32 (full precision floating point).
Warning

Remember that the shape, type and layout of all operations (be it MMA, load or store) MUST match. Otherwise, the behaviour is undefined!

source

Store matrix

CUDA.WMMA.llvm_wmma_storeFunction
WMMA.llvm_wmma_store_d_{layout}_{shape}_{addr_space}_stride_{elem_type}(dst_addr, data, stride)

Wrapper around the LLVM intrinsic @llvm.nvvm.wmma.store.d.sync.{layout}.{shape}.{addr_space}.stride.{elem_type}.

Arguments

  • dst_addr: The memory address to store to.
  • data: The $D$ fragment to store.
  • stride: The leading dimension of the matrix, in numbers of elements.

Placeholders

  • {layout}: The storage layout for the matrix. Can be row or col, for row major (C style) or column major (Julia style), respectively.
  • {shape}: The overall shape of the MAC operation. Valid values are m16n16k16, m32n8k16, and m8n32k16.
  • {addr_space}: The address space of src_addr. Can be empty (generic addressing), shared or global.
  • {elem_type}: The type of each element in the matrix. For a and b matrices, valid values are u8 (byte unsigned integer), s8 (byte signed integer), and f16 (half precision floating point). For c and d matrices, valid values are s32 (32-bit signed integer), f16 (half precision floating point), and f32 (full precision floating point).
source

CUDA C-like API

Fragment

CUDA.WMMA.UnspecifiedType
WMMA.Unspecified

Type that represents a matrix stored in an unspecified order.

Warning

This storage format is not valid for all WMMA operations!

source
CUDA.WMMA.FragmentType
WMMA.Fragment

Type that represents per-thread intermediate results of WMMA operations.

You can access individual elements using the x member or [] operator, but beware that the exact ordering of elements is unspecified.

source

WMMA configuration

CUDA.WMMA.ConfigType
WMMA.Config{M, N, K, d_type}

Type that contains all information for WMMA operations that cannot be inferred from the argument's types.

WMMA instructions calculate the matrix multiply-accumulate operation $D = A \cdot B + C$, where $A$ is a $M \times K$ matrix, $B$ a $K \times N$ matrix, and $C$ and $D$ are $M \times N$ matrices.

d_type refers to the type of the elements of matrix $D$, and can be either Float16 or Float32.

All WMMA operations take a Config as their final argument.

Examples

julia> config = WMMA.Config{16, 16, 16, Float32}
CUDA.WMMA.Config{16, 16, 16, Float32}
source

Load matrix

CUDA.WMMA.load_aFunction
WMMA.load_a(addr, stride, layout, config)
WMMA.load_b(addr, stride, layout, config)
WMMA.load_c(addr, stride, layout, config)

Load the matrix a, b or c from the memory location indicated by addr, and return the resulting WMMA.Fragment.

Arguments

  • addr: The address to load the matrix from.
  • stride: The leading dimension of the matrix pointed to by addr, specified in number of elements.
  • layout: The storage layout of the matrix. Possible values are WMMA.RowMajor and WMMA.ColMajor.
  • config: The WMMA configuration that should be used for loading this matrix. See WMMA.Config.

See also: WMMA.Fragment, WMMA.FragmentLayout, WMMA.Config

Warning

All threads in a warp MUST execute the load operation in lockstep, and have to use exactly the same arguments. Failure to do so will result in undefined behaviour.

source

WMMA.load_b and WMMA.load_c have the same signature.

Perform multiply-accumulate

CUDA.WMMA.mmaFunction
WMMA.mma(a, b, c, conf)

Perform the matrix multiply-accumulate operation $D = A \cdot B + C$.

Arguments

Warning

All threads in a warp MUST execute the mma operation in lockstep, and have to use exactly the same arguments. Failure to do so will result in undefined behaviour.

source

Store matrix

CUDA.WMMA.store_dFunction
WMMA.store_d(addr, d, stride, layout, config)

Store the result matrix d to the memory location indicated by addr.

Arguments

  • addr: The address to store the matrix to.
  • d: The WMMA.Fragment corresponding to the d matrix.
  • stride: The leading dimension of the matrix pointed to by addr, specified in number of elements.
  • layout: The storage layout of the matrix. Possible values are WMMA.RowMajor and WMMA.ColMajor.
  • config: The WMMA configuration that should be used for storing this matrix. See WMMA.Config.

See also: WMMA.Fragment, WMMA.FragmentLayout, WMMA.Config

Warning

All threads in a warp MUST execute the store operation in lockstep, and have to use exactly the same arguments. Failure to do so will result in undefined behaviour.

source

Fill fragment

CUDA.WMMA.fill_cFunction
WMMA.fill_c(value, config)

Return a WMMA.Fragment filled with the value value.

This operation is useful if you want to implement a matrix multiplication (and thus want to set $C = O$).

Arguments

  • value: The value used to fill the fragment. Can be a Float16 or Float32.
  • config: The WMMA configuration that should be used for this WMMA operation. See WMMA.Config.
source

Other

CUDA.alignType
CUDA.align{N}(obj)

Construct an aligned object, providing alignment information to APIs that require it.

source