PointNet + + visualization of prediction results

Posted by pianoparadise on Sun, 12 Dec 2021 05:00:05 +0100

At present, there are few visual data about PointNet + + prediction results on the Internet, which are generally direct visual data sets. Here is a code I use Matplotlib to visualize prediction. I hope it can be helpful to you.

Principle:

Briefly explain the principle of the code. Firstly, we use the network to give the prediction results of the input image and store them as txt files; Then use Matplotlib to read txt files and draw 3d images.

Prediction effect

Since the 3d image cannot be displayed directly and requires specific software, we can only convert it into 2d image and report an error. The results are as follows

Preparation of materials

1) Network model. This code is based on PointNet + + of pytorch version, so the input and output format of the network model should be consistent. (comments will be made in the code)
2) Data set. Now that the prediction results are ready, we should have a more detailed understanding of the overall process of the network. The preparation of data sets will not be introduced in detail here. It should be noted that different data sets output different contents. For example, ShapeNet will output point cloud images, label (classification category) and target (segmentation category), while S3DIS will only output point cloud images and segmentation categories. Here we need to adjust.
3) Trained weight file. Save the trained weight file in PointNet + + format.
When the above materials are ready, you can start preparing to generate prediction results.

Main code

class Generate_txt_and_3d_img:
    def __init__(self,img_root,target_root,num_classes,testDataLoader,model_dict,color_map=None):
        self.img_root = img_root # Point cloud data path
        self.target_root = target_root  # Generate txt label and forecast result path
        self.testDataLoader = testDataLoader
        self.num_classes = num_classes
        self.color_map = color_map
        self.heat_map = False # Controls whether heatmap is output
        self.label_path_txt = os.path.join(self.target_root, 'label_txt') # txt file for label
        self.make_dir(self.label_path_txt)

        # Get the model and load the weight
        self.model_name = []
        self.model = []
        self.model_weight_path = []

        for k,v in model_dict.items():
            self.model_name.append(k)
            self.model.append(v[0])
            self.model_weight_path.append(v[1])

        # Load weight
        self.load_cheackpoint_for_models(self.model_name,self.model,self.model_weight_path)

        # create folder
        self.all_pred_image_path = [] # Path list of all forecast results
        self.all_pred_txt_path = [] # A list of all predicted paths
        for n in self.model_name:
            self.make_dir(os.path.join(self.target_root,n+'_predict_txt'))
            self.make_dir(os.path.join(self.target_root, n + '_predict_image'))
            self.all_pred_txt_path.append(os.path.join(self.target_root,n+'_predict_txt'))
            self.all_pred_image_path.append(os.path.join(self.target_root, n + '_predict_image'))
        "Forecast corresponding to the model txt Results and img When the results are generated, several elements are added to the list corresponding to several models"

        self.generate_predict_to_txt() # Generate forecast txt
        self.draw_3d_img() # Drawing

    def generate_predict_to_txt(self):

        for batch_id, (points, label, target) in tqdm.tqdm(enumerate(self.testDataLoader),
                                                                      total=len(self.testDataLoader),smoothing=0.9):

            #Point cloud data, label of the whole image, label of each point, point cloud data without normalization (with label) torch Size([1, 7, 2048])
            points = points.transpose(2, 1)
            #print('1',target.shape) # 1 torch.Size([1, 2048])
            xyz_feature_point = points[:, :6, :] # B C N ---->B N C
            # Save the label as a txt file
            point_set_without_normal = np.asarray(torch.cat([points.permute(0, 2, 1),target[:,:,None]],dim=-1)).squeeze(0)  # Generation labels have no normalized numpy form of point cloud data
            np.savetxt(os.path.join(self.label_path_txt,f'{batch_id}_label.txt'), point_set_without_normal, fmt='%.04f') # Store it as a txt file
            " points  torch.Size([16, 2048, 6])  label torch.Size([16, 1])  target torch.Size([16, 2048])"

            assert len(self.model) == len(self.all_pred_txt_path) , 'The path does not match the number of models, please check'

            for n,model,pred_path in zip(self.model_name,self.model,self.all_pred_txt_path):
                points = points.long()
                seg_pred, trans_feat = model(points, self.to_categorical(label, 16))
                seg_pred = seg_pred.cpu().data.numpy()
                #=================================================
         
                seg_pred = np.argmax(seg_pred, axis=-1)  # Obtain the prediction results b n c of the network
                #=================================================
                seg_pred = np.concatenate([np.asarray(xyz_feature_point), seg_pred[:, None, :]],
                        axis=1).transpose((0, 2, 1)).squeeze(0)  # Splice the point cloud with the prediction results and prepare to generate a txt file
                svae_path = os.path.join(pred_path, f'{n}_{batch_id}.txt')
                np.savetxt(svae_path,seg_pred, fmt='%.04f')

    def draw_3d_img(self):
        #   Call matpltlib to draw 3d images

        each_label = os.listdir(self.label_path_txt)  # All label txt paths
        self.label_path_3d_img = os.path.join(self.target_root, 'label_3d_img')
        self.make_dir(self.label_path_3d_img)

        assert  len(self.all_pred_txt_path) == len(self.all_pred_image_path)


        for i,(pre_txt_path,save_img_path,name) in enumerate(zip(self.all_pred_txt_path,self.all_pred_image_path,self.model_name)):
            each_txt_path = os.listdir(pre_txt_path) # Get the full name of the txt file

            for idx,(txt,lab) in tqdm.tqdm(enumerate(zip(each_txt_path,each_label)),total=len(each_txt_path)):
                if i == 0:
                    self.draw_each_img(os.path.join(self.label_path_txt, lab), idx,heat_maps=False)
                self.draw_each_img(os.path.join(pre_txt_path,txt),idx,name=name,save_path=save_img_path,heat_maps=self.heat_map)

        print(f'All forecast pictures have been generated. Please go to:{self.all_pred_image_path} see')

    def draw_each_img(self,root,idx,name=None,skip=1,save_path=None,heat_maps=False):
        "root: each txt Path to file"
        points = np.loadtxt(root)[:, :3]  # xyz coordinates of point cloud
        points_all = np.loadtxt(root)  # All coordinates of the point cloud
        points = self.pc_normalize(points)
        skip = skip  # Skip every n points

        fig = plt.figure()
        ax = fig.add_subplot(111, projection='3d')
        point_range = range(0, points.shape[0], skip)  # skip points to prevent crash
        x = points[point_range, 0]
        z = points[point_range, 1]
        y = points[point_range, 2]

        "Generate customized dye board according to the number of categories passed in. Label 0 corresponds to random color 1. Label 1 corresponds to random color 2"
        if self.color_map is not None:
            color_map = self.color_map
        else:
            color_map  = {idx: i for idx, i in enumerate(np.linspace(0, 0.9, num_classes))}
        Label = points_all[point_range, -1] # Get the label
        # Pass the label into the previous dictionary, find the corresponding color and put it in the list

        Color = list(map(lambda x: color_map[x], Label))
        
        ax.scatter(x,  # x
                   y,  # y
                   z,  # z
                   c=Color,  # Color,  # height data for color
                   s=25,
                   marker=".")
        ax.axis('auto')  # {equal, scaled}
        ax.set_xlabel('X')
        ax.set_ylabel('Y')
        ax.set_zlabel('Z')
        ax.axis('off')  # Sets the axis invisible
        ax.grid(False)  # Sets the background grid invisible
        ax.view_init(elev=0, azim=0)


        if save_path is None:
            plt.savefig(os.path.join(self.label_path_3d_img,f'{idx}_label_img.png'), dpi=300,bbox_inches='tight',transparent=True)
        else:
            plt.savefig(os.path.join(save_path, f'{idx}_{name}_img.png'), dpi=300, bbox_inches='tight',
                        transparent=True)

    def pc_normalize(self,pc):
        l = pc.shape[0]
        centroid = np.mean(pc, axis=0)
        pc = pc - centroid
        m = np.max(np.sqrt(np.sum(pc ** 2, axis=1)))
        pc = pc / m
        return pc

    def make_dir(self, root):
        if os.path.exists(root):
            print(f'{root} Path already exists, no need to create')
        else:
            os.mkdir(root)
    def to_categorical(self,y, num_classes):
        """ 1-hot encodes a tensor """
        new_y = torch.eye(num_classes)[y.cpu().data.numpy(),]
        if (y.is_cuda):
            return new_y.cuda()
        return new_y

    def load_cheackpoint_for_models(self,name,model,cheackpoints):

        assert cheackpoints is not None,'Please fill in the weight file'
        assert model is not None, 'Please instantiate the model'

        for n,m,c in zip(name,model,cheackpoints):
            print(f'Loading{n}Weight of.....')
            weight_dict = torch.load(os.path.join(c,'best_model.pth'))
            m.load_state_dict(weight_dict['model_state_dict'])
            print(f'{n}Weight loading completed')


if __name__ =='__main__':
    import copy


    img_root = r'Your dataset path' # Dataset path
    target_root = r'Path to save forecast results' # Output result path

    num_classes = 13 # Fill in the number of categories of the dataset
    choice_dataset = 'S3dis'
    # Import model section
    "All models in PointNet++Two parameters for standard input and two parameters for output. If the model outputs only one, it can be modified to output one more None!!!!"
    #==============================================
    from models.pointnet2_sem_seg import get_model as pointnet2
    from models.finally_csa_part_seg import get_model as csa
    from models.pointcouldtransformer_part_seg import get_model as pct
    model1 = pointnet2(num_classes=num_classes).eval()
    model2 = csa(num_classes, normal_channel=True).eval()
    model3 = pct(num_class=num_classes,normal_channel=False).eval()
    #============================================
    # Instantiated dataset
    "Dataset Similarly, press ShapeNet Format output three variables point_set, cls, seg # pointset is point cloud data, cls has 16 categories, and seg is a small category corresponding to different points in the data“
    "If it is not in this format, add one manually"

    if choice_dataset == 'ShapeNet':
        print('instantiation  ShapeNet')
        TEST_DATASET = PartNormalDataset(root=img_root, npoints=2048, split='test', normal_channel=True)
        testDataLoader = torch.utils.data.DataLoader(TEST_DATASET, batch_size=1, shuffle=False, num_workers=0,
                                                     drop_last=True)
        color_map = {idx: i for idx, i in enumerate(np.linspace(0, 0.9, num_classes))}
    else:
        TEST_DATASET = S3DISDataset(split='test', data_root=img_root, num_point=4096, test_area=5,
                                    block_size=1.0, sample_rate=1.0, transform=None)
        testDataLoader = torch.utils.data.DataLoader(TEST_DATASET, batch_size=1, shuffle=False, num_workers=0,
                                                     pin_memory=True, drop_last=True)
        color_maps = [(152, 223, 138), (174, 199, 232), (255, 127, 14), (91, 163, 138), (255, 187, 120), (188, 189, 34),
                     (140, 86, 75)
            , (255, 152, 150), (214, 39, 40), (197, 176, 213), (196, 156, 148), (23, 190, 207), (112, 128, 144)]

        color_map = []
        for i in color_maps:
            tem = ()
            for j in i:
                j = j / 255
                tem += (j,)
            color_map.append(tem)
        print('instantiation  S3DIS')


    #Fill in the model and weight path in the dictionary and fill it in the following format
    # If an error is reported in the loading weight, you can view the loading weight part in the class and modify it accordingly
    model_dict = {
        'PonintNet': [model1,r'Weight path 1'],
        'CSA': [model2, r'Weight path 2'],
        'PCT': [model3,r'Weight path 3']
    }


    c = Generate_txt_and_3d_img(img_root,target_root,num_classes,testDataLoader,model_dict,color_map)

Through this script, you can draw the pre results using matplotlib. Due to my limited strength, these codes are based on PointNet + +, so you should attribute PointNet + + code when predicting. Some points for attention are annotated and explained in the code, which is emphasized here again:
1. Therefore, the models have two inputs and two outputs
2. Remember to import the data set. By default, the data set outputs three values with ShapNet as the template. See ShapeNet data set for details. (the data set is not imported, remember to import it yourself if necessary). The preprocessing of training and prediction data sets must be the same, otherwise the prediction effect will be very poor.
3. All models and weights should be placed in the model_dict
4. If you think the saved image has a strange viewing angle, you can set the angle through ax.view_init(elev=0, azim=0)
What else can you think of later.

Start import

Import part at the beginning

"Pass in the model weight file, read the prediction points and generate the predicted weight txt file"

import tqdm
import matplotlib.pyplot as plt
import matplotlib
import torch
import os
import json
import warnings
import numpy as np
from torch.utils.data import Dataset
warnings.filterwarnings('ignore')
matplotlib.use("Agg")
def pc_normalize(pc):
    centroid = np.mean(pc, axis=0)
    pc = pc - centroid
    m = np.max(np.sqrt(np.sum(pc ** 2, axis=1)))
    pc = pc / m
    return pc

It must be added here

warnings.filterwarnings('ignore')
matplotlib.use("Agg")

These two lines of code, otherwise it will cause interruption when drawing. It seems that the drawing program of Matplotlib can't draw too many at one time. If it is too many, it may report an error (probably more than 1w)

The above code is my own code. I have no problem using it. If you have any problems, you can read the code for corresponding adjustment or leave a message in the comment area. I hope it can help you.
reference resources:
https://blog.csdn.net/ssq183/article/details/104603454/
Finally, we must remember:

Topics: Python Pytorch Computer Vision 3d