Table Of Contents
- Description
- How does this sample work?
- Prerequisites
- Running the sample
- Additional resources
- License
- Changelog
- Known issues
This sample, network_api_pytorch_mnist, trains a convolutional model on the MNIST dataset and runs inference with a TensorRT engine.
This sample is an end-to-end sample that trains a model in PyTorch, recreates the network in TensorRT, imports weights from the trained model, and finally runs inference with a TensorRT engine. For more information, see Creating A Network Definition In Python.
The sample.py script imports the functions from the mnist.py script for training the PyTorch model, as well as retrieving test cases from the PyTorch Data Loader.
In this sample, the following layers are used. For more information about these layers, see the TensorRT Developer Guide: Layers documentation.
Activation layer
The Activation layer implements element-wise activation functions. Specifically, this sample uses the Activation layer with the type RELU.
Convolution layer The Convolution layer computes a 2D (channel, height, and width) convolution, with or without bias.
MatrixMultiplyLayer
The MatrixMultiply layer implements a matrix multiplication.
(The FullyConnected layer is deprecated since 8.4.
The bias of FullyConnected semantic can be added with an
ElementwiseLayer of SUM operation.)
Pooling layer
The Pooling layer implements pooling within a channel. Supported pooling types are maximum, average and maximum-average blend.
- Upgrade pip version and install the sample dependencies.
pip3 install --upgrade pip pip3 install -r requirements.txt
To run this sample you must be using Python 3.6 or newer.
On PowerPC systems, you will need to manually install PyTorch using IBM's PowerAI.
- Preparing sample data
See Preparing sample data in the main samples README.
The MNIST dataset can be found under $TRT_DATADIR/mnist.
-
Run the sample to create a TensorRT inference engine and run inference:
python3 sample.py -
Verify that the sample ran successfully. If the sample runs successfully you should see a match between the test case and the prediction.
Test Case: 0 Prediction: 0
To see the full list of available options and their descriptions, use the -h or --help command line option.
The following resources provide a deeper understanding about getting started with TensorRT using Python:
Model
Dataset
Documentation
- Introduction To NVIDIA’s TensorRT Samples
- Working With TensorRT Using The Python API
- NVIDIA’s TensorRT Documentation Library
For terms and conditions for use, reproduction, and distribution, see the TensorRT Software License Agreement documentation.
October 2025 Migrate to strongly typed APIs.
August 2025 Removed support for Python versions < 3.10.
August 2023 Removed support for Python versions < 3.8.
September 2021 Updated the sample to use explicit batch network definition.
March 2021 Documented the Python version limitations.
February 2019
This README.md file was recreated, updated and reviewed.
This sample only supports Python 3.6+ due to torch and torchvision version requirements.