Back propagation of deep neural network
TOC
Problem of Multi Layer Neural Network
- Multi layer neural network is a powerful tool for Machine learning, but training the neural network, which means updating weights and bias for each layers, is difficult.
- To update weights and bias, it is necessary to know how much each weights and bias affect the cost, and this is done by derivative, especially gradient descent algorithm.
- However, derivative is very expensive computation. So if there are too many weights and bias in multi layer neural network, the problem could be impossible to solve in a certain time.
- In the previous post, MNIST data set is trained with single layer neural network. For the result of example, it took almost 12 hours.
- Multi layer neural network could have more weights and bias for additional layers, so it would take much longer time.
Back Propagation
- A method to calculate the gradient of the loss function with respect to the weights in an artificial neural network. - Wiki
- Back propagation is a fast partial derivative method with chain rule. - Wiki
- To understand back propagation, only it is necessary to know basic partial derivative knowledge, but it won't be explained here. - Wiki
Back Propagation with Graph
- Back propagation is based on chain rule of mathematics, so usually it is explained with complicated mathematics equations.
- However, it can be also explained with graph and simple partial derivative. For neural network, back propagation can be applied to the network directly.
Simple Example of Back Propagation
- Let's use simple equation for logit
$$ Y = X1 \cdot X2 + C $$
- Back propagation calculates how much final node Y is sensitive to the all nodes, Y, X1, X2 and C.
- As the name of back propagation, the last node is the start point of it, its direction is backward.
- The first target is Y. It is differential of itself, and the equation is \( \frac{\partial{Y}}{\partial{Y}} = 1 \).
- The second node is ADD. The input of ADD is S1and S2, and output is Y.
$$ \frac {\partial{Y}} {\partial{S1}} = \frac {\partial{S1 + S2}} {\partial{S1}} = 1 $$
$$ \frac {\partial{Y}} {\partial{S2}} = \frac {\partial{S1 + S2}} {\partial{S2}} = 1 $$
- The next node is MUL. The input of MUL is X1 and X2, and output is S1.
$$ \frac{\partial{S1}}{\partial{X1}} = \frac{\partial{S1 \cdot S2}}{\partial{X1}} = X2 $$
$$ \frac{\partial{S1}}{\partial{X2}} = = \frac{\partial{S1 \cdot S2}}{\partial{X2}} = X1 $$
- Now, the graph is
- The black line is forward propagation, and the red line is backward propagation, and back propagation.
- The value of the red line is how much the output is sensitive to the input.
- With chain rule, the partial derivative values for COST are
$$ \frac{\partial{COST}}{\partial{S1}} = \frac{\partial{COST}}{\partial{Y}} \cdot \frac{\partial{Y}}{\partial{S1}} = \frac{\partial{COST}}{\partial{Y}} $$
$$ \frac{\partial{COST}}{\partial{S2}} = \frac{\partial{COST}}{\partial{Y}} \cdot \frac{\partial{Y}}{\partial{S2}} = \frac{\partial{COST}}{\partial{Y}} $$
$$ \frac{\partial{COST}}{\partial{X1}} = \frac{\partial{COST}}{\partial{Y}} \cdot \frac{\partial{Y}}{\partial{S1}} \cdot \frac{\partial{S1}}{\partial{X1}} = \frac{\partial{COST}}{\partial{Y}} \cdot X2 $$
$$ \frac{\partial{COST}}{\partial{X2}} = \frac{\partial{COST}}{\partial{Y}} \cdot \frac{\partial{Y}}{\partial{S1}} \cdot \frac{\partial{S1}}{\partial{X2}} = \frac{\partial{COST}}{\partial{Y}} \cdot X1 $$
- From this work, it is possible to design python ADD and MUL node which have forward and backward path.
- ADD
class ADD():
def __init__(self):
pass
def forward(self, x1, x2):
return x1 + x2
def backward(self, d):
dx = d * 1
dy = d * 1
return dx, dy
- MUL
class MUL():
def __init__(self):
self.x1 = None
self.x2 = None
def forward(self, x1, x2):
self.x1 = x1
self.x2 = x2
return x1 * x2
def backward(self, d):
dx1 = d * self.x2
dx2 = d * self.x1
return dx1, dx2
- With the node, the code for the example is
# Inputs
x1 = 2
x2 = 5
C = 3
# Nodes
mul = MUL()
add = ADD()
# Forward
s1 = mul.forward(x1, x2)
s2 = C
y = add.forward(s1, s2)
# Backward
# Let's assume dCOST/dY is 1
ds1, ds2 = add.backward(1)
dc = ds2
dx1, dx2 = mul.backward(ds1)
print("Y: {0}".format(y))
print("dCOST/dX1: {0}".format(dx1))
print("dCOST/dX2: {0}".format(dx2))
print("dCOST/dC: {0}".format(dc))
Y: 13
dCOST/dX1: 5
dCOST/dX2: 2
dCOST/dC: 1
COMMENTS