[Spark] user defined functions UDF and UDAF

Posted by shaunie123 on Wed, 05 Jan 2022 21:44:19 +0100

  • All the trees we use in this article are user JSON, as shown in the figure below

{"username": "zhangsan","age": 20}
{"username": "lisi","age": 21}
{"username": "wangwu","age": 19}

Custom UDF

Introduction to UDF

UDF: enter a line and return a result For one-to-one relationship, if you put a value into a function, you will return a value instead of multiple values. As can be seen from the following example:

(x: String) => "Name=" + x

In this function, the input parameter is one and the return is one, instead of returning multiple values

Concrete implementation

object UDF {
  def main(args: Array[String]): Unit = {
    val spark = SparkSession.builder()
      .master("local[*]")
      .appName("UTF")
      .getOrCreate()
    val df = spark.read
      .json("data/user.json")
    df.createOrReplaceTempView("user")

    //Register udf
    spark.udf.register("prefixName", (name: String) => {
      "Name:" + name
    })

    spark.sql("select age,prefixName(username) from user").show()

    spark.close()
  }
}

Result display

explain

  • Before using UDF, you need to register spark udf. register

Jump top

Customize UDAF

Introduction to UDAF

UDAF can be divided into strong type and weak type

  • The main difference between strong and weak types is that strong types should pay attention to the type of data

Both strongly typed Dataset and weakly typed DataFrame provide related aggregation functions, such as count(), countDistinct(), avg(), max(), min(). In addition, users can set their own custom aggregation functions. Implement user-defined weakly typed aggregate functions by inheriting UserDefinedAggregateFunction. UserDefinedAggregateFunction is no longer recommended. The strongly typed aggregate function Aggregator can be used uniformly

Weak type UDAF

Customize UDAF

  class MyAvgUDAF extends UserDefinedAggregateFunction {
    /**
     * For the structure of the input data, we find the average value of age, so the input data is age
     * Because it is an aggregate function, when it is positive, enter an array of data, and finally return a data, that is, the average value
     * Therefore, the input is an array, the category of data is age, and the type of data is longType
     */
    override def inputSchema: StructType = {
      StructType(
        Array(
          StructField("age", LongType)
        )
      )
    }

    /**
     * buffer
     * Buffer is used to temporarily store data. Data will be temporarily stored and calculated here before outputting data
     * For example, average value: the data is summed in the buffer, the quantity is calculated, and the average value is output
     *
     * @return
     */
    override def bufferSchema: StructType = {
      StructType(
        Array(
          StructField("total", LongType),
          StructField("count", LongType)
        )
      )
    }

    /**
     * The data type of function output is the data type of calculation result
     *
     * @return
     */
    override def dataType: DataType = LongType

    /**
     * Stability of function
     *
     * @return
     */
    override def deterministic: Boolean = true

    /**
     * Initial conversion of buffer
     *
     * @param buffer
     */
    override def initialize(buffer: MutableAggregationBuffer): Unit = {
      //Here is how to reset the data in the initial buffer (i.e. return to zero). There are two methods to return to zero
      //Method 1
      //buffer(0) = 0l
      //buffer(1) = 0l

      //Method 2
      buffer.update(0, 0l)
      buffer.update(1, 0l)
    }

    /**
     * Update the data of the buffer according to the input data, that is, the calculation rules of the buffer
     *
     * @param buffer
     * @param input
     */
    override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
      //The first data is summation, the data in the buffer plus the input data
      buffer.update(0, buffer.getLong(0) + input.getLong(0))
      //The second data is to calculate the total, adding one at a time
      buffer.update(1, buffer.getLong(1) + 1)
    }

    /**
     * Buffer data merge
     * Reserved 1
     *
     * @param buffer1
     * @param buffer2
     */
    override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
      buffer1.update(0, buffer1.getLong(0) + buffer2.getLong(0))
      buffer1.update(1, buffer1.getLong(1) + buffer2.getLong(1))
    }

    /**
     * Calculate average
     *
     * @param buffer
     * @return
     */
    override def evaluate(buffer: Row): Any = (buffer.getLong(0) / buffer.getLong(1))
  }

Main steps:

  • Inherits the UserDefinedAggregateFunction class
  • Realize his method

What is the meaning of each method?

  • inputSchema: structure of input data. Since it is aggregation, the input data must be an array
  • bufferSchema: the structure of buffer data. The buffer is used to write calculation rules. If you choose to calculate the average value, you need to calculate the total number and sum in the buffer
  • dataType: the data structure of the output, that is, the data structure of the output result
  • deterministic: the stability of the function to ensure consistency. Generally, true is used
  • initialize: the buffer is initialized to zero
  • Update: update the data of the buffer according to the input data, that is, the calculation rules of the buffer
  • Merge: merge of buffers
  • evaluate: calculate average

Register and use

  def main(args: Array[String]): Unit = {
    val spark = SparkSession.builder()
      .master("local[*]")
      .appName("UDAF")
      .getOrCreate()
    val df = spark.read
      .json("data/user.json")
    df.createOrReplaceTempView("user")

    //Register function
    spark.udf.register("ageAvg",new MyAvgUDAF())

    spark.sql("select ageAvg(age) from user").show()

    spark.close()
  }

Operation results

Jump top

Strongly typed UDAF

Customize two sample classes

  //Store buffer data
  case class Buff(var total: Long, var count: Long)

  //Store input data
  case class User(var username: String, var age: Long)

Custom strongly typed UDAF classes

  class MyAvgAgeUDAF extends Aggregator[User, Buff, Long] {
    /**
     * Initial value or zero value
     * Buffer initialization
     *
     * @return
     */
    override def zero: Buff = {
      Buff(0l, 0l)
    }

    /**
     * Update the data of the buffer according to the input data
     *
     * @param b
     * @param a
     * @return
     */
    override def reduce(b: Buff, a: User): Buff = {
      b.total += a.age
      b.count += 1
      b
    }

    /**
     * Merge buffer
     *
     * @param b1
     * @param b2
     * @return
     */
    override def merge(b1: Buff, b2: Buff): Buff = {
      b1.total += b2.total
      b1.count += b2.count
      b1
    }

    /**
     * Calculation results
     *
     * @param reduction
     * @return
     */
    override def finish(reduction: Buff): Long = (reduction.total / reduction.count)

    /**
     * This is a fixed way of writing. If it is a user-defined class, it is: product
     * Encoding operation of buffer
     *
     * @return
     */
    override def bufferEncoder: Encoder[Buff] = Encoders.product

    /**
     * This is also a fixed way of writing. If scala exists, just select the corresponding class (such as long, int, string...)
     * Output encoding operation
     *
     * @return
     */
    override def outputEncoder: Encoder[Long] = Encoders.scalaLong
  }

explain

  • Inherit the Aggregator class
  • Implementation method
  • Compared with weak types, you need to define generics for input, buffer and output data

Brief introduction of method

  • zero: initialization of buffer
  • reduce: update the data of the buffer according to the input data, that is, calculate the total number of data and the sum of data
  • Merge: merge buffer data
  • Fish: calculation results
  • bufferEncoder and · outputEncoder: these two are the encoding formats of buffer and output respectively. In fact, they are in fixed format. If the data output in the next stage is user-defined, they are encoders Product, if the output data is scala's own, it is encoders The long after scalalong depends on the type of data you output

Register and use

  def main(args: Array[String]): Unit = {
    val spark = SparkSession.builder()
      .master("local[*]")
      .appName("UDAF")
      .getOrCreate()
    import spark.implicits._
    val df = spark.read
      .json("data/user.json")
    df.createOrReplaceTempView("user")
    val ds = df.as[User]

    //Turn UDAF into a query column object
    val udafCol = new MyAvgAgeUDAF().toColumn
    ds.select(udafCol).show()
    spark.close()
  }

Result display

Jump top

Topics: Big Data Spark