Pyspark Dataframe Imputations -- Replace Unknown & Missing Values with Column Mean based on specified condition

midnightfalcon picture midnightfalcon · May 25, 2016 · Viewed 8.6k times · Source

Given a Spark dataframe, I would like to compute a column mean based on the non-missing and non-unknown values for that column. I would then like to take this mean and use it to replace the column's missing & unknown values.

For example, assuming I'm working with a:

  • Dataframe named df, where each record represents one individual and all columns are integer or numeric
  • Column named age (ages for each record)
  • Column named missing_age (which equals 1 if that individual has no age, 0 otherwise)
  • Column named unknown_age (which equals 1 if that individual has unknown age, 0 otherwise)

Then I can compute this mean as shown below.

calc_mean = df.where((col("unknown_age") == 0) & (col("missing_age") == 0))
.agg(avg(col("age")))

OR via SQL and windows functions,

mean_compute = hiveContext.sql("select avg(age) over() as mean from df 
where missing_age = 0 and unknown_age = 0")

I don't want to use SQL/windows functions if I can help it. My challenge has been taking this mean and replacing the unknown/missing values with it using non-SQL methods.

I've tried using when(), where(), replace(), withColumn, UDFs, and combinations... Regardless of what I do, I either get errors or the results aren't what I expect. Here's an example of one of many things I've tried that didn't work.

imputed = df.when((col("unknown_age") == 1) | (col("missing_age") == 1),
calc_mean).otherwise("age")

I've scoured the web, but haven't found similar imputation type questions so any help is much appreciated. It could be something very simple that I've missed.

A side note -- I'm trying to apply this code to all columns in the Spark Dataframe that don't have unknown_ or missing_ in the column names. Can I just wrap the Spark related code in a Python 'for loop' and loop through all of the applicable columns to do this?

UPDATE:

Also figured out how to loop through columns... Here's an example.

for x in df.columns:
    if 'unknown_' not in x and 'missing_' not in x:
        avg_compute = df.where(df['missing_' + x] != 1).agg(avg(x)).first()[0]
        df = df.withColumn(x + 'mean_miss_imp', when((df['missing_' + x] == 1), 
        avg_compute).otherwise(df[x]))

Answer

zero323 picture zero323 · May 25, 2016

If age for unknown or missing is some value:

from pyspark.sql.functions import col, avg, when

df = sc.parallelize([
    (10, 0, 0), (20, 0, 0), (-1, 1, 0), (-1, 0, 1)
]).toDF(["age", "missing_age", "unknown_age"])

avg_age = df.where(
    (col("unknown_age") != 1) & (col("missing_age") != 1)
).agg(avg("age")).first()[0]

df.withColumn("age_imp", when(
    (col("unknown_age") == 1) | (col("missing_age") == 1), avg_age
).otherwise(col("age")))

If age for unknown or missing is NULL you can simplify this to:

df = sc.parallelize([
    (10, 0, 0), (20, 0, 0), (None, 1, 0), (None, 0, 1)
]).toDF(["age", "missing_age", "unknown_age"])

df.na.fill(df.na.drop().agg(avg("age")).first()[0], ["age"])