Skip to content

Commit a844065

Browse files
authored
replace references to deprecated KeyArray & PRNGKeyArray (huggingface#5324)
1 parent 35952e6 commit a844065

15 files changed

Lines changed: 28 additions & 26 deletions

setup.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -102,8 +102,8 @@
102102
"importlib_metadata",
103103
"invisible-watermark>=0.2.0",
104104
"isort>=5.5.4",
105-
"jax>=0.2.8,!=0.3.2",
106-
"jaxlib>=0.1.65",
105+
"jax>=0.4.1",
106+
"jaxlib>=0.4.1",
107107
"Jinja2",
108108
"k-diffusion>=0.0.12",
109109
"torchsde",

src/diffusers/dependency_versions_table.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@
1515
"importlib_metadata": "importlib_metadata",
1616
"invisible-watermark": "invisible-watermark>=0.2.0",
1717
"isort": "isort>=5.5.4",
18-
"jax": "jax>=0.2.8,!=0.3.2",
19-
"jaxlib": "jaxlib>=0.1.65",
18+
"jax": "jax>=0.4.1",
19+
"jaxlib": "jaxlib>=0.4.1",
2020
"Jinja2": "Jinja2",
2121
"k-diffusion": "k-diffusion>=0.0.12",
2222
"torchsde": "torchsde",

src/diffusers/models/controlnet_flax.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ class FlaxControlNetModel(nn.Module, FlaxModelMixin, ConfigMixin):
168168
controlnet_conditioning_channel_order: str = "rgb"
169169
conditioning_embedding_out_channels: Tuple[int] = (16, 32, 96, 256)
170170

171-
def init_weights(self, rng: jax.random.KeyArray) -> FrozenDict:
171+
def init_weights(self, rng: jax.Array) -> FrozenDict:
172172
# init input tensors
173173
sample_shape = (1, self.in_channels, self.sample_size, self.sample_size)
174174
sample = jnp.zeros(sample_shape, dtype=jnp.float32)

src/diffusers/models/modeling_flax_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@ def to_fp16(self, params: Union[Dict, FrozenDict], mask: Any = None):
192192
```"""
193193
return self._cast_floating_to(params, jnp.float16, mask)
194194

195-
def init_weights(self, rng: jax.random.KeyArray) -> Dict:
195+
def init_weights(self, rng: jax.Array) -> Dict:
196196
raise NotImplementedError(f"init_weights method has to be implemented for {self}")
197197

198198
@classmethod

src/diffusers/models/unet_2d_condition_flax.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
126126
addition_embed_type_num_heads: int = 64
127127
projection_class_embeddings_input_dim: Optional[int] = None
128128

129-
def init_weights(self, rng: jax.random.KeyArray) -> FrozenDict:
129+
def init_weights(self, rng: jax.Array) -> FrozenDict:
130130
# init input tensors
131131
sample_shape = (1, self.in_channels, self.sample_size, self.sample_size)
132132
sample = jnp.zeros(sample_shape, dtype=jnp.float32)

src/diffusers/models/vae_flax.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -817,7 +817,7 @@ def setup(self):
817817
dtype=self.dtype,
818818
)
819819

820-
def init_weights(self, rng: jax.random.KeyArray) -> FrozenDict:
820+
def init_weights(self, rng: jax.Array) -> FrozenDict:
821821
# init input tensors
822822
sample_shape = (1, self.in_channels, self.sample_size, self.sample_size)
823823
sample = jnp.zeros(sample_shape, dtype=jnp.float32)

src/diffusers/pipelines/controlnet/pipeline_flax_controlnet.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,7 @@ def _generate(
241241
prompt_ids: jnp.array,
242242
image: jnp.array,
243243
params: Union[Dict, FrozenDict],
244-
prng_seed: jax.random.KeyArray,
244+
prng_seed: jax.Array,
245245
num_inference_steps: int,
246246
guidance_scale: float,
247247
latents: Optional[jnp.array] = None,
@@ -351,7 +351,7 @@ def __call__(
351351
prompt_ids: jnp.array,
352352
image: jnp.array,
353353
params: Union[Dict, FrozenDict],
354-
prng_seed: jax.random.KeyArray,
354+
prng_seed: jax.Array,
355355
num_inference_steps: int = 50,
356356
guidance_scale: Union[float, jnp.array] = 7.5,
357357
latents: jnp.array = None,
@@ -370,7 +370,7 @@ def __call__(
370370
Array representing the ControlNet input condition to provide guidance to the `unet` for generation.
371371
params (`Dict` or `FrozenDict`):
372372
Dictionary containing the model parameters/weights.
373-
prng_seed (`jax.random.KeyArray` or `jax.Array`):
373+
prng_seed (`jax.Array` or `jax.Array`):
374374
Array containing random number generator key.
375375
num_inference_steps (`int`, *optional*, defaults to 50):
376376
The number of denoising steps. More denoising steps usually lead to a higher quality image at the

src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,7 @@ def _generate(
215215
self,
216216
prompt_ids: jnp.array,
217217
params: Union[Dict, FrozenDict],
218-
prng_seed: jax.random.KeyArray,
218+
prng_seed: jax.Array,
219219
num_inference_steps: int,
220220
height: int,
221221
width: int,
@@ -312,7 +312,7 @@ def __call__(
312312
self,
313313
prompt_ids: jnp.array,
314314
params: Union[Dict, FrozenDict],
315-
prng_seed: jax.random.KeyArray,
315+
prng_seed: jax.Array,
316316
num_inference_steps: int = 50,
317317
height: Optional[int] = None,
318318
width: Optional[int] = None,

src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,7 @@ def _generate(
235235
prompt_ids: jnp.array,
236236
image: jnp.array,
237237
params: Union[Dict, FrozenDict],
238-
prng_seed: jax.random.KeyArray,
238+
prng_seed: jax.Array,
239239
start_timestep: int,
240240
num_inference_steps: int,
241241
height: int,
@@ -340,7 +340,7 @@ def __call__(
340340
prompt_ids: jnp.array,
341341
image: jnp.array,
342342
params: Union[Dict, FrozenDict],
343-
prng_seed: jax.random.KeyArray,
343+
prng_seed: jax.Array,
344344
strength: float = 0.8,
345345
num_inference_steps: int = 50,
346346
height: Optional[int] = None,
@@ -361,7 +361,7 @@ def __call__(
361361
Array representing an image batch to be used as the starting point.
362362
params (`Dict` or `FrozenDict`):
363363
Dictionary containing the model parameters/weights.
364-
prng_seed (`jax.random.KeyArray` or `jax.Array`):
364+
prng_seed (`jax.Array` or `jax.Array`):
365365
Array containing random number generator key.
366366
strength (`float`, *optional*, defaults to 0.8):
367367
Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a

src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,7 @@ def _generate(
270270
mask: jnp.array,
271271
masked_image: jnp.array,
272272
params: Union[Dict, FrozenDict],
273-
prng_seed: jax.random.KeyArray,
273+
prng_seed: jax.Array,
274274
num_inference_steps: int,
275275
height: int,
276276
width: int,
@@ -398,7 +398,7 @@ def __call__(
398398
mask: jnp.array,
399399
masked_image: jnp.array,
400400
params: Union[Dict, FrozenDict],
401-
prng_seed: jax.random.KeyArray,
401+
prng_seed: jax.Array,
402402
num_inference_steps: int = 50,
403403
height: Optional[int] = None,
404404
width: Optional[int] = None,

0 commit comments

Comments
 (0)