Skip to content

Latest commit

 

History

History

Folders and files

NameName
Last commit message
Last commit date

parent directory

..
 
 
 
 
 
 
 
 

README.md

“Hello World” For TensorRT Using PyTorch And Python

Table Of Contents

Description

This sample, network_api_pytorch_mnist, trains a convolutional model on the MNIST dataset and runs inference with a TensorRT engine.

How does this sample work?

This sample is an end-to-end sample that trains a model in PyTorch, recreates the network in TensorRT, imports weights from the trained model, and finally runs inference with a TensorRT engine. For more information, see Creating A Network Definition In Python.

The sample.py script imports the functions from the mnist.py script for training the PyTorch model, as well as retrieving test cases from the PyTorch Data Loader.

TensorRT API layers and ops

In this sample, the following layers are used. For more information about these layers, see the TensorRT Developer Guide: Layers documentation.

Activation layer The Activation layer implements element-wise activation functions. Specifically, this sample uses the Activation layer with the type RELU.

Convolution layer The Convolution layer computes a 2D (channel, height, and width) convolution, with or without bias.

MatrixMultiplyLayer The MatrixMultiply layer implements a matrix multiplication. (The FullyConnected layer is deprecated since 8.4. The bias of FullyConnected semantic can be added with an ElementwiseLayer of SUM operation.)

Pooling layer The Pooling layer implements pooling within a channel. Supported pooling types are maximum, average and maximum-average blend.

Prerequisites

  1. Upgrade pip version and install the sample dependencies.
    pip3 install --upgrade pip
    pip3 install -r requirements.txt

To run this sample you must be using Python 3.6 or newer.

On PowerPC systems, you will need to manually install PyTorch using IBM's PowerAI.

  1. Preparing sample data

See Preparing sample data in the main samples README.

The MNIST dataset can be found under $TRT_DATADIR/mnist.

Running the sample

  1. Run the sample to create a TensorRT inference engine and run inference: python3 sample.py

  2. Verify that the sample ran successfully. If the sample runs successfully you should see a match between the test case and the prediction.

    Test Case: 0
    Prediction: 0
    

Sample --help options

To see the full list of available options and their descriptions, use the -h or --help command line option.

Additional resources

The following resources provide a deeper understanding about getting started with TensorRT using Python:

Model

Dataset

Documentation

License

For terms and conditions for use, reproduction, and distribution, see the TensorRT Software License Agreement documentation.

Changelog

October 2025 Migrate to strongly typed APIs.

August 2025 Removed support for Python versions < 3.10.

August 2023 Removed support for Python versions < 3.8.

September 2021 Updated the sample to use explicit batch network definition.

March 2021 Documented the Python version limitations.

February 2019 This README.md file was recreated, updated and reviewed.

Known issues

This sample only supports Python 3.6+ due to torch and torchvision version requirements.