package org.apache.flink.table.runtime.operators.join;

import java.io.Serializable;
import org.apache.flink.configuration.AlgorithmOptions;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.metrics.groups.OperatorMetricGroup;
import org.apache.flink.streaming.api.operators.BoundedMultiInput;
import org.apache.flink.streaming.api.operators.InputSelectable;
import org.apache.flink.streaming.api.operators.InputSelection;
import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
import org.apache.flink.table.data.GenericRowData;
import org.apache.flink.table.data.RowData;
import org.apache.flink.table.data.binary.BinaryRowData;
import org.apache.flink.table.data.utils.JoinedRowData;
import org.apache.flink.table.runtime.generated.GeneratedJoinCondition;
import org.apache.flink.table.runtime.generated.GeneratedProjection;
import org.apache.flink.table.runtime.generated.JoinCondition;
import org.apache.flink.table.runtime.hashtable.BinaryHashPartition;
import org.apache.flink.table.runtime.hashtable.BinaryHashTable;
import org.apache.flink.table.runtime.hashtable.ProbeIterator;
import org.apache.flink.table.runtime.operators.TableStreamOperator;
import org.apache.flink.table.runtime.typeutils.AbstractRowDataSerializer;
import org.apache.flink.table.runtime.util.RowIterator;
import org.apache.flink.table.runtime.util.StreamRecordCollector;
import org.apache.flink.table.types.logical.RowType;
import org.apache.flink.util.Collector;
import org.apache.flink.util.Preconditions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/apache/flink/table/runtime/operators/join/HashJoinOperator.class */
public abstract class HashJoinOperator extends TableStreamOperator<RowData> implements TwoInputStreamOperator<RowData, RowData, RowData>, BoundedMultiInput, InputSelectable {
    private static final Logger LOG = LoggerFactory.getLogger((Class<?>) HashJoinOperator.class);
    private final HashJoinParameter parameter;
    private final boolean reverseJoinFunction;
    private final HashJoinType type;
    private final boolean leftIsBuild;
    private final SortMergeJoinFunction sortMergeJoinFunction;
    private transient BinaryHashTable table;
    transient Collector<RowData> collector;
    transient RowData buildSideNullRow;
    private transient RowData probeSideNullRow;
    private transient JoinedRowData joinedRow;
    private transient boolean buildEnd;
    private transient JoinCondition condition;
    private transient boolean fallbackSMJ;

    /* loaded from: input_file:org/apache/flink/table/runtime/operators/join/HashJoinOperator$AntiHashJoinOperator.class */
    private static class AntiHashJoinOperator extends HashJoinOperator {
        AntiHashJoinOperator(HashJoinParameter hashJoinParameter) {
            super(hashJoinParameter);
        }

        @Override // org.apache.flink.table.runtime.operators.join.HashJoinOperator
        public void join(RowIterator<BinaryRowData> rowIterator, RowData rowData) throws Exception {
            Preconditions.checkNotNull(rowData);
            if (rowIterator.advanceNext()) {
                return;
            }
            this.collector.collect(rowData);
        }
    }

    /* loaded from: input_file:org/apache/flink/table/runtime/operators/join/HashJoinOperator$BuildLeftSemiOrAntiHashJoinOperator.class */
    private static class BuildLeftSemiOrAntiHashJoinOperator extends HashJoinOperator {
        BuildLeftSemiOrAntiHashJoinOperator(HashJoinParameter hashJoinParameter) {
            super(hashJoinParameter);
        }

        @Override // org.apache.flink.table.runtime.operators.join.HashJoinOperator
        public void join(RowIterator<BinaryRowData> rowIterator, RowData rowData) throws Exception {
            if (rowIterator.advanceNext()) {
                if (rowData == null) {
                    this.collector.collect(rowIterator.getRow());
                    while (rowIterator.advanceNext()) {
                        this.collector.collect(rowIterator.getRow());
                    }
                    return;
                }
                do {
                } while (rowIterator.advanceNext());
            }
        }
    }

    /* loaded from: input_file:org/apache/flink/table/runtime/operators/join/HashJoinOperator$BuildOuterHashJoinOperator.class */
    private static class BuildOuterHashJoinOperator extends HashJoinOperator {
        BuildOuterHashJoinOperator(HashJoinParameter hashJoinParameter) {
            super(hashJoinParameter);
        }

        @Override // org.apache.flink.table.runtime.operators.join.HashJoinOperator
        public void join(RowIterator<BinaryRowData> rowIterator, RowData rowData) throws Exception {
            if (rowIterator.advanceNext()) {
                if (rowData != null) {
                    innerJoin(rowIterator, rowData);
                } else {
                    buildOuterJoin(rowIterator);
                }
            }
        }
    }

    /* loaded from: input_file:org/apache/flink/table/runtime/operators/join/HashJoinOperator$FullOuterHashJoinOperator.class */
    private static class FullOuterHashJoinOperator extends HashJoinOperator {
        FullOuterHashJoinOperator(HashJoinParameter hashJoinParameter) {
            super(hashJoinParameter);
        }

        @Override // org.apache.flink.table.runtime.operators.join.HashJoinOperator
        public void join(RowIterator<BinaryRowData> rowIterator, RowData rowData) throws Exception {
            if (!rowIterator.advanceNext()) {
                if (rowData != null) {
                    collect(this.buildSideNullRow, rowData);
                }
            } else if (rowData != null) {
                innerJoin(rowIterator, rowData);
            } else {
                buildOuterJoin(rowIterator);
            }
        }
    }

    /* loaded from: input_file:org/apache/flink/table/runtime/operators/join/HashJoinOperator$HashJoinParameter.class */
    static class HashJoinParameter implements Serializable {
        HashJoinType type;
        boolean leftIsBuild;
        GeneratedJoinCondition condFuncCode;
        boolean reverseJoinFunction;
        boolean[] filterNullKeys;
        GeneratedProjection buildProjectionCode;
        GeneratedProjection probeProjectionCode;
        boolean tryDistinctBuildRow;
        int buildRowSize;
        long buildRowCount;
        long probeRowCount;
        RowType keyType;
        SortMergeJoinFunction sortMergeJoinFunction;

        HashJoinParameter(HashJoinType hashJoinType, boolean z, GeneratedJoinCondition generatedJoinCondition, boolean z2, boolean[] zArr, GeneratedProjection generatedProjection, GeneratedProjection generatedProjection2, boolean z3, int i, long j, long j2, RowType rowType, SortMergeJoinFunction sortMergeJoinFunction) {
            this.type = hashJoinType;
            this.leftIsBuild = z;
            this.condFuncCode = generatedJoinCondition;
            this.reverseJoinFunction = z2;
            this.filterNullKeys = zArr;
            this.buildProjectionCode = generatedProjection;
            this.probeProjectionCode = generatedProjection2;
            this.tryDistinctBuildRow = z3;
            this.buildRowSize = i;
            this.buildRowCount = j;
            this.probeRowCount = j2;
            this.keyType = rowType;
            this.sortMergeJoinFunction = sortMergeJoinFunction;
        }
    }

    /* loaded from: input_file:org/apache/flink/table/runtime/operators/join/HashJoinOperator$InnerHashJoinOperator.class */
    private static class InnerHashJoinOperator extends HashJoinOperator {
        InnerHashJoinOperator(HashJoinParameter hashJoinParameter) {
            super(hashJoinParameter);
        }

        @Override // org.apache.flink.table.runtime.operators.join.HashJoinOperator
        public void join(RowIterator<BinaryRowData> rowIterator, RowData rowData) throws Exception {
            if (!rowIterator.advanceNext() || rowData == null) {
                return;
            }
            innerJoin(rowIterator, rowData);
        }
    }

    /* loaded from: input_file:org/apache/flink/table/runtime/operators/join/HashJoinOperator$ProbeOuterHashJoinOperator.class */
    private static class ProbeOuterHashJoinOperator extends HashJoinOperator {
        ProbeOuterHashJoinOperator(HashJoinParameter hashJoinParameter) {
            super(hashJoinParameter);
        }

        @Override // org.apache.flink.table.runtime.operators.join.HashJoinOperator
        public void join(RowIterator<BinaryRowData> rowIterator, RowData rowData) throws Exception {
            if (rowIterator.advanceNext()) {
                if (rowData != null) {
                    innerJoin(rowIterator, rowData);
                }
            } else if (rowData != null) {
                collect(this.buildSideNullRow, rowData);
            }
        }
    }

    /* loaded from: input_file:org/apache/flink/table/runtime/operators/join/HashJoinOperator$SemiHashJoinOperator.class */
    private static class SemiHashJoinOperator extends HashJoinOperator {
        SemiHashJoinOperator(HashJoinParameter hashJoinParameter) {
            super(hashJoinParameter);
        }

        @Override // org.apache.flink.table.runtime.operators.join.HashJoinOperator
        public void join(RowIterator<BinaryRowData> rowIterator, RowData rowData) throws Exception {
            Preconditions.checkNotNull(rowData);
            if (rowIterator.advanceNext()) {
                this.collector.collect(rowData);
            }
        }
    }

    HashJoinOperator(HashJoinParameter hashJoinParameter) {
        this.parameter = hashJoinParameter;
        this.type = hashJoinParameter.type;
        this.leftIsBuild = hashJoinParameter.leftIsBuild;
        this.reverseJoinFunction = hashJoinParameter.reverseJoinFunction;
        this.sortMergeJoinFunction = hashJoinParameter.sortMergeJoinFunction;
    }

    @Override // org.apache.flink.table.runtime.operators.TableStreamOperator, org.apache.flink.streaming.api.operators.AbstractStreamOperator, org.apache.flink.streaming.api.operators.StreamOperator
    public void open() throws Exception {
        super.open();
        ClassLoader userCodeClassLoader = getContainingTask().getUserCodeClassLoader();
        AbstractRowDataSerializer abstractRowDataSerializer = (AbstractRowDataSerializer) getOperatorConfig().getTypeSerializerIn1(getUserCodeClassloader());
        AbstractRowDataSerializer abstractRowDataSerializer2 = (AbstractRowDataSerializer) getOperatorConfig().getTypeSerializerIn2(getUserCodeClassloader());
        boolean z = getContainingTask().getEnvironment().getTaskConfiguration().getBoolean(AlgorithmOptions.HASH_JOIN_BLOOM_FILTERS);
        int numberOfParallelSubtasks = getRuntimeContext().getNumberOfParallelSubtasks();
        this.condition = this.parameter.condFuncCode.newInstance(userCodeClassLoader);
        this.condition.setRuntimeContext(getRuntimeContext());
        this.condition.open(new Configuration());
        this.table = new BinaryHashTable(getContainingTask().getJobConfiguration(), getContainingTask(), abstractRowDataSerializer, abstractRowDataSerializer2, this.parameter.buildProjectionCode.newInstance(userCodeClassLoader), this.parameter.probeProjectionCode.newInstance(userCodeClassLoader), getContainingTask().getEnvironment().getMemoryManager(), computeMemorySize(), getContainingTask().getEnvironment().getIOManager(), this.parameter.buildRowSize, this.parameter.buildRowCount / numberOfParallelSubtasks, z, this.type, this.condition, this.reverseJoinFunction, this.parameter.filterNullKeys, this.parameter.tryDistinctBuildRow);
        this.collector = new StreamRecordCollector(this.output);
        this.buildSideNullRow = new GenericRowData(abstractRowDataSerializer.getArity());
        this.probeSideNullRow = new GenericRowData(abstractRowDataSerializer2.getArity());
        this.joinedRow = new JoinedRowData();
        this.buildEnd = false;
        this.fallbackSMJ = false;
        OperatorMetricGroup metricGroup = getMetricGroup();
        BinaryHashTable binaryHashTable = this.table;
        binaryHashTable.getClass();
        metricGroup.gauge("memoryUsedSizeInBytes", (String) binaryHashTable::getUsedMemoryInBytes);
        OperatorMetricGroup metricGroup2 = getMetricGroup();
        BinaryHashTable binaryHashTable2 = this.table;
        binaryHashTable2.getClass();
        metricGroup2.gauge("numSpillFiles", (String) binaryHashTable2::getNumSpillFiles);
        OperatorMetricGroup metricGroup3 = getMetricGroup();
        BinaryHashTable binaryHashTable3 = this.table;
        binaryHashTable3.getClass();
        metricGroup3.gauge("spillInBytes", (String) binaryHashTable3::getSpillInBytes);
        this.parameter.condFuncCode = null;
        this.parameter.buildProjectionCode = null;
        this.parameter.probeProjectionCode = null;
    }

    @Override // org.apache.flink.streaming.api.operators.TwoInputStreamOperator
    public void processElement1(StreamRecord<RowData> streamRecord) throws Exception {
        Preconditions.checkState(!this.buildEnd, "Should not build ended.");
        this.table.putBuildRow(streamRecord.getValue());
    }

    @Override // org.apache.flink.streaming.api.operators.TwoInputStreamOperator
    public void processElement2(StreamRecord<RowData> streamRecord) throws Exception {
        Preconditions.checkState(this.buildEnd, "Should build ended.");
        if (this.table.tryProbe(streamRecord.getValue())) {
            joinWithNextKey();
        }
    }

    @Override // org.apache.flink.streaming.api.operators.InputSelectable
    public InputSelection nextSelection() {
        return this.buildEnd ? InputSelection.SECOND : InputSelection.FIRST;
    }

    @Override // org.apache.flink.streaming.api.operators.BoundedMultiInput
    public void endInput(int i) throws Exception {
        switch (i) {
            case 1:
                Preconditions.checkState(!this.buildEnd, "Should not build ended.");
                LOG.info("Finish build phase.");
                this.buildEnd = true;
                this.table.endBuild();
                return;
            case 2:
                Preconditions.checkState(this.buildEnd, "Should build ended.");
                LOG.info("Finish probe phase.");
                while (this.table.nextMatching()) {
                    joinWithNextKey();
                }
                LOG.info("Finish rebuild phase.");
                fallbackSMJProcessPartition();
                return;
            default:
                return;
        }
    }

    private void joinWithNextKey() throws Exception {
        join(this.table.getBuildSideIterator(), this.table.getCurrentProbeRow());
    }

    public abstract void join(RowIterator<BinaryRowData> rowIterator, RowData rowData) throws Exception;

    void innerJoin(RowIterator<BinaryRowData> rowIterator, RowData rowData) throws Exception {
        collect(rowIterator.getRow(), rowData);
        while (rowIterator.advanceNext()) {
            collect(rowIterator.getRow(), rowData);
        }
    }

    void buildOuterJoin(RowIterator<BinaryRowData> rowIterator) throws Exception {
        collect(rowIterator.getRow(), this.probeSideNullRow);
        while (rowIterator.advanceNext()) {
            collect(rowIterator.getRow(), this.probeSideNullRow);
        }
    }

    void collect(RowData rowData, RowData rowData2) throws Exception {
        if (this.reverseJoinFunction) {
            this.collector.collect(this.joinedRow.replace(rowData2, rowData));
        } else {
            this.collector.collect(this.joinedRow.replace(rowData, rowData2));
        }
    }

    @Override // org.apache.flink.streaming.api.operators.AbstractStreamOperator, org.apache.flink.streaming.api.operators.StreamOperator
    public void close() throws Exception {
        super.close();
        closeHashTable();
        this.condition.close();
        if (this.fallbackSMJ) {
            this.sortMergeJoinFunction.close();
        }
    }

    private void closeHashTable() {
        if (this.table != null) {
            this.table.close();
            this.table.free();
            this.table = null;
        }
    }

    private void fallbackSMJProcessPartition() throws Exception {
        if (this.table.getPartitionsPendingForSMJ().isEmpty()) {
            return;
        }
        this.table.releaseMemoryCacheForSMJ();
        LOG.info("Fallback to sort merge join to process spilled partitions.");
        initialSortMergeJoinFunction();
        this.fallbackSMJ = true;
        for (BinaryHashPartition binaryHashPartition : this.table.getPartitionsPendingForSMJ()) {
            RowIterator spilledPartitionBuildSideIter = this.table.getSpilledPartitionBuildSideIter(binaryHashPartition);
            while (spilledPartitionBuildSideIter.advanceNext()) {
                processSortMergeJoinElement1(spilledPartitionBuildSideIter.getRow());
            }
            ProbeIterator spilledPartitionProbeSideIter = this.table.getSpilledPartitionProbeSideIter(binaryHashPartition);
            while (true) {
                BinaryRowData next = spilledPartitionProbeSideIter.next();
                if (next != null) {
                    processSortMergeJoinElement2(next);
                }
            }
        }
        closeHashTable();
        this.sortMergeJoinFunction.endInput(1);
        this.sortMergeJoinFunction.endInput(2);
        LOG.info("Finish sort merge join for spilled partitions.");
    }

    private void initialSortMergeJoinFunction() throws Exception {
        this.sortMergeJoinFunction.open(true, getContainingTask(), getOperatorConfig(), (StreamRecordCollector) this.collector, computeMemorySize(), getRuntimeContext(), getMetricGroup());
    }

    private void processSortMergeJoinElement1(RowData rowData) throws Exception {
        if (this.leftIsBuild) {
            this.sortMergeJoinFunction.processElement1(rowData);
        } else {
            this.sortMergeJoinFunction.processElement2(rowData);
        }
    }

    private void processSortMergeJoinElement2(RowData rowData) throws Exception {
        if (this.leftIsBuild) {
            this.sortMergeJoinFunction.processElement2(rowData);
        } else {
            this.sortMergeJoinFunction.processElement1(rowData);
        }
    }

    public static HashJoinOperator newHashJoinOperator(HashJoinType hashJoinType, boolean z, GeneratedJoinCondition generatedJoinCondition, boolean z2, boolean[] zArr, GeneratedProjection generatedProjection, GeneratedProjection generatedProjection2, boolean z3, int i, long j, long j2, RowType rowType, SortMergeJoinFunction sortMergeJoinFunction) {
        HashJoinParameter hashJoinParameter = new HashJoinParameter(hashJoinType, z, generatedJoinCondition, z2, zArr, generatedProjection, generatedProjection2, z3, i, j, j2, rowType, sortMergeJoinFunction);
        switch (hashJoinType) {
            case INNER:
                return new InnerHashJoinOperator(hashJoinParameter);
            case BUILD_OUTER:
                return new BuildOuterHashJoinOperator(hashJoinParameter);
            case PROBE_OUTER:
                return new ProbeOuterHashJoinOperator(hashJoinParameter);
            case FULL_OUTER:
                return new FullOuterHashJoinOperator(hashJoinParameter);
            case SEMI:
                return new SemiHashJoinOperator(hashJoinParameter);
            case ANTI:
                return new AntiHashJoinOperator(hashJoinParameter);
            case BUILD_LEFT_SEMI:
            case BUILD_LEFT_ANTI:
                return new BuildLeftSemiOrAntiHashJoinOperator(hashJoinParameter);
            default:
                throw new IllegalArgumentException("invalid: " + hashJoinType);
        }
    }
}
