forked from huggingface/diffusers
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathimage_processor.py
More file actions
185 lines (162 loc) · 8.21 KB
/
image_processor.py
File metadata and controls
185 lines (162 loc) · 8.21 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
# Copyright 2025 The Wan Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Tuple, Union
import numpy as np
import PIL.Image
import torch
from ...configuration_utils import register_to_config
from ...image_processor import VaeImageProcessor
from ...utils import PIL_INTERPOLATION
class WanAnimateImageProcessor(VaeImageProcessor):
r"""
Image processor to preprocess the reference (character) image for the Wan Animate model.
Args:
do_resize (`bool`, *optional*, defaults to `True`):
Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`. Can accept
`height` and `width` arguments from [`image_processor.VaeImageProcessor.preprocess`] method.
vae_scale_factor (`int`, *optional*, defaults to `8`):
VAE (spatial) scale factor. If `do_resize` is `True`, the image is automatically resized to multiples of
this factor.
vae_latent_channels (`int`, *optional*, defaults to `16`):
VAE latent channels.
spatial_patch_size (`Tuple[int, int]`, *optional*, defaults to `(2, 2)`):
The spatial patch size used by the diffusion transformer. For Wan models, this is typically (2, 2).
resample (`str`, *optional*, defaults to `lanczos`):
Resampling filter to use when resizing the image.
do_normalize (`bool`, *optional*, defaults to `True`):
Whether to normalize the image to [-1,1].
do_binarize (`bool`, *optional*, defaults to `False`):
Whether to binarize the image to 0/1.
do_convert_rgb (`bool`, *optional*, defaults to be `False`):
Whether to convert the images to RGB format.
do_convert_grayscale (`bool`, *optional*, defaults to be `False`):
Whether to convert the images to grayscale format.
fill_color (`str` or `float` or `Tuple[float, ...]`, *optional*, defaults to `None`):
An optional fill color when `resize_mode` is set to `"fill"`. This will fill the empty space with that
color instead of filling with data from the image. Any valid `color` argument to `PIL.Image.new` is valid;
if `None`, will default to filling with data from `image`.
"""
@register_to_config
def __init__(
self,
do_resize: bool = True,
vae_scale_factor: int = 8,
vae_latent_channels: int = 16,
spatial_patch_size: Tuple[int, int] = (2, 2),
resample: str = "lanczos",
reducing_gap: int = None,
do_normalize: bool = True,
do_binarize: bool = False,
do_convert_rgb: bool = False,
do_convert_grayscale: bool = False,
fill_color: Optional[Union[str, float, Tuple[float, ...]]] = 0,
):
super().__init__()
if do_convert_rgb and do_convert_grayscale:
raise ValueError(
"`do_convert_rgb` and `do_convert_grayscale` can not both be set to `True`,"
" if you intended to convert the image into RGB format, please set `do_convert_grayscale = False`.",
" if you intended to convert the image into grayscale format, please set `do_convert_rgb = False`",
)
def _resize_and_fill(
self,
image: PIL.Image.Image,
width: int,
height: int,
) -> PIL.Image.Image:
r"""
Resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center
the image within the dimensions, filling empty with data from image.
Args:
image (`PIL.Image.Image`):
The image to resize and fill.
width (`int`):
The width to resize the image to.
height (`int`):
The height to resize the image to.
Returns:
`PIL.Image.Image`:
The resized and filled image.
"""
ratio = width / height
src_ratio = image.width / image.height
fill_with_image_data = self.config.fill_color is None
fill_color = self.config.fill_color or 0
src_w = width if ratio < src_ratio else image.width * height // image.height
src_h = height if ratio >= src_ratio else image.height * width // image.width
resized = image.resize((src_w, src_h), resample=PIL_INTERPOLATION[self.config.resample])
res = PIL.Image.new("RGB", (width, height), color=fill_color)
res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2))
if fill_with_image_data:
if ratio < src_ratio:
fill_height = height // 2 - src_h // 2
if fill_height > 0:
res.paste(resized.resize((width, fill_height), box=(0, 0, width, 0)), box=(0, 0))
res.paste(
resized.resize((width, fill_height), box=(0, resized.height, width, resized.height)),
box=(0, fill_height + src_h),
)
elif ratio > src_ratio:
fill_width = width // 2 - src_w // 2
if fill_width > 0:
res.paste(resized.resize((fill_width, height), box=(0, 0, 0, height)), box=(0, 0))
res.paste(
resized.resize((fill_width, height), box=(resized.width, 0, resized.width, height)),
box=(fill_width + src_w, 0),
)
return res
def get_default_height_width(
self,
image: Union[PIL.Image.Image, np.ndarray, torch.Tensor],
height: Optional[int] = None,
width: Optional[int] = None,
) -> Tuple[int, int]:
r"""
Returns the height and width of the image, downscaled to the next integer multiple of `vae_scale_factor`.
Args:
image (`Union[PIL.Image.Image, np.ndarray, torch.Tensor]`):
The image input, which can be a PIL image, NumPy array, or PyTorch tensor. If it is a NumPy array, it
should have shape `[batch, height, width]` or `[batch, height, width, channels]`. If it is a PyTorch
tensor, it should have shape `[batch, channels, height, width]`.
height (`Optional[int]`, *optional*, defaults to `None`):
The height of the preprocessed image. If `None`, the height of the `image` input will be used.
width (`Optional[int]`, *optional*, defaults to `None`):
The width of the preprocessed image. If `None`, the width of the `image` input will be used.
Returns:
`Tuple[int, int]`:
A tuple containing the height and width, both resized to the nearest integer multiple of
`vae_scale_factor * spatial_patch_size`.
"""
if height is None:
if isinstance(image, PIL.Image.Image):
height = image.height
elif isinstance(image, torch.Tensor):
height = image.shape[2]
else:
height = image.shape[1]
if width is None:
if isinstance(image, PIL.Image.Image):
width = image.width
elif isinstance(image, torch.Tensor):
width = image.shape[3]
else:
width = image.shape[2]
max_area = width * height
aspect_ratio = height / width
mod_value_h = self.config.vae_scale_factor * self.config.spatial_patch_size[0]
mod_value_w = self.config.vae_scale_factor * self.config.spatial_patch_size[1]
# Try to preserve the aspect ratio
height = round(np.sqrt(max_area * aspect_ratio)) // mod_value_h * mod_value_h
width = round(np.sqrt(max_area / aspect_ratio)) // mod_value_w * mod_value_w
return height, width