I am new to both Spark and PySpark Data Frames and ML. How to create a custom crosscheck for the ML library. For example, I want to change the way the training folds are formed, for example. stratified splitting .
This is my current code.
numFolds = 10 predictions = [] lr = LogisticRegression()\ .setFeaturesCol("features")\ .setLabelCol('label') # Grid search on LR model lrparamGrid = ParamGridBuilder()\ .addGrid(lr.regParam, [0.01, 0.1, 0.5, 1.0, 2.0])\ .addGrid(lr.elasticNetParam, [0.0, 0.1, 0.5, 0.8, 1.0])\ .addGrid(lr.maxIter, [5, 10, 20])\ .build() pipelineModel = Pipeline(stages=[lr]) evaluator = BinaryClassificationEvaluator() cv = CrossValidator()\ .setEstimator(pipelineModel)\ .setEvaluator(evaluator)\ .setEstimatorParamMaps(lrparamGrid).setNumFolds(5) # My own Cross-Validation with stratified splits for i in range(numFolds): # Use Stratified indexOfStratifiedSplits trainingData = df[df.ID.isin(indexOfStratifiedSplits[i][0])] testingData = df[df.ID.isin(indexOfStratifiedSplits[i][1])] # Training and Grid Search cvModel = cv.fit(trainingData) predictions.append(cvModel.transform(testingData))
I would like the Cross-Validation class to be called as
cv = MyCrossValidator()\ .setEstimator(pipelineModel)\ .setEvaluator(evaluator)\ .setEstimatorParamMaps(lrparamGrid).setNumFolds(5)\ # Option 1 .setSplitIndexes(indexOfStratifiedSplits) # Option 2 .setSplitType("Stratified",ColumnName)
I do not know, the best option is to create a class that extends CrossValidation.fit or Spark Function Pass . Any option is difficult for me, as a beginner, I tried to copy GitHub codes, but I get a lot of errors, especially I donโt tell Scala, but this pipeline works faster in the Scala API.
Although I have my own functions to separate the data the way I want (based on sklearn), I want to use Pipelines, grid search, and cv together, so that all permutations are allocated, not executed in master. This loop with "My own Cross-Validation" uses only part of the cluster nodes, since the loop takes place in master / driver.
Any Python or Scala API is great, but preferable to Scala.
thanks