forked from deepspeedai/DeepSpeed
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathutils.py
More file actions
43 lines (35 loc) · 1.61 KB
/
utils.py
File metadata and controls
43 lines (35 loc) · 1.61 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from deepspeed.utils import log_dist
# helper function to map between DS policies and DS containers
def policy_to_ds_container(**kwargs):
from .containers import HFGPT2LayerPolicy, DS_GPT2Container
from .containers import HFBertLayerPolicy, DS_BERTContainer
from .containers import BLOOMLayerPolicy, DS_BloomContainer
from .containers import HFGPTJLayerPolicy, DS_GPTJContainer
from .containers import HFGPTNEOLayerPolicy, DS_GPTNEOContainer
from .containers import GPTNEOXLayerPolicy, DS_GPTNEOXContainer
from .containers import HFOPTLayerPolicy, DS_OPTContainer
from .containers import MegatronLayerPolicy, DS_MegatronGPTContainer
from .containers import HFDistilBertLayerPolicy, DS_DistilBERTContainer
policy_to_container = {
HFGPT2LayerPolicy: DS_GPT2Container,
HFBertLayerPolicy: DS_BERTContainer,
BLOOMLayerPolicy: DS_BloomContainer,
HFGPTJLayerPolicy: DS_GPTJContainer,
HFGPTNEOLayerPolicy: DS_GPTNEOContainer,
GPTNEOXLayerPolicy: DS_GPTNEOXContainer,
HFOPTLayerPolicy: DS_OPTContainer,
MegatronLayerPolicy: DS_MegatronGPTContainer,
HFDistilBertLayerPolicy: DS_DistilBERTContainer,
}
container = None
policy = kwargs['policy']
assert policy is not None, "Policy cannot be None"
policy_type = type(policy)
if policy_type not in policy_to_container:
log_dist(f"Policy type {policy_type} not supported", [0])
else:
container = policy_to_container[policy_type](**kwargs)
return container