forked from tonybeltramelli/pix2code
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.py
More file actions
executable file
·66 lines (48 loc) · 2.26 KB
/
Copy pathtrain.py
File metadata and controls
executable file
·66 lines (48 loc) · 2.26 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
#!/usr/bin/env python
from __future__ import print_function
from __future__ import absolute_import
__author__ = 'Tony Beltramelli - www.tonybeltramelli.com'
import tensorflow as tf
sess = tf.Session(config=tf.ConfigProto(log_device_placement=True))
import sys
from classes.dataset.Generator import *
from classes.model.pix2code import *
def run(input_path, output_path, is_memory_intensive=False, pretrained_model=None):
np.random.seed(1234)
dataset = Dataset()
dataset.load(input_path, generate_binary_sequences=True)
dataset.save_metadata(output_path)
dataset.voc.save(output_path)
if not is_memory_intensive:
dataset.convert_arrays()
input_shape = dataset.input_shape
output_size = dataset.output_size
print(len(dataset.input_images), len(dataset.partial_sequences), len(dataset.next_words))
print(dataset.input_images.shape, dataset.partial_sequences.shape, dataset.next_words.shape)
else:
gui_paths, img_paths = Dataset.load_paths_only(input_path)
input_shape = dataset.input_shape
output_size = dataset.output_size
steps_per_epoch = dataset.size / BATCH_SIZE
voc = Vocabulary()
voc.retrieve(output_path)
generator = Generator.data_generator(voc, gui_paths, img_paths, batch_size=BATCH_SIZE, generate_binary_sequences=True)
model = pix2code(input_shape, output_size, output_path)
if pretrained_model is not None:
model.model.load_weights(pretrained_model)
if not is_memory_intensive:
model.fit(dataset.input_images, dataset.partial_sequences, dataset.next_words)
else:
model.fit_generator(generator, steps_per_epoch=steps_per_epoch)
if __name__ == "__main__":
argv = sys.argv[1:]
if len(argv) < 2:
print("Error: not enough argument supplied:")
print("train.py <input path> <output path> <is memory intensive (default: 0)> <pretrained weights (optional)>")
exit(0)
else:
input_path = argv[0]
output_path = argv[1]
use_generator = False if len(argv) < 3 else True if int(argv[2]) == 1 else False
pretrained_weigths = None if len(argv) < 4 else argv[3]
run(input_path, output_path, is_memory_intensive=use_generator, pretrained_model=pretrained_weigths)