Skip to content

Add dynamic batching, FP16, and /metrics to the mBERT API#177

Open
ajamous wants to merge 2 commits into
mainfrom
feat/mbert-dynamic-batching
Open

Add dynamic batching, FP16, and /metrics to the mBERT API#177
ajamous wants to merge 2 commits into
mainfrom
feat/mbert-dynamic-batching

Conversation

@ajamous
Copy link
Copy Markdown
Collaborator

@ajamous ajamous commented Apr 15, 2026

Why

Recent GPU load tests on a g4dn.4xlarge (Tesla T4) showed the API serialises one HTTP request per GPU forward pass and pads every input to the full 512-token max length. The result: the GPU sits idle ~93% of the time, sustained throughput caps at ~5 MPS, and a 600 MPS burst takes ~80 minutes to drain — despite the hardware being capable of hundreds of MPS. See the test report in the accompanying issue/discussion for the full numbers.

This PR captures the highest-leverage optimisations identified in that analysis and makes them tunable via environment variables so the bridge team can dial them in per deployment.

What's in the PR

Dynamic batching

  • New DynamicBatcher in src/api_interface/services/batching_service.py
  • Coalesces concurrent single-message requests into padded batches (one tokenizer call, one forward pass, results split back to each caller's asyncio.Future)
  • Configurable collection window (batch_wait_ms) and max_batch_size
  • In-memory metrics: request/batch counters, queue depth, batch-size histogram, inference time

FP16 + sequence-length fixes

  • FP16 weights on CUDA (guarded; CPU/MPS keep FP32) for ~2× throughput on tensor-core GPUs
  • max_text_length lowered from 512 → 96 with padding='longest' so short SMS no longer waste ~10× the FLOPs
  • torch.inference_mode() in place of torch.no_grad()

Observability

  • GET /metrics Prometheus-compatible endpoint (no extra dep — hand-rolled exposition format)
  • Series: ots_requests_total, ots_batches_total, ots_inference_seconds_total, ots_queue_depth, ots_last_batch_size, ots_batch_size_bucket{le="N"}, ots_api_info{device,fp16,max_text_length,version}
  • Designed for the ots-bridge to scrape and drive adaptive concurrency

Config knobs (all OTS_-prefixed env vars)

Setting Default Description
batching_enabled true Master switch
max_batch_size 32 Max requests per forward pass
batch_wait_ms 15 Collection window
max_text_length 96 Token truncation
use_fp16 true FP16 on CUDA only

Tests

  • 10 new pytest tests (8 batcher + 2 metrics endpoint) — all passing
  • Exercise real async logic against a stub torch.nn.Module (no mBERT weights needed)
  • Run locally: pytest src/api_interface/tests/ --asyncio-mode=auto

Docs

  • CLAUDE.md — new "Dynamic Batching" and "Observability: /metrics" sections
  • README.md — updated Performance section + Health Checks section

Test plan

  • Unit tests: 10/10 passing on CPU
  • FastAPI app imports cleanly with /metrics registered in route table
  • Single-request HTTP contract unchanged (/predict/ response shape identical)
  • Backward-compatible: OTS_BATCHING_ENABLED=false falls back to original per-request path
  • GPU benchmark against v2.8-amd64 baseline on a T4 (needs environment with mBERT weights + CUDA — cannot be done in CI)
  • Verify ots-bridge can raise its concurrency throttle (25 → 200+) against the new image

Backward compatibility

  • /predict/ request/response shape is unchanged
  • Per-request thread-pool fallback preserved for batching_enabled=false
  • No new runtime dependencies
  • Existing /health, /audit, /feedback, TMForum endpoints untouched

Expected performance impact (from analysis)

Configuration Est. per-batch GPU time (T4, FP16) Effective MPS
Today (batch=1, 512 tokens) ~460 ms ~5 MPS
After this PR (batch=32, 96 tokens, FP16) ~20–40 ms ~800–1500 MPS

Numbers to be validated by the bridge team's benchmark run against the built image.

Files changed

  • New: src/api_interface/services/batching_service.py, src/api_interface/routers/metrics.py, src/api_interface/tests/{__init__,conftest,test_batching_service,test_metrics_endpoint}.py
  • Modified: src/api_interface/services/{model_loader,prediction_service}.py, src/api_interface/config/settings.py, src/api_interface/main.py, src/api_interface/routers/__init__.py, CLAUDE.md, README.md

claude added 2 commits April 15, 2026 16:17
Under load, the API currently serialises one HTTP request per GPU forward
pass and pads every input to the full 512-token max length. Recent load
tests on a g4dn.4xlarge (Tesla T4) confirmed this leaves the GPU idle
~93% of the time and caps sustained throughput at ~5 MPS — a 600 MPS
burst takes ~80 minutes to drain even though the hardware can do
hundreds of MPS.

This change introduces:

- DynamicBatcher service that coalesces concurrent single-message
  requests into padded batches (configurable max size / wait window).
  One tokenizer call, one forward pass, results split back to each
  caller's asyncio Future.
- FP16 weights on CUDA for ~2x throughput on T4/A10/L4 tensor cores,
  guarded so CPU/MPS keep FP32.
- max_text_length lowered from 512 -> 96 with dynamic padding
  ('longest') so short SMS no longer waste ~10x the FLOPs.
- torch.inference_mode() in place of torch.no_grad() for a small but
  free speedup and cleaner semantics.
- /metrics Prometheus-compatible endpoint (no extra dep) exposing
  request/batch counters, queue depth, batch-size histogram, and
  inference time, so ots-bridge can drive adaptive concurrency.

All new knobs are env-var tunable: OTS_BATCHING_ENABLED,
OTS_MAX_BATCH_SIZE, OTS_BATCH_WAIT_MS, OTS_MAX_TEXT_LENGTH, OTS_USE_FP16.
Docs updated in CLAUDE.md and README.md.
Covers the async batching logic end-to-end against a stub model +
tokenizer (no mBERT weights required):

- single request returns the correct label
- 8 concurrent submissions coalesce into one batch
- max_batch_size is respected (10 requests => batches of <=4)
- partial batches flush after batch_wait_ms, not later
- model errors propagate to every future in the batch
- metrics counters increment correctly
- shutdown fails in-flight requests instead of hanging
- init_batcher honours OTS_BATCHING_ENABLED=false

The /metrics endpoint is exercised with FastAPI TestClient in both
the disabled-batcher and active-batcher states, asserting the
Prometheus exposition format (counters, gauges, histogram buckets).

Run:
    pytest src/api_interface/tests/ --asyncio-mode=auto
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants