We know that, in few-shot learning, we learn from lesser data points, but how can we apply gradient descent in a few-shot learning setting? In a few-shot learning setting, gradient descent fails abruptly due to very few data points. Gradient descent optimization requires more data points to reach the convergence and minimize loss. So, we need a better optimization technique in the few-shot regime. Let's say we have a model parameterized by some parameter . We initialize this parameter with some random values and try to find the optimal value using gradient descent. Let's recall the update equation of our gradient descent:
In the previous equation, the following applies:
- is the updated parameter
- is the parameter value at previous time step
- is the learning rate
- is the gradient of loss function with respect to
Doesn't the update equation of gradient descent look familiar? Yes, you guessed it right: it resembles the cell state update equation of LSTM and it can be written as follows:
We can totally relate our LSTM cell update equation with gradient descent as, let's say = 1, then the following applies:
So, instead of using gradient descent as an optimizer in the few-shot learning regime, we can use LSTM as an optimizer. Our meta learner is the LSTM, which learns the update rule for training our model. So we use two networks: one, our base learner, which learns to perform a task, and the other, the meta learner, which tries to find the optimal parameter. But how does this work?
We know that, in LSTM, we use a forget gate for discarding information that is not required in the memory, and it can be represented as follows:
How can this forget gate be useful in our optimization setting? Let's say we are in a position where the loss is high, and the gradient is close to zero. How can we escape from this position? In this case, we can shrink the parameters of our model and forget some parts of its previous value. So, we can use our forget gate to do that and it takes a current parameter value , current loss , current gradient and the previous forget gate as the input; it can be represented as follows:
Now let's come to the input gate. We know that the input gate in LSTM is used for deciding what value to update, and it can be represented as follows:
In our few-shot learning setting, we can use this input gate to tune our learning rate to learn quickly while preventing it from divergence:
So, our meta learner learns the optimal value of and after several updates.
But still, how does this work?
Let's say we have a base network parameterized by and our LSTM meta learner parameterized by . Assume that we have a dataset . We split our dataset as and for training and testing respectively. First, we randomly initialize our meta learner parameter .
For some T number of iterations, we randomly sample data points from , calculate the loss, and then we calculate the gradients of loss with respect to our model parameter . Now we feed this gradient, loss, and meta learner parameter to our meta learner. Our meta learner will return a cell state and then we update our base network parameter at a time t as . We repeat this for some N number of times, as shown in the following diagram:
So, after T iterations, we will have an optimal parameter . But how can we check the performance of and how can we update our meta learner parameter? We take the test set and compute the loss on our test set with parameter . Then, we calculate the gradients of the loss with respect to our meta learner parameter and then we update , as shown here:
We do this for some n number of iterations and update our meta learner. The overall algorithm is shown here: