summaryrefslogtreecommitdiff
path: root/src/nnetwork.js
blob: dd214918443cc17360edc94ef059a301ef60db12 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
console.log('Hello from nnetwork.js');

class NNetwork {
	constructor(nodeCounts, activationFunctionNames, learningRate) {
		this.learningRate = learningRate;
		this.layers = [];
		for (let i = 1; i < nodeCounts.length; i++) {
			this.layers.push(new Layer(nodeCounts[i-1], nodeCounts[i], activationFunctionNames[i]));
		};
	}
	forwardPropogation(activationInput) {
		let ao = activationInput;		
		this.layers.forEach(layer => {
			ao = layer.forwardPropogation(ao);
		});
		return ao;
	}
	backPropogation(activationInput, targetOutput) {
		let ao = this.forwardPropogation(activationInput);
		let cost = math.subtract(ao, targetOutput);
		let dc_da = math.multiply(cost, 2);
		cost = math.map(cost, element => element ** 2);
		cost = math.sum(cost)
		this.layers.reverse().forEach(layer => {
			dc_da = layer.backPropogation(dc_da);
		});
		this.layers.reverse()
		return cost;
	}
	gradientDescent() {
		this.layers.reverse().forEach(layer => {
			layer.gradientDescent(this.learningRate);
		});
		this.layers.reverse()
	}
};