温馨提示×

温馨提示×

您好,登录后才能下订单哦!

密码登录×
登录注册×
其他方式登录
点击 登录注册 即表示同意《亿速云用户服务条款》

spark-sql 自定义函数

发布时间:2020-06-02 13:21:57 来源:网络 阅读:1203 作者:原生zzy 栏目:大数据

(1)自定义UDF

object SparkSqlTest {
    def main(args: Array[String]): Unit = {
        //屏蔽多余的日志
        Logger.getLogger("org.apache.hadoop").setLevel(Level.WARN)
        Logger.getLogger("org.apache.spark").setLevel(Level.WARN)
        Logger.getLogger("org.project-spark").setLevel(Level.WARN)
        //构建编程入口
        val conf: SparkConf = new SparkConf()
        conf.setAppName("SparkSqlTest")
            .setMaster("local[2]")
        val spark: SparkSession = SparkSession.builder().config(conf)
            .getOrCreate()

        //创建sqlcontext对象
        val sqlContext: SQLContext = spark.sqlContext

        /**
          * 注册定义的UDF:
          * 这里的泛型[Int,String]
          * 第一个是返回值类型,后面可以是一个或者多个,是方法参数类型
          */
        sqlContext.udf.register[Int,String]("strLen",strLen)
        val sql=
            """
              |select strLen("zhangsan")
            """.stripMargin
        spark.sql(sql).show()
    }
    //自定义UDF方法
    def strLen(str:String):Integer={
        str.length
    }
}

(2) 自定义UDAF

这里举的例子是实现一个count:
自定义UDAF类:

    class MyCountUDAF extends UserDefinedAggregateFunction{
    //该UDAF输入的数据类型
    override def inputSchema: StructType = {
        StructType(List(
            StructField("age",DataTypes.IntegerType)
        ))
    }

    //在该UDAF中聚合的数据类型
    override def bufferSchema: StructType = {
        StructType(List(
            StructField("age",DataTypes.IntegerType)
        ))
    }
    //该UDAF输出的数据类型
    override def dataType: DataType = DataTypes.IntegerType

    //确定性判断,通常特定输入和输出的类型一致
    override def deterministic: Boolean = true

    //buffer:计算过程中临时的存储了聚合结果的Buffer
    override def initialize(buffer: MutableAggregationBuffer): Unit = {
        buffer.update(0,0)
    }

    /**
      * 分区内的数据聚合合并
      * @param buffer:就是我们在initialize方法中声明初始化的临时缓冲区
      * @param input:聚合操作新传入的值
      */
    override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
        val oldValue=buffer.getInt(0)
        buffer.update(0,oldValue+1)
    }
    /**
      * 分区间的聚合
      * @param buffer1:分区一聚合的临时结果
      * @param buffer2;分区二聚合的临时结果
      */
    override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
        val p1=buffer1.getInt(0)
        val p2=buffer2.getInt(0)
        buffer1.update(0,p1+p2)
    }

    //该聚合函数最终输出的值
    override def evaluate(buffer: Row): Any = {
        buffer.get(0)
    }
}

调用:

object SparkSqlTest {
    def main(args: Array[String]): Unit = {
        //屏蔽多余的日志
        Logger.getLogger("org.apache.hadoop").setLevel(Level.WARN)
        Logger.getLogger("org.apache.spark").setLevel(Level.WARN)
        Logger.getLogger("org.project-spark").setLevel(Level.WARN)
        //构建编程入口
        val conf: SparkConf = new SparkConf()
        conf.setAppName("SparkSqlTest")
            .setMaster("local[2]")
            .set("spark.serializer","org.apache.spark.serializer.KryoSerializer")
            .registerKryoClasses(Array(classOf[Student]))
        val spark: SparkSession = SparkSession.builder().config(conf)
            .getOrCreate()

        //创建sqlcontext对象
        val sqlContext: SQLContext = spark.sqlContext

        //注册UDAF
        sqlContext.udf.register("myCount",new MyCountUDAF())

        val stuList = List(
            new Student("委xx", 18),
            new Student("吴xx", 18),
            new Student("戚xx", 18),
            new Student("王xx", 19),
            new Student("薛xx", 19)
        )
        import spark.implicits._
        val stuDS: Dataset[Student] = sqlContext.createDataset(stuList)
        stuDS.createTempView("student")
        val sql=
            """
              |select myCount(1) counts
              |from student
              |group by age
              |order by counts
            """.stripMargin
        spark.sql(sql).show()
    }

}
case class Student(name:String,age:Int)
向AI问一下细节

免责声明:本站发布的内容(图片、视频和文字)以原创、转载和分享为主,文章观点不代表本网站立场,如果涉及侵权请联系站长邮箱:is@yisu.com进行举报,并提供相关证据,一经查实,将立刻删除涉嫌侵权内容。

AI