summaryrefslogtreecommitdiff
path: root/test/unit_tests/activationFunctions.test.js
diff options
context:
space:
mode:
authorChristian Hodgden <chrhodgden@gmail.com>2024-02-15 13:09:21 +0000
committerChristian Hodgden <chrhodgden@gmail.com>2024-02-15 13:09:21 +0000
commit2d83cbf2238065ed24f3e6ca8fec65f32cfc16f4 (patch)
tree49f01c12caee00ea37cdb830d7fc16e059dd5a18 /test/unit_tests/activationFunctions.test.js
parentff0b69860bd619680436d736e4d693efcaf5ef97 (diff)
adding activation function tests
added sigmoid and relu. Relu currently fails because of -0. need to update src code.
Diffstat (limited to 'test/unit_tests/activationFunctions.test.js')
-rw-r--r--test/unit_tests/activationFunctions.test.js35
1 files changed, 25 insertions, 10 deletions
diff --git a/test/unit_tests/activationFunctions.test.js b/test/unit_tests/activationFunctions.test.js
index e00140f..231544b 100644
--- a/test/unit_tests/activationFunctions.test.js
+++ b/test/unit_tests/activationFunctions.test.js
@@ -4,23 +4,38 @@ const math = require('mathjs');
describe('Test Activation Function module', () => {
let result;
+ let targetResult;
let testVector;
let testMatrix;
beforeEach(() => {
testVector = [-2, -1, 0, 1, 2];
testVector = math.matrix(testVector);
- testMatrix = math.matrix([testVector, testVector]);
+ testMatrix = math.matrix([testVector, testVector]);
+ });
+
+ test('Sigmoid Function', () => {
+ result = activationFunctionList['sigmoid'].gx(testVector);
+ expect(result._data[2]).toEqual(0.5);
+
+ result = activationFunctionList['sigmoid'].dg_dx(testVector);
+ expect(result._data[2]).toEqual(0.25);
});
-
- test.todo('Sigmoid Function');
- // result = activationFunctionList['sigmoid'].gx(testVector);
- // result = activationFunctionList['sigmoid'].dg_dx(testVector);
-
- test.todo('RELU Function');
- // result = activationFunctionList['relu'].gx(testMatrix);
- // result = activationFunctionList['relu'].dg_dx(testMatrix);
-
+
+ test('RELU Function', () => {
+ targetResult = [0, 0, 0, 1, 2];
+ targetResult = math.matrix(targetResult);
+ result = activationFunctionList['relu'].gx(testVector);
+ console.log(result);
+ console.log(targetResult);
+ expect(result).toEqual(targetResult);
+
+ targetResult = [0, 0, 0, 1, 1];
+ targetResult = math.matrix(targetResult);
+ result = activationFunctionList['relu'].dg_dx(testVector);
+ expect(result).toEqual(targetResult);
+ });
+
test.todo('Identity Function');
// result = activationFunctionList['identity'].gx(testMatrix);
// result = activationFunctionList['identity'].dg_dx(testMatrix);