package org.apache.flink.runtime.scheduler.adaptivebatch;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Optional;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.Stream;
import org.apache.flink.configuration.BatchExecutionOptions;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.configuration.MemorySize;
import org.apache.flink.runtime.executiongraph.ExecutionVertexInputInfo;
import org.apache.flink.runtime.executiongraph.IndexRange;
import org.apache.flink.runtime.executiongraph.JobVertexInputInfo;
import org.apache.flink.runtime.executiongraph.ParallelismAndInputInfos;
import org.apache.flink.runtime.executiongraph.VertexInputInfoComputationUtils;
import org.apache.flink.runtime.jobgraph.JobVertexID;
import org.apache.flink.util.Preconditions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/apache/flink/runtime/scheduler/adaptivebatch/DefaultVertexParallelismAndInputInfosDecider.class */
public class DefaultVertexParallelismAndInputInfosDecider implements VertexParallelismAndInputInfosDecider {
    private static final Logger LOG = LoggerFactory.getLogger((Class<?>) DefaultVertexParallelismAndInputInfosDecider.class);
    private static final int MAX_NUM_SUBPARTITIONS_PER_TASK_CONSUME = 32768;
    private final int globalMaxParallelism;
    private final int globalMinParallelism;
    private final long dataVolumePerTask;
    private final int globalDefaultSourceParallelism;

    private DefaultVertexParallelismAndInputInfosDecider(int i, int i2, MemorySize memorySize, int i3) {
        Preconditions.checkArgument(i2 > 0, "The minimum parallelism must be larger than 0.");
        Preconditions.checkArgument(i >= i2, "Maximum parallelism should be greater than or equal to the minimum parallelism.");
        Preconditions.checkArgument(i3 > 0, "The default source parallelism must be larger than 0.");
        Preconditions.checkNotNull(memorySize);
        this.globalMaxParallelism = i;
        this.globalMinParallelism = i2;
        this.dataVolumePerTask = memorySize.getBytes();
        this.globalDefaultSourceParallelism = i3;
    }

    @Override // org.apache.flink.runtime.scheduler.adaptivebatch.VertexParallelismAndInputInfosDecider
    public ParallelismAndInputInfos decideParallelismAndInputInfosForVertex(JobVertexID jobVertexID, List<BlockingResultInfo> list, int i, int i2) {
        Preconditions.checkArgument(i == -1 || i > 0);
        Preconditions.checkArgument(i2 > 0 && i2 >= i);
        if (list.isEmpty()) {
            return new ParallelismAndInputInfos(i > 0 ? i : computeSourceParallelism(jobVertexID, i2), Collections.emptyMap());
        }
        int i3 = this.globalMinParallelism;
        int i4 = this.globalMaxParallelism;
        if (i == -1 && i2 < i3) {
            LOG.info("The vertex maximum parallelism {} is smaller than the global minimum parallelism {}. Use {} as the lower bound to decide parallelism of job vertex {}.", Integer.valueOf(i2), Integer.valueOf(i3), Integer.valueOf(i2), jobVertexID);
            i3 = i2;
        }
        if (i == -1 && i2 < i4) {
            LOG.info("The vertex maximum parallelism {} is smaller than the global maximum parallelism {}. Use {} as the upper bound to decide parallelism of job vertex {}.", Integer.valueOf(i2), Integer.valueOf(i4), Integer.valueOf(i2), jobVertexID);
            i4 = i2;
        }
        Preconditions.checkState(i4 >= i3);
        return (i == -1 && areAllInputsAllToAll(list) && !areAllInputsBroadcast(list)) ? decideParallelismAndEvenlyDistributeData(jobVertexID, list, i, i3, i4) : decideParallelismAndEvenlyDistributeSubpartitions(jobVertexID, list, i, i3, i4);
    }

    private int computeSourceParallelism(JobVertexID jobVertexID, int i) {
        if (this.globalDefaultSourceParallelism <= i) {
            return this.globalDefaultSourceParallelism;
        }
        LOG.info("The global default source parallelism {} is larger than the maximum parallelism {}. Use {} as the parallelism of source job vertex {}.", Integer.valueOf(this.globalDefaultSourceParallelism), Integer.valueOf(i), Integer.valueOf(i), jobVertexID);
        return i;
    }

    private static boolean areAllInputsAllToAll(List<BlockingResultInfo> list) {
        return list.stream().noneMatch((v0) -> {
            return v0.isPointwise();
        });
    }

    private static boolean areAllInputsBroadcast(List<BlockingResultInfo> list) {
        return list.stream().allMatch((v0) -> {
            return v0.isBroadcast();
        });
    }

    private ParallelismAndInputInfos decideParallelismAndEvenlyDistributeSubpartitions(JobVertexID jobVertexID, List<BlockingResultInfo> list, int i, int i2, int i3) {
        Preconditions.checkArgument(!list.isEmpty());
        int decideParallelism = i > 0 ? i : decideParallelism(jobVertexID, list, i2, i3);
        return new ParallelismAndInputInfos(decideParallelism, VertexInputInfoComputationUtils.computeVertexInputInfos(decideParallelism, list, true));
    }

    int decideParallelism(JobVertexID jobVertexID, List<BlockingResultInfo> list, int i, int i2) {
        Preconditions.checkArgument(!list.isEmpty());
        List<BlockingResultInfo> nonBroadcastResultInfos = getNonBroadcastResultInfos(list);
        if (nonBroadcastResultInfos.isEmpty()) {
            return i;
        }
        long sum = nonBroadcastResultInfos.stream().mapToLong((v0) -> {
            return v0.getNumBytesProduced();
        }).sum();
        int max = Math.max((int) Math.ceil(sum / this.dataVolumePerTask), (int) Math.ceil(getMaxNumSubpartitions(nonBroadcastResultInfos) / 32768.0d));
        LOG.debug("The total size of non-broadcast data is {}, the initially decided parallelism of job vertex {} is {}.", new MemorySize(sum), jobVertexID, Integer.valueOf(max));
        if (max < i) {
            LOG.info("The initially decided parallelism {} is smaller than the minimum parallelism {}. Use {} as the finally decided parallelism of job vertex {}.", Integer.valueOf(max), Integer.valueOf(i), Integer.valueOf(i), jobVertexID);
            max = i;
        } else if (max > i2) {
            LOG.info("The initially decided parallelism {} is larger than the maximum parallelism {}. Use {} as the finally decided parallelism of job vertex {}.", Integer.valueOf(max), Integer.valueOf(i2), Integer.valueOf(i2), jobVertexID);
            max = i2;
        }
        return max;
    }

    private ParallelismAndInputInfos decideParallelismAndEvenlyDistributeData(JobVertexID jobVertexID, List<BlockingResultInfo> list, int i, int i2, int i3) {
        Preconditions.checkArgument(i == -1);
        Preconditions.checkArgument(!list.isEmpty());
        list.forEach(blockingResultInfo -> {
            Preconditions.checkState(!blockingResultInfo.isPointwise());
        });
        List<BlockingResultInfo> nonBroadcastResultInfos = getNonBroadcastResultInfos(list);
        int checkAndGetSubpartitionNum = checkAndGetSubpartitionNum(nonBroadcastResultInfos);
        long[] jArr = new long[checkAndGetSubpartitionNum];
        Arrays.fill(jArr, 0L);
        Iterator<BlockingResultInfo> it = nonBroadcastResultInfos.iterator();
        while (it.hasNext()) {
            List<Long> aggregatedSubpartitionBytes = ((AllToAllBlockingResultInfo) it.next()).getAggregatedSubpartitionBytes();
            for (int i4 = 0; i4 < checkAndGetSubpartitionNum; i4++) {
                int i5 = i4;
                jArr[i5] = jArr[i5] + aggregatedSubpartitionBytes.get(i4).longValue();
            }
        }
        int maxNumPartitions = 32768 / getMaxNumPartitions(nonBroadcastResultInfos);
        List<IndexRange> computeSubpartitionRanges = computeSubpartitionRanges(jArr, this.dataVolumePerTask, maxNumPartitions);
        if (!isLegalParallelism(computeSubpartitionRanges.size(), i2, i3)) {
            Optional<List<IndexRange>> adjustToClosestLegalParallelism = adjustToClosestLegalParallelism(this.dataVolumePerTask, computeSubpartitionRanges.size(), i2, i3, Arrays.stream(jArr).min().getAsLong(), Arrays.stream(jArr).sum(), l -> {
                return Integer.valueOf(computeParallelism(jArr, l.longValue(), maxNumPartitions));
            }, l2 -> {
                return computeSubpartitionRanges(jArr, l2.longValue(), maxNumPartitions);
            });
            if (!adjustToClosestLegalParallelism.isPresent()) {
                LOG.info("Cannot find a legal parallelism to evenly distribute data for job vertex {}. Fall back to compute a parallelism that can evenly distribute subpartitions.", jobVertexID);
                return decideParallelismAndEvenlyDistributeSubpartitions(jobVertexID, list, i, i2, i3);
            }
            computeSubpartitionRanges = adjustToClosestLegalParallelism.get();
        }
        Preconditions.checkState(isLegalParallelism(computeSubpartitionRanges.size(), i2, i3));
        return createParallelismAndInputInfos(list, computeSubpartitionRanges);
    }

    private static boolean isLegalParallelism(int i, int i2, int i3) {
        return i >= i2 && i <= i3;
    }

    private static int checkAndGetSubpartitionNum(List<BlockingResultInfo> list) {
        Set set = (Set) list.stream().flatMap(blockingResultInfo -> {
            Stream<Integer> boxed = IntStream.range(0, blockingResultInfo.getNumPartitions()).boxed();
            blockingResultInfo.getClass();
            return boxed.map((v1) -> {
                return r1.getNumSubpartitions(v1);
            });
        }).collect(Collectors.toSet());
        Preconditions.checkState(set.size() == 1);
        return ((Integer) set.iterator().next()).intValue();
    }

    private static Optional<List<IndexRange>> adjustToClosestLegalParallelism(long j, int i, int i2, int i3, long j2, long j3, Function<Long, Integer> function, Function<Long, List<IndexRange>> function2) {
        long j4 = j;
        if (i < i2) {
            long findMaxLegalValue = BisectionSearchUtils.findMaxLegalValue(l -> {
                return Boolean.valueOf(((Integer) function.apply(l)).intValue() >= i2);
            }, j2, j);
            long intValue = function.apply(Long.valueOf(findMaxLegalValue)).intValue();
            j4 = BisectionSearchUtils.findMinLegalValue(l2 -> {
                return Boolean.valueOf(((long) ((Integer) function.apply(l2)).intValue()) == intValue);
            }, j2, findMaxLegalValue);
        } else if (i > i3) {
            j4 = BisectionSearchUtils.findMinLegalValue(l3 -> {
                return Boolean.valueOf(((Integer) function.apply(l3)).intValue() <= i3);
            }, j, j3);
        }
        return isLegalParallelism(function.apply(Long.valueOf(j4)).intValue(), i2, i3) ? Optional.of(function2.apply(Long.valueOf(j4))) : Optional.empty();
    }

    private static ParallelismAndInputInfos createParallelismAndInputInfos(List<BlockingResultInfo> list, List<IndexRange> list2) {
        HashMap hashMap = new HashMap();
        list.forEach(blockingResultInfo -> {
            IndexRange indexRange = new IndexRange(0, blockingResultInfo.getNumPartitions() - 1);
            ArrayList arrayList = new ArrayList();
            for (int i = 0; i < list2.size(); i++) {
                arrayList.add(new ExecutionVertexInputInfo(i, indexRange, blockingResultInfo.isBroadcast() ? new IndexRange(0, 0) : (IndexRange) list2.get(i)));
            }
            hashMap.put(blockingResultInfo.getResultId(), new JobVertexInputInfo(arrayList));
        });
        return new ParallelismAndInputInfos(list2.size(), hashMap);
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static List<IndexRange> computeSubpartitionRanges(long[] jArr, long j, int i) {
        long j2;
        ArrayList arrayList = new ArrayList();
        long j3 = 0;
        int i2 = 0;
        for (int i3 = 0; i3 < jArr.length; i3++) {
            long j4 = jArr[i3];
            if (i3 == i2 || (j3 + j4 <= j && (i3 - i2) + 1 <= i)) {
                j2 = j3 + j4;
            } else {
                arrayList.add(new IndexRange(i2, i3 - 1));
                i2 = i3;
                j2 = j4;
            }
            j3 = j2;
        }
        arrayList.add(new IndexRange(i2, jArr.length - 1));
        return arrayList;
    }

    private static int computeParallelism(long[] jArr, long j, int i) {
        long j2 = 0;
        int i2 = 0;
        int i3 = 1;
        for (int i4 = 0; i4 < jArr.length; i4++) {
            long j3 = jArr[i4];
            if (i4 == i2 || (j2 + j3 <= j && (i4 - i2) + 1 <= i)) {
                j2 += j3;
            } else {
                i2 = i4;
                j2 = j3;
                i3++;
            }
        }
        return i3;
    }

    private static int getMaxNumPartitions(List<BlockingResultInfo> list) {
        Preconditions.checkArgument(!list.isEmpty());
        return list.stream().mapToInt((v0) -> {
            return v0.getNumPartitions();
        }).max().getAsInt();
    }

    private static int getMaxNumSubpartitions(List<BlockingResultInfo> list) {
        Preconditions.checkArgument(!list.isEmpty());
        return list.stream().mapToInt(blockingResultInfo -> {
            Stream<Integer> boxed = IntStream.range(0, blockingResultInfo.getNumPartitions()).boxed();
            blockingResultInfo.getClass();
            return boxed.mapToInt((v1) -> {
                return r1.getNumSubpartitions(v1);
            }).sum();
        }).max().getAsInt();
    }

    private static List<BlockingResultInfo> getNonBroadcastResultInfos(List<BlockingResultInfo> list) {
        return (List) list.stream().filter(blockingResultInfo -> {
            return !blockingResultInfo.isBroadcast();
        }).collect(Collectors.toList());
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static DefaultVertexParallelismAndInputInfosDecider from(int i, Configuration configuration) {
        return new DefaultVertexParallelismAndInputInfosDecider(i, configuration.getInteger(BatchExecutionOptions.ADAPTIVE_AUTO_PARALLELISM_MIN_PARALLELISM), (MemorySize) configuration.get(BatchExecutionOptions.ADAPTIVE_AUTO_PARALLELISM_AVG_DATA_VOLUME_PER_TASK), ((Integer) configuration.get(BatchExecutionOptions.ADAPTIVE_AUTO_PARALLELISM_DEFAULT_SOURCE_PARALLELISM)).intValue());
    }
}
