
こんにちは。テックラボの高橋です。
pytorchにtorch.compileという機能があることをご存知でしょうか?
torch 2.0から導入されたこの機能を利用することで、推論処理や学習処理を高速化できるとのことです。
今回はNVIDIA A100を用いて、torch.compileがどのくらい効果があるか検証してみました。
環境
- pytorch 2.6
- GPU NVIDIA A100 80G
- ubuntu 20.04.6
- nvidia-docker 24.0.9-1
- モデル tokyotech-llm/Llama-3.1-Swallow-8B-Instruct-v0.3
torch.compileとは
はじめに、torch.compileについて簡単に説明します。

torch.compileは実行中にモデルを最適化し高速化を図る技術です。
上図中にあるTorchDynamo、TorchInductor、Tritonについて各処理を見ていきます。
torch.compileを実行すると、まずTorchDynamoがPythonバイトコードをFXグラフと呼ばれる計算グラフに変換します。
この際の内部構造はdypyfというライブラリを用いると可視化するこができます。
例えば以下のような処理があるとします。
def toy_example(a, b):
x = a / (torch.abs(a) + 1)
if b.sum() < 0:
b = b * -1
return x * b
この関数をコンパイルしたものをdepyfで可視化すると、個々のグラフのモジュールが人間に読みやすい形式で、コメント付きで生成されます。 グラフモジュールのひとつを見てみましょう。
from __future__ import annotations
import torch
class GraphModule(torch.nn.Module):
def forward(self, L_a_: "f32[10]", L_b_: "f32[10]"):
l_a_ = L_a_
l_b_ = L_b_
# File: /workspace/check_depyf.py:8 in toy_example, code: x = a / (torch.abs(a) + 1)
abs_1: "f32[10]" = torch.abs(l_a_)
add: "f32[10]" = abs_1 + 1; abs_1 = None
x: "f32[10]" = l_a_ / add; l_a_ = add = None
# File: /workspace/check_depyf.py:9 in toy_example, code: if b.sum() < 0:
sum_1: "f32[]" = l_b_.sum(); l_b_ = None
lt: "b8[]" = sum_1 < 0; sum_1 = None
return (x, lt)
生成されたコメントを読むと、元の関数の以下の箇所がモジュール化されていることがわかります。
x = a / (torch.abs(a) + 1) if b.sum() < 0:
その後、このFXグラフをTorchInductorが最適化し、TritonがNVIDIAで処理を実行するための関数であるCUDAカーネルを生成します。
実装
torch.compileのチュートリアルを参考に、以下のようなコードを実行してみます。
import os
import numpy as np
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import matplotlib.pyplot as plt
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html#demonstrating-speedups
def timed(fn):
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
result = fn()
end.record()
torch.cuda.synchronize()
return result, start.elapsed_time(end) / 1000
device = "cuda"
ckpt = "tokyotech-llm/Llama-3.1-Swallow-8B-Instruct-v0.3"
model = AutoModelForCausalLM.from_pretrained(ckpt, torch_dtype=torch.float16)
model.to(device)
tokenizer = AutoTokenizer.from_pretrained(ckpt)
model.generation_config.max_length = 128
prompts = ["なぜ犬はこんなにもかわいいの?"] * 10
# without torch.compile
timings_without_compile = []
for prompt in prompts:
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
_, elapsed_time = timed(lambda: model.generate(**inputs, do_sample=False, pad_token_id=tokenizer.eos_token_id))
print(elapsed_time)
timings_without_compile.append(elapsed_time)
torch._dynamo.reset()
# compile
model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True)
model.generation_config.cache_implementation = "static"
# with torch.compile
torch.compiler.cudagraph_mark_step_begin()
timings_with_compile = []
for prompt in prompts:
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
_, elapsed_time = timed(lambda: model.generate(**inputs, do_sample=False, pad_token_id=tokenizer.eos_token_id))
print(elapsed_time)
timings_with_compile.append(elapsed_time)
結果は以下のようになります。

橙色がコンパイル有りの結果ですが、最初の2回だけコンパイルのために時間がかかっています。
そこで、下記のようにして最初の2回を抜いた実行時間の平均を取ると、
print(f"{np.array(timings_without_compile[2:]).mean()=}")
print(f"{np.array(timings_with_compile[2:]).mean()=}")
コンパイル無しの平均値: 3.12秒
コンパイル有りの平均値: 1.79秒
となりました。1.7倍程度のスピードアップですね。
上記は同じプロンプトを10回実行していましたが、 試しにpromptsの内容を以下のように10種類に変更してみます。
prompts = [
"なぜ犬はこんなにもかわいいの?",
"なぜ猫は夜行性なの?",
"宇宙の果てには何があるの?",
"人間はなぜ夢を見るの?",
"海の深さはどれくらい?",
"なぜ空は青いの?",
"恐竜はどのようにして絶滅したの?",
"なぜ音楽を聴くと感動するの?",
"人類はいつ火星に住むことができるの?",
"なぜ植物は光合成をするの?",
"どうして鳥は飛ぶの?"
]
結果はこちらです。

コンパイル有りの場合、文章を変更すると再コンパイルが走る場合があるようです。
vLLMとの比較
vLLMはLLMを高速に実行するためのライブラリです。 PagedAttentionという仕組み等を利用して高速に推論処理を行うことができます。
こちらを以下のコードで実行時間を計測してみます。
import os
import numpy as np
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from vllm import LLM, SamplingParams
os.environ["TOKENIZERS_PARALLELISM"] = "false"
def timed(fn):
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
result = fn()
end.record()
torch.cuda.synchronize()
return result, start.elapsed_time(end) / 1000
device = "cuda"
ckpt = "tokyotech-llm/Llama-3.1-Swallow-8B-Instruct-v0.3"
prompts = ["なぜ犬はこんなにもかわいいの?"] * 10
model = LLM(
model=ckpt,
tokenizer=ckpt,
dtype="float16"
)
sampling_params = SamplingParams(
temperature=0.8,
top_p=0.95,
max_length=128
)
timings_with_vllm = []
for prompt in prompts:
_, elapsed_time = timed(lambda: model.generate(prompt, sampling_params))
print(elapsed_time)
timings_with_vllm.append(elapsed_time)
処理の平均値は1.61秒になりました。
デフォルトのtorch、torch.compileと比較すると以下のようになります。

GPU:NVIDIA A100 80G CUDA:12.6 pytorch:2.6 モデル:Llama-3.1-Swallow-8B-Instruct-v0.3 量子化:float16 試行:同一プロンプト x 10回
若干vLLMのほうが早いみたいですね。
今回はtorch.compile側がtorch2.6をほぼデフォルトパラメータで実行しており、FlashAttention等を無効化していません。 また、vLLMもパラメータ最適化は行っていないので、設定によってかなり結果は変わりうると思われます。
他の方の記事等を読むと、vLLMがより高速になる例があるようですので、 あくまでこのコード・環境での場合の値ということでご了承いただければと思います。
おわりに
本記事ではNVIDIA A100でのtorch.compileの検証と、vLLMとの比較を行ってみました。
vLLMなどの推論フレームワークのベンチマークを取った"LLM-Inference-Bench: Inference Benchmarking of Large Language Models on AI Accelerators"という論文によると、バッチ推論において特にvLLMは性能が良さそうです。他の条件についても、引き続き調査していきたいと思います。
参考
PyTorch 2.0の新機能「torch.compile」使ってみた - まったり勉強ノート
Introduction to torch.compile — PyTorch Tutorials 2.6.0+cu124 documentation