/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.table.plan.rules.logical;

import java.util.List;
import java.util.stream.Collectors;
import org.apache.calcite.plan.RelOptRule;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.RelOptUtil;
import org.apache.calcite.rel.core.JoinRelType;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexProgram;
import org.apache.calcite.rex.RexProgramBuilder;
import org.apache.calcite.rex.RexUtil;
import org.apache.flink.table.plan.nodes.logical.FlinkLogicalCalc;
import org.apache.flink.table.plan.nodes.logical.FlinkLogicalCorrelate;
import org.apache.flink.table.plan.nodes.logical.FlinkLogicalRel;
import org.apache.flink.table.plan.nodes.logical.FlinkLogicalTableFunctionScan;
import org.apache.flink.table.plan.rules.logical.InputRefRewriter;
import org.apache.flink.table.plan.util.CorrelateUtil;
import org.apache.flink.table.plan.util.PythonUtil;
import scala.Option;

public class CalcPythonCorrelateTransposeRule
extends RelOptRule {
    public static final CalcPythonCorrelateTransposeRule INSTANCE = new CalcPythonCorrelateTransposeRule();

    private CalcPythonCorrelateTransposeRule() {
        super(CalcPythonCorrelateTransposeRule.operand(FlinkLogicalCorrelate.class, CalcPythonCorrelateTransposeRule.operand(FlinkLogicalRel.class, CalcPythonCorrelateTransposeRule.any()), CalcPythonCorrelateTransposeRule.operand(FlinkLogicalCalc.class, CalcPythonCorrelateTransposeRule.any())), "CalcPythonCorrelateTransposeRule");
    }

    @Override
    public boolean matches(RelOptRuleCall call) {
        FlinkLogicalCorrelate correlate = (FlinkLogicalCorrelate)call.rel(0);
        FlinkLogicalCalc right = (FlinkLogicalCalc)call.rel(2);
        JoinRelType joinType = correlate.getJoinType();
        FlinkLogicalCalc mergedCalc = CorrelateUtil.getMergedCalc(right);
        Option<FlinkLogicalTableFunctionScan> scan = CorrelateUtil.getTableFunctionScan(mergedCalc);
        return joinType == JoinRelType.INNER && scan.isDefined() && PythonUtil.isPythonCall(((FlinkLogicalTableFunctionScan)scan.get()).getCall(), null) && mergedCalc.getProgram().getCondition() != null;
    }

    @Override
    public void onMatch(RelOptRuleCall call) {
        FlinkLogicalCorrelate correlate = (FlinkLogicalCorrelate)call.rel(0);
        FlinkLogicalCalc right = (FlinkLogicalCalc)call.rel(2);
        RexBuilder rexBuilder = call.builder().getRexBuilder();
        FlinkLogicalCalc mergedCalc = CorrelateUtil.getMergedCalc(right);
        FlinkLogicalTableFunctionScan tableScan = (FlinkLogicalTableFunctionScan)CorrelateUtil.getTableFunctionScan(mergedCalc).get();
        RexProgram mergedCalcProgram = mergedCalc.getProgram();
        InputRefRewriter inputRefRewriter = new InputRefRewriter(correlate.getRowType().getFieldCount() - mergedCalc.getRowType().getFieldCount());
        List correlateFilters = RelOptUtil.conjunctions(mergedCalcProgram.expandLocalRef(mergedCalcProgram.getCondition())).stream().map(x -> x.accept(inputRefRewriter)).collect(Collectors.toList());
        FlinkLogicalCorrelate newCorrelate = new FlinkLogicalCorrelate(correlate.getCluster(), correlate.getTraitSet(), correlate.getLeft(), tableScan, correlate.getCorrelationId(), correlate.getRequiredColumns(), correlate.getJoinType());
        RexNode topCalcCondition = RexUtil.composeConjunction(rexBuilder, correlateFilters);
        RexProgram rexProgram = new RexProgramBuilder(newCorrelate.getRowType(), rexBuilder).getProgram();
        FlinkLogicalCalc newTopCalc = new FlinkLogicalCalc(newCorrelate.getCluster(), newCorrelate.getTraitSet(), newCorrelate, RexProgram.create(newCorrelate.getRowType(), rexProgram.getExprList(), topCalcCondition, newCorrelate.getRowType(), rexBuilder));
        call.transformTo(newTopCalc);
    }
}

