diff --git a/codegen/op-codegen/src/main/ops/org/nd4j/codegen/ops/NeuralNetwork.kt b/codegen/op-codegen/src/main/ops/org/nd4j/codegen/ops/NeuralNetwork.kt
index b7c19c16ac17..076b9269c7c3 100644
--- a/codegen/op-codegen/src/main/ops/org/nd4j/codegen/ops/NeuralNetwork.kt
+++ b/codegen/op-codegen/src/main/ops/org/nd4j/codegen/ops/NeuralNetwork.kt
@@ -445,45 +445,75 @@ fun NN() = Namespace("NN") {
}
}
+ Op("rmsNorm") {
+ javaPackage = "org.nd4j.linalg.api.ops.impl.transforms.custom"
+ javaOpClass = "RmsNorm"
+ val input = Input(NUMERIC, "input") { description = "Input variable" }
+ val gamma = Input(NUMERIC, "gamma") { description = "Scale/gain vector"; defaultValue = null }
+ val epsilon = Arg(FLOATING_POINT, "epsilon") { defaultValue = 1e-5; description = "Epsilon for numerical stability" }
+
+ Output(NUMERIC, "output") { description = "RMS normalized output" }
+
+ AllParamSignature()
+ Signature(input, gamma)
+ Signature(input, epsilon)
+ Signature(input)
+
+ Doc(Language.ANY, DocScope.ALL) {
+ """
+ Root Mean Square Layer Normalization (RMSNorm):
+
+ output = input * rsqrt(mean(input^2, axis=-1) + epsilon) * gamma
+
+ If gamma is not provided, only RMS normalization is applied.
+ """.trimIndent()
+ }
+ }
+
Op("dotProductAttentionV2") {
javaPackage = "org.nd4j.linalg.api.ops.impl.transforms.custom"
- val q = Input(NUMERIC, "queries") { description = "A {@link SDVariable} representing the query tensor. Shape: [batchSize, numQueries, queryDim]" }
- val v = Input(NUMERIC, "values") { description = "A {@link SDVariable} representing the value tensor. Shape: [batchSize, numValues, valueDim]" }
+ val q = Input(NUMERIC, "queries") { description = "Query tensor. Shape: [batchSize, numQueries, queryDim] or [batchSize, numQueries, numHeads, headDim] for flash attention" }
+ val v = Input(NUMERIC, "values") { description = "Value tensor. Shape: [batchSize, numValues, valueDim] or [batchSize, numValues, numHeads, headDim]" }
+ val k = Input(NUMERIC, "keys") { description = "Key tensor. Shape: [batchSize, numValues, keyDim] or [batchSize, numValues, numHeads, headDim]" }
+ val queryMask = Input(NUMERIC, "queryMask") { description = "Query mask tensor (optional). Shape: [batchSize, numQueries]"; defaultValue = null }
+ val valueMask = Input(NUMERIC, "valueMask") { description = "Value mask tensor (optional). Shape: [batchSize, numValues]"; defaultValue = null }
+ val attentionBias = Input(NUMERIC, "attentionBias") { description = "Attention bias tensor (optional). Shape: [batchSize, numHeads, numQueries, numKeys] or broadcastable. Added to attention scores before softmax."; defaultValue = null }
- val k = Input(NUMERIC, "keys") { description = "A {@link SDVariable} representing the key tensor. Shape: [batchSize, numValues, keyDim]" }
- val queryMask = Input(NUMERIC, "queryMask") { description = "A {@link SDVariable} representing the query mask tensor. Shape: [batchSize, numQueries]" }
- val valueMask = Input(NUMERIC, "valueMask") { description = "@param valueMask A {@link SDVariable} representing the value mask tensor. Shape: [batchSize, numValues]" }
+ val s = Arg(FLOATING_POINT, "scaleFactor") { defaultValue = 0.0; description = "Scaling factor applied to attention scores. 0 = auto (1/sqrt(headDim))" }
+ val dropout = Arg(FLOATING_POINT, "dropoutProbability") { defaultValue = 0.0; description = "Dropout probability applied to attention weights" }
+ val useCausalMask = Arg(BOOL, "useCausalMask") { defaultValue = false; description = "Whether to apply causal mask for autoregressive tasks" }
+ val training = Arg(BOOL, "training") { defaultValue = false; description = "Whether in training mode (affects dropout)" }
- val s = Arg(FLOATING_POINT, "scaleFactor") { defaultValue = 1.0; description = "@param scaleFactor A {@code double} scaling factor applied to the dot product between queries and keys." }
- val dropout = Arg(FLOATING_POINT, "dropoutProbability") { defaultValue = 0.0; description = "A {@code double} specifying the dropout probability to be applied to attention weights." }
- val useCausalMask = Arg(BOOL, "useCausalMask") { defaultValue = false; description = " A {@code boolean} flag to indicate whether to apply a causal mask to the attention scores, for autoregressive tasks." }
- val training = Arg(BOOL, "training") { defaultValue = false; description = " A {@code boolean} flag to indicate whether the layer is in training mode or inference mode, affecting dropout." }
+ Output(NUMERIC, "output") { description = "Output tensor. Shape: [batchSize, numQueries, valueDim] or [batchSize, numQueries, numHeads, headDim]" }
- Output(NUMERIC, "output") { description = " A {@link SDVariable} representing the output tensor of the dot product attention operation. Shape: [batchSize, numQueries, valueDim]"}
-
- Signature(q,v,k,queryMask,valueMask, s,dropout,useCausalMask,training)
+ // Standard signature without attention bias (backward compatible)
+ Signature(q, v, k, queryMask, valueMask, s, dropout, useCausalMask, training)
+ // Full signature with attention bias
+ Signature(q, v, k, queryMask, valueMask, attentionBias, s, dropout, useCausalMask, training)
Doc(Language.ANY, DocScope.ALL) {
"""
- This operation performs dot product attention on the given timeseries input with the given queries
- out = sum(similarity(k_i, q) * v_i)
-
- similarity(k, q) = softmax(k * q) where x * q is the dot product of x and q
-
- Optionally with normalization step:
- similarity(k, q) = softmax(k * q / sqrt(size(q))
-
- See also "Attention is all you need" (https://arxiv.org/abs/1706.03762, p. 4, eq. 1)
-
- Note: This supports multiple queries at once, if only one query is available the queries vector still has to
- be 3D but can have queryCount = 1
-
- Note: keys and values usually is the same array. If you want to use it as the same array, simply pass it for
- both.
-
- Note: Queries, keys and values must either be all rank 3 or all rank 4 arrays. Mixing them doesn't work. The
- output rank will depend on the input rank.
+ Dot product attention operation with flash attention and KV cache support.
+
+ out = softmax(Q * K^T / scale + attentionBias) * V
+
+ For 4D inputs [batch, seq, heads, dim], uses memory-efficient flash attention algorithm.
+ For 2D/3D inputs, uses standard attention computation.
+
+ Flash attention features:
+ - O(N) memory complexity instead of O(N^2)
+ - Tiled computation with online softmax
+ - Supports grouped query attention (GQA) where numHeads > numKvHeads
+ - Supports attention bias (relative position bias, ALiBi, etc.)
+
+ KV Cache support for autoregressive generation:
+ - Pass keyCache and valueCache tensors
+ - Set kvCachePosition to current generation position
+ - Cached keys/values are updated in-place
+
+ See "Attention is all you need" (https://arxiv.org/abs/1706.03762)
+ See "FlashAttention: Fast and Memory-Efficient Exact Attention" (https://arxiv.org/abs/2205.14135)
""".trimIndent()
}
}
@@ -565,6 +595,139 @@ fun NN() = Namespace("NN") {
}
}
+ Op("flashAttention") {
+ javaPackage = "org.nd4j.linalg.api.ops.impl.transforms.custom"
+ val q = Input(NUMERIC, "query") { description = "Query tensor. Shape: [batch, seqLen, numHeads, headDim]" }
+ val k = Input(NUMERIC, "key") { description = "Key tensor. Shape: [batch, seqLen, numKvHeads, headDim]" }
+ val v = Input(NUMERIC, "value") { description = "Value tensor. Shape: [batch, seqLen, numKvHeads, headDim]" }
+
+ val scale = Arg(FLOATING_POINT, "scale") { defaultValue = 0.0; description = "Scaling factor. 0 = auto (1/sqrt(headDim))" }
+ val isCausal = Arg(BOOL, "isCausal") { defaultValue = true; description = "Whether to apply causal masking" }
+ val numHeads = Arg(INT, "numHeads") { description = "Number of query attention heads" }
+ val numKvHeads = Arg(INT, "numKvHeads") { defaultValue = 0; description = "Number of KV heads (0 = same as numHeads, for GQA use smaller value)" }
+
+ Output(NUMERIC, "output") { description = "Attention output. Shape: [batch, seqLen, numHeads, headDim]" }
+
+ Signature(q, k, v, scale, isCausal, numHeads, numKvHeads)
+
+ Doc(Language.ANY, DocScope.ALL) {
+ """
+ Flash Attention - Memory-efficient attention computation.
+
+ Uses tiled computation with online softmax to achieve O(N) memory complexity
+ instead of O(N^2) for standard attention.
+
+ Supports Grouped Query Attention (GQA) where numHeads > numKvHeads,
+ allowing multiple query heads to share the same KV heads.
+
+ out = softmax(Q * K^T / scale) * V
+
+ See "FlashAttention: Fast and Memory-Efficient Exact Attention" (https://arxiv.org/abs/2205.14135)
+ """.trimIndent()
+ }
+ }
+
+ Op("groupedQueryAttention") {
+ javaPackage = "org.nd4j.linalg.api.ops.impl.transforms.custom"
+ val q = Input(NUMERIC, "query") { description = "Query tensor. Shape: [batch, seqLen, numHeads, headDim]" }
+ val k = Input(NUMERIC, "key") { description = "Key tensor. Shape: [batch, seqLen, numKvHeads, headDim]" }
+ val v = Input(NUMERIC, "value") { description = "Value tensor. Shape: [batch, seqLen, numKvHeads, headDim]" }
+
+ val scale = Arg(FLOATING_POINT, "scale") { defaultValue = 0.0; description = "Scaling factor. 0 = auto (1/sqrt(headDim))" }
+ val isCausal = Arg(BOOL, "isCausal") { defaultValue = true; description = "Whether to apply causal masking" }
+ val numHeads = Arg(INT, "numHeads") { description = "Number of query attention heads" }
+ val numKvHeads = Arg(INT, "numKvHeads") { description = "Number of KV heads (must divide numHeads evenly)" }
+
+ Output(NUMERIC, "output") { description = "Attention output. Shape: [batch, seqLen, numHeads, headDim]" }
+
+ Signature(q, k, v, scale, isCausal, numHeads, numKvHeads)
+
+ Doc(Language.ANY, DocScope.ALL) {
+ """
+ Grouped Query Attention (GQA) - Efficient attention with shared KV heads.
+
+ Multiple query heads share the same key-value heads, reducing memory and
+ computation while maintaining model quality. Used in LLaMA 2, Mistral, etc.
+
+ numHeads must be divisible by numKvHeads. Each KV head is repeated
+ (numHeads / numKvHeads) times to match query heads.
+
+ Special cases:
+ - numKvHeads == numHeads: Standard Multi-Head Attention (MHA)
+ - numKvHeads == 1: Multi-Query Attention (MQA)
+
+ See "GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints"
+ """.trimIndent()
+ }
+ }
+
+ Op("kvCacheUpdate") {
+ javaPackage = "org.nd4j.linalg.api.ops.impl.transforms.custom"
+ javaOpClass = "KVCacheUpdate"
+ val keyCache = Input(NUMERIC, "keyCache") { description = "Existing key cache. Shape: [batch, maxSeqLen, numKvHeads, headDim]" }
+ val valueCache = Input(NUMERIC, "valueCache") { description = "Existing value cache. Shape: [batch, maxSeqLen, numKvHeads, headDim]" }
+ val newKeys = Input(NUMERIC, "newKeys") { description = "New keys to insert. Shape: [batch, newSeqLen, numKvHeads, headDim]" }
+ val newValues = Input(NUMERIC, "newValues") { description = "New values to insert. Shape: [batch, newSeqLen, numKvHeads, headDim]" }
+
+ val startPosition = Arg(INT, "startPosition") { defaultValue = 0; description = "Position in cache where new keys/values should be inserted" }
+
+ Output(NUMERIC, "updatedKeyCache") { description = "Updated key cache" }
+ Output(NUMERIC, "updatedValueCache") { description = "Updated value cache" }
+
+ Signature(keyCache, valueCache, newKeys, newValues, startPosition)
+
+ Doc(Language.ANY, DocScope.ALL) {
+ """
+ KV Cache Update - Updates key-value cache for autoregressive generation.
+
+ During LLM inference, past key-value pairs are cached to avoid redundant
+ computation during token-by-token generation. This operation efficiently
+ inserts new keys/values at the specified position.
+
+ Usage pattern:
+ 1. Initialize cache with zeros: [batch, maxSeqLen, numKvHeads, headDim]
+ 2. For each new token, compute new K/V and update cache
+ 3. Use full cached K/V for attention computation
+
+ Returns updated keyCache and valueCache tensors.
+ """.trimIndent()
+ }
+ }
+
+ Op("slidingWindowAttention") {
+ javaPackage = "org.nd4j.linalg.api.ops.impl.transforms.custom"
+ val q = Input(NUMERIC, "query") { description = "Query tensor. Shape: [batch, seqLen, numHeads, headDim]" }
+ val k = Input(NUMERIC, "key") { description = "Key tensor. Shape: [batch, seqLen, numKvHeads, headDim]" }
+ val v = Input(NUMERIC, "value") { description = "Value tensor. Shape: [batch, seqLen, numKvHeads, headDim]" }
+
+ val windowSize = Arg(INT, "windowSize") { defaultValue = 4096; description = "Sliding window size - tokens can only attend to this many previous positions" }
+ val numHeads = Arg(INT, "numHeads") { description = "Number of query attention heads" }
+ val numKvHeads = Arg(INT, "numKvHeads") { defaultValue = 0; description = "Number of KV heads (0 = same as numHeads)" }
+ val scale = Arg(FLOATING_POINT, "scale") { defaultValue = 0.0; description = "Scaling factor. 0 = auto (1/sqrt(headDim))" }
+
+ Output(NUMERIC, "output") { description = "Attention output. Shape: [batch, seqLen, numHeads, headDim]" }
+
+ Signature(q, k, v, windowSize, numHeads, numKvHeads, scale)
+
+ Doc(Language.ANY, DocScope.ALL) {
+ """
+ Sliding Window Attention - Efficient attention for long sequences.
+
+ Each token only attends to a fixed window of previous tokens, enabling
+ efficient processing of very long sequences. Used in Mistral and other
+ modern LLMs for handling long contexts.
+
+ Benefits:
+ - O(N * windowSize) complexity instead of O(N^2)
+ - Memory efficient for long sequences
+ - Supports very long context lengths (e.g., 32K with 4K window)
+
+ The attention mask is automatically applied to restrict each position
+ to only attend to positions within [pos - windowSize, pos].
+ """.trimIndent()
+ }
+ }
+
Op("pad") {
javaPackage = "org.nd4j.linalg.api.ops.impl.transforms"
Input(NUMERIC, "input") { description = "Input tensor"}
@@ -576,7 +739,7 @@ fun NN() = Namespace("NN") {
Doc(Language.ANY, DocScope.ALL){
"""
- Padding operation
+ Padding operation
""".trimIndent()
}
}
@@ -598,5 +761,314 @@ fun NN() = Namespace("NN") {
}
}
+ Op("windowedAttention") {
+ javaPackage = "org.nd4j.linalg.api.ops.impl.transforms.custom"
+ val q = Input(NUMERIC, "query") { description = "Query tensor. Shape: [batch, seqLen, numHeads, headDim] for 1D or [batch, height, width, numHeads, headDim] for 2D" }
+ val k = Input(NUMERIC, "key") { description = "Key tensor. Same shape as query" }
+ val v = Input(NUMERIC, "value") { description = "Value tensor. Same shape as query" }
+ val rpb = Input(NUMERIC, "relativePositionBias") { description = "Optional relative position bias. Shape: [numHeads, windowSize, windowSize]"; defaultValue = null }
+ val mask = Input(NUMERIC, "attentionMask") { description = "Optional attention mask"; defaultValue = null }
+
+ val windowSize = Arg(INT, "windowSize") { description = "Size of attention window" }
+ val numHeads = Arg(INT, "numHeads") { description = "Number of attention heads" }
+ val shiftSize = Arg(INT, "shiftSize") { defaultValue = 0; description = "Shift size for shifted window attention (Swin style). 0 = no shift" }
+ val scale = Arg(FLOATING_POINT, "scale") { defaultValue = 0.0; description = "Attention scale factor. 0 = auto (1/sqrt(headDim))" }
+ val returnWeights = Arg(BOOL, "returnWeights") { defaultValue = false; description = "Whether to return attention weights" }
+
+ Output(NUMERIC, "output") { description = "Attention output. Same shape as query" }
+
+ Signature(q, k, v, windowSize, numHeads)
+ Signature(q, k, v, rpb, mask, windowSize, numHeads, shiftSize, scale, returnWeights)
+
+ Doc(Language.ANY, DocScope.ALL) {
+ """
+ Windowed Attention - Local/Sliding Window Attention.
+
+ Implements windowed attention mechanisms used in efficient transformers like
+ Longformer, BigBird, Swin Transformer, and SAM (Segment Anything Model).
+
+ Supports both:
+ - 1D windowed attention: for sequences [batch, seqLen, heads, dim]
+ - 2D windowed attention: for images [batch, height, width, heads, dim]
+
+ Shifted window attention (shiftSize > 0) enables cross-window connections
+ as used in Swin Transformer.
+
+ Benefits:
+ - O(N * windowSize) complexity instead of O(N^2)
+ - Efficient for long sequences and high-resolution images
+ - Supports relative position bias for position-aware attention
+ """.trimIndent()
+ }
+ }
+
+ Op("relativePositionBias") {
+ javaPackage = "org.nd4j.linalg.api.ops.impl.transforms.custom"
+ val biasTable = Input(NUMERIC, "biasTable") { description = "Learned bias table. Shape: [numRelativePositions, numHeads] for learned mode, or scalar/tensor for ALiBi mode" }
+ val relPosIndex = Input(NUMERIC, "relativePositionIndex") { description = "Optional precomputed relative position index. Shape: [windowSize^2, windowSize^2]"; defaultValue = null }
+
+ val numHeads = Arg(INT, "numHeads") { description = "Number of attention heads" }
+ val windowSize = Arg(INT, "windowSize") { defaultValue = 0; description = "Window size for 2D position encoding (used if generating index)" }
+ val useAlibi = Arg(BOOL, "useAlibi") { defaultValue = false; description = "Use ALiBi (Attention with Linear Biases) instead of learned bias" }
+
+ Output(NUMERIC, "output") { description = "Position bias. Shape: [numHeads, windowSize^2, windowSize^2] or [numHeads, seqLen, seqLen]" }
+
+ Signature(biasTable, numHeads, windowSize)
+ Signature(biasTable, relPosIndex, numHeads, windowSize)
+
+ Doc(Language.ANY, DocScope.ALL) {
+ """
+ Relative Position Bias - Compute relative position bias for attention.
+
+ Supports two modes:
+ 1. Learned bias (Swin/SAM style): Looks up bias values from a learned table
+ based on relative positions between query and key positions.
+
+ 2. ALiBi (Attention with Linear Biases): Computes linear position-based bias
+ without learned parameters. More efficient for very long sequences.
+
+ For learned bias mode:
+ - biasTable shape: [(2*windowSize-1)^2, numHeads] for 2D
+ - Output is gathered based on relative position indices
+
+ For ALiBi mode:
+ - biasTable can be sequence length (scalar) or input tensor
+ - Computes m_h * |i - j| where m_h = 2^(-8*h/H)
+
+ Reference: "Swin Transformer" (Liu et al., 2021)
+ "Train Short, Test Long" (Press et al., 2021) for ALiBi
+ """.trimIndent()
+ }
+ }
+
+ Op("mixtureOfExperts") {
+ javaPackage = "org.nd4j.linalg.api.ops.impl.transforms.custom"
+ val input = Input(NUMERIC, "input") { description = "Input embeddings. Shape: [batch, seqLen, hiddenSize]" }
+ val routerWeights = Input(NUMERIC, "routerWeights") { description = "Router projection weights. Shape: [hiddenSize, numExperts]" }
+ val expertWeights = Input(NUMERIC, "expertWeights") { description = "Expert weight matrices. Shape: [numExperts, hiddenSize, expertHiddenSize]" }
+ val expertBias = Input(NUMERIC, "expertBias") { description = "Optional expert biases. Shape: [numExperts, expertHiddenSize]"; defaultValue = null }
+
+ val numExperts = Arg(INT, "numExperts") { description = "Total number of experts" }
+ val topK = Arg(INT, "topK") { defaultValue = 2; description = "Number of experts to route to per token" }
+ val normalizeProbs = Arg(BOOL, "normalizeProbs") { defaultValue = true; description = "Whether to normalize router probabilities for selected experts" }
+ val capacityFactor = Arg(FLOATING_POINT, "capacityFactor") { defaultValue = 1.0; description = "Expert capacity factor for load balancing" }
+
+ Output(NUMERIC, "output") { description = "Combined expert outputs. Shape: [batch, seqLen, expertHiddenSize]" }
+ Output(NUMERIC, "routerProbs") { description = "Router probabilities. Shape: [batch, seqLen, numExperts]" }
+ Output(NUMERIC, "expertIndices") { description = "Selected expert indices. Shape: [batch, seqLen, topK]" }
+
+ Signature(input, routerWeights, expertWeights, numExperts, topK)
+ Signature(input, routerWeights, expertWeights, expertBias, numExperts, topK, normalizeProbs, capacityFactor)
+
+ Doc(Language.ANY, DocScope.ALL) {
+ """
+ Mixture of Experts (MoE) Layer.
+
+ Implements sparse MoE routing where each token is processed by only the top-k
+ selected experts out of a larger pool. This enables scaling model capacity
+ without proportionally increasing computation.
+
+ Used in large language models like:
+ - DeepSeek (DeepSeekMoE)
+ - Mixtral (Mistral AI)
+ - Switch Transformer (Google)
+ - GShard (Google)
+
+ The router computes expert selection probabilities:
+ router_probs = softmax(input @ routerWeights)
+
+ Top-k experts are selected and their outputs are weighted by normalized probs:
+ output = sum(normalized_prob[i] * expert[i](input) for i in top_k)
+
+ Benefits:
+ - Scales model capacity with sublinear compute increase
+ - Enables very large models with efficient inference
+ - Supports expert parallelism across devices
+ """.trimIndent()
+ }
+ }
+
+ Op("ctcGreedyDecoder") {
+ javaPackage = "org.nd4j.linalg.api.ops.impl.transforms.custom"
+ javaOpClass = "CTCGreedyDecoder"
+ val logits = Input(NUMERIC, "logits") { description = "Log probabilities from CTC output. Shape: [batch, timeSteps, numClasses]" }
+ val sequenceLength = Input(NUMERIC, "sequenceLength") { description = "Optional actual sequence lengths. Shape: [batch]"; defaultValue = null }
+
+ val mergeRepeated = Arg(BOOL, "mergeRepeated") { defaultValue = true; description = "Whether to merge repeated characters in output" }
+ val blankIndex = Arg(INT, "blankIndex") { defaultValue = 0; description = "Index of the blank label in the vocabulary" }
+
+ Output(NUMERIC, "decoded") { description = "Decoded sequences. Shape: [batch, timeSteps] (padded with blank)" }
+ Output(NUMERIC, "logProbability") { description = "Log probability of decoded sequences. Shape: [batch]" }
+
+ Signature(logits, mergeRepeated, blankIndex)
+ Signature(logits, sequenceLength, mergeRepeated, blankIndex)
+
+ Doc(Language.ANY, DocScope.ALL) {
+ """
+ CTC Greedy Decoder - Connectionist Temporal Classification decoding.
+
+ Performs greedy (best path) decoding on CTC output. Used in:
+ - OCR (Optical Character Recognition) - PaddleOCR, CRNN
+ - Speech recognition - DeepSpeech, Wav2Vec
+ - Handwriting recognition
+
+ Algorithm:
+ 1. At each timestep, select the class with highest probability
+ 2. Optionally merge consecutive repeated characters
+ 3. Remove blank labels from the output
+
+ For example, with mergeRepeated=true and blankIndex=0:
+ Input: [0, 1, 1, 0, 2, 2, 2, 0] (0=blank, 1='a', 2='b')
+ Output: [1, 2] -> "ab"
+
+ Note: This is greedy decoding. For better accuracy with language models,
+ use beam search decoding instead.
+ """.trimIndent()
+ }
+ }
+
+ Op("emaUpdate") {
+ javaPackage = "org.nd4j.linalg.api.ops.impl.transforms.custom"
+ javaOpClass = "EmaUpdate"
+ Input(NUMERIC, "model") { description = "Current model parameters (student)" }
+ Input(NUMERIC, "shadow") { description = "EMA shadow parameters (teacher)" }
+ Arg(NUMERIC, "decay") { description = "EMA decay factor (typically 0.996-0.9999)"; defaultValue = 0.999 }
+ Output(NUMERIC, "output") { description = "Updated shadow parameters" }
+ Doc(Language.ANY, DocScope.ALL) {
+ """
+ Exponential Moving Average parameter update for DINOv2 teacher networks.
+ Computes: output = decay * shadow + (1 - decay) * model
+ Used in self-supervised learning to maintain a slowly-updated teacher model.
+ """.trimIndent()
+ }
+ }
+
+ Op("centerAndSharpen") {
+ javaPackage = "org.nd4j.linalg.api.ops.impl.transforms.custom"
+ javaOpClass = "CenterAndSharpen"
+ Input(NUMERIC, "input") { description = "Teacher output logits [batch, features]" }
+ Input(NUMERIC, "center") { description = "Running center vector [features]" }
+ Arg(NUMERIC, "temperature") { description = "Sharpening temperature (typically 0.04-0.07)"; defaultValue = 0.07 }
+ Output(NUMERIC, "output") { description = "Sharpened probabilities [batch, features]" }
+ Doc(Language.ANY, DocScope.ALL) {
+ """
+ DINOv2 centering and sharpening operation.
+ Prevents mode collapse in self-supervised learning by centering the teacher output
+ and applying temperature-based sharpening:
+ output = softmax((input - center) / temperature)
+ """.trimIndent()
+ }
+ }
+
+ Op("twoWayCrossAttention") {
+ javaPackage = "org.nd4j.linalg.api.ops.impl.transforms.custom"
+ javaOpClass = "TwoWayCrossAttention"
+ Input(NUMERIC, "tokenQuery") { description = "Token queries [batch, tokenSeqLen, embedDim]" }
+ Input(NUMERIC, "tokenKey") { description = "Token keys [batch, tokenSeqLen, embedDim]" }
+ Input(NUMERIC, "tokenValue") { description = "Token values [batch, tokenSeqLen, embedDim]" }
+ Input(NUMERIC, "imageQuery") { description = "Image queries [batch, imageSeqLen, embedDim]" }
+ Input(NUMERIC, "imageKey") { description = "Image keys [batch, imageSeqLen, embedDim]" }
+ Input(NUMERIC, "imageValue") { description = "Image values [batch, imageSeqLen, embedDim]" }
+ Arg(NUMERIC, "scale") { description = "Attention scale factor (default: 1/sqrt(embedDim))"; defaultValue = 0.0 }
+ Output(NUMERIC, "tokenOutput") { description = "Attended token embeddings [batch, tokenSeqLen, embedDim]" }
+ Output(NUMERIC, "imageOutput") { description = "Attended image embeddings [batch, imageSeqLen, embedDim]" }
+ Doc(Language.ANY, DocScope.ALL) {
+ """
+ SAM-style Two-Way Cross Attention.
+ Bidirectional cross-attention where tokens attend to image features and
+ image features attend to tokens simultaneously:
+ tokenOutput = softmax(tokenQ @ imageK^T * scale) @ imageV
+ imageOutput = softmax(imageQ @ tokenK^T * scale) @ tokenV
+ """.trimIndent()
+ }
+ }
+
+ Op("tokenSample") {
+ javaPackage = "org.nd4j.linalg.api.ops.impl.transforms.custom"
+ javaOpClass = "TokenSample"
+ val logits = Input(NUMERIC, "logits") { description = "Logits tensor. Shape: [vocabSize], [batch, vocabSize], or [batch, seqLen, vocabSize]. For rank-3, samples from the last sequence position." }
+ val temperature = Arg(FLOATING_POINT, "temperature") { defaultValue = 0.0; description = "Temperature for sampling. 0 = greedy (argmax)" }
+ val topK = Arg(INT, "topK") { defaultValue = 0; description = "Top-K filtering: keep only top K logits. 0 = disabled" }
+ val topP = Arg(FLOATING_POINT, "topP") { defaultValue = 0.0; description = "Top-P (nucleus) filtering threshold. 0 = disabled" }
+ val seed = Arg(LONG, "seed") { defaultValue = 0; description = "Random seed for sampling. 0 = random" }
+
+ Output(LONG, "output") { description = "Sampled token indices. Shape: [batch] or scalar" }
+
+ Signature(logits)
+ Signature(logits, temperature, topK, topP, seed)
+
+ Doc(Language.ANY, DocScope.ALL) {
+ """
+ Token sampling for LLM inference.
+
+ Full sampling pipeline in a single native GPU call:
+ temperature scaling -> top-K filtering -> softmax -> top-P filtering -> sample/argmax
+
+ For greedy decoding (temperature=0 or no top-k/top-p), performs GPU-side argmax
+ with shared-memory reduction — avoids transferring the full logits tensor to host.
+
+ Supports rank 1 [vocabSize], rank 2 [batch, vocabSize], and rank 3
+ [batch, seqLen, vocabSize] inputs. For rank 3, the last sequence position
+ is automatically extracted for sampling.
+ """.trimIndent()
+ }
+ }
+
+ Op("kvScatter") {
+ javaPackage = "org.nd4j.linalg.api.ops.impl.transforms.custom"
+ javaOpClass = "KvScatter"
+ val present = Input(NUMERIC, "present") { description = "Present KV tensor from decoder output. Shape: [batch, heads, seqLen, dim]" }
+ val staticBuf = Input(NUMERIC, "staticBuffer") { description = "Static KV cache buffer. Shape: [batch, heads, maxKvLen, dim]. Updated in-place." }
+
+ val cachePos = Arg(LONG, "cachePos") { description = "Position in static buffer to write the new entry" }
+ val numPairs = Arg(INT, "numPairs") { defaultValue = 1; description = "Number of present/static KV pairs. When > 1, inputs are [present_0..N-1, static_0..N-1]" }
+
+ Output(LONG, "output") { description = "Scalar 0 on success" }
+
+ Signature(present, staticBuf, cachePos)
+ Signature(present, staticBuf, cachePos, numPairs)
+
+ Doc(Language.ANY, DocScope.ALL) {
+ """
+ Batch KV cache scatter update for LLM autoregressive decoding.
+
+ Copies a single time-step slice from each present KV tensor into the
+ corresponding static KV buffer at a given cache position. Replaces N
+ individual Java view+assign calls with a single native kernel launch.
+
+ The present tensor has shape [batch, heads, seqLen, dim] where the new
+ token's KV entry is at the last sequence position. This entry is extracted
+ and written into the static buffer at cachePos.
+
+ For multiple pairs, inputs are ordered as:
+ [present_0, ..., present_{N-1}, static_0, ..., static_{N-1}]
+ """.trimIndent()
+ }
+ }
+
Alias(Math(), "tanh")
+
+ Op("fusedElementwiseChain") {
+ javaPackage = "org.nd4j.linalg.api.ops.impl.transforms.custom"
+ javaOpClass = "FusedElementwiseChain"
+
+ Input(NUMERIC, "input") { description = "Primary input array" }
+ Input(NUMERIC, "secondaryInputs") {
+ count = AtLeast(0)
+ description = "Optional secondary input arrays for binary ops (add, sub, mul, div)"
+ }
+ Arg(INT, "opCodes") {
+ count = AtLeast(1)
+ description = "Op codes: 0=add, 1=sub, 2=mul, 3=div, 10=relu, 11=sigmoid, 12=tanh, 13=gelu, 14=exp, 15=log, 16=abs, 17=neg, 18=square, 19=sqrt, 20=swish, 21=silu, 22=mish, 30=clip, 31=leaky_relu"
+ }
+
+ Output(NUMERIC, "output"){ description = "Result of applying the fused element-wise chain" }
+
+ Doc(Language.ANY, DocScope.ALL){
+ """
+ Executes a fused chain of element-wise operations in a single kernel pass.
+ Intermediate values stay in registers instead of global memory. Replaces N separate kernel launches with 1.
+ """.trimIndent()
+ }
+ }
}
diff --git a/codegen/op-codegen/src/main/ops/org/nd4j/codegen/ops/SDBaseOps.kt b/codegen/op-codegen/src/main/ops/org/nd4j/codegen/ops/SDBaseOps.kt
index 237a9d4a37ed..0d6e1928c244 100644
--- a/codegen/op-codegen/src/main/ops/org/nd4j/codegen/ops/SDBaseOps.kt
+++ b/codegen/op-codegen/src/main/ops/org/nd4j/codegen/ops/SDBaseOps.kt
@@ -1713,6 +1713,20 @@ fun SDBaseOps() = Namespace("BaseOps"){
}
}
+ Op("squeezeAll") {
+ javaPackage = "org.nd4j.linalg.api.ops.impl.shape"
+ javaOpClass = "Squeeze"
+ Input(NUMERIC, "x") { description = "Input variable" }
+ Output(NUMERIC, "output"){ description = "Output variable" }
+ Doc(Language.ANY, DocScope.ALL){
+ """
+ Remove all dimensions of size 1 from the input tensor.
+ For example, if input has shape [a,1,b,1,c] then squeezeAll(input) returns an array of shape [a,b,c]
+ This is the NumPy-style squeeze with no axis specified.
+ """.trimIndent()
+ }
+ }
+
Op("stack") {
javaPackage = "org.nd4j.linalg.api.ops.impl.shape"
argsFirst = true
@@ -2179,4 +2193,56 @@ fun SDBaseOps() = Namespace("BaseOps"){
""".trimIndent()
}
}
+
+ Op("booleanNot") {
+ javaPackage = "org.nd4j.linalg.api.ops.impl.transforms.bool"
+ javaOpClass = "BooleanNot"
+ legacy = true
+ Input(BOOL, "x") { description = "Input boolean array" }
+ Output(BOOL, "output"){ description = "Boolean NOT result" }
+ Doc(Language.ANY, DocScope.ALL){
+ """
+ Boolean NOT operation: elementwise !x
+ """.trimIndent()
+ }
+ }
+
+ Op("booleanAnd") {
+ javaPackage = "org.nd4j.linalg.api.ops.impl.transforms.custom"
+ javaOpClass = "BooleanAnd"
+ Input(BOOL, "x") { description = "First input boolean array" }
+ Input(BOOL, "y") { description = "Second input boolean array" }
+ Output(BOOL, "output"){ description = "Boolean AND result" }
+ Doc(Language.ANY, DocScope.ALL){
+ """
+ Boolean AND operation: elementwise x && y. Supports broadcasting.
+ """.trimIndent()
+ }
+ }
+
+ Op("booleanOr") {
+ javaPackage = "org.nd4j.linalg.api.ops.impl.transforms.custom"
+ javaOpClass = "BooleanOr"
+ Input(BOOL, "x") { description = "First input boolean array" }
+ Input(BOOL, "y") { description = "Second input boolean array" }
+ Output(BOOL, "output"){ description = "Boolean OR result" }
+ Doc(Language.ANY, DocScope.ALL){
+ """
+ Boolean OR operation: elementwise x || y. Supports broadcasting.
+ """.trimIndent()
+ }
+ }
+
+ Op("booleanXor") {
+ javaPackage = "org.nd4j.linalg.api.ops.impl.transforms.custom"
+ javaOpClass = "BooleanXor"
+ Input(BOOL, "x") { description = "First input boolean array" }
+ Input(BOOL, "y") { description = "Second input boolean array" }
+ Output(BOOL, "output"){ description = "Boolean XOR result" }
+ Doc(Language.ANY, DocScope.ALL){
+ """
+ Boolean XOR operation: elementwise x ^ y. Supports broadcasting.
+ """.trimIndent()
+ }
+ }
}
\ No newline at end of file
diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDBaseOps.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDBaseOps.java
index e64580b9ff00..a36ea1612f6e 100644
--- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDBaseOps.java
+++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDBaseOps.java
@@ -424,6 +424,114 @@ public SDVariable[] batchMmul(String[] names, SDVariable alphas, SDVariable beta
return sd.updateVariableNamesAndReferences(out, names);
}
+ /**
+ * Boolean AND operation: elementwise x && y. Supports broadcasting.
+ *
+ * @param x First input boolean array (BOOL type)
+ * @param y Second input boolean array (BOOL type)
+ * @return output Boolean AND result (BOOL type)
+ */
+ public SDVariable booleanAnd(SDVariable x, SDVariable y) {
+ SDValidation.validateBool("booleanAnd", "x", x);
+ SDValidation.validateBool("booleanAnd", "y", y);
+ return new org.nd4j.linalg.api.ops.impl.transforms.custom.BooleanAnd(sd,x, y).outputVariable();
+ }
+
+ /**
+ * Boolean AND operation: elementwise x && y. Supports broadcasting.
+ *
+ * @param name name May be null. Name for the output variable
+ * @param x First input boolean array (BOOL type)
+ * @param y Second input boolean array (BOOL type)
+ * @return output Boolean AND result (BOOL type)
+ */
+ public SDVariable booleanAnd(String name, SDVariable x, SDVariable y) {
+ SDValidation.validateBool("booleanAnd", "x", x);
+ SDValidation.validateBool("booleanAnd", "y", y);
+ SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.BooleanAnd(sd,x, y).outputVariable();
+ return sd.updateVariableNameAndReference(out, name);
+ }
+
+ /**
+ * Boolean NOT operation: elementwise !x
+ *
+ * @param x Input boolean array (BOOL type)
+ * @return output Boolean NOT result (BOOL type)
+ */
+ public SDVariable booleanNot(SDVariable x) {
+ SDValidation.validateBool("booleanNot", "x", x);
+ return new org.nd4j.linalg.api.ops.impl.transforms.bool.BooleanNot(sd,x).outputVariable();
+ }
+
+ /**
+ * Boolean NOT operation: elementwise !x
+ *
+ * @param name name May be null. Name for the output variable
+ * @param x Input boolean array (BOOL type)
+ * @return output Boolean NOT result (BOOL type)
+ */
+ public SDVariable booleanNot(String name, SDVariable x) {
+ SDValidation.validateBool("booleanNot", "x", x);
+ SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.bool.BooleanNot(sd,x).outputVariable();
+ return sd.updateVariableNameAndReference(out, name);
+ }
+
+ /**
+ * Boolean OR operation: elementwise x || y. Supports broadcasting.
+ *
+ * @param x First input boolean array (BOOL type)
+ * @param y Second input boolean array (BOOL type)
+ * @return output Boolean OR result (BOOL type)
+ */
+ public SDVariable booleanOr(SDVariable x, SDVariable y) {
+ SDValidation.validateBool("booleanOr", "x", x);
+ SDValidation.validateBool("booleanOr", "y", y);
+ return new org.nd4j.linalg.api.ops.impl.transforms.custom.BooleanOr(sd,x, y).outputVariable();
+ }
+
+ /**
+ * Boolean OR operation: elementwise x || y. Supports broadcasting.
+ *
+ * @param name name May be null. Name for the output variable
+ * @param x First input boolean array (BOOL type)
+ * @param y Second input boolean array (BOOL type)
+ * @return output Boolean OR result (BOOL type)
+ */
+ public SDVariable booleanOr(String name, SDVariable x, SDVariable y) {
+ SDValidation.validateBool("booleanOr", "x", x);
+ SDValidation.validateBool("booleanOr", "y", y);
+ SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.BooleanOr(sd,x, y).outputVariable();
+ return sd.updateVariableNameAndReference(out, name);
+ }
+
+ /**
+ * Boolean XOR operation: elementwise x ^ y. Supports broadcasting.
+ *
+ * @param x First input boolean array (BOOL type)
+ * @param y Second input boolean array (BOOL type)
+ * @return output Boolean XOR result (BOOL type)
+ */
+ public SDVariable booleanXor(SDVariable x, SDVariable y) {
+ SDValidation.validateBool("booleanXor", "x", x);
+ SDValidation.validateBool("booleanXor", "y", y);
+ return new org.nd4j.linalg.api.ops.impl.transforms.custom.BooleanXor(sd,x, y).outputVariable();
+ }
+
+ /**
+ * Boolean XOR operation: elementwise x ^ y. Supports broadcasting.
+ *
+ * @param name name May be null. Name for the output variable
+ * @param x First input boolean array (BOOL type)
+ * @param y Second input boolean array (BOOL type)
+ * @return output Boolean XOR result (BOOL type)
+ */
+ public SDVariable booleanXor(String name, SDVariable x, SDVariable y) {
+ SDValidation.validateBool("booleanXor", "x", x);
+ SDValidation.validateBool("booleanXor", "y", y);
+ SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.BooleanXor(sd,x, y).outputVariable();
+ return sd.updateVariableNameAndReference(out, name);
+ }
+
/**
* Cast the array to a new datatype - for example, Integer -> Float
*
@@ -4716,6 +4824,34 @@ public SDVariable squeeze(String name, SDVariable x, int axis) {
return sd.updateVariableNameAndReference(out, name);
}
+ /**
+ * Remove all dimensions of size 1 from the input tensor.
+ * For example, if input has shape [a,1,b,1,c] then squeezeAll(input) returns an array of shape [a,b,c]
+ * This is the NumPy-style squeeze with no axis specified.
+ *
+ * @param x Input variable (NUMERIC type)
+ * @return output Output variable (NUMERIC type)
+ */
+ public SDVariable squeezeAll(SDVariable x) {
+ SDValidation.validateNumerical("squeezeAll", "x", x);
+ return new org.nd4j.linalg.api.ops.impl.shape.Squeeze(sd,x).outputVariable();
+ }
+
+ /**
+ * Remove all dimensions of size 1 from the input tensor.
+ * For example, if input has shape [a,1,b,1,c] then squeezeAll(input) returns an array of shape [a,b,c]
+ * This is the NumPy-style squeeze with no axis specified.
+ *
+ * @param name name May be null. Name for the output variable
+ * @param x Input variable (NUMERIC type)
+ * @return output Output variable (NUMERIC type)
+ */
+ public SDVariable squeezeAll(String name, SDVariable x) {
+ SDValidation.validateNumerical("squeezeAll", "x", x);
+ SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.Squeeze(sd,x).outputVariable();
+ return sd.updateVariableNameAndReference(out, name);
+ }
+
/**
* Stack a set of N INDArray of rank X into one rank X+1 variable.
* If inputs have shape [a,b,c] then output has shape:
diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDNN.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDNN.java
index 03150f6189f6..7f56b5fc131e 100644
--- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDNN.java
+++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDNN.java
@@ -145,6 +145,204 @@ public SDVariable biasAdd(String name, SDVariable input, SDVariable bias, boolea
return sd.updateVariableNameAndReference(out, name);
}
+ /**
+ * DINOv2 centering and sharpening operation.
+ * Prevents mode collapse in self-supervised learning by centering the teacher output
+ * and applying temperature-based sharpening:
+ * output = softmax((input - center) / temperature)
+ *
+ * @param input Teacher output logits [batch, features] (NUMERIC type)
+ * @param center Running center vector [features] (NUMERIC type)
+ * @param temperature Sharpening temperature (typically 0.04-0.07)
+ * @return output Sharpened probabilities [batch, features] (NUMERIC type)
+ */
+ public SDVariable centerAndSharpen(SDVariable input, SDVariable center, double temperature) {
+ SDValidation.validateNumerical("centerAndSharpen", "input", input);
+ SDValidation.validateNumerical("centerAndSharpen", "center", center);
+ return new org.nd4j.linalg.api.ops.impl.transforms.custom.CenterAndSharpen(sd,input, center, temperature).outputVariable();
+ }
+
+ /**
+ * DINOv2 centering and sharpening operation.
+ * Prevents mode collapse in self-supervised learning by centering the teacher output
+ * and applying temperature-based sharpening:
+ * output = softmax((input - center) / temperature)
+ *
+ * @param name name May be null. Name for the output variable
+ * @param input Teacher output logits [batch, features] (NUMERIC type)
+ * @param center Running center vector [features] (NUMERIC type)
+ * @param temperature Sharpening temperature (typically 0.04-0.07)
+ * @return output Sharpened probabilities [batch, features] (NUMERIC type)
+ */
+ public SDVariable centerAndSharpen(String name, SDVariable input, SDVariable center,
+ double temperature) {
+ SDValidation.validateNumerical("centerAndSharpen", "input", input);
+ SDValidation.validateNumerical("centerAndSharpen", "center", center);
+ SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.CenterAndSharpen(sd,input, center, temperature).outputVariable();
+ return sd.updateVariableNameAndReference(out, name);
+ }
+
+ /**
+ * DINOv2 centering and sharpening operation.
+ * Prevents mode collapse in self-supervised learning by centering the teacher output
+ * and applying temperature-based sharpening:
+ * output = softmax((input - center) / temperature)
+ *
+ * @param input Teacher output logits [batch, features] (NUMERIC type)
+ * @param center Running center vector [features] (NUMERIC type)
+ * @return output Sharpened probabilities [batch, features] (NUMERIC type)
+ */
+ public SDVariable centerAndSharpen(SDVariable input, SDVariable center) {
+ SDValidation.validateNumerical("centerAndSharpen", "input", input);
+ SDValidation.validateNumerical("centerAndSharpen", "center", center);
+ return new org.nd4j.linalg.api.ops.impl.transforms.custom.CenterAndSharpen(sd,input, center, 0.07).outputVariable();
+ }
+
+ /**
+ * DINOv2 centering and sharpening operation.
+ * Prevents mode collapse in self-supervised learning by centering the teacher output
+ * and applying temperature-based sharpening:
+ * output = softmax((input - center) / temperature)
+ *
+ * @param name name May be null. Name for the output variable
+ * @param input Teacher output logits [batch, features] (NUMERIC type)
+ * @param center Running center vector [features] (NUMERIC type)
+ * @return output Sharpened probabilities [batch, features] (NUMERIC type)
+ */
+ public SDVariable centerAndSharpen(String name, SDVariable input, SDVariable center) {
+ SDValidation.validateNumerical("centerAndSharpen", "input", input);
+ SDValidation.validateNumerical("centerAndSharpen", "center", center);
+ SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.CenterAndSharpen(sd,input, center, 0.07).outputVariable();
+ return sd.updateVariableNameAndReference(out, name);
+ }
+
+ /**
+ * CTC Greedy Decoder - Connectionist Temporal Classification decoding.
+ *
+ * Performs greedy (best path) decoding on CTC output. Used in:
+ * - OCR (Optical Character Recognition) - PaddleOCR, CRNN
+ * - Speech recognition - DeepSpeech, Wav2Vec
+ * - Handwriting recognition
+ *
+ * Algorithm:
+ * 1. At each timestep, select the class with highest probability
+ * 2. Optionally merge consecutive repeated characters
+ * 3. Remove blank labels from the output
+ *
+ * For example, with mergeRepeated=true and blankIndex=0:
+ * Input: [0, 1, 1, 0, 2, 2, 2, 0] (0=blank, 1='a', 2='b')
+ * Output: [1, 2] -> "ab"
+ *
+ * Note: This is greedy decoding. For better accuracy with language models,
+ * use beam search decoding instead.
+ *
+ * @param logits Log probabilities from CTC output. Shape: [batch, timeSteps, numClasses] (NUMERIC type)
+ * @param mergeRepeated Whether to merge repeated characters in output
+ * @param blankIndex Index of the blank label in the vocabulary
+ */
+ public SDVariable[] ctcGreedyDecoder(SDVariable logits, boolean mergeRepeated, int blankIndex) {
+ SDValidation.validateNumerical("ctcGreedyDecoder", "logits", logits);
+ return new org.nd4j.linalg.api.ops.impl.transforms.custom.CTCGreedyDecoder(sd,logits, null, mergeRepeated, blankIndex).outputVariables();
+ }
+
+ /**
+ * CTC Greedy Decoder - Connectionist Temporal Classification decoding.
+ *
+ * Performs greedy (best path) decoding on CTC output. Used in:
+ * - OCR (Optical Character Recognition) - PaddleOCR, CRNN
+ * - Speech recognition - DeepSpeech, Wav2Vec
+ * - Handwriting recognition
+ *
+ * Algorithm:
+ * 1. At each timestep, select the class with highest probability
+ * 2. Optionally merge consecutive repeated characters
+ * 3. Remove blank labels from the output
+ *
+ * For example, with mergeRepeated=true and blankIndex=0:
+ * Input: [0, 1, 1, 0, 2, 2, 2, 0] (0=blank, 1='a', 2='b')
+ * Output: [1, 2] -> "ab"
+ *
+ * Note: This is greedy decoding. For better accuracy with language models,
+ * use beam search decoding instead.
+ *
+ * @param names names May be null. Arrays of names for the output variables.
+ * @param logits Log probabilities from CTC output. Shape: [batch, timeSteps, numClasses] (NUMERIC type)
+ * @param mergeRepeated Whether to merge repeated characters in output
+ * @param blankIndex Index of the blank label in the vocabulary
+ */
+ public SDVariable[] ctcGreedyDecoder(String[] names, SDVariable logits, boolean mergeRepeated,
+ int blankIndex) {
+ SDValidation.validateNumerical("ctcGreedyDecoder", "logits", logits);
+ SDVariable[] out = new org.nd4j.linalg.api.ops.impl.transforms.custom.CTCGreedyDecoder(sd,logits, null, mergeRepeated, blankIndex).outputVariables();
+ return sd.updateVariableNamesAndReferences(out, names);
+ }
+
+ /**
+ * CTC Greedy Decoder - Connectionist Temporal Classification decoding.
+ *
+ * Performs greedy (best path) decoding on CTC output. Used in:
+ * - OCR (Optical Character Recognition) - PaddleOCR, CRNN
+ * - Speech recognition - DeepSpeech, Wav2Vec
+ * - Handwriting recognition
+ *
+ * Algorithm:
+ * 1. At each timestep, select the class with highest probability
+ * 2. Optionally merge consecutive repeated characters
+ * 3. Remove blank labels from the output
+ *
+ * For example, with mergeRepeated=true and blankIndex=0:
+ * Input: [0, 1, 1, 0, 2, 2, 2, 0] (0=blank, 1='a', 2='b')
+ * Output: [1, 2] -> "ab"
+ *
+ * Note: This is greedy decoding. For better accuracy with language models,
+ * use beam search decoding instead.
+ *
+ * @param logits Log probabilities from CTC output. Shape: [batch, timeSteps, numClasses] (NUMERIC type)
+ * @param sequenceLength Optional actual sequence lengths. Shape: [batch] (NUMERIC type)
+ * @param mergeRepeated Whether to merge repeated characters in output
+ * @param blankIndex Index of the blank label in the vocabulary
+ */
+ public SDVariable[] ctcGreedyDecoder(SDVariable logits, SDVariable sequenceLength,
+ boolean mergeRepeated, int blankIndex) {
+ SDValidation.validateNumerical("ctcGreedyDecoder", "logits", logits);
+ SDValidation.validateNumerical("ctcGreedyDecoder", "sequenceLength", sequenceLength);
+ return new org.nd4j.linalg.api.ops.impl.transforms.custom.CTCGreedyDecoder(sd,logits, sequenceLength, mergeRepeated, blankIndex).outputVariables();
+ }
+
+ /**
+ * CTC Greedy Decoder - Connectionist Temporal Classification decoding.
+ *
+ * Performs greedy (best path) decoding on CTC output. Used in:
+ * - OCR (Optical Character Recognition) - PaddleOCR, CRNN
+ * - Speech recognition - DeepSpeech, Wav2Vec
+ * - Handwriting recognition
+ *
+ * Algorithm:
+ * 1. At each timestep, select the class with highest probability
+ * 2. Optionally merge consecutive repeated characters
+ * 3. Remove blank labels from the output
+ *
+ * For example, with mergeRepeated=true and blankIndex=0:
+ * Input: [0, 1, 1, 0, 2, 2, 2, 0] (0=blank, 1='a', 2='b')
+ * Output: [1, 2] -> "ab"
+ *
+ * Note: This is greedy decoding. For better accuracy with language models,
+ * use beam search decoding instead.
+ *
+ * @param names names May be null. Arrays of names for the output variables.
+ * @param logits Log probabilities from CTC output. Shape: [batch, timeSteps, numClasses] (NUMERIC type)
+ * @param sequenceLength Optional actual sequence lengths. Shape: [batch] (NUMERIC type)
+ * @param mergeRepeated Whether to merge repeated characters in output
+ * @param blankIndex Index of the blank label in the vocabulary
+ */
+ public SDVariable[] ctcGreedyDecoder(String[] names, SDVariable logits, SDVariable sequenceLength,
+ boolean mergeRepeated, int blankIndex) {
+ SDValidation.validateNumerical("ctcGreedyDecoder", "logits", logits);
+ SDValidation.validateNumerical("ctcGreedyDecoder", "sequenceLength", sequenceLength);
+ SDVariable[] out = new org.nd4j.linalg.api.ops.impl.transforms.custom.CTCGreedyDecoder(sd,logits, sequenceLength, mergeRepeated, blankIndex).outputVariables();
+ return sd.updateVariableNamesAndReferences(out, names);
+ }
+
/**
* This operation performs dot product attention on the given timeseries input with the given queries
* out = sum(similarity(k_i, q) * v_i)
@@ -228,35 +426,37 @@ public SDVariable dotProductAttention(String name, SDVariable queries, SDVariabl
}
/**
- * This operation performs dot product attention on the given timeseries input with the given queries
- * out = sum(similarity(k_i, q) * v_i)
+ * Dot product attention operation with flash attention and KV cache support.
*
- * similarity(k, q) = softmax(k * q) where x * q is the dot product of x and q
+ * out = softmax(Q * K^T / scale + attentionBias) * V
*
- * Optionally with normalization step:
- * similarity(k, q) = softmax(k * q / sqrt(size(q))
- *
- * See also "Attention is all you need" (https://arxiv.org/abs/1706.03762, p. 4, eq. 1)
+ * For 4D inputs [batch, seq, heads, dim], uses memory-efficient flash attention algorithm.
+ * For 2D/3D inputs, uses standard attention computation.
*
- * Note: This supports multiple queries at once, if only one query is available the queries vector still has to
- * be 3D but can have queryCount = 1
+ * Flash attention features:
+ * - O(N) memory complexity instead of O(N^2)
+ * - Tiled computation with online softmax
+ * - Supports grouped query attention (GQA) where numHeads > numKvHeads
+ * - Supports attention bias (relative position bias, ALiBi, etc.)
*
- * Note: keys and values usually is the same array. If you want to use it as the same array, simply pass it for
- * both.
+ * KV Cache support for autoregressive generation:
+ * - Pass keyCache and valueCache tensors
+ * - Set kvCachePosition to current generation position
+ * - Cached keys/values are updated in-place
*
- * Note: Queries, keys and values must either be all rank 3 or all rank 4 arrays. Mixing them doesn't work. The
- * output rank will depend on the input rank.
- *
- * @param queries A {@link SDVariable} representing the query tensor. Shape: [batchSize, numQueries, queryDim] (NUMERIC type)
- * @param values A {@link SDVariable} representing the value tensor. Shape: [batchSize, numValues, valueDim] (NUMERIC type)
- * @param keys A {@link SDVariable} representing the key tensor. Shape: [batchSize, numValues, keyDim] (NUMERIC type)
- * @param queryMask A {@link SDVariable} representing the query mask tensor. Shape: [batchSize, numQueries] (NUMERIC type)
- * @param valueMask @param valueMask A {@link SDVariable} representing the value mask tensor. Shape: [batchSize, numValues] (NUMERIC type)
- * @param scaleFactor @param scaleFactor A {@code double} scaling factor applied to the dot product between queries and keys.
- * @param dropoutProbability A {@code double} specifying the dropout probability to be applied to attention weights.
- * @param useCausalMask A {@code boolean} flag to indicate whether to apply a causal mask to the attention scores, for autoregressive tasks.
- * @param training A {@code boolean} flag to indicate whether the layer is in training mode or inference mode, affecting dropout.
- * @return output A {@link SDVariable} representing the output tensor of the dot product attention operation. Shape: [batchSize, numQueries, valueDim] (NUMERIC type)
+ * See "Attention is all you need" (https://arxiv.org/abs/1706.03762)
+ * See "FlashAttention: Fast and Memory-Efficient Exact Attention" (https://arxiv.org/abs/2205.14135)
+ *
+ * @param queries Query tensor. Shape: [batchSize, numQueries, queryDim] or [batchSize, numQueries, numHeads, headDim] for flash attention (NUMERIC type)
+ * @param values Value tensor. Shape: [batchSize, numValues, valueDim] or [batchSize, numValues, numHeads, headDim] (NUMERIC type)
+ * @param keys Key tensor. Shape: [batchSize, numValues, keyDim] or [batchSize, numValues, numHeads, headDim] (NUMERIC type)
+ * @param queryMask Query mask tensor (optional). Shape: [batchSize, numQueries] (NUMERIC type)
+ * @param valueMask Value mask tensor (optional). Shape: [batchSize, numValues] (NUMERIC type)
+ * @param scaleFactor Scaling factor applied to attention scores. 0 = auto (1/sqrt(headDim))
+ * @param dropoutProbability Dropout probability applied to attention weights
+ * @param useCausalMask Whether to apply causal mask for autoregressive tasks
+ * @param training Whether in training mode (affects dropout)
+ * @return output Output tensor. Shape: [batchSize, numQueries, valueDim] or [batchSize, numQueries, numHeads, headDim] (NUMERIC type)
*/
public SDVariable dotProductAttentionV2(SDVariable queries, SDVariable values, SDVariable keys,
SDVariable queryMask, SDVariable valueMask, double scaleFactor, double dropoutProbability,
@@ -266,40 +466,42 @@ public SDVariable dotProductAttentionV2(SDVariable queries, SDVariable values, S
SDValidation.validateNumerical("dotProductAttentionV2", "keys", keys);
SDValidation.validateNumerical("dotProductAttentionV2", "queryMask", queryMask);
SDValidation.validateNumerical("dotProductAttentionV2", "valueMask", valueMask);
- return new org.nd4j.linalg.api.ops.impl.transforms.custom.DotProductAttentionV2(sd,queries, values, keys, queryMask, valueMask, scaleFactor, dropoutProbability, useCausalMask, training).outputVariable();
+ return new org.nd4j.linalg.api.ops.impl.transforms.custom.DotProductAttentionV2(sd,queries, values, keys, queryMask, valueMask, null, scaleFactor, dropoutProbability, useCausalMask, training).outputVariable();
}
/**
- * This operation performs dot product attention on the given timeseries input with the given queries
- * out = sum(similarity(k_i, q) * v_i)
+ * Dot product attention operation with flash attention and KV cache support.
*
- * similarity(k, q) = softmax(k * q) where x * q is the dot product of x and q
+ * out = softmax(Q * K^T / scale + attentionBias) * V
*
- * Optionally with normalization step:
- * similarity(k, q) = softmax(k * q / sqrt(size(q))
+ * For 4D inputs [batch, seq, heads, dim], uses memory-efficient flash attention algorithm.
+ * For 2D/3D inputs, uses standard attention computation.
*
- * See also "Attention is all you need" (https://arxiv.org/abs/1706.03762, p. 4, eq. 1)
- *
- * Note: This supports multiple queries at once, if only one query is available the queries vector still has to
- * be 3D but can have queryCount = 1
+ * Flash attention features:
+ * - O(N) memory complexity instead of O(N^2)
+ * - Tiled computation with online softmax
+ * - Supports grouped query attention (GQA) where numHeads > numKvHeads
+ * - Supports attention bias (relative position bias, ALiBi, etc.)
*
- * Note: keys and values usually is the same array. If you want to use it as the same array, simply pass it for
- * both.
+ * KV Cache support for autoregressive generation:
+ * - Pass keyCache and valueCache tensors
+ * - Set kvCachePosition to current generation position
+ * - Cached keys/values are updated in-place
*
- * Note: Queries, keys and values must either be all rank 3 or all rank 4 arrays. Mixing them doesn't work. The
- * output rank will depend on the input rank.
+ * See "Attention is all you need" (https://arxiv.org/abs/1706.03762)
+ * See "FlashAttention: Fast and Memory-Efficient Exact Attention" (https://arxiv.org/abs/2205.14135)
*
* @param name name May be null. Name for the output variable
- * @param queries A {@link SDVariable} representing the query tensor. Shape: [batchSize, numQueries, queryDim] (NUMERIC type)
- * @param values A {@link SDVariable} representing the value tensor. Shape: [batchSize, numValues, valueDim] (NUMERIC type)
- * @param keys A {@link SDVariable} representing the key tensor. Shape: [batchSize, numValues, keyDim] (NUMERIC type)
- * @param queryMask A {@link SDVariable} representing the query mask tensor. Shape: [batchSize, numQueries] (NUMERIC type)
- * @param valueMask @param valueMask A {@link SDVariable} representing the value mask tensor. Shape: [batchSize, numValues] (NUMERIC type)
- * @param scaleFactor @param scaleFactor A {@code double} scaling factor applied to the dot product between queries and keys.
- * @param dropoutProbability A {@code double} specifying the dropout probability to be applied to attention weights.
- * @param useCausalMask A {@code boolean} flag to indicate whether to apply a causal mask to the attention scores, for autoregressive tasks.
- * @param training A {@code boolean} flag to indicate whether the layer is in training mode or inference mode, affecting dropout.
- * @return output A {@link SDVariable} representing the output tensor of the dot product attention operation. Shape: [batchSize, numQueries, valueDim] (NUMERIC type)
+ * @param queries Query tensor. Shape: [batchSize, numQueries, queryDim] or [batchSize, numQueries, numHeads, headDim] for flash attention (NUMERIC type)
+ * @param values Value tensor. Shape: [batchSize, numValues, valueDim] or [batchSize, numValues, numHeads, headDim] (NUMERIC type)
+ * @param keys Key tensor. Shape: [batchSize, numValues, keyDim] or [batchSize, numValues, numHeads, headDim] (NUMERIC type)
+ * @param queryMask Query mask tensor (optional). Shape: [batchSize, numQueries] (NUMERIC type)
+ * @param valueMask Value mask tensor (optional). Shape: [batchSize, numValues] (NUMERIC type)
+ * @param scaleFactor Scaling factor applied to attention scores. 0 = auto (1/sqrt(headDim))
+ * @param dropoutProbability Dropout probability applied to attention weights
+ * @param useCausalMask Whether to apply causal mask for autoregressive tasks
+ * @param training Whether in training mode (affects dropout)
+ * @return output Output tensor. Shape: [batchSize, numQueries, valueDim] or [batchSize, numQueries, numHeads, headDim] (NUMERIC type)
*/
public SDVariable dotProductAttentionV2(String name, SDVariable queries, SDVariable values,
SDVariable keys, SDVariable queryMask, SDVariable valueMask, double scaleFactor,
@@ -309,7 +511,101 @@ public SDVariable dotProductAttentionV2(String name, SDVariable queries, SDVaria
SDValidation.validateNumerical("dotProductAttentionV2", "keys", keys);
SDValidation.validateNumerical("dotProductAttentionV2", "queryMask", queryMask);
SDValidation.validateNumerical("dotProductAttentionV2", "valueMask", valueMask);
- SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.DotProductAttentionV2(sd,queries, values, keys, queryMask, valueMask, scaleFactor, dropoutProbability, useCausalMask, training).outputVariable();
+ SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.DotProductAttentionV2(sd,queries, values, keys, queryMask, valueMask, null, scaleFactor, dropoutProbability, useCausalMask, training).outputVariable();
+ return sd.updateVariableNameAndReference(out, name);
+ }
+
+ /**
+ * Dot product attention operation with flash attention and KV cache support.
+ *
+ * out = softmax(Q * K^T / scale + attentionBias) * V
+ *
+ * For 4D inputs [batch, seq, heads, dim], uses memory-efficient flash attention algorithm.
+ * For 2D/3D inputs, uses standard attention computation.
+ *
+ * Flash attention features:
+ * - O(N) memory complexity instead of O(N^2)
+ * - Tiled computation with online softmax
+ * - Supports grouped query attention (GQA) where numHeads > numKvHeads
+ * - Supports attention bias (relative position bias, ALiBi, etc.)
+ *
+ * KV Cache support for autoregressive generation:
+ * - Pass keyCache and valueCache tensors
+ * - Set kvCachePosition to current generation position
+ * - Cached keys/values are updated in-place
+ *
+ * See "Attention is all you need" (https://arxiv.org/abs/1706.03762)
+ * See "FlashAttention: Fast and Memory-Efficient Exact Attention" (https://arxiv.org/abs/2205.14135)
+ *
+ * @param queries Query tensor. Shape: [batchSize, numQueries, queryDim] or [batchSize, numQueries, numHeads, headDim] for flash attention (NUMERIC type)
+ * @param values Value tensor. Shape: [batchSize, numValues, valueDim] or [batchSize, numValues, numHeads, headDim] (NUMERIC type)
+ * @param keys Key tensor. Shape: [batchSize, numValues, keyDim] or [batchSize, numValues, numHeads, headDim] (NUMERIC type)
+ * @param queryMask Query mask tensor (optional). Shape: [batchSize, numQueries] (NUMERIC type)
+ * @param valueMask Value mask tensor (optional). Shape: [batchSize, numValues] (NUMERIC type)
+ * @param attentionBias Attention bias tensor (optional). Shape: [batchSize, numHeads, numQueries, numKeys] or broadcastable. Added to attention scores before softmax. (NUMERIC type)
+ * @param scaleFactor Scaling factor applied to attention scores. 0 = auto (1/sqrt(headDim))
+ * @param dropoutProbability Dropout probability applied to attention weights
+ * @param useCausalMask Whether to apply causal mask for autoregressive tasks
+ * @param training Whether in training mode (affects dropout)
+ * @return output Output tensor. Shape: [batchSize, numQueries, valueDim] or [batchSize, numQueries, numHeads, headDim] (NUMERIC type)
+ */
+ public SDVariable dotProductAttentionV2(SDVariable queries, SDVariable values, SDVariable keys,
+ SDVariable queryMask, SDVariable valueMask, SDVariable attentionBias, double scaleFactor,
+ double dropoutProbability, boolean useCausalMask, boolean training) {
+ SDValidation.validateNumerical("dotProductAttentionV2", "queries", queries);
+ SDValidation.validateNumerical("dotProductAttentionV2", "values", values);
+ SDValidation.validateNumerical("dotProductAttentionV2", "keys", keys);
+ SDValidation.validateNumerical("dotProductAttentionV2", "queryMask", queryMask);
+ SDValidation.validateNumerical("dotProductAttentionV2", "valueMask", valueMask);
+ SDValidation.validateNumerical("dotProductAttentionV2", "attentionBias", attentionBias);
+ return new org.nd4j.linalg.api.ops.impl.transforms.custom.DotProductAttentionV2(sd,queries, values, keys, queryMask, valueMask, attentionBias, scaleFactor, dropoutProbability, useCausalMask, training).outputVariable();
+ }
+
+ /**
+ * Dot product attention operation with flash attention and KV cache support.
+ *
+ * out = softmax(Q * K^T / scale + attentionBias) * V
+ *
+ * For 4D inputs [batch, seq, heads, dim], uses memory-efficient flash attention algorithm.
+ * For 2D/3D inputs, uses standard attention computation.
+ *
+ * Flash attention features:
+ * - O(N) memory complexity instead of O(N^2)
+ * - Tiled computation with online softmax
+ * - Supports grouped query attention (GQA) where numHeads > numKvHeads
+ * - Supports attention bias (relative position bias, ALiBi, etc.)
+ *
+ * KV Cache support for autoregressive generation:
+ * - Pass keyCache and valueCache tensors
+ * - Set kvCachePosition to current generation position
+ * - Cached keys/values are updated in-place
+ *
+ * See "Attention is all you need" (https://arxiv.org/abs/1706.03762)
+ * See "FlashAttention: Fast and Memory-Efficient Exact Attention" (https://arxiv.org/abs/2205.14135)
+ *
+ * @param name name May be null. Name for the output variable
+ * @param queries Query tensor. Shape: [batchSize, numQueries, queryDim] or [batchSize, numQueries, numHeads, headDim] for flash attention (NUMERIC type)
+ * @param values Value tensor. Shape: [batchSize, numValues, valueDim] or [batchSize, numValues, numHeads, headDim] (NUMERIC type)
+ * @param keys Key tensor. Shape: [batchSize, numValues, keyDim] or [batchSize, numValues, numHeads, headDim] (NUMERIC type)
+ * @param queryMask Query mask tensor (optional). Shape: [batchSize, numQueries] (NUMERIC type)
+ * @param valueMask Value mask tensor (optional). Shape: [batchSize, numValues] (NUMERIC type)
+ * @param attentionBias Attention bias tensor (optional). Shape: [batchSize, numHeads, numQueries, numKeys] or broadcastable. Added to attention scores before softmax. (NUMERIC type)
+ * @param scaleFactor Scaling factor applied to attention scores. 0 = auto (1/sqrt(headDim))
+ * @param dropoutProbability Dropout probability applied to attention weights
+ * @param useCausalMask Whether to apply causal mask for autoregressive tasks
+ * @param training Whether in training mode (affects dropout)
+ * @return output Output tensor. Shape: [batchSize, numQueries, valueDim] or [batchSize, numQueries, numHeads, headDim] (NUMERIC type)
+ */
+ public SDVariable dotProductAttentionV2(String name, SDVariable queries, SDVariable values,
+ SDVariable keys, SDVariable queryMask, SDVariable valueMask, SDVariable attentionBias,
+ double scaleFactor, double dropoutProbability, boolean useCausalMask, boolean training) {
+ SDValidation.validateNumerical("dotProductAttentionV2", "queries", queries);
+ SDValidation.validateNumerical("dotProductAttentionV2", "values", values);
+ SDValidation.validateNumerical("dotProductAttentionV2", "keys", keys);
+ SDValidation.validateNumerical("dotProductAttentionV2", "queryMask", queryMask);
+ SDValidation.validateNumerical("dotProductAttentionV2", "valueMask", valueMask);
+ SDValidation.validateNumerical("dotProductAttentionV2", "attentionBias", attentionBias);
+ SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.DotProductAttentionV2(sd,queries, values, keys, queryMask, valueMask, attentionBias, scaleFactor, dropoutProbability, useCausalMask, training).outputVariable();
return sd.updateVariableNameAndReference(out, name);
}
@@ -407,6 +703,172 @@ public SDVariable elu(String name, SDVariable x) {
return sd.updateVariableNameAndReference(out, name);
}
+ /**
+ * Exponential Moving Average parameter update for DINOv2 teacher networks.
+ * Computes: output = decay * shadow + (1 - decay) * model
+ * Used in self-supervised learning to maintain a slowly-updated teacher model.
+ *
+ * @param model Current model parameters (student) (NUMERIC type)
+ * @param shadow EMA shadow parameters (teacher) (NUMERIC type)
+ * @param decay EMA decay factor (typically 0.996-0.9999)
+ * @return output Updated shadow parameters (NUMERIC type)
+ */
+ public SDVariable emaUpdate(SDVariable model, SDVariable shadow, double decay) {
+ SDValidation.validateNumerical("emaUpdate", "model", model);
+ SDValidation.validateNumerical("emaUpdate", "shadow", shadow);
+ return new org.nd4j.linalg.api.ops.impl.transforms.custom.EmaUpdate(sd,model, shadow, decay).outputVariable();
+ }
+
+ /**
+ * Exponential Moving Average parameter update for DINOv2 teacher networks.
+ * Computes: output = decay * shadow + (1 - decay) * model
+ * Used in self-supervised learning to maintain a slowly-updated teacher model.
+ *
+ * @param name name May be null. Name for the output variable
+ * @param model Current model parameters (student) (NUMERIC type)
+ * @param shadow EMA shadow parameters (teacher) (NUMERIC type)
+ * @param decay EMA decay factor (typically 0.996-0.9999)
+ * @return output Updated shadow parameters (NUMERIC type)
+ */
+ public SDVariable emaUpdate(String name, SDVariable model, SDVariable shadow, double decay) {
+ SDValidation.validateNumerical("emaUpdate", "model", model);
+ SDValidation.validateNumerical("emaUpdate", "shadow", shadow);
+ SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.EmaUpdate(sd,model, shadow, decay).outputVariable();
+ return sd.updateVariableNameAndReference(out, name);
+ }
+
+ /**
+ * Exponential Moving Average parameter update for DINOv2 teacher networks.
+ * Computes: output = decay * shadow + (1 - decay) * model
+ * Used in self-supervised learning to maintain a slowly-updated teacher model.
+ *
+ * @param model Current model parameters (student) (NUMERIC type)
+ * @param shadow EMA shadow parameters (teacher) (NUMERIC type)
+ * @return output Updated shadow parameters (NUMERIC type)
+ */
+ public SDVariable emaUpdate(SDVariable model, SDVariable shadow) {
+ SDValidation.validateNumerical("emaUpdate", "model", model);
+ SDValidation.validateNumerical("emaUpdate", "shadow", shadow);
+ return new org.nd4j.linalg.api.ops.impl.transforms.custom.EmaUpdate(sd,model, shadow, 0.999).outputVariable();
+ }
+
+ /**
+ * Exponential Moving Average parameter update for DINOv2 teacher networks.
+ * Computes: output = decay * shadow + (1 - decay) * model
+ * Used in self-supervised learning to maintain a slowly-updated teacher model.
+ *
+ * @param name name May be null. Name for the output variable
+ * @param model Current model parameters (student) (NUMERIC type)
+ * @param shadow EMA shadow parameters (teacher) (NUMERIC type)
+ * @return output Updated shadow parameters (NUMERIC type)
+ */
+ public SDVariable emaUpdate(String name, SDVariable model, SDVariable shadow) {
+ SDValidation.validateNumerical("emaUpdate", "model", model);
+ SDValidation.validateNumerical("emaUpdate", "shadow", shadow);
+ SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.EmaUpdate(sd,model, shadow, 0.999).outputVariable();
+ return sd.updateVariableNameAndReference(out, name);
+ }
+
+ /**
+ * Flash Attention - Memory-efficient attention computation.
+ *
+ * Uses tiled computation with online softmax to achieve O(N) memory complexity
+ * instead of O(N^2) for standard attention.
+ *
+ * Supports Grouped Query Attention (GQA) where numHeads > numKvHeads,
+ * allowing multiple query heads to share the same KV heads.
+ *
+ * out = softmax(Q * K^T / scale) * V
+ *
+ * See "FlashAttention: Fast and Memory-Efficient Exact Attention" (https://arxiv.org/abs/2205.14135)
+ *
+ * @param query Query tensor. Shape: [batch, seqLen, numHeads, headDim] (NUMERIC type)
+ * @param key Key tensor. Shape: [batch, seqLen, numKvHeads, headDim] (NUMERIC type)
+ * @param value Value tensor. Shape: [batch, seqLen, numKvHeads, headDim] (NUMERIC type)
+ * @param scale Scaling factor. 0 = auto (1/sqrt(headDim))
+ * @param isCausal Whether to apply causal masking
+ * @param numHeads Number of query attention heads
+ * @param numKvHeads Number of KV heads (0 = same as numHeads, for GQA use smaller value)
+ * @return output Attention output. Shape: [batch, seqLen, numHeads, headDim] (NUMERIC type)
+ */
+ public SDVariable flashAttention(SDVariable query, SDVariable key, SDVariable value, double scale,
+ boolean isCausal, int numHeads, int numKvHeads) {
+ SDValidation.validateNumerical("flashAttention", "query", query);
+ SDValidation.validateNumerical("flashAttention", "key", key);
+ SDValidation.validateNumerical("flashAttention", "value", value);
+ return new org.nd4j.linalg.api.ops.impl.transforms.custom.FlashAttention(sd,query, key, value, scale, isCausal, numHeads, numKvHeads).outputVariable();
+ }
+
+ /**
+ * Flash Attention - Memory-efficient attention computation.
+ *
+ * Uses tiled computation with online softmax to achieve O(N) memory complexity
+ * instead of O(N^2) for standard attention.
+ *
+ * Supports Grouped Query Attention (GQA) where numHeads > numKvHeads,
+ * allowing multiple query heads to share the same KV heads.
+ *
+ * out = softmax(Q * K^T / scale) * V
+ *
+ * See "FlashAttention: Fast and Memory-Efficient Exact Attention" (https://arxiv.org/abs/2205.14135)
+ *
+ * @param name name May be null. Name for the output variable
+ * @param query Query tensor. Shape: [batch, seqLen, numHeads, headDim] (NUMERIC type)
+ * @param key Key tensor. Shape: [batch, seqLen, numKvHeads, headDim] (NUMERIC type)
+ * @param value Value tensor. Shape: [batch, seqLen, numKvHeads, headDim] (NUMERIC type)
+ * @param scale Scaling factor. 0 = auto (1/sqrt(headDim))
+ * @param isCausal Whether to apply causal masking
+ * @param numHeads Number of query attention heads
+ * @param numKvHeads Number of KV heads (0 = same as numHeads, for GQA use smaller value)
+ * @return output Attention output. Shape: [batch, seqLen, numHeads, headDim] (NUMERIC type)
+ */
+ public SDVariable flashAttention(String name, SDVariable query, SDVariable key, SDVariable value,
+ double scale, boolean isCausal, int numHeads, int numKvHeads) {
+ SDValidation.validateNumerical("flashAttention", "query", query);
+ SDValidation.validateNumerical("flashAttention", "key", key);
+ SDValidation.validateNumerical("flashAttention", "value", value);
+ SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.FlashAttention(sd,query, key, value, scale, isCausal, numHeads, numKvHeads).outputVariable();
+ return sd.updateVariableNameAndReference(out, name);
+ }
+
+ /**
+ * Executes a fused chain of element-wise operations in a single kernel pass.
+ * Intermediate values stay in registers instead of global memory. Replaces N separate kernel launches with 1.
+ *
+ * @param input Primary input array (NUMERIC type)
+ * @param secondaryInputs Optional secondary input arrays for binary ops (add, sub, mul, div) (NUMERIC type)
+ * @param opCodes Op codes: 0=add, 1=sub, 2=mul, 3=div, 10=relu, 11=sigmoid, 12=tanh, 13=gelu, 14=exp, 15=log, 16=abs, 17=neg, 18=square, 19=sqrt, 20=swish, 21=silu, 22=mish, 30=clip, 31=leaky_relu (Size: AtLeast(min=1))
+ * @return output Result of applying the fused element-wise chain (NUMERIC type)
+ */
+ public SDVariable fusedElementwiseChain(SDVariable input, SDVariable[] secondaryInputs,
+ int[] opCodes) {
+ SDValidation.validateNumerical("fusedElementwiseChain", "input", input);
+ SDValidation.validateNumerical("fusedElementwiseChain", "secondaryInputs", secondaryInputs);
+ Preconditions.checkArgument(secondaryInputs.length >= 0, "secondaryInputs has incorrect size/length. Expected: secondaryInputs.length >= 0, got %s", secondaryInputs.length);
+ Preconditions.checkArgument(opCodes.length >= 1, "opCodes has incorrect size/length. Expected: opCodes.length >= 1, got %s", opCodes.length);
+ return new org.nd4j.linalg.api.ops.impl.transforms.custom.FusedElementwiseChain(sd,input, secondaryInputs, opCodes).outputVariable();
+ }
+
+ /**
+ * Executes a fused chain of element-wise operations in a single kernel pass.
+ * Intermediate values stay in registers instead of global memory. Replaces N separate kernel launches with 1.
+ *
+ * @param name name May be null. Name for the output variable
+ * @param input Primary input array (NUMERIC type)
+ * @param secondaryInputs Optional secondary input arrays for binary ops (add, sub, mul, div) (NUMERIC type)
+ * @param opCodes Op codes: 0=add, 1=sub, 2=mul, 3=div, 10=relu, 11=sigmoid, 12=tanh, 13=gelu, 14=exp, 15=log, 16=abs, 17=neg, 18=square, 19=sqrt, 20=swish, 21=silu, 22=mish, 30=clip, 31=leaky_relu (Size: AtLeast(min=1))
+ * @return output Result of applying the fused element-wise chain (NUMERIC type)
+ */
+ public SDVariable fusedElementwiseChain(String name, SDVariable input,
+ SDVariable[] secondaryInputs, int[] opCodes) {
+ SDValidation.validateNumerical("fusedElementwiseChain", "input", input);
+ SDValidation.validateNumerical("fusedElementwiseChain", "secondaryInputs", secondaryInputs);
+ Preconditions.checkArgument(secondaryInputs.length >= 0, "secondaryInputs has incorrect size/length. Expected: secondaryInputs.length >= 0, got %s", secondaryInputs.length);
+ Preconditions.checkArgument(opCodes.length >= 1, "opCodes has incorrect size/length. Expected: opCodes.length >= 1, got %s", opCodes.length);
+ SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.FusedElementwiseChain(sd,input, secondaryInputs, opCodes).outputVariable();
+ return sd.updateVariableNameAndReference(out, name);
+ }
+
/**
* GELU activation function - Gaussian Error Linear Units
* For more details, see Gaussian Error Linear Units (GELUs) - https://arxiv.org/abs/1606.08415
@@ -435,6 +897,72 @@ public SDVariable gelu(String name, SDVariable x) {
return sd.updateVariableNameAndReference(out, name);
}
+ /**
+ * Grouped Query Attention (GQA) - Efficient attention with shared KV heads.
+ *
+ * Multiple query heads share the same key-value heads, reducing memory and
+ * computation while maintaining model quality. Used in LLaMA 2, Mistral, etc.
+ *
+ * numHeads must be divisible by numKvHeads. Each KV head is repeated
+ * (numHeads / numKvHeads) times to match query heads.
+ *
+ * Special cases:
+ * - numKvHeads == numHeads: Standard Multi-Head Attention (MHA)
+ * - numKvHeads == 1: Multi-Query Attention (MQA)
+ *
+ * See "GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints"
+ *
+ * @param query Query tensor. Shape: [batch, seqLen, numHeads, headDim] (NUMERIC type)
+ * @param key Key tensor. Shape: [batch, seqLen, numKvHeads, headDim] (NUMERIC type)
+ * @param value Value tensor. Shape: [batch, seqLen, numKvHeads, headDim] (NUMERIC type)
+ * @param scale Scaling factor. 0 = auto (1/sqrt(headDim))
+ * @param isCausal Whether to apply causal masking
+ * @param numHeads Number of query attention heads
+ * @param numKvHeads Number of KV heads (must divide numHeads evenly)
+ * @return output Attention output. Shape: [batch, seqLen, numHeads, headDim] (NUMERIC type)
+ */
+ public SDVariable groupedQueryAttention(SDVariable query, SDVariable key, SDVariable value,
+ double scale, boolean isCausal, int numHeads, int numKvHeads) {
+ SDValidation.validateNumerical("groupedQueryAttention", "query", query);
+ SDValidation.validateNumerical("groupedQueryAttention", "key", key);
+ SDValidation.validateNumerical("groupedQueryAttention", "value", value);
+ return new org.nd4j.linalg.api.ops.impl.transforms.custom.GroupedQueryAttention(sd,query, key, value, scale, isCausal, numHeads, numKvHeads).outputVariable();
+ }
+
+ /**
+ * Grouped Query Attention (GQA) - Efficient attention with shared KV heads.
+ *
+ * Multiple query heads share the same key-value heads, reducing memory and
+ * computation while maintaining model quality. Used in LLaMA 2, Mistral, etc.
+ *
+ * numHeads must be divisible by numKvHeads. Each KV head is repeated
+ * (numHeads / numKvHeads) times to match query heads.
+ *
+ * Special cases:
+ * - numKvHeads == numHeads: Standard Multi-Head Attention (MHA)
+ * - numKvHeads == 1: Multi-Query Attention (MQA)
+ *
+ * See "GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints"
+ *
+ * @param name name May be null. Name for the output variable
+ * @param query Query tensor. Shape: [batch, seqLen, numHeads, headDim] (NUMERIC type)
+ * @param key Key tensor. Shape: [batch, seqLen, numKvHeads, headDim] (NUMERIC type)
+ * @param value Value tensor. Shape: [batch, seqLen, numKvHeads, headDim] (NUMERIC type)
+ * @param scale Scaling factor. 0 = auto (1/sqrt(headDim))
+ * @param isCausal Whether to apply causal masking
+ * @param numHeads Number of query attention heads
+ * @param numKvHeads Number of KV heads (must divide numHeads evenly)
+ * @return output Attention output. Shape: [batch, seqLen, numHeads, headDim] (NUMERIC type)
+ */
+ public SDVariable groupedQueryAttention(String name, SDVariable query, SDVariable key,
+ SDVariable value, double scale, boolean isCausal, int numHeads, int numKvHeads) {
+ SDValidation.validateNumerical("groupedQueryAttention", "query", query);
+ SDValidation.validateNumerical("groupedQueryAttention", "key", key);
+ SDValidation.validateNumerical("groupedQueryAttention", "value", value);
+ SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.GroupedQueryAttention(sd,query, key, value, scale, isCausal, numHeads, numKvHeads).outputVariable();
+ return sd.updateVariableNameAndReference(out, name);
+ }
+
/**
* Element-wise hard sigmoid function:
* out[i] = 0 if in[i] <= -2.5
@@ -519,6 +1047,175 @@ public SDVariable hardTanhDerivative(String name, SDVariable x) {
return sd.updateVariableNameAndReference(out, name);
}
+ /**
+ * KV Cache Update - Updates key-value cache for autoregressive generation.
+ *
+ * During LLM inference, past key-value pairs are cached to avoid redundant
+ * computation during token-by-token generation. This operation efficiently
+ * inserts new keys/values at the specified position.
+ *
+ * Usage pattern:
+ * 1. Initialize cache with zeros: [batch, maxSeqLen, numKvHeads, headDim]
+ * 2. For each new token, compute new K/V and update cache
+ * 3. Use full cached K/V for attention computation
+ *
+ * Returns updated keyCache and valueCache tensors.
+ *
+ * @param keyCache Existing key cache. Shape: [batch, maxSeqLen, numKvHeads, headDim] (NUMERIC type)
+ * @param valueCache Existing value cache. Shape: [batch, maxSeqLen, numKvHeads, headDim] (NUMERIC type)
+ * @param newKeys New keys to insert. Shape: [batch, newSeqLen, numKvHeads, headDim] (NUMERIC type)
+ * @param newValues New values to insert. Shape: [batch, newSeqLen, numKvHeads, headDim] (NUMERIC type)
+ * @param startPosition Position in cache where new keys/values should be inserted
+ */
+ public SDVariable[] kvCacheUpdate(SDVariable keyCache, SDVariable valueCache, SDVariable newKeys,
+ SDVariable newValues, int startPosition) {
+ SDValidation.validateNumerical("kvCacheUpdate", "keyCache", keyCache);
+ SDValidation.validateNumerical("kvCacheUpdate", "valueCache", valueCache);
+ SDValidation.validateNumerical("kvCacheUpdate", "newKeys", newKeys);
+ SDValidation.validateNumerical("kvCacheUpdate", "newValues", newValues);
+ return new org.nd4j.linalg.api.ops.impl.transforms.custom.KVCacheUpdate(sd,keyCache, valueCache, newKeys, newValues, startPosition).outputVariables();
+ }
+
+ /**
+ * KV Cache Update - Updates key-value cache for autoregressive generation.
+ *
+ * During LLM inference, past key-value pairs are cached to avoid redundant
+ * computation during token-by-token generation. This operation efficiently
+ * inserts new keys/values at the specified position.
+ *
+ * Usage pattern:
+ * 1. Initialize cache with zeros: [batch, maxSeqLen, numKvHeads, headDim]
+ * 2. For each new token, compute new K/V and update cache
+ * 3. Use full cached K/V for attention computation
+ *
+ * Returns updated keyCache and valueCache tensors.
+ *
+ * @param names names May be null. Arrays of names for the output variables.
+ * @param keyCache Existing key cache. Shape: [batch, maxSeqLen, numKvHeads, headDim] (NUMERIC type)
+ * @param valueCache Existing value cache. Shape: [batch, maxSeqLen, numKvHeads, headDim] (NUMERIC type)
+ * @param newKeys New keys to insert. Shape: [batch, newSeqLen, numKvHeads, headDim] (NUMERIC type)
+ * @param newValues New values to insert. Shape: [batch, newSeqLen, numKvHeads, headDim] (NUMERIC type)
+ * @param startPosition Position in cache where new keys/values should be inserted
+ */
+ public SDVariable[] kvCacheUpdate(String[] names, SDVariable keyCache, SDVariable valueCache,
+ SDVariable newKeys, SDVariable newValues, int startPosition) {
+ SDValidation.validateNumerical("kvCacheUpdate", "keyCache", keyCache);
+ SDValidation.validateNumerical("kvCacheUpdate", "valueCache", valueCache);
+ SDValidation.validateNumerical("kvCacheUpdate", "newKeys", newKeys);
+ SDValidation.validateNumerical("kvCacheUpdate", "newValues", newValues);
+ SDVariable[] out = new org.nd4j.linalg.api.ops.impl.transforms.custom.KVCacheUpdate(sd,keyCache, valueCache, newKeys, newValues, startPosition).outputVariables();
+ return sd.updateVariableNamesAndReferences(out, names);
+ }
+
+ /**
+ * Batch KV cache scatter update for LLM autoregressive decoding.
+ *
+ * Copies a single time-step slice from each present KV tensor into the
+ * corresponding static KV buffer at a given cache position. Replaces N
+ * individual Java view+assign calls with a single native kernel launch.
+ *
+ * The present tensor has shape [batch, heads, seqLen, dim] where the new
+ * token's KV entry is at the last sequence position. This entry is extracted
+ * and written into the static buffer at cachePos.
+ *
+ * For multiple pairs, inputs are ordered as:
+ * [present_0, ..., present_{N-1}, static_0, ..., static_{N-1}]
+ *
+ * @param present Present KV tensor from decoder output. Shape: [batch, heads, seqLen, dim] (NUMERIC type)
+ * @param staticBuffer Static KV cache buffer. Shape: [batch, heads, maxKvLen, dim]. Updated in-place. (NUMERIC type)
+ * @param cachePos Position in static buffer to write the new entry
+ * @return output Scalar 0 on success (LONG type)
+ */
+ public SDVariable kvScatter(SDVariable present, SDVariable staticBuffer, long cachePos) {
+ SDValidation.validateNumerical("kvScatter", "present", present);
+ SDValidation.validateNumerical("kvScatter", "staticBuffer", staticBuffer);
+ return new org.nd4j.linalg.api.ops.impl.transforms.custom.KvScatter(sd,present, staticBuffer, cachePos, 1).outputVariable();
+ }
+
+ /**
+ * Batch KV cache scatter update for LLM autoregressive decoding.
+ *
+ * Copies a single time-step slice from each present KV tensor into the
+ * corresponding static KV buffer at a given cache position. Replaces N
+ * individual Java view+assign calls with a single native kernel launch.
+ *
+ * The present tensor has shape [batch, heads, seqLen, dim] where the new
+ * token's KV entry is at the last sequence position. This entry is extracted
+ * and written into the static buffer at cachePos.
+ *
+ * For multiple pairs, inputs are ordered as:
+ * [present_0, ..., present_{N-1}, static_0, ..., static_{N-1}]
+ *
+ * @param name name May be null. Name for the output variable
+ * @param present Present KV tensor from decoder output. Shape: [batch, heads, seqLen, dim] (NUMERIC type)
+ * @param staticBuffer Static KV cache buffer. Shape: [batch, heads, maxKvLen, dim]. Updated in-place. (NUMERIC type)
+ * @param cachePos Position in static buffer to write the new entry
+ * @return output Scalar 0 on success (LONG type)
+ */
+ public SDVariable kvScatter(String name, SDVariable present, SDVariable staticBuffer,
+ long cachePos) {
+ SDValidation.validateNumerical("kvScatter", "present", present);
+ SDValidation.validateNumerical("kvScatter", "staticBuffer", staticBuffer);
+ SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.KvScatter(sd,present, staticBuffer, cachePos, 1).outputVariable();
+ return sd.updateVariableNameAndReference(out, name);
+ }
+
+ /**
+ * Batch KV cache scatter update for LLM autoregressive decoding.
+ *
+ * Copies a single time-step slice from each present KV tensor into the
+ * corresponding static KV buffer at a given cache position. Replaces N
+ * individual Java view+assign calls with a single native kernel launch.
+ *
+ * The present tensor has shape [batch, heads, seqLen, dim] where the new
+ * token's KV entry is at the last sequence position. This entry is extracted
+ * and written into the static buffer at cachePos.
+ *
+ * For multiple pairs, inputs are ordered as:
+ * [present_0, ..., present_{N-1}, static_0, ..., static_{N-1}]
+ *
+ * @param present Present KV tensor from decoder output. Shape: [batch, heads, seqLen, dim] (NUMERIC type)
+ * @param staticBuffer Static KV cache buffer. Shape: [batch, heads, maxKvLen, dim]. Updated in-place. (NUMERIC type)
+ * @param cachePos Position in static buffer to write the new entry
+ * @param numPairs Number of present/static KV pairs. When > 1, inputs are [present_0..N-1, static_0..N-1]
+ * @return output Scalar 0 on success (LONG type)
+ */
+ public SDVariable kvScatter(SDVariable present, SDVariable staticBuffer, long cachePos,
+ int numPairs) {
+ SDValidation.validateNumerical("kvScatter", "present", present);
+ SDValidation.validateNumerical("kvScatter", "staticBuffer", staticBuffer);
+ return new org.nd4j.linalg.api.ops.impl.transforms.custom.KvScatter(sd,present, staticBuffer, cachePos, numPairs).outputVariable();
+ }
+
+ /**
+ * Batch KV cache scatter update for LLM autoregressive decoding.
+ *
+ * Copies a single time-step slice from each present KV tensor into the
+ * corresponding static KV buffer at a given cache position. Replaces N
+ * individual Java view+assign calls with a single native kernel launch.
+ *
+ * The present tensor has shape [batch, heads, seqLen, dim] where the new
+ * token's KV entry is at the last sequence position. This entry is extracted
+ * and written into the static buffer at cachePos.
+ *
+ * For multiple pairs, inputs are ordered as:
+ * [present_0, ..., present_{N-1}, static_0, ..., static_{N-1}]
+ *
+ * @param name name May be null. Name for the output variable
+ * @param present Present KV tensor from decoder output. Shape: [batch, heads, seqLen, dim] (NUMERIC type)
+ * @param staticBuffer Static KV cache buffer. Shape: [batch, heads, maxKvLen, dim]. Updated in-place. (NUMERIC type)
+ * @param cachePos Position in static buffer to write the new entry
+ * @param numPairs Number of present/static KV pairs. When > 1, inputs are [present_0..N-1, static_0..N-1]
+ * @return output Scalar 0 on success (LONG type)
+ */
+ public SDVariable kvScatter(String name, SDVariable present, SDVariable staticBuffer,
+ long cachePos, int numPairs) {
+ SDValidation.validateNumerical("kvScatter", "present", present);
+ SDValidation.validateNumerical("kvScatter", "staticBuffer", staticBuffer);
+ SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.KvScatter(sd,present, staticBuffer, cachePos, numPairs).outputVariable();
+ return sd.updateVariableNameAndReference(out, name);
+ }
+
/**
* Apply Layer Normalization
*
@@ -811,6 +1508,172 @@ public SDVariable logSoftmax(String name, SDVariable x, int dimension) {
return sd.updateVariableNameAndReference(out, name);
}
+ /**
+ * Mixture of Experts (MoE) Layer.
+ *
+ * Implements sparse MoE routing where each token is processed by only the top-k
+ * selected experts out of a larger pool. This enables scaling model capacity
+ * without proportionally increasing computation.
+ *
+ * Used in large language models like:
+ * - DeepSeek (DeepSeekMoE)
+ * - Mixtral (Mistral AI)
+ * - Switch Transformer (Google)
+ * - GShard (Google)
+ *
+ * The router computes expert selection probabilities:
+ * router_probs = softmax(input @ routerWeights)
+ *
+ * Top-k experts are selected and their outputs are weighted by normalized probs:
+ * output = sum(normalized_prob[i] * expert[i](input) for i in top_k)
+ *
+ * Benefits:
+ * - Scales model capacity with sublinear compute increase
+ * - Enables very large models with efficient inference
+ * - Supports expert parallelism across devices
+ *
+ * @param input Input embeddings. Shape: [batch, seqLen, hiddenSize] (NUMERIC type)
+ * @param routerWeights Router projection weights. Shape: [hiddenSize, numExperts] (NUMERIC type)
+ * @param expertWeights Expert weight matrices. Shape: [numExperts, hiddenSize, expertHiddenSize] (NUMERIC type)
+ * @param numExperts Total number of experts
+ * @param topK Number of experts to route to per token
+ */
+ public SDVariable[] mixtureOfExperts(SDVariable input, SDVariable routerWeights,
+ SDVariable expertWeights, int numExperts, int topK) {
+ SDValidation.validateNumerical("mixtureOfExperts", "input", input);
+ SDValidation.validateNumerical("mixtureOfExperts", "routerWeights", routerWeights);
+ SDValidation.validateNumerical("mixtureOfExperts", "expertWeights", expertWeights);
+ return new org.nd4j.linalg.api.ops.impl.transforms.custom.MixtureOfExperts(sd,input, routerWeights, expertWeights, null, numExperts, topK, true, 1.0).outputVariables();
+ }
+
+ /**
+ * Mixture of Experts (MoE) Layer.
+ *
+ * Implements sparse MoE routing where each token is processed by only the top-k
+ * selected experts out of a larger pool. This enables scaling model capacity
+ * without proportionally increasing computation.
+ *
+ * Used in large language models like:
+ * - DeepSeek (DeepSeekMoE)
+ * - Mixtral (Mistral AI)
+ * - Switch Transformer (Google)
+ * - GShard (Google)
+ *
+ * The router computes expert selection probabilities:
+ * router_probs = softmax(input @ routerWeights)
+ *
+ * Top-k experts are selected and their outputs are weighted by normalized probs:
+ * output = sum(normalized_prob[i] * expert[i](input) for i in top_k)
+ *
+ * Benefits:
+ * - Scales model capacity with sublinear compute increase
+ * - Enables very large models with efficient inference
+ * - Supports expert parallelism across devices
+ *
+ * @param names names May be null. Arrays of names for the output variables.
+ * @param input Input embeddings. Shape: [batch, seqLen, hiddenSize] (NUMERIC type)
+ * @param routerWeights Router projection weights. Shape: [hiddenSize, numExperts] (NUMERIC type)
+ * @param expertWeights Expert weight matrices. Shape: [numExperts, hiddenSize, expertHiddenSize] (NUMERIC type)
+ * @param numExperts Total number of experts
+ * @param topK Number of experts to route to per token
+ */
+ public SDVariable[] mixtureOfExperts(String[] names, SDVariable input, SDVariable routerWeights,
+ SDVariable expertWeights, int numExperts, int topK) {
+ SDValidation.validateNumerical("mixtureOfExperts", "input", input);
+ SDValidation.validateNumerical("mixtureOfExperts", "routerWeights", routerWeights);
+ SDValidation.validateNumerical("mixtureOfExperts", "expertWeights", expertWeights);
+ SDVariable[] out = new org.nd4j.linalg.api.ops.impl.transforms.custom.MixtureOfExperts(sd,input, routerWeights, expertWeights, null, numExperts, topK, true, 1.0).outputVariables();
+ return sd.updateVariableNamesAndReferences(out, names);
+ }
+
+ /**
+ * Mixture of Experts (MoE) Layer.
+ *
+ * Implements sparse MoE routing where each token is processed by only the top-k
+ * selected experts out of a larger pool. This enables scaling model capacity
+ * without proportionally increasing computation.
+ *
+ * Used in large language models like:
+ * - DeepSeek (DeepSeekMoE)
+ * - Mixtral (Mistral AI)
+ * - Switch Transformer (Google)
+ * - GShard (Google)
+ *
+ * The router computes expert selection probabilities:
+ * router_probs = softmax(input @ routerWeights)
+ *
+ * Top-k experts are selected and their outputs are weighted by normalized probs:
+ * output = sum(normalized_prob[i] * expert[i](input) for i in top_k)
+ *
+ * Benefits:
+ * - Scales model capacity with sublinear compute increase
+ * - Enables very large models with efficient inference
+ * - Supports expert parallelism across devices
+ *
+ * @param input Input embeddings. Shape: [batch, seqLen, hiddenSize] (NUMERIC type)
+ * @param routerWeights Router projection weights. Shape: [hiddenSize, numExperts] (NUMERIC type)
+ * @param expertWeights Expert weight matrices. Shape: [numExperts, hiddenSize, expertHiddenSize] (NUMERIC type)
+ * @param expertBias Optional expert biases. Shape: [numExperts, expertHiddenSize] (NUMERIC type)
+ * @param numExperts Total number of experts
+ * @param topK Number of experts to route to per token
+ * @param normalizeProbs Whether to normalize router probabilities for selected experts
+ * @param capacityFactor Expert capacity factor for load balancing
+ */
+ public SDVariable[] mixtureOfExperts(SDVariable input, SDVariable routerWeights,
+ SDVariable expertWeights, SDVariable expertBias, int numExperts, int topK,
+ boolean normalizeProbs, double capacityFactor) {
+ SDValidation.validateNumerical("mixtureOfExperts", "input", input);
+ SDValidation.validateNumerical("mixtureOfExperts", "routerWeights", routerWeights);
+ SDValidation.validateNumerical("mixtureOfExperts", "expertWeights", expertWeights);
+ SDValidation.validateNumerical("mixtureOfExperts", "expertBias", expertBias);
+ return new org.nd4j.linalg.api.ops.impl.transforms.custom.MixtureOfExperts(sd,input, routerWeights, expertWeights, expertBias, numExperts, topK, normalizeProbs, capacityFactor).outputVariables();
+ }
+
+ /**
+ * Mixture of Experts (MoE) Layer.
+ *
+ * Implements sparse MoE routing where each token is processed by only the top-k
+ * selected experts out of a larger pool. This enables scaling model capacity
+ * without proportionally increasing computation.
+ *
+ * Used in large language models like:
+ * - DeepSeek (DeepSeekMoE)
+ * - Mixtral (Mistral AI)
+ * - Switch Transformer (Google)
+ * - GShard (Google)
+ *
+ * The router computes expert selection probabilities:
+ * router_probs = softmax(input @ routerWeights)
+ *
+ * Top-k experts are selected and their outputs are weighted by normalized probs:
+ * output = sum(normalized_prob[i] * expert[i](input) for i in top_k)
+ *
+ * Benefits:
+ * - Scales model capacity with sublinear compute increase
+ * - Enables very large models with efficient inference
+ * - Supports expert parallelism across devices
+ *
+ * @param names names May be null. Arrays of names for the output variables.
+ * @param input Input embeddings. Shape: [batch, seqLen, hiddenSize] (NUMERIC type)
+ * @param routerWeights Router projection weights. Shape: [hiddenSize, numExperts] (NUMERIC type)
+ * @param expertWeights Expert weight matrices. Shape: [numExperts, hiddenSize, expertHiddenSize] (NUMERIC type)
+ * @param expertBias Optional expert biases. Shape: [numExperts, expertHiddenSize] (NUMERIC type)
+ * @param numExperts Total number of experts
+ * @param topK Number of experts to route to per token
+ * @param normalizeProbs Whether to normalize router probabilities for selected experts
+ * @param capacityFactor Expert capacity factor for load balancing
+ */
+ public SDVariable[] mixtureOfExperts(String[] names, SDVariable input, SDVariable routerWeights,
+ SDVariable expertWeights, SDVariable expertBias, int numExperts, int topK,
+ boolean normalizeProbs, double capacityFactor) {
+ SDValidation.validateNumerical("mixtureOfExperts", "input", input);
+ SDValidation.validateNumerical("mixtureOfExperts", "routerWeights", routerWeights);
+ SDValidation.validateNumerical("mixtureOfExperts", "expertWeights", expertWeights);
+ SDValidation.validateNumerical("mixtureOfExperts", "expertBias", expertBias);
+ SDVariable[] out = new org.nd4j.linalg.api.ops.impl.transforms.custom.MixtureOfExperts(sd,input, routerWeights, expertWeights, expertBias, numExperts, topK, normalizeProbs, capacityFactor).outputVariables();
+ return sd.updateVariableNamesAndReferences(out, names);
+ }
+
/**
* This performs multi-headed dot product attention on the given timeseries input
* out = concat(head_1, head_2, ..., head_n) * Wo
@@ -890,7 +1753,7 @@ public SDVariable multiHeadDotProductAttention(String name, SDVariable queries,
}
/**
- * Padding operation
+ * Padding operation
*
* @param input Input tensor (NUMERIC type)
* @param padding Padding value (NUMERIC type)
@@ -905,7 +1768,7 @@ public SDVariable pad(SDVariable input, SDVariable padding, PadMode PadMode, dou
}
/**
- * Padding operation
+ * Padding operation
*
* @param name name May be null. Name for the output variable
* @param input Input tensor (NUMERIC type)
@@ -923,7 +1786,7 @@ public SDVariable pad(String name, SDVariable input, SDVariable padding, PadMode
}
/**
- * Padding operation
+ * Padding operation
*
* @param input Input tensor (NUMERIC type)
* @param padding Padding value (NUMERIC type)
@@ -937,7 +1800,7 @@ public SDVariable pad(SDVariable input, SDVariable padding, double constant) {
}
/**
- * Padding operation
+ * Padding operation
*
* @param name name May be null. Name for the output variable
* @param input Input tensor (NUMERIC type)
@@ -1026,6 +1889,141 @@ public SDVariable prelu(String name, SDVariable input, SDVariable alpha, int...
return sd.updateVariableNameAndReference(out, name);
}
+ /**
+ * Relative Position Bias - Compute relative position bias for attention.
+ *
+ * Supports two modes:
+ * 1. Learned bias (Swin/SAM style): Looks up bias values from a learned table
+ * based on relative positions between query and key positions.
+ *
+ * 2. ALiBi (Attention with Linear Biases): Computes linear position-based bias
+ * without learned parameters. More efficient for very long sequences.
+ *
+ * For learned bias mode:
+ * - biasTable shape: [(2*windowSize-1)^2, numHeads] for 2D
+ * - Output is gathered based on relative position indices
+ *
+ * For ALiBi mode:
+ * - biasTable can be sequence length (scalar) or input tensor
+ * - Computes m_h * |i - j| where m_h = 2^(-8*h/H)
+ *
+ * Reference: "Swin Transformer" (Liu et al., 2021)
+ * "Train Short, Test Long" (Press et al., 2021) for ALiBi
+ *
+ * @param biasTable Learned bias table. Shape: [numRelativePositions, numHeads] for learned mode, or scalar/tensor for ALiBi mode (NUMERIC type)
+ * @param numHeads Number of attention heads
+ * @param windowSize Window size for 2D position encoding (used if generating index)
+ * @return output Position bias. Shape: [numHeads, windowSize^2, windowSize^2] or [numHeads, seqLen, seqLen] (NUMERIC type)
+ */
+ public SDVariable relativePositionBias(SDVariable biasTable, int numHeads, int windowSize) {
+ SDValidation.validateNumerical("relativePositionBias", "biasTable", biasTable);
+ return new org.nd4j.linalg.api.ops.impl.transforms.custom.RelativePositionBias(sd,biasTable, null, numHeads, windowSize, false).outputVariable();
+ }
+
+ /**
+ * Relative Position Bias - Compute relative position bias for attention.
+ *
+ * Supports two modes:
+ * 1. Learned bias (Swin/SAM style): Looks up bias values from a learned table
+ * based on relative positions between query and key positions.
+ *
+ * 2. ALiBi (Attention with Linear Biases): Computes linear position-based bias
+ * without learned parameters. More efficient for very long sequences.
+ *
+ * For learned bias mode:
+ * - biasTable shape: [(2*windowSize-1)^2, numHeads] for 2D
+ * - Output is gathered based on relative position indices
+ *
+ * For ALiBi mode:
+ * - biasTable can be sequence length (scalar) or input tensor
+ * - Computes m_h * |i - j| where m_h = 2^(-8*h/H)
+ *
+ * Reference: "Swin Transformer" (Liu et al., 2021)
+ * "Train Short, Test Long" (Press et al., 2021) for ALiBi
+ *
+ * @param name name May be null. Name for the output variable
+ * @param biasTable Learned bias table. Shape: [numRelativePositions, numHeads] for learned mode, or scalar/tensor for ALiBi mode (NUMERIC type)
+ * @param numHeads Number of attention heads
+ * @param windowSize Window size for 2D position encoding (used if generating index)
+ * @return output Position bias. Shape: [numHeads, windowSize^2, windowSize^2] or [numHeads, seqLen, seqLen] (NUMERIC type)
+ */
+ public SDVariable relativePositionBias(String name, SDVariable biasTable, int numHeads,
+ int windowSize) {
+ SDValidation.validateNumerical("relativePositionBias", "biasTable", biasTable);
+ SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.RelativePositionBias(sd,biasTable, null, numHeads, windowSize, false).outputVariable();
+ return sd.updateVariableNameAndReference(out, name);
+ }
+
+ /**
+ * Relative Position Bias - Compute relative position bias for attention.
+ *
+ * Supports two modes:
+ * 1. Learned bias (Swin/SAM style): Looks up bias values from a learned table
+ * based on relative positions between query and key positions.
+ *
+ * 2. ALiBi (Attention with Linear Biases): Computes linear position-based bias
+ * without learned parameters. More efficient for very long sequences.
+ *
+ * For learned bias mode:
+ * - biasTable shape: [(2*windowSize-1)^2, numHeads] for 2D
+ * - Output is gathered based on relative position indices
+ *
+ * For ALiBi mode:
+ * - biasTable can be sequence length (scalar) or input tensor
+ * - Computes m_h * |i - j| where m_h = 2^(-8*h/H)
+ *
+ * Reference: "Swin Transformer" (Liu et al., 2021)
+ * "Train Short, Test Long" (Press et al., 2021) for ALiBi
+ *
+ * @param biasTable Learned bias table. Shape: [numRelativePositions, numHeads] for learned mode, or scalar/tensor for ALiBi mode (NUMERIC type)
+ * @param relativePositionIndex Optional precomputed relative position index. Shape: [windowSize^2, windowSize^2] (NUMERIC type)
+ * @param numHeads Number of attention heads
+ * @param windowSize Window size for 2D position encoding (used if generating index)
+ * @return output Position bias. Shape: [numHeads, windowSize^2, windowSize^2] or [numHeads, seqLen, seqLen] (NUMERIC type)
+ */
+ public SDVariable relativePositionBias(SDVariable biasTable, SDVariable relativePositionIndex,
+ int numHeads, int windowSize) {
+ SDValidation.validateNumerical("relativePositionBias", "biasTable", biasTable);
+ SDValidation.validateNumerical("relativePositionBias", "relativePositionIndex", relativePositionIndex);
+ return new org.nd4j.linalg.api.ops.impl.transforms.custom.RelativePositionBias(sd,biasTable, relativePositionIndex, numHeads, windowSize, false).outputVariable();
+ }
+
+ /**
+ * Relative Position Bias - Compute relative position bias for attention.
+ *
+ * Supports two modes:
+ * 1. Learned bias (Swin/SAM style): Looks up bias values from a learned table
+ * based on relative positions between query and key positions.
+ *
+ * 2. ALiBi (Attention with Linear Biases): Computes linear position-based bias
+ * without learned parameters. More efficient for very long sequences.
+ *
+ * For learned bias mode:
+ * - biasTable shape: [(2*windowSize-1)^2, numHeads] for 2D
+ * - Output is gathered based on relative position indices
+ *
+ * For ALiBi mode:
+ * - biasTable can be sequence length (scalar) or input tensor
+ * - Computes m_h * |i - j| where m_h = 2^(-8*h/H)
+ *
+ * Reference: "Swin Transformer" (Liu et al., 2021)
+ * "Train Short, Test Long" (Press et al., 2021) for ALiBi
+ *
+ * @param name name May be null. Name for the output variable
+ * @param biasTable Learned bias table. Shape: [numRelativePositions, numHeads] for learned mode, or scalar/tensor for ALiBi mode (NUMERIC type)
+ * @param relativePositionIndex Optional precomputed relative position index. Shape: [windowSize^2, windowSize^2] (NUMERIC type)
+ * @param numHeads Number of attention heads
+ * @param windowSize Window size for 2D position encoding (used if generating index)
+ * @return output Position bias. Shape: [numHeads, windowSize^2, windowSize^2] or [numHeads, seqLen, seqLen] (NUMERIC type)
+ */
+ public SDVariable relativePositionBias(String name, SDVariable biasTable,
+ SDVariable relativePositionIndex, int numHeads, int windowSize) {
+ SDValidation.validateNumerical("relativePositionBias", "biasTable", biasTable);
+ SDValidation.validateNumerical("relativePositionBias", "relativePositionIndex", relativePositionIndex);
+ SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.RelativePositionBias(sd,biasTable, relativePositionIndex, numHeads, windowSize, false).outputVariable();
+ return sd.updateVariableNameAndReference(out, name);
+ }
+
/**
* Element-wise rectified linear function with specified cutoff:
* out[i] = in[i] if in[i] >= cutoff
@@ -1116,6 +2114,146 @@ public SDVariable reluLayer(String name, SDVariable input, SDVariable weights, S
return sd.updateVariableNameAndReference(out, name);
}
+ /**
+ * Root Mean Square Layer Normalization (RMSNorm):
+ *
+ * output = input * rsqrt(mean(input^2, axis=-1) + epsilon) * gamma
+ *
+ * If gamma is not provided, only RMS normalization is applied.
+ *
+ * @param input Input variable (NUMERIC type)
+ * @param gamma Scale/gain vector (NUMERIC type)
+ * @param epsilon Epsilon for numerical stability
+ * @return output RMS normalized output (NUMERIC type)
+ */
+ public SDVariable rmsNorm(SDVariable input, SDVariable gamma, double epsilon) {
+ SDValidation.validateNumerical("rmsNorm", "input", input);
+ SDValidation.validateNumerical("rmsNorm", "gamma", gamma);
+ return new org.nd4j.linalg.api.ops.impl.transforms.custom.RmsNorm(sd,input, gamma, epsilon).outputVariable();
+ }
+
+ /**
+ * Root Mean Square Layer Normalization (RMSNorm):
+ *
+ * output = input * rsqrt(mean(input^2, axis=-1) + epsilon) * gamma
+ *
+ * If gamma is not provided, only RMS normalization is applied.
+ *
+ * @param name name May be null. Name for the output variable
+ * @param input Input variable (NUMERIC type)
+ * @param gamma Scale/gain vector (NUMERIC type)
+ * @param epsilon Epsilon for numerical stability
+ * @return output RMS normalized output (NUMERIC type)
+ */
+ public SDVariable rmsNorm(String name, SDVariable input, SDVariable gamma, double epsilon) {
+ SDValidation.validateNumerical("rmsNorm", "input", input);
+ SDValidation.validateNumerical("rmsNorm", "gamma", gamma);
+ SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.RmsNorm(sd,input, gamma, epsilon).outputVariable();
+ return sd.updateVariableNameAndReference(out, name);
+ }
+
+ /**
+ * Root Mean Square Layer Normalization (RMSNorm):
+ *
+ * output = input * rsqrt(mean(input^2, axis=-1) + epsilon) * gamma
+ *
+ * If gamma is not provided, only RMS normalization is applied.
+ *
+ * @param input Input variable (NUMERIC type)
+ * @param gamma Scale/gain vector (NUMERIC type)
+ * @return output RMS normalized output (NUMERIC type)
+ */
+ public SDVariable rmsNorm(SDVariable input, SDVariable gamma) {
+ SDValidation.validateNumerical("rmsNorm", "input", input);
+ SDValidation.validateNumerical("rmsNorm", "gamma", gamma);
+ return new org.nd4j.linalg.api.ops.impl.transforms.custom.RmsNorm(sd,input, gamma, 1.0E-5).outputVariable();
+ }
+
+ /**
+ * Root Mean Square Layer Normalization (RMSNorm):
+ *
+ * output = input * rsqrt(mean(input^2, axis=-1) + epsilon) * gamma
+ *
+ * If gamma is not provided, only RMS normalization is applied.
+ *
+ * @param name name May be null. Name for the output variable
+ * @param input Input variable (NUMERIC type)
+ * @param gamma Scale/gain vector (NUMERIC type)
+ * @return output RMS normalized output (NUMERIC type)
+ */
+ public SDVariable rmsNorm(String name, SDVariable input, SDVariable gamma) {
+ SDValidation.validateNumerical("rmsNorm", "input", input);
+ SDValidation.validateNumerical("rmsNorm", "gamma", gamma);
+ SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.RmsNorm(sd,input, gamma, 1.0E-5).outputVariable();
+ return sd.updateVariableNameAndReference(out, name);
+ }
+
+ /**
+ * Root Mean Square Layer Normalization (RMSNorm):
+ *
+ * output = input * rsqrt(mean(input^2, axis=-1) + epsilon) * gamma
+ *
+ * If gamma is not provided, only RMS normalization is applied.
+ *
+ * @param input Input variable (NUMERIC type)
+ * @param epsilon Epsilon for numerical stability
+ * @return output RMS normalized output (NUMERIC type)
+ */
+ public SDVariable rmsNorm(SDVariable input, double epsilon) {
+ SDValidation.validateNumerical("rmsNorm", "input", input);
+ return new org.nd4j.linalg.api.ops.impl.transforms.custom.RmsNorm(sd,input, null, epsilon).outputVariable();
+ }
+
+ /**
+ * Root Mean Square Layer Normalization (RMSNorm):
+ *
+ * output = input * rsqrt(mean(input^2, axis=-1) + epsilon) * gamma
+ *
+ * If gamma is not provided, only RMS normalization is applied.
+ *
+ * @param name name May be null. Name for the output variable
+ * @param input Input variable (NUMERIC type)
+ * @param epsilon Epsilon for numerical stability
+ * @return output RMS normalized output (NUMERIC type)
+ */
+ public SDVariable rmsNorm(String name, SDVariable input, double epsilon) {
+ SDValidation.validateNumerical("rmsNorm", "input", input);
+ SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.RmsNorm(sd,input, null, epsilon).outputVariable();
+ return sd.updateVariableNameAndReference(out, name);
+ }
+
+ /**
+ * Root Mean Square Layer Normalization (RMSNorm):
+ *
+ * output = input * rsqrt(mean(input^2, axis=-1) + epsilon) * gamma
+ *
+ * If gamma is not provided, only RMS normalization is applied.
+ *
+ * @param input Input variable (NUMERIC type)
+ * @return output RMS normalized output (NUMERIC type)
+ */
+ public SDVariable rmsNorm(SDVariable input) {
+ SDValidation.validateNumerical("rmsNorm", "input", input);
+ return new org.nd4j.linalg.api.ops.impl.transforms.custom.RmsNorm(sd,input, null, 1.0E-5).outputVariable();
+ }
+
+ /**
+ * Root Mean Square Layer Normalization (RMSNorm):
+ *
+ * output = input * rsqrt(mean(input^2, axis=-1) + epsilon) * gamma
+ *
+ * If gamma is not provided, only RMS normalization is applied.
+ *
+ * @param name name May be null. Name for the output variable
+ * @param input Input variable (NUMERIC type)
+ * @return output RMS normalized output (NUMERIC type)
+ */
+ public SDVariable rmsNorm(String name, SDVariable input) {
+ SDValidation.validateNumerical("rmsNorm", "input", input);
+ SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.RmsNorm(sd,input, null, 1.0E-5).outputVariable();
+ return sd.updateVariableNameAndReference(out, name);
+ }
+
/**
* Element-wise SeLU function - Scaled exponential Lineal Unit: see Self-Normalizing Neural Networks
*
@@ -1198,6 +2336,72 @@ public SDVariable sigmoidDerivative(String name, SDVariable x, SDVariable wrt) {
return sd.updateVariableNameAndReference(out, name);
}
+ /**
+ * Sliding Window Attention - Efficient attention for long sequences.
+ *
+ * Each token only attends to a fixed window of previous tokens, enabling
+ * efficient processing of very long sequences. Used in Mistral and other
+ * modern LLMs for handling long contexts.
+ *
+ * Benefits:
+ * - O(N * windowSize) complexity instead of O(N^2)
+ * - Memory efficient for long sequences
+ * - Supports very long context lengths (e.g., 32K with 4K window)
+ *
+ * The attention mask is automatically applied to restrict each position
+ * to only attend to positions within [pos - windowSize, pos].
+ *
+ * @param query Query tensor. Shape: [batch, seqLen, numHeads, headDim] (NUMERIC type)
+ * @param key Key tensor. Shape: [batch, seqLen, numKvHeads, headDim] (NUMERIC type)
+ * @param value Value tensor. Shape: [batch, seqLen, numKvHeads, headDim] (NUMERIC type)
+ * @param windowSize Sliding window size - tokens can only attend to this many previous positions
+ * @param numHeads Number of query attention heads
+ * @param numKvHeads Number of KV heads (0 = same as numHeads)
+ * @param scale Scaling factor. 0 = auto (1/sqrt(headDim))
+ * @return output Attention output. Shape: [batch, seqLen, numHeads, headDim] (NUMERIC type)
+ */
+ public SDVariable slidingWindowAttention(SDVariable query, SDVariable key, SDVariable value,
+ int windowSize, int numHeads, int numKvHeads, double scale) {
+ SDValidation.validateNumerical("slidingWindowAttention", "query", query);
+ SDValidation.validateNumerical("slidingWindowAttention", "key", key);
+ SDValidation.validateNumerical("slidingWindowAttention", "value", value);
+ return new org.nd4j.linalg.api.ops.impl.transforms.custom.SlidingWindowAttention(sd,query, key, value, windowSize, numHeads, numKvHeads, scale).outputVariable();
+ }
+
+ /**
+ * Sliding Window Attention - Efficient attention for long sequences.
+ *
+ * Each token only attends to a fixed window of previous tokens, enabling
+ * efficient processing of very long sequences. Used in Mistral and other
+ * modern LLMs for handling long contexts.
+ *
+ * Benefits:
+ * - O(N * windowSize) complexity instead of O(N^2)
+ * - Memory efficient for long sequences
+ * - Supports very long context lengths (e.g., 32K with 4K window)
+ *
+ * The attention mask is automatically applied to restrict each position
+ * to only attend to positions within [pos - windowSize, pos].
+ *
+ * @param name name May be null. Name for the output variable
+ * @param query Query tensor. Shape: [batch, seqLen, numHeads, headDim] (NUMERIC type)
+ * @param key Key tensor. Shape: [batch, seqLen, numKvHeads, headDim] (NUMERIC type)
+ * @param value Value tensor. Shape: [batch, seqLen, numKvHeads, headDim] (NUMERIC type)
+ * @param windowSize Sliding window size - tokens can only attend to this many previous positions
+ * @param numHeads Number of query attention heads
+ * @param numKvHeads Number of KV heads (0 = same as numHeads)
+ * @param scale Scaling factor. 0 = auto (1/sqrt(headDim))
+ * @return output Attention output. Shape: [batch, seqLen, numHeads, headDim] (NUMERIC type)
+ */
+ public SDVariable slidingWindowAttention(String name, SDVariable query, SDVariable key,
+ SDVariable value, int windowSize, int numHeads, int numKvHeads, double scale) {
+ SDValidation.validateNumerical("slidingWindowAttention", "query", query);
+ SDValidation.validateNumerical("slidingWindowAttention", "key", key);
+ SDValidation.validateNumerical("slidingWindowAttention", "value", value);
+ SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.SlidingWindowAttention(sd,query, key, value, windowSize, numHeads, numKvHeads, scale).outputVariable();
+ return sd.updateVariableNameAndReference(out, name);
+ }
+
/**
* Softmax activation, along the specified dimension
*
@@ -1370,6 +2574,104 @@ public SDVariable tanh(String name, SDVariable x) {
return sd.updateVariableNameAndReference(out, name);
}
+ /**
+ * Token sampling for LLM inference.
+ *
+ * Full sampling pipeline in a single native GPU call:
+ * temperature scaling -> top-K filtering -> softmax -> top-P filtering -> sample/argmax
+ *
+ * For greedy decoding (temperature=0 or no top-k/top-p), performs GPU-side argmax
+ * with shared-memory reduction — avoids transferring the full logits tensor to host.
+ *
+ * Supports rank 1 [vocabSize], rank 2 [batch, vocabSize], and rank 3
+ * [batch, seqLen, vocabSize] inputs. For rank 3, the last sequence position
+ * is automatically extracted for sampling.
+ *
+ * @param logits Logits tensor. Shape: [vocabSize], [batch, vocabSize], or [batch, seqLen, vocabSize]. For rank-3, samples from the last sequence position. (NUMERIC type)
+ * @return output Sampled token indices. Shape: [batch] or scalar (LONG type)
+ */
+ public SDVariable tokenSample(SDVariable logits) {
+ SDValidation.validateNumerical("tokenSample", "logits", logits);
+ return new org.nd4j.linalg.api.ops.impl.transforms.custom.TokenSample(sd,logits, 0.0, 0, 0.0, 0).outputVariable();
+ }
+
+ /**
+ * Token sampling for LLM inference.
+ *
+ * Full sampling pipeline in a single native GPU call:
+ * temperature scaling -> top-K filtering -> softmax -> top-P filtering -> sample/argmax
+ *
+ * For greedy decoding (temperature=0 or no top-k/top-p), performs GPU-side argmax
+ * with shared-memory reduction — avoids transferring the full logits tensor to host.
+ *
+ * Supports rank 1 [vocabSize], rank 2 [batch, vocabSize], and rank 3
+ * [batch, seqLen, vocabSize] inputs. For rank 3, the last sequence position
+ * is automatically extracted for sampling.
+ *
+ * @param name name May be null. Name for the output variable
+ * @param logits Logits tensor. Shape: [vocabSize], [batch, vocabSize], or [batch, seqLen, vocabSize]. For rank-3, samples from the last sequence position. (NUMERIC type)
+ * @return output Sampled token indices. Shape: [batch] or scalar (LONG type)
+ */
+ public SDVariable tokenSample(String name, SDVariable logits) {
+ SDValidation.validateNumerical("tokenSample", "logits", logits);
+ SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.TokenSample(sd,logits, 0.0, 0, 0.0, 0).outputVariable();
+ return sd.updateVariableNameAndReference(out, name);
+ }
+
+ /**
+ * Token sampling for LLM inference.
+ *
+ * Full sampling pipeline in a single native GPU call:
+ * temperature scaling -> top-K filtering -> softmax -> top-P filtering -> sample/argmax
+ *
+ * For greedy decoding (temperature=0 or no top-k/top-p), performs GPU-side argmax
+ * with shared-memory reduction — avoids transferring the full logits tensor to host.
+ *
+ * Supports rank 1 [vocabSize], rank 2 [batch, vocabSize], and rank 3
+ * [batch, seqLen, vocabSize] inputs. For rank 3, the last sequence position
+ * is automatically extracted for sampling.
+ *
+ * @param logits Logits tensor. Shape: [vocabSize], [batch, vocabSize], or [batch, seqLen, vocabSize]. For rank-3, samples from the last sequence position. (NUMERIC type)
+ * @param temperature Temperature for sampling. 0 = greedy (argmax)
+ * @param topK Top-K filtering: keep only top K logits. 0 = disabled
+ * @param topP Top-P (nucleus) filtering threshold. 0 = disabled
+ * @param seed Random seed for sampling. 0 = random
+ * @return output Sampled token indices. Shape: [batch] or scalar (LONG type)
+ */
+ public SDVariable tokenSample(SDVariable logits, double temperature, int topK, double topP,
+ long seed) {
+ SDValidation.validateNumerical("tokenSample", "logits", logits);
+ return new org.nd4j.linalg.api.ops.impl.transforms.custom.TokenSample(sd,logits, temperature, topK, topP, seed).outputVariable();
+ }
+
+ /**
+ * Token sampling for LLM inference.
+ *
+ * Full sampling pipeline in a single native GPU call:
+ * temperature scaling -> top-K filtering -> softmax -> top-P filtering -> sample/argmax
+ *
+ * For greedy decoding (temperature=0 or no top-k/top-p), performs GPU-side argmax
+ * with shared-memory reduction — avoids transferring the full logits tensor to host.
+ *
+ * Supports rank 1 [vocabSize], rank 2 [batch, vocabSize], and rank 3
+ * [batch, seqLen, vocabSize] inputs. For rank 3, the last sequence position
+ * is automatically extracted for sampling.
+ *
+ * @param name name May be null. Name for the output variable
+ * @param logits Logits tensor. Shape: [vocabSize], [batch, vocabSize], or [batch, seqLen, vocabSize]. For rank-3, samples from the last sequence position. (NUMERIC type)
+ * @param temperature Temperature for sampling. 0 = greedy (argmax)
+ * @param topK Top-K filtering: keep only top K logits. 0 = disabled
+ * @param topP Top-P (nucleus) filtering threshold. 0 = disabled
+ * @param seed Random seed for sampling. 0 = random
+ * @return output Sampled token indices. Shape: [batch] or scalar (LONG type)
+ */
+ public SDVariable tokenSample(String name, SDVariable logits, double temperature, int topK,
+ double topP, long seed) {
+ SDValidation.validateNumerical("tokenSample", "logits", logits);
+ SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.TokenSample(sd,logits, temperature, topK, topP, seed).outputVariable();
+ return sd.updateVariableNameAndReference(out, name);
+ }
+
/**
* Find values and indices for the largest k entries along the last dimension.
*
@@ -1395,4 +2697,265 @@ public SDVariable[] topK(String[] names, SDVariable input, double k, boolean sor
SDVariable[] out = new org.nd4j.linalg.api.ops.impl.transforms.custom.TopK(sd,input, k, sorted).outputVariables();
return sd.updateVariableNamesAndReferences(out, names);
}
+
+ /**
+ * SAM-style Two-Way Cross Attention.
+ * Bidirectional cross-attention where tokens attend to image features and
+ * image features attend to tokens simultaneously:
+ * tokenOutput = softmax(tokenQ @ imageK^T * scale) @ imageV
+ * imageOutput = softmax(imageQ @ tokenK^T * scale) @ tokenV
+ *
+ * @param tokenQuery Token queries [batch, tokenSeqLen, embedDim] (NUMERIC type)
+ * @param tokenKey Token keys [batch, tokenSeqLen, embedDim] (NUMERIC type)
+ * @param tokenValue Token values [batch, tokenSeqLen, embedDim] (NUMERIC type)
+ * @param imageQuery Image queries [batch, imageSeqLen, embedDim] (NUMERIC type)
+ * @param imageKey Image keys [batch, imageSeqLen, embedDim] (NUMERIC type)
+ * @param imageValue Image values [batch, imageSeqLen, embedDim] (NUMERIC type)
+ * @param scale Attention scale factor (default: 1/sqrt(embedDim))
+ */
+ public SDVariable[] twoWayCrossAttention(SDVariable tokenQuery, SDVariable tokenKey,
+ SDVariable tokenValue, SDVariable imageQuery, SDVariable imageKey, SDVariable imageValue,
+ double scale) {
+ SDValidation.validateNumerical("twoWayCrossAttention", "tokenQuery", tokenQuery);
+ SDValidation.validateNumerical("twoWayCrossAttention", "tokenKey", tokenKey);
+ SDValidation.validateNumerical("twoWayCrossAttention", "tokenValue", tokenValue);
+ SDValidation.validateNumerical("twoWayCrossAttention", "imageQuery", imageQuery);
+ SDValidation.validateNumerical("twoWayCrossAttention", "imageKey", imageKey);
+ SDValidation.validateNumerical("twoWayCrossAttention", "imageValue", imageValue);
+ return new org.nd4j.linalg.api.ops.impl.transforms.custom.TwoWayCrossAttention(sd,tokenQuery, tokenKey, tokenValue, imageQuery, imageKey, imageValue, scale).outputVariables();
+ }
+
+ /**
+ * SAM-style Two-Way Cross Attention.
+ * Bidirectional cross-attention where tokens attend to image features and
+ * image features attend to tokens simultaneously:
+ * tokenOutput = softmax(tokenQ @ imageK^T * scale) @ imageV
+ * imageOutput = softmax(imageQ @ tokenK^T * scale) @ tokenV
+ *
+ * @param names names May be null. Arrays of names for the output variables.
+ * @param tokenQuery Token queries [batch, tokenSeqLen, embedDim] (NUMERIC type)
+ * @param tokenKey Token keys [batch, tokenSeqLen, embedDim] (NUMERIC type)
+ * @param tokenValue Token values [batch, tokenSeqLen, embedDim] (NUMERIC type)
+ * @param imageQuery Image queries [batch, imageSeqLen, embedDim] (NUMERIC type)
+ * @param imageKey Image keys [batch, imageSeqLen, embedDim] (NUMERIC type)
+ * @param imageValue Image values [batch, imageSeqLen, embedDim] (NUMERIC type)
+ * @param scale Attention scale factor (default: 1/sqrt(embedDim))
+ */
+ public SDVariable[] twoWayCrossAttention(String[] names, SDVariable tokenQuery,
+ SDVariable tokenKey, SDVariable tokenValue, SDVariable imageQuery, SDVariable imageKey,
+ SDVariable imageValue, double scale) {
+ SDValidation.validateNumerical("twoWayCrossAttention", "tokenQuery", tokenQuery);
+ SDValidation.validateNumerical("twoWayCrossAttention", "tokenKey", tokenKey);
+ SDValidation.validateNumerical("twoWayCrossAttention", "tokenValue", tokenValue);
+ SDValidation.validateNumerical("twoWayCrossAttention", "imageQuery", imageQuery);
+ SDValidation.validateNumerical("twoWayCrossAttention", "imageKey", imageKey);
+ SDValidation.validateNumerical("twoWayCrossAttention", "imageValue", imageValue);
+ SDVariable[] out = new org.nd4j.linalg.api.ops.impl.transforms.custom.TwoWayCrossAttention(sd,tokenQuery, tokenKey, tokenValue, imageQuery, imageKey, imageValue, scale).outputVariables();
+ return sd.updateVariableNamesAndReferences(out, names);
+ }
+
+ /**
+ * SAM-style Two-Way Cross Attention.
+ * Bidirectional cross-attention where tokens attend to image features and
+ * image features attend to tokens simultaneously:
+ * tokenOutput = softmax(tokenQ @ imageK^T * scale) @ imageV
+ * imageOutput = softmax(imageQ @ tokenK^T * scale) @ tokenV
+ *
+ * @param tokenQuery Token queries [batch, tokenSeqLen, embedDim] (NUMERIC type)
+ * @param tokenKey Token keys [batch, tokenSeqLen, embedDim] (NUMERIC type)
+ * @param tokenValue Token values [batch, tokenSeqLen, embedDim] (NUMERIC type)
+ * @param imageQuery Image queries [batch, imageSeqLen, embedDim] (NUMERIC type)
+ * @param imageKey Image keys [batch, imageSeqLen, embedDim] (NUMERIC type)
+ * @param imageValue Image values [batch, imageSeqLen, embedDim] (NUMERIC type)
+ */
+ public SDVariable[] twoWayCrossAttention(SDVariable tokenQuery, SDVariable tokenKey,
+ SDVariable tokenValue, SDVariable imageQuery, SDVariable imageKey, SDVariable imageValue) {
+ SDValidation.validateNumerical("twoWayCrossAttention", "tokenQuery", tokenQuery);
+ SDValidation.validateNumerical("twoWayCrossAttention", "tokenKey", tokenKey);
+ SDValidation.validateNumerical("twoWayCrossAttention", "tokenValue", tokenValue);
+ SDValidation.validateNumerical("twoWayCrossAttention", "imageQuery", imageQuery);
+ SDValidation.validateNumerical("twoWayCrossAttention", "imageKey", imageKey);
+ SDValidation.validateNumerical("twoWayCrossAttention", "imageValue", imageValue);
+ return new org.nd4j.linalg.api.ops.impl.transforms.custom.TwoWayCrossAttention(sd,tokenQuery, tokenKey, tokenValue, imageQuery, imageKey, imageValue, 0.0).outputVariables();
+ }
+
+ /**
+ * SAM-style Two-Way Cross Attention.
+ * Bidirectional cross-attention where tokens attend to image features and
+ * image features attend to tokens simultaneously:
+ * tokenOutput = softmax(tokenQ @ imageK^T * scale) @ imageV
+ * imageOutput = softmax(imageQ @ tokenK^T * scale) @ tokenV
+ *
+ * @param names names May be null. Arrays of names for the output variables.
+ * @param tokenQuery Token queries [batch, tokenSeqLen, embedDim] (NUMERIC type)
+ * @param tokenKey Token keys [batch, tokenSeqLen, embedDim] (NUMERIC type)
+ * @param tokenValue Token values [batch, tokenSeqLen, embedDim] (NUMERIC type)
+ * @param imageQuery Image queries [batch, imageSeqLen, embedDim] (NUMERIC type)
+ * @param imageKey Image keys [batch, imageSeqLen, embedDim] (NUMERIC type)
+ * @param imageValue Image values [batch, imageSeqLen, embedDim] (NUMERIC type)
+ */
+ public SDVariable[] twoWayCrossAttention(String[] names, SDVariable tokenQuery,
+ SDVariable tokenKey, SDVariable tokenValue, SDVariable imageQuery, SDVariable imageKey,
+ SDVariable imageValue) {
+ SDValidation.validateNumerical("twoWayCrossAttention", "tokenQuery", tokenQuery);
+ SDValidation.validateNumerical("twoWayCrossAttention", "tokenKey", tokenKey);
+ SDValidation.validateNumerical("twoWayCrossAttention", "tokenValue", tokenValue);
+ SDValidation.validateNumerical("twoWayCrossAttention", "imageQuery", imageQuery);
+ SDValidation.validateNumerical("twoWayCrossAttention", "imageKey", imageKey);
+ SDValidation.validateNumerical("twoWayCrossAttention", "imageValue", imageValue);
+ SDVariable[] out = new org.nd4j.linalg.api.ops.impl.transforms.custom.TwoWayCrossAttention(sd,tokenQuery, tokenKey, tokenValue, imageQuery, imageKey, imageValue, 0.0).outputVariables();
+ return sd.updateVariableNamesAndReferences(out, names);
+ }
+
+ /**
+ * Windowed Attention - Local/Sliding Window Attention.
+ *
+ * Implements windowed attention mechanisms used in efficient transformers like
+ * Longformer, BigBird, Swin Transformer, and SAM (Segment Anything Model).
+ *
+ * Supports both:
+ * - 1D windowed attention: for sequences [batch, seqLen, heads, dim]
+ * - 2D windowed attention: for images [batch, height, width, heads, dim]
+ *
+ * Shifted window attention (shiftSize > 0) enables cross-window connections
+ * as used in Swin Transformer.
+ *
+ * Benefits:
+ * - O(N * windowSize) complexity instead of O(N^2)
+ * - Efficient for long sequences and high-resolution images
+ * - Supports relative position bias for position-aware attention
+ *
+ * @param query Query tensor. Shape: [batch, seqLen, numHeads, headDim] for 1D or [batch, height, width, numHeads, headDim] for 2D (NUMERIC type)
+ * @param key Key tensor. Same shape as query (NUMERIC type)
+ * @param value Value tensor. Same shape as query (NUMERIC type)
+ * @param windowSize Size of attention window
+ * @param numHeads Number of attention heads
+ * @return output Attention output. Same shape as query (NUMERIC type)
+ */
+ public SDVariable windowedAttention(SDVariable query, SDVariable key, SDVariable value,
+ int windowSize, int numHeads) {
+ SDValidation.validateNumerical("windowedAttention", "query", query);
+ SDValidation.validateNumerical("windowedAttention", "key", key);
+ SDValidation.validateNumerical("windowedAttention", "value", value);
+ return new org.nd4j.linalg.api.ops.impl.transforms.custom.WindowedAttention(sd,query, key, value, null, null, windowSize, numHeads, 0, 0.0, false).outputVariable();
+ }
+
+ /**
+ * Windowed Attention - Local/Sliding Window Attention.
+ *
+ * Implements windowed attention mechanisms used in efficient transformers like
+ * Longformer, BigBird, Swin Transformer, and SAM (Segment Anything Model).
+ *
+ * Supports both:
+ * - 1D windowed attention: for sequences [batch, seqLen, heads, dim]
+ * - 2D windowed attention: for images [batch, height, width, heads, dim]
+ *
+ * Shifted window attention (shiftSize > 0) enables cross-window connections
+ * as used in Swin Transformer.
+ *
+ * Benefits:
+ * - O(N * windowSize) complexity instead of O(N^2)
+ * - Efficient for long sequences and high-resolution images
+ * - Supports relative position bias for position-aware attention
+ *
+ * @param name name May be null. Name for the output variable
+ * @param query Query tensor. Shape: [batch, seqLen, numHeads, headDim] for 1D or [batch, height, width, numHeads, headDim] for 2D (NUMERIC type)
+ * @param key Key tensor. Same shape as query (NUMERIC type)
+ * @param value Value tensor. Same shape as query (NUMERIC type)
+ * @param windowSize Size of attention window
+ * @param numHeads Number of attention heads
+ * @return output Attention output. Same shape as query (NUMERIC type)
+ */
+ public SDVariable windowedAttention(String name, SDVariable query, SDVariable key,
+ SDVariable value, int windowSize, int numHeads) {
+ SDValidation.validateNumerical("windowedAttention", "query", query);
+ SDValidation.validateNumerical("windowedAttention", "key", key);
+ SDValidation.validateNumerical("windowedAttention", "value", value);
+ SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.WindowedAttention(sd,query, key, value, null, null, windowSize, numHeads, 0, 0.0, false).outputVariable();
+ return sd.updateVariableNameAndReference(out, name);
+ }
+
+ /**
+ * Windowed Attention - Local/Sliding Window Attention.
+ *
+ * Implements windowed attention mechanisms used in efficient transformers like
+ * Longformer, BigBird, Swin Transformer, and SAM (Segment Anything Model).
+ *
+ * Supports both:
+ * - 1D windowed attention: for sequences [batch, seqLen, heads, dim]
+ * - 2D windowed attention: for images [batch, height, width, heads, dim]
+ *
+ * Shifted window attention (shiftSize > 0) enables cross-window connections
+ * as used in Swin Transformer.
+ *
+ * Benefits:
+ * - O(N * windowSize) complexity instead of O(N^2)
+ * - Efficient for long sequences and high-resolution images
+ * - Supports relative position bias for position-aware attention
+ *
+ * @param query Query tensor. Shape: [batch, seqLen, numHeads, headDim] for 1D or [batch, height, width, numHeads, headDim] for 2D (NUMERIC type)
+ * @param key Key tensor. Same shape as query (NUMERIC type)
+ * @param value Value tensor. Same shape as query (NUMERIC type)
+ * @param relativePositionBias Optional relative position bias. Shape: [numHeads, windowSize, windowSize] (NUMERIC type)
+ * @param attentionMask Optional attention mask (NUMERIC type)
+ * @param windowSize Size of attention window
+ * @param numHeads Number of attention heads
+ * @param shiftSize Shift size for shifted window attention (Swin style). 0 = no shift
+ * @param scale Attention scale factor. 0 = auto (1/sqrt(headDim))
+ * @param returnWeights Whether to return attention weights
+ * @return output Attention output. Same shape as query (NUMERIC type)
+ */
+ public SDVariable windowedAttention(SDVariable query, SDVariable key, SDVariable value,
+ SDVariable relativePositionBias, SDVariable attentionMask, int windowSize, int numHeads,
+ int shiftSize, double scale, boolean returnWeights) {
+ SDValidation.validateNumerical("windowedAttention", "query", query);
+ SDValidation.validateNumerical("windowedAttention", "key", key);
+ SDValidation.validateNumerical("windowedAttention", "value", value);
+ SDValidation.validateNumerical("windowedAttention", "relativePositionBias", relativePositionBias);
+ SDValidation.validateNumerical("windowedAttention", "attentionMask", attentionMask);
+ return new org.nd4j.linalg.api.ops.impl.transforms.custom.WindowedAttention(sd,query, key, value, relativePositionBias, attentionMask, windowSize, numHeads, shiftSize, scale, returnWeights).outputVariable();
+ }
+
+ /**
+ * Windowed Attention - Local/Sliding Window Attention.
+ *
+ * Implements windowed attention mechanisms used in efficient transformers like
+ * Longformer, BigBird, Swin Transformer, and SAM (Segment Anything Model).
+ *
+ * Supports both:
+ * - 1D windowed attention: for sequences [batch, seqLen, heads, dim]
+ * - 2D windowed attention: for images [batch, height, width, heads, dim]
+ *
+ * Shifted window attention (shiftSize > 0) enables cross-window connections
+ * as used in Swin Transformer.
+ *
+ * Benefits:
+ * - O(N * windowSize) complexity instead of O(N^2)
+ * - Efficient for long sequences and high-resolution images
+ * - Supports relative position bias for position-aware attention
+ *
+ * @param name name May be null. Name for the output variable
+ * @param query Query tensor. Shape: [batch, seqLen, numHeads, headDim] for 1D or [batch, height, width, numHeads, headDim] for 2D (NUMERIC type)
+ * @param key Key tensor. Same shape as query (NUMERIC type)
+ * @param value Value tensor. Same shape as query (NUMERIC type)
+ * @param relativePositionBias Optional relative position bias. Shape: [numHeads, windowSize, windowSize] (NUMERIC type)
+ * @param attentionMask Optional attention mask (NUMERIC type)
+ * @param windowSize Size of attention window
+ * @param numHeads Number of attention heads
+ * @param shiftSize Shift size for shifted window attention (Swin style). 0 = no shift
+ * @param scale Attention scale factor. 0 = auto (1/sqrt(headDim))
+ * @param returnWeights Whether to return attention weights
+ * @return output Attention output. Same shape as query (NUMERIC type)
+ */
+ public SDVariable windowedAttention(String name, SDVariable query, SDVariable key,
+ SDVariable value, SDVariable relativePositionBias, SDVariable attentionMask, int windowSize,
+ int numHeads, int shiftSize, double scale, boolean returnWeights) {
+ SDValidation.validateNumerical("windowedAttention", "query", query);
+ SDValidation.validateNumerical("windowedAttention", "key", key);
+ SDValidation.validateNumerical("windowedAttention", "value", value);
+ SDValidation.validateNumerical("windowedAttention", "relativePositionBias", relativePositionBias);
+ SDValidation.validateNumerical("windowedAttention", "attentionMask", attentionMask);
+ SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.WindowedAttention(sd,query, key, value, relativePositionBias, attentionMask, windowSize, numHeads, shiftSize, scale, returnWeights).outputVariable();
+ return sd.updateVariableNameAndReference(out, name);
+ }
}
diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/AwqMatmul.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/AwqMatmul.java
new file mode 100644
index 000000000000..8a2065afb7b2
--- /dev/null
+++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/AwqMatmul.java
@@ -0,0 +1,134 @@
+/*
+ * ******************************************************************************
+ * *
+ * *
+ * * This program and the accompanying materials are made available under the
+ * * terms of the Apache License, Version 2.0 which is available at
+ * * https://www.apache.org/licenses/LICENSE-2.0.
+ * *
+ * * See the NOTICE file distributed with this work for additional
+ * * information regarding copyright ownership.
+ * * Unless required by applicable law or agreed to in writing, software
+ * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+ * * License for the specific language governing permissions and limitations
+ * * under the License.
+ * *
+ * * SPDX-License-Identifier: Apache-2.0
+ * *****************************************************************************
+ */
+
+package org.nd4j.linalg.api.ops.impl.transforms.custom;
+
+import lombok.Getter;
+import lombok.NoArgsConstructor;
+import org.nd4j.autodiff.samediff.SDVariable;
+import org.nd4j.autodiff.samediff.SameDiff;
+import org.nd4j.linalg.api.buffer.DataType;
+import org.nd4j.linalg.api.ndarray.INDArray;
+import org.nd4j.linalg.api.ops.DynamicCustomOp;
+
+import java.util.Collections;
+import java.util.List;
+
+/**
+ * AWQ (Activation-aware Weight Quantization) matrix multiplication.
+ *
+ * Performs dequantization and GEMM with AWQ-packed weights: + *
+ * output = input @ dequant(weightPacked, scales, zeros) + bias + *+ *
+ * Inputs: + *
+ * Integer arguments: + *
+ * Output: [M, N] same dtype as input
+ *
+ * @author Eclipse Deeplearning4j Contributors
+ */
+@NoArgsConstructor
+public class AwqMatmul extends DynamicCustomOp {
+
+ @Getter private int groupSize = 128;
+ @Getter private int numBits = 4;
+
+ /**
+ * SameDiff constructor with required inputs.
+ */
+ public AwqMatmul(SameDiff sameDiff, SDVariable input, SDVariable weightPacked,
+ SDVariable scales, SDVariable zeros) {
+ super(null, sameDiff, new SDVariable[]{input, weightPacked, scales, zeros}, false);
+ addIArgument((long) groupSize, (long) numBits);
+ }
+
+ /**
+ * SameDiff constructor with bias.
+ */
+ public AwqMatmul(SameDiff sameDiff, SDVariable input, SDVariable weightPacked,
+ SDVariable scales, SDVariable zeros, SDVariable bias) {
+ super(null, sameDiff, new SDVariable[]{input, weightPacked, scales, zeros, bias}, false);
+ addIArgument((long) groupSize, (long) numBits);
+ }
+
+ /**
+ * Full SameDiff constructor with all options.
+ */
+ public AwqMatmul(SameDiff sameDiff, SDVariable input, SDVariable weightPacked,
+ SDVariable scales, SDVariable zeros, SDVariable bias,
+ int groupSize, int numBits) {
+ super(null, sameDiff, bias != null ?
+ new SDVariable[]{input, weightPacked, scales, zeros, bias} :
+ new SDVariable[]{input, weightPacked, scales, zeros}, false);
+ this.groupSize = groupSize;
+ this.numBits = numBits;
+ addIArgument((long) groupSize, (long) numBits);
+ }
+
+ /**
+ * INDArray constructor.
+ */
+ public AwqMatmul(INDArray input, INDArray weightPacked, INDArray scales, INDArray zeros,
+ INDArray bias, INDArray output, int groupSize, int numBits) {
+ super(null, bias != null ?
+ new INDArray[]{input, weightPacked, scales, zeros, bias} :
+ new INDArray[]{input, weightPacked, scales, zeros},
+ output != null ? new INDArray[]{output} : null);
+ this.groupSize = groupSize;
+ this.numBits = numBits;
+ addIArgument((long) groupSize, (long) numBits);
+ }
+
+ @Override
+ public void configureFromArguments() {
+ super.configureFromArguments();
+ if (iArguments.size() > 0) this.groupSize = iArguments.get(0).intValue();
+ if (iArguments.size() > 1) this.numBits = iArguments.get(1).intValue();
+ }
+
+ @Override
+ public String opName() {
+ return "awq_matmul";
+ }
+
+ @Override
+ public List
+ * Splits the weight matrix along columns across tensor-parallel ranks.
+ * Each rank computes a shard of the output, optionally gathering across ranks:
+ *
+ * Inputs:
+ *
+ * Integer arguments:
+ *
+ * Output: [B, O] if gatherOutput=1, else [B, O/tp]
+ *
+ * @author Eclipse Deeplearning4j Contributors
+ */
+@NoArgsConstructor
+public class ColumnParallelLinear extends DynamicCustomOp {
+
+ @Getter private int tpSize = 1;
+ @Getter private int tpRank = 0;
+ @Getter private int gatherOutput = 1;
+
+ /**
+ * SameDiff constructor with required inputs.
+ */
+ public ColumnParallelLinear(SameDiff sameDiff, SDVariable input, SDVariable weightShard) {
+ super(null, sameDiff, new SDVariable[]{input, weightShard}, false);
+ addIArgument((long) tpSize, (long) tpRank, (long) gatherOutput);
+ }
+
+ /**
+ * SameDiff constructor with bias.
+ */
+ public ColumnParallelLinear(SameDiff sameDiff, SDVariable input, SDVariable weightShard,
+ SDVariable biasShard) {
+ super(null, sameDiff, new SDVariable[]{input, weightShard, biasShard}, false);
+ addIArgument((long) tpSize, (long) tpRank, (long) gatherOutput);
+ }
+
+ /**
+ * SameDiff constructor with boolean gatherOutput.
+ */
+ public ColumnParallelLinear(SameDiff sameDiff, SDVariable input, SDVariable weightShard,
+ SDVariable biasShard, int tpSize, int tpRank, boolean gatherOutput) {
+ this(sameDiff, input, weightShard, biasShard, tpSize, tpRank, gatherOutput ? 1 : 0);
+ }
+
+ /**
+ * Full SameDiff constructor with all options.
+ */
+ public ColumnParallelLinear(SameDiff sameDiff, SDVariable input, SDVariable weightShard,
+ SDVariable biasShard, int tpSize, int tpRank, int gatherOutput) {
+ super(null, sameDiff, biasShard != null ?
+ new SDVariable[]{input, weightShard, biasShard} :
+ new SDVariable[]{input, weightShard}, false);
+ this.tpSize = tpSize;
+ this.tpRank = tpRank;
+ this.gatherOutput = gatherOutput;
+ addIArgument((long) tpSize, (long) tpRank, (long) gatherOutput);
+ }
+
+ /**
+ * INDArray constructor (no output pre-allocation).
+ */
+ public ColumnParallelLinear(INDArray input, INDArray weightShard, INDArray biasShard,
+ int tpSize, int tpRank, boolean gatherOutput) {
+ this(input, weightShard, biasShard, null, tpSize, tpRank, gatherOutput ? 1 : 0);
+ }
+
+ /**
+ * INDArray constructor.
+ */
+ public ColumnParallelLinear(INDArray input, INDArray weightShard, INDArray biasShard,
+ INDArray output, int tpSize, int tpRank, int gatherOutput) {
+ super(null, biasShard != null ?
+ new INDArray[]{input, weightShard, biasShard} :
+ new INDArray[]{input, weightShard},
+ output != null ? new INDArray[]{output} : null);
+ this.tpSize = tpSize;
+ this.tpRank = tpRank;
+ this.gatherOutput = gatherOutput;
+ addIArgument((long) tpSize, (long) tpRank, (long) gatherOutput);
+ }
+
+ @Override
+ public void configureFromArguments() {
+ super.configureFromArguments();
+ if (iArguments.size() > 0) this.tpSize = iArguments.get(0).intValue();
+ if (iArguments.size() > 1) this.tpRank = iArguments.get(1).intValue();
+ if (iArguments.size() > 2) this.gatherOutput = iArguments.get(2).intValue();
+ }
+
+ @Override
+ public String opName() {
+ return "column_parallel_linear";
+ }
+
+ @Override
+ public List
+ * Performs fused QKV projection, rotary positional encoding, and
+ * masked multi-head attention in a single op for decoder-only models.
+ *
+ * Inputs:
+ *
+ * Integer arguments:
+ *
+ * Float arguments:
+ *
+ * Outputs:
+ *
+ * Performs GEMM with FP8 quantized inputs and FP16/FP32 output, using
+ * CUTLASS FP8 GEMM with per-tensor dequantization scales:
+ *
+ * Inputs:
+ *
+ * Output: C tensor (FLOAT16) [M, N]
+ *
+ * Integer arguments:
+ *
+ * Executes a sequence of element-wise ops in a single kernel pass, keeping
+ * intermediate values in registers instead of global memory. This replaces
+ * N separate kernel launches with 1.
+ *
+ * Op codes (iArgs):
+ *
+ * Usage:
+ *
+ * Computes the gated linear unit with SiLU (Swish) activation in a single
+ * fused kernel, commonly used in LLM feed-forward blocks:
+ *
+ * Inputs:
+ *
+ * Output: [M, N]
+ *
+ * @author Eclipse Deeplearning4j Contributors
+ */
+@NoArgsConstructor
+public class FusedGemmSwiglu extends DynamicCustomOp {
+
+ /**
+ * SameDiff constructor.
+ */
+ public FusedGemmSwiglu(SameDiff sameDiff, SDVariable input, SDVariable wGate, SDVariable wUp) {
+ super(null, sameDiff, new SDVariable[]{input, wGate, wUp}, false);
+ }
+
+ /**
+ * INDArray constructor.
+ */
+ public FusedGemmSwiglu(INDArray input, INDArray wGate, INDArray wUp, INDArray output) {
+ super(null, new INDArray[]{input, wGate, wUp},
+ output != null ? new INDArray[]{output} : null);
+ }
+
+ @Override
+ public String opName() {
+ return "fused_gemm_swiglu";
+ }
+
+ @Override
+ public List
+ * Applies normalization (RMSNorm or LayerNorm) followed by quantization
+ * (INT8 or FP8) in a single fused kernel:
+ *
+ * Inputs:
+ *
+ * Integer arguments:
+ *
+ * Float arguments:
+ *
+ * Outputs:
+ *
+ * Performs top-K filtering, softmax, and multinomial sampling entirely on the GPU,
+ * eliminating the device-to-host transfer overhead for logits. Uses CUB radix sort
+ * for efficient top-K selection.
+ *
+ * Pipeline: logits -> temperature scaling -> top-K filter -> softmax -> multinomial sample
+ *
+ * Inputs:
+ *
+ * Outputs:
+ *
+ * Integer arguments:
+ *
+ * Float arguments:
+ *
+ * Performs nucleus sampling entirely on the GPU: sorts tokens by probability,
+ * computes cumulative sums to find the smallest set of tokens whose cumulative
+ * probability exceeds p, then samples from that nucleus.
+ *
+ * Pipeline: logits -> temperature scaling -> softmax -> sort -> cumsum -> nucleus filter -> sample
+ *
+ * Supports additional penalty parameters for controlling repetition:
+ *
+ * Inputs:
+ *
+ * Outputs:
+ *
+ * Integer arguments:
+ *
+ * Float arguments:
+ *
+ * Computes top-K expert routing with load-balancing auxiliary loss:
+ *
+ * Inputs:
+ *
+ * Integer arguments:
+ *
+ * Float arguments:
+ *
+ * Outputs:
+ *
+ * Applies per-sample LoRA adapters during batched inference:
+ *
+ * Inputs:
+ *
+ * Float arguments:
+ *
+ * Output: [B, O]
+ *
+ * @author Eclipse Deeplearning4j Contributors
+ */
+@NoArgsConstructor
+public class MultiLoraMatmul extends DynamicCustomOp {
+
+ @Getter private float alpha = 1.0f;
+
+ /**
+ * SameDiff constructor with default alpha.
+ */
+ public MultiLoraMatmul(SameDiff sameDiff, SDVariable input, SDVariable baseWeight,
+ SDVariable loraA, SDVariable loraB, SDVariable adapterIds) {
+ super(null, sameDiff, new SDVariable[]{input, baseWeight, loraA, loraB, adapterIds}, false);
+ addTArgument((double) alpha);
+ }
+
+ /**
+ * SameDiff constructor with double alpha.
+ */
+ public MultiLoraMatmul(SameDiff sameDiff, SDVariable input, SDVariable baseWeight,
+ SDVariable loraA, SDVariable loraB, SDVariable adapterIds,
+ double alpha) {
+ this(sameDiff, input, baseWeight, loraA, loraB, adapterIds, (float) alpha);
+ }
+
+ /**
+ * Full SameDiff constructor with alpha.
+ */
+ public MultiLoraMatmul(SameDiff sameDiff, SDVariable input, SDVariable baseWeight,
+ SDVariable loraA, SDVariable loraB, SDVariable adapterIds,
+ float alpha) {
+ super(null, sameDiff, new SDVariable[]{input, baseWeight, loraA, loraB, adapterIds}, false);
+ this.alpha = alpha;
+ addTArgument((double) alpha);
+ }
+
+ /**
+ * INDArray constructor.
+ */
+ public MultiLoraMatmul(INDArray input, INDArray baseWeight, INDArray loraA, INDArray loraB,
+ INDArray adapterIds, INDArray output, float alpha) {
+ super(null, new INDArray[]{input, baseWeight, loraA, loraB, adapterIds},
+ output != null ? new INDArray[]{output} : null);
+ this.alpha = alpha;
+ addTArgument((double) alpha);
+ }
+
+ @Override
+ public void configureFromArguments() {
+ super.configureFromArguments();
+ if (tArguments.size() > 0) this.alpha = tArguments.get(0).floatValue();
+ }
+
+ @Override
+ public String opName() {
+ return "multi_lora_matmul";
+ }
+
+ @Override
+ public List
+ * Splits the weight matrix along rows across tensor-parallel ranks.
+ * Each rank computes a partial result, optionally reducing across ranks:
+ *
+ * Inputs:
+ *
+ * Integer arguments:
+ *
+ * Output: [B, O]
+ *
+ * @author Eclipse Deeplearning4j Contributors
+ */
+@NoArgsConstructor
+public class RowParallelLinear extends DynamicCustomOp {
+
+ @Getter private int tpSize = 1;
+ @Getter private int tpRank = 0;
+ @Getter private int reduceOutput = 1;
+
+ /**
+ * SameDiff constructor with required inputs.
+ */
+ public RowParallelLinear(SameDiff sameDiff, SDVariable inputShard, SDVariable weightShard) {
+ super(null, sameDiff, new SDVariable[]{inputShard, weightShard}, false);
+ addIArgument((long) tpSize, (long) tpRank, (long) reduceOutput);
+ }
+
+ /**
+ * SameDiff constructor with bias.
+ */
+ public RowParallelLinear(SameDiff sameDiff, SDVariable inputShard, SDVariable weightShard,
+ SDVariable bias) {
+ super(null, sameDiff, new SDVariable[]{inputShard, weightShard, bias}, false);
+ addIArgument((long) tpSize, (long) tpRank, (long) reduceOutput);
+ }
+
+ /**
+ * SameDiff constructor with boolean reduceOutput.
+ */
+ public RowParallelLinear(SameDiff sameDiff, SDVariable inputShard, SDVariable weightShard,
+ SDVariable bias, int tpSize, int tpRank, boolean reduceOutput) {
+ this(sameDiff, inputShard, weightShard, bias, tpSize, tpRank, reduceOutput ? 1 : 0);
+ }
+
+ /**
+ * Full SameDiff constructor with all options.
+ */
+ public RowParallelLinear(SameDiff sameDiff, SDVariable inputShard, SDVariable weightShard,
+ SDVariable bias, int tpSize, int tpRank, int reduceOutput) {
+ super(null, sameDiff, bias != null ?
+ new SDVariable[]{inputShard, weightShard, bias} :
+ new SDVariable[]{inputShard, weightShard}, false);
+ this.tpSize = tpSize;
+ this.tpRank = tpRank;
+ this.reduceOutput = reduceOutput;
+ addIArgument((long) tpSize, (long) tpRank, (long) reduceOutput);
+ }
+
+ /**
+ * INDArray constructor (no output pre-allocation).
+ */
+ public RowParallelLinear(INDArray inputShard, INDArray weightShard, INDArray bias,
+ int tpSize, int tpRank, boolean reduceOutput) {
+ this(inputShard, weightShard, bias, null, tpSize, tpRank, reduceOutput ? 1 : 0);
+ }
+
+ /**
+ * INDArray constructor.
+ */
+ public RowParallelLinear(INDArray inputShard, INDArray weightShard, INDArray bias,
+ INDArray output, int tpSize, int tpRank, int reduceOutput) {
+ super(null, bias != null ?
+ new INDArray[]{inputShard, weightShard, bias} :
+ new INDArray[]{inputShard, weightShard},
+ output != null ? new INDArray[]{output} : null);
+ this.tpSize = tpSize;
+ this.tpRank = tpRank;
+ this.reduceOutput = reduceOutput;
+ addIArgument((long) tpSize, (long) tpRank, (long) reduceOutput);
+ }
+
+ @Override
+ public void configureFromArguments() {
+ super.configureFromArguments();
+ if (iArguments.size() > 0) this.tpSize = iArguments.get(0).intValue();
+ if (iArguments.size() > 1) this.tpRank = iArguments.get(1).intValue();
+ if (iArguments.size() > 2) this.reduceOutput = iArguments.get(2).intValue();
+ }
+
+ @Override
+ public String opName() {
+ return "row_parallel_linear";
+ }
+
+ @Override
+ public List
+ * Implements the selective state space model scan used in Mamba architectures:
+ *
+ * Inputs:
+ *
+ * Output: [B, L, D]
+ *
+ * @author Eclipse Deeplearning4j Contributors
+ */
+@NoArgsConstructor
+public class SelectiveScan extends DynamicCustomOp {
+
+ /**
+ * SameDiff constructor with required inputs.
+ */
+ public SelectiveScan(SameDiff sameDiff, SDVariable x, SDVariable a, SDVariable b,
+ SDVariable c, SDVariable d) {
+ super(null, sameDiff, new SDVariable[]{x, a, b, c, d}, false);
+ }
+
+ /**
+ * SameDiff constructor with initial hidden state.
+ */
+ public SelectiveScan(SameDiff sameDiff, SDVariable x, SDVariable a, SDVariable b,
+ SDVariable c, SDVariable d, SDVariable h0) {
+ super(null, sameDiff, new SDVariable[]{x, a, b, c, d, h0}, false);
+ }
+
+ /**
+ * INDArray constructor without output pre-allocation.
+ */
+ public SelectiveScan(INDArray x, INDArray a, INDArray b, INDArray c, INDArray d,
+ INDArray h0) {
+ this(x, a, b, c, d, h0, null);
+ }
+
+ /**
+ * INDArray constructor.
+ */
+ public SelectiveScan(INDArray x, INDArray a, INDArray b, INDArray c, INDArray d,
+ INDArray h0, INDArray output) {
+ super(null, h0 != null ?
+ new INDArray[]{x, a, b, c, d, h0} :
+ new INDArray[]{x, a, b, c, d},
+ output != null ? new INDArray[]{output} : null);
+ }
+
+ @Override
+ public String opName() {
+ return "selective_scan";
+ }
+
+ @Override
+ public List
+ * Implements the SmoothQuant algorithm which migrates quantization difficulty
+ * from activations to weights by applying a per-channel smoothing factor:
+ *
+ * Inputs:
+ *
+ * Output: Y (FLOAT) - dequantized output [batch, out_features]
+ *
+ * Integer arguments:
+ *
+ * output_shard = input @ weightShard + biasShard
+ * output = allGather(output_shard) if gatherOutput else output_shard
+ *
+ *
+ *
+ *
+ *
+ *
+ *
+ *
+ *
+ *
+ *
+ *
+ *
+ *
+ * @author Eclipse Deeplearning4j Contributors
+ */
+@NoArgsConstructor
+public class DecoderMaskedMha extends DynamicCustomOp {
+
+ @Getter private int numHeads = 32;
+ @Getter private int numKvHeads = 0;
+ @Getter private int headDim = 0;
+ @Getter private int useRoPE = 1;
+ @Getter private int ropeBase = 10000;
+ @Getter private float maskFilterValue = -3.4028235e+38f;
+
+ /**
+ * SameDiff constructor with required inputs and default options.
+ */
+ public DecoderMaskedMha(SameDiff sameDiff, SDVariable hiddenStates, SDVariable qkvWeight,
+ SDVariable outWeight, SDVariable pastKey, SDVariable pastValue) {
+ super(null, sameDiff, new SDVariable[]{hiddenStates, qkvWeight, outWeight, pastKey, pastValue}, false);
+ addIArgument((long) numHeads, (long) numKvHeads, (long) headDim, (long) useRoPE, (long) ropeBase);
+ addTArgument((double) maskFilterValue);
+ }
+
+ /**
+ * SameDiff constructor with boolean useRoPE (uses default ropeBase and maskFilterValue).
+ */
+ public DecoderMaskedMha(SameDiff sameDiff, SDVariable hiddenStates, SDVariable qkvWeight,
+ SDVariable outWeight, SDVariable pastKey, SDVariable pastValue,
+ int numHeads, int numKvHeads, int headDim, boolean useRoPE) {
+ this(sameDiff, hiddenStates, qkvWeight, outWeight, pastKey, pastValue,
+ numHeads, numKvHeads, headDim, useRoPE ? 1 : 0, 10000, -3.4028235e+38f);
+ }
+
+ /**
+ * Full SameDiff constructor with all options.
+ */
+ public DecoderMaskedMha(SameDiff sameDiff, SDVariable hiddenStates, SDVariable qkvWeight,
+ SDVariable outWeight, SDVariable pastKey, SDVariable pastValue,
+ int numHeads, int numKvHeads, int headDim,
+ int useRoPE, int ropeBase, float maskFilterValue) {
+ super(null, sameDiff, new SDVariable[]{hiddenStates, qkvWeight, outWeight, pastKey, pastValue}, false);
+ this.numHeads = numHeads;
+ this.numKvHeads = numKvHeads;
+ this.headDim = headDim;
+ this.useRoPE = useRoPE;
+ this.ropeBase = ropeBase;
+ this.maskFilterValue = maskFilterValue;
+ addIArgument((long) numHeads, (long) numKvHeads, (long) headDim, (long) useRoPE, (long) ropeBase);
+ addTArgument((double) maskFilterValue);
+ }
+
+ /**
+ * INDArray constructor.
+ */
+ public DecoderMaskedMha(INDArray hiddenStates, INDArray qkvWeight, INDArray outWeight,
+ INDArray pastKey, INDArray pastValue,
+ INDArray output, INDArray presentKey, INDArray presentValue,
+ int numHeads, int numKvHeads, int headDim,
+ int useRoPE, int ropeBase, float maskFilterValue) {
+ super(null, new INDArray[]{hiddenStates, qkvWeight, outWeight, pastKey, pastValue},
+ output != null ? new INDArray[]{output, presentKey, presentValue} : null);
+ this.numHeads = numHeads;
+ this.numKvHeads = numKvHeads;
+ this.headDim = headDim;
+ this.useRoPE = useRoPE;
+ this.ropeBase = ropeBase;
+ this.maskFilterValue = maskFilterValue;
+ addIArgument((long) numHeads, (long) numKvHeads, (long) headDim, (long) useRoPE, (long) ropeBase);
+ addTArgument((double) maskFilterValue);
+ }
+
+ @Override
+ public void configureFromArguments() {
+ super.configureFromArguments();
+ if (iArguments.size() > 0) this.numHeads = iArguments.get(0).intValue();
+ if (iArguments.size() > 1) this.numKvHeads = iArguments.get(1).intValue();
+ if (iArguments.size() > 2) this.headDim = iArguments.get(2).intValue();
+ if (iArguments.size() > 3) this.useRoPE = iArguments.get(3).intValue();
+ if (iArguments.size() > 4) this.ropeBase = iArguments.get(4).intValue();
+ if (tArguments.size() > 0) this.maskFilterValue = tArguments.get(0).floatValue();
+ }
+
+ @Override
+ public String opName() {
+ return "decoder_masked_mha";
+ }
+
+ @Override
+ public List
+ * C = dequant(A @ B) + bias
+ * = (scale_A * A_fp8) @ (scale_B * B_fp8) + bias
+ *
+ *
+ *
+ *
+ *
+ *
+ * @author Eclipse Deeplearning4j Contributors
+ */
+@NoArgsConstructor
+public class Fp8Matmul extends DynamicCustomOp {
+
+ /** FP8 E4M3 format (4-bit exponent, 3-bit mantissa) - higher precision */
+ public static final int FP8_E4M3 = 0;
+ /** FP8 E5M2 format (5-bit exponent, 2-bit mantissa) - wider range */
+ public static final int FP8_E5M2 = 1;
+
+ @Getter private int fp8Format = FP8_E4M3;
+ @Getter private boolean transposeA = false;
+ @Getter private boolean transposeB = false;
+
+ /**
+ * SameDiff constructor with required inputs.
+ *
+ * @param sameDiff the SameDiff instance
+ * @param a FP8 quantized A matrix [M, K]
+ * @param b FP8 quantized B matrix [K, N]
+ * @param scaleA per-tensor scale for A
+ * @param scaleB per-tensor scale for B
+ */
+ public Fp8Matmul(SameDiff sameDiff, SDVariable a, SDVariable b,
+ SDVariable scaleA, SDVariable scaleB) {
+ super(null, sameDiff, new SDVariable[]{a, b, scaleA, scaleB}, false);
+ addIArgument((long) fp8Format, transposeA ? 1L : 0L, transposeB ? 1L : 0L);
+ }
+
+ /**
+ * SameDiff constructor with bias.
+ */
+ public Fp8Matmul(SameDiff sameDiff, SDVariable a, SDVariable b,
+ SDVariable scaleA, SDVariable scaleB, SDVariable bias) {
+ super(null, sameDiff, new SDVariable[]{a, b, scaleA, scaleB, bias}, false);
+ addIArgument((long) fp8Format, transposeA ? 1L : 0L, transposeB ? 1L : 0L);
+ }
+
+ /**
+ * Full SameDiff constructor with all options.
+ */
+ public Fp8Matmul(SameDiff sameDiff, SDVariable a, SDVariable b,
+ SDVariable scaleA, SDVariable scaleB, SDVariable bias,
+ int fp8Format, boolean transposeA, boolean transposeB) {
+ super(null, sameDiff, bias != null ?
+ new SDVariable[]{a, b, scaleA, scaleB, bias} :
+ new SDVariable[]{a, b, scaleA, scaleB}, false);
+ this.fp8Format = fp8Format;
+ this.transposeA = transposeA;
+ this.transposeB = transposeB;
+ addIArgument((long) fp8Format, transposeA ? 1L : 0L, transposeB ? 1L : 0L);
+ }
+
+ /**
+ * INDArray constructor.
+ */
+ public Fp8Matmul(INDArray a, INDArray b, INDArray scaleA, INDArray scaleB,
+ INDArray bias, INDArray output,
+ int fp8Format, boolean transposeA, boolean transposeB) {
+ super(null, bias != null ?
+ new INDArray[]{a, b, scaleA, scaleB, bias} :
+ new INDArray[]{a, b, scaleA, scaleB},
+ output != null ? new INDArray[]{output} : null);
+ this.fp8Format = fp8Format;
+ this.transposeA = transposeA;
+ this.transposeB = transposeB;
+ addIArgument((long) fp8Format, transposeA ? 1L : 0L, transposeB ? 1L : 0L);
+ }
+
+ @Override
+ public void configureFromArguments() {
+ super.configureFromArguments();
+ if (iArguments.size() > 0) this.fp8Format = iArguments.get(0).intValue();
+ if (iArguments.size() > 1) this.transposeA = iArguments.get(1) != 0;
+ if (iArguments.size() > 2) this.transposeB = iArguments.get(2) != 0;
+ }
+
+ @Override
+ public String opName() {
+ return "fp8_matmul";
+ }
+
+ @Override
+ public List
+ * 0=add, 1=sub, 2=mul, 3=div,
+ * 10=relu, 11=sigmoid, 12=tanh, 13=gelu,
+ * 14=exp, 15=log, 16=abs, 17=neg,
+ * 18=square, 19=sqrt, 20=swish, 21=silu, 22=mish,
+ * 30=clip, 31=leaky_relu
+ *
+ *
+ * // multiply(x, y) -> sigmoid
+ * FusedElementwiseChain.builder()
+ * .input(x)
+ * .multiply(y)
+ * .sigmoid()
+ * .build()
+ * .exec();
+ *
+ *
+ * @author Adam Gibson
+ */
+@NoArgsConstructor
+public class FusedElementwiseChain extends DynamicCustomOp {
+
+ // Op codes matching FusedElemOp enum in C++
+ public static final int OP_ADD = 0;
+ public static final int OP_SUB = 1;
+ public static final int OP_MUL = 2;
+ public static final int OP_DIV = 3;
+ public static final int OP_RELU = 10;
+ public static final int OP_SIGMOID = 11;
+ public static final int OP_TANH = 12;
+ public static final int OP_GELU = 13;
+ public static final int OP_EXP = 14;
+ public static final int OP_LOG = 15;
+ public static final int OP_ABS = 16;
+ public static final int OP_NEG = 17;
+ public static final int OP_SQUARE = 18;
+ public static final int OP_SQRT = 19;
+ public static final int OP_SWISH = 20;
+ public static final int OP_SILU = 21;
+ public static final int OP_MISH = 22;
+ public static final int OP_CLIP = 30;
+ public static final int OP_LEAKY_RELU = 31;
+
+ /**
+ * Create a fused elementwise chain.
+ *
+ * @param inputs Primary input at index 0, secondary inputs for binary ops at indices 1+
+ * @param output Pre-allocated output (or null)
+ * @param opCodes Sequence of FusedElemOp codes
+ */
+ public FusedElementwiseChain(INDArray[] inputs, INDArray output, int... opCodes) {
+ super(inputs, output != null ? new INDArray[]{output} : null);
+ for (int code : opCodes) {
+ addIArgument(code);
+ }
+ }
+
+ /**
+ * Constructor for codegen (NDNN): single primary input, unary ops only.
+ */
+ public FusedElementwiseChain(INDArray input, int... opCodes) {
+ this(new INDArray[]{input}, null, opCodes);
+ }
+
+ /**
+ * Constructor for codegen (NDNN): single primary input + optional secondary inputs array.
+ */
+ public FusedElementwiseChain(INDArray input, INDArray[] secondaryInputs, int[] opCodes) {
+ this(buildInputs(input, secondaryInputs), null, opCodes);
+ }
+
+ /**
+ * Constructor for codegen (SDNN): SameDiff graph mode, unary ops only.
+ */
+ public FusedElementwiseChain(SameDiff sd, SDVariable input, int... opCodes) {
+ this(sd, input, (SDVariable[]) null, opCodes);
+ }
+
+ /**
+ * Constructor for codegen (SDNN): SameDiff graph mode with secondary inputs.
+ */
+ public FusedElementwiseChain(SameDiff sd, SDVariable input, SDVariable[] secondaryInputs, int[] opCodes) {
+ super(null, sd, buildSdInputs(input, secondaryInputs), false);
+ for (int code : opCodes) {
+ addIArgument(code);
+ }
+ }
+
+ private static INDArray[] buildInputs(INDArray input, INDArray[] secondaryInputs) {
+ if (secondaryInputs == null || secondaryInputs.length == 0) {
+ return new INDArray[]{input};
+ }
+ INDArray[] all = new INDArray[1 + secondaryInputs.length];
+ all[0] = input;
+ System.arraycopy(secondaryInputs, 0, all, 1, secondaryInputs.length);
+ return all;
+ }
+
+ private static SDVariable[] buildSdInputs(SDVariable input, SDVariable[] secondaryInputs) {
+ if (secondaryInputs == null || secondaryInputs.length == 0) {
+ return new SDVariable[]{input};
+ }
+ SDVariable[] all = new SDVariable[1 + secondaryInputs.length];
+ all[0] = input;
+ System.arraycopy(secondaryInputs, 0, all, 1, secondaryInputs.length);
+ return all;
+ }
+
+ @Override
+ public String opName() {
+ return "fused_elementwise_chain";
+ }
+
+ @Override
+ public List
+ * output = (input @ wGate) * silu(input @ wUp)
+ *
+ *
+ *
+ *
+ * normalized = norm(input, weight, bias)
+ * quantized = quantize(normalized)
+ * scale = compute_scale(normalized)
+ *
+ *
+ *
+ *
+ *
+ *
+ *
+ *
+ *
+ *
+ * @author Eclipse Deeplearning4j Contributors
+ */
+@NoArgsConstructor
+public class FusedNormQuantize extends DynamicCustomOp {
+
+ public static final int NORM_RMSNORM = 0;
+ public static final int NORM_LAYERNORM = 1;
+ public static final int QUANT_INT8 = 0;
+ public static final int QUANT_FP8 = 1;
+
+ @Getter private int normType = NORM_RMSNORM;
+ @Getter private int quantType = QUANT_INT8;
+ @Getter private float epsilon = 1e-5f;
+
+ /**
+ * SameDiff constructor with required inputs.
+ */
+ public FusedNormQuantize(SameDiff sameDiff, SDVariable input, SDVariable weight) {
+ super(null, sameDiff, new SDVariable[]{input, weight}, false);
+ addIArgument((long) normType, (long) quantType);
+ addTArgument((double) epsilon);
+ }
+
+ /**
+ * SameDiff constructor with bias.
+ */
+ public FusedNormQuantize(SameDiff sameDiff, SDVariable input, SDVariable weight, SDVariable bias) {
+ super(null, sameDiff, new SDVariable[]{input, weight, bias}, false);
+ addIArgument((long) normType, (long) quantType);
+ addTArgument((double) epsilon);
+ }
+
+ /**
+ * SameDiff constructor with double epsilon.
+ */
+ public FusedNormQuantize(SameDiff sameDiff, SDVariable input, SDVariable weight, SDVariable bias,
+ int normType, int quantType, double epsilon) {
+ this(sameDiff, input, weight, bias, normType, quantType, (float) epsilon);
+ }
+
+ /**
+ * Full SameDiff constructor with all options.
+ */
+ public FusedNormQuantize(SameDiff sameDiff, SDVariable input, SDVariable weight, SDVariable bias,
+ int normType, int quantType, float epsilon) {
+ super(null, sameDiff, bias != null ?
+ new SDVariable[]{input, weight, bias} :
+ new SDVariable[]{input, weight}, false);
+ this.normType = normType;
+ this.quantType = quantType;
+ this.epsilon = epsilon;
+ addIArgument((long) normType, (long) quantType);
+ addTArgument((double) epsilon);
+ }
+
+ /**
+ * INDArray constructor.
+ */
+ public FusedNormQuantize(INDArray input, INDArray weight, INDArray bias,
+ INDArray quantized, INDArray scale,
+ int normType, int quantType, float epsilon) {
+ super(null, bias != null ?
+ new INDArray[]{input, weight, bias} :
+ new INDArray[]{input, weight},
+ quantized != null ? new INDArray[]{quantized, scale} : null);
+ this.normType = normType;
+ this.quantType = quantType;
+ this.epsilon = epsilon;
+ addIArgument((long) normType, (long) quantType);
+ addTArgument((double) epsilon);
+ }
+
+ @Override
+ public void configureFromArguments() {
+ super.configureFromArguments();
+ if (iArguments.size() > 0) this.normType = iArguments.get(0).intValue();
+ if (iArguments.size() > 1) this.quantType = iArguments.get(1).intValue();
+ if (tArguments.size() > 0) this.epsilon = tArguments.get(0).floatValue();
+ }
+
+ @Override
+ public String opName() {
+ return "fused_norm_quantize";
+ }
+
+ @Override
+ public List
+ *
+ *
+ *
+ *
+ *
+ *
+ *
+ *
+ * @author Eclipse Deeplearning4j Contributors
+ * @see GpuTopPSample
+ * @see TokenSample
+ */
+@NoArgsConstructor
+public class GpuTopKSample extends DynamicCustomOp {
+
+ @Getter private int k = 50;
+ @Getter private long seed = 0;
+ @Getter private double temperature = 1.0;
+
+ /**
+ * INDArray constructor with logits only (default k=50, temperature=1.0).
+ */
+ public GpuTopKSample(INDArray logits, int k) {
+ this(logits, k, 1.0, 0);
+ }
+
+ /**
+ * Full INDArray constructor.
+ *
+ * @param logits input logits [batch, vocab_size]
+ * @param k number of top tokens to consider
+ * @param temperature temperature for scaling
+ * @param seed RNG seed (0 for random)
+ */
+ public GpuTopKSample(INDArray logits, int k, double temperature, long seed) {
+ super(new INDArray[]{logits}, null);
+ this.k = k;
+ this.temperature = temperature;
+ this.seed = seed;
+ addIArgument((long) k, seed);
+ addTArgument(temperature);
+ }
+
+ /**
+ * INDArray constructor with explicit random values.
+ */
+ public GpuTopKSample(INDArray logits, INDArray randomValues, int k, double temperature) {
+ super(new INDArray[]{logits, randomValues}, null);
+ this.k = k;
+ this.temperature = temperature;
+ addIArgument((long) k, seed);
+ addTArgument(temperature);
+ }
+
+ /**
+ * SameDiff constructor with int seed and reordered parameters (k, seed, temperature).
+ */
+ public GpuTopKSample(SameDiff sameDiff, SDVariable logits, int k, int seed, double temperature) {
+ this(sameDiff, logits, k, temperature, (long) seed);
+ }
+
+ /**
+ * SameDiff constructor.
+ */
+ public GpuTopKSample(SameDiff sameDiff, SDVariable logits, int k, double temperature, long seed) {
+ super(null, sameDiff, new SDVariable[]{logits}, false);
+ this.k = k;
+ this.temperature = temperature;
+ this.seed = seed;
+ addIArgument((long) k, seed);
+ addTArgument(temperature);
+ }
+
+ /**
+ * SameDiff constructor with random values input.
+ */
+ public GpuTopKSample(SameDiff sameDiff, SDVariable logits, SDVariable randomValues,
+ int k, double temperature) {
+ super(null, sameDiff, new SDVariable[]{logits, randomValues}, false);
+ this.k = k;
+ this.temperature = temperature;
+ addIArgument((long) k, seed);
+ addTArgument(temperature);
+ }
+
+ @Override
+ public void configureFromArguments() {
+ super.configureFromArguments();
+ if (iArguments.size() > 0) this.k = iArguments.get(0).intValue();
+ if (iArguments.size() > 1) this.seed = iArguments.get(1);
+ if (tArguments.size() > 0) this.temperature = tArguments.get(0);
+ }
+
+ @Override
+ public String opName() {
+ return "gpu_top_k_sample";
+ }
+
+ @Override
+ public List
+ *
+ *
+ *
+ *
+ *
+ *
+ *
+ *
+ *
+ *
+ * @author Eclipse Deeplearning4j Contributors
+ * @see GpuTopKSample
+ * @see TokenSample
+ */
+@NoArgsConstructor
+public class GpuTopPSample extends DynamicCustomOp {
+
+ @Getter private double p = 0.9;
+ @Getter private double temperature = 1.0;
+ @Getter private double repetitionPenalty = 1.0;
+ @Getter private double frequencyPenalty = 0.0;
+ @Getter private double presencePenalty = 0.0;
+ @Getter private long seed = 0;
+
+ /**
+ * INDArray constructor with logits only (default p=0.9, temperature=1.0).
+ */
+ public GpuTopPSample(INDArray logits, double p) {
+ this(logits, p, 1.0, 0);
+ }
+
+ /**
+ * INDArray constructor with core parameters.
+ *
+ * @param logits input logits [batch, vocab_size]
+ * @param p nucleus probability threshold (0.0 to 1.0)
+ * @param temperature temperature for scaling
+ * @param seed RNG seed (0 for random)
+ */
+ public GpuTopPSample(INDArray logits, double p, double temperature, long seed) {
+ super(new INDArray[]{logits}, null);
+ this.p = p;
+ this.temperature = temperature;
+ this.seed = seed;
+ addIArgument(seed);
+ addTArgument(p, temperature, repetitionPenalty, frequencyPenalty, presencePenalty);
+ }
+
+ /**
+ * Full INDArray constructor with all penalty parameters.
+ */
+ public GpuTopPSample(INDArray logits, double p, double temperature, long seed,
+ double repetitionPenalty, double frequencyPenalty, double presencePenalty) {
+ super(new INDArray[]{logits}, null);
+ this.p = p;
+ this.temperature = temperature;
+ this.seed = seed;
+ this.repetitionPenalty = repetitionPenalty;
+ this.frequencyPenalty = frequencyPenalty;
+ this.presencePenalty = presencePenalty;
+ addIArgument(seed);
+ addTArgument(p, temperature, repetitionPenalty, frequencyPenalty, presencePenalty);
+ }
+
+ /**
+ * INDArray constructor with explicit random values.
+ */
+ public GpuTopPSample(INDArray logits, INDArray randomValues, double p, double temperature) {
+ super(new INDArray[]{logits, randomValues}, null);
+ this.p = p;
+ this.temperature = temperature;
+ addIArgument(seed);
+ addTArgument(p, temperature, repetitionPenalty, frequencyPenalty, presencePenalty);
+ }
+
+ /**
+ * SameDiff constructor.
+ */
+ public GpuTopPSample(SameDiff sameDiff, SDVariable logits, double p, double temperature, long seed) {
+ super(null, sameDiff, new SDVariable[]{logits}, false);
+ this.p = p;
+ this.temperature = temperature;
+ this.seed = seed;
+ addIArgument(seed);
+ addTArgument(p, temperature, repetitionPenalty, frequencyPenalty, presencePenalty);
+ }
+
+ /**
+ * SameDiff constructor with random values input.
+ */
+ public GpuTopPSample(SameDiff sameDiff, SDVariable logits, SDVariable randomValues,
+ double p, double temperature) {
+ super(null, sameDiff, new SDVariable[]{logits, randomValues}, false);
+ this.p = p;
+ this.temperature = temperature;
+ addIArgument(seed);
+ addTArgument(p, temperature, repetitionPenalty, frequencyPenalty, presencePenalty);
+ }
+
+ /**
+ * SameDiff constructor with int seed and reordered parameters (seed, p, temperature, ...).
+ */
+ public GpuTopPSample(SameDiff sameDiff, SDVariable logits, int seed, double p,
+ double temperature, double repetitionPenalty, double frequencyPenalty,
+ double presencePenalty) {
+ this(sameDiff, logits, p, temperature, (long) seed, repetitionPenalty, frequencyPenalty, presencePenalty);
+ }
+
+ /**
+ * Full SameDiff constructor with all parameters.
+ */
+ public GpuTopPSample(SameDiff sameDiff, SDVariable logits, double p, double temperature,
+ long seed, double repetitionPenalty, double frequencyPenalty,
+ double presencePenalty) {
+ super(null, sameDiff, new SDVariable[]{logits}, false);
+ this.p = p;
+ this.temperature = temperature;
+ this.seed = seed;
+ this.repetitionPenalty = repetitionPenalty;
+ this.frequencyPenalty = frequencyPenalty;
+ this.presencePenalty = presencePenalty;
+ addIArgument(seed);
+ addTArgument(p, temperature, repetitionPenalty, frequencyPenalty, presencePenalty);
+ }
+
+ @Override
+ public void configureFromArguments() {
+ super.configureFromArguments();
+ if (iArguments.size() > 0) this.seed = iArguments.get(0);
+ if (tArguments.size() > 0) this.p = tArguments.get(0);
+ if (tArguments.size() > 1) this.temperature = tArguments.get(1);
+ if (tArguments.size() > 2) this.repetitionPenalty = tArguments.get(2);
+ if (tArguments.size() > 3) this.frequencyPenalty = tArguments.get(3);
+ if (tArguments.size() > 4) this.presencePenalty = tArguments.get(4);
+ }
+
+ @Override
+ public String opName() {
+ return "gpu_top_p_sample";
+ }
+
+ @Override
+ public List
+ * logits = hiddenStates @ gateWeight
+ * expertIndices = topK(softmax(logits))
+ * gateWeights = normalized weights for selected experts
+ * auxLoss = load balancing loss
+ *
+ *
+ *
+ *
+ *
+ *
+ *
+ *
+ *
+ *
+ * @author Eclipse Deeplearning4j Contributors
+ */
+@NoArgsConstructor
+public class MoeGate extends DynamicCustomOp {
+
+ @Getter private int topK = 2;
+ @Getter private int numExperts = 8;
+ @Getter private float auxLossCoeff = 0.01f;
+
+ /**
+ * SameDiff constructor with default options.
+ */
+ public MoeGate(SameDiff sameDiff, SDVariable hiddenStates, SDVariable gateWeight) {
+ super(null, sameDiff, new SDVariable[]{hiddenStates, gateWeight}, false);
+ addIArgument((long) topK, (long) numExperts);
+ addTArgument((double) auxLossCoeff);
+ }
+
+ /**
+ * SameDiff constructor with double auxLossCoeff.
+ */
+ public MoeGate(SameDiff sameDiff, SDVariable hiddenStates, SDVariable gateWeight,
+ int topK, int numExperts, double auxLossCoeff) {
+ this(sameDiff, hiddenStates, gateWeight, topK, numExperts, (float) auxLossCoeff);
+ }
+
+ /**
+ * Full SameDiff constructor with all options.
+ */
+ public MoeGate(SameDiff sameDiff, SDVariable hiddenStates, SDVariable gateWeight,
+ int topK, int numExperts, float auxLossCoeff) {
+ super(null, sameDiff, new SDVariable[]{hiddenStates, gateWeight}, false);
+ this.topK = topK;
+ this.numExperts = numExperts;
+ this.auxLossCoeff = auxLossCoeff;
+ addIArgument((long) topK, (long) numExperts);
+ addTArgument((double) auxLossCoeff);
+ }
+
+ /**
+ * INDArray constructor.
+ */
+ public MoeGate(INDArray hiddenStates, INDArray gateWeight,
+ INDArray expertIndices, INDArray gateWeights, INDArray auxLoss,
+ int topK, int numExperts, float auxLossCoeff) {
+ super(null, new INDArray[]{hiddenStates, gateWeight},
+ expertIndices != null ? new INDArray[]{expertIndices, gateWeights, auxLoss} : null);
+ this.topK = topK;
+ this.numExperts = numExperts;
+ this.auxLossCoeff = auxLossCoeff;
+ addIArgument((long) topK, (long) numExperts);
+ addTArgument((double) auxLossCoeff);
+ }
+
+ @Override
+ public void configureFromArguments() {
+ super.configureFromArguments();
+ if (iArguments.size() > 0) this.topK = iArguments.get(0).intValue();
+ if (iArguments.size() > 1) this.numExperts = iArguments.get(1).intValue();
+ if (tArguments.size() > 0) this.auxLossCoeff = tArguments.get(0).floatValue();
+ }
+
+ @Override
+ public String opName() {
+ return "moe_gate";
+ }
+
+ @Override
+ public List
+ * output[i] = input[i] @ baseWeight + alpha * input[i] @ loraA[adapterIds[i]] @ loraB[adapterIds[i]]
+ *
+ *
+ *
+ *
+ *
+ *
+ * partial = inputShard @ weightShard
+ * output = allReduce(partial) + bias if reduceOutput else partial
+ *
+ *
+ *
+ *
+ *
+ *
+ * h_t = A_t * h_{t-1} + B_t * x_t
+ * y_t = C_t * h_t + D * x_t
+ *
+ *
+ *
+ *
+ * Y = deq( quant(X * diag(s)^{-1}) @ quant(diag(s) * W) )
+ *
+ * where {@code s} is a per-channel smoothing factor that balances the dynamic
+ * ranges of activations and weights.
+ *
+ *
+ *
+ *
+ *
+ * @author Eclipse Deeplearning4j Contributors
+ */
+@NoArgsConstructor
+public class SmoothQuant extends DynamicCustomOp {
+
+ @Getter private boolean transposeWeight = false;
+
+ /**
+ * SameDiff constructor with required inputs (no bias).
+ *
+ * @param sameDiff the SameDiff instance
+ * @param x input activations [batch, in_features]
+ * @param wQuantized pre-quantized smoothed weights [out_features, in_features]
+ * @param smoothScale per-channel smoothing factors [in_features]
+ * @param actScale activation quantization scale
+ * @param weightScale weight quantization scale [out_features]
+ */
+ public SmoothQuant(SameDiff sameDiff, SDVariable x, SDVariable wQuantized,
+ SDVariable smoothScale, SDVariable actScale, SDVariable weightScale) {
+ super(null, sameDiff, new SDVariable[]{x, wQuantized, smoothScale, actScale, weightScale}, false);
+ addIArgument(transposeWeight ? 1L : 0L);
+ }
+
+ /**
+ * SameDiff constructor with bias.
+ */
+ public SmoothQuant(SameDiff sameDiff, SDVariable x, SDVariable wQuantized,
+ SDVariable smoothScale, SDVariable actScale, SDVariable weightScale,
+ SDVariable bias) {
+ super(null, sameDiff,
+ new SDVariable[]{x, wQuantized, smoothScale, actScale, weightScale, bias}, false);
+ addIArgument(transposeWeight ? 1L : 0L);
+ }
+
+ /**
+ * SameDiff constructor with all options.
+ */
+ public SmoothQuant(SameDiff sameDiff, SDVariable x, SDVariable wQuantized,
+ SDVariable smoothScale, SDVariable actScale, SDVariable weightScale,
+ SDVariable bias, boolean transposeWeight) {
+ super(null, sameDiff,
+ bias != null ?
+ new SDVariable[]{x, wQuantized, smoothScale, actScale, weightScale, bias} :
+ new SDVariable[]{x, wQuantized, smoothScale, actScale, weightScale},
+ false);
+ this.transposeWeight = transposeWeight;
+ addIArgument(transposeWeight ? 1L : 0L);
+ }
+
+ /**
+ * INDArray constructor.
+ */
+ public SmoothQuant(INDArray x, INDArray wQuantized, INDArray smoothScale,
+ INDArray actScale, INDArray weightScale, INDArray bias,
+ INDArray output, boolean transposeWeight) {
+ super(null,
+ bias != null ?
+ new INDArray[]{x, wQuantized, smoothScale, actScale, weightScale, bias} :
+ new INDArray[]{x, wQuantized, smoothScale, actScale, weightScale},
+ output != null ? new INDArray[]{output} : null);
+ this.transposeWeight = transposeWeight;
+ addIArgument(transposeWeight ? 1L : 0L);
+ }
+
+ @Override
+ public void configureFromArguments() {
+ super.configureFromArguments();
+ if (iArguments.size() > 0) this.transposeWeight = iArguments.get(0) != 0;
+ }
+
+ @Override
+ public String opName() {
+ return "smooth_quant";
+ }
+
+ @Override
+ public List
+ *
+ * @param x First input boolean array (BOOL type)
+ * @param y Second input boolean array (BOOL type)
+ * @return output Boolean AND result (BOOL type)
+ */
+ public INDArray booleanAnd(INDArray x, INDArray y) {
+ NDValidation.validateBool("booleanAnd", "x", x);
+ NDValidation.validateBool("booleanAnd", "y", y);
+ INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.BooleanAnd(x, y));
+ try {
+ return __tmp[0];
+ } finally {
+ if(__tmp != null) {
+ for(int __i = 1; __i < __tmp.length; __i++) {
+ if(__tmp[__i] != null) {
+ __tmp[__i].close();
+ }
+ }
+ }
+ }
+ }
+
+ /**
+ * Boolean NOT operation: elementwise !x
+ *
+ * @param x Input boolean array (BOOL type)
+ * @return output Boolean NOT result (BOOL type)
+ */
+ public INDArray booleanNot(INDArray x) {
+ NDValidation.validateBool("booleanNot", "x", x);
+ return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.bool.BooleanNot(x));
+ }
+
+ /**
+ * Boolean OR operation: elementwise x || y. Supports broadcasting.
+ *
+ * @param x First input boolean array (BOOL type)
+ * @param y Second input boolean array (BOOL type)
+ * @return output Boolean OR result (BOOL type)
+ */
+ public INDArray booleanOr(INDArray x, INDArray y) {
+ NDValidation.validateBool("booleanOr", "x", x);
+ NDValidation.validateBool("booleanOr", "y", y);
+ INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.BooleanOr(x, y));
+ try {
+ return __tmp[0];
+ } finally {
+ if(__tmp != null) {
+ for(int __i = 1; __i < __tmp.length; __i++) {
+ if(__tmp[__i] != null) {
+ __tmp[__i].close();
+ }
+ }
+ }
+ }
+ }
+
+ /**
+ * Boolean XOR operation: elementwise x ^ y. Supports broadcasting.
+ *
+ * @param x First input boolean array (BOOL type)
+ * @param y Second input boolean array (BOOL type)
+ * @return output Boolean XOR result (BOOL type)
+ */
+ public INDArray booleanXor(INDArray x, INDArray y) {
+ NDValidation.validateBool("booleanXor", "x", x);
+ NDValidation.validateBool("booleanXor", "y", y);
+ INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.BooleanXor(x, y));
+ try {
+ return __tmp[0];
+ } finally {
+ if(__tmp != null) {
+ for(int __i = 1; __i < __tmp.length; __i++) {
+ if(__tmp[__i] != null) {
+ __tmp[__i].close();
+ }
+ }
+ }
+ }
+ }
+
/**
* Cast the array to a new datatype - for example, Integer -> Float
*
@@ -228,7 +366,18 @@ public INDArray[] batchMmul(INDArray alphas, INDArray betas, INDArray[] inputsA,
* @return output Output array (after casting) (NDARRAY type)
*/
public INDArray castTo(INDArray arg, DataType datatype) {
- return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.dtype.Cast(arg, datatype))[0];
+ INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.dtype.Cast(arg, datatype));
+ try {
+ return __tmp[0];
+ } finally {
+ if(__tmp != null) {
+ for(int __i = 1; __i < __tmp.length; __i++) {
+ if(__tmp[__i] != null) {
+ __tmp[__i].close();
+ }
+ }
+ }
+ }
}
/**
@@ -241,7 +390,18 @@ public INDArray castTo(INDArray arg, DataType datatype) {
*/
public INDArray clipByNorm(INDArray x, double clipValue) {
NDValidation.validateNumerical("clipByNorm", "x", x);
- return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.clip.ClipByNorm(x, clipValue))[0];
+ INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.clip.ClipByNorm(x, clipValue));
+ try {
+ return __tmp[0];
+ } finally {
+ if(__tmp != null) {
+ for(int __i = 1; __i < __tmp.length; __i++) {
+ if(__tmp[__i] != null) {
+ __tmp[__i].close();
+ }
+ }
+ }
+ }
}
/**
@@ -257,7 +417,18 @@ public INDArray clipByNorm(INDArray x, INDArray clipValue, INDArray dimensions)
NDValidation.validateNumerical("clipByNorm", "x", x);
NDValidation.validateNumerical("clipByNorm", "clipValue", clipValue);
NDValidation.validateNumerical("clipByNorm", "dimensions", dimensions);
- return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.clip.ClipByNorm(x, clipValue, dimensions))[0];
+ INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.clip.ClipByNorm(x, clipValue, dimensions));
+ try {
+ return __tmp[0];
+ } finally {
+ if(__tmp != null) {
+ for(int __i = 1; __i < __tmp.length; __i++) {
+ if(__tmp[__i] != null) {
+ __tmp[__i].close();
+ }
+ }
+ }
+ }
}
/**
@@ -270,7 +441,18 @@ public INDArray clipByNorm(INDArray x, INDArray clipValue, INDArray dimensions)
*/
public INDArray clipByValue(INDArray x, double clipValueMin, double clipValueMax) {
NDValidation.validateNumerical("clipByValue", "x", x);
- return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.clip.ClipByValue(x, clipValueMin, clipValueMax))[0];
+ INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.clip.ClipByValue(x, clipValueMin, clipValueMax));
+ try {
+ return __tmp[0];
+ } finally {
+ if(__tmp != null) {
+ for(int __i = 1; __i < __tmp.length; __i++) {
+ if(__tmp[__i] != null) {
+ __tmp[__i].close();
+ }
+ }
+ }
+ }
}
/**
@@ -285,7 +467,18 @@ public INDArray clipByValue(INDArray x, INDArray clipValueMin, INDArray clipValu
NDValidation.validateNumerical("clipByValue", "x", x);
NDValidation.validateNumerical("clipByValue", "clipValueMin", clipValueMin);
NDValidation.validateNumerical("clipByValue", "clipValueMax", clipValueMax);
- return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.clip.ClipByValue(x, clipValueMin, clipValueMax))[0];
+ INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.clip.ClipByValue(x, clipValueMin, clipValueMax));
+ try {
+ return __tmp[0];
+ } finally {
+ if(__tmp != null) {
+ for(int __i = 1; __i < __tmp.length; __i++) {
+ if(__tmp[__i] != null) {
+ __tmp[__i].close();
+ }
+ }
+ }
+ }
}
/**
@@ -304,7 +497,18 @@ public INDArray concat(int dimension, INDArray... inputs) {
NDValidation.validateNumerical("concat", "inputs", inputs);
Preconditions.checkArgument(inputs.length >= 1, "inputs has incorrect size/length. Expected: inputs.length >= 1, got %s", inputs.length);
Preconditions.checkArgument(isSameType(inputs), "Input arrays must all be the same datatype");
- return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Concat(inputs, dimension))[0];
+ INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Concat(inputs, dimension));
+ try {
+ return __tmp[0];
+ } finally {
+ if(__tmp != null) {
+ for(int __i = 1; __i < __tmp.length; __i++) {
+ if(__tmp[__i] != null) {
+ __tmp[__i].close();
+ }
+ }
+ }
+ }
}
/**
@@ -318,7 +522,18 @@ public INDArray concat(int dimension, INDArray... inputs) {
*/
public INDArray create(INDArray shape, DataType dataType, String order, boolean initialize) {
NDValidation.validateNumerical("create", "shape", shape);
- return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Create(shape, dataType, order, initialize))[0];
+ INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Create(shape, dataType, order, initialize));
+ try {
+ return __tmp[0];
+ } finally {
+ if(__tmp != null) {
+ for(int __i = 1; __i < __tmp.length; __i++) {
+ if(__tmp[__i] != null) {
+ __tmp[__i].close();
+ }
+ }
+ }
+ }
}
/**
@@ -330,7 +545,18 @@ public INDArray create(INDArray shape, DataType dataType, String order, boolean
*/
public INDArray create(INDArray shape, DataType dataType) {
NDValidation.validateNumerical("create", "shape", shape);
- return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Create(shape, dataType, "c", false))[0];
+ INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Create(shape, dataType, "c", false));
+ try {
+ return __tmp[0];
+ } finally {
+ if(__tmp != null) {
+ for(int __i = 1; __i < __tmp.length; __i++) {
+ if(__tmp[__i] != null) {
+ __tmp[__i].close();
+ }
+ }
+ }
+ }
}
/**
@@ -342,7 +568,18 @@ public INDArray create(INDArray shape, DataType dataType) {
*/
public INDArray createView(INDArray input, INDArray... indices) {
Preconditions.checkArgument(indices.length >= 0, "indices has incorrect size/length. Expected: indices.length >= 0, got %s", indices.length);
- return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.CreateView(input, indices))[0];
+ INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.CreateView(input, indices));
+ try {
+ return __tmp[0];
+ } finally {
+ if(__tmp != null) {
+ for(int __i = 1; __i < __tmp.length; __i++) {
+ if(__tmp[__i] != null) {
+ __tmp[__i].close();
+ }
+ }
+ }
+ }
}
/**
@@ -362,7 +599,18 @@ public INDArray createView(INDArray input, INDArray... indices) {
public INDArray cumprod(INDArray in, boolean exclusive, boolean reverse, long... axis) {
NDValidation.validateNumerical("cumprod", "in", in);
Preconditions.checkArgument(axis.length >= 1, "axis has incorrect size/length. Expected: axis.length >= 1, got %s", axis.length);
- return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.CumProd(in, exclusive, reverse, axis))[0];
+ INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.CumProd(in, exclusive, reverse, axis));
+ try {
+ return __tmp[0];
+ } finally {
+ if(__tmp != null) {
+ for(int __i = 1; __i < __tmp.length; __i++) {
+ if(__tmp[__i] != null) {
+ __tmp[__i].close();
+ }
+ }
+ }
+ }
}
/**
@@ -380,7 +628,18 @@ public INDArray cumprod(INDArray in, boolean exclusive, boolean reverse, long...
public INDArray cumprod(INDArray in, long... axis) {
NDValidation.validateNumerical("cumprod", "in", in);
Preconditions.checkArgument(axis.length >= 1, "axis has incorrect size/length. Expected: axis.length >= 1, got %s", axis.length);
- return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.CumProd(in, false, false, axis))[0];
+ INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.CumProd(in, false, false, axis));
+ try {
+ return __tmp[0];
+ } finally {
+ if(__tmp != null) {
+ for(int __i = 1; __i < __tmp.length; __i++) {
+ if(__tmp[__i] != null) {
+ __tmp[__i].close();
+ }
+ }
+ }
+ }
}
/**
@@ -400,7 +659,18 @@ public INDArray cumprod(INDArray in, long... axis) {
public INDArray cumsum(INDArray in, boolean exclusive, boolean reverse, long... axis) {
NDValidation.validateNumerical("cumsum", "in", in);
Preconditions.checkArgument(axis.length >= 1, "axis has incorrect size/length. Expected: axis.length >= 1, got %s", axis.length);
- return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.CumSum(in, exclusive, reverse, axis))[0];
+ INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.CumSum(in, exclusive, reverse, axis));
+ try {
+ return __tmp[0];
+ } finally {
+ if(__tmp != null) {
+ for(int __i = 1; __i < __tmp.length; __i++) {
+ if(__tmp[__i] != null) {
+ __tmp[__i].close();
+ }
+ }
+ }
+ }
}
/**
@@ -418,7 +688,18 @@ public INDArray cumsum(INDArray in, boolean exclusive, boolean reverse, long...
public INDArray cumsum(INDArray in, long... axis) {
NDValidation.validateNumerical("cumsum", "in", in);
Preconditions.checkArgument(axis.length >= 1, "axis has incorrect size/length. Expected: axis.length >= 1, got %s", axis.length);
- return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.CumSum(in, false, false, axis))[0];
+ INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.CumSum(in, false, false, axis));
+ try {
+ return __tmp[0];
+ } finally {
+ if(__tmp != null) {
+ for(int __i = 1; __i < __tmp.length; __i++) {
+ if(__tmp[__i] != null) {
+ __tmp[__i].close();
+ }
+ }
+ }
+ }
}
/**
@@ -470,7 +751,18 @@ public INDArray dynamicStitch(INDArray[] indices, INDArray... x) {
Preconditions.checkArgument(indices.length >= 1, "indices has incorrect size/length. Expected: indices.length >= 1, got %s", indices.length);
NDValidation.validateNumerical("dynamicStitch", "x", x);
Preconditions.checkArgument(x.length >= 1, "x has incorrect size/length. Expected: x.length >= 1, got %s", x.length);
- return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.DynamicStitch(indices, x))[0];
+ INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.DynamicStitch(indices, x));
+ try {
+ return __tmp[0];
+ } finally {
+ if(__tmp != null) {
+ for(int __i = 1; __i < __tmp.length; __i++) {
+ if(__tmp[__i] != null) {
+ __tmp[__i].close();
+ }
+ }
+ }
+ }
}
/**
@@ -502,7 +794,18 @@ public INDArray eq(INDArray x, double y) {
* @return output Boolean array out, with values true/false based on where the condition is satisfied (NUMERIC type)
*/
public INDArray eq(INDArray x, INDArray y) {
- return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.EqualTo(x, y))[0];
+ INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.EqualTo(x, y));
+ try {
+ return __tmp[0];
+ } finally {
+ if(__tmp != null) {
+ for(int __i = 1; __i < __tmp.length; __i++) {
+ if(__tmp[__i] != null) {
+ __tmp[__i].close();
+ }
+ }
+ }
+ }
}
/**
@@ -517,7 +820,18 @@ public INDArray eq(INDArray x, INDArray y) {
* @return output Output variable (NUMERIC type)
*/
public INDArray expandDims(INDArray x, int axis) {
- return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.ExpandDims(x, axis))[0];
+ INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.ExpandDims(x, axis));
+ try {
+ return __tmp[0];
+ } finally {
+ if(__tmp != null) {
+ for(int __i = 1; __i < __tmp.length; __i++) {
+ if(__tmp[__i] != null) {
+ __tmp[__i].close();
+ }
+ }
+ }
+ }
}
/**
@@ -530,7 +844,18 @@ public INDArray expandDims(INDArray x, int axis) {
*/
public INDArray fill(INDArray shape, DataType dataType, double value) {
NDValidation.validateInteger("fill", "shape", shape);
- return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.Fill(shape, dataType, value))[0];
+ INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.Fill(shape, dataType, value));
+ try {
+ return __tmp[0];
+ } finally {
+ if(__tmp != null) {
+ for(int __i = 1; __i < __tmp.length; __i++) {
+ if(__tmp[__i] != null) {
+ __tmp[__i].close();
+ }
+ }
+ }
+ }
}
/**
@@ -542,7 +867,18 @@ public INDArray fill(INDArray shape, DataType dataType, double value) {
*/
public INDArray flatten(INDArray[] inputs, String order) {
Preconditions.checkArgument(inputs.length >= 1, "inputs has incorrect size/length. Expected: inputs.length >= 1, got %s", inputs.length);
- return Nd4j.exec(new org.nd4j.linalg.api.ops.custom.Flatten(inputs, order))[0];
+ INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.custom.Flatten(inputs, order));
+ try {
+ return __tmp[0];
+ } finally {
+ if(__tmp != null) {
+ for(int __i = 1; __i < __tmp.length; __i++) {
+ if(__tmp[__i] != null) {
+ __tmp[__i].close();
+ }
+ }
+ }
+ }
}
/**
@@ -553,7 +889,18 @@ public INDArray flatten(INDArray[] inputs, String order) {
*/
public INDArray flatten(INDArray... inputs) {
Preconditions.checkArgument(inputs.length >= 1, "inputs has incorrect size/length. Expected: inputs.length >= 1, got %s", inputs.length);
- return Nd4j.exec(new org.nd4j.linalg.api.ops.custom.Flatten(inputs, "c"))[0];
+ INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.custom.Flatten(inputs, "c"));
+ try {
+ return __tmp[0];
+ } finally {
+ if(__tmp != null) {
+ for(int __i = 1; __i < __tmp.length; __i++) {
+ if(__tmp[__i] != null) {
+ __tmp[__i].close();
+ }
+ }
+ }
+ }
}
/**
@@ -567,7 +914,18 @@ public INDArray flatten(INDArray... inputs) {
*/
public INDArray gather(INDArray df, int[] indices, int axis) {
Preconditions.checkArgument(indices.length >= 1, "indices has incorrect size/length. Expected: indices.length >= 1, got %s", indices.length);
- return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Gather(df, indices, axis))[0];
+ INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Gather(df, indices, axis));
+ try {
+ return __tmp[0];
+ } finally {
+ if(__tmp != null) {
+ for(int __i = 1; __i < __tmp.length; __i++) {
+ if(__tmp[__i] != null) {
+ __tmp[__i].close();
+ }
+ }
+ }
+ }
}
/**
@@ -581,7 +939,18 @@ public INDArray gather(INDArray df, int[] indices, int axis) {
*/
public INDArray gather(INDArray df, INDArray indices, int axis) {
NDValidation.validateInteger("gather", "indices", indices);
- return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Gather(df, indices, axis))[0];
+ INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Gather(df, indices, axis));
+ try {
+ return __tmp[0];
+ } finally {
+ if(__tmp != null) {
+ for(int __i = 1; __i < __tmp.length; __i++) {
+ if(__tmp[__i] != null) {
+ __tmp[__i].close();
+ }
+ }
+ }
+ }
}
/**
@@ -593,7 +962,18 @@ public INDArray gather(INDArray df, INDArray indices, int axis) {
*/
public INDArray gatherNd(INDArray df, INDArray indices) {
NDValidation.validateNumerical("gatherNd", "indices", indices);
- return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.GatherNd(df, indices))[0];
+ INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.GatherNd(df, indices));
+ try {
+ return __tmp[0];
+ } finally {
+ if(__tmp != null) {
+ for(int __i = 1; __i < __tmp.length; __i++) {
+ if(__tmp[__i] != null) {
+ __tmp[__i].close();
+ }
+ }
+ }
+ }
}
/**
@@ -624,7 +1004,18 @@ public INDArray gt(INDArray x, double y) {
* @return output Output Boolean array out, with values true/false based on where the condition is satisfied (NUMERIC type)
*/
public INDArray gt(INDArray x, INDArray y) {
- return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.GreaterThan(x, y))[0];
+ INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.GreaterThan(x, y));
+ try {
+ return __tmp[0];
+ } finally {
+ if(__tmp != null) {
+ for(int __i = 1; __i < __tmp.length; __i++) {
+ if(__tmp[__i] != null) {
+ __tmp[__i].close();
+ }
+ }
+ }
+ }
}
/**
@@ -655,7 +1046,18 @@ public INDArray gte(INDArray x, double y) {
* @return output (NDARRAY type)
*/
public INDArray gte(INDArray x, INDArray y) {
- return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.GreaterThanOrEqual(x, y))[0];
+ INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.GreaterThanOrEqual(x, y));
+ try {
+ return __tmp[0];
+ } finally {
+ if(__tmp != null) {
+ for(int __i = 1; __i < __tmp.length; __i++) {
+ if(__tmp[__i] != null) {
+ __tmp[__i].close();
+ }
+ }
+ }
+ }
}
/**
@@ -665,7 +1067,18 @@ public INDArray gte(INDArray x, INDArray y) {
* @return output Output variable (NDARRAY type)
*/
public INDArray identity(INDArray input) {
- return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.same.Identity(input))[0];
+ INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.same.Identity(input));
+ try {
+ return __tmp[0];
+ } finally {
+ if(__tmp != null) {
+ for(int __i = 1; __i < __tmp.length; __i++) {
+ if(__tmp[__i] != null) {
+ __tmp[__i].close();
+ }
+ }
+ }
+ }
}
/**
@@ -678,7 +1091,18 @@ public INDArray identity(INDArray input) {
*/
public INDArray invertPermutation(INDArray input) {
NDValidation.validateInteger("invertPermutation", "input", input);
- return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.InvertPermutation(input))[0];
+ INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.InvertPermutation(input));
+ try {
+ return __tmp[0];
+ } finally {
+ if(__tmp != null) {
+ for(int __i = 1; __i < __tmp.length; __i++) {
+ if(__tmp[__i] != null) {
+ __tmp[__i].close();
+ }
+ }
+ }
+ }
}
/**
@@ -689,7 +1113,18 @@ public INDArray invertPermutation(INDArray input) {
*/
public INDArray isNumericTensor(INDArray x) {
NDValidation.validateNumerical("isNumericTensor", "x", x);
- return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.IsNumericTensor(x))[0];
+ INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.IsNumericTensor(x));
+ try {
+ return __tmp[0];
+ } finally {
+ if(__tmp != null) {
+ for(int __i = 1; __i < __tmp.length; __i++) {
+ if(__tmp[__i] != null) {
+ __tmp[__i].close();
+ }
+ }
+ }
+ }
}
/**
@@ -703,7 +1138,18 @@ public INDArray isNumericTensor(INDArray x) {
* @return output INDArray with linearly spaced elements (NUMERIC type)
*/
public INDArray linspace(DataType dataType, double start, double stop, long number) {
- return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Linspace(dataType, start, stop, number))[0];
+ INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Linspace(dataType, start, stop, number));
+ try {
+ return __tmp[0];
+ } finally {
+ if(__tmp != null) {
+ for(int __i = 1; __i < __tmp.length; __i++) {
+ if(__tmp[__i] != null) {
+ __tmp[__i].close();
+ }
+ }
+ }
+ }
}
/**
@@ -720,7 +1166,18 @@ public INDArray linspace(INDArray start, INDArray stop, INDArray number, DataTyp
NDValidation.validateNumerical("linspace", "start", start);
NDValidation.validateNumerical("linspace", "stop", stop);
NDValidation.validateInteger("linspace", "number", number);
- return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Linspace(start, stop, number, dataType))[0];
+ INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Linspace(start, stop, number, dataType));
+ try {
+ return __tmp[0];
+ } finally {
+ if(__tmp != null) {
+ for(int __i = 1; __i < __tmp.length; __i++) {
+ if(__tmp[__i] != null) {
+ __tmp[__i].close();
+ }
+ }
+ }
+ }
}
/**
@@ -751,7 +1208,18 @@ public INDArray lt(INDArray x, double y) {
* @return output Output Boolean array out, with values true/false based on where the condition is satisfied (NDARRAY type)
*/
public INDArray lt(INDArray x, INDArray y) {
- return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.LessThan(x, y))[0];
+ INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.LessThan(x, y));
+ try {
+ return __tmp[0];
+ } finally {
+ if(__tmp != null) {
+ for(int __i = 1; __i < __tmp.length; __i++) {
+ if(__tmp[__i] != null) {
+ __tmp[__i].close();
+ }
+ }
+ }
+ }
}
/**
@@ -784,7 +1252,18 @@ public INDArray lte(INDArray x, double y) {
public INDArray lte(INDArray x, INDArray y) {
NDValidation.validateNumerical("lte", "x", x);
NDValidation.validateNumerical("lte", "y", y);
- return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.LessThanOrEqual(x, y))[0];
+ INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.LessThanOrEqual(x, y));
+ try {
+ return __tmp[0];
+ } finally {
+ if(__tmp != null) {
+ for(int __i = 1; __i < __tmp.length; __i++) {
+ if(__tmp[__i] != null) {
+ __tmp[__i].close();
+ }
+ }
+ }
+ }
}
/**
@@ -910,7 +1389,18 @@ public INDArray max(INDArray x, long... dimensions) {
public INDArray max(INDArray first, INDArray second) {
NDValidation.validateNumerical("max", "first", first);
NDValidation.validateNumerical("max", "second", second);
- return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.Max(first, second))[0];
+ INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.Max(first, second));
+ try {
+ return __tmp[0];
+ } finally {
+ if(__tmp != null) {
+ for(int __i = 1; __i < __tmp.length; __i++) {
+ if(__tmp[__i] != null) {
+ __tmp[__i].close();
+ }
+ }
+ }
+ }
}
/**
@@ -1005,7 +1495,18 @@ public INDArray mean(INDArray x, INDArray dimensions) {
* @return output Output (NDARRAY type)
*/
public INDArray merge(INDArray x, INDArray y) {
- return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.controlflow.compat.Merge(x, y))[0];
+ INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.controlflow.compat.Merge(x, y));
+ try {
+ return __tmp[0];
+ } finally {
+ if(__tmp != null) {
+ for(int __i = 1; __i < __tmp.length; __i++) {
+ if(__tmp[__i] != null) {
+ __tmp[__i].close();
+ }
+ }
+ }
+ }
}
/**
@@ -1063,7 +1564,18 @@ public INDArray min(INDArray x, long... dimensions) {
public INDArray min(INDArray first, INDArray second) {
NDValidation.validateNumerical("min", "first", first);
NDValidation.validateNumerical("min", "second", second);
- return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.Min(first, second))[0];
+ INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.Min(first, second));
+ try {
+ return __tmp[0];
+ } finally {
+ if(__tmp != null) {
+ for(int __i = 1; __i < __tmp.length; __i++) {
+ if(__tmp[__i] != null) {
+ __tmp[__i].close();
+ }
+ }
+ }
+ }
}
/**
@@ -1074,7 +1586,18 @@ public INDArray min(INDArray first, INDArray second) {
* @return output Output array (after casting) (NDARRAY type)
*/
public INDArray minMax(int datatype, int minOrMax) {
- return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.dtype.MinMaxDataType(datatype, minOrMax))[0];
+ INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.dtype.MinMaxDataType(datatype, minOrMax));
+ try {
+ return __tmp[0];
+ } finally {
+ if(__tmp != null) {
+ for(int __i = 1; __i < __tmp.length; __i++) {
+ if(__tmp[__i] != null) {
+ __tmp[__i].close();
+ }
+ }
+ }
+ }
}
/**
@@ -1092,7 +1615,18 @@ public INDArray mmul(INDArray x, INDArray y, boolean transposeX, boolean transpo
boolean transposeZ) {
NDValidation.validateNumerical("mmul", "x", x);
NDValidation.validateNumerical("mmul", "y", y);
- return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.Mmul(x, y, transposeX, transposeY, transposeZ))[0];
+ INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.Mmul(x, y, transposeX, transposeY, transposeZ));
+ try {
+ return __tmp[0];
+ } finally {
+ if(__tmp != null) {
+ for(int __i = 1; __i < __tmp.length; __i++) {
+ if(__tmp[__i] != null) {
+ __tmp[__i].close();
+ }
+ }
+ }
+ }
}
/**
@@ -1106,7 +1640,18 @@ public INDArray mmul(INDArray x, INDArray y, boolean transposeX, boolean transpo
public INDArray mmul(INDArray x, INDArray y) {
NDValidation.validateNumerical("mmul", "x", x);
NDValidation.validateNumerical("mmul", "y", y);
- return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.Mmul(x, y, false, false, false))[0];
+ INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.Mmul(x, y, false, false, false));
+ try {
+ return __tmp[0];
+ } finally {
+ if(__tmp != null) {
+ for(int __i = 1; __i < __tmp.length; __i++) {
+ if(__tmp[__i] != null) {
+ __tmp[__i].close();
+ }
+ }
+ }
+ }
}
/**
@@ -1137,7 +1682,18 @@ public INDArray neq(INDArray x, double y) {
* @return output Boolean array out, with values true/false based on where the condition is satisfied (NDARRAY type)
*/
public INDArray neq(INDArray x, INDArray y) {
- return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.NotEqualTo(x, y))[0];
+ INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.NotEqualTo(x, y));
+ try {
+ return __tmp[0];
+ } finally {
+ if(__tmp != null) {
+ for(int __i = 1; __i < __tmp.length; __i++) {
+ if(__tmp[__i] != null) {
+ __tmp[__i].close();
+ }
+ }
+ }
+ }
}
/**
@@ -1287,7 +1843,18 @@ public INDArray normmax(INDArray x, long... dimensions) {
public INDArray oneHot(INDArray indices, int depth, int axis, double on, double off,
DataType dataType) {
NDValidation.validateNumerical("oneHot", "indices", indices);
- return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.OneHot(indices, depth, axis, on, off, dataType))[0];
+ INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.OneHot(indices, depth, axis, on, off, dataType));
+ try {
+ return __tmp[0];
+ } finally {
+ if(__tmp != null) {
+ for(int __i = 1; __i < __tmp.length; __i++) {
+ if(__tmp[__i] != null) {
+ __tmp[__i].close();
+ }
+ }
+ }
+ }
}
/**
@@ -1304,7 +1871,18 @@ public INDArray oneHot(INDArray indices, int depth, int axis, double on, double
*/
public INDArray oneHot(INDArray indices, int depth, int axis, double on, double off) {
NDValidation.validateNumerical("oneHot", "indices", indices);
- return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.OneHot(indices, depth, axis, on, off, DataType.FLOAT))[0];
+ INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.OneHot(indices, depth, axis, on, off, DataType.FLOAT));
+ try {
+ return __tmp[0];
+ } finally {
+ if(__tmp != null) {
+ for(int __i = 1; __i < __tmp.length; __i++) {
+ if(__tmp[__i] != null) {
+ __tmp[__i].close();
+ }
+ }
+ }
+ }
}
/**
@@ -1319,7 +1897,18 @@ public INDArray oneHot(INDArray indices, int depth, int axis, double on, double
*/
public INDArray oneHot(INDArray indices, int depth) {
NDValidation.validateNumerical("oneHot", "indices", indices);
- return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.OneHot(indices, depth))[0];
+ INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.OneHot(indices, depth));
+ try {
+ return __tmp[0];
+ } finally {
+ if(__tmp != null) {
+ for(int __i = 1; __i < __tmp.length; __i++) {
+ if(__tmp[__i] != null) {
+ __tmp[__i].close();
+ }
+ }
+ }
+ }
}
/**
@@ -1330,7 +1919,18 @@ public INDArray oneHot(INDArray indices, int depth) {
* @return output A new INDArray with the same (dynamic) shape as the input (NUMERIC type)
*/
public INDArray onesLike(INDArray input) {
- return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.OnesLike(input))[0];
+ INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.OnesLike(input));
+ try {
+ return __tmp[0];
+ } finally {
+ if(__tmp != null) {
+ for(int __i = 1; __i < __tmp.length; __i++) {
+ if(__tmp[__i] != null) {
+ __tmp[__i].close();
+ }
+ }
+ }
+ }
}
/**
@@ -1341,7 +1941,18 @@ public INDArray onesLike(INDArray input) {
* @return output (NUMERIC type)
*/
public INDArray onesLike(INDArray input, DataType dataType) {
- return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.OnesLike(input, dataType))[0];
+ INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.OnesLike(input, dataType));
+ try {
+ return __tmp[0];
+ } finally {
+ if(__tmp != null) {
+ for(int __i = 1; __i < __tmp.length; __i++) {
+ if(__tmp[__i] != null) {
+ __tmp[__i].close();
+ }
+ }
+ }
+ }
}
/**
@@ -1354,7 +1965,18 @@ public INDArray onesLike(INDArray input, DataType dataType) {
*/
public INDArray permute(INDArray x, INDArray dimensions) {
NDValidation.validateInteger("permute", "dimensions", dimensions);
- return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Permute(x, dimensions))[0];
+ INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Permute(x, dimensions));
+ try {
+ return __tmp[0];
+ } finally {
+ if(__tmp != null) {
+ for(int __i = 1; __i < __tmp.length; __i++) {
+ if(__tmp[__i] != null) {
+ __tmp[__i].close();
+ }
+ }
+ }
+ }
}
/**
@@ -1367,7 +1989,18 @@ public INDArray permute(INDArray x, INDArray dimensions) {
*/
public INDArray permute(INDArray x, long... dimensions) {
Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length);
- return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Permute(x, dimensions))[0];
+ INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Permute(x, dimensions));
+ try {
+ return __tmp[0];
+ } finally {
+ if(__tmp != null) {
+ for(int __i = 1; __i < __tmp.length; __i++) {
+ if(__tmp[__i] != null) {
+ __tmp[__i].close();
+ }
+ }
+ }
+ }
}
/**
@@ -1464,7 +2097,18 @@ public INDArray prod(INDArray x, INDArray dimensions) {
* @return output INDArray with the specified values (NUMERIC type)
*/
public INDArray range(double from, double to, double step, DataType dataType) {
- return Nd4j.exec(new org.nd4j.linalg.api.ops.random.impl.Range(from, to, step, dataType))[0];
+ INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.random.impl.Range(from, to, step, dataType));
+ try {
+ return __tmp[0];
+ } finally {
+ if(__tmp != null) {
+ for(int __i = 1; __i < __tmp.length; __i++) {
+ if(__tmp[__i] != null) {
+ __tmp[__i].close();
+ }
+ }
+ }
+ }
}
/**
@@ -1482,7 +2126,18 @@ public INDArray range(INDArray from, INDArray to, INDArray step, DataType dataTy
NDValidation.validateNumerical("range", "from", from);
NDValidation.validateNumerical("range", "to", to);
NDValidation.validateNumerical("range", "step", step);
- return Nd4j.exec(new org.nd4j.linalg.api.ops.random.impl.Range(from, to, step, dataType))[0];
+ INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.random.impl.Range(from, to, step, dataType));
+ try {
+ return __tmp[0];
+ } finally {
+ if(__tmp != null) {
+ for(int __i = 1; __i < __tmp.length; __i++) {
+ if(__tmp[__i] != null) {
+ __tmp[__i].close();
+ }
+ }
+ }
+ }
}
/**
@@ -1492,7 +2147,18 @@ public INDArray range(INDArray from, INDArray to, INDArray step, DataType dataTy
* @return output (scalar) output variable with value equal to the rank of the input variable (NUMERIC type)
*/
public INDArray rank(INDArray in) {
- return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Rank(in))[0];
+ INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Rank(in));
+ try {
+ return __tmp[0];
+ } finally {
+ if(__tmp != null) {
+ for(int __i = 1; __i < __tmp.length; __i++) {
+ if(__tmp[__i] != null) {
+ __tmp[__i].close();
+ }
+ }
+ }
+ }
}
/**
@@ -1506,7 +2172,18 @@ public INDArray rank(INDArray in) {
public INDArray repeat(INDArray input, INDArray repeats, int axis) {
NDValidation.validateNumerical("repeat", "input", input);
NDValidation.validateNumerical("repeat", "repeats", repeats);
- return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Repeat(input, repeats, axis))[0];
+ INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Repeat(input, repeats, axis));
+ try {
+ return __tmp[0];
+ } finally {
+ if(__tmp != null) {
+ for(int __i = 1; __i < __tmp.length; __i++) {
+ if(__tmp[__i] != null) {
+ __tmp[__i].close();
+ }
+ }
+ }
+ }
}
/**
@@ -1551,7 +2228,18 @@ public INDArray replaceWhere(INDArray update, double value, Condition condition)
*/
public INDArray reshape(INDArray x, INDArray shape) {
NDValidation.validateNumerical("reshape", "shape", shape);
- return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Reshape(x, shape))[0];
+ INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Reshape(x, shape));
+ try {
+ return __tmp[0];
+ } finally {
+ if(__tmp != null) {
+ for(int __i = 1; __i < __tmp.length; __i++) {
+ if(__tmp[__i] != null) {
+ __tmp[__i].close();
+ }
+ }
+ }
+ }
}
/**
@@ -1565,7 +2253,18 @@ public INDArray reshape(INDArray x, INDArray shape) {
*/
public INDArray reshape(INDArray x, long... shape) {
Preconditions.checkArgument(shape.length >= 0, "shape has incorrect size/length. Expected: shape.length >= 0, got %s", shape.length);
- return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Reshape(x, shape))[0];
+ INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Reshape(x, shape));
+ try {
+ return __tmp[0];
+ } finally {
+ if(__tmp != null) {
+ for(int __i = 1; __i < __tmp.length; __i++) {
+ if(__tmp[__i] != null) {
+ __tmp[__i].close();
+ }
+ }
+ }
+ }
}
/**
@@ -1587,7 +2286,18 @@ public INDArray reshape(INDArray x, long... shape) {
*/
public INDArray reverse(INDArray x, long... dimensions) {
Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length);
- return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.Reverse(x, dimensions))[0];
+ INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.Reverse(x, dimensions));
+ try {
+ return __tmp[0];
+ } finally {
+ if(__tmp != null) {
+ for(int __i = 1; __i < __tmp.length; __i++) {
+ if(__tmp[__i] != null) {
+ __tmp[__i].close();
+ }
+ }
+ }
+ }
}
/**
@@ -1601,7 +2311,18 @@ public INDArray reverse(INDArray x, long... dimensions) {
*/
public INDArray reverseSequence(INDArray x, INDArray seq_lengths, int seqDim, int batchDim) {
NDValidation.validateInteger("reverseSequence", "seq_lengths", seq_lengths);
- return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.ReverseSequence(x, seq_lengths, seqDim, batchDim))[0];
+ INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.ReverseSequence(x, seq_lengths, seqDim, batchDim));
+ try {
+ return __tmp[0];
+ } finally {
+ if(__tmp != null) {
+ for(int __i = 1; __i < __tmp.length; __i++) {
+ if(__tmp[__i] != null) {
+ __tmp[__i].close();
+ }
+ }
+ }
+ }
}
/**
@@ -1613,7 +2334,18 @@ public INDArray reverseSequence(INDArray x, INDArray seq_lengths, int seqDim, in
*/
public INDArray reverseSequence(INDArray x, INDArray seq_lengths) {
NDValidation.validateInteger("reverseSequence", "seq_lengths", seq_lengths);
- return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.ReverseSequence(x, seq_lengths, -1, 0))[0];
+ INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.ReverseSequence(x, seq_lengths, -1, 0));
+ try {
+ return __tmp[0];
+ } finally {
+ if(__tmp != null) {
+ for(int __i = 1; __i < __tmp.length; __i++) {
+ if(__tmp[__i] != null) {
+ __tmp[__i].close();
+ }
+ }
+ }
+ }
}
/**
@@ -1682,7 +2414,18 @@ public INDArray scatterAdd(INDArray ref, INDArray indices, INDArray updates) {
NDValidation.validateNumerical("scatterAdd", "ref", ref);
NDValidation.validateNumerical("scatterAdd", "indices", indices);
NDValidation.validateNumerical("scatterAdd", "updates", updates);
- return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.scatter.ScatterAdd(ref, indices, updates))[0];
+ INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.scatter.ScatterAdd(ref, indices, updates));
+ try {
+ return __tmp[0];
+ } finally {
+ if(__tmp != null) {
+ for(int __i = 1; __i < __tmp.length; __i++) {
+ if(__tmp[__i] != null) {
+ __tmp[__i].close();
+ }
+ }
+ }
+ }
}
/**
@@ -1702,7 +2445,18 @@ public INDArray scatterDiv(INDArray ref, INDArray indices, INDArray updates) {
NDValidation.validateNumerical("scatterDiv", "ref", ref);
NDValidation.validateNumerical("scatterDiv", "indices", indices);
NDValidation.validateNumerical("scatterDiv", "updates", updates);
- return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.scatter.ScatterDiv(ref, indices, updates))[0];
+ INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.scatter.ScatterDiv(ref, indices, updates));
+ try {
+ return __tmp[0];
+ } finally {
+ if(__tmp != null) {
+ for(int __i = 1; __i < __tmp.length; __i++) {
+ if(__tmp[__i] != null) {
+ __tmp[__i].close();
+ }
+ }
+ }
+ }
}
/**
@@ -1724,7 +2478,18 @@ public INDArray scatterMax(INDArray ref, INDArray indices, INDArray updates) {
NDValidation.validateNumerical("scatterMax", "ref", ref);
NDValidation.validateNumerical("scatterMax", "indices", indices);
NDValidation.validateNumerical("scatterMax", "updates", updates);
- return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.scatter.ScatterMax(ref, indices, updates))[0];
+ INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.scatter.ScatterMax(ref, indices, updates));
+ try {
+ return __tmp[0];
+ } finally {
+ if(__tmp != null) {
+ for(int __i = 1; __i < __tmp.length; __i++) {
+ if(__tmp[__i] != null) {
+ __tmp[__i].close();
+ }
+ }
+ }
+ }
}
/**
@@ -1746,7 +2511,18 @@ public INDArray scatterMin(INDArray ref, INDArray indices, INDArray updates) {
NDValidation.validateNumerical("scatterMin", "ref", ref);
NDValidation.validateNumerical("scatterMin", "indices", indices);
NDValidation.validateNumerical("scatterMin", "updates", updates);
- return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.scatter.ScatterMin(ref, indices, updates))[0];
+ INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.scatter.ScatterMin(ref, indices, updates));
+ try {
+ return __tmp[0];
+ } finally {
+ if(__tmp != null) {
+ for(int __i = 1; __i < __tmp.length; __i++) {
+ if(__tmp[__i] != null) {
+ __tmp[__i].close();
+ }
+ }
+ }
+ }
}
/**
@@ -1768,7 +2544,18 @@ public INDArray scatterMul(INDArray ref, INDArray indices, INDArray updates) {
NDValidation.validateNumerical("scatterMul", "ref", ref);
NDValidation.validateNumerical("scatterMul", "indices", indices);
NDValidation.validateNumerical("scatterMul", "updates", updates);
- return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.scatter.ScatterMul(ref, indices, updates))[0];
+ INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.scatter.ScatterMul(ref, indices, updates));
+ try {
+ return __tmp[0];
+ } finally {
+ if(__tmp != null) {
+ for(int __i = 1; __i < __tmp.length; __i++) {
+ if(__tmp[__i] != null) {
+ __tmp[__i].close();
+ }
+ }
+ }
+ }
}
/**
@@ -1791,7 +2578,18 @@ public INDArray scatterNdAdd(INDArray ref, INDArray indices, INDArray updates) {
NDValidation.validateNumerical("scatterNdAdd", "ref", ref);
NDValidation.validateNumerical("scatterNdAdd", "indices", indices);
NDValidation.validateNumerical("scatterNdAdd", "updates", updates);
- return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.scatter.ScatterNdAdd(ref, indices, updates))[0];
+ INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.scatter.ScatterNdAdd(ref, indices, updates));
+ try {
+ return __tmp[0];
+ } finally {
+ if(__tmp != null) {
+ for(int __i = 1; __i < __tmp.length; __i++) {
+ if(__tmp[__i] != null) {
+ __tmp[__i].close();
+ }
+ }
+ }
+ }
}
/**
@@ -1814,7 +2612,18 @@ public INDArray scatterNdSub(INDArray ref, INDArray indices, INDArray updates) {
NDValidation.validateNumerical("scatterNdSub", "ref", ref);
NDValidation.validateNumerical("scatterNdSub", "indices", indices);
NDValidation.validateNumerical("scatterNdSub", "updates", updates);
- return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.scatter.ScatterNdSub(ref, indices, updates))[0];
+ INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.scatter.ScatterNdSub(ref, indices, updates));
+ try {
+ return __tmp[0];
+ } finally {
+ if(__tmp != null) {
+ for(int __i = 1; __i < __tmp.length; __i++) {
+ if(__tmp[__i] != null) {
+ __tmp[__i].close();
+ }
+ }
+ }
+ }
}
/**
@@ -1837,7 +2646,18 @@ public INDArray scatterNdUpdate(INDArray ref, INDArray indices, INDArray updates
NDValidation.validateNumerical("scatterNdUpdate", "ref", ref);
NDValidation.validateNumerical("scatterNdUpdate", "indices", indices);
NDValidation.validateNumerical("scatterNdUpdate", "updates", updates);
- return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.scatter.ScatterNdUpdate(ref, indices, updates))[0];
+ INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.scatter.ScatterNdUpdate(ref, indices, updates));
+ try {
+ return __tmp[0];
+ } finally {
+ if(__tmp != null) {
+ for(int __i = 1; __i < __tmp.length; __i++) {
+ if(__tmp[__i] != null) {
+ __tmp[__i].close();
+ }
+ }
+ }
+ }
}
/**
@@ -1859,7 +2679,18 @@ public INDArray scatterSub(INDArray ref, INDArray indices, INDArray updates) {
NDValidation.validateNumerical("scatterSub", "ref", ref);
NDValidation.validateNumerical("scatterSub", "indices", indices);
NDValidation.validateNumerical("scatterSub", "updates", updates);
- return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.scatter.ScatterSub(ref, indices, updates))[0];
+ INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.scatter.ScatterSub(ref, indices, updates));
+ try {
+ return __tmp[0];
+ } finally {
+ if(__tmp != null) {
+ for(int __i = 1; __i < __tmp.length; __i++) {
+ if(__tmp[__i] != null) {
+ __tmp[__i].close();
+ }
+ }
+ }
+ }
}
/**
@@ -1881,7 +2712,18 @@ public INDArray scatterUpdate(INDArray ref, INDArray indices, INDArray updates)
NDValidation.validateNumerical("scatterUpdate", "ref", ref);
NDValidation.validateNumerical("scatterUpdate", "indices", indices);
NDValidation.validateNumerical("scatterUpdate", "updates", updates);
- return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.scatter.ScatterUpdate(ref, indices, updates))[0];
+ INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.scatter.ScatterUpdate(ref, indices, updates));
+ try {
+ return __tmp[0];
+ } finally {
+ if(__tmp != null) {
+ for(int __i = 1; __i < __tmp.length; __i++) {
+ if(__tmp[__i] != null) {
+ __tmp[__i].close();
+ }
+ }
+ }
+ }
}
/**
@@ -1900,7 +2742,18 @@ public INDArray scatterUpdate(INDArray ref, INDArray indices, INDArray updates)
*/
public INDArray segmentMax(INDArray data, INDArray segmentIds) {
NDValidation.validateNumerical("segmentMax", "segmentIds", segmentIds);
- return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.segment.SegmentMax(data, segmentIds))[0];
+ INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.segment.SegmentMax(data, segmentIds));
+ try {
+ return __tmp[0];
+ } finally {
+ if(__tmp != null) {
+ for(int __i = 1; __i < __tmp.length; __i++) {
+ if(__tmp[__i] != null) {
+ __tmp[__i].close();
+ }
+ }
+ }
+ }
}
/**
@@ -1919,7 +2772,18 @@ public INDArray segmentMax(INDArray data, INDArray segmentIds) {
*/
public INDArray segmentMean(INDArray data, INDArray segmentIds) {
NDValidation.validateNumerical("segmentMean", "segmentIds", segmentIds);
- return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.segment.SegmentMean(data, segmentIds))[0];
+ INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.segment.SegmentMean(data, segmentIds));
+ try {
+ return __tmp[0];
+ } finally {
+ if(__tmp != null) {
+ for(int __i = 1; __i < __tmp.length; __i++) {
+ if(__tmp[__i] != null) {
+ __tmp[__i].close();
+ }
+ }
+ }
+ }
}
/**
@@ -1938,7 +2802,18 @@ public INDArray segmentMean(INDArray data, INDArray segmentIds) {
*/
public INDArray segmentMin(INDArray data, INDArray segmentIds) {
NDValidation.validateNumerical("segmentMin", "segmentIds", segmentIds);
- return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.segment.SegmentMin(data, segmentIds))[0];
+ INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.segment.SegmentMin(data, segmentIds));
+ try {
+ return __tmp[0];
+ } finally {
+ if(__tmp != null) {
+ for(int __i = 1; __i < __tmp.length; __i++) {
+ if(__tmp[__i] != null) {
+ __tmp[__i].close();
+ }
+ }
+ }
+ }
}
/**
@@ -1957,7 +2832,18 @@ public INDArray segmentMin(INDArray data, INDArray segmentIds) {
*/
public INDArray segmentProd(INDArray data, INDArray segmentIds) {
NDValidation.validateNumerical("segmentProd", "segmentIds", segmentIds);
- return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.segment.SegmentProd(data, segmentIds))[0];
+ INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.segment.SegmentProd(data, segmentIds));
+ try {
+ return __tmp[0];
+ } finally {
+ if(__tmp != null) {
+ for(int __i = 1; __i < __tmp.length; __i++) {
+ if(__tmp[__i] != null) {
+ __tmp[__i].close();
+ }
+ }
+ }
+ }
}
/**
@@ -1976,7 +2862,18 @@ public INDArray segmentProd(INDArray data, INDArray segmentIds) {
*/
public INDArray segmentSum(INDArray data, INDArray segmentIds) {
NDValidation.validateNumerical("segmentSum", "segmentIds", segmentIds);
- return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.segment.SegmentSum(data, segmentIds))[0];
+ INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.segment.SegmentSum(data, segmentIds));
+ try {
+ return __tmp[0];
+ } finally {
+ if(__tmp != null) {
+ for(int __i = 1; __i < __tmp.length; __i++) {
+ if(__tmp[__i] != null) {
+ __tmp[__i].close();
+ }
+ }
+ }
+ }
}
/**
@@ -1990,7 +2887,18 @@ public INDArray segmentSum(INDArray data, INDArray segmentIds) {
*/
public INDArray sequenceMask(INDArray lengths, int maxLen, DataType dataType) {
NDValidation.validateNumerical("sequenceMask", "lengths", lengths);
- return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.SequenceMask(lengths, maxLen, dataType))[0];
+ INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.SequenceMask(lengths, maxLen, dataType));
+ try {
+ return __tmp[0];
+ } finally {
+ if(__tmp != null) {
+ for(int __i = 1; __i < __tmp.length; __i++) {
+ if(__tmp[__i] != null) {
+ __tmp[__i].close();
+ }
+ }
+ }
+ }
}
/**
@@ -2005,7 +2913,18 @@ public INDArray sequenceMask(INDArray lengths, int maxLen, DataType dataType) {
public INDArray sequenceMask(INDArray lengths, INDArray maxLen, DataType dataType) {
NDValidation.validateNumerical("sequenceMask", "lengths", lengths);
NDValidation.validateInteger("sequenceMask", "maxLen", maxLen);
- return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.SequenceMask(lengths, maxLen, dataType))[0];
+ INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.SequenceMask(lengths, maxLen, dataType));
+ try {
+ return __tmp[0];
+ } finally {
+ if(__tmp != null) {
+ for(int __i = 1; __i < __tmp.length; __i++) {
+ if(__tmp[__i] != null) {
+ __tmp[__i].close();
+ }
+ }
+ }
+ }
}
/**
@@ -2017,7 +2936,18 @@ public INDArray sequenceMask(INDArray lengths, INDArray maxLen, DataType dataTyp
*/
public INDArray sequenceMask(INDArray lengths, DataType dataType) {
NDValidation.validateNumerical("sequenceMask", "lengths", lengths);
- return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.SequenceMask(lengths, dataType))[0];
+ INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.SequenceMask(lengths, dataType));
+ try {
+ return __tmp[0];
+ } finally {
+ if(__tmp != null) {
+ for(int __i = 1; __i < __tmp.length; __i++) {
+ if(__tmp[__i] != null) {
+ __tmp[__i].close();
+ }
+ }
+ }
+ }
}
/**
@@ -2038,7 +2968,18 @@ public INDArray[] setShape(INDArray input, INDArray shape) {
* @return output 1D output variable with contents equal to the shape of the input (NUMERIC type)
*/
public INDArray shape(INDArray input) {
- return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Shape(input))[0];
+ INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Shape(input));
+ try {
+ return __tmp[0];
+ } finally {
+ if(__tmp != null) {
+ for(int __i = 1; __i < __tmp.length; __i++) {
+ if(__tmp[__i] != null) {
+ __tmp[__i].close();
+ }
+ }
+ }
+ }
}
/**
@@ -2048,7 +2989,18 @@ public INDArray shape(INDArray input) {
* @return output 0D (scalar) output variable with value equal to the number of elements in the specified array (NUMERIC type)
*/
public INDArray size(INDArray in) {
- return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Size(in))[0];
+ INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Size(in));
+ try {
+ return __tmp[0];
+ } finally {
+ if(__tmp != null) {
+ for(int __i = 1; __i < __tmp.length; __i++) {
+ if(__tmp[__i] != null) {
+ __tmp[__i].close();
+ }
+ }
+ }
+ }
}
/**
@@ -2060,7 +3012,18 @@ public INDArray size(INDArray in) {
* @return output Scalar INDArray for size at specified variable (NUMERIC type)
*/
public INDArray sizeAt(INDArray in, int dimension) {
- return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.SizeAt(in, dimension))[0];
+ INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.SizeAt(in, dimension));
+ try {
+ return __tmp[0];
+ } finally {
+ if(__tmp != null) {
+ for(int __i = 1; __i < __tmp.length; __i++) {
+ if(__tmp[__i] != null) {
+ __tmp[__i].close();
+ }
+ }
+ }
+ }
}
/**
@@ -2081,7 +3044,18 @@ public INDArray sizeAt(INDArray in, int dimension) {
public INDArray slice(INDArray input, int[] begin, int... size) {
Preconditions.checkArgument(begin.length >= 1, "begin has incorrect size/length. Expected: begin.length >= 1, got %s", begin.length);
Preconditions.checkArgument(size.length >= 1, "size has incorrect size/length. Expected: size.length >= 1, got %s", size.length);
- return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Slice(input, begin, size))[0];
+ INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Slice(input, begin, size));
+ try {
+ return __tmp[0];
+ } finally {
+ if(__tmp != null) {
+ for(int __i = 1; __i < __tmp.length; __i++) {
+ if(__tmp[__i] != null) {
+ __tmp[__i].close();
+ }
+ }
+ }
+ }
}
/**
@@ -2102,7 +3076,18 @@ public INDArray slice(INDArray input, int[] begin, int... size) {
public INDArray slice(INDArray input, INDArray begin, INDArray size) {
NDValidation.validateInteger("slice", "begin", begin);
NDValidation.validateInteger("slice", "size", size);
- return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Slice(input, begin, size))[0];
+ INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Slice(input, begin, size));
+ try {
+ return __tmp[0];
+ } finally {
+ if(__tmp != null) {
+ for(int __i = 1; __i < __tmp.length; __i++) {
+ if(__tmp[__i] != null) {
+ __tmp[__i].close();
+ }
+ }
+ }
+ }
}
/**
@@ -2117,7 +3102,18 @@ public INDArray sparseToDense(INDArray indices, INDArray shape, INDArray values)
NDValidation.validateNumerical("sparseToDense", "indices", indices);
NDValidation.validateNumerical("sparseToDense", "shape", shape);
NDValidation.validateNumerical("sparseToDense", "values", values);
- return Nd4j.exec(new org.nd4j.linalg.api.ops.compat.CompatSparseToDense(indices, shape, values))[0];
+ INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.compat.CompatSparseToDense(indices, shape, values));
+ try {
+ return __tmp[0];
+ } finally {
+ if(__tmp != null) {
+ for(int __i = 1; __i < __tmp.length; __i++) {
+ if(__tmp[__i] != null) {
+ __tmp[__i].close();
+ }
+ }
+ }
+ }
}
/**
@@ -2135,7 +3131,18 @@ public INDArray sparseToDense(INDArray indices, INDArray shape, INDArray values,
NDValidation.validateNumerical("sparseToDense", "shape", shape);
NDValidation.validateNumerical("sparseToDense", "values", values);
NDValidation.validateNumerical("sparseToDense", "defaultValue", defaultValue);
- return Nd4j.exec(new org.nd4j.linalg.api.ops.compat.CompatSparseToDense(indices, shape, values, defaultValue))[0];
+ INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.compat.CompatSparseToDense(indices, shape, values, defaultValue));
+ try {
+ return __tmp[0];
+ } finally {
+ if(__tmp != null) {
+ for(int __i = 1; __i < __tmp.length; __i++) {
+ if(__tmp[__i] != null) {
+ __tmp[__i].close();
+ }
+ }
+ }
+ }
}
/**
@@ -2226,7 +3233,42 @@ public INDArray squaredNorm(INDArray x, long... dimensions) {
*/
public INDArray squeeze(INDArray x, int axis) {
NDValidation.validateNumerical("squeeze", "x", x);
- return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Squeeze(x, axis))[0];
+ INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Squeeze(x, axis));
+ try {
+ return __tmp[0];
+ } finally {
+ if(__tmp != null) {
+ for(int __i = 1; __i < __tmp.length; __i++) {
+ if(__tmp[__i] != null) {
+ __tmp[__i].close();
+ }
+ }
+ }
+ }
+ }
+
+ /**
+ * Remove all dimensions of size 1 from the input tensor.
+ * For example, if input has shape [a,1,b,1,c] then squeezeAll(input) returns an array of shape [a,b,c]
+ * This is the NumPy-style squeeze with no axis specified.
+ *
+ * @param x Input variable (NUMERIC type)
+ * @return output Output variable (NUMERIC type)
+ */
+ public INDArray squeezeAll(INDArray x) {
+ NDValidation.validateNumerical("squeezeAll", "x", x);
+ INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Squeeze(x));
+ try {
+ return __tmp[0];
+ } finally {
+ if(__tmp != null) {
+ for(int __i = 1; __i < __tmp.length; __i++) {
+ if(__tmp[__i] != null) {
+ __tmp[__i].close();
+ }
+ }
+ }
+ }
}
/**
@@ -2244,7 +3286,18 @@ public INDArray squeeze(INDArray x, int axis) {
*/
public INDArray stack(int axis, INDArray... values) {
Preconditions.checkArgument(values.length >= 1, "values has incorrect size/length. Expected: values.length >= 1, got %s", values.length);
- return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Stack(values, axis))[0];
+ INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Stack(values, axis));
+ try {
+ return __tmp[0];
+ } finally {
+ if(__tmp != null) {
+ for(int __i = 1; __i < __tmp.length; __i++) {
+ if(__tmp[__i] != null) {
+ __tmp[__i].close();
+ }
+ }
+ }
+ }
}
/**
@@ -2317,7 +3370,18 @@ public INDArray stridedSlice(INDArray in, long[] begin, long[] end, long[] strid
Preconditions.checkArgument(begin.length >= 1, "begin has incorrect size/length. Expected: begin.length >= 1, got %s", begin.length);
Preconditions.checkArgument(end.length >= 1, "end has incorrect size/length. Expected: end.length >= 1, got %s", end.length);
Preconditions.checkArgument(strides.length >= 1, "strides has incorrect size/length. Expected: strides.length >= 1, got %s", strides.length);
- return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.StridedSlice(in, begin, end, strides, beginMask, endMask, ellipsisMask, newAxisMask, shrinkAxisMask))[0];
+ INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.StridedSlice(in, begin, end, strides, beginMask, endMask, ellipsisMask, newAxisMask, shrinkAxisMask));
+ try {
+ return __tmp[0];
+ } finally {
+ if(__tmp != null) {
+ for(int __i = 1; __i < __tmp.length; __i++) {
+ if(__tmp[__i] != null) {
+ __tmp[__i].close();
+ }
+ }
+ }
+ }
}
/**
@@ -2340,7 +3404,18 @@ public INDArray stridedSlice(INDArray in, long[] begin, long[] end, long... stri
Preconditions.checkArgument(begin.length >= 1, "begin has incorrect size/length. Expected: begin.length >= 1, got %s", begin.length);
Preconditions.checkArgument(end.length >= 1, "end has incorrect size/length. Expected: end.length >= 1, got %s", end.length);
Preconditions.checkArgument(strides.length >= 1, "strides has incorrect size/length. Expected: strides.length >= 1, got %s", strides.length);
- return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.StridedSlice(in, begin, end, strides, 0, 0, 0, 0, 0))[0];
+ INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.StridedSlice(in, begin, end, strides, 0, 0, 0, 0, 0));
+ try {
+ return __tmp[0];
+ } finally {
+ if(__tmp != null) {
+ for(int __i = 1; __i < __tmp.length; __i++) {
+ if(__tmp[__i] != null) {
+ __tmp[__i].close();
+ }
+ }
+ }
+ }
}
/**
@@ -2369,7 +3444,18 @@ public INDArray stridedSlice(INDArray in, INDArray begin, INDArray end, INDArray
NDValidation.validateNumerical("stridedSlice", "begin", begin);
NDValidation.validateNumerical("stridedSlice", "end", end);
NDValidation.validateNumerical("stridedSlice", "strides", strides);
- return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.StridedSlice(in, begin, end, strides, beginMask, endMask, ellipsisMask, newAxisMask, shrinkAxisMask))[0];
+ INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.StridedSlice(in, begin, end, strides, beginMask, endMask, ellipsisMask, newAxisMask, shrinkAxisMask));
+ try {
+ return __tmp[0];
+ } finally {
+ if(__tmp != null) {
+ for(int __i = 1; __i < __tmp.length; __i++) {
+ if(__tmp[__i] != null) {
+ __tmp[__i].close();
+ }
+ }
+ }
+ }
}
/**
@@ -2392,7 +3478,18 @@ public INDArray stridedSlice(INDArray in, INDArray begin, INDArray end, INDArray
NDValidation.validateNumerical("stridedSlice", "begin", begin);
NDValidation.validateNumerical("stridedSlice", "end", end);
NDValidation.validateNumerical("stridedSlice", "strides", strides);
- return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.StridedSlice(in, begin, end, strides, 0, 0, 0, 0, 0))[0];
+ INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.StridedSlice(in, begin, end, strides, 0, 0, 0, 0, 0));
+ try {
+ return __tmp[0];
+ } finally {
+ if(__tmp != null) {
+ for(int __i = 1; __i < __tmp.length; __i++) {
+ if(__tmp[__i] != null) {
+ __tmp[__i].close();
+ }
+ }
+ }
+ }
}
/**
@@ -2466,7 +3563,18 @@ public INDArray tensorMmul(INDArray x, INDArray y, int[] dimensionsX, int[] dime
NDValidation.validateNumerical("tensorMmul", "y", y);
Preconditions.checkArgument(dimensionsX.length >= 1, "dimensionsX has incorrect size/length. Expected: dimensionsX.length >= 1, got %s", dimensionsX.length);
Preconditions.checkArgument(dimensionsY.length >= 1, "dimensionsY has incorrect size/length. Expected: dimensionsY.length >= 1, got %s", dimensionsY.length);
- return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.TensorMmul(x, y, dimensionsX, dimensionsY, transposeX, transposeY, transposeZ))[0];
+ INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.TensorMmul(x, y, dimensionsX, dimensionsY, transposeX, transposeY, transposeZ));
+ try {
+ return __tmp[0];
+ } finally {
+ if(__tmp != null) {
+ for(int __i = 1; __i < __tmp.length; __i++) {
+ if(__tmp[__i] != null) {
+ __tmp[__i].close();
+ }
+ }
+ }
+ }
}
/**
@@ -2483,7 +3591,18 @@ public INDArray tensorMmul(INDArray x, INDArray y, int[] dimensionsX, int... dim
NDValidation.validateNumerical("tensorMmul", "y", y);
Preconditions.checkArgument(dimensionsX.length >= 1, "dimensionsX has incorrect size/length. Expected: dimensionsX.length >= 1, got %s", dimensionsX.length);
Preconditions.checkArgument(dimensionsY.length >= 1, "dimensionsY has incorrect size/length. Expected: dimensionsY.length >= 1, got %s", dimensionsY.length);
- return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.TensorMmul(x, y, dimensionsX, dimensionsY, false, false, false))[0];
+ INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.TensorMmul(x, y, dimensionsX, dimensionsY, false, false, false));
+ try {
+ return __tmp[0];
+ } finally {
+ if(__tmp != null) {
+ for(int __i = 1; __i < __tmp.length; __i++) {
+ if(__tmp[__i] != null) {
+ __tmp[__i].close();
+ }
+ }
+ }
+ }
}
/**
@@ -2504,7 +3623,18 @@ public INDArray tensorMmul(INDArray x, INDArray y, int[] dimensionsX, int... dim
*/
public INDArray tile(INDArray x, INDArray repeat) {
NDValidation.validateInteger("tile", "repeat", repeat);
- return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Tile(x, repeat))[0];
+ INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Tile(x, repeat));
+ try {
+ return __tmp[0];
+ } finally {
+ if(__tmp != null) {
+ for(int __i = 1; __i < __tmp.length; __i++) {
+ if(__tmp[__i] != null) {
+ __tmp[__i].close();
+ }
+ }
+ }
+ }
}
/**
@@ -2516,7 +3646,18 @@ public INDArray tile(INDArray x, INDArray repeat) {
*/
public INDArray tile(INDArray x, int... repeat) {
Preconditions.checkArgument(repeat.length >= 1, "repeat has incorrect size/length. Expected: repeat.length >= 1, got %s", repeat.length);
- return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Tile(x, repeat))[0];
+ INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Tile(x, repeat));
+ try {
+ return __tmp[0];
+ } finally {
+ if(__tmp != null) {
+ for(int __i = 1; __i < __tmp.length; __i++) {
+ if(__tmp[__i] != null) {
+ __tmp[__i].close();
+ }
+ }
+ }
+ }
}
/**
@@ -2526,7 +3667,18 @@ public INDArray tile(INDArray x, int... repeat) {
* @return output transposed input (NDARRAY type)
*/
public INDArray transpose(INDArray x) {
- return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Transpose(x))[0];
+ INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Transpose(x));
+ try {
+ return __tmp[0];
+ } finally {
+ if(__tmp != null) {
+ for(int __i = 1; __i < __tmp.length; __i++) {
+ if(__tmp[__i] != null) {
+ __tmp[__i].close();
+ }
+ }
+ }
+ }
}
/**
@@ -2544,7 +3696,18 @@ public INDArray transpose(INDArray x) {
public INDArray unsortedSegmentMax(INDArray data, INDArray segmentIds, int numSegments) {
NDValidation.validateNumerical("unsortedSegmentMax", "data", data);
NDValidation.validateNumerical("unsortedSegmentMax", "segmentIds", segmentIds);
- return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentMax(data, segmentIds, numSegments))[0];
+ INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentMax(data, segmentIds, numSegments));
+ try {
+ return __tmp[0];
+ } finally {
+ if(__tmp != null) {
+ for(int __i = 1; __i < __tmp.length; __i++) {
+ if(__tmp[__i] != null) {
+ __tmp[__i].close();
+ }
+ }
+ }
+ }
}
/**
@@ -2563,7 +3726,18 @@ public INDArray unsortedSegmentMax(INDArray data, INDArray segmentIds, INDArray
NDValidation.validateNumerical("unsortedSegmentMax", "data", data);
NDValidation.validateNumerical("unsortedSegmentMax", "segmentIds", segmentIds);
NDValidation.validateInteger("unsortedSegmentMax", "numSegments", numSegments);
- return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentMax(data, segmentIds, numSegments))[0];
+ INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentMax(data, segmentIds, numSegments));
+ try {
+ return __tmp[0];
+ } finally {
+ if(__tmp != null) {
+ for(int __i = 1; __i < __tmp.length; __i++) {
+ if(__tmp[__i] != null) {
+ __tmp[__i].close();
+ }
+ }
+ }
+ }
}
/**
@@ -2581,7 +3755,18 @@ public INDArray unsortedSegmentMax(INDArray data, INDArray segmentIds, INDArray
public INDArray unsortedSegmentMean(INDArray data, INDArray segmentIds, int numSegments) {
NDValidation.validateNumerical("unsortedSegmentMean", "data", data);
NDValidation.validateNumerical("unsortedSegmentMean", "segmentIds", segmentIds);
- return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentMean(data, segmentIds, numSegments))[0];
+ INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentMean(data, segmentIds, numSegments));
+ try {
+ return __tmp[0];
+ } finally {
+ if(__tmp != null) {
+ for(int __i = 1; __i < __tmp.length; __i++) {
+ if(__tmp[__i] != null) {
+ __tmp[__i].close();
+ }
+ }
+ }
+ }
}
/**
@@ -2600,7 +3785,18 @@ public INDArray unsortedSegmentMean(INDArray data, INDArray segmentIds, INDArray
NDValidation.validateNumerical("unsortedSegmentMean", "data", data);
NDValidation.validateNumerical("unsortedSegmentMean", "segmentIds", segmentIds);
NDValidation.validateInteger("unsortedSegmentMean", "numSegments", numSegments);
- return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentMean(data, segmentIds, numSegments))[0];
+ INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentMean(data, segmentIds, numSegments));
+ try {
+ return __tmp[0];
+ } finally {
+ if(__tmp != null) {
+ for(int __i = 1; __i < __tmp.length; __i++) {
+ if(__tmp[__i] != null) {
+ __tmp[__i].close();
+ }
+ }
+ }
+ }
}
/**
@@ -2618,7 +3814,18 @@ public INDArray unsortedSegmentMean(INDArray data, INDArray segmentIds, INDArray
public INDArray unsortedSegmentMin(INDArray data, INDArray segmentIds, int numSegments) {
NDValidation.validateNumerical("unsortedSegmentMin", "data", data);
NDValidation.validateNumerical("unsortedSegmentMin", "segmentIds", segmentIds);
- return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentMin(data, segmentIds, numSegments))[0];
+ INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentMin(data, segmentIds, numSegments));
+ try {
+ return __tmp[0];
+ } finally {
+ if(__tmp != null) {
+ for(int __i = 1; __i < __tmp.length; __i++) {
+ if(__tmp[__i] != null) {
+ __tmp[__i].close();
+ }
+ }
+ }
+ }
}
/**
@@ -2637,7 +3844,18 @@ public INDArray unsortedSegmentMin(INDArray data, INDArray segmentIds, INDArray
NDValidation.validateNumerical("unsortedSegmentMin", "data", data);
NDValidation.validateNumerical("unsortedSegmentMin", "segmentIds", segmentIds);
NDValidation.validateInteger("unsortedSegmentMin", "numSegments", numSegments);
- return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentMin(data, segmentIds, numSegments))[0];
+ INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentMin(data, segmentIds, numSegments));
+ try {
+ return __tmp[0];
+ } finally {
+ if(__tmp != null) {
+ for(int __i = 1; __i < __tmp.length; __i++) {
+ if(__tmp[__i] != null) {
+ __tmp[__i].close();
+ }
+ }
+ }
+ }
}
/**
@@ -2655,7 +3873,18 @@ public INDArray unsortedSegmentMin(INDArray data, INDArray segmentIds, INDArray
public INDArray unsortedSegmentProd(INDArray data, INDArray segmentIds, int numSegments) {
NDValidation.validateNumerical("unsortedSegmentProd", "data", data);
NDValidation.validateNumerical("unsortedSegmentProd", "segmentIds", segmentIds);
- return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentProd(data, segmentIds, numSegments))[0];
+ INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentProd(data, segmentIds, numSegments));
+ try {
+ return __tmp[0];
+ } finally {
+ if(__tmp != null) {
+ for(int __i = 1; __i < __tmp.length; __i++) {
+ if(__tmp[__i] != null) {
+ __tmp[__i].close();
+ }
+ }
+ }
+ }
}
/**
@@ -2674,7 +3903,18 @@ public INDArray unsortedSegmentProd(INDArray data, INDArray segmentIds, INDArray
NDValidation.validateNumerical("unsortedSegmentProd", "data", data);
NDValidation.validateNumerical("unsortedSegmentProd", "segmentIds", segmentIds);
NDValidation.validateInteger("unsortedSegmentProd", "numSegments", numSegments);
- return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentProd(data, segmentIds, numSegments))[0];
+ INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentProd(data, segmentIds, numSegments));
+ try {
+ return __tmp[0];
+ } finally {
+ if(__tmp != null) {
+ for(int __i = 1; __i < __tmp.length; __i++) {
+ if(__tmp[__i] != null) {
+ __tmp[__i].close();
+ }
+ }
+ }
+ }
}
/**
@@ -2691,7 +3931,18 @@ public INDArray unsortedSegmentProd(INDArray data, INDArray segmentIds, INDArray
public INDArray unsortedSegmentSqrtN(INDArray data, INDArray segmentIds, int numSegments) {
NDValidation.validateNumerical("unsortedSegmentSqrtN", "data", data);
NDValidation.validateNumerical("unsortedSegmentSqrtN", "segmentIds", segmentIds);
- return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentSqrtN(data, segmentIds, numSegments))[0];
+ INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentSqrtN(data, segmentIds, numSegments));
+ try {
+ return __tmp[0];
+ } finally {
+ if(__tmp != null) {
+ for(int __i = 1; __i < __tmp.length; __i++) {
+ if(__tmp[__i] != null) {
+ __tmp[__i].close();
+ }
+ }
+ }
+ }
}
/**
@@ -2709,7 +3960,18 @@ public INDArray unsortedSegmentSqrtN(INDArray data, INDArray segmentIds, INDArra
NDValidation.validateNumerical("unsortedSegmentSqrtN", "data", data);
NDValidation.validateNumerical("unsortedSegmentSqrtN", "segmentIds", segmentIds);
NDValidation.validateInteger("unsortedSegmentSqrtN", "numSegments", numSegments);
- return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentSqrtN(data, segmentIds, numSegments))[0];
+ INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentSqrtN(data, segmentIds, numSegments));
+ try {
+ return __tmp[0];
+ } finally {
+ if(__tmp != null) {
+ for(int __i = 1; __i < __tmp.length; __i++) {
+ if(__tmp[__i] != null) {
+ __tmp[__i].close();
+ }
+ }
+ }
+ }
}
/**
@@ -2727,7 +3989,18 @@ public INDArray unsortedSegmentSqrtN(INDArray data, INDArray segmentIds, INDArra
public INDArray unsortedSegmentSum(INDArray data, INDArray segmentIds, int numSegments) {
NDValidation.validateNumerical("unsortedSegmentSum", "data", data);
NDValidation.validateNumerical("unsortedSegmentSum", "segmentIds", segmentIds);
- return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentSum(data, segmentIds, numSegments))[0];
+ INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentSum(data, segmentIds, numSegments));
+ try {
+ return __tmp[0];
+ } finally {
+ if(__tmp != null) {
+ for(int __i = 1; __i < __tmp.length; __i++) {
+ if(__tmp[__i] != null) {
+ __tmp[__i].close();
+ }
+ }
+ }
+ }
}
/**
@@ -2746,7 +4019,18 @@ public INDArray unsortedSegmentSum(INDArray data, INDArray segmentIds, INDArray
NDValidation.validateNumerical("unsortedSegmentSum", "data", data);
NDValidation.validateNumerical("unsortedSegmentSum", "segmentIds", segmentIds);
NDValidation.validateInteger("unsortedSegmentSum", "numSegments", numSegments);
- return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentSum(data, segmentIds, numSegments))[0];
+ INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentSum(data, segmentIds, numSegments));
+ try {
+ return __tmp[0];
+ } finally {
+ if(__tmp != null) {
+ for(int __i = 1; __i < __tmp.length; __i++) {
+ if(__tmp[__i] != null) {
+ __tmp[__i].close();
+ }
+ }
+ }
+ }
}
/**
@@ -2825,7 +4109,18 @@ public INDArray variance(INDArray x, boolean biasCorrected, long... dimensions)
*/
public INDArray where(INDArray x, INDArray y, INDArray condition) {
NDValidation.validateBool("where", "condition", condition);
- return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.controlflow.Where(x, y, condition))[0];
+ INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.controlflow.Where(x, y, condition));
+ try {
+ return __tmp[0];
+ } finally {
+ if(__tmp != null) {
+ for(int __i = 1; __i < __tmp.length; __i++) {
+ if(__tmp[__i] != null) {
+ __tmp[__i].close();
+ }
+ }
+ }
+ }
}
/**
@@ -2845,7 +4140,18 @@ public INDArray where(INDArray x, INDArray y, INDArray condition) {
public INDArray where(INDArray x, INDArray condition) {
NDValidation.validateNumerical("where", "x", x);
NDValidation.validateBool("where", "condition", condition);
- return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.controlflow.Where(x, condition))[0];
+ INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.controlflow.Where(x, condition));
+ try {
+ return __tmp[0];
+ } finally {
+ if(__tmp != null) {
+ for(int __i = 1; __i < __tmp.length; __i++) {
+ if(__tmp[__i] != null) {
+ __tmp[__i].close();
+ }
+ }
+ }
+ }
}
/**
@@ -2863,7 +4169,18 @@ public INDArray where(INDArray x, INDArray condition) {
*/
public INDArray where(INDArray condition) {
NDValidation.validateBool("where", "condition", condition);
- return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.controlflow.Where(condition))[0];
+ INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.controlflow.Where(condition));
+ try {
+ return __tmp[0];
+ } finally {
+ if(__tmp != null) {
+ for(int __i = 1; __i < __tmp.length; __i++) {
+ if(__tmp[__i] != null) {
+ __tmp[__i].close();
+ }
+ }
+ }
+ }
}
/**
@@ -2883,7 +4200,18 @@ public INDArray where(INDArray condition) {
*/
public INDArray whereNumpy(INDArray x, INDArray y, INDArray condition) {
NDValidation.validateNumerical("whereNumpy", "condition", condition);
- return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.controlflow.WhereNumpy(x, y, condition))[0];
+ INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.controlflow.WhereNumpy(x, y, condition));
+ try {
+ return __tmp[0];
+ } finally {
+ if(__tmp != null) {
+ for(int __i = 1; __i < __tmp.length; __i++) {
+ if(__tmp[__i] != null) {
+ __tmp[__i].close();
+ }
+ }
+ }
+ }
}
/**
@@ -2894,6 +4222,17 @@ public INDArray whereNumpy(INDArray x, INDArray y, INDArray condition) {
* @return output A new Variable with the same (dynamic) shape as the input (NUMERIC type)
*/
public INDArray zerosLike(INDArray input) {
- return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.ZerosLike(input))[0];
+ INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.ZerosLike(input));
+ try {
+ return __tmp[0];
+ } finally {
+ if(__tmp != null) {
+ for(int __i = 1; __i < __tmp.length; __i++) {
+ if(__tmp[__i] != null) {
+ __tmp[__i].close();
+ }
+ }
+ }
+ }
}
}
diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDNN.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDNN.java
index 1a211b65d47d..a0165ed8bda5 100644
--- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDNN.java
+++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDNN.java
@@ -42,7 +42,18 @@ public NDNN() {
*/
public INDArray cReLU(INDArray x) {
NDValidation.validateNumerical("CReLU", "x", x);
- return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.CReLU(x))[0];
+ INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.CReLU(x));
+ try {
+ return __tmp[0];
+ } finally {
+ if(__tmp != null) {
+ for(int __i = 1; __i < __tmp.length; __i++) {
+ if(__tmp[__i] != null) {
+ __tmp[__i].close();
+ }
+ }
+ }
+ }
}
/**
@@ -68,7 +79,18 @@ public INDArray batchNorm(INDArray input, INDArray mean, INDArray variance, INDA
NDValidation.validateNumerical("batchNorm", "gamma", gamma);
NDValidation.validateNumerical("batchNorm", "beta", beta);
Preconditions.checkArgument(axis.length >= 1, "axis has incorrect size/length. Expected: axis.length >= 1, got %s", axis.length);
- return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.layers.convolution.BatchNorm(input, mean, variance, gamma, beta, epsilon, axis))[0];
+ INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.layers.convolution.BatchNorm(input, mean, variance, gamma, beta, epsilon, axis));
+ try {
+ return __tmp[0];
+ } finally {
+ if(__tmp != null) {
+ for(int __i = 1; __i < __tmp.length; __i++) {
+ if(__tmp[__i] != null) {
+ __tmp[__i].close();
+ }
+ }
+ }
+ }
}
/**
@@ -83,7 +105,134 @@ public INDArray batchNorm(INDArray input, INDArray mean, INDArray variance, INDA
public INDArray biasAdd(INDArray input, INDArray bias, boolean nchw) {
NDValidation.validateNumerical("biasAdd", "input", input);
NDValidation.validateNumerical("biasAdd", "bias", bias);
- return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.broadcast.BiasAdd(input, bias, nchw))[0];
+ INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.broadcast.BiasAdd(input, bias, nchw));
+ try {
+ return __tmp[0];
+ } finally {
+ if(__tmp != null) {
+ for(int __i = 1; __i < __tmp.length; __i++) {
+ if(__tmp[__i] != null) {
+ __tmp[__i].close();
+ }
+ }
+ }
+ }
+ }
+
+ /**
+ * DINOv2 centering and sharpening operation.
+ * Prevents mode collapse in self-supervised learning by centering the teacher output
+ * and applying temperature-based sharpening:
+ * output = softmax((input - center) / temperature)
+ *
+ * @param input Teacher output logits [batch, features] (NUMERIC type)
+ * @param center Running center vector [features] (NUMERIC type)
+ * @param temperature Sharpening temperature (typically 0.04-0.07)
+ * @return output Sharpened probabilities [batch, features] (NUMERIC type)
+ */
+ public INDArray centerAndSharpen(INDArray input, INDArray center, double temperature) {
+ NDValidation.validateNumerical("centerAndSharpen", "input", input);
+ NDValidation.validateNumerical("centerAndSharpen", "center", center);
+ INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.CenterAndSharpen(input, center, temperature));
+ try {
+ return __tmp[0];
+ } finally {
+ if(__tmp != null) {
+ for(int __i = 1; __i < __tmp.length; __i++) {
+ if(__tmp[__i] != null) {
+ __tmp[__i].close();
+ }
+ }
+ }
+ }
+ }
+
+ /**
+ * DINOv2 centering and sharpening operation.
+ * Prevents mode collapse in self-supervised learning by centering the teacher output
+ * and applying temperature-based sharpening:
+ * output = softmax((input - center) / temperature)
+ *
+ * @param input Teacher output logits [batch, features] (NUMERIC type)
+ * @param center Running center vector [features] (NUMERIC type)
+ * @return output Sharpened probabilities [batch, features] (NUMERIC type)
+ */
+ public INDArray centerAndSharpen(INDArray input, INDArray center) {
+ NDValidation.validateNumerical("centerAndSharpen", "input", input);
+ NDValidation.validateNumerical("centerAndSharpen", "center", center);
+ INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.CenterAndSharpen(input, center, 0.07));
+ try {
+ return __tmp[0];
+ } finally {
+ if(__tmp != null) {
+ for(int __i = 1; __i < __tmp.length; __i++) {
+ if(__tmp[__i] != null) {
+ __tmp[__i].close();
+ }
+ }
+ }
+ }
+ }
+
+ /**
+ * CTC Greedy Decoder - Connectionist Temporal Classification decoding.
+ *
+ * Performs greedy (best path) decoding on CTC output. Used in:
+ * - OCR (Optical Character Recognition) - PaddleOCR, CRNN
+ * - Speech recognition - DeepSpeech, Wav2Vec
+ * - Handwriting recognition
+ *
+ * Algorithm:
+ * 1. At each timestep, select the class with highest probability
+ * 2. Optionally merge consecutive repeated characters
+ * 3. Remove blank labels from the output
+ *
+ * For example, with mergeRepeated=true and blankIndex=0:
+ * Input: [0, 1, 1, 0, 2, 2, 2, 0] (0=blank, 1='a', 2='b')
+ * Output: [1, 2] -> "ab"
+ *
+ * Note: This is greedy decoding. For better accuracy with language models,
+ * use beam search decoding instead.
+ *
+ * @param logits Log probabilities from CTC output. Shape: [batch, timeSteps, numClasses] (NUMERIC type)
+ * @param mergeRepeated Whether to merge repeated characters in output
+ * @param blankIndex Index of the blank label in the vocabulary
+ */
+ public INDArray[] ctcGreedyDecoder(INDArray logits, boolean mergeRepeated, int blankIndex) {
+ NDValidation.validateNumerical("ctcGreedyDecoder", "logits", logits);
+ return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.CTCGreedyDecoder(logits, null, mergeRepeated, blankIndex));
+ }
+
+ /**
+ * CTC Greedy Decoder - Connectionist Temporal Classification decoding.
+ *
+ * Performs greedy (best path) decoding on CTC output. Used in:
+ * - OCR (Optical Character Recognition) - PaddleOCR, CRNN
+ * - Speech recognition - DeepSpeech, Wav2Vec
+ * - Handwriting recognition
+ *
+ * Algorithm:
+ * 1. At each timestep, select the class with highest probability
+ * 2. Optionally merge consecutive repeated characters
+ * 3. Remove blank labels from the output
+ *
+ * For example, with mergeRepeated=true and blankIndex=0:
+ * Input: [0, 1, 1, 0, 2, 2, 2, 0] (0=blank, 1='a', 2='b')
+ * Output: [1, 2] -> "ab"
+ *
+ * Note: This is greedy decoding. For better accuracy with language models,
+ * use beam search decoding instead.
+ *
+ * @param logits Log probabilities from CTC output. Shape: [batch, timeSteps, numClasses] (NUMERIC type)
+ * @param sequenceLength Optional actual sequence lengths. Shape: [batch] (NUMERIC type)
+ * @param mergeRepeated Whether to merge repeated characters in output
+ * @param blankIndex Index of the blank label in the vocabulary
+ */
+ public INDArray[] ctcGreedyDecoder(INDArray logits, INDArray sequenceLength,
+ boolean mergeRepeated, int blankIndex) {
+ NDValidation.validateNumerical("ctcGreedyDecoder", "logits", logits);
+ NDValidation.validateNumerical("ctcGreedyDecoder", "sequenceLength", sequenceLength);
+ return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.CTCGreedyDecoder(logits, sequenceLength, mergeRepeated, blankIndex));
}
/**
@@ -123,39 +272,52 @@ public INDArray dotProductAttention(INDArray queries, INDArray keys, INDArray va
NDValidation.validateNumerical("dotProductAttention", "keys", keys);
NDValidation.validateNumerical("dotProductAttention", "values", values);
NDValidation.validateNumerical("dotProductAttention", "mask", mask);
- return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.DotProductAttention(queries, keys, values, mask, scaled, false))[0];
+ INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.DotProductAttention(queries, keys, values, mask, scaled, false));
+ try {
+ return __tmp[0];
+ } finally {
+ if(__tmp != null) {
+ for(int __i = 1; __i < __tmp.length; __i++) {
+ if(__tmp[__i] != null) {
+ __tmp[__i].close();
+ }
+ }
+ }
+ }
}
/**
- * This operation performs dot product attention on the given timeseries input with the given queries
- * out = sum(similarity(k_i, q) * v_i)
+ * Dot product attention operation with flash attention and KV cache support.
*
- * similarity(k, q) = softmax(k * q) where x * q is the dot product of x and q
+ * out = softmax(Q * K^T / scale + attentionBias) * V
*
- * Optionally with normalization step:
- * similarity(k, q) = softmax(k * q / sqrt(size(q))
+ * For 4D inputs [batch, seq, heads, dim], uses memory-efficient flash attention algorithm.
+ * For 2D/3D inputs, uses standard attention computation.
*
- * See also "Attention is all you need" (https://arxiv.org/abs/1706.03762, p. 4, eq. 1)
+ * Flash attention features:
+ * - O(N) memory complexity instead of O(N^2)
+ * - Tiled computation with online softmax
+ * - Supports grouped query attention (GQA) where numHeads > numKvHeads
+ * - Supports attention bias (relative position bias, ALiBi, etc.)
*
- * Note: This supports multiple queries at once, if only one query is available the queries vector still has to
- * be 3D but can have queryCount = 1
+ * KV Cache support for autoregressive generation:
+ * - Pass keyCache and valueCache tensors
+ * - Set kvCachePosition to current generation position
+ * - Cached keys/values are updated in-place
*
- * Note: keys and values usually is the same array. If you want to use it as the same array, simply pass it for
- * both.
- *
- * Note: Queries, keys and values must either be all rank 3 or all rank 4 arrays. Mixing them doesn't work. The
- * output rank will depend on the input rank.
+ * See "Attention is all you need" (https://arxiv.org/abs/1706.03762)
+ * See "FlashAttention: Fast and Memory-Efficient Exact Attention" (https://arxiv.org/abs/2205.14135)
*
- * @param queries A {@link SDVariable} representing the query tensor. Shape: [batchSize, numQueries, queryDim] (NUMERIC type)
- * @param values A {@link SDVariable} representing the value tensor. Shape: [batchSize, numValues, valueDim] (NUMERIC type)
- * @param keys A {@link SDVariable} representing the key tensor. Shape: [batchSize, numValues, keyDim] (NUMERIC type)
- * @param queryMask A {@link SDVariable} representing the query mask tensor. Shape: [batchSize, numQueries] (NUMERIC type)
- * @param valueMask @param valueMask A {@link SDVariable} representing the value mask tensor. Shape: [batchSize, numValues] (NUMERIC type)
- * @param scaleFactor @param scaleFactor A {@code double} scaling factor applied to the dot product between queries and keys.
- * @param dropoutProbability A {@code double} specifying the dropout probability to be applied to attention weights.
- * @param useCausalMask A {@code boolean} flag to indicate whether to apply a causal mask to the attention scores, for autoregressive tasks.
- * @param training A {@code boolean} flag to indicate whether the layer is in training mode or inference mode, affecting dropout.
- * @return output A {@link SDVariable} representing the output tensor of the dot product attention operation. Shape: [batchSize, numQueries, valueDim] (NUMERIC type)
+ * @param queries Query tensor. Shape: [batchSize, numQueries, queryDim] or [batchSize, numQueries, numHeads, headDim] for flash attention (NUMERIC type)
+ * @param values Value tensor. Shape: [batchSize, numValues, valueDim] or [batchSize, numValues, numHeads, headDim] (NUMERIC type)
+ * @param keys Key tensor. Shape: [batchSize, numValues, keyDim] or [batchSize, numValues, numHeads, headDim] (NUMERIC type)
+ * @param queryMask Query mask tensor (optional). Shape: [batchSize, numQueries] (NUMERIC type)
+ * @param valueMask Value mask tensor (optional). Shape: [batchSize, numValues] (NUMERIC type)
+ * @param scaleFactor Scaling factor applied to attention scores. 0 = auto (1/sqrt(headDim))
+ * @param dropoutProbability Dropout probability applied to attention weights
+ * @param useCausalMask Whether to apply causal mask for autoregressive tasks
+ * @param training Whether in training mode (affects dropout)
+ * @return output Output tensor. Shape: [batchSize, numQueries, valueDim] or [batchSize, numQueries, numHeads, headDim] (NUMERIC type)
*/
public INDArray dotProductAttentionV2(INDArray queries, INDArray values, INDArray keys,
INDArray queryMask, INDArray valueMask, double scaleFactor, double dropoutProbability,
@@ -165,7 +327,75 @@ public INDArray dotProductAttentionV2(INDArray queries, INDArray values, INDArra
NDValidation.validateNumerical("dotProductAttentionV2", "keys", keys);
NDValidation.validateNumerical("dotProductAttentionV2", "queryMask", queryMask);
NDValidation.validateNumerical("dotProductAttentionV2", "valueMask", valueMask);
- return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.DotProductAttentionV2(queries, values, keys, queryMask, valueMask, scaleFactor, dropoutProbability, useCausalMask, training))[0];
+ INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.DotProductAttentionV2(queries, values, keys, queryMask, valueMask, null, scaleFactor, dropoutProbability, useCausalMask, training));
+ try {
+ return __tmp[0];
+ } finally {
+ if(__tmp != null) {
+ for(int __i = 1; __i < __tmp.length; __i++) {
+ if(__tmp[__i] != null) {
+ __tmp[__i].close();
+ }
+ }
+ }
+ }
+ }
+
+ /**
+ * Dot product attention operation with flash attention and KV cache support.
+ *
+ * out = softmax(Q * K^T / scale + attentionBias) * V
+ *
+ * For 4D inputs [batch, seq, heads, dim], uses memory-efficient flash attention algorithm.
+ * For 2D/3D inputs, uses standard attention computation.
+ *
+ * Flash attention features:
+ * - O(N) memory complexity instead of O(N^2)
+ * - Tiled computation with online softmax
+ * - Supports grouped query attention (GQA) where numHeads > numKvHeads
+ * - Supports attention bias (relative position bias, ALiBi, etc.)
+ *
+ * KV Cache support for autoregressive generation:
+ * - Pass keyCache and valueCache tensors
+ * - Set kvCachePosition to current generation position
+ * - Cached keys/values are updated in-place
+ *
+ * See "Attention is all you need" (https://arxiv.org/abs/1706.03762)
+ * See "FlashAttention: Fast and Memory-Efficient Exact Attention" (https://arxiv.org/abs/2205.14135)
+ *
+ * @param queries Query tensor. Shape: [batchSize, numQueries, queryDim] or [batchSize, numQueries, numHeads, headDim] for flash attention (NUMERIC type)
+ * @param values Value tensor. Shape: [batchSize, numValues, valueDim] or [batchSize, numValues, numHeads, headDim] (NUMERIC type)
+ * @param keys Key tensor. Shape: [batchSize, numValues, keyDim] or [batchSize, numValues, numHeads, headDim] (NUMERIC type)
+ * @param queryMask Query mask tensor (optional). Shape: [batchSize, numQueries] (NUMERIC type)
+ * @param valueMask Value mask tensor (optional). Shape: [batchSize, numValues] (NUMERIC type)
+ * @param attentionBias Attention bias tensor (optional). Shape: [batchSize, numHeads, numQueries, numKeys] or broadcastable. Added to attention scores before softmax. (NUMERIC type)
+ * @param scaleFactor Scaling factor applied to attention scores. 0 = auto (1/sqrt(headDim))
+ * @param dropoutProbability Dropout probability applied to attention weights
+ * @param useCausalMask Whether to apply causal mask for autoregressive tasks
+ * @param training Whether in training mode (affects dropout)
+ * @return output Output tensor. Shape: [batchSize, numQueries, valueDim] or [batchSize, numQueries, numHeads, headDim] (NUMERIC type)
+ */
+ public INDArray dotProductAttentionV2(INDArray queries, INDArray values, INDArray keys,
+ INDArray queryMask, INDArray valueMask, INDArray attentionBias, double scaleFactor,
+ double dropoutProbability, boolean useCausalMask, boolean training) {
+ NDValidation.validateNumerical("dotProductAttentionV2", "queries", queries);
+ NDValidation.validateNumerical("dotProductAttentionV2", "values", values);
+ NDValidation.validateNumerical("dotProductAttentionV2", "keys", keys);
+ NDValidation.validateNumerical("dotProductAttentionV2", "queryMask", queryMask);
+ NDValidation.validateNumerical("dotProductAttentionV2", "valueMask", valueMask);
+ NDValidation.validateNumerical("dotProductAttentionV2", "attentionBias", attentionBias);
+ INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.DotProductAttentionV2(queries, values, keys, queryMask, valueMask, attentionBias, scaleFactor, dropoutProbability, useCausalMask, training));
+ try {
+ return __tmp[0];
+ } finally {
+ if(__tmp != null) {
+ for(int __i = 1; __i < __tmp.length; __i++) {
+ if(__tmp[__i] != null) {
+ __tmp[__i].close();
+ }
+ }
+ }
+ }
}
/**
@@ -179,7 +409,18 @@ public INDArray dotProductAttentionV2(INDArray queries, INDArray values, INDArra
*/
public INDArray dropout(INDArray input, boolean inverted, int seed, double probabilityValue) {
NDValidation.validateNumerical("dropout", "input", input);
- return Nd4j.exec(new org.nd4j.linalg.api.ops.random.impl.CustomDropOut(input, inverted, seed, probabilityValue))[0];
+ INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.random.impl.CustomDropOut(input, inverted, seed, probabilityValue));
+ try {
+ return __tmp[0];
+ } finally {
+ if(__tmp != null) {
+ for(int __i = 1; __i < __tmp.length; __i++) {
+ if(__tmp[__i] != null) {
+ __tmp[__i].close();
+ }
+ }
+ }
+ }
}
/**
@@ -192,7 +433,18 @@ public INDArray dropout(INDArray input, boolean inverted, int seed, double proba
*/
public INDArray dropout(INDArray input, boolean inverted, double probabilityValue) {
NDValidation.validateNumerical("dropout", "input", input);
- return Nd4j.exec(new org.nd4j.linalg.api.ops.random.impl.CustomDropOut(input, inverted, 0, probabilityValue))[0];
+ INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.random.impl.CustomDropOut(input, inverted, 0, probabilityValue));
+ try {
+ return __tmp[0];
+ } finally {
+ if(__tmp != null) {
+ for(int __i = 1; __i < __tmp.length; __i++) {
+ if(__tmp[__i] != null) {
+ __tmp[__i].close();
+ }
+ }
+ }
+ }
}
/**
@@ -208,7 +460,140 @@ public INDArray dropout(INDArray input, boolean inverted, double probabilityValu
*/
public INDArray elu(INDArray x) {
NDValidation.validateNumerical("elu", "x", x);
- return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.strict.ELU(x))[0];
+ INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.strict.ELU(x));
+ try {
+ return __tmp[0];
+ } finally {
+ if(__tmp != null) {
+ for(int __i = 1; __i < __tmp.length; __i++) {
+ if(__tmp[__i] != null) {
+ __tmp[__i].close();
+ }
+ }
+ }
+ }
+ }
+
+ /**
+ * Exponential Moving Average parameter update for DINOv2 teacher networks.
+ * Computes: output = decay * shadow + (1 - decay) * model
+ * Used in self-supervised learning to maintain a slowly-updated teacher model.
+ *
+ * @param model Current model parameters (student) (NUMERIC type)
+ * @param shadow EMA shadow parameters (teacher) (NUMERIC type)
+ * @param decay EMA decay factor (typically 0.996-0.9999)
+ * @return output Updated shadow parameters (NUMERIC type)
+ */
+ public INDArray emaUpdate(INDArray model, INDArray shadow, double decay) {
+ NDValidation.validateNumerical("emaUpdate", "model", model);
+ NDValidation.validateNumerical("emaUpdate", "shadow", shadow);
+ INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.EmaUpdate(model, shadow, decay));
+ try {
+ return __tmp[0];
+ } finally {
+ if(__tmp != null) {
+ for(int __i = 1; __i < __tmp.length; __i++) {
+ if(__tmp[__i] != null) {
+ __tmp[__i].close();
+ }
+ }
+ }
+ }
+ }
+
+ /**
+ * Exponential Moving Average parameter update for DINOv2 teacher networks.
+ * Computes: output = decay * shadow + (1 - decay) * model
+ * Used in self-supervised learning to maintain a slowly-updated teacher model.
+ *
+ * @param model Current model parameters (student) (NUMERIC type)
+ * @param shadow EMA shadow parameters (teacher) (NUMERIC type)
+ * @return output Updated shadow parameters (NUMERIC type)
+ */
+ public INDArray emaUpdate(INDArray model, INDArray shadow) {
+ NDValidation.validateNumerical("emaUpdate", "model", model);
+ NDValidation.validateNumerical("emaUpdate", "shadow", shadow);
+ INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.EmaUpdate(model, shadow, 0.999));
+ try {
+ return __tmp[0];
+ } finally {
+ if(__tmp != null) {
+ for(int __i = 1; __i < __tmp.length; __i++) {
+ if(__tmp[__i] != null) {
+ __tmp[__i].close();
+ }
+ }
+ }
+ }
+ }
+
+ /**
+ * Flash Attention - Memory-efficient attention computation.
+ *
+ * Uses tiled computation with online softmax to achieve O(N) memory complexity
+ * instead of O(N^2) for standard attention.
+ *
+ * Supports Grouped Query Attention (GQA) where numHeads > numKvHeads,
+ * allowing multiple query heads to share the same KV heads.
+ *
+ * out = softmax(Q * K^T / scale) * V
+ *
+ * See "FlashAttention: Fast and Memory-Efficient Exact Attention" (https://arxiv.org/abs/2205.14135)
+ *
+ * @param query Query tensor. Shape: [batch, seqLen, numHeads, headDim] (NUMERIC type)
+ * @param key Key tensor. Shape: [batch, seqLen, numKvHeads, headDim] (NUMERIC type)
+ * @param value Value tensor. Shape: [batch, seqLen, numKvHeads, headDim] (NUMERIC type)
+ * @param scale Scaling factor. 0 = auto (1/sqrt(headDim))
+ * @param isCausal Whether to apply causal masking
+ * @param numHeads Number of query attention heads
+ * @param numKvHeads Number of KV heads (0 = same as numHeads, for GQA use smaller value)
+ * @return output Attention output. Shape: [batch, seqLen, numHeads, headDim] (NUMERIC type)
+ */
+ public INDArray flashAttention(INDArray query, INDArray key, INDArray value, double scale,
+ boolean isCausal, int numHeads, int numKvHeads) {
+ NDValidation.validateNumerical("flashAttention", "query", query);
+ NDValidation.validateNumerical("flashAttention", "key", key);
+ NDValidation.validateNumerical("flashAttention", "value", value);
+ INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.FlashAttention(query, key, value, scale, isCausal, numHeads, numKvHeads));
+ try {
+ return __tmp[0];
+ } finally {
+ if(__tmp != null) {
+ for(int __i = 1; __i < __tmp.length; __i++) {
+ if(__tmp[__i] != null) {
+ __tmp[__i].close();
+ }
+ }
+ }
+ }
+ }
+
+ /**
+ * Executes a fused chain of element-wise operations in a single kernel pass.
+ * Intermediate values stay in registers instead of global memory. Replaces N separate kernel launches with 1.
+ *
+ * @param input Primary input array (NUMERIC type)
+ * @param secondaryInputs Optional secondary input arrays for binary ops (add, sub, mul, div) (NUMERIC type)
+ * @param opCodes Op codes: 0=add, 1=sub, 2=mul, 3=div, 10=relu, 11=sigmoid, 12=tanh, 13=gelu, 14=exp, 15=log, 16=abs, 17=neg, 18=square, 19=sqrt, 20=swish, 21=silu, 22=mish, 30=clip, 31=leaky_relu (Size: AtLeast(min=1))
+ * @return output Result of applying the fused element-wise chain (NUMERIC type)
+ */
+ public INDArray fusedElementwiseChain(INDArray input, INDArray[] secondaryInputs, int[] opCodes) {
+ NDValidation.validateNumerical("fusedElementwiseChain", "input", input);
+ NDValidation.validateNumerical("fusedElementwiseChain", "secondaryInputs", secondaryInputs);
+ Preconditions.checkArgument(secondaryInputs.length >= 0, "secondaryInputs has incorrect size/length. Expected: secondaryInputs.length >= 0, got %s", secondaryInputs.length);
+ Preconditions.checkArgument(opCodes.length >= 1, "opCodes has incorrect size/length. Expected: opCodes.length >= 1, got %s", opCodes.length);
+ INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.FusedElementwiseChain(input, secondaryInputs, opCodes));
+ try {
+ return __tmp[0];
+ } finally {
+ if(__tmp != null) {
+ for(int __i = 1; __i < __tmp.length; __i++) {
+ if(__tmp[__i] != null) {
+ __tmp[__i].close();
+ }
+ }
+ }
+ }
}
/**
@@ -224,6 +609,49 @@ public INDArray gelu(INDArray x) {
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.strict.GELU(x));
}
+ /**
+ * Grouped Query Attention (GQA) - Efficient attention with shared KV heads.
+ *
+ * Multiple query heads share the same key-value heads, reducing memory and
+ * computation while maintaining model quality. Used in LLaMA 2, Mistral, etc.
+ *
+ * numHeads must be divisible by numKvHeads. Each KV head is repeated
+ * (numHeads / numKvHeads) times to match query heads.
+ *
+ * Special cases:
+ * - numKvHeads == numHeads: Standard Multi-Head Attention (MHA)
+ * - numKvHeads == 1: Multi-Query Attention (MQA)
+ *
+ * See "GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints"
+ *
+ * @param query Query tensor. Shape: [batch, seqLen, numHeads, headDim] (NUMERIC type)
+ * @param key Key tensor. Shape: [batch, seqLen, numKvHeads, headDim] (NUMERIC type)
+ * @param value Value tensor. Shape: [batch, seqLen, numKvHeads, headDim] (NUMERIC type)
+ * @param scale Scaling factor. 0 = auto (1/sqrt(headDim))
+ * @param isCausal Whether to apply causal masking
+ * @param numHeads Number of query attention heads
+ * @param numKvHeads Number of KV heads (must divide numHeads evenly)
+ * @return output Attention output. Shape: [batch, seqLen, numHeads, headDim] (NUMERIC type)
+ */
+ public INDArray groupedQueryAttention(INDArray query, INDArray key, INDArray value, double scale,
+ boolean isCausal, int numHeads, int numKvHeads) {
+ NDValidation.validateNumerical("groupedQueryAttention", "query", query);
+ NDValidation.validateNumerical("groupedQueryAttention", "key", key);
+ NDValidation.validateNumerical("groupedQueryAttention", "value", value);
+ INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.GroupedQueryAttention(query, key, value, scale, isCausal, numHeads, numKvHeads));
+ try {
+ return __tmp[0];
+ } finally {
+ if(__tmp != null) {
+ for(int __i = 1; __i < __tmp.length; __i++) {
+ if(__tmp[__i] != null) {
+ __tmp[__i].close();
+ }
+ }
+ }
+ }
+ }
+
/**
* Element-wise hard sigmoid function:
* out[i] = 0 if in[i] <= -2.5
@@ -263,6 +691,108 @@ public INDArray hardTanhDerivative(INDArray x) {
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.gradient.HardTanhDerivative(x));
}
+ /**
+ * KV Cache Update - Updates key-value cache for autoregressive generation.
+ *
+ * During LLM inference, past key-value pairs are cached to avoid redundant
+ * computation during token-by-token generation. This operation efficiently
+ * inserts new keys/values at the specified position.
+ *
+ * Usage pattern:
+ * 1. Initialize cache with zeros: [batch, maxSeqLen, numKvHeads, headDim]
+ * 2. For each new token, compute new K/V and update cache
+ * 3. Use full cached K/V for attention computation
+ *
+ * Returns updated keyCache and valueCache tensors.
+ *
+ * @param keyCache Existing key cache. Shape: [batch, maxSeqLen, numKvHeads, headDim] (NUMERIC type)
+ * @param valueCache Existing value cache. Shape: [batch, maxSeqLen, numKvHeads, headDim] (NUMERIC type)
+ * @param newKeys New keys to insert. Shape: [batch, newSeqLen, numKvHeads, headDim] (NUMERIC type)
+ * @param newValues New values to insert. Shape: [batch, newSeqLen, numKvHeads, headDim] (NUMERIC type)
+ * @param startPosition Position in cache where new keys/values should be inserted
+ */
+ public INDArray[] kvCacheUpdate(INDArray keyCache, INDArray valueCache, INDArray newKeys,
+ INDArray newValues, int startPosition) {
+ NDValidation.validateNumerical("kvCacheUpdate", "keyCache", keyCache);
+ NDValidation.validateNumerical("kvCacheUpdate", "valueCache", valueCache);
+ NDValidation.validateNumerical("kvCacheUpdate", "newKeys", newKeys);
+ NDValidation.validateNumerical("kvCacheUpdate", "newValues", newValues);
+ return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.KVCacheUpdate(keyCache, valueCache, newKeys, newValues, startPosition));
+ }
+
+ /**
+ * Batch KV cache scatter update for LLM autoregressive decoding.
+ *
+ * Copies a single time-step slice from each present KV tensor into the
+ * corresponding static KV buffer at a given cache position. Replaces N
+ * individual Java view+assign calls with a single native kernel launch.
+ *
+ * The present tensor has shape [batch, heads, seqLen, dim] where the new
+ * token's KV entry is at the last sequence position. This entry is extracted
+ * and written into the static buffer at cachePos.
+ *
+ * For multiple pairs, inputs are ordered as:
+ * [present_0, ..., present_{N-1}, static_0, ..., static_{N-1}]
+ *
+ * @param present Present KV tensor from decoder output. Shape: [batch, heads, seqLen, dim] (NUMERIC type)
+ * @param staticBuffer Static KV cache buffer. Shape: [batch, heads, maxKvLen, dim]. Updated in-place. (NUMERIC type)
+ * @param cachePos Position in static buffer to write the new entry
+ * @return output Scalar 0 on success (LONG type)
+ */
+ public INDArray kvScatter(INDArray present, INDArray staticBuffer, long cachePos) {
+ NDValidation.validateNumerical("kvScatter", "present", present);
+ NDValidation.validateNumerical("kvScatter", "staticBuffer", staticBuffer);
+ INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.KvScatter(present, staticBuffer, cachePos, 1));
+ try {
+ return __tmp[0];
+ } finally {
+ if(__tmp != null) {
+ for(int __i = 1; __i < __tmp.length; __i++) {
+ if(__tmp[__i] != null) {
+ __tmp[__i].close();
+ }
+ }
+ }
+ }
+ }
+
+ /**
+ * Batch KV cache scatter update for LLM autoregressive decoding.
+ *
+ * Copies a single time-step slice from each present KV tensor into the
+ * corresponding static KV buffer at a given cache position. Replaces N
+ * individual Java view+assign calls with a single native kernel launch.
+ *
+ * The present tensor has shape [batch, heads, seqLen, dim] where the new
+ * token's KV entry is at the last sequence position. This entry is extracted
+ * and written into the static buffer at cachePos.
+ *
+ * For multiple pairs, inputs are ordered as:
+ * [present_0, ..., present_{N-1}, static_0, ..., static_{N-1}]
+ *
+ * @param present Present KV tensor from decoder output. Shape: [batch, heads, seqLen, dim] (NUMERIC type)
+ * @param staticBuffer Static KV cache buffer. Shape: [batch, heads, maxKvLen, dim]. Updated in-place. (NUMERIC type)
+ * @param cachePos Position in static buffer to write the new entry
+ * @param numPairs Number of present/static KV pairs. When > 1, inputs are [present_0..N-1, static_0..N-1]
+ * @return output Scalar 0 on success (LONG type)
+ */
+ public INDArray kvScatter(INDArray present, INDArray staticBuffer, long cachePos, int numPairs) {
+ NDValidation.validateNumerical("kvScatter", "present", present);
+ NDValidation.validateNumerical("kvScatter", "staticBuffer", staticBuffer);
+ INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.KvScatter(present, staticBuffer, cachePos, numPairs));
+ try {
+ return __tmp[0];
+ } finally {
+ if(__tmp != null) {
+ for(int __i = 1; __i < __tmp.length; __i++) {
+ if(__tmp[__i] != null) {
+ __tmp[__i].close();
+ }
+ }
+ }
+ }
+ }
+
/**
* Apply Layer Normalization
*
@@ -281,7 +811,18 @@ public INDArray layerNorm(INDArray input, INDArray gain, INDArray bias, boolean
NDValidation.validateNumerical("layerNorm", "gain", gain);
NDValidation.validateNumerical("layerNorm", "bias", bias);
Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length);
- return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.LayerNorm(input, gain, bias, channelsFirst, dimensions))[0];
+ INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.LayerNorm(input, gain, bias, channelsFirst, dimensions));
+ try {
+ return __tmp[0];
+ } finally {
+ if(__tmp != null) {
+ for(int __i = 1; __i < __tmp.length; __i++) {
+ if(__tmp[__i] != null) {
+ __tmp[__i].close();
+ }
+ }
+ }
+ }
}
/**
@@ -300,7 +841,18 @@ public INDArray layerNorm(INDArray input, INDArray gain, boolean channelsFirst,
NDValidation.validateNumerical("layerNorm", "input", input);
NDValidation.validateNumerical("layerNorm", "gain", gain);
Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length);
- return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.LayerNorm(input, gain, null, channelsFirst, dimensions))[0];
+ INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.LayerNorm(input, gain, null, channelsFirst, dimensions));
+ try {
+ return __tmp[0];
+ } finally {
+ if(__tmp != null) {
+ for(int __i = 1; __i < __tmp.length; __i++) {
+ if(__tmp[__i] != null) {
+ __tmp[__i].close();
+ }
+ }
+ }
+ }
}
/**
@@ -347,7 +899,18 @@ public INDArray linear(INDArray input, INDArray weights, INDArray bias, boolean
NDValidation.validateNumerical("linear", "input", input);
NDValidation.validateNumerical("linear", "weights", weights);
NDValidation.validateNumerical("linear", "bias", bias);
- return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.XwPlusB(input, weights, bias, transposeA, transposeB, transposeC))[0];
+ INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.XwPlusB(input, weights, bias, transposeA, transposeB, transposeC));
+ try {
+ return __tmp[0];
+ } finally {
+ if(__tmp != null) {
+ for(int __i = 1; __i < __tmp.length; __i++) {
+ if(__tmp[__i] != null) {
+ __tmp[__i].close();
+ }
+ }
+ }
+ }
}
/**
@@ -363,7 +926,18 @@ public INDArray linear(INDArray input, INDArray weights, INDArray bias) {
NDValidation.validateNumerical("linear", "input", input);
NDValidation.validateNumerical("linear", "weights", weights);
NDValidation.validateNumerical("linear", "bias", bias);
- return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.XwPlusB(input, weights, bias, false, false, false))[0];
+ INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.XwPlusB(input, weights, bias, false, false, false));
+ try {
+ return __tmp[0];
+ } finally {
+ if(__tmp != null) {
+ for(int __i = 1; __i < __tmp.length; __i++) {
+ if(__tmp[__i] != null) {
+ __tmp[__i].close();
+ }
+ }
+ }
+ }
}
/**
@@ -385,7 +959,18 @@ public INDArray logSigmoid(INDArray x) {
*/
public INDArray logSoftmax(INDArray x) {
NDValidation.validateNumerical("logSoftmax", "x", x);
- return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.LogSoftMax(x))[0];
+ INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.LogSoftMax(x));
+ try {
+ return __tmp[0];
+ } finally {
+ if(__tmp != null) {
+ for(int __i = 1; __i < __tmp.length; __i++) {
+ if(__tmp[__i] != null) {
+ __tmp[__i].close();
+ }
+ }
+ }
+ }
}
/**
@@ -397,7 +982,99 @@ public INDArray logSoftmax(INDArray x) {
*/
public INDArray logSoftmax(INDArray x, int dimension) {
NDValidation.validateNumerical("logSoftmax", "x", x);
- return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.LogSoftMax(x, dimension))[0];
+ INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.LogSoftMax(x, dimension));
+ try {
+ return __tmp[0];
+ } finally {
+ if(__tmp != null) {
+ for(int __i = 1; __i < __tmp.length; __i++) {
+ if(__tmp[__i] != null) {
+ __tmp[__i].close();
+ }
+ }
+ }
+ }
+ }
+
+ /**
+ * Mixture of Experts (MoE) Layer.
+ *
+ * Implements sparse MoE routing where each token is processed by only the top-k
+ * selected experts out of a larger pool. This enables scaling model capacity
+ * without proportionally increasing computation.
+ *
+ * Used in large language models like:
+ * - DeepSeek (DeepSeekMoE)
+ * - Mixtral (Mistral AI)
+ * - Switch Transformer (Google)
+ * - GShard (Google)
+ *
+ * The router computes expert selection probabilities:
+ * router_probs = softmax(input @ routerWeights)
+ *
+ * Top-k experts are selected and their outputs are weighted by normalized probs:
+ * output = sum(normalized_prob[i] * expert[i](input) for i in top_k)
+ *
+ * Benefits:
+ * - Scales model capacity with sublinear compute increase
+ * - Enables very large models with efficient inference
+ * - Supports expert parallelism across devices
+ *
+ * @param input Input embeddings. Shape: [batch, seqLen, hiddenSize] (NUMERIC type)
+ * @param routerWeights Router projection weights. Shape: [hiddenSize, numExperts] (NUMERIC type)
+ * @param expertWeights Expert weight matrices. Shape: [numExperts, hiddenSize, expertHiddenSize] (NUMERIC type)
+ * @param numExperts Total number of experts
+ * @param topK Number of experts to route to per token
+ */
+ public INDArray[] mixtureOfExperts(INDArray input, INDArray routerWeights, INDArray expertWeights,
+ int numExperts, int topK) {
+ NDValidation.validateNumerical("mixtureOfExperts", "input", input);
+ NDValidation.validateNumerical("mixtureOfExperts", "routerWeights", routerWeights);
+ NDValidation.validateNumerical("mixtureOfExperts", "expertWeights", expertWeights);
+ return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.MixtureOfExperts(input, routerWeights, expertWeights, null, numExperts, topK, true, 1.0));
+ }
+
+ /**
+ * Mixture of Experts (MoE) Layer.
+ *
+ * Implements sparse MoE routing where each token is processed by only the top-k
+ * selected experts out of a larger pool. This enables scaling model capacity
+ * without proportionally increasing computation.
+ *
+ * Used in large language models like:
+ * - DeepSeek (DeepSeekMoE)
+ * - Mixtral (Mistral AI)
+ * - Switch Transformer (Google)
+ * - GShard (Google)
+ *
+ * The router computes expert selection probabilities:
+ * router_probs = softmax(input @ routerWeights)
+ *
+ * Top-k experts are selected and their outputs are weighted by normalized probs:
+ * output = sum(normalized_prob[i] * expert[i](input) for i in top_k)
+ *
+ * Benefits:
+ * - Scales model capacity with sublinear compute increase
+ * - Enables very large models with efficient inference
+ * - Supports expert parallelism across devices
+ *
+ * @param input Input embeddings. Shape: [batch, seqLen, hiddenSize] (NUMERIC type)
+ * @param routerWeights Router projection weights. Shape: [hiddenSize, numExperts] (NUMERIC type)
+ * @param expertWeights Expert weight matrices. Shape: [numExperts, hiddenSize, expertHiddenSize] (NUMERIC type)
+ * @param expertBias Optional expert biases. Shape: [numExperts, expertHiddenSize] (NUMERIC type)
+ * @param numExperts Total number of experts
+ * @param topK Number of experts to route to per token
+ * @param normalizeProbs Whether to normalize router probabilities for selected experts
+ * @param capacityFactor Expert capacity factor for load balancing
+ */
+ public INDArray[] mixtureOfExperts(INDArray input, INDArray routerWeights, INDArray expertWeights,
+ INDArray expertBias, int numExperts, int topK, boolean normalizeProbs,
+ double capacityFactor) {
+ NDValidation.validateNumerical("mixtureOfExperts", "input", input);
+ NDValidation.validateNumerical("mixtureOfExperts", "routerWeights", routerWeights);
+ NDValidation.validateNumerical("mixtureOfExperts", "expertWeights", expertWeights);
+ NDValidation.validateNumerical("mixtureOfExperts", "expertBias", expertBias);
+ return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.MixtureOfExperts(input, routerWeights, expertWeights, expertBias, numExperts, topK, normalizeProbs, capacityFactor));
}
/**
@@ -434,11 +1111,22 @@ public INDArray multiHeadDotProductAttention(INDArray queries, INDArray keys, IN
NDValidation.validateNumerical("multiHeadDotProductAttention", "Wv", Wv);
NDValidation.validateNumerical("multiHeadDotProductAttention", "Wo", Wo);
NDValidation.validateNumerical("multiHeadDotProductAttention", "mask", mask);
- return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.MultiHeadDotProductAttention(queries, keys, values, Wq, Wk, Wv, Wo, mask, scaled, false))[0];
+ INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.MultiHeadDotProductAttention(queries, keys, values, Wq, Wk, Wv, Wo, mask, scaled, false));
+ try {
+ return __tmp[0];
+ } finally {
+ if(__tmp != null) {
+ for(int __i = 1; __i < __tmp.length; __i++) {
+ if(__tmp[__i] != null) {
+ __tmp[__i].close();
+ }
+ }
+ }
+ }
}
/**
- * Padding operation
+ * Padding operation
*
* @param input Input tensor (NUMERIC type)
* @param padding Padding value (NUMERIC type)
@@ -449,11 +1137,22 @@ public INDArray multiHeadDotProductAttention(INDArray queries, INDArray keys, IN
public INDArray pad(INDArray input, INDArray padding, PadMode PadMode, double constant) {
NDValidation.validateNumerical("pad", "input", input);
NDValidation.validateNumerical("pad", "padding", padding);
- return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.Pad(input, padding, PadMode, constant))[0];
+ INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.Pad(input, padding, PadMode, constant));
+ try {
+ return __tmp[0];
+ } finally {
+ if(__tmp != null) {
+ for(int __i = 1; __i < __tmp.length; __i++) {
+ if(__tmp[__i] != null) {
+ __tmp[__i].close();
+ }
+ }
+ }
+ }
}
/**
- * Padding operation
+ * Padding operation
*
* @param input Input tensor (NUMERIC type)
* @param padding Padding value (NUMERIC type)
@@ -463,7 +1162,18 @@ public INDArray pad(INDArray input, INDArray padding, PadMode PadMode, double co
public INDArray pad(INDArray input, INDArray padding, double constant) {
NDValidation.validateNumerical("pad", "input", input);
NDValidation.validateNumerical("pad", "padding", padding);
- return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.Pad(input, padding, PadMode.CONSTANT, constant))[0];
+ INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.Pad(input, padding, PadMode.CONSTANT, constant));
+ try {
+ return __tmp[0];
+ } finally {
+ if(__tmp != null) {
+ for(int __i = 1; __i < __tmp.length; __i++) {
+ if(__tmp[__i] != null) {
+ __tmp[__i].close();
+ }
+ }
+ }
+ }
}
/**
@@ -498,7 +1208,105 @@ public INDArray prelu(INDArray input, INDArray alpha, int... sharedAxes) {
NDValidation.validateNumerical("prelu", "input", input);
NDValidation.validateNumerical("prelu", "alpha", alpha);
Preconditions.checkArgument(sharedAxes.length >= 1, "sharedAxes has incorrect size/length. Expected: sharedAxes.length >= 1, got %s", sharedAxes.length);
- return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.scalar.PRelu(input, alpha, sharedAxes))[0];
+ INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.scalar.PRelu(input, alpha, sharedAxes));
+ try {
+ return __tmp[0];
+ } finally {
+ if(__tmp != null) {
+ for(int __i = 1; __i < __tmp.length; __i++) {
+ if(__tmp[__i] != null) {
+ __tmp[__i].close();
+ }
+ }
+ }
+ }
+ }
+
+ /**
+ * Relative Position Bias - Compute relative position bias for attention.
+ *
+ * Supports two modes:
+ * 1. Learned bias (Swin/SAM style): Looks up bias values from a learned table
+ * based on relative positions between query and key positions.
+ *
+ * 2. ALiBi (Attention with Linear Biases): Computes linear position-based bias
+ * without learned parameters. More efficient for very long sequences.
+ *
+ * For learned bias mode:
+ * - biasTable shape: [(2*windowSize-1)^2, numHeads] for 2D
+ * - Output is gathered based on relative position indices
+ *
+ * For ALiBi mode:
+ * - biasTable can be sequence length (scalar) or input tensor
+ * - Computes m_h * |i - j| where m_h = 2^(-8*h/H)
+ *
+ * Reference: "Swin Transformer" (Liu et al., 2021)
+ * "Train Short, Test Long" (Press et al., 2021) for ALiBi
+ *
+ * @param biasTable Learned bias table. Shape: [numRelativePositions, numHeads] for learned mode, or scalar/tensor for ALiBi mode (NUMERIC type)
+ * @param numHeads Number of attention heads
+ * @param windowSize Window size for 2D position encoding (used if generating index)
+ * @return output Position bias. Shape: [numHeads, windowSize^2, windowSize^2] or [numHeads, seqLen, seqLen] (NUMERIC type)
+ */
+ public INDArray relativePositionBias(INDArray biasTable, int numHeads, int windowSize) {
+ NDValidation.validateNumerical("relativePositionBias", "biasTable", biasTable);
+ INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.RelativePositionBias(biasTable, null, numHeads, windowSize, false));
+ try {
+ return __tmp[0];
+ } finally {
+ if(__tmp != null) {
+ for(int __i = 1; __i < __tmp.length; __i++) {
+ if(__tmp[__i] != null) {
+ __tmp[__i].close();
+ }
+ }
+ }
+ }
+ }
+
+ /**
+ * Relative Position Bias - Compute relative position bias for attention.
+ *
+ * Supports two modes:
+ * 1. Learned bias (Swin/SAM style): Looks up bias values from a learned table
+ * based on relative positions between query and key positions.
+ *
+ * 2. ALiBi (Attention with Linear Biases): Computes linear position-based bias
+ * without learned parameters. More efficient for very long sequences.
+ *
+ * For learned bias mode:
+ * - biasTable shape: [(2*windowSize-1)^2, numHeads] for 2D
+ * - Output is gathered based on relative position indices
+ *
+ * For ALiBi mode:
+ * - biasTable can be sequence length (scalar) or input tensor
+ * - Computes m_h * |i - j| where m_h = 2^(-8*h/H)
+ *
+ * Reference: "Swin Transformer" (Liu et al., 2021)
+ * "Train Short, Test Long" (Press et al., 2021) for ALiBi
+ *
+ * @param biasTable Learned bias table. Shape: [numRelativePositions, numHeads] for learned mode, or scalar/tensor for ALiBi mode (NUMERIC type)
+ * @param relativePositionIndex Optional precomputed relative position index. Shape: [windowSize^2, windowSize^2] (NUMERIC type)
+ * @param numHeads Number of attention heads
+ * @param windowSize Window size for 2D position encoding (used if generating index)
+ * @return output Position bias. Shape: [numHeads, windowSize^2, windowSize^2] or [numHeads, seqLen, seqLen] (NUMERIC type)
+ */
+ public INDArray relativePositionBias(INDArray biasTable, INDArray relativePositionIndex,
+ int numHeads, int windowSize) {
+ NDValidation.validateNumerical("relativePositionBias", "biasTable", biasTable);
+ NDValidation.validateNumerical("relativePositionBias", "relativePositionIndex", relativePositionIndex);
+ INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.RelativePositionBias(biasTable, relativePositionIndex, numHeads, windowSize, false));
+ try {
+ return __tmp[0];
+ } finally {
+ if(__tmp != null) {
+ for(int __i = 1; __i < __tmp.length; __i++) {
+ if(__tmp[__i] != null) {
+ __tmp[__i].close();
+ }
+ }
+ }
+ }
}
/**
@@ -540,7 +1348,128 @@ public INDArray reluLayer(INDArray input, INDArray weights, INDArray bias) {
NDValidation.validateNumerical("reluLayer", "input", input);
NDValidation.validateNumerical("reluLayer", "weights", weights);
NDValidation.validateNumerical("reluLayer", "bias", bias);
- return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.ReluLayer(input, weights, bias))[0];
+ INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.ReluLayer(input, weights, bias));
+ try {
+ return __tmp[0];
+ } finally {
+ if(__tmp != null) {
+ for(int __i = 1; __i < __tmp.length; __i++) {
+ if(__tmp[__i] != null) {
+ __tmp[__i].close();
+ }
+ }
+ }
+ }
+ }
+
+ /**
+ * Root Mean Square Layer Normalization (RMSNorm):
+ *
+ * output = input * rsqrt(mean(input^2, axis=-1) + epsilon) * gamma
+ *
+ * If gamma is not provided, only RMS normalization is applied.
+ *
+ * @param input Input variable (NUMERIC type)
+ * @param gamma Scale/gain vector (NUMERIC type)
+ * @param epsilon Epsilon for numerical stability
+ * @return output RMS normalized output (NUMERIC type)
+ */
+ public INDArray rmsNorm(INDArray input, INDArray gamma, double epsilon) {
+ NDValidation.validateNumerical("rmsNorm", "input", input);
+ NDValidation.validateNumerical("rmsNorm", "gamma", gamma);
+ INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.RmsNorm(input, gamma, epsilon));
+ try {
+ return __tmp[0];
+ } finally {
+ if(__tmp != null) {
+ for(int __i = 1; __i < __tmp.length; __i++) {
+ if(__tmp[__i] != null) {
+ __tmp[__i].close();
+ }
+ }
+ }
+ }
+ }
+
+ /**
+ * Root Mean Square Layer Normalization (RMSNorm):
+ *
+ * output = input * rsqrt(mean(input^2, axis=-1) + epsilon) * gamma
+ *
+ * If gamma is not provided, only RMS normalization is applied.
+ *
+ * @param input Input variable (NUMERIC type)
+ * @param gamma Scale/gain vector (NUMERIC type)
+ * @return output RMS normalized output (NUMERIC type)
+ */
+ public INDArray rmsNorm(INDArray input, INDArray gamma) {
+ NDValidation.validateNumerical("rmsNorm", "input", input);
+ NDValidation.validateNumerical("rmsNorm", "gamma", gamma);
+ INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.RmsNorm(input, gamma, 1.0E-5));
+ try {
+ return __tmp[0];
+ } finally {
+ if(__tmp != null) {
+ for(int __i = 1; __i < __tmp.length; __i++) {
+ if(__tmp[__i] != null) {
+ __tmp[__i].close();
+ }
+ }
+ }
+ }
+ }
+
+ /**
+ * Root Mean Square Layer Normalization (RMSNorm):
+ *
+ * output = input * rsqrt(mean(input^2, axis=-1) + epsilon) * gamma
+ *
+ * If gamma is not provided, only RMS normalization is applied.
+ *
+ * @param input Input variable (NUMERIC type)
+ * @param epsilon Epsilon for numerical stability
+ * @return output RMS normalized output (NUMERIC type)
+ */
+ public INDArray rmsNorm(INDArray input, double epsilon) {
+ NDValidation.validateNumerical("rmsNorm", "input", input);
+ INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.RmsNorm(input, null, epsilon));
+ try {
+ return __tmp[0];
+ } finally {
+ if(__tmp != null) {
+ for(int __i = 1; __i < __tmp.length; __i++) {
+ if(__tmp[__i] != null) {
+ __tmp[__i].close();
+ }
+ }
+ }
+ }
+ }
+
+ /**
+ * Root Mean Square Layer Normalization (RMSNorm):
+ *
+ * output = input * rsqrt(mean(input^2, axis=-1) + epsilon) * gamma
+ *
+ * If gamma is not provided, only RMS normalization is applied.
+ *
+ * @param input Input variable (NUMERIC type)
+ * @return output RMS normalized output (NUMERIC type)
+ */
+ public INDArray rmsNorm(INDArray input) {
+ NDValidation.validateNumerical("rmsNorm", "input", input);
+ INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.RmsNorm(input, null, 1.0E-5));
+ try {
+ return __tmp[0];
+ } finally {
+ if(__tmp != null) {
+ for(int __i = 1; __i < __tmp.length; __i++) {
+ if(__tmp[__i] != null) {
+ __tmp[__i].close();
+ }
+ }
+ }
+ }
}
/**
@@ -578,7 +1507,61 @@ public INDArray sigmoid(INDArray x) {
public INDArray sigmoidDerivative(INDArray x, INDArray wrt) {
NDValidation.validateNumerical("sigmoidDerivative", "x", x);
NDValidation.validateNumerical("sigmoidDerivative", "wrt", wrt);
- return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.gradient.SigmoidDerivative(x, wrt))[0];
+ INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.gradient.SigmoidDerivative(x, wrt));
+ try {
+ return __tmp[0];
+ } finally {
+ if(__tmp != null) {
+ for(int __i = 1; __i < __tmp.length; __i++) {
+ if(__tmp[__i] != null) {
+ __tmp[__i].close();
+ }
+ }
+ }
+ }
+ }
+
+ /**
+ * Sliding Window Attention - Efficient attention for long sequences.
+ *
+ * Each token only attends to a fixed window of previous tokens, enabling
+ * efficient processing of very long sequences. Used in Mistral and other
+ * modern LLMs for handling long contexts.
+ *
+ * Benefits:
+ * - O(N * windowSize) complexity instead of O(N^2)
+ * - Memory efficient for long sequences
+ * - Supports very long context lengths (e.g., 32K with 4K window)
+ *
+ * The attention mask is automatically applied to restrict each position
+ * to only attend to positions within [pos - windowSize, pos].
+ *
+ * @param query Query tensor. Shape: [batch, seqLen, numHeads, headDim] (NUMERIC type)
+ * @param key Key tensor. Shape: [batch, seqLen, numKvHeads, headDim] (NUMERIC type)
+ * @param value Value tensor. Shape: [batch, seqLen, numKvHeads, headDim] (NUMERIC type)
+ * @param windowSize Sliding window size - tokens can only attend to this many previous positions
+ * @param numHeads Number of query attention heads
+ * @param numKvHeads Number of KV heads (0 = same as numHeads)
+ * @param scale Scaling factor. 0 = auto (1/sqrt(headDim))
+ * @return output Attention output. Shape: [batch, seqLen, numHeads, headDim] (NUMERIC type)
+ */
+ public INDArray slidingWindowAttention(INDArray query, INDArray key, INDArray value,
+ int windowSize, int numHeads, int numKvHeads, double scale) {
+ NDValidation.validateNumerical("slidingWindowAttention", "query", query);
+ NDValidation.validateNumerical("slidingWindowAttention", "key", key);
+ NDValidation.validateNumerical("slidingWindowAttention", "value", value);
+ INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.SlidingWindowAttention(query, key, value, windowSize, numHeads, numKvHeads, scale));
+ try {
+ return __tmp[0];
+ } finally {
+ if(__tmp != null) {
+ for(int __i = 1; __i < __tmp.length; __i++) {
+ if(__tmp[__i] != null) {
+ __tmp[__i].close();
+ }
+ }
+ }
+ }
}
/**
@@ -590,7 +1573,18 @@ public INDArray sigmoidDerivative(INDArray x, INDArray wrt) {
*/
public INDArray softmax(INDArray x, int dimension) {
NDValidation.validateNumerical("softmax", "x", x);
- return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax(x, dimension))[0];
+ INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax(x, dimension));
+ try {
+ return __tmp[0];
+ } finally {
+ if(__tmp != null) {
+ for(int __i = 1; __i < __tmp.length; __i++) {
+ if(__tmp[__i] != null) {
+ __tmp[__i].close();
+ }
+ }
+ }
+ }
}
/**
@@ -601,7 +1595,18 @@ public INDArray softmax(INDArray x, int dimension) {
*/
public INDArray softmax(INDArray x) {
NDValidation.validateNumerical("softmax", "x", x);
- return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax(x, -1))[0];
+ INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax(x, -1));
+ try {
+ return __tmp[0];
+ } finally {
+ if(__tmp != null) {
+ for(int __i = 1; __i < __tmp.length; __i++) {
+ if(__tmp[__i] != null) {
+ __tmp[__i].close();
+ }
+ }
+ }
+ }
}
/**
@@ -660,6 +1665,75 @@ public INDArray tanh(INDArray x) {
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.strict.Tanh(x));
}
+ /**
+ * Token sampling for LLM inference.
+ *
+ * Full sampling pipeline in a single native GPU call:
+ * temperature scaling -> top-K filtering -> softmax -> top-P filtering -> sample/argmax
+ *
+ * For greedy decoding (temperature=0 or no top-k/top-p), performs GPU-side argmax
+ * with shared-memory reduction — avoids transferring the full logits tensor to host.
+ *
+ * Supports rank 1 [vocabSize], rank 2 [batch, vocabSize], and rank 3
+ * [batch, seqLen, vocabSize] inputs. For rank 3, the last sequence position
+ * is automatically extracted for sampling.
+ *
+ * @param logits Logits tensor. Shape: [vocabSize], [batch, vocabSize], or [batch, seqLen, vocabSize]. For rank-3, samples from the last sequence position. (NUMERIC type)
+ * @return output Sampled token indices. Shape: [batch] or scalar (LONG type)
+ */
+ public INDArray tokenSample(INDArray logits) {
+ NDValidation.validateNumerical("tokenSample", "logits", logits);
+ INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.TokenSample(logits, 0.0, 0, 0.0, 0));
+ try {
+ return __tmp[0];
+ } finally {
+ if(__tmp != null) {
+ for(int __i = 1; __i < __tmp.length; __i++) {
+ if(__tmp[__i] != null) {
+ __tmp[__i].close();
+ }
+ }
+ }
+ }
+ }
+
+ /**
+ * Token sampling for LLM inference.
+ *
+ * Full sampling pipeline in a single native GPU call:
+ * temperature scaling -> top-K filtering -> softmax -> top-P filtering -> sample/argmax
+ *
+ * For greedy decoding (temperature=0 or no top-k/top-p), performs GPU-side argmax
+ * with shared-memory reduction — avoids transferring the full logits tensor to host.
+ *
+ * Supports rank 1 [vocabSize], rank 2 [batch, vocabSize], and rank 3
+ * [batch, seqLen, vocabSize] inputs. For rank 3, the last sequence position
+ * is automatically extracted for sampling.
+ *
+ * @param logits Logits tensor. Shape: [vocabSize], [batch, vocabSize], or [batch, seqLen, vocabSize]. For rank-3, samples from the last sequence position. (NUMERIC type)
+ * @param temperature Temperature for sampling. 0 = greedy (argmax)
+ * @param topK Top-K filtering: keep only top K logits. 0 = disabled
+ * @param topP Top-P (nucleus) filtering threshold. 0 = disabled
+ * @param seed Random seed for sampling. 0 = random
+ * @return output Sampled token indices. Shape: [batch] or scalar (LONG type)
+ */
+ public INDArray tokenSample(INDArray logits, double temperature, int topK, double topP,
+ long seed) {
+ NDValidation.validateNumerical("tokenSample", "logits", logits);
+ INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.TokenSample(logits, temperature, topK, topP, seed));
+ try {
+ return __tmp[0];
+ } finally {
+ if(__tmp != null) {
+ for(int __i = 1; __i < __tmp.length; __i++) {
+ if(__tmp[__i] != null) {
+ __tmp[__i].close();
+ }
+ }
+ }
+ }
+ }
+
/**
* Find values and indices for the largest k entries along the last dimension.
*
@@ -671,4 +1745,152 @@ public INDArray[] topK(INDArray input, double k, boolean sorted) {
NDValidation.validateNumerical("topK", "input", input);
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.TopK(input, k, sorted));
}
+
+ /**
+ * SAM-style Two-Way Cross Attention.
+ * Bidirectional cross-attention where tokens attend to image features and
+ * image features attend to tokens simultaneously:
+ * tokenOutput = softmax(tokenQ @ imageK^T * scale) @ imageV
+ * imageOutput = softmax(imageQ @ tokenK^T * scale) @ tokenV
+ *
+ * @param tokenQuery Token queries [batch, tokenSeqLen, embedDim] (NUMERIC type)
+ * @param tokenKey Token keys [batch, tokenSeqLen, embedDim] (NUMERIC type)
+ * @param tokenValue Token values [batch, tokenSeqLen, embedDim] (NUMERIC type)
+ * @param imageQuery Image queries [batch, imageSeqLen, embedDim] (NUMERIC type)
+ * @param imageKey Image keys [batch, imageSeqLen, embedDim] (NUMERIC type)
+ * @param imageValue Image values [batch, imageSeqLen, embedDim] (NUMERIC type)
+ * @param scale Attention scale factor (default: 1/sqrt(embedDim))
+ */
+ public INDArray[] twoWayCrossAttention(INDArray tokenQuery, INDArray tokenKey,
+ INDArray tokenValue, INDArray imageQuery, INDArray imageKey, INDArray imageValue,
+ double scale) {
+ NDValidation.validateNumerical("twoWayCrossAttention", "tokenQuery", tokenQuery);
+ NDValidation.validateNumerical("twoWayCrossAttention", "tokenKey", tokenKey);
+ NDValidation.validateNumerical("twoWayCrossAttention", "tokenValue", tokenValue);
+ NDValidation.validateNumerical("twoWayCrossAttention", "imageQuery", imageQuery);
+ NDValidation.validateNumerical("twoWayCrossAttention", "imageKey", imageKey);
+ NDValidation.validateNumerical("twoWayCrossAttention", "imageValue", imageValue);
+ return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.TwoWayCrossAttention(tokenQuery, tokenKey, tokenValue, imageQuery, imageKey, imageValue, scale));
+ }
+
+ /**
+ * SAM-style Two-Way Cross Attention.
+ * Bidirectional cross-attention where tokens attend to image features and
+ * image features attend to tokens simultaneously:
+ * tokenOutput = softmax(tokenQ @ imageK^T * scale) @ imageV
+ * imageOutput = softmax(imageQ @ tokenK^T * scale) @ tokenV
+ *
+ * @param tokenQuery Token queries [batch, tokenSeqLen, embedDim] (NUMERIC type)
+ * @param tokenKey Token keys [batch, tokenSeqLen, embedDim] (NUMERIC type)
+ * @param tokenValue Token values [batch, tokenSeqLen, embedDim] (NUMERIC type)
+ * @param imageQuery Image queries [batch, imageSeqLen, embedDim] (NUMERIC type)
+ * @param imageKey Image keys [batch, imageSeqLen, embedDim] (NUMERIC type)
+ * @param imageValue Image values [batch, imageSeqLen, embedDim] (NUMERIC type)
+ */
+ public INDArray[] twoWayCrossAttention(INDArray tokenQuery, INDArray tokenKey,
+ INDArray tokenValue, INDArray imageQuery, INDArray imageKey, INDArray imageValue) {
+ NDValidation.validateNumerical("twoWayCrossAttention", "tokenQuery", tokenQuery);
+ NDValidation.validateNumerical("twoWayCrossAttention", "tokenKey", tokenKey);
+ NDValidation.validateNumerical("twoWayCrossAttention", "tokenValue", tokenValue);
+ NDValidation.validateNumerical("twoWayCrossAttention", "imageQuery", imageQuery);
+ NDValidation.validateNumerical("twoWayCrossAttention", "imageKey", imageKey);
+ NDValidation.validateNumerical("twoWayCrossAttention", "imageValue", imageValue);
+ return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.TwoWayCrossAttention(tokenQuery, tokenKey, tokenValue, imageQuery, imageKey, imageValue, 0.0));
+ }
+
+ /**
+ * Windowed Attention - Local/Sliding Window Attention.
+ *
+ * Implements windowed attention mechanisms used in efficient transformers like
+ * Longformer, BigBird, Swin Transformer, and SAM (Segment Anything Model).
+ *
+ * Supports both:
+ * - 1D windowed attention: for sequences [batch, seqLen, heads, dim]
+ * - 2D windowed attention: for images [batch, height, width, heads, dim]
+ *
+ * Shifted window attention (shiftSize > 0) enables cross-window connections
+ * as used in Swin Transformer.
+ *
+ * Benefits:
+ * - O(N * windowSize) complexity instead of O(N^2)
+ * - Efficient for long sequences and high-resolution images
+ * - Supports relative position bias for position-aware attention
+ *
+ * @param query Query tensor. Shape: [batch, seqLen, numHeads, headDim] for 1D or [batch, height, width, numHeads, headDim] for 2D (NUMERIC type)
+ * @param key Key tensor. Same shape as query (NUMERIC type)
+ * @param value Value tensor. Same shape as query (NUMERIC type)
+ * @param windowSize Size of attention window
+ * @param numHeads Number of attention heads
+ * @return output Attention output. Same shape as query (NUMERIC type)
+ */
+ public INDArray windowedAttention(INDArray query, INDArray key, INDArray value, int windowSize,
+ int numHeads) {
+ NDValidation.validateNumerical("windowedAttention", "query", query);
+ NDValidation.validateNumerical("windowedAttention", "key", key);
+ NDValidation.validateNumerical("windowedAttention", "value", value);
+ INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.WindowedAttention(query, key, value, null, null, windowSize, numHeads, 0, 0.0, false));
+ try {
+ return __tmp[0];
+ } finally {
+ if(__tmp != null) {
+ for(int __i = 1; __i < __tmp.length; __i++) {
+ if(__tmp[__i] != null) {
+ __tmp[__i].close();
+ }
+ }
+ }
+ }
+ }
+
+ /**
+ * Windowed Attention - Local/Sliding Window Attention.
+ *
+ * Implements windowed attention mechanisms used in efficient transformers like
+ * Longformer, BigBird, Swin Transformer, and SAM (Segment Anything Model).
+ *
+ * Supports both:
+ * - 1D windowed attention: for sequences [batch, seqLen, heads, dim]
+ * - 2D windowed attention: for images [batch, height, width, heads, dim]
+ *
+ * Shifted window attention (shiftSize > 0) enables cross-window connections
+ * as used in Swin Transformer.
+ *
+ * Benefits:
+ * - O(N * windowSize) complexity instead of O(N^2)
+ * - Efficient for long sequences and high-resolution images
+ * - Supports relative position bias for position-aware attention
+ *
+ * @param query Query tensor. Shape: [batch, seqLen, numHeads, headDim] for 1D or [batch, height, width, numHeads, headDim] for 2D (NUMERIC type)
+ * @param key Key tensor. Same shape as query (NUMERIC type)
+ * @param value Value tensor. Same shape as query (NUMERIC type)
+ * @param relativePositionBias Optional relative position bias. Shape: [numHeads, windowSize, windowSize] (NUMERIC type)
+ * @param attentionMask Optional attention mask (NUMERIC type)
+ * @param windowSize Size of attention window
+ * @param numHeads Number of attention heads
+ * @param shiftSize Shift size for shifted window attention (Swin style). 0 = no shift
+ * @param scale Attention scale factor. 0 = auto (1/sqrt(headDim))
+ * @param returnWeights Whether to return attention weights
+ * @return output Attention output. Same shape as query (NUMERIC type)
+ */
+ public INDArray windowedAttention(INDArray query, INDArray key, INDArray value,
+ INDArray relativePositionBias, INDArray attentionMask, int windowSize, int numHeads,
+ int shiftSize, double scale, boolean returnWeights) {
+ NDValidation.validateNumerical("windowedAttention", "query", query);
+ NDValidation.validateNumerical("windowedAttention", "key", key);
+ NDValidation.validateNumerical("windowedAttention", "value", value);
+ NDValidation.validateNumerical("windowedAttention", "relativePositionBias", relativePositionBias);
+ NDValidation.validateNumerical("windowedAttention", "attentionMask", attentionMask);
+ INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.WindowedAttention(query, key, value, relativePositionBias, attentionMask, windowSize, numHeads, shiftSize, scale, returnWeights));
+ try {
+ return __tmp[0];
+ } finally {
+ if(__tmp != null) {
+ for(int __i = 1; __i < __tmp.length; __i++) {
+ if(__tmp[__i] != null) {
+ __tmp[__i].close();
+ }
+ }
+ }
+ }
+ }
}