Skip to content

Commit 5da71f8

Browse files
fix generator 2
1 parent 8edf598 commit 5da71f8

1 file changed

Lines changed: 4 additions & 1 deletion

File tree

tests/test_modeling_utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from distutils.util import strtobool
2323

2424
import torch
25+
import numpy as np
2526

2627
from diffusers import GaussianDDPMScheduler, UNetModel
2728
from diffusers.pipeline_utils import DiffusionPipeline
@@ -35,7 +36,7 @@
3536
def get_random_generator(seed):
3637
seed = 1234
3738
random.seed(seed)
38-
os.environ[PYTHONHASHSEED] = str(seed)
39+
os.environ["PYTHONHASHSEED"] = str(seed)
3940
np.random.seed(seed)
4041
torch.manual_seed(seed)
4142
torch.cuda.manual_seed(seed)
@@ -176,6 +177,7 @@ def test_sample(self):
176177

177178
assert image.shape == (1, 3, 256, 256)
178179
image_slice = image[0, -1, -3:, -3:].cpu()
180+
import ipdb; ipdb.set_trace()
179181
assert (image_slice - torch.tensor([[-0.0598, -0.0611, -0.0506], [-0.0726, 0.0220, 0.0103], [-0.0723, -0.1310, -0.2458]])).abs().sum() < 1e-3
180182

181183
def test_sample_fast(self):
@@ -216,6 +218,7 @@ def test_sample_fast(self):
216218

217219
assert image.shape == (1, 3, 256, 256)
218220
image_slice = image[0, -1, -3:, -3:].cpu()
221+
import ipdb; ipdb.set_trace()
219222
assert (image_slice - torch.tensor([[0.1746, 0.5125, -0.7920], [-0.5734, -0.2910, -0.1984], [0.4090, -0.7740, -0.3941]])).abs().sum() < 1e-3
220223

221224

0 commit comments

Comments
 (0)