From ee3817882c752fbab6823c145067d672fa473b97 Mon Sep 17 00:00:00 2001 From: agibsonccc Date: Thu, 26 Feb 2026 15:44:05 +0900 Subject: [PATCH] Snapshot split from ag_new_release_updates_2 (b5893454f0) for pr/java-op-api-surface --- .../ops/org/nd4j/codegen/ops/NeuralNetwork.kt | 534 +++++- .../ops/org/nd4j/codegen/ops/SDBaseOps.kt | 66 + .../nd4j/autodiff/samediff/ops/SDBaseOps.java | 136 ++ .../org/nd4j/autodiff/samediff/ops/SDNN.java | 1669 ++++++++++++++++- .../ops/impl/transforms/custom/AwqMatmul.java | 134 ++ .../impl/transforms/custom/BooleanAnd.java | 66 + .../ops/impl/transforms/custom/BooleanOr.java | 66 + .../impl/transforms/custom/BooleanXor.java | 66 + .../custom/ColumnParallelLinear.java | 153 ++ .../transforms/custom/DecoderMaskedMha.java | 166 ++ .../ops/impl/transforms/custom/Fp8Matmul.java | 155 ++ .../custom/FusedElementwiseChain.java | 233 +++ .../transforms/custom/FusedGemmSwiglu.java | 85 + .../transforms/custom/FusedNormQuantize.java | 163 ++ .../impl/transforms/custom/GpuTopKSample.java | 167 ++ .../impl/transforms/custom/GpuTopPSample.java | 217 +++ .../ops/impl/transforms/custom/MoeGate.java | 146 ++ .../transforms/custom/MultiLoraMatmul.java | 125 ++ .../transforms/custom/RowParallelLinear.java | 153 ++ .../impl/transforms/custom/SelectiveScan.java | 108 ++ .../impl/transforms/custom/SmoothQuant.java | 150 ++ .../org/nd4j/linalg/factory/ops/NDBase.java | 1563 +++++++++++++-- .../org/nd4j/linalg/factory/ops/NDNN.java | 1316 ++++++++++++- 23 files changed, 7394 insertions(+), 243 deletions(-) create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/AwqMatmul.java create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/BooleanAnd.java create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/BooleanOr.java create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/BooleanXor.java create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/ColumnParallelLinear.java create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/DecoderMaskedMha.java create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Fp8Matmul.java create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/FusedElementwiseChain.java create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/FusedGemmSwiglu.java create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/FusedNormQuantize.java create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/GpuTopKSample.java create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/GpuTopPSample.java create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/MoeGate.java create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/MultiLoraMatmul.java create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/RowParallelLinear.java create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/SelectiveScan.java create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/SmoothQuant.java diff --git a/codegen/op-codegen/src/main/ops/org/nd4j/codegen/ops/NeuralNetwork.kt b/codegen/op-codegen/src/main/ops/org/nd4j/codegen/ops/NeuralNetwork.kt index b7c19c16ac17..076b9269c7c3 100644 --- a/codegen/op-codegen/src/main/ops/org/nd4j/codegen/ops/NeuralNetwork.kt +++ b/codegen/op-codegen/src/main/ops/org/nd4j/codegen/ops/NeuralNetwork.kt @@ -445,45 +445,75 @@ fun NN() = Namespace("NN") { } } + Op("rmsNorm") { + javaPackage = "org.nd4j.linalg.api.ops.impl.transforms.custom" + javaOpClass = "RmsNorm" + val input = Input(NUMERIC, "input") { description = "Input variable" } + val gamma = Input(NUMERIC, "gamma") { description = "Scale/gain vector"; defaultValue = null } + val epsilon = Arg(FLOATING_POINT, "epsilon") { defaultValue = 1e-5; description = "Epsilon for numerical stability" } + + Output(NUMERIC, "output") { description = "RMS normalized output" } + + AllParamSignature() + Signature(input, gamma) + Signature(input, epsilon) + Signature(input) + + Doc(Language.ANY, DocScope.ALL) { + """ + Root Mean Square Layer Normalization (RMSNorm): + + output = input * rsqrt(mean(input^2, axis=-1) + epsilon) * gamma + + If gamma is not provided, only RMS normalization is applied. + """.trimIndent() + } + } + Op("dotProductAttentionV2") { javaPackage = "org.nd4j.linalg.api.ops.impl.transforms.custom" - val q = Input(NUMERIC, "queries") { description = "A {@link SDVariable} representing the query tensor. Shape: [batchSize, numQueries, queryDim]" } - val v = Input(NUMERIC, "values") { description = "A {@link SDVariable} representing the value tensor. Shape: [batchSize, numValues, valueDim]" } + val q = Input(NUMERIC, "queries") { description = "Query tensor. Shape: [batchSize, numQueries, queryDim] or [batchSize, numQueries, numHeads, headDim] for flash attention" } + val v = Input(NUMERIC, "values") { description = "Value tensor. Shape: [batchSize, numValues, valueDim] or [batchSize, numValues, numHeads, headDim]" } + val k = Input(NUMERIC, "keys") { description = "Key tensor. Shape: [batchSize, numValues, keyDim] or [batchSize, numValues, numHeads, headDim]" } + val queryMask = Input(NUMERIC, "queryMask") { description = "Query mask tensor (optional). Shape: [batchSize, numQueries]"; defaultValue = null } + val valueMask = Input(NUMERIC, "valueMask") { description = "Value mask tensor (optional). Shape: [batchSize, numValues]"; defaultValue = null } + val attentionBias = Input(NUMERIC, "attentionBias") { description = "Attention bias tensor (optional). Shape: [batchSize, numHeads, numQueries, numKeys] or broadcastable. Added to attention scores before softmax."; defaultValue = null } - val k = Input(NUMERIC, "keys") { description = "A {@link SDVariable} representing the key tensor. Shape: [batchSize, numValues, keyDim]" } - val queryMask = Input(NUMERIC, "queryMask") { description = "A {@link SDVariable} representing the query mask tensor. Shape: [batchSize, numQueries]" } - val valueMask = Input(NUMERIC, "valueMask") { description = "@param valueMask A {@link SDVariable} representing the value mask tensor. Shape: [batchSize, numValues]" } + val s = Arg(FLOATING_POINT, "scaleFactor") { defaultValue = 0.0; description = "Scaling factor applied to attention scores. 0 = auto (1/sqrt(headDim))" } + val dropout = Arg(FLOATING_POINT, "dropoutProbability") { defaultValue = 0.0; description = "Dropout probability applied to attention weights" } + val useCausalMask = Arg(BOOL, "useCausalMask") { defaultValue = false; description = "Whether to apply causal mask for autoregressive tasks" } + val training = Arg(BOOL, "training") { defaultValue = false; description = "Whether in training mode (affects dropout)" } - val s = Arg(FLOATING_POINT, "scaleFactor") { defaultValue = 1.0; description = "@param scaleFactor A {@code double} scaling factor applied to the dot product between queries and keys." } - val dropout = Arg(FLOATING_POINT, "dropoutProbability") { defaultValue = 0.0; description = "A {@code double} specifying the dropout probability to be applied to attention weights." } - val useCausalMask = Arg(BOOL, "useCausalMask") { defaultValue = false; description = " A {@code boolean} flag to indicate whether to apply a causal mask to the attention scores, for autoregressive tasks." } - val training = Arg(BOOL, "training") { defaultValue = false; description = " A {@code boolean} flag to indicate whether the layer is in training mode or inference mode, affecting dropout." } + Output(NUMERIC, "output") { description = "Output tensor. Shape: [batchSize, numQueries, valueDim] or [batchSize, numQueries, numHeads, headDim]" } - Output(NUMERIC, "output") { description = " A {@link SDVariable} representing the output tensor of the dot product attention operation. Shape: [batchSize, numQueries, valueDim]"} - - Signature(q,v,k,queryMask,valueMask, s,dropout,useCausalMask,training) + // Standard signature without attention bias (backward compatible) + Signature(q, v, k, queryMask, valueMask, s, dropout, useCausalMask, training) + // Full signature with attention bias + Signature(q, v, k, queryMask, valueMask, attentionBias, s, dropout, useCausalMask, training) Doc(Language.ANY, DocScope.ALL) { """ - This operation performs dot product attention on the given timeseries input with the given queries - out = sum(similarity(k_i, q) * v_i) - - similarity(k, q) = softmax(k * q) where x * q is the dot product of x and q - - Optionally with normalization step: - similarity(k, q) = softmax(k * q / sqrt(size(q)) - - See also "Attention is all you need" (https://arxiv.org/abs/1706.03762, p. 4, eq. 1) - - Note: This supports multiple queries at once, if only one query is available the queries vector still has to - be 3D but can have queryCount = 1 - - Note: keys and values usually is the same array. If you want to use it as the same array, simply pass it for - both. - - Note: Queries, keys and values must either be all rank 3 or all rank 4 arrays. Mixing them doesn't work. The - output rank will depend on the input rank. + Dot product attention operation with flash attention and KV cache support. + + out = softmax(Q * K^T / scale + attentionBias) * V + + For 4D inputs [batch, seq, heads, dim], uses memory-efficient flash attention algorithm. + For 2D/3D inputs, uses standard attention computation. + + Flash attention features: + - O(N) memory complexity instead of O(N^2) + - Tiled computation with online softmax + - Supports grouped query attention (GQA) where numHeads > numKvHeads + - Supports attention bias (relative position bias, ALiBi, etc.) + + KV Cache support for autoregressive generation: + - Pass keyCache and valueCache tensors + - Set kvCachePosition to current generation position + - Cached keys/values are updated in-place + + See "Attention is all you need" (https://arxiv.org/abs/1706.03762) + See "FlashAttention: Fast and Memory-Efficient Exact Attention" (https://arxiv.org/abs/2205.14135) """.trimIndent() } } @@ -565,6 +595,139 @@ fun NN() = Namespace("NN") { } } + Op("flashAttention") { + javaPackage = "org.nd4j.linalg.api.ops.impl.transforms.custom" + val q = Input(NUMERIC, "query") { description = "Query tensor. Shape: [batch, seqLen, numHeads, headDim]" } + val k = Input(NUMERIC, "key") { description = "Key tensor. Shape: [batch, seqLen, numKvHeads, headDim]" } + val v = Input(NUMERIC, "value") { description = "Value tensor. Shape: [batch, seqLen, numKvHeads, headDim]" } + + val scale = Arg(FLOATING_POINT, "scale") { defaultValue = 0.0; description = "Scaling factor. 0 = auto (1/sqrt(headDim))" } + val isCausal = Arg(BOOL, "isCausal") { defaultValue = true; description = "Whether to apply causal masking" } + val numHeads = Arg(INT, "numHeads") { description = "Number of query attention heads" } + val numKvHeads = Arg(INT, "numKvHeads") { defaultValue = 0; description = "Number of KV heads (0 = same as numHeads, for GQA use smaller value)" } + + Output(NUMERIC, "output") { description = "Attention output. Shape: [batch, seqLen, numHeads, headDim]" } + + Signature(q, k, v, scale, isCausal, numHeads, numKvHeads) + + Doc(Language.ANY, DocScope.ALL) { + """ + Flash Attention - Memory-efficient attention computation. + + Uses tiled computation with online softmax to achieve O(N) memory complexity + instead of O(N^2) for standard attention. + + Supports Grouped Query Attention (GQA) where numHeads > numKvHeads, + allowing multiple query heads to share the same KV heads. + + out = softmax(Q * K^T / scale) * V + + See "FlashAttention: Fast and Memory-Efficient Exact Attention" (https://arxiv.org/abs/2205.14135) + """.trimIndent() + } + } + + Op("groupedQueryAttention") { + javaPackage = "org.nd4j.linalg.api.ops.impl.transforms.custom" + val q = Input(NUMERIC, "query") { description = "Query tensor. Shape: [batch, seqLen, numHeads, headDim]" } + val k = Input(NUMERIC, "key") { description = "Key tensor. Shape: [batch, seqLen, numKvHeads, headDim]" } + val v = Input(NUMERIC, "value") { description = "Value tensor. Shape: [batch, seqLen, numKvHeads, headDim]" } + + val scale = Arg(FLOATING_POINT, "scale") { defaultValue = 0.0; description = "Scaling factor. 0 = auto (1/sqrt(headDim))" } + val isCausal = Arg(BOOL, "isCausal") { defaultValue = true; description = "Whether to apply causal masking" } + val numHeads = Arg(INT, "numHeads") { description = "Number of query attention heads" } + val numKvHeads = Arg(INT, "numKvHeads") { description = "Number of KV heads (must divide numHeads evenly)" } + + Output(NUMERIC, "output") { description = "Attention output. Shape: [batch, seqLen, numHeads, headDim]" } + + Signature(q, k, v, scale, isCausal, numHeads, numKvHeads) + + Doc(Language.ANY, DocScope.ALL) { + """ + Grouped Query Attention (GQA) - Efficient attention with shared KV heads. + + Multiple query heads share the same key-value heads, reducing memory and + computation while maintaining model quality. Used in LLaMA 2, Mistral, etc. + + numHeads must be divisible by numKvHeads. Each KV head is repeated + (numHeads / numKvHeads) times to match query heads. + + Special cases: + - numKvHeads == numHeads: Standard Multi-Head Attention (MHA) + - numKvHeads == 1: Multi-Query Attention (MQA) + + See "GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints" + """.trimIndent() + } + } + + Op("kvCacheUpdate") { + javaPackage = "org.nd4j.linalg.api.ops.impl.transforms.custom" + javaOpClass = "KVCacheUpdate" + val keyCache = Input(NUMERIC, "keyCache") { description = "Existing key cache. Shape: [batch, maxSeqLen, numKvHeads, headDim]" } + val valueCache = Input(NUMERIC, "valueCache") { description = "Existing value cache. Shape: [batch, maxSeqLen, numKvHeads, headDim]" } + val newKeys = Input(NUMERIC, "newKeys") { description = "New keys to insert. Shape: [batch, newSeqLen, numKvHeads, headDim]" } + val newValues = Input(NUMERIC, "newValues") { description = "New values to insert. Shape: [batch, newSeqLen, numKvHeads, headDim]" } + + val startPosition = Arg(INT, "startPosition") { defaultValue = 0; description = "Position in cache where new keys/values should be inserted" } + + Output(NUMERIC, "updatedKeyCache") { description = "Updated key cache" } + Output(NUMERIC, "updatedValueCache") { description = "Updated value cache" } + + Signature(keyCache, valueCache, newKeys, newValues, startPosition) + + Doc(Language.ANY, DocScope.ALL) { + """ + KV Cache Update - Updates key-value cache for autoregressive generation. + + During LLM inference, past key-value pairs are cached to avoid redundant + computation during token-by-token generation. This operation efficiently + inserts new keys/values at the specified position. + + Usage pattern: + 1. Initialize cache with zeros: [batch, maxSeqLen, numKvHeads, headDim] + 2. For each new token, compute new K/V and update cache + 3. Use full cached K/V for attention computation + + Returns updated keyCache and valueCache tensors. + """.trimIndent() + } + } + + Op("slidingWindowAttention") { + javaPackage = "org.nd4j.linalg.api.ops.impl.transforms.custom" + val q = Input(NUMERIC, "query") { description = "Query tensor. Shape: [batch, seqLen, numHeads, headDim]" } + val k = Input(NUMERIC, "key") { description = "Key tensor. Shape: [batch, seqLen, numKvHeads, headDim]" } + val v = Input(NUMERIC, "value") { description = "Value tensor. Shape: [batch, seqLen, numKvHeads, headDim]" } + + val windowSize = Arg(INT, "windowSize") { defaultValue = 4096; description = "Sliding window size - tokens can only attend to this many previous positions" } + val numHeads = Arg(INT, "numHeads") { description = "Number of query attention heads" } + val numKvHeads = Arg(INT, "numKvHeads") { defaultValue = 0; description = "Number of KV heads (0 = same as numHeads)" } + val scale = Arg(FLOATING_POINT, "scale") { defaultValue = 0.0; description = "Scaling factor. 0 = auto (1/sqrt(headDim))" } + + Output(NUMERIC, "output") { description = "Attention output. Shape: [batch, seqLen, numHeads, headDim]" } + + Signature(q, k, v, windowSize, numHeads, numKvHeads, scale) + + Doc(Language.ANY, DocScope.ALL) { + """ + Sliding Window Attention - Efficient attention for long sequences. + + Each token only attends to a fixed window of previous tokens, enabling + efficient processing of very long sequences. Used in Mistral and other + modern LLMs for handling long contexts. + + Benefits: + - O(N * windowSize) complexity instead of O(N^2) + - Memory efficient for long sequences + - Supports very long context lengths (e.g., 32K with 4K window) + + The attention mask is automatically applied to restrict each position + to only attend to positions within [pos - windowSize, pos]. + """.trimIndent() + } + } + Op("pad") { javaPackage = "org.nd4j.linalg.api.ops.impl.transforms" Input(NUMERIC, "input") { description = "Input tensor"} @@ -576,7 +739,7 @@ fun NN() = Namespace("NN") { Doc(Language.ANY, DocScope.ALL){ """ - Padding operation + Padding operation """.trimIndent() } } @@ -598,5 +761,314 @@ fun NN() = Namespace("NN") { } } + Op("windowedAttention") { + javaPackage = "org.nd4j.linalg.api.ops.impl.transforms.custom" + val q = Input(NUMERIC, "query") { description = "Query tensor. Shape: [batch, seqLen, numHeads, headDim] for 1D or [batch, height, width, numHeads, headDim] for 2D" } + val k = Input(NUMERIC, "key") { description = "Key tensor. Same shape as query" } + val v = Input(NUMERIC, "value") { description = "Value tensor. Same shape as query" } + val rpb = Input(NUMERIC, "relativePositionBias") { description = "Optional relative position bias. Shape: [numHeads, windowSize, windowSize]"; defaultValue = null } + val mask = Input(NUMERIC, "attentionMask") { description = "Optional attention mask"; defaultValue = null } + + val windowSize = Arg(INT, "windowSize") { description = "Size of attention window" } + val numHeads = Arg(INT, "numHeads") { description = "Number of attention heads" } + val shiftSize = Arg(INT, "shiftSize") { defaultValue = 0; description = "Shift size for shifted window attention (Swin style). 0 = no shift" } + val scale = Arg(FLOATING_POINT, "scale") { defaultValue = 0.0; description = "Attention scale factor. 0 = auto (1/sqrt(headDim))" } + val returnWeights = Arg(BOOL, "returnWeights") { defaultValue = false; description = "Whether to return attention weights" } + + Output(NUMERIC, "output") { description = "Attention output. Same shape as query" } + + Signature(q, k, v, windowSize, numHeads) + Signature(q, k, v, rpb, mask, windowSize, numHeads, shiftSize, scale, returnWeights) + + Doc(Language.ANY, DocScope.ALL) { + """ + Windowed Attention - Local/Sliding Window Attention. + + Implements windowed attention mechanisms used in efficient transformers like + Longformer, BigBird, Swin Transformer, and SAM (Segment Anything Model). + + Supports both: + - 1D windowed attention: for sequences [batch, seqLen, heads, dim] + - 2D windowed attention: for images [batch, height, width, heads, dim] + + Shifted window attention (shiftSize > 0) enables cross-window connections + as used in Swin Transformer. + + Benefits: + - O(N * windowSize) complexity instead of O(N^2) + - Efficient for long sequences and high-resolution images + - Supports relative position bias for position-aware attention + """.trimIndent() + } + } + + Op("relativePositionBias") { + javaPackage = "org.nd4j.linalg.api.ops.impl.transforms.custom" + val biasTable = Input(NUMERIC, "biasTable") { description = "Learned bias table. Shape: [numRelativePositions, numHeads] for learned mode, or scalar/tensor for ALiBi mode" } + val relPosIndex = Input(NUMERIC, "relativePositionIndex") { description = "Optional precomputed relative position index. Shape: [windowSize^2, windowSize^2]"; defaultValue = null } + + val numHeads = Arg(INT, "numHeads") { description = "Number of attention heads" } + val windowSize = Arg(INT, "windowSize") { defaultValue = 0; description = "Window size for 2D position encoding (used if generating index)" } + val useAlibi = Arg(BOOL, "useAlibi") { defaultValue = false; description = "Use ALiBi (Attention with Linear Biases) instead of learned bias" } + + Output(NUMERIC, "output") { description = "Position bias. Shape: [numHeads, windowSize^2, windowSize^2] or [numHeads, seqLen, seqLen]" } + + Signature(biasTable, numHeads, windowSize) + Signature(biasTable, relPosIndex, numHeads, windowSize) + + Doc(Language.ANY, DocScope.ALL) { + """ + Relative Position Bias - Compute relative position bias for attention. + + Supports two modes: + 1. Learned bias (Swin/SAM style): Looks up bias values from a learned table + based on relative positions between query and key positions. + + 2. ALiBi (Attention with Linear Biases): Computes linear position-based bias + without learned parameters. More efficient for very long sequences. + + For learned bias mode: + - biasTable shape: [(2*windowSize-1)^2, numHeads] for 2D + - Output is gathered based on relative position indices + + For ALiBi mode: + - biasTable can be sequence length (scalar) or input tensor + - Computes m_h * |i - j| where m_h = 2^(-8*h/H) + + Reference: "Swin Transformer" (Liu et al., 2021) + "Train Short, Test Long" (Press et al., 2021) for ALiBi + """.trimIndent() + } + } + + Op("mixtureOfExperts") { + javaPackage = "org.nd4j.linalg.api.ops.impl.transforms.custom" + val input = Input(NUMERIC, "input") { description = "Input embeddings. Shape: [batch, seqLen, hiddenSize]" } + val routerWeights = Input(NUMERIC, "routerWeights") { description = "Router projection weights. Shape: [hiddenSize, numExperts]" } + val expertWeights = Input(NUMERIC, "expertWeights") { description = "Expert weight matrices. Shape: [numExperts, hiddenSize, expertHiddenSize]" } + val expertBias = Input(NUMERIC, "expertBias") { description = "Optional expert biases. Shape: [numExperts, expertHiddenSize]"; defaultValue = null } + + val numExperts = Arg(INT, "numExperts") { description = "Total number of experts" } + val topK = Arg(INT, "topK") { defaultValue = 2; description = "Number of experts to route to per token" } + val normalizeProbs = Arg(BOOL, "normalizeProbs") { defaultValue = true; description = "Whether to normalize router probabilities for selected experts" } + val capacityFactor = Arg(FLOATING_POINT, "capacityFactor") { defaultValue = 1.0; description = "Expert capacity factor for load balancing" } + + Output(NUMERIC, "output") { description = "Combined expert outputs. Shape: [batch, seqLen, expertHiddenSize]" } + Output(NUMERIC, "routerProbs") { description = "Router probabilities. Shape: [batch, seqLen, numExperts]" } + Output(NUMERIC, "expertIndices") { description = "Selected expert indices. Shape: [batch, seqLen, topK]" } + + Signature(input, routerWeights, expertWeights, numExperts, topK) + Signature(input, routerWeights, expertWeights, expertBias, numExperts, topK, normalizeProbs, capacityFactor) + + Doc(Language.ANY, DocScope.ALL) { + """ + Mixture of Experts (MoE) Layer. + + Implements sparse MoE routing where each token is processed by only the top-k + selected experts out of a larger pool. This enables scaling model capacity + without proportionally increasing computation. + + Used in large language models like: + - DeepSeek (DeepSeekMoE) + - Mixtral (Mistral AI) + - Switch Transformer (Google) + - GShard (Google) + + The router computes expert selection probabilities: + router_probs = softmax(input @ routerWeights) + + Top-k experts are selected and their outputs are weighted by normalized probs: + output = sum(normalized_prob[i] * expert[i](input) for i in top_k) + + Benefits: + - Scales model capacity with sublinear compute increase + - Enables very large models with efficient inference + - Supports expert parallelism across devices + """.trimIndent() + } + } + + Op("ctcGreedyDecoder") { + javaPackage = "org.nd4j.linalg.api.ops.impl.transforms.custom" + javaOpClass = "CTCGreedyDecoder" + val logits = Input(NUMERIC, "logits") { description = "Log probabilities from CTC output. Shape: [batch, timeSteps, numClasses]" } + val sequenceLength = Input(NUMERIC, "sequenceLength") { description = "Optional actual sequence lengths. Shape: [batch]"; defaultValue = null } + + val mergeRepeated = Arg(BOOL, "mergeRepeated") { defaultValue = true; description = "Whether to merge repeated characters in output" } + val blankIndex = Arg(INT, "blankIndex") { defaultValue = 0; description = "Index of the blank label in the vocabulary" } + + Output(NUMERIC, "decoded") { description = "Decoded sequences. Shape: [batch, timeSteps] (padded with blank)" } + Output(NUMERIC, "logProbability") { description = "Log probability of decoded sequences. Shape: [batch]" } + + Signature(logits, mergeRepeated, blankIndex) + Signature(logits, sequenceLength, mergeRepeated, blankIndex) + + Doc(Language.ANY, DocScope.ALL) { + """ + CTC Greedy Decoder - Connectionist Temporal Classification decoding. + + Performs greedy (best path) decoding on CTC output. Used in: + - OCR (Optical Character Recognition) - PaddleOCR, CRNN + - Speech recognition - DeepSpeech, Wav2Vec + - Handwriting recognition + + Algorithm: + 1. At each timestep, select the class with highest probability + 2. Optionally merge consecutive repeated characters + 3. Remove blank labels from the output + + For example, with mergeRepeated=true and blankIndex=0: + Input: [0, 1, 1, 0, 2, 2, 2, 0] (0=blank, 1='a', 2='b') + Output: [1, 2] -> "ab" + + Note: This is greedy decoding. For better accuracy with language models, + use beam search decoding instead. + """.trimIndent() + } + } + + Op("emaUpdate") { + javaPackage = "org.nd4j.linalg.api.ops.impl.transforms.custom" + javaOpClass = "EmaUpdate" + Input(NUMERIC, "model") { description = "Current model parameters (student)" } + Input(NUMERIC, "shadow") { description = "EMA shadow parameters (teacher)" } + Arg(NUMERIC, "decay") { description = "EMA decay factor (typically 0.996-0.9999)"; defaultValue = 0.999 } + Output(NUMERIC, "output") { description = "Updated shadow parameters" } + Doc(Language.ANY, DocScope.ALL) { + """ + Exponential Moving Average parameter update for DINOv2 teacher networks. + Computes: output = decay * shadow + (1 - decay) * model + Used in self-supervised learning to maintain a slowly-updated teacher model. + """.trimIndent() + } + } + + Op("centerAndSharpen") { + javaPackage = "org.nd4j.linalg.api.ops.impl.transforms.custom" + javaOpClass = "CenterAndSharpen" + Input(NUMERIC, "input") { description = "Teacher output logits [batch, features]" } + Input(NUMERIC, "center") { description = "Running center vector [features]" } + Arg(NUMERIC, "temperature") { description = "Sharpening temperature (typically 0.04-0.07)"; defaultValue = 0.07 } + Output(NUMERIC, "output") { description = "Sharpened probabilities [batch, features]" } + Doc(Language.ANY, DocScope.ALL) { + """ + DINOv2 centering and sharpening operation. + Prevents mode collapse in self-supervised learning by centering the teacher output + and applying temperature-based sharpening: + output = softmax((input - center) / temperature) + """.trimIndent() + } + } + + Op("twoWayCrossAttention") { + javaPackage = "org.nd4j.linalg.api.ops.impl.transforms.custom" + javaOpClass = "TwoWayCrossAttention" + Input(NUMERIC, "tokenQuery") { description = "Token queries [batch, tokenSeqLen, embedDim]" } + Input(NUMERIC, "tokenKey") { description = "Token keys [batch, tokenSeqLen, embedDim]" } + Input(NUMERIC, "tokenValue") { description = "Token values [batch, tokenSeqLen, embedDim]" } + Input(NUMERIC, "imageQuery") { description = "Image queries [batch, imageSeqLen, embedDim]" } + Input(NUMERIC, "imageKey") { description = "Image keys [batch, imageSeqLen, embedDim]" } + Input(NUMERIC, "imageValue") { description = "Image values [batch, imageSeqLen, embedDim]" } + Arg(NUMERIC, "scale") { description = "Attention scale factor (default: 1/sqrt(embedDim))"; defaultValue = 0.0 } + Output(NUMERIC, "tokenOutput") { description = "Attended token embeddings [batch, tokenSeqLen, embedDim]" } + Output(NUMERIC, "imageOutput") { description = "Attended image embeddings [batch, imageSeqLen, embedDim]" } + Doc(Language.ANY, DocScope.ALL) { + """ + SAM-style Two-Way Cross Attention. + Bidirectional cross-attention where tokens attend to image features and + image features attend to tokens simultaneously: + tokenOutput = softmax(tokenQ @ imageK^T * scale) @ imageV + imageOutput = softmax(imageQ @ tokenK^T * scale) @ tokenV + """.trimIndent() + } + } + + Op("tokenSample") { + javaPackage = "org.nd4j.linalg.api.ops.impl.transforms.custom" + javaOpClass = "TokenSample" + val logits = Input(NUMERIC, "logits") { description = "Logits tensor. Shape: [vocabSize], [batch, vocabSize], or [batch, seqLen, vocabSize]. For rank-3, samples from the last sequence position." } + val temperature = Arg(FLOATING_POINT, "temperature") { defaultValue = 0.0; description = "Temperature for sampling. 0 = greedy (argmax)" } + val topK = Arg(INT, "topK") { defaultValue = 0; description = "Top-K filtering: keep only top K logits. 0 = disabled" } + val topP = Arg(FLOATING_POINT, "topP") { defaultValue = 0.0; description = "Top-P (nucleus) filtering threshold. 0 = disabled" } + val seed = Arg(LONG, "seed") { defaultValue = 0; description = "Random seed for sampling. 0 = random" } + + Output(LONG, "output") { description = "Sampled token indices. Shape: [batch] or scalar" } + + Signature(logits) + Signature(logits, temperature, topK, topP, seed) + + Doc(Language.ANY, DocScope.ALL) { + """ + Token sampling for LLM inference. + + Full sampling pipeline in a single native GPU call: + temperature scaling -> top-K filtering -> softmax -> top-P filtering -> sample/argmax + + For greedy decoding (temperature=0 or no top-k/top-p), performs GPU-side argmax + with shared-memory reduction — avoids transferring the full logits tensor to host. + + Supports rank 1 [vocabSize], rank 2 [batch, vocabSize], and rank 3 + [batch, seqLen, vocabSize] inputs. For rank 3, the last sequence position + is automatically extracted for sampling. + """.trimIndent() + } + } + + Op("kvScatter") { + javaPackage = "org.nd4j.linalg.api.ops.impl.transforms.custom" + javaOpClass = "KvScatter" + val present = Input(NUMERIC, "present") { description = "Present KV tensor from decoder output. Shape: [batch, heads, seqLen, dim]" } + val staticBuf = Input(NUMERIC, "staticBuffer") { description = "Static KV cache buffer. Shape: [batch, heads, maxKvLen, dim]. Updated in-place." } + + val cachePos = Arg(LONG, "cachePos") { description = "Position in static buffer to write the new entry" } + val numPairs = Arg(INT, "numPairs") { defaultValue = 1; description = "Number of present/static KV pairs. When > 1, inputs are [present_0..N-1, static_0..N-1]" } + + Output(LONG, "output") { description = "Scalar 0 on success" } + + Signature(present, staticBuf, cachePos) + Signature(present, staticBuf, cachePos, numPairs) + + Doc(Language.ANY, DocScope.ALL) { + """ + Batch KV cache scatter update for LLM autoregressive decoding. + + Copies a single time-step slice from each present KV tensor into the + corresponding static KV buffer at a given cache position. Replaces N + individual Java view+assign calls with a single native kernel launch. + + The present tensor has shape [batch, heads, seqLen, dim] where the new + token's KV entry is at the last sequence position. This entry is extracted + and written into the static buffer at cachePos. + + For multiple pairs, inputs are ordered as: + [present_0, ..., present_{N-1}, static_0, ..., static_{N-1}] + """.trimIndent() + } + } + Alias(Math(), "tanh") + + Op("fusedElementwiseChain") { + javaPackage = "org.nd4j.linalg.api.ops.impl.transforms.custom" + javaOpClass = "FusedElementwiseChain" + + Input(NUMERIC, "input") { description = "Primary input array" } + Input(NUMERIC, "secondaryInputs") { + count = AtLeast(0) + description = "Optional secondary input arrays for binary ops (add, sub, mul, div)" + } + Arg(INT, "opCodes") { + count = AtLeast(1) + description = "Op codes: 0=add, 1=sub, 2=mul, 3=div, 10=relu, 11=sigmoid, 12=tanh, 13=gelu, 14=exp, 15=log, 16=abs, 17=neg, 18=square, 19=sqrt, 20=swish, 21=silu, 22=mish, 30=clip, 31=leaky_relu" + } + + Output(NUMERIC, "output"){ description = "Result of applying the fused element-wise chain" } + + Doc(Language.ANY, DocScope.ALL){ + """ + Executes a fused chain of element-wise operations in a single kernel pass. + Intermediate values stay in registers instead of global memory. Replaces N separate kernel launches with 1. + """.trimIndent() + } + } } diff --git a/codegen/op-codegen/src/main/ops/org/nd4j/codegen/ops/SDBaseOps.kt b/codegen/op-codegen/src/main/ops/org/nd4j/codegen/ops/SDBaseOps.kt index 237a9d4a37ed..0d6e1928c244 100644 --- a/codegen/op-codegen/src/main/ops/org/nd4j/codegen/ops/SDBaseOps.kt +++ b/codegen/op-codegen/src/main/ops/org/nd4j/codegen/ops/SDBaseOps.kt @@ -1713,6 +1713,20 @@ fun SDBaseOps() = Namespace("BaseOps"){ } } + Op("squeezeAll") { + javaPackage = "org.nd4j.linalg.api.ops.impl.shape" + javaOpClass = "Squeeze" + Input(NUMERIC, "x") { description = "Input variable" } + Output(NUMERIC, "output"){ description = "Output variable" } + Doc(Language.ANY, DocScope.ALL){ + """ + Remove all dimensions of size 1 from the input tensor. + For example, if input has shape [a,1,b,1,c] then squeezeAll(input) returns an array of shape [a,b,c] + This is the NumPy-style squeeze with no axis specified. + """.trimIndent() + } + } + Op("stack") { javaPackage = "org.nd4j.linalg.api.ops.impl.shape" argsFirst = true @@ -2179,4 +2193,56 @@ fun SDBaseOps() = Namespace("BaseOps"){ """.trimIndent() } } + + Op("booleanNot") { + javaPackage = "org.nd4j.linalg.api.ops.impl.transforms.bool" + javaOpClass = "BooleanNot" + legacy = true + Input(BOOL, "x") { description = "Input boolean array" } + Output(BOOL, "output"){ description = "Boolean NOT result" } + Doc(Language.ANY, DocScope.ALL){ + """ + Boolean NOT operation: elementwise !x + """.trimIndent() + } + } + + Op("booleanAnd") { + javaPackage = "org.nd4j.linalg.api.ops.impl.transforms.custom" + javaOpClass = "BooleanAnd" + Input(BOOL, "x") { description = "First input boolean array" } + Input(BOOL, "y") { description = "Second input boolean array" } + Output(BOOL, "output"){ description = "Boolean AND result" } + Doc(Language.ANY, DocScope.ALL){ + """ + Boolean AND operation: elementwise x && y. Supports broadcasting. + """.trimIndent() + } + } + + Op("booleanOr") { + javaPackage = "org.nd4j.linalg.api.ops.impl.transforms.custom" + javaOpClass = "BooleanOr" + Input(BOOL, "x") { description = "First input boolean array" } + Input(BOOL, "y") { description = "Second input boolean array" } + Output(BOOL, "output"){ description = "Boolean OR result" } + Doc(Language.ANY, DocScope.ALL){ + """ + Boolean OR operation: elementwise x || y. Supports broadcasting. + """.trimIndent() + } + } + + Op("booleanXor") { + javaPackage = "org.nd4j.linalg.api.ops.impl.transforms.custom" + javaOpClass = "BooleanXor" + Input(BOOL, "x") { description = "First input boolean array" } + Input(BOOL, "y") { description = "Second input boolean array" } + Output(BOOL, "output"){ description = "Boolean XOR result" } + Doc(Language.ANY, DocScope.ALL){ + """ + Boolean XOR operation: elementwise x ^ y. Supports broadcasting. + """.trimIndent() + } + } } \ No newline at end of file diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDBaseOps.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDBaseOps.java index e64580b9ff00..a36ea1612f6e 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDBaseOps.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDBaseOps.java @@ -424,6 +424,114 @@ public SDVariable[] batchMmul(String[] names, SDVariable alphas, SDVariable beta return sd.updateVariableNamesAndReferences(out, names); } + /** + * Boolean AND operation: elementwise x && y. Supports broadcasting.
+ * + * @param x First input boolean array (BOOL type) + * @param y Second input boolean array (BOOL type) + * @return output Boolean AND result (BOOL type) + */ + public SDVariable booleanAnd(SDVariable x, SDVariable y) { + SDValidation.validateBool("booleanAnd", "x", x); + SDValidation.validateBool("booleanAnd", "y", y); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.BooleanAnd(sd,x, y).outputVariable(); + } + + /** + * Boolean AND operation: elementwise x && y. Supports broadcasting.
+ * + * @param name name May be null. Name for the output variable + * @param x First input boolean array (BOOL type) + * @param y Second input boolean array (BOOL type) + * @return output Boolean AND result (BOOL type) + */ + public SDVariable booleanAnd(String name, SDVariable x, SDVariable y) { + SDValidation.validateBool("booleanAnd", "x", x); + SDValidation.validateBool("booleanAnd", "y", y); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.BooleanAnd(sd,x, y).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Boolean NOT operation: elementwise !x
+ * + * @param x Input boolean array (BOOL type) + * @return output Boolean NOT result (BOOL type) + */ + public SDVariable booleanNot(SDVariable x) { + SDValidation.validateBool("booleanNot", "x", x); + return new org.nd4j.linalg.api.ops.impl.transforms.bool.BooleanNot(sd,x).outputVariable(); + } + + /** + * Boolean NOT operation: elementwise !x
+ * + * @param name name May be null. Name for the output variable + * @param x Input boolean array (BOOL type) + * @return output Boolean NOT result (BOOL type) + */ + public SDVariable booleanNot(String name, SDVariable x) { + SDValidation.validateBool("booleanNot", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.bool.BooleanNot(sd,x).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Boolean OR operation: elementwise x || y. Supports broadcasting.
+ * + * @param x First input boolean array (BOOL type) + * @param y Second input boolean array (BOOL type) + * @return output Boolean OR result (BOOL type) + */ + public SDVariable booleanOr(SDVariable x, SDVariable y) { + SDValidation.validateBool("booleanOr", "x", x); + SDValidation.validateBool("booleanOr", "y", y); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.BooleanOr(sd,x, y).outputVariable(); + } + + /** + * Boolean OR operation: elementwise x || y. Supports broadcasting.
+ * + * @param name name May be null. Name for the output variable + * @param x First input boolean array (BOOL type) + * @param y Second input boolean array (BOOL type) + * @return output Boolean OR result (BOOL type) + */ + public SDVariable booleanOr(String name, SDVariable x, SDVariable y) { + SDValidation.validateBool("booleanOr", "x", x); + SDValidation.validateBool("booleanOr", "y", y); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.BooleanOr(sd,x, y).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Boolean XOR operation: elementwise x ^ y. Supports broadcasting.
+ * + * @param x First input boolean array (BOOL type) + * @param y Second input boolean array (BOOL type) + * @return output Boolean XOR result (BOOL type) + */ + public SDVariable booleanXor(SDVariable x, SDVariable y) { + SDValidation.validateBool("booleanXor", "x", x); + SDValidation.validateBool("booleanXor", "y", y); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.BooleanXor(sd,x, y).outputVariable(); + } + + /** + * Boolean XOR operation: elementwise x ^ y. Supports broadcasting.
+ * + * @param name name May be null. Name for the output variable + * @param x First input boolean array (BOOL type) + * @param y Second input boolean array (BOOL type) + * @return output Boolean XOR result (BOOL type) + */ + public SDVariable booleanXor(String name, SDVariable x, SDVariable y) { + SDValidation.validateBool("booleanXor", "x", x); + SDValidation.validateBool("booleanXor", "y", y); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.BooleanXor(sd,x, y).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + /** * Cast the array to a new datatype - for example, Integer -> Float
* @@ -4716,6 +4824,34 @@ public SDVariable squeeze(String name, SDVariable x, int axis) { return sd.updateVariableNameAndReference(out, name); } + /** + * Remove all dimensions of size 1 from the input tensor.
+ * For example, if input has shape [a,1,b,1,c] then squeezeAll(input) returns an array of shape [a,b,c]
+ * This is the NumPy-style squeeze with no axis specified.
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable squeezeAll(SDVariable x) { + SDValidation.validateNumerical("squeezeAll", "x", x); + return new org.nd4j.linalg.api.ops.impl.shape.Squeeze(sd,x).outputVariable(); + } + + /** + * Remove all dimensions of size 1 from the input tensor.
+ * For example, if input has shape [a,1,b,1,c] then squeezeAll(input) returns an array of shape [a,b,c]
+ * This is the NumPy-style squeeze with no axis specified.
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable squeezeAll(String name, SDVariable x) { + SDValidation.validateNumerical("squeezeAll", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.Squeeze(sd,x).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + /** * Stack a set of N INDArray of rank X into one rank X+1 variable.
* If inputs have shape [a,b,c] then output has shape:
diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDNN.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDNN.java index 03150f6189f6..7f56b5fc131e 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDNN.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDNN.java @@ -145,6 +145,204 @@ public SDVariable biasAdd(String name, SDVariable input, SDVariable bias, boolea return sd.updateVariableNameAndReference(out, name); } + /** + * DINOv2 centering and sharpening operation.
+ * Prevents mode collapse in self-supervised learning by centering the teacher output
+ * and applying temperature-based sharpening:
+ * output = softmax((input - center) / temperature)
+ * + * @param input Teacher output logits [batch, features] (NUMERIC type) + * @param center Running center vector [features] (NUMERIC type) + * @param temperature Sharpening temperature (typically 0.04-0.07) + * @return output Sharpened probabilities [batch, features] (NUMERIC type) + */ + public SDVariable centerAndSharpen(SDVariable input, SDVariable center, double temperature) { + SDValidation.validateNumerical("centerAndSharpen", "input", input); + SDValidation.validateNumerical("centerAndSharpen", "center", center); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.CenterAndSharpen(sd,input, center, temperature).outputVariable(); + } + + /** + * DINOv2 centering and sharpening operation.
+ * Prevents mode collapse in self-supervised learning by centering the teacher output
+ * and applying temperature-based sharpening:
+ * output = softmax((input - center) / temperature)
+ * + * @param name name May be null. Name for the output variable + * @param input Teacher output logits [batch, features] (NUMERIC type) + * @param center Running center vector [features] (NUMERIC type) + * @param temperature Sharpening temperature (typically 0.04-0.07) + * @return output Sharpened probabilities [batch, features] (NUMERIC type) + */ + public SDVariable centerAndSharpen(String name, SDVariable input, SDVariable center, + double temperature) { + SDValidation.validateNumerical("centerAndSharpen", "input", input); + SDValidation.validateNumerical("centerAndSharpen", "center", center); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.CenterAndSharpen(sd,input, center, temperature).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * DINOv2 centering and sharpening operation.
+ * Prevents mode collapse in self-supervised learning by centering the teacher output
+ * and applying temperature-based sharpening:
+ * output = softmax((input - center) / temperature)
+ * + * @param input Teacher output logits [batch, features] (NUMERIC type) + * @param center Running center vector [features] (NUMERIC type) + * @return output Sharpened probabilities [batch, features] (NUMERIC type) + */ + public SDVariable centerAndSharpen(SDVariable input, SDVariable center) { + SDValidation.validateNumerical("centerAndSharpen", "input", input); + SDValidation.validateNumerical("centerAndSharpen", "center", center); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.CenterAndSharpen(sd,input, center, 0.07).outputVariable(); + } + + /** + * DINOv2 centering and sharpening operation.
+ * Prevents mode collapse in self-supervised learning by centering the teacher output
+ * and applying temperature-based sharpening:
+ * output = softmax((input - center) / temperature)
+ * + * @param name name May be null. Name for the output variable + * @param input Teacher output logits [batch, features] (NUMERIC type) + * @param center Running center vector [features] (NUMERIC type) + * @return output Sharpened probabilities [batch, features] (NUMERIC type) + */ + public SDVariable centerAndSharpen(String name, SDVariable input, SDVariable center) { + SDValidation.validateNumerical("centerAndSharpen", "input", input); + SDValidation.validateNumerical("centerAndSharpen", "center", center); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.CenterAndSharpen(sd,input, center, 0.07).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * CTC Greedy Decoder - Connectionist Temporal Classification decoding.
+ *
+ * Performs greedy (best path) decoding on CTC output. Used in:
+ * - OCR (Optical Character Recognition) - PaddleOCR, CRNN
+ * - Speech recognition - DeepSpeech, Wav2Vec
+ * - Handwriting recognition
+ *
+ * Algorithm:
+ * 1. At each timestep, select the class with highest probability
+ * 2. Optionally merge consecutive repeated characters
+ * 3. Remove blank labels from the output
+ *
+ * For example, with mergeRepeated=true and blankIndex=0:
+ * Input: [0, 1, 1, 0, 2, 2, 2, 0] (0=blank, 1='a', 2='b')
+ * Output: [1, 2] -> "ab"
+ *
+ * Note: This is greedy decoding. For better accuracy with language models,
+ * use beam search decoding instead.
+ * + * @param logits Log probabilities from CTC output. Shape: [batch, timeSteps, numClasses] (NUMERIC type) + * @param mergeRepeated Whether to merge repeated characters in output + * @param blankIndex Index of the blank label in the vocabulary + */ + public SDVariable[] ctcGreedyDecoder(SDVariable logits, boolean mergeRepeated, int blankIndex) { + SDValidation.validateNumerical("ctcGreedyDecoder", "logits", logits); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.CTCGreedyDecoder(sd,logits, null, mergeRepeated, blankIndex).outputVariables(); + } + + /** + * CTC Greedy Decoder - Connectionist Temporal Classification decoding.
+ *
+ * Performs greedy (best path) decoding on CTC output. Used in:
+ * - OCR (Optical Character Recognition) - PaddleOCR, CRNN
+ * - Speech recognition - DeepSpeech, Wav2Vec
+ * - Handwriting recognition
+ *
+ * Algorithm:
+ * 1. At each timestep, select the class with highest probability
+ * 2. Optionally merge consecutive repeated characters
+ * 3. Remove blank labels from the output
+ *
+ * For example, with mergeRepeated=true and blankIndex=0:
+ * Input: [0, 1, 1, 0, 2, 2, 2, 0] (0=blank, 1='a', 2='b')
+ * Output: [1, 2] -> "ab"
+ *
+ * Note: This is greedy decoding. For better accuracy with language models,
+ * use beam search decoding instead.
+ * + * @param names names May be null. Arrays of names for the output variables. + * @param logits Log probabilities from CTC output. Shape: [batch, timeSteps, numClasses] (NUMERIC type) + * @param mergeRepeated Whether to merge repeated characters in output + * @param blankIndex Index of the blank label in the vocabulary + */ + public SDVariable[] ctcGreedyDecoder(String[] names, SDVariable logits, boolean mergeRepeated, + int blankIndex) { + SDValidation.validateNumerical("ctcGreedyDecoder", "logits", logits); + SDVariable[] out = new org.nd4j.linalg.api.ops.impl.transforms.custom.CTCGreedyDecoder(sd,logits, null, mergeRepeated, blankIndex).outputVariables(); + return sd.updateVariableNamesAndReferences(out, names); + } + + /** + * CTC Greedy Decoder - Connectionist Temporal Classification decoding.
+ *
+ * Performs greedy (best path) decoding on CTC output. Used in:
+ * - OCR (Optical Character Recognition) - PaddleOCR, CRNN
+ * - Speech recognition - DeepSpeech, Wav2Vec
+ * - Handwriting recognition
+ *
+ * Algorithm:
+ * 1. At each timestep, select the class with highest probability
+ * 2. Optionally merge consecutive repeated characters
+ * 3. Remove blank labels from the output
+ *
+ * For example, with mergeRepeated=true and blankIndex=0:
+ * Input: [0, 1, 1, 0, 2, 2, 2, 0] (0=blank, 1='a', 2='b')
+ * Output: [1, 2] -> "ab"
+ *
+ * Note: This is greedy decoding. For better accuracy with language models,
+ * use beam search decoding instead.
+ * + * @param logits Log probabilities from CTC output. Shape: [batch, timeSteps, numClasses] (NUMERIC type) + * @param sequenceLength Optional actual sequence lengths. Shape: [batch] (NUMERIC type) + * @param mergeRepeated Whether to merge repeated characters in output + * @param blankIndex Index of the blank label in the vocabulary + */ + public SDVariable[] ctcGreedyDecoder(SDVariable logits, SDVariable sequenceLength, + boolean mergeRepeated, int blankIndex) { + SDValidation.validateNumerical("ctcGreedyDecoder", "logits", logits); + SDValidation.validateNumerical("ctcGreedyDecoder", "sequenceLength", sequenceLength); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.CTCGreedyDecoder(sd,logits, sequenceLength, mergeRepeated, blankIndex).outputVariables(); + } + + /** + * CTC Greedy Decoder - Connectionist Temporal Classification decoding.
+ *
+ * Performs greedy (best path) decoding on CTC output. Used in:
+ * - OCR (Optical Character Recognition) - PaddleOCR, CRNN
+ * - Speech recognition - DeepSpeech, Wav2Vec
+ * - Handwriting recognition
+ *
+ * Algorithm:
+ * 1. At each timestep, select the class with highest probability
+ * 2. Optionally merge consecutive repeated characters
+ * 3. Remove blank labels from the output
+ *
+ * For example, with mergeRepeated=true and blankIndex=0:
+ * Input: [0, 1, 1, 0, 2, 2, 2, 0] (0=blank, 1='a', 2='b')
+ * Output: [1, 2] -> "ab"
+ *
+ * Note: This is greedy decoding. For better accuracy with language models,
+ * use beam search decoding instead.
+ * + * @param names names May be null. Arrays of names for the output variables. + * @param logits Log probabilities from CTC output. Shape: [batch, timeSteps, numClasses] (NUMERIC type) + * @param sequenceLength Optional actual sequence lengths. Shape: [batch] (NUMERIC type) + * @param mergeRepeated Whether to merge repeated characters in output + * @param blankIndex Index of the blank label in the vocabulary + */ + public SDVariable[] ctcGreedyDecoder(String[] names, SDVariable logits, SDVariable sequenceLength, + boolean mergeRepeated, int blankIndex) { + SDValidation.validateNumerical("ctcGreedyDecoder", "logits", logits); + SDValidation.validateNumerical("ctcGreedyDecoder", "sequenceLength", sequenceLength); + SDVariable[] out = new org.nd4j.linalg.api.ops.impl.transforms.custom.CTCGreedyDecoder(sd,logits, sequenceLength, mergeRepeated, blankIndex).outputVariables(); + return sd.updateVariableNamesAndReferences(out, names); + } + /** * This operation performs dot product attention on the given timeseries input with the given queries
* out = sum(similarity(k_i, q) * v_i)
@@ -228,35 +426,37 @@ public SDVariable dotProductAttention(String name, SDVariable queries, SDVariabl } /** - * This operation performs dot product attention on the given timeseries input with the given queries
- * out = sum(similarity(k_i, q) * v_i)
+ * Dot product attention operation with flash attention and KV cache support.
*
- * similarity(k, q) = softmax(k * q) where x * q is the dot product of x and q
+ * out = softmax(Q * K^T / scale + attentionBias) * V
*
- * Optionally with normalization step:
- * similarity(k, q) = softmax(k * q / sqrt(size(q))
- *
- * See also "Attention is all you need" (https://arxiv.org/abs/1706.03762, p. 4, eq. 1)
+ * For 4D inputs [batch, seq, heads, dim], uses memory-efficient flash attention algorithm.
+ * For 2D/3D inputs, uses standard attention computation.
*
- * Note: This supports multiple queries at once, if only one query is available the queries vector still has to
- * be 3D but can have queryCount = 1
+ * Flash attention features:
+ * - O(N) memory complexity instead of O(N^2)
+ * - Tiled computation with online softmax
+ * - Supports grouped query attention (GQA) where numHeads > numKvHeads
+ * - Supports attention bias (relative position bias, ALiBi, etc.)
*
- * Note: keys and values usually is the same array. If you want to use it as the same array, simply pass it for
- * both.
+ * KV Cache support for autoregressive generation:
+ * - Pass keyCache and valueCache tensors
+ * - Set kvCachePosition to current generation position
+ * - Cached keys/values are updated in-place
*
- * Note: Queries, keys and values must either be all rank 3 or all rank 4 arrays. Mixing them doesn't work. The
- * output rank will depend on the input rank.
- * - * @param queries A {@link SDVariable} representing the query tensor. Shape: [batchSize, numQueries, queryDim] (NUMERIC type) - * @param values A {@link SDVariable} representing the value tensor. Shape: [batchSize, numValues, valueDim] (NUMERIC type) - * @param keys A {@link SDVariable} representing the key tensor. Shape: [batchSize, numValues, keyDim] (NUMERIC type) - * @param queryMask A {@link SDVariable} representing the query mask tensor. Shape: [batchSize, numQueries] (NUMERIC type) - * @param valueMask @param valueMask A {@link SDVariable} representing the value mask tensor. Shape: [batchSize, numValues] (NUMERIC type) - * @param scaleFactor @param scaleFactor A {@code double} scaling factor applied to the dot product between queries and keys. - * @param dropoutProbability A {@code double} specifying the dropout probability to be applied to attention weights. - * @param useCausalMask A {@code boolean} flag to indicate whether to apply a causal mask to the attention scores, for autoregressive tasks. - * @param training A {@code boolean} flag to indicate whether the layer is in training mode or inference mode, affecting dropout. - * @return output A {@link SDVariable} representing the output tensor of the dot product attention operation. Shape: [batchSize, numQueries, valueDim] (NUMERIC type) + * See "Attention is all you need" (https://arxiv.org/abs/1706.03762)
+ * See "FlashAttention: Fast and Memory-Efficient Exact Attention" (https://arxiv.org/abs/2205.14135)
+ * + * @param queries Query tensor. Shape: [batchSize, numQueries, queryDim] or [batchSize, numQueries, numHeads, headDim] for flash attention (NUMERIC type) + * @param values Value tensor. Shape: [batchSize, numValues, valueDim] or [batchSize, numValues, numHeads, headDim] (NUMERIC type) + * @param keys Key tensor. Shape: [batchSize, numValues, keyDim] or [batchSize, numValues, numHeads, headDim] (NUMERIC type) + * @param queryMask Query mask tensor (optional). Shape: [batchSize, numQueries] (NUMERIC type) + * @param valueMask Value mask tensor (optional). Shape: [batchSize, numValues] (NUMERIC type) + * @param scaleFactor Scaling factor applied to attention scores. 0 = auto (1/sqrt(headDim)) + * @param dropoutProbability Dropout probability applied to attention weights + * @param useCausalMask Whether to apply causal mask for autoregressive tasks + * @param training Whether in training mode (affects dropout) + * @return output Output tensor. Shape: [batchSize, numQueries, valueDim] or [batchSize, numQueries, numHeads, headDim] (NUMERIC type) */ public SDVariable dotProductAttentionV2(SDVariable queries, SDVariable values, SDVariable keys, SDVariable queryMask, SDVariable valueMask, double scaleFactor, double dropoutProbability, @@ -266,40 +466,42 @@ public SDVariable dotProductAttentionV2(SDVariable queries, SDVariable values, S SDValidation.validateNumerical("dotProductAttentionV2", "keys", keys); SDValidation.validateNumerical("dotProductAttentionV2", "queryMask", queryMask); SDValidation.validateNumerical("dotProductAttentionV2", "valueMask", valueMask); - return new org.nd4j.linalg.api.ops.impl.transforms.custom.DotProductAttentionV2(sd,queries, values, keys, queryMask, valueMask, scaleFactor, dropoutProbability, useCausalMask, training).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.DotProductAttentionV2(sd,queries, values, keys, queryMask, valueMask, null, scaleFactor, dropoutProbability, useCausalMask, training).outputVariable(); } /** - * This operation performs dot product attention on the given timeseries input with the given queries
- * out = sum(similarity(k_i, q) * v_i)
+ * Dot product attention operation with flash attention and KV cache support.
*
- * similarity(k, q) = softmax(k * q) where x * q is the dot product of x and q
+ * out = softmax(Q * K^T / scale + attentionBias) * V
*
- * Optionally with normalization step:
- * similarity(k, q) = softmax(k * q / sqrt(size(q))
+ * For 4D inputs [batch, seq, heads, dim], uses memory-efficient flash attention algorithm.
+ * For 2D/3D inputs, uses standard attention computation.
*
- * See also "Attention is all you need" (https://arxiv.org/abs/1706.03762, p. 4, eq. 1)
- *
- * Note: This supports multiple queries at once, if only one query is available the queries vector still has to
- * be 3D but can have queryCount = 1
+ * Flash attention features:
+ * - O(N) memory complexity instead of O(N^2)
+ * - Tiled computation with online softmax
+ * - Supports grouped query attention (GQA) where numHeads > numKvHeads
+ * - Supports attention bias (relative position bias, ALiBi, etc.)
*
- * Note: keys and values usually is the same array. If you want to use it as the same array, simply pass it for
- * both.
+ * KV Cache support for autoregressive generation:
+ * - Pass keyCache and valueCache tensors
+ * - Set kvCachePosition to current generation position
+ * - Cached keys/values are updated in-place
*
- * Note: Queries, keys and values must either be all rank 3 or all rank 4 arrays. Mixing them doesn't work. The
- * output rank will depend on the input rank.
+ * See "Attention is all you need" (https://arxiv.org/abs/1706.03762)
+ * See "FlashAttention: Fast and Memory-Efficient Exact Attention" (https://arxiv.org/abs/2205.14135)
* * @param name name May be null. Name for the output variable - * @param queries A {@link SDVariable} representing the query tensor. Shape: [batchSize, numQueries, queryDim] (NUMERIC type) - * @param values A {@link SDVariable} representing the value tensor. Shape: [batchSize, numValues, valueDim] (NUMERIC type) - * @param keys A {@link SDVariable} representing the key tensor. Shape: [batchSize, numValues, keyDim] (NUMERIC type) - * @param queryMask A {@link SDVariable} representing the query mask tensor. Shape: [batchSize, numQueries] (NUMERIC type) - * @param valueMask @param valueMask A {@link SDVariable} representing the value mask tensor. Shape: [batchSize, numValues] (NUMERIC type) - * @param scaleFactor @param scaleFactor A {@code double} scaling factor applied to the dot product between queries and keys. - * @param dropoutProbability A {@code double} specifying the dropout probability to be applied to attention weights. - * @param useCausalMask A {@code boolean} flag to indicate whether to apply a causal mask to the attention scores, for autoregressive tasks. - * @param training A {@code boolean} flag to indicate whether the layer is in training mode or inference mode, affecting dropout. - * @return output A {@link SDVariable} representing the output tensor of the dot product attention operation. Shape: [batchSize, numQueries, valueDim] (NUMERIC type) + * @param queries Query tensor. Shape: [batchSize, numQueries, queryDim] or [batchSize, numQueries, numHeads, headDim] for flash attention (NUMERIC type) + * @param values Value tensor. Shape: [batchSize, numValues, valueDim] or [batchSize, numValues, numHeads, headDim] (NUMERIC type) + * @param keys Key tensor. Shape: [batchSize, numValues, keyDim] or [batchSize, numValues, numHeads, headDim] (NUMERIC type) + * @param queryMask Query mask tensor (optional). Shape: [batchSize, numQueries] (NUMERIC type) + * @param valueMask Value mask tensor (optional). Shape: [batchSize, numValues] (NUMERIC type) + * @param scaleFactor Scaling factor applied to attention scores. 0 = auto (1/sqrt(headDim)) + * @param dropoutProbability Dropout probability applied to attention weights + * @param useCausalMask Whether to apply causal mask for autoregressive tasks + * @param training Whether in training mode (affects dropout) + * @return output Output tensor. Shape: [batchSize, numQueries, valueDim] or [batchSize, numQueries, numHeads, headDim] (NUMERIC type) */ public SDVariable dotProductAttentionV2(String name, SDVariable queries, SDVariable values, SDVariable keys, SDVariable queryMask, SDVariable valueMask, double scaleFactor, @@ -309,7 +511,101 @@ public SDVariable dotProductAttentionV2(String name, SDVariable queries, SDVaria SDValidation.validateNumerical("dotProductAttentionV2", "keys", keys); SDValidation.validateNumerical("dotProductAttentionV2", "queryMask", queryMask); SDValidation.validateNumerical("dotProductAttentionV2", "valueMask", valueMask); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.DotProductAttentionV2(sd,queries, values, keys, queryMask, valueMask, scaleFactor, dropoutProbability, useCausalMask, training).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.DotProductAttentionV2(sd,queries, values, keys, queryMask, valueMask, null, scaleFactor, dropoutProbability, useCausalMask, training).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Dot product attention operation with flash attention and KV cache support.
+ *
+ * out = softmax(Q * K^T / scale + attentionBias) * V
+ *
+ * For 4D inputs [batch, seq, heads, dim], uses memory-efficient flash attention algorithm.
+ * For 2D/3D inputs, uses standard attention computation.
+ *
+ * Flash attention features:
+ * - O(N) memory complexity instead of O(N^2)
+ * - Tiled computation with online softmax
+ * - Supports grouped query attention (GQA) where numHeads > numKvHeads
+ * - Supports attention bias (relative position bias, ALiBi, etc.)
+ *
+ * KV Cache support for autoregressive generation:
+ * - Pass keyCache and valueCache tensors
+ * - Set kvCachePosition to current generation position
+ * - Cached keys/values are updated in-place
+ *
+ * See "Attention is all you need" (https://arxiv.org/abs/1706.03762)
+ * See "FlashAttention: Fast and Memory-Efficient Exact Attention" (https://arxiv.org/abs/2205.14135)
+ * + * @param queries Query tensor. Shape: [batchSize, numQueries, queryDim] or [batchSize, numQueries, numHeads, headDim] for flash attention (NUMERIC type) + * @param values Value tensor. Shape: [batchSize, numValues, valueDim] or [batchSize, numValues, numHeads, headDim] (NUMERIC type) + * @param keys Key tensor. Shape: [batchSize, numValues, keyDim] or [batchSize, numValues, numHeads, headDim] (NUMERIC type) + * @param queryMask Query mask tensor (optional). Shape: [batchSize, numQueries] (NUMERIC type) + * @param valueMask Value mask tensor (optional). Shape: [batchSize, numValues] (NUMERIC type) + * @param attentionBias Attention bias tensor (optional). Shape: [batchSize, numHeads, numQueries, numKeys] or broadcastable. Added to attention scores before softmax. (NUMERIC type) + * @param scaleFactor Scaling factor applied to attention scores. 0 = auto (1/sqrt(headDim)) + * @param dropoutProbability Dropout probability applied to attention weights + * @param useCausalMask Whether to apply causal mask for autoregressive tasks + * @param training Whether in training mode (affects dropout) + * @return output Output tensor. Shape: [batchSize, numQueries, valueDim] or [batchSize, numQueries, numHeads, headDim] (NUMERIC type) + */ + public SDVariable dotProductAttentionV2(SDVariable queries, SDVariable values, SDVariable keys, + SDVariable queryMask, SDVariable valueMask, SDVariable attentionBias, double scaleFactor, + double dropoutProbability, boolean useCausalMask, boolean training) { + SDValidation.validateNumerical("dotProductAttentionV2", "queries", queries); + SDValidation.validateNumerical("dotProductAttentionV2", "values", values); + SDValidation.validateNumerical("dotProductAttentionV2", "keys", keys); + SDValidation.validateNumerical("dotProductAttentionV2", "queryMask", queryMask); + SDValidation.validateNumerical("dotProductAttentionV2", "valueMask", valueMask); + SDValidation.validateNumerical("dotProductAttentionV2", "attentionBias", attentionBias); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.DotProductAttentionV2(sd,queries, values, keys, queryMask, valueMask, attentionBias, scaleFactor, dropoutProbability, useCausalMask, training).outputVariable(); + } + + /** + * Dot product attention operation with flash attention and KV cache support.
+ *
+ * out = softmax(Q * K^T / scale + attentionBias) * V
+ *
+ * For 4D inputs [batch, seq, heads, dim], uses memory-efficient flash attention algorithm.
+ * For 2D/3D inputs, uses standard attention computation.
+ *
+ * Flash attention features:
+ * - O(N) memory complexity instead of O(N^2)
+ * - Tiled computation with online softmax
+ * - Supports grouped query attention (GQA) where numHeads > numKvHeads
+ * - Supports attention bias (relative position bias, ALiBi, etc.)
+ *
+ * KV Cache support for autoregressive generation:
+ * - Pass keyCache and valueCache tensors
+ * - Set kvCachePosition to current generation position
+ * - Cached keys/values are updated in-place
+ *
+ * See "Attention is all you need" (https://arxiv.org/abs/1706.03762)
+ * See "FlashAttention: Fast and Memory-Efficient Exact Attention" (https://arxiv.org/abs/2205.14135)
+ * + * @param name name May be null. Name for the output variable + * @param queries Query tensor. Shape: [batchSize, numQueries, queryDim] or [batchSize, numQueries, numHeads, headDim] for flash attention (NUMERIC type) + * @param values Value tensor. Shape: [batchSize, numValues, valueDim] or [batchSize, numValues, numHeads, headDim] (NUMERIC type) + * @param keys Key tensor. Shape: [batchSize, numValues, keyDim] or [batchSize, numValues, numHeads, headDim] (NUMERIC type) + * @param queryMask Query mask tensor (optional). Shape: [batchSize, numQueries] (NUMERIC type) + * @param valueMask Value mask tensor (optional). Shape: [batchSize, numValues] (NUMERIC type) + * @param attentionBias Attention bias tensor (optional). Shape: [batchSize, numHeads, numQueries, numKeys] or broadcastable. Added to attention scores before softmax. (NUMERIC type) + * @param scaleFactor Scaling factor applied to attention scores. 0 = auto (1/sqrt(headDim)) + * @param dropoutProbability Dropout probability applied to attention weights + * @param useCausalMask Whether to apply causal mask for autoregressive tasks + * @param training Whether in training mode (affects dropout) + * @return output Output tensor. Shape: [batchSize, numQueries, valueDim] or [batchSize, numQueries, numHeads, headDim] (NUMERIC type) + */ + public SDVariable dotProductAttentionV2(String name, SDVariable queries, SDVariable values, + SDVariable keys, SDVariable queryMask, SDVariable valueMask, SDVariable attentionBias, + double scaleFactor, double dropoutProbability, boolean useCausalMask, boolean training) { + SDValidation.validateNumerical("dotProductAttentionV2", "queries", queries); + SDValidation.validateNumerical("dotProductAttentionV2", "values", values); + SDValidation.validateNumerical("dotProductAttentionV2", "keys", keys); + SDValidation.validateNumerical("dotProductAttentionV2", "queryMask", queryMask); + SDValidation.validateNumerical("dotProductAttentionV2", "valueMask", valueMask); + SDValidation.validateNumerical("dotProductAttentionV2", "attentionBias", attentionBias); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.DotProductAttentionV2(sd,queries, values, keys, queryMask, valueMask, attentionBias, scaleFactor, dropoutProbability, useCausalMask, training).outputVariable(); return sd.updateVariableNameAndReference(out, name); } @@ -407,6 +703,172 @@ public SDVariable elu(String name, SDVariable x) { return sd.updateVariableNameAndReference(out, name); } + /** + * Exponential Moving Average parameter update for DINOv2 teacher networks.
+ * Computes: output = decay * shadow + (1 - decay) * model
+ * Used in self-supervised learning to maintain a slowly-updated teacher model.
+ * + * @param model Current model parameters (student) (NUMERIC type) + * @param shadow EMA shadow parameters (teacher) (NUMERIC type) + * @param decay EMA decay factor (typically 0.996-0.9999) + * @return output Updated shadow parameters (NUMERIC type) + */ + public SDVariable emaUpdate(SDVariable model, SDVariable shadow, double decay) { + SDValidation.validateNumerical("emaUpdate", "model", model); + SDValidation.validateNumerical("emaUpdate", "shadow", shadow); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.EmaUpdate(sd,model, shadow, decay).outputVariable(); + } + + /** + * Exponential Moving Average parameter update for DINOv2 teacher networks.
+ * Computes: output = decay * shadow + (1 - decay) * model
+ * Used in self-supervised learning to maintain a slowly-updated teacher model.
+ * + * @param name name May be null. Name for the output variable + * @param model Current model parameters (student) (NUMERIC type) + * @param shadow EMA shadow parameters (teacher) (NUMERIC type) + * @param decay EMA decay factor (typically 0.996-0.9999) + * @return output Updated shadow parameters (NUMERIC type) + */ + public SDVariable emaUpdate(String name, SDVariable model, SDVariable shadow, double decay) { + SDValidation.validateNumerical("emaUpdate", "model", model); + SDValidation.validateNumerical("emaUpdate", "shadow", shadow); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.EmaUpdate(sd,model, shadow, decay).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Exponential Moving Average parameter update for DINOv2 teacher networks.
+ * Computes: output = decay * shadow + (1 - decay) * model
+ * Used in self-supervised learning to maintain a slowly-updated teacher model.
+ * + * @param model Current model parameters (student) (NUMERIC type) + * @param shadow EMA shadow parameters (teacher) (NUMERIC type) + * @return output Updated shadow parameters (NUMERIC type) + */ + public SDVariable emaUpdate(SDVariable model, SDVariable shadow) { + SDValidation.validateNumerical("emaUpdate", "model", model); + SDValidation.validateNumerical("emaUpdate", "shadow", shadow); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.EmaUpdate(sd,model, shadow, 0.999).outputVariable(); + } + + /** + * Exponential Moving Average parameter update for DINOv2 teacher networks.
+ * Computes: output = decay * shadow + (1 - decay) * model
+ * Used in self-supervised learning to maintain a slowly-updated teacher model.
+ * + * @param name name May be null. Name for the output variable + * @param model Current model parameters (student) (NUMERIC type) + * @param shadow EMA shadow parameters (teacher) (NUMERIC type) + * @return output Updated shadow parameters (NUMERIC type) + */ + public SDVariable emaUpdate(String name, SDVariable model, SDVariable shadow) { + SDValidation.validateNumerical("emaUpdate", "model", model); + SDValidation.validateNumerical("emaUpdate", "shadow", shadow); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.EmaUpdate(sd,model, shadow, 0.999).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Flash Attention - Memory-efficient attention computation.
+ *
+ * Uses tiled computation with online softmax to achieve O(N) memory complexity
+ * instead of O(N^2) for standard attention.
+ *
+ * Supports Grouped Query Attention (GQA) where numHeads > numKvHeads,
+ * allowing multiple query heads to share the same KV heads.
+ *
+ * out = softmax(Q * K^T / scale) * V
+ *
+ * See "FlashAttention: Fast and Memory-Efficient Exact Attention" (https://arxiv.org/abs/2205.14135)
+ * + * @param query Query tensor. Shape: [batch, seqLen, numHeads, headDim] (NUMERIC type) + * @param key Key tensor. Shape: [batch, seqLen, numKvHeads, headDim] (NUMERIC type) + * @param value Value tensor. Shape: [batch, seqLen, numKvHeads, headDim] (NUMERIC type) + * @param scale Scaling factor. 0 = auto (1/sqrt(headDim)) + * @param isCausal Whether to apply causal masking + * @param numHeads Number of query attention heads + * @param numKvHeads Number of KV heads (0 = same as numHeads, for GQA use smaller value) + * @return output Attention output. Shape: [batch, seqLen, numHeads, headDim] (NUMERIC type) + */ + public SDVariable flashAttention(SDVariable query, SDVariable key, SDVariable value, double scale, + boolean isCausal, int numHeads, int numKvHeads) { + SDValidation.validateNumerical("flashAttention", "query", query); + SDValidation.validateNumerical("flashAttention", "key", key); + SDValidation.validateNumerical("flashAttention", "value", value); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.FlashAttention(sd,query, key, value, scale, isCausal, numHeads, numKvHeads).outputVariable(); + } + + /** + * Flash Attention - Memory-efficient attention computation.
+ *
+ * Uses tiled computation with online softmax to achieve O(N) memory complexity
+ * instead of O(N^2) for standard attention.
+ *
+ * Supports Grouped Query Attention (GQA) where numHeads > numKvHeads,
+ * allowing multiple query heads to share the same KV heads.
+ *
+ * out = softmax(Q * K^T / scale) * V
+ *
+ * See "FlashAttention: Fast and Memory-Efficient Exact Attention" (https://arxiv.org/abs/2205.14135)
+ * + * @param name name May be null. Name for the output variable + * @param query Query tensor. Shape: [batch, seqLen, numHeads, headDim] (NUMERIC type) + * @param key Key tensor. Shape: [batch, seqLen, numKvHeads, headDim] (NUMERIC type) + * @param value Value tensor. Shape: [batch, seqLen, numKvHeads, headDim] (NUMERIC type) + * @param scale Scaling factor. 0 = auto (1/sqrt(headDim)) + * @param isCausal Whether to apply causal masking + * @param numHeads Number of query attention heads + * @param numKvHeads Number of KV heads (0 = same as numHeads, for GQA use smaller value) + * @return output Attention output. Shape: [batch, seqLen, numHeads, headDim] (NUMERIC type) + */ + public SDVariable flashAttention(String name, SDVariable query, SDVariable key, SDVariable value, + double scale, boolean isCausal, int numHeads, int numKvHeads) { + SDValidation.validateNumerical("flashAttention", "query", query); + SDValidation.validateNumerical("flashAttention", "key", key); + SDValidation.validateNumerical("flashAttention", "value", value); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.FlashAttention(sd,query, key, value, scale, isCausal, numHeads, numKvHeads).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Executes a fused chain of element-wise operations in a single kernel pass.
+ * Intermediate values stay in registers instead of global memory. Replaces N separate kernel launches with 1.
+ * + * @param input Primary input array (NUMERIC type) + * @param secondaryInputs Optional secondary input arrays for binary ops (add, sub, mul, div) (NUMERIC type) + * @param opCodes Op codes: 0=add, 1=sub, 2=mul, 3=div, 10=relu, 11=sigmoid, 12=tanh, 13=gelu, 14=exp, 15=log, 16=abs, 17=neg, 18=square, 19=sqrt, 20=swish, 21=silu, 22=mish, 30=clip, 31=leaky_relu (Size: AtLeast(min=1)) + * @return output Result of applying the fused element-wise chain (NUMERIC type) + */ + public SDVariable fusedElementwiseChain(SDVariable input, SDVariable[] secondaryInputs, + int[] opCodes) { + SDValidation.validateNumerical("fusedElementwiseChain", "input", input); + SDValidation.validateNumerical("fusedElementwiseChain", "secondaryInputs", secondaryInputs); + Preconditions.checkArgument(secondaryInputs.length >= 0, "secondaryInputs has incorrect size/length. Expected: secondaryInputs.length >= 0, got %s", secondaryInputs.length); + Preconditions.checkArgument(opCodes.length >= 1, "opCodes has incorrect size/length. Expected: opCodes.length >= 1, got %s", opCodes.length); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.FusedElementwiseChain(sd,input, secondaryInputs, opCodes).outputVariable(); + } + + /** + * Executes a fused chain of element-wise operations in a single kernel pass.
+ * Intermediate values stay in registers instead of global memory. Replaces N separate kernel launches with 1.
+ * + * @param name name May be null. Name for the output variable + * @param input Primary input array (NUMERIC type) + * @param secondaryInputs Optional secondary input arrays for binary ops (add, sub, mul, div) (NUMERIC type) + * @param opCodes Op codes: 0=add, 1=sub, 2=mul, 3=div, 10=relu, 11=sigmoid, 12=tanh, 13=gelu, 14=exp, 15=log, 16=abs, 17=neg, 18=square, 19=sqrt, 20=swish, 21=silu, 22=mish, 30=clip, 31=leaky_relu (Size: AtLeast(min=1)) + * @return output Result of applying the fused element-wise chain (NUMERIC type) + */ + public SDVariable fusedElementwiseChain(String name, SDVariable input, + SDVariable[] secondaryInputs, int[] opCodes) { + SDValidation.validateNumerical("fusedElementwiseChain", "input", input); + SDValidation.validateNumerical("fusedElementwiseChain", "secondaryInputs", secondaryInputs); + Preconditions.checkArgument(secondaryInputs.length >= 0, "secondaryInputs has incorrect size/length. Expected: secondaryInputs.length >= 0, got %s", secondaryInputs.length); + Preconditions.checkArgument(opCodes.length >= 1, "opCodes has incorrect size/length. Expected: opCodes.length >= 1, got %s", opCodes.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.FusedElementwiseChain(sd,input, secondaryInputs, opCodes).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + /** * GELU activation function - Gaussian Error Linear Units
* For more details, see Gaussian Error Linear Units (GELUs) - https://arxiv.org/abs/1606.08415
@@ -435,6 +897,72 @@ public SDVariable gelu(String name, SDVariable x) { return sd.updateVariableNameAndReference(out, name); } + /** + * Grouped Query Attention (GQA) - Efficient attention with shared KV heads.
+ *
+ * Multiple query heads share the same key-value heads, reducing memory and
+ * computation while maintaining model quality. Used in LLaMA 2, Mistral, etc.
+ *
+ * numHeads must be divisible by numKvHeads. Each KV head is repeated
+ * (numHeads / numKvHeads) times to match query heads.
+ *
+ * Special cases:
+ * - numKvHeads == numHeads: Standard Multi-Head Attention (MHA)
+ * - numKvHeads == 1: Multi-Query Attention (MQA)
+ *
+ * See "GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints"
+ * + * @param query Query tensor. Shape: [batch, seqLen, numHeads, headDim] (NUMERIC type) + * @param key Key tensor. Shape: [batch, seqLen, numKvHeads, headDim] (NUMERIC type) + * @param value Value tensor. Shape: [batch, seqLen, numKvHeads, headDim] (NUMERIC type) + * @param scale Scaling factor. 0 = auto (1/sqrt(headDim)) + * @param isCausal Whether to apply causal masking + * @param numHeads Number of query attention heads + * @param numKvHeads Number of KV heads (must divide numHeads evenly) + * @return output Attention output. Shape: [batch, seqLen, numHeads, headDim] (NUMERIC type) + */ + public SDVariable groupedQueryAttention(SDVariable query, SDVariable key, SDVariable value, + double scale, boolean isCausal, int numHeads, int numKvHeads) { + SDValidation.validateNumerical("groupedQueryAttention", "query", query); + SDValidation.validateNumerical("groupedQueryAttention", "key", key); + SDValidation.validateNumerical("groupedQueryAttention", "value", value); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.GroupedQueryAttention(sd,query, key, value, scale, isCausal, numHeads, numKvHeads).outputVariable(); + } + + /** + * Grouped Query Attention (GQA) - Efficient attention with shared KV heads.
+ *
+ * Multiple query heads share the same key-value heads, reducing memory and
+ * computation while maintaining model quality. Used in LLaMA 2, Mistral, etc.
+ *
+ * numHeads must be divisible by numKvHeads. Each KV head is repeated
+ * (numHeads / numKvHeads) times to match query heads.
+ *
+ * Special cases:
+ * - numKvHeads == numHeads: Standard Multi-Head Attention (MHA)
+ * - numKvHeads == 1: Multi-Query Attention (MQA)
+ *
+ * See "GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints"
+ * + * @param name name May be null. Name for the output variable + * @param query Query tensor. Shape: [batch, seqLen, numHeads, headDim] (NUMERIC type) + * @param key Key tensor. Shape: [batch, seqLen, numKvHeads, headDim] (NUMERIC type) + * @param value Value tensor. Shape: [batch, seqLen, numKvHeads, headDim] (NUMERIC type) + * @param scale Scaling factor. 0 = auto (1/sqrt(headDim)) + * @param isCausal Whether to apply causal masking + * @param numHeads Number of query attention heads + * @param numKvHeads Number of KV heads (must divide numHeads evenly) + * @return output Attention output. Shape: [batch, seqLen, numHeads, headDim] (NUMERIC type) + */ + public SDVariable groupedQueryAttention(String name, SDVariable query, SDVariable key, + SDVariable value, double scale, boolean isCausal, int numHeads, int numKvHeads) { + SDValidation.validateNumerical("groupedQueryAttention", "query", query); + SDValidation.validateNumerical("groupedQueryAttention", "key", key); + SDValidation.validateNumerical("groupedQueryAttention", "value", value); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.GroupedQueryAttention(sd,query, key, value, scale, isCausal, numHeads, numKvHeads).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + /** * Element-wise hard sigmoid function:
* out[i] = 0 if in[i] <= -2.5
@@ -519,6 +1047,175 @@ public SDVariable hardTanhDerivative(String name, SDVariable x) { return sd.updateVariableNameAndReference(out, name); } + /** + * KV Cache Update - Updates key-value cache for autoregressive generation.
+ *
+ * During LLM inference, past key-value pairs are cached to avoid redundant
+ * computation during token-by-token generation. This operation efficiently
+ * inserts new keys/values at the specified position.
+ *
+ * Usage pattern:
+ * 1. Initialize cache with zeros: [batch, maxSeqLen, numKvHeads, headDim]
+ * 2. For each new token, compute new K/V and update cache
+ * 3. Use full cached K/V for attention computation
+ *
+ * Returns updated keyCache and valueCache tensors.
+ * + * @param keyCache Existing key cache. Shape: [batch, maxSeqLen, numKvHeads, headDim] (NUMERIC type) + * @param valueCache Existing value cache. Shape: [batch, maxSeqLen, numKvHeads, headDim] (NUMERIC type) + * @param newKeys New keys to insert. Shape: [batch, newSeqLen, numKvHeads, headDim] (NUMERIC type) + * @param newValues New values to insert. Shape: [batch, newSeqLen, numKvHeads, headDim] (NUMERIC type) + * @param startPosition Position in cache where new keys/values should be inserted + */ + public SDVariable[] kvCacheUpdate(SDVariable keyCache, SDVariable valueCache, SDVariable newKeys, + SDVariable newValues, int startPosition) { + SDValidation.validateNumerical("kvCacheUpdate", "keyCache", keyCache); + SDValidation.validateNumerical("kvCacheUpdate", "valueCache", valueCache); + SDValidation.validateNumerical("kvCacheUpdate", "newKeys", newKeys); + SDValidation.validateNumerical("kvCacheUpdate", "newValues", newValues); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.KVCacheUpdate(sd,keyCache, valueCache, newKeys, newValues, startPosition).outputVariables(); + } + + /** + * KV Cache Update - Updates key-value cache for autoregressive generation.
+ *
+ * During LLM inference, past key-value pairs are cached to avoid redundant
+ * computation during token-by-token generation. This operation efficiently
+ * inserts new keys/values at the specified position.
+ *
+ * Usage pattern:
+ * 1. Initialize cache with zeros: [batch, maxSeqLen, numKvHeads, headDim]
+ * 2. For each new token, compute new K/V and update cache
+ * 3. Use full cached K/V for attention computation
+ *
+ * Returns updated keyCache and valueCache tensors.
+ * + * @param names names May be null. Arrays of names for the output variables. + * @param keyCache Existing key cache. Shape: [batch, maxSeqLen, numKvHeads, headDim] (NUMERIC type) + * @param valueCache Existing value cache. Shape: [batch, maxSeqLen, numKvHeads, headDim] (NUMERIC type) + * @param newKeys New keys to insert. Shape: [batch, newSeqLen, numKvHeads, headDim] (NUMERIC type) + * @param newValues New values to insert. Shape: [batch, newSeqLen, numKvHeads, headDim] (NUMERIC type) + * @param startPosition Position in cache where new keys/values should be inserted + */ + public SDVariable[] kvCacheUpdate(String[] names, SDVariable keyCache, SDVariable valueCache, + SDVariable newKeys, SDVariable newValues, int startPosition) { + SDValidation.validateNumerical("kvCacheUpdate", "keyCache", keyCache); + SDValidation.validateNumerical("kvCacheUpdate", "valueCache", valueCache); + SDValidation.validateNumerical("kvCacheUpdate", "newKeys", newKeys); + SDValidation.validateNumerical("kvCacheUpdate", "newValues", newValues); + SDVariable[] out = new org.nd4j.linalg.api.ops.impl.transforms.custom.KVCacheUpdate(sd,keyCache, valueCache, newKeys, newValues, startPosition).outputVariables(); + return sd.updateVariableNamesAndReferences(out, names); + } + + /** + * Batch KV cache scatter update for LLM autoregressive decoding.
+ *
+ * Copies a single time-step slice from each present KV tensor into the
+ * corresponding static KV buffer at a given cache position. Replaces N
+ * individual Java view+assign calls with a single native kernel launch.
+ *
+ * The present tensor has shape [batch, heads, seqLen, dim] where the new
+ * token's KV entry is at the last sequence position. This entry is extracted
+ * and written into the static buffer at cachePos.
+ *
+ * For multiple pairs, inputs are ordered as:
+ * [present_0, ..., present_{N-1}, static_0, ..., static_{N-1}]
+ * + * @param present Present KV tensor from decoder output. Shape: [batch, heads, seqLen, dim] (NUMERIC type) + * @param staticBuffer Static KV cache buffer. Shape: [batch, heads, maxKvLen, dim]. Updated in-place. (NUMERIC type) + * @param cachePos Position in static buffer to write the new entry + * @return output Scalar 0 on success (LONG type) + */ + public SDVariable kvScatter(SDVariable present, SDVariable staticBuffer, long cachePos) { + SDValidation.validateNumerical("kvScatter", "present", present); + SDValidation.validateNumerical("kvScatter", "staticBuffer", staticBuffer); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.KvScatter(sd,present, staticBuffer, cachePos, 1).outputVariable(); + } + + /** + * Batch KV cache scatter update for LLM autoregressive decoding.
+ *
+ * Copies a single time-step slice from each present KV tensor into the
+ * corresponding static KV buffer at a given cache position. Replaces N
+ * individual Java view+assign calls with a single native kernel launch.
+ *
+ * The present tensor has shape [batch, heads, seqLen, dim] where the new
+ * token's KV entry is at the last sequence position. This entry is extracted
+ * and written into the static buffer at cachePos.
+ *
+ * For multiple pairs, inputs are ordered as:
+ * [present_0, ..., present_{N-1}, static_0, ..., static_{N-1}]
+ * + * @param name name May be null. Name for the output variable + * @param present Present KV tensor from decoder output. Shape: [batch, heads, seqLen, dim] (NUMERIC type) + * @param staticBuffer Static KV cache buffer. Shape: [batch, heads, maxKvLen, dim]. Updated in-place. (NUMERIC type) + * @param cachePos Position in static buffer to write the new entry + * @return output Scalar 0 on success (LONG type) + */ + public SDVariable kvScatter(String name, SDVariable present, SDVariable staticBuffer, + long cachePos) { + SDValidation.validateNumerical("kvScatter", "present", present); + SDValidation.validateNumerical("kvScatter", "staticBuffer", staticBuffer); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.KvScatter(sd,present, staticBuffer, cachePos, 1).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Batch KV cache scatter update for LLM autoregressive decoding.
+ *
+ * Copies a single time-step slice from each present KV tensor into the
+ * corresponding static KV buffer at a given cache position. Replaces N
+ * individual Java view+assign calls with a single native kernel launch.
+ *
+ * The present tensor has shape [batch, heads, seqLen, dim] where the new
+ * token's KV entry is at the last sequence position. This entry is extracted
+ * and written into the static buffer at cachePos.
+ *
+ * For multiple pairs, inputs are ordered as:
+ * [present_0, ..., present_{N-1}, static_0, ..., static_{N-1}]
+ * + * @param present Present KV tensor from decoder output. Shape: [batch, heads, seqLen, dim] (NUMERIC type) + * @param staticBuffer Static KV cache buffer. Shape: [batch, heads, maxKvLen, dim]. Updated in-place. (NUMERIC type) + * @param cachePos Position in static buffer to write the new entry + * @param numPairs Number of present/static KV pairs. When > 1, inputs are [present_0..N-1, static_0..N-1] + * @return output Scalar 0 on success (LONG type) + */ + public SDVariable kvScatter(SDVariable present, SDVariable staticBuffer, long cachePos, + int numPairs) { + SDValidation.validateNumerical("kvScatter", "present", present); + SDValidation.validateNumerical("kvScatter", "staticBuffer", staticBuffer); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.KvScatter(sd,present, staticBuffer, cachePos, numPairs).outputVariable(); + } + + /** + * Batch KV cache scatter update for LLM autoregressive decoding.
+ *
+ * Copies a single time-step slice from each present KV tensor into the
+ * corresponding static KV buffer at a given cache position. Replaces N
+ * individual Java view+assign calls with a single native kernel launch.
+ *
+ * The present tensor has shape [batch, heads, seqLen, dim] where the new
+ * token's KV entry is at the last sequence position. This entry is extracted
+ * and written into the static buffer at cachePos.
+ *
+ * For multiple pairs, inputs are ordered as:
+ * [present_0, ..., present_{N-1}, static_0, ..., static_{N-1}]
+ * + * @param name name May be null. Name for the output variable + * @param present Present KV tensor from decoder output. Shape: [batch, heads, seqLen, dim] (NUMERIC type) + * @param staticBuffer Static KV cache buffer. Shape: [batch, heads, maxKvLen, dim]. Updated in-place. (NUMERIC type) + * @param cachePos Position in static buffer to write the new entry + * @param numPairs Number of present/static KV pairs. When > 1, inputs are [present_0..N-1, static_0..N-1] + * @return output Scalar 0 on success (LONG type) + */ + public SDVariable kvScatter(String name, SDVariable present, SDVariable staticBuffer, + long cachePos, int numPairs) { + SDValidation.validateNumerical("kvScatter", "present", present); + SDValidation.validateNumerical("kvScatter", "staticBuffer", staticBuffer); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.KvScatter(sd,present, staticBuffer, cachePos, numPairs).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + /** * Apply Layer Normalization
*
@@ -811,6 +1508,172 @@ public SDVariable logSoftmax(String name, SDVariable x, int dimension) { return sd.updateVariableNameAndReference(out, name); } + /** + * Mixture of Experts (MoE) Layer.
+ *
+ * Implements sparse MoE routing where each token is processed by only the top-k
+ * selected experts out of a larger pool. This enables scaling model capacity
+ * without proportionally increasing computation.
+ *
+ * Used in large language models like:
+ * - DeepSeek (DeepSeekMoE)
+ * - Mixtral (Mistral AI)
+ * - Switch Transformer (Google)
+ * - GShard (Google)
+ *
+ * The router computes expert selection probabilities:
+ * router_probs = softmax(input @ routerWeights)
+ *
+ * Top-k experts are selected and their outputs are weighted by normalized probs:
+ * output = sum(normalized_prob[i] * expert[i](input) for i in top_k)
+ *
+ * Benefits:
+ * - Scales model capacity with sublinear compute increase
+ * - Enables very large models with efficient inference
+ * - Supports expert parallelism across devices
+ * + * @param input Input embeddings. Shape: [batch, seqLen, hiddenSize] (NUMERIC type) + * @param routerWeights Router projection weights. Shape: [hiddenSize, numExperts] (NUMERIC type) + * @param expertWeights Expert weight matrices. Shape: [numExperts, hiddenSize, expertHiddenSize] (NUMERIC type) + * @param numExperts Total number of experts + * @param topK Number of experts to route to per token + */ + public SDVariable[] mixtureOfExperts(SDVariable input, SDVariable routerWeights, + SDVariable expertWeights, int numExperts, int topK) { + SDValidation.validateNumerical("mixtureOfExperts", "input", input); + SDValidation.validateNumerical("mixtureOfExperts", "routerWeights", routerWeights); + SDValidation.validateNumerical("mixtureOfExperts", "expertWeights", expertWeights); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.MixtureOfExperts(sd,input, routerWeights, expertWeights, null, numExperts, topK, true, 1.0).outputVariables(); + } + + /** + * Mixture of Experts (MoE) Layer.
+ *
+ * Implements sparse MoE routing where each token is processed by only the top-k
+ * selected experts out of a larger pool. This enables scaling model capacity
+ * without proportionally increasing computation.
+ *
+ * Used in large language models like:
+ * - DeepSeek (DeepSeekMoE)
+ * - Mixtral (Mistral AI)
+ * - Switch Transformer (Google)
+ * - GShard (Google)
+ *
+ * The router computes expert selection probabilities:
+ * router_probs = softmax(input @ routerWeights)
+ *
+ * Top-k experts are selected and their outputs are weighted by normalized probs:
+ * output = sum(normalized_prob[i] * expert[i](input) for i in top_k)
+ *
+ * Benefits:
+ * - Scales model capacity with sublinear compute increase
+ * - Enables very large models with efficient inference
+ * - Supports expert parallelism across devices
+ * + * @param names names May be null. Arrays of names for the output variables. + * @param input Input embeddings. Shape: [batch, seqLen, hiddenSize] (NUMERIC type) + * @param routerWeights Router projection weights. Shape: [hiddenSize, numExperts] (NUMERIC type) + * @param expertWeights Expert weight matrices. Shape: [numExperts, hiddenSize, expertHiddenSize] (NUMERIC type) + * @param numExperts Total number of experts + * @param topK Number of experts to route to per token + */ + public SDVariable[] mixtureOfExperts(String[] names, SDVariable input, SDVariable routerWeights, + SDVariable expertWeights, int numExperts, int topK) { + SDValidation.validateNumerical("mixtureOfExperts", "input", input); + SDValidation.validateNumerical("mixtureOfExperts", "routerWeights", routerWeights); + SDValidation.validateNumerical("mixtureOfExperts", "expertWeights", expertWeights); + SDVariable[] out = new org.nd4j.linalg.api.ops.impl.transforms.custom.MixtureOfExperts(sd,input, routerWeights, expertWeights, null, numExperts, topK, true, 1.0).outputVariables(); + return sd.updateVariableNamesAndReferences(out, names); + } + + /** + * Mixture of Experts (MoE) Layer.
+ *
+ * Implements sparse MoE routing where each token is processed by only the top-k
+ * selected experts out of a larger pool. This enables scaling model capacity
+ * without proportionally increasing computation.
+ *
+ * Used in large language models like:
+ * - DeepSeek (DeepSeekMoE)
+ * - Mixtral (Mistral AI)
+ * - Switch Transformer (Google)
+ * - GShard (Google)
+ *
+ * The router computes expert selection probabilities:
+ * router_probs = softmax(input @ routerWeights)
+ *
+ * Top-k experts are selected and their outputs are weighted by normalized probs:
+ * output = sum(normalized_prob[i] * expert[i](input) for i in top_k)
+ *
+ * Benefits:
+ * - Scales model capacity with sublinear compute increase
+ * - Enables very large models with efficient inference
+ * - Supports expert parallelism across devices
+ * + * @param input Input embeddings. Shape: [batch, seqLen, hiddenSize] (NUMERIC type) + * @param routerWeights Router projection weights. Shape: [hiddenSize, numExperts] (NUMERIC type) + * @param expertWeights Expert weight matrices. Shape: [numExperts, hiddenSize, expertHiddenSize] (NUMERIC type) + * @param expertBias Optional expert biases. Shape: [numExperts, expertHiddenSize] (NUMERIC type) + * @param numExperts Total number of experts + * @param topK Number of experts to route to per token + * @param normalizeProbs Whether to normalize router probabilities for selected experts + * @param capacityFactor Expert capacity factor for load balancing + */ + public SDVariable[] mixtureOfExperts(SDVariable input, SDVariable routerWeights, + SDVariable expertWeights, SDVariable expertBias, int numExperts, int topK, + boolean normalizeProbs, double capacityFactor) { + SDValidation.validateNumerical("mixtureOfExperts", "input", input); + SDValidation.validateNumerical("mixtureOfExperts", "routerWeights", routerWeights); + SDValidation.validateNumerical("mixtureOfExperts", "expertWeights", expertWeights); + SDValidation.validateNumerical("mixtureOfExperts", "expertBias", expertBias); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.MixtureOfExperts(sd,input, routerWeights, expertWeights, expertBias, numExperts, topK, normalizeProbs, capacityFactor).outputVariables(); + } + + /** + * Mixture of Experts (MoE) Layer.
+ *
+ * Implements sparse MoE routing where each token is processed by only the top-k
+ * selected experts out of a larger pool. This enables scaling model capacity
+ * without proportionally increasing computation.
+ *
+ * Used in large language models like:
+ * - DeepSeek (DeepSeekMoE)
+ * - Mixtral (Mistral AI)
+ * - Switch Transformer (Google)
+ * - GShard (Google)
+ *
+ * The router computes expert selection probabilities:
+ * router_probs = softmax(input @ routerWeights)
+ *
+ * Top-k experts are selected and their outputs are weighted by normalized probs:
+ * output = sum(normalized_prob[i] * expert[i](input) for i in top_k)
+ *
+ * Benefits:
+ * - Scales model capacity with sublinear compute increase
+ * - Enables very large models with efficient inference
+ * - Supports expert parallelism across devices
+ * + * @param names names May be null. Arrays of names for the output variables. + * @param input Input embeddings. Shape: [batch, seqLen, hiddenSize] (NUMERIC type) + * @param routerWeights Router projection weights. Shape: [hiddenSize, numExperts] (NUMERIC type) + * @param expertWeights Expert weight matrices. Shape: [numExperts, hiddenSize, expertHiddenSize] (NUMERIC type) + * @param expertBias Optional expert biases. Shape: [numExperts, expertHiddenSize] (NUMERIC type) + * @param numExperts Total number of experts + * @param topK Number of experts to route to per token + * @param normalizeProbs Whether to normalize router probabilities for selected experts + * @param capacityFactor Expert capacity factor for load balancing + */ + public SDVariable[] mixtureOfExperts(String[] names, SDVariable input, SDVariable routerWeights, + SDVariable expertWeights, SDVariable expertBias, int numExperts, int topK, + boolean normalizeProbs, double capacityFactor) { + SDValidation.validateNumerical("mixtureOfExperts", "input", input); + SDValidation.validateNumerical("mixtureOfExperts", "routerWeights", routerWeights); + SDValidation.validateNumerical("mixtureOfExperts", "expertWeights", expertWeights); + SDValidation.validateNumerical("mixtureOfExperts", "expertBias", expertBias); + SDVariable[] out = new org.nd4j.linalg.api.ops.impl.transforms.custom.MixtureOfExperts(sd,input, routerWeights, expertWeights, expertBias, numExperts, topK, normalizeProbs, capacityFactor).outputVariables(); + return sd.updateVariableNamesAndReferences(out, names); + } + /** * This performs multi-headed dot product attention on the given timeseries input
* out = concat(head_1, head_2, ..., head_n) * Wo
@@ -890,7 +1753,7 @@ public SDVariable multiHeadDotProductAttention(String name, SDVariable queries, } /** - * Padding operation
+ * Padding operation
* * @param input Input tensor (NUMERIC type) * @param padding Padding value (NUMERIC type) @@ -905,7 +1768,7 @@ public SDVariable pad(SDVariable input, SDVariable padding, PadMode PadMode, dou } /** - * Padding operation
+ * Padding operation
* * @param name name May be null. Name for the output variable * @param input Input tensor (NUMERIC type) @@ -923,7 +1786,7 @@ public SDVariable pad(String name, SDVariable input, SDVariable padding, PadMode } /** - * Padding operation
+ * Padding operation
* * @param input Input tensor (NUMERIC type) * @param padding Padding value (NUMERIC type) @@ -937,7 +1800,7 @@ public SDVariable pad(SDVariable input, SDVariable padding, double constant) { } /** - * Padding operation
+ * Padding operation
* * @param name name May be null. Name for the output variable * @param input Input tensor (NUMERIC type) @@ -1026,6 +1889,141 @@ public SDVariable prelu(String name, SDVariable input, SDVariable alpha, int... return sd.updateVariableNameAndReference(out, name); } + /** + * Relative Position Bias - Compute relative position bias for attention.
+ *
+ * Supports two modes:
+ * 1. Learned bias (Swin/SAM style): Looks up bias values from a learned table
+ * based on relative positions between query and key positions.
+ *
+ * 2. ALiBi (Attention with Linear Biases): Computes linear position-based bias
+ * without learned parameters. More efficient for very long sequences.
+ *
+ * For learned bias mode:
+ * - biasTable shape: [(2*windowSize-1)^2, numHeads] for 2D
+ * - Output is gathered based on relative position indices
+ *
+ * For ALiBi mode:
+ * - biasTable can be sequence length (scalar) or input tensor
+ * - Computes m_h * |i - j| where m_h = 2^(-8*h/H)
+ *
+ * Reference: "Swin Transformer" (Liu et al., 2021)
+ * "Train Short, Test Long" (Press et al., 2021) for ALiBi
+ * + * @param biasTable Learned bias table. Shape: [numRelativePositions, numHeads] for learned mode, or scalar/tensor for ALiBi mode (NUMERIC type) + * @param numHeads Number of attention heads + * @param windowSize Window size for 2D position encoding (used if generating index) + * @return output Position bias. Shape: [numHeads, windowSize^2, windowSize^2] or [numHeads, seqLen, seqLen] (NUMERIC type) + */ + public SDVariable relativePositionBias(SDVariable biasTable, int numHeads, int windowSize) { + SDValidation.validateNumerical("relativePositionBias", "biasTable", biasTable); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.RelativePositionBias(sd,biasTable, null, numHeads, windowSize, false).outputVariable(); + } + + /** + * Relative Position Bias - Compute relative position bias for attention.
+ *
+ * Supports two modes:
+ * 1. Learned bias (Swin/SAM style): Looks up bias values from a learned table
+ * based on relative positions between query and key positions.
+ *
+ * 2. ALiBi (Attention with Linear Biases): Computes linear position-based bias
+ * without learned parameters. More efficient for very long sequences.
+ *
+ * For learned bias mode:
+ * - biasTable shape: [(2*windowSize-1)^2, numHeads] for 2D
+ * - Output is gathered based on relative position indices
+ *
+ * For ALiBi mode:
+ * - biasTable can be sequence length (scalar) or input tensor
+ * - Computes m_h * |i - j| where m_h = 2^(-8*h/H)
+ *
+ * Reference: "Swin Transformer" (Liu et al., 2021)
+ * "Train Short, Test Long" (Press et al., 2021) for ALiBi
+ * + * @param name name May be null. Name for the output variable + * @param biasTable Learned bias table. Shape: [numRelativePositions, numHeads] for learned mode, or scalar/tensor for ALiBi mode (NUMERIC type) + * @param numHeads Number of attention heads + * @param windowSize Window size for 2D position encoding (used if generating index) + * @return output Position bias. Shape: [numHeads, windowSize^2, windowSize^2] or [numHeads, seqLen, seqLen] (NUMERIC type) + */ + public SDVariable relativePositionBias(String name, SDVariable biasTable, int numHeads, + int windowSize) { + SDValidation.validateNumerical("relativePositionBias", "biasTable", biasTable); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.RelativePositionBias(sd,biasTable, null, numHeads, windowSize, false).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Relative Position Bias - Compute relative position bias for attention.
+ *
+ * Supports two modes:
+ * 1. Learned bias (Swin/SAM style): Looks up bias values from a learned table
+ * based on relative positions between query and key positions.
+ *
+ * 2. ALiBi (Attention with Linear Biases): Computes linear position-based bias
+ * without learned parameters. More efficient for very long sequences.
+ *
+ * For learned bias mode:
+ * - biasTable shape: [(2*windowSize-1)^2, numHeads] for 2D
+ * - Output is gathered based on relative position indices
+ *
+ * For ALiBi mode:
+ * - biasTable can be sequence length (scalar) or input tensor
+ * - Computes m_h * |i - j| where m_h = 2^(-8*h/H)
+ *
+ * Reference: "Swin Transformer" (Liu et al., 2021)
+ * "Train Short, Test Long" (Press et al., 2021) for ALiBi
+ * + * @param biasTable Learned bias table. Shape: [numRelativePositions, numHeads] for learned mode, or scalar/tensor for ALiBi mode (NUMERIC type) + * @param relativePositionIndex Optional precomputed relative position index. Shape: [windowSize^2, windowSize^2] (NUMERIC type) + * @param numHeads Number of attention heads + * @param windowSize Window size for 2D position encoding (used if generating index) + * @return output Position bias. Shape: [numHeads, windowSize^2, windowSize^2] or [numHeads, seqLen, seqLen] (NUMERIC type) + */ + public SDVariable relativePositionBias(SDVariable biasTable, SDVariable relativePositionIndex, + int numHeads, int windowSize) { + SDValidation.validateNumerical("relativePositionBias", "biasTable", biasTable); + SDValidation.validateNumerical("relativePositionBias", "relativePositionIndex", relativePositionIndex); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.RelativePositionBias(sd,biasTable, relativePositionIndex, numHeads, windowSize, false).outputVariable(); + } + + /** + * Relative Position Bias - Compute relative position bias for attention.
+ *
+ * Supports two modes:
+ * 1. Learned bias (Swin/SAM style): Looks up bias values from a learned table
+ * based on relative positions between query and key positions.
+ *
+ * 2. ALiBi (Attention with Linear Biases): Computes linear position-based bias
+ * without learned parameters. More efficient for very long sequences.
+ *
+ * For learned bias mode:
+ * - biasTable shape: [(2*windowSize-1)^2, numHeads] for 2D
+ * - Output is gathered based on relative position indices
+ *
+ * For ALiBi mode:
+ * - biasTable can be sequence length (scalar) or input tensor
+ * - Computes m_h * |i - j| where m_h = 2^(-8*h/H)
+ *
+ * Reference: "Swin Transformer" (Liu et al., 2021)
+ * "Train Short, Test Long" (Press et al., 2021) for ALiBi
+ * + * @param name name May be null. Name for the output variable + * @param biasTable Learned bias table. Shape: [numRelativePositions, numHeads] for learned mode, or scalar/tensor for ALiBi mode (NUMERIC type) + * @param relativePositionIndex Optional precomputed relative position index. Shape: [windowSize^2, windowSize^2] (NUMERIC type) + * @param numHeads Number of attention heads + * @param windowSize Window size for 2D position encoding (used if generating index) + * @return output Position bias. Shape: [numHeads, windowSize^2, windowSize^2] or [numHeads, seqLen, seqLen] (NUMERIC type) + */ + public SDVariable relativePositionBias(String name, SDVariable biasTable, + SDVariable relativePositionIndex, int numHeads, int windowSize) { + SDValidation.validateNumerical("relativePositionBias", "biasTable", biasTable); + SDValidation.validateNumerical("relativePositionBias", "relativePositionIndex", relativePositionIndex); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.RelativePositionBias(sd,biasTable, relativePositionIndex, numHeads, windowSize, false).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + /** * Element-wise rectified linear function with specified cutoff:
* out[i] = in[i] if in[i] >= cutoff
@@ -1116,6 +2114,146 @@ public SDVariable reluLayer(String name, SDVariable input, SDVariable weights, S return sd.updateVariableNameAndReference(out, name); } + /** + * Root Mean Square Layer Normalization (RMSNorm):
+ *
+ * output = input * rsqrt(mean(input^2, axis=-1) + epsilon) * gamma
+ *
+ * If gamma is not provided, only RMS normalization is applied.
+ * + * @param input Input variable (NUMERIC type) + * @param gamma Scale/gain vector (NUMERIC type) + * @param epsilon Epsilon for numerical stability + * @return output RMS normalized output (NUMERIC type) + */ + public SDVariable rmsNorm(SDVariable input, SDVariable gamma, double epsilon) { + SDValidation.validateNumerical("rmsNorm", "input", input); + SDValidation.validateNumerical("rmsNorm", "gamma", gamma); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.RmsNorm(sd,input, gamma, epsilon).outputVariable(); + } + + /** + * Root Mean Square Layer Normalization (RMSNorm):
+ *
+ * output = input * rsqrt(mean(input^2, axis=-1) + epsilon) * gamma
+ *
+ * If gamma is not provided, only RMS normalization is applied.
+ * + * @param name name May be null. Name for the output variable + * @param input Input variable (NUMERIC type) + * @param gamma Scale/gain vector (NUMERIC type) + * @param epsilon Epsilon for numerical stability + * @return output RMS normalized output (NUMERIC type) + */ + public SDVariable rmsNorm(String name, SDVariable input, SDVariable gamma, double epsilon) { + SDValidation.validateNumerical("rmsNorm", "input", input); + SDValidation.validateNumerical("rmsNorm", "gamma", gamma); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.RmsNorm(sd,input, gamma, epsilon).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Root Mean Square Layer Normalization (RMSNorm):
+ *
+ * output = input * rsqrt(mean(input^2, axis=-1) + epsilon) * gamma
+ *
+ * If gamma is not provided, only RMS normalization is applied.
+ * + * @param input Input variable (NUMERIC type) + * @param gamma Scale/gain vector (NUMERIC type) + * @return output RMS normalized output (NUMERIC type) + */ + public SDVariable rmsNorm(SDVariable input, SDVariable gamma) { + SDValidation.validateNumerical("rmsNorm", "input", input); + SDValidation.validateNumerical("rmsNorm", "gamma", gamma); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.RmsNorm(sd,input, gamma, 1.0E-5).outputVariable(); + } + + /** + * Root Mean Square Layer Normalization (RMSNorm):
+ *
+ * output = input * rsqrt(mean(input^2, axis=-1) + epsilon) * gamma
+ *
+ * If gamma is not provided, only RMS normalization is applied.
+ * + * @param name name May be null. Name for the output variable + * @param input Input variable (NUMERIC type) + * @param gamma Scale/gain vector (NUMERIC type) + * @return output RMS normalized output (NUMERIC type) + */ + public SDVariable rmsNorm(String name, SDVariable input, SDVariable gamma) { + SDValidation.validateNumerical("rmsNorm", "input", input); + SDValidation.validateNumerical("rmsNorm", "gamma", gamma); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.RmsNorm(sd,input, gamma, 1.0E-5).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Root Mean Square Layer Normalization (RMSNorm):
+ *
+ * output = input * rsqrt(mean(input^2, axis=-1) + epsilon) * gamma
+ *
+ * If gamma is not provided, only RMS normalization is applied.
+ * + * @param input Input variable (NUMERIC type) + * @param epsilon Epsilon for numerical stability + * @return output RMS normalized output (NUMERIC type) + */ + public SDVariable rmsNorm(SDVariable input, double epsilon) { + SDValidation.validateNumerical("rmsNorm", "input", input); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.RmsNorm(sd,input, null, epsilon).outputVariable(); + } + + /** + * Root Mean Square Layer Normalization (RMSNorm):
+ *
+ * output = input * rsqrt(mean(input^2, axis=-1) + epsilon) * gamma
+ *
+ * If gamma is not provided, only RMS normalization is applied.
+ * + * @param name name May be null. Name for the output variable + * @param input Input variable (NUMERIC type) + * @param epsilon Epsilon for numerical stability + * @return output RMS normalized output (NUMERIC type) + */ + public SDVariable rmsNorm(String name, SDVariable input, double epsilon) { + SDValidation.validateNumerical("rmsNorm", "input", input); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.RmsNorm(sd,input, null, epsilon).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Root Mean Square Layer Normalization (RMSNorm):
+ *
+ * output = input * rsqrt(mean(input^2, axis=-1) + epsilon) * gamma
+ *
+ * If gamma is not provided, only RMS normalization is applied.
+ * + * @param input Input variable (NUMERIC type) + * @return output RMS normalized output (NUMERIC type) + */ + public SDVariable rmsNorm(SDVariable input) { + SDValidation.validateNumerical("rmsNorm", "input", input); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.RmsNorm(sd,input, null, 1.0E-5).outputVariable(); + } + + /** + * Root Mean Square Layer Normalization (RMSNorm):
+ *
+ * output = input * rsqrt(mean(input^2, axis=-1) + epsilon) * gamma
+ *
+ * If gamma is not provided, only RMS normalization is applied.
+ * + * @param name name May be null. Name for the output variable + * @param input Input variable (NUMERIC type) + * @return output RMS normalized output (NUMERIC type) + */ + public SDVariable rmsNorm(String name, SDVariable input) { + SDValidation.validateNumerical("rmsNorm", "input", input); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.RmsNorm(sd,input, null, 1.0E-5).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + /** * Element-wise SeLU function - Scaled exponential Lineal Unit: see Self-Normalizing Neural Networks
*
@@ -1198,6 +2336,72 @@ public SDVariable sigmoidDerivative(String name, SDVariable x, SDVariable wrt) { return sd.updateVariableNameAndReference(out, name); } + /** + * Sliding Window Attention - Efficient attention for long sequences.
+ *
+ * Each token only attends to a fixed window of previous tokens, enabling
+ * efficient processing of very long sequences. Used in Mistral and other
+ * modern LLMs for handling long contexts.
+ *
+ * Benefits:
+ * - O(N * windowSize) complexity instead of O(N^2)
+ * - Memory efficient for long sequences
+ * - Supports very long context lengths (e.g., 32K with 4K window)
+ *
+ * The attention mask is automatically applied to restrict each position
+ * to only attend to positions within [pos - windowSize, pos].
+ * + * @param query Query tensor. Shape: [batch, seqLen, numHeads, headDim] (NUMERIC type) + * @param key Key tensor. Shape: [batch, seqLen, numKvHeads, headDim] (NUMERIC type) + * @param value Value tensor. Shape: [batch, seqLen, numKvHeads, headDim] (NUMERIC type) + * @param windowSize Sliding window size - tokens can only attend to this many previous positions + * @param numHeads Number of query attention heads + * @param numKvHeads Number of KV heads (0 = same as numHeads) + * @param scale Scaling factor. 0 = auto (1/sqrt(headDim)) + * @return output Attention output. Shape: [batch, seqLen, numHeads, headDim] (NUMERIC type) + */ + public SDVariable slidingWindowAttention(SDVariable query, SDVariable key, SDVariable value, + int windowSize, int numHeads, int numKvHeads, double scale) { + SDValidation.validateNumerical("slidingWindowAttention", "query", query); + SDValidation.validateNumerical("slidingWindowAttention", "key", key); + SDValidation.validateNumerical("slidingWindowAttention", "value", value); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.SlidingWindowAttention(sd,query, key, value, windowSize, numHeads, numKvHeads, scale).outputVariable(); + } + + /** + * Sliding Window Attention - Efficient attention for long sequences.
+ *
+ * Each token only attends to a fixed window of previous tokens, enabling
+ * efficient processing of very long sequences. Used in Mistral and other
+ * modern LLMs for handling long contexts.
+ *
+ * Benefits:
+ * - O(N * windowSize) complexity instead of O(N^2)
+ * - Memory efficient for long sequences
+ * - Supports very long context lengths (e.g., 32K with 4K window)
+ *
+ * The attention mask is automatically applied to restrict each position
+ * to only attend to positions within [pos - windowSize, pos].
+ * + * @param name name May be null. Name for the output variable + * @param query Query tensor. Shape: [batch, seqLen, numHeads, headDim] (NUMERIC type) + * @param key Key tensor. Shape: [batch, seqLen, numKvHeads, headDim] (NUMERIC type) + * @param value Value tensor. Shape: [batch, seqLen, numKvHeads, headDim] (NUMERIC type) + * @param windowSize Sliding window size - tokens can only attend to this many previous positions + * @param numHeads Number of query attention heads + * @param numKvHeads Number of KV heads (0 = same as numHeads) + * @param scale Scaling factor. 0 = auto (1/sqrt(headDim)) + * @return output Attention output. Shape: [batch, seqLen, numHeads, headDim] (NUMERIC type) + */ + public SDVariable slidingWindowAttention(String name, SDVariable query, SDVariable key, + SDVariable value, int windowSize, int numHeads, int numKvHeads, double scale) { + SDValidation.validateNumerical("slidingWindowAttention", "query", query); + SDValidation.validateNumerical("slidingWindowAttention", "key", key); + SDValidation.validateNumerical("slidingWindowAttention", "value", value); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.SlidingWindowAttention(sd,query, key, value, windowSize, numHeads, numKvHeads, scale).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + /** * Softmax activation, along the specified dimension
* @@ -1370,6 +2574,104 @@ public SDVariable tanh(String name, SDVariable x) { return sd.updateVariableNameAndReference(out, name); } + /** + * Token sampling for LLM inference.
+ *
+ * Full sampling pipeline in a single native GPU call:
+ * temperature scaling -> top-K filtering -> softmax -> top-P filtering -> sample/argmax
+ *
+ * For greedy decoding (temperature=0 or no top-k/top-p), performs GPU-side argmax
+ * with shared-memory reduction — avoids transferring the full logits tensor to host.
+ *
+ * Supports rank 1 [vocabSize], rank 2 [batch, vocabSize], and rank 3
+ * [batch, seqLen, vocabSize] inputs. For rank 3, the last sequence position
+ * is automatically extracted for sampling.
+ * + * @param logits Logits tensor. Shape: [vocabSize], [batch, vocabSize], or [batch, seqLen, vocabSize]. For rank-3, samples from the last sequence position. (NUMERIC type) + * @return output Sampled token indices. Shape: [batch] or scalar (LONG type) + */ + public SDVariable tokenSample(SDVariable logits) { + SDValidation.validateNumerical("tokenSample", "logits", logits); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.TokenSample(sd,logits, 0.0, 0, 0.0, 0).outputVariable(); + } + + /** + * Token sampling for LLM inference.
+ *
+ * Full sampling pipeline in a single native GPU call:
+ * temperature scaling -> top-K filtering -> softmax -> top-P filtering -> sample/argmax
+ *
+ * For greedy decoding (temperature=0 or no top-k/top-p), performs GPU-side argmax
+ * with shared-memory reduction — avoids transferring the full logits tensor to host.
+ *
+ * Supports rank 1 [vocabSize], rank 2 [batch, vocabSize], and rank 3
+ * [batch, seqLen, vocabSize] inputs. For rank 3, the last sequence position
+ * is automatically extracted for sampling.
+ * + * @param name name May be null. Name for the output variable + * @param logits Logits tensor. Shape: [vocabSize], [batch, vocabSize], or [batch, seqLen, vocabSize]. For rank-3, samples from the last sequence position. (NUMERIC type) + * @return output Sampled token indices. Shape: [batch] or scalar (LONG type) + */ + public SDVariable tokenSample(String name, SDVariable logits) { + SDValidation.validateNumerical("tokenSample", "logits", logits); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.TokenSample(sd,logits, 0.0, 0, 0.0, 0).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Token sampling for LLM inference.
+ *
+ * Full sampling pipeline in a single native GPU call:
+ * temperature scaling -> top-K filtering -> softmax -> top-P filtering -> sample/argmax
+ *
+ * For greedy decoding (temperature=0 or no top-k/top-p), performs GPU-side argmax
+ * with shared-memory reduction — avoids transferring the full logits tensor to host.
+ *
+ * Supports rank 1 [vocabSize], rank 2 [batch, vocabSize], and rank 3
+ * [batch, seqLen, vocabSize] inputs. For rank 3, the last sequence position
+ * is automatically extracted for sampling.
+ * + * @param logits Logits tensor. Shape: [vocabSize], [batch, vocabSize], or [batch, seqLen, vocabSize]. For rank-3, samples from the last sequence position. (NUMERIC type) + * @param temperature Temperature for sampling. 0 = greedy (argmax) + * @param topK Top-K filtering: keep only top K logits. 0 = disabled + * @param topP Top-P (nucleus) filtering threshold. 0 = disabled + * @param seed Random seed for sampling. 0 = random + * @return output Sampled token indices. Shape: [batch] or scalar (LONG type) + */ + public SDVariable tokenSample(SDVariable logits, double temperature, int topK, double topP, + long seed) { + SDValidation.validateNumerical("tokenSample", "logits", logits); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.TokenSample(sd,logits, temperature, topK, topP, seed).outputVariable(); + } + + /** + * Token sampling for LLM inference.
+ *
+ * Full sampling pipeline in a single native GPU call:
+ * temperature scaling -> top-K filtering -> softmax -> top-P filtering -> sample/argmax
+ *
+ * For greedy decoding (temperature=0 or no top-k/top-p), performs GPU-side argmax
+ * with shared-memory reduction — avoids transferring the full logits tensor to host.
+ *
+ * Supports rank 1 [vocabSize], rank 2 [batch, vocabSize], and rank 3
+ * [batch, seqLen, vocabSize] inputs. For rank 3, the last sequence position
+ * is automatically extracted for sampling.
+ * + * @param name name May be null. Name for the output variable + * @param logits Logits tensor. Shape: [vocabSize], [batch, vocabSize], or [batch, seqLen, vocabSize]. For rank-3, samples from the last sequence position. (NUMERIC type) + * @param temperature Temperature for sampling. 0 = greedy (argmax) + * @param topK Top-K filtering: keep only top K logits. 0 = disabled + * @param topP Top-P (nucleus) filtering threshold. 0 = disabled + * @param seed Random seed for sampling. 0 = random + * @return output Sampled token indices. Shape: [batch] or scalar (LONG type) + */ + public SDVariable tokenSample(String name, SDVariable logits, double temperature, int topK, + double topP, long seed) { + SDValidation.validateNumerical("tokenSample", "logits", logits); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.TokenSample(sd,logits, temperature, topK, topP, seed).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + /** * Find values and indices for the largest k entries along the last dimension.
* @@ -1395,4 +2697,265 @@ public SDVariable[] topK(String[] names, SDVariable input, double k, boolean sor SDVariable[] out = new org.nd4j.linalg.api.ops.impl.transforms.custom.TopK(sd,input, k, sorted).outputVariables(); return sd.updateVariableNamesAndReferences(out, names); } + + /** + * SAM-style Two-Way Cross Attention.
+ * Bidirectional cross-attention where tokens attend to image features and
+ * image features attend to tokens simultaneously:
+ * tokenOutput = softmax(tokenQ @ imageK^T * scale) @ imageV
+ * imageOutput = softmax(imageQ @ tokenK^T * scale) @ tokenV
+ * + * @param tokenQuery Token queries [batch, tokenSeqLen, embedDim] (NUMERIC type) + * @param tokenKey Token keys [batch, tokenSeqLen, embedDim] (NUMERIC type) + * @param tokenValue Token values [batch, tokenSeqLen, embedDim] (NUMERIC type) + * @param imageQuery Image queries [batch, imageSeqLen, embedDim] (NUMERIC type) + * @param imageKey Image keys [batch, imageSeqLen, embedDim] (NUMERIC type) + * @param imageValue Image values [batch, imageSeqLen, embedDim] (NUMERIC type) + * @param scale Attention scale factor (default: 1/sqrt(embedDim)) + */ + public SDVariable[] twoWayCrossAttention(SDVariable tokenQuery, SDVariable tokenKey, + SDVariable tokenValue, SDVariable imageQuery, SDVariable imageKey, SDVariable imageValue, + double scale) { + SDValidation.validateNumerical("twoWayCrossAttention", "tokenQuery", tokenQuery); + SDValidation.validateNumerical("twoWayCrossAttention", "tokenKey", tokenKey); + SDValidation.validateNumerical("twoWayCrossAttention", "tokenValue", tokenValue); + SDValidation.validateNumerical("twoWayCrossAttention", "imageQuery", imageQuery); + SDValidation.validateNumerical("twoWayCrossAttention", "imageKey", imageKey); + SDValidation.validateNumerical("twoWayCrossAttention", "imageValue", imageValue); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.TwoWayCrossAttention(sd,tokenQuery, tokenKey, tokenValue, imageQuery, imageKey, imageValue, scale).outputVariables(); + } + + /** + * SAM-style Two-Way Cross Attention.
+ * Bidirectional cross-attention where tokens attend to image features and
+ * image features attend to tokens simultaneously:
+ * tokenOutput = softmax(tokenQ @ imageK^T * scale) @ imageV
+ * imageOutput = softmax(imageQ @ tokenK^T * scale) @ tokenV
+ * + * @param names names May be null. Arrays of names for the output variables. + * @param tokenQuery Token queries [batch, tokenSeqLen, embedDim] (NUMERIC type) + * @param tokenKey Token keys [batch, tokenSeqLen, embedDim] (NUMERIC type) + * @param tokenValue Token values [batch, tokenSeqLen, embedDim] (NUMERIC type) + * @param imageQuery Image queries [batch, imageSeqLen, embedDim] (NUMERIC type) + * @param imageKey Image keys [batch, imageSeqLen, embedDim] (NUMERIC type) + * @param imageValue Image values [batch, imageSeqLen, embedDim] (NUMERIC type) + * @param scale Attention scale factor (default: 1/sqrt(embedDim)) + */ + public SDVariable[] twoWayCrossAttention(String[] names, SDVariable tokenQuery, + SDVariable tokenKey, SDVariable tokenValue, SDVariable imageQuery, SDVariable imageKey, + SDVariable imageValue, double scale) { + SDValidation.validateNumerical("twoWayCrossAttention", "tokenQuery", tokenQuery); + SDValidation.validateNumerical("twoWayCrossAttention", "tokenKey", tokenKey); + SDValidation.validateNumerical("twoWayCrossAttention", "tokenValue", tokenValue); + SDValidation.validateNumerical("twoWayCrossAttention", "imageQuery", imageQuery); + SDValidation.validateNumerical("twoWayCrossAttention", "imageKey", imageKey); + SDValidation.validateNumerical("twoWayCrossAttention", "imageValue", imageValue); + SDVariable[] out = new org.nd4j.linalg.api.ops.impl.transforms.custom.TwoWayCrossAttention(sd,tokenQuery, tokenKey, tokenValue, imageQuery, imageKey, imageValue, scale).outputVariables(); + return sd.updateVariableNamesAndReferences(out, names); + } + + /** + * SAM-style Two-Way Cross Attention.
+ * Bidirectional cross-attention where tokens attend to image features and
+ * image features attend to tokens simultaneously:
+ * tokenOutput = softmax(tokenQ @ imageK^T * scale) @ imageV
+ * imageOutput = softmax(imageQ @ tokenK^T * scale) @ tokenV
+ * + * @param tokenQuery Token queries [batch, tokenSeqLen, embedDim] (NUMERIC type) + * @param tokenKey Token keys [batch, tokenSeqLen, embedDim] (NUMERIC type) + * @param tokenValue Token values [batch, tokenSeqLen, embedDim] (NUMERIC type) + * @param imageQuery Image queries [batch, imageSeqLen, embedDim] (NUMERIC type) + * @param imageKey Image keys [batch, imageSeqLen, embedDim] (NUMERIC type) + * @param imageValue Image values [batch, imageSeqLen, embedDim] (NUMERIC type) + */ + public SDVariable[] twoWayCrossAttention(SDVariable tokenQuery, SDVariable tokenKey, + SDVariable tokenValue, SDVariable imageQuery, SDVariable imageKey, SDVariable imageValue) { + SDValidation.validateNumerical("twoWayCrossAttention", "tokenQuery", tokenQuery); + SDValidation.validateNumerical("twoWayCrossAttention", "tokenKey", tokenKey); + SDValidation.validateNumerical("twoWayCrossAttention", "tokenValue", tokenValue); + SDValidation.validateNumerical("twoWayCrossAttention", "imageQuery", imageQuery); + SDValidation.validateNumerical("twoWayCrossAttention", "imageKey", imageKey); + SDValidation.validateNumerical("twoWayCrossAttention", "imageValue", imageValue); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.TwoWayCrossAttention(sd,tokenQuery, tokenKey, tokenValue, imageQuery, imageKey, imageValue, 0.0).outputVariables(); + } + + /** + * SAM-style Two-Way Cross Attention.
+ * Bidirectional cross-attention where tokens attend to image features and
+ * image features attend to tokens simultaneously:
+ * tokenOutput = softmax(tokenQ @ imageK^T * scale) @ imageV
+ * imageOutput = softmax(imageQ @ tokenK^T * scale) @ tokenV
+ * + * @param names names May be null. Arrays of names for the output variables. + * @param tokenQuery Token queries [batch, tokenSeqLen, embedDim] (NUMERIC type) + * @param tokenKey Token keys [batch, tokenSeqLen, embedDim] (NUMERIC type) + * @param tokenValue Token values [batch, tokenSeqLen, embedDim] (NUMERIC type) + * @param imageQuery Image queries [batch, imageSeqLen, embedDim] (NUMERIC type) + * @param imageKey Image keys [batch, imageSeqLen, embedDim] (NUMERIC type) + * @param imageValue Image values [batch, imageSeqLen, embedDim] (NUMERIC type) + */ + public SDVariable[] twoWayCrossAttention(String[] names, SDVariable tokenQuery, + SDVariable tokenKey, SDVariable tokenValue, SDVariable imageQuery, SDVariable imageKey, + SDVariable imageValue) { + SDValidation.validateNumerical("twoWayCrossAttention", "tokenQuery", tokenQuery); + SDValidation.validateNumerical("twoWayCrossAttention", "tokenKey", tokenKey); + SDValidation.validateNumerical("twoWayCrossAttention", "tokenValue", tokenValue); + SDValidation.validateNumerical("twoWayCrossAttention", "imageQuery", imageQuery); + SDValidation.validateNumerical("twoWayCrossAttention", "imageKey", imageKey); + SDValidation.validateNumerical("twoWayCrossAttention", "imageValue", imageValue); + SDVariable[] out = new org.nd4j.linalg.api.ops.impl.transforms.custom.TwoWayCrossAttention(sd,tokenQuery, tokenKey, tokenValue, imageQuery, imageKey, imageValue, 0.0).outputVariables(); + return sd.updateVariableNamesAndReferences(out, names); + } + + /** + * Windowed Attention - Local/Sliding Window Attention.
+ *
+ * Implements windowed attention mechanisms used in efficient transformers like
+ * Longformer, BigBird, Swin Transformer, and SAM (Segment Anything Model).
+ *
+ * Supports both:
+ * - 1D windowed attention: for sequences [batch, seqLen, heads, dim]
+ * - 2D windowed attention: for images [batch, height, width, heads, dim]
+ *
+ * Shifted window attention (shiftSize > 0) enables cross-window connections
+ * as used in Swin Transformer.
+ *
+ * Benefits:
+ * - O(N * windowSize) complexity instead of O(N^2)
+ * - Efficient for long sequences and high-resolution images
+ * - Supports relative position bias for position-aware attention
+ * + * @param query Query tensor. Shape: [batch, seqLen, numHeads, headDim] for 1D or [batch, height, width, numHeads, headDim] for 2D (NUMERIC type) + * @param key Key tensor. Same shape as query (NUMERIC type) + * @param value Value tensor. Same shape as query (NUMERIC type) + * @param windowSize Size of attention window + * @param numHeads Number of attention heads + * @return output Attention output. Same shape as query (NUMERIC type) + */ + public SDVariable windowedAttention(SDVariable query, SDVariable key, SDVariable value, + int windowSize, int numHeads) { + SDValidation.validateNumerical("windowedAttention", "query", query); + SDValidation.validateNumerical("windowedAttention", "key", key); + SDValidation.validateNumerical("windowedAttention", "value", value); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.WindowedAttention(sd,query, key, value, null, null, windowSize, numHeads, 0, 0.0, false).outputVariable(); + } + + /** + * Windowed Attention - Local/Sliding Window Attention.
+ *
+ * Implements windowed attention mechanisms used in efficient transformers like
+ * Longformer, BigBird, Swin Transformer, and SAM (Segment Anything Model).
+ *
+ * Supports both:
+ * - 1D windowed attention: for sequences [batch, seqLen, heads, dim]
+ * - 2D windowed attention: for images [batch, height, width, heads, dim]
+ *
+ * Shifted window attention (shiftSize > 0) enables cross-window connections
+ * as used in Swin Transformer.
+ *
+ * Benefits:
+ * - O(N * windowSize) complexity instead of O(N^2)
+ * - Efficient for long sequences and high-resolution images
+ * - Supports relative position bias for position-aware attention
+ * + * @param name name May be null. Name for the output variable + * @param query Query tensor. Shape: [batch, seqLen, numHeads, headDim] for 1D or [batch, height, width, numHeads, headDim] for 2D (NUMERIC type) + * @param key Key tensor. Same shape as query (NUMERIC type) + * @param value Value tensor. Same shape as query (NUMERIC type) + * @param windowSize Size of attention window + * @param numHeads Number of attention heads + * @return output Attention output. Same shape as query (NUMERIC type) + */ + public SDVariable windowedAttention(String name, SDVariable query, SDVariable key, + SDVariable value, int windowSize, int numHeads) { + SDValidation.validateNumerical("windowedAttention", "query", query); + SDValidation.validateNumerical("windowedAttention", "key", key); + SDValidation.validateNumerical("windowedAttention", "value", value); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.WindowedAttention(sd,query, key, value, null, null, windowSize, numHeads, 0, 0.0, false).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Windowed Attention - Local/Sliding Window Attention.
+ *
+ * Implements windowed attention mechanisms used in efficient transformers like
+ * Longformer, BigBird, Swin Transformer, and SAM (Segment Anything Model).
+ *
+ * Supports both:
+ * - 1D windowed attention: for sequences [batch, seqLen, heads, dim]
+ * - 2D windowed attention: for images [batch, height, width, heads, dim]
+ *
+ * Shifted window attention (shiftSize > 0) enables cross-window connections
+ * as used in Swin Transformer.
+ *
+ * Benefits:
+ * - O(N * windowSize) complexity instead of O(N^2)
+ * - Efficient for long sequences and high-resolution images
+ * - Supports relative position bias for position-aware attention
+ * + * @param query Query tensor. Shape: [batch, seqLen, numHeads, headDim] for 1D or [batch, height, width, numHeads, headDim] for 2D (NUMERIC type) + * @param key Key tensor. Same shape as query (NUMERIC type) + * @param value Value tensor. Same shape as query (NUMERIC type) + * @param relativePositionBias Optional relative position bias. Shape: [numHeads, windowSize, windowSize] (NUMERIC type) + * @param attentionMask Optional attention mask (NUMERIC type) + * @param windowSize Size of attention window + * @param numHeads Number of attention heads + * @param shiftSize Shift size for shifted window attention (Swin style). 0 = no shift + * @param scale Attention scale factor. 0 = auto (1/sqrt(headDim)) + * @param returnWeights Whether to return attention weights + * @return output Attention output. Same shape as query (NUMERIC type) + */ + public SDVariable windowedAttention(SDVariable query, SDVariable key, SDVariable value, + SDVariable relativePositionBias, SDVariable attentionMask, int windowSize, int numHeads, + int shiftSize, double scale, boolean returnWeights) { + SDValidation.validateNumerical("windowedAttention", "query", query); + SDValidation.validateNumerical("windowedAttention", "key", key); + SDValidation.validateNumerical("windowedAttention", "value", value); + SDValidation.validateNumerical("windowedAttention", "relativePositionBias", relativePositionBias); + SDValidation.validateNumerical("windowedAttention", "attentionMask", attentionMask); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.WindowedAttention(sd,query, key, value, relativePositionBias, attentionMask, windowSize, numHeads, shiftSize, scale, returnWeights).outputVariable(); + } + + /** + * Windowed Attention - Local/Sliding Window Attention.
+ *
+ * Implements windowed attention mechanisms used in efficient transformers like
+ * Longformer, BigBird, Swin Transformer, and SAM (Segment Anything Model).
+ *
+ * Supports both:
+ * - 1D windowed attention: for sequences [batch, seqLen, heads, dim]
+ * - 2D windowed attention: for images [batch, height, width, heads, dim]
+ *
+ * Shifted window attention (shiftSize > 0) enables cross-window connections
+ * as used in Swin Transformer.
+ *
+ * Benefits:
+ * - O(N * windowSize) complexity instead of O(N^2)
+ * - Efficient for long sequences and high-resolution images
+ * - Supports relative position bias for position-aware attention
+ * + * @param name name May be null. Name for the output variable + * @param query Query tensor. Shape: [batch, seqLen, numHeads, headDim] for 1D or [batch, height, width, numHeads, headDim] for 2D (NUMERIC type) + * @param key Key tensor. Same shape as query (NUMERIC type) + * @param value Value tensor. Same shape as query (NUMERIC type) + * @param relativePositionBias Optional relative position bias. Shape: [numHeads, windowSize, windowSize] (NUMERIC type) + * @param attentionMask Optional attention mask (NUMERIC type) + * @param windowSize Size of attention window + * @param numHeads Number of attention heads + * @param shiftSize Shift size for shifted window attention (Swin style). 0 = no shift + * @param scale Attention scale factor. 0 = auto (1/sqrt(headDim)) + * @param returnWeights Whether to return attention weights + * @return output Attention output. Same shape as query (NUMERIC type) + */ + public SDVariable windowedAttention(String name, SDVariable query, SDVariable key, + SDVariable value, SDVariable relativePositionBias, SDVariable attentionMask, int windowSize, + int numHeads, int shiftSize, double scale, boolean returnWeights) { + SDValidation.validateNumerical("windowedAttention", "query", query); + SDValidation.validateNumerical("windowedAttention", "key", key); + SDValidation.validateNumerical("windowedAttention", "value", value); + SDValidation.validateNumerical("windowedAttention", "relativePositionBias", relativePositionBias); + SDValidation.validateNumerical("windowedAttention", "attentionMask", attentionMask); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.WindowedAttention(sd,query, key, value, relativePositionBias, attentionMask, windowSize, numHeads, shiftSize, scale, returnWeights).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/AwqMatmul.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/AwqMatmul.java new file mode 100644 index 000000000000..8a2065afb7b2 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/AwqMatmul.java @@ -0,0 +1,134 @@ +/* + * ****************************************************************************** + * * + * * + * * This program and the accompanying materials are made available under the + * * terms of the Apache License, Version 2.0 which is available at + * * https://www.apache.org/licenses/LICENSE-2.0. + * * + * * See the NOTICE file distributed with this work for additional + * * information regarding copyright ownership. + * * Unless required by applicable law or agreed to in writing, software + * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * * License for the specific language governing permissions and limitations + * * under the License. + * * + * * SPDX-License-Identifier: Apache-2.0 + * ***************************************************************************** + */ + +package org.nd4j.linalg.api.ops.impl.transforms.custom; + +import lombok.Getter; +import lombok.NoArgsConstructor; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; + +import java.util.Collections; +import java.util.List; + +/** + * AWQ (Activation-aware Weight Quantization) matrix multiplication. + *

+ * Performs dequantization and GEMM with AWQ-packed weights: + *

+ *   output = input @ dequant(weightPacked, scales, zeros) + bias
+ * 
+ *

+ * Inputs: + *

+ *

+ * Integer arguments: + *

+ *

+ * Output: [M, N] same dtype as input + * + * @author Eclipse Deeplearning4j Contributors + */ +@NoArgsConstructor +public class AwqMatmul extends DynamicCustomOp { + + @Getter private int groupSize = 128; + @Getter private int numBits = 4; + + /** + * SameDiff constructor with required inputs. + */ + public AwqMatmul(SameDiff sameDiff, SDVariable input, SDVariable weightPacked, + SDVariable scales, SDVariable zeros) { + super(null, sameDiff, new SDVariable[]{input, weightPacked, scales, zeros}, false); + addIArgument((long) groupSize, (long) numBits); + } + + /** + * SameDiff constructor with bias. + */ + public AwqMatmul(SameDiff sameDiff, SDVariable input, SDVariable weightPacked, + SDVariable scales, SDVariable zeros, SDVariable bias) { + super(null, sameDiff, new SDVariable[]{input, weightPacked, scales, zeros, bias}, false); + addIArgument((long) groupSize, (long) numBits); + } + + /** + * Full SameDiff constructor with all options. + */ + public AwqMatmul(SameDiff sameDiff, SDVariable input, SDVariable weightPacked, + SDVariable scales, SDVariable zeros, SDVariable bias, + int groupSize, int numBits) { + super(null, sameDiff, bias != null ? + new SDVariable[]{input, weightPacked, scales, zeros, bias} : + new SDVariable[]{input, weightPacked, scales, zeros}, false); + this.groupSize = groupSize; + this.numBits = numBits; + addIArgument((long) groupSize, (long) numBits); + } + + /** + * INDArray constructor. + */ + public AwqMatmul(INDArray input, INDArray weightPacked, INDArray scales, INDArray zeros, + INDArray bias, INDArray output, int groupSize, int numBits) { + super(null, bias != null ? + new INDArray[]{input, weightPacked, scales, zeros, bias} : + new INDArray[]{input, weightPacked, scales, zeros}, + output != null ? new INDArray[]{output} : null); + this.groupSize = groupSize; + this.numBits = numBits; + addIArgument((long) groupSize, (long) numBits); + } + + @Override + public void configureFromArguments() { + super.configureFromArguments(); + if (iArguments.size() > 0) this.groupSize = iArguments.get(0).intValue(); + if (iArguments.size() > 1) this.numBits = iArguments.get(1).intValue(); + } + + @Override + public String opName() { + return "awq_matmul"; + } + + @Override + public List calculateOutputDataTypes(List inputDataTypes) { + return Collections.singletonList(inputDataTypes.get(0)); + } + + @Override + public int getNumOutputs() { + return 1; + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/BooleanAnd.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/BooleanAnd.java new file mode 100644 index 000000000000..728bf2337477 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/BooleanAnd.java @@ -0,0 +1,66 @@ +/* + * ****************************************************************************** + * * + * * + * * This program and the accompanying materials are made available under the + * * terms of the Apache License, Version 2.0 which is available at + * * https://www.apache.org/licenses/LICENSE-2.0. + * * + * * See the NOTICE file distributed with this work for additional + * * information regarding copyright ownership. + * * Unless required by applicable law or agreed to in writing, software + * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * * License for the specific language governing permissions and limitations + * * under the License. + * * + * * SPDX-License-Identifier: Apache-2.0 + * ***************************************************************************** + */ + +package org.nd4j.linalg.api.ops.impl.transforms.custom; + +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.impl.transforms.BaseDynamicTransformOp; + +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + +public class BooleanAnd extends BaseDynamicTransformOp { + public BooleanAnd() {} + + public BooleanAnd(SameDiff sameDiff, SDVariable x, SDVariable y) { + this(sameDiff, new SDVariable[]{x, y}, false); + } + + public BooleanAnd(SameDiff sameDiff, SDVariable[] args, boolean inPlace) { + super(sameDiff, args, inPlace); + } + + public BooleanAnd(INDArray[] inputs, INDArray[] outputs) { + super(inputs, outputs); + } + + public BooleanAnd(INDArray x, INDArray y) { + this(new INDArray[]{x, y}, null); + } + + @Override + public String opName() { + return "boolean_and"; + } + + @Override + public List doDiff(List f1) { + return Arrays.asList(sameDiff.zerosLike(args()[0]), sameDiff.zerosLike(args()[1])); + } + + @Override + public List calculateOutputDataTypes(List dataTypes) { + return Collections.singletonList(DataType.BOOL); + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/BooleanOr.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/BooleanOr.java new file mode 100644 index 000000000000..73b80708af9c --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/BooleanOr.java @@ -0,0 +1,66 @@ +/* + * ****************************************************************************** + * * + * * + * * This program and the accompanying materials are made available under the + * * terms of the Apache License, Version 2.0 which is available at + * * https://www.apache.org/licenses/LICENSE-2.0. + * * + * * See the NOTICE file distributed with this work for additional + * * information regarding copyright ownership. + * * Unless required by applicable law or agreed to in writing, software + * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * * License for the specific language governing permissions and limitations + * * under the License. + * * + * * SPDX-License-Identifier: Apache-2.0 + * ***************************************************************************** + */ + +package org.nd4j.linalg.api.ops.impl.transforms.custom; + +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.impl.transforms.BaseDynamicTransformOp; + +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + +public class BooleanOr extends BaseDynamicTransformOp { + public BooleanOr() {} + + public BooleanOr(SameDiff sameDiff, SDVariable x, SDVariable y) { + this(sameDiff, new SDVariable[]{x, y}, false); + } + + public BooleanOr(SameDiff sameDiff, SDVariable[] args, boolean inPlace) { + super(sameDiff, args, inPlace); + } + + public BooleanOr(INDArray[] inputs, INDArray[] outputs) { + super(inputs, outputs); + } + + public BooleanOr(INDArray x, INDArray y) { + this(new INDArray[]{x, y}, null); + } + + @Override + public String opName() { + return "boolean_or"; + } + + @Override + public List doDiff(List f1) { + return Arrays.asList(sameDiff.zerosLike(args()[0]), sameDiff.zerosLike(args()[1])); + } + + @Override + public List calculateOutputDataTypes(List dataTypes) { + return Collections.singletonList(DataType.BOOL); + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/BooleanXor.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/BooleanXor.java new file mode 100644 index 000000000000..80c6dd55c8eb --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/BooleanXor.java @@ -0,0 +1,66 @@ +/* + * ****************************************************************************** + * * + * * + * * This program and the accompanying materials are made available under the + * * terms of the Apache License, Version 2.0 which is available at + * * https://www.apache.org/licenses/LICENSE-2.0. + * * + * * See the NOTICE file distributed with this work for additional + * * information regarding copyright ownership. + * * Unless required by applicable law or agreed to in writing, software + * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * * License for the specific language governing permissions and limitations + * * under the License. + * * + * * SPDX-License-Identifier: Apache-2.0 + * ***************************************************************************** + */ + +package org.nd4j.linalg.api.ops.impl.transforms.custom; + +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.impl.transforms.BaseDynamicTransformOp; + +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + +public class BooleanXor extends BaseDynamicTransformOp { + public BooleanXor() {} + + public BooleanXor(SameDiff sameDiff, SDVariable x, SDVariable y) { + this(sameDiff, new SDVariable[]{x, y}, false); + } + + public BooleanXor(SameDiff sameDiff, SDVariable[] args, boolean inPlace) { + super(sameDiff, args, inPlace); + } + + public BooleanXor(INDArray[] inputs, INDArray[] outputs) { + super(inputs, outputs); + } + + public BooleanXor(INDArray x, INDArray y) { + this(new INDArray[]{x, y}, null); + } + + @Override + public String opName() { + return "boolean_xor"; + } + + @Override + public List doDiff(List f1) { + return Arrays.asList(sameDiff.zerosLike(args()[0]), sameDiff.zerosLike(args()[1])); + } + + @Override + public List calculateOutputDataTypes(List dataTypes) { + return Collections.singletonList(DataType.BOOL); + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/ColumnParallelLinear.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/ColumnParallelLinear.java new file mode 100644 index 000000000000..741eccbcf31e --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/ColumnParallelLinear.java @@ -0,0 +1,153 @@ +/* + * ****************************************************************************** + * * + * * + * * This program and the accompanying materials are made available under the + * * terms of the Apache License, Version 2.0 which is available at + * * https://www.apache.org/licenses/LICENSE-2.0. + * * + * * See the NOTICE file distributed with this work for additional + * * information regarding copyright ownership. + * * Unless required by applicable law or agreed to in writing, software + * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * * License for the specific language governing permissions and limitations + * * under the License. + * * + * * SPDX-License-Identifier: Apache-2.0 + * ***************************************************************************** + */ + +package org.nd4j.linalg.api.ops.impl.transforms.custom; + +import lombok.Getter; +import lombok.NoArgsConstructor; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; + +import java.util.Collections; +import java.util.List; + +/** + * Column-parallel linear layer for tensor parallelism. + *

+ * Splits the weight matrix along columns across tensor-parallel ranks. + * Each rank computes a shard of the output, optionally gathering across ranks: + *

+ *   output_shard = input @ weightShard + biasShard
+ *   output = allGather(output_shard) if gatherOutput else output_shard
+ * 
+ *

+ * Inputs: + *

+ *

+ * Integer arguments: + *

+ *

+ * Output: [B, O] if gatherOutput=1, else [B, O/tp] + * + * @author Eclipse Deeplearning4j Contributors + */ +@NoArgsConstructor +public class ColumnParallelLinear extends DynamicCustomOp { + + @Getter private int tpSize = 1; + @Getter private int tpRank = 0; + @Getter private int gatherOutput = 1; + + /** + * SameDiff constructor with required inputs. + */ + public ColumnParallelLinear(SameDiff sameDiff, SDVariable input, SDVariable weightShard) { + super(null, sameDiff, new SDVariable[]{input, weightShard}, false); + addIArgument((long) tpSize, (long) tpRank, (long) gatherOutput); + } + + /** + * SameDiff constructor with bias. + */ + public ColumnParallelLinear(SameDiff sameDiff, SDVariable input, SDVariable weightShard, + SDVariable biasShard) { + super(null, sameDiff, new SDVariable[]{input, weightShard, biasShard}, false); + addIArgument((long) tpSize, (long) tpRank, (long) gatherOutput); + } + + /** + * SameDiff constructor with boolean gatherOutput. + */ + public ColumnParallelLinear(SameDiff sameDiff, SDVariable input, SDVariable weightShard, + SDVariable biasShard, int tpSize, int tpRank, boolean gatherOutput) { + this(sameDiff, input, weightShard, biasShard, tpSize, tpRank, gatherOutput ? 1 : 0); + } + + /** + * Full SameDiff constructor with all options. + */ + public ColumnParallelLinear(SameDiff sameDiff, SDVariable input, SDVariable weightShard, + SDVariable biasShard, int tpSize, int tpRank, int gatherOutput) { + super(null, sameDiff, biasShard != null ? + new SDVariable[]{input, weightShard, biasShard} : + new SDVariable[]{input, weightShard}, false); + this.tpSize = tpSize; + this.tpRank = tpRank; + this.gatherOutput = gatherOutput; + addIArgument((long) tpSize, (long) tpRank, (long) gatherOutput); + } + + /** + * INDArray constructor (no output pre-allocation). + */ + public ColumnParallelLinear(INDArray input, INDArray weightShard, INDArray biasShard, + int tpSize, int tpRank, boolean gatherOutput) { + this(input, weightShard, biasShard, null, tpSize, tpRank, gatherOutput ? 1 : 0); + } + + /** + * INDArray constructor. + */ + public ColumnParallelLinear(INDArray input, INDArray weightShard, INDArray biasShard, + INDArray output, int tpSize, int tpRank, int gatherOutput) { + super(null, biasShard != null ? + new INDArray[]{input, weightShard, biasShard} : + new INDArray[]{input, weightShard}, + output != null ? new INDArray[]{output} : null); + this.tpSize = tpSize; + this.tpRank = tpRank; + this.gatherOutput = gatherOutput; + addIArgument((long) tpSize, (long) tpRank, (long) gatherOutput); + } + + @Override + public void configureFromArguments() { + super.configureFromArguments(); + if (iArguments.size() > 0) this.tpSize = iArguments.get(0).intValue(); + if (iArguments.size() > 1) this.tpRank = iArguments.get(1).intValue(); + if (iArguments.size() > 2) this.gatherOutput = iArguments.get(2).intValue(); + } + + @Override + public String opName() { + return "column_parallel_linear"; + } + + @Override + public List calculateOutputDataTypes(List inputDataTypes) { + return Collections.singletonList(inputDataTypes.get(0)); + } + + @Override + public int getNumOutputs() { + return 1; + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/DecoderMaskedMha.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/DecoderMaskedMha.java new file mode 100644 index 000000000000..cccbfeaced69 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/DecoderMaskedMha.java @@ -0,0 +1,166 @@ +/* + * ****************************************************************************** + * * + * * + * * This program and the accompanying materials are made available under the + * * terms of the Apache License, Version 2.0 which is available at + * * https://www.apache.org/licenses/LICENSE-2.0. + * * + * * See the NOTICE file distributed with this work for additional + * * information regarding copyright ownership. + * * Unless required by applicable law or agreed to in writing, software + * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * * License for the specific language governing permissions and limitations + * * under the License. + * * + * * SPDX-License-Identifier: Apache-2.0 + * ***************************************************************************** + */ + +package org.nd4j.linalg.api.ops.impl.transforms.custom; + +import lombok.Getter; +import lombok.NoArgsConstructor; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; + +import java.util.Arrays; +import java.util.List; + +/** + * Decoder masked multi-head attention for autoregressive inference. + *

+ * Performs fused QKV projection, rotary positional encoding, and + * masked multi-head attention in a single op for decoder-only models. + *

+ * Inputs: + *

+ *

+ * Integer arguments: + *

+ *

+ * Float arguments: + *

+ *

+ * Outputs: + *

+ * + * @author Eclipse Deeplearning4j Contributors + */ +@NoArgsConstructor +public class DecoderMaskedMha extends DynamicCustomOp { + + @Getter private int numHeads = 32; + @Getter private int numKvHeads = 0; + @Getter private int headDim = 0; + @Getter private int useRoPE = 1; + @Getter private int ropeBase = 10000; + @Getter private float maskFilterValue = -3.4028235e+38f; + + /** + * SameDiff constructor with required inputs and default options. + */ + public DecoderMaskedMha(SameDiff sameDiff, SDVariable hiddenStates, SDVariable qkvWeight, + SDVariable outWeight, SDVariable pastKey, SDVariable pastValue) { + super(null, sameDiff, new SDVariable[]{hiddenStates, qkvWeight, outWeight, pastKey, pastValue}, false); + addIArgument((long) numHeads, (long) numKvHeads, (long) headDim, (long) useRoPE, (long) ropeBase); + addTArgument((double) maskFilterValue); + } + + /** + * SameDiff constructor with boolean useRoPE (uses default ropeBase and maskFilterValue). + */ + public DecoderMaskedMha(SameDiff sameDiff, SDVariable hiddenStates, SDVariable qkvWeight, + SDVariable outWeight, SDVariable pastKey, SDVariable pastValue, + int numHeads, int numKvHeads, int headDim, boolean useRoPE) { + this(sameDiff, hiddenStates, qkvWeight, outWeight, pastKey, pastValue, + numHeads, numKvHeads, headDim, useRoPE ? 1 : 0, 10000, -3.4028235e+38f); + } + + /** + * Full SameDiff constructor with all options. + */ + public DecoderMaskedMha(SameDiff sameDiff, SDVariable hiddenStates, SDVariable qkvWeight, + SDVariable outWeight, SDVariable pastKey, SDVariable pastValue, + int numHeads, int numKvHeads, int headDim, + int useRoPE, int ropeBase, float maskFilterValue) { + super(null, sameDiff, new SDVariable[]{hiddenStates, qkvWeight, outWeight, pastKey, pastValue}, false); + this.numHeads = numHeads; + this.numKvHeads = numKvHeads; + this.headDim = headDim; + this.useRoPE = useRoPE; + this.ropeBase = ropeBase; + this.maskFilterValue = maskFilterValue; + addIArgument((long) numHeads, (long) numKvHeads, (long) headDim, (long) useRoPE, (long) ropeBase); + addTArgument((double) maskFilterValue); + } + + /** + * INDArray constructor. + */ + public DecoderMaskedMha(INDArray hiddenStates, INDArray qkvWeight, INDArray outWeight, + INDArray pastKey, INDArray pastValue, + INDArray output, INDArray presentKey, INDArray presentValue, + int numHeads, int numKvHeads, int headDim, + int useRoPE, int ropeBase, float maskFilterValue) { + super(null, new INDArray[]{hiddenStates, qkvWeight, outWeight, pastKey, pastValue}, + output != null ? new INDArray[]{output, presentKey, presentValue} : null); + this.numHeads = numHeads; + this.numKvHeads = numKvHeads; + this.headDim = headDim; + this.useRoPE = useRoPE; + this.ropeBase = ropeBase; + this.maskFilterValue = maskFilterValue; + addIArgument((long) numHeads, (long) numKvHeads, (long) headDim, (long) useRoPE, (long) ropeBase); + addTArgument((double) maskFilterValue); + } + + @Override + public void configureFromArguments() { + super.configureFromArguments(); + if (iArguments.size() > 0) this.numHeads = iArguments.get(0).intValue(); + if (iArguments.size() > 1) this.numKvHeads = iArguments.get(1).intValue(); + if (iArguments.size() > 2) this.headDim = iArguments.get(2).intValue(); + if (iArguments.size() > 3) this.useRoPE = iArguments.get(3).intValue(); + if (iArguments.size() > 4) this.ropeBase = iArguments.get(4).intValue(); + if (tArguments.size() > 0) this.maskFilterValue = tArguments.get(0).floatValue(); + } + + @Override + public String opName() { + return "decoder_masked_mha"; + } + + @Override + public List calculateOutputDataTypes(List inputDataTypes) { + DataType dt = inputDataTypes.get(0); + return Arrays.asList(dt, dt, dt); + } + + @Override + public int getNumOutputs() { + return 3; + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Fp8Matmul.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Fp8Matmul.java new file mode 100644 index 000000000000..78fbcd040493 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Fp8Matmul.java @@ -0,0 +1,155 @@ +/* + * ****************************************************************************** + * * + * * + * * This program and the accompanying materials are made available under the + * * terms of the Apache License, Version 2.0 which is available at + * * https://www.apache.org/licenses/LICENSE-2.0. + * * + * * See the NOTICE file distributed with this work for additional + * * information regarding copyright ownership. + * * Unless required by applicable law or agreed to in writing, software + * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * * License for the specific language governing permissions and limitations + * * under the License. + * * + * * SPDX-License-Identifier: Apache-2.0 + * ***************************************************************************** + */ + +package org.nd4j.linalg.api.ops.impl.transforms.custom; + +import lombok.Getter; +import lombok.NoArgsConstructor; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.common.base.Preconditions; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; + +import java.util.Collections; +import java.util.List; + +/** + * FP8 Matrix Multiplication operation (SM 89+ / Ada Lovelace and newer). + *

+ * Performs GEMM with FP8 quantized inputs and FP16/FP32 output, using + * CUTLASS FP8 GEMM with per-tensor dequantization scales: + *

+ *   C = dequant(A @ B) + bias
+ *     = (scale_A * A_fp8) @ (scale_B * B_fp8) + bias
+ * 
+ *

+ * Inputs: + *

    + *
  • 0: A tensor (FP8 E4M3 quantized, stored as INT8) [M, K]
  • + *
  • 1: B tensor (FP8 E4M3 quantized, stored as INT8) [K, N]
  • + *
  • 2: scale_A (float scalar) - per-tensor dequantization scale for A
  • + *
  • 3: scale_B (float scalar) - per-tensor dequantization scale for B
  • + *
  • 4: bias (optional) [N]
  • + *
+ *

+ * Output: C tensor (FLOAT16) [M, N] + *

+ * Integer arguments: + *

    + *
  • 0: fp8_format (0=E4M3, 1=E5M2, default: 0)
  • + *
  • 1: transpose_a (0=no, 1=yes, default: 0)
  • + *
  • 2: transpose_b (0=no, 1=yes, default: 0)
  • + *
+ * + * @author Eclipse Deeplearning4j Contributors + */ +@NoArgsConstructor +public class Fp8Matmul extends DynamicCustomOp { + + /** FP8 E4M3 format (4-bit exponent, 3-bit mantissa) - higher precision */ + public static final int FP8_E4M3 = 0; + /** FP8 E5M2 format (5-bit exponent, 2-bit mantissa) - wider range */ + public static final int FP8_E5M2 = 1; + + @Getter private int fp8Format = FP8_E4M3; + @Getter private boolean transposeA = false; + @Getter private boolean transposeB = false; + + /** + * SameDiff constructor with required inputs. + * + * @param sameDiff the SameDiff instance + * @param a FP8 quantized A matrix [M, K] + * @param b FP8 quantized B matrix [K, N] + * @param scaleA per-tensor scale for A + * @param scaleB per-tensor scale for B + */ + public Fp8Matmul(SameDiff sameDiff, SDVariable a, SDVariable b, + SDVariable scaleA, SDVariable scaleB) { + super(null, sameDiff, new SDVariable[]{a, b, scaleA, scaleB}, false); + addIArgument((long) fp8Format, transposeA ? 1L : 0L, transposeB ? 1L : 0L); + } + + /** + * SameDiff constructor with bias. + */ + public Fp8Matmul(SameDiff sameDiff, SDVariable a, SDVariable b, + SDVariable scaleA, SDVariable scaleB, SDVariable bias) { + super(null, sameDiff, new SDVariable[]{a, b, scaleA, scaleB, bias}, false); + addIArgument((long) fp8Format, transposeA ? 1L : 0L, transposeB ? 1L : 0L); + } + + /** + * Full SameDiff constructor with all options. + */ + public Fp8Matmul(SameDiff sameDiff, SDVariable a, SDVariable b, + SDVariable scaleA, SDVariable scaleB, SDVariable bias, + int fp8Format, boolean transposeA, boolean transposeB) { + super(null, sameDiff, bias != null ? + new SDVariable[]{a, b, scaleA, scaleB, bias} : + new SDVariable[]{a, b, scaleA, scaleB}, false); + this.fp8Format = fp8Format; + this.transposeA = transposeA; + this.transposeB = transposeB; + addIArgument((long) fp8Format, transposeA ? 1L : 0L, transposeB ? 1L : 0L); + } + + /** + * INDArray constructor. + */ + public Fp8Matmul(INDArray a, INDArray b, INDArray scaleA, INDArray scaleB, + INDArray bias, INDArray output, + int fp8Format, boolean transposeA, boolean transposeB) { + super(null, bias != null ? + new INDArray[]{a, b, scaleA, scaleB, bias} : + new INDArray[]{a, b, scaleA, scaleB}, + output != null ? new INDArray[]{output} : null); + this.fp8Format = fp8Format; + this.transposeA = transposeA; + this.transposeB = transposeB; + addIArgument((long) fp8Format, transposeA ? 1L : 0L, transposeB ? 1L : 0L); + } + + @Override + public void configureFromArguments() { + super.configureFromArguments(); + if (iArguments.size() > 0) this.fp8Format = iArguments.get(0).intValue(); + if (iArguments.size() > 1) this.transposeA = iArguments.get(1) != 0; + if (iArguments.size() > 2) this.transposeB = iArguments.get(2) != 0; + } + + @Override + public String opName() { + return "fp8_matmul"; + } + + @Override + public List calculateOutputDataTypes(List inputDataTypes) { + // FP8 matmul always outputs FLOAT16 (or FLOAT for accumulation) + return Collections.singletonList(DataType.HALF); + } + + @Override + public int getNumOutputs() { + return 1; + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/FusedElementwiseChain.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/FusedElementwiseChain.java new file mode 100644 index 000000000000..74d68334ef98 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/FusedElementwiseChain.java @@ -0,0 +1,233 @@ +/* + * ****************************************************************************** + * * + * * + * * This program and the accompanying materials are made available under the + * * terms of the Apache License, Version 2.0 which is available at + * * https://www.apache.org/licenses/LICENSE-2.0. + * * + * * See the NOTICE file distributed with this work for additional + * * information regarding copyright ownership. + * * Unless required by applicable law or agreed to in writing, software + * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * * License for the specific language governing permissions and limitations + * * under the License. + * * + * * SPDX-License-Identifier: Apache-2.0 + * ***************************************************************************** + */ + +package org.nd4j.linalg.api.ops.impl.transforms.custom; + +import lombok.NoArgsConstructor; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +/** + * Fused element-wise chain operation. + *

+ * Executes a sequence of element-wise ops in a single kernel pass, keeping + * intermediate values in registers instead of global memory. This replaces + * N separate kernel launches with 1. + *

+ * Op codes (iArgs): + *

+ *   0=add, 1=sub, 2=mul, 3=div,
+ *   10=relu, 11=sigmoid, 12=tanh, 13=gelu,
+ *   14=exp, 15=log, 16=abs, 17=neg,
+ *   18=square, 19=sqrt, 20=swish, 21=silu, 22=mish,
+ *   30=clip, 31=leaky_relu
+ * 
+ *

+ * Usage: + *

+ *   // multiply(x, y) -> sigmoid
+ *   FusedElementwiseChain.builder()
+ *       .input(x)
+ *       .multiply(y)
+ *       .sigmoid()
+ *       .build()
+ *       .exec();
+ * 
+ * + * @author Adam Gibson + */ +@NoArgsConstructor +public class FusedElementwiseChain extends DynamicCustomOp { + + // Op codes matching FusedElemOp enum in C++ + public static final int OP_ADD = 0; + public static final int OP_SUB = 1; + public static final int OP_MUL = 2; + public static final int OP_DIV = 3; + public static final int OP_RELU = 10; + public static final int OP_SIGMOID = 11; + public static final int OP_TANH = 12; + public static final int OP_GELU = 13; + public static final int OP_EXP = 14; + public static final int OP_LOG = 15; + public static final int OP_ABS = 16; + public static final int OP_NEG = 17; + public static final int OP_SQUARE = 18; + public static final int OP_SQRT = 19; + public static final int OP_SWISH = 20; + public static final int OP_SILU = 21; + public static final int OP_MISH = 22; + public static final int OP_CLIP = 30; + public static final int OP_LEAKY_RELU = 31; + + /** + * Create a fused elementwise chain. + * + * @param inputs Primary input at index 0, secondary inputs for binary ops at indices 1+ + * @param output Pre-allocated output (or null) + * @param opCodes Sequence of FusedElemOp codes + */ + public FusedElementwiseChain(INDArray[] inputs, INDArray output, int... opCodes) { + super(inputs, output != null ? new INDArray[]{output} : null); + for (int code : opCodes) { + addIArgument(code); + } + } + + /** + * Constructor for codegen (NDNN): single primary input, unary ops only. + */ + public FusedElementwiseChain(INDArray input, int... opCodes) { + this(new INDArray[]{input}, null, opCodes); + } + + /** + * Constructor for codegen (NDNN): single primary input + optional secondary inputs array. + */ + public FusedElementwiseChain(INDArray input, INDArray[] secondaryInputs, int[] opCodes) { + this(buildInputs(input, secondaryInputs), null, opCodes); + } + + /** + * Constructor for codegen (SDNN): SameDiff graph mode, unary ops only. + */ + public FusedElementwiseChain(SameDiff sd, SDVariable input, int... opCodes) { + this(sd, input, (SDVariable[]) null, opCodes); + } + + /** + * Constructor for codegen (SDNN): SameDiff graph mode with secondary inputs. + */ + public FusedElementwiseChain(SameDiff sd, SDVariable input, SDVariable[] secondaryInputs, int[] opCodes) { + super(null, sd, buildSdInputs(input, secondaryInputs), false); + for (int code : opCodes) { + addIArgument(code); + } + } + + private static INDArray[] buildInputs(INDArray input, INDArray[] secondaryInputs) { + if (secondaryInputs == null || secondaryInputs.length == 0) { + return new INDArray[]{input}; + } + INDArray[] all = new INDArray[1 + secondaryInputs.length]; + all[0] = input; + System.arraycopy(secondaryInputs, 0, all, 1, secondaryInputs.length); + return all; + } + + private static SDVariable[] buildSdInputs(SDVariable input, SDVariable[] secondaryInputs) { + if (secondaryInputs == null || secondaryInputs.length == 0) { + return new SDVariable[]{input}; + } + SDVariable[] all = new SDVariable[1 + secondaryInputs.length]; + all[0] = input; + System.arraycopy(secondaryInputs, 0, all, 1, secondaryInputs.length); + return all; + } + + @Override + public String opName() { + return "fused_elementwise_chain"; + } + + @Override + public List calculateOutputDataTypes(List inputDataTypes) { + return Collections.singletonList(inputDataTypes.get(0)); + } + + @Override + public int getNumOutputs() { + return 1; + } + + /** + * Builder for constructing fused chains fluently. + */ + public static ChainBuilder builder() { + return new ChainBuilder(); + } + + public static class ChainBuilder { + private INDArray primaryInput; + private INDArray output; + private java.util.List secondaryInputs = new java.util.ArrayList<>(); + private java.util.List opCodes = new java.util.ArrayList<>(); + + public ChainBuilder input(INDArray input) { + this.primaryInput = input; + return this; + } + + public ChainBuilder output(INDArray output) { + this.output = output; + return this; + } + + // Binary ops (need secondary input) + public ChainBuilder add(INDArray secondary) { opCodes.add(OP_ADD); secondaryInputs.add(secondary); return this; } + public ChainBuilder subtract(INDArray secondary) { opCodes.add(OP_SUB); secondaryInputs.add(secondary); return this; } + public ChainBuilder multiply(INDArray secondary) { opCodes.add(OP_MUL); secondaryInputs.add(secondary); return this; } + public ChainBuilder divide(INDArray secondary) { opCodes.add(OP_DIV); secondaryInputs.add(secondary); return this; } + + // Unary ops + public ChainBuilder relu() { opCodes.add(OP_RELU); return this; } + public ChainBuilder sigmoid() { opCodes.add(OP_SIGMOID); return this; } + public ChainBuilder tanh() { opCodes.add(OP_TANH); return this; } + public ChainBuilder gelu() { opCodes.add(OP_GELU); return this; } + public ChainBuilder exp() { opCodes.add(OP_EXP); return this; } + public ChainBuilder log() { opCodes.add(OP_LOG); return this; } + public ChainBuilder abs() { opCodes.add(OP_ABS); return this; } + public ChainBuilder neg() { opCodes.add(OP_NEG); return this; } + public ChainBuilder square() { opCodes.add(OP_SQUARE); return this; } + public ChainBuilder sqrt() { opCodes.add(OP_SQRT); return this; } + public ChainBuilder swish() { opCodes.add(OP_SWISH); return this; } + public ChainBuilder silu() { opCodes.add(OP_SILU); return this; } + public ChainBuilder mish() { opCodes.add(OP_MISH); return this; } + + public ChainBuilder addOp(int opCode) { opCodes.add(opCode); return this; } + public ChainBuilder addOp(int opCode, INDArray secondary) { opCodes.add(opCode); secondaryInputs.add(secondary); return this; } + + public FusedElementwiseChain build() { + if (primaryInput == null) { + throw new IllegalStateException("Primary input is required"); + } + if (opCodes.isEmpty()) { + throw new IllegalStateException("At least one op is required"); + } + + // Build inputs array: primary first, then secondaries + INDArray[] inputs = new INDArray[1 + secondaryInputs.size()]; + inputs[0] = primaryInput; + for (int i = 0; i < secondaryInputs.size(); i++) { + inputs[i + 1] = secondaryInputs.get(i); + } + + int[] codes = opCodes.stream().mapToInt(Integer::intValue).toArray(); + return new FusedElementwiseChain(inputs, output, codes); + } + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/FusedGemmSwiglu.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/FusedGemmSwiglu.java new file mode 100644 index 000000000000..0f3349af36a3 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/FusedGemmSwiglu.java @@ -0,0 +1,85 @@ +/* + * ****************************************************************************** + * * + * * + * * This program and the accompanying materials are made available under the + * * terms of the Apache License, Version 2.0 which is available at + * * https://www.apache.org/licenses/LICENSE-2.0. + * * + * * See the NOTICE file distributed with this work for additional + * * information regarding copyright ownership. + * * Unless required by applicable law or agreed to in writing, software + * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * * License for the specific language governing permissions and limitations + * * under the License. + * * + * * SPDX-License-Identifier: Apache-2.0 + * ***************************************************************************** + */ + +package org.nd4j.linalg.api.ops.impl.transforms.custom; + +import lombok.NoArgsConstructor; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; + +import java.util.Collections; +import java.util.List; + +/** + * Fused GEMM + SwiGLU activation. + *

+ * Computes the gated linear unit with SiLU (Swish) activation in a single + * fused kernel, commonly used in LLM feed-forward blocks: + *

+ *   output = (input @ wGate) * silu(input @ wUp)
+ * 
+ *

+ * Inputs: + *

    + *
  • 0: input [M, K]
  • + *
  • 1: wGate [K, N] - gate projection weight
  • + *
  • 2: wUp [K, N] - up projection weight
  • + *
+ *

+ * Output: [M, N] + * + * @author Eclipse Deeplearning4j Contributors + */ +@NoArgsConstructor +public class FusedGemmSwiglu extends DynamicCustomOp { + + /** + * SameDiff constructor. + */ + public FusedGemmSwiglu(SameDiff sameDiff, SDVariable input, SDVariable wGate, SDVariable wUp) { + super(null, sameDiff, new SDVariable[]{input, wGate, wUp}, false); + } + + /** + * INDArray constructor. + */ + public FusedGemmSwiglu(INDArray input, INDArray wGate, INDArray wUp, INDArray output) { + super(null, new INDArray[]{input, wGate, wUp}, + output != null ? new INDArray[]{output} : null); + } + + @Override + public String opName() { + return "fused_gemm_swiglu"; + } + + @Override + public List calculateOutputDataTypes(List inputDataTypes) { + return Collections.singletonList(inputDataTypes.get(0)); + } + + @Override + public int getNumOutputs() { + return 1; + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/FusedNormQuantize.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/FusedNormQuantize.java new file mode 100644 index 000000000000..66be2dea58fd --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/FusedNormQuantize.java @@ -0,0 +1,163 @@ +/* + * ****************************************************************************** + * * + * * + * * This program and the accompanying materials are made available under the + * * terms of the Apache License, Version 2.0 which is available at + * * https://www.apache.org/licenses/LICENSE-2.0. + * * + * * See the NOTICE file distributed with this work for additional + * * information regarding copyright ownership. + * * Unless required by applicable law or agreed to in writing, software + * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * * License for the specific language governing permissions and limitations + * * under the License. + * * + * * SPDX-License-Identifier: Apache-2.0 + * ***************************************************************************** + */ + +package org.nd4j.linalg.api.ops.impl.transforms.custom; + +import lombok.Getter; +import lombok.NoArgsConstructor; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; + +import java.util.Arrays; +import java.util.List; + +/** + * Fused normalization and quantization. + *

+ * Applies normalization (RMSNorm or LayerNorm) followed by quantization + * (INT8 or FP8) in a single fused kernel: + *

+ *   normalized = norm(input, weight, bias)
+ *   quantized  = quantize(normalized)
+ *   scale      = compute_scale(normalized)
+ * 
+ *

+ * Inputs: + *

    + *
  • 0: input [B, H]
  • + *
  • 1: weight [H] - normalization weight/gamma
  • + *
  • 2: bias [H] (optional) - normalization bias/beta
  • + *
+ *

+ * Integer arguments: + *

    + *
  • 0: normType (0=RMSNorm, 1=LayerNorm)
  • + *
  • 1: quantType (0=INT8, 1=FP8)
  • + *
+ *

+ * Float arguments: + *

    + *
  • 0: epsilon (default 1e-5)
  • + *
+ *

+ * Outputs: + *

    + *
  • 0: quantized [B, H] INT8
  • + *
  • 1: scale [B] FLOAT32
  • + *
+ * + * @author Eclipse Deeplearning4j Contributors + */ +@NoArgsConstructor +public class FusedNormQuantize extends DynamicCustomOp { + + public static final int NORM_RMSNORM = 0; + public static final int NORM_LAYERNORM = 1; + public static final int QUANT_INT8 = 0; + public static final int QUANT_FP8 = 1; + + @Getter private int normType = NORM_RMSNORM; + @Getter private int quantType = QUANT_INT8; + @Getter private float epsilon = 1e-5f; + + /** + * SameDiff constructor with required inputs. + */ + public FusedNormQuantize(SameDiff sameDiff, SDVariable input, SDVariable weight) { + super(null, sameDiff, new SDVariable[]{input, weight}, false); + addIArgument((long) normType, (long) quantType); + addTArgument((double) epsilon); + } + + /** + * SameDiff constructor with bias. + */ + public FusedNormQuantize(SameDiff sameDiff, SDVariable input, SDVariable weight, SDVariable bias) { + super(null, sameDiff, new SDVariable[]{input, weight, bias}, false); + addIArgument((long) normType, (long) quantType); + addTArgument((double) epsilon); + } + + /** + * SameDiff constructor with double epsilon. + */ + public FusedNormQuantize(SameDiff sameDiff, SDVariable input, SDVariable weight, SDVariable bias, + int normType, int quantType, double epsilon) { + this(sameDiff, input, weight, bias, normType, quantType, (float) epsilon); + } + + /** + * Full SameDiff constructor with all options. + */ + public FusedNormQuantize(SameDiff sameDiff, SDVariable input, SDVariable weight, SDVariable bias, + int normType, int quantType, float epsilon) { + super(null, sameDiff, bias != null ? + new SDVariable[]{input, weight, bias} : + new SDVariable[]{input, weight}, false); + this.normType = normType; + this.quantType = quantType; + this.epsilon = epsilon; + addIArgument((long) normType, (long) quantType); + addTArgument((double) epsilon); + } + + /** + * INDArray constructor. + */ + public FusedNormQuantize(INDArray input, INDArray weight, INDArray bias, + INDArray quantized, INDArray scale, + int normType, int quantType, float epsilon) { + super(null, bias != null ? + new INDArray[]{input, weight, bias} : + new INDArray[]{input, weight}, + quantized != null ? new INDArray[]{quantized, scale} : null); + this.normType = normType; + this.quantType = quantType; + this.epsilon = epsilon; + addIArgument((long) normType, (long) quantType); + addTArgument((double) epsilon); + } + + @Override + public void configureFromArguments() { + super.configureFromArguments(); + if (iArguments.size() > 0) this.normType = iArguments.get(0).intValue(); + if (iArguments.size() > 1) this.quantType = iArguments.get(1).intValue(); + if (tArguments.size() > 0) this.epsilon = tArguments.get(0).floatValue(); + } + + @Override + public String opName() { + return "fused_norm_quantize"; + } + + @Override + public List calculateOutputDataTypes(List inputDataTypes) { + return Arrays.asList(DataType.INT8, DataType.FLOAT); + } + + @Override + public int getNumOutputs() { + return 2; + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/GpuTopKSample.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/GpuTopKSample.java new file mode 100644 index 000000000000..12709266ded5 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/GpuTopKSample.java @@ -0,0 +1,167 @@ +/* + * ****************************************************************************** + * * + * * + * * This program and the accompanying materials are made available under the + * * terms of the Apache License, Version 2.0 which is available at + * * https://www.apache.org/licenses/LICENSE-2.0. + * * + * * See the NOTICE file distributed with this work for additional + * * information regarding copyright ownership. + * * Unless required by applicable law or agreed to in writing, software + * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * * License for the specific language governing permissions and limitations + * * under the License. + * * + * * SPDX-License-Identifier: Apache-2.0 + * ***************************************************************************** + */ + +package org.nd4j.linalg.api.ops.impl.transforms.custom; + +import lombok.Getter; +import lombok.NoArgsConstructor; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; + +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + +/** + * GPU-accelerated Top-K sampling for LLM token generation. + *

+ * Performs top-K filtering, softmax, and multinomial sampling entirely on the GPU, + * eliminating the device-to-host transfer overhead for logits. Uses CUB radix sort + * for efficient top-K selection. + *

+ * Pipeline: logits -> temperature scaling -> top-K filter -> softmax -> multinomial sample + *

+ * Inputs: + *

    + *
  • 0: logits [batch, vocab_size]
  • + *
  • 1: random values [batch] (uniform [0,1) for multinomial, optional)
  • + *
+ *

+ * Outputs: + *

    + *
  • 0: sampled token IDs [batch] (INT64)
  • + *
  • 1: sampled probabilities [batch] (FLOAT, optional)
  • + *
+ *

+ * Integer arguments: + *

    + *
  • 0: k (number of top tokens to consider, default: 50)
  • + *
  • 1: seed (RNG seed, used if no random values input, default: 0)
  • + *
+ *

+ * Float arguments: + *

    + *
  • 0: temperature (default: 1.0)
  • + *
+ * + * @author Eclipse Deeplearning4j Contributors + * @see GpuTopPSample + * @see TokenSample + */ +@NoArgsConstructor +public class GpuTopKSample extends DynamicCustomOp { + + @Getter private int k = 50; + @Getter private long seed = 0; + @Getter private double temperature = 1.0; + + /** + * INDArray constructor with logits only (default k=50, temperature=1.0). + */ + public GpuTopKSample(INDArray logits, int k) { + this(logits, k, 1.0, 0); + } + + /** + * Full INDArray constructor. + * + * @param logits input logits [batch, vocab_size] + * @param k number of top tokens to consider + * @param temperature temperature for scaling + * @param seed RNG seed (0 for random) + */ + public GpuTopKSample(INDArray logits, int k, double temperature, long seed) { + super(new INDArray[]{logits}, null); + this.k = k; + this.temperature = temperature; + this.seed = seed; + addIArgument((long) k, seed); + addTArgument(temperature); + } + + /** + * INDArray constructor with explicit random values. + */ + public GpuTopKSample(INDArray logits, INDArray randomValues, int k, double temperature) { + super(new INDArray[]{logits, randomValues}, null); + this.k = k; + this.temperature = temperature; + addIArgument((long) k, seed); + addTArgument(temperature); + } + + /** + * SameDiff constructor with int seed and reordered parameters (k, seed, temperature). + */ + public GpuTopKSample(SameDiff sameDiff, SDVariable logits, int k, int seed, double temperature) { + this(sameDiff, logits, k, temperature, (long) seed); + } + + /** + * SameDiff constructor. + */ + public GpuTopKSample(SameDiff sameDiff, SDVariable logits, int k, double temperature, long seed) { + super(null, sameDiff, new SDVariable[]{logits}, false); + this.k = k; + this.temperature = temperature; + this.seed = seed; + addIArgument((long) k, seed); + addTArgument(temperature); + } + + /** + * SameDiff constructor with random values input. + */ + public GpuTopKSample(SameDiff sameDiff, SDVariable logits, SDVariable randomValues, + int k, double temperature) { + super(null, sameDiff, new SDVariable[]{logits, randomValues}, false); + this.k = k; + this.temperature = temperature; + addIArgument((long) k, seed); + addTArgument(temperature); + } + + @Override + public void configureFromArguments() { + super.configureFromArguments(); + if (iArguments.size() > 0) this.k = iArguments.get(0).intValue(); + if (iArguments.size() > 1) this.seed = iArguments.get(1); + if (tArguments.size() > 0) this.temperature = tArguments.get(0); + } + + @Override + public String opName() { + return "gpu_top_k_sample"; + } + + @Override + public List calculateOutputDataTypes(List inputDataTypes) { + // Output 0: token IDs (INT64), Output 1: probabilities (FLOAT) + return Arrays.asList(DataType.INT64, DataType.FLOAT); + } + + @Override + public int getNumOutputs() { + return 2; + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/GpuTopPSample.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/GpuTopPSample.java new file mode 100644 index 000000000000..5c9506ec0c3b --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/GpuTopPSample.java @@ -0,0 +1,217 @@ +/* + * ****************************************************************************** + * * + * * + * * This program and the accompanying materials are made available under the + * * terms of the Apache License, Version 2.0 which is available at + * * https://www.apache.org/licenses/LICENSE-2.0. + * * + * * See the NOTICE file distributed with this work for additional + * * information regarding copyright ownership. + * * Unless required by applicable law or agreed to in writing, software + * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * * License for the specific language governing permissions and limitations + * * under the License. + * * + * * SPDX-License-Identifier: Apache-2.0 + * ***************************************************************************** + */ + +package org.nd4j.linalg.api.ops.impl.transforms.custom; + +import lombok.Getter; +import lombok.NoArgsConstructor; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; + +import java.util.Arrays; +import java.util.List; + +/** + * GPU-accelerated Top-P (nucleus) sampling for LLM token generation. + *

+ * Performs nucleus sampling entirely on the GPU: sorts tokens by probability, + * computes cumulative sums to find the smallest set of tokens whose cumulative + * probability exceeds p, then samples from that nucleus. + *

+ * Pipeline: logits -> temperature scaling -> softmax -> sort -> cumsum -> nucleus filter -> sample + *

+ * Supports additional penalty parameters for controlling repetition: + *

    + *
  • repetition_penalty: multiplicative penalty for previously generated tokens
  • + *
  • frequency_penalty: additive penalty proportional to token frequency
  • + *
  • presence_penalty: flat additive penalty for any previously seen token
  • + *
+ *

+ * Inputs: + *

    + *
  • 0: logits [batch, vocab_size]
  • + *
  • 1: random values [batch] (uniform [0,1), optional)
  • + *
+ *

+ * Outputs: + *

    + *
  • 0: sampled token IDs [batch] (INT64)
  • + *
  • 1: sampled probabilities [batch] (FLOAT, optional)
  • + *
+ *

+ * Integer arguments: + *

    + *
  • 0: seed (RNG seed, default: 0)
  • + *
+ *

+ * Float arguments: + *

    + *
  • 0: p (nucleus probability threshold, default: 0.9)
  • + *
  • 1: temperature (default: 1.0)
  • + *
  • 2: repetition_penalty (default: 1.0, no penalty)
  • + *
  • 3: frequency_penalty (default: 0.0)
  • + *
  • 4: presence_penalty (default: 0.0)
  • + *
+ * + * @author Eclipse Deeplearning4j Contributors + * @see GpuTopKSample + * @see TokenSample + */ +@NoArgsConstructor +public class GpuTopPSample extends DynamicCustomOp { + + @Getter private double p = 0.9; + @Getter private double temperature = 1.0; + @Getter private double repetitionPenalty = 1.0; + @Getter private double frequencyPenalty = 0.0; + @Getter private double presencePenalty = 0.0; + @Getter private long seed = 0; + + /** + * INDArray constructor with logits only (default p=0.9, temperature=1.0). + */ + public GpuTopPSample(INDArray logits, double p) { + this(logits, p, 1.0, 0); + } + + /** + * INDArray constructor with core parameters. + * + * @param logits input logits [batch, vocab_size] + * @param p nucleus probability threshold (0.0 to 1.0) + * @param temperature temperature for scaling + * @param seed RNG seed (0 for random) + */ + public GpuTopPSample(INDArray logits, double p, double temperature, long seed) { + super(new INDArray[]{logits}, null); + this.p = p; + this.temperature = temperature; + this.seed = seed; + addIArgument(seed); + addTArgument(p, temperature, repetitionPenalty, frequencyPenalty, presencePenalty); + } + + /** + * Full INDArray constructor with all penalty parameters. + */ + public GpuTopPSample(INDArray logits, double p, double temperature, long seed, + double repetitionPenalty, double frequencyPenalty, double presencePenalty) { + super(new INDArray[]{logits}, null); + this.p = p; + this.temperature = temperature; + this.seed = seed; + this.repetitionPenalty = repetitionPenalty; + this.frequencyPenalty = frequencyPenalty; + this.presencePenalty = presencePenalty; + addIArgument(seed); + addTArgument(p, temperature, repetitionPenalty, frequencyPenalty, presencePenalty); + } + + /** + * INDArray constructor with explicit random values. + */ + public GpuTopPSample(INDArray logits, INDArray randomValues, double p, double temperature) { + super(new INDArray[]{logits, randomValues}, null); + this.p = p; + this.temperature = temperature; + addIArgument(seed); + addTArgument(p, temperature, repetitionPenalty, frequencyPenalty, presencePenalty); + } + + /** + * SameDiff constructor. + */ + public GpuTopPSample(SameDiff sameDiff, SDVariable logits, double p, double temperature, long seed) { + super(null, sameDiff, new SDVariable[]{logits}, false); + this.p = p; + this.temperature = temperature; + this.seed = seed; + addIArgument(seed); + addTArgument(p, temperature, repetitionPenalty, frequencyPenalty, presencePenalty); + } + + /** + * SameDiff constructor with random values input. + */ + public GpuTopPSample(SameDiff sameDiff, SDVariable logits, SDVariable randomValues, + double p, double temperature) { + super(null, sameDiff, new SDVariable[]{logits, randomValues}, false); + this.p = p; + this.temperature = temperature; + addIArgument(seed); + addTArgument(p, temperature, repetitionPenalty, frequencyPenalty, presencePenalty); + } + + /** + * SameDiff constructor with int seed and reordered parameters (seed, p, temperature, ...). + */ + public GpuTopPSample(SameDiff sameDiff, SDVariable logits, int seed, double p, + double temperature, double repetitionPenalty, double frequencyPenalty, + double presencePenalty) { + this(sameDiff, logits, p, temperature, (long) seed, repetitionPenalty, frequencyPenalty, presencePenalty); + } + + /** + * Full SameDiff constructor with all parameters. + */ + public GpuTopPSample(SameDiff sameDiff, SDVariable logits, double p, double temperature, + long seed, double repetitionPenalty, double frequencyPenalty, + double presencePenalty) { + super(null, sameDiff, new SDVariable[]{logits}, false); + this.p = p; + this.temperature = temperature; + this.seed = seed; + this.repetitionPenalty = repetitionPenalty; + this.frequencyPenalty = frequencyPenalty; + this.presencePenalty = presencePenalty; + addIArgument(seed); + addTArgument(p, temperature, repetitionPenalty, frequencyPenalty, presencePenalty); + } + + @Override + public void configureFromArguments() { + super.configureFromArguments(); + if (iArguments.size() > 0) this.seed = iArguments.get(0); + if (tArguments.size() > 0) this.p = tArguments.get(0); + if (tArguments.size() > 1) this.temperature = tArguments.get(1); + if (tArguments.size() > 2) this.repetitionPenalty = tArguments.get(2); + if (tArguments.size() > 3) this.frequencyPenalty = tArguments.get(3); + if (tArguments.size() > 4) this.presencePenalty = tArguments.get(4); + } + + @Override + public String opName() { + return "gpu_top_p_sample"; + } + + @Override + public List calculateOutputDataTypes(List inputDataTypes) { + // Output 0: token IDs (INT64), Output 1: probabilities (FLOAT) + return Arrays.asList(DataType.INT64, DataType.FLOAT); + } + + @Override + public int getNumOutputs() { + return 2; + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/MoeGate.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/MoeGate.java new file mode 100644 index 000000000000..c614ca842619 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/MoeGate.java @@ -0,0 +1,146 @@ +/* + * ****************************************************************************** + * * + * * + * * This program and the accompanying materials are made available under the + * * terms of the Apache License, Version 2.0 which is available at + * * https://www.apache.org/licenses/LICENSE-2.0. + * * + * * See the NOTICE file distributed with this work for additional + * * information regarding copyright ownership. + * * Unless required by applicable law or agreed to in writing, software + * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * * License for the specific language governing permissions and limitations + * * under the License. + * * + * * SPDX-License-Identifier: Apache-2.0 + * ***************************************************************************** + */ + +package org.nd4j.linalg.api.ops.impl.transforms.custom; + +import lombok.Getter; +import lombok.NoArgsConstructor; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; + +import java.util.Arrays; +import java.util.List; + +/** + * Mixture-of-Experts (MoE) gating operation. + *

+ * Computes top-K expert routing with load-balancing auxiliary loss: + *

+ *   logits         = hiddenStates @ gateWeight
+ *   expertIndices  = topK(softmax(logits))
+ *   gateWeights    = normalized weights for selected experts
+ *   auxLoss        = load balancing loss
+ * 
+ *

+ * Inputs: + *

    + *
  • 0: hiddenStates [T, H] - token hidden states
  • + *
  • 1: gateWeight [H, E] - gating projection (H=hidden, E=numExperts)
  • + *
+ *

+ * Integer arguments: + *

    + *
  • 0: topK (default 2) - number of experts per token
  • + *
  • 1: numExperts (default 8) - total number of experts
  • + *
+ *

+ * Float arguments: + *

    + *
  • 0: auxLossCoeff (default 0.01) - load balancing loss coefficient
  • + *
+ *

+ * Outputs: + *

    + *
  • 0: expertIndices [T, K] INT64 - selected expert indices
  • + *
  • 1: gateWeights [T, K] - routing weights for selected experts
  • + *
  • 2: auxLoss [1] - auxiliary load balancing loss scalar
  • + *
+ * + * @author Eclipse Deeplearning4j Contributors + */ +@NoArgsConstructor +public class MoeGate extends DynamicCustomOp { + + @Getter private int topK = 2; + @Getter private int numExperts = 8; + @Getter private float auxLossCoeff = 0.01f; + + /** + * SameDiff constructor with default options. + */ + public MoeGate(SameDiff sameDiff, SDVariable hiddenStates, SDVariable gateWeight) { + super(null, sameDiff, new SDVariable[]{hiddenStates, gateWeight}, false); + addIArgument((long) topK, (long) numExperts); + addTArgument((double) auxLossCoeff); + } + + /** + * SameDiff constructor with double auxLossCoeff. + */ + public MoeGate(SameDiff sameDiff, SDVariable hiddenStates, SDVariable gateWeight, + int topK, int numExperts, double auxLossCoeff) { + this(sameDiff, hiddenStates, gateWeight, topK, numExperts, (float) auxLossCoeff); + } + + /** + * Full SameDiff constructor with all options. + */ + public MoeGate(SameDiff sameDiff, SDVariable hiddenStates, SDVariable gateWeight, + int topK, int numExperts, float auxLossCoeff) { + super(null, sameDiff, new SDVariable[]{hiddenStates, gateWeight}, false); + this.topK = topK; + this.numExperts = numExperts; + this.auxLossCoeff = auxLossCoeff; + addIArgument((long) topK, (long) numExperts); + addTArgument((double) auxLossCoeff); + } + + /** + * INDArray constructor. + */ + public MoeGate(INDArray hiddenStates, INDArray gateWeight, + INDArray expertIndices, INDArray gateWeights, INDArray auxLoss, + int topK, int numExperts, float auxLossCoeff) { + super(null, new INDArray[]{hiddenStates, gateWeight}, + expertIndices != null ? new INDArray[]{expertIndices, gateWeights, auxLoss} : null); + this.topK = topK; + this.numExperts = numExperts; + this.auxLossCoeff = auxLossCoeff; + addIArgument((long) topK, (long) numExperts); + addTArgument((double) auxLossCoeff); + } + + @Override + public void configureFromArguments() { + super.configureFromArguments(); + if (iArguments.size() > 0) this.topK = iArguments.get(0).intValue(); + if (iArguments.size() > 1) this.numExperts = iArguments.get(1).intValue(); + if (tArguments.size() > 0) this.auxLossCoeff = tArguments.get(0).floatValue(); + } + + @Override + public String opName() { + return "moe_gate"; + } + + @Override + public List calculateOutputDataTypes(List inputDataTypes) { + DataType dt = inputDataTypes.get(0); + return Arrays.asList(DataType.INT64, dt, dt); + } + + @Override + public int getNumOutputs() { + return 3; + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/MultiLoraMatmul.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/MultiLoraMatmul.java new file mode 100644 index 000000000000..99f742953d99 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/MultiLoraMatmul.java @@ -0,0 +1,125 @@ +/* + * ****************************************************************************** + * * + * * + * * This program and the accompanying materials are made available under the + * * terms of the Apache License, Version 2.0 which is available at + * * https://www.apache.org/licenses/LICENSE-2.0. + * * + * * See the NOTICE file distributed with this work for additional + * * information regarding copyright ownership. + * * Unless required by applicable law or agreed to in writing, software + * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * * License for the specific language governing permissions and limitations + * * under the License. + * * + * * SPDX-License-Identifier: Apache-2.0 + * ***************************************************************************** + */ + +package org.nd4j.linalg.api.ops.impl.transforms.custom; + +import lombok.Getter; +import lombok.NoArgsConstructor; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; + +import java.util.Collections; +import java.util.List; + +/** + * Multi-LoRA matrix multiplication for batched adapter inference. + *

+ * Applies per-sample LoRA adapters during batched inference: + *

+ *   output[i] = input[i] @ baseWeight + alpha * input[i] @ loraA[adapterIds[i]] @ loraB[adapterIds[i]]
+ * 
+ *

+ * Inputs: + *

    + *
  • 0: input [B, I] - input activations
  • + *
  • 1: baseWeight [I, O] - base model weight
  • + *
  • 2: loraA [A, I, R] - LoRA down-projection weights (A adapters)
  • + *
  • 3: loraB [A, R, O] - LoRA up-projection weights (A adapters)
  • + *
  • 4: adapterIds [B] INT64 - per-sample adapter index
  • + *
+ *

+ * Float arguments: + *

    + *
  • 0: alpha (default 1.0) - LoRA scaling factor
  • + *
+ *

+ * Output: [B, O] + * + * @author Eclipse Deeplearning4j Contributors + */ +@NoArgsConstructor +public class MultiLoraMatmul extends DynamicCustomOp { + + @Getter private float alpha = 1.0f; + + /** + * SameDiff constructor with default alpha. + */ + public MultiLoraMatmul(SameDiff sameDiff, SDVariable input, SDVariable baseWeight, + SDVariable loraA, SDVariable loraB, SDVariable adapterIds) { + super(null, sameDiff, new SDVariable[]{input, baseWeight, loraA, loraB, adapterIds}, false); + addTArgument((double) alpha); + } + + /** + * SameDiff constructor with double alpha. + */ + public MultiLoraMatmul(SameDiff sameDiff, SDVariable input, SDVariable baseWeight, + SDVariable loraA, SDVariable loraB, SDVariable adapterIds, + double alpha) { + this(sameDiff, input, baseWeight, loraA, loraB, adapterIds, (float) alpha); + } + + /** + * Full SameDiff constructor with alpha. + */ + public MultiLoraMatmul(SameDiff sameDiff, SDVariable input, SDVariable baseWeight, + SDVariable loraA, SDVariable loraB, SDVariable adapterIds, + float alpha) { + super(null, sameDiff, new SDVariable[]{input, baseWeight, loraA, loraB, adapterIds}, false); + this.alpha = alpha; + addTArgument((double) alpha); + } + + /** + * INDArray constructor. + */ + public MultiLoraMatmul(INDArray input, INDArray baseWeight, INDArray loraA, INDArray loraB, + INDArray adapterIds, INDArray output, float alpha) { + super(null, new INDArray[]{input, baseWeight, loraA, loraB, adapterIds}, + output != null ? new INDArray[]{output} : null); + this.alpha = alpha; + addTArgument((double) alpha); + } + + @Override + public void configureFromArguments() { + super.configureFromArguments(); + if (tArguments.size() > 0) this.alpha = tArguments.get(0).floatValue(); + } + + @Override + public String opName() { + return "multi_lora_matmul"; + } + + @Override + public List calculateOutputDataTypes(List inputDataTypes) { + return Collections.singletonList(inputDataTypes.get(0)); + } + + @Override + public int getNumOutputs() { + return 1; + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/RowParallelLinear.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/RowParallelLinear.java new file mode 100644 index 000000000000..1023165c4571 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/RowParallelLinear.java @@ -0,0 +1,153 @@ +/* + * ****************************************************************************** + * * + * * + * * This program and the accompanying materials are made available under the + * * terms of the Apache License, Version 2.0 which is available at + * * https://www.apache.org/licenses/LICENSE-2.0. + * * + * * See the NOTICE file distributed with this work for additional + * * information regarding copyright ownership. + * * Unless required by applicable law or agreed to in writing, software + * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * * License for the specific language governing permissions and limitations + * * under the License. + * * + * * SPDX-License-Identifier: Apache-2.0 + * ***************************************************************************** + */ + +package org.nd4j.linalg.api.ops.impl.transforms.custom; + +import lombok.Getter; +import lombok.NoArgsConstructor; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; + +import java.util.Collections; +import java.util.List; + +/** + * Row-parallel linear layer for tensor parallelism. + *

+ * Splits the weight matrix along rows across tensor-parallel ranks. + * Each rank computes a partial result, optionally reducing across ranks: + *

+ *   partial = inputShard @ weightShard
+ *   output = allReduce(partial) + bias if reduceOutput else partial
+ * 
+ *

+ * Inputs: + *

    + *
  • 0: inputShard [B, I/tp] - row-sharded input for this rank
  • + *
  • 1: weightShard [I/tp, O] - row-sharded weight for this rank
  • + *
  • 2: bias [O] (optional) - full bias (only added after reduce)
  • + *
+ *

+ * Integer arguments: + *

    + *
  • 0: tpSize (default 1) - tensor parallel world size
  • + *
  • 1: tpRank (default 0) - this rank's index
  • + *
  • 2: reduceOutput (0=no, 1=yes, default 1) - whether to all-reduce output
  • + *
+ *

+ * Output: [B, O] + * + * @author Eclipse Deeplearning4j Contributors + */ +@NoArgsConstructor +public class RowParallelLinear extends DynamicCustomOp { + + @Getter private int tpSize = 1; + @Getter private int tpRank = 0; + @Getter private int reduceOutput = 1; + + /** + * SameDiff constructor with required inputs. + */ + public RowParallelLinear(SameDiff sameDiff, SDVariable inputShard, SDVariable weightShard) { + super(null, sameDiff, new SDVariable[]{inputShard, weightShard}, false); + addIArgument((long) tpSize, (long) tpRank, (long) reduceOutput); + } + + /** + * SameDiff constructor with bias. + */ + public RowParallelLinear(SameDiff sameDiff, SDVariable inputShard, SDVariable weightShard, + SDVariable bias) { + super(null, sameDiff, new SDVariable[]{inputShard, weightShard, bias}, false); + addIArgument((long) tpSize, (long) tpRank, (long) reduceOutput); + } + + /** + * SameDiff constructor with boolean reduceOutput. + */ + public RowParallelLinear(SameDiff sameDiff, SDVariable inputShard, SDVariable weightShard, + SDVariable bias, int tpSize, int tpRank, boolean reduceOutput) { + this(sameDiff, inputShard, weightShard, bias, tpSize, tpRank, reduceOutput ? 1 : 0); + } + + /** + * Full SameDiff constructor with all options. + */ + public RowParallelLinear(SameDiff sameDiff, SDVariable inputShard, SDVariable weightShard, + SDVariable bias, int tpSize, int tpRank, int reduceOutput) { + super(null, sameDiff, bias != null ? + new SDVariable[]{inputShard, weightShard, bias} : + new SDVariable[]{inputShard, weightShard}, false); + this.tpSize = tpSize; + this.tpRank = tpRank; + this.reduceOutput = reduceOutput; + addIArgument((long) tpSize, (long) tpRank, (long) reduceOutput); + } + + /** + * INDArray constructor (no output pre-allocation). + */ + public RowParallelLinear(INDArray inputShard, INDArray weightShard, INDArray bias, + int tpSize, int tpRank, boolean reduceOutput) { + this(inputShard, weightShard, bias, null, tpSize, tpRank, reduceOutput ? 1 : 0); + } + + /** + * INDArray constructor. + */ + public RowParallelLinear(INDArray inputShard, INDArray weightShard, INDArray bias, + INDArray output, int tpSize, int tpRank, int reduceOutput) { + super(null, bias != null ? + new INDArray[]{inputShard, weightShard, bias} : + new INDArray[]{inputShard, weightShard}, + output != null ? new INDArray[]{output} : null); + this.tpSize = tpSize; + this.tpRank = tpRank; + this.reduceOutput = reduceOutput; + addIArgument((long) tpSize, (long) tpRank, (long) reduceOutput); + } + + @Override + public void configureFromArguments() { + super.configureFromArguments(); + if (iArguments.size() > 0) this.tpSize = iArguments.get(0).intValue(); + if (iArguments.size() > 1) this.tpRank = iArguments.get(1).intValue(); + if (iArguments.size() > 2) this.reduceOutput = iArguments.get(2).intValue(); + } + + @Override + public String opName() { + return "row_parallel_linear"; + } + + @Override + public List calculateOutputDataTypes(List inputDataTypes) { + return Collections.singletonList(inputDataTypes.get(0)); + } + + @Override + public int getNumOutputs() { + return 1; + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/SelectiveScan.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/SelectiveScan.java new file mode 100644 index 000000000000..7eb9ef50fcd8 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/SelectiveScan.java @@ -0,0 +1,108 @@ +/* + * ****************************************************************************** + * * + * * + * * This program and the accompanying materials are made available under the + * * terms of the Apache License, Version 2.0 which is available at + * * https://www.apache.org/licenses/LICENSE-2.0. + * * + * * See the NOTICE file distributed with this work for additional + * * information regarding copyright ownership. + * * Unless required by applicable law or agreed to in writing, software + * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * * License for the specific language governing permissions and limitations + * * under the License. + * * + * * SPDX-License-Identifier: Apache-2.0 + * ***************************************************************************** + */ + +package org.nd4j.linalg.api.ops.impl.transforms.custom; + +import lombok.NoArgsConstructor; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; + +import java.util.Collections; +import java.util.List; + +/** + * Selective scan (S6/Mamba) operation. + *

+ * Implements the selective state space model scan used in Mamba architectures: + *

+ *   h_t = A_t * h_{t-1} + B_t * x_t
+ *   y_t = C_t * h_t + D * x_t
+ * 
+ *

+ * Inputs: + *

    + *
  • 0: x [B, L, D] - input sequence
  • + *
  • 1: A [B, L, S] - state transition (discretized)
  • + *
  • 2: B [B, L, S] - input projection
  • + *
  • 3: C [B, L, S] - output projection
  • + *
  • 4: D [D] - skip connection weight
  • + *
  • 5: h0 (optional) - initial hidden state [B, D, S]
  • + *
+ *

+ * Output: [B, L, D] + * + * @author Eclipse Deeplearning4j Contributors + */ +@NoArgsConstructor +public class SelectiveScan extends DynamicCustomOp { + + /** + * SameDiff constructor with required inputs. + */ + public SelectiveScan(SameDiff sameDiff, SDVariable x, SDVariable a, SDVariable b, + SDVariable c, SDVariable d) { + super(null, sameDiff, new SDVariable[]{x, a, b, c, d}, false); + } + + /** + * SameDiff constructor with initial hidden state. + */ + public SelectiveScan(SameDiff sameDiff, SDVariable x, SDVariable a, SDVariable b, + SDVariable c, SDVariable d, SDVariable h0) { + super(null, sameDiff, new SDVariable[]{x, a, b, c, d, h0}, false); + } + + /** + * INDArray constructor without output pre-allocation. + */ + public SelectiveScan(INDArray x, INDArray a, INDArray b, INDArray c, INDArray d, + INDArray h0) { + this(x, a, b, c, d, h0, null); + } + + /** + * INDArray constructor. + */ + public SelectiveScan(INDArray x, INDArray a, INDArray b, INDArray c, INDArray d, + INDArray h0, INDArray output) { + super(null, h0 != null ? + new INDArray[]{x, a, b, c, d, h0} : + new INDArray[]{x, a, b, c, d}, + output != null ? new INDArray[]{output} : null); + } + + @Override + public String opName() { + return "selective_scan"; + } + + @Override + public List calculateOutputDataTypes(List inputDataTypes) { + return Collections.singletonList(inputDataTypes.get(0)); + } + + @Override + public int getNumOutputs() { + return 1; + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/SmoothQuant.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/SmoothQuant.java new file mode 100644 index 000000000000..0b1daeafa846 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/SmoothQuant.java @@ -0,0 +1,150 @@ +/* + * ****************************************************************************** + * * + * * + * * This program and the accompanying materials are made available under the + * * terms of the Apache License, Version 2.0 which is available at + * * https://www.apache.org/licenses/LICENSE-2.0. + * * + * * See the NOTICE file distributed with this work for additional + * * information regarding copyright ownership. + * * Unless required by applicable law or agreed to in writing, software + * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * * License for the specific language governing permissions and limitations + * * under the License. + * * + * * SPDX-License-Identifier: Apache-2.0 + * ***************************************************************************** + */ + +package org.nd4j.linalg.api.ops.impl.transforms.custom; + +import lombok.Getter; +import lombok.NoArgsConstructor; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.common.base.Preconditions; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; + +import java.util.Collections; +import java.util.List; + +/** + * SmoothQuant W8A8 Quantized Matrix Multiplication. + *

+ * Implements the SmoothQuant algorithm which migrates quantization difficulty + * from activations to weights by applying a per-channel smoothing factor: + *

+ *   Y = deq( quant(X * diag(s)^{-1}) @ quant(diag(s) * W) )
+ * 
+ * where {@code s} is a per-channel smoothing factor that balances the dynamic + * ranges of activations and weights. + *

+ * Inputs: + *

    + *
  • 0: X (FLOAT) - input activations [batch, in_features]
  • + *
  • 1: W_quantized (INT8) - pre-quantized smoothed weights [out_features, in_features]
  • + *
  • 2: smooth_scale (FLOAT) - per-channel smoothing factors [in_features]
  • + *
  • 3: act_scale (FLOAT) - activation quantization scale (scalar or per-channel)
  • + *
  • 4: weight_scale (FLOAT) - weight quantization scale (per-channel) [out_features]
  • + *
  • 5: bias (optional) [out_features]
  • + *
+ *

+ * Output: Y (FLOAT) - dequantized output [batch, out_features] + *

+ * Integer arguments: + *

    + *
  • 0: transpose_weight (0=no, 1=yes, default: 0)
  • + *
+ * + * @author Eclipse Deeplearning4j Contributors + */ +@NoArgsConstructor +public class SmoothQuant extends DynamicCustomOp { + + @Getter private boolean transposeWeight = false; + + /** + * SameDiff constructor with required inputs (no bias). + * + * @param sameDiff the SameDiff instance + * @param x input activations [batch, in_features] + * @param wQuantized pre-quantized smoothed weights [out_features, in_features] + * @param smoothScale per-channel smoothing factors [in_features] + * @param actScale activation quantization scale + * @param weightScale weight quantization scale [out_features] + */ + public SmoothQuant(SameDiff sameDiff, SDVariable x, SDVariable wQuantized, + SDVariable smoothScale, SDVariable actScale, SDVariable weightScale) { + super(null, sameDiff, new SDVariable[]{x, wQuantized, smoothScale, actScale, weightScale}, false); + addIArgument(transposeWeight ? 1L : 0L); + } + + /** + * SameDiff constructor with bias. + */ + public SmoothQuant(SameDiff sameDiff, SDVariable x, SDVariable wQuantized, + SDVariable smoothScale, SDVariable actScale, SDVariable weightScale, + SDVariable bias) { + super(null, sameDiff, + new SDVariable[]{x, wQuantized, smoothScale, actScale, weightScale, bias}, false); + addIArgument(transposeWeight ? 1L : 0L); + } + + /** + * SameDiff constructor with all options. + */ + public SmoothQuant(SameDiff sameDiff, SDVariable x, SDVariable wQuantized, + SDVariable smoothScale, SDVariable actScale, SDVariable weightScale, + SDVariable bias, boolean transposeWeight) { + super(null, sameDiff, + bias != null ? + new SDVariable[]{x, wQuantized, smoothScale, actScale, weightScale, bias} : + new SDVariable[]{x, wQuantized, smoothScale, actScale, weightScale}, + false); + this.transposeWeight = transposeWeight; + addIArgument(transposeWeight ? 1L : 0L); + } + + /** + * INDArray constructor. + */ + public SmoothQuant(INDArray x, INDArray wQuantized, INDArray smoothScale, + INDArray actScale, INDArray weightScale, INDArray bias, + INDArray output, boolean transposeWeight) { + super(null, + bias != null ? + new INDArray[]{x, wQuantized, smoothScale, actScale, weightScale, bias} : + new INDArray[]{x, wQuantized, smoothScale, actScale, weightScale}, + output != null ? new INDArray[]{output} : null); + this.transposeWeight = transposeWeight; + addIArgument(transposeWeight ? 1L : 0L); + } + + @Override + public void configureFromArguments() { + super.configureFromArguments(); + if (iArguments.size() > 0) this.transposeWeight = iArguments.get(0) != 0; + } + + @Override + public String opName() { + return "smooth_quant"; + } + + @Override + public List calculateOutputDataTypes(List inputDataTypes) { + Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() >= 5, + "Expected at least 5 input data types for smooth_quant, got %s", inputDataTypes); + // Output is always FLOAT (dequantized result) + return Collections.singletonList(DataType.FLOAT); + } + + @Override + public int getNumOutputs() { + return 1; + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDBase.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDBase.java index 14dbb96adb29..277d2cebc505 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDBase.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDBase.java @@ -80,7 +80,18 @@ public INDArray any(INDArray x, long... dimensions) { public INDArray argmax(INDArray in, boolean keepDims, long... dimensions) { NDValidation.validateNumerical("argmax", "in", in); Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMax(in, keepDims, dimensions))[0]; + INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMax(in, keepDims, dimensions)); + try { + return __tmp[0]; + } finally { + if(__tmp != null) { + for(int __i = 1; __i < __tmp.length; __i++) { + if(__tmp[__i] != null) { + __tmp[__i].close(); + } + } + } + } } /** @@ -102,7 +113,18 @@ public INDArray argmax(INDArray in, boolean keepDims, long... dimensions) { public INDArray argmax(INDArray in, long... dimensions) { NDValidation.validateNumerical("argmax", "in", in); Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMax(in, false, dimensions))[0]; + INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMax(in, false, dimensions)); + try { + return __tmp[0]; + } finally { + if(__tmp != null) { + for(int __i = 1; __i < __tmp.length; __i++) { + if(__tmp[__i] != null) { + __tmp[__i].close(); + } + } + } + } } /** @@ -128,7 +150,18 @@ public INDArray argmax(INDArray in, long... dimensions) { public INDArray argmin(INDArray in, boolean keepDims, long... dimensions) { NDValidation.validateNumerical("argmin", "in", in); Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMin(in, keepDims, dimensions))[0]; + INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMin(in, keepDims, dimensions)); + try { + return __tmp[0]; + } finally { + if(__tmp != null) { + for(int __i = 1; __i < __tmp.length; __i++) { + if(__tmp[__i] != null) { + __tmp[__i].close(); + } + } + } + } } /** @@ -153,7 +186,18 @@ public INDArray argmin(INDArray in, boolean keepDims, long... dimensions) { public INDArray argmin(INDArray in, long... dimensions) { NDValidation.validateNumerical("argmin", "in", in); Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMin(in, false, dimensions))[0]; + INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMin(in, false, dimensions)); + try { + return __tmp[0]; + } finally { + if(__tmp != null) { + for(int __i = 1; __i < __tmp.length; __i++) { + if(__tmp[__i] != null) { + __tmp[__i].close(); + } + } + } + } } /** @@ -165,7 +209,18 @@ public INDArray argmin(INDArray in, long... dimensions) { * @return output The newly assigned output (NUMERIC type) */ public INDArray assign(INDArray x, INDArray y) { - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.Assign(x, y))[0]; + INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.Assign(x, y)); + try { + return __tmp[0]; + } finally { + if(__tmp != null) { + for(int __i = 1; __i < __tmp.length; __i++) { + if(__tmp[__i] != null) { + __tmp[__i].close(); + } + } + } + } } /** @@ -220,6 +275,89 @@ public INDArray[] batchMmul(INDArray alphas, INDArray betas, INDArray[] inputsA, return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.custom.BatchMmul(alphas, betas, inputsA, inputsB, false, false)); } + /** + * Boolean AND operation: elementwise x && y. Supports broadcasting.
+ * + * @param x First input boolean array (BOOL type) + * @param y Second input boolean array (BOOL type) + * @return output Boolean AND result (BOOL type) + */ + public INDArray booleanAnd(INDArray x, INDArray y) { + NDValidation.validateBool("booleanAnd", "x", x); + NDValidation.validateBool("booleanAnd", "y", y); + INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.BooleanAnd(x, y)); + try { + return __tmp[0]; + } finally { + if(__tmp != null) { + for(int __i = 1; __i < __tmp.length; __i++) { + if(__tmp[__i] != null) { + __tmp[__i].close(); + } + } + } + } + } + + /** + * Boolean NOT operation: elementwise !x
+ * + * @param x Input boolean array (BOOL type) + * @return output Boolean NOT result (BOOL type) + */ + public INDArray booleanNot(INDArray x) { + NDValidation.validateBool("booleanNot", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.bool.BooleanNot(x)); + } + + /** + * Boolean OR operation: elementwise x || y. Supports broadcasting.
+ * + * @param x First input boolean array (BOOL type) + * @param y Second input boolean array (BOOL type) + * @return output Boolean OR result (BOOL type) + */ + public INDArray booleanOr(INDArray x, INDArray y) { + NDValidation.validateBool("booleanOr", "x", x); + NDValidation.validateBool("booleanOr", "y", y); + INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.BooleanOr(x, y)); + try { + return __tmp[0]; + } finally { + if(__tmp != null) { + for(int __i = 1; __i < __tmp.length; __i++) { + if(__tmp[__i] != null) { + __tmp[__i].close(); + } + } + } + } + } + + /** + * Boolean XOR operation: elementwise x ^ y. Supports broadcasting.
+ * + * @param x First input boolean array (BOOL type) + * @param y Second input boolean array (BOOL type) + * @return output Boolean XOR result (BOOL type) + */ + public INDArray booleanXor(INDArray x, INDArray y) { + NDValidation.validateBool("booleanXor", "x", x); + NDValidation.validateBool("booleanXor", "y", y); + INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.BooleanXor(x, y)); + try { + return __tmp[0]; + } finally { + if(__tmp != null) { + for(int __i = 1; __i < __tmp.length; __i++) { + if(__tmp[__i] != null) { + __tmp[__i].close(); + } + } + } + } + } + /** * Cast the array to a new datatype - for example, Integer -> Float
* @@ -228,7 +366,18 @@ public INDArray[] batchMmul(INDArray alphas, INDArray betas, INDArray[] inputsA, * @return output Output array (after casting) (NDARRAY type) */ public INDArray castTo(INDArray arg, DataType datatype) { - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.dtype.Cast(arg, datatype))[0]; + INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.dtype.Cast(arg, datatype)); + try { + return __tmp[0]; + } finally { + if(__tmp != null) { + for(int __i = 1; __i < __tmp.length; __i++) { + if(__tmp[__i] != null) { + __tmp[__i].close(); + } + } + } + } } /** @@ -241,7 +390,18 @@ public INDArray castTo(INDArray arg, DataType datatype) { */ public INDArray clipByNorm(INDArray x, double clipValue) { NDValidation.validateNumerical("clipByNorm", "x", x); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.clip.ClipByNorm(x, clipValue))[0]; + INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.clip.ClipByNorm(x, clipValue)); + try { + return __tmp[0]; + } finally { + if(__tmp != null) { + for(int __i = 1; __i < __tmp.length; __i++) { + if(__tmp[__i] != null) { + __tmp[__i].close(); + } + } + } + } } /** @@ -257,7 +417,18 @@ public INDArray clipByNorm(INDArray x, INDArray clipValue, INDArray dimensions) NDValidation.validateNumerical("clipByNorm", "x", x); NDValidation.validateNumerical("clipByNorm", "clipValue", clipValue); NDValidation.validateNumerical("clipByNorm", "dimensions", dimensions); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.clip.ClipByNorm(x, clipValue, dimensions))[0]; + INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.clip.ClipByNorm(x, clipValue, dimensions)); + try { + return __tmp[0]; + } finally { + if(__tmp != null) { + for(int __i = 1; __i < __tmp.length; __i++) { + if(__tmp[__i] != null) { + __tmp[__i].close(); + } + } + } + } } /** @@ -270,7 +441,18 @@ public INDArray clipByNorm(INDArray x, INDArray clipValue, INDArray dimensions) */ public INDArray clipByValue(INDArray x, double clipValueMin, double clipValueMax) { NDValidation.validateNumerical("clipByValue", "x", x); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.clip.ClipByValue(x, clipValueMin, clipValueMax))[0]; + INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.clip.ClipByValue(x, clipValueMin, clipValueMax)); + try { + return __tmp[0]; + } finally { + if(__tmp != null) { + for(int __i = 1; __i < __tmp.length; __i++) { + if(__tmp[__i] != null) { + __tmp[__i].close(); + } + } + } + } } /** @@ -285,7 +467,18 @@ public INDArray clipByValue(INDArray x, INDArray clipValueMin, INDArray clipValu NDValidation.validateNumerical("clipByValue", "x", x); NDValidation.validateNumerical("clipByValue", "clipValueMin", clipValueMin); NDValidation.validateNumerical("clipByValue", "clipValueMax", clipValueMax); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.clip.ClipByValue(x, clipValueMin, clipValueMax))[0]; + INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.clip.ClipByValue(x, clipValueMin, clipValueMax)); + try { + return __tmp[0]; + } finally { + if(__tmp != null) { + for(int __i = 1; __i < __tmp.length; __i++) { + if(__tmp[__i] != null) { + __tmp[__i].close(); + } + } + } + } } /** @@ -304,7 +497,18 @@ public INDArray concat(int dimension, INDArray... inputs) { NDValidation.validateNumerical("concat", "inputs", inputs); Preconditions.checkArgument(inputs.length >= 1, "inputs has incorrect size/length. Expected: inputs.length >= 1, got %s", inputs.length); Preconditions.checkArgument(isSameType(inputs), "Input arrays must all be the same datatype"); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Concat(inputs, dimension))[0]; + INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Concat(inputs, dimension)); + try { + return __tmp[0]; + } finally { + if(__tmp != null) { + for(int __i = 1; __i < __tmp.length; __i++) { + if(__tmp[__i] != null) { + __tmp[__i].close(); + } + } + } + } } /** @@ -318,7 +522,18 @@ public INDArray concat(int dimension, INDArray... inputs) { */ public INDArray create(INDArray shape, DataType dataType, String order, boolean initialize) { NDValidation.validateNumerical("create", "shape", shape); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Create(shape, dataType, order, initialize))[0]; + INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Create(shape, dataType, order, initialize)); + try { + return __tmp[0]; + } finally { + if(__tmp != null) { + for(int __i = 1; __i < __tmp.length; __i++) { + if(__tmp[__i] != null) { + __tmp[__i].close(); + } + } + } + } } /** @@ -330,7 +545,18 @@ public INDArray create(INDArray shape, DataType dataType, String order, boolean */ public INDArray create(INDArray shape, DataType dataType) { NDValidation.validateNumerical("create", "shape", shape); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Create(shape, dataType, "c", false))[0]; + INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Create(shape, dataType, "c", false)); + try { + return __tmp[0]; + } finally { + if(__tmp != null) { + for(int __i = 1; __i < __tmp.length; __i++) { + if(__tmp[__i] != null) { + __tmp[__i].close(); + } + } + } + } } /** @@ -342,7 +568,18 @@ public INDArray create(INDArray shape, DataType dataType) { */ public INDArray createView(INDArray input, INDArray... indices) { Preconditions.checkArgument(indices.length >= 0, "indices has incorrect size/length. Expected: indices.length >= 0, got %s", indices.length); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.CreateView(input, indices))[0]; + INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.CreateView(input, indices)); + try { + return __tmp[0]; + } finally { + if(__tmp != null) { + for(int __i = 1; __i < __tmp.length; __i++) { + if(__tmp[__i] != null) { + __tmp[__i].close(); + } + } + } + } } /** @@ -362,7 +599,18 @@ public INDArray createView(INDArray input, INDArray... indices) { public INDArray cumprod(INDArray in, boolean exclusive, boolean reverse, long... axis) { NDValidation.validateNumerical("cumprod", "in", in); Preconditions.checkArgument(axis.length >= 1, "axis has incorrect size/length. Expected: axis.length >= 1, got %s", axis.length); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.CumProd(in, exclusive, reverse, axis))[0]; + INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.CumProd(in, exclusive, reverse, axis)); + try { + return __tmp[0]; + } finally { + if(__tmp != null) { + for(int __i = 1; __i < __tmp.length; __i++) { + if(__tmp[__i] != null) { + __tmp[__i].close(); + } + } + } + } } /** @@ -380,7 +628,18 @@ public INDArray cumprod(INDArray in, boolean exclusive, boolean reverse, long... public INDArray cumprod(INDArray in, long... axis) { NDValidation.validateNumerical("cumprod", "in", in); Preconditions.checkArgument(axis.length >= 1, "axis has incorrect size/length. Expected: axis.length >= 1, got %s", axis.length); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.CumProd(in, false, false, axis))[0]; + INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.CumProd(in, false, false, axis)); + try { + return __tmp[0]; + } finally { + if(__tmp != null) { + for(int __i = 1; __i < __tmp.length; __i++) { + if(__tmp[__i] != null) { + __tmp[__i].close(); + } + } + } + } } /** @@ -400,7 +659,18 @@ public INDArray cumprod(INDArray in, long... axis) { public INDArray cumsum(INDArray in, boolean exclusive, boolean reverse, long... axis) { NDValidation.validateNumerical("cumsum", "in", in); Preconditions.checkArgument(axis.length >= 1, "axis has incorrect size/length. Expected: axis.length >= 1, got %s", axis.length); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.CumSum(in, exclusive, reverse, axis))[0]; + INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.CumSum(in, exclusive, reverse, axis)); + try { + return __tmp[0]; + } finally { + if(__tmp != null) { + for(int __i = 1; __i < __tmp.length; __i++) { + if(__tmp[__i] != null) { + __tmp[__i].close(); + } + } + } + } } /** @@ -418,7 +688,18 @@ public INDArray cumsum(INDArray in, boolean exclusive, boolean reverse, long... public INDArray cumsum(INDArray in, long... axis) { NDValidation.validateNumerical("cumsum", "in", in); Preconditions.checkArgument(axis.length >= 1, "axis has incorrect size/length. Expected: axis.length >= 1, got %s", axis.length); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.CumSum(in, false, false, axis))[0]; + INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.CumSum(in, false, false, axis)); + try { + return __tmp[0]; + } finally { + if(__tmp != null) { + for(int __i = 1; __i < __tmp.length; __i++) { + if(__tmp[__i] != null) { + __tmp[__i].close(); + } + } + } + } } /** @@ -470,7 +751,18 @@ public INDArray dynamicStitch(INDArray[] indices, INDArray... x) { Preconditions.checkArgument(indices.length >= 1, "indices has incorrect size/length. Expected: indices.length >= 1, got %s", indices.length); NDValidation.validateNumerical("dynamicStitch", "x", x); Preconditions.checkArgument(x.length >= 1, "x has incorrect size/length. Expected: x.length >= 1, got %s", x.length); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.DynamicStitch(indices, x))[0]; + INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.DynamicStitch(indices, x)); + try { + return __tmp[0]; + } finally { + if(__tmp != null) { + for(int __i = 1; __i < __tmp.length; __i++) { + if(__tmp[__i] != null) { + __tmp[__i].close(); + } + } + } + } } /** @@ -502,7 +794,18 @@ public INDArray eq(INDArray x, double y) { * @return output Boolean array out, with values true/false based on where the condition is satisfied (NUMERIC type) */ public INDArray eq(INDArray x, INDArray y) { - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.EqualTo(x, y))[0]; + INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.EqualTo(x, y)); + try { + return __tmp[0]; + } finally { + if(__tmp != null) { + for(int __i = 1; __i < __tmp.length; __i++) { + if(__tmp[__i] != null) { + __tmp[__i].close(); + } + } + } + } } /** @@ -517,7 +820,18 @@ public INDArray eq(INDArray x, INDArray y) { * @return output Output variable (NUMERIC type) */ public INDArray expandDims(INDArray x, int axis) { - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.ExpandDims(x, axis))[0]; + INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.ExpandDims(x, axis)); + try { + return __tmp[0]; + } finally { + if(__tmp != null) { + for(int __i = 1; __i < __tmp.length; __i++) { + if(__tmp[__i] != null) { + __tmp[__i].close(); + } + } + } + } } /** @@ -530,7 +844,18 @@ public INDArray expandDims(INDArray x, int axis) { */ public INDArray fill(INDArray shape, DataType dataType, double value) { NDValidation.validateInteger("fill", "shape", shape); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.Fill(shape, dataType, value))[0]; + INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.Fill(shape, dataType, value)); + try { + return __tmp[0]; + } finally { + if(__tmp != null) { + for(int __i = 1; __i < __tmp.length; __i++) { + if(__tmp[__i] != null) { + __tmp[__i].close(); + } + } + } + } } /** @@ -542,7 +867,18 @@ public INDArray fill(INDArray shape, DataType dataType, double value) { */ public INDArray flatten(INDArray[] inputs, String order) { Preconditions.checkArgument(inputs.length >= 1, "inputs has incorrect size/length. Expected: inputs.length >= 1, got %s", inputs.length); - return Nd4j.exec(new org.nd4j.linalg.api.ops.custom.Flatten(inputs, order))[0]; + INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.custom.Flatten(inputs, order)); + try { + return __tmp[0]; + } finally { + if(__tmp != null) { + for(int __i = 1; __i < __tmp.length; __i++) { + if(__tmp[__i] != null) { + __tmp[__i].close(); + } + } + } + } } /** @@ -553,7 +889,18 @@ public INDArray flatten(INDArray[] inputs, String order) { */ public INDArray flatten(INDArray... inputs) { Preconditions.checkArgument(inputs.length >= 1, "inputs has incorrect size/length. Expected: inputs.length >= 1, got %s", inputs.length); - return Nd4j.exec(new org.nd4j.linalg.api.ops.custom.Flatten(inputs, "c"))[0]; + INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.custom.Flatten(inputs, "c")); + try { + return __tmp[0]; + } finally { + if(__tmp != null) { + for(int __i = 1; __i < __tmp.length; __i++) { + if(__tmp[__i] != null) { + __tmp[__i].close(); + } + } + } + } } /** @@ -567,7 +914,18 @@ public INDArray flatten(INDArray... inputs) { */ public INDArray gather(INDArray df, int[] indices, int axis) { Preconditions.checkArgument(indices.length >= 1, "indices has incorrect size/length. Expected: indices.length >= 1, got %s", indices.length); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Gather(df, indices, axis))[0]; + INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Gather(df, indices, axis)); + try { + return __tmp[0]; + } finally { + if(__tmp != null) { + for(int __i = 1; __i < __tmp.length; __i++) { + if(__tmp[__i] != null) { + __tmp[__i].close(); + } + } + } + } } /** @@ -581,7 +939,18 @@ public INDArray gather(INDArray df, int[] indices, int axis) { */ public INDArray gather(INDArray df, INDArray indices, int axis) { NDValidation.validateInteger("gather", "indices", indices); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Gather(df, indices, axis))[0]; + INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Gather(df, indices, axis)); + try { + return __tmp[0]; + } finally { + if(__tmp != null) { + for(int __i = 1; __i < __tmp.length; __i++) { + if(__tmp[__i] != null) { + __tmp[__i].close(); + } + } + } + } } /** @@ -593,7 +962,18 @@ public INDArray gather(INDArray df, INDArray indices, int axis) { */ public INDArray gatherNd(INDArray df, INDArray indices) { NDValidation.validateNumerical("gatherNd", "indices", indices); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.GatherNd(df, indices))[0]; + INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.GatherNd(df, indices)); + try { + return __tmp[0]; + } finally { + if(__tmp != null) { + for(int __i = 1; __i < __tmp.length; __i++) { + if(__tmp[__i] != null) { + __tmp[__i].close(); + } + } + } + } } /** @@ -624,7 +1004,18 @@ public INDArray gt(INDArray x, double y) { * @return output Output Boolean array out, with values true/false based on where the condition is satisfied (NUMERIC type) */ public INDArray gt(INDArray x, INDArray y) { - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.GreaterThan(x, y))[0]; + INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.GreaterThan(x, y)); + try { + return __tmp[0]; + } finally { + if(__tmp != null) { + for(int __i = 1; __i < __tmp.length; __i++) { + if(__tmp[__i] != null) { + __tmp[__i].close(); + } + } + } + } } /** @@ -655,7 +1046,18 @@ public INDArray gte(INDArray x, double y) { * @return output (NDARRAY type) */ public INDArray gte(INDArray x, INDArray y) { - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.GreaterThanOrEqual(x, y))[0]; + INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.GreaterThanOrEqual(x, y)); + try { + return __tmp[0]; + } finally { + if(__tmp != null) { + for(int __i = 1; __i < __tmp.length; __i++) { + if(__tmp[__i] != null) { + __tmp[__i].close(); + } + } + } + } } /** @@ -665,7 +1067,18 @@ public INDArray gte(INDArray x, INDArray y) { * @return output Output variable (NDARRAY type) */ public INDArray identity(INDArray input) { - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.same.Identity(input))[0]; + INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.same.Identity(input)); + try { + return __tmp[0]; + } finally { + if(__tmp != null) { + for(int __i = 1; __i < __tmp.length; __i++) { + if(__tmp[__i] != null) { + __tmp[__i].close(); + } + } + } + } } /** @@ -678,7 +1091,18 @@ public INDArray identity(INDArray input) { */ public INDArray invertPermutation(INDArray input) { NDValidation.validateInteger("invertPermutation", "input", input); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.InvertPermutation(input))[0]; + INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.InvertPermutation(input)); + try { + return __tmp[0]; + } finally { + if(__tmp != null) { + for(int __i = 1; __i < __tmp.length; __i++) { + if(__tmp[__i] != null) { + __tmp[__i].close(); + } + } + } + } } /** @@ -689,7 +1113,18 @@ public INDArray invertPermutation(INDArray input) { */ public INDArray isNumericTensor(INDArray x) { NDValidation.validateNumerical("isNumericTensor", "x", x); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.IsNumericTensor(x))[0]; + INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.IsNumericTensor(x)); + try { + return __tmp[0]; + } finally { + if(__tmp != null) { + for(int __i = 1; __i < __tmp.length; __i++) { + if(__tmp[__i] != null) { + __tmp[__i].close(); + } + } + } + } } /** @@ -703,7 +1138,18 @@ public INDArray isNumericTensor(INDArray x) { * @return output INDArray with linearly spaced elements (NUMERIC type) */ public INDArray linspace(DataType dataType, double start, double stop, long number) { - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Linspace(dataType, start, stop, number))[0]; + INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Linspace(dataType, start, stop, number)); + try { + return __tmp[0]; + } finally { + if(__tmp != null) { + for(int __i = 1; __i < __tmp.length; __i++) { + if(__tmp[__i] != null) { + __tmp[__i].close(); + } + } + } + } } /** @@ -720,7 +1166,18 @@ public INDArray linspace(INDArray start, INDArray stop, INDArray number, DataTyp NDValidation.validateNumerical("linspace", "start", start); NDValidation.validateNumerical("linspace", "stop", stop); NDValidation.validateInteger("linspace", "number", number); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Linspace(start, stop, number, dataType))[0]; + INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Linspace(start, stop, number, dataType)); + try { + return __tmp[0]; + } finally { + if(__tmp != null) { + for(int __i = 1; __i < __tmp.length; __i++) { + if(__tmp[__i] != null) { + __tmp[__i].close(); + } + } + } + } } /** @@ -751,7 +1208,18 @@ public INDArray lt(INDArray x, double y) { * @return output Output Boolean array out, with values true/false based on where the condition is satisfied (NDARRAY type) */ public INDArray lt(INDArray x, INDArray y) { - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.LessThan(x, y))[0]; + INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.LessThan(x, y)); + try { + return __tmp[0]; + } finally { + if(__tmp != null) { + for(int __i = 1; __i < __tmp.length; __i++) { + if(__tmp[__i] != null) { + __tmp[__i].close(); + } + } + } + } } /** @@ -784,7 +1252,18 @@ public INDArray lte(INDArray x, double y) { public INDArray lte(INDArray x, INDArray y) { NDValidation.validateNumerical("lte", "x", x); NDValidation.validateNumerical("lte", "y", y); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.LessThanOrEqual(x, y))[0]; + INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.LessThanOrEqual(x, y)); + try { + return __tmp[0]; + } finally { + if(__tmp != null) { + for(int __i = 1; __i < __tmp.length; __i++) { + if(__tmp[__i] != null) { + __tmp[__i].close(); + } + } + } + } } /** @@ -910,7 +1389,18 @@ public INDArray max(INDArray x, long... dimensions) { public INDArray max(INDArray first, INDArray second) { NDValidation.validateNumerical("max", "first", first); NDValidation.validateNumerical("max", "second", second); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.Max(first, second))[0]; + INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.Max(first, second)); + try { + return __tmp[0]; + } finally { + if(__tmp != null) { + for(int __i = 1; __i < __tmp.length; __i++) { + if(__tmp[__i] != null) { + __tmp[__i].close(); + } + } + } + } } /** @@ -1005,7 +1495,18 @@ public INDArray mean(INDArray x, INDArray dimensions) { * @return output Output (NDARRAY type) */ public INDArray merge(INDArray x, INDArray y) { - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.controlflow.compat.Merge(x, y))[0]; + INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.controlflow.compat.Merge(x, y)); + try { + return __tmp[0]; + } finally { + if(__tmp != null) { + for(int __i = 1; __i < __tmp.length; __i++) { + if(__tmp[__i] != null) { + __tmp[__i].close(); + } + } + } + } } /** @@ -1063,7 +1564,18 @@ public INDArray min(INDArray x, long... dimensions) { public INDArray min(INDArray first, INDArray second) { NDValidation.validateNumerical("min", "first", first); NDValidation.validateNumerical("min", "second", second); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.Min(first, second))[0]; + INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.Min(first, second)); + try { + return __tmp[0]; + } finally { + if(__tmp != null) { + for(int __i = 1; __i < __tmp.length; __i++) { + if(__tmp[__i] != null) { + __tmp[__i].close(); + } + } + } + } } /** @@ -1074,7 +1586,18 @@ public INDArray min(INDArray first, INDArray second) { * @return output Output array (after casting) (NDARRAY type) */ public INDArray minMax(int datatype, int minOrMax) { - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.dtype.MinMaxDataType(datatype, minOrMax))[0]; + INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.dtype.MinMaxDataType(datatype, minOrMax)); + try { + return __tmp[0]; + } finally { + if(__tmp != null) { + for(int __i = 1; __i < __tmp.length; __i++) { + if(__tmp[__i] != null) { + __tmp[__i].close(); + } + } + } + } } /** @@ -1092,7 +1615,18 @@ public INDArray mmul(INDArray x, INDArray y, boolean transposeX, boolean transpo boolean transposeZ) { NDValidation.validateNumerical("mmul", "x", x); NDValidation.validateNumerical("mmul", "y", y); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.Mmul(x, y, transposeX, transposeY, transposeZ))[0]; + INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.Mmul(x, y, transposeX, transposeY, transposeZ)); + try { + return __tmp[0]; + } finally { + if(__tmp != null) { + for(int __i = 1; __i < __tmp.length; __i++) { + if(__tmp[__i] != null) { + __tmp[__i].close(); + } + } + } + } } /** @@ -1106,7 +1640,18 @@ public INDArray mmul(INDArray x, INDArray y, boolean transposeX, boolean transpo public INDArray mmul(INDArray x, INDArray y) { NDValidation.validateNumerical("mmul", "x", x); NDValidation.validateNumerical("mmul", "y", y); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.Mmul(x, y, false, false, false))[0]; + INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.Mmul(x, y, false, false, false)); + try { + return __tmp[0]; + } finally { + if(__tmp != null) { + for(int __i = 1; __i < __tmp.length; __i++) { + if(__tmp[__i] != null) { + __tmp[__i].close(); + } + } + } + } } /** @@ -1137,7 +1682,18 @@ public INDArray neq(INDArray x, double y) { * @return output Boolean array out, with values true/false based on where the condition is satisfied (NDARRAY type) */ public INDArray neq(INDArray x, INDArray y) { - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.NotEqualTo(x, y))[0]; + INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.NotEqualTo(x, y)); + try { + return __tmp[0]; + } finally { + if(__tmp != null) { + for(int __i = 1; __i < __tmp.length; __i++) { + if(__tmp[__i] != null) { + __tmp[__i].close(); + } + } + } + } } /** @@ -1287,7 +1843,18 @@ public INDArray normmax(INDArray x, long... dimensions) { public INDArray oneHot(INDArray indices, int depth, int axis, double on, double off, DataType dataType) { NDValidation.validateNumerical("oneHot", "indices", indices); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.OneHot(indices, depth, axis, on, off, dataType))[0]; + INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.OneHot(indices, depth, axis, on, off, dataType)); + try { + return __tmp[0]; + } finally { + if(__tmp != null) { + for(int __i = 1; __i < __tmp.length; __i++) { + if(__tmp[__i] != null) { + __tmp[__i].close(); + } + } + } + } } /** @@ -1304,7 +1871,18 @@ public INDArray oneHot(INDArray indices, int depth, int axis, double on, double */ public INDArray oneHot(INDArray indices, int depth, int axis, double on, double off) { NDValidation.validateNumerical("oneHot", "indices", indices); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.OneHot(indices, depth, axis, on, off, DataType.FLOAT))[0]; + INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.OneHot(indices, depth, axis, on, off, DataType.FLOAT)); + try { + return __tmp[0]; + } finally { + if(__tmp != null) { + for(int __i = 1; __i < __tmp.length; __i++) { + if(__tmp[__i] != null) { + __tmp[__i].close(); + } + } + } + } } /** @@ -1319,7 +1897,18 @@ public INDArray oneHot(INDArray indices, int depth, int axis, double on, double */ public INDArray oneHot(INDArray indices, int depth) { NDValidation.validateNumerical("oneHot", "indices", indices); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.OneHot(indices, depth))[0]; + INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.OneHot(indices, depth)); + try { + return __tmp[0]; + } finally { + if(__tmp != null) { + for(int __i = 1; __i < __tmp.length; __i++) { + if(__tmp[__i] != null) { + __tmp[__i].close(); + } + } + } + } } /** @@ -1330,7 +1919,18 @@ public INDArray oneHot(INDArray indices, int depth) { * @return output A new INDArray with the same (dynamic) shape as the input (NUMERIC type) */ public INDArray onesLike(INDArray input) { - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.OnesLike(input))[0]; + INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.OnesLike(input)); + try { + return __tmp[0]; + } finally { + if(__tmp != null) { + for(int __i = 1; __i < __tmp.length; __i++) { + if(__tmp[__i] != null) { + __tmp[__i].close(); + } + } + } + } } /** @@ -1341,7 +1941,18 @@ public INDArray onesLike(INDArray input) { * @return output (NUMERIC type) */ public INDArray onesLike(INDArray input, DataType dataType) { - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.OnesLike(input, dataType))[0]; + INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.OnesLike(input, dataType)); + try { + return __tmp[0]; + } finally { + if(__tmp != null) { + for(int __i = 1; __i < __tmp.length; __i++) { + if(__tmp[__i] != null) { + __tmp[__i].close(); + } + } + } + } } /** @@ -1354,7 +1965,18 @@ public INDArray onesLike(INDArray input, DataType dataType) { */ public INDArray permute(INDArray x, INDArray dimensions) { NDValidation.validateInteger("permute", "dimensions", dimensions); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Permute(x, dimensions))[0]; + INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Permute(x, dimensions)); + try { + return __tmp[0]; + } finally { + if(__tmp != null) { + for(int __i = 1; __i < __tmp.length; __i++) { + if(__tmp[__i] != null) { + __tmp[__i].close(); + } + } + } + } } /** @@ -1367,7 +1989,18 @@ public INDArray permute(INDArray x, INDArray dimensions) { */ public INDArray permute(INDArray x, long... dimensions) { Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Permute(x, dimensions))[0]; + INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Permute(x, dimensions)); + try { + return __tmp[0]; + } finally { + if(__tmp != null) { + for(int __i = 1; __i < __tmp.length; __i++) { + if(__tmp[__i] != null) { + __tmp[__i].close(); + } + } + } + } } /** @@ -1464,7 +2097,18 @@ public INDArray prod(INDArray x, INDArray dimensions) { * @return output INDArray with the specified values (NUMERIC type) */ public INDArray range(double from, double to, double step, DataType dataType) { - return Nd4j.exec(new org.nd4j.linalg.api.ops.random.impl.Range(from, to, step, dataType))[0]; + INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.random.impl.Range(from, to, step, dataType)); + try { + return __tmp[0]; + } finally { + if(__tmp != null) { + for(int __i = 1; __i < __tmp.length; __i++) { + if(__tmp[__i] != null) { + __tmp[__i].close(); + } + } + } + } } /** @@ -1482,7 +2126,18 @@ public INDArray range(INDArray from, INDArray to, INDArray step, DataType dataTy NDValidation.validateNumerical("range", "from", from); NDValidation.validateNumerical("range", "to", to); NDValidation.validateNumerical("range", "step", step); - return Nd4j.exec(new org.nd4j.linalg.api.ops.random.impl.Range(from, to, step, dataType))[0]; + INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.random.impl.Range(from, to, step, dataType)); + try { + return __tmp[0]; + } finally { + if(__tmp != null) { + for(int __i = 1; __i < __tmp.length; __i++) { + if(__tmp[__i] != null) { + __tmp[__i].close(); + } + } + } + } } /** @@ -1492,7 +2147,18 @@ public INDArray range(INDArray from, INDArray to, INDArray step, DataType dataTy * @return output (scalar) output variable with value equal to the rank of the input variable (NUMERIC type) */ public INDArray rank(INDArray in) { - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Rank(in))[0]; + INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Rank(in)); + try { + return __tmp[0]; + } finally { + if(__tmp != null) { + for(int __i = 1; __i < __tmp.length; __i++) { + if(__tmp[__i] != null) { + __tmp[__i].close(); + } + } + } + } } /** @@ -1506,7 +2172,18 @@ public INDArray rank(INDArray in) { public INDArray repeat(INDArray input, INDArray repeats, int axis) { NDValidation.validateNumerical("repeat", "input", input); NDValidation.validateNumerical("repeat", "repeats", repeats); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Repeat(input, repeats, axis))[0]; + INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Repeat(input, repeats, axis)); + try { + return __tmp[0]; + } finally { + if(__tmp != null) { + for(int __i = 1; __i < __tmp.length; __i++) { + if(__tmp[__i] != null) { + __tmp[__i].close(); + } + } + } + } } /** @@ -1551,7 +2228,18 @@ public INDArray replaceWhere(INDArray update, double value, Condition condition) */ public INDArray reshape(INDArray x, INDArray shape) { NDValidation.validateNumerical("reshape", "shape", shape); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Reshape(x, shape))[0]; + INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Reshape(x, shape)); + try { + return __tmp[0]; + } finally { + if(__tmp != null) { + for(int __i = 1; __i < __tmp.length; __i++) { + if(__tmp[__i] != null) { + __tmp[__i].close(); + } + } + } + } } /** @@ -1565,7 +2253,18 @@ public INDArray reshape(INDArray x, INDArray shape) { */ public INDArray reshape(INDArray x, long... shape) { Preconditions.checkArgument(shape.length >= 0, "shape has incorrect size/length. Expected: shape.length >= 0, got %s", shape.length); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Reshape(x, shape))[0]; + INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Reshape(x, shape)); + try { + return __tmp[0]; + } finally { + if(__tmp != null) { + for(int __i = 1; __i < __tmp.length; __i++) { + if(__tmp[__i] != null) { + __tmp[__i].close(); + } + } + } + } } /** @@ -1587,7 +2286,18 @@ public INDArray reshape(INDArray x, long... shape) { */ public INDArray reverse(INDArray x, long... dimensions) { Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.Reverse(x, dimensions))[0]; + INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.Reverse(x, dimensions)); + try { + return __tmp[0]; + } finally { + if(__tmp != null) { + for(int __i = 1; __i < __tmp.length; __i++) { + if(__tmp[__i] != null) { + __tmp[__i].close(); + } + } + } + } } /** @@ -1601,7 +2311,18 @@ public INDArray reverse(INDArray x, long... dimensions) { */ public INDArray reverseSequence(INDArray x, INDArray seq_lengths, int seqDim, int batchDim) { NDValidation.validateInteger("reverseSequence", "seq_lengths", seq_lengths); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.ReverseSequence(x, seq_lengths, seqDim, batchDim))[0]; + INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.ReverseSequence(x, seq_lengths, seqDim, batchDim)); + try { + return __tmp[0]; + } finally { + if(__tmp != null) { + for(int __i = 1; __i < __tmp.length; __i++) { + if(__tmp[__i] != null) { + __tmp[__i].close(); + } + } + } + } } /** @@ -1613,7 +2334,18 @@ public INDArray reverseSequence(INDArray x, INDArray seq_lengths, int seqDim, in */ public INDArray reverseSequence(INDArray x, INDArray seq_lengths) { NDValidation.validateInteger("reverseSequence", "seq_lengths", seq_lengths); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.ReverseSequence(x, seq_lengths, -1, 0))[0]; + INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.ReverseSequence(x, seq_lengths, -1, 0)); + try { + return __tmp[0]; + } finally { + if(__tmp != null) { + for(int __i = 1; __i < __tmp.length; __i++) { + if(__tmp[__i] != null) { + __tmp[__i].close(); + } + } + } + } } /** @@ -1682,7 +2414,18 @@ public INDArray scatterAdd(INDArray ref, INDArray indices, INDArray updates) { NDValidation.validateNumerical("scatterAdd", "ref", ref); NDValidation.validateNumerical("scatterAdd", "indices", indices); NDValidation.validateNumerical("scatterAdd", "updates", updates); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.scatter.ScatterAdd(ref, indices, updates))[0]; + INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.scatter.ScatterAdd(ref, indices, updates)); + try { + return __tmp[0]; + } finally { + if(__tmp != null) { + for(int __i = 1; __i < __tmp.length; __i++) { + if(__tmp[__i] != null) { + __tmp[__i].close(); + } + } + } + } } /** @@ -1702,7 +2445,18 @@ public INDArray scatterDiv(INDArray ref, INDArray indices, INDArray updates) { NDValidation.validateNumerical("scatterDiv", "ref", ref); NDValidation.validateNumerical("scatterDiv", "indices", indices); NDValidation.validateNumerical("scatterDiv", "updates", updates); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.scatter.ScatterDiv(ref, indices, updates))[0]; + INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.scatter.ScatterDiv(ref, indices, updates)); + try { + return __tmp[0]; + } finally { + if(__tmp != null) { + for(int __i = 1; __i < __tmp.length; __i++) { + if(__tmp[__i] != null) { + __tmp[__i].close(); + } + } + } + } } /** @@ -1724,7 +2478,18 @@ public INDArray scatterMax(INDArray ref, INDArray indices, INDArray updates) { NDValidation.validateNumerical("scatterMax", "ref", ref); NDValidation.validateNumerical("scatterMax", "indices", indices); NDValidation.validateNumerical("scatterMax", "updates", updates); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.scatter.ScatterMax(ref, indices, updates))[0]; + INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.scatter.ScatterMax(ref, indices, updates)); + try { + return __tmp[0]; + } finally { + if(__tmp != null) { + for(int __i = 1; __i < __tmp.length; __i++) { + if(__tmp[__i] != null) { + __tmp[__i].close(); + } + } + } + } } /** @@ -1746,7 +2511,18 @@ public INDArray scatterMin(INDArray ref, INDArray indices, INDArray updates) { NDValidation.validateNumerical("scatterMin", "ref", ref); NDValidation.validateNumerical("scatterMin", "indices", indices); NDValidation.validateNumerical("scatterMin", "updates", updates); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.scatter.ScatterMin(ref, indices, updates))[0]; + INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.scatter.ScatterMin(ref, indices, updates)); + try { + return __tmp[0]; + } finally { + if(__tmp != null) { + for(int __i = 1; __i < __tmp.length; __i++) { + if(__tmp[__i] != null) { + __tmp[__i].close(); + } + } + } + } } /** @@ -1768,7 +2544,18 @@ public INDArray scatterMul(INDArray ref, INDArray indices, INDArray updates) { NDValidation.validateNumerical("scatterMul", "ref", ref); NDValidation.validateNumerical("scatterMul", "indices", indices); NDValidation.validateNumerical("scatterMul", "updates", updates); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.scatter.ScatterMul(ref, indices, updates))[0]; + INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.scatter.ScatterMul(ref, indices, updates)); + try { + return __tmp[0]; + } finally { + if(__tmp != null) { + for(int __i = 1; __i < __tmp.length; __i++) { + if(__tmp[__i] != null) { + __tmp[__i].close(); + } + } + } + } } /** @@ -1791,7 +2578,18 @@ public INDArray scatterNdAdd(INDArray ref, INDArray indices, INDArray updates) { NDValidation.validateNumerical("scatterNdAdd", "ref", ref); NDValidation.validateNumerical("scatterNdAdd", "indices", indices); NDValidation.validateNumerical("scatterNdAdd", "updates", updates); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.scatter.ScatterNdAdd(ref, indices, updates))[0]; + INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.scatter.ScatterNdAdd(ref, indices, updates)); + try { + return __tmp[0]; + } finally { + if(__tmp != null) { + for(int __i = 1; __i < __tmp.length; __i++) { + if(__tmp[__i] != null) { + __tmp[__i].close(); + } + } + } + } } /** @@ -1814,7 +2612,18 @@ public INDArray scatterNdSub(INDArray ref, INDArray indices, INDArray updates) { NDValidation.validateNumerical("scatterNdSub", "ref", ref); NDValidation.validateNumerical("scatterNdSub", "indices", indices); NDValidation.validateNumerical("scatterNdSub", "updates", updates); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.scatter.ScatterNdSub(ref, indices, updates))[0]; + INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.scatter.ScatterNdSub(ref, indices, updates)); + try { + return __tmp[0]; + } finally { + if(__tmp != null) { + for(int __i = 1; __i < __tmp.length; __i++) { + if(__tmp[__i] != null) { + __tmp[__i].close(); + } + } + } + } } /** @@ -1837,7 +2646,18 @@ public INDArray scatterNdUpdate(INDArray ref, INDArray indices, INDArray updates NDValidation.validateNumerical("scatterNdUpdate", "ref", ref); NDValidation.validateNumerical("scatterNdUpdate", "indices", indices); NDValidation.validateNumerical("scatterNdUpdate", "updates", updates); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.scatter.ScatterNdUpdate(ref, indices, updates))[0]; + INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.scatter.ScatterNdUpdate(ref, indices, updates)); + try { + return __tmp[0]; + } finally { + if(__tmp != null) { + for(int __i = 1; __i < __tmp.length; __i++) { + if(__tmp[__i] != null) { + __tmp[__i].close(); + } + } + } + } } /** @@ -1859,7 +2679,18 @@ public INDArray scatterSub(INDArray ref, INDArray indices, INDArray updates) { NDValidation.validateNumerical("scatterSub", "ref", ref); NDValidation.validateNumerical("scatterSub", "indices", indices); NDValidation.validateNumerical("scatterSub", "updates", updates); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.scatter.ScatterSub(ref, indices, updates))[0]; + INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.scatter.ScatterSub(ref, indices, updates)); + try { + return __tmp[0]; + } finally { + if(__tmp != null) { + for(int __i = 1; __i < __tmp.length; __i++) { + if(__tmp[__i] != null) { + __tmp[__i].close(); + } + } + } + } } /** @@ -1881,7 +2712,18 @@ public INDArray scatterUpdate(INDArray ref, INDArray indices, INDArray updates) NDValidation.validateNumerical("scatterUpdate", "ref", ref); NDValidation.validateNumerical("scatterUpdate", "indices", indices); NDValidation.validateNumerical("scatterUpdate", "updates", updates); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.scatter.ScatterUpdate(ref, indices, updates))[0]; + INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.scatter.ScatterUpdate(ref, indices, updates)); + try { + return __tmp[0]; + } finally { + if(__tmp != null) { + for(int __i = 1; __i < __tmp.length; __i++) { + if(__tmp[__i] != null) { + __tmp[__i].close(); + } + } + } + } } /** @@ -1900,7 +2742,18 @@ public INDArray scatterUpdate(INDArray ref, INDArray indices, INDArray updates) */ public INDArray segmentMax(INDArray data, INDArray segmentIds) { NDValidation.validateNumerical("segmentMax", "segmentIds", segmentIds); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.segment.SegmentMax(data, segmentIds))[0]; + INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.segment.SegmentMax(data, segmentIds)); + try { + return __tmp[0]; + } finally { + if(__tmp != null) { + for(int __i = 1; __i < __tmp.length; __i++) { + if(__tmp[__i] != null) { + __tmp[__i].close(); + } + } + } + } } /** @@ -1919,7 +2772,18 @@ public INDArray segmentMax(INDArray data, INDArray segmentIds) { */ public INDArray segmentMean(INDArray data, INDArray segmentIds) { NDValidation.validateNumerical("segmentMean", "segmentIds", segmentIds); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.segment.SegmentMean(data, segmentIds))[0]; + INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.segment.SegmentMean(data, segmentIds)); + try { + return __tmp[0]; + } finally { + if(__tmp != null) { + for(int __i = 1; __i < __tmp.length; __i++) { + if(__tmp[__i] != null) { + __tmp[__i].close(); + } + } + } + } } /** @@ -1938,7 +2802,18 @@ public INDArray segmentMean(INDArray data, INDArray segmentIds) { */ public INDArray segmentMin(INDArray data, INDArray segmentIds) { NDValidation.validateNumerical("segmentMin", "segmentIds", segmentIds); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.segment.SegmentMin(data, segmentIds))[0]; + INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.segment.SegmentMin(data, segmentIds)); + try { + return __tmp[0]; + } finally { + if(__tmp != null) { + for(int __i = 1; __i < __tmp.length; __i++) { + if(__tmp[__i] != null) { + __tmp[__i].close(); + } + } + } + } } /** @@ -1957,7 +2832,18 @@ public INDArray segmentMin(INDArray data, INDArray segmentIds) { */ public INDArray segmentProd(INDArray data, INDArray segmentIds) { NDValidation.validateNumerical("segmentProd", "segmentIds", segmentIds); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.segment.SegmentProd(data, segmentIds))[0]; + INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.segment.SegmentProd(data, segmentIds)); + try { + return __tmp[0]; + } finally { + if(__tmp != null) { + for(int __i = 1; __i < __tmp.length; __i++) { + if(__tmp[__i] != null) { + __tmp[__i].close(); + } + } + } + } } /** @@ -1976,7 +2862,18 @@ public INDArray segmentProd(INDArray data, INDArray segmentIds) { */ public INDArray segmentSum(INDArray data, INDArray segmentIds) { NDValidation.validateNumerical("segmentSum", "segmentIds", segmentIds); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.segment.SegmentSum(data, segmentIds))[0]; + INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.segment.SegmentSum(data, segmentIds)); + try { + return __tmp[0]; + } finally { + if(__tmp != null) { + for(int __i = 1; __i < __tmp.length; __i++) { + if(__tmp[__i] != null) { + __tmp[__i].close(); + } + } + } + } } /** @@ -1990,7 +2887,18 @@ public INDArray segmentSum(INDArray data, INDArray segmentIds) { */ public INDArray sequenceMask(INDArray lengths, int maxLen, DataType dataType) { NDValidation.validateNumerical("sequenceMask", "lengths", lengths); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.SequenceMask(lengths, maxLen, dataType))[0]; + INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.SequenceMask(lengths, maxLen, dataType)); + try { + return __tmp[0]; + } finally { + if(__tmp != null) { + for(int __i = 1; __i < __tmp.length; __i++) { + if(__tmp[__i] != null) { + __tmp[__i].close(); + } + } + } + } } /** @@ -2005,7 +2913,18 @@ public INDArray sequenceMask(INDArray lengths, int maxLen, DataType dataType) { public INDArray sequenceMask(INDArray lengths, INDArray maxLen, DataType dataType) { NDValidation.validateNumerical("sequenceMask", "lengths", lengths); NDValidation.validateInteger("sequenceMask", "maxLen", maxLen); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.SequenceMask(lengths, maxLen, dataType))[0]; + INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.SequenceMask(lengths, maxLen, dataType)); + try { + return __tmp[0]; + } finally { + if(__tmp != null) { + for(int __i = 1; __i < __tmp.length; __i++) { + if(__tmp[__i] != null) { + __tmp[__i].close(); + } + } + } + } } /** @@ -2017,7 +2936,18 @@ public INDArray sequenceMask(INDArray lengths, INDArray maxLen, DataType dataTyp */ public INDArray sequenceMask(INDArray lengths, DataType dataType) { NDValidation.validateNumerical("sequenceMask", "lengths", lengths); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.SequenceMask(lengths, dataType))[0]; + INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.SequenceMask(lengths, dataType)); + try { + return __tmp[0]; + } finally { + if(__tmp != null) { + for(int __i = 1; __i < __tmp.length; __i++) { + if(__tmp[__i] != null) { + __tmp[__i].close(); + } + } + } + } } /** @@ -2038,7 +2968,18 @@ public INDArray[] setShape(INDArray input, INDArray shape) { * @return output 1D output variable with contents equal to the shape of the input (NUMERIC type) */ public INDArray shape(INDArray input) { - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Shape(input))[0]; + INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Shape(input)); + try { + return __tmp[0]; + } finally { + if(__tmp != null) { + for(int __i = 1; __i < __tmp.length; __i++) { + if(__tmp[__i] != null) { + __tmp[__i].close(); + } + } + } + } } /** @@ -2048,7 +2989,18 @@ public INDArray shape(INDArray input) { * @return output 0D (scalar) output variable with value equal to the number of elements in the specified array (NUMERIC type) */ public INDArray size(INDArray in) { - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Size(in))[0]; + INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Size(in)); + try { + return __tmp[0]; + } finally { + if(__tmp != null) { + for(int __i = 1; __i < __tmp.length; __i++) { + if(__tmp[__i] != null) { + __tmp[__i].close(); + } + } + } + } } /** @@ -2060,7 +3012,18 @@ public INDArray size(INDArray in) { * @return output Scalar INDArray for size at specified variable (NUMERIC type) */ public INDArray sizeAt(INDArray in, int dimension) { - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.SizeAt(in, dimension))[0]; + INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.SizeAt(in, dimension)); + try { + return __tmp[0]; + } finally { + if(__tmp != null) { + for(int __i = 1; __i < __tmp.length; __i++) { + if(__tmp[__i] != null) { + __tmp[__i].close(); + } + } + } + } } /** @@ -2081,7 +3044,18 @@ public INDArray sizeAt(INDArray in, int dimension) { public INDArray slice(INDArray input, int[] begin, int... size) { Preconditions.checkArgument(begin.length >= 1, "begin has incorrect size/length. Expected: begin.length >= 1, got %s", begin.length); Preconditions.checkArgument(size.length >= 1, "size has incorrect size/length. Expected: size.length >= 1, got %s", size.length); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Slice(input, begin, size))[0]; + INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Slice(input, begin, size)); + try { + return __tmp[0]; + } finally { + if(__tmp != null) { + for(int __i = 1; __i < __tmp.length; __i++) { + if(__tmp[__i] != null) { + __tmp[__i].close(); + } + } + } + } } /** @@ -2102,7 +3076,18 @@ public INDArray slice(INDArray input, int[] begin, int... size) { public INDArray slice(INDArray input, INDArray begin, INDArray size) { NDValidation.validateInteger("slice", "begin", begin); NDValidation.validateInteger("slice", "size", size); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Slice(input, begin, size))[0]; + INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Slice(input, begin, size)); + try { + return __tmp[0]; + } finally { + if(__tmp != null) { + for(int __i = 1; __i < __tmp.length; __i++) { + if(__tmp[__i] != null) { + __tmp[__i].close(); + } + } + } + } } /** @@ -2117,7 +3102,18 @@ public INDArray sparseToDense(INDArray indices, INDArray shape, INDArray values) NDValidation.validateNumerical("sparseToDense", "indices", indices); NDValidation.validateNumerical("sparseToDense", "shape", shape); NDValidation.validateNumerical("sparseToDense", "values", values); - return Nd4j.exec(new org.nd4j.linalg.api.ops.compat.CompatSparseToDense(indices, shape, values))[0]; + INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.compat.CompatSparseToDense(indices, shape, values)); + try { + return __tmp[0]; + } finally { + if(__tmp != null) { + for(int __i = 1; __i < __tmp.length; __i++) { + if(__tmp[__i] != null) { + __tmp[__i].close(); + } + } + } + } } /** @@ -2135,7 +3131,18 @@ public INDArray sparseToDense(INDArray indices, INDArray shape, INDArray values, NDValidation.validateNumerical("sparseToDense", "shape", shape); NDValidation.validateNumerical("sparseToDense", "values", values); NDValidation.validateNumerical("sparseToDense", "defaultValue", defaultValue); - return Nd4j.exec(new org.nd4j.linalg.api.ops.compat.CompatSparseToDense(indices, shape, values, defaultValue))[0]; + INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.compat.CompatSparseToDense(indices, shape, values, defaultValue)); + try { + return __tmp[0]; + } finally { + if(__tmp != null) { + for(int __i = 1; __i < __tmp.length; __i++) { + if(__tmp[__i] != null) { + __tmp[__i].close(); + } + } + } + } } /** @@ -2226,7 +3233,42 @@ public INDArray squaredNorm(INDArray x, long... dimensions) { */ public INDArray squeeze(INDArray x, int axis) { NDValidation.validateNumerical("squeeze", "x", x); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Squeeze(x, axis))[0]; + INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Squeeze(x, axis)); + try { + return __tmp[0]; + } finally { + if(__tmp != null) { + for(int __i = 1; __i < __tmp.length; __i++) { + if(__tmp[__i] != null) { + __tmp[__i].close(); + } + } + } + } + } + + /** + * Remove all dimensions of size 1 from the input tensor.
+ * For example, if input has shape [a,1,b,1,c] then squeezeAll(input) returns an array of shape [a,b,c]
+ * This is the NumPy-style squeeze with no axis specified.
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public INDArray squeezeAll(INDArray x) { + NDValidation.validateNumerical("squeezeAll", "x", x); + INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Squeeze(x)); + try { + return __tmp[0]; + } finally { + if(__tmp != null) { + for(int __i = 1; __i < __tmp.length; __i++) { + if(__tmp[__i] != null) { + __tmp[__i].close(); + } + } + } + } } /** @@ -2244,7 +3286,18 @@ public INDArray squeeze(INDArray x, int axis) { */ public INDArray stack(int axis, INDArray... values) { Preconditions.checkArgument(values.length >= 1, "values has incorrect size/length. Expected: values.length >= 1, got %s", values.length); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Stack(values, axis))[0]; + INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Stack(values, axis)); + try { + return __tmp[0]; + } finally { + if(__tmp != null) { + for(int __i = 1; __i < __tmp.length; __i++) { + if(__tmp[__i] != null) { + __tmp[__i].close(); + } + } + } + } } /** @@ -2317,7 +3370,18 @@ public INDArray stridedSlice(INDArray in, long[] begin, long[] end, long[] strid Preconditions.checkArgument(begin.length >= 1, "begin has incorrect size/length. Expected: begin.length >= 1, got %s", begin.length); Preconditions.checkArgument(end.length >= 1, "end has incorrect size/length. Expected: end.length >= 1, got %s", end.length); Preconditions.checkArgument(strides.length >= 1, "strides has incorrect size/length. Expected: strides.length >= 1, got %s", strides.length); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.StridedSlice(in, begin, end, strides, beginMask, endMask, ellipsisMask, newAxisMask, shrinkAxisMask))[0]; + INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.StridedSlice(in, begin, end, strides, beginMask, endMask, ellipsisMask, newAxisMask, shrinkAxisMask)); + try { + return __tmp[0]; + } finally { + if(__tmp != null) { + for(int __i = 1; __i < __tmp.length; __i++) { + if(__tmp[__i] != null) { + __tmp[__i].close(); + } + } + } + } } /** @@ -2340,7 +3404,18 @@ public INDArray stridedSlice(INDArray in, long[] begin, long[] end, long... stri Preconditions.checkArgument(begin.length >= 1, "begin has incorrect size/length. Expected: begin.length >= 1, got %s", begin.length); Preconditions.checkArgument(end.length >= 1, "end has incorrect size/length. Expected: end.length >= 1, got %s", end.length); Preconditions.checkArgument(strides.length >= 1, "strides has incorrect size/length. Expected: strides.length >= 1, got %s", strides.length); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.StridedSlice(in, begin, end, strides, 0, 0, 0, 0, 0))[0]; + INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.StridedSlice(in, begin, end, strides, 0, 0, 0, 0, 0)); + try { + return __tmp[0]; + } finally { + if(__tmp != null) { + for(int __i = 1; __i < __tmp.length; __i++) { + if(__tmp[__i] != null) { + __tmp[__i].close(); + } + } + } + } } /** @@ -2369,7 +3444,18 @@ public INDArray stridedSlice(INDArray in, INDArray begin, INDArray end, INDArray NDValidation.validateNumerical("stridedSlice", "begin", begin); NDValidation.validateNumerical("stridedSlice", "end", end); NDValidation.validateNumerical("stridedSlice", "strides", strides); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.StridedSlice(in, begin, end, strides, beginMask, endMask, ellipsisMask, newAxisMask, shrinkAxisMask))[0]; + INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.StridedSlice(in, begin, end, strides, beginMask, endMask, ellipsisMask, newAxisMask, shrinkAxisMask)); + try { + return __tmp[0]; + } finally { + if(__tmp != null) { + for(int __i = 1; __i < __tmp.length; __i++) { + if(__tmp[__i] != null) { + __tmp[__i].close(); + } + } + } + } } /** @@ -2392,7 +3478,18 @@ public INDArray stridedSlice(INDArray in, INDArray begin, INDArray end, INDArray NDValidation.validateNumerical("stridedSlice", "begin", begin); NDValidation.validateNumerical("stridedSlice", "end", end); NDValidation.validateNumerical("stridedSlice", "strides", strides); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.StridedSlice(in, begin, end, strides, 0, 0, 0, 0, 0))[0]; + INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.StridedSlice(in, begin, end, strides, 0, 0, 0, 0, 0)); + try { + return __tmp[0]; + } finally { + if(__tmp != null) { + for(int __i = 1; __i < __tmp.length; __i++) { + if(__tmp[__i] != null) { + __tmp[__i].close(); + } + } + } + } } /** @@ -2466,7 +3563,18 @@ public INDArray tensorMmul(INDArray x, INDArray y, int[] dimensionsX, int[] dime NDValidation.validateNumerical("tensorMmul", "y", y); Preconditions.checkArgument(dimensionsX.length >= 1, "dimensionsX has incorrect size/length. Expected: dimensionsX.length >= 1, got %s", dimensionsX.length); Preconditions.checkArgument(dimensionsY.length >= 1, "dimensionsY has incorrect size/length. Expected: dimensionsY.length >= 1, got %s", dimensionsY.length); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.TensorMmul(x, y, dimensionsX, dimensionsY, transposeX, transposeY, transposeZ))[0]; + INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.TensorMmul(x, y, dimensionsX, dimensionsY, transposeX, transposeY, transposeZ)); + try { + return __tmp[0]; + } finally { + if(__tmp != null) { + for(int __i = 1; __i < __tmp.length; __i++) { + if(__tmp[__i] != null) { + __tmp[__i].close(); + } + } + } + } } /** @@ -2483,7 +3591,18 @@ public INDArray tensorMmul(INDArray x, INDArray y, int[] dimensionsX, int... dim NDValidation.validateNumerical("tensorMmul", "y", y); Preconditions.checkArgument(dimensionsX.length >= 1, "dimensionsX has incorrect size/length. Expected: dimensionsX.length >= 1, got %s", dimensionsX.length); Preconditions.checkArgument(dimensionsY.length >= 1, "dimensionsY has incorrect size/length. Expected: dimensionsY.length >= 1, got %s", dimensionsY.length); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.TensorMmul(x, y, dimensionsX, dimensionsY, false, false, false))[0]; + INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.TensorMmul(x, y, dimensionsX, dimensionsY, false, false, false)); + try { + return __tmp[0]; + } finally { + if(__tmp != null) { + for(int __i = 1; __i < __tmp.length; __i++) { + if(__tmp[__i] != null) { + __tmp[__i].close(); + } + } + } + } } /** @@ -2504,7 +3623,18 @@ public INDArray tensorMmul(INDArray x, INDArray y, int[] dimensionsX, int... dim */ public INDArray tile(INDArray x, INDArray repeat) { NDValidation.validateInteger("tile", "repeat", repeat); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Tile(x, repeat))[0]; + INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Tile(x, repeat)); + try { + return __tmp[0]; + } finally { + if(__tmp != null) { + for(int __i = 1; __i < __tmp.length; __i++) { + if(__tmp[__i] != null) { + __tmp[__i].close(); + } + } + } + } } /** @@ -2516,7 +3646,18 @@ public INDArray tile(INDArray x, INDArray repeat) { */ public INDArray tile(INDArray x, int... repeat) { Preconditions.checkArgument(repeat.length >= 1, "repeat has incorrect size/length. Expected: repeat.length >= 1, got %s", repeat.length); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Tile(x, repeat))[0]; + INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Tile(x, repeat)); + try { + return __tmp[0]; + } finally { + if(__tmp != null) { + for(int __i = 1; __i < __tmp.length; __i++) { + if(__tmp[__i] != null) { + __tmp[__i].close(); + } + } + } + } } /** @@ -2526,7 +3667,18 @@ public INDArray tile(INDArray x, int... repeat) { * @return output transposed input (NDARRAY type) */ public INDArray transpose(INDArray x) { - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Transpose(x))[0]; + INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Transpose(x)); + try { + return __tmp[0]; + } finally { + if(__tmp != null) { + for(int __i = 1; __i < __tmp.length; __i++) { + if(__tmp[__i] != null) { + __tmp[__i].close(); + } + } + } + } } /** @@ -2544,7 +3696,18 @@ public INDArray transpose(INDArray x) { public INDArray unsortedSegmentMax(INDArray data, INDArray segmentIds, int numSegments) { NDValidation.validateNumerical("unsortedSegmentMax", "data", data); NDValidation.validateNumerical("unsortedSegmentMax", "segmentIds", segmentIds); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentMax(data, segmentIds, numSegments))[0]; + INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentMax(data, segmentIds, numSegments)); + try { + return __tmp[0]; + } finally { + if(__tmp != null) { + for(int __i = 1; __i < __tmp.length; __i++) { + if(__tmp[__i] != null) { + __tmp[__i].close(); + } + } + } + } } /** @@ -2563,7 +3726,18 @@ public INDArray unsortedSegmentMax(INDArray data, INDArray segmentIds, INDArray NDValidation.validateNumerical("unsortedSegmentMax", "data", data); NDValidation.validateNumerical("unsortedSegmentMax", "segmentIds", segmentIds); NDValidation.validateInteger("unsortedSegmentMax", "numSegments", numSegments); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentMax(data, segmentIds, numSegments))[0]; + INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentMax(data, segmentIds, numSegments)); + try { + return __tmp[0]; + } finally { + if(__tmp != null) { + for(int __i = 1; __i < __tmp.length; __i++) { + if(__tmp[__i] != null) { + __tmp[__i].close(); + } + } + } + } } /** @@ -2581,7 +3755,18 @@ public INDArray unsortedSegmentMax(INDArray data, INDArray segmentIds, INDArray public INDArray unsortedSegmentMean(INDArray data, INDArray segmentIds, int numSegments) { NDValidation.validateNumerical("unsortedSegmentMean", "data", data); NDValidation.validateNumerical("unsortedSegmentMean", "segmentIds", segmentIds); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentMean(data, segmentIds, numSegments))[0]; + INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentMean(data, segmentIds, numSegments)); + try { + return __tmp[0]; + } finally { + if(__tmp != null) { + for(int __i = 1; __i < __tmp.length; __i++) { + if(__tmp[__i] != null) { + __tmp[__i].close(); + } + } + } + } } /** @@ -2600,7 +3785,18 @@ public INDArray unsortedSegmentMean(INDArray data, INDArray segmentIds, INDArray NDValidation.validateNumerical("unsortedSegmentMean", "data", data); NDValidation.validateNumerical("unsortedSegmentMean", "segmentIds", segmentIds); NDValidation.validateInteger("unsortedSegmentMean", "numSegments", numSegments); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentMean(data, segmentIds, numSegments))[0]; + INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentMean(data, segmentIds, numSegments)); + try { + return __tmp[0]; + } finally { + if(__tmp != null) { + for(int __i = 1; __i < __tmp.length; __i++) { + if(__tmp[__i] != null) { + __tmp[__i].close(); + } + } + } + } } /** @@ -2618,7 +3814,18 @@ public INDArray unsortedSegmentMean(INDArray data, INDArray segmentIds, INDArray public INDArray unsortedSegmentMin(INDArray data, INDArray segmentIds, int numSegments) { NDValidation.validateNumerical("unsortedSegmentMin", "data", data); NDValidation.validateNumerical("unsortedSegmentMin", "segmentIds", segmentIds); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentMin(data, segmentIds, numSegments))[0]; + INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentMin(data, segmentIds, numSegments)); + try { + return __tmp[0]; + } finally { + if(__tmp != null) { + for(int __i = 1; __i < __tmp.length; __i++) { + if(__tmp[__i] != null) { + __tmp[__i].close(); + } + } + } + } } /** @@ -2637,7 +3844,18 @@ public INDArray unsortedSegmentMin(INDArray data, INDArray segmentIds, INDArray NDValidation.validateNumerical("unsortedSegmentMin", "data", data); NDValidation.validateNumerical("unsortedSegmentMin", "segmentIds", segmentIds); NDValidation.validateInteger("unsortedSegmentMin", "numSegments", numSegments); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentMin(data, segmentIds, numSegments))[0]; + INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentMin(data, segmentIds, numSegments)); + try { + return __tmp[0]; + } finally { + if(__tmp != null) { + for(int __i = 1; __i < __tmp.length; __i++) { + if(__tmp[__i] != null) { + __tmp[__i].close(); + } + } + } + } } /** @@ -2655,7 +3873,18 @@ public INDArray unsortedSegmentMin(INDArray data, INDArray segmentIds, INDArray public INDArray unsortedSegmentProd(INDArray data, INDArray segmentIds, int numSegments) { NDValidation.validateNumerical("unsortedSegmentProd", "data", data); NDValidation.validateNumerical("unsortedSegmentProd", "segmentIds", segmentIds); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentProd(data, segmentIds, numSegments))[0]; + INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentProd(data, segmentIds, numSegments)); + try { + return __tmp[0]; + } finally { + if(__tmp != null) { + for(int __i = 1; __i < __tmp.length; __i++) { + if(__tmp[__i] != null) { + __tmp[__i].close(); + } + } + } + } } /** @@ -2674,7 +3903,18 @@ public INDArray unsortedSegmentProd(INDArray data, INDArray segmentIds, INDArray NDValidation.validateNumerical("unsortedSegmentProd", "data", data); NDValidation.validateNumerical("unsortedSegmentProd", "segmentIds", segmentIds); NDValidation.validateInteger("unsortedSegmentProd", "numSegments", numSegments); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentProd(data, segmentIds, numSegments))[0]; + INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentProd(data, segmentIds, numSegments)); + try { + return __tmp[0]; + } finally { + if(__tmp != null) { + for(int __i = 1; __i < __tmp.length; __i++) { + if(__tmp[__i] != null) { + __tmp[__i].close(); + } + } + } + } } /** @@ -2691,7 +3931,18 @@ public INDArray unsortedSegmentProd(INDArray data, INDArray segmentIds, INDArray public INDArray unsortedSegmentSqrtN(INDArray data, INDArray segmentIds, int numSegments) { NDValidation.validateNumerical("unsortedSegmentSqrtN", "data", data); NDValidation.validateNumerical("unsortedSegmentSqrtN", "segmentIds", segmentIds); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentSqrtN(data, segmentIds, numSegments))[0]; + INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentSqrtN(data, segmentIds, numSegments)); + try { + return __tmp[0]; + } finally { + if(__tmp != null) { + for(int __i = 1; __i < __tmp.length; __i++) { + if(__tmp[__i] != null) { + __tmp[__i].close(); + } + } + } + } } /** @@ -2709,7 +3960,18 @@ public INDArray unsortedSegmentSqrtN(INDArray data, INDArray segmentIds, INDArra NDValidation.validateNumerical("unsortedSegmentSqrtN", "data", data); NDValidation.validateNumerical("unsortedSegmentSqrtN", "segmentIds", segmentIds); NDValidation.validateInteger("unsortedSegmentSqrtN", "numSegments", numSegments); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentSqrtN(data, segmentIds, numSegments))[0]; + INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentSqrtN(data, segmentIds, numSegments)); + try { + return __tmp[0]; + } finally { + if(__tmp != null) { + for(int __i = 1; __i < __tmp.length; __i++) { + if(__tmp[__i] != null) { + __tmp[__i].close(); + } + } + } + } } /** @@ -2727,7 +3989,18 @@ public INDArray unsortedSegmentSqrtN(INDArray data, INDArray segmentIds, INDArra public INDArray unsortedSegmentSum(INDArray data, INDArray segmentIds, int numSegments) { NDValidation.validateNumerical("unsortedSegmentSum", "data", data); NDValidation.validateNumerical("unsortedSegmentSum", "segmentIds", segmentIds); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentSum(data, segmentIds, numSegments))[0]; + INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentSum(data, segmentIds, numSegments)); + try { + return __tmp[0]; + } finally { + if(__tmp != null) { + for(int __i = 1; __i < __tmp.length; __i++) { + if(__tmp[__i] != null) { + __tmp[__i].close(); + } + } + } + } } /** @@ -2746,7 +4019,18 @@ public INDArray unsortedSegmentSum(INDArray data, INDArray segmentIds, INDArray NDValidation.validateNumerical("unsortedSegmentSum", "data", data); NDValidation.validateNumerical("unsortedSegmentSum", "segmentIds", segmentIds); NDValidation.validateInteger("unsortedSegmentSum", "numSegments", numSegments); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentSum(data, segmentIds, numSegments))[0]; + INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentSum(data, segmentIds, numSegments)); + try { + return __tmp[0]; + } finally { + if(__tmp != null) { + for(int __i = 1; __i < __tmp.length; __i++) { + if(__tmp[__i] != null) { + __tmp[__i].close(); + } + } + } + } } /** @@ -2825,7 +4109,18 @@ public INDArray variance(INDArray x, boolean biasCorrected, long... dimensions) */ public INDArray where(INDArray x, INDArray y, INDArray condition) { NDValidation.validateBool("where", "condition", condition); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.controlflow.Where(x, y, condition))[0]; + INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.controlflow.Where(x, y, condition)); + try { + return __tmp[0]; + } finally { + if(__tmp != null) { + for(int __i = 1; __i < __tmp.length; __i++) { + if(__tmp[__i] != null) { + __tmp[__i].close(); + } + } + } + } } /** @@ -2845,7 +4140,18 @@ public INDArray where(INDArray x, INDArray y, INDArray condition) { public INDArray where(INDArray x, INDArray condition) { NDValidation.validateNumerical("where", "x", x); NDValidation.validateBool("where", "condition", condition); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.controlflow.Where(x, condition))[0]; + INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.controlflow.Where(x, condition)); + try { + return __tmp[0]; + } finally { + if(__tmp != null) { + for(int __i = 1; __i < __tmp.length; __i++) { + if(__tmp[__i] != null) { + __tmp[__i].close(); + } + } + } + } } /** @@ -2863,7 +4169,18 @@ public INDArray where(INDArray x, INDArray condition) { */ public INDArray where(INDArray condition) { NDValidation.validateBool("where", "condition", condition); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.controlflow.Where(condition))[0]; + INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.controlflow.Where(condition)); + try { + return __tmp[0]; + } finally { + if(__tmp != null) { + for(int __i = 1; __i < __tmp.length; __i++) { + if(__tmp[__i] != null) { + __tmp[__i].close(); + } + } + } + } } /** @@ -2883,7 +4200,18 @@ public INDArray where(INDArray condition) { */ public INDArray whereNumpy(INDArray x, INDArray y, INDArray condition) { NDValidation.validateNumerical("whereNumpy", "condition", condition); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.controlflow.WhereNumpy(x, y, condition))[0]; + INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.controlflow.WhereNumpy(x, y, condition)); + try { + return __tmp[0]; + } finally { + if(__tmp != null) { + for(int __i = 1; __i < __tmp.length; __i++) { + if(__tmp[__i] != null) { + __tmp[__i].close(); + } + } + } + } } /** @@ -2894,6 +4222,17 @@ public INDArray whereNumpy(INDArray x, INDArray y, INDArray condition) { * @return output A new Variable with the same (dynamic) shape as the input (NUMERIC type) */ public INDArray zerosLike(INDArray input) { - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.ZerosLike(input))[0]; + INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.ZerosLike(input)); + try { + return __tmp[0]; + } finally { + if(__tmp != null) { + for(int __i = 1; __i < __tmp.length; __i++) { + if(__tmp[__i] != null) { + __tmp[__i].close(); + } + } + } + } } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDNN.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDNN.java index 1a211b65d47d..a0165ed8bda5 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDNN.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDNN.java @@ -42,7 +42,18 @@ public NDNN() { */ public INDArray cReLU(INDArray x) { NDValidation.validateNumerical("CReLU", "x", x); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.CReLU(x))[0]; + INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.CReLU(x)); + try { + return __tmp[0]; + } finally { + if(__tmp != null) { + for(int __i = 1; __i < __tmp.length; __i++) { + if(__tmp[__i] != null) { + __tmp[__i].close(); + } + } + } + } } /** @@ -68,7 +79,18 @@ public INDArray batchNorm(INDArray input, INDArray mean, INDArray variance, INDA NDValidation.validateNumerical("batchNorm", "gamma", gamma); NDValidation.validateNumerical("batchNorm", "beta", beta); Preconditions.checkArgument(axis.length >= 1, "axis has incorrect size/length. Expected: axis.length >= 1, got %s", axis.length); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.layers.convolution.BatchNorm(input, mean, variance, gamma, beta, epsilon, axis))[0]; + INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.layers.convolution.BatchNorm(input, mean, variance, gamma, beta, epsilon, axis)); + try { + return __tmp[0]; + } finally { + if(__tmp != null) { + for(int __i = 1; __i < __tmp.length; __i++) { + if(__tmp[__i] != null) { + __tmp[__i].close(); + } + } + } + } } /** @@ -83,7 +105,134 @@ public INDArray batchNorm(INDArray input, INDArray mean, INDArray variance, INDA public INDArray biasAdd(INDArray input, INDArray bias, boolean nchw) { NDValidation.validateNumerical("biasAdd", "input", input); NDValidation.validateNumerical("biasAdd", "bias", bias); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.broadcast.BiasAdd(input, bias, nchw))[0]; + INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.broadcast.BiasAdd(input, bias, nchw)); + try { + return __tmp[0]; + } finally { + if(__tmp != null) { + for(int __i = 1; __i < __tmp.length; __i++) { + if(__tmp[__i] != null) { + __tmp[__i].close(); + } + } + } + } + } + + /** + * DINOv2 centering and sharpening operation.
+ * Prevents mode collapse in self-supervised learning by centering the teacher output
+ * and applying temperature-based sharpening:
+ * output = softmax((input - center) / temperature)
+ * + * @param input Teacher output logits [batch, features] (NUMERIC type) + * @param center Running center vector [features] (NUMERIC type) + * @param temperature Sharpening temperature (typically 0.04-0.07) + * @return output Sharpened probabilities [batch, features] (NUMERIC type) + */ + public INDArray centerAndSharpen(INDArray input, INDArray center, double temperature) { + NDValidation.validateNumerical("centerAndSharpen", "input", input); + NDValidation.validateNumerical("centerAndSharpen", "center", center); + INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.CenterAndSharpen(input, center, temperature)); + try { + return __tmp[0]; + } finally { + if(__tmp != null) { + for(int __i = 1; __i < __tmp.length; __i++) { + if(__tmp[__i] != null) { + __tmp[__i].close(); + } + } + } + } + } + + /** + * DINOv2 centering and sharpening operation.
+ * Prevents mode collapse in self-supervised learning by centering the teacher output
+ * and applying temperature-based sharpening:
+ * output = softmax((input - center) / temperature)
+ * + * @param input Teacher output logits [batch, features] (NUMERIC type) + * @param center Running center vector [features] (NUMERIC type) + * @return output Sharpened probabilities [batch, features] (NUMERIC type) + */ + public INDArray centerAndSharpen(INDArray input, INDArray center) { + NDValidation.validateNumerical("centerAndSharpen", "input", input); + NDValidation.validateNumerical("centerAndSharpen", "center", center); + INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.CenterAndSharpen(input, center, 0.07)); + try { + return __tmp[0]; + } finally { + if(__tmp != null) { + for(int __i = 1; __i < __tmp.length; __i++) { + if(__tmp[__i] != null) { + __tmp[__i].close(); + } + } + } + } + } + + /** + * CTC Greedy Decoder - Connectionist Temporal Classification decoding.
+ *
+ * Performs greedy (best path) decoding on CTC output. Used in:
+ * - OCR (Optical Character Recognition) - PaddleOCR, CRNN
+ * - Speech recognition - DeepSpeech, Wav2Vec
+ * - Handwriting recognition
+ *
+ * Algorithm:
+ * 1. At each timestep, select the class with highest probability
+ * 2. Optionally merge consecutive repeated characters
+ * 3. Remove blank labels from the output
+ *
+ * For example, with mergeRepeated=true and blankIndex=0:
+ * Input: [0, 1, 1, 0, 2, 2, 2, 0] (0=blank, 1='a', 2='b')
+ * Output: [1, 2] -> "ab"
+ *
+ * Note: This is greedy decoding. For better accuracy with language models,
+ * use beam search decoding instead.
+ * + * @param logits Log probabilities from CTC output. Shape: [batch, timeSteps, numClasses] (NUMERIC type) + * @param mergeRepeated Whether to merge repeated characters in output + * @param blankIndex Index of the blank label in the vocabulary + */ + public INDArray[] ctcGreedyDecoder(INDArray logits, boolean mergeRepeated, int blankIndex) { + NDValidation.validateNumerical("ctcGreedyDecoder", "logits", logits); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.CTCGreedyDecoder(logits, null, mergeRepeated, blankIndex)); + } + + /** + * CTC Greedy Decoder - Connectionist Temporal Classification decoding.
+ *
+ * Performs greedy (best path) decoding on CTC output. Used in:
+ * - OCR (Optical Character Recognition) - PaddleOCR, CRNN
+ * - Speech recognition - DeepSpeech, Wav2Vec
+ * - Handwriting recognition
+ *
+ * Algorithm:
+ * 1. At each timestep, select the class with highest probability
+ * 2. Optionally merge consecutive repeated characters
+ * 3. Remove blank labels from the output
+ *
+ * For example, with mergeRepeated=true and blankIndex=0:
+ * Input: [0, 1, 1, 0, 2, 2, 2, 0] (0=blank, 1='a', 2='b')
+ * Output: [1, 2] -> "ab"
+ *
+ * Note: This is greedy decoding. For better accuracy with language models,
+ * use beam search decoding instead.
+ * + * @param logits Log probabilities from CTC output. Shape: [batch, timeSteps, numClasses] (NUMERIC type) + * @param sequenceLength Optional actual sequence lengths. Shape: [batch] (NUMERIC type) + * @param mergeRepeated Whether to merge repeated characters in output + * @param blankIndex Index of the blank label in the vocabulary + */ + public INDArray[] ctcGreedyDecoder(INDArray logits, INDArray sequenceLength, + boolean mergeRepeated, int blankIndex) { + NDValidation.validateNumerical("ctcGreedyDecoder", "logits", logits); + NDValidation.validateNumerical("ctcGreedyDecoder", "sequenceLength", sequenceLength); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.CTCGreedyDecoder(logits, sequenceLength, mergeRepeated, blankIndex)); } /** @@ -123,39 +272,52 @@ public INDArray dotProductAttention(INDArray queries, INDArray keys, INDArray va NDValidation.validateNumerical("dotProductAttention", "keys", keys); NDValidation.validateNumerical("dotProductAttention", "values", values); NDValidation.validateNumerical("dotProductAttention", "mask", mask); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.DotProductAttention(queries, keys, values, mask, scaled, false))[0]; + INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.DotProductAttention(queries, keys, values, mask, scaled, false)); + try { + return __tmp[0]; + } finally { + if(__tmp != null) { + for(int __i = 1; __i < __tmp.length; __i++) { + if(__tmp[__i] != null) { + __tmp[__i].close(); + } + } + } + } } /** - * This operation performs dot product attention on the given timeseries input with the given queries
- * out = sum(similarity(k_i, q) * v_i)
+ * Dot product attention operation with flash attention and KV cache support.
*
- * similarity(k, q) = softmax(k * q) where x * q is the dot product of x and q
+ * out = softmax(Q * K^T / scale + attentionBias) * V
*
- * Optionally with normalization step:
- * similarity(k, q) = softmax(k * q / sqrt(size(q))
+ * For 4D inputs [batch, seq, heads, dim], uses memory-efficient flash attention algorithm.
+ * For 2D/3D inputs, uses standard attention computation.
*
- * See also "Attention is all you need" (https://arxiv.org/abs/1706.03762, p. 4, eq. 1)
+ * Flash attention features:
+ * - O(N) memory complexity instead of O(N^2)
+ * - Tiled computation with online softmax
+ * - Supports grouped query attention (GQA) where numHeads > numKvHeads
+ * - Supports attention bias (relative position bias, ALiBi, etc.)
*
- * Note: This supports multiple queries at once, if only one query is available the queries vector still has to
- * be 3D but can have queryCount = 1
+ * KV Cache support for autoregressive generation:
+ * - Pass keyCache and valueCache tensors
+ * - Set kvCachePosition to current generation position
+ * - Cached keys/values are updated in-place
*
- * Note: keys and values usually is the same array. If you want to use it as the same array, simply pass it for
- * both.
- *
- * Note: Queries, keys and values must either be all rank 3 or all rank 4 arrays. Mixing them doesn't work. The
- * output rank will depend on the input rank.
+ * See "Attention is all you need" (https://arxiv.org/abs/1706.03762)
+ * See "FlashAttention: Fast and Memory-Efficient Exact Attention" (https://arxiv.org/abs/2205.14135)
* - * @param queries A {@link SDVariable} representing the query tensor. Shape: [batchSize, numQueries, queryDim] (NUMERIC type) - * @param values A {@link SDVariable} representing the value tensor. Shape: [batchSize, numValues, valueDim] (NUMERIC type) - * @param keys A {@link SDVariable} representing the key tensor. Shape: [batchSize, numValues, keyDim] (NUMERIC type) - * @param queryMask A {@link SDVariable} representing the query mask tensor. Shape: [batchSize, numQueries] (NUMERIC type) - * @param valueMask @param valueMask A {@link SDVariable} representing the value mask tensor. Shape: [batchSize, numValues] (NUMERIC type) - * @param scaleFactor @param scaleFactor A {@code double} scaling factor applied to the dot product between queries and keys. - * @param dropoutProbability A {@code double} specifying the dropout probability to be applied to attention weights. - * @param useCausalMask A {@code boolean} flag to indicate whether to apply a causal mask to the attention scores, for autoregressive tasks. - * @param training A {@code boolean} flag to indicate whether the layer is in training mode or inference mode, affecting dropout. - * @return output A {@link SDVariable} representing the output tensor of the dot product attention operation. Shape: [batchSize, numQueries, valueDim] (NUMERIC type) + * @param queries Query tensor. Shape: [batchSize, numQueries, queryDim] or [batchSize, numQueries, numHeads, headDim] for flash attention (NUMERIC type) + * @param values Value tensor. Shape: [batchSize, numValues, valueDim] or [batchSize, numValues, numHeads, headDim] (NUMERIC type) + * @param keys Key tensor. Shape: [batchSize, numValues, keyDim] or [batchSize, numValues, numHeads, headDim] (NUMERIC type) + * @param queryMask Query mask tensor (optional). Shape: [batchSize, numQueries] (NUMERIC type) + * @param valueMask Value mask tensor (optional). Shape: [batchSize, numValues] (NUMERIC type) + * @param scaleFactor Scaling factor applied to attention scores. 0 = auto (1/sqrt(headDim)) + * @param dropoutProbability Dropout probability applied to attention weights + * @param useCausalMask Whether to apply causal mask for autoregressive tasks + * @param training Whether in training mode (affects dropout) + * @return output Output tensor. Shape: [batchSize, numQueries, valueDim] or [batchSize, numQueries, numHeads, headDim] (NUMERIC type) */ public INDArray dotProductAttentionV2(INDArray queries, INDArray values, INDArray keys, INDArray queryMask, INDArray valueMask, double scaleFactor, double dropoutProbability, @@ -165,7 +327,75 @@ public INDArray dotProductAttentionV2(INDArray queries, INDArray values, INDArra NDValidation.validateNumerical("dotProductAttentionV2", "keys", keys); NDValidation.validateNumerical("dotProductAttentionV2", "queryMask", queryMask); NDValidation.validateNumerical("dotProductAttentionV2", "valueMask", valueMask); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.DotProductAttentionV2(queries, values, keys, queryMask, valueMask, scaleFactor, dropoutProbability, useCausalMask, training))[0]; + INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.DotProductAttentionV2(queries, values, keys, queryMask, valueMask, null, scaleFactor, dropoutProbability, useCausalMask, training)); + try { + return __tmp[0]; + } finally { + if(__tmp != null) { + for(int __i = 1; __i < __tmp.length; __i++) { + if(__tmp[__i] != null) { + __tmp[__i].close(); + } + } + } + } + } + + /** + * Dot product attention operation with flash attention and KV cache support.
+ *
+ * out = softmax(Q * K^T / scale + attentionBias) * V
+ *
+ * For 4D inputs [batch, seq, heads, dim], uses memory-efficient flash attention algorithm.
+ * For 2D/3D inputs, uses standard attention computation.
+ *
+ * Flash attention features:
+ * - O(N) memory complexity instead of O(N^2)
+ * - Tiled computation with online softmax
+ * - Supports grouped query attention (GQA) where numHeads > numKvHeads
+ * - Supports attention bias (relative position bias, ALiBi, etc.)
+ *
+ * KV Cache support for autoregressive generation:
+ * - Pass keyCache and valueCache tensors
+ * - Set kvCachePosition to current generation position
+ * - Cached keys/values are updated in-place
+ *
+ * See "Attention is all you need" (https://arxiv.org/abs/1706.03762)
+ * See "FlashAttention: Fast and Memory-Efficient Exact Attention" (https://arxiv.org/abs/2205.14135)
+ * + * @param queries Query tensor. Shape: [batchSize, numQueries, queryDim] or [batchSize, numQueries, numHeads, headDim] for flash attention (NUMERIC type) + * @param values Value tensor. Shape: [batchSize, numValues, valueDim] or [batchSize, numValues, numHeads, headDim] (NUMERIC type) + * @param keys Key tensor. Shape: [batchSize, numValues, keyDim] or [batchSize, numValues, numHeads, headDim] (NUMERIC type) + * @param queryMask Query mask tensor (optional). Shape: [batchSize, numQueries] (NUMERIC type) + * @param valueMask Value mask tensor (optional). Shape: [batchSize, numValues] (NUMERIC type) + * @param attentionBias Attention bias tensor (optional). Shape: [batchSize, numHeads, numQueries, numKeys] or broadcastable. Added to attention scores before softmax. (NUMERIC type) + * @param scaleFactor Scaling factor applied to attention scores. 0 = auto (1/sqrt(headDim)) + * @param dropoutProbability Dropout probability applied to attention weights + * @param useCausalMask Whether to apply causal mask for autoregressive tasks + * @param training Whether in training mode (affects dropout) + * @return output Output tensor. Shape: [batchSize, numQueries, valueDim] or [batchSize, numQueries, numHeads, headDim] (NUMERIC type) + */ + public INDArray dotProductAttentionV2(INDArray queries, INDArray values, INDArray keys, + INDArray queryMask, INDArray valueMask, INDArray attentionBias, double scaleFactor, + double dropoutProbability, boolean useCausalMask, boolean training) { + NDValidation.validateNumerical("dotProductAttentionV2", "queries", queries); + NDValidation.validateNumerical("dotProductAttentionV2", "values", values); + NDValidation.validateNumerical("dotProductAttentionV2", "keys", keys); + NDValidation.validateNumerical("dotProductAttentionV2", "queryMask", queryMask); + NDValidation.validateNumerical("dotProductAttentionV2", "valueMask", valueMask); + NDValidation.validateNumerical("dotProductAttentionV2", "attentionBias", attentionBias); + INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.DotProductAttentionV2(queries, values, keys, queryMask, valueMask, attentionBias, scaleFactor, dropoutProbability, useCausalMask, training)); + try { + return __tmp[0]; + } finally { + if(__tmp != null) { + for(int __i = 1; __i < __tmp.length; __i++) { + if(__tmp[__i] != null) { + __tmp[__i].close(); + } + } + } + } } /** @@ -179,7 +409,18 @@ public INDArray dotProductAttentionV2(INDArray queries, INDArray values, INDArra */ public INDArray dropout(INDArray input, boolean inverted, int seed, double probabilityValue) { NDValidation.validateNumerical("dropout", "input", input); - return Nd4j.exec(new org.nd4j.linalg.api.ops.random.impl.CustomDropOut(input, inverted, seed, probabilityValue))[0]; + INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.random.impl.CustomDropOut(input, inverted, seed, probabilityValue)); + try { + return __tmp[0]; + } finally { + if(__tmp != null) { + for(int __i = 1; __i < __tmp.length; __i++) { + if(__tmp[__i] != null) { + __tmp[__i].close(); + } + } + } + } } /** @@ -192,7 +433,18 @@ public INDArray dropout(INDArray input, boolean inverted, int seed, double proba */ public INDArray dropout(INDArray input, boolean inverted, double probabilityValue) { NDValidation.validateNumerical("dropout", "input", input); - return Nd4j.exec(new org.nd4j.linalg.api.ops.random.impl.CustomDropOut(input, inverted, 0, probabilityValue))[0]; + INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.random.impl.CustomDropOut(input, inverted, 0, probabilityValue)); + try { + return __tmp[0]; + } finally { + if(__tmp != null) { + for(int __i = 1; __i < __tmp.length; __i++) { + if(__tmp[__i] != null) { + __tmp[__i].close(); + } + } + } + } } /** @@ -208,7 +460,140 @@ public INDArray dropout(INDArray input, boolean inverted, double probabilityValu */ public INDArray elu(INDArray x) { NDValidation.validateNumerical("elu", "x", x); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.strict.ELU(x))[0]; + INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.strict.ELU(x)); + try { + return __tmp[0]; + } finally { + if(__tmp != null) { + for(int __i = 1; __i < __tmp.length; __i++) { + if(__tmp[__i] != null) { + __tmp[__i].close(); + } + } + } + } + } + + /** + * Exponential Moving Average parameter update for DINOv2 teacher networks.
+ * Computes: output = decay * shadow + (1 - decay) * model
+ * Used in self-supervised learning to maintain a slowly-updated teacher model.
+ * + * @param model Current model parameters (student) (NUMERIC type) + * @param shadow EMA shadow parameters (teacher) (NUMERIC type) + * @param decay EMA decay factor (typically 0.996-0.9999) + * @return output Updated shadow parameters (NUMERIC type) + */ + public INDArray emaUpdate(INDArray model, INDArray shadow, double decay) { + NDValidation.validateNumerical("emaUpdate", "model", model); + NDValidation.validateNumerical("emaUpdate", "shadow", shadow); + INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.EmaUpdate(model, shadow, decay)); + try { + return __tmp[0]; + } finally { + if(__tmp != null) { + for(int __i = 1; __i < __tmp.length; __i++) { + if(__tmp[__i] != null) { + __tmp[__i].close(); + } + } + } + } + } + + /** + * Exponential Moving Average parameter update for DINOv2 teacher networks.
+ * Computes: output = decay * shadow + (1 - decay) * model
+ * Used in self-supervised learning to maintain a slowly-updated teacher model.
+ * + * @param model Current model parameters (student) (NUMERIC type) + * @param shadow EMA shadow parameters (teacher) (NUMERIC type) + * @return output Updated shadow parameters (NUMERIC type) + */ + public INDArray emaUpdate(INDArray model, INDArray shadow) { + NDValidation.validateNumerical("emaUpdate", "model", model); + NDValidation.validateNumerical("emaUpdate", "shadow", shadow); + INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.EmaUpdate(model, shadow, 0.999)); + try { + return __tmp[0]; + } finally { + if(__tmp != null) { + for(int __i = 1; __i < __tmp.length; __i++) { + if(__tmp[__i] != null) { + __tmp[__i].close(); + } + } + } + } + } + + /** + * Flash Attention - Memory-efficient attention computation.
+ *
+ * Uses tiled computation with online softmax to achieve O(N) memory complexity
+ * instead of O(N^2) for standard attention.
+ *
+ * Supports Grouped Query Attention (GQA) where numHeads > numKvHeads,
+ * allowing multiple query heads to share the same KV heads.
+ *
+ * out = softmax(Q * K^T / scale) * V
+ *
+ * See "FlashAttention: Fast and Memory-Efficient Exact Attention" (https://arxiv.org/abs/2205.14135)
+ * + * @param query Query tensor. Shape: [batch, seqLen, numHeads, headDim] (NUMERIC type) + * @param key Key tensor. Shape: [batch, seqLen, numKvHeads, headDim] (NUMERIC type) + * @param value Value tensor. Shape: [batch, seqLen, numKvHeads, headDim] (NUMERIC type) + * @param scale Scaling factor. 0 = auto (1/sqrt(headDim)) + * @param isCausal Whether to apply causal masking + * @param numHeads Number of query attention heads + * @param numKvHeads Number of KV heads (0 = same as numHeads, for GQA use smaller value) + * @return output Attention output. Shape: [batch, seqLen, numHeads, headDim] (NUMERIC type) + */ + public INDArray flashAttention(INDArray query, INDArray key, INDArray value, double scale, + boolean isCausal, int numHeads, int numKvHeads) { + NDValidation.validateNumerical("flashAttention", "query", query); + NDValidation.validateNumerical("flashAttention", "key", key); + NDValidation.validateNumerical("flashAttention", "value", value); + INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.FlashAttention(query, key, value, scale, isCausal, numHeads, numKvHeads)); + try { + return __tmp[0]; + } finally { + if(__tmp != null) { + for(int __i = 1; __i < __tmp.length; __i++) { + if(__tmp[__i] != null) { + __tmp[__i].close(); + } + } + } + } + } + + /** + * Executes a fused chain of element-wise operations in a single kernel pass.
+ * Intermediate values stay in registers instead of global memory. Replaces N separate kernel launches with 1.
+ * + * @param input Primary input array (NUMERIC type) + * @param secondaryInputs Optional secondary input arrays for binary ops (add, sub, mul, div) (NUMERIC type) + * @param opCodes Op codes: 0=add, 1=sub, 2=mul, 3=div, 10=relu, 11=sigmoid, 12=tanh, 13=gelu, 14=exp, 15=log, 16=abs, 17=neg, 18=square, 19=sqrt, 20=swish, 21=silu, 22=mish, 30=clip, 31=leaky_relu (Size: AtLeast(min=1)) + * @return output Result of applying the fused element-wise chain (NUMERIC type) + */ + public INDArray fusedElementwiseChain(INDArray input, INDArray[] secondaryInputs, int[] opCodes) { + NDValidation.validateNumerical("fusedElementwiseChain", "input", input); + NDValidation.validateNumerical("fusedElementwiseChain", "secondaryInputs", secondaryInputs); + Preconditions.checkArgument(secondaryInputs.length >= 0, "secondaryInputs has incorrect size/length. Expected: secondaryInputs.length >= 0, got %s", secondaryInputs.length); + Preconditions.checkArgument(opCodes.length >= 1, "opCodes has incorrect size/length. Expected: opCodes.length >= 1, got %s", opCodes.length); + INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.FusedElementwiseChain(input, secondaryInputs, opCodes)); + try { + return __tmp[0]; + } finally { + if(__tmp != null) { + for(int __i = 1; __i < __tmp.length; __i++) { + if(__tmp[__i] != null) { + __tmp[__i].close(); + } + } + } + } } /** @@ -224,6 +609,49 @@ public INDArray gelu(INDArray x) { return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.strict.GELU(x)); } + /** + * Grouped Query Attention (GQA) - Efficient attention with shared KV heads.
+ *
+ * Multiple query heads share the same key-value heads, reducing memory and
+ * computation while maintaining model quality. Used in LLaMA 2, Mistral, etc.
+ *
+ * numHeads must be divisible by numKvHeads. Each KV head is repeated
+ * (numHeads / numKvHeads) times to match query heads.
+ *
+ * Special cases:
+ * - numKvHeads == numHeads: Standard Multi-Head Attention (MHA)
+ * - numKvHeads == 1: Multi-Query Attention (MQA)
+ *
+ * See "GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints"
+ * + * @param query Query tensor. Shape: [batch, seqLen, numHeads, headDim] (NUMERIC type) + * @param key Key tensor. Shape: [batch, seqLen, numKvHeads, headDim] (NUMERIC type) + * @param value Value tensor. Shape: [batch, seqLen, numKvHeads, headDim] (NUMERIC type) + * @param scale Scaling factor. 0 = auto (1/sqrt(headDim)) + * @param isCausal Whether to apply causal masking + * @param numHeads Number of query attention heads + * @param numKvHeads Number of KV heads (must divide numHeads evenly) + * @return output Attention output. Shape: [batch, seqLen, numHeads, headDim] (NUMERIC type) + */ + public INDArray groupedQueryAttention(INDArray query, INDArray key, INDArray value, double scale, + boolean isCausal, int numHeads, int numKvHeads) { + NDValidation.validateNumerical("groupedQueryAttention", "query", query); + NDValidation.validateNumerical("groupedQueryAttention", "key", key); + NDValidation.validateNumerical("groupedQueryAttention", "value", value); + INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.GroupedQueryAttention(query, key, value, scale, isCausal, numHeads, numKvHeads)); + try { + return __tmp[0]; + } finally { + if(__tmp != null) { + for(int __i = 1; __i < __tmp.length; __i++) { + if(__tmp[__i] != null) { + __tmp[__i].close(); + } + } + } + } + } + /** * Element-wise hard sigmoid function:
* out[i] = 0 if in[i] <= -2.5
@@ -263,6 +691,108 @@ public INDArray hardTanhDerivative(INDArray x) { return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.gradient.HardTanhDerivative(x)); } + /** + * KV Cache Update - Updates key-value cache for autoregressive generation.
+ *
+ * During LLM inference, past key-value pairs are cached to avoid redundant
+ * computation during token-by-token generation. This operation efficiently
+ * inserts new keys/values at the specified position.
+ *
+ * Usage pattern:
+ * 1. Initialize cache with zeros: [batch, maxSeqLen, numKvHeads, headDim]
+ * 2. For each new token, compute new K/V and update cache
+ * 3. Use full cached K/V for attention computation
+ *
+ * Returns updated keyCache and valueCache tensors.
+ * + * @param keyCache Existing key cache. Shape: [batch, maxSeqLen, numKvHeads, headDim] (NUMERIC type) + * @param valueCache Existing value cache. Shape: [batch, maxSeqLen, numKvHeads, headDim] (NUMERIC type) + * @param newKeys New keys to insert. Shape: [batch, newSeqLen, numKvHeads, headDim] (NUMERIC type) + * @param newValues New values to insert. Shape: [batch, newSeqLen, numKvHeads, headDim] (NUMERIC type) + * @param startPosition Position in cache where new keys/values should be inserted + */ + public INDArray[] kvCacheUpdate(INDArray keyCache, INDArray valueCache, INDArray newKeys, + INDArray newValues, int startPosition) { + NDValidation.validateNumerical("kvCacheUpdate", "keyCache", keyCache); + NDValidation.validateNumerical("kvCacheUpdate", "valueCache", valueCache); + NDValidation.validateNumerical("kvCacheUpdate", "newKeys", newKeys); + NDValidation.validateNumerical("kvCacheUpdate", "newValues", newValues); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.KVCacheUpdate(keyCache, valueCache, newKeys, newValues, startPosition)); + } + + /** + * Batch KV cache scatter update for LLM autoregressive decoding.
+ *
+ * Copies a single time-step slice from each present KV tensor into the
+ * corresponding static KV buffer at a given cache position. Replaces N
+ * individual Java view+assign calls with a single native kernel launch.
+ *
+ * The present tensor has shape [batch, heads, seqLen, dim] where the new
+ * token's KV entry is at the last sequence position. This entry is extracted
+ * and written into the static buffer at cachePos.
+ *
+ * For multiple pairs, inputs are ordered as:
+ * [present_0, ..., present_{N-1}, static_0, ..., static_{N-1}]
+ * + * @param present Present KV tensor from decoder output. Shape: [batch, heads, seqLen, dim] (NUMERIC type) + * @param staticBuffer Static KV cache buffer. Shape: [batch, heads, maxKvLen, dim]. Updated in-place. (NUMERIC type) + * @param cachePos Position in static buffer to write the new entry + * @return output Scalar 0 on success (LONG type) + */ + public INDArray kvScatter(INDArray present, INDArray staticBuffer, long cachePos) { + NDValidation.validateNumerical("kvScatter", "present", present); + NDValidation.validateNumerical("kvScatter", "staticBuffer", staticBuffer); + INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.KvScatter(present, staticBuffer, cachePos, 1)); + try { + return __tmp[0]; + } finally { + if(__tmp != null) { + for(int __i = 1; __i < __tmp.length; __i++) { + if(__tmp[__i] != null) { + __tmp[__i].close(); + } + } + } + } + } + + /** + * Batch KV cache scatter update for LLM autoregressive decoding.
+ *
+ * Copies a single time-step slice from each present KV tensor into the
+ * corresponding static KV buffer at a given cache position. Replaces N
+ * individual Java view+assign calls with a single native kernel launch.
+ *
+ * The present tensor has shape [batch, heads, seqLen, dim] where the new
+ * token's KV entry is at the last sequence position. This entry is extracted
+ * and written into the static buffer at cachePos.
+ *
+ * For multiple pairs, inputs are ordered as:
+ * [present_0, ..., present_{N-1}, static_0, ..., static_{N-1}]
+ * + * @param present Present KV tensor from decoder output. Shape: [batch, heads, seqLen, dim] (NUMERIC type) + * @param staticBuffer Static KV cache buffer. Shape: [batch, heads, maxKvLen, dim]. Updated in-place. (NUMERIC type) + * @param cachePos Position in static buffer to write the new entry + * @param numPairs Number of present/static KV pairs. When > 1, inputs are [present_0..N-1, static_0..N-1] + * @return output Scalar 0 on success (LONG type) + */ + public INDArray kvScatter(INDArray present, INDArray staticBuffer, long cachePos, int numPairs) { + NDValidation.validateNumerical("kvScatter", "present", present); + NDValidation.validateNumerical("kvScatter", "staticBuffer", staticBuffer); + INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.KvScatter(present, staticBuffer, cachePos, numPairs)); + try { + return __tmp[0]; + } finally { + if(__tmp != null) { + for(int __i = 1; __i < __tmp.length; __i++) { + if(__tmp[__i] != null) { + __tmp[__i].close(); + } + } + } + } + } + /** * Apply Layer Normalization
*
@@ -281,7 +811,18 @@ public INDArray layerNorm(INDArray input, INDArray gain, INDArray bias, boolean NDValidation.validateNumerical("layerNorm", "gain", gain); NDValidation.validateNumerical("layerNorm", "bias", bias); Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.LayerNorm(input, gain, bias, channelsFirst, dimensions))[0]; + INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.LayerNorm(input, gain, bias, channelsFirst, dimensions)); + try { + return __tmp[0]; + } finally { + if(__tmp != null) { + for(int __i = 1; __i < __tmp.length; __i++) { + if(__tmp[__i] != null) { + __tmp[__i].close(); + } + } + } + } } /** @@ -300,7 +841,18 @@ public INDArray layerNorm(INDArray input, INDArray gain, boolean channelsFirst, NDValidation.validateNumerical("layerNorm", "input", input); NDValidation.validateNumerical("layerNorm", "gain", gain); Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.LayerNorm(input, gain, null, channelsFirst, dimensions))[0]; + INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.LayerNorm(input, gain, null, channelsFirst, dimensions)); + try { + return __tmp[0]; + } finally { + if(__tmp != null) { + for(int __i = 1; __i < __tmp.length; __i++) { + if(__tmp[__i] != null) { + __tmp[__i].close(); + } + } + } + } } /** @@ -347,7 +899,18 @@ public INDArray linear(INDArray input, INDArray weights, INDArray bias, boolean NDValidation.validateNumerical("linear", "input", input); NDValidation.validateNumerical("linear", "weights", weights); NDValidation.validateNumerical("linear", "bias", bias); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.XwPlusB(input, weights, bias, transposeA, transposeB, transposeC))[0]; + INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.XwPlusB(input, weights, bias, transposeA, transposeB, transposeC)); + try { + return __tmp[0]; + } finally { + if(__tmp != null) { + for(int __i = 1; __i < __tmp.length; __i++) { + if(__tmp[__i] != null) { + __tmp[__i].close(); + } + } + } + } } /** @@ -363,7 +926,18 @@ public INDArray linear(INDArray input, INDArray weights, INDArray bias) { NDValidation.validateNumerical("linear", "input", input); NDValidation.validateNumerical("linear", "weights", weights); NDValidation.validateNumerical("linear", "bias", bias); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.XwPlusB(input, weights, bias, false, false, false))[0]; + INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.XwPlusB(input, weights, bias, false, false, false)); + try { + return __tmp[0]; + } finally { + if(__tmp != null) { + for(int __i = 1; __i < __tmp.length; __i++) { + if(__tmp[__i] != null) { + __tmp[__i].close(); + } + } + } + } } /** @@ -385,7 +959,18 @@ public INDArray logSigmoid(INDArray x) { */ public INDArray logSoftmax(INDArray x) { NDValidation.validateNumerical("logSoftmax", "x", x); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.LogSoftMax(x))[0]; + INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.LogSoftMax(x)); + try { + return __tmp[0]; + } finally { + if(__tmp != null) { + for(int __i = 1; __i < __tmp.length; __i++) { + if(__tmp[__i] != null) { + __tmp[__i].close(); + } + } + } + } } /** @@ -397,7 +982,99 @@ public INDArray logSoftmax(INDArray x) { */ public INDArray logSoftmax(INDArray x, int dimension) { NDValidation.validateNumerical("logSoftmax", "x", x); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.LogSoftMax(x, dimension))[0]; + INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.LogSoftMax(x, dimension)); + try { + return __tmp[0]; + } finally { + if(__tmp != null) { + for(int __i = 1; __i < __tmp.length; __i++) { + if(__tmp[__i] != null) { + __tmp[__i].close(); + } + } + } + } + } + + /** + * Mixture of Experts (MoE) Layer.
+ *
+ * Implements sparse MoE routing where each token is processed by only the top-k
+ * selected experts out of a larger pool. This enables scaling model capacity
+ * without proportionally increasing computation.
+ *
+ * Used in large language models like:
+ * - DeepSeek (DeepSeekMoE)
+ * - Mixtral (Mistral AI)
+ * - Switch Transformer (Google)
+ * - GShard (Google)
+ *
+ * The router computes expert selection probabilities:
+ * router_probs = softmax(input @ routerWeights)
+ *
+ * Top-k experts are selected and their outputs are weighted by normalized probs:
+ * output = sum(normalized_prob[i] * expert[i](input) for i in top_k)
+ *
+ * Benefits:
+ * - Scales model capacity with sublinear compute increase
+ * - Enables very large models with efficient inference
+ * - Supports expert parallelism across devices
+ * + * @param input Input embeddings. Shape: [batch, seqLen, hiddenSize] (NUMERIC type) + * @param routerWeights Router projection weights. Shape: [hiddenSize, numExperts] (NUMERIC type) + * @param expertWeights Expert weight matrices. Shape: [numExperts, hiddenSize, expertHiddenSize] (NUMERIC type) + * @param numExperts Total number of experts + * @param topK Number of experts to route to per token + */ + public INDArray[] mixtureOfExperts(INDArray input, INDArray routerWeights, INDArray expertWeights, + int numExperts, int topK) { + NDValidation.validateNumerical("mixtureOfExperts", "input", input); + NDValidation.validateNumerical("mixtureOfExperts", "routerWeights", routerWeights); + NDValidation.validateNumerical("mixtureOfExperts", "expertWeights", expertWeights); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.MixtureOfExperts(input, routerWeights, expertWeights, null, numExperts, topK, true, 1.0)); + } + + /** + * Mixture of Experts (MoE) Layer.
+ *
+ * Implements sparse MoE routing where each token is processed by only the top-k
+ * selected experts out of a larger pool. This enables scaling model capacity
+ * without proportionally increasing computation.
+ *
+ * Used in large language models like:
+ * - DeepSeek (DeepSeekMoE)
+ * - Mixtral (Mistral AI)
+ * - Switch Transformer (Google)
+ * - GShard (Google)
+ *
+ * The router computes expert selection probabilities:
+ * router_probs = softmax(input @ routerWeights)
+ *
+ * Top-k experts are selected and their outputs are weighted by normalized probs:
+ * output = sum(normalized_prob[i] * expert[i](input) for i in top_k)
+ *
+ * Benefits:
+ * - Scales model capacity with sublinear compute increase
+ * - Enables very large models with efficient inference
+ * - Supports expert parallelism across devices
+ * + * @param input Input embeddings. Shape: [batch, seqLen, hiddenSize] (NUMERIC type) + * @param routerWeights Router projection weights. Shape: [hiddenSize, numExperts] (NUMERIC type) + * @param expertWeights Expert weight matrices. Shape: [numExperts, hiddenSize, expertHiddenSize] (NUMERIC type) + * @param expertBias Optional expert biases. Shape: [numExperts, expertHiddenSize] (NUMERIC type) + * @param numExperts Total number of experts + * @param topK Number of experts to route to per token + * @param normalizeProbs Whether to normalize router probabilities for selected experts + * @param capacityFactor Expert capacity factor for load balancing + */ + public INDArray[] mixtureOfExperts(INDArray input, INDArray routerWeights, INDArray expertWeights, + INDArray expertBias, int numExperts, int topK, boolean normalizeProbs, + double capacityFactor) { + NDValidation.validateNumerical("mixtureOfExperts", "input", input); + NDValidation.validateNumerical("mixtureOfExperts", "routerWeights", routerWeights); + NDValidation.validateNumerical("mixtureOfExperts", "expertWeights", expertWeights); + NDValidation.validateNumerical("mixtureOfExperts", "expertBias", expertBias); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.MixtureOfExperts(input, routerWeights, expertWeights, expertBias, numExperts, topK, normalizeProbs, capacityFactor)); } /** @@ -434,11 +1111,22 @@ public INDArray multiHeadDotProductAttention(INDArray queries, INDArray keys, IN NDValidation.validateNumerical("multiHeadDotProductAttention", "Wv", Wv); NDValidation.validateNumerical("multiHeadDotProductAttention", "Wo", Wo); NDValidation.validateNumerical("multiHeadDotProductAttention", "mask", mask); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.MultiHeadDotProductAttention(queries, keys, values, Wq, Wk, Wv, Wo, mask, scaled, false))[0]; + INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.MultiHeadDotProductAttention(queries, keys, values, Wq, Wk, Wv, Wo, mask, scaled, false)); + try { + return __tmp[0]; + } finally { + if(__tmp != null) { + for(int __i = 1; __i < __tmp.length; __i++) { + if(__tmp[__i] != null) { + __tmp[__i].close(); + } + } + } + } } /** - * Padding operation
+ * Padding operation
* * @param input Input tensor (NUMERIC type) * @param padding Padding value (NUMERIC type) @@ -449,11 +1137,22 @@ public INDArray multiHeadDotProductAttention(INDArray queries, INDArray keys, IN public INDArray pad(INDArray input, INDArray padding, PadMode PadMode, double constant) { NDValidation.validateNumerical("pad", "input", input); NDValidation.validateNumerical("pad", "padding", padding); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.Pad(input, padding, PadMode, constant))[0]; + INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.Pad(input, padding, PadMode, constant)); + try { + return __tmp[0]; + } finally { + if(__tmp != null) { + for(int __i = 1; __i < __tmp.length; __i++) { + if(__tmp[__i] != null) { + __tmp[__i].close(); + } + } + } + } } /** - * Padding operation
+ * Padding operation
* * @param input Input tensor (NUMERIC type) * @param padding Padding value (NUMERIC type) @@ -463,7 +1162,18 @@ public INDArray pad(INDArray input, INDArray padding, PadMode PadMode, double co public INDArray pad(INDArray input, INDArray padding, double constant) { NDValidation.validateNumerical("pad", "input", input); NDValidation.validateNumerical("pad", "padding", padding); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.Pad(input, padding, PadMode.CONSTANT, constant))[0]; + INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.Pad(input, padding, PadMode.CONSTANT, constant)); + try { + return __tmp[0]; + } finally { + if(__tmp != null) { + for(int __i = 1; __i < __tmp.length; __i++) { + if(__tmp[__i] != null) { + __tmp[__i].close(); + } + } + } + } } /** @@ -498,7 +1208,105 @@ public INDArray prelu(INDArray input, INDArray alpha, int... sharedAxes) { NDValidation.validateNumerical("prelu", "input", input); NDValidation.validateNumerical("prelu", "alpha", alpha); Preconditions.checkArgument(sharedAxes.length >= 1, "sharedAxes has incorrect size/length. Expected: sharedAxes.length >= 1, got %s", sharedAxes.length); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.scalar.PRelu(input, alpha, sharedAxes))[0]; + INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.scalar.PRelu(input, alpha, sharedAxes)); + try { + return __tmp[0]; + } finally { + if(__tmp != null) { + for(int __i = 1; __i < __tmp.length; __i++) { + if(__tmp[__i] != null) { + __tmp[__i].close(); + } + } + } + } + } + + /** + * Relative Position Bias - Compute relative position bias for attention.
+ *
+ * Supports two modes:
+ * 1. Learned bias (Swin/SAM style): Looks up bias values from a learned table
+ * based on relative positions between query and key positions.
+ *
+ * 2. ALiBi (Attention with Linear Biases): Computes linear position-based bias
+ * without learned parameters. More efficient for very long sequences.
+ *
+ * For learned bias mode:
+ * - biasTable shape: [(2*windowSize-1)^2, numHeads] for 2D
+ * - Output is gathered based on relative position indices
+ *
+ * For ALiBi mode:
+ * - biasTable can be sequence length (scalar) or input tensor
+ * - Computes m_h * |i - j| where m_h = 2^(-8*h/H)
+ *
+ * Reference: "Swin Transformer" (Liu et al., 2021)
+ * "Train Short, Test Long" (Press et al., 2021) for ALiBi
+ * + * @param biasTable Learned bias table. Shape: [numRelativePositions, numHeads] for learned mode, or scalar/tensor for ALiBi mode (NUMERIC type) + * @param numHeads Number of attention heads + * @param windowSize Window size for 2D position encoding (used if generating index) + * @return output Position bias. Shape: [numHeads, windowSize^2, windowSize^2] or [numHeads, seqLen, seqLen] (NUMERIC type) + */ + public INDArray relativePositionBias(INDArray biasTable, int numHeads, int windowSize) { + NDValidation.validateNumerical("relativePositionBias", "biasTable", biasTable); + INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.RelativePositionBias(biasTable, null, numHeads, windowSize, false)); + try { + return __tmp[0]; + } finally { + if(__tmp != null) { + for(int __i = 1; __i < __tmp.length; __i++) { + if(__tmp[__i] != null) { + __tmp[__i].close(); + } + } + } + } + } + + /** + * Relative Position Bias - Compute relative position bias for attention.
+ *
+ * Supports two modes:
+ * 1. Learned bias (Swin/SAM style): Looks up bias values from a learned table
+ * based on relative positions between query and key positions.
+ *
+ * 2. ALiBi (Attention with Linear Biases): Computes linear position-based bias
+ * without learned parameters. More efficient for very long sequences.
+ *
+ * For learned bias mode:
+ * - biasTable shape: [(2*windowSize-1)^2, numHeads] for 2D
+ * - Output is gathered based on relative position indices
+ *
+ * For ALiBi mode:
+ * - biasTable can be sequence length (scalar) or input tensor
+ * - Computes m_h * |i - j| where m_h = 2^(-8*h/H)
+ *
+ * Reference: "Swin Transformer" (Liu et al., 2021)
+ * "Train Short, Test Long" (Press et al., 2021) for ALiBi
+ * + * @param biasTable Learned bias table. Shape: [numRelativePositions, numHeads] for learned mode, or scalar/tensor for ALiBi mode (NUMERIC type) + * @param relativePositionIndex Optional precomputed relative position index. Shape: [windowSize^2, windowSize^2] (NUMERIC type) + * @param numHeads Number of attention heads + * @param windowSize Window size for 2D position encoding (used if generating index) + * @return output Position bias. Shape: [numHeads, windowSize^2, windowSize^2] or [numHeads, seqLen, seqLen] (NUMERIC type) + */ + public INDArray relativePositionBias(INDArray biasTable, INDArray relativePositionIndex, + int numHeads, int windowSize) { + NDValidation.validateNumerical("relativePositionBias", "biasTable", biasTable); + NDValidation.validateNumerical("relativePositionBias", "relativePositionIndex", relativePositionIndex); + INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.RelativePositionBias(biasTable, relativePositionIndex, numHeads, windowSize, false)); + try { + return __tmp[0]; + } finally { + if(__tmp != null) { + for(int __i = 1; __i < __tmp.length; __i++) { + if(__tmp[__i] != null) { + __tmp[__i].close(); + } + } + } + } } /** @@ -540,7 +1348,128 @@ public INDArray reluLayer(INDArray input, INDArray weights, INDArray bias) { NDValidation.validateNumerical("reluLayer", "input", input); NDValidation.validateNumerical("reluLayer", "weights", weights); NDValidation.validateNumerical("reluLayer", "bias", bias); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.ReluLayer(input, weights, bias))[0]; + INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.ReluLayer(input, weights, bias)); + try { + return __tmp[0]; + } finally { + if(__tmp != null) { + for(int __i = 1; __i < __tmp.length; __i++) { + if(__tmp[__i] != null) { + __tmp[__i].close(); + } + } + } + } + } + + /** + * Root Mean Square Layer Normalization (RMSNorm):
+ *
+ * output = input * rsqrt(mean(input^2, axis=-1) + epsilon) * gamma
+ *
+ * If gamma is not provided, only RMS normalization is applied.
+ * + * @param input Input variable (NUMERIC type) + * @param gamma Scale/gain vector (NUMERIC type) + * @param epsilon Epsilon for numerical stability + * @return output RMS normalized output (NUMERIC type) + */ + public INDArray rmsNorm(INDArray input, INDArray gamma, double epsilon) { + NDValidation.validateNumerical("rmsNorm", "input", input); + NDValidation.validateNumerical("rmsNorm", "gamma", gamma); + INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.RmsNorm(input, gamma, epsilon)); + try { + return __tmp[0]; + } finally { + if(__tmp != null) { + for(int __i = 1; __i < __tmp.length; __i++) { + if(__tmp[__i] != null) { + __tmp[__i].close(); + } + } + } + } + } + + /** + * Root Mean Square Layer Normalization (RMSNorm):
+ *
+ * output = input * rsqrt(mean(input^2, axis=-1) + epsilon) * gamma
+ *
+ * If gamma is not provided, only RMS normalization is applied.
+ * + * @param input Input variable (NUMERIC type) + * @param gamma Scale/gain vector (NUMERIC type) + * @return output RMS normalized output (NUMERIC type) + */ + public INDArray rmsNorm(INDArray input, INDArray gamma) { + NDValidation.validateNumerical("rmsNorm", "input", input); + NDValidation.validateNumerical("rmsNorm", "gamma", gamma); + INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.RmsNorm(input, gamma, 1.0E-5)); + try { + return __tmp[0]; + } finally { + if(__tmp != null) { + for(int __i = 1; __i < __tmp.length; __i++) { + if(__tmp[__i] != null) { + __tmp[__i].close(); + } + } + } + } + } + + /** + * Root Mean Square Layer Normalization (RMSNorm):
+ *
+ * output = input * rsqrt(mean(input^2, axis=-1) + epsilon) * gamma
+ *
+ * If gamma is not provided, only RMS normalization is applied.
+ * + * @param input Input variable (NUMERIC type) + * @param epsilon Epsilon for numerical stability + * @return output RMS normalized output (NUMERIC type) + */ + public INDArray rmsNorm(INDArray input, double epsilon) { + NDValidation.validateNumerical("rmsNorm", "input", input); + INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.RmsNorm(input, null, epsilon)); + try { + return __tmp[0]; + } finally { + if(__tmp != null) { + for(int __i = 1; __i < __tmp.length; __i++) { + if(__tmp[__i] != null) { + __tmp[__i].close(); + } + } + } + } + } + + /** + * Root Mean Square Layer Normalization (RMSNorm):
+ *
+ * output = input * rsqrt(mean(input^2, axis=-1) + epsilon) * gamma
+ *
+ * If gamma is not provided, only RMS normalization is applied.
+ * + * @param input Input variable (NUMERIC type) + * @return output RMS normalized output (NUMERIC type) + */ + public INDArray rmsNorm(INDArray input) { + NDValidation.validateNumerical("rmsNorm", "input", input); + INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.RmsNorm(input, null, 1.0E-5)); + try { + return __tmp[0]; + } finally { + if(__tmp != null) { + for(int __i = 1; __i < __tmp.length; __i++) { + if(__tmp[__i] != null) { + __tmp[__i].close(); + } + } + } + } } /** @@ -578,7 +1507,61 @@ public INDArray sigmoid(INDArray x) { public INDArray sigmoidDerivative(INDArray x, INDArray wrt) { NDValidation.validateNumerical("sigmoidDerivative", "x", x); NDValidation.validateNumerical("sigmoidDerivative", "wrt", wrt); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.gradient.SigmoidDerivative(x, wrt))[0]; + INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.gradient.SigmoidDerivative(x, wrt)); + try { + return __tmp[0]; + } finally { + if(__tmp != null) { + for(int __i = 1; __i < __tmp.length; __i++) { + if(__tmp[__i] != null) { + __tmp[__i].close(); + } + } + } + } + } + + /** + * Sliding Window Attention - Efficient attention for long sequences.
+ *
+ * Each token only attends to a fixed window of previous tokens, enabling
+ * efficient processing of very long sequences. Used in Mistral and other
+ * modern LLMs for handling long contexts.
+ *
+ * Benefits:
+ * - O(N * windowSize) complexity instead of O(N^2)
+ * - Memory efficient for long sequences
+ * - Supports very long context lengths (e.g., 32K with 4K window)
+ *
+ * The attention mask is automatically applied to restrict each position
+ * to only attend to positions within [pos - windowSize, pos].
+ * + * @param query Query tensor. Shape: [batch, seqLen, numHeads, headDim] (NUMERIC type) + * @param key Key tensor. Shape: [batch, seqLen, numKvHeads, headDim] (NUMERIC type) + * @param value Value tensor. Shape: [batch, seqLen, numKvHeads, headDim] (NUMERIC type) + * @param windowSize Sliding window size - tokens can only attend to this many previous positions + * @param numHeads Number of query attention heads + * @param numKvHeads Number of KV heads (0 = same as numHeads) + * @param scale Scaling factor. 0 = auto (1/sqrt(headDim)) + * @return output Attention output. Shape: [batch, seqLen, numHeads, headDim] (NUMERIC type) + */ + public INDArray slidingWindowAttention(INDArray query, INDArray key, INDArray value, + int windowSize, int numHeads, int numKvHeads, double scale) { + NDValidation.validateNumerical("slidingWindowAttention", "query", query); + NDValidation.validateNumerical("slidingWindowAttention", "key", key); + NDValidation.validateNumerical("slidingWindowAttention", "value", value); + INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.SlidingWindowAttention(query, key, value, windowSize, numHeads, numKvHeads, scale)); + try { + return __tmp[0]; + } finally { + if(__tmp != null) { + for(int __i = 1; __i < __tmp.length; __i++) { + if(__tmp[__i] != null) { + __tmp[__i].close(); + } + } + } + } } /** @@ -590,7 +1573,18 @@ public INDArray sigmoidDerivative(INDArray x, INDArray wrt) { */ public INDArray softmax(INDArray x, int dimension) { NDValidation.validateNumerical("softmax", "x", x); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax(x, dimension))[0]; + INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax(x, dimension)); + try { + return __tmp[0]; + } finally { + if(__tmp != null) { + for(int __i = 1; __i < __tmp.length; __i++) { + if(__tmp[__i] != null) { + __tmp[__i].close(); + } + } + } + } } /** @@ -601,7 +1595,18 @@ public INDArray softmax(INDArray x, int dimension) { */ public INDArray softmax(INDArray x) { NDValidation.validateNumerical("softmax", "x", x); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax(x, -1))[0]; + INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax(x, -1)); + try { + return __tmp[0]; + } finally { + if(__tmp != null) { + for(int __i = 1; __i < __tmp.length; __i++) { + if(__tmp[__i] != null) { + __tmp[__i].close(); + } + } + } + } } /** @@ -660,6 +1665,75 @@ public INDArray tanh(INDArray x) { return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.strict.Tanh(x)); } + /** + * Token sampling for LLM inference.
+ *
+ * Full sampling pipeline in a single native GPU call:
+ * temperature scaling -> top-K filtering -> softmax -> top-P filtering -> sample/argmax
+ *
+ * For greedy decoding (temperature=0 or no top-k/top-p), performs GPU-side argmax
+ * with shared-memory reduction — avoids transferring the full logits tensor to host.
+ *
+ * Supports rank 1 [vocabSize], rank 2 [batch, vocabSize], and rank 3
+ * [batch, seqLen, vocabSize] inputs. For rank 3, the last sequence position
+ * is automatically extracted for sampling.
+ * + * @param logits Logits tensor. Shape: [vocabSize], [batch, vocabSize], or [batch, seqLen, vocabSize]. For rank-3, samples from the last sequence position. (NUMERIC type) + * @return output Sampled token indices. Shape: [batch] or scalar (LONG type) + */ + public INDArray tokenSample(INDArray logits) { + NDValidation.validateNumerical("tokenSample", "logits", logits); + INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.TokenSample(logits, 0.0, 0, 0.0, 0)); + try { + return __tmp[0]; + } finally { + if(__tmp != null) { + for(int __i = 1; __i < __tmp.length; __i++) { + if(__tmp[__i] != null) { + __tmp[__i].close(); + } + } + } + } + } + + /** + * Token sampling for LLM inference.
+ *
+ * Full sampling pipeline in a single native GPU call:
+ * temperature scaling -> top-K filtering -> softmax -> top-P filtering -> sample/argmax
+ *
+ * For greedy decoding (temperature=0 or no top-k/top-p), performs GPU-side argmax
+ * with shared-memory reduction — avoids transferring the full logits tensor to host.
+ *
+ * Supports rank 1 [vocabSize], rank 2 [batch, vocabSize], and rank 3
+ * [batch, seqLen, vocabSize] inputs. For rank 3, the last sequence position
+ * is automatically extracted for sampling.
+ * + * @param logits Logits tensor. Shape: [vocabSize], [batch, vocabSize], or [batch, seqLen, vocabSize]. For rank-3, samples from the last sequence position. (NUMERIC type) + * @param temperature Temperature for sampling. 0 = greedy (argmax) + * @param topK Top-K filtering: keep only top K logits. 0 = disabled + * @param topP Top-P (nucleus) filtering threshold. 0 = disabled + * @param seed Random seed for sampling. 0 = random + * @return output Sampled token indices. Shape: [batch] or scalar (LONG type) + */ + public INDArray tokenSample(INDArray logits, double temperature, int topK, double topP, + long seed) { + NDValidation.validateNumerical("tokenSample", "logits", logits); + INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.TokenSample(logits, temperature, topK, topP, seed)); + try { + return __tmp[0]; + } finally { + if(__tmp != null) { + for(int __i = 1; __i < __tmp.length; __i++) { + if(__tmp[__i] != null) { + __tmp[__i].close(); + } + } + } + } + } + /** * Find values and indices for the largest k entries along the last dimension.
* @@ -671,4 +1745,152 @@ public INDArray[] topK(INDArray input, double k, boolean sorted) { NDValidation.validateNumerical("topK", "input", input); return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.TopK(input, k, sorted)); } + + /** + * SAM-style Two-Way Cross Attention.
+ * Bidirectional cross-attention where tokens attend to image features and
+ * image features attend to tokens simultaneously:
+ * tokenOutput = softmax(tokenQ @ imageK^T * scale) @ imageV
+ * imageOutput = softmax(imageQ @ tokenK^T * scale) @ tokenV
+ * + * @param tokenQuery Token queries [batch, tokenSeqLen, embedDim] (NUMERIC type) + * @param tokenKey Token keys [batch, tokenSeqLen, embedDim] (NUMERIC type) + * @param tokenValue Token values [batch, tokenSeqLen, embedDim] (NUMERIC type) + * @param imageQuery Image queries [batch, imageSeqLen, embedDim] (NUMERIC type) + * @param imageKey Image keys [batch, imageSeqLen, embedDim] (NUMERIC type) + * @param imageValue Image values [batch, imageSeqLen, embedDim] (NUMERIC type) + * @param scale Attention scale factor (default: 1/sqrt(embedDim)) + */ + public INDArray[] twoWayCrossAttention(INDArray tokenQuery, INDArray tokenKey, + INDArray tokenValue, INDArray imageQuery, INDArray imageKey, INDArray imageValue, + double scale) { + NDValidation.validateNumerical("twoWayCrossAttention", "tokenQuery", tokenQuery); + NDValidation.validateNumerical("twoWayCrossAttention", "tokenKey", tokenKey); + NDValidation.validateNumerical("twoWayCrossAttention", "tokenValue", tokenValue); + NDValidation.validateNumerical("twoWayCrossAttention", "imageQuery", imageQuery); + NDValidation.validateNumerical("twoWayCrossAttention", "imageKey", imageKey); + NDValidation.validateNumerical("twoWayCrossAttention", "imageValue", imageValue); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.TwoWayCrossAttention(tokenQuery, tokenKey, tokenValue, imageQuery, imageKey, imageValue, scale)); + } + + /** + * SAM-style Two-Way Cross Attention.
+ * Bidirectional cross-attention where tokens attend to image features and
+ * image features attend to tokens simultaneously:
+ * tokenOutput = softmax(tokenQ @ imageK^T * scale) @ imageV
+ * imageOutput = softmax(imageQ @ tokenK^T * scale) @ tokenV
+ * + * @param tokenQuery Token queries [batch, tokenSeqLen, embedDim] (NUMERIC type) + * @param tokenKey Token keys [batch, tokenSeqLen, embedDim] (NUMERIC type) + * @param tokenValue Token values [batch, tokenSeqLen, embedDim] (NUMERIC type) + * @param imageQuery Image queries [batch, imageSeqLen, embedDim] (NUMERIC type) + * @param imageKey Image keys [batch, imageSeqLen, embedDim] (NUMERIC type) + * @param imageValue Image values [batch, imageSeqLen, embedDim] (NUMERIC type) + */ + public INDArray[] twoWayCrossAttention(INDArray tokenQuery, INDArray tokenKey, + INDArray tokenValue, INDArray imageQuery, INDArray imageKey, INDArray imageValue) { + NDValidation.validateNumerical("twoWayCrossAttention", "tokenQuery", tokenQuery); + NDValidation.validateNumerical("twoWayCrossAttention", "tokenKey", tokenKey); + NDValidation.validateNumerical("twoWayCrossAttention", "tokenValue", tokenValue); + NDValidation.validateNumerical("twoWayCrossAttention", "imageQuery", imageQuery); + NDValidation.validateNumerical("twoWayCrossAttention", "imageKey", imageKey); + NDValidation.validateNumerical("twoWayCrossAttention", "imageValue", imageValue); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.TwoWayCrossAttention(tokenQuery, tokenKey, tokenValue, imageQuery, imageKey, imageValue, 0.0)); + } + + /** + * Windowed Attention - Local/Sliding Window Attention.
+ *
+ * Implements windowed attention mechanisms used in efficient transformers like
+ * Longformer, BigBird, Swin Transformer, and SAM (Segment Anything Model).
+ *
+ * Supports both:
+ * - 1D windowed attention: for sequences [batch, seqLen, heads, dim]
+ * - 2D windowed attention: for images [batch, height, width, heads, dim]
+ *
+ * Shifted window attention (shiftSize > 0) enables cross-window connections
+ * as used in Swin Transformer.
+ *
+ * Benefits:
+ * - O(N * windowSize) complexity instead of O(N^2)
+ * - Efficient for long sequences and high-resolution images
+ * - Supports relative position bias for position-aware attention
+ * + * @param query Query tensor. Shape: [batch, seqLen, numHeads, headDim] for 1D or [batch, height, width, numHeads, headDim] for 2D (NUMERIC type) + * @param key Key tensor. Same shape as query (NUMERIC type) + * @param value Value tensor. Same shape as query (NUMERIC type) + * @param windowSize Size of attention window + * @param numHeads Number of attention heads + * @return output Attention output. Same shape as query (NUMERIC type) + */ + public INDArray windowedAttention(INDArray query, INDArray key, INDArray value, int windowSize, + int numHeads) { + NDValidation.validateNumerical("windowedAttention", "query", query); + NDValidation.validateNumerical("windowedAttention", "key", key); + NDValidation.validateNumerical("windowedAttention", "value", value); + INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.WindowedAttention(query, key, value, null, null, windowSize, numHeads, 0, 0.0, false)); + try { + return __tmp[0]; + } finally { + if(__tmp != null) { + for(int __i = 1; __i < __tmp.length; __i++) { + if(__tmp[__i] != null) { + __tmp[__i].close(); + } + } + } + } + } + + /** + * Windowed Attention - Local/Sliding Window Attention.
+ *
+ * Implements windowed attention mechanisms used in efficient transformers like
+ * Longformer, BigBird, Swin Transformer, and SAM (Segment Anything Model).
+ *
+ * Supports both:
+ * - 1D windowed attention: for sequences [batch, seqLen, heads, dim]
+ * - 2D windowed attention: for images [batch, height, width, heads, dim]
+ *
+ * Shifted window attention (shiftSize > 0) enables cross-window connections
+ * as used in Swin Transformer.
+ *
+ * Benefits:
+ * - O(N * windowSize) complexity instead of O(N^2)
+ * - Efficient for long sequences and high-resolution images
+ * - Supports relative position bias for position-aware attention
+ * + * @param query Query tensor. Shape: [batch, seqLen, numHeads, headDim] for 1D or [batch, height, width, numHeads, headDim] for 2D (NUMERIC type) + * @param key Key tensor. Same shape as query (NUMERIC type) + * @param value Value tensor. Same shape as query (NUMERIC type) + * @param relativePositionBias Optional relative position bias. Shape: [numHeads, windowSize, windowSize] (NUMERIC type) + * @param attentionMask Optional attention mask (NUMERIC type) + * @param windowSize Size of attention window + * @param numHeads Number of attention heads + * @param shiftSize Shift size for shifted window attention (Swin style). 0 = no shift + * @param scale Attention scale factor. 0 = auto (1/sqrt(headDim)) + * @param returnWeights Whether to return attention weights + * @return output Attention output. Same shape as query (NUMERIC type) + */ + public INDArray windowedAttention(INDArray query, INDArray key, INDArray value, + INDArray relativePositionBias, INDArray attentionMask, int windowSize, int numHeads, + int shiftSize, double scale, boolean returnWeights) { + NDValidation.validateNumerical("windowedAttention", "query", query); + NDValidation.validateNumerical("windowedAttention", "key", key); + NDValidation.validateNumerical("windowedAttention", "value", value); + NDValidation.validateNumerical("windowedAttention", "relativePositionBias", relativePositionBias); + NDValidation.validateNumerical("windowedAttention", "attentionMask", attentionMask); + INDArray[] __tmp = Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.WindowedAttention(query, key, value, relativePositionBias, attentionMask, windowSize, numHeads, shiftSize, scale, returnWeights)); + try { + return __tmp[0]; + } finally { + if(__tmp != null) { + for(int __i = 1; __i < __tmp.length; __i++) { + if(__tmp[__i] != null) { + __tmp[__i].close(); + } + } + } + } + } }