/*
 * Decompiled with CFR 0.152.
 */
package com.oracle.truffle.llvm.runtime.nodes.intrinsics.interop;

import com.oracle.truffle.api.CallTarget;
import com.oracle.truffle.api.CompilerDirectives;
import com.oracle.truffle.api.Truffle;
import com.oracle.truffle.api.TruffleLanguage;
import com.oracle.truffle.api.dsl.Cached;
import com.oracle.truffle.api.dsl.Fallback;
import com.oracle.truffle.api.dsl.NodeChild;
import com.oracle.truffle.api.dsl.NodeChildren;
import com.oracle.truffle.api.dsl.Specialization;
import com.oracle.truffle.api.frame.VirtualFrame;
import com.oracle.truffle.api.nodes.DirectCallNode;
import com.oracle.truffle.api.nodes.Node;
import com.oracle.truffle.api.nodes.RootNode;
import com.oracle.truffle.api.profiles.ConditionProfile;
import com.oracle.truffle.llvm.runtime.LLVMContext;
import com.oracle.truffle.llvm.runtime.LLVMFunctionDescriptor;
import com.oracle.truffle.llvm.runtime.LLVMLanguage;
import com.oracle.truffle.llvm.runtime.except.LLVMPolyglotException;
import com.oracle.truffle.llvm.runtime.interop.LLVMTypedForeignObject;
import com.oracle.truffle.llvm.runtime.memory.LLVMMemory;
import com.oracle.truffle.llvm.runtime.memory.LLVMStack;
import com.oracle.truffle.llvm.runtime.nodes.api.LLVMExpressionNode;
import com.oracle.truffle.llvm.runtime.nodes.func.LLVMDispatchNode;
import com.oracle.truffle.llvm.runtime.nodes.func.LLVMDispatchNodeGen;
import com.oracle.truffle.llvm.runtime.nodes.intrinsics.llvm.LLVMIntrinsic;
import com.oracle.truffle.llvm.runtime.nodes.memory.load.LLVMDerefHandleGetReceiverNode;
import com.oracle.truffle.llvm.runtime.pointer.LLVMManagedPointer;
import com.oracle.truffle.llvm.runtime.pointer.LLVMNativePointer;
import com.oracle.truffle.llvm.runtime.types.FunctionType;
import com.oracle.truffle.llvm.runtime.types.Type;

@NodeChildren(value={@NodeChild(type=LLVMExpressionNode.class), @NodeChild(type=LLVMExpressionNode.class)})
public abstract class LLVMTruffleDecorateFunction
extends LLVMIntrinsic {
    @Node.Child
    private LLVMDerefHandleGetReceiverNode derefHandleGetReceiverNode;
    @CompilerDirectives.CompilationFinal
    private TruffleLanguage.ContextReference<LLVMContext> contextRef;
    @CompilerDirectives.CompilationFinal
    private LLVMMemory cachedMemory;

    protected LLVMDerefHandleGetReceiverNode getDerefHandleGetReceiverNode() {
        if (this.derefHandleGetReceiverNode == null) {
            CompilerDirectives.transferToInterpreterAndInvalidate();
            this.derefHandleGetReceiverNode = (LLVMDerefHandleGetReceiverNode)this.insert(LLVMDerefHandleGetReceiverNode.create());
        }
        return this.derefHandleGetReceiverNode;
    }

    private LLVMMemory getLLVMMemoryCached() {
        if (this.cachedMemory == null) {
            CompilerDirectives.transferToInterpreterAndInvalidate();
            this.cachedMemory = LLVMTruffleDecorateFunction.getLLVMMemory();
        }
        return this.cachedMemory;
    }

    protected boolean isAutoDerefHandle(LLVMNativePointer addr) {
        return this.getLLVMMemoryCached().isDerefHandleMemory(addr.asNative());
    }

    @Specialization(guards={"!isAutoDerefHandle(func)"})
    protected Object decorate(LLVMNativePointer func, LLVMNativePointer wrapper) {
        return this.decorate(this.getContext().getFunctionDescriptor(func), this.getContext().getFunctionDescriptor(wrapper));
    }

    @Specialization(guards={"isAutoDerefHandle(func)"})
    protected Object decorateDerefHandle(LLVMNativePointer func, LLVMNativePointer wrapper, @Cached(value="createBinaryProfile()") ConditionProfile isFunctionDescriptorProfile) {
        LLVMManagedPointer resolved = this.getDerefHandleGetReceiverNode().execute(func);
        if (isFunctionDescriptorProfile.profile(LLVMTruffleDecorateFunction.isFunctionDescriptor(resolved.getObject()))) {
            return this.decorate(resolved, wrapper);
        }
        return this.doGeneric(func, wrapper);
    }

    @Specialization(guards={"!isAutoDerefHandle(func)", "isFunctionDescriptor(wrapper.getObject())"})
    protected Object decorate(LLVMNativePointer func, LLVMManagedPointer wrapper) {
        return this.decorate(this.getContext().getFunctionDescriptor(func), (LLVMFunctionDescriptor)wrapper.getObject());
    }

    @Specialization(guards={"isAutoDerefHandle(func)", "isFunctionDescriptor(wrapper.getObject())"})
    protected Object decorateDerefHandle(LLVMNativePointer func, LLVMManagedPointer wrapper, @Cached(value="createBinaryProfile()") ConditionProfile isFunctionDescriptorProfile) {
        LLVMManagedPointer resolved = this.getDerefHandleGetReceiverNode().execute(func);
        if (isFunctionDescriptorProfile.profile(LLVMTruffleDecorateFunction.isFunctionDescriptor(resolved.getObject()))) {
            return this.decorate(resolved, wrapper);
        }
        if (LLVMTruffleDecorateFunction.isForeignFunction(resolved.getObject())) {
            return this.decorateForeign(resolved, wrapper);
        }
        return this.doGeneric(func, wrapper);
    }

    private Object decorateForeign(LLVMManagedPointer resolved, LLVMManagedPointer wrapper) {
        LLVMTypedForeignObject foreign = (LLVMTypedForeignObject)resolved.getObject();
        return this.decorateForeign(foreign, (LLVMFunctionDescriptor)wrapper.getObject());
    }

    @Specialization(guards={"isFunctionDescriptor(func.getObject())"})
    protected Object decorate(LLVMManagedPointer func, LLVMNativePointer wrapper) {
        return this.decorate((LLVMFunctionDescriptor)func.getObject(), this.getContext().getFunctionDescriptor(wrapper));
    }

    @Specialization(guards={"isFunctionDescriptor(func.getObject())", "isFunctionDescriptor(wrapper.getObject())"})
    protected Object decorate(LLVMManagedPointer func, LLVMManagedPointer wrapper) {
        return this.decorate((LLVMFunctionDescriptor)func.getObject(), (LLVMFunctionDescriptor)wrapper.getObject());
    }

    @Fallback
    protected Object doGeneric(Object func, Object wrapper) {
        throw new LLVMPolyglotException(this, "invalid arguments for function composition");
    }

    @CompilerDirectives.TruffleBoundary
    private Object decorate(LLVMFunctionDescriptor function, LLVMFunctionDescriptor wrapperFunction) {
        assert (function != null && wrapperFunction != null);
        FunctionType newFunctionType = new FunctionType(wrapperFunction.getType().getReturnType(), function.getType().getArgumentTypes(), function.getType().isVarargs());
        NativeDecoratedRoot decoratedRoot = new NativeDecoratedRoot(this.lookupLanguageReference(LLVMLanguage.class).get(), function, wrapperFunction);
        return this.registerRoot(function.getLibrary(), newFunctionType, decoratedRoot);
    }

    @CompilerDirectives.TruffleBoundary
    private Object decorateForeign(Object function, LLVMFunctionDescriptor wrapperFunction) {
        assert (function != null && wrapperFunction != null);
        FunctionType newFunctionType = new FunctionType(wrapperFunction.getType().getReturnType(), Type.EMPTY_ARRAY, true);
        ForeignDecoratedRoot decoratedRoot = new ForeignDecoratedRoot(this.lookupLanguageReference(LLVMLanguage.class).get(), newFunctionType, function, wrapperFunction);
        return this.registerRoot(wrapperFunction.getLibrary(), newFunctionType, decoratedRoot);
    }

    private Object registerRoot(LLVMContext.ExternalLibrary lib, FunctionType newFunctionType, DecoratedRoot decoratedRoot) {
        LLVMFunctionDescriptor.LLVMIRFunction function = new LLVMFunctionDescriptor.LLVMIRFunction(Truffle.getRuntime().createCallTarget((RootNode)decoratedRoot), null);
        LLVMFunctionDescriptor wrappedFunction = new LLVMFunctionDescriptor(this.getContext(), "<wrapper>", newFunctionType, -1, function, lib);
        return LLVMManagedPointer.create(wrappedFunction);
    }

    protected static boolean isForeignFunction(Object object) {
        return object instanceof LLVMTypedForeignObject;
    }

    private LLVMContext getContext() {
        if (this.contextRef == null) {
            CompilerDirectives.transferToInterpreterAndInvalidate();
            this.contextRef = this.lookupContextReference(LLVMLanguage.class);
        }
        return (LLVMContext)this.contextRef.get();
    }

    protected static class ForeignDecoratedRoot
    extends DecoratedRoot {
        @Node.Child
        private LLVMDispatchNode funcCallNode;
        @Node.Child
        private DirectCallNode wrapperCallNode;
        private final Object func;

        protected ForeignDecoratedRoot(TruffleLanguage<?> language, FunctionType type, Object func, LLVMFunctionDescriptor wrapper) {
            super(language);
            this.funcCallNode = LLVMDispatchNodeGen.create(type);
            this.func = func;
            this.wrapperCallNode = Truffle.getRuntime().createDirectCallNode((CallTarget)wrapper.getLLVMIRFunctionSlowPath());
            this.wrapperCallNode.cloneCallTarget();
        }

        public Object execute(VirtualFrame frame) {
            Object[] arguments = frame.getArguments();
            Object result = this.funcCallNode.executeDispatch(this.func, arguments);
            Object[] wrapperArgs = new Object[]{arguments[0], result};
            try (LLVMStack.StackPointer sp = ((LLVMStack.StackPointer)arguments[0]).newFrame();){
                Object object = this.wrapperCallNode.call(wrapperArgs);
                return object;
            }
        }
    }

    protected static class NativeDecoratedRoot
    extends DecoratedRoot {
        @Node.Child
        private DirectCallNode funcCallNode;
        @Node.Child
        private DirectCallNode wrapperCallNode;

        protected NativeDecoratedRoot(TruffleLanguage<?> language, LLVMFunctionDescriptor func, LLVMFunctionDescriptor wrapper) {
            super(language);
            this.funcCallNode = Truffle.getRuntime().createDirectCallNode((CallTarget)func.getLLVMIRFunctionSlowPath());
            this.wrapperCallNode = Truffle.getRuntime().createDirectCallNode((CallTarget)wrapper.getLLVMIRFunctionSlowPath());
            this.funcCallNode.cloneCallTarget();
            this.wrapperCallNode.cloneCallTarget();
            this.funcCallNode.forceInlining();
            this.wrapperCallNode.forceInlining();
        }

        public Object execute(VirtualFrame frame) {
            Object result;
            Object[] arguments = frame.getArguments();
            try (LLVMStack.StackPointer sp = ((LLVMStack.StackPointer)arguments[0]).newFrame();){
                result = this.funcCallNode.call(arguments);
            }
            Object[] wrapperArgs = new Object[]{arguments[0], result};
            try (LLVMStack.StackPointer sp = ((LLVMStack.StackPointer)arguments[0]).newFrame();){
                Object object = this.wrapperCallNode.call(wrapperArgs);
                return object;
            }
        }
    }

    protected static abstract class DecoratedRoot
    extends RootNode {
        protected DecoratedRoot(TruffleLanguage<?> language) {
            super(language);
        }
    }
}

