多 GPU 与分布式环境#

在现实世界的 LLM 部署场景中,人们通常需要多 GPU 或者多节点的环境来支持较大的模型, 与/或完成高吞吐量的推理任务。 FlagGems 通过允许跨多个 GPU 完成算子执行加速来支持这类使用场景。

1. 单节点与多节点用法#

对于单节点部署而言,集成工作是相对简单直接的。 你可以在你的代码开始部分 import flag_gems 之后调用 flag_gems.enable()。 无需其他的变更,你就可以获得算子加速的效果。

然而,在多节点部署环境中,这种方法是不够的。 分布式的推理框架(例如 vLLM)需要跨多个节点来启动多个工作进程, 每个进程都需要独立初始化 flag_gems。 如果对 flag_gems 的启用或激活操作仅发生在第一个节点的启动代码上, 其他节点上的工作进程会回退为默认的算子实现,也就是没有被加速过的版本。

2. 示例:与 vLLM 和 DeepSeek 集成#

要在一个分布式的 vLLM + DeepSeek 部署环境中启用 FlagGems,需执行以下步骤:

  1. 基线检验

    在开始此实验之前,请先检查模型在没有与 FlagGems 集成的前提下可以正常启动并提供服务。 例如,加载类似于 Deepseek-R1 这类模型通常需要至少 2 块 H100 GPU 卡, 并且其初始化过程可能需要将近 20 分钟的时间,取决于检查点的大小、系统 I/O 的带宽与延迟。

  2. flag_gems 注入到 vLLM 工作进程代码中

    基于你所使用的 vLLM 版本,找到模型运行脚本的位置:

    • 如果所使用的是vLLM v1 架构(在 vLLM ≥ 0.8 环境中可用), 要修改 vllm/v1/worker/gpu_model_runner.py 文件;
    • 如果所使用的是老的 v0 架构, 则需要修改 vllm/worker/model_runner.py 文件。

    打开所找到的文件,在最后一行 import 语句之后插入如下逻辑:

    import os
    if os.getenv("USE_FLAGGEMS", "false").lower() in ("1", "true", "yes"):
        try:
            import flag_gems
            flag_gems.enable()
            flag_gems.apply_gems_patches_to_vllm(verbose=True)
            logger.info("Successfully enabled flag_gems as default ops implementation.")
        except ImportError:
            logger.warning("Failed to import 'flag_gems', falling back to default implementation.")
        except Exception as e:
            logger.warning(f"Failed to enable 'flag_gems': {e}, falling back to default implementation.")
  3. 在所有节点上设置环境变量

    在启动服务之前,要确保所有节点上都设置了下面的环境变量:

    export USE_FLAGGEMS=1
  4. 启动分布式推理服务并检验运行状态

    启动分布式推理服务,在所有节点上检查启动日志,搜索表明算子已经被覆盖的消息。

    Overriding a previously registered kernel for the same operator and the same dispatch key
    operator: aten::add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor
      registered at /pytorch/build/aten/src/ATen/RegisterSchema.cpp:6
    dispatch key: CUDA
    previous kernel: registered at /pytorch/aten/src/ATen/....
         new kernel: registered at /dev/null:488 (Triggered internally at ....)
    self.m.impl(

    出现这类消息则意味着 flag_gems 已经被成功地跨多个节点启用了。