/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.hops.rewrite;

import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.sysds.api.DMLException;
import org.apache.sysds.hops.AggBinaryOp;
import org.apache.sysds.hops.AggUnaryOp;
import org.apache.sysds.hops.BinaryOp;
import org.apache.sysds.hops.DataOp;
import org.apache.sysds.hops.Hop;
import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.hops.ReorgOp;
import org.apache.sysds.hops.TernaryOp;
import org.apache.sysds.hops.cost.HopRel;
import org.apache.sysds.hops.ipa.FunctionCallGraph;
import org.apache.sysds.hops.ipa.FunctionCallSizeInfo;
import org.apache.sysds.hops.ipa.IPAPass;
import org.apache.sysds.parser.DMLProgram;
import org.apache.sysds.parser.ForStatement;
import org.apache.sysds.parser.ForStatementBlock;
import org.apache.sysds.parser.FunctionStatement;
import org.apache.sysds.parser.FunctionStatementBlock;
import org.apache.sysds.parser.IfStatement;
import org.apache.sysds.parser.IfStatementBlock;
import org.apache.sysds.parser.Statement;
import org.apache.sysds.parser.StatementBlock;
import org.apache.sysds.parser.WhileStatement;
import org.apache.sysds.parser.WhileStatementBlock;
import org.apache.sysds.runtime.instructions.fed.FEDInstruction;

public class IPAPassRewriteFederatedPlan
extends IPAPass {
    private static final Map<Long, List<HopRel>> hopRelMemo = new HashMap<Long, List<HopRel>>();

    @Override
    public boolean isApplicable(FunctionCallGraph fgraph) {
        return OptimizerUtils.FEDERATED_COMPILATION;
    }

    @Override
    public boolean rewriteProgram(DMLProgram prog, FunctionCallGraph fgraph, FunctionCallSizeInfo fcallSizes) {
        this.rewriteStatementBlocks(prog.getStatementBlocks());
        return false;
    }

    public ArrayList<StatementBlock> rewriteStatementBlocks(List<StatementBlock> sbs) {
        ArrayList<StatementBlock> rewrittenStmBlocks = new ArrayList<StatementBlock>();
        for (StatementBlock stmBlock : sbs) {
            rewrittenStmBlocks.addAll(this.rewriteStatementBlock(stmBlock));
        }
        return rewrittenStmBlocks;
    }

    public ArrayList<StatementBlock> rewriteStatementBlock(StatementBlock sb) {
        if (sb instanceof WhileStatementBlock) {
            return this.rewriteWhileStatementBlock((WhileStatementBlock)sb);
        }
        if (sb instanceof IfStatementBlock) {
            return this.rewriteIfStatementBlock((IfStatementBlock)sb);
        }
        if (sb instanceof ForStatementBlock) {
            return this.rewriteForStatementBlock((ForStatementBlock)sb);
        }
        if (sb instanceof FunctionStatementBlock) {
            return this.rewriteFunctionStatementBlock((FunctionStatementBlock)sb);
        }
        this.selectFederatedExecutionPlan(sb.getHops());
        return new ArrayList<StatementBlock>(Collections.singletonList(sb));
    }

    private ArrayList<StatementBlock> rewriteWhileStatementBlock(WhileStatementBlock whileSB) {
        Hop whilePredicateHop = whileSB.getPredicateHops();
        this.selectFederatedExecutionPlan(whilePredicateHop);
        for (Statement stm : whileSB.getStatements()) {
            WhileStatement whileStm = (WhileStatement)stm;
            whileStm.setBody(this.rewriteStatementBlocks(whileStm.getBody()));
        }
        return new ArrayList<StatementBlock>(Collections.singletonList(whileSB));
    }

    private ArrayList<StatementBlock> rewriteIfStatementBlock(IfStatementBlock ifSB) {
        this.selectFederatedExecutionPlan(ifSB.getPredicateHops());
        for (Statement statement : ifSB.getStatements()) {
            IfStatement ifStatement = (IfStatement)statement;
            ifStatement.setIfBody(this.rewriteStatementBlocks(ifStatement.getIfBody()));
            ifStatement.setElseBody(this.rewriteStatementBlocks(ifStatement.getElseBody()));
        }
        return new ArrayList<StatementBlock>(Collections.singletonList(ifSB));
    }

    private ArrayList<StatementBlock> rewriteForStatementBlock(ForStatementBlock forSB) {
        this.selectFederatedExecutionPlan(forSB.getFromHops());
        this.selectFederatedExecutionPlan(forSB.getToHops());
        this.selectFederatedExecutionPlan(forSB.getIncrementHops());
        for (Statement statement : forSB.getStatements()) {
            ForStatement forStatement = (ForStatement)statement;
            forStatement.setBody(this.rewriteStatementBlocks(forStatement.getBody()));
        }
        return new ArrayList<StatementBlock>(Collections.singletonList(forSB));
    }

    private ArrayList<StatementBlock> rewriteFunctionStatementBlock(FunctionStatementBlock funcSB) {
        for (Statement statement : funcSB.getStatements()) {
            FunctionStatement funcStm = (FunctionStatement)statement;
            funcStm.setBody(this.rewriteStatementBlocks(funcStm.getBody()));
        }
        return new ArrayList<StatementBlock>(Collections.singletonList(funcSB));
    }

    private void setFinalFedout(Hop root) {
        HopRel optimalRootHopRel = hopRelMemo.get(root.getHopID()).stream().min(Comparator.comparingDouble(HopRel::getCost)).orElseThrow(() -> new DMLException("Hop root " + root + " has no feasible federated output alternatives"));
        this.setFinalFedout(root, optimalRootHopRel);
    }

    private void setFinalFedout(Hop root, HopRel rootHopRel) {
        this.updateFederatedOutput(root, rootHopRel);
        this.visitInputDependency(rootHopRel);
    }

    private void visitInputDependency(HopRel rootHopRel) {
        List<HopRel> hopRelInputs = rootHopRel.getInputDependency();
        for (HopRel input : hopRelInputs) {
            this.setFinalFedout(input.getHopRef(), input);
        }
    }

    private void updateFederatedOutput(Hop root, HopRel updateHopRel) {
        root.setFederatedOutput(updateHopRel.getFederatedOutput());
        root.setFederatedCost(updateHopRel.getCostObject());
    }

    private void selectFederatedExecutionPlan(ArrayList<Hop> roots) {
        for (Hop root : roots) {
            this.selectFederatedExecutionPlan(root);
        }
    }

    private void selectFederatedExecutionPlan(Hop root) {
        this.visitFedPlanHop(root);
        this.setFinalFedout(root);
    }

    private void visitFedPlanHop(Hop currentHop) {
        if (hopRelMemo.containsKey(currentHop.getHopID())) {
            return;
        }
        if (currentHop.getInput() != null && currentHop.getInput().size() > 0) {
            for (Hop input : currentHop.getInput()) {
                this.visitFedPlanHop(input);
            }
        }
        ArrayList<HopRel> hopRels = new ArrayList<HopRel>();
        if (this.isFedInstSupportedHop(currentHop)) {
            for (FEDInstruction.FederatedOutput fedoutValue : FEDInstruction.FederatedOutput.values()) {
                if (!this.isFedOutSupported(currentHop, fedoutValue)) continue;
                hopRels.add(new HopRel(currentHop, fedoutValue, hopRelMemo));
            }
        }
        if (hopRels.isEmpty()) {
            hopRels.add(new HopRel(currentHop, FEDInstruction.FederatedOutput.NONE, hopRelMemo));
        }
        hopRelMemo.put(currentHop.getHopID(), hopRels);
    }

    private boolean isFedInstSupportedHop(Hop hop) {
        return hop instanceof AggBinaryOp || hop instanceof BinaryOp || hop instanceof ReorgOp || hop instanceof AggUnaryOp || hop instanceof TernaryOp || hop instanceof DataOp;
    }

    private boolean isFedOutSupported(Hop associatedHop, FEDInstruction.FederatedOutput fedOut) {
        switch (fedOut) {
            case FOUT: {
                return this.isFOUTSupported(associatedHop);
            }
            case LOUT: {
                return this.isLOUTSupported(associatedHop);
            }
            case NONE: {
                return false;
            }
        }
        return true;
    }

    private boolean isFOUTSupported(Hop associatedHop) {
        if (associatedHop instanceof AggUnaryOp && associatedHop.isScalar()) {
            return false;
        }
        return !associatedHop.getInput().stream().noneMatch(input -> hopRelMemo.get(input.getHopID()).stream().anyMatch(HopRel::hasFederatedOutput)) || associatedHop.isFederatedDataOp();
    }

    private boolean isLOUTSupported(Hop associatedHop) {
        return associatedHop.getPrivacy() == null || !associatedHop.getPrivacy().hasConstraints();
    }
}

