Now, we will see how to build a matching network in TensorFlow step by step. We will see the final code at the end.
First, we import the libraries:
import tensorflow as tf
slim = tf.contrib.slim
rnn = tf.contrib.rnn
Now, we define a class called Matching_network, where we define our network:
class Matching_network():
We define the __init__ method, where we initialize all of the variables:
def __init__(self, lr, n_way, k_shot, batch_size=32):
#placeholder for support set
self.support_set_image = tf.placeholder(tf.float32, [None, n_way * k_shot, 28, 28, 1])
self.support_set_label = tf.placeholder(tf.int32, [None, n_way * k_shot, ])
#placeholder for query set
self.query_image = tf.placeholder(tf.float32, [None, 28, 28, 1])
self.query_label = tf.placeholder(tf.int32, [None, ])
Let's say our...