/*
 * Decompiled with CFR 0.152.
 */
package com.oracle.truffle.llvm.initialization;

import com.oracle.truffle.api.RootCallTarget;
import com.oracle.truffle.api.Truffle;
import com.oracle.truffle.api.frame.FrameDescriptor;
import com.oracle.truffle.api.frame.VirtualFrame;
import com.oracle.truffle.api.nodes.Node;
import com.oracle.truffle.api.nodes.RootNode;
import com.oracle.truffle.llvm.initialization.StaticInitsNode;
import com.oracle.truffle.llvm.initialization.StaticInitsNodeGen;
import com.oracle.truffle.llvm.parser.LLVMParserResult;
import com.oracle.truffle.llvm.parser.model.symbols.constants.Constant;
import com.oracle.truffle.llvm.parser.model.symbols.constants.aggregate.ArrayConstant;
import com.oracle.truffle.llvm.parser.model.symbols.constants.aggregate.StructureConstant;
import com.oracle.truffle.llvm.parser.model.symbols.globals.GlobalVariable;
import com.oracle.truffle.llvm.parser.nodes.LLVMSymbolReadResolver;
import com.oracle.truffle.llvm.parser.util.Pair;
import com.oracle.truffle.llvm.runtime.CommonNodeFactory;
import com.oracle.truffle.llvm.runtime.LLVMContext;
import com.oracle.truffle.llvm.runtime.LLVMLanguage;
import com.oracle.truffle.llvm.runtime.LLVMScope;
import com.oracle.truffle.llvm.runtime.NodeFactory;
import com.oracle.truffle.llvm.runtime.datalayout.DataLayout;
import com.oracle.truffle.llvm.runtime.global.LLVMGlobal;
import com.oracle.truffle.llvm.runtime.nodes.api.LLVMExpressionNode;
import com.oracle.truffle.llvm.runtime.nodes.api.LLVMHasDatalayoutNode;
import com.oracle.truffle.llvm.runtime.nodes.api.LLVMNode;
import com.oracle.truffle.llvm.runtime.nodes.api.LLVMStatementNode;
import com.oracle.truffle.llvm.runtime.nodes.api.LLVMVoidStatementNode;
import com.oracle.truffle.llvm.runtime.nodes.api.LLVMVoidStatementNodeGen;
import com.oracle.truffle.llvm.runtime.nodes.others.LLVMStatementRootNode;
import com.oracle.truffle.llvm.runtime.types.FunctionType;
import com.oracle.truffle.llvm.runtime.types.PointerType;
import com.oracle.truffle.llvm.runtime.types.PrimitiveType;
import com.oracle.truffle.llvm.runtime.types.StructureType;
import com.oracle.truffle.llvm.runtime.types.Type;
import java.util.ArrayList;
import java.util.Comparator;

public final class InitializeModuleNode
extends LLVMNode
implements LLVMHasDatalayoutNode {
    private static final int LEAST_CONSTRUCTOR_PRIORITY = 65535;
    private static final Comparator<Pair<Integer, ?>> ASCENDING_PRIORITY = (p1, p2) -> (Integer)p1.getFirst() - (Integer)p2.getFirst();
    private static final Comparator<Pair<Integer, ?>> DESCENDING_PRIORITY = (p1, p2) -> (Integer)p2.getFirst() - (Integer)p1.getFirst();
    private final RootCallTarget destructor;
    private final DataLayout dataLayout;
    @Node.Child
    private StaticInitsNode constructor;

    public InitializeModuleNode(LLVMLanguage language, LLVMParserResult parserResult, String moduleName) {
        this.destructor = InitializeModuleNode.createDestructor(parserResult, moduleName, language);
        this.dataLayout = parserResult.getDataLayout();
        this.constructor = InitializeModuleNode.createConstructor(parserResult, moduleName);
    }

    public void execute(VirtualFrame frame, LLVMContext ctx) {
        if (this.destructor != null) {
            ctx.registerDestructorFunctions(this.destructor);
        }
        this.constructor.execute(frame);
    }

    @Override
    public DataLayout getDatalayout() {
        return this.dataLayout;
    }

    public static RootCallTarget createDestructor(LLVMParserResult parserResult, String moduleName, LLVMLanguage language) {
        LLVMStatementNode[] destructors = InitializeModuleNode.createStructor("llvm.global_dtors", parserResult, DESCENDING_PRIORITY);
        if (destructors.length > 0) {
            FrameDescriptor frameDescriptor = new FrameDescriptor();
            LLVMStatementRootNode root = new LLVMStatementRootNode(language, StaticInitsNodeGen.create(destructors, "fini", moduleName), frameDescriptor, parserResult.getRuntime().getNodeFactory().createStackAccess(frameDescriptor));
            return Truffle.getRuntime().createCallTarget((RootNode)root);
        }
        return null;
    }

    private static StaticInitsNode createConstructor(LLVMParserResult parserResult, String moduleName) {
        return StaticInitsNodeGen.create(InitializeModuleNode.createStructor("llvm.global_ctors", parserResult, ASCENDING_PRIORITY), "init", moduleName);
    }

    private static LLVMStatementNode[] createStructor(String name, LLVMParserResult parserResult, Comparator<Pair<Integer, ?>> priorityComparator) {
        for (GlobalVariable globalVariable : parserResult.getDefinedGlobals()) {
            if (!globalVariable.getName().equals(name)) continue;
            return InitializeModuleNode.resolveStructor(parserResult.getRuntime().getFileScope(), globalVariable, priorityComparator, parserResult.getDataLayout(), parserResult.getRuntime().getNodeFactory());
        }
        return LLVMStatementNode.NO_STATEMENTS;
    }

    private static LLVMStatementNode[] resolveStructor(LLVMScope fileScope, GlobalVariable globalSymbol, Comparator<Pair<Integer, ?>> priorityComparator, DataLayout dataLayout, NodeFactory nodeFactory) {
        if (!(globalSymbol.getValue() instanceof ArrayConstant)) {
            return LLVMStatementNode.NO_STATEMENTS;
        }
        LLVMGlobal global = (LLVMGlobal)fileScope.get(globalSymbol.getName());
        ArrayConstant arrayConstant = (ArrayConstant)globalSymbol.getValue();
        int elemCount = arrayConstant.getElementCount();
        StructureType elementType = (StructureType)arrayConstant.getType().getElementType();
        try {
            long elementSize = elementType.getSize(dataLayout);
            FunctionType functionType = (FunctionType)((PointerType)elementType.getElementType(1L)).getPointeeType();
            int indexedTypeLength = functionType.getAlignment(dataLayout);
            ArrayList<Pair<Integer, LLVMVoidStatementNode>> structors = new ArrayList<Pair<Integer, LLVMVoidStatementNode>>(elemCount);
            for (int i = 0; i < elemCount; ++i) {
                LLVMExpressionNode globalVarAddress = CommonNodeFactory.createLiteral(global, new PointerType(globalSymbol.getType()));
                LLVMExpressionNode iNode = CommonNodeFactory.createLiteral(i, PrimitiveType.I32);
                LLVMExpressionNode structPointer = nodeFactory.createTypedElementPointer(elementSize, elementType, globalVarAddress, iNode);
                LLVMExpressionNode loadedStruct = nodeFactory.createLoad(elementType, structPointer);
                LLVMExpressionNode oneLiteralNode = CommonNodeFactory.createLiteral(1, PrimitiveType.I32);
                LLVMExpressionNode functionLoadTarget = nodeFactory.createTypedElementPointer(indexedTypeLength, functionType, loadedStruct, oneLiteralNode);
                LLVMExpressionNode loadedFunction = nodeFactory.createLoad(functionType, functionLoadTarget);
                LLVMExpressionNode[] argNodes = new LLVMExpressionNode[]{nodeFactory.createGetStackFromFrame()};
                LLVMVoidStatementNode functionCall = LLVMVoidStatementNodeGen.create(CommonNodeFactory.createFunctionCall(loadedFunction, argNodes, functionType));
                StructureConstant structorDefinition = (StructureConstant)arrayConstant.getElement(i);
                Constant prioritySymbol = structorDefinition.getElement(0);
                Integer priority = LLVMSymbolReadResolver.evaluateIntegerConstant(prioritySymbol);
                structors.add(new Pair<Integer, LLVMVoidStatementNode>(priority != null ? priority : 65535, functionCall));
            }
            return (LLVMStatementNode[])structors.stream().sorted(priorityComparator).map(Pair::getSecond).toArray(LLVMStatementNode[]::new);
        }
        catch (Type.TypeOverflowException e) {
            return new LLVMStatementNode[]{Type.handleOverflowStatement(e)};
        }
    }
}

