We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 93983b6 commit 2c25b98Copy full SHA for 2c25b98
1 file changed
src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py
@@ -260,7 +260,6 @@ def encode_prompt(
260
padding="max_length",
261
return_tensors="pt",
262
)
263
- text_inputs = {k: v.to(device) for k, v in text_inputs.items()}
264
text_input_ids = text_inputs["input_ids"]
265
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
266
@@ -273,6 +272,7 @@ def encode_prompt(
273
272
f" {max_length} tokens: {removed_text}"
274
275
+ text_inputs = {k: v.to(device) for k, v in text_inputs.items()}
276
prompt_embeds = self.text_encoder(**text_inputs)[0]
277
prompt_attention_mask = text_inputs["attention_mask"].unsqueeze(-1).expand(prompt_embeds.shape)
278
prompt_embeds = prompt_embeds * prompt_attention_mask
0 commit comments