diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/logical/RemoteCalcSplitRule.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/logical/RemoteCalcSplitRule.scala index 26baa44f43d82..98205d0f8fc29 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/logical/RemoteCalcSplitRule.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/logical/RemoteCalcSplitRule.scala @@ -432,9 +432,7 @@ class ScalarFunctionSplitter( callFinder: RemoteCallFinder) extends RexDefaultVisitor[RexNode] { - private var fieldsRexCall: Map[Int, Int] = Map[Int, Int]() - - private val extractedRexNodeRefs: mutable.HashSet[RexNode] = mutable.HashSet[RexNode]() + private val extractedRexNodeToIndex = mutable.HashMap.empty[RexNode, Int] override def visitCall(call: RexCall): RexNode = { if (needConvert(call)) { @@ -474,30 +472,33 @@ class ScalarFunctionSplitter( private def convertInputRefToLocalRefIfNecessary(node: RexNode): RexNode = { node match { - case inputRef: RexInputRef if extractedRexNodeRefs.contains(node) => + case inputRef: RexInputRef if inputRef.getIndex >= extractedFunctionOffset => new RexLocalRef(inputRef.getIndex, node.getType) case _ => node } } private def getExtractedRexNode(node: RexNode): RexNode = { - val newNode = new RexInputRef(extractedFunctionOffset + extractedRexNodes.length, node.getType) - extractedRexNodes.append(node) - extractedRexNodeRefs.add(newNode) - newNode + new RexInputRef(extractedFunctionOffset + getExtractedRexNodeIndex(node), node.getType) } private def getExtractedRexFieldAccess(node: RexFieldAccess, rexCallIndex: Int): RexNode = { val remoteCall: RexCall = program.expandLocalRef(node.getReferenceExpr.asInstanceOf[RexLocalRef]).asInstanceOf[RexCall] - if (!fieldsRexCall.contains(rexCallIndex)) { - extractedRexNodes.append(remoteCall) - fieldsRexCall += rexCallIndex -> (extractedFunctionOffset + extractedRexNodes.length - 1) - } rexBuilder.makeFieldAccess( - new RexInputRef(fieldsRexCall(rexCallIndex), remoteCall.getType), + new RexInputRef( + extractedFunctionOffset + getExtractedRexNodeIndex(remoteCall), + remoteCall.getType), node.getField.getIndex) } + + private def getExtractedRexNodeIndex(node: RexNode): Int = { + extractedRexNodeToIndex.getOrElseUpdate( + node, { + extractedRexNodes.append(node) + extractedRexNodes.length - 1 + }) + } } /** diff --git a/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/rules/logical/AsyncCalcSplitRuleTest.xml b/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/rules/logical/AsyncCalcSplitRuleTest.xml index 726b02847c8a5..d7c7cbc1adec8 100644 --- a/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/rules/logical/AsyncCalcSplitRuleTest.xml +++ b/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/rules/logical/AsyncCalcSplitRuleTest.xml @@ -716,12 +716,10 @@ LogicalProject(EXPR$0=[func1(func1($0))], EXPR$1=[func1(func1(func1($0)))], EXPR diff --git a/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/rules/logical/PythonCalcSplitRuleTest.xml b/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/rules/logical/PythonCalcSplitRuleTest.xml index 14a1873bf5184..4e94a33aa7d7d 100644 --- a/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/rules/logical/PythonCalcSplitRuleTest.xml +++ b/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/rules/logical/PythonCalcSplitRuleTest.xml @@ -520,6 +520,24 @@ LogicalProject(a=[$0], EXPR$1=[pyFunc1($0, $2)], b=[$1]) FlinkLogicalCalc(select=[a, f0 AS EXPR$1, b]) +- FlinkLogicalCalc(select=[a, b, pyFunc1(a, c) AS f0]) +- FlinkLogicalTableSourceScan(table=[[default_catalog, default_database, MyTable]], fields=[a, b, c, d]) +]]> + + + + + + + + + + + diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/rules/logical/PythonCalcSplitRuleTest.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/rules/logical/PythonCalcSplitRuleTest.scala index c4f989d244316..b796345e4a4e3 100644 --- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/rules/logical/PythonCalcSplitRuleTest.scala +++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/rules/logical/PythonCalcSplitRuleTest.scala @@ -120,6 +120,12 @@ class PythonCalcSplitRuleTest extends TableTestBase { util.verifyRelPlan(sqlQuery) } + @Test + def testSamePythonFunctionUsedInMultipleProjectionExpressions(): Unit = { + val sqlQuery = "SELECT pyFunc1(a, c) + 1, pyFunc1(a, c) + 2 FROM MyTable" + util.verifyRelPlan(sqlQuery) + } + @Test def testReorderPythonCalc(): Unit = { val sqlQuery = "SELECT a, pyFunc1(a, c), b FROM MyTable"