YOLOV7剪枝流程

news2024/11/16 23:59:58

YOLOV7剪枝流程

1、训练

1)划分数据集进行训练前的准备,按正常的划分流程即可

2)修改train.py文件

第一次处在参数列表里添加剪枝的参数,正常训练时设置为False,剪枝后微调时设置为True
parser.add_argument('--pruned', action='store_true', default=True, help='pruned model fine train')
第二处位置在
# Resume

下面修改代码,源代码为:

   # Resume
    start_epoch, best_fitness = 0, 0.0
    if pretrained:
        # Optimizer
        if ckpt['optimizer'] is not None:
            optimizer.load_state_dict(ckpt['optimizer'])
            best_fitness = ckpt['best_fitness']

        # EMA
        if ema and ckpt.get('ema'):
            ema.ema.load_state_dict(ckpt['ema'].float().state_dict())
            ma.updates = ckpt['updates']

         # Results
         if ckpt.get('training_results') is not None:
            results_file.write_text(ckpt['training_results'])  # write results.txt

修改为:

   # Resume
    start_epoch, best_fitness = 0, 0.0
    if pretrained:
        if not opt.pruned:
            # Optimizer
            if ckpt['optimizer'] is not None:
                optimizer.load_state_dict(ckpt['optimizer'])
                best_fitness = ckpt['best_fitness']

            # EMA
            if ema and ckpt.get('ema'):
                ema.ema.load_state_dict(ckpt['ema'].float().state_dict())
                ema.updates = ckpt['updates']

            # Results
            if ckpt.get('training_results') is not None:
                results_file.write_text(ckpt['training_results'])  # write results.txt
第三处位置在
# Epochs

源代码为:

         # Epochs
	     start_epoch = ckpt['epoch'] + 1
         if opt.resume:
             assert start_epoch > 0, '%s training to %g epochs is finished, nothing to resume.' % (weights, epochs)
         if epochs < start_epoch:
             logger.info('%s has been trained for %g epochs. Fine-tuning for %g additional epochs.' %
                        (weights, ckpt['epoch'], epochs))
             epochs += ckpt['epoch']  # finetune additional epochs

         del ckpt, state_dict

修改为

        # Epochs
        if not opt.pruned:
            start_epoch = ckpt['epoch'] + 1
            if opt.resume:
                assert start_epoch > 0, '%s training to %g epochs is finished, nothing to resume.' % (weights, epochs)
        elif opt.pruned:
            ckpt['epoch'] = 0
            start_epoch = ckpt['epoch'] + 1
        if epochs < start_epoch:
            logger.info('%s has been trained for %g epochs. Fine-tuning for %g additional epochs.' %
                        (weights, ckpt['epoch'], epochs))
            epochs += ckpt['epoch']  # finetune additional epochs
        if not opt.pruned:
            del ckpt, state_dict
        elif opt.pruned:
            del ckpt
第四处位置在
# Save model

源代码为:

         # Save model
         if (not opt.nosave) or (final_epoch and not opt.evolve):  # if save
            ckpt = {'epoch': epoch,
                    'best_fitness': best_fitness,
                    'training_results': results_file.read_text(),
                    'model': deepcopy(model.module if is_parallel(model) else model).half(),
                    'ema': deepcopy(ema.ema).half(),
                    'updates': ema.updates,
                    'optimizer': optimizer.state_dict(),
                    'wandb_id': wandb_logger.wandb_run.id if wandb_logger.wandb else None}

修改为:

            # Save model
            if (not opt.nosave) or (final_epoch and not opt.evolve):  # if save
                if opt.pruned:
                      ckpt = {
                        'model': deepcopy(model.module if is_parallel(model) else model).half(),
                            }
                elif not opt.pruned:
                    ckpt = {'epoch': epoch,
                            'best_fitness': best_fitness,
                            'training_results': results_file.read_text(),
                            'model': deepcopy(model.module if is_parallel(model) else model).half(),
                            'ema': deepcopy(ema.ema).half(),
                            'updates': ema.updates,
                            'optimizer': optimizer.state_dict(),
                            'wandb_id': wandb_logger.wandb_run.id if wandb_logger.wandb else None}
修改后的train.py整体代码如下:
import argparse
import logging
import math
import os
import random
import time
from copy import deepcopy
from pathlib import Path
from threading import Thread

import numpy as np
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
import torch.utils.data
import yaml
from torch.cuda import amp
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm

import test  # import test.py to get mAP after each epoch
from models.experimental import attempt_load
from models.yolo import Model
from utils.autoanchor import check_anchors
from utils.datasets import create_dataloader
from utils.general import labels_to_class_weights, increment_path, labels_to_image_weights, init_seeds, \
    fitness, strip_optimizer, get_latest_run, check_dataset, check_file, check_git_status, check_img_size, \
    check_requirements, print_mutation, set_logging, one_cycle, colorstr
from utils.google_utils import attempt_download
from utils.loss import ComputeLoss, ComputeLossOTA
from utils.plots import plot_images, plot_labels, plot_results, plot_evolution
from utils.torch_utils import ModelEMA, select_device, intersect_dicts, torch_distributed_zero_first, is_parallel
from utils.wandb_logging.wandb_utils import WandbLogger, check_wandb_resume

logger = logging.getLogger(__name__)


def train(hyp, opt, device, tb_writer=None):
    logger.info(colorstr('hyperparameters: ') + ', '.join(f'{k}={v}' for k, v in hyp.items()))
    save_dir, epochs, batch_size, total_batch_size, weights, rank, freeze = \
        Path(opt.save_dir), opt.epochs, opt.batch_size, opt.total_batch_size, opt.weights, opt.global_rank, opt.freeze

    # Directories
    wdir = save_dir / 'weights'
    wdir.mkdir(parents=True, exist_ok=True)  # make dir
    last = wdir / 'last.pt'
    best = wdir / 'best.pt'
    results_file = save_dir / 'results.txt'

    # Save run settings
    with open(save_dir / 'hyp.yaml', 'w') as f:
        yaml.dump(hyp, f, sort_keys=False)
    with open(save_dir / 'opt.yaml', 'w') as f:
        yaml.dump(vars(opt), f, sort_keys=False)

    # Configure
    plots = not opt.evolve  # create plots
    cuda = device.type != 'cpu'
    init_seeds(2 + rank)
    with open(opt.data) as f:
        data_dict = yaml.load(f, Loader=yaml.SafeLoader)  # data dict
    is_coco = opt.data.endswith('coco.yaml')

    # Logging- Doing this before checking the dataset. Might update data_dict
    loggers = {'wandb': None}  # loggers dict
    if rank in [-1, 0]:
        opt.hyp = hyp  # add hyperparameters
        run_id = torch.load(weights, map_location=device).get('wandb_id') if weights.endswith('.pt') and os.path.isfile(weights) else None
        wandb_logger = WandbLogger(opt, Path(opt.save_dir).stem, run_id, data_dict)
        loggers['wandb'] = wandb_logger.wandb
        data_dict = wandb_logger.data_dict
        if wandb_logger.wandb:
            weights, epochs, hyp = opt.weights, opt.epochs, opt.hyp  # WandbLogger might update weights, epochs if resuming

    nc = 1 if opt.single_cls else int(data_dict['nc'])  # number of classes
    names = ['item'] if opt.single_cls and len(data_dict['names']) != 1 else data_dict['names']  # class names
    assert len(names) == nc, '%g names found for nc=%g dataset in %s' % (len(names), nc, opt.data)  # check

    # Model
    pretrained = weights.endswith('.pt')
    if pretrained:
        if opt.pruned:
            from models.yolo import attempt_load
            #model = attempt_load(weights,map_location=device)
            ckpt = torch.load(weights, map_location=device)
            model = ckpt['model']
        else:
            with torch_distributed_zero_first(rank):
                attempt_download(weights)  # download if not found locally
            ckpt = torch.load(weights, map_location=device)  # 加载模型
            # 模型的定义
            model = Model(opt.cfg or ckpt['model'].yaml, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device)  # create
            exclude = ['anchor'] if (opt.cfg or hyp.get('anchors'))else []  # exclude keys
            state_dict = ckpt['model'].float().state_dict()  # to FP32 获得预权重的权值
            state_dict = intersect_dicts(state_dict, model.state_dict(), exclude=exclude)  # intersect
            # 将权重加载到模型内
            model.load_state_dict(state_dict, strict=False)  # load
            logger.info('Transferred %g/%g items from %s' % (len(state_dict), len(model.state_dict()), weights))  # report
        # with torch_distributed_zero_first(rank):
        #     attempt_download(weights)  # download if not found locally
        # ckpt = torch.load(weights, map_location=device)  # load checkpoint
        # model = Model(opt.cfg or ckpt['model'].yaml, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device)  # create
        # exclude = ['anchor'] if (opt.cfg or hyp.get('anchors')) and not opt.resume else []  # exclude keys
        # state_dict = ckpt['model'].float().state_dict()  # to FP32
        # state_dict = intersect_dicts(state_dict, model.state_dict(), exclude=exclude)  # intersect
        # model.load_state_dict(state_dict, strict=False)  # load
        # logger.info('Transferred %g/%g items from %s' % (len(state_dict), len(model.state_dict()), weights))  # report
    else:
        model = Model(opt.cfg, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device)  # create
    with torch_distributed_zero_first(rank):
        check_dataset(data_dict)  # check
    train_path = data_dict['train']
    test_path = data_dict['val']

    # Freeze
    freeze = [f'model.{x}.' for x in (freeze if len(freeze) > 1 else range(freeze[0]))]  # parameter names to freeze (full or partial)
    for k, v in model.named_parameters():
        v.requires_grad = True  # train all layers
        if any(x in k for x in freeze):
            print('freezing %s' % k)
            v.requires_grad = False

    # Optimizer
    nbs = 64  # nominal batch size
    accumulate = max(round(nbs / total_batch_size), 1)  # accumulate loss before optimizing
    hyp['weight_decay'] *= total_batch_size * accumulate / nbs  # scale weight_decay
    logger.info(f"Scaled weight_decay = {hyp['weight_decay']}")

    pg0, pg1, pg2 = [], [], []  # optimizer parameter groups
    for k, v in model.named_modules():
        if hasattr(v, 'bias') and isinstance(v.bias, nn.Parameter):
            pg2.append(v.bias)  # biases
        if isinstance(v, nn.BatchNorm2d):
            pg0.append(v.weight)  # no decay
        elif hasattr(v, 'weight') and isinstance(v.weight, nn.Parameter):
            pg1.append(v.weight)  # apply decay
        if hasattr(v, 'im'):
            if hasattr(v.im, 'implicit'):           
                pg0.append(v.im.implicit)
            else:
                for iv in v.im:
                    pg0.append(iv.implicit)
        if hasattr(v, 'imc'):
            if hasattr(v.imc, 'implicit'):           
                pg0.append(v.imc.implicit)
            else:
                for iv in v.imc:
                    pg0.append(iv.implicit)
        if hasattr(v, 'imb'):
            if hasattr(v.imb, 'implicit'):           
                pg0.append(v.imb.implicit)
            else:
                for iv in v.imb:
                    pg0.append(iv.implicit)
        if hasattr(v, 'imo'):
            if hasattr(v.imo, 'implicit'):           
                pg0.append(v.imo.implicit)
            else:
                for iv in v.imo:
                    pg0.append(iv.implicit)
        if hasattr(v, 'ia'):
            if hasattr(v.ia, 'implicit'):           
                pg0.append(v.ia.implicit)
            else:
                for iv in v.ia:
                    pg0.append(iv.implicit)
        if hasattr(v, 'attn'):
            if hasattr(v.attn, 'logit_scale'):   
                pg0.append(v.attn.logit_scale)
            if hasattr(v.attn, 'q_bias'):   
                pg0.append(v.attn.q_bias)
            if hasattr(v.attn, 'v_bias'):  
                pg0.append(v.attn.v_bias)
            if hasattr(v.attn, 'relative_position_bias_table'):  
                pg0.append(v.attn.relative_position_bias_table)
        if hasattr(v, 'rbr_dense'):
            if hasattr(v.rbr_dense, 'weight_rbr_origin'):  
                pg0.append(v.rbr_dense.weight_rbr_origin)
            if hasattr(v.rbr_dense, 'weight_rbr_avg_conv'): 
                pg0.append(v.rbr_dense.weight_rbr_avg_conv)
            if hasattr(v.rbr_dense, 'weight_rbr_pfir_conv'):  
                pg0.append(v.rbr_dense.weight_rbr_pfir_conv)
            if hasattr(v.rbr_dense, 'weight_rbr_1x1_kxk_idconv1'): 
                pg0.append(v.rbr_dense.weight_rbr_1x1_kxk_idconv1)
            if hasattr(v.rbr_dense, 'weight_rbr_1x1_kxk_conv2'):   
                pg0.append(v.rbr_dense.weight_rbr_1x1_kxk_conv2)
            if hasattr(v.rbr_dense, 'weight_rbr_gconv_dw'):   
                pg0.append(v.rbr_dense.weight_rbr_gconv_dw)
            if hasattr(v.rbr_dense, 'weight_rbr_gconv_pw'):   
                pg0.append(v.rbr_dense.weight_rbr_gconv_pw)
            if hasattr(v.rbr_dense, 'vector'):   
                pg0.append(v.rbr_dense.vector)

    if opt.adam:
        optimizer = optim.Adam(pg0, lr=hyp['lr0'], betas=(hyp['momentum'], 0.999))  # adjust beta1 to momentum
    else:
        optimizer = optim.SGD(pg0, lr=hyp['lr0'], momentum=hyp['momentum'], nesterov=True)

    optimizer.add_param_group({'params': pg1, 'weight_decay': hyp['weight_decay']})  # add pg1 with weight_decay
    optimizer.add_param_group({'params': pg2})  # add pg2 (biases)
    logger.info('Optimizer groups: %g .bias, %g conv.weight, %g other' % (len(pg2), len(pg1), len(pg0)))
    del pg0, pg1, pg2

    # Scheduler https://arxiv.org/pdf/1812.01187.pdf
    # https://pytorch.org/docs/stable/_modules/torch/optim/lr_scheduler.html#OneCycleLR
    if opt.linear_lr:
        lf = lambda x: (1 - x / (epochs - 1)) * (1.0 - hyp['lrf']) + hyp['lrf']  # linear
    else:
        lf = one_cycle(1, hyp['lrf'], epochs)  # cosine 1->hyp['lrf']
    scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)
    # plot_lr_scheduler(optimizer, scheduler, epochs)

    # EMA
    ema = ModelEMA(model) if rank in [-1, 0] else None

    # Resume
    start_epoch, best_fitness = 0, 0.0
    if pretrained:
        if not opt.pruned:
            # Optimizer
            if ckpt['optimizer'] is not None:
                optimizer.load_state_dict(ckpt['optimizer'])
                best_fitness = ckpt['best_fitness']

            # EMA
            if ema and ckpt.get('ema'):
                ema.ema.load_state_dict(ckpt['ema'].float().state_dict())
                ema.updates = ckpt['updates']

            # Results
            if ckpt.get('training_results') is not None:
                results_file.write_text(ckpt['training_results'])  # write results.txt

        # Epochs
        if not opt.pruned:
            start_epoch = ckpt['epoch'] + 1
            if opt.resume:
                assert start_epoch > 0, '%s training to %g epochs is finished, nothing to resume.' % (weights, epochs)
        elif opt.pruned:
            ckpt['epoch'] = 0
            start_epoch = ckpt['epoch'] + 1
        if epochs < start_epoch:
            logger.info('%s has been trained for %g epochs. Fine-tuning for %g additional epochs.' %
                        (weights, ckpt['epoch'], epochs))
            epochs += ckpt['epoch']  # finetune additional epochs
        if not opt.pruned:
            del ckpt, state_dict
        elif opt.pruned:
            del ckpt
        # start_epoch = ckpt['epoch'] + 1
        # if opt.resume:
        #     assert start_epoch > 0, '%s training to %g epochs is finished, nothing to resume.' % (weights, epochs)
        # if epochs < start_epoch:
        #     logger.info('%s has been trained for %g epochs. Fine-tuning for %g additional epochs.' %
        #                 (weights, ckpt['epoch'], epochs))
        #     epochs += ckpt['epoch']  # finetune additional epochs

        # del ckpt, state_dict

    # Image sizes
    gs = max(int(model.stride.max()), 32)  # grid size (max stride)
    nl = model.model[-1].nl  # number of detection layers (used for scaling hyp['obj'])
    imgsz, imgsz_test = [check_img_size(x, gs) for x in opt.img_size]  # verify imgsz are gs-multiples

    # DP mode
    if cuda and rank == -1 and torch.cuda.device_count() > 1:
        model = torch.nn.DataParallel(model)

    # SyncBatchNorm
    if opt.sync_bn and cuda and rank != -1:
        model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model).to(device)
        logger.info('Using SyncBatchNorm()')

    # Trainloader
    dataloader, dataset = create_dataloader(train_path, imgsz, batch_size, gs, opt,
                                            hyp=hyp, augment=True, cache=opt.cache_images, rect=opt.rect, rank=rank,
                                            world_size=opt.world_size, workers=opt.workers,
                                            image_weights=opt.image_weights, quad=opt.quad, prefix=colorstr('train: '))
    mlc = np.concatenate(dataset.labels, 0)[:, 0].max()  # max label class
    nb = len(dataloader)  # number of batches
    assert mlc < nc, 'Label class %g exceeds nc=%g in %s. Possible class labels are 0-%g' % (mlc, nc, opt.data, nc - 1)

    # Process 0
    if rank in [-1, 0]:
        testloader = create_dataloader(test_path, imgsz_test, batch_size * 2, gs, opt,  # testloader
                                       hyp=hyp, cache=opt.cache_images and not opt.notest, rect=True, rank=-1,
                                       world_size=opt.world_size, workers=opt.workers,
                                       pad=0.5, prefix=colorstr('val: '))[0]

        if not opt.resume:
            labels = np.concatenate(dataset.labels, 0)
            c = torch.tensor(labels[:, 0])  # classes
            # cf = torch.bincount(c.long(), minlength=nc) + 1.  # frequency
            # model._initialize_biases(cf.to(device))
            if plots:
                #plot_labels(labels, names, save_dir, loggers)
                if tb_writer:
                    tb_writer.add_histogram('classes', c, 0)

            # Anchors
            if not opt.noautoanchor:
                check_anchors(dataset, model=model, thr=hyp['anchor_t'], imgsz=imgsz)
            model.half().float()  # pre-reduce anchor precision

    # DDP mode
    if cuda and rank != -1:
        model = DDP(model, device_ids=[opt.local_rank], output_device=opt.local_rank,
                    # nn.MultiheadAttention incompatibility with DDP https://github.com/pytorch/pytorch/issues/26698
                    find_unused_parameters=any(isinstance(layer, nn.MultiheadAttention) for layer in model.modules()))

    # Model parameters
    hyp['box'] *= 3. / nl  # scale to layers
    hyp['cls'] *= nc / 80. * 3. / nl  # scale to classes and layers
    hyp['obj'] *= (imgsz / 640) ** 2 * 3. / nl  # scale to image size and layers
    hyp['label_smoothing'] = opt.label_smoothing
    model.nc = nc  # attach number of classes to model
    model.hyp = hyp  # attach hyperparameters to model
    model.gr = 1.0  # iou loss ratio (obj_loss = 1.0 or iou)
    model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) * nc  # attach class weights
    model.names = names

    # Start training
    t0 = time.time()
    nw = max(round(hyp['warmup_epochs'] * nb), 1000)  # number of warmup iterations, max(3 epochs, 1k iterations)
    # nw = min(nw, (epochs - start_epoch) / 2 * nb)  # limit warmup to < 1/2 of training
    maps = np.zeros(nc)  # mAP per class
    results = (0, 0, 0, 0, 0, 0, 0)  # P, R, mAP@.5, mAP@.5-.95, val_loss(box, obj, cls)
    scheduler.last_epoch = start_epoch - 1  # do not move
    scaler = amp.GradScaler(enabled=cuda)
    compute_loss_ota = ComputeLossOTA(model)  # init loss class
    compute_loss = ComputeLoss(model)  # init loss class
    logger.info(f'Image sizes {imgsz} train, {imgsz_test} test\n'
                f'Using {dataloader.num_workers} dataloader workers\n'
                f'Logging results to {save_dir}\n'
                f'Starting training for {epochs} epochs...')
    torch.save(model, wdir / 'init.pt')
    for epoch in range(start_epoch, epochs):  # epoch ------------------------------------------------------------------
        model.train()

        # Update image weights (optional)
        if opt.image_weights:
            # Generate indices
            if rank in [-1, 0]:
                cw = model.class_weights.cpu().numpy() * (1 - maps) ** 2 / nc  # class weights
                iw = labels_to_image_weights(dataset.labels, nc=nc, class_weights=cw)  # image weights
                dataset.indices = random.choices(range(dataset.n), weights=iw, k=dataset.n)  # rand weighted idx
            # Broadcast if DDP
            if rank != -1:
                indices = (torch.tensor(dataset.indices) if rank == 0 else torch.zeros(dataset.n)).int()
                dist.broadcast(indices, 0)
                if rank != 0:
                    dataset.indices = indices.cpu().numpy()

        # Update mosaic border
        # b = int(random.uniform(0.25 * imgsz, 0.75 * imgsz + gs) // gs * gs)
        # dataset.mosaic_border = [b - imgsz, -b]  # height, width borders

        mloss = torch.zeros(4, device=device)  # mean losses
        if rank != -1:
            dataloader.sampler.set_epoch(epoch)
        pbar = enumerate(dataloader)
        logger.info(('\n' + '%10s' * 8) % ('Epoch', 'gpu_mem', 'box', 'obj', 'cls', 'total', 'labels', 'img_size'))
        if rank in [-1, 0]:
            pbar = tqdm(pbar, total=nb)  # progress bar
        optimizer.zero_grad()
        for i, (imgs, targets, paths, _) in pbar:  # batch -------------------------------------------------------------
            ni = i + nb * epoch  # number integrated batches (since train start)
            imgs = imgs.to(device, non_blocking=True).float() / 255.0  # uint8 to float32, 0-255 to 0.0-1.0

            # Warmup
            if ni <= nw:
                xi = [0, nw]  # x interp
                # model.gr = np.interp(ni, xi, [0.0, 1.0])  # iou loss ratio (obj_loss = 1.0 or iou)
                accumulate = max(1, np.interp(ni, xi, [1, nbs / total_batch_size]).round())
                for j, x in enumerate(optimizer.param_groups):
                    # bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
                    x['lr'] = np.interp(ni, xi, [hyp['warmup_bias_lr'] if j == 2 else 0.0, x['initial_lr'] * lf(epoch)])
                    if 'momentum' in x:
                        x['momentum'] = np.interp(ni, xi, [hyp['warmup_momentum'], hyp['momentum']])

            # Multi-scale
            if opt.multi_scale:
                sz = random.randrange(imgsz * 0.5, imgsz * 1.5 + gs) // gs * gs  # size
                sf = sz / max(imgs.shape[2:])  # scale factor
                if sf != 1:
                    ns = [math.ceil(x * sf / gs) * gs for x in imgs.shape[2:]]  # new shape (stretched to gs-multiple)
                    imgs = F.interpolate(imgs, size=ns, mode='bilinear', align_corners=False)

            # Forward
            with amp.autocast(enabled=cuda):
                pred = model(imgs)  # forward
                if 'loss_ota' not in hyp or hyp['loss_ota'] == 1:
                    loss, loss_items = compute_loss_ota(pred, targets.to(device), imgs)  # loss scaled by batch_size
                else:
                    loss, loss_items = compute_loss(pred, targets.to(device))  # loss scaled by batch_size
                if rank != -1:
                    loss *= opt.world_size  # gradient averaged between devices in DDP mode
                if opt.quad:
                    loss *= 4.

            # Backward
            scaler.scale(loss).backward()

            # Optimize
            if ni % accumulate == 0:
                scaler.step(optimizer)  # optimizer.step
                scaler.update()
                optimizer.zero_grad()
                if ema:
                    ema.update(model)

            # Print
            if rank in [-1, 0]:
                mloss = (mloss * i + loss_items) / (i + 1)  # update mean losses
                mem = '%.3gG' % (torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0)  # (GB)
                s = ('%10s' * 2 + '%10.4g' * 6) % (
                    '%g/%g' % (epoch, epochs - 1), mem, *mloss, targets.shape[0], imgs.shape[-1])
                pbar.set_description(s)

                # Plot
                if plots and ni < 10:
                    f = save_dir / f'train_batch{ni}.jpg'  # filename
                    Thread(target=plot_images, args=(imgs, targets, paths, f), daemon=True).start()
                    # if tb_writer:
                    #     tb_writer.add_image(f, result, dataformats='HWC', global_step=epoch)
                    #     tb_writer.add_graph(torch.jit.trace(model, imgs, strict=False), [])  # add model graph
                elif plots and ni == 10 and wandb_logger.wandb:
                    wandb_logger.log({"Mosaics": [wandb_logger.wandb.Image(str(x), caption=x.name) for x in
                                                  save_dir.glob('train*.jpg') if x.exists()]})

            # end batch ------------------------------------------------------------------------------------------------
        # end epoch ----------------------------------------------------------------------------------------------------

        # Scheduler
        lr = [x['lr'] for x in optimizer.param_groups]  # for tensorboard
        scheduler.step()

        # DDP process 0 or single-GPU
        if rank in [-1, 0]:
            # mAP
            ema.update_attr(model, include=['yaml', 'nc', 'hyp', 'gr', 'names', 'stride', 'class_weights'])
            final_epoch = epoch + 1 == epochs
            if not opt.notest or final_epoch:  # Calculate mAP
                wandb_logger.current_epoch = epoch + 1
                results, maps, times = test.test(data_dict,
                                                 batch_size=batch_size * 2,
                                                 imgsz=imgsz_test,
                                                 model=ema.ema,
                                                 single_cls=opt.single_cls,
                                                 dataloader=testloader,
                                                 save_dir=save_dir,
                                                 verbose=nc < 50 and final_epoch,
                                                 plots=plots and final_epoch,
                                                 wandb_logger=wandb_logger,
                                                 compute_loss=compute_loss,
                                                 is_coco=is_coco,
                                                 v5_metric=opt.v5_metric)

            # Write
            with open(results_file, 'a') as f:
                f.write(s + '%10.4g' * 7 % results + '\n')  # append metrics, val_loss
            if len(opt.name) and opt.bucket:
                os.system('gsutil cp %s gs://%s/results/results%s.txt' % (results_file, opt.bucket, opt.name))

            # Log
            tags = ['train/box_loss', 'train/obj_loss', 'train/cls_loss',  # train loss
                    'metrics/precision', 'metrics/recall', 'metrics/mAP_0.5', 'metrics/mAP_0.5:0.95',
                    'val/box_loss', 'val/obj_loss', 'val/cls_loss',  # val loss
                    'x/lr0', 'x/lr1', 'x/lr2']  # params
            for x, tag in zip(list(mloss[:-1]) + list(results) + lr, tags):
                if tb_writer:
                    tb_writer.add_scalar(tag, x, epoch)  # tensorboard
                if wandb_logger.wandb:
                    wandb_logger.log({tag: x})  # W&B

            # Update best mAP
            fi = fitness(np.array(results).reshape(1, -1))  # weighted combination of [P, R, mAP@.5, mAP@.5-.95]
            if fi > best_fitness:
                best_fitness = fi
            wandb_logger.end_epoch(best_result=best_fitness == fi)

            # Save model
            if (not opt.nosave) or (final_epoch and not opt.evolve):  # if save
                if opt.pruned:
                      ckpt = {
                        'model': deepcopy(model.module if is_parallel(model) else model).half(),
                            }
                elif not opt.pruned:
                    ckpt = {'epoch': epoch,
                            'best_fitness': best_fitness,
                            'training_results': results_file.read_text(),
                            'model': deepcopy(model.module if is_parallel(model) else model).half(),
                            'ema': deepcopy(ema.ema).half(),
                            'updates': ema.updates,
                            'optimizer': optimizer.state_dict(),
                            'wandb_id': wandb_logger.wandb_run.id if wandb_logger.wandb else None}

                # Save last, best and delete
                torch.save(ckpt, last)
                if best_fitness == fi:
                    torch.save(ckpt, best)
                if (best_fitness == fi) and (epoch >= 200):
                    torch.save(ckpt, wdir / 'best_{:03d}.pt'.format(epoch))
                if epoch == 0:
                    torch.save(ckpt, wdir / 'epoch_{:03d}.pt'.format(epoch))
                elif ((epoch+1) % 25) == 0:
                    torch.save(ckpt, wdir / 'epoch_{:03d}.pt'.format(epoch))
                elif epoch >= (epochs-5):
                    torch.save(ckpt, wdir / 'epoch_{:03d}.pt'.format(epoch))
                if wandb_logger.wandb:
                    if ((epoch + 1) % opt.save_period == 0 and not final_epoch) and opt.save_period != -1:
                        wandb_logger.log_model(
                            last.parent, opt, epoch, fi, best_model=best_fitness == fi)
                del ckpt

        # end epoch ----------------------------------------------------------------------------------------------------
    # end training
    if rank in [-1, 0]:
        # Plots
        if plots:
            plot_results(save_dir=save_dir)  # save as results.png
            if wandb_logger.wandb:
                files = ['results.png', 'confusion_matrix.png', *[f'{x}_curve.png' for x in ('F1', 'PR', 'P', 'R')]]
                wandb_logger.log({"Results": [wandb_logger.wandb.Image(str(save_dir / f), caption=f) for f in files
                                              if (save_dir / f).exists()]})
        # Test best.pt
        logger.info('%g epochs completed in %.3f hours.\n' % (epoch - start_epoch + 1, (time.time() - t0) / 3600))
        if opt.data.endswith('coco.yaml') and nc == 80:  # if COCO
            for m in (last, best) if best.exists() else (last):  # speed, mAP tests
                results, _, _ = test.test(opt.data,
                                          batch_size=batch_size * 2,
                                          imgsz=imgsz_test,
                                          conf_thres=0.001,
                                          iou_thres=0.7,
                                          model=attempt_load(m, device).half(),
                                          single_cls=opt.single_cls,
                                          dataloader=testloader,
                                          save_dir=save_dir,
                                          save_json=True,
                                          plots=False,
                                          is_coco=is_coco,
                                          v5_metric=opt.v5_metric)

        # Strip optimizers
        final = best if best.exists() else last  # final model
        for f in last, best:
            if f.exists():
                strip_optimizer(f)  # strip optimizers
        if opt.bucket:
            os.system(f'gsutil cp {final} gs://{opt.bucket}/weights')  # upload
        if wandb_logger.wandb and not opt.evolve:  # Log the stripped model
            wandb_logger.wandb.log_artifact(str(final), type='model',
                                            name='run_' + wandb_logger.wandb_run.id + '_model',
                                            aliases=['last', 'best', 'stripped'])
        wandb_logger.finish_run()
    else:
        dist.destroy_process_group()
    torch.cuda.empty_cache()
    return results


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--weights', type=str, default='my_dataset/layer_pruning.pt', help='initial weights path')
    parser.add_argument('--cfg', type=str, default='cfg/training/yolov7-tiny.yaml', help='model.yaml path')
    parser.add_argument('--data', type=str, default='my_dataset/coco.yaml', help='data.yaml path')
    parser.add_argument('--hyp', type=str, default='data/hyp.scratch.p5.yaml', help='hyperparameters path')
    parser.add_argument('--epochs', type=int, default=600)
    parser.add_argument('--batch-size', type=int, default=32, help='total batch size for all GPUs')
    parser.add_argument('--img-size', nargs='+', type=int, default=[640, 640], help='[train, test] image sizes')
    parser.add_argument('--rect', action='store_true', help='rectangular training')
    parser.add_argument('--resume', nargs='?', const=True, default=False, help='resume most recent training')
    parser.add_argument('--nosave', action='store_true', help='only save final checkpoint')
    parser.add_argument('--notest', action='store_true', help='only test final epoch')
    parser.add_argument('--noautoanchor', action='store_true', help='disable autoanchor check')
    parser.add_argument('--evolve', action='store_true', help='evolve hyperparameters')
    parser.add_argument('--bucket', type=str, default='', help='gsutil bucket')
    parser.add_argument('--cache-images', action='store_true', help='cache images for faster training')
    parser.add_argument('--image-weights', action='store_true', help='use weighted image selection for training')
    parser.add_argument('--device', default='0', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
    parser.add_argument('--multi-scale', action='store_true', default=True, help='vary img-size +/- 50%%')
    parser.add_argument('--single-cls', action='store_true', help='train multi-class data as single-class')
    parser.add_argument('--adam', action='store_true', help='use torch.optim.Adam() optimizer')
    parser.add_argument('--sync-bn', action='store_true', help='use SyncBatchNorm, only available in DDP mode')
    parser.add_argument('--local_rank', type=int, default=-1, help='DDP parameter, do not modify')
    parser.add_argument('--workers', type=int, default=16, help='maximum number of dataloader workers')
    parser.add_argument('--project', default='runs/train', help='save to project/name')
    parser.add_argument('--entity', default=None, help='W&B entity')
    parser.add_argument('--name', default='exp', help='save to project/name')
    parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment')
    parser.add_argument('--quad', action='store_true', help='quad dataloader')
    parser.add_argument('--linear-lr', action='store_true', help='linear LR')
    parser.add_argument('--label-smoothing', type=float, default=0.0, help='Label smoothing epsilon')
    parser.add_argument('--upload_dataset', action='store_true', help='Upload dataset as W&B artifact table')
    parser.add_argument('--bbox_interval', type=int, default=-1, help='Set bounding-box image logging interval for W&B')
    parser.add_argument('--save_period', type=int, default=2, help='Log model after every "save_period" epoch')
    parser.add_argument('--artifact_alias', type=str, default="latest", help='version of dataset artifact to be used')
    parser.add_argument('--freeze', nargs='+', type=int, default=[0], help='Freeze layers: backbone of yolov7=50, first3=0 1 2')
    parser.add_argument('--v5-metric', action='store_true', help='assume maximum recall as 1.0 in AP calculation')
    parser.add_argument('--pruned', action='store_true', default=True, help='pruned model fine train')
    opt = parser.parse_args()

    # Set DDP variables
    opt.world_size = int(os.environ['WORLD_SIZE']) if 'WORLD_SIZE' in os.environ else 1
    opt.global_rank = int(os.environ['RANK']) if 'RANK' in os.environ else -1
    set_logging(opt.global_rank)
    #if opt.global_rank in [-1, 0]:
    #    check_git_status()
    #    check_requirements()

    # Resume
    wandb_run = check_wandb_resume(opt)
    if opt.resume and not wandb_run:  # resume an interrupted run
        ckpt = opt.resume if isinstance(opt.resume, str) else get_latest_run()  # specified or most recent path
        assert os.path.isfile(ckpt), 'ERROR: --resume checkpoint does not exist'
        apriori = opt.global_rank, opt.local_rank
        with open(Path(ckpt).parent.parent / 'opt.yaml') as f:
            opt = argparse.Namespace(**yaml.load(f, Loader=yaml.SafeLoader))  # replace
        opt.save_dir = Path(ckpt).parent.parent  # increment run
        opt.cfg, opt.weights, opt.resume, opt.batch_size, opt.global_rank, opt.local_rank = '', ckpt, True, opt.total_batch_size, *apriori  # reinstate
        logger.info('Resuming training from %s' % ckpt)
    else:
        # opt.hyp = opt.hyp or ('hyp.finetune.yaml' if opt.weights else 'hyp.scratch.yaml')
        opt.data, opt.cfg, opt.hyp = check_file(opt.data), check_file(opt.cfg), check_file(opt.hyp)  # check files
        assert len(opt.cfg) or len(opt.weights), 'either --cfg or --weights must be specified'
        opt.img_size.extend([opt.img_size[-1]] * (2 - len(opt.img_size)))  # extend to 2 sizes (train, test)
        opt.name = 'evolve' if opt.evolve else opt.name
        opt.save_dir = increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok | opt.evolve)  # increment run

    # DDP mode
    opt.total_batch_size = opt.batch_size
    device = select_device(opt.device, batch_size=opt.batch_size)
    if opt.local_rank != -1:
        assert torch.cuda.device_count() > opt.local_rank
        torch.cuda.set_device(opt.local_rank)
        device = torch.device('cuda', opt.local_rank)
        dist.init_process_group(backend='nccl', init_method='env://')  # distributed backend
        assert opt.batch_size % opt.world_size == 0, '--batch-size must be multiple of CUDA device count'
        opt.batch_size = opt.total_batch_size // opt.world_size

    # Hyperparameters
    with open(opt.hyp) as f:
        hyp = yaml.load(f, Loader=yaml.SafeLoader)  # load hyps

    # Train
    logger.info(opt)
    if not opt.evolve:
        tb_writer = None  # init loggers
        if opt.global_rank in [-1, 0]:
            prefix = colorstr('tensorboard: ')
            logger.info(f"{prefix}Start with 'tensorboard --logdir {opt.project}', view at http://localhost:6006/")
            tb_writer = SummaryWriter(opt.save_dir)  # Tensorboard
        train(hyp, opt, device, tb_writer)

    # Evolve hyperparameters (optional)
    else:
        # Hyperparameter evolution metadata (mutation scale 0-1, lower_limit, upper_limit)
        meta = {'lr0': (1, 1e-5, 1e-1),  # initial learning rate (SGD=1E-2, Adam=1E-3)
                'lrf': (1, 0.01, 1.0),  # final OneCycleLR learning rate (lr0 * lrf)
                'momentum': (0.3, 0.6, 0.98),  # SGD momentum/Adam beta1
                'weight_decay': (1, 0.0, 0.001),  # optimizer weight decay
                'warmup_epochs': (1, 0.0, 5.0),  # warmup epochs (fractions ok)
                'warmup_momentum': (1, 0.0, 0.95),  # warmup initial momentum
                'warmup_bias_lr': (1, 0.0, 0.2),  # warmup initial bias lr
                'box': (1, 0.02, 0.2),  # box loss gain
                'cls': (1, 0.2, 4.0),  # cls loss gain
                'cls_pw': (1, 0.5, 2.0),  # cls BCELoss positive_weight
                'obj': (1, 0.2, 4.0),  # obj loss gain (scale with pixels)
                'obj_pw': (1, 0.5, 2.0),  # obj BCELoss positive_weight
                'iou_t': (0, 0.1, 0.7),  # IoU training threshold
                'anchor_t': (1, 2.0, 8.0),  # anchor-multiple threshold
                'anchors': (2, 2.0, 10.0),  # anchors per output grid (0 to ignore)
                'fl_gamma': (0, 0.0, 2.0),  # focal loss gamma (efficientDet default gamma=1.5)
                'hsv_h': (1, 0.0, 0.1),  # image HSV-Hue augmentation (fraction)
                'hsv_s': (1, 0.0, 0.9),  # image HSV-Saturation augmentation (fraction)
                'hsv_v': (1, 0.0, 0.9),  # image HSV-Value augmentation (fraction)
                'degrees': (1, 0.0, 45.0),  # image rotation (+/- deg)
                'translate': (1, 0.0, 0.9),  # image translation (+/- fraction)
                'scale': (1, 0.0, 0.9),  # image scale (+/- gain)
                'shear': (1, 0.0, 10.0),  # image shear (+/- deg)
                'perspective': (0, 0.0, 0.001),  # image perspective (+/- fraction), range 0-0.001
                'flipud': (1, 0.0, 1.0),  # image flip up-down (probability)
                'fliplr': (0, 0.0, 1.0),  # image flip left-right (probability)
                'mosaic': (1, 0.0, 1.0),  # image mixup (probability)
                'mixup': (1, 0.0, 1.0),   # image mixup (probability)
                'copy_paste': (1, 0.0, 1.0),  # segment copy-paste (probability)
                'paste_in': (1, 0.0, 1.0)}    # segment copy-paste (probability)
        
        with open(opt.hyp, errors='ignore') as f:
            hyp = yaml.safe_load(f)  # load hyps dict
            if 'anchors' not in hyp:  # anchors commented in hyp.yaml
                hyp['anchors'] = 3
                
        assert opt.local_rank == -1, 'DDP mode not implemented for --evolve'
        opt.notest, opt.nosave = True, True  # only test/save final epoch
        # ei = [isinstance(x, (int, float)) for x in hyp.values()]  # evolvable indices
        yaml_file = Path(opt.save_dir) / 'hyp_evolved.yaml'  # save best result here
        if opt.bucket:
            os.system('gsutil cp gs://%s/evolve.txt .' % opt.bucket)  # download evolve.txt if exists

        for _ in range(300):  # generations to evolve
            if Path('evolve.txt').exists():  # if evolve.txt exists: select best hyps and mutate
                # Select parent(s)
                parent = 'single'  # parent selection method: 'single' or 'weighted'
                x = np.loadtxt('evolve.txt', ndmin=2)
                n = min(5, len(x))  # number of previous results to consider
                x = x[np.argsort(-fitness(x))][:n]  # top n mutations
                w = fitness(x) - fitness(x).min()  # weights
                if parent == 'single' or len(x) == 1:
                    # x = x[random.randint(0, n - 1)]  # random selection
                    x = x[random.choices(range(n), weights=w)[0]]  # weighted selection
                elif parent == 'weighted':
                    x = (x * w.reshape(n, 1)).sum(0) / w.sum()  # weighted combination

                # Mutate
                mp, s = 0.8, 0.2  # mutation probability, sigma
                npr = np.random
                npr.seed(int(time.time()))
                g = np.array([x[0] for x in meta.values()])  # gains 0-1
                ng = len(meta)
                v = np.ones(ng)
                while all(v == 1):  # mutate until a change occurs (prevent duplicates)
                    v = (g * (npr.random(ng) < mp) * npr.randn(ng) * npr.random() * s + 1).clip(0.3, 3.0)
                for i, k in enumerate(hyp.keys()):  # plt.hist(v.ravel(), 300)
                    hyp[k] = float(x[i + 7] * v[i])  # mutate

            # Constrain to limits
            for k, v in meta.items():
                hyp[k] = max(hyp[k], v[1])  # lower limit
                hyp[k] = min(hyp[k], v[2])  # upper limit
                hyp[k] = round(hyp[k], 5)  # significant digits

            # Train mutation
            results = train(hyp.copy(), opt, device)

            # Write mutation results
            print_mutation(hyp.copy(), results, yaml_file, opt.bucket)

        # Plot results
        plot_evolution(yaml_file)
        print(f'Hyperparameter evolution complete. Best results saved as: {yaml_file}\n'
              f'Command to train a new model with these hyperparameters: $ python train.py --hyp {yaml_file}')

2、剪枝

剪枝方法

YOLOv4剪枝【附代码】_strategy = tp.strategy.l1strategy()-CSDN博客参考此篇博文进行的通道剪枝。Pruning Filters for Efficient ConvNets这篇论文的技术

添加prunmodel.py文件

下载torch_pruning模块时一定要使用0.2.5版本,

pip install torch_pruning==0.2.7
pip install loguru

把训练好的模型路径放进去

layer_pruning('/home/jovyan/exp_3046/runs/train/best.pt')  

定义剪枝后保存的模型路径

torch.save(model_, '/home/jovyan/exp_3047/data/layer_pruning.pt')

修改prunmodel.py文件中需要剪枝的权重路径。重点修改58~62行。这里是以修改model的前10层为例。head层不能剪,我们选择backbone的层。

  included_layers = []
    for layer in model.model[:10]:  # 获取backbone
        if type(layer) is Conv:
            included_layers.append(layer.conv)
            included_layers.append(layer.bn)

下面代码是剪枝conv和BN层。【重点是tp.prune_conv】,自己修改amout也就是剪枝率。

        if isinstance(m, nn.Conv2d) and m in included_layers:
            # amount是剪枝率
            # 卷积剪枝
            pruning_plan = DG.get_pruning_plan(m, tp.prune_conv, idxs=strategy(m.weight, amount=0.8))
            logger.info(pruning_plan)
            # 执行剪枝
            pruning_plan.exec()
        if isinstance(m, nn.BatchNorm2d) and m in included_layers:
            # BN层剪枝
            pruning_plan = DG.get_pruning_plan(m, tp.prune_batchnorm, idxs=strategy(m.weight, amount=0.8))
            logger.info(pruning_plan)
            pruning_plan.exec()
出现以下内容说明剪枝成功
2023-03-15 14:57:40.825 | INFO     | __main__:layer_pruning:84 -   Params: 37196556 => 368397952023-03-15 14:57:41.176 | INFO     | __main__:layer_pruning:95 - 剪枝完成

代码如下:

import sys

sys.path.append("/home/jovyan/exp_3046")
# print(sys.path)

import torch_pruning as tp
from loguru import logger
from models.common import *
from models.experimental import Ensemble
from utils.torch_utils import select_device


"""
剪枝的时候根据模型结构去剪,不要盲目的猜
剪枝完需要进行一个微调训练
"""


# 加载模型
def attempt_load(weights, map_location=None, inplace=True):
    from models.yolo import Detect, Model

    model = Ensemble()
    ckpt = torch.load(weights, map_location=map_location)  # load weights
    model.append(ckpt['ema' if ckpt.get('ema') else 'model'].float().eval())  # without layer fuse  权值的加载

    # Compatibility updates
    for m in model.modules():  # 取出每一层
        if type(m) in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU, Detect, Model]:
            m.inplace = inplace  # pytorch 1.7.0 compatibility
            if type(m) is Detect:  # 判断是否为目标检测
                if not isinstance(m.anchor_grid, list):  # new Detect Layer compatibility
                    delattr(m, 'anchor_grid')
                    setattr(m, 'anchor_grid', [torch.zeros(1)] * m.nl)
        elif type(m) is Conv:  # 卷积层
            m._non_persistent_buffers_set = set()  # pytorch 1.6.0 compatibility

    if len(model) == 1:
        return model[-1], ckpt  # return model
    else:
        print(f'Ensemble created with {weights}\n')
        for k in ['names']:
            setattr(model, k, getattr(model[-1], k))
        model.stride = model[torch.argmax(torch.tensor([m.stride.max() for m in model])).int()].stride  # max stride
        return model, ckpt  # return ensemble


@logger.catch
def layer_pruning(weights):
    logger.add('../logs/layer_pruning.log', rotation='1 MB')
    device = select_device('cpu')
    model, ckpt = attempt_load(weights, map_location=device)
    for para in model.parameters():
        para.requires_grad = True
    # 创建输入样例,可在此修改输入大小
    x = torch.zeros(1, 3, 640, 640)
    # -----------------对整个模型的剪枝--------------------
    strategy = tp.strategy.L1Strategy()  # L1策略
    DG = tp.DependencyGraph()  # 依赖图
    DG = DG.build_dependency(model, example_inputs=x)

    """
    这里写要剪枝的层
    这里以backbone为例
    """
    included_layers = []
    for layer in model.model[:41]:  # 获取backbone
        if type(layer) is Conv:
            included_layers.append(layer.conv)
            included_layers.append(layer.bn)
    logger.info(included_layers)
    # 获取未剪枝之前的参数量
    num_params_before_pruning = tp.utils.count_params(model)
    # 模型遍历
    for m in model.modules():
        # 判断是否为卷积并且是否在需要剪枝的层里
        if isinstance(m, nn.Conv2d) and m in included_layers:
            # amount是剪枝率
            # 卷积剪枝
            # 层剪枝(需要筛选出不需要剪枝的层,比如yolo需要把头部的预测部分取出来,这个是不需要剪枝的) 
            pruning_plan = DG.get_pruning_plan(m, tp.prune_conv, idxs=strategy(m.weight, amount=0.4))
            logger.info(pruning_plan)
            # 执行剪枝
            pruning_plan.exec()
        if isinstance(m, nn.BatchNorm2d) and m in included_layers:
            # BN层剪枝
            pruning_plan = DG.get_pruning_plan(m, tp.prune_batchnorm, idxs=strategy(m.weight, amount=0.3))
            logger.info(pruning_plan)
            pruning_plan.exec()
    # 获得剪枝以后的参数量
    num_params_after_pruning = tp.utils.count_params(model)
    # 输出一下剪枝前后的参数量
    logger.info("  Params: %s => %s\n" % (num_params_before_pruning, num_params_after_pruning))
    # 剪枝完以后模型的保存(不要用torch.save(model.state_dict(),...))

    model_ = {
        'model': model.half(),
        # 'optimizer': ckpt['optimizer'],
        # 'training_results': ckpt['training_results'],
        'epoch': ckpt['epoch']
    }
    torch.save(model_, '/home/jovyan/exp_3046/my_dataset/layer_pruning.pt')
    del model_, ckpt
    logger.info("剪枝完成\n")


layer_pruning('/home/jovyan/exp_3046/runs/train/best.pt')  

3、微调

运行train.py

权重为剪枝之后的layer_pruning.pt

parser.add_argument('--pruned', action='store_true', default=True, help='pruned model fine train')

4、效果

使用初始训练好的权重进行预测用时 249.217s
在这里插入图片描述

使用剪枝之后的模型进行预测用时172.418s

在这里插入图片描述

p.prune_batchnorm, idxs=strategy(m.weight, amount=0.3))
            logger.info(pruning_plan)
            pruning_plan.exec()
    # 获得剪枝以后的参数量
    num_params_after_pruning = tp.utils.count_params(model)
    # 输出一下剪枝前后的参数量
    logger.info("  Params: %s => %s\n" % (num_params_before_pruning, num_params_after_pruning))
    # 剪枝完以后模型的保存(不要用torch.save(model.state_dict(),...))

    model_ = {
        'model': model.half(),
        # 'optimizer': ckpt['optimizer'],
        # 'training_results': ckpt['training_results'],
        'epoch': ckpt['epoch']
    }
    torch.save(model_, '/home/jovyan/exp_3046/my_dataset/layer_pruning.pt')
    del model_, ckpt
    logger.info("剪枝完成\n")


layer_pruning('/home/jovyan/exp_3046/runs/train/best.pt')  

3、微调

运行train.py

权重为剪枝之后的layer_pruning.pt

parser.add_argument('--pruned', action='store_true', default=True, help='pruned model fine train')

4、效果

使用初始训练好的权重进行预测用时 249.217s

[外链图片转存中...(img-BIsS94eN-1705027254462)]

使用剪枝之后的模型进行预测用时172.418s

[外链图片转存中...(img-lLdHRbRI-1705027254463)]

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1376613.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

POI:对Excel的基本写操作 整理1

首先导入相关依赖 <!-- https://mvnrepository.com/artifact/org.apache.poi/poi --><!--xls(03)--><dependency><groupId>org.apache.poi</groupId><artifactId>poi</artifactId><version>5.2.2</version></depend…

鸿蒙Harmony--状态管理器-@Observed装饰器和@ObjectLink装饰器详解

经历的越多&#xff0c;越喜欢简单的生活&#xff0c;干净的东西&#xff0c;清楚的感觉&#xff0c;有结果的事&#xff0c;和说到做到的人。把圈子变小&#xff0c;把语放缓&#xff0c;把心放宽&#xff0c;用心做好手边的事儿&#xff0c;该有的总会有的! 目录 一&#xff…

基于多反应堆的高并发服务器【C/C++/Reactor】(中)在TcpConnection 中接收并解析Http请求消息

一、在TcpConnection 中多添加和http协议相关的request和response struct TcpConnection {struct EventLoop* evLoop;struct Channel* channel;struct Buffer* readBuf;struct Buffer* writeBuf;char name[32];// http协议struct HttpRequest* request;struct HttpResponse* r…

leecode1143 | 最长公共子序列

给定两个字符串 text1 和 text2&#xff0c;返回这两个字符串的最长 公共子序列 的长度。如果不存在 公共子序列 &#xff0c;返回 0 。 一个字符串的 子序列 是指这样一个新的字符串&#xff1a;它是由原字符串在不改变字符的相对顺序的情况下删除某些字符&#xff08;也可以不…

第7章 PKI 和密码应用

7.1 非对称密码 第6章的“现代密码学”一节介绍了私钥&#xff08;对称&#xff09;和公钥&#xff08;非对称&#xff09;密码的基本原则。 你曾学过&#xff0c;对称密钥密码系统要求通信双方使用同一个共享秘密密钥&#xff0c;因而形成了安全分发密钥的问题。 你还曾学过…

如何使用科大讯飞星火大模型AI批量生成文章

如何使用科大讯飞的星火大模型AI工具批量生成文章呢&#xff1f; 我们可以使用科大讯飞AI的星火大模型API接口&#xff0c;它支持批量处理和生成文章的AI功能。 但是星火大模型API接口无法直接使用&#xff0c;一般需要技术人员开发对应程序对接才行。为了让不懂技术的普通用…

【微信小程序开发】深入学习小程序开发之功能扩展和优化

前言 随着移动互联网的快速发展&#xff0c;微信小程序作为一种轻量级应用&#xff0c;已经逐渐成为许多企业和个人进行业务推广和服务提供的重要平台本文将详细介绍 微信小程序开发的功能扩展和优化&#xff0c;帮助开发者更好地提升小程序的用户体验和性能。 一、功能扩展 …

数据库系统概念 第七版 中文答案 第3章 SQL介绍

3.1 将以下查询使用SQL语言编写&#xff0c;使用大学数据库模式。 &#xff08;我们建议您实际在数据库上运行这些查询&#xff0c;使用我们在书籍网站db-book.com上提供的示例数据。有关设置数据库和加载示例数据的说明&#xff0c;请参阅上述网站。&#xff09; a. 查找计算机…

蓝桥杯练习题(六)

&#x1f4d1;前言 本文主要是【算法】——蓝桥杯练习题&#xff08;六&#xff09;的文章&#xff0c;如果有什么需要改进的地方还请大佬指出⛺️ &#x1f3ac;作者简介&#xff1a;大家好&#xff0c;我是听风与他&#x1f947; ☁️博客首页&#xff1a;CSDN主页听风与他 …

软件分发点(DP)的合理规划

软件分发点&#xff08;Distribution Point, DP&#xff09;是用于托管文件以分发到计算机和移动设的服务器&#xff0c;Jamf Pro可以通过分发点分发以下类型的文件&#xff1a; 软件包 脚本 内部应用程序 内部书籍 Jamf Pro支持两种类型的分发点&#xff0c;您可以使用这些类型…

我用 Laf 开发了一个非常好用的密码管理工具

【KeePass 密码管理】是一款简单、安全简洁的账号密码管理工具&#xff0c;服务端使用 Laf 云开发&#xff0c;支持指纹验证、FaceID&#xff0c;N 重安全保障&#xff0c;可以随时随地记录我的账号和密码。 写这个小程序之前&#xff0c;在国内市场找了很多密码存储类的 App …

汽配企业MES管理系统的特点与实践

随着汽车工业的飞速发展&#xff0c;汽车零部件制造企业面临着日益复杂的生产环境和多样化的市场需求。为了应对这些挑战&#xff0c;许多汽配企业开始引入MES管理系统解决方案&#xff0c;以提高生产效率、优化资源配置、提升产品质量。本文将重点探讨汽配企业MES管理系统的特…

阿尔泰推出19“8槽4U上架式CPCI机箱 支持客户定制化机箱需求

阿尔泰科技发展有限公司是北京阿尔泰科技的子公司&#xff0c;公司于2010年正式成立&#xff0c;集全国技术支持与服务&#xff0c;销售&#xff0c;结构设计&#xff0c;项目支持等一批专业从事工控行业的工程师屹立在天府之国。公司涵盖数据采集&#xff0c;无线传输&#xf…

搭建sprinboot服务环境

搭建sprinboot服务环境 安装jdk安装nginx安装Redis安装MySQL一 下载MySQL二 安装MySQL三 启动mysql服务获取初始化密码四 登陆MySQL五 修改密码六 设置远程访问七 相关问题错误&#xff1a;1819错误&#xff1a;1251 或 2059错误&#xff1a;10060忽略表名大小写 记录搭建sprin…

[计算机提升] 创建FTP共享

4.7 创建FTP共享 4.7.1 FTP介绍 在Windows系统中&#xff0c;FTP共享是一种用于在网络上进行文件传输的标准协议。它可以让用户通过FTP客户端程序访问并下载或上传文件&#xff0c;实现文件共享。 FTP共享的用途非常广泛&#xff0c;例如可以让多个用户共享文件、进行文件备份…

solr 远程命令执行漏洞复现 (CVE-2019-17558)

solr 远程命令执行漏洞复现 (CVE-2019-17558) ‍ 名称: solr 远程命令执行 (CVE-2019-17558) 描述: Apache Velocity是一个基于Java的模板引擎&#xff0c;它提供了一个模板语言去引用由Java代码定义的对象。Velocity是Apache基金会旗下的一个开源软件项目&#xff0c;旨在确…

【抓包教程】BurpSuite联动雷电模拟器——安卓高版本抓包移动应用教程

前言 近期找到了最适合自己的高版本安卓版本移动应用抓HTTP协议数据包教程&#xff0c;解决了安卓低版本的问题&#xff0c;同时用最简单的办法抓到https的数据包&#xff0c;特此进行文字记录和视频记录。 前期准备 抓包工具&#xff1a;BurpSuite安卓模拟器&#xff1a;雷…

docker 利用特权模式逃逸并拿下主机

docker 利用特权模式逃逸并拿下主机 在溯源反制过程中&#xff0c;会经常遇到一些有趣的玩法&#xff0c;这里给大家分享一种docker在特权模式下逃逸&#xff0c;并拿下主机权限的玩法。 前言 在一次溯源反制过程中&#xff0c;发现了一个主机&#xff0c;经过资产收集之后&…

图形化编程:下一代的创新教育工具

在科技日新月异的今天&#xff0c;编程已经成为了一项必备的技能。然而&#xff0c;传统的编程语言对于许多人来说仍然是一项挑战&#xff0c;尤其是对于年轻的学习者。为了解决这个问题&#xff0c;图形化编程应运而生&#xff0c;它以其直观、易理解和易操作的特点&#xff0…

DevOps搭建(十六)-Jenkins+K8s部署详细步骤

​ 1、整体部署架构图 2、编写脚本 vi pipeline.yml apiVersion: apps/v1 kind: Deployment metadata:namespace: testname: pipelinelabels:app: pipeline spec:replicas: 2selector:matchLabels:app: pipelinetemplate:metadata:labels:app: pipelinespec:containers:- nam…