一个优秀的coder,先从CV工程开始......
首先复制了 LB 0.804- EfficientNet B0 Pytorch Pipeline | Kaggle 这个notebook并尝试提交(Kaggle的notebook中包括参赛者训练好的模型,所以本次提交只能熟悉一下流程而已),ok,0.804,下载了大佬的代码试图在本地修改模型结构并训练。
爬榜日记
20250416:efficientnet训练到loss为0.03左右的时候提交了一次,淦,为什么只有0.510
20240418:
更换backbone为‘convnext_tiny.in12k_ft_in1k’,并在后面加了一个attention块,loss大约0.023的时候提交,0.596!好耶,马上快及格了。这个backbone是convnext系列中最小的一个模型, 后续会考虑跑大一点的模型试试看
尝试了maxvit,具体的模型名称是‘maxvit_base_tf_384’,Deepseek说它的模型大小是119M,目前batchsize设置为16,在4090上以20.18G的显存占用训练。看来再大一些的模型就要租显卡了。。。
20240419:
ok,再大的模型也没必要了,因为发现convnext_base会导致timeout。比赛推理时不允许使用GPU,且CPU有时间限制(90min)。那么今天就需要尝试一些轻量级的模型。之前训练时其实没有对val_loss进行记录,现在更新了代码,可以在wandb查看训练情况。
问DS:音频有底噪,不考虑处理数据的情况下,推荐一些适用于音频分类的模型结构:
tf_efficientnetv2_b0.in1k
,resnext50_32x4d.a1h_in1k,
mobilevit_s.cvnets_in1k
挨个训练一下。
准备修改学习率调度策略为warmup+余弦退火。
以下是大佬的notebook中的代码,可以直接提交。由于其中只含有加载模型推理的代码,我把它命名为test.py。这样后续我就可以从test.py调用模型结构进行训练,无需重复定义模型,而且也只需修改一次模型结构。
导包
import os
import gc
import warnings
import logging
import time
import math
import cv2
from pathlib import Path
import numpy as np
import pandas as pd
import librosa
import torch
import torch.nn as nn
import torch.nn.functional as F
import timm
from tqdm.auto import tqdm
# Suppress warnings and limit logging output
warnings.filterwarnings("ignore")
logging.basicConfig(level=logging.ERROR)
训练参数
class CFG:
"""
Configuration class holding all paths and parameters required for the inference pipeline.
"""
test_soundscapes = '/kaggle/input/birdclef-2025/test_soundscapes'
submission_csv = '/kaggle/input/birdclef-2025/sample_submission.csv'
taxonomy_csv = '/kaggle/input/birdclef-2025/taxonomy.csv'
model_path = '/kaggle/input/birdclef-2025-efficientnet-b0' # 从这里上传?
# Audio parameters
FS = 32000
WINDOW_SIZE = 5
# Mel spectrogram parameters
N_FFT = 1034
HOP_LENGTH = 64
N_MELS = 136
FMIN = 20
FMAX = 16000
TARGET_SHAPE = (256, 256)
model_name = 'efficientnet_b0'
in_channels = 1
device = 'cpu'
# Inference parameters
batch_size = 16
use_tta = False
tta_count = 3
threshold = 0.7
use_specific_folds = False # If False, use all found models
folds = [0, 1] # Used only if use_specific_folds is True
debug = False
debug_count = 3
模型定义
可以看到大佬选择用timm库中的模型作为骨干网络(选择的是efficientnet b0)
输出后经过一个池化层
最后经过一个分类头来适配比赛中的分类任务
class BirdCLEFModel(nn.Module):
"""
Custom neural network model for BirdCLEF-2025 that uses a timm backbone.
"""
def __init__(self, cfg, num_classes):
"""
Initialize the BirdCLEFModel.
:param cfg: Configuration parameters.
:param num_classes: Number of output classes.
"""
super().__init__()
self.cfg = cfg
# Create backbone using timm with specified parameters.
self.backbone = timm.create_model(
cfg.model_name,
pretrained=False,
in_chans=cfg.in_channels,
drop_rate=0.0,
drop_path_rate=0.0
)
# Adjust final layers based on model type
if 'efficientnet' in cfg.model_name:
backbone_out = self.backbone.classifier.in_features
self.backbone.classifier = nn.Identity()
elif 'resnet' in cfg.model_name:
backbone_out = self.backbone.fc.in_features
self.backbone.fc = nn.Identity()
else:
backbone_out = self.backbone.get_classifier().in_features
self.backbone.reset_classifier(0, '')
self.pooling = nn.AdaptiveAvgPool2d(1)
self.feat_dim = backbone_out
self.classifier = nn.Linear(backbone_out, num_classes)
def forward(self, x):
"""
Forward pass through the network.
:param x: Input tensor.
:return: Logits for each class.
"""
features = self.backbone(x)
if isinstance(features, dict):
features = features['features']
# If features are 4D, apply global average pooling.
if len(features.shape) == 4:
features = self.pooling(features)
features = features.view(features.size(0), -1)
logits = self.classifier(features)
return logits
pipeline定义
那么么有人就要问了:什么是pipeline呢??
class BirdCLEF2025Pipeline:
"""
Pipeline for the BirdCLEF-2025 inference task.
This class organizes the complete inference process:
- Loading taxonomy data.
- 加载预训练模型文件.
- 将音频文件处理成梅尔频谱.
- 对每个音频片段进行预测.
- 生成提交所需的结果文件.
- 对结果文件进行后处理,以 smooth predictions? 这句没看懂
"""
def __init__(self, cfg):
"""
根据所给参数初始化inference pipeline.
:param cfg: Configuration object with paths and parameters.
"""
self.cfg = cfg
self.taxonomy_df = None
self.species_ids = []
self.models = []
self._load_taxonomy()
def _load_taxonomy(self):
"""
Load taxonomy data from CSV and extract species identifiers.
"""
print("Loading taxonomy data...")
self.taxonomy_df = pd.read_csv(self.cfg.taxonomy_csv)
self.species_ids = self.taxonomy_df['primary_label'].tolist()
print(f"Number of classes: {len(self.species_ids)}")
def audio2melspec(self, audio_data):
"""
将原始音频文件处理为梅尔频谱
:param audio_data: 1D numpy array of audio samples.
:return: Normalized mel spectrogram.
"""
if np.isnan(audio_data).any():
mean_signal = np.nanmean(audio_data)
audio_data = np.nan_to_num(audio_data, nan=mean_signal)
mel_spec = librosa.feature.melspectrogram(
y=audio_data,
sr=self.cfg.FS,
n_fft=self.cfg.N_FFT,
hop_length=self.cfg.HOP_LENGTH,
n_mels=self.cfg.N_MELS,
fmin=self.cfg.FMIN,
fmax=self.cfg.FMAX,
power=2.0
)
mel_spec_db = librosa.power_to_db(mel_spec, ref=np.max)
mel_spec_norm = (mel_spec_db - mel_spec_db.min()) / (mel_spec_db.max() - mel_spec_db.min() + 1e-8)
return mel_spec_norm
def process_audio_segment(self, audio_data):
"""
Process an audio segment to obtain a mel spectrogram with the target shape.
:param audio_data: 1D numpy array of audio samples.
:return: Processed mel spectrogram as a float32 numpy array.
"""
# Pad audio if it is shorter than the required window size.
if len(audio_data) < self.cfg.FS * self.cfg.WINDOW_SIZE:
audio_data = np.pad(
audio_data,
(0, self.cfg.FS * self.cfg.WINDOW_SIZE - len(audio_data)),
mode='constant'
)
mel_spec = self.audio2melspec(audio_data)
# Resize spectrogram to the target shape if necessary.
if mel_spec.shape != self.cfg.TARGET_SHAPE:
mel_spec = cv2.resize(mel_spec, self.cfg.TARGET_SHAPE, interpolation=cv2.INTER_LINEAR)
return mel_spec.astype(np.float32)
def find_model_files(self):
"""
Find all .pth model files in the specified model directory.
:return: List of model file paths.
"""
model_files = []
model_dir = Path(self.cfg.model_path)
for path in model_dir.glob('**/*.pth'):
model_files.append(str(path))
return model_files
def load_models(self):
"""
Load all found model files and prepare them for ensemble inference.
:return: List of loaded PyTorch models.
"""
self.models = []
model_files = self.find_model_files()
if not model_files:
print(f"Warning: No model files found under {self.cfg.model_path}!")
return self.models
print(f"Found a total of {len(model_files)} model files.")
# If specific folds are required, filter the model files.
if self.cfg.use_specific_folds:
filtered_files = []
for fold in self.cfg.folds:
fold_files = [f for f in model_files if f"fold{fold}" in f]
filtered_files.extend(fold_files)
model_files = filtered_files
print(f"Using {len(model_files)} model files for the specified folds ({self.cfg.folds}).")
# Load each model file.
for model_path in model_files:
try:
print(f"Loading model: {model_path}")
checkpoint = torch.load(model_path, map_location=torch.device(self.cfg.device))
model = BirdCLEFModel(self.cfg, len(self.species_ids))
model.load_state_dict(checkpoint['model_state_dict'])
model = model.to(self.cfg.device)
model.eval()
self.models.append(model)
except Exception as e:
print(f"Error loading model {model_path}: {e}")
return self.models
def apply_tta(self, spec, tta_idx):
"""
Apply test-time augmentation (TTA) to the spectrogram.
:param spec: Input mel spectrogram.
:param tta_idx: Index indicating which TTA to apply.
:return: Augmented spectrogram.
"""
if tta_idx == 0:
# No augmentation.
return spec
elif tta_idx == 1:
# Time shift (horizontal flip).
return np.flip(spec, axis=1)
elif tta_idx == 2:
# Frequency shift (vertical flip).
return np.flip(spec, axis=0)
else:
return spec
def predict_on_spectrogram(self, audio_path):
"""
Process a single audio file and predict species presence for each 5-second segment.
:param audio_path: Path to the audio file.
:return: Tuple (row_ids, predictions) for each segment.
"""
predictions = []
row_ids = []
soundscape_id = Path(audio_path).stem
try:
print(f"Processing {soundscape_id}")
audio_data, _ = librosa.load(audio_path, sr=self.cfg.FS)
total_segments = int(len(audio_data) / (self.cfg.FS * self.cfg.WINDOW_SIZE))
for segment_idx in range(total_segments):
start_sample = segment_idx * self.cfg.FS * self.cfg.WINDOW_SIZE
end_sample = start_sample + self.cfg.FS * self.cfg.WINDOW_SIZE
segment_audio = audio_data[start_sample:end_sample]
end_time_sec = (segment_idx + 1) * self.cfg.WINDOW_SIZE
row_id = f"{soundscape_id}_{end_time_sec}"
row_ids.append(row_id)
if self.cfg.use_tta:
all_preds = []
for tta_idx in range(self.cfg.tta_count):
mel_spec = self.process_audio_segment(segment_audio)
mel_spec = self.apply_tta(mel_spec, tta_idx)
mel_spec_tensor = torch.tensor(mel_spec, dtype=torch.float32).unsqueeze(0).unsqueeze(0)
mel_spec_tensor = mel_spec_tensor.to(self.cfg.device)
if len(self.models) == 1:
with torch.no_grad():
outputs = self.models[0](mel_spec_tensor)
probs = torch.sigmoid(outputs).cpu().numpy().squeeze()
all_preds.append(probs)
else:
segment_preds = []
for model in self.models:
with torch.no_grad():
outputs = model(mel_spec_tensor)
probs = torch.sigmoid(outputs).cpu().numpy().squeeze()
segment_preds.append(probs)
avg_preds = np.mean(segment_preds, axis=0)
all_preds.append(avg_preds)
final_preds = np.mean(all_preds, axis=0)
else:
mel_spec = self.process_audio_segment(segment_audio)
mel_spec_tensor = torch.tensor(mel_spec, dtype=torch.float32).unsqueeze(0).unsqueeze(0)
mel_spec_tensor = mel_spec_tensor.to(self.cfg.device)
if len(self.models) == 1:
with torch.no_grad():
outputs = self.models[0](mel_spec_tensor)
final_preds = torch.sigmoid(outputs).cpu().numpy().squeeze()
else:
segment_preds = []
for model in self.models:
with torch.no_grad():
outputs = model(mel_spec_tensor)
probs = torch.sigmoid(outputs).cpu().numpy().squeeze()
segment_preds.append(probs)
final_preds = np.mean(segment_preds, axis=0)
predictions.append(final_preds)
except Exception as e:
print(f"Error processing {audio_path}: {e}")
return row_ids, predictions
def run_inference(self):
"""
Run inference on all test soundscape audio files.
:return: Tuple (all_row_ids, all_predictions) aggregated from all files.
"""
test_files = list(Path(self.cfg.test_soundscapes).glob('*.ogg'))
if self.cfg.debug:
print(f"Debug mode enabled, using only {self.cfg.debug_count} files")
test_files = test_files[:self.cfg.debug_count]
print(f"Found {len(test_files)} test soundscapes")
all_row_ids = []
all_predictions = []
for audio_path in tqdm(test_files):
row_ids, predictions = self.predict_on_spectrogram(str(audio_path))
all_row_ids.extend(row_ids)
all_predictions.extend(predictions)
return all_row_ids, all_predictions
def create_submission(self, row_ids, predictions):
"""
Create the submission dataframe based on predictions.
:param row_ids: List of row identifiers for each segment.
:param predictions: List of prediction arrays.
:return: A pandas DataFrame formatted for submission.
"""
print("Creating submission dataframe...")
submission_dict = {'row_id': row_ids}
for i, species in enumerate(self.species_ids):
submission_dict[species] = [pred[i] for pred in predictions]
submission_df = pd.DataFrame(submission_dict)
submission_df.set_index('row_id', inplace=True)
sample_sub = pd.read_csv(self.cfg.submission_csv, index_col='row_id')
missing_cols = set(sample_sub.columns) - set(submission_df.columns)
if missing_cols:
print(f"Warning: Missing {len(missing_cols)} species columns in submission")
for col in missing_cols:
submission_df[col] = 0.0
submission_df = submission_df[sample_sub.columns]
submission_df = submission_df.reset_index()
return submission_df
def smooth_submission(self, submission_path):
"""
Post-process the submission CSV by smoothing predictions to enforce temporal consistency.
For each soundscape (grouped by the file name part of 'row_id'), each row's predictions
are averaged with those of its neighbors using defined weights.
:param submission_path: Path to the submission CSV file.
"""
print("Smoothing submission predictions...")
sub = pd.read_csv(submission_path)
cols = sub.columns[1:]
# Extract group names by splitting row_id on the last underscore
groups = sub['row_id'].str.rsplit('_', n=1).str[0].values
unique_groups = np.unique(groups)
for group in unique_groups:
# Get indices for the current group
idx = np.where(groups == group)[0]
sub_group = sub.iloc[idx].copy()
predictions = sub_group[cols].values
new_predictions = predictions.copy()
if predictions.shape[0] > 1:
# Smooth the predictions using neighboring segments
new_predictions[0] = (predictions[0] * 0.8) + (predictions[1] * 0.2)
new_predictions[-1] = (predictions[-1] * 0.8) + (predictions[-2] * 0.2)
for i in range(1, predictions.shape[0] - 1):
new_predictions[i] = (predictions[i - 1] * 0.2) + (predictions[i] * 0.6) + (
predictions[i + 1] * 0.2)
# Replace the smoothed values in the submission dataframe
sub.iloc[idx, 1:] = new_predictions
sub.to_csv(submission_path, index=False)
print(f"Smoothed submission saved to {submission_path}")
def run(self):
"""
Main method to execute the complete inference pipeline.
This method:
- Loads the pre-trained models.
- Processes test audio files and runs predictions.
- Creates the submission CSV.
- Applies smoothing to the predictions.
"""
start_time = time.time()
print("Starting BirdCLEF-2025 inference...")
print(f"TTA enabled: {self.cfg.use_tta} (variations: {self.cfg.tta_count if self.cfg.use_tta else 0})")
self.load_models()
if not self.models:
print("No models found! Please check model paths.")
return
print(f"Model usage: {'Single model' if len(self.models) == 1 else f'Ensemble of {len(self.models)} models'}")
row_ids, predictions = self.run_inference()
submission_df = self.create_submission(row_ids, predictions)
submission_path = 'submission.csv'
submission_df.to_csv(submission_path, index=False)
print(f"Initial submission saved to {submission_path}")
# Apply smoothing on the submission predictions.
self.smooth_submission(submission_path)
end_time = time.time()
print(f"Inference completed in {(end_time - start_time) / 60:.2f} minutes")
# Run the BirdCLEF2025 Pipeline:
if __name__ == "__main__":
cfg = CFG()
print(f"Using device: {cfg.device}")
pipeline = BirdCLEF2025Pipeline(cfg)
pipeline.run()
训练代码
由于想要自己训练一个模型,所以另外写了一个train.py
注意其中的
train_audio_dir = '/root/autodl-tmp/BirdCLEF2025/train_audio'
train_csv = '/root/autodl-tmp/BirdCLEF2025/train.csv'
taxonomy_csv = '/root/autodl-tmp/BirdCLEF2025/taxonomy.csv'
output_dir = ""
需要修改为你实际存放数据的位置。
以下是完整的train.py。如果报有关多线程的错,把TrainCFG中的num_workers设置成0就好。
(因为这部分我也没太搞懂)
# train.py
import os
import pandas as pd
import numpy as np
import librosa
import cv2
import torch
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import StratifiedKFold
from tqdm.auto import tqdm
# 继承test.py中的原始组件
from test import CFG, BirdCLEFModel
import warnings # 必须放在最顶部
warnings.filterwarnings("ignore") # 忽略所有警告
# ---------------------- 扩展训练配置 ----------------------
class TrainCFG(CFG):
"""新增训练专用参数"""
# 数据路径需要覆盖父类配置
train_audio_dir = '/root/autodl-tmp/BirdCLEF2025/train_audio' # "./data/birdclef-2025/train_audio"
train_csv = '/root/autodl-tmp/BirdCLEF2025/train.csv' # "./data/birdclef-2025/train.csv"
taxonomy_csv = '/root/autodl-tmp/BirdCLEF2025/taxonomy.csv' # './data/birdclef-2025/taxonomy.csv'
output_dir = "./checkpoints"
# 训练参数
device = "cuda" # if torch.cuda.is_available() else "cpu"
num_epochs = 20
lr = 1e-4
batch_size = 256
num_workers = 4
num_folds = 5
seed = 42
# 标签平滑参数
label_smoothing = 0.05
# 混合精度训练
use_amp = True
# ---------------------- 核心数据处理器 ----------------------
class BirdDataset(Dataset):
def __init__(self, cfg, df, audio_dir, is_train=True):
"""
保持与test.py中spectrogram生成逻辑一致
:param df: 从train.csv加载的DataFrame
"""
self.cfg = cfg
self.df = df.reset_index(drop=True)
self.audio_dir = audio_dir
self.is_train = is_train
# 从taxonomy获取标签映射
taxonomy = pd.read_csv(cfg.taxonomy_csv)
self.label_mapping = {
row['primary_label']: idx
for idx, row in taxonomy.iterrows()
}
print(f"Total classes: {len(self.label_mapping)}")
def __len__(self):
return len(self.df)
def _load_audio(self, filename):
"""严格保持与test.py相同的音频加载逻辑"""
audio_path = os.path.join(self.audio_dir, filename)
# 异常处理与test.py一致
try:
audio, _ = librosa.load(audio_path, sr=self.cfg.FS)
if np.isnan(audio).any():
audio = np.nan_to_num(audio, nan=np.mean(audio))
except Exception as e:
print(f"Error loading {audio_path}: {e}")
audio = np.zeros(self.cfg.FS * 5)
return audio
def _process_segment(self, audio):
"""严格复制test.py中的频谱生成代码"""
# 填充逻辑需要完全相同
if len(audio) < self.cfg.FS * self.cfg.WINDOW_SIZE:
audio = np.pad(
audio,
(0, self.cfg.FS * self.cfg.WINDOW_SIZE - len(audio)),
mode='constant'
)
# Mel频谱生成参数完全一致
mel_spec = librosa.feature.melspectrogram(
y=audio,
sr=self.cfg.FS,
n_fft=self.cfg.N_FFT,
hop_length=self.cfg.HOP_LENGTH,
n_mels=self.cfg.N_MELS,
fmin=self.cfg.FMIN,
fmax=self.cfg.FMAX,
power=2.0
)
mel_spec_db = librosa.power_to_db(mel_spec, ref=np.max)
mel_spec_norm = (mel_spec_db - mel_spec_db.min()) / (mel_spec_db.max() - mel_spec_db.min() + 1e-8)
# 调整尺寸方式与test.py完全一致
return cv2.resize(mel_spec_norm, self.cfg.TARGET_SHAPE, interpolation=cv2.INTER_LINEAR)
def __getitem__(self, idx):
row = self.df.iloc[idx]
# 1.音频加载与预处理
audio = self._load_audio(row['filename'])
# 2.保持数据增强与test.py的兼容性
# (注意:训练时需要自定义增广,但推理时不应启用)
if self.is_train:
# 随机时间裁剪(保持核心逻辑但扩展为训练模式)
if len(audio) > self.cfg.FS * self.cfg.WINDOW_SIZE:
start = np.random.randint(0, len(audio) - self.cfg.FS * self.cfg.WINDOW_SIZE)
audio = audio[start: start + self.cfg.FS * self.cfg.WINDOW_SIZE]
# 3.严格使用test.py频谱生成方法
spec = self._process_segment(audio) # shape (256,256)
# 4.目标生成(保持与模型输出的206类一致)
target = torch.zeros(len(self.label_mapping), dtype=torch.float32)
primary_idx = self.label_mapping.get(row['primary_label'], -1)
if primary_idx != -1:
target[primary_idx] = 1.0 - self.cfg.label_smoothing
target += self.cfg.label_smoothing / len(target)
return {
'spec': torch.tensor(spec).unsqueeze(0), # shape [1,256,256]
'target': target # shape [206]
}
# ---------------------- 训练循环 ----------------------
def train_fn(cfg, model, train_loader, optimizer, criterion):
model.train()
total_loss = 0.0
progress = tqdm(train_loader, desc="Training", leave=False)
scaler = torch.cuda.amp.GradScaler(enabled=cfg.use_amp)
for batch in progress:
specs = batch['spec'].to(cfg.device) # shape [B,1,256,256]
targets = batch['target'].to(cfg.device) # shape [B,206]
optimizer.zero_grad()
with torch.cuda.amp.autocast(enabled=cfg.use_amp):
outputs = model(specs) # 完全保留test.py的forward逻辑
loss = criterion(outputs, targets)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
total_loss += loss.item()
progress.set_postfix(loss=loss.item())
return total_loss / len(train_loader)
def validate_fn(cfg, model, val_loader, criterion):
model.eval()
total_loss = 0.0
progress = tqdm(val_loader, desc="Validating", leave=False)
with torch.no_grad():
for batch in progress:
specs = batch['spec'].to(cfg.device)
targets = batch['target'].to(cfg.device)
outputs = model(specs)
loss = criterion(outputs, targets)
total_loss += loss.item()
return total_loss / len(val_loader)
# ---------------------- 主流程 ----------------------
def main():
cfg = TrainCFG()
os.makedirs(cfg.output_dir, exist_ok=True)
# 确保不同来源的配置同步
cfg.TARGET_SHAPE = (256, 256) # 与test.py完全一致
torch.manual_seed(cfg.seed)
# 加载数据
train_df = pd.read_csv(cfg.train_csv)
taxonomy = pd.read_csv(cfg.taxonomy_csv)
assert len(taxonomy) == 206, "Taxonomy类数应与模型输出一致"
# Cross-validation训练循环
skf = StratifiedKFold(n_splits=cfg.num_folds)
for fold, (train_idx, val_idx) in enumerate(skf.split(train_df, train_df['primary_label'])):
print(f"\n{'=' * 25} Fold {fold + 1}/{cfg.num_folds} {'=' * 25}")
# 数据加载器
print('loading dataset...')
train_ds = BirdDataset(cfg, train_df.iloc[train_idx], cfg.train_audio_dir)
val_ds = BirdDataset(cfg, train_df.iloc[val_idx], cfg.train_audio_dir, is_train=False)
train_loader = DataLoader(
train_ds,
batch_size=cfg.batch_size,
shuffle=True,
num_workers=0,#cfg.num_workers,
pin_memory=True
)
val_loader = DataLoader(
val_ds,
batch_size=cfg.batch_size * 2,
shuffle=False,
num_workers=0,#cfg.num_workers,
)
# 初始化与test.py完全一致的模型结构
print('constructing MODEL...')
model = BirdCLEFModel(cfg, num_classes=len(taxonomy)).to(cfg.device)
optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.lr)
criterion = torch.nn.BCEWithLogitsLoss() # 使用与sigmoid推理一致的目标函数
# 训练循环
best_val_loss = float('inf')
for epoch in range(1, cfg.num_epochs + 1):
print(f"Epoch {epoch}/{cfg.num_epochs}")
train_loss = train_fn(cfg, model, train_loader, optimizer, criterion)
val_loss = validate_fn(cfg, model, val_loader, criterion)
# 保存最佳模型(与test.py加载格式完全兼容)
if val_loss < best_val_loss:
best_val_loss = val_loss
ckpt_path = os.path.join(cfg.output_dir, f"best_fold{fold}.pth")
torch.save({
'model_state_dict': model.state_dict(),
'config': vars(cfg)
}, ckpt_path)
print(f"Fold {fold} New best model saved (val_loss={val_loss:.4f})")
print(f"Fold {fold} completed. Best val loss: {best_val_loss:.4f}")
if __name__ == "__main__":
main()
在代码中学:
num_folds
(折数)通常指交叉验证中的子集划分数量,用于评估模型的泛化性能。以下是详细解释:
一、核心作用
-
数据利用率优化
将数据集划分为K
个子集(K=num_folds),进行K
次训练/验证,每次用 K-1个子集训练,1个子集验证,充分利用有限数据。 -
评估稳定性增强
通过多个不同验证集的平均结果,减少因数据划分随机性带来的评估偏差。
二、常用场景
场景 | 应用方式 |
---|---|
交叉验证训练 | 将num_folds=5 , 运行5次训练后平均结果 |
集成学习 | 每折训练一个子模型,最终预测为多模型投票或平均 |
超参数调优 | 在每折中搜索最佳参数,选择平均性能最优的配置 |
小数据集验证 | 数据量少时提高验证可靠性(常用num_folds=5/10 ) |
三、工作流程示例(5折交叉验证)
数据集划分:
原始数据 ➜ 划分为5等份(F1~F5)
训练轮次 | 训练集 | 验证集 | 评估模型 |
---|---|---|---|
第1折 | F2+F3+F4+F5 | F1 | Model_1 |
第2折 | F1+F3+F4+F5 | F2 | Model_2 |
第3折 | F1+F2+F4+F5 | F3 | Model_3 |
第4折 | F1+F2+F3+F5 | F4 | Model_4 |
第5折 | F1+F2+F3+F4 | F5 | Model_5 |
最终性能:
取5次验证结果的均值(如准确率、F1分数等)