User-defined aggregate functions (UDAFs)

Applies to:check marked yesDatabricks Runtime

User-defined aggregate functions (UDAFs) are user-programmable routines that act on multiple rows at once and return a single aggregated value as a result. This documentation lists the classes that are required for creating and registering UDAFs. It also contains examples that demonstrate how to define and register UDAFs in Scala and invoke them in Spark SQL.

Aggregator

SyntaxAggregator[-IN,BUF,OUT]

A base class for user-defined aggregations, which can be used in Dataset operations to take all of the elements of a group and reduce them to a single value.

  • IN: The input type for the aggregation.

  • BUF: The type of the intermediate value of the reduction.

  • OUT: The type of the final output result.

  • bufferEncoder: Encoder[BUF]

    The Encoder for the intermediate value type.

  • finish(reduction: BUF): OUT

    Transform the output of the reduction.

  • merge(b1: BUF, b2: BUF): BUF

    Merge two intermediate values.

  • outputEncoder: Encoder[OUT]

    The Encoder for the final output value type.

  • reduce(b: BUF, a: IN): BUF

    Aggregate input valueainto current intermediate value. For performance, the function may modifyband return it instead of constructing new object forb.

  • zero: BUF

    The initial value of the intermediate result for this aggregation.

Examples

Type-safe user-defined aggregate functions

User-defined aggregations for strongly typed Datasets revolve around theAggregatorabstract class. For example, a type-safe user-defined average can look like:

Untyped user-defined aggregate functions

Typed aggregations, as described above, may also be registered as untyped aggregating UDFs for use with DataFrames. For example, a user-defined average for untyped DataFrames can look like:

importorg.apache.spark.sql.{Encoder,Encoders,SparkSession}importorg.apache.spark.sql.expressions.Aggregatorimportorg.apache.spark.sql.functionscaseclassAverage(varsum:Long,varcount:Long)objectMyAverageextendsAggregator[Long,Average,Double]{//A zero value for this aggregation. Should satisfy the property that any b + zero = bdefzero:Average=Average(0L,0L)//Combine two values to produce a new value. For performance, the function may modify `buffer`//and return it instead of constructing a new objectdefreduce(buffer:Average,data:Long):Average={buffer.sum+=databuffer.count+=1buffer}//Merge two intermediate valuesdefmerge(b1:Average,b2:Average):Average={b1.sum+=b2.sumb1.count+=b2.countb1}//Transform the output of the reductiondeffinish(reduction:Average):Double=reduction.sum.toDouble/reduction.count//The Encoder for the intermediate value typedefbufferEncoder:Encoder[Average]=Encoders.product//The Encoder for the final output value typedefoutputEncoder:Encoder[Double]=Encoders.scalaDouble}//Register the function to access itspark.udf.register("myAverage",functions.udaf(MyAverage))valdf=spark.read.format("json").load("examples/src/main/resources/employees.json")df.createOrReplaceTempView("employees")df.show()//+-------+------+//| name|salary|//+-------+------+//|Michael| 3000|//| Andy| 4500|//| Justin| 3500|//| Berta| 4000|//+-------+------+valresult=spark.sql("SELECT myAverage(salary) as average_salary FROM employees")result.show()//+--------------+//|average_salary|//+--------------+//| 3750.0|//+--------------+
importjava.io.Serializable;importorg.apache.spark.sql.Dataset;importorg.apache.spark.sql.Encoder;importorg.apache.spark.sql.Encoders;importorg.apache.spark.sql.Row;importorg.apache.spark.sql.SparkSession;importorg.apache.spark.sql.expressions.Aggregator;importorg.apache.spark.sql.functions;publicstaticclassAverageimplements可序列化的{privatelongsum;privatelongcount;//Constructors, getters, setters...}publicstaticclassMyAverageextendsAggregator<Long,Average,Double>{//A zero value for this aggregation. Should satisfy the property that any b + zero = bpublicAveragezero(){returnnewAverage(0L,0L);}//Combine two values to produce a new value. For performance, the function may modify `buffer`//and return it instead of constructing a new objectpublicAveragereduce(Averagebuffer,Longdata){longnewSum=buffer.getSum()+data;longnewCount=buffer.getCount()+1;buffer.setSum(newSum);buffer.setCount(newCount);returnbuffer;}//Merge two intermediate valuespublicAveragemerge(Averageb1,Averageb2){longmergedSum=b1.getSum()+b2.getSum();longmergedCount=b1.getCount()+b2.getCount();b1.setSum(mergedSum);b1.setCount(mergedCount);returnb1;}//Transform the output of the reductionpublicDoublefinish(Averagereduction){return((double)reduction.getSum())/reduction.getCount();}//The Encoder for the intermediate value typepublicEncoder<Average>bufferEncoder(){returnEncoders.bean(Average.class);}//The Encoder for the final output value typepublicEncoder<Double>outputEncoder(){returnEncoders.DOUBLE();}}//Register the function to access itspark.udf().register("myAverage",functions.udaf(newMyAverage(),Encoders.LONG()));Dataset<Row>df=spark.read().format("json").load("examples/src/main/resources/employees.json");df.createOrReplaceTempView("employees");df.show();//+-------+------+//| name|salary|//+-------+------+//|Michael| 3000|//| Andy| 4500|//| Justin| 3500|//| Berta| 4000|//+-------+------+Dataset<Row>result=spark.sql("SELECT myAverage(salary) as average_salary FROM employees");result.show();//+--------------+//|average_salary|//+--------------+//| 3750.0|//+--------------+
-- Compile and place UDAF MyAverage in a JAR file called `MyAverage.jar` in /tmp.CREATEFUNCTIONmyAverageAS'MyAverage'USINGJAR'/tmp/MyAverage.jar';SHOWUSERFUNCTIONS;+------------------+|function|+------------------+|default.myAverage|+------------------+CREATETEMPORARYVIEWemployeesUSINGorg.apache.spark.sql.jsonOPTIONS(path"examples/src/main/resources/employees.json");SELECT*FROMemployees;+-------+------+|name|salary|+-------+------+|Michael|3000||Andy|4500||Justin|3500||Berta|4000|+-------+------+SELECTmyAverage(salary)asaverage_salaryFROMemployees;+--------------+|average_salary|+--------------+|3750.0|+--------------+