difference between StratifiedKFold and StratifiedShuffleSplit in sklearn
Output examples of KFold, StratifiedKFold, StratifiedShuffleSplit:
The above pictorial output is an extension of @Ken Syme
's code:
from sklearn.model_selection import KFold, StratifiedKFold, StratifiedShuffleSplit
SEED = 43
SPLIT = 3
X_train = [0,1,2,3,4,5,6,7,8]
y_train = [0,0,0,0,0,0,1,1,1] # note 6,7,8 are labelled class '1'
print("KFold, shuffle=False (default)")
kf = KFold(n_splits=SPLIT, random_state=SEED)
for train_index, test_index in kf.split(X_train, y_train):
print("TRAIN:", train_index, "TEST:", test_index)
print("KFold, shuffle=True")
kf = KFold(n_splits=SPLIT, shuffle=True, random_state=SEED)
for train_index, test_index in kf.split(X_train, y_train):
print("TRAIN:", train_index, "TEST:", test_index)
print("\nStratifiedKFold, shuffle=False (default)")
skf = StratifiedKFold(n_splits=SPLIT, random_state=SEED)
for train_index, test_index in skf.split(X_train, y_train):
print("TRAIN:", train_index, "TEST:", test_index)
print("StratifiedKFold, shuffle=True")
skf = StratifiedKFold(n_splits=SPLIT, shuffle=True, random_state=SEED)
for train_index, test_index in skf.split(X_train, y_train):
print("TRAIN:", train_index, "TEST:", test_index)
print("\nStratifiedShuffleSplit")
sss = StratifiedShuffleSplit(n_splits=SPLIT, random_state=SEED, test_size=3)
for train_index, test_index in sss.split(X_train, y_train):
print("TRAIN:", train_index, "TEST:", test_index)
print("\nStratifiedShuffleSplit (can customise test_size)")
sss = StratifiedShuffleSplit(n_splits=SPLIT, random_state=SEED, test_size=2)
for train_index, test_index in sss.split(X_train, y_train):
print("TRAIN:", train_index, "TEST:", test_index)
In stratKFolds
, each test set should not overlap, even when shuffle is included. With stratKFolds
and shuffle=True
, the data is shuffled once at the start, and then divided into the number of desired splits. The test data is always one of the splits, the train data is the rest.
In ShuffleSplit
, the data is shuffled every time, and then split. This means the test sets may overlap between the splits.
See this block for an example of the difference. Note the overlap of the elements in the test sets for ShuffleSplit
.
splits = 5
tx = range(10)
ty = [0] * 5 + [1] * 5
from sklearn.model_selection import StratifiedShuffleSplit, StratifiedKFold
from sklearn import datasets
stratKfold = StratifiedKFold(n_splits=splits, shuffle=True, random_state=42)
shufflesplit = StratifiedShuffleSplit(n_splits=splits, random_state=42, test_size=2)
print("stratKFold")
for train_index, test_index in stratKfold.split(tx, ty):
print("TRAIN:", train_index, "TEST:", test_index)
print("Shuffle Split")
for train_index, test_index in shufflesplit.split(tx, ty):
print("TRAIN:", train_index, "TEST:", test_index)
Output:
stratKFold
TRAIN: [0 2 3 4 5 6 7 9] TEST: [1 8]
TRAIN: [0 1 2 3 5 7 8 9] TEST: [4 6]
TRAIN: [0 1 3 4 5 6 8 9] TEST: [2 7]
TRAIN: [1 2 3 4 6 7 8 9] TEST: [0 5]
TRAIN: [0 1 2 4 5 6 7 8] TEST: [3 9]
Shuffle Split
TRAIN: [8 4 1 0 6 5 7 2] TEST: [3 9]
TRAIN: [7 0 3 9 4 5 1 6] TEST: [8 2]
TRAIN: [1 2 5 6 4 8 9 0] TEST: [3 7]
TRAIN: [4 6 7 8 3 5 1 2] TEST: [9 0]
TRAIN: [7 2 6 5 4 3 0 9] TEST: [1 8]
As for when to use them, I tend to use stratKFolds
for any cross validation, and I use ShuffleSplit
with a split of 2 for my train/test set splits. But I'm sure there are other use cases for both.
@Ken Syme already has a very good answer. I just want to add something.
StratifiedKFold
is a variation ofKFold
. First,StratifiedKFold
shuffles your data, after that splits the data inton_splits
parts and Done. Now, it will use each part as a test set. Note that it only and always shuffles data one time before splitting.
With shuffle = True
, the data is shuffled by your random_state
. Otherwise,
the data is shuffled by np.random
(as default).
For example, with n_splits = 4
, and your data has 3 classes (label) for y
(dependent variable). 4 test sets cover all the data without any overlap.
- On the other hand,
StratifiedShuffleSplit
is a variation ofShuffleSplit
. First,StratifiedShuffleSplit
shuffles your data, and then it also splits the data inton_splits
parts. However, it's not done yet. After this step,StratifiedShuffleSplit
picks one part to use as a test set. Then it repeats the same processn_splits - 1
other times, to getn_splits - 1
other test sets. Look at the picture below, with the same data, but this time, the 4 test sets do not cover all the data, i.e there are overlaps among test sets.
So, the difference here is that StratifiedKFold
just shuffles and splits once, therefore the test sets do not overlap, while StratifiedShuffleSplit
shuffles each time before splitting, and it splits n_splits
times, the test sets can overlap.
- Note: the two methods uses "stratified fold" (that why "stratified" appears in both names). It means each part preserves the same percentage of samples of each class (label) as the original data. You can read more at cross_validation documents