核心在于让模型学会“拉开不同类距离、拉近同类距离”,依赖损失函数(如三元组、对比损失)、数据组织与训练策略协同;三元组损失要求锚点与正样本距离小于锚点与负样本距离。

构建图像嵌入模型的核心,不是堆叠网络层数,而是让模型学会“拉开不同类距离、拉近同类距离”——这靠的是损失函数设计、数据组织方式和训练策略的协同。
用三元组(Triplet)或对比(Contrastive)损失替代分类损失
传统分类模型输出类别概率,但嵌入任务需要向量间的几何关系。三元组损失要求:锚点(anchor)与正样本(same class)距离
- 对比损失可简化实现,适合初学者:只构造正负样本对,加 margin 控制负样本最小距离
- PyTorch 中可用 torch.nn.TripletMarginLoss 或自定义 loss,注意设置合理 margin(通常 0.1–1.0,取决于 embedding 维度和归一化方式)
- 务必对 embedding 向量做 L2 归一化(尤其用余弦相似度时),否则模长干扰距离度量
图像预处理要匹配下游使用场景
嵌入模型最终用于检索或聚类,输入必须和线上推理一致。常见误区是训练用 RandomResizedCrop,而推理用 CenterCrop,导致分布偏移。
- 训练时增强要有“语义一致性”:ColorJitter、RandomGrayscale 可以,但避免 RandomRotation(除非业务允许旋转不变性)
- 统一缩放到固定尺寸(如 224×224),再归一化(ImageNet 均值标准差即可,不必重算)
- 若部署在移动端,可提前模拟量化噪声(如添加 torch.round(x * 128) / 128),提升训练-推理一致性
采样策略比网络结构更影响收敛质量
随机打乱 batch 很难保证每批都有足够正负样本对。尤其类别不均衡时,小众类可能整 epoch 都没被选为正样本。
版权声明:除非特别标注,否则均为本站原创文章,转载时请以链接形式注明文章出处。
还木有评论哦,快来抢沙发吧~