Intro
It looks like noone is reading CSV for spark 2.1.0 anymore. Only reference I could find was https://elbauldelprogramador.com/en/how-to-convert-column-to-vectorudt-densevector-spark/.
So, here comes my 5 cents on the issue.
Code
val sc: SparkContext = new SparkContext(master, \ "SuperApp", System.getenv("SPARK_HOME")) val session: SparkSession = SparkSession.builder().getOrCreate() //Firstly define schema val struct = StructType( StructField("price", DoubleType, false) :: StructField("_id", StringType, false) :: StructField("modelYearId", IntegerType, false) :: StructField("zip", IntegerType, false) :: StructField("modelYear", IntegerType, false) :: StructField("modelId", IntegerType, false) :: StructField("makeId", IntegerType, false) :: StructField("mileage", DoubleType, false) :: Nil) val df: DataFrame = session.sqlContext.read.schema(struct) .option("header", "true").csv("cars.csv") var data: DataFrame = df //Transform variable into categorical one data = new OneHotEncoder() .setInputCol("zip") .setOutputCol("zipVec").transform(data) //Assemble features that matter val assembler = new VectorAssembler(). setInputCols(Array("modelYearIdVec", "zipVec", "modelYear", "modelIdVec", "makeIdVec", "mileage")). setOutputCol("features") //to verify our schema is as we want data.printSchema() data = assembler.transform(data)
And then pretty usual mllib tutorial stuff:
// Split the data into training and //test sets (30% held out for testing). val Array(trainingData, testData) = data .randomSplit(Array(0.7, 0.3)) val lr = new LinearRegression() .setLabelCol("price") .setMaxIter(200) .setRegParam(10) .setElasticNetParam(0.75) // Fit the model val lrModel: LinearRegressionModel = lr.fit(trainingData) // Print the coefficients and intercept for linear regression println(s"Coefficients: ${lrModel.coefficients} Intercept: ${lrModel.intercept}") // Summarize the model over the training set //and print out some metrics val trainingSummary = lrModel.summary println(s"numIterations: ${trainingSummary.totalIterations}") println(s"objectiveHistory: [${trainingSummary.objectiveHistory.mkString(",")}]") trainingSummary.residuals.show() println(s"RMSE: ${trainingSummary.rootMeanSquaredError}") println(s"r2: ${trainingSummary.r2}")
No comments:
Post a Comment