forked from tensorflow/tensor2tensor
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathalgorithmic.py
More file actions
178 lines (141 loc) · 6.74 KB
/
algorithmic.py
File metadata and controls
178 lines (141 loc) · 6.74 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
# Copyright 2017 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Algorithmic data generators."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# Dependency imports
import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin
def identity_generator(nbr_symbols, max_length, nbr_cases):
"""Generator for the identity (copy) task on sequences of symbols.
The length of the sequence is drawn uniformly at random from [1, max_length]
and then symbols are drawn uniformly at random from [1, nbr_symbols] until
nbr_cases sequences have been produced.
Args:
nbr_symbols: number of symbols to use in each sequence.
max_length: integer, maximum length of sequences to generate.
nbr_cases: the number of cases to generate.
Yields:
A dictionary {"inputs": input-list, "targets": target-list} where
input-list and target-list are the same.
"""
for _ in xrange(nbr_cases):
l = np.random.randint(max_length) + 1
inputs = [np.random.randint(nbr_symbols) + 1 for _ in xrange(l)]
yield {"inputs": inputs, "targets": inputs}
def shift_generator(nbr_symbols, shift, max_length, nbr_cases):
"""Generator for the shift task on sequences of symbols.
The length of the sequence is drawn uniformly at random from [1, max_length]
and then symbols are drawn uniformly at random from [1, nbr_symbols - shift]
until nbr_cases sequences have been produced (output[i] = input[i] + shift).
Args:
nbr_symbols: number of symbols to use in each sequence (input + output).
shift: by how much to shift the input.
max_length: integer, maximum length of sequences to generate.
nbr_cases: the number of cases to generate.
Yields:
A dictionary {"inputs": input-list, "targets": target-list} where
target-list[i] = input-list[i] + shift.
"""
for _ in xrange(nbr_cases):
l = np.random.randint(max_length) + 1
inputs = [np.random.randint(nbr_symbols - shift) + 1 for _ in xrange(l)]
yield {"inputs": inputs, "targets": [i + shift for i in inputs]}
def reverse_generator(nbr_symbols, max_length, nbr_cases):
"""Generator for the reversing task on sequences of symbols.
The length of the sequence is drawn uniformly at random from [1, max_length]
and then symbols are drawn uniformly at random from [1, nbr_symbols] until
nbr_cases sequences have been produced.
Args:
nbr_symbols: number of symbols to use in each sequence.
max_length: integer, maximum length of sequences to generate.
nbr_cases: the number of cases to generate.
Yields:
A dictionary {"inputs": input-list, "targets": target-list} where
target-list is input-list reversed.
"""
for _ in xrange(nbr_cases):
l = np.random.randint(max_length) + 1
inputs = [np.random.randint(nbr_symbols) + 1 for _ in xrange(l)]
yield {"inputs": inputs, "targets": list(reversed(inputs))}
def lower_endian_to_number(l, base):
"""Helper function: convert a list of digits in the given base to a number."""
return sum([d * (base**i) for i, d in enumerate(l)])
def number_to_lower_endian(n, base):
"""Helper function: convert a number to a list of digits in the given base."""
if n < base:
return [n]
return [n % base] + number_to_lower_endian(n // base, base)
def random_number_lower_endian(length, base):
"""Helper function: generate a random number as a lower-endian digits list."""
if length == 1: # Last digit can be 0 only if length is 1.
return [np.random.randint(base)]
prefix = [np.random.randint(base) for _ in xrange(length - 1)]
return prefix + [np.random.randint(base - 1) + 1] # Last digit is not 0.
def addition_generator(base, max_length, nbr_cases):
"""Generator for the addition task.
The length of each number is drawn uniformly at random from [1, max_length/2]
and then digits are drawn uniformly at random. The numbers are added and
separated by [base+1] in the input. Stops at nbr_cases.
Args:
base: in which base are the numbers.
max_length: integer, maximum length of sequences to generate.
nbr_cases: the number of cases to generate.
Yields:
A dictionary {"inputs": input-list, "targets": target-list} where
input-list are the 2 numbers and target-list is the result of adding them.
Raises:
ValueError: if max_length is lower than 3.
"""
if max_length < 3:
raise ValueError("Maximum length must be at least 3.")
for _ in xrange(nbr_cases):
l1 = np.random.randint(max_length // 2) + 1
l2 = np.random.randint(max_length - l1 - 1) + 1
n1 = random_number_lower_endian(l1, base)
n2 = random_number_lower_endian(l2, base)
result = lower_endian_to_number(n1, base) + lower_endian_to_number(n2, base)
# We shift digits by 1 on input and output to leave 0 for padding.
inputs = [i + 1 for i in n1] + [base + 1] + [i + 1 for i in n2]
targets = [i + 1 for i in number_to_lower_endian(result, base)]
yield {"inputs": inputs, "targets": targets}
def multiplication_generator(base, max_length, nbr_cases):
"""Generator for the multiplication task.
The length of each number is drawn uniformly at random from [1, max_length/2]
and then digits are drawn uniformly at random. The numbers are multiplied
and separated by [base+1] in the input. Stops at nbr_cases.
Args:
base: in which base are the numbers.
max_length: integer, maximum length of sequences to generate.
nbr_cases: the number of cases to generate.
Yields:
A dictionary {"inputs": input-list, "targets": target-list} where
input-list are the 2 numbers and target-list is the result of multiplying
them.
Raises:
ValueError: if max_length is lower than 3.
"""
if max_length < 3:
raise ValueError("Maximum length must be at least 3.")
for _ in xrange(nbr_cases):
l1 = np.random.randint(max_length // 2) + 1
l2 = np.random.randint(max_length - l1 - 1) + 1
n1 = random_number_lower_endian(l1, base)
n2 = random_number_lower_endian(l2, base)
result = lower_endian_to_number(n1, base) * lower_endian_to_number(n2, base)
# We shift digits by 1 on input and output to leave 0 for padding.
inputs = [i + 1 for i in n1] + [base + 1] + [i + 1 for i in n2]
targets = [i + 1 for i in number_to_lower_endian(result, base)]
yield {"inputs": inputs, "targets": targets}