Skip to content

Commit 4fb44bd

Browse files
Promiseryyiyixuxu
andauthored
Fix wrong param types, docs, and handles noise=None in scale_noise of FlowMatching schedulers (huggingface#11669)
* Bug: Fix wrong params, docs, and handles noise=None * make noise a required arg --------- Co-authored-by: YiYi Xu <yixu310@gmail.com>
1 parent b7a8158 commit 4fb44bd

3 files changed

Lines changed: 16 additions & 9 deletions

File tree

src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -171,17 +171,19 @@ def set_shift(self, shift: float):
171171
def scale_noise(
172172
self,
173173
sample: torch.FloatTensor,
174-
timestep: Union[float, torch.FloatTensor],
175-
noise: Optional[torch.FloatTensor] = None,
174+
timestep: torch.FloatTensor,
175+
noise: torch.FloatTensor,
176176
) -> torch.FloatTensor:
177177
"""
178178
Forward process in flow-matching
179179
180180
Args:
181181
sample (`torch.FloatTensor`):
182182
The input sample.
183-
timestep (`int`, *optional*):
183+
timestep (`torch.FloatTensor`):
184184
The current timestep in the diffusion chain.
185+
noise (`torch.FloatTensor`):
186+
The noise tensor.
185187
186188
Returns:
187189
`torch.FloatTensor`:

src/diffusers/schedulers/scheduling_flow_match_heun_discrete.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -110,17 +110,19 @@ def set_begin_index(self, begin_index: int = 0):
110110
def scale_noise(
111111
self,
112112
sample: torch.FloatTensor,
113-
timestep: Union[float, torch.FloatTensor],
114-
noise: Optional[torch.FloatTensor] = None,
113+
timestep: torch.FloatTensor,
114+
noise: torch.FloatTensor,
115115
) -> torch.FloatTensor:
116116
"""
117117
Forward process in flow-matching
118118
119119
Args:
120120
sample (`torch.FloatTensor`):
121121
The input sample.
122-
timestep (`int`, *optional*):
122+
timestep (`torch.FloatTensor`):
123123
The current timestep in the diffusion chain.
124+
noise (`torch.FloatTensor`):
125+
The noise tensor.
124126
125127
Returns:
126128
`torch.FloatTensor`:
@@ -130,6 +132,7 @@ def scale_noise(
130132
self._init_step_index(timestep)
131133

132134
sigma = self.sigmas[self.step_index]
135+
133136
sample = sigma * noise + (1.0 - sigma) * sample
134137

135138
return sample

src/diffusers/schedulers/scheduling_flow_match_lcm.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -192,17 +192,19 @@ def set_scale_factors(self, scale_factors: list, upscale_mode):
192192
def scale_noise(
193193
self,
194194
sample: torch.FloatTensor,
195-
timestep: Union[float, torch.FloatTensor],
196-
noise: Optional[torch.FloatTensor] = None,
195+
timestep: torch.FloatTensor,
196+
noise: torch.FloatTensor,
197197
) -> torch.FloatTensor:
198198
"""
199199
Forward process in flow-matching
200200
201201
Args:
202202
sample (`torch.FloatTensor`):
203203
The input sample.
204-
timestep (`int`, *optional*):
204+
timestep (`torch.FloatTensor`):
205205
The current timestep in the diffusion chain.
206+
noise (`torch.FloatTensor`):
207+
The noise tensor.
206208
207209
Returns:
208210
`torch.FloatTensor`:

0 commit comments

Comments
 (0)