发布时间:2023-04-20 文章分类:电脑百科 投稿人:李佳 字号: 默认 | | 超大 打印

【深度学习】详解 MAE

目录

摘要

一、引言

二、相关工作

三、方法

四、ImageNet 实验

4.1 主要属性

4.2 与先前结果的对比

4.3 部分微调

五、迁移学习实验

六、讨论与结论 

七、核心代码


  • Title:Masked Autoencoders Are Scalable Vision Learners
  • Paper:https://arxiv.org/abs/2111.06377
  • Github:https://github.com/facebookresearch/mae

摘要

        本文证明了 masked autoencoder (MAE) 是一种可扩展的 (scalable) CV 自监督学习器。MAE 的思想很简单:mask 输入图像的随机 patches,并重建缺失的 pixels。MAE 基于两个核心设计。首先,我们开发了一个非对称 (asymmetric) 的编码器-解码器架构,编码器只操作于 patches 的可见子集 (无 mask tokens),轻量级解码器 从潜在表示和 mask tokens 中重构原始图像。其次,我们发现高比例地 mask 输入图像 (如 75%) 产生了一个重要 (non-trival) 和有意义的自监督任务。耦合 (coupling) 这两种设计使我们能够有效地训练大型模型:我们加速训练 (3×或更多) 并提高准确率。我们的可扩展方法 允许学习 具有良好泛化能力的高容量 (high-capacity) 模型:例如,在只用 ImageNet-1K 的方法中,一个普通的 ViT-Huge 达到最佳准确率 (87.8%)。在下游任务中的迁移性能优于有监督预训练,并表现出有前景的扩展行为 (scaling behavior)。

一、引言

        深度学习见证了具有不断增长的能力和容量 (capability and capacity) 的架构的爆炸式增长。在硬件的快速增长的帮助下,今天的模型可以很容易地过拟合 100 万张图像,并开始要求数亿张 —— 通常是公开不可访问的 —— 带标签的图像。

        这种对数据的饥渴已经在 NLP 中通过 自监督预训练 被成功地解决了。GPT 中基于 autoregressive 的语言建模,和 BERT 中的 masked autoencoding 的解决方案 在概念上很简单:它们删除部分数据,并学习预测被删除的内容。这些方法现在可以训练包含超过 1000 亿个参数的generalizable NLP 模型。

        Masked autoencoding 的思想,是一种更一般的 denoising autoencoders 的形式,是自然的且适用于 CV 的。事实上,与视觉密切相关的研究早于 BERT。然而,尽管随着 BERT 的成功,人们对这一想法产生了极大的兴趣,但在视觉领域的 autoencoding 方法的进展落后于 NLP。我们会问:是什么使 masked autoencoding 在视觉和语言之间有所不同?我们试图从以下角度来回答这个问题:

        (i) 直到最近的架构都是不同的。在视觉中,卷积网络在过去十年的中占主导地位。卷积通常在常规 grids 上运行,要将 mask token 或位置嵌入等 “indicators” 集成到卷积网络中并不简单。然而,随着 ViT 的引入,架构的鸿沟已经得到了解决且,不应再构成障碍。

        (ii) 语言和视觉之间的信息密度不同。语言是由人类产生的信号,具有高度的语义性和信息密集性 (highly semantic and information-generated)。当训练一个模型只预测每个句子中缺失的一些单词时,这项任务似乎能诱导 (induce) 复杂的语言理解。相反,图像是具有大量空间冗余的自然信号,例如,丢失的 patches 可以从相邻的 patches 中恢复,而几乎不需要对 parts、objects 和 scenes 的高级理解。为克服这种差异并鼓励学习有用的特征,我们展示了一个在 CV 中表现很好的简单策略:masking 非常高比例的随机 patches。这种策略在很大程度上减少了冗余,并创造了一个具有挑战性的自监督任务,需要对低级图像统计之外的整体 (holistic) 理解。要对我们的重建任务进行定性的了解,请参见图 2-4。

        (iii) autoencoder 的解码器,它将潜在的表示映射回输入,在重建文本和图像之间起着不同的作用。在视觉中,解码器重建像素,因此其输出的语义级别低于常见的识别任务。相反,在语言中,解码器预测包含丰富语义信息的缺失单词。虽然在 BERT 中,解码器可以是简单的 (一个MLP),但我们发现,对于图像,解码器的设计 在决定学习到的潜在表示的语义水平中 起着关键作用

        在此分析的驱动下 (Driven by this analysis),我们提出了一种简单的、有效的、可扩展的 masked autoencoder (MAE) 形式,用于视觉表示学习。MAE 从输入图像中 mask 随机 patches,并重建像素空间中缺失的 patches。它具有非对称的编码器-解码器设计。我们的编码器只对可见的 patches 子集进行操作 (没有 mask tokens),并且我们的解码器是轻量级的,可以基于潜在表示和 mask tokens 重建输入 (图 1)。在我们的非对称编码器-解码器中,将 mask tokens 转移到小解码器可以大大减少计算量。在这种设计下,一个非常高的掩蔽率 (如 75%) 可以实现双赢的场景 (win-win scenario):它优化准确率,同时允许编码器只处理一小部分 (如 25%) patches。这可以将整体预训练时间减少 3× 或更多,同样减少内存消耗,使我们能够轻松地将 MAE 扩展为大型模型。

        MAE 学习了非常高容量的模型,且具有很好的泛化能力。通过 MAE 预训练,我们可以在 ImageNet-1K (IN1K) 上训练诸如 ViT-Large/-Huge 这样的对数据饥渴的模型,从而提高泛化性能。使用一个普通的 ViT-Huge 模型,当我们在 ImageNet-1K 上微调时,达到了 87.8% 的准确率,这优于之前所有只使用 ImageNet-1K 数据的结果。我们还评估了迁移学习的目标检测、实例分割和语义分割。在这些任务中,我们的预训练比有监督预训练取得了更好的结果,更重要的是,我们通过扩大 (scaling up) 模型观察到显著的收益。这些观察结果与 NLP 中自监督预训练的观察结果一致,我们希望 NLP 能够使 CV 领域探索类似的轨迹。

二、相关工作

        Masked language modeling (MLM) 及其对手 autoregressive,如 BERT 和 GPT,是 NLP 中非常成功的预训练方法。这些方法包含了部分输入序列,并训练模型来预测缺失的内容。这些方法已被证明能够良好扩展/放缩 (scale excellently),并且大量的证据表明,这些预训练好的表示可以很好地推广到各种下游任务。

        Autoencoding 是学习表示的一种经典方法。它有一个将输入映射到潜在表示 (latent representations) 的编码器,和一个重建输入的解码器。例如,PCA 和 k-means 是 autoencoders。Denoising autoencoders (DAE) 是一类 autoencoders,它会损坏 (corrupt) 一个输入信号,并学习重建原始的、未损坏的信号。一系列的方法可以被认为是在不同的损坏下的广义 DAE,例如,masking pixels 或移除颜色通道。我们的 MAE是 denoising autoencoding 的一种形式,但在许多方面与经典的 DAE 不同。

        Masked image encoding方法从被 masking 损坏的图像中学习表示。Stacked denoising autoencoders 作为开创性工作将 masking 作为DAE 中的一种噪声类型。Context Encoder 使用卷积网络来补全大的缺失区域。由于 NLP 的成功,近期相关方法是基于 Transformer 的。iGPT 处理像素序列并预测未知像素。ViT 研究了自监督学习预测 masked patch。最近,BEiT 提出预测离散 tokens。

        Self-supervised learning 方法对 CV 具有显著兴趣,通常聚焦于预训练的不同前置 (pretext) 任务。最近,对比学习非常流行,它建模两个或多个视图 (views) 之间的图像相似性和不相似性 (或仅为相似性)。对比的 (Contrastive) 和相关方法强烈依赖于数据增广。Autoencoding 追求一个概念上不同的方向,相比于我们将呈现出的内容,它将表现出不同的行为。

三、方法

        我们的 masked autoencoder (MAE) 是一种简单的 autoencoding 方法,它可以重建原始信号。像所有的 autoencoders 一样,本方法有一个 将观察到的信号映射到一个潜在表示 的编码器 ,和一个从潜在表示重建原始信号 的解码器。与经典的 autoencoders 不同,我们采用了一种非对称的设计允许编码器只操作部分的、观察到的信号 (没有 mask tokens),以及一个轻量级的解码器,从潜在表示和 mask tokens 重j建完整的信号。图 1 说明了接下来将介绍的这个想法。

【深度学习】详解 MAE
图 1:MAE 架构。

在预训练过程中,大量的图像 patches 的随机子集 (如 75%) 被 mask。

​​​该编码器应用于可见 patches 的小子集。

在编码器之后引入 mask tokens,全部经编码的可见 patches 和不可见 mask tokens 由一个小解码器处理,从而按像素重建原始图像。

预训练后,解码器被舍弃,编码器被应用于未损坏的图像 (完整的 patches 集合) 进行识别任务。

        Masking。我们 按照 ViT 将一幅图像划分成规则无重叠的 (non-overlapping) patches。然后,从所有 patches 中采样一个子集,并 mask (即移除) 其余未被采样的 patches。采样策略很简单:按照均匀分布随机采样 patches 而不替换。我们仅将其称为 “随机采样”。

        具有 高 masking 比例的随机采样 (即被移除的 patches 的比例) 很大程度上消除了冗余,从而创建了一个不易通过从可见的相邻 patches 进行外推/外插 (extrapolation) 来解决的任务 (见图 2-4)。均匀分布防止了潜在的中心偏差 (即在图像中心附近有更多的 masked patches)。最后,高度稀疏的输入为设计一个有效的编码器创造了机会,接下来将介绍。

【深度学习】详解 MAE

【深度学习】详解 MAE

        MAE encoder。编码器是一个只应用于 visible & unmasked patches 的 ViT。正如一个标准 ViT,编码器 通过一个添加了位置嵌入的线性投影 来嵌入 patches,然后通过一系列 Transformer blocks 处理结果集。然而,编码器只操作于整个集合的一小个子集 (如 25%)。删除 masked patches;不使用 mask tokens。这允许 只用一小部分的计算和内存 来训练非常大的编码器。完整的集合由一个轻量级解码器处理。

        MAE decoder。MAE 解码器的输入是完整的 tokens 集合,包括 (i) 经编码的可见 patches,和 (ii) mask tokens。见图 1。每个 mask token 是一个共享的、经学习的向量,它指示要预测的缺失 patches 的存在。我们 向这个完整集中的所有 tokens 添加位置嵌入,否则 mask token 就没有关于它们在图像中所处的位置的信息。该解码器有另一系列的 Transformer blocks。

        MAE 解码器 仅用于在预训练时执行图像重建任务 (只有编码器产生用于识别的图像表示)。因此,解码器架构可以以独立于编码器设计的方式灵活设计。实验使用非常小的解码器,比编码器更窄更浅 (narrow and shallow)。例如,相比于编码器,默认的解码器的每个 token 计算 < 10%。通过这种不对称的设计,完整的 tokens 集只由轻量级的解码器来处理,这大大减少了预训练时间。

        Reconstruction target。MAE 通过预测每个 masked patch 的像素值来重建输入。解码器输出的每个元素 都是一个 代表一个 patch 的像素值的向量。解码器的最后一层是一个线性投影,其输出通道数 等于 一个 patch 中的像素值数 (如果 L = 单个 patch 的像素数,那么 D = patch 数 ???)。解码器的输出被 reshaped 以形成一幅经重建的图像。损失函数为像素空间中经重建图像和原始图像间的 MSE,且 只计算在 masked patches 上的损失,类似于 BERT (不同于 计算所有像素损失的传统 denoising autoencoders (DAE),仅在 masked patch 上计算损失 纯粹是由 结果驱动的:计算所有像素上的损失会导致准确率的轻微下降,例如 ∼0.5%)。

        我们还研究了一个变体,其 重建目标 是 每个 masked patch 的归一化像素值。具体来说,计算一个 patch 中所有像素的平均值和标准差来归一化 patch。在实验中,使用归一化像素作为重建目标,改善了表示的质量

        Simple implementation。MAE 预训练可以高效地实现 而无需任何专门的稀疏操作。首先,为每个输入 patch 生成一个 token (通过添加位置嵌入的线性投影)。接着,随机 shuffle tokens 列表,并根据 masking 比例删除 tokens 列表的最后一部分。该过程为编码器产生一个小的 tokens 子集,相当于 无替换的 patches 采样 (无放回抽样?)。编码后,为经编码的 patches 列表 额外加入 一个 mask tokens 列表,并 unshuffle (随机 shuffle 的反操作) 该完整列表,以将所有 tokens 与它们的 tagrtes 对齐。解码器被用于这个完整的列表 (添加了位置嵌入)。如上所述 (as noted),无需稀疏操作。这个简单的实现引入了可以忽略不计的开销,因为 shuffle 和 unshuffle 操作很快。

四、ImageNet 实验

        我们用 ImageNet-1K (IN1K) 训练集 自监督预训练。然后,进行有监督训练,通过 (i) end-to-end fine-tuning(ii) linear-probing 来评估表示。我们报告了单个 224×224 裁剪的 top-1 验证准确率。详见附录A.1。

        Baseline: ViT-Large。消融研究 使用 ViT-Large (ViT-L/16) 作为主干。ViT-L 非常大 (比 ResNet-50 大一个数量级),且倾向于过拟合。以下是 从头开始训练的 ViT-L 微调的基线 MAE 的比较。可见,不论是原版 ViT-Large,还是添加了强正则化的 ViT-Large,性能都大幅落后于基线 MAE

【深度学习】详解 MAE

        注意,从头开始训练原版有监督 ViT-L 并不重要 (76.5%),一个具有强正则化的好方法才是需要的 (a good recipe) (82.5%, 见附录 A.2)。即便如此,我们的 MAE 预训练也有了很大的改善 (84.9%)。此处,微调 50 个 epochs,而从头开始训练 200 个 epochs,这意味着 (implying that) 微调的准确率很大程度上依赖于预训练。 

4.1 主要属性

【深度学习】详解 MAE

        我们在表 1 中使用默认设置消融 (ablate) MAE (见说明文字),并观察到了几个有趣的特性。

【深度学习】详解 MAE

        Masking ratio。图 5 显示了掩码率的影响。75% 这么高的最佳掩码率同时有利于 linear probing 和 fine-tuning。这种行为与典型掩码率为 15% 的 BERT。我们的 75% 掩码率也远高于 CV 中的相关工作 (20%-50%)。

        模型推断出缺失的 patches 以产生不同的、但看似可信的输出 (图 4)。它使物体和场景的 gestalt 变得有意义,这不能简单地通过延伸线条或纹理来完成。我们假设这种类似推理 (reasoning-like) 的行为与学习有用的表示有关。

        图 5 还显示出 linear probing 和微调结果遵循不同的趋势。对于 linear probing,准确率随着掩码率稳步增加,直到最佳点 (sweet point):准确率差距高达 ~20% (最低的 54.6% vs 最高的 73.5%)。对于微调,结果对掩码率的敏感度较低,而且掩码率在很大范围内 (40-80%) 模型都能工作得很好。图 5 中的所有微调结果都优于从头开始训练 (82.5%)。

        Decoder design。可以灵活设计 MAE 解码器,如表 1a 和 1b 所示。表 1a 改变了 解码器的深度 (Transformer blocks 数)一个足够深的解码器对 linear probing 而言是很重要的。这可以用像素重建任务和识别任务之间的差距来解释:autoencoder 的最后几层更加专门 (specialized) 用于重建,但与识别的相关性较小。一个合理的/适度的 (reasonably) 深度解码器 能解释 重建的专门化/特殊化 (specialization),将潜在的表示留在一个更抽象的层次上 (A reasonably deep decoder can account for the reconstruction specialization, leaving the latent representations at a more abstract level)。这种设计在 linear probing 方面可以实现高达 8% 的提高 (表 1a, “lin”)。然而,若用微调,则编码器的最后一层可以被调整以适应识别任务。解码器深度对改进微调的影响较小 (表 1a, “ft”)。

        有趣的是,具有单个 block 的解码器 的 MAE 可通过微调实现强大的性能 (84.8%)。注意,单个 Transformer block 是将信息从可见 token 传播到 mask tokens 的最小要求。这样一个小的解码器可以进一步加快训练速度。

        在表 1b 中,我们研究了 解码器的宽度 (通道数)。默认使用的 512-d 在微调和 linear probing 下均表现良好。一个较窄的解码器也可以很好地进行微调。

        总之,默认的 MAE 解码器是轻量级的,有 8 个 blocks,512-d 的宽度 (表 1 中的灰色),每个 token 只有 9% 的 FLOPs;而 ViT-L 则有24 个 blocks 和 1024-d 的宽度)。因此,虽然解码器处理所有的 tokens,但这仍是整个计算的一小部分。

        Mask token。MAE 的一个重要设计是跳过编码器中的 mask token 【深度学习】详解 MAE,然后将其应用到轻量级解码器中。表 1c 研究了该设计。

        编码器如果使用 mask token 会表现更差:在 linear probing 中,其精度下降了14%。在这种情况下,预训练和部署之间存在一个鸿沟:这个编码器在预训练的输入中有大量的 mask token,然而 mask token 在未损坏的图像中是不存在的。这种鸿沟可能会降低部署的准确率。通过从编码器中删除 mask token,能够约束编码器始终看到真实存在的 patches,从而提高准确率

        此外,通过跳过编码器中的 mask token,可以大大减少训练计算量。在表 1c 中,我们将总训练 FLOPs 减少了 3.3×。这导致了在我们的实现中一个 2.8× 的 wall-clock 加速 (见表 2)。对于 一个更小的解码器 (1-block),一个更大的编码器 (ViT-H),或二者都有时,wall-clock 加速甚至更大 (3.5-4.1×)。注意,对于 75% 的掩码率,加速可以 > 4×,部分原因是自注意力的复杂度是平方的。此外,内存的大大减少,使得我们可以训练更大的模型 或 通过大 batch 极大地加速训练。时间和内存的高效性使我们的 MAE 有利于训练非常大的模型。

        Reconstruction target。表 1d 中比较了不同的重建目标。到目前为止,我们的结果是 基于未经 (per-patch) 归一化的像素。使用经归一化的像素可以提高准确率。这种per-patch 归一化 增强了局部的对比度。在另一种变体中,我们在 patch 空间中执行 PCA,并用最大的 PCA 系数 (此处为 96) 作为 target,但这样做会降低准确率。这两个实验都表明,高频分量在我们的方法中是有用的

        我们还比较了一种预测 tokens 的 MAE 变体,其 target 使用于 BEiT。特别是对于这种变体,我们遵循 BEiT 使用 DALLE 预训练的 dVAE 作为 tokenizer。此处,MAE 解码器使用交叉熵损失来 预测 token indices。这种 tokenization 与未经归一化的像素相比提高了 0.4% 的微调准确率,但与经归一化像素相比则没有优势。这种 tokenization 还降低了 linear probing 的准确率。在 §5 中,我们进一步证明了 tokenization 在迁移学习中不是必要的

        我们的基于像素的 MAE 比 tokenization 要简单得多。dVAE tokenizer 需要一个额外的预训练阶段,这可能依赖于额外的数据 (250M 图像)。dVAE 编码器是一个大型的卷积网络 (ViT-L 40% 的 FLOPs),并增加了大量的开销 (adds nontrivial overhead)。使用像素则并未遇到这些问题

        Data augmentation。表 1e 研究了数据扩增对 MAE 预训练的影响。

        MAE 使用只裁剪 (cropping-only) 的扩增,无论是固定尺寸还是随机尺寸 (但都有随机水平翻转),效果都很好。添加 color-jittor 会降低结果,所以我们不在其他实验中使用它。

        令人惊讶的是,即使没用数据扩增 (只有中心裁剪 没有翻转),MAE 也表现得很好这一特性与严重依赖于数据扩增的对比学习及相关方法显著不同。据观察,对于 BYOL 和 SimCLR,使用只裁剪的扩增会分别降低 13% 和 28% 的准确率。此外,没有证据表明对比学习可以在无需扩增的情况下工作:一幅图像的两个视图 (view) 相同,且可以很容易地满足一个平凡的解 (trivial solution)。

        在 MAE 中,数据扩增的作用主要通过随机 masking (ablated next) 来实现。每次迭代的 masks 都不同,所以无论如何数据扩增,它们都会生成新的训练样本。前置任务 (pretext) 因 masking 变得困难,且需要较少的扩增来正则化训练。

【深度学习】详解 MAE

        Mask sampling strategy。表 1f 比较了不同的 mask 采样策略,如图 6 所示。

        BEiT 中提出的 block-wise masking 策略倾向于删除大的 blocks (图 6 中间)。我们的 block-wise masking MAE 在 50% 的比例下工作得相当有效,但在 75% 的比例下性能下降。这个任务比随机抽样更难,因为观察到了更高的训练损失。重建结果也更加模糊。

        我们还研究了 grid-wise 采样,它规律地保留每 4 个 patches 中的 1 个 (图 6 右侧)。这是一项更容易完成的任务,且训练损失也更低。重建结果更加锐利。但是,表示的质量更低。

        简单随机抽样最适合 MAE。它允许一个更高的掩码率,这提供了一个更大的加速收益,同时也享受良好的准确率。

【深度学习】详解 MAE

        Training schedule。到目前为止,我们的消融是基于 800-epoch 的预训练。图 7 显示了训练计划长度的影响。随着更长时间的训练,准确率会稳步提高。事实上,即使在 1600 个 epoches 时也未能观察到 linear probing 准确率的饱和。这种行为有别于对比学习方法,例如,MoCo v3 在 ViT-L 的 300 个 epoches 时便达到饱和。注意,MAE 编码器在每个 epoch 只看到 25% 的 patches;而在对比学习中,编码器每个 epoch 可以看到 200% (两次复制) 甚至更多 (多次裁剪) 的 patches。

4.2 与先前结果的对比

【深度学习】详解 MAE

        Comparisons with self-supervised methods。表 3 比较了自监督 ViT 模型的微调结果。对于 ViT-B,所有方法的性能都很接近。对于ViT-L,方法之间的差距则更大,这表明 更大的模型面临的挑战是降低过拟合

        MAE 可以很容易地扩大 (scale up),并从更大的模型中显示出稳步改善。我们用 ViT-H (尺寸 224) 获得了 86.9% 的准确率。仅用 IN1K 数据,通过用尺寸 448 微调实现了 87.8% 的准确率。在所有仅用 IN1K 数据的方法中,之前基于先进的网络的最佳准确率是 87.1% (尺寸 512)。在高度竞争激烈的 IN1K 基准测试 (没有外部数据) 中,我们以显著的优势 (by a nontrivial margin) 提高了 SOTA 水平。我们的研究结果 基于普通的 (vanilla) ViT,我们期望先进的网络会表现得更好。

        与 BEiT 相比,MAE 更准确、更简单、更迅速。MAE 重建像素,与预测 tokens 的 BEiT 相比:BEiT 报告了在使用 ViT-B 重建像素时,退化了 1.8% (我们在使用 ViT-L 的 BEiT 中也观察到了退化:它产生了 85.2% (tokens) 和 83.5% (pixels),拷贝自官方代码)。我们无需 dVAE 预训练。此外,MAE 比 BEiT 要快得多 (3.5× 每 epoch),原因如表 1c 所示。

        表 3 中的 MAE 模型预训练了 1600 个 epochs 以获取更好的准确率 (图 7)。即便如此,在相同的硬件上训练时,我们的预训练总时间比其他方法少。例如,在 128 个TPU-v3 内核上训练 ViT-L,MAE 训练 1600 epochs 用时 31 小时,而 MoCov3 训练 300 epochs 用时 36 小时。

【深度学习】详解 MAE

        Comparisons with supervised pre-training。在最初的 ViT 论文中,ViT-L 在 IN1K 中训练时退化。我们实施的有监督训练 (见 A.2) 效果更好,但准确率达到饱和。见图 8。

        MAE 的预训练只用 IN1K,可以更好地泛化:对于更高容量的模型,从头开始训练的增益更大。它遵循了类似于 ViT 原文中的 JFT-300M 监督预训练的趋势。这种比较表明,MAE 有助于扩大模型尺寸。

4.3 部分微调

【深度学习】详解 MAE
微调 0 个 blocks 即为 linear probing (只有最后的 FC 分类器可学习),24 个即为完全微调 (所有层都可以学习)

        表 1 显示,linear probing 和微调的结果在很大程度上是不相关的 (largely uncorrelated)。在过去的几年中,linear probing 一直是一种流行的 protocol;然而,它错过了追求强大但非线性特征的机会 —— 这确实是深度学习的一种优势。作为中间立场 (as a middle ground),我们研究了一个 部分微调 protocol:微调最后几层,同时冻结其他层。这个 protocol 也被用于早期的工作中。

        图 9 显示了结果。值得注意的是 (notably),仅微调一个 Transformer block 就可以将准确率从 73.5% 显著提高到 81.0%。此外,如果只微调最后一个 block (即它的 MLP sub-block),可以得到 79.1%,比 linear probing 好得多。这个变体本质上是对 MLP 头进行微调。微调几个 blocks (例如 4 或 6 个) 可以实现接近完全微调的准确率

        图 9 中还与 MoCo v3 比较,这是一种与 ViT-L 结果相结合的对比方法。MoCo v3 的 linear probing 准确率较高,但它所有部分微调结果都不如 MAE。当调优 4 个块时,差距为 2.6%。虽然 MAE 表示的线性可分离性较少,但它们具有更强的非线性特征,且当非线性头部被调优时表现良好。

        这些观察结果表明,线性可分性 (linear separability) 并不是评价表示质量的唯一指标。人们还观察到,linear probing 与迁移学习性能并没有很好的相关性,例如,对于目标检测。据我们所知,线性评价在 NLP 中并不经常用于预训练的基准测试

五、迁移学习实验

【深度学习】详解 MAE

        我们使用表 3 中的预训练模型来评估下游任务中的迁移学习。

【深度学习】详解 MAE

        Object detection and segmentation。我们在 COCO 上对 Mask R-CNN 进行端到端微调。ViT 主干适和与 FPN 共用 (见 A.3)。我们对表4 中的所有条目都应用了这种方法。我们报告了用于目标检测的 box AP 和用于实例分割的 mask AP。

        与有监督的预训练相比,MAE 在所有配置下都表现得更好 (表 4)。对于较小的 ViT-B,MAE 比有监督的预训练高 2.4 个点 (50.3 vs. 47.9, APbox)。更重要的是,对于较大的 ViT-L,MAE 预训练比有监督的预训练高 4.0 个点 (53.3 vs. 49.3)。基于 pixel 的 MAE 比基于 token 的 BEiT 更好或相当,而 MAE 更简单、更快。MAE 和 BEiT 都优于 MoCov3,MoCov3 与有监督的预训练相当

【深度学习】详解 MAE

        Semantic segmentation。我们用 UperNet 在 ADE20K 进行实验 (见 A.4)。表 5 显示,我们的预训练 比 有监督预训练 显著提高了结果,例如,ViT-L 提高了 3.7 个点。基于 pixel 的 MAE 也优于基于 token 的 BEiT。这些观察结果与在 COCO 上的一致。

【深度学习】详解 MAE

        Classification tasks。表 6 研究了在 iNaturalists 和 Places 任务上的迁移学习 (见 A.5)。在 iNat 上,我们的方法显示出很强的扩展行为:准确率随着模型增大显著提高。我们的结果远超 (by large margins) 之前的最佳结果。在 Places 上,MAE 优于之前的对数十亿张图像预训练获得的最佳结果。

【深度学习】详解 MAE

         Pixels vs. tokens。表 7 比较了 pixel 与 token 作为 MAE 重建目标。虽然使用 dVAE token 比使用未经归一化的像素更好,但它在统计上与 在我们测试的所有案例中 使用的经归一化像素相似。这再次表明,tokenization 对 MAE 并不是必要的

六、讨论与结论 

        具有良好扩展性 (scale well) 的简单算法是深度学习的核心。在 NLP 中,简单的自监督学习方法可以从指数级扩展 (exponentially scaling) 的模型中获益。在 CV 中,尽管在自监督学习方面取得了进展,但实际的预训练范式仍主要是有监督的。在这项研究中,我们在 ImageNet 和迁移学习中观察到,一个 autoencoder —— 一种类似于 NLP 技术的简单的自监督方法 —— 提供了可扩展/放缩 (scalable) 的好处。视觉中的自监督学习现在可能开始了 (be embarking on) 与 NLP 类似的轨迹

        另一方面,我们注意到图像和语言是不同性质的信号,必须仔细处理这种差异。图像只是记录到的光,它没有语义分解成单词的视觉模拟。我们没有试图删除目标物体,而是删除了那些很可能不会形成语义分片 (segment) 的随机 patches。同样地,我们的 MAE 重建了像素,它们不是语义实体 (semantic entities)。然而,我们观察到 (如图 4),我们的 MAE 推断出了复杂的、整体的重建,表明它已经学习了许多视觉概念,即语义 (semantics)。我们 假设这种行为是基于 在 MAE 内部的丰富的隐藏表示 而发生的。我们希望这一观点能启发未来的工作。

        Broader impacts。所提出的方法 基于训练数据集的已学习到统计信息 (statistic) 来预测内容,因此将反映这些数据中的偏差/偏置 (bias),包括那些具有负面社会影响的数据。该模型可能会生成不存在的内容。当在这基础上生成图像时,这些问题值得进一步研究 (warrant further research) 和考虑。

七、核心代码

# https://github.com/facebookresearch/mae/blob/main/models_mae.py
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# References:
# timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm
# DeiT: https://github.com/facebookresearch/deit
# --------------------------------------------------------
from functools import partial
import torch
import torch.nn as nn
from timm.models.vision_transformer import PatchEmbed, Block
from util.pos_embed import get_2d_sincos_pos_embed
class MaskedAutoencoderViT(nn.Module):
    """ Masked Autoencoder with VisionTransformer backbone
    """
    def __init__(self, img_size=224, patch_size=16, in_chans=3,
                 embed_dim=1024, depth=24, num_heads=16,
                 decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
                 mlp_ratio=4., norm_layer=nn.LayerNorm, norm_pix_loss=False):
        super().__init__()
        # --------------------------------------------------------------------------
        # MAE encoder specifics
        self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim)
        num_patches = self.patch_embed.num_patches
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim), requires_grad=False)  # fixed sin-cos embedding
        self.blocks = nn.ModuleList([
            Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, qk_scale=None, norm_layer=norm_layer)
            for i in range(depth)])
        self.norm = norm_layer(embed_dim)
        # --------------------------------------------------------------------------
        # --------------------------------------------------------------------------
        # MAE decoder specifics
        self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True)
        self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))
        self.decoder_pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, decoder_embed_dim), requires_grad=False)  # fixed sin-cos embedding
        self.decoder_blocks = nn.ModuleList([
            Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, qk_scale=None, norm_layer=norm_layer)
            for i in range(decoder_depth)])
        self.decoder_norm = norm_layer(decoder_embed_dim)
        self.decoder_pred = nn.Linear(decoder_embed_dim, patch_size**2 * in_chans, bias=True) # decoder to patch
        # --------------------------------------------------------------------------
        self.norm_pix_loss = norm_pix_loss
        self.initialize_weights()
    def initialize_weights(self):
        # initialization
        # initialize (and freeze) pos_embed by sin-cos embedding
        pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=True)
        self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
        decoder_pos_embed = get_2d_sincos_pos_embed(self.decoder_pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=True)
        self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0))
        # initialize patch_embed like nn.Linear (instead of nn.Conv2d)
        w = self.patch_embed.proj.weight.data
        torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
        # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)
        torch.nn.init.normal_(self.cls_token, std=.02)
        torch.nn.init.normal_(self.mask_token, std=.02)
        # initialize nn.Linear and nn.LayerNorm
        self.apply(self._init_weights)
    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            # we use xavier_uniform following official JAX ViT:
            torch.nn.init.xavier_uniform_(m.weight)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
    def patchify(self, imgs):
        """
        imgs: (N, 3, H, W)
        x: (N, L, patch_size**2 *3)
        """
        p = self.patch_embed.patch_size[0]
        assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0
        h = w = imgs.shape[2] // p
        x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p))
        x = torch.einsum('nchpwq->nhwpqc', x)
        x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 3))
        return x
    def unpatchify(self, x):
        """
        x: (N, L, patch_size**2 *3)
        imgs: (N, 3, H, W)
        """
        p = self.patch_embed.patch_size[0]
        h = w = int(x.shape[1]**.5)
        assert h * w == x.shape[1]
        x = x.reshape(shape=(x.shape[0], h, w, p, p, 3))
        x = torch.einsum('nhwpqc->nchpwq', x)
        imgs = x.reshape(shape=(x.shape[0], 3, h * p, h * p))
        return imgs
    def random_masking(self, x, mask_ratio):
        """
        Perform per-sample random masking by per-sample shuffling.
        Per-sample shuffling is done by argsort random noise.
        x: [N, L, D], sequence
        """
        N, L, D = x.shape  # batch, length, dim
        len_keep = int(L * (1 - mask_ratio))
        noise = torch.rand(N, L, device=x.device)  # noise in [0, 1]
        # sort noise for each sample
        ids_shuffle = torch.argsort(noise, dim=1)  # ascend: small is keep, large is remove
        ids_restore = torch.argsort(ids_shuffle, dim=1)
        # keep the first subset
        ids_keep = ids_shuffle[:, :len_keep]
        x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
        # generate the binary mask: 0 is keep, 1 is remove
        mask = torch.ones([N, L], device=x.device)
        mask[:, :len_keep] = 0
        # unshuffle to get the binary mask
        mask = torch.gather(mask, dim=1, index=ids_restore)
        return x_masked, mask, ids_restore
    def forward_encoder(self, x, mask_ratio):
        # embed patches
        x = self.patch_embed(x)
        # add pos embed w/o cls token
        x = x + self.pos_embed[:, 1:, :]
        # masking: length -> length * mask_ratio
        x, mask, ids_restore = self.random_masking(x, mask_ratio)
        # append cls token
        cls_token = self.cls_token + self.pos_embed[:, :1, :]
        cls_tokens = cls_token.expand(x.shape[0], -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        # apply Transformer blocks
        for blk in self.blocks:
            x = blk(x)
        x = self.norm(x)
        return x, mask, ids_restore
    def forward_decoder(self, x, ids_restore):
        # embed tokens
        x = self.decoder_embed(x)
        # append mask tokens to sequence
        mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1)
        x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1)  # no cls token
        x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2]))  # unshuffle
        x = torch.cat([x[:, :1, :], x_], dim=1)  # append cls token
        # add pos embed
        x = x + self.decoder_pos_embed
        # apply Transformer blocks
        for blk in self.decoder_blocks:
            x = blk(x)
        x = self.decoder_norm(x)
        # predictor projection
        x = self.decoder_pred(x)
        # remove cls token
        x = x[:, 1:, :]
        return x
    def forward_loss(self, imgs, pred, mask):
        """
        imgs: [N, 3, H, W]
        pred: [N, L, p*p*3]
        mask: [N, L], 0 is keep, 1 is remove, 
        """
        target = self.patchify(imgs)
        if self.norm_pix_loss:
            mean = target.mean(dim=-1, keepdim=True)
            var = target.var(dim=-1, keepdim=True)
            target = (target - mean) / (var + 1.e-6)**.5
        loss = (pred - target) ** 2
        loss = loss.mean(dim=-1)  # [N, L], mean loss per patch
        loss = (loss * mask).sum() / mask.sum()  # mean loss on removed patches
        return loss
    def forward(self, imgs, mask_ratio=0.75):
        latent, mask, ids_restore = self.forward_encoder(imgs, mask_ratio)
        pred = self.forward_decoder(latent, ids_restore)  # [N, L, p*p*3]
        loss = self.forward_loss(imgs, pred, mask)
        return loss, pred, mask
def mae_vit_base_patch16_dec512d8b(**kwargs):
    model = MaskedAutoencoderViT(
        patch_size=16, embed_dim=768, depth=12, num_heads=12,
        decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
        mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    return model
def mae_vit_large_patch16_dec512d8b(**kwargs):
    model = MaskedAutoencoderViT(
        patch_size=16, embed_dim=1024, depth=24, num_heads=16,
        decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
        mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    return model
def mae_vit_huge_patch14_dec512d8b(**kwargs):
    model = MaskedAutoencoderViT(
        patch_size=14, embed_dim=1280, depth=32, num_heads=16,
        decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
        mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    return model
# set recommended archs
mae_vit_base_patch16 = mae_vit_base_patch16_dec512d8b  # decoder: 512 dim, 8 blocks
mae_vit_large_patch16 = mae_vit_large_patch16_dec512d8b  # decoder: 512 dim, 8 blocks
mae_vit_huge_patch14 = mae_vit_huge_patch14_dec512d8b  # decoder: 512 dim, 8 blocks