分层图像金字塔变压器

news2025/2/23 1:53:07

文章来源:hierarchical-image-pyramid-transformers

2024 年 2 月 5 日

本文介绍了分层图像金字塔变换器 (HIPT),这是一种新颖的视觉变换器 (ViT) 架构,设计用于分析计算病理学中的十亿像素全幻灯片图像 (WSI)。 HIPT 利用 WSI 固有的层次结构通过自我监督学习来学习高分辨率图像表示。 HIPT 在涵盖 33 种癌症类型的大型数据集上进行预训练,并在多个幻灯片级任务中进行评估,在癌症亚型分型和生存预测方面表现出卓越的性能,展示了自我监督学习模型在捕获肿瘤微环境中关键的归纳偏差和表型方面的潜力。

1

本图展示了计算病理学中使用的全切片图像 (WSI) 的分层结构。左图显示的是多层次方法,在这种方法中,大型组织图像(150,000 x 150,000 像素)被分解成更小、更易于管理的部分:首先是显示组织表型的 4096 x 4096 区域,然后是 256 x 256 细胞组织斑块,最后是最小的 16 x 16 细胞特征。右图展示了 256 x 256 图像是如何由 256 个较小的 16 x 16 标记序列组成的,反过来,每个 256 x 256 图像又是如何成为 4096 x 4096 区域内 256 x 256 标记的更大的不连续序列的一部分。这种分层标记化方法可以处理和分析不同分辨率和比例的超大图像。

该模型由三个阶段的分层聚合组成,首先是自下而上地聚合各自 256x256 和 4096x4096 窗口中的 16x16 视觉标记,最终形成幻灯片级表示。HIPT 模型的主要组成部分可写如下:

1. 分层聚合: HIPT 在细胞、斑块和区域层面聚合视觉标记,形成幻灯片表征。这种分层方法是受自然语言处理中使用分层表示法的启发,在自然语言处理中,嵌入可以在不同层次上聚合,形成文档表示法。同样,在 WSI 的背景下,分层聚合允许模型捕捉不同粒度级别的信息,从单个细胞到更广泛的组织结构。

2. Transformer自注意力: 为了在聚合的每个阶段对视觉概念之间的重要依赖关系进行建模,HIPT 将 Transformer 自注意力调整为包络变换聚合层。这样,该模型就能捕捉视觉标记之间的复杂关系,并学习能编码图像中局部和全局上下文的表征。

3. 预训练和自我监督学习: HIPT 采用自我监督学习的方式对 33 种癌症类型的千兆像素 WSI 大数据集进行预训练。该模型利用两个层次的自我监督学习来学习高分辨率图像表征,并利用学生-教师知识提炼来对每个聚合层进行预训练,对大至 4096x4096 的区域进行自我监督学习。

4. 性能和应用: 研究结果表明,采用分层预训练的 HIPT 在幻灯片级任务上的表现优于目前最先进的方法。该模型的性能在包括癌症亚型和生存预测在内的 9 项幻灯片级任务上进行了评估,并显示其在捕捉组织微环境中更广泛的预后特征方面表现出色。

2

图中从左到右显示了三个聚合级别:

  1. 细胞级聚合: 单个细胞由 16 px tokens表示,然后使用 ViT256-16 模型将其聚合为片段级表示,再进行全局池化以获得单一矢量表示。
  2. 斑块级聚合: 使用专为 256 px 输入设计的更大 ViT 变体来处理 256 px 补丁,然后再次使用池化层将补丁级特征汇总为区域级表示。
  3. 区域级聚合: 最后,对 4096 px 的区域进行聚合,这一次使用的是将整个区域作为输入的 ViT,从而形成一个全局注意力汇集层,提供幻灯片级表示。

这一分层过程将问题分解为易于处理的部分,并关注从细胞到组织结构等不同层次的细节,从而使模型能够处理规模巨大的 WSI。

下面的脚本利用了专门用于高分辨率图像分析的视觉转换器(ViTs),并结合了几种先进的功能和技术:

1. 截断法线初始化: 这是一种用于初始化神经网络权重的技术,可避免与平均值产生较大偏差,从而确保早期训练阶段的稳定性。

2. Drop Path: 一种正则化方法,在训练过程中随机丢弃网络中的路径,通过模拟更薄的网络来提高泛化效果,类似于dropout,但针对的是残余连接。

3. 多层感知器(MLP)模块: 定义一个简单的双层 MLP,具有 GELU 激活函数和滤除功能,用于在转换器模块中处理特征。

4. 注意机制:采用可选偏置和缩放的自注意机制,这对捕捉输入数据中的全局依赖性至关重要。

5. Transformer模块: 将规范层、注意机制和 MLP 组合成一个内聚块,并可选择路径剔除进行正则化。

6. VisionTransformer4K:Vision Transformer 的专用版本,专为超高分辨率图像而设计,采用了位置嵌入插值等技术,以适应不同的图像尺寸,其结构也针对处理大规模图像进行了优化。

7. 实用功能: 包括用于截断法线权重初始化、下落路径模拟和参数计算的函数,以帮助进行模型设置和分析。

import argparse
import os
import sys
import datetime
import time
import math
import json
from pathlib import Path
import numpy as np
from PIL import Image
import torch
import torch.nn as nn
import torch.distributed as dist
import torch.backends.cudnn as cudnn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torchvision import models as torchvision_models
import vision_transformer as vits
from vision_transformer import DINOHead
import math
from functools import partial
import torch
import torch.nn as nn
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
    # Cut & paste from PyTorch official master until it's in a few official releases - RW
    # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
    def norm_cdf(x):
        # Computes standard normal cumulative distribution function
        return (1. + math.erf(x / math.sqrt(2.))) / 2.
    if (mean < a - 2 * std) or (mean > b + 2 * std):
        warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
                      "The distribution of values may be incorrect.",
                      stacklevel=2)
    with torch.no_grad():
        # Values are generated by using a truncated uniform distribution and
        # then using the inverse CDF for the normal distribution.
        # Get upper and lower cdf values
        l = norm_cdf((a - mean) / std)
        u = norm_cdf((b - mean) / std)
        # Uniformly fill tensor with values from [l, u], then translate to
        # [2l-1, 2u-1].
        tensor.uniform_(2 * l - 1, 2 * u - 1)
        # Use inverse cdf transform for normal distribution to get truncated
        # standard normal
        tensor.erfinv_()
        # Transform to proper mean, std
        tensor.mul_(std * math.sqrt(2.))
        tensor.add_(mean)
        # Clamp to ensure it's in the proper range
        tensor.clamp_(min=a, max=b)
        return tensor
def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
    # type: (Tensor, float, float, float, float) -> Tensor
    return _no_grad_trunc_normal_(tensor, mean, std, a, b)

def drop_path(x, drop_prob: float = 0., training: bool = False):
    if drop_prob == 0. or not training:
        return x
    keep_prob = 1 - drop_prob
    shape = (x.shape[0],) + (1,) * (x.ndim - 1)  # work with diff dim tensors, not just 2D ConvNets
    random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
    random_tensor.floor_()  # binarize
    output = x.div(keep_prob) * random_tensor
    return output
class DropPath(nn.Module):
    """Drop paths (Stochastic Depth) per sample  (when applied in main path of residual blocks).
    """
    def __init__(self, drop_prob=None):
        super(DropPath, self).__init__()
        self.drop_prob = drop_prob
    def forward(self, x):
        return drop_path(x, self.drop_prob, self.training)
class Mlp(nn.Module):
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)
    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x
class Attention(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim ** -0.5
        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)
    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]
        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)
        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x, attn
class Block(nn.Module):
    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
                 drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.attn = Attention(
            dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
    def forward(self, x, return_attention=False):
        y, attn = self.attn(self.norm1(x))
        if return_attention:
            return attn
        x = x + self.drop_path(y)
        x = x + self.drop_path(self.mlp(self.norm2(x)))
        return x
class VisionTransformer4K(nn.Module):
    """ Vision Transformer 4K """
    def __init__(self, num_classes=0, img_size=[224], input_embed_dim=384, output_embed_dim = 192,
                 depth=12, num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, 
                 drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm, num_prototypes=64, **kwargs):
        super().__init__()
        embed_dim = output_embed_dim
        self.num_features = self.embed_dim = embed_dim
        self.phi = nn.Sequential(*[nn.Linear(input_embed_dim, output_embed_dim), nn.GELU(), nn.Dropout(p=drop_rate)])
        num_patches = int(img_size[0] // 16)**2
        print("# of Patches:", num_patches)
        
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
        self.pos_drop = nn.Dropout(p=drop_rate)
        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]  # stochastic depth decay rule
        self.blocks = nn.ModuleList([
            Block(
                dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
                drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer)
            for i in range(depth)])
        self.norm = norm_layer(embed_dim)
        # Classifier head
        self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
        trunc_normal_(self.pos_embed, std=.02)
        trunc_normal_(self.cls_token, std=.02)
        self.apply(self._init_weights)
    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
    def interpolate_pos_encoding(self, x, w, h):
        npatch = x.shape[1] - 1
        N = self.pos_embed.shape[1] - 1
        if npatch == N and w == h:
            return self.pos_embed
        class_pos_embed = self.pos_embed[:, 0]
        patch_pos_embed = self.pos_embed[:, 1:]
        dim = x.shape[-1]
        w0 = w // 1
        h0 = h // 1
        # we add a small number to avoid floating point error in the interpolation
        # see discussion at https://github.com/facebookresearch/dino/issues/8
        w0, h0 = w0 + 0.1, h0 + 0.1
        patch_pos_embed = nn.functional.interpolate(
            patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
            scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
            mode='bicubic',
        )
        assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1]
        patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
        return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
    def prepare_tokens(self, x):
        #print('preparing tokens (after crop)', x.shape)
        self.mpp_feature = x
        B, embed_dim, w, h = x.shape
        x = x.flatten(2, 3).transpose(1,2)
        x = self.phi(x)
        # add the [CLS] token to the embed patch tokens
        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        # add positional encoding to each token
        x = x + self.interpolate_pos_encoding(x, w, h)
        return self.pos_drop(x)
    def forward(self, x):
        x = self.prepare_tokens(x)
        for blk in self.blocks:
            x = blk(x)
        x = self.norm(x)
        return x[:, 0]
    def get_last_selfattention(self, x):
        x = self.prepare_tokens(x)
        for i, blk in enumerate(self.blocks):
            if i < len(self.blocks) - 1:
                x = blk(x)
            else:
                # return attention of the last block
                return blk(x, return_attention=True)
    def get_intermediate_layers(self, x, n=1):
        x = self.prepare_tokens(x)
        # we return the output tokens from the `n` last blocks
        output = []
        for i, blk in enumerate(self.blocks):
            x = blk(x)
            if len(self.blocks) - i <= n:
                output.append(self.norm(x))
        return output
    
def vit4k_xs(patch_size=16, **kwargs):
    model = VisionTransformer4K(
        patch_size=patch_size, input_embed_dim=384, output_embed_dim=192,
        depth=6, num_heads=6, mlp_ratio=4, 
        qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    return model
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

下面的代码脚本概述了加载和评估用于图像分析的 Vision Transformer (ViT) 模型的实现过程,该模型专为计算病理学中的高分辨率图像而设计。它定义了以下功能

1. 加载预训练的 ViT 模型(`get_vit256` 和 `get_vit4k`),并为不同的架构和设备设置提供选项,在不进行梯度计算的评估模式下对其进行初始化。

2. 为模型评估应用变换 (`eval_transforms`),以特定的平均值和标准偏差对图像进行归一化处理。

3. 将成批的图像张量转换为单个 PIL 图像(`roll_batch2img`)或 numpy 数组(`tensorbatch2im`),便于处理图像数据,以实现可视化或进一步处理。

### Dependencies
# Base Dependencies
import argparse
import colorsys
from io import BytesIO
import os
import random
import requests
import sys
# LinAlg / Stats / Plotting Dependencies
import cv2
import h5py
import matplotlib
import matplotlib.pyplot as plt
from matplotlib.patches import Polygon
import numpy as np
from PIL import Image
from PIL import ImageFont
from PIL import ImageDraw 
from scipy.stats import rankdata
import skimage.io
from skimage.measure import find_contours
from tqdm import tqdm
import webdataset as wds
# Torch Dependencies
import torch
import torch.multiprocessing
import torchvision
from torchvision import transforms
from einops import rearrange, repeat
torch.multiprocessing.set_sharing_strategy('file_system')
# Local Dependencies
import vision_transformer as vits
import vision_transformer4k as vits4k
def get_vit256(pretrained_weights, arch='vit_small', device=torch.device('cuda:0')):
    r"""
    Builds ViT-256 Model.
    
    Args:
    - pretrained_weights (str): Path to ViT-256 Model Checkpoint.
    - arch (str): Which model architecture.
    - device (torch): Torch device to save model.
    
    Returns:
    - model256 (torch.nn): Initialized model.
    """
    
    checkpoint_key = 'teacher'
    device = torch.device("cpu")
    model256 = vits.__dict__[arch](patch_size=16, num_classes=0)
    for p in model256.parameters():
        p.requires_grad = False
    model256.eval()
    model256.to(device)
    if os.path.isfile(pretrained_weights):
        state_dict = torch.load(pretrained_weights, map_location="cpu")
        if checkpoint_key is not None and checkpoint_key in state_dict:
            print(f"Take key {checkpoint_key} in provided checkpoint dict")
            state_dict = state_dict[checkpoint_key]
        # remove `module.` prefix
        state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
        # remove `backbone.` prefix induced by multicrop wrapper
        state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()}
        msg = model256.load_state_dict(state_dict, strict=False)
        print('Pretrained weights found at {} and loaded with msg: {}'.format(pretrained_weights, msg))
        
    return model256
def get_vit4k(pretrained_weights, arch='vit4k_xs', device=torch.device('cuda:1')):
    r"""
    Builds ViT-4K Model.
    
    Args:
    - pretrained_weights (str): Path to ViT-4K Model Checkpoint.
    - arch (str): Which model architecture.
    - device (torch): Torch device to save model.
    
    Returns:
    - model256 (torch.nn): Initialized model.
    """
    
    checkpoint_key = 'teacher'
    device = torch.device("cpu")
    model4k = vits4k.__dict__[arch](num_classes=0)
    for p in model4k.parameters():
        p.requires_grad = False
    model4k.eval()
    model4k.to(device)
    if os.path.isfile(pretrained_weights):
        state_dict = torch.load(pretrained_weights, map_location="cpu")
        if checkpoint_key is not None and checkpoint_key in state_dict:
            print(f"Take key {checkpoint_key} in provided checkpoint dict")
            state_dict = state_dict[checkpoint_key]
        # remove `module.` prefix
        state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
        # remove `backbone.` prefix induced by multicrop wrapper
        state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()}
        msg = model4k.load_state_dict(state_dict, strict=False)
        print('Pretrained weights found at {} and loaded with msg: {}'.format(pretrained_weights, msg))
        
    return model4k
def eval_transforms():
 """
 """
 mean, std = (0.5, 0.5, 0.5), (0.5, 0.5, 0.5)
 eval_t = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean = mean, std = std)])
 return eval_t
def roll_batch2img(batch: torch.Tensor, w: int, h: int, patch_size=256):
 """
 Rolls an image tensor batch (batch of [256 x 256] images) into a [W x H] Pil.Image object.
 Args:
  batch (torch.Tensor): [B x 3 x 256 x 256] image tensor batch.
 Return:
  Image.PIL: [W x H X 3] Image.
 """
 batch = batch.reshape(w, h, 3, patch_size, patch_size)
 img = rearrange(batch, 'p1 p2 c w h-> c (p1 w) (p2 h)').unsqueeze(dim=0)
 return Image.fromarray(tensorbatch2im(img)[0])
def tensorbatch2im(input_image, imtype=np.uint8):
    r""""
    Converts a Tensor array into a numpy image array.
    
    Args:
        - input_image (torch.Tensor): (B, C, W, H) Torch Tensor.
        - imtype (type): the desired type of the converted numpy array
        
    Returns:
        - image_numpy (np.array): (B, W, H, C) Numpy Array.
    """
    if not isinstance(input_image, np.ndarray):
        image_numpy = input_image.cpu().float().numpy()  # convert it into a numpy array
        #if image_numpy.shape[0] == 1:  # grayscale to RGB
        #    image_numpy = np.tile(image_numpy, (3, 1, 1))
        image_numpy = (np.transpose(image_numpy, (0, 2, 3, 1)) + 1) / 2.0 * 255.0  # post-processing: tranpose and scaling
    else:  # if it is a numpy array, do nothing
        image_numpy = input_image
    return image_numpy.astype(imtype)

该脚本定义了 HIPT_4K 模型,整合了用于处理高分辨率图像的视觉转换器模型。它为 256x256 和 4K 分辨率加载预先训练好的 ViT 模型,并将其分层应用于输入图像。这一过程包括将输入图像裁剪成 256x256 的补丁,使用 ViT_256 从每个补丁中提取特征,然后将这些特征输入 ViT_4K,以获得全局表示。这种分层方法能有效处理非方形的高分辨率图像,优化局部和全局尺度的详细特征提取,与本文利用分层结构进行图像分析的方法相一致。

import torch
from einops import rearrange, repeat
from HIPT_4K.hipt_model_utils import get_vit256, get_vit4k
class HIPT_4K(torch.nn.Module):
    """
    HIPT Model (ViT_4K-256) for encoding non-square images (with [256 x 256] patch tokens), with 
    [256 x 256] patch tokens encoded via ViT_256-16 using [16 x 16] patch tokens.
    """
    def __init__(self, 
        model256_path: str = 'path/to/Checkpoints/vit256_small_dino.pth',
        model4k_path: str = 'path/to/Checkpoints/vit4k_xs_dino.pth', 
        device256=torch.device('cuda:0'), 
        device4k=torch.device('cuda:1')):
        super().__init__()
        self.model256 = get_vit256(pretrained_weights=model256_path).to(device256)
        self.model4k = get_vit4k(pretrained_weights=model4k_path).to(device4k)
        self.device256 = device256
        self.device4k = device4k
        self.patch_filter_params = patch_filter_params
 
    def forward(self, x):
        """
        Forward pass of HIPT (given an image tensor x), outputting the [CLS] token from ViT_4K.
        1. x is center-cropped such that the W / H is divisible by the patch token size in ViT_4K (e.g. - 256 x 256).
        2. x then gets unfolded into a "batch" of [256 x 256] images.
        3. A pretrained ViT_256-16 model extracts the CLS token from each [256 x 256] image in the batch.
        4. These batch-of-features are then reshaped into a 2D feature grid (of width "w_256" and height "h_256".)
        5. This feature grid is then used as the input to ViT_4K-256, outputting [CLS]_4K.
        Args:
          - x (torch.Tensor): [1 x C x W' x H'] image tensor.
        Return:
          - features_cls4k (torch.Tensor): [1 x 192] cls token (d_4k = 192 by default).
        """
        batch_256, w_256, h_256 = self.prepare_img_tensor(x)                    # 1. [1 x 3 x W x H].
        batch_256 = batch_256.unfold(2, 256, 256).unfold(3, 256, 256)           # 2. [1 x 3 x w_256 x h_256 x 256 x 256] 
        batch_256 = rearrange(batch_256, 'b c p1 p2 w h -> (b p1 p2) c w h')    # 2. [B x 3 x 256 x 256], where B = (1*w_256*h_256)
        features_cls256 = []
        for mini_bs in range(0, batch_256.shape[0], 256):                       # 3. B may be too large for ViT_256. We further take minibatches of 256.
            minibatch_256 = batch_256[mini_bs:mini_bs+256].to(self.device256, non_blocking=True)
            features_cls256.append(self.model256(minibatch_256).detach().cpu()) # 3. Extracting ViT_256 features from [256 x 3 x 256 x 256] image batches.
        features_cls256 = torch.vstack(features_cls256)                         # 3. [B x 384], where 384 == dim of ViT-256 [ClS] token.
        features_cls256 = features_cls256.reshape(w_256, h_256, 384).transpose(0,1).transpose(0,2).unsqueeze(dim=0) 
        features_cls256 = features_cls256.to(self.device4k, non_blocking=True)  # 4. [1 x 384 x w_256 x h_256]
        features_cls4k = self.model4k.forward(features_cls256)                  # 5. [1 x 192], where 192 == dim of ViT_4K [ClS] token.
        return features_cls4k

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1641597.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

面经总结系列(二): 面壁智能大模型算法工程师

&#x1f468;‍&#x1f4bb;作者简介&#xff1a; CSDN、阿里云人工智能领域博客专家&#xff0c;新星计划计算机视觉导师&#xff0c;百度飞桨PPDE&#xff0c;专注大数据与AI知识分享。✨公众号&#xff1a;GoAI的学习小屋 &#xff0c;免费分享书籍、简历、导图等&#xf…

Mysql基础篇(一)Mysql概述

基本概念 数据库(DataBase,DB) 数据库的定义 按照数据结构来组织、存储和管理数据的仓库。 严格意义上来说&#xff0c;数据库是一个实体&#xff0c;它是能够合理保管数据的“仓库”&#xff0c;用户在该“仓库”中存放要管理的事务数据&#xff0c;“数据”和“库”两个概念…

HTML5+CSS3小实例:无限循环loading动画

实例:无限循环loading动画 技术栈:HTML+CSS 效果: 源码: 【HTML】 <!DOCTYPE html> <html lang="zh-CN"><head><meta charset="UTF-8"><meta name="viewport" content="width=device-width, initial-sc…

大数据分析入门之10分钟掌握GROUP BY语法

前言 书接上回大数据分析入门10分钟快速了解SQL。 本篇将会进一步介绍group by语法。 基本语法 SELECT column_name, aggregate_function(column_name) FROM table_name GROUP BY column_name HAVING condition假设我们有students表&#xff0c;其中有id,grade_number,class…

Matlab画箱线图

⚠申明&#xff1a; 未经许可&#xff0c;禁止以任何形式转载&#xff0c;若要引用&#xff0c;请标注链接地址。 全文共计3077字&#xff0c;阅读大概需要3分钟 &#x1f308;更多学习内容&#xff0c; 欢迎&#x1f44f;关注&#x1f440;【文末】我的个人微信公众号&#xf…

内网安全-代理Socks协议路由不出网后渗透通讯CS-MSF控制上线简单总结

我这里只记录原理&#xff0c;具体操作看文章后半段或者这篇文章内网渗透—代理Socks协议、路由不出网、后渗透通讯、CS-MSF控制上线_内网渗透 代理-CSDN博客 注意这里是解决后渗透通讯问题&#xff0c;之后怎么提权&#xff0c;控制后面再说 背景 只有win7有网&#xff0c;其…

Unity Trail Renderer入门

概述&#xff1a; 在项目的开发过程中&#xff0c;一定有时候需要炫酷的尾迹效果&#xff0c;那接下来这部分的内容&#xff0c;一定不要错过&#xff01; Trail Renderer&#xff08;尾迹渲染&#xff09; Time&#xff1a;尾迹存在的时间&#xff0c;时间越长尾迹存在的越久…

无人机+无人车:自组网协同技术及应用前景详解

无人车&#xff0c;也被称为自动驾驶汽车、电脑驾驶汽车或轮式移动机器人&#xff0c;是一种通过电脑系统实现无人驾驶的智能汽车。这种汽车依靠人工智能、视觉计算、雷达、监控装置和全球定位系统协同合作&#xff0c;使得电脑可以在没有任何人类主动操作的情况下&#xff0c;…

总分420+专业140+哈工大哈尔滨工业大学803信号与系统和数字逻辑电路考研电子信息与通信工程,真题,大纲,参考书。

考研复习一路走来&#xff0c;成绩还是令人满意&#xff0c;专业803信号和数电140&#xff0c;总分420&#xff0c;顺利上岸&#xff0c;总结一下自己这一年复习经历&#xff0c;希望大家可以所有参考&#xff0c;这一年复习跌跌拌拌&#xff0c;有时面对压力也会焦虑&#xff…

【算法系列】字符串

目录 leetcode题目 一、最长公共前缀 二、最长回文子串 三、二进制求和 四、字符串相加 五、字符串相乘 六、仅仅反转字母 七、字符串最后一个单词的长度 八、验证回文串 九、反转字符串 十、反转字符串 II 十一、反转字符串中的单词 III leetcode题目 一、最长公…

[Kubernetes] 安装KubeSphere

选择4核8G&#xff08;master&#xff09;、8核16G&#xff08;node1&#xff09;、8核16G&#xff08;node2&#xff09; 三台机器&#xff0c;按量付费进行实验&#xff0c;CentOS7.9安装Docker安装Kubernetes安装KubeSphere前置环境: nfs和监控安装KubeSphere masternode1no…

从零开始学AI绘画,万字Stable Diffusion终极教程(三)

【第3期】Lora模型 欢迎来到SD的终极教程&#xff0c;这是我们的第三节课 这套课程分为六节课&#xff0c;会系统性的介绍sd的全部功能&#xff0c;让你打下坚实牢靠的基础 1.SD入门 2.关键词 3.Lora模型 4.图生图 5.controlnet 6.知识补充 在SD里面&#xff0c;有一个…

基础I/O--文件系统

文章目录 回顾C文件接口初步理解文件理解文件使用和并认识系统调用open概述标记位传参理解返回值 closewriteread总结 文件描述符fd0&1&2理解 回顾C文件接口 C代码&#xff1a; #include<stdio.h> int main() { FILE *fpfopen("log.txt",&…

基于Pytorch深度学习——GPU安装/使用

本文章来源于对李沐动手深度学习代码以及原理的理解&#xff0c;并且由于李沐老师的代码能力很强&#xff0c;以及视频中讲解代码的部分较少&#xff0c;所以这里将代码进行尽量逐行详细解释 并且由于pytorch的语法有些小伙伴可能并不熟悉&#xff0c;所以我们会采用逐行解释小…

用git上传本地文件到github

两种方式&#xff1a;都需要git软件&#xff08;1&#xff09;VScode上传 &#xff08;2&#xff09;直接命令行&#xff0c;后者不需要VScode软件 &#xff08;1&#xff09;vscode 上传非常方便&#xff0c;前提是下载好了vscode和git软件 1 在项目空白处右击&#xff0c;弹…

字符函数与字符串函数(2)

遇见她如春水映莲花 字符函数与字符串函数&#xff08;2&#xff09; 前言一、strcatstrncat 二、strcmpstrncmp在这里插入图片描述 三、strstr四、strtok五、strerror总结 前言 根据上期字符函数与字符串函数我们可以了解到字符函数与个别字符串函数的用法&#xff0c; 那么接…

手写一个uart协议——rs232

先了解一下关于uart和rs232的基础知识 文章目录 一、RS232的回环测试1.1模块整体架构1.2 rx模块设计1.2.1 波形设计1.2.2代码实现与tb1.2.4 仿真 1.3 tx模块设计1.3.1 波形设计1.3.2 代码实现与tb1.3.4 顶层设计1.3.3 仿真 本篇内容&#xff1a; 一、RS232的回环测试 上位机…

改变视觉创造力:图像合成中基于样式的生成架构的影响和创新

原文地址&#xff1a;revolutionizing-visual-creativity-the-impact-and-innovations-of-style-based-generative 2024 年 4 月 30 日 介绍 基于风格的生成架构已经开辟了一个利基市场&#xff0c;它将机器学习的技术严谨性与类人创造力的微妙表现力融为一体。这一发展的核…

4.3 JavaScript变量

4.3.1 变量的声明 JavaScript是一种弱类型的脚本语言&#xff0c;无论是数字、文本还是其他内容&#xff0c;统一使用关键词var加上变量名称进行声明&#xff0c;其中关键词var来源于英文单词variable&#xff08;变量&#xff09;的前三个字母。 可以在声明变量的同时对其指定…

汇川AM400PLC通过EtherCAT总线控制禾川X3E伺服使能和点动控制

进行通信之前需要安装禾川X3E的XML文件&#xff0c;具体方法如下&#xff1a; 1、汇川AM400PLC和X3E通信配置 汇川AM400PLC和禾川X3E伺服EtherCAT通信-CSDN博客文章浏览阅读29次。1、汇川H5UPLC和X3E伺服EtherCAT总线控制汇川H5U PLC通过EtherCAT总线控制SV660N和X3E伺服_伺服…