package org.apache.flink.table.planner.plan.optimize.program;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.function.BiFunction;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.Calc;
import org.apache.calcite.rel.core.Exchange;
import org.apache.calcite.rel.core.Join;
import org.apache.calcite.rel.core.JoinInfo;
import org.apache.calcite.rel.core.JoinRelType;
import org.apache.calcite.rel.core.Union;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexLocalRef;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexProgram;
import org.apache.calcite.util.ImmutableBitSet;
import org.apache.calcite.util.ImmutableIntList;
import org.apache.commons.math3.optimization.direct.CMAESOptimizer;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.configuration.MemorySize;
import org.apache.flink.table.api.config.OptimizerConfigOptions;
import org.apache.flink.table.planner.plan.nodes.FlinkConventions;
import org.apache.flink.table.planner.plan.nodes.physical.batch.BatchPhysicalDynamicFilteringTableSourceScan;
import org.apache.flink.table.planner.plan.nodes.physical.batch.BatchPhysicalExchange;
import org.apache.flink.table.planner.plan.nodes.physical.batch.BatchPhysicalGroupAggregateBase;
import org.apache.flink.table.planner.plan.nodes.physical.batch.BatchPhysicalHashJoin;
import org.apache.flink.table.planner.plan.nodes.physical.batch.BatchPhysicalSortMergeJoin;
import org.apache.flink.table.planner.plan.nodes.physical.batch.runtimefilter.BatchPhysicalGlobalRuntimeFilterBuilder;
import org.apache.flink.table.planner.plan.nodes.physical.batch.runtimefilter.BatchPhysicalLocalRuntimeFilterBuilder;
import org.apache.flink.table.planner.plan.nodes.physical.batch.runtimefilter.BatchPhysicalRuntimeFilter;
import org.apache.flink.table.planner.plan.trait.FlinkRelDistribution;
import org.apache.flink.table.planner.plan.utils.DefaultRelShuttle;
import org.apache.flink.table.planner.plan.utils.FlinkRelMdUtil;
import org.apache.flink.table.planner.plan.utils.JoinUtil;
import org.apache.flink.table.planner.utils.ShortcutUtils;
import org.apache.flink.util.Preconditions;

/* loaded from: input_file:org/apache/flink/table/planner/plan/optimize/program/FlinkRuntimeFilterProgram.class */
public class FlinkRuntimeFilterProgram implements FlinkOptimizeProgram<BatchOptimizeContext> {

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/flink/table/planner/plan/optimize/program/FlinkRuntimeFilterProgram$BuildSideInfo.class */
    public static class BuildSideInfo {
        private final RelNode buildSide;
        private final ImmutableIntList buildIndices;

        public BuildSideInfo(RelNode relNode, ImmutableIntList immutableIntList) {
            this.buildSide = (RelNode) Preconditions.checkNotNull(relNode);
            this.buildIndices = (ImmutableIntList) Preconditions.checkNotNull(immutableIntList);
        }
    }

    @Override // org.apache.flink.table.planner.plan.optimize.program.FlinkOptimizeProgram
    public RelNode optimize(RelNode relNode, BatchOptimizeContext batchOptimizeContext) {
        if (!isRuntimeFilterEnabled(relNode)) {
            return relNode;
        }
        Preconditions.checkState(getMinProbeDataSize(relNode) > getMaxBuildDataSize(relNode), "The min probe data size should be larger than the max build data size.");
        return new DefaultRelShuttle() { // from class: org.apache.flink.table.planner.plan.optimize.program.FlinkRuntimeFilterProgram.1
            @Override // org.apache.flink.table.planner.plan.utils.DefaultRelShuttle, org.apache.calcite.rel.RelShuttleImpl, org.apache.calcite.rel.RelShuttle
            public RelNode visit(RelNode relNode2) {
                if (relNode2 instanceof Join) {
                    Join join = (Join) relNode2;
                    return FlinkRuntimeFilterProgram.tryInjectRuntimeFilter(join.copy(join.getTraitSet(), Arrays.asList(join.getLeft().accept(this), join.getRight().accept(this))));
                }
                ArrayList arrayList = new ArrayList();
                Iterator<RelNode> it = relNode2.getInputs().iterator();
                while (it.hasNext()) {
                    arrayList.add(it.next().accept(this));
                }
                return relNode2.copy(relNode2.getTraitSet(), arrayList);
            }
        }.visit(relNode);
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static Join tryInjectRuntimeFilter(Join join) {
        boolean z;
        RelNode right;
        RelNode left;
        ImmutableIntList immutableIntList;
        ImmutableIntList immutableIntList2;
        if (!isSuitableJoinType(join.getJoinType())) {
            return join;
        }
        if (!(join instanceof BatchPhysicalHashJoin) && !(join instanceof BatchPhysicalSortMergeJoin)) {
            return join;
        }
        if (canBeProbeSide(join.getLeft())) {
            z = false;
        } else {
            if (!canBeProbeSide(join.getRight())) {
                return join;
            }
            z = true;
        }
        if (join.getJoinType() == JoinRelType.LEFT && !z) {
            return join;
        }
        if (join.getJoinType() == JoinRelType.RIGHT && z) {
            return join;
        }
        JoinInfo analyzeCondition = join.analyzeCondition();
        if (z) {
            right = join.getLeft();
            left = join.getRight();
            immutableIntList = analyzeCondition.leftKeys;
            immutableIntList2 = analyzeCondition.rightKeys;
        } else {
            right = join.getRight();
            left = join.getLeft();
            immutableIntList = analyzeCondition.rightKeys;
            immutableIntList2 = analyzeCondition.leftKeys;
        }
        RelNode relNode = left;
        ImmutableIntList immutableIntList3 = immutableIntList2;
        Optional<BuildSideInfo> findSuitableBuildSide = findSuitableBuildSide(right, immutableIntList, (relNode2, immutableIntList4) -> {
            return Boolean.valueOf(isSuitableDataSize(relNode2, relNode, immutableIntList4, immutableIntList3));
        });
        if (!findSuitableBuildSide.isPresent()) {
            return join;
        }
        RelNode tryPushDownProbeAndInjectRuntimeFilter = tryPushDownProbeAndInjectRuntimeFilter(left, immutableIntList2, findSuitableBuildSide.get(), false);
        return z ? join.copy(join.getTraitSet(), Arrays.asList(right, tryPushDownProbeAndInjectRuntimeFilter)) : join.copy(join.getTraitSet(), Arrays.asList(tryPushDownProbeAndInjectRuntimeFilter, right));
    }

    private static RelNode createNewProbeWithRuntimeFilter(RelNode relNode, RelNode relNode2, ImmutableIntList immutableIntList, ImmutableIntList immutableIntList2) {
        Optional<Double> estimatedRowCount = getEstimatedRowCount(relNode);
        Preconditions.checkState(estimatedRowCount.isPresent());
        int intValue = estimatedRowCount.get().intValue();
        int ceil = (int) Math.ceil(getMaxBuildDataSize(relNode) / FlinkRelMdUtil.binaryRowAverageSize(relNode).doubleValue());
        double computeFilterRatio = computeFilterRatio(relNode, relNode2, immutableIntList, immutableIntList2);
        Stream stream = immutableIntList.stream();
        List<String> fieldNames = relNode.getRowType().getFieldNames();
        fieldNames.getClass();
        String[] strArr = (String[]) stream.map((v1) -> {
            return r1.get(v1);
        }).toArray(i -> {
            return new String[i];
        });
        BatchPhysicalLocalRuntimeFilterBuilder batchPhysicalLocalRuntimeFilterBuilder = new BatchPhysicalLocalRuntimeFilterBuilder(relNode.getCluster(), relNode.getTraitSet(), relNode, immutableIntList.toIntArray(), strArr, intValue, ceil);
        return new BatchPhysicalRuntimeFilter(relNode2.getCluster(), relNode2.getTraitSet(), createExchange(new BatchPhysicalGlobalRuntimeFilterBuilder(batchPhysicalLocalRuntimeFilterBuilder.getCluster(), batchPhysicalLocalRuntimeFilterBuilder.getTraitSet(), createExchange(batchPhysicalLocalRuntimeFilterBuilder, FlinkRelDistribution.SINGLETON()), strArr, intValue, ceil), FlinkRelDistribution.BROADCAST_DISTRIBUTED()), relNode2, immutableIntList2.toIntArray(), computeFilterRatio);
    }

    private static Optional<BuildSideInfo> findSuitableBuildSide(RelNode relNode, ImmutableIntList immutableIntList, BiFunction<RelNode, ImmutableIntList, Boolean> biFunction) {
        if (relNode instanceof Exchange) {
            Exchange exchange = (Exchange) relNode;
            if (!(exchange.getInput() instanceof BatchPhysicalRuntimeFilter) && biFunction.apply(exchange.getInput(), immutableIntList).booleanValue()) {
                return Optional.of(new BuildSideInfo(exchange.getInput(), immutableIntList));
            }
        } else {
            if (relNode instanceof BatchPhysicalRuntimeFilter) {
                return Optional.empty();
            }
            if (relNode instanceof Calc) {
                Calc calc = (Calc) relNode;
                RexProgram program = calc.getProgram();
                Stream<RexLocalRef> stream = program.getProjectList().stream();
                program.getClass();
                ImmutableIntList inputIndices = getInputIndices((List<RexNode>) stream.map(program::expandLocalRef).collect(Collectors.toList()), immutableIntList);
                return inputIndices.isEmpty() ? Optional.empty() : findSuitableBuildSide(calc.getInput(), inputIndices, biFunction);
            }
            if (relNode instanceof Join) {
                Join join = (Join) relNode;
                if (!isSuitableJoinType(join.getJoinType())) {
                    return Optional.empty();
                }
                Tuple2<ImmutableIntList, ImmutableIntList> inputIndices2 = getInputIndices(join, immutableIntList);
                ImmutableIntList immutableIntList2 = inputIndices2.f0;
                ImmutableIntList immutableIntList3 = inputIndices2.f1;
                if (join.getJoinType() == JoinRelType.LEFT) {
                    immutableIntList3 = ImmutableIntList.of();
                } else if (join.getJoinType() == JoinRelType.RIGHT) {
                    immutableIntList2 = ImmutableIntList.of();
                }
                if (immutableIntList2.isEmpty() && immutableIntList3.isEmpty()) {
                    return Optional.empty();
                }
                boolean z = !immutableIntList2.isEmpty() && (join.getLeft() instanceof Exchange);
                Optional<BuildSideInfo> empty = Optional.empty();
                if (z) {
                    Optional<BuildSideInfo> findSuitableBuildSide = findSuitableBuildSide(join.getLeft(), immutableIntList2, biFunction);
                    if (!findSuitableBuildSide.isPresent() && !immutableIntList3.isEmpty()) {
                        findSuitableBuildSide = findSuitableBuildSide(join.getRight(), immutableIntList3, biFunction);
                    }
                    return findSuitableBuildSide;
                }
                if (!immutableIntList3.isEmpty()) {
                    empty = findSuitableBuildSide(join.getRight(), immutableIntList3, biFunction);
                    if (!empty.isPresent() && !immutableIntList2.isEmpty()) {
                        empty = findSuitableBuildSide(join.getLeft(), immutableIntList2, biFunction);
                    }
                }
                return empty;
            }
            if (relNode instanceof BatchPhysicalGroupAggregateBase) {
                BatchPhysicalGroupAggregateBase batchPhysicalGroupAggregateBase = (BatchPhysicalGroupAggregateBase) relNode;
                int[] grouping = batchPhysicalGroupAggregateBase.grouping();
                Iterator<Integer> it = immutableIntList.iterator();
                while (it.hasNext()) {
                    if (it.next().intValue() >= grouping.length) {
                        return Optional.empty();
                    }
                }
                return findSuitableBuildSide(batchPhysicalGroupAggregateBase.getInput(), ImmutableIntList.copyOf((Iterable<? extends Number>) immutableIntList.stream().map(num -> {
                    return Integer.valueOf(batchPhysicalGroupAggregateBase.grouping()[num.intValue()]);
                }).collect(Collectors.toList())), biFunction);
            }
        }
        return Optional.empty();
    }

    private static RelNode tryPushDownProbeAndInjectRuntimeFilter(RelNode relNode, ImmutableIntList immutableIntList, BuildSideInfo buildSideInfo, boolean z) {
        if (relNode instanceof BatchPhysicalRuntimeFilter) {
            return relNode;
        }
        if (relNode instanceof Exchange) {
            Exchange exchange = (Exchange) relNode;
            return exchange.copy(exchange.getTraitSet(), Collections.singletonList(tryPushDownProbeAndInjectRuntimeFilter(exchange.getInput(), immutableIntList, buildSideInfo, true)));
        }
        if (relNode instanceof Calc) {
            Calc calc = (Calc) relNode;
            RexProgram program = calc.getProgram();
            Stream<RexLocalRef> stream = program.getProjectList().stream();
            program.getClass();
            ImmutableIntList inputIndices = getInputIndices((List<RexNode>) stream.map(program::expandLocalRef).collect(Collectors.toList()), immutableIntList);
            if (!inputIndices.isEmpty()) {
                return calc.copy(calc.getTraitSet(), Collections.singletonList(tryPushDownProbeAndInjectRuntimeFilter(calc.getInput(), inputIndices, buildSideInfo, z)));
            }
        } else if (relNode instanceof Join) {
            Join join = (Join) relNode;
            Tuple2<ImmutableIntList, ImmutableIntList> inputIndices2 = getInputIndices(join, immutableIntList);
            ImmutableIntList immutableIntList2 = inputIndices2.f0;
            ImmutableIntList immutableIntList3 = inputIndices2.f1;
            if (!immutableIntList2.isEmpty() || !immutableIntList3.isEmpty()) {
                RelNode left = join.getLeft();
                RelNode right = join.getRight();
                if (!immutableIntList2.isEmpty()) {
                    left = tryPushDownProbeAndInjectRuntimeFilter(left, immutableIntList2, buildSideInfo, true);
                }
                if (!immutableIntList3.isEmpty()) {
                    right = tryPushDownProbeAndInjectRuntimeFilter(right, immutableIntList3, buildSideInfo, true);
                }
                return join.copy(join.getTraitSet(), Arrays.asList(left, right));
            }
        } else if (relNode instanceof BatchPhysicalGroupAggregateBase) {
            BatchPhysicalGroupAggregateBase batchPhysicalGroupAggregateBase = (BatchPhysicalGroupAggregateBase) relNode;
            int[] grouping = batchPhysicalGroupAggregateBase.grouping();
            if (immutableIntList.stream().allMatch(num -> {
                return num.intValue() < grouping.length;
            })) {
                return batchPhysicalGroupAggregateBase.copy(batchPhysicalGroupAggregateBase.getTraitSet(), Collections.singletonList(tryPushDownProbeAndInjectRuntimeFilter(batchPhysicalGroupAggregateBase.getInput(), ImmutableIntList.copyOf((Iterable<? extends Number>) immutableIntList.stream().map(num2 -> {
                    return Integer.valueOf(batchPhysicalGroupAggregateBase.grouping()[num2.intValue()]);
                }).collect(Collectors.toList())), buildSideInfo, true)));
            }
        } else {
            if (relNode instanceof Union) {
                Union union = (Union) relNode;
                ArrayList arrayList = new ArrayList();
                Iterator<RelNode> it = union.getInputs().iterator();
                while (it.hasNext()) {
                    arrayList.add(tryPushDownProbeAndInjectRuntimeFilter(it.next(), immutableIntList, buildSideInfo, z));
                }
                return union.copy(union.getTraitSet(), arrayList, union.all);
            }
            if ((relNode instanceof BatchPhysicalDynamicFilteringTableSourceScan) && new HashSet(((BatchPhysicalDynamicFilteringTableSourceScan) relNode).dynamicFilteringIndices()).containsAll(immutableIntList)) {
                return relNode;
            }
        }
        return z ? createNewProbeWithRuntimeFilter(ignoreExchange(buildSideInfo.buildSide), ignoreExchange(relNode), buildSideInfo.buildIndices, immutableIntList) : relNode;
    }

    private static BatchPhysicalExchange createExchange(RelNode relNode, FlinkRelDistribution flinkRelDistribution) {
        return new BatchPhysicalExchange(relNode.getCluster(), relNode.getCluster().getPlanner().emptyTraitSet().replace(FlinkConventions.BATCH_PHYSICAL()).replace(flinkRelDistribution), relNode, flinkRelDistribution);
    }

    private static ImmutableIntList getInputIndices(List<RexNode> list, ImmutableIntList immutableIntList) {
        ArrayList arrayList = new ArrayList();
        Iterator<Integer> it = immutableIntList.iterator();
        while (it.hasNext()) {
            RexNode rexNode = list.get(it.next().intValue());
            if (!(rexNode instanceof RexInputRef)) {
                return ImmutableIntList.of();
            }
            arrayList.add(Integer.valueOf(((RexInputRef) rexNode).getIndex()));
        }
        return ImmutableIntList.copyOf((Iterable<? extends Number>) arrayList);
    }

    private static Tuple2<ImmutableIntList, ImmutableIntList> getInputIndices(Join join, ImmutableIntList immutableIntList) {
        JoinInfo analyzeCondition = join.analyzeCondition();
        Map<Integer, Integer> createKeysMapping = createKeysMapping(analyzeCondition.leftKeys, analyzeCondition.rightKeys);
        Map<Integer, Integer> createKeysMapping2 = createKeysMapping(analyzeCondition.rightKeys, analyzeCondition.leftKeys);
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        int fieldCount = join.getLeft().getRowType().getFieldCount();
        Iterator<Integer> it = immutableIntList.iterator();
        while (it.hasNext()) {
            int intValue = it.next().intValue();
            if (intValue < fieldCount) {
                arrayList.add(Integer.valueOf(intValue));
                if (createKeysMapping.containsKey(Integer.valueOf(intValue))) {
                    arrayList2.add(createKeysMapping.get(Integer.valueOf(intValue)));
                }
            } else {
                int i = intValue - fieldCount;
                arrayList2.add(Integer.valueOf(i));
                if (createKeysMapping2.containsKey(Integer.valueOf(i))) {
                    arrayList.add(createKeysMapping2.get(Integer.valueOf(i)));
                }
            }
        }
        return Tuple2.of(arrayList.size() == immutableIntList.size() ? ImmutableIntList.copyOf((Iterable<? extends Number>) arrayList) : ImmutableIntList.of(), arrayList2.size() == immutableIntList.size() ? ImmutableIntList.copyOf((Iterable<? extends Number>) arrayList2) : ImmutableIntList.of());
    }

    private static Map<Integer, Integer> createKeysMapping(ImmutableIntList immutableIntList, ImmutableIntList immutableIntList2) {
        Preconditions.checkState(immutableIntList.size() == immutableIntList2.size());
        HashMap hashMap = new HashMap();
        for (int i = 0; i < immutableIntList.size(); i++) {
            hashMap.put(immutableIntList.get(i), immutableIntList2.get(i));
        }
        return hashMap;
    }

    private static boolean canBeProbeSide(RelNode relNode) {
        Optional<Double> estimatedDataSize = getEstimatedDataSize(relNode);
        return estimatedDataSize.isPresent() && estimatedDataSize.get().doubleValue() >= ((double) getMinProbeDataSize(relNode));
    }

    private static boolean isSuitableDataSize(RelNode relNode, RelNode relNode2, ImmutableIntList immutableIntList, ImmutableIntList immutableIntList2) {
        Optional<Double> estimatedDataSize = getEstimatedDataSize(relNode);
        Optional<Double> estimatedDataSize2 = getEstimatedDataSize(relNode2);
        return estimatedDataSize.isPresent() && estimatedDataSize2.isPresent() && estimatedDataSize.get().doubleValue() <= ((double) getMaxBuildDataSize(relNode)) && estimatedDataSize2.get().doubleValue() >= ((double) getMinProbeDataSize(relNode2)) && computeFilterRatio(relNode, relNode2, immutableIntList, immutableIntList2) >= getMinFilterRatio(relNode);
    }

    private static double computeFilterRatio(RelNode relNode, RelNode relNode2, ImmutableIntList immutableIntList, ImmutableIntList immutableIntList2) {
        Optional<Double> estimatedNdv = getEstimatedNdv(relNode, ImmutableBitSet.of(immutableIntList));
        Optional<Double> estimatedNdv2 = getEstimatedNdv(relNode2, ImmutableBitSet.of(immutableIntList2));
        if (estimatedNdv.isPresent() && estimatedNdv2.isPresent()) {
            return Math.max(CMAESOptimizer.DEFAULT_STOPFITNESS, 1.0d - (estimatedNdv.get().doubleValue() / estimatedNdv2.get().doubleValue()));
        }
        Optional<Double> estimatedRowCount = getEstimatedRowCount(relNode);
        Optional<Double> estimatedRowCount2 = getEstimatedRowCount(relNode2);
        Preconditions.checkState(estimatedRowCount.isPresent() && estimatedRowCount2.isPresent());
        return Math.max(CMAESOptimizer.DEFAULT_STOPFITNESS, 1.0d - (estimatedRowCount.get().doubleValue() / estimatedRowCount2.get().doubleValue()));
    }

    private static RelNode ignoreExchange(RelNode relNode) {
        return relNode instanceof Exchange ? relNode.getInput(0) : relNode;
    }

    private static Optional<Double> getEstimatedDataSize(RelNode relNode) {
        return Optional.ofNullable(JoinUtil.binaryRowRelNodeSize(relNode));
    }

    private static Optional<Double> getEstimatedRowCount(RelNode relNode) {
        return Optional.ofNullable(relNode.getCluster().getMetadataQuery().getRowCount(relNode));
    }

    private static Optional<Double> getEstimatedNdv(RelNode relNode, ImmutableBitSet immutableBitSet) {
        return Optional.ofNullable(relNode.getCluster().getMetadataQuery().getDistinctRowCount(relNode, immutableBitSet, null));
    }

    private static boolean isRuntimeFilterEnabled(RelNode relNode) {
        return ((Boolean) ShortcutUtils.unwrapTableConfig(relNode).get(OptimizerConfigOptions.TABLE_OPTIMIZER_RUNTIME_FILTER_ENABLED)).booleanValue();
    }

    private static long getMaxBuildDataSize(RelNode relNode) {
        return ((MemorySize) ShortcutUtils.unwrapTableConfig(relNode).get(OptimizerConfigOptions.TABLE_OPTIMIZER_RUNTIME_FILTER_MAX_BUILD_DATA_SIZE)).getBytes();
    }

    private static long getMinProbeDataSize(RelNode relNode) {
        return ((MemorySize) ShortcutUtils.unwrapTableConfig(relNode).get(OptimizerConfigOptions.TABLE_OPTIMIZER_RUNTIME_FILTER_MIN_PROBE_DATA_SIZE)).getBytes();
    }

    private static double getMinFilterRatio(RelNode relNode) {
        return ((Double) ShortcutUtils.unwrapTableConfig(relNode).get(OptimizerConfigOptions.TABLE_OPTIMIZER_RUNTIME_FILTER_MIN_FILTER_RATIO)).doubleValue();
    }

    public static boolean isSuitableJoinType(JoinRelType joinRelType) {
        return joinRelType == JoinRelType.INNER || joinRelType == JoinRelType.SEMI || joinRelType == JoinRelType.LEFT || joinRelType == JoinRelType.RIGHT;
    }
}
