How to submit many concurrent requests to Ray Serve
Ray Serve has the ability to dynamically batch incoming requests to process them in chunks. However, if you use requests.post()
calls they will be block, so you don't benefit from batching.
Instead, you want to fire-and-forget many requests at once using asynchronous requests. Here's how you can achieve this using aiohttp
:
import asyncio
import time
import aiohttp
import numpy as np
import requests
from ray import serve
from ray.serve.handle import DeploymentHandle
from starlette.requests import Request
model = lambda x: np.random.rand(len(x))
@serve.deployment
class BatchedModel:
def __init__(self):
self.model = model
@serve.batch(max_batch_size=5, batch_wait_timeout_s=0.1)
async def process_batch(self, input_data: list[dict]) -> list[float]:
print(f"Processing batch of size: {len(input_data)}")
results = model(input_data)
return results
async def __call__(self, request: Request):
input_data = await request.json()
# Route the request to the batch handler
return await self.process_batch(input_data)
def main():
model = BatchedModel.bind()
_handle: DeploymentHandle = serve.run(model, name="batched-model")
# Simplified sample input
sample_input = {"value": 1.0}
url = "http://127.0.0.1:8000/"
# --- Test with a single request ---
print("\n--- Sending single request ---")
start_time = time.time()
prediction = requests.post(url, json=sample_input).json()
end_time = time.time()
print(f"Time taken: {end_time - start_time:.4f}s")
# --- Simulate many concurrent requests ---
print("\n--- Sending 100 concurrent requests ---")
sample_input_list = [sample_input] * 100
async def fetch(session, url, data):
async with session.post(url, json=data) as response:
return await response.json()
async def fetch_all():
async with aiohttp.ClientSession() as session:
tasks = [fetch(session, url, input_item) for input_item in sample_input_list]
start_time_async = time.time()
responses = await asyncio.gather(*tasks)
end_time_async = time.time()
print(f"Time taken for 100 requests: {end_time_async - start_time_async:.4f}s")
return responses
start_time_main = time.time()
responses = asyncio.run(fetch_all())
end_time_main = time.time()
# Note: Responses might vary depending on how requests are batched
print(f"First response: {responses[0]}")
print(f"Total time (including client-side async setup): {end_time_main - start_time_main:.4f}s")
if __name__ == "__main__":
main()