Compiler

Execution

The main entry-point to the compiler is the @cuda macro:

CUDACore.@cudaMacro
@cuda [kwargs...] func(args...)

High-level interface for executing code on a GPU. The @cuda macro should prefix a call, with func a callable function or object that should return nothing. It will be compiled to a CUDA function upon first use, and to a certain extent arguments will be converted and managed automatically using cudaconvert. Finally, a call to cudacall is performed, scheduling a kernel launch on the current CUDA context.

Several keyword arguments are supported that influence the behavior of @cuda.

  • launch: whether to launch this kernel, defaults to true. If false the returned kernel object should be launched by calling it and passing arguments again.
  • dynamic: use dynamic parallelism to launch device-side kernels, defaults to false.
  • backend: which compiler backend to use, defaults to LLVMBackend. Either an AbstractBackend instance or a module that defines DefaultBackend() (e.g. backend=CUDA resolves to CUDA.DefaultBackend()). Backend-specific compiler kwargs not recognized by @cuda itself are forwarded to kernel_compile.
  • arguments that influence kernel compilation: see cufunction and dynamic_cufunction
  • arguments that influence kernel launch: see CUDACore.HostKernel and CUDACore.DeviceKernel
source

If needed, you can use a lower-level API that lets you inspect the compiler kernel:

CUDACore.cudaconvertFunction
cudaconvert(x)

This function is called for every argument to be passed to a kernel, allowing it to be converted to a GPU-friendly format. By default, the function does nothing and returns the input object x as-is.

Do not add methods to this function, but instead extend the underlying Adapt.jl package and register methods for the the CUDA.KernelAdaptor type.

source
CUDACore.cufunctionFunction
cufunction(f, tt=Tuple{}; kwargs...)

Low-level interface to compile a function invocation for the currently-active GPU, returning a callable kernel object. For a higher-level interface, use @cuda.

The following keyword arguments are supported:

  • minthreads: the required number of threads in a thread block
  • maxthreads: the maximum number of threads in a thread block
  • blocks_per_sm: a minimum number of thread blocks to be scheduled on a single multiprocessor
  • maxregs: the maximum number of registers to be allocated to a single thread (only supported on LLVM 4.0+)
  • name: override the name that the kernel will have in the generated code
  • always_inline: inline all function calls in the kernel
  • fastmath: use less precise square roots and flush denormals
  • arch and ptx: override the GPU architecture (matching nvcc/ptxas -arch) and the PTX ISA version to compile for. arch accepts either an SMVersion via the sm"..." string macro (e.g. arch=sm"103a" for architecture-accelerated codegen on CC 10.3, or arch=sm"100f" for family-portable Blackwell codegen) or a VersionNumber (e.g. arch=v"10.3", treated as baseline / forward-compatible). The old kwarg name cap= is accepted as a deprecated alias.

The output of this function is automatically cached, i.e. you can simply call cufunction in a hot path without degrading performance. New code will be generated automatically, when when function changes, or when different types or keyword arguments are provided.

source
CUDACore.AbstractKernelType
(::HostKernel)(args...; kwargs...)
(::DeviceKernel)(args...; kwargs...)

Low-level interface to call a compiled kernel, passing GPU-compatible arguments in args. For a higher-level interface, use @cuda.

A HostKernel is callable on the host, and a DeviceKernel is callable on the device (created by @cuda with dynamic=true).

The following keyword arguments are supported:

  • threads (default: 1): Number of threads per block, or a 1-, 2- or 3-tuple of dimensions (e.g. threads=(32, 32) for a 2D block of 32×32 threads). Use threadIdx() and blockDim() to query from within the kernel.
  • blocks (default: 1): Number of thread blocks to launch, or a 1-, 2- or 3-tuple of dimensions (e.g. blocks=(2, 4, 2) for a 3D grid of blocks). Use blockIdx() and gridDim() to query from within the kernel.
  • clustersize (default: 1): Number of thread blocks to launch as a cooperative cluster, or a 1-, 2- or 3-tuple of dimensions (e.g. clustersize=(2, 2, 2) for a 3D grid). Use clusterIdx() and clusterDim() to query from within the kernel. Only supported on compute capability 9.0 and above. If clustersize=1, no clusters are launched.
  • shmem(default: 0): Amount of dynamic shared memory in bytes to allocate per thread block; used by CuDynamicSharedArray.
  • stream (default: stream()): CuStream to launch the kernel on.
  • cooperative (default: false): whether to launch a cooperative kernel that supports grid synchronization (see CG.this_grid and CG.sync). Note that this requires care wrt. the number of blocks launched.
source
CUDACore.HostKernelType
(::HostKernel)(args...; kwargs...)
(::DeviceKernel)(args...; kwargs...)

Low-level interface to call a compiled kernel, passing GPU-compatible arguments in args. For a higher-level interface, use @cuda.

A HostKernel is callable on the host, and a DeviceKernel is callable on the device (created by @cuda with dynamic=true).

The following keyword arguments are supported:

  • threads (default: 1): Number of threads per block, or a 1-, 2- or 3-tuple of dimensions (e.g. threads=(32, 32) for a 2D block of 32×32 threads). Use threadIdx() and blockDim() to query from within the kernel.
  • blocks (default: 1): Number of thread blocks to launch, or a 1-, 2- or 3-tuple of dimensions (e.g. blocks=(2, 4, 2) for a 3D grid of blocks). Use blockIdx() and gridDim() to query from within the kernel.
  • clustersize (default: 1): Number of thread blocks to launch as a cooperative cluster, or a 1-, 2- or 3-tuple of dimensions (e.g. clustersize=(2, 2, 2) for a 3D grid). Use clusterIdx() and clusterDim() to query from within the kernel. Only supported on compute capability 9.0 and above. If clustersize=1, no clusters are launched.
  • shmem(default: 0): Amount of dynamic shared memory in bytes to allocate per thread block; used by CuDynamicSharedArray.
  • stream (default: stream()): CuStream to launch the kernel on.
  • cooperative (default: false): whether to launch a cooperative kernel that supports grid synchronization (see CG.this_grid and CG.sync). Note that this requires care wrt. the number of blocks launched.
source
CUDACore.versionFunction
version(k::HostKernel)

Queries the PTX and SM versions a kernel was compiled for. Returns a named tuple.

source
CUDACore.maxthreadsFunction
maxthreads(k::HostKernel)

Queries the maximum amount of threads a kernel can use in a single block.

source
CUDACore.memoryFunction
memory(k::HostKernel)

Queries the local, shared and constant memory usage of a compiled kernel in bytes. Returns a named tuple.

source

The PTX compilation target is identified by an SMVersion, constructed via the sm"..." string macro:

CUDACore.SMVersionType
SMVersion(major, minor, [feature_set])
SMVersion(s::AbstractString)
SMVersion(v::VersionNumber)
SMVersion(sm::SMVersion)

A PTX compilation target, identifying a CUDA compute capability together with the subtarget feature set selected by the suffix on its .target directive. Printed and parsed in NVIDIA's compact form – sm"90" for compute capability 9.0, sm"103a" for 10.3 architecture-accelerated, etc. – to mirror the .target sm_NN[a|f] notation in the PTX ISA reference and to distinguish visually from a device-level VersionNumber like v"9.0".

The single-argument constructors normalize various inputs to an SMVersion:

  • SMVersion(::AbstractString) parses the compact form, with or without the sm_ prefix (so e.g. SMVersion("sm_103a") and SMVersion("103a") both work).
  • SMVersion(::VersionNumber) promotes a plain compute-capability version to a baseline SMVersion (SMVersion(v"10.3") == SMVersion(10, 3, :baseline)).
  • SMVersion(::SMVersion) is the identity (idempotent).

This is what lets @cuda arch=... accept v"10.3", sm"103a", "sm_103a", or an already-constructed SMVersion interchangeably.

feature_set is one of:

  • :baseline (no suffix, e.g. sm_90) — forward-compatible (the "onion model"): PTX compiled for sm_X runs on any sm_Y with Y >= X.
  • :family (f suffix, e.g. sm_100f) — same-major-family-portable: PTX runs on any device in the same architecture family (currently == same major version) at or above this CC.
  • :arch (a suffix, e.g. sm_90a) — locked to one exact CC: PTX runs only on devices with exactly this compute capability, but in exchange gets access to architecture-accelerated features.

See NVIDIA's PTX ISA reference under .target for the full compatibility rules, and lib/Target/NVPTX/NVPTX.td in LLVM for the corresponding subtarget feature definitions.

Public fields:

  • sm.major::Int
  • sm.minor::Int
  • sm.feature_set::Symbol

See also @sm_str for an ergonomic string-macro constructor.

Examples

julia> SMVersion(9, 0)            # baseline
sm"90"

julia> SMVersion(9, 0, :arch)
sm"90a"

julia> sm"100f" == SMVersion(10, 0, :family)
true
source
CUDACore.@sm_strMacro
@sm_str

String macro used to parse a string to an SMVersion. Accepts NVIDIA's compact sm_NN[a|f] notation (with or without the sm_ prefix): sm"90" for baseline, sm"90a" for architecture-accelerated, sm"100f" for family-specific. Equivalent to calling SMVersion(str); parses at macro-expansion time, so the resulting SMVersion is a compile-time constant in the surrounding expression.

Examples

julia> sm"103a"
sm"103a"

julia> sm"100f" == SMVersion(10, 0, :family)
true
source

To plug in alternative compiler back-ends (e.g. cuTile.jl), @cuda dispatches through a small protocol:

CUDACore.AbstractBackendType
AbstractBackend

Abstract supertype for @cuda backend dispatch. The default backend is LLVMBackend, which compiles SIMT/PTX kernels via cufunction. Other backends (e.g. Tile IR via cuTile.jl) register a subtype and define methods for kernel_convert and kernel_compile; @cuda backend=... then routes through them.

@cuda backend=... accepts either an AbstractBackend instance or a module that defines DefaultBackend() returning one (e.g. @cuda backend=cuTile ... resolves to cuTile.DefaultBackend()).

source
CUDACore.DefaultBackendFunction
DefaultBackend()

Returns the default @cuda backend for this module (LLVMBackend). This makes @cuda backend=CUDA ... (or backend=CUDACore) resolve to LLVMBackend, mirroring the convention used by other backend packages (e.g. @cuda backend=cuTile ... resolves to cuTile.DefaultBackend()).

source
CUDACore.kernel_convertFunction
kernel_convert(backend, x)

Convert a host-side launch argument to its kernel-side form. The default implementation for LLVMBackend forwards to cudaconvert; other backends override to produce backend-specific argument types.

source

Reflection

If you want to inspect generated code, you can use macros that resemble functionality from the InteractiveUtils standard library:

@device_code_lowered
@device_code_typed
@device_code_warntype
@device_code_llvm
@device_code_ptx
@device_code_sass
@device_code

These macros are also available in function-form:

CUDA.code_typed
CUDA.code_warntype
CUDA.code_llvm
CUDA.code_ptx
CUDA.code_sass

For more information, please consult the GPUCompiler.jl documentation. Only the code_sass functionality is actually defined in CUDA.jl:

CUDATools.code_sassFunction
code_sass([io], f, types; raw=false)
code_sass(f, [io]; raw=false)

Prints the SASS code corresponding to one or more CUDA modules to io, which defaults to stdout.

If providing both f and types, it is assumed that this uniquely identifies a kernel function, for which SASS code will be generated, and printed to io.

If only providing a callable function f, typically specified using the do syntax, the SASS code for all modules executed during evaluation of f will be printed. This can be convenient to display the SASS code for functions whose source code is not available.

  • raw: dump the assembly like nvdisasm reports it, without post-processing;
  • in the case of specifying f and types: all keyword arguments from cufunction

See also: @device_code_sass

source