第三章 原型网络对字符集分类

tech2024-05-25  88

原型网络对训练集中不存在的类别也具有泛化能力,与孪生网络一样,也试图学习度量空间来进行分类,基本思想是创建每个类的原型表示,并根据类原型与查询点之间的距离对查询点(新点)进行分类。

使用原型网络对Omniglot字符集分类

import os import glob from PIL import Image import numpy as np import tensorflow as tf root_dir = 'data/' # 该文件包括语言名称、旋转信息和字符数量 train_split_path = os.path.join(root_dir, 'splits', 'train.txt') with open(train_split_path, 'r') as train_split: train_classes = [line.rstrip() for line in train_split.readlines()] # 类的数量 no_of_classes = len(train_classes) # 样本数量 num_examples = 20 img_width = 28 img_height = 28 channels = 1 #初始化数据集的类的数量、样本数量、图像高度和宽度 train_dataset = np.zeros([no_of_classes, num_examples, img_height, img_width], dtype=np.float32) # 读取所有图像,转换为numpy数组,并将其标签和值一起存储在train_dataset数组中 for label, name in enumerate(train_classes): alphabet, character, rotation = name.split('/') rotation = float(rotation[3:]) img_dir = os.path.join(root_dir, 'data', alphabet, character) img_files = sorted(glob.glob(os.path.join(img_dir, '*.png'))) for index, img_file in enumerate(img_files): values = 1. - np.array(Image.open(img_file).rotate(rotation).resize((img_width, img_height)), np.float32, copy=False) train_dataset[label, index] = values def convolution_block(inputs, out_channels, name='conv'): ''' 加载完训练数据之后,需要提取特征向量,因为输入值是图像,所以使用卷积运算提取特征 定义了包含64个过滤器的卷积块,并使用批量标准化和Relu激活函数,最大池化 ''' conv = tf.layers.conv2d(inputs, out_channels, kernel_size=3, padding='SAME') conv = tf.contrib.layers.batch_norm(conv, updates_collections=None, decay=0.99, scale=True, center=True) conv = tf.nn.relu(conv) conv = tf.contrib.layers.max_pool2d(conv, 2) return conv #定义嵌入函数,提供包含4个卷积块的嵌入 def get_embeddings(support_set, h_dim, z_dim, reuse=False): net = convolution_block(support_set, h_dim) net = convolution_block(net, h_dim) net = convolution_block(net, h_dim) net = convolution_block(net, z_dim) net = tf.contrib.layers.flatten(net) return net num_way = 60 # 支撑集每个类的样本数量 num_shot = 5 # 查询集中查询点的数量 num_query = 5 # 样本数量 num_examples = 20 h_dim = 64 z_dim = 64 # 为支撑集和查询集初始化占位符 support_set = tf.placeholder(tf.float32, [None, None, img_height, img_width, channels]) query_set = tf.placeholder(tf.float32, [None, None, img_height, img_width, channels]) # 将支撑集和查询集的形状分别存储 support_set_shape = tf.shape(support_set) query_set_shape = tf.shape(query_set) #得到类的数量、支撑集中数据点的数量、查询集中数据点的数量 num_classes, num_support_points = support_set_shape[0], support_set_shape[1] num_query_points = query_set_shape[1] # 给标签定义占位符 y = tf.placeholder(tf.int64, [None, None]) # 将标签转换为one-hot编码 y_one_hot = tf.one_hot(y, depth=num_classes) # 使用嵌入函数,为支撑集生成嵌入 support_set_embeddings = get_embeddings( tf.reshape(support_set, [num_classes * num_support_points, img_height, img_width, channels]), h_dim, z_dim) # 计算每个类的原型,它是类的支撑集嵌入的均值向量 embedding_dimension = tf.shape(support_set_embeddings)[-1] class_prototype = tf.reduce_mean( tf.reshape(support_set_embeddings, [num_classes, num_support_points, embedding_dimension]), axis=1) # 相同的方式获取查询集的嵌入 query_set_embeddings = get_embeddings( tf.reshape(query_set, [num_classes * num_query_points, img_height, img_width, channels]), h_dim, z_dim, reuse=True) # 获得嵌入后,定义一个距离函数,给出了类原型和查询集嵌入之间的距离 def euclidean_distance(a, b): N, D = tf.shape(a)[0], tf.shape(a)[1] M = tf.shape(b)[0] a = tf.tile(tf.expand_dims(a, axis=1), (1, M, 1)) b = tf.tile(tf.expand_dims(b, axis=0), (N, 1, 1)) return tf.reduce_mean(tf.square(a - b), axis=2) # 计算类原型和查询集嵌入之间的距离 distance = euclidean_distance(class_prototype, query_set_embeddings) # 将距离输入softmax函数,得到每个类的概率 predicted_probability = tf.reshape(tf.nn.log_softmax(-distance), [num_classes, num_query_points, -1]) loss = -tf.reduce_mean(tf.reshape(tf.reduce_sum(tf.multiply(y_one_hot, predicted_probability), axis=-1), [-1])) accuracy = tf.reduce_mean(tf.to_float(tf.equal(tf.argmax(predicted_probability, axis=-1), y))) train = tf.train.AdamOptimizer().minimize(loss) # 启动TensorFlow会话训练模型 sess = tf.InteractiveSession() init = tf.global_variables_initializer() sess.run(init) num_epochs = 20 num_episodes = 100 # 阶段性训练,对于每一个阶段,抽样数据点,构建支撑集和查询集,并训练模型。 for epoch in range(num_epochs): for episode in range(num_episodes): # select 60 classes episodic_classes = np.random.permutation(no_of_classes)[:num_way] support = np.zeros([num_way, num_shot, img_height, img_width], dtype=np.float32) query = np.zeros([num_way, num_query, img_height, img_width], dtype=np.float32) for index, class_ in enumerate(episodic_classes): selected = np.random.permutation(num_examples)[:num_shot + num_query] support[index] = train_dataset[class_, selected[:num_shot]] # 5 querypoints per classs query[index] = train_dataset[class_, selected[num_shot:]] support = np.expand_dims(support, axis=-1) query = np.expand_dims(query, axis=-1) labels = np.tile(np.arange(num_way)[:, np.newaxis], (1, num_query)).astype(np.uint8) _, loss_, accuracy_ = sess.run([train, loss, accuracy], feed_dict={support_set: support, query_set: query, y: labels}) if (episode + 1) % 10 == 0: print('Epoch {} : Episode {} : Loss: {}, Accuracy: {}'.format(epoch + 1, episode + 1, loss_, accuracy_))

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

最新回复(0)