文章目录
- 一、tqdm基本知识
- 二、在pytorch中使用tqdm
提示:以下是本篇文章正文内容,下面案例可供参考
一、tqdm基本知识
“tqdm” 是一个 Python 库,用于在命令行界面中创建进度条。
基本使用如下:
from tqdm import tqdm
import time
items = range(10)
for item in tqdm(items, desc="Test", total=len(items)):
time.sleep(1)
其文档如下:
只介绍传入的三个参数:iterable, desc, total
iterable:是一个可迭代对象
desc:进度条前的描述性信息
total:可迭代对象的长度
结果如下:
可以看到有描述性信息,进度条,已经运行了多少时间,还差多少时间,速度。在之后还可以添加后缀描述,见下面。
二、在pytorch中使用tqdm
一般都是在train函数中使用tqdm,讲dataloader做为一个可迭代对象传入tqdm
loop = tqdm((dataloader_train), desc=f"Epoch: [{epoch}/20]", total=len(dataloader_train))
for img, label in loop:
img = img.to(device)
label = label.to(device)
output = model(img)
optimizer.zero_grad()
loss = criterion(output,label)
loss.backward()
optimizer.step()
train_loss += loss.item()
correct += (torch.argmax(output,dim=1) == label).sum().item()
loop.set_postfix(loss=loss.item() / label.shape[0])
print("epoch: {i} Train Loss: {loss}".format(i=epoch, loss=train_loss))
print("epoch: {i} Train Accuracy: {acc}".format(i=epoch, acc=correct / len(dataset_train)))