Skip to content

vllm.benchmarks.iterations

Batch iteration benchmark for precise prefill/decode phase measurement.

On the server side, run: vllm serve --profiler-config.profiler torch \ --profiler-config.torch_profiler_dir /path/to/traces

On the client side, run: # Prefill benchmark: measure prefill of 8K new tokens (no existing context) vllm bench iterations \ --endpoints 127.0.0.1:8000 \ --input-len 8192 \ --batch-size 1 \ --mode prefill \ --profile \ --model

# Prefill benchmark: measure prefill of 2K new tokens against 4K existing context
vllm bench iterations \
    --endpoints 127.0.0.1:8000 \
    --context-len 4096 \
    --input-len 2048 \
    --batch-size 1 \
    --mode prefill \
    --profile \
    --model <your_model>

# Decode benchmark: warmup with 8K context, measure 128 decode iterations
vllm bench iterations \
    --endpoints 127.0.0.1:8000 \
    --context-len 8192 \
    --batch-size 64 \
    --mode decode \
    --iterations 128 \
    --profile \
    --model <your_model>

This benchmark uses sleep(level=0) to pause scheduling, queues requests, then resumes scheduling to measure precise batch execution times.

Prefix Cache Warmup

Before each benchmark run, the client sends warmup requests with context_len tokens to populate the prefix cache. The benchmark requests share this prefix, so the server can skip prefilling the context portion (prefix cache hit).

Modes

prefill: First warms up prefix cache with context_len tokens. Then measures prefill of input_len NEW tokens against existing context. Total prompt = context_len + input_len tokens. context_len=0 is valid (clean prefill of new input only).

decode: First warms up prefix cache with context_len tokens. Then measures decode throughput for --iterations output tokens. The benchmark prompt matches the warmup (full prefix cache hit), so we measure ONLY decode latency, not prefill. context_len > 0 is REQUIRED (cannot decode without context).

Batch Size Semantics

--batch-size specifies the batch size PER DP domain, matching the standalone benchmark (fbcode) semantics. The client automatically queries the server's DP configuration and multiplies to get the global batch size.

Example: With DP=8 and --batch-size 64, the client sends 64*8=512 total requests distributed round-robin across all DP ranks.

NOTE: For accurate prefill benchmarks, do NOT use --enable-chunked-prefill on the server. Chunked prefill breaks long prefills into multiple steps, which interferes with measuring true prefill performance.

BenchmarkConfig dataclass

Configuration for the iterations benchmark.

Source code in vllm/benchmarks/iterations.py
@dataclass
class BenchmarkConfig:
    """Configuration for the iterations benchmark."""

    endpoints: list[str]
    context_lens: list[int]
    input_lens: list[int]
    batch_sizes: list[int]
    mode: str  # "prefill" or "decode"
    iterations: int
    profile: bool
    model: str

EndpointRotator

Round-robin endpoint selection (matches disagg_benchmarks pattern).

Source code in vllm/benchmarks/iterations.py
class EndpointRotator:
    """Round-robin endpoint selection (matches disagg_benchmarks pattern)."""

    def __init__(self, endpoints: list[str]):
        self.endpoints = [self._normalize(e) for e in endpoints]
        self.cycle = itertools.cycle(self.endpoints)

    def _normalize(self, endpoint: str) -> str:
        """Ensure endpoint has http:// prefix."""
        if not endpoint.startswith(("http://", "https://")):
            return f"http://{endpoint}"
        return endpoint.rstrip("/")

    def next(self) -> str:
        return next(self.cycle)

    def all(self) -> list[str]:
        return self.endpoints

_normalize

_normalize(endpoint: str) -> str

Ensure endpoint has http:// prefix.

Source code in vllm/benchmarks/iterations.py
def _normalize(self, endpoint: str) -> str:
    """Ensure endpoint has http:// prefix."""
    if not endpoint.startswith(("http://", "https://")):
        return f"http://{endpoint}"
    return endpoint.rstrip("/")

IterationResult dataclass

Result of a single benchmark iteration.

Source code in vllm/benchmarks/iterations.py
@dataclass
class IterationResult:
    """Result of a single benchmark iteration."""

    endpoint: str
    mode: str
    context_len: int
    input_len: int
    batch_size: int
    iteration: int
    total_latency_ms: float
    latency_per_iter_ms: float
    tokens_per_second: float
    prompt_tokens: int
    completion_tokens: int

ServerConfig dataclass

Server parallelism configuration.

Source code in vllm/benchmarks/iterations.py
@dataclass
class ServerConfig:
    """Server parallelism configuration."""

    data_parallel_size: int = 1
    tensor_parallel_size: int = 1
    pipeline_parallel_size: int = 1
    world_size: int = 1

add_cli_args

add_cli_args(parser: ArgumentParser) -> None

Add CLI arguments for the iterations benchmark.

Source code in vllm/benchmarks/iterations.py
def add_cli_args(parser: argparse.ArgumentParser) -> None:
    """Add CLI arguments for the iterations benchmark."""
    parser.add_argument(
        "--endpoints",
        type=str,
        required=True,
        help="Comma-separated list of endpoints (e.g., host1:8000,host2:8000)",
    )
    parser.add_argument(
        "--context-len",
        type=str,
        default="0",
        help="Prompt tokens to prefill KV cache before decode (decode mode only, "
        "comma-separated for sweep, e.g., 512,1024,2048)",
    )
    parser.add_argument(
        "--input-len",
        type=str,
        default="128",
        help="Prompt tokens to measure prefill latency (prefill mode only, "
        "comma-separated for sweep, e.g., 128,256,512)",
    )
    parser.add_argument(
        "--batch-size",
        type=str,
        default="1",
        help="Batch size per DP domain, comma-separated for sweep (e.g., 1,4,8). "
        "Automatically multiplied by server DP size to get global batch size.",
    )
    parser.add_argument(
        "--mode",
        type=str,
        choices=["prefill", "decode"],
        required=True,
        help="Benchmark mode: prefill or decode",
    )
    parser.add_argument(
        "--iterations",
        type=int,
        default=128,
        help="Number of decode tokens to generate (decode mode only, default: 128)",
    )
    parser.add_argument(
        "--profile",
        action="store_true",
        help="Enable GPU profiling and download traces",
    )
    parser.add_argument(
        "--model",
        type=str,
        required=True,
        help="Model name to use for requests",
    )
    parser.add_argument(
        "--output",
        type=str,
        default=None,
        help="Output JSON file for results",
    )

build_prompts

build_prompts(
    context_len: int, input_len: int, mode: str
) -> tuple[str | None, str]

Build context and benchmark prompts for a parameter combination.

Returns (context_prompt, benchmark_prompt) where: - context_prompt: Used for prefix cache warmup (None if context_len <= 0) - benchmark_prompt: Used for the actual benchmark run

For prefill mode

context_prompt = context_len tokens ("hello " repeated) benchmark_prompt = context_len + input_len tokens The first context_len tokens match context_prompt (prefix cache hit). We measure prefill of the remaining input_len new tokens.

For decode mode

context_prompt = context_len tokens benchmark_prompt = same as context_prompt (full prefix cache hit) We measure only decode iterations (no prefill work).

Source code in vllm/benchmarks/iterations.py
def build_prompts(
    context_len: int, input_len: int, mode: str
) -> tuple[str | None, str]:
    """Build context and benchmark prompts for a parameter combination.

    Returns (context_prompt, benchmark_prompt) where:
    - context_prompt: Used for prefix cache warmup (None if context_len <= 0)
    - benchmark_prompt: Used for the actual benchmark run

    For prefill mode:
        context_prompt = context_len tokens ("hello " repeated)
        benchmark_prompt = context_len + input_len tokens
        The first context_len tokens match context_prompt (prefix cache hit).
        We measure prefill of the remaining input_len new tokens.

    For decode mode:
        context_prompt = context_len tokens
        benchmark_prompt = same as context_prompt (full prefix cache hit)
        We measure only decode iterations (no prefill work).
    """
    # Build context portion ("hello " is roughly 1-2 tokens per word)
    context_words = context_len // 2
    context_part = "hello " * max(1, context_words) if context_len > 0 else ""

    # Context prompt for prefix cache warmup
    context_prompt = context_part if context_len > 0 else None

    # Build benchmark prompt
    if mode == "prefill":
        # Add new input tokens after context (these will be prefilled)
        input_words = input_len // 2
        input_part = "world " * max(1, input_words)
        benchmark_prompt = context_part + input_part
    else:
        # Decode: same as context (full prefix cache hit, no prefill)
        benchmark_prompt = context_part

    return context_prompt, benchmark_prompt

call_debug_endpoint async

call_debug_endpoint(
    session: ClientSession,
    rotator: EndpointRotator,
    path: str,
    params: dict | None = None,
) -> bool

Call debug endpoint on ALL endpoints (for sleep/wake_up/profile).

Returns True if all calls succeeded, False if any failed.

Source code in vllm/benchmarks/iterations.py
async def call_debug_endpoint(
    session: aiohttp.ClientSession,
    rotator: EndpointRotator,
    path: str,
    params: dict | None = None,
) -> bool:
    """Call debug endpoint on ALL endpoints (for sleep/wake_up/profile).

    Returns True if all calls succeeded, False if any failed.
    """
    tasks = [
        session.post(f"{endpoint}{path}", params=params) for endpoint in rotator.all()
    ]
    responses = await asyncio.gather(*tasks, return_exceptions=True)
    all_success = True
    for endpoint, resp in zip(rotator.all(), responses):
        if isinstance(resp, Exception):
            logger.warning("Failed to call %s%s: %s", endpoint, path, resp)
            all_success = False
        else:
            body = await resp.read()
            if resp.status >= 400:
                logger.warning(
                    "HTTP %d from %s%s: %s",
                    resp.status,
                    endpoint,
                    path,
                    body.decode()[:200],
                )
                all_success = False
    return all_success

count_tokens

count_tokens(response_data: dict) -> tuple[int, int]

Extract token counts from completion response.

Source code in vllm/benchmarks/iterations.py
def count_tokens(response_data: dict) -> tuple[int, int]:
    """Extract token counts from completion response."""
    usage = response_data.get("usage", {})
    return usage.get("prompt_tokens", 0), usage.get("completion_tokens", 0)

fetch_server_config async

fetch_server_config(
    session: ClientSession, rotator: EndpointRotator
) -> ServerConfig

Fetch server parallelism config from first endpoint.

Source code in vllm/benchmarks/iterations.py
async def fetch_server_config(
    session: aiohttp.ClientSession,
    rotator: EndpointRotator,
) -> ServerConfig:
    """Fetch server parallelism config from first endpoint."""
    endpoint = rotator.all()[0]
    try:
        resp = await session.get(f"{endpoint}/debug/config")
        if resp.status == 200:
            data = await resp.json()
            return ServerConfig(
                data_parallel_size=data.get("data_parallel_size", 1),
                tensor_parallel_size=data.get("tensor_parallel_size", 1),
                pipeline_parallel_size=data.get("pipeline_parallel_size", 1),
                world_size=data.get("world_size", 1),
            )
    except Exception as e:
        logger.warning("Failed to fetch server config: %s", e)
    return ServerConfig()

fetch_traces async

fetch_traces(
    session: ClientSession,
    rotator: EndpointRotator,
    prefix: str,
    output_dir: str,
) -> list[str]

Download trace files from all endpoints.

Source code in vllm/benchmarks/iterations.py
async def fetch_traces(
    session: aiohttp.ClientSession,
    rotator: EndpointRotator,
    prefix: str,
    output_dir: str,
) -> list[str]:
    """Download trace files from all endpoints."""
    os.makedirs(output_dir, exist_ok=True)
    downloaded = []

    for i, endpoint in enumerate(rotator.all()):
        try:
            resp = await session.get(f"{endpoint}/debug/traces")
            data = await resp.json()

            for trace_file in data.get("traces", []):
                if prefix in trace_file:
                    trace_resp = await session.get(
                        f"{endpoint}/debug/traces/{trace_file}"
                    )
                    local_path = os.path.join(output_dir, f"endpoint{i}_{trace_file}")
                    with open(local_path, "wb") as f:
                        f.write(await trace_resp.read())
                    logger.info("Downloaded: %s", local_path)
                    downloaded.append(local_path)
        except Exception as e:
            logger.warning("Failed to fetch traces from %s: %s", endpoint, e)

    return downloaded

parse_comma_list

parse_comma_list(value: str) -> list[int]

Parse comma-separated integers.

Source code in vllm/benchmarks/iterations.py
def parse_comma_list(value: str) -> list[int]:
    """Parse comma-separated integers."""
    return [int(x.strip()) for x in value.split(",")]

print_results_summary

print_results_summary(
    results: list[IterationResult],
    server_config: ServerConfig | None = None,
) -> None

Print a summary of benchmark results.

Source code in vllm/benchmarks/iterations.py
def print_results_summary(
    results: list[IterationResult],
    server_config: ServerConfig | None = None,
) -> None:
    """Print a summary of benchmark results."""
    if not results:
        logger.warning("No results to summarize")
        return

    print("\n" + "=" * 110)
    print("BENCHMARK RESULTS SUMMARY")
    print("=" * 110)

    if server_config:
        print(
            f"Server: DP={server_config.data_parallel_size}, "
            f"TP={server_config.tensor_parallel_size}, "
            f"PP={server_config.pipeline_parallel_size}, "
            f"World={server_config.world_size}"
        )
        print("-" * 110)

    print(
        f"{'Mode':>8} {'Context':>8} {'Input':>8} {'Batch':>6} {'Iters':>6} "
        f"{'Total (ms)':>12} {'Per-Iter (ms)':>14} {'Tokens/s':>12}"
    )
    print("-" * 110)

    for r in results:
        print(
            f"{r.mode:>8} {r.context_len:>8} {r.input_len:>8} {r.batch_size:>6} "
            f"{r.iteration:>6} {r.total_latency_ms:>12.2f} "
            f"{r.latency_per_iter_ms:>14.2f} {r.tokens_per_second:>12.2f}"
        )

    print("=" * 110 + "\n")

run_benchmark async

run_benchmark(
    config: BenchmarkConfig,
) -> tuple[list[IterationResult], ServerConfig]

Main benchmark loop with parameter sweeping.

Source code in vllm/benchmarks/iterations.py
async def run_benchmark(
    config: BenchmarkConfig,
) -> tuple[list[IterationResult], ServerConfig]:
    """Main benchmark loop with parameter sweeping."""

    rotator = EndpointRotator(config.endpoints)
    results: list[IterationResult] = []

    connector = aiohttp.TCPConnector(
        limit=0,  # No limit
        ttl_dns_cache=300,
        keepalive_timeout=60,
    )

    async with aiohttp.ClientSession(connector=connector) as session:
        # Fetch server config once
        server_config = await fetch_server_config(session, rotator)
        dp_size = server_config.data_parallel_size
        logger.info(
            "Server config: DP=%d, TP=%d, PP=%d",
            dp_size,
            server_config.tensor_parallel_size,
            server_config.pipeline_parallel_size,
        )

        # Warmup: trigger runtime compilation before benchmarking
        await run_compilation_warmup(session, rotator, config.model)

        # Sweep all parameter combinations
        param_combos = list(
            itertools.product(
                config.context_lens,
                config.input_lens,
                config.batch_sizes,
            )
        )

        logger.info(
            "Running %d parameter combinations",
            len(param_combos),
        )

        # For prefill: 1 output token. For decode: config.iterations tokens.
        num_output_tokens = 1 if config.mode == "prefill" else config.iterations

        # Track all trace prefixes for fetching at the end
        trace_prefixes: list[str] = []

        for ctx_len, in_len, batch_size_per_dp in param_combos:
            # Scale batch size by DP to match standalone benchmark semantics
            # User specifies per-DP batch size, we compute global batch size
            global_batch_size = batch_size_per_dp * dp_size

            # Build all prompts for this parameter combination
            context_prompt, benchmark_prompt = build_prompts(
                ctx_len, in_len, config.mode
            )

            logger.info(
                "Running: mode=%s, ctx=%d, input=%d, batch=%d/dp (global=%d), "
                "output_tokens=%d",
                config.mode,
                ctx_len,
                in_len,
                batch_size_per_dp,
                global_batch_size,
                num_output_tokens,
            )

            # Prefix cache warmup: populate KV cache before profiling
            # This is NOT profiled - we only profile the actual benchmark
            await run_prefix_cache_warmup(
                session, rotator, config.model, context_prompt, global_batch_size
            )

            # Start profiling for this param combo (after prefix cache warmup)
            # Use batch_size_per_dp in trace prefix for consistency with SA naming
            trace_prefix = None
            if config.profile:
                trace_prefix = (
                    f"{config.mode}_ctx{ctx_len}_in{in_len}_bs{batch_size_per_dp}"
                )
                started = await call_debug_endpoint(
                    session, rotator, "/debug/profile/start", {"prefix": trace_prefix}
                )
                if started:
                    trace_prefixes.append(trace_prefix)
                else:
                    logger.warning("Failed to start profiling for %s", trace_prefix)

            (
                elapsed_ms,
                prompt_tokens,
                completion_tokens,
            ) = await run_single_iteration(
                session,
                config,
                rotator,
                benchmark_prompt,
                global_batch_size,
            )

            # Stop profiling for this param combo
            if config.profile and trace_prefix:
                await call_debug_endpoint(session, rotator, "/debug/profile/stop")

            total_tokens = prompt_tokens + completion_tokens
            tokens_per_second = (
                total_tokens / (elapsed_ms / 1000) if elapsed_ms > 0 else 0
            )
            latency_per_iter = (
                elapsed_ms / num_output_tokens if num_output_tokens > 0 else elapsed_ms
            )

            # Record per-DP batch size for comparison with standalone benchmark
            results.append(
                IterationResult(
                    endpoint=",".join(rotator.all()),
                    mode=config.mode,
                    context_len=ctx_len,
                    input_len=in_len,
                    batch_size=batch_size_per_dp,
                    iteration=num_output_tokens,
                    total_latency_ms=elapsed_ms,
                    latency_per_iter_ms=latency_per_iter,
                    tokens_per_second=tokens_per_second,
                    prompt_tokens=prompt_tokens,
                    completion_tokens=completion_tokens,
                )
            )

            logger.info(
                "  Result: %.2fms total, %.2fms/iter, %.2f tok/s",
                elapsed_ms,
                latency_per_iter,
                tokens_per_second,
            )

        # Fetch traces for all param combos
        if config.profile and trace_prefixes:
            logger.info("Fetching traces for %d runs...", len(trace_prefixes))

            # Retry fetching traces (async write by torch profiler)
            max_retries = 3
            all_downloaded: list[str] = []
            for attempt in range(max_retries):
                await asyncio.sleep(2.0)
                for prefix in trace_prefixes:
                    downloaded = await fetch_traces(session, rotator, prefix, "traces")
                    all_downloaded.extend(downloaded)
                if all_downloaded:
                    logger.info(
                        "Downloaded %d trace files to ./traces/", len(all_downloaded)
                    )
                    break
                logger.info(
                    "No traces yet, retrying (%d/%d)...", attempt + 1, max_retries
                )

            if not all_downloaded:
                logger.warning(
                    "No trace files found after %d attempts. "
                    "Check server profiler directory.",
                    max_retries,
                )

    return results, server_config

run_compilation_warmup async

run_compilation_warmup(
    session: ClientSession,
    rotator: EndpointRotator,
    model: str,
) -> None

Send a warmup request to trigger runtime compilation.

Source code in vllm/benchmarks/iterations.py
async def run_compilation_warmup(
    session: aiohttp.ClientSession,
    rotator: EndpointRotator,
    model: str,
) -> None:
    """Send a warmup request to trigger runtime compilation."""
    endpoint = rotator.all()[0]
    logger.info("Sending warmup request to trigger compilation...")
    try:
        resp = await session.post(
            f"{endpoint}/v1/completions",
            json={
                "model": model,
                "prompt": "Hello",
                "max_tokens": 1,
                "stream": False,
            },
        )
        if resp.status == 200:
            await resp.json()
            logger.info("Compilation warmup complete")
        else:
            logger.warning("Warmup request failed: HTTP %d", resp.status)
    except Exception as e:
        logger.warning("Warmup request failed: %s", e)

run_prefix_cache_warmup async

run_prefix_cache_warmup(
    session: ClientSession,
    rotator: EndpointRotator,
    model: str,
    context_prompt: str | None,
    batch_size: int,
) -> None

Populate prefix cache with context tokens before benchmarking.

Sends batch_size requests with context_prompt to populate the prefix cache. The benchmark requests will share this prefix.

Source code in vllm/benchmarks/iterations.py
async def run_prefix_cache_warmup(
    session: aiohttp.ClientSession,
    rotator: EndpointRotator,
    model: str,
    context_prompt: str | None,
    batch_size: int,
) -> None:
    """Populate prefix cache with context tokens before benchmarking.

    Sends batch_size requests with context_prompt to populate the
    prefix cache. The benchmark requests will share this prefix.
    """
    if context_prompt is None:
        return

    logger.info("Populating prefix cache...")

    # Send warmup requests to all endpoints (round-robin)
    tasks = []
    for _ in range(batch_size):
        endpoint = rotator.next()
        task = asyncio.ensure_future(
            session.post(
                f"{endpoint}/v1/completions",
                json={
                    "model": model,
                    "prompt": context_prompt,
                    "max_tokens": 1,
                    "stream": False,
                },
            )
        )
        tasks.append(task)

    responses = await asyncio.gather(*tasks, return_exceptions=True)
    success_count = sum(
        1 for r in responses if not isinstance(r, Exception) and r.status == 200
    )
    # Consume response bodies
    for resp in responses:
        if not isinstance(resp, Exception):
            await resp.read()

    logger.info(
        "Prefix cache warmup: %d/%d requests succeeded", success_count, batch_size
    )

run_single_iteration async

run_single_iteration(
    session: ClientSession,
    config: BenchmarkConfig,
    rotator: EndpointRotator,
    benchmark_prompt: str,
    batch_size: int,
) -> tuple[float, int, int]

Run one iteration: sleep → queue requests → wake → measure.

Source code in vllm/benchmarks/iterations.py
async def run_single_iteration(
    session: aiohttp.ClientSession,
    config: BenchmarkConfig,
    rotator: EndpointRotator,
    benchmark_prompt: str,
    batch_size: int,
) -> tuple[float, int, int]:
    """Run one iteration: sleep → queue requests → wake → measure."""

    # 1. Pause scheduling on ALL endpoints
    await call_debug_endpoint(session, rotator, "/debug/sleep", {"level": "0"})

    # 2. Build requests and start sending them (they queue while server sleeps)
    # We use asyncio.ensure_future to actually start the requests immediately,
    # not just create coroutines. The requests will be sent to the server
    # and queue there while scheduling is paused.
    max_tokens = 1 if config.mode == "prefill" else config.iterations
    tasks = []
    for _ in range(batch_size):
        endpoint = rotator.next()

        # ensure_future schedules the coroutine immediately
        task = asyncio.ensure_future(
            session.post(
                f"{endpoint}/v1/completions",
                json={
                    "model": config.model,
                    "prompt": benchmark_prompt,
                    "max_tokens": max_tokens,
                    "stream": False,
                },
            )
        )
        tasks.append(task)

    # Small delay to ensure requests are queued on server
    await asyncio.sleep(0.1)

    # 3. Resume scheduling on ALL endpoints and time the batch
    start = time.perf_counter()
    await call_debug_endpoint(session, rotator, "/debug/wake_up")
    responses = await asyncio.gather(*tasks)
    elapsed_ms = (time.perf_counter() - start) * 1000

    # 4. Count tokens
    total_prompt_tokens = 0
    total_completion_tokens = 0
    for resp in responses:
        try:
            data = await resp.json()
            prompt_tokens, completion_tokens = count_tokens(data)
            total_prompt_tokens += prompt_tokens
            total_completion_tokens += completion_tokens
        except Exception as e:
            logger.warning("Failed to parse response: %s", e)

    return elapsed_ms, total_prompt_tokens, total_completion_tokens

write_results_json

write_results_json(
    results: list[IterationResult], output_path: str
) -> None

Write results to JSON file.

Source code in vllm/benchmarks/iterations.py
def write_results_json(results: list[IterationResult], output_path: str) -> None:
    """Write results to JSON file."""
    data = {
        "results": [asdict(r) for r in results],
        "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
    }
    with open(output_path, "w") as f:
        json.dump(data, f, indent=2)
    logger.info("Results written to: %s", output_path)