/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.compress.cocode;

import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;
import java.util.PriorityQueue;
import java.util.Queue;
import org.apache.sysds.runtime.compress.CompressionSettings;
import org.apache.sysds.runtime.compress.cocode.AColumnCoCoder;
import org.apache.sysds.runtime.compress.estim.CompressedSizeEstimator;
import org.apache.sysds.runtime.compress.estim.CompressedSizeEstimatorSample;
import org.apache.sysds.runtime.compress.estim.CompressedSizeInfo;
import org.apache.sysds.runtime.compress.estim.CompressedSizeInfoColGroup;

public class CoCodeCostTSMM
extends AColumnCoCoder {
    protected CoCodeCostTSMM(CompressedSizeEstimator e, CompressionSettings cs) {
        super(e, cs);
    }

    @Override
    protected CompressedSizeInfo coCodeColumns(CompressedSizeInfo colInfos, int k) {
        List<CompressedSizeInfoColGroup> joinRes = this.join(colInfos.getInfo());
        if (this._cs.samplingRatio < 0.1 && this._est instanceof CompressedSizeEstimatorSample) {
            LOG.debug((Object)"Performing second join with double sample rate");
            CompressedSizeEstimatorSample estS = (CompressedSizeEstimatorSample)this._est;
            estS.sampleData(estS.getSample().getNumRows() * 2);
            ArrayList<int[]> colG = new ArrayList<int[]>(joinRes.size());
            for (CompressedSizeInfoColGroup g : joinRes) {
                colG.add(g.getColumns());
            }
            joinRes = this.join(estS.computeCompressedSizeInfos(colG, k));
        }
        colInfos.setInfo(joinRes);
        return colInfos;
    }

    private List<CompressedSizeInfoColGroup> join(List<CompressedSizeInfoColGroup> currentGroups) {
        Comparator<CompressedSizeInfoColGroup> comp = Comparator.comparing(CompressedSizeInfoColGroup::getNumVals);
        PriorityQueue<CompressedSizeInfoColGroup> que = new PriorityQueue<CompressedSizeInfoColGroup>(currentGroups.size(), comp);
        ArrayList<CompressedSizeInfoColGroup> ret = new ArrayList<CompressedSizeInfoColGroup>();
        for (CompressedSizeInfoColGroup g : currentGroups) {
            que.add(g);
        }
        double currentCost = this.getCost(que, ret);
        while (que.peek() != null) {
            CompressedSizeInfoColGroup l = (CompressedSizeInfoColGroup)que.poll();
            if (que.peek() != null) {
                CompressedSizeInfoColGroup r = (CompressedSizeInfoColGroup)que.poll();
                CompressedSizeInfoColGroup g = this.joinWithAnalysis(l, r);
                double newCost = this.getCost(que, ret, g);
                if (newCost < currentCost) {
                    currentCost = newCost;
                    que.add(g);
                    continue;
                }
                ret.add(l);
                que.add(r);
                continue;
            }
            ret.add(l);
            break;
        }
        for (CompressedSizeInfoColGroup g : que) {
            ret.add(g);
        }
        return ret;
    }

    private double getCost(Queue<CompressedSizeInfoColGroup> que, List<CompressedSizeInfoColGroup> ret) {
        CompressedSizeInfoColGroup[] queValues = que.toArray(new CompressedSizeInfoColGroup[que.size()]);
        return this.getCost(queValues, ret);
    }

    private double getCost(Queue<CompressedSizeInfoColGroup> que, List<CompressedSizeInfoColGroup> ret, CompressedSizeInfoColGroup g) {
        int i;
        CompressedSizeInfoColGroup[] queValues = que.toArray(new CompressedSizeInfoColGroup[que.size()]);
        double cost = this.getCost(queValues, ret);
        cost += CoCodeCostTSMM.getCostOfSelfTSMM(g);
        for (i = 0; i < queValues.length; ++i) {
            cost += this.getCostOfLeftTransposedMM(queValues[i], g);
        }
        for (i = 0; i < ret.size(); ++i) {
            cost += this.getCostOfLeftTransposedMM(ret.get(i), g);
        }
        return cost;
    }

    private double getCost(CompressedSizeInfoColGroup[] queValues, List<CompressedSizeInfoColGroup> ret) {
        int i;
        double cost = 0.0;
        for (i = 0; i < queValues.length; ++i) {
            cost += CoCodeCostTSMM.getCostOfSelfTSMM(queValues[i]);
            for (int j = i + 1; j < queValues.length; ++j) {
                cost += this.getCostOfLeftTransposedMM(queValues[i], queValues[j]);
            }
            for (CompressedSizeInfoColGroup g : ret) {
                cost += this.getCostOfLeftTransposedMM(queValues[i], g);
            }
        }
        for (i = 0; i < ret.size(); ++i) {
            cost += CoCodeCostTSMM.getCostOfSelfTSMM(ret.get(i));
            for (int j = i + 1; j < ret.size(); ++j) {
                cost += this.getCostOfLeftTransposedMM(ret.get(i), ret.get(j));
            }
        }
        return cost;
    }

    private static double getCostOfSelfTSMM(CompressedSizeInfoColGroup g) {
        double cost = 0.0;
        int nCol = g.getColumns().length;
        return cost += (double)(g.getNumVals() * (nCol * (nCol + 1)) / 2);
    }

    private double getCostOfLeftTransposedMM(CompressedSizeInfoColGroup gl, CompressedSizeInfoColGroup gr) {
        int nRows = this._est.getNumRows();
        int nColsL = gl.getColumns().length;
        int nColsR = gl.getColumns().length;
        double preAggLeft = nRows;
        double preAggRight = nRows;
        double tsL = gl.getTupleSparsity();
        double tsR = gr.getTupleSparsity();
        int nvL = gl.getNumVals();
        int nvR = gr.getNumVals();
        double postScaleLeft = nColsL > 1 && tsL > 0.4 ? (double)(nvL * nColsL) : (double)(nvL * nColsL) * tsL;
        double postScaleRight = nColsR > 1 && tsR > 0.4 ? (double)(nvR * nColsR) : (double)(nvR * nColsR) * tsR;
        double costLeft = preAggLeft + postScaleLeft * 5.0;
        double costRight = preAggRight + postScaleRight * 5.0;
        return Math.min(costLeft, costRight);
    }
}

