Where do you need to use lit() in Pyspark SQL?

flybonzai picture flybonzai · Jun 9, 2016 · Viewed 67.6k times · Source

I'm trying to make sense of where you need to use a lit value, which is defined as a literal column in the documentation.

Take for example this udf, which returns the index of a SQL column array:

def find_index(column, index):
    return column[index]

If I were to pass an integer into this I would get an error. I would need to pass a lit(n) value into the udf to get the correct index of an array.

Is there a place I can better learn the hard and fast rules of when to use lit and possibly col as well?

Answer

zero323 picture zero323 · Jun 9, 2016

To keep it simple you need a Column (can be a one created using lit but it is not the only option) when JVM counterpart expects a column and there is no internal conversion in a Python wrapper or you wan to call a Column specific method.

In the first case the only strict rule is the on that applies to UDFs. UDF (Python or JVM) can be called only with arguments which are of Column type. It also typically applies to functions from pyspark.sql.functions. In other cases it is always best to check documentation and docs string firsts and if it is not sufficient docs of a corresponding Scala counterpart.

In the second case rules are simple. If you for example want to compare a column to a value then value has to be on the RHS:

col("foo") > 0  # OK

or value has to be wrapped with literal:

lit(0) < col("foo")  # OK

In Python many operators (<, ==, <=, &, |, + , -, *, /) can use non column object on the LHS:

0 < col("foo") 

but such applications are not supported in Scala.

It goes without saying that you have to use lit if you want to access any of the pyspark.sql.Column methods treating standard Python scalar as a constant column. For example you'll need

c = lit(1)

not

c = 1