Simplifying parallel processing in Ray with ray_map and ray_starmap
Ray has a nice way to parallelize function invocations:
ray.get([f.remote(x) for x in iterable])
I find this cleaner than the built-in option with concurrent futures:
with concurrent.futures.ProcessPoolExecutor() as executor:
results = list(executor.map(f, iterable))
However, I feel like the ultimate API for mapping with Ray's remote functions would be something like f.map(iterables)
. So, I've written a convenience function ray_map
which gets us close: ray_map(f, iterables)
.
Similarly, I've written a convenience function ray_starmap
that helps us process iterables of sequences: ray_starmap(f, iterable)
.
The code for ray_map
and ray_starmap
are here: https://github.com/crypdick/ray-map
Key Features of ray_map
and ray_starmap
:
- Parallel Execution: Automatically distributes tasks across all available CPUs in the cluster.
- Order Control: Choose whether to receive results in the order of input or as they complete (controlled with
order_outputs=False
) - Reduced boilerplate: Cleaner code than vanilla Ray
How to use ray_map
Using ray_map
is straightforward:
import ray
@ray.remote
def square(x):
return x * x
results = list(ray_map(square, [1, 2, 3]))
print(results) # Output: [1, 4, 9]
Handling Multiple Arguments and Keyword Arguments
ray_map
also supports functions with multiple arguments and keyword arguments:
@ray.remote
def add(x, y):
return x + y
# Multiple arguments
results = list(ray_map(add, [1, 2, 3], [4, 5, 6]))
print(results) # Output: [5, 7, 9]
# With keyword arguments
@ray.remote
def power(x, exp=2):
return x ** exp
results = list(ray_map(power, [1, 2, 3], kwargs={'exp': 3}))
print(results) # Output: [1, 8, 27]
Notice that in the add
example that the input iterators got zipped together, like [f.remote(*args, **kwargs) for args in zip(*input_iterators)]
.
How to use ray_starmap
Sometimes, you want to process a single iterable of sequences. This is where ray_starmap
comes in handy. It works similarly to Python's built-in itertools.starmap
.
How ray_starmap Works
ray_starmap takes a Ray remote function and an iterable of argument sequences. It executes the function in parallel, yielding results as they complete.
Here's an example:
@ray.remote
def add(x, y):
return x + y
# Using ray_starmap with a single iterable of sequences
results = list(ray_starmap(add, [(1, 4), (2, 5), (3, 6)]))
print(results) # Output: [5, 7, 9]
At its core, this is like [f.remote(*args, **kwargs) for args in input_iterator]
Flexibility with Keyword Arguments
Just like ray_map, ray_starmap also supports keyword arguments:
@ray.remote
def power(x, exp=2):
return x ** exp
# Using ray_starmap with keyword arguments
results = list(ray_starmap(power, [(1,), (2,), (3,)], kwargs={'exp': 3}))
print(results) # Output: [1, 8, 27]
Performance considerations
- if you are running your results through multiple transforms, it's better to avoid running
ray.get
until the end of all your transformations - processing tasks in submission order hurts performance because smaller tasks might finish earlier than larger tasks.