参数与显存

date
icon
password
Sub-item
Blocked by
Parent item
type
status
slug
summary
tags
category
Blocking

显存与参数的估算

推理

训练

训练时所需显存分为三大块
  1. 载入模型所需的显存
  1. 输入值、输出值、中间激活值所需显存
  1. 梯度+动量
以LLama 7b为例分析

显存与参数的估算

推理

推理时所需显存分为两大块
  1. 载入模型所需的显存
  1. 输入值、输出值、中间激活值所需显存

载入模型

载入模型时,需要把模型的权重、偏置等参数载入显存中。这部分显存通常较小,不会成为瓶颈。

输入输出和中间激活值

这部分显存需要存储输入、输出和中间激活值。这部分显存的大小与输入的形状有关,通常是模型参数大小的倍数。在前向传播过程中,这部分显存会逐渐释放,因此在推理过程中这部分显存是不会成为瓶颈的。

训练

训练时所需显存分为三大块
  1. 载入模型所需的显存
  1. 输入值、输出值、中间激活值所需显存
  1. 梯度+动量

载入模型

载入模型时,需要把模型的权重、偏置等参数载入显存中。这部分显存通常较小,不会成为瓶颈。

输入输出和中间激活值

这部分显存需要存储输入、输出和中间激活值。这部分显存的大小与输入的形状有关,通常是模型参数大小的倍数。在前向传播过程中,这部分显存会逐渐释放,因此在训练过程中这部分显存是不会成为瓶颈的。

梯度+动量

在反向传播过程中,需要存储梯度和动量。这部分显存的大小也与模型参数大小有关。在反向传播结束后,这部分显存会被释放。
以LLama 7b为例分析
LLama 7b是一种比较小的模型,它有7亿个参数。在训练时,载入模型所需的显存大约为200MB。输入、输出和中间激活值所需显存大约为1.8GB。梯度和动量所需显存大约为1.8GB。因此,LLama 7b在训练时需要约3.6GB的显存。
Reference
Reference
     
    ES算法原理Constitutional AI: Harmlessness from AI Feedback