Local vs Global task parallelism in Ray Distributed

By default, ray.remote tasks are executed with global parallelism: the tasks are spread across all the CPUs available in the cluster. Sometimes, it is better to instead run tasks with local parallelism: pinning the tasks to only work on the current node. A typical example is when your task generates large objects and you want to avoid costly and unnecessary network I/O. This is why other distributed frameworks like Spark have both local and global flavors of their primitives, e.g. combiners and reducers.

Here is an example of how to achieve local parallelism in Ray using the .options() and locality-aware scheduling. The crux is using task.options(scheduling_strategy=NodeAffinitySchedulingStrategy(node_id=get_runtime_context().get_node_id())).remote():

import ray

ray.init()


@ray.remote(scheduling_strategy="SPREAD")
def global_parallel_task():
    node_id = ray.get_runtime_context().get_node_id()
    return sum(ray.get([local_parallel_task.options(scheduling_strategy=ray.util.scheduling_strategies.NodeAffinitySchedulingStrategy(
        node_id=ray.get_runtime_context().get_node_id(), soft=False)).remote(node_id) for _ in range(100)]))


@ray.remote
def local_parallel_task(originating_id):
    node_id = ray.get_runtime_context().get_node_id()
    print(f"Originating node Id: {originating_id} | current node ID: {node_id}")
    return 1


sum(ray.get([global_parallel_task.remote() for worker in range(1000)]))

Creating a drop-in replacement for ThreadPoolExecutor

For those of us that want a drop-in replacement for ThreadPoolExecutor, I've implemented RayLocalParallelExecutor that does just that:

class RayLocalParallelExecutor:
    """Context manager for running Ray tasks locally in parallel using locality aware scheduling"""
    def __init__(self, max_workers=None):
        self.max_workers = max_workers
        self.node_id = None
        
    def __enter__(self):
        # Get the current node ID when entering the context
        self.node_id = ray.get_runtime_context().get_node_id()
        return self
        
    def __exit__(self, exc_type, exc_val, exc_tb):
        pass
    
    def submit(self, fn, *args, **kwargs):
        """Submit a function to be executed as a Ray remote task"""
        # Convert the function to a Ray remote function if it isn't already
        if not hasattr(fn, 'remote'):
            fn = ray.remote(fn)
        
        # Submit the task with node affinity scheduling
        scheduling_strategy = ray.util.scheduling_strategies.NodeAffinitySchedulingStrategy(
            node_id=self.node_id,
            soft=False  # Strict scheduling - only run on specified node
        )
        return fn.options(scheduling_strategy=scheduling_strategy).remote(*args, **kwargs)
    
    def map(self, fn, *iterables):
        """Map a function over iterables, executing in parallel"""
        futures = [self.submit(fn, *args) for args in zip(*iterables)]
        return (ray.get(future) for future in futures)

Creating the local parallel tasks then simplifies to:

@ray.remote(scheduling_strategy="SPREAD")
def global_parallel_task_with_context_manager():
    node_id = ray.get_runtime_context().get_node_id()
    with RayLocalParallelExecutor() as executor:
        futures = [executor.submit(local_parallel_task, node_id) for _ in range(100)]
    return sum(ray.get(futures))

sum(ray.get([global_parallel_task_with_context_manager.remote() for worker in range(1000)]))

Copyright Richard Decal. richarddecal.com