Skip to content

Commit 18b7264

Browse files
[Utils] Correct custom init sort (huggingface#4967)
* [Utils] Correct custom init sort * [Utils] Correct custom init sort * [Utils] Correct custom init sort * add type checking * fix custom init sort * fix test * fix tests --------- Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
1 parent d82157b commit 18b7264

39 files changed

Lines changed: 1371 additions & 498 deletions

File tree

src/diffusers/models/__init__.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from typing import TYPE_CHECKING
16+
1517
from ..utils import _LazyModule, is_flax_available, is_torch_available
1618

1719

@@ -40,7 +42,32 @@
4042
_import_structure["unet_2d_condition_flax"] = ["FlaxUNet2DConditionModel"]
4143
_import_structure["vae_flax"] = ["FlaxAutoencoderKL"]
4244

43-
import sys
4445

46+
if TYPE_CHECKING:
47+
if is_torch_available():
48+
from .adapter import MultiAdapter, T2IAdapter
49+
from .autoencoder_asym_kl import AsymmetricAutoencoderKL
50+
from .autoencoder_kl import AutoencoderKL
51+
from .autoencoder_tiny import AutoencoderTiny
52+
from .controlnet import ControlNetModel
53+
from .dual_transformer_2d import DualTransformer2DModel
54+
from .modeling_utils import ModelMixin
55+
from .prior_transformer import PriorTransformer
56+
from .t5_film_transformer import T5FilmDecoder
57+
from .transformer_2d import Transformer2DModel
58+
from .transformer_temporal import TransformerTemporalModel
59+
from .unet_1d import UNet1DModel
60+
from .unet_2d import UNet2DModel
61+
from .unet_2d_condition import UNet2DConditionModel
62+
from .unet_3d_condition import UNet3DConditionModel
63+
from .vq_model import VQModel
64+
65+
if is_flax_available():
66+
from .controlnet_flax import FlaxControlNetModel
67+
from .unet_2d_condition_flax import FlaxUNet2DConditionModel
68+
from .vae_flax import FlaxAutoencoderKL
69+
70+
else:
71+
import sys
4572

46-
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
73+
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)

0 commit comments

Comments
 (0)