diff --git a/.gitignore b/.gitignore
index cf8183463613..0dfcf97ce70e 100644
--- a/.gitignore
+++ b/.gitignore
@@ -163,4 +163,7 @@ tags
*.lock
# DS_Store (MacOS)
-.DS_Store
\ No newline at end of file
+.DS_Store
+
+log
+/*.png
diff --git a/setup.py b/setup.py
index 20c9ea61f5f2..b59f65787eab 100644
--- a/setup.py
+++ b/setup.py
@@ -186,6 +186,12 @@ def run(self):
"torchvision",
"transformers"
)
+extras["oneflow"] = deps_list(
+ "torch",
+ "scipy",
+ "transformers"
+)
+
extras["torch"] = deps_list("torch")
if os.name == "nt": # windows
diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py
index acdddaac4d26..d0bd9ede87d0 100644
--- a/src/diffusers/__init__.py
+++ b/src/diffusers/__init__.py
@@ -82,3 +82,16 @@
from .pipelines import FlaxStableDiffusionPipeline
else:
from .utils.dummy_flax_and_transformers_objects import * # noqa F403
+
+from .models.unet_2d_condition_oneflow import OneFlowUNet2DConditionModel
+from .models.vae_oneflow import OneFlowAutoencoderKL
+
+from .schedulers import (
+ OneFlowDDIMScheduler,
+ OneFlowPNDMScheduler,
+ OneFlowSchedulerMixin
+)
+
+from .pipelines import OneFlowStableDiffusionPipeline
+from .pipeline_oneflow_utils import OneFlowDiffusionPipeline
+from .modeling_oneflow_utils import OneFlowModelMixin
diff --git a/src/diffusers/configuration_utils.py b/src/diffusers/configuration_utils.py
index 19f58fd8165d..865a211d80de 100644
--- a/src/diffusers/configuration_utils.py
+++ b/src/diffusers/configuration_utils.py
@@ -331,6 +331,9 @@ def __repr__(self):
def config(self) -> Dict[str, Any]:
return self._internal_dict
+ def config_dict(self) -> Dict[str, Any]:
+ return self._internal_dict
+
def to_json_string(self) -> str:
"""
Serializes this instance to a JSON string.
diff --git a/src/diffusers/modeling_oneflow_utils.py b/src/diffusers/modeling_oneflow_utils.py
new file mode 100644
index 000000000000..81874543ae37
--- /dev/null
+++ b/src/diffusers/modeling_oneflow_utils.py
@@ -0,0 +1,622 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team.
+# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from genericpath import isdir
+import os
+from functools import partial
+from typing import Callable, List, Optional, Tuple, Union
+
+import oneflow as torch
+from oneflow import Tensor, device
+
+from huggingface_hub import hf_hub_download
+from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
+from requests import HTTPError
+
+from .utils import CONFIG_NAME, DIFFUSERS_CACHE, HUGGINGFACE_CO_RESOLVE_ENDPOINT, WEIGHTS_NAME, logging
+import numpy as np
+
+
+logger = logging.get_logger(__name__)
+
+
+def index_cast(indices):
+ if isinstance(indices, torch.Tensor):
+ return indices.to(dtype=torch.int32)
+ else:
+ return indices
+
+def extract_scalar(t):
+ if isinstance(t, torch.Tensor) and t.size() == torch.Size([]) and t.device == torch.device("cpu"):
+ return t.item()
+ return t
+
+def from_numpy_if_needed(*args):
+ if len(args) == 1:
+ if isinstance(args[0], np.ndarray):
+ return torch.from_numpy(args[0])
+ else:
+ return args[0]
+ return [torch.from_numpy(a) if isinstance(a, np.ndarray) else a for a in args]
+
+def print_dtype(*args):
+ for a in args:
+ if isinstance(a, torch.Tensor):
+ print(a.dtype)
+
+def get_parameter_device(parameter: torch.nn.Module):
+ try:
+ return next(parameter.parameters()).device
+ except StopIteration:
+ # For torch.nn.DataParallel compatibility in PyTorch 1.5
+
+ def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]:
+ tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
+ return tuples
+
+ gen = parameter._named_members(get_members_fn=find_tensor_attributes)
+ first_tuple = next(gen)
+ return first_tuple[1].device
+
+
+def get_parameter_dtype(parameter: torch.nn.Module):
+ try:
+ return next(parameter.parameters()).dtype
+ except StopIteration:
+ # For torch.nn.DataParallel compatibility in PyTorch 1.5
+
+ def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]:
+ tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
+ return tuples
+
+ gen = parameter._named_members(get_members_fn=find_tensor_attributes)
+ first_tuple = next(gen)
+ return first_tuple[1].dtype
+
+
+def load_state_dict(checkpoint_file: Union[str, os.PathLike]):
+ """
+ Reads a PyTorch checkpoint file, returning properly formatted errors if they arise.
+ """
+ try:
+ # this is oneflow saved model, a dir
+ if os.path.isdir(checkpoint_file):
+ return torch.load(checkpoint_file, map_location="cpu")
+ else:
+ import torch as og_torch
+ torch_parameters = og_torch.load(checkpoint_file, map_location="cpu")
+ oneflow_parameters = dict()
+ for key,value in torch_parameters.items():
+ if value.is_cuda:
+ raise ValueError(f"torch model is not on cpu, it is on {value.device}")
+ val = value.detach().cpu().numpy()
+ oneflow_parameters[key] = val
+ return oneflow_parameters
+ except Exception as e:
+ try:
+ with open(checkpoint_file) as f:
+ if f.read().startswith("version"):
+ raise OSError(
+ "You seem to have cloned a repository without having git-lfs installed. Please install "
+ "git-lfs and run `git lfs install` followed by `git lfs pull` in the folder "
+ "you cloned."
+ )
+ else:
+ raise ValueError(
+ f"Unable to locate the file {checkpoint_file} which is necessary to load this pretrained "
+ "model. Make sure you have saved the model properly."
+ ) from e
+ except (UnicodeDecodeError, ValueError):
+ raise OSError(
+ f"Unable to load weights from pytorch checkpoint file for '{checkpoint_file}' "
+ f"at '{checkpoint_file}'. "
+ "If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True."
+ )
+
+
+def _load_state_dict_into_model(model_to_load, state_dict):
+ # Convert old format to new format if needed from a PyTorch state_dict
+ # copy state_dict so _load_from_state_dict can modify it
+ state_dict = state_dict.copy()
+ error_msgs = []
+
+ # PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
+ # so we need to apply the function recursively.
+ def load(module: torch.nn.Module, prefix=""):
+ args = (state_dict, prefix, {}, True, [], [], error_msgs)
+ module._load_from_state_dict(*args)
+
+ for name, child in module._modules.items():
+ if child is not None:
+ load(child, prefix + name + ".")
+
+ load(model_to_load)
+
+ return error_msgs
+
+
+class OneFlowModelMixin(torch.nn.Module):
+ r"""
+ Base class for all models.
+
+ [`ModelMixin`] takes care of storing the configuration of the models and handles methods for loading, downloading
+ and saving models.
+
+ - **config_name** ([`str`]) -- A filename under which the model should be stored when calling
+ [`~modeling_utils.ModelMixin.save_pretrained`].
+ """
+ config_name = CONFIG_NAME
+ _automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"]
+ _supports_gradient_checkpointing = False
+
+ def __init__(self):
+ super().__init__()
+
+ @property
+ def is_gradient_checkpointing(self) -> bool:
+ """
+ Whether gradient checkpointing is activated for this model or not.
+
+ Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint
+ activations".
+ """
+ return any(hasattr(m, "gradient_checkpointing") and m.gradient_checkpointing for m in self.modules())
+
+ def enable_gradient_checkpointing(self):
+ """
+ Activates gradient checkpointing for the current model.
+
+ Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint
+ activations".
+ """
+ if not self._supports_gradient_checkpointing:
+ raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.")
+ self.apply(partial(self._set_gradient_checkpointing, value=True))
+
+ def disable_gradient_checkpointing(self):
+ """
+ Deactivates gradient checkpointing for the current model.
+
+ Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint
+ activations".
+ """
+ if self._supports_gradient_checkpointing:
+ self.apply(partial(self._set_gradient_checkpointing, value=False))
+
+ def save_pretrained(
+ self,
+ save_directory: Union[str, os.PathLike],
+ is_main_process: bool = True,
+ save_function: Callable = torch.save,
+ ):
+ """
+ Save a model and its configuration file to a directory, so that it can be re-loaded using the
+ `[`~modeling_utils.ModelMixin.from_pretrained`]` class method.
+
+ Arguments:
+ save_directory (`str` or `os.PathLike`):
+ Directory to which to save. Will be created if it doesn't exist.
+ is_main_process (`bool`, *optional*, defaults to `True`):
+ Whether the process calling this is the main process or not. Useful when in distributed training like
+ TPUs and need to call this function on all processes. In this case, set `is_main_process=True` only on
+ the main process to avoid race conditions.
+ save_function (`Callable`):
+ The function to use to save the state dictionary. Useful on distributed training like TPUs when one
+ need to replace `torch.save` by another method.
+ """
+ if os.path.isfile(save_directory):
+ logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
+ return
+
+ os.makedirs(save_directory, exist_ok=True)
+
+ model_to_save = self
+
+ # Attach architecture to the config
+ # Save the config
+ if is_main_process:
+ model_to_save.save_config(save_directory)
+
+ # Save the model
+ state_dict = model_to_save.state_dict()
+
+ # Clean the folder from a previous save
+ for filename in os.listdir(save_directory):
+ full_filename = os.path.join(save_directory, filename)
+ # If we have a shard file that is not going to be replaced, we delete it, but only from the main process
+ # in distributed settings to avoid race conditions.
+ if filename.startswith(WEIGHTS_NAME[:-4]) and os.path.isfile(full_filename) and is_main_process:
+ os.remove(full_filename)
+
+ # Save the model
+ save_function(state_dict, os.path.join(save_directory, WEIGHTS_NAME))
+
+ logger.info(f"Model weights saved in {os.path.join(save_directory, WEIGHTS_NAME)}")
+
+ @classmethod
+ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
+ r"""
+ Instantiate a pretrained pytorch model from a pre-trained model configuration.
+
+ The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). To train
+ the model, you should first set it back in training mode with `model.train()`.
+
+ The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come
+ pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning
+ task.
+
+ The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those
+ weights are discarded.
+
+ Parameters:
+ pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
+ Can be either:
+
+ - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
+ Valid model ids should have an organization name, like `google/ddpm-celebahq-256`.
+ - A path to a *directory* containing model weights saved using [`~ModelMixin.save_config`], e.g.,
+ `./my_model_directory/`.
+
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
+ Path to a directory in which a downloaded pretrained model configuration should be cached if the
+ standard cache should not be used.
+ torch_dtype (`str` or `torch.dtype`, *optional*):
+ Override the default `torch.dtype` and load the model under this dtype. If `"auto"` is passed the dtype
+ will be automatically derived from the model's weights.
+ force_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
+ cached versions if they exist.
+ resume_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to delete incompletely received files. Will attempt to resume the download if such a
+ file exists.
+ proxies (`Dict[str, str]`, *optional*):
+ A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
+ output_loading_info(`bool`, *optional*, defaults to `False`):
+ Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
+ local_files_only(`bool`, *optional*, defaults to `False`):
+ Whether or not to only look at local files (i.e., do not try to download the model).
+ use_auth_token (`str` or *bool*, *optional*):
+ The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
+ when running `diffusers-cli login` (stored in `~/.huggingface`).
+ revision (`str`, *optional*, defaults to `"main"`):
+ The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
+ git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
+ identifier allowed by git.
+ subfolder (`str`, *optional*, defaults to `""`):
+ In case the relevant files are located inside a subfolder of the model repo (either remote in
+ huggingface.co or downloaded locally), you can specify the folder name here.
+
+ mirror (`str`, *optional*):
+ Mirror source to accelerate downloads in China. If you are from China and have an accessibility
+ problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety.
+ Please refer to the mirror site for more information.
+
+
+
+ Passing `use_auth_token=True`` is required when you want to use a private model.
+
+
+
+
+
+ Activate the special ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use
+ this method in a firewalled environment.
+
+
+
+ """
+ cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
+ ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False)
+ force_download = kwargs.pop("force_download", False)
+ resume_download = kwargs.pop("resume_download", False)
+ proxies = kwargs.pop("proxies", None)
+ output_loading_info = kwargs.pop("output_loading_info", False)
+ local_files_only = kwargs.pop("local_files_only", False)
+ use_auth_token = kwargs.pop("use_auth_token", None)
+ revision = kwargs.pop("revision", None)
+ from_auto_class = kwargs.pop("_from_auto", False)
+ torch_dtype = kwargs.pop("torch_dtype", None)
+ subfolder = kwargs.pop("subfolder", None)
+
+ user_agent = {"file_type": "model", "framework": "pytorch", "from_auto_class": from_auto_class}
+
+ # Load config if we don't provide a configuration
+ config_path = pretrained_model_name_or_path
+ model, unused_kwargs = cls.from_config(
+ config_path,
+ cache_dir=cache_dir,
+ return_unused_kwargs=True,
+ force_download=force_download,
+ resume_download=resume_download,
+ proxies=proxies,
+ local_files_only=local_files_only,
+ use_auth_token=use_auth_token,
+ revision=revision,
+ subfolder=subfolder,
+ **kwargs,
+ )
+
+ if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
+ raise ValueError(
+ f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}."
+ )
+ elif torch_dtype is not None:
+ model = model.to(torch_dtype)
+
+ model.register_to_config(_name_or_path=pretrained_model_name_or_path)
+ # This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the
+ # Load model
+ pretrained_model_name_or_path = str(pretrained_model_name_or_path)
+ if os.path.isdir(pretrained_model_name_or_path):
+ if os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)):
+ # Load from a PyTorch checkpoint
+ model_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
+ elif subfolder is not None and os.path.isfile(
+ os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_NAME)
+ ):
+ model_file = os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_NAME)
+ # model saved by oneflow, a directory
+ elif os.path.isdir(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)):
+ # Load from a PyTorch checkpoint
+ model_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
+ elif subfolder is not None and os.path.isdir(
+ os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_NAME)
+ ):
+ model_file = os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_NAME)
+ else:
+ raise EnvironmentError(
+ f"Error no file named {WEIGHTS_NAME} found in directory {pretrained_model_name_or_path}."
+ )
+ else:
+ try:
+ # Load from URL or cache if already cached
+ model_file = hf_hub_download(
+ pretrained_model_name_or_path,
+ filename=WEIGHTS_NAME,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ proxies=proxies,
+ resume_download=resume_download,
+ local_files_only=local_files_only,
+ use_auth_token=use_auth_token,
+ user_agent=user_agent,
+ subfolder=subfolder,
+ revision=revision,
+ )
+
+ except RepositoryNotFoundError:
+ raise EnvironmentError(
+ f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier "
+ "listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a "
+ "token having permission to this repo with `use_auth_token` or log in with `huggingface-cli "
+ "login` and pass `use_auth_token=True`."
+ )
+ except RevisionNotFoundError:
+ raise EnvironmentError(
+ f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for "
+ "this model name. Check the model page at "
+ f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
+ )
+ except EntryNotFoundError:
+ raise EnvironmentError(
+ f"{pretrained_model_name_or_path} does not appear to have a file named {WEIGHTS_NAME}."
+ )
+ except HTTPError as err:
+ raise EnvironmentError(
+ "There was a specific connection error when trying to load"
+ f" {pretrained_model_name_or_path}:\n{err}"
+ )
+ except ValueError:
+ raise EnvironmentError(
+ f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it"
+ f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a"
+ f" directory containing a file named {WEIGHTS_NAME} or"
+ " \nCheckout your internet connection or see how to run the library in"
+ " offline mode at 'https://huggingface.co/docs/diffusers/installation#offline-mode'."
+ )
+ except EnvironmentError:
+ raise EnvironmentError(
+ f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it from "
+ "'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
+ f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
+ f"containing a file named {WEIGHTS_NAME}"
+ )
+
+ # restore default dtype
+ state_dict = load_state_dict(model_file)
+ model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model(
+ model,
+ state_dict,
+ model_file,
+ pretrained_model_name_or_path,
+ ignore_mismatched_sizes=ignore_mismatched_sizes,
+ )
+
+ # Set model in evaluation mode to deactivate DropOut modules by default
+ model.eval()
+
+ if output_loading_info:
+ loading_info = {
+ "missing_keys": missing_keys,
+ "unexpected_keys": unexpected_keys,
+ "mismatched_keys": mismatched_keys,
+ "error_msgs": error_msgs,
+ }
+ return model, loading_info
+
+ return model
+
+ @classmethod
+ def _load_pretrained_model(
+ cls,
+ model,
+ state_dict,
+ resolved_archive_file,
+ pretrained_model_name_or_path,
+ ignore_mismatched_sizes=False,
+ ):
+ # Retrieve missing & unexpected_keys
+ model_state_dict = model.state_dict()
+ loaded_keys = [k for k in state_dict.keys()]
+
+ expected_keys = list(model_state_dict.keys())
+
+ original_loaded_keys = loaded_keys
+
+ missing_keys = list(set(expected_keys) - set(loaded_keys))
+ unexpected_keys = list(set(loaded_keys) - set(expected_keys))
+
+ # Make sure we are able to load base models as well as derived models (with heads)
+ model_to_load = model
+
+ def _find_mismatched_keys(
+ state_dict,
+ model_state_dict,
+ loaded_keys,
+ ignore_mismatched_sizes,
+ ):
+ mismatched_keys = []
+ if ignore_mismatched_sizes:
+ for checkpoint_key in loaded_keys:
+ model_key = checkpoint_key
+
+ if (
+ model_key in model_state_dict
+ and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape
+ ):
+ mismatched_keys.append(
+ (checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
+ )
+ del state_dict[checkpoint_key]
+ return mismatched_keys
+
+ if state_dict is not None:
+ # Whole checkpoint
+ mismatched_keys = _find_mismatched_keys(
+ state_dict,
+ model_state_dict,
+ original_loaded_keys,
+ ignore_mismatched_sizes,
+ )
+ error_msgs = _load_state_dict_into_model(model_to_load, state_dict)
+
+ if len(error_msgs) > 0:
+ error_msg = "\n\t".join(error_msgs)
+ if "size mismatch" in error_msg:
+ error_msg += (
+ "\n\tYou may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method."
+ )
+ raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}")
+
+ if len(unexpected_keys) > 0:
+ logger.warning(
+ f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when"
+ f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are"
+ f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task"
+ " or with another architecture (e.g. initializing a BertForSequenceClassification model from a"
+ " BertForPreTraining model).\n- This IS NOT expected if you are initializing"
+ f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly"
+ " identical (initializing a BertForSequenceClassification model from a"
+ " BertForSequenceClassification model)."
+ )
+ else:
+ logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n")
+ if len(missing_keys) > 0:
+ logger.warning(
+ f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
+ f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably"
+ " TRAIN this model on a down-stream task to be able to use it for predictions and inference."
+ )
+ elif len(mismatched_keys) == 0:
+ logger.info(
+ f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at"
+ f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the"
+ f" checkpoint was trained on, you can already use {model.__class__.__name__} for predictions"
+ " without further training."
+ )
+ if len(mismatched_keys) > 0:
+ mismatched_warning = "\n".join(
+ [
+ f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated"
+ for key, shape1, shape2 in mismatched_keys
+ ]
+ )
+ logger.warning(
+ f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
+ f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not"
+ f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be"
+ " able to use it for predictions and inference."
+ )
+
+ return model, missing_keys, unexpected_keys, mismatched_keys, error_msgs
+
+ @property
+ def device(self) -> device:
+ """
+ `torch.device`: The device on which the module is (assuming that all the module parameters are on the same
+ device).
+ """
+ return get_parameter_device(self)
+
+ @property
+ def dtype(self) -> torch.dtype:
+ """
+ `torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype).
+ """
+ return get_parameter_dtype(self)
+
+ def num_parameters(self, only_trainable: bool = False, exclude_embeddings: bool = False) -> int:
+ """
+ Get number of (optionally, trainable or non-embeddings) parameters in the module.
+
+ Args:
+ only_trainable (`bool`, *optional*, defaults to `False`):
+ Whether or not to return only the number of trainable parameters
+
+ exclude_embeddings (`bool`, *optional*, defaults to `False`):
+ Whether or not to return only the number of non-embeddings parameters
+
+ Returns:
+ `int`: The number of parameters.
+ """
+
+ if exclude_embeddings:
+ embedding_param_names = [
+ f"{name}.weight"
+ for name, module_type in self.named_modules()
+ if isinstance(module_type, torch.nn.Embedding)
+ ]
+ non_embedding_parameters = [
+ parameter for name, parameter in self.named_parameters() if name not in embedding_param_names
+ ]
+ return sum(p.numel() for p in non_embedding_parameters if p.requires_grad or not only_trainable)
+ else:
+ return sum(p.numel() for p in self.parameters() if p.requires_grad or not only_trainable)
+
+
+def unwrap_model(model: torch.nn.Module) -> torch.nn.Module:
+ """
+ Recursively unwraps a model from potential containers (as used in distributed training).
+
+ Args:
+ model (`torch.nn.Module`): The model to unwrap.
+ """
+ # since there could be multiple levels of wrapping, unwrap recursively
+ if hasattr(model, "module"):
+ return unwrap_model(model.module)
+ else:
+ return model
diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py
index 1242ad6fca7f..b8fb56e9186b 100644
--- a/src/diffusers/models/__init__.py
+++ b/src/diffusers/models/__init__.py
@@ -23,3 +23,6 @@
if is_flax_available():
from .unet_2d_condition_flax import FlaxUNet2DConditionModel
from .vae_flax import FlaxAutoencoderKL
+
+from .vae_oneflow import OneFlowAutoencoderKL
+from .unet_2d_condition_oneflow import OneFlowUNet2DConditionModel
diff --git a/src/diffusers/models/attention_oneflow.py b/src/diffusers/models/attention_oneflow.py
new file mode 100644
index 000000000000..443133ffc124
--- /dev/null
+++ b/src/diffusers/models/attention_oneflow.py
@@ -0,0 +1,361 @@
+import math
+from typing import Optional
+
+import oneflow as torch
+import oneflow.nn.functional as F
+from oneflow import nn
+
+
+class AttentionBlock(nn.Module):
+ """
+ An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted
+ to the N-d case.
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
+ Uses three q, k, v linear layers to compute attention.
+
+ Parameters:
+ channels (:obj:`int`): The number of channels in the input and output.
+ num_head_channels (:obj:`int`, *optional*):
+ The number of channels in each head. If None, then `num_heads` = 1.
+ num_groups (:obj:`int`, *optional*, defaults to 32): The number of groups to use for group norm.
+ rescale_output_factor (:obj:`float`, *optional*, defaults to 1.0): The factor to rescale the output by.
+ eps (:obj:`float`, *optional*, defaults to 1e-5): The epsilon value to use for group norm.
+ """
+
+ def __init__(
+ self,
+ channels: int,
+ num_head_channels: Optional[int] = None,
+ num_groups: int = 32,
+ rescale_output_factor: float = 1.0,
+ eps: float = 1e-5,
+ ):
+ super().__init__()
+ self.channels = channels
+
+ self.num_heads = channels // num_head_channels if num_head_channels is not None else 1
+ self.num_head_size = num_head_channels
+ self.group_norm = nn.GroupNorm(num_channels=channels, num_groups=num_groups, eps=eps, affine=True)
+
+ # define q,k,v as linear layers
+ self.query = nn.Linear(channels, channels)
+ self.key = nn.Linear(channels, channels)
+ self.value = nn.Linear(channels, channels)
+
+ self.rescale_output_factor = rescale_output_factor
+ self.proj_attn = nn.Linear(channels, channels, 1)
+
+ def transpose_for_scores(self, projection: torch.Tensor) -> torch.Tensor:
+ new_projection_shape = projection.size()[:-1] + (self.num_heads, -1)
+ # move heads to 2nd position (B, T, H * D) -> (B, T, H, D) -> (B, H, T, D)
+ new_projection = projection.view(new_projection_shape).permute(0, 2, 1, 3)
+ return new_projection
+
+ def forward(self, hidden_states):
+ residual = hidden_states
+ batch, channel, height, width = hidden_states.shape
+
+ # norm
+ hidden_states = self.group_norm(hidden_states)
+
+ hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2)
+
+ # proj to q, k, v
+ query_proj = self.query(hidden_states)
+ key_proj = self.key(hidden_states)
+ value_proj = self.value(hidden_states)
+
+ '''
+
+ # transpose
+ query_states = self.transpose_for_scores(query_proj)
+ key_states = self.transpose_for_scores(key_proj)
+ value_states = self.transpose_for_scores(value_proj)
+
+ # get scores
+ scale = 1 / math.sqrt(math.sqrt(self.channels / self.num_heads))
+
+ attention_scores = torch.matmul(query_states * scale, key_states.transpose(-1, -2) * scale)
+ attention_probs = torch.softmax(attention_scores.float(), dim=-1).type(attention_scores.dtype)
+
+ # compute attention output
+ hidden_states = torch.matmul(attention_probs, value_states)
+
+ hidden_states = hidden_states.permute(0, 2, 1, 3).contiguous()
+ new_hidden_states_shape = hidden_states.size()[:-2] + (self.channels,)
+ hidden_states = hidden_states.view(new_hidden_states_shape)
+ '''
+ hidden_states = torch._C.fused_multi_head_attention_inference(query_proj, key_proj, value_proj, self.num_heads)
+ # compute next hidden_states
+ hidden_states = self.proj_attn(hidden_states)
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width)
+
+ # res connect and rescale
+ hidden_states = (hidden_states + residual) / self.rescale_output_factor
+ return hidden_states
+
+
+class SpatialTransformer(nn.Module):
+ """
+ Transformer block for image-like data. First, project the input (aka embedding) and reshape to b, t, d. Then apply
+ standard transformer action. Finally, reshape to image.
+
+ Parameters:
+ in_channels (:obj:`int`): The number of channels in the input and output.
+ n_heads (:obj:`int`): The number of heads to use for multi-head attention.
+ d_head (:obj:`int`): The number of channels in each head.
+ depth (:obj:`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
+ dropout (:obj:`float`, *optional*, defaults to 0.1): The dropout probability to use.
+ context_dim (:obj:`int`, *optional*): The number of context dimensions to use.
+ """
+
+ def __init__(
+ self,
+ in_channels: int,
+ n_heads: int,
+ d_head: int,
+ depth: int = 1,
+ dropout: float = 0.0,
+ num_groups: int = 32,
+ context_dim: Optional[int] = None,
+ ):
+ super().__init__()
+ self.n_heads = n_heads
+ self.d_head = d_head
+ self.in_channels = in_channels
+ inner_dim = n_heads * d_head
+ self.norm = torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
+
+ self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
+
+ self.transformer_blocks = nn.ModuleList(
+ [
+ BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim)
+ for d in range(depth)
+ ]
+ )
+
+ self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
+
+ def _set_attention_slice(self, slice_size):
+ for block in self.transformer_blocks:
+ block._set_attention_slice(slice_size)
+
+ def forward(self, hidden_states, context=None):
+ # note: if no context is given, cross-attention defaults to self-attention
+ batch, channel, height, weight = hidden_states.shape
+ residual = hidden_states
+ hidden_states = self.norm(hidden_states)
+ hidden_states = self.proj_in(hidden_states)
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, channel)
+ for block in self.transformer_blocks:
+ hidden_states = block(hidden_states, context=context)
+ hidden_states = hidden_states.reshape(batch, height, weight, channel).permute(0, 3, 1, 2)
+ hidden_states = self.proj_out(hidden_states)
+ return hidden_states + residual
+
+
+class BasicTransformerBlock(nn.Module):
+ r"""
+ A basic Transformer block.
+
+ Parameters:
+ dim (:obj:`int`): The number of channels in the input and output.
+ n_heads (:obj:`int`): The number of heads to use for multi-head attention.
+ d_head (:obj:`int`): The number of channels in each head.
+ dropout (:obj:`float`, *optional*, defaults to 0.0): The dropout probability to use.
+ context_dim (:obj:`int`, *optional*): The size of the context vector for cross attention.
+ gated_ff (:obj:`bool`, *optional*, defaults to :obj:`False`): Whether to use a gated feed-forward network.
+ checkpoint (:obj:`bool`, *optional*, defaults to :obj:`False`): Whether to use checkpointing.
+ """
+
+ def __init__(
+ self,
+ dim: int,
+ n_heads: int,
+ d_head: int,
+ dropout=0.0,
+ context_dim: Optional[int] = None,
+ gated_ff: bool = True,
+ checkpoint: bool = True,
+ ):
+ super().__init__()
+ self.attn1 = CrossAttention(
+ query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout
+ ) # is a self-attention
+ self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
+ self.attn2 = CrossAttention(
+ query_dim=dim, context_dim=context_dim, heads=n_heads, dim_head=d_head, dropout=dropout
+ ) # is self-attn if context is none
+ self.norm1 = nn.LayerNorm(dim)
+ self.norm2 = nn.LayerNorm(dim)
+ self.norm3 = nn.LayerNorm(dim)
+ self.checkpoint = checkpoint
+
+ def _set_attention_slice(self, slice_size):
+ self.attn1._slice_size = slice_size
+ self.attn2._slice_size = slice_size
+
+ def forward(self, hidden_states, context=None):
+ hidden_states = hidden_states.contiguous() if hidden_states.device.type == "mps" else hidden_states
+ hidden_states = self.attn1(self.norm1(hidden_states)) + hidden_states
+ hidden_states = self.attn2(self.norm2(hidden_states), context=context) + hidden_states
+ hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
+ return hidden_states
+
+
+class CrossAttention(nn.Module):
+ r"""
+ A cross attention layer.
+
+ Parameters:
+ query_dim (:obj:`int`): The number of channels in the query.
+ context_dim (:obj:`int`, *optional*):
+ The number of channels in the context. If not given, defaults to `query_dim`.
+ heads (:obj:`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention.
+ dim_head (:obj:`int`, *optional*, defaults to 64): The number of channels in each head.
+ dropout (:obj:`float`, *optional*, defaults to 0.0): The dropout probability to use.
+ """
+
+ def __init__(
+ self, query_dim: int, context_dim: Optional[int] = None, heads: int = 8, dim_head: int = 64, dropout: int = 0.0
+ ):
+ super().__init__()
+ inner_dim = dim_head * heads
+ context_dim = context_dim if context_dim is not None else query_dim
+
+ self.scale = dim_head**-0.5
+ self.heads = heads
+ # for slice_size > 0 the attention score computation
+ # is split across the batch axis to save memory
+ # You can set slice_size with `set_attention_slice`
+ self._slice_size = None
+
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
+
+ self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
+
+ def reshape_heads_to_batch_dim(self, tensor):
+ batch_size, seq_len, dim = tensor.shape
+ head_size = self.heads
+ tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
+ tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
+ return tensor
+
+ def reshape_batch_dim_to_heads(self, tensor):
+ batch_size, seq_len, dim = tensor.shape
+ head_size = self.heads
+ tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
+ tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
+ return tensor
+
+ def forward(self, hidden_states, context=None, mask=None):
+ batch_size, sequence_length, _ = hidden_states.shape
+
+ query = self.to_q(hidden_states)
+ context = context if context is not None else hidden_states
+ key = self.to_k(context)
+ value = self.to_v(context)
+
+ '''
+ dim = query.shape[-1]
+
+ query = self.reshape_heads_to_batch_dim(query)
+ key = self.reshape_heads_to_batch_dim(key)
+ value = self.reshape_heads_to_batch_dim(value)
+
+ # TODO(PVP) - mask is currently never used. Remember to re-implement when used
+
+ # attention, what we cannot get enough of
+
+ if self._slice_size is None or query.shape[0] // self._slice_size == 1:
+ hidden_states = self._attention(query, key, value)
+ else:
+ hidden_states = self._sliced_attention(query, key, value, sequence_length, dim)
+ '''
+ hidden_states = torch._C.fused_multi_head_attention_inference(query, key, value, self.heads)
+
+ return self.to_out(hidden_states)
+
+ def _attention(self, query, key, value):
+ attention_scores = torch.matmul(query, key.transpose(-1, -2)) * self.scale
+ attention_probs = attention_scores.softmax(dim=-1)
+ # compute attention output
+ hidden_states = torch.matmul(attention_probs, value)
+ # reshape hidden_states
+ hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
+ return hidden_states
+
+ def _sliced_attention(self, query, key, value, sequence_length, dim):
+ batch_size_attention = query.shape[0]
+ hidden_states = torch.zeros(
+ (batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype
+ )
+ slice_size = self._slice_size if self._slice_size is not None else hidden_states.shape[0]
+ for i in range(hidden_states.shape[0] // slice_size):
+ start_idx = i * slice_size
+ end_idx = (i + 1) * slice_size
+ attn_slice = torch.matmul(query[start_idx:end_idx], key[start_idx:end_idx].transpose(1, 2)) * self.scale
+ attn_slice = attn_slice.softmax(dim=-1)
+ attn_slice = torch.matmul(attn_slice, value[start_idx:end_idx])
+
+ hidden_states[start_idx:end_idx] = attn_slice
+
+ # reshape hidden_states
+ hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
+ return hidden_states
+
+
+class FeedForward(nn.Module):
+ r"""
+ A feed-forward layer.
+
+ Parameters:
+ dim (:obj:`int`): The number of channels in the input.
+ dim_out (:obj:`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
+ mult (:obj:`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
+ glu (:obj:`bool`, *optional*, defaults to :obj:`False`): Whether to use GLU activation.
+ dropout (:obj:`float`, *optional*, defaults to 0.0): The dropout probability to use.
+ """
+
+ def __init__(
+ self, dim: int, dim_out: Optional[int] = None, mult: int = 4, glu: bool = False, dropout: float = 0.0
+ ):
+ super().__init__()
+ inner_dim = int(dim * mult)
+ dim_out = dim_out if dim_out is not None else dim
+ project_in = GEGLU(dim, inner_dim)
+
+ self.net = nn.Sequential(project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out))
+
+ def forward(self, hidden_states):
+ return self.net(hidden_states)
+
+
+# feedforward
+class GEGLU(nn.Module):
+ r"""
+ A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202.
+
+ Parameters:
+ dim_in (:obj:`int`): The number of channels in the input.
+ dim_out (:obj:`int`): The number of channels in the output.
+ """
+
+ def __init__(self, dim_in: int, dim_out: int):
+ super().__init__()
+ self.proj = nn.Linear(dim_in, dim_out * 2)
+
+ def forward(self, hidden_states):
+ x_shape = hidden_states.shape
+ if len(x_shape) != 2:
+ hidden_states = hidden_states.reshape(-1, x_shape[-1])
+ out = torch._C.fused_geglu(hidden_states, self.proj.weight, self.proj.bias)
+ if len(x_shape) != 2:
+ out_shape = x_shape[0:len(x_shape) -1 ] + (-1, )
+ out = out.reshape(out_shape)
+ return out
+ hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1)
+ return hidden_states * F.gelu(gate)
diff --git a/src/diffusers/models/embeddings_oneflow.py b/src/diffusers/models/embeddings_oneflow.py
new file mode 100644
index 000000000000..944aefcee277
--- /dev/null
+++ b/src/diffusers/models/embeddings_oneflow.py
@@ -0,0 +1,115 @@
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import math
+
+import numpy as np
+import oneflow as torch
+from oneflow import nn
+
+
+def get_timestep_embedding(
+ timesteps: torch.Tensor,
+ embedding_dim: int,
+ flip_sin_to_cos: bool = False,
+ downscale_freq_shift: float = 1,
+ scale: float = 1,
+ max_period: int = 10000,
+):
+ """
+ This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
+
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
+ These may be fractional.
+ :param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the
+ embeddings. :return: an [N x dim] Tensor of positional embeddings.
+ """
+ assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
+
+ half_dim = embedding_dim // 2
+ exponent = -math.log(max_period) * torch.arange(start=0, end=half_dim, dtype=torch.float32)
+ exponent = exponent / (half_dim - downscale_freq_shift)
+
+ emb = torch.exp(exponent).to(device=timesteps.device)
+ emb = timesteps[:, None].float() * emb[None, :]
+
+ # scale embeddings
+ emb = scale * emb
+
+ # concat sine and cosine embeddings
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
+
+ # flip sine and cosine embeddings
+ if flip_sin_to_cos:
+ emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
+
+ # zero pad
+ if embedding_dim % 2 == 1:
+ emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
+ return emb
+
+
+class TimestepEmbedding(nn.Module):
+ def __init__(self, channel: int, time_embed_dim: int, act_fn: str = "silu"):
+ super().__init__()
+
+ self.linear_1 = nn.Linear(channel, time_embed_dim)
+ self.act = None
+ if act_fn == "silu":
+ self.act = nn.SiLU()
+ self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim)
+
+ def forward(self, sample):
+ sample = self.linear_1(sample)
+
+ if self.act is not None:
+ sample = self.act(sample)
+
+ sample = self.linear_2(sample)
+ return sample
+
+
+class Timesteps(nn.Module):
+ def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float):
+ super().__init__()
+ self.num_channels = num_channels
+ self.flip_sin_to_cos = flip_sin_to_cos
+ self.downscale_freq_shift = downscale_freq_shift
+
+ def forward(self, timesteps):
+ t_emb = get_timestep_embedding(
+ timesteps,
+ self.num_channels,
+ flip_sin_to_cos=self.flip_sin_to_cos,
+ downscale_freq_shift=self.downscale_freq_shift,
+ )
+ return t_emb
+
+
+class GaussianFourierProjection(nn.Module):
+ """Gaussian Fourier embeddings for noise levels."""
+
+ def __init__(self, embedding_size: int = 256, scale: float = 1.0):
+ super().__init__()
+ self.weight = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False)
+
+ # to delete later
+ self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False)
+
+ self.weight = self.W
+
+ def forward(self, x):
+ x = torch.log(x)
+ x_proj = x[:, None] * self.weight[None, :] * 2 * np.pi
+ out = torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)
+ return out
diff --git a/src/diffusers/models/resnet_oneflow.py b/src/diffusers/models/resnet_oneflow.py
new file mode 100644
index 000000000000..a48cfd2b62c1
--- /dev/null
+++ b/src/diffusers/models/resnet_oneflow.py
@@ -0,0 +1,479 @@
+from functools import partial
+
+import oneflow as torch
+import oneflow.nn as nn
+import oneflow.nn.functional as F
+
+
+class Upsample2D(nn.Module):
+ """
+ An upsampling layer with an optional convolution.
+
+ :param channels: channels in the inputs and outputs. :param use_conv: a bool determining if a convolution is
+ applied. :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
+ upsampling occurs in the inner-two dimensions.
+ """
+
+ def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.use_conv_transpose = use_conv_transpose
+ self.name_ = name
+
+ conv = None
+ if use_conv_transpose:
+ conv = nn.ConvTranspose2d(channels, self.out_channels, 4, 2, 1)
+ elif use_conv:
+ conv = nn.Conv2d(self.channels, self.out_channels, 3, padding=1)
+
+ # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
+ if name == "conv":
+ self.conv = conv
+ else:
+ self.Conv2d_0 = conv
+
+ def forward(self, x):
+ assert x.shape[1] == self.channels
+ if self.use_conv_transpose:
+ return self.conv(x)
+
+ x = F.interpolate(x, scale_factor=2.0, mode="nearest")
+
+ # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
+ if self.use_conv:
+ if self.name_ == "conv":
+ x = self.conv(x)
+ else:
+ x = self.Conv2d_0(x)
+
+ return x
+
+
+class Downsample2D(nn.Module):
+ """
+ A downsampling layer with an optional convolution.
+
+ :param channels: channels in the inputs and outputs. :param use_conv: a bool determining if a convolution is
+ applied. :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
+ downsampling occurs in the inner-two dimensions.
+ """
+
+ def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.padding = padding
+ stride = 2
+ self.name = name
+
+ if use_conv:
+ conv = nn.Conv2d(self.channels, self.out_channels, 3, stride=stride, padding=padding)
+ else:
+ assert self.channels == self.out_channels
+ conv = nn.AvgPool2d(kernel_size=stride, stride=stride)
+
+ # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
+ if name == "conv":
+ self.Conv2d_0 = conv
+ self.conv = conv
+ elif name == "Conv2d_0":
+ self.conv = conv
+ else:
+ self.conv = conv
+
+ def forward(self, x):
+ assert x.shape[1] == self.channels
+ if self.use_conv and self.padding == 0:
+ pad = (0, 1, 0, 1)
+ x = F.pad(x, pad, mode="constant", value=0)
+
+ assert x.shape[1] == self.channels
+ x = self.conv(x)
+
+ return x
+
+
+class FirUpsample2D(nn.Module):
+ def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)):
+ super().__init__()
+ out_channels = out_channels if out_channels else channels
+ if use_conv:
+ self.Conv2d_0 = nn.Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1)
+ self.use_conv = use_conv
+ self.fir_kernel = fir_kernel
+ self.out_channels = out_channels
+
+ def _upsample_2d(self, x, weight=None, kernel=None, factor=2, gain=1):
+ """Fused `upsample_2d()` followed by `Conv2d()`.
+
+ Args:
+ Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
+ efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of arbitrary:
+ order.
+ x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
+ C]`.
+ weight: Weight tensor of the shape `[filterH, filterW, inChannels,
+ outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`.
+ kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
+ (separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling.
+ factor: Integer upsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0).
+
+ Returns:
+ Tensor of the shape `[N, C, H * factor, W * factor]` or `[N, H * factor, W * factor, C]`, and same datatype as
+ `x`.
+ """
+
+ assert isinstance(factor, int) and factor >= 1
+
+ # Setup filter kernel.
+ if kernel is None:
+ kernel = [1] * factor
+
+ # setup kernel
+ kernel = torch.tensor(kernel, dtype=torch.float32)
+ if kernel.ndim == 1:
+ kernel = torch.outer(kernel, kernel)
+ kernel /= torch.sum(kernel)
+
+ kernel = kernel * (gain * (factor**2))
+
+ if self.use_conv:
+ convH = weight.shape[2]
+ convW = weight.shape[3]
+ inC = weight.shape[1]
+
+ p = (kernel.shape[0] - factor) - (convW - 1)
+
+ stride = (factor, factor)
+ # Determine data dimensions.
+ output_shape = ((x.shape[2] - 1) * factor + convH, (x.shape[3] - 1) * factor + convW)
+ output_padding = (
+ output_shape[0] - (x.shape[2] - 1) * stride[0] - convH,
+ output_shape[1] - (x.shape[3] - 1) * stride[1] - convW,
+ )
+ assert output_padding[0] >= 0 and output_padding[1] >= 0
+ inC = weight.shape[1]
+ num_groups = x.shape[1] // inC
+
+ # Transpose weights.
+ weight = torch.reshape(weight, (num_groups, -1, inC, convH, convW))
+ weight = torch.flip(weight, dims=[3, 4]).permute(0, 2, 1, 3, 4)
+ weight = torch.reshape(weight, (num_groups * inC, -1, convH, convW))
+
+ x = F.conv_transpose2d(x, weight, stride=stride, output_padding=output_padding, padding=0)
+
+ x = upfirdn2d_native(x, torch.tensor(kernel, device=x.device), pad=((p + 1) // 2 + factor - 1, p // 2 + 1))
+ else:
+ p = kernel.shape[0] - factor
+ x = upfirdn2d_native(
+ x, torch.tensor(kernel, device=x.device), up=factor, pad=((p + 1) // 2 + factor - 1, p // 2)
+ )
+
+ return x
+
+ def forward(self, x):
+ if self.use_conv:
+ height = self._upsample_2d(x, self.Conv2d_0.weight, kernel=self.fir_kernel)
+ height = height + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
+ else:
+ height = self._upsample_2d(x, kernel=self.fir_kernel, factor=2)
+
+ return height
+
+
+class FirDownsample2D(nn.Module):
+ def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)):
+ super().__init__()
+ out_channels = out_channels if out_channels else channels
+ if use_conv:
+ self.Conv2d_0 = nn.Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1)
+ self.fir_kernel = fir_kernel
+ self.use_conv = use_conv
+ self.out_channels = out_channels
+
+ def _downsample_2d(self, x, weight=None, kernel=None, factor=2, gain=1):
+ """Fused `Conv2d()` followed by `downsample_2d()`.
+
+ Args:
+ Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
+ efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of arbitrary:
+ order.
+ x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`. w: Weight tensor of the shape `[filterH,
+ filterW, inChannels, outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] //
+ numGroups`. k: FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] *
+ factor`, which corresponds to average pooling. factor: Integer downsampling factor (default: 2). gain:
+ Scaling factor for signal magnitude (default: 1.0).
+
+ Returns:
+ Tensor of the shape `[N, C, H // factor, W // factor]` or `[N, H // factor, W // factor, C]`, and same
+ datatype as `x`.
+ """
+
+ assert isinstance(factor, int) and factor >= 1
+ if kernel is None:
+ kernel = [1] * factor
+
+ # setup kernel
+ kernel = torch.tensor(kernel, dtype=torch.float32)
+ if kernel.ndim == 1:
+ kernel = torch.outer(kernel, kernel)
+ kernel /= torch.sum(kernel)
+
+ kernel = kernel * gain
+
+ if self.use_conv:
+ _, _, convH, convW = weight.shape
+ p = (kernel.shape[0] - factor) + (convW - 1)
+ s = [factor, factor]
+ x = upfirdn2d_native(x, torch.tensor(kernel, device=x.device), pad=((p + 1) // 2, p // 2))
+ x = F.conv2d(x, weight, stride=s, padding=0)
+ else:
+ p = kernel.shape[0] - factor
+ x = upfirdn2d_native(x, torch.tensor(kernel, device=x.device), down=factor, pad=((p + 1) // 2, p // 2))
+
+ return x
+
+ def forward(self, x):
+ if self.use_conv:
+ x = self._downsample_2d(x, weight=self.Conv2d_0.weight, kernel=self.fir_kernel)
+ x = x + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
+ else:
+ x = self._downsample_2d(x, kernel=self.fir_kernel, factor=2)
+
+ return x
+
+
+class ResnetBlock2D(nn.Module):
+ def __init__(
+ self,
+ *,
+ in_channels,
+ out_channels=None,
+ conv_shortcut=False,
+ dropout=0.0,
+ temb_channels=512,
+ groups=32,
+ groups_out=None,
+ pre_norm=True,
+ eps=1e-6,
+ non_linearity="swish",
+ time_embedding_norm="default",
+ kernel=None,
+ output_scale_factor=1.0,
+ use_in_shortcut=None,
+ up=False,
+ down=False,
+ ):
+ super().__init__()
+ self.pre_norm = pre_norm
+ self.pre_norm = True
+ self.in_channels = in_channels
+ out_channels = in_channels if out_channels is None else out_channels
+ self.out_channels = out_channels
+ self.use_conv_shortcut = conv_shortcut
+ self.time_embedding_norm = time_embedding_norm
+ self.up = up
+ self.down = down
+ self.output_scale_factor = output_scale_factor
+
+ if groups_out is None:
+ groups_out = groups
+
+ self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
+
+ self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
+
+ if temb_channels is not None:
+ self.time_emb_proj = torch.nn.Linear(temb_channels, out_channels)
+ else:
+ self.time_emb_proj = None
+
+ self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
+ self.dropout = torch.nn.Dropout(dropout)
+ self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
+
+ if non_linearity == "swish":
+ self.nonlinearity = lambda x: F.silu(x)
+ elif non_linearity == "mish":
+ self.nonlinearity = Mish()
+ elif non_linearity == "silu":
+ self.nonlinearity = nn.SiLU()
+
+ self.upsample = self.downsample = None
+ if self.up:
+ if kernel == "fir":
+ fir_kernel = (1, 3, 3, 1)
+ self.upsample = lambda x: upsample_2d(x, kernel=fir_kernel)
+ elif kernel == "sde_vp":
+ self.upsample = partial(F.interpolate, scale_factor=2.0, mode="nearest")
+ else:
+ self.upsample = Upsample2D(in_channels, use_conv=False)
+ elif self.down:
+ if kernel == "fir":
+ fir_kernel = (1, 3, 3, 1)
+ self.downsample = lambda x: downsample_2d(x, kernel=fir_kernel)
+ elif kernel == "sde_vp":
+ self.downsample = partial(F.avg_pool2d, kernel_size=2, stride=2)
+ else:
+ self.downsample = Downsample2D(in_channels, use_conv=False, padding=1, name="op")
+
+ self.use_in_shortcut = self.in_channels != self.out_channels if use_in_shortcut is None else use_in_shortcut
+
+ self.conv_shortcut = None
+ if self.use_in_shortcut:
+ self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
+
+ def forward(self, x, temb):
+ hidden_states = x
+
+ # make sure hidden states is in float32
+ # when running in half-precision
+ hidden_states = self.norm1(hidden_states).type(hidden_states.dtype)
+ hidden_states = self.nonlinearity(hidden_states)
+
+ if self.upsample is not None:
+ x = self.upsample(x)
+ hidden_states = self.upsample(hidden_states)
+ elif self.downsample is not None:
+ x = self.downsample(x)
+ hidden_states = self.downsample(hidden_states)
+
+ hidden_states = self.conv1(hidden_states)
+
+ if temb is not None:
+ temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None]
+ hidden_states = hidden_states + temb
+
+ # make sure hidden states is in float32
+ # when running in half-precision
+ hidden_states = self.norm2(hidden_states).type(hidden_states.dtype)
+ hidden_states = self.nonlinearity(hidden_states)
+
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = self.conv2(hidden_states)
+
+ if self.conv_shortcut is not None:
+ x = self.conv_shortcut(x)
+
+ out = (x + hidden_states) / self.output_scale_factor
+
+ return out
+
+
+class Mish(torch.nn.Module):
+ def forward(self, x):
+ return x * torch.tanh(torch.nn.functional.softplus(x))
+
+
+def upsample_2d(x, kernel=None, factor=2, gain=1):
+ r"""Upsample2D a batch of 2D images with the given filter.
+
+ Args:
+ Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given
+ filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the specified
+ `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its shape is a:
+ multiple of the upsampling factor.
+ x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
+ C]`.
+ k: FIR filter of the shape `[firH, firW]` or `[firN]`
+ (separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling.
+ factor: Integer upsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0).
+
+ Returns:
+ Tensor of the shape `[N, C, H * factor, W * factor]`
+ """
+ assert isinstance(factor, int) and factor >= 1
+ if kernel is None:
+ kernel = [1] * factor
+
+ kernel = torch.tensor(kernel, dtype=torch.float32)
+ if kernel.ndim == 1:
+ kernel = torch.outer(kernel, kernel)
+ kernel /= torch.sum(kernel)
+
+ kernel = kernel * (gain * (factor**2))
+ p = kernel.shape[0] - factor
+ return upfirdn2d_native(x, kernel.to(device=x.device), up=factor, pad=((p + 1) // 2 + factor - 1, p // 2))
+
+
+def downsample_2d(x, kernel=None, factor=2, gain=1):
+ r"""Downsample2D a batch of 2D images with the given filter.
+
+ Args:
+ Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and downsamples each image with the
+ given filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the
+ specified `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its
+ shape is a multiple of the downsampling factor.
+ x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
+ C]`.
+ kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
+ (separable). The default is `[1] * factor`, which corresponds to average pooling.
+ factor: Integer downsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0).
+
+ Returns:
+ Tensor of the shape `[N, C, H // factor, W // factor]`
+ """
+
+ assert isinstance(factor, int) and factor >= 1
+ if kernel is None:
+ kernel = [1] * factor
+
+ kernel = torch.tensor(kernel, dtype=torch.float32)
+ if kernel.ndim == 1:
+ kernel = torch.outer(kernel, kernel)
+ kernel /= torch.sum(kernel)
+
+ kernel = kernel * gain
+ p = kernel.shape[0] - factor
+ return upfirdn2d_native(x, kernel.to(device=x.device), down=factor, pad=((p + 1) // 2, p // 2))
+
+
+def upfirdn2d_native(input, kernel, up=1, down=1, pad=(0, 0)):
+ up_x = up_y = up
+ down_x = down_y = down
+ pad_x0 = pad_y0 = pad[0]
+ pad_x1 = pad_y1 = pad[1]
+
+ _, channel, in_h, in_w = input.shape
+ input = input.reshape(-1, in_h, in_w, 1)
+
+ _, in_h, in_w, minor = input.shape
+ kernel_h, kernel_w = kernel.shape
+
+ out = input.view(-1, in_h, 1, in_w, 1, minor)
+
+ # Temporary workaround for mps specific issue: https://github.com/pytorch/pytorch/issues/84535
+ if input.device.type == "mps":
+ out = out.to("cpu")
+ out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
+ out = out.view(-1, in_h * up_y, in_w * up_x, minor)
+
+ out = F.pad(out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)])
+ out = out.to(input.device) # Move back to mps if necessary
+ out = out[
+ :,
+ max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),
+ max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0),
+ :,
+ ]
+
+ out = out.permute(0, 3, 1, 2)
+ out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1])
+ w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
+ out = F.conv2d(out, w)
+ out = out.reshape(
+ -1,
+ minor,
+ in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
+ in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
+ )
+ out = out.permute(0, 2, 3, 1)
+ out = out[:, ::down_y, ::down_x, :]
+
+ out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
+ out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
+
+ return out.view(-1, channel, out_h, out_w)
diff --git a/src/diffusers/models/unet_2d_condition_oneflow.py b/src/diffusers/models/unet_2d_condition_oneflow.py
new file mode 100644
index 000000000000..57e34c7d8bf9
--- /dev/null
+++ b/src/diffusers/models/unet_2d_condition_oneflow.py
@@ -0,0 +1,289 @@
+from dataclasses import dataclass
+from typing import Optional, Tuple, Union
+
+import oneflow as torch
+import oneflow.nn as nn
+import oneflow.utils.checkpoint
+
+from ..configuration_utils import ConfigMixin, register_to_config
+from ..modeling_oneflow_utils import OneFlowModelMixin as ModelMixin
+from ..utils import BaseOutput
+from .embeddings_oneflow import TimestepEmbedding, Timesteps
+from .unet_blocks_oneflow import (
+ CrossAttnDownBlock2D,
+ CrossAttnUpBlock2D,
+ DownBlock2D,
+ UNetMidBlock2DCrossAttn,
+ UpBlock2D,
+ get_down_block,
+ get_up_block,
+)
+
+
+@dataclass
+class OneFlowUNet2DConditionOutput(BaseOutput):
+ """
+ Args:
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
+ Hidden states conditioned on `encoder_hidden_states` input. Output of last layer of model.
+ """
+
+ sample: torch.FloatTensor
+
+
+class OneFlowUNet2DConditionModel(ModelMixin, ConfigMixin):
+ r"""
+ UNet2DConditionModel is a conditional 2D UNet model that takes in a noisy sample, conditional state, and a timestep
+ and returns sample shaped output.
+
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
+ implements for all the models (such as downloading or saving, etc.)
+
+ Parameters:
+ sample_size (`int`, *optional*): The size of the input sample.
+ in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample.
+ out_channels (`int`, *optional*, defaults to 4): The number of channels in the output.
+ center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
+ flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
+ Whether to flip the sin to cos in the time embedding.
+ freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
+ The tuple of downsample blocks to use.
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D",)`):
+ The tuple of upsample blocks to use.
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
+ The tuple of output channels for each block.
+ layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
+ downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
+ mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
+ norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
+ norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
+ cross_attention_dim (`int`, *optional*, defaults to 1280): The dimension of the cross attention features.
+ attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
+ """
+
+ _supports_gradient_checkpointing = True
+
+ @register_to_config
+ def __init__(
+ self,
+ sample_size: Optional[int] = None,
+ in_channels: int = 4,
+ out_channels: int = 4,
+ center_input_sample: bool = False,
+ flip_sin_to_cos: bool = True,
+ freq_shift: int = 0,
+ down_block_types: Tuple[str] = (
+ "CrossAttnDownBlock2D",
+ "CrossAttnDownBlock2D",
+ "CrossAttnDownBlock2D",
+ "DownBlock2D",
+ ),
+ up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
+ layers_per_block: int = 2,
+ downsample_padding: int = 1,
+ mid_block_scale_factor: float = 1,
+ act_fn: str = "silu",
+ norm_num_groups: int = 32,
+ norm_eps: float = 1e-5,
+ cross_attention_dim: int = 1280,
+ attention_head_dim: int = 8,
+ ):
+ super().__init__()
+
+ self.sample_size = sample_size
+ time_embed_dim = block_out_channels[0] * 4
+
+ # input
+ self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1))
+
+ # time
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
+ timestep_input_dim = block_out_channels[0]
+
+ self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
+
+ self.down_blocks = nn.ModuleList([])
+ self.mid_block = None
+ self.up_blocks = nn.ModuleList([])
+
+ # down
+ output_channel = block_out_channels[0]
+ for i, down_block_type in enumerate(down_block_types):
+ input_channel = output_channel
+ output_channel = block_out_channels[i]
+ is_final_block = i == len(block_out_channels) - 1
+
+ down_block = get_down_block(
+ down_block_type,
+ num_layers=layers_per_block,
+ in_channels=input_channel,
+ out_channels=output_channel,
+ temb_channels=time_embed_dim,
+ add_downsample=not is_final_block,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ resnet_groups=norm_num_groups,
+ cross_attention_dim=cross_attention_dim,
+ attn_num_head_channels=attention_head_dim,
+ downsample_padding=downsample_padding,
+ )
+ self.down_blocks.append(down_block)
+
+ # mid
+ self.mid_block = UNetMidBlock2DCrossAttn(
+ in_channels=block_out_channels[-1],
+ temb_channels=time_embed_dim,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ output_scale_factor=mid_block_scale_factor,
+ resnet_time_scale_shift="default",
+ cross_attention_dim=cross_attention_dim,
+ attn_num_head_channels=attention_head_dim,
+ resnet_groups=norm_num_groups,
+ )
+
+ # up
+ reversed_block_out_channels = list(reversed(block_out_channels))
+ output_channel = reversed_block_out_channels[0]
+ for i, up_block_type in enumerate(up_block_types):
+ prev_output_channel = output_channel
+ output_channel = reversed_block_out_channels[i]
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
+
+ is_final_block = i == len(block_out_channels) - 1
+
+ up_block = get_up_block(
+ up_block_type,
+ num_layers=layers_per_block + 1,
+ in_channels=input_channel,
+ out_channels=output_channel,
+ prev_output_channel=prev_output_channel,
+ temb_channels=time_embed_dim,
+ add_upsample=not is_final_block,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ resnet_groups=norm_num_groups,
+ cross_attention_dim=cross_attention_dim,
+ attn_num_head_channels=attention_head_dim,
+ )
+ self.up_blocks.append(up_block)
+ prev_output_channel = output_channel
+
+ # out
+ self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps)
+ self.conv_act = nn.SiLU()
+ self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1)
+
+ def set_attention_slice(self, slice_size):
+ if slice_size is not None and self.config.attention_head_dim % slice_size != 0:
+ raise ValueError(
+ f"Make sure slice_size {slice_size} is a divisor of "
+ f"the number of heads used in cross_attention {self.config.attention_head_dim}"
+ )
+ if slice_size is not None and slice_size > self.config.attention_head_dim:
+ raise ValueError(
+ f"Chunk_size {slice_size} has to be smaller or equal to "
+ f"the number of heads used in cross_attention {self.config.attention_head_dim}"
+ )
+
+ for block in self.down_blocks:
+ if hasattr(block, "attentions") and block.attentions is not None:
+ block.set_attention_slice(slice_size)
+
+ self.mid_block.set_attention_slice(slice_size)
+
+ for block in self.up_blocks:
+ if hasattr(block, "attentions") and block.attentions is not None:
+ block.set_attention_slice(slice_size)
+
+ def _set_gradient_checkpointing(self, module, value=False):
+ if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D, CrossAttnUpBlock2D, UpBlock2D)):
+ module.gradient_checkpointing = value
+
+ def forward(
+ self,
+ sample: torch.FloatTensor,
+ timestep: Union[torch.Tensor, float, int],
+ encoder_hidden_states: torch.Tensor,
+ return_dict: bool = True,
+ ) -> Union[OneFlowUNet2DConditionOutput, Tuple]:
+ """r
+ Args:
+ sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
+ timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
+ encoder_hidden_states (`torch.FloatTensor`): (batch, channel, height, width) encoder hidden states
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`models.unet_2d_condition.OneFlowUNet2DConditionOutput`] instead of a plain tuple.
+
+ Returns:
+ [`~models.unet_2d_condition.OneFlowUNet2DConditionOutput`] or `tuple`:
+ [`~models.unet_2d_condition.OneFlowUNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When
+ returning a tuple, the first element is the sample tensor.
+ """
+ # 0. center input if necessary
+ if self.config_dict().center_input_sample:
+ sample = 2 * sample - 1.0
+
+ # 1. time
+ timesteps = timestep
+ if not torch.is_tensor(timesteps):
+ timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
+ elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
+ timesteps = timesteps.to(dtype=torch.float32)
+ timesteps = timesteps[None].to(device=sample.device)
+
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timesteps = timesteps.expand(sample.shape[0])
+
+ t_emb = self.time_proj(timesteps)
+ emb = self.time_embedding(t_emb)
+
+ # 2. pre-process
+ sample = self.conv_in(sample)
+
+ # 3. down
+ down_block_res_samples = (sample,)
+ for downsample_block in self.down_blocks:
+ if hasattr(downsample_block, "attentions") and downsample_block.attentions is not None:
+ sample, res_samples = downsample_block(
+ hidden_states=sample,
+ temb=emb,
+ encoder_hidden_states=encoder_hidden_states,
+ )
+ else:
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
+
+ down_block_res_samples += res_samples
+
+ # 4. mid
+ sample = self.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states)
+
+ # 5. up
+ for upsample_block in self.up_blocks:
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
+
+ if hasattr(upsample_block, "attentions") and upsample_block.attentions is not None:
+ sample = upsample_block(
+ hidden_states=sample,
+ temb=emb,
+ res_hidden_states_tuple=res_samples,
+ encoder_hidden_states=encoder_hidden_states,
+ )
+ else:
+ sample = upsample_block(hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples)
+
+ # 6. post-process
+ # make sure hidden states is in float32
+ # when running in half-precision
+ sample = self.conv_norm_out(sample.float()).type(sample.dtype)
+ sample = self.conv_act(sample)
+ sample = self.conv_out(sample)
+
+ if not return_dict:
+ return (sample,)
+
+ return OneFlowUNet2DConditionOutput(sample=sample)
diff --git a/src/diffusers/models/unet_blocks_oneflow.py b/src/diffusers/models/unet_blocks_oneflow.py
new file mode 100644
index 000000000000..af4d29824667
--- /dev/null
+++ b/src/diffusers/models/unet_blocks_oneflow.py
@@ -0,0 +1,1557 @@
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+
+import numpy as np
+
+# limitations under the License.
+import oneflow as torch
+from oneflow import nn
+
+from .attention_oneflow import AttentionBlock, SpatialTransformer
+from .resnet_oneflow import Downsample2D, FirDownsample2D, FirUpsample2D, ResnetBlock2D, Upsample2D
+
+
+def get_down_block(
+ down_block_type,
+ num_layers,
+ in_channels,
+ out_channels,
+ temb_channels,
+ add_downsample,
+ resnet_eps,
+ resnet_act_fn,
+ attn_num_head_channels,
+ resnet_groups=None,
+ cross_attention_dim=None,
+ downsample_padding=None,
+):
+ down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
+ if down_block_type == "DownBlock2D":
+ return DownBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ add_downsample=add_downsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_groups=resnet_groups,
+ downsample_padding=downsample_padding,
+ )
+ elif down_block_type == "AttnDownBlock2D":
+ return AttnDownBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ add_downsample=add_downsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_groups=resnet_groups,
+ downsample_padding=downsample_padding,
+ attn_num_head_channels=attn_num_head_channels,
+ )
+ elif down_block_type == "CrossAttnDownBlock2D":
+ if cross_attention_dim is None:
+ raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock2D")
+ return CrossAttnDownBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ add_downsample=add_downsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_groups=resnet_groups,
+ downsample_padding=downsample_padding,
+ cross_attention_dim=cross_attention_dim,
+ attn_num_head_channels=attn_num_head_channels,
+ )
+ elif down_block_type == "SkipDownBlock2D":
+ return SkipDownBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ add_downsample=add_downsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ downsample_padding=downsample_padding,
+ )
+ elif down_block_type == "AttnSkipDownBlock2D":
+ return AttnSkipDownBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ add_downsample=add_downsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ downsample_padding=downsample_padding,
+ attn_num_head_channels=attn_num_head_channels,
+ )
+ elif down_block_type == "DownEncoderBlock2D":
+ return DownEncoderBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ add_downsample=add_downsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_groups=resnet_groups,
+ downsample_padding=downsample_padding,
+ )
+
+
+def get_up_block(
+ up_block_type,
+ num_layers,
+ in_channels,
+ out_channels,
+ prev_output_channel,
+ temb_channels,
+ add_upsample,
+ resnet_eps,
+ resnet_act_fn,
+ attn_num_head_channels,
+ resnet_groups=None,
+ cross_attention_dim=None,
+):
+ up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
+ if up_block_type == "UpBlock2D":
+ return UpBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ prev_output_channel=prev_output_channel,
+ temb_channels=temb_channels,
+ add_upsample=add_upsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_groups=resnet_groups,
+ )
+ elif up_block_type == "CrossAttnUpBlock2D":
+ if cross_attention_dim is None:
+ raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock2D")
+ return CrossAttnUpBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ prev_output_channel=prev_output_channel,
+ temb_channels=temb_channels,
+ add_upsample=add_upsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_groups=resnet_groups,
+ cross_attention_dim=cross_attention_dim,
+ attn_num_head_channels=attn_num_head_channels,
+ )
+ elif up_block_type == "AttnUpBlock2D":
+ return AttnUpBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ prev_output_channel=prev_output_channel,
+ temb_channels=temb_channels,
+ add_upsample=add_upsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_groups=resnet_groups,
+ attn_num_head_channels=attn_num_head_channels,
+ )
+ elif up_block_type == "SkipUpBlock2D":
+ return SkipUpBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ prev_output_channel=prev_output_channel,
+ temb_channels=temb_channels,
+ add_upsample=add_upsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ )
+ elif up_block_type == "AttnSkipUpBlock2D":
+ return AttnSkipUpBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ prev_output_channel=prev_output_channel,
+ temb_channels=temb_channels,
+ add_upsample=add_upsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ attn_num_head_channels=attn_num_head_channels,
+ )
+ elif up_block_type == "UpDecoderBlock2D":
+ return UpDecoderBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ add_upsample=add_upsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_groups=resnet_groups,
+ )
+ raise ValueError(f"{up_block_type} does not exist.")
+
+
+class UNetMidBlock2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ attn_num_head_channels=1,
+ attention_type="default",
+ output_scale_factor=1.0,
+ **kwargs,
+ ):
+ super().__init__()
+
+ self.attention_type = attention_type
+ resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
+
+ # there is always at least one resnet
+ resnets = [
+ ResnetBlock2D(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ ]
+ attentions = []
+
+ for _ in range(num_layers):
+ attentions.append(
+ AttentionBlock(
+ in_channels,
+ num_head_channels=attn_num_head_channels,
+ rescale_output_factor=output_scale_factor,
+ eps=resnet_eps,
+ num_groups=resnet_groups,
+ )
+ )
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+
+ self.attentions = nn.ModuleList(attentions)
+ self.resnets = nn.ModuleList(resnets)
+
+ def forward(self, hidden_states, temb=None, encoder_states=None):
+ hidden_states = self.resnets[0](hidden_states, temb)
+ for attn, resnet in zip(self.attentions, self.resnets[1:]):
+ if self.attention_type == "default":
+ hidden_states = attn(hidden_states)
+ else:
+ hidden_states = attn(hidden_states, encoder_states)
+ hidden_states = resnet(hidden_states, temb)
+
+ return hidden_states
+
+
+class UNetMidBlock2DCrossAttn(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ attn_num_head_channels=1,
+ attention_type="default",
+ output_scale_factor=1.0,
+ cross_attention_dim=1280,
+ **kwargs,
+ ):
+ super().__init__()
+
+ self.attention_type = attention_type
+ self.attn_num_head_channels = attn_num_head_channels
+ resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
+
+ # there is always at least one resnet
+ resnets = [
+ ResnetBlock2D(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ ]
+ attentions = []
+
+ for _ in range(num_layers):
+ attentions.append(
+ SpatialTransformer(
+ in_channels,
+ attn_num_head_channels,
+ in_channels // attn_num_head_channels,
+ depth=1,
+ context_dim=cross_attention_dim,
+ num_groups=resnet_groups,
+ )
+ )
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+
+ self.attentions = nn.ModuleList(attentions)
+ self.resnets = nn.ModuleList(resnets)
+
+ def set_attention_slice(self, slice_size):
+ if slice_size is not None and self.attn_num_head_channels % slice_size != 0:
+ raise ValueError(
+ f"Make sure slice_size {slice_size} is a divisor of "
+ f"the number of heads used in cross_attention {self.attn_num_head_channels}"
+ )
+ if slice_size is not None and slice_size > self.attn_num_head_channels:
+ raise ValueError(
+ f"Chunk_size {slice_size} has to be smaller or equal to "
+ f"the number of heads used in cross_attention {self.attn_num_head_channels}"
+ )
+
+ for attn in self.attentions:
+ attn._set_attention_slice(slice_size)
+
+ def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
+ hidden_states = self.resnets[0](hidden_states, temb)
+ for attn, resnet in zip(self.attentions, self.resnets[1:]):
+ hidden_states = attn(hidden_states, encoder_hidden_states)
+ hidden_states = resnet(hidden_states, temb)
+
+ return hidden_states
+
+
+class AttnDownBlock2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ attn_num_head_channels=1,
+ attention_type="default",
+ output_scale_factor=1.0,
+ downsample_padding=1,
+ add_downsample=True,
+ ):
+ super().__init__()
+ resnets = []
+ attentions = []
+
+ self.attention_type = attention_type
+
+ for i in range(num_layers):
+ in_channels = in_channels if i == 0 else out_channels
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+ attentions.append(
+ AttentionBlock(
+ out_channels,
+ num_head_channels=attn_num_head_channels,
+ rescale_output_factor=output_scale_factor,
+ eps=resnet_eps,
+ num_groups=resnet_groups,
+ )
+ )
+
+ self.attentions = nn.ModuleList(attentions)
+ self.resnets = nn.ModuleList(resnets)
+
+ if add_downsample:
+ self.downsamplers = nn.ModuleList(
+ [
+ Downsample2D(
+ in_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
+ )
+ ]
+ )
+ else:
+ self.downsamplers = None
+
+ def forward(self, hidden_states, temb=None):
+ output_states = ()
+
+ for resnet, attn in zip(self.resnets, self.attentions):
+ hidden_states = resnet(hidden_states, temb)
+ hidden_states = attn(hidden_states)
+ output_states += (hidden_states,)
+
+ if self.downsamplers is not None:
+ for downsampler in self.downsamplers:
+ hidden_states = downsampler(hidden_states)
+
+ output_states += (hidden_states,)
+
+ return hidden_states, output_states
+
+
+class CrossAttnDownBlock2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ attn_num_head_channels=1,
+ cross_attention_dim=1280,
+ attention_type="default",
+ output_scale_factor=1.0,
+ downsample_padding=1,
+ add_downsample=True,
+ ):
+ super().__init__()
+ resnets = []
+ attentions = []
+
+ self.attention_type = attention_type
+ self.attn_num_head_channels = attn_num_head_channels
+
+ for i in range(num_layers):
+ in_channels = in_channels if i == 0 else out_channels
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+ attentions.append(
+ SpatialTransformer(
+ out_channels,
+ attn_num_head_channels,
+ out_channels // attn_num_head_channels,
+ depth=1,
+ context_dim=cross_attention_dim,
+ num_groups=resnet_groups,
+ )
+ )
+ self.attentions = nn.ModuleList(attentions)
+ self.resnets = nn.ModuleList(resnets)
+
+ if add_downsample:
+ self.downsamplers = nn.ModuleList(
+ [
+ Downsample2D(
+ in_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
+ )
+ ]
+ )
+ else:
+ self.downsamplers = None
+
+ self.gradient_checkpointing = False
+
+ def set_attention_slice(self, slice_size):
+ if slice_size is not None and self.attn_num_head_channels % slice_size != 0:
+ raise ValueError(
+ f"Make sure slice_size {slice_size} is a divisor of "
+ f"the number of heads used in cross_attention {self.attn_num_head_channels}"
+ )
+ if slice_size is not None and slice_size > self.attn_num_head_channels:
+ raise ValueError(
+ f"Chunk_size {slice_size} has to be smaller or equal to "
+ f"the number of heads used in cross_attention {self.attn_num_head_channels}"
+ )
+
+ for attn in self.attentions:
+ attn._set_attention_slice(slice_size)
+
+ def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
+ output_states = ()
+
+ for resnet, attn in zip(self.resnets, self.attentions):
+ if self.training and self.gradient_checkpointing:
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(*inputs)
+
+ return custom_forward
+
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(attn), hidden_states, encoder_hidden_states
+ )
+ else:
+ hidden_states = resnet(hidden_states, temb)
+ hidden_states = attn(hidden_states, context=encoder_hidden_states)
+
+ output_states += (hidden_states,)
+
+ if self.downsamplers is not None:
+ for downsampler in self.downsamplers:
+ hidden_states = downsampler(hidden_states)
+
+ output_states += (hidden_states,)
+
+ return hidden_states, output_states
+
+
+class DownBlock2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ output_scale_factor=1.0,
+ add_downsample=True,
+ downsample_padding=1,
+ ):
+ super().__init__()
+ resnets = []
+
+ for i in range(num_layers):
+ in_channels = in_channels if i == 0 else out_channels
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+
+ self.resnets = nn.ModuleList(resnets)
+
+ if add_downsample:
+ self.downsamplers = nn.ModuleList(
+ [
+ Downsample2D(
+ in_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
+ )
+ ]
+ )
+ else:
+ self.downsamplers = None
+
+ self.gradient_checkpointing = False
+
+ def forward(self, hidden_states, temb=None):
+ output_states = ()
+
+ for resnet in self.resnets:
+ if self.training and self.gradient_checkpointing:
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(*inputs)
+
+ return custom_forward
+
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
+ else:
+ hidden_states = resnet(hidden_states, temb)
+
+ output_states += (hidden_states,)
+
+ if self.downsamplers is not None:
+ for downsampler in self.downsamplers:
+ hidden_states = downsampler(hidden_states)
+
+ output_states += (hidden_states,)
+
+ return hidden_states, output_states
+
+
+class DownEncoderBlock2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ output_scale_factor=1.0,
+ add_downsample=True,
+ downsample_padding=1,
+ ):
+ super().__init__()
+ resnets = []
+
+ for i in range(num_layers):
+ in_channels = in_channels if i == 0 else out_channels
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=None,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+
+ self.resnets = nn.ModuleList(resnets)
+
+ if add_downsample:
+ self.downsamplers = nn.ModuleList(
+ [
+ Downsample2D(
+ in_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
+ )
+ ]
+ )
+ else:
+ self.downsamplers = None
+
+ def forward(self, hidden_states):
+ for resnet in self.resnets:
+ hidden_states = resnet(hidden_states, temb=None)
+
+ if self.downsamplers is not None:
+ for downsampler in self.downsamplers:
+ hidden_states = downsampler(hidden_states)
+
+ return hidden_states
+
+
+class AttnDownEncoderBlock2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ attn_num_head_channels=1,
+ output_scale_factor=1.0,
+ add_downsample=True,
+ downsample_padding=1,
+ ):
+ super().__init__()
+ resnets = []
+ attentions = []
+
+ for i in range(num_layers):
+ in_channels = in_channels if i == 0 else out_channels
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=None,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+ attentions.append(
+ AttentionBlock(
+ out_channels,
+ num_head_channels=attn_num_head_channels,
+ rescale_output_factor=output_scale_factor,
+ eps=resnet_eps,
+ num_groups=resnet_groups,
+ )
+ )
+
+ self.attentions = nn.ModuleList(attentions)
+ self.resnets = nn.ModuleList(resnets)
+
+ if add_downsample:
+ self.downsamplers = nn.ModuleList(
+ [
+ Downsample2D(
+ in_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
+ )
+ ]
+ )
+ else:
+ self.downsamplers = None
+
+ def forward(self, hidden_states):
+ for resnet, attn in zip(self.resnets, self.attentions):
+ hidden_states = resnet(hidden_states, temb=None)
+ hidden_states = attn(hidden_states)
+
+ if self.downsamplers is not None:
+ for downsampler in self.downsamplers:
+ hidden_states = downsampler(hidden_states)
+
+ return hidden_states
+
+
+class AttnSkipDownBlock2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_pre_norm: bool = True,
+ attn_num_head_channels=1,
+ attention_type="default",
+ output_scale_factor=np.sqrt(2.0),
+ downsample_padding=1,
+ add_downsample=True,
+ ):
+ super().__init__()
+ self.attentions = nn.ModuleList([])
+ self.resnets = nn.ModuleList([])
+
+ self.attention_type = attention_type
+
+ for i in range(num_layers):
+ in_channels = in_channels if i == 0 else out_channels
+ self.resnets.append(
+ ResnetBlock2D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=min(in_channels // 4, 32),
+ groups_out=min(out_channels // 4, 32),
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+ self.attentions.append(
+ AttentionBlock(
+ out_channels,
+ num_head_channels=attn_num_head_channels,
+ rescale_output_factor=output_scale_factor,
+ eps=resnet_eps,
+ )
+ )
+
+ if add_downsample:
+ self.resnet_down = ResnetBlock2D(
+ in_channels=out_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=min(out_channels // 4, 32),
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ use_in_shortcut=True,
+ down=True,
+ kernel="fir",
+ )
+ self.downsamplers = nn.ModuleList([FirDownsample2D(in_channels, out_channels=out_channels)])
+ self.skip_conv = nn.Conv2d(3, out_channels, kernel_size=(1, 1), stride=(1, 1))
+ else:
+ self.resnet_down = None
+ self.downsamplers = None
+ self.skip_conv = None
+
+ def forward(self, hidden_states, temb=None, skip_sample=None):
+ output_states = ()
+
+ for resnet, attn in zip(self.resnets, self.attentions):
+ hidden_states = resnet(hidden_states, temb)
+ hidden_states = attn(hidden_states)
+ output_states += (hidden_states,)
+
+ if self.downsamplers is not None:
+ hidden_states = self.resnet_down(hidden_states, temb)
+ for downsampler in self.downsamplers:
+ skip_sample = downsampler(skip_sample)
+
+ hidden_states = self.skip_conv(skip_sample) + hidden_states
+
+ output_states += (hidden_states,)
+
+ return hidden_states, output_states, skip_sample
+
+
+class SkipDownBlock2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_pre_norm: bool = True,
+ output_scale_factor=np.sqrt(2.0),
+ add_downsample=True,
+ downsample_padding=1,
+ ):
+ super().__init__()
+ self.resnets = nn.ModuleList([])
+
+ for i in range(num_layers):
+ in_channels = in_channels if i == 0 else out_channels
+ self.resnets.append(
+ ResnetBlock2D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=min(in_channels // 4, 32),
+ groups_out=min(out_channels // 4, 32),
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+
+ if add_downsample:
+ self.resnet_down = ResnetBlock2D(
+ in_channels=out_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=min(out_channels // 4, 32),
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ use_in_shortcut=True,
+ down=True,
+ kernel="fir",
+ )
+ self.downsamplers = nn.ModuleList([FirDownsample2D(in_channels, out_channels=out_channels)])
+ self.skip_conv = nn.Conv2d(3, out_channels, kernel_size=(1, 1), stride=(1, 1))
+ else:
+ self.resnet_down = None
+ self.downsamplers = None
+ self.skip_conv = None
+
+ def forward(self, hidden_states, temb=None, skip_sample=None):
+ output_states = ()
+
+ for resnet in self.resnets:
+ hidden_states = resnet(hidden_states, temb)
+ output_states += (hidden_states,)
+
+ if self.downsamplers is not None:
+ hidden_states = self.resnet_down(hidden_states, temb)
+ for downsampler in self.downsamplers:
+ skip_sample = downsampler(skip_sample)
+
+ hidden_states = self.skip_conv(skip_sample) + hidden_states
+
+ output_states += (hidden_states,)
+
+ return hidden_states, output_states, skip_sample
+
+
+class AttnUpBlock2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ prev_output_channel: int,
+ out_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ attention_type="default",
+ attn_num_head_channels=1,
+ output_scale_factor=1.0,
+ add_upsample=True,
+ ):
+ super().__init__()
+ resnets = []
+ attentions = []
+
+ self.attention_type = attention_type
+
+ for i in range(num_layers):
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
+
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=resnet_in_channels + res_skip_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+ attentions.append(
+ AttentionBlock(
+ out_channels,
+ num_head_channels=attn_num_head_channels,
+ rescale_output_factor=output_scale_factor,
+ eps=resnet_eps,
+ num_groups=resnet_groups,
+ )
+ )
+
+ self.attentions = nn.ModuleList(attentions)
+ self.resnets = nn.ModuleList(resnets)
+
+ if add_upsample:
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
+ else:
+ self.upsamplers = None
+
+ def forward(self, hidden_states, res_hidden_states_tuple, temb=None):
+ for resnet, attn in zip(self.resnets, self.attentions):
+ # pop res hidden states
+ res_hidden_states = res_hidden_states_tuple[-1]
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
+
+ hidden_states = resnet(hidden_states, temb)
+ hidden_states = attn(hidden_states)
+
+ if self.upsamplers is not None:
+ for upsampler in self.upsamplers:
+ hidden_states = upsampler(hidden_states)
+
+ return hidden_states
+
+
+class CrossAttnUpBlock2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ prev_output_channel: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ attn_num_head_channels=1,
+ cross_attention_dim=1280,
+ attention_type="default",
+ output_scale_factor=1.0,
+ downsample_padding=1,
+ add_upsample=True,
+ ):
+ super().__init__()
+ resnets = []
+ attentions = []
+
+ self.attention_type = attention_type
+ self.attn_num_head_channels = attn_num_head_channels
+
+ for i in range(num_layers):
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
+
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=resnet_in_channels + res_skip_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+ attentions.append(
+ SpatialTransformer(
+ out_channels,
+ attn_num_head_channels,
+ out_channels // attn_num_head_channels,
+ depth=1,
+ context_dim=cross_attention_dim,
+ num_groups=resnet_groups,
+ )
+ )
+ self.attentions = nn.ModuleList(attentions)
+ self.resnets = nn.ModuleList(resnets)
+
+ if add_upsample:
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
+ else:
+ self.upsamplers = None
+
+ self.gradient_checkpointing = False
+
+ def set_attention_slice(self, slice_size):
+ if slice_size is not None and self.attn_num_head_channels % slice_size != 0:
+ raise ValueError(
+ f"Make sure slice_size {slice_size} is a divisor of "
+ f"the number of heads used in cross_attention {self.attn_num_head_channels}"
+ )
+ if slice_size is not None and slice_size > self.attn_num_head_channels:
+ raise ValueError(
+ f"Chunk_size {slice_size} has to be smaller or equal to "
+ f"the number of heads used in cross_attention {self.attn_num_head_channels}"
+ )
+
+ for attn in self.attentions:
+ attn._set_attention_slice(slice_size)
+
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ hidden_states,
+ res_hidden_states_tuple,
+ temb=None,
+ encoder_hidden_states=None,
+ ):
+ for resnet, attn in zip(self.resnets, self.attentions):
+ # pop res hidden states
+ res_hidden_states = res_hidden_states_tuple[-1]
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
+
+ if self.training and self.gradient_checkpointing:
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(*inputs)
+
+ return custom_forward
+
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(attn), hidden_states, encoder_hidden_states
+ )
+ else:
+ hidden_states = resnet(hidden_states, temb)
+ hidden_states = attn(hidden_states, context=encoder_hidden_states)
+
+ if self.upsamplers is not None:
+ for upsampler in self.upsamplers:
+ hidden_states = upsampler(hidden_states)
+
+ return hidden_states
+
+
+class UpBlock2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ prev_output_channel: int,
+ out_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ output_scale_factor=1.0,
+ add_upsample=True,
+ ):
+ super().__init__()
+ resnets = []
+
+ for i in range(num_layers):
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
+
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=resnet_in_channels + res_skip_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+
+ self.resnets = nn.ModuleList(resnets)
+
+ if add_upsample:
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
+ else:
+ self.upsamplers = None
+
+ self.gradient_checkpointing = False
+
+ def forward(self, hidden_states, res_hidden_states_tuple, temb=None):
+ for resnet in self.resnets:
+ # pop res hidden states
+ res_hidden_states = res_hidden_states_tuple[-1]
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
+
+ if self.training and self.gradient_checkpointing:
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(*inputs)
+
+ return custom_forward
+
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
+ else:
+ hidden_states = resnet(hidden_states, temb)
+
+ if self.upsamplers is not None:
+ for upsampler in self.upsamplers:
+ hidden_states = upsampler(hidden_states)
+
+ return hidden_states
+
+
+class UpDecoderBlock2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ output_scale_factor=1.0,
+ add_upsample=True,
+ ):
+ super().__init__()
+ resnets = []
+
+ for i in range(num_layers):
+ input_channels = in_channels if i == 0 else out_channels
+
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=input_channels,
+ out_channels=out_channels,
+ temb_channels=None,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+
+ self.resnets = nn.ModuleList(resnets)
+
+ if add_upsample:
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
+ else:
+ self.upsamplers = None
+
+ def forward(self, hidden_states):
+ for resnet in self.resnets:
+ hidden_states = resnet(hidden_states, temb=None)
+
+ if self.upsamplers is not None:
+ for upsampler in self.upsamplers:
+ hidden_states = upsampler(hidden_states)
+
+ return hidden_states
+
+
+class AttnUpDecoderBlock2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ attn_num_head_channels=1,
+ output_scale_factor=1.0,
+ add_upsample=True,
+ ):
+ super().__init__()
+ resnets = []
+ attentions = []
+
+ for i in range(num_layers):
+ input_channels = in_channels if i == 0 else out_channels
+
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=input_channels,
+ out_channels=out_channels,
+ temb_channels=None,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+ attentions.append(
+ AttentionBlock(
+ out_channels,
+ num_head_channels=attn_num_head_channels,
+ rescale_output_factor=output_scale_factor,
+ eps=resnet_eps,
+ num_groups=resnet_groups,
+ )
+ )
+
+ self.attentions = nn.ModuleList(attentions)
+ self.resnets = nn.ModuleList(resnets)
+
+ if add_upsample:
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
+ else:
+ self.upsamplers = None
+
+ def forward(self, hidden_states):
+ for resnet, attn in zip(self.resnets, self.attentions):
+ hidden_states = resnet(hidden_states, temb=None)
+ hidden_states = attn(hidden_states)
+
+ if self.upsamplers is not None:
+ for upsampler in self.upsamplers:
+ hidden_states = upsampler(hidden_states)
+
+ return hidden_states
+
+
+class AttnSkipUpBlock2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ prev_output_channel: int,
+ out_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_pre_norm: bool = True,
+ attn_num_head_channels=1,
+ attention_type="default",
+ output_scale_factor=np.sqrt(2.0),
+ upsample_padding=1,
+ add_upsample=True,
+ ):
+ super().__init__()
+ self.attentions = nn.ModuleList([])
+ self.resnets = nn.ModuleList([])
+
+ self.attention_type = attention_type
+
+ for i in range(num_layers):
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
+
+ self.resnets.append(
+ ResnetBlock2D(
+ in_channels=resnet_in_channels + res_skip_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=min(resnet_in_channels + res_skip_channels // 4, 32),
+ groups_out=min(out_channels // 4, 32),
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+
+ self.attentions.append(
+ AttentionBlock(
+ out_channels,
+ num_head_channels=attn_num_head_channels,
+ rescale_output_factor=output_scale_factor,
+ eps=resnet_eps,
+ )
+ )
+
+ self.upsampler = FirUpsample2D(in_channels, out_channels=out_channels)
+ if add_upsample:
+ self.resnet_up = ResnetBlock2D(
+ in_channels=out_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=min(out_channels // 4, 32),
+ groups_out=min(out_channels // 4, 32),
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ use_in_shortcut=True,
+ up=True,
+ kernel="fir",
+ )
+ self.skip_conv = nn.Conv2d(out_channels, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
+ self.skip_norm = torch.nn.GroupNorm(
+ num_groups=min(out_channels // 4, 32), num_channels=out_channels, eps=resnet_eps, affine=True
+ )
+ self.act = nn.SiLU()
+ else:
+ self.resnet_up = None
+ self.skip_conv = None
+ self.skip_norm = None
+ self.act = None
+
+ def forward(self, hidden_states, res_hidden_states_tuple, temb=None, skip_sample=None):
+ for resnet in self.resnets:
+ # pop res hidden states
+ res_hidden_states = res_hidden_states_tuple[-1]
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
+
+ hidden_states = resnet(hidden_states, temb)
+
+ hidden_states = self.attentions[0](hidden_states)
+
+ if skip_sample is not None:
+ skip_sample = self.upsampler(skip_sample)
+ else:
+ skip_sample = 0
+
+ if self.resnet_up is not None:
+ skip_sample_states = self.skip_norm(hidden_states)
+ skip_sample_states = self.act(skip_sample_states)
+ skip_sample_states = self.skip_conv(skip_sample_states)
+
+ skip_sample = skip_sample + skip_sample_states
+
+ hidden_states = self.resnet_up(hidden_states, temb)
+
+ return hidden_states, skip_sample
+
+
+class SkipUpBlock2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ prev_output_channel: int,
+ out_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_pre_norm: bool = True,
+ output_scale_factor=np.sqrt(2.0),
+ add_upsample=True,
+ upsample_padding=1,
+ ):
+ super().__init__()
+ self.resnets = nn.ModuleList([])
+
+ for i in range(num_layers):
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
+
+ self.resnets.append(
+ ResnetBlock2D(
+ in_channels=resnet_in_channels + res_skip_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=min((resnet_in_channels + res_skip_channels) // 4, 32),
+ groups_out=min(out_channels // 4, 32),
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+
+ self.upsampler = FirUpsample2D(in_channels, out_channels=out_channels)
+ if add_upsample:
+ self.resnet_up = ResnetBlock2D(
+ in_channels=out_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=min(out_channels // 4, 32),
+ groups_out=min(out_channels // 4, 32),
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ use_in_shortcut=True,
+ up=True,
+ kernel="fir",
+ )
+ self.skip_conv = nn.Conv2d(out_channels, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
+ self.skip_norm = torch.nn.GroupNorm(
+ num_groups=min(out_channels // 4, 32), num_channels=out_channels, eps=resnet_eps, affine=True
+ )
+ self.act = nn.SiLU()
+ else:
+ self.resnet_up = None
+ self.skip_conv = None
+ self.skip_norm = None
+ self.act = None
+
+ def forward(self, hidden_states, res_hidden_states_tuple, temb=None, skip_sample=None):
+ for resnet in self.resnets:
+ # pop res hidden states
+ res_hidden_states = res_hidden_states_tuple[-1]
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
+
+ hidden_states = resnet(hidden_states, temb)
+
+ if skip_sample is not None:
+ skip_sample = self.upsampler(skip_sample)
+ else:
+ skip_sample = 0
+
+ if self.resnet_up is not None:
+ skip_sample_states = self.skip_norm(hidden_states)
+ skip_sample_states = self.act(skip_sample_states)
+ skip_sample_states = self.skip_conv(skip_sample_states)
+
+ skip_sample = skip_sample + skip_sample_states
+
+ hidden_states = self.resnet_up(hidden_states, temb)
+
+ return hidden_states, skip_sample
diff --git a/src/diffusers/models/vae_oneflow.py b/src/diffusers/models/vae_oneflow.py
new file mode 100644
index 000000000000..41c615cc284a
--- /dev/null
+++ b/src/diffusers/models/vae_oneflow.py
@@ -0,0 +1,593 @@
+from dataclasses import dataclass
+from typing import Optional, Tuple, Union
+
+import numpy as np
+import oneflow as torch
+import oneflow.nn as nn
+
+from ..configuration_utils import ConfigMixin, register_to_config
+from ..modeling_oneflow_utils import OneFlowModelMixin as ModelMixin
+from ..utils import BaseOutput
+from .unet_blocks_oneflow import UNetMidBlock2D, get_down_block, get_up_block
+
+
+@dataclass
+class DecoderOutput(BaseOutput):
+ """
+ Output of decoding method.
+
+ Args:
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
+ Decoded output sample of the model. Output of the last layer of the model.
+ """
+
+ sample: torch.FloatTensor
+
+
+@dataclass
+class VQEncoderOutput(BaseOutput):
+ """
+ Output of VQModel encoding method.
+
+ Args:
+ latents (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
+ Encoded output sample of the model. Output of the last layer of the model.
+ """
+
+ latents: torch.FloatTensor
+
+
+@dataclass
+class AutoencoderKLOutput(BaseOutput):
+ """
+ Output of AutoencoderKL encoding method.
+
+ Args:
+ latent_dist (`DiagonalGaussianDistribution`):
+ Encoded outputs of `Encoder` represented as the mean and logvar of `DiagonalGaussianDistribution`.
+ `DiagonalGaussianDistribution` allows for sampling latents from the distribution.
+ """
+
+ latent_dist: "DiagonalGaussianDistribution"
+
+
+class Encoder(nn.Module):
+ def __init__(
+ self,
+ in_channels=3,
+ out_channels=3,
+ down_block_types=("DownEncoderBlock2D",),
+ block_out_channels=(64,),
+ layers_per_block=2,
+ norm_num_groups=32,
+ act_fn="silu",
+ double_z=True,
+ ):
+ super().__init__()
+ self.layers_per_block = layers_per_block
+
+ self.conv_in = torch.nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1)
+
+ self.mid_block = None
+ self.down_blocks = nn.ModuleList([])
+
+ # down
+ output_channel = block_out_channels[0]
+ for i, down_block_type in enumerate(down_block_types):
+ input_channel = output_channel
+ output_channel = block_out_channels[i]
+ is_final_block = i == len(block_out_channels) - 1
+
+ down_block = get_down_block(
+ down_block_type,
+ num_layers=self.layers_per_block,
+ in_channels=input_channel,
+ out_channels=output_channel,
+ add_downsample=not is_final_block,
+ resnet_eps=1e-6,
+ downsample_padding=0,
+ resnet_act_fn=act_fn,
+ resnet_groups=norm_num_groups,
+ attn_num_head_channels=None,
+ temb_channels=None,
+ )
+ self.down_blocks.append(down_block)
+
+ # mid
+ self.mid_block = UNetMidBlock2D(
+ in_channels=block_out_channels[-1],
+ resnet_eps=1e-6,
+ resnet_act_fn=act_fn,
+ output_scale_factor=1,
+ resnet_time_scale_shift="default",
+ attn_num_head_channels=None,
+ resnet_groups=norm_num_groups,
+ temb_channels=None,
+ )
+
+ # out
+ self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6)
+ self.conv_act = nn.SiLU()
+
+ conv_out_channels = 2 * out_channels if double_z else out_channels
+ self.conv_out = nn.Conv2d(block_out_channels[-1], conv_out_channels, 3, padding=1)
+
+ def forward(self, x):
+ sample = x
+ sample = self.conv_in(sample)
+
+ # down
+ for down_block in self.down_blocks:
+ sample = down_block(sample)
+
+ # middle
+ sample = self.mid_block(sample)
+
+ # post-process
+ sample = self.conv_norm_out(sample)
+ sample = self.conv_act(sample)
+ sample = self.conv_out(sample)
+
+ return sample
+
+
+class Decoder(nn.Module):
+ def __init__(
+ self,
+ in_channels=3,
+ out_channels=3,
+ up_block_types=("UpDecoderBlock2D",),
+ block_out_channels=(64,),
+ layers_per_block=2,
+ norm_num_groups=32,
+ act_fn="silu",
+ ):
+ super().__init__()
+ self.layers_per_block = layers_per_block
+
+ self.conv_in = nn.Conv2d(in_channels, block_out_channels[-1], kernel_size=3, stride=1, padding=1)
+
+ self.mid_block = None
+ self.up_blocks = nn.ModuleList([])
+
+ # mid
+ self.mid_block = UNetMidBlock2D(
+ in_channels=block_out_channels[-1],
+ resnet_eps=1e-6,
+ resnet_act_fn=act_fn,
+ output_scale_factor=1,
+ resnet_time_scale_shift="default",
+ attn_num_head_channels=None,
+ resnet_groups=norm_num_groups,
+ temb_channels=None,
+ )
+
+ # up
+ reversed_block_out_channels = list(reversed(block_out_channels))
+ output_channel = reversed_block_out_channels[0]
+ for i, up_block_type in enumerate(up_block_types):
+ prev_output_channel = output_channel
+ output_channel = reversed_block_out_channels[i]
+
+ is_final_block = i == len(block_out_channels) - 1
+
+ up_block = get_up_block(
+ up_block_type,
+ num_layers=self.layers_per_block + 1,
+ in_channels=prev_output_channel,
+ out_channels=output_channel,
+ prev_output_channel=None,
+ add_upsample=not is_final_block,
+ resnet_eps=1e-6,
+ resnet_act_fn=act_fn,
+ resnet_groups=norm_num_groups,
+ attn_num_head_channels=None,
+ temb_channels=None,
+ )
+ self.up_blocks.append(up_block)
+ prev_output_channel = output_channel
+
+ # out
+ self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6)
+ self.conv_act = nn.SiLU()
+ self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1)
+
+ def forward(self, z):
+ sample = z
+ sample = self.conv_in(sample)
+
+ # middle
+ sample = self.mid_block(sample)
+
+ # up
+ for up_block in self.up_blocks:
+ sample = up_block(sample)
+
+ # post-process
+ sample = self.conv_norm_out(sample)
+ sample = self.conv_act(sample)
+ sample = self.conv_out(sample)
+
+ return sample
+
+
+class VectorQuantizer(nn.Module):
+ """
+ Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly avoids costly matrix
+ multiplications and allows for post-hoc remapping of indices.
+ """
+
+ # NOTE: due to a bug the beta term was applied to the wrong term. for
+ # backwards compatibility we use the buggy version by default, but you can
+ # specify legacy=False to fix it.
+ def __init__(self, n_e, e_dim, beta, remap=None, unknown_index="random", sane_index_shape=False, legacy=True):
+ super().__init__()
+ self.n_e = n_e
+ self.e_dim = e_dim
+ self.beta = beta
+ self.legacy = legacy
+
+ self.embedding = nn.Embedding(self.n_e, self.e_dim)
+ self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
+
+ self.remap = remap
+ if self.remap is not None:
+ self.register_buffer("used", torch.tensor(np.load(self.remap)))
+ self.re_embed = self.used.shape[0]
+ self.unknown_index = unknown_index # "random" or "extra" or integer
+ if self.unknown_index == "extra":
+ self.unknown_index = self.re_embed
+ self.re_embed = self.re_embed + 1
+ print(
+ f"Remapping {self.n_e} indices to {self.re_embed} indices. "
+ f"Using {self.unknown_index} for unknown indices."
+ )
+ else:
+ self.re_embed = n_e
+
+ self.sane_index_shape = sane_index_shape
+
+ def remap_to_used(self, inds):
+ ishape = inds.shape
+ assert len(ishape) > 1
+ inds = inds.reshape(ishape[0], -1)
+ used = self.used.to(inds)
+ match = (inds[:, :, None] == used[None, None, ...]).long()
+ new = match.argmax(-1)
+ unknown = match.sum(2) < 1
+ if self.unknown_index == "random":
+ new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(device=new.device)
+ else:
+ new[unknown] = self.unknown_index
+ return new.reshape(ishape)
+
+ def unmap_to_all(self, inds):
+ ishape = inds.shape
+ assert len(ishape) > 1
+ inds = inds.reshape(ishape[0], -1)
+ used = self.used.to(inds)
+ if self.re_embed > self.used.shape[0]: # extra token
+ inds[inds >= self.used.shape[0]] = 0 # simply set to zero
+ back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds)
+ return back.reshape(ishape)
+
+ def forward(self, z):
+ # reshape z -> (batch, height, width, channel) and flatten
+ z = z.permute(0, 2, 3, 1).contiguous()
+ z_flattened = z.view(-1, self.e_dim)
+ # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
+
+ d = (
+ torch.sum(z_flattened**2, dim=1, keepdim=True)
+ + torch.sum(self.embedding.weight**2, dim=1)
+ - 2 * torch.einsum("bd,dn->bn", z_flattened, self.embedding.weight.t())
+ )
+
+ min_encoding_indices = torch.argmin(d, dim=1)
+ z_q = self.embedding(min_encoding_indices).view(z.shape)
+ perplexity = None
+ min_encodings = None
+
+ # compute loss for embedding
+ if not self.legacy:
+ loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + torch.mean((z_q - z.detach()) ** 2)
+ else:
+ loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * torch.mean((z_q - z.detach()) ** 2)
+
+ # preserve gradients
+ z_q = z + (z_q - z).detach()
+
+ # reshape back to match original input shape
+ z_q = z_q.permute(0, 3, 1, 2).contiguous()
+
+ if self.remap is not None:
+ min_encoding_indices = min_encoding_indices.reshape(z.shape[0], -1) # add batch axis
+ min_encoding_indices = self.remap_to_used(min_encoding_indices)
+ min_encoding_indices = min_encoding_indices.reshape(-1, 1) # flatten
+
+ if self.sane_index_shape:
+ min_encoding_indices = min_encoding_indices.reshape(z_q.shape[0], z_q.shape[2], z_q.shape[3])
+
+ return z_q, loss, (perplexity, min_encodings, min_encoding_indices)
+
+ def get_codebook_entry(self, indices, shape):
+ # shape specifying (batch, height, width, channel)
+ if self.remap is not None:
+ indices = indices.reshape(shape[0], -1) # add batch axis
+ indices = self.unmap_to_all(indices)
+ indices = indices.reshape(-1) # flatten again
+
+ # get quantized latent vectors
+ z_q = self.embedding(indices)
+
+ if shape is not None:
+ z_q = z_q.view(shape)
+ # reshape back to match original input shape
+ z_q = z_q.permute(0, 3, 1, 2).contiguous()
+
+ return z_q
+
+
+class DiagonalGaussianDistribution(object):
+ def __init__(self, parameters, deterministic=False):
+ self.parameters = parameters
+ self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
+ self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
+ self.deterministic = deterministic
+ self.std = torch.exp(0.5 * self.logvar)
+ self.var = torch.exp(self.logvar)
+ if self.deterministic:
+ self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
+
+ def sample(self, generator: Optional[torch.Generator] = None) -> torch.FloatTensor:
+ device = self.parameters.device
+ sample_device = "cpu" if device.type == "mps" else device
+ sample = torch.randn(self.mean.shape, generator=generator, device=sample_device).to(device)
+ x = self.mean + self.std * sample
+ return x
+
+ def kl(self, other=None):
+ if self.deterministic:
+ return torch.Tensor([0.0])
+ else:
+ if other is None:
+ return 0.5 * torch.sum(torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, dim=[1, 2, 3])
+ else:
+ return 0.5 * torch.sum(
+ torch.pow(self.mean - other.mean, 2) / other.var
+ + self.var / other.var
+ - 1.0
+ - self.logvar
+ + other.logvar,
+ dim=[1, 2, 3],
+ )
+
+ def nll(self, sample, dims=[1, 2, 3]):
+ if self.deterministic:
+ return torch.Tensor([0.0])
+ logtwopi = np.log(2.0 * np.pi)
+ return 0.5 * torch.sum(logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, dim=dims)
+
+ def mode(self):
+ return self.mean
+
+
+class VQModel(ModelMixin, ConfigMixin):
+ r"""VQ-VAE model from the paper Neural Discrete Representation Learning by Aaron van den Oord, Oriol Vinyals and Koray
+ Kavukcuoglu.
+
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
+ implements for all the model (such as downloading or saving, etc.)
+
+ Parameters:
+ in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
+ out_channels (int, *optional*, defaults to 3): Number of channels in the output.
+ down_block_types (`Tuple[str]`, *optional*, defaults to :
+ obj:`("DownEncoderBlock2D",)`): Tuple of downsample block types.
+ up_block_types (`Tuple[str]`, *optional*, defaults to :
+ obj:`("UpDecoderBlock2D",)`): Tuple of upsample block types.
+ block_out_channels (`Tuple[int]`, *optional*, defaults to :
+ obj:`(64,)`): Tuple of block output channels.
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
+ latent_channels (`int`, *optional*, defaults to `3`): Number of channels in the latent space.
+ sample_size (`int`, *optional*, defaults to `32`): TODO
+ num_vq_embeddings (`int`, *optional*, defaults to `256`): Number of codebook vectors in the VQ-VAE.
+ """
+
+ @register_to_config
+ def __init__(
+ self,
+ in_channels: int = 3,
+ out_channels: int = 3,
+ down_block_types: Tuple[str] = ("DownEncoderBlock2D",),
+ up_block_types: Tuple[str] = ("UpDecoderBlock2D",),
+ block_out_channels: Tuple[int] = (64,),
+ layers_per_block: int = 1,
+ act_fn: str = "silu",
+ latent_channels: int = 3,
+ sample_size: int = 32,
+ num_vq_embeddings: int = 256,
+ norm_num_groups: int = 32,
+ ):
+ super().__init__()
+
+ # pass init params to Encoder
+ self.encoder = Encoder(
+ in_channels=in_channels,
+ out_channels=latent_channels,
+ down_block_types=down_block_types,
+ block_out_channels=block_out_channels,
+ layers_per_block=layers_per_block,
+ act_fn=act_fn,
+ norm_num_groups=norm_num_groups,
+ double_z=False,
+ )
+
+ self.quant_conv = torch.nn.Conv2d(latent_channels, latent_channels, 1)
+ self.quantize = VectorQuantizer(
+ num_vq_embeddings, latent_channels, beta=0.25, remap=None, sane_index_shape=False
+ )
+ self.post_quant_conv = torch.nn.Conv2d(latent_channels, latent_channels, 1)
+
+ # pass init params to Decoder
+ self.decoder = Decoder(
+ in_channels=latent_channels,
+ out_channels=out_channels,
+ up_block_types=up_block_types,
+ block_out_channels=block_out_channels,
+ layers_per_block=layers_per_block,
+ act_fn=act_fn,
+ norm_num_groups=norm_num_groups,
+ )
+
+ def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> VQEncoderOutput:
+ h = self.encoder(x)
+ h = self.quant_conv(h)
+
+ if not return_dict:
+ return (h,)
+
+ return VQEncoderOutput(latents=h)
+
+ def decode(
+ self, h: torch.FloatTensor, force_not_quantize: bool = False, return_dict: bool = True
+ ) -> Union[DecoderOutput, torch.FloatTensor]:
+ # also go through quantization layer
+ if not force_not_quantize:
+ quant, emb_loss, info = self.quantize(h)
+ else:
+ quant = h
+ quant = self.post_quant_conv(quant)
+ dec = self.decoder(quant)
+
+ if not return_dict:
+ return (dec,)
+
+ return DecoderOutput(sample=dec)
+
+ def forward(self, sample: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
+ r"""
+ Args:
+ sample (`torch.FloatTensor`): Input sample.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
+ """
+ x = sample
+ h = self.encode(x).latents
+ dec = self.decode(h).sample
+
+ if not return_dict:
+ return (dec,)
+
+ return DecoderOutput(sample=dec)
+
+
+class OneFlowAutoencoderKL(ModelMixin, ConfigMixin):
+ r"""Variational Autoencoder (VAE) model with KL loss from the paper Auto-Encoding Variational Bayes by Diederik P. Kingma
+ and Max Welling.
+
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
+ implements for all the model (such as downloading or saving, etc.)
+
+ Parameters:
+ in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
+ out_channels (int, *optional*, defaults to 3): Number of channels in the output.
+ down_block_types (`Tuple[str]`, *optional*, defaults to :
+ obj:`("DownEncoderBlock2D",)`): Tuple of downsample block types.
+ up_block_types (`Tuple[str]`, *optional*, defaults to :
+ obj:`("UpDecoderBlock2D",)`): Tuple of upsample block types.
+ block_out_channels (`Tuple[int]`, *optional*, defaults to :
+ obj:`(64,)`): Tuple of block output channels.
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
+ latent_channels (`int`, *optional*, defaults to `4`): Number of channels in the latent space.
+ sample_size (`int`, *optional*, defaults to `32`): TODO
+ """
+
+ @register_to_config
+ def __init__(
+ self,
+ in_channels: int = 3,
+ out_channels: int = 3,
+ down_block_types: Tuple[str] = ("DownEncoderBlock2D",),
+ up_block_types: Tuple[str] = ("UpDecoderBlock2D",),
+ block_out_channels: Tuple[int] = (64,),
+ layers_per_block: int = 1,
+ act_fn: str = "silu",
+ latent_channels: int = 4,
+ norm_num_groups: int = 32,
+ sample_size: int = 32,
+ ):
+ super().__init__()
+
+ # pass init params to Encoder
+ self.encoder = Encoder(
+ in_channels=in_channels,
+ out_channels=latent_channels,
+ down_block_types=down_block_types,
+ block_out_channels=block_out_channels,
+ layers_per_block=layers_per_block,
+ act_fn=act_fn,
+ norm_num_groups=norm_num_groups,
+ double_z=True,
+ )
+
+ # pass init params to Decoder
+ self.decoder = Decoder(
+ in_channels=latent_channels,
+ out_channels=out_channels,
+ up_block_types=up_block_types,
+ block_out_channels=block_out_channels,
+ layers_per_block=layers_per_block,
+ norm_num_groups=norm_num_groups,
+ act_fn=act_fn,
+ )
+
+ self.quant_conv = torch.nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1)
+ self.post_quant_conv = torch.nn.Conv2d(latent_channels, latent_channels, 1)
+
+ def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput:
+ h = self.encoder(x)
+ moments = self.quant_conv(h)
+ posterior = DiagonalGaussianDistribution(moments)
+
+ if not return_dict:
+ return (posterior,)
+
+ return AutoencoderKLOutput(latent_dist=posterior)
+
+ def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
+ z = self.post_quant_conv(z)
+ dec = self.decoder(z)
+
+ if not return_dict:
+ return (dec,)
+
+ return DecoderOutput(sample=dec)
+
+ def forward(
+ self,
+ sample: torch.FloatTensor,
+ sample_posterior: bool = False,
+ return_dict: bool = True,
+ generator: Optional[torch.Generator] = None,
+ ) -> Union[DecoderOutput, torch.FloatTensor]:
+ r"""
+ Args:
+ sample (`torch.FloatTensor`): Input sample.
+ sample_posterior (`bool`, *optional*, defaults to `False`):
+ Whether to sample from the posterior.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
+ """
+ x = sample
+ posterior = self.encode(x).latent_dist
+ if sample_posterior:
+ z = posterior.sample(generator=generator)
+ else:
+ z = posterior.mode()
+ dec = self.decode(z).sample
+
+ if not return_dict:
+ return (dec,)
+
+ return DecoderOutput(sample=dec)
diff --git a/src/diffusers/pipeline_oneflow_utils.py b/src/diffusers/pipeline_oneflow_utils.py
new file mode 100644
index 000000000000..1d8808ea0adb
--- /dev/null
+++ b/src/diffusers/pipeline_oneflow_utils.py
@@ -0,0 +1,457 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team.
+# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import importlib
+import inspect
+import os
+from dataclasses import dataclass
+from typing import List, Optional, Union
+
+import numpy as np
+import oneflow as torch
+import torch as og_torch
+
+import diffusers
+import PIL
+from huggingface_hub import snapshot_download
+from PIL import Image
+from tqdm.auto import tqdm
+
+from .configuration_utils import ConfigMixin
+from .schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
+from .utils import CONFIG_NAME, DIFFUSERS_CACHE, ONNX_WEIGHTS_NAME, WEIGHTS_NAME, BaseOutput, logging
+
+
+INDEX_FILE = "diffusion_pytorch_model.bin"
+
+
+logger = logging.get_logger(__name__)
+
+
+LOADABLE_CLASSES = {
+ "diffusers": {
+ "OneFlowModelMixin": ["save_pretrained", "from_pretrained"],
+ "OneFlowSchedulerMixin": ["save_config", "from_config"],
+ "OneFlowDiffusionPipeline": ["save_pretrained", "from_pretrained"],
+ "OnnxRuntimeModel": ["save_pretrained", "from_pretrained"],
+ },
+ "transformers": {
+ "PreTrainedTokenizer": ["save_pretrained", "from_pretrained"],
+ "PreTrainedTokenizerFast": ["save_pretrained", "from_pretrained"],
+ "PreTrainedModel": ["save_pretrained", "from_pretrained"],
+ "FeatureExtractionMixin": ["save_pretrained", "from_pretrained"],
+ "OneFlowCLIPTextModel": ["save_pretrained", "from_pretrained"],
+ },
+}
+
+ALL_IMPORTABLE_CLASSES = {}
+for library in LOADABLE_CLASSES:
+ ALL_IMPORTABLE_CLASSES.update(LOADABLE_CLASSES[library])
+
+
+@dataclass
+class ImagePipelineOutput(BaseOutput):
+ """
+ Output class for image pipelines.
+
+ Args:
+ images (`List[PIL.Image.Image]` or `np.ndarray`)
+ List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
+ num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
+ """
+
+ images: Union[List[PIL.Image.Image], np.ndarray]
+
+
+class OneFlowDiffusionPipeline(ConfigMixin):
+ r"""
+ Base class for all models.
+
+ [`DiffusionPipeline`] takes care of storing all components (models, schedulers, processors) for diffusion pipelines
+ and handles methods for loading, downloading and saving models as well as a few methods common to all pipelines to:
+
+ - move all PyTorch modules to the device of your choice
+ - enabling/disabling the progress bar for the denoising iteration
+
+ Class attributes:
+
+ - **config_name** ([`str`]) -- name of the config file that will store the class and module names of all
+ components of the diffusion pipeline.
+ """
+ config_name = "model_index.json"
+
+ def register_modules(self, **kwargs):
+ # import it here to avoid circular import
+ from diffusers import pipelines
+
+ for name, module in kwargs.items():
+ # retrieve library
+ library = module.__module__.split(".")[0]
+
+ # check if the module is a pipeline module
+ pipeline_dir = module.__module__.split(".")[-2]
+ path = module.__module__.split(".")
+ is_pipeline_module = pipeline_dir in path and hasattr(pipelines, pipeline_dir)
+
+ # if library is not in LOADABLE_CLASSES, then it is a custom module.
+ # Or if it's a pipeline module, then the module is inside the pipeline
+ # folder so we set the library to module name.
+ if library not in LOADABLE_CLASSES or is_pipeline_module:
+ library = pipeline_dir
+
+ # retrieve class_name
+ class_name = module.__class__.__name__
+
+ register_dict = {name: (library, class_name)}
+
+ # save model index config
+ self.register_to_config(**register_dict)
+
+ # set models
+ setattr(self, name, module)
+
+ def save_pretrained(self, save_directory: Union[str, os.PathLike]):
+ """
+ Save all variables of the pipeline that can be saved and loaded as well as the pipelines configuration file to
+ a directory. A pipeline variable can be saved and loaded if its class implements both a save and loading
+ method. The pipeline can easily be re-loaded using the `[`~DiffusionPipeline.from_pretrained`]` class method.
+
+ Arguments:
+ save_directory (`str` or `os.PathLike`):
+ Directory to which to save. Will be created if it doesn't exist.
+ """
+ self.save_config(save_directory)
+
+ model_index_dict = dict(self.config)
+ model_index_dict.pop("_class_name")
+ model_index_dict.pop("_diffusers_version")
+ model_index_dict.pop("_module", None)
+
+ for pipeline_component_name in model_index_dict.keys():
+ sub_model = getattr(self, pipeline_component_name)
+ model_cls = sub_model.__class__
+
+ save_method_name = None
+ # search for the model's base class in LOADABLE_CLASSES
+ for library_name, library_classes in LOADABLE_CLASSES.items():
+ library = importlib.import_module(library_name)
+ for base_class, save_load_methods in library_classes.items():
+ class_candidate = getattr(library, base_class)
+ if issubclass(model_cls, class_candidate):
+ # if we found a suitable base class in LOADABLE_CLASSES then grab its save method
+ save_method_name = save_load_methods[0]
+ break
+ if save_method_name is not None:
+ break
+
+ save_method = getattr(sub_model, save_method_name)
+ save_method(os.path.join(save_directory, pipeline_component_name))
+
+ def to(self, torch_device: Optional[Union[str, torch.device]] = None):
+ if torch_device is None:
+ return self
+
+ module_names, _ = self.extract_init_dict(dict(self.config))
+ for name in module_names.keys():
+ module = getattr(self, name)
+ if isinstance(module, torch.nn.Module):
+ module.to(torch_device)
+ if isinstance(module, og_torch.nn.Module):
+ print(f"moving pytorch model to cuda {type(module)}: {module.device} => cuda" )
+ if isinstance(torch_device, torch.device):
+ torch_device = og_torch.device(str(torch_device))
+ else:
+ assert isinstance(torch_device, str)
+ module.to(torch_device)
+ return self
+
+ @property
+ def device(self) -> torch.device:
+ r"""
+ Returns:
+ `torch.device`: The torch device on which the pipeline is located.
+ """
+ module_names, _ = self.extract_init_dict(dict(self.config))
+ for name in module_names.keys():
+ module = getattr(self, name)
+ if isinstance(module, torch.nn.Module):
+ return module.device
+ return torch.device("cpu")
+
+ @classmethod
+ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
+ r"""
+ Instantiate a PyTorch diffusion pipeline from pre-trained pipeline weights.
+
+ The pipeline is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated).
+
+ The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come
+ pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning
+ task.
+
+ The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those
+ weights are discarded.
+
+ Parameters:
+ pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
+ Can be either:
+
+ - A string, the *repo id* of a pretrained pipeline hosted inside a model repo on
+ https://huggingface.co/ Valid repo ids have to be located under a user or organization name, like
+ `CompVis/ldm-text2im-large-256`.
+ - A path to a *directory* containing pipeline weights saved using
+ [`~DiffusionPipeline.save_pretrained`], e.g., `./my_pipeline_directory/`.
+ torch_dtype (`str` or `torch.dtype`, *optional*):
+ Override the default `torch.dtype` and load the model under this dtype. If `"auto"` is passed the dtype
+ will be automatically derived from the model's weights.
+ force_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
+ cached versions if they exist.
+ resume_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to delete incompletely received files. Will attempt to resume the download if such a
+ file exists.
+ proxies (`Dict[str, str]`, *optional*):
+ A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
+ output_loading_info(`bool`, *optional*, defaults to `False`):
+ Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
+ local_files_only(`bool`, *optional*, defaults to `False`):
+ Whether or not to only look at local files (i.e., do not try to download the model).
+ use_auth_token (`str` or *bool*, *optional*):
+ The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
+ when running `huggingface-cli login` (stored in `~/.huggingface`).
+ revision (`str`, *optional*, defaults to `"main"`):
+ The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
+ git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
+ identifier allowed by git.
+ mirror (`str`, *optional*):
+ Mirror source to accelerate downloads in China. If you are from China and have an accessibility
+ problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety.
+ Please refer to the mirror site for more information. specify the folder name here.
+
+ kwargs (remaining dictionary of keyword arguments, *optional*):
+ Can be used to overwrite load - and saveable variables - *i.e.* the pipeline components - of the
+ specific pipeline class. The overwritten components are then directly passed to the pipelines
+ `__init__` method. See example below for more information.
+
+
+
+ Passing `use_auth_token=True`` is required when you want to use a private model, *e.g.*
+ `"CompVis/stable-diffusion-v1-4"`
+
+
+
+
+
+ Activate the special ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use
+ this method in a firewalled environment.
+
+
+
+ Examples:
+
+ ```py
+ >>> from diffusers import DiffusionPipeline
+
+ >>> # Download pipeline from huggingface.co and cache.
+ >>> pipeline = DiffusionPipeline.from_pretrained("CompVis/ldm-text2im-large-256")
+
+ >>> # Download pipeline that requires an authorization token
+ >>> # For more information on access tokens, please refer to this section
+ >>> # of the documentation](https://huggingface.co/docs/hub/security-tokens)
+ >>> pipeline = DiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", use_auth_token=True)
+
+ >>> # Download pipeline, but overwrite scheduler
+ >>> from diffusers import LMSDiscreteScheduler
+
+ >>> scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear")
+ >>> pipeline = DiffusionPipeline.from_pretrained(
+ ... "CompVis/stable-diffusion-v1-4", scheduler=scheduler, use_auth_token=True
+ ... )
+ ```
+ """
+ cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
+ resume_download = kwargs.pop("resume_download", False)
+ proxies = kwargs.pop("proxies", None)
+ local_files_only = kwargs.pop("local_files_only", False)
+ use_auth_token = kwargs.pop("use_auth_token", None)
+ revision = kwargs.pop("revision", None)
+ torch_dtype = kwargs.pop("torch_dtype", None)
+ provider = kwargs.pop("provider", None)
+ sess_options = kwargs.pop("sess_options", None)
+
+ # 1. Download the checkpoints and configs
+ # use snapshot download here to get it working from from_pretrained
+ if not os.path.isdir(pretrained_model_name_or_path):
+ config_dict = cls.get_config_dict(
+ pretrained_model_name_or_path,
+ cache_dir=cache_dir,
+ resume_download=resume_download,
+ proxies=proxies,
+ local_files_only=local_files_only,
+ use_auth_token=use_auth_token,
+ revision=revision,
+ )
+ # make sure we only download sub-folders and `diffusers` filenames
+ folder_names = [k for k in config_dict.keys() if not k.startswith("_")]
+ allow_patterns = [os.path.join(k, "*") for k in folder_names]
+ allow_patterns += [WEIGHTS_NAME, SCHEDULER_CONFIG_NAME, CONFIG_NAME, ONNX_WEIGHTS_NAME, cls.config_name]
+
+ # download all allow_patterns
+ cached_folder = snapshot_download(
+ pretrained_model_name_or_path,
+ cache_dir=cache_dir,
+ resume_download=resume_download,
+ proxies=proxies,
+ local_files_only=local_files_only,
+ use_auth_token=use_auth_token,
+ revision=revision,
+ allow_patterns=allow_patterns,
+ )
+ else:
+ cached_folder = pretrained_model_name_or_path
+
+ config_dict = cls.get_config_dict(cached_folder)
+
+ # 2. Load the pipeline class, if using custom module then load it from the hub
+ # if we load from explicit class, let's use it
+ if cls != OneFlowDiffusionPipeline:
+ pipeline_class = cls
+ else:
+ diffusers_module = importlib.import_module(cls.__module__.split(".")[0])
+ pipeline_class = getattr(diffusers_module, config_dict["_class_name"])
+
+ # some modules can be passed directly to the init
+ # in this case they are already instantiated in `kwargs`
+ # extract them here
+ expected_modules = set(inspect.signature(pipeline_class.__init__).parameters.keys())
+ passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs}
+
+ init_dict, _ = pipeline_class.extract_init_dict(config_dict, **kwargs)
+
+ init_kwargs = {}
+
+ # import it here to avoid circular import
+ from diffusers import pipelines
+
+ # 3. Load each module in the pipeline
+ for name, (library_name, class_name) in init_dict.items():
+ if name in ["scheduler", "unet", "vae", "text_encoder", "safety_checker"]:
+ class_name = "OneFlow" + class_name
+ print(f"[oneflow]", f"[{name}]", f"{library_name}.{class_name}")
+ else:
+ print(f"[diffusers]", f"[{name}]", f"{library_name}.{class_name}")
+ # 3.1 - now that JAX/Flax is an official framework of the library, we might load from Flax names
+ if class_name.startswith("Flax"):
+ class_name = class_name[4:]
+
+ is_pipeline_module = hasattr(pipelines, library_name)
+ loaded_sub_model = None
+
+ # if the model is in a pipeline module, then we load it from the pipeline
+ if name in passed_class_obj:
+ # 1. check that passed_class_obj has correct parent class
+ if not is_pipeline_module:
+ library = importlib.import_module(library_name)
+ class_obj = getattr(library, class_name)
+ importable_classes = LOADABLE_CLASSES[library_name]
+ class_candidates = {c: getattr(library, c) for c in importable_classes.keys()}
+
+ expected_class_obj = None
+ for class_name, class_candidate in class_candidates.items():
+ if issubclass(class_obj, class_candidate):
+ expected_class_obj = class_candidate
+
+ if not issubclass(passed_class_obj[name].__class__, expected_class_obj):
+ raise ValueError(
+ f"{passed_class_obj[name]} is of type: {type(passed_class_obj[name])}, but should be"
+ f" {expected_class_obj}"
+ )
+ else:
+ logger.warn(
+ f"You have passed a non-standard module {passed_class_obj[name]}. We cannot verify whether it"
+ " has the correct type"
+ )
+
+ # set passed class object
+ loaded_sub_model = passed_class_obj[name]
+ elif is_pipeline_module:
+ pipeline_module = getattr(pipelines, library_name)
+ class_obj = getattr(pipeline_module, class_name)
+ importable_classes = ALL_IMPORTABLE_CLASSES
+ class_candidates = {c: class_obj for c in importable_classes.keys()}
+ else:
+ # else we just import it from the library.
+ library = importlib.import_module(library_name)
+ class_obj = getattr(library, class_name)
+ importable_classes = LOADABLE_CLASSES[library_name]
+ class_candidates = {c: getattr(library, c) for c in importable_classes.keys()}
+
+ if loaded_sub_model is None:
+ load_method_name = None
+ for class_name, class_candidate in class_candidates.items():
+ if issubclass(class_obj, class_candidate):
+ load_method_name = importable_classes[class_name][1]
+
+ try:
+ load_method = getattr(class_obj, load_method_name)
+ except TypeError as e:
+ print(f"fail to load {library_name}.{class_name}, class obj: {class_obj}, maybe it is not allowed?")
+ raise e
+ loading_kwargs = {}
+ if issubclass(class_obj, torch.nn.Module):
+ loading_kwargs["torch_dtype"] = torch_dtype
+ if issubclass(class_obj, diffusers.OnnxRuntimeModel):
+ loading_kwargs["provider"] = provider
+ loading_kwargs["sess_options"] = sess_options
+
+ # check if the module is in a subdirectory
+ if os.path.isdir(os.path.join(cached_folder, name)):
+ loaded_sub_model = load_method(os.path.join(cached_folder, name), **loading_kwargs)
+ else:
+ # else load from the root directory
+ loaded_sub_model = load_method(cached_folder, **loading_kwargs)
+
+ init_kwargs[name] = loaded_sub_model # UNet(...), # DiffusionSchedule(...)
+
+ # 4. Instantiate the pipeline
+ model = pipeline_class(**init_kwargs)
+ return model
+
+ @staticmethod
+ def numpy_to_pil(images):
+ """
+ Convert a numpy image or a batch of images to a PIL image.
+ """
+ if images.ndim == 3:
+ images = images[None, ...]
+ images = (images * 255).round().astype("uint8")
+ pil_images = [Image.fromarray(image) for image in images]
+
+ return pil_images
+
+ def progress_bar(self, iterable):
+ if not hasattr(self, "_progress_bar_config"):
+ self._progress_bar_config = {}
+ elif not isinstance(self._progress_bar_config, dict):
+ raise ValueError(
+ f"`self._progress_bar_config` should be of type `dict`, but is {type(self._progress_bar_config)}."
+ )
+
+ return tqdm(iterable, **self._progress_bar_config)
+
+ def set_progress_bar_config(self, **kwargs):
+ self._progress_bar_config = kwargs
diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py
index 8e3c8592a258..e49346b3fd51 100644
--- a/src/diffusers/pipelines/__init__.py
+++ b/src/diffusers/pipelines/__init__.py
@@ -20,3 +20,5 @@
if is_transformers_available() and is_flax_available():
from .stable_diffusion import FlaxStableDiffusionPipeline
+
+from .stable_diffusion import OneFlowStableDiffusionPipeline
diff --git a/src/diffusers/pipelines/stable_diffusion/__init__.py b/src/diffusers/pipelines/stable_diffusion/__init__.py
index 1016ce69e450..b6824aa64088 100644
--- a/src/diffusers/pipelines/stable_diffusion/__init__.py
+++ b/src/diffusers/pipelines/stable_diffusion/__init__.py
@@ -58,3 +58,6 @@ class FlaxStableDiffusionPipelineOutput(BaseOutput):
from .pipeline_flax_stable_diffusion import FlaxStableDiffusionPipeline
from .safety_checker_flax import FlaxStableDiffusionSafetyChecker
+
+from .pipeline_stable_diffusion_oneflow import OneFlowStableDiffusionPipeline
+from .safety_checker_oneflow import OneFlowStableDiffusionSafetyChecker
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_oneflow.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_oneflow.py
new file mode 100644
index 000000000000..f3776207e9ff
--- /dev/null
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_oneflow.py
@@ -0,0 +1,345 @@
+import inspect
+import warnings
+from typing import List, Optional, Union
+
+import oneflow as torch
+
+from transformers import CLIPFeatureExtractor, CLIPTokenizer
+from transformers import OneFlowCLIPTextModel as CLIPTextModel
+
+from ...configuration_utils import FrozenDict
+from ...models import OneFlowAutoencoderKL as AutoencoderKL
+from ...models import OneFlowUNet2DConditionModel as UNet2DConditionModel
+from ...pipeline_oneflow_utils import OneFlowDiffusionPipeline as DiffusionPipeline
+from ...schedulers import OneFlowDDIMScheduler as DDIMScheduler
+from ...schedulers import OneFlowPNDMScheduler as PNDMScheduler
+from ...schedulers import LMSDiscreteScheduler
+from . import StableDiffusionPipelineOutput
+from .safety_checker_oneflow import OneFlowStableDiffusionSafetyChecker as StableDiffusionSafetyChecker
+
+import os
+os.environ["ONEFLOW_MLIR_ENABLE_ROUND_TRIP"] = "1"
+os.environ["ONEFLOW_MLIR_ENABLE_INFERENCE_OPTIMIZATION"] = "1"
+os.environ["ONEFLOW_MLIR_PREFER_NHWC"] = "1"
+os.environ["ONEFLOW_KERNEL_ENABLE_CUDNN_FUSED_CONV_BIAS"] = "1"
+os.environ["ONEFLOW_KERNEL_ENABLE_FUSED_LINEAR"] = "1"
+
+import oneflow as flow
+class UNetGraph(flow.nn.Graph):
+ def __init__(self, unet):
+ super().__init__()
+ self.unet = unet
+ self.config.enable_cudnn_conv_heuristic_search_algo(False)
+
+ def build(self, latent_model_input, t, text_embeddings):
+ text_embeddings = torch._C.amp_white_identity(text_embeddings)
+ return self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
+
+class OneFlowStableDiffusionPipeline(DiffusionPipeline):
+ r"""
+ Pipeline for text-to-image generation using Stable Diffusion.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModel`]):
+ Frozen text-encoder. Stable Diffusion uses the text portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ safety_checker ([`StableDiffusionSafetyChecker`]):
+ Classification module that estimates whether generated images could be considered offensive or harmful.
+ Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details.
+ feature_extractor ([`CLIPFeatureExtractor`]):
+ Model that extracts features from generated images to be used as inputs for the `safety_checker`.
+ """
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
+ safety_checker: StableDiffusionSafetyChecker,
+ feature_extractor: CLIPFeatureExtractor,
+ ):
+ super().__init__()
+ scheduler = scheduler.set_format("pt")
+
+ if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
+ warnings.warn(
+ f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
+ f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
+ "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
+ " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
+ " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
+ " file",
+ DeprecationWarning,
+ )
+ new_config = dict(scheduler.config)
+ new_config["steps_offset"] = 1
+ scheduler._internal_dict = FrozenDict(new_config)
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
+ )
+ self.unet_graph = UNetGraph(self.unet)
+ self.unet_compiled = False
+
+ def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
+ r"""
+ Enable sliced attention computation.
+
+ When this option is enabled, the attention module will split the input tensor in slices, to compute attention
+ in several steps. This is useful to save some memory in exchange for a small speed decrease.
+
+ Args:
+ slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
+ When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
+ a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,
+ `attention_head_dim` must be a multiple of `slice_size`.
+ """
+ if slice_size == "auto":
+ # half the attention head size is usually a good trade-off between
+ # speed and memory
+ slice_size = self.unet.config.attention_head_dim // 2
+ self.unet.set_attention_slice(slice_size)
+
+ def disable_attention_slicing(self):
+ r"""
+ Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
+ back to computing attention in one step.
+ """
+ # set slice_size = `None` to disable `attention slicing`
+ self.enable_attention_slicing(None)
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt: Union[str, List[str]],
+ height: Optional[int] = 512,
+ width: Optional[int] = 512,
+ num_inference_steps: Optional[int] = 50,
+ guidance_scale: Optional[float] = 7.5,
+ eta: Optional[float] = 0.0,
+ generator: Optional[torch.Generator] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ **kwargs,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`):
+ The prompt or prompts to guide the image generation.
+ height (`int`, *optional*, defaults to 512):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to 512):
+ The width in pixels of the generated image.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator`, *optional*):
+ A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
+ deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
+ (nsfw) content, according to the `safety_checker`.
+ """
+
+ from timeit import default_timer as timer
+ start = timer()
+ if "torch_device" in kwargs:
+ device = kwargs.pop("torch_device")
+ warnings.warn(
+ "`torch_device` is deprecated as an input argument to `__call__` and will be removed in v0.3.0."
+ " Consider using `pipe.to(torch_device)` instead."
+ )
+
+ # Set device as before (to be removed in 0.3.0)
+ if device is None:
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+ self.to(device)
+
+ if isinstance(prompt, str):
+ batch_size = 1
+ elif isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ # get prompt text embeddings
+ text_input = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="np",
+ )
+ text_input.input_ids = torch.from_numpy(text_input.input_ids)
+ torch._oneflow_internal.profiler.RangePush(f"text-encoder")
+ text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0]
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance:
+ max_length = text_input.input_ids.shape[-1]
+ uncond_input = self.tokenizer(
+ [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="np"
+ )
+ uncond_input.input_ids = torch.from_numpy(uncond_input.input_ids)
+ uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
+ torch._oneflow_internal.profiler.RangePop()
+ # get the initial random noise unless the user supplied it
+
+ # Unlike in other pipelines, latents need to be generated in the target device
+ # for 1-to-1 results reproducibility with the CompVis implementation.
+ # However this currently doesn't work in `mps`.
+ latents_device = "cpu" if self.device.type == "mps" else self.device
+ latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8)
+ if latents is None:
+ latents = torch.randn(
+ latents_shape,
+ generator=generator,
+ device=latents_device,
+ )
+ else:
+ if latents.shape != latents_shape:
+ raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
+ latents = latents.to(self.device)
+
+ # set timesteps
+ self.scheduler.set_timesteps(num_inference_steps)
+
+ # if we use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas
+ if isinstance(self.scheduler, LMSDiscreteScheduler):
+ latents = latents * self.scheduler.sigmas[0]
+
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ compilation_start = timer()
+ compilation_time = 0
+ if self.unet_compiled == False:
+ print("[oneflow]", "compiling unet beforehand to make sure the progress bar is more accurate")
+ i, t = list(enumerate(self.scheduler.timesteps))[0]
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+ self.unet_graph._compile(latent_model_input, t, text_embeddings)
+ self.unet_compiled = True
+ self.unet_graph(latent_model_input, t, text_embeddings) # warmup
+ compilation_time = timer() - compilation_start
+ print("[oneflow]", "[elapsed(s)]", "[unet compilation]", compilation_time)
+
+ for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)):
+ torch._oneflow_internal.profiler.RangePush(f"denoise-{i}")
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+ if isinstance(self.scheduler, LMSDiscreteScheduler):
+ sigma = self.scheduler.sigmas[i]
+ # the model input needs to be scaled to match the continuous ODE formulation in K-LMS
+ latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5)
+
+ # predict the noise residual
+ torch._oneflow_internal.profiler.RangePush(f"denoise-{i}-unet-graph")
+ noise_pred = self.unet_graph(latent_model_input, t, text_embeddings)
+ torch._oneflow_internal.profiler.RangePop()
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ if isinstance(self.scheduler, LMSDiscreteScheduler):
+ latents = self.scheduler.step(noise_pred, i, latents, **extra_step_kwargs).prev_sample
+ else:
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
+ torch._oneflow_internal.profiler.RangePop()
+
+ # scale and decode the image latents with vae
+ latents = 1 / 0.18215 * latents
+ import numpy as np
+ if isinstance(latents, np.ndarray):
+ latents = torch.from_numpy(latents)
+ image = self.vae.decode(latents).sample
+ print("[oneflow]", "[elapsed(s)]", "[image]", timer() - start - compilation_time)
+ post_process_start = timer()
+
+ image = (image / 2 + 0.5).clamp(0, 1)
+ image = image.cpu().permute(0, 2, 3, 1).numpy()
+
+ # run safety checker
+ safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="np")
+ safety_checker_input.pixel_values = torch.from_numpy(safety_checker_input.pixel_values).to(self.device)
+ torch._oneflow_internal.profiler.RangePush(f"safety-checker")
+ image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_checker_input.pixel_values)
+ torch._oneflow_internal.profiler.RangePop()
+
+ if output_type == "pil":
+ image = self.numpy_to_pil(image)
+
+ if not return_dict:
+ return (image, has_nsfw_concept)
+ import torch as og_torch
+ assert og_torch.cuda.is_initialized() is False
+
+ print("[oneflow]", "[elapsed(s)]", "[post-process]", timer() - post_process_start)
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
diff --git a/src/diffusers/pipelines/stable_diffusion/safety_checker_oneflow.py b/src/diffusers/pipelines/stable_diffusion/safety_checker_oneflow.py
new file mode 100644
index 000000000000..e5cd2bef36d3
--- /dev/null
+++ b/src/diffusers/pipelines/stable_diffusion/safety_checker_oneflow.py
@@ -0,0 +1,108 @@
+import numpy as np
+import oneflow as torch
+import oneflow.nn as nn
+
+from transformers import CLIPConfig
+from transformers.modeling_oneflow_utils import OneFlowPreTrainedModel as PreTrainedModel
+from transformers import OneFlowCLIPVisionModel as CLIPVisionModel
+
+from ...utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+
+def cosine_distance(image_embeds, text_embeds):
+ normalized_image_embeds = nn.functional.normalize(image_embeds)
+ normalized_text_embeds = nn.functional.normalize(text_embeds)
+ return torch.mm(normalized_image_embeds, normalized_text_embeds.t())
+
+
+class OneFlowStableDiffusionSafetyChecker(PreTrainedModel):
+ config_class = CLIPConfig
+
+ def __init__(self, config: CLIPConfig):
+ super().__init__(config)
+
+ self.vision_model = CLIPVisionModel(config.vision_config)
+ self.visual_projection = nn.Linear(config.vision_config.hidden_size, config.projection_dim, bias=False)
+
+ self.concept_embeds = nn.Parameter(torch.ones(17, config.projection_dim), requires_grad=False)
+ self.special_care_embeds = nn.Parameter(torch.ones(3, config.projection_dim), requires_grad=False)
+
+ self.register_buffer("concept_embeds_weights", torch.ones(17))
+ self.register_buffer("special_care_embeds_weights", torch.ones(3))
+
+ @torch.no_grad()
+ def forward(self, clip_input, images):
+ pooled_output = self.vision_model(clip_input)[1] # pooled_output
+ image_embeds = self.visual_projection(pooled_output)
+
+ special_cos_dist = cosine_distance(image_embeds, self.special_care_embeds).cpu().numpy()
+ cos_dist = cosine_distance(image_embeds, self.concept_embeds).cpu().numpy()
+
+ result = []
+ batch_size = image_embeds.shape[0]
+ for i in range(batch_size):
+ result_img = {"special_scores": {}, "special_care": [], "concept_scores": {}, "bad_concepts": []}
+
+ # increase this value to create a stronger `nfsw` filter
+ # at the cost of increasing the possibility of filtering benign images
+ adjustment = 0.0
+
+ for concept_idx in range(len(special_cos_dist[0])):
+ concept_cos = special_cos_dist[i][concept_idx]
+ concept_threshold = self.special_care_embeds_weights[concept_idx].item()
+ result_img["special_scores"][concept_idx] = round(concept_cos - concept_threshold + adjustment, 3)
+ if result_img["special_scores"][concept_idx] > 0:
+ result_img["special_care"].append({concept_idx, result_img["special_scores"][concept_idx]})
+ adjustment = 0.01
+
+ for concept_idx in range(len(cos_dist[0])):
+ concept_cos = cos_dist[i][concept_idx]
+ concept_threshold = self.concept_embeds_weights[concept_idx].item()
+ result_img["concept_scores"][concept_idx] = round(concept_cos - concept_threshold + adjustment, 3)
+ if result_img["concept_scores"][concept_idx] > 0:
+ result_img["bad_concepts"].append(concept_idx)
+
+ result.append(result_img)
+
+ has_nsfw_concepts = [len(res["bad_concepts"]) > 0 for res in result]
+
+ for idx, has_nsfw_concept in enumerate(has_nsfw_concepts):
+ if has_nsfw_concept:
+ images[idx] = np.zeros(images[idx].shape) # black image
+
+ if any(has_nsfw_concepts):
+ logger.warning(
+ "Potential NSFW content was detected in one or more images. A black image will be returned instead."
+ " Try again with a different prompt and/or seed."
+ )
+
+ return images, has_nsfw_concepts
+
+ @torch.inference_mode()
+ def forward_onnx(self, clip_input: torch.FloatTensor, images: torch.FloatTensor):
+ pooled_output = self.vision_model(clip_input)[1] # pooled_output
+ image_embeds = self.visual_projection(pooled_output)
+
+ special_cos_dist = cosine_distance(image_embeds, self.special_care_embeds)
+ cos_dist = cosine_distance(image_embeds, self.concept_embeds)
+
+ # increase this value to create a stronger `nsfw` filter
+ # at the cost of increasing the possibility of filtering benign images
+ adjustment = 0.0
+
+ special_scores = special_cos_dist - self.special_care_embeds_weights + adjustment
+ # special_scores = special_scores.round(decimals=3)
+ special_care = torch.any(special_scores > 0, dim=1)
+ special_adjustment = special_care * 0.01
+ special_adjustment = special_adjustment.unsqueeze(1).expand(-1, cos_dist.shape[1])
+
+ concept_scores = (cos_dist - self.concept_embeds_weights) + special_adjustment
+ # concept_scores = concept_scores.round(decimals=3)
+ has_nsfw_concepts = torch.any(concept_scores > 0, dim=1)
+
+ images[has_nsfw_concepts] = 0.0 # black image
+
+ return images, has_nsfw_concepts
diff --git a/src/diffusers/schedulers/__init__.py b/src/diffusers/schedulers/__init__.py
index 495f30d9fabd..711826364a25 100644
--- a/src/diffusers/schedulers/__init__.py
+++ b/src/diffusers/schedulers/__init__.py
@@ -41,3 +41,8 @@
from .scheduling_lms_discrete import LMSDiscreteScheduler
else:
from ..utils.dummy_torch_and_scipy_objects import * # noqa F403
+
+
+from .scheduling_oneflow_utils import OneFlowSchedulerMixin
+from .scheduling_ddim_oneflow import OneFlowDDIMScheduler
+from .scheduling_pndm_oneflow import OneFlowPNDMScheduler
diff --git a/src/diffusers/schedulers/scheduling_ddim_oneflow.py b/src/diffusers/schedulers/scheduling_ddim_oneflow.py
new file mode 100644
index 000000000000..a1f9df82ea2a
--- /dev/null
+++ b/src/diffusers/schedulers/scheduling_ddim_oneflow.py
@@ -0,0 +1,312 @@
+# Copyright 2022 Stanford University Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion
+# and https://github.com/hojonathanho/diffusion
+
+import math
+import warnings
+from dataclasses import dataclass
+from typing import Optional, Tuple, Union
+
+import numpy as np
+import oneflow as torch
+
+from ..configuration_utils import ConfigMixin, register_to_config
+from ..utils import BaseOutput
+from .scheduling_oneflow_utils import OneFlowSchedulerMixin as SchedulerMixin
+from ..modeling_oneflow_utils import print_dtype
+
+
+@dataclass
+class DDIMSchedulerOutput(BaseOutput):
+ """
+ Output class for the scheduler's step function output.
+
+ Args:
+ prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
+ Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
+ denoising loop.
+ pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
+ The predicted denoised sample (x_{0}) based on the model output from the current timestep.
+ `pred_original_sample` can be used to preview progress or for guidance.
+ """
+
+ prev_sample: torch.FloatTensor
+ pred_original_sample: Optional[torch.FloatTensor] = None
+
+
+def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
+ """
+ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
+ (1-beta) over time from t = [0,1].
+
+ Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
+ to that part of the diffusion process.
+
+
+ Args:
+ num_diffusion_timesteps (`int`): the number of betas to produce.
+ max_beta (`float`): the maximum beta to use; use values lower than 1 to
+ prevent singularities.
+
+ Returns:
+ betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
+ """
+
+ def alpha_bar(time_step):
+ return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2
+
+ betas = []
+ for i in range(num_diffusion_timesteps):
+ t1 = i / num_diffusion_timesteps
+ t2 = (i + 1) / num_diffusion_timesteps
+ betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
+ return np.array(betas, dtype=np.float32)
+
+
+class OneFlowDDIMScheduler(SchedulerMixin, ConfigMixin):
+ """
+ Denoising diffusion implicit models is a scheduler that extends the denoising procedure introduced in denoising
+ diffusion probabilistic models (DDPMs) with non-Markovian guidance.
+
+ [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
+ function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
+ [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
+ [`~ConfigMixin.from_config`] functions.
+
+ For more details, see the original paper: https://arxiv.org/abs/2010.02502
+
+ Args:
+ num_train_timesteps (`int`): number of diffusion steps used to train the model.
+ beta_start (`float`): the starting `beta` value of inference.
+ beta_end (`float`): the final `beta` value.
+ beta_schedule (`str`):
+ the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
+ `linear`, `scaled_linear`, or `squaredcos_cap_v2`.
+ trained_betas (`np.ndarray`, optional):
+ option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
+ clip_sample (`bool`, default `True`):
+ option to clip predicted sample between -1 and 1 for numerical stability.
+ set_alpha_to_one (`bool`, default `True`):
+ each diffusion step uses the value of alphas product at that step and at the previous one. For the final
+ step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
+ otherwise it uses the value of alpha at step 0.
+ steps_offset (`int`, default `0`):
+ an offset added to the inference steps. You can use a combination of `offset=1` and
+ `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
+ stable diffusion.
+ tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays.
+
+ """
+
+ @register_to_config
+ def __init__(
+ self,
+ num_train_timesteps: int = 1000,
+ beta_start: float = 0.0001,
+ beta_end: float = 0.02,
+ beta_schedule: str = "linear",
+ trained_betas: Optional[np.ndarray] = None,
+ clip_sample: bool = True,
+ set_alpha_to_one: bool = True,
+ steps_offset: int = 0,
+ tensor_format: str = "pt",
+ ):
+ if trained_betas is not None:
+ self.betas = np.asarray(trained_betas)
+ if beta_schedule == "linear":
+ self.betas = np.linspace(beta_start, beta_end, num_train_timesteps, dtype=np.float32)
+ elif beta_schedule == "scaled_linear":
+ # this schedule is very specific to the latent diffusion model.
+ self.betas = np.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=np.float32) ** 2
+ elif beta_schedule == "squaredcos_cap_v2":
+ # Glide cosine schedule
+ self.betas = betas_for_alpha_bar(num_train_timesteps)
+ else:
+ raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
+
+ self.alphas = 1.0 - self.betas
+ self.alphas_cumprod = np.cumprod(self.alphas, axis=0)
+
+ # At every step in ddim, we are looking into the previous alphas_cumprod
+ # For the final step, there is no previous alphas_cumprod because we are already at 0
+ # `set_alpha_to_one` decides whether we set this parameter simply to one or
+ # whether we use the final alpha of the "non-previous" one.
+ self.final_alpha_cumprod = np.array(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
+
+ # setable values
+ self.num_inference_steps = None
+ self.timesteps = np.arange(0, num_train_timesteps)[::-1].copy()
+
+ self.tensor_format = tensor_format
+ self.set_format(tensor_format=tensor_format)
+
+ def _get_variance(self, timestep, prev_timestep):
+ alpha_prod_t = self.alphas_cumprod[timestep]
+ alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
+ beta_prod_t = 1 - alpha_prod_t
+ beta_prod_t_prev = 1 - alpha_prod_t_prev
+
+ if beta_prod_t_prev.dtype == torch.float64:
+ beta_prod_t_prev = beta_prod_t_prev.to(dtype=torch.float32)
+ if alpha_prod_t_prev.dtype == torch.float64:
+ alpha_prod_t_prev = alpha_prod_t_prev.to(dtype=torch.float32)
+ if isinstance(beta_prod_t_prev, np.float64):
+ beta_prod_t_prev = beta_prod_t_prev.item()
+ if isinstance(alpha_prod_t_prev, np.float32):
+ alpha_prod_t_prev = alpha_prod_t_prev.item()
+ variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
+
+ return variance
+
+ def set_timesteps(self, num_inference_steps: int, **kwargs):
+ """
+ Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
+
+ Args:
+ num_inference_steps (`int`):
+ the number of diffusion steps used when generating samples with a pre-trained model.
+ """
+
+ offset = self.config.steps_offset
+
+ if "offset" in kwargs:
+ warnings.warn(
+ "`offset` is deprecated as an input argument to `set_timesteps` and will be removed in v0.4.0."
+ " Please pass `steps_offset` to `__init__` instead.",
+ DeprecationWarning,
+ )
+
+ offset = kwargs["offset"]
+
+ self.num_inference_steps = num_inference_steps
+ step_ratio = self.config.num_train_timesteps // self.num_inference_steps
+ # creates integer timesteps by multiplying by ratio
+ # casting to int to avoid issues when num_inference_step is power of 3
+ self.timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy()
+ self.timesteps += offset
+ self.set_format(tensor_format=self.tensor_format)
+
+ def step(
+ self,
+ model_output: Union[torch.FloatTensor, np.ndarray],
+ timestep: int,
+ sample: Union[torch.FloatTensor, np.ndarray],
+ eta: float = 0.0,
+ use_clipped_model_output: bool = False,
+ generator=None,
+ return_dict: bool = True,
+ ) -> Union[DDIMSchedulerOutput, Tuple]:
+ """
+ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
+ process from the learned model outputs (most often the predicted noise).
+
+ Args:
+ model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model.
+ timestep (`int`): current discrete timestep in the diffusion chain.
+ sample (`torch.FloatTensor` or `np.ndarray`):
+ current instance of sample being created by diffusion process.
+ eta (`float`): weight of noise for added noise in diffusion step.
+ use_clipped_model_output (`bool`): TODO
+ generator: random number generator.
+ return_dict (`bool`): option for returning tuple rather than DDIMSchedulerOutput class
+
+ Returns:
+ [`~schedulers.scheduling_utils.DDIMSchedulerOutput`] or `tuple`:
+ [`~schedulers.scheduling_utils.DDIMSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When
+ returning a tuple, the first element is the sample tensor.
+
+ """
+ if self.num_inference_steps is None:
+ raise ValueError(
+ "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
+ )
+
+ # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
+ # Ideally, read DDIM paper in-detail understanding
+
+ # Notation ( ->
+ # - pred_noise_t -> e_theta(x_t, t)
+ # - pred_original_sample -> f_theta(x_t, t) or x_0
+ # - std_dev_t -> sigma_t
+ # - eta -> η
+ # - pred_sample_direction -> "direction pointing to x_t"
+ # - pred_prev_sample -> "x_t-1"
+
+ # 1. get previous step value (=t-1)
+ prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps
+
+ # 2. compute alphas, betas
+ alpha_prod_t = self.alphas_cumprod[timestep]
+ alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
+
+ beta_prod_t = 1 - alpha_prod_t
+
+ # 3. compute predicted original sample from predicted noise also called
+ # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
+ pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
+
+ # 4. Clip "predicted x_0"
+ if self.config.clip_sample:
+ pred_original_sample = self.clip(pred_original_sample, -1, 1)
+
+ # 5. compute variance: "sigma_t(η)" -> see formula (16)
+ # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
+ variance = self._get_variance(timestep, prev_timestep)
+ std_dev_t = eta * variance ** (0.5)
+
+ if use_clipped_model_output:
+ # the model_output is always re-derived from the clipped x_0 in Glide
+ model_output = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
+
+ # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
+ pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * model_output
+
+ # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
+ prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
+
+ if eta > 0:
+ device = model_output.device if torch.is_tensor(model_output) else "cpu"
+ noise = torch.randn(model_output.shape, generator=generator).to(device)
+ variance = self._get_variance(timestep, prev_timestep) ** (0.5) * eta * noise
+
+ if not torch.is_tensor(model_output):
+ variance = variance.numpy()
+
+ prev_sample = prev_sample + variance
+
+ if not return_dict:
+ return (prev_sample,)
+
+ return DDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
+
+ def add_noise(
+ self,
+ original_samples: Union[torch.FloatTensor, np.ndarray],
+ noise: Union[torch.FloatTensor, np.ndarray],
+ timesteps: Union[torch.IntTensor, np.ndarray],
+ ) -> Union[torch.FloatTensor, np.ndarray]:
+ if self.tensor_format == "pt":
+ timesteps = timesteps.to(self.alphas_cumprod.device)
+ sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
+ sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples)
+ sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
+ sqrt_one_minus_alpha_prod = self.match_shape(sqrt_one_minus_alpha_prod, original_samples)
+
+ noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
+ return noisy_samples
+
+ def __len__(self):
+ return self.config.num_train_timesteps
diff --git a/src/diffusers/schedulers/scheduling_oneflow_utils.py b/src/diffusers/schedulers/scheduling_oneflow_utils.py
new file mode 100644
index 000000000000..0effe944f46a
--- /dev/null
+++ b/src/diffusers/schedulers/scheduling_oneflow_utils.py
@@ -0,0 +1,125 @@
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from dataclasses import dataclass
+from typing import Union
+
+import numpy as np
+import oneflow as torch
+
+from ..utils import BaseOutput
+
+
+SCHEDULER_CONFIG_NAME = "scheduler_config.json"
+
+
+@dataclass
+class SchedulerOutput(BaseOutput):
+ """
+ Base class for the scheduler's step function output.
+
+ Args:
+ prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
+ Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
+ denoising loop.
+ """
+
+ prev_sample: torch.FloatTensor
+
+
+class OneFlowSchedulerMixin:
+ """
+ Mixin containing common functions for the schedulers.
+ """
+
+ config_name = SCHEDULER_CONFIG_NAME
+ ignore_for_config = ["tensor_format"]
+
+ def set_format(self, tensor_format="pt"):
+ self.tensor_format = tensor_format
+ if tensor_format == "pt":
+ for key, value in vars(self).items():
+ if isinstance(value, np.ndarray):
+ setattr(self, key, torch.from_numpy(value))
+
+ return self
+
+ def clip(self, tensor, min_value=None, max_value=None):
+ tensor_format = getattr(self, "tensor_format", "pt")
+
+ if tensor_format == "np":
+ return np.clip(tensor, min_value, max_value)
+ elif tensor_format == "pt":
+ return torch.clamp(tensor, min_value, max_value)
+
+ raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.")
+
+ def log(self, tensor):
+ tensor_format = getattr(self, "tensor_format", "pt")
+
+ if tensor_format == "np":
+ return np.log(tensor)
+ elif tensor_format == "pt":
+ return torch.log(tensor)
+
+ raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.")
+
+ def match_shape(self, values: Union[np.ndarray, torch.Tensor], broadcast_array: Union[np.ndarray, torch.Tensor]):
+ """
+ Turns a 1-D array into an array or tensor with len(broadcast_array.shape) dims.
+
+ Args:
+ values: an array or tensor of values to extract.
+ broadcast_array: an array with a larger shape of K dimensions with the batch
+ dimension equal to the length of timesteps.
+ Returns:
+ a tensor of shape [batch_size, 1, ...] where the shape has K dims.
+ """
+
+ tensor_format = getattr(self, "tensor_format", "pt")
+ values = values.flatten()
+
+ while len(values.shape) < len(broadcast_array.shape):
+ values = values[..., None]
+ if tensor_format == "pt":
+ values = values.to(broadcast_array.device)
+
+ return values
+
+ def norm(self, tensor):
+ tensor_format = getattr(self, "tensor_format", "pt")
+ if tensor_format == "np":
+ return np.linalg.norm(tensor)
+ elif tensor_format == "pt":
+ return torch.norm(tensor.reshape(tensor.shape[0], -1), dim=-1).mean()
+
+ raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.")
+
+ def randn_like(self, tensor, generator=None):
+ tensor_format = getattr(self, "tensor_format", "pt")
+ if tensor_format == "np":
+ return np.random.randn(*np.shape(tensor))
+ elif tensor_format == "pt":
+ # return torch.randn_like(tensor)
+ return torch.randn(tensor.shape, layout=tensor.layout, generator=generator).to(tensor.device)
+
+ raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.")
+
+ def zeros_like(self, tensor):
+ tensor_format = getattr(self, "tensor_format", "pt")
+ if tensor_format == "np":
+ return np.zeros_like(tensor)
+ elif tensor_format == "pt":
+ return torch.zeros_like(tensor)
+
+ raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.")
diff --git a/src/diffusers/schedulers/scheduling_pndm_oneflow.py b/src/diffusers/schedulers/scheduling_pndm_oneflow.py
new file mode 100644
index 000000000000..350d323d8859
--- /dev/null
+++ b/src/diffusers/schedulers/scheduling_pndm_oneflow.py
@@ -0,0 +1,411 @@
+# Copyright 2022 Zhejiang University Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim
+
+import math
+import warnings
+from typing import Optional, Tuple, Union
+
+import numpy as np
+import oneflow as torch
+
+from ..configuration_utils import ConfigMixin, register_to_config
+from .scheduling_oneflow_utils import OneFlowSchedulerMixin, SchedulerOutput
+from ..modeling_oneflow_utils import extract_scalar
+
+def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
+ """
+ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
+ (1-beta) over time from t = [0,1].
+
+ Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
+ to that part of the diffusion process.
+
+
+ Args:
+ num_diffusion_timesteps (`int`): the number of betas to produce.
+ max_beta (`float`): the maximum beta to use; use values lower than 1 to
+ prevent singularities.
+
+ Returns:
+ betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
+ """
+
+ def alpha_bar(time_step):
+ return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2
+
+ betas = []
+ for i in range(num_diffusion_timesteps):
+ t1 = i / num_diffusion_timesteps
+ t2 = (i + 1) / num_diffusion_timesteps
+ betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
+ return np.array(betas, dtype=np.float32)
+
+
+class OneFlowPNDMScheduler(OneFlowSchedulerMixin, ConfigMixin):
+ """
+ Pseudo numerical methods for diffusion models (PNDM) proposes using more advanced ODE integration techniques,
+ namely Runge-Kutta method and a linear multi-step method.
+
+ [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
+ function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
+ [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
+ [`~ConfigMixin.from_config`] functions.
+
+ For more details, see the original paper: https://arxiv.org/abs/2202.09778
+
+ Args:
+ num_train_timesteps (`int`): number of diffusion steps used to train the model.
+ beta_start (`float`): the starting `beta` value of inference.
+ beta_end (`float`): the final `beta` value.
+ beta_schedule (`str`):
+ the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
+ `linear`, `scaled_linear`, or `squaredcos_cap_v2`.
+ trained_betas (`np.ndarray`, optional):
+ option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
+ skip_prk_steps (`bool`):
+ allows the scheduler to skip the Runge-Kutta steps that are defined in the original paper as being required
+ before plms steps; defaults to `False`.
+ set_alpha_to_one (`bool`, default `False`):
+ each diffusion step uses the value of alphas product at that step and at the previous one. For the final
+ step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
+ otherwise it uses the value of alpha at step 0.
+ steps_offset (`int`, default `0`):
+ an offset added to the inference steps. You can use a combination of `offset=1` and
+ `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
+ stable diffusion.
+ tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays
+
+ """
+
+ @register_to_config
+ def __init__(
+ self,
+ num_train_timesteps: int = 1000,
+ beta_start: float = 0.0001,
+ beta_end: float = 0.02,
+ beta_schedule: str = "linear",
+ trained_betas: Optional[np.ndarray] = None,
+ skip_prk_steps: bool = False,
+ set_alpha_to_one: bool = False,
+ steps_offset: int = 0,
+ tensor_format: str = "pt",
+ ):
+ if trained_betas is not None:
+ self.betas = np.asarray(trained_betas)
+ if beta_schedule == "linear":
+ self.betas = np.linspace(beta_start, beta_end, num_train_timesteps, dtype=np.float32)
+ elif beta_schedule == "scaled_linear":
+ # this schedule is very specific to the latent diffusion model.
+ self.betas = np.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=np.float32) ** 2
+ elif beta_schedule == "squaredcos_cap_v2":
+ # Glide cosine schedule
+ self.betas = betas_for_alpha_bar(num_train_timesteps)
+ else:
+ raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
+
+ self.alphas = 1.0 - self.betas
+ self.alphas_cumprod = np.cumprod(self.alphas, axis=0)
+
+ self.final_alpha_cumprod = np.array(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
+
+ # For now we only support F-PNDM, i.e. the runge-kutta method
+ # For more information on the algorithm please take a look at the paper: https://arxiv.org/pdf/2202.09778.pdf
+ # mainly at formula (9), (12), (13) and the Algorithm 2.
+ self.pndm_order = 4
+
+ # running values
+ self.cur_model_output = 0
+ self.counter = 0
+ self.cur_sample = None
+ self.ets = []
+
+ # setable values
+ self.num_inference_steps = None
+ self._timesteps = np.arange(0, num_train_timesteps)[::-1].copy()
+ self.prk_timesteps = None
+ self.plms_timesteps = None
+ self.timesteps = None
+
+ self.tensor_format = tensor_format
+ self.set_format(tensor_format=tensor_format)
+
+ def set_timesteps(self, num_inference_steps: int, **kwargs) -> torch.FloatTensor:
+ """
+ Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
+
+ Args:
+ num_inference_steps (`int`):
+ the number of diffusion steps used when generating samples with a pre-trained model.
+ """
+
+ offset = self.config.steps_offset
+
+ if "offset" in kwargs:
+ warnings.warn(
+ "`offset` is deprecated as an input argument to `set_timesteps` and will be removed in v0.4.0."
+ " Please pass `steps_offset` to `__init__` instead."
+ )
+
+ offset = kwargs["offset"]
+
+ self.num_inference_steps = num_inference_steps
+ step_ratio = self.config.num_train_timesteps // self.num_inference_steps
+ # creates integer timesteps by multiplying by ratio
+ # casting to int to avoid issues when num_inference_step is power of 3
+ self._timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()
+ self._timesteps += offset
+
+ if self.config.skip_prk_steps:
+ # for some models like stable diffusion the prk steps can/should be skipped to
+ # produce better results. When using PNDM with `self.config.skip_prk_steps` the implementation
+ # is based on crowsonkb's PLMS sampler implementation: https://github.com/CompVis/latent-diffusion/pull/51
+ self.prk_timesteps = np.array([])
+ self.plms_timesteps = np.concatenate([self._timesteps[:-1], self._timesteps[-2:-1], self._timesteps[-1:]])[
+ ::-1
+ ].copy()
+ else:
+ prk_timesteps = np.array(self._timesteps[-self.pndm_order :]).repeat(2) + np.tile(
+ np.array([0, self.config.num_train_timesteps // num_inference_steps // 2]), self.pndm_order
+ )
+ self.prk_timesteps = (prk_timesteps[:-1].repeat(2)[1:-1])[::-1].copy()
+ self.plms_timesteps = self._timesteps[:-3][
+ ::-1
+ ].copy() # we copy to avoid having negative strides which are not supported by torch.from_numpy
+
+ self.timesteps = np.concatenate([self.prk_timesteps, self.plms_timesteps]).astype(np.int64)
+
+ self.ets = []
+ self.counter = 0
+ self.set_format(tensor_format=self.tensor_format)
+
+ def step(
+ self,
+ model_output: Union[torch.FloatTensor, np.ndarray],
+ timestep: int,
+ sample: Union[torch.FloatTensor, np.ndarray],
+ return_dict: bool = True,
+ ) -> Union[SchedulerOutput, Tuple]:
+ """
+ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
+ process from the learned model outputs (most often the predicted noise).
+
+ This function calls `step_prk()` or `step_plms()` depending on the internal variable `counter`.
+
+ Args:
+ model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model.
+ timestep (`int`): current discrete timestep in the diffusion chain.
+ sample (`torch.FloatTensor` or `np.ndarray`):
+ current instance of sample being created by diffusion process.
+ return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
+
+ Returns:
+ [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
+ [`~schedulers.scheduling_utils.SchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When
+ returning a tuple, the first element is the sample tensor.
+
+ """
+ if self.counter < len(self.prk_timesteps) and not self.config.skip_prk_steps:
+ return self.step_prk(model_output=model_output, timestep=timestep, sample=sample, return_dict=return_dict)
+ else:
+ return self.step_plms(model_output=model_output, timestep=timestep, sample=sample, return_dict=return_dict)
+
+ def step_prk(
+ self,
+ model_output: Union[torch.FloatTensor, np.ndarray],
+ timestep: int,
+ sample: Union[torch.FloatTensor, np.ndarray],
+ return_dict: bool = True,
+ ) -> Union[SchedulerOutput, Tuple]:
+ """
+ Step function propagating the sample with the Runge-Kutta method. RK takes 4 forward passes to approximate the
+ solution to the differential equation.
+
+ Args:
+ model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model.
+ timestep (`int`): current discrete timestep in the diffusion chain.
+ sample (`torch.FloatTensor` or `np.ndarray`):
+ current instance of sample being created by diffusion process.
+ return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
+
+ Returns:
+ [`~scheduling_utils.SchedulerOutput`] or `tuple`: [`~scheduling_utils.SchedulerOutput`] if `return_dict` is
+ True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.
+
+ """
+ if self.num_inference_steps is None:
+ raise ValueError(
+ "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
+ )
+
+ diff_to_prev = 0 if self.counter % 2 else self.config.num_train_timesteps // self.num_inference_steps // 2
+ prev_timestep = timestep - diff_to_prev
+ timestep = self.prk_timesteps[self.counter // 4 * 4]
+
+ if self.counter % 4 == 0:
+ self.cur_model_output += 1 / 6 * model_output
+ self.ets.append(model_output)
+ self.cur_sample = sample
+ elif (self.counter - 1) % 4 == 0:
+ self.cur_model_output += 1 / 3 * model_output
+ elif (self.counter - 2) % 4 == 0:
+ self.cur_model_output += 1 / 3 * model_output
+ elif (self.counter - 3) % 4 == 0:
+ model_output = self.cur_model_output + 1 / 6 * model_output
+ self.cur_model_output = 0
+
+ # cur_sample should not be `None`
+ cur_sample = self.cur_sample if self.cur_sample is not None else sample
+
+ prev_sample = self._get_prev_sample(cur_sample, timestep, prev_timestep, model_output)
+ self.counter += 1
+
+ if not return_dict:
+ return (prev_sample,)
+
+ return SchedulerOutput(prev_sample=prev_sample)
+
+ def step_plms(
+ self,
+ model_output: Union[torch.FloatTensor, np.ndarray],
+ timestep: int,
+ sample: Union[torch.FloatTensor, np.ndarray],
+ return_dict: bool = True,
+ ) -> Union[SchedulerOutput, Tuple]:
+ """
+ Step function propagating the sample with the linear multi-step method. This has one forward pass with multiple
+ times to approximate the solution.
+
+ Args:
+ model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model.
+ timestep (`int`): current discrete timestep in the diffusion chain.
+ sample (`torch.FloatTensor` or `np.ndarray`):
+ current instance of sample being created by diffusion process.
+ return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
+
+ Returns:
+ [`~scheduling_utils.SchedulerOutput`] or `tuple`: [`~scheduling_utils.SchedulerOutput`] if `return_dict` is
+ True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.
+
+ """
+ if self.num_inference_steps is None:
+ raise ValueError(
+ "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
+ )
+
+ if not self.config.skip_prk_steps and len(self.ets) < 3:
+ raise ValueError(
+ f"{self.__class__} can only be run AFTER scheduler has been run "
+ "in 'prk' mode for at least 12 iterations "
+ "See: https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_pndm.py "
+ "for more information."
+ )
+
+ prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps
+
+ if self.counter != 1:
+ self.ets.append(model_output)
+ else:
+ prev_timestep = timestep
+ timestep = timestep + self.config.num_train_timesteps // self.num_inference_steps
+
+ if len(self.ets) == 1 and self.counter == 0:
+ model_output = model_output
+ self.cur_sample = sample
+ elif len(self.ets) == 1 and self.counter == 1:
+ model_output = (model_output + self.ets[-1]) / 2
+ sample = self.cur_sample
+ self.cur_sample = None
+ elif len(self.ets) == 2:
+ model_output = (3 * self.ets[-1] - self.ets[-2]) / 2
+ elif len(self.ets) == 3:
+ model_output = (23 * self.ets[-1] - 16 * self.ets[-2] + 5 * self.ets[-3]) / 12
+ else:
+ model_output = (1 / 24) * (55 * self.ets[-1] - 59 * self.ets[-2] + 37 * self.ets[-3] - 9 * self.ets[-4])
+
+ prev_sample = self._get_prev_sample(sample, timestep, prev_timestep, model_output)
+ self.counter += 1
+
+ if not return_dict:
+ return (prev_sample,)
+
+ return SchedulerOutput(prev_sample=prev_sample)
+
+ def _get_prev_sample(self, sample, timestep, prev_timestep, model_output):
+ # See formula (9) of PNDM paper https://arxiv.org/pdf/2202.09778.pdf
+ # this function computes x_(t−δ) using the formula of (9)
+ # Note that x_t needs to be added to both sides of the equation
+
+ # Notation ( ->
+ # alpha_prod_t -> α_t
+ # alpha_prod_t_prev -> α_(t−δ)
+ # beta_prod_t -> (1 - α_t)
+ # beta_prod_t_prev -> (1 - α_(t−δ))
+ # sample -> x_t
+ # model_output -> e_θ(x_t, t)
+ # prev_sample -> x_(t−δ)
+ alpha_prod_t = self.alphas_cumprod[timestep]
+ alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
+ if self.tensor_format == "pt":
+ if (alpha_prod_t_prev.dtype == torch.float64):
+ alpha_prod_t_prev = alpha_prod_t_prev.to(dtype=torch.float32)
+ elif isinstance(alpha_prod_t_prev, np.float32):
+ alpha_prod_t_prev = alpha_prod_t_prev.item()
+ beta_prod_t = 1 - alpha_prod_t
+ beta_prod_t_prev = 1 - alpha_prod_t_prev
+
+ # corresponds to (α_(t−δ) - α_t) divided by
+ # denominator of x_t in formula (9) and plus 1
+ # Note: (α_(t−δ) - α_t) / (sqrt(α_t) * (sqrt(α_(t−δ)) + sqr(α_t))) =
+ # sqrt(α_(t−δ)) / sqrt(α_t))
+ sample_coeff = (alpha_prod_t_prev / alpha_prod_t) ** (0.5)
+
+ # corresponds to denominator of e_θ(x_t, t) in formula (9)
+ model_output_denom_coeff = alpha_prod_t * beta_prod_t_prev ** (0.5) + (
+ alpha_prod_t * beta_prod_t * alpha_prod_t_prev
+ ) ** (0.5)
+
+ # TODO(oneflow), oneflow's size [] tensor can't be used as a scalar
+ timestep = extract_scalar(timestep)
+ sample_coeff = extract_scalar(sample_coeff)
+ alpha_prod_t_prev = extract_scalar(alpha_prod_t_prev)
+ alpha_prod_t = extract_scalar(alpha_prod_t)
+ model_output_denom_coeff = extract_scalar(model_output_denom_coeff)
+
+ # full formula (9)
+ prev_sample = (
+ sample_coeff * sample - (alpha_prod_t_prev - alpha_prod_t) * model_output / model_output_denom_coeff
+ )
+
+ return prev_sample
+
+ def add_noise(
+ self,
+ original_samples: Union[torch.FloatTensor, np.ndarray],
+ noise: Union[torch.FloatTensor, np.ndarray],
+ timesteps: Union[torch.IntTensor, np.ndarray],
+ ) -> torch.Tensor:
+ if self.tensor_format == "pt":
+ timesteps = timesteps.to(self.alphas_cumprod.device)
+ sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
+ sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples)
+ sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
+ sqrt_one_minus_alpha_prod = self.match_shape(sqrt_one_minus_alpha_prod, original_samples)
+
+ noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
+ return noisy_samples
+
+ def __len__(self):
+ return self.config.num_train_timesteps
diff --git a/src/diffusers/testing_oneflow_utils.py b/src/diffusers/testing_oneflow_utils.py
new file mode 100644
index 000000000000..f975b2f34547
--- /dev/null
+++ b/src/diffusers/testing_oneflow_utils.py
@@ -0,0 +1,250 @@
+import os
+import random
+import re
+import unittest
+from distutils.util import strtobool
+from pathlib import Path
+from typing import Union
+
+import oneflow as torch
+
+import PIL.Image
+import PIL.ImageOps
+import requests
+from packaging import version
+
+
+global_rng = random.Random()
+torch_device = "cuda" if torch.cuda.is_available() else "cpu"
+is_torch_higher_equal_than_1_12 = version.parse(version.parse(torch.__version__).base_version) >= version.parse("1.12")
+
+if is_torch_higher_equal_than_1_12:
+ torch_device = "mps" if torch.backends.mps.is_available() else torch_device
+
+
+def parse_flag_from_env(key, default=False):
+ try:
+ value = os.environ[key]
+ except KeyError:
+ # KEY isn't set, default to `default`.
+ _value = default
+ else:
+ # KEY is set, convert it to True or False.
+ try:
+ _value = strtobool(value)
+ except ValueError:
+ # More values are supported, but let's keep the message simple.
+ raise ValueError(f"If set, {key} must be yes or no.")
+ return _value
+
+
+_run_slow_tests = parse_flag_from_env("RUN_SLOW", default=False)
+
+
+def floats_tensor(shape, scale=1.0, rng=None, name=None):
+ """Creates a random float32 tensor"""
+ if rng is None:
+ rng = global_rng
+
+ total_dims = 1
+ for dim in shape:
+ total_dims *= dim
+
+ values = []
+ for _ in range(total_dims):
+ values.append(rng.random() * scale)
+
+ return torch.tensor(data=values, dtype=torch.float).view(shape).contiguous()
+
+
+def slow(test_case):
+ """
+ Decorator marking a test as slow.
+
+ Slow tests are skipped by default. Set the RUN_SLOW environment variable to a truthy value to run them.
+
+ """
+ return unittest.skipUnless(_run_slow_tests, "test is slow")(test_case)
+
+
+def load_image(image: Union[str, PIL.Image.Image]) -> PIL.Image.Image:
+ """
+ Args:
+ Loads `image` to a PIL Image.
+ image (`str` or `PIL.Image.Image`):
+ The image to convert to the PIL Image format.
+ Returns:
+ `PIL.Image.Image`: A PIL Image.
+ """
+ if isinstance(image, str):
+ if image.startswith("http://") or image.startswith("https://"):
+ image = PIL.Image.open(requests.get(image, stream=True).raw)
+ elif os.path.isfile(image):
+ image = PIL.Image.open(image)
+ else:
+ raise ValueError(
+ f"Incorrect path or url, URLs must start with `http://` or `https://`, and {image} is not a valid path"
+ )
+ elif isinstance(image, PIL.Image.Image):
+ image = image
+ else:
+ raise ValueError(
+ "Incorrect format used for image. Should be an url linking to an image, a local path, or a PIL image."
+ )
+ image = PIL.ImageOps.exif_transpose(image)
+ image = image.convert("RGB")
+ return image
+
+
+# --- pytest conf functions --- #
+
+# to avoid multiple invocation from tests/conftest.py and examples/conftest.py - make sure it's called only once
+pytest_opt_registered = {}
+
+
+def pytest_addoption_shared(parser):
+ """
+ This function is to be called from `conftest.py` via `pytest_addoption` wrapper that has to be defined there.
+
+ It allows loading both `conftest.py` files at once without causing a failure due to adding the same `pytest`
+ option.
+
+ """
+ option = "--make-reports"
+ if option not in pytest_opt_registered:
+ parser.addoption(
+ option,
+ action="store",
+ default=False,
+ help="generate report files. The value of this option is used as a prefix to report names",
+ )
+ pytest_opt_registered[option] = 1
+
+
+def pytest_terminal_summary_main(tr, id):
+ """
+ Generate multiple reports at the end of test suite run - each report goes into a dedicated file in the current
+ directory. The report files are prefixed with the test suite name.
+
+ This function emulates --duration and -rA pytest arguments.
+
+ This function is to be called from `conftest.py` via `pytest_terminal_summary` wrapper that has to be defined
+ there.
+
+ Args:
+ - tr: `terminalreporter` passed from `conftest.py`
+ - id: unique id like `tests` or `examples` that will be incorporated into the final reports filenames - this is
+ needed as some jobs have multiple runs of pytest, so we can't have them overwrite each other.
+
+ NB: this functions taps into a private _pytest API and while unlikely, it could break should
+ pytest do internal changes - also it calls default internal methods of terminalreporter which
+ can be hijacked by various `pytest-` plugins and interfere.
+
+ """
+ from _pytest.config import create_terminal_writer
+
+ if not len(id):
+ id = "tests"
+
+ config = tr.config
+ orig_writer = config.get_terminal_writer()
+ orig_tbstyle = config.option.tbstyle
+ orig_reportchars = tr.reportchars
+
+ dir = "reports"
+ Path(dir).mkdir(parents=True, exist_ok=True)
+ report_files = {
+ k: f"{dir}/{id}_{k}.txt"
+ for k in [
+ "durations",
+ "errors",
+ "failures_long",
+ "failures_short",
+ "failures_line",
+ "passes",
+ "stats",
+ "summary_short",
+ "warnings",
+ ]
+ }
+
+ # custom durations report
+ # note: there is no need to call pytest --durations=XX to get this separate report
+ # adapted from https://github.com/pytest-dev/pytest/blob/897f151e/src/_pytest/runner.py#L66
+ dlist = []
+ for replist in tr.stats.values():
+ for rep in replist:
+ if hasattr(rep, "duration"):
+ dlist.append(rep)
+ if dlist:
+ dlist.sort(key=lambda x: x.duration, reverse=True)
+ with open(report_files["durations"], "w") as f:
+ durations_min = 0.05 # sec
+ f.write("slowest durations\n")
+ for i, rep in enumerate(dlist):
+ if rep.duration < durations_min:
+ f.write(f"{len(dlist)-i} durations < {durations_min} secs were omitted")
+ break
+ f.write(f"{rep.duration:02.2f}s {rep.when:<8} {rep.nodeid}\n")
+
+ def summary_failures_short(tr):
+ # expecting that the reports were --tb=long (default) so we chop them off here to the last frame
+ reports = tr.getreports("failed")
+ if not reports:
+ return
+ tr.write_sep("=", "FAILURES SHORT STACK")
+ for rep in reports:
+ msg = tr._getfailureheadline(rep)
+ tr.write_sep("_", msg, red=True, bold=True)
+ # chop off the optional leading extra frames, leaving only the last one
+ longrepr = re.sub(r".*_ _ _ (_ ){10,}_ _ ", "", rep.longreprtext, 0, re.M | re.S)
+ tr._tw.line(longrepr)
+ # note: not printing out any rep.sections to keep the report short
+
+ # use ready-made report funcs, we are just hijacking the filehandle to log to a dedicated file each
+ # adapted from https://github.com/pytest-dev/pytest/blob/897f151e/src/_pytest/terminal.py#L814
+ # note: some pytest plugins may interfere by hijacking the default `terminalreporter` (e.g.
+ # pytest-instafail does that)
+
+ # report failures with line/short/long styles
+ config.option.tbstyle = "auto" # full tb
+ with open(report_files["failures_long"], "w") as f:
+ tr._tw = create_terminal_writer(config, f)
+ tr.summary_failures()
+
+ # config.option.tbstyle = "short" # short tb
+ with open(report_files["failures_short"], "w") as f:
+ tr._tw = create_terminal_writer(config, f)
+ summary_failures_short(tr)
+
+ config.option.tbstyle = "line" # one line per error
+ with open(report_files["failures_line"], "w") as f:
+ tr._tw = create_terminal_writer(config, f)
+ tr.summary_failures()
+
+ with open(report_files["errors"], "w") as f:
+ tr._tw = create_terminal_writer(config, f)
+ tr.summary_errors()
+
+ with open(report_files["warnings"], "w") as f:
+ tr._tw = create_terminal_writer(config, f)
+ tr.summary_warnings() # normal warnings
+ tr.summary_warnings() # final warnings
+
+ tr.reportchars = "wPpsxXEf" # emulate -rA (used in summary_passes() and short_test_summary())
+ with open(report_files["passes"], "w") as f:
+ tr._tw = create_terminal_writer(config, f)
+ tr.summary_passes()
+
+ with open(report_files["summary_short"], "w") as f:
+ tr._tw = create_terminal_writer(config, f)
+ tr.short_test_summary()
+
+ with open(report_files["stats"], "w") as f:
+ tr._tw = create_terminal_writer(config, f)
+ tr.summary_stats()
+
+ # restore:
+ tr._tw = orig_writer
+ tr.reportchars = orig_reportchars
+ config.option.tbstyle = orig_tbstyle
diff --git a/src/diffusers/training_oneflow_utils.py b/src/diffusers/training_oneflow_utils.py
new file mode 100644
index 000000000000..2765f2f70247
--- /dev/null
+++ b/src/diffusers/training_oneflow_utils.py
@@ -0,0 +1,125 @@
+import copy
+import os
+import random
+
+import numpy as np
+import oneflow as torch
+
+
+def enable_full_determinism(seed: int):
+ """
+ Helper function for reproducible behavior during distributed training. See
+ - https://pytorch.org/docs/stable/notes/randomness.html for pytorch
+ """
+ # set seed first
+ set_seed(seed)
+
+ # Enable PyTorch deterministic mode. This potentially requires either the environment
+ # variable 'CUDA_LAUNCH_BLOCKING' or 'CUBLAS_WORKSPACE_CONFIG' to be set,
+ # depending on the CUDA version, so we set them both here
+ os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
+ os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"
+ torch.use_deterministic_algorithms(True)
+
+ # Enable CUDNN deterministic mode
+ torch.backends.cudnn.deterministic = True
+ torch.backends.cudnn.benchmark = False
+
+
+def set_seed(seed: int):
+ """
+ Args:
+ Helper function for reproducible behavior to set the seed in `random`, `numpy`, `torch`.
+ seed (`int`): The seed to set.
+ """
+ random.seed(seed)
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+ torch.cuda.manual_seed_all(seed)
+ # ^^ safe to call this function even if cuda is not available
+
+
+class EMAModel:
+ """
+ Exponential Moving Average of models weights
+ """
+
+ def __init__(
+ self,
+ model,
+ update_after_step=0,
+ inv_gamma=1.0,
+ power=2 / 3,
+ min_value=0.0,
+ max_value=0.9999,
+ device=None,
+ ):
+ """
+ @crowsonkb's notes on EMA Warmup:
+ If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are good values for models you plan
+ to train for a million or more steps (reaches decay factor 0.999 at 31.6K steps, 0.9999 at 1M steps),
+ gamma=1, power=3/4 for models you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999
+ at 215.4k steps).
+ Args:
+ inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1.
+ power (float): Exponential factor of EMA warmup. Default: 2/3.
+ min_value (float): The minimum EMA decay rate. Default: 0.
+ """
+
+ self.averaged_model = copy.deepcopy(model).eval()
+ self.averaged_model.requires_grad_(False)
+
+ self.update_after_step = update_after_step
+ self.inv_gamma = inv_gamma
+ self.power = power
+ self.min_value = min_value
+ self.max_value = max_value
+
+ if device is not None:
+ self.averaged_model = self.averaged_model.to(device=device)
+
+ self.decay = 0.0
+ self.optimization_step = 0
+
+ def get_decay(self, optimization_step):
+ """
+ Compute the decay factor for the exponential moving average.
+ """
+ step = max(0, optimization_step - self.update_after_step - 1)
+ value = 1 - (1 + step / self.inv_gamma) ** -self.power
+
+ if step <= 0:
+ return 0.0
+
+ return max(self.min_value, min(value, self.max_value))
+
+ @torch.no_grad()
+ def step(self, new_model):
+ ema_state_dict = {}
+ ema_params = self.averaged_model.state_dict()
+
+ self.decay = self.get_decay(self.optimization_step)
+
+ for key, param in new_model.named_parameters():
+ if isinstance(param, dict):
+ continue
+ try:
+ ema_param = ema_params[key]
+ except KeyError:
+ ema_param = param.float().clone() if param.ndim == 1 else copy.deepcopy(param)
+ ema_params[key] = ema_param
+
+ if not param.requires_grad:
+ ema_params[key].copy_(param.to(dtype=ema_param.dtype).data)
+ ema_param = ema_params[key]
+ else:
+ ema_param.mul_(self.decay)
+ ema_param.add_(param.data.to(dtype=ema_param.dtype), alpha=1 - self.decay)
+
+ ema_state_dict[key] = ema_param
+
+ for key, param in new_model.named_buffers():
+ ema_state_dict[key] = param
+
+ self.averaged_model.load_state_dict(ema_state_dict, strict=False)
+ self.optimization_step += 1
diff --git a/tests/test_modeling_common_oneflow.py b/tests/test_modeling_common_oneflow.py
new file mode 100644
index 000000000000..fa2a5d210c1f
--- /dev/null
+++ b/tests/test_modeling_common_oneflow.py
@@ -0,0 +1,266 @@
+# coding=utf-8
+# Copyright 2022 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+import tempfile
+import unittest
+from typing import Dict, List, Tuple
+
+import numpy as np
+import oneflow as torch
+
+from diffusers.modeling_oneflow_utils import OneFlowModelMixin as ModelMixin
+from diffusers.testing_oneflow_utils import torch_device
+from diffusers.training_oneflow_utils import EMAModel
+
+
+class ModelTesterMixin:
+ def test_from_pretrained_save_pretrained(self):
+ init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
+
+ model = self.model_class(**init_dict)
+ model.to(torch_device)
+ model.eval()
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ model.save_pretrained(tmpdirname)
+ new_model = self.model_class.from_pretrained(tmpdirname)
+ new_model.to(torch_device)
+
+ with torch.no_grad():
+ # Warmup pass when using mps (see #372)
+ if torch_device == "mps" and isinstance(model, ModelMixin):
+ _ = model(**self.dummy_input)
+ _ = new_model(**self.dummy_input)
+
+ image = model(**inputs_dict)
+ if isinstance(image, dict):
+ image = image.sample
+
+ new_image = new_model(**inputs_dict)
+
+ if isinstance(new_image, dict):
+ new_image = new_image.sample
+
+ max_diff = (image - new_image).abs().sum().item()
+ self.assertLessEqual(max_diff, 5e-5, "Models give different forward passes")
+
+ def test_determinism(self):
+ init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
+ model = self.model_class(**init_dict)
+ model.to(torch_device)
+ model.eval()
+
+ with torch.no_grad():
+ # Warmup pass when using mps (see #372)
+ if torch_device == "mps" and isinstance(model, ModelMixin):
+ model(**self.dummy_input)
+
+ first = model(**inputs_dict)
+ if isinstance(first, dict):
+ first = first.sample
+
+ second = model(**inputs_dict)
+ if isinstance(second, dict):
+ second = second.sample
+
+ out_1 = first.cpu().numpy()
+ out_2 = second.cpu().numpy()
+ out_1 = out_1[~np.isnan(out_1)]
+ out_2 = out_2[~np.isnan(out_2)]
+ max_diff = np.amax(np.abs(out_1 - out_2))
+ self.assertLessEqual(max_diff, 1e-5)
+
+ def test_output(self):
+ init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
+ model = self.model_class(**init_dict)
+ model.to(torch_device)
+ model.eval()
+
+ with torch.no_grad():
+ output = model(**inputs_dict)
+
+ if isinstance(output, dict):
+ output = output.sample
+
+ self.assertIsNotNone(output)
+ expected_shape = inputs_dict["sample"].shape
+ self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
+
+ def test_forward_with_norm_groups(self):
+ init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
+
+ init_dict["norm_num_groups"] = 16
+ init_dict["block_out_channels"] = (16, 32)
+
+ model = self.model_class(**init_dict)
+ model.to(torch_device)
+ model.eval()
+
+ with torch.no_grad():
+ output = model(**inputs_dict)
+
+ if isinstance(output, dict):
+ output = output.sample
+
+ self.assertIsNotNone(output)
+ expected_shape = inputs_dict["sample"].shape
+ self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
+
+ def test_forward_signature(self):
+ init_dict, _ = self.prepare_init_args_and_inputs_for_common()
+
+ model = self.model_class(**init_dict)
+ signature = inspect.signature(model.forward)
+ # signature.parameters is an OrderedDict => so arg_names order is deterministic
+ arg_names = [*signature.parameters.keys()]
+
+ expected_arg_names = ["sample", "timestep"]
+ self.assertListEqual(arg_names[:2], expected_arg_names)
+
+ def test_model_from_config(self):
+ init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
+
+ model = self.model_class(**init_dict)
+ model.to(torch_device)
+ model.eval()
+
+ # test if the model can be loaded from the config
+ # and has all the expected shape
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ model.save_config(tmpdirname)
+ new_model = self.model_class.from_config(tmpdirname)
+ new_model.to(torch_device)
+ new_model.eval()
+
+ # check if all parameters shape are the same
+ for param_name in model.state_dict().keys():
+ param_1 = model.state_dict()[param_name]
+ param_2 = new_model.state_dict()[param_name]
+ self.assertEqual(param_1.shape, param_2.shape)
+
+ with torch.no_grad():
+ output_1 = model(**inputs_dict)
+
+ if isinstance(output_1, dict):
+ output_1 = output_1.sample
+
+ output_2 = new_model(**inputs_dict)
+
+ if isinstance(output_2, dict):
+ output_2 = output_2.sample
+
+ self.assertEqual(output_1.shape, output_2.shape)
+
+ @unittest.skipIf(torch_device == "mps", "Training is not supported in mps")
+ def test_training(self):
+ init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
+
+ model = self.model_class(**init_dict)
+ model.to(torch_device)
+ model.train()
+ output = model(**inputs_dict)
+
+ if isinstance(output, dict):
+ output = output.sample
+
+ noise = torch.randn((inputs_dict["sample"].shape[0],) + self.output_shape).to(torch_device)
+ loss = torch.nn.functional.mse_loss(output, noise)
+ loss.backward()
+
+ @unittest.skipIf(torch_device == "mps", "Training is not supported in mps")
+ def test_ema_training(self):
+ init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
+
+ model = self.model_class(**init_dict)
+ model.to(torch_device)
+ model.train()
+ ema_model = EMAModel(model, device=torch_device)
+
+ output = model(**inputs_dict)
+
+ if isinstance(output, dict):
+ output = output.sample
+
+ noise = torch.randn((inputs_dict["sample"].shape[0],) + self.output_shape).to(torch_device)
+ loss = torch.nn.functional.mse_loss(output, noise)
+ loss.backward()
+ ema_model.step(model)
+
+ def test_outputs_equivalence(self):
+ def set_nan_tensor_to_zero(t):
+ # Temporary fallback until `aten::_index_put_impl_` is implemented in mps
+ # Track progress in https://github.com/pytorch/pytorch/issues/77764
+ device = t.device
+ if device.type == "mps":
+ t = t.to("cpu")
+ t[t != t] = 0
+ return t.to(device)
+
+ def recursive_check(tuple_object, dict_object):
+ if isinstance(tuple_object, (List, Tuple)):
+ for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object.values()):
+ recursive_check(tuple_iterable_value, dict_iterable_value)
+ elif isinstance(tuple_object, Dict):
+ for tuple_iterable_value, dict_iterable_value in zip(tuple_object.values(), dict_object.values()):
+ recursive_check(tuple_iterable_value, dict_iterable_value)
+ elif tuple_object is None:
+ return
+ else:
+ self.assertTrue(
+ np.allclose(
+ set_nan_tensor_to_zero(tuple_object).numpy(), set_nan_tensor_to_zero(dict_object).numpy(), atol=1e-5
+ ),
+ msg=(
+ "Tuple and dict output are not equal. Difference:"
+ f" {torch.max(torch.abs(tuple_object - dict_object))}. Tuple has `nan`:"
+ f" {torch.isnan(tuple_object).any()} and `inf`: {torch.isinf(tuple_object)}. Dict has"
+ f" `nan`: {torch.isnan(dict_object).any()} and `inf`: {torch.isinf(dict_object)}."
+ ),
+ )
+
+ init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
+
+ model = self.model_class(**init_dict)
+ model.to(torch_device)
+ model.eval()
+
+ with torch.no_grad():
+ # Warmup pass when using mps (see #372)
+ if torch_device == "mps" and isinstance(model, ModelMixin):
+ model(**self.dummy_input)
+
+ outputs_dict = model(**inputs_dict)
+ outputs_tuple = model(**inputs_dict, return_dict=False)
+
+ recursive_check(outputs_tuple, outputs_dict)
+
+ def test_enable_disable_gradient_checkpointing(self):
+ if not self.model_class._supports_gradient_checkpointing:
+ return # Skip test if model does not support gradient checkpointing
+
+ init_dict, _ = self.prepare_init_args_and_inputs_for_common()
+
+ # at init model should have gradient checkpointing disabled
+ model = self.model_class(**init_dict)
+ self.assertFalse(model.is_gradient_checkpointing)
+
+ # check enable works
+ model.enable_gradient_checkpointing()
+ self.assertTrue(model.is_gradient_checkpointing)
+
+ # check disable works
+ model.disable_gradient_checkpointing()
+ self.assertFalse(model.is_gradient_checkpointing)
diff --git a/tests/test_models_unet_oneflow.py b/tests/test_models_unet_oneflow.py
new file mode 100644
index 000000000000..a591f47295ec
--- /dev/null
+++ b/tests/test_models_unet_oneflow.py
@@ -0,0 +1,63 @@
+# coding=utf-8
+# Copyright 2022 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+import unittest
+
+import oneflow as torch
+
+from diffusers import OneFlowUNet2DConditionModel
+from diffusers.testing_oneflow_utils import floats_tensor, slow, torch_device
+
+from .test_modeling_common_oneflow import ModelTesterMixin
+
+
+class UNet2DConditionModelTests(ModelTesterMixin, unittest.TestCase):
+ model_class = OneFlowUNet2DConditionModel
+
+ @property
+ def dummy_input(self):
+ batch_size = 4
+ num_channels = 4
+ sizes = (32, 32)
+
+ noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
+ time_step = torch.tensor([10]).to(torch_device)
+ encoder_hidden_states = floats_tensor((batch_size, 4, 32)).to(torch_device)
+
+ return {"sample": noise, "timestep": time_step, "encoder_hidden_states": encoder_hidden_states}
+
+ @property
+ def input_shape(self):
+ return (4, 32, 32)
+
+ @property
+ def output_shape(self):
+ return (4, 32, 32)
+
+ def prepare_init_args_and_inputs_for_common(self):
+ init_dict = {
+ "block_out_channels": (32, 64),
+ "down_block_types": ("CrossAttnDownBlock2D", "DownBlock2D"),
+ "up_block_types": ("UpBlock2D", "CrossAttnUpBlock2D"),
+ "cross_attention_dim": 32,
+ "attention_head_dim": 8,
+ "out_channels": 4,
+ "in_channels": 4,
+ "layers_per_block": 2,
+ "sample_size": 32,
+ }
+ inputs_dict = self.dummy_input
+ return init_dict, inputs_dict
diff --git a/tests/test_models_vae_oneflow.py b/tests/test_models_vae_oneflow.py
new file mode 100644
index 000000000000..3c16fe90cabf
--- /dev/null
+++ b/tests/test_models_vae_oneflow.py
@@ -0,0 +1,110 @@
+# coding=utf-8
+# Copyright 2022 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+import numpy as np
+import oneflow as torch
+
+from diffusers import OneFlowAutoencoderKL as AutoencoderKL
+from diffusers.modeling_oneflow_utils import OneFlowModelMixin as ModelMixin
+from diffusers.testing_oneflow_utils import floats_tensor, torch_device
+
+from .test_modeling_common_oneflow import ModelTesterMixin
+
+
+
+class AutoencoderKLTests(ModelTesterMixin, unittest.TestCase):
+ model_class = AutoencoderKL
+
+ @property
+ def dummy_input(self):
+ batch_size = 4
+ num_channels = 3
+ sizes = (32, 32)
+
+ image = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
+
+ return {"sample": image}
+
+ @property
+ def input_shape(self):
+ return (3, 32, 32)
+
+ @property
+ def output_shape(self):
+ return (3, 32, 32)
+
+ def prepare_init_args_and_inputs_for_common(self):
+ init_dict = {
+ "block_out_channels": [32, 64],
+ "in_channels": 3,
+ "out_channels": 3,
+ "down_block_types": ["DownEncoderBlock2D", "DownEncoderBlock2D"],
+ "up_block_types": ["UpDecoderBlock2D", "UpDecoderBlock2D"],
+ "latent_channels": 4,
+ }
+ inputs_dict = self.dummy_input
+ return init_dict, inputs_dict
+
+ def test_forward_signature(self):
+ pass
+
+ def test_training(self):
+ pass
+
+ def test_from_pretrained_hub(self):
+ model, loading_info = AutoencoderKL.from_pretrained("fusing/autoencoder-kl-dummy", output_loading_info=True)
+ self.assertIsNotNone(model)
+ self.assertEqual(len(loading_info["missing_keys"]), 0)
+
+ model.to(torch_device)
+ image = model(**self.dummy_input)
+
+ assert image is not None, "Make sure output is not None"
+
+ def test_output_pretrained(self):
+ model = AutoencoderKL.from_pretrained("fusing/autoencoder-kl-dummy")
+ model = model.to(torch_device)
+ model.eval()
+
+ # One-time warmup pass (see #372)
+ if torch_device == "mps" and isinstance(model, ModelMixin):
+ image = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size)
+ image = image.to(torch_device)
+ with torch.no_grad():
+ _ = model(image, sample_posterior=True).sample
+ generator = torch.manual_seed(0)
+ else:
+ generator = torch.Generator(device=torch_device).manual_seed(0)
+
+ image = torch.randn(
+ 1,
+ model.config.in_channels,
+ model.config.sample_size,
+ model.config.sample_size,
+ generator=torch.manual_seed(0),
+ )
+ image = image.to(torch_device)
+ with torch.no_grad():
+ output = model(image, sample_posterior=True, generator=generator).sample
+
+ output_slice = output[0, -1, -3:, -3:].flatten().cpu()
+
+ # NOTE: oneflow's random generator is not aligned with pytorch's
+ # TODO(oneflow): check if oneflow has identical result to pytorch
+ expected_output_slice = torch.tensor(
+ [-0.1307, 0.1102, 0.3255, -0.2596, -0.0746, -0.1416, -0.2858, -0.3020, -0.1785]
+ )
+ self.assertTrue(np.allclose(output_slice.numpy(), expected_output_slice.numpy(), rtol=1e-2))
diff --git a/tests/test_pipelines.py b/tests/test_pipelines.py
index dddf42bd03f2..c371c7de269f 100644
--- a/tests/test_pipelines.py
+++ b/tests/test_pipelines.py
@@ -701,8 +701,6 @@ def test_stable_diffusion_inpaint(self):
expected_slice = np.array([0.4731, 0.5346, 0.4531, 0.6251, 0.5446, 0.4057, 0.5527, 0.5896, 0.5153])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
-
-
class PipelineTesterMixin(unittest.TestCase):
def tearDown(self):
# clean up the VRAM after each test
diff --git a/tests/test_pipelines_oneflow.py b/tests/test_pipelines_oneflow.py
new file mode 100644
index 000000000000..914e36abcd58
--- /dev/null
+++ b/tests/test_pipelines_oneflow.py
@@ -0,0 +1,1423 @@
+# coding=utf-8
+# Copyright 2022 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import gc
+import os
+import random
+import tempfile
+import unittest
+
+import numpy as np
+import oneflow as torch
+import torch as og_torch
+
+import PIL
+from diffusers import (
+ # AutoencoderKL,
+ DDIMPipeline,
+ # DDIMScheduler,
+ DDPMPipeline,
+ DDPMScheduler,
+ KarrasVePipeline,
+ KarrasVeScheduler,
+ LDMPipeline,
+ LDMTextToImagePipeline,
+ LMSDiscreteScheduler,
+ PNDMPipeline,
+ # PNDMScheduler,
+ ScoreSdeVePipeline,
+ ScoreSdeVeScheduler,
+ StableDiffusionImg2ImgPipeline,
+ StableDiffusionInpaintPipeline,
+ StableDiffusionOnnxPipeline,
+ # StableDiffusionPipeline,
+ # UNet2DConditionModel,
+ UNet2DModel,
+ VQModel,
+)
+from diffusers import OneFlowAutoencoderKL as AutoencoderKL
+from diffusers import OneFlowStableDiffusionPipeline as StableDiffusionPipeline
+from diffusers import OneFlowDDIMScheduler as DDIMScheduler
+from diffusers import OneFlowPNDMScheduler as PNDMScheduler
+from diffusers import OneFlowUNet2DConditionModel as UNet2DConditionModel
+from diffusers.pipeline_oneflow_utils import OneFlowDiffusionPipeline as DiffusionPipeline
+from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
+from diffusers.testing_utils import floats_tensor, load_image, slow, torch_device
+from diffusers.utils import CONFIG_NAME, WEIGHTS_NAME
+from PIL import Image
+from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
+
+@unittest.skip("not implemented in oneflow")
+def test_progress_bar(capsys):
+ model = UNet2DModel(
+ block_out_channels=(32, 64),
+ layers_per_block=2,
+ sample_size=32,
+ in_channels=3,
+ out_channels=3,
+ down_block_types=("DownBlock2D", "AttnDownBlock2D"),
+ up_block_types=("AttnUpBlock2D", "UpBlock2D"),
+ )
+ scheduler = DDPMScheduler(num_train_timesteps=10)
+
+ ddpm = DDPMPipeline(model, scheduler).to(torch_device)
+ ddpm(output_type="numpy").images
+ captured = capsys.readouterr()
+ assert "10/10" in captured.err, "Progress bar has to be displayed"
+
+ ddpm.set_progress_bar_config(disable=True)
+ ddpm(output_type="numpy").images
+ captured = capsys.readouterr()
+ assert captured.err == "", "Progress bar should be disabled"
+
+
+class PipelineFastTests(unittest.TestCase):
+ def tearDown(self):
+ # clean up the VRAM after each test
+ super().tearDown()
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ @property
+ def dummy_image(self):
+ batch_size = 1
+ num_channels = 3
+ sizes = (32, 32)
+
+ image = floats_tensor((batch_size, num_channels) + sizes, rng=random.Random(0)).to(torch_device)
+ return image
+
+ @property
+ def dummy_uncond_unet(self):
+ torch.manual_seed(0)
+ model = UNet2DModel(
+ block_out_channels=(32, 64),
+ layers_per_block=2,
+ sample_size=32,
+ in_channels=3,
+ out_channels=3,
+ down_block_types=("DownBlock2D", "AttnDownBlock2D"),
+ up_block_types=("AttnUpBlock2D", "UpBlock2D"),
+ )
+ return model
+
+ @property
+ def dummy_cond_unet(self):
+ torch.manual_seed(0)
+ model = UNet2DConditionModel(
+ block_out_channels=(32, 64),
+ layers_per_block=2,
+ sample_size=32,
+ in_channels=4,
+ out_channels=4,
+ down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
+ up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
+ cross_attention_dim=32,
+ )
+ return model
+
+ @property
+ def dummy_vq_model(self):
+ torch.manual_seed(0)
+ model = VQModel(
+ block_out_channels=[32, 64],
+ in_channels=3,
+ out_channels=3,
+ down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
+ up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
+ latent_channels=3,
+ )
+ return model
+
+ @property
+ def dummy_vae(self):
+ torch.manual_seed(0)
+ model = AutoencoderKL(
+ block_out_channels=[32, 64],
+ in_channels=3,
+ out_channels=3,
+ down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
+ up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
+ latent_channels=4,
+ )
+ return model
+
+ @property
+ def dummy_text_encoder(self):
+ torch.manual_seed(0)
+ config = CLIPTextConfig(
+ bos_token_id=0,
+ eos_token_id=2,
+ hidden_size=32,
+ intermediate_size=37,
+ layer_norm_eps=1e-05,
+ num_attention_heads=4,
+ num_hidden_layers=5,
+ pad_token_id=1,
+ vocab_size=1000,
+ )
+ return CLIPTextModel(config)
+
+ @property
+ def dummy_safety_checker(self):
+ def check(images, *args, **kwargs):
+ return images, [False] * len(images)
+
+ return check
+
+ @property
+ def dummy_extractor(self):
+ def extract(*args, **kwargs):
+ class Out:
+ def __init__(self):
+ self.pixel_values = torch.ones([0])
+
+ def to(self, device):
+ self.pixel_values.to(device)
+ return self
+
+ return Out()
+
+ return extract
+
+ @unittest.skip("not implemented in oneflow")
+ def test_ddim(self):
+ unet = self.dummy_uncond_unet
+ scheduler = DDIMScheduler(tensor_format="pt")
+
+ ddpm = DDIMPipeline(unet=unet, scheduler=scheduler)
+ ddpm.to(torch_device)
+ ddpm.set_progress_bar_config(disable=None)
+
+ # Warmup pass when using mps (see #372)
+ if torch_device == "mps":
+ _ = ddpm(num_inference_steps=1)
+
+ generator = torch.manual_seed(0)
+ image = ddpm(generator=generator, num_inference_steps=2, output_type="numpy").images
+
+ generator = torch.manual_seed(0)
+ image_from_tuple = ddpm(generator=generator, num_inference_steps=2, output_type="numpy", return_dict=False)[0]
+
+ image_slice = image[0, -3:, -3:, -1]
+ image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
+
+ assert image.shape == (1, 32, 32, 3)
+ expected_slice = np.array(
+ [1.000e00, 5.717e-01, 4.717e-01, 1.000e00, 0.000e00, 1.000e00, 3.000e-04, 0.000e00, 9.000e-04]
+ )
+ tolerance = 1e-2 if torch_device != "mps" else 3e-2
+ assert np.abs(image_slice.flatten() - expected_slice).max() < tolerance
+ assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < tolerance
+
+ @unittest.skip("not implemented in oneflow")
+ def test_pndm_cifar10(self):
+ unet = self.dummy_uncond_unet
+ scheduler = PNDMScheduler(tensor_format="pt")
+
+ pndm = PNDMPipeline(unet=unet, scheduler=scheduler)
+ pndm.to(torch_device)
+ pndm.set_progress_bar_config(disable=None)
+
+ generator = torch.manual_seed(0)
+ image = pndm(generator=generator, num_inference_steps=20, output_type="numpy").images
+
+ generator = torch.manual_seed(0)
+ image_from_tuple = pndm(generator=generator, num_inference_steps=20, output_type="numpy", return_dict=False)[0]
+
+ image_slice = image[0, -3:, -3:, -1]
+ image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
+
+ assert image.shape == (1, 32, 32, 3)
+ expected_slice = np.array([1.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0])
+ assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
+ assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
+
+ @unittest.skip("not implemented in oneflow")
+ def test_ldm_text2img(self):
+ unet = self.dummy_cond_unet
+ scheduler = DDIMScheduler(tensor_format="pt")
+ vae = self.dummy_vae
+ bert = self.dummy_text_encoder
+ tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
+
+ ldm = LDMTextToImagePipeline(vqvae=vae, bert=bert, tokenizer=tokenizer, unet=unet, scheduler=scheduler)
+ ldm.to(torch_device)
+ ldm.set_progress_bar_config(disable=None)
+
+ prompt = "A painting of a squirrel eating a burger"
+
+ # Warmup pass when using mps (see #372)
+ if torch_device == "mps":
+ generator = torch.manual_seed(0)
+ _ = ldm([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=1, output_type="numpy")[
+ "sample"
+ ]
+
+ generator = torch.manual_seed(0)
+ image = ldm([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="numpy")[
+ "sample"
+ ]
+
+ generator = torch.manual_seed(0)
+ image_from_tuple = ldm(
+ [prompt],
+ generator=generator,
+ guidance_scale=6.0,
+ num_inference_steps=2,
+ output_type="numpy",
+ return_dict=False,
+ )[0]
+
+ image_slice = image[0, -3:, -3:, -1]
+ image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
+
+ assert image.shape == (1, 64, 64, 3)
+ expected_slice = np.array([0.5074, 0.5026, 0.4998, 0.4056, 0.3523, 0.4649, 0.5289, 0.5299, 0.4897])
+ assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
+ assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
+
+ def test_stable_diffusion_ddim(self):
+ device = "cpu" # ensure determinism for the device-dependent torch.Generator
+ unet = self.dummy_cond_unet
+ scheduler = DDIMScheduler(
+ beta_start=0.00085,
+ beta_end=0.012,
+ beta_schedule="scaled_linear",
+ clip_sample=False,
+ set_alpha_to_one=False,
+ )
+
+ vae = self.dummy_vae
+ bert = self.dummy_text_encoder
+ tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
+
+ # make sure here that pndm scheduler skips prk
+ sd_pipe = StableDiffusionPipeline(
+ unet=unet,
+ scheduler=scheduler,
+ vae=vae,
+ text_encoder=bert,
+ tokenizer=tokenizer,
+ safety_checker=self.dummy_safety_checker,
+ feature_extractor=self.dummy_extractor,
+ )
+ sd_pipe = sd_pipe.to(device)
+ sd_pipe.set_progress_bar_config(disable=None)
+
+ prompt = "A painting of a squirrel eating a burger"
+
+ generator = torch.Generator(device=device).manual_seed(0)
+ output = sd_pipe([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="np")
+ image = output.images
+
+ generator = torch.Generator(device=device).manual_seed(0)
+ image_from_tuple = sd_pipe(
+ [prompt],
+ generator=generator,
+ guidance_scale=6.0,
+ num_inference_steps=2,
+ output_type="np",
+ return_dict=False,
+ )[0]
+
+ image_slice = image[0, -3:, -3:, -1]
+ image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
+
+ assert image.shape == (1, 128, 128, 3)
+ expected_slice = np.array([0.65244484, 0.50245994, 0.546379, 0.5757261, 0.5937552, 0.5248434, 0.56001717, 0.5617137, 0.4641921])
+ print(image_slice.flatten())
+ assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
+ assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
+
+ def test_stable_diffusion_pndm(self):
+ device = "cpu" # ensure determinism for the device-dependent torch.Generator
+ unet = self.dummy_cond_unet
+ scheduler = PNDMScheduler(tensor_format="pt", skip_prk_steps=True)
+ vae = self.dummy_vae
+ bert = self.dummy_text_encoder
+ tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
+
+ # make sure here that pndm scheduler skips prk
+ sd_pipe = StableDiffusionPipeline(
+ unet=unet,
+ scheduler=scheduler,
+ vae=vae,
+ text_encoder=bert,
+ tokenizer=tokenizer,
+ safety_checker=self.dummy_safety_checker,
+ feature_extractor=self.dummy_extractor,
+ )
+ sd_pipe = sd_pipe.to(device)
+ sd_pipe.set_progress_bar_config(disable=None)
+
+ prompt = "A painting of a squirrel eating a burger"
+ generator = torch.Generator(device=device).manual_seed(0)
+ output = sd_pipe([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="np")
+
+ image = output.images
+
+ generator = torch.Generator(device=device).manual_seed(0)
+ image_from_tuple = sd_pipe(
+ [prompt],
+ generator=generator,
+ guidance_scale=6.0,
+ num_inference_steps=2,
+ output_type="np",
+ return_dict=False,
+ )[0]
+
+ image_slice = image[0, -3:, -3:, -1]
+ image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
+
+ assert image.shape == (1, 128, 128, 3)
+ expected_slice = np.array([0.4574893, 0.46907586, 0.47894412, 0.5719864, 0.6205503, 0.59339994, 0.5450494, 0.47354442, 0.4824788])
+ print(image_slice.flatten())
+ assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
+ assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
+
+ @unittest.skip("not implemented in oneflow")
+ def test_stable_diffusion_k_lms(self):
+ device = "cpu" # ensure determinism for the device-dependent torch.Generator
+ unet = self.dummy_cond_unet
+ scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear")
+ vae = self.dummy_vae
+ bert = self.dummy_text_encoder
+ tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
+
+ # make sure here that pndm scheduler skips prk
+ sd_pipe = StableDiffusionPipeline(
+ unet=unet,
+ scheduler=scheduler,
+ vae=vae,
+ text_encoder=bert,
+ tokenizer=tokenizer,
+ safety_checker=self.dummy_safety_checker,
+ feature_extractor=self.dummy_extractor,
+ )
+ sd_pipe = sd_pipe.to(device)
+ sd_pipe.set_progress_bar_config(disable=None)
+
+ prompt = "A painting of a squirrel eating a burger"
+ generator = torch.Generator(device=device).manual_seed(0)
+ output = sd_pipe([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="np")
+
+ image = output.images
+
+ generator = torch.Generator(device=device).manual_seed(0)
+ image_from_tuple = sd_pipe(
+ [prompt],
+ generator=generator,
+ guidance_scale=6.0,
+ num_inference_steps=2,
+ output_type="np",
+ return_dict=False,
+ )[0]
+
+ image_slice = image[0, -3:, -3:, -1]
+ image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
+
+ assert image.shape == (1, 128, 128, 3)
+ expected_slice = np.array([0.5067, 0.4689, 0.4614, 0.5233, 0.4903, 0.5112, 0.524, 0.5069, 0.4785])
+ assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
+ assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
+
+ @unittest.skip("not implemented in oneflow")
+ def test_stable_diffusion_attention_chunk(self):
+ device = "cpu" # ensure determinism for the device-dependent torch.Generator
+ unet = self.dummy_cond_unet
+ scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear")
+ vae = self.dummy_vae
+ bert = self.dummy_text_encoder
+ tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
+
+ # make sure here that pndm scheduler skips prk
+ sd_pipe = StableDiffusionPipeline(
+ unet=unet,
+ scheduler=scheduler,
+ vae=vae,
+ text_encoder=bert,
+ tokenizer=tokenizer,
+ safety_checker=self.dummy_safety_checker,
+ feature_extractor=self.dummy_extractor,
+ )
+ sd_pipe = sd_pipe.to(device)
+ sd_pipe.set_progress_bar_config(disable=None)
+
+ prompt = "A painting of a squirrel eating a burger"
+ generator = torch.Generator(device=device).manual_seed(0)
+ output_1 = sd_pipe([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="np")
+
+ # make sure chunking the attention yields the same result
+ sd_pipe.enable_attention_slicing(slice_size=1)
+ generator = torch.Generator(device=device).manual_seed(0)
+ output_2 = sd_pipe([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="np")
+
+ assert np.abs(output_2.images.flatten() - output_1.images.flatten()).max() < 1e-4
+
+ @unittest.skip("not implemented in oneflow")
+ def test_score_sde_ve_pipeline(self):
+ unet = self.dummy_uncond_unet
+ scheduler = ScoreSdeVeScheduler(tensor_format="pt")
+
+ sde_ve = ScoreSdeVePipeline(unet=unet, scheduler=scheduler)
+ sde_ve.to(torch_device)
+ sde_ve.set_progress_bar_config(disable=None)
+
+ generator = torch.manual_seed(0)
+ image = sde_ve(num_inference_steps=2, output_type="numpy", generator=generator).images
+
+ generator = torch.manual_seed(0)
+ image_from_tuple = sde_ve(num_inference_steps=2, output_type="numpy", generator=generator, return_dict=False)[
+ 0
+ ]
+
+ image_slice = image[0, -3:, -3:, -1]
+ image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
+
+ assert image.shape == (1, 32, 32, 3)
+ expected_slice = np.array([0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0])
+ assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
+ assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
+
+ @unittest.skip("not implemented in oneflow")
+ def test_ldm_uncond(self):
+ unet = self.dummy_uncond_unet
+ scheduler = DDIMScheduler(tensor_format="pt")
+ vae = self.dummy_vq_model
+
+ ldm = LDMPipeline(unet=unet, vqvae=vae, scheduler=scheduler)
+ ldm.to(torch_device)
+ ldm.set_progress_bar_config(disable=None)
+
+ # Warmup pass when using mps (see #372)
+ if torch_device == "mps":
+ generator = torch.manual_seed(0)
+ _ = ldm(generator=generator, num_inference_steps=1, output_type="numpy").images
+
+ generator = torch.manual_seed(0)
+ image = ldm(generator=generator, num_inference_steps=2, output_type="numpy").images
+
+ generator = torch.manual_seed(0)
+ image_from_tuple = ldm(generator=generator, num_inference_steps=2, output_type="numpy", return_dict=False)[0]
+
+ image_slice = image[0, -3:, -3:, -1]
+ image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
+
+ assert image.shape == (1, 64, 64, 3)
+ expected_slice = np.array([0.8512, 0.818, 0.6411, 0.6808, 0.4465, 0.5618, 0.46, 0.6231, 0.5172])
+ assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
+ assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
+
+ @unittest.skip("not implemented in oneflow")
+ def test_karras_ve_pipeline(self):
+ unet = self.dummy_uncond_unet
+ scheduler = KarrasVeScheduler(tensor_format="pt")
+
+ pipe = KarrasVePipeline(unet=unet, scheduler=scheduler)
+ pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+
+ generator = torch.manual_seed(0)
+ image = pipe(num_inference_steps=2, generator=generator, output_type="numpy").images
+
+ generator = torch.manual_seed(0)
+ image_from_tuple = pipe(num_inference_steps=2, generator=generator, output_type="numpy", return_dict=False)[0]
+
+ image_slice = image[0, -3:, -3:, -1]
+ image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
+
+ assert image.shape == (1, 32, 32, 3)
+ expected_slice = np.array([0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0])
+ assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
+ assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
+
+ @unittest.skip("not implemented in oneflow")
+ def test_stable_diffusion_img2img(self):
+ device = "cpu" # ensure determinism for the device-dependent torch.Generator
+ unet = self.dummy_cond_unet
+ scheduler = PNDMScheduler(tensor_format="pt", skip_prk_steps=True)
+ vae = self.dummy_vae
+ bert = self.dummy_text_encoder
+ tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
+
+ init_image = self.dummy_image.to(device)
+
+ # make sure here that pndm scheduler skips prk
+ sd_pipe = StableDiffusionImg2ImgPipeline(
+ unet=unet,
+ scheduler=scheduler,
+ vae=vae,
+ text_encoder=bert,
+ tokenizer=tokenizer,
+ safety_checker=self.dummy_safety_checker,
+ feature_extractor=self.dummy_extractor,
+ )
+ sd_pipe = sd_pipe.to(device)
+ sd_pipe.set_progress_bar_config(disable=None)
+
+ prompt = "A painting of a squirrel eating a burger"
+ generator = torch.Generator(device=device).manual_seed(0)
+ output = sd_pipe(
+ [prompt],
+ generator=generator,
+ guidance_scale=6.0,
+ num_inference_steps=2,
+ output_type="np",
+ init_image=init_image,
+ )
+
+ image = output.images
+
+ generator = torch.Generator(device=device).manual_seed(0)
+ image_from_tuple = sd_pipe(
+ [prompt],
+ generator=generator,
+ guidance_scale=6.0,
+ num_inference_steps=2,
+ output_type="np",
+ init_image=init_image,
+ return_dict=False,
+ )[0]
+
+ image_slice = image[0, -3:, -3:, -1]
+ image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
+
+ assert image.shape == (1, 32, 32, 3)
+ expected_slice = np.array([0.4492, 0.3865, 0.4222, 0.5854, 0.5139, 0.4379, 0.4193, 0.48, 0.4218])
+ assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
+ assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
+
+ @unittest.skip("not implemented in oneflow")
+ def test_stable_diffusion_img2img_k_lms(self):
+ device = "cpu" # ensure determinism for the device-dependent torch.Generator
+ unet = self.dummy_cond_unet
+ scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear")
+
+ vae = self.dummy_vae
+ bert = self.dummy_text_encoder
+ tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
+
+ init_image = self.dummy_image.to(device)
+
+ # make sure here that pndm scheduler skips prk
+ sd_pipe = StableDiffusionImg2ImgPipeline(
+ unet=unet,
+ scheduler=scheduler,
+ vae=vae,
+ text_encoder=bert,
+ tokenizer=tokenizer,
+ safety_checker=self.dummy_safety_checker,
+ feature_extractor=self.dummy_extractor,
+ )
+ sd_pipe = sd_pipe.to(device)
+ sd_pipe.set_progress_bar_config(disable=None)
+
+ prompt = "A painting of a squirrel eating a burger"
+ generator = torch.Generator(device=device).manual_seed(0)
+ output = sd_pipe(
+ [prompt],
+ generator=generator,
+ guidance_scale=6.0,
+ num_inference_steps=2,
+ output_type="np",
+ init_image=init_image,
+ )
+ image = output.images
+
+ generator = torch.Generator(device=device).manual_seed(0)
+ output = sd_pipe(
+ [prompt],
+ generator=generator,
+ guidance_scale=6.0,
+ num_inference_steps=2,
+ output_type="np",
+ init_image=init_image,
+ return_dict=False,
+ )
+ image_from_tuple = output[0]
+
+ image_slice = image[0, -3:, -3:, -1]
+ image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
+
+ assert image.shape == (1, 32, 32, 3)
+ expected_slice = np.array([0.4367, 0.4986, 0.4372, 0.6706, 0.5665, 0.444, 0.5864, 0.6019, 0.5203])
+ assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
+ assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
+
+ @unittest.skip("not implemented in oneflow")
+ def test_stable_diffusion_inpaint(self):
+ device = "cpu" # ensure determinism for the device-dependent torch.Generator
+ unet = self.dummy_cond_unet
+ scheduler = PNDMScheduler(tensor_format="pt", skip_prk_steps=True)
+ vae = self.dummy_vae
+ bert = self.dummy_text_encoder
+ tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
+
+ image = self.dummy_image.cpu().permute(0, 2, 3, 1)[0]
+ init_image = Image.fromarray(np.uint8(image)).convert("RGB")
+ mask_image = Image.fromarray(np.uint8(image + 4)).convert("RGB").resize((128, 128))
+
+ # make sure here that pndm scheduler skips prk
+ sd_pipe = StableDiffusionInpaintPipeline(
+ unet=unet,
+ scheduler=scheduler,
+ vae=vae,
+ text_encoder=bert,
+ tokenizer=tokenizer,
+ safety_checker=self.dummy_safety_checker,
+ feature_extractor=self.dummy_extractor,
+ )
+ sd_pipe = sd_pipe.to(device)
+ sd_pipe.set_progress_bar_config(disable=None)
+
+ prompt = "A painting of a squirrel eating a burger"
+ generator = torch.Generator(device=device).manual_seed(0)
+ output = sd_pipe(
+ [prompt],
+ generator=generator,
+ guidance_scale=6.0,
+ num_inference_steps=2,
+ output_type="np",
+ init_image=init_image,
+ mask_image=mask_image,
+ )
+
+ image = output.images
+
+ generator = torch.Generator(device=device).manual_seed(0)
+ image_from_tuple = sd_pipe(
+ [prompt],
+ generator=generator,
+ guidance_scale=6.0,
+ num_inference_steps=2,
+ output_type="np",
+ init_image=init_image,
+ mask_image=mask_image,
+ return_dict=False,
+ )[0]
+
+ image_slice = image[0, -3:, -3:, -1]
+ image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
+
+ assert image.shape == (1, 32, 32, 3)
+ expected_slice = np.array([0.4731, 0.5346, 0.4531, 0.6251, 0.5446, 0.4057, 0.5527, 0.5896, 0.5153])
+ assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
+ assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
+class PipelineTesterMixin(unittest.TestCase):
+ def tearDown(self):
+ # clean up the VRAM after each test
+ super().tearDown()
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ @unittest.skip("OneFlowDDPMScheduler not implemented in oneflow")
+ def test_smart_download(self):
+ model_id = "hf-internal-testing/unet-pipeline-dummy"
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ _ = DiffusionPipeline.from_pretrained(model_id, cache_dir=tmpdirname, force_download=True)
+ local_repo_name = "--".join(["models"] + model_id.split("/"))
+ snapshot_dir = os.path.join(tmpdirname, local_repo_name, "snapshots")
+ snapshot_dir = os.path.join(snapshot_dir, os.listdir(snapshot_dir)[0])
+
+ # inspect all downloaded files to make sure that everything is included
+ assert os.path.isfile(os.path.join(snapshot_dir, DiffusionPipeline.config_name))
+ assert os.path.isfile(os.path.join(snapshot_dir, CONFIG_NAME))
+ assert os.path.isfile(os.path.join(snapshot_dir, SCHEDULER_CONFIG_NAME))
+ assert os.path.isfile(os.path.join(snapshot_dir, WEIGHTS_NAME))
+ assert os.path.isfile(os.path.join(snapshot_dir, "scheduler", SCHEDULER_CONFIG_NAME))
+ assert os.path.isfile(os.path.join(snapshot_dir, "unet", WEIGHTS_NAME))
+ assert os.path.isfile(os.path.join(snapshot_dir, "unet", WEIGHTS_NAME))
+ # let's make sure the super large numpy file:
+ # https://huggingface.co/hf-internal-testing/unet-pipeline-dummy/blob/main/big_array.npy
+ # is not downloaded, but all the expected ones
+ assert not os.path.isfile(os.path.join(snapshot_dir, "big_array.npy"))
+
+ @property
+ def dummy_safety_checker(self):
+ def check(images, *args, **kwargs):
+ return images, [False] * len(images)
+
+ return check
+
+ @unittest.skip("not implemented in oneflow")
+ def test_from_pretrained_save_pretrained(self):
+ # 1. Load models
+ model = UNet2DModel(
+ block_out_channels=(32, 64),
+ layers_per_block=2,
+ sample_size=32,
+ in_channels=3,
+ out_channels=3,
+ down_block_types=("DownBlock2D", "AttnDownBlock2D"),
+ up_block_types=("AttnUpBlock2D", "UpBlock2D"),
+ )
+ schedular = DDPMScheduler(num_train_timesteps=10)
+
+ ddpm = DDPMPipeline(model, schedular)
+ ddpm.to(torch_device)
+ ddpm.set_progress_bar_config(disable=None)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ ddpm.save_pretrained(tmpdirname)
+ new_ddpm = DDPMPipeline.from_pretrained(tmpdirname)
+ new_ddpm.to(torch_device)
+
+ generator = torch.manual_seed(0)
+ image = ddpm(generator=generator, output_type="numpy").images
+
+ generator = generator.manual_seed(0)
+ new_image = new_ddpm(generator=generator, output_type="numpy").images
+
+ assert np.abs(image - new_image).sum() < 1e-5, "Models don't give the same forward pass"
+
+ @unittest.skip("not implemented in oneflow")
+ @slow
+ def test_from_pretrained_hub(self):
+ model_path = "google/ddpm-cifar10-32"
+
+ scheduler = DDPMScheduler(num_train_timesteps=10)
+
+ ddpm = DDPMPipeline.from_pretrained(model_path, scheduler=scheduler)
+ ddpm.to(torch_device)
+ ddpm.set_progress_bar_config(disable=None)
+ ddpm_from_hub = DiffusionPipeline.from_pretrained(model_path, scheduler=scheduler)
+ ddpm_from_hub.to(torch_device)
+ ddpm_from_hub.set_progress_bar_config(disable=None)
+
+ generator = torch.manual_seed(0)
+ image = ddpm(generator=generator, output_type="numpy").images
+
+ generator = generator.manual_seed(0)
+ new_image = ddpm_from_hub(generator=generator, output_type="numpy").images
+
+ assert np.abs(image - new_image).sum() < 1e-5, "Models don't give the same forward pass"
+
+ @unittest.skip("not implemented in oneflow")
+ @slow
+ def test_from_pretrained_hub_pass_model(self):
+ model_path = "google/ddpm-cifar10-32"
+
+ scheduler = DDPMScheduler(num_train_timesteps=10)
+
+ # pass unet into DiffusionPipeline
+ unet = UNet2DModel.from_pretrained(model_path)
+ ddpm_from_hub_custom_model = DiffusionPipeline.from_pretrained(model_path, unet=unet, scheduler=scheduler)
+ ddpm_from_hub_custom_model.to(torch_device)
+ ddpm_from_hub_custom_model.set_progress_bar_config(disable=None)
+
+ ddpm_from_hub = DiffusionPipeline.from_pretrained(model_path, scheduler=scheduler)
+ ddpm_from_hub.to(torch_device)
+ ddpm_from_hub_custom_model.set_progress_bar_config(disable=None)
+
+ generator = torch.manual_seed(0)
+ image = ddpm_from_hub_custom_model(generator=generator, output_type="numpy").images
+
+ generator = generator.manual_seed(0)
+ new_image = ddpm_from_hub(generator=generator, output_type="numpy").images
+
+ assert np.abs(image - new_image).sum() < 1e-5, "Models don't give the same forward pass"
+
+ @unittest.skip("not implemented in oneflow")
+ @slow
+ def test_output_format(self):
+ model_path = "google/ddpm-cifar10-32"
+
+ pipe = DDIMPipeline.from_pretrained(model_path)
+ pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+
+ generator = torch.manual_seed(0)
+ images = pipe(generator=generator, output_type="numpy").images
+ assert images.shape == (1, 32, 32, 3)
+ assert isinstance(images, np.ndarray)
+
+ images = pipe(generator=generator, output_type="pil").images
+ assert isinstance(images, list)
+ assert len(images) == 1
+ assert isinstance(images[0], PIL.Image.Image)
+
+ # use PIL by default
+ images = pipe(generator=generator).images
+ assert isinstance(images, list)
+ assert isinstance(images[0], PIL.Image.Image)
+
+ @unittest.skip("not implemented in oneflow")
+ @slow
+ def test_ddpm_cifar10(self):
+ model_id = "google/ddpm-cifar10-32"
+
+ unet = UNet2DModel.from_pretrained(model_id)
+ scheduler = DDPMScheduler.from_config(model_id)
+ scheduler = scheduler.set_format("pt")
+
+ ddpm = DDPMPipeline(unet=unet, scheduler=scheduler)
+ ddpm.to(torch_device)
+ ddpm.set_progress_bar_config(disable=None)
+
+ generator = torch.manual_seed(0)
+ image = ddpm(generator=generator, output_type="numpy").images
+
+ image_slice = image[0, -3:, -3:, -1]
+
+ assert image.shape == (1, 32, 32, 3)
+ expected_slice = np.array([0.41995, 0.35885, 0.19385, 0.38475, 0.3382, 0.2647, 0.41545, 0.3582, 0.33845])
+ assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
+
+ @unittest.skip("not implemented in oneflow")
+ @slow
+ def test_ddim_lsun(self):
+ model_id = "google/ddpm-ema-bedroom-256"
+
+ unet = UNet2DModel.from_pretrained(model_id)
+ scheduler = DDIMScheduler.from_config(model_id)
+
+ ddpm = DDIMPipeline(unet=unet, scheduler=scheduler)
+ ddpm.to(torch_device)
+ ddpm.set_progress_bar_config(disable=None)
+
+ generator = torch.manual_seed(0)
+ image = ddpm(generator=generator, output_type="numpy").images
+
+ image_slice = image[0, -3:, -3:, -1]
+
+ assert image.shape == (1, 256, 256, 3)
+ expected_slice = np.array([0.00605, 0.0201, 0.0344, 0.00235, 0.00185, 0.00025, 0.00215, 0.0, 0.00685])
+ assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
+
+ @unittest.skip("not implemented in oneflow")
+ @slow
+ def test_ddim_cifar10(self):
+ model_id = "google/ddpm-cifar10-32"
+
+ unet = UNet2DModel.from_pretrained(model_id)
+ scheduler = DDIMScheduler(tensor_format="pt")
+
+ ddim = DDIMPipeline(unet=unet, scheduler=scheduler)
+ ddim.to(torch_device)
+ ddim.set_progress_bar_config(disable=None)
+
+ generator = torch.manual_seed(0)
+ image = ddim(generator=generator, eta=0.0, output_type="numpy").images
+
+ image_slice = image[0, -3:, -3:, -1]
+
+ assert image.shape == (1, 32, 32, 3)
+ expected_slice = np.array([0.17235, 0.16175, 0.16005, 0.16255, 0.1497, 0.1513, 0.15045, 0.1442, 0.1453])
+ assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
+
+ @unittest.skip("not implemented in oneflow")
+ @slow
+ def test_pndm_cifar10(self):
+ model_id = "google/ddpm-cifar10-32"
+
+ unet = UNet2DModel.from_pretrained(model_id)
+ scheduler = PNDMScheduler(tensor_format="pt")
+
+ pndm = PNDMPipeline(unet=unet, scheduler=scheduler)
+ pndm.to(torch_device)
+ pndm.set_progress_bar_config(disable=None)
+ generator = torch.manual_seed(0)
+ image = pndm(generator=generator, output_type="numpy").images
+
+ image_slice = image[0, -3:, -3:, -1]
+
+ assert image.shape == (1, 32, 32, 3)
+ expected_slice = np.array([0.1564, 0.14645, 0.1406, 0.14715, 0.12425, 0.14045, 0.13115, 0.12175, 0.125])
+ assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
+
+ @unittest.skip("not implemented in oneflow")
+ @slow
+ def test_ldm_text2img(self):
+ ldm = LDMTextToImagePipeline.from_pretrained("CompVis/ldm-text2im-large-256")
+ ldm.to(torch_device)
+ ldm.set_progress_bar_config(disable=None)
+
+ prompt = "A painting of a squirrel eating a burger"
+ generator = torch.manual_seed(0)
+ image = ldm([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=20, output_type="numpy")[
+ "sample"
+ ]
+
+ image_slice = image[0, -3:, -3:, -1]
+
+ assert image.shape == (1, 256, 256, 3)
+ expected_slice = np.array([0.9256, 0.9340, 0.8933, 0.9361, 0.9113, 0.8727, 0.9122, 0.8745, 0.8099])
+ assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
+
+ @unittest.skip("not implemented in oneflow")
+ @slow
+ def test_ldm_text2img_fast(self):
+ ldm = LDMTextToImagePipeline.from_pretrained("CompVis/ldm-text2im-large-256")
+ ldm.to(torch_device)
+ ldm.set_progress_bar_config(disable=None)
+
+ prompt = "A painting of a squirrel eating a burger"
+ generator = torch.manual_seed(0)
+ image = ldm(prompt, generator=generator, num_inference_steps=1, output_type="numpy").images
+
+ image_slice = image[0, -3:, -3:, -1]
+
+ assert image.shape == (1, 256, 256, 3)
+ expected_slice = np.array([0.3163, 0.8670, 0.6465, 0.1865, 0.6291, 0.5139, 0.2824, 0.3723, 0.4344])
+ assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
+
+ @slow
+ @unittest.skipIf(torch_device == "cpu", "Stable diffusion is supposed to run on GPU")
+ def test_stable_diffusion(self):
+ # make sure here that pndm scheduler skips prk
+ sd_pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-1", use_auth_token=True)
+ sd_pipe = sd_pipe.to(torch_device)
+ sd_pipe.set_progress_bar_config(disable=None)
+
+ prompt = "A painting of a squirrel eating a burger"
+ generator = torch.Generator(device=torch_device).manual_seed(0)
+ with torch.autocast("cuda"):
+ output = sd_pipe(
+ [prompt], generator=generator, guidance_scale=6.0, num_inference_steps=20, output_type="np"
+ )
+
+ image = output.images
+
+ image_slice = image[0, -3:, -3:, -1]
+
+ assert image.shape == (1, 512, 512, 3)
+ expected_slice = np.array([0.8887, 0.915, 0.91, 0.894, 0.909, 0.912, 0.919, 0.925, 0.883])
+ assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
+
+ @slow
+ @unittest.skipIf(torch_device == "cpu", "Stable diffusion is supposed to run on GPU")
+ def test_stable_diffusion_fast_ddim(self):
+ sd_pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-1", use_auth_token=True)
+ sd_pipe = sd_pipe.to(torch_device)
+ sd_pipe.set_progress_bar_config(disable=None)
+
+ scheduler = DDIMScheduler(
+ beta_start=0.00085,
+ beta_end=0.012,
+ beta_schedule="scaled_linear",
+ clip_sample=False,
+ set_alpha_to_one=False,
+ )
+ sd_pipe.scheduler = scheduler
+
+ prompt = "A painting of a squirrel eating a burger"
+ generator = torch.Generator(device=torch_device).manual_seed(0)
+
+ with torch.autocast("cuda"):
+ output = sd_pipe([prompt], generator=generator, num_inference_steps=2, output_type="numpy")
+ image = output.images
+
+ image_slice = image[0, -3:, -3:, -1]
+
+ assert image.shape == (1, 512, 512, 3)
+ expected_slice = np.array([0.9326, 0.923, 0.951, 0.9365, 0.9214, 0.951, 0.9365, 0.9414, 0.918])
+ assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
+
+ @unittest.skip("not implemented in oneflow")
+ @slow
+ def test_score_sde_ve_pipeline(self):
+ model_id = "google/ncsnpp-church-256"
+ model = UNet2DModel.from_pretrained(model_id)
+
+ scheduler = ScoreSdeVeScheduler.from_config(model_id)
+
+ sde_ve = ScoreSdeVePipeline(unet=model, scheduler=scheduler)
+ sde_ve.to(torch_device)
+ sde_ve.set_progress_bar_config(disable=None)
+
+ generator = torch.manual_seed(0)
+ image = sde_ve(num_inference_steps=10, output_type="numpy", generator=generator).images
+
+ image_slice = image[0, -3:, -3:, -1]
+
+ assert image.shape == (1, 256, 256, 3)
+
+ expected_slice = np.array([0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0])
+ assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
+
+ @unittest.skip("not implemented in oneflow")
+ @slow
+ def test_ldm_uncond(self):
+ ldm = LDMPipeline.from_pretrained("CompVis/ldm-celebahq-256")
+ ldm.to(torch_device)
+ ldm.set_progress_bar_config(disable=None)
+
+ generator = torch.manual_seed(0)
+ image = ldm(generator=generator, num_inference_steps=5, output_type="numpy").images
+
+ image_slice = image[0, -3:, -3:, -1]
+
+ assert image.shape == (1, 256, 256, 3)
+ expected_slice = np.array([0.4399, 0.44975, 0.46825, 0.474, 0.4359, 0.4581, 0.45095, 0.4341, 0.4447])
+ assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
+
+ @unittest.skip("not implemented in oneflow")
+ @slow
+ def test_ddpm_ddim_equality(self):
+ model_id = "google/ddpm-cifar10-32"
+
+ unet = UNet2DModel.from_pretrained(model_id)
+ ddpm_scheduler = DDPMScheduler(tensor_format="pt")
+ ddim_scheduler = DDIMScheduler(tensor_format="pt")
+
+ ddpm = DDPMPipeline(unet=unet, scheduler=ddpm_scheduler)
+ ddpm.to(torch_device)
+ ddpm.set_progress_bar_config(disable=None)
+ ddim = DDIMPipeline(unet=unet, scheduler=ddim_scheduler)
+ ddim.to(torch_device)
+ ddim.set_progress_bar_config(disable=None)
+
+ generator = torch.manual_seed(0)
+ ddpm_image = ddpm(generator=generator, output_type="numpy").images
+
+ generator = torch.manual_seed(0)
+ ddim_image = ddim(generator=generator, num_inference_steps=1000, eta=1.0, output_type="numpy").images
+
+ # the values aren't exactly equal, but the images look the same visually
+ assert np.abs(ddpm_image - ddim_image).max() < 1e-1
+
+ @unittest.skip("not implemented in oneflow")
+ @unittest.skip("(Anton) The test is failing for large batch sizes, needs investigation")
+ def test_ddpm_ddim_equality_batched(self):
+ model_id = "google/ddpm-cifar10-32"
+
+ unet = UNet2DModel.from_pretrained(model_id)
+ ddpm_scheduler = DDPMScheduler(tensor_format="pt")
+ ddim_scheduler = DDIMScheduler(tensor_format="pt")
+
+ ddpm = DDPMPipeline(unet=unet, scheduler=ddpm_scheduler)
+ ddpm.to(torch_device)
+ ddpm.set_progress_bar_config(disable=None)
+
+ ddim = DDIMPipeline(unet=unet, scheduler=ddim_scheduler)
+ ddim.to(torch_device)
+ ddim.set_progress_bar_config(disable=None)
+
+ generator = torch.manual_seed(0)
+ ddpm_images = ddpm(batch_size=4, generator=generator, output_type="numpy").images
+
+ generator = torch.manual_seed(0)
+ ddim_images = ddim(batch_size=4, generator=generator, num_inference_steps=1000, eta=1.0, output_type="numpy")[
+ "sample"
+ ]
+
+ # the values aren't exactly equal, but the images look the same visually
+ assert np.abs(ddpm_images - ddim_images).max() < 1e-1
+
+ @unittest.skip("not implemented in oneflow")
+ @slow
+ def test_karras_ve_pipeline(self):
+ model_id = "google/ncsnpp-celebahq-256"
+ model = UNet2DModel.from_pretrained(model_id)
+ scheduler = KarrasVeScheduler(tensor_format="pt")
+
+ pipe = KarrasVePipeline(unet=model, scheduler=scheduler)
+ pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+
+ generator = torch.manual_seed(0)
+ image = pipe(num_inference_steps=20, generator=generator, output_type="numpy").images
+
+ image_slice = image[0, -3:, -3:, -1]
+ assert image.shape == (1, 256, 256, 3)
+ expected_slice = np.array([0.578, 0.5811, 0.5924, 0.5809, 0.587, 0.5886, 0.5861, 0.5802, 0.586])
+ assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
+
+ @unittest.skip("not implemented in oneflow")
+ @slow
+ @unittest.skipIf(torch_device == "cpu", "Stable diffusion is supposed to run on GPU")
+ def test_lms_stable_diffusion_pipeline(self):
+ model_id = "CompVis/stable-diffusion-v1-1"
+ pipe = StableDiffusionPipeline.from_pretrained(model_id, use_auth_token=True).to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+ scheduler = LMSDiscreteScheduler.from_config(model_id, subfolder="scheduler", use_auth_token=True)
+ pipe.scheduler = scheduler
+
+ prompt = "a photograph of an astronaut riding a horse"
+ generator = torch.Generator(device=torch_device).manual_seed(0)
+ image = pipe([prompt], generator=generator, guidance_scale=7.5, num_inference_steps=10, output_type="numpy")[
+ "sample"
+ ]
+
+ image_slice = image[0, -3:, -3:, -1]
+ assert image.shape == (1, 512, 512, 3)
+ expected_slice = np.array([0.9077, 0.9254, 0.9181, 0.9227, 0.9213, 0.9367, 0.9399, 0.9406, 0.9024])
+ assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
+
+ @slow
+ @unittest.skipIf(torch_device == "cpu", "Stable diffusion is supposed to run on GPU")
+ def test_stable_diffusion_memory_chunking(self):
+ model_id = "CompVis/stable-diffusion-v1-4"
+ pipe = StableDiffusionPipeline.from_pretrained(
+ model_id, revision="fp16", torch_dtype=torch.float16, use_auth_token=True
+ ).to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+
+ prompt = "a photograph of an astronaut riding a horse"
+
+ # make attention efficient
+ pipe.enable_attention_slicing()
+ generator = torch.Generator(device=torch_device)
+ generator.manual_seed(0)
+ with og_torch.autocast(torch_device):
+ with torch.autocast(torch_device):
+ output_chunked = pipe(
+ [prompt], generator=generator, guidance_scale=7.5, num_inference_steps=10, output_type="numpy"
+ )
+ image_chunked = output_chunked.images
+
+ # disable chunking
+ pipe.disable_attention_slicing()
+ generator = torch.Generator(device=torch_device)
+ generator.manual_seed(0)
+ with og_torch.autocast(torch_device):
+ with torch.autocast(torch_device):
+ output = pipe(
+ [prompt], generator=generator, guidance_scale=7.5, num_inference_steps=10, output_type="numpy"
+ )
+ image = output.images
+
+ # make sure that more than 3.75 GB is allocated
+ mem_bytes = torch.cuda.max_memory_allocated()
+ assert mem_bytes > 3.75 * 10**9
+ assert np.abs(image_chunked.flatten() - image.flatten()).max() < 1e-3
+
+ @unittest.skip("not implemented in oneflow")
+ @slow
+ @unittest.skipIf(torch_device == "cpu", "Stable diffusion is supposed to run on GPU")
+ def test_stable_diffusion_text2img_pipeline(self):
+ expected_image = load_image(
+ "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
+ "/text2img/astronaut_riding_a_horse.png"
+ )
+ expected_image = np.array(expected_image, dtype=np.float32) / 255.0
+
+ model_id = "CompVis/stable-diffusion-v1-4"
+ pipe = StableDiffusionPipeline.from_pretrained(
+ model_id,
+ safety_checker=self.dummy_safety_checker,
+ use_auth_token=True,
+ )
+ pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+ pipe.enable_attention_slicing()
+
+ prompt = "astronaut riding a horse"
+
+ generator = torch.Generator(device=torch_device).manual_seed(0)
+ output = pipe(prompt=prompt, strength=0.75, guidance_scale=7.5, generator=generator, output_type="np")
+ image = output.images[0]
+
+ assert image.shape == (512, 512, 3)
+ assert np.abs(expected_image - image).max() < 1e-2
+
+ @unittest.skip("not implemented in oneflow")
+ @slow
+ @unittest.skipIf(torch_device == "cpu", "Stable diffusion is supposed to run on GPU")
+ def test_stable_diffusion_img2img_pipeline(self):
+ init_image = load_image(
+ "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
+ "/img2img/sketch-mountains-input.jpg"
+ )
+ expected_image = load_image(
+ "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
+ "/img2img/fantasy_landscape.png"
+ )
+ init_image = init_image.resize((768, 512))
+ expected_image = np.array(expected_image, dtype=np.float32) / 255.0
+
+ model_id = "CompVis/stable-diffusion-v1-4"
+ pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
+ model_id,
+ safety_checker=self.dummy_safety_checker,
+ use_auth_token=True,
+ )
+ pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+ pipe.enable_attention_slicing()
+
+ prompt = "A fantasy landscape, trending on artstation"
+
+ generator = torch.Generator(device=torch_device).manual_seed(0)
+ output = pipe(
+ prompt=prompt,
+ init_image=init_image,
+ strength=0.75,
+ guidance_scale=7.5,
+ generator=generator,
+ output_type="np",
+ )
+ image = output.images[0]
+
+ assert image.shape == (512, 768, 3)
+ # img2img is flaky across GPUs even in fp32, so using MAE here
+ assert np.abs(expected_image - image).mean() < 1e-2
+
+ @unittest.skip("not implemented in oneflow")
+ @slow
+ @unittest.skipIf(torch_device == "cpu", "Stable diffusion is supposed to run on GPU")
+ def test_stable_diffusion_img2img_pipeline_k_lms(self):
+ init_image = load_image(
+ "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
+ "/img2img/sketch-mountains-input.jpg"
+ )
+ expected_image = load_image(
+ "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
+ "/img2img/fantasy_landscape_k_lms.png"
+ )
+ init_image = init_image.resize((768, 512))
+ expected_image = np.array(expected_image, dtype=np.float32) / 255.0
+
+ lms = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear")
+
+ model_id = "CompVis/stable-diffusion-v1-4"
+ pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
+ model_id,
+ scheduler=lms,
+ safety_checker=self.dummy_safety_checker,
+ use_auth_token=True,
+ )
+ pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+ pipe.enable_attention_slicing()
+
+ prompt = "A fantasy landscape, trending on artstation"
+
+ generator = torch.Generator(device=torch_device).manual_seed(0)
+ output = pipe(
+ prompt=prompt,
+ init_image=init_image,
+ strength=0.75,
+ guidance_scale=7.5,
+ generator=generator,
+ output_type="np",
+ )
+ image = output.images[0]
+
+ assert image.shape == (512, 768, 3)
+ # img2img is flaky across GPUs even in fp32, so using MAE here
+ assert np.abs(expected_image - image).mean() < 1e-2
+
+ @unittest.skip("not implemented in oneflow")
+ @slow
+ @unittest.skipIf(torch_device == "cpu", "Stable diffusion is supposed to run on GPU")
+ def test_stable_diffusion_inpaint_pipeline(self):
+ init_image = load_image(
+ "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
+ "/in_paint/overture-creations-5sI6fQgYIuo.png"
+ )
+ mask_image = load_image(
+ "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
+ "/in_paint/overture-creations-5sI6fQgYIuo_mask.png"
+ )
+ expected_image = load_image(
+ "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
+ "/in_paint/red_cat_sitting_on_a_park_bench.png"
+ )
+ expected_image = np.array(expected_image, dtype=np.float32) / 255.0
+
+ model_id = "CompVis/stable-diffusion-v1-4"
+ pipe = StableDiffusionInpaintPipeline.from_pretrained(
+ model_id,
+ safety_checker=self.dummy_safety_checker,
+ use_auth_token=True,
+ )
+ pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+ pipe.enable_attention_slicing()
+
+ prompt = "A red cat sitting on a park bench"
+
+ generator = torch.Generator(device=torch_device).manual_seed(0)
+ output = pipe(
+ prompt=prompt,
+ init_image=init_image,
+ mask_image=mask_image,
+ strength=0.75,
+ guidance_scale=7.5,
+ generator=generator,
+ output_type="np",
+ )
+ image = output.images[0]
+
+ assert image.shape == (512, 512, 3)
+ assert np.abs(expected_image - image).max() < 1e-2
+
+ @unittest.skip("not implemented in oneflow")
+ @slow
+ @unittest.skipIf(torch_device == "cpu", "Stable diffusion is supposed to run on GPU")
+ def test_stable_diffusion_inpaint_pipeline_k_lms(self):
+ init_image = load_image(
+ "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
+ "/in_paint/overture-creations-5sI6fQgYIuo.png"
+ )
+ mask_image = load_image(
+ "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
+ "/in_paint/overture-creations-5sI6fQgYIuo_mask.png"
+ )
+ expected_image = load_image(
+ "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
+ "/in_paint/red_cat_sitting_on_a_park_bench_k_lms.png"
+ )
+ expected_image = np.array(expected_image, dtype=np.float32) / 255.0
+
+ lms = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear")
+
+ model_id = "CompVis/stable-diffusion-v1-4"
+ pipe = StableDiffusionInpaintPipeline.from_pretrained(
+ model_id,
+ scheduler=lms,
+ safety_checker=self.dummy_safety_checker,
+ use_auth_token=True,
+ )
+ pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+ pipe.enable_attention_slicing()
+
+ prompt = "A red cat sitting on a park bench"
+
+ generator = torch.Generator(device=torch_device).manual_seed(0)
+ output = pipe(
+ prompt=prompt,
+ init_image=init_image,
+ mask_image=mask_image,
+ strength=0.75,
+ guidance_scale=7.5,
+ generator=generator,
+ output_type="np",
+ )
+ image = output.images[0]
+
+ assert image.shape == (512, 512, 3)
+ assert np.abs(expected_image - image).max() < 1e-2
+
+ @unittest.skip("not implemented in oneflow")
+ @slow
+ def test_stable_diffusion_onnx(self):
+ sd_pipe = StableDiffusionOnnxPipeline.from_pretrained(
+ "CompVis/stable-diffusion-v1-4", revision="onnx", provider="CUDAExecutionProvider", use_auth_token=True
+ )
+
+ prompt = "A painting of a squirrel eating a burger"
+ np.random.seed(0)
+ output = sd_pipe([prompt], guidance_scale=6.0, num_inference_steps=20, output_type="np")
+ image = output.images
+
+ image_slice = image[0, -3:, -3:, -1]
+
+ assert image.shape == (1, 512, 512, 3)
+ expected_slice = np.array([0.0385, 0.0252, 0.0234, 0.0287, 0.0358, 0.0287, 0.0276, 0.0235, 0.0010])
+ assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
diff --git a/tests/test_scheduler.py b/tests/test_scheduler.py
index 7377797bebfa..7185f7d6adda 100755
--- a/tests/test_scheduler.py
+++ b/tests/test_scheduler.py
@@ -723,6 +723,8 @@ def test_full_loop_no_noise(self):
assert abs(result_sum.item() - 198.1318) < 1e-2
assert abs(result_mean.item() - 0.2580) < 1e-3
+ raise
+
def test_full_loop_with_set_alpha_to_one(self):
# We specify different beta, so that the first alpha is 0.99
sample = self.full_loop(set_alpha_to_one=True, beta_start=0.01)
diff --git a/tests/test_scheduler_oneflow.py b/tests/test_scheduler_oneflow.py
new file mode 100755
index 000000000000..9e9c49995afe
--- /dev/null
+++ b/tests/test_scheduler_oneflow.py
@@ -0,0 +1,661 @@
+# coding=utf-8
+# Copyright 2022 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import tempfile
+import unittest
+from typing import Dict, List, Tuple
+
+import numpy as np
+import oneflow as torch
+
+from diffusers import OneFlowDDIMScheduler as DDIMScheduler
+from diffusers import OneFlowPNDMScheduler as PNDMScheduler
+from diffusers.modeling_oneflow_utils import from_numpy_if_needed
+
+class SchedulerCommonTest(unittest.TestCase):
+ scheduler_classes = ()
+ forward_default_kwargs = ()
+
+ @property
+ def dummy_sample(self):
+ batch_size = 4
+ num_channels = 3
+ height = 8
+ width = 8
+
+ sample = torch.rand((batch_size, num_channels, height, width))
+
+ return sample
+
+ @property
+ def dummy_sample_deter(self):
+ batch_size = 4
+ num_channels = 3
+ height = 8
+ width = 8
+
+ num_elems = batch_size * num_channels * height * width
+ sample = torch.arange(num_elems)
+ # TODO(oneflow): in pytorch, no need for this cast
+ sample = sample.to(torch.float32)
+ sample = sample.reshape(num_channels, height, width, batch_size)
+ sample = sample / num_elems
+ sample = sample.permute(3, 0, 1, 2)
+
+ return sample
+
+ def get_scheduler_config(self):
+ raise NotImplementedError
+
+ def dummy_model(self):
+ def model(sample, t, *args):
+ sample = sample.to(dtype=torch.float32)
+ t = t.to(dtype=torch.float32)
+ return sample * t / (t + 1)
+
+ return model
+
+ def check_over_configs(self, time_step=0, **config):
+ kwargs = dict(self.forward_default_kwargs)
+
+ num_inference_steps = kwargs.pop("num_inference_steps", None)
+
+ for scheduler_class in self.scheduler_classes:
+ sample = self.dummy_sample
+ residual = 0.1 * sample
+
+ scheduler_config = self.get_scheduler_config(**config)
+ scheduler = scheduler_class(**scheduler_config)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ scheduler.save_config(tmpdirname)
+ new_scheduler = scheduler_class.from_config(tmpdirname)
+
+ if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
+ scheduler.set_timesteps(num_inference_steps)
+ new_scheduler.set_timesteps(num_inference_steps)
+ elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
+ kwargs["num_inference_steps"] = num_inference_steps
+
+ output = scheduler.step(residual, time_step, sample, **kwargs).prev_sample
+ new_output = new_scheduler.step(residual, time_step, sample, **kwargs).prev_sample
+
+ assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
+
+ def check_over_forward(self, time_step=0, **forward_kwargs):
+ kwargs = dict(self.forward_default_kwargs)
+ kwargs.update(forward_kwargs)
+
+ num_inference_steps = kwargs.pop("num_inference_steps", None)
+
+ for scheduler_class in self.scheduler_classes:
+ sample = self.dummy_sample
+ residual = 0.1 * sample
+
+ scheduler_config = self.get_scheduler_config()
+ scheduler = scheduler_class(**scheduler_config)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ scheduler.save_config(tmpdirname)
+ new_scheduler = scheduler_class.from_config(tmpdirname)
+
+ if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
+ scheduler.set_timesteps(num_inference_steps)
+ new_scheduler.set_timesteps(num_inference_steps)
+ elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
+ kwargs["num_inference_steps"] = num_inference_steps
+
+ torch.manual_seed(0)
+ output = scheduler.step(residual, time_step, sample, **kwargs).prev_sample
+ torch.manual_seed(0)
+ new_output = new_scheduler.step(residual, time_step, sample, **kwargs).prev_sample
+
+ assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
+
+ def test_from_pretrained_save_pretrained(self):
+ kwargs = dict(self.forward_default_kwargs)
+
+ num_inference_steps = kwargs.pop("num_inference_steps", None)
+
+ for scheduler_class in self.scheduler_classes:
+ sample = self.dummy_sample
+ residual = 0.1 * sample
+
+ scheduler_config = self.get_scheduler_config()
+ scheduler = scheduler_class(**scheduler_config)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ scheduler.save_config(tmpdirname)
+ new_scheduler = scheduler_class.from_config(tmpdirname)
+
+ if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
+ scheduler.set_timesteps(num_inference_steps)
+ new_scheduler.set_timesteps(num_inference_steps)
+ elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
+ kwargs["num_inference_steps"] = num_inference_steps
+
+ torch.manual_seed(0)
+ output = scheduler.step(residual, 1, sample, **kwargs).prev_sample
+ torch.manual_seed(0)
+ new_output = new_scheduler.step(residual, 1, sample, **kwargs).prev_sample
+
+ assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
+
+ def test_step_shape(self):
+ kwargs = dict(self.forward_default_kwargs)
+
+ num_inference_steps = kwargs.pop("num_inference_steps", None)
+
+ for scheduler_class in self.scheduler_classes:
+ scheduler_config = self.get_scheduler_config()
+ scheduler = scheduler_class(**scheduler_config)
+
+ sample = self.dummy_sample
+ residual = 0.1 * sample
+
+ if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
+ scheduler.set_timesteps(num_inference_steps)
+ elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
+ kwargs["num_inference_steps"] = num_inference_steps
+
+ output_0 = scheduler.step(residual, 0, sample, **kwargs).prev_sample
+ output_1 = scheduler.step(residual, 1, sample, **kwargs).prev_sample
+
+ self.assertEqual(output_0.shape, sample.shape)
+ self.assertEqual(output_0.shape, output_1.shape)
+
+ def test_pytorch_equal_numpy(self):
+ kwargs = dict(self.forward_default_kwargs)
+
+ num_inference_steps = kwargs.pop("num_inference_steps", None)
+
+ for scheduler_class in self.scheduler_classes:
+ sample_pt = self.dummy_sample
+ residual_pt = 0.1 * sample_pt
+
+ sample = sample_pt.numpy()
+ residual = 0.1 * sample
+
+ scheduler_config = self.get_scheduler_config()
+ scheduler = scheduler_class(tensor_format="np", **scheduler_config)
+
+ scheduler_pt = scheduler_class(tensor_format="pt", **scheduler_config)
+
+ if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
+ scheduler.set_timesteps(num_inference_steps)
+ scheduler_pt.set_timesteps(num_inference_steps)
+ elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
+ kwargs["num_inference_steps"] = num_inference_steps
+
+ output = scheduler.step(residual, 1, sample, **kwargs).prev_sample
+ output_pt = scheduler_pt.step(residual_pt, 1, sample_pt, **kwargs).prev_sample
+
+ assert np.sum(np.abs(output - output_pt.numpy())) < 1e-4, "Scheduler outputs are not identical"
+
+ def test_scheduler_outputs_equivalence(self):
+ def set_nan_tensor_to_zero(t):
+ t[t != t] = 0
+ return t
+
+ def recursive_check(tuple_object, dict_object):
+ if isinstance(tuple_object, (List, Tuple)):
+ for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object.values()):
+ recursive_check(tuple_iterable_value, dict_iterable_value)
+ elif isinstance(tuple_object, Dict):
+ for tuple_iterable_value, dict_iterable_value in zip(tuple_object.values(), dict_object.values()):
+ recursive_check(tuple_iterable_value, dict_iterable_value)
+ elif tuple_object is None:
+ return
+ else:
+ self.assertTrue(
+ np.allclose(
+ set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), atol=1e-5
+ )
+ )
+
+ kwargs = dict(self.forward_default_kwargs)
+ num_inference_steps = kwargs.pop("num_inference_steps", None)
+
+ for scheduler_class in self.scheduler_classes:
+ scheduler_config = self.get_scheduler_config()
+ scheduler = scheduler_class(**scheduler_config)
+
+ sample = self.dummy_sample
+ residual = 0.1 * sample
+
+ if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
+ scheduler.set_timesteps(num_inference_steps)
+ elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
+ kwargs["num_inference_steps"] = num_inference_steps
+
+ outputs_dict = scheduler.step(residual, 0, sample, **kwargs)
+
+ if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
+ scheduler.set_timesteps(num_inference_steps)
+ elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
+ kwargs["num_inference_steps"] = num_inference_steps
+
+ outputs_tuple = scheduler.step(residual, 0, sample, return_dict=False, **kwargs)
+
+ recursive_check(outputs_tuple, outputs_dict)
+
+
+class DDIMSchedulerTest(SchedulerCommonTest):
+ scheduler_classes = (DDIMScheduler,)
+ forward_default_kwargs = (("eta", 0.0), ("num_inference_steps", 50))
+
+ def get_scheduler_config(self, **kwargs):
+ config = {
+ "num_train_timesteps": 1000,
+ "beta_start": 0.0001,
+ "beta_end": 0.02,
+ "beta_schedule": "linear",
+ "clip_sample": True,
+ }
+
+ config.update(**kwargs)
+ return config
+
+ def full_loop(self, **config):
+ scheduler_class = self.scheduler_classes[0]
+ scheduler_config = self.get_scheduler_config(**config)
+ scheduler = scheduler_class(**scheduler_config)
+
+ num_inference_steps, eta = 10, 0.0
+
+ model = self.dummy_model()
+ sample = self.dummy_sample_deter
+
+ scheduler.set_timesteps(num_inference_steps)
+
+ for t in scheduler.timesteps:
+ residual = model(sample, t)
+ sample = scheduler.step(residual, t, sample, eta).prev_sample
+
+ return sample
+
+ def test_timesteps(self):
+ for timesteps in [100, 500, 1000]:
+ self.check_over_configs(num_train_timesteps=timesteps)
+
+ def test_steps_offset(self):
+ for steps_offset in [0, 1]:
+ self.check_over_configs(steps_offset=steps_offset)
+
+ scheduler_class = self.scheduler_classes[0]
+ scheduler_config = self.get_scheduler_config(steps_offset=1)
+ scheduler = scheduler_class(**scheduler_config)
+ scheduler.set_timesteps(5)
+ # TODO(oneflow) pytorch don't need the `torch.all` here
+ assert torch.all(torch.equal(scheduler.timesteps, torch.tensor([801, 601, 401, 201, 1])))
+
+ def test_betas(self):
+ for beta_start, beta_end in zip([0.0001, 0.001, 0.01, 0.1], [0.002, 0.02, 0.2, 2]):
+ self.check_over_configs(beta_start=beta_start, beta_end=beta_end)
+
+ def test_schedules(self):
+ for schedule in ["linear", "squaredcos_cap_v2"]:
+ self.check_over_configs(beta_schedule=schedule)
+
+ def test_clip_sample(self):
+ for clip_sample in [True, False]:
+ self.check_over_configs(clip_sample=clip_sample)
+
+ def test_time_indices(self):
+ for t in [1, 10, 49]:
+ self.check_over_forward(time_step=t)
+
+ def test_inference_steps(self):
+ for t, num_inference_steps in zip([1, 10, 50], [10, 50, 500]):
+ self.check_over_forward(time_step=t, num_inference_steps=num_inference_steps)
+
+ def test_eta(self):
+ for t, eta in zip([1, 10, 49], [0.0, 0.5, 1.0]):
+ self.check_over_forward(time_step=t, eta=eta)
+
+ def test_variance(self):
+ scheduler_class = self.scheduler_classes[0]
+ scheduler_config = self.get_scheduler_config()
+ scheduler = scheduler_class(**scheduler_config)
+
+ assert torch.sum(torch.abs(scheduler._get_variance(0, 0) - 0.0)) < 1e-5
+ assert torch.sum(torch.abs(scheduler._get_variance(420, 400) - 0.14771)) < 1e-5
+ assert torch.sum(torch.abs(scheduler._get_variance(980, 960) - 0.32460)) < 1e-5
+ assert torch.sum(torch.abs(scheduler._get_variance(0, 0) - 0.0)) < 1e-5
+ assert torch.sum(torch.abs(scheduler._get_variance(487, 486) - 0.00979)) < 1e-5
+ assert torch.sum(torch.abs(scheduler._get_variance(999, 998) - 0.02)) < 1e-5
+
+ def test_full_loop_no_noise(self):
+ sample = self.full_loop()
+
+ result_sum = torch.sum(torch.abs(sample))
+ result_mean = torch.mean(torch.abs(sample))
+
+ assert abs(result_sum.item() - 172.0067) < 1e-2
+ assert abs(result_mean.item() - 0.223967) < 1e-3
+
+ def test_full_loop_with_set_alpha_to_one(self):
+ # We specify different beta, so that the first alpha is 0.99
+ sample = self.full_loop(set_alpha_to_one=True, beta_start=0.01)
+ result_sum = torch.sum(torch.abs(sample))
+ result_mean = torch.mean(torch.abs(sample))
+
+ assert abs(result_sum.item() - 149.8295) < 1e-2
+ assert abs(result_mean.item() - 0.1951) < 1e-3
+
+ def test_full_loop_with_no_set_alpha_to_one(self):
+ # We specify different beta, so that the first alpha is 0.99
+ sample = self.full_loop(set_alpha_to_one=False, beta_start=0.01)
+ sample = from_numpy_if_needed(sample)
+ result_sum = torch.sum(torch.abs(sample))
+ result_mean = torch.mean(torch.abs(sample))
+
+ assert abs(result_sum.item() - 149.0784) < 1e-2
+ assert abs(result_mean.item() - 0.1941) < 1e-3
+
+
+class PNDMSchedulerTest(SchedulerCommonTest):
+ scheduler_classes = (PNDMScheduler,)
+ forward_default_kwargs = (("num_inference_steps", 50),)
+
+ def get_scheduler_config(self, **kwargs):
+ config = {
+ "num_train_timesteps": 1000,
+ "beta_start": 0.0001,
+ "beta_end": 0.02,
+ "beta_schedule": "linear",
+ }
+
+ config.update(**kwargs)
+ return config
+
+ def check_over_configs(self, time_step=0, **config):
+ kwargs = dict(self.forward_default_kwargs)
+ num_inference_steps = kwargs.pop("num_inference_steps", None)
+ sample = self.dummy_sample
+ residual = 0.1 * sample
+ dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.1, residual + 0.05]
+
+ for scheduler_class in self.scheduler_classes:
+ scheduler_config = self.get_scheduler_config(**config)
+ scheduler = scheduler_class(**scheduler_config)
+ scheduler.set_timesteps(num_inference_steps)
+ # copy over dummy past residuals
+ scheduler.ets = dummy_past_residuals[:]
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ scheduler.save_config(tmpdirname)
+ new_scheduler = scheduler_class.from_config(tmpdirname)
+ new_scheduler.set_timesteps(num_inference_steps)
+ # copy over dummy past residuals
+ new_scheduler.ets = dummy_past_residuals[:]
+
+ output = scheduler.step_prk(residual, time_step, sample, **kwargs).prev_sample
+ new_output = new_scheduler.step_prk(residual, time_step, sample, **kwargs).prev_sample
+
+ output, new_output = from_numpy_if_needed(output, new_output)
+ assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
+
+ output, new_output = from_numpy_if_needed(output, new_output)
+ output = scheduler.step_plms(residual, time_step, sample, **kwargs).prev_sample
+ new_output = new_scheduler.step_plms(residual, time_step, sample, **kwargs).prev_sample
+
+ assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
+
+ def test_from_pretrained_save_pretrained(self):
+ pass
+
+ def check_over_forward(self, time_step=0, **forward_kwargs):
+ kwargs = dict(self.forward_default_kwargs)
+ num_inference_steps = kwargs.pop("num_inference_steps", None)
+ sample = self.dummy_sample
+ residual = 0.1 * sample
+ dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.1, residual + 0.05]
+
+ for scheduler_class in self.scheduler_classes:
+ scheduler_config = self.get_scheduler_config()
+ scheduler = scheduler_class(**scheduler_config)
+ scheduler.set_timesteps(num_inference_steps)
+
+ # copy over dummy past residuals (must be after setting timesteps)
+ scheduler.ets = dummy_past_residuals[:]
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ scheduler.save_config(tmpdirname)
+ new_scheduler = scheduler_class.from_config(tmpdirname)
+ # copy over dummy past residuals
+ new_scheduler.set_timesteps(num_inference_steps)
+
+ # copy over dummy past residual (must be after setting timesteps)
+ new_scheduler.ets = dummy_past_residuals[:]
+
+ output = scheduler.step_prk(residual, time_step, sample, **kwargs).prev_sample
+ new_output = new_scheduler.step_prk(residual, time_step, sample, **kwargs).prev_sample
+
+ output, new_output = from_numpy_if_needed(output, new_output)
+ assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
+
+ output = scheduler.step_plms(residual, time_step, sample, **kwargs).prev_sample
+ new_output = new_scheduler.step_plms(residual, time_step, sample, **kwargs).prev_sample
+
+ output, new_output = from_numpy_if_needed(output, new_output)
+ assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
+
+ def full_loop(self, **config):
+ scheduler_class = self.scheduler_classes[0]
+ scheduler_config = self.get_scheduler_config(**config)
+ scheduler = scheduler_class(**scheduler_config)
+
+ num_inference_steps = 10
+ model = self.dummy_model()
+ sample = self.dummy_sample_deter
+ scheduler.set_timesteps(num_inference_steps)
+
+ for i, t in enumerate(scheduler.prk_timesteps):
+ residual = model(sample, t)
+ sample = scheduler.step_prk(residual, t, sample).prev_sample
+
+ for i, t in enumerate(scheduler.plms_timesteps):
+ residual = model(sample, t)
+ sample = scheduler.step_plms(residual, t, sample).prev_sample
+
+ return sample
+
+ def test_pytorch_equal_numpy(self):
+ kwargs = dict(self.forward_default_kwargs)
+ num_inference_steps = kwargs.pop("num_inference_steps", None)
+
+ for scheduler_class in self.scheduler_classes:
+ sample_pt = self.dummy_sample
+ residual_pt = 0.1 * sample_pt
+ dummy_past_residuals_pt = [residual_pt + 0.2, residual_pt + 0.15, residual_pt + 0.1, residual_pt + 0.05]
+
+ sample = sample_pt.numpy()
+ residual = 0.1 * sample
+ dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.1, residual + 0.05]
+
+ scheduler_config = self.get_scheduler_config()
+ scheduler = scheduler_class(tensor_format="np", **scheduler_config)
+
+ scheduler_pt = scheduler_class(tensor_format="pt", **scheduler_config)
+
+ if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
+ scheduler.set_timesteps(num_inference_steps)
+ scheduler_pt.set_timesteps(num_inference_steps)
+ elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
+ kwargs["num_inference_steps"] = num_inference_steps
+
+ # copy over dummy past residuals (must be done after set_timesteps)
+ scheduler.ets = dummy_past_residuals[:]
+ scheduler_pt.ets = dummy_past_residuals_pt[:]
+
+ output = scheduler.step_prk(residual, 1, sample, **kwargs).prev_sample
+ output_pt = scheduler_pt.step_prk(residual_pt, 1, sample_pt, **kwargs).prev_sample
+ # TODO(oneflow): investigate why the difference is so large
+ assert np.sum(np.abs(output - output_pt.numpy())) < 1e-4, "Scheduler outputs are not identical"
+
+ output = scheduler.step_plms(residual, 1, sample, **kwargs).prev_sample
+ output_pt = scheduler_pt.step_plms(residual_pt, 1, sample_pt, **kwargs).prev_sample
+
+ assert np.sum(np.abs(output - output_pt.numpy())) < 1e-4, "Scheduler outputs are not identical"
+
+ def test_set_format(self):
+ kwargs = dict(self.forward_default_kwargs)
+ num_inference_steps = kwargs.pop("num_inference_steps", None)
+
+ for scheduler_class in self.scheduler_classes:
+ scheduler_config = self.get_scheduler_config()
+ scheduler = scheduler_class(tensor_format="np", **scheduler_config)
+ scheduler_pt = scheduler_class(tensor_format="pt", **scheduler_config)
+
+ if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
+ scheduler.set_timesteps(num_inference_steps)
+ scheduler_pt.set_timesteps(num_inference_steps)
+
+ for key, value in vars(scheduler).items():
+ # we only allow `ets` attr to be a list
+ assert not isinstance(value, list) or key in [
+ "ets"
+ ], f"Scheduler is not correctly set to np format, the attribute {key} is {type(value)}"
+
+ # check if `scheduler.set_format` does convert correctly attrs to pt format
+ for key, value in vars(scheduler_pt).items():
+ # we only allow `ets` attr to be a list
+ assert not isinstance(value, list) or key in [
+ "ets"
+ ], f"Scheduler is not correctly set to pt format, the attribute {key} is {type(value)}"
+ assert not isinstance(
+ value, np.ndarray
+ ), f"Scheduler is not correctly set to pt format, the attribute {key} is {type(value)}"
+
+ def test_step_shape(self):
+ kwargs = dict(self.forward_default_kwargs)
+
+ num_inference_steps = kwargs.pop("num_inference_steps", None)
+
+ for scheduler_class in self.scheduler_classes:
+ scheduler_config = self.get_scheduler_config()
+ scheduler = scheduler_class(**scheduler_config)
+
+ sample = self.dummy_sample
+ residual = 0.1 * sample
+
+ if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
+ scheduler.set_timesteps(num_inference_steps)
+ elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
+ kwargs["num_inference_steps"] = num_inference_steps
+
+ # copy over dummy past residuals (must be done after set_timesteps)
+ dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.1, residual + 0.05]
+ scheduler.ets = dummy_past_residuals[:]
+
+ output_0 = scheduler.step_prk(residual, 0, sample, **kwargs).prev_sample
+ output_1 = scheduler.step_prk(residual, 1, sample, **kwargs).prev_sample
+
+ self.assertEqual(output_0.shape, sample.shape)
+ self.assertEqual(output_0.shape, output_1.shape)
+
+ output_0 = scheduler.step_plms(residual, 0, sample, **kwargs).prev_sample
+ output_1 = scheduler.step_plms(residual, 1, sample, **kwargs).prev_sample
+
+ self.assertEqual(output_0.shape, sample.shape)
+ self.assertEqual(output_0.shape, output_1.shape)
+
+ def test_timesteps(self):
+ for timesteps in [100, 1000]:
+ self.check_over_configs(num_train_timesteps=timesteps)
+
+ def test_steps_offset(self):
+ for steps_offset in [0, 1]:
+ self.check_over_configs(steps_offset=steps_offset)
+
+ scheduler_class = self.scheduler_classes[0]
+ scheduler_config = self.get_scheduler_config(steps_offset=1)
+ scheduler = scheduler_class(**scheduler_config)
+ scheduler.set_timesteps(10)
+ assert torch.all(torch.equal(
+ scheduler.timesteps,
+ torch.tensor(
+ [901, 851, 851, 801, 801, 751, 751, 701, 701, 651, 651, 601, 601, 501, 401, 301, 201, 101, 1]
+ ),
+ ))
+
+ def test_betas(self):
+ for beta_start, beta_end in zip([0.0001, 0.001], [0.002, 0.02]):
+ self.check_over_configs(beta_start=beta_start, beta_end=beta_end)
+
+ def test_schedules(self):
+ for schedule in ["linear", "squaredcos_cap_v2"]:
+ self.check_over_configs(beta_schedule=schedule)
+
+ def test_time_indices(self):
+ for t in [1, 5, 10]:
+ self.check_over_forward(time_step=t)
+
+ def test_inference_steps(self):
+ for t, num_inference_steps in zip([1, 5, 10], [10, 50, 100]):
+ self.check_over_forward(num_inference_steps=num_inference_steps)
+
+ def test_pow_of_3_inference_steps(self):
+ # earlier version of set_timesteps() caused an error indexing alpha's with inference steps as power of 3
+ num_inference_steps = 27
+
+ for scheduler_class in self.scheduler_classes:
+ sample = self.dummy_sample
+ residual = 0.1 * sample
+
+ scheduler_config = self.get_scheduler_config()
+ scheduler = scheduler_class(**scheduler_config)
+
+ scheduler.set_timesteps(num_inference_steps)
+
+ # before power of 3 fix, would error on first step, so we only need to do two
+ for i, t in enumerate(scheduler.prk_timesteps[:2]):
+ sample = scheduler.step_prk(residual, t, sample).prev_sample
+
+ def test_inference_plms_no_past_residuals(self):
+ with self.assertRaises(ValueError):
+ scheduler_class = self.scheduler_classes[0]
+ scheduler_config = self.get_scheduler_config()
+ scheduler = scheduler_class(**scheduler_config)
+
+ scheduler.step_plms(self.dummy_sample, 1, self.dummy_sample).prev_sample
+
+ def test_full_loop_no_noise(self):
+ sample = self.full_loop()
+ sample = from_numpy_if_needed(sample)
+ result_sum = torch.sum(torch.abs(sample))
+ result_mean = torch.mean(torch.abs(sample))
+
+ assert abs(result_sum.item() - 198.1318) < 1e-2
+ assert abs(result_mean.item() - 0.2580) < 1e-3
+
+ def test_full_loop_with_set_alpha_to_one(self):
+ # We specify different beta, so that the first alpha is 0.99
+ sample = self.full_loop(set_alpha_to_one=True, beta_start=0.01)
+ result_sum = torch.sum(torch.abs(sample))
+ result_mean = torch.mean(torch.abs(sample))
+
+ assert abs(result_sum.item() - 230.0399) < 1e-2
+ assert abs(result_mean.item() - 0.2995) < 1e-3
+
+ def test_full_loop_with_no_set_alpha_to_one(self):
+ # We specify different beta, so that the first alpha is 0.99
+ sample = self.full_loop(set_alpha_to_one=False, beta_start=0.01)
+ sample = from_numpy_if_needed(sample)
+ result_sum = torch.sum(torch.abs(sample))
+ result_mean = torch.mean(torch.abs(sample))
+
+ assert abs(result_sum.item() - 186.9482) < 1e-2
+ assert abs(result_mean.item() - 0.2434) < 1e-3