预测零售/金融/制造数据,TimesFM时间序列模型实战指南
at 2个月前 ca timesfm pv 163 by touch
TimesFM是Google Research开发的一个预训练时间序列基础模型,使用包含1000亿现实世界时间序列数据集进行了预训练,拥有2亿参数,该模型在各种现实世界的预测基准上展现出令人印象深刻的零样本性能。
零样本性能,指模型在没有接受过任何特定任务训练数据的情况下,对该任务的预测能力。
本文将通过使用TimesFM模型对月度进口普通化妆品备案数进行预测来介绍模型的部署和运行过程,以及如何解读最终的输出结果。
01
运行环境配置
操作系统:Ubuntu 22.04
# 克隆官方模型加载和推理仓库 git clone https://github.com/google-research/timesfm.git cd timesfm # 配置专用虚拟环境 conda env create -f environment_cpu.yml # 激活环境 conda activate tfm_env # 安装timesfm库 pip install -e .
import timesfm # 模型载入参数设置 tfm = timesfm.TimesFm( context_len=32, horizon_len=5, input_patch_len=32, output_patch_len=128, num_layers=20, model_dims=1280, backend="cpu", ) # 模型载入 tfm.load_from_checkpoint(repo_id="google/timesfm-1.0-200m")
根据官方仓库说明,以下4个参数固定不变,以适配当前2亿参数的timesfm模型。
input_patch_len=32
output_patch_len=128
num_layers=20
model_dims=1280
context_len参数:设置上下文的长度,必须为input_patch_len参数的乘数,当前最大支持512,可选值包括:32、64、96、128、256、... 、512。另外,此参数设置的上下文长度不影响输入时间序列的上下文长度,输入时间序列可以为任意长度,在推理过程中代码会自动处理序列的填充或截断。
horizon_len参数:设置预测的范围,可以设置为任意值,但建议小于等于context_len的值。如果horizon_len的值为3,输入时间序列最后一条数据的日期为2023-12-31,时间频度为天(该参数在后面的推理代码中设置),那么模型就会预测2024-01-01、2024-01-02、2024-01-03这3天的结果。
backend参数:可选值 "cpu"、"gpu"、"tpu",初次尝试建议使用"cpu",其它选项安装依赖库时容易出现冲突的问题。
03
模型推理
# 输入时间序列数据 data = [ ['jkpt', '2022-01-31', 82], ['jkpt', '2022-02-28', 180], ['jkpt', '2022-03-31', 506], ['jkpt', '2022-04-30', 1036], ... 省略 ... ['jkpt', '2024-01-31', 490], ['jkpt', '2024-02-29', 615] ] # 将列表格式转为dataframe格式并设置每列的列名 df = pd.DataFrame(data, columns=['unique_id', 'ds', 'value']) # 将数据第二列时间数据转为datetime格式 df['ds'] = pd.to_datetime(df['ds']) # 进行预测 forecast_df = tfm.forecast_on_df( inputs=df, freq="M", value_name="value", num_jobs=1, ) # 显示结果 print(forecast_df)
data是输入的时间序列数据,通常是当前预测任务的历史数据。数据需包含3列,以下是各列的解释:
第一列,是数据id,用于区分不同的数据组。当前演示示例只有1组数据,所以第一列所有id都一样。本文后面会提及多组数据的情形并提供示例,例如预测股市时,有价格和成交量2组时序数据。
第二列,是数据的时间信息,在进行推理前需要将其转为datetime格式,否则会报错。
第三列,是数据的值,该演示示例写的是各月份的备案数。
此外,使用pd将数据转为dataframe格式时,第一列和第二列的列名必须设置为unique_id和ds。
freq参数:指时间频度,可选值包括T, MIN, H, D, B, U, W, M, Q, Y,T, MIN表示分钟,H, D, B, U, W, M, Q, Y分别表示小时、天、工作日、微秒、周、月、季度、年。虽然这里给出了10种可选项,但最终实际都会转换为0, 1, 2这3种值,分别对应高频、中频、低频,深入了解可查看仓库源码。
value_name参数:指需要预测数据那列的列名。
num_jobs参数:指用于数据帧处理的并行进程数,默认值为1。
04
输出结果解读
输出结果长下面这样:
一共有12列数据,前两列分别是数据id和时间信息,timesfm列即是预测的结果。
其中有9列,timesfm-q-0.1~timesfm-q-0.9,其中q指的是分位数(Quantile)。数据结果第一行,timesfm-q-0.1对应的值是694,表示模型预测中有10%的数据小于694,有90%的数据大于694。数据结果第一行,timesfm-q-0.9对应的值是1309,表示模型预测中有10%的数据小于1309,有90%的数据大于1309。综合以上,预测值会有80%的概率落在694~1309这个区间。
timesfm-q-0.5是中位数,timesfm列的结果和timesfm-q-0.5列一致。
本次示例使用2022-01至2024-02各月份的进口化妆品备案数作为输入时间序列,预测2022-03至2024-07各月的备案数,其中已有3月、4月的备案数据。通过matplotlib库,将预测结果绘制成图表如下:
data = [ ['price', '2022-01-01', 3001.16], ['price', '2022-01-02', 3080.83], ... 省略 ... ['price', '2024-05-31', 3080.83], ['volume', '2022-01-01', 302380768], ['volume', '2022-01-02', 289970336], ... 省略 ... ['volume', '2024-05-31', 289970336] ]
tfm.load_from_checkpoint(checkpoint_path="timesfm-1.0-200m/checkpoints")
07
environment_cpu.yml内容如下
name: tfm_env channels: - conda-forge - defaults - anaconda dependencies: - jupyterlab - pip - python=3.10 - pip: - huggingface_hub[cli] - utilsforecast - praxis - paxml - jax[cpu] - einshape
environment.yml内容如下
name: tfm_env channels: - conda-forge - defaults - anaconda dependencies: - jupyterlab - pip - python=3.10 - pip: - huggingface_hub[cli] - utilsforecast - praxis - paxml - -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html - "jax[cuda12_pip]" - einshape
conda安装教程:
版权声明
本文仅代表作者观点,不代表码农殇立场。
本文系作者授权码农殇发表,未经许可,不得转载。