summaryrefslogtreecommitdiff
path: root/train_loop.r
diff options
context:
space:
mode:
Diffstat (limited to 'train_loop.r')
-rw-r--r--train_loop.r38
1 files changed, 38 insertions, 0 deletions
diff --git a/train_loop.r b/train_loop.r
new file mode 100644
index 0000000..48fa0c1
--- /dev/null
+++ b/train_loop.r
@@ -0,0 +1,38 @@
+source(paste(getwd(), "NNetwork_refCls.r", sep = "/"))
+source(paste(getwd(), "test_gen.r", sep = "/"))
+
+ni <- 5
+no <- packBits(as.integer(c(rep(1, ni), rep(0, 32 - ni))), "integer") + 1
+
+nnetwork_1 <- NNetwork(
+ node_counts = c(ni, 32, 32, no),
+ l_rate = 0.01,
+ act_func_names = c("no_func", "relu", "relu", "no_func")
+)
+
+cost_chk <- rep(0, no)
+res_chk <- rep(0, no)
+
+for (i in 0:10000) {
+
+ ti <- test_input(ni)
+ tar_out <- test_output(ti)
+
+ ao <- nnetwork_1$fw_prop(ti)
+ cost <- nnetwork_1$back_prop(ti, tar_out)
+
+ tar_int <- read_input(ti)
+ act_int <- clean_output(ao)
+
+ if (tar_int == 3) {print(c(cost, ao[4], act_int))}
+
+ cost_chk[tar_int + 1] <- cost
+ res_chk[tar_int + 1] <- act_int
+
+ if (i %% 1 == 0) {
+ nnetwork_1$grad_desc()
+ }
+}
+
+cost_chk
+res_chk