Split test

I'm curious if there is something similar to sklearn http://scikit-learn.org/stable/modules/generated/sklearn.model_selection.StratifiedShuffleSplit.html for apache-spark in the latest version 2.0.1.

So far, I could find https://spark.apache.org/docs/latest/mllib-statistics.html#stratified-sampling , which doesn't seem to be very suitable for splitting a heavily unbalanced data set into train / test samples.

+8
source share
3 answers

Spark supports stratified samples, as described at https://s3.amazonaws.com/sparksummit-share/ml-ams-1.0.1/6-sampling/scala/6-sampling_student.html.

df.stat.sampleBy("label", Map(0 -> .10, 1 -> .20, 2 -> .3), 0) 
+2
source

Although this answer is not specific to Spark, in the Apache bass I do this to divide the train by 66% and check 33% (just an illustrative example, you can set partition_fn below to be more complex and accept arguments to indicate the number of buckets or selection of bias towards something or ensuring randomization is fair in size, etc.):

 raw_data = p | 'Read Data' >> Read(...) clean_data = (raw_data | "Clean Data" >> beam.ParDo(CleanFieldsFn()) def partition_fn(element): return random.randint(0, 2) random_buckets = (clean_data | beam.Partition(partition_fn, 3)) clean_train_data = ((random_buckets[0], random_buckets[1]) | beam.Flatten()) clean_eval_data = random_buckets[2] 
+3
source

Suppose we have a dataset like this:

 +---+-----+ | id|label| +---+-----+ | 0| 0.0| | 1| 1.0| | 2| 0.0| | 3| 1.0| | 4| 0.0| | 5| 1.0| | 6| 0.0| | 7| 1.0| | 8| 0.0| | 9| 1.0| +---+-----+ 

This data set is perfectly balanced, but this approach will work for unbalanced data.

Now let's add this DataFrame with additional information that will be useful in determining which rows should go on the train. These steps are as follows:

  • Determine how many examples of each label should be part of a train set given a certain ratio .
  • Shuffle the rows of the DataFrame.
  • Use the window function to split and arrange the DataFrame by label and then row_number() each observation of labels with row_number() .

As a result, we get the following data frame:

 +---+-----+----------+ | id|label|row_number| +---+-----+----------+ | 6| 0.0| 1| | 2| 0.0| 2| | 0| 0.0| 3| | 4| 0.0| 4| | 8| 0.0| 5| | 9| 1.0| 1| | 5| 1.0| 2| | 3| 1.0| 3| | 1| 1.0| 4| | 7| 1.0| 5| +---+-----+----------+ 

Note: rows are shuffled (see Shuffle in the id column), separated by label (see Label column) and ranked.

Suppose we would like to divide 80%. In this case, we would like four 1.0 labels and four 0.0 labels to move to the training set and one 1.0 label and one 0.0 label to go to the test dataset. We have this information in the row_number column, so now we can just use it in a user-defined function (if row_number less than or equal to four, the example goes to a set of workouts).

After applying UDF, the resulting data frame is as follows:

 +---+-----+----------+----------+ | id|label|row_number|isTrainSet| +---+-----+----------+----------+ | 6| 0.0| 1| true| | 2| 0.0| 2| true| | 0| 0.0| 3| true| | 4| 0.0| 4| true| | 8| 0.0| 5| false| | 9| 1.0| 1| true| | 5| 1.0| 2| true| | 3| 1.0| 3| true| | 1| 1.0| 4| true| | 7| 1.0| 5| false| +---+-----+----------+----------+ 

Now, to get the train / test data, you need to do:

 val train = df.where(col("isTrainSet") === true) val test = df.where(col("isTrainSet") === false) 

These sorting and splitting steps can be prohibitive for some really large datasets, so I suggest filtering the dataset as much as possible first. The physical plan is as follows:

 == Physical Plan == *(3) Project [id#4, label#5, row_number#11, if (isnull(row_number#11)) null else UDF(label#5, row_number#11) AS isTrainSet#48] +- Window [row_number() windowspecdefinition(label#5, label#5 ASC NULLS FIRST, specifiedwindowframe(RowFrame, unboundedpreceding$(), currentrow$())) AS row_number#11], [label#5], [label#5 ASC NULLS FIRST] +- *(2) Sort [label#5 ASC NULLS FIRST, label#5 ASC NULLS FIRST], false, 0 +- Exchange hashpartitioning(label#5, 200) +- *(1) Project [id#4, label#5] +- *(1) Sort [_nondeterministic#9 ASC NULLS FIRST], true, 0 +- Exchange rangepartitioning(_nondeterministic#9 ASC NULLS FIRST, 200) +- LocalTableScan [id#4, label#5, _nondeterministic#9 

Here's a full working example (tested with Spark 2.3.0 and Scala 2.11.12):

 import org.apache.spark.SparkConf import org.apache.spark.sql.expressions.Window import org.apache.spark.sql.{DataFrame, Row, SparkSession} import org.apache.spark.sql.functions.{col, row_number, udf, rand} class StratifiedTrainTestSplitter { def getNumExamplesPerClass(ss: SparkSession, label: String, trainRatio: Double)(df: DataFrame): Map[Double, Long] = { df.groupBy(label).count().createOrReplaceTempView("labelCounts") val query = f"SELECT $label AS ratioLabel, count, cast(count * $trainRatio as long) AS trainExamples FROM labelCounts" import ss.implicits._ ss.sql(query) .select("ratioLabel", "trainExamples") .map((r: Row) => r.getDouble(0) -> r.getLong(1)) .collect() .toMap } def split(df: DataFrame, label: String, trainRatio: Double): DataFrame = { val w = Window.partitionBy(col(label)).orderBy(col(label)) val rowNumPartitioner = row_number().over(w) val dfRowNum = df.sort(rand).select(col("*"), rowNumPartitioner as "row_number") dfRowNum.show() val observationsPerLabel: Map[Double, Long] = getNumExamplesPerClass(df.sparkSession, label, trainRatio)(df) val addIsTrainColumn = udf((label: Double, rowNumber: Int) => rowNumber <= observationsPerLabel(label)) dfRowNum.withColumn("isTrainSet", addIsTrainColumn(col(label), col("row_number"))) } } object StratifiedTrainTestSplitter { def getDf(ss: SparkSession): DataFrame = { val data = Seq( (0, 0.0), (1, 1.0), (2, 0.0), (3, 1.0), (4, 0.0), (5, 1.0), (6, 0.0), (7, 1.0), (8, 0.0), (9, 1.0) ) ss.createDataFrame(data).toDF("id", "label") } def main(args: Array[String]): Unit = { val spark: SparkSession = SparkSession .builder() .config(new SparkConf().setMaster("local[1]")) .getOrCreate() val df = new StratifiedTrainTestSplitter().split(getDf(spark), "label", 0.8) df.cache() df.where(col("isTrainSet") === true).show() df.where(col("isTrainSet") === false).show() } } 

Note: in this case, Double labels. If your String shortcuts you have to switch types here and there.

+2
source

All Articles