参照[1]中的目标对应的激活图谱
可解释性是AI模型中的一个重要话题。最近复杂的AI模型往往成为黑盒算法,使得人们很难理解AI为什么会得出这些结果。最近,我读了一篇论文《CLIP Surgery for Better Explainability with Enhancement in Open-Vocabulary Tasks》[1],主要讲述了CLIP出色的可解释性技术。尽管这篇论文展示了CLIP出色的可解释性,但很少有博客解释这个技术。因此,我将在本文中介绍CLIP_Surgery的架构及其应用。
目录:
- 快速回顾一下CLIP
- 解释CLIP Surgery算法流程
- 应用:检验在处理真实世界数据的能力,以及与Segment Anything的兼容性
CLIP 是 OpenAI [2] 的一个改变游戏规则的 AI。由于其独特的架构,它能够进行零样本图像识别。其架构如图所示。
来自[2]的CLIP架构图改编自[2]
(Note: I've maintained the "改编自" as it was initially suggested, but if "参考自" is preferred for a more academic tone, it can be used instead.)
CLIP 有图像和文本编码器,用于生成图像和文本嵌入。训练数据是图像和文本对,例如一张带有文本“一只狗的照片”的狗的图像。它利用对比预训练来对齐图像和文本嵌入,仅当图像和文本为配对时才进行对齐。为了更直观地理解,让我们来看一个例子。在此示例中,我们使用三个图像和文本对(如上图所示,N = 3)。
作者展示的对比性预训练图
图像和文本编码器的输出嵌入维度始终为(1, 512)。例如,我们有图像和文本嵌入,每个的维度为 (3, 512)。通过计算嵌入的余弦相似度,我们可以得到相似度矩阵,如上图所示。在对比预训练中,CLIP 利用此相似度矩阵使匹配对(对角元素)更相似,但其他对(其他元素)更不相似。具体来说,如下所示是论文 [2] 中的伪代码过程。
# image_encoder - ResNet 或 Vision Transformer
# text_encoder - CBOW 或 Text Transformer
# I[n, h, w, c] - 对齐的图像小批量
# T[n, l] - 对齐的文本小批量
# W_i[d_i, d_e] - 图像到嵌入空间的学习投影
# W_t[d_t, d_e] - 文本到嵌入空间的学习投影
# t - 学习到的温度参数
# 提取每种模式的特征表示
I_f = image_encoder(I) #[n, d_i]
T_f = text_encoder(T) #[n, d_t]
# 联合多模态嵌入 [n, d_e]
I_e = l2_normalize(np.dot(I_f, W_i), axis=1)
T_e = l2_normalize(np.dot(T_f, W_t), axis=1)
# 成对余弦相似度的缩放 [n, n]
logits = np.dot(I_e, T_e.T) * np.exp(t)
# 对称损失函数
标签 = np.arange(n)
loss_i = 交叉熵损失函数(logits, 标签, axis=0)
loss_t = 交叉熵损失函数(logits, 标签, axis=1)
loss = (loss_i + loss_t)/2
计算图像和文本嵌入的余弦相似度之后,他们应用交叉熵损失,让相似度矩阵的对角线元素接近1,而其他元素接近0。作者将这种计算称为对比损失函数。CLIP仅通过这种对比损失函数进行训练。
对于零样本识别,操作如下。首先,我们输入 n 个候选文本,并得到维度为 (n, 512) 的嵌入。接下来,我们计算目标图像与候选文本嵌入之间的相似度。最后,我们将选择最相似的候选文本作为类别。这么简单,不是吗?
该流程简单且直观,但需要通过数百万个图像和文本对以及数百个GPU来训练CLIP。根据原论文,他们使用了非常大的小批量大小,例如32,768,并且在592个V100 GPU上训练了18天时间。因此,许多公司将其作为基础模型使用,而不是从零开始训练。
2. CLIP Surgery算法的解释CLIP Surgery 主要是为了增强 CLIP 结果的可解释性而开发的,其目标是提高模型的透明度。更令人惊讶的是,CLIP Surgery 可以在不进行任何额外训练的情况下直接展示与标签对应的激活图。由于其出色的激活图可视化能力,这项技术可以应用于分割任务的基础模型Segment Anything。我将在后面的章节中详细介绍其应用。
研究者们深入审查了注意力层以提高可解释性,无需额外训练。请看下面的图表。
CLIP Surgery架构源自论文[1]
左边显示的是原始CLIP的注意力层,而右边显示的是CLIP Surgery的注意力层。它们表明查询-键自注意力激活与标签对应的相反语义区域,这意味着与标签相关的部分会被激活。另一方面,值-值自注意力只能专注于特定的语义区域。这意味着什么?下图展示了查询-键自注意力和值-值自注意力的激活图,这有助于我们更好地理解这些注意力机制的工作原理。
参考自论文[1]的查询-键注意力和value-value注意力机制的激活图可视化
可以看到,查询键自我注意不仅会突出显示目标标签区域,还会突出显示无关区域。相反,值值自注意力可以聚焦在对应的目标标签区域。根据实验,查询-键自我注意可能导致特征图混乱。需要注意的是,这一观察主要是基于实验结果,尚未通过数学定理证明。
另外,他们意识到在各个标签之间存在冗余特征。请参阅下图,如下所示。
CLIP激活图针对某些标签,摘自论文[1]
如你所见,冗余区域在所有标签中出现在相同位置。因此,他们想出了一个办法,可以通过移除所有标签中共同激活的部分来去除冗余。
他们是如何做到的?具体来说是,官方实现如下所示。
# 权重以限制显著类别对其他类别的影响
# (batch_size, 1, 512) @ (标签数量, 即标签的数量, 512).T = (batch_size, 1, 标签数量)
prob = image_features[:, :1, :] @ text_features.t()
# prob 的形状为 (batch_size, 1, 标签数量)
prob = (prob * 2).softmax(-1) # 这样
# w 的形状为 (batch_size, 1, 标签数量)
w = prob / prob.mean(-1, keepdim=True)
# 元素级乘积特征
# b 是批次大小
# n_t 是标签数量
# n_i 是 token 数量 (=197)
# c 是特征维度 (=512)
b, n_t, n_i, c = image_features.shape[0], text_features.shape[0], image_features.shape[1], image_features.shape[2]
# feats 的形状为 (batch_size, n_i, n_t, c)
feats = image_features.reshape(b, n_i, 1, c) * text_features.reshape(1, 1, n_t, c)
feats *= w.reshape(1, 1, n_t, 1)
# redundant_feats 的形状为 (batch_size, n_i, n_t, c)
redundant_feats = feats.mean(2, keepdim=True) # 沿着类别维度
feats = feats - redundant_feats
# 将元素级乘积特征作为余弦相似度求和
# similarity 的形状为 (batch_size, n_i, n_t)
similarity = feats.sum(-1)
为了更清楚地说明,我在代码里给每个计算都加了尺寸变化。现在,让我们一步步来分析这个问题。
第一个部分计算权重向量,以保持每个类的影响相等。首先,我们从图像嵌入中提取类别标记。在变压器架构中,类别标记是在标记维度中作为第一个标记。请注意,类别标记应该反映出所有其他标记的信息(如果您不熟悉视觉变换器,可以参考这篇博客 [5])。然后,我们计算余弦相似度并得到相似度矩阵。接下来,我们将相似度矩阵中的值转换成标签维度上的概率值,从而得到权重矩阵。
在第二个部分中,我们计算除了冗余特征外的特征矩阵。首先,我们计算图像和文本嵌入的逐元素特征图。直观来看,跨标签激活的区域在这张图中会有更高的值,如前所示的图。因此,我们可以通过计算跨标签的平均从特征矩阵中得到冗余特征。在从原始的特征矩阵中减去冗余特征后,我们可以得到纯粹的特征矩阵。
在最后一步中,通过对特征矩阵沿特征维度求和来得到相似性矩阵。
对于我们所说的特征图可视化,我们需要对相似性矩阵进行归一化、调整并插值到输入图像的大小(稍后可以通过附带的代码来检查实现过程)。下图是CLIP Surgery的结果。
CLIP 手术激活图谱改编自 [1]
你看,它能够捕捉到标签相关含义的范围。你可以感受到这种可视化的强大之处。
到目前为止,我们已经看到了CLIP Surgery的详细算法。在最后部分,我们将检查它在处理真实世界数据的能力以及它的应用。
3. 应用:验证现实世界数据的能力和Segment Anything的Points提供在最后一节中,我将引导您如何在实际数据和Segment Anything (SAM) 上应用CLIP Surgery。让我们开始吧!
设置环境
作为第一步,你需要搭建一个环境。我使用了ubuntu20.04,cuda11.7和Python3.10。首先,我用conda创建了一个虚拟环境。
conda create --name sam python==3.10 -y
conda activate sam
conda install pip
以下是可选步骤:为了避免在本地环境安装库,
检查将用于存储库的 pip 位置
which pip
# 在我的环境中,我使用 /opt/conda/envs/sam/bin/pip。
接下来,你需要按照官方指南安装Pytorch和torchvision如下所述。你可以根据所处的环境安装相应的版本即可。例如,下面的命令就是我遇到的情况。
conda install pytorch torchvision torchaudio pytorch-cuda=11.8 -c pytorch -c nvidia
接下来,你需要通过运行以下命令来安装它们:SAM仓库模型和模型权重。
pip install git+https://github.com/facebookresearch/segment-anything.git # 安装segment-anything
wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth # 使用wget下载权重文件,地址为:https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth
通过运行上述命令,可以安装segment-anything并下载权重文件。
你也需要安装CLIP Surgery。
运行以下命令克隆代码库:
git clone https://github.com/xmed-lab/CLIP_Surgery.git
最后,你需要安装几个库。你可以使用如下格式通过pip来安装它们:“pip install <library>。”
以下是一些常用的Python库和它们的版本:
tqdm==4.66.5
ftfy==6.2.3
matplotlib
opencv-python
regex
现在你已经完成了环境搭建。
针对 Flickr30k 数据集的 CLIP 手术功能
首先,我要用Flickr30k数据集[4]来验证CLIP Surgery处理实际数据的能力。因此,我将比较CLIP和CLIP Surgery的激活模式。我会在后面附上我用到的代码。下面是对比的结果图。
比较CLIP和CLIP Surgery的激活图谱
如您所见,纯原版的CLIP无法精确检测对象,但CLIP Surgery可以在对象存在的情况下找到与其标签匹配的对象。然而,比如当试图检测猫或植物这样的对象时,CLIP Surgery仍然存在问题。其中一个原因是后处理中的这种最小最大归一化方法。换句话说,当激活图中只有不相关的区域时,最小最大归一化可能会增强它们值之间的差异。为了解决这个问题,我们可以在进行最小最大归一化之前设置一个简单的阈值。在Flickr数据集的情况下,通过相似度图的直方图分析得知,相关区域的值阈值应高于0.1。结果如图所示。
比较修改后的CLIP和CLIP Surgery的激活图
多亏了这个阈值,我们可以去掉那些不相关的部分。阈值可以根据数据集调整;因此,我们应该通过查看直方图找到合适的值。
Segment Anything的点来源
CLIP Surgery 可应用于 Segment Anything 的点提供者部分,因为其激活图可视化非常精确。需要说明的是,SAM 是 Meta 在 2023 年开发的一种分割基础模型。下图展示了该模型的架构。
来自论文[3]的Segment Anything架构
SAM的分割能力非常强大。然而,它并没有通过带有标签的分割数据集标注进行训练,因此当我们需要指定具体对象时,需要提供一些点、边界框或掩码。正如你所猜测的,这样的标注工作相当费时。在这里,CLIP Surgery 可以帮助我们自动找到这些点。接下来我们看看如何在实际操作中将 CLIP Surgery 和 SAM 结合起来。
为了给SAM生成点,我们对激活图进行降采样,并按值排序以选择相关区域。在官方实现中,他们使用7x7的激活图来找到最相关的区域。当目标对象不存在时也会遇到问题,所以我稍微调整了原始实现,添加了一个阈值。结果如下。
CLIP Surgery 抽取的结果
橙色的点指的是与标签相关的点,而蓝色的点代表标签的对立点。如图所示,它能够相当准确地检测目标标签的位置。请注意,点的精度由CLIP决定。因此,如果CLIP无法理解目标,它将无法准确给出目标点的位置。我将附上用于此应用程序的Jupyter笔记本。
这就结束了这篇博客。感谢您花时间阅读这篇博客!
参考[1] 李 Y., 王 H. 等.,增强开放词汇任务可解释性的CLIP手术,arXiv预印本
[2] Radford, A., Kim, J., 等, 从自然语言监督中学习可转移的视觉模型, arXiv
[3] Kirillov, A., Ravi, N., Mintun, E., Mao, H., 等, Segment Anything, arXiv
[4] https://www.kaggle.com/datasets/hsankesara/flickr-image-dataset (Flickr图片数据集)
[5] Callis, S., 视觉Transformer详解, 数据科学