微信搜索lxw1234bigdata | 邀请体验:数阅–数据管理、OLAP分析与可视化平台 | 赞助作者:赞助作者

Spark MLlib实现的广告点击预测–Gradient-Boosted Trees

Spark lxw1234@qq.com 40773℃ 9评论

关键字:spark、mllib、Gradient-Boosted Trees、广告点击预测

本文尝试使用Spark提供的机器学习算法 Gradient-Boosted Trees来预测一个用户是否会点击广告。

训练和测试数据使用Kaggle Avazu CTR 比赛的样例数据,下载地址:https://www.kaggle.com/c/avazu-ctr-prediction/data

数据格式如下:

spark

包含24个字段:

  • 1-id: ad identifier
  • 2-click: 0/1 for non-click/click
  • 3-hour: format is YYMMDDHH, so 14091123 means 23:00 on Sept. 11, 2014 UTC.
  • 4-C1 — anonymized categorical variable
  • 5-banner_pos
  • 6-site_id
  • 7-site_domain
  • 8-site_category
  • 9-app_id
  • 10-app_domain
  • 11-app_category
  • 12-device_id
  • 13-device_ip
  • 14-device_model
  • 15-device_type
  • 16-device_conn_type
  • 17~24—C14-C21 — anonymized categorical variables

其中5到15列为分类特征,16~24列为数值型特征。

Spark代码如下:

package com.lxw1234.test

import scala.collection.mutable.ListBuffer
import scala.collection.mutable.ArrayBuffer

import org.apache.spark.SparkContext
import org.apache.spark.SparkContext._
import org.apache.spark.SparkConf
import org.apache.spark.rdd.RDD

import org.apache.spark.mllib.classification.NaiveBayes
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.linalg.Vectors

import org.apache.spark.mllib.tree.GradientBoostedTrees
import org.apache.spark.mllib.tree.configuration.BoostingStrategy
import org.apache.spark.mllib.tree.model.GradientBoostedTreesModel

/**
 * By: lxw
 * http://lxw1234.com
 */
object CtrPredict {
  
  //input (1fbe01fe,f3845767,28905ebd,ecad2386,7801e8d9)
  //output ((0:1fbe01fe),(1:f3845767),(2:28905ebd),(3:ecad2386),(4:7801e8d9))
    def parseCatFeatures(catfeatures: Array[String]) :  List[(Int, String)] = {
      var catfeatureList = new ListBuffer[(Int, String)]()
      for (i <- 0 until catfeatures.length){
          catfeatureList += i -> catfeatures(i).toString
      }
      catfeatureList.toList
    }
  
  def main(args: Array[String]) {
      val conf = new SparkConf().setMaster("yarn-client")
      val sc = new SparkContext(conf)
      
      var ctrRDD = sc.textFile("/tmp/lxw1234/sample.txt",10);
      println("Total records : " + ctrRDD.count)
      
      //将整个数据集80%作为训练数据,20%作为测试数据集
      var train_test_rdd = ctrRDD.randomSplit(Array(0.8, 0.2), seed = 37L)
      var train_raw_rdd = train_test_rdd(0)
      var test_raw_rdd = train_test_rdd(1)
      
      println("Train records : " + train_raw_rdd.count)
      println("Test records : " + test_raw_rdd.count)
      
      //cache train, test
      train_raw_rdd.cache()
      test_raw_rdd.cache()
      
      var train_rdd = train_raw_rdd.map{ line =>
          var tokens = line.split(",",-1)
          //key为id和是否点击广告
          var catkey = tokens(0) + "::" + tokens(1)
          //第6列到第15列为分类特征,需要One-Hot-Encoding
          var catfeatures = tokens.slice(5, 14)
          //第16列到24列为数值特征,直接使用
          var numericalfeatures = tokens.slice(15, tokens.size-1)
          (catkey, catfeatures, numericalfeatures)
      }
      
      //拿一条出来看看
      train_rdd.take(1)
      //scala> train_rdd.take(1)
      //res6: Array[(String, Array[String], Array[String])] = Array((1000009418151094273::0,Array(1fbe01fe, 
      //            f3845767, 28905ebd, ecad2386, 7801e8d9, 07d7df22, a99f214a, ddd2926e, 44956a24),
      //              Array(2, 15706, 320, 50, 1722, 0, 35, -1)))
      
      //将分类特征先做特征ID映射
      var train_cat_rdd  = train_rdd.map{
        x => parseCatFeatures(x._2)
      }
      
      train_cat_rdd.take(1)
      //scala> train_cat_rdd.take(1)
      //res12: Array[List[(Int, String)]] = Array(List((0,1fbe01fe), (1,f3845767), (2,28905ebd), 
      //        (3,ecad2386), (4,7801e8d9), (5,07d7df22), (6,a99f214a), (7,ddd2926e), (8,44956a24)))
      
      //将train_cat_rdd中的(特征ID:特征)去重,并进行编号
      var oheMap = train_cat_rdd.flatMap(x => x).distinct().zipWithIndex().collectAsMap()
      //oheMap: scala.collection.Map[(Int, String),Long] = Map((7,608511e9) -> 31527, (7,b2d8fbed) -> 42207, 
      //  (7,1d3e2fdb) -> 52791
      println("Number of features")
      println(oheMap.size)
      
      //create OHE for train data
      var ohe_train_rdd = train_rdd.map{ case (key, cateorical_features, numerical_features) =>
              var cat_features_indexed = parseCatFeatures(cateorical_features)                        
              var cat_feature_ohe = new ArrayBuffer[Double]
              for (k <- cat_features_indexed) {
                if(oheMap contains k){
                cat_feature_ohe += (oheMap get (k)).get.toDouble
                }else {
                  cat_feature_ohe += 0.0
                }               
              }
              var numerical_features_dbl  = numerical_features.map{
                        x => 
                          var x1 = if (x.toInt < 0) "0" else x
                        x1.toDouble
              }
              var features = cat_feature_ohe.toArray ++  numerical_features_dbl           
              LabeledPoint(key.split("::")(1).toInt, Vectors.dense(features))                                               
     }
      
     ohe_train_rdd.take(1)
     //res15: Array[org.apache.spark.mllib.regression.LabeledPoint] = 
     //  Array((0.0,[43127.0,50023.0,57445.0,13542.0,31092.0,14800.0,23414.0,54121.0,
     //     17554.0,2.0,15706.0,320.0,50.0,1722.0,0.0,35.0,0.0]))
     
     //训练模型
     //val boostingStrategy = BoostingStrategy.defaultParams("Regression")
     val boostingStrategy = BoostingStrategy.defaultParams("Classification")
     boostingStrategy.numIterations = 100
     boostingStrategy.treeStrategy.numClasses = 2
     boostingStrategy.treeStrategy.maxDepth = 10
     boostingStrategy.treeStrategy.categoricalFeaturesInfo = Map[Int, Int]()
     
     
     val model = GradientBoostedTrees.train(ohe_train_rdd, boostingStrategy)
     //保存模型
     model.save(sc, "/tmp/myGradientBoostingClassificationModel")
     //加载模型
     val sameModel = GradientBoostedTreesModel.load(sc,"/tmp/myGradientBoostingClassificationModel")
     
     //将测试数据集做OHE
     var test_rdd = test_raw_rdd.map{ line =>
        var tokens = line.split(",")
        var catkey = tokens(0) + "::" + tokens(1)
        var catfeatures = tokens.slice(5, 14)
        var numericalfeatures = tokens.slice(15, tokens.size-1)
        (catkey, catfeatures, numericalfeatures)
     }
     
     var ohe_test_rdd = test_rdd.map{ case (key, cateorical_features, numerical_features) =>
            var cat_features_indexed = parseCatFeatures(cateorical_features)      
            var cat_feature_ohe = new ArrayBuffer[Double]
            for (k <- cat_features_indexed) {               
              if(oheMap contains k){
                cat_feature_ohe += (oheMap get (k)).get.toDouble
              }else {
                cat_feature_ohe += 0.0
              }
            }
          var numerical_features_dbl  = numerical_features.map{x => 
                              var x1 = if (x.toInt < 0) "0" else x
                              x1.toDouble}
            var features = cat_feature_ohe.toArray ++  numerical_features_dbl           
            LabeledPoint(key.split("::")(1).toInt, Vectors.dense(features))                                               
     }
     
     //验证测试数据集
     var b = ohe_test_rdd.map {
        y => var s = model.predict(y.features)
        (s,y.label,y.features)
     }
     
     b.take(10).foreach(println)
     
     //预测准确率
      var predictions = ohe_test_rdd.map(lp => sameModel.predict(lp.features))
      predictions.take(10).foreach(println)
      var predictionAndLabel = predictions.zip( ohe_test_rdd.map(_.label))
      var accuracy = 1.0 * predictionAndLabel.filter(x => x._1 == x._2 ).count/ohe_test_rdd.count
      println("GBTR accuracy " + accuracy)
      //GBTR accuracy 0.8227084119200302
    
  }
  
}

其中,训练数据集: Train records : 104558, 测试数据集:Test records : 26510

程序主要输出:

scala> train_rdd.take(1)
res23: Array[(String, Array[String], Array[String])] = Array((1000009418151094273::0,
		Array(1fbe01fe, f3845767, 28905ebd, ecad2386, 7801e8d9, 07d7df22, a99f214a, ddd2926e, 44956a24),
		Array(2, 15706, 320, 50, 1722, 0, 35, -1)))


scala> train_cat_rdd.take(1)
res24: Array[List[(Int, String)]] = Array(List((0,1fbe01fe), (1,f3845767), (2,28905ebd), 
		(3,ecad2386), (4,7801e8d9), (5,07d7df22), (6,a99f214a), (7,ddd2926e), (8,44956a24)))


scala> println("Number of features")
Number of features

scala> println(oheMap.size)
57606


scala> ohe_train_rdd.take(1)
res27: Array[org.apache.spark.mllib.regression.LabeledPoint] = Array(
		(0.0,[11602.0,22813.0,11497.0,16828.0,30657.0,23893.0,13182.0,31723.0,39722.0,2.0,15706.0,320.0,50.0,1722.0,0.0,35.0,0.0]))


scala> println("GBTR accuracy " + accuracy)
GBTR accuracy 0.8227084119200302


 

如果觉得本博客对您有帮助,请 赞助作者

转载请注明:lxw的大数据田地 » Spark MLlib实现的广告点击预测–Gradient-Boosted Trees

喜欢 (40)
分享 (0)
发表我的评论
取消评论
表情

Hi,您需要填写昵称和邮箱!

  • 昵称 (必填)
  • 邮箱 (必填)
  • 网址
(9)个小伙伴在吐槽
  1. 你好,这里One-Hot-Encoding的做法是直接把特征在map中所在位置作为特征的数值表示,这样做有什么trick么,因为传统的做法是把特征位置的值置为1,其它位置为0。
    smartlan2016-01-18 11:22 回复
  2. 在做特征提取的时候你用了和上一篇贝叶斯分类不同的计算方式 有什么讲究吗
    郭彦超2016-01-27 09:33 回复
    • 没什么讲究,就是测试效果。
      lxw1234@qq.com2016-01-27 11:48 回复
      • hashTf 与 zipWithIndex成一个大集合 两种方式哪个效率高 更易于维护和使用
        郭彦超2016-01-28 13:36 回复
  3. 写的很好,想学习一下,请问源码有没有开源 ,地址能共享下吗 790753906@qq.com
    郭彦超2016-01-28 16:56 回复
  4. 你好,作为初学者,表示没有看懂代码中哪里体现了“预测一个用户是否会点击广告”这一功能,望解答 :smile:
    zero2016-05-05 13:00 回复
  5. 这里用准确率做为评价标准不适合,因为正样本率比较低,如果全部预测为0,准确率也非常高。建议用auc,ks之类指标
    rookie2016-12-16 17:50 回复
  6. spark gbdt 这个可以输出预测的具体概率值吗,而不是分为 0 和1
    babyxingqing2017-06-27 15:14 回复
  7. 训练模型的语句那边编译能通过?
    Hunglish2018-03-21 12:05 回复