Skip to content

StableDiffusionInpaintPipeline should handle batches #1000

@vict0rsch

Description

@vict0rsch

I'd like to batch inferences with StableDiffusionInpaintPipeline. Most of the code in pipeline_stable_diffusion_inpaint.py can do it except for prepare_mask_and_masked_image(image, mask) which requires a PIL.Image.

Describe the solution you'd like
A simple fix would be to just let tensors alone and assume the end user has done their job:

def prepare_mask_and_masked_image(image, mask):
    if isinstance(image, torch.Tensor):
        assert isinstance(mask, torch.Tensor)
        return image, mask
    elif isinstance(mask, torch.Tensor):
        raise TypeError(f"`mask` is a torch.Tensor but `image` (type: {type(image)} is not")

    image = np.array(image.convert("RGB"))
    image = image[None].transpose(0, 3, 1, 2)
    image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0

    mask = np.array(mask.convert("L"))
    mask = mask.astype(np.float32) / 255.0
    mask = mask[None, None]
    mask[mask < 0.5] = 0
    mask[mask >= 0.5] = 1
    mask = torch.from_numpy(mask)

    masked_image = image * (mask < 0.5)

    return mask, masked_image

We could have additional checks for shape, dtype but that's "nice" rather than "need" to have.

Metadata

Metadata

Assignees

No one assigned

    Labels

    staleIssues that haven't received updates

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions