LK 博客
VIT:图像分割与嵌入技术详解
大数据
约 1 分钟阅读 1 赞 0 条评论 鸿蒙黑体

VIT:图像分割与嵌入技术详解

zihan
王子翰 @zihan
累计点赞 1 登录后每个账号只能点一次
内容长度 0 正文词元数
正文
目录会跟随阅读位置移动。
阅读进度

visual transform(视觉变换)是一种基于深度学习的计算机视觉模型,它通过图像分割技术使计算机能够深入理解图像内容。这种模型的核心思想是将图像分割为多个有意义的区域或对象,从而实现对图像内容的语义级理解。

主要特点包括:

像素级分割能力:能够精确到单个像素级别进行图像分割 上下文理解:通过自注意力机制捕捉图像中各部分之间的空间关系 多尺度特征提取:同时处理不同尺度的视觉特征 典型应用场景:

医学影像分析:精确分割肿瘤区域或器官结构 自动驾驶:识别道路、行人和其他车辆 遥感图像处理:区分不同地物类型 工业质检:检测产品缺陷区域 工作流程示例:

输入图像预处理(尺寸调整、归一化等) 通过编码器提取多层级特征 使用注意力机制建立特征间的空间关系 解码器将特征映射回原始分辨率 输出每个像素的类别预测 相比传统计算机视觉方法,visual transform的优势在于:

更好的长距离依赖建模能力 更准确的边界划分 更强的泛化性能 端到端的训练方式 目前最先进的实现包括Vision Transformer (ViT)、Swin Transformer等架构,这些模型在ImageNet等基准测试中展现了卓越的性能。

Embedding层

将图像分割处理

以224*224像素的图像为例

输入维度为【B,3,224,224】

采用16×16的patch_size进行分割,得到14*14=196个图像块 输出维度为【B,196,768】 代码实现

class Embedding(nn.Module):
    def __init__(self, patch_size=16, image_size=224, num_channels=3, hidden_dim=768):
        super().__init__()

        self.patch_size = patch_size
        self.image_size = image_size
        self.num_channels = num_channels
        self.hidden_dim = hidden_dim

        # patch 数量
        self.num_patches = (image_size // patch_size) * (image_size // patch_size)

        # 用卷积实现 patch embedding
        self.conv_proj = nn.Conv2d(
            in_channels=num_channels,
            out_channels=hidden_dim,
            kernel_size=patch_size,
            stride=patch_size
        )

将196个图像块展平后,得到16163=768维的特征向量 最终得到维度为[B,196,768]的token序列 至此,原始图像已被转换为token表示

x = x.flatten(2)     

接着增加cls(class),作用为在训练时连接其他token,保存信息,节省head层的计算量 这时,尺寸为 【B,197,768】

self.class_token = nn.Parameter(torch.zeros(1, 1, hidden_dim))


n = x.shape[0]
batch_class_token = self.class_token.expand(n, -1, -1)   # [B, 1, hidden_dim]
x = torch.cat([batch_class_token, x], dim=1)    # [B, num_patches+1, hidden_dim]

加入位置编码,使计算机知道分割后的图片原本的位置,同时让计算机知道cls是可学习的token

# 位置编码
self.pos_embedding = nn.Parameter(torch.zeros(1, self.num_patches + 1, hidden_dim))

x = x + self.pos_embedding

return x

最后将这串 token 输入 Transformer Encoder。

Ecoder部分

主要模块

class MLPBlock(MLP):
    """Transformer MLP block."""

    _version = 2

    def __init__(self, in_dim: int, mlp_dim: int, dropout: float):
        super().__init__(in_dim, [mlp_dim, in_dim], activation_layer=nn.GELU, inplace=None, dropout=dropout)

        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.normal_(m.bias, std=1e-6)

    def _load_from_state_dict(
        self,
        state_dict,
        prefix,
        local_metadata,
        strict,
        missing_keys,
        unexpected_keys,
        error_msgs,
    ):
        version = local_metadata.get("version", None)

        if version is None or version < 2:
            # Replacing legacy MLPBlock with MLP. See https://github.com/pytorch/vision/pull/6053
            for i in range(2):
                for type in ["weight", "bias"]:
                    old_key = f"{prefix}linear_{i+1}.{type}"
                    new_key = f"{prefix}{3*i}.{type}"
                    if old_key in state_dict:
                        state_dict[new_key] = state_dict.pop(old_key)

        super()._load_from_state_dict(
            state_dict,
            prefix,
            local_metadata,
            strict,
            missing_keys,
            unexpected_keys,
            error_msgs,
        )


class EncoderBlock(nn.Module):
    """Transformer encoder block."""

    def __init__(
        self,
        num_heads: int,
        hidden_dim: int,
        mlp_dim: int,
        dropout: float,
        attention_dropout: float,
        norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6),
    ):
        super().__init__()
        self.num_heads = num_heads

        # Attention block
        self.ln_1 = norm_layer(hidden_dim)
        self.self_attention = nn.MultiheadAttention(hidden_dim, num_heads, dropout=attention_dropout, batch_first=True)
        self.dropout = nn.Dropout(dropout)

        # MLP block
        self.ln_2 = norm_layer(hidden_dim)
        self.mlp = MLPBlock(hidden_dim, mlp_dim, dropout)

    def forward(self, input: torch.Tensor):
        torch._assert(input.dim() == 3, f"Expected (batch_size, seq_length, hidden_dim) got {input.shape}")
        x = self.ln_1(input)
        x, _ = self.self_attention(x, x, x, need_weights=False)
        x = self.dropout(x)
        x = x + input

        y = self.ln_2(x)
        y = self.mlp(y)
        return x + y

(1)首先由layerNorm进行归一化,并利用残差连接保存原token

self.ln_1 = norm_layer(hidden_dim)

(2)接着进行多头自注意力机制,将token用不同角度的自注意力机制进行思考,观察

self.self_attention = nn.MultiheadAttention(hidden_dim, num_heads, 

dropout=attention_dropout, batch_first=True)

(一般还会有一个dropout层进行限制防止过拟合)

self.dropout = nn.Dropout(dropout)

(3)最后和原token拼接,避免将原token磨掉

x = x + input

(4)接着继续残差连接和归一化

self.ln_1 = norm_layer(hidden_dim)

(5)最后MLP模块进行全连接,进行前向传播

class MLPBlock(MLP):
    """Transformer MLP block."""

    _version = 2

    def __init__(self, in_dim: int, mlp_dim: int, dropout: float):
        super().__init__(in_dim, [mlp_dim, in_dim], activation_layer=nn.GELU, inplace=None, dropout=dropout)

        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.normal_(m.bias, std=1e-6)

    def _load_from_state_dict(
        self,
        state_dict,
        prefix,
        local_metadata,
        strict,
        missing_keys,
        unexpected_keys,
        error_msgs,
    ):
        version = local_metadata.get("version", None)

        if version is None or version < 2:
            # Replacing legacy MLPBlock with MLP. See https://github.com/pytorch/vision/pull/6053
            for i in range(2):
                for type in ["weight", "bias"]:
                    old_key = f"{prefix}linear_{i+1}.{type}"
                    new_key = f"{prefix}{3*i}.{type}"
                    if old_key in state_dict:
                        state_dict[new_key] = state_dict.pop(old_key)

        super()._load_from_state_dict(
            state_dict,
            prefix,
            local_metadata,
            strict,
            missing_keys,
            unexpected_keys,
            error_msgs,
        )

(6)拼接原token

return x + y

MLP head部分

先取出之前的cls,因为自注意力的原因已经充分吸收其他token的信息了,可以减少计算量

接着就是MLP head

先全连接

在用Tanh激活函数

最后全连接

输出结果

heads_layers: OrderedDict[str, nn.Module] = OrderedDict()

heads_layers["pre_logits"] = nn.Linear(hidden_dim, representation_size)
            
heads_layers["act"] = nn.Tanh()
            
heads_layers["head"] = nn.Linear(representation_size, num_classes)

self.heads = nn.Sequential(heads_layers)

作者名片

zihan
王子翰
@zihan

这个作者暂时还没有填写个人简介。

评论区
文章作者和管理员都可以管理这里的评论。
0 条评论
登录后即可参与评论。 去登录
还没有评论,欢迎留下第一条交流内容。