Custom transformer for sklearn Pipeline that alters both X and y

MarkAWard picture MarkAWard · Aug 28, 2014 · Viewed 8.8k times · Source

I want to create my own transformer for use with the sklearn Pipeline. Hence I am creating a class that implements both fit and transform methods. The purpose of the transformer will be to remove rows from the matrix that have more than a specified number of NaNs. So the issue I am facing is how can I change both the X and y matrices that are passed to the transformer? I believe this has to be done in the fit method since it has access to both X and y. Since python passes arguments by assignment once I reassign X to a new matrix with fewer rows the reference to the original X is lost (and of course the same is true for y). Is it possible to maintain this reference?

I’m using a pandas DataFrame to easily drop the rows that have too many NaNs, this may not be the right way to do it for my use case. Current code looks like this:

class Dropna():

    # thresh is max number of NaNs allowed in a row
    def __init__(self, thresh=0):
        self.thresh = thresh

    def fit(self, X, y):
        total = X.shape[1]
        # +1 to account for 'y' being added to the dframe                                                                                                                            
        new_thresh = total + 1 - self.thresh
        df = pd.DataFrame(X)
        df['y'] = y
        df.dropna(thresh=new_thresh, inplace=True)
        X = df.drop('y', axis=1).values
        y = df['y'].values
        return self

    def transform(self, X):
        return X

Answer

eickenberg picture eickenberg · Aug 28, 2014

Modifying the sample axis, e.g. removing samples, does not (yet?) comply with the scikit-learn transformer API. So if you need to do this, you should do it outside any calls to scikit learn, as preprocessing.

As it is now, the transformer API is used to transform the features of a given sample into something new. This can implicitly contain information from other samples, but samples are never deleted.

Another option is to attempt to impute the missing values. But again, if you need to delete samples, treat it as preprocessing before using scikit learn.