package model; import java.util.ArrayList; import java.util.List; import matrix.Matrix; import autodiff.Graph; public class NeuralNetwork implements Model { private static final long serialVersionUID = 1L; List layers = new ArrayList<>(); public NeuralNetwork(List layers) { this.layers = layers; } @Override public Matrix forward(Matrix input, Graph g) throws Exception { Matrix prev = input; for (Model layer : layers) { prev = layer.forward(prev, g); } return prev; } @Override public void resetState() { for (Model layer : layers) { layer.resetState(); } } @Override public List getParameters() { List result = new ArrayList<>(); for (Model layer : layers) { result.addAll(layer.getParameters()); } return result; } }