神经微分方程Resnet变体实现内存下降和保持精度

news2025/1/19 23:08:01

本文内容:

1、学习神经微分方程的笔记,主要锻炼自己学习新知识的能力和看有很多数学原理的论文能力;

2、神经微分方程可以用于时序数据建模、动力学建模等,但是本文专注于分类问题-resnet变体<比较容易理解>;


个人理解:

联合灵敏度的代码实现比较复杂,代码逻辑和算法步骤是一样的,对照看就很容易明白,其实本质上就是把梯度计算归结为求解微分的问题:

工程上实现OdeintAdjointMethod的方法是继承torch.autograd.Function类,实现forward和backward方法,将forward和backward替换成ODE求解器的方式,而不是用原来torch.autograd.Function的链式法则进行梯度求解。


基本原理:

梯度反传法是用于训练神经网络的方法,可以避免使用反向传播训练导数函数时所遇到的可扩展性问题。这种方法涉及使用普通微分方程(ODE)求解器进行前向传播,然后使用联合灵敏度方法进行反向传播,从而使得可以再次使用ODE求解器进行反向传播。为了更新导数函数的参数,需要使用联合灵敏度方法获取损失函数相对于动态函数参数的梯度。最终算法涉及设置某些变量并将信息打包到其中,然后调用ODE求解器反向传播以获得theta并更新CNN编码器和导数函数参数。

梯度反传法算法流程如下:

在这里插入图片描述

更完整的版本:

在这里插入图片描述


联合灵敏度的代码实现比较复杂,代码逻辑和算法步骤是一样的,对照看就很容易明白,其实本质上就是把梯度计算归结为求解微分的问题:

**工程上实现OdeintAdjointMethod的方法是继承torch.autograd.Function类,实现forward和backward方法,将forward和backward替换成ODE求解器的方式,而不是用原来torch.autograd.Function的链式法则进行梯度求解。**odeint则是本文使用的ODE求解器。


代码仓库提供的求解器种类如下:

SOLVERS = {
    'dopri8': Dopri8Solver,
    'dopri5': Dopri5Solver,
    'bosh3': Bosh3Solver,
    'fehlberg2': Fehlberg2,
    'adaptive_heun': AdaptiveHeunSolver,
    'euler': Euler,
    'midpoint': Midpoint,
    'rk4': RK4,
    'explicit_adams': AdamsBashforth,
    'implicit_adams': AdamsBashforthMoulton,
    # Backward compatibility: use the same name as before
    'fixed_adams': AdamsBashforthMoulton,
    # ~Backwards compatibility
    'scipy_solver': ScipyWrapperODESolver,
}

完整源码在:rtqichen/torchdiffeq: Differentiable ODE solvers with full GPU support and O(1)-memory backpropagation. (github.com),这里放上核心部分和注释,前向和反向传播部分;

class OdeintAdjointMethod(torch.autograd.Function):

    @staticmethod
    def forward(ctx, shapes, func, y0, t, rtol, atol, method, options, event_fn, adjoint_rtol, adjoint_atol, adjoint_method,
                adjoint_options, t_requires_grad, *adjoint_params):

        ctx.shapes = shapes
        ctx.func = func
        ctx.adjoint_rtol = adjoint_rtol
        ctx.adjoint_atol = adjoint_atol
        ctx.adjoint_method = adjoint_method
        ctx.adjoint_options = adjoint_options
        ctx.t_requires_grad = t_requires_grad
        ctx.event_mode = event_fn is not None

        with torch.no_grad():
            ans = odeint(func, y0, t, rtol=rtol, atol=atol, method=method, options=options, event_fn=event_fn)

            if event_fn is None:
                y = ans
                ctx.save_for_backward(t, y, *adjoint_params)
            else:
                event_t, y = ans
                ctx.save_for_backward(t, y, event_t, *adjoint_params)

        return ans

    @staticmethod
    def backward(ctx, *grad_y):
        with torch.no_grad():
            func = ctx.func
            adjoint_rtol = ctx.adjoint_rtol
            adjoint_atol = ctx.adjoint_atol
            adjoint_method = ctx.adjoint_method
            adjoint_options = ctx.adjoint_options
            t_requires_grad = ctx.t_requires_grad

            # 反向传播如果积分到达时间,不会在事件时间内反向传播。
            # Backprop as if integrating up to event time.
            # Does NOT backpropagate through the event time.
            event_mode = ctx.event_mode
            if event_mode:
                t, y, event_t, *adjoint_params = ctx.saved_tensors
                _t = t
                t = torch.cat([t[0].reshape(-1), event_t.reshape(-1)])
                grad_y = grad_y[1]
            else:
                t, y, *adjoint_params = ctx.saved_tensors
                grad_y = grad_y[0]

            adjoint_params = tuple(adjoint_params)

            ##################################
            #      创建初始化状态      #
            ##################################

            # [-1] because y and grad_y are both of shape (len(t), *y0.shape)
            aug_state = [torch.zeros((), dtype=y.dtype, device=y.device), y[-1], grad_y[-1]]  # vjp_t, y, vjp_y
            aug_state.extend([torch.zeros_like(param) for param in adjoint_params])  # vjp_params

            ##################################
            #    创建反向ODE函数    #
            ##################################

            # TODO: use a nn.Module and call odeint_adjoint to implement higher order derivatives.
            def augmented_dynamics(t, y_aug):
                # 动力学函数
                # Dynamics of the original system augmented with
                # the adjoint wrt y, and an integrator wrt t and args.
                y = y_aug[1]
                adj_y = y_aug[2]
                # ignore gradients wrt time and parameters

                with torch.enable_grad():
                    t_ = t.detach()
                    t = t_.requires_grad_(True)
                    y = y.detach().requires_grad_(True)

                    # If using an adaptive solver we don't want to waste time resolving dL/dt unless we need it (which
                    # doesn't necessarily even exist if there is piecewise structure in time), so turning off gradients
                    #如果使用自适应求解器,不想浪费时间来求解dL/dt,除非我们需要它(如果有分段结构,它甚至不存在),所以关闭梯度
                    # wrt t here means we won't compute that if we don't need it.
                    func_eval = func(t if t_requires_grad else t_, y)

                    # Workaround for PyTorch bug #39784
                    _t = torch.as_strided(t, (), ())  # noqa
                    _y = torch.as_strided(y, (), ())  # noqa
                    _params = tuple(torch.as_strided(param, (), ()) for param in adjoint_params)  # noqa

                    vjp_t, vjp_y, *vjp_params = torch.autograd.grad(
                        func_eval, (t, y) + adjoint_params, -adj_y,
                        allow_unused=True, retain_graph=True
                    )

                # autograd.grad returns None if no gradient, set to zero.
                vjp_t = torch.zeros_like(t) if vjp_t is None else vjp_t
                vjp_y = torch.zeros_like(y) if vjp_y is None else vjp_y
                vjp_params = [torch.zeros_like(param) if vjp_param is None else vjp_param
                              for param, vjp_param in zip(adjoint_params, vjp_params)]

                return (vjp_t, func_eval, vjp_y, *vjp_params)

            ##################################
            #      求解联合ODE       #
            ##################################

            if t_requires_grad:
                time_vjps = torch.empty(len(t), dtype=t.dtype, device=t.device)
            else:
                time_vjps = None
            for i in range(len(t) - 1, 0, -1):
                if t_requires_grad:
                    # Compute the effect of moving the current time measurement point.
                    # We don't compute this unless we need to, to save some computation.
                    func_eval = func(t[i], y[i])
                    dLd_cur_t = func_eval.reshape(-1).dot(grad_y[i].reshape(-1))
                    aug_state[0] -= dLd_cur_t
                    time_vjps[i] = dLd_cur_t

                # Run the augmented system backwards in time.
                # 运行增强系统反向
                aug_state = odeint(
                    augmented_dynamics, tuple(aug_state),
                    t[i - 1:i + 1].flip(0),
                    rtol=adjoint_rtol, atol=adjoint_atol, method=adjoint_method, options=adjoint_options
                )
                aug_state = [a[1] for a in aug_state]  # extract just the t[i - 1] value
                aug_state[1] = y[i - 1]  # update to use our forward-pass estimate of the state
                aug_state[2] += grad_y[i - 1]  # update any gradients wrt state at this time point

            if t_requires_grad:
                time_vjps[0] = aug_state[0]

            # 计算梯度
            # Only compute gradient wrt initial time when in event handling mode.
            if event_mode and t_requires_grad:
                time_vjps = torch.cat([time_vjps[0].reshape(-1), torch.zeros_like(_t[1:])])

            adj_y = aug_state[2]
            adj_params = aug_state[3:]

        return (None, None, adj_y, time_vjps, None, None, None, None, None, None, None, None, None, None, *adj_params)

模型架构:

使用small的residual net,在下采样后,有6个标准残差块被ODESolve替代,得到ODE-Net。

Sequential(
  (0): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1))
  (1): GroupNorm(32, 64, eps=1e-05, affine=True)
  (2): ReLU(inplace=True)
  (3): Conv2d(64, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  (4): GroupNorm(32, 64, eps=1e-05, affine=True)
  (5): ReLU(inplace=True)
  (6): Conv2d(64, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  (7): ODEBlock(
    (odefunc): ODEfunc(
      (norm1): GroupNorm(32, 64, eps=1e-05, affine=True)
      (relu): ReLU(inplace=True)
      (conv1): ConcatConv2d(
        (_layer): Conv2d(65, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
      (norm2): GroupNorm(32, 64, eps=1e-05, affine=True)
      (conv2): ConcatConv2d(
        (_layer): Conv2d(65, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
      (norm3): GroupNorm(32, 64, eps=1e-05, affine=True)
    )
  )
  (8): GroupNorm(32, 64, eps=1e-05, affine=True)
  (9): ReLU(inplace=True)
  (10): AdaptiveAvgPool2d(output_size=(1, 1))
  (11): Flatten()
  (12): Linear(in_features=64, out_features=10, bias=True)
)

结果如下:

参数量和内存比Resnet优,但是时间没有明确展示出来,只是展示时间复杂度;

在这里插入图片描述


### 运行resnet demo:
import os
import argparse
import logging
import time
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torchvision.datasets as datasets
import torchvision.transforms as transforms

parser = argparse.ArgumentParser()
parser.add_argument('--network', type=str, choices=['resnet', 'odenet'], default='odenet')
parser.add_argument('--tol', type=float, default=1e-3)
parser.add_argument('--adjoint', type=eval, default=False, choices=[True, False])
parser.add_argument('--downsampling-method', type=str, default='conv', choices=['conv', 'res'])
parser.add_argument('--nepochs', type=int, default=160)
parser.add_argument('--data_aug', type=eval, default=True, choices=[True, False])
parser.add_argument('--lr', type=float, default=0.1)
parser.add_argument('--batch_size', type=int, default=128)
parser.add_argument('--test_batch_size', type=int, default=1000)

parser.add_argument('--save', type=str, default='./experiment1')
parser.add_argument('--debug', action='store_true')
parser.add_argument('--gpu', type=int, default=0)
args = parser.parse_args()

if args.adjoint:
    from torchdiffeq import odeint_adjoint as odeint
else:
    from torchdiffeq import odeint


def conv3x3(in_planes, out_planes, stride=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)


def conv1x1(in_planes, out_planes, stride=1):
    """1x1 convolution"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)


def norm(dim):
    return nn.GroupNorm(min(32, dim), dim)


class ResBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(ResBlock, self).__init__()
        self.norm1 = norm(inplanes)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.norm2 = norm(planes)
        self.conv2 = conv3x3(planes, planes)

    def forward(self, x):
        shortcut = x

        out = self.relu(self.norm1(x))

        if self.downsample is not None:
            shortcut = self.downsample(out)

        out = self.conv1(out)
        out = self.norm2(out)
        out = self.relu(out)
        out = self.conv2(out)

        return out + shortcut


class ConcatConv2d(nn.Module):

    def __init__(self, dim_in, dim_out, ksize=3, stride=1, padding=0, dilation=1, groups=1, bias=True, transpose=False):
        super(ConcatConv2d, self).__init__()
        module = nn.ConvTranspose2d if transpose else nn.Conv2d
        self._layer = module(
            dim_in + 1, dim_out, kernel_size=ksize, stride=stride, padding=padding, dilation=dilation, groups=groups,
            bias=bias
        )

    def forward(self, t, x):
        tt = torch.ones_like(x[:, :1, :, :]) * t
        ttx = torch.cat([tt, x], 1)
        return self._layer(ttx)


class ODEfunc(nn.Module):

    def __init__(self, dim):
        super(ODEfunc, self).__init__()
        self.norm1 = norm(dim)
        self.relu = nn.ReLU(inplace=True)
        self.conv1 = ConcatConv2d(dim, dim, 3, 1, 1)
        self.norm2 = norm(dim)
        self.conv2 = ConcatConv2d(dim, dim, 3, 1, 1)
        self.norm3 = norm(dim)
        self.nfe = 0

    def forward(self, t, x):
        self.nfe += 1
        out = self.norm1(x)
        out = self.relu(out)
        out = self.conv1(t, out)
        out = self.norm2(out)
        out = self.relu(out)
        out = self.conv2(t, out)
        out = self.norm3(out)
        return out


class ODEBlock(nn.Module):

    def __init__(self, odefunc):
        super(ODEBlock, self).__init__()
        self.odefunc = odefunc
        self.integration_time = torch.tensor([0, 1]).float()

    def forward(self, x):
        self.integration_time = self.integration_time.type_as(x)
        out = odeint(self.odefunc, x, self.integration_time, rtol=args.tol, atol=args.tol)
        return out[1]

    @property
    def nfe(self):
        return self.odefunc.nfe

    @nfe.setter
    def nfe(self, value):
        self.odefunc.nfe = value


class Flatten(nn.Module):

    def __init__(self):
        super(Flatten, self).__init__()

    def forward(self, x):
        shape = torch.prod(torch.tensor(x.shape[1:])).item()
        return x.view(-1, shape)


class RunningAverageMeter(object):
    """Computes and stores the average and current value"""

    def __init__(self, momentum=0.99):
        self.momentum = momentum
        self.reset()

    def reset(self):
        self.val = None
        self.avg = 0

    def update(self, val):
        if self.val is None:
            self.avg = val
        else:
            self.avg = self.avg * self.momentum + val * (1 - self.momentum)
        self.val = val


def get_mnist_loaders(data_aug=False, batch_size=128, test_batch_size=1000, perc=1.0):
    if data_aug:
        transform_train = transforms.Compose([
            transforms.RandomCrop(28, padding=4),
            transforms.ToTensor(),
        ])
    else:
        transform_train = transforms.Compose([
            transforms.ToTensor(),
        ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
    ])

    train_loader = DataLoader(
        datasets.MNIST(root='.data/mnist', train=True, download=True, transform=transform_train), batch_size=batch_size,
        shuffle=True, num_workers=2, drop_last=True
    )

    train_eval_loader = DataLoader(
        datasets.MNIST(root='.data/mnist', train=True, download=True, transform=transform_test),
        batch_size=test_batch_size, shuffle=False, num_workers=2, drop_last=True
    )

    test_loader = DataLoader(
        datasets.MNIST(root='.data/mnist', train=False, download=True, transform=transform_test),
        batch_size=test_batch_size, shuffle=False, num_workers=2, drop_last=True
    )

    return train_loader, test_loader, train_eval_loader


def inf_generator(iterable):
    """Allows training with DataLoaders in a single infinite loop:
        for i, (x, y) in enumerate(inf_generator(train_loader)):
    """
    iterator = iterable.__iter__()
    while True:
        try:
            yield iterator.__next__()
        except StopIteration:
            iterator = iterable.__iter__()


def learning_rate_with_decay(batch_size, batch_denom, batches_per_epoch, boundary_epochs, decay_rates):
    initial_learning_rate = args.lr * batch_size / batch_denom

    boundaries = [int(batches_per_epoch * epoch) for epoch in boundary_epochs]
    vals = [initial_learning_rate * decay for decay in decay_rates]

    def learning_rate_fn(itr):
        lt = [itr < b for b in boundaries] + [True]
        i = np.argmax(lt)
        return vals[i]

    return learning_rate_fn


def one_hot(x, K):
    return np.array(x[:, None] == np.arange(K)[None, :], dtype=int)


def accuracy(model, dataset_loader):
    total_correct = 0
    for x, y in dataset_loader:
        x = x.to(device)
        y = one_hot(np.array(y.numpy()), 10)

        target_class = np.argmax(y, axis=1)
        predicted_class = np.argmax(model(x).cpu().detach().numpy(), axis=1)
        total_correct += np.sum(predicted_class == target_class)
    return total_correct / len(dataset_loader.dataset)


def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


def makedirs(dirname):
    if not os.path.exists(dirname):
        os.makedirs(dirname)


def get_logger(logpath, filepath, package_files=[], displaying=True, saving=True, debug=False):
    logger = logging.getLogger()
    if debug:
        level = logging.DEBUG
    else:
        level = logging.INFO
    logger.setLevel(level)
    if saving:
        info_file_handler = logging.FileHandler(logpath, mode="a")
        info_file_handler.setLevel(level)
        logger.addHandler(info_file_handler)
    if displaying:
        console_handler = logging.StreamHandler()
        console_handler.setLevel(level)
        logger.addHandler(console_handler)
    logger.info(filepath)
    with open(filepath, "r") as f:
        logger.info(f.read())

    for f in package_files:
        logger.info(f)
        with open(f, "r") as package_f:
            logger.info(package_f.read())

    return logger


if __name__ == '__main__':

    makedirs(args.save)
    logger = get_logger(logpath=os.path.join(args.save, 'logs'), filepath=os.path.abspath(__file__))
    logger.info(args)

    device = torch.device('cuda:' + str(args.gpu) if torch.cuda.is_available() else 'cpu')

    is_odenet = args.network == 'odenet'

    if args.downsampling_method == 'conv':
        downsampling_layers = [
            nn.Conv2d(1, 64, 3, 1),
            norm(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, 4, 2, 1),
            norm(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, 4, 2, 1),
        ]
    elif args.downsampling_method == 'res':
        downsampling_layers = [
            nn.Conv2d(1, 64, 3, 1),
            ResBlock(64, 64, stride=2, downsample=conv1x1(64, 64, 2)),
            ResBlock(64, 64, stride=2, downsample=conv1x1(64, 64, 2)),
        ]

    feature_layers = [ODEBlock(ODEfunc(64))] if is_odenet else [ResBlock(64, 64) for _ in range(6)]
    fc_layers = [norm(64), nn.ReLU(inplace=True), nn.AdaptiveAvgPool2d((1, 1)), Flatten(), nn.Linear(64, 10)]

    model = nn.Sequential(*downsampling_layers, *feature_layers, *fc_layers).to(device)

    logger.info(model)
    logger.info('Number of parameters: {}'.format(count_parameters(model)))

    criterion = nn.CrossEntropyLoss().to(device)

    train_loader, test_loader, train_eval_loader = get_mnist_loaders(
        args.data_aug, args.batch_size, args.test_batch_size
    )

    data_gen = inf_generator(train_loader)
    batches_per_epoch = len(train_loader)

    lr_fn = learning_rate_with_decay(
        args.batch_size, batch_denom=128, batches_per_epoch=batches_per_epoch, boundary_epochs=[60, 100, 140],
        decay_rates=[1, 0.1, 0.01, 0.001]
    )

    optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=0.9)

    best_acc = 0
    batch_time_meter = RunningAverageMeter()
    f_nfe_meter = RunningAverageMeter()
    b_nfe_meter = RunningAverageMeter()
    end = time.time()

    for itr in range(args.nepochs * batches_per_epoch):

        for param_group in optimizer.param_groups:
            param_group['lr'] = lr_fn(itr)

        optimizer.zero_grad()
        x, y = data_gen.__next__()
        x = x.to(device)
        y = y.to(device)
        logits = model(x)
        loss = criterion(logits, y)

        if is_odenet:
            nfe_forward = feature_layers[0].nfe
            feature_layers[0].nfe = 0

        loss.backward()
        optimizer.step()

        if is_odenet:
            nfe_backward = feature_layers[0].nfe
            feature_layers[0].nfe = 0

        batch_time_meter.update(time.time() - end)
        if is_odenet:
            f_nfe_meter.update(nfe_forward)
            b_nfe_meter.update(nfe_backward)
        end = time.time()

        if itr % batches_per_epoch == 0:
            with torch.no_grad():
                train_acc = accuracy(model, train_eval_loader)
                val_acc = accuracy(model, test_loader)
                if val_acc > best_acc:
                    torch.save({'state_dict': model.state_dict(), 'args': args}, os.path.join(args.save, 'model.pth'))
                    best_acc = val_acc
                logger.info(
                    "Epoch {:04d} | Time {:.3f} ({:.3f}) | NFE-F {:.1f} | NFE-B {:.1f} | "
                    "Train Acc {:.4f} | Test Acc {:.4f}".format(
                        itr // batches_per_epoch, batch_time_meter.val, batch_time_meter.avg, f_nfe_meter.avg,
                        b_nfe_meter.avg, train_acc, val_acc
                    )
                )

训练过程:

Epoch 0000 | Time 3.425 (3.425) | NFE-F 32.0 | NFE-B 0.0 | Train Acc 0.0987 | Test Acc 0.0958
Epoch 0001 | Time 3.279 (0.840) | NFE-F 20.3 | NFE-B 0.0 | Train Acc 0.9755 | Test Acc 0.9779
Epoch 0002 | Time 3.500 (0.839) | NFE-F 20.1 | NFE-B 0.0 | Train Acc 0.9858 | Test Acc 0.9875
Epoch 0003 | Time 3.403 (0.828) | NFE-F 20.1 | NFE-B 0.0 | Train Acc 0.9884 | Test Acc 0.9879
Epoch 0004 | Time 3.303 (0.807) | NFE-F 20.1 | NFE-B 0.0 | Train Acc 0.9926 | Test Acc 0.9921
Epoch 0005 | Time 3.308 (0.801) | NFE-F 20.1 | NFE-B 0.0 | Train Acc 0.9940 | Test Acc 0.9930
Epoch 0006 | Time 3.255 (0.804) | NFE-F 20.1 | NFE-B 0.0 | Train Acc 0.9917 | Test Acc 0.9894
Epoch 0007 | Time 3.376 (0.808) | NFE-F 20.2 | NFE-B 0.0 | Train Acc 0.9948 | Test Acc 0.9929
Epoch 0008 | Time 3.260 (0.806) | NFE-F 20.1 | NFE-B 0.0 | Train Acc 0.9935 | Test Acc 0.9934
Epoch 0009 | Time 3.248 (0.832) | NFE-F 20.4 | NFE-B 0.0 | Train Acc 0.9948 | Test Acc 0.9909
Epoch 0010 | Time 3.286 (0.817) | NFE-F 20.4 | NFE-B 0.0 | Train Acc 0.9959 | Test Acc 0.9947
Epoch 0011 | Time 3.281 (0.827) | NFE-F 20.8 | NFE-B 0.0 | Train Acc 0.9967 | Test Acc 0.9951
Epoch 0012 | Time 3.382 (0.825) | NFE-F 20.9 | NFE-B 0.0 | Train Acc 0.9949 | Test Acc 0.9929
Epoch 0013 | Time 3.299 (0.862) | NFE-F 22.0 | NFE-B 0.0 | Train Acc 0.9976 | Test Acc 0.9949
Epoch 0014 | Time 3.326 (0.824) | NFE-F 20.8 | NFE-B 0.0 | Train Acc 0.9947 | Test Acc 0.9936
Epoch 0015 | Time 3.291 (0.839) | NFE-F 21.3 | NFE-B 0.0 | Train Acc 0.9974 | Test Acc 0.9948
Epoch 0016 | Time 3.467 (0.935) | NFE-F 24.4 | NFE-B 0.0 | Train Acc 0.9977 | Test Acc 0.9941
Epoch 0017 | Time 3.483 (0.900) | NFE-F 23.2 | NFE-B 0.0 | Train Acc 0.9970 | Test Acc 0.9939
Epoch 0018 | Time 3.309 (0.872) | NFE-F 22.2 | NFE-B 0.0 | Train Acc 0.9961 | Test Acc 0.9932
Epoch 0019 | Time 3.294 (0.913) | NFE-F 23.6 | NFE-B 0.0 | Train Acc 0.9974 | Test Acc 0.9954
Epoch 0020 | Time 3.504 (0.984) | NFE-F 25.9 | NFE-B 0.0 | Train Acc 0.9983 | Test Acc 0.9951
Epoch 0021 | Time 3.589 (0.966) | NFE-F 25.2 | NFE-B 0.0 | Train Acc 0.9966 | Test Acc 0.9929
Epoch 0022 | Time 3.503 (0.994) | NFE-F 26.2 | NFE-B 0.0 | Train Acc 0.9977 | Test Acc 0.9949
Epoch 0023 | Time 3.457 (0.995) | NFE-F 26.1 | NFE-B 0.0 | Train Acc 0.9978 | Test Acc 0.9939
Epoch 0024 | Time 3.529 (0.985) | NFE-F 26.0 | NFE-B 0.0 | Train Acc 0.9985 | Test Acc 0.9958
Epoch 0025 | Time 3.459 (0.988) | NFE-F 26.0 | NFE-B 0.0 | Train Acc 0.9973 | Test Acc 0.9947
Epoch 0026 | Time 3.541 (0.988) | NFE-F 26.0 | NFE-B 0.0 | Train Acc 0.9979 | Test Acc 0.9946
Epoch 0027 | Time 3.513 (0.993) | NFE-F 26.1 | NFE-B 0.0 | Train Acc 0.9986 | Test Acc 0.9959
Epoch 0028 | Time 3.505 (0.996) | NFE-F 26.2 | NFE-B 0.0 | Train Acc 0.9982 | Test Acc 0.9953
Epoch 0029 | Time 3.501 (0.990) | NFE-F 26.1 | NFE-B 0.0 | Train Acc 0.9985 | Test Acc 0.9953
Epoch 0030 | Time 3.475 (0.992) | NFE-F 26.2 | NFE-B 0.0 | Train Acc 0.9983 | Test Acc 0.9954
Epoch 0031 | Time 3.506 (0.993) | NFE-F 26.2 | NFE-B 0.0 | Train Acc 0.9986 | Test Acc 0.9947
Epoch 0032 | Time 3.527 (0.995) | NFE-F 26.2 | NFE-B 0.0 | Train Acc 0.9981 | Test Acc 0.9954
Epoch 0033 | Time 3.529 (0.996) | NFE-F 26.2 | NFE-B 0.0 | Train Acc 0.9976 | Test Acc 0.9945
Epoch 0034 | Time 3.545 (0.996) | NFE-F 26.2 | NFE-B 0.0 | Train Acc 0.9988 | Test Acc 0.9959
Epoch 0035 | Time 3.479 (0.995) | NFE-F 26.2 | NFE-B 0.0 | Train Acc 0.9990 | Test Acc 0.9953
Epoch 0036 | Time 3.479 (0.997) | NFE-F 26.2 | NFE-B 0.0 | Train Acc 0.9989 | Test Acc 0.9963
Epoch 0037 | Time 3.540 (0.998) | NFE-F 26.2 | NFE-B 0.0 | Train Acc 0.9988 | Test Acc 0.9957

参考:

1、神经常微分方程 (Neural ODE):入门教程 - 知乎 (zhihu.com)

2、rtqichen/torchdiffeq: Differentiable ODE solvers with full GPU support and O(1)-memory backpropagation. (github.com)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/418547.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

StringBuffer,StringBuilder,

StringBuffer 结构示意图&#xff0c; Serializable,可以实现网络传输 package com.jshedu.StringBuffer_;/*** author Mr.jia* version 1.0*/public class StringBuffer01 {public static void main(String[] args) {/*1.在父类中AbstractStringBuilder 属性char[] value不是f…

博弈论在电动车和电网系统中分布式模型预测控制研究(Matlab代码实现)

&#x1f4a5;&#x1f4a5;&#x1f49e;&#x1f49e;欢迎来到本博客❤️❤️&#x1f4a5;&#x1f4a5; &#x1f3c6;博主优势&#xff1a;&#x1f31e;&#x1f31e;&#x1f31e;博客内容尽量做到思维缜密&#xff0c;逻辑清晰&#xff0c;为了方便读者。 ⛳️座右铭&a…

Android中使用有趣的指示器和过渡自定义 Compose Pager

Android使用有趣的指示器和过渡自定义 Compose Pager google最近在compose中新增了Pager控件&#xff0c;HorizontalPager和VerticalPager。 页面之间的转换 该文档涵盖了访问页面从“对齐”位置滚动到多远的基础知识。我们可以使用这些信息来创建页面之间的过渡效果。 例…

DC:4通关详解

信息收集 漏洞发现 访问web 尝试弱口令 账号admin 可以执行ls du df看看发的包,我们是否有机会执行任意命令 发现post传参radio处可以任意命令执行 弹个shell先 提权 从vps上下载LinEnum.sh来枚举脆弱性 优化shell 现在shell就有自动补齐了 在/home/jim下发现密码字典…

cube-studio AI平台 提供开源模型示例列表(3月份)

文章目录背景AI应用商店背景 cube是腾讯音乐开源的一站式云原生机器学习平台&#xff0c;目前主要包含 1、数据管理&#xff1a;特征存储、在线和离线特征&#xff1b;数据集管理、结构数据和媒体数据、数据标签平台 2、开发&#xff1a;notebook(vscode/jupyter)&#xff1b…

【PTA天梯赛】L1-001 L1-002 L1-003 L-004 L-005 L-006 L-007 L-008 L-009 L1-010 c++

&#x1f680; 个人简介&#xff1a;CSDN「博客新星」TOP 10 &#xff0c; C/C 领域新星创作者&#x1f49f; 作 者&#xff1a;锡兰_CC ❣️&#x1f4dd; 专 栏&#xff1a;狠狠的刷题&#xff01;&#xff01;&#xff01;&#x1f308; 若有帮助&#xff0c;还请…

【Ubuntu 22.04 上配置 FTP 服务器步骤】

Ubuntu 22.04 上配置 FTP 服务器步骤 1.安装 vsftpd 软件包&#xff1a; sudo apt-get update sudo apt-get install vsftpd 2.查看vsftpd版本和状态&#xff0c;确认vsftpd安装成功和正常启动 2.修改 vsftpd 配置文件&#xff1a; sudo nano /etc/vsftpd.conf 3.在配置文件中…

Ethercat概念学习

Ethercat技术调研 背景 最近我们要基于Ethercat技术进行开发&#xff0c;首先需要了解其基本原理&#xff0c;github上看到了有相关实现&#xff0c;一起来看看吧。 Ethercat技术 速度更快 传输速率:2*100 Mbaud 全双工 高速性、高实时性 微秒级 像火车一样有帧头、帧尾&a…

如何对农田温室气体进行有效模拟?

农业是甲烷&#xff08;CH4&#xff09;、氧化亚氮&#xff08;N2O&#xff09;和二氧化碳&#xff08;CO2&#xff09;等温室气体的主要排放源&#xff0c;占全产业排放的13.5%。农田温室气体又以施肥产生的N2O和稻田生产产生的CH4为主&#xff0c;如何对农田温室气体进行有效…

计算机组成原理(四)输入/输出系统

一、概述 1.1前言 I/O设备是计算机组成原理之硬件最后的一部分。输入输出系统是计算机系统当中种类最多、功能最多、结构最复杂、构成也最多样的系统。在现代计算机系统当中&#xff0c;外部设备的总成本可以占到计算机总成本的80%以上。可以说&#xff0c;没有这些丰富多彩的外…

「Vue面试题」Vue项目中有封装过axios吗?主要是封装哪方面的?

一、axios是什么 axios 是一个轻量的 HTTP客户端 基于 XMLHttpRequest 服务来执行 HTTP 请求&#xff0c;支持丰富的配置&#xff0c;支持 Promise&#xff0c;支持浏览器端和 Node.js 端。自Vue2.0起&#xff0c;尤大宣布取消对 vue-resource 的官方推荐&#xff0c;转而推荐…

(原创)Flutter基础入门:实现各种Shape效果

前言 上一篇博客讲了Flutter的装饰器Decoration Flutter基础入门&#xff1a;装饰器Decoration 装饰器就可以帮我们实现各种Shape效果 但上篇文章并没有讲如何实现具体的Shape效果 那么具体要怎么做呢&#xff1f;这篇文章就主要讲这块 在Fluter中实现Shape效果时&#xff0c;…

Servlet(一)

目录 1.什么是Servlet 2.servlet程序 2.1 创建项目 2.2 引入依赖 2.3 创建目录 2.4 编写代码 2.5 打包程序 2.6 部署程序 2.7 验证程序 3.更简单的部署方法 3.1 安装 3.2配置 4.访问出错怎么办 4.1 404 4.2 405 4.3 500 4.4 空白页面 4.5 无法访问此页面 5.se…

Gin web框架初步认识

Goland使用及gin框架下载引入 第一次使用Goland时需要配置GOROOT、GOPATH、Go Modules 配置完成后进入面板&#xff0c;右键选择Go Modules文件&#xff0c;或者在go工作区通过命令go mod init [name]创建go mod项目。 创建完的项目一般都有go.mod文件和go.sum&#xff0c;前者…

Mysql【安装教程】

Mysql安装教程 1.安装教程 可以去官网下载这个版本的&#xff1a;mysql-installer-community-8.0.31.0 双击点开&#xff0c;选择自定义&#xff1a; 选择主键&#xff1a;左边选择之后就点蓝色按钮添加到右边去&#xff0c;next&#xff1a; 如果出现这个页面&#xff0c…

机器视觉检测系统的基本流程你知道吗

工业制造业种&#xff0c;首先我们便需要了解其基本流程&#xff0c;作为工厂信息科人员&#xff0c;我们不能只依靠视觉服务商的巡检驻检来解决问题&#xff0c;为了产线的效率提升&#xff0c;我们更多的应该培养产线技术人员&#xff0c;出现问题便可以最快速度解决问题&…

领跑新能源车市“下半场”,这家企业凭什么?

中国新能源汽车市场行至下半场&#xff0c;将围绕技术升级、产品竞争力比拼、整合淘汰等趋势快速发展。 4月7日&#xff0c;在北京水立方发布的奇瑞新能源之夜上&#xff0c;奇瑞汽车全面展示新战略、新技术、新品牌和新产品&#xff0c;宣布将以全新的技术生态加速向全球科技…

光伏电池片技术N型迭代,机器视觉检测赋能完成产量“弯道超车”

电池片是光伏发电的核心部件&#xff0c;其技术路线和工艺水平直接影响光伏组件的发电效率和使用寿命。随着硅料、硅片技术逐渐接近其升级迭代空间的瓶颈&#xff0c;电池片环节正处于技术变革期&#xff0c;是光伏产业链中迭代最快的部分。P型中PERC电池片是现阶段市场的主流产…

已知原根多项式和寄存器初始值时求LFSR的简单例子

线性反馈移位寄存器&#xff08;LFSR&#xff09;是一种用于生成伪随机数序列的简单结构。在这里&#xff0c;我们有一个四项原根多项式 p(x)1x0x21102p(x) 1 x 0x^2 110_2p(x)1x0x21102​ 和初始值 S0100S_0 100S0​100。我们将使用 LFSR 动作过程来生成一个伪随机序列。…

2023美赛春季赛_赛题原文及翻译

目录 Problem Y: Understanding Used Sailboat Prices Y题翻译&#xff1a; Problem Z: The Future of the Olympics Z题翻译&#xff1a; Problem Y: Understanding Used Sailboat Prices ​Like many luxury goods, sailboats vary in value as they age and as market c…