We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
def LFA(img_feats, cls_prototypes, labels, beta, test_img_features): """ img_feats: [N, d] cls_prototypes: [C, d] labels: [N] test_img_features: [M, d] # N: number of training image features # C: number of classes # d: features dimensionality # M: number of test image features """ # One-to-one matchings text_feats = cls_prototypes[labels] # Orthogonal Procrustes u, _, v = torch.svd(img_feats.T @ text_feats) W_op = u @ v.T # Beta-Procrustes identity = torch.eye(d) W_beta = W_op - (W_op - identity) * beta # Refine W = adaptive_rerank_refine(W_beta) test_logits = (test_img_features @ W) @ cls_prototypes.T test_preds = test_logits.argmax(-1) return test_preds
以下是对这段代码的解释: **函数功能**: 这个函数实现了线性特征对齐(LFA)算法 用于对图像特征进行分类预测,包括有监督学习的情况, 即使用有标注数据进行训练和对测试图像特征进行预测。 **参数解释**: - `img_feats`:形状为`[N, d]`的张量,表示训练图像特征,其中`N`是训练图像特征的数量,`d`是特征的维度。 - `cls_prototypes`:形状为`[C, d]`的张量,表示类别原型,其中`C`是类别的数量,`d`是特征的维度。 - `labels`:形状为`[N]`的张量,表示训练图像的标签。 - `beta`:超参数,用于控制`Beta-Procrustes`步骤中的插值程度。 - `test_img_features`:形状为`[M, d]`的张量,表示测试图像特征,其中`M`是测试图像特征的数量,`d`是特征的维度。 **函数主体解释**: 1. `# One-to-one matchings text_feats = cls_prototypes[labels]`: - 根据训练图像的标签,从类别原型中选取对应的文本特征,得到与每个训练图像特征对应的文本特征。 2. `# Orthogonal Procrustes u, _, v = torch.svd(img_feats.T @ text_feats)`: - 对训练图像特征和对应的文本特征的乘积进行奇异值分解(SVD),得到`u`、奇异值和`v`。 3. `W_op = u @ v.T`: - 计算正交普罗克汝斯忒斯(Orthogonal Procrustes)解,即`u`和`v`的乘积。 4. `# Beta-Procrustes identity = torch.eye(d)`: - 创建一个维度为`d`的单位矩阵。 5. `W_beta = W_op - (W_op - identity) * beta`: - 计算`Beta-Procrustes`解,通过插值将正交解向单位矩阵靠近,控制过拟合。 6. `# Refine W = adaptive_rerank_refine(W_beta)`: - 使用自适应重排序(ARerank)损失对`W_beta`进行细化,得到最终的映射矩阵`W`。 7. `test_logits = (test_img_features @ W) @ cls_prototypes.T`: - 对测试图像特征进行变换,然后与类别原型进行矩阵乘法,得到测试图像的对数几率。 8. `test_preds = test_logits.argmax(-1)`: - 取对数几率在最后一个维度上的最大值索引,得到测试图像的预测类别。 9. `return test_preds`: - 返回测试图像的预测类别。
The text was updated successfully, but these errors were encountered:
No branches or pull requests
LFA
Arch
Approximating Soft Prompts with a Linear Transformation
LFA
-为了学习W,线性特征对齐(LFA)首先使用最小二乘优化问题的闭式解进行初始化,然后最小化一个重排序损失来细化初始解。
Reference
The text was updated successfully, but these errors were encountered: