Method
Data2vec combines masked prediction with the learning of latent target representations but generalizes the latter by using multiple network layers as targets and shows this approach works across several modalities.
Data2vec is trained by predicting the model representations of the full input data given a partial view of the input. The model in student mode learns to predict the model representations from a masked version of the input. The training targets are produced by the model in teacher model from the unmasked version of the input. The weights of the teacher are an exponentially decaying average of the student.
Model architecture. The authors use the standard Transformer architecture with a modality-specific encoding of the input data.
-
Computer vision: ViT-strategy that encoding images as 16x16 patches followed by a linear transformation.
-
Speech: multi-layer 1D convolutional neural network that maps 16 kHz waveform to 50 Hz representations.
-
Text: embeddings of sub-word units in distributional space via learned embedding vectors.
Masking. Part of the input tokens are substituted by a learned MASK embedding token.
-
Computer vision: block-wise masking strategy.
-
Speech: mask spans of latent representations.
-
Text: mask tokens.
Teacher parameterization. The weights of the teacher model is an exponentially moving average (EMA) of the model parameters given by
\[ \Delta \leftarrow \tau\Delta + (1-\tau)\Delta \]
where a linear warm-up schedule is applied to $\tau$. Further, the teacher and student model shares the parameters of the feature encoder and the positional encoder.
Training targets. Training targets are constructed based on the output of the top $K$ blocks of the teacher network for time-steps which are masked. The output of block $l$ at time-step $t$ is denoted as $a_t^l$, which is then normalized as $\hat{a}_t^l$. Finally the training target is
\[ y_t = \frac{1}{K} \sum_{l=L-K+1}^L \hat{a}_l^t \]
Normalizing the targets helps prevent the model from collapsing into a constant representation for all time-steps and it also prevents layers with high norms to dominate the target features. The authors note that these representations are contextualized representations, due to the use of self-attention in the Transformer network.
Objective. A smooth L1 loss between the predicted representations of the student model and the training targets.
|