/*
 * Decompiled with CFR 0.152.
 */
package org.graalvm.compiler.phases.common;

import jdk.vm.ci.code.CodeUtil;
import org.graalvm.collections.Pair;
import org.graalvm.compiler.core.common.type.IntegerStamp;
import org.graalvm.compiler.nodes.ConstantNode;
import org.graalvm.compiler.nodes.FixedNode;
import org.graalvm.compiler.nodes.NodeView;
import org.graalvm.compiler.nodes.StructuredGraph;
import org.graalvm.compiler.nodes.ValueNode;
import org.graalvm.compiler.nodes.calc.BinaryArithmeticNode;
import org.graalvm.compiler.nodes.calc.FixedBinaryNode;
import org.graalvm.compiler.nodes.calc.IntegerDivRemNode;
import org.graalvm.compiler.nodes.calc.IntegerMulHighNode;
import org.graalvm.compiler.nodes.calc.MulNode;
import org.graalvm.compiler.nodes.calc.NarrowNode;
import org.graalvm.compiler.nodes.calc.RightShiftNode;
import org.graalvm.compiler.nodes.calc.SignExtendNode;
import org.graalvm.compiler.nodes.calc.SignedDivNode;
import org.graalvm.compiler.nodes.calc.SignedRemNode;
import org.graalvm.compiler.nodes.calc.UnsignedRightShiftNode;
import org.graalvm.compiler.phases.Phase;

public class OptimizeDivPhase
extends Phase {
    @Override
    protected void run(StructuredGraph graph) {
        for (IntegerDivRemNode rem : graph.getNodes(IntegerDivRemNode.TYPE)) {
            if (!(rem instanceof SignedRemNode) || !OptimizeDivPhase.divByNonZeroConstant(rem)) continue;
            this.optimizeRem(rem);
        }
        for (IntegerDivRemNode div : graph.getNodes(IntegerDivRemNode.TYPE)) {
            if (!(div instanceof SignedDivNode) || !OptimizeDivPhase.divByNonZeroConstant(div)) continue;
            OptimizeDivPhase.optimizeSignedDiv((SignedDivNode)div);
        }
    }

    @Override
    public float codeSizeIncrease() {
        return 5.0f;
    }

    protected static boolean divByNonZeroConstant(IntegerDivRemNode divRemNode) {
        return divRemNode.getY().isConstant() && divRemNode.getY().asJavaConstant().asLong() != 0L;
    }

    protected final void optimizeRem(IntegerDivRemNode rem) {
        assert (rem.getOp() == IntegerDivRemNode.Op.REM);
        StructuredGraph graph = rem.graph();
        ValueNode div = this.findDivForRem(rem);
        ValueNode mul = BinaryArithmeticNode.mul(graph, div, rem.getY(), NodeView.DEFAULT);
        ValueNode result = BinaryArithmeticNode.sub(graph, rem.getX(), mul, NodeView.DEFAULT);
        graph.replaceFixedWithFloating(rem, result);
    }

    private ValueNode findDivForRem(IntegerDivRemNode rem) {
        ValueNode div;
        if (rem.next() instanceof IntegerDivRemNode && ((IntegerDivRemNode)(div = (IntegerDivRemNode)rem.next())).getOp() == IntegerDivRemNode.Op.DIV && ((IntegerDivRemNode)div).getType() == rem.getType() && ((FixedBinaryNode)div).getX() == rem.getX() && ((FixedBinaryNode)div).getY() == rem.getY()) {
            return div;
        }
        if (rem.predecessor() instanceof IntegerDivRemNode && ((IntegerDivRemNode)(div = (IntegerDivRemNode)rem.predecessor())).getOp() == IntegerDivRemNode.Op.DIV && ((IntegerDivRemNode)div).getType() == rem.getType() && ((FixedBinaryNode)div).getX() == rem.getX() && ((FixedBinaryNode)div).getY() == rem.getY()) {
            return div;
        }
        div = rem.graph().addOrUniqueWithInputs(this.createDiv(rem));
        if (div instanceof FixedNode) {
            rem.graph().addAfterFixed(rem, (FixedNode)div);
        }
        return div;
    }

    protected ValueNode createDiv(IntegerDivRemNode rem) {
        assert (rem instanceof SignedRemNode);
        return SignedDivNode.create(rem.getX(), rem.getY(), rem.getZeroCheck(), NodeView.DEFAULT);
    }

    protected static void optimizeSignedDiv(SignedDivNode div) {
        ConstantNode s;
        ValueNode value;
        ValueNode forX = div.getX();
        long c = div.getY().asJavaConstant().asLong();
        assert (c != 1L && c != -1L && c != 0L);
        IntegerStamp dividendStamp = (IntegerStamp)forX.stamp(NodeView.DEFAULT);
        int bitSize = dividendStamp.getBits();
        Pair<Long, Integer> nums = OptimizeDivPhase.magicDivideConstants(c, bitSize);
        long magicNum = (Long)nums.getLeft();
        int shiftNum = (Integer)nums.getRight();
        assert (shiftNum >= 0);
        ConstantNode m = ConstantNode.forLong(magicNum);
        if (bitSize == 32) {
            value = new MulNode(new SignExtendNode(forX, 64), m);
            if (c > 0L && magicNum < 0L || c < 0L && magicNum > 0L) {
                value = NarrowNode.create(new RightShiftNode(value, ConstantNode.forInt(32)), 32, NodeView.DEFAULT);
                value = c > 0L ? BinaryArithmeticNode.add(value, forX, NodeView.DEFAULT) : BinaryArithmeticNode.sub(value, forX, NodeView.DEFAULT);
                if (shiftNum > 0) {
                    value = new RightShiftNode(value, ConstantNode.forInt(shiftNum));
                }
            } else {
                value = new RightShiftNode(value, ConstantNode.forInt(32 + shiftNum));
                value = new NarrowNode(value, 32);
            }
        } else {
            assert (bitSize == 64);
            value = new IntegerMulHighNode(forX, m);
            if (c > 0L && magicNum < 0L) {
                value = BinaryArithmeticNode.add(value, forX, NodeView.DEFAULT);
            } else if (c < 0L && magicNum > 0L) {
                value = BinaryArithmeticNode.sub(value, forX, NodeView.DEFAULT);
            }
            if (shiftNum > 0) {
                value = new RightShiftNode(value, ConstantNode.forInt(shiftNum));
            }
        }
        if (c < 0L) {
            s = ConstantNode.forInt(bitSize - 1);
            ValueNode sign = UnsignedRightShiftNode.create(value, s, NodeView.DEFAULT);
            value = BinaryArithmeticNode.add(value, sign, NodeView.DEFAULT);
        } else if (dividendStamp.canBeNegative()) {
            s = ConstantNode.forInt(bitSize - 1);
            ValueNode sign = UnsignedRightShiftNode.create(forX, s, NodeView.DEFAULT);
            value = BinaryArithmeticNode.add(value, sign, NodeView.DEFAULT);
        }
        StructuredGraph graph = div.graph();
        graph.replaceFixed(div, graph.addOrUniqueWithInputs(value));
    }

    private static Pair<Long, Integer> magicDivideConstants(long divisor, int size) {
        long delta;
        long twoW = 1L << size - 1;
        long t = twoW + (divisor >>> 63);
        long ad = Math.abs(divisor);
        long anc = t - 1L - Long.remainderUnsigned(t, ad);
        long q1 = Long.divideUnsigned(twoW, anc);
        long r1 = Long.remainderUnsigned(twoW, anc);
        long q2 = Long.divideUnsigned(twoW, ad);
        long r2 = Long.remainderUnsigned(twoW, ad);
        int p = size - 1;
        do {
            ++p;
            q1 = 2L * q1;
            if (Long.compareUnsigned(r1 = 2L * r1, anc) >= 0) {
                ++q1;
                r1 -= anc;
            }
            q2 = 2L * q2;
            if (Long.compareUnsigned(r2 = 2L * r2, ad) < 0) continue;
            ++q2;
            r2 -= ad;
        } while (Long.compareUnsigned(q1, delta = ad - r2) < 0 || q1 == delta && r1 == 0L);
        long magic = CodeUtil.signExtend((long)(q2 + 1L), (int)size);
        if (divisor < 0L) {
            magic = -magic;
        }
        return Pair.create((Object)magic, (Object)(p - size));
    }
}

