/*
 * Decompiled with CFR 0.152.
 */
package org.apache.uniffle.server.merge;

import io.netty.buffer.ByteBuf;
import java.io.File;
import java.io.IOException;
import java.lang.reflect.Constructor;
import java.lang.reflect.InvocationTargetException;
import java.net.URL;
import java.net.URLClassLoader;
import java.security.AccessController;
import java.security.PrivilegedExceptionAction;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.commons.lang3.ClassUtils;
import org.apache.commons.lang3.StringUtils;
import org.apache.uniffle.common.ShuffleDataResult;
import org.apache.uniffle.common.config.RssConf;
import org.apache.uniffle.common.merger.MergeState;
import org.apache.uniffle.common.merger.Segment;
import org.apache.uniffle.common.rpc.StatusCode;
import org.apache.uniffle.common.serializer.SerOutputStream;
import org.apache.uniffle.common.util.JavaUtils;
import org.apache.uniffle.proto.RssProtos;
import org.apache.uniffle.server.ShuffleServer;
import org.apache.uniffle.server.ShuffleServerConf;
import org.apache.uniffle.server.merge.BlockFlushFileReader;
import org.apache.uniffle.server.merge.DefaultMergeEventHandler;
import org.apache.uniffle.server.merge.MergeEvent;
import org.apache.uniffle.server.merge.MergeEventHandler;
import org.apache.uniffle.server.merge.MergeStatus;
import org.apache.uniffle.server.merge.Partition;
import org.apache.uniffle.server.merge.Shuffle;
import org.apache.uniffle.shaded.guava.annotations.VisibleForTesting;
import org.roaringbitmap.longlong.Roaring64NavigableMap;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class ShuffleMergeManager {
    private static final Logger LOG = LoggerFactory.getLogger(ShuffleMergeManager.class);
    public static final String MERGE_APP_SUFFIX = "@RemoteMerge";
    private ShuffleServerConf serverConf;
    private final ShuffleServer shuffleServer;
    private final Map<String, Map<Integer, Shuffle>> shuffles = JavaUtils.newConcurrentMap();
    private final MergeEventHandler eventHandler;
    private final Map<String, ClassLoader> cachedClassLoader = new HashMap<String, ClassLoader>();
    private Comparator defaultComparator = new Comparator(){

        public int compare(Object o1, Object o2) {
            int h2;
            int h1 = o1 == null ? 0 : o1.hashCode();
            int n = h2 = o2 == null ? 0 : o2.hashCode();
            return h1 < h2 ? -1 : (h1 == h2 ? 0 : 1);
        }
    };

    public ShuffleMergeManager(ShuffleServerConf serverConf, ShuffleServer shuffleServer) throws Exception {
        this.serverConf = serverConf;
        this.shuffleServer = shuffleServer;
        this.eventHandler = new DefaultMergeEventHandler(this.serverConf, this::processEvent);
        this.initCacheClassLoader();
    }

    public void initCacheClassLoader() throws Exception {
        this.addCacheClassLoader("", this.serverConf.getString(ShuffleServerConf.SERVER_MERGE_CLASS_LOADER_JARS_PATH));
        Map props = this.serverConf.getPropsWithPrefix(ShuffleServerConf.SERVER_MERGE_CLASS_LOADER_JARS_PATH.key() + ".");
        for (Map.Entry prop : props.entrySet()) {
            this.addCacheClassLoader((String)prop.getKey(), (String)prop.getValue());
        }
    }

    public void addCacheClassLoader(String label, final String jarsPath) throws Exception {
        if (StringUtils.isNotBlank((CharSequence)jarsPath)) {
            File jarsPathFile = new File(jarsPath);
            if (jarsPathFile.exists()) {
                if (jarsPathFile.isFile()) {
                    URLClassLoader urlClassLoader = AccessController.doPrivileged(new PrivilegedExceptionAction<URLClassLoader>(){

                        @Override
                        public URLClassLoader run() throws Exception {
                            return new URLClassLoader(new URL[]{new URL("file://" + jarsPath)}, Thread.currentThread().getContextClassLoader());
                        }
                    });
                    this.cachedClassLoader.put(label, urlClassLoader);
                } else if (jarsPathFile.isDirectory()) {
                    File[] files = jarsPathFile.listFiles();
                    final ArrayList<URL> urlList = new ArrayList<URL>();
                    if (files != null) {
                        for (File file : files) {
                            if (!file.getName().endsWith(".jar")) continue;
                            urlList.add(new URL("file://" + file.getAbsolutePath()));
                        }
                    }
                    URLClassLoader urlClassLoader = AccessController.doPrivileged(new PrivilegedExceptionAction<URLClassLoader>(){

                        @Override
                        public URLClassLoader run() throws Exception {
                            return new URLClassLoader(urlList.toArray(new URL[urlList.size()]), Thread.currentThread().getContextClassLoader());
                        }
                    });
                    this.cachedClassLoader.put(label, urlClassLoader);
                } else {
                    this.cachedClassLoader.put(label, Thread.currentThread().getContextClassLoader());
                }
            }
        } else {
            this.cachedClassLoader.put(label, Thread.currentThread().getContextClassLoader());
        }
    }

    public ClassLoader getClassLoader(String label) {
        if (StringUtils.isBlank((CharSequence)label)) {
            return this.cachedClassLoader.get("");
        }
        return this.cachedClassLoader.getOrDefault(label, this.cachedClassLoader.get(""));
    }

    public StatusCode registerShuffle(String appId, int shuffleId, RssProtos.MergeContext mergeContext) {
        try {
            Comparator comparator;
            ClassLoader classLoader = this.getClassLoader(mergeContext.getMergeClassLoader());
            Class kClass = ClassUtils.getClass((ClassLoader)classLoader, (String)mergeContext.getKeyClass());
            Class vClass = ClassUtils.getClass((ClassLoader)classLoader, (String)mergeContext.getValueClass());
            if (StringUtils.isNotBlank((CharSequence)mergeContext.getComparatorClass())) {
                Constructor constructor = ClassUtils.getClass((ClassLoader)classLoader, (String)mergeContext.getComparatorClass()).getDeclaredConstructor(new Class[0]);
                constructor.setAccessible(true);
                comparator = (Comparator)constructor.newInstance(new Object[0]);
            } else {
                comparator = this.defaultComparator;
            }
            this.shuffles.computeIfAbsent(appId, key -> JavaUtils.newConcurrentMap());
            this.shuffles.get(appId).computeIfAbsent(shuffleId, key -> new Shuffle((RssConf)this.serverConf, this.eventHandler, this.shuffleServer, appId, shuffleId, kClass, vClass, comparator, mergeContext.getMergedBlockSize(), classLoader));
        }
        catch (ClassNotFoundException | IllegalAccessException | InstantiationException | NoSuchMethodException | InvocationTargetException e) {
            LOG.info("Cant register shuffle, caused by ", (Throwable)e);
            this.removeBuffer(appId, shuffleId);
            return StatusCode.INTERNAL_ERROR;
        }
        return StatusCode.SUCCESS;
    }

    public void removeBuffer(String appId) {
        if (this.shuffles.containsKey(appId)) {
            for (Integer shuffleId : this.shuffles.get(appId).keySet()) {
                this.removeBuffer(appId, shuffleId);
            }
        }
    }

    public void removeBuffer(String appId, List<Integer> shuffleIds) {
        if (this.shuffles.containsKey(appId)) {
            for (Integer shuffleId : shuffleIds) {
                this.removeBuffer(appId, shuffleId);
            }
        }
    }

    public void removeBuffer(String appId, int shuffleId) {
        if (this.shuffles.containsKey(appId)) {
            if (this.shuffles.get(appId).containsKey(shuffleId)) {
                this.shuffles.get(appId).get(shuffleId).cleanup();
                this.shuffles.get(appId).remove(shuffleId);
            }
            if (this.shuffles.get(appId).size() == 0) {
                this.shuffles.remove(appId);
            }
        }
    }

    public void startSortMerge(String appId, int shuffleId, int partitionId, Roaring64NavigableMap expectedBlockIdMap) throws IOException {
        Shuffle shuffle;
        Map<Integer, Shuffle> shuffleMap = this.shuffles.get(appId);
        if (shuffleMap != null && (shuffle = shuffleMap.get(shuffleId)) != null) {
            shuffle.startSortMerge(partitionId, expectedBlockIdMap);
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public void processEvent(MergeEvent event) {
        boolean success = false;
        Partition partition = null;
        HashMap<Long, ByteBuf> cachedBlocks = new HashMap<Long, ByteBuf>();
        try {
            Thread.currentThread().setContextClassLoader(this.getShuffle(event.getAppId(), event.getShuffleId()).getClassLoader());
            partition = this.getPartition(event.getAppId(), event.getShuffleId(), event.getPartitionId());
            if (partition == null) {
                LOG.info("Can not find partition for event: {}", (Object)event);
                return;
            }
            boolean allCached = partition.collectBlocks(event.getExpectedBlockIdMap().iterator(), cachedBlocks);
            BlockFlushFileReader reader = null;
            if (!allCached) {
                reader = partition.createReader((RssConf)this.serverConf);
            }
            ArrayList<Segment> segments = new ArrayList<Segment>();
            boolean allFound = partition.collectSegments((RssConf)this.serverConf, event.getExpectedBlockIdMap().iterator(), event.getKeyClass(), event.getValueClass(), cachedBlocks, segments, reader);
            if (!allFound) {
                return;
            }
            long totalBytes = segments.stream().mapToLong(segment -> segment.getSize()).sum();
            SerOutputStream output = partition.createSerOutputStream(totalBytes);
            partition.merge(segments, output, reader);
            success = true;
        }
        catch (Throwable e) {
            LOG.error("Merge failed, caused by ", e);
        }
        finally {
            if (!success && partition != null) {
                partition.setState(MergeState.INTERNAL_ERROR);
            }
            cachedBlocks.values().forEach(byteBuf -> byteBuf.release());
        }
    }

    public ShuffleDataResult getShuffleData(String appId, int shuffleId, int partitionId, long blockId) throws IOException {
        return this.getPartition(appId, shuffleId, partitionId).getShuffleData(blockId);
    }

    public void setDirect(String appId, int shuffleId, boolean direct) throws IOException {
        if (this.shuffles.containsKey(appId) && this.shuffles.get(appId).containsKey(shuffleId)) {
            this.getShuffle(appId, shuffleId).setDirect(direct);
        }
    }

    public MergeStatus tryGetBlock(String appId, int shuffleId, int partitionId, long blockId) {
        return this.getPartition(appId, shuffleId, partitionId).tryGetBlock(blockId);
    }

    @VisibleForTesting
    MergeEventHandler getEventHandler() {
        return this.eventHandler;
    }

    Shuffle getShuffle(String appId, int shuffleId) {
        return this.shuffles.get(appId).get(shuffleId);
    }

    @VisibleForTesting
    Partition getPartition(String appId, int shuffleId, int partitionId) {
        if (this.shuffles.containsKey(appId) && this.shuffles.get(appId).containsKey(shuffleId)) {
            return this.shuffles.get(appId).get(shuffleId).getPartition(partitionId);
        }
        return null;
    }

    public void refreshAppId(String appId) {
        this.shuffleServer.getShuffleTaskManager().refreshAppId(appId + MERGE_APP_SUFFIX);
    }
}

