原型网络对训练集中不存在的类别也具有泛化能力,与孪生网络一样,也试图学习度量空间来进行分类,基本思想是创建每个类的原型表示,并根据类原型与查询点之间的距离对查询点(新点)进行分类。
使用原型网络对Omniglot字符集分类
import os
import glob
from PIL import Image
import numpy as np
import tensorflow as tf
root_dir = 'data/'
# 该文件包括语言名称、旋转信息和字符数量
train_split_path = os.path.join(root_dir, 'splits', 'train.txt')
with open(train_split_path, 'r') as train_split:
train_classes = [line.rstrip() for line in train_split.readlines()]
# 类的数量
no_of_classes = len(train_classes)
# 样本数量
num_examples = 20
img_width = 28
img_height = 28
channels = 1
#初始化数据集的类的数量、样本数量、图像高度和宽度
train_dataset = np.zeros([no_of_classes, num_examples, img_height, img_width], dtype=np.float32)
# 读取所有图像,转换为numpy数组,并将其标签和值一起存储在train_dataset数组中
for label, name in enumerate(train_classes):
alphabet, character, rotation = name.split('/')
rotation = float(rotation[3:])
img_dir = os.path.join(root_dir, 'data', alphabet, character)
img_files = sorted(glob.glob(os.path.join(img_dir, '*.png')))
for index, img_file in enumerate(img_files):
values = 1. - np.array(Image.open(img_file).rotate(rotation).resize((img_width, img_height)), np.float32,
copy=False)
train_dataset[label, index] = values
def convolution_block(inputs, out_channels, name='conv'):
'''
加载完训练数据之后,需要提取特征向量,因为输入值是图像,所以使用卷积运算提取特征
定义了包含64个过滤器的卷积块,并使用批量标准化和Relu激活函数,最大池化
'''
conv = tf.layers.conv2d(inputs, out_channels, kernel_size=3, padding='SAME')
conv = tf.contrib.layers.batch_norm(conv, updates_collections=None, decay=0.99, scale=True, center=True)
conv = tf.nn.relu(conv)
conv = tf.contrib.layers.max_pool2d(conv, 2)
return conv
#定义嵌入函数,提供包含4个卷积块的嵌入
def get_embeddings(support_set, h_dim, z_dim, reuse=False):
net = convolution_block(support_set, h_dim)
net = convolution_block(net, h_dim)
net = convolution_block(net, h_dim)
net = convolution_block(net, z_dim)
net = tf.contrib.layers.flatten(net)
return net
num_way = 60
# 支撑集每个类的样本数量
num_shot = 5
# 查询集中查询点的数量
num_query = 5
# 样本数量
num_examples = 20
h_dim = 64
z_dim = 64
# 为支撑集和查询集初始化占位符
support_set = tf.placeholder(tf.float32, [None, None, img_height, img_width, channels])
query_set = tf.placeholder(tf.float32, [None, None, img_height, img_width, channels])
# 将支撑集和查询集的形状分别存储
support_set_shape = tf.shape(support_set)
query_set_shape = tf.shape(query_set)
#得到类的数量、支撑集中数据点的数量、查询集中数据点的数量
num_classes, num_support_points = support_set_shape[0], support_set_shape[1]
num_query_points = query_set_shape[1]
# 给标签定义占位符
y = tf.placeholder(tf.int64, [None, None])
# 将标签转换为one-hot编码
y_one_hot = tf.one_hot(y, depth=num_classes)
# 使用嵌入函数,为支撑集生成嵌入
support_set_embeddings = get_embeddings(
tf.reshape(support_set, [num_classes * num_support_points, img_height, img_width, channels]), h_dim, z_dim)
# 计算每个类的原型,它是类的支撑集嵌入的均值向量
embedding_dimension = tf.shape(support_set_embeddings)[-1]
class_prototype = tf.reduce_mean(
tf.reshape(support_set_embeddings, [num_classes, num_support_points, embedding_dimension]), axis=1)
# 相同的方式获取查询集的嵌入
query_set_embeddings = get_embeddings(
tf.reshape(query_set, [num_classes * num_query_points, img_height, img_width, channels]), h_dim, z_dim, reuse=True)
# 获得嵌入后,定义一个距离函数,给出了类原型和查询集嵌入之间的距离
def euclidean_distance(a, b):
N, D = tf.shape(a)[0], tf.shape(a)[1]
M = tf.shape(b)[0]
a = tf.tile(tf.expand_dims(a, axis=1), (1, M, 1))
b = tf.tile(tf.expand_dims(b, axis=0), (N, 1, 1))
return tf.reduce_mean(tf.square(a - b), axis=2)
# 计算类原型和查询集嵌入之间的距离
distance = euclidean_distance(class_prototype, query_set_embeddings)
# 将距离输入softmax函数,得到每个类的概率
predicted_probability = tf.reshape(tf.nn.log_softmax(-distance), [num_classes, num_query_points, -1])
loss = -tf.reduce_mean(tf.reshape(tf.reduce_sum(tf.multiply(y_one_hot, predicted_probability), axis=-1), [-1]))
accuracy = tf.reduce_mean(tf.to_float(tf.equal(tf.argmax(predicted_probability, axis=-1), y)))
train = tf.train.AdamOptimizer().minimize(loss)
# 启动TensorFlow会话训练模型
sess = tf.InteractiveSession()
init = tf.global_variables_initializer()
sess.run(init)
num_epochs = 20
num_episodes = 100
# 阶段性训练,对于每一个阶段,抽样数据点,构建支撑集和查询集,并训练模型。
for epoch in range(num_epochs):
for episode in range(num_episodes):
# select 60 classes
episodic_classes = np.random.permutation(no_of_classes)[:num_way]
support = np.zeros([num_way, num_shot, img_height, img_width], dtype=np.float32)
query = np.zeros([num_way, num_query, img_height, img_width], dtype=np.float32)
for index, class_ in enumerate(episodic_classes):
selected = np.random.permutation(num_examples)[:num_shot + num_query]
support[index] = train_dataset[class_, selected[:num_shot]]
# 5 querypoints per classs
query[index] = train_dataset[class_, selected[num_shot:]]
support = np.expand_dims(support, axis=-1)
query = np.expand_dims(query, axis=-1)
labels = np.tile(np.arange(num_way)[:, np.newaxis], (1, num_query)).astype(np.uint8)
_, loss_, accuracy_ = sess.run([train, loss, accuracy],
feed_dict={support_set: support, query_set: query, y: labels})
if (episode + 1) % 10 == 0:
print('Epoch {} : Episode {} : Loss: {}, Accuracy: {}'.format(epoch + 1, episode + 1, loss_, accuracy_))