|
FlexNN 1
Fully connected neural network built from scratch with flexible n-layer design and multiple activations.
|
Class representing a neural network. More...
#include <FlexNN.h>
Public Member Functions | |
| NeuralNetwork (const std::vector< Layer > &layers) | |
| Constructor for the NeuralNetwork class. | |
| void | train (const Eigen::MatrixXd &input, const Eigen::MatrixXd &target, double learningRate, int epochs) |
| Train the neural network. | |
| double | accuracy (const Eigen::MatrixXd &X, const Eigen::MatrixXd &Y) |
| Calculate the accuracy of the neural network. | |
| Eigen::MatrixXd | predict (const Eigen::MatrixXd &input) |
| Predict the output for given input data. | |
Class representing a neural network.
This class encapsulates the functionality of a neural network, including training, prediction, and accuracy calculation. It uses a vector of Layer objects to represent the structure of the network.
|
inline |
Constructor for the NeuralNetwork class.
| layers | A vector of Layer objects representing the layers of the neural network. |
| double FlexNN::NeuralNetwork::accuracy | ( | const Eigen::MatrixXd & | X, |
| const Eigen::MatrixXd & | Y | ||
| ) |
Calculate the accuracy of the neural network.
This method computes the accuracy of the neural network's predictions against the target data.
| X | The input data for prediction. |
| Y | The target output data for comparison. |
|
inline |
Predict the output for given input data.
This method performs a forward pass through the neural network to predict the output for the provided input data.
| input | The input data for prediction. |
| void FlexNN::NeuralNetwork::train | ( | const Eigen::MatrixXd & | input, |
| const Eigen::MatrixXd & | target, | ||
| double | learningRate, | ||
| int | epochs | ||
| ) |
Train the neural network.
This method trains the neural network using the provided input and target data. It performs forward and backward passes, updating weights based on the gradients.
| input | The input data for training. |
| target | The target output data for training. |
| learningRate | The learning rate for weight updates. |
| epochs | The number of training epochs. |