Stratified sampling is imporant when you have extremely unbalanced machine learning datasets to ensure that each class is evenly distributed across your train/test/validation splits. While there are several solutions for multi-class data, there are few for multi-classs and multi-label datasets. So, I’m sharing my solution below.
Here’s a handy visual for stratified sampling from Wikipedia:
The solution depends on skmultilearn
’s IterativeStratification
method. Unfortunately, skmultilearn
is not very well maintained and I ran into a few sharp corners while coming up with this solution. I documented those sharp corners in the comments below.
Also note that these function is used inside a Kedro pipeline, where I am using Kedro’s PartitionDatasets. That’s why the return values are sometimes packaged inside a dictionary.
I am also making available my pytest
suites at the end.
import pandas as pd
import numpy as np
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
import ipytest
import pytest
from typeguard import typechecked
from skmultilearn.model_selection.iterative_stratification import IterativeStratification
# enables ipytest notebook magic
ipytest.autoconfig()
= logging.getLogger(__name__) logger
@typechecked
def stratify_shuffle_split_subsets(
str, Callable], output_partition_name: str, train_fraction: float
full_dataset: Dict[-> Tuple[pd.DataFrame, Dict[str, pd.DataFrame], Dict[str, pd.DataFrame]]:
) """
Stratify-shuffle-split the a multi-class multi-label dataset into
train/dev/validation sets.
Args:
full_dataset: the full supervised dataset. One column is the img uris, and the rest are binary labels.
output_partition_name: the name of the output partition
train_fraction: the fraction of data to reserve for the training dataset. The remaining data will be evenly
split into the dev and validation subsets.
Returns:
the supervised dataset, split into train/test/dev subsets. The train subset is not saved as a partitionedDataset
yet because we are doing label smoothing in the next processing step.
"""
if not 0 < train_fraction < 1:
raise ValueError(f"Training fraction must be value between 0 and 1, got {train_fraction}")
= full_dataset[output_partition_name]
loader = loader()
full_dataset = full_dataset.columns
original_cols # pandas documentation says to use .to_numpy() instead of .values for consistency
= full_dataset["img_url"].to_numpy()
img_urls
# sanity check: no duplicate labels
if not len(img_urls) == len(set(img_urls)):
raise ValueError("Duplicate image keys detected.")
= full_dataset.drop(columns=["img_url"]).to_numpy().astype(int)
labels # NOTE generators are replicated across workers. do stratified shuffle split beforehand
"stratifying dataset iteratively. this may take a while.")
logger.info(# NOTE: splits >2 broken; https://github.com/scikit-multilearn/scikit-multilearn/issues/209
# so, do 2 rounds of iterative splitting
= IterativeStratification(
stratifier =2, order=2, sample_distribution_per_fold=[1.0 - train_fraction, train_fraction],
n_splits
)# this class is a generator that produces k-folds. we just want to iterate it once to make a single static split
# NOTE: needs to be computed on hard labels.
= next(stratifier.split(X=img_urls, y=labels))
train_indexes, everything_else_indexes
= len(set(train_indexes).intersection(set(everything_else_indexes)))
num_overlapping_samples if num_overlapping_samples != 0:
raise ValueError(f"First splitting failed, {num_overlapping_samples} overlapping samples detected")
# s3url array shape (N_samp,)
= img_urls[train_indexes], img_urls[everything_else_indexes]
x_train, x_else # labels array shape (N_samp, n_classes)
= labels[train_indexes, :], labels[everything_else_indexes, :]
Y_train, Y_else
# now, split the "else" data evenly into dev/val splits
= IterativeStratification(n_splits=2, order=2) # splits evenly by default
stratifier = next(stratifier.split(x_else, Y_else))
dev_indexes, validation_indexes
= len(set(dev_indexes).intersection(set(validation_indexes)))
num_overlapping_samples if num_overlapping_samples != 0:
raise ValueError(f"Second splitting failed, {num_overlapping_samples} overlapping samples detected")
= (x_else[dev_indexes], x_else[validation_indexes])
x_dev, x_val = (Y_else[dev_indexes, :], Y_else[validation_indexes, :])
Y_dev, Y_val
for subset_name, frac, encodings_collection in [
"train", train_fraction, Y_train),
("dev", (1.0 - train_fraction) / 2, Y_dev),
("val", (1.0 - train_fraction) / 2, Y_val),
(
]:# column-wise sum. sum(counts) > n_samples due to imgs with >1 class
= np.sum(encodings_collection, axis=0)
count_values # skip first col, which is the image key, not a class ID
= {class_id: count_val for class_id, count_val in zip(full_dataset.columns[1:], count_values)}
counts f" {subset_name} subset ({frac * 100:.1f}%) encodings counts after stratification: {counts}")
logger.info(
# combine (x,y) data into dataframes
= pd.DataFrame(Y_train)
train_subset 0, "img_url", pd.Series(x_train))
train_subset.insert(= original_cols
train_subset.columns
= pd.DataFrame(Y_dev)
dev_subset 0, "img_url", pd.Series(x_dev))
dev_subset.insert(= original_cols
dev_subset.columns
= pd.DataFrame(Y_val)
val_subset 0, "img_url", pd.Series(x_val))
val_subset.insert(= original_cols
val_subset.columns
# TODO create Great Expectations suite for this node instead
assert "img_url" in dev_subset.columns
assert "img_url" in val_subset.columns
assert "img_url" in train_subset.columns
return (
train_subset,
{output_partition_name: dev_subset},
{output_partition_name: val_subset}, )
Testing that it all works
Below are my pytests and their results. Note, that the fixtures mimic Kedro PartitionDatasets.
@pytest.fixture
def unsplit_supervised_learning_dataset_partition_name():
"""PartitionName to simulate Kedro PartitionDataset"""
return "pytest_unsplit_supervised_learning_dataset_partition_name"
@pytest.fixture
def unsplit_supervised_learning_dataset(unsplit_supervised_learning_dataset_partition_name: str) -> Dict[str, Callable]:
"""
Generates a multi-class multilabel dataset.
Args:
unsplit_supervised_learning_dataset_partition_name:
Returns:
a multi-class multi-label df, packed in a lambda to mimic Kedro PartitionDataset
"""
= []
df_rows
= 60
n_total = 0
n_created = n_total // 3
n_class1 = n_total // 3
n_class2 = n_total // 3
n_class2_plus_3
for _ in range(n_class1):
"img_url": f"s3://test{n_created}", "class0": 1, "class1": 0, "class2": 0})
df_rows.append({+= 1
n_created
for _ in range(n_class2):
"img_url": f"s3://test{n_created}", "class0": 0, "class1": 1, "class2": 0})
df_rows.append({+= 1
n_created
for _ in range(n_class2_plus_3):
"img_url": f"s3://test{n_created}", "class0": 0, "class1": 1, "class2": 1})
df_rows.append({+= 1
n_created
= pd.DataFrame(df_rows)
df = lambda: df
data_loader
return {unsplit_supervised_learning_dataset_partition_name: data_loader}
@pytest.mark.parametrize("train_fraction", [-1.0, 0, 1, 5])
def test_train_fraction_range(
unsplit_supervised_learning_dataset, train_fraction, unsplit_supervised_learning_dataset_partition_name
):"""
Tests that train_fraction is validated between 0, 1.
"""
with pytest.raises(ValueError):
stratify_shuffle_split_subsets(=unsplit_supervised_learning_dataset,
full_dataset=unsplit_supervised_learning_dataset_partition_name,
output_partition_name=train_fraction,
train_fraction
)
def test_duplicate_uris_detected(unsplit_supervised_learning_dataset_partition_name):
"""Tests whether duplicate image keys are properly detected"""
# setup
= pd.DataFrame(
df
["img_url": "s3://duplicate1", "class0": 1, "class1": 0, "class2": 0},
{"img_url": "s3://duplicate1", "class0": 1, "class1": 0, "class2": 0},
{"img_url": "s3://duplicate1", "class0": 1, "class1": 0, "class2": 0},
{"img_url": "s3://duplicate2", "class0": 0, "class1": 1, "class2": 0},
{"img_url": "s3://duplicate2", "class0": 0, "class1": 1, "class2": 0},
{
]
)= lambda: df
data_loader = {unsplit_supervised_learning_dataset_partition_name: data_loader}
partition_dset
# test
with pytest.raises(ValueError):
stratify_shuffle_split_subsets(=partition_dset,
full_dataset=unsplit_supervised_learning_dataset_partition_name,
output_partition_name=0.6,
train_fraction
)
def test_col_names_same(unsplit_supervised_learning_dataset, unsplit_supervised_learning_dataset_partition_name):
"""
Tests that train_fraction is validated between 0, 1.
"""
# setup
train_split: pd.DataFramestr, pd.DataFrame]
dev_split: Dict[str, pd.DataFrame]
val_split: Dict[# set train fraction to 0.6 to make math easier later
= 0.6
train_fraction = unsplit_supervised_learning_dataset[unsplit_supervised_learning_dataset_partition_name]
data_loader = data_loader()
original_df = original_df.columns
original_df_cols
# run
= stratify_shuffle_split_subsets(
train_df, dev_split, val_split =unsplit_supervised_learning_dataset,
full_dataset=unsplit_supervised_learning_dataset_partition_name,
output_partition_name=train_fraction,
train_fraction
)
# test setup
= (
dev_df, val_df
dev_split[unsplit_supervised_learning_dataset_partition_name],
val_split[unsplit_supervised_learning_dataset_partition_name],
)= train_df.columns, dev_df.columns, val_df.columns
train_df_cols, dev_df_cols, val_df_cols
# tests
assert all(train_df_cols == original_df_cols)
assert all(dev_df_cols == original_df_cols)
assert all(val_df_cols == original_df_cols)
def test_stratify_shuffle_split_subsets(
unsplit_supervised_learning_dataset, unsplit_supervised_learning_dataset_partition_name
):"""Tests whether a multi-class multi-label dataset gets properly stratify-shuffle-split."""
# setup
train_split: pd.DataFramestr, pd.DataFrame]
dev_split: Dict[str, pd.DataFrame]
val_split: Dict[# set train fraction to 0.6 to make math easier later
= 0.6
train_fraction = train_fraction / ((1.0 - train_fraction) / 2)
ratio_train_to_val = unsplit_supervised_learning_dataset[unsplit_supervised_learning_dataset_partition_name]
data_loader = len(data_loader())
original_dset_len
# run
= stratify_shuffle_split_subsets(
train_df, dev_split, val_split =unsplit_supervised_learning_dataset,
full_dataset=unsplit_supervised_learning_dataset_partition_name,
output_partition_name=train_fraction,
train_fraction
)
# test setup
= (
dev_df, val_df
dev_split[unsplit_supervised_learning_dataset_partition_name],
val_split[unsplit_supervised_learning_dataset_partition_name],
)= train_df["img_url"], dev_df["img_url"], val_df["img_url"]
train_keys, dev_keys, val_keys
= set(train_keys), set(dev_keys), set(val_keys)
uniq_train_keys, uniq_dev_keys, uniq_val_keys
# Tests
## test keys
# test no duplicate keys within subsets
assert len(train_keys) == len(uniq_train_keys)
assert len(val_keys) == len(uniq_val_keys)
assert len(dev_keys) == len(uniq_dev_keys)
# test that all subsets have mutually exclusive S3 URIs
assert len(uniq_train_keys.intersection(uniq_dev_keys)) == 0
assert len(uniq_train_keys.intersection(uniq_val_keys)) == 0
assert len(uniq_val_keys.intersection(uniq_dev_keys)) == 0
## test ratios
# test that dev and val subsets get same number of samples
assert len(dev_df) == len(val_df)
# test that test set is 2x size of dev
assert len(train_df) / (len(train_df) + len(dev_df) + len(val_df)) == pytest.approx(train_fraction)
# test that all samples are used
assert len(train_df) + len(dev_df) + len(val_df) == original_dset_len
## test distributions
### class 0
# test whether dev/val set got same amount of class 0
assert len(dev_df[dev_df.class0 == 1]) == len(val_df[val_df.class0 == 1])
# test whether train set got 2x class0 as val set
assert len(train_df[train_df.class0 == 1]) / len(val_df[val_df.class0 == 1]) == pytest.approx(ratio_train_to_val)
### class 1
# test whether dev/val set got same amount of class 1
assert len(dev_df[dev_df.class1 == 1]) == len(val_df[val_df.class1 == 1])
# test whether train set got 2x class1 as val set
assert len(train_df[train_df.class1 == 1]) / len(val_df[val_df.class1 == 1]) == pytest.approx(ratio_train_to_val)
### class 2
# test whether dev/val set got same amount of class 2
assert len(dev_df[dev_df.class2 == 1]) == len(val_df[val_df.class2 == 1])
# test whether train set got 2x class1 as val set
assert len(train_df[train_df.class2 == 1]) / len(val_df[val_df.class2 == 1]) == pytest.approx(ratio_train_to_val)
ipytest.run()
.....2020-11-22 17:19:46,228 - __main__ - INFO - stratifying dataset iteratively. this may take a while.
2020-11-22 17:19:46,239 - __main__ - INFO - train subset (60.0%) encodings counts after stratification: {'class0': 12, 'class1': 24, 'class2': 12}
2020-11-22 17:19:46,251 - __main__ - INFO - dev subset (20.0%) encodings counts after stratification: {'class0': 4, 'class1': 8, 'class2': 4}
2020-11-22 17:19:46,253 - __main__ - INFO - val subset (20.0%) encodings counts after stratification: {'class0': 4, 'class1': 8, 'class2': 4}
.2020-11-22 17:19:46,271 - __main__ - INFO - stratifying dataset iteratively. this may take a while.
2020-11-22 17:19:46,280 - __main__ - INFO - train subset (60.0%) encodings counts after stratification: {'class0': 12, 'class1': 24, 'class2': 12}
2020-11-22 17:19:46,283 - __main__ - INFO - dev subset (20.0%) encodings counts after stratification: {'class0': 4, 'class1': 8, 'class2': 4}
2020-11-22 17:19:46,287 - __main__ - INFO - val subset (20.0%) encodings counts after stratification: {'class0': 4, 'class1': 8, 'class2': 4}
. [100%]
========================================== warnings summary ===========================================
/home/richard/src/DENDRA/seeweed/venv/lib/python3.7/site-packages/_pytest/config/__init__.py:1040
/home/richard/src/DENDRA/seeweed/venv/lib/python3.7/site-packages/_pytest/config/__init__.py:1040: PytestAssertRewriteWarning: Module already imported so cannot be rewritten: typeguard
self._mark_plugins_for_rewrite(hook)
-- Docs: https://docs.pytest.org/en/stable/warnings.html
7 passed, 1 warning in 0.15s