type
status
date
slug
summary
tags
category
icon
password

原文摘要翻译

表格数据,即以行列形式组织的电子表格,在从生物医学到粒子物理、经济学和气候科学等多个科学领域中无处不在。基于其余列来填补标签列缺失值这一基本预测任务对于各种应用至关重要,这些应用范围广泛,包括生物医学风险模型、药物发现和材料科学。尽管深度学习已经在原始数据学习方面带来了革命性变化,并产生了众多瞩目的成功案例,但在过去20年中,梯度提升决策树在处理表格数据方面占据主导地位。本文介绍了一种名为 Tabular Prior-data Fitted Network(TabPFN)的表格基础模型,它在样本量不超过10,000的数据集上大幅超越了所有先前的方法,且训练时间显著减少。在2.8秒内,TabPFN在一个分类设置中的表现超过了经过4小时调优的强大基线模型组合。作为一个基于generative transformer的基础模型,该模型还支持微调、数据生成、密度估计和学习可复用嵌入。TabPFN是一种通过在数百万个合成数据集上学习而来的学习算法,展示了这种方法在算法开发中的强大能力。通过提高跨多个领域的建模能力,TabPFN有潜力加速科学研究并增强各个领域的重要决策制定。

速览表
T 目标
跨域能力强的的表型数据生成
I 输入(推理)
一个有部分单元格丢失的表格
P 处理(推理)
单次前向生成过程(模型架构总体而言类似Encoder-Only Transformer)
O 输出(推理)
恢复好的表格
P 问题
训练数据集量小;训练出的模型若不优化,性能开销过大
C 条件
应用于最多10000个样本*500个特征的表型数据
D 难点
在 表格数据中,相同的数值在不同数据集中可能代表完全不同的含义。这种高度专业化导致了大量小型、独立的数据集,每个数据集可能需要不同的模型——泛化能力差、难以迁移学习
L 水平
发表于《nature》正刊、“它在样本量不超过10,000的数据集上大幅超越了所有先前的方法,且训练时间显著减少。在2.8秒内,TabPFN在一个分类设置中的表现超过了经过4小时调优的强大基线模型组合。”

1. 功能介绍

总结:跨域能力强的表格数据修复模型(TabPFN)+用于生成表格数据的结构化因果模型(SCM)
  • TabPFN
TabPFN 使用起来很像用cahtGPT时先给一个prompt。给定一个有部分单元格缺失的表格,经过模型处理即可生成缺失数据,从而修复表格。
该TabPFN先是在数百万个各种表格数据上训练,得到一个基础模型。在使用该模型时,模型接受一个从未见过的有数据丢失的表格,输出一个修复了缺失数据的表格。
  • SCM
作者应该只是对前人工作进行了表型数据结构上的适配后,借用来生成表型数据的。作者在原文中给出了引用,而且没有任何实验证明SCM生成数据的质量,也没有公开代码。该文宣传的效果具有一定不可复现性。
  • TabPFN的扩展功能
TabPFN 还具备一些基础模型的核心能力,作者指出,能做到数据生成,密度估计、学习可复用的嵌入和微调功能。

2. 模型架构

这篇工作可以粗略的描述为“一个用大量表型数据预训练出的‘LLM’,利用‘prompt’来引导表型数据的生成。”
前半句话,预训练所使用的“大量表型数据”是由作者魔改SCM模型生成的,“LLM”则是一个基于Transformer-Encoder架构的模型,作者在其上做出了双向注意力机制的改动,随后进行了预训练(作者使用了8 张 NVIDIA RTX 2080 GPU,训练 2 周)。
后半句,“利用‘prompt’”指的是上下文学习,作者用这种方式来解决预训练模型跨域能力不做的问题。除此之外,还做了一些性能优化,是的上下文学习效率提升了很多很多倍(个人觉得作者提出翻了5000倍有点“米式对比法”)
文章的主要贡献是前半句话概括的内容。后半句的上下文学习,我觉得作者只在硬套概念,仅仅是解决了一些工程上的问题。用于生成预训练数据的SCM模型既没有实验,也没有公开代码。

双向注意力机制

Transformer 模型主要处理序列数据,并通过注意力机制 在序列项之间聚合信息,从而有效捕捉长程依赖关系,学习数据中的复杂模式。虽然你也可以把表型数据当序列交给transformer处理,也能跑,但是这显然无法利用表格数据的结构信息。
为了更好地利用表格数据的结构特性,作者提出了一种新的架构,该架构借鉴了 [22,28] 的研究成果,为表格中的每个单元格赋予独立的表示。如图 1b 所示,作者的架构采用双向注意力机制,其中:
  • 每个单元格 先关注 同一行的其他特征(即同一个样本的不同特征)。
  • 然后,它再关注 同一列的其他样本(即所有样本在该特征上的值)。
这种设计使得模型 对样本顺序和特征顺序具有不变性(permutation invariance),从而可以更高效地训练,并在推理时适应比训练时更大规模的表格数据(无论是样本数量还是特征数量)。
notion image
这种设计使得模型 对样本顺序和特征顺序具有不变性(permutation invariance),从而可以更高效地训练,并在推理时适应比训练时更大规模的表格数据(无论是样本数量还是特征数量)。

训练状态缓存优化机制

在这里,作者对模型推理性能做出了优化。为了理清思路,我们先来简单分一下类。
  1. 传统机器学习模型(如 XGBoost、随机森林、神经网络等)中,训练和推理是两个独立的阶段,这被该文作者称为“Fit-Predic”方式——训练好的模型可以对多个测试集进行推理
  1. Transformer 里的上下文学习方式(类似于平时用chatGPT时先给一个Prompt那样)。训练样本和测试样本是一起输入模型的,整个过程在单次前向传播中完成。上下文学习模式不会有模型的参数更新,模型通过注意力机制在训练数据和测试数据之间建立联系。(注意区分上下文学习的训练集与预训练大模型本身的训练集。)
    1. 但是存在这些缺点。对于通用大语言模型,其实影响不大,但对于表型数据来说,每一行都是一个样本,计算成本就会指数级增长了。
      • 每次推理时,都必须重新计算训练数据的影响
      • 无法缓存训练数据的计算结果,导致计算开销大。
      • 当测试样本较多时,计算成本会指数级增长
那TabPFN 如何优化 ICL 呢?
TabPFN 通过 训练状态缓存 解决了 ICL 计算开销过大的问题:
  1. 先在训练集上运行 ICL(例如图1a中绿底中的 ),并缓存中间计算状态
  1. 当有新的测试数据时,直接复用训练状态,而不需要重新计算训练数据的影响
  1. 这样,训练数据只计算一次,测试数据的推理速度大幅提升
个人推测,这里是或许是上下文学习中训练集计算出的K和V矩阵缓存起来了。

训练数据集来源

为何要把数据集来源放到模型架构里呢?那是因为该文章作者魔改出了一个基于结构化因果模型[31](structural causal models ,SCMs)的方法,用这个方法来生成大量表型数据,用于上面提到的TabPFN。下图2展示了这个生成模型的框架。
notion image
SCMs 提供了一种正式框架,用于表示数据背后的因果关系和生成过程。通过依赖合成数据而非大规模的公开表格数据集,作者可以避免基础模型常见的问题,比如说隐私和版权侵犯、训练数据与测试数据污染以及真实世界表格数据的获取限制
生成流程:
  1. 图2(a):作者的数据生成管道首先随机采样高层级超参数,包括:数据集大小、特征数量、任务难度之类的,这些超参数决定了每个合成数据集的整体特性。在这些超参数的指导下,构造一个有向无环图,用于指定该数据集的因果结构
  1. 图2(b)
    1. 初始化:
        • 在因果图的根节点处,作者注入随机噪声,这被称为“初始化数据”。
        • 这些初始化数据从随机正态分布或均匀分布中采样,并且样本之间可能具有不同程度的非独立性
    2. 数据传播与计算
        • 初始化数据沿着因果图传播,通过不同的计算映射进行变换,包括:
          • 小型神经网络,使用线性或非线性激活函数(如 Sigmoid、ReLU、取模运算、正弦函数等)。
          • 离散化机制,用于生成类别特征
          • 决策树结构,用于编码局部的规则依赖
        • 在因果图的每条边上,作者添加高斯噪声,以引入数据的不确定性。
        • 每个节点的中间数据表示都会被存储,以便后续检索和使用(详见 “计算边映射” 章节)。
    3. 数据集样本的最终提取
        • 经过因果图的计算后,作者从特征节点和目标变量节点提取最终的样本,形成包含特征值和相应目标值的样本
通过上述生成流程,作者构建了一个庞大的合成数据集库,用于 TabPFN 的训练。
  • 每次模型训练时,作者会生成大约 1 亿个(100 million)合成数据集
  • 这些数据集具有独特的因果结构、特征类型和函数特性,从而赋予 TabPFN 强大的泛化能力

3. 效果

  • 性能开销
    • 作者专门进行了性能开销优化
    • 每个单元格的存储需求小于 1,000 字节,单张 H100 GPU 可以处理包含 5,000 万个单元格的数据集

定型分析

表型数据的一大特点与难点便是混合数据类型。现实世界中表格数据通常由混合的离散和连续数据类型组成,而同时对离散列和连续列进行联合建模较为困难,可能无法捕捉到单个特征列的边缘分布——各种多峰分布、非高斯分布、长尾分布混杂。作者为了证明TabPFN模型建模能力的强大,直接把各种分布类型的数据搬了出来,train一发看看拟合效果(图3(a)):
notion image
橙色点是Ground Truth,蓝色点就是生成出来的点了。
  • 线性回归只能自然地建模线性函数,这使得其预测结果简单且可解释,但在许多玩具函数上会完全失效
  • 多层感知机高度非平滑模式的数据集上表现较差 ,这一问题在阶梯函数上尤为明显。
  • TabPFN 能够直接建模各种函数类型,无论是平滑还是非平滑,甚至可以很好地逼近阶梯函数,尽管 TabPFN 本身是一个神经网络。
  • CatBoost (代表基于树的方法) 只能拟合分段常数函数,这会导致逼近误差和非直观的预测,但至少避免了完全失效的情况。
  • TabPFN 相比所有基线方法的主要优势在于,它能够以零额外成本建模不确定性
    • 传统回归方法通常输出的是单一的实值预测,而 TabPFN 返回的是目标分布,从而能够表达预测的不确定性。
    • 这种不确定性建模能力不仅适用于简单分布,还能处理复杂的多峰分布
这个实验跑下来,证明他效果还是不错。图 3b 进一步展示了这一能力:作者使用 TabPFN 建模双缝实验中光照射到探测屏上的密度分布,并考察不同的缝间距和宽度对结果的影响。
在这个经典实验中,光子通过两条狭缝后,由于光的波动干涉行为,会产生多峰强度分布。TabPFN 仅需单次前向传播即可预测这些复杂模式,计算时间仅 1.2 秒
相比之下,传统方法(如 CatBoost)需要:
  • 训练多个分位数模型,在不同分位点上进行预测,
  • 然后重建整个目标分布,这一过程更加复杂且低效。
  • 即使经过专门针对该任务的调优,CatBoost 仍然比 TabPFN 预测效果更差(详见 图 3b)。
  • 默认设置下,CatBoost 需要 169.3 秒,并且结果进一步恶化。

定量分析

实验一:分类与回归任务,及其耗时评估

作者在两个数据集集合上对 TabPFN 进行了定量评估:AutoML Benchmark 和 OpenML-CTR233 。这些基准数据集涵盖了多种真实世界的表格数据,并经过筛选以确保复杂性、相关性和领域多样性。从这些基准数据集中,作者选取了:
  • 29 个分类数据集 和 28 个回归数据集
这里的分类与回归并不是指作者基于这两个数据集做分类与回归的下游任务。而是表格中,人工去掉离散类别数据(构建为分类任务)或者连续数值数据(构建为回归任务),模型随后恢复这些数据,再将恢复的数据与Ground Truth做对比得来。作者在文章中没有写明,而是写在Github的readme.md中的
  • 数据规模最高为 10,000 个样本、500 个特征和 10 个类别。
Metrics:
  • 分类任务:ROC AUC和准确率。
  • 回归任务:R²(决定系数)、负 RMSE(均方根误差)。
得分:
  • 1.0 代表最佳性能(相对于所有基线方法)。
  • 0.0 代表最差性能
实验设置
  • 每个数据集和方法都运行 10 次实验(不同的随机种子和 90% 训练集 / 10% 测试集划分)。
  • 超参数调优使用 随机搜索(random search)+ 五折交叉验证(five-fold cross-validation),时间预算从 30 秒到 4 小时 不等。
  • 所有方法均在 8 核 CPU 上运行,TabPFN 额外使用消费级 GPU(RTX 2080 Ti)(其他方法未从 GPU 受益,详见 扩展数据图 2d)。
  • TabPFN 仅需一次预训练(使用 8 张 NVIDIA RTX 2080 GPU,训练 2 周),之后就能在所有新数据集上 单次前向传播完成 上下文学习。
结果如图4所示
notion image

实验二:评估不同数据属性的影响&数据缺失情况

在 图 5a、5b 中,作者展示了 TabPFN 对于某些传统上难以处理的数据集特征的鲁棒性,这些特征通常会对基于神经网络的方法造成挑战 。
notion image
  • 图 5a 分析了 TabPFN 在不同类型数据集上的表现
      1. 无信息特征和 异常值处理能力:
          • 无信息特征 是指从原始数据集中随机打乱的特征,这些特征对预测任务没有贡献。
          • 异常值 是指以 2% 的概率 将数据单元格的值乘以一个介于 0 和异常值因子之间的随机数。
          • 结果表明,TabPFN 对无信息特征和异常值具有很强的鲁棒性,而神经网络(如 MLP)通常难以处理这些情况。
      1. 样本或特征丢失的影响:
          • 删除数据样本或特征会对所有方法的性能造成影响。
          • 即使只剩下一半的样本,TabPFN 的性能仍然与使用全部样本的次优方法相当,表现出较强的稳定性。
  • 图 5b 进一步分析了测试数据集的不同子组:
    • 作者基于以下特征对数据集进行子组划分,并对每个子组进行单独分析:
      • 是否包含类别特征。
      • 是否存在缺失值。
      • 数据集的样本数量。
      • 数据集的特征数量。
    • 样本数和特征数划分:为了均衡数据分布,作者将 1/3 的数据集划入每个子组。
    • 结果表明,这些数据特征并不会显著影响 TabPFN 相对于其他方法的表现。
    • 但需要注意,这些结果不能作为 TabPFN 在 10,000 个样本和 500 个特征以上数据集表现良好的直接证据,即其可扩展性仍需进一步测试。

实验三:与调优的集成方法对比

走着将 TabPFN 与 AutoGluon 1.0 进行了比较。AutoGluon 采用了一种集成学习方法,结合了多种机器学习模型(包括作者的基线方法),并且:
  1. 自动调优超参数,
  1. 使用后验集成(post hoc ensembling, PHE)[42,43] 生成最终预测。
由于 AutoGluon 代表的是不同类别的方法(集成方法),作者还探讨了 TabPFN 是否也可以通过集成学习进一步提升性能。为此,作者引入了一种新方法 —— TabPFN(PHE):
  • TabPFN(PHE)自动集成多个 TabPFN 模型,
  • 使用 PHE 进行最终预测,
  • 在搜索空间中随机选择超参数进行调优。
图 5c–5d 展示了 TabPFN、TabPFN(PHE)、AutoGluon 和 CatBoost 的性能对比。
  • TabPFN(PHE)和 AutoGluon 需要一定的调优时间,因此作者设定最低 300 秒的超参数调优预算,因为 AutoGluon 在更短时间内无法稳定返回结果。
  • TabPFN(默认配置)仅需 2.8 秒,其分类任务性能就超越了 AutoGluon,即使 AutoGluon 被允许运行 4 小时。
    • 加速比:5,140×(TabPFN 仅用 2.8 秒,而 AutoGluon 需要 4 小时)。
  • TabPFN(PHE)在分类任务中进一步提升了性能:
    • TabPFN(PHE) 的归一化 ROC AUC 为 0.971,
    • TabPFN(默认)为 0.939,
    • AutoGluon 为 0.914,
    • TabPFN(PHE)在分类任务上大幅超越 AutoGluon。
对于 回归任务,超参数调优显得更加重要:
  • TabPFN(PHE)在仅 300 秒的调优时间后,就超越了 AutoGluon(允许 4 小时运行)。
  • 加速比:48×(TabPFN(PHE)仅用 300 秒,而 AutoGluon 需要 4 小时)。

实验四:扩展功能效果

notion image
除了强大的预测性能之外,TabPFN 还具备一些基础模型的核心能力,例如数据生成、密度估计、学习可复用的嵌入和微调。
作者通过概念验证实验展示了这些能力,使用的数据集包括:
  • German Credit 数据集,包含信用风险信息。
  • mfeat-factors 数据集,用于基于表格表示的手写数字分类。
1. 数据生成与密度估计
TabPFN 可以估计数值特征的概率密度函数,并且可以估计类别特征的概率质量函数。计算样本密度有助于异常检测,可用于识别欺诈行为、设备故障、医疗紧急情况或低质量数据。此外,TabPFN 还能合成新的表格数据样本,模拟真实数据集的特征。这对于数据增强或隐私保护的数据共享具有重要意义。
2. 可复用的嵌入学习
TabPFN 的架构能够学习具有实际意义的特征表示,这些表示可以用于数据填补和聚类。在 mfeat-factors 数据集中,作者可视化了 TabPFN 学习到的嵌入。与原始数据相比,TabPFN 提取的特征在前两个主成分上具有更好的类别分离性,说明其学习到了更有意义的特征表示。
3. 通过微调提升性能
TabPFN 能够通过微调在相关数据集上提升性能。与基于树的方法不同,TabPFN 采用的是神经网络架构,因此可以针对特定类别的数据集进行微调。作者进行了概念验证实验:
  • 使用不同偏移量的正弦曲线数据集进行微调,并在测试数据上评估性能。
  • 作者发现,即使微调数据的标签与测试数据存在显著差异,TabPFN 仍然能成功迁移知识。
  • 当数据分布更相似时,微调的性能提升更明显。
潜在应用:
  • 例如,在医学研究中,TabPFN 可以通过微调来自多个医学数据集的模型,从而构建更强大的通用医学诊断模型。
4. TabPFN 的可解释性
作者还开发了一种方法,使 TabPFN 的预测更易解释。在高影响力领域部署模型时,可解释性对于建立信任和问责至关重要。如何计算特征重要性呢:
  • 作者使用 SHAP 方法来计算特征重要性。SHAP 是一种博弈论方法,用于解释模型的预测结果,SHAP 值可以表示每个特征对模型输出的贡献。
在实验中,作者比较了逻辑回归、CatBoost 和 TabPFN 的特征重要性及其影响:
  • TabPFN 既能达到高准确率,同时学习到简单、可解释的特征关系。
  • 逻辑回归可解释性较强,但准确率较低。
  • CatBoost 准确率较高,但由于决策边界复杂且不平滑,可解释性较差。

4. 代码

作者仅提供了TabPFN的代码,用于生成合成预训练数据的代码(SCM模型)并未随作者的模型一同发布。
TabPFN代码:Accurate predictions on small data with a tabular foundation model - Prior Labs
暂未尝试复现,但作者提供了完善的TabPFN生态系统
  • TabPFN Client:基于云端推理的易用 API 客户端。
  • TabPFN Extensions:社区扩展与集成。
  • TabPFN(本仓库):适用于本地部署与研究的核心实现。
  • TabPFN UX:无代码使用 TabPFN。
在CoLab中也有教程,从Github仓库现状看来,部署到实际项目中是可行的:
  • 自仓库见库以来,每周均有维护
  • 2.7k star & 已有128人调用该库
 
四轴III——飞控算法【tmux】打造使用终端
Loading...