
VIT:图像分割与嵌入技术详解
visual transform(视觉变换)是一种基于深度学习的计算机视觉模型,它通过图像分割技术使计算机能够深入理解图像内容。这种模型的核心思想是将图像分割为多个有意义的区域或对象,从而实现对图像内容的语义级理解。
主要特点包括:
像素级分割能力:能够精确到单个像素级别进行图像分割 上下文理解:通过自注意力机制捕捉图像中各部分之间的空间关系 多尺度特征提取:同时处理不同尺度的视觉特征 典型应用场景:
医学影像分析:精确分割肿瘤区域或器官结构 自动驾驶:识别道路、行人和其他车辆 遥感图像处理:区分不同地物类型 工业质检:检测产品缺陷区域 工作流程示例:
输入图像预处理(尺寸调整、归一化等) 通过编码器提取多层级特征 使用注意力机制建立特征间的空间关系 解码器将特征映射回原始分辨率 输出每个像素的类别预测 相比传统计算机视觉方法,visual transform的优势在于:
更好的长距离依赖建模能力 更准确的边界划分 更强的泛化性能 端到端的训练方式 目前最先进的实现包括Vision Transformer (ViT)、Swin Transformer等架构,这些模型在ImageNet等基准测试中展现了卓越的性能。
Embedding层
将图像分割处理
以224*224像素的图像为例
输入维度为【B,3,224,224】
采用16×16的patch_size进行分割,得到14*14=196个图像块 输出维度为【B,196,768】 代码实现
class Embedding(nn.Module):
def __init__(self, patch_size=16, image_size=224, num_channels=3, hidden_dim=768):
super().__init__()
self.patch_size = patch_size
self.image_size = image_size
self.num_channels = num_channels
self.hidden_dim = hidden_dim
# patch 数量
self.num_patches = (image_size // patch_size) * (image_size // patch_size)
# 用卷积实现 patch embedding
self.conv_proj = nn.Conv2d(
in_channels=num_channels,
out_channels=hidden_dim,
kernel_size=patch_size,
stride=patch_size
)
将196个图像块展平后,得到16163=768维的特征向量 最终得到维度为[B,196,768]的token序列 至此,原始图像已被转换为token表示
x = x.flatten(2)
接着增加cls(class),作用为在训练时连接其他token,保存信息,节省head层的计算量 这时,尺寸为 【B,197,768】
self.class_token = nn.Parameter(torch.zeros(1, 1, hidden_dim))
n = x.shape[0]
batch_class_token = self.class_token.expand(n, -1, -1) # [B, 1, hidden_dim]
x = torch.cat([batch_class_token, x], dim=1) # [B, num_patches+1, hidden_dim]
加入位置编码,使计算机知道分割后的图片原本的位置,同时让计算机知道cls是可学习的token
# 位置编码
self.pos_embedding = nn.Parameter(torch.zeros(1, self.num_patches + 1, hidden_dim))
x = x + self.pos_embedding
return x
最后将这串 token 输入 Transformer Encoder。
Ecoder部分
主要模块
class MLPBlock(MLP):
"""Transformer MLP block."""
_version = 2
def __init__(self, in_dim: int, mlp_dim: int, dropout: float):
super().__init__(in_dim, [mlp_dim, in_dim], activation_layer=nn.GELU, inplace=None, dropout=dropout)
for m in self.modules():
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
if m.bias is not None:
nn.init.normal_(m.bias, std=1e-6)
def _load_from_state_dict(
self,
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
):
version = local_metadata.get("version", None)
if version is None or version < 2:
# Replacing legacy MLPBlock with MLP. See https://github.com/pytorch/vision/pull/6053
for i in range(2):
for type in ["weight", "bias"]:
old_key = f"{prefix}linear_{i+1}.{type}"
new_key = f"{prefix}{3*i}.{type}"
if old_key in state_dict:
state_dict[new_key] = state_dict.pop(old_key)
super()._load_from_state_dict(
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
)
class EncoderBlock(nn.Module):
"""Transformer encoder block."""
def __init__(
self,
num_heads: int,
hidden_dim: int,
mlp_dim: int,
dropout: float,
attention_dropout: float,
norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6),
):
super().__init__()
self.num_heads = num_heads
# Attention block
self.ln_1 = norm_layer(hidden_dim)
self.self_attention = nn.MultiheadAttention(hidden_dim, num_heads, dropout=attention_dropout, batch_first=True)
self.dropout = nn.Dropout(dropout)
# MLP block
self.ln_2 = norm_layer(hidden_dim)
self.mlp = MLPBlock(hidden_dim, mlp_dim, dropout)
def forward(self, input: torch.Tensor):
torch._assert(input.dim() == 3, f"Expected (batch_size, seq_length, hidden_dim) got {input.shape}")
x = self.ln_1(input)
x, _ = self.self_attention(x, x, x, need_weights=False)
x = self.dropout(x)
x = x + input
y = self.ln_2(x)
y = self.mlp(y)
return x + y
(1)首先由layerNorm进行归一化,并利用残差连接保存原token
self.ln_1 = norm_layer(hidden_dim)
(2)接着进行多头自注意力机制,将token用不同角度的自注意力机制进行思考,观察
self.self_attention = nn.MultiheadAttention(hidden_dim, num_heads,
dropout=attention_dropout, batch_first=True)
(一般还会有一个dropout层进行限制防止过拟合)
self.dropout = nn.Dropout(dropout)
(3)最后和原token拼接,避免将原token磨掉
x = x + input
(4)接着继续残差连接和归一化
self.ln_1 = norm_layer(hidden_dim)
(5)最后MLP模块进行全连接,进行前向传播
class MLPBlock(MLP):
"""Transformer MLP block."""
_version = 2
def __init__(self, in_dim: int, mlp_dim: int, dropout: float):
super().__init__(in_dim, [mlp_dim, in_dim], activation_layer=nn.GELU, inplace=None, dropout=dropout)
for m in self.modules():
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
if m.bias is not None:
nn.init.normal_(m.bias, std=1e-6)
def _load_from_state_dict(
self,
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
):
version = local_metadata.get("version", None)
if version is None or version < 2:
# Replacing legacy MLPBlock with MLP. See https://github.com/pytorch/vision/pull/6053
for i in range(2):
for type in ["weight", "bias"]:
old_key = f"{prefix}linear_{i+1}.{type}"
new_key = f"{prefix}{3*i}.{type}"
if old_key in state_dict:
state_dict[new_key] = state_dict.pop(old_key)
super()._load_from_state_dict(
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
)
(6)拼接原token
return x + y
MLP head部分
先取出之前的cls,因为自注意力的原因已经充分吸收其他token的信息了,可以减少计算量
接着就是MLP head
先全连接
在用Tanh激活函数
最后全连接
输出结果
heads_layers: OrderedDict[str, nn.Module] = OrderedDict()
heads_layers["pre_logits"] = nn.Linear(hidden_dim, representation_size)
heads_layers["act"] = nn.Tanh()
heads_layers["head"] = nn.Linear(representation_size, num_classes)
self.heads = nn.Sequential(heads_layers)