Use these normalization values for torchvision datasets

When training an image classifier, it's important to normalize the images using the mean and standard deviation of your actual dataset. Otherwise, the incorrectly normalized features could have elongated valleys in the loss landscape, which can cause optimization problems.

To my dismay, I often see people blindly copy-pasting the same normalization values across different reference implementations tutorials, quickstart demos, Kaggle notebooks, etc. For example, the vit_huge_patch14_244_in21k model was trained using was trained using apparently bogus mean and standard deviation values of (0.5, 0.5, 0.5).

To make matters worse, I wasn't able to easily Google the correct normalization values to use for many popular vision datasets. So, I've computed the correct mean and standard deviation for a bunch of popular image datasets that are available in the torchvision dataset.

Reference values

Dataset Mean (unnormalized) StdDev (unnormalized) Mean (normalized) StdDev (normalized)
CIFAR10 (125.31, 122.95, 113.87) (51.56, 50.83, 51.22) (0.4914, 0.48216, 0.44653) (0.2022, 0.19932, 0.20086)
CIFAR100 (129.3, 124.07, 112.43) (51.2, 50.58, 51.56) (0.50708, 0.48655, 0.44092) (0.2008, 0.19835, 0.2022)
Country211 (116.52, 114.82, 107.3) (60.09, 58.75, 61.33) (0.45694, 0.45027, 0.42077) (0.23566, 0.23038, 0.24052)
DTD (134.84, 120.63, 108.3) (46.91, 47.12, 46.07) (0.52879, 0.47304, 0.42472) (0.18396, 0.18477, 0.18067)
EuroSAT (87.82, 96.97, 103.98) (23.3, 16.61, 14.09) (0.34438, 0.38029, 0.40777) (0.09137, 0.06512, 0.05524)
FGVCAircraft (122.72, 130.6, 136.57) (50.12, 49.84, 55.36) (0.48125, 0.51215, 0.53555) (0.19655, 0.19544, 0.21711)
FakeData (127.3, 127.32, 127.31) (74.02, 74.02, 74.02) (0.49921, 0.49929, 0.49924) (0.29027, 0.29026, 0.29027)
FashionMNIST (72.94) (81.66) (0.28604) (0.32025)
Flowers102 (110.4, 97.39, 75.58) (66.83, 54.38, 57.33) (0.43296, 0.38192, 0.29638) (0.26207, 0.21327, 0.22484)
Food101 (138.97, 113.09, 87.62) (59.54, 62.28, 61.75) (0.54499, 0.44349, 0.3436) (0.23349, 0.24423, 0.24216)
ImageNet (123.67, 116.28, 103.53) (58.4, 57.12, 57.38) (0.485, 0.456, 0.406) (0.229, 0.224, 0.225)
Imagenette (117.94, 116.79, 109.52) (62.51, 60.94, 62.94) (0.46252, 0.45801, 0.42948) (0.24515, 0.23898, 0.24681)
KMNIST (48.9) (86.27) (0.19176) (0.33831)
Kitti (93.83, 98.76, 95.88) (78.78, 80.13, 81.2) (0.36797, 0.3873, 0.37599) (0.30895, 0.31424, 0.31843)
MNIST (33.32) (76.83) (0.13066) (0.30131)
Omniglot (235.13) (66.87) (0.92206) (0.26225)
QMNIST (33.36) (77.01) (0.13083) (0.30199)
RenderedSST2 (251.1, 251.1, 251.1) (26.52, 26.52, 26.52) (0.9847, 0.9847, 0.9847) (0.10399, 0.10399, 0.10399)
SBDataset (116.92, 111.92, 103.47) (61.13, 60.04, 61.18) (0.45853, 0.43888, 0.40577) (0.23974, 0.23546, 0.23992)
SBU (120.16, 115.99, 106.92) (57.69, 56.38, 58.82) (0.47123, 0.45488, 0.4193) (0.22624, 0.22112, 0.23065)
SEMEION (83.8) (118.62) (0.32863) (0.46517)
STL10 (113.91, 112.15, 103.69) (57.16, 56.48, 57.09) (0.44671, 0.43981, 0.40665) (0.22415, 0.22149, 0.2239)
SVHN (111.61, 113.16, 120.57) (30.61, 31.38, 26.81) (0.43768, 0.44377, 0.4728) (0.12003, 0.12308, 0.10515)
USPS (62.95) (71.53) (0.24688) (0.28051)
VOCDetection (116.55, 111.75, 103.57) (60.97, 59.95, 61.13) (0.45705, 0.43825, 0.40617) (0.23909, 0.2351, 0.23973)
VOCSegmentation (116.48, 113.0, 104.12) (60.41, 59.48, 60.93) (0.4568, 0.44313, 0.4083) (0.23691, 0.23326, 0.23893)
WIDERFace (119.86, 110.81, 104.15) (67.26, 64.71, 64.71) (0.47002, 0.43454, 0.40842) (0.26378, 0.25377, 0.25377)

Note that it's important to only compute these on the train subset of each dataset, otherwise if you include the val split you are letting information leak into the model training. Also note that while the image values natively span [0-255], it's common practice to rescale these values to [0-1] values using Pytorch's ToTensor or ToDtype transforms. So, I've also included the scaled values in the table.

My hope is that these values will be published alongside the official dataset sources, as well as in 3rd party libraries such as torchvision and timm.

In the interest of time, I skipped a bunch of datasets that had gotchas. Please send me a note with the normalization constants if you figure them out!

Implementation

I implemented this code using Ray Data, which is a nice library for scaling ML workloads. It has some nice features:

The full code to calculate these numbers are available here. The meat of the code is quite simple:

ds = ray.data.from_items(dataset)
# (PIL img, label) -> np.array
ds = ds.map(extract_and_process_image)
# np.array -> per-channel mean, standard deviation
ds = ds.map(compute_channel_stats)
# Count channels from first sample
num_channels = len([k for k in ds.take(1)[0].keys() if k.startswith("mean_")])
# Aggregate statistics across images
results = ds.aggregate(
	*[Mean(f"mean_{i}", alias_name=f"mean_{i}") for i in range(num_channels)],
	*[Mean(f"stddev_{i}", alias_name=f"stddev_{i}") for i in range(num_channels)],
)

Copyright Richard Decal. richarddecal.com