I want to filter a DataFrame
using a condition related to the length of a column, this question might be very easy but I didn't find any related question in the SO.
More specific, I have a DataFrame
with only one Column
which of ArrayType(StringType())
, I want to filter the DataFrame
using the length as filterer, I shot a snippet below.
df = sqlContext.read.parquet("letters.parquet")
df.show()
# The output will be
# +------------+
# | tokens|
# +------------+
# |[L, S, Y, S]|
# |[L, V, I, S]|
# |[I, A, N, A]|
# |[I, L, S, A]|
# |[E, N, N, Y]|
# |[E, I, M, A]|
# |[O, A, N, A]|
# | [S, U, S]|
# +------------+
# But I want only the entries with length 3 or less
fdf = df.filter(len(df.tokens) <= 3)
fdf.show() # But it says that the TypeError: object of type 'Column' has no len(), so the previous statement is obviously incorrect.
I read Column's Documentation, but didn't find any property useful for this matter. I appreciate any help!
In Spark >= 1.5 you can use size
function:
from pyspark.sql.functions import col, size
df = sqlContext.createDataFrame([
(["L", "S", "Y", "S"], ),
(["L", "V", "I", "S"], ),
(["I", "A", "N", "A"], ),
(["I", "L", "S", "A"], ),
(["E", "N", "N", "Y"], ),
(["E", "I", "M", "A"], ),
(["O", "A", "N", "A"], ),
(["S", "U", "S"], )],
("tokens", ))
df.where(size(col("tokens")) <= 3).show()
## +---------+
## | tokens|
## +---------+
## |[S, U, S]|
## +---------+
In Spark < 1.5 an UDF should do the trick:
from pyspark.sql.types import IntegerType
from pyspark.sql.functions import udf
size_ = udf(lambda xs: len(xs), IntegerType())
df.where(size_(col("tokens")) <= 3).show()
## +---------+
## | tokens|
## +---------+
## |[S, U, S]|
## +---------+
If you use HiveContext
then size
UDF with raw SQL should work with any version:
df.registerTempTable("df")
sqlContext.sql("SELECT * FROM df WHERE size(tokens) <= 3").show()
## +--------------------+
## | tokens|
## +--------------------+
## |ArrayBuffer(S, U, S)|
## +--------------------+
For string columns you can either use an udf
defined above or length
function:
from pyspark.sql.functions import length
df = sqlContext.createDataFrame([("fooo", ), ("bar", )], ("k", ))
df.where(length(col("k")) <= 3).show()
## +---+
## | k|
## +---+
## |bar|
## +---+