@@ -241,7 +241,7 @@ def _generate(
241241 prompt_ids : jnp .array ,
242242 image : jnp .array ,
243243 params : Union [Dict , FrozenDict ],
244- prng_seed : jax .random . KeyArray ,
244+ prng_seed : jax .Array ,
245245 num_inference_steps : int ,
246246 guidance_scale : float ,
247247 latents : Optional [jnp .array ] = None ,
@@ -351,7 +351,7 @@ def __call__(
351351 prompt_ids : jnp .array ,
352352 image : jnp .array ,
353353 params : Union [Dict , FrozenDict ],
354- prng_seed : jax .random . KeyArray ,
354+ prng_seed : jax .Array ,
355355 num_inference_steps : int = 50 ,
356356 guidance_scale : Union [float , jnp .array ] = 7.5 ,
357357 latents : jnp .array = None ,
@@ -370,7 +370,7 @@ def __call__(
370370 Array representing the ControlNet input condition to provide guidance to the `unet` for generation.
371371 params (`Dict` or `FrozenDict`):
372372 Dictionary containing the model parameters/weights.
373- prng_seed (`jax.random.KeyArray ` or `jax.Array`):
373+ prng_seed (`jax.Array ` or `jax.Array`):
374374 Array containing random number generator key.
375375 num_inference_steps (`int`, *optional*, defaults to 50):
376376 The number of denoising steps. More denoising steps usually lead to a higher quality image at the
0 commit comments