Add dynamic batching, FP16, and /metrics to the mBERT API#177
Open
ajamous wants to merge 2 commits into
Open
Conversation
Under load, the API currently serialises one HTTP request per GPU forward
pass and pads every input to the full 512-token max length. Recent load
tests on a g4dn.4xlarge (Tesla T4) confirmed this leaves the GPU idle
~93% of the time and caps sustained throughput at ~5 MPS — a 600 MPS
burst takes ~80 minutes to drain even though the hardware can do
hundreds of MPS.
This change introduces:
- DynamicBatcher service that coalesces concurrent single-message
requests into padded batches (configurable max size / wait window).
One tokenizer call, one forward pass, results split back to each
caller's asyncio Future.
- FP16 weights on CUDA for ~2x throughput on T4/A10/L4 tensor cores,
guarded so CPU/MPS keep FP32.
- max_text_length lowered from 512 -> 96 with dynamic padding
('longest') so short SMS no longer waste ~10x the FLOPs.
- torch.inference_mode() in place of torch.no_grad() for a small but
free speedup and cleaner semantics.
- /metrics Prometheus-compatible endpoint (no extra dep) exposing
request/batch counters, queue depth, batch-size histogram, and
inference time, so ots-bridge can drive adaptive concurrency.
All new knobs are env-var tunable: OTS_BATCHING_ENABLED,
OTS_MAX_BATCH_SIZE, OTS_BATCH_WAIT_MS, OTS_MAX_TEXT_LENGTH, OTS_USE_FP16.
Docs updated in CLAUDE.md and README.md.
Covers the async batching logic end-to-end against a stub model +
tokenizer (no mBERT weights required):
- single request returns the correct label
- 8 concurrent submissions coalesce into one batch
- max_batch_size is respected (10 requests => batches of <=4)
- partial batches flush after batch_wait_ms, not later
- model errors propagate to every future in the batch
- metrics counters increment correctly
- shutdown fails in-flight requests instead of hanging
- init_batcher honours OTS_BATCHING_ENABLED=false
The /metrics endpoint is exercised with FastAPI TestClient in both
the disabled-batcher and active-batcher states, asserting the
Prometheus exposition format (counters, gauges, histogram buckets).
Run:
pytest src/api_interface/tests/ --asyncio-mode=auto
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Why
Recent GPU load tests on a
g4dn.4xlarge(Tesla T4) showed the API serialises one HTTP request per GPU forward pass and pads every input to the full 512-token max length. The result: the GPU sits idle ~93% of the time, sustained throughput caps at ~5 MPS, and a 600 MPS burst takes ~80 minutes to drain — despite the hardware being capable of hundreds of MPS. See the test report in the accompanying issue/discussion for the full numbers.This PR captures the highest-leverage optimisations identified in that analysis and makes them tunable via environment variables so the bridge team can dial them in per deployment.
What's in the PR
Dynamic batching
DynamicBatcherinsrc/api_interface/services/batching_service.pyasyncio.Future)batch_wait_ms) andmax_batch_sizeFP16 + sequence-length fixes
max_text_lengthlowered from 512 → 96 withpadding='longest'so short SMS no longer waste ~10× the FLOPstorch.inference_mode()in place oftorch.no_grad()Observability
GET /metricsPrometheus-compatible endpoint (no extra dep — hand-rolled exposition format)ots_requests_total,ots_batches_total,ots_inference_seconds_total,ots_queue_depth,ots_last_batch_size,ots_batch_size_bucket{le="N"},ots_api_info{device,fp16,max_text_length,version}ots-bridgeto scrape and drive adaptive concurrencyConfig knobs (all
OTS_-prefixed env vars)batching_enabledtruemax_batch_size32batch_wait_ms15max_text_length96use_fp16trueTests
torch.nn.Module(no mBERT weights needed)pytest src/api_interface/tests/ --asyncio-mode=autoDocs
CLAUDE.md— new "Dynamic Batching" and "Observability: /metrics" sectionsREADME.md— updated Performance section + Health Checks sectionTest plan
/metricsregistered in route table/predict/response shape identical)OTS_BATCHING_ENABLED=falsefalls back to original per-request pathv2.8-amd64baseline on a T4 (needs environment with mBERT weights + CUDA — cannot be done in CI)ots-bridgecan raise its concurrency throttle (25 → 200+) against the new imageBackward compatibility
/predict/request/response shape is unchangedbatching_enabled=false/health,/audit,/feedback, TMForum endpoints untouchedExpected performance impact (from analysis)
Numbers to be validated by the bridge team's benchmark run against the built image.
Files changed
src/api_interface/services/batching_service.py,src/api_interface/routers/metrics.py,src/api_interface/tests/{__init__,conftest,test_batching_service,test_metrics_endpoint}.pysrc/api_interface/services/{model_loader,prediction_service}.py,src/api_interface/config/settings.py,src/api_interface/main.py,src/api_interface/routers/__init__.py,CLAUDE.md,README.md