-
Notifications
You must be signed in to change notification settings - Fork 1.7k
Expand file tree
/
Copy pathautocast.py
More file actions
252 lines (208 loc) · 8.11 KB
/
autocast.py
File metadata and controls
252 lines (208 loc) · 8.11 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
"""This module implements utilities and abstractions for use with
`torch.autocast`, i.e. Automatic Mixed Precision.
Authors
* Sylvain de Langen 2023
* Adel Moumen 2025
"""
import functools
from contextlib import nullcontext
from dataclasses import dataclass
from typing import Callable, Optional
import torch
@dataclass
class AMPConfig:
"""Configuration for automatic mixed precision (AMP).
Arguments
---------
dtype : torch.dtype
The dtype to use for AMP.
"""
dtype: torch.dtype
@classmethod
def from_name(self, name):
"""Create an AMPConfig from a string name.
Arguments
---------
name : str
The name of the AMPConfig to create. Must be one of `fp32`,
`fp16`, or `bf16`.
Returns
-------
AMPConfig
The AMPConfig corresponding to the name.
"""
if name is None or name == "fp32":
return AMPConfig(torch.float32)
elif name == "fp16":
return AMPConfig(torch.float16)
elif name == "bf16":
return AMPConfig(torch.bfloat16)
else:
raise ValueError(
f"Specified autocast mode ({name}) incorrect, expected one of `fp32`, `fp16`, `bf16`."
)
class TorchAutocast:
"""
A context manager that conditionally enables ``torch.autocast`` for GPU operations.
This manager wraps around ``torch.autocast`` to automatically enable autocasting when
running on a GPU and a data type other than float32 is specified. If the desired
data type is float32, autocasting is bypassed and the context manager behaves as a
no-op.
Parameters
----------
*args : tuple
Positional arguments forwarded to `torch.autocast`.
See the PyTorch documentation: https://pytorch.org/docs/stable/amp.html#torch.autocast
**kwargs : dict
Keyword arguments forwarded to `torch.autocast`.
Typically includes the `dtype` argument to specify the desired precision.
See the PyTorch documentation for more details.
"""
def __init__(self, *args, **kwargs):
enabled = kwargs.get("dtype", torch.float32) != torch.float32
if enabled:
self.context = torch.autocast(*args, **kwargs)
else:
self.context = nullcontext() # no-op context manager
def __enter__(self):
"""
Enter the autocast context.
Returns
-------
context
The result of entering the underlying autocast context manager.
Raises
------
RuntimeError
If an error occurs while entering the autocast context and the context
provides 'device' and 'fast_dtype' attributes, a RuntimeError is raised
with additional diagnostic information.
"""
try:
return self.context.__enter__()
except RuntimeError as e:
if hasattr(self.context, "device") and hasattr(
self.context, "fast_dtype"
):
device = self.context.device
dtype = self.context.fast_dtype
raise RuntimeError(
f"Error during autocasting with dtype={dtype} on device={device}.\n"
) from e
else:
raise
def __exit__(self, exc_type, exc_val, exc_tb):
"""
Exit the autocast context.
Parameters
----------
exc_type : type
Exception type if an exception occurred, otherwise None.
exc_val : Exception
Exception instance if an exception occurred, otherwise None.
exc_tb : traceback
Traceback object if an exception occurred, otherwise None.
Returns
-------
bool or None
The result of exiting the underlying autocast context manager.
"""
return self.context.__exit__(exc_type, exc_val, exc_tb)
def _infer_device_type(*args, **kwargs):
"""Infer device type from the input tensors.
This function returns the device type of the first tensor found in the
arguments or keyword arguments. It assumes all tensors are on the same
device, which is typically the case in PyTorch operations.
Arguments
---------
*args: tuple
Arguments that may contain tensors
**kwargs: dict
Keyword arguments that may contain tensors
Returns
-------
str
Device type ('cuda', 'cpu', 'mps', etc.)
"""
# Check args for tensors
for arg in args:
if isinstance(arg, torch.Tensor):
return arg.device.type
# Check kwargs for tensors
for value in kwargs.values():
if isinstance(value, torch.Tensor):
return value.device.type
# Default to cpu if no tensors found
return "cpu"
def fwd_default_precision(
fwd: Optional[Callable] = None,
cast_inputs: Optional[torch.dtype] = torch.float32,
):
"""Decorator for forward methods which, by default, *disables* autocast
and casts any floating-point tensor parameters into the specified dtype
(much like `torch.amp.custom_fwd`).
The *wrapped forward* will gain an additional `force_allow_autocast` keyword
parameter.
When set to `True`, the function will ignore `cast_inputs` and will not
disable autocast, as if this decorator was not specified.
(Thus, modules can specify a default recommended precision, and users can
override that behavior when desired.)
This decorator now supports both CPU and CUDA by using `torch.amp.custom_fwd`
with the device_type inferred from input tensors at runtime.
When autocast is *not* active, this decorator does not change any behavior.
Arguments
---------
fwd: Optional[Callable]
The function to wrap. If omitted, returns a partial application of the
decorator, e.g. allowing
`new_decorator = fwd_default_precision(cast_inputs=torch.float32)`.
Reminder: If you are decorating a function directly, this argument is
already specified implicitly.
cast_inputs: Optional[torch.dtype]
If not `None` (the default being `torch.float32`), then any
floating-point inputs to the wrapped function will be cast to the
specified type.
Note: When autocasting is enabled, output tensors of autocast-compatible
operations may be of the autocast data type.
Disabling autocast *without* casting inputs will not change this fact,
so lower precision operations can happen even inside of an
autocast-disabled region, which this argument helps avoid if desired.
Returns
-------
The wrapped function
"""
if fwd is None:
return functools.partial(fwd_default_precision, cast_inputs=cast_inputs)
# Cache for wrapped functions by device type (lazy initialization)
wrapped_cache = {}
def get_wrapped_fwd(device_type):
"""Get or create a wrapped function for the given device type."""
if device_type not in wrapped_cache:
wrapped_cache[device_type] = torch.amp.custom_fwd(
fwd, device_type=device_type, cast_inputs=cast_inputs
)
return wrapped_cache[device_type]
@functools.wraps(fwd)
def wrapper(*args, force_allow_autocast: bool = False, **kwargs):
"""Wrapped forward function from fwd_default_precision.
Arguments
---------
*args: tuple
Arguments to be forwarded to the unwrapped function.
force_allow_autocast: bool
When `True`, the wrapped function will be executed directly with no
change to the autocast context and no input casting.
**kwargs: dict
Arguments to be forwarded to the unwrapped function.
Returns
-------
The wrapped function if force_allow_autocast, else the original
"""
if force_allow_autocast:
return fwd(*args, **kwargs)
else:
# Infer device type from input tensors
device_type = _infer_device_type(*args, **kwargs)
wrapped_fwd = get_wrapped_fwd(device_type)
return wrapped_fwd(*args, **kwargs)
return wrapper