Parameter "stratify" from method "train_test_split" (scikit Learn)

Daneel Olivaw picture Daneel Olivaw · Jan 17, 2016 · Viewed 130.6k times · Source

I am trying to use train_test_split from package scikit Learn, but I am having trouble with parameter stratify. Hereafter is the code:

from sklearn import cross_validation, datasets 

X = iris.data[:,:2]
y = iris.target

cross_validation.train_test_split(X,y,stratify=y)

However, I keep getting the following problem:

raise TypeError("Invalid parameters passed: %s" % str(options))
TypeError: Invalid parameters passed: {'stratify': array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 
0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2])}

Does someone have an idea what is going on? Below is the function documentation.

[...]

stratify : array-like or None (default is None)

If not None, data is split in a stratified fashion, using this as the labels array.

New in version 0.17: stratify splitting

[...]

Answer

Fazzolini picture Fazzolini · Aug 11, 2016

This stratify parameter makes a split so that the proportion of values in the sample produced will be the same as the proportion of values provided to parameter stratify.

For example, if variable y is a binary categorical variable with values 0 and 1 and there are 25% of zeros and 75% of ones, stratify=y will make sure that your random split has 25% of 0's and 75% of 1's.