Guide to map-reduce jobs in Ray Data
Ray Data is a powerful library for scalable data processing that provides streaming execution and a clean Pythonic API. In this post, I'll walk through a simple “map-reduce” pipeline that extracts bigrams from a text corpus, aggregates, and counts them. While my previous post demonstrated a basic map-reduce using built-in reducers like Mean
, this tutorial will cover making custom reducers using AggregateFn
.
All the code for this tutorial is available as a script here.
Why Ray Data?
I like Ray Data for various reasons. First and foremost, it is a much nicer developer experience than writing PySpark. The API is great, and when things break you don't have a nasty JVM stack trace.
Beyond that, Ray Data can:
- Scale to large clusters (hundreds of machines, hundreds of TBs of data) with minimal or no code changes. This is a huge plus for future-proofing your data pipelines.
- Schedule computation to heterogeneous resources (e.g., some nodes have GPUs, others do not). You can get super fine grained here by defining custom resources, using fractional resources, etc.
- Automatically spill data to disk if it doesn’t fit in memory (out-of-memory recovery). It'll avoid this if possible by a memory backpressure mechanism.
- Lazily execute the graph of transformations for optimal performance.
- Data is streamed through the steps. This ensures that you fully utilize your cluster resources (for example, concurrently utilizing network bandwidth for ingest, CPUs for preprocessing, GPUs for inference, disk iops for storing results).
- Provide observability and debugging via the Ray Dashboard.
- If you're using Ray Turbo, you get nice additional features such as pipeline checkpointing for resuming failed runs.
1. Setting Up and Initializing Ray
We need to import Ray, along with Python’s regular expressions and some Ray Data primitives. Then, we can start a Ray runtime session with:
import ray
ray.init() # Initialize the Ray runtime
At this point, Ray is ready to schedule tasks locally or across multiple distributed nodes, if available.
2. Ingesting Our Data
Next, let’s prepare a small text corpus for demonstration. In practice, you would probably connect to your db, object storage, or dataframes (see the data loaders docs).
text_lines = [
"the cat sat on the mat every day",
"the cat ate a mouse every day",
"the cat and the man became friends",
"I like to eat pizza, but so does the cat.",
"my cat has a meme coin named after him",
"I eat pizza every day",
]
# Construct a dataset from these lines.
ds = ray.data.from_items(text_lines)
We now have a Ray Dataset
which support parallelized operations. Under the hood, your data is split into partition blocks. On a large cluster, these blocks get distributed and processed in parallel in a streaming fashion.
3. Extracting Bigrams (Map Step)
We’ll now define a function to extract bigrams (pairs of consecutive words) from each line. Remember, because Ray Data is lazy, transformations won’t run until we actually consume the data.
import re
def extract_bigrams(item: dict):
"""
Extract bigrams from text.
Args:
item: Dictionary containing text under the 'item' key
Yields:
dict: Records of the form {"bigram": "word1 word2", "count": 1}
"""
line = item["item"]
tokens = re.findall(r"\w+", line.lower())
for i in range(len(tokens) - 1):
# gotcha: returning a tuple(w1, w2) will not work, because Ray Data will convert this
# to an array. Then, the groupby will fail with an error like:
# `The truth value of an array with more than one element is ambiguous.`
yield {"bigram": f"{tokens[i]} {tokens[i + 1]}", "count": 1}
# Apply the flat_map transformation to get {bigram, count} pairs for each line.
# flat_map is a 1-to-many mapping
ds = ds.flat_map(extract_bigrams)
4. Deduplicating bigrams
We next group all the records by their “bigram” key, and sum up each bigram’s count. This combines duplicate objects, making it a "reduce" step:
# FIXME: Sum moving to `ray.data.aggregate` in the next release
from ray.data._internal.aggregate import Sum
ds = ds.groupby("bigram").aggregate(
Sum("count", alias_name="count")
)
As soon as we do this aggregation, Ray Data will:
- Shuffle and group the data by “bigram.”
- Sum all counts for matching bigrams.
- Produce a new dataset that has
{"item": "w1 w2", "count": total_count}
;Sum
automatically renames thebigram
key toitem
.
Ray comes out-of-the-box with a bunch of built-in aggregators: Sum, Mean, Std, Min, Max, Count, Sum.
5. Creating a Custom Aggregator to Collect a List of Bigrams
What about more complex reductions? For that, you need AggregateFn
.
To demonstrate, let's show to collect bigrams based on their total occurences in the corpus.
from ray.data.aggregate import AggregateFn
def make_list_aggregator(input_key: str, output_key: str = "items"):
"""
Creates an aggregator that collects grouped values
into a single (merged) list.
For example, if rows grouped by {'count': 1} are:
[{"bigram": "the cat"}, {"bigram": "the man"}],
the final accumulated value would be:
["the cat", "the man"]
"""
return AggregateFn(
# Initialize an empty list "accumulator" for each group.
init=lambda _key: [],
# Merge two accumulator lists.
merge=lambda accum1, accum2: accum1 + accum2,
# The key to store the resulting accumulator under
name=output_key,
# Append the relevant field (e.g. row[input_key]) to the accumulator list.
accumulate_row=lambda acc, row: acc + [row[input_key]],
# we don't need to do anything with the final accumulated value,
# so we just return it as-is
finalize=lambda acc: acc,
)
This aggregator starts an empty list (init
), appends an item to the list for each row in that group (accumulate_row
), merges two lists if needed (merge
), and returns the final list (finalize
).
6. Reducing Bigrams by Their Counts
Now we apply our custom aggregator. First we group by “count,” then collect all the bigrams that share that count into a single list under a new "bigrams" key.
ds = ds.groupby("count").aggregate(
make_list_aggregator(input_key="bigram", output_key="bigrams")
)
The resulting dataset will have rows like:
{"count": 1, "bigrams": ["door mat", "meme coin", ...]}
7. Mapping the List of Bigrams to a Simple Metric
Finally, suppose we just want to know how many bigrams have a given count—like a histogram. We map each row to a simpler record (one-to-one transformation):
ds = ds.map(
lambda row: {
"count": row["count"],
"num_bigrams": len(row["bigrams"]),
}
)
This transformation yields, for instance:
[
{'count': 1, 'num_bigrams': 31},
{'count': 2, 'num_bigrams': 1},
...
]
8. Triggering Execution and Inspecting Results
At this point, our pipeline is defined but not yet fully executed. This is Ray Data’s lazy evaluation in action. By calling a method like ds.take_all()
, we trigger the actual compute:
results = ds.take_all()
print("Histogram of how many bigrams share each count:")
print(results)
Ray Data will (in parallel) extract bigrams, shuffle and sum them, group them again, and finally map them to the smaller output, returning the final result. If our data is huge, Ray Data’s streaming execution keeps memory overhead under control and can spill intermediate objects to disk as needed.
9. Scheduling, Observability, and Going Bigger
One nice benefit of Ray is its dashboard. You can use it to monitor the cluster utilization, diagnose failing tasks or actors, examine flamegraphs to find bottlenecks, read logs on your various nodes, etc.
Bringing it all together
For teaching purposes, I've over-commented the code. In practice, my pipeline code would be cleaner:
# polished version
results = (
ray.data.from_items(text_lines)
.flat_map(extract_bigrams)
.groupby("bigram")
.aggregate(Sum("count", alias_name="count"))
.groupby("count")
.aggregate(make_list_aggregator(input_key="bigram", output_key="bigrams"))
.map(lambda row: {
"count": row["count"],
"num_bigrams": len(row["bigrams"]),
})
.take_all()
)
All this code is available here.