From b32930de5e17b9e15c43a2b83f3a27f9aa40bfc6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=BCleyman=20Melih=20Portakal?= <56885279+smelihportakal@users.noreply.github.com> Date: Wed, 22 Apr 2026 22:53:03 +0200 Subject: [PATCH] [SYSTEMDS-3941] Add new rewrites This patch adds two algebraic simplification rewrites. First, it pushes scalar multiplication through rowSums and colSums, rewriting rowSums(a*A) to a*rowSums(A) and colSums(a*A) to a*colSums(A). Second, it simplifies sum(matrix(a, rows=b, cols=c)) to a*b*c when the matrix dimensions are known. The patch also adds rewrite tests for rowSums, colSums, and constant-value matrix sums with rewrites enabled and disabled to validate correctness against the reference outputs. --- ...RewriteAlgebraicSimplificationDynamic.java | 26 +++++ .../RewriteAlgebraicSimplificationStatic.java | 54 ++++++++++ .../rewrite/RewriteFusedRandTest.java | 2 +- .../RewritePushdownColSumBinaryMultTest.java | 100 ++++++++++++++++++ .../RewritePushdownRowSumBinaryMultTest.java | 100 ++++++++++++++++++ .../RewriteSimplifySumConstantMatrixTest.java | 86 +++++++++++++++ .../functions/rewrite/RewriteFusedRandLit.dml | 2 +- .../rewrite/RewritePushdownColSumBinaryMult.R | 24 +++++ .../RewritePushdownColSumBinaryMult.dml | 25 +++++ .../RewritePushdownColSumBinaryMult2.R | 24 +++++ .../RewritePushdownColSumBinaryMult2.dml | 25 +++++ .../rewrite/RewritePushdownRowSumBinaryMult.R | 24 +++++ .../RewritePushdownRowSumBinaryMult.dml | 25 +++++ .../RewritePushdownRowSumBinaryMult2.R | 24 +++++ .../RewritePushdownRowSumBinaryMult2.dml | 25 +++++ .../RewriteSimplifySumConstantMatrix.R | 28 +++++ .../RewriteSimplifySumConstantMatrix.dml | 23 ++++ 17 files changed, 615 insertions(+), 2 deletions(-) create mode 100644 src/test/java/org/apache/sysds/test/functions/rewrite/RewritePushdownColSumBinaryMultTest.java create mode 100644 src/test/java/org/apache/sysds/test/functions/rewrite/RewritePushdownRowSumBinaryMultTest.java create mode 100644 src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifySumConstantMatrixTest.java create mode 100644 src/test/scripts/functions/rewrite/RewritePushdownColSumBinaryMult.R create mode 100644 src/test/scripts/functions/rewrite/RewritePushdownColSumBinaryMult.dml create mode 100644 src/test/scripts/functions/rewrite/RewritePushdownColSumBinaryMult2.R create mode 100644 src/test/scripts/functions/rewrite/RewritePushdownColSumBinaryMult2.dml create mode 100644 src/test/scripts/functions/rewrite/RewritePushdownRowSumBinaryMult.R create mode 100644 src/test/scripts/functions/rewrite/RewritePushdownRowSumBinaryMult.dml create mode 100644 src/test/scripts/functions/rewrite/RewritePushdownRowSumBinaryMult2.R create mode 100644 src/test/scripts/functions/rewrite/RewritePushdownRowSumBinaryMult2.dml create mode 100644 src/test/scripts/functions/rewrite/RewriteSimplifySumConstantMatrix.R create mode 100644 src/test/scripts/functions/rewrite/RewriteSimplifySumConstantMatrix.dml diff --git a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java index eb51348a8e3..79b6d8a39bd 100644 --- a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java +++ b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java @@ -184,6 +184,7 @@ private void rule_AlgebraicSimplification(Hop hop, boolean descendFirst) hi = simplifySumDiagToTrace(hi); //e.g., sum(diag(X)) -> trace(X); if col vector hi = simplifyLowerTriExtraction(hop, hi, i); //e.g., X * cumsum(diag(matrix(1,nrow(X),1))) -> lower.tri hi = simplifyConstantCumsum(hop, hi, i); //e.g., cumsum(matrix(1/n,n,1)) -> seq(1/n, 1, 1/n) + hi = simplifySumConstantMatrix(hop, hi, i); //e.g., sum(matrix(a,rows=b,cols=c)) -> a*b*c hi = pushdownBinaryOperationOnDiag(hop, hi, i); //e.g., diag(X)*7 -> diag(X*7); if col vector hi = pushdownSumOnAdditiveBinary(hop, hi, i); //e.g., sum(A+B) -> sum(A)+sum(B); if dims(A)==dims(B) if(OptimizerUtils.ALLOW_OPERATOR_FUSION) { @@ -1273,6 +1274,31 @@ private static Hop simplifyConstantCumsum(Hop parent, Hop hi, int pos) { } return hi; } + + private static Hop simplifySumConstantMatrix(Hop parent, Hop hi, int pos) { + //pattern: sum(matrix(a, rows=b, cols=c)) -> a*b*c + if( HopRewriteUtils.isAggUnaryOp(hi, AggOp.SUM, Direction.RowCol) + && HopRewriteUtils.isDataGenOpWithConstantValue(hi.getInput(0)) + && hi.getInput(0).dimsKnown() + && hi.getInput(0).getDim1() >= 1 + && hi.getInput(0).getDim2() >= 1 + && hi.getInput(0).getParent().size() == 1 ) + { + DataGenOp datagen = (DataGenOp) hi.getInput(0); + Hop constVal = datagen.getConstantValue(); + Hop rows = new LiteralOp(datagen.getDim1()); + Hop cols = new LiteralOp(datagen.getDim2()); + + Hop hnew = HopRewriteUtils.createBinary( + HopRewriteUtils.createBinary(constVal, rows, OpOp2.MULT), cols, OpOp2.MULT); + HopRewriteUtils.replaceChildReference(parent, hi, hnew, pos); + HopRewriteUtils.cleanupUnreferenced(hi, datagen); + + hi = hnew; + LOG.debug("Applied simplifySumConstantMatrix (line "+hi.getBeginLine()+")."); + } + return hi; + } private static Hop pushdownBinaryOperationOnDiag(Hop parent, Hop hi, int pos) { diff --git a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java index 2ae15502575..b014ab7920c 100644 --- a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java +++ b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java @@ -170,6 +170,8 @@ private void rule_AlgebraicSimplification(Hop hop, boolean descendFirst) hi = pushdownDetMultOperation(hop, hi, i); //e.g., det(X%*%Y) -> det(X)*det(Y) hi = pushdownDetScalarMatrixMultOperation(hop, hi, i); //e.g., det(lambda*X) -> lambda^nrow(X)*det(X) hi = pushdownSumBinaryMult(hop, hi, i); //e.g., sum(lambda*X) -> lambda*sum(X) + hi = pushdownRowSumBinaryMult(hop, hi, i); //e.g., rowSums(lambda*X) -> lambda*rowSums(X) + hi = pushdownColSumBinaryMult(hop, hi, i); //e.g., colSums(lambda*X) -> lambda*colSums(X) hi = pullupAbs(hop, hi, i); //e.g., abs(X)*abs(Y) --> abs(X*Y) hi = simplifyUnaryPPredOperation(hop, hi, i); //e.g., abs(ppred()) -> ppred(), others: round, ceil, floor hi = simplifyTransposedAppend(hop, hi, i); //e.g., t(cbind(t(A),t(B))) -> rbind(A,B); @@ -1447,6 +1449,58 @@ private static Hop pushdownSumBinaryMult(Hop parent, Hop hi, int pos ) { return hi; } + private static Hop pushdownRowSumBinaryMult(Hop parent, Hop hi, int pos ) { + //pattern: rowSums(lamda*X) -> lamda*rowSums(X) + if( hi instanceof AggUnaryOp && ((AggUnaryOp)hi).getDirection()==Direction.Row + && ((AggUnaryOp)hi).getOp()==AggOp.SUM // only one parent which is the rowSums + && HopRewriteUtils.isBinary(hi.getInput(0), OpOp2.MULT, 1) + && ((hi.getInput(0).getInput(0).getDataType()==DataType.SCALAR && hi.getInput(0).getInput(1).getDataType()==DataType.MATRIX) + ||(hi.getInput(0).getInput(0).getDataType()==DataType.MATRIX && hi.getInput(0).getInput(1).getDataType()==DataType.SCALAR))) + { + Hop operand1 = hi.getInput(0).getInput(0); + Hop operand2 = hi.getInput(0).getInput(1); + + //check which operand is the Scalar and which is the matrix + Hop lamda = (operand1.getDataType()==DataType.SCALAR) ? operand1 : operand2; + Hop matrix = (operand1.getDataType()==DataType.MATRIX) ? operand1 : operand2; + + AggUnaryOp aggOp=HopRewriteUtils.createAggUnaryOp(matrix, AggOp.SUM, Direction.Row); + Hop bop = HopRewriteUtils.createBinary(lamda, aggOp, OpOp2.MULT); + + HopRewriteUtils.replaceChildReference(parent, hi, bop, pos); + + LOG.debug("Applied pushdownRowSumBinaryMult (line "+hi.getBeginLine()+")."); + return bop; + } + return hi; + } + + private static Hop pushdownColSumBinaryMult(Hop parent, Hop hi, int pos ) { + //pattern: colSums(lamda*X) -> lamda*colSums(X) + if( hi instanceof AggUnaryOp && ((AggUnaryOp)hi).getDirection()==Direction.Col + && ((AggUnaryOp)hi).getOp()==AggOp.SUM // only one parent which is the colSums + && HopRewriteUtils.isBinary(hi.getInput(0), OpOp2.MULT, 1) + && ((hi.getInput(0).getInput(0).getDataType()==DataType.SCALAR && hi.getInput(0).getInput(1).getDataType()==DataType.MATRIX) + ||(hi.getInput(0).getInput(0).getDataType()==DataType.MATRIX && hi.getInput(0).getInput(1).getDataType()==DataType.SCALAR))) + { + Hop operand1 = hi.getInput(0).getInput(0); + Hop operand2 = hi.getInput(0).getInput(1); + + //check which operand is the Scalar and which is the matrix + Hop lamda = (operand1.getDataType()==DataType.SCALAR) ? operand1 : operand2; + Hop matrix = (operand1.getDataType()==DataType.MATRIX) ? operand1 : operand2; + + AggUnaryOp aggOp=HopRewriteUtils.createAggUnaryOp(matrix, AggOp.SUM, Direction.Col); + Hop bop = HopRewriteUtils.createBinary(lamda, aggOp, OpOp2.MULT); + + HopRewriteUtils.replaceChildReference(parent, hi, bop, pos); + + LOG.debug("Applied pushdownColSumBinaryMult (line "+hi.getBeginLine()+")."); + return bop; + } + return hi; + } + private static Hop pullupAbs(Hop parent, Hop hi, int pos ) { if( HopRewriteUtils.isBinary(hi, OpOp2.MULT) && HopRewriteUtils.isUnary(hi.getInput(0), OpOp1.ABS) diff --git a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteFusedRandTest.java b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteFusedRandTest.java index ef580848fba..e840accd068 100644 --- a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteFusedRandTest.java +++ b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteFusedRandTest.java @@ -121,7 +121,7 @@ private void testRewriteFusedRand( String testname, String pdf, boolean rewrites //compare matrices Double ret = readDMLMatrixFromOutputDir("R").get(new CellIndex(1,1)); if( testname.equals(TEST_NAME1) ) - Assert.assertEquals("Wrong result", Double.valueOf(rows), ret); + Assert.assertEquals("Wrong result", Double.valueOf(rows*cols), ret); else if( testname.equals(TEST_NAME2) ) Assert.assertEquals("Wrong result", Double.valueOf(Math.pow(rows*cols, 2)), ret); else if( testname.equals(TEST_NAME3) ) diff --git a/src/test/java/org/apache/sysds/test/functions/rewrite/RewritePushdownColSumBinaryMultTest.java b/src/test/java/org/apache/sysds/test/functions/rewrite/RewritePushdownColSumBinaryMultTest.java new file mode 100644 index 00000000000..eb31f12c3d1 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/rewrite/RewritePushdownColSumBinaryMultTest.java @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.functions.rewrite; + +import java.util.HashMap; + +import org.junit.Test; +import org.apache.sysds.hops.OptimizerUtils; +import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestConfiguration; +import org.apache.sysds.test.TestUtils; +import org.apache.sysds.utils.Statistics; +import org.junit.Assert; + +public class RewritePushdownColSumBinaryMultTest extends AutomatedTestBase +{ + private static final String TEST_NAME1 = "RewritePushdownColSumBinaryMult"; + private static final String TEST_NAME2 = "RewritePushdownColSumBinaryMult2"; + + private static final String TEST_DIR = "functions/rewrite/"; + private static final String TEST_CLASS_DIR = TEST_DIR + RewritePushdownColSumBinaryMultTest.class.getSimpleName() + "/"; + + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + addTestConfiguration(TEST_NAME1, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] { "R" })); + addTestConfiguration(TEST_NAME2, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[] { "R" })); + } + + @Test + public void testPushdownColSumBinaryMultNoRewrite() { + testRewritePushdownColSumBinaryMult(TEST_NAME1, false); + } + + @Test + public void testPushdownColSumBinaryMultRewrite() { + testRewritePushdownColSumBinaryMult(TEST_NAME1, true); + } + + @Test + public void testPushdownColSumBinaryMultNoRewrite2() { + testRewritePushdownColSumBinaryMult(TEST_NAME2, false); + } + + @Test + public void testPushdownColSumBinaryMultRewrite2() { + testRewritePushdownColSumBinaryMult(TEST_NAME2, true); + } + + private void testRewritePushdownColSumBinaryMult(String testname, boolean rewrites) { + boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION; + + try { + TestConfiguration config = getTestConfiguration(testname); + loadTestConfiguration(config); + + String HOME = SCRIPT_DIR + TEST_DIR; + fullDMLScriptName = HOME + testname + ".dml"; + programArgs = new String[] { "-stats", "-args", output("R") }; + + fullRScriptName = HOME + testname + ".R"; + rCmd = getRCmd(inputDir(), expectedDir()); + + OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = rewrites; + + runTest(true, false, null, -1); + runRScript(true); + + HashMap dmlfile = readDMLMatrixFromOutputDir("R"); + HashMap rfile = readRMatrixFromExpectedDir("R"); + TestUtils.compareMatrices(dmlfile, rfile, 1e-10, "DML", "R"); + + if(rewrites) + Assert.assertEquals(1, Statistics.getCPHeavyHitterCount("n*")); + else + Assert.assertEquals(2, Statistics.getCPHeavyHitterCount("*")); + } + finally { + OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldFlag; + } + } +} diff --git a/src/test/java/org/apache/sysds/test/functions/rewrite/RewritePushdownRowSumBinaryMultTest.java b/src/test/java/org/apache/sysds/test/functions/rewrite/RewritePushdownRowSumBinaryMultTest.java new file mode 100644 index 00000000000..cfa18ee335d --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/rewrite/RewritePushdownRowSumBinaryMultTest.java @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.functions.rewrite; + +import java.util.HashMap; + +import org.junit.Test; +import org.apache.sysds.hops.OptimizerUtils; +import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestConfiguration; +import org.apache.sysds.test.TestUtils; +import org.apache.sysds.utils.Statistics; +import org.junit.Assert; + +public class RewritePushdownRowSumBinaryMultTest extends AutomatedTestBase +{ + private static final String TEST_NAME1 = "RewritePushdownRowSumBinaryMult"; + private static final String TEST_NAME2 = "RewritePushdownRowSumBinaryMult2"; + + private static final String TEST_DIR = "functions/rewrite/"; + private static final String TEST_CLASS_DIR = TEST_DIR + RewritePushdownRowSumBinaryMultTest.class.getSimpleName() + "/"; + + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + addTestConfiguration(TEST_NAME1, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] { "R" })); + addTestConfiguration(TEST_NAME2, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[] { "R" })); + } + + @Test + public void testPushdownRowSumBinaryMultNoRewrite() { + testRewritePushdownRowSumBinaryMult(TEST_NAME1, false); + } + + @Test + public void testPushdownRowSumBinaryMultRewrite() { + testRewritePushdownRowSumBinaryMult(TEST_NAME1, true); + } + + @Test + public void testPushdownRowSumBinaryMultNoRewrite2() { + testRewritePushdownRowSumBinaryMult(TEST_NAME2, false); + } + + @Test + public void testPushdownRowSumBinaryMultRewrite2() { + testRewritePushdownRowSumBinaryMult(TEST_NAME2, true); + } + + private void testRewritePushdownRowSumBinaryMult(String testname, boolean rewrites) { + boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION; + + try { + TestConfiguration config = getTestConfiguration(testname); + loadTestConfiguration(config); + + String HOME = SCRIPT_DIR + TEST_DIR; + fullDMLScriptName = HOME + testname + ".dml"; + programArgs = new String[] { "-stats", "-args", output("R") }; + + fullRScriptName = HOME + testname + ".R"; + rCmd = getRCmd(inputDir(), expectedDir()); + + OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = rewrites; + + runTest(true, false, null, -1); + runRScript(true); + + HashMap dmlfile = readDMLMatrixFromOutputDir("R"); + HashMap rfile = readRMatrixFromExpectedDir("R"); + TestUtils.compareMatrices(dmlfile, rfile, 1e-10, "DML", "R"); + + if(rewrites) + Assert.assertEquals(1, Statistics.getCPHeavyHitterCount("n*")); + else + Assert.assertEquals(2, Statistics.getCPHeavyHitterCount("*")); + } + finally { + OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldFlag; + } + } +} diff --git a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifySumConstantMatrixTest.java b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifySumConstantMatrixTest.java new file mode 100644 index 00000000000..9530740ff8e --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifySumConstantMatrixTest.java @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.functions.rewrite; + +import org.junit.Assert; +import org.junit.Test; +import org.apache.sysds.hops.OptimizerUtils; +import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestConfiguration; +import org.apache.sysds.test.TestUtils; + +public class RewriteSimplifySumConstantMatrixTest extends AutomatedTestBase +{ + private static final String TEST_NAME = "RewriteSimplifySumConstantMatrix"; + private static final String TEST_DIR = "functions/rewrite/"; + private static final String TEST_CLASS_DIR = TEST_DIR + RewriteSimplifySumConstantMatrixTest.class.getSimpleName() + "/"; + + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] { "R" })); + } + + @Test + public void testSimplifySumConstantMatrixNoRewritePositive() { + testRewriteSimplifySumConstantMatrix(2.5, 7, 11, false); + } + + @Test + public void testSimplifySumConstantMatrixRewritePositive() { + testRewriteSimplifySumConstantMatrix(2.5, 7, 11, true); + } + + + private void testRewriteSimplifySumConstantMatrix(double value, long rows, long cols, boolean rewrites) { + boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION; + + try { + TestConfiguration config = getTestConfiguration(TEST_NAME); + loadTestConfiguration(config); + + String HOME = SCRIPT_DIR + TEST_DIR; + fullDMLScriptName = HOME + TEST_NAME + ".dml"; + programArgs = new String[] { + "-stats", "-args", + String.valueOf(value), String.valueOf(rows), String.valueOf(cols), output("R") + }; + + fullRScriptName = HOME + TEST_NAME + ".R"; + rCmd = getRCmd(String.valueOf(value), String.valueOf(rows), String.valueOf(cols), expectedDir()); + + OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = rewrites; + + runTest(true, false, null, -1); + runRScript(true); + + double actual = readDMLScalarFromOutputDir("R").get(new CellIndex(1, 1)); + double expected = readRScalarFromExpectedDir("R").get(new CellIndex(1, 1)); + Assert.assertEquals(expected, actual, 1e-15); + + if(rewrites) + Assert.assertFalse(heavyHittersContainsString("rand")); + } + finally { + OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldFlag; + } + } +} diff --git a/src/test/scripts/functions/rewrite/RewriteFusedRandLit.dml b/src/test/scripts/functions/rewrite/RewriteFusedRandLit.dml index ab00f047727..2e97afdba9b 100644 --- a/src/test/scripts/functions/rewrite/RewriteFusedRandLit.dml +++ b/src/test/scripts/functions/rewrite/RewriteFusedRandLit.dml @@ -25,5 +25,5 @@ while(FALSE){} #prevent cse X2 = rand(rows=$1, cols=$2, pdf=$3, seed=$4) * 7; -R = as.matrix(sum(rowSums(X1)==rowSums(X2))); +R = as.matrix(sum(abs(X1)==abs(X2))); write(R, $5); \ No newline at end of file diff --git a/src/test/scripts/functions/rewrite/RewritePushdownColSumBinaryMult.R b/src/test/scripts/functions/rewrite/RewritePushdownColSumBinaryMult.R new file mode 100644 index 00000000000..8813cc5a504 --- /dev/null +++ b/src/test/scripts/functions/rewrite/RewritePushdownColSumBinaryMult.R @@ -0,0 +1,24 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +args<-commandArgs(TRUE) +library("Matrix") +X=matrix(1, 100, 1) %*% t(seq(1,100)) +R=matrix(2*colSums(3*X), nrow=1) +writeMM(as(R, "CsparseMatrix"), paste(args[2], "R", sep="")) diff --git a/src/test/scripts/functions/rewrite/RewritePushdownColSumBinaryMult.dml b/src/test/scripts/functions/rewrite/RewritePushdownColSumBinaryMult.dml new file mode 100644 index 00000000000..6494dc45e4a --- /dev/null +++ b/src/test/scripts/functions/rewrite/RewritePushdownColSumBinaryMult.dml @@ -0,0 +1,25 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +X=matrix(1, 100, 1) %*% t(seq(1,100)) +while(FALSE){} +R=2*colSums(3*X) +write(R, $1) diff --git a/src/test/scripts/functions/rewrite/RewritePushdownColSumBinaryMult2.R b/src/test/scripts/functions/rewrite/RewritePushdownColSumBinaryMult2.R new file mode 100644 index 00000000000..a951ea42c4f --- /dev/null +++ b/src/test/scripts/functions/rewrite/RewritePushdownColSumBinaryMult2.R @@ -0,0 +1,24 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +args<-commandArgs(TRUE) +library("Matrix") +X=matrix(1, 100, 1) %*% t(seq(1,100)) +R=matrix(2*colSums(X*3), nrow=1) +writeMM(as(R, "CsparseMatrix"), paste(args[2], "R", sep="")) diff --git a/src/test/scripts/functions/rewrite/RewritePushdownColSumBinaryMult2.dml b/src/test/scripts/functions/rewrite/RewritePushdownColSumBinaryMult2.dml new file mode 100644 index 00000000000..c69492b5713 --- /dev/null +++ b/src/test/scripts/functions/rewrite/RewritePushdownColSumBinaryMult2.dml @@ -0,0 +1,25 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +X=matrix(1, 100, 1) %*% t(seq(1,100)) +while(FALSE){} +R=2*colSums(X*3) +write(R, $1) diff --git a/src/test/scripts/functions/rewrite/RewritePushdownRowSumBinaryMult.R b/src/test/scripts/functions/rewrite/RewritePushdownRowSumBinaryMult.R new file mode 100644 index 00000000000..2b9f7b4b161 --- /dev/null +++ b/src/test/scripts/functions/rewrite/RewritePushdownRowSumBinaryMult.R @@ -0,0 +1,24 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +args<-commandArgs(TRUE) +library("Matrix") +X=matrix(1, 100, 1) %*% t(seq(1,100)) +R=matrix(2*rowSums(3*X), ncol=1) +writeMM(as(R, "CsparseMatrix"), paste(args[2], "R", sep="")) diff --git a/src/test/scripts/functions/rewrite/RewritePushdownRowSumBinaryMult.dml b/src/test/scripts/functions/rewrite/RewritePushdownRowSumBinaryMult.dml new file mode 100644 index 00000000000..31f5e0bd1e4 --- /dev/null +++ b/src/test/scripts/functions/rewrite/RewritePushdownRowSumBinaryMult.dml @@ -0,0 +1,25 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +X=matrix(1, 100, 1) %*% t(seq(1,100)) +while(FALSE){} +R=2*rowSums(3*X) +write(R, $1) diff --git a/src/test/scripts/functions/rewrite/RewritePushdownRowSumBinaryMult2.R b/src/test/scripts/functions/rewrite/RewritePushdownRowSumBinaryMult2.R new file mode 100644 index 00000000000..782e2fa687e --- /dev/null +++ b/src/test/scripts/functions/rewrite/RewritePushdownRowSumBinaryMult2.R @@ -0,0 +1,24 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +args<-commandArgs(TRUE) +library("Matrix") +X=matrix(1, 100, 1) %*% t(seq(1,100)) +R=matrix(2*rowSums(X*3), ncol=1) +writeMM(as(R, "CsparseMatrix"), paste(args[2], "R", sep="")) diff --git a/src/test/scripts/functions/rewrite/RewritePushdownRowSumBinaryMult2.dml b/src/test/scripts/functions/rewrite/RewritePushdownRowSumBinaryMult2.dml new file mode 100644 index 00000000000..d579df77582 --- /dev/null +++ b/src/test/scripts/functions/rewrite/RewritePushdownRowSumBinaryMult2.dml @@ -0,0 +1,25 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +X=matrix(1, 100, 1) %*% t(seq(1,100)) +while(FALSE){} +R=2*rowSums(X*3) +write(R, $1) diff --git a/src/test/scripts/functions/rewrite/RewriteSimplifySumConstantMatrix.R b/src/test/scripts/functions/rewrite/RewriteSimplifySumConstantMatrix.R new file mode 100644 index 00000000000..d6c07b4baf0 --- /dev/null +++ b/src/test/scripts/functions/rewrite/RewriteSimplifySumConstantMatrix.R @@ -0,0 +1,28 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +args <- commandArgs(TRUE) + +value <- as.numeric(args[1]) +rows <- as.integer(args[2]) +cols <- as.integer(args[3]) + +write(sum(matrix(value, nrow=rows, ncol=cols)), paste(args[4], "R", sep="")) diff --git a/src/test/scripts/functions/rewrite/RewriteSimplifySumConstantMatrix.dml b/src/test/scripts/functions/rewrite/RewriteSimplifySumConstantMatrix.dml new file mode 100644 index 00000000000..0b54eeb12eb --- /dev/null +++ b/src/test/scripts/functions/rewrite/RewriteSimplifySumConstantMatrix.dml @@ -0,0 +1,23 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +R = sum(matrix($1, rows=$2, cols=$3)) +write(R, $4)