Implementation of prototype networks based on PyTorch

Posted by - - NC - - on Wed, 05 Jan 2022 10:50:28 +0100

I modified the data processing method and evaluation based on Jay2coomzz_ As a beginner, if there is anything wrong with the model () method, please criticize and correct it. The original link is as follows:
https://blog.csdn.net/weixin_38471579/article/details/102631018

Data sets and source links will be placed in the comments area

1 data set

A simple data set introduction

Omniglot dataset contains 1623 types of handwriting, and each type contains 20 samples. These 1623 handwriting classes come from the alphabets of 50 different regions (or civilizations). For example, Latin civilization contains 26 alphabets and Greek contains 24 alphabets. Such as images_ The 24 Greek letters in the background / Greek folder represent the 24 letter classes under the Greek civilization, and each letter has only 20 samples.
Images is generally used for training_ 964 categories (letters in 30 regions) under the background folder are used to test images_ 659 classes (letters in 20 regions) under the evaluation folder.
The purpose of training is to train the model with 964 classes and identify 659 new classes. The test set and the training set are completely separated, and the problems are similar, but they have never been met. This is the meaning of meta learning learning to learn.

Data set processing

First import the test set and training set according to the folder name

import os 
import matplotlib.image as mpimg
import numpy as np
import csv

#The image data is transformed into numpy, and the data of each class is transformed into training set and test set, and stored in the dictionary

os.chdir('E:/pytorch/prototypical_network/scripts')
def load_data():
	#Validation set
	labels_trainData = {}
	label = 0
	for file in os.listdir('../data/images_background'):
		for dir in os.listdir('../data/images_background/' + file):
			labels_trainData[label] = []
			data = []
			for png in os.listdir('../data/images_background/' + file +'/' + dir):
				image_np = mpimg.imread('../data/images_background/' + file +'/' + dir+'/' +png)
				image_np.astype(np.float64)
				data.append(image_np)
			labels_trainData[label] = np.array(data)
			label += 1
	#Test set
	labels_testData = {}
	label = 0
	for file in os.listdir('../data/images_evaluation'):
		for dir in os.listdir('../data/images_evaluation/' + file):
			labels_testData[label] = []
			data = []
			for png in os.listdir('../data/images_evaluation/' + file +'/' + dir):
				image_np = mpimg.imread('../data/images_evaluation/' + file +'/' + dir+'/' +png)
				image_np.astype(np.float64)
				data.append(image_np)
			labels_testData[label] = np.array(data)
			label += 1
	return labels_trainData,labels_testData

Pictures will be stored in labels in the form of a dictionary_ Traindata and labels_ In testdata. Its type (label, that is, the key value of the dictionary) is stored as 0, 1, 2, 3,..., 962. (963 categories in training set and 658 categories in test set)

labels_trainData ,labels_testData = load_data()

Let's take a look at the format of the data (take class 0 of the test set as an example)

print(labels_testData[0].shape)

Output (20, 105, 105). You can see that there are 20 samples of this kind of text. Each sample is a 105 * 105 matrix. If you print another picture, you will see that the picture is read as a binary matrix such as 0 and 1. (it's not even a grayscale image, which is much easier)
Print a picture to see:

import matplotlib.pyplot as plt
plt.imshow(labels_testData[0][3])


Next, add a channel to the data set for later transmission to the neural network

wide = labels_trainData[0][0].shape[0]
length = labels_trainData[0][0].shape[1]
	
for label in labels_trainData.keys():
	labels_trainData[label] = np.reshape(labels_trainData[label], [-1,1,wide, length])

for label in labels_testData.keys():
	labels_testData[label] = np.reshape(labels_testData[label], [-1,1,wide, length])

2. Build network

Some references:

import os
import numpy as np
import h5py
import random
import csv

import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.autograd as autograd
from torch.autograd import Variable

CNN

class CNNnet(torch.nn.Module):
	def __init__(self,input_shape,outDim):	
		super(CNNnet,self).__init__()
		self.conv1 = torch.nn.Sequential(
			torch.nn.Conv2d(in_channels=input_shape[0],
							out_channels=16,
							kernel_size=3,
							stride=1,
							padding=1),
			torch.nn.BatchNorm2d(16),
			torch.nn.MaxPool2d(2),
			torch.nn.ReLU()
		)
		self.conv2 = torch.nn.Sequential(
			torch.nn.Conv2d(16,32,3,1,1),
			torch.nn.BatchNorm2d(32),
			nn.MaxPool2d(2),
			torch.nn.ReLU()
		)
		self.conv3 = torch.nn.Sequential(
			torch.nn.Conv2d(32,64,3,1,1),
			torch.nn.BatchNorm2d(64),
			nn.MaxPool2d(2),
			torch.nn.ReLU()
		)
		self.conv4 = torch.nn.Sequential(
			torch.nn.Conv2d(64,64,3,1,1),
			torch.nn.BatchNorm2d(64),
			#nn.MaxPool2d(2)
			torch.nn.ReLU()
		)
		self.conv5 = torch.nn.Sequential(
			torch.nn.Conv2d(64,64,3,1,1),
			torch.nn.BatchNorm2d(64),
			#nn.MaxPool2d(2)
			torch.nn.ReLU()
		)
		self.mlp1 = torch.nn.Linear(10816,125)		#'' 'modify torch here nn. X position '' in linear (x, 125)
		self.mlp2 = torch.nn.Linear(125,outDim)
		
	def forward(self, x):	#'' according to__ init__ Modify '' '
		x = self.conv1(x)
		x = self.conv2(x)
		x = self.conv3(x)
		x = self.conv4(x)
		x = self.conv5(x)
		x = self.mlp1(x.view(x.size(0),-1))
		x = self.mlp2(x)
		return x

This is a 5-layer CNN, which is the focus of training. Our purpose is to get a CNN with generalization ability and map similar types of problems to an out_dim dimension space, which should gather similar pictures as close as possible and different kinds of pictures as far as possible
Therefore, the loss function and prediction function need to be customized

Protonets

CNN is only a part of Protonets. The definition of Protonets class is given below:

def eucli_tensor(x,y):	#Calculate the Euclidean distance of two tensor s for loss calculation
	return -1*torch.sqrt(torch.sum((x-y)*(x-y))).view(1)

class Protonets(object):
	def __init__(self,input_shape,outDim,Ns,Nq,Nc,log_data,step,trainval=False):
		#Ns: number of supported sets, Nq: number of query sets, Nc: number of selected classes per iteration, log_data: the location of the center corresponding to the model and class to be stored. Step: if trainval==True, read the model and center of the trained step. Trainval: whether to start training the model again
		self.input_shape = input_shape
		self.outDim = outDim
		self.batchSize = 1
		self.Ns = Ns
		self.Nq = Nq
		self.Nc = Nc
		if trainval == False:
			#If training a new model, initialize CNN and central point
			self.center = {}
			self.model = CNNnet(input_shape,outDim).cuda()
		else:
			#Otherwise, load CNN model and center point
			self.center = {}
			self.model = torch.load(log_data+'model_net_'+str(step)+'.pkl')		#'' modify, file name of storage model '' '
			self.load_center(log_data+'model_center_'+str(step)+'.csv')	#'' modify, file name of storage center '' '
	
	def compute_center(self,data_set):	#data_set is a numpy object, which is a support set and calculates the center point corresponding to the support set
		center = 0
		for i in range(self.Ns):
			data = np.reshape(data_set[i], [1, self.input_shape[0], self.input_shape[1], self.input_shape[2]])
			data = Variable(torch.from_numpy(data)).cuda()
			data = self.model(data)[0]	#Embedding query points in another space
			if i == 0:
				center = data
			else:
				center += data
		center /= self.Ns
		return center
	
	def train(self,labels_data,class_number):	#Network training
		#Select class indices for episode
		class_index = list(range(class_number))
		random.shuffle(class_index)
		choss_class_index = class_index[:self.Nc]#Select 20 classes
		sample = {'xc':[],'xq':[]}
		for label in choss_class_index:
			D_set = labels_data[label]
			#From D_set randomly fetches the support set and query set
			support_set,query_set = self.randomSample(D_set)
			#Calculation center point
			self.center[label] = self.compute_center(support_set)
			#Store the center and query set in the list
			sample['xc'].append(self.center[label])	#list
			sample['xq'].append(query_set)
		#optimizer
		optimizer = torch.optim.Adam(self.model.parameters(),lr=0.001)
		optimizer.zero_grad()
		protonets_loss = self.loss(sample)
		protonets_loss.backward()
		optimizer.step()
	
	def loss(self,sample):	#Custom loss
		loss_1 = autograd.Variable(torch.FloatTensor([0])).cuda()
		for i in range(self.Nc):
			query_dataSet = sample['xq'][i]
			for n in range(self.Nq):
				data = np.reshape(query_dataSet[n], [1, self.input_shape[0], self.input_shape[1], self.input_shape[2]])
				data = Variable(torch.from_numpy(data)).cuda()
				data = self.model(data)[0]	#Embedding query points in another space
				#The Euclidean distance between the query point and each center point is calculated one by one
				predict = 0
				for j in range(self.Nc):
					center_j = sample['xc'][j]
					if j == 0:
						predict = eucli_tensor(data,center_j)
					else:
						predict = torch.cat((predict, eucli_tensor(data,center_j)), 0)
				#Stack for loss
				loss_1 += -1*F.log_softmax(predict,dim=0)[i]
		loss_1 /= self.Nq*self.Nc
		return loss_1
	
	def randomSample(self,D_set): #From D_set randomly takes the support set and query set (one of the 20 classes, shape is [20105105])
		index_list = list(range(D_set.shape[0]))#5 out of 20 pictures
		random.shuffle(index_list)
		support_data_index = index_list[:self.Ns]
		query_data_index = index_list[self.Ns:self.Ns + self.Nq]
		support_set = []
		query_set = []
		for i in support_data_index:
			support_set.append(D_set[i])
		for i in query_data_index:
			query_set.append(D_set[i])
		return support_set,query_set
	
	def evaluation_model(self,labels_data,class_number):
		test_accury = []
		center_for_test={}
		class_index = list(range(class_number))#More than 600 categories
		random.shuffle(class_index)
		choss_class_index = class_index[:self.Nc]#Select 20 classes
		sample = {'xc':[],'xq':[]}
		for label in choss_class_index:
			D_set = labels_data[label]
			#From D_set randomly fetches the support set and query set
			support_set,query_set = self.randomSample(D_set)
			#Calculation center point
			center_for_test[label] = self.compute_center(support_set)
			#Store the center and query set in the list
			sample['xc'].append(center_for_test[label])	#list
			sample['xq'].append(query_set)
		
		for i in range(self.Nc):
			query_dataSet = sample['xq'][i]
			for n in range(self.Nq):
				data = np.reshape(query_dataSet[n], [1, self.input_shape[0], self.input_shape[1], self.input_shape[2]])
				data = torch.from_numpy(data).cuda()
				data = self.model(data)[0]	#Embedding query points in another space
				#The Euclidean distance between the query point and each center point is calculated one by one
				predict = 0
				for j in range(self.Nc):
					center_j = sample['xc'][j]
					if j == 0:
						predict = eucli_tensor(data,center_j)
					else:
						predict = torch.cat((predict, eucli_tensor(data,center_j)), 0)
				y_pre_j = int(torch.argmax(F.log_softmax(predict,dim=0)))	#Closest to the j-th Center
				test_accury.append(1 if y_pre_j == i else 0)
		return sum(test_accury)/len(test_accury)
	
	def save_center(self,path):
		datas = []
		for label in self.center.keys():
			datas.append([label] + list(self.center[label].cpu().detach().numpy()))
		with open(path,"w", newline="") as datacsv:
			csvwriter = csv.writer(datacsv,dialect = ("excel"))
			csvwriter.writerows(datas)
	
	def load_center(self,path):
		csvReader = csv.reader(open(path))
		for line in csvReader:
			label = int(line[0])
			center = [ float(line[i]) for i in range(1,len(line))]
			center = np.array(center)
			center = Variable(torch.from_numpy(center)).cuda()
			self.center[label] = center

The general idea is:
1. Only the training set is used to train the network from the training set labels_ Select Nc classes from traindata, in which each class has 20 samples, then extract Ns support sets and Nq search sets from these 20 samples, take the support sets through CNN, calculate the mean value of the results, obtain the center point, and then calculate the loss update gradient with the verification set Go on to the next round of study
2. After every 50 rounds of learning, use the verification set to calculate the accuracy of the model. The evaluation scheme is as follows: in labels_ Select Nc classes from testdata, in which each class has 20 samples, and then extract Ns support sets and Nq search sets from these 20 samples. Pass the support sets through CNN, take the results, calculate the mean value, and obtain the center point. Then use the verification set to calculate the distance from each center point, and take the nearest point as the predicted value, so as to evaluate the network performance
The training part is as follows:
Take 20way-5 shot as an example:

protonets = Protonets((1,wide,length),10,5,5,20,'../log/',50)
for n in range(10000):	 ##Randomly select x classes to train an episode
		protonets.train(labels_trainData,class_number_train)                                                                                                                                                                                                                                                                                                                                                                                
		if n % 50 == 0 and n != 0:	#Store the model every 50 times and test the accuracy of the model. The accuracy of the training set and the accuracy of the test set are stored in the model_step_eval.txt
			torch.save(protonets.model, '../log/model_net_'+str(n)+'.pkl')
			protonets.save_center('../log/model_center_'+str(n)+'.csv')
			test_accury = protonets.evaluation_model(labels_testData,class_number_test)
			print(test_accury)
			str_data = str(n) + ',' + str('       test_accury     ') + str(test_accury) + '\n'
			with open('../log/model_step_eval.txt', "a") as f:
				f.write(str_data)
		print(n)

The training converges very fast. Some results are listed below: (saved in model_step_eval.txt)

50,       test_accury     0.72
100,       test_accury     0.86
150,       test_accury     0.88
200,       test_accury     0.76
250,       test_accury     0.77
300,       test_accury     0.88
350,       test_accury     0.86
400,       test_accury     0.89
450,       test_accury     0.95
500,       test_accury     0.96
550,       test_accury     0.87
600,       test_accury     0.93
650,       test_accury     0.9
700,       test_accury     0.98
750,       test_accury     0.98
...........
...........
3400,       test_accury     0.9
3450,       test_accury     0.99
3500,       test_accury     0.92
3550,       test_accury     1.0

Topics: Python network Pytorch Deep Learning Deeplearning