Stackoverflow continues to guide me on this question, so I think I will add some information here.
Returning simple types from UDF:
from pyspark.sql.types import * from pyspark.sql import functions as F def get_df(): d = [(0.0, 0.0), (0.0, 3.0), (1.0, 6.0), (1.0, 9.0)] df = sqlContext.createDataFrame(d, ['x', 'y']) return df df = get_df() df.show()
When integers are not enough:
df = get_df() func = udf(lambda x: [0]*int(x), ArrayType(IntegerType())) df = df.withColumn('list', func('y')) func = udf(lambda x: {float(y): str(y) for y in range(int(x))}, MapType(FloatType(), StringType())) df = df.withColumn('map', func('y')) df.show() # +---+---+--------------------+--------------------+ # | x| y| list| map| # +---+---+--------------------+--------------------+ # |0.0|0.0| []| Map()| # |0.0|3.0| [0, 0, 0]|Map(2.0 -> 2, 0.0...| # |1.0|6.0| [0, 0, 0, 0, 0, 0]|Map(0.0 -> 0, 5.0...| # |1.0|9.0|[0, 0, 0, 0, 0, 0...|Map(0.0 -> 0, 5.0...| # +---+---+--------------------+--------------------+ df.printSchema() # root # |-- x: double (nullable = true) # |-- y: double (nullable = true) # |-- list: array (nullable = true) # | |-- element: integer (containsNull = true) # |-- map: map (nullable = true) # | |-- key: float # | |-- value: string (valueContainsNull = true)
Returning complex data types from UDF:
df = get_df() df = df.groupBy('x').agg(F.collect_list('y').alias('y[]')) df.show() # +---+----------+ # | x| y[]| # +---+----------+ # |0.0|[0.0, 3.0]| # |1.0|[9.0, 6.0]| # +---+----------+ schema = StructType([ StructField("min", FloatType(), True), StructField("size", IntegerType(), True), StructField("edges", ArrayType(FloatType()), True), StructField("val_to_index", MapType(FloatType(), IntegerType()), True) # StructField('insanity', StructType([StructField("min_", FloatType(), True), StructField("size_", IntegerType(), True)])) ]) def func(values): mn = min(values) size = len(values) lst = sorted(values)[::-1] val_to_index = {x: i for i, x in enumerate(values)} return (mn, size, lst, val_to_index) func = udf(func, schema) dff = df.select('*', func('y[]').alias('complex_type')) dff.show(10, False) # +---+----------+------------------------------------------------------+ # |x |y[] |complex_type | # +---+----------+------------------------------------------------------+ # |0.0|[0.0, 3.0]|[0.0,2,WrappedArray(3.0, 0.0),Map(0.0 -> 0, 3.0 -> 1)]| # |1.0|[6.0, 9.0]|[6.0,2,WrappedArray(9.0, 6.0),Map(9.0 -> 1, 6.0 -> 0)]| # +---+----------+------------------------------------------------------+ dff.printSchema() # +---+----------+------------------------------------------------------+ # |x |y[] |complex_type | # +---+----------+------------------------------------------------------+ # |0.0|[0.0, 3.0]|[0.0,2,WrappedArray(3.0, 0.0),Map(0.0 -> 0, 3.0 -> 1)]| # |1.0|[6.0, 9.0]|[6.0,2,WrappedArray(9.0, 6.0),Map(9.0 -> 1, 6.0 -> 0)]| # +---+----------+------------------------------------------------------+
Passing multiple arguments to UDF:
df = get_df() func = udf(lambda arr: arr[0]*arr[1],FloatType()) df = df.withColumn('x*y', func(F.array('x', 'y'))) # +---+---+---+ # | x| y|x*y| # +---+---+---+ # |0.0|0.0|0.0| # |0.0|3.0|0.0| # |1.0|6.0|6.0| # |1.0|9.0|9.0| # +---+---+---+
The code is intended solely for demonstration purposes, all of the above conversions are available in the Spark code and provide much better performance. Like @ zero323 in the comment above, UDF should be avoided in pyspark; returning complex types should make you think about simplifying your logic.