forked from evolvingstuff/RecurrentJava
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathEmbeddedReberGrammar.java
More file actions
executable file
·117 lines (101 loc) · 3.94 KB
/
Copy pathEmbeddedReberGrammar.java
File metadata and controls
executable file
·117 lines (101 loc) · 3.94 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
package datasets;
import java.util.*;
import datastructs.DataSequence;
import datastructs.DataSet;
import datastructs.DataStep;
import loss.LossMultiDimensionalBinary;
import loss.LossSumOfSquares;
import model.Model;
import model.Nonlinearity;
import model.SigmoidUnit;
public class EmbeddedReberGrammar extends DataSet {
static public class State {
public State(Transition[] transitions) {
this.transitions = transitions;
}
public Transition[] transitions;
}
static public class Transition {
public Transition(int next_state_id, int token) {
this.next_state_id = next_state_id;
this.token = token;
}
public int next_state_id;
public int token;
}
public EmbeddedReberGrammar(Random r) throws Exception {
int total_sequences = 1000;
inputDimension = 7;
outputDimension = 7;
lossTraining = new LossSumOfSquares();
lossReporting = new LossMultiDimensionalBinary();
training = generateSequences(r, total_sequences);
validation = generateSequences(r, total_sequences);
testing = generateSequences(r, total_sequences);
}
public static List<DataSequence> generateSequences(Random r, int sequences) {
List<DataSequence> result = new ArrayList<>();
final int B = 0;
final int T = 1;
final int P = 2;
final int S = 3;
final int X = 4;
final int V = 5;
final int E = 6;
State[] states = new State[19];
states[0] = new State(new Transition[] {new Transition(1,B)});
states[1] = new State(new Transition[] {new Transition(2,T), new Transition(11,P)});
states[2] = new State(new Transition[] {new Transition(3,B)});
states[3] = new State(new Transition[] {new Transition(4,T), new Transition(9,P)});
states[4] = new State(new Transition[] {new Transition(4,S), new Transition(5,X)});
states[5] = new State(new Transition[] {new Transition(6,S), new Transition(9,X)});
states[6] = new State(new Transition[] {new Transition(7,E)});
states[7] = new State(new Transition[] {new Transition(8,T)});
states[8] = new State(new Transition[] {new Transition(0,E)});
states[9] = new State(new Transition[] {new Transition(9,T), new Transition(10,V)});
states[10] = new State(new Transition[] {new Transition(5,P), new Transition(6,V)});
states[11] = new State(new Transition[] {new Transition(12,B)});
states[12] = new State(new Transition[] {new Transition(13,T), new Transition(17,P)});
states[13] = new State(new Transition[] {new Transition(13,S), new Transition(14,X)});
states[14] = new State(new Transition[] {new Transition(15,S), new Transition(17,X)});
states[15] = new State(new Transition[] {new Transition(16,E)});
states[16] = new State(new Transition[] {new Transition(8,P)});
states[17] = new State(new Transition[] {new Transition(17,T), new Transition(18,V)});
states[18] = new State(new Transition[] {new Transition(14,P), new Transition(15,V)});
for (int sequence = 0; sequence < sequences; sequence++) {
List<DataStep> steps = new ArrayList<>();;
int state_id = 0;
while (true) {
int transition = -1;
if (states[state_id].transitions.length == 1) {
transition = 0;
}
else if (states[state_id].transitions.length == 2) {
transition = r.nextInt(2);
}
double[] observation = null;
observation = new double[7];
observation[states[state_id].transitions[transition].token] = 1.0;
state_id = states[state_id].transitions[transition].next_state_id;
if (state_id == 0) { //exit at end of sequence
break;
}
double[] target_output = new double[7];
for (int i = 0; i < states[state_id].transitions.length; i++) {
target_output[states[state_id].transitions[i].token] = 1.0;
}
steps.add(new DataStep(observation, target_output));
}
result.add(new DataSequence(steps));
}
return result;
}
@Override
public void DisplayReport(Model model, Random rng) throws Exception {
// TODO Auto-generated method stub
}
@Override
public Nonlinearity getModelOutputUnitToUse() {
return new SigmoidUnit();
}
}