package ml.dmlc.xgboost4j.java.example;

import java.io.File;
import java.io.IOException;
import java.io.PrintWriter;
import java.util.Arrays;
import java.util.HashMap;
import ml.dmlc.xgboost4j.java.Booster;
import ml.dmlc.xgboost4j.java.DMatrix;
import ml.dmlc.xgboost4j.java.IEvaluation;
import ml.dmlc.xgboost4j.java.IObjective;
import ml.dmlc.xgboost4j.java.XGBoost;
import ml.dmlc.xgboost4j.java.XGBoostError;
import ml.dmlc.xgboost4j.java.example.util.DataLoader;

/* loaded from: input_file:ml/dmlc/xgboost4j/java/example/BasicWalkThrough.class */
public class BasicWalkThrough {
    public static boolean checkPredicts(float[][] fArr, float[][] fArr2) {
        if (fArr.length != fArr2.length) {
            return false;
        }
        for (int i = 0; i < fArr.length; i++) {
            if (!Arrays.equals(fArr[i], fArr2[i])) {
                return false;
            }
        }
        return true;
    }

    public static void saveDumpModel(String str, String[] strArr) throws IOException {
        try {
            PrintWriter printWriter = new PrintWriter(str, "UTF-8");
            for (int i = 0; i < strArr.length; i++) {
                printWriter.print("booster[" + i + "]:\n");
                printWriter.print(strArr[i]);
            }
            printWriter.close();
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public static void main(String[] strArr) throws IOException, XGBoostError {
        DMatrix dMatrix = new DMatrix("../../demo/data/agaricus.txt.train?format=libsvm&indexing_mode=1");
        DMatrix dMatrix2 = new DMatrix("../../demo/data/agaricus.txt.test?format=libsvm&indexing_mode=1");
        HashMap hashMap = new HashMap();
        hashMap.put("eta", Double.valueOf(1.0d));
        hashMap.put("max_depth", 2);
        hashMap.put("silent", 1);
        hashMap.put("objective", "binary:logistic");
        HashMap hashMap2 = new HashMap();
        hashMap2.put("train", dMatrix);
        hashMap2.put("test", dMatrix2);
        Booster train = XGBoost.train(dMatrix, hashMap, 2, hashMap2, (IObjective) null, (IEvaluation) null);
        float[][] predict = train.predict(dMatrix2);
        File file = new File("./model");
        if (!file.exists()) {
            file.mkdirs();
        }
        train.saveModel("./model/xgb.model");
        saveDumpModel("./model/dump.raw.txt", train.getModelDump("../../demo/data/featmap.txt", false));
        dMatrix2.saveBinary("./model/dtest.buffer");
        Booster loadModel = XGBoost.loadModel("./model/xgb.model");
        DMatrix dMatrix3 = new DMatrix("./model/dtest.buffer");
        System.out.println(checkPredicts(predict, loadModel.predict(dMatrix3)));
        System.out.println("start build dmatrix from csr sparse data ...");
        DataLoader.CSRSparseData loadSVMFile = DataLoader.loadSVMFile("../../demo/data/agaricus.txt.train");
        DMatrix dMatrix4 = new DMatrix(loadSVMFile.rowHeaders, loadSVMFile.colIndex, loadSVMFile.data, DMatrix.SparseType.CSR, 127);
        dMatrix4.setLabel(loadSVMFile.labels);
        HashMap hashMap3 = new HashMap();
        hashMap3.put("train", dMatrix4);
        hashMap3.put("test", dMatrix3);
        System.out.println(checkPredicts(predict, XGBoost.train(dMatrix4, hashMap, 2, hashMap3, (IObjective) null, (IEvaluation) null).predict(dMatrix3)));
    }
}
