forked from huggingface/diffusers
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathnormalization.py
More file actions
115 lines (92 loc) · 4.05 KB
/
normalization.py
File metadata and controls
115 lines (92 loc) · 4.05 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
# coding=utf-8
# Copyright 2023 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from .activations import get_activation
from .embeddings import CombinedTimestepLabelEmbeddings
class AdaLayerNorm(nn.Module):
r"""
Norm layer modified to incorporate timestep embeddings.
Parameters:
embedding_dim (`int`): The size of each embedding vector.
num_embeddings (`int`): The size of the embeddings dictionary.
"""
def __init__(self, embedding_dim: int, num_embeddings: int):
super().__init__()
self.emb = nn.Embedding(num_embeddings, embedding_dim)
self.silu = nn.SiLU()
self.linear = nn.Linear(embedding_dim, embedding_dim * 2)
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False)
def forward(self, x: torch.Tensor, timestep: torch.Tensor) -> torch.Tensor:
emb = self.linear(self.silu(self.emb(timestep)))
scale, shift = torch.chunk(emb, 2)
x = self.norm(x) * (1 + scale) + shift
return x
class AdaLayerNormZero(nn.Module):
r"""
Norm layer adaptive layer norm zero (adaLN-Zero).
Parameters:
embedding_dim (`int`): The size of each embedding vector.
num_embeddings (`int`): The size of the embeddings dictionary.
"""
def __init__(self, embedding_dim: int, num_embeddings: int):
super().__init__()
self.emb = CombinedTimestepLabelEmbeddings(num_embeddings, embedding_dim)
self.silu = nn.SiLU()
self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True)
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
def forward(
self,
x: torch.Tensor,
timestep: torch.Tensor,
class_labels: torch.LongTensor,
hidden_dtype: Optional[torch.dtype] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
emb = self.linear(self.silu(self.emb(timestep, class_labels, hidden_dtype=hidden_dtype)))
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, dim=1)
x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
class AdaGroupNorm(nn.Module):
r"""
GroupNorm layer modified to incorporate timestep embeddings.
Parameters:
embedding_dim (`int`): The size of each embedding vector.
num_embeddings (`int`): The size of the embeddings dictionary.
num_groups (`int`): The number of groups to separate the channels into.
act_fn (`str`, *optional*, defaults to `None`): The activation function to use.
eps (`float`, *optional*, defaults to `1e-5`): The epsilon value to use for numerical stability.
"""
def __init__(
self, embedding_dim: int, out_dim: int, num_groups: int, act_fn: Optional[str] = None, eps: float = 1e-5
):
super().__init__()
self.num_groups = num_groups
self.eps = eps
if act_fn is None:
self.act = None
else:
self.act = get_activation(act_fn)
self.linear = nn.Linear(embedding_dim, out_dim * 2)
def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor:
if self.act:
emb = self.act(emb)
emb = self.linear(emb)
emb = emb[:, :, None, None]
scale, shift = emb.chunk(2, dim=1)
x = F.group_norm(x, self.num_groups, eps=self.eps)
x = x * (1 + scale) + shift
return x