Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

paper LFA #480

Open
junxnone opened this issue Oct 11, 2024 · 0 comments
Open

paper LFA #480

junxnone opened this issue Oct 11, 2024 · 0 comments

Comments

@junxnone
Copy link
Owner

junxnone commented Oct 11, 2024

LFA

  • Linear Feature Alignment 重新对齐视觉语言特征
  • 通过最小二乘法问题的闭式解进行初始化,然后通过最小化重排序损失进行迭代更新

image

Arch

Approximating Soft Prompts with a Linear Transformation

  • $Y \in \mathbb{R}^{C ×d}$ 是以矩阵形式表示的类名嵌入
  • $Y' \in \mathbb{R}^{C ×d}$ 是通过软提示学习获得的类原型
  • 目标是通过求解以下最小二乘问题来学习一个试图近似提示学习的线性变换 $W \in \mathbb{R}^{d ×d}$
    • $Y \stackrel{w}{\to } Y'$$min _{W \in \mathbb{R}^{d × d}}\parallel Y W - Y'\parallel _{F}^{2}, (1)$
    • $|\cdot|_{F}$ 是弗罗贝尼乌斯范数。

LFA

  • 目标是学习一个线性映射 W,用于将图像嵌入 X 与它们对应的文本类别原型 Y 对齐,即 $X\stackrel{w}{\to}Y$
  • 一旦学习到W,为了对一个新样本 x 进行分类,我们可以从 softmax $(xW\cdot Y^{\top}/\tau)$ 中获得其 c 类概率
    • τ 是温度参数
      -为了学习W,线性特征对齐(LFA)首先使用最小二乘优化问题的闭式解进行初始化,然后最小化一个重排序损失来细化初始解。
    • $X\in\mathbb{R}^{N×d}$ 是由 CLIP 图像编码器生成的 N 个样本的图像嵌入
    • $Y\in\mathbb{R}^{C×d}$ 是使用 CLIP 文本编码器(即没有任何提示)对类别名称进行编码后得到的 C 个类别原型
    • $P\in P_{N×C}$ 是一个分配矩阵,它将每个类别原型分配给其对应的图像嵌入
    • $P_{N×C}={P\in{0,1}^{N×C}, P1_{C}=1_{N}}$
    • 作为二进制置换矩阵的集合,它将 N 行中的每一行映射到 C 列中的一列,即输入图像映射到它们对应的类别。
    • 在有监督的设置中,当我们有 N 个(图像-类别)对时,P 是堆叠的 N 个 C 维的独热向量。
  • 目标是找到一个最优的线性映射,以弥合模态差距并将每个图像嵌入与其文本类别原型对齐。
  • 通过求解以下最小二乘问题来学习线性映射: $\underset{W \in \mathbb{R}^{d × d}}{argmin}\parallel X W-P Y\parallel _{F}^{2}.$
  • 即在实数空间 $\mathbb{R}^{d × d}$ 中寻找一个线性映射 $W$ ,使得 $\parallel X W-P Y\parallel _{F}^{2}$ 最小。其中, $X$ 是图像嵌入矩阵, $Y$ 是文本类别原型矩阵, $P$ 是分配矩阵。目标是找到一个最优的线性映射,来弥合模态差距,让每个图像嵌入与对应的文本类别原型对齐。
def LFA(img_feats, cls_prototypes, labels, beta, test_img_features):
    """
    img_feats: [N, d] 
    cls_prototypes: [C, d] 
    labels: [N] 
    test_img_features: [M, d]
    # N: number of training image features 
    # C: number of classes 
    # d: features dimensionality 
    # M: number of test image features 
    """
    # One-to-one matchings 
    text_feats = cls_prototypes[labels]
    # Orthogonal Procrustes 
    u, _, v = torch.svd(img_feats.T @ text_feats) 
    W_op = u @ v.T
    # Beta-Procrustes 
    identity = torch.eye(d) 
    W_beta = W_op - (W_op - identity) * beta
    # Refine 
    W = adaptive_rerank_refine(W_beta)
    test_logits = (test_img_features @ W) @ cls_prototypes.T 
    test_preds = test_logits.argmax(-1)
    return test_preds
以下是对这段代码的解释:

**函数功能**:

这个函数实现了线性特征对齐(LFA)算法
用于对图像特征进行分类预测,包括有监督学习的情况,
即使用有标注数据进行训练和对测试图像特征进行预测。

**参数解释**:
- `img_feats`:形状为`[N, d]`的张量,表示训练图像特征,其中`N`是训练图像特征的数量,`d`是特征的维度。
- `cls_prototypes`:形状为`[C, d]`的张量,表示类别原型,其中`C`是类别的数量,`d`是特征的维度。
- `labels`:形状为`[N]`的张量,表示训练图像的标签。
- `beta`:超参数,用于控制`Beta-Procrustes`步骤中的插值程度。
- `test_img_features`:形状为`[M, d]`的张量,表示测试图像特征,其中`M`是测试图像特征的数量,`d`是特征的维度。

**函数主体解释**:

1. `# One-to-one matchings text_feats = cls_prototypes[labels]`:
   - 根据训练图像的标签,从类别原型中选取对应的文本特征,得到与每个训练图像特征对应的文本特征。

2. `# Orthogonal Procrustes u, _, v = torch.svd(img_feats.T @ text_feats)`:
   - 对训练图像特征和对应的文本特征的乘积进行奇异值分解(SVD),得到`u`、奇异值和`v`。

3. `W_op = u @ v.T`:
   - 计算正交普罗克汝斯忒斯(Orthogonal Procrustes)解,即`u`和`v`的乘积。

4. `# Beta-Procrustes identity = torch.eye(d)`:
   - 创建一个维度为`d`的单位矩阵。

5. `W_beta = W_op - (W_op - identity) * beta`:
   - 计算`Beta-Procrustes`解,通过插值将正交解向单位矩阵靠近,控制过拟合。

6. `# Refine W = adaptive_rerank_refine(W_beta)`:
   - 使用自适应重排序(ARerank)损失对`W_beta`进行细化,得到最终的映射矩阵`W`。

7. `test_logits = (test_img_features @ W) @ cls_prototypes.T`:
   - 对测试图像特征进行变换,然后与类别原型进行矩阵乘法,得到测试图像的对数几率。

8. `test_preds = test_logits.argmax(-1)`:
   - 取对数几率在最后一个维度上的最大值索引,得到测试图像的预测类别。

9. `return test_preds`:
   - 返回测试图像的预测类别。

Reference

@junxnone junxnone changed the title x LFA paper LFA Oct 16, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant