cuDNN

Public

Private

cuDNN.cudnnDropoutSeedConstant
cudnnDropoutForward(x; dropout=0.5)
cudnnDropoutForward(x, d::cudnnDropoutDescriptor)
cudnnDropoutForward!(y, x; dropout=0.5)
cudnnDropoutForward!(y, x, d::cudnnDropoutDescriptor)

Return a new array similar to x where approximately dropout fraction of the values are replaced by a 0, and the rest are scaled by 1/(1-dropout). Optionally y holds the result and d specifies the operation. y should be similar to x if specified.

The user can set the global seed cudnnDropoutSeed[] to a positive number to always drop the same values deterministically for debugging. Note that this slows down the operation by about 40x.

The global constant cudnnDropoutState::Dict holds the random number generator state for each cuDNN handle.

source
cuDNN.cudnnDropoutStateConstant
cudnnDropoutForward(x; dropout=0.5)
cudnnDropoutForward(x, d::cudnnDropoutDescriptor)
cudnnDropoutForward!(y, x; dropout=0.5)
cudnnDropoutForward!(y, x, d::cudnnDropoutDescriptor)

Return a new array similar to x where approximately dropout fraction of the values are replaced by a 0, and the rest are scaled by 1/(1-dropout). Optionally y holds the result and d specifies the operation. y should be similar to x if specified.

The user can set the global seed cudnnDropoutSeed[] to a positive number to always drop the same values deterministically for debugging. Note that this slows down the operation by about 40x.

The global constant cudnnDropoutState::Dict holds the random number generator state for each cuDNN handle.

source
cuDNN.cudnnAttnDescriptorType
cudnnAttnDescriptor(attnMode::Cuint,
                    nHeads::Cint,
                    smScaler::Cdouble,
                    dataType::cudnnDataType_t,
                    computePrec::cudnnDataType_t,
                    mathType::cudnnMathType_t,
                    attnDropoutDesc::cudnnDropoutDescriptor_t,
                    postDropoutDesc::cudnnDropoutDescriptor_t,
                    qSize::Cint,
                    kSize::Cint,
                    vSize::Cint,
                    qProjSize::Cint,
                    kProjSize::Cint,
                    vProjSize::Cint,
                    oProjSize::Cint,
                    qoMaxSeqLength::Cint,
                    kvMaxSeqLength::Cint,
                    maxBatchSize::Cint,
                    maxBeamSize::Cint)
source
cuDNN.cudnnCTCLossDescriptorType
cudnnCTCLossDescriptor(compType::cudnnDataType_t,
                       normMode::cudnnLossNormalizationMode_t,
                       gradMode::cudnnNanPropagation_t,
                       maxLabelLength::Cint)
source
cuDNN.cudnnConvolutionDescriptorType

cudnnConvolutionDescriptor(pad::Vector{Cint}, stride::Vector{Cint}, dilation::Vector{Cint}, mode::cudnnConvolutionModet, dataType::cudnnDataTypet, groupCount::Cint, mathType::cudnnMathTypet, reorderType::cudnnReorderTypet)

source
cuDNN.cudnnPoolingDescriptorType
cudnnPoolingDescriptor(mode::cudnnPoolingMode_t,
                       maxpoolingNanOpt::cudnnNanPropagation_t,
                       nbDims::Cint,
                       windowDimA::Vector{Cint},
                       paddingA::Vector{Cint},
                       strideA::Vector{Cint})
source
cuDNN.cudnnRNNDataDescriptorType
cudnnRNNDataDescriptor(dataType::cudnnDataType_t,
                       layout::cudnnRNNDataLayout_t,
                       maxSeqLength::Cint,
                       batchSize::Cint,
                       vectorSize::Cint,
                       seqLengthArray::Vector{Cint},
                       paddingFill::Ptr{Cvoid})
source
cuDNN.cudnnRNNDescriptorType
cudnnRNNDescriptor(algo::cudnnRNNAlgo_t,
                   cellMode::cudnnRNNMode_t,
                   biasMode::cudnnRNNBiasMode_t,
                   dirMode::cudnnDirectionMode_t,
                   inputMode::cudnnRNNInputMode_t,
                   dataType::cudnnDataType_t,
                   mathPrec::cudnnDataType_t,
                   mathType::cudnnMathType_t,
                   inputSize::Int32,
                   hiddenSize::Int32,
                   projSize::Int32,
                   numLayers::Int32,
                   dropoutDesc::cudnnDropoutDescriptor_t,
                   auxFlags::UInt32)
source
cuDNN.cudnnReduceTensorDescriptorType
cudnnReduceTensorDescriptor(reduceTensorOp::cudnnReduceTensorOp_t,
                            reduceTensorCompType::cudnnDataType_t,
                            reduceTensorNanOpt::cudnnNanPropagation_t,
                            reduceTensorIndices::cudnnReduceTensorIndices_t,
                            reduceTensorIndicesType::cudnnIndicesType_t)
source
cuDNN.cudnnSeqDataDescriptorType
cudnnSeqDataDescriptor(dataType::cudnnDataType_t,
                       nbDims::Cint,
                       dimA::Vector{Cint},
                       axes::Vector{cudnnSeqDataAxis_t},
                       seqLengthArraySize::Csize_t,
                       seqLengthArray::Vector{Cint},
                       paddingFill::Ptr{Cvoid})
source
cuDNN.cudnnTensorTransformDescriptorType
cudnnTensorTransformDescriptor(nbDims::UInt32,
                               destFormat::cudnnTensorFormat_t,
                               padBeforeA::Vector{Int32},
                               padAfterA::Vector{Int32},
                               foldA::Vector{UInt32},
                               direction::cudnnFoldingDirection_t)
source
cuDNN.cudnnActivationForwardFunction
cudnnActivationForward(x; mode, nanOpt, coef, alpha)
cudnnActivationForward(x, d::cudnnActivationDescriptor; alpha)
cudnnActivationForward!(y, x; mode, nanOpt, coef, alpha, beta)
cudnnActivationForward!(y, x, d::cudnnActivationDescriptor; alpha, beta)

Return the result of the specified elementwise activation operation applied to x. Optionally y holds the result and d specifies the operation. y should be similar to x if specified. Keyword arguments alpha=1, beta=0 can be used for scaling, i.e. y .= alpha * op.(x) .+ beta * y. The following keyword arguments specify the operation if d is not given:

  • mode = CUDNN_ACTIVATION_RELU: Options are SIGMOID, RELU, TANH, CLIPPED_RELU, ELU, IDENTITY
  • nanOpt = CUDNN_NOT_PROPAGATE_NAN: NaN propagation policy, the other option is CUDNN_PROPAGATE_NAN
  • coef=1: When the activation mode is set to CUDNN_ACTIVATION_CLIPPED_RELU, this input specifies the clipping threshold; and when the activation mode is set to CUDNN_ACTIVATION_ELU, this input specifies the alpha parameter.
source
cuDNN.cudnnActivationForward!Function
cudnnActivationForward(x; mode, nanOpt, coef, alpha)
cudnnActivationForward(x, d::cudnnActivationDescriptor; alpha)
cudnnActivationForward!(y, x; mode, nanOpt, coef, alpha, beta)
cudnnActivationForward!(y, x, d::cudnnActivationDescriptor; alpha, beta)

Return the result of the specified elementwise activation operation applied to x. Optionally y holds the result and d specifies the operation. y should be similar to x if specified. Keyword arguments alpha=1, beta=0 can be used for scaling, i.e. y .= alpha * op.(x) .+ beta * y. The following keyword arguments specify the operation if d is not given:

  • mode = CUDNN_ACTIVATION_RELU: Options are SIGMOID, RELU, TANH, CLIPPED_RELU, ELU, IDENTITY
  • nanOpt = CUDNN_NOT_PROPAGATE_NAN: NaN propagation policy, the other option is CUDNN_PROPAGATE_NAN
  • coef=1: When the activation mode is set to CUDNN_ACTIVATION_CLIPPED_RELU, this input specifies the clipping threshold; and when the activation mode is set to CUDNN_ACTIVATION_ELU, this input specifies the alpha parameter.
source
cuDNN.cudnnAddTensorFunction
cudnnAddTensor(x, b; alpha)
cudnnAddTensor!(y, x, b; alpha, beta)

Broadcast-add tensor b to tensor x. alpha=1, beta=1 are used for scaling, i.e. y .= alpha * b .+ beta * x. cudnnAddTensor allocates a new array for the answer, cudnnAddTensor! overwrites y. Does not support all valid broadcasting dimensions. For more flexible broadcast operations see cudnnOpTensor.

source
cuDNN.cudnnAddTensor!Function
cudnnAddTensor(x, b; alpha)
cudnnAddTensor!(y, x, b; alpha, beta)

Broadcast-add tensor b to tensor x. alpha=1, beta=1 are used for scaling, i.e. y .= alpha * b .+ beta * x. cudnnAddTensor allocates a new array for the answer, cudnnAddTensor! overwrites y. Does not support all valid broadcasting dimensions. For more flexible broadcast operations see cudnnOpTensor.

source
cuDNN.cudnnConvolutionBwdDataAlgoPerfFunction
cudnnConvolutionBwdDataAlgoPerf(wDesc, w, dyDesc, dy, convDesc, dxDesc, dx, allocateTmpBuf=true)

allocateTmpBuf controls whether a temporary buffer is allocated for the input gradient dx. It can be set to false when beta is zero to save an allocation and must otherwise be set to true.

source
cuDNN.cudnnConvolutionBwdFilterAlgoPerfFunction
cudnnConvolutionBwdFilterAlgoPerf(xDesc, x, dyDesc, dy, convDesc, dwDesc, dw, allocateTmpBuf=true)

allocateTmpBuf controls whether a temporary buffer is allocated for the weight gradient dw. It can be set to false when beta is zero to save an allocation and must otherwise be set to true.

source
cuDNN.cudnnConvolutionForwardFunction
cudnnConvolutionForward(w, x; bias, activation, mode, padding, stride, dilation, group, mathType, reorderType, alpha, beta, z, format)
cudnnConvolutionForward(w, x, d::cudnnConvolutionDescriptor; bias, activation, alpha, beta, z, format)
cudnnConvolutionForward!(y, w, x; bias, activation, mode, padding, stride, dilation, group, mathType, reorderType, alpha, beta, z, format)
cudnnConvolutionForward!(y, w, x, d::cudnnConvolutionDescriptor; bias, activation, alpha, beta, z, format)

Return the convolution of filter w with tensor x, overwriting y if provided, according to keyword arguments or the convolution descriptor d. Optionally perform bias addition, activation and/or scaling:

y .= activation.(alpha * conv(w,x) + beta * z .+ bias)

All tensors should have the same number of dimensions. If they are less than 4-D their dimensions are assumed to be padded on the left with ones. x has size (X...,Cx,N) where (X...) are the spatial dimensions, Cx is the number of input channels, and N is the number of instances. y,z have size (Y...,Cy,N) where (Y...) are the spatial dimensions and Cy is the number of output channels (y and z can be the same array). Both Cx and Cy have to be an exact multiple of group. w has size (W...,Cx÷group,Cy) where (W...) are the filter dimensions. bias has size (1...,Cy,1).

The arguments padding, stride and dilation can be specified as n-2 dimensional vectors, tuples or a single integer which is assumed to be repeated n-2 times. If any of the entries is larger than the corresponding x dimension, the x dimension is used instead. For a description of different types of convolution see: https://towardsdatascience.com/a-comprehensive-introduction-to-different-types-of-convolutions-in-deep-learning-669281e58215

Keyword arguments:

  • activation = CUDNN_ACTIVATION_IDENTITY: the only other supported option is CUDNN_ACTIVATION_RELU
  • bias = nothing: add bias if provided
  • z = nothing: add beta*z, z can be nothing, y or another array similar to y
  • alpha = 1, beta = 0: scaling parameters
  • format = CUDNN_TENSOR_NCHW: order of tensor dimensions, the other alternative is CUDNN_TENSOR_NHWC. Note that Julia dimensions will have the opposite order, i.e. WHCN or CWHN.

Keyword arguments describing the convolution when d is not given:

  • mode = CUDNN_CONVOLUTION: alternatively CUDNN_CROSS_CORRELATION
  • padding = 0: padding assumed around x
  • stride = 1: how far to shift the convolution window at each step
  • dilation = 1: dilation factor
  • group = 1: number of groups to be used
  • mathType = cuDNN.math_mode(): whether or not the use of tensor op is permitted
  • reorderType = CUDNN_DEFAULT_REORDER: convolution reorder type
source
cuDNN.cudnnConvolutionForward!Function
cudnnConvolutionForward(w, x; bias, activation, mode, padding, stride, dilation, group, mathType, reorderType, alpha, beta, z, format)
cudnnConvolutionForward(w, x, d::cudnnConvolutionDescriptor; bias, activation, alpha, beta, z, format)
cudnnConvolutionForward!(y, w, x; bias, activation, mode, padding, stride, dilation, group, mathType, reorderType, alpha, beta, z, format)
cudnnConvolutionForward!(y, w, x, d::cudnnConvolutionDescriptor; bias, activation, alpha, beta, z, format)

Return the convolution of filter w with tensor x, overwriting y if provided, according to keyword arguments or the convolution descriptor d. Optionally perform bias addition, activation and/or scaling:

y .= activation.(alpha * conv(w,x) + beta * z .+ bias)

All tensors should have the same number of dimensions. If they are less than 4-D their dimensions are assumed to be padded on the left with ones. x has size (X...,Cx,N) where (X...) are the spatial dimensions, Cx is the number of input channels, and N is the number of instances. y,z have size (Y...,Cy,N) where (Y...) are the spatial dimensions and Cy is the number of output channels (y and z can be the same array). Both Cx and Cy have to be an exact multiple of group. w has size (W...,Cx÷group,Cy) where (W...) are the filter dimensions. bias has size (1...,Cy,1).

The arguments padding, stride and dilation can be specified as n-2 dimensional vectors, tuples or a single integer which is assumed to be repeated n-2 times. If any of the entries is larger than the corresponding x dimension, the x dimension is used instead. For a description of different types of convolution see: https://towardsdatascience.com/a-comprehensive-introduction-to-different-types-of-convolutions-in-deep-learning-669281e58215

Keyword arguments:

  • activation = CUDNN_ACTIVATION_IDENTITY: the only other supported option is CUDNN_ACTIVATION_RELU
  • bias = nothing: add bias if provided
  • z = nothing: add beta*z, z can be nothing, y or another array similar to y
  • alpha = 1, beta = 0: scaling parameters
  • format = CUDNN_TENSOR_NCHW: order of tensor dimensions, the other alternative is CUDNN_TENSOR_NHWC. Note that Julia dimensions will have the opposite order, i.e. WHCN or CWHN.

Keyword arguments describing the convolution when d is not given:

  • mode = CUDNN_CONVOLUTION: alternatively CUDNN_CROSS_CORRELATION
  • padding = 0: padding assumed around x
  • stride = 1: how far to shift the convolution window at each step
  • dilation = 1: dilation factor
  • group = 1: number of groups to be used
  • mathType = cuDNN.math_mode(): whether or not the use of tensor op is permitted
  • reorderType = CUDNN_DEFAULT_REORDER: convolution reorder type
source
cuDNN.cudnnConvolutionFwdAlgoPerfFunction
cudnnConvolutionFwdAlgoPerf(xDesc, x, wDesc, w, convDesc, yDesc, y, biasDesc, activation, allocateTmpBuf=true)

allocateTmpBuf controls whether a temporary buffer is allocated for the output y. It can be set to false when beta is zero to save an allocation and must otherwise be set to true.

source
cuDNN.cudnnDropoutForwardFunction
cudnnDropoutForward(x; dropout=0.5)
cudnnDropoutForward(x, d::cudnnDropoutDescriptor)
cudnnDropoutForward!(y, x; dropout=0.5)
cudnnDropoutForward!(y, x, d::cudnnDropoutDescriptor)

Return a new array similar to x where approximately dropout fraction of the values are replaced by a 0, and the rest are scaled by 1/(1-dropout). Optionally y holds the result and d specifies the operation. y should be similar to x if specified.

The user can set the global seed cudnnDropoutSeed[] to a positive number to always drop the same values deterministically for debugging. Note that this slows down the operation by about 40x.

The global constant cudnnDropoutState::Dict holds the random number generator state for each cuDNN handle.

source
cuDNN.cudnnDropoutForward!Function
cudnnDropoutForward(x; dropout=0.5)
cudnnDropoutForward(x, d::cudnnDropoutDescriptor)
cudnnDropoutForward!(y, x; dropout=0.5)
cudnnDropoutForward!(y, x, d::cudnnDropoutDescriptor)

Return a new array similar to x where approximately dropout fraction of the values are replaced by a 0, and the rest are scaled by 1/(1-dropout). Optionally y holds the result and d specifies the operation. y should be similar to x if specified.

The user can set the global seed cudnnDropoutSeed[] to a positive number to always drop the same values deterministically for debugging. Note that this slows down the operation by about 40x.

The global constant cudnnDropoutState::Dict holds the random number generator state for each cuDNN handle.

source
cuDNN.cudnnGetRNNWeightParamsMethod
cudnnGetRNNWeightParams(w, d::cudnnRNNDescriptor)
cudnnGetRNNWeightParams(w; hiddenSize, o...)

Return an array of weight matrices and bias vectors of an RNN specified by d or keyword options as views into w. The keyword arguments and defaults in the second form are the same as those in cudnnRNNForward specifying the RNN.

In the returned array a[1,l,p] and a[2,l,p] give the weight matrix and bias vector for the l'th layer and p'th parameter or nothing if the specified matrix/vector does not exist. Note that the matrices should be transposed for left multiplication, e.g. `a[1,l,p]'

  • x`

The l index refers to the pseudo-layer number. In uni-directional RNNs, a pseudo-layer is the same as a physical layer (pseudoLayer=1 is the RNN input layer, pseudoLayer=2 is the first hidden layer). In bi-directional RNNs, there are twice as many pseudo-layers in comparison to physical layers:

pseudoLayer=1 refers to the forward direction sub-layer of the physical input layer
pseudoLayer=2 refers to the backward direction sub-layer of the physical input layer
pseudoLayer=3 is the forward direction sub-layer of the first hidden layer, and so on

The p index refers to the weight matrix or bias vector linear ID index.

If cellMode in rnnDesc was set to CUDNNRNNRELU or CUDNNRNNTANH:

Value 1 references the weight matrix or bias vector used in conjunction with the input from the previous layer or input to the RNN model.
Value 2 references the weight matrix or bias vector used in conjunction with the hidden state from the previous time step or the initial hidden state.

If cellMode in rnnDesc was set to CUDNN_LSTM:

Values 1, 2, 3 and 4 reference weight matrices or bias vectors used in conjunction with the input from the previous layer or input to the RNN model.
Values 5, 6, 7 and 8 reference weight matrices or bias vectors used in conjunction with the hidden state from the previous time step or the initial hidden state.
Value 9 corresponds to the projection matrix, if enabled (there is no bias in this operation).

Values and their LSTM gates:

Values 1 and 5 correspond to the input gate.
Values 2 and 6 correspond to the forget gate.
Values 3 and 7 correspond to the new cell state calculations with hyperbolic tangent.
Values 4 and 8 correspond to the output gate.

If cellMode in rnnDesc was set to CUDNN_GRU:

Values 1, 2 and 3 reference weight matrices or bias vectors used in conjunction with the input from the previous layer or input to the RNN model.
Values 4, 5 and 6 reference weight matrices or bias vectors used in conjunction with the hidden state from the previous time step or the initial hidden state.

Values and their GRU gates:

Values 1 and 4 correspond to the reset gate.
Values 2 and 5 reference to the update gate.
Values 3 and 6 correspond to the new hidden state calculations with hyperbolic tangent.
source
cuDNN.cudnnMultiHeadAttnForwardFunction
cudnnMultiHeadAttnForward(weights, queries, keys, values; o...)
cudnnMultiHeadAttnForward!(out, weights, queries, keys, values; o...)
cudnnMultiHeadAttnForward(weights, queries, keys, values, d::cudnnAttnDescriptor; o...)
cudnnMultiHeadAttnForward!(out, weights, queries, keys, values, d::cudnnAttnDescriptor; o...)

Return the multi-head attention result with weights, queries, keys, and values, overwriting out if provided, according to keyword arguments or the attention descriptor d. The multi-head attention model can be described by the following equations:

\[\begin{aligned} &h_i = (W_{V,i} V) \operatorname{softmax}(\operatorname{smScaler}(K^T W^T_{K,i}) (W_{Q,i} q)) &\operatorname(MultiHeadAttn)(q,K,V,W_Q,W_K,W_V,W_O) = \sum_{i=1}^{\operatorname{nHeads}-1} W_{O,i} h_i \end{aligned}\]

The input arguments are:

  • out: Optional output tensor.
  • weights: A weight buffer that contains $W_Q, W_K, W_V, W_O$.
  • queries: A query tensor $Q$ which may contain a batch of queries (the above equations were for a single query vector $q$ for simplicity).
  • keys: The keys tensor $K$.
  • values: The values tensor $V$.

Keyword arguments describing the tensors:

  • axes::Vector{cudnnSeqDataAxis_t} = [CUDNN_SEQDATA_VECT_DIM, CUDNN_SEQDATA_BATCH_DIM, CUDNN_SEQDATA_TIME_DIM, CUDNN_SEQDATA_BEAM_DIM]: an array of length 4 that specifies the role of (Julia) dimensions. VECT has to be the first dimension, all 6 permutations of the remaining three are supported.
  • seqLengthsQO::Vector{<:Integer}: sequence lengths in the queries and out containers. By default sequences are assumed to be full length of the TIME dimension.
  • seqLengthsKV::Vector{<:Integer}: sequence lengths in the keys and values containers. By default sequences are assumed to be full length of the TIME dimension.

Keyword arguments describing the attention operation when d is not given:

  • attnMode::Unsigned = CUDNN_ATTN_QUERYMAP_ALL_TO_ONE | CUDNN_ATTN_DISABLE_PROJ_BIASES: bitwise flags indicating various attention options. See cudnn docs for details.
  • nHeads::Integer = 1: number of attention heads.
  • smScaler::Real = 1: softmax smoothing (1.0 >= smScaler >= 0.0) or sharpening (smScaler > 1.0) coefficient. Negative values are not accepted.
  • mathType::cudnnMathType_t = math_mode(): NVIDIA Tensor Core settings.
  • qProjSize, kProjSize, vProjSize, oProjSize: vector lengths after projections, set to 0 by default which disables projections.
  • qoMaxSeqLength::Integer: largest sequence length expected in queries and out, set to their TIME dim by default.
  • kvMaxSeqLength::Integer: largest sequence length expected in keys and values, set to their TIME dim by default.
  • maxBatchSize::Integer: largest batch size expected in any container, set to the BATCH dim of queries by default.
  • maxBeamSize::Integer: largest beam size expected in any container, set to the BEAM dim of queries by default.

Other keyword arguments:

  • residuals = nothing: optional tensor with the same size as queries that can be used to implement residual connections (see figure in cudnn docs). When residual connections are enabled, the vector length in queries should match the vector length in out, so that a vector addition is feasible.
  • currIdx::Integer = -1: Time-step (0-based) in queries to process. When the currIdx argument is negative, all $Q$ time-steps are processed. When currIdx is zero or positive, the forward response is computed for the selected time-step only. The latter input can be used in inference mode only, to process one time-step while updating the next attention window and $Q$, $K$, $V$ inputs in-between calls.
  • loWinIdx, hiWinIdx::Array{Cint}: Two host integer arrays specifying the start and end (0-based) indices of the attention window for each $Q$ time-step. The start index in $K$, $V$ sets is inclusive, and the end index is exclusive. By default set at 0 and kvMaxSeqLength respectively.
source
cuDNN.cudnnMultiHeadAttnForward!Function
cudnnMultiHeadAttnForward(weights, queries, keys, values; o...)
cudnnMultiHeadAttnForward!(out, weights, queries, keys, values; o...)
cudnnMultiHeadAttnForward(weights, queries, keys, values, d::cudnnAttnDescriptor; o...)
cudnnMultiHeadAttnForward!(out, weights, queries, keys, values, d::cudnnAttnDescriptor; o...)

Return the multi-head attention result with weights, queries, keys, and values, overwriting out if provided, according to keyword arguments or the attention descriptor d. The multi-head attention model can be described by the following equations:

\[\begin{aligned} &h_i = (W_{V,i} V) \operatorname{softmax}(\operatorname{smScaler}(K^T W^T_{K,i}) (W_{Q,i} q)) &\operatorname(MultiHeadAttn)(q,K,V,W_Q,W_K,W_V,W_O) = \sum_{i=1}^{\operatorname{nHeads}-1} W_{O,i} h_i \end{aligned}\]

The input arguments are:

  • out: Optional output tensor.
  • weights: A weight buffer that contains $W_Q, W_K, W_V, W_O$.
  • queries: A query tensor $Q$ which may contain a batch of queries (the above equations were for a single query vector $q$ for simplicity).
  • keys: The keys tensor $K$.
  • values: The values tensor $V$.

Keyword arguments describing the tensors:

  • axes::Vector{cudnnSeqDataAxis_t} = [CUDNN_SEQDATA_VECT_DIM, CUDNN_SEQDATA_BATCH_DIM, CUDNN_SEQDATA_TIME_DIM, CUDNN_SEQDATA_BEAM_DIM]: an array of length 4 that specifies the role of (Julia) dimensions. VECT has to be the first dimension, all 6 permutations of the remaining three are supported.
  • seqLengthsQO::Vector{<:Integer}: sequence lengths in the queries and out containers. By default sequences are assumed to be full length of the TIME dimension.
  • seqLengthsKV::Vector{<:Integer}: sequence lengths in the keys and values containers. By default sequences are assumed to be full length of the TIME dimension.

Keyword arguments describing the attention operation when d is not given:

  • attnMode::Unsigned = CUDNN_ATTN_QUERYMAP_ALL_TO_ONE | CUDNN_ATTN_DISABLE_PROJ_BIASES: bitwise flags indicating various attention options. See cudnn docs for details.
  • nHeads::Integer = 1: number of attention heads.
  • smScaler::Real = 1: softmax smoothing (1.0 >= smScaler >= 0.0) or sharpening (smScaler > 1.0) coefficient. Negative values are not accepted.
  • mathType::cudnnMathType_t = math_mode(): NVIDIA Tensor Core settings.
  • qProjSize, kProjSize, vProjSize, oProjSize: vector lengths after projections, set to 0 by default which disables projections.
  • qoMaxSeqLength::Integer: largest sequence length expected in queries and out, set to their TIME dim by default.
  • kvMaxSeqLength::Integer: largest sequence length expected in keys and values, set to their TIME dim by default.
  • maxBatchSize::Integer: largest batch size expected in any container, set to the BATCH dim of queries by default.
  • maxBeamSize::Integer: largest beam size expected in any container, set to the BEAM dim of queries by default.

Other keyword arguments:

  • residuals = nothing: optional tensor with the same size as queries that can be used to implement residual connections (see figure in cudnn docs). When residual connections are enabled, the vector length in queries should match the vector length in out, so that a vector addition is feasible.
  • currIdx::Integer = -1: Time-step (0-based) in queries to process. When the currIdx argument is negative, all $Q$ time-steps are processed. When currIdx is zero or positive, the forward response is computed for the selected time-step only. The latter input can be used in inference mode only, to process one time-step while updating the next attention window and $Q$, $K$, $V$ inputs in-between calls.
  • loWinIdx, hiWinIdx::Array{Cint}: Two host integer arrays specifying the start and end (0-based) indices of the attention window for each $Q$ time-step. The start index in $K$, $V$ sets is inclusive, and the end index is exclusive. By default set at 0 and kvMaxSeqLength respectively.
source
cuDNN.cudnnNormalizationForwardFunction
cudnnNormalizationForward(x, xmean, xvar, bias, scale; o...)
cudnnNormalizationForward!(y, x, xmean, xvar, bias, scale; o...)

Return batch normalization applied to x:

y .= ((x .- mean(x; dims)) ./ sqrt.(epsilon .+ var(x; dims))) .* scale .+ bias  # training
y .= ((x .- xmean) ./ sqrt.(epsilon .+ xvar)) .* scale .+ bias                  # inference

bias and scale are trainable parameters, xmean and xvar are modified to collect statistics during training and treated as constants during inference. Note that during inference the values given by xmean and xvar arguments are used in the formula whereas during training the actual mean and variance of the minibatch are used in the formula: the xmean/xvar arguments are only used to collect statistics. In the original paper bias is referred to as beta and scale as gamma (Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift, S. Ioffe, C. Szegedy, 2015).

Keyword arguments:

  • epsilon = 1e-5: epsilon value used in the normalization formula
  • exponentialAverageFactor = 0.1: factor used in running mean/variance calculation: runningMean = runningMean*(1-factor) + newMean*factor
  • training = false: boolean indicating training vs inference mode
  • mode::cudnnNormMode_t = CUDNN_NORM_PER_CHANNEL: Per-channel layer is based on the paper. In this mode scale etc. have dimensions (1,1,C,1). The other alternative is CUDNN_NORM_PER_ACTIVATION where scale etc. have dimensions (W,H,C,1).
  • algo::cudnnNormAlgo_t = CUDNN_NORM_ALGO_STANDARD: The other alternative, CUDNN_NORM_ALGO_PERSIST, triggers the new semi-persistent NHWC kernel when certain conditions are met (see cudnn docs).
  • normOps::cudnnNormOps_t = CUDNN_NORM_OPS_NORM: Currently the other alternatives, CUDNN_NORM_OPS_NORM_ACTIVATION and CUDNN_NORM_OPS_NORM_ADD_ACTIVATION are not supported.
  • z = nothing: for residual addition to the result of the normalization operation, prior to the activation (will be supported when CUDNN_NORM_OPS_NORM_ADD_ACTIVATION is supported)
  • groupCnt = 1: Place holder for future work, should be set to 1 now
  • alpha = 1; beta = 0: scaling parameters: return alpha * new_y + beta * old_y
source
cuDNN.cudnnNormalizationForward!Function
cudnnNormalizationForward(x, xmean, xvar, bias, scale; o...)
cudnnNormalizationForward!(y, x, xmean, xvar, bias, scale; o...)

Return batch normalization applied to x:

y .= ((x .- mean(x; dims)) ./ sqrt.(epsilon .+ var(x; dims))) .* scale .+ bias  # training
y .= ((x .- xmean) ./ sqrt.(epsilon .+ xvar)) .* scale .+ bias                  # inference

bias and scale are trainable parameters, xmean and xvar are modified to collect statistics during training and treated as constants during inference. Note that during inference the values given by xmean and xvar arguments are used in the formula whereas during training the actual mean and variance of the minibatch are used in the formula: the xmean/xvar arguments are only used to collect statistics. In the original paper bias is referred to as beta and scale as gamma (Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift, S. Ioffe, C. Szegedy, 2015).

Keyword arguments:

  • epsilon = 1e-5: epsilon value used in the normalization formula
  • exponentialAverageFactor = 0.1: factor used in running mean/variance calculation: runningMean = runningMean*(1-factor) + newMean*factor
  • training = false: boolean indicating training vs inference mode
  • mode::cudnnNormMode_t = CUDNN_NORM_PER_CHANNEL: Per-channel layer is based on the paper. In this mode scale etc. have dimensions (1,1,C,1). The other alternative is CUDNN_NORM_PER_ACTIVATION where scale etc. have dimensions (W,H,C,1).
  • algo::cudnnNormAlgo_t = CUDNN_NORM_ALGO_STANDARD: The other alternative, CUDNN_NORM_ALGO_PERSIST, triggers the new semi-persistent NHWC kernel when certain conditions are met (see cudnn docs).
  • normOps::cudnnNormOps_t = CUDNN_NORM_OPS_NORM: Currently the other alternatives, CUDNN_NORM_OPS_NORM_ACTIVATION and CUDNN_NORM_OPS_NORM_ADD_ACTIVATION are not supported.
  • z = nothing: for residual addition to the result of the normalization operation, prior to the activation (will be supported when CUDNN_NORM_OPS_NORM_ADD_ACTIVATION is supported)
  • groupCnt = 1: Place holder for future work, should be set to 1 now
  • alpha = 1; beta = 0: scaling parameters: return alpha * new_y + beta * old_y
source
cuDNN.cudnnOpTensorFunction
cudnnOpTensor(x1, x2; op, compType, nanOpt, alpha1, alpha2)
cudnnOpTensor(x1, x2, d::cudnnOpTensorDescriptor; alpha1, alpha2)
cudnnOpTensor!(y, x1, x2; op, compType, nanOpt, alpha1, alpha2, beta)
cudnnOpTensor!(y, x1, x2, d::cudnnOpTensorDescriptor; alpha1, alpha2, beta)

Return the result of the specified broadcasting operation applied to x1 and x2. Optionally y holds the result and d specifies the operation. Each dimension of the input tensor x1 must match the corresponding dimension of the destination tensor y, and each dimension of the input tensor x2 must match the corresponding dimension of the destination tensor y or must be equal to 1. Keyword arguments:

  • alpha1=1, alpha2=1, beta=0 are used for scaling, i.e. y .= beta*y .+ op.(alpha1*x1, alpha2*x2)

Keyword arguments used when cudnnOpTensorDescriptor is not specified:

  • op = CUDNN_OP_TENSOR_ADD, ADD can be replaced with MUL, MIN, MAX, SQRT, NOT; SQRT and NOT performed only on x1; NOT computes 1-x1
  • compType = (eltype(x1) <: Float64 ? Float64 : Float32): Computation datatype (see cudnn docs for available options)
  • nanOpt = CUDNN_NOT_PROPAGATE_NAN: NaN propagation policy. The other option is CUDNN_PROPAGATE_NAN.
source
cuDNN.cudnnOpTensor!Function
cudnnOpTensor(x1, x2; op, compType, nanOpt, alpha1, alpha2)
cudnnOpTensor(x1, x2, d::cudnnOpTensorDescriptor; alpha1, alpha2)
cudnnOpTensor!(y, x1, x2; op, compType, nanOpt, alpha1, alpha2, beta)
cudnnOpTensor!(y, x1, x2, d::cudnnOpTensorDescriptor; alpha1, alpha2, beta)

Return the result of the specified broadcasting operation applied to x1 and x2. Optionally y holds the result and d specifies the operation. Each dimension of the input tensor x1 must match the corresponding dimension of the destination tensor y, and each dimension of the input tensor x2 must match the corresponding dimension of the destination tensor y or must be equal to 1. Keyword arguments:

  • alpha1=1, alpha2=1, beta=0 are used for scaling, i.e. y .= beta*y .+ op.(alpha1*x1, alpha2*x2)

Keyword arguments used when cudnnOpTensorDescriptor is not specified:

  • op = CUDNN_OP_TENSOR_ADD, ADD can be replaced with MUL, MIN, MAX, SQRT, NOT; SQRT and NOT performed only on x1; NOT computes 1-x1
  • compType = (eltype(x1) <: Float64 ? Float64 : Float32): Computation datatype (see cudnn docs for available options)
  • nanOpt = CUDNN_NOT_PROPAGATE_NAN: NaN propagation policy. The other option is CUDNN_PROPAGATE_NAN.
source
cuDNN.cudnnPoolingForwardFunction
cudnnPoolingForward(x; mode, nanOpt, window, padding, stride, alpha)
cudnnPoolingForward(x, d::cudnnPoolingDescriptor; alpha)
cudnnPoolingForward!(y, x; mode, nanOpt, window, padding, stride, alpha, beta)
cudnnPoolingForward!(y, x, d::cudnnPoolingDescriptor; alpha, beta)

Return pooled x, overwriting y if provided, according to keyword arguments or the pooling descriptor d. Please see the cuDNN docs for details.

The dimensions of x,y tensors that are less than 4-D are assumed to be padded on the left with 1's. The first n-2 are spatial dimensions, the last two are always assumed to be channel and batch.

The arguments window, padding, and stride can be specified as n-2 dimensional vectors, tuples or a single integer which is assumed to be repeated n-2 times. If any of the entries is larger than the corresponding x dimension, the x dimension is used instead.

Arguments:

  • mode = CUDNN_POOLING_MAX: Pooling method, other options are CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING, CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING, CUDNN_POOLING_MAX_DETERMINISTIC
  • nanOpt = CUDNN_NOT_PROPAGATE_NAN: NaN propagation policy, the other option is CUDNN_PROPAGATE_NAN
  • window = 2: Pooling window size
  • padding = 0: Padding assumed around x
  • stride = window: How far to shift pooling window at each step
  • alpha=1, beta=0 can be used for scaling, i.e. y .= alpha * op(x1) .+ beta * y
source
cuDNN.cudnnPoolingForward!Function
cudnnPoolingForward(x; mode, nanOpt, window, padding, stride, alpha)
cudnnPoolingForward(x, d::cudnnPoolingDescriptor; alpha)
cudnnPoolingForward!(y, x; mode, nanOpt, window, padding, stride, alpha, beta)
cudnnPoolingForward!(y, x, d::cudnnPoolingDescriptor; alpha, beta)

Return pooled x, overwriting y if provided, according to keyword arguments or the pooling descriptor d. Please see the cuDNN docs for details.

The dimensions of x,y tensors that are less than 4-D are assumed to be padded on the left with 1's. The first n-2 are spatial dimensions, the last two are always assumed to be channel and batch.

The arguments window, padding, and stride can be specified as n-2 dimensional vectors, tuples or a single integer which is assumed to be repeated n-2 times. If any of the entries is larger than the corresponding x dimension, the x dimension is used instead.

Arguments:

  • mode = CUDNN_POOLING_MAX: Pooling method, other options are CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING, CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING, CUDNN_POOLING_MAX_DETERMINISTIC
  • nanOpt = CUDNN_NOT_PROPAGATE_NAN: NaN propagation policy, the other option is CUDNN_PROPAGATE_NAN
  • window = 2: Pooling window size
  • padding = 0: Padding assumed around x
  • stride = window: How far to shift pooling window at each step
  • alpha=1, beta=0 can be used for scaling, i.e. y .= alpha * op(x1) .+ beta * y
source
cuDNN.cudnnRNNForwardFunction
cudnnRNNForward(w, x; hiddenSize, o...)
cudnnRNNForward!(y, w, x; hiddenSize, o...)
cudnnRNNForward(w, x, d::cudnnRNNDescriptor; o...)
cudnnRNNForward!(y, w, x, d::cudnnRNNDescriptor; o...)

Apply the RNN specified with weights w and configuration given by d or keyword options to input x.

Keyword arguments for hidden input/output:

  • hx=nothing: initialize the hidden vector if specified (by default initialized to 0).
  • cx=nothing: initialize the cell vector (only in LSTMs) if specified (by default initialized to 0).
  • hy=nothing: return the final hidden vector in hy if set to Ref{Any}().
  • cy=nothing: return the final cell vector in cy (only in LSTMs) if set to Ref{Any}().

Keyword arguments specifying the RNN when d::cudnnRNNDescriptor is not given:

  • hiddenSize::Integer: hidden vector size, which must be supplied when d is not given
  • algo::cudnnRNNAlgo_t = CUDNN_RNN_ALGO_STANDARD: RNN algo (CUDNN_RNN_ALGO_STANDARD, CUDNN_RNN_ALGO_PERSIST_STATIC, or CUDNN_RNN_ALGO_PERSIST_DYNAMIC).
  • cellMode::cudnnRNNMode_t = CUDNN_LSTM: Specifies the RNN cell type in the entire model (CUDNN_RNN_RELU, CUDNN_RNN_TANH, CUDNN_LSTM, CUDNN_GRU).
  • biasMode::cudnnRNNBiasMode_t = CUDNN_RNN_DOUBLE_BIAS: Sets the number of bias vectors (CUDNN_RNN_NO_BIAS, CUDNN_RNN_SINGLE_INP_BIAS, CUDNN_RNN_SINGLE_REC_BIAS, CUDNN_RNN_DOUBLE_BIAS). The two single bias settings are functionally the same for RELU, TANH and LSTM cell types. For differences in GRU cells, see the description of CUDNN_GRU in cudnn docs.
  • dirMode::cudnnDirectionMode_t = CUDNN_UNIDIRECTIONAL: Specifies the recurrence pattern: CUDNN_UNIDIRECTIONAL or CUDNN_BIDIRECTIONAL. In bidirectional RNNs, the hidden states passed between physical layers are concatenations of forward and backward hidden states.
  • inputMode::cudnnRNNInputMode_t = CUDNN_LINEAR_INPUT: Specifies how the input to the RNN model is processed by the first layer. When inputMode is CUDNN_LINEAR_INPUT, original input vectors of size inputSize are multiplied by the weight matrix to obtain vectors of hiddenSize. When inputMode is CUDNN_SKIP_INPUT, the original input vectors to the first layer are used as is without multiplying them by the weight matrix.
  • mathPrec::DataType = eltype(x): This parameter is used to control the compute math precision in the RNN model. For Float16 input/output can be Float16 or Float32, for Float32 or Float64 input/output, must match the input/output type.
  • mathType::cudnnMathType_t = math_mode(): Sets the preferred option to use NVIDIA Tensor Cores accelerators on Volta (SM 7.0) or higher GPUs. When dataType is CUDNN_DATA_HALF, the mathType parameter can be CUDNN_DEFAULT_MATH or CUDNN_TENSOR_OP_MATH. The ALLOW_CONVERSION setting is treated the same CUDNN_TENSOR_OP_MATH for this data type. When dataType is CUDNN_DATA_FLOAT, the mathType parameter can be CUDNN_DEFAULT_MATH or CUDNN_TENSOR_OP_MATH_ALLOW_CONVERSION. When the latter settings are used, original weights and intermediate results will be down-converted to CUDNN_DATA_HALF before they are used in another recursive iteration. When dataType is CUDNN_DATA_DOUBLE, the mathType parameter can be CUDNN_DEFAULT_MATH.
  • inputSize::Integer = size(x,1): Size of the input vector in the RNN model. When the inputMode=CUDNN_SKIP_INPUT, the inputSize should match the hiddenSize value.
  • projSize::Integer = hiddenSize: The size of the LSTM cell output after the recurrent projection. This value should not be larger than hiddenSize. It is legal to set projSize equal to hiddenSize, however, in this case, the recurrent projection feature is disabled. The recurrent projection is an additional matrix multiplication in the LSTM cell to project hidden state vectors ht into smaller vectors rt = Wr * ht, where Wr is a rectangular matrix with projSize rows and hiddenSize columns. When the recurrent projection is enabled, the output of the LSTM cell (both to the next layer and unrolled in-time) is rt instead of ht. The recurrent projection can be enabled for LSTM cells and CUDNN_RNN_ALGO_STANDARD only.
  • numLayers::Integer = 1: Number of stacked, physical layers in the deep RNN model. When dirMode= CUDNN_BIDIRECTIONAL, the physical layer consists of two pseudo-layers corresponding to forward and backward directions.
  • dropout::Real = 0: When non-zero, dropout operation will be applied between physical layers. A single layer network will have no dropout applied. Dropout is used in the training mode only.
  • auxFlags::Integer = CUDNN_RNN_PADDED_IO_ENABLED: Miscellaneous switches that do not require additional numerical values to configure the corresponding feature. In future cuDNN releases, this parameter will be used to extend the RNN functionality without adding new API functions (applicable options should be bitwise OR-ed). Currently, this parameter is used to enable or disable padded input/output (CUDNN_RNN_PADDED_IO_DISABLED, CUDNN_RNN_PADDED_IO_ENABLED). When the padded I/O is enabled, layouts CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_UNPACKED and CUDNN_RNN_DATA_LAYOUT_BATCH_MAJOR_UNPACKED are permitted in RNN data descriptors.

Other keyword arguments:

  • layout::cudnnRNNDataLayout_t = CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_UNPACKED: The memory layout of the RNN data tensor. Options are CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_UNPACKED: Data layout is padded, with outer stride from one time-step to the next; CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_PACKED: The sequence length is sorted and packed as in the basic RNN API; CUDNN_RNN_DATA_LAYOUT_BATCH_MAJOR_UNPACKED: Data layout is padded, with outer stride from one batch to the next.
  • seqLengthArray::Vector{Cint} = nothing: An integer array with batchSize number of elements. Describes the length (number of time-steps) of each sequence. Each element in seqLengthArray must be greater than or equal to 0 but less than or equal to maxSeqLength. In the packed layout, the elements should be sorted in descending order, similar to the layout required by the non-extended RNN compute functions. The default value nothing assumes uniform seqLengths, no padding.
  • devSeqLengths::CuVector{Cint} = nothing: Device copy of seqLengthArray
  • fwdMode::cudnnForwardMode_t = CUDNN_FWD_MODE_INFERENCE: set to CUDNN_FWD_MODE_TRAINING when training
source
cuDNN.cudnnRNNForward!Function
cudnnRNNForward(w, x; hiddenSize, o...)
cudnnRNNForward!(y, w, x; hiddenSize, o...)
cudnnRNNForward(w, x, d::cudnnRNNDescriptor; o...)
cudnnRNNForward!(y, w, x, d::cudnnRNNDescriptor; o...)

Apply the RNN specified with weights w and configuration given by d or keyword options to input x.

Keyword arguments for hidden input/output:

  • hx=nothing: initialize the hidden vector if specified (by default initialized to 0).
  • cx=nothing: initialize the cell vector (only in LSTMs) if specified (by default initialized to 0).
  • hy=nothing: return the final hidden vector in hy if set to Ref{Any}().
  • cy=nothing: return the final cell vector in cy (only in LSTMs) if set to Ref{Any}().

Keyword arguments specifying the RNN when d::cudnnRNNDescriptor is not given:

  • hiddenSize::Integer: hidden vector size, which must be supplied when d is not given
  • algo::cudnnRNNAlgo_t = CUDNN_RNN_ALGO_STANDARD: RNN algo (CUDNN_RNN_ALGO_STANDARD, CUDNN_RNN_ALGO_PERSIST_STATIC, or CUDNN_RNN_ALGO_PERSIST_DYNAMIC).
  • cellMode::cudnnRNNMode_t = CUDNN_LSTM: Specifies the RNN cell type in the entire model (CUDNN_RNN_RELU, CUDNN_RNN_TANH, CUDNN_LSTM, CUDNN_GRU).
  • biasMode::cudnnRNNBiasMode_t = CUDNN_RNN_DOUBLE_BIAS: Sets the number of bias vectors (CUDNN_RNN_NO_BIAS, CUDNN_RNN_SINGLE_INP_BIAS, CUDNN_RNN_SINGLE_REC_BIAS, CUDNN_RNN_DOUBLE_BIAS). The two single bias settings are functionally the same for RELU, TANH and LSTM cell types. For differences in GRU cells, see the description of CUDNN_GRU in cudnn docs.
  • dirMode::cudnnDirectionMode_t = CUDNN_UNIDIRECTIONAL: Specifies the recurrence pattern: CUDNN_UNIDIRECTIONAL or CUDNN_BIDIRECTIONAL. In bidirectional RNNs, the hidden states passed between physical layers are concatenations of forward and backward hidden states.
  • inputMode::cudnnRNNInputMode_t = CUDNN_LINEAR_INPUT: Specifies how the input to the RNN model is processed by the first layer. When inputMode is CUDNN_LINEAR_INPUT, original input vectors of size inputSize are multiplied by the weight matrix to obtain vectors of hiddenSize. When inputMode is CUDNN_SKIP_INPUT, the original input vectors to the first layer are used as is without multiplying them by the weight matrix.
  • mathPrec::DataType = eltype(x): This parameter is used to control the compute math precision in the RNN model. For Float16 input/output can be Float16 or Float32, for Float32 or Float64 input/output, must match the input/output type.
  • mathType::cudnnMathType_t = math_mode(): Sets the preferred option to use NVIDIA Tensor Cores accelerators on Volta (SM 7.0) or higher GPUs. When dataType is CUDNN_DATA_HALF, the mathType parameter can be CUDNN_DEFAULT_MATH or CUDNN_TENSOR_OP_MATH. The ALLOW_CONVERSION setting is treated the same CUDNN_TENSOR_OP_MATH for this data type. When dataType is CUDNN_DATA_FLOAT, the mathType parameter can be CUDNN_DEFAULT_MATH or CUDNN_TENSOR_OP_MATH_ALLOW_CONVERSION. When the latter settings are used, original weights and intermediate results will be down-converted to CUDNN_DATA_HALF before they are used in another recursive iteration. When dataType is CUDNN_DATA_DOUBLE, the mathType parameter can be CUDNN_DEFAULT_MATH.
  • inputSize::Integer = size(x,1): Size of the input vector in the RNN model. When the inputMode=CUDNN_SKIP_INPUT, the inputSize should match the hiddenSize value.
  • projSize::Integer = hiddenSize: The size of the LSTM cell output after the recurrent projection. This value should not be larger than hiddenSize. It is legal to set projSize equal to hiddenSize, however, in this case, the recurrent projection feature is disabled. The recurrent projection is an additional matrix multiplication in the LSTM cell to project hidden state vectors ht into smaller vectors rt = Wr * ht, where Wr is a rectangular matrix with projSize rows and hiddenSize columns. When the recurrent projection is enabled, the output of the LSTM cell (both to the next layer and unrolled in-time) is rt instead of ht. The recurrent projection can be enabled for LSTM cells and CUDNN_RNN_ALGO_STANDARD only.
  • numLayers::Integer = 1: Number of stacked, physical layers in the deep RNN model. When dirMode= CUDNN_BIDIRECTIONAL, the physical layer consists of two pseudo-layers corresponding to forward and backward directions.
  • dropout::Real = 0: When non-zero, dropout operation will be applied between physical layers. A single layer network will have no dropout applied. Dropout is used in the training mode only.
  • auxFlags::Integer = CUDNN_RNN_PADDED_IO_ENABLED: Miscellaneous switches that do not require additional numerical values to configure the corresponding feature. In future cuDNN releases, this parameter will be used to extend the RNN functionality without adding new API functions (applicable options should be bitwise OR-ed). Currently, this parameter is used to enable or disable padded input/output (CUDNN_RNN_PADDED_IO_DISABLED, CUDNN_RNN_PADDED_IO_ENABLED). When the padded I/O is enabled, layouts CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_UNPACKED and CUDNN_RNN_DATA_LAYOUT_BATCH_MAJOR_UNPACKED are permitted in RNN data descriptors.

Other keyword arguments:

  • layout::cudnnRNNDataLayout_t = CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_UNPACKED: The memory layout of the RNN data tensor. Options are CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_UNPACKED: Data layout is padded, with outer stride from one time-step to the next; CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_PACKED: The sequence length is sorted and packed as in the basic RNN API; CUDNN_RNN_DATA_LAYOUT_BATCH_MAJOR_UNPACKED: Data layout is padded, with outer stride from one batch to the next.
  • seqLengthArray::Vector{Cint} = nothing: An integer array with batchSize number of elements. Describes the length (number of time-steps) of each sequence. Each element in seqLengthArray must be greater than or equal to 0 but less than or equal to maxSeqLength. In the packed layout, the elements should be sorted in descending order, similar to the layout required by the non-extended RNN compute functions. The default value nothing assumes uniform seqLengths, no padding.
  • devSeqLengths::CuVector{Cint} = nothing: Device copy of seqLengthArray
  • fwdMode::cudnnForwardMode_t = CUDNN_FWD_MODE_INFERENCE: set to CUDNN_FWD_MODE_TRAINING when training
source
cuDNN.cudnnReduceTensorFunction
cudnnReduceTensor(x; dims, op, compType, nanOpt, indices, alpha)
cudnnReduceTensor(x, d::cudnnReduceTensorDescriptor; dims, indices, alpha)
cudnnReduceTensor!(y, x; op, compType, nanOpt, indices, alpha, beta)
cudnnReduceTensor!(y, x, d::cudnnReduceTensorDescriptor; indices, alpha, beta)

Return the result of the specified reduction operation applied to x. Optionally y holds the result and d specifies the operation. Each dimension of the output tensor y must match the corresponding dimension of the input tensor x or must be equal to 1. The dimensions equal to 1 indicate the dimensions of x to be reduced. Keyword arguments:

  • dims = ntuple(i->1,ndims(x)): specifies the shape of the output when y is not given
  • indices = nothing: previously allocated space for writing indices which can be generated for min and max ops only, can be a CuArray of UInt8, UInt16, UInt32 or UInt64
  • alpha=1, beta=0 are used for scaling, i.e. y .= alpha * op.(x1) .+ beta * y

Keyword arguments that can be used when reduceTensorDesc is not specified:

  • op = CUDNN_REDUCE_TENSOR_ADD: Reduction operation, ADD can be replaced with MUL, MIN, MAX, AMAX, AVG, NORM1, NORM2, MUL_NO_ZEROS
  • compType = (eltype(x) <: Float64 ? Float64 : Float32): Computation datatype
  • nanOpt = CUDNN_NOT_PROPAGATE_NAN: NaN propagation policy, the other option is CUDNN_PROPAGATE_NAN
source
cuDNN.cudnnReduceTensor!Function
cudnnReduceTensor(x; dims, op, compType, nanOpt, indices, alpha)
cudnnReduceTensor(x, d::cudnnReduceTensorDescriptor; dims, indices, alpha)
cudnnReduceTensor!(y, x; op, compType, nanOpt, indices, alpha, beta)
cudnnReduceTensor!(y, x, d::cudnnReduceTensorDescriptor; indices, alpha, beta)

Return the result of the specified reduction operation applied to x. Optionally y holds the result and d specifies the operation. Each dimension of the output tensor y must match the corresponding dimension of the input tensor x or must be equal to 1. The dimensions equal to 1 indicate the dimensions of x to be reduced. Keyword arguments:

  • dims = ntuple(i->1,ndims(x)): specifies the shape of the output when y is not given
  • indices = nothing: previously allocated space for writing indices which can be generated for min and max ops only, can be a CuArray of UInt8, UInt16, UInt32 or UInt64
  • alpha=1, beta=0 are used for scaling, i.e. y .= alpha * op.(x1) .+ beta * y

Keyword arguments that can be used when reduceTensorDesc is not specified:

  • op = CUDNN_REDUCE_TENSOR_ADD: Reduction operation, ADD can be replaced with MUL, MIN, MAX, AMAX, AVG, NORM1, NORM2, MUL_NO_ZEROS
  • compType = (eltype(x) <: Float64 ? Float64 : Float32): Computation datatype
  • nanOpt = CUDNN_NOT_PROPAGATE_NAN: NaN propagation policy, the other option is CUDNN_PROPAGATE_NAN
source
cuDNN.cudnnScaleTensorFunction
cudnnScaleTensor(x, s)
cudnnScaleTensor!(y, x, s)

Scale all elements of tensor x with scale s and return the result. cudnnScaleTensor allocates a new array for the answer, cudnnScaleTensor! overwrites y.

source
cuDNN.cudnnScaleTensor!Function
cudnnScaleTensor(x, s)
cudnnScaleTensor!(y, x, s)

Scale all elements of tensor x with scale s and return the result. cudnnScaleTensor allocates a new array for the answer, cudnnScaleTensor! overwrites y.

source
cuDNN.cudnnSoftmaxForwardFunction
cudnnSoftmaxForward(x; algo, mode, alpha)
cudnnSoftmaxForward!(y, x; algo, mode, alpha, beta)

Return the softmax or logsoftmax of the input x depending on the algo keyword argument. The y argument holds the result and it should be similar to x if specified. Keyword arguments:

  • algo = (CUDA.math_mode()===CUDA.FAST_MATH ? CUDNN_SOFTMAX_FAST : CUDNN_SOFTMAX_ACCURATE): Options are CUDNN_SOFTMAX_ACCURATE which subtracts max from every point to avoid overflow, CUDNN_SOFTMAX_FAST which doesn't and CUDNN_SOFTMAX_LOG which returns logsoftmax.
  • mode = CUDNN_SOFTMAX_MODE_INSTANCE: Compute softmax per image (N) across the dimensions C,H,W. CUDNN_SOFTMAX_MODE_CHANNEL computes softmax per spatial location (H,W) per image (N) across the dimension C.
  • alpha=1, beta=0 can be used for scaling, i.e. y .= alpha * op(x1) .+ beta * y
source
cuDNN.cudnnSoftmaxForward!Function
cudnnSoftmaxForward(x; algo, mode, alpha)
cudnnSoftmaxForward!(y, x; algo, mode, alpha, beta)

Return the softmax or logsoftmax of the input x depending on the algo keyword argument. The y argument holds the result and it should be similar to x if specified. Keyword arguments:

  • algo = (CUDA.math_mode()===CUDA.FAST_MATH ? CUDNN_SOFTMAX_FAST : CUDNN_SOFTMAX_ACCURATE): Options are CUDNN_SOFTMAX_ACCURATE which subtracts max from every point to avoid overflow, CUDNN_SOFTMAX_FAST which doesn't and CUDNN_SOFTMAX_LOG which returns logsoftmax.
  • mode = CUDNN_SOFTMAX_MODE_INSTANCE: Compute softmax per image (N) across the dimensions C,H,W. CUDNN_SOFTMAX_MODE_CHANNEL computes softmax per spatial location (H,W) per image (N) across the dimension C.
  • alpha=1, beta=0 can be used for scaling, i.e. y .= alpha * op(x1) .+ beta * y
source
cuDNN.sdimMethod
sdim(x,axes,dim)
sdim(x,axes)

The first form returns the size of x in the dimension specified with dim::cudnnSeqDataAxis_t (e.g. CUDNNSEQDATATIME_DIM), i.e. return size(x,i) such that axes[i]==dim.

The second form returns an array of length 4 dims::Vector{Cint} such that dims[1+dim] == sdim(x,axes,dim) where dim::cudnnSeqDataAxis_t specifies the role of the dimension (e.g. dims[CUDNNSEQDATATIME_DIM]==5).

The axes::Vector{cudnnSeqDataAxis_t} argument is an array of length 4 that specifies the role of Julia dimensions, e.g. axes[3]=CUDNN_SEQDATA_TIME_DIM.

source
cuDNN.@cudnnDescriptorMacro
@cudnnDescriptor(XXX, setter=cudnnSetXXXDescriptor)

Defines a new type cudnnXXXDescriptor with a single field ptr::cudnnXXXDescriptor_t and its constructor. The second optional argument is the function that sets the descriptor fields and defaults to cudnnSetXXXDescriptor. The constructor is memoized, i.e. when called with the same arguments it returns the same object rather than creating a new one.

The arguments of the constructor and thus the keys to the memoization cache depend on the setter: If the setter has arguments cudnnSetXXXDescriptor(ptr::cudnnXXXDescriptor_t, args...), then the constructor has cudnnXXXDescriptor(args...). The user can control these arguments by defining a custom setter.

source