Skip to content

Commit c33fc7e

Browse files
authored
Merge branch 'huggingface:main' into main
2 parents dcf3b13 + 4737806 commit c33fc7e

3 files changed

Lines changed: 24 additions & 12 deletions

File tree

src/diffusers/loaders/single_file_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@
162162
"default_subfolder": "transformer",
163163
},
164164
"QwenImageTransformer2DModel": {
165-
"checkpoint_mapping_fn": lambda x: x,
165+
"checkpoint_mapping_fn": lambda checkpoint, **kwargs: checkpoint,
166166
"default_subfolder": "transformer",
167167
},
168168
"Flux2Transformer2DModel": {

src/diffusers/loaders/single_file_utils.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,10 @@
120120
"hunyuan-video": "txt_in.individual_token_refiner.blocks.0.adaLN_modulation.1.bias",
121121
"instruct-pix2pix": "model.diffusion_model.input_blocks.0.0.weight",
122122
"lumina2": ["model.diffusion_model.cap_embedder.0.weight", "cap_embedder.0.weight"],
123-
"z-image-turbo": "cap_embedder.0.weight",
123+
"z-image-turbo": [
124+
"model.diffusion_model.layers.0.adaLN_modulation.0.weight",
125+
"layers.0.adaLN_modulation.0.weight",
126+
],
124127
"z-image-turbo-controlnet": "control_all_x_embedder.2-1.weight",
125128
"z-image-turbo-controlnet-2.x": "control_layers.14.adaLN_modulation.0.weight",
126129
"sana": [
@@ -727,10 +730,7 @@ def infer_diffusers_model_type(checkpoint):
727730
):
728731
model_type = "instruct-pix2pix"
729732

730-
elif (
731-
CHECKPOINT_KEY_NAMES["z-image-turbo"] in checkpoint
732-
and checkpoint[CHECKPOINT_KEY_NAMES["z-image-turbo"]].shape[0] == 2560
733-
):
733+
elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["z-image-turbo"]):
734734
model_type = "z-image-turbo"
735735

736736
elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["lumina2"]):
@@ -3852,6 +3852,7 @@ def convert_z_image_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
38523852
".attention.k_norm.weight": ".attention.norm_k.weight",
38533853
".attention.q_norm.weight": ".attention.norm_q.weight",
38543854
".attention.out.weight": ".attention.to_out.0.weight",
3855+
"model.diffusion_model.": "",
38553856
}
38563857

38573858
def convert_z_image_fused_attention(key: str, state_dict: dict[str, object]) -> None:
@@ -3886,6 +3887,9 @@ def update_state_dict(state_dict: dict[str, object], old_key: str, new_key: str)
38863887

38873888
update_state_dict(converted_state_dict, key, new_key)
38883889

3890+
if "norm_final.weight" in converted_state_dict.keys():
3891+
_ = converted_state_dict.pop("norm_final.weight")
3892+
38893893
# Handle any special logic which can't be expressed by a simple 1:1 remapping with the handlers in
38903894
# special_keys_remap
38913895
for key in list(converted_state_dict.keys()):

src/diffusers/quantizers/torchao/torchao_quantizer.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,9 @@
3636
from ..base import DiffusersQuantizer
3737

3838

39+
logger = logging.get_logger(__name__)
40+
41+
3942
if TYPE_CHECKING:
4043
from ...models.modeling_utils import ModelMixin
4144

@@ -83,11 +86,19 @@ def _update_torch_safe_globals():
8386
]
8487
try:
8588
from torchao.dtypes import NF4Tensor
86-
from torchao.dtypes.floatx.float8_layout import Float8AQTTensorImpl
87-
from torchao.dtypes.uintx.uint4_layout import UInt4Tensor
8889
from torchao.dtypes.uintx.uintx_layout import UintxAQTTensorImpl, UintxTensor
8990

90-
safe_globals.extend([UintxTensor, UInt4Tensor, UintxAQTTensorImpl, Float8AQTTensorImpl, NF4Tensor])
91+
safe_globals.extend([UintxTensor, UintxAQTTensorImpl, NF4Tensor])
92+
93+
# note: is_torchao_version(">=", "0.16.0") does not work correctly
94+
# with torchao nightly, so using a ">" check which does work correctly
95+
if is_torchao_version(">", "0.15.0"):
96+
pass
97+
else:
98+
from torchao.dtypes.floatx.float8_layout import Float8AQTTensorImpl
99+
from torchao.dtypes.uintx.uint4_layout import UInt4Tensor
100+
101+
safe_globals.extend([UInt4Tensor, Float8AQTTensorImpl])
91102

92103
except (ImportError, ModuleNotFoundError) as e:
93104
logger.warning(
@@ -123,9 +134,6 @@ def fuzzy_match_size(config_name: str) -> Optional[str]:
123134
return None
124135

125136

126-
logger = logging.get_logger(__name__)
127-
128-
129137
def _quantization_type(weight):
130138
from torchao.dtypes import AffineQuantizedTensor
131139
from torchao.quantization.linear_activation_quantized_tensor import LinearActivationQuantizedTensor

0 commit comments

Comments
 (0)