[pytoch framework] 3.1 actual combat of logistic regression

Posted by coditoergosum on Tue, 08 Mar 2022 21:05:33 +0100

import torch
import torch.nn as nn
import numpy as np

3.1 actual combat of logistic regression

In this chapter, we will deal with structured data and use logistic regression to classify structured data.

3.1.1 introduction to logistic regression

Logistic regression is a generalized linear model, which has many similarities with multiple linear regression analysis. Their model forms are basically the same, with wx+b, where w and B are the parameters to be solved. The difference is that their dependent variables are different. Multiple linear regression directly takes wx+b as the dependent variable, that is, y =wx+b, while logistic regression corresponds wx+b to an implicit state P through function L, p =L(wx+b), and then determines the value of the dependent variable according to the size of P and 1-p. If l is a logistic function, it is logistic regression. If l is a polynomial function, it is polynomial regression.

More commonly, logistic regression will add a layer of call to logistic function after linear regression.

Logistic regression is mainly used for binary prediction. We talked about Sigmod function when activating the function. Sigmod function is the most common logistic function, because the output of Sigmod function is the probability value between 0 and 1. When the probability is greater than 0.5, it is predicted to be 1 and less than 0.5, it is predicted to be 0.

Let's use the public data to introduce

3.1.2 UCI German Credit dataset

UCI German Credit is UCI's German credit data set, which contains original data and numerical data.

German Credit data is a data set that predicts the loan default tendency based on personal bank loan information and the overdue situation of customers applying for loans. The data set contains 1000 data in 24 dimensions,

Here, we directly use the processed numerical data as a display.


3.2 code practice

We use German Data numeric is numeric data processed by numpy. We can directly use numpy's load method to read it


After reading the data, we need to normalize the data

for j in range(l-1):

Scramble data


Distinguish between training set and test set. Since there is no verification set here, we directly use the accuracy of test set as the standard to judge whether it is good or bad

Distinguishing rules: 900 for training and 100 for testing

german. The format of data numeric is. The first 24 columns are 24 dimensions, and the last one is the label to be printed (0, 1), so we distinguish the data and label together


Let's define the model. The model is very simple

class LR(nn.Module):
    def __init__(self):
        self.fc=nn.Linear(24,2) # Since the 24 dimensions have been fixed, write 24 here
    def forward(self,x):
        return out

Accuracy on test set

def test(pred,lab):
    return torch.mean(t.float())

Here are some settings

criterion=nn.CrossEntropyLoss() # Loss using CrossEntropyLoss
optm=torch.optim.Adam(net.parameters()) # Adam optimization
epochs=1000 # Training 1000 times

Let's start training

for i in range(epochs):
    # Specify the model as the training mode and calculate the gradient
    # All input values need to be converted into Tensor of torch
    loss=criterion(y_hat,y) # Calculate loss
    optm.zero_grad() # The loss of the previous step is cleared
    loss.backward() # Back propagation
    optm.step() # optimization
    if (i+1)%100 ==0 : # Here we output relevant information every 100 times
        # Specify the model as the calculation mode
        # Use our test function to calculate the accuracy
        print("Epoch:{},Loss:{:.4f},Accuracy: {:.2f}".format(i+1,loss.item(),accu))
Epoch:100,Loss:0.6313,Accuracy: 0.76
Epoch:200,Loss:0.6065,Accuracy: 0.79
Epoch:300,Loss:0.5909,Accuracy: 0.80
Epoch:400,Loss:0.5801,Accuracy: 0.81
Epoch:500,Loss:0.5720,Accuracy: 0.82
Epoch:600,Loss:0.5657,Accuracy: 0.81
Epoch:700,Loss:0.5606,Accuracy: 0.81
Epoch:800,Loss:0.5563,Accuracy: 0.81
Epoch:900,Loss:0.5527,Accuracy: 0.81
Epoch:1000,Loss:0.5496,Accuracy: 0.80

The training was completed and our accuracy reached 80%

Topics: Pytorch