torch显存分析——如何在不关闭进程的情况下释放显存

torch显存分析——如何在不关闭进程的情况下释放显存

1. 基本概念——allocator和block

2. torch.cuda的三大常用方法

3. 可以释放的显存

4. 无法释放的显存?

5. 清理“显存钉子户”

一直以来,对于torch的显存管理,我都没有特别注意,只是大概了解到,显存主要分为几个部分(cuda context、模型占用、数据占用),然而在接触大模型之后,遇到越来越多的显存合理利用的问题,尤其是利用大模型进行推理时,怎样规划好一个进程的显存占用,是一件非常重要的事情。

本文就近期针对torch显存管理的工作进行整理总结,主要目的就是解决一个问题——如何在不关闭进程的情况下释放显存。

1. 基本概念——allocator和block

首先需要了解两个基本概念,allocator与block。

Allocator是torch用来管理显存的工具,以下是chatgpt的解释:

在PyTorch中,allocator是用于动态分配内存的抽象接口。

PyTorch使用allocator来分配张量所需的内存,并使用该内存来存储张量的数据和元数据。

这使得PyTorch能够管理内存的使用,避免内存泄漏和浪费,并最大化系统的使用效率。

而block可以理解为显存中的若干分区,这些分区有大有小,torch将tensor从cpu移动到gpu上,实际上是将tensor移动到某个block上。

根据我的理解,可以将相关的要点总结如下:

从功能上讲,allocator是torch用来获取和管理block的工具,torch通过allocator从gpu获取到所需要的block,然后将所有获取到的block放在一个block pool中;

当需要将某个tensor放到gpu上时,会将其放在其中一个block上;

tensor不能分割开,放在不同的block,例如一个6Mb的tensor,会要求一个大于等于6Mb的block,而无法将其分散在2个4Mb的block上;

一般情况下,torch不会主动去释放掉block,当一个tensor不再使用时,其所占用的block仍然处在block pool中,此时查看进程所占用显存,不会出现下降;

当又有一个tensor需要放在gpu上时,会优先检查block pool中,是否存在可以放得下这个tensor的block,如果有,则有限使用这个block,如果没有,则allocator会再尝试向显卡申请其他block,如果显卡上也没有符合条件的空闲block,则程序就会报OOM;

可以利用torch.cuda.empty_cache方法,手动释放掉未被占用的block,但是会造成程序运行变慢。

2. torch.cuda的三大常用方法

我在学习torch的显存管理时,参考了这篇文章,其中很具体的介绍了torch显存管理的三个常用的方法,这里不再重复详细的介绍,仅将其作用简单介绍如下:

torch.cuda.memory_allocated():查看当前tensor占用的显存

torch.cuda.memory_reserved():查看进程占用的总共的显存

torch.cuda.empty_cache():释放掉未使用的缓存

除了参考文章中所介绍的三个常用方法,这里再补充另一个比较实用的方法,查看显存占用的方法:torch.cuda.memory_stats(),可以查看当前显存的更加具体的占用情况。

具体说明可以参考:https://pytorch.org/docs/1.13/generated/torch.cuda.memory_stats.html#torch.cuda.memory_stats

看起来一切都很合理,当我需要释放block pool中没有被使用到的block,还给gpu时,就调用torch.cuda.empty_cache()方法即可。但问题偏偏就出在这里,当我们执行这一行指令的时候,显存真的会像所想的那样被释放吗?

3. 可以释放的显存

为了分析和验证显存占用情况的机制,我做了一个简单的实验。

实验只考虑推理阶段,所以所有的代码是在torch.no_grad()模式下进行的,这种模式下不会保存中间变量和梯度,所以显存的占用=模型参数占用+输入数据占用+输出结果占用。

完成这个实验,只需要一个for循环即可,通过逐渐增加输入的长度,来观察显存的变换情况:

# 以chatGLM-6B为代表进行实验

# 用一个列表来存储每一个时刻的显存信息

points = []

for cur_len in tqdm(range(0, 6000, 10)):

# 输入序列的长度从0,10,20,...,一直增长到OOM为止

real_inputs = inputs['input_ids'][..., : cur_len, ...].to(model.device)

# 开始阶段记录两个数值,分别是将inputs放