Skip to content

Commit 052bf32

Browse files
authored
Fix AutoencoderTiny encoder scaling convention (huggingface#4682)
* Fix AutoencoderTiny encoder scaling convention * Add [-1, 1] -> [0, 1] rescaling to EncoderTiny * Move [0, 1] -> [-1, 1] rescaling from AutoencoderTiny.decode to DecoderTiny (i.e. immediately after the final conv, as early as possible) * Fix missing [0, 255] -> [0, 1] rescaling in AutoencoderTiny.forward * Update AutoencoderTinyIntegrationTests to protect against scaling issues. The new test constructs a simple image, round-trips it through AutoencoderTiny, and confirms the decoded result is approximately equal to the source image. This test checks behavior with and without tiling enabled. This test will fail if new AutoencoderTiny scaling issues are introduced. * Context: Raw TAESD weights expect images in [0, 1], but diffusers' convention represents images with zero-centered values in [-1, 1], so AutoencoderTiny needs to scale / unscale images at the start of encoding and at the end of decoding in order to work with diffusers. * Re-add existing AutoencoderTiny test, update golden values * Add comments to AutoencoderTiny.forward
1 parent 80871ac commit 052bf32

3 files changed

Lines changed: 35 additions & 7 deletions

File tree

src/diffusers/models/autoencoder_tiny.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -312,9 +312,6 @@ def decode(self, x: torch.FloatTensor, return_dict: bool = True) -> Union[Decode
312312
output = torch.cat(output)
313313
else:
314314
output = self._tiled_decode(x) if self.use_tiling else self.decoder(x)
315-
# Refer to the following discussion to know why this is needed.
316-
# https://github.com/huggingface/diffusers/pull/4384#discussion_r1279401854
317-
output = output.mul_(2).sub_(1)
318315

319316
if not return_dict:
320317
return (output,)
@@ -333,8 +330,15 @@ def forward(
333330
Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
334331
"""
335332
enc = self.encode(sample).latents
333+
334+
# scale latents to be in [0, 1], then quantize latents to a byte tensor,
335+
# as if we were storing the latents in an RGBA uint8 image.
336336
scaled_enc = self.scale_latents(enc).mul_(255).round_().byte()
337-
unscaled_enc = self.unscale_latents(scaled_enc)
337+
338+
# unquantize latents back into [0, 1], then unscale latents back to their original range,
339+
# as if we were loading the latents from an RGBA uint8 image.
340+
unscaled_enc = self.unscale_latents(scaled_enc / 255.0)
341+
338342
dec = self.decode(unscaled_enc)
339343

340344
if not return_dict:

src/diffusers/models/vae.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -732,7 +732,8 @@ def custom_forward(*inputs):
732732
x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.layers), x)
733733

734734
else:
735-
x = self.layers(x)
735+
# scale image from [-1, 1] to [0, 1] to match TAESD convention
736+
x = self.layers(x.add(1).div(2))
736737

737738
return x
738739

@@ -790,4 +791,5 @@ def custom_forward(*inputs):
790791
else:
791792
x = self.layers(x)
792793

793-
return x
794+
# scale image from [0, 1] to [-1, 1] to match diffusers convention
795+
return x.mul(2).sub(1)

tests/models/test_models_vae.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -312,10 +312,32 @@ def test_stable_diffusion(self):
312312
assert sample.shape == image.shape
313313

314314
output_slice = sample[-1, -2:, -2:, :2].flatten().float().cpu()
315-
expected_output_slice = torch.tensor([0.9858, 0.9262, 0.8629, 1.0974, -0.091, -0.2485, 0.0936, 0.0604])
315+
expected_output_slice = torch.tensor([0.0093, 0.6385, -0.1274, 0.1631, -0.1762, 0.5232, -0.3108, -0.0382])
316316

317317
assert torch_all_close(output_slice, expected_output_slice, atol=3e-3)
318318

319+
@parameterized.expand([(True,), (False,)])
320+
def test_tae_roundtrip(self, enable_tiling):
321+
# load the autoencoder
322+
model = self.get_sd_vae_model()
323+
if enable_tiling:
324+
model.enable_tiling()
325+
326+
# make a black image with a white square in the middle,
327+
# which is large enough to split across multiple tiles
328+
image = -torch.ones(1, 3, 1024, 1024, device=torch_device)
329+
image[..., 256:768, 256:768] = 1.0
330+
331+
# round-trip the image through the autoencoder
332+
with torch.no_grad():
333+
sample = model(image).sample
334+
335+
# the autoencoder reconstruction should match original image, sorta
336+
def downscale(x):
337+
return torch.nn.functional.avg_pool2d(x, model.spatial_scale_factor)
338+
339+
assert torch_all_close(downscale(sample), downscale(image), atol=0.125)
340+
319341

320342
@slow
321343
class AutoencoderKLIntegrationTests(unittest.TestCase):

0 commit comments

Comments
 (0)