博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
批量进行One-hot-encoder且进行特征字段拼接,并完成模型训练demo
阅读量:5967 次
发布时间:2019-06-19

本文共 3614 字,大约阅读时间需要 12 分钟。

hot3.png

import org.apache.spark.ml.Pipelineimport org.apache.spark.ml.feature.{StringIndexer, OneHotEncoder}import org.apache.spark.ml.feature.VectorAssemblerimport ml.dmlc.xgboost4j.scala.spark.{XGBoostEstimator, XGBoostClassificationModel}import org.apache.spark.ml.evaluation.BinaryClassificationEvaluatorimport org.apache.spark.ml.tuning.{ParamGridBuilder, CrossValidator}import org.apache.spark.ml.PipelineModelval data = (spark.read.format("csv")  .option("sep", ",")  .option("inferSchema", "true")  .option("header", "true")  .load("/user/spark/security/Affairs.csv"))data.createOrReplaceTempView("res1")val affairs = "case when affairs>0 then 1 else 0 end as affairs,"val df = (spark.sql("select " + affairs +  "gender,age,yearsmarried,children,religiousness,education,occupation,rating" +  " from res1 "))    val categoricals = df.dtypes.filter(_._2 == "StringType") map (_._1)val indexers = categoricals.map(  c => new StringIndexer().setInputCol(c).setOutputCol(s"${c}_idx"))val encoders = categoricals.map(  c => new OneHotEncoder().setInputCol(s"${c}_idx").setOutputCol(s"${c}_enc").setDropLast(false))    val colArray_enc = categoricals.map(x => x + "_enc")val colArray_numeric = df.dtypes.filter(_._2 != "StringType") map (_._1)val final_colArray = (colArray_numeric ++ colArray_enc).filter(!_.contains("affairs"))val vectorAssembler = new VectorAssembler().setInputCols(final_colArray).setOutputCol("features")/*val pipeline = new Pipeline().setStages(indexers ++ encoders ++ Array(vectorAssembler))pipeline.fit(df).transform(df)*////// Create an XGBoost Classifier val xgb = new XGBoostEstimator(Map("num_class" -> 2, "num_rounds" -> 5, "objective" -> "binary:logistic", "booster" -> "gbtree")).setLabelCol("affairs").setFeaturesCol("features")	  // XGBoost paramater gridval xgbParamGrid = (new ParamGridBuilder()   .addGrid(xgb.round, Array(10))   .addGrid(xgb.maxDepth, Array(10,20))   .addGrid(xgb.minChildWeight, Array(0.1))   .addGrid(xgb.gamma, Array(0.1))   .addGrid(xgb.subSample, Array(0.8))   .addGrid(xgb.colSampleByTree, Array(0.90))   .addGrid(xgb.alpha, Array(0.0))   .addGrid(xgb.lambda, Array(0.6))   .addGrid(xgb.scalePosWeight, Array(0.1))   .addGrid(xgb.eta, Array(0.4))   .addGrid(xgb.boosterType, Array("gbtree"))   .addGrid(xgb.objective, Array("binary:logistic"))    .build())   // Create the XGBoost pipelineval pipeline = new Pipeline().setStages(indexers ++ encoders ++ Array(vectorAssembler, xgb))// Setup the binary classifier evaluatorval evaluator = (new BinaryClassificationEvaluator()   .setLabelCol("affairs")   .setRawPredictionCol("prediction")   .setMetricName("areaUnderROC"))   // Create the Cross Validation pipeline, using XGBoost as the estimator, the// Binary Classification evaluator, and xgbParamGrid for hyperparametersval cv = (new CrossValidator()   .setEstimator(pipeline)   .setEvaluator(evaluator)   .setEstimatorParamMaps(xgbParamGrid)   .setNumFolds(3)   .setSeed(0)) // Create the model by fitting the training dataval xgbModel = cv.fit(df) // Test the data by scoring the modelval results = xgbModel.transform(df)// Print out a copy of the parameters used by XGBoost, attention pipeline(xgbModel.bestModel.asInstanceOf[PipelineModel]  .stages(5).asInstanceOf[XGBoostClassificationModel]  .extractParamMap().toSeq.foreach(println))  results.select("affairs","prediction").showprintln("---Confusion Matrix------")results.stat.crosstab("affairs","prediction").show()// What was the overall accuracy of the model, using AUCval auc = evaluator.evaluate(results)println("----AUC--------")println("auc="+auc)

转载于:https://my.oschina.net/kyo4321/blog/2050708

你可能感兴趣的文章
DataUml Design 介绍11 - DataUML 1.5版本功能-支持无Oracle客户端
查看>>
我的友情链接
查看>>
你一个人能独处多久
查看>>
Octopress使用中经验总结
查看>>
spring结合ehcache-spring-annotations配置缓存
查看>>
一个简单的数据库工具类
查看>>
我的友情链接
查看>>
理解 Glance - 每天5分钟玩转 OpenStack(20)
查看>>
Unshelve Instance 操作详解 - 每天5分钟玩转 OpenStack(39)
查看>>
init.d文件夹 2012-02-09
查看>>
CKeditor的几种配置方式
查看>>
解决Android 输入法InputMethodService 显示时让原Activity大小计算错误问题
查看>>
s3c6410烧写u-boot&&Linux
查看>>
TensorBoard:嵌入可视化
查看>>
FreeSWITCH的NAT穿越
查看>>
gitlab版本控制系统源码部署
查看>>
java反射机制中的getDeclaredField()
查看>>
java数据流无法输出验证码
查看>>
JAVA中的IO流
查看>>
PHP 正则表达式
查看>>