Given a Spark dataframe, I would like to compute a column mean based on the non-missing and non-unknown values for that column. I would then like to take this mean and use it to replace the column's missing & unknown values.
For example, assuming I'm working with a:
Then I can compute this mean as shown below.
calc_mean = df.where((col("unknown_age") == 0) & (col("missing_age") == 0))
.agg(avg(col("age")))
OR via SQL and windows functions,
mean_compute = hiveContext.sql("select avg(age) over() as mean from df
where missing_age = 0 and unknown_age = 0")
I don't want to use SQL/windows functions if I can help it. My challenge has been taking this mean and replacing the unknown/missing values with it using non-SQL methods.
I've tried using when(), where(), replace(), withColumn, UDFs, and combinations... Regardless of what I do, I either get errors or the results aren't what I expect. Here's an example of one of many things I've tried that didn't work.
imputed = df.when((col("unknown_age") == 1) | (col("missing_age") == 1),
calc_mean).otherwise("age")
I've scoured the web, but haven't found similar imputation type questions so any help is much appreciated. It could be something very simple that I've missed.
A side note -- I'm trying to apply this code to all columns in the Spark Dataframe that don't have unknown_ or missing_ in the column names. Can I just wrap the Spark related code in a Python 'for loop' and loop through all of the applicable columns to do this?
UPDATE:
Also figured out how to loop through columns... Here's an example.
for x in df.columns:
if 'unknown_' not in x and 'missing_' not in x:
avg_compute = df.where(df['missing_' + x] != 1).agg(avg(x)).first()[0]
df = df.withColumn(x + 'mean_miss_imp', when((df['missing_' + x] == 1),
avg_compute).otherwise(df[x]))
If age for unknown or missing is some value:
from pyspark.sql.functions import col, avg, when
df = sc.parallelize([
(10, 0, 0), (20, 0, 0), (-1, 1, 0), (-1, 0, 1)
]).toDF(["age", "missing_age", "unknown_age"])
avg_age = df.where(
(col("unknown_age") != 1) & (col("missing_age") != 1)
).agg(avg("age")).first()[0]
df.withColumn("age_imp", when(
(col("unknown_age") == 1) | (col("missing_age") == 1), avg_age
).otherwise(col("age")))
If age for unknown or missing is NULL you can simplify this to:
df = sc.parallelize([
(10, 0, 0), (20, 0, 0), (None, 1, 0), (None, 0, 1)
]).toDF(["age", "missing_age", "unknown_age"])
df.na.fill(df.na.drop().agg(avg("age")).first()[0], ["age"])