系统框架
1. 数据集加载
继承torch.utils.data.Dataset类,重写__getitem__和__len__方法,并在__getitem__中预处理数据。
# load.py
import torch
class IrisDataset(torch.utils.data.Dataset):
def __init__(self, data_file, iris_class):
super(IrisDataset, self).__init__()
self.iris_class = iris_class
self.all_data = []
with open(data_file, 'r') as f:
lines = f.readlines()
lines = [line.rstrip() for line in lines]
for l in lines:
l = l.split(',')
vec = [float(i) for i in l[:-1]]
label = self.iris_class[str(l[-1])]
self.all_data.append([vec, label])
def __getitem__(self, item):
fea, label = self.all_data[item]
fea, label = torch.tensor(fea, dtype=torch.float32), torch.tensor(label, dtype=torch.long)
# No data augmentation
return fea, label
def __len__(self):
return len(self.all_data)
if __name__ == "__main__":
import config
dataset = IrisDataset("iris/train", config.iris_class)
print(dataset.__getitem__(0))
2. 网络模型——MLP
# net.py
import torch
import torch.nn as nn
class Net(torch.nn.Module):
def __init__(self, input_dim=4, num_class=3):
super(Net, self).__init__()
self.fc = nn.Sequential(
nn.Linear(input_dim, 512),
nn.BatchNorm1d(512),
nn.ReLU(inplace=True),
nn.Linear(512, 128),
nn.BatchNorm1d(128),
nn.ReLU(inplace=True),
nn.Linear(128, 64),
nn.BatchNorm1d(64),
nn.ReLU(inplace=True),
nn.Linear(64, num_class),
nn.Softmax()
)
def forward(self, x):
return self.fc(x)
if __name__ == "__main__":
net = Net()
print(net)
x = torch.randn(2, 4)
print(net(x).shape)
3. 配置文件——网络参数、训练参数整理
# config.py
import warnings
warnings.filterwarnings('ignore')
"""dataset"""
iris_class = {
"Iris-setosa": 0,
"Iris-versicolor": 1,
"Iris-virginica": 2
}
"""net args"""
input_dim = 4
num_class = 3
"""train & valid"""
train_data = 'iris/train'
valid_data = 'iris/valid'
batch_size = 10
nworks = 1
max_epoch = 200
lr = 1e-3
factor = 0.9
""" test """
test_data = "iris/test"
pre_model = "pth/model_100.pth"
4. 训练
# train.py
import torch, os, tqdm
from torch.utils.data import DataLoader
import load, net, config
import matplotlib.pyplot as plt
def train():
if torch.cuda.is_available():
DEVICE = torch.device("cuda:" + str(0))
torch.backends.cudnn.benchmark = True
else:
DEVICE = torch.device("cpu")
print("current deveice:", DEVICE)
# load dataset for train and eval
train_dataset = load.IrisDataset(config.train_data, config.iris_class)
train_batchs = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True, num_workers=config.nworks, pin_memory=True)
valid_dataset = load.IrisDataset(config.valid_data, config.iris_class)
valid_batchs = DataLoader(valid_dataset, batch_size=config.batch_size, shuffle=False, num_workers=config.nworks, pin_memory=True)
model = net.Net(config.input_dim, config.num_class)
model = model.to(DEVICE)
loss_criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.fc.parameters(), lr=1e-3)
os.makedirs("pth", exist_ok=True)
plt.ion()
train_loss, valid_loss, valid_acc = [], [], []
for epoch in tqdm.tqdm(range(1, config.max_epoch+1)):
optimizer.param_groups[0]['lr'] = config.lr * ((1 - (epoch-1)/ config.max_epoch)**config.factor)
""" train """
model.train()
total_loss=0
for batch, (fea, target) in enumerate(train_batchs):
fea, target = fea.to(DEVICE), torch.nn.functional.one_hot(target, 3).float().to(DEVICE)
pred = model(fea)
loss = loss_criterion(pred, target)
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
train_epoch_loss = total_loss / len(train_dataset)*config.batch_size
# print("epoch",epoch,"loss:", train_epoch_loss)
torch.save(model.state_dict(), os.path.join("pth", 'model_' + str(epoch) + '.pth'))
""" valid """
model.eval()
with torch.no_grad():
correct = 0
total = 0
total_loss = 0
for fea, labels in valid_batchs:
labels = labels.to(DEVICE)
fea, target = fea.to(DEVICE), torch.nn.functional.one_hot(labels, 3).float().to(DEVICE)
pred = model(fea)
loss = loss_criterion(pred, target)
total_loss += loss.item()
_, predicted = torch.max(pred.data, dim=1)
total += labels.size(0)
correct += (predicted == labels).sum()
valid_epoch_loss = total_loss / len(valid_dataset) * config.batch_size
# print('Accuracy test set: %d%%' % (100 * (correct / total)))
train_loss.append(train_epoch_loss)
valid_loss.append(valid_epoch_loss)
valid_acc.append(correct.cpu() / total)
plt.clf()
plt.plot(train_loss, color='black', label="train loss")
plt.plot(valid_loss, color='red', label="valid loss")
plt.plot(valid_acc, color='green', label="valid acc")
plt.grid()
plt.legend()
plt.savefig("train.jpg")
plt.ioff()
plt.close()
if __name__ == '__main__':
train()
训练过程可视化
5.测试
# test.py
import torch
from torch.utils.data import DataLoader
from sklearn.metrics import classification_report
import net, load, config
def test():
if torch.cuda.is_available():
DEVICE = torch.device("cuda:" + str(0))
torch.backends.cudnn.benchmark = True
else:
DEVICE = torch.device("cpu")
print("current deveice:", DEVICE)
# load dataset
test_dataset = load.IrisDataset(config.test_data, config.iris_class)
test_batchs = DataLoader(test_dataset, batch_size=10, shuffle=False,
num_workers=0, pin_memory=True)
# model
model = net.Net(config.input_dim, config.num_class)
model.load_state_dict(torch.load(config.pre_model, map_location='cpu'), strict=False)
model = model.to(DEVICE)
# test
model.eval()
with torch.no_grad():
preds, labels = [], []
for i, (fea, label) in enumerate(test_batchs):
pred = model(fea.to(DEVICE))
_, predicted = torch.max(pred.data, dim=1)
preds.append(predicted)
labels.append(label)
# report
preds = torch.stack(preds, dim=0).view(-1).cpu().numpy()
labels = torch.stack(labels, dim=0).view(-1).numpy()
report = classification_report(labels, preds, target_names = config.iris_class.keys())
print(report)
if __name__ == '__main__':
test()
测试集结果
6.文件结构
6.1 requirements.txt
matplotlib==3.7.2
scikit_learn==1.3.2
torch==2.0.0+cu118
tqdm==4.65.2
6.2附已划分的数据集
训练集——iris/train
5.1,3.5,1.4,0.2,Iris-setosa
4.9,3.0,1.4,0.2,Iris-setosa
4.7,3.2,1.3,0.2,Iris-setosa
4.6,3.1,1.5,0.2,Iris-setosa
5.0,3.6,1.4,0.2,Iris-setosa
5.4,3.9,1.7,0.4,Iris-setosa
4.6,3.4,1.4,0.3,Iris-setosa
5.0,3.4,1.5,0.2,Iris-setosa
4.4,2.9,1.4,0.2,Iris-setosa
4.9,3.1,1.5,0.1,Iris-setosa
5.4,3.7,1.5,0.2,Iris-setosa
4.8,3.4,1.6,0.2,Iris-setosa
4.8,3.0,1.4,0.1,Iris-setosa
4.3,3.0,1.1,0.1,Iris-setosa
5.8,4.0,1.2,0.2,Iris-setosa
5.7,4.4,1.5,0.4,Iris-setosa
5.4,3.9,1.3,0.4,Iris-setosa
5.1,3.5,1.4,0.3,Iris-setosa
5.7,3.8,1.7,0.3,Iris-setosa
5.1,3.8,1.5,0.3,Iris-setosa
5.4,3.4,1.7,0.2,Iris-setosa
5.1,3.7,1.5,0.4,Iris-setosa
4.6,3.6,1.0,0.2,Iris-setosa
5.1,3.3,1.7,0.5,Iris-setosa
4.8,3.4,1.9,0.2,Iris-setosa
5.0,3.0,1.6,0.2,Iris-setosa
5.0,3.4,1.6,0.4,Iris-setosa
5.2,3.5,1.5,0.2,Iris-setosa
5.2,3.4,1.4,0.2,Iris-setosa
4.7,3.2,1.6,0.2,Iris-setosa
7.0,3.2,4.7,1.4,Iris-versicolor
6.4,3.2,4.5,1.5,Iris-versicolor
6.9,3.1,4.9,1.5,Iris-versicolor
5.5,2.3,4.0,1.3,Iris-versicolor
6.5,2.8,4.6,1.5,Iris-versicolor
5.7,2.8,4.5,1.3,Iris-versicolor
6.3,3.3,4.7,1.6,Iris-versicolor
4.9,2.4,3.3,1.0,Iris-versicolor
6.6,2.9,4.6,1.3,Iris-versicolor
5.2,2.7,3.9,1.4,Iris-versicolor
5.0,2.0,3.5,1.0,Iris-versicolor
5.9,3.0,4.2,1.5,Iris-versicolor
6.0,2.2,4.0,1.0,Iris-versicolor
6.1,2.9,4.7,1.4,Iris-versicolor
5.6,2.9,3.6,1.3,Iris-versicolor
6.7,3.1,4.4,1.4,Iris-versicolor
5.6,3.0,4.5,1.5,Iris-versicolor
5.8,2.7,4.1,1.0,Iris-versicolor
6.2,2.2,4.5,1.5,Iris-versicolor
5.6,2.5,3.9,1.1,Iris-versicolor
5.9,3.2,4.8,1.8,Iris-versicolor
6.1,2.8,4.0,1.3,Iris-versicolor
6.3,2.5,4.9,1.5,Iris-versicolor
6.1,2.8,4.7,1.2,Iris-versicolor
6.4,2.9,4.3,1.3,Iris-versicolor
6.6,3.0,4.4,1.4,Iris-versicolor
6.8,2.8,4.8,1.4,Iris-versicolor
6.7,3.0,5.0,1.7,Iris-versicolor
6.0,2.9,4.5,1.5,Iris-versicolor
5.7,2.6,3.5,1.0,Iris-versicolor
6.3,3.3,6.0,2.5,Iris-virginica
5.8,2.7,5.1,1.9,Iris-virginica
7.1,3.0,5.9,2.1,Iris-virginica
6.3,2.9,5.6,1.8,Iris-virginica
6.5,3.0,5.8,2.2,Iris-virginica
7.6,3.0,6.6,2.1,Iris-virginica
4.9,2.5,4.5,1.7,Iris-virginica
7.3,2.9,6.3,1.8,Iris-virginica
6.7,2.5,5.8,1.8,Iris-virginica
7.2,3.6,6.1,2.5,Iris-virginica
6.5,3.2,5.1,2.0,Iris-virginica
6.4,2.7,5.3,1.9,Iris-virginica
6.8,3.0,5.5,2.1,Iris-virginica
5.7,2.5,5.0,2.0,Iris-virginica
5.8,2.8,5.1,2.4,Iris-virginica
6.4,3.2,5.3,2.3,Iris-virginica
6.5,3.0,5.5,1.8,Iris-virginica
7.7,3.8,6.7,2.2,Iris-virginica
7.7,2.6,6.9,2.3,Iris-virginica
6.0,2.2,5.0,1.5,Iris-virginica
6.9,3.2,5.7,2.3,Iris-virginica
5.6,2.8,4.9,2.0,Iris-virginica
7.7,2.8,6.7,2.0,Iris-virginica
6.3,2.7,4.9,1.8,Iris-virginica
6.7,3.3,5.7,2.1,Iris-virginica
7.2,3.2,6.0,1.8,Iris-virginica
6.2,2.8,4.8,1.8,Iris-virginica
6.1,3.0,4.9,1.8,Iris-virginica
6.4,2.8,5.6,2.1,Iris-virginica
7.2,3.0,5.8,1.6,Iris-virginica
验证集——iris/valid
4.8,3.1,1.6,0.2,Iris-setosa
5.4,3.4,1.5,0.4,Iris-setosa
5.2,4.1,1.5,0.1,Iris-setosa
5.5,4.2,1.4,0.2,Iris-setosa
4.9,3.1,1.5,0.1,Iris-setosa
5.0,3.2,1.2,0.2,Iris-setosa
5.5,3.5,1.3,0.2,Iris-setosa
4.9,3.1,1.5,0.1,Iris-setosa
4.4,3.0,1.3,0.2,Iris-setosa
5.1,3.4,1.5,0.2,Iris-setosa
5.5,2.4,3.8,1.1,Iris-versicolor
5.5,2.4,3.7,1.0,Iris-versicolor
5.8,2.7,3.9,1.2,Iris-versicolor
6.0,2.7,5.1,1.6,Iris-versicolor
5.4,3.0,4.5,1.5,Iris-versicolor
6.0,3.4,4.5,1.6,Iris-versicolor
6.7,3.1,4.7,1.5,Iris-versicolor
6.3,2.3,4.4,1.3,Iris-versicolor
5.6,3.0,4.1,1.3,Iris-versicolor
5.5,2.5,4.0,1.3,Iris-versicolor
7.4,2.8,6.1,1.9,Iris-virginica
7.9,3.8,6.4,2.0,Iris-virginica
6.4,2.8,5.6,2.2,Iris-virginica
6.3,2.8,5.1,1.5,Iris-virginica
6.1,2.6,5.6,1.4,Iris-virginica
7.7,3.0,6.1,2.3,Iris-virginica
6.3,3.4,5.6,2.4,Iris-virginica
6.4,3.1,5.5,1.8,Iris-virginica
6.0,3.0,4.8,1.8,Iris-virginica
6.9,3.1,5.4,2.1,Iris-virginica
测试集——iris/test
5.0,3.5,1.3,0.3,Iris-setosa
4.5,2.3,1.3,0.3,Iris-setosa
4.4,3.2,1.3,0.2,Iris-setosa
5.0,3.5,1.6,0.6,Iris-setosa
5.1,3.8,1.9,0.4,Iris-setosa
4.8,3.0,1.4,0.3,Iris-setosa
5.1,3.8,1.6,0.2,Iris-setosa
4.6,3.2,1.4,0.2,Iris-setosa
5.3,3.7,1.5,0.2,Iris-setosa
5.0,3.3,1.4,0.2,Iris-setosa
5.5,2.6,4.4,1.2,Iris-versicolor
6.1,3.0,4.6,1.4,Iris-versicolor
5.8,2.6,4.0,1.2,Iris-versicolor
5.0,2.3,3.3,1.0,Iris-versicolor
5.6,2.7,4.2,1.3,Iris-versicolor
5.7,3.0,4.2,1.2,Iris-versicolor
5.7,2.9,4.2,1.3,Iris-versicolor
6.2,2.9,4.3,1.3,Iris-versicolor
5.1,2.5,3.0,1.1,Iris-versicolor
5.7,2.8,4.1,1.3,Iris-versicolor
6.7,3.1,5.6,2.4,Iris-virginica
6.9,3.1,5.1,2.3,Iris-virginica
5.8,2.7,5.1,1.9,Iris-virginica
6.8,3.2,5.9,2.3,Iris-virginica
6.7,3.3,5.7,2.5,Iris-virginica
6.7,3.0,5.2,2.3,Iris-virginica
6.3,2.5,5.0,1.9,Iris-virginica
6.5,3.0,5.2,2.0,Iris-virginica
6.2,3.4,5.4,2.3,Iris-virginica
5.9,3.0,5.1,1.8,Iris-virginica