We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent dd42b18 commit 20a78b3Copy full SHA for 20a78b3
1 file changed
examples/pytorch_image_search.py
@@ -25,13 +25,15 @@
25
26
27
# load pretrained model
28
+device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')
29
model = torchvision.models.resnet18(weights='DEFAULT')
30
model.fc = torch.nn.Identity()
31
+model.to(device)
32
model.eval()
33
34
35
def generate_embeddings(inputs):
- return model(inputs).detach().numpy()
36
+ return model(inputs.to(device)).detach().cpu().numpy()
37
38
39
# generate, save, and index embeddings
@@ -53,7 +55,8 @@ def show_images(dataset_images):
53
55
grid = torchvision.utils.make_grid(dataset_images)
54
56
img = (grid / 2 + 0.5).permute(1, 2, 0).numpy()
57
plt.imshow(img)
- plt.waitforbuttonpress()
58
+ plt.draw()
59
+ plt.waitforbuttonpress(timeout=3)
60
61
62
# load 5 random unseen images
0 commit comments