A pure F# Qwen3 NVFP4 training/inference project.
This repo is built to help F# developers learn and run end-to-end LLM workflows in one codebase: weight loading, forward pass, loss, backward pass, optimizer step, .dat export, and inference validation.
Traditional Chinese version: README.zh-TW.md.
- F# developers entering LLM engineering
- Engineers who want to understand NVFP4 data/compute paths
- Contributors who want both training and inference paths in F#
- Pure F# core (
Types/Cli/Nvfp4State/Qwen3Model/InferenceBridge/Trainer/Program) - Uses
FAkka.TorchSharp.DGX 26.1.0-py3.9 - Uses
TorchSharp.Q4.Extensionfor NVFP4 quant/dequant + kernel paths - Qwen3 block wiring (
q/k/v/o + norm + attn + mlp + residual) - GQA-aware config (
num_attention_heads,num_key_value_heads) - Training loss modes:
ce: token-level cross entropy (main LM path)scalar: hidden-state L1 (debug/baseline)
- Chunked + streaming optimizer step support to reduce update-time memory spikes
cd /workspace/Qwen3-4B-Instruct-2507-TorchSharp.fs
dotnet build -c Releasedotnet run -c Release -- --helpdotnet run -c Release -- \
--synthetic true \
--device cuda \
--epochs 1 \
--steps-per-epoch 1 \
--batch-size 1 \
--in-features 64 \
--out-features 64 \
--model-dir /workspace/models/qwen3-4b-instruct-2507-torchsharp \
--config /workspace/models/qwen3-4b-instruct-2507-torchsharp/config.json \
--tokenizer /workspace/models/qwen3-4b-instruct-2507-torchsharp/tokenizer.json \
--weight /workspace/models/qwen3-4b-instruct-2507-torchsharp/Qwen3-4B-Instruct-2507-nvfp4.datTorchSharp.Q4.Extension now resolves libNVFP4.so in this order:
NVFP4_LIB_PATH(if set)- Same output directory as the running app (
Qwen3-4B-Instruct-2507-TorchSharp.fs.dll) runtimes/linux-arm64/native/libNVFP4.sounder the app output directory/workspace/nvfp4_native/libNVFP4.so- System loader path (
libNVFP4.so)
For direct local builds, the most reliable setup is placing libNVFP4.so next to the generated app DLL in bin/Release/net10.0/.
cd /workspace/Qwen3-4B-Instruct-2507-TorchSharp.fs
dotnet fsi scripts/Train.OneStep.fsx \
--device cuda \
--loss ce \
--seq-len 8 \
--step-chunk-rows 16 \
--train-data TrainData/train-inputs.txt \
--vram-report doc/train-step-vram-onestep.jsoncd /workspace/Qwen3-4B-Instruct-2507-TorchSharp.fs
dotnet fsi scripts/Train.WhoAmI.AndExportDat.fsx \
--input-dat /models/qwen3-4b-instruct-2507-torchsharp/Qwen3-4B-Instruct-2507-nvfp4.dat \
--output-dat artifacts/whoami-trained.dat \
--train-data TrainData/stageL-1percent-mix.tsv \
--loss ce \
--steps 10 \
--lr 5e-5 \
--seq-len 96 \
--step-chunk-rows 32 \
--compute-dtype float16Notes:
--train-last-layers <= 0means full-parameter mode (default).--train-last-layers Nis debug mode (train only last N layers).- Export writes a new
.datfile unless you explicitly reuse an existing path.
For validation with local runner-arm64-fp4, use the guard wrapper:
cd /workspace/home/qwen3.fs.experiments/Qwen3-4B-Instruct-2507-TorchSharp.fs/runner-arm64-fp4
dotnet fsi run-script-with-guard.fsx \
--gpu-limit-gb 108 \
--gpu-over-secs 0 \
--gpu-poll-secs 0.5 \
script run-training2.fsx \
--model-dir /workspace/models/qwen3-4b-instruct-2507-torchsharp \
--config /workspace/models/qwen3-4b-instruct-2507-torchsharp/config.json \
--tokenizer /workspace/models/qwen3-4b-instruct-2507-torchsharp/tokenizer.json \
--weight /workspace/models/Qwen3-4B-Instruct-2507-TorchSharp.fs/artifacts/whoami-1000-seq192-r8-s10-lr1e3.dat \
--prompt 你是誰 \
--max-tokens 24 \
--check-logits false \
--timing true \
--stop-here true \
--KVCacheOut trueNote: --stop-here true is a smoke-test switch. The script will finish work and then throw stop here, so a non-zero exit is expected.
cd /workspace/home/qwen3.fs.experiments/Qwen3-4B-Instruct-2507-TorchSharp.fs/runner-arm64-fp4
dotnet fsi run-script-with-guard.fsx \
--gpu-limit-gb 108 \
--gpu-over-secs 0 \
--gpu-poll-secs 0.5 \
script /workspace/home/qwen3.fs.experiments/Qwen3-4B-Instruct-2507-TorchSharp.fs/scripts/Train.WhoAmI.AndExportDat.fsx \
--model-dir /workspace/models/qwen3-4b-instruct-2507-torchsharp \
--input-dat /workspace/models/Qwen3-4B-Instruct-2507-TorchSharp.fs/artifacts/whoami-1000-seq192-r8-s10-lr1e3.dat \
--output-dat /workspace/home/qwen3.fs.experiments/Qwen3-4B-Instruct-2507-TorchSharp.fs/artifacts/whoami-trained.dat \
--train-data /workspace/home/qwen3.fs.experiments/Qwen3-4B-Instruct-2507-TorchSharp.fs/TrainData/whoami-1000-natural.tsv \
--loss ce \
--steps 1 \
--lr 1e-4 \
--seq-len 64 \
--step-chunk-rows 16 \
--compute-dtype float16 \
--device cudaSee artifacts/BASELINE_BRIDGE_SUCCESS.md for a fixed baseline command/dat.
For FP2 training-path inference (run-training-fp2.fsx) usage and differences vs run-training2.fsx, see:
runner-arm64-fp4/README.md
If you want to study this stack from top-level F# scripts down to native CUDA symbols, start from these repositories:
| Project | Link | Role in this project | Lineage / relationship |
|---|---|---|---|
| TorchSharp.Fun | https://github.com/ingted/TorchSharp.Fun | Functional composition helpers used by the training/inference workflow style | Base functional model-composition layer. In this workspace it is consumed via the DGX-focused variant (/workspace/TorchSharp.Fun.DGX). |
| nvfp4_native | https://github.com/ingted/nvfp4_native | Native CUDA bridge exposing NVFP4 ops (scaled_mm, quantize/dequantize, cache ops) |
Lowest-level native backend used by the FP4 path. This project links to libNVFP4.so in runtime environments. |
| TorchSharp_In_DGX_Spark | https://github.com/ingted/TorchSharp_In_DGX_Spark | ARM64 DGX-oriented TorchSharp distribution and native wrapper build path | Infrastructure/base for DGX TorchSharp runtime. In this workspace, FP4 work continues in /workspace/TorchSharp_In_DGX_Spark_fp4 (with TorchSharp.Q4.Extension). |
Suggested reading order:
TorchSharp_In_DGX_Spark(runtime/native foundation)nvfp4_native(NVFP4 native ops and exported symbols)TorchSharp.Fun(F# functional composition style)- This repository (
Qwen3-4B-Instruct-2507-TorchSharp.fs) for end-to-end training/inference integration
If you are an F# developer evaluating DGX Spark and worried about TorchSharp support: this stack is already proven in this workspace.
Build from source (base runtime):
cd /workspace/TorchSharp_In_DGX_Spark
bash build_TorchSharp.Native.sh
bash build_TorchSharp.net.shBuild FP4 branch used by this project:
cd /workspace/TorchSharp_In_DGX_Spark_fp4
bash build_TorchSharp.Native.sh
bash build_TorchSharp.net.sh
dotnet build TorchSharp.Q4.Extension/TorchSharp.Q4.Extension.fsproj -c ReleaseThen rebuild this repository:
cd /workspace/Qwen3-4B-Instruct-2507-TorchSharp.fs
dotnet build -c ReleaseThis project (Qwen3-4B-Instruct-2507-TorchSharp.fs) depends on that DGX Spark TorchSharp foundation.
graph TD
HF[Official Qwen3 HF weights] --> EXP[export_qwen3.py]
EXP --> FP16[Qwen3-4B-Instruct-2507-fp16.dat]
FP16 --> QNV[quantize_qwen3_to_nvfp4.py]
QNV --> NV[Qwen3-4B-Instruct-2507-nvfp4.dat]
NV --> PROJ[Qwen3-4B-Instruct-2507-TorchSharp.fs]
NV --> RUN[runner-arm64-fp4 scripts]
NATIVE[nvfp4_native: libNVFP4.so] --> Q4EXT[TorchSharp.Q4.Extension]
Q4EXT --> PROJ
DGX[TorchSharp_In_DGX_Spark / _fp4] --> Q4EXT
FUN[TorchSharp.Fun / TorchSharp.Fun.DGX] --> PROJ
This is the canonical path used in this workspace.
- Companion scripts are in:
/workspace/fsann/Qwen3-4B-Instruct-2507-TorchSharp-mod/Qwen3/export_qwen3.py/workspace/fsann/Qwen3-4B-Instruct-2507-TorchSharp-mod/Qwen3/quantize_qwen3_to_nvfp4.py/workspace/fsann/Qwen3-4B-Instruct-2507-TorchSharp-mod/Qwen3/dat_reader.py
If you already have Qwen3-4B-Instruct-2507-fp16.dat, skip this step.
cd /workspace/fsann/Qwen3-4B-Instruct-2507-TorchSharp-mod/Qwen3
python export_qwen3.py \
--model /path/to/official_qwen3_hf_model \
--dtype float16 \
--quant none \
--out /models/qwen3-4b-instruct-2507-torchsharp/Qwen3-4B-Instruct-2507-fp16.datcd /workspace/fsann/Qwen3-4B-Instruct-2507-TorchSharp-mod/Qwen3
python quantize_qwen3_to_nvfp4.pyBy default that script uses:
- input:
/models/qwen3-4b-instruct-2507-torchsharp/Qwen3-4B-Instruct-2507-fp16.dat - output:
/models/qwen3-4b-instruct-2507-torchsharp/Qwen3-4B-Instruct-2507-nvfp4.dat
Important behavior in quantization script:
- It quantizes selected 2D projection weights only:
q_proj,k_proj,v_proj,o_proj,gate_proj,up_proj,down_proj,lm_head
- It keeps
embed_tokensandnormin non-NVFP4 raw tensors. - Quantized tensors are stored as paired keys:
*.qdata*.scale
Quick runtime verification:
cd /workspace/Qwen3-4B-Instruct-2507-TorchSharp.fs/runner-arm64-fp4
dotnet fsi run-script-with-guard.fsx \
--gpu-limit-gb 108 --gpu-over-secs 0 --gpu-poll-secs 0.5 \
script run-training2.fsx \
--weight /models/qwen3-4b-instruct-2507-torchsharp/Qwen3-4B-Instruct-2507-nvfp4.dat \
--prompt 你是誰 --max-tokens 24 --check-logits false --timing true --stop-here true --KVCacheOut trueOptional structure check (Python):
- Use
dat_reader.pyto inspect whether projection families haveqdata/scalepairs and required raw tensors still exist.
Types.fs: training config, defaults, Q4 session/schema defaultsCli.fs: command-line parsingNvfp4State.fs: streaming.datloader (qdata/scalepairs)Qwen3Model.fs: model construction, trainable params, forward/KV forwardQwen3Core.fs: block graph (attention, RoPE, MLP, residual)InferenceBridge.fs: inference wiring + tokenizer flowNvfp4Optimizer.fs: packed optimizer + chunked step implementationTrainer.fs: train loop, loss, checkpoint, VRAM profilingProgram.fs: app entrypoint
TrainData/: training datasets (prompt<TAB>targetTSV)scripts/Train.OneStep.fsx: one-step training + VRAM JSON outputscripts/Train.WhoAmI.AndExportDat.fsx: training +.datexport + quick self-testscripts/Generate.WhoAmINaturalData.fsx: natural-style dataset generationscripts/Tests.Parity.fsx: runner readability/stability spot-checkscripts/Tests.KVCStress.fsx: KV-cache stress matrixrunner-arm64-fp4/: cleaned local copy of runner scripts with postmortem and dtype/VRAM notes (runner-arm64-fp4/README.md)models/qwen3-4b-instruct-2507-torchsharp/: local copy of non-.datmodel metadata (config.json,tokenizer*.json) for reproducible script dependencies
doc/Architecture.md: architecture overviewdoc/NVFP4_DataPath.md: NVFP4 storage/compute path detailsdoc/SA.md: system analysis (risks/strategy)doc/SD.md: system designdoc/WBS.md: work breakdown and trackingdoc/DevLog.md: experiment log (including failures and fixes)notes/: raw notes and investigation referencesartifacts/檔案說明清單.md: artifact usage inventory
step: one parameter update (forward + loss + backward + optimizer step)seq-len: max token length used per training sample windowKV cache: inference acceleration cache; usually off in training pathCE loss: cross entropy between logits and target token IDsstep-chunk-rows: rows processed per optimizer chunk; smaller is usually safer but slower
- Offload is disabled by default for DGX Spark profile in this project.
- Large
.datfiles are ignored by git (artifacts/**/*.dat). - For stability, run
Train.OneStep.fsxbefore long runs. - Validate quality with fixed A/B prompts (e.g.,
你是誰and unrelated prompts like談談 UFO).
- This is an engineering-focused, rapidly evolving trainer, not a full production trainer.
- Full-parameter long-context training still faces high peak memory pressure; tune with guard.
- Different inference paths (
fp2-model/ bridge / runner) can diverge; run parity checks regularly.
Short answer: usually no for this repo's default profile.
This project keeps offload disabled by default for DGX Spark-oriented runs to avoid extra copy/management overhead.
Unified memory does not remove peak-allocation risk.
Transient buffers, allocator fragmentation, and step-time spikes can still kill the run. Guard is a safety rail.
Training window length (tokens), not generated output length.
Larger seq-len can improve learning context but raises activation memory pressure.
How many parameter rows are updated per optimizer chunk.
Smaller values reduce step-time memory peaks but are slower.
No.
--train-last-layers <= 0 means full-parameter mode (default).
--train-last-layers N is debug mode for last-N-layer tuning.
Inference routes can differ (bridge, fp2-model, runner scripts/options).
Always validate with fixed prompts and parity scripts (scripts/Tests.Parity.fsx) before concluding model quality.
This is usually dataset/objective imbalance (over-targeted data, too few contrastive samples, or aggressive LR/steps).
Use mixed data, keep unrelated prompts in validation, and tune steps/lr/seq-len conservatively.
Not unless you set --output-dat to the original path.
The training export script writes a new .dat by default.
| Symptom | Likely Cause | Recommended Action |
|---|---|---|
| Process crashes / host becomes unstable during training | Peak memory spike in backward or optimizer step | Use guard runner, lower seq-len, lower --step-chunk-rows (for example 32 -> 16), reduce steps for sanity first |
| Guard kills process quickly at high memory | Step-time temporary tensor burst | Keep --gpu-poll-secs small, reduce chunk rows, run one-step script first to profile phases |
| Output collapses to repeated phrase (for example repeated target tokens) | Over-targeted dataset or too aggressive LR/steps | Use mixed dataset, reduce LR, reduce steps, keep unrelated prompts in validation A/B |
你是誰 works but unrelated prompts are degraded |
Catastrophic forgetting from narrow objective | Add balanced/mixed training data and run short staged training instead of single overfit stage |
Same .dat behaves differently between paths |
Different inference path/options (bridge, fp2 model, runner flags) |
Lock prompt + flags and run parity checks (scripts/Tests.Parity.fsx) before judging quality |
| CE training fails with shape/target mismatch | Input/target window alignment issue | Verify tokenized lengths, ensure next-token target shift, validate response span slicing |
| Training seems to run but model does not improve | Wrong trainable subset or too weak update | Confirm full-parameter mode (--train-last-layers <= 0), check loss trend, verify export target path |
Exported .dat loads but quality is corrupted |
Incorrect export replacement scope or broken source checkpoint | Re-test with known baseline .dat, compare against artifacts/BASELINE_BRIDGE_SUCCESS.md, keep backup before export |
| VRAM unexpectedly high at model load | Multiple materialized tensors / duplicated parameter residency | Run one-step VRAM report, inspect dtype/path choices, keep compute dtype conservative (float16 on CUDA) |
| Interactive runner exits early | --stop-here or script flow exits before loop |
Ensure --stop-here false and verify script supports interactive mode in current branch |
- Build and inspect
--helpto understand the full config surface. - Run
scripts/Train.OneStep.fsxto understand one train step and VRAM profiling. - Read
Trainer.fs(tokenCrossEntropyLoss, training loop). - Run
scripts/Train.WhoAmI.AndExportDat.fsxend-to-end (train -> export -> self-test). - Validate exported
.datwith runner. - Deep dive into
doc/NVFP4_DataPath.md,doc/SA.md,doc/SD.md.
LICENSE