/*
 * Decompiled with CFR 0.152.
 */
package org.apache.spark.shuffle.writer;

import java.io.Closeable;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import org.apache.spark.shuffle.writer.AddBlockEvent;
import org.apache.spark.shuffle.writer.WriteBufferManager;
import org.apache.uniffle.client.api.ShuffleWriteClient;
import org.apache.uniffle.client.common.ShuffleServerPushCostTracker;
import org.apache.uniffle.client.impl.FailedBlockSendTracker;
import org.apache.uniffle.client.response.SendShuffleDataResult;
import org.apache.uniffle.common.ShuffleBlockInfo;
import org.apache.uniffle.common.ShuffleServerInfo;
import org.apache.uniffle.common.exception.RssException;
import org.apache.uniffle.common.rpc.StatusCode;
import org.apache.uniffle.common.util.ThreadUtils;
import org.apache.uniffle.shaded.com.google.common.collect.Queues;
import org.apache.uniffle.shaded.com.google.common.collect.Sets;
import org.apache.uniffle.shaded.org.apache.commons.collections4.CollectionUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class DataPusher
implements Closeable {
    private static final Logger LOGGER = LoggerFactory.getLogger(DataPusher.class);
    private final ExecutorService executorService;
    private final ShuffleWriteClient shuffleWriteClient;
    private final Map<String, Set<Long>> taskToSuccessBlockIds;
    Map<String, FailedBlockSendTracker> taskToFailedBlockSendTracker;
    private String rssAppId;
    private final Set<String> failedTaskIds;

    public DataPusher(ShuffleWriteClient shuffleWriteClient, Map<String, Set<Long>> taskToSuccessBlockIds, Map<String, FailedBlockSendTracker> taskToFailedBlockSendTracker, Set<String> failedTaskIds, int threadPoolSize, int threadKeepAliveTime) {
        this.shuffleWriteClient = shuffleWriteClient;
        this.taskToSuccessBlockIds = taskToSuccessBlockIds;
        this.taskToFailedBlockSendTracker = taskToFailedBlockSendTracker;
        this.failedTaskIds = failedTaskIds;
        this.executorService = new ThreadPoolExecutor(threadPoolSize, threadPoolSize * 2, (long)threadKeepAliveTime, TimeUnit.SECONDS, Queues.newLinkedBlockingQueue(Integer.MAX_VALUE), ThreadUtils.getThreadFactory(this.getClass().getName()));
    }

    public CompletableFuture<Long> send(AddBlockEvent event) {
        if (this.rssAppId == null) {
            throw new RssException("RssAppId should be set.");
        }
        return CompletableFuture.supplyAsync(() -> {
            Set<Long> succeedBlockIds;
            List<ShuffleBlockInfo> shuffleBlockInfoList;
            String taskId = event.getTaskId();
            List<ShuffleBlockInfo> validBlocks = this.filterOutStaleAssignmentBlocks(taskId, shuffleBlockInfoList = event.getShuffleDataInfoList());
            if (CollectionUtils.isEmpty(validBlocks)) {
                return 0L;
            }
            SendShuffleDataResult result = null;
            try {
                result = this.shuffleWriteClient.sendShuffleData(this.rssAppId, event.getStageAttemptNumber(), validBlocks, () -> !this.isValidTask(taskId));
                this.putBlockId(this.taskToSuccessBlockIds, taskId, result.getSuccessBlockIds());
                this.putFailedBlockSendTracker(this.taskToFailedBlockSendTracker, taskId, result.getFailedBlockSendTracker());
                WriteBufferManager bufferManager = event.getBufferManager();
                if (bufferManager != null && result != null) {
                    ShuffleServerPushCostTracker shuffleServerPushCostTracker = result.getShuffleServerPushCostTracker();
                    bufferManager.merge(shuffleServerPushCostTracker);
                }
                succeedBlockIds = this.getSucceedBlockIds(result);
            }
            catch (Throwable throwable) {
                WriteBufferManager bufferManager = event.getBufferManager();
                if (bufferManager != null && result != null) {
                    ShuffleServerPushCostTracker shuffleServerPushCostTracker = result.getShuffleServerPushCostTracker();
                    bufferManager.merge(shuffleServerPushCostTracker);
                }
                Set<Long> succeedBlockIds2 = this.getSucceedBlockIds(result);
                for (ShuffleBlockInfo block : validBlocks) {
                    block.executeCompletionCallback(succeedBlockIds2.contains(block.getBlockId()));
                }
                List<Runnable> callbackChain = Optional.of(event.getProcessedCallbackChain()).orElse(Collections.EMPTY_LIST);
                for (Runnable runnable : callbackChain) {
                    runnable.run();
                }
                throw throwable;
            }
            for (ShuffleBlockInfo block : validBlocks) {
                block.executeCompletionCallback(succeedBlockIds.contains(block.getBlockId()));
            }
            List<Runnable> callbackChain = Optional.of(event.getProcessedCallbackChain()).orElse(Collections.EMPTY_LIST);
            for (Runnable runnable : callbackChain) {
                runnable.run();
            }
            Set<Long> succeedBlockIds3 = this.getSucceedBlockIds(result);
            return validBlocks.stream().filter(x -> succeedBlockIds3.contains(x.getBlockId())).map(x -> x.getFreeMemory()).reduce((a, b) -> a + b).orElse(0L);
        }, this.executorService).exceptionally(ex -> {
            LOGGER.error("Unexpected exceptions occurred while sending shuffle data", ex);
            return null;
        });
    }

    private List<ShuffleBlockInfo> filterOutStaleAssignmentBlocks(String taskId, List<ShuffleBlockInfo> blocks) {
        FailedBlockSendTracker staleBlockTracker = new FailedBlockSendTracker();
        ArrayList<ShuffleBlockInfo> validBlocks = new ArrayList<ShuffleBlockInfo>();
        for (ShuffleBlockInfo block : blocks) {
            List<ShuffleServerInfo> servers = block.getShuffleServerInfos();
            if (servers == null || servers.size() != 1) {
                validBlocks.add(block);
                continue;
            }
            if (block.isStaleAssignment()) {
                staleBlockTracker.add(block, block.getShuffleServerInfos().get(0), StatusCode.INTERNAL_ERROR);
                continue;
            }
            validBlocks.add(block);
        }
        this.putFailedBlockSendTracker(this.taskToFailedBlockSendTracker, taskId, staleBlockTracker);
        return validBlocks;
    }

    private Set<Long> getSucceedBlockIds(SendShuffleDataResult result) {
        if (result == null || result.getSuccessBlockIds() == null) {
            return Collections.emptySet();
        }
        return result.getSuccessBlockIds();
    }

    private synchronized void putBlockId(Map<String, Set<Long>> taskToBlockIds, String taskAttemptId, Set<Long> blockIds) {
        if (blockIds == null || blockIds.isEmpty()) {
            return;
        }
        taskToBlockIds.computeIfAbsent(taskAttemptId, x -> Sets.newConcurrentHashSet()).addAll(blockIds);
    }

    private synchronized void putFailedBlockSendTracker(Map<String, FailedBlockSendTracker> taskToFailedBlockSendTracker, String taskAttemptId, FailedBlockSendTracker failedBlockSendTracker) {
        if (failedBlockSendTracker == null || failedBlockSendTracker.isEmpty()) {
            return;
        }
        taskToFailedBlockSendTracker.computeIfAbsent(taskAttemptId, x -> new FailedBlockSendTracker()).merge(failedBlockSendTracker);
    }

    public boolean isValidTask(String taskId) {
        return !this.failedTaskIds.contains(taskId);
    }

    public void setRssAppId(String rssAppId) {
        this.rssAppId = rssAppId;
    }

    @Override
    public void close() throws IOException {
        if (this.executorService != null) {
            try {
                ThreadUtils.shutdownThreadPool(this.executorService, 5);
            }
            catch (InterruptedException interruptedException) {
                LOGGER.error("Errors on shutdown thread pool of [{}].", (Object)this.getClass().getSimpleName());
            }
        }
    }
}

