/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.table.resource.batch.managedmem;

import org.apache.flink.configuration.Configuration;
import org.apache.flink.table.plan.nodes.exec.BatchExecNode;
import org.apache.flink.table.plan.nodes.exec.batch.BatchExecNodeVisitorImpl;
import org.apache.flink.table.plan.nodes.physical.batch.BatchExecBoundedStreamScan;
import org.apache.flink.table.plan.nodes.physical.batch.BatchExecCalc;
import org.apache.flink.table.plan.nodes.physical.batch.BatchExecCorrelate;
import org.apache.flink.table.plan.nodes.physical.batch.BatchExecExchange;
import org.apache.flink.table.plan.nodes.physical.batch.BatchExecExpand;
import org.apache.flink.table.plan.nodes.physical.batch.BatchExecHashAggregate;
import org.apache.flink.table.plan.nodes.physical.batch.BatchExecHashAggregateBase;
import org.apache.flink.table.plan.nodes.physical.batch.BatchExecHashJoinBase;
import org.apache.flink.table.plan.nodes.physical.batch.BatchExecHashWindowAggregate;
import org.apache.flink.table.plan.nodes.physical.batch.BatchExecHashWindowAggregateBase;
import org.apache.flink.table.plan.nodes.physical.batch.BatchExecLimit;
import org.apache.flink.table.plan.nodes.physical.batch.BatchExecLocalHashAggregate;
import org.apache.flink.table.plan.nodes.physical.batch.BatchExecLocalHashWindowAggregate;
import org.apache.flink.table.plan.nodes.physical.batch.BatchExecLocalSortAggregate;
import org.apache.flink.table.plan.nodes.physical.batch.BatchExecLocalSortWindowAggregate;
import org.apache.flink.table.plan.nodes.physical.batch.BatchExecNestedLoopJoinBase;
import org.apache.flink.table.plan.nodes.physical.batch.BatchExecOverAggregate;
import org.apache.flink.table.plan.nodes.physical.batch.BatchExecRank;
import org.apache.flink.table.plan.nodes.physical.batch.BatchExecSort;
import org.apache.flink.table.plan.nodes.physical.batch.BatchExecSortAggregate;
import org.apache.flink.table.plan.nodes.physical.batch.BatchExecSortLimit;
import org.apache.flink.table.plan.nodes.physical.batch.BatchExecSortMergeJoinBase;
import org.apache.flink.table.plan.nodes.physical.batch.BatchExecSortWindowAggregate;
import org.apache.flink.table.plan.nodes.physical.batch.BatchExecTableSourceScan;
import org.apache.flink.table.plan.nodes.physical.batch.BatchExecTemporalTableJoin;
import org.apache.flink.table.plan.nodes.physical.batch.BatchExecUnion;
import org.apache.flink.table.plan.nodes.physical.batch.BatchExecValues;
import org.apache.flink.table.util.NodeResourceUtil;

public class BatchManagedMemCalculatorOnConfig
extends BatchExecNodeVisitorImpl {
    private final Configuration tableConf;

    public BatchManagedMemCalculatorOnConfig(Configuration tableConf) {
        this.tableConf = tableConf;
    }

    private void calculateNoManagedMem(BatchExecNode<?> batchExecNode) {
        super.visitInputs(batchExecNode);
        batchExecNode.getResource().setManagedMem(0, 0, 0);
    }

    @Override
    public void visit(BatchExecBoundedStreamScan boundedStreamScan) {
        this.calculateNoManagedMem(boundedStreamScan);
    }

    @Override
    public void visit(BatchExecTableSourceScan scanTableSource) {
        this.calculateNoManagedMem(scanTableSource);
    }

    @Override
    public void visit(BatchExecValues values) {
        this.calculateNoManagedMem(values);
    }

    @Override
    public void visit(BatchExecCalc calc) {
        this.calculateNoManagedMem(calc);
    }

    @Override
    public void visit(BatchExecCorrelate correlate) {
        this.calculateNoManagedMem(correlate);
    }

    @Override
    public void visit(BatchExecExchange exchange) {
        super.visitInputs(exchange);
    }

    @Override
    public void visit(BatchExecExpand expand) {
        this.calculateNoManagedMem(expand);
    }

    private void calculateHashAgg(BatchExecHashAggregateBase hashAgg) {
        if (hashAgg.getGrouping().length == 0) {
            this.calculateNoManagedMem(hashAgg);
            return;
        }
        super.visitInputs(hashAgg);
        int reservedMem = NodeResourceUtil.getHashAggManagedMemory(this.tableConf);
        int preferMem = NodeResourceUtil.getHashAggManagedPreferredMemory(this.tableConf);
        int maxMem = NodeResourceUtil.getHashAggManagedMaxMemory(this.tableConf);
        hashAgg.getResource().setManagedMem(reservedMem, preferMem, maxMem);
    }

    private void calculateHashWindowAgg(BatchExecHashWindowAggregateBase hashWindowAgg) {
        super.visitInputs(hashWindowAgg);
        int reservedMem = NodeResourceUtil.getHashAggManagedMemory(this.tableConf);
        int preferMem = NodeResourceUtil.getHashAggManagedPreferredMemory(this.tableConf);
        int maxMem = NodeResourceUtil.getHashAggManagedMaxMemory(this.tableConf);
        hashWindowAgg.getResource().setManagedMem(reservedMem, preferMem, maxMem);
    }

    @Override
    public void visit(BatchExecHashAggregate hashAggregate) {
        this.calculateHashAgg(hashAggregate);
    }

    @Override
    public void visit(BatchExecHashWindowAggregate hashAggregate) {
        this.calculateHashWindowAgg(hashAggregate);
    }

    @Override
    public void visit(BatchExecHashJoinBase hashJoin) {
        super.visitInputs(hashJoin);
        int reservedMem = NodeResourceUtil.getHashJoinTableManagedMemory(this.tableConf);
        int preferMem = NodeResourceUtil.getHashJoinTableManagedPreferredMemory(this.tableConf);
        int maxMem = NodeResourceUtil.getHashJoinTableManagedMaxMemory(this.tableConf);
        hashJoin.getResource().setManagedMem(reservedMem, preferMem, maxMem);
    }

    @Override
    public void visit(BatchExecSortMergeJoinBase sortMergeJoin) {
        super.visitInputs(sortMergeJoin);
        int externalBufferMemoryMb = NodeResourceUtil.getExternalBufferManagedMemory(this.tableConf) * sortMergeJoin.getExternalBufferNum();
        int sortMemory = NodeResourceUtil.getSortBufferManagedMemory(this.tableConf);
        int reservedMemory = sortMemory * 2 + externalBufferMemoryMb;
        int preferSortMemory = NodeResourceUtil.getSortBufferManagedPreferredMemory(this.tableConf);
        int preferMemory = preferSortMemory * 2 + externalBufferMemoryMb;
        int maxSortMemory = NodeResourceUtil.getSortBufferManagedMaxMemory(this.tableConf);
        int maxMemory = maxSortMemory * 2 + externalBufferMemoryMb;
        sortMergeJoin.getResource().setManagedMem(reservedMemory, preferMemory, maxMemory);
    }

    @Override
    public void visit(BatchExecNestedLoopJoinBase nestedLoopJoin) {
        if (nestedLoopJoin.singleRowJoin()) {
            this.calculateNoManagedMem(nestedLoopJoin);
        } else {
            super.visitInputs(nestedLoopJoin);
            int externalBufferMemoryMb = NodeResourceUtil.getExternalBufferManagedMemory(this.tableConf);
            nestedLoopJoin.getResource().setManagedMem(externalBufferMemoryMb, externalBufferMemoryMb, externalBufferMemoryMb);
        }
    }

    @Override
    public void visit(BatchExecLocalHashAggregate localHashAggregate) {
        this.calculateHashAgg(localHashAggregate);
    }

    @Override
    public void visit(BatchExecSortAggregate sortAggregate) {
        this.calculateNoManagedMem(sortAggregate);
    }

    @Override
    public void visit(BatchExecLocalHashWindowAggregate localHashAggregate) {
        this.calculateHashWindowAgg(localHashAggregate);
    }

    @Override
    public void visit(BatchExecLocalSortAggregate localSortAggregate) {
        this.calculateNoManagedMem(localSortAggregate);
    }

    @Override
    public void visit(BatchExecLocalSortWindowAggregate localSortAggregate) {
        this.calculateNoManagedMem(localSortAggregate);
    }

    @Override
    public void visit(BatchExecSortWindowAggregate sortAggregate) {
        this.calculateNoManagedMem(sortAggregate);
    }

    @Override
    public void visit(BatchExecOverAggregate overWindowAgg) {
        boolean[] needBufferList = (boolean[])overWindowAgg.needBufferDataToNeedResetAcc()._1;
        boolean needBuffer = false;
        for (boolean b : needBufferList) {
            if (!b) continue;
            needBuffer = true;
            break;
        }
        if (!needBuffer) {
            this.calculateNoManagedMem(overWindowAgg);
        } else {
            super.visitInputs(overWindowAgg);
            int externalBufferMemory = NodeResourceUtil.getExternalBufferManagedMemory(this.tableConf);
            overWindowAgg.getResource().setManagedMem(externalBufferMemory, externalBufferMemory, externalBufferMemory);
        }
    }

    @Override
    public void visit(BatchExecLimit limit) {
        this.calculateNoManagedMem(limit);
    }

    @Override
    public void visit(BatchExecSort sort) {
        super.visitInputs(sort);
        int reservedMemory = NodeResourceUtil.getSortBufferManagedMemory(this.tableConf);
        int preferMemory = NodeResourceUtil.getSortBufferManagedPreferredMemory(this.tableConf);
        int maxMemory = NodeResourceUtil.getSortBufferManagedMaxMemory(this.tableConf);
        sort.getResource().setManagedMem(reservedMemory, preferMemory, maxMemory);
    }

    @Override
    public void visit(BatchExecSortLimit sortLimit) {
        this.calculateNoManagedMem(sortLimit);
    }

    @Override
    public void visit(BatchExecRank rank) {
        this.calculateNoManagedMem(rank);
    }

    @Override
    public void visit(BatchExecUnion union) {
        super.visitInputs(union);
    }

    @Override
    public void visit(BatchExecTemporalTableJoin joinTable) {
        this.calculateNoManagedMem(joinTable);
    }
}

