Solution
This is one hidden layer network , \( x= (n, m) \) , hidden layer size = \(h\) , output node = 1
Algorithm Steps:
Code
import numpy as np
class NeuralNetwork:
def __init__(self, input_dim, hidden_dim, output_dim):
self.input_dim = input_dim
self.hidden_dim = hidden_dim
self.output_dim = output_dim
def initialize_weights(self):
self.w1 = np.random.randn(self.input_dim, self.hidden_dim)
self.b1 = np.random.randn(1, self.hidden_dim)
self.w2 = np.random.randn(self.hidden_dim, self.output_dim)
self.b2 = np.random.randn(1, self.output_dim)
def predict(self, x):
_, _, _, a2 = self.forward_pass(x)
return a2
def train_network(self, x_train, y_train, epochs, learning_rate, x_val=None, y_val=None):
self.train_cost = []
self.initialize_weights()
for i in range(epochs):
z1, a1, z2, a2 = self.forward_pass(x_train)
loss = self.compute_loss(y_train, a2)
self.train_cost.append(loss)
dw1, db1, dw2, db2 = self.back_ward_pass(z1, a1, a2, x_train, y_train)
self.update_parameter(dw1, db1, dw2, db2, learning_rate)
val_loss = None
if x_val is not None and y_val is not None:
self.val_cost = []
val_loss = self.compute_loss(y_val, self.predict(x_val))
self.val_cost.append(val_loss)
print("epoch {} , train_loss = {} val_loss = {}".format(i, loss, val_loss))
def relu(self,x):
return np.maximum(0,x)
def relu_derivate(self,x):
return np.where(x>0,1,0)
def forward_pass(self, x):
z1 = np.dot(x, self.w1) + self.b1
a1 = self.relu(z1)
z2 = np.dot(a1, self.w2) + self.b2
a2 = z2
return z1, a1, z2,a2
def compute_loss(self, y_true, y_pred):
return np.mean((y_true-y_pred)**2)
def back_ward_pass(self, z1, a1, a2, x, y_true):
m = x.shape[0]
dz2 = a2-y_true
dw2 = np.dot(a1.T, dz2)/m
db2 = np.sum(dz2, axis=0, keepdims=True)/m
dz1 = np.dot(dz2, self.w2.T) * self.relu_derivate(z1)
dw1 = np.dot(x.T, dz1)/m
db1 = np.sum(dz1, axis=0, keepdims=True)/m
return dw1, db1, dw2, db2
def update_parameter(self, dw1, db1, dw2, db2, learning_rate):
self.w1 -= learning_rate * dw1
self.b1 -= learning_rate * db1
self.w2 -= learning_rate * dw2
self.b2 -= learning_rate * db2