package net.maizegenetics.stats.EMMA;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.stream.IntStream;
import net.maizegenetics.matrixalgebra.Matrix.DoubleMatrix;
import net.maizegenetics.matrixalgebra.Matrix.DoubleMatrixFactory;
import net.maizegenetics.matrixalgebra.decomposition.EigenvalueDecomposition;
import net.maizegenetics.stats.linearmodels.LinearModelUtils;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

/* loaded from: input_file:net/maizegenetics/stats/EMMA/EMMAforDoubleMatrix.class */
public class EMMAforDoubleMatrix {
    private static final Logger myLogger = LogManager.getLogger(EMMAforDoubleMatrix.class);
    protected DoubleMatrix y;
    protected double[] lambda;
    protected double[] eta2;
    protected double c;
    protected int N;
    protected int q;
    protected int Nran;
    protected int dfMarker;
    protected DoubleMatrix Xoriginal;
    protected DoubleMatrix X;
    protected DoubleMatrix Zoriginal;
    protected DoubleMatrix Z;
    protected DoubleMatrix K;
    protected EigenvalueDecomposition eig;
    protected EigenvalueDecomposition eigA;
    protected DoubleMatrix U;
    protected DoubleMatrix invH;
    protected DoubleMatrix invXHX;
    protected DoubleMatrix beta;
    protected DoubleMatrix Xbeta;
    protected double ssModel;
    protected double ssError;
    protected double SST;
    protected double Rsq;
    protected int dfModel;
    protected int dfError;
    protected double delta;
    protected double varResidual;
    protected double varRandomEffect;
    protected DoubleMatrix blup;
    protected DoubleMatrix pred;
    protected DoubleMatrix res;
    protected DoubleMatrix pev;
    protected double lnLikelihood;
    protected boolean findDelta;
    protected boolean calculatePEV;
    protected double lowerlimit;
    protected double upperlimit;
    protected int nregions;
    protected double convergence;
    protected int maxiter;
    protected int subintervalCount;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: net.maizegenetics.stats.EMMA.EMMAforDoubleMatrix$1Pair, reason: invalid class name */
    /* loaded from: input_file:net/maizegenetics/stats/EMMA/EMMAforDoubleMatrix$1Pair.class */
    public class C1Pair implements Comparable<C1Pair> {
        int order;
        double absvalue;

        C1Pair(int i, double d) {
            this.order = i;
            this.absvalue = Math.abs(d);
        }

        @Override // java.lang.Comparable
        public int compareTo(C1Pair c1Pair) {
            if (this.absvalue < c1Pair.absvalue) {
                return 1;
            }
            return this.absvalue > c1Pair.absvalue ? -1 : 0;
        }
    }

    public EMMAforDoubleMatrix(DoubleMatrix doubleMatrix, DoubleMatrix doubleMatrix2, DoubleMatrix doubleMatrix3, int i) {
        this(doubleMatrix, doubleMatrix2, doubleMatrix3, i, Double.NaN);
    }

    public EMMAforDoubleMatrix(DoubleMatrix doubleMatrix, DoubleMatrix doubleMatrix2, DoubleMatrix doubleMatrix3, int i, double d) {
        this.dfMarker = 0;
        this.Xoriginal = null;
        this.Zoriginal = null;
        this.Z = null;
        this.findDelta = true;
        this.calculatePEV = false;
        this.lowerlimit = 1.0E-5d;
        this.upperlimit = 100000.0d;
        this.nregions = 100;
        this.convergence = 1.0E-10d;
        this.maxiter = 50;
        this.subintervalCount = 0;
        this.dfModel = doubleMatrix2.numberOfColumns();
        if (doubleMatrix2.columnRank() < this.dfModel) {
            throw new IllegalArgumentException("The fixed effect design matrix has less than full column rank. The analysis will not be run.");
        }
        if (!Double.isNaN(d)) {
            this.delta = d;
            this.findDelta = false;
        }
        this.y = doubleMatrix;
        if (this.y.numberOfColumns() > 1 && this.y.numberOfRows() == 1) {
            this.y = this.y.transpose();
        }
        this.N = this.y.numberOfRows();
        this.X = doubleMatrix2;
        this.q = this.X.numberOfColumns();
        this.K = doubleMatrix3;
        this.Nran = this.K.numberOfRows();
        this.Z = DoubleMatrixFactory.DEFAULT.identity(this.Nran);
        this.dfMarker = i - 1;
        init();
    }

    public EMMAforDoubleMatrix(DoubleMatrix doubleMatrix, DoubleMatrix doubleMatrix2, DoubleMatrix doubleMatrix3, DoubleMatrix doubleMatrix4, int i, double d) {
        this.dfMarker = 0;
        this.Xoriginal = null;
        this.Zoriginal = null;
        this.Z = null;
        this.findDelta = true;
        this.calculatePEV = false;
        this.lowerlimit = 1.0E-5d;
        this.upperlimit = 100000.0d;
        this.nregions = 100;
        this.convergence = 1.0E-10d;
        this.maxiter = 50;
        this.subintervalCount = 0;
        this.dfModel = doubleMatrix2.numberOfColumns();
        if (doubleMatrix2.columnRank() < this.dfModel) {
            throw new IllegalArgumentException("The fixed effect design matrix has less than full column rank. The analysis will not be run.");
        }
        if (!Double.isNaN(d)) {
            this.delta = d;
            this.findDelta = false;
        }
        this.y = doubleMatrix;
        if (this.y.numberOfColumns() > 1 && this.y.numberOfRows() == 1) {
            this.y = this.y.transpose();
        }
        this.N = this.y.numberOfRows();
        this.X = doubleMatrix2;
        this.q = this.X.numberOfColumns();
        this.Z = doubleMatrix4;
        this.K = doubleMatrix3;
        this.Nran = this.Z.numberOfRows();
        this.dfMarker = i - 1;
        init();
    }

    public EMMAforDoubleMatrix(DoubleMatrix doubleMatrix, DoubleMatrix doubleMatrix2, DoubleMatrix doubleMatrix3) {
        this.dfMarker = 0;
        this.Xoriginal = null;
        this.Zoriginal = null;
        this.Z = null;
        this.findDelta = true;
        this.calculatePEV = false;
        this.lowerlimit = 1.0E-5d;
        this.upperlimit = 100000.0d;
        this.nregions = 100;
        this.convergence = 1.0E-10d;
        this.maxiter = 50;
        this.subintervalCount = 0;
        this.dfModel = doubleMatrix2.numberOfColumns();
        if (doubleMatrix2.columnRank() < this.dfModel) {
            throw new IllegalArgumentException("The fixed effect design matrix has less than full column rank. The analysis will not be run.");
        }
        if (doubleMatrix.numberOfColumns() > 1 && doubleMatrix.numberOfRows() == 1) {
            throw new IllegalArgumentException("The phenotype data must be a column matrix.");
        }
        this.K = doubleMatrix3;
        this.Nran = this.K.numberOfRows();
        int[] array = IntStream.range(0, this.Nran).filter(i -> {
            return !Double.isNaN(doubleMatrix.get(i, 0));
        }).toArray();
        this.y = doubleMatrix.getSelection(array, null);
        this.Zoriginal = DoubleMatrixFactory.DEFAULT.identity(this.Nran);
        this.Z = this.Zoriginal.getSelection(array, null);
        this.N = this.y.numberOfRows();
        this.Xoriginal = doubleMatrix2;
        this.X = doubleMatrix2.getSelection(array, null);
        this.q = this.X.numberOfColumns();
        init();
    }

    protected void init() {
        int i = this.N - this.q;
        this.c = (i * Math.log((i / 2) / 3.141592653589793d)) - i;
        this.lambda = new double[i];
        DoubleMatrix tcrossproduct = this.Z.mult(this.K).tcrossproduct(this.Z);
        this.eigA = tcrossproduct.getEigenvalueDecomposition();
        double[] eigenvalues = this.eigA.getEigenvalues();
        int length = eigenvalues.length;
        double d = eigenvalues[0];
        for (int i2 = 1; i2 < length; i2++) {
            d = Math.min(d, eigenvalues[i2]);
        }
        double d2 = d < 0.01d ? ((-1.0d) * d) + 0.5d : 0.0d;
        DoubleMatrix[] xtXGM = this.X.getXtXGM();
        DoubleMatrix doubleMatrix = xtXGM[0];
        DoubleMatrix doubleMatrix2 = xtXGM[2];
        DoubleMatrix doubleMatrix3 = xtXGM[1];
        int numberOfRows = tcrossproduct.numberOfRows();
        for (int i3 = 0; i3 < numberOfRows; i3++) {
            tcrossproduct.set(i3, i3, tcrossproduct.get(i3, i3) + d2);
        }
        this.eig = doubleMatrix2.mult(tcrossproduct.mult(doubleMatrix2)).getEigenvalueDecomposition();
        double[] eigenvalues2 = this.eig.getEigenvalues();
        int[] sortedIndexofAbsoluteValues = getSortedIndexofAbsoluteValues(eigenvalues2);
        int[] iArr = new int[i];
        for (int i4 = 0; i4 < i; i4++) {
            iArr[i4] = sortedIndexofAbsoluteValues[i4];
        }
        this.U = this.eig.getEigenvectors().getSelection(null, sortedIndexofAbsoluteValues);
        for (int i5 = 0; i5 < i; i5++) {
            this.lambda[i5] = eigenvalues2[iArr[i5]] - d2;
        }
    }

    private int[] getSortedIndexofAbsoluteValues(double[] dArr) {
        int length = dArr.length;
        int[] iArr = new int[length];
        C1Pair[] c1PairArr = new C1Pair[length];
        for (int i = 0; i < length; i++) {
            c1PairArr[i] = new C1Pair(i, dArr[i]);
        }
        Arrays.sort(c1PairArr);
        for (int i2 = 0; i2 < length; i2++) {
            iArr[i2] = c1PairArr[i2].order;
        }
        return iArr;
    }

    public void solve() {
        DoubleMatrix crossproduct = this.U.crossproduct(this.y);
        int numberOfRows = crossproduct.numberOfRows();
        this.eta2 = new double[numberOfRows];
        for (int i = 0; i < numberOfRows; i++) {
            this.eta2[i] = crossproduct.get(i, 0) * crossproduct.get(i, 0);
        }
        if (this.findDelta) {
            this.delta = findDeltaInInterval(new double[]{this.lowerlimit, this.upperlimit});
        }
        this.lnLikelihood = lnlk(this.delta);
        this.invH = inverseH(this.delta);
        this.beta = calculateBeta();
        double genvar = getGenvar(this.beta);
        this.dfModel = this.q - 1;
        this.dfError = this.N - this.q;
        this.varResidual = genvar * this.delta;
        this.varRandomEffect = genvar;
    }

    public void calculateBlupsPredictedResiduals() {
        calculateBLUP();
        this.pred = calculatePred();
        this.res = calculateRes();
    }

    public void calculateBlupsPredicted() {
        calculateBLUP();
        this.pred = calculatePred();
    }

    private double findDeltaInInterval(double[] dArr) {
        double[][] scanlnlk = scanlnlk(dArr[0], dArr[1]);
        double[][] findSignChanges = findSignChanges(scanlnlk);
        double[] dArr2 = {Double.NaN, Double.NaN, Double.NaN};
        int length = scanlnlk.length;
        for (int i = 0; i < length; i++) {
            if (Double.isNaN(dArr2[1])) {
                dArr2 = scanlnlk[i];
            } else if (!Double.isNaN(scanlnlk[i][1]) && scanlnlk[i][1] > dArr2[1]) {
                dArr2 = scanlnlk[i];
            }
        }
        double d = dArr2[0];
        double d2 = dArr2[1];
        for (double[] dArr3 : findSignChanges) {
            double findMaximum = findMaximum(dArr3);
            if (!Double.isNaN(findMaximum)) {
                double lnlk = lnlk(findMaximum);
                if (!Double.isNaN(lnlk) && lnlk > d2) {
                    d = findMaximum;
                    d2 = lnlk;
                }
            }
        }
        return d;
    }

    private double lnlk(double d) {
        double d2 = 0.0d;
        double d3 = 0.0d;
        int i = this.N - this.q;
        for (int i2 = 0; i2 < i; i2++) {
            double d4 = this.lambda[i2] + d;
            if (d4 < 0.0d) {
                return Double.NaN;
            }
            d2 += this.eta2[i2] / d4;
            d3 += Math.log(d4);
        }
        return ((this.c - (i * Math.log(d2))) - d3) / 2.0d;
    }

    private double d1lnlk(double d) {
        double d2 = 0.0d;
        double d3 = 0.0d;
        double d4 = 0.0d;
        int i = this.N - this.q;
        for (int i2 = 0; i2 < i; i2++) {
            double d5 = 1.0d / (this.lambda[i2] + d);
            double d6 = this.eta2[i2] * d5;
            d2 += d6;
            d3 += d6 * d5;
            d4 += d5;
        }
        return (((i * d3) / d2) / 2.0d) - (d4 / 2.0d);
    }

    private double[][] scanlnlk(double d, double d2) {
        double[][] dArr = new double[this.nregions][3];
        double log10 = Math.log10(d2);
        double log102 = Math.log10(d);
        double d3 = (log10 - log102) / (this.nregions - 1);
        for (int i = 0; i < this.nregions; i++) {
            double pow = Math.pow(10.0d, log102 + (i * d3));
            dArr[i][0] = pow;
            dArr[i][1] = lnlk(pow);
            dArr[i][2] = d1lnlk(pow);
        }
        return dArr;
    }

    private double[][] findSignChanges(double[][] dArr) {
        ArrayList arrayList = new ArrayList();
        int length = dArr.length;
        for (int i = 0; i < length - 1; i++) {
            if (dArr[i][2] > 0.0d && dArr[i + 1][2] <= 0.0d && !Double.isNaN(dArr[i][1])) {
                arrayList.add(new Double[]{Double.valueOf(dArr[i][0]), Double.valueOf(dArr[i + 1][0])});
            }
        }
        int size = arrayList.size();
        double[][] dArr2 = new double[size][2];
        for (int i2 = 0; i2 < size; i2++) {
            dArr2[i2][0] = ((Double[]) arrayList.get(i2))[0].doubleValue();
            dArr2[i2][1] = ((Double[]) arrayList.get(i2))[1].doubleValue();
        }
        return dArr2;
    }

    private double findMaximum(double[] dArr) {
        double d = dArr[0];
        boolean z = false;
        int i = this.N - this.q;
        for (int i2 = 0; !z && i2 < this.maxiter; i2++) {
            double d2 = 0.0d;
            double d3 = 0.0d;
            double d4 = 0.0d;
            double d5 = 0.0d;
            double d6 = 0.0d;
            for (int i3 = 0; i3 < i; i3++) {
                double d7 = this.lambda[i3] + d;
                double d8 = d7 * d7;
                d2 += this.eta2[i3] / d7;
                d3 += this.eta2[i3] / d8;
                d4 += this.eta2[i3] / (d8 * d7);
                d5 += 1.0d / d7;
                d6 += 1.0d / d8;
            }
            double d9 = ((i * d3) / d2) - d5;
            if (Math.abs(d9) < this.convergence) {
                z = true;
            } else {
                d -= d9 / (d6 + (((i * ((d3 * d3) - ((2.0d * d2) * d4))) / d2) / d2));
            }
            if (d < dArr[0] || d > dArr[1]) {
                this.subintervalCount++;
                if (this.subintervalCount > 3) {
                    this.subintervalCount = 0;
                    return Double.NaN;
                }
                d = findDeltaInInterval(dArr);
                z = true;
            }
        }
        this.subintervalCount = 0;
        return d;
    }

    private DoubleMatrix inverseH(double d) {
        DoubleMatrix eigenvectors = this.eigA.getEigenvectors();
        DoubleMatrix eigenvalueMatrix = this.eigA.getEigenvalueMatrix();
        int numberOfRows = eigenvalueMatrix.numberOfRows();
        for (int i = 0; i < numberOfRows; i++) {
            eigenvalueMatrix.set(i, i, 1.0d / (eigenvalueMatrix.get(i, i) + d));
        }
        return eigenvectors.mult(eigenvalueMatrix.tcrossproduct(eigenvectors));
    }

    private DoubleMatrix calculateBeta() {
        DoubleMatrix crossproduct = this.X.crossproduct(this.invH);
        this.invXHX = crossproduct.mult(this.X).inverse();
        return this.invXHX.mult(crossproduct.mult(this.y));
    }

    public void calculateBLUP() {
        this.Xbeta = this.X.mult(this.beta);
        DoubleMatrix minus = this.y.minus(this.Xbeta);
        DoubleMatrix mult = this.K.mult(this.Z.transpose());
        DoubleMatrix mult2 = mult.mult(this.invH);
        this.blup = mult2.mult(minus);
        if (this.calculatePEV) {
            DoubleMatrix mult3 = mult2.mult(this.X);
            DoubleMatrix minus2 = this.K.copy().minus(mult2.tcrossproduct(mult)).minus(mult3.mult(this.invXHX).tcrossproduct(mult3));
            int numberOfRows = minus2.numberOfRows();
            this.pev = DoubleMatrixFactory.DEFAULT.make(numberOfRows, 1);
            for (int i = 0; i < numberOfRows; i++) {
                this.pev.set(i, 0, minus2.get(i, i));
            }
            this.pev.scalarMultEquals(this.varRandomEffect);
        }
    }

    private DoubleMatrix calculatePred() {
        if (this.Xoriginal == null) {
            this.Xbeta = this.X.mult(this.beta);
            return this.Xbeta.plus(this.Z.mult(this.blup));
        }
        this.Xbeta = this.Xoriginal.mult(this.beta);
        return this.Xbeta.plus(this.Zoriginal.mult(this.blup));
    }

    private DoubleMatrix calculateRes() {
        return this.y.minus(this.pred);
    }

    private double getGenvar(DoubleMatrix doubleMatrix) {
        DoubleMatrix copy = this.y.copy();
        copy.minusEquals(this.X.mult(doubleMatrix));
        return copy.crossproduct(this.invH.mult(copy)).get(0, 0) / (this.N - this.q);
    }

    public int getDfMarker() {
        return this.dfMarker;
    }

    public DoubleMatrix getBeta() {
        return this.beta;
    }

    public int getDfModel() {
        return this.dfModel;
    }

    public int getDfError() {
        return this.dfError;
    }

    public double getDelta() {
        return this.delta;
    }

    public DoubleMatrix getInvH() {
        return this.invH;
    }

    public double getVarRes() {
        return this.varResidual;
    }

    public double getVarRan() {
        return this.varRandomEffect;
    }

    public DoubleMatrix getBlup() {
        return this.blup;
    }

    public DoubleMatrix getPev() {
        return this.pev;
    }

    public DoubleMatrix getPred() {
        return this.pred;
    }

    public DoubleMatrix getRes() {
        return this.res;
    }

    public double getLnLikelihood() {
        return this.lnLikelihood;
    }

    public double[] getMarkerFp() {
        double d;
        double d2;
        double d3;
        double d4;
        double d5;
        if (this.dfMarker < 1) {
            return new double[]{Double.NaN, Double.NaN, Double.NaN};
        }
        int numberOfRows = this.beta.numberOfRows();
        int i = numberOfRows - this.dfMarker;
        DoubleMatrix make = DoubleMatrixFactory.DEFAULT.make(this.dfMarker, numberOfRows);
        for (int i2 = 0; i2 < this.dfMarker; i2++) {
            make.set(i2, i2 + i, 1.0d);
        }
        DoubleMatrix mult = make.mult(this.beta);
        DoubleMatrix mult2 = make.mult(this.invXHX.tcrossproduct(make));
        mult2.invert();
        double d6 = (mult.crossproduct(mult2.mult(mult)).get(0, 0) / this.varRandomEffect) / this.dfMarker;
        try {
            d = LinearModelUtils.Ftest(d6, this.dfMarker, this.N - this.q);
        } catch (Exception e) {
            d = Double.NaN;
        }
        if (this.dfMarker != 2) {
            return new double[]{d6, d};
        }
        DoubleMatrix make2 = DoubleMatrixFactory.DEFAULT.make(1, numberOfRows, 0.0d);
        make2.set(0, numberOfRows - 2, 0.5d);
        make2.set(0, numberOfRows - 1, -0.5d);
        double d7 = make2.mult(this.beta).get(0, 0);
        try {
            d2 = ((d7 * d7) / make2.mult(this.invXHX.tcrossproduct(make2)).get(0, 0)) / this.varRandomEffect;
        } catch (Exception e2) {
            d2 = Double.NaN;
        }
        try {
            d3 = LinearModelUtils.Ftest(d2, 1.0d, this.N - this.q);
        } catch (Exception e3) {
            d3 = Double.NaN;
        }
        DoubleMatrix make3 = DoubleMatrixFactory.DEFAULT.make(1, numberOfRows, 0.0d);
        make3.set(0, numberOfRows - 2, -0.5d);
        make3.set(0, numberOfRows - 1, -0.5d);
        double d8 = make3.mult(this.beta).get(0, 0);
        try {
            d4 = ((d8 * d8) / make3.mult(this.invXHX.tcrossproduct(make3)).get(0, 0)) / this.varRandomEffect;
        } catch (Exception e4) {
            d4 = Double.NaN;
        }
        try {
            d5 = LinearModelUtils.Ftest(d4, 1.0d, this.N - this.q);
        } catch (Exception e5) {
            d5 = Double.NaN;
        }
        return new double[]{d6, d, d7, d2, d3, d8, d4, d5};
    }

    public void solveWithNewData(DoubleMatrix doubleMatrix) {
        this.y = doubleMatrix;
        solve();
    }

    public void setCalculatePEV(boolean z) {
        this.calculatePEV = z;
    }
}
