package ml.dmlc.xgboost4j.scala.rapids.spark;

import ai.rapids.cudf.Table;
import java.util.NoSuchElementException;
import ml.dmlc.xgboost4j.gpu.java.CudfColumnBatch;
import ml.dmlc.xgboost4j.java.nvidia.spark.GpuColumnBatch;
import ml.dmlc.xgboost4j.scala.Booster;
import ml.dmlc.xgboost4j.scala.DMatrix;
import ml.dmlc.xgboost4j.scala.QuantileDMatrix;
import ml.dmlc.xgboost4j.scala.rapids.spark.GpuPreXGBoost;
import ml.dmlc.xgboost4j.scala.spark.PreXGBoost$;
import ml.dmlc.xgboost4j.scala.spark.PreXGBoostProvider;
import ml.dmlc.xgboost4j.scala.spark.Watches;
import ml.dmlc.xgboost4j.scala.spark.XGBoost$;
import ml.dmlc.xgboost4j.scala.spark.XGBoostClassificationModel;
import ml.dmlc.xgboost4j.scala.spark.XGBoostClassificationModel$;
import ml.dmlc.xgboost4j.scala.spark.XGBoostClassifier;
import ml.dmlc.xgboost4j.scala.spark.XGBoostExecutionParams;
import ml.dmlc.xgboost4j.scala.spark.XGBoostRegressionModel;
import ml.dmlc.xgboost4j.scala.spark.XGBoostRegressionModel$;
import ml.dmlc.xgboost4j.scala.spark.XGBoostRegressor;
import ml.dmlc.xgboost4j.scala.spark.params.BoosterParams;
import ml.dmlc.xgboost4j.scala.spark.params.HasBaseMarginCol;
import ml.dmlc.xgboost4j.scala.spark.params.HasFeaturesCols;
import ml.dmlc.xgboost4j.scala.spark.params.NonParamVariables;
import ml.dmlc.xgboost4j.scala.spark.params.XGBoostEstimatorCommon;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.spark.SparkContext;
import org.apache.spark.TaskContext$;
import org.apache.spark.broadcast.Broadcast;
import org.apache.spark.ml.Estimator;
import org.apache.spark.ml.Model;
import org.apache.spark.ml.param.Param;
import org.apache.spark.ml.param.shared.HasLabelCol;
import org.apache.spark.ml.param.shared.HasWeightCol;
import org.apache.spark.rdd.RDD;
import org.apache.spark.sql.Column;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RuntimeConfig;
import org.apache.spark.sql.SparkSession$;
import org.apache.spark.sql.catalyst.CatalystTypeConverters$;
import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.catalyst.encoders.RowEncoder$;
import org.apache.spark.sql.catalyst.expressions.UnsafeProjection;
import org.apache.spark.sql.catalyst.expressions.UnsafeProjection$;
import org.apache.spark.sql.functions$;
import org.apache.spark.sql.types.ArrayType;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.FloatType$;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructField$;
import org.apache.spark.sql.types.StructType;
import org.apache.spark.sql.vectorized.ColumnarBatch;
import scala.Array$;
import scala.Function0;
import scala.Function1;
import scala.Function2;
import scala.Function3;
import scala.MatchError;
import scala.None$;
import scala.Option;
import scala.PartialFunction;
import scala.Predef$;
import scala.Predef$ArrowAssoc$;
import scala.Some;
import scala.Tuple2;
import scala.Tuple3;
import scala.Tuple4;
import scala.Tuple5;
import scala.Tuple6;
import scala.collection.BufferedIterator;
import scala.collection.GenTraversableOnce;
import scala.collection.Iterable;
import scala.collection.Iterator;
import scala.collection.JavaConverters$;
import scala.collection.Seq;
import scala.collection.Seq$;
import scala.collection.SeqLike;
import scala.collection.Traversable;
import scala.collection.TraversableOnce;
import scala.collection.generic.CanBuildFrom;
import scala.collection.immutable.$colon;
import scala.collection.immutable.IndexedSeq;
import scala.collection.immutable.List;
import scala.collection.immutable.Map;
import scala.collection.immutable.Map$;
import scala.collection.immutable.Nil$;
import scala.collection.immutable.Set;
import scala.collection.immutable.Stream;
import scala.collection.immutable.StringOps;
import scala.collection.immutable.Vector;
import scala.collection.mutable.ArrayOps;
import scala.collection.mutable.Buffer;
import scala.collection.mutable.StringBuilder;
import scala.math.Numeric;
import scala.math.Ordering;
import scala.package$;
import scala.reflect.ClassTag;
import scala.reflect.ClassTag$;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;
import scala.runtime.Nothing$;

/* compiled from: GpuPreXGBoost.scala */
/* loaded from: input_file:ml/dmlc/xgboost4j/scala/rapids/spark/GpuPreXGBoost$.class */
public final class GpuPreXGBoost$ implements PreXGBoostProvider {
    public static GpuPreXGBoost$ MODULE$;
    private final Log logger;
    private final String FEATURES_COLS;
    private final String TRAIN_NAME;

    static {
        new GpuPreXGBoost$();
    }

    private Log logger() {
        return this.logger;
    }

    private String FEATURES_COLS() {
        return this.FEATURES_COLS;
    }

    private String TRAIN_NAME() {
        return this.TRAIN_NAME;
    }

    @Override // ml.dmlc.xgboost4j.scala.spark.PreXGBoostProvider
    public boolean providerEnabled(Option<Dataset<?>> option) {
        boolean z;
        Option option2 = (Option) option.map(dataset -> {
            return new Some(dataset.sparkSession().conf());
        }).getOrElse(() -> {
            return SparkSession$.MODULE$.getActiveSession().map(sparkSession -> {
                return sparkSession.conf();
            });
        });
        if (!option2.isDefined()) {
            return false;
        }
        RuntimeConfig runtimeConfig = (RuntimeConfig) option2.get();
        try {
            z = new StringOps(Predef$.MODULE$.augmentString(runtimeConfig.get("spark.rapids.sql.enabled"))).toBoolean();
        } catch (NoSuchElementException unused) {
            z = true;
        } catch (Throwable unused2) {
            z = false;
        }
        return z && new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps(runtimeConfig.get("spark.sql.extensions", "").split(","))).contains("com.nvidia.spark.rapids.SQLExecPlugin");
    }

    @Override // ml.dmlc.xgboost4j.scala.spark.PreXGBoostProvider
    public Function1<XGBoostExecutionParams, Tuple2<RDD<Function0<Watches>>, Option<RDD<?>>>> buildDatasetToRDD(Estimator<?> estimator, Dataset<?> dataset, Map<String, Object> map) {
        String str;
        if (!(estimator instanceof XGBoostEstimatorCommon)) {
            throw new RuntimeException(new StringBuilder(23).append("Unsupported estimator: ").append(estimator).toString());
        }
        Predef$.MODULE$.require((estimator.isDefined(((BoosterParams) estimator).device()) && (((BoosterParams) estimator).getDevice().equals("cuda") || ((BoosterParams) estimator).getDevice().equals("gpu"))) || (estimator.isDefined(((BoosterParams) estimator).treeMethod()) && ((BoosterParams) estimator).getTreeMethod().equals("gpu_hist")), () -> {
            return "GPU train requires `device` set to `cuda` or `gpu`.";
        });
        if (estimator instanceof XGBoostRegressor) {
            XGBoostRegressor xGBoostRegressor = (XGBoostRegressor) estimator;
            str = xGBoostRegressor.isDefined(xGBoostRegressor.groupCol()) ? xGBoostRegressor.getGroupCol() : "";
        } else {
            if (!(estimator instanceof XGBoostClassifier)) {
                throw new RuntimeException(new StringBuilder(23).append("Unsupported estimator: ").append(estimator).toString());
            }
            str = "";
        }
        Tuple4 tuple4 = new Tuple4(GpuUtils$.MODULE$.getColumnNames(estimator, Predef$.MODULE$.wrapRefArray(new Param[]{((HasLabelCol) estimator).labelCol(), ((HasWeightCol) estimator).weightCol(), ((HasBaseMarginCol) estimator).baseMarginCol()})), ((HasFeaturesCols) estimator).getFeaturesCols(), str, ((NonParamVariables) estimator).getEvalSets(map));
        if (tuple4 != null) {
            Seq seq = (Seq) tuple4._1();
            String[] strArr = (String[]) tuple4._2();
            String str2 = (String) tuple4._3();
            Map map2 = (Map) tuple4._4();
            Some unapplySeq = Seq$.MODULE$.unapplySeq(seq);
            if (!unapplySeq.isEmpty() && unapplySeq.get() != null && ((SeqLike) unapplySeq.get()).lengthCompare(3) == 0) {
                Tuple6 tuple6 = new Tuple6((String) ((SeqLike) unapplySeq.get()).apply(0), (String) ((SeqLike) unapplySeq.get()).apply(1), (String) ((SeqLike) unapplySeq.get()).apply(2), strArr, str2, map2);
                String str3 = (String) tuple6._1();
                String str4 = (String) tuple6._2();
                String str5 = (String) tuple6._3();
                String[] strArr2 = (String[]) tuple6._4();
                String str6 = (String) tuple6._5();
                Map map3 = (Map) tuple6._6();
                ColumnDataBatch buildColumnDataBatch = GpuUtils$.MODULE$.buildColumnDataBatch(Predef$.MODULE$.wrapRefArray(strArr2), str3, str4, str5, "", GpuUtils$.MODULE$.prepareColumnType(dataset, Predef$.MODULE$.wrapRefArray(strArr2), str3, str4, str5, GpuUtils$.MODULE$.prepareColumnType$default$6()));
                Map map4 = (Map) map3.map(tuple2 -> {
                    if (tuple2 == null) {
                        throw new MatchError(tuple2);
                    }
                    return new Tuple2((String) tuple2._1(), GpuUtils$.MODULE$.buildColumnDataBatch(Predef$.MODULE$.wrapRefArray(strArr2), str3, str4, str5, str6, GpuUtils$.MODULE$.prepareColumnType((Dataset) tuple2._2(), Predef$.MODULE$.wrapRefArray(strArr2), str3, str4, str5, GpuUtils$.MODULE$.prepareColumnType$default$6())));
                }, Map$.MODULE$.canBuildFrom());
                return xGBoostExecutionParams -> {
                    return new Tuple2(MODULE$.buildRDDWatches(MODULE$.prepareInputData(buildColumnDataBatch, map4, xGBoostExecutionParams.numWorkers(), xGBoostExecutionParams.cacheTrainingSet()), xGBoostExecutionParams, map4.isEmpty()), None$.MODULE$);
                };
            }
        }
        throw new MatchError(tuple4);
    }

    @Override // ml.dmlc.xgboost4j.scala.spark.PreXGBoostProvider
    public Dataset<Row> transformDataset(Model<?> model, Dataset<?> dataset) {
        Tuple5 tuple5;
        if (model instanceof XGBoostClassificationModel) {
            XGBoostClassificationModel xGBoostClassificationModel = (XGBoostClassificationModel) model;
            new $colon.colon(XGBoostClassificationModel$.MODULE$._rawPredictionCol(), new $colon.colon(XGBoostClassificationModel$.MODULE$._probabilityCol(), new $colon.colon(xGBoostClassificationModel.leafPredictionCol(), new $colon.colon(xGBoostClassificationModel.contribPredictionCol(), Nil$.MODULE$))));
            Function3 function3 = (booster, dMatrix, iterator) -> {
                Iterator<Row>[] producePredictionItrs = xGBoostClassificationModel.producePredictionItrs(booster, dMatrix);
                Option unapplySeq = Array$.MODULE$.unapplySeq(producePredictionItrs);
                if (unapplySeq.isEmpty() || unapplySeq.get() == null || ((SeqLike) unapplySeq.get()).lengthCompare(4) != 0) {
                    throw new MatchError(producePredictionItrs);
                }
                Tuple4 tuple4 = new Tuple4((Iterator) ((SeqLike) unapplySeq.get()).apply(0), (Iterator) ((SeqLike) unapplySeq.get()).apply(1), (Iterator) ((SeqLike) unapplySeq.get()).apply(2), (Iterator) ((SeqLike) unapplySeq.get()).apply(3));
                return xGBoostClassificationModel.produceResultIterator(iterator, (Iterator) tuple4._1(), (Iterator) tuple4._2(), (Iterator) tuple4._3(), (Iterator) tuple4._4());
            };
            StructType structType = new StructType((StructField[]) new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[]) new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps(dataset.schema().fields())).$plus$plus(new $colon.colon(new StructField(XGBoostClassificationModel$.MODULE$._rawPredictionCol(), new ArrayType(FloatType$.MODULE$, false), false, StructField$.MODULE$.apply$default$4()), Nil$.MODULE$), Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(StructField.class))))).$plus$plus(new $colon.colon(new StructField(XGBoostClassificationModel$.MODULE$._probabilityCol(), new ArrayType(FloatType$.MODULE$, false), false, StructField$.MODULE$.apply$default$4()), Nil$.MODULE$), Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(StructField.class))));
            if (xGBoostClassificationModel.isDefined(xGBoostClassificationModel.leafPredictionCol())) {
                structType = structType.add(new StructField(xGBoostClassificationModel.getLeafPredictionCol(), new ArrayType(FloatType$.MODULE$, false), false, StructField$.MODULE$.apply$default$4()));
            }
            if (xGBoostClassificationModel.isDefined(xGBoostClassificationModel.contribPredictionCol())) {
                structType = structType.add(new StructField(xGBoostClassificationModel.getContribPredictionCol(), new ArrayType(FloatType$.MODULE$, false), false, StructField$.MODULE$.apply$default$4()));
            }
            tuple5 = new Tuple5(xGBoostClassificationModel._booster(), function3, structType, xGBoostClassificationModel.getFeaturesCols(), BoxesRunTime.boxToFloat(xGBoostClassificationModel.getMissing()));
        } else {
            if (!(model instanceof XGBoostRegressionModel)) {
                throw new MatchError(model);
            }
            XGBoostRegressionModel xGBoostRegressionModel = (XGBoostRegressionModel) model;
            new $colon.colon(XGBoostRegressionModel$.MODULE$._originalPredictionCol(), new $colon.colon(xGBoostRegressionModel.leafPredictionCol(), new $colon.colon(xGBoostRegressionModel.contribPredictionCol(), Nil$.MODULE$)));
            Function3 function32 = (booster2, dMatrix2, iterator2) -> {
                Iterator<Row>[] producePredictionItrs = xGBoostRegressionModel.producePredictionItrs(booster2, dMatrix2);
                Option unapplySeq = Array$.MODULE$.unapplySeq(producePredictionItrs);
                if (unapplySeq.isEmpty() || unapplySeq.get() == null || ((SeqLike) unapplySeq.get()).lengthCompare(3) != 0) {
                    throw new MatchError(producePredictionItrs);
                }
                Tuple3 tuple3 = new Tuple3((Iterator) ((SeqLike) unapplySeq.get()).apply(0), (Iterator) ((SeqLike) unapplySeq.get()).apply(1), (Iterator) ((SeqLike) unapplySeq.get()).apply(2));
                return xGBoostRegressionModel.produceResultIterator(iterator2, (Iterator) tuple3._1(), (Iterator) tuple3._2(), (Iterator) tuple3._3());
            };
            StructType structType2 = new StructType((StructField[]) new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps(dataset.schema().fields())).$plus$plus(new $colon.colon(new StructField(XGBoostRegressionModel$.MODULE$._originalPredictionCol(), new ArrayType(FloatType$.MODULE$, false), false, StructField$.MODULE$.apply$default$4()), Nil$.MODULE$), Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(StructField.class))));
            if (xGBoostRegressionModel.isDefined(xGBoostRegressionModel.leafPredictionCol())) {
                structType2 = structType2.add(new StructField(xGBoostRegressionModel.getLeafPredictionCol(), new ArrayType(FloatType$.MODULE$, false), false, StructField$.MODULE$.apply$default$4()));
            }
            if (xGBoostRegressionModel.isDefined(xGBoostRegressionModel.contribPredictionCol())) {
                structType2 = structType2.add(new StructField(xGBoostRegressionModel.getContribPredictionCol(), new ArrayType(FloatType$.MODULE$, false), false, StructField$.MODULE$.apply$default$4()));
            }
            tuple5 = new Tuple5(xGBoostRegressionModel._booster(), function32, structType2, xGBoostRegressionModel.getFeaturesCols(), BoxesRunTime.boxToFloat(xGBoostRegressionModel.getMissing()));
        }
        Tuple5 tuple52 = tuple5;
        if (tuple52 == null) {
            throw new MatchError(tuple52);
        }
        Tuple5 tuple53 = new Tuple5((Booster) tuple52._1(), (Function3) tuple52._2(), (StructType) tuple52._3(), (String[]) tuple52._4(), BoxesRunTime.boxToFloat(BoxesRunTime.unboxToFloat(tuple52._5())));
        Booster booster3 = (Booster) tuple53._1();
        Function3 function33 = (Function3) tuple53._2();
        StructType structType3 = (StructType) tuple53._3();
        String[] strArr = (String[]) tuple53._4();
        float unboxToFloat = BoxesRunTime.unboxToFloat(tuple53._5());
        SparkContext sparkContext = dataset.sparkSession().sparkContext();
        Broadcast broadcast = sparkContext.broadcast(dataset.schema(), ClassTag$.MODULE$.apply(StructType.class));
        Broadcast broadcast2 = sparkContext.broadcast(structType3, ClassTag$.MODULE$.apply(StructType.class));
        Broadcast broadcast3 = sparkContext.broadcast(booster3, ClassTag$.MODULE$.apply(Booster.class));
        Broadcast broadcast4 = sparkContext.broadcast(new BoosterFlag(), ClassTag$.MODULE$.apply(BoosterFlag.class));
        boolean isLocal = sparkContext.isLocal();
        ArrayOps.ofRef ofref = new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[]) new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps(strArr)).distinct()));
        StructType schema = dataset.schema();
        int[] iArr = (int[]) ofref.map(str -> {
            return BoxesRunTime.boxToInteger(schema.fieldIndex(str));
        }, Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.Int()));
        RDD<Table> columnarRdd = GpuUtils$.MODULE$.toColumnarRdd(dataset);
        RDD mapPartitions = columnarRdd.mapPartitions(iterator3 -> {
            final UnsafeProjection create = UnsafeProjection$.MODULE$.create((StructType) broadcast.value());
            final Booster booster4 = (Booster) broadcast3.value();
            BoosterFlag boosterFlag = (BoosterFlag) broadcast4.value();
            ?? r0 = MODULE$;
            synchronized (r0) {
                if (!boosterFlag.isGpuParamsSet()) {
                    int gPUAddrFromResources = !isLocal ? XGBoost$.MODULE$.getGPUAddrFromResources() : 0;
                    booster4.setParam("device", new StringBuilder(5).append("cuda:").append(gPUAddrFromResources).toString());
                    MODULE$.logger().info(new StringBuilder(25).append("GPU transform on device: ").append(gPUAddrFromResources).toString());
                    r0 = boosterFlag;
                    r0.isGpuParamsSet_$eq(true);
                }
            }
            return new Iterator<Row>(broadcast, iterator3, iArr, unboxToFloat, create, function33, booster4) { // from class: ml.dmlc.xgboost4j.scala.rapids.spark.GpuPreXGBoost$$anon$1
                private final Function1<InternalRow, Row> converter;
                private transient ColumnarBatch currentBatch;
                private Iterator<Row> iter;
                private final Broadcast bOrigSchema$1;
                private final Iterator tableIters$1;
                private final int[] featureIds$1;
                private final float missing$1;
                private final UnsafeProjection toUnsafe$1;
                private final Function3 predictFunc$1;
                private final Booster booster$1;

                /* renamed from: seq, reason: merged with bridge method [inline-methods] */
                public Iterator<Row> m8seq() {
                    return Iterator.seq$(this);
                }

                public boolean isEmpty() {
                    return Iterator.isEmpty$(this);
                }

                public boolean isTraversableAgain() {
                    return Iterator.isTraversableAgain$(this);
                }

                public boolean hasDefiniteSize() {
                    return Iterator.hasDefiniteSize$(this);
                }

                public Iterator<Row> take(int i) {
                    return Iterator.take$(this, i);
                }

                public Iterator<Row> drop(int i) {
                    return Iterator.drop$(this, i);
                }

                public Iterator<Row> slice(int i, int i2) {
                    return Iterator.slice$(this, i, i2);
                }

                public Iterator<Row> sliceIterator(int i, int i2) {
                    return Iterator.sliceIterator$(this, i, i2);
                }

                public <B> Iterator<B> map(Function1<Row, B> function1) {
                    return Iterator.map$(this, function1);
                }

                public <B> Iterator<B> $plus$plus(Function0<GenTraversableOnce<B>> function0) {
                    return Iterator.$plus$plus$(this, function0);
                }

                public <B> Iterator<B> flatMap(Function1<Row, GenTraversableOnce<B>> function1) {
                    return Iterator.flatMap$(this, function1);
                }

                public Iterator<Row> filter(Function1<Row, Object> function1) {
                    return Iterator.filter$(this, function1);
                }

                public <B> boolean corresponds(GenTraversableOnce<B> genTraversableOnce, Function2<Row, B, Object> function2) {
                    return Iterator.corresponds$(this, genTraversableOnce, function2);
                }

                public Iterator<Row> withFilter(Function1<Row, Object> function1) {
                    return Iterator.withFilter$(this, function1);
                }

                public Iterator<Row> filterNot(Function1<Row, Object> function1) {
                    return Iterator.filterNot$(this, function1);
                }

                public <B> Iterator<B> collect(PartialFunction<Row, B> partialFunction) {
                    return Iterator.collect$(this, partialFunction);
                }

                public <B> Iterator<B> scanLeft(B b, Function2<B, Row, B> function2) {
                    return Iterator.scanLeft$(this, b, function2);
                }

                public <B> Iterator<B> scanRight(B b, Function2<Row, B, B> function2) {
                    return Iterator.scanRight$(this, b, function2);
                }

                public Iterator<Row> takeWhile(Function1<Row, Object> function1) {
                    return Iterator.takeWhile$(this, function1);
                }

                public Tuple2<Iterator<Row>, Iterator<Row>> partition(Function1<Row, Object> function1) {
                    return Iterator.partition$(this, function1);
                }

                public Tuple2<Iterator<Row>, Iterator<Row>> span(Function1<Row, Object> function1) {
                    return Iterator.span$(this, function1);
                }

                public Iterator<Row> dropWhile(Function1<Row, Object> function1) {
                    return Iterator.dropWhile$(this, function1);
                }

                public <B> Iterator<Tuple2<Row, B>> zip(Iterator<B> iterator3) {
                    return Iterator.zip$(this, iterator3);
                }

                public <A1> Iterator<A1> padTo(int i, A1 a1) {
                    return Iterator.padTo$(this, i, a1);
                }

                public Iterator<Tuple2<Row, Object>> zipWithIndex() {
                    return Iterator.zipWithIndex$(this);
                }

                public <B, A1, B1> Iterator<Tuple2<A1, B1>> zipAll(Iterator<B> iterator3, A1 a1, B1 b1) {
                    return Iterator.zipAll$(this, iterator3, a1, b1);
                }

                public <U> void foreach(Function1<Row, U> function1) {
                    Iterator.foreach$(this, function1);
                }

                public boolean forall(Function1<Row, Object> function1) {
                    return Iterator.forall$(this, function1);
                }

                public boolean exists(Function1<Row, Object> function1) {
                    return Iterator.exists$(this, function1);
                }

                public boolean contains(Object obj) {
                    return Iterator.contains$(this, obj);
                }

                public Option<Row> find(Function1<Row, Object> function1) {
                    return Iterator.find$(this, function1);
                }

                public int indexWhere(Function1<Row, Object> function1) {
                    return Iterator.indexWhere$(this, function1);
                }

                public int indexWhere(Function1<Row, Object> function1, int i) {
                    return Iterator.indexWhere$(this, function1, i);
                }

                public <B> int indexOf(B b) {
                    return Iterator.indexOf$(this, b);
                }

                public <B> int indexOf(B b, int i) {
                    return Iterator.indexOf$(this, b, i);
                }

                public BufferedIterator<Row> buffered() {
                    return Iterator.buffered$(this);
                }

                public <B> Iterator<Row>.GroupedIterator<B> grouped(int i) {
                    return Iterator.grouped$(this, i);
                }

                public <B> Iterator<Row>.GroupedIterator<B> sliding(int i, int i2) {
                    return Iterator.sliding$(this, i, i2);
                }

                public <B> int sliding$default$2() {
                    return Iterator.sliding$default$2$(this);
                }

                public int length() {
                    return Iterator.length$(this);
                }

                public Tuple2<Iterator<Row>, Iterator<Row>> duplicate() {
                    return Iterator.duplicate$(this);
                }

                public <B> Iterator<B> patch(int i, Iterator<B> iterator3, int i2) {
                    return Iterator.patch$(this, i, iterator3, i2);
                }

                public <B> void copyToArray(Object obj, int i, int i2) {
                    Iterator.copyToArray$(this, obj, i, i2);
                }

                public boolean sameElements(Iterator<?> iterator3) {
                    return Iterator.sameElements$(this, iterator3);
                }

                /* renamed from: toTraversable, reason: merged with bridge method [inline-methods] */
                public Traversable<Row> m7toTraversable() {
                    return Iterator.toTraversable$(this);
                }

                public Iterator<Row> toIterator() {
                    return Iterator.toIterator$(this);
                }

                public Stream<Row> toStream() {
                    return Iterator.toStream$(this);
                }

                public String toString() {
                    return Iterator.toString$(this);
                }

                public List<Row> reversed() {
                    return TraversableOnce.reversed$(this);
                }

                public int size() {
                    return TraversableOnce.size$(this);
                }

                public boolean nonEmpty() {
                    return TraversableOnce.nonEmpty$(this);
                }

                public int count(Function1<Row, Object> function1) {
                    return TraversableOnce.count$(this, function1);
                }

                public <B> Option<B> collectFirst(PartialFunction<Row, B> partialFunction) {
                    return TraversableOnce.collectFirst$(this, partialFunction);
                }

                public <B> B $div$colon(B b, Function2<B, Row, B> function2) {
                    return (B) TraversableOnce.$div$colon$(this, b, function2);
                }

                public <B> B $colon$bslash(B b, Function2<Row, B, B> function2) {
                    return (B) TraversableOnce.$colon$bslash$(this, b, function2);
                }

                public <B> B foldLeft(B b, Function2<B, Row, B> function2) {
                    return (B) TraversableOnce.foldLeft$(this, b, function2);
                }

                public <B> B foldRight(B b, Function2<Row, B, B> function2) {
                    return (B) TraversableOnce.foldRight$(this, b, function2);
                }

                public <B> B reduceLeft(Function2<B, Row, B> function2) {
                    return (B) TraversableOnce.reduceLeft$(this, function2);
                }

                public <B> B reduceRight(Function2<Row, B, B> function2) {
                    return (B) TraversableOnce.reduceRight$(this, function2);
                }

                public <B> Option<B> reduceLeftOption(Function2<B, Row, B> function2) {
                    return TraversableOnce.reduceLeftOption$(this, function2);
                }

                public <B> Option<B> reduceRightOption(Function2<Row, B, B> function2) {
                    return TraversableOnce.reduceRightOption$(this, function2);
                }

                public <A1> A1 reduce(Function2<A1, A1, A1> function2) {
                    return (A1) TraversableOnce.reduce$(this, function2);
                }

                public <A1> Option<A1> reduceOption(Function2<A1, A1, A1> function2) {
                    return TraversableOnce.reduceOption$(this, function2);
                }

                public <A1> A1 fold(A1 a1, Function2<A1, A1, A1> function2) {
                    return (A1) TraversableOnce.fold$(this, a1, function2);
                }

                public <B> B aggregate(Function0<B> function0, Function2<B, Row, B> function2, Function2<B, B, B> function22) {
                    return (B) TraversableOnce.aggregate$(this, function0, function2, function22);
                }

                public <B> B sum(Numeric<B> numeric) {
                    return (B) TraversableOnce.sum$(this, numeric);
                }

                public <B> B product(Numeric<B> numeric) {
                    return (B) TraversableOnce.product$(this, numeric);
                }

                public Object min(Ordering ordering) {
                    return TraversableOnce.min$(this, ordering);
                }

                public Object max(Ordering ordering) {
                    return TraversableOnce.max$(this, ordering);
                }

                public Object maxBy(Function1 function1, Ordering ordering) {
                    return TraversableOnce.maxBy$(this, function1, ordering);
                }

                public Object minBy(Function1 function1, Ordering ordering) {
                    return TraversableOnce.minBy$(this, function1, ordering);
                }

                public <B> void copyToBuffer(Buffer<B> buffer) {
                    TraversableOnce.copyToBuffer$(this, buffer);
                }

                public <B> void copyToArray(Object obj, int i) {
                    TraversableOnce.copyToArray$(this, obj, i);
                }

                public <B> void copyToArray(Object obj) {
                    TraversableOnce.copyToArray$(this, obj);
                }

                public <B> Object toArray(ClassTag<B> classTag) {
                    return TraversableOnce.toArray$(this, classTag);
                }

                public List<Row> toList() {
                    return TraversableOnce.toList$(this);
                }

                /* renamed from: toIterable, reason: merged with bridge method [inline-methods] */
                public Iterable<Row> m6toIterable() {
                    return TraversableOnce.toIterable$(this);
                }

                /* renamed from: toSeq, reason: merged with bridge method [inline-methods] */
                public Seq<Row> m5toSeq() {
                    return TraversableOnce.toSeq$(this);
                }

                public IndexedSeq<Row> toIndexedSeq() {
                    return TraversableOnce.toIndexedSeq$(this);
                }

                public <B> Buffer<B> toBuffer() {
                    return TraversableOnce.toBuffer$(this);
                }

                /* renamed from: toSet, reason: merged with bridge method [inline-methods] */
                public <B> Set<B> m4toSet() {
                    return TraversableOnce.toSet$(this);
                }

                public Vector<Row> toVector() {
                    return TraversableOnce.toVector$(this);
                }

                public <Col> Col to(CanBuildFrom<Nothing$, Row, Col> canBuildFrom) {
                    return (Col) TraversableOnce.to$(this, canBuildFrom);
                }

                /* renamed from: toMap, reason: merged with bridge method [inline-methods] */
                public <T, U> Map<T, U> m3toMap(Predef$.less.colon.less<Row, Tuple2<T, U>> lessVar) {
                    return TraversableOnce.toMap$(this, lessVar);
                }

                public String mkString(String str2, String str3, String str4) {
                    return TraversableOnce.mkString$(this, str2, str3, str4);
                }

                public String mkString(String str2) {
                    return TraversableOnce.mkString$(this, str2);
                }

                public String mkString() {
                    return TraversableOnce.mkString$(this);
                }

                public StringBuilder addString(StringBuilder stringBuilder, String str2, String str3, String str4) {
                    return TraversableOnce.addString$(this, stringBuilder, str2, str3, str4);
                }

                public StringBuilder addString(StringBuilder stringBuilder, String str2) {
                    return TraversableOnce.addString$(this, stringBuilder, str2);
                }

                public StringBuilder addString(StringBuilder stringBuilder) {
                    return TraversableOnce.addString$(this, stringBuilder);
                }

                public int sizeHintIfCheap() {
                    return GenTraversableOnce.sizeHintIfCheap$(this);
                }

                private Function1<InternalRow, Row> converter() {
                    return this.converter;
                }

                private ColumnarBatch currentBatch() {
                    return this.currentBatch;
                }

                private void currentBatch_$eq(ColumnarBatch columnarBatch) {
                    this.currentBatch = columnarBatch;
                }

                private Iterator<Row> iter() {
                    return this.iter;
                }

                private void iter_$eq(Iterator<Row> iterator3) {
                    this.iter = iterator3;
                }

                /* JADX INFO: Access modifiers changed from: private */
                public void closeCurrentBatch() {
                    if (currentBatch() != null) {
                        currentBatch().close();
                        currentBatch_$eq(null);
                    }
                }

                private void loadNextBatch() {
                    closeCurrentBatch();
                    if (!this.tableIters$1.hasNext()) {
                        iter_$eq(null);
                    } else {
                        DataType[] dataTypeArr = (DataType[]) new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps(((StructType) this.bOrigSchema$1.value()).fields())).map(structField -> {
                            return structField.dataType();
                        }, Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(DataType.class)));
                        iter_$eq((Iterator) GpuPreXGBoost$.MODULE$.withResource((AutoCloseable) this.tableIters$1.next(), table -> {
                            Iterator iterator3;
                            Table slice = new GpuColumnBatch(table, (StructType) this.bOrigSchema$1.value()).slice((java.util.List) JavaConverters$.MODULE$.seqAsJavaListConverter(GpuUtils$.MODULE$.seqIntToSeqInteger(Predef$.MODULE$.wrapIntArray(this.featureIds$1))).asJava());
                            if (slice == null) {
                                throw new RuntimeException("Something wrong for feature indices");
                            }
                            try {
                                DMatrix dMatrix3 = new DMatrix(new CudfColumnBatch(slice, (Table) null, (Table) null, (Table) null), this.missing$1, 1);
                                if (dMatrix3 == null) {
                                    iterator3 = package$.MODULE$.Iterator().empty();
                                } else {
                                    try {
                                        this.currentBatch_$eq(new ColumnarBatch(GpuUtils$.MODULE$.extractBatchToHost(table, dataTypeArr), (int) table.getRowCount()));
                                        iterator3 = (Iterator) this.predictFunc$1.apply(this.booster$1, dMatrix3, ((Iterator) JavaConverters$.MODULE$.asScalaIteratorConverter(this.currentBatch().rowIterator()).asScala()).map(this.toUnsafe$1).map(internalRow -> {
                                            return (Row) this.converter().apply(internalRow);
                                        }));
                                    } finally {
                                        dMatrix3.delete();
                                    }
                                }
                                return iterator3;
                            } finally {
                                slice.close();
                            }
                        }));
                    }
                }

                public boolean hasNext() {
                    boolean z = iter() != null && iter().hasNext();
                    if (z) {
                        return z;
                    }
                    loadNextBatch();
                    return iter() != null && iter().hasNext();
                }

                /* renamed from: next, reason: merged with bridge method [inline-methods] */
                public Row m9next() {
                    if (iter() == null || !iter().hasNext()) {
                        loadNextBatch();
                    }
                    if (iter() == null) {
                        throw new NoSuchElementException();
                    }
                    return (Row) iter().next();
                }

                {
                    this.bOrigSchema$1 = broadcast;
                    this.tableIters$1 = iterator3;
                    this.featureIds$1 = iArr;
                    this.missing$1 = unboxToFloat;
                    this.toUnsafe$1 = create;
                    this.predictFunc$1 = function33;
                    this.booster$1 = booster4;
                    GenTraversableOnce.$init$(this);
                    TraversableOnce.$init$(this);
                    Iterator.$init$(this);
                    this.converter = CatalystTypeConverters$.MODULE$.createToScalaConverter((DataType) broadcast.value());
                    this.currentBatch = null;
                    this.iter = null;
                    TaskContext$.MODULE$.get().addTaskCompletionListener(taskContext -> {
                        this.closeCurrentBatch();
                        return BoxedUnit.UNIT;
                    });
                }
            };
        }, columnarRdd.mapPartitions$default$2(), ClassTag$.MODULE$.apply(Row.class));
        broadcast.unpersist(false);
        broadcast2.unpersist(false);
        broadcast3.unpersist(false);
        return dataset.sparkSession().createDataFrame(mapPartitions, structType3);
    }

    @Override // ml.dmlc.xgboost4j.scala.spark.PreXGBoostProvider
    public StructType transformSchema(XGBoostEstimatorCommon xGBoostEstimatorCommon, StructType structType) {
        boolean z = xGBoostEstimatorCommon instanceof XGBoostClassifier ? true : xGBoostEstimatorCommon instanceof XGBoostRegressor;
        Seq<String> columnNames = GpuUtils$.MODULE$.getColumnNames(xGBoostEstimatorCommon, Predef$.MODULE$.wrapRefArray(new Param[]{xGBoostEstimatorCommon.labelCol(), xGBoostEstimatorCommon.weightCol(), xGBoostEstimatorCommon.baseMarginCol()}));
        Some unapplySeq = Seq$.MODULE$.unapplySeq(columnNames);
        if (unapplySeq.isEmpty() || unapplySeq.get() == null || ((SeqLike) unapplySeq.get()).lengthCompare(3) != 0) {
            throw new MatchError(columnNames);
        }
        Tuple3 tuple3 = new Tuple3((String) ((SeqLike) unapplySeq.get()).apply(0), (String) ((SeqLike) unapplySeq.get()).apply(1), (String) ((SeqLike) unapplySeq.get()).apply(2));
        return GpuUtils$.MODULE$.validateSchema(structType, Predef$.MODULE$.wrapRefArray(xGBoostEstimatorCommon.getFeaturesCols()), (String) tuple3._1(), (String) tuple3._2(), (String) tuple3._3(), z);
    }

    private Map<String, ColumnDataBatch> prepareInputData(ColumnDataBatch columnDataBatch, Map<String, ColumnDataBatch> map, int i, boolean z) {
        if (z) {
            logger().warn("the cache param will be ignored by GPU pipeline!");
        }
        return (Map) Predef$.MODULE$.Map().apply(Predef$.MODULE$.wrapRefArray(new Tuple2[]{Predef$ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc(TRAIN_NAME()), columnDataBatch)})).$plus$plus(map).map(tuple2 -> {
            if (tuple2 == null) {
                throw new MatchError(tuple2);
            }
            String str = (String) tuple2._1();
            ColumnDataBatch columnDataBatch2 = (ColumnDataBatch) tuple2._2();
            return Predef$ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc(str), new ColumnDataBatch((Dataset) columnDataBatch2.groupColName().map(str2 -> {
                return MODULE$.repartitionForGroup(str2, columnDataBatch2.rawDF(), i);
            }).getOrElse(() -> {
                return MODULE$.repartitionInputData(columnDataBatch2.rawDF(), i);
            }), columnDataBatch2.colIndices(), columnDataBatch2.groupColName()));
        }, Map$.MODULE$.canBuildFrom());
    }

    /* JADX INFO: Access modifiers changed from: private */
    public Dataset<Row> repartitionInputData(Dataset<Row> dataset, int i) {
        return dataset.repartition(i);
    }

    /* JADX INFO: Access modifiers changed from: private */
    public Dataset<Row> repartitionForGroup(String str, Dataset<Row> dataset, int i) {
        logger().info("Start groupBy for LTR");
        StructType schema = dataset.schema();
        Dataset<Row> agg = dataset.groupBy(str, Predef$.MODULE$.wrapRefArray(new String[0])).agg(functions$.MODULE$.collect_list(functions$.MODULE$.struct(Predef$.MODULE$.wrapRefArray((Object[]) new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps(schema.fieldNames())).map(str2 -> {
            return functions$.MODULE$.col(str2);
        }, Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(Column.class)))))).as("list"), Predef$.MODULE$.wrapRefArray(new Column[0]));
        return repartitionInputData(agg, i).mapPartitions(iterator -> {
            return new Iterator<Row>(iterator) { // from class: ml.dmlc.xgboost4j.scala.rapids.spark.GpuPreXGBoost$$anon$2
                private Iterator<Object> iterInRow;
                private final Iterator iter$1;

                /* renamed from: seq, reason: merged with bridge method [inline-methods] */
                public Iterator<Row> m15seq() {
                    return Iterator.seq$(this);
                }

                public boolean isEmpty() {
                    return Iterator.isEmpty$(this);
                }

                public boolean isTraversableAgain() {
                    return Iterator.isTraversableAgain$(this);
                }

                public boolean hasDefiniteSize() {
                    return Iterator.hasDefiniteSize$(this);
                }

                public Iterator<Row> take(int i2) {
                    return Iterator.take$(this, i2);
                }

                public Iterator<Row> drop(int i2) {
                    return Iterator.drop$(this, i2);
                }

                public Iterator<Row> slice(int i2, int i3) {
                    return Iterator.slice$(this, i2, i3);
                }

                public Iterator<Row> sliceIterator(int i2, int i3) {
                    return Iterator.sliceIterator$(this, i2, i3);
                }

                public <B> Iterator<B> map(Function1<Row, B> function1) {
                    return Iterator.map$(this, function1);
                }

                public <B> Iterator<B> $plus$plus(Function0<GenTraversableOnce<B>> function0) {
                    return Iterator.$plus$plus$(this, function0);
                }

                public <B> Iterator<B> flatMap(Function1<Row, GenTraversableOnce<B>> function1) {
                    return Iterator.flatMap$(this, function1);
                }

                public Iterator<Row> filter(Function1<Row, Object> function1) {
                    return Iterator.filter$(this, function1);
                }

                public <B> boolean corresponds(GenTraversableOnce<B> genTraversableOnce, Function2<Row, B, Object> function2) {
                    return Iterator.corresponds$(this, genTraversableOnce, function2);
                }

                public Iterator<Row> withFilter(Function1<Row, Object> function1) {
                    return Iterator.withFilter$(this, function1);
                }

                public Iterator<Row> filterNot(Function1<Row, Object> function1) {
                    return Iterator.filterNot$(this, function1);
                }

                public <B> Iterator<B> collect(PartialFunction<Row, B> partialFunction) {
                    return Iterator.collect$(this, partialFunction);
                }

                public <B> Iterator<B> scanLeft(B b, Function2<B, Row, B> function2) {
                    return Iterator.scanLeft$(this, b, function2);
                }

                public <B> Iterator<B> scanRight(B b, Function2<Row, B, B> function2) {
                    return Iterator.scanRight$(this, b, function2);
                }

                public Iterator<Row> takeWhile(Function1<Row, Object> function1) {
                    return Iterator.takeWhile$(this, function1);
                }

                public Tuple2<Iterator<Row>, Iterator<Row>> partition(Function1<Row, Object> function1) {
                    return Iterator.partition$(this, function1);
                }

                public Tuple2<Iterator<Row>, Iterator<Row>> span(Function1<Row, Object> function1) {
                    return Iterator.span$(this, function1);
                }

                public Iterator<Row> dropWhile(Function1<Row, Object> function1) {
                    return Iterator.dropWhile$(this, function1);
                }

                public <B> Iterator<Tuple2<Row, B>> zip(Iterator<B> iterator) {
                    return Iterator.zip$(this, iterator);
                }

                public <A1> Iterator<A1> padTo(int i2, A1 a1) {
                    return Iterator.padTo$(this, i2, a1);
                }

                public Iterator<Tuple2<Row, Object>> zipWithIndex() {
                    return Iterator.zipWithIndex$(this);
                }

                public <B, A1, B1> Iterator<Tuple2<A1, B1>> zipAll(Iterator<B> iterator, A1 a1, B1 b1) {
                    return Iterator.zipAll$(this, iterator, a1, b1);
                }

                public <U> void foreach(Function1<Row, U> function1) {
                    Iterator.foreach$(this, function1);
                }

                public boolean forall(Function1<Row, Object> function1) {
                    return Iterator.forall$(this, function1);
                }

                public boolean exists(Function1<Row, Object> function1) {
                    return Iterator.exists$(this, function1);
                }

                public boolean contains(Object obj) {
                    return Iterator.contains$(this, obj);
                }

                public Option<Row> find(Function1<Row, Object> function1) {
                    return Iterator.find$(this, function1);
                }

                public int indexWhere(Function1<Row, Object> function1) {
                    return Iterator.indexWhere$(this, function1);
                }

                public int indexWhere(Function1<Row, Object> function1, int i2) {
                    return Iterator.indexWhere$(this, function1, i2);
                }

                public <B> int indexOf(B b) {
                    return Iterator.indexOf$(this, b);
                }

                public <B> int indexOf(B b, int i2) {
                    return Iterator.indexOf$(this, b, i2);
                }

                public BufferedIterator<Row> buffered() {
                    return Iterator.buffered$(this);
                }

                public <B> Iterator<Row>.GroupedIterator<B> grouped(int i2) {
                    return Iterator.grouped$(this, i2);
                }

                public <B> Iterator<Row>.GroupedIterator<B> sliding(int i2, int i3) {
                    return Iterator.sliding$(this, i2, i3);
                }

                public <B> int sliding$default$2() {
                    return Iterator.sliding$default$2$(this);
                }

                public int length() {
                    return Iterator.length$(this);
                }

                public Tuple2<Iterator<Row>, Iterator<Row>> duplicate() {
                    return Iterator.duplicate$(this);
                }

                public <B> Iterator<B> patch(int i2, Iterator<B> iterator, int i3) {
                    return Iterator.patch$(this, i2, iterator, i3);
                }

                public <B> void copyToArray(Object obj, int i2, int i3) {
                    Iterator.copyToArray$(this, obj, i2, i3);
                }

                public boolean sameElements(Iterator<?> iterator) {
                    return Iterator.sameElements$(this, iterator);
                }

                /* renamed from: toTraversable, reason: merged with bridge method [inline-methods] */
                public Traversable<Row> m14toTraversable() {
                    return Iterator.toTraversable$(this);
                }

                public Iterator<Row> toIterator() {
                    return Iterator.toIterator$(this);
                }

                public Stream<Row> toStream() {
                    return Iterator.toStream$(this);
                }

                public String toString() {
                    return Iterator.toString$(this);
                }

                public List<Row> reversed() {
                    return TraversableOnce.reversed$(this);
                }

                public int size() {
                    return TraversableOnce.size$(this);
                }

                public boolean nonEmpty() {
                    return TraversableOnce.nonEmpty$(this);
                }

                public int count(Function1<Row, Object> function1) {
                    return TraversableOnce.count$(this, function1);
                }

                public <B> Option<B> collectFirst(PartialFunction<Row, B> partialFunction) {
                    return TraversableOnce.collectFirst$(this, partialFunction);
                }

                public <B> B $div$colon(B b, Function2<B, Row, B> function2) {
                    return (B) TraversableOnce.$div$colon$(this, b, function2);
                }

                public <B> B $colon$bslash(B b, Function2<Row, B, B> function2) {
                    return (B) TraversableOnce.$colon$bslash$(this, b, function2);
                }

                public <B> B foldLeft(B b, Function2<B, Row, B> function2) {
                    return (B) TraversableOnce.foldLeft$(this, b, function2);
                }

                public <B> B foldRight(B b, Function2<Row, B, B> function2) {
                    return (B) TraversableOnce.foldRight$(this, b, function2);
                }

                public <B> B reduceLeft(Function2<B, Row, B> function2) {
                    return (B) TraversableOnce.reduceLeft$(this, function2);
                }

                public <B> B reduceRight(Function2<Row, B, B> function2) {
                    return (B) TraversableOnce.reduceRight$(this, function2);
                }

                public <B> Option<B> reduceLeftOption(Function2<B, Row, B> function2) {
                    return TraversableOnce.reduceLeftOption$(this, function2);
                }

                public <B> Option<B> reduceRightOption(Function2<Row, B, B> function2) {
                    return TraversableOnce.reduceRightOption$(this, function2);
                }

                public <A1> A1 reduce(Function2<A1, A1, A1> function2) {
                    return (A1) TraversableOnce.reduce$(this, function2);
                }

                public <A1> Option<A1> reduceOption(Function2<A1, A1, A1> function2) {
                    return TraversableOnce.reduceOption$(this, function2);
                }

                public <A1> A1 fold(A1 a1, Function2<A1, A1, A1> function2) {
                    return (A1) TraversableOnce.fold$(this, a1, function2);
                }

                public <B> B aggregate(Function0<B> function0, Function2<B, Row, B> function2, Function2<B, B, B> function22) {
                    return (B) TraversableOnce.aggregate$(this, function0, function2, function22);
                }

                public <B> B sum(Numeric<B> numeric) {
                    return (B) TraversableOnce.sum$(this, numeric);
                }

                public <B> B product(Numeric<B> numeric) {
                    return (B) TraversableOnce.product$(this, numeric);
                }

                public Object min(Ordering ordering) {
                    return TraversableOnce.min$(this, ordering);
                }

                public Object max(Ordering ordering) {
                    return TraversableOnce.max$(this, ordering);
                }

                public Object maxBy(Function1 function1, Ordering ordering) {
                    return TraversableOnce.maxBy$(this, function1, ordering);
                }

                public Object minBy(Function1 function1, Ordering ordering) {
                    return TraversableOnce.minBy$(this, function1, ordering);
                }

                public <B> void copyToBuffer(Buffer<B> buffer) {
                    TraversableOnce.copyToBuffer$(this, buffer);
                }

                public <B> void copyToArray(Object obj, int i2) {
                    TraversableOnce.copyToArray$(this, obj, i2);
                }

                public <B> void copyToArray(Object obj) {
                    TraversableOnce.copyToArray$(this, obj);
                }

                public <B> Object toArray(ClassTag<B> classTag) {
                    return TraversableOnce.toArray$(this, classTag);
                }

                public List<Row> toList() {
                    return TraversableOnce.toList$(this);
                }

                /* renamed from: toIterable, reason: merged with bridge method [inline-methods] */
                public Iterable<Row> m13toIterable() {
                    return TraversableOnce.toIterable$(this);
                }

                /* renamed from: toSeq, reason: merged with bridge method [inline-methods] */
                public Seq<Row> m12toSeq() {
                    return TraversableOnce.toSeq$(this);
                }

                public IndexedSeq<Row> toIndexedSeq() {
                    return TraversableOnce.toIndexedSeq$(this);
                }

                public <B> Buffer<B> toBuffer() {
                    return TraversableOnce.toBuffer$(this);
                }

                /* renamed from: toSet, reason: merged with bridge method [inline-methods] */
                public <B> Set<B> m11toSet() {
                    return TraversableOnce.toSet$(this);
                }

                public Vector<Row> toVector() {
                    return TraversableOnce.toVector$(this);
                }

                public <Col> Col to(CanBuildFrom<Nothing$, Row, Col> canBuildFrom) {
                    return (Col) TraversableOnce.to$(this, canBuildFrom);
                }

                /* renamed from: toMap, reason: merged with bridge method [inline-methods] */
                public <T, U> Map<T, U> m10toMap(Predef$.less.colon.less<Row, Tuple2<T, U>> lessVar) {
                    return TraversableOnce.toMap$(this, lessVar);
                }

                public String mkString(String str3, String str4, String str5) {
                    return TraversableOnce.mkString$(this, str3, str4, str5);
                }

                public String mkString(String str3) {
                    return TraversableOnce.mkString$(this, str3);
                }

                public String mkString() {
                    return TraversableOnce.mkString$(this);
                }

                public StringBuilder addString(StringBuilder stringBuilder, String str3, String str4, String str5) {
                    return TraversableOnce.addString$(this, stringBuilder, str3, str4, str5);
                }

                public StringBuilder addString(StringBuilder stringBuilder, String str3) {
                    return TraversableOnce.addString$(this, stringBuilder, str3);
                }

                public StringBuilder addString(StringBuilder stringBuilder) {
                    return TraversableOnce.addString$(this, stringBuilder);
                }

                public int sizeHintIfCheap() {
                    return GenTraversableOnce.sizeHintIfCheap$(this);
                }

                private Iterator<Object> iterInRow() {
                    return this.iterInRow;
                }

                private void iterInRow_$eq(Iterator<Object> iterator) {
                    this.iterInRow = iterator;
                }

                public boolean hasNext() {
                    if (this.iter$1.hasNext() && !iterInRow().hasNext()) {
                        iterInRow_$eq(((Row) this.iter$1.next()).getSeq(1).iterator());
                    }
                    return iterInRow().hasNext();
                }

                /* renamed from: next, reason: merged with bridge method [inline-methods] */
                public Row m16next() {
                    return (Row) iterInRow().next();
                }

                {
                    this.iter$1 = iterator;
                    GenTraversableOnce.$init$(this);
                    TraversableOnce.$init$(this);
                    Iterator.$init$(this);
                    this.iterInRow = package$.MODULE$.Iterator().empty();
                }
            };
        }, RowEncoder$.MODULE$.apply(schema));
    }

    private RDD<Function0<Watches>> buildRDDWatches(Map<String, ColumnDataBatch> map, XGBoostExecutionParams xGBoostExecutionParams, boolean z) {
        SparkContext sparkContext = ((ColumnDataBatch) map.apply(TRAIN_NAME())).rawDF().sparkSession().sparkContext();
        int unboxToInt = BoxesRunTime.unboxToInt(xGBoostExecutionParams.toMap().getOrElse("max_bin", () -> {
            return 256;
        }));
        if (z) {
            ColumnIndices colIndices = ((ColumnDataBatch) map.apply(TRAIN_NAME())).colIndices();
            RDD<Table> columnarRdd = GpuUtils$.MODULE$.toColumnarRdd(((ColumnDataBatch) map.apply(TRAIN_NAME())).rawDF());
            return columnarRdd.mapPartitions(iterator -> {
                Iterator map2 = iterator.map(table -> {
                    return new GpuColumnBatch(table, null);
                });
                return package$.MODULE$.Iterator().apply(Predef$.MODULE$.wrapRefArray(new Function0[]{() -> {
                    return MODULE$.buildWatches(PreXGBoost$.MODULE$.getCacheDirName(xGBoostExecutionParams.useExternalMemory()), xGBoostExecutionParams.missing(), colIndices, map2, unboxToInt);
                }}));
            }, columnarRdd.mapPartitions$default$2(), ClassTag$.MODULE$.apply(Function0.class));
        }
        Map map2 = (Map) map.map(tuple2 -> {
            return new Tuple2(tuple2._1(), ((ColumnDataBatch) tuple2._2()).colIndices());
        }, Map$.MODULE$.canBuildFrom());
        RDD<Tuple2<String, Iterator<GpuColumnBatch>>> coPartitionForGpu = coPartitionForGpu(map, sparkContext, xGBoostExecutionParams.numWorkers());
        return coPartitionForGpu.mapPartitions(iterator2 -> {
            return package$.MODULE$.Iterator().apply(Predef$.MODULE$.wrapRefArray(new Function0[]{() -> {
                return MODULE$.buildWatchesWithEval(PreXGBoost$.MODULE$.getCacheDirName(xGBoostExecutionParams.useExternalMemory()), xGBoostExecutionParams.missing(), map2, iterator2, unboxToInt);
            }}));
        }, coPartitionForGpu.mapPartitions$default$2(), ClassTag$.MODULE$.apply(Function0.class));
    }

    /* JADX INFO: Access modifiers changed from: private */
    public Watches buildWatches(Option<String> option, float f, ColumnIndices columnIndices, Iterator<GpuColumnBatch> iterator, int i) {
        Tuple2 time = GpuUtils$.MODULE$.time(() -> {
            return MODULE$.buildDMatrix(iterator, columnIndices, f, i);
        });
        if (time == null) {
            throw new MatchError(time);
        }
        Tuple2 tuple2 = new Tuple2((DMatrix) time._1(), BoxesRunTime.boxToFloat(BoxesRunTime.unboxToFloat(time._2())));
        DMatrix dMatrix = (DMatrix) tuple2._1();
        logger().debug(new StringBuilder(46).append("Benchmark[Train: Build DMatrix incrementally] ").append(BoxesRunTime.unboxToFloat(tuple2._2())).toString());
        Tuple2 tuple22 = dMatrix == null ? new Tuple2(Array$.MODULE$.empty(ClassTag$.MODULE$.apply(DMatrix.class)), Array$.MODULE$.empty(ClassTag$.MODULE$.apply(String.class))) : new Tuple2(new DMatrix[]{dMatrix}, new String[]{"train"});
        if (tuple22 == null) {
            throw new MatchError(tuple22);
        }
        Tuple2 tuple23 = new Tuple2((DMatrix[]) tuple22._1(), (String[]) tuple22._2());
        return new Watches((DMatrix[]) tuple23._1(), (String[]) tuple23._2(), option);
    }

    /* JADX INFO: Access modifiers changed from: private */
    public Watches buildWatchesWithEval(Option<String> option, float f, Map<String, ColumnIndices> map, Iterator<Tuple2<String, Iterator<GpuColumnBatch>>> iterator, int i) {
        Tuple2[] tuple2Arr = (Tuple2[]) iterator.map(tuple2 -> {
            if (tuple2 == null) {
                throw new MatchError(tuple2);
            }
            String str = (String) tuple2._1();
            Iterator iterator2 = (Iterator) tuple2._2();
            Tuple2 time = GpuUtils$.MODULE$.time(() -> {
                return MODULE$.buildDMatrix(iterator2, (ColumnIndices) map.apply(str), f, i);
            });
            if (time == null) {
                throw new MatchError(time);
            }
            Tuple2 tuple2 = new Tuple2((DMatrix) time._1(), BoxesRunTime.boxToFloat(BoxesRunTime.unboxToFloat(time._2())));
            DMatrix dMatrix = (DMatrix) tuple2._1();
            MODULE$.logger().debug(new StringBuilder(32).append("Benchmark[Train build ").append(str).append(" DMatrix] ").append(BoxesRunTime.unboxToFloat(tuple2._2())).toString());
            return new Tuple2(str, dMatrix);
        }).filter(tuple22 -> {
            return BoxesRunTime.boxToBoolean($anonfun$buildWatchesWithEval$3(tuple22));
        }).toArray(ClassTag$.MODULE$.apply(Tuple2.class));
        return new Watches((DMatrix[]) new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps(tuple2Arr)).map(tuple23 -> {
            return (DMatrix) tuple23._2();
        }, Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(DMatrix.class))), (String[]) new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps(tuple2Arr)).map(tuple24 -> {
            return (String) tuple24._1();
        }, Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(String.class))), option);
    }

    /* JADX INFO: Access modifiers changed from: private */
    public DMatrix buildDMatrix(Iterator<GpuColumnBatch> iterator, ColumnIndices columnIndices, float f, int i) {
        return new QuantileDMatrix(new GpuPreXGBoost.RapidsIterator(iterator, columnIndices), f, i, 1);
    }

    private RDD<Tuple2<String, Iterator<GpuColumnBatch>>> coPartitionForGpu(Map<String, ColumnDataBatch> map, SparkContext sparkContext, int i) {
        return (RDD) map.foldLeft(sparkContext.parallelize(Predef$.MODULE$.wrapRefArray((Object[]) Array$.MODULE$.fill(i, () -> {
            return null;
        }, ClassTag$.MODULE$.apply(Tuple2.class))), i, ClassTag$.MODULE$.apply(Tuple2.class)), (rdd, tuple2) -> {
            Tuple2 tuple2 = new Tuple2(rdd, tuple2);
            if (tuple2 != null) {
                RDD rdd = (RDD) tuple2._1();
                Tuple2 tuple22 = (Tuple2) tuple2._2();
                if (tuple22 != null) {
                    String str = (String) tuple22._1();
                    return rdd.zipPartitions(GpuUtils$.MODULE$.toColumnarRdd(((ColumnDataBatch) tuple22._2()).rawDF()), (iterator, iterator2) -> {
                        return new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[]) new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[]) new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[]) iterator.toArray(ClassTag$.MODULE$.apply(Tuple2.class)))).$colon$plus(Predef$ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc(str), iterator2.map(table -> {
                            return new GpuColumnBatch(table, null);
                        })), ClassTag$.MODULE$.apply(Tuple2.class)))).filter(tuple23 -> {
                            return BoxesRunTime.boxToBoolean($anonfun$coPartitionForGpu$5(tuple23));
                        }))).toIterator();
                    }, ClassTag$.MODULE$.apply(Table.class), ClassTag$.MODULE$.apply(Tuple2.class));
                }
            }
            throw new MatchError(tuple2);
        });
    }

    public <T extends AutoCloseable, V> V withResource(T t, Function1<T, V> function1) {
        try {
            return (V) function1.apply(t);
        } finally {
            t.close();
        }
    }

    public static final /* synthetic */ boolean $anonfun$buildWatchesWithEval$3(Tuple2 tuple2) {
        return tuple2._2() != null;
    }

    public static final /* synthetic */ boolean $anonfun$coPartitionForGpu$5(Tuple2 tuple2) {
        return tuple2 != null;
    }

    private GpuPreXGBoost$() {
        MODULE$ = this;
        PreXGBoostProvider.$init$(this);
        this.logger = LogFactory.getLog("XGBoostSpark");
        this.FEATURES_COLS = "features_cols";
        this.TRAIN_NAME = "train";
    }
}
