How to return the "Tuple" type to UDF in PySpark?

All data types in pyspark.sql.types ::

 __all__ = [ "DataType", "NullType", "StringType", "BinaryType", "BooleanType", "DateType", "TimestampType", "DecimalType", "DoubleType", "FloatType", "ByteType", "IntegerType", "LongType", "ShortType", "ArrayType", "MapType", "StructField", "StructType"] 

I need to write UDF (in pyspark) which returns an array of tuples. What does the second argument give me, which is the return type of the udf method? That would be something in the lines of ArrayType(TupleType()) ...

+10
python dataframe apache-spark pyspark apache-spark-sql
source share
3 answers

There is no such thing as TupleType . Product types are represented as structs with fields of a specific type. For example, if you want to return an array of pairs (integer, string), you can use the scheme as follows:

 from pyspark.sql.types import * schema = ArrayType(StructType([ StructField("char", StringType(), False), StructField("count", IntegerType(), False) ])) 

Usage example:

 from pyspark.sql.functions import udf from collections import Counter char_count_udf = udf( lambda s: Counter(s).most_common(), schema ) df = sc.parallelize([(1, "foo"), (2, "bar")]).toDF(["id", "value"]) df.select("*", char_count_udf(df["value"])).show(2, False) ## +---+-----+-------------------------+ ## |id |value|PythonUDF#<lambda>(value)| ## +---+-----+-------------------------+ ## |1 |foo |[[o,2], [f,1]] | ## |2 |bar |[[r,1], [a,1], [b,1]] | ## +---+-----+-------------------------+ 
+21
source share

Stackoverflow continues to guide me on this question, so I think I will add some information here.

Returning simple types from UDF:

 from pyspark.sql.types import * from pyspark.sql import functions as F def get_df(): d = [(0.0, 0.0), (0.0, 3.0), (1.0, 6.0), (1.0, 9.0)] df = sqlContext.createDataFrame(d, ['x', 'y']) return df df = get_df() df.show() # +---+---+ # | x| y| # +---+---+ # |0.0|0.0| # |0.0|3.0| # |1.0|6.0| # |1.0|9.0| # +---+---+ func = udf(lambda x: str(x), StringType()) df = df.withColumn('y_str', func('y')) func = udf(lambda x: int(x), IntegerType()) df = df.withColumn('y_int', func('y')) df.show() # +---+---+-----+-----+ # | x| y|y_str|y_int| # +---+---+-----+-----+ # |0.0|0.0| 0.0| 0| # |0.0|3.0| 3.0| 3| # |1.0|6.0| 6.0| 6| # |1.0|9.0| 9.0| 9| # +---+---+-----+-----+ df.printSchema() # root # |-- x: double (nullable = true) # |-- y: double (nullable = true) # |-- y_str: string (nullable = true) # |-- y_int: integer (nullable = true) 

When integers are not enough:

 df = get_df() func = udf(lambda x: [0]*int(x), ArrayType(IntegerType())) df = df.withColumn('list', func('y')) func = udf(lambda x: {float(y): str(y) for y in range(int(x))}, MapType(FloatType(), StringType())) df = df.withColumn('map', func('y')) df.show() # +---+---+--------------------+--------------------+ # | x| y| list| map| # +---+---+--------------------+--------------------+ # |0.0|0.0| []| Map()| # |0.0|3.0| [0, 0, 0]|Map(2.0 -> 2, 0.0...| # |1.0|6.0| [0, 0, 0, 0, 0, 0]|Map(0.0 -> 0, 5.0...| # |1.0|9.0|[0, 0, 0, 0, 0, 0...|Map(0.0 -> 0, 5.0...| # +---+---+--------------------+--------------------+ df.printSchema() # root # |-- x: double (nullable = true) # |-- y: double (nullable = true) # |-- list: array (nullable = true) # | |-- element: integer (containsNull = true) # |-- map: map (nullable = true) # | |-- key: float # | |-- value: string (valueContainsNull = true) 

Returning complex data types from UDF:

 df = get_df() df = df.groupBy('x').agg(F.collect_list('y').alias('y[]')) df.show() # +---+----------+ # | x| y[]| # +---+----------+ # |0.0|[0.0, 3.0]| # |1.0|[9.0, 6.0]| # +---+----------+ schema = StructType([ StructField("min", FloatType(), True), StructField("size", IntegerType(), True), StructField("edges", ArrayType(FloatType()), True), StructField("val_to_index", MapType(FloatType(), IntegerType()), True) # StructField('insanity', StructType([StructField("min_", FloatType(), True), StructField("size_", IntegerType(), True)])) ]) def func(values): mn = min(values) size = len(values) lst = sorted(values)[::-1] val_to_index = {x: i for i, x in enumerate(values)} return (mn, size, lst, val_to_index) func = udf(func, schema) dff = df.select('*', func('y[]').alias('complex_type')) dff.show(10, False) # +---+----------+------------------------------------------------------+ # |x |y[] |complex_type | # +---+----------+------------------------------------------------------+ # |0.0|[0.0, 3.0]|[0.0,2,WrappedArray(3.0, 0.0),Map(0.0 -> 0, 3.0 -> 1)]| # |1.0|[6.0, 9.0]|[6.0,2,WrappedArray(9.0, 6.0),Map(9.0 -> 1, 6.0 -> 0)]| # +---+----------+------------------------------------------------------+ dff.printSchema() # +---+----------+------------------------------------------------------+ # |x |y[] |complex_type | # +---+----------+------------------------------------------------------+ # |0.0|[0.0, 3.0]|[0.0,2,WrappedArray(3.0, 0.0),Map(0.0 -> 0, 3.0 -> 1)]| # |1.0|[6.0, 9.0]|[6.0,2,WrappedArray(9.0, 6.0),Map(9.0 -> 1, 6.0 -> 0)]| # +---+----------+------------------------------------------------------+ 

Passing multiple arguments to UDF:

 df = get_df() func = udf(lambda arr: arr[0]*arr[1],FloatType()) df = df.withColumn('x*y', func(F.array('x', 'y'))) # +---+---+---+ # | x| y|x*y| # +---+---+---+ # |0.0|0.0|0.0| # |0.0|3.0|0.0| # |1.0|6.0|6.0| # |1.0|9.0|9.0| # +---+---+---+ 

The code is intended solely for demonstration purposes, all of the above conversions are available in the Spark code and provide much better performance. Like @ zero323 in the comment above, UDF should be avoided in pyspark; returning complex types should make you think about simplifying your logic.

+5
source share

For the Scala version instead of Python. version 2.4

 import org.apache.spark.sql.types._ val testschema : StructType = StructType( StructField("number", IntegerType) :: StructField("Array", ArrayType(StructType(StructField("cnt_rnk", IntegerType) :: StructField("comp", StringType) :: Nil))) :: StructField("comp", StringType):: Nil) 

The tree structure is as follows.

 testschema.printTreeString root |-- number: integer (nullable = true) |-- Array: array (nullable = true) | |-- element: struct (containsNull = true) | | |-- cnt_rnk: integer (nullable = true) | | |-- corp_id: string (nullable = true) |-- comp: string (nullable = true) 
0
source share

All Articles