forked from jindongwang/transferlearning
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdataset.py
More file actions
executable file
·44 lines (42 loc) · 1.66 KB
/
dataset.py
File metadata and controls
executable file
·44 lines (42 loc) · 1.66 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import numpy as np
import torch
import torch.backends.cudnn as cudnn
import torch.nn as nn
import torch.nn.parallel
import torch.optim as optim
import torch.utils.data as data
import torchvision.datasets as datasets
import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image
import glob
import os
def loader(path, batch_size=16, num_workers=1, pin_memory=True):
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
return data.DataLoader(
datasets.ImageFolder(path,
transforms.Compose([
transforms.Resize(256),
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize,
])),
batch_size=batch_size,
shuffle=True,
num_workers=num_workers,
pin_memory=pin_memory)
def test_loader(path, batch_size=16, num_workers=1, pin_memory=True):
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
return data.DataLoader(
datasets.ImageFolder(path,
transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
normalize,
])),
batch_size=batch_size,
shuffle=False,
num_workers=num_workers,
pin_memory=pin_memory)