forked from evolvingstuff/RecurrentJava
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathLossSoftmax.java
More file actions
executable file
·103 lines (93 loc) · 3.13 KB
/
Copy pathLossSoftmax.java
File metadata and controls
executable file
·103 lines (93 loc) · 3.13 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
package loss;
import autodiff.Graph;
import datastructs.DataSequence;
import datastructs.DataStep;
import matrix.Matrix;
import model.Model;
import util.Util;
import java.util.ArrayList;
import java.util.List;
public class LossSoftmax
implements Loss
{
@Override
public void backward(Matrix logprobs, Matrix targetOutput)
throws Exception
{
int targetIndex = getTargetIndex(targetOutput);
Matrix probs = getSoftmaxProbs(logprobs, 1.0);
for (int i = 0; i < probs.w.length; i++) {
logprobs.dw[i] = probs.w[i];
}
logprobs.dw[targetIndex] -= 1;
}
@Override
public double measure(Matrix logprobs, Matrix targetOutput)
throws Exception
{
int targetIndex = getTargetIndex(targetOutput);
Matrix probs = getSoftmaxProbs(logprobs, 1.0);
double cost = -Math.log(probs.w[targetIndex]);
return cost;
}
public static double calculateMedianPerplexity(Model model, List<DataSequence> sequences)
throws Exception
{
double temperature = 1.0;
List<Double> ppls = new ArrayList<>();
for (DataSequence seq : sequences) {
double n = 0;
double neglog2ppl = 0;
Graph g = new Graph(false);
model.resetState();
for (DataStep step : seq.steps) {
Matrix logprobs = model.forward(step.input, g);
Matrix probs = getSoftmaxProbs(logprobs, temperature);
int targetIndex = getTargetIndex(step.targetOutput);
double probOfCorrect = probs.w[targetIndex];
double log2prob = Math.log(probOfCorrect) / Math.log(2); //change-of-base
neglog2ppl += -log2prob;
n += 1;
}
n -= 1; //don't count first symbol of sentence
double ppl = Math.pow(2, (neglog2ppl / (n - 1)));
ppls.add(ppl);
}
return Util.median(ppls);
}
public static Matrix getSoftmaxProbs(Matrix logprobs, double temperature)
throws Exception
{
Matrix probs = new Matrix(logprobs.w.length);
if (temperature != 1.0) {
for (int i = 0; i < logprobs.w.length; i++) {
logprobs.w[i] /= temperature;
}
}
double maxval = Double.NEGATIVE_INFINITY;
for (int i = 0; i < logprobs.w.length; i++) {
if (logprobs.w[i] > maxval) {
maxval = logprobs.w[i];
}
}
double sum = 0;
for (int i = 0; i < logprobs.w.length; i++) {
probs.w[i] = Math.exp(logprobs.w[i] - maxval); //all inputs to exp() are non-positive
sum += probs.w[i];
}
for (int i = 0; i < probs.w.length; i++) {
probs.w[i] /= sum;
}
return probs;
}
private static int getTargetIndex(Matrix targetOutput)
throws Exception
{
for (int i = 0; i < targetOutput.w.length; i++) {
if (targetOutput.w[i] == 1.0) {
return i;
}
}
throw new Exception("no target index selected");
}
}