diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/logical/RemoteCalcSplitRule.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/logical/RemoteCalcSplitRule.scala
index 26baa44f43d82..98205d0f8fc29 100644
--- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/logical/RemoteCalcSplitRule.scala
+++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/logical/RemoteCalcSplitRule.scala
@@ -432,9 +432,7 @@ class ScalarFunctionSplitter(
callFinder: RemoteCallFinder)
extends RexDefaultVisitor[RexNode] {
- private var fieldsRexCall: Map[Int, Int] = Map[Int, Int]()
-
- private val extractedRexNodeRefs: mutable.HashSet[RexNode] = mutable.HashSet[RexNode]()
+ private val extractedRexNodeToIndex = mutable.HashMap.empty[RexNode, Int]
override def visitCall(call: RexCall): RexNode = {
if (needConvert(call)) {
@@ -474,30 +472,33 @@ class ScalarFunctionSplitter(
private def convertInputRefToLocalRefIfNecessary(node: RexNode): RexNode = {
node match {
- case inputRef: RexInputRef if extractedRexNodeRefs.contains(node) =>
+ case inputRef: RexInputRef if inputRef.getIndex >= extractedFunctionOffset =>
new RexLocalRef(inputRef.getIndex, node.getType)
case _ => node
}
}
private def getExtractedRexNode(node: RexNode): RexNode = {
- val newNode = new RexInputRef(extractedFunctionOffset + extractedRexNodes.length, node.getType)
- extractedRexNodes.append(node)
- extractedRexNodeRefs.add(newNode)
- newNode
+ new RexInputRef(extractedFunctionOffset + getExtractedRexNodeIndex(node), node.getType)
}
private def getExtractedRexFieldAccess(node: RexFieldAccess, rexCallIndex: Int): RexNode = {
val remoteCall: RexCall =
program.expandLocalRef(node.getReferenceExpr.asInstanceOf[RexLocalRef]).asInstanceOf[RexCall]
- if (!fieldsRexCall.contains(rexCallIndex)) {
- extractedRexNodes.append(remoteCall)
- fieldsRexCall += rexCallIndex -> (extractedFunctionOffset + extractedRexNodes.length - 1)
- }
rexBuilder.makeFieldAccess(
- new RexInputRef(fieldsRexCall(rexCallIndex), remoteCall.getType),
+ new RexInputRef(
+ extractedFunctionOffset + getExtractedRexNodeIndex(remoteCall),
+ remoteCall.getType),
node.getField.getIndex)
}
+
+ private def getExtractedRexNodeIndex(node: RexNode): Int = {
+ extractedRexNodeToIndex.getOrElseUpdate(
+ node, {
+ extractedRexNodes.append(node)
+ extractedRexNodes.length - 1
+ })
+ }
}
/**
diff --git a/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/rules/logical/AsyncCalcSplitRuleTest.xml b/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/rules/logical/AsyncCalcSplitRuleTest.xml
index 726b02847c8a5..d7c7cbc1adec8 100644
--- a/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/rules/logical/AsyncCalcSplitRuleTest.xml
+++ b/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/rules/logical/AsyncCalcSplitRuleTest.xml
@@ -716,12 +716,10 @@ LogicalProject(EXPR$0=[func1(func1($0))], EXPR$1=[func1(func1(func1($0)))], EXPR
diff --git a/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/rules/logical/PythonCalcSplitRuleTest.xml b/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/rules/logical/PythonCalcSplitRuleTest.xml
index 14a1873bf5184..4e94a33aa7d7d 100644
--- a/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/rules/logical/PythonCalcSplitRuleTest.xml
+++ b/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/rules/logical/PythonCalcSplitRuleTest.xml
@@ -520,6 +520,24 @@ LogicalProject(a=[$0], EXPR$1=[pyFunc1($0, $2)], b=[$1])
FlinkLogicalCalc(select=[a, f0 AS EXPR$1, b])
+- FlinkLogicalCalc(select=[a, b, pyFunc1(a, c) AS f0])
+- FlinkLogicalTableSourceScan(table=[[default_catalog, default_database, MyTable]], fields=[a, b, c, d])
+]]>
+
+
+
+
+
+
+
+
+
+
+
diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/rules/logical/PythonCalcSplitRuleTest.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/rules/logical/PythonCalcSplitRuleTest.scala
index c4f989d244316..b796345e4a4e3 100644
--- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/rules/logical/PythonCalcSplitRuleTest.scala
+++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/rules/logical/PythonCalcSplitRuleTest.scala
@@ -120,6 +120,12 @@ class PythonCalcSplitRuleTest extends TableTestBase {
util.verifyRelPlan(sqlQuery)
}
+ @Test
+ def testSamePythonFunctionUsedInMultipleProjectionExpressions(): Unit = {
+ val sqlQuery = "SELECT pyFunc1(a, c) + 1, pyFunc1(a, c) + 2 FROM MyTable"
+ util.verifyRelPlan(sqlQuery)
+ }
+
@Test
def testReorderPythonCalc(): Unit = {
val sqlQuery = "SELECT a, pyFunc1(a, c), b FROM MyTable"