diff --git a/train.py b/train.py index 1bc0f6d..0efa836 100644 --- a/train.py +++ b/train.py @@ -1,6 +1,7 @@ import mlx.core as mx import numpy as np import math +import os # Functions shared by both the training and the testing scripts. from common import * @@ -186,6 +187,10 @@ for epoch in range(EPOCHS): LEARNING_RATE *= 0.99 print("\nTraining complete.") +if not os.path.exists("weights"): + os.makedirs("weights") +if not os.path.exists("biases"): + os.makedirs("biases") np.savetxt("weights/weights1.txt", np.array(W1)) np.savetxt("biases/biases1.txt", np.array(b1)) for i in range(HIDDEN_LAYERS):