/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.hops.rewrite;

import java.util.ArrayList;
import java.util.List;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.conf.CompilerConfig;
import org.apache.sysds.conf.ConfigurationManager;
import org.apache.sysds.hops.Hop;
import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.hops.rewrite.HopRewriteRule;
import org.apache.sysds.hops.rewrite.MarkForLineageReuse;
import org.apache.sysds.hops.rewrite.ProgramRewriteStatus;
import org.apache.sysds.hops.rewrite.RewriteAlgebraicSimplificationDynamic;
import org.apache.sysds.hops.rewrite.RewriteAlgebraicSimplificationStatic;
import org.apache.sysds.hops.rewrite.RewriteBlockSizeAndReblock;
import org.apache.sysds.hops.rewrite.RewriteCommonSubexpressionElimination;
import org.apache.sysds.hops.rewrite.RewriteCompressedReblock;
import org.apache.sysds.hops.rewrite.RewriteConstantFolding;
import org.apache.sysds.hops.rewrite.RewriteElementwiseMultChainOptimization;
import org.apache.sysds.hops.rewrite.RewriteFederatedExecution;
import org.apache.sysds.hops.rewrite.RewriteForLoopVectorization;
import org.apache.sysds.hops.rewrite.RewriteGPUSpecificOps;
import org.apache.sysds.hops.rewrite.RewriteHoistLoopInvariantOperations;
import org.apache.sysds.hops.rewrite.RewriteIndexingVectorization;
import org.apache.sysds.hops.rewrite.RewriteInjectSparkLoopCheckpointing;
import org.apache.sysds.hops.rewrite.RewriteInjectSparkPReadCheckpointing;
import org.apache.sysds.hops.rewrite.RewriteMarkLoopVariablesUpdateInPlace;
import org.apache.sysds.hops.rewrite.RewriteMatrixMultChainOptimization;
import org.apache.sysds.hops.rewrite.RewriteMergeBlockSequence;
import org.apache.sysds.hops.rewrite.RewriteRemoveEmptyBasicBlocks;
import org.apache.sysds.hops.rewrite.RewriteRemoveForLoopEmptySequence;
import org.apache.sysds.hops.rewrite.RewriteRemoveReadAfterWrite;
import org.apache.sysds.hops.rewrite.RewriteRemoveUnnecessaryBranches;
import org.apache.sysds.hops.rewrite.RewriteRemoveUnnecessaryCasts;
import org.apache.sysds.hops.rewrite.RewriteSplitDagDataDependentOperators;
import org.apache.sysds.hops.rewrite.RewriteSplitDagUnknownCSVRead;
import org.apache.sysds.hops.rewrite.RewriteTransientWriteParentHandling;
import org.apache.sysds.hops.rewrite.StatementBlockRewriteRule;
import org.apache.sysds.parser.DMLProgram;
import org.apache.sysds.parser.ForStatement;
import org.apache.sysds.parser.ForStatementBlock;
import org.apache.sysds.parser.FunctionStatement;
import org.apache.sysds.parser.FunctionStatementBlock;
import org.apache.sysds.parser.IfStatement;
import org.apache.sysds.parser.IfStatementBlock;
import org.apache.sysds.parser.ParForStatementBlock;
import org.apache.sysds.parser.StatementBlock;
import org.apache.sysds.parser.WhileStatement;
import org.apache.sysds.parser.WhileStatementBlock;
import org.apache.sysds.runtime.lineage.LineageCacheConfig;

public class ProgramRewriter {
    private static final boolean LDEBUG = false;
    private static final boolean CHECK = false;
    private ArrayList<HopRewriteRule> _dagRuleSet = new ArrayList();
    private ArrayList<StatementBlockRewriteRule> _sbRuleSet = null;

    public ProgramRewriter() {
        this(true, true);
    }

    public ProgramRewriter(boolean staticRewrites, boolean dynamicRewrites) {
        this._sbRuleSet = new ArrayList();
        if (staticRewrites) {
            this._dagRuleSet.add(new RewriteTransientWriteParentHandling());
            this._dagRuleSet.add(new RewriteRemoveReadAfterWrite());
            this._dagRuleSet.add(new RewriteBlockSizeAndReblock());
            if (OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION) {
                this._dagRuleSet.add(new RewriteRemoveUnnecessaryCasts());
            }
            if (OptimizerUtils.ALLOW_COMMON_SUBEXPRESSION_ELIMINATION) {
                this._dagRuleSet.add(new RewriteCommonSubexpressionElimination());
            }
            if (OptimizerUtils.ALLOW_CONSTANT_FOLDING) {
                this._dagRuleSet.add(new RewriteConstantFolding());
            }
            if (OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION) {
                this._dagRuleSet.add(new RewriteAlgebraicSimplificationStatic());
            }
            if (OptimizerUtils.ALLOW_COMMON_SUBEXPRESSION_ELIMINATION) {
                this._dagRuleSet.add(new RewriteCommonSubexpressionElimination());
            }
            if (OptimizerUtils.ALLOW_AUTO_VECTORIZATION) {
                this._dagRuleSet.add(new RewriteIndexingVectorization());
            }
            this._dagRuleSet.add(new RewriteInjectSparkPReadCheckpointing());
            if (OptimizerUtils.ALLOW_BRANCH_REMOVAL) {
                this._sbRuleSet.add(new RewriteRemoveUnnecessaryBranches());
            }
            if (OptimizerUtils.ALLOW_FOR_LOOP_REMOVAL) {
                this._sbRuleSet.add(new RewriteRemoveForLoopEmptySequence());
            }
            if (OptimizerUtils.ALLOW_BRANCH_REMOVAL || OptimizerUtils.ALLOW_FOR_LOOP_REMOVAL) {
                this._sbRuleSet.add(new RewriteMergeBlockSequence());
            }
            if (OptimizerUtils.ALLOW_COMPRESSION_REWRITE) {
                this._sbRuleSet.add(new RewriteCompressedReblock());
            }
            if (OptimizerUtils.ALLOW_SPLIT_HOP_DAGS) {
                this._sbRuleSet.add(new RewriteSplitDagUnknownCSVRead());
            }
            if (ConfigurationManager.getCompilerConfigFlag(CompilerConfig.ConfigType.ALLOW_INDIVIDUAL_SB_SPECIFIC_OPS)) {
                this._sbRuleSet.add(new RewriteSplitDagDataDependentOperators());
            }
            if (OptimizerUtils.ALLOW_AUTO_VECTORIZATION) {
                this._sbRuleSet.add(new RewriteForLoopVectorization());
            }
            this._sbRuleSet.add(new RewriteInjectSparkLoopCheckpointing(true));
            if (OptimizerUtils.ALLOW_CODE_MOTION) {
                this._sbRuleSet.add(new RewriteHoistLoopInvariantOperations());
            }
            if (OptimizerUtils.ALLOW_LOOP_UPDATE_IN_PLACE) {
                this._sbRuleSet.add(new RewriteMarkLoopVariablesUpdateInPlace());
            }
            if (LineageCacheConfig.getCompAssRW()) {
                this._sbRuleSet.add(new MarkForLineageReuse());
            }
        }
        if (dynamicRewrites) {
            if (DMLScript.USE_ACCELERATOR) {
                this._dagRuleSet.add(new RewriteGPUSpecificOps());
            }
            if (OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES) {
                this._dagRuleSet.add(new RewriteMatrixMultChainOptimization());
                this._dagRuleSet.add(new RewriteElementwiseMultChainOptimization());
            }
            if (OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION) {
                this._dagRuleSet.add(new RewriteAlgebraicSimplificationDynamic());
                this._dagRuleSet.add(new RewriteAlgebraicSimplificationStatic());
            }
            if (OptimizerUtils.FEDERATED_COMPILATION) {
                this._dagRuleSet.add(new RewriteFederatedExecution());
            }
        }
        if (OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION) {
            this._dagRuleSet.add(new RewriteRemoveUnnecessaryCasts());
        }
        if (OptimizerUtils.ALLOW_COMMON_SUBEXPRESSION_ELIMINATION) {
            this._dagRuleSet.add(new RewriteCommonSubexpressionElimination(true));
        }
        if (OptimizerUtils.ALLOW_CONSTANT_FOLDING) {
            this._dagRuleSet.add(new RewriteConstantFolding());
        }
        this._sbRuleSet.add(new RewriteRemoveEmptyBasicBlocks());
    }

    public ProgramRewriter(HopRewriteRule ... rewrites) {
        for (HopRewriteRule rewrite : rewrites) {
            this._dagRuleSet.add(rewrite);
        }
        this._sbRuleSet = new ArrayList();
    }

    public ProgramRewriter(StatementBlockRewriteRule ... rewrites) {
        this._sbRuleSet = new ArrayList();
        for (StatementBlockRewriteRule rewrite : rewrites) {
            this._sbRuleSet.add(rewrite);
        }
    }

    public ProgramRewriter(ArrayList<HopRewriteRule> hRewrites, ArrayList<StatementBlockRewriteRule> sbRewrites) {
        this._dagRuleSet.addAll(hRewrites);
        this._sbRuleSet = new ArrayList();
        this._sbRuleSet.addAll(sbRewrites);
    }

    public void removeHopRewrite(Class<? extends HopRewriteRule> clazz) {
        this._dagRuleSet.removeIf(r -> r.getClass().equals(clazz));
    }

    public void removeStatementBlockRewrite(Class<? extends StatementBlockRewriteRule> clazz) {
        this._sbRuleSet.removeIf(r -> r.getClass().equals(clazz));
    }

    public ProgramRewriteStatus rewriteProgramHopDAGs(DMLProgram dmlp) {
        return this.rewriteProgramHopDAGs(dmlp, true);
    }

    public ProgramRewriteStatus rewriteProgramHopDAGs(DMLProgram dmlp, boolean splitDags) {
        return this.rewriteProgramHopDAGs(dmlp, splitDags, new ProgramRewriteStatus());
    }

    public ProgramRewriteStatus rewriteProgramHopDAGs(DMLProgram dmlp, boolean splitDags, ProgramRewriteStatus state) {
        for (String namespaceKey : dmlp.getNamespaces().keySet()) {
            for (String fname : dmlp.getFunctionStatementBlocks(namespaceKey).keySet()) {
                FunctionStatementBlock fsblock = dmlp.getFunctionStatementBlock(namespaceKey, fname);
                this.rewriteHopDAGsFunction(fsblock, state, splitDags);
            }
        }
        for (int i = 0; i < dmlp.getNumStatementBlocks(); ++i) {
            StatementBlock current = dmlp.getStatementBlock(i);
            this.rRewriteStatementBlockHopDAGs(current, state);
        }
        if (!this._sbRuleSet.isEmpty()) {
            dmlp.setStatementBlocks(this.rRewriteStatementBlocks(dmlp.getStatementBlocks(), state, splitDags));
        }
        return state;
    }

    public void rewriteHopDAGsFunction(FunctionStatementBlock fsb, boolean splitDags) {
        this.rewriteHopDAGsFunction(fsb, new ProgramRewriteStatus(), splitDags);
    }

    public void rewriteHopDAGsFunction(FunctionStatementBlock fsb, ProgramRewriteStatus state, boolean splitDags) {
        this.rRewriteStatementBlockHopDAGs(fsb, state);
        if (!this._sbRuleSet.isEmpty()) {
            this.rRewriteStatementBlock(fsb, state, splitDags);
        }
    }

    public void rRewriteStatementBlockHopDAGs(StatementBlock current, ProgramRewriteStatus state) {
        if (state == null) {
            state = new ProgramRewriteStatus();
        }
        if (current instanceof FunctionStatementBlock) {
            FunctionStatementBlock fsb = (FunctionStatementBlock)current;
            FunctionStatement fstmt = (FunctionStatement)fsb.getStatement(0);
            for (StatementBlock sb : fstmt.getBody()) {
                this.rRewriteStatementBlockHopDAGs(sb, state);
            }
        } else if (current instanceof WhileStatementBlock) {
            WhileStatementBlock wsb = (WhileStatementBlock)current;
            WhileStatement wstmt = (WhileStatement)wsb.getStatement(0);
            wsb.setPredicateHops(this.rewriteHopDAG(wsb.getPredicateHops(), state));
            for (StatementBlock sb : wstmt.getBody()) {
                this.rRewriteStatementBlockHopDAGs(sb, state);
            }
        } else if (current instanceof IfStatementBlock) {
            IfStatementBlock isb = (IfStatementBlock)current;
            IfStatement istmt = (IfStatement)isb.getStatement(0);
            isb.setPredicateHops(this.rewriteHopDAG(isb.getPredicateHops(), state));
            for (StatementBlock sb : istmt.getIfBody()) {
                this.rRewriteStatementBlockHopDAGs(sb, state);
            }
            for (StatementBlock sb : istmt.getElseBody()) {
                this.rRewriteStatementBlockHopDAGs(sb, state);
            }
        } else if (current instanceof ForStatementBlock) {
            ForStatementBlock fsb = (ForStatementBlock)current;
            ForStatement fstmt = (ForStatement)fsb.getStatement(0);
            fsb.setFromHops(this.rewriteHopDAG(fsb.getFromHops(), state));
            fsb.setToHops(this.rewriteHopDAG(fsb.getToHops(), state));
            fsb.setIncrementHops(this.rewriteHopDAG(fsb.getIncrementHops(), state));
            for (StatementBlock sb : fstmt.getBody()) {
                this.rRewriteStatementBlockHopDAGs(sb, state);
            }
        } else {
            current.setHops(this.rewriteHopDAG(current.getHops(), state));
        }
    }

    public ArrayList<Hop> rewriteHopDAG(ArrayList<Hop> roots, ProgramRewriteStatus state) {
        for (HopRewriteRule r : this._dagRuleSet) {
            Hop.resetVisitStatus(roots);
            roots = r.rewriteHopDAGs(roots, state);
        }
        return roots;
    }

    public Hop rewriteHopDAG(Hop root, ProgramRewriteStatus state) {
        if (root == null) {
            return null;
        }
        for (HopRewriteRule r : this._dagRuleSet) {
            root.resetVisitStatus();
            root = r.rewriteHopDAG(root, state);
        }
        return root;
    }

    public ArrayList<StatementBlock> rRewriteStatementBlocks(ArrayList<StatementBlock> sbs, ProgramRewriteStatus status, boolean splitDags) {
        if (status == null) {
            status = new ProgramRewriteStatus();
        }
        List<StatementBlock> tmp = sbs;
        for (StatementBlockRewriteRule r : this._sbRuleSet) {
            if (!splitDags && r.createsSplitDag()) continue;
            tmp = r.rewriteStatementBlocks(tmp, status);
        }
        List<StatementBlock> tmp2 = new ArrayList<StatementBlock>();
        for (StatementBlock sb : tmp) {
            tmp2.addAll(this.rRewriteStatementBlock(sb, status, splitDags));
        }
        for (StatementBlockRewriteRule r : this._sbRuleSet) {
            if (!splitDags && r.createsSplitDag()) continue;
            tmp2 = r.rewriteStatementBlocks(tmp2, status);
        }
        sbs.clear();
        sbs.addAll(tmp2);
        return sbs;
    }

    public ArrayList<StatementBlock> rRewriteStatementBlock(StatementBlock sb, ProgramRewriteStatus status, boolean splitDags) {
        ArrayList<StatementBlock> ret = new ArrayList<StatementBlock>();
        ret.add(sb);
        if (sb instanceof FunctionStatementBlock) {
            FunctionStatementBlock fsb = (FunctionStatementBlock)sb;
            FunctionStatement fstmt = (FunctionStatement)fsb.getStatement(0);
            fstmt.setBody(this.rRewriteStatementBlocks(fstmt.getBody(), status, splitDags));
        } else if (sb instanceof WhileStatementBlock) {
            WhileStatementBlock wsb = (WhileStatementBlock)sb;
            WhileStatement wstmt = (WhileStatement)wsb.getStatement(0);
            wstmt.setBody(this.rRewriteStatementBlocks(wstmt.getBody(), status, splitDags));
        } else if (sb instanceof IfStatementBlock) {
            IfStatementBlock isb = (IfStatementBlock)sb;
            IfStatement istmt = (IfStatement)isb.getStatement(0);
            istmt.setIfBody(this.rRewriteStatementBlocks(istmt.getIfBody(), status, splitDags));
            istmt.setElseBody(this.rRewriteStatementBlocks(istmt.getElseBody(), status, splitDags));
        } else if (sb instanceof ForStatementBlock) {
            boolean prestatus = status.isInParforContext();
            if (sb instanceof ParForStatementBlock) {
                status.setInParforContext(true);
            }
            ForStatementBlock fsb = (ForStatementBlock)sb;
            ForStatement fstmt = (ForStatement)fsb.getStatement(0);
            fstmt.setBody(this.rRewriteStatementBlocks(fstmt.getBody(), status, splitDags));
            status.setInParforContext(prestatus);
        }
        for (StatementBlockRewriteRule r : this._sbRuleSet) {
            if (!splitDags && r.createsSplitDag()) continue;
            ArrayList<StatementBlock> tmp = new ArrayList<StatementBlock>();
            for (StatementBlock sbc : ret) {
                tmp.addAll(r.rewriteStatementBlock(sbc, status));
            }
            ret.clear();
            ret.addAll(tmp);
        }
        return ret;
    }
}

