使用sklearn函数对模型进行交叉验证
- 交叉验证用来做什么
- sklearn 中的函数
交叉验证用来做什么
交叉验证(Cross-Validatio),是用于在驯良过程中对训练模型的性能和参数进行评估选择的技术。
它的意义在于能够充分利用优先的数据集,减少数据分布不均匀以及随机性带来的模型评估误差。
交叉验证的作用就是将数据集分割成多个自己进行多次训练,每次训练的训练集与测试机不完全相同。
sklearn 中的函数
from sklearn.model_selection import train_test_split, StratifiedKFold, KFold
skf = KFold(n_splits=10, random_state=233, shuffle=True)
n_splits:int, default=5
表示,要分割为多少个K子集
shuffle:bool, default=False
是否打乱数据
random_state:int or RandomState instance, default=None
随机状态,需要配合shuffle参数使用
参考文章 https://blog.csdn.net/weixin_43803950/article/details/120894868
# 如果有额外的标签,train_path 标签数据,如果标签是跟随train_path,第二个可不填入
skf.split(train_path, train_path)
for fold_idx, (train_idx, val_idx) in enumerate(skf.split(train_path, train_path)):
train_loader = torch.utils.data.DataLoader(
XunFeiDataset(np.array(train_path)[train_idx],
A.Compose([
A.RandomRotate90(),
A.RandomCrop(120, 120),
A.HorizontalFlip(p=0.5),
A.RandomContrast(p=0.5),
A.RandomBrightnessContrast(p=0.5),
])
), batch_size=8, shuffle=True, num_workers=0, pin_memory=False
)
val_loader = torch.utils.data.DataLoader(
XunFeiDataset(np.array(train_path)[val_idx],
A.Compose([
A.RandomCrop(120, 120),
])
), batch_size=8, shuffle=False, num_workers=0, pin_memory=False
)
for epoch_item in range(30):
# adjust_learning_rate(optimizer, epoch_item)
train_loss = train(train_loader, model, criterion, optimizer)
val_acc = validate(val_loader, model, criterion)
train_acc = validate(train_loader, model, criterion)
print(train_loss, train_acc, val_acc)