/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.compress.estim.sample;

import java.util.HashMap;
import org.apache.commons.math3.analysis.UnivariateFunction;
import org.apache.commons.math3.analysis.solvers.UnivariateSolverUtils;
import org.apache.sysds.runtime.compress.estim.sample.FrequencyCount;
import org.apache.sysds.runtime.compress.utils.ABitmap;

public class HassAndStokes {
    public static final double HAAS_AND_STOKES_ALPHA1 = 0.9;
    public static final double HAAS_AND_STOKES_ALPHA2 = 30.0;
    public static final int HAAS_AND_STOKES_UJ2A_C = 50;
    public static final boolean HAAS_AND_STOKES_UJ2A_CUT2 = true;
    public static final boolean HAAS_AND_STOKES_UJ2A_SOLVE = true;
    public static final int MAX_SOLVE_CACHE_SIZE = 65536;

    public static int haasAndStokes(ABitmap ubm, int nRows, int sampleSize, HashMap<Integer, Double> solveCache) {
        int numVals = ubm.getNumValues();
        int[] freqCounts = FrequencyCount.get(ubm);
        if (numVals == 0) {
            return 1;
        }
        double q = (double)sampleSize / (double)nRows;
        double f1 = freqCounts[0];
        double duj1 = HassAndStokes.getDuj1Estimate(q, f1, sampleSize, numVals);
        double gamma = HassAndStokes.getGammaSquared(duj1, freqCounts, sampleSize, nRows);
        double d = -1.0;
        d = gamma < 0.9 ? HassAndStokes.getDuj2Estimate(q, f1, sampleSize, numVals, gamma) : (gamma < 30.0 ? HassAndStokes.getDuj2aEstimate(q, freqCounts, sampleSize, numVals, gamma, nRows, solveCache) : HassAndStokes.getSh3Estimate(q, freqCounts, numVals));
        return Math.max(1, (int)Math.round(d));
    }

    private static double getDuj1Estimate(double q, double f1, int n, int dn) {
        return (double)dn / (1.0 - (1.0 - q) * f1 / (double)n);
    }

    private static double getDuj2Estimate(double q, double f1, int n, int dn, double gammaDuj1) {
        return ((double)dn - (1.0 - q) * f1 * Math.log(1.0 - q) * gammaDuj1 / q) / (1.0 - (1.0 - q) * f1 / (double)n);
    }

    private static double getDuj2aEstimate(double q, int[] f, int n, int dn, double gammaDuj1, int N, HashMap<Integer, Double> solveCache) {
        int i;
        int c = f.length / 2 + 1;
        int nB = 0;
        int cardB = 0;
        for (int i2 = c; i2 <= f.length; ++i2) {
            if (f[i2 - 1] == 0) continue;
            nB += f[i2 - 1] * i2;
            cardB += f[i2 - 1];
        }
        if (n - nB == 0) {
            return HassAndStokes.getDuj2Estimate(q, f[0], n, dn, gammaDuj1);
        }
        int updatedN = N;
        for (i = c; i <= f.length; ++i) {
            if (f[i - 1] == 0) continue;
            updatedN = (int)((double)updatedN - (double)f[i - 1] * HassAndStokes.getMethodOfMomentsEstimate(i, q, 1.0, N, solveCache));
        }
        for (i = c; i <= f.length; ++i) {
            f[i - 1] = 0;
        }
        double updatedDuj1 = HassAndStokes.getDuj1Estimate(q, f[0], n - nB, dn - cardB);
        double updatedGammaDuj1 = HassAndStokes.getGammaSquared(updatedDuj1, f, n - nB, updatedN);
        double duj2 = HassAndStokes.getDuj2Estimate(q, f[0], n - nB, dn - cardB, updatedGammaDuj1);
        return duj2 + (double)cardB;
    }

    private static double getGammaSquared(double D, int[] f, int n, int N) {
        double gamma = 0.0;
        for (int i = 1; i <= f.length; ++i) {
            if (f[i - 1] == 0) continue;
            gamma += (double)(i * (i - 1) * f[i - 1]);
        }
        gamma *= D / (double)n / (double)n;
        return Math.max(0.0, gamma += D / (double)N - 1.0);
    }

    private static double getSh3Estimate(double q, int[] f, double dn) {
        double fraq11 = 0.0;
        double fraq12 = 0.0;
        double fraq21 = 0.0;
        double fraq22 = 0.0;
        for (int i = 1; i <= f.length; ++i) {
            if (f[i - 1] == 0) continue;
            fraq11 += (double)i * q * q * Math.pow(1.0 - q * q, i - 1) * (double)f[i - 1];
            fraq12 += (Math.pow(1.0 - q * q, i) - Math.pow(1.0 - q, i)) * (double)f[i - 1];
            fraq21 += Math.pow(1.0 - q, i) * (double)f[i - 1];
            fraq22 += (double)i * q * Math.pow(1.0 - q, i - 1) * (double)f[i - 1];
        }
        return dn + (double)f[0] * fraq11 / fraq12 * Math.pow(fraq21 / fraq22, 2.0);
    }

    private static double getMethodOfMomentsEstimate(int nj, double q, double min, double max, HashMap<Integer, Double> solveCache) {
        if (solveCache.containsKey(nj)) {
            return solveCache.get(nj);
        }
        double est = UnivariateSolverUtils.solve((UnivariateFunction)new MethodOfMomentsFunction(nj, q), (double)min, (double)max, (double)1.0E-9);
        if (solveCache.size() < 65536) {
            solveCache.put(nj, est);
        }
        return est;
    }

    private static class MethodOfMomentsFunction
    implements UnivariateFunction {
        private final int _nj;
        private final double _q;

        public MethodOfMomentsFunction(int nj, double q) {
            this._nj = nj;
            this._q = q;
        }

        public double value(double x) {
            return this._q * x / (1.0 - Math.pow(1.0 - this._q, x)) - (double)this._nj;
        }
    }
}

