Guide to map-reduce jobs in Ray Data

Ray Data is a powerful library for scalable data processing that provides streaming execution and a clean Pythonic API. In this post, I'll walk through a simple “map-reduce” pipeline that extracts bigrams from a text corpus, aggregates, and counts them. While my previous post demonstrated a basic map-reduce using built-in reducers like Mean, this tutorial will cover making custom reducers using AggregateFn.

All the code for this tutorial is available as a script here.

Why Ray Data?

I like Ray Data for various reasons. First and foremost, it is a much nicer developer experience than writing PySpark. The API is great, and when things break you don't have a nasty JVM stack trace.

Beyond that, Ray Data can:

1. Setting Up and Initializing Ray

We need to import Ray, along with Python’s regular expressions and some Ray Data primitives. Then, we can start a Ray runtime session with:

import ray
ray.init()  # Initialize the Ray runtime

At this point, Ray is ready to schedule tasks locally or across multiple distributed nodes, if available.

2. Ingesting Our Data

Next, let’s prepare a small text corpus for demonstration. In practice, you would probably connect to your db, object storage, or dataframes (see the data loaders docs).

text_lines = [
    "the cat sat on the mat every day",
    "the cat ate a mouse every day",
    "the cat and the man became friends",
    "I like to eat pizza, but so does the cat.",
    "my cat has a meme coin named after him",
    "I eat pizza every day",
]

# Construct a dataset from these lines.
ds = ray.data.from_items(text_lines)

We now have a Ray Dataset which support parallelized operations. Under the hood, your data is split into partition blocks. On a large cluster, these blocks get distributed and processed in parallel in a streaming fashion.

3. Extracting Bigrams (Map Step)

We’ll now define a function to extract bigrams (pairs of consecutive words) from each line. Remember, because Ray Data is lazy, transformations won’t run until we actually consume the data.

import re

def extract_bigrams(item: dict):
    """
    Extract bigrams from text.

    Args:
        item: Dictionary containing text under the 'item' key
        
    Yields:
        dict: Records of the form {"bigram": "word1 word2", "count": 1}
    """
    line = item["item"]
    tokens = re.findall(r"\w+", line.lower())
    for i in range(len(tokens) - 1):
        # gotcha: returning a tuple(w1, w2) will not work, because Ray Data will convert this
        # to an array. Then, the groupby will fail with an error like:
        # `The truth value of an array with more than one element is ambiguous.`
        yield {"bigram": f"{tokens[i]} {tokens[i + 1]}", "count": 1}

# Apply the flat_map transformation to get {bigram, count} pairs for each line.
# flat_map is a 1-to-many mapping
ds = ds.flat_map(extract_bigrams)

4. Deduplicating bigrams

We next group all the records by their “bigram” key, and sum up each bigram’s count. This combines duplicate objects, making it a "reduce" step:

# FIXME: Sum moving to `ray.data.aggregate` in the next release
from ray.data._internal.aggregate import Sum

ds = ds.groupby("bigram").aggregate(
    Sum("count", alias_name="count")
)

As soon as we do this aggregation, Ray Data will:

Ray comes out-of-the-box with a bunch of built-in aggregators: Sum, Mean, Std, Min, Max, Count, Sum.

5. Creating a Custom Aggregator to Collect a List of Bigrams

What about more complex reductions? For that, you need AggregateFn.

To demonstrate, let's show to collect bigrams based on their total occurences in the corpus.

from ray.data.aggregate import AggregateFn

def make_list_aggregator(input_key: str, output_key: str = "items"):
    """
    Creates an aggregator that collects grouped values
    into a single (merged) list.

    For example, if rows grouped by {'count': 1} are:
      [{"bigram": "the cat"}, {"bigram": "the man"}],
    the final accumulated value would be:
      ["the cat", "the man"]
    """
    return AggregateFn(
            # Initialize an empty list "accumulator" for each group.
            init=lambda _key: [],
            # Merge two accumulator lists.
            merge=lambda accum1, accum2: accum1 + accum2,
            # The key to store the resulting accumulator under
            name=output_key,
            # Append the relevant field (e.g. row[input_key]) to the accumulator list.
            accumulate_row=lambda acc, row: acc + [row[input_key]],
            # we don't need to do anything with the final accumulated value,
            # so we just return it as-is
            finalize=lambda acc: acc,
        )

This aggregator starts an empty list (init), appends an item to the list for each row in that group (accumulate_row), merges two lists if needed (merge), and returns the final list (finalize).

6. Reducing Bigrams by Their Counts

Now we apply our custom aggregator. First we group by “count,” then collect all the bigrams that share that count into a single list under a new "bigrams" key.

ds = ds.groupby("count").aggregate(
    make_list_aggregator(input_key="bigram", output_key="bigrams")
)

The resulting dataset will have rows like:

{"count": 1, "bigrams": ["door mat", "meme coin", ...]}

7. Mapping the List of Bigrams to a Simple Metric

Finally, suppose we just want to know how many bigrams have a given count—like a histogram. We map each row to a simpler record (one-to-one transformation):

ds = ds.map(
    lambda row: {
        "count": row["count"],
        "num_bigrams": len(row["bigrams"]),
    }
)

This transformation yields, for instance:

[
  {'count': 1, 'num_bigrams': 31},
  {'count': 2, 'num_bigrams': 1},
  ...
]

8. Triggering Execution and Inspecting Results

At this point, our pipeline is defined but not yet fully executed. This is Ray Data’s lazy evaluation in action. By calling a method like ds.take_all(), we trigger the actual compute:

results = ds.take_all()
print("Histogram of how many bigrams share each count:")
print(results)

Ray Data will (in parallel) extract bigrams, shuffle and sum them, group them again, and finally map them to the smaller output, returning the final result. If our data is huge, Ray Data’s streaming execution keeps memory overhead under control and can spill intermediate objects to disk as needed.

9. Scheduling, Observability, and Going Bigger

One nice benefit of Ray is its dashboard. You can use it to monitor the cluster utilization, diagnose failing tasks or actors, examine flamegraphs to find bottlenecks, read logs on your various nodes, etc.

Bringing it all together

For teaching purposes, I've over-commented the code. In practice, my pipeline code would be cleaner:

# polished version
results = (
        ray.data.from_items(text_lines)
        .flat_map(extract_bigrams)
        .groupby("bigram")
        .aggregate(Sum("count", alias_name="count"))
        .groupby("count")
        .aggregate(make_list_aggregator(input_key="bigram", output_key="bigrams"))
        .map(lambda row: {
            "count": row["count"],
            "num_bigrams": len(row["bigrams"]),
        })
        .take_all()
    )

All this code is available here.

Copyright Richard Decal. richarddecal.com