type
status
date
slug
summary
tags
category
icon
password
这篇论文和【CVPR2024】Distribution-Consistent Modal Recovering for Incomplete Multimodal Learning 简直是一稿多投。所以其细节请参考那篇文章,这里就做细致的代码分析论证吧。不过不得不称赞一句,这位作者的代码十分优雅规范,层级清晰,复现方便,也因此有了兴趣仔细学习一下。
6. 代码分析
6.1 Dataloader
我们先来分析下数据集的读取。它是pickle的文件格式,我们这样读取就好:·
读取出来的data是什么呢?简单的用 type( ) 看一下就知道它是一个字典:
其每个键对应的值都是一个字典,下面以train_data为例:
从每个值的含义都能看出,前四个是各个模态的数据,可以转换为3维的nd array,表示文本、音频或视觉的特征,3个维度分别是时间步、序列长度和特征维度。
其余的模块一一拆散了写在这里可能会导致逻辑上的断联,下面贴上代码,以注释的形式呈现:
6.2 模型主体
6.2.1 class IMDER( )
类
这个类是由 ATIO类(All Trains in One) 创建的,需要传入 args 这一个参数。args 是一个规模较大的 dict,里面有模型训练需要的所有参数,以 .json 形式呈现。
该类出去定义计算Loss和评价标准的
__init__
外,还有两个方法,分别用来训练和测试。def do_train(self, model, dataloader, return_epoch_results=False)
方法
平平常常的定义了优化器:
系列初始化:
加载预训练模型:
后面就是很寻常的forward backward了。这里提取其中个人觉得有意思的一个地方
6.2.2 class IMDER(nn.Module)
类
这个类虽然和上面介绍的类重名,但是这两个类不在同一个文件哦。上面那个类主要是管理训练、测试中的各种细节,这个类继承自
nn.Module
,可是一个标准的基于 Pytorch 的神经网络模型。- 初始化方法
1.
__init__
函数__init__
是 Python 中的类构造函数,它负责在创建对象时初始化对象的属性。在 PyTorch 的 nn.Module
中,__init__
函数通常用于定义模型的层和其他组件。- 调用
super(IMDER, self).__init__()
来初始化父类nn.Module
的内容,这是每个自定义 PyTorch 模型类的标准做法。
2. 使用 BERT 文本编码器
args.use_bert
:这个参数决定是否使用 BERT 模型来处理文本数据。如果args.use_bert
为True
,模型将使用一个BertTextEncoder
(可能是自定义的 BERT 文本编码器类)来对文本进行编码。- BERT(Bidirectional Encoder Representations from Transformers) 是一种强大的预训练语言模型,用于提取文本特征。通过
BertTextEncoder
,IMDER
可以利用 BERT 的强大能力对文本进行特征提取。
use_finetune
:如果设置为True
,模型将对 BERT 进行微调(fine-tuning),即在训练过程中更新 BERT 模型的参数,否则 BERT 的参数将被冻结。
transformers
和pretrained
:这些参数用于指定使用哪个版本的 BERT 模型,以及是否从预训练模型开始。
3. 模态统一维度
dst_feature_dims
:这是目标特征维度,用于对不同模态(如文本、音频、视频)的特征进行统一表示。这里d_l
、d_a
、d_v
分别表示文本(l = linguistic)、音频(a = acoustic) 和 视频(v = visual) 的目标特征维度。- 模态统一维度是指将不同模态的特征通过嵌入、线性变换等手段转化为相同的维度,以便于之后的多模态融合。
nheads
:这个变量代表用于注意力机制的**多头注意力(Multi-head Attention)**中的头数。多头注意力是 Transformer 中的重要机制,允许模型从多个不同的角度对输入进行建模。
args.feature_dims
:原始模态的维度,分别是文本、音频和视频的输入特征维度。orig_d_l
、orig_d_a
和orig_d_v
保存这些模态的原始特征维度。
4. 多头注意力机制相关参数
num_heads
:注意力机制中使用的头的数量,即多头注意力中的头数。多个头可以捕获不同子空间中的表示。
layers
(nlevels
):模型的层数,可能是模型中 Transformer 编码器/解码器的层数。
attn_dropout
:注意力机制中的 dropout 率,用于防止过拟合。attn_dropout_a
和attn_dropout_v
分别是音频和视频模态的注意力机制中的 dropout 率。
5. 其他 Dropout 参数
relu_dropout
:在 ReLU 激活后的 Dropout 率。Dropout 是一种正则化技术,通过随机丢弃神经元来防止模型过拟合。
embed_dropout
:嵌入层的 Dropout 率,可能用于对输入(如嵌入后的文本特征)进行正则化。
res_dropout
:残差连接中的 Dropout 率。残差连接(ResNet)帮助信息在深层网络中更容易传播。
output_dropout
:输出层的 Dropout 率,用于输出特征的正则化。
text_dropout
:对文本特征的 Dropout 率,用于文本模态的正则化。
6. Attention Mask
attn_mask
:注意力机制中的掩码(mask),用于在计算注意力权重时屏蔽一些无关的或无效的部分。常见于序列数据(如文本),用于避免模型关注到填充的部分(padding tokens)。
7. MSE 损失函数
MSE()
:定义了一个均方误差(MSE)损失函数。MSE 是一种常见的损失函数,通常用于回归任务,目标是最小化预测值与真实值之间的差异平方和。
这个作者自己写了个均方误差函数,而不是调用的torch库里的:
模型整体架构解释
- IMDER 这个模型的设计表明它是一个多模态模型,可以处理文本、音频和视频等不同模态的数据。
- 通过使用 BERT 来处理文本模态,IMDER 可能会对文本进行深度特征提取,结合音频和视觉模态的特征进行统一的多模态融合。
- 通过统一模态之间的特征维度(
dst_feature_dims
),并且使用多头注意力机制(nheads
),该模型可以捕获不同模态之间的相互关系。
- 各种 Dropout 参数表明模型中使用了多种正则化手段,以防止过拟合,并提高模型的泛化能力。
- 最终,模型可能使用均方误差(MSE)作为损失函数,表明这是一个回归任务,目标是预测某种连续值(如感情强度、情感分数等)。
总结
- IMDER 是一个复杂的多模态神经网络模型,主要用于处理多种模态数据(文本、音频、视频),并通过统一模态维度的方式融合这些数据。
- 模型使用 BERT 来处理文本模态,并通过注意力机制来捕捉不同模态之间的相互关系。
- 使用了大量的 dropout 正则化手段来防止模型过拟合,并使用均方误差作为损失函数来进行回归任务。
如果你想了解模型的具体实现细节,比如
BertTextEncoder
或者 MSE
的定义,请提供更多相关代码,我可以进一步解释。- 作者:CrystalPuNK
- 链接:https://crystalpunk.top/article/papers/IMDer
- 声明:本文采用 CC BY-NC-SA 4.0 许可协议,转载请注明出处。
相关文章
【nature】Accurate predictions on small data with a tabular foundation model

【ICLR 2024】A Flexible Generative Model for Heterogeneous Tabular EHR with Missing Modality
【ML】队列数据 Cohort Data

【ML】扩散薛定谔桥&扩散模型

Cross Domain Skeleton-based Action Recognition by Using Parameter-free Similarity Matching
【NIPS2017】Attention Is All You Need