Discussion:
[01/15] mahout git commit: NO-JIRA Trevors updates
r***@apache.org
2018-09-08 23:35:05 UTC
Permalink
Repository: mahout
Updated Branches:
refs/heads/branch-0.14.0 49ad8cb45 -> 545648f6a


http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/test/java/org/apache/mahout/math/decomposer/lanczos/TestLanczosSolver.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/org/apache/mahout/math/decomposer/lanczos/TestLanczosSolver.java b/core/src/test/java/org/apache/mahout/math/decomposer/lanczos/TestLanczosSolver.java
new file mode 100644
index 0000000..5a0c660
--- /dev/null
+++ b/core/src/test/java/org/apache/mahout/math/decomposer/lanczos/TestLanczosSolver.java
@@ -0,0 +1,97 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math.decomposer.lanczos;
+
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.decomposer.SolverTest;
+import org.apache.mahout.math.solver.EigenDecomposition;
+import org.junit.Test;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+public final class TestLanczosSolver extends SolverTest {
+ private static final Logger log = LoggerFactory.getLogger(TestLanczosSolver.class);
+
+ private static final double ERROR_TOLERANCE = 0.05;
+
+ @Test
+ public void testEigenvalueCheck() throws Exception {
+ int size = 100;
+ Matrix m = randomHierarchicalSymmetricMatrix(size);
+
+ Vector initialVector = new DenseVector(size);
+ initialVector.assign(1.0 / Math.sqrt(size));
+ LanczosSolver solver = new LanczosSolver();
+ int desiredRank = 80;
+ LanczosState state = new LanczosState(m, desiredRank, initialVector);
+ // set initial vector?
+ solver.solve(state, desiredRank, true);
+
+ EigenDecomposition decomposition = new EigenDecomposition(m);
+ Vector eigenvalues = decomposition.getRealEigenvalues();
+
+ float fractionOfEigensExpectedGood = 0.6f;
+ for (int i = 0; i < fractionOfEigensExpectedGood * desiredRank; i++) {
+ double s = state.getSingularValue(i);
+ double e = eigenvalues.get(i);
+ log.info("{} : L = {}, E = {}", i, s, e);
+ assertTrue("Singular value differs from eigenvalue", Math.abs((s-e)/e) < ERROR_TOLERANCE);
+ Vector v = state.getRightSingularVector(i);
+ Vector v2 = decomposition.getV().viewColumn(i);
+ double error = 1 - Math.abs(v.dot(v2)/(v.norm(2) * v2.norm(2)));
+ log.info("error: {}", error);
+ assertTrue(i + ": 1 - cosAngle = " + error, error < ERROR_TOLERANCE);
+ }
+ }
+
+
+ @Test
+ public void testLanczosSolver() throws Exception {
+ int numRows = 800;
+ int numColumns = 500;
+ Matrix corpus = randomHierarchicalMatrix(numRows, numColumns, false);
+ Vector initialVector = new DenseVector(numColumns);
+ initialVector.assign(1.0 / Math.sqrt(numColumns));
+ int rank = 50;
+ LanczosState state = new LanczosState(corpus, rank, initialVector);
+ LanczosSolver solver = new LanczosSolver();
+ solver.solve(state, rank, false);
+ assertOrthonormal(state);
+ for (int i = 0; i < rank/2; i++) {
+ assertEigen(i, state.getRightSingularVector(i), corpus, ERROR_TOLERANCE, false);
+ }
+ //assertEigen(eigens, corpus, rank / 2, ERROR_TOLERANCE, false);
+ }
+
+ @Test
+ public void testLanczosSolverSymmetric() throws Exception {
+ int numCols = 500;
+ Matrix corpus = randomHierarchicalSymmetricMatrix(numCols);
+ Vector initialVector = new DenseVector(numCols);
+ initialVector.assign(1.0 / Math.sqrt(numCols));
+ int rank = 30;
+ LanczosState state = new LanczosState(corpus, rank, initialVector);
+ LanczosSolver solver = new LanczosSolver();
+ solver.solve(state, rank, true);
+ //assertOrthonormal(state);
+ //assertEigen(state, rank / 2, ERROR_TOLERANCE, true);
+ }
+
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/test/java/org/apache/mahout/math/jet/stat/GammaTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/org/apache/mahout/math/jet/stat/GammaTest.java b/core/src/test/java/org/apache/mahout/math/jet/stat/GammaTest.java
new file mode 100644
index 0000000..b5280b4
--- /dev/null
+++ b/core/src/test/java/org/apache/mahout/math/jet/stat/GammaTest.java
@@ -0,0 +1,138 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math.jet.stat;
+
+import com.google.common.base.Charsets;
+import com.google.common.base.Splitter;
+import com.google.common.collect.Iterables;
+import com.google.common.io.CharStreams;
+import com.google.common.io.InputSupplier;
+import com.google.common.io.Resources;
+import org.apache.mahout.common.RandomUtils;
+import org.apache.mahout.math.MahoutTestCase;
+import org.junit.Test;
+
+import java.io.IOException;
+import java.io.InputStreamReader;
+import java.util.Random;
+
+public final class GammaTest extends MahoutTestCase {
+
+ @Test
+ public void testGamma() {
+ double[] x = {1, 2, 5, 10, 20, 50, 100};
+ double[] expected = {
+ 1.000000e+00, 1.000000e+00, 2.400000e+01, 3.628800e+05, 1.216451e+17, 6.082819e+62, 9.332622e+155
+ };
+
+ for (int i = 0; i < x.length; i++) {
+ assertEquals(expected[i], Gamma.gamma(x[i]), expected[i] * 1.0e-5);
+ assertEquals(gammaInteger(x[i]), Gamma.gamma(x[i]), expected[i] * 1.0e-5);
+ assertEquals(gammaInteger(x[i]), Math.exp(Gamma.logGamma(x[i])), expected[i] * 1.0e-5);
+ }
+ }
+
+ @Test
+ public void testNegativeArgForGamma() {
+ double[] x = {-30.3, -20.7, -10.5, -1.1, 0.5, 0.99, -0.999};
+ double[] expected = {
+ -5.243216e-33, -1.904051e-19, -2.640122e-07, 9.714806e+00, 1.772454e+00, 1.005872e+00, -1.000424e+03
+ };
+
+ for (int i = 0; i < x.length; i++) {
+ assertEquals(expected[i], Gamma.gamma(x[i]), Math.abs(expected[i] * 1.0e-5));
+ assertEquals(Math.abs(expected[i]), Math.abs(Math.exp(Gamma.logGamma(x[i]))), Math.abs(expected[i] * 1.0e-5));
+ }
+ }
+
+ private static double gammaInteger(double x) {
+ double r = 1;
+ for (int i = 2; i < x; i++) {
+ r *= i;
+ }
+ return r;
+ }
+
+ @Test
+ public void testBigX() {
+ assertEquals(factorial(4), 4 * 3 * 2, 0);
+ assertEquals(factorial(4), Gamma.gamma(5), 0);
+ assertEquals(factorial(14), Gamma.gamma(15), 0);
+ assertEquals(factorial(34), Gamma.gamma(35), 1.0e-15 * factorial(34));
+ assertEquals(factorial(44), Gamma.gamma(45), 1.0e-15 * factorial(44));
+
+ assertEquals(-6.884137e-40 + 3.508309e-47, Gamma.gamma(-35.1), 1.0e-52);
+ assertEquals(-3.915646e-41 - 3.526813e-48 - 1.172516e-55, Gamma.gamma(-35.9), 1.0e-52);
+ assertEquals(-2000000000.577215, Gamma.gamma(-0.5e-9), 1.0e-15 * 2000000000.577215);
+ assertEquals(1999999999.422784, Gamma.gamma(0.5e-9), 1.0e-15 * 1999999999.422784);
+ assertEquals(1.324296658017984e+252, Gamma.gamma(146.1), 1.0e-10 * 1.324296658017984e+252);
+
+ for (double x : new double[]{5, 15, 35, 45, -35.1, -35.9, -0.5e-9, 0.5e-9, 146.1}) {
+ double ref = Math.log(Math.abs(Gamma.gamma(x)));
+ double actual = Gamma.logGamma(x);
+ double diff = Math.abs(ref - actual) / ref;
+ assertEquals("gamma versus logGamma at " + x + " (diff = " + diff + ')', 0, (ref - actual) / ref, 1.0e-8);
+ }
+ }
+
+ private static double factorial(int n) {
+ double r = 1;
+ for (int i = 2; i <= n; i++) {
+ r *= i;
+ }
+ return r;
+ }
+
+ @Test
+ public void beta() {
+ Random r = RandomUtils.getRandom();
+ for (int i = 0; i < 200; i++) {
+ double alpha = -50 * Math.log1p(-r.nextDouble());
+ double beta = -50 * Math.log1p(-r.nextDouble());
+ double ref = Math.exp(Gamma.logGamma(alpha) + Gamma.logGamma(beta) - Gamma.logGamma(alpha + beta));
+ double actual = Gamma.beta(alpha, beta);
+ double err = (ref - actual) / ref;
+ assertEquals("beta at (" + alpha + ", " + beta + ") relative error = " + err, 0, err, 1.0e-10);
+ }
+ }
+
+ @Test
+ public void incompleteBeta() throws IOException {
+ Splitter onComma = Splitter.on(",").trimResults();
+
+ InputSupplier<InputStreamReader> input =
+ Resources.newReaderSupplier(Resources.getResource("beta-test-data.csv"), Charsets.UTF_8);
+ boolean header = true;
+ for (String line : CharStreams.readLines(input)) {
+ if (header) {
+ // skip
+ header = false;
+ } else {
+ Iterable<String> values = onComma.split(line);
+ double alpha = Double.parseDouble(Iterables.get(values, 0));
+ double beta = Double.parseDouble(Iterables.get(values, 1));
+ double x = Double.parseDouble(Iterables.get(values, 2));
+ double ref = Double.parseDouble(Iterables.get(values, 3));
+ double actual = Gamma.incompleteBeta(alpha, beta, x);
+ assertEquals(alpha + "," + beta + ',' + x, ref, actual, ref * 1.0e-5);
+ }
+ }
+ }
+
+
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/test/java/org/apache/mahout/math/jet/stat/ProbabilityTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/org/apache/mahout/math/jet/stat/ProbabilityTest.java b/core/src/test/java/org/apache/mahout/math/jet/stat/ProbabilityTest.java
new file mode 100644
index 0000000..47bbb0e
--- /dev/null
+++ b/core/src/test/java/org/apache/mahout/math/jet/stat/ProbabilityTest.java
@@ -0,0 +1,196 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math.jet.stat;
+
+import org.apache.mahout.math.MahoutTestCase;
+import org.junit.Test;
+
+import java.util.Locale;
+
+public final class ProbabilityTest extends MahoutTestCase {
+
+ @Test
+ public void testNormalCdf() {
+ // computed by R
+ // pnorm(seq(-5,5, length.out=100))
+ double[] ref = {
+ 2.866516e-07, 4.816530e-07, 8.013697e-07, 1.320248e-06, 2.153811e-06,
+ 3.479323e-06, 5.565743e-06, 8.816559e-06, 1.383023e-05, 2.148428e-05,
+ 3.305072e-05, 5.035210e-05, 7.596947e-05, 1.135152e-04, 1.679855e-04,
+ 2.462079e-04, 3.574003e-04, 5.138562e-04, 7.317683e-04, 1.032198e-03,
+ 1.442193e-03, 1.996034e-03, 2.736602e-03, 3.716808e-03, 5.001037e-03,
+ 6.666521e-03, 8.804535e-03, 1.152131e-02, 1.493850e-02, 1.919309e-02,
+ 2.443656e-02, 3.083320e-02, 3.855748e-02, 4.779035e-02, 5.871452e-02,
+ 7.150870e-02, 8.634102e-02, 1.033618e-01, 1.226957e-01, 1.444345e-01,
+ 1.686293e-01, 1.952845e-01, 2.243525e-01, 2.557301e-01, 2.892574e-01,
+ 3.247181e-01, 3.618436e-01, 4.003175e-01, 4.397847e-01, 4.798600e-01,
+ 5.201400e-01, 5.602153e-01, 5.996825e-01, 6.381564e-01, 6.752819e-01,
+ 7.107426e-01, 7.442699e-01, 7.756475e-01, 8.047155e-01, 8.313707e-01,
+ 8.555655e-01, 8.773043e-01, 8.966382e-01, 9.136590e-01, 9.284913e-01,
+ 9.412855e-01, 9.522096e-01, 9.614425e-01, 9.691668e-01, 9.755634e-01,
+ 9.808069e-01, 9.850615e-01, 9.884787e-01, 9.911955e-01, 9.933335e-01,
+ 9.949990e-01, 9.962832e-01, 9.972634e-01, 9.980040e-01, 9.985578e-01,
+ 9.989678e-01, 9.992682e-01, 9.994861e-01, 9.996426e-01, 9.997538e-01,
+ 9.998320e-01, 9.998865e-01, 9.999240e-01, 9.999496e-01, 9.999669e-01,
+ 9.999785e-01, 9.999862e-01, 9.999912e-01, 9.999944e-01, 9.999965e-01,
+ 9.999978e-01, 9.999987e-01, 9.999992e-01, 9.999995e-01, 9.999997e-01
+ };
+ assertEquals(0.682689492137 / 2 + 0.5, Probability.normal(1), 1.0e-7);
+
+ int i = 0;
+ for (double x = -5; x <= 5.005; x += 10.0 / 99) {
+ assertEquals("Test 1 cdf function at " + x, ref[i], Probability.normal(x), 1.0e-6);
+ assertEquals("Test 2 cdf function at " + x, ref[i], Probability.normal(12, 1, x + 12), 1.0e-6);
+ assertEquals("Test 3 cdf function at " + x, ref[i], Probability.normal(12, 0.25, x / 2.0 + 12), 1.0e-6);
+ i++;
+ }
+ }
+
+ @Test
+ public void testBetaCdf() {
+ // values computed using:
+ //> pbeta(seq(0, 1, length.out=100), 1, 1)
+ //> pbeta(seq(0, 1, length.out=100), 2, 1)
+ //> pbeta(seq(0, 1, length.out=100), 2, 5)
+ //> pbeta(seq(0, 1, length.out=100), 0.2, 5)
+ //> pbeta(seq(0, 1, length.out=100), 0.2, 0.01)
+
+ double[][] ref = new double[5][];
+
+ ref[0] = new double[]{
+ 0.00000000, 0.01010101, 0.02020202, 0.03030303, 0.04040404, 0.05050505,
+ 0.06060606, 0.07070707, 0.08080808, 0.09090909, 0.10101010, 0.11111111,
+ 0.12121212, 0.13131313, 0.14141414, 0.15151515, 0.16161616, 0.17171717,
+ 0.18181818, 0.19191919, 0.20202020, 0.21212121, 0.22222222, 0.23232323,
+ 0.24242424, 0.25252525, 0.26262626, 0.27272727, 0.28282828, 0.29292929,
+ 0.30303030, 0.31313131, 0.32323232, 0.33333333, 0.34343434, 0.35353535,
+ 0.36363636, 0.37373737, 0.38383838, 0.39393939, 0.40404040, 0.41414141,
+ 0.42424242, 0.43434343, 0.44444444, 0.45454545, 0.46464646, 0.47474747,
+ 0.48484848, 0.49494949, 0.50505051, 0.51515152, 0.52525253, 0.53535354,
+ 0.54545455, 0.55555556, 0.56565657, 0.57575758, 0.58585859, 0.59595960,
+ 0.60606061, 0.61616162, 0.62626263, 0.63636364, 0.64646465, 0.65656566,
+ 0.66666667, 0.67676768, 0.68686869, 0.69696970, 0.70707071, 0.71717172,
+ 0.72727273, 0.73737374, 0.74747475, 0.75757576, 0.76767677, 0.77777778,
+ 0.78787879, 0.79797980, 0.80808081, 0.81818182, 0.82828283, 0.83838384,
+ 0.84848485, 0.85858586, 0.86868687, 0.87878788, 0.88888889, 0.89898990,
+ 0.90909091, 0.91919192, 0.92929293, 0.93939394, 0.94949495, 0.95959596,
+ 0.96969697, 0.97979798, 0.98989899, 1.00000000
+ };
+ ref[1] = new double[]{
+ 0.0000000000, 0.0001020304, 0.0004081216, 0.0009182736, 0.0016324865,
+ 0.0025507601, 0.0036730946, 0.0049994898, 0.0065299459, 0.0082644628,
+ 0.0102030405, 0.0123456790, 0.0146923783, 0.0172431385, 0.0199979594,
+ 0.0229568411, 0.0261197837, 0.0294867871, 0.0330578512, 0.0368329762,
+ 0.0408121620, 0.0449954086, 0.0493827160, 0.0539740843, 0.0587695133,
+ 0.0637690032, 0.0689725538, 0.0743801653, 0.0799918376, 0.0858075707,
+ 0.0918273646, 0.0980512193, 0.1044791348, 0.1111111111, 0.1179471483,
+ 0.1249872462, 0.1322314050, 0.1396796245, 0.1473319049, 0.1551882461,
+ 0.1632486481, 0.1715131109, 0.1799816345, 0.1886542190, 0.1975308642,
+ 0.2066115702, 0.2158963371, 0.2253851648, 0.2350780533, 0.2449750026,
+ 0.2550760127, 0.2653810836, 0.2758902153, 0.2866034078, 0.2975206612,
+ 0.3086419753, 0.3199673503, 0.3314967860, 0.3432302826, 0.3551678400,
+ 0.3673094582, 0.3796551372, 0.3922048771, 0.4049586777, 0.4179165391,
+ 0.4310784614, 0.4444444444, 0.4580144883, 0.4717885930, 0.4857667585,
+ 0.4999489848, 0.5143352719, 0.5289256198, 0.5437200286, 0.5587184981,
+ 0.5739210285, 0.5893276196, 0.6049382716, 0.6207529844, 0.6367717580,
+ 0.6529945924, 0.6694214876, 0.6860524436, 0.7028874605, 0.7199265381,
+ 0.7371696766, 0.7546168758, 0.7722681359, 0.7901234568, 0.8081828385,
+ 0.8264462810, 0.8449137843, 0.8635853484, 0.8824609734, 0.9015406591,
+ 0.9208244057, 0.9403122130, 0.9600040812, 0.9799000102, 1.0000000000
+ };
+ ref[2] = new double[]{
+ 0.000000000, 0.001489698, 0.005799444, 0.012698382, 0.021966298, 0.033393335,
+ 0.046779694, 0.061935356, 0.078679798, 0.096841712, 0.116258735, 0.136777178,
+ 0.158251755, 0.180545326, 0.203528637, 0.227080061, 0.251085352, 0.275437393,
+ 0.300035957, 0.324787463, 0.349604743, 0.374406809, 0.399118623, 0.423670875,
+ 0.447999763, 0.472046772, 0.495758466, 0.519086275, 0.541986291, 0.564419069,
+ 0.586349424, 0.607746242, 0.628582288, 0.648834019, 0.668481403, 0.687507740,
+ 0.705899486, 0.723646086, 0.740739801, 0.757175549, 0.772950746, 0.788065147,
+ 0.802520695, 0.816321377, 0.829473074, 0.841983426, 0.853861691, 0.865118615,
+ 0.875766302, 0.885818092, 0.895288433, 0.904192771, 0.912547431, 0.920369513,
+ 0.927676778, 0.934487554, 0.940820632, 0.946695177, 0.952130629, 0.957146627,
+ 0.961762916, 0.965999275, 0.969875437, 0.973411020, 0.976625460, 0.979537944,
+ 0.982167353, 0.984532203, 0.986650598, 0.988540173, 0.990218056, 0.991700827,
+ 0.993004475, 0.994144371, 0.995135237, 0.995991117, 0.996725360, 0.997350600,
+ 0.997878739, 0.998320942, 0.998687627, 0.998988463, 0.999232371, 0.999427531,
+ 0.999581387, 0.999700663, 0.999791377, 0.999858864, 0.999907798, 0.999942219,
+ 0.999965567, 0.999980718, 0.999990021, 0.999995342, 0.999998111, 0.999999376,
+ 0.999999851, 0.999999980, 0.999999999, 1.000000000
+ };
+ ref[3] = new double[]{
+ 0.0000000, 0.5858072, 0.6684658, 0.7201859, 0.7578936, 0.7873991, 0.8114552,
+ 0.8316029, 0.8487998, 0.8636849, 0.8767081, 0.8881993, 0.8984080, 0.9075280,
+ 0.9157131, 0.9230876, 0.9297536, 0.9357958, 0.9412856, 0.9462835, 0.9508414,
+ 0.9550044, 0.9588113, 0.9622963, 0.9654896, 0.9684178, 0.9711044, 0.9735707,
+ 0.9758356, 0.9779161, 0.9798276, 0.9815839, 0.9831977, 0.9846805, 0.9860426,
+ 0.9872936, 0.9884422, 0.9894965, 0.9904638, 0.9913509, 0.9921638, 0.9929085,
+ 0.9935900, 0.9942134, 0.9947832, 0.9953034, 0.9957779, 0.9962104, 0.9966041,
+ 0.9969621, 0.9972872, 0.9975821, 0.9978492, 0.9980907, 0.9983088, 0.9985055,
+ 0.9986824, 0.9988414, 0.9989839, 0.9991113, 0.9992251, 0.9993265, 0.9994165,
+ 0.9994963, 0.9995668, 0.9996288, 0.9996834, 0.9997311, 0.9997727, 0.9998089,
+ 0.9998401, 0.9998671, 0.9998901, 0.9999098, 0.9999265, 0.9999406, 0.9999524,
+ 0.9999622, 0.9999703, 0.9999769, 0.9999823, 0.9999866, 0.9999900, 0.9999927,
+ 0.9999947, 0.9999963, 0.9999975, 0.9999983, 0.9999989, 0.9999993, 0.9999996,
+ 0.9999998, 0.9999999, 0.9999999, 1.0000000, 1.0000000, 1.0000000, 1.0000000,
+ 1.0000000, 1.0000000
+ };
+ ref[4] = new double[]{
+ 0.00000000, 0.01908202, 0.02195656, 0.02385194, 0.02530810, 0.02650923,
+ 0.02754205, 0.02845484, 0.02927741, 0.03002959, 0.03072522, 0.03137444,
+ 0.03198487, 0.03256240, 0.03311171, 0.03363655, 0.03414001, 0.03462464,
+ 0.03509259, 0.03554568, 0.03598550, 0.03641339, 0.03683054, 0.03723799,
+ 0.03763667, 0.03802739, 0.03841091, 0.03878787, 0.03915890, 0.03952453,
+ 0.03988529, 0.04024162, 0.04059396, 0.04094272, 0.04128827, 0.04163096,
+ 0.04197113, 0.04230909, 0.04264515, 0.04297958, 0.04331268, 0.04364471,
+ 0.04397592, 0.04430658, 0.04463693, 0.04496722, 0.04529770, 0.04562860,
+ 0.04596017, 0.04629265, 0.04662629, 0.04696134, 0.04729804, 0.04763666,
+ 0.04797747, 0.04832073, 0.04866673, 0.04901578, 0.04936816, 0.04972422,
+ 0.05008428, 0.05044871, 0.05081789, 0.05119222, 0.05157213, 0.05195809,
+ 0.05235059, 0.05275018, 0.05315743, 0.05357298, 0.05399753, 0.05443184,
+ 0.05487673, 0.05533315, 0.05580212, 0.05628480, 0.05678247, 0.05729660,
+ 0.05782885, 0.05838111, 0.05895557, 0.05955475, 0.06018161, 0.06083965,
+ 0.06153300, 0.06226670, 0.06304685, 0.06388102, 0.06477877, 0.06575235,
+ 0.06681788, 0.06799717, 0.06932077, 0.07083331, 0.07260394, 0.07474824,
+ 0.07748243, 0.08129056, 0.08771055, 1.00000000
+ };
+
+ double[] alpha = {1.0, 2.0, 2.0, 0.2, 0.2};
+ double[] beta = {1.0, 1.0, 5.0, 5.0, 0.01};
+ for (int j = 0; j < 4; j++) {
+ for (int i = 0; i < 100; i++) {
+ double x = i / 99.0;
+ String p = String.format(Locale.ENGLISH,
+ "pbeta(q=%6.4f, shape1=%5.3f shape2=%5.3f) = %.8f",
+ x, alpha[j], beta[j], ref[j][i]);
+ assertEquals(p, ref[j][i], Probability.beta(alpha[j], beta[j], x), 1.0e-7);
+ }
+ }
+ }
+
+ @Test
+ public void testLogGamma() {
+ double[] xValues = {1.1, 2.1, 3.1, 4.1, 5.1, 20.1, 100.1, -0.9};
+ double[] ref = {
+ -0.04987244, 0.04543774, 0.78737508, 1.91877719, 3.32976417, 39.63719250, 359.59427179, 2.35807317
+ };
+ for (int i = 0; i < xValues.length; i++) {
+ double x = xValues[i];
+ assertEquals(ref[i], Gamma.logGamma(x), 1.0e-7);
+ }
+ }
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/test/java/org/apache/mahout/math/random/ChineseRestaurantTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/org/apache/mahout/math/random/ChineseRestaurantTest.java b/core/src/test/java/org/apache/mahout/math/random/ChineseRestaurantTest.java
new file mode 100644
index 0000000..b8c4624
--- /dev/null
+++ b/core/src/test/java/org/apache/mahout/math/random/ChineseRestaurantTest.java
@@ -0,0 +1,158 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math.random;
+
+import com.google.common.collect.HashMultiset;
+import com.google.common.collect.ImmutableSet;
+import com.google.common.collect.Lists;
+import com.google.common.collect.Multiset;
+import org.apache.mahout.math.DenseMatrix;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.MahoutTestCase;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.QRDecomposition;
+import org.junit.Test;
+
+import java.util.Collections;
+import java.util.List;
+import java.util.Set;
+
+public final class ChineseRestaurantTest extends MahoutTestCase {
+
+ @Test
+ public void testDepth() {
+ List<Integer> totals = Lists.newArrayList();
+ for (int i = 0; i < 1000; i++) {
+ ChineseRestaurant x = new ChineseRestaurant(10);
+ Multiset<Integer> counts = HashMultiset.create();
+ for (int j = 0; j < 100; j++) {
+ counts.add(x.sample());
+ }
+ List<Integer> tmp = Lists.newArrayList();
+ for (Integer k : counts.elementSet()) {
+ tmp.add(counts.count(k));
+ }
+ Collections.sort(tmp, Collections.reverseOrder());
+ while (totals.size() < tmp.size()) {
+ totals.add(0);
+ }
+ int j = 0;
+ for (Integer k : tmp) {
+ totals.set(j, totals.get(j) + k);
+ j++;
+ }
+ }
+
+ // these are empirically derived values, not principled ones
+ assertEquals(25000.0, (double) totals.get(0), 1000);
+ assertEquals(24000.0, (double) totals.get(1), 1000);
+ assertEquals(8000.0, (double) totals.get(2), 200);
+ assertEquals(1000.0, (double) totals.get(15), 50);
+ assertEquals(1000.0, (double) totals.get(20), 40);
+ }
+
+ @Test
+ public void testExtremeDiscount() {
+ ChineseRestaurant x = new ChineseRestaurant(100, 1);
+ Multiset<Integer> counts = HashMultiset.create();
+ for (int i = 0; i < 10000; i++) {
+ counts.add(x.sample());
+ }
+ assertEquals(10000, x.size());
+ for (int i = 0; i < 10000; i++) {
+ assertEquals(1, x.count(i));
+ }
+ }
+
+ @Test
+ public void testGrowth() {
+ ChineseRestaurant s0 = new ChineseRestaurant(10, 0.0);
+ ChineseRestaurant s5 = new ChineseRestaurant(10, 0.5);
+ ChineseRestaurant s9 = new ChineseRestaurant(10, 0.9);
+ Set<Double> splits = ImmutableSet.of(1.0, 1.5, 2.0, 3.0, 5.0, 8.0);
+
+ double offset0 = 0;
+ int k = 0;
+ int i = 0;
+ Matrix m5 = new DenseMatrix(20, 3);
+ Matrix m9 = new DenseMatrix(20, 3);
+ while (i <= 200000) {
+ double n = i / Math.pow(10, Math.floor(Math.log10(i)));
+ if (splits.contains(n)) {
+ //System.out.printf("%d\t%d\t%d\t%d\n", i, s0.size(), s5.size(), s9.size());
+ if (i > 900) {
+ double predict5 = predictSize(m5.viewPart(0, k, 0, 3), i, 0.5);
+ assertEquals(predict5, Math.log(s5.size()), 1);
+
+ double predict9 = predictSize(m9.viewPart(0, k, 0, 3), i, 0.9);
+ assertEquals(predict9, Math.log(s9.size()), 1);
+
+ //assertEquals(10.5 * Math.log(i) - offset0, s0.size(), 10);
+ } else if (i > 50) {
+ double x = 10.5 * Math.log(i) - s0.size();
+ m5.viewRow(k).assign(new double[]{Math.log(s5.size()), Math.log(i), 1});
+ m9.viewRow(k).assign(new double[]{Math.log(s9.size()), Math.log(i), 1});
+
+ k++;
+ offset0 += (x - offset0) / k;
+ }
+ if (i > 10000) {
+ assertEquals(0.0, (double) hapaxCount(s0) / s0.size(), 0.25);
+ assertEquals(0.5, (double) hapaxCount(s5) / s5.size(), 0.1);
+ assertEquals(0.9, (double) hapaxCount(s9) / s9.size(), 0.05);
+ }
+ }
+ s0.sample();
+ s5.sample();
+ s9.sample();
+ i++;
+ }
+ }
+
+ /**
+ * Predict the power law growth in number of unique samples from the first few data points.
+ * Also check that the fitted growth coefficient is about right.
+ *
+ * @param m
+ * @param currentIndex Total data points seen so far. Unique values should be log(currentIndex)*expectedCoefficient + offset.
+ * @param expectedCoefficient What slope do we expect.
+ * @return The predicted value for log(currentIndex)
+ */
+ private static double predictSize(Matrix m, int currentIndex, double expectedCoefficient) {
+ int rows = m.rowSize();
+ Matrix a = m.viewPart(0, rows, 1, 2);
+ Matrix b = m.viewPart(0, rows, 0, 1);
+
+ Matrix ata = a.transpose().times(a);
+ Matrix atb = a.transpose().times(b);
+ QRDecomposition s = new QRDecomposition(ata);
+ Matrix r = s.solve(atb).transpose();
+ assertEquals(expectedCoefficient, r.get(0, 0), 0.2);
+ return r.times(new DenseVector(new double[]{Math.log(currentIndex), 1})).get(0);
+ }
+
+ private static int hapaxCount(ChineseRestaurant s) {
+ int r = 0;
+ for (int i = 0; i < s.size(); i++) {
+ if (s.count(i) == 1) {
+ r++;
+ }
+ }
+ return r;
+ }
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/test/java/org/apache/mahout/math/randomized/RandomBlasting.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/org/apache/mahout/math/randomized/RandomBlasting.java b/core/src/test/java/org/apache/mahout/math/randomized/RandomBlasting.java
new file mode 100644
index 0000000..120b25f
--- /dev/null
+++ b/core/src/test/java/org/apache/mahout/math/randomized/RandomBlasting.java
@@ -0,0 +1,355 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math.randomized;
+
+import java.lang.reflect.Field;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+
+import com.google.common.collect.Lists;
+import com.google.common.collect.Sets;
+import org.apache.mahout.math.map.OpenIntIntHashMap;
+import org.apache.mahout.math.map.OpenIntObjectHashMap;
+import org.apache.mahout.math.map.OpenObjectIntHashMap;
+import org.apache.mahout.math.set.AbstractIntSet;
+import org.apache.mahout.math.set.OpenHashSet;
+import org.apache.mahout.math.set.OpenIntHashSet;
+import org.junit.Test;
+
+import com.carrotsearch.randomizedtesting.RandomizedTest;
+import com.carrotsearch.randomizedtesting.annotations.Repeat;
+import com.carrotsearch.randomizedtesting.annotations.Seed;
+
+/**
+ * Some randomized tests against Java Util Collections.
+ */
+public class RandomBlasting extends RandomizedTest {
+ private static enum Operation {
+ ADD, REMOVE, CLEAR, INDEXOF, ISEMPTY, SIZE
+ }
+
+ @Test
+ @Repeat(iterations = 20)
+ public void testAgainstReferenceOpenObjectIntHashMap() {
+ OpenObjectIntHashMap<Integer> base = new OpenObjectIntHashMap<>();
+ Map<Integer, Integer> reference = new HashMap<>();
+
+ List<Operation> ops = Lists.newArrayList();
+ addOp(ops, Operation.ADD, 60);
+ addOp(ops, Operation.REMOVE, 30);
+ addOp(ops, Operation.INDEXOF, 30);
+ addOp(ops, Operation.CLEAR, 5);
+ addOp(ops, Operation.ISEMPTY, 2);
+ addOp(ops, Operation.SIZE, 2);
+
+ int max = randomIntBetween(1000, 20000);
+ for (int reps = 0; reps < max; reps++) {
+ // Ensure some collisions among keys.
+ int k = randomIntBetween(0, max / 4);
+ int v = randomInt();
+ switch (randomFrom(ops)) {
+ case ADD:
+ assertEquals(reference.put(k, v) == null, base.put(k, v));
+ break;
+
+ case REMOVE:
+ assertEquals(reference.remove(k) != null, base.removeKey(k));
+ break;
+
+ case INDEXOF:
+ assertEquals(reference.containsKey(k), base.containsKey(k));
+ break;
+
+ case CLEAR:
+ reference.clear();
+ base.clear();
+ break;
+
+ case ISEMPTY:
+ assertEquals(reference.isEmpty(), base.isEmpty());
+ break;
+
+ case SIZE:
+ assertEquals(reference.size(), base.size());
+ break;
+
+ default:
+ throw new RuntimeException();
+ }
+ }
+ }
+
+ @Test
+ @Repeat(iterations = 20)
+ public void testAgainstReferenceOpenIntObjectHashMap() {
+ OpenIntObjectHashMap<Integer> base = new OpenIntObjectHashMap<>();
+ Map<Integer, Integer> reference = new HashMap<>();
+
+ List<Operation> ops = Lists.newArrayList();
+ addOp(ops, Operation.ADD, 60);
+ addOp(ops, Operation.REMOVE, 30);
+ addOp(ops, Operation.INDEXOF, 30);
+ addOp(ops, Operation.CLEAR, 5);
+ addOp(ops, Operation.ISEMPTY, 2);
+ addOp(ops, Operation.SIZE, 2);
+
+ int max = randomIntBetween(1000, 20000);
+ for (int reps = 0; reps < max; reps++) {
+ // Ensure some collisions among keys.
+ int k = randomIntBetween(0, max / 4);
+ int v = randomInt();
+ switch (randomFrom(ops)) {
+ case ADD:
+ assertEquals(reference.put(k, v) == null, base.put(k, v));
+ break;
+
+ case REMOVE:
+ assertEquals(reference.remove(k) != null, base.removeKey(k));
+ break;
+
+ case INDEXOF:
+ assertEquals(reference.containsKey(k), base.containsKey(k));
+ break;
+
+ case CLEAR:
+ reference.clear();
+ base.clear();
+ break;
+
+ case ISEMPTY:
+ assertEquals(reference.isEmpty(), base.isEmpty());
+ break;
+
+ case SIZE:
+ assertEquals(reference.size(), base.size());
+ break;
+
+ default:
+ throw new RuntimeException();
+ }
+ }
+ }
+
+ @Test
+ @Repeat(iterations = 20)
+ public void testAgainstReferenceOpenIntIntHashMap() {
+ OpenIntIntHashMap base = new OpenIntIntHashMap();
+ HashMap<Integer, Integer> reference = new HashMap<>();
+
+ List<Operation> ops = Lists.newArrayList();
+ addOp(ops, Operation.ADD, 60);
+ addOp(ops, Operation.REMOVE, 30);
+ addOp(ops, Operation.INDEXOF, 30);
+ addOp(ops, Operation.CLEAR, 5);
+ addOp(ops, Operation.ISEMPTY, 2);
+ addOp(ops, Operation.SIZE, 2);
+
+ int max = randomIntBetween(1000, 20000);
+ for (int reps = 0; reps < max; reps++) {
+ // Ensure some collisions among keys.
+ int k = randomIntBetween(0, max / 4);
+ int v = randomInt();
+ switch (randomFrom(ops)) {
+ case ADD:
+ Integer prevValue = reference.put(k, v);
+
+ if (prevValue == null) {
+ assertEquals(true, base.put(k, v));
+ } else {
+ assertEquals(prevValue.intValue(), base.get(k));
+ assertEquals(false, base.put(k, v));
+ }
+ break;
+
+ case REMOVE:
+ assertEquals(reference.containsKey(k), base.containsKey(k));
+
+ Integer removed = reference.remove(k);
+ if (removed == null) {
+ assertEquals(false, base.removeKey(k));
+ } else {
+ assertEquals(removed.intValue(), base.get(k));
+ assertEquals(true, base.removeKey(k));
+ }
+ break;
+
+ case INDEXOF:
+ assertEquals(reference.containsKey(k), base.containsKey(k));
+ break;
+
+ case CLEAR:
+ reference.clear();
+ base.clear();
+ break;
+
+ case ISEMPTY:
+ assertEquals(reference.isEmpty(), base.isEmpty());
+ break;
+
+ case SIZE:
+ assertEquals(reference.size(), base.size());
+ break;
+
+ default:
+ throw new RuntimeException();
+ }
+ }
+ }
+
+ @Test
+ @Repeat(iterations = 20)
+ public void testAgainstReferenceOpenIntHashSet() {
+ AbstractIntSet base = new OpenIntHashSet();
+ HashSet<Integer> reference = Sets.newHashSet();
+
+ List<Operation> ops = Lists.newArrayList();
+ addOp(ops, Operation.ADD, 60);
+ addOp(ops, Operation.REMOVE, 30);
+ addOp(ops, Operation.INDEXOF, 30);
+ addOp(ops, Operation.CLEAR, 5);
+ addOp(ops, Operation.ISEMPTY, 2);
+ addOp(ops, Operation.SIZE, 2);
+
+ int max = randomIntBetween(1000, 20000);
+ for (int reps = 0; reps < max; reps++) {
+ // Ensure some collisions among keys.
+ int k = randomIntBetween(0, max / 4);
+ switch (randomFrom(ops)) {
+ case ADD:
+ assertEquals(reference.add(k), base.add(k));
+ break;
+
+ case REMOVE:
+ assertEquals(reference.remove(k), base.remove(k));
+ break;
+
+ case INDEXOF:
+ assertEquals(reference.contains(k), base.contains(k));
+ break;
+
+ case CLEAR:
+ reference.clear();
+ base.clear();
+ break;
+
+ case ISEMPTY:
+ assertEquals(reference.isEmpty(), base.isEmpty());
+ break;
+
+ case SIZE:
+ assertEquals(reference.size(), base.size());
+ break;
+
+ default:
+ throw new RuntimeException();
+ }
+ }
+ }
+
+ @Seed("deadbeef")
+ @Test
+ @Repeat(iterations = 20)
+ public void testAgainstReferenceOpenHashSet() {
+ Set<Integer> base = new OpenHashSet<>();
+ Set<Integer> reference = Sets.newHashSet();
+
+ List<Operation> ops = Lists.newArrayList();
+ addOp(ops, Operation.ADD, 60);
+ addOp(ops, Operation.REMOVE, 30);
+ addOp(ops, Operation.INDEXOF, 30);
+ addOp(ops, Operation.CLEAR, 5);
+ addOp(ops, Operation.ISEMPTY, 2);
+ addOp(ops, Operation.SIZE, 2);
+
+ int max = randomIntBetween(1000, 20000);
+ for (int reps = 0; reps < max; reps++) {
+ // Ensure some collisions among keys.
+ int k = randomIntBetween(0, max / 4);
+ switch (randomFrom(ops)) {
+ case ADD:
+ assertEquals(reference.contains(k), base.contains(k));
+ break;
+
+ case REMOVE:
+ assertEquals(reference.remove(k), base.remove(k));
+ break;
+
+ case INDEXOF:
+ assertEquals(reference.contains(k), base.contains(k));
+ break;
+
+ case CLEAR:
+ reference.clear();
+ base.clear();
+ break;
+
+ case ISEMPTY:
+ assertEquals(reference.isEmpty(), base.isEmpty());
+ break;
+
+ case SIZE:
+ assertEquals(reference.size(), base.size());
+ break;
+
+ default:
+ throw new RuntimeException();
+ }
+ }
+ }
+
+ /**
+ * @see "https://issues.apache.org/jira/browse/MAHOUT-1225"
+ */
+ @Test
+ public void testMahout1225() {
+ AbstractIntSet s = new OpenIntHashSet();
+ s.clear();
+ s.add(23);
+ s.add(46);
+ s.clear();
+ s.add(70);
+ s.add(93);
+ s.contains(100);
+ }
+
+ /** */
+ @Test
+ public void testClearTable() throws Exception {
+ OpenObjectIntHashMap<Integer> m = new OpenObjectIntHashMap<>();
+ m.clear(); // rehash from the default capacity to the next prime after 1 (3).
+ m.put(1, 2);
+ m.clear(); // Should clear internal references.
+
+ Field tableField = m.getClass().getDeclaredField("table");
+ tableField.setAccessible(true);
+ Object[] table = (Object[]) tableField.get(m);
+
+ assertEquals(Sets.newHashSet(Arrays.asList(new Object[] {null})), Sets.newHashSet(Arrays.asList(table)));
+ }
+
+ /** Add multiple repetitions of op to the list. */
+ private static void addOp(List<Operation> ops, Operation op, int reps) {
+ for (int i = 0; i < reps; i++) {
+ ops.add(op);
+ }
+ }
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/test/java/org/apache/mahout/math/ssvd/SequentialBigSvdTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/org/apache/mahout/math/ssvd/SequentialBigSvdTest.java b/core/src/test/java/org/apache/mahout/math/ssvd/SequentialBigSvdTest.java
new file mode 100644
index 0000000..92fd5bb
--- /dev/null
+++ b/core/src/test/java/org/apache/mahout/math/ssvd/SequentialBigSvdTest.java
@@ -0,0 +1,86 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math.ssvd;
+
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.DiagonalMatrix;
+import org.apache.mahout.math.MahoutTestCase;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.RandomTrinaryMatrix;
+import org.apache.mahout.math.SingularValueDecomposition;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.function.Functions;
+import org.junit.Test;
+
+public final class SequentialBigSvdTest extends MahoutTestCase {
+
+ @Test
+ public void testSingularValues() {
+ Matrix A = lowRankMatrix();
+
+ SequentialBigSvd s = new SequentialBigSvd(A, 8);
+ SingularValueDecomposition svd = new SingularValueDecomposition(A);
+
+ Vector reference = new DenseVector(svd.getSingularValues()).viewPart(0, 8);
+ assertEquals(reference, s.getSingularValues());
+
+ assertEquals(A, s.getU().times(new DiagonalMatrix(s.getSingularValues())).times(s.getV().transpose()));
+ }
+
+ @Test
+ public void testLeftVectors() {
+ Matrix A = lowRankMatrix();
+
+ SequentialBigSvd s = new SequentialBigSvd(A, 8);
+ SingularValueDecomposition svd = new SingularValueDecomposition(A);
+
+ // can only check first few singular vectors because once the singular values
+ // go to zero, the singular vectors are not uniquely determined
+ Matrix u1 = svd.getU().viewPart(0, 20, 0, 4).assign(Functions.ABS);
+ Matrix u2 = s.getU().viewPart(0, 20, 0, 4).assign(Functions.ABS);
+ assertEquals(0, u1.minus(u2).aggregate(Functions.PLUS, Functions.ABS), 1.0e-9);
+ }
+
+ private static void assertEquals(Matrix u1, Matrix u2) {
+ assertEquals(0, u1.minus(u2).aggregate(Functions.MAX, Functions.ABS), 1.0e-10);
+ }
+
+ private static void assertEquals(Vector u1, Vector u2) {
+ assertEquals(0, u1.minus(u2).aggregate(Functions.MAX, Functions.ABS), 1.0e-10);
+ }
+
+ @Test
+ public void testRightVectors() {
+ Matrix A = lowRankMatrix();
+
+ SequentialBigSvd s = new SequentialBigSvd(A, 6);
+ SingularValueDecomposition svd = new SingularValueDecomposition(A);
+
+ Matrix v1 = svd.getV().viewPart(0, 20, 0, 3).assign(Functions.ABS);
+ Matrix v2 = s.getV().viewPart(0, 20, 0, 3).assign(Functions.ABS);
+ assertEquals(v1, v2);
+ }
+
+ private static Matrix lowRankMatrix() {
+ Matrix u = new RandomTrinaryMatrix(1, 20, 4, false);
+ Matrix d = new DiagonalMatrix(new double[]{5, 3, 1, 0.5});
+ Matrix v = new RandomTrinaryMatrix(2, 23, 4, false);
+
+ return u.times(d).times(v.transpose());
+ }
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/test/java/org/apache/mahout/math/stats/OnlineSummarizerTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/org/apache/mahout/math/stats/OnlineSummarizerTest.java b/core/src/test/java/org/apache/mahout/math/stats/OnlineSummarizerTest.java
new file mode 100644
index 0000000..681213b
--- /dev/null
+++ b/core/src/test/java/org/apache/mahout/math/stats/OnlineSummarizerTest.java
@@ -0,0 +1,108 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math.stats;
+
+import org.apache.mahout.common.RandomUtils;
+import org.apache.mahout.math.MahoutTestCase;
+import org.apache.mahout.math.jet.random.AbstractContinousDistribution;
+import org.apache.mahout.math.jet.random.Gamma;
+import org.junit.Test;
+
+import java.util.Arrays;
+import java.util.Random;
+
+public final class OnlineSummarizerTest extends MahoutTestCase {
+
+ @Test
+ public void testStats() {
+ /**
+ the reference limits here were derived using a numerical simulation where I took
+ 10,000 samples from the distribution in question and computed the stats from that
+ sample to get min, 25%-ile, median and so on. I did this 1000 times to get 5% and
+ 95% confidence limits for those values.
+ */
+
+ //symmetrical, well behaved
+ System.out.printf("normal\n");
+ check(normal(10000));
+
+ //asymmetrical, well behaved. The range for the maximum was fudged slightly to all this to pass.
+ System.out.printf("exp\n");
+ check(exp(10000));
+
+ //asymmetrical, wacko distribution where mean/median is about 200
+ System.out.printf("gamma\n");
+ check(gamma(10000, 0.1));
+ }
+
+ private static void check(double[] samples) {
+ OnlineSummarizer s = new OnlineSummarizer();
+ double mean = 0;
+ double sd = 0;
+ int n = 1;
+ for (double x : samples) {
+ s.add(x);
+ double old = mean;
+ mean += (x - mean) / n;
+ sd += (x - old) * (x - mean);
+ n++;
+ }
+ sd = Math.sqrt(sd / samples.length);
+
+ Arrays.sort(samples);
+
+// for (int i = 0; i < 5; i++) {
+// int index = Math.abs(Arrays.binarySearch(samples, s.getQuartile(i)));
+// assertEquals("quartile " + i, i * (samples.length - 1) / 4.0, index, 10);
+// }
+// assertEquals(s.getQuartile(2), s.getMedian(), 0);
+
+ assertEquals("mean", s.getMean(), mean, 0);
+ assertEquals("sd", s.getSD(), sd, 1e-8);
+ }
+
+ private static double[] normal(int n) {
+ double[] r = new double[n];
+ Random gen = RandomUtils.getRandom(1L);
+ for (int i = 0; i < n; i++) {
+ r[i] = gen.nextGaussian();
+ }
+ return r;
+ }
+
+ private static double[] exp(int n) {
+ double[] r = new double[n];
+ Random gen = RandomUtils.getRandom(1L);
+ for (int i = 0; i < n; i++) {
+ r[i] = -Math.log1p(-gen.nextDouble());
+ }
+ return r;
+ }
+
+ private static double[] gamma(int n, double shape) {
+ double[] r = new double[n];
+ Random gen = RandomUtils.getRandom();
+ AbstractContinousDistribution gamma = new Gamma(shape, shape, gen);
+ for (int i = 0; i < n; i++) {
+ r[i] = gamma.nextDouble();
+ }
+ return r;
+ }
+}
+
+
r***@apache.org
2018-09-08 23:35:09 UTC
Permalink
http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/jet/random/sampling/RandomSampler.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/mahout/math/jet/random/sampling/RandomSampler.java b/core/src/main/java/org/apache/mahout/math/jet/random/sampling/RandomSampler.java
new file mode 100644
index 0000000..6804547
--- /dev/null
+++ b/core/src/main/java/org/apache/mahout/math/jet/random/sampling/RandomSampler.java
@@ -0,0 +1,503 @@
+/*
+Copyright � 1999 CERN - European Organization for Nuclear Research.
+Permission to use, copy, modify, distribute and sell this software and its documentation for any purpose
+is hereby granted without fee, provided that the above copyright notice appear in all copies and
+that both that copyright notice and this permission notice appear in supporting documentation.
+CERN makes no representations about the suitability of this software for any purpose.
+It is provided "as is" without expressed or implied warranty.
+*/
+package org.apache.mahout.math.jet.random.sampling;
+
+import org.apache.mahout.common.RandomUtils;
+
+import java.util.Random;
+
+/**
+ * Space and time efficiently computes a sorted <i>Simple Random Sample Without Replacement
+ * (SRSWOR)</i>, that is, a sorted set of <tt>n</tt> random numbers from an interval of <tt>N</tt> numbers;
+ * Example: Computing <tt>n=3</tt> random numbers from the interval <tt>[1,50]</tt> may yield
+ * the sorted random set <tt>(7,13,47)</tt>.
+ * Since we are talking about a set (sampling without replacement), no element will occur more than once.
+ * Each number from the <tt>N</tt> numbers has the same probability to be included in the <tt>n</tt> chosen numbers.
+ *
+ * <p><b>Problem:</b> This class solves problems including the following: <i>
+ * Suppose we have a file containing 10^12 objects.
+ * We would like to take a truly random subset of 10^6 objects and do something with it,
+ * for example, compute the sum over some instance field, or whatever.
+ * How do we choose the subset? In particular, how do we avoid multiple equal elements?
+ * How do we do this quick and without consuming excessive memory?
+ * How do we avoid slowly jumping back and forth within the file? </i>
+ *
+ * <p><b>Sorted Simple Random Sample Without Replacement (SRSWOR):</b>
+ * What are the exact semantics of this class? What is a SRSWOR? In which sense exactly is a returned set "random"?
+ * It is random in the sense, that each number from the <tt>N</tt> numbers has the
+ * same probability to be included in the <tt>n</tt> chosen numbers.
+ * For those who think in implementations rather than abstract interfaces:
+ * <i>Suppose, we have an empty list.
+ * We pick a random number between 1 and 10^12 and add it to the list only if it was not
+ * already picked before, i.e. if it is not already contained in the list.
+ * We then do the same thing again and again until we have eventually collected 10^6 distinct numbers.
+ * Now we sort the set ascending and return it.</i>
+ * <dl>
+ * <dt>It is exactly in this sense that this class returns "random" sets.
+ * <b>Note, however, that the implementation of this class uses a technique orders of magnitudes
+ * better (both in time and space) than the one outlined above.</b></dt></dl>
+ *
+ * <p><b>Performance:</b> Space requirements are zero. Running time is <tt>O(n)</tt> on average,
+ * <tt>O(N)</tt> in the worst case.
+ * <h2>Performance (200Mhz Pentium Pro, JDK 1.2, NT)</h2>
+ * <center>
+ * <table border="1" summary="performance table">
+ * <tr>
+ * <td align="center" width="20%">n</td>
+ * <td align="center" width="20%">N</td>
+ * <td align="center" width="20%">Speed [seconds]</td>
+ * </tr>
+ * <tr>
+ * <td align="center" width="20%">10<sup>3</sup></td>
+ * <td align="center" width="20%">1.2*10<sup>3</sup></td>
+ * <td align="center" width="20">0.0014</td>
+ * </tr>
+ * <tr>
+ * <td align="center" width="20%">10<sup>3</sup></td>
+ * <td align="center" width="20%">10<sup>7</sup></td>
+ * <td align="center" width="20">0.006</td>
+ * </tr>
+ * <tr>
+ * <td align="center" width="20%">10<sup>5</sup></td>
+ * <td align="center" width="20%">10<sup>7</sup></td>
+ * <td align="center" width="20">0.7</td>
+ * </tr>
+ * <tr>
+ * <td align="center" width="20%">9.0*10<sup>6</sup></td>
+ * <td align="center" width="20%">10<sup>7</sup></td>
+ * <td align="center" width="20">8.5</td>
+ * </tr>
+ * <tr>
+ * <td align="center" width="20%">9.9*10<sup>6</sup></td>
+ * <td align="center" width="20%">10<sup>7</sup></td>
+ * <td align="center" width="20">2.0 (samples more than 95%)</td>
+ * </tr>
+ * <tr>
+ * <td align="center" width="20%">10<sup>4</sup></td>
+ * <td align="center" width="20%">10<sup>12</sup></td>
+ * <td align="center" width="20">0.07</td>
+ * </tr>
+ * <tr>
+ * <td align="center" width="20%">10<sup>7</sup></td>
+ * <td align="center" width="20%">10<sup>12</sup></td>
+ * <td align="center" width="20">60</td>
+ * </tr>
+ * </table>
+ * </center>
+ *
+ * <p><b>Scalability:</b> This random sampler is designed to be scalable. In iterator style,
+ * it is able to compute and deliver sorted random sets stepwise in units called <i>blocks</i>.
+ * Example: Computing <tt>n=9</tt> random numbers from the interval <tt>[1,50]</tt> in
+ * 3 blocks may yield the blocks <tt>(7,13,14), (27,37,42), (45,46,49)</tt>.
+ * (The maximum of a block is guaranteed to be less than the minimum of its successor block.
+ * Every block is sorted ascending. No element will ever occur twice, both within a block and among blocks.)
+ * A block can be computed and retrieved with method <tt>nextBlock</tt>.
+ * Successive calls to method <tt>nextBlock</tt> will deliver as many random numbers as required.
+ *
+ * <p>Computing and retrieving samples in blocks is useful if you need very many random
+ * numbers that cannot be stored in main memory at the same time.
+ * For example, if you want to compute 10^10 such numbers you can do this by computing
+ * them in blocks of, say, 500 elements each.
+ * You then need only space to keep one block of 500 elements (i.e. 4 KB).
+ * When you are finished processing the first 500 elements you call <tt>nextBlock</tt> to
+ * fill the next 500 elements into the block, process them, and so on.
+ * If you have the time and need, by using such blocks you can compute random sets
+ * up to <tt>n=10^19</tt> random numbers.
+ *
+ * <p>If you do not need the block feature, you can also directly call
+ * the static methods of this class without needing to construct a <tt>RandomSampler</tt> instance first.
+ *
+ * <p><b>Random number generation:</b> By default uses <tt>MersenneTwister</tt>, a very
+ * strong random number generator, much better than <tt>java.util.Random</tt>.
+ * You can also use other strong random number generators of Paul Houle's RngPack package.
+ * For example, <tt>Ranecu</tt>, <tt>Ranmar</tt> and <tt>Ranlux</tt> are strong well
+ * analyzed research grade pseudo-random number generators with known periods.
+ *
+ * <p><b>Implementation:</b> after J.S. Vitter, An Efficient Algorithm for Sequential Random Sampling,
+ * ACM Transactions on Mathematical Software, Vol 13, 1987.
+ * Paper available <A HREF="http://www.cs.duke.edu/~jsv"> here</A>.
+ */
+public final class RandomSampler {
+
+ private RandomSampler() {
+ }
+
+ /**
+ * Efficiently computes a sorted random set of <tt>count</tt> elements from the interval <tt>[low,low+N-1]</tt>. Since
+ * we are talking about a random set, no element will occur more than once.
+ *
+ * <p>Running time is <tt>O(count)</tt>, on average. Space requirements are zero.
+ *
+ * <p>Numbers are filled into the specified array starting at index <tt>fromIndex</tt> to the right. The array is
+ * returned sorted ascending in the range filled with numbers.
+ *
+ * @param n the total number of elements to choose (must be &gt;= 0).
+ * @param N the interval to choose random numbers from is <tt>[low,low+N-1]</tt>.
+ * @param count the number of elements to be filled into <tt>values</tt> by this call (must be &gt;= 0 and
+ * &lt;=<tt>n</tt>). Normally, you will set <tt>count=n</tt>.
+ * @param low the interval to choose random numbers from is <tt>[low,low+N-1]</tt>. Hint: If
+ * <tt>low==0</tt>, then draws random numbers from the interval <tt>[0,N-1]</tt>.
+ * @param values the array into which the random numbers are to be filled; must have a length <tt>&gt;=
+ * count+fromIndex</tt>.
+ * @param fromIndex the first index within <tt>values</tt> to be filled with numbers (inclusive).
+ * @param randomGenerator a random number generator.
+ */
+ private static void rejectMethodD(long n, long N, int count, long low, long[] values, int fromIndex,
+ Random randomGenerator) {
+ /* This algorithm is applicable if a large percentage (90%..100%) of N shall be sampled.
+ In such cases it is more efficient than sampleMethodA() and sampleMethodD().
+ The idea is that it is more efficient to express
+ sample(n,N,count) in terms of reject(N-n,N,count)
+ and then invert the result.
+ For example, sampling 99% turns into sampling 1% plus inversion.
+
+ This algorithm is the same as method sampleMethodD(...) with the exception that sampled elements are rejected,
+ and not sampled elements included in the result set.
+ */
+ n = N - n; // IMPORTANT !!!
+
+ //long threshold;
+ long chosen = -1 + low;
+
+ //long negalphainv =
+ // -13; //tuning paramter, determines when to switch from method D to method A. Dependent on programming
+ // language, platform, etc.
+
+ double nreal = n;
+ double ninv = 1.0 / nreal;
+ double Nreal = N;
+ double Vprime = Math.exp(Math.log(randomGenerator.nextDouble()) * ninv);
+ long qu1 = -n + 1 + N;
+ double qu1real = -nreal + 1.0 + Nreal;
+ //threshold = -negalphainv * n;
+
+ long S;
+ while (n > 1 && count > 0) { //&& threshold<N) {
+ double nmin1inv = 1.0 / (-1.0 + nreal);
+ double negSreal;
+ while (true) {
+ double X;
+ while (true) { // step D2: generate U and X
+ X = Nreal * (-Vprime + 1.0);
+ S = (long) X;
+ if (S < qu1) {
+ break;
+ }
+ Vprime = Math.exp(Math.log(randomGenerator.nextDouble()) * ninv);
+ }
+ double U = randomGenerator.nextDouble();
+ negSreal = -S;
+
+ //step D3: Accept?
+ double y1 = Math.exp(Math.log(U * Nreal / qu1real) * nmin1inv);
+ Vprime = y1 * (-X / Nreal + 1.0) * qu1real / (negSreal + qu1real);
+ if (Vprime <= 1.0) {
+ break;
+ } //break inner loop
+
+ //step D4: Accept?
+ double top = -1.0 + Nreal;
+ long limit;
+ double bottom;
+ if (n - 1 > S) {
+ bottom = -nreal + Nreal;
+ limit = -S + N;
+ } else {
+ bottom = -1.0 + negSreal + Nreal;
+ limit = qu1;
+ }
+ double y2 = 1.0;
+ for (long t = N - 1; t >= limit; t--) {
+ y2 *= top / bottom;
+ top--;
+ bottom--;
+ }
+ if (Nreal / (-X + Nreal) >= y1 * Math.exp(Math.log(y2) * nmin1inv)) {
+ // accept !
+ Vprime = Math.exp(Math.log(randomGenerator.nextDouble()) * nmin1inv);
+ break; //break inner loop
+ }
+ Vprime = Math.exp(Math.log(randomGenerator.nextDouble()) * ninv);
+ }
+
+ //step D5: reject the (S+1)st record !
+ int iter = count; //int iter = (int) (Math.min(S,count));
+ if (S < iter) {
+ iter = (int) S;
+ }
+
+ count -= iter;
+ while (--iter >= 0) {
+ values[fromIndex++] = ++chosen;
+ }
+ chosen++;
+
+ N -= S + 1;
+ Nreal = negSreal - 1.0 + Nreal;
+ n--;
+ nreal--;
+ ninv = nmin1inv;
+ qu1 = -S + qu1;
+ qu1real = negSreal + qu1real;
+ //threshold += negalphainv;
+ } //end while
+
+
+ if (count > 0) { //special case n==1
+ //reject the (S+1)st record !
+ S = (long) (N * Vprime);
+
+ int iter = count; //int iter = (int) (Math.min(S,count));
+ if (S < iter) {
+ iter = (int) S;
+ }
+
+ count -= iter;
+ while (--iter >= 0) {
+ values[fromIndex++] = ++chosen;
+ }
+
+ chosen++;
+
+ // fill the rest
+ while (--count >= 0) {
+ values[fromIndex++] = ++chosen;
+ }
+ }
+ }
+
+ /**
+ * Efficiently computes a sorted random set of <tt>count</tt> elements from the interval <tt>[low,low+N-1]</tt>. Since
+ * we are talking about a random set, no element will occur more than once.
+ *
+ * <p>Running time is <tt>O(count)</tt>, on average. Space requirements are zero.
+ *
+ * <p>Numbers are filled into the specified array starting at index <tt>fromIndex</tt> to the right. The array is
+ * returned sorted ascending in the range filled with numbers.
+ *
+ * <p><b>Random number generation:</b> By default uses <tt>MersenneTwister</tt>, a very strong random number
+ * generator, much better than <tt>java.util.Random</tt>. You can also use other strong random number generators of
+ * Paul Houle's RngPack package. For example, <tt>Ranecu</tt>, <tt>Ranmar</tt> and <tt>Ranlux</tt> are strong well
+ * analyzed research grade pseudo-random number generators with known periods.
+ *
+ * @param n the total number of elements to choose (must be <tt>n &gt;= 0</tt> and <tt>n &lt;= N</tt>).
+ * @param N the interval to choose random numbers from is <tt>[low,low+N-1]</tt>.
+ * @param count the number of elements to be filled into <tt>values</tt> by this call (must be &gt;= 0 and
+ * &lt;=<tt>n</tt>). Normally, you will set <tt>count=n</tt>.
+ * @param low the interval to choose random numbers from is <tt>[low,low+N-1]</tt>. Hint: If
+ * <tt>low==0</tt>, then draws random numbers from the interval <tt>[0,N-1]</tt>.
+ * @param values the array into which the random numbers are to be filled; must have a length <tt>&gt;=
+ * count+fromIndex</tt>.
+ * @param fromIndex the first index within <tt>values</tt> to be filled with numbers (inclusive).
+ * @param randomGenerator a random number generator. Set this parameter to <tt>null</tt> to use the default random
+ * number generator.
+ */
+ public static void sample(long n, long N, int count, long low, long[] values, int fromIndex,
+ Random randomGenerator) {
+ if (n <= 0 || count <= 0) {
+ return;
+ }
+ if (count > n) {
+ throw new IllegalArgumentException("count must not be greater than n");
+ }
+ if (randomGenerator == null) {
+ randomGenerator = RandomUtils.getRandom();
+ }
+
+ if (count == N) { // rare case treated quickly
+ long val = low;
+ int limit = fromIndex + count;
+ for (int i = fromIndex; i < limit; i++) {
+ values[i] = val++;
+ }
+ return;
+ }
+
+ if (n < N * 0.95) { // || Math.min(count,N-n)>maxTmpMemoryAllowed) {
+ sampleMethodD(n, N, count, low, values, fromIndex, randomGenerator);
+ } else { // More than 95% of all numbers shall be sampled.
+ rejectMethodD(n, N, count, low, values, fromIndex, randomGenerator);
+ }
+
+
+ }
+
+ /**
+ * Computes a sorted random set of <tt>count</tt> elements from the interval <tt>[low,low+N-1]</tt>. Since we are
+ * talking about a random set, no element will occur more than once.
+ *
+ * <p>Running time is <tt>O(N)</tt>, on average. Space requirements are zero.
+ *
+ * <p>Numbers are filled into the specified array starting at index <tt>fromIndex</tt> to the right. The array is
+ * returned sorted ascending in the range filled with numbers.
+ *
+ * @param n the total number of elements to choose (must be &gt;= 0).
+ * @param N the interval to choose random numbers from is <tt>[low,low+N-1]</tt>.
+ * @param count the number of elements to be filled into <tt>values</tt> by this call (must be &gt;= 0 and
+ * &lt;=<tt>n</tt>). Normally, you will set <tt>count=n</tt>.
+ * @param low the interval to choose random numbers from is <tt>[low,low+N-1]</tt>. Hint: If
+ * <tt>low==0</tt>, then draws random numbers from the interval <tt>[0,N-1]</tt>.
+ * @param values the array into which the random numbers are to be filled; must have a length <tt>&gt;=
+ * count+fromIndex</tt>.
+ * @param fromIndex the first index within <tt>values</tt> to be filled with numbers (inclusive).
+ * @param randomGenerator a random number generator.
+ */
+ private static void sampleMethodA(long n, long N, int count, long low, long[] values, int fromIndex,
+ Random randomGenerator) {
+ long chosen = -1 + low;
+
+ double top = N - n;
+ double Nreal = N;
+ long S;
+ while (n >= 2 && count > 0) {
+ double V = randomGenerator.nextDouble();
+ S = 0;
+ double quot = top / Nreal;
+ while (quot > V) {
+ S++;
+ top--;
+ Nreal--;
+ quot *= top / Nreal;
+ }
+ chosen += S + 1;
+ values[fromIndex++] = chosen;
+ count--;
+ Nreal--;
+ n--;
+ }
+
+ if (count > 0) {
+ // special case n==1
+ S = (long) (Math.round(Nreal) * randomGenerator.nextDouble());
+ chosen += S + 1;
+ values[fromIndex] = chosen;
+ }
+ }
+
+ /**
+ * Efficiently computes a sorted random set of <tt>count</tt> elements from the interval <tt>[low,low+N-1]</tt>. Since
+ * we are talking about a random set, no element will occur more than once.
+ *
+ * <p>Running time is <tt>O(count)</tt>, on average. Space requirements are zero.
+ *
+ * <p>Numbers are filled into the specified array starting at index <tt>fromIndex</tt> to the right. The array is
+ * returned sorted ascending in the range filled with numbers.
+ *
+ * @param n the total number of elements to choose (must be &gt;= 0).
+ * @param N the interval to choose random numbers from is <tt>[low,low+N-1]</tt>.
+ * @param count the number of elements to be filled into <tt>values</tt> by this call (must be &gt;= 0 and
+ * &lt;=<tt>n</tt>). Normally, you will set <tt>count=n</tt>.
+ * @param low the interval to choose random numbers from is <tt>[low,low+N-1]</tt>. Hint: If
+ * <tt>low==0</tt>, then draws random numbers from the interval <tt>[0,N-1]</tt>.
+ * @param values the array into which the random numbers are to be filled; must have a length <tt>&gt;=
+ * count+fromIndex</tt>.
+ * @param fromIndex the first index within <tt>values</tt> to be filled with numbers (inclusive).
+ * @param randomGenerator a random number generator.
+ */
+ private static void sampleMethodD(long n, long N, int count, long low, long[] values, int fromIndex,
+ Random randomGenerator) {
+ long chosen = -1 + low;
+
+ double nreal = n;
+ double ninv = 1.0 / nreal;
+ double Nreal = N;
+ double vprime = Math.exp(Math.log(randomGenerator.nextDouble()) * ninv);
+ long qu1 = -n + 1 + N;
+ double qu1real = -nreal + 1.0 + Nreal;
+ long negalphainv = -13;
+ //tuning paramter, determines when to switch from method D to method A. Dependent on programming
+ // language, platform, etc.
+ long threshold = -negalphainv * n;
+
+ long S;
+ while (n > 1 && count > 0 && threshold < N) {
+ double nmin1inv = 1.0 / (-1.0 + nreal);
+ double negSreal;
+ while (true) {
+ double X;
+ while (true) { // step D2: generate U and X
+ X = Nreal * (-vprime + 1.0);
+ S = (long) X;
+ if (S < qu1) {
+ break;
+ }
+ vprime = Math.exp(Math.log(randomGenerator.nextDouble()) * ninv);
+ }
+ double U = randomGenerator.nextDouble();
+ negSreal = -S;
+
+ //step D3: Accept?
+ double y1 = Math.exp(Math.log(U * Nreal / qu1real) * nmin1inv);
+ vprime = y1 * (-X / Nreal + 1.0) * qu1real / (negSreal + qu1real);
+ if (vprime <= 1.0) {
+ break;
+ } //break inner loop
+
+ //step D4: Accept?
+ double top = -1.0 + Nreal;
+ long limit;
+ double bottom;
+ if (n - 1 > S) {
+ bottom = -nreal + Nreal;
+ limit = -S + N;
+ } else {
+ bottom = -1.0 + negSreal + Nreal;
+ limit = qu1;
+ }
+ double y2 = 1.0;
+ for (long t = N - 1; t >= limit; t--) {
+ y2 *= top / bottom;
+ top--;
+ bottom--;
+ }
+ if (Nreal / (-X + Nreal) >= y1 * Math.exp(Math.log(y2) * nmin1inv)) {
+ // accept !
+ vprime = Math.exp(Math.log(randomGenerator.nextDouble()) * nmin1inv);
+ break; //break inner loop
+ }
+ vprime = Math.exp(Math.log(randomGenerator.nextDouble()) * ninv);
+ }
+
+ //step D5: select the (S+1)st record !
+ chosen += S + 1;
+ values[fromIndex++] = chosen;
+ /*
+ // invert
+ for (int iter=0; iter<S && count > 0; iter++) {
+ values[fromIndex++] = ++chosen;
+ count--;
+ }
+ chosen++;
+ */
+ count--;
+
+ N -= S + 1;
+ Nreal = negSreal - 1.0 + Nreal;
+ n--;
+ nreal--;
+ ninv = nmin1inv;
+ qu1 = -S + qu1;
+ qu1real = negSreal + qu1real;
+ threshold += negalphainv;
+ } //end while
+
+
+ if (count > 0) {
+ if (n > 1) { //faster to use method A to finish the sampling
+ sampleMethodA(n, N, count, chosen + 1, values, fromIndex, randomGenerator);
+ } else {
+ //special case n==1
+ S = (long) (N * vprime);
+ chosen += S + 1;
+ values[fromIndex++] = chosen;
+ }
+ }
+ }
+
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/jet/stat/Gamma.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/mahout/math/jet/stat/Gamma.java b/core/src/main/java/org/apache/mahout/math/jet/stat/Gamma.java
new file mode 100644
index 0000000..3ab61a6
--- /dev/null
+++ b/core/src/main/java/org/apache/mahout/math/jet/stat/Gamma.java
@@ -0,0 +1,681 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/*
+Copyright 1999 CERN - European Organization for Nuclear Research.
+Permission to use, copy, modify, distribute and sell this software and its documentation for any purpose
+is hereby granted without fee, provided that the above copyright notice appear in all copies and
+that both that copyright notice and this permission notice appear in supporting documentation.
+CERN makes no representations about the suitability of this software for any purpose.
+It is provided "as is" without expressed or implied warranty.
+*/
+package org.apache.mahout.math.jet.stat;
+
+import org.apache.mahout.math.jet.math.Constants;
+import org.apache.mahout.math.jet.math.Polynomial;
+
+/** Partially deprecated until unit tests are in place. Until this time, this class/interface is unsupported. */
+public final class Gamma {
+
+ private static final double MAXSTIR = 143.01608;
+
+ private Gamma() {
+ }
+
+ /**
+ * Returns the beta function of the arguments.
+ * <pre>
+ * - -
+ * | (a) | (b)
+ * beta( a, b ) = -----------.
+ * -
+ * | (a+b)
+ * </pre>
+ * @param alpha
+ * @param beta
+ * @return The beta function for given values of alpha and beta.
+ */
+ public static double beta(double alpha, double beta) {
+ double y;
+ if (alpha < 40 && beta < 40) {
+ y = gamma(alpha + beta);
+ if (y == 0.0) {
+ return 1.0;
+ }
+
+ if (alpha > beta) {
+ y = gamma(alpha) / y;
+ y *= gamma(beta);
+ } else {
+ y = gamma(beta) / y;
+ y *= gamma(alpha);
+ }
+ } else {
+ y = Math.exp(logGamma(alpha) + logGamma(beta) - logGamma(alpha + beta));
+ }
+
+ return y;
+ }
+
+ /** Returns the Gamma function of the argument. */
+ public static double gamma(double x) {
+
+ double[] pCoefficient = {
+ 1.60119522476751861407E-4,
+ 1.19135147006586384913E-3,
+ 1.04213797561761569935E-2,
+ 4.76367800457137231464E-2,
+ 2.07448227648435975150E-1,
+ 4.94214826801497100753E-1,
+ 9.99999999999999996796E-1
+ };
+ double[] qCoefficient = {
+ -2.31581873324120129819E-5,
+ 5.39605580493303397842E-4,
+ -4.45641913851797240494E-3,
+ 1.18139785222060435552E-2,
+ 3.58236398605498653373E-2,
+ -2.34591795718243348568E-1,
+ 7.14304917030273074085E-2,
+ 1.00000000000000000320E0
+ };
+//double MAXGAM = 171.624376956302725;
+//double LOGPI = 1.14472988584940017414;
+
+ double p;
+ double z;
+
+ double q = Math.abs(x);
+
+ if (q > 33.0) {
+ if (x < 0.0) {
+ p = Math.floor(q);
+ if (p == q) {
+ throw new ArithmeticException("gamma: overflow");
+ }
+ //int i = (int) p;
+ z = q - p;
+ if (z > 0.5) {
+ p += 1.0;
+ z = q - p;
+ }
+ z = q * Math.sin(Math.PI * z);
+ if (z == 0.0) {
+ throw new ArithmeticException("gamma: overflow");
+ }
+ z = Math.abs(z);
+ z = Math.PI / (z * stirlingFormula(q));
+
+ return -z;
+ } else {
+ return stirlingFormula(x);
+ }
+ }
+
+ z = 1.0;
+ while (x >= 3.0) {
+ x -= 1.0;
+ z *= x;
+ }
+
+ while (x < 0.0) {
+ if (x == 0.0) {
+ throw new ArithmeticException("gamma: singular");
+ }
+ if (x > -1.0e-9) {
+ return z / ((1.0 + 0.5772156649015329 * x) * x);
+ }
+ z /= x;
+ x += 1.0;
+ }
+
+ while (x < 2.0) {
+ if (x == 0.0) {
+ throw new ArithmeticException("gamma: singular");
+ }
+ if (x < 1.0e-9) {
+ return z / ((1.0 + 0.5772156649015329 * x) * x);
+ }
+ z /= x;
+ x += 1.0;
+ }
+
+ if ((x == 2.0) || (x == 3.0)) {
+ return z;
+ }
+
+ x -= 2.0;
+ p = Polynomial.polevl(x, pCoefficient, 6);
+ q = Polynomial.polevl(x, qCoefficient, 7);
+ return z * p / q;
+
+ }
+
+ /**
+ * Returns the regularized Incomplete Beta Function evaluated from zero to <tt>xx</tt>; formerly named <tt>ibeta</tt>.
+ *
+ * See http://en.wikipedia.org/wiki/Incomplete_beta_function#Incomplete_beta_function
+ *
+ * @param alpha the alpha parameter of the beta distribution.
+ * @param beta the beta parameter of the beta distribution.
+ * @param xx the integration end point.
+ */
+ public static double incompleteBeta(double alpha, double beta, double xx) {
+
+ if (alpha <= 0.0) {
+ throw new ArithmeticException("incompleteBeta: Domain error! alpha must be > 0, but was " + alpha);
+ }
+
+ if (beta <= 0.0) {
+ throw new ArithmeticException("incompleteBeta: Domain error! beta must be > 0, but was " + beta);
+ }
+
+ if (xx <= 0.0) {
+ return 0.0;
+ }
+
+ if (xx >= 1.0) {
+ return 1.0;
+ }
+
+ double t;
+ if ((beta * xx) <= 1.0 && xx <= 0.95) {
+ t = powerSeries(alpha, beta, xx);
+ return t;
+ }
+
+ double w = 1.0 - xx;
+
+ /* Reverse a and b if x is greater than the mean. */
+ double xc;
+ double x;
+ double b;
+ double a;
+ boolean flag = false;
+ if (xx > (alpha / (alpha + beta))) {
+ flag = true;
+ a = beta;
+ b = alpha;
+ xc = xx;
+ x = w;
+ } else {
+ a = alpha;
+ b = beta;
+ xc = w;
+ x = xx;
+ }
+
+ if (flag && (b * x) <= 1.0 && x <= 0.95) {
+ t = powerSeries(a, b, x);
+ t = t <= Constants.MACHEP ? 1.0 - Constants.MACHEP : 1.0 - t;
+ return t;
+ }
+
+ /* Choose expansion for better convergence. */
+ double y = x * (a + b - 2.0) - (a - 1.0);
+ w = y < 0.0 ? incompleteBetaFraction1(a, b, x) : incompleteBetaFraction2(a, b, x) / xc;
+
+ /* Multiply w by the factor
+ a b _ _ _
+ x (1-x) | (a+b) / ( a | (a) | (b) ) . */
+
+ y = a * Math.log(x);
+ t = b * Math.log(xc);
+ if ((a + b) < Constants.MAXGAM && Math.abs(y) < Constants.MAXLOG && Math.abs(t) < Constants.MAXLOG) {
+ t = Math.pow(xc, b);
+ t *= Math.pow(x, a);
+ t /= a;
+ t *= w;
+ t *= gamma(a + b) / (gamma(a) * gamma(b));
+ if (flag) {
+ t = t <= Constants.MACHEP ? 1.0 - Constants.MACHEP : 1.0 - t;
+ }
+ return t;
+ }
+ /* Resort to logarithms. */
+ y += t + logGamma(a + b) - logGamma(a) - logGamma(b);
+ y += Math.log(w / a);
+ t = y < Constants.MINLOG ? 0.0 : Math.exp(y);
+
+ if (flag) {
+ t = t <= Constants.MACHEP ? 1.0 - Constants.MACHEP : 1.0 - t;
+ }
+ return t;
+ }
+
+ /** Continued fraction expansion #1 for incomplete beta integral; formerly named <tt>incbcf</tt>. */
+ static double incompleteBetaFraction1(double a, double b, double x) {
+
+ double k1 = a;
+ double k2 = a + b;
+ double k3 = a;
+ double k4 = a + 1.0;
+ double k5 = 1.0;
+ double k6 = b - 1.0;
+ double k7 = k4;
+ double k8 = a + 2.0;
+
+ double pkm2 = 0.0;
+ double qkm2 = 1.0;
+ double pkm1 = 1.0;
+ double qkm1 = 1.0;
+ double ans = 1.0;
+ double r = 1.0;
+ int n = 0;
+ double thresh = 3.0 * Constants.MACHEP;
+ do {
+ double xk = -(x * k1 * k2) / (k3 * k4);
+ double pk = pkm1 + pkm2 * xk;
+ double qk = qkm1 + qkm2 * xk;
+ pkm2 = pkm1;
+ pkm1 = pk;
+ qkm2 = qkm1;
+ qkm1 = qk;
+
+ xk = (x * k5 * k6) / (k7 * k8);
+ pk = pkm1 + pkm2 * xk;
+ qk = qkm1 + qkm2 * xk;
+ pkm2 = pkm1;
+ pkm1 = pk;
+ qkm2 = qkm1;
+ qkm1 = qk;
+
+ if (qk != 0) {
+ r = pk / qk;
+ }
+ double t;
+ if (r != 0) {
+ t = Math.abs((ans - r) / r);
+ ans = r;
+ } else {
+ t = 1.0;
+ }
+
+ if (t < thresh) {
+ return ans;
+ }
+
+ k1 += 1.0;
+ k2 += 1.0;
+ k3 += 2.0;
+ k4 += 2.0;
+ k5 += 1.0;
+ k6 -= 1.0;
+ k7 += 2.0;
+ k8 += 2.0;
+
+ if ((Math.abs(qk) + Math.abs(pk)) > Constants.BIG) {
+ pkm2 *= Constants.BIG_INVERSE;
+ pkm1 *= Constants.BIG_INVERSE;
+ qkm2 *= Constants.BIG_INVERSE;
+ qkm1 *= Constants.BIG_INVERSE;
+ }
+ if ((Math.abs(qk) < Constants.BIG_INVERSE) || (Math.abs(pk) < Constants.BIG_INVERSE)) {
+ pkm2 *= Constants.BIG;
+ pkm1 *= Constants.BIG;
+ qkm2 *= Constants.BIG;
+ qkm1 *= Constants.BIG;
+ }
+ } while (++n < 300);
+
+ return ans;
+ }
+
+ /** Continued fraction expansion #2 for incomplete beta integral; formerly named <tt>incbd</tt>. */
+ static double incompleteBetaFraction2(double a, double b, double x) {
+
+ double k1 = a;
+ double k2 = b - 1.0;
+ double k3 = a;
+ double k4 = a + 1.0;
+ double k5 = 1.0;
+ double k6 = a + b;
+ double k7 = a + 1.0;
+ double k8 = a + 2.0;
+
+ double pkm2 = 0.0;
+ double qkm2 = 1.0;
+ double pkm1 = 1.0;
+ double qkm1 = 1.0;
+ double z = x / (1.0 - x);
+ double ans = 1.0;
+ double r = 1.0;
+ int n = 0;
+ double thresh = 3.0 * Constants.MACHEP;
+ do {
+ double xk = -(z * k1 * k2) / (k3 * k4);
+ double pk = pkm1 + pkm2 * xk;
+ double qk = qkm1 + qkm2 * xk;
+ pkm2 = pkm1;
+ pkm1 = pk;
+ qkm2 = qkm1;
+ qkm1 = qk;
+
+ xk = (z * k5 * k6) / (k7 * k8);
+ pk = pkm1 + pkm2 * xk;
+ qk = qkm1 + qkm2 * xk;
+ pkm2 = pkm1;
+ pkm1 = pk;
+ qkm2 = qkm1;
+ qkm1 = qk;
+
+ if (qk != 0) {
+ r = pk / qk;
+ }
+ double t;
+ if (r != 0) {
+ t = Math.abs((ans - r) / r);
+ ans = r;
+ } else {
+ t = 1.0;
+ }
+
+ if (t < thresh) {
+ return ans;
+ }
+
+ k1 += 1.0;
+ k2 -= 1.0;
+ k3 += 2.0;
+ k4 += 2.0;
+ k5 += 1.0;
+ k6 += 1.0;
+ k7 += 2.0;
+ k8 += 2.0;
+
+ if ((Math.abs(qk) + Math.abs(pk)) > Constants.BIG) {
+ pkm2 *= Constants.BIG_INVERSE;
+ pkm1 *= Constants.BIG_INVERSE;
+ qkm2 *= Constants.BIG_INVERSE;
+ qkm1 *= Constants.BIG_INVERSE;
+ }
+ if ((Math.abs(qk) < Constants.BIG_INVERSE) || (Math.abs(pk) < Constants.BIG_INVERSE)) {
+ pkm2 *= Constants.BIG;
+ pkm1 *= Constants.BIG;
+ qkm2 *= Constants.BIG;
+ qkm1 *= Constants.BIG;
+ }
+ } while (++n < 300);
+
+ return ans;
+ }
+
+ /**
+ * Returns the Incomplete Gamma function; formerly named <tt>igamma</tt>.
+ *
+ * @param alpha the shape parameter of the gamma distribution.
+ * @param x the integration end point.
+ * @return The value of the unnormalized incomplete gamma function.
+ */
+ public static double incompleteGamma(double alpha, double x) {
+ if (x <= 0 || alpha <= 0) {
+ return 0.0;
+ }
+
+ if (x > 1.0 && x > alpha) {
+ return 1.0 - incompleteGammaComplement(alpha, x);
+ }
+
+ /* Compute x**a * exp(-x) / gamma(a) */
+ double ax = alpha * Math.log(x) - x - logGamma(alpha);
+ if (ax < -Constants.MAXLOG) {
+ return 0.0;
+ }
+
+ ax = Math.exp(ax);
+
+ /* power series */
+ double r = alpha;
+ double c = 1.0;
+ double ans = 1.0;
+
+ do {
+ r += 1.0;
+ c *= x / r;
+ ans += c;
+ }
+ while (c / ans > Constants.MACHEP);
+
+ return ans * ax / alpha;
+
+ }
+
+ /**
+ * Returns the Complemented Incomplete Gamma function; formerly named <tt>igamc</tt>.
+ *
+ * @param alpha the shape parameter of the gamma distribution.
+ * @param x the integration start point.
+ */
+ public static double incompleteGammaComplement(double alpha, double x) {
+
+ if (x <= 0 || alpha <= 0) {
+ return 1.0;
+ }
+
+ if (x < 1.0 || x < alpha) {
+ return 1.0 - incompleteGamma(alpha, x);
+ }
+
+ double ax = alpha * Math.log(x) - x - logGamma(alpha);
+ if (ax < -Constants.MAXLOG) {
+ return 0.0;
+ }
+
+ ax = Math.exp(ax);
+
+ /* continued fraction */
+ double y = 1.0 - alpha;
+ double z = x + y + 1.0;
+ double c = 0.0;
+ double pkm2 = 1.0;
+ double qkm2 = x;
+ double pkm1 = x + 1.0;
+ double qkm1 = z * x;
+ double ans = pkm1 / qkm1;
+
+ double t;
+ do {
+ c += 1.0;
+ y += 1.0;
+ z += 2.0;
+ double yc = y * c;
+ double pk = pkm1 * z - pkm2 * yc;
+ double qk = qkm1 * z - qkm2 * yc;
+ if (qk != 0) {
+ double r = pk / qk;
+ t = Math.abs((ans - r) / r);
+ ans = r;
+ } else {
+ t = 1.0;
+ }
+
+ pkm2 = pkm1;
+ pkm1 = pk;
+ qkm2 = qkm1;
+ qkm1 = qk;
+ if (Math.abs(pk) > Constants.BIG) {
+ pkm2 *= Constants.BIG_INVERSE;
+ pkm1 *= Constants.BIG_INVERSE;
+ qkm2 *= Constants.BIG_INVERSE;
+ qkm1 *= Constants.BIG_INVERSE;
+ }
+ } while (t > Constants.MACHEP);
+
+ return ans * ax;
+ }
+
+ /** Returns the natural logarithm of the gamma function; formerly named <tt>lgamma</tt>. */
+ public static double logGamma(double x) {
+ double p;
+ double q;
+ double z;
+
+ double[] aCoefficient = {
+ 8.11614167470508450300E-4,
+ -5.95061904284301438324E-4,
+ 7.93650340457716943945E-4,
+ -2.77777777730099687205E-3,
+ 8.33333333333331927722E-2
+ };
+ double[] bCoefficient = {
+ -1.37825152569120859100E3,
+ -3.88016315134637840924E4,
+ -3.31612992738871184744E5,
+ -1.16237097492762307383E6,
+ -1.72173700820839662146E6,
+ -8.53555664245765465627E5
+ };
+ double[] cCoefficient = {
+ /* 1.00000000000000000000E0, */
+ -3.51815701436523470549E2,
+ -1.70642106651881159223E4,
+ -2.20528590553854454839E5,
+ -1.13933444367982507207E6,
+ -2.53252307177582951285E6,
+ -2.01889141433532773231E6
+ };
+
+ if (x < -34.0) {
+ q = -x;
+ double w = logGamma(q);
+ p = Math.floor(q);
+ if (p == q) {
+ throw new ArithmeticException("lgam: Overflow");
+ }
+ z = q - p;
+ if (z > 0.5) {
+ p += 1.0;
+ z = p - q;
+ }
+ z = q * Math.sin(Math.PI * z);
+ if (z == 0.0) {
+ throw new
+ ArithmeticException("lgamma: Overflow");
+ }
+ z = Constants.LOGPI - Math.log(z) - w;
+ return z;
+ }
+
+ if (x < 13.0) {
+ z = 1.0;
+ while (x >= 3.0) {
+ x -= 1.0;
+ z *= x;
+ }
+ while (x < 2.0) {
+ if (x == 0.0) {
+ throw new ArithmeticException("lgamma: Overflow");
+ }
+ z /= x;
+ x += 1.0;
+ }
+ if (z < 0.0) {
+ z = -z;
+ }
+ if (x == 2.0) {
+ return Math.log(z);
+ }
+ x -= 2.0;
+ p = x * Polynomial.polevl(x, bCoefficient, 5) / Polynomial.p1evl(x, cCoefficient, 6);
+ return Math.log(z) + p;
+ }
+
+ if (x > 2.556348e305) {
+ throw new ArithmeticException("lgamma: Overflow");
+ }
+
+ q = (x - 0.5) * Math.log(x) - x + 0.91893853320467274178;
+ //if ( x > 1.0e8 ) return( q );
+ if (x > 1.0e8) {
+ return q;
+ }
+
+ p = 1.0 / (x * x);
+ if (x >= 1000.0) {
+ q += ((7.9365079365079365079365e-4 * p
+ - 2.7777777777777777777778e-3) * p
+ + 0.0833333333333333333333) / x;
+ } else {
+ q += Polynomial.polevl(p, aCoefficient, 4) / x;
+ }
+ return q;
+ }
+
+ /**
+ * Power series for incomplete beta integral; formerly named <tt>pseries</tt>. Use when b*x is small and x not too
+ * close to 1.
+ */
+ private static double powerSeries(double a, double b, double x) {
+
+ double ai = 1.0 / a;
+ double u = (1.0 - b) * x;
+ double v = u / (a + 1.0);
+ double t1 = v;
+ double t = u;
+ double n = 2.0;
+ double s = 0.0;
+ double z = Constants.MACHEP * ai;
+ while (Math.abs(v) > z) {
+ u = (n - b) * x / n;
+ t *= u;
+ v = t / (a + n);
+ s += v;
+ n += 1.0;
+ }
+ s += t1;
+ s += ai;
+
+ u = a * Math.log(x);
+ if ((a + b) < Constants.MAXGAM && Math.abs(u) < Constants.MAXLOG) {
+ t = gamma(a + b) / (gamma(a) * gamma(b));
+ s *= t * Math.pow(x, a);
+ } else {
+ t = logGamma(a + b) - logGamma(a) - logGamma(b) + u + Math.log(s);
+ s = t < Constants.MINLOG ? 0.0 : Math.exp(t);
+ }
+ return s;
+ }
+
+ /**
+ * Returns the Gamma function computed by Stirling's formula; formerly named <tt>stirf</tt>. The polynomial STIR is
+ * valid for 33 <= x <= 172.
+ */
+ static double stirlingFormula(double x) {
+ double[] coefficients = {
+ 7.87311395793093628397E-4,
+ -2.29549961613378126380E-4,
+ -2.68132617805781232825E-3,
+ 3.47222221605458667310E-3,
+ 8.33333333333482257126E-2,
+ };
+
+ double w = 1.0 / x;
+ double y = Math.exp(x);
+
+ w = 1.0 + w * Polynomial.polevl(w, coefficients, 4);
+
+ if (x > MAXSTIR) {
+ /* Avoid overflow in Math.pow() */
+ double v = Math.pow(x, 0.5 * x - 0.25);
+ y = v * (v / y);
+ } else {
+ y = Math.pow(x, x - 0.5) / y;
+ }
+ y = Constants.SQTPI * y * w;
+ return y;
+ }
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/jet/stat/Probability.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/mahout/math/jet/stat/Probability.java b/core/src/main/java/org/apache/mahout/math/jet/stat/Probability.java
new file mode 100644
index 0000000..bcd1a86
--- /dev/null
+++ b/core/src/main/java/org/apache/mahout/math/jet/stat/Probability.java
@@ -0,0 +1,203 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/*
+Copyright 1999 CERN - European Organization for Nuclear Research.
+Permission to use, copy, modify, distribute and sell this software and its documentation for any purpose
+is hereby granted without fee, provided that the above copyright notice appear in all copies and
+that both that copyright notice and this permission notice appear in supporting documentation.
+CERN makes no representations about the suitability of this software for any purpose.
+It is provided "as is" without expressed or implied warranty.
+*/
+package org.apache.mahout.math.jet.stat;
+
+import org.apache.mahout.math.jet.random.Normal;
+
+/** Partially deprecated until unit tests are in place. Until this time, this class/interface is unsupported. */
+public final class Probability {
+
+ private static final Normal UNIT_NORMAL = new Normal(0, 1, null);
+
+ private Probability() {
+ }
+
+ /**
+ * Returns the area from zero to <tt>x</tt> under the beta density function.
+ * <pre>
+ * x
+ * - -
+ * | (a+b) | | a-1 b-1
+ * P(x) = ---------- | t (1-t) dt
+ * - - | |
+ * | (a) | (b) -
+ * 0
+ * </pre>
+ * This function is identical to the incomplete beta integral function <tt>Gamma.incompleteBeta(a, b, x)</tt>.
+ *
+ * The complemented function is
+ *
+ * <tt>1 - P(1-x) = Gamma.incompleteBeta( b, a, x )</tt>;
+ */
+ public static double beta(double a, double b, double x) {
+ return Gamma.incompleteBeta(a, b, x);
+ }
+
+ /**
+ * Returns the integral from zero to <tt>x</tt> of the gamma probability density function.
+ * <pre>
+ *
+ * alpha - x
+ * beta | alpha-1 -beta t
+ * y = --------- | t e dt
+ * - |
+ * | (alpha) - 0
+ * </pre>
+ * The incomplete gamma integral is used, according to the relation
+ *
+ * <tt>y = Gamma.incompleteGamma( alpha, beta*x )</tt>.
+ *
+ * See http://en.wikipedia.org/wiki/Gamma_distribution#Probability_density_function
+ *
+ * @param alpha the shape parameter of the gamma distribution.
+ * @param beta the rate parameter of the gamma distribution.
+ * @param x integration end point.
+ */
+ public static double gamma(double alpha, double beta, double x) {
+ if (x < 0.0) {
+ return 0.0;
+ }
+ return Gamma.incompleteGamma(alpha, beta * x);
+ }
+
+ /**
+ * Returns the sum of the terms <tt>0</tt> through <tt>k</tt> of the Negative Binomial Distribution.
+ * {@code
+ * k
+ * -- ( n+j-1 ) n j
+ * > ( ) p (1-p)
+ * -- ( j )
+ * j=0
+ * }
+ * In a sequence of Bernoulli trials, this is the probability that <tt>k</tt> or fewer failures precede the
+ * <tt>n</tt>-th success. <p> The terms are not computed individually; instead the incomplete beta integral is
+ * employed, according to the formula <p> <tt>y = negativeBinomial( k, n, p ) = Gamma.incompleteBeta( n, k+1, p
+ * )</tt>.
+ *
+ * All arguments must be positive,
+ *
+ * @param k end term.
+ * @param n the number of trials.
+ * @param p the probability of success (must be in <tt>(0.0,1.0)</tt>).
+ */
+ public static double negativeBinomial(int k, int n, double p) {
+ if (p < 0.0 || p > 1.0) {
+ throw new IllegalArgumentException();
+ }
+ if (k < 0) {
+ return 0.0;
+ }
+
+ return Gamma.incompleteBeta(n, k + 1, p);
+ }
+
+ /**
+ * Returns the area under the Normal (Gaussian) probability density function, integrated from minus infinity to
+ * <tt>x</tt> (assumes mean is zero, variance is one).
+ * {@code
+ * x
+ * -
+ * 1 | | 2
+ * normal(x) = --------- | exp( - t /2 ) dt
+ * sqrt(2pi) | |
+ * -
+ * -inf.
+ *
+ * = ( 1 + erf(z) ) / 2
+ * = erfc(z) / 2
+ * }
+ * where <tt>z = x/sqrt(2)</tt>. Computation is via the functions <tt>errorFunction</tt> and
+ * <tt>errorFunctionComplement</tt>.
+ * <p>
+ * Computed using method 26.2.17 from Abramovitz and Stegun (see http://www.math.sfu.ca/~cbm/aands/page_932.htm
+ * and http://en.wikipedia.org/wiki/Normal_distribution#Numerical_approximations_of_the_normal_cdf
+ */
+
+ public static double normal(double a) {
+ if (a < 0) {
+ return 1 - normal(-a);
+ }
+ double b0 = 0.2316419;
+ double b1 = 0.319381530;
+ double b2 = -0.356563782;
+ double b3 = 1.781477937;
+ double b4 = -1.821255978;
+ double b5 = 1.330274429;
+ double t = 1 / (1 + b0 * a);
+ return 1 - UNIT_NORMAL.pdf(a) * t * (b1 + t * (b2 + t * (b3 + t * (b4 + t * b5))));
+ }
+
+ /**
+ * Returns the area under the Normal (Gaussian) probability density function, integrated from minus infinity to
+ * <tt>x</tt>.
+ * {@code
+ * x
+ * -
+ * 1 | | 2
+ * normal(x) = --------- | exp( - (t-mean) / 2v ) dt
+ * sqrt(2pi*v)| |
+ * -
+ * -inf.
+ *
+ * }
+ * where <tt>v = variance</tt>. Computation is via the functions <tt>errorFunction</tt>.
+ *
+ * @param mean the mean of the normal distribution.
+ * @param variance the variance of the normal distribution.
+ * @param x the integration limit.
+ */
+ public static double normal(double mean, double variance, double x) {
+ return normal((x - mean) / Math.sqrt(variance));
+ }
+
+ /**
+ * Returns the sum of the first <tt>k</tt> terms of the Poisson distribution.
+ * <pre>
+ * k j
+ * -- -m m
+ * > e --
+ * -- j!
+ * j=0
+ * </pre>
+ * The terms are not summed directly; instead the incomplete gamma integral is employed, according to the relation <p>
+ * <tt>y = poisson( k, m ) = Gamma.incompleteGammaComplement( k+1, m )</tt>.
+ *
+ * The arguments must both be positive.
+ *
+ * @param k number of terms.
+ * @param mean the mean of the poisson distribution.
+ */
+ public static double poisson(int k, double mean) {
+ if (mean < 0) {
+ throw new IllegalArgumentException();
+ }
+ if (k < 0) {
+ return 0.0;
+ }
+ return Gamma.incompleteGammaComplement(k + 1, mean);
+ }
+
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/jet/stat/package-info.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/mahout/math/jet/stat/package-info.java b/core/src/main/java/org/apache/mahout/math/jet/stat/package-info.java
new file mode 100644
index 0000000..1d4d7bd
--- /dev/null
+++ b/core/src/main/java/org/apache/mahout/math/jet/stat/package-info.java
@@ -0,0 +1,5 @@
+/**
+ * Tools for basic and advanced statistics: Estimators, Gamma functions, Beta functions, Probabilities,
+ * Special integrals, etc.
+ */
+package org.apache.mahout.math.jet.stat;

http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/list/AbstractList.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/mahout/math/list/AbstractList.java b/core/src/main/java/org/apache/mahout/math/list/AbstractList.java
new file mode 100644
index 0000000..c672f40
--- /dev/null
+++ b/core/src/main/java/org/apache/mahout/math/list/AbstractList.java
@@ -0,0 +1,247 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+ /*
+Copyright � 1999 CERN - European Organization for Nuclear Research.
+Permission to use, copy, modify, distribute and sell this software and its documentation for any purpose
+is hereby granted without fee, provided that the above copyright notice appear in all copies and
+that both that copyright notice and this permission notice appear in supporting documentation.
+CERN makes no representations about the suitability of this software for any purpose.
+It is provided "as is" without expressed or implied warranty.
+*/
+/*
+Copyright 1999 CERN - European Organization for Nuclear Research.
+Permission to use, copy, modify, distribute and sell this software and its documentation for any purpose
+is hereby granted without fee, provided that the above copyright notice appear in all copies and
+that both that copyright notice and this permission notice appear in supporting documentation.
+CERN makes no representations about the suitability of this software for any purpose.
+It is provided "as is" without expressed or implied warranty.
+*/
+package org.apache.mahout.math.list;
+
+import org.apache.mahout.math.PersistentObject;
+
+/**
+ * Abstract base class for resizable lists holding objects or primitive data types such as
+ * {@code int}, {@code float}, etc.
+ * First see the <a href="package-summary.html">package summary</a> and javadoc
+ * <a href="package-tree.html">tree view</a> to get the broad picture.
+ * <p>
+ * <b>Note that this implementation is not synchronized.</b>
+ *
+ * @author ***@cern.ch
+ * @version 1.0, 09/24/99
+ * @see java.util.ArrayList
+ * @see java.util.Vector
+ * @see java.util.Arrays
+ */
+public abstract class AbstractList extends PersistentObject {
+
+ public abstract int size();
+
+ public boolean isEmpty() {
+ return size() == 0;
+ }
+
+ /**
+ * Inserts <tt>length</tt> dummy elements before the specified position into the receiver. Shifts the element
+ * currently at that position (if any) and any subsequent elements to the right. <b>This method must set the new size
+ * to be <tt>size()+length</tt></b>.
+ *
+ * @param index index before which to insert dummy elements (must be in [0,size])..
+ * @param length number of dummy elements to be inserted.
+ * @throws IndexOutOfBoundsException if <tt>index &lt; 0 || index &gt; size()</tt>.
+ */
+ protected abstract void beforeInsertDummies(int index, int length);
+
+ /** Checks if the given index is in range. */
+ protected static void checkRange(int index, int theSize) {
+ if (index >= theSize || index < 0) {
+ throw new IndexOutOfBoundsException("Index: " + index + ", Size: " + theSize);
+ }
+ }
+
+ /**
+ * Checks if the given range is within the contained array's bounds.
+ *
+ * @throws IndexOutOfBoundsException if <tt>to!=from-1 || from&lt;0 || from&gt;to || to&gt;=size()</tt>.
+ */
+ protected static void checkRangeFromTo(int from, int to, int theSize) {
+ if (to == from - 1) {
+ return;
+ }
+ if (from < 0 || from > to || to >= theSize) {
+ throw new IndexOutOfBoundsException("from: " + from + ", to: " + to + ", size=" + theSize);
+ }
+ }
+
+ /**
+ * Removes all elements from the receiver. The receiver will be empty after this call returns, but keep its current
+ * capacity.
+ */
+ public void clear() {
+ removeFromTo(0, size() - 1);
+ }
+
+ /**
+ * Sorts the receiver into ascending order. This sort is guaranteed to be <i>stable</i>: equal elements will not be
+ * reordered as a result of the sort.<p>
+ *
+ * The sorting algorithm is a modified mergesort (in which the merge is omitted if the highest element in the low
+ * sublist is less than the lowest element in the high sublist). This algorithm offers guaranteed n*log(n)
+ * performance, and can approach linear performance on nearly sorted lists.
+ *
+ * <p><b>You should never call this method unless you are sure that this particular sorting algorithm is the right one
+ * for your data set.</b> It is generally better to call <tt>sort()</tt> or <tt>sortFromTo(...)</tt> instead, because
+ * those methods automatically choose the best sorting algorithm.
+ */
+ public final void mergeSort() {
+ mergeSortFromTo(0, size() - 1);
+ }
+
+ /**
+ * Sorts the receiver into ascending order. This sort is guaranteed to be <i>stable</i>: equal elements will not be
+ * reordered as a result of the sort.<p>
+ *
+ * The sorting algorithm is a modified mergesort (in which the merge is omitted if the highest element in the low
+ * sublist is less than the lowest element in the high sublist). This algorithm offers guaranteed n*log(n)
+ * performance, and can approach linear performance on nearly sorted lists.
+ *
+ * <p><b>You should never call this method unless you are sure that this particular sorting algorithm is the right one
+ * for your data set.</b> It is generally better to call <tt>sort()</tt> or <tt>sortFromTo(...)</tt> instead, because
+ * those methods automatically choose the best sorting algorithm.
+ *
+ * @param from the index of the first element (inclusive) to be sorted.
+ * @param to the index of the last element (inclusive) to be sorted.
+ * @throws IndexOutOfBoundsException if <tt>(from&lt;0 || from&gt;to || to&gt;=size()) && to!=from-1</tt>.
+ */
+ public abstract void mergeSortFromTo(int from, int to);
+
+ /**
+ * Sorts the receiver into ascending order. The sorting algorithm is a tuned quicksort, adapted from Jon L. Bentley
+ * and M. Douglas McIlroy's "Engineering a Sort Function", Software-Practice and Experience, Vol. 23(11) P. 1249-1265
+ * (November 1993). This algorithm offers n*log(n) performance on many data sets that cause other quicksorts to
+ * degrade to quadratic performance.
+ *
+ * <p><b>You should never call this method unless you are sure that this particular sorting algorithm is the right one
+ * for your data set.</b> It is generally better to call <tt>sort()</tt> or <tt>sortFromTo(...)</tt> instead, because
+ * those methods automatically choose the best sorting algorithm.
+ */
+ public final void quickSort() {
+ quickSortFromTo(0, size() - 1);
+ }
+
+ /**
+ * Sorts the specified range of the receiver into ascending order. The sorting algorithm is a tuned quicksort,
+ * adapted from Jon L. Bentley and M. Douglas McIlroy's "Engineering a Sort Function", Software-Practice and
+ * Experience, Vol. 23(11) P. 1249-1265 (November 1993). This algorithm offers n*log(n) performance on many data sets
+ * that cause other quicksorts to degrade to quadratic performance.
+ *
+ * <p><b>You should never call this method unless you are sure that this particular sorting algorithm is the right one
+ * for your data set.</b> It is generally better to call <tt>sort()</tt> or <tt>sortFromTo(...)</tt> instead, because
+ * those methods automatically choose the best sorting algorithm.
+ *
+ * @param from the index of the first element (inclusive) to be sorted.
+ * @param to the index of the last element (inclusive) to be sorted.
+ * @throws IndexOutOfBoundsException if <tt>(from&lt;0 || from&gt;to || to&gt;=size()) && to!=from-1</tt>.
+ */
+ public abstract void quickSortFromTo(int from, int to);
+
+ /**
+ * Removes the element at the specified position from the receiver. Shifts any subsequent elements to the left.
+ *
+ * @param index the index of the element to removed.
+ * @throws IndexOutOfBoundsException if <tt>index &lt; 0 || index &gt;= size()</tt>.
+ */
+ public void remove(int index) {
+ removeFromTo(index, index);
+ }
+
+ /**
+ * Removes from the receiver all elements whose index is between <code>from</code>, inclusive and <code>to</code>,
+ * inclusive. Shifts any succeeding elements to the left (reduces their index). This call shortens the list by
+ * <tt>(to - from + 1)</tt> elements.
+ *
+ * @param fromIndex index of first element to be removed.
+ * @param toIndex index of last element to be removed.
+ * @throws IndexOutOfBoundsException if <tt>(from&lt;0 || from&gt;to || to&gt;=size()) && to!=from-1</tt>.
+ */
+ public abstract void removeFromTo(int fromIndex, int toIndex);
+
+ /** Reverses the elements of the receiver. Last becomes first, second last becomes second first, and so on. */
+ public abstract void reverse();
+
+ /**
+ * Sets the size of the receiver. If the new size is greater than the current size, new null or zero items are added
+ * to the end of the receiver. If the new size is less than the current size, all components at index newSize and
+ * greater are discarded. This method does not release any superfluos internal memory. Use method <tt>trimToSize</tt>
+ * to release superfluos internal memory.
+ *
+ * @param newSize the new size of the receiver.
+ * @throws IndexOutOfBoundsException if <tt>newSize &lt; 0</tt>.
+ */
+ public void setSize(int newSize) {
+ if (newSize < 0) {
+ throw new IndexOutOfBoundsException("newSize:" + newSize);
+ }
+
+ int currentSize = size();
+ if (newSize != currentSize) {
+ if (newSize > currentSize) {
+ beforeInsertDummies(currentSize, newSize - currentSize);
+ } else if (newSize < currentSize) {
+ removeFromTo(newSize, currentSize - 1);
+ }
+ }
+ }
+
+ /**
+ * Sorts the receiver into ascending order.
+ *
+ * The sorting algorithm is dynamically chosen according to the characteristics of the data set.
+ *
+ * This implementation simply calls <tt>sortFromTo(...)</tt>. Override <tt>sortFromTo(...)</tt> if you can determine
+ * which sort is most appropriate for the given data set.
+ */
+ public final void sort() {
+ sortFromTo(0, size() - 1);
+ }
+
+ /**
+ * Sorts the specified range of the receiver into ascending order.
+ *
+ * The sorting algorithm is dynamically chosen according to the characteristics of the data set. This default
+ * implementation simply calls quickSort. Override this method if you can determine which sort is most appropriate for
+ * the given data set.
+ *
+ * @param from the index of the first element (inclusive) to be sorted.
+ * @param to the index of the last element (inclusive) to be sorted.
+ * @throws IndexOutOfBoundsException if <tt>(from&lt;0 || from&gt;to || to&gt;=size()) && to!=from-1</tt>.
+ */
+ public void sortFromTo(int from, int to) {
+ quickSortFromTo(from, to);
+ }
+
+ /**
+ * Trims the capacity of the receiver to be the receiver's current size. Releases any superfluos internal memory. An
+ * application can use this operation to minimize the storage of the receiver. <p> This default implementation does
+ * nothing. Override this method in space efficient implementations.
+ */
+ public void trimToSize() {
+ }
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/list/AbstractObjectList.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/mahout/math/list/AbstractObjectList.java b/core/src/main/java/org/apache/mahout/math/list/AbstractObjectList.java
new file mode 100644
index 0000000..a1a5899
--- /dev/null
+++ b/core/src/main/java/org/apache/mahout/math/list/AbstractObjectList.java
@@ -0,0 +1,80 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*
+Copyright 1999 CERN - European Organization for Nuclear Research.
+Permission to use, copy, modify, distribute and sell this software and its documentation for any purpose
+is hereby granted without fee, provided that the above copyright notice appear in all copies and
+that both that copyright notice and this permission notice appear in supporting documentation.
+CERN makes no representations about the suitability of this software for any purpose.
+It is provided "as is" without expressed or implied warranty.
+*/
+package org.apache.mahout.math.list;
+
+import java.util.Collection;
+
+/**
+ Abstract base class for resizable lists holding objects or primitive data types such as <code>int</code>,
+ <code>float</code>, etc.First see the <a href="package-summary.html">package summary</a> and
+ javadoc <a href="package-tree.html">tree view</a> to get the broad picture.
+ <p>
+ <b>Note that this implementation is not synchronized.</b>
+
+ @author ***@cern.ch
+ @version 1.0, 09/24/99
+ @see java.util.ArrayList
+ @see java.util.Vector
+ @see java.util.Arrays
+ */
+public abstract class AbstractObjectList<T> extends AbstractList {
+
+ /**
+ * Appends all of the elements of the specified Collection to the receiver.
+ *
+ * @throws ClassCastException if an element in the collection is not of the same parameter type of the receiver.
+ */
+ public void addAllOf(Collection<T> collection) {
+ this.beforeInsertAllOf(size(), collection);
+ }
+
+ /**
+ * Inserts all elements of the specified collection before the specified position into the receiver. Shifts the
+ * element currently at that position (if any) and any subsequent elements to the right (increases their indices).
+ *
+ * @param index index before which to insert first element from the specified collection.
+ * @param collection the collection to be inserted
+ * @throws ClassCastException if an element in the collection is not of the same parameter type of the
+ * receiver.
+ * @throws IndexOutOfBoundsException if <tt>index &lt; 0 || index &gt; size()</tt>.
+ */
+ public void beforeInsertAllOf(int index, Collection<T> collection) {
+ this.beforeInsertDummies(index, collection.size());
+ this.replaceFromWith(index, collection);
+ }
+
+ /**
+ * Replaces the part of the receiver starting at <code>from</code> (inclusive) with all the elements of the specified
+ * collection. Does not alter the size of the receiver. Replaces exactly <tt>Math.max(0,Math.min(size()-from,
+ * other.size()))</tt> elements.
+ *
+ * @param from the index at which to copy the first element from the specified collection.
+ * @param other Collection to replace part of the receiver
+ * @throws IndexOutOfBoundsException if <tt>index &lt; 0 || index &gt;= size()</tt>.
+ */
+ public abstract void replaceFromWith(int from, Collection<T> other);
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/list/ObjectArrayList.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/mahout/math/list/ObjectArrayList.java b/core/src/main/java/org/apache/mahout/math/list/ObjectArrayList.java
new file mode 100644
index 0000000..c41141f
--- /dev/null
+++ b/core/src/main/java/org/apache/mahout/math/list/ObjectArrayList.java
@@ -0,0 +1,419 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*
+Copyright � 1999 CERN - European Organization for Nuclear Research.
+Permission to use, copy, modify, distribute and sell this software and its documentation for any purpose
+is hereby granted without fee, provided that the above copyright notice appear in all copies and
+that both that copyright notice and this permission notice appear in supporting documentation.
+CERN makes no representations about the suitability of this software for any purpose.
+It is provided "as is" without expressed or implied warranty.
+*/
+package org.apache.mahout.math.list;
+
+import org.apache.mahout.math.function.ObjectProcedure;
+
+import java.util.Collection;
+
+/**
+ Resizable list holding <code>${valueType}</code> elements; implemented with arrays.
+*/
+
+public class ObjectArrayList<T> extends AbstractObjectList<T> {
+
+ /**
+ * The array buffer into which the elements of the list are stored. The capacity of the list is the length of this
+ * array buffer.
+ */
+ private Object[] elements;
+ private int size;
+
+ /** Constructs an empty list. */
+ public ObjectArrayList() {
+ this(10);
+ }
+
+ /**
+ * Constructs a list containing the specified elements. The initial size and capacity of the list is the length of the
+ * array.
+ *
+ * <b>WARNING:</b> For efficiency reasons and to keep memory usage low, <b>the array is not copied</b>. So if
+ * subsequently you modify the specified array directly via the [] operator, be sure you know what you're doing.
+ *
+ * @param elements the array to be backed by the the constructed list
+ */
+ public ObjectArrayList(T[] elements) {
+ elements(elements);
+ }
+
+ /**
+ * Constructs an empty list with the specified initial capacity.
+ *
+ * @param initialCapacity the number of elements the receiver can hold without auto-expanding itself by allocating new
+ * internal memory.
+ */
+ @SuppressWarnings("unchecked")
+ public ObjectArrayList(int initialCapacity) {
+ elements = new Object[initialCapacity];
+ size = 0;
+ }
+
+ /**
+ * Appends the specified element to the end of this list.
+ *
+ * @param element element to be appended to this list.
+ */
+ public void add(T element) {
+ // overridden for performance only.
+ if (size == elements.length) {
+ ensureCapacity(size + 1);
+ }
+ elements[size++] = element;
+ }
+
+ /**
+ * Inserts the specified element before the specified position into the receiver. Shifts the element currently at that
+ * position (if any) and any subsequent elements to the right.
+ *
+ * @param index index before which the specified element is to be inserted (must be in [0,size]).
+ * @param element element to be inserted.
+ * @throws IndexOutOfBoundsException index is out of range (<tt>index &lt; 0 || index &gt; size()</tt>).
+ */
+ public void beforeInsert(int index, T element) {
+ // overridden for performance only.
+ if (size == index) {
+ add(element);
+ return;
+ }
+ if (index > size || index < 0) {
+ throw new IndexOutOfBoundsException("Index: " + index + ", Size: " + size);
+ }
+ ensureCapacity(size + 1);
+ System.arraycopy(elements, index, elements, index + 1, size - index);
+ elements[index] = element;
+ size++;
+ }
+
+
+ /**
+ * Returns a deep copy of the receiver.
+ *
+ * @return a deep copy of the receiver.
+ */
+ @SuppressWarnings("unchecked")
+ @Override
+ public Object clone() {
+ // overridden for performance only.
+ return new ObjectArrayList<>((T[]) elements.clone());
+ }
+
+ /**
+ * Returns a deep copy of the receiver; uses <code>clone()</code> and casts the result.
+ *
+ * @return a deep copy of the receiver.
+ */
+ @SuppressWarnings("unchecked")
+ public ObjectArrayList<T> copy() {
+ return (ObjectArrayList<T>) clone();
+ }
+
+ /**
+ * Returns the elements currently stored, including invalid elements between size and capacity, if any.
+ *
+ * <b>WARNING:</b> For efficiency reasons and to keep memory usage low, <b>the array is not copied</b>. So if
+ * subsequently you modify the returned array directly via the [] operator, be sure you know what you're doing.
+ *
+ * @return the elements currently stored.
+ */
+ @SuppressWarnings("unchecked")
+ public <Q> Q[] elements() {
+ return (Q[])elements;
+ }
+
+ /**
+ * Sets the receiver's elements to be the specified array (not a copy of it).
+ *
+ * The size and capacity of the list is the length of the array. <b>WARNING:</b> For efficiency reasons and to keep
+ * memory usage low, <b>the array is not copied</b>. So if subsequently you modify the specified array directly via
+ * the [] operator, be sure you know what you're doing.
+ *
+ * @param elements the new elements to be stored.
+ */
+ public void elements(T[] elements) {
+ this.elements = elements;
+ this.size = elements.length;
+ }
+
+ /**
+ * Ensures that the receiver can hold at least the specified number of elements without needing to allocate new
+ * internal memory. If necessary, allocates new internal memory and increases the capacity of the receiver.
+ *
+ * @param minCapacity the desired minimum capacity.
+ */
+ public void ensureCapacity(int minCapacity) {
+ elements = org.apache.mahout.math.Arrays.ensureCapacity(elements, minCapacity);
+ }
+
+ /**
+ * Compares the specified Object with the receiver. Returns true if and only if the specified Object is also an
+ * ArrayList of the same type, both Lists have the same size, and all corresponding pairs of elements in the two Lists
+ * are identical. In other words, two Lists are defined to be equal if they contain the same elements in the same
+ * order.
+ *
+ * @param otherObj the Object to be compared for equality with the receiver.
+ * @return true if the specified Object is equal to the receiver.
+ */
+ @Override
+ @SuppressWarnings("unchecked")
+ public boolean equals(Object otherObj) { //delta
+ // overridden for performance only.
+ if (!(otherObj instanceof ObjectArrayList)) {
+ return super.equals(otherObj);
+ }
+ if (this == otherObj) {
+ return true;
+ }
+ if (otherObj == null) {
+ return false;
+ }
+ ObjectArrayList<?> other = (ObjectArrayList<?>) otherObj;
+ if (size() != other.size()) {
+ return false;
+ }
+
+ Object[] theElements = elements();
+ Object[] otherElements = other.elements();
+ for (int i = size(); --i >= 0;) {
+ if (theElements[i] != otherElements[i]) {
+ return false;
+ }
+ }
+ return true;
+ }
+
+ /**
+ * Applies a procedure to each element of the receiver, if any. Starts at index 0, moving rightwards.
+ *
+ * @param procedure the procedure to be applied. Stops iteration if the procedure returns <tt>false</tt>, otherwise
+ * continues.
+ * @return <tt>false</tt> if the procedure stopped before all elements where iterated over, <tt>true</tt> otherwise.
+ */
+ @SuppressWarnings("unchecked")
+ public boolean forEach(ObjectProcedure<T> procedure) {
+ T[] theElements = (T[]) elements;
+ int theSize = size;
+
+ for (int i = 0; i < theSize;) {
+ if (!procedure.apply(theElements[i++])) {
+ return false;
+ }
+ }
+ return true;
+ }
+
+ /**
+ * Returns the element at the specified position in the receiver.
+ *
+ * @param index index of element to return.
+ * @throws IndexOutOfBoundsException index is out of range (index &lt; 0 || index &gt;= size()).
+ */
+ @SuppressWarnings("unchecked")
+ public T get(int index) {
+ // overridden for performance only.
+ if (index >= size || index < 0) {
+ throw new IndexOutOfBoundsException("Index: " + index + ", Size: " + size);
+ }
+ return (T) elements[index];
+ }
+
+ /**
+ * Returns the element at the specified position in the receiver; <b>WARNING:</b> Does not check preconditions.
+ * Provided with invalid parameters this method may return invalid elements without throwing any exception! <b>You
+ * should only use this method when you are absolutely sure that the index is within bounds.</b> Precondition
+ * (unchecked): <tt>index &gt;= 0 && index &lt; size()</tt>.
+ *
+ * @param index index of element to return.
+ */
+ @SuppressWarnings("unchecked")
+ public T getQuick(int index) {
+ return (T) elements[index];
+ }
+
+ /**
+ * Returns the index of the first occurrence of the specified element. Returns <code>-1</code> if the receiver does
+ * not contain this element. Searches between <code>from</code>, inclusive and <code>to</code>, inclusive. Tests for
+ * identity.
+ *
+ * @param element element to search for.
+ * @param from the leftmost search position, inclusive.
+ * @param to the rightmost search position, inclusive.
+ * @return the index of the first occurrence of the element in the receiver; returns <code>-1</code> if the element is
+ * not found.
+ * @throws IndexOutOfBoundsException index is out of range (<tt>size()&gt;0 && (from&lt;0 || from&gt;to ||
+ * to&gt;=size())</tt>).
+ */
+ public int indexOfFromTo(T element, int from, int to) {
+ // overridden for performance only.
+ if (size == 0) {
+ return -1;
+ }
+ checkRangeFromTo(from, to, size);
+
+ Object[] theElements = elements;
+ for (int i = from; i <= to; i++) {
+ if (element == theElements[i]) {
+ return i;
+ } //found
+ }
+ return -1; //not found
+ }
+
+ /**
+ * Returns the index of the last occurrence of the specified element. Returns <code>-1</code> if the receiver does not
+ * contain this element. Searches beginning at <code>to</code>, inclusive until <code>from</code>, inclusive. Tests
+ * for identity.
+ *
+ * @param element element to search for.
+ * @param from the leftmost search position, inclusive.
+ * @param to the rightmost search position, inclusive.
+ * @return the index of the last occurrence of the element in the receiver; returns <code>-1</code> if the element is
+ * not found.
+ * @throws IndexOutOfBoundsException index is out of range (<tt>size()&gt;0 && (from&lt;0 || from&gt;to ||
+ * to&gt;=size())</tt>).
+ */
+ public int lastIndexOfFromTo(T element, int from, int to) {
+ // overridden for performance only.
+ if (size == 0) {
+ return -1;
+ }
+ checkRangeFromTo(from, to, size);
+
+ Object[] theElements = elements;
+ for (int i = to; i >= from; i--) {
+ if (element == theElements[i]) {
+ return i;
+ } //found
+ }
+ return -1; //not found
+ }
+
+ /**
+ * Returns a new list of the part of the receiver between <code>from</code>, inclusive, and <code>to</code>,
+ * inclusive.
+ *
+ * @param from the index of the first element (inclusive).
+ * @param to the index of the last element (inclusive).
+ * @return a new list
+ * @throws IndexOutOfBoundsException index is out of range (<tt>size()&gt;0 && (from&lt;0 || from&gt;to ||
+ * to&gt;=size())</tt>).
+ */
+ @SuppressWarnings("unchecked")
+ public AbstractObjectList<T> partFromTo(int from, int to) {
+ if (size == 0) {
+ return new ObjectArrayList<>(0);
+ }
+
+ checkRangeFromTo(from, to, size);
+
+ Object[] part = new Object[to - from + 1];
+ System.arraycopy(elements, from, part, 0, to - from + 1);
+ return new ObjectArrayList<>((T[]) part);
+ }
+
+ /** Reverses the elements of the receiver. Last becomes first, second last becomes second first, and so on. */
+ @Override
+ public void reverse() {
+ // overridden for performance only.
+ int limit = size / 2;
+ int j = size - 1;
+
+ Object[] theElements = elements;
+ for (int i = 0; i < limit;) { //swap
+ Object tmp = theElements[i];
+ theElements[i++] = theElements[j];
+ theElements[j--] = tmp;
+ }
+ }
+
+ /**
+ * Replaces the element at the specified position in the receiver with the specified element.
+ *
+ * @param index index of element to replace.
+ * @param element element to be stored at the specified position.
+ * @throws IndexOutOfBoundsException index is out of range (index &lt; 0 || index &gt;= size()).
+ */
+ public void set(int index, T element) {
+ // overridden for performance only.
+ if (index >= size || index < 0) {
+ throw new IndexOutOfBoundsException("Index: " + index + ", Size: " + size);
+ }
+ elements[index] = element;
+ }
+
+ /**
+ * Replaces the element at the specified position in the receiver with the specified element; <b>WARNING:</b> Does not
+ * check preconditions. Provided with invalid parameters this method may access invalid indexes without throwing any
+ * exception! <b>You should only use this method when you are absolutely sure that the index is within bounds.</b>
+ * Precondition (unchecked): {@code index >= 0 && index < size()}.
+ *
+ * @param index index of element to replace.
+ * @param element element to be stored at the specified position.
+ */
+ public void setQuick(int index, T element) {
+ elements[index] = element;
+ }
+
+ /**
+ * Trims the capacity of the receiver to be the receiver's current size. Releases any superfluous internal memory. An
+ * application can use this operation to minimize the storage of the receiver.
+ */
+ @Override
+ public void trimToSize() {
+ elements = org.apache.mahout.math.Arrays.trimToCapacity(elements, size());
+ }
+
+ @Override
+ public void removeFromTo(int fromIndex, int toIndex) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public void replaceFromWith(int from, Collection<T> other) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ protected void beforeInsertDummies(int index, int length) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public void mergeSortFromTo(int from, int to) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public void quickSortFromTo(int from, int to) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public int size() {
+ return size;
+ }
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/list/SimpleLongArrayList.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/mahout/math/list/SimpleLongArrayList.java b/core/src/main/java/org/apache/mahout/math/list/SimpleLongArrayList.java
new file mode 100644
index 0000000..1a765eb
--- /dev/null
+++ b/core/src/main/java/org/apache/mahout/math/list/SimpleLongArrayList.java
@@ -0,0 +1,102 @@
+/*
+Copyright 1999 CERN - European Organization for Nuclear Research.
+Permission to use, copy, modify, distribute and sell this software and its documentation for any purpose
+is hereby granted without fee, provided that the above copyright notice appear in all copies and
+that both that copyright notice and this permission notice appear in supporting documentation.
+CERN makes no representations about the suitability of this software for any purpose.
+It is provided "as is" without expressed or implied warranty.
+*/
+package org.apache.mahout.math.list;
+
+/**
+ Resizable list holding <code>long</code> elements; implemented with arrays; not efficient; just to
+ demonstrate which methods you must override to implement a fully functional list.
+ */
+public class SimpleLongArrayList extends AbstractLongList {
+
+ /**
+ * The array buffer into which the elements of the list are stored. The capacity of the list is the length of this
+ * array buffer.
+ */
+ private long[] elements;
+
+ /** Constructs an empty list. */
+ public SimpleLongArrayList() {
+ this(10);
+ }
+
+ /**
+ * Constructs a list containing the specified elements. The initial size and capacity of the list is the length of the
+ * array.
+ *
+ * <b>WARNING:</b> For efficiency reasons and to keep memory usage low, <b>the array is not copied</b>. So if
+ * subsequently you modify the specified array directly via the [] operator, be sure you know what you're doing.
+ *
+ * @param elements the array to be backed by the the constructed list
+ */
+ public SimpleLongArrayList(long[] elements) {
+ elements(elements);
+ }
+
+ /**
+ * Constructs an empty list with the specified initial capacity.
+ *
+ * @param initialCapacity the number of elements the receiver can hold without auto-expanding itself by allocating new
+ * internal memory.
+ */
+ private SimpleLongArrayList(int initialCapacity) {
+ if (initialCapacity < 0) {
+ throw new IllegalArgumentException("Illegal Capacity: " + initialCapacity);
+ }
+
+ this.elements(new long[initialCapacity]);
+ size = 0;
+ }
+
+ /**
+ * Ensures that the receiver can hold at least the specified number of elements without needing to allocate new
+ * internal memory. If necessary, allocates new internal memory and increases the capacity of the receiver.
+ *
+ * @param minCapacity the desired minimum capacity.
+ */
+ @Override
+ public void ensureCapacity(int minCapacity) {
+ elements = org.apache.mahout.math.Arrays.ensureCapacity(elements, minCapacity);
+ }
+
+ /**
+ * Returns the element at the specified position in the receiver; <b>WARNING:</b> Does not check preconditions.
+ * Provided with invalid parameters this method may return invalid elements without throwing any exception! <b>You
+ * should only use this method when you are absolutely sure that the index is within bounds.</b> Precondition
+ * (unchecked): <tt>index &gt;= 0 && index &lt; size()</tt>.
+ *
+ * @param index index of element to return.
+ */
+ @Override
+ protected long getQuick(int index) {
+ return elements[index];
+ }
+
+ /**
+ * Replaces the element at the specified position in the receiver with the specified element; <b>WARNING:</b> Does not
+ * check preconditions. Provided with invalid parameters this method may access invalid indexes without throwing any
+ * exception! <b>You should only use this method when you are absolutely sure that the index is within bounds.</b>
+ * Precondition (unchecked): <tt>index &gt;= 0 && index &lt; size()</tt>.
+ *
+ * @param index index of element to replace.
+ * @param element element to be stored at the specified position.
+ */
+ @Override
+ protected void setQuick(int index, long element) {
+ elements[index] = element;
+ }
+
+ /**
+ * Trims the capacity of the receiver to be the receiver's current size. An application can use this operation to
+ * minimize the storage of the receiver.
+ */
+ @Override
+ public void trimToSize() {
+ elements = org.apache.mahout.math.Arrays.trimToCapacity(elements, size());
+ }
+}
r***@apache.org
2018-09-08 23:35:07 UTC
Permalink
http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/random/Multinomial.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/mahout/math/random/Multinomial.java b/core/src/main/java/org/apache/mahout/math/random/Multinomial.java
new file mode 100644
index 0000000..d79c32c
--- /dev/null
+++ b/core/src/main/java/org/apache/mahout/math/random/Multinomial.java
@@ -0,0 +1,202 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math.random;
+
+import java.util.Iterator;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+
+import com.google.common.base.Preconditions;
+import com.google.common.collect.AbstractIterator;
+import com.google.common.collect.Iterables;
+import com.google.common.collect.Lists;
+import com.google.common.collect.Maps;
+import com.google.common.collect.Multiset;
+import org.apache.mahout.common.RandomUtils;
+import org.apache.mahout.math.list.DoubleArrayList;
+
+/**
+ * Multinomial sampler that allows updates to element probabilities. The basic idea is that sampling is
+ * done by using a simple balanced tree. Probabilities are kept in the tree so that we can navigate to
+ * any leaf in log N time. Updates are simple because we can just propagate them upwards.
+ * <p/>
+ * In order to facilitate access by value, we maintain an additional map from value to tree node.
+ */
+public final class Multinomial<T> implements Sampler<T>, Iterable<T> {
+ // these lists use heap ordering. Thus, the root is at location 1, first level children at 2 and 3, second level
+ // at 4, 5 and 6, 7.
+ private final DoubleArrayList weight = new DoubleArrayList();
+ private final List<T> values = Lists.newArrayList();
+ private final Map<T, Integer> items = Maps.newHashMap();
+ private Random rand = RandomUtils.getRandom();
+
+ public Multinomial() {
+ weight.add(0);
+ values.add(null);
+ }
+
+ public Multinomial(Multiset<T> counts) {
+ this();
+ Preconditions.checkArgument(!counts.isEmpty(), "Need some data to build sampler");
+ rand = RandomUtils.getRandom();
+ for (T t : counts.elementSet()) {
+ add(t, counts.count(t));
+ }
+ }
+
+ public Multinomial(Iterable<WeightedThing<T>> things) {
+ this();
+ for (WeightedThing<T> thing : things) {
+ add(thing.getValue(), thing.getWeight());
+ }
+ }
+
+ public void add(T value, double w) {
+ Preconditions.checkNotNull(value);
+ Preconditions.checkArgument(!items.containsKey(value));
+
+ int n = this.weight.size();
+ if (n == 1) {
+ weight.add(w);
+ values.add(value);
+ items.put(value, 1);
+ } else {
+ // parent comes down
+ weight.add(weight.get(n / 2));
+ values.add(values.get(n / 2));
+ items.put(values.get(n / 2), n);
+ n++;
+
+ // new item goes in
+ items.put(value, n);
+ this.weight.add(w);
+ values.add(value);
+
+ // parents get incremented all the way to the root
+ while (n > 1) {
+ n /= 2;
+ this.weight.set(n, this.weight.get(n) + w);
+ }
+ }
+ }
+
+ public double getWeight(T value) {
+ if (items.containsKey(value)) {
+ return weight.get(items.get(value));
+ } else {
+ return 0;
+ }
+ }
+
+ public double getProbability(T value) {
+ if (items.containsKey(value)) {
+ return weight.get(items.get(value)) / weight.get(1);
+ } else {
+ return 0;
+ }
+ }
+
+ public double getWeight() {
+ if (weight.size() > 1) {
+ return weight.get(1);
+ } else {
+ return 0;
+ }
+ }
+
+ public void delete(T value) {
+ set(value, 0);
+ }
+
+ public void set(T value, double newP) {
+ Preconditions.checkArgument(items.containsKey(value));
+ int n = items.get(value);
+ if (newP <= 0) {
+ // this makes the iterator not see such an element even though we leave a phantom in the tree
+ // Leaving the phantom behind simplifies tree maintenance and testing, but isn't really necessary.
+ items.remove(value);
+ }
+ double oldP = weight.get(n);
+ while (n > 0) {
+ weight.set(n, weight.get(n) - oldP + newP);
+ n /= 2;
+ }
+ }
+
+ @Override
+ public T sample() {
+ Preconditions.checkArgument(!weight.isEmpty());
+ return sample(rand.nextDouble());
+ }
+
+ public T sample(double u) {
+ u *= weight.get(1);
+
+ int n = 1;
+ while (2 * n < weight.size()) {
+ // children are at 2n and 2n+1
+ double left = weight.get(2 * n);
+ if (u <= left) {
+ n = 2 * n;
+ } else {
+ u -= left;
+ n = 2 * n + 1;
+ }
+ }
+ return values.get(n);
+ }
+
+ /**
+ * Exposed for testing only. Returns a list of the leaf weights. These are in an
+ * order such that probing just before and after the cumulative sum of these weights
+ * will touch every element of the tree twice and thus will make it possible to test
+ * every possible left/right decision in navigating the tree.
+ */
+ List<Double> getWeights() {
+ List<Double> r = Lists.newArrayList();
+ int i = Integer.highestOneBit(weight.size());
+ while (i < weight.size()) {
+ r.add(weight.get(i));
+ i++;
+ }
+ i /= 2;
+ while (i < Integer.highestOneBit(weight.size())) {
+ r.add(weight.get(i));
+ i++;
+ }
+ return r;
+ }
+
+ @Override
+ public Iterator<T> iterator() {
+ return new AbstractIterator<T>() {
+ Iterator<T> valuesIterator = Iterables.skip(values, 1).iterator();
+ @Override
+ protected T computeNext() {
+ while (valuesIterator.hasNext()) {
+ T next = valuesIterator.next();
+ if (items.containsKey(next)) {
+ return next;
+ }
+ }
+ return endOfData();
+ }
+ };
+ }
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/random/Normal.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/mahout/math/random/Normal.java b/core/src/main/java/org/apache/mahout/math/random/Normal.java
new file mode 100644
index 0000000..c162f26
--- /dev/null
+++ b/core/src/main/java/org/apache/mahout/math/random/Normal.java
@@ -0,0 +1,40 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math.random;
+
+import org.apache.mahout.common.RandomUtils;
+
+import java.util.Random;
+
+public final class Normal extends AbstractSamplerFunction {
+ private final Random rand = RandomUtils.getRandom();
+ private double mean = 0;
+ private double sd = 1;
+
+ public Normal() {}
+
+ public Normal(double mean, double sd) {
+ this.mean = mean;
+ this.sd = sd;
+ }
+
+ @Override
+ public Double sample() {
+ return rand.nextGaussian() * sd + mean;
+ }
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/random/PoissonSampler.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/mahout/math/random/PoissonSampler.java b/core/src/main/java/org/apache/mahout/math/random/PoissonSampler.java
new file mode 100644
index 0000000..e4e49f8
--- /dev/null
+++ b/core/src/main/java/org/apache/mahout/math/random/PoissonSampler.java
@@ -0,0 +1,67 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math.random;
+
+import com.google.common.collect.Lists;
+import org.apache.commons.math3.distribution.PoissonDistribution;
+import org.apache.mahout.common.RandomUtils;
+import org.apache.mahout.common.RandomWrapper;
+
+import java.util.List;
+
+/**
+ * Samples from a Poisson distribution. Should probably not be used for lambda > 1000 or so.
+ */
+public final class PoissonSampler extends AbstractSamplerFunction {
+
+ private double limit;
+ private Multinomial<Integer> partial;
+ private final RandomWrapper gen;
+ private final PoissonDistribution pd;
+
+ public PoissonSampler(double lambda) {
+ limit = 1;
+ gen = RandomUtils.getRandom();
+ pd = new PoissonDistribution(gen.getRandomGenerator(),
+ lambda,
+ PoissonDistribution.DEFAULT_EPSILON,
+ PoissonDistribution.DEFAULT_MAX_ITERATIONS);
+ }
+
+ @Override
+ public Double sample() {
+ return sample(gen.nextDouble());
+ }
+
+ double sample(double u) {
+ if (u < limit) {
+ List<WeightedThing<Integer>> steps = Lists.newArrayList();
+ limit = 1;
+ int i = 0;
+ while (u / 20 < limit) {
+ double pdf = pd.probability(i);
+ limit -= pdf;
+ steps.add(new WeightedThing<>(i, pdf));
+ i++;
+ }
+ steps.add(new WeightedThing<>(steps.size(), limit));
+ partial = new Multinomial<>(steps);
+ }
+ return partial.sample(u);
+ }
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/random/Sampler.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/mahout/math/random/Sampler.java b/core/src/main/java/org/apache/mahout/math/random/Sampler.java
new file mode 100644
index 0000000..51460fa
--- /dev/null
+++ b/core/src/main/java/org/apache/mahout/math/random/Sampler.java
@@ -0,0 +1,25 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math.random;
+
+/**
+ * Samples from a generic type.
+ */
+public interface Sampler<T> {
+ T sample();
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/random/WeightedThing.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/mahout/math/random/WeightedThing.java b/core/src/main/java/org/apache/mahout/math/random/WeightedThing.java
new file mode 100644
index 0000000..20f6df3
--- /dev/null
+++ b/core/src/main/java/org/apache/mahout/math/random/WeightedThing.java
@@ -0,0 +1,71 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math.random;
+
+import com.google.common.base.Preconditions;
+import org.apache.mahout.common.RandomUtils;
+
+/**
+ * Handy for creating multinomial distributions of things.
+ */
+public final class WeightedThing<T> implements Comparable<WeightedThing<T>> {
+ private double weight;
+ private final T value;
+
+ public WeightedThing(T thing, double weight) {
+ this.value = Preconditions.checkNotNull(thing);
+ this.weight = weight;
+ }
+
+ public WeightedThing(double weight) {
+ this.value = null;
+ this.weight = weight;
+ }
+
+ public T getValue() {
+ return value;
+ }
+
+ public double getWeight() {
+ return weight;
+ }
+
+ public void setWeight(double weight) {
+ this.weight = weight;
+ }
+
+ @Override
+ public int compareTo(WeightedThing<T> other) {
+ return Double.compare(this.weight, other.weight);
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (o instanceof WeightedThing) {
+ @SuppressWarnings("unchecked")
+ WeightedThing<T> other = (WeightedThing<T>) o;
+ return weight == other.weight && value.equals(other.value);
+ }
+ return false;
+ }
+
+ @Override
+ public int hashCode() {
+ return 31 * RandomUtils.hashDouble(weight) + value.hashCode();
+ }
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/set/AbstractSet.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/mahout/math/set/AbstractSet.java b/core/src/main/java/org/apache/mahout/math/set/AbstractSet.java
new file mode 100644
index 0000000..7691420
--- /dev/null
+++ b/core/src/main/java/org/apache/mahout/math/set/AbstractSet.java
@@ -0,0 +1,188 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+/*
+Copyright 1999 CERN - European Organization for Nuclear Research.
+Permission to use, copy, modify, distribute and sell this software and its documentation for any purpose
+is hereby granted without fee, provided that the above copyright notice appear in all copies and
+that both that copyright notice and this permission notice appear in supporting documentation.
+CERN makes no representations about the suitability of this software for any purpose.
+It is provided "as is" without expressed or implied warranty.
+*/
+package org.apache.mahout.math.set;
+
+import org.apache.mahout.math.PersistentObject;
+import org.apache.mahout.math.map.PrimeFinder;
+
+public abstract class AbstractSet extends PersistentObject {
+ //public static boolean debug = false; // debug only
+
+ /** The number of distinct associations in the map; its "size()". */
+ protected int distinct;
+
+ /**
+ * The table capacity c=table.length always satisfies the invariant <tt>c * minLoadFactor <= s <= c *
+ * maxLoadFactor</tt>, where s=size() is the number of associations currently contained. The term "c * minLoadFactor"
+ * is called the "lowWaterMark", "c * maxLoadFactor" is called the "highWaterMark". In other words, the table capacity
+ * (and proportionally the memory used by this class) oscillates within these constraints. The terms are precomputed
+ * and cached to avoid recalculating them each time put(..) or removeKey(...) is called.
+ */
+ protected int lowWaterMark;
+ protected int highWaterMark;
+
+ /** The minimum load factor for the hashtable. */
+ protected double minLoadFactor;
+
+ /** The maximum load factor for the hashtable. */
+ protected double maxLoadFactor;
+
+ // these are public access for unit tests.
+ public static final int DEFAULT_CAPACITY = 277;
+ public static final double DEFAULT_MIN_LOAD_FACTOR = 0.2;
+ public static final double DEFAULT_MAX_LOAD_FACTOR = 0.5;
+
+ /**
+ * Chooses a new prime table capacity optimized for growing that (approximately) satisfies the invariant <tt>c *
+ * minLoadFactor <= size <= c * maxLoadFactor</tt> and has at least one FREE slot for the given size.
+ */
+ protected int chooseGrowCapacity(int size, double minLoad, double maxLoad) {
+ return nextPrime(Math.max(size + 1, (int) ((4 * size / (3 * minLoad + maxLoad)))));
+ }
+
+ /**
+ * Returns new high water mark threshold based on current capacity and maxLoadFactor.
+ *
+ * @return int the new threshold.
+ */
+ protected int chooseHighWaterMark(int capacity, double maxLoad) {
+ return Math.min(capacity - 2, (int) (capacity * maxLoad)); //makes sure there is always at least one FREE slot
+ }
+
+ /**
+ * Returns new low water mark threshold based on current capacity and minLoadFactor.
+ *
+ * @return int the new threshold.
+ */
+ protected int chooseLowWaterMark(int capacity, double minLoad) {
+ return (int) (capacity * minLoad);
+ }
+
+ /**
+ * Chooses a new prime table capacity neither favoring shrinking nor growing, that (approximately) satisfies the
+ * invariant <tt>c * minLoadFactor <= size <= c * maxLoadFactor</tt> and has at least one FREE slot for the given
+ * size.
+ */
+ protected int chooseMeanCapacity(int size, double minLoad, double maxLoad) {
+ return nextPrime(Math.max(size + 1, (int) ((2 * size / (minLoad + maxLoad)))));
+ }
+
+ /**
+ * Chooses a new prime table capacity optimized for shrinking that (approximately) satisfies the invariant <tt>c *
+ * minLoadFactor <= size <= c * maxLoadFactor</tt> and has at least one FREE slot for the given size.
+ */
+ protected int chooseShrinkCapacity(int size, double minLoad, double maxLoad) {
+ return nextPrime(Math.max(size + 1, (int) ((4 * size / (minLoad + 3 * maxLoad)))));
+ }
+
+ /** Removes all (key,value) associations from the receiver. */
+ public abstract void clear();
+
+ /**
+ * Ensures that the receiver can hold at least the specified number of elements without needing to allocate new
+ * internal memory. If necessary, allocates new internal memory and increases the capacity of the receiver. <p> This
+ * method never need be called; it is for performance tuning only. Calling this method before <tt>put()</tt>ing a
+ * large number of associations boosts performance, because the receiver will grow only once instead of potentially
+ * many times. <p> <b>This default implementation does nothing.</b> Override this method if necessary.
+ *
+ * @param minCapacity the desired minimum capacity.
+ */
+ public void ensureCapacity(int minCapacity) {
+ }
+
+ /**
+ * Returns <tt>true</tt> if the receiver contains no (key,value) associations.
+ *
+ * @return <tt>true</tt> if the receiver contains no (key,value) associations.
+ */
+ public boolean isEmpty() {
+ return distinct == 0;
+ }
+
+ /**
+ * Returns a prime number which is <code>&gt;= desiredCapacity</code> and very close to <code>desiredCapacity</code>
+ * (within 11% if <code>desiredCapacity &gt;= 1000</code>).
+ *
+ * @param desiredCapacity the capacity desired by the user.
+ * @return the capacity which should be used for a hashtable.
+ */
+ protected int nextPrime(int desiredCapacity) {
+ return PrimeFinder.nextPrime(desiredCapacity);
+ }
+
+ /**
+ * Initializes the receiver. You will almost certainly need to override this method in subclasses to initialize the
+ * hash table.
+ *
+ * @param initialCapacity the initial capacity of the receiver.
+ * @param minLoadFactor the minLoadFactor of the receiver.
+ * @param maxLoadFactor the maxLoadFactor of the receiver.
+ * @throws IllegalArgumentException if <tt>initialCapacity < 0 || (minLoadFactor < 0.0 || minLoadFactor >= 1.0) ||
+ * (maxLoadFactor <= 0.0 || maxLoadFactor >= 1.0) || (minLoadFactor >=
+ * maxLoadFactor)</tt>.
+ */
+ protected void setUp(int initialCapacity, double minLoadFactor, double maxLoadFactor) {
+ if (initialCapacity < 0) {
+ throw new IllegalArgumentException("Initial Capacity must not be less than zero: " + initialCapacity);
+ }
+ if (minLoadFactor < 0.0 || minLoadFactor >= 1.0) {
+ throw new IllegalArgumentException("Illegal minLoadFactor: " + minLoadFactor);
+ }
+ if (maxLoadFactor <= 0.0 || maxLoadFactor >= 1.0) {
+ throw new IllegalArgumentException("Illegal maxLoadFactor: " + maxLoadFactor);
+ }
+ if (minLoadFactor >= maxLoadFactor) {
+ throw new IllegalArgumentException(
+ "Illegal minLoadFactor: " + minLoadFactor + " and maxLoadFactor: " + maxLoadFactor);
+ }
+ }
+
+ /**
+ * Returns the number of (key,value) associations currently contained.
+ *
+ * @return the number of (key,value) associations currently contained.
+ */
+ public int size() {
+ return distinct;
+ }
+
+ /**
+ * Trims the capacity of the receiver to be the receiver's current size. Releases any superfluous internal memory. An
+ * application can use this operation to minimize the storage of the receiver. <p> This default implementation does
+ * nothing. Override this method if necessary.
+ */
+ public void trimToSize() {
+ }
+
+ protected static boolean equalsMindTheNull(Object a, Object b) {
+ if (a == null && b == null) {
+ return true;
+ }
+ if (a == null || b == null) {
+ return false;
+ }
+ return a.equals(b);
+ }
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/set/HashUtils.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/mahout/math/set/HashUtils.java b/core/src/main/java/org/apache/mahout/math/set/HashUtils.java
new file mode 100644
index 0000000..f5dfeb0
--- /dev/null
+++ b/core/src/main/java/org/apache/mahout/math/set/HashUtils.java
@@ -0,0 +1,56 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math.set;
+
+/**
+ * Computes hashes of primitive values. Providing these as statics allows the templated code
+ * to compute hashes of sets.
+ */
+public final class HashUtils {
+
+ private HashUtils() {
+ }
+
+ public static int hash(byte x) {
+ return x;
+ }
+
+ public static int hash(short x) {
+ return x;
+ }
+
+ public static int hash(char x) {
+ return x;
+ }
+
+ public static int hash(int x) {
+ return x;
+ }
+
+ public static int hash(float x) {
+ return Float.floatToIntBits(x) >>> 3 + Float.floatToIntBits((float) (Math.PI * x));
+ }
+
+ public static int hash(double x) {
+ return hash(17 * Double.doubleToLongBits(x));
+ }
+
+ public static int hash(long x) {
+ return (int) ((x * 11) >>> 32 ^ x);
+ }
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/set/OpenHashSet.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/mahout/math/set/OpenHashSet.java b/core/src/main/java/org/apache/mahout/math/set/OpenHashSet.java
new file mode 100644
index 0000000..285b5a5
--- /dev/null
+++ b/core/src/main/java/org/apache/mahout/math/set/OpenHashSet.java
@@ -0,0 +1,548 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.mahout.math.set;
+
+import java.nio.ByteBuffer;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Set;
+
+import org.apache.mahout.math.MurmurHash;
+import org.apache.mahout.math.function.ObjectProcedure;
+import org.apache.mahout.math.map.PrimeFinder;
+
+/**
+ * Open hashing alternative to java.util.HashSet.
+ **/
+public class OpenHashSet<T> extends AbstractSet implements Set<T> {
+ protected static final byte FREE = 0;
+ protected static final byte FULL = 1;
+ protected static final byte REMOVED = 2;
+ protected static final char NO_KEY_VALUE = 0;
+
+ /** The hash table keys. */
+ private Object[] table;
+
+ /** The state of each hash table entry (FREE, FULL, REMOVED). */
+ private byte[] state;
+
+ /** The number of table entries in state==FREE. */
+ private int freeEntries;
+
+
+ /** Constructs an empty map with default capacity and default load factors. */
+ public OpenHashSet() {
+ this(DEFAULT_CAPACITY);
+ }
+
+ /**
+ * Constructs an empty map with the specified initial capacity and default load factors.
+ *
+ * @param initialCapacity the initial capacity of the map.
+ * @throws IllegalArgumentException if the initial capacity is less than zero.
+ */
+ public OpenHashSet(int initialCapacity) {
+ this(initialCapacity, DEFAULT_MIN_LOAD_FACTOR, DEFAULT_MAX_LOAD_FACTOR);
+ }
+
+ /**
+ * Constructs an empty map with the specified initial capacity and the specified minimum and maximum load factor.
+ *
+ * @param initialCapacity the initial capacity.
+ * @param minLoadFactor the minimum load factor.
+ * @param maxLoadFactor the maximum load factor.
+ * @throws IllegalArgumentException if <tt>initialCapacity < 0 || (minLoadFactor < 0.0 || minLoadFactor >= 1.0) ||
+ * (maxLoadFactor <= 0.0 || maxLoadFactor >= 1.0) || (minLoadFactor >=
+ * maxLoadFactor)</tt>.
+ */
+ public OpenHashSet(int initialCapacity, double minLoadFactor, double maxLoadFactor) {
+ setUp(initialCapacity, minLoadFactor, maxLoadFactor);
+ }
+
+ /** Removes all values associations from the receiver. Implicitly calls <tt>trimToSize()</tt>. */
+ @Override
+ public void clear() {
+ Arrays.fill(this.state, 0, state.length - 1, FREE);
+ distinct = 0;
+ freeEntries = table.length; // delta
+ trimToSize();
+ }
+
+ /**
+ * Returns a deep copy of the receiver.
+ *
+ * @return a deep copy of the receiver.
+ */
+ @SuppressWarnings("unchecked")
+ @Override
+ public Object clone() {
+ OpenHashSet<T> copy = (OpenHashSet<T>) super.clone();
+ copy.table = copy.table.clone();
+ copy.state = copy.state.clone();
+ return copy;
+ }
+
+ /**
+ * Returns <tt>true</tt> if the receiver contains the specified key.
+ *
+ * @return <tt>true</tt> if the receiver contains the specified key.
+ */
+ @Override
+ @SuppressWarnings("unchecked")
+ public boolean contains(Object key) {
+ return indexOfKey((T)key) >= 0;
+ }
+
+ /**
+ * Ensures that the receiver can hold at least the specified number of associations without needing to allocate new
+ * internal memory. If necessary, allocates new internal memory and increases the capacity of the receiver. <p> This
+ * method never need be called; it is for performance tuning only. Calling this method before <tt>add()</tt>ing a
+ * large number of associations boosts performance, because the receiver will grow only once instead of potentially
+ * many times and hash collisions get less probable.
+ *
+ * @param minCapacity the desired minimum capacity.
+ */
+ @Override
+ public void ensureCapacity(int minCapacity) {
+ if (table.length < minCapacity) {
+ int newCapacity = nextPrime(minCapacity);
+ rehash(newCapacity);
+ }
+ }
+
+ /**
+ * Applies a procedure to each key of the receiver, if any. Note: Iterates over the keys in no particular order.
+ * Subclasses can define a particular order, for example, "sorted by key". All methods which <i>can</i> be expressed
+ * in terms of this method (most methods can) <i>must guarantee</i> to use the <i>same</i> order defined by this
+ * method, even if it is no particular order. This is necessary so that, for example, methods <tt>keys</tt> and
+ * <tt>values</tt> will yield association pairs, not two uncorrelated lists.
+ *
+ * @param procedure the procedure to be applied. Stops iteration if the procedure returns <tt>false</tt>, otherwise
+ * continues.
+ * @return <tt>false</tt> if the procedure stopped before all keys where iterated over, <tt>true</tt> otherwise.
+ */
+ @SuppressWarnings("unchecked")
+ public boolean forEachKey(ObjectProcedure<T> procedure) {
+ for (int i = table.length; i-- > 0;) {
+ if (state[i] == FULL) {
+ if (!procedure.apply((T)table[i])) {
+ return false;
+ }
+ }
+ }
+ return true;
+ }
+
+ /**
+ * @param key the key to be added to the receiver.
+ * @return the index where the key would need to be inserted, if it is not already contained. Returns -index-1 if the
+ * key is already contained at slot index. Therefore, if the returned index < 0, then it is already contained
+ * at slot -index-1. If the returned index >= 0, then it is NOT already contained and should be inserted at
+ * slot index.
+ */
+ protected int indexOfInsertion(T key) {
+ Object[] tab = table;
+ byte[] stat = state;
+ int length = tab.length;
+
+ int hash = key.hashCode() & 0x7FFFFFFF;
+ int i = hash % length;
+ int decrement = hash % (length - 2); // double hashing, see http://www.eece.unm.edu/faculty/heileman/hash/node4.html
+ //int decrement = (hash / length) % length;
+ if (decrement == 0) {
+ decrement = 1;
+ }
+
+ // stop if we find a removed or free slot, or if we find the key itself
+ // do NOT skip over removed slots (yes, open addressing is like that...)
+ while (stat[i] == FULL && tab[i] != key) {
+ i -= decrement;
+ //hashCollisions++;
+ if (i < 0) {
+ i += length;
+ }
+ }
+
+ if (stat[i] == REMOVED) {
+ // stop if we find a free slot, or if we find the key itself.
+ // do skip over removed slots (yes, open addressing is like that...)
+ // assertion: there is at least one FREE slot.
+ int j = i;
+ while (stat[i] != FREE && (stat[i] == REMOVED || tab[i] != key)) {
+ i -= decrement;
+ //hashCollisions++;
+ if (i < 0) {
+ i += length;
+ }
+ }
+ if (stat[i] == FREE) {
+ i = j;
+ }
+ }
+
+
+ if (stat[i] == FULL) {
+ // key already contained at slot i.
+ // return a negative number identifying the slot.
+ return -i - 1;
+ }
+ // not already contained, should be inserted at slot i.
+ // return a number >= 0 identifying the slot.
+ return i;
+ }
+
+ /**
+ * @param key the key to be searched in the receiver.
+ * @return the index where the key is contained in the receiver, returns -1 if the key was not found.
+ */
+ protected int indexOfKey(T key) {
+ Object[] tab = table;
+ byte[] stat = state;
+ int length = tab.length;
+
+ int hash = key.hashCode() & 0x7FFFFFFF;
+ int i = hash % length;
+ int decrement = hash % (length - 2); // double hashing, see http://www.eece.unm.edu/faculty/heileman/hash/node4.html
+ //int decrement = (hash / length) % length;
+ if (decrement == 0) {
+ decrement = 1;
+ }
+
+ // stop if we find a free slot, or if we find the key itself.
+ // do skip over removed slots (yes, open addressing is like that...)
+ while (stat[i] != FREE && (stat[i] == REMOVED || (!key.equals(tab[i])))) {
+ i -= decrement;
+ //hashCollisions++;
+ if (i < 0) {
+ i += length;
+ }
+ }
+
+ if (stat[i] == FREE) {
+ return -1;
+ } // not found
+ return i; //found, return index where key is contained
+ }
+
+ /**
+ * Fills all keys contained in the receiver into the specified list. Fills the list, starting at index 0. After this
+ * call returns the specified list has a new size that equals <tt>this.size()</tt>.
+ * This method can be used
+ * to iterate over the keys of the receiver.
+ *
+ * @param list the list to be filled, can have any size.
+ */
+ @SuppressWarnings("unchecked")
+ public void keys(List<T> list) {
+ list.clear();
+
+
+ Object [] tab = table;
+ byte[] stat = state;
+
+ for (int i = tab.length; i-- > 0;) {
+ if (stat[i] == FULL) {
+ list.add((T)tab[i]);
+ }
+ }
+ }
+
+ @SuppressWarnings("unchecked")
+ @Override
+ public boolean add(Object key) {
+ int i = indexOfInsertion((T)key);
+ if (i < 0) { //already contained
+ return false;
+ }
+
+ if (this.distinct > this.highWaterMark) {
+ int newCapacity = chooseGrowCapacity(this.distinct + 1, this.minLoadFactor, this.maxLoadFactor);
+ rehash(newCapacity);
+ return add(key);
+ }
+
+ this.table[i] = key;
+ if (this.state[i] == FREE) {
+ this.freeEntries--;
+ }
+ this.state[i] = FULL;
+ this.distinct++;
+
+ if (this.freeEntries < 1) { //delta
+ int newCapacity = chooseGrowCapacity(this.distinct + 1, this.minLoadFactor, this.maxLoadFactor);
+ rehash(newCapacity);
+ return add(key);
+ }
+
+ return true;
+ }
+
+ /**
+ * Rehashes the contents of the receiver into a new table with a smaller or larger capacity. This method is called
+ * automatically when the number of keys in the receiver exceeds the high water mark or falls below the low water
+ * mark.
+ */
+ @SuppressWarnings("unchecked")
+ protected void rehash(int newCapacity) {
+ int oldCapacity = table.length;
+ //if (oldCapacity == newCapacity) return;
+
+ Object[] oldTable = table;
+ byte[] oldState = state;
+
+ Object[] newTable = new Object[newCapacity];
+ byte[] newState = new byte[newCapacity];
+
+ this.lowWaterMark = chooseLowWaterMark(newCapacity, this.minLoadFactor);
+ this.highWaterMark = chooseHighWaterMark(newCapacity, this.maxLoadFactor);
+
+ this.table = newTable;
+ this.state = newState;
+ this.freeEntries = newCapacity - this.distinct; // delta
+
+ for (int i = oldCapacity; i-- > 0;) {
+ if (oldState[i] == FULL) {
+ Object element = oldTable[i];
+ int index = indexOfInsertion((T)element);
+ newTable[index] = element;
+ newState[index] = FULL;
+ }
+ }
+ }
+
+ /**
+ * Removes the given key with its associated element from the receiver, if present.
+ *
+ * @param key the key to be removed from the receiver.
+ * @return <tt>true</tt> if the receiver contained the specified key, <tt>false</tt> otherwise.
+ */
+ @SuppressWarnings("unchecked")
+ @Override
+ public boolean remove(Object key) {
+ int i = indexOfKey((T)key);
+ if (i < 0) {
+ return false;
+ } // key not contained
+
+ this.state[i] = REMOVED;
+ this.distinct--;
+
+ if (this.distinct < this.lowWaterMark) {
+ int newCapacity = chooseShrinkCapacity(this.distinct, this.minLoadFactor, this.maxLoadFactor);
+ rehash(newCapacity);
+ }
+
+ return true;
+ }
+
+ /**
+ * Initializes the receiver.
+ *
+ * @param initialCapacity the initial capacity of the receiver.
+ * @param minLoadFactor the minLoadFactor of the receiver.
+ * @param maxLoadFactor the maxLoadFactor of the receiver.
+ * @throws IllegalArgumentException if <tt>initialCapacity < 0 || (minLoadFactor < 0.0 || minLoadFactor >= 1.0) ||
+ * (maxLoadFactor <= 0.0 || maxLoadFactor >= 1.0) || (minLoadFactor >=
+ * maxLoadFactor)</tt>.
+ */
+ @Override
+ protected final void setUp(int initialCapacity, double minLoadFactor, double maxLoadFactor) {
+ int capacity = initialCapacity;
+ super.setUp(capacity, minLoadFactor, maxLoadFactor);
+ capacity = nextPrime(capacity);
+ if (capacity == 0) {
+ capacity = 1;
+ } // open addressing needs at least one FREE slot at any time.
+
+ this.table = new Object[capacity];
+ this.state = new byte[capacity];
+
+ // memory will be exhausted long before this pathological case happens, anyway.
+ this.minLoadFactor = minLoadFactor;
+ if (capacity == PrimeFinder.LARGEST_PRIME) {
+ this.maxLoadFactor = 1.0;
+ } else {
+ this.maxLoadFactor = maxLoadFactor;
+ }
+
+ this.distinct = 0;
+ this.freeEntries = capacity; // delta
+
+ // lowWaterMark will be established upon first expansion.
+ // establishing it now (upon instance construction) would immediately make the table shrink upon first put(...).
+ // After all the idea of an "initialCapacity" implies violating lowWaterMarks when an object is young.
+ // See ensureCapacity(...)
+ this.lowWaterMark = 0;
+ this.highWaterMark = chooseHighWaterMark(capacity, this.maxLoadFactor);
+ }
+
+ /**
+ * Trims the capacity of the receiver to be the receiver's current size. Releases any superfluous internal memory. An
+ * application can use this operation to minimize the storage of the receiver.
+ */
+ @Override
+ public void trimToSize() {
+ // * 1.2 because open addressing's performance exponentially degrades beyond that point
+ // so that even rehashing the table can take very long
+ int newCapacity = nextPrime((int) (1 + 1.2 * size()));
+ if (table.length > newCapacity) {
+ rehash(newCapacity);
+ }
+ }
+
+ /**
+ * Access for unit tests.
+ * @param capacity
+ * @param minLoadFactor
+ * @param maxLoadFactor
+ */
+ void getInternalFactors(int[] capacity,
+ double[] minLoadFactor,
+ double[] maxLoadFactor) {
+ capacity[0] = table.length;
+ minLoadFactor[0] = this.minLoadFactor;
+ maxLoadFactor[0] = this.maxLoadFactor;
+ }
+
+ @Override
+ public boolean isEmpty() {
+ return size() == 0;
+ }
+
+ /**
+ * OpenHashSet instances are only equal to other OpenHashSet instances, not to
+ * any other collection. Hypothetically, we should check for and permit
+ * equals on other Sets.
+ */
+ @Override
+ @SuppressWarnings("unchecked")
+ public boolean equals(Object obj) {
+ if (obj == this) {
+ return true;
+ }
+
+ if (!(obj instanceof OpenHashSet)) {
+ return false;
+ }
+ final OpenHashSet<T> other = (OpenHashSet<T>) obj;
+ if (other.size() != size()) {
+ return false;
+ }
+
+ return forEachKey(new ObjectProcedure<T>() {
+ @Override
+ public boolean apply(T key) {
+ return other.contains(key);
+ }
+ });
+ }
+
+ @Override
+ public int hashCode() {
+ ByteBuffer buf = ByteBuffer.allocate(size());
+ for (int i = 0; i < table.length; i++) {
+ Object v = table[i];
+ if (state[i] == FULL) {
+ buf.putInt(v.hashCode());
+ }
+ }
+ return MurmurHash.hash(buf, this.getClass().getName().hashCode());
+ }
+
+ /**
+ * Implement the standard Java Collections iterator. Note that 'remove' is silently
+ * ineffectual here. This method is provided for convenience, only.
+ */
+ @Override
+ public Iterator<T> iterator() {
+ List<T> keyList = new ArrayList<>();
+ keys(keyList);
+ return keyList.iterator();
+ }
+
+ @Override
+ public Object[] toArray() {
+ List<T> keyList = new ArrayList<>();
+ keys(keyList);
+ return keyList.toArray();
+ }
+
+ @Override
+ public boolean addAll(Collection<? extends T> c) {
+ boolean anyAdded = false;
+ for (T o : c) {
+ boolean added = add(o);
+ anyAdded |= added;
+ }
+ return anyAdded;
+ }
+
+ @Override
+ public boolean containsAll(Collection<?> c) {
+ for (Object o : c) {
+ if (!contains(o)) {
+ return false;
+ }
+ }
+ return true;
+ }
+
+ @Override
+ public boolean removeAll(Collection<?> c) {
+ boolean anyRemoved = false;
+ for (Object o : c) {
+ boolean removed = remove(o);
+ anyRemoved |= removed;
+ }
+ return anyRemoved;
+ }
+
+ @Override
+ public boolean retainAll(Collection<?> c) {
+ final Collection<?> finalCollection = c;
+ final boolean[] modified = new boolean[1];
+ modified[0] = false;
+ forEachKey(new ObjectProcedure<T>() {
+ @Override
+ public boolean apply(T element) {
+ if (!finalCollection.contains(element)) {
+ remove(element);
+ modified[0] = true;
+ }
+ return true;
+ }
+ });
+ return modified[0];
+ }
+
+ @Override
+ public <T1> T1[] toArray(T1[] a) {
+ return keys().toArray(a);
+ }
+
+ public List<T> keys() {
+ List<T> keys = new ArrayList<>();
+ keys(keys);
+ return keys;
+ }
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/solver/ConjugateGradientSolver.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/mahout/math/solver/ConjugateGradientSolver.java b/core/src/main/java/org/apache/mahout/math/solver/ConjugateGradientSolver.java
new file mode 100644
index 0000000..02bde9b
--- /dev/null
+++ b/core/src/main/java/org/apache/mahout/math/solver/ConjugateGradientSolver.java
@@ -0,0 +1,213 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math.solver;
+
+import org.apache.mahout.math.CardinalityException;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorIterable;
+import org.apache.mahout.math.function.Functions;
+import org.apache.mahout.math.function.PlusMult;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * <p>Implementation of a conjugate gradient iterative solver for linear systems. Implements both
+ * standard conjugate gradient and pre-conditioned conjugate gradient.
+ *
+ * <p>Conjugate gradient requires the matrix A in the linear system Ax = b to be symmetric and positive
+ * definite. For convenience, this implementation could be extended relatively easily to handle the
+ * case where the input matrix to be be non-symmetric, in which case the system A'Ax = b would be solved.
+ * Because this requires only one pass through the matrix A, it is faster than explicitly computing A'A,
+ * then passing the results to the solver.
+ *
+ * <p>For inputs that may be ill conditioned (often the case for highly sparse input), this solver
+ * also accepts a parameter, lambda, which adds a scaled identity to the matrix A, solving the system
+ * (A + lambda*I)x = b. This obviously changes the solution, but it will guarantee solvability. The
+ * ridge regression approach to linear regression is a common use of this feature.
+ *
+ * <p>If only an approximate solution is required, the maximum number of iterations or the error threshold
+ * may be specified to end the algorithm early at the expense of accuracy. When the matrix A is ill conditioned,
+ * it may sometimes be necessary to increase the maximum number of iterations above the default of A.numCols()
+ * due to numerical issues.
+ *
+ * <p>By default the solver will run a.numCols() iterations or until the residual falls below 1E-9.
+ *
+ * <p>For more information on the conjugate gradient algorithm, see Golub & van Loan, "Matrix Computations",
+ * sections 10.2 and 10.3 or the <a href="http://en.wikipedia.org/wiki/Conjugate_gradient">conjugate gradient
+ * wikipedia article</a>.
+ */
+
+public class ConjugateGradientSolver {
+
+ public static final double DEFAULT_MAX_ERROR = 1.0e-9;
+
+ private static final Logger log = LoggerFactory.getLogger(ConjugateGradientSolver.class);
+ private static final PlusMult PLUS_MULT = new PlusMult(1.0);
+
+ private int iterations;
+ private double residualNormSquared;
+
+ public ConjugateGradientSolver() {
+ this.iterations = 0;
+ this.residualNormSquared = Double.NaN;
+ }
+
+ /**
+ * Solves the system Ax = b with default termination criteria. A must be symmetric, square, and positive definite.
+ * Only the squareness of a is checked, since testing for symmetry and positive definiteness are too expensive. If
+ * an invalid matrix is specified, then the algorithm may not yield a valid result.
+ *
+ * @param a The linear operator A.
+ * @param b The vector b.
+ * @return The result x of solving the system.
+ * @throws IllegalArgumentException if a is not square or if the size of b is not equal to the number of columns of a.
+ *
+ */
+ public Vector solve(VectorIterable a, Vector b) {
+ return solve(a, b, null, b.size() + 2, DEFAULT_MAX_ERROR);
+ }
+
+ /**
+ * Solves the system Ax = b with default termination criteria using the specified preconditioner. A must be
+ * symmetric, square, and positive definite. Only the squareness of a is checked, since testing for symmetry
+ * and positive definiteness are too expensive. If an invalid matrix is specified, then the algorithm may not
+ * yield a valid result.
+ *
+ * @param a The linear operator A.
+ * @param b The vector b.
+ * @param precond A preconditioner to use on A during the solution process.
+ * @return The result x of solving the system.
+ * @throws IllegalArgumentException if a is not square or if the size of b is not equal to the number of columns of a.
+ *
+ */
+ public Vector solve(VectorIterable a, Vector b, Preconditioner precond) {
+ return solve(a, b, precond, b.size() + 2, DEFAULT_MAX_ERROR);
+ }
+
+
+ /**
+ * Solves the system Ax = b, where A is a linear operator and b is a vector. Uses the specified preconditioner
+ * to improve numeric stability and possibly speed convergence. This version of solve() allows control over the
+ * termination and iteration parameters.
+ *
+ * @param a The matrix A.
+ * @param b The vector b.
+ * @param preconditioner The preconditioner to apply.
+ * @param maxIterations The maximum number of iterations to run.
+ * @param maxError The maximum amount of residual error to tolerate. The algorithm will run until the residual falls
+ * below this value or until maxIterations are completed.
+ * @return The result x of solving the system.
+ * @throws IllegalArgumentException if the matrix is not square, if the size of b is not equal to the number of
+ * columns of A, if maxError is less than zero, or if maxIterations is not positive.
+ */
+
+ public Vector solve(VectorIterable a,
+ Vector b,
+ Preconditioner preconditioner,
+ int maxIterations,
+ double maxError) {
+
+ if (a.numRows() != a.numCols()) {
+ throw new IllegalArgumentException("Matrix must be square, symmetric and positive definite.");
+ }
+
+ if (a.numCols() != b.size()) {
+ throw new CardinalityException(a.numCols(), b.size());
+ }
+
+ if (maxIterations <= 0) {
+ throw new IllegalArgumentException("Max iterations must be positive.");
+ }
+
+ if (maxError < 0.0) {
+ throw new IllegalArgumentException("Max error must be non-negative.");
+ }
+
+ Vector x = new DenseVector(b.size());
+
+ iterations = 0;
+ Vector residual = b.minus(a.times(x));
+ residualNormSquared = residual.dot(residual);
+
+ log.info("Conjugate gradient initial residual norm = {}", Math.sqrt(residualNormSquared));
+ double previousConditionedNormSqr = 0.0;
+ Vector updateDirection = null;
+ while (Math.sqrt(residualNormSquared) > maxError && iterations < maxIterations) {
+ Vector conditionedResidual;
+ double conditionedNormSqr;
+ if (preconditioner == null) {
+ conditionedResidual = residual;
+ conditionedNormSqr = residualNormSquared;
+ } else {
+ conditionedResidual = preconditioner.precondition(residual);
+ conditionedNormSqr = residual.dot(conditionedResidual);
+ }
+
+ ++iterations;
+
+ if (iterations == 1) {
+ updateDirection = new DenseVector(conditionedResidual);
+ } else {
+ double beta = conditionedNormSqr / previousConditionedNormSqr;
+
+ // updateDirection = residual + beta * updateDirection
+ updateDirection.assign(Functions.MULT, beta);
+ updateDirection.assign(conditionedResidual, Functions.PLUS);
+ }
+
+ Vector aTimesUpdate = a.times(updateDirection);
+
+ double alpha = conditionedNormSqr / updateDirection.dot(aTimesUpdate);
+
+ // x = x + alpha * updateDirection
+ PLUS_MULT.setMultiplicator(alpha);
+ x.assign(updateDirection, PLUS_MULT);
+
+ // residual = residual - alpha * A * updateDirection
+ PLUS_MULT.setMultiplicator(-alpha);
+ residual.assign(aTimesUpdate, PLUS_MULT);
+
+ previousConditionedNormSqr = conditionedNormSqr;
+ residualNormSquared = residual.dot(residual);
+
+ log.info("Conjugate gradient iteration {} residual norm = {}", iterations, Math.sqrt(residualNormSquared));
+ }
+ return x;
+ }
+
+ /**
+ * Returns the number of iterations run once the solver is complete.
+ *
+ * @return The number of iterations run.
+ */
+ public int getIterations() {
+ return iterations;
+ }
+
+ /**
+ * Returns the norm of the residual at the completion of the solver. Usually this should be close to zero except in
+ * the case of a non positive definite matrix A, which results in an unsolvable system, or for ill conditioned A, in
+ * which case more iterations than the default may be needed.
+ *
+ * @return The norm of the residual in the solution.
+ */
+ public double getResidualNorm() {
+ return Math.sqrt(residualNormSquared);
+ }
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/solver/EigenDecomposition.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/mahout/math/solver/EigenDecomposition.java b/core/src/main/java/org/apache/mahout/math/solver/EigenDecomposition.java
new file mode 100644
index 0000000..871ba44
--- /dev/null
+++ b/core/src/main/java/org/apache/mahout/math/solver/EigenDecomposition.java
@@ -0,0 +1,892 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/**
+ * Adapted from the public domain Jama code.
+ */
+
+package org.apache.mahout.math.solver;
+
+import org.apache.mahout.math.DenseMatrix;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.function.Functions;
+
+/**
+ * Eigenvalues and eigenvectors of a real matrix.
+ * <p/>
+ * If A is symmetric, then A = V*D*V' where the eigenvalue matrix D is diagonal and the eigenvector
+ * matrix V is orthogonal. I.e. A = V.times(D.times(V.transpose())) and V.times(V.transpose())
+ * equals the identity matrix.
+ * <p/>
+ * If A is not symmetric, then the eigenvalue matrix D is block diagonal with the real eigenvalues
+ * in 1-by-1 blocks and any complex eigenvalues, lambda + i*mu, in 2-by-2 blocks, [lambda, mu; -mu,
+ * lambda]. The columns of V represent the eigenvectors in the sense that A*V = V*D, i.e.
+ * A.times(V) equals V.times(D). The matrix V may be badly conditioned, or even singular, so the
+ * validity of the equation A = V*D*inverse(V) depends upon V.cond().
+ */
+public class EigenDecomposition {
+
+ /** Row and column dimension (square matrix). */
+ private final int n;
+ /** Arrays for internal storage of eigenvalues. */
+ private final Vector d;
+ private final Vector e;
+ /** Array for internal storage of eigenvectors. */
+ private final Matrix v;
+
+ public EigenDecomposition(Matrix x) {
+ this(x, isSymmetric(x));
+ }
+
+ public EigenDecomposition(Matrix x, boolean isSymmetric) {
+ n = x.columnSize();
+ d = new DenseVector(n);
+ e = new DenseVector(n);
+ v = new DenseMatrix(n, n);
+
+ if (isSymmetric) {
+ v.assign(x);
+
+ // Tridiagonalize.
+ tred2();
+
+ // Diagonalize.
+ tql2();
+
+ } else {
+ // Reduce to Hessenberg form.
+ // Reduce Hessenberg to real Schur form.
+ hqr2(orthes(x));
+ }
+ }
+
+ /**
+ * Return the eigenvector matrix
+ *
+ * @return V
+ */
+ public Matrix getV() {
+ return v.like().assign(v);
+ }
+
+ /**
+ * Return the real parts of the eigenvalues
+ */
+ public Vector getRealEigenvalues() {
+ return d;
+ }
+
+ /**
+ * Return the imaginary parts of the eigenvalues
+ */
+ public Vector getImagEigenvalues() {
+ return e;
+ }
+
+ /**
+ * Return the block diagonal eigenvalue matrix
+ *
+ * @return D
+ */
+ public Matrix getD() {
+ Matrix x = new DenseMatrix(n, n);
+ x.assign(0);
+ x.viewDiagonal().assign(d);
+ for (int i = 0; i < n; i++) {
+ double v = e.getQuick(i);
+ if (v > 0) {
+ x.setQuick(i, i + 1, v);
+ } else if (v < 0) {
+ x.setQuick(i, i - 1, v);
+ }
+ }
+ return x;
+ }
+
+ // Symmetric Householder reduction to tridiagonal form.
+ private void tred2() {
+ // This is derived from the Algol procedures tred2 by
+ // Bowdler, Martin, Reinsch, and Wilkinson, Handbook for
+ // Auto. Comp., Vol.ii-Linear Algebra, and the corresponding
+ // Fortran subroutine in EISPACK.
+
+ d.assign(v.viewColumn(n - 1));
+
+ // Householder reduction to tridiagonal form.
+
+ for (int i = n - 1; i > 0; i--) {
+
+ // Scale to avoid under/overflow.
+
+ double scale = d.viewPart(0, i).norm(1);
+ double h = 0.0;
+
+
+ if (scale == 0.0) {
+ e.setQuick(i, d.getQuick(i - 1));
+ for (int j = 0; j < i; j++) {
+ d.setQuick(j, v.getQuick(i - 1, j));
+ v.setQuick(i, j, 0.0);
+ v.setQuick(j, i, 0.0);
+ }
+ } else {
+
+ // Generate Householder vector.
+
+ for (int k = 0; k < i; k++) {
+ d.setQuick(k, d.getQuick(k) / scale);
+ h += d.getQuick(k) * d.getQuick(k);
+ }
+ double f = d.getQuick(i - 1);
+ double g = Math.sqrt(h);
+ if (f > 0) {
+ g = -g;
+ }
+ e.setQuick(i, scale * g);
+ h -= f * g;
+ d.setQuick(i - 1, f - g);
+ for (int j = 0; j < i; j++) {
+ e.setQuick(j, 0.0);
+ }
+
+ // Apply similarity transformation to remaining columns.
+
+ for (int j = 0; j < i; j++) {
+ f = d.getQuick(j);
+ v.setQuick(j, i, f);
+ g = e.getQuick(j) + v.getQuick(j, j) * f;
+ for (int k = j + 1; k <= i - 1; k++) {
+ g += v.getQuick(k, j) * d.getQuick(k);
+ e.setQuick(k, e.getQuick(k) + v.getQuick(k, j) * f);
+ }
+ e.setQuick(j, g);
+ }
+ f = 0.0;
+ for (int j = 0; j < i; j++) {
+ e.setQuick(j, e.getQuick(j) / h);
+ f += e.getQuick(j) * d.getQuick(j);
+ }
+ double hh = f / (h + h);
+ for (int j = 0; j < i; j++) {
+ e.setQuick(j, e.getQuick(j) - hh * d.getQuick(j));
+ }
+ for (int j = 0; j < i; j++) {
+ f = d.getQuick(j);
+ g = e.getQuick(j);
+ for (int k = j; k <= i - 1; k++) {
+ v.setQuick(k, j, v.getQuick(k, j) - (f * e.getQuick(k) + g * d.getQuick(k)));
+ }
+ d.setQuick(j, v.getQuick(i - 1, j));
+ v.setQuick(i, j, 0.0);
+ }
+ }
+ d.setQuick(i, h);
+ }
+
+ // Accumulate transformations.
+
+ for (int i = 0; i < n - 1; i++) {
+ v.setQuick(n - 1, i, v.getQuick(i, i));
+ v.setQuick(i, i, 1.0);
+ double h = d.getQuick(i + 1);
+ if (h != 0.0) {
+ for (int k = 0; k <= i; k++) {
+ d.setQuick(k, v.getQuick(k, i + 1) / h);
+ }
+ for (int j = 0; j <= i; j++) {
+ double g = 0.0;
+ for (int k = 0; k <= i; k++) {
+ g += v.getQuick(k, i + 1) * v.getQuick(k, j);
+ }
+ for (int k = 0; k <= i; k++) {
+ v.setQuick(k, j, v.getQuick(k, j) - g * d.getQuick(k));
+ }
+ }
+ }
+ for (int k = 0; k <= i; k++) {
+ v.setQuick(k, i + 1, 0.0);
+ }
+ }
+ d.assign(v.viewRow(n - 1));
+ v.viewRow(n - 1).assign(0);
+ v.setQuick(n - 1, n - 1, 1.0);
+ e.setQuick(0, 0.0);
+ }
+
+ // Symmetric tridiagonal QL algorithm.
+ private void tql2() {
+
+ // This is derived from the Algol procedures tql2, by
+ // Bowdler, Martin, Reinsch, and Wilkinson, Handbook for
+ // Auto. Comp., Vol.ii-Linear Algebra, and the corresponding
+ // Fortran subroutine in EISPACK.
+
+ e.viewPart(0, n - 1).assign(e.viewPart(1, n - 1));
+ e.setQuick(n - 1, 0.0);
+
+ double f = 0.0;
+ double tst1 = 0.0;
+ double eps = Math.pow(2.0, -52.0);
+ for (int l = 0; l < n; l++) {
+
+ // Find small subdiagonal element
+
+ tst1 = Math.max(tst1, Math.abs(d.getQuick(l)) + Math.abs(e.getQuick(l)));
+ int m = l;
+ while (m < n) {
+ if (Math.abs(e.getQuick(m)) <= eps * tst1) {
+ break;
+ }
+ m++;
+ }
+
+ // If m == l, d.getQuick(l) is an eigenvalue,
+ // otherwise, iterate.
+
+ if (m > l) {
+ do {
+ // Compute implicit shift
+
+ double g = d.getQuick(l);
+ double p = (d.getQuick(l + 1) - g) / (2.0 * e.getQuick(l));
+ double r = Math.hypot(p, 1.0);
+ if (p < 0) {
+ r = -r;
+ }
+ d.setQuick(l, e.getQuick(l) / (p + r));
+ d.setQuick(l + 1, e.getQuick(l) * (p + r));
+ double dl1 = d.getQuick(l + 1);
+ double h = g - d.getQuick(l);
+ for (int i = l + 2; i < n; i++) {
+ d.setQuick(i, d.getQuick(i) - h);
+ }
+ f += h;
+
+ // Implicit QL transformation.
+
+ p = d.getQuick(m);
+ double c = 1.0;
+ double c2 = c;
+ double c3 = c;
+ double el1 = e.getQuick(l + 1);
+ double s = 0.0;
+ double s2 = 0.0;
+ for (int i = m - 1; i >= l; i--) {
+ c3 = c2;
+ c2 = c;
+ s2 = s;
+ g = c * e.getQuick(i);
+ h = c * p;
+ r = Math.hypot(p, e.getQuick(i));
+ e.setQuick(i + 1, s * r);
+ s = e.getQuick(i) / r;
+ c = p / r;
+ p = c * d.getQuick(i) - s * g;
+ d.setQuick(i + 1, h + s * (c * g + s * d.getQuick(i)));
+
+ // Accumulate transformation.
+
+ for (int k = 0; k < n; k++) {
+ h = v.getQuick(k, i + 1);
+ v.setQuick(k, i + 1, s * v.getQuick(k, i) + c * h);
+ v.setQuick(k, i, c * v.getQuick(k, i) - s * h);
+ }
+ }
+ p = -s * s2 * c3 * el1 * e.getQuick(l) / dl1;
+ e.setQuick(l, s * p);
+ d.setQuick(l, c * p);
+
+ // Check for convergence.
+
+ } while (Math.abs(e.getQuick(l)) > eps * tst1);
+ }
+ d.setQuick(l, d.getQuick(l) + f);
+ e.setQuick(l, 0.0);
+ }
+
+ // Sort eigenvalues and corresponding vectors.
+
+ for (int i = 0; i < n - 1; i++) {
+ int k = i;
+ double p = d.getQuick(i);
+ for (int j = i + 1; j < n; j++) {
+ if (d.getQuick(j) > p) {
+ k = j;
+ p = d.getQuick(j);
+ }
+ }
+ if (k != i) {
+ d.setQuick(k, d.getQuick(i));
+ d.setQuick(i, p);
+ for (int j = 0; j < n; j++) {
+ p = v.getQuick(j, i);
+ v.setQuick(j, i, v.getQuick(j, k));
+ v.setQuick(j, k, p);
+ }
+ }
+ }
+ }
+
+ // Nonsymmetric reduction to Hessenberg form.
+ private Matrix orthes(Matrix x) {
+ // Working storage for nonsymmetric algorithm.
+ Vector ort = new DenseVector(n);
+ Matrix hessenBerg = new DenseMatrix(n, n).assign(x);
+
+ // This is derived from the Algol procedures orthes and ortran,
+ // by Martin and Wilkinson, Handbook for Auto. Comp.,
+ // Vol.ii-Linear Algebra, and the corresponding
+ // Fortran subroutines in EISPACK.
+
+ int low = 0;
+ int high = n - 1;
+
+ for (int m = low + 1; m <= high - 1; m++) {
+
+ // Scale column.
+
+ Vector hColumn = hessenBerg.viewColumn(m - 1).viewPart(m, high - m + 1);
+ double scale = hColumn.norm(1);
+
+ if (scale != 0.0) {
+ // Compute Householder transformation.
+
+ ort.viewPart(m, high - m + 1).assign(hColumn, Functions.plusMult(1 / scale));
+ double h = ort.viewPart(m, high - m + 1).getLengthSquared();
+
+ double g = Math.sqrt(h);
+ if (ort.getQuick(m) > 0) {
+ g = -g;
+ }
+ h -= ort.getQuick(m) * g;
+ ort.setQuick(m, ort.getQuick(m) - g);
+
+ // Apply Householder similarity transformation
+ // H = (I-u*u'/h)*H*(I-u*u')/h)
+
+ Vector ortPiece = ort.viewPart(m, high - m + 1);
+ for (int j = m; j < n; j++) {
+ double f = ortPiece.dot(hessenBerg.viewColumn(j).viewPart(m, high - m + 1)) / h;
+ hessenBerg.viewColumn(j).viewPart(m, high - m + 1).assign(ortPiece, Functions.plusMult(-f));
+ }
+
+ for (int i = 0; i <= high; i++) {
+ double f = ortPiece.dot(hessenBerg.viewRow(i).viewPart(m, high - m + 1)) / h;
+ hessenBerg.viewRow(i).viewPart(m, high - m + 1).assign(ortPiece, Functions.plusMult(-f));
+ }
+ ort.setQuick(m, scale * ort.getQuick(m));
+ hessenBerg.setQuick(m, m - 1, scale * g);
+ }
+ }
+
+ // Accumulate transformations (Algol's ortran).
+
+ v.assign(0);
+ v.viewDiagonal().assign(1);
+
+ for (int m = high - 1; m >= low + 1; m--) {
+ if (hessenBerg.getQuick(m, m - 1) != 0.0) {
+ ort.viewPart(m + 1, high - m).assign(hessenBerg.viewColumn(m - 1).viewPart(m + 1, high - m));
+ for (int j = m; j <= high; j++) {
+ double g = ort.viewPart(m, high - m + 1).dot(v.viewColumn(j).viewPart(m, high - m + 1));
+ // Double division avoids possible underflow
+ g = g / ort.getQuick(m) / hessenBerg.getQuick(m, m - 1);
+ v.viewColumn(j).viewPart(m, high - m + 1).assign(ort.viewPart(m, high - m + 1), Functions.plusMult(g));
+ }
+ }
+ }
+ return hessenBerg;
+ }
+
+
+ // Complex scalar division.
+ private double cdivr;
+ private double cdivi;
+
+ private void cdiv(double xr, double xi, double yr, double yi) {
+ double r;
+ double d;
+ if (Math.abs(yr) > Math.abs(yi)) {
+ r = yi / yr;
+ d = yr + r * yi;
+ cdivr = (xr + r * xi) / d;
+ cdivi = (xi - r * xr) / d;
+ } else {
+ r = yr / yi;
+ d = yi + r * yr;
+ cdivr = (r * xr + xi) / d;
+ cdivi = (r * xi - xr) / d;
+ }
+ }
+
+
+ // Nonsymmetric reduction from Hessenberg to real Schur form.
+
+ private void hqr2(Matrix h) {
+
+ // This is derived from the Algol procedure hqr2,
+ // by Martin and Wilkinson, Handbook for Auto. Comp.,
+ // Vol.ii-Linear Algebra, and the corresponding
+ // Fortran subroutine in EISPACK.
+
+ // Initialize
+
+ int nn = this.n;
+ int n = nn - 1;
+ int low = 0;
+ int high = nn - 1;
+ double eps = Math.pow(2.0, -52.0);
+ double exshift = 0.0;
+ double p = 0;
+ double q = 0;
+ double r = 0;
+ double s = 0;
+ double z = 0;
+ double w;
+ double x;
+ double y;
+
+ // Store roots isolated by balanc and compute matrix norm
+
+ double norm = h.aggregate(Functions.PLUS, Functions.ABS);
+
+ // Outer loop over eigenvalue index
+
+ int iter = 0;
+ while (n >= low) {
+
+ // Look for single small sub-diagonal element
+
+ int l = n;
+ while (l > low) {
+ s = Math.abs(h.getQuick(l - 1, l - 1)) + Math.abs(h.getQuick(l, l));
+ if (s == 0.0) {
+ s = norm;
+ }
+ if (Math.abs(h.getQuick(l, l - 1)) < eps * s) {
+ break;
+ }
+ l--;
+ }
+
+ // Check for convergence
+
+ if (l == n) {
+ // One root found
+ h.setQuick(n, n, h.getQuick(n, n) + exshift);
+ d.setQuick(n, h.getQuick(n, n));
+ e.setQuick(n, 0.0);
+ n--;
+ iter = 0;
+
+
+ } else if (l == n - 1) {
+ // Two roots found
+ w = h.getQuick(n, n - 1) * h.getQuick(n - 1, n);
+ p = (h.getQuick(n - 1, n - 1) - h.getQuick(n, n)) / 2.0;
+ q = p * p + w;
+ z = Math.sqrt(Math.abs(q));
+ h.setQuick(n, n, h.getQuick(n, n) + exshift);
+ h.setQuick(n - 1, n - 1, h.getQuick(n - 1, n - 1) + exshift);
+ x = h.getQuick(n, n);
+
+ // Real pair
+ if (q >= 0) {
+ if (p >= 0) {
+ z = p + z;
+ } else {
+ z = p - z;
+ }
+ d.setQuick(n - 1, x + z);
+ d.setQuick(n, d.getQuick(n - 1));
+ if (z != 0.0) {
+ d.setQuick(n, x - w / z);
+ }
+ e.setQuick(n - 1, 0.0);
+ e.setQuick(n, 0.0);
+ x = h.getQuick(n, n - 1);
+ s = Math.abs(x) + Math.abs(z);
+ p = x / s;
+ q = z / s;
+ r = Math.sqrt(p * p + q * q);
+ p /= r;
+ q /= r;
+
+ // Row modification
+
+ for (int j = n - 1; j < nn; j++) {
+ z = h.getQuick(n - 1, j);
+ h.setQuick(n - 1, j, q * z + p * h.getQuick(n, j));
+ h.setQuick(n, j, q * h.getQuick(n, j) - p * z);
+ }
+
+ // Column modification
+
+ for (int i = 0; i <= n; i++) {
+ z = h.getQuick(i, n - 1);
+ h.setQuick(i, n - 1, q * z + p * h.getQuick(i, n));
+ h.setQuick(i, n, q * h.getQuick(i, n) - p * z);
+ }
+
+ // Accumulate transformations
+
+ for (int i = low; i <= high; i++) {
+ z = v.getQuick(i, n - 1);
+ v.setQuick(i, n - 1, q * z + p * v.getQuick(i, n));
+ v.setQuick(i, n, q * v.getQuick(i, n) - p * z);
+ }
+
+ // Complex pair
+
+ } else {
+ d.setQuick(n - 1, x + p);
+ d.setQuick(n, x + p);
+ e.setQuick(n - 1, z);
+ e.setQuick(n, -z);
+ }
+ n -= 2;
+ iter = 0;
+
+ // No convergence yet
+
+ } else {
+
+ // Form shift
+
+ x = h.getQuick(n, n);
+ y = 0.0;
+ w = 0.0;
+ if (l < n) {
+ y = h.getQuick(n - 1, n - 1);
+ w = h.getQuick(n, n - 1) * h.getQuick(n - 1, n);
+ }
+
+ // Wilkinson's original ad hoc shift
+
+ if (iter == 10) {
+ exshift += x;
+ for (int i = low; i <= n; i++) {
+ h.setQuick(i, i, x);
+ }
+ s = Math.abs(h.getQuick(n, n - 1)) + Math.abs(h.getQuick(n - 1, n - 2));
+ x = y = 0.75 * s;
+ w = -0.4375 * s * s;
+ }
+
+ // MATLAB's new ad hoc shift
+
+ if (iter == 30) {
+ s = (y - x) / 2.0;
+ s = s * s + w;
+ if (s > 0) {
+ s = Math.sqrt(s);
+ if (y < x) {
+ s = -s;
+ }
+ s = x - w / ((y - x) / 2.0 + s);
+ for (int i = low; i <= n; i++) {
+ h.setQuick(i, i, h.getQuick(i, i) - s);
+ }
+ exshift += s;
+ x = y = w = 0.964;
+ }
+ }
+
+ iter++; // (Could check iteration count here.)
+
+ // Look for two consecutive small sub-diagonal elements
+
+ int m = n - 2;
+ while (m >= l) {
+ z = h.getQuick(m, m);
+ r = x - z;
+ s = y - z;
+ p = (r * s - w) / h.getQuick(m + 1, m) + h.getQuick(m, m + 1);
+ q = h.getQuick(m + 1, m + 1) - z - r - s;
+ r = h.getQuick(m + 2, m + 1);
+ s = Math.abs(p) + Math.abs(q) + Math.abs(r);
+ p /= s;
+ q /= s;
+ r /= s;
+ if (m == l) {
+ break;
+ }
+ double hmag = Math.abs(h.getQuick(m - 1, m - 1)) + Math.abs(h.getQuick(m + 1, m + 1));
+ double threshold = eps * Math.abs(p) * (Math.abs(z) + hmag);
+ if (Math.abs(h.getQuick(m, m - 1)) * (Math.abs(q) + Math.abs(r)) < threshold) {
+ break;
+ }
+ m--;
+ }
+
+ for (int i = m + 2; i <= n; i++) {
+ h.setQuick(i, i - 2, 0.0);
+ if (i > m + 2) {
+ h.setQuick(i, i - 3, 0.0);
+ }
+ }
+
+ // Double QR step involving rows l:n and columns m:n
+
+ for (int k = m; k <= n - 1; k++) {
+ boolean notlast = k != n - 1;
+ if (k != m) {
+ p = h.getQuick(k, k - 1);
+ q = h.getQuick(k + 1, k - 1);
+ r = notlast ? h.getQuick(k + 2, k - 1) : 0.0;
+ x = Math.abs(p) + Math.abs(q) + Math.abs(r);
+ if (x != 0.0) {
+ p /= x;
+ q /= x;
+ r /= x;
+ }
+ }
+ if (x == 0.0) {
+ break;
+ }
+ s = Math.sqrt(p * p + q * q + r * r);
+ if (p < 0) {
+ s = -s;
+ }
+ if (s != 0) {
+ if (k != m) {
+ h.setQuick(k, k - 1, -s * x);
+ } else if (l != m) {
+ h.setQuick(k, k - 1, -h.getQuick(k, k - 1));
+ }
+ p += s;
+ x = p / s;
+ y = q / s;
+ z = r / s;
+ q /= p;
+ r /= p;
+
+ // Row modification
+
+ for (int j = k; j < nn; j++) {
+ p = h.getQuick(k, j) + q * h.getQuick(k + 1, j);
+ if (notlast) {
+ p += r * h.getQuick(k + 2, j);
+ h.setQuick(k + 2, j, h.getQuick(k + 2, j) - p * z);
+ }
+ h.setQuick(k, j, h.getQuick(k, j) - p * x);
+ h.setQuick(k + 1, j, h.getQuick(k + 1, j) - p * y);
+ }
+
+ // Column modification
+
+ for (int i = 0; i <= Math.min(n, k + 3); i++) {
+ p = x * h.getQuick(i, k) + y * h.getQuick(i, k + 1);
+ if (notlast) {
+ p += z * h.getQuick(i, k + 2);
+ h.setQuick(i, k + 2, h.getQuick(i, k + 2) - p * r);
+ }
+ h.setQuick(i, k, h.getQuick(i, k) - p);
+ h.setQuick(i, k + 1, h.getQuick(i, k + 1) - p * q);
+ }
+
+ // Accumulate transformations
+
+ for (int i = low; i <= high; i++) {
+ p = x * v.getQuick(i, k) + y * v.getQuick(i, k + 1);
+ if (notlast) {
+ p += z * v.getQuick(i, k + 2);
+ v.setQuick(i, k + 2, v.getQuick(i, k + 2) - p * r);
+ }
+ v.setQuick(i, k, v.getQuick(i, k) - p);
+ v.setQuick(i, k + 1, v.getQuick(i, k + 1) - p * q);
+ }
+ } // (s != 0)
+ } // k loop
+ } // check convergence
+ } // while (n >= low)
+
+ // Backsubstitute to find vectors of upper triangular form
+
+ if (norm == 0.0) {
+ return;
+ }
+
+ for (n = nn - 1; n >= 0; n--) {
+ p = d.getQuick(n);
+ q = e.getQuick(n);
+
+ // Real vector
+
+ double t;
+ if (q == 0) {
+ int l = n;
+ h.setQuick(n, n, 1.0);
+ for (int i = n - 1; i >= 0; i--) {
+ w = h.getQuick(i, i) - p;
+ r = 0.0;
+ for (int j = l; j <= n; j++) {
+ r += h.getQuick(i, j) * h.getQuick(j, n);
+ }
+ if (e.getQuick(i) < 0.0) {
+ z = w;
+ s = r;
+ } else {
+ l = i;
+ if (e.getQuick(i) == 0.0) {
+ if (w == 0.0) {
+ h.setQuick(i, n, -r / (eps * norm));
+ } else {
+ h.setQuick(i, n, -r / w);
+ }
+
+ // Solve real equations
+
+ } else {
+ x = h.getQuick(i, i + 1);
+ y = h.getQuick(i + 1, i);
+ q = (d.getQuick(i) - p) * (d.getQuick(i) - p) + e.getQuick(i) * e.getQuick(i);
+ t = (x * s - z * r) / q;
+ h.setQuick(i, n, t);
+ if (Math.abs(x) > Math.abs(z)) {
+ h.setQuick(i + 1, n, (-r - w * t) / x);
+ } else {
+ h.setQuick(i + 1, n, (-s - y * t) / z);
+ }
+ }
+
+ // Overflow control
+
+ t = Math.abs(h.getQuick(i, n));
+ if (eps * t * t > 1) {
+ for (int j = i; j <= n; j++) {
+ h.setQuick(j, n, h.getQuick(j, n) / t);
+ }
+ }
+ }
+ }
+
+ // Complex vector
+
+ } else if (q < 0) {
+ int l = n - 1;
+
+ // Last vector component imaginary so matrix is triangular
+
+ if (Math.abs(h.getQuick(n, n - 1)) > Math.abs(h.getQuick(n - 1, n))) {
+ h.setQuick(n - 1, n - 1, q / h.getQuick(n, n - 1));
+ h.setQuick(n - 1, n, -(h.getQuick(n, n) - p) / h.getQuick(n, n - 1));
+ } else {
+ cdiv(0.0, -h.getQuick(n - 1, n), h.getQuick(n - 1, n - 1) - p, q);
+ h.setQuick(n - 1, n - 1, cdivr);
+ h.setQuick(n - 1, n, cdivi);
+ }
+ h.setQuick(n, n - 1, 0.0);
+ h.setQuick(n, n, 1.0);
+ for (int i = n - 2; i >= 0; i--) {
+ double ra = 0.0;
+ double sa = 0.0;
+ for (int j = l; j <= n; j++) {
+ ra += h.getQuick(i, j) * h.getQuick(j, n - 1);
+ sa += h.getQuick(i, j) * h.getQuick(j, n);
+ }
+ w = h.getQuick(i, i) - p;
+
+ if (e.getQuick(i) < 0.0) {
+ z = w;
+ r = ra;
+ s = sa;
+ } else {
+ l = i;
+ if (e.getQuick(i) == 0) {
+ cdiv(-ra, -sa, w, q);
+ h.setQuick(i, n - 1, cdivr);
+ h.setQuick(i, n, cdivi);
+ } else {
+
+ // Solve complex equations
+
+ x = h.getQuick(i, i + 1);
+ y = h.getQuick(i + 1, i);
+ double vr = (d.getQuick(i) - p) * (d.getQuick(i) - p) + e.getQuick(i) * e.getQuick(i) - q * q;
+ double vi = (d.getQuick(i) - p) * 2.0 * q;
+ if (vr == 0.0 && vi == 0.0) {
+ double hmag = Math.abs(x) + Math.abs(y);
+ vr = eps * norm * (Math.abs(w) + Math.abs(q) + hmag + Math.abs(z));
+ }
+ cdiv(x * r - z * ra + q * sa, x * s - z * sa - q * ra, vr, vi);
+ h.setQuick(i, n - 1, cdivr);
+ h.setQuick(i, n, cdivi);
+ if (Math.abs(x) > (Math.abs(z) + Math.abs(q))) {
+ h.setQuick(i + 1, n - 1, (-ra - w * h.getQuick(i, n - 1) + q * h.getQuick(i, n)) / x);
+ h.setQuick(i + 1, n, (-sa - w * h.getQuick(i, n) - q * h.getQuick(i, n - 1)) / x);
+ } else {
+ cdiv(-r - y * h.getQuick(i, n - 1), -s - y * h.getQuick(i, n), z, q);
+ h.setQuick(i + 1, n - 1, cdivr);
+ h.setQuick(i + 1, n, cdivi);
+ }
+ }
+
+ // Overflow control
+
+ t = Math.max(Math.abs(h.getQuick(i, n - 1)), Math.abs(h.getQuick(i, n)));
+ if (eps * t * t > 1) {
+ for (int j = i; j <= n; j++) {
+ h.setQuick(j, n - 1, h.getQuick(j, n - 1) / t);
+ h.setQuick(j, n, h.getQuick(j, n) / t);
+ }
+ }
+ }
+ }
+ }
+ }
+
+ // Vectors of isolated roots
+
+ for (int i = 0; i < nn; i++) {
+ if (i < low || i > high) {
+ for (int j = i; j < nn; j++) {
+ v.setQuick(i, j, h.getQuick(i, j));
+ }
+ }
+ }
+
+ // Back transformation to get eigenvectors of original matrix
+
+ for (int j = nn - 1; j >= low; j--) {
+ for (int i = low; i <= high; i++) {
+ z = 0.0;
+ for (int k = low; k <= Math.min(j, high); k++) {
+ z += v.getQuick(i, k) * h.getQuick(k, j);
+ }
+ v.setQuick(i, j, z);
+ }
+ }
+ }
+
+ private static boolean isSymmetric(Matrix a) {
+ /*
+ Symmetry flag.
+ */
+ int n = a.columnSize();
+
+ boolean isSymmetric = true;
+ for (int j = 0; (j < n) && isSymmetric; j++) {
+ for (int i = 0; (i < n) && isSymmetric; i++) {
+ isSymmetric = a.getQuick(i, j) == a.getQuick(j, i);
+ }
+ }
+ return isSymmetric;
+ }
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/solver/JacobiConditioner.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/mahout/math/solver/JacobiConditioner.java b/core/src/main/java/org/apache/mahout/math/solver/JacobiConditioner.java
new file mode 100644
index 0000000..7524564
--- /dev/null
+++ b/core/src/main/java/org/apache/mahout/math/solver/JacobiConditioner.java
@@ -0,0 +1,47 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math.solver;
+
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.Vector;
+
+/**
+ * Implements the Jacobi preconditioner for a matrix A. This is defined as inv(diag(A)).
+ */
+public final class JacobiConditioner implements Preconditioner {
+
+ private final DenseVector inverseDiagonal;
+
+ public JacobiConditioner(Matrix a) {
+ if (a.numCols() != a.numRows()) {
+ throw new IllegalArgumentException("Matrix must be square.");
+ }
+
+ inverseDiagonal = new DenseVector(a.numCols());
+ for (int i = 0; i < a.numCols(); ++i) {
+ inverseDiagonal.setQuick(i, 1.0 / a.getQuick(i, i));
+ }
+ }
+
+ @Override
+ public Vector precondition(Vector v) {
+ return v.times(inverseDiagonal);
+ }
+
+}
r***@apache.org
2018-09-08 23:35:15 UTC
Permalink
http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/SingularValueDecomposition.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/mahout/math/SingularValueDecomposition.java b/core/src/main/java/org/apache/mahout/math/SingularValueDecomposition.java
new file mode 100644
index 0000000..2abff10
--- /dev/null
+++ b/core/src/main/java/org/apache/mahout/math/SingularValueDecomposition.java
@@ -0,0 +1,669 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *
+ * Copyright 1999 CERN - European Organization for Nuclear Research.
+ * Permission to use, copy, modify, distribute and sell this software and its documentation for any purpose
+ * is hereby granted without fee, provided that the above copyright notice appear in all copies and
+ * that both that copyright notice and this permission notice appear in supporting documentation.
+ * CERN makes no representations about the suitability of this software for any purpose.
+ * It is provided "as is" without expressed or implied warranty.
+ */
+package org.apache.mahout.math;
+
+public class SingularValueDecomposition implements java.io.Serializable {
+
+ /** Arrays for internal storage of U and V. */
+ private final double[][] u;
+ private final double[][] v;
+
+ /** Array for internal storage of singular values. */
+ private final double[] s;
+
+ /** Row and column dimensions. */
+ private final int m;
+ private final int n;
+
+ /**To handle the case where numRows() < numCols() and to use the fact that SVD(A')=VSU'=> SVD(A')'=SVD(A)**/
+ private boolean transpositionNeeded;
+
+ /**
+ * Constructs and returns a new singular value decomposition object; The
+ * decomposed matrices can be retrieved via instance methods of the returned
+ * decomposition object.
+ *
+ * @param arg
+ * A rectangular matrix.
+ */
+ public SingularValueDecomposition(Matrix arg) {
+ if (arg.numRows() < arg.numCols()) {
+ transpositionNeeded = true;
+ }
+
+ // Derived from LINPACK code.
+ // Initialize.
+ double[][] a;
+ if (transpositionNeeded) {
+ //use the transpose Matrix
+ m = arg.numCols();
+ n = arg.numRows();
+ a = new double[m][n];
+ for (int i = 0; i < m; i++) {
+ for (int j = 0; j < n; j++) {
+ a[i][j] = arg.get(j, i);
+ }
+ }
+ } else {
+ m = arg.numRows();
+ n = arg.numCols();
+ a = new double[m][n];
+ for (int i = 0; i < m; i++) {
+ for (int j = 0; j < n; j++) {
+ a[i][j] = arg.get(i, j);
+ }
+ }
+ }
+
+
+ int nu = Math.min(m, n);
+ s = new double[Math.min(m + 1, n)];
+ u = new double[m][nu];
+ v = new double[n][n];
+ double[] e = new double[n];
+ double[] work = new double[m];
+ boolean wantu = true;
+ boolean wantv = true;
+
+ // Reduce A to bidiagonal form, storing the diagonal elements
+ // in s and the super-diagonal elements in e.
+
+ int nct = Math.min(m - 1, n);
+ int nrt = Math.max(0, Math.min(n - 2, m));
+ for (int k = 0; k < Math.max(nct, nrt); k++) {
+ if (k < nct) {
+
+ // Compute the transformation for the k-th column and
+ // place the k-th diagonal in s[k].
+ // Compute 2-norm of k-th column without under/overflow.
+ s[k] = 0;
+ for (int i = k; i < m; i++) {
+ s[k] = Algebra.hypot(s[k], a[i][k]);
+ }
+ if (s[k] != 0.0) {
+ if (a[k][k] < 0.0) {
+ s[k] = -s[k];
+ }
+ for (int i = k; i < m; i++) {
+ a[i][k] /= s[k];
+ }
+ a[k][k] += 1.0;
+ }
+ s[k] = -s[k];
+ }
+ for (int j = k + 1; j < n; j++) {
+ if (k < nct && s[k] != 0.0) {
+
+ // Apply the transformation.
+
+ double t = 0;
+ for (int i = k; i < m; i++) {
+ t += a[i][k] * a[i][j];
+ }
+ t = -t / a[k][k];
+ for (int i = k; i < m; i++) {
+ a[i][j] += t * a[i][k];
+ }
+ }
+
+ // Place the k-th row of A into e for the
+ // subsequent calculation of the row transformation.
+
+ e[j] = a[k][j];
+ }
+ if (wantu && k < nct) {
+
+ // Place the transformation in U for subsequent back
+ // multiplication.
+
+ for (int i = k; i < m; i++) {
+ u[i][k] = a[i][k];
+ }
+ }
+ if (k < nrt) {
+
+ // Compute the k-th row transformation and place the
+ // k-th super-diagonal in e[k].
+ // Compute 2-norm without under/overflow.
+ e[k] = 0;
+ for (int i = k + 1; i < n; i++) {
+ e[k] = Algebra.hypot(e[k], e[i]);
+ }
+ if (e[k] != 0.0) {
+ if (e[k + 1] < 0.0) {
+ e[k] = -e[k];
+ }
+ for (int i = k + 1; i < n; i++) {
+ e[i] /= e[k];
+ }
+ e[k + 1] += 1.0;
+ }
+ e[k] = -e[k];
+ if (k + 1 < m && e[k] != 0.0) {
+
+ // Apply the transformation.
+
+ for (int i = k + 1; i < m; i++) {
+ work[i] = 0.0;
+ }
+ for (int j = k + 1; j < n; j++) {
+ for (int i = k + 1; i < m; i++) {
+ work[i] += e[j] * a[i][j];
+ }
+ }
+ for (int j = k + 1; j < n; j++) {
+ double t = -e[j] / e[k + 1];
+ for (int i = k + 1; i < m; i++) {
+ a[i][j] += t * work[i];
+ }
+ }
+ }
+ if (wantv) {
+
+ // Place the transformation in V for subsequent
+ // back multiplication.
+
+ for (int i = k + 1; i < n; i++) {
+ v[i][k] = e[i];
+ }
+ }
+ }
+ }
+
+ // Set up the final bidiagonal matrix or order p.
+
+ int p = Math.min(n, m + 1);
+ if (nct < n) {
+ s[nct] = a[nct][nct];
+ }
+ if (m < p) {
+ s[p - 1] = 0.0;
+ }
+ if (nrt + 1 < p) {
+ e[nrt] = a[nrt][p - 1];
+ }
+ e[p - 1] = 0.0;
+
+ // If required, generate U.
+
+ if (wantu) {
+ for (int j = nct; j < nu; j++) {
+ for (int i = 0; i < m; i++) {
+ u[i][j] = 0.0;
+ }
+ u[j][j] = 1.0;
+ }
+ for (int k = nct - 1; k >= 0; k--) {
+ if (s[k] != 0.0) {
+ for (int j = k + 1; j < nu; j++) {
+ double t = 0;
+ for (int i = k; i < m; i++) {
+ t += u[i][k] * u[i][j];
+ }
+ t = -t / u[k][k];
+ for (int i = k; i < m; i++) {
+ u[i][j] += t * u[i][k];
+ }
+ }
+ for (int i = k; i < m; i++) {
+ u[i][k] = -u[i][k];
+ }
+ u[k][k] = 1.0 + u[k][k];
+ for (int i = 0; i < k - 1; i++) {
+ u[i][k] = 0.0;
+ }
+ } else {
+ for (int i = 0; i < m; i++) {
+ u[i][k] = 0.0;
+ }
+ u[k][k] = 1.0;
+ }
+ }
+ }
+
+ // If required, generate V.
+
+ if (wantv) {
+ for (int k = n - 1; k >= 0; k--) {
+ if (k < nrt && e[k] != 0.0) {
+ for (int j = k + 1; j < nu; j++) {
+ double t = 0;
+ for (int i = k + 1; i < n; i++) {
+ t += v[i][k] * v[i][j];
+ }
+ t = -t / v[k + 1][k];
+ for (int i = k + 1; i < n; i++) {
+ v[i][j] += t * v[i][k];
+ }
+ }
+ }
+ for (int i = 0; i < n; i++) {
+ v[i][k] = 0.0;
+ }
+ v[k][k] = 1.0;
+ }
+ }
+
+ // Main iteration loop for the singular values.
+
+ int pp = p - 1;
+ int iter = 0;
+ double eps = Math.pow(2.0, -52.0);
+ double tiny = Math.pow(2.0,-966.0);
+ while (p > 0) {
+ int k;
+
+ // Here is where a test for too many iterations would go.
+
+ // This section of the program inspects for
+ // negligible elements in the s and e arrays. On
+ // completion the variables kase and k are set as follows.
+
+ // kase = 1 if s(p) and e[k-1] are negligible and k<p
+ // kase = 2 if s(k) is negligible and k<p
+ // kase = 3 if e[k-1] is negligible, k<p, and
+ // s(k), ..., s(p) are not negligible (qr step).
+ // kase = 4 if e(p-1) is negligible (convergence).
+
+ for (k = p - 2; k >= -1; k--) {
+ if (k == -1) {
+ break;
+ }
+ if (Math.abs(e[k]) <= tiny +eps * (Math.abs(s[k]) + Math.abs(s[k + 1]))) {
+ e[k] = 0.0;
+ break;
+ }
+ }
+ int kase;
+ if (k == p - 2) {
+ kase = 4;
+ } else {
+ int ks;
+ for (ks = p - 1; ks >= k; ks--) {
+ if (ks == k) {
+ break;
+ }
+ double t =
+ (ks != p ? Math.abs(e[ks]) : 0.) +
+ (ks != k + 1 ? Math.abs(e[ks-1]) : 0.);
+ if (Math.abs(s[ks]) <= tiny + eps * t) {
+ s[ks] = 0.0;
+ break;
+ }
+ }
+ if (ks == k) {
+ kase = 3;
+ } else if (ks == p - 1) {
+ kase = 1;
+ } else {
+ kase = 2;
+ k = ks;
+ }
+ }
+ k++;
+
+ // Perform the task indicated by kase.
+
+ switch (kase) {
+
+ // Deflate negligible s(p).
+
+ case 1: {
+ double f = e[p - 2];
+ e[p - 2] = 0.0;
+ for (int j = p - 2; j >= k; j--) {
+ double t = Algebra.hypot(s[j], f);
+ double cs = s[j] / t;
+ double sn = f / t;
+ s[j] = t;
+ if (j != k) {
+ f = -sn * e[j - 1];
+ e[j - 1] = cs * e[j - 1];
+ }
+ if (wantv) {
+ for (int i = 0; i < n; i++) {
+ t = cs * v[i][j] + sn * v[i][p - 1];
+ v[i][p - 1] = -sn * v[i][j] + cs * v[i][p - 1];
+ v[i][j] = t;
+ }
+ }
+ }
+ }
+ break;
+
+ // Split at negligible s(k).
+
+ case 2: {
+ double f = e[k - 1];
+ e[k - 1] = 0.0;
+ for (int j = k; j < p; j++) {
+ double t = Algebra.hypot(s[j], f);
+ double cs = s[j] / t;
+ double sn = f / t;
+ s[j] = t;
+ f = -sn * e[j];
+ e[j] = cs * e[j];
+ if (wantu) {
+ for (int i = 0; i < m; i++) {
+ t = cs * u[i][j] + sn * u[i][k - 1];
+ u[i][k - 1] = -sn * u[i][j] + cs * u[i][k - 1];
+ u[i][j] = t;
+ }
+ }
+ }
+ }
+ break;
+
+ // Perform one qr step.
+
+ case 3: {
+
+ // Calculate the shift.
+
+ double scale = Math.max(Math.max(Math.max(Math.max(
+ Math.abs(s[p - 1]), Math.abs(s[p - 2])), Math.abs(e[p - 2])),
+ Math.abs(s[k])), Math.abs(e[k]));
+ double sp = s[p - 1] / scale;
+ double spm1 = s[p - 2] / scale;
+ double epm1 = e[p - 2] / scale;
+ double sk = s[k] / scale;
+ double ek = e[k] / scale;
+ double b = ((spm1 + sp) * (spm1 - sp) + epm1 * epm1) / 2.0;
+ double c = sp * epm1 * sp * epm1;
+ double shift = 0.0;
+ if (b != 0.0 || c != 0.0) {
+ shift = Math.sqrt(b * b + c);
+ if (b < 0.0) {
+ shift = -shift;
+ }
+ shift = c / (b + shift);
+ }
+ double f = (sk + sp) * (sk - sp) + shift;
+ double g = sk * ek;
+
+ // Chase zeros.
+
+ for (int j = k; j < p - 1; j++) {
+ double t = Algebra.hypot(f, g);
+ double cs = f / t;
+ double sn = g / t;
+ if (j != k) {
+ e[j - 1] = t;
+ }
+ f = cs * s[j] + sn * e[j];
+ e[j] = cs * e[j] - sn * s[j];
+ g = sn * s[j + 1];
+ s[j + 1] = cs * s[j + 1];
+ if (wantv) {
+ for (int i = 0; i < n; i++) {
+ t = cs * v[i][j] + sn * v[i][j + 1];
+ v[i][j + 1] = -sn * v[i][j] + cs * v[i][j + 1];
+ v[i][j] = t;
+ }
+ }
+ t = Algebra.hypot(f, g);
+ cs = f / t;
+ sn = g / t;
+ s[j] = t;
+ f = cs * e[j] + sn * s[j + 1];
+ s[j + 1] = -sn * e[j] + cs * s[j + 1];
+ g = sn * e[j + 1];
+ e[j + 1] = cs * e[j + 1];
+ if (wantu && j < m - 1) {
+ for (int i = 0; i < m; i++) {
+ t = cs * u[i][j] + sn * u[i][j + 1];
+ u[i][j + 1] = -sn * u[i][j] + cs * u[i][j + 1];
+ u[i][j] = t;
+ }
+ }
+ }
+ e[p - 2] = f;
+ iter = iter + 1;
+ }
+ break;
+
+ // Convergence.
+
+ case 4: {
+
+ // Make the singular values positive.
+
+ if (s[k] <= 0.0) {
+ s[k] = s[k] < 0.0 ? -s[k] : 0.0;
+ if (wantv) {
+ for (int i = 0; i <= pp; i++) {
+ v[i][k] = -v[i][k];
+ }
+ }
+ }
+
+ // Order the singular values.
+
+ while (k < pp) {
+ if (s[k] >= s[k + 1]) {
+ break;
+ }
+ double t = s[k];
+ s[k] = s[k + 1];
+ s[k + 1] = t;
+ if (wantv && k < n - 1) {
+ for (int i = 0; i < n; i++) {
+ t = v[i][k + 1];
+ v[i][k + 1] = v[i][k];
+ v[i][k] = t;
+ }
+ }
+ if (wantu && k < m - 1) {
+ for (int i = 0; i < m; i++) {
+ t = u[i][k + 1];
+ u[i][k + 1] = u[i][k];
+ u[i][k] = t;
+ }
+ }
+ k++;
+ }
+ iter = 0;
+ p--;
+ }
+ break;
+ default:
+ throw new IllegalStateException();
+ }
+ }
+ }
+
+ /**
+ * Returns the two norm condition number, which is <tt>max(S) / min(S)</tt>.
+ */
+ public double cond() {
+ return s[0] / s[Math.min(m, n) - 1];
+ }
+
+ /**
+ * @return the diagonal matrix of singular values.
+ */
+ public Matrix getS() {
+ double[][] s = new double[n][n];
+ for (int i = 0; i < n; i++) {
+ for (int j = 0; j < n; j++) {
+ s[i][j] = 0.0;
+ }
+ s[i][i] = this.s[i];
+ }
+
+ return new DenseMatrix(s);
+ }
+
+ /**
+ * Returns the diagonal of <tt>S</tt>, which is a one-dimensional array of
+ * singular values
+ *
+ * @return diagonal of <tt>S</tt>.
+ */
+ public double[] getSingularValues() {
+ return s;
+ }
+
+ /**
+ * Returns the left singular vectors <tt>U</tt>.
+ *
+ * @return <tt>U</tt>
+ */
+ public Matrix getU() {
+ if (transpositionNeeded) { //case numRows() < numCols()
+ return new DenseMatrix(v);
+ } else {
+ int numCols = Math.min(m + 1, n);
+ Matrix r = new DenseMatrix(m, numCols);
+ for (int i = 0; i < m; i++) {
+ for (int j = 0; j < numCols; j++) {
+ r.set(i, j, u[i][j]);
+ }
+ }
+
+ return r;
+ }
+ }
+
+ /**
+ * Returns the right singular vectors <tt>V</tt>.
+ *
+ * @return <tt>V</tt>
+ */
+ public Matrix getV() {
+ if (transpositionNeeded) { //case numRows() < numCols()
+ int numCols = Math.min(m + 1, n);
+ Matrix r = new DenseMatrix(m, numCols);
+ for (int i = 0; i < m; i++) {
+ for (int j = 0; j < numCols; j++) {
+ r.set(i, j, u[i][j]);
+ }
+ }
+
+ return r;
+ } else {
+ return new DenseMatrix(v);
+ }
+ }
+
+ /** Returns the two norm, which is <tt>max(S)</tt>. */
+ public double norm2() {
+ return s[0];
+ }
+
+ /**
+ * Returns the effective numerical matrix rank, which is the number of
+ * nonnegligible singular values.
+ */
+ public int rank() {
+ double eps = Math.pow(2.0, -52.0);
+ double tol = Math.max(m, n) * s[0] * eps;
+ int r = 0;
+ for (double value : s) {
+ if (value > tol) {
+ r++;
+ }
+ }
+ return r;
+ }
+
+ /**
+ * @param minSingularValue
+ * minSingularValue - value below which singular values are ignored (a 0 or negative
+ * value implies all singular value will be used)
+ * @return Returns the n × n covariance matrix.
+ * The covariance matrix is V × J × Vt where J is the diagonal matrix of the inverse
+ * of the squares of the singular values.
+ */
+ Matrix getCovariance(double minSingularValue) {
+ Matrix j = new DenseMatrix(s.length,s.length);
+ Matrix vMat = new DenseMatrix(this.v);
+ for (int i = 0; i < s.length; i++) {
+ j.set(i, i, s[i] >= minSingularValue ? 1 / (s[i] * s[i]) : 0.0);
+ }
+ return vMat.times(j).times(vMat.transpose());
+ }
+
+ /**
+ * Returns a String with (propertyName, propertyValue) pairs. Useful for
+ * debugging or to quickly get the rough picture. For example,
+ *
+ * <pre>
+ * rank : 3
+ * trace : 0
+ * </pre>
+ */
+ @Override
+ public String toString() {
+ StringBuilder buf = new StringBuilder();
+ buf.append("---------------------------------------------------------------------\n");
+ buf.append("SingularValueDecomposition(A) --> cond(A), rank(A), norm2(A), U, S, V\n");
+ buf.append("---------------------------------------------------------------------\n");
+
+ buf.append("cond = ");
+ String unknown = "Illegal operation or error: ";
+ try {
+ buf.append(String.valueOf(this.cond()));
+ } catch (IllegalArgumentException exc) {
+ buf.append(unknown).append(exc.getMessage());
+ }
+
+ buf.append("\nrank = ");
+ try {
+ buf.append(String.valueOf(this.rank()));
+ } catch (IllegalArgumentException exc) {
+ buf.append(unknown).append(exc.getMessage());
+ }
+
+ buf.append("\nnorm2 = ");
+ try {
+ buf.append(String.valueOf(this.norm2()));
+ } catch (IllegalArgumentException exc) {
+ buf.append(unknown).append(exc.getMessage());
+ }
+
+ buf.append("\n\nU = ");
+ try {
+ buf.append(String.valueOf(this.getU()));
+ } catch (IllegalArgumentException exc) {
+ buf.append(unknown).append(exc.getMessage());
+ }
+
+ buf.append("\n\nS = ");
+ try {
+ buf.append(String.valueOf(this.getS()));
+ } catch (IllegalArgumentException exc) {
+ buf.append(unknown).append(exc.getMessage());
+ }
+
+ buf.append("\n\nV = ");
+ try {
+ buf.append(String.valueOf(this.getV()));
+ } catch (IllegalArgumentException exc) {
+ buf.append(unknown).append(exc.getMessage());
+ }
+
+ return buf.toString();
+ }
+}
r***@apache.org
2018-09-08 23:35:10 UTC
Permalink
http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/jet/math/Polynomial.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/mahout/math/jet/math/Polynomial.java b/core/src/main/java/org/apache/mahout/math/jet/math/Polynomial.java
new file mode 100644
index 0000000..723e7d0
--- /dev/null
+++ b/core/src/main/java/org/apache/mahout/math/jet/math/Polynomial.java
@@ -0,0 +1,98 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/*
+Copyright 1999 CERN - European Organization for Nuclear Research.
+Permission to use, copy, modify, distribute and sell this software and its documentation for any purpose
+is hereby granted without fee, provided that the above copyright notice appear in all copies and
+that both that copyright notice and this permission notice appear in supporting documentation.
+CERN makes no representations about the suitability of this software for any purpose.
+It is provided "as is" without expressed or implied warranty.
+*/
+package org.apache.mahout.math.jet.math;
+
+/**
+ * Polynomial functions.
+ */
+public final class Polynomial {
+
+ private Polynomial() {
+ }
+
+ /**
+ * Evaluates the given polynomial of degree <tt>N</tt> at <tt>x</tt>, assuming coefficient of N is 1.0. Otherwise same
+ * as <tt>polevl()</tt>.
+ * <pre>
+ * 2 N
+ * y = C + C x + C x +...+ C x
+ * 0 1 2 N
+ *
+ * where C = 1 and hence is omitted from the array.
+ * N
+ *
+ * Coefficients are stored in reverse order:
+ *
+ * coef[0] = C , ..., coef[N-1] = C .
+ * N-1 0
+ *
+ * Calling arguments are otherwise the same as polevl().
+ * </pre>
+ * In the interest of speed, there are no checks for out of bounds arithmetic.
+ *
+ * @param x argument to the polynomial.
+ * @param coef the coefficients of the polynomial.
+ * @param N the degree of the polynomial.
+ */
+ public static double p1evl(double x, double[] coef, int N) {
+
+ double ans = x + coef[0];
+
+ for (int i = 1; i < N; i++) {
+ ans = ans * x + coef[i];
+ }
+
+ return ans;
+ }
+
+ /**
+ * Evaluates the given polynomial of degree <tt>N</tt> at <tt>x</tt>.
+ * <pre>
+ * 2 N
+ * y = C + C x + C x +...+ C x
+ * 0 1 2 N
+ *
+ * Coefficients are stored in reverse order:
+ *
+ * coef[0] = C , ..., coef[N] = C .
+ * N 0
+ * </pre>
+ * In the interest of speed, there are no checks for out of bounds arithmetic.
+ *
+ * @param x argument to the polynomial.
+ * @param coef the coefficients of the polynomial.
+ * @param N the degree of the polynomial.
+ */
+ public static double polevl(double x, double[] coef, int N) {
+ double ans = coef[0];
+
+ for (int i = 1; i <= N; i++) {
+ ans = ans * x + coef[i];
+ }
+
+ return ans;
+ }
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/jet/math/package-info.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/mahout/math/jet/math/package-info.java b/core/src/main/java/org/apache/mahout/math/jet/math/package-info.java
new file mode 100644
index 0000000..3cda850
--- /dev/null
+++ b/core/src/main/java/org/apache/mahout/math/jet/math/package-info.java
@@ -0,0 +1,5 @@
+/**
+ * Tools for basic and advanced mathematics: Arithmetics and Algebra, Polynomials and Chebyshev series, Bessel and Airy
+ * functions, Function Objects for generic function evaluation, etc.
+ */
+package org.apache.mahout.math.jet.math;

http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/jet/random/AbstractContinousDistribution.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/mahout/math/jet/random/AbstractContinousDistribution.java b/core/src/main/java/org/apache/mahout/math/jet/random/AbstractContinousDistribution.java
new file mode 100644
index 0000000..8ca03d0
--- /dev/null
+++ b/core/src/main/java/org/apache/mahout/math/jet/random/AbstractContinousDistribution.java
@@ -0,0 +1,51 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*
+Copyright 1999 CERN - European Organization for Nuclear Research.
+Permission to use, copy, modify, distribute and sell this software and its documentation for any purpose
+is hereby granted without fee, provided that the above copyright notice appear in all copies and
+that both that copyright notice and this permission notice appear in supporting documentation.
+CERN makes no representations about the suitability of this software for any purpose.
+It is provided "as is" without expressed or implied warranty.
+*/
+package org.apache.mahout.math.jet.random;
+
+/**
+ * Abstract base class for all continuous distributions. Continuous distributions have
+ * probability density and a cumulative distribution functions.
+ *
+ */
+public abstract class AbstractContinousDistribution extends AbstractDistribution {
+ public double cdf(double x) {
+ throw new UnsupportedOperationException("Can't compute pdf for " + this.getClass().getName());
+ }
+
+ public double pdf(double x) {
+ throw new UnsupportedOperationException("Can't compute pdf for " + this.getClass().getName());
+ }
+
+ /**
+ * @return A random number from the distribution; returns <tt>(int) Math.round(nextDouble())</tt>.
+ * Override this method if necessary.
+ */
+ @Override
+ public int nextInt() {
+ return (int) Math.round(nextDouble());
+ }
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/jet/random/AbstractDiscreteDistribution.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/mahout/math/jet/random/AbstractDiscreteDistribution.java b/core/src/main/java/org/apache/mahout/math/jet/random/AbstractDiscreteDistribution.java
new file mode 100644
index 0000000..d93d76c
--- /dev/null
+++ b/core/src/main/java/org/apache/mahout/math/jet/random/AbstractDiscreteDistribution.java
@@ -0,0 +1,27 @@
+/*
+Copyright 1999 CERN - European Organization for Nuclear Research.
+Permission to use, copy, modify, distribute and sell this software and its documentation for any purpose
+is hereby granted without fee, provided that the above copyright notice appear in all copies and
+that both that copyright notice and this permission notice appear in supporting documentation.
+CERN makes no representations about the suitability of this software for any purpose.
+It is provided "as is" without expressed or implied warranty.
+*/
+package org.apache.mahout.math.jet.random;
+
+/**
+ * Abstract base class for all discrete distributions.
+ *
+ */
+public abstract class AbstractDiscreteDistribution extends AbstractDistribution {
+
+ /** Makes this class non instantiable, but still let's others inherit from it. */
+ protected AbstractDiscreteDistribution() {
+ }
+
+ /** Returns a random number from the distribution; returns <tt>(double) nextInt()</tt>. */
+ @Override
+ public double nextDouble() {
+ return nextInt();
+ }
+
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/jet/random/AbstractDistribution.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/mahout/math/jet/random/AbstractDistribution.java b/core/src/main/java/org/apache/mahout/math/jet/random/AbstractDistribution.java
new file mode 100644
index 0000000..8e9cb0e
--- /dev/null
+++ b/core/src/main/java/org/apache/mahout/math/jet/random/AbstractDistribution.java
@@ -0,0 +1,87 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/*
+Copyright � 1999 CERN - European Organization for Nuclear Research.
+Permission to use, copy, modify, distribute and sell this software and its documentation for any purpose
+is hereby granted without fee, provided that the above copyright notice appear in all copies and
+that both that copyright notice and this permission notice appear in supporting documentation.
+CERN makes no representations about the suitability of this software for any purpose.
+It is provided "as is" without expressed or implied warranty.
+*/
+package org.apache.mahout.math.jet.random;
+
+import java.util.Random;
+
+import org.apache.mahout.math.function.DoubleFunction;
+import org.apache.mahout.math.function.IntFunction;
+
+public abstract class AbstractDistribution extends DoubleFunction implements IntFunction {
+
+ private Random randomGenerator;
+
+ /** Makes this class non instantiable, but still let's others inherit from it. */
+ protected AbstractDistribution() {
+ }
+
+ protected Random getRandomGenerator() {
+ return randomGenerator;
+ }
+
+ protected double randomDouble() {
+ return randomGenerator.nextDouble();
+ }
+
+ /**
+ * Equivalent to <tt>nextDouble()</tt>. This has the effect that distributions can now be used as function objects,
+ * returning a random number upon function evaluation.
+ */
+ @Override
+ public double apply(double dummy) {
+ return nextDouble();
+ }
+
+ /**
+ * Equivalent to <tt>nextInt()</tt>. This has the effect that distributions can now be used as function objects,
+ * returning a random number upon function evaluation.
+ */
+ @Override
+ public int apply(int dummy) {
+ return nextInt();
+ }
+
+ /**
+ * Returns a random number from the distribution.
+ * @return A new sample from this distribution.
+ */
+ public abstract double nextDouble();
+
+ /**
+ * @return
+ * A random number from the distribution; returns <tt>(int) Math.round(nextDouble())</tt>. Override this
+ * method if necessary.
+ */
+ public abstract int nextInt();
+
+ /**
+ * Sets the uniform random generator internally used.
+ * @param randomGenerator the new PRNG
+ */
+ public void setRandomGenerator(Random randomGenerator) {
+ this.randomGenerator = randomGenerator;
+ }
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/jet/random/Exponential.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/mahout/math/jet/random/Exponential.java b/core/src/main/java/org/apache/mahout/math/jet/random/Exponential.java
new file mode 100644
index 0000000..06472c2
--- /dev/null
+++ b/core/src/main/java/org/apache/mahout/math/jet/random/Exponential.java
@@ -0,0 +1,81 @@
+/*
+Copyright � 1999 CERN - European Organization for Nuclear Research.
+Permission to use, copy, modify, distribute and sell this software and its documentation for any purpose
+is hereby granted without fee, provided that the above copyright notice appear in all copies and
+that both that copyright notice and this permission notice appear in supporting documentation.
+CERN makes no representations about the suitability of this software for any purpose.
+It is provided "as is" without expressed or implied warranty.
+*/
+package org.apache.mahout.math.jet.random;
+
+import java.util.Locale;
+import java.util.Random;
+
+public class Exponential extends AbstractContinousDistribution {
+ // rate parameter for the distribution. Mean is 1/lambda.
+ private double lambda;
+
+ /**
+ * Provides a negative exponential distribution given a rate parameter lambda and an underlying
+ * random number generator. The mean of this distribution will be equal to 1/lambda.
+ *
+ * @param lambda The rate parameter of the distribution.
+ * @param randomGenerator The PRNG that is used to generate values.
+ */
+ public Exponential(double lambda, Random randomGenerator) {
+ setRandomGenerator(randomGenerator);
+ this.lambda = lambda;
+ }
+
+ /**
+ * Returns the cumulative distribution function.
+ * @param x The point at which the cumulative distribution function is to be evaluated.
+ * @return Returns the integral from -infinity to x of the PDF, also known as the cumulative distribution
+ * function.
+ */
+ @Override
+ public double cdf(double x) {
+ if (x <= 0.0) {
+ return 0.0;
+ }
+ return 1.0 - Math.exp(-x * lambda);
+ }
+
+ /**
+ * Returns a random number from the distribution.
+ */
+ @Override
+ public double nextDouble() {
+ return -Math.log1p(-randomDouble()) / lambda;
+ }
+
+ /**
+ * Returns the value of the probability density function at a particular point.
+ * @param x The point at which the probability density function is to be evaluated.
+ * @return The value of the probability density function at the specified point.
+ */
+ @Override
+ public double pdf(double x) {
+ if (x < 0.0) {
+ return 0.0;
+ }
+ return lambda * Math.exp(-x * lambda);
+ }
+
+ /**
+ * Sets the rate parameter.
+ * @param lambda The new value of the rate parameter.
+ */
+ public void setState(double lambda) {
+ this.lambda = lambda;
+ }
+
+ /**
+ * Returns a String representation of the receiver.
+ */
+ @Override
+ public String toString() {
+ return String.format(Locale.ENGLISH, "%s(%.4f)", this.getClass().getName(), lambda);
+ }
+
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/jet/random/Gamma.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/mahout/math/jet/random/Gamma.java b/core/src/main/java/org/apache/mahout/math/jet/random/Gamma.java
new file mode 100644
index 0000000..914157b
--- /dev/null
+++ b/core/src/main/java/org/apache/mahout/math/jet/random/Gamma.java
@@ -0,0 +1,302 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/*
+Copyright � 1999 CERN - European Organization for Nuclear Research.
+Permission to use, copy, modify, distribute and sell this software and its documentation for any purpose
+is hereby granted without fee, provided that the above copyright notice appear in all copies and
+that both that copyright notice and this permission notice appear in supporting documentation.
+CERN makes no representations about the suitability of this software for any purpose.
+It is provided "as is" without expressed or implied warranty.
+*/
+package org.apache.mahout.math.jet.random;
+
+import org.apache.mahout.math.jet.stat.Probability;
+
+import java.util.Random;
+
+public class Gamma extends AbstractContinousDistribution {
+ // shape
+ private final double alpha;
+
+ // rate
+ private final double rate;
+
+ /**
+ * Constructs a Gamma distribution with a given shape (alpha) and rate (beta).
+ *
+ * @param alpha The shape parameter.
+ * @param rate The rate parameter.
+ * @param randomGenerator The random number generator that generates bits for us.
+ * @throws IllegalArgumentException if <tt>alpha &lt;= 0.0 || alpha &lt;= 0.0</tt>.
+ */
+ public Gamma(double alpha, double rate, Random randomGenerator) {
+ this.alpha = alpha;
+ this.rate = rate;
+ setRandomGenerator(randomGenerator);
+ }
+
+ /**
+ * Returns the cumulative distribution function.
+ * @param x The end-point where the cumulation should end.
+ */
+ @Override
+ public double cdf(double x) {
+ return Probability.gamma(alpha, rate, x);
+ }
+
+ /** Returns a random number from the distribution. */
+ @Override
+ public double nextDouble() {
+ return nextDouble(alpha, rate);
+ }
+
+ /** Returns a random number from the distribution; bypasses the internal state.
+ * *
+ * Gamma Distribution - Acceptance Rejection combined with *
+ * Acceptance Complement *
+ * *
+ ******************************************************************
+ * *
+ * FUNCTION: - gds samples a random number from the standard *
+ * gamma distribution with parameter a &gt; 0. *
+ * Acceptance Rejection gs for a &lt; 1 , *
+ * Acceptance Complement gd for a &gt;= 1 . *
+ * REFERENCES: - J.H. Ahrens, U. Dieter (1974): Computer methods *
+ * for sampling from gamma, beta, Poisson and *
+ * binomial distributions, Computing 12, 223-246. *
+ * - J.H. Ahrens, U. Dieter (1982): Generating gamma *
+ * variates by a modified rejection technique, *
+ * Communications of the ACM 25, 47-54. *
+ * SUBPROGRAMS: - drand(seed) ... (0,1)-Uniform generator with *
+ * unsigned long integer *seed *
+ * - NORMAL(seed) ... Normal generator N(0,1). *
+ * *
+ * @param alpha Shape parameter.
+ * @param rate Rate parameter (=1/scale).
+ * @return A gamma distributed sample.
+ */
+ public double nextDouble(double alpha, double rate) {
+ if (alpha <= 0.0) {
+ throw new IllegalArgumentException();
+ }
+ if (rate <= 0.0) {
+ throw new IllegalArgumentException();
+ }
+
+ double gds;
+ double b = 0.0;
+ if (alpha < 1.0) { // CASE A: Acceptance rejection algorithm gs
+ b = 1.0 + 0.36788794412 * alpha; // Step 1
+ while (true) {
+ double p = b * randomDouble();
+ if (p <= 1.0) { // Step 2. Case gds <= 1
+ gds = Math.exp(Math.log(p) / alpha);
+ if (Math.log(randomDouble()) <= -gds) {
+ return gds / rate;
+ }
+ } else { // Step 3. Case gds > 1
+ gds = -Math.log((b - p) / alpha);
+ if (Math.log(randomDouble()) <= (alpha - 1.0) * Math.log(gds)) {
+ return gds / rate;
+ }
+ }
+ }
+ } else { // CASE B: Acceptance complement algorithm gd (gaussian distribution, box muller transformation)
+ double ss = 0.0;
+ double s = 0.0;
+ double d = 0.0;
+ if (alpha != -1.0) { // Step 1. Preparations
+ ss = alpha - 0.5;
+ s = Math.sqrt(ss);
+ d = 5.656854249 - 12.0 * s;
+ }
+ // Step 2. Normal deviate
+ double v12;
+ double v1;
+ do {
+ v1 = 2.0 * randomDouble() - 1.0;
+ double v2 = 2.0 * randomDouble() - 1.0;
+ v12 = v1 * v1 + v2 * v2;
+ } while (v12 > 1.0);
+ double t = v1 * Math.sqrt(-2.0 * Math.log(v12) / v12);
+ double x = s + 0.5 * t;
+ gds = x * x;
+ if (t >= 0.0) {
+ return gds / rate;
+ } // Immediate acceptance
+
+ double u = randomDouble();
+ if (d * u <= t * t * t) {
+ return gds / rate;
+ } // Squeeze acceptance
+
+ double q0 = 0.0;
+ double si = 0.0;
+ double c = 0.0;
+ if (alpha != -1.0) { // Step 4. Set-up for hat case
+ double r = 1.0 / alpha;
+ double q9 = 0.0001710320;
+ double q8 = -0.0004701849;
+ double q7 = 0.0006053049;
+ double q6 = 0.0003340332;
+ double q5 = -0.0003349403;
+ double q4 = 0.0015746717;
+ double q3 = 0.0079849875;
+ double q2 = 0.0208333723;
+ double q1 = 0.0416666664;
+ q0 = ((((((((q9 * r + q8) * r + q7) * r + q6) * r + q5) * r + q4) * r + q3) * r + q2) * r + q1) * r;
+ if (alpha > 3.686) {
+ if (alpha > 13.022) {
+ b = 1.77;
+ si = 0.75;
+ c = 0.1515 / s;
+ } else {
+ b = 1.654 + 0.0076 * ss;
+ si = 1.68 / s + 0.275;
+ c = 0.062 / s + 0.024;
+ }
+ } else {
+ b = 0.463 + s - 0.178 * ss;
+ si = 1.235;
+ c = 0.195 / s - 0.079 + 0.016 * s;
+ }
+ }
+ double v;
+ double q;
+ double a9 = 0.104089866;
+ double a8 = -0.112750886;
+ double a7 = 0.110368310;
+ double a6 = -0.124385581;
+ double a5 = 0.142873973;
+ double a4 = -0.166677482;
+ double a3 = 0.199999867;
+ double a2 = -0.249999949;
+ double a1 = 0.333333333;
+ if (x > 0.0) { // Step 5. Calculation of q
+ v = t / (s + s); // Step 6.
+ if (Math.abs(v) > 0.25) {
+ q = q0 - s * t + 0.25 * t * t + (ss + ss) * Math.log1p(v);
+ } else {
+ q = q0 + 0.5 * t * t * ((((((((a9 * v + a8) * v + a7) * v + a6)
+ * v + a5) * v + a4) * v + a3) * v + a2) * v + a1) * v;
+ } // Step 7. Quotient acceptance
+ if (Math.log1p(-u) <= q) {
+ return gds / rate;
+ }
+ }
+
+ double e7 = 0.000247453;
+ double e6 = 0.001353826;
+ double e5 = 0.008345522;
+ double e4 = 0.041664508;
+ double e3 = 0.166666848;
+ double e2 = 0.499999994;
+ double e1 = 1.000000000;
+ while (true) { // Step 8. Double exponential deviate t
+ double sign_u;
+ double e;
+ do {
+ e = -Math.log(randomDouble());
+ u = randomDouble();
+ u = u + u - 1.0;
+ sign_u = u > 0 ? 1.0 : -1.0;
+ t = b + e * si * sign_u;
+ } while (t <= -0.71874483771719); // Step 9. Rejection of t
+ v = t / (s + s); // Step 10. New q(t)
+ if (Math.abs(v) > 0.25) {
+ q = q0 - s * t + 0.25 * t * t + (ss + ss) * Math.log1p(v);
+ } else {
+ q = q0 + 0.5 * t * t * ((((((((a9 * v + a8) * v + a7) * v + a6)
+ * v + a5) * v + a4) * v + a3) * v + a2) * v + a1) * v;
+ }
+ if (q <= 0.0) {
+ continue;
+ } // Step 11.
+ double w;
+ if (q > 0.5) {
+ w = Math.exp(q) - 1.0;
+ } else {
+ w = ((((((e7 * q + e6) * q + e5) * q + e4) * q + e3) * q + e2) * q + e1) * q;
+ } // Step 12. Hat acceptance
+ if (c * u * sign_u <= w * Math.exp(e - 0.5 * t * t)) {
+ x = s + 0.5 * t;
+ return x * x / rate;
+ }
+ }
+ }
+ }
+
+ /** Returns the probability distribution function.
+ * @param x Where to compute the density function.
+ *
+ * @return The value of the gamma density at x.
+ */
+ @Override
+ public double pdf(double x) {
+ if (x < 0) {
+ throw new IllegalArgumentException();
+ }
+ if (x == 0) {
+ if (alpha == 1.0) {
+ return rate;
+ } else if (alpha < 1) {
+ return Double.POSITIVE_INFINITY;
+ } else {
+ return 0;
+ }
+ }
+ if (alpha == 1.0) {
+ return rate * Math.exp(-x * rate);
+ }
+ return rate * Math.exp((alpha - 1.0) * Math.log(x * rate) - x * rate - logGamma(alpha));
+ }
+
+ @Override
+ public String toString() {
+ return this.getClass().getName() + '(' + rate + ',' + alpha + ')';
+ }
+
+ /** Returns a quick approximation of <tt>log(gamma(x))</tt>. */
+ public static double logGamma(double x) {
+
+ if (x <= 0.0 /* || x > 1.3e19 */) {
+ return -999;
+ }
+
+ double z;
+ for (z = 1.0; x < 11.0; x++) {
+ z *= x;
+ }
+
+ double r = 1.0 / (x * x);
+ double c6 = -1.9175269175269175e-03;
+ double c5 = 8.4175084175084175e-04;
+ double c4 = -5.9523809523809524e-04;
+ double c3 = 7.9365079365079365e-04;
+ double c2 = -2.7777777777777777e-03;
+ double c1 = 8.3333333333333333e-02;
+ double g = c1 + r * (c2 + r * (c3 + r * (c4 + r * (c5 + r + c6))));
+ double c0 = 9.1893853320467274e-01;
+ g = (x - 0.5) * Math.log(x) - x + c0 + g / x;
+ if (z == 1.0) {
+ return g;
+ }
+ return g - Math.log(z);
+ }
+
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/jet/random/NegativeBinomial.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/mahout/math/jet/random/NegativeBinomial.java b/core/src/main/java/org/apache/mahout/math/jet/random/NegativeBinomial.java
new file mode 100644
index 0000000..1e577eb
--- /dev/null
+++ b/core/src/main/java/org/apache/mahout/math/jet/random/NegativeBinomial.java
@@ -0,0 +1,106 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/*
+Copyright � 1999 CERN - European Organization for Nuclear Research.
+Permission to use, copy, modify, distribute and sell this software and its documentation for any purpose
+is hereby granted without fee, provided that the above copyright notice appear in all copies and
+that both that copyright notice and this permission notice appear in supporting documentation.
+CERN makes no representations about the suitability of this software for any purpose.
+It is provided "as is" without expressed or implied warranty.
+*/
+package org.apache.mahout.math.jet.random;
+
+import org.apache.mahout.math.jet.math.Arithmetic;
+import org.apache.mahout.math.jet.stat.Probability;
+
+import java.util.Random;
+
+/** Mostly deprecated until unit tests are in place. Until this time, this class/interface is unsupported. */
+public final class NegativeBinomial extends AbstractDiscreteDistribution {
+
+ private final int r;
+ private final double p;
+
+ private final Gamma gamma;
+ private final Poisson poisson;
+
+ /**
+ * Constructs a Negative Binomial distribution which describes the probability of getting
+ * a particular number of negative trials (k) before getting a fixed number of positive
+ * trials (r) where each positive trial has probability (p) of being successful.
+ *
+ * @param r the required number of positive trials.
+ * @param p the probability of success.
+ * @param randomGenerator a uniform random number generator.
+ */
+ public NegativeBinomial(int r, double p, Random randomGenerator) {
+ setRandomGenerator(randomGenerator);
+ this.r = r;
+ this.p = p;
+ this.gamma = new Gamma(r, 1, randomGenerator);
+ this.poisson = new Poisson(0.0, randomGenerator);
+ }
+
+ /**
+ * Returns the cumulative distribution function.
+ */
+ public double cdf(int k) {
+ return Probability.negativeBinomial(k, r, p);
+ }
+
+ /**
+ * Returns the probability distribution function.
+ */
+ public double pdf(int k) {
+ return Arithmetic.binomial(k + r - 1, r - 1) * Math.pow(p, r) * Math.pow(1.0 - p, k);
+ }
+
+ @Override
+ public int nextInt() {
+ return nextInt(r, p);
+ }
+
+ /**
+ * Returns a sample from this distribution. The value returned will
+ * be the number of negative samples required before achieving r
+ * positive samples. Each successive sample is taken independently
+ * from a Bernouli process with probability p of success.
+ *
+ * The algorithm used is taken from J.H. Ahrens, U. Dieter (1974):
+ * Computer methods for sampling from gamma, beta, Poisson and
+ * binomial distributions, Computing 12, 223--246.
+ *
+ * This algorithm is essentially the same as described at
+ * http://en.wikipedia.org/wiki/Negative_binomial_distribution#Gamma.E2.80.93Poisson_mixture
+ * except that the notion of positive and negative outcomes is uniformly
+ * inverted. Because the inversion is complete and consistent, this
+ * definition is effectively identical to that defined on wikipedia.
+ */
+ public int nextInt(int r, double p) {
+ return this.poisson.nextInt(gamma.nextDouble(r, p / (1.0 - p)));
+ }
+
+ /**
+ * Returns a String representation of the receiver.
+ */
+ @Override
+ public String toString() {
+ return this.getClass().getName() + '(' + r + ',' + p + ')';
+ }
+
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/jet/random/Normal.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/mahout/math/jet/random/Normal.java b/core/src/main/java/org/apache/mahout/math/jet/random/Normal.java
new file mode 100644
index 0000000..7ceac22
--- /dev/null
+++ b/core/src/main/java/org/apache/mahout/math/jet/random/Normal.java
@@ -0,0 +1,110 @@
+/*
+Copyright � 1999 CERN - European Organization for Nuclear Research.
+Permission to use, copy, modify, distribute and sell this software and its documentation for any purpose
+is hereby granted without fee, provided that the above copyright notice appear in all copies and
+that both that copyright notice and this permission notice appear in supporting documentation.
+CERN makes no representations about the suitability of this software for any purpose.
+It is provided "as is" without expressed or implied warranty.
+*/
+package org.apache.mahout.math.jet.random;
+
+import org.apache.mahout.math.jet.stat.Probability;
+
+import java.util.Locale;
+import java.util.Random;
+
+/**
+ * Implements a normal distribution specified mean and standard deviation.
+ */
+public class Normal extends AbstractContinousDistribution {
+
+ private double mean;
+ private double variance;
+ private double standardDeviation;
+
+ private double cache; // cache for Box-Mueller algorithm
+ private boolean cacheFilled; // Box-Mueller
+
+ private double normalizer; // performance cache
+
+ /**
+ * @param mean The mean of the resulting distribution.
+ * @param standardDeviation The standard deviation of the distribution.
+ * @param randomGenerator The random number generator to use. This can be null if you don't
+ * need to generate any numbers.
+ */
+ public Normal(double mean, double standardDeviation, Random randomGenerator) {
+ setRandomGenerator(randomGenerator);
+ setState(mean, standardDeviation);
+ }
+
+ /**
+ * Returns the cumulative distribution function.
+ */
+ @Override
+ public double cdf(double x) {
+ return Probability.normal(mean, variance, x);
+ }
+
+ /** Returns the probability density function. */
+ @Override
+ public double pdf(double x) {
+ double diff = x - mean;
+ return normalizer * Math.exp(-(diff * diff) / (2.0 * variance));
+ }
+
+ /**
+ * Returns a random number from the distribution.
+ */
+ @Override
+ public double nextDouble() {
+ // Uses polar Box-Muller transformation.
+ if (cacheFilled) {
+ cacheFilled = false;
+ return cache;
+ }
+
+ double x;
+ double y;
+ double r;
+ do {
+ x = 2.0 * randomDouble() - 1.0;
+ y = 2.0 * randomDouble() - 1.0;
+ r = x * x + y * y;
+ } while (r >= 1.0);
+
+ double z = Math.sqrt(-2.0 * Math.log(r) / r);
+ cache = this.mean + this.standardDeviation * x * z;
+ cacheFilled = true;
+ return this.mean + this.standardDeviation * y * z;
+ }
+
+ /** Sets the uniform random generator internally used. */
+ @Override
+ public final void setRandomGenerator(Random randomGenerator) {
+ super.setRandomGenerator(randomGenerator);
+ this.cacheFilled = false;
+ }
+
+ /**
+ * Sets the mean and variance.
+ * @param mean The new value for the mean.
+ * @param standardDeviation The new value for the standard deviation.
+ */
+ public final void setState(double mean, double standardDeviation) {
+ if (mean != this.mean || standardDeviation != this.standardDeviation) {
+ this.mean = mean;
+ this.standardDeviation = standardDeviation;
+ this.variance = standardDeviation * standardDeviation;
+ this.cacheFilled = false;
+
+ this.normalizer = 1.0 / Math.sqrt(2.0 * Math.PI * variance);
+ }
+ }
+
+ /** Returns a String representation of the receiver. */
+ @Override
+ public String toString() {
+ return String.format(Locale.ENGLISH, "%s(m=%f, sd=%f)", this.getClass().getName(), mean, standardDeviation);
+ }
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/jet/random/Poisson.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/mahout/math/jet/random/Poisson.java b/core/src/main/java/org/apache/mahout/math/jet/random/Poisson.java
new file mode 100644
index 0000000..497691e
--- /dev/null
+++ b/core/src/main/java/org/apache/mahout/math/jet/random/Poisson.java
@@ -0,0 +1,296 @@
+/*
+Copyright � 1999 CERN - European Organization for Nuclear Research.
+Permission to use, copy, modify, distribute and sell this software and its documentation for any purpose
+is hereby granted without fee, provided that the above copyright notice appear in all copies and
+that both that copyright notice and this permission notice appear in supporting documentation.
+CERN makes no representations about the suitability of this software for any purpose.
+It is provided "as is" without expressed or implied warranty.
+*/
+package org.apache.mahout.math.jet.random;
+
+import org.apache.mahout.math.jet.math.Arithmetic;
+
+import java.util.Random;
+
+/** Partially deprecated until unit tests are in place. Until this time, this class/interface is unsupported. */
+public final class Poisson extends AbstractDiscreteDistribution {
+
+ private final double mean;
+
+ // precomputed and cached values (for performance only)
+ // cache for < SWITCH_MEAN
+ private double myOld = -1.0;
+ private double p;
+ private double q;
+ private double p0;
+ private final double[] pp = new double[36];
+ private int llll;
+
+ // cache for >= SWITCH_MEAN
+ private double myLast = -1.0;
+ private double ll;
+ private int k2;
+ private int k4;
+ private int k1;
+ private int k5;
+ private double dl;
+ private double dr;
+ private double r1;
+ private double r2;
+ private double r4;
+ private double r5;
+ private double lr;
+ private double lMy;
+ private double cPm;
+ private double f1;
+ private double f2;
+ private double f4;
+ private double f5;
+ private double p1;
+ private double p2;
+ private double p3;
+ private double p4;
+ private double p5;
+ private double p6;
+
+ // cache for both;
+
+
+ private static final double MEAN_MAX = Integer.MAX_VALUE;
+ // for all means larger than that, we don't try to compute a poisson deviation, but return the mean.
+ private static final double SWITCH_MEAN = 10.0; // switch from method A to method B
+
+
+ /** Constructs a poisson distribution. Example: mean=1.0. */
+ public Poisson(double mean, Random randomGenerator) {
+ setRandomGenerator(randomGenerator);
+ this.mean = mean;
+ }
+
+ private static double f(int k, double lNu, double cPm) {
+ return Math.exp(k * lNu - Arithmetic.logFactorial(k) - cPm);
+ }
+
+ @Override
+ public int nextInt() {
+ return nextInt(mean);
+ }
+
+ /** Returns a random number from the distribution; bypasses the internal state. */
+ public int nextInt(double theMean) {
+ /******************************************************************
+ * *
+ * Poisson Distribution - Patchwork Rejection/Inversion *
+ * *
+ ******************************************************************
+ * *
+ * For parameter my < 10 Tabulated Inversion is applied. *
+ * For my >= 10 Patchwork Rejection is employed: *
+ * The area below the histogram function f(x) is rearranged in *
+ * its body by certain point reflections. Within a large center *
+ * interval variates are sampled efficiently by rejection from *
+ * uniform hats. Rectangular immediate acceptance regions speed *
+ * up the generation. The remaining tails are covered by *
+ * exponential functions. *
+ * *
+ *****************************************************************/
+ Random gen = getRandomGenerator();
+
+ //double t, g, my_k;
+
+ //double gx, gy, px, py, e, x, xx, delta, v;
+ //int sign;
+
+ //static double p,q,p0,pp[36];
+ //static long ll,m;
+
+ int m;
+ if (theMean < SWITCH_MEAN) { // CASE B: Inversion- start new table and calculate p0
+ if (theMean != myOld) {
+ myOld = theMean;
+ llll = 0;
+ p = Math.exp(-theMean);
+ q = p;
+ p0 = p;
+ //for (k=pp.length; --k >=0;) pp[k] = 0;
+ }
+ m = theMean > 1.0 ? (int) theMean : 1;
+ while (true) {
+ double u = gen.nextDouble();
+ int k = 0;
+ if (u <= p0) {
+ return k;
+ }
+ if (llll != 0) { // Step T. Table comparison
+ int i = u > 0.458 ? Math.min(llll, m) : 1;
+ for (k = i; k <= llll; k++) {
+ if (u <= pp[k]) {
+ return k;
+ }
+ }
+ if (llll == 35) {
+ continue;
+ }
+ }
+ for (k = llll + 1; k <= 35; k++) { // Step C. Creation of new prob.
+ p *= theMean / k;
+ q += p;
+ pp[k] = q;
+ if (u <= q) {
+ llll = k;
+ return k;
+ }
+ }
+ llll = 35;
+ }
+ // end my < SWITCH_MEAN
+ } else if (theMean < MEAN_MAX) { // CASE A: acceptance complement
+ //static double my_last = -1.0;
+ //static long int m, k2, k4, k1, k5;
+ //static double dl, dr, r1, r2, r4, r5, ll, lr, l_my, c_pm,
+ // f1, f2, f4, f5, p1, p2, p3, p4, p5, p6;
+
+ m = (int) theMean;
+ if (theMean != myLast) { // set-up
+ myLast = theMean;
+
+ // approximate deviation of reflection points k2, k4 from my - 1/2
+ double Ds = Math.sqrt(theMean + 0.25);
+
+ // mode m, reflection points k2 and k4, and points k1 and k5, which
+ // delimit the centre region of h(x)
+ k2 = (int) Math.ceil(theMean - 0.5 - Ds);
+ k4 = (int) (theMean - 0.5 + Ds);
+ k1 = k2 + k2 - m + 1;
+ k5 = k4 + k4 - m;
+
+ // range width of the critical left and right centre region
+ dl = k2 - k1;
+ dr = k5 - k4;
+
+ // recurrence constants r(k) = p(k)/p(k-1) at k = k1, k2, k4+1, k5+1
+ r1 = theMean / k1;
+ r2 = theMean / k2;
+ r4 = theMean / (k4 + 1);
+ r5 = theMean / (k5 + 1);
+
+ // reciprocal values of the scale parameters of expon. tail envelopes
+ ll = Math.log(r1); // expon. tail left
+ lr = -Math.log(r5); // expon. tail right
+
+ // Poisson constants, necessary for computing function values f(k)
+ lMy = Math.log(theMean);
+ cPm = m * lMy - Arithmetic.logFactorial(m);
+
+ // function values f(k) = p(k)/p(m) at k = k2, k4, k1, k5
+ f2 = f(k2, lMy, cPm);
+ f4 = f(k4, lMy, cPm);
+ f1 = f(k1, lMy, cPm);
+ f5 = f(k5, lMy, cPm);
+
+ // area of the two centre and the two exponential tail regions
+ // area of the two immediate acceptance regions between k2, k4
+ p1 = f2 * (dl + 1.0); // immed. left
+ p2 = f2 * dl + p1; // centre left
+ p3 = f4 * (dr + 1.0) + p2; // immed. right
+ p4 = f4 * dr + p3; // centre right
+ p5 = f1 / ll + p4; // expon. tail left
+ p6 = f5 / lr + p5; // expon. tail right
+ } // end set-up
+
+ while (true) {
+ // generate uniform number U -- U(0, p6)
+ // case distinction corresponding to U
+ double W;
+ double V;
+ double U;
+ int Y;
+ int X;
+ int Dk;
+ if ((U = gen.nextDouble() * p6) < p2) { // centre left
+
+ // immediate acceptance region R2 = [k2, m) *[0, f2), X = k2, ... m -1
+ if ((V = U - p1) < 0.0) {
+ return k2 + (int) (U / f2);
+ }
+ // immediate acceptance region R1 = [k1, k2)*[0, f1), X = k1, ... k2-1
+ if ((W = V / dl) < f1) {
+ return k1 + (int) (V / f1);
+ }
+
+ // computation of candidate X < k2, and its counterpart Y > k2
+ // either squeeze-acceptance of X or acceptance-rejection of Y
+ Dk = gen.nextInt((int) dl) + 1;
+ if (W <= f2 - Dk * (f2 - f2 / r2)) { // quick accept of
+ return k2 - Dk; // X = k2 - Dk
+ }
+ if ((V = f2 + f2 - W) < 1.0) { // quick reject of Y
+ Y = k2 + Dk;
+ if (V <= f2 + Dk * (1.0 - f2) / (dl + 1.0)) { // quick accept of
+ return Y; // Y = k2 + Dk
+ }
+ if (V <= f(Y, lMy, cPm)) {
+ return Y;
+ } // final accept of Y
+ }
+ X = k2 - Dk;
+ } else if (U < p4) { // centre right
+ // immediate acceptance region R3 = [m, k4+1)*[0, f4), X = m, ... k4
+ if ((V = U - p3) < 0.0) {
+ return k4 - (int) ((U - p2) / f4);
+ }
+ // immediate acceptance region R4 = [k4+1, k5+1)*[0, f5)
+ if ((W = V / dr) < f5) {
+ return k5 - (int) (V / f5);
+ }
+
+ // computation of candidate X > k4, and its counterpart Y < k4
+ // either squeeze-acceptance of X or acceptance-rejection of Y
+ Dk = gen.nextInt((int) dr) + 1;
+ if (W <= f4 - Dk * (f4 - f4 * r4)) { // quick accept of
+ return k4 + Dk; // X = k4 + Dk
+ }
+ if ((V = f4 + f4 - W) < 1.0) { // quick reject of Y
+ Y = k4 - Dk;
+ if (V <= f4 + Dk * (1.0 - f4) / dr) { // quick accept of
+ return Y; // Y = k4 - Dk
+ }
+ if (V <= f(Y, lMy, cPm)) {
+ return Y;
+ } // final accept of Y
+ }
+ X = k4 + Dk;
+ } else {
+ W = gen.nextDouble();
+ if (U < p5) { // expon. tail left
+ Dk = (int) (1.0 - Math.log(W) / ll);
+ if ((X = k1 - Dk) < 0) {
+ continue;
+ } // 0 <= X <= k1 - 1
+ W *= (U - p4) * ll; // W -- U(0, h(x))
+ if (W <= f1 - Dk * (f1 - f1 / r1)) {
+ return X;
+ } // quick accept of X
+ } else { // expon. tail right
+ Dk = (int) (1.0 - Math.log(W) / lr);
+ X = k5 + Dk; // X >= k5 + 1
+ W *= (U - p5) * lr; // W -- U(0, h(x))
+ if (W <= f5 - Dk * (f5 - f5 * r5)) {
+ return X;
+ } // quick accept of X
+ }
+ }
+
+ // acceptance-rejection test of candidate X from the original area
+ // test, whether W <= f(k), with W = U*h(x) and U -- U(0, 1)
+ // log f(X) = (X - m)*log(my) - log X! + log m!
+ if (Math.log(W) <= X * lMy - Arithmetic.logFactorial(X) - cPm) {
+ return X;
+ }
+ }
+ } else { // mean is too large
+ return (int) theMean;
+ }
+ }
+
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/jet/random/Uniform.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/mahout/math/jet/random/Uniform.java b/core/src/main/java/org/apache/mahout/math/jet/random/Uniform.java
new file mode 100644
index 0000000..32c8b90
--- /dev/null
+++ b/core/src/main/java/org/apache/mahout/math/jet/random/Uniform.java
@@ -0,0 +1,164 @@
+/*
+Copyright � 1999 CERN - European Organization for Nuclear Research.
+Permission to use, copy, modify, distribute and sell this software and its documentation for any purpose
+is hereby granted without fee, provided that the above copyright notice appear in all copies and
+that both that copyright notice and this permission notice appear in supporting documentation.
+CERN makes no representations about the suitability of this software for any purpose.
+It is provided "as is" without expressed or implied warranty.
+*/
+package org.apache.mahout.math.jet.random;
+
+import org.apache.mahout.common.RandomUtils;
+
+import java.util.Random;
+
+public class Uniform extends AbstractContinousDistribution {
+
+ private double min;
+ private double max;
+
+ /**
+ * Constructs a uniform distribution with the given minimum and maximum, using a {@link
+ * org.apache.mahout.math.jet.random.engine.MersenneTwister} seeded with the given seed.
+ */
+ public Uniform(double min, double max, int seed) {
+ this(min, max, RandomUtils.getRandom(seed));
+ }
+
+ /** Constructs a uniform distribution with the given minimum and maximum. */
+ public Uniform(double min, double max, Random randomGenerator) {
+ setRandomGenerator(randomGenerator);
+ setState(min, max);
+ }
+
+ /** Constructs a uniform distribution with <tt>min=0.0</tt> and <tt>max=1.0</tt>. */
+ public Uniform(Random randomGenerator) {
+ this(0, 1, randomGenerator);
+ }
+
+ /** Returns the cumulative distribution function (assuming a continous uniform distribution). */
+ @Override
+ public double cdf(double x) {
+ if (x <= min) {
+ return 0.0;
+ }
+ if (x >= max) {
+ return 1.0;
+ }
+ return (x - min) / (max - min);
+ }
+
+ /** Returns a uniformly distributed random <tt>boolean</tt>. */
+ public boolean nextBoolean() {
+ return randomDouble() > 0.5;
+ }
+
+ /**
+ * Returns a uniformly distributed random number in the open interval <tt>(min,max)</tt> (excluding <tt>min</tt> and
+ * <tt>max</tt>).
+ */
+ @Override
+ public double nextDouble() {
+ return min + (max - min) * randomDouble();
+ }
+
+ /**
+ * Returns a uniformly distributed random number in the open interval <tt>(from,to)</tt> (excluding <tt>from</tt> and
+ * <tt>to</tt>). Pre conditions: <tt>from &lt;= to</tt>.
+ */
+ public double nextDoubleFromTo(double from, double to) {
+ return from + (to - from) * randomDouble();
+ }
+
+ /**
+ * Returns a uniformly distributed random number in the open interval <tt>(from,to)</tt> (excluding <tt>from</tt> and
+ * <tt>to</tt>). Pre conditions: <tt>from &lt;= to</tt>.
+ */
+ public float nextFloatFromTo(float from, float to) {
+ return (float) nextDoubleFromTo(from, to);
+ }
+
+ /**
+ * Returns a uniformly distributed random number in the closed interval
+ * <tt>[from,to]</tt> (including <tt>from</tt>
+ * and <tt>to</tt>). Pre conditions: <tt>from &lt;= to</tt>.
+ */
+ public int nextIntFromTo(int from, int to) {
+ return (int) (from + (long) ((1L + to - from) * randomDouble()));
+ }
+
+ /**
+ * Returns a uniformly distributed random number in the closed interval <tt>[from,to]</tt> (including <tt>from</tt>
+ * and <tt>to</tt>). Pre conditions: <tt>from &lt;= to</tt>.
+ */
+ public long nextLongFromTo(long from, long to) {
+ /* Doing the thing turns out to be more tricky than expected.
+ avoids overflows and underflows.
+ treats cases like from=-1, to=1 and the like right.
+ the following code would NOT solve the problem: return (long) (Doubles.randomFromTo(from,to));
+
+ rounding avoids the unsymmetric behaviour of casts from double to long: (long) -0.7 = 0, (long) 0.7 = 0.
+ checking for overflows and underflows is also necessary.
+ */
+
+ // first the most likely and also the fastest case.
+ if (from >= 0 && to < Long.MAX_VALUE) {
+ return from + (long) nextDoubleFromTo(0.0, to - from + 1);
+ }
+
+ // would we get a numeric overflow?
+ // if not, we can still handle the case rather efficient.
+ double diff = (double) to - (double) from + 1.0;
+ if (diff <= Long.MAX_VALUE) {
+ return from + (long) nextDoubleFromTo(0.0, diff);
+ }
+
+ // now the pathologic boundary cases.
+ // they are handled rather slow.
+ long random;
+ if (from == Long.MIN_VALUE) {
+ if (to == Long.MAX_VALUE) {
+ //return Math.round(nextDoubleFromTo(from,to));
+ int i1 = nextIntFromTo(Integer.MIN_VALUE, Integer.MAX_VALUE);
+ int i2 = nextIntFromTo(Integer.MIN_VALUE, Integer.MAX_VALUE);
+ return ((i1 & 0xFFFFFFFFL) << 32) | (i2 & 0xFFFFFFFFL);
+ }
+ random = Math.round(nextDoubleFromTo(Long.MIN_VALUE, to + 1));
+ if (random > to) {
+ random = Long.MIN_VALUE;
+ }
+ } else {
+ random = Math.round(nextDoubleFromTo(from - 1, to));
+ if (random < from) {
+ random = to;
+ }
+ }
+ return random;
+ }
+
+ /** Returns the probability distribution function (assuming a continous uniform distribution). */
+ @Override
+ public double pdf(double x) {
+ if (x <= min || x >= max) {
+ return 0.0;
+ }
+ return 1.0 / (max - min);
+ }
+
+ /** Sets the internal state. */
+ public void setState(double min, double max) {
+ if (max < min) {
+ setState(max, min);
+ return;
+ }
+ this.min = min;
+ this.max = max;
+ }
+
+
+ /** Returns a String representation of the receiver. */
+ @Override
+ public String toString() {
+ return this.getClass().getName() + '(' + min + ',' + max + ')';
+ }
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/jet/random/engine/MersenneTwister.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/mahout/math/jet/random/engine/MersenneTwister.java b/core/src/main/java/org/apache/mahout/math/jet/random/engine/MersenneTwister.java
new file mode 100644
index 0000000..8bca895
--- /dev/null
+++ b/core/src/main/java/org/apache/mahout/math/jet/random/engine/MersenneTwister.java
@@ -0,0 +1,275 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/*
+Copyright 1999 CERN - European Organization for Nuclear Research.
+Permission to use, copy, modify, distribute and sell this software and its documentation for any purpose
+is hereby granted without fee, provided that the above copyright notice appear in all copies and
+that both that copyright notice and this permission notice appear in supporting documentation.
+CERN makes no representations about the suitability of this software for any purpose.
+It is provided "as is" without expressed or implied warranty.
+*/
+
+package org.apache.mahout.math.jet.random.engine;
+
+import java.util.Date;
+/**
+ MersenneTwister (MT19937) is one of the strongest uniform pseudo-random number generators
+ known so far; at the same time it is quick.
+ Produces uniformly distributed <tt>int</tt>'s and <tt>long</tt>'s in the closed intervals
+ <tt>[Integer.MIN_VALUE,Integer.MAX_VALUE]</tt> and <tt>[Long.MIN_VALUE,Long.MAX_VALUE]</tt>, respectively,
+ as well as <tt>float</tt>'s and <tt>double</tt>'s in the open unit intervals <tt>(0.0f,1.0f)</tt>
+ and <tt>(0.0,1.0)</tt>, respectively.
+ The seed can be any 32-bit integer except <tt>0</tt>. Shawn J. Cokus commented that perhaps the
+ seed should preferably be odd.
+ <p>
+ <b>Quality:</b> MersenneTwister is designed to pass the k-distribution test. It has an
+ astronomically large period of 2<sup>19937</sup>-1 (=10<sup>6001</sup>) and 623-dimensional
+ equidistribution up to 32-bit accuracy.
+ It passes many stringent statistical tests, including the
+ <A HREF="http://stat.fsu.edu/~geo/diehard.html">diehard</A> test of G. Marsaglia
+ and the load test of P. Hellekalek and S. Wegenkittl.
+ <p>
+ <b>Performance:</b> Its speed is comparable to other modern generators (in particular,
+ as fast as <tt>java.util.Random.nextFloat()</tt>).
+ 2.5 million calls to <tt>raw()</tt> per second (Pentium Pro 200 Mhz, JDK 1.2, NT).
+ Be aware, however, that there is a non-negligible amount of overhead required to initialize the data
+ structures used by a MersenneTwister. Code like
+ {@code
+ double sum = 0.0;
+ for (int i=0; i<100000; ++i) {
+ RandomElement twister = new MersenneTwister(new Date());
+ sum += twister.raw();
+ }
+ }
+ will be wildly inefficient. Consider using
+ {@code
+ double sum = 0.0;
+ RandomElement twister = new MersenneTwister(new Date());
+ for (int i=0; i&lt;100000; ++i) {
+ sum += twister.raw();
+ }
+ }
+ instead. This allows the cost of constructing the MersenneTwister object
+ to be borne only once, rather than once for each iteration in the loop.
+ <p>
+ <b>Implementation:</b> After M. Matsumoto and T. Nishimura,
+ "Mersenne Twister: A 623-Dimensionally Equidistributed Uniform Pseudo-Random Number Generator",
+ ACM Transactions on Modeling and Computer Simulation,
+ Vol. 8, No. 1, January 1998, pp 3--30.
+ <dl>
+ <dt>More info on <a HREF="http://www.math.keio.ac.jp/~matumoto/eindex.html"> Masumoto's homepage</a>.</dt>
+ <dt>More info on <a HREF="http://www.ncsa.uiuc.edu/Apps/CMP/RNG/www-rng.html"> Pseudo-random number
+ generators is on the Web</a>.</dt>
+ <dt>Yet <a HREF="http://nhse.npac.syr.edu/random"> some more info</a>.</dt>
+ <p>
+ The correctness of this implementation has been verified against the published output sequence
+ <a href="http://www.math.keio.ac.jp/~nisimura/random/real2/mt19937-2.out">mt19937-2.out</a> of the C-implementation
+ <a href="http://www.math.keio.ac.jp/~nisimura/random/real2/mt19937-2.c">mt19937-2.c</a>.
+ (Call <tt>test(1000)</tt> to print the sequence).
+ <dt>
+ Note that this implementation is <b>not synchronized</b>.</dt>
+ </dl>
+ <p>
+ <b>Details:</b> MersenneTwister is designed with consideration of the flaws of various existing generators in mind.
+ It is an improved version of TT800, a very successful generator.
+ MersenneTwister is based on linear recurrences modulo 2.
+ Such generators are very fast, have extremely long periods, and appear quite robust.
+ MersenneTwister produces 32-bit numbers, and every <tt>k</tt>-dimensional vector of such
+ numbers appears the same number of times as <tt>k</tt> successive values over the
+ period length, for each <tt>k &lt;= 623</tt> (except for the zero vector, which appears one time less).
+ If one looks at only the first <tt>n &lt;= 16</tt> bits of each number, then the property holds
+ for even larger <tt>k</tt>, as shown in the following table (taken from the publication cited above):
+ <table width="75%" border="1" cellspacing="0" cellpadding="0" summary="property table">
+ <tr>
+ <td width="2%" align="center"> <div>n</div> </td>
+ <td width="6%" align="center"> <div>1</div> </td>
+ <td width="5%" align="center"> <div>2</div> </td>
+ <td width="5%" align="center"> <div>3</div> </td>
+ <td width="5%" align="center"> <div>4</div> </td>
+ <td width="5%" align="center"> <div>5</div> </td>
+ <td width="5%" align="center"> <div>6</div> </td>
+ <td width="5%" align="center"> <div>7</div> </td>
+ <td width="5%" align="center"> <div>8</div> </td>
+ <td width="5%" align="center"> <div>9</div> </td>
+ <td width="5%" align="center"> <div>10</div> </td>
+ <td width="5%" align="center"> <div>11</div> </td>
+ <td width="10%" align="center"> <div>12 .. 16</div> </td>
+ <td width="10%" align="center"> <div>17 .. 32</div> </td>
+ </tr>
+ <tr>
+ <td width="2%" align="center"> <div>k</div> </td>
+ <td width="6%" align="center"> <div>19937</div> </td>
+ <td width="5%" align="center"> <div>9968</div> </td>
+ <td width="5%" align="center"> <div>6240</div> </td>
+ <td width="5%" align="center"> <div>4984</div> </td>
+ <td width="5%" align="center"> <div>3738</div> </td>
+ <td width="5%" align="center"> <div>3115</div> </td>
+ <td width="5%" align="center"> <div>2493</div> </td>
+ <td width="5%" align="center"> <div>2492</div> </td>
+ <td width="5%" align="center"> <div>1869</div> </td>
+ <td width="5%" align="center"> <div>1869</div> </td>
+ <td width="5%" align="center"> <div>1248</div> </td>
+ <td width="10%" align="center"> <div>1246</div> </td>
+ <td width="10%" align="center"> <div>623</div> </td>
+ </tr>
+ </table>
+ <p>
+ MersenneTwister generates random numbers in batches of 624 numbers at a time, so
+ the caching and pipelining of modern systems is exploited.
+ The generator is implemented to generate the output by using the fastest arithmetic
+ operations only: 32-bit additions and bit operations (no division, no multiplication, no mod).
+ These operations generate sequences of 32 random bits (<tt>int</tt>'s).
+ <tt>long</tt>'s are formed by concatenating two 32 bit <tt>int</tt>'s.
+ <tt>float</tt>'s are formed by dividing the interval <tt>[0.0,1.0]</tt> into 2<sup>32</sup>
+ sub intervals, then randomly choosing one subinterval.
+ <tt>double</tt>'s are formed by dividing the interval <tt>[0.0,1.0]</tt> into 2<sup>64</sup>
+ sub intervals, then randomly choosing one subinterval.
+ <p>
+ @author ***@cern.ch
+ @version 1.0, 09/24/99
+ @see java.util.Random
+ */
+public final class MersenneTwister extends RandomEngine {
+
+ /* Period parameters */
+ private static final int N = 624;
+ private static final int M = 397;
+ private static final int MATRIX_A = 0x9908b0df; /* constant vector a */
+ private static final int UPPER_MASK = 0x80000000; /* most significant w-r bits */
+ private static final int LOWER_MASK = 0x7fffffff; /* least significant r bits */
+
+ /* for tempering */
+ private static final int TEMPERING_MASK_B = 0x9d2c5680;
+ private static final int TEMPERING_MASK_C = 0xefc60000;
+
+ private static final int MAG0 = 0x0;
+ private static final int MAG1 = MATRIX_A;
+ //private static final int[] mag01=new int[] {0x0, MATRIX_A};
+ /* mag01[x] = x * MATRIX_A for x=0,1 */
+
+ private static final int DEFAULT_SEED = 4357;
+
+ private int mti;
+ private final int[] mt = new int[N]; /* set initial seeds: N = 624 words */
+
+ /**
+ * Constructs and returns a random number generator with a default seed, which is a <b>constant</b>. Thus using this
+ * constructor will yield generators that always produce exactly the same sequence. This method is mainly intended to
+ * ease testing and debugging.
+ */
+ public MersenneTwister() {
+ this(DEFAULT_SEED);
+ }
+
+ /** Constructs and returns a random number generator with the given seed.
+ * @param seed A number that is used to initialize the internal state of the generator.
+ */
+ public MersenneTwister(int seed) {
+ setSeed(seed);
+ }
+
+ /**
+ * Constructs and returns a random number generator seeded with the given date.
+ *
+ * @param d typically <tt>new Date()</tt>
+ */
+ public MersenneTwister(Date d) {
+ this((int) d.getTime());
+ }
+
+ /** Generates N words at one time */
+ void nextBlock() {
+ int y;
+ int kk;
+
+ for (kk = 0; kk < N - M; kk++) {
+ y = (mt[kk] & UPPER_MASK) | (mt[kk + 1] & LOWER_MASK);
+ mt[kk] = mt[kk + M] ^ (y >>> 1) ^ ((y & 0x1) == 0 ? MAG0 : MAG1);
+ }
+ for (; kk < N - 1; kk++) {
+ y = (mt[kk] & UPPER_MASK) | (mt[kk + 1] & LOWER_MASK);
+ mt[kk] = mt[kk + (M - N)] ^ (y >>> 1) ^ ((y & 0x1) == 0 ? MAG0 : MAG1);
+ }
+ y = (mt[N - 1] & UPPER_MASK) | (mt[0] & LOWER_MASK);
+ mt[N - 1] = mt[M - 1] ^ (y >>> 1) ^ ((y & 0x1) == 0 ? MAG0 : MAG1);
+
+ this.mti = 0;
+ }
+
+ /**
+ * Returns a 32 bit uniformly distributed random number in the closed interval
+ * <tt>[Integer.MIN_VALUE,Integer.MAX_VALUE]</tt>
+ * (including <tt>Integer.MIN_VALUE</tt> and <tt>Integer.MAX_VALUE</tt>).
+ */
+ @Override
+ public int nextInt() {
+ /* Each single bit including the sign bit will be random */
+ if (mti == N) {
+ nextBlock();
+ } // generate N ints at one time
+
+ int y = mt[mti++];
+ y ^= y >>> 11; // y ^= TEMPERING_SHIFT_U(y );
+ y ^= (y << 7) & TEMPERING_MASK_B; // y ^= TEMPERING_SHIFT_S(y) & TEMPERING_MASK_B;
+ y ^= (y << 15) & TEMPERING_MASK_C; // y ^= TEMPERING_SHIFT_T(y) & TEMPERING_MASK_C;
+ // y &= 0xffffffff; //you may delete this line if word size = 32
+ y ^= y >>> 18; // y ^= TEMPERING_SHIFT_L(y);
+
+ return y;
+ }
+
+ /** Sets the receiver's seed. This method resets the receiver's entire internal state.
+ * @param seed An integer that is used to reset the internal state of the generator */
+ void setSeed(int seed) {
+ mt[0] = seed;
+ for (int i = 1; i < N; i++) {
+ mt[i] = 1812433253 * (mt[i - 1] ^ (mt[i - 1] >> 30)) + i;
+ /* See Knuth TAOCP Vol2. 3rd Ed. P.106 for multiplier. */
+ /* In the previous versions, MSBs of the seed affect */
+ /* only MSBs of the array mt[]. */
+ /* 2002/01/09 modified by Makoto Matsumoto */
+ //mt[i] &= 0xffffffff;
+ /* for >32 bit machines */
+ }
+ //log.info("init done");
+ mti = N;
+ }
+
+ /**
+ * Sets the receiver's seed in a fashion compatible with the
+ * reference C implementation. See
+ * http://www.math.sci.hiroshima-u.ac.jp/~m-mat/MT/VERSIONS/C-LANG/980409/mt19937int.c
+ *
+ * This method isn't as good as the default method due to poor distribution of the
+ * resulting states.
+ *
+ * @param seed An integer that is used to reset the internal state in the same way as
+ * done in the 1999 reference implementation. Should only be used for testing, not
+ * actual coding.
+ */
+ void setReferenceSeed(int seed) {
+ for (int i = 0; i < N; i++) {
+ mt[i] = seed & 0xffff0000;
+ seed = 69069 * seed + 1;
+ mt[i] |= (seed & 0xffff0000) >>> 16;
+ seed = 69069 * seed + 1;
+ }
+ //log.info("init done");
+ mti = N;
+ }
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/jet/random/engine/RandomEngine.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/mahout/math/jet/random/engine/RandomEngine.java b/core/src/main/java/org/apache/mahout/math/jet/random/engine/RandomEngine.java
new file mode 100644
index 0000000..f832b1d
--- /dev/null
+++ b/core/src/main/java/org/apache/mahout/math/jet/random/engine/RandomEngine.java
@@ -0,0 +1,169 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*
+Copyright � 1999 CERN - European Organization for Nuclear Research.
+Permission to use, copy, modify, distribute and sell this software and its documentation for any purpose
+is hereby granted without fee, provided that the above copyright notice appear in all copies and
+that both that copyright notice and this permission notice appear in supporting documentation.
+CERN makes no representations about the suitability of this software for any purpose.
+It is provided "as is" without expressed or implied warranty.
+*/
+package org.apache.mahout.math.jet.random.engine;
+
+import org.apache.mahout.math.function.DoubleFunction;
+import org.apache.mahout.math.function.IntFunction;
+
+/**
+ * Abstract base class for uniform pseudo-random number generating engines.
+ * <p>
+ * Most probability distributions are obtained by using a <b>uniform</b> pseudo-random number generation engine
+ * followed by a transformation to the desired distribution.
+ * Thus, subclasses of this class are at the core of computational statistics, simulations, Monte Carlo methods, etc.
+ * <p>
+ * Subclasses produce uniformly distributed <tt>int</tt>'s and <tt>long</tt>'s in the closed intervals
+ * <tt>[Integer.MIN_VALUE,Integer.MAX_VALUE]</tt> and <tt>[Long.MIN_VALUE,Long.MAX_VALUE]</tt>, respectively,
+ * as well as <tt>float</tt>'s and <tt>double</tt>'s in the open unit intervals <tt>(0.0f,1.0f)</tt> and
+ * <tt>(0.0,1.0)</tt>, respectively.
+ * <p>
+ * Subclasses need to override one single method only: <tt>nextInt()</tt>.
+ * All other methods generating different data types or ranges are usually layered upon <tt>nextInt()</tt>.
+ * <tt>long</tt>'s are formed by concatenating two 32 bit <tt>int</tt>'s.
+ * <tt>float</tt>'s are formed by dividing the interval <tt>[0.0f,1.0f]</tt> into 2<sup>32</sup> sub intervals,
+ * then randomly choosing one subinterval.
+ * <tt>double</tt>'s are formed by dividing the interval <tt>[0.0,1.0]</tt> into 2<sup>64</sup> sub intervals,
+ * then randomly choosing one subinterval.
+ * <p>
+ * Note that this implementation is <b>not synchronized</b>.
+ *
+ * @see MersenneTwister
+ * @see java.util.Random
+ */
+public abstract class RandomEngine extends DoubleFunction implements IntFunction {
+
+ /**
+ * Equivalent to <tt>raw()</tt>. This has the effect that random engines can now be used as function objects,
+ * returning a random number upon function evaluation.
+ */
+ @Override
+ public double apply(double dummy) {
+ return raw();
+ }
+
+ /**
+ * Equivalent to <tt>nextInt()</tt>. This has the effect that random engines can now be used as function objects,
+ * returning a random number upon function evaluation.
+ */
+ @Override
+ public int apply(int dummy) {
+ return nextInt();
+ }
+
+ /**
+ * @return a 64 bit uniformly distributed random number in the open unit interval {@code (0.0,1.0)} (excluding
+ * 0.0 and 1.0).
+ */
+ public double nextDouble() {
+ double nextDouble;
+
+ do {
+ // -9.223372036854776E18 == (double) Long.MIN_VALUE
+ // 5.421010862427522E-20 == 1 / Math.pow(2,64) == 1 / ((double) Long.MAX_VALUE - (double) Long.MIN_VALUE);
+ nextDouble = (nextLong() - -9.223372036854776E18) * 5.421010862427522E-20;
+ }
+ // catch loss of precision of long --> double conversion
+ while (!(nextDouble > 0.0 && nextDouble < 1.0));
+
+ // --> in (0.0,1.0)
+ return nextDouble;
+
+ /*
+ nextLong == Long.MAX_VALUE --> 1.0
+ nextLong == Long.MIN_VALUE --> 0.0
+ nextLong == Long.MAX_VALUE-1 --> 1.0
+ nextLong == Long.MAX_VALUE-100000L --> 0.9999999999999946
+ nextLong == Long.MIN_VALUE+1 --> 0.0
+ nextLong == Long.MIN_VALUE-100000L --> 0.9999999999999946
+ nextLong == 1L --> 0.5
+ nextLong == -1L --> 0.5
+ nextLong == 2L --> 0.5
+ nextLong == -2L --> 0.5
+ nextLong == 2L+100000L --> 0.5000000000000054
+ nextLong == -2L-100000L --> 0.49999999999999456
+ */
+ }
+
+ /**
+ * @return a 32 bit uniformly distributed random number in the open unit interval {@code (0.0f, 1.0f)} (excluding
+ * 0.0f and 1.0f).
+ */
+ public float nextFloat() {
+ // catch loss of precision of double --> float conversion which could result in a value == 1.0F
+ float nextFloat;
+ do {
+ nextFloat = (float) raw();
+ }
+ while (nextFloat >= 1.0f);
+
+ // --> in [0.0f,1.0f)
+ return nextFloat;
+ }
+
+ /**
+ * @return a 32 bit uniformly distributed random number in the closed interval
+ * <tt>[Integer.MIN_VALUE,Integer.MAX_VALUE]</tt>
+ * (including <tt>Integer.MIN_VALUE</tt> and <tt>Integer.MAX_VALUE</tt>);
+ */
+ public abstract int nextInt();
+
+ /**
+ * @return a 64 bit uniformly distributed random number in the closed interval
+ * <tt>[Long.MIN_VALUE,Long.MAX_VALUE]</tt>
+ * (including <tt>Long.MIN_VALUE</tt> and <tt>Long.MAX_VALUE</tt>).
+ */
+ public long nextLong() {
+ // concatenate two 32-bit strings into one 64-bit string
+ return ((nextInt() & 0xFFFFFFFFL) << 32) | (nextInt() & 0xFFFFFFFFL);
+ }
+
+ /**
+ * @return a 32 bit uniformly distributed random number in the open unit interval {@code (0.0, 1.0)} (excluding
+ * 0.0 and 1.0).
+ */
+ public double raw() {
+ int nextInt;
+ do { // accept anything but zero
+ nextInt = nextInt(); // in [Integer.MIN_VALUE,Integer.MAX_VALUE]-interval
+ } while (nextInt == 0);
+
+ // transform to (0.0,1.0)-interval
+ // 2.3283064365386963E-10 == 1.0 / Math.pow(2,32)
+ return (nextInt & 0xFFFFFFFFL) * 2.3283064365386963E-10;
+
+ /*
+ nextInt == Integer.MAX_VALUE --> 0.49999999976716936
+ nextInt == Integer.MIN_VALUE --> 0.5
+ nextInt == Integer.MAX_VALUE-1 --> 0.4999999995343387
+ nextInt == Integer.MIN_VALUE+1 --> 0.5000000002328306
+ nextInt == 1 --> 2.3283064365386963E-10
+ nextInt == -1 --> 0.9999999997671694
+ nextInt == 2 --> 4.6566128730773926E-10
+ nextInt == -2 --> 0.9999999995343387
+ */
+ }
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/jet/random/engine/package-info.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/mahout/math/jet/random/engine/package-info.java b/core/src/main/java/org/apache/mahout/math/jet/random/engine/package-info.java
new file mode 100644
index 0000000..e092010
--- /dev/null
+++ b/core/src/main/java/org/apache/mahout/math/jet/random/engine/package-info.java
@@ -0,0 +1,7 @@
+/**
+ * Engines generating strong uniformly distributed pseudo-random numbers;
+ * Needed by all JET probability distributions since they rely on uniform random numbers to generate random
+ * numbers from their own distribution.
+ * Thus, the classes of this package are at the core of computational statistics, simulations, Monte Carlo methods, etc.
+ */
+package org.apache.mahout.math.jet.random.engine;
r***@apache.org
2018-09-08 23:35:11 UTC
Permalink
http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/function/Functions.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/mahout/math/function/Functions.java b/core/src/main/java/org/apache/mahout/math/function/Functions.java
new file mode 100644
index 0000000..f08c328
--- /dev/null
+++ b/core/src/main/java/org/apache/mahout/math/function/Functions.java
@@ -0,0 +1,1730 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/*
+Copyright 1999 CERN - European Organization for Nuclear Research.
+Permission to use, copy, modify, distribute and sell this software and its documentation for any purpose
+is hereby granted without fee, provided that the above copyright notice appear in all copies and
+that both that copyright notice and this permission notice appear in supporting documentation.
+CERN makes no representations about the suitability of this software for any purpose.
+It is provided "as is" without expressed or implied warranty.
+*/
+
+package org.apache.mahout.math.function;
+
+
+import com.google.common.base.Preconditions;
+import org.apache.mahout.math.jet.random.engine.MersenneTwister;
+
+import java.util.Date;
+
+
+/**
+ * Function objects to be passed to generic methods. Contains the functions of {@link java.lang.Math} as function
+ * objects, as well as a few more basic functions. <p>Function objects conveniently allow to express arbitrary functions
+ * in a generic manner. Essentially, a function object is an object that can perform a function on some arguments. It
+ * has a minimal interface: a method <tt>apply</tt> that takes the arguments, computes something and returns some result
+ * value. Function objects are comparable to function pointers in C used for call-backs. <p>Unary functions are of type
+ * {@link org.apache.mahout.math.function.DoubleFunction}, binary functions of type {@link
+ * org.apache.mahout.math.function.DoubleDoubleFunction}. All can be retrieved via <tt>public static final</tt>
+ * variables named after the function. Unary predicates are of type
+ * {@link DoubleProcedure},
+ * binary predicates of type {@link org.apache.mahout.math.function.DoubleDoubleProcedure}. All can be retrieved via
+ * <tt>public static final</tt> variables named <tt>isXXX</tt>.
+ *
+ * <p> Binary functions and predicates also exist as unary functions with the second argument being fixed to a constant.
+ * These are generated and retrieved via factory methods (again with the same name as the function). Example: <ul>
+ * <li><tt>Functions.pow</tt> gives the function <tt>a<sup>b</sup></tt>. <li><tt>Functions.pow.apply(2,3)==8</tt>.
+ * <li><tt>Functions.pow(3)</tt> gives the function <tt>a<sup>3</sup></tt>. <li><tt>Functions.pow(3).apply(2)==8</tt>.
+ * </ul> More general, any binary function can be made an unary functions by fixing either the first or the second
+ * argument. See methods {@link #bindArg1(org.apache.mahout.math.function.DoubleDoubleFunction ,double)} and {@link
+ * #bindArg2(org.apache.mahout.math.function.DoubleDoubleFunction ,double)}. The order of arguments can
+ * be swapped so that the first argument becomes the
+ * second and vice-versa. See method {@link #swapArgs(org.apache.mahout.math.function.DoubleDoubleFunction)}.
+ * Example: <ul> <li><tt>Functions.pow</tt>
+ * gives the function <tt>a<sup>b</sup></tt>. <li><tt>Functions.bindArg2(Functions.pow,3)</tt> gives the function
+ * <tt>x<sup>3</sup></tt>. <li><tt>Functions.bindArg1(Functions.pow,3)</tt> gives the function <tt>3<sup>x</sup></tt>.
+ * <li><tt>Functions.swapArgs(Functions.pow)</tt> gives the function <tt>b<sup>a</sup></tt>. </ul> <p> Even more
+ * general, functions can be chained (composed, assembled). Assume we have two unary functions <tt>g</tt> and
+ * <tt>h</tt>. The unary function <tt>g(h(a))</tt> applying both in sequence can be generated via {@link
+ * #chain(org.apache.mahout.math.function.DoubleFunction , org.apache.mahout.math.function.DoubleFunction)}:
+ * <ul> <li><tt>Functions.chain(g,h);</tt> </ul> Assume further we have a binary
+ * function <tt>f</tt>. The binary function <tt>g(f(a,b))</tt> can be generated via {@link
+ * #chain(org.apache.mahout.math.function.DoubleFunction , org.apache.mahout.math.function.DoubleDoubleFunction)}:
+ * <ul> <li><tt>Functions.chain(g,f);</tt> </ul> The binary function
+ * <tt>f(g(a),h(b))</tt> can be generated via
+ * {@link #chain(org.apache.mahout.math.function.DoubleDoubleFunction , org.apache.mahout.math.function.DoubleFunction ,
+ * org.apache.mahout.math.function.DoubleFunction)}: <ul>
+ * <li><tt>Functions.chain(f,g,h);</tt> </ul> Arbitrarily complex functions can be composed from these building blocks.
+ * For example <tt>sin(a) + cos<sup>2</sup>(b)</tt> can be specified as follows: <ul>
+ * <li><tt>chain(plus,sin,chain(square,cos));</tt> </ul> or, of course, as
+ * <pre>
+ * new DoubleDoubleFunction() {
+ * &nbsp;&nbsp;&nbsp;public final double apply(double a, double b) { return Math.sin(a) + Math.pow(Math.cos(b),2); }
+ * }
+ * </pre>
+ * <p> For aliasing see functions. Try this <table> <tr><td class="PRE">
+ * <pre>
+ * // should yield 1.4399560356056456 in all cases
+ * double a = 0.5;
+ * double b = 0.2;
+ * double v = Math.sin(a) + Math.pow(Math.cos(b),2);
+ * log.info(v);
+ * Functions F = Functions.functions;
+ * DoubleDoubleFunction f = F.chain(F.plus,F.sin,F.chain(F.square,F.cos));
+ * log.info(f.apply(a,b));
+ * DoubleDoubleFunction g = new DoubleDoubleFunction() {
+ * &nbsp;&nbsp;&nbsp;public double apply(double a, double b) { return Math.sin(a) + Math.pow(Math.cos(b),2); }
+ * };
+ * log.info(g.apply(a,b));
+ * </pre>
+ * </td></tr></table>
+ *
+ * <p> <H3>Performance</H3>
+ *
+ * Surprise. Using modern non-adaptive JITs such as SunJDK 1.2.2 (java -classic) there seems to be no or only moderate
+ * performance penalty in using function objects in a loop over traditional code in a loop. For complex nested function
+ * objects (e.g. <tt>F.chain(F.abs,F.chain(F.plus,F.sin,F.chain(F.square,F.cos)))</tt>) the penalty is zero, for trivial
+ * functions (e.g. <tt>F.plus</tt>) the penalty is often acceptable. <center> <table border cellpadding="3"
+ * cellspacing="0" align="center">
+ * <tr valign="middle" bgcolor="#33CC66" align="center"> <td nowrap colspan="7">
+ * <font size="+2">Iteration Performance [million function evaluations per second]</font><br> <font size="-1">Pentium
+ * Pro 200 Mhz, SunJDK 1.2.2, NT, java -classic, </font></td> </tr>
+ * <tr valign="middle" bgcolor="#66CCFF" align="center"> <td nowrap bgcolor="#FF9966" rowspan="2">&nbsp;</td> <td bgcolor="#FF9966" colspan="2"> <p> 30000000
+ * iterations</p> </td> <td bgcolor="#FF9966" colspan="2"> 3000000 iterations (10 times less)</td> <td bgcolor="#FF9966"
+ * colspan="2">&nbsp;</td> </tr>
+ * <tr valign="middle" bgcolor="#66CCFF" align="center"> <td nowrap bgcolor="#FF9966">
+ * <tt>F.plus</tt></td> <td bgcolor="#FF9966"><tt>a+b</tt></td> <td bgcolor="#FF9966">
+ * <tt>F.chain(F.abs,F.chain(F.plus,F.sin,F.chain(F.square,F.cos)))</tt></td> <td bgcolor="#FF9966">
+ * <tt>Math.abs(Math.sin(a) + Math.pow(Math.cos(b),2))</tt></td> <td bgcolor="#FF9966">&nbsp;</td> <td
+ * bgcolor="#FF9966">&nbsp;</td> </tr>
+ * <tr valign="middle" bgcolor="#66CCFF" align="center"> <td nowrap
+ * bgcolor="#FF9966">&nbsp;</td> <td nowrap>10.8</td> <td nowrap>29.6</td> <td nowrap>0.43</td> <td nowrap>0.35</td> <td
+ * nowrap>&nbsp;</td> <td nowrap>&nbsp;</td> </tr>
+ * </table></center>
+ */
+public final class Functions {
+
+ /*
+ * <H3>Unary functions</H3>
+ */
+ /** Function that returns <tt>Math.abs(a)</tt>. */
+ public static final DoubleFunction ABS = new DoubleFunction() {
+ @Override
+ public double apply(double a) {
+ return Math.abs(a);
+ }
+ };
+
+ /** Function that returns <tt>Math.acos(a)</tt>. */
+ public static final DoubleFunction ACOS = new DoubleFunction() {
+ @Override
+ public double apply(double a) {
+ return Math.acos(a);
+ }
+ };
+
+ /** Function that returns <tt>Math.asin(a)</tt>. */
+ public static final DoubleFunction ASIN = new DoubleFunction() {
+ @Override
+ public double apply(double a) {
+ return Math.asin(a);
+ }
+ };
+
+ /** Function that returns <tt>Math.atan(a)</tt>. */
+ public static final DoubleFunction ATAN = new DoubleFunction() {
+ @Override
+ public double apply(double a) {
+ return Math.atan(a);
+ }
+ };
+
+ /** Function that returns <tt>Math.ceil(a)</tt>. */
+ public static final DoubleFunction CEIL = new DoubleFunction() {
+
+ @Override
+ public double apply(double a) {
+ return Math.ceil(a);
+ }
+ };
+
+ /** Function that returns <tt>Math.cos(a)</tt>. */
+ public static final DoubleFunction COS = new DoubleFunction() {
+
+ @Override
+ public double apply(double a) {
+ return Math.cos(a);
+ }
+ };
+
+ /** Function that returns <tt>Math.exp(a)</tt>. */
+ public static final DoubleFunction EXP = new DoubleFunction() {
+
+ @Override
+ public double apply(double a) {
+ return Math.exp(a);
+ }
+ };
+
+ /** Function that returns <tt>Math.floor(a)</tt>. */
+ public static final DoubleFunction FLOOR = new DoubleFunction() {
+
+ @Override
+ public double apply(double a) {
+ return Math.floor(a);
+ }
+ };
+
+ /** Function that returns its argument. */
+ public static final DoubleFunction IDENTITY = new DoubleFunction() {
+
+ @Override
+ public double apply(double a) {
+ return a;
+ }
+ };
+
+ /** Function that returns <tt>1.0 / a</tt>. */
+ public static final DoubleFunction INV = new DoubleFunction() {
+
+ @Override
+ public double apply(double a) {
+ return 1.0 / a;
+ }
+ };
+
+ /** Function that returns <tt>Math.log(a)</tt>. */
+ public static final DoubleFunction LOGARITHM = new DoubleFunction() {
+
+ @Override
+ public double apply(double a) {
+ return Math.log(a);
+ }
+ };
+
+ /** Function that returns <tt>Math.log(a) / Math.log(2)</tt>. */
+ public static final DoubleFunction LOG2 = new DoubleFunction() {
+
+ @Override
+ public double apply(double a) {
+ return Math.log(a) * 1.4426950408889634;
+ }
+ };
+
+ /** Function that returns <tt>-a</tt>. */
+ public static final DoubleFunction NEGATE = new DoubleFunction() {
+
+ @Override
+ public double apply(double a) {
+ return -a;
+ }
+ };
+
+ /** Function that returns <tt>Math.rint(a)</tt>. */
+ public static final DoubleFunction RINT = new DoubleFunction() {
+
+ @Override
+ public double apply(double a) {
+ return Math.rint(a);
+ }
+ };
+
+ /**
+ * Function that returns {@code a < 0 ? -1 : a > 0 ? 1 : 0}.
+ */
+ public static final DoubleFunction SIGN = new DoubleFunction() {
+
+ @Override
+ public double apply(double a) {
+ return a < 0 ? -1 : a > 0 ? 1 : 0;
+ }
+ };
+
+ /** Function that returns <tt>Math.sin(a)</tt>. */
+ public static final DoubleFunction SIN = new DoubleFunction() {
+
+ @Override
+ public double apply(double a) {
+ return Math.sin(a);
+ }
+ };
+
+ /** Function that returns <tt>Math.sqrt(a)</tt>. */
+ public static final DoubleFunction SQRT = new DoubleFunction() {
+
+ @Override
+ public double apply(double a) {
+ return Math.sqrt(a);
+ }
+ };
+
+ /** Function that returns <tt>a * a</tt>. */
+ public static final DoubleFunction SQUARE = new DoubleFunction() {
+
+ @Override
+ public double apply(double a) {
+ return a * a;
+ }
+ };
+
+ /** Function that returns <tt> 1 / (1 + exp(-a) </tt> */
+ public static final DoubleFunction SIGMOID = new DoubleFunction() {
+ @Override
+ public double apply(double a) {
+ return 1.0 / (1.0 + Math.exp(-a));
+ }
+ };
+
+ /** Function that returns <tt> a * (1-a) </tt> */
+ public static final DoubleFunction SIGMOIDGRADIENT = new DoubleFunction() {
+ @Override
+ public double apply(double a) {
+ return a * (1.0 - a);
+ }
+ };
+
+ /** Function that returns <tt>Math.tan(a)</tt>. */
+ public static final DoubleFunction TAN = new DoubleFunction() {
+
+ @Override
+ public double apply(double a) {
+ return Math.tan(a);
+ }
+ };
+
+ /*
+ * <H3>Binary functions</H3>
+ */
+
+ /** Function that returns <tt>Math.atan2(a,b)</tt>. */
+ public static final DoubleDoubleFunction ATAN2 = new DoubleDoubleFunction() {
+
+ @Override
+ public double apply(double a, double b) {
+ return Math.atan2(a, b);
+ }
+ };
+
+ /**
+ * Function that returns <tt>a &lt; b ? -1 : a &gt; b ? 1 : 0</tt>.
+ */
+ public static final DoubleDoubleFunction COMPARE = new DoubleDoubleFunction() {
+
+ @Override
+ public double apply(double a, double b) {
+ return a < b ? -1 : a > b ? 1 : 0;
+ }
+ };
+
+ /** Function that returns <tt>a / b</tt>. */
+ public static final DoubleDoubleFunction DIV = new DoubleDoubleFunction() {
+
+ @Override
+ public double apply(double a, double b) {
+ return a / b;
+ }
+
+ /**
+ * x / 0 = infinity or undefined depending on x
+ * @return true iff f(x, 0) = x for any x
+ */
+ @Override
+ public boolean isLikeRightPlus() {
+ return false;
+ }
+
+ /**
+ * 0 / y = 0 unless y = 0
+ * @return true iff f(0, y) = 0 for any y
+ */
+ @Override
+ public boolean isLikeLeftMult() {
+ return false;
+ }
+
+ /**
+ * x / 0 = infinity or undefined depending on x
+ * @return true iff f(x, 0) = 0 for any x
+ */
+ @Override
+ public boolean isLikeRightMult() {
+ return false;
+ }
+
+ /**
+ * x / y != y / x
+ * @return true iff f(x, y) = f(y, x) for any x, y
+ */
+ @Override
+ public boolean isCommutative() {
+ return false;
+ }
+
+ /**
+ * x / (y / z) = x * z / y
+ * (x / y) / z = x / (y * z)
+ * @return true iff f(x, f(y, z)) = f(f(x, y), z) for any x, y, z
+ */
+ @Override
+ public boolean isAssociative() {
+ return false;
+ }
+
+ };
+
+ /** Function that returns <tt>a == b ? 1 : 0</tt>. */
+ public static final DoubleDoubleFunction EQUALS = new DoubleDoubleFunction() {
+
+ @Override
+ public double apply(double a, double b) {
+ return a == b ? 1 : 0;
+ }
+
+ /**
+ * x = y iff y = x
+ * @return true iff f(x, y) = f(y, x) for any x, y
+ */
+ @Override
+ public boolean isCommutative() {
+ return true;
+ }
+ };
+
+ /**
+ * Function that returns <tt>a &gt; b ? 1 : 0</tt>.
+ */
+ public static final DoubleDoubleFunction GREATER = new DoubleDoubleFunction() {
+
+ @Override
+ public double apply(double a, double b) {
+ return a > b ? 1 : 0;
+ }
+ };
+
+ /** Function that returns <tt>Math.IEEEremainder(a,b)</tt>. */
+ public static final DoubleDoubleFunction IEEE_REMAINDER = new DoubleDoubleFunction() {
+
+ @Override
+ public double apply(double a, double b) {
+ return Math.IEEEremainder(a, b);
+ }
+ };
+
+ /** Function that returns <tt>a == b</tt>. */
+ public static final DoubleDoubleProcedure IS_EQUAL = new DoubleDoubleProcedure() {
+
+ @Override
+ public boolean apply(double a, double b) {
+ return a == b;
+ }
+ };
+
+ /**
+ * Function that returns {@code a < b}.
+ */
+ public static final DoubleDoubleProcedure IS_LESS = new DoubleDoubleProcedure() {
+
+ @Override
+ public boolean apply(double a, double b) {
+ return a < b;
+ }
+ };
+
+ /**
+ * Function that returns {@code a > b}.
+ */
+ public static final DoubleDoubleProcedure IS_GREATER = new DoubleDoubleProcedure() {
+
+ @Override
+ public boolean apply(double a, double b) {
+ return a > b;
+ }
+ };
+
+ /**
+ * Function that returns <tt>a &lt; b ? 1 : 0</tt>.
+ */
+ public static final DoubleDoubleFunction LESS = new DoubleDoubleFunction() {
+
+ @Override
+ public double apply(double a, double b) {
+ return a < b ? 1 : 0;
+ }
+ };
+
+ /** Function that returns <tt>Math.log(a) / Math.log(b)</tt>. */
+ public static final DoubleDoubleFunction LG = new DoubleDoubleFunction() {
+
+ @Override
+ public double apply(double a, double b) {
+ return Math.log(a) / Math.log(b);
+ }
+ };
+
+ /** Function that returns <tt>Math.max(a,b)</tt>. */
+ public static final DoubleDoubleFunction MAX = new DoubleDoubleFunction() {
+
+ @Override
+ public double apply(double a, double b) {
+ return Math.max(a, b);
+ }
+
+ /**
+ * max(x, 0) = x or 0 depending on the sign of x
+ * @return true iff f(x, 0) = x for any x
+ */
+ @Override
+ public boolean isLikeRightPlus() {
+ return false;
+ }
+
+ /**
+ * max(0, y) = y or 0 depending on the sign of y
+ * @return true iff f(0, y) = 0 for any y
+ */
+ @Override
+ public boolean isLikeLeftMult() {
+ return false;
+ }
+
+ /**
+ * max(x, 0) = x or 0 depending on the sign of x
+ * @return true iff f(x, 0) = 0 for any x
+ */
+ @Override
+ public boolean isLikeRightMult() {
+ return false;
+ }
+
+ /**
+ * max(x, max(y, z)) = max(max(x, y), z)
+ * @return true iff f(x, f(y, z)) = f(f(x, y), z) for any x, y, z
+ */
+ @Override
+ public boolean isAssociative() {
+ return true;
+ }
+
+ /**
+ * max(x, y) = max(y, x)
+ * @return true iff f(x, y) = f(y, x) for any x, y
+ */
+ @Override
+ public boolean isCommutative() {
+ return true;
+ }
+ };
+
+ public static final DoubleDoubleFunction MAX_ABS = new DoubleDoubleFunction() {
+
+ @Override
+ public double apply(double a, double b) {
+ return Math.max(Math.abs(a), Math.abs(b));
+ }
+
+ /**
+ * max(|x|, 0) = |x|
+ * @return true iff f(x, 0) = x for any x
+ */
+ @Override
+ public boolean isLikeRightPlus() {
+ return true;
+ }
+
+ /**
+ * max(0, |y|) = |y|
+ * @return true iff f(0, y) = 0 for any y
+ */
+ @Override
+ public boolean isLikeLeftMult() {
+ return false;
+ }
+
+ /**
+ * max(|x|, 0) = |x|
+ * @return true iff f(x, 0) = 0 for any x
+ */
+ @Override
+ public boolean isLikeRightMult() {
+ return false;
+ }
+
+ /**
+ * max(|x|, max(|y|, |z|)) = max(max(|x|, |y|), |z|)
+ * @return true iff f(x, f(y, z)) = f(f(x, y), z) for any x, y, z
+ */
+ @Override
+ public boolean isAssociative() {
+ return true;
+ }
+
+ /**
+ * max(|x|, |y|) = max(|y\, |x\)
+ * @return true iff f(x, y) = f(y, x) for any x, y
+ */
+ @Override
+ public boolean isCommutative() {
+ return true;
+ }
+ };
+
+ /** Function that returns <tt>Math.min(a,b)</tt>. */
+ public static final DoubleDoubleFunction MIN = new DoubleDoubleFunction() {
+
+ @Override
+ public double apply(double a, double b) {
+ return Math.min(a, b);
+ }
+
+ /**
+ * min(x, 0) = x or 0 depending on the sign of x
+ * @return true iff f(x, 0) = x for any x
+ */
+ @Override
+ public boolean isLikeRightPlus() {
+ return false;
+ }
+
+ /**
+ * min(0, y) = y or 0 depending on the sign of y
+ * @return true iff f(0, y) = 0 for any y
+ */
+ @Override
+ public boolean isLikeLeftMult() {
+ return false;
+ }
+
+ /**
+ * min(x, 0) = x or 0 depending on the sign of x
+ * @return true iff f(x, 0) = 0 for any x
+ */
+ @Override
+ public boolean isLikeRightMult() {
+ return false;
+ }
+
+ /**
+ * min(x, min(y, z)) = min(min(x, y), z)
+ * @return true iff f(x, f(y, z)) = f(f(x, y), z) for any x, y, z
+ */
+ @Override
+ public boolean isAssociative() {
+ return true;
+ }
+
+ /**
+ * min(x, y) = min(y, x)
+ * @return true iff f(x, y) = f(y, x) for any x, y
+ */
+ @Override
+ public boolean isCommutative() {
+ return true;
+ }
+ };
+
+ /** Function that returns <tt>a - b</tt>. */
+ public static final DoubleDoubleFunction MINUS = plusMult(-1);
+
+ public static final DoubleDoubleFunction MINUS_SQUARED = new DoubleDoubleFunction() {
+ @Override
+ public double apply(double x, double y) {
+ return (x - y) * (x - y);
+ }
+
+ /**
+ * (x - 0)^2 = x^2 != x
+ * @return true iff f(x, 0) = x for any x
+ */
+ @Override
+ public boolean isLikeRightPlus() {
+ return false;
+ }
+
+ /**
+ * (0 - y)^2 != 0
+ * @return true iff f(0, y) = 0 for any y
+ */
+ @Override
+ public boolean isLikeLeftMult() {
+ return false;
+ }
+
+ /**
+ * (x - 0)^2 != x
+ * @return true iff f(x, 0) = 0 for any x
+ */
+ @Override
+ public boolean isLikeRightMult() {
+ return false;
+ }
+
+ /**
+ * (x - y)^2 = (y - x)^2
+ * @return true iff f(x, y) = f(y, x) for any x, y
+ */
+ @Override
+ public boolean isCommutative() {
+ return true;
+ }
+
+ /**
+ * (x - (y - z)^2)^2 != ((x - y)^2 - z)^2
+ * @return true iff f(x, f(y, z)) = f(f(x, y), z) for any x, y, z
+ */
+ @Override
+ public boolean isAssociative() {
+ return false;
+ }
+ };
+
+ /** Function that returns <tt>a % b</tt>. */
+ public static final DoubleDoubleFunction MOD = new DoubleDoubleFunction() {
+
+ @Override
+ public double apply(double a, double b) {
+ return a % b;
+ }
+ };
+
+ /** Function that returns <tt>a * b</tt>. */
+ public static final DoubleDoubleFunction MULT = new TimesFunction();
+
+ /** Function that returns <tt>a + b</tt>. */
+ public static final DoubleDoubleFunction PLUS = plusMult(1);
+
+ /** Function that returns <tt>Math.abs(a) + Math.abs(b)</tt>. */
+ public static final DoubleDoubleFunction PLUS_ABS = new DoubleDoubleFunction() {
+
+ @Override
+ public double apply(double a, double b) {
+ return Math.abs(a) + Math.abs(b);
+ }
+
+ /**
+ * abs(x) + abs(0) = abs(x) != x
+ * @return true iff f(x, 0) = x for any x
+ */
+ @Override
+ public boolean isLikeRightPlus() {
+ return false;
+ }
+
+ /**
+ * abs(0) + abs(y) = abs(y) != 0 unless y = 0
+ * @return true iff f(0, y) = 0 for any y
+ */
+ @Override
+ public boolean isLikeLeftMult() {
+ return false;
+ }
+
+ /**
+ * abs(x) + abs(0) = abs(x) != 0 unless x = 0
+ * @return true iff f(x, 0) = 0 for any x
+ */
+ @Override
+ public boolean isLikeRightMult() {
+ return false;
+ }
+
+ /**
+ * abs(x) + abs(abs(y) + abs(z)) = abs(x) + abs(y) + abs(z)
+ * abs(abs(x) + abs(y)) + abs(z) = abs(x) + abs(y) + abs(z)
+ * @return true iff f(x, f(y, z)) = f(f(x, y), z) for any x, y, z
+ */
+ @Override
+ public boolean isAssociative() {
+ return true;
+ }
+
+ /**
+ * abs(x) + abs(y) = abs(y) + abs(x)
+ * @return true iff f(x, y) = f(y, x) for any x, y
+ */
+ @Override
+ public boolean isCommutative() {
+ return true;
+ }
+ };
+
+ public static final DoubleDoubleFunction MINUS_ABS = new DoubleDoubleFunction() {
+ @Override
+ public double apply(double x, double y) {
+ return Math.abs(x - y);
+ }
+
+ /**
+ * |x - 0| = |x|
+ * @return true iff f(x, 0) = x for any x
+ */
+ @Override
+ public boolean isLikeRightPlus() {
+ return false;
+ }
+
+ /**
+ * |0 - y| = |y|
+ * @return true iff f(0, y) = 0 for any y
+ */
+ @Override
+ public boolean isLikeLeftMult() {
+ return false;
+ }
+
+ /**
+ * |x - 0| = |x|
+ * @return true iff f(x, 0) = 0 for any x
+ */
+ @Override
+ public boolean isLikeRightMult() {
+ return false;
+ }
+
+ /**
+ * |x - y| = |y - x|
+ * @return true iff f(x, y) = f(y, x) for any x, y
+ */
+ @Override
+ public boolean isCommutative() {
+ return true;
+ }
+
+ /**
+ * |x - |y - z|| != ||x - y| - z| (|5 - |4 - 3|| = 1; ||5 - 4| - 3| = |1 - 3| = 2)
+ * @return true iff f(x, f(y, z)) = f(f(x, y), z) for any x, y, z
+ */
+ @Override
+ public boolean isAssociative() {
+ return false;
+ }
+ };
+
+ /** Function that returns <tt>Math.pow(a,b)</tt>. */
+ public static final DoubleDoubleFunction POW = new DoubleDoubleFunction() {
+
+ @Override
+ public double apply(double a, double b) {
+ return Math.pow(a, b);
+ }
+
+ /**
+ * x^0 = 1 for any x unless x = 0 (undefined)
+ * @return true iff f(x, 0) = x for any x
+ */
+ @Override
+ public boolean isLikeRightPlus() {
+ return false;
+ }
+
+ /**
+ * 0^y = 0 for any y unless y = 0 (undefined, but Math.pow(0, 0) = 1)
+ * @return true iff f(0, y) = 0 for any y
+ */
+ @Override
+ public boolean isLikeLeftMult() {
+ return false;
+ }
+
+ /**
+ * x^0 = 1 for any x (even x = 0)
+ * @return true iff f(x, 0) = 0 for any x
+ */
+ @Override
+ public boolean isLikeRightMult() {
+ return false;
+ }
+
+ /**
+ * x^y != y^x (2^3 != 3^2)
+ * @return true iff f(x, y) = f(y, x) for any x, y
+ */
+ @Override
+ public boolean isCommutative() {
+ return false;
+ }
+
+ /**
+ * x^(y^z) != (x^y)^z ((2^3)^4 = 8^4 = 2^12 != 2^(3^4) = 2^81)
+ * @return true iff f(x, f(y, z)) = f(f(x, y), z) for any x, y, z
+ */
+ @Override
+ public boolean isAssociative() {
+ return false;
+ }
+ };
+
+ public static final DoubleDoubleFunction SECOND = new DoubleDoubleFunction() {
+ @Override
+ public double apply(double x, double y) {
+ return y;
+ }
+
+ /**
+ * f(x, 0) = x for any x
+ * @return true iff f(x, 0) = x for any x
+ */
+ @Override
+ public boolean isLikeRightPlus() {
+ return false;
+ }
+
+ /**
+ * f(0, y) = y for any y
+ * @return true iff f(0, y) = 0 for any y
+ */
+ @Override
+ public boolean isLikeLeftMult() {
+ return false;
+ }
+
+ /**
+ * f(x, 0) = 0 for any x
+ * @return true iff f(x, 0) = 0 for any x
+ */
+ @Override
+ public boolean isLikeRightMult() {
+ return true;
+ }
+
+ /**
+ * f(x, y) = x != y = f(y, x) for any x, y unless x = y
+ * @return true iff f(x, y) = f(y, x) for any x, y
+ */
+ @Override
+ public boolean isCommutative() {
+ return false;
+ }
+
+ /**
+ * f(x, f(y, z)) = f(x, z) = z
+ * f(f(x, y), z) = z
+ * @return true iff f(x, f(y, z)) = f(f(x, y), z) for any x, y, z
+ */
+ @Override
+ public boolean isAssociative() {
+ return true;
+ }
+ };
+
+ /**
+ * This function is specifically designed to be used when assigning a vector to one that is all zeros (created
+ * by like()). It enables iteration only through the nonzeros of the right hand side by declaring isLikeRightPlus
+ * to be true. This is NOT generally true for SECOND (hence the other function above).
+ */
+ public static final DoubleDoubleFunction SECOND_LEFT_ZERO = new DoubleDoubleFunction() {
+ @Override
+ public double apply(double x, double y) {
+ Preconditions.checkArgument(x == 0, "This special version of SECOND needs x == 0");
+ return y;
+ }
+
+ /**
+ * f(x, 0) = 0 for any x; we're only assigning to left hand sides that are strictly 0
+ * @return true iff f(x, 0) = x for any x
+ */
+ @Override
+ public boolean isLikeRightPlus() {
+ return true;
+ }
+
+ /**
+ * f(0, y) = y for any y
+ * @return true iff f(0, y) = 0 for any y
+ */
+ @Override
+ public boolean isLikeLeftMult() {
+ return false;
+ }
+
+ /**
+ * f(x, 0) = 0 for any x
+ * @return true iff f(x, 0) = 0 for any x
+ */
+ @Override
+ public boolean isLikeRightMult() {
+ return true;
+ }
+
+ /**
+ * f(x, y) = x != y = f(y, x) for any x, y unless x = y
+ * @return true iff f(x, y) = f(y, x) for any x, y
+ */
+ @Override
+ public boolean isCommutative() {
+ return false;
+ }
+
+ /**
+ * f(x, f(y, z)) = f(x, z) = z
+ * f(f(x, y), z) = z
+ * @return true iff f(x, f(y, z)) = f(f(x, y), z) for any x, y, z
+ */
+ @Override
+ public boolean isAssociative() {
+ return true;
+ }
+ };
+ public static final DoubleDoubleFunction MULT_SQUARE_LEFT = new DoubleDoubleFunction() {
+ @Override
+ public double apply(double x, double y) {
+ return x * x * y;
+ }
+
+ /**
+ * x * x * 0 = 0
+ * @return true iff f(x, 0) = x for any x
+ */
+ @Override
+ public boolean isLikeRightPlus() {
+ return false;
+ }
+
+ /**
+ * 0 * 0 * y = 0
+ * @return true iff f(0, y) = 0 for any y
+ */
+ @Override
+ public boolean isLikeLeftMult() {
+ return true;
+ }
+
+ /**
+ * x * x * 0 = 0
+ * @return true iff f(x, 0) = 0 for any x
+ */
+ @Override
+ public boolean isLikeRightMult() {
+ return true;
+ }
+
+ /**
+ * x * x * y != y * y * x
+ * @return true iff f(x, y) = f(y, x) for any x, y
+ */
+ @Override
+ public boolean isCommutative() {
+ return false;
+ }
+
+ /**
+ * x * x * y * y * z != x * x * y * x * x * y * z
+ * @return true iff f(x, f(y, z)) = f(f(x, y), z) for any x, y, z
+ */
+ @Override
+ public boolean isAssociative() {
+ return false;
+ }
+ };
+
+ public static final DoubleDoubleFunction MULT_RIGHT_PLUS1 = new DoubleDoubleFunction() {
+
+ /**
+ * Apply the function to the arguments and return the result
+ *
+ * @param x a double for the first argument
+ * @param y a double for the second argument
+ * @return the result of applying the function
+ */
+ @Override
+ public double apply(double x, double y) {
+ return x * (y + 1);
+ }
+
+ /**
+ * x * 1 = x
+ * @return true iff f(x, 0) = x for any x
+ */
+ @Override
+ public boolean isLikeRightPlus() {
+ return true;
+ }
+
+ /**
+ * 0 * y = 0
+ * @return true iff f(0, y) = 0 for any y
+ */
+ @Override
+ public boolean isLikeLeftMult() {
+ return true;
+ }
+
+ /**
+ * x * 1 = x != 0
+ * @return true iff f(x, 0) = 0 for any x
+ */
+ @Override
+ public boolean isLikeRightMult() {
+ return false;
+ }
+
+ /**
+ * x * (y + 1) != y * (x + 1)
+ * @return true iff f(x, y) = f(y, x) for any x, y
+ */
+ @Override
+ public boolean isCommutative() {
+ return false;
+ }
+
+ /**
+ * @return true iff f(x, f(y, z)) = f(f(x, y), z) for any x, y, z
+ */
+ @Override
+ public boolean isAssociative() {
+ return false;
+ }
+ };
+
+ public static DoubleDoubleFunction reweigh(final double wx, final double wy) {
+ final double tw = wx + wy;
+ return new DoubleDoubleFunction() {
+ @Override
+ public double apply(double x, double y) {
+ return (wx * x + wy * y) / tw;
+ }
+
+ /**
+ * f(x, 0) = wx * x / tw = x iff wx = tw (practically, impossible, as tw = wx + wy and wy > 0)
+ * @return true iff f(x, 0) = x for any x
+ */
+ @Override
+ public boolean isLikeRightPlus() {
+ return wx == tw;
+ }
+
+ /**
+ * f(0, y) = wy * y / tw = 0 iff y = 0
+ * @return true iff f(0, y) = 0 for any y
+ */
+ @Override
+ public boolean isLikeLeftMult() {
+ return false;
+ }
+
+ /**
+ * f(x, 0) = wx * x / tw = 0 iff x = 0
+ * @return true iff f(x, 0) = 0 for any x
+ */
+ @Override
+ public boolean isLikeRightMult() {
+ return false;
+ }
+
+ /**
+ * wx * x + wy * y = wx * y + wy * x iff wx = wy
+ * @return true iff f(x, y) = f(y, x) for any x, y
+ */
+ @Override
+ public boolean isCommutative() {
+ return wx == wy;
+ }
+
+ /**
+ * @return true iff f(x, f(y, z)) = f(f(x, y), z) for any x, y, z
+ */
+ @Override
+ public boolean isAssociative() {
+ return false;
+ }
+ };
+ }
+
+ private Functions() {
+ }
+
+ /**
+ * Constructs a function that returns {@code (from<=a && a<=to) ? 1 : 0}.
+ * <tt>a</tt> is a variable, <tt>from</tt> and <tt>to</tt> are fixed.
+ */
+ public static DoubleFunction between(final double from, final double to) {
+ return new DoubleFunction() {
+
+ @Override
+ public double apply(double a) {
+ return from <= a && a <= to ? 1 : 0;
+ }
+ };
+ }
+
+ /**
+ * Constructs a unary function from a binary function with the first operand (argument) fixed to the given constant
+ * <tt>c</tt>. The second operand is variable (free).
+ *
+ * @param function a binary function taking operands in the form <tt>function.apply(c,var)</tt>.
+ * @return the unary function <tt>function(c,var)</tt>.
+ */
+ public static DoubleFunction bindArg1(final DoubleDoubleFunction function, final double c) {
+ return new DoubleFunction() {
+
+ @Override
+ public double apply(double var) {
+ return function.apply(c, var);
+ }
+ };
+ }
+
+ /**
+ * Constructs a unary function from a binary function with the second operand (argument) fixed to the given constant
+ * <tt>c</tt>. The first operand is variable (free).
+ *
+ * @param function a binary function taking operands in the form <tt>function.apply(var,c)</tt>.
+ * @return the unary function <tt>function(var,c)</tt>.
+ */
+ public static DoubleFunction bindArg2(final DoubleDoubleFunction function, final double c) {
+ return new DoubleFunction() {
+
+ @Override
+ public double apply(double var) {
+ return function.apply(var, c);
+ }
+ };
+ }
+
+ /**
+ * Constructs the function <tt>f( g(a), h(b) )</tt>.
+ *
+ * @param f a binary function.
+ * @param g a unary function.
+ * @param h a unary function.
+ * @return the binary function <tt>f( g(a), h(b) )</tt>.
+ */
+ public static DoubleDoubleFunction chain(final DoubleDoubleFunction f, final DoubleFunction g,
+ final DoubleFunction h) {
+ return new DoubleDoubleFunction() {
+
+ @Override
+ public double apply(double a, double b) {
+ return f.apply(g.apply(a), h.apply(b));
+ }
+
+ /**
+ * fx(c, 0) = f(g(x), h(0)) = f(g(x), 0) = g(x) = x if h(0) = 0 and f isLikeRightPlus and g(x) = x
+ * Impossible to check whether g(x) = x for any x, so we return false.
+ * @return true iff f(x, 0) = x for any x
+ */
+ @Override
+ public boolean isLikeRightPlus() {
+ return false;
+ }
+
+ /**
+ * fc(0, y) = f(g(0), h(y)) = f(0, h(y)) = 0 if g(0) = 0 and f isLikeLeftMult
+ * @return true iff f(0, y) = 0 for any y
+ */
+ @Override
+ public boolean isLikeLeftMult() {
+ return g.apply(0) == 0 && f.isLikeLeftMult();
+ }
+
+ /**
+ * fc(x, 0) = f(g(x), h(0)) = f(g(x), 0) = 0 if h(0) = 0 and f isLikeRightMult
+ * @return true iff f(x, 0) = 0 for any x
+ */
+ @Override
+ public boolean isLikeRightMult() {
+ return h.apply(0) == 0 && f.isLikeRightMult();
+ }
+
+ /**
+ * fc(x, y) = f(g(x), h(y)) = f(h(y), g(x))
+ * fc(y, x) = f(g(y), h(x)) = f(h(x), g(y))
+ * Either g(x) = g(y) for any x, y and h(x) = h(y) for any x, y or g = h and f isCommutative.
+ * Can only check if g = h (reference equality, assuming they're both the same static function in
+ * this file) and f isCommutative. There are however other scenarios when this might happen that are NOT
+ * covered by this definition.
+ * @return true iff f(x, y) = f(y, x) for any x, y
+ */
+ @Override
+ public boolean isCommutative() {
+ return g.equals(h) && f.isCommutative();
+ }
+
+ /**
+ * fc(x, fc(y, z)) = f(g(x), h(f(g(y), h(z))))
+ * fc(fc(x, y), z) = f(g(f(g(x), h(y))), h(z))
+ * Impossible to check.
+ * @return true iff f(x, f(y, z)) = f(f(x, y), z) for any x, y, z
+ */
+ @Override
+ public boolean isAssociative() {
+ return false;
+ }
+ };
+ }
+
+ /**
+ * Constructs the function <tt>g( h(a,b) )</tt>.
+ *
+ * @param g a unary function.
+ * @param h a binary function.
+ * @return the binary function <tt>g( h(a,b) )</tt>.
+ */
+ public static DoubleDoubleFunction chain(final DoubleFunction g, final DoubleDoubleFunction h) {
+ return new DoubleDoubleFunction() {
+
+ @Override
+ public double apply(double a, double b) {
+ return g.apply(h.apply(a, b));
+ }
+
+ /**
+ * g(h(x, 0)) = g(x) = x for any x iff g(x) = x and h isLikeRightPlus
+ * Impossible to check.
+ * @return true iff f(x, 0) = x for any x
+ */
+ @Override
+ public boolean isLikeRightPlus() {
+ return false;
+ }
+
+ /**
+ * g(h(0, y)) = g(0) = 0 for any y iff g(0) = 0 and h isLikeLeftMult
+ * @return true iff f(0, y) = 0 for any y
+ */
+ @Override
+ public boolean isLikeLeftMult() {
+ return !g.isDensifying() && h.isLikeLeftMult();
+ }
+
+ /**
+ * g(h(x, 0)) = g(0) = 0 for any x iff g(0) = 0 and h isLikeRightMult
+ * @return true iff f(x, 0) = 0 for any x
+ */
+ @Override
+ public boolean isLikeRightMult() {
+ return !g.isDensifying() && h.isLikeRightMult();
+ }
+
+ /**
+ * fc(x, y) = g(h(x, y)) = g(h(y, x)) = fc(y, x) iff h isCommutative
+ * @return true iff f(x, y) = f(y, x) for any x, y
+ */
+ @Override
+ public boolean isCommutative() {
+ return h.isCommutative();
+ }
+
+ /**
+ * fc(x, fc(y, z)) = g(h(x, g(h(y, z)))
+ * fc(fc(x, y), z) = g(h(g(h(x, y)), z))
+ * Impossible to check.
+ * @return true iff f(x, f(y, z)) = f(f(x, y), z) for any x, y, z
+ */
+ @Override
+ public boolean isAssociative() {
+ return false;
+ }
+ };
+ }
+
+ /**
+ * Constructs the function <tt>g( h(a) )</tt>.
+ *
+ * @param g a unary function.
+ * @param h a unary function.
+ * @return the unary function <tt>g( h(a) )</tt>.
+ */
+ public static DoubleFunction chain(final DoubleFunction g, final DoubleFunction h) {
+ return new DoubleFunction() {
+
+ @Override
+ public double apply(double a) {
+ return g.apply(h.apply(a));
+ }
+ };
+ }
+
+ /**
+ * Constructs the function <tt>g( h(a) )</tt>.
+ *
+ * @param g a unary function.
+ * @param h an {@link IntIntFunction} function.
+ * @return the unary function <tt>g( h(a) )</tt>.
+ */
+ public static IntIntFunction chain(final DoubleFunction g, final IntIntFunction h) {
+ return new IntIntFunction() {
+
+ @Override
+ public double apply(int first, int second) {
+ return g.apply(h.apply(first, second));
+ }
+ };
+ }
+
+
+ /**
+ * Constructs a function that returns {@code a < b ? -1 : a > b ? 1 : 0}. <tt>a</tt> is a variable, <tt>b</tt> is
+ * fixed.
+ */
+ public static DoubleFunction compare(final double b) {
+ return new DoubleFunction() {
+
+ @Override
+ public double apply(double a) {
+ return a < b ? -1 : a > b ? 1 : 0;
+ }
+ };
+ }
+
+ /** Constructs a function that returns the constant <tt>c</tt>. */
+ public static DoubleFunction constant(final double c) {
+ return new DoubleFunction() {
+
+ @Override
+ public double apply(double a) {
+ return c;
+ }
+ };
+ }
+
+
+ /** Constructs a function that returns <tt>a / b</tt>. <tt>a</tt> is a variable, <tt>b</tt> is fixed. */
+ public static DoubleFunction div(double b) {
+ return mult(1 / b);
+ }
+
+ /** Constructs a function that returns <tt>a == b ? 1 : 0</tt>. <tt>a</tt> is a variable, <tt>b</tt> is fixed. */
+ public static DoubleFunction equals(final double b) {
+ return new DoubleFunction() {
+
+ @Override
+ public double apply(double a) {
+ return a == b ? 1 : 0;
+ }
+ };
+ }
+
+ /** Constructs a function that returns <tt>a != b ? 1 : 0</tt>. <tt>a</tt> is a variable, <tt>b</tt> is fixed. */
+ public static DoubleFunction notEqual(final double b) {
+ return new DoubleFunction() {
+
+ @Override
+ public double apply(double a) {
+ return a != b ? 1 : 0;
+ }
+ };
+ }
+
+ /**
+ * Constructs a function that returns <tt>a &gt; b ? 1 : 0</tt>. <tt>a</tt>
+ * is a variable, <tt>b</tt> is fixed.
+ */
+ public static DoubleFunction greater(final double b) {
+ return new DoubleFunction() {
+
+ @Override
+ public double apply(double a) {
+ return a > b ? 1 : 0;
+ }
+ };
+ }
+
+ /**
+ * Constructs a function that returns <tt>Math.IEEEremainder(a,b)</tt>. <tt>a</tt> is a variable, <tt>b</tt> is
+ * fixed.
+ */
+ public static DoubleFunction mathIEEEremainder(final double b) {
+ return new DoubleFunction() {
+
+ @Override
+ public double apply(double a) {
+ return Math.IEEEremainder(a, b);
+ }
+ };
+ }
+
+ /**
+ * Constructs a function that returns {@code from<=a && a<=to}. <tt>a</tt>
+ * is a variable, <tt>from</tt> and
+ * <tt>to</tt> are fixed.
+ *
+ * Note that DoubleProcedure is generated code and thus looks like an invalid reference unless you can see
+ * the generated stuff.
+ */
+ public static DoubleProcedure isBetween(final double from, final double to) {
+ return new DoubleProcedure() {
+
+ @Override
+ public boolean apply(double a) {
+ return from <= a && a <= to;
+ }
+ };
+ }
+
+ /**
+ * Constructs a function that returns <tt>a == b</tt>. <tt>a</tt> is a
+ * variable, <tt>b</tt> is fixed.
+ */
+ public static DoubleProcedure isEqual(final double b) {
+ return new DoubleProcedure() {
+
+ @Override
+ public boolean apply(double a) {
+ return a == b;
+ }
+ };
+ }
+
+ /**
+ * Constructs a function that returns <tt>a &gt; b</tt>. <tt>a</tt> is a
+ * variable, <tt>b</tt> is fixed.
+ */
+ public static DoubleProcedure isGreater(final double b) {
+ return new DoubleProcedure() {
+
+ @Override
+ public boolean apply(double a) {
+ return a > b;
+ }
+ };
+ }
+
+ /**
+ * Constructs a function that returns {@code a < b}. <tt>a</tt> is a
+ * variable, <tt>b</tt> is fixed.
+ */
+ public static DoubleProcedure isLess(final double b) {
+ return new DoubleProcedure() {
+
+ @Override
+ public boolean apply(double a) {
+ return a < b;
+ }
+ };
+ }
+
+ /**
+ * Constructs a function that returns <tt>a &lt; b ? 1 : 0</tt>. <tt>a</tt> is a
+ * variable, <tt>b</tt> is fixed.
+ */
+ public static DoubleFunction less(final double b) {
+ return new DoubleFunction() {
+
+ @Override
+ public double apply(double a) {
+ return a < b ? 1 : 0;
+ }
+ };
+ }
+
+ /**
+ * Constructs a function that returns <tt>Math.log(a) / Math.log(b)</tt>.
+ * <tt>a</tt> is a variable, <tt>b</tt> is fixed.
+ */
+ public static DoubleFunction lg(final double b) {
+ return new DoubleFunction() {
+ private final double logInv = 1 / Math.log(b); // cached for speed
+
+
+ @Override
+ public double apply(double a) {
+ return Math.log(a) * logInv;
+ }
+ };
+ }
+
+ /** Constructs a function that returns <tt>Math.max(a,b)</tt>. <tt>a</tt> is a variable, <tt>b</tt> is fixed. */
+ public static DoubleFunction max(final double b) {
+ return new DoubleFunction() {
+
+ @Override
+ public double apply(double a) {
+ return Math.max(a, b);
+ }
+ };
+ }
+
+ /** Constructs a function that returns <tt>Math.min(a,b)</tt>. <tt>a</tt> is a variable, <tt>b</tt> is fixed. */
+ public static DoubleFunction min(final double b) {
+ return new DoubleFunction() {
+
+ @Override
+ public double apply(double a) {
+ return Math.min(a, b);
+ }
+ };
+ }
+
+ /** Constructs a function that returns <tt>a - b</tt>. <tt>a</tt> is a variable, <tt>b</tt> is fixed. */
+ public static DoubleFunction minus(double b) {
+ return plus(-b);
+ }
+
+ /**
+ * Constructs a function that returns <tt>a - b*constant</tt>. <tt>a</tt> and <tt>b</tt> are variables,
+ * <tt>constant</tt> is fixed.
+ */
+ public static DoubleDoubleFunction minusMult(double constant) {
+ return plusMult(-constant);
+ }
+
+ /** Constructs a function that returns <tt>a % b</tt>. <tt>a</tt> is a variable, <tt>b</tt> is fixed. */
+ public static DoubleFunction mod(final double b) {
+ return new DoubleFunction() {
+
+ @Override
+ public double apply(double a) {
+ return a % b;
+ }
+ };
+ }
+
+ /** Constructs a function that returns <tt>a * b</tt>. <tt>a</tt> is a variable, <tt>b</tt> is fixed. */
+ public static DoubleFunction mult(double b) {
+ return new Mult(b);
+ /*
+ return new DoubleFunction() {
+ public final double apply(double a) { return a * b; }
+ };
+ */
+ }
+
+ /** Constructs a function that returns <tt>a + b</tt>. <tt>a</tt> is a variable, <tt>b</tt> is fixed. */
+ public static DoubleFunction plus(final double b) {
+ return new DoubleFunction() {
+
+ @Override
+ public double apply(double a) {
+ return a + b;
+ }
+ };
+ }
+
+ /**
+ * Constructs a function that returns <tt>a + b*constant</tt>. <tt>a</tt> and <tt>b</tt> are variables,
+ * <tt>constant</tt> is fixed.
+ */
+ public static DoubleDoubleFunction plusMult(double constant) {
+ return new PlusMult(constant);
+ }
+
+ /** Constructs a function that returns <tt>Math.pow(a,b)</tt>. <tt>a</tt> is a variable, <tt>b</tt> is fixed. */
+ public static DoubleFunction pow(final double b) {
+ return new DoubleFunction() {
+
+ @Override
+ public double apply(double a) {
+ if (b == 2) {
+ return a * a;
+ } else {
+ return Math.pow(a, b);
+ }
+ }
+ };
+ }
+
+ /**
+ * Constructs a function that returns a new uniform random number in the open unit interval {@code (0.0,1.0)}
+ * (excluding 0.0 and 1.0). Currently the engine is {@link MersenneTwister} and is
+ * seeded with the current time. <p> Note that any random engine derived from {@link
+ * org.apache.mahout.math.jet.random.engine.RandomEngine} and any random distribution derived from {@link
+ * org.apache.mahout.math.jet.random.AbstractDistribution} are function objects, because they implement the proper
+ * interfaces. Thus, if you are not happy with the default, just pass your favourite random generator to function
+ * evaluating methods.
+ */
+ public static DoubleFunction random() {
+ return new MersenneTwister(new Date());
+ }
+
+ /**
+ * Constructs a function that returns the number rounded to the given precision;
+ * <tt>Math.rint(a/precision)*precision</tt>. Examples:
+ * {@code
+ * precision = 0.01 rounds 0.012 --> 0.01, 0.018 --> 0.02
+ * precision = 10 rounds 123 --> 120 , 127 --> 130
+ * }
+ */
+ public static DoubleFunction round(final double precision) {
+ return new DoubleFunction() {
+ @Override
+ public double apply(double a) {
+ return Math.rint(a / precision) * precision;
+ }
+ };
+ }
+
+ /**
+ * Constructs a function that returns <tt>function.apply(b,a)</tt>, i.e. applies the function with the first operand
+ * as second operand and the second operand as first operand.
+ *
+ * @param function a function taking operands in the form <tt>function.apply(a,b)</tt>.
+ * @return the binary function <tt>function(b,a)</tt>.
+ */
+ public static DoubleDoubleFunction swapArgs(final DoubleDoubleFunction function) {
+ return new DoubleDoubleFunction() {
+ @Override
+ public double apply(double a, double b) {
+ return function.apply(b, a);
+ }
+ };
+ }
+
+ public static DoubleDoubleFunction minusAbsPow(final double exponent) {
+ return new DoubleDoubleFunction() {
+ @Override
+ public double apply(double x, double y) {
+ return Math.pow(Math.abs(x - y), exponent);
+ }
+
+ /**
+ * |x - 0|^p = |x|^p != x unless x > 0 and p = 1
+ * @return true iff f(x, 0) = x for any x
+ */
+ @Override
+ public boolean isLikeRightPlus() {
+ return false;
+ }
+
+ /**
+ * |0 - y|^p = |y|^p
+ * @return true iff f(0, y) = 0 for any y
+ */
+ @Override
+ public boolean isLikeLeftMult() {
+ return false;
+ }
+
+ /**
+ * |x - 0|^p = |x|^p
+ * @return true iff f(x, 0) = 0 for any x
+ */
+ @Override
+ public boolean isLikeRightMult() {
+ return false;
+ }
+
+ /**
+ * |x - y|^p = |y - x|^p
+ * @return true iff f(x, y) = f(y, x) for any x, y
+ */
+ @Override
+ public boolean isCommutative() {
+ return true;
+ }
+
+ /**
+ * |x - |y - z|^p|^p != ||x - y|^p - z|^p
+ * @return true iff f(x, f(y, z)) = f(f(x, y), z) for any x, y, z
+ */
+ @Override
+ public boolean isAssociative() {
+ return false;
+ }
+ };
+ }
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/function/IntFunction.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/mahout/math/function/IntFunction.java b/core/src/main/java/org/apache/mahout/math/function/IntFunction.java
new file mode 100644
index 0000000..b91fe18
--- /dev/null
+++ b/core/src/main/java/org/apache/mahout/math/function/IntFunction.java
@@ -0,0 +1,41 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/*
+Copyright 1999 CERN - European Organization for Nuclear Research.
+Permission to use, copy, modify, distribute and sell this software and its documentation for any purpose
+is hereby granted without fee, provided that the above copyright notice appear in all copies and
+that both that copyright notice and this permission notice appear in supporting documentation.
+CERN makes no representations about the suitability of this software for any purpose.
+It is provided "as is" without expressed or implied warranty.
+*/
+
+package org.apache.mahout.math.function;
+
+/**
+ * Interface that represents a function object: a function that takes a single argument and returns a single value.
+ */
+public interface IntFunction {
+
+ /**
+ * Applies a function to an argument.
+ *
+ * @param argument argument passed to the function.
+ * @return the result of the function.
+ */
+ int apply(int argument);
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/function/IntIntDoubleFunction.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/mahout/math/function/IntIntDoubleFunction.java b/core/src/main/java/org/apache/mahout/math/function/IntIntDoubleFunction.java
new file mode 100644
index 0000000..b08f08b
--- /dev/null
+++ b/core/src/main/java/org/apache/mahout/math/function/IntIntDoubleFunction.java
@@ -0,0 +1,43 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/*
+Copyright 1999 CERN - European Organization for Nuclear Research.
+Permission to use, copy, modify, distribute and sell this software and its documentation for any purpose
+is hereby granted without fee, provided that the above copyright notice appear in all copies and
+that both that copyright notice and this permission notice appear in supporting documentation.
+CERN makes no representations about the suitability of this software for any purpose.
+It is provided "as is" without expressed or implied warranty.
+*/
+
+package org.apache.mahout.math.function;
+
+/**
+ * Interface that represents a function object: a function that takes three arguments.
+ */
+public interface IntIntDoubleFunction {
+
+ /**
+ * Applies a function to three arguments.
+ *
+ * @param first first argument passed to the function.
+ * @param second second argument passed to the function.
+ * @param third third argument passed to the function.
+ * @return the result of the function.
+ */
+ double apply(int first, int second, double third);
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/function/IntIntFunction.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/mahout/math/function/IntIntFunction.java b/core/src/main/java/org/apache/mahout/math/function/IntIntFunction.java
new file mode 100644
index 0000000..f08bb28
--- /dev/null
+++ b/core/src/main/java/org/apache/mahout/math/function/IntIntFunction.java
@@ -0,0 +1,25 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math.function;
+
+/**
+ * A function that takes to integer arguments and returns Double.
+ */
+public interface IntIntFunction {
+ double apply(int first, int second);
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/function/Mult.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/mahout/math/function/Mult.java b/core/src/main/java/org/apache/mahout/math/function/Mult.java
new file mode 100644
index 0000000..9bbc5ec
--- /dev/null
+++ b/core/src/main/java/org/apache/mahout/math/function/Mult.java
@@ -0,0 +1,71 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/*
+Copyright 1999 CERN - European Organization for Nuclear Research.
+Permission to use, copy, modify, distribute and sell this software and its documentation for any purpose
+is hereby granted without fee, provided that the above copyright notice appear in all copies and
+that both that copyright notice and this permission notice appear in supporting documentation.
+CERN makes no representations about the suitability of this software for any purpose.
+It is provided "as is" without expressed or implied warranty.
+*/
+
+package org.apache.mahout.math.function;
+
+/**
+ * Only for performance tuning of compute intensive linear algebraic computations.
+ * Constructs functions that return one of
+ * <ul>
+ * <li><tt>a * constant</tt>
+ * <li><tt>a / constant</tt>
+ * </ul>
+ * <tt>a</tt> is variable, <tt>constant</tt> is fixed, but for performance reasons publicly accessible.
+ * Intended to be passed to <tt>matrix.assign(function)</tt> methods.
+ */
+
+public final class Mult extends DoubleFunction {
+
+ private double multiplicator;
+
+ Mult(double multiplicator) {
+ this.multiplicator = multiplicator;
+ }
+
+ /** Returns the result of the function evaluation. */
+ @Override
+ public double apply(double a) {
+ return a * multiplicator;
+ }
+
+ /** <tt>a / constant</tt>. */
+ public static Mult div(double constant) {
+ return mult(1 / constant);
+ }
+
+ /** <tt>a * constant</tt>. */
+ public static Mult mult(double constant) {
+ return new Mult(constant);
+ }
+
+ public double getMultiplicator() {
+ return multiplicator;
+ }
+
+ public void setMultiplicator(double multiplicator) {
+ this.multiplicator = multiplicator;
+ }
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/function/ObjectObjectProcedure.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/mahout/math/function/ObjectObjectProcedure.java b/core/src/main/java/org/apache/mahout/math/function/ObjectObjectProcedure.java
new file mode 100644
index 0000000..46ad8d0
--- /dev/null
+++ b/core/src/main/java/org/apache/mahout/math/function/ObjectObjectProcedure.java
@@ -0,0 +1,40 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.mahout.math.function;
+
+/**
+ * Interface that represents a procedure object:
+ * a procedure that takes two arguments and returns a 'continue' flag.
+ */
+public interface ObjectObjectProcedure<K,V> {
+
+ /**
+ * Applies a procedure to an argument. Optionally can return a boolean flag to inform the object calling the
+ * procedure.
+ *
+ * <p>Example: forEach() methods often use procedure objects. To signal to a forEach() method whether iteration should
+ * continue normally or terminate (because for example a matching element has been found), a procedure can return
+ * <tt>false</tt> to indicate termination and <tt>true</tt> to indicate continuation.
+ *
+ * @param key key value passed to the procedure
+ * @param value value value passed to the procedure.
+ * @return a flag to inform the object calling the procedure.
+ */
+ boolean apply(K key, V value);
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/function/ObjectProcedure.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/mahout/math/function/ObjectProcedure.java b/core/src/main/java/org/apache/mahout/math/function/ObjectProcedure.java
new file mode 100644
index 0000000..8c1b1c8
--- /dev/null
+++ b/core/src/main/java/org/apache/mahout/math/function/ObjectProcedure.java
@@ -0,0 +1,47 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/*
+Copyright 1999 CERN - European Organization for Nuclear Research.
+Permission to use, copy, modify, distribute and sell this software and its documentation for any purpose
+is hereby granted without fee, provided that the above copyright notice appear in all copies and
+that both that copyright notice and this permission notice appear in supporting documentation.
+CERN makes no representations about the suitability of this software for any purpose.
+It is provided "as is" without expressed or implied warranty.
+*/
+
+
+package org.apache.mahout.math.function;
+
+/**
+ * Interface that represents a procedure object: a procedure that takes a single argument and does not return a value.
+ */
+public interface ObjectProcedure<T> {
+
+ /**
+ * Applies a procedure to an argument. Optionally can return a boolean flag to inform the object calling the
+ * procedure.
+ *
+ * <p>Example: forEach() methods often use procedure objects. To signal to a forEach() method whether iteration should
+ * continue normally or terminate (because for example a matching element has been found), a procedure can return
+ * <tt>false</tt> to indicate termination and <tt>true</tt> to indicate continuation.
+ *
+ * @param element element passed to the procedure.
+ * @return a flag to inform the object calling the procedure.
+ */
+ boolean apply(T element);
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/function/PlusMult.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/mahout/math/function/PlusMult.java b/core/src/main/java/org/apache/mahout/math/function/PlusMult.java
new file mode 100644
index 0000000..ff99a70
--- /dev/null
+++ b/core/src/main/java/org/apache/mahout/math/function/PlusMult.java
@@ -0,0 +1,123 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/*
+Copyright 1999 CERN - European Organization for Nuclear Research.
+Permission to use, copy, modify, distribute and sell this software and its documentation for any purpose
+is hereby granted without fee, provided that the above copyright notice appear in all copies and
+that both that copyright notice and this permission notice appear in supporting documentation.
+CERN makes no representations about the suitability of this software for any purpose.
+It is provided "as is" without expressed or implied warranty.
+*/
+
+package org.apache.mahout.math.function;
+
+import org.apache.mahout.math.jet.math.Constants;
+
+/**
+ * Only for performance tuning of compute intensive linear algebraic computations.
+ * Constructs functions that return one of
+ * <ul>
+ * <li><tt>a + b*constant</tt>
+ * <li><tt>a - b*constant</tt>
+ * <li><tt>a + b/constant</tt>
+ * <li><tt>a - b/constant</tt>
+ * </ul>
+ * <tt>a</tt> and <tt>b</tt> are variables, <tt>constant</tt> is fixed, but for performance reasons publicly accessible.
+ * Intended to be passed to <tt>matrix.assign(otherMatrix,function)</tt> methods.
+ */
+
+public final class PlusMult extends DoubleDoubleFunction {
+
+ private double multiplicator;
+
+ public PlusMult(double multiplicator) {
+ this.multiplicator = multiplicator;
+ }
+
+ /** Returns the result of the function evaluation. */
+ @Override
+ public double apply(double a, double b) {
+ return a + b * multiplicator;
+ }
+
+ /** <tt>a - b*constant</tt>. */
+ public static PlusMult minusMult(double constant) {
+ return new PlusMult(-constant);
+ }
+
+ /** <tt>a + b*constant</tt>. */
+ public static PlusMult plusMult(double constant) {
+ return new PlusMult(constant);
+ }
+
+ public double getMultiplicator() {
+ return multiplicator;
+ }
+
+ /**
+ * x + 0 * c = x
+ * @return true iff f(x, 0) = x for any x
+ */
+ @Override
+ public boolean isLikeRightPlus() {
+ return true;
+ }
+
+ /**
+ * 0 + y * c = y * c != 0
+ * @return true iff f(0, y) = 0 for any y
+ */
+ @Override
+ public boolean isLikeLeftMult() {
+ return false;
+ }
+
+ /**
+ * x + 0 * c = x != 0
+ * @return true iff f(x, 0) = 0 for any x
+ */
+ @Override
+ public boolean isLikeRightMult() {
+ return false;
+ }
+
+ /**
+ * x + y * c = y + x * c iff c = 1
+ * @return true iff f(x, y) = f(y, x) for any x, y
+ */
+ @Override
+ public boolean isCommutative() {
+ return Math.abs(multiplicator - 1.0) < Constants.EPSILON;
+ }
+
+ /**
+ * f(x, f(y, z)) = x + c * (y + c * z) = x + c * y + c^2 * z
+ * f(f(x, y), z) = (x + c * y) + c * z = x + c * y + c * z
+ * true only for c = 0 or c = 1
+ * @return true iff f(x, f(y, z)) = f(f(x, y), z) for any x, y, z
+ */
+ @Override
+ public boolean isAssociative() {
+ return Math.abs(multiplicator - 0.0) < Constants.EPSILON
+ || Math.abs(multiplicator - 1.0) < Constants.EPSILON;
+ }
+
+ public void setMultiplicator(double multiplicator) {
+ this.multiplicator = multiplicator;
+ }
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/function/SquareRootFunction.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/mahout/math/function/SquareRootFunction.java b/core/src/main/java/org/apache/mahout/math/function/SquareRootFunction.java
new file mode 100644
index 0000000..5eebea0
--- /dev/null
+++ b/core/src/main/java/org/apache/mahout/math/function/SquareRootFunction.java
@@ -0,0 +1,26 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math.function;
+
+public final class SquareRootFunction extends DoubleFunction {
+
+ @Override
+ public double apply(double arg1) {
+ return Math.sqrt(arg1);
+ }
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/function/TimesFunction.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/mahout/math/function/TimesFunction.java b/core/src/main/java/org/apache/mahout/math/function/TimesFunction.java
new file mode 100644
index 0000000..e4e27b4
--- /dev/null
+++ b/core/src/main/java/org/apache/mahout/math/function/TimesFunction.java
@@ -0,0 +1,77 @@
+/* Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.math.function;
+
+public final class TimesFunction extends DoubleDoubleFunction {
+
+ /**
+ * Computes the product of two numbers.
+ *
+ * @param x first argument
+ * @param y second argument
+ * @return the product
+ */
+ @Override
+ public double apply(double x, double y) {
+ return x * y;
+ }
+
+ /**
+ * x * 0 = y only if y = 0
+ * @return true iff f(x, 0) = x for any x
+ */
+ @Override
+ public boolean isLikeRightPlus() {
+ return false;
+ }
+
+ /**
+ * 0 * y = 0 for any y
+ * @return true iff f(0, y) = 0 for any y
+ */
+ @Override
+ public boolean isLikeLeftMult() {
+ return true;
+ }
+
+ /**
+ * x * 0 = 0 for any x
+ * @return true iff f(x, 0) = 0 for any x
+ */
+ @Override
+ public boolean isLikeRightMult() {
+ return true;
+ }
+
+ /**
+ * x * y = y * x for any x, y
+ * @return true iff f(x, y) = f(y, x) for any x, y
+ */
+ @Override
+ public boolean isCommutative() {
+ return true;
+ }
+
+ /**
+ * x * (y * z) = (x * y) * z for any x, y, z
+ * @return true iff f(x, f(y, z)) = f(f(x, y), z) for any x, y, z
+ */
+ @Override
+ public boolean isAssociative() {
+ return true;
+ }
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/function/VectorFunction.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/mahout/math/function/VectorFunction.java b/core/src/main/java/org/apache/mahout/math/function/VectorFunction.java
new file mode 100644
index 0000000..3b5af77
--- /dev/null
+++ b/core/src/main/java/org/apache/mahout/math/function/VectorFunction.java
@@ -0,0 +1,27 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math.function;
+
+import org.apache.mahout.math.Vector;
+
+/**
+ * Defines a function of a vector that returns a double.
+ */
+public interface VectorFunction {
+ double apply(Vector f);
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/function/package-info.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/mahout/math/function/package-info.java b/core/src/main/java/org/apache/mahout/math/function/package-info.java
new file mode 100644
index 0000000..47ceace
--- /dev/null
+++ b/core/src/main/java/org/apache/mahout/math/function/package-info.java
@@ -0,0 +1,4 @@
+/**
+ * Core interfaces for functions, comparisons and procedures on objects and primitive data types.
+ */
+package org.apache.mahout.math.function;

http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/jet/math/Arithmetic.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/mahout/math/jet/math/Arithmetic.java b/core/src/main/java/org/apache/mahout/math/jet/math/Arithmetic.java
new file mode 100644
index 0000000..83d512b
--- /dev/null
+++ b/core/src/main/java/org/apache/mahout/math/jet/math/Arithmetic.java
@@ -0,0 +1,328 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*
+Copyright 1999 CERN - European Organization for Nuclear Research.
+Permission to use, copy, modify, distribute and sell this software and its documentation for any purpose
+is hereby granted without fee, provided that the above copyright notice appear in all copies and
+that both that copyright notice and this permission notice appear in supporting documentation.
+CERN makes no representations about the suitability of this software for any purpose.
+It is provided "as is" without expressed or implied warranty.
+*/
+package org.apache.mahout.math.jet.math;
+
+/**
+ * Arithmetic functions.
+ */
+public final class Arithmetic {
+
+ // for method logFactorial(...)
+ // log(k!) for k = 0, ..., 29
+ private static final double[] LOG_FACTORIAL_TABLE = {
+ 0.00000000000000000, 0.00000000000000000, 0.69314718055994531,
+ 1.79175946922805500, 3.17805383034794562, 4.78749174278204599,
+ 6.57925121201010100, 8.52516136106541430, 10.60460290274525023,
+ 12.80182748008146961, 15.10441257307551530, 17.50230784587388584,
+ 19.98721449566188615, 22.55216385312342289, 25.19122118273868150,
+ 27.89927138384089157, 30.67186010608067280, 33.50507345013688888,
+ 36.39544520803305358, 39.33988418719949404, 42.33561646075348503,
+ 45.38013889847690803, 48.47118135183522388, 51.60667556776437357,
+ 54.78472939811231919, 58.00360522298051994, 61.26170176100200198,
+ 64.55753862700633106, 67.88974313718153498, 71.25703896716800901
+ };
+
+ // k! for k = 0, ..., 20
+ private static final long[] FACTORIAL_TABLE = {
+ 1L,
+ 1L,
+ 2L,
+ 6L,
+ 24L,
+ 120L,
+ 720L,
+ 5040L,
+ 40320L,
+ 362880L,
+ 3628800L,
+ 39916800L,
+ 479001600L,
+ 6227020800L,
+ 87178291200L,
+ 1307674368000L,
+ 20922789888000L,
+ 355687428096000L,
+ 6402373705728000L,
+ 121645100408832000L,
+ 2432902008176640000L
+ };
+
+ // k! for k = 21, ..., 170
+ private static final double[] LARGE_FACTORIAL_TABLE = {
+ 5.109094217170944E19,
+ 1.1240007277776077E21,
+ 2.585201673888498E22,
+ 6.204484017332394E23,
+ 1.5511210043330984E25,
+ 4.032914611266057E26,
+ 1.0888869450418352E28,
+ 3.048883446117138E29,
+ 8.841761993739701E30,
+ 2.652528598121911E32,
+ 8.222838654177924E33,
+ 2.6313083693369355E35,
+ 8.68331761881189E36,
+ 2.952327990396041E38,
+ 1.0333147966386144E40,
+ 3.719933267899013E41,
+ 1.3763753091226346E43,
+ 5.23022617466601E44,
+ 2.0397882081197447E46,
+ 8.15915283247898E47,
+ 3.34525266131638E49,
+ 1.4050061177528801E51,
+ 6.041526306337384E52,
+ 2.6582715747884495E54,
+ 1.196222208654802E56,
+ 5.502622159812089E57,
+ 2.5862324151116827E59,
+ 1.2413915592536068E61,
+ 6.082818640342679E62,
+ 3.0414093201713376E64,
+ 1.5511187532873816E66,
+ 8.06581751709439E67,
+ 4.274883284060024E69,
+ 2.308436973392413E71,
+ 1.2696403353658264E73,
+ 7.109985878048632E74,
+ 4.052691950487723E76,
+ 2.350561331282879E78,
+ 1.386831185456898E80,
+ 8.32098711274139E81,
+ 5.075802138772246E83,
+ 3.146997326038794E85,
+ 1.9826083154044396E87,
+ 1.2688693218588414E89,
+ 8.247650592082472E90,
+ 5.443449390774432E92,
+ 3.6471110918188705E94,
+ 2.48003554243683E96,
+ 1.7112245242814127E98,
+ 1.1978571669969892E100,
+ 8.504785885678624E101,
+ 6.123445837688612E103,
+ 4.470115461512686E105,
+ 3.307885441519387E107,
+ 2.4809140811395404E109,
+ 1.8854947016660506E111,
+ 1.451830920282859E113,
+ 1.1324281178206295E115,
+ 8.94618213078298E116,
+ 7.15694570462638E118,
+ 5.797126020747369E120,
+ 4.7536433370128435E122,
+ 3.94552396972066E124,
+ 3.314240134565354E126,
+ 2.8171041143805494E128,
+ 2.4227095383672744E130,
+ 2.107757298379527E132,
+ 1.854826422573984E134,
+ 1.6507955160908465E136,
+ 1.4857159644817605E138,
+ 1.3520015276784033E140,
+ 1.2438414054641305E142,
+ 1.156772507081641E144,
+ 1.0873661566567426E146,
+ 1.0329978488239061E148,
+ 9.916779348709491E149,
+ 9.619275968248216E151,
+ 9.426890448883248E153,
+ 9.332621544394415E155,
+ 9.332621544394418E157,
+ 9.42594775983836E159,
+ 9.614466715035125E161,
+ 9.902900716486178E163,
+ 1.0299016745145631E166,
+ 1.0813967582402912E168,
+ 1.1462805637347086E170,
+ 1.2265202031961373E172,
+ 1.324641819451829E174,
+ 1.4438595832024942E176,
+ 1.5882455415227423E178,
+ 1.7629525510902457E180,
+ 1.974506857221075E182,
+ 2.2311927486598138E184,
+ 2.543559733472186E186,
+ 2.925093693493014E188,
+ 3.393108684451899E190,
+ 3.96993716080872E192,
+ 4.6845258497542896E194,
+ 5.574585761207606E196,
+ 6.689502913449135E198,
+ 8.094298525273444E200,
+ 9.875044200833601E202,
+ 1.2146304367025332E205,
+ 1.506141741511141E207,
+ 1.882677176888926E209,
+ 2.3721732428800483E211,
+ 3.0126600184576624E213,
+ 3.856204823625808E215,
+ 4.974504222477287E217,
+ 6.466855489220473E219,
+ 8.471580690878813E221,
+ 1.1182486511960037E224,
+ 1.4872707060906847E226,
+ 1.99294274616152E228,
+ 2.690472707318049E230,
+ 3.6590428819525483E232,
+ 5.0128887482749884E234,
+ 6.917786472619482E236,
+ 9.615723196941089E238,
+ 1.3462012475717523E241,
+ 1.8981437590761713E243,
+ 2.6953641378881633E245,
+ 3.8543707171800694E247,
+ 5.550293832739308E249,
+ 8.047926057471989E251,
+ 1.1749972043909107E254,
+ 1.72724589045464E256,
+ 2.5563239178728637E258,
+ 3.8089226376305687E260,
+ 5.7133839564458575E262,
+ 8.627209774233244E264,
+ 1.3113358856834527E267,
+ 2.0063439050956838E269,
+ 3.0897696138473515E271,
+ 4.789142901463393E273,
+ 7.471062926282892E275,
+ 1.1729568794264134E278,
+ 1.8532718694937346E280,
+ 2.946702272495036E282,
+ 4.714723635992061E284,
+ 7.590705053947223E286,
+ 1.2296942187394494E289,
+ 2.0044015765453032E291,
+ 3.287218585534299E293,
+ 5.423910666131583E295,
+ 9.003691705778434E297,
+ 1.5036165148649983E300,
+ 2.5260757449731988E302,
+ 4.2690680090047056E304,
+ 7.257415615308004E306
+ };
+
+ private Arithmetic() {
+ }
+
+ /**
+ * Efficiently returns the binomial coefficient, often also referred to as "n over k" or "n choose k". The binomial
+ * coefficient is defined as <ul>
+ * <li><tt>k&lt;0</tt>: <tt>0</tt>.</li>
+ * <li><tt>k==0 || k==n</tt>: <tt>1</tt>.</li>
+ * <li><tt>k==1 || k==n-1</tt>: <tt>n</tt>.</li>
+ * <li>else: <tt>(n * n-1 * ... * n-k+1 ) / ( 1 * 2 * ... * k )</tt>.</li>
+ * </ul>
+ *
+ * @return the binomial coefficient.
+ */
+ public static double binomial(long n, long k) {
+ if (k < 0) {
+ return 0;
+ }
+ if (k == 0 || k == n) {
+ return 1;
+ }
+ if (k == 1 || k == n - 1) {
+ return n;
+ }
+
+ // try quick version and see whether we get numeric overflows.
+ // factorial(..) is O(1); requires no loop; only a table lookup.
+ if (n > k) {
+ int max = FACTORIAL_TABLE.length + LARGE_FACTORIAL_TABLE.length;
+ if (n < max) { // if (n! < inf && k! < inf)
+ double nFactorial = factorial((int) n);
+ double kFactorial = factorial((int) k);
+ double nMinusKFactorial = factorial((int) (n - k));
+ double nk = nMinusKFactorial * kFactorial;
+ if (nk != Double.POSITIVE_INFINITY) { // no numeric overflow?
+ // now this is completely safe and accurate
+ return nFactorial / nk;
+ }
+ }
+ if (k > n / 2) {
+ k = n - k;
+ } // quicker
+ }
+
+ // binomial(n,k) = (n * n-1 * ... * n-k+1 ) / ( 1 * 2 * ... * k )
+ long a = n - k + 1;
+ long b = 1;
+ double binomial = 1;
+ for (long i = k; i-- > 0;) {
+ binomial *= (double) a++ / b++;
+ }
+ return binomial;
+ }
+
+ /**
+ * Instantly returns the factorial <tt>k!</tt>.
+ *
+ * @param k must hold <tt>k &gt;= 0</tt>.
+ */
+ private static double factorial(int k) {
+ if (k < 0) {
+ throw new IllegalArgumentException();
+ }
+
+ int length1 = FACTORIAL_TABLE.length;
+ if (k < length1) {
+ return FACTORIAL_TABLE[k];
+ }
+
+ int length2 = LARGE_FACTORIAL_TABLE.length;
+ if (k < length1 + length2) {
+ return LARGE_FACTORIAL_TABLE[k - length1];
+ } else {
+ return Double.POSITIVE_INFINITY;
+ }
+ }
+
+ /**
+ * Returns <tt>log(k!)</tt>. Tries to avoid overflows. For {@code k<30} simply
+ * looks up a table in <tt>O(1)</tt>. For {@code k>=30} uses stirlings
+ * approximation.
+ *
+ * @param k must hold <tt>k &gt;= 0</tt>.
+ */
+ public static double logFactorial(int k) {
+ if (k >= 30) {
+
+ double r = 1.0 / k;
+ double rr = r * r;
+ double c7 = -5.95238095238095238e-04;
+ double c5 = 7.93650793650793651e-04;
+ double c3 = -2.77777777777777778e-03;
+ double c1 = 8.33333333333333333e-02;
+ double c0 = 9.18938533204672742e-01;
+ return (k + 0.5) * Math.log(k) - k + c0 + r * (c1 + rr * (c3 + rr * (c5 + rr * c7)));
+ } else {
+ return LOG_FACTORIAL_TABLE[k];
+ }
+ }
+
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/jet/math/Constants.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/mahout/math/jet/math/Constants.java b/core/src/main/java/org/apache/mahout/math/jet/math/Constants.java
new file mode 100644
index 0000000..b99340d
--- /dev/null
+++ b/core/src/main/java/org/apache/mahout/math/jet/math/Constants.java
@@ -0,0 +1,49 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*
+Copyright 1999 CERN - European Organization for Nuclear Research.
+Permission to use, copy, modify, distribute and sell this software and its documentation for any purpose
+is hereby granted without fee, provided that the above copyright notice appear in all copies and
+that both that copyright notice and this permission notice appear in supporting documentation.
+CERN makes no representations about the suitability of this software for any purpose.
+It is provided "as is" without expressed or implied warranty.
+*/
+package org.apache.mahout.math.jet.math;
+
+/**
+ * Defines some useful constants.
+ */
+public final class Constants {
+
+ public static final double MACHEP = 1.11022302462515654042E-16;
+ public static final double MAXLOG = 7.09782712893383996732E2;
+ public static final double MINLOG = -7.451332191019412076235E2;
+ public static final double MAXGAM = 171.624376956302725;
+ public static final double SQTPI = 2.50662827463100050242E0;
+ public static final double LOGPI = 1.14472988584940017414;
+
+ public static final double BIG = 4.503599627370496e15;
+ public static final double BIG_INVERSE = 2.22044604925031308085e-16;
+
+ public static final double EPSILON = 1.0E-6;
+
+ private Constants() {
+ }
+}
r***@apache.org
2018-09-08 23:35:12 UTC
Permalink
http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/VectorView.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/mahout/math/VectorView.java b/core/src/main/java/org/apache/mahout/math/VectorView.java
new file mode 100644
index 0000000..62c5490
--- /dev/null
+++ b/core/src/main/java/org/apache/mahout/math/VectorView.java
@@ -0,0 +1,238 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math;
+
+import java.util.Iterator;
+
+import com.google.common.collect.AbstractIterator;
+
+/** Implements subset view of a Vector */
+public class VectorView extends AbstractVector {
+
+ protected Vector vector;
+
+ // the offset into the Vector
+ protected int offset;
+
+ /** For serialization purposes only */
+ public VectorView() {
+ super(0);
+ }
+
+ public VectorView(Vector vector, int offset, int cardinality) {
+ super(cardinality);
+ this.vector = vector;
+ this.offset = offset;
+ }
+
+ @Override
+ protected Matrix matrixLike(int rows, int columns) {
+ return ((AbstractVector) vector).matrixLike(rows, columns);
+ }
+
+ @Override
+ public Vector clone() {
+ VectorView r = (VectorView) super.clone();
+ r.vector = vector.clone();
+ r.offset = offset;
+ return r;
+ }
+
+ @Override
+ public boolean isDense() {
+ return vector.isDense();
+ }
+
+ @Override
+ public boolean isSequentialAccess() {
+ return vector.isSequentialAccess();
+ }
+
+ @Override
+ public VectorView like() {
+ return new VectorView(vector.like(), offset, size());
+ }
+
+ @Override
+ public Vector like(int cardinality) {
+ return vector.like(cardinality);
+ }
+
+ @Override
+ public double getQuick(int index) {
+ return vector.getQuick(offset + index);
+ }
+
+ @Override
+ public void setQuick(int index, double value) {
+ vector.setQuick(offset + index, value);
+ }
+
+ @Override
+ public int getNumNondefaultElements() {
+ return size();
+ }
+
+ @Override
+ public Vector viewPart(int offset, int length) {
+ if (offset < 0) {
+ throw new IndexException(offset, size());
+ }
+ if (offset + length > size()) {
+ throw new IndexException(offset + length, size());
+ }
+ return new VectorView(vector, offset + this.offset, length);
+ }
+
+ /** @return true if index is a valid index in the underlying Vector */
+ private boolean isInView(int index) {
+ return index >= offset && index < offset + size();
+ }
+
+ @Override
+ public Iterator<Element> iterateNonZero() {
+ return new NonZeroIterator();
+ }
+
+ @Override
+ public Iterator<Element> iterator() {
+ return new AllIterator();
+ }
+
+ public final class NonZeroIterator extends AbstractIterator<Element> {
+
+ private final Iterator<Element> it;
+
+ private NonZeroIterator() {
+ it = vector.nonZeroes().iterator();
+ }
+
+ @Override
+ protected Element computeNext() {
+ while (it.hasNext()) {
+ Element el = it.next();
+ if (isInView(el.index()) && el.get() != 0) {
+ Element decorated = el; /* vector.getElement(el.index()); */
+ return new DecoratorElement(decorated);
+ }
+ }
+ return endOfData();
+ }
+
+ }
+
+ public final class AllIterator extends AbstractIterator<Element> {
+
+ private final Iterator<Element> it;
+
+ private AllIterator() {
+ it = vector.all().iterator();
+ }
+
+ @Override
+ protected Element computeNext() {
+ while (it.hasNext()) {
+ Element el = it.next();
+ if (isInView(el.index())) {
+ Element decorated = vector.getElement(el.index());
+ return new DecoratorElement(decorated);
+ }
+ }
+ return endOfData(); // No element was found
+ }
+
+ }
+
+ private final class DecoratorElement implements Element {
+
+ private final Element decorated;
+
+ private DecoratorElement(Element decorated) {
+ this.decorated = decorated;
+ }
+
+ @Override
+ public double get() {
+ return decorated.get();
+ }
+
+ @Override
+ public int index() {
+ return decorated.index() - offset;
+ }
+
+ @Override
+ public void set(double value) {
+ decorated.set(value);
+ }
+ }
+
+ @Override
+ public double getLengthSquared() {
+ double result = 0.0;
+ int size = size();
+ for (int i = 0; i < size; i++) {
+ double value = getQuick(i);
+ result += value * value;
+ }
+ return result;
+ }
+
+ @Override
+ public double getDistanceSquared(Vector v) {
+ double result = 0.0;
+ int size = size();
+ for (int i = 0; i < size; i++) {
+ double delta = getQuick(i) - v.getQuick(i);
+ result += delta * delta;
+ }
+ return result;
+ }
+
+ @Override
+ public double getLookupCost() {
+ return vector.getLookupCost();
+ }
+
+ @Override
+ public double getIteratorAdvanceCost() {
+ // TODO: remove the 2x after fixing the Element iterator
+ return 2 * vector.getIteratorAdvanceCost();
+ }
+
+ @Override
+ public boolean isAddConstantTime() {
+ return vector.isAddConstantTime();
+ }
+
+ /**
+ * Used internally by assign() to update multiple indices and values at once.
+ * Only really useful for sparse vectors (especially SequentialAccessSparseVector).
+ * <p>
+ * If someone ever adds a new type of sparse vectors, this method must merge (index, value) pairs into the vector.
+ *
+ * @param updates a mapping of indices to values to merge in the vector.
+ */
+ @Override
+ public void mergeUpdates(OrderedIntDoubleMapping updates) {
+ for (int i = 0; i < updates.getNumMappings(); ++i) {
+ updates.setIndexAt(i, updates.indexAt(i) + offset);
+ }
+ vector.mergeUpdates(updates);
+ }
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/WeightedVector.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/mahout/math/WeightedVector.java b/core/src/main/java/org/apache/mahout/math/WeightedVector.java
new file mode 100644
index 0000000..c8fdfac
--- /dev/null
+++ b/core/src/main/java/org/apache/mahout/math/WeightedVector.java
@@ -0,0 +1,87 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math;
+
+/**
+ * Decorates a vector with a floating point weight and an index.
+ */
+public class WeightedVector extends DelegatingVector {
+ private static final int INVALID_INDEX = -1;
+ private double weight;
+ private int index;
+
+ protected WeightedVector(double weight, int index) {
+ super();
+ this.weight = weight;
+ this.index = index;
+ }
+
+ public WeightedVector(Vector v, double weight, int index) {
+ super(v);
+ this.weight = weight;
+ this.index = index;
+ }
+
+ public WeightedVector(Vector v, Vector projection, int index) {
+ super(v);
+ this.index = index;
+ this.weight = v.dot(projection);
+ }
+
+ public static WeightedVector project(Vector v, Vector projection) {
+ return project(v, projection, INVALID_INDEX);
+ }
+
+ public static WeightedVector project(Vector v, Vector projection, int index) {
+ return new WeightedVector(v, projection, index);
+ }
+
+ public double getWeight() {
+ return weight;
+ }
+
+ public int getIndex() {
+ return index;
+ }
+
+ public void setWeight(double newWeight) {
+ this.weight = newWeight;
+ }
+
+ public void setIndex(int index) {
+ this.index = index;
+ }
+
+ @Override
+ public Vector like() {
+ return new WeightedVector(getVector().like(), weight, index);
+ }
+
+ @Override
+ public String toString() {
+ return String.format("index=%d, weight=%.2f, v=%s", index, weight, getVector());
+ }
+
+ @Override
+ public WeightedVector clone() {
+ WeightedVector v = (WeightedVector)super.clone();
+ v.weight = weight;
+ v.index = index;
+ return v;
+ }
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/WeightedVectorComparator.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/mahout/math/WeightedVectorComparator.java b/core/src/main/java/org/apache/mahout/math/WeightedVectorComparator.java
new file mode 100644
index 0000000..9fdd621
--- /dev/null
+++ b/core/src/main/java/org/apache/mahout/math/WeightedVectorComparator.java
@@ -0,0 +1,54 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math;
+
+import java.io.Serializable;
+import java.util.Comparator;
+
+/**
+ * Orders {@link WeightedVector} by {@link WeightedVector#getWeight()}.
+ */
+public final class WeightedVectorComparator implements Comparator<WeightedVector>, Serializable {
+
+ private static final double DOUBLE_EQUALITY_ERROR = 1.0e-8;
+
+ @Override
+ public int compare(WeightedVector a, WeightedVector b) {
+ if (a == b) {
+ return 0;
+ }
+ double aWeight = a.getWeight();
+ double bWeight = b.getWeight();
+ int r = Double.compare(aWeight, bWeight);
+ if (r != 0 && Math.abs(aWeight - bWeight) >= DOUBLE_EQUALITY_ERROR) {
+ return r;
+ }
+ double diff = a.minus(b).norm(1);
+ if (diff < 1.0e-12) {
+ return 0;
+ }
+ for (Vector.Element element : a.all()) {
+ r = Double.compare(element.get(), b.get(element.index()));
+ if (r != 0) {
+ return r;
+ }
+ }
+ return 0;
+ }
+
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/als/AlternatingLeastSquaresSolver.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/mahout/math/als/AlternatingLeastSquaresSolver.java b/core/src/main/java/org/apache/mahout/math/als/AlternatingLeastSquaresSolver.java
new file mode 100644
index 0000000..dbe1f8b
--- /dev/null
+++ b/core/src/main/java/org/apache/mahout/math/als/AlternatingLeastSquaresSolver.java
@@ -0,0 +1,116 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math.als;
+
+import com.google.common.base.Preconditions;
+import com.google.common.collect.Iterables;
+import org.apache.mahout.math.DenseMatrix;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.QRDecomposition;
+import org.apache.mahout.math.Vector;
+
+/**
+ * See
+ * <a href="http://www.hpl.hp.com/personal/Robert_Schreiber/papers/2008%20AAIM%20Netflix/netflix_aaim08(submitted).pdf">
+ * this paper.</a>
+ */
+public final class AlternatingLeastSquaresSolver {
+
+ private AlternatingLeastSquaresSolver() {}
+
+ //TODO make feature vectors a simple array
+ public static Vector solve(Iterable<Vector> featureVectors, Vector ratingVector, double lambda, int numFeatures) {
+
+ Preconditions.checkNotNull(featureVectors, "Feature Vectors cannot be null");
+ Preconditions.checkArgument(!Iterables.isEmpty(featureVectors));
+ Preconditions.checkNotNull(ratingVector, "Rating Vector cannot be null");
+ Preconditions.checkArgument(ratingVector.getNumNondefaultElements() > 0, "Rating Vector cannot be empty");
+ Preconditions.checkArgument(Iterables.size(featureVectors) == ratingVector.getNumNondefaultElements());
+
+ int nui = ratingVector.getNumNondefaultElements();
+
+ Matrix MiIi = createMiIi(featureVectors, numFeatures);
+ Matrix RiIiMaybeTransposed = createRiIiMaybeTransposed(ratingVector);
+
+ /* compute Ai = MiIi * t(MiIi) + lambda * nui * E */
+ Matrix Ai = miTimesMiTransposePlusLambdaTimesNuiTimesE(MiIi, lambda, nui);
+ /* compute Vi = MiIi * t(R(i,Ii)) */
+ Matrix Vi = MiIi.times(RiIiMaybeTransposed);
+ /* compute Ai * ui = Vi */
+ return solve(Ai, Vi);
+ }
+
+ private static Vector solve(Matrix Ai, Matrix Vi) {
+ return new QRDecomposition(Ai).solve(Vi).viewColumn(0);
+ }
+
+ static Matrix addLambdaTimesNuiTimesE(Matrix matrix, double lambda, int nui) {
+ Preconditions.checkArgument(matrix.numCols() == matrix.numRows(), "Must be a Square Matrix");
+ double lambdaTimesNui = lambda * nui;
+ int numCols = matrix.numCols();
+ for (int n = 0; n < numCols; n++) {
+ matrix.setQuick(n, n, matrix.getQuick(n, n) + lambdaTimesNui);
+ }
+ return matrix;
+ }
+
+ private static Matrix miTimesMiTransposePlusLambdaTimesNuiTimesE(Matrix MiIi, double lambda, int nui) {
+
+ double lambdaTimesNui = lambda * nui;
+ int rows = MiIi.numRows();
+
+ double[][] result = new double[rows][rows];
+
+ for (int i = 0; i < rows; i++) {
+ for (int j = i; j < rows; j++) {
+ double dot = MiIi.viewRow(i).dot(MiIi.viewRow(j));
+ if (i != j) {
+ result[i][j] = dot;
+ result[j][i] = dot;
+ } else {
+ result[i][i] = dot + lambdaTimesNui;
+ }
+ }
+ }
+ return new DenseMatrix(result, true);
+ }
+
+
+ static Matrix createMiIi(Iterable<Vector> featureVectors, int numFeatures) {
+ double[][] MiIi = new double[numFeatures][Iterables.size(featureVectors)];
+ int n = 0;
+ for (Vector featureVector : featureVectors) {
+ for (int m = 0; m < numFeatures; m++) {
+ MiIi[m][n] = featureVector.getQuick(m);
+ }
+ n++;
+ }
+ return new DenseMatrix(MiIi, true);
+ }
+
+ static Matrix createRiIiMaybeTransposed(Vector ratingVector) {
+ Preconditions.checkArgument(ratingVector.isSequentialAccess(), "Ratings should be iterable in Index or Sequential Order");
+
+ double[][] RiIiMaybeTransposed = new double[ratingVector.getNumNondefaultElements()][1];
+ int index = 0;
+ for (Vector.Element elem : ratingVector.nonZeroes()) {
+ RiIiMaybeTransposed[index++][0] = elem.get();
+ }
+ return new DenseMatrix(RiIiMaybeTransposed, true);
+ }
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/als/ImplicitFeedbackAlternatingLeastSquaresSolver.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/mahout/math/als/ImplicitFeedbackAlternatingLeastSquaresSolver.java b/core/src/main/java/org/apache/mahout/math/als/ImplicitFeedbackAlternatingLeastSquaresSolver.java
new file mode 100644
index 0000000..28bf4b4
--- /dev/null
+++ b/core/src/main/java/org/apache/mahout/math/als/ImplicitFeedbackAlternatingLeastSquaresSolver.java
@@ -0,0 +1,171 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math.als;
+
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.TimeUnit;
+
+import org.apache.mahout.math.DenseMatrix;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.QRDecomposition;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.Vector.Element;
+import org.apache.mahout.math.function.Functions;
+import org.apache.mahout.math.list.IntArrayList;
+import org.apache.mahout.math.map.OpenIntObjectHashMap;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import com.google.common.base.Preconditions;
+
+/** see <a href="http://research.yahoo.com/pub/2433">Collaborative Filtering for Implicit Feedback Datasets</a> */
+public class ImplicitFeedbackAlternatingLeastSquaresSolver {
+
+ private final int numFeatures;
+ private final double alpha;
+ private final double lambda;
+ private final int numTrainingThreads;
+
+ private final OpenIntObjectHashMap<Vector> Y;
+ private final Matrix YtransposeY;
+
+ private static final Logger log = LoggerFactory.getLogger(ImplicitFeedbackAlternatingLeastSquaresSolver.class);
+
+ public ImplicitFeedbackAlternatingLeastSquaresSolver(int numFeatures, double lambda, double alpha,
+ OpenIntObjectHashMap<Vector> Y, int numTrainingThreads) {
+ this.numFeatures = numFeatures;
+ this.lambda = lambda;
+ this.alpha = alpha;
+ this.Y = Y;
+ this.numTrainingThreads = numTrainingThreads;
+ YtransposeY = getYtransposeY(Y);
+ }
+
+ public Vector solve(Vector ratings) {
+ return solve(YtransposeY.plus(getYtransponseCuMinusIYPlusLambdaI(ratings)), getYtransponseCuPu(ratings));
+ }
+
+ private static Vector solve(Matrix A, Matrix y) {
+ return new QRDecomposition(A).solve(y).viewColumn(0);
+ }
+
+ double confidence(double rating) {
+ return 1 + alpha * rating;
+ }
+
+ /* Y' Y */
+ public Matrix getYtransposeY(final OpenIntObjectHashMap<Vector> Y) {
+
+ ExecutorService queue = Executors.newFixedThreadPool(numTrainingThreads);
+ if (log.isInfoEnabled()) {
+ log.info("Starting the computation of Y'Y");
+ }
+ long startTime = System.nanoTime();
+ final IntArrayList indexes = Y.keys();
+ final int numIndexes = indexes.size();
+
+ final double[][] YtY = new double[numFeatures][numFeatures];
+
+ // Compute Y'Y by dot products between the 'columns' of Y
+ for (int i = 0; i < numFeatures; i++) {
+ for (int j = i; j < numFeatures; j++) {
+
+ final int ii = i;
+ final int jj = j;
+ queue.execute(new Runnable() {
+ @Override
+ public void run() {
+ double dot = 0;
+ for (int k = 0; k < numIndexes; k++) {
+ Vector row = Y.get(indexes.getQuick(k));
+ dot += row.getQuick(ii) * row.getQuick(jj);
+ }
+ YtY[ii][jj] = dot;
+ if (ii != jj) {
+ YtY[jj][ii] = dot;
+ }
+ }
+ });
+
+ }
+ }
+ queue.shutdown();
+ try {
+ queue.awaitTermination(1, TimeUnit.DAYS);
+ } catch (InterruptedException e) {
+ log.error("Error during Y'Y queue shutdown", e);
+ throw new RuntimeException("Error during Y'Y queue shutdown");
+ }
+ if (log.isInfoEnabled()) {
+ log.info("Computed Y'Y in " + (System.nanoTime() - startTime) / 1000000.0 + " ms" );
+ }
+ return new DenseMatrix(YtY, true);
+ }
+
+ /** Y' (Cu - I) Y + λ I */
+ private Matrix getYtransponseCuMinusIYPlusLambdaI(Vector userRatings) {
+ Preconditions.checkArgument(userRatings.isSequentialAccess(), "need sequential access to ratings!");
+
+ /* (Cu -I) Y */
+ OpenIntObjectHashMap<Vector> CuMinusIY = new OpenIntObjectHashMap<>(userRatings.getNumNondefaultElements());
+ for (Element e : userRatings.nonZeroes()) {
+ CuMinusIY.put(e.index(), Y.get(e.index()).times(confidence(e.get()) - 1));
+ }
+
+ Matrix YtransponseCuMinusIY = new DenseMatrix(numFeatures, numFeatures);
+
+ /* Y' (Cu -I) Y by outer products */
+ for (Element e : userRatings.nonZeroes()) {
+ for (Vector.Element feature : Y.get(e.index()).all()) {
+ Vector partial = CuMinusIY.get(e.index()).times(feature.get());
+ YtransponseCuMinusIY.viewRow(feature.index()).assign(partial, Functions.PLUS);
+ }
+ }
+
+ /* Y' (Cu - I) Y + λ I add lambda on the diagonal */
+ for (int feature = 0; feature < numFeatures; feature++) {
+ YtransponseCuMinusIY.setQuick(feature, feature, YtransponseCuMinusIY.getQuick(feature, feature) + lambda);
+ }
+
+ return YtransponseCuMinusIY;
+ }
+
+ /** Y' Cu p(u) */
+ private Matrix getYtransponseCuPu(Vector userRatings) {
+ Preconditions.checkArgument(userRatings.isSequentialAccess(), "need sequential access to ratings!");
+
+ Vector YtransponseCuPu = new DenseVector(numFeatures);
+
+ for (Element e : userRatings.nonZeroes()) {
+ YtransponseCuPu.assign(Y.get(e.index()).times(confidence(e.get())), Functions.PLUS);
+ }
+
+ return columnVectorAsMatrix(YtransponseCuPu);
+ }
+
+ private Matrix columnVectorAsMatrix(Vector v) {
+ double[][] matrix = new double[numFeatures][1];
+ for (Vector.Element e : v.all()) {
+ matrix[e.index()][0] = e.get();
+ }
+ return new DenseMatrix(matrix, true);
+ }
+
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/decomposer/AsyncEigenVerifier.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/mahout/math/decomposer/AsyncEigenVerifier.java b/core/src/main/java/org/apache/mahout/math/decomposer/AsyncEigenVerifier.java
new file mode 100644
index 0000000..0233848
--- /dev/null
+++ b/core/src/main/java/org/apache/mahout/math/decomposer/AsyncEigenVerifier.java
@@ -0,0 +1,80 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math.decomposer;
+
+import java.io.Closeable;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorIterable;
+
+public class AsyncEigenVerifier extends SimpleEigenVerifier implements Closeable {
+
+ private final ExecutorService threadPool;
+ private EigenStatus status;
+ private boolean finished;
+ private boolean started;
+
+ public AsyncEigenVerifier() {
+ threadPool = Executors.newFixedThreadPool(1);
+ status = new EigenStatus(-1, 0);
+ }
+
+ @Override
+ public synchronized EigenStatus verify(VectorIterable corpus, Vector vector) {
+ if (!finished && !started) { // not yet started or finished, so start!
+ status = new EigenStatus(-1, 0);
+ Vector vectorCopy = vector.clone();
+ threadPool.execute(new VerifierRunnable(corpus, vectorCopy));
+ started = true;
+ }
+ if (finished) {
+ finished = false;
+ }
+ return status;
+ }
+
+ @Override
+ public void close() {
+ this.threadPool.shutdownNow();
+ }
+ protected EigenStatus innerVerify(VectorIterable corpus, Vector vector) {
+ return super.verify(corpus, vector);
+ }
+
+ private class VerifierRunnable implements Runnable {
+ private final VectorIterable corpus;
+ private final Vector vector;
+
+ protected VerifierRunnable(VectorIterable corpus, Vector vector) {
+ this.corpus = corpus;
+ this.vector = vector;
+ }
+
+ @Override
+ public void run() {
+ EigenStatus status = innerVerify(corpus, vector);
+ synchronized (AsyncEigenVerifier.this) {
+ AsyncEigenVerifier.this.status = status;
+ finished = true;
+ started = false;
+ }
+ }
+ }
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/decomposer/EigenStatus.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/mahout/math/decomposer/EigenStatus.java b/core/src/main/java/org/apache/mahout/math/decomposer/EigenStatus.java
new file mode 100644
index 0000000..a284f50
--- /dev/null
+++ b/core/src/main/java/org/apache/mahout/math/decomposer/EigenStatus.java
@@ -0,0 +1,50 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math.decomposer;
+
+public class EigenStatus {
+ private final double eigenValue;
+ private final double cosAngle;
+ private volatile Boolean inProgress;
+
+ public EigenStatus(double eigenValue, double cosAngle) {
+ this(eigenValue, cosAngle, true);
+ }
+
+ public EigenStatus(double eigenValue, double cosAngle, boolean inProgress) {
+ this.eigenValue = eigenValue;
+ this.cosAngle = cosAngle;
+ this.inProgress = inProgress;
+ }
+
+ public double getCosAngle() {
+ return cosAngle;
+ }
+
+ public double getEigenValue() {
+ return eigenValue;
+ }
+
+ public boolean inProgress() {
+ return inProgress;
+ }
+
+ void setInProgress(boolean status) {
+ inProgress = status;
+ }
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/decomposer/SimpleEigenVerifier.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/mahout/math/decomposer/SimpleEigenVerifier.java b/core/src/main/java/org/apache/mahout/math/decomposer/SimpleEigenVerifier.java
new file mode 100644
index 0000000..71aaa30
--- /dev/null
+++ b/core/src/main/java/org/apache/mahout/math/decomposer/SimpleEigenVerifier.java
@@ -0,0 +1,41 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.math.decomposer;
+
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorIterable;
+
+public class SimpleEigenVerifier implements SingularVectorVerifier {
+
+ @Override
+ public EigenStatus verify(VectorIterable corpus, Vector vector) {
+ Vector resultantVector = corpus.timesSquared(vector);
+ double newNorm = resultantVector.norm(2);
+ double oldNorm = vector.norm(2);
+ double eigenValue;
+ double cosAngle;
+ if (newNorm > 0 && oldNorm > 0) {
+ eigenValue = newNorm / oldNorm;
+ cosAngle = resultantVector.dot(vector) / newNorm * oldNorm;
+ } else {
+ eigenValue = 1.0;
+ cosAngle = 0.0;
+ }
+ return new EigenStatus(eigenValue, cosAngle, false);
+ }
+
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/decomposer/SingularVectorVerifier.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/mahout/math/decomposer/SingularVectorVerifier.java b/core/src/main/java/org/apache/mahout/math/decomposer/SingularVectorVerifier.java
new file mode 100644
index 0000000..a9a7af8
--- /dev/null
+++ b/core/src/main/java/org/apache/mahout/math/decomposer/SingularVectorVerifier.java
@@ -0,0 +1,25 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math.decomposer;
+
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorIterable;
+
+public interface SingularVectorVerifier {
+ EigenStatus verify(VectorIterable eigenMatrix, Vector vector);
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/decomposer/hebbian/EigenUpdater.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/mahout/math/decomposer/hebbian/EigenUpdater.java b/core/src/main/java/org/apache/mahout/math/decomposer/hebbian/EigenUpdater.java
new file mode 100644
index 0000000..ac9cc41
--- /dev/null
+++ b/core/src/main/java/org/apache/mahout/math/decomposer/hebbian/EigenUpdater.java
@@ -0,0 +1,25 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math.decomposer.hebbian;
+
+import org.apache.mahout.math.Vector;
+
+
+public interface EigenUpdater {
+ void update(Vector pseudoEigen, Vector trainingVector, TrainingState currentState);
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/decomposer/hebbian/HebbianSolver.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/mahout/math/decomposer/hebbian/HebbianSolver.java b/core/src/main/java/org/apache/mahout/math/decomposer/hebbian/HebbianSolver.java
new file mode 100644
index 0000000..5b5cc9b
--- /dev/null
+++ b/core/src/main/java/org/apache/mahout/math/decomposer/hebbian/HebbianSolver.java
@@ -0,0 +1,342 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math.decomposer.hebbian;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Properties;
+import java.util.Random;
+
+import org.apache.mahout.common.RandomUtils;
+import org.apache.mahout.math.DenseMatrix;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.decomposer.AsyncEigenVerifier;
+import org.apache.mahout.math.decomposer.EigenStatus;
+import org.apache.mahout.math.decomposer.SingularVectorVerifier;
+import org.apache.mahout.math.function.PlusMult;
+import org.apache.mahout.math.function.TimesFunction;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * The Hebbian solver is an iterative, sparse, singular value decomposition solver, based on the paper
+ * <a href="http://www.dcs.shef.ac.uk/~genevieve/gorrell_webb.pdf">Generalized Hebbian Algorithm for
+ * Latent Semantic Analysis</a> (2005) by Genevieve Gorrell and Brandyn Webb (a.k.a. Simon Funk).
+ * TODO: more description here! For now: read the inline comments, and the comments for the constructors.
+ */
+public class HebbianSolver {
+
+ private static final Logger log = LoggerFactory.getLogger(HebbianSolver.class);
+ private static final boolean DEBUG = false;
+
+ private final EigenUpdater updater;
+ private final SingularVectorVerifier verifier;
+ private final double convergenceTarget;
+ private final int maxPassesPerEigen;
+ private final Random rng = RandomUtils.getRandom();
+
+ private int numPasses = 0;
+
+ /**
+ * Creates a new HebbianSolver
+ *
+ * @param updater
+ * {@link EigenUpdater} used to do the actual work of iteratively updating the current "best guess"
+ * singular vector one data-point presentation at a time.
+ * @param verifier
+ * {@link SingularVectorVerifier } an object which perpetually tries to check how close to
+ * convergence the current singular vector is (typically is a
+ * {@link org.apache.mahout.math.decomposer.AsyncEigenVerifier } which does this
+ * in the background in another thread, while the main thread continues to converge)
+ * @param convergenceTarget a small "epsilon" value which tells the solver how small you want the cosine of the
+ * angle between a proposed eigenvector and that same vector after being multiplied by the (square of the) input
+ * corpus
+ * @param maxPassesPerEigen a cutoff which tells the solver after how many times of checking for convergence (done
+ * by the verifier) should the solver stop trying, even if it has not reached the convergenceTarget.
+ */
+ public HebbianSolver(EigenUpdater updater,
+ SingularVectorVerifier verifier,
+ double convergenceTarget,
+ int maxPassesPerEigen) {
+ this.updater = updater;
+ this.verifier = verifier;
+ this.convergenceTarget = convergenceTarget;
+ this.maxPassesPerEigen = maxPassesPerEigen;
+ }
+
+ /**
+ * Creates a new HebbianSolver with maxPassesPerEigen = Integer.MAX_VALUE (i.e. keep on iterating until
+ * convergenceTarget is reached). <b>Not recommended</b> unless only looking for
+ * the first few (5, maybe 10?) singular
+ * vectors, as small errors which compound early on quickly put a minimum error on subsequent vectors.
+ *
+ * @param updater {@link EigenUpdater} used to do the actual work of iteratively updating the current "best guess"
+ * singular vector one data-point presentation at a time.
+ * @param verifier {@link org.apache.mahout.math.decomposer.SingularVectorVerifier }
+ * an object which perpetually tries to check how close to
+ * convergence the current singular vector is (typically is a
+ * {@link org.apache.mahout.math.decomposer.AsyncEigenVerifier } which does this
+ * in the background in another thread, while the main thread continues to converge)
+ * @param convergenceTarget a small "epsilon" value which tells the solver how small you want the cosine of the
+ * angle between a proposed eigenvector and that same vector after being multiplied by the (square of the) input
+ * corpus
+ */
+ public HebbianSolver(EigenUpdater updater,
+ SingularVectorVerifier verifier,
+ double convergenceTarget) {
+ this(updater,
+ verifier,
+ convergenceTarget,
+ Integer.MAX_VALUE);
+ }
+
+ /**
+ * <b>This is the recommended constructor to use if you're not sure</b>
+ * Creates a new HebbianSolver with the default {@link HebbianUpdater } to do the updating work, and the default
+ * {@link org.apache.mahout.math.decomposer.AsyncEigenVerifier } to check for convergence in a
+ * (single) background thread.
+ *
+ * @param convergenceTarget a small "epsilon" value which tells the solver how small you want the cosine of the
+ * angle between a proposed eigenvector and that same vector after being multiplied by the (square of the) input
+ * corpus
+ * @param maxPassesPerEigen a cutoff which tells the solver after how many times of checking for convergence (done
+ * by the verifier) should the solver stop trying, even if it has not reached the convergenceTarget.
+ */
+ public HebbianSolver(double convergenceTarget, int maxPassesPerEigen) {
+ this(new HebbianUpdater(),
+ new AsyncEigenVerifier(),
+ convergenceTarget,
+ maxPassesPerEigen);
+ }
+
+ /**
+ * Creates a new HebbianSolver with the default {@link HebbianUpdater } to do the updating work, and the default
+ * {@link org.apache.mahout.math.decomposer.AsyncEigenVerifier } to check for convergence in a (single)
+ * background thread, with
+ * maxPassesPerEigen set to Integer.MAX_VALUE. <b>Not recommended</b> unless only looking
+ * for the first few (5, maybe 10?) singular
+ * vectors, as small errors which compound early on quickly put a minimum error on subsequent vectors.
+ *
+ * @param convergenceTarget a small "epsilon" value which tells the solver how small you want the cosine of the
+ * angle between a proposed eigenvector and that same vector after being multiplied by the (square of the) input
+ * corpus
+ */
+ public HebbianSolver(double convergenceTarget) {
+ this(convergenceTarget, Integer.MAX_VALUE);
+ }
+
+ /**
+ * Creates a new HebbianSolver with the default {@link HebbianUpdater } to do the updating work, and the default
+ * {@link org.apache.mahout.math.decomposer.AsyncEigenVerifier } to check for convergence in a (single)
+ * background thread, with
+ * convergenceTarget set to 0, which means that the solver will not really care about convergence as a loop-exiting
+ * criterion (but will be checking for convergence anyways, so it will be logged and singular values will be
+ * saved).
+ *
+ * @param numPassesPerEigen the exact number of times the verifier will check convergence status in the background
+ * before the solver will move on to the next eigen-vector.
+ */
+ public HebbianSolver(int numPassesPerEigen) {
+ this(0.0, numPassesPerEigen);
+ }
+
+ /**
+ * Primary singular vector solving method.
+ *
+ * @param corpus input matrix to find singular vectors of. Needs not be symmetric, should probably be sparse (in
+ * fact the input vectors are not mutated, and accessed only via dot-products and sums, so they should be
+ * {@link org.apache.mahout.math.SequentialAccessSparseVector }
+ * @param desiredRank the number of singular vectors to find (in roughly decreasing order by singular value)
+ * @return the final {@link TrainingState } of the solver, after desiredRank singular vectors (and approximate
+ * singular values) have been found.
+ */
+ public TrainingState solve(Matrix corpus,
+ int desiredRank) {
+ int cols = corpus.numCols();
+ Matrix eigens = new DenseMatrix(desiredRank, cols);
+ List<Double> eigenValues = new ArrayList<>();
+ log.info("Finding {} singular vectors of matrix with {} rows, via Hebbian", desiredRank, corpus.numRows());
+ /*
+ * The corpusProjections matrix is a running cache of the residual projection of each corpus vector against all
+ * of the previously found singular vectors. Without this, if multiple passes over the data is made (per
+ * singular vector), recalculating these projections eventually dominates the computational complexity of the
+ * solver.
+ */
+ Matrix corpusProjections = new DenseMatrix(corpus.numRows(), desiredRank);
+ TrainingState state = new TrainingState(eigens, corpusProjections);
+ for (int i = 0; i < desiredRank; i++) {
+ Vector currentEigen = new DenseVector(cols);
+ Vector previousEigen = null;
+ while (hasNotConverged(currentEigen, corpus, state)) {
+ int randomStartingIndex = getRandomStartingIndex(corpus, eigens);
+ Vector initialTrainingVector = corpus.viewRow(randomStartingIndex);
+ state.setTrainingIndex(randomStartingIndex);
+ updater.update(currentEigen, initialTrainingVector, state);
+ for (int corpusRow = 0; corpusRow < corpus.numRows(); corpusRow++) {
+ state.setTrainingIndex(corpusRow);
+ if (corpusRow != randomStartingIndex) {
+ updater.update(currentEigen, corpus.viewRow(corpusRow), state);
+ }
+ }
+ state.setFirstPass(false);
+ if (DEBUG) {
+ if (previousEigen == null) {
+ previousEigen = currentEigen.clone();
+ } else {
+ double dot = currentEigen.dot(previousEigen);
+ if (dot > 0.0) {
+ dot /= currentEigen.norm(2) * previousEigen.norm(2);
+ }
+ // log.info("Current pass * previous pass = {}", dot);
+ }
+ }
+ }
+ // converged!
+ double eigenValue = state.getStatusProgress().get(state.getStatusProgress().size() - 1).getEigenValue();
+ // it's actually more efficient to do this to normalize than to call currentEigen = currentEigen.normalize(),
+ // because the latter does a clone, which isn't necessary here.
+ currentEigen.assign(new TimesFunction(), 1 / currentEigen.norm(2));
+ eigens.assignRow(i, currentEigen);
+ eigenValues.add(eigenValue);
+ state.setCurrentEigenValues(eigenValues);
+ log.info("Found eigenvector {}, eigenvalue: {}", i, eigenValue);
+
+ /**
+ * TODO: Persist intermediate output!
+ */
+ state.setFirstPass(true);
+ state.setNumEigensProcessed(state.getNumEigensProcessed() + 1);
+ state.setActivationDenominatorSquared(0);
+ state.setActivationNumerator(0);
+ state.getStatusProgress().clear();
+ numPasses = 0;
+ }
+ return state;
+ }
+
+ /**
+ * You have to start somewhere...
+ * TODO: start instead wherever you find a vector with maximum residual length after subtracting off the projection
+ * TODO: onto all previous eigenvectors.
+ *
+ * @param corpus the corpus matrix
+ * @param eigens not currently used, but should be (see above TODO)
+ * @return the index into the corpus where the "starting seed" input vector lies.
+ */
+ private int getRandomStartingIndex(Matrix corpus, Matrix eigens) {
+ int index;
+ Vector v;
+ do {
+ double r = rng.nextDouble();
+ index = (int) (r * corpus.numRows());
+ v = corpus.viewRow(index);
+ } while (v == null || v.norm(2) == 0 || v.getNumNondefaultElements() < 5);
+ return index;
+ }
+
+ /**
+ * Uses the {@link SingularVectorVerifier } to check for convergence
+ *
+ * @param currentPseudoEigen the purported singular vector whose convergence is being checked
+ * @param corpus the corpus to check against
+ * @param state contains the previous eigens, various other solving state {@link TrainingState}
+ * @return true if <em>either</em> we have converged, <em>or</em> maxPassesPerEigen has been exceeded.
+ */
+ protected boolean hasNotConverged(Vector currentPseudoEigen,
+ Matrix corpus,
+ TrainingState state) {
+ numPasses++;
+ if (state.isFirstPass()) {
+ log.info("First pass through the corpus, no need to check convergence...");
+ return true;
+ }
+ Matrix previousEigens = state.getCurrentEigens();
+ log.info("Have made {} passes through the corpus, checking convergence...", numPasses);
+ /*
+ * Step 1: orthogonalize currentPseudoEigen by subtracting off eigen(i) * helper.get(i)
+ * Step 2: zero-out the helper vector because it has already helped.
+ */
+ for (int i = 0; i < state.getNumEigensProcessed(); i++) {
+ Vector previousEigen = previousEigens.viewRow(i);
+ currentPseudoEigen.assign(previousEigen, new PlusMult(-state.getHelperVector().get(i)));
+ state.getHelperVector().set(i, 0);
+ }
+ if (currentPseudoEigen.norm(2) > 0) {
+ for (int i = 0; i < state.getNumEigensProcessed(); i++) {
+ Vector previousEigen = previousEigens.viewRow(i);
+ log.info("dot with previous: {}", previousEigen.dot(currentPseudoEigen) / currentPseudoEigen.norm(2));
+ }
+ }
+ /*
+ * Step 3: verify how eigen-like the prospective eigen is. This is potentially asynchronous.
+ */
+ EigenStatus status = verify(corpus, currentPseudoEigen);
+ if (status.inProgress()) {
+ log.info("Verifier not finished, making another pass...");
+ } else {
+ log.info("Has 1 - cosAngle: {}, convergence target is: {}", 1.0 - status.getCosAngle(), convergenceTarget);
+ state.getStatusProgress().add(status);
+ }
+ return
+ state.getStatusProgress().size() <= maxPassesPerEigen
+ && 1.0 - status.getCosAngle() > convergenceTarget;
+ }
+
+ protected EigenStatus verify(Matrix corpus, Vector currentPseudoEigen) {
+ return verifier.verify(corpus, currentPseudoEigen);
+ }
+
+ public static void main(String[] args) {
+ Properties props = new Properties();
+ String propertiesFile = args.length > 0 ? args[0] : "config/solver.properties";
+ // props.load(new FileInputStream(propertiesFile));
+
+ String corpusDir = props.getProperty("solver.input.dir");
+ String outputDir = props.getProperty("solver.output.dir");
+ if (corpusDir == null || corpusDir.isEmpty() || outputDir == null || outputDir.isEmpty()) {
+ log.error("{} must contain values for solver.input.dir and solver.output.dir", propertiesFile);
+ return;
+ }
+ //int inBufferSize = Integer.parseInt(props.getProperty("solver.input.bufferSize"));
+ int rank = Integer.parseInt(props.getProperty("solver.output.desiredRank"));
+ double convergence = Double.parseDouble(props.getProperty("solver.convergence"));
+ int maxPasses = Integer.parseInt(props.getProperty("solver.maxPasses"));
+ //int numThreads = Integer.parseInt(props.getProperty("solver.verifier.numThreads"));
+
+ HebbianUpdater updater = new HebbianUpdater();
+ SingularVectorVerifier verifier = new AsyncEigenVerifier();
+ HebbianSolver solver = new HebbianSolver(updater, verifier, convergence, maxPasses);
+ Matrix corpus = null;
+ /*
+ if (numThreads <= 1) {
+ // corpus = new DiskBufferedDoubleMatrix(new File(corpusDir), inBufferSize);
+ } else {
+ // corpus = new ParallelMultiplyingDiskBufferedDoubleMatrix(new File(corpusDir), inBufferSize, numThreads);
+ }
+ */
+ long now = System.currentTimeMillis();
+ TrainingState finalState = solver.solve(corpus, rank);
+ long time = (System.currentTimeMillis() - now) / 1000;
+ log.info("Solved {} eigenVectors in {} seconds. Persisted to {}",
+ finalState.getCurrentEigens().rowSize(), time, outputDir);
+ }
+
+
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/decomposer/hebbian/HebbianUpdater.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/mahout/math/decomposer/hebbian/HebbianUpdater.java b/core/src/main/java/org/apache/mahout/math/decomposer/hebbian/HebbianUpdater.java
new file mode 100644
index 0000000..2080c3a
--- /dev/null
+++ b/core/src/main/java/org/apache/mahout/math/decomposer/hebbian/HebbianUpdater.java
@@ -0,0 +1,71 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math.decomposer.hebbian;
+
+
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.function.PlusMult;
+
+public class HebbianUpdater implements EigenUpdater {
+
+ @Override
+ public void update(Vector pseudoEigen,
+ Vector trainingVector,
+ TrainingState currentState) {
+ double trainingVectorNorm = trainingVector.norm(2);
+ int numPreviousEigens = currentState.getNumEigensProcessed();
+ if (numPreviousEigens > 0 && currentState.isFirstPass()) {
+ updateTrainingProjectionsVector(currentState, trainingVector, numPreviousEigens - 1);
+ }
+ if (currentState.getActivationDenominatorSquared() == 0 || trainingVectorNorm == 0) {
+ if (currentState.getActivationDenominatorSquared() == 0) {
+ pseudoEigen.assign(trainingVector, new PlusMult(1));
+ currentState.setHelperVector(currentState.currentTrainingProjection().clone());
+ double helperNorm = currentState.getHelperVector().norm(2);
+ currentState.setActivationDenominatorSquared(trainingVectorNorm * trainingVectorNorm - helperNorm * helperNorm);
+ }
+ return;
+ }
+ currentState.setActivationNumerator(pseudoEigen.dot(trainingVector));
+ currentState.setActivationNumerator(
+ currentState.getActivationNumerator()
+ - currentState.getHelperVector().dot(currentState.currentTrainingProjection()));
+
+ double activation = currentState.getActivationNumerator()
+ / Math.sqrt(currentState.getActivationDenominatorSquared());
+ currentState.setActivationDenominatorSquared(
+ currentState.getActivationDenominatorSquared()
+ + 2 * activation * currentState.getActivationNumerator()
+ + activation * activation
+ * (trainingVector.getLengthSquared() - currentState.currentTrainingProjection().getLengthSquared()));
+ if (numPreviousEigens > 0) {
+ currentState.getHelperVector().assign(currentState.currentTrainingProjection(), new PlusMult(activation));
+ }
+ pseudoEigen.assign(trainingVector, new PlusMult(activation));
+ }
+
+ private static void updateTrainingProjectionsVector(TrainingState state,
+ Vector trainingVector,
+ int previousEigenIndex) {
+ Vector previousEigen = state.mostRecentEigen();
+ Vector currentTrainingVectorProjection = state.currentTrainingProjection();
+ double projection = previousEigen.dot(trainingVector);
+ currentTrainingVectorProjection.set(previousEigenIndex, projection);
+ }
+
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/decomposer/hebbian/TrainingState.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/mahout/math/decomposer/hebbian/TrainingState.java b/core/src/main/java/org/apache/mahout/math/decomposer/hebbian/TrainingState.java
new file mode 100644
index 0000000..af6c2ef
--- /dev/null
+++ b/core/src/main/java/org/apache/mahout/math/decomposer/hebbian/TrainingState.java
@@ -0,0 +1,143 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math.decomposer.hebbian;
+
+import java.util.ArrayList;
+import java.util.List;
+
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.decomposer.EigenStatus;
+
+public class TrainingState {
+
+ private Matrix currentEigens;
+ private int numEigensProcessed;
+ private List<Double> currentEigenValues;
+ private Matrix trainingProjections;
+ private int trainingIndex;
+ private Vector helperVector;
+ private boolean firstPass;
+ private List<EigenStatus> statusProgress;
+ private double activationNumerator;
+ private double activationDenominatorSquared;
+
+ TrainingState(Matrix eigens, Matrix projections) {
+ currentEigens = eigens;
+ trainingProjections = projections;
+ trainingIndex = 0;
+ helperVector = new DenseVector(eigens.numRows());
+ firstPass = true;
+ statusProgress = new ArrayList<>();
+ activationNumerator = 0;
+ activationDenominatorSquared = 0;
+ numEigensProcessed = 0;
+ }
+
+ public Vector mostRecentEigen() {
+ return currentEigens.viewRow(numEigensProcessed - 1);
+ }
+
+ public Vector currentTrainingProjection() {
+ if (trainingProjections.viewRow(trainingIndex) == null) {
+ trainingProjections.assignRow(trainingIndex, new DenseVector(currentEigens.numCols()));
+ }
+ return trainingProjections.viewRow(trainingIndex);
+ }
+
+ public Matrix getCurrentEigens() {
+ return currentEigens;
+ }
+
+ public void setCurrentEigens(Matrix currentEigens) {
+ this.currentEigens = currentEigens;
+ }
+
+ public int getNumEigensProcessed() {
+ return numEigensProcessed;
+ }
+
+ public void setNumEigensProcessed(int numEigensProcessed) {
+ this.numEigensProcessed = numEigensProcessed;
+ }
+
+ public List<Double> getCurrentEigenValues() {
+ return currentEigenValues;
+ }
+
+ public void setCurrentEigenValues(List<Double> currentEigenValues) {
+ this.currentEigenValues = currentEigenValues;
+ }
+
+ public Matrix getTrainingProjections() {
+ return trainingProjections;
+ }
+
+ public void setTrainingProjections(Matrix trainingProjections) {
+ this.trainingProjections = trainingProjections;
+ }
+
+ public int getTrainingIndex() {
+ return trainingIndex;
+ }
+
+ public void setTrainingIndex(int trainingIndex) {
+ this.trainingIndex = trainingIndex;
+ }
+
+ public Vector getHelperVector() {
+ return helperVector;
+ }
+
+ public void setHelperVector(Vector helperVector) {
+ this.helperVector = helperVector;
+ }
+
+ public boolean isFirstPass() {
+ return firstPass;
+ }
+
+ public void setFirstPass(boolean firstPass) {
+ this.firstPass = firstPass;
+ }
+
+ public List<EigenStatus> getStatusProgress() {
+ return statusProgress;
+ }
+
+ public void setStatusProgress(List<EigenStatus> statusProgress) {
+ this.statusProgress = statusProgress;
+ }
+
+ public double getActivationNumerator() {
+ return activationNumerator;
+ }
+
+ public void setActivationNumerator(double activationNumerator) {
+ this.activationNumerator = activationNumerator;
+ }
+
+ public double getActivationDenominatorSquared() {
+ return activationDenominatorSquared;
+ }
+
+ public void setActivationDenominatorSquared(double activationDenominatorSquared) {
+ this.activationDenominatorSquared = activationDenominatorSquared;
+ }
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/decomposer/lanczos/LanczosSolver.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/mahout/math/decomposer/lanczos/LanczosSolver.java b/core/src/main/java/org/apache/mahout/math/decomposer/lanczos/LanczosSolver.java
new file mode 100644
index 0000000..61a77db
--- /dev/null
+++ b/core/src/main/java/org/apache/mahout/math/decomposer/lanczos/LanczosSolver.java
@@ -0,0 +1,213 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math.decomposer.lanczos;
+
+
+import java.util.EnumMap;
+import java.util.Map;
+
+import com.google.common.base.Preconditions;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorIterable;
+import org.apache.mahout.math.function.DoubleFunction;
+import org.apache.mahout.math.function.PlusMult;
+import org.apache.mahout.math.solver.EigenDecomposition;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * Simple implementation of the <a href="http://en.wikipedia.org/wiki/Lanczos_algorithm">Lanczos algorithm</a> for
+ * finding eigenvalues of a symmetric matrix, applied to non-symmetric matrices by applying Matrix.timesSquared(vector)
+ * as the "matrix-multiplication" method.<p>
+ *
+ * See the SSVD code for a better option
+ * {@link org.apache.mahout.math.ssvd.SequentialBigSvd}
+ * See also the docs on
+ * <a href=https://mahout.apache.org/users/dim-reduction/ssvd.html>stochastic
+ * projection SVD</a>
+ * <p>
+ * To avoid floating point overflow problems which arise in power-methods like Lanczos, an initial pass is made
+ * through the input matrix to
+ * <ul>
+ * <li>generate a good starting seed vector by summing all the rows of the input matrix, and</li>
+ * <li>compute the trace(inputMatrix<sup>t</sup>*matrix)
+ * </ul>
+ * <p>
+ * This latter value, being the sum of all of the singular values, is used to rescale the entire matrix, effectively
+ * forcing the largest singular value to be strictly less than one, and transforming floating point <em>overflow</em>
+ * problems into floating point <em>underflow</em> (ie, very small singular values will become invisible, as they
+ * will appear to be zero and the algorithm will terminate).
+ * <p>This implementation uses {@link EigenDecomposition} to do the
+ * eigenvalue extraction from the small (desiredRank x desiredRank) tridiagonal matrix. Numerical stability is
+ * achieved via brute-force: re-orthogonalization against all previous eigenvectors is computed after every pass.
+ * This can be made smarter if (when!) this proves to be a major bottleneck. Of course, this step can be parallelized
+ * as well.
+ * @see org.apache.mahout.math.ssvd.SequentialBigSvd
+ */
+@Deprecated
+public class LanczosSolver {
+
+ private static final Logger log = LoggerFactory.getLogger(LanczosSolver.class);
+
+ public static final double SAFE_MAX = 1.0e150;
+
+ public enum TimingSection {
+ ITERATE, ORTHOGANLIZE, TRIDIAG_DECOMP, FINAL_EIGEN_CREATE
+ }
+
+ private final Map<TimingSection, Long> startTimes = new EnumMap<>(TimingSection.class);
+ private final Map<TimingSection, Long> times = new EnumMap<>(TimingSection.class);
+
+ private static final class Scale extends DoubleFunction {
+ private final double d;
+
+ private Scale(double d) {
+ this.d = d;
+ }
+
+ @Override
+ public double apply(double arg1) {
+ return arg1 * d;
+ }
+ }
+
+ public void solve(LanczosState state,
+ int desiredRank) {
+ solve(state, desiredRank, false);
+ }
+
+ public void solve(LanczosState state,
+ int desiredRank,
+ boolean isSymmetric) {
+ VectorIterable corpus = state.getCorpus();
+ log.info("Finding {} singular vectors of matrix with {} rows, via Lanczos",
+ desiredRank, corpus.numRows());
+ int i = state.getIterationNumber();
+ Vector currentVector = state.getBasisVector(i - 1);
+ Vector previousVector = state.getBasisVector(i - 2);
+ double beta = 0;
+ Matrix triDiag = state.getDiagonalMatrix();
+ while (i < desiredRank) {
+ startTime(TimingSection.ITERATE);
+ Vector nextVector = isSymmetric ? corpus.times(currentVector) : corpus.timesSquared(currentVector);
+ log.info("{} passes through the corpus so far...", i);
+ if (state.getScaleFactor() <= 0) {
+ state.setScaleFactor(calculateScaleFactor(nextVector));
+ }
+ nextVector.assign(new Scale(1.0 / state.getScaleFactor()));
+ if (previousVector != null) {
+ nextVector.assign(previousVector, new PlusMult(-beta));
+ }
+ // now orthogonalize
+ double alpha = currentVector.dot(nextVector);
+ nextVector.assign(currentVector, new PlusMult(-alpha));
+ endTime(TimingSection.ITERATE);
+ startTime(TimingSection.ORTHOGANLIZE);
+ orthoganalizeAgainstAllButLast(nextVector, state);
+ endTime(TimingSection.ORTHOGANLIZE);
+ // and normalize
+ beta = nextVector.norm(2);
+ if (outOfRange(beta) || outOfRange(alpha)) {
+ log.warn("Lanczos parameters out of range: alpha = {}, beta = {}. Bailing out early!",
+ alpha, beta);
+ break;
+ }
+ nextVector.assign(new Scale(1 / beta));
+ state.setBasisVector(i, nextVector);
+ previousVector = currentVector;
+ currentVector = nextVector;
+ // save the projections and norms!
+ triDiag.set(i - 1, i - 1, alpha);
+ if (i < desiredRank - 1) {
+ triDiag.set(i - 1, i, beta);
+ triDiag.set(i, i - 1, beta);
+ }
+ state.setIterationNumber(++i);
+ }
+ startTime(TimingSection.TRIDIAG_DECOMP);
+
+ log.info("Lanczos iteration complete - now to diagonalize the tri-diagonal auxiliary matrix.");
+ // at this point, have tridiag all filled out, and basis is all filled out, and orthonormalized
+ EigenDecomposition decomp = new EigenDecomposition(triDiag);
+
+ Matrix eigenVects = decomp.getV();
+ Vector eigenVals = decomp.getRealEigenvalues();
+ endTime(TimingSection.TRIDIAG_DECOMP);
+ startTime(TimingSection.FINAL_EIGEN_CREATE);
+ for (int row = 0; row < i; row++) {
+ Vector realEigen = null;
+
+ Vector ejCol = eigenVects.viewColumn(row);
+ int size = Math.min(ejCol.size(), state.getBasisSize());
+ for (int j = 0; j < size; j++) {
+ double d = ejCol.get(j);
+ Vector rowJ = state.getBasisVector(j);
+ if (realEigen == null) {
+ realEigen = rowJ.like();
+ }
+ realEigen.assign(rowJ, new PlusMult(d));
+ }
+
+ Preconditions.checkState(realEigen != null);
+ assert realEigen != null;
+
+ realEigen = realEigen.normalize();
+ state.setRightSingularVector(row, realEigen);
+ double e = eigenVals.get(row) * state.getScaleFactor();
+ if (!isSymmetric) {
+ e = Math.sqrt(e);
+ }
+ log.info("Eigenvector {} found with eigenvalue {}", row, e);
+ state.setSingularValue(row, e);
+ }
+ log.info("LanczosSolver finished.");
+ endTime(TimingSection.FINAL_EIGEN_CREATE);
+ }
+
+ protected static double calculateScaleFactor(Vector nextVector) {
+ return nextVector.norm(2);
+ }
+
+ private static boolean outOfRange(double d) {
+ return Double.isNaN(d) || d > SAFE_MAX || -d > SAFE_MAX;
+ }
+
+ protected static void orthoganalizeAgainstAllButLast(Vector nextVector, LanczosState state) {
+ for (int i = 0; i < state.getIterationNumber(); i++) {
+ Vector basisVector = state.getBasisVector(i);
+ double alpha;
+ if (basisVector == null || (alpha = nextVector.dot(basisVector)) == 0.0) {
+ continue;
+ }
+ nextVector.assign(basisVector, new PlusMult(-alpha));
+ }
+ }
+
+ private void startTime(TimingSection section) {
+ startTimes.put(section, System.nanoTime());
+ }
+
+ private void endTime(TimingSection section) {
+ if (!times.containsKey(section)) {
+ times.put(section, 0L);
+ }
+ times.put(section, times.get(section) + System.nanoTime() - startTimes.get(section));
+ }
+
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/decomposer/lanczos/LanczosState.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/mahout/math/decomposer/lanczos/LanczosState.java b/core/src/main/java/org/apache/mahout/math/decomposer/lanczos/LanczosState.java
new file mode 100644
index 0000000..2ba34bd
--- /dev/null
+++ b/core/src/main/java/org/apache/mahout/math/decomposer/lanczos/LanczosState.java
@@ -0,0 +1,107 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math.decomposer.lanczos;
+
+import com.google.common.collect.Maps;
+import org.apache.mahout.math.DenseMatrix;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorIterable;
+
+import java.util.Map;
+
+@Deprecated
+public class LanczosState {
+
+ protected Matrix diagonalMatrix;
+ protected final VectorIterable corpus;
+ protected double scaleFactor;
+ protected int iterationNumber;
+ protected final int desiredRank;
+ protected Map<Integer, Vector> basis;
+ protected final Map<Integer, Double> singularValues;
+ protected Map<Integer, Vector> singularVectors;
+
+ public LanczosState(VectorIterable corpus, int desiredRank, Vector initialVector) {
+ this.corpus = corpus;
+ this.desiredRank = desiredRank;
+ intitializeBasisAndSingularVectors();
+ setBasisVector(0, initialVector);
+ scaleFactor = 0;
+ diagonalMatrix = new DenseMatrix(desiredRank, desiredRank);
+ singularValues = Maps.newHashMap();
+ iterationNumber = 1;
+ }
+
+ private void intitializeBasisAndSingularVectors() {
+ basis = Maps.newHashMap();
+ singularVectors = Maps.newHashMap();
+ }
+
+ public Matrix getDiagonalMatrix() {
+ return diagonalMatrix;
+ }
+
+ public int getIterationNumber() {
+ return iterationNumber;
+ }
+
+ public double getScaleFactor() {
+ return scaleFactor;
+ }
+
+ public VectorIterable getCorpus() {
+ return corpus;
+ }
+
+ public Vector getRightSingularVector(int i) {
+ return singularVectors.get(i);
+ }
+
+ public Double getSingularValue(int i) {
+ return singularValues.get(i);
+ }
+
+ public Vector getBasisVector(int i) {
+ return basis.get(i);
+ }
+
+ public int getBasisSize() {
+ return basis.size();
+ }
+
+ public void setBasisVector(int i, Vector basisVector) {
+ basis.put(i, basisVector);
+ }
+
+ public void setScaleFactor(double scale) {
+ scaleFactor = scale;
+ }
+
+ public void setIterationNumber(int i) {
+ iterationNumber = i;
+ }
+
+ public void setRightSingularVector(int i, Vector vector) {
+ singularVectors.put(i, vector);
+ }
+
+ public void setSingularValue(int i, double value) {
+ singularValues.put(i, value);
+ }
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/flavor/BackEnum.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/mahout/math/flavor/BackEnum.java b/core/src/main/java/org/apache/mahout/math/flavor/BackEnum.java
new file mode 100644
index 0000000..1782f04
--- /dev/null
+++ b/core/src/main/java/org/apache/mahout/math/flavor/BackEnum.java
@@ -0,0 +1,26 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math.flavor;
+
+/**
+ * Matrix backends
+ */
+public enum BackEnum {
+ JVMMEM,
+ NETLIB_BLAS
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/flavor/MatrixFlavor.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/mahout/math/flavor/MatrixFlavor.java b/core/src/main/java/org/apache/mahout/math/flavor/MatrixFlavor.java
new file mode 100644
index 0000000..e1d93f2
--- /dev/null
+++ b/core/src/main/java/org/apache/mahout/math/flavor/MatrixFlavor.java
@@ -0,0 +1,82 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math.flavor;
+
+/** A set of matrix structure properties that I denote as "flavor" (by analogy to quarks) */
+public interface MatrixFlavor {
+
+ /**
+ * Whether matrix is backed by a native system -- such as java memory, lapack/atlas, Magma etc.
+ */
+ BackEnum getBacking();
+
+ /**
+ * Structure flavors
+ */
+ TraversingStructureEnum getStructure() ;
+
+ boolean isDense();
+
+ /**
+ * This default for {@link org.apache.mahout.math.DenseMatrix}-like structures
+ */
+ MatrixFlavor DENSELIKE = new FlavorImpl(BackEnum.JVMMEM, TraversingStructureEnum.ROWWISE, true);
+ /**
+ * This is default flavor for {@link org.apache.mahout.math.SparseRowMatrix}-like.
+ */
+ MatrixFlavor SPARSELIKE = new FlavorImpl(BackEnum.JVMMEM, TraversingStructureEnum.ROWWISE, false);
+
+ /**
+ * This is default flavor for {@link org.apache.mahout.math.SparseMatrix}-like structures, i.e. sparse matrix blocks,
+ * where few, perhaps most, rows may be missing entirely.
+ */
+ MatrixFlavor SPARSEROWLIKE = new FlavorImpl(BackEnum.JVMMEM, TraversingStructureEnum.SPARSEROWWISE, false);
+
+ /**
+ * This is default flavor for {@link org.apache.mahout.math.DiagonalMatrix} and the likes.
+ */
+ MatrixFlavor DIAGONALLIKE = new FlavorImpl(BackEnum.JVMMEM, TraversingStructureEnum.VECTORBACKED, false);
+
+ final class FlavorImpl implements MatrixFlavor {
+ private BackEnum pBacking;
+ private TraversingStructureEnum pStructure;
+ private boolean pDense;
+
+ public FlavorImpl(BackEnum backing, TraversingStructureEnum structure, boolean dense) {
+ pBacking = backing;
+ pStructure = structure;
+ pDense = dense;
+ }
+
+ @Override
+ public BackEnum getBacking() {
+ return pBacking;
+ }
+
+ @Override
+ public TraversingStructureEnum getStructure() {
+ return pStructure;
+ }
+
+ @Override
+ public boolean isDense() {
+ return pDense;
+ }
+ }
+
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/flavor/TraversingStructureEnum.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/mahout/math/flavor/TraversingStructureEnum.java b/core/src/main/java/org/apache/mahout/math/flavor/TraversingStructureEnum.java
new file mode 100644
index 0000000..13c2cf4
--- /dev/null
+++ b/core/src/main/java/org/apache/mahout/math/flavor/TraversingStructureEnum.java
@@ -0,0 +1,48 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math.flavor;
+
+/** STRUCTURE HINT */
+public enum TraversingStructureEnum {
+
+ UNKNOWN,
+
+ /**
+ * Backing vectors are directly available as row views.
+ */
+ ROWWISE,
+
+ /**
+ * Column vectors are directly available as column views.
+ */
+ COLWISE,
+
+ /**
+ * Only some row-wise vectors are really present (can use iterateNonEmpty). Corresponds to
+ * [[org.apache.mahout.math.SparseMatrix]].
+ */
+ SPARSEROWWISE,
+
+ SPARSECOLWISE,
+
+ SPARSEHASH,
+
+ VECTORBACKED,
+
+ BLOCKIFIED
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/function/DoubleDoubleFunction.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/mahout/math/function/DoubleDoubleFunction.java b/core/src/main/java/org/apache/mahout/math/function/DoubleDoubleFunction.java
new file mode 100644
index 0000000..466ddd6
--- /dev/null
+++ b/core/src/main/java/org/apache/mahout/math/function/DoubleDoubleFunction.java
@@ -0,0 +1,98 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/*
+Copyright 1999 CERN - European Organization for Nuclear Research.
+Permission to use, copy, modify, distribute and sell this software and its documentation for any purpose
+is hereby granted without fee, provided that the above copyright notice appear in all copies and
+that both that copyright notice and this permission notice appear in supporting documentation.
+CERN makes no representations about the suitability of this software for any purpose.
+It is provided "as is" without expressed or implied warranty.
+*/
+
+package org.apache.mahout.math.function;
+
+/**
+ * Interface that represents a function object: a function that takes two arguments and returns a single value.
+ **/
+public abstract class DoubleDoubleFunction {
+
+ /**
+ * Apply the function to the arguments and return the result
+ *
+ * @param arg1 a double for the first argument
+ * @param arg2 a double for the second argument
+ * @return the result of applying the function
+ */
+ public abstract double apply(double arg1, double arg2);
+
+ /**
+ * @return true iff f(x, 0) = x for any x
+ */
+ public boolean isLikeRightPlus() {
+ return false;
+ }
+
+ /**
+ * @return true iff f(0, y) = 0 for any y
+ */
+ public boolean isLikeLeftMult() {
+ return false;
+ }
+
+ /**
+ * @return true iff f(x, 0) = 0 for any x
+ */
+ public boolean isLikeRightMult() {
+ return false;
+ }
+
+ /**
+ * @return true iff f(x, 0) = f(0, y) = 0 for any x, y
+ */
+ public boolean isLikeMult() {
+ return isLikeLeftMult() && isLikeRightMult();
+ }
+
+ /**
+ * @return true iff f(x, y) = f(y, x) for any x, y
+ */
+ public boolean isCommutative() {
+ return false;
+ }
+
+ /**
+ * @return true iff f(x, f(y, z)) = f(f(x, y), z) for any x, y, z
+ */
+ public boolean isAssociative() {
+ return false;
+ }
+
+ /**
+ * @return true iff f(x, y) = f(y, x) for any x, y AND f(x, f(y, z)) = f(f(x, y), z) for any x, y, z
+ */
+ public boolean isAssociativeAndCommutative() {
+ return isAssociative() && isCommutative();
+ }
+
+ /**
+ * @return true iff f(0, 0) != 0
+ */
+ public boolean isDensifying() {
+ return apply(0.0, 0.0) != 0.0;
+ }
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/function/DoubleFunction.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/mahout/math/function/DoubleFunction.java b/core/src/main/java/org/apache/mahout/math/function/DoubleFunction.java
new file mode 100644
index 0000000..7545154
--- /dev/null
+++ b/core/src/main/java/org/apache/mahout/math/function/DoubleFunction.java
@@ -0,0 +1,48 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.mahout.math.function;
+
+/*
+Copyright 1999 CERN - European Organization for Nuclear Research.
+Permission to use, copy, modify, distribute and sell this software and its documentation for any purpose
+is hereby granted without fee, provided that the above copyright notice appear in all copies and
+that both that copyright notice and this permission notice appear in supporting documentation.
+CERN makes no representations about the suitability of this software for any purpose.
+It is provided "as is" without expressed or implied warranty.
+*/
+
+/**
+ * Interface that represents a function object: a function that takes a single argument and returns a single value.
+ * @see org.apache.mahout.math.map
+ */
+public abstract class DoubleFunction {
+
+ /**
+ * Apply the function to the argument and return the result
+ *
+ * @param x double for the argument
+ * @return the result of applying the function
+ */
+ public abstract double apply(double x);
+
+ public boolean isDensifying() {
+ return Math.abs(apply(0.0)) != 0.0;
+ }
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/function/FloatFunction.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/mahout/math/function/FloatFunction.java b/core/src/main/java/org/apache/mahout/math/function/FloatFunction.java
new file mode 100644
index 0000000..94dfe32
--- /dev/null
+++ b/core/src/main/java/org/apache/mahout/math/function/FloatFunction.java
@@ -0,0 +1,36 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.mahout.math.function;
+
+
+/**
+ * Interface that represents a function object: a function that takes a single argument and returns a single value.
+ *
+ */
+public interface FloatFunction {
+
+ /**
+ * Applies a function to an argument.
+ *
+ * @param argument argument passed to the function.
+ * @return the result of the function.
+ */
+ float apply(float argument);
+}
r***@apache.org
2018-09-08 23:35:14 UTC
Permalink
http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/Sorting.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/mahout/math/Sorting.java b/core/src/main/java/org/apache/mahout/math/Sorting.java
new file mode 100644
index 0000000..93293ac
--- /dev/null
+++ b/core/src/main/java/org/apache/mahout/math/Sorting.java
@@ -0,0 +1,2297 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math;
+
+import java.io.Serializable;
+import java.util.Comparator;
+
+import com.google.common.base.Preconditions;
+import org.apache.mahout.math.function.ByteComparator;
+import org.apache.mahout.math.function.CharComparator;
+import org.apache.mahout.math.function.DoubleComparator;
+import org.apache.mahout.math.function.FloatComparator;
+import org.apache.mahout.math.function.IntComparator;
+import org.apache.mahout.math.function.LongComparator;
+import org.apache.mahout.math.function.ShortComparator;
+
+public final class Sorting {
+
+ /* Specifies when to switch to insertion sort */
+ private static final int SIMPLE_LENGTH = 7;
+ static final int SMALL = 7;
+
+ private Sorting() {}
+
+ private static <T> int med3(T[] array, int a, int b, int c, Comparator<T> comp) {
+ T x = array[a];
+ T y = array[b];
+ T z = array[c];
+ int comparisonxy = comp.compare(x, y);
+ int comparisonxz = comp.compare(x, z);
+ int comparisonyz = comp.compare(y, z);
+ return comparisonxy < 0 ? (comparisonyz < 0 ? b
+ : (comparisonxz < 0 ? c : a)) : (comparisonyz > 0 ? b
+ : (comparisonxz > 0 ? c : a));
+ }
+
+ private static int med3(byte[] array, int a, int b, int c, ByteComparator comp) {
+ byte x = array[a];
+ byte y = array[b];
+ byte z = array[c];
+ int comparisonxy = comp.compare(x, y);
+ int comparisonxz = comp.compare(x, z);
+ int comparisonyz = comp.compare(y, z);
+ return comparisonxy < 0 ? (comparisonyz < 0 ? b
+ : (comparisonxz < 0 ? c : a)) : (comparisonyz > 0 ? b
+ : (comparisonxz > 0 ? c : a));
+ }
+
+ private static int med3(char[] array, int a, int b, int c, CharComparator comp) {
+ char x = array[a];
+ char y = array[b];
+ char z = array[c];
+ int comparisonxy = comp.compare(x, y);
+ int comparisonxz = comp.compare(x, z);
+ int comparisonyz = comp.compare(y, z);
+ return comparisonxy < 0 ? (comparisonyz < 0 ? b
+ : (comparisonxz < 0 ? c : a)) : (comparisonyz > 0 ? b
+ : (comparisonxz > 0 ? c : a));
+ }
+
+ private static int med3(double[] array, int a, int b, int c,
+ DoubleComparator comp) {
+ double x = array[a];
+ double y = array[b];
+ double z = array[c];
+ int comparisonxy = comp.compare(x, y);
+ int comparisonxz = comp.compare(x, z);
+ int comparisonyz = comp.compare(y, z);
+ return comparisonxy < 0 ? (comparisonyz < 0 ? b
+ : (comparisonxz < 0 ? c : a)) : (comparisonyz > 0 ? b
+ : (comparisonxz > 0 ? c : a));
+ }
+
+ private static int med3(float[] array, int a, int b, int c,
+ FloatComparator comp) {
+ float x = array[a];
+ float y = array[b];
+ float z = array[c];
+ int comparisonxy = comp.compare(x, y);
+ int comparisonxz = comp.compare(x, z);
+ int comparisonyz = comp.compare(y, z);
+ return comparisonxy < 0 ? (comparisonyz < 0 ? b
+ : (comparisonxz < 0 ? c : a)) : (comparisonyz > 0 ? b
+ : (comparisonxz > 0 ? c : a));
+ }
+
+ private static int med3(int[] array, int a, int b, int c, IntComparator comp) {
+ int x = array[a];
+ int y = array[b];
+ int z = array[c];
+ int comparisonxy = comp.compare(x, y);
+ int comparisonxz = comp.compare(x, z);
+ int comparisonyz = comp.compare(y, z);
+ return comparisonxy < 0 ? (comparisonyz < 0 ? b
+ : (comparisonxz < 0 ? c : a)) : (comparisonyz > 0 ? b
+ : (comparisonxz > 0 ? c : a));
+ }
+
+ /**
+ * This is used for 'external' sorting. The comparator takes <em>indices</em>,
+ * not values, and compares the external values found at those indices.
+ * @param a
+ * @param b
+ * @param c
+ * @param comp
+ * @return
+ */
+ private static int med3(int a, int b, int c, IntComparator comp) {
+ int comparisonab = comp.compare(a, b);
+ int comparisonac = comp.compare(a, c);
+ int comparisonbc = comp.compare(b, c);
+ return comparisonab < 0
+ ? (comparisonbc < 0 ? b : (comparisonac < 0 ? c : a))
+ : (comparisonbc > 0 ? b : (comparisonac > 0 ? c : a));
+ }
+
+ private static int med3(long[] array, int a, int b, int c, LongComparator comp) {
+ long x = array[a];
+ long y = array[b];
+ long z = array[c];
+ int comparisonxy = comp.compare(x, y);
+ int comparisonxz = comp.compare(x, z);
+ int comparisonyz = comp.compare(y, z);
+ return comparisonxy < 0 ? (comparisonyz < 0 ? b
+ : (comparisonxz < 0 ? c : a)) : (comparisonyz > 0 ? b
+ : (comparisonxz > 0 ? c : a));
+ }
+
+ private static int med3(short[] array, int a, int b, int c,
+ ShortComparator comp) {
+ short x = array[a];
+ short y = array[b];
+ short z = array[c];
+ int comparisonxy = comp.compare(x, y);
+ int comparisonxz = comp.compare(x, z);
+ int comparisonyz = comp.compare(y, z);
+ return comparisonxy < 0 ? (comparisonyz < 0 ? b
+ : (comparisonxz < 0 ? c : a)) : (comparisonyz > 0 ? b
+ : (comparisonxz > 0 ? c : a));
+ }
+
+ /**
+ * Sorts the specified range in the array in a specified order.
+ *
+ * @param array
+ * the {@code byte} array to be sorted.
+ * @param start
+ * the start index to sort.
+ * @param end
+ * the last + 1 index to sort.
+ * @param comp
+ * the comparison that determines the sort.
+ * @throws IllegalArgumentException
+ * if {@code start > end}.
+ * @throws ArrayIndexOutOfBoundsException
+ * if {@code start < 0} or {@code end > array.length}.
+ */
+ public static void quickSort(byte[] array, int start, int end,
+ ByteComparator comp) {
+ Preconditions.checkNotNull(array);
+ checkBounds(array.length, start, end);
+ quickSort0(start, end, array, comp);
+ }
+
+ private static void checkBounds(int arrLength, int start, int end) {
+ if (start > end) {
+ // K0033=Start index ({0}) is greater than end index ({1})
+ throw new IllegalArgumentException("Start index " + start
+ + " is greater than end index " + end);
+ }
+ if (start < 0) {
+ throw new ArrayIndexOutOfBoundsException("Array index out of range "
+ + start);
+ }
+ if (end > arrLength) {
+ throw new ArrayIndexOutOfBoundsException("Array index out of range "
+ + end);
+ }
+ }
+
+ private static void quickSort0(int start, int end, byte[] array, ByteComparator comp) {
+ byte temp;
+ int length = end - start;
+ if (length < 7) {
+ for (int i = start + 1; i < end; i++) {
+ for (int j = i; j > start && comp.compare(array[j - 1], array[j]) > 0; j--) {
+ temp = array[j];
+ array[j] = array[j - 1];
+ array[j - 1] = temp;
+ }
+ }
+ return;
+ }
+ int middle = (start + end) / 2;
+ if (length > 7) {
+ int bottom = start;
+ int top = end - 1;
+ if (length > 40) {
+ length /= 8;
+ bottom = med3(array, bottom, bottom + length, bottom + (2 * length),
+ comp);
+ middle = med3(array, middle - length, middle, middle + length, comp);
+ top = med3(array, top - (2 * length), top - length, top, comp);
+ }
+ middle = med3(array, bottom, middle, top, comp);
+ }
+ byte partionValue = array[middle];
+ int a = start;
+ int b = a;
+ int c = end - 1;
+ int d = c;
+ while (true) {
+ int comparison;
+ while (b <= c && (comparison = comp.compare(array[b], partionValue)) <= 0) {
+ if (comparison == 0) {
+ temp = array[a];
+ array[a++] = array[b];
+ array[b] = temp;
+ }
+ b++;
+ }
+ while (c >= b && (comparison = comp.compare(array[c], partionValue)) >= 0) {
+ if (comparison == 0) {
+ temp = array[c];
+ array[c] = array[d];
+ array[d--] = temp;
+ }
+ c--;
+ }
+ if (b > c) {
+ break;
+ }
+ temp = array[b];
+ array[b++] = array[c];
+ array[c--] = temp;
+ }
+ length = a - start < b - a ? a - start : b - a;
+ int l = start;
+ int h = b - length;
+ while (length-- > 0) {
+ temp = array[l];
+ array[l++] = array[h];
+ array[h++] = temp;
+ }
+ length = d - c < end - 1 - d ? d - c : end - 1 - d;
+ l = b;
+ h = end - length;
+ while (length-- > 0) {
+ temp = array[l];
+ array[l++] = array[h];
+ array[h++] = temp;
+ }
+ if ((length = b - a) > 0) {
+ quickSort0(start, start + length, array, comp);
+ }
+ if ((length = d - c) > 0) {
+ quickSort0(end - length, end, array, comp);
+ }
+ }
+
+
+ /**
+ * Sorts some external data with QuickSort.
+ *
+ * @param start
+ * the start index to sort.
+ * @param end
+ * the last + 1 index to sort.
+ * @param comp
+ * the comparator.
+ * @param swap an object that can exchange the positions of two items.
+ * @throws IllegalArgumentException
+ * if {@code start > end}.
+ * @throws ArrayIndexOutOfBoundsException
+ * if {@code start < 0} or {@code end > array.length}.
+ */
+ public static void quickSort(int start, int end, IntComparator comp, Swapper swap) {
+ checkBounds(end + 1, start, end);
+ quickSort0(start, end, comp, swap);
+ }
+
+ private static void quickSort0(int start, int end, IntComparator comp, Swapper swap) {
+ int length = end - start;
+ if (length < 7) {
+ insertionSort(start, end, comp, swap);
+ return;
+ }
+ int middle = (start + end) / 2;
+ if (length > 7) {
+ int bottom = start;
+ int top = end - 1;
+ if (length > 40) {
+ // for lots of data, bottom, middle and top are medians near the beginning, middle or end of the data
+ int skosh = length / 8;
+ bottom = med3(bottom, bottom + skosh, bottom + (2 * skosh), comp);
+ middle = med3(middle - skosh, middle, middle + skosh, comp);
+ top = med3(top - (2 * skosh), top - skosh, top, comp);
+ }
+ middle = med3(bottom, middle, top, comp);
+ }
+
+ int partitionIndex = middle; // an index, not a value.
+
+ // regions from a to b and from c to d are what we will recursively sort
+ int a = start;
+ int b = a;
+ int c = end - 1;
+ int d = c;
+ while (b <= c) {
+ // copy all values equal to the partition value to before a..b. In the process, advance b
+ // as long as values less than the partition or equal are found, also stop when a..b collides with c..d
+ int comparison;
+ while (b <= c && (comparison = comp.compare(b, partitionIndex)) <= 0) {
+ if (comparison == 0) {
+ if (a == partitionIndex) {
+ partitionIndex = b;
+ } else if (b == partitionIndex) {
+ partitionIndex = a;
+ }
+ swap.swap(a, b);
+ a++;
+ }
+ b++;
+ }
+ // at this point [start..a) has partition values, [a..b) has values < partition
+ // also, either b>c or v[b] > partition value
+
+ while (c >= b && (comparison = comp.compare(c, partitionIndex)) >= 0) {
+ if (comparison == 0) {
+ if (c == partitionIndex) {
+ partitionIndex = d;
+ } else if (d == partitionIndex) {
+ partitionIndex = c;
+ }
+ swap.swap(c, d);
+
+ d--;
+ }
+ c--;
+ }
+ // now we also know that [d..end] contains partition values,
+ // [c..d) contains values > partition value
+ // also, either b>c or (v[b] > partition OR v[c] < partition)
+
+ if (b <= c) {
+ // v[b] > partition OR v[c] < partition
+ // swapping will let us continue to grow the two regions
+ if (c == partitionIndex) {
+ partitionIndex = b;
+ } else if (b == partitionIndex) {
+ partitionIndex = d;
+ }
+ swap.swap(b, c);
+ b++;
+ c--;
+ }
+ }
+ // now we know
+ // b = c+1
+ // [start..a) and [d..end) contain partition value
+ // all of [a..b) are less than partition
+ // all of [c..d) are greater than partition
+
+ // shift [a..b) to beginning
+ length = Math.min(a - start, b - a);
+ int l = start;
+ int h = b - length;
+ while (length-- > 0) {
+ swap.swap(l, h);
+ l++;
+ h++;
+ }
+
+ // shift [c..d) to end
+ length = Math.min(d - c, end - 1 - d);
+ l = b;
+ h = end - length;
+ while (length-- > 0) {
+ swap.swap(l, h);
+ l++;
+ h++;
+ }
+
+ // recurse left and right
+ length = b - a;
+ if (length > 0) {
+ quickSort0(start, start + length, comp, swap);
+ }
+
+ length = d - c;
+ if (length > 0) {
+ quickSort0(end - length, end, comp, swap);
+ }
+ }
+
+ /**
+ * In-place insertion sort that is fast for pre-sorted data.
+ *
+ * @param start Where to start sorting (inclusive)
+ * @param end Where to stop (exclusive)
+ * @param comp Sort order.
+ * @param swap How to swap items.
+ */
+ private static void insertionSort(int start, int end, IntComparator comp, Swapper swap) {
+ for (int i = start + 1; i < end; i++) {
+ for (int j = i; j > start && comp.compare(j - 1, j) > 0; j--) {
+ swap.swap(j - 1, j);
+ }
+ }
+ }
+ /**
+ * Sorts the specified range in the array in a specified order.
+ *
+ * @param array
+ * the {@code char} array to be sorted.
+ * @param start
+ * the start index to sort.
+ * @param end
+ * the last + 1 index to sort.
+ * @throws IllegalArgumentException
+ * if {@code start > end}.
+ * @throws ArrayIndexOutOfBoundsException
+ * if {@code start < 0} or {@code end > array.length}.
+ */
+ public static void quickSort(char[] array, int start, int end, CharComparator comp) {
+ Preconditions.checkNotNull(array);
+ checkBounds(array.length, start, end);
+ quickSort0(start, end, array, comp);
+ }
+
+ private static void quickSort0(int start, int end, char[] array, CharComparator comp) {
+ char temp;
+ int length = end - start;
+ if (length < 7) {
+ for (int i = start + 1; i < end; i++) {
+ for (int j = i; j > start && comp.compare(array[j - 1], array[j]) > 0; j--) {
+ temp = array[j];
+ array[j] = array[j - 1];
+ array[j - 1] = temp;
+ }
+ }
+ return;
+ }
+ int middle = (start + end) / 2;
+ if (length > 7) {
+ int bottom = start;
+ int top = end - 1;
+ if (length > 40) {
+ length /= 8;
+ bottom = med3(array, bottom, bottom + length, bottom + (2 * length),
+ comp);
+ middle = med3(array, middle - length, middle, middle + length, comp);
+ top = med3(array, top - (2 * length), top - length, top, comp);
+ }
+ middle = med3(array, bottom, middle, top, comp);
+ }
+ char partionValue = array[middle];
+ int a = start;
+ int b = a;
+ int c = end - 1;
+ int d = c;
+ while (true) {
+ int comparison;
+ while (b <= c && (comparison = comp.compare(array[b], partionValue)) <= 0) {
+ if (comparison == 0) {
+ temp = array[a];
+ array[a++] = array[b];
+ array[b] = temp;
+ }
+ b++;
+ }
+ while (c >= b && (comparison = comp.compare(array[c], partionValue)) >= 0) {
+ if (comparison == 0) {
+ temp = array[c];
+ array[c] = array[d];
+ array[d--] = temp;
+ }
+ c--;
+ }
+ if (b > c) {
+ break;
+ }
+ temp = array[b];
+ array[b++] = array[c];
+ array[c--] = temp;
+ }
+ length = a - start < b - a ? a - start : b - a;
+ int l = start;
+ int h = b - length;
+ while (length-- > 0) {
+ temp = array[l];
+ array[l++] = array[h];
+ array[h++] = temp;
+ }
+ length = d - c < end - 1 - d ? d - c : end - 1 - d;
+ l = b;
+ h = end - length;
+ while (length-- > 0) {
+ temp = array[l];
+ array[l++] = array[h];
+ array[h++] = temp;
+ }
+ if ((length = b - a) > 0) {
+ quickSort0(start, start + length, array, comp);
+ }
+ if ((length = d - c) > 0) {
+ quickSort0(end - length, end, array, comp);
+ }
+ }
+
+ /**
+ * Sorts the specified range in the array in a specified order.
+ *
+ * @param array
+ * the {@code double} array to be sorted.
+ * @param start
+ * the start index to sort.
+ * @param end
+ * the last + 1 index to sort.
+ * @param comp
+ * the comparison.
+ * @throws IllegalArgumentException
+ * if {@code start > end}.
+ * @throws ArrayIndexOutOfBoundsException
+ * if {@code start < 0} or {@code end > array.length}.
+ * @see Double#compareTo(Double)
+ */
+ public static void quickSort(double[] array, int start, int end, DoubleComparator comp) {
+ Preconditions.checkNotNull(array);
+ checkBounds(array.length, start, end);
+ quickSort0(start, end, array, comp);
+ }
+
+ private static void quickSort0(int start, int end, double[] array, DoubleComparator comp) {
+ double temp;
+ int length = end - start;
+ if (length < 7) {
+ for (int i = start + 1; i < end; i++) {
+ for (int j = i; j > start && comp.compare(array[j], array[j - 1]) < 0; j--) {
+ temp = array[j];
+ array[j] = array[j - 1];
+ array[j - 1] = temp;
+ }
+ }
+ return;
+ }
+ int middle = (start + end) / 2;
+ if (length > 7) {
+ int bottom = start;
+ int top = end - 1;
+ if (length > 40) {
+ length /= 8;
+ bottom = med3(array, bottom, bottom + length, bottom + (2 * length),
+ comp);
+ middle = med3(array, middle - length, middle, middle + length, comp);
+ top = med3(array, top - (2 * length), top - length, top, comp);
+ }
+ middle = med3(array, bottom, middle, top, comp);
+ }
+ double partionValue = array[middle];
+ int a = start;
+ int b = a;
+ int c = end - 1;
+ int d = c;
+ while (true) {
+ int comparison;
+ while (b <= c && (comparison = comp.compare(partionValue, array[b])) >= 0) {
+ if (comparison == 0) {
+ temp = array[a];
+ array[a++] = array[b];
+ array[b] = temp;
+ }
+ b++;
+ }
+ while (c >= b && (comparison = comp.compare(array[c], partionValue)) >= 0) {
+ if (comparison == 0) {
+ temp = array[c];
+ array[c] = array[d];
+ array[d--] = temp;
+ }
+ c--;
+ }
+ if (b > c) {
+ break;
+ }
+ temp = array[b];
+ array[b++] = array[c];
+ array[c--] = temp;
+ }
+ length = a - start < b - a ? a - start : b - a;
+ int l = start;
+ int h = b - length;
+ while (length-- > 0) {
+ temp = array[l];
+ array[l++] = array[h];
+ array[h++] = temp;
+ }
+ length = d - c < end - 1 - d ? d - c : end - 1 - d;
+ l = b;
+ h = end - length;
+ while (length-- > 0) {
+ temp = array[l];
+ array[l++] = array[h];
+ array[h++] = temp;
+ }
+ if ((length = b - a) > 0) {
+ quickSort0(start, start + length, array, comp);
+ }
+ if ((length = d - c) > 0) {
+ quickSort0(end - length, end, array, comp);
+ }
+ }
+
+ /**
+ * Sorts the specified range in the array in a specified order.
+ *
+ * @param array
+ * the {@code float} array to be sorted.
+ * @param start
+ * the start index to sort.
+ * @param end
+ * the last + 1 index to sort.
+ * @param comp
+ * the comparator.
+ * @throws IllegalArgumentException
+ * if {@code start > end}.
+ * @throws ArrayIndexOutOfBoundsException
+ * if {@code start < 0} or {@code end > array.length}.
+ */
+ public static void quickSort(float[] array, int start, int end, FloatComparator comp) {
+ Preconditions.checkNotNull(array);
+ checkBounds(array.length, start, end);
+ quickSort0(start, end, array, comp);
+ }
+
+ private static void quickSort0(int start, int end, float[] array, FloatComparator comp) {
+ float temp;
+ int length = end - start;
+ if (length < 7) {
+ for (int i = start + 1; i < end; i++) {
+ for (int j = i; j > start && comp.compare(array[j], array[j - 1]) < 0; j--) {
+ temp = array[j];
+ array[j] = array[j - 1];
+ array[j - 1] = temp;
+ }
+ }
+ return;
+ }
+ int middle = (start + end) / 2;
+ if (length > 7) {
+ int bottom = start;
+ int top = end - 1;
+ if (length > 40) {
+ length /= 8;
+ bottom = med3(array, bottom, bottom + length, bottom + (2 * length),
+ comp);
+ middle = med3(array, middle - length, middle, middle + length, comp);
+ top = med3(array, top - (2 * length), top - length, top, comp);
+ }
+ middle = med3(array, bottom, middle, top, comp);
+ }
+ float partionValue = array[middle];
+ int a = start;
+ int b = a;
+ int c = end - 1;
+ int d = c;
+ while (true) {
+ int comparison;
+ while (b <= c && (comparison = comp.compare(partionValue, array[b])) >= 0) {
+ if (comparison == 0) {
+ temp = array[a];
+ array[a++] = array[b];
+ array[b] = temp;
+ }
+ b++;
+ }
+ while (c >= b && (comparison = comp.compare(array[c], partionValue)) >= 0) {
+ if (comparison == 0) {
+ temp = array[c];
+ array[c] = array[d];
+ array[d--] = temp;
+ }
+ c--;
+ }
+ if (b > c) {
+ break;
+ }
+ temp = array[b];
+ array[b++] = array[c];
+ array[c--] = temp;
+ }
+ length = a - start < b - a ? a - start : b - a;
+ int l = start;
+ int h = b - length;
+ while (length-- > 0) {
+ temp = array[l];
+ array[l++] = array[h];
+ array[h++] = temp;
+ }
+ length = d - c < end - 1 - d ? d - c : end - 1 - d;
+ l = b;
+ h = end - length;
+ while (length-- > 0) {
+ temp = array[l];
+ array[l++] = array[h];
+ array[h++] = temp;
+ }
+ if ((length = b - a) > 0) {
+ quickSort0(start, start + length, array, comp);
+ }
+ if ((length = d - c) > 0) {
+ quickSort0(end - length, end, array, comp);
+ }
+ }
+
+ /**
+ * Sorts the specified range in the array in a specified order.
+ *
+ * @param array
+ * the {@code int} array to be sorted.
+ * @param start
+ * the start index to sort.
+ * @param end
+ * the last + 1 index to sort.
+ * @param comp
+ * the comparator.
+ * @throws IllegalArgumentException
+ * if {@code start > end}.
+ * @throws ArrayIndexOutOfBoundsException
+ * if {@code start < 0} or {@code end > array.length}.
+ */
+ public static void quickSort(int[] array, int start, int end, IntComparator comp) {
+ Preconditions.checkNotNull(array);
+ checkBounds(array.length, start, end);
+ quickSort0(start, end, array, comp);
+ }
+
+ private static void quickSort0(int start, int end, int[] array, IntComparator comp) {
+ int temp;
+ int length = end - start;
+ if (length < 7) {
+ for (int i = start + 1; i < end; i++) {
+ for (int j = i; j > start && comp.compare(array[j - 1], array[j]) > 0; j--) {
+ temp = array[j];
+ array[j] = array[j - 1];
+ array[j - 1] = temp;
+ }
+ }
+ return;
+ }
+ int middle = (start + end) / 2;
+ if (length > 7) {
+ int bottom = start;
+ int top = end - 1;
+ if (length > 40) {
+ length /= 8;
+ bottom = med3(array, bottom, bottom + length, bottom + (2 * length),
+ comp);
+ middle = med3(array, middle - length, middle, middle + length, comp);
+ top = med3(array, top - (2 * length), top - length, top, comp);
+ }
+ middle = med3(array, bottom, middle, top, comp);
+ }
+ int partionValue = array[middle];
+ int a = start;
+ int b = a;
+ int c = end - 1;
+ int d = c;
+ while (true) {
+ int comparison;
+ while (b <= c && (comparison = comp.compare(array[b], partionValue)) <= 0) {
+ if (comparison == 0) {
+ temp = array[a];
+ array[a++] = array[b];
+ array[b] = temp;
+ }
+ b++;
+ }
+ while (c >= b && (comparison = comp.compare(array[c], partionValue)) >= 0) {
+ if (comparison == 0) {
+ temp = array[c];
+ array[c] = array[d];
+ array[d--] = temp;
+ }
+ c--;
+ }
+ if (b > c) {
+ break;
+ }
+ temp = array[b];
+ array[b++] = array[c];
+ array[c--] = temp;
+ }
+ length = a - start < b - a ? a - start : b - a;
+ int l = start;
+ int h = b - length;
+ while (length-- > 0) {
+ temp = array[l];
+ array[l++] = array[h];
+ array[h++] = temp;
+ }
+ length = d - c < end - 1 - d ? d - c : end - 1 - d;
+ l = b;
+ h = end - length;
+ while (length-- > 0) {
+ temp = array[l];
+ array[l++] = array[h];
+ array[h++] = temp;
+ }
+ if ((length = b - a) > 0) {
+ quickSort0(start, start + length, array, comp);
+ }
+ if ((length = d - c) > 0) {
+ quickSort0(end - length, end, array, comp);
+ }
+ }
+
+ /**
+ * Sorts the specified range in the array in a specified order.
+ *
+ * @param array
+ * the {@code long} array to be sorted.
+ * @param start
+ * the start index to sort.
+ * @param end
+ * the last + 1 index to sort.
+ * @param comp
+ * the comparator.
+ * @throws IllegalArgumentException
+ * if {@code start > end}.
+ * @throws ArrayIndexOutOfBoundsException
+ * if {@code start < 0} or {@code end > array.length}.
+ */
+ public static void quickSort(long[] array, int start, int end, LongComparator comp) {
+ Preconditions.checkNotNull(array);
+ checkBounds(array.length, start, end);
+ quickSort0(start, end, array, comp);
+ }
+
+ private static void quickSort0(int start, int end, long[] array, LongComparator comp) {
+ long temp;
+ int length = end - start;
+ if (length < 7) {
+ for (int i = start + 1; i < end; i++) {
+ for (int j = i; j > start && comp.compare(array[j - 1], array[j]) > 0; j--) {
+ temp = array[j];
+ array[j] = array[j - 1];
+ array[j - 1] = temp;
+ }
+ }
+ return;
+ }
+ int middle = (start + end) / 2;
+ if (length > 7) {
+ int bottom = start;
+ int top = end - 1;
+ if (length > 40) {
+ length /= 8;
+ bottom = med3(array, bottom, bottom + length, bottom + (2 * length),
+ comp);
+ middle = med3(array, middle - length, middle, middle + length, comp);
+ top = med3(array, top - (2 * length), top - length, top, comp);
+ }
+ middle = med3(array, bottom, middle, top, comp);
+ }
+ long partionValue = array[middle];
+ int a = start;
+ int b = a;
+ int c = end - 1;
+ int d = c;
+ while (true) {
+ int comparison;
+ while (b <= c && (comparison = comp.compare(array[b], partionValue)) <= 0) {
+ if (comparison == 0) {
+ temp = array[a];
+ array[a++] = array[b];
+ array[b] = temp;
+ }
+ b++;
+ }
+ while (c >= b && (comparison = comp.compare(array[c], partionValue)) >= 0) {
+ if (comparison == 0) {
+ temp = array[c];
+ array[c] = array[d];
+ array[d--] = temp;
+ }
+ c--;
+ }
+ if (b > c) {
+ break;
+ }
+ temp = array[b];
+ array[b++] = array[c];
+ array[c--] = temp;
+ }
+ length = a - start < b - a ? a - start : b - a;
+ int l = start;
+ int h = b - length;
+ while (length-- > 0) {
+ temp = array[l];
+ array[l++] = array[h];
+ array[h++] = temp;
+ }
+ length = d - c < end - 1 - d ? d - c : end - 1 - d;
+ l = b;
+ h = end - length;
+ while (length-- > 0) {
+ temp = array[l];
+ array[l++] = array[h];
+ array[h++] = temp;
+ }
+ if ((length = b - a) > 0) {
+ quickSort0(start, start + length, array, comp);
+ }
+ if ((length = d - c) > 0) {
+ quickSort0(end - length, end, array, comp);
+ }
+ }
+
+ /**
+ * Sorts the specified range in the array in a specified order.
+ *
+ * @param array
+ * the array to be sorted.
+ * @param start
+ * the start index to sort.
+ * @param end
+ * the last + 1 index to sort.
+ * @param comp
+ * the comparator.
+ * @throws IllegalArgumentException
+ * if {@code start > end}.
+ * @throws ArrayIndexOutOfBoundsException
+ * if {@code start < 0} or {@code end > array.length}.
+ */
+ public static <T> void quickSort(T[] array, int start, int end, Comparator<T> comp) {
+ Preconditions.checkNotNull(array);
+ checkBounds(array.length, start, end);
+ quickSort0(start, end, array, comp);
+ }
+
+ private static final class ComparableAdaptor<T extends Comparable<? super T>>
+ implements Comparator<T>, Serializable {
+
+ @Override
+ public int compare(T o1, T o2) {
+ return o1.compareTo(o2);
+ }
+
+ }
+
+ /**
+ * Sort the specified range of an array of object that implement the Comparable
+ * interface.
+ * @param <T> The type of object.
+ * @param array the array.
+ * @param start the first index.
+ * @param end the last index (exclusive).
+ */
+ public static <T extends Comparable<? super T>> void quickSort(T[] array, int start, int end) {
+ quickSort(array, start, end, new ComparableAdaptor<T>());
+ }
+
+ private static <T> void quickSort0(int start, int end, T[] array, Comparator<T> comp) {
+ T temp;
+ int length = end - start;
+ if (length < 7) {
+ for (int i = start + 1; i < end; i++) {
+ for (int j = i; j > start && comp.compare(array[j - 1], array[j]) > 0; j--) {
+ temp = array[j];
+ array[j] = array[j - 1];
+ array[j - 1] = temp;
+ }
+ }
+ return;
+ }
+ int middle = (start + end) / 2;
+ if (length > 7) {
+ int bottom = start;
+ int top = end - 1;
+ if (length > 40) {
+ length /= 8;
+ bottom = med3(array, bottom, bottom + length, bottom + (2 * length),
+ comp);
+ middle = med3(array, middle - length, middle, middle + length, comp);
+ top = med3(array, top - (2 * length), top - length, top, comp);
+ }
+ middle = med3(array, bottom, middle, top, comp);
+ }
+ T partionValue = array[middle];
+ int a = start;
+ int b = a;
+ int c = end - 1;
+ int d = c;
+ while (true) {
+ int comparison;
+ while (b <= c && (comparison = comp.compare(array[b], partionValue)) <= 0) {
+ if (comparison == 0) {
+ temp = array[a];
+ array[a++] = array[b];
+ array[b] = temp;
+ }
+ b++;
+ }
+ while (c >= b && (comparison = comp.compare(array[c], partionValue)) >= 0) {
+ if (comparison == 0) {
+ temp = array[c];
+ array[c] = array[d];
+ array[d--] = temp;
+ }
+ c--;
+ }
+ if (b > c) {
+ break;
+ }
+ temp = array[b];
+ array[b++] = array[c];
+ array[c--] = temp;
+ }
+ length = a - start < b - a ? a - start : b - a;
+ int l = start;
+ int h = b - length;
+ while (length-- > 0) {
+ temp = array[l];
+ array[l++] = array[h];
+ array[h++] = temp;
+ }
+ length = d - c < end - 1 - d ? d - c : end - 1 - d;
+ l = b;
+ h = end - length;
+ while (length-- > 0) {
+ temp = array[l];
+ array[l++] = array[h];
+ array[h++] = temp;
+ }
+ if ((length = b - a) > 0) {
+ quickSort0(start, start + length, array, comp);
+ }
+ if ((length = d - c) > 0) {
+ quickSort0(end - length, end, array, comp);
+ }
+ }
+
+ /**
+ * Sorts the specified range in the array in ascending numerical order.
+ *
+ * @param array
+ * the {@code short} array to be sorted.
+ * @param start
+ * the start index to sort.
+ * @param end
+ * the last + 1 index to sort.
+ * @throws IllegalArgumentException
+ * if {@code start > end}.
+ * @throws ArrayIndexOutOfBoundsException
+ * if {@code start < 0} or {@code end > array.length}.
+ */
+ public static void quickSort(short[] array, int start, int end, ShortComparator comp) {
+ Preconditions.checkNotNull(array);
+ checkBounds(array.length, start, end);
+ quickSort0(start, end, array, comp);
+ }
+
+ private static void quickSort0(int start, int end, short[] array, ShortComparator comp) {
+ short temp;
+ int length = end - start;
+ if (length < 7) {
+ for (int i = start + 1; i < end; i++) {
+ for (int j = i; j > start && comp.compare(array[j - 1], array[j]) > 0; j--) {
+ temp = array[j];
+ array[j] = array[j - 1];
+ array[j - 1] = temp;
+ }
+ }
+ return;
+ }
+ int middle = (start + end) / 2;
+ if (length > 7) {
+ int bottom = start;
+ int top = end - 1;
+ if (length > 40) {
+ length /= 8;
+ bottom = med3(array, bottom, bottom + length, bottom + (2 * length),
+ comp);
+ middle = med3(array, middle - length, middle, middle + length, comp);
+ top = med3(array, top - (2 * length), top - length, top, comp);
+ }
+ middle = med3(array, bottom, middle, top, comp);
+ }
+ short partionValue = array[middle];
+ int a = start;
+ int b = a;
+ int c = end - 1;
+ int d = c;
+ while (true) {
+ int comparison;
+ while (b <= c && (comparison = comp.compare(array[b], partionValue)) < 0) {
+ if (comparison == 0) {
+ temp = array[a];
+ array[a++] = array[b];
+ array[b] = temp;
+ }
+ b++;
+ }
+ while (c >= b && (comparison = comp.compare(array[c], partionValue)) > 0) {
+ if (comparison == 0) {
+ temp = array[c];
+ array[c] = array[d];
+ array[d--] = temp;
+ }
+ c--;
+ }
+ if (b > c) {
+ break;
+ }
+ temp = array[b];
+ array[b++] = array[c];
+ array[c--] = temp;
+ }
+ length = a - start < b - a ? a - start : b - a;
+ int l = start;
+ int h = b - length;
+ while (length-- > 0) {
+ temp = array[l];
+ array[l++] = array[h];
+ array[h++] = temp;
+ }
+ length = d - c < end - 1 - d ? d - c : end - 1 - d;
+ l = b;
+ h = end - length;
+ while (length-- > 0) {
+ temp = array[l];
+ array[l++] = array[h];
+ array[h++] = temp;
+ }
+ if ((length = b - a) > 0) {
+ quickSort0(start, start + length, array, comp);
+ }
+ if ((length = d - c) > 0) {
+ quickSort0(end - length, end, array, comp);
+ }
+ }
+
+ /**
+ * Perform a merge sort on the specified range of an array.
+ *
+ * @param <T> the type of object in the array.
+ * @param array the array.
+ * @param start first index.
+ * @param end last index (exclusive).
+ * @param comp comparator object.
+ */
+ @SuppressWarnings("unchecked") // required to make the temp array work, afaict.
+ public static <T> void mergeSort(T[] array, int start, int end, Comparator<T> comp) {
+ checkBounds(array.length, start, end);
+ int length = end - start;
+ if (length <= 0) {
+ return;
+ }
+
+ T[] out = (T[]) new Object[array.length];
+ System.arraycopy(array, start, out, start, length);
+ mergeSort(out, array, start, end, comp);
+ }
+
+ /**
+ * Perform a merge sort of the specific range of an array of objects that implement
+ * Comparable.
+ * @param <T> the type of the objects in the array.
+ * @param array the array.
+ * @param start the first index.
+ * @param end the last index (exclusive).
+ */
+ public static <T extends Comparable<? super T>> void mergeSort(T[] array, int start, int end) {
+ mergeSort(array, start, end, new ComparableAdaptor<T>());
+ }
+
+ /**
+ * Performs a sort on the section of the array between the given indices using
+ * a mergesort with exponential search algorithm (in which the merge is
+ * performed by exponential search). n*log(n) performance is guaranteed and in
+ * the average case it will be faster then any mergesort in which the merge is
+ * performed by linear search.
+ *
+ * @param in
+ * - the array for sorting.
+ * @param out
+ * - the result, sorted array.
+ * @param start
+ * the start index
+ * @param end
+ * the end index + 1
+ * @param c
+ * - the comparator to determine the order of the array.
+ */
+ private static <T> void mergeSort(T[] in, T[] out, int start, int end, Comparator<T> c) {
+ int len = end - start;
+ // use insertion sort for small arrays
+ if (len <= SIMPLE_LENGTH) {
+ for (int i = start + 1; i < end; i++) {
+ T current = out[i];
+ T prev = out[i - 1];
+ if (c.compare(prev, current) > 0) {
+ int j = i;
+ do {
+ out[j--] = prev;
+ } while (j > start && (c.compare(prev = out[j - 1], current) > 0));
+ out[j] = current;
+ }
+ }
+ return;
+ }
+ int med = (end + start) >>> 1;
+ mergeSort(out, in, start, med, c);
+ mergeSort(out, in, med, end, c);
+
+ // merging
+
+ // if arrays are already sorted - no merge
+ if (c.compare(in[med - 1], in[med]) <= 0) {
+ System.arraycopy(in, start, out, start, len);
+ return;
+ }
+ int r = med;
+ int i = start;
+
+ // use merging with exponential search
+ do {
+ T fromVal = in[start];
+ T rVal = in[r];
+ if (c.compare(fromVal, rVal) <= 0) {
+ int l_1 = find(in, rVal, -1, start + 1, med - 1, c);
+ int toCopy = l_1 - start + 1;
+ System.arraycopy(in, start, out, i, toCopy);
+ i += toCopy;
+ out[i++] = rVal;
+ r++;
+ start = l_1 + 1;
+ } else {
+ int r_1 = find(in, fromVal, 0, r + 1, end - 1, c);
+ int toCopy = r_1 - r + 1;
+ System.arraycopy(in, r, out, i, toCopy);
+ i += toCopy;
+ out[i++] = fromVal;
+ start++;
+ r = r_1 + 1;
+ }
+ } while ((end - r) > 0 && (med - start) > 0);
+
+ // copy rest of array
+ if ((end - r) <= 0) {
+ System.arraycopy(in, start, out, i, med - start);
+ } else {
+ System.arraycopy(in, r, out, i, end - r);
+ }
+ }
+
+ /**
+ * Finds the place of specified range of specified sorted array, where the
+ * element should be inserted for getting sorted array. Uses exponential
+ * search algorithm.
+ *
+ * @param arr
+ * - the array with already sorted range
+ * @param val
+ * - object to be inserted
+ * @param l
+ * - the start index
+ * @param r
+ * - the end index
+ * @param bnd
+ * - possible values 0,-1. "-1" - val is located at index more then
+ * elements equals to val. "0" - val is located at index less then
+ * elements equals to val.
+ * @param c
+ * - the comparator used to compare Objects
+ */
+ private static <T> int find(T[] arr, T val, int bnd, int l, int r, Comparator<T> c) {
+ int m = l;
+ int d = 1;
+ while (m <= r) {
+ if (c.compare(val, arr[m]) > bnd) {
+ l = m + 1;
+ } else {
+ r = m - 1;
+ break;
+ }
+ m += d;
+ d <<= 1;
+ }
+ while (l <= r) {
+ m = (l + r) >>> 1;
+ if (c.compare(val, arr[m]) > bnd) {
+ l = m + 1;
+ } else {
+ r = m - 1;
+ }
+ }
+ return l - 1;
+ }
+
+ private static final ByteComparator NATURAL_BYTE_COMPARISON = new ByteComparator() {
+ @Override
+ public int compare(byte o1, byte o2) {
+ return o1 - o2;
+ }
+ };
+
+ /**
+ * Perform a merge sort on a range of a byte array, using numerical order.
+ * @param array the array.
+ * @param start the first index.
+ * @param end the last index (exclusive).
+ */
+ public static void mergeSort(byte[] array, int start, int end) {
+ mergeSort(array, start, end, NATURAL_BYTE_COMPARISON);
+ }
+
+ /**
+ * Perform a merge sort on a range of a byte array using a specified ordering.
+ * @param array the array.
+ * @param start the first index.
+ * @param end the last index (exclusive).
+ * @param comp the comparator object.
+ */
+ public static void mergeSort(byte[] array, int start, int end, ByteComparator comp) {
+ checkBounds(array.length, start, end);
+ byte[] out = Arrays.copyOf(array, array.length);
+ mergeSort(out, array, start, end, comp);
+ }
+
+ private static void mergeSort(byte[] in, byte[] out, int start, int end, ByteComparator c) {
+ int len = end - start;
+ // use insertion sort for small arrays
+ if (len <= SIMPLE_LENGTH) {
+ for (int i = start + 1; i < end; i++) {
+ byte current = out[i];
+ byte prev = out[i - 1];
+ if (c.compare(prev, current) > 0) {
+ int j = i;
+ do {
+ out[j--] = prev;
+ } while (j > start && (c.compare(prev = out[j - 1], current) > 0));
+ out[j] = current;
+ }
+ }
+ return;
+ }
+ int med = (end + start) >>> 1;
+ mergeSort(out, in, start, med, c);
+ mergeSort(out, in, med, end, c);
+
+ // merging
+
+ // if arrays are already sorted - no merge
+ if (c.compare(in[med - 1], in[med]) <= 0) {
+ System.arraycopy(in, start, out, start, len);
+ return;
+ }
+ int r = med;
+ int i = start;
+
+ // use merging with exponential search
+ do {
+ byte fromVal = in[start];
+ byte rVal = in[r];
+ if (c.compare(fromVal, rVal) <= 0) {
+ int l_1 = find(in, rVal, -1, start + 1, med - 1, c);
+ int toCopy = l_1 - start + 1;
+ System.arraycopy(in, start, out, i, toCopy);
+ i += toCopy;
+ out[i++] = rVal;
+ r++;
+ start = l_1 + 1;
+ } else {
+ int r_1 = find(in, fromVal, 0, r + 1, end - 1, c);
+ int toCopy = r_1 - r + 1;
+ System.arraycopy(in, r, out, i, toCopy);
+ i += toCopy;
+ out[i++] = fromVal;
+ start++;
+ r = r_1 + 1;
+ }
+ } while ((end - r) > 0 && (med - start) > 0);
+
+ // copy rest of array
+ if ((end - r) <= 0) {
+ System.arraycopy(in, start, out, i, med - start);
+ } else {
+ System.arraycopy(in, r, out, i, end - r);
+ }
+ }
+
+ private static int find(byte[] arr, byte val, int bnd, int l, int r, ByteComparator c) {
+ int m = l;
+ int d = 1;
+ while (m <= r) {
+ if (c.compare(val, arr[m]) > bnd) {
+ l = m + 1;
+ } else {
+ r = m - 1;
+ break;
+ }
+ m += d;
+ d <<= 1;
+ }
+ while (l <= r) {
+ m = (l + r) >>> 1;
+ if (c.compare(val, arr[m]) > bnd) {
+ l = m + 1;
+ } else {
+ r = m - 1;
+ }
+ }
+ return l - 1;
+ }
+
+ private static final CharComparator NATURAL_CHAR_COMPARISON = new CharComparator() {
+ @Override
+ public int compare(char o1, char o2) {
+ return o1 - o2;
+ }
+ };
+
+ /**
+ * Perform a merge sort on a range of a char array, using numerical order.
+ * @param array the array.
+ * @param start the first index.
+ * @param end the last index (exclusive).
+ */
+ public static void mergeSort(char[] array, int start, int end) {
+ mergeSort(array, start, end, NATURAL_CHAR_COMPARISON);
+ }
+
+ /**
+ * Perform a merge sort on a range of a char array using a specified ordering.
+ * @param array the array.
+ * @param start the first index.
+ * @param end the last index (exclusive).
+ * @param comp the comparator object.
+ */
+ public static void mergeSort(char[] array, int start, int end, CharComparator comp) {
+ checkBounds(array.length, start, end);
+ char[] out = Arrays.copyOf(array, array.length);
+ mergeSort(out, array, start, end, comp);
+ }
+
+ private static void mergeSort(char[] in, char[] out, int start, int end, CharComparator c) {
+ int len = end - start;
+ // use insertion sort for small arrays
+ if (len <= SIMPLE_LENGTH) {
+ for (int i = start + 1; i < end; i++) {
+ char current = out[i];
+ char prev = out[i - 1];
+ if (c.compare(prev, current) > 0) {
+ int j = i;
+ do {
+ out[j--] = prev;
+ } while (j > start && (c.compare(prev = out[j - 1], current) > 0));
+ out[j] = current;
+ }
+ }
+ return;
+ }
+ int med = (end + start) >>> 1;
+ mergeSort(out, in, start, med, c);
+ mergeSort(out, in, med, end, c);
+
+ // merging
+
+ // if arrays are already sorted - no merge
+ if (c.compare(in[med - 1], in[med]) <= 0) {
+ System.arraycopy(in, start, out, start, len);
+ return;
+ }
+ int r = med;
+ int i = start;
+
+ // use merging with exponential search
+ do {
+ char fromVal = in[start];
+ char rVal = in[r];
+ if (c.compare(fromVal, rVal) <= 0) {
+ int l_1 = find(in, rVal, -1, start + 1, med - 1, c);
+ int toCopy = l_1 - start + 1;
+ System.arraycopy(in, start, out, i, toCopy);
+ i += toCopy;
+ out[i++] = rVal;
+ r++;
+ start = l_1 + 1;
+ } else {
+ int r_1 = find(in, fromVal, 0, r + 1, end - 1, c);
+ int toCopy = r_1 - r + 1;
+ System.arraycopy(in, r, out, i, toCopy);
+ i += toCopy;
+ out[i++] = fromVal;
+ start++;
+ r = r_1 + 1;
+ }
+ } while ((end - r) > 0 && (med - start) > 0);
+
+ // copy rest of array
+ if ((end - r) <= 0) {
+ System.arraycopy(in, start, out, i, med - start);
+ } else {
+ System.arraycopy(in, r, out, i, end - r);
+ }
+ }
+
+ private static int find(char[] arr, char val, int bnd, int l, int r, CharComparator c) {
+ int m = l;
+ int d = 1;
+ while (m <= r) {
+ if (c.compare(val, arr[m]) > bnd) {
+ l = m + 1;
+ } else {
+ r = m - 1;
+ break;
+ }
+ m += d;
+ d <<= 1;
+ }
+ while (l <= r) {
+ m = (l + r) >>> 1;
+ if (c.compare(val, arr[m]) > bnd) {
+ l = m + 1;
+ } else {
+ r = m - 1;
+ }
+ }
+ return l - 1;
+ }
+
+ private static final ShortComparator NATURAL_SHORT_COMPARISON = new ShortComparator() {
+ @Override
+ public int compare(short o1, short o2) {
+ return o1 - o2;
+ }
+ };
+
+ /**
+ * Perform a merge sort on a range of a short array, using numerical order.
+ * @param array the array.
+ * @param start the first index.
+ * @param end the last index (exclusive).
+ */
+ public static void mergeSort(short[] array, int start, int end) {
+ mergeSort(array, start, end, NATURAL_SHORT_COMPARISON);
+ }
+
+ public static void mergeSort(short[] array, int start, int end, ShortComparator comp) {
+ checkBounds(array.length, start, end);
+ short[] out = Arrays.copyOf(array, array.length);
+ mergeSort(out, array, start, end, comp);
+ }
+
+
+ /**
+ * Perform a merge sort on a range of a short array using a specified ordering.
+ * @param in the array.
+ * @param start the first index.
+ * @param end the last index (exclusive).
+ * @param c the comparator object.
+ */
+ private static void mergeSort(short[] in, short[] out, int start, int end, ShortComparator c) {
+ int len = end - start;
+ // use insertion sort for small arrays
+ if (len <= SIMPLE_LENGTH) {
+ for (int i = start + 1; i < end; i++) {
+ short current = out[i];
+ short prev = out[i - 1];
+ if (c.compare(prev, current) > 0) {
+ int j = i;
+ do {
+ out[j--] = prev;
+ } while (j > start && (c.compare(prev = out[j - 1], current) > 0));
+ out[j] = current;
+ }
+ }
+ return;
+ }
+ int med = (end + start) >>> 1;
+ mergeSort(out, in, start, med, c);
+ mergeSort(out, in, med, end, c);
+
+ // merging
+
+ // if arrays are already sorted - no merge
+ if (c.compare(in[med - 1], in[med]) <= 0) {
+ System.arraycopy(in, start, out, start, len);
+ return;
+ }
+ int r = med;
+ int i = start;
+
+ // use merging with exponential search
+ do {
+ short fromVal = in[start];
+ short rVal = in[r];
+ if (c.compare(fromVal, rVal) <= 0) {
+ int l_1 = find(in, rVal, -1, start + 1, med - 1, c);
+ int toCopy = l_1 - start + 1;
+ System.arraycopy(in, start, out, i, toCopy);
+ i += toCopy;
+ out[i++] = rVal;
+ r++;
+ start = l_1 + 1;
+ } else {
+ int r_1 = find(in, fromVal, 0, r + 1, end - 1, c);
+ int toCopy = r_1 - r + 1;
+ System.arraycopy(in, r, out, i, toCopy);
+ i += toCopy;
+ out[i++] = fromVal;
+ start++;
+ r = r_1 + 1;
+ }
+ } while ((end - r) > 0 && (med - start) > 0);
+
+ // copy rest of array
+ if ((end - r) <= 0) {
+ System.arraycopy(in, start, out, i, med - start);
+ } else {
+ System.arraycopy(in, r, out, i, end - r);
+ }
+ }
+
+ private static int find(short[] arr, short val, int bnd, int l, int r, ShortComparator c) {
+ int m = l;
+ int d = 1;
+ while (m <= r) {
+ if (c.compare(val, arr[m]) > bnd) {
+ l = m + 1;
+ } else {
+ r = m - 1;
+ break;
+ }
+ m += d;
+ d <<= 1;
+ }
+ while (l <= r) {
+ m = (l + r) >>> 1;
+ if (c.compare(val, arr[m]) > bnd) {
+ l = m + 1;
+ } else {
+ r = m - 1;
+ }
+ }
+ return l - 1;
+ }
+
+ private static final IntComparator NATURAL_INT_COMPARISON = new IntComparator() {
+ @Override
+ public int compare(int o1, int o2) {
+ return o1 < o2 ? -1 : o1 > o2 ? 1 : 0;
+ }
+ };
+
+ public static void mergeSort(int[] array, int start, int end) {
+ mergeSort(array, start, end, NATURAL_INT_COMPARISON);
+ }
+
+ /**
+ * Perform a merge sort on a range of a int array using numerical order.
+ * @param array the array.
+ * @param start the first index.
+ * @param end the last index (exclusive).
+ * @param comp the comparator object.
+ */
+ public static void mergeSort(int[] array, int start, int end, IntComparator comp) {
+ checkBounds(array.length, start, end);
+ int[] out = Arrays.copyOf(array, array.length);
+ mergeSort(out, array, start, end, comp);
+ }
+
+ /**
+ * Perform a merge sort on a range of a int array using a specified ordering.
+ * @param in the array.
+ * @param start the first index.
+ * @param end the last index (exclusive).
+ * @param c the comparator object.
+ */
+ private static void mergeSort(int[] in, int[] out, int start, int end, IntComparator c) {
+ int len = end - start;
+ // use insertion sort for small arrays
+ if (len <= SIMPLE_LENGTH) {
+ for (int i = start + 1; i < end; i++) {
+ int current = out[i];
+ int prev = out[i - 1];
+ if (c.compare(prev, current) > 0) {
+ int j = i;
+ do {
+ out[j--] = prev;
+ } while (j > start && (c.compare(prev = out[j - 1], current) > 0));
+ out[j] = current;
+ }
+ }
+ return;
+ }
+ int med = (end + start) >>> 1;
+ mergeSort(out, in, start, med, c);
+ mergeSort(out, in, med, end, c);
+
+ // merging
+
+ // if arrays are already sorted - no merge
+ if (c.compare(in[med - 1], in[med]) <= 0) {
+ System.arraycopy(in, start, out, start, len);
+ return;
+ }
+ int r = med;
+ int i = start;
+
+ // use merging with exponential search
+ do {
+ int fromVal = in[start];
+ int rVal = in[r];
+ if (c.compare(fromVal, rVal) <= 0) {
+ int l_1 = find(in, rVal, -1, start + 1, med - 1, c);
+ int toCopy = l_1 - start + 1;
+ System.arraycopy(in, start, out, i, toCopy);
+ i += toCopy;
+ out[i++] = rVal;
+ r++;
+ start = l_1 + 1;
+ } else {
+ int r_1 = find(in, fromVal, 0, r + 1, end - 1, c);
+ int toCopy = r_1 - r + 1;
+ System.arraycopy(in, r, out, i, toCopy);
+ i += toCopy;
+ out[i++] = fromVal;
+ start++;
+ r = r_1 + 1;
+ }
+ } while ((end - r) > 0 && (med - start) > 0);
+
+ // copy rest of array
+ if ((end - r) <= 0) {
+ System.arraycopy(in, start, out, i, med - start);
+ } else {
+ System.arraycopy(in, r, out, i, end - r);
+ }
+ }
+
+ private static int find(int[] arr, int val, int bnd, int l, int r, IntComparator c) {
+ int m = l;
+ int d = 1;
+ while (m <= r) {
+ if (c.compare(val, arr[m]) > bnd) {
+ l = m + 1;
+ } else {
+ r = m - 1;
+ break;
+ }
+ m += d;
+ d <<= 1;
+ }
+ while (l <= r) {
+ m = (l + r) >>> 1;
+ if (c.compare(val, arr[m]) > bnd) {
+ l = m + 1;
+ } else {
+ r = m - 1;
+ }
+ }
+ return l - 1;
+ }
+
+
+ private static final LongComparator NATURAL_LONG_COMPARISON = new LongComparator() {
+ @Override
+ public int compare(long o1, long o2) {
+ return o1 < o2 ? -1 : o1 > o2 ? 1 : 0;
+ }
+ };
+
+ /**
+ * Perform a merge sort on a range of a long array using numerical order.
+ * @param array the array.
+ * @param start the first index.
+ * @param end the last index (exclusive).
+ */
+ public static void mergeSort(long[] array, int start, int end) {
+ mergeSort(array, start, end, NATURAL_LONG_COMPARISON);
+ }
+
+ /**
+ * Perform a merge sort on a range of a long array using a specified ordering.
+ * @param array the array.
+ * @param start the first index.
+ * @param end the last index (exclusive).
+ * @param comp the comparator object.
+ */
+ public static void mergeSort(long[] array, int start, int end, LongComparator comp) {
+ checkBounds(array.length, start, end);
+ long[] out = Arrays.copyOf(array, array.length);
+ mergeSort(out, array, start, end, comp);
+ }
+
+ private static void mergeSort(long[] in, long[] out, int start, int end, LongComparator c) {
+ int len = end - start;
+ // use insertion sort for small arrays
+ if (len <= SIMPLE_LENGTH) {
+ for (int i = start + 1; i < end; i++) {
+ long current = out[i];
+ long prev = out[i - 1];
+ if (c.compare(prev, current) > 0) {
+ int j = i;
+ do {
+ out[j--] = prev;
+ } while (j > start && (c.compare(prev = out[j - 1], current) > 0));
+ out[j] = current;
+ }
+ }
+ return;
+ }
+ int med = (end + start) >>> 1;
+ mergeSort(out, in, start, med, c);
+ mergeSort(out, in, med, end, c);
+
+ // merging
+
+ // if arrays are already sorted - no merge
+ if (c.compare(in[med - 1], in[med]) <= 0) {
+ System.arraycopy(in, start, out, start, len);
+ return;
+ }
+ int r = med;
+ int i = start;
+
+ // use merging with exponential search
+ do {
+ long fromVal = in[start];
+ long rVal = in[r];
+ if (c.compare(fromVal, rVal) <= 0) {
+ int l_1 = find(in, rVal, -1, start + 1, med - 1, c);
+ int toCopy = l_1 - start + 1;
+ System.arraycopy(in, start, out, i, toCopy);
+ i += toCopy;
+ out[i++] = rVal;
+ r++;
+ start = l_1 + 1;
+ } else {
+ int r_1 = find(in, fromVal, 0, r + 1, end - 1, c);
+ int toCopy = r_1 - r + 1;
+ System.arraycopy(in, r, out, i, toCopy);
+ i += toCopy;
+ out[i++] = fromVal;
+ start++;
+ r = r_1 + 1;
+ }
+ } while ((end - r) > 0 && (med - start) > 0);
+
+ // copy rest of array
+ if ((end - r) <= 0) {
+ System.arraycopy(in, start, out, i, med - start);
+ } else {
+ System.arraycopy(in, r, out, i, end - r);
+ }
+ }
+
+ private static int find(long[] arr, long val, int bnd, int l, int r, LongComparator c) {
+ int m = l;
+ int d = 1;
+ while (m <= r) {
+ if (c.compare(val, arr[m]) > bnd) {
+ l = m + 1;
+ } else {
+ r = m - 1;
+ break;
+ }
+ m += d;
+ d <<= 1;
+ }
+ while (l <= r) {
+ m = (l + r) >>> 1;
+ if (c.compare(val, arr[m]) > bnd) {
+ l = m + 1;
+ } else {
+ r = m - 1;
+ }
+ }
+ return l - 1;
+ }
+
+ private static final FloatComparator NATURAL_FLOAT_COMPARISON = new FloatComparator() {
+ @Override
+ public int compare(float o1, float o2) {
+ return Float.compare(o1, o2);
+ }
+ };
+
+ /**
+ * Perform a merge sort on a range of a float array using Float.compare for an ordering.
+ * @param array the array.
+ * @param start the first index.
+ * @param end the last index (exclusive).
+ */
+ public static void mergeSort(float[] array, int start, int end) {
+ mergeSort(array, start, end, NATURAL_FLOAT_COMPARISON);
+ }
+
+ /**
+ * Perform a merge sort on a range of a float array using a specified ordering.
+ * @param array the array.
+ * @param start the first index.
+ * @param end the last index (exclusive).
+ * @param comp the comparator object.
+ */
+ public static void mergeSort(float[] array, int start, int end, FloatComparator comp) {
+ checkBounds(array.length, start, end);
+ float[] out = Arrays.copyOf(array, array.length);
+ mergeSort(out, array, start, end, comp);
+ }
+
+ private static void mergeSort(float[] in, float[] out, int start, int end, FloatComparator c) {
+ int len = end - start;
+ // use insertion sort for small arrays
+ if (len <= SIMPLE_LENGTH) {
+ for (int i = start + 1; i < end; i++) {
+ float current = out[i];
+ float prev = out[i - 1];
+ if (c.compare(prev, current) > 0) {
+ int j = i;
+ do {
+ out[j--] = prev;
+ } while (j > start && (c.compare(prev = out[j - 1], current) > 0));
+ out[j] = current;
+ }
+ }
+ return;
+ }
+ int med = (end + start) >>> 1;
+ mergeSort(out, in, start, med, c);
+ mergeSort(out, in, med, end, c);
+
+ // merging
+
+ // if arrays are already sorted - no merge
+ if (c.compare(in[med - 1], in[med]) <= 0) {
+ System.arraycopy(in, start, out, start, len);
+ return;
+ }
+ int r = med;
+ int i = start;
+
+ // use merging with exponential search
+ do {
+ float fromVal = in[start];
+ float rVal = in[r];
+ if (c.compare(fromVal, rVal) <= 0) {
+ int l_1 = find(in, rVal, -1, start + 1, med - 1, c);
+ int toCopy = l_1 - start + 1;
+ System.arraycopy(in, start, out, i, toCopy);
+ i += toCopy;
+ out[i++] = rVal;
+ r++;
+ start = l_1 + 1;
+ } else {
+ int r_1 = find(in, fromVal, 0, r + 1, end - 1, c);
+ int toCopy = r_1 - r + 1;
+ System.arraycopy(in, r, out, i, toCopy);
+ i += toCopy;
+ out[i++] = fromVal;
+ start++;
+ r = r_1 + 1;
+ }
+ } while ((end - r) > 0 && (med - start) > 0);
+
+ // copy rest of array
+ if ((end - r) <= 0) {
+ System.arraycopy(in, start, out, i, med - start);
+ } else {
+ System.arraycopy(in, r, out, i, end - r);
+ }
+ }
+
+ private static int find(float[] arr, float val, int bnd, int l, int r, FloatComparator c) {
+ int m = l;
+ int d = 1;
+ while (m <= r) {
+ if (c.compare(val, arr[m]) > bnd) {
+ l = m + 1;
+ } else {
+ r = m - 1;
+ break;
+ }
+ m += d;
+ d <<= 1;
+ }
+ while (l <= r) {
+ m = (l + r) >>> 1;
+ if (c.compare(val, arr[m]) > bnd) {
+ l = m + 1;
+ } else {
+ r = m - 1;
+ }
+ }
+ return l - 1;
+ }
+
+ private static final DoubleComparator NATURAL_DOUBLE_COMPARISON = new DoubleComparator() {
+ @Override
+ public int compare(double o1, double o2) {
+ return Double.compare(o1, o2);
+ }
+ };
+
+
+ /**
+ * Perform a merge sort on a range of a double array using a Double.compare as an ordering.
+ * @param array the array.
+ * @param start the first index.
+ * @param end the last index (exclusive).
+ */
+ public static void mergeSort(double[] array, int start, int end) {
+ mergeSort(array, start, end, NATURAL_DOUBLE_COMPARISON);
+ }
+
+ /**
+ * Perform a merge sort on a range of a double array using a specified ordering.
+ * @param array the array.
+ * @param start the first index.
+ * @param end the last index (exclusive).
+ * @param comp the comparator object.
+ */
+ public static void mergeSort(double[] array, int start, int end, DoubleComparator comp) {
+ checkBounds(array.length, start, end);
+ double[] out = Arrays.copyOf(array, array.length);
+ mergeSort(out, array, start, end, comp);
+ }
+
+ private static void mergeSort(double[] in, double[] out, int start, int end, DoubleComparator c) {
+ int len = end - start;
+ // use insertion sort for small arrays
+ if (len <= SIMPLE_LENGTH) {
+ for (int i = start + 1; i < end; i++) {
+ double current = out[i];
+ double prev = out[i - 1];
+ if (c.compare(prev, current) > 0) {
+ int j = i;
+ do {
+ out[j--] = prev;
+ } while (j > start && (c.compare(prev = out[j - 1], current) > 0));
+ out[j] = current;
+ }
+ }
+ return;
+ }
+ int med = (end + start) >>> 1;
+ mergeSort(out, in, start, med, c);
+ mergeSort(out, in, med, end, c);
+
+ // merging
+
+ // if arrays are already sorted - no merge
+ if (c.compare(in[med - 1], in[med]) <= 0) {
+ System.arraycopy(in, start, out, start, len);
+ return;
+ }
+ int r = med;
+ int i = start;
+
+ // use merging with exponential search
+ do {
+ double fromVal = in[start];
+ double rVal = in[r];
+ if (c.compare(fromVal, rVal) <= 0) {
+ int l_1 = find(in, rVal, -1, start + 1, med - 1, c);
+ int toCopy = l_1 - start + 1;
+ System.arraycopy(in, start, out, i, toCopy);
+ i += toCopy;
+ out[i++] = rVal;
+ r++;
+ start = l_1 + 1;
+ } else {
+ int r_1 = find(in, fromVal, 0, r + 1, end - 1, c);
+ int toCopy = r_1 - r + 1;
+ System.arraycopy(in, r, out, i, toCopy);
+ i += toCopy;
+ out[i++] = fromVal;
+ start++;
+ r = r_1 + 1;
+ }
+ } while ((end - r) > 0 && (med - start) > 0);
+
+ // copy rest of array
+ if ((end - r) <= 0) {
+ System.arraycopy(in, start, out, i, med - start);
+ } else {
+ System.arraycopy(in, r, out, i, end - r);
+ }
+ }
+
+ private static int find(double[] arr, double val, int bnd, int l, int r, DoubleComparator c) {
+ int m = l;
+ int d = 1;
+ while (m <= r) {
+ if (c.compare(val, arr[m]) > bnd) {
+ l = m + 1;
+ } else {
+ r = m - 1;
+ break;
+ }
+ m += d;
+ d <<= 1;
+ }
+ while (l <= r) {
+ m = (l + r) >>> 1;
+ if (c.compare(val, arr[m]) > bnd) {
+ l = m + 1;
+ } else {
+ r = m - 1;
+ }
+ }
+ return l - 1;
+ }
+
+ /**
+ * Transforms two consecutive sorted ranges into a single sorted range. The initial ranges are {@code [first,}
+ * middle)</code> and {@code [middle, last)}, and the resulting range is {@code [first, last)}. Elements in
+ * the first input range will precede equal elements in the second.
+ */
+ static void inplaceMerge(int first, int middle, int last, IntComparator comp, Swapper swapper) {
+ if (first >= middle || middle >= last) {
+ return;
+ }
+ if (last - first == 2) {
+ if (comp.compare(middle, first) < 0) {
+ swapper.swap(first, middle);
+ }
+ return;
+ }
+ int firstCut;
+ int secondCut;
+ if (middle - first > last - middle) {
+ firstCut = first + (middle - first) / 2;
+ secondCut = lowerBound(middle, last, firstCut, comp);
+ } else {
+ secondCut = middle + (last - middle) / 2;
+ firstCut = upperBound(first, middle, secondCut, comp);
+ }
+
+ // rotate(firstCut, middle, secondCut, swapper);
+ // is manually inlined for speed (jitter inlining seems to work only for small call depths, even if methods
+ // are "static private")
+ // speedup = 1.7
+ // begin inline
+ int first2 = firstCut;
+ int middle2 = middle;
+ int last2 = secondCut;
+ if (middle2 != first2 && middle2 != last2) {
+ int first1 = first2;
+ int last1 = middle2;
+ while (first1 < --last1) {
+ swapper.swap(first1++, last1);
+ }
+ first1 = middle2;
+ last1 = last2;
+ while (first1 < --last1) {
+ swapper.swap(first1++, last1);
+ }
+ first1 = first2;
+ last1 = last2;
+ while (first1 < --last1) {
+ swapper.swap(first1++, last1);
+ }
+ }
+ // end inline
+
+ middle = firstCut + (secondCut - middle);
+ inplaceMerge(first, firstCut, middle, comp, swapper);
+ inplaceMerge(middle, secondCut, last, comp, swapper);
+ }
+
+ /**
+ * Performs a binary search on an already-sorted range: finds the first position where an element can be inserted
+ * without violating the ordering. Sorting is by a user-supplied comparison function.
+ *
+ * @param first Beginning of the range.
+ * @param last One past the end of the range.
+ * @param x Element to be searched for.
+ * @param comp Comparison function.
+ * @return The largest index i such that, for every j in the range <code>[first, i)</code>,
+ * <code></code></codeA>{@code comp.apply(array[j], x)</code> is {@code true}.
+ * @see Sorting#upperBound
+ */
+ static int lowerBound(int first, int last, int x, IntComparator comp) {
+ int len = last - first;
+ while (len > 0) {
+ int half = len / 2;
+ int middle = first + half;
+ if (comp.compare(middle, x) < 0) {
+ first = middle + 1;
+ len -= half + 1;
+ } else {
+ len = half;
+ }
+ }
+ return first;
+ }
+
+ /**
+ * Sorts the specified range of elements according to the order induced by the specified comparator. All elements in
+ * the range must be <i>mutually comparable</i> by the specified comparator (that is, <tt>c.compare(a, b)</tt> must
+ * not throw an exception for any indexes <tt>a</tt> and <tt>b</tt> in the range).<p>
+ *
+ * This sort is guaranteed to be <i>stable</i>: equal elements will not be reordered as a result of the sort.<p>
+ *
+ * The sorting algorithm is a modified mergesort (in which the merge is omitted if the highest element in the low
+ * sublist is less than the lowest element in the high sublist). This algorithm offers guaranteed n*log(n)
+ * performance, and can approach linear performance on nearly sorted lists.
+ *
+ * @param fromIndex the index of the first element (inclusive) to be sorted.
+ * @param toIndex the index of the last element (exclusive) to be sorted.
+ * @param c the comparator to determine the order of the generic data.
+ * @param swapper an object that knows how to swap the elements at any two indexes (a,b).
+ * @see IntComparator
+ * @see Swapper
+ */
+ public static void mergeSort(int fromIndex, int toIndex, IntComparator c, Swapper swapper) {
+ /*
+ We retain the same method signature as quickSort.
+ Given only a comparator and swapper we do not know how to copy and move elements from/to temporary arrays.
+ Hence, in contrast to the JDK mergesorts this is an "in-place" mergesort, i.e. does not allocate any temporary
+ arrays.
+ A non-inplace mergesort would perhaps be faster in most cases, but would require non-intuitive delegate objects...
+ */
+ int length = toIndex - fromIndex;
+
+ // Insertion sort on smallest arrays
+ if (length < SMALL) {
+ for (int i = fromIndex; i < toIndex; i++) {
+ for (int j = i; j > fromIndex && (c.compare(j - 1, j) > 0); j--) {
+ swapper.swap(j, j - 1);
+ }
+ }
+ return;
+ }
+
+ // Recursively sort halves
+ int mid = (fromIndex + toIndex) / 2;
+ mergeSort(fromIndex, mid, c, swapper);
+ mergeSort(mid, toIndex, c, swapper);
+
+ // If list is already sorted, nothing left to do. This is an
+ // optimization that results in faster sorts for nearly ordered lists.
+ if (c.compare(mid - 1, mid) <= 0) {
+ return;
+ }
+
+ // Merge sorted halves
+ inplaceMerge(fromIndex, mid, toIndex, c, swapper);
+ }
+
+ /**
+ * Performs a binary search on an already-sorted range: finds the last position where an element can be inserted
+ * without violating the ordering. Sorting is by a user-supplied comparison function.
+ *
+ * @param first Beginning of the range.
+ * @param last One past the end of the range.
+ * @param x Element to be searched for.
+ * @param comp Comparison function.
+ * @return The largest index i such that, for every j in the range <code>[first, i)</code>, {@code comp.apply(x,}
+ * array[j])</code> is {@code false}.
+ * @see Sorting#lowerBound
+ */
+ static int upperBound(int first, int last, int x, IntComparator comp) {
+ int len = last - first;
+ while (len > 0) {
+ int half = len / 2;
+ int middle = first + half;
+ if (comp.compare(x, middle) < 0) {
+ len = half;
+ } else {
+ first = middle + 1;
+ len -= half + 1;
+ }
+ }
+ return first;
+ }
+
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/SparseColumnMatrix.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/mahout/math/SparseColumnMatrix.java b/core/src/main/java/org/apache/mahout/math/SparseColumnMatrix.java
new file mode 100644
index 0000000..eeffc78
--- /dev/null
+++ b/core/src/main/java/org/apache/mahout/math/SparseColumnMatrix.java
@@ -0,0 +1,220 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math;
+
+import org.apache.mahout.math.flavor.TraversingStructureEnum;
+
+/**
+ * sparse matrix with general element values whose columns are accessible quickly. Implemented as a column array of
+ * SparseVectors.
+ *
+ * @deprecated tons of inconsistences. Use transpose view of SparseRowMatrix for fast column-wise iteration.
+ */
+public class SparseColumnMatrix extends AbstractMatrix {
+
+ private Vector[] columnVectors;
+
+ /**
+ * Construct a matrix of the given cardinality with the given data columns
+ *
+ * @param columns a RandomAccessSparseVector[] array of columns
+ * @param columnVectors
+ */
+ public SparseColumnMatrix(int rows, int columns, Vector[] columnVectors) {
+ this(rows, columns, columnVectors, false);
+ }
+
+ public SparseColumnMatrix(int rows, int columns, Vector[] columnVectors, boolean shallow) {
+ super(rows, columns);
+ if (shallow) {
+ this.columnVectors = columnVectors;
+ } else {
+ this.columnVectors = columnVectors.clone();
+ for (int col = 0; col < columnSize(); col++) {
+ this.columnVectors[col] = this.columnVectors[col].clone();
+ }
+ }
+ }
+
+ /**
+ * Construct a matrix of the given cardinality
+ *
+ * @param rows # of rows
+ * @param columns # of columns
+ */
+ public SparseColumnMatrix(int rows, int columns) {
+ super(rows, columns);
+ this.columnVectors = new RandomAccessSparseVector[columnSize()];
+ for (int col = 0; col < columnSize(); col++) {
+ this.columnVectors[col] = new RandomAccessSparseVector(rowSize());
+ }
+ }
+
+ @Override
+ public Matrix clone() {
+ SparseColumnMatrix clone = (SparseColumnMatrix) super.clone();
+ clone.columnVectors = new Vector[columnVectors.length];
+ for (int i = 0; i < columnVectors.length; i++) {
+ clone.columnVectors[i] = columnVectors[i].clone();
+ }
+ return clone;
+ }
+
+ /**
+ * Abstracted out for the iterator
+ *
+ * @return {@link #numCols()}
+ */
+ @Override
+ public int numSlices() {
+ return numCols();
+ }
+
+ @Override
+ public double getQuick(int row, int column) {
+ return columnVectors[column] == null ? 0.0 : columnVectors[column].getQuick(row);
+ }
+
+ @Override
+ public Matrix like() {
+ return new SparseColumnMatrix(rowSize(), columnSize());
+ }
+
+ @Override
+ public Matrix like(int rows, int columns) {
+ return new SparseColumnMatrix(rows, columns);
+ }
+
+ @Override
+ public void setQuick(int row, int column, double value) {
+ if (columnVectors[column] == null) {
+ columnVectors[column] = new RandomAccessSparseVector(rowSize());
+ }
+ columnVectors[column].setQuick(row, value);
+ }
+
+ @Override
+ public int[] getNumNondefaultElements() {
+ int[] result = new int[2];
+ result[COL] = columnVectors.length;
+ for (int col = 0; col < columnSize(); col++) {
+ result[ROW] = Math.max(result[ROW], columnVectors[col]
+ .getNumNondefaultElements());
+ }
+ return result;
+ }
+
+ @Override
+ public Matrix viewPart(int[] offset, int[] size) {
+ if (offset[ROW] < 0) {
+ throw new IndexException(offset[ROW], columnVectors[COL].size());
+ }
+ if (offset[ROW] + size[ROW] > columnVectors[COL].size()) {
+ throw new IndexException(offset[ROW] + size[ROW], columnVectors[COL].size());
+ }
+ if (offset[COL] < 0) {
+ throw new IndexException(offset[COL], columnVectors.length);
+ }
+ if (offset[COL] + size[COL] > columnVectors.length) {
+ throw new IndexException(offset[COL] + size[COL], columnVectors.length);
+ }
+ return new MatrixView(this, offset, size);
+ }
+
+ @Override
+ public Matrix assignColumn(int column, Vector other) {
+ if (rowSize() != other.size()) {
+ throw new CardinalityException(rowSize(), other.size());
+ }
+ if (column < 0 || column >= columnSize()) {
+ throw new IndexException(column, columnSize());
+ }
+ columnVectors[column].assign(other);
+ return this;
+ }
+
+ @Override
+ public Matrix assignRow(int row, Vector other) {
+ if (columnSize() != other.size()) {
+ throw new CardinalityException(columnSize(), other.size());
+ }
+ if (row < 0 || row >= rowSize()) {
+ throw new IndexException(row, rowSize());
+ }
+ for (int col = 0; col < columnSize(); col++) {
+ columnVectors[col].setQuick(row, other.getQuick(col));
+ }
+ return this;
+ }
+
+ @Override
+ public Vector viewColumn(int column) {
+ if (column < 0 || column >= columnSize()) {
+ throw new IndexException(column, columnSize());
+ }
+ return columnVectors[column];
+ }
+
+ @Override
+ public Matrix transpose() {
+ SparseRowMatrix srm = new SparseRowMatrix(columns, rows);
+ for (int i = 0; i < columns; i++) {
+ Vector col = columnVectors[i];
+ if (col.getNumNonZeroElements() > 0)
+ // this should already be optimized
+ srm.assignRow(i, col);
+ }
+ return srm;
+ }
+
+ @Override
+ public String toString() {
+ int row = 0;
+ int maxRowsToDisplay = 10;
+ int maxColsToDisplay = 20;
+ int colsToDisplay = maxColsToDisplay;
+
+ if(maxColsToDisplay > columnSize()){
+ colsToDisplay = columnSize();
+ }
+
+ StringBuilder s = new StringBuilder("{\n");
+ for (MatrixSlice next : this.transpose()) {
+ if (row < maxRowsToDisplay) {
+ s.append(" ")
+ .append(next.index())
+ .append(" =>\t")
+ .append(new VectorView(next.vector(), 0, colsToDisplay))
+ .append('\n');
+ row++;
+ }
+ }
+
+ String returnString = s.toString();
+ if (maxColsToDisplay <= columnSize()) {
+ returnString = returnString.replace("}", " ... }");
+ }
+
+ if (maxRowsToDisplay <= rowSize()) {
+ return returnString + "... }";
+ } else {
+ return returnString + "}";
+ }
+ }
+
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/SparseMatrix.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/mahout/math/SparseMatrix.java b/core/src/main/java/org/apache/mahout/math/SparseMatrix.java
new file mode 100644
index 0000000..a75ac55
--- /dev/null
+++ b/core/src/main/java/org/apache/mahout/math/SparseMatrix.java
@@ -0,0 +1,245 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math;
+
+import it.unimi.dsi.fastutil.ints.Int2ObjectMap.Entry;
+import it.unimi.dsi.fastutil.ints.Int2ObjectOpenHashMap;
+import it.unimi.dsi.fastutil.objects.ObjectIterator;
+
+import java.util.Iterator;
+import java.util.Map;
+
+import org.apache.mahout.math.flavor.MatrixFlavor;
+import org.apache.mahout.math.function.DoubleDoubleFunction;
+import org.apache.mahout.math.function.Functions;
+import org.apache.mahout.math.list.IntArrayList;
+
+import com.google.common.collect.AbstractIterator;
+
+/** Doubly sparse matrix. Implemented as a Map of RandomAccessSparseVector rows */
+public class SparseMatrix extends AbstractMatrix {
+
+ private Int2ObjectOpenHashMap<Vector> rowVectors;
+
+ /**
+ * Construct a matrix of the given cardinality with the given row map
+ *
+ * @param rows no of rows
+ * @param columns no of columns
+ * @param rowVectors a {@code Map<Integer, RandomAccessSparseVector>} of rows
+ */
+ public SparseMatrix(int rows, int columns, Map<Integer, Vector> rowVectors) {
+ this(rows, columns, rowVectors, false);
+ }
+
+ public SparseMatrix(int rows, int columns, Map<Integer, Vector> rowVectors, boolean shallow) {
+
+ // Why this is passing in a map? iterating it is pretty inefficient as opposed to simple lists...
+ super(rows, columns);
+ this.rowVectors = new Int2ObjectOpenHashMap<>();
+ if (shallow) {
+ for (Map.Entry<Integer, Vector> entry : rowVectors.entrySet()) {
+ this.rowVectors.put(entry.getKey().intValue(), entry.getValue());
+ }
+ } else {
+ for (Map.Entry<Integer, Vector> entry : rowVectors.entrySet()) {
+ this.rowVectors.put(entry.getKey().intValue(), entry.getValue().clone());
+ }
+ }
+ }
+
+ /**
+ * Construct a matrix with specified number of rows and columns.
+ */
+ public SparseMatrix(int rows, int columns) {
+ super(rows, columns);
+ this.rowVectors = new Int2ObjectOpenHashMap<>();
+ }
+
+ @Override
+ public Matrix clone() {
+ SparseMatrix clone = new SparseMatrix(numRows(), numCols());
+ for (MatrixSlice slice : this) {
+ clone.rowVectors.put(slice.index(), slice.clone());
+ }
+ return clone;
+ }
+
+ @Override
+ public int numSlices() {
+ return rowVectors.size();
+ }
+
+ public Iterator<MatrixSlice> iterateNonEmpty() {
+ final int[] keys = rowVectors.keySet().toIntArray();
+ return new AbstractIterator<MatrixSlice>() {
+ private int slice;
+ @Override
+ protected MatrixSlice computeNext() {
+ if (slice >= rowVectors.size()) {
+ return endOfData();
+ }
+ int i = keys[slice];
+ Vector row = rowVectors.get(i);
+ slice++;
+ return new MatrixSlice(row, i);
+ }
+ };
+ }
+
+ @Override
+ public double getQuick(int row, int column) {
+ Vector r = rowVectors.get(row);
+ return r == null ? 0.0 : r.getQuick(column);
+ }
+
+ @Override
+ public Matrix like() {
+ return new SparseMatrix(rowSize(), columnSize());
+ }
+
+ @Override
+ public Matrix like(int rows, int columns) {
+ return new SparseMatrix(rows, columns);
+ }
+
+ @Override
+ public void setQuick(int row, int column, double value) {
+ Vector r = rowVectors.get(row);
+ if (r == null) {
+ r = new RandomAccessSparseVector(columnSize());
+ rowVectors.put(row, r);
+ }
+ r.setQuick(column, value);
+ }
+
+ @Override
+ public int[] getNumNondefaultElements() {
+ int[] result = new int[2];
+ result[ROW] = rowVectors.size();
+ for (Vector row : rowVectors.values()) {
+ result[COL] = Math.max(result[COL], row.getNumNondefaultElements());
+ }
+ return result;
+ }
+
+ @Override
+ public Matrix viewPart(int[] offset, int[] size) {
+ if (offset[ROW] < 0) {
+ throw new IndexException(offset[ROW], rowSize());
+ }
+ if (offset[ROW] + size[ROW] > rowSize()) {
+ throw new IndexException(offset[ROW] + size[ROW], rowSize());
+ }
+ if (offset[COL] < 0) {
+ throw new IndexException(offset[COL], columnSize());
+ }
+ if (offset[COL] + size[COL] > columnSize()) {
+ throw new IndexException(offset[COL] + size[COL], columnSize());
+ }
+ return new MatrixView(this, offset, size);
+ }
+
+ @Override
+ public Matrix assign(Matrix other, DoubleDoubleFunction function) {
+ //TODO generalize to other kinds of functions
+ if (Functions.PLUS.equals(function) && other instanceof SparseMatrix) {
+ int rows = rowSize();
+ if (rows != other.rowSize()) {
+ throw new CardinalityException(rows, other.rowSize());
+ }
+ int columns = columnSize();
+ if (columns != other.columnSize()) {
+ throw new CardinalityException(columns, other.columnSize());
+ }
+
+ SparseMatrix otherSparse = (SparseMatrix) other;
+ for(ObjectIterator<Entry<Vector>> fastIterator = otherSparse.rowVectors.int2ObjectEntrySet().fastIterator();
+ fastIterator.hasNext();) {
+ final Entry<Vector> entry = fastIterator.next();
+ final int rowIndex = entry.getIntKey();
+ Vector row = rowVectors.get(rowIndex);
+ if (row == null) {
+ rowVectors.put(rowIndex, entry.getValue().clone());
+ } else {
+ row.assign(entry.getValue(), Functions.PLUS);
+ }
+ }
+ return this;
+ } else {
+ return super.assign(other, function);
+ }
+ }
+
+ @Override
+ public Matrix assignColumn(int column, Vector other) {
+ if (rowSize() != other.size()) {
+ throw new CardinalityException(rowSize(), other.size());
+ }
+ if (column < 0 || column >= columnSize()) {
+ throw new IndexException(column, columnSize());
+ }
+ for (int row = 0; row < rowSize(); row++) {
+ double val = other.getQuick(row);
+ if (val != 0.0) {
+ Vector r = rowVectors.get(row);
+ if (r == null) {
+ r = new RandomAccessSparseVector(columnSize());
+ rowVectors.put(row, r);
+ }
+ r.setQuick(column, val);
+ }
+ }
+ return this;
+ }
+
+ @Override
+ public Matrix assignRow(int row, Vector other) {
+ if (columnSize() != other.size()) {
+ throw new CardinalityException(columnSize(), other.size());
+ }
+ if (row < 0 || row >= rowSize()) {
+ throw new IndexException(row, rowSize());
+ }
+ rowVectors.put(row, other);
+ return this;
+ }
+
+ @Override
+ public Vector viewRow(int row) {
+ if (row < 0 || row >= rowSize()) {
+ throw new IndexException(row, rowSize());
+ }
+ Vector res = rowVectors.get(row);
+ if (res == null) {
+ res = new RandomAccessSparseVector(columnSize());
+ rowVectors.put(row, res);
+ }
+ return res;
+ }
+
+ /** special method necessary for efficient serialization */
+ public IntArrayList nonZeroRowIndices() {
+ return new IntArrayList(rowVectors.keySet().toIntArray());
+ }
+
+ @Override
+ public MatrixFlavor getFlavor() {
+ return MatrixFlavor.SPARSEROWLIKE;
+ }
+}
r***@apache.org
2018-09-08 23:35:17 UTC
Permalink
http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/DiagonalMatrix.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/mahout/math/DiagonalMatrix.java b/core/src/main/java/org/apache/mahout/math/DiagonalMatrix.java
new file mode 100644
index 0000000..070fad2
--- /dev/null
+++ b/core/src/main/java/org/apache/mahout/math/DiagonalMatrix.java
@@ -0,0 +1,378 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math;
+
+import org.apache.mahout.math.flavor.MatrixFlavor;
+import org.apache.mahout.math.flavor.TraversingStructureEnum;
+
+import java.util.Iterator;
+import java.util.NoSuchElementException;
+
+public class DiagonalMatrix extends AbstractMatrix implements MatrixTimesOps {
+ private final Vector diagonal;
+
+ public DiagonalMatrix(Vector values) {
+ super(values.size(), values.size());
+ this.diagonal = values;
+ }
+
+ public DiagonalMatrix(Matrix values) {
+ this(values.viewDiagonal());
+ }
+
+ public DiagonalMatrix(double value, int size) {
+ this(new ConstantVector(value, size));
+ }
+
+ public DiagonalMatrix(double[] values) {
+ super(values.length, values.length);
+ this.diagonal = new DenseVector(values);
+ }
+
+ public static DiagonalMatrix identity(int size) {
+ return new DiagonalMatrix(1, size);
+ }
+
+ @Override
+ public Matrix assignColumn(int column, Vector other) {
+ throw new UnsupportedOperationException("Can't assign a column to a diagonal matrix");
+ }
+
+ /**
+ * Assign the other vector values to the row of the receiver
+ *
+ * @param row the int row to assign
+ * @param other a Vector
+ * @return the modified receiver
+ * @throws CardinalityException if the cardinalities differ
+ */
+ @Override
+ public Matrix assignRow(int row, Vector other) {
+ throw new UnsupportedOperationException("Can't assign a row to a diagonal matrix");
+ }
+
+ @Override
+ public Vector viewRow(int row) {
+ return new SingleElementVector(row);
+ }
+
+ @Override
+ public Vector viewColumn(int row) {
+ return new SingleElementVector(row);
+ }
+
+ /**
+ * Special class to implement views of rows and columns of a diagonal matrix.
+ */
+ public class SingleElementVector extends AbstractVector {
+ private int index;
+
+ public SingleElementVector(int index) {
+ super(diagonal.size());
+ this.index = index;
+ }
+
+ @Override
+ public double getQuick(int index) {
+ if (index == this.index) {
+ return diagonal.get(index);
+ } else {
+ return 0;
+ }
+ }
+
+ @Override
+ public void set(int index, double value) {
+ if (index == this.index) {
+ diagonal.set(index, value);
+ } else {
+ throw new IllegalArgumentException("Can't set off-diagonal element of diagonal matrix");
+ }
+ }
+
+ @Override
+ protected Iterator<Element> iterateNonZero() {
+ return new Iterator<Element>() {
+ boolean more = true;
+
+ @Override
+ public boolean hasNext() {
+ return more;
+ }
+
+ @Override
+ public Element next() {
+ if (more) {
+ more = false;
+ return new Element() {
+ @Override
+ public double get() {
+ return diagonal.get(index);
+ }
+
+ @Override
+ public int index() {
+ return index;
+ }
+
+ @Override
+ public void set(double value) {
+ diagonal.set(index, value);
+ }
+ };
+ } else {
+ throw new NoSuchElementException("Only one non-zero element in a row or column of a diagonal matrix");
+ }
+ }
+
+ @Override
+ public void remove() {
+ throw new UnsupportedOperationException("Can't remove from vector view");
+ }
+ };
+ }
+
+ @Override
+ protected Iterator<Element> iterator() {
+ return new Iterator<Element>() {
+ int i = 0;
+
+ Element r = new Element() {
+ @Override
+ public double get() {
+ if (i == index) {
+ return diagonal.get(index);
+ } else {
+ return 0;
+ }
+ }
+
+ @Override
+ public int index() {
+ return i;
+ }
+
+ @Override
+ public void set(double value) {
+ if (i == index) {
+ diagonal.set(index, value);
+ } else {
+ throw new IllegalArgumentException("Can't set any element but diagonal");
+ }
+ }
+ };
+
+ @Override
+ public boolean hasNext() {
+ return i < diagonal.size() - 1;
+ }
+
+ @Override
+ public Element next() {
+ if (i < SingleElementVector.this.size() - 1) {
+ i++;
+ return r;
+ } else {
+ throw new NoSuchElementException("Attempted to access passed last element of vector");
+ }
+ }
+
+
+ @Override
+ public void remove() {
+ throw new UnsupportedOperationException("Default operation");
+ }
+ };
+ }
+
+ @Override
+ protected Matrix matrixLike(int rows, int columns) {
+ return new DiagonalMatrix(rows, columns);
+ }
+
+ @Override
+ public boolean isDense() {
+ return false;
+ }
+
+ @Override
+ public boolean isSequentialAccess() {
+ return true;
+ }
+
+ @Override
+ public void mergeUpdates(OrderedIntDoubleMapping updates) {
+ throw new UnsupportedOperationException("Default operation");
+ }
+
+ @Override
+ public Vector like() {
+ return new DenseVector(size());
+ }
+
+ @Override
+ public Vector like(int cardinality) {
+ return new DenseVector(cardinality);
+ }
+
+ @Override
+ public void setQuick(int index, double value) {
+ if (index == this.index) {
+ diagonal.set(this.index, value);
+ } else {
+ throw new IllegalArgumentException("Can't set off-diagonal element of DiagonalMatrix");
+ }
+ }
+
+ @Override
+ public int getNumNondefaultElements() {
+ return 1;
+ }
+
+ @Override
+ public double getLookupCost() {
+ return 0;
+ }
+
+ @Override
+ public double getIteratorAdvanceCost() {
+ return 1;
+ }
+
+ @Override
+ public boolean isAddConstantTime() {
+ return false;
+ }
+ }
+
+ /**
+ * Provides a view of the diagonal of a matrix.
+ */
+ @Override
+ public Vector viewDiagonal() {
+ return this.diagonal;
+ }
+
+ /**
+ * Return the value at the given location, without checking bounds
+ *
+ * @param row an int row index
+ * @param column an int column index
+ * @return the double at the index
+ */
+ @Override
+ public double getQuick(int row, int column) {
+ if (row == column) {
+ return diagonal.get(row);
+ } else {
+ return 0;
+ }
+ }
+
+ /**
+ * Return an empty matrix of the same underlying class as the receiver
+ *
+ * @return a Matrix
+ */
+ @Override
+ public Matrix like() {
+ return new SparseRowMatrix(rowSize(), columnSize());
+ }
+
+ /**
+ * Returns an empty matrix of the same underlying class as the receiver and of the specified
+ * size.
+ *
+ * @param rows the int number of rows
+ * @param columns the int number of columns
+ */
+ @Override
+ public Matrix like(int rows, int columns) {
+ return new SparseRowMatrix(rows, columns);
+ }
+
+ @Override
+ public void setQuick(int row, int column, double value) {
+ if (row == column) {
+ diagonal.set(row, value);
+ } else {
+ throw new UnsupportedOperationException("Can't set off-diagonal element");
+ }
+ }
+
+ /**
+ * Return the number of values in the recipient
+ *
+ * @return an int[2] containing [row, column] count
+ */
+ @Override
+ public int[] getNumNondefaultElements() {
+ throw new UnsupportedOperationException("Don't understand how to implement this");
+ }
+
+ /**
+ * Return a new matrix containing the subset of the recipient
+ *
+ * @param offset an int[2] offset into the receiver
+ * @param size the int[2] size of the desired result
+ * @return a new Matrix that is a view of the original
+ * @throws CardinalityException if the length is greater than the cardinality of the receiver
+ * @throws IndexException if the offset is negative or the offset+length is outside of the
+ * receiver
+ */
+ @Override
+ public Matrix viewPart(int[] offset, int[] size) {
+ return new MatrixView(this, offset, size);
+ }
+
+ @Override
+ public Matrix times(Matrix other) {
+ return timesRight(other);
+ }
+
+ @Override
+ public Matrix timesRight(Matrix that) {
+ if (that.numRows() != diagonal.size()) {
+ throw new IllegalArgumentException("Incompatible number of rows in the right operand of matrix multiplication.");
+ }
+ Matrix m = that.like();
+ for (int row = 0; row < diagonal.size(); row++) {
+ m.assignRow(row, that.viewRow(row).times(diagonal.getQuick(row)));
+ }
+ return m;
+ }
+
+ @Override
+ public Matrix timesLeft(Matrix that) {
+ if (that.numCols() != diagonal.size()) {
+ throw new IllegalArgumentException(
+ "Incompatible number of rows in the left operand of matrix-matrix multiplication.");
+ }
+ Matrix m = that.like();
+ for (int col = 0; col < diagonal.size(); col++) {
+ m.assignColumn(col, that.viewColumn(col).times(diagonal.getQuick(col)));
+ }
+ return m;
+ }
+
+ @Override
+ public MatrixFlavor getFlavor() {
+ return MatrixFlavor.DIAGONALLIKE;
+ }
+
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/FileBasedMatrix.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/mahout/math/FileBasedMatrix.java b/core/src/main/java/org/apache/mahout/math/FileBasedMatrix.java
new file mode 100644
index 0000000..3a19318
--- /dev/null
+++ b/core/src/main/java/org/apache/mahout/math/FileBasedMatrix.java
@@ -0,0 +1,185 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math;
+
+import com.google.common.base.Preconditions;
+import com.google.common.collect.Lists;
+
+import java.io.File;
+import java.io.FileInputStream;
+import java.io.FileOutputStream;
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.nio.DoubleBuffer;
+import java.nio.MappedByteBuffer;
+import java.nio.channels.FileChannel;
+import java.util.List;
+
+/**
+ * Provides a way to get data from a file and treat it as if it were a matrix, but avoids putting all that
+ * data onto the Java heap. Instead, the file is mapped into non-heap memory as a DoubleBuffer and we access
+ * that instead.
+ */
+public final class FileBasedMatrix extends AbstractMatrix {
+ private final int rowsPerBlock;
+ private final List<DoubleBuffer> content = Lists.newArrayList();
+
+ /**
+ * Constructs an empty matrix of the given size.
+ *
+ * @param rows The number of rows in the result.
+ * @param columns The number of columns in the result.
+ */
+ public FileBasedMatrix(int rows, int columns) {
+ super(rows, columns);
+ long maxRows = ((1L << 31) - 1) / (columns * 8);
+ if (rows > maxRows) {
+ rowsPerBlock = (int) maxRows;
+ } else {
+ rowsPerBlock = rows;
+ }
+ }
+
+ private void addData(DoubleBuffer content) {
+ this.content.add(content);
+ }
+
+ public void setData(File f, boolean loadNow) throws IOException {
+ Preconditions.checkArgument(f.length() == rows * columns * 8L, "File " + f + " is wrong length");
+
+ for (int i = 0; i < (rows + rowsPerBlock - 1) / rowsPerBlock; i++) {
+ long start = i * rowsPerBlock * columns * 8L;
+ long size = rowsPerBlock * columns * 8L;
+ MappedByteBuffer buf = new FileInputStream(f).getChannel().map(FileChannel.MapMode.READ_ONLY, start,
+ Math.min(f.length() - start, size));
+ if (loadNow) {
+ buf.load();
+ }
+ addData(buf.asDoubleBuffer());
+ }
+ }
+
+ public static void writeMatrix(File f, Matrix m) throws IOException {
+ Preconditions.checkArgument(f.canWrite(), "Can't write to output file");
+ FileOutputStream fos = new FileOutputStream(f);
+ try {
+ ByteBuffer buf = ByteBuffer.allocate(m.columnSize() * 8);
+ for (MatrixSlice row : m) {
+ buf.clear();
+ for (Vector.Element element : row.vector().all()) {
+ buf.putDouble(element.get());
+ }
+ buf.flip();
+ fos.write(buf.array());
+ }
+ } finally {
+ fos.close();
+ }
+ }
+
+ /**
+ * Assign the other vector values to the column of the receiver
+ *
+ * @param column the int row to assign
+ * @param other a Vector
+ * @return the modified receiver
+ * @throws org.apache.mahout.math.CardinalityException
+ * if the cardinalities differ
+ */
+ @Override
+ public Matrix assignColumn(int column, Vector other) {
+ throw new UnsupportedOperationException("Default operation");
+ }
+
+ /**
+ * Assign the other vector values to the row of the receiver
+ *
+ * @param row the int row to assign
+ * @param other a Vector
+ * @return the modified receiver
+ * @throws org.apache.mahout.math.CardinalityException
+ * if the cardinalities differ
+ */
+ @Override
+ public Matrix assignRow(int row, Vector other) {
+ throw new UnsupportedOperationException("Default operation");
+ }
+
+ /**
+ * Return the value at the given indexes, without checking bounds
+ *
+ * @param row an int row index
+ * @param column an int column index
+ * @return the double at the index
+ */
+ @Override
+ public double getQuick(int row, int column) {
+ int block = row / rowsPerBlock;
+ return content.get(block).get((row % rowsPerBlock) * columns + column);
+ }
+
+ /**
+ * Return an empty matrix of the same underlying class as the receiver
+ *
+ * @return a Matrix
+ */
+ @Override
+ public Matrix like() {
+ throw new UnsupportedOperationException("Default operation");
+ }
+
+ /**
+ * Returns an empty matrix of the same underlying class as the receiver and of the specified size.
+ *
+ * @param rows the int number of rows
+ * @param columns the int number of columns
+ */
+ @Override
+ public Matrix like(int rows, int columns) {
+ return new DenseMatrix(rows, columns);
+ }
+
+ /**
+ * Set the value at the given index, without checking bounds
+ *
+ * @param row an int row index into the receiver
+ * @param column an int column index into the receiver
+ * @param value a double value to set
+ */
+ @Override
+ public void setQuick(int row, int column, double value) {
+ throw new UnsupportedOperationException("Default operation");
+ }
+
+ /**
+ * Return a view into part of a matrix. Changes to the view will change the
+ * original matrix.
+ *
+ * @param offset an int[2] offset into the receiver
+ * @param size the int[2] size of the desired result
+ * @return a matrix that shares storage with part of the original matrix.
+ * @throws org.apache.mahout.math.CardinalityException
+ * if the length is greater than the cardinality of the receiver
+ * @throws org.apache.mahout.math.IndexException
+ * if the offset is negative or the offset+length is outside of the receiver
+ */
+ @Override
+ public Matrix viewPart(int[] offset, int[] size) {
+ throw new UnsupportedOperationException("Default operation");
+ }
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/FileBasedSparseBinaryMatrix.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/mahout/math/FileBasedSparseBinaryMatrix.java b/core/src/main/java/org/apache/mahout/math/FileBasedSparseBinaryMatrix.java
new file mode 100644
index 0000000..0b0c25e
--- /dev/null
+++ b/core/src/main/java/org/apache/mahout/math/FileBasedSparseBinaryMatrix.java
@@ -0,0 +1,535 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math;
+
+import java.io.DataOutputStream;
+import java.io.File;
+import java.io.FileInputStream;
+import java.io.FileOutputStream;
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.nio.IntBuffer;
+import java.nio.channels.FileChannel;
+import java.util.Collections;
+import java.util.Iterator;
+import java.util.List;
+
+import com.google.common.base.Function;
+import com.google.common.base.Preconditions;
+import com.google.common.collect.AbstractIterator;
+import com.google.common.collect.Iterables;
+import com.google.common.collect.Lists;
+
+/**
+ * Provides a way to get data from a file and treat it as if it were a matrix, but avoids putting
+ * all that data onto the Java heap. Instead, the file is mapped into non-heap memory as a
+ * DoubleBuffer and we access that instead. The interesting aspect of this is that the values in
+ * the matrix are binary and sparse so we don't need to store the actual data, just the location of
+ * non-zero values.
+ * <p>
+ * Currently file data is formatted as follows:
+ * <p>
+ * <ul> <li>A magic number to indicate the file format.</li> <li>The size of the matrix (max rows
+ * and columns possible)</li> <li>Number of non-zeros in each row.</li> <li>A list of non-zero
+ * columns for each row. The list starts with a count and then has column numbers</li> </ul>
+ * <p>
+ * It would be preferable to use something like protobufs to define the format so that we can use
+ * different row formats for different kinds of data. For instance, Golay coding of column numbers
+ * or compressed bit vectors might be good representations for some purposes.
+ */
+public final class FileBasedSparseBinaryMatrix extends AbstractMatrix {
+ private static final int MAGIC_NUMBER_V0 = 0x12d7067d;
+
+ private final List<IntBuffer> data = Lists.newArrayList();
+ private int[] bufferIndex;
+ private int[] rowOffset;
+ private int[] rowSize;
+
+ /**
+ * Constructs an empty matrix of the given size.
+ *
+ * @param rows The number of rows in the result.
+ * @param columns The number of columns in the result.
+ */
+ public FileBasedSparseBinaryMatrix(int rows, int columns) {
+ super(rows, columns);
+ }
+
+ public void setData(File f) throws IOException {
+ List<ByteBuffer> buffers = Lists.newArrayList();
+ FileChannel input = new FileInputStream(f).getChannel();
+
+ buffers.add(input.map(FileChannel.MapMode.READ_ONLY, 0, Math.min(Integer.MAX_VALUE, f.length())));
+ data.add(buffers.get(0).asIntBuffer());
+ Preconditions.checkArgument(buffers.get(0).getInt() == MAGIC_NUMBER_V0, "Wrong type of file");
+
+ int rows = buffers.get(0).getInt();
+ int cols = buffers.get(0).getInt();
+ Preconditions.checkArgument(rows == rowSize());
+ Preconditions.checkArgument(cols == columnSize());
+
+ rowOffset = new int[rows];
+ rowSize = new int[rows];
+ bufferIndex = new int[rows];
+
+ int offset = 12 + 4 * rows;
+ for (int i = 0; i < rows; i++) {
+ int size = buffers.get(0).getInt();
+ int buffer = 0;
+ while (buffer < buffers.size()) {
+ if (offset + size * 4 <= buffers.get(buffer).limit()) {
+ break;
+ } else {
+ offset -= buffers.get(buffer).capacity();
+ }
+ }
+ if (buffer == buffers.size()) {
+ buffers.add(input.map(FileChannel.MapMode.READ_ONLY, 0, Math.min(Integer.MAX_VALUE, f.length() - offset)));
+ data.add(buffers.get(buffer).asIntBuffer());
+ }
+ rowOffset[i] = offset / 4;
+ rowSize[i] = size;
+ bufferIndex[i] = buffer;
+
+// final SparseBinaryVector v = new SparseBinaryVector(buffers.get(buffer), columns, offset, size);
+// this.rows.add(v);
+ offset += size * 4;
+ }
+ }
+
+ public static void writeMatrix(File f, Matrix m) throws IOException {
+ Preconditions.checkArgument(f.canWrite(), "Can't write to output file");
+ FileOutputStream fos = new FileOutputStream(f);
+
+ // write header
+ DataOutputStream out = new DataOutputStream(fos);
+ out.writeInt(MAGIC_NUMBER_V0);
+ out.writeInt(m.rowSize());
+ out.writeInt(m.columnSize());
+
+ // compute offsets and write row headers
+ for (MatrixSlice row : m) {
+ int nondefaultElements = row.vector().getNumNondefaultElements();
+ out.writeInt(nondefaultElements);
+ }
+
+ // write rows
+ for (MatrixSlice row : m) {
+ List<Integer> columns = Lists.newArrayList(Iterables.transform(row.vector().nonZeroes(),
+ new Function<Vector.Element, Integer>() {
+ @Override
+ public Integer apply(Vector.Element element) {
+ return element.index();
+ }
+ }));
+ Collections.sort(columns);
+
+ for (Integer column : columns) {
+ out.writeInt(column);
+ }
+ }
+
+ out.close();
+ fos.close();
+ }
+
+ /**
+ * Assign the other vector values to the column of the receiver
+ *
+ * @param column the int row to assign
+ * @param other a Vector
+ * @return the modified receiver
+ * @throws org.apache.mahout.math.CardinalityException
+ * if the cardinalities differ
+ */
+ @Override
+ public Matrix assignColumn(int column, Vector other) {
+ throw new UnsupportedOperationException("Default operation");
+ }
+
+ /**
+ * Assign the other vector values to the row of the receiver
+ *
+ * @param row the int row to assign
+ * @param other a Vector
+ * @return the modified receiver
+ * @throws org.apache.mahout.math.CardinalityException
+ * if the cardinalities differ
+ */
+ @Override
+ public Matrix assignRow(int row, Vector other) {
+ throw new UnsupportedOperationException("Default operation");
+ }
+
+ /**
+ * Return the value at the given indexes, without checking bounds
+ *
+ * @param rowIndex an int row index
+ * @param columnIndex an int column index
+ * @return the double at the index
+ */
+ @Override
+ public double getQuick(int rowIndex, int columnIndex) {
+ IntBuffer tmp = data.get(bufferIndex[rowIndex]).asReadOnlyBuffer();
+ tmp.position(rowOffset[rowIndex]);
+ tmp.limit(rowSize[rowIndex]);
+ tmp = tmp.slice();
+ return searchForIndex(tmp, columnIndex);
+ }
+
+ private static double searchForIndex(IntBuffer row, int columnIndex) {
+ int high = row.limit();
+ if (high == 0) {
+ return 0;
+ }
+ int low = 0;
+ while (high > low) {
+ int mid = (low + high) / 2;
+ if (row.get(mid) < columnIndex) {
+ low = mid + 1;
+ } else {
+ high = mid;
+ }
+ }
+ if (low >= row.limit()) {
+ return 0;
+ } else if (high == low && row.get(low) == columnIndex) {
+ return 1;
+ } else {
+ return 0;
+ }
+ }
+
+ /**
+ * Return an empty matrix of the same underlying class as the receiver
+ *
+ * @return a Matrix
+ */
+ @Override
+ public Matrix like() {
+ throw new UnsupportedOperationException("Default operation");
+ }
+
+ /**
+ * Returns an empty matrix of the same underlying class as the receiver and of the specified
+ * size.
+ *
+ * @param rows the int number of rows
+ * @param columns the int number of columns
+ */
+ @Override
+ public Matrix like(int rows, int columns) {
+ return new DenseMatrix(rows, columns);
+ }
+
+ /**
+ * Set the value at the given index, without checking bounds
+ *
+ * @param row an int row index into the receiver
+ * @param column an int column index into the receiver
+ * @param value a double value to set
+ */
+ @Override
+ public void setQuick(int row, int column, double value) {
+ throw new UnsupportedOperationException("Default operation");
+ }
+
+ /**
+ * Return a view into part of a matrix. Changes to the view will change the original matrix.
+ *
+ * @param offset an int[2] offset into the receiver
+ * @param size the int[2] size of the desired result
+ * @return a matrix that shares storage with part of the original matrix.
+ * @throws org.apache.mahout.math.CardinalityException
+ * if the length is greater than the cardinality of the receiver
+ * @throws org.apache.mahout.math.IndexException
+ * if the offset is negative or the offset+length is outside of the receiver
+ */
+ @Override
+ public Matrix viewPart(int[] offset, int[] size) {
+ throw new UnsupportedOperationException("Default operation");
+ }
+
+ /**
+ * Returns a view of a row. Changes to the view will affect the original.
+ *
+ * @param rowIndex Which row to return.
+ * @return A vector that references the desired row.
+ */
+ @Override
+ public Vector viewRow(int rowIndex) {
+ IntBuffer tmp = data.get(bufferIndex[rowIndex]).asReadOnlyBuffer();
+ tmp.position(rowOffset[rowIndex]);
+ tmp.limit(rowOffset[rowIndex] + rowSize[rowIndex]);
+ tmp = tmp.slice();
+ return new SparseBinaryVector(tmp, columnSize());
+ }
+
+ private static class SparseBinaryVector extends AbstractVector {
+ private final IntBuffer buffer;
+ private final int maxIndex;
+
+ private SparseBinaryVector(IntBuffer buffer, int maxIndex) {
+ super(maxIndex);
+ this.buffer = buffer;
+ this.maxIndex = maxIndex;
+ }
+
+ SparseBinaryVector(ByteBuffer row, int maxIndex, int offset, int size) {
+ super(maxIndex);
+ row = row.asReadOnlyBuffer();
+ row.position(offset);
+ row.limit(offset + size * 4);
+ row = row.slice();
+ this.buffer = row.slice().asIntBuffer();
+ this.maxIndex = maxIndex;
+ }
+
+ /**
+ * Subclasses must override to return an appropriately sparse or dense result
+ *
+ * @param rows the row cardinality
+ * @param columns the column cardinality
+ * @return a Matrix
+ */
+ @Override
+ protected Matrix matrixLike(int rows, int columns) {
+ throw new UnsupportedOperationException("Default operation");
+ }
+
+ /**
+ * Used internally by assign() to update multiple indices and values at once.
+ * Only really useful for sparse vectors (especially SequentialAccessSparseVector).
+ * <p/>
+ * If someone ever adds a new type of sparse vectors, this method must merge (index, value) pairs into the vector.
+ *
+ * @param updates a mapping of indices to values to merge in the vector.
+ */
+ @Override
+ public void mergeUpdates(OrderedIntDoubleMapping updates) {
+ throw new UnsupportedOperationException("Cannot mutate SparseBinaryVector");
+ }
+
+ /**
+ * @return true iff this implementation should be considered dense -- that it explicitly represents
+ * every value
+ */
+ @Override
+ public boolean isDense() {
+ return false;
+ }
+
+ /**
+ * @return true iff this implementation should be considered to be iterable in index order in an
+ * efficient way. In particular this implies that {@link #iterator()} and {@link
+ * #iterateNonZero()} return elements in ascending order by index.
+ */
+ @Override
+ public boolean isSequentialAccess() {
+ return true;
+ }
+
+ /**
+ * Iterates over all elements
+ *
+ * NOTE: Implementations may choose to reuse the Element returned
+ * for performance reasons, so if you need a copy of it, you should call {@link #getElement(int)}
+ * for the given index
+ *
+ * @return An {@link java.util.Iterator} over all elements
+ */
+ @Override
+ public Iterator<Element> iterator() {
+ return new AbstractIterator<Element>() {
+ int i = 0;
+
+ @Override
+ protected Element computeNext() {
+ if (i < maxIndex) {
+ return new Element() {
+ int index = i++;
+ /**
+ * @return the value of this vector element.
+ */
+ @Override
+ public double get() {
+ return getQuick(index);
+ }
+
+ /**
+ * @return the index of this vector element.
+ */
+ @Override
+ public int index() {
+ return index;
+ }
+
+ /**
+ * @param value Set the current element to value.
+ */
+ @Override
+ public void set(double value) {
+ throw new UnsupportedOperationException("Default operation");
+ }
+ };
+ } else {
+ return endOfData();
+ }
+ }
+ };
+ }
+
+ /**
+ * Iterates over all non-zero elements. <p/> NOTE: Implementations may choose to reuse the Element
+ * returned for performance reasons, so if you need a copy of it, you should call {@link
+ * #getElement(int)} for the given index
+ *
+ * @return An {@link java.util.Iterator} over all non-zero elements
+ */
+ @Override
+ public Iterator<Element> iterateNonZero() {
+ return new AbstractIterator<Element>() {
+ int i = 0;
+ @Override
+ protected Element computeNext() {
+ if (i < buffer.limit()) {
+ return new BinaryReadOnlyElement(buffer.get(i++));
+ } else {
+ return endOfData();
+ }
+ }
+ };
+ }
+
+ /**
+ * Return the value at the given index, without checking bounds
+ *
+ * @param index an int index
+ * @return the double at the index
+ */
+ @Override
+ public double getQuick(int index) {
+ return searchForIndex(buffer, index);
+ }
+
+ /**
+ * Return an empty vector of the same underlying class as the receiver
+ *
+ * @return a Vector
+ */
+ @Override
+ public Vector like() {
+ return new RandomAccessSparseVector(size());
+ }
+
+ @Override
+ public Vector like(int cardinality) {
+ return new RandomAccessSparseVector(cardinality);
+ }
+
+ /**
+ * Copy the vector for fast operations.
+ *
+ * @return a Vector
+ */
+ @Override
+ protected Vector createOptimizedCopy() {
+ return new RandomAccessSparseVector(size()).assign(this);
+ }
+
+ /**
+ * Set the value at the given index, without checking bounds
+ *
+ * @param index an int index into the receiver
+ * @param value a double value to set
+ */
+ @Override
+ public void setQuick(int index, double value) {
+ throw new UnsupportedOperationException("Read-only view");
+ }
+
+ /**
+ * Set the value at the given index, without checking bounds
+ *
+ * @param index an int index into the receiver
+ * @param increment a double value to set
+ */
+ @Override
+ public void incrementQuick(int index, double increment) {
+ throw new UnsupportedOperationException("Read-only view");
+ }
+
+ /**
+ * Return the number of values in the recipient which are not the default value. For instance, for
+ * a sparse vector, this would be the number of non-zero values.
+ *
+ * @return an int
+ */
+ @Override
+ public int getNumNondefaultElements() {
+ return buffer.limit();
+ }
+
+ @Override
+ public double getLookupCost() {
+ return 1;
+ }
+
+ @Override
+ public double getIteratorAdvanceCost() {
+ return 1;
+ }
+
+ @Override
+ public boolean isAddConstantTime() {
+ throw new UnsupportedOperationException("Can't add binary value");
+ }
+ }
+
+ public static class BinaryReadOnlyElement implements Vector.Element {
+ private final int index;
+
+ public BinaryReadOnlyElement(int index) {
+ this.index = index;
+ }
+
+ /**
+ * @return the value of this vector element.
+ */
+ @Override
+ public double get() {
+ return 1;
+ }
+
+ /**
+ * @return the index of this vector element.
+ */
+ @Override
+ public int index() {
+ return index;
+ }
+
+ /**
+ * @param value Set the current element to value.
+ */
+ @Override
+ public void set(double value) {
+ throw new UnsupportedOperationException("Can't set binary value");
+ }
+ }
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/FunctionalMatrixView.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/mahout/math/FunctionalMatrixView.java b/core/src/main/java/org/apache/mahout/math/FunctionalMatrixView.java
new file mode 100644
index 0000000..9028e23
--- /dev/null
+++ b/core/src/main/java/org/apache/mahout/math/FunctionalMatrixView.java
@@ -0,0 +1,99 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math;
+
+import org.apache.mahout.math.flavor.BackEnum;
+import org.apache.mahout.math.flavor.MatrixFlavor;
+import org.apache.mahout.math.flavor.TraversingStructureEnum;
+import org.apache.mahout.math.function.IntIntFunction;
+
+/**
+ * Matrix View backed by an {@link IntIntFunction}
+ */
+class FunctionalMatrixView extends AbstractMatrix {
+
+ /**
+ * view generator function
+ */
+ private IntIntFunction gf;
+ private boolean denseLike;
+ private MatrixFlavor flavor;
+
+ public FunctionalMatrixView(int rows, int columns, IntIntFunction gf) {
+ this(rows, columns, gf, false);
+ }
+
+ /**
+ * @param gf generator function
+ * @param denseLike whether like() should create Dense or Sparse matrix.
+ */
+ public FunctionalMatrixView(int rows, int columns, IntIntFunction gf, boolean denseLike) {
+ super(rows, columns);
+ this.gf = gf;
+ this.denseLike = denseLike;
+ flavor = new MatrixFlavor.FlavorImpl(BackEnum.JVMMEM, TraversingStructureEnum.BLOCKIFIED, denseLike);
+ }
+
+ @Override
+ public Matrix assignColumn(int column, Vector other) {
+ throw new UnsupportedOperationException("Assignment to a matrix not supported");
+ }
+
+ @Override
+ public Matrix assignRow(int row, Vector other) {
+ throw new UnsupportedOperationException("Assignment to a matrix view not supported");
+ }
+
+ @Override
+ public double getQuick(int row, int column) {
+ return gf.apply(row, column);
+ }
+
+ @Override
+ public Matrix like() {
+ return like(rows, columns);
+ }
+
+ @Override
+ public Matrix like(int rows, int columns) {
+ if (denseLike)
+ return new DenseMatrix(rows, columns);
+ else
+ return new SparseMatrix(rows, columns);
+ }
+
+ @Override
+ public void setQuick(int row, int column, double value) {
+ throw new UnsupportedOperationException("Assignment to a matrix view not supported");
+ }
+
+ @Override
+ public Vector viewRow(int row) {
+ return new MatrixVectorView(this, row, 0, 0, 1, denseLike);
+ }
+
+ @Override
+ public Vector viewColumn(int column) {
+ return new MatrixVectorView(this, 0, column, 1, 0, denseLike);
+ }
+
+ @Override
+ public MatrixFlavor getFlavor() {
+ return flavor;
+ }
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/IndexException.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/mahout/math/IndexException.java b/core/src/main/java/org/apache/mahout/math/IndexException.java
new file mode 100644
index 0000000..489d536
--- /dev/null
+++ b/core/src/main/java/org/apache/mahout/math/IndexException.java
@@ -0,0 +1,30 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math;
+
+/**
+ * Exception thrown when a matrix or vector is accessed at an index, or dimension,
+ * which does not logically exist in the entity.
+ */
+public class IndexException extends IllegalArgumentException {
+
+ public IndexException(int index, int cardinality) {
+ super("Index " + index + " is outside allowable range of [0," + cardinality + ')');
+ }
+
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/LengthCachingVector.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/mahout/math/LengthCachingVector.java b/core/src/main/java/org/apache/mahout/math/LengthCachingVector.java
new file mode 100644
index 0000000..770ccc4
--- /dev/null
+++ b/core/src/main/java/org/apache/mahout/math/LengthCachingVector.java
@@ -0,0 +1,35 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math;
+
+/**
+ * Marker interface for vectors that may cache their squared length.
+ */
+interface LengthCachingVector {
+ /**
+ * Gets the currently cached squared length or if there is none, recalculates
+ * the value and returns that.
+ * @return The sum of the squares of all elements in the vector.
+ */
+ double getLengthSquared();
+
+ /**
+ * Invalidates the length cache. This should be called by all mutators of the vector.
+ */
+ void invalidateCachedLength();
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/Matrices.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/mahout/math/Matrices.java b/core/src/main/java/org/apache/mahout/math/Matrices.java
new file mode 100644
index 0000000..5d8b5c5
--- /dev/null
+++ b/core/src/main/java/org/apache/mahout/math/Matrices.java
@@ -0,0 +1,167 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math;
+
+import com.google.common.base.Preconditions;
+import org.apache.mahout.common.RandomUtils;
+import org.apache.mahout.math.function.DoubleFunction;
+import org.apache.mahout.math.function.Functions;
+import org.apache.mahout.math.function.IntIntFunction;
+
+import java.util.Random;
+
+public final class Matrices {
+
+ /**
+ * Create a matrix view based on a function generator.
+ * <p>
+ * The generator needs to be idempotent, i.e. returning same value
+ * for each combination of (row, column) argument sent to generator's
+ * {@link IntIntFunction#apply(int, int)} call.
+ *
+ * @param rows Number of rows in a view
+ * @param columns Number of columns in a view.
+ * @param gf view generator
+ * @param denseLike type of matrix returne dby {@link org.apache.mahout.math.Matrix#like()}.
+ * @return new matrix view.
+ */
+ public static Matrix functionalMatrixView(final int rows,
+ final int columns,
+ final IntIntFunction gf,
+ final boolean denseLike) {
+ return new FunctionalMatrixView(rows, columns, gf, denseLike);
+ }
+
+ /**
+ * Shorter form of {@link Matrices#functionalMatrixView(int, int,
+ * org.apache.mahout.math.function.IntIntFunction, boolean)}.
+ */
+ public static Matrix functionalMatrixView(final int rows,
+ final int columns,
+ final IntIntFunction gf) {
+ return new FunctionalMatrixView(rows, columns, gf);
+ }
+
+ /**
+ * A read-only transposed view of a matrix argument.
+ *
+ * @param m original matrix
+ * @return transposed view of original matrix
+ */
+ public static Matrix transposedView(final Matrix m) {
+
+ Preconditions.checkArgument(!(m instanceof SparseColumnMatrix));
+
+ if (m instanceof TransposedMatrixView) {
+ return ((TransposedMatrixView) m).getDelegate();
+ } else {
+ return new TransposedMatrixView(m);
+ }
+ }
+
+ /**
+ * Random Gaussian matrix view.
+ *
+ * @param seed generator seed
+ */
+ public static Matrix gaussianView(final int rows,
+ final int columns,
+ long seed) {
+ return functionalMatrixView(rows, columns, gaussianGenerator(seed), true);
+ }
+
+
+ /**
+ * Matrix view based on uniform [-1,1) distribution.
+ *
+ * @param seed generator seed
+ */
+ public static Matrix symmetricUniformView(final int rows,
+ final int columns,
+ int seed) {
+ return functionalMatrixView(rows, columns, uniformSymmetricGenerator(seed), true);
+ }
+
+ /**
+ * Matrix view based on uniform [0,1) distribution.
+ *
+ * @param seed generator seed
+ */
+ public static Matrix uniformView(final int rows,
+ final int columns,
+ int seed) {
+ return functionalMatrixView(rows, columns, uniformGenerator(seed), true);
+ }
+
+ /**
+ * Generator for a matrix populated by random Gauissian values (Gaussian matrix view)
+ *
+ * @param seed The seed for the matrix.
+ * @return Gaussian {@link IntIntFunction} generating matrix view with normal values
+ */
+ public static IntIntFunction gaussianGenerator(final long seed) {
+ final Random rnd = RandomUtils.getRandom(seed);
+ return new IntIntFunction() {
+ @Override
+ public double apply(int first, int second) {
+ rnd.setSeed(seed ^ (((long) first << 32) | (second & 0xffffffffL)));
+ return rnd.nextGaussian();
+ }
+ };
+ }
+
+ private static final double UNIFORM_DIVISOR = Math.pow(2.0, 64);
+
+ /**
+ * Uniform [-1,1) matrix generator function.
+ * <p>
+ * WARNING: to keep things performant, it is stateful and so not thread-safe.
+ * You'd need to create a copy per thread (with same seed) if shared between threads.
+ *
+ * @param seed - random seed initializer
+ * @return Uniform {@link IntIntFunction} generator
+ */
+ public static IntIntFunction uniformSymmetricGenerator(final int seed) {
+ return new IntIntFunction() {
+ private byte[] data = new byte[8];
+
+ @Override
+ public double apply(int row, int column) {
+ long d = ((long) row << Integer.SIZE) | (column & 0xffffffffL);
+ for (int i = 0; i < 8; i++, d >>>= 8) data[i] = (byte) d;
+ long hash = MurmurHash.hash64A(data, seed);
+ return hash / UNIFORM_DIVISOR;
+ }
+ };
+ }
+
+ /**
+ * Uniform [0,1) matrix generator function
+ *
+ * @param seed generator seed
+ */
+ public static IntIntFunction uniformGenerator(final int seed) {
+ return Functions.chain(new DoubleFunction() {
+ @Override
+ public double apply(double x) {
+ return (x + 1.0) / 2.0;
+ }
+ }, uniformSymmetricGenerator(seed));
+ }
+
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/Matrix.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/mahout/math/Matrix.java b/core/src/main/java/org/apache/mahout/math/Matrix.java
new file mode 100644
index 0000000..57fab78
--- /dev/null
+++ b/core/src/main/java/org/apache/mahout/math/Matrix.java
@@ -0,0 +1,413 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math;
+
+import org.apache.mahout.math.flavor.MatrixFlavor;
+import org.apache.mahout.math.function.DoubleDoubleFunction;
+import org.apache.mahout.math.function.DoubleFunction;
+import org.apache.mahout.math.function.VectorFunction;
+
+import java.util.Map;
+
+/** The basic interface including numerous convenience functions */
+public interface Matrix extends Cloneable, VectorIterable {
+
+ /** @return a formatted String suitable for output */
+ String asFormatString();
+
+ /**
+ * Assign the value to all elements of the receiver
+ *
+ * @param value a double value
+ * @return the modified receiver
+ */
+ Matrix assign(double value);
+
+ /**
+ * Assign the values to the receiver
+ *
+ * @param values a double[] of values
+ * @return the modified receiver
+ * @throws CardinalityException if the cardinalities differ
+ */
+ Matrix assign(double[][] values);
+
+ /**
+ * Assign the other vector values to the receiver
+ *
+ * @param other a Matrix
+ * @return the modified receiver
+ * @throws CardinalityException if the cardinalities differ
+ */
+ Matrix assign(Matrix other);
+
+ /**
+ * Apply the function to each element of the receiver
+ *
+ * @param function a DoubleFunction to apply
+ * @return the modified receiver
+ */
+ Matrix assign(DoubleFunction function);
+
+ /**
+ * Apply the function to each element of the receiver and the corresponding element of the other argument
+ *
+ * @param other a Matrix containing the second arguments to the function
+ * @param function a DoubleDoubleFunction to apply
+ * @return the modified receiver
+ * @throws CardinalityException if the cardinalities differ
+ */
+ Matrix assign(Matrix other, DoubleDoubleFunction function);
+
+ /**
+ * Assign the other vector values to the column of the receiver
+ *
+ * @param column the int row to assign
+ * @param other a Vector
+ * @return the modified receiver
+ * @throws CardinalityException if the cardinalities differ
+ */
+ Matrix assignColumn(int column, Vector other);
+
+ /**
+ * Assign the other vector values to the row of the receiver
+ *
+ * @param row the int row to assign
+ * @param other a Vector
+ * @return the modified receiver
+ * @throws CardinalityException if the cardinalities differ
+ */
+ Matrix assignRow(int row, Vector other);
+
+ /**
+ * Collects the results of a function applied to each row of a matrix.
+ * @param f The function to be applied to each row.
+ * @return The vector of results.
+ */
+ Vector aggregateRows(VectorFunction f);
+
+ /**
+ * Collects the results of a function applied to each column of a matrix.
+ * @param f The function to be applied to each column.
+ * @return The vector of results.
+ */
+ Vector aggregateColumns(VectorFunction f);
+
+ /**
+ * Collects the results of a function applied to each element of a matrix and then
+ * aggregated.
+ * @param combiner A function that combines the results of the mapper.
+ * @param mapper A function to apply to each element.
+ * @return The result.
+ */
+ double aggregate(DoubleDoubleFunction combiner, DoubleFunction mapper);
+
+ /**
+ * @return The number of rows in the matrix.
+ */
+ int columnSize();
+
+ /**
+ * @return Returns the number of rows in the matrix.
+ */
+ int rowSize();
+
+ /**
+ * Return a copy of the recipient
+ *
+ * @return a new Matrix
+ */
+ Matrix clone();
+
+ /**
+ * Returns matrix determinator using Laplace theorem
+ *
+ * @return a matrix determinator
+ */
+ double determinant();
+
+ /**
+ * Return a new matrix containing the values of the recipient divided by the argument
+ *
+ * @param x a double value
+ * @return a new Matrix
+ */
+ Matrix divide(double x);
+
+ /**
+ * Return the value at the given indexes
+ *
+ * @param row an int row index
+ * @param column an int column index
+ * @return the double at the index
+ * @throws IndexException if the index is out of bounds
+ */
+ double get(int row, int column);
+
+ /**
+ * Return the value at the given indexes, without checking bounds
+ *
+ * @param row an int row index
+ * @param column an int column index
+ * @return the double at the index
+ */
+ double getQuick(int row, int column);
+
+ /**
+ * Return an empty matrix of the same underlying class as the receiver
+ *
+ * @return a Matrix
+ */
+ Matrix like();
+
+ /**
+ * Returns an empty matrix of the same underlying class as the receiver and of the specified size.
+ *
+ * @param rows the int number of rows
+ * @param columns the int number of columns
+ */
+ Matrix like(int rows, int columns);
+
+ /**
+ * Return a new matrix containing the element by element difference of the recipient and the argument
+ *
+ * @param x a Matrix
+ * @return a new Matrix
+ * @throws CardinalityException if the cardinalities differ
+ */
+ Matrix minus(Matrix x);
+
+ /**
+ * Return a new matrix containing the sum of each value of the recipient and the argument
+ *
+ * @param x a double
+ * @return a new Matrix
+ */
+ Matrix plus(double x);
+
+ /**
+ * Return a new matrix containing the element by element sum of the recipient and the argument
+ *
+ * @param x a Matrix
+ * @return a new Matrix
+ * @throws CardinalityException if the cardinalities differ
+ */
+ Matrix plus(Matrix x);
+
+ /**
+ * Set the value at the given index
+ *
+ * @param row an int row index into the receiver
+ * @param column an int column index into the receiver
+ * @param value a double value to set
+ * @throws IndexException if the index is out of bounds
+ */
+ void set(int row, int column, double value);
+
+ void set(int row, double[] data);
+
+ /**
+ * Set the value at the given index, without checking bounds
+ *
+ * @param row an int row index into the receiver
+ * @param column an int column index into the receiver
+ * @param value a double value to set
+ */
+ void setQuick(int row, int column, double value);
+
+ /**
+ * Return the number of values in the recipient
+ *
+ * @return an int[2] containing [row, column] count
+ */
+ int[] getNumNondefaultElements();
+
+ /**
+ * Return a new matrix containing the product of each value of the recipient and the argument
+ *
+ * @param x a double argument
+ * @return a new Matrix
+ */
+ Matrix times(double x);
+
+ /**
+ * Return a new matrix containing the product of the recipient and the argument
+ *
+ * @param x a Matrix argument
+ * @return a new Matrix
+ * @throws CardinalityException if the cardinalities are incompatible
+ */
+ Matrix times(Matrix x);
+
+ /**
+ * Return a new matrix that is the transpose of the receiver
+ *
+ * @return the transpose
+ */
+ Matrix transpose();
+
+ /**
+ * Return the sum of all the elements of the receiver
+ *
+ * @return a double
+ */
+ double zSum();
+
+ /**
+ * Return a map of the current column label bindings of the receiver
+ *
+ * @return a {@code Map<String, Integer>}
+ */
+ Map<String, Integer> getColumnLabelBindings();
+
+ /**
+ * Return a map of the current row label bindings of the receiver
+ *
+ * @return a {@code Map<String, Integer>}
+ */
+ Map<String, Integer> getRowLabelBindings();
+
+ /**
+ * Sets a map of column label bindings in the receiver
+ *
+ * @param bindings a {@code Map<String, Integer>} of label bindings
+ */
+ void setColumnLabelBindings(Map<String, Integer> bindings);
+
+ /**
+ * Sets a map of row label bindings in the receiver
+ *
+ * @param bindings a {@code Map<String, Integer>} of label bindings
+ */
+ void setRowLabelBindings(Map<String, Integer> bindings);
+
+ /**
+ * Return the value at the given labels
+ *
+ * @param rowLabel a String row label
+ * @param columnLabel a String column label
+ * @return the double at the index
+ *
+ * @throws IndexException if the index is out of bounds
+ */
+ double get(String rowLabel, String columnLabel);
+
+ /**
+ * Set the value at the given index
+ *
+ * @param rowLabel a String row label
+ * @param columnLabel a String column label
+ * @param value a double value to set
+ * @throws IndexException if the index is out of bounds
+ */
+ void set(String rowLabel, String columnLabel, double value);
+
+ /**
+ * Set the value at the given index, updating the row and column label bindings
+ *
+ * @param rowLabel a String row label
+ * @param columnLabel a String column label
+ * @param row an int row index
+ * @param column an int column index
+ * @param value a double value
+ */
+ void set(String rowLabel, String columnLabel, int row, int column, double value);
+
+ /**
+ * Sets the row values at the given row label
+ *
+ * @param rowLabel a String row label
+ * @param rowData a double[] array of row data
+ */
+ void set(String rowLabel, double[] rowData);
+
+ /**
+ * Sets the row values at the given row index and updates the row labels
+ *
+ * @param rowLabel the String row label
+ * @param row an int the row index
+ * @param rowData a double[] array of row data
+ */
+ void set(String rowLabel, int row, double[] rowData);
+
+ /*
+ * Need stories for these but keeping them here for now.
+ *
+ */
+ // void getNonZeros(IntArrayList jx, DoubleArrayList values);
+ // void foreachNonZero(IntDoubleFunction f);
+ // double aggregate(DoubleDoubleFunction aggregator, DoubleFunction map);
+ // double aggregate(Matrix other, DoubleDoubleFunction aggregator,
+ // DoubleDoubleFunction map);
+ // NewMatrix assign(Matrix y, DoubleDoubleFunction function, IntArrayList
+ // nonZeroIndexes);
+
+ /**
+ * Return a view into part of a matrix. Changes to the view will change the
+ * original matrix.
+ *
+ * @param offset an int[2] offset into the receiver
+ * @param size the int[2] size of the desired result
+ * @return a matrix that shares storage with part of the original matrix.
+ * @throws CardinalityException if the length is greater than the cardinality of the receiver
+ * @throws IndexException if the offset is negative or the offset+length is outside of the receiver
+ */
+ Matrix viewPart(int[] offset, int[] size);
+
+ /**
+ * Return a view into part of a matrix. Changes to the view will change the
+ * original matrix.
+ *
+ * @param rowOffset The first row of the view
+ * @param rowsRequested The number of rows in the view
+ * @param columnOffset The first column in the view
+ * @param columnsRequested The number of columns in the view
+ * @return a matrix that shares storage with part of the original matrix.
+ * @throws CardinalityException if the length is greater than the cardinality of the receiver
+ * @throws IndexException if the offset is negative or the offset+length is outside of the
+ * receiver
+ */
+ Matrix viewPart(int rowOffset, int rowsRequested, int columnOffset, int columnsRequested);
+
+ /**
+ * Return a reference to a row. Changes to the view will change the original matrix.
+ * @param row The index of the row to return.
+ * @return A vector that shares storage with the original.
+ */
+ Vector viewRow(int row);
+
+ /**
+ * Return a reference to a column. Changes to the view will change the original matrix.
+ * @param column The index of the column to return.
+ * @return A vector that shares storage with the original.
+ */
+ Vector viewColumn(int column);
+
+ /**
+ * Returns a reference to the diagonal of a matrix. Changes to the view will change
+ * the original matrix.
+ * @return A vector that shares storage with the original matrix.
+ */
+ Vector viewDiagonal();
+
+ /**
+ * Get matrix structural flavor (operations performance hints). This is optional operation, may
+ * throw {@link java.lang.UnsupportedOperationException}.
+ */
+ MatrixFlavor getFlavor();
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/MatrixSlice.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/mahout/math/MatrixSlice.java b/core/src/main/java/org/apache/mahout/math/MatrixSlice.java
new file mode 100644
index 0000000..51378c1
--- /dev/null
+++ b/core/src/main/java/org/apache/mahout/math/MatrixSlice.java
@@ -0,0 +1,36 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math;
+
+public class MatrixSlice extends DelegatingVector {
+ private int index;
+
+ public MatrixSlice(Vector v, int index) {
+ super(v);
+ this.index = index;
+ }
+
+ public Vector vector() {
+ return getVector();
+ }
+
+ public int index() {
+ return index;
+ }
+}
+

http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/MatrixTimesOps.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/mahout/math/MatrixTimesOps.java b/core/src/main/java/org/apache/mahout/math/MatrixTimesOps.java
new file mode 100644
index 0000000..30d2afb
--- /dev/null
+++ b/core/src/main/java/org/apache/mahout/math/MatrixTimesOps.java
@@ -0,0 +1,35 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math;
+
+/**
+ * Optional interface for optimized matrix multiplications.
+ * Some concrete Matrix implementations may mix this in.
+ */
+public interface MatrixTimesOps {
+ /**
+ * computes matrix product of (this * that)
+ */
+ Matrix timesRight(Matrix that);
+
+ /**
+ * Computes matrix product of (that * this)
+ */
+ Matrix timesLeft(Matrix that);
+
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/MatrixVectorView.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/mahout/math/MatrixVectorView.java b/core/src/main/java/org/apache/mahout/math/MatrixVectorView.java
new file mode 100644
index 0000000..6ad44b5
--- /dev/null
+++ b/core/src/main/java/org/apache/mahout/math/MatrixVectorView.java
@@ -0,0 +1,292 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math;
+
+import java.util.Iterator;
+import java.util.NoSuchElementException;
+
+/**
+ * Provides a virtual vector that is really a row or column or diagonal of a matrix.
+ */
+public class MatrixVectorView extends AbstractVector {
+ private Matrix matrix;
+ private int row;
+ private int column;
+ private int rowStride;
+ private int columnStride;
+ private boolean isDense = true;
+
+ public MatrixVectorView(Matrix matrix, int row, int column, int rowStride, int columnStride, boolean isDense) {
+ this(matrix, row, column, rowStride, columnStride);
+ this.isDense = isDense;
+ }
+
+ public MatrixVectorView(Matrix matrix, int row, int column, int rowStride, int columnStride) {
+ super(viewSize(matrix, row, column, rowStride, columnStride));
+ if (row < 0 || row >= matrix.rowSize()) {
+ throw new IndexException(row, matrix.rowSize());
+ }
+ if (column < 0 || column >= matrix.columnSize()) {
+ throw new IndexException(column, matrix.columnSize());
+ }
+
+ this.matrix = matrix;
+ this.row = row;
+ this.column = column;
+ this.rowStride = rowStride;
+ this.columnStride = columnStride;
+ }
+
+ private static int viewSize(Matrix matrix, int row, int column, int rowStride, int columnStride) {
+ if (rowStride != 0 && columnStride != 0) {
+ int n1 = (matrix.numRows() - row) / rowStride;
+ int n2 = (matrix.numCols() - column) / columnStride;
+ return Math.min(n1, n2);
+ } else if (rowStride > 0) {
+ return (matrix.numRows() - row) / rowStride;
+ } else {
+ return (matrix.numCols() - column) / columnStride;
+ }
+ }
+
+ /**
+ * @return true iff the {@link Vector} implementation should be considered
+ * dense -- that it explicitly represents every value
+ */
+ @Override
+ public boolean isDense() {
+ return isDense;
+ }
+
+ /**
+ * @return true iff {@link Vector} should be considered to be iterable in
+ * index order in an efficient way. In particular this implies that {@link #iterator()} and
+ * {@link #iterateNonZero()} return elements in ascending order by index.
+ */
+ @Override
+ public boolean isSequentialAccess() {
+ return true;
+ }
+
+ /**
+ * Iterates over all elements <p>
+ * NOTE: Implementations may choose to reuse the Element returned
+ * for performance reasons, so if you need a copy of it, you should call {@link #getElement(int)} for
+ * the given index
+ *
+ * @return An {@link java.util.Iterator} over all elements
+ */
+ @Override
+ public Iterator<Element> iterator() {
+ final LocalElement r = new LocalElement(0);
+ return new Iterator<Element>() {
+ private int i;
+
+ @Override
+ public boolean hasNext() {
+ return i < size();
+ }
+
+ @Override
+ public Element next() {
+ if (i >= size()) {
+ throw new NoSuchElementException();
+ }
+ r.index = i++;
+ return r;
+ }
+
+ @Override
+ public void remove() {
+ throw new UnsupportedOperationException("Can't remove from a view");
+ }
+ };
+ }
+
+ /**
+ * Iterates over all non-zero elements. <p>
+ * NOTE: Implementations may choose to reuse the Element
+ * returned for performance reasons, so if you need a copy of it, you should call {@link
+ * #getElement(int)} for the given index
+ *
+ * @return An {@link java.util.Iterator} over all non-zero elements
+ */
+ @Override
+ public Iterator<Element> iterateNonZero() {
+
+ return new Iterator<Element>() {
+ class NonZeroElement implements Element {
+ int index;
+
+ @Override
+ public double get() {
+ return getQuick(index);
+ }
+
+ @Override
+ public int index() {
+ return index;
+ }
+
+ @Override
+ public void set(double value) {
+ invalidateCachedLength();
+ setQuick(index, value);
+ }
+ }
+
+ private final NonZeroElement element = new NonZeroElement();
+ private int index = -1;
+ private int lookAheadIndex = -1;
+
+ @Override
+ public boolean hasNext() {
+ if (lookAheadIndex == index) { // User calls hasNext() after a next()
+ lookAhead();
+ } // else user called hasNext() repeatedly.
+ return lookAheadIndex < size();
+ }
+
+ private void lookAhead() {
+ lookAheadIndex++;
+ while (lookAheadIndex < size() && getQuick(lookAheadIndex) == 0.0) {
+ lookAheadIndex++;
+ }
+ }
+
+ @Override
+ public Element next() {
+ if (lookAheadIndex == index) { // If user called next() without checking hasNext().
+ lookAhead();
+ }
+
+ index = lookAheadIndex;
+
+ if (index >= size()) { // If the end is reached.
+ throw new NoSuchElementException();
+ }
+
+ element.index = index;
+ return element;
+ }
+
+ @Override
+ public void remove() {
+ throw new UnsupportedOperationException();
+ }
+ };
+ }
+
+ /**
+ * Return the value at the given index, without checking bounds
+ *
+ * @param index an int index
+ * @return the double at the index
+ */
+ @Override
+ public double getQuick(int index) {
+ return matrix.getQuick(row + rowStride * index, column + columnStride * index);
+ }
+
+ /**
+ * Return an empty vector of the same underlying class as the receiver
+ *
+ * @return a Vector
+ */
+ @Override
+ public Vector like() {
+ return matrix.like(size(), 1).viewColumn(0);
+ }
+
+ @Override
+ public Vector like(int cardinality) {
+ return matrix.like(cardinality, 1).viewColumn(0);
+ }
+
+ /**
+ * Set the value at the given index, without checking bounds
+ *
+ * @param index an int index into the receiver
+ * @param value a double value to set
+ */
+ @Override
+ public void setQuick(int index, double value) {
+ matrix.setQuick(row + rowStride * index, column + columnStride * index, value);
+ }
+
+ /**
+ * Return the number of values in the recipient
+ *
+ * @return an int
+ */
+ @Override
+ public int getNumNondefaultElements() {
+ return size();
+ }
+
+ @Override
+ public double getLookupCost() {
+ // TODO: what is a genuine value here?
+ return 1;
+ }
+
+ @Override
+ public double getIteratorAdvanceCost() {
+ // TODO: what is a genuine value here?
+ return 1;
+ }
+
+ @Override
+ public boolean isAddConstantTime() {
+ // TODO: what is a genuine value here?
+ return true;
+ }
+
+ @Override
+ protected Matrix matrixLike(int rows, int columns) {
+ return matrix.like(rows, columns);
+ }
+
+ @Override
+ public Vector clone() {
+ MatrixVectorView r = (MatrixVectorView) super.clone();
+ r.matrix = matrix.clone();
+ r.row = row;
+ r.column = column;
+ r.rowStride = rowStride;
+ r.columnStride = columnStride;
+ return r;
+ }
+
+ /**
+ * Used internally by assign() to update multiple indices and values at once.
+ * Only really useful for sparse vectors (especially SequentialAccessSparseVector).
+ * <p>
+ * If someone ever adds a new type of sparse vectors, this method must merge (index, value) pairs into the vector.
+ *
+ * @param updates a mapping of indices to values to merge in the vector.
+ */
+ @Override
+ public void mergeUpdates(OrderedIntDoubleMapping updates) {
+ int[] indices = updates.getIndices();
+ double[] values = updates.getValues();
+ for (int i = 0; i < updates.getNumMappings(); ++i) {
+ matrix.setQuick(row + rowStride * indices[i], column + columnStride * indices[i], values[i]);
+ }
+ }
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/MatrixView.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/mahout/math/MatrixView.java b/core/src/main/java/org/apache/mahout/math/MatrixView.java
new file mode 100644
index 0000000..951515b
--- /dev/null
+++ b/core/src/main/java/org/apache/mahout/math/MatrixView.java
@@ -0,0 +1,160 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math;
+
+import org.apache.mahout.math.flavor.MatrixFlavor;
+
+/** Implements subset view of a Matrix */
+public class MatrixView extends AbstractMatrix {
+
+ private Matrix matrix;
+
+ // the offset into the Matrix
+ private int[] offset;
+
+ /**
+ * Construct a view of the matrix with given offset and cardinality
+ *
+ * @param matrix an underlying Matrix
+ * @param offset the int[2] offset into the underlying matrix
+ * @param size the int[2] size of the view
+ */
+ public MatrixView(Matrix matrix, int[] offset, int[] size) {
+ super(size[ROW], size[COL]);
+ int rowOffset = offset[ROW];
+ if (rowOffset < 0) {
+ throw new IndexException(rowOffset, rowSize());
+ }
+
+ int rowsRequested = size[ROW];
+ if (rowOffset + rowsRequested > matrix.rowSize()) {
+ throw new IndexException(rowOffset + rowsRequested, matrix.rowSize());
+ }
+
+ int columnOffset = offset[COL];
+ if (columnOffset < 0) {
+ throw new IndexException(columnOffset, columnSize());
+ }
+
+ int columnsRequested = size[COL];
+ if (columnOffset + columnsRequested > matrix.columnSize()) {
+ throw new IndexException(columnOffset + columnsRequested, matrix.columnSize());
+ }
+ this.matrix = matrix;
+ this.offset = offset;
+ }
+
+ @Override
+ public Matrix clone() {
+ MatrixView clone = (MatrixView) super.clone();
+ clone.matrix = matrix.clone();
+ clone.offset = offset.clone();
+ return clone;
+ }
+
+ @Override
+ public double getQuick(int row, int column) {
+ return matrix.getQuick(offset[ROW] + row, offset[COL] + column);
+ }
+
+ @Override
+ public Matrix like() {
+ return matrix.like(rowSize(), columnSize());
+ }
+
+ @Override
+ public Matrix like(int rows, int columns) {
+ return matrix.like(rows, columns);
+ }
+
+ @Override
+ public void setQuick(int row, int column, double value) {
+ matrix.setQuick(offset[ROW] + row, offset[COL] + column, value);
+ }
+
+ @Override
+ public int[] getNumNondefaultElements() {
+ return new int[]{rowSize(), columnSize()};
+
+ }
+
+ @Override
+ public Matrix viewPart(int[] offset, int[] size) {
+ if (offset[ROW] < 0) {
+ throw new IndexException(offset[ROW], 0);
+ }
+ if (offset[ROW] + size[ROW] > rowSize()) {
+ throw new IndexException(offset[ROW] + size[ROW], rowSize());
+ }
+ if (offset[COL] < 0) {
+ throw new IndexException(offset[COL], 0);
+ }
+ if (offset[COL] + size[COL] > columnSize()) {
+ throw new IndexException(offset[COL] + size[COL], columnSize());
+ }
+ int[] origin = this.offset.clone();
+ origin[ROW] += offset[ROW];
+ origin[COL] += offset[COL];
+ return new MatrixView(matrix, origin, size);
+ }
+
+ @Override
+ public Matrix assignColumn(int column, Vector other) {
+ if (rowSize() != other.size()) {
+ throw new CardinalityException(rowSize(), other.size());
+ }
+ for (int row = 0; row < rowSize(); row++) {
+ matrix.setQuick(row + offset[ROW], column + offset[COL], other
+ .getQuick(row));
+ }
+ return this;
+ }
+
+ @Override
+ public Matrix assignRow(int row, Vector other) {
+ if (columnSize() != other.size()) {
+ throw new CardinalityException(columnSize(), other.size());
+ }
+ for (int col = 0; col < columnSize(); col++) {
+ matrix
+ .setQuick(row + offset[ROW], col + offset[COL], other.getQuick(col));
+ }
+ return this;
+ }
+
+ @Override
+ public Vector viewColumn(int column) {
+ if (column < 0 || column >= columnSize()) {
+ throw new IndexException(column, columnSize());
+ }
+ return matrix.viewColumn(column + offset[COL]).viewPart(offset[ROW], rowSize());
+ }
+
+ @Override
+ public Vector viewRow(int row) {
+ if (row < 0 || row >= rowSize()) {
+ throw new IndexException(row, rowSize());
+ }
+ return matrix.viewRow(row + offset[ROW]).viewPart(offset[COL], columnSize());
+ }
+
+ @Override
+ public MatrixFlavor getFlavor() {
+ return matrix.getFlavor();
+ }
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/MurmurHash.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/mahout/math/MurmurHash.java b/core/src/main/java/org/apache/mahout/math/MurmurHash.java
new file mode 100644
index 0000000..13f3a07
--- /dev/null
+++ b/core/src/main/java/org/apache/mahout/math/MurmurHash.java
@@ -0,0 +1,158 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math;
+
+import com.google.common.primitives.Ints;
+
+import java.nio.ByteBuffer;
+import java.nio.ByteOrder;
+
+/**
+ * <p>This is a very fast, non-cryptographic hash suitable for general hash-based
+ * lookup. See http://murmurhash.googlepages.com/ for more details.
+ * </p>
+ * <p>The C version of MurmurHash 2.0 found at that site was ported
+ * to Java by Andrzej Bialecki (ab at getopt org).</p>
+ */
+public final class MurmurHash {
+
+ private MurmurHash() {}
+
+ /**
+ * Hashes an int.
+ * @param data The int to hash.
+ * @param seed The seed for the hash.
+ * @return The 32 bit hash of the bytes in question.
+ */
+ public static int hash(int data, int seed) {
+ return hash(ByteBuffer.wrap(Ints.toByteArray(data)), seed);
+ }
+
+ /**
+ * Hashes bytes in an array.
+ * @param data The bytes to hash.
+ * @param seed The seed for the hash.
+ * @return The 32 bit hash of the bytes in question.
+ */
+ public static int hash(byte[] data, int seed) {
+ return hash(ByteBuffer.wrap(data), seed);
+ }
+
+ /**
+ * Hashes bytes in part of an array.
+ * @param data The data to hash.
+ * @param offset Where to start munging.
+ * @param length How many bytes to process.
+ * @param seed The seed to start with.
+ * @return The 32-bit hash of the data in question.
+ */
+ public static int hash(byte[] data, int offset, int length, int seed) {
+ return hash(ByteBuffer.wrap(data, offset, length), seed);
+ }
+
+ /**
+ * Hashes the bytes in a buffer from the current position to the limit.
+ * @param buf The bytes to hash.
+ * @param seed The seed for the hash.
+ * @return The 32 bit murmur hash of the bytes in the buffer.
+ */
+ public static int hash(ByteBuffer buf, int seed) {
+ // save byte order for later restoration
+ ByteOrder byteOrder = buf.order();
+ buf.order(ByteOrder.LITTLE_ENDIAN);
+
+ int m = 0x5bd1e995;
+ int r = 24;
+
+ int h = seed ^ buf.remaining();
+
+ while (buf.remaining() >= 4) {
+ int k = buf.getInt();
+
+ k *= m;
+ k ^= k >>> r;
+ k *= m;
+
+ h *= m;
+ h ^= k;
+ }
+
+ if (buf.remaining() > 0) {
+ ByteBuffer finish = ByteBuffer.allocate(4).order(ByteOrder.LITTLE_ENDIAN);
+ // for big-endian version, use this first:
+ // finish.position(4-buf.remaining());
+ finish.put(buf).rewind();
+ h ^= finish.getInt();
+ h *= m;
+ }
+
+ h ^= h >>> 13;
+ h *= m;
+ h ^= h >>> 15;
+
+ buf.order(byteOrder);
+ return h;
+ }
+
+
+ public static long hash64A(byte[] data, int seed) {
+ return hash64A(ByteBuffer.wrap(data), seed);
+ }
+
+ public static long hash64A(byte[] data, int offset, int length, int seed) {
+ return hash64A(ByteBuffer.wrap(data, offset, length), seed);
+ }
+
+ public static long hash64A(ByteBuffer buf, int seed) {
+ ByteOrder byteOrder = buf.order();
+ buf.order(ByteOrder.LITTLE_ENDIAN);
+
+ long m = 0xc6a4a7935bd1e995L;
+ int r = 47;
+
+ long h = seed ^ (buf.remaining() * m);
+
+ while (buf.remaining() >= 8) {
+ long k = buf.getLong();
+
+ k *= m;
+ k ^= k >>> r;
+ k *= m;
+
+ h ^= k;
+ h *= m;
+ }
+
+ if (buf.remaining() > 0) {
+ ByteBuffer finish = ByteBuffer.allocate(8).order(ByteOrder.LITTLE_ENDIAN);
+ // for big-endian version, do this first:
+ // finish.position(8-buf.remaining());
+ finish.put(buf).rewind();
+ h ^= finish.getLong();
+ h *= m;
+ }
+
+ h ^= h >>> r;
+ h *= m;
+ h ^= h >>> r;
+
+ buf.order(byteOrder);
+ return h;
+ }
+
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/MurmurHash3.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/mahout/math/MurmurHash3.java b/core/src/main/java/org/apache/mahout/math/MurmurHash3.java
new file mode 100644
index 0000000..bd0bb6b
--- /dev/null
+++ b/core/src/main/java/org/apache/mahout/math/MurmurHash3.java
@@ -0,0 +1,84 @@
+/*
+ * This code is public domain.
+ *
+ * The MurmurHash3 algorithm was created by Austin Appleby and put into the public domain.
+ * See http://code.google.com/p/smhasher/
+ *
+ * This java port was authored by
+ * Yonik Seeley and was placed into the public domain per
+ * https://github.com/yonik/java_util/blob/master/src/util/hash/MurmurHash3.java.
+ */
+
+package org.apache.mahout.math;
+
+/**
+ * <p>
+ * This produces exactly the same hash values as the final C+
+ + * version of MurmurHash3 and is thus suitable for producing the same hash values across
+ * platforms.
+ * <p>
+ * The 32 bit x86 version of this hash should be the fastest variant for relatively short keys like ids.
+ * <p>
+ * Note - The x86 and x64 versions do _not_ produce the same results, as the
+ * algorithms are optimized for their respective platforms.
+ * <p>
+ * See also http://github.com/yonik/java_util for future updates to this file.
+ */
+public final class MurmurHash3 {
+
+ private MurmurHash3() {}
+
+ /** Returns the MurmurHash3_x86_32 hash. */
+ public static int murmurhash3x8632(byte[] data, int offset, int len, int seed) {
+
+ int c1 = 0xcc9e2d51;
+ int c2 = 0x1b873593;
+
+ int h1 = seed;
+ int roundedEnd = offset + (len & 0xfffffffc); // round down to 4 byte block
+
+ for (int i = offset; i < roundedEnd; i += 4) {
+ // little endian load order
+ int k1 = (data[i] & 0xff) | ((data[i + 1] & 0xff) << 8) | ((data[i + 2] & 0xff) << 16) | (data[i + 3] << 24);
+ k1 *= c1;
+ k1 = (k1 << 15) | (k1 >>> 17); // ROTL32(k1,15);
+ k1 *= c2;
+
+ h1 ^= k1;
+ h1 = (h1 << 13) | (h1 >>> 19); // ROTL32(h1,13);
+ h1 = h1 * 5 + 0xe6546b64;
+ }
+
+ // tail
+ int k1 = 0;
+
+ switch(len & 0x03) {
+ case 3:
+ k1 = (data[roundedEnd + 2] & 0xff) << 16;
+ // fallthrough
+ case 2:
+ k1 |= (data[roundedEnd + 1] & 0xff) << 8;
+ // fallthrough
+ case 1:
+ k1 |= data[roundedEnd] & 0xff;
+ k1 *= c1;
+ k1 = (k1 << 15) | (k1 >>> 17); // ROTL32(k1,15);
+ k1 *= c2;
+ h1 ^= k1;
+ default:
+ }
+
+ // finalization
+ h1 ^= len;
+
+ // fmix(h1);
+ h1 ^= h1 >>> 16;
+ h1 *= 0x85ebca6b;
+ h1 ^= h1 >>> 13;
+ h1 *= 0xc2b2ae35;
+ h1 ^= h1 >>> 16;
+
+ return h1;
+ }
+
+}
r***@apache.org
2018-09-08 23:35:16 UTC
Permalink
http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/NamedVector.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/mahout/math/NamedVector.java b/core/src/main/java/org/apache/mahout/math/NamedVector.java
new file mode 100644
index 0000000..d4fa609
--- /dev/null
+++ b/core/src/main/java/org/apache/mahout/math/NamedVector.java
@@ -0,0 +1,328 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math;
+
+import org.apache.mahout.math.function.DoubleDoubleFunction;
+import org.apache.mahout.math.function.DoubleFunction;
+
+public class NamedVector implements Vector {
+
+ private Vector delegate;
+ private String name;
+
+ public NamedVector() {
+ }
+
+ public NamedVector(NamedVector other) {
+ this.delegate = other.getDelegate();
+ this.name = other.getName();
+ }
+
+ public NamedVector(Vector delegate, String name) {
+ if (delegate == null || name == null) {
+ throw new IllegalArgumentException();
+ }
+ this.delegate = delegate;
+ this.name = name;
+ }
+
+ public String getName() {
+ return name;
+ }
+
+ public Vector getDelegate() {
+ return delegate;
+ }
+
+ @Override
+ public int hashCode() {
+ return delegate.hashCode();
+ }
+
+ /**
+ * To not break transitivity with other {@link Vector}s, this does not compare name.
+ */
+ @SuppressWarnings("EqualsWhichDoesntCheckParameterClass")
+ @Override
+ public boolean equals(Object other) {
+ return delegate.equals(other);
+ }
+
+ @SuppressWarnings("CloneDoesntCallSuperClone")
+ @Override
+ public NamedVector clone() {
+ return new NamedVector(delegate.clone(), name);
+ }
+
+ @Override
+ public Iterable<Element> all() {
+ return delegate.all();
+ }
+
+ @Override
+ public Iterable<Element> nonZeroes() {
+ return delegate.nonZeroes();
+ }
+
+ @Override
+ public String asFormatString() {
+ return toString();
+ }
+
+ @Override
+ public String toString() {
+ StringBuilder bldr = new StringBuilder();
+ bldr.append(name).append(':').append(delegate.toString());
+ return bldr.toString();
+ }
+
+ @Override
+ public Vector assign(double value) {
+ return delegate.assign(value);
+ }
+
+ @Override
+ public Vector assign(double[] values) {
+ return delegate.assign(values);
+ }
+
+ @Override
+ public Vector assign(Vector other) {
+ return delegate.assign(other);
+ }
+
+ @Override
+ public Vector assign(DoubleFunction function) {
+ return delegate.assign(function);
+ }
+
+ @Override
+ public Vector assign(Vector other, DoubleDoubleFunction function) {
+ return delegate.assign(other, function);
+ }
+
+ @Override
+ public Vector assign(DoubleDoubleFunction f, double y) {
+ return delegate.assign(f, y);
+ }
+
+ @Override
+ public int size() {
+ return delegate.size();
+ }
+
+ @Override
+ public boolean isDense() {
+ return delegate.isDense();
+ }
+
+ @Override
+ public boolean isSequentialAccess() {
+ return delegate.isSequentialAccess();
+ }
+
+ @Override
+ public Element getElement(int index) {
+ return delegate.getElement(index);
+ }
+
+ /**
+ * Merge a set of (index, value) pairs into the vector.
+ *
+ * @param updates an ordered mapping of indices to values to be merged in.
+ */
+ @Override
+ public void mergeUpdates(OrderedIntDoubleMapping updates) {
+ delegate.mergeUpdates(updates);
+ }
+
+ @Override
+ public Vector divide(double x) {
+ return delegate.divide(x);
+ }
+
+ @Override
+ public double dot(Vector x) {
+ return delegate.dot(x);
+ }
+
+ @Override
+ public double get(int index) {
+ return delegate.get(index);
+ }
+
+ @Override
+ public double getQuick(int index) {
+ return delegate.getQuick(index);
+ }
+
+ @Override
+ public NamedVector like() {
+ return new NamedVector(delegate.like(), name);
+ }
+
+ @Override
+ public Vector like(int cardinality) {
+ return new NamedVector(delegate.like(cardinality), name);
+ }
+
+ @Override
+ public Vector minus(Vector x) {
+ return delegate.minus(x);
+ }
+
+ @Override
+ public Vector normalize() {
+ return delegate.normalize();
+ }
+
+ @Override
+ public Vector normalize(double power) {
+ return delegate.normalize(power);
+ }
+
+ @Override
+ public Vector logNormalize() {
+ return delegate.logNormalize();
+ }
+
+ @Override
+ public Vector logNormalize(double power) {
+ return delegate.logNormalize(power);
+ }
+
+ @Override
+ public double norm(double power) {
+ return delegate.norm(power);
+ }
+
+ @Override
+ public double maxValue() {
+ return delegate.maxValue();
+ }
+
+ @Override
+ public int maxValueIndex() {
+ return delegate.maxValueIndex();
+ }
+
+ @Override
+ public double minValue() {
+ return delegate.minValue();
+ }
+
+ @Override
+ public int minValueIndex() {
+ return delegate.minValueIndex();
+ }
+
+ @Override
+ public Vector plus(double x) {
+ return delegate.plus(x);
+ }
+
+ @Override
+ public Vector plus(Vector x) {
+ return delegate.plus(x);
+ }
+
+ @Override
+ public void set(int index, double value) {
+ delegate.set(index, value);
+ }
+
+ @Override
+ public void setQuick(int index, double value) {
+ delegate.setQuick(index, value);
+ }
+
+ @Override
+ public void incrementQuick(int index, double increment) {
+ delegate.incrementQuick(index, increment);
+ }
+
+ @Override
+ public int getNumNonZeroElements() {
+ return delegate.getNumNonZeroElements();
+ }
+
+ @Override
+ public int getNumNondefaultElements() {
+ return delegate.getNumNondefaultElements();
+ }
+
+ @Override
+ public Vector times(double x) {
+ return delegate.times(x);
+ }
+
+ @Override
+ public Vector times(Vector x) {
+ return delegate.times(x);
+ }
+
+ @Override
+ public Vector viewPart(int offset, int length) {
+ return delegate.viewPart(offset, length);
+ }
+
+ @Override
+ public double zSum() {
+ return delegate.zSum();
+ }
+
+ @Override
+ public Matrix cross(Vector other) {
+ return delegate.cross(other);
+ }
+
+ @Override
+ public double aggregate(DoubleDoubleFunction aggregator, DoubleFunction map) {
+ return delegate.aggregate(aggregator, map);
+ }
+
+ @Override
+ public double aggregate(Vector other, DoubleDoubleFunction aggregator, DoubleDoubleFunction combiner) {
+ return delegate.aggregate(other, aggregator, combiner);
+ }
+
+ @Override
+ public double getLengthSquared() {
+ return delegate.getLengthSquared();
+ }
+
+ @Override
+ public double getDistanceSquared(Vector v) {
+ return delegate.getDistanceSquared(v);
+ }
+
+ @Override
+ public double getLookupCost() {
+ return delegate.getLookupCost();
+ }
+
+ @Override
+ public double getIteratorAdvanceCost() {
+ return delegate.getIteratorAdvanceCost();
+ }
+
+ @Override
+ public boolean isAddConstantTime() {
+ return delegate.isAddConstantTime();
+ }
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/OldQRDecomposition.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/mahout/math/OldQRDecomposition.java b/core/src/main/java/org/apache/mahout/math/OldQRDecomposition.java
new file mode 100644
index 0000000..e1552e4
--- /dev/null
+++ b/core/src/main/java/org/apache/mahout/math/OldQRDecomposition.java
@@ -0,0 +1,234 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *
+ * Copyright 1999 CERN - European Organization for Nuclear Research.
+ * Permission to use, copy, modify, distribute and sell this software and its documentation for any purpose
+ * is hereby granted without fee, provided that the above copyright notice appear in all copies and
+ * that both that copyright notice and this permission notice appear in supporting documentation.
+ * CERN makes no representations about the suitability of this software for any purpose.
+ * It is provided "as is" without expressed or implied warranty.
+ */
+package org.apache.mahout.math;
+
+import org.apache.mahout.math.function.Functions;
+
+import java.util.Locale;
+
+
+/**
+ For an <tt>m x n</tt> matrix <tt>A</tt> with <tt>m >= n</tt>, the QR decomposition is an <tt>m x n</tt>
+ orthogonal matrix <tt>Q</tt> and an <tt>n x n</tt> upper triangular matrix <tt>R</tt> so that
+ <tt>A = Q*R</tt>.
+ <P>
+ The QR decompostion always exists, even if the matrix does not have
+ full rank, so the constructor will never fail. The primary use of the
+ QR decomposition is in the least squares solution of nonsquare systems
+ of simultaneous linear equations. This will fail if <tt>isFullRank()</tt>
+ returns <tt>false</tt>.
+ */
+
+/** partially deprecated until unit tests are in place. Until this time, this class/interface is unsupported. */
+public class OldQRDecomposition implements QR {
+
+ /** Array for internal storage of decomposition. */
+ private final Matrix qr;
+
+ /** Row and column dimensions. */
+ private final int originalRows;
+ private final int originalColumns;
+
+ /** Array for internal storage of diagonal of R. */
+ private final Vector rDiag;
+
+ /**
+ * Constructs and returns a new QR decomposition object; computed by Householder reflections; The decomposed matrices
+ * can be retrieved via instance methods of the returned decomposition object.
+ *
+ * @param a A rectangular matrix.
+ * @throws IllegalArgumentException if {@code A.rows() < A.columns()}
+ */
+
+ public OldQRDecomposition(Matrix a) {
+
+ // Initialize.
+ qr = a.clone();
+ originalRows = a.numRows();
+ originalColumns = a.numCols();
+ rDiag = new DenseVector(originalColumns);
+
+ // precompute and cache some views to avoid regenerating them time and again
+ Vector[] QRcolumnsPart = new Vector[originalColumns];
+ for (int k = 0; k < originalColumns; k++) {
+ QRcolumnsPart[k] = qr.viewColumn(k).viewPart(k, originalRows - k);
+ }
+
+ // Main loop.
+ for (int k = 0; k < originalColumns; k++) {
+ //DoubleMatrix1D QRcolk = QR.viewColumn(k).viewPart(k,m-k);
+ // Compute 2-norm of k-th column without under/overflow.
+ double nrm = 0;
+ //if (k<m) nrm = QRcolumnsPart[k].aggregate(hypot,F.identity);
+
+ for (int i = k; i < originalRows; i++) { // fixes bug reported by ***@osu.edu
+ nrm = Algebra.hypot(nrm, qr.getQuick(i, k));
+ }
+
+
+ if (nrm != 0.0) {
+ // Form k-th Householder vector.
+ if (qr.getQuick(k, k) < 0) {
+ nrm = -nrm;
+ }
+ QRcolumnsPart[k].assign(Functions.div(nrm));
+ /*
+ for (int i = k; i < m; i++) {
+ QR[i][k] /= nrm;
+ }
+ */
+
+ qr.setQuick(k, k, qr.getQuick(k, k) + 1);
+
+ // Apply transformation to remaining columns.
+ for (int j = k + 1; j < originalColumns; j++) {
+ Vector QRcolj = qr.viewColumn(j).viewPart(k, originalRows - k);
+ double s = QRcolumnsPart[k].dot(QRcolj);
+ /*
+ // fixes bug reported by John Chambers
+ DoubleMatrix1D QRcolj = QR.viewColumn(j).viewPart(k,m-k);
+ double s = QRcolumnsPart[k].zDotProduct(QRcolumns[j]);
+ double s = 0.0;
+ for (int i = k; i < m; i++) {
+ s += QR[i][k]*QR[i][j];
+ }
+ */
+ s = -s / qr.getQuick(k, k);
+ //QRcolumnsPart[j].assign(QRcolumns[k], F.plusMult(s));
+
+ for (int i = k; i < originalRows; i++) {
+ qr.setQuick(i, j, qr.getQuick(i, j) + s * qr.getQuick(i, k));
+ }
+
+ }
+ }
+ rDiag.setQuick(k, -nrm);
+ }
+ }
+
+ /**
+ * Generates and returns the (economy-sized) orthogonal factor <tt>Q</tt>.
+ *
+ * @return <tt>Q</tt>
+ */
+ @Override
+ public Matrix getQ() {
+ int columns = Math.min(originalColumns, originalRows);
+ Matrix q = qr.like(originalRows, columns);
+ for (int k = columns - 1; k >= 0; k--) {
+ Vector QRcolk = qr.viewColumn(k).viewPart(k, originalRows - k);
+ q.set(k, k, 1);
+ for (int j = k; j < columns; j++) {
+ if (qr.get(k, k) != 0) {
+ Vector Qcolj = q.viewColumn(j).viewPart(k, originalRows - k);
+ double s = -QRcolk.dot(Qcolj) / qr.get(k, k);
+ Qcolj.assign(QRcolk, Functions.plusMult(s));
+ }
+ }
+ }
+ return q;
+ }
+
+ /**
+ * Returns the upper triangular factor, <tt>R</tt>.
+ *
+ * @return <tt>R</tt>
+ */
+ @Override
+ public Matrix getR() {
+ int rows = Math.min(originalRows, originalColumns);
+ Matrix r = qr.like(rows, originalColumns);
+ for (int i = 0; i < rows; i++) {
+ for (int j = 0; j < originalColumns; j++) {
+ if (i < j) {
+ r.setQuick(i, j, qr.getQuick(i, j));
+ } else if (i == j) {
+ r.setQuick(i, j, rDiag.getQuick(i));
+ } else {
+ r.setQuick(i, j, 0);
+ }
+ }
+ }
+ return r;
+ }
+
+ /**
+ * Returns whether the matrix <tt>A</tt> has full rank.
+ *
+ * @return true if <tt>R</tt>, and hence <tt>A</tt>, has full rank.
+ */
+ @Override
+ public boolean hasFullRank() {
+ for (int j = 0; j < originalColumns; j++) {
+ if (rDiag.getQuick(j) == 0) {
+ return false;
+ }
+ }
+ return true;
+ }
+
+ /**
+ * Least squares solution of <tt>A*X = B</tt>; <tt>returns X</tt>.
+ *
+ * @param B A matrix with as many rows as <tt>A</tt> and any number of columns.
+ * @return <tt>X</tt> that minimizes the two norm of <tt>Q*R*X - B</tt>.
+ * @throws IllegalArgumentException if <tt>B.rows() != A.rows()</tt>.
+ */
+ @Override
+ public Matrix solve(Matrix B) {
+ if (B.numRows() != originalRows) {
+ throw new IllegalArgumentException("Matrix row dimensions must agree.");
+ }
+
+ int columns = B.numCols();
+ Matrix x = B.like(originalColumns, columns);
+
+ // this can all be done a bit more efficiently if we don't actually
+ // form explicit versions of Q^T and R but this code isn't soo bad
+ // and it is much easier to understand
+ Matrix qt = getQ().transpose();
+ Matrix y = qt.times(B);
+
+ Matrix r = getR();
+ for (int k = Math.min(originalColumns, originalRows) - 1; k >= 0; k--) {
+ // X[k,] = Y[k,] / R[k,k], note that X[k,] starts with 0 so += is same as =
+ x.viewRow(k).assign(y.viewRow(k), Functions.plusMult(1 / r.get(k, k)));
+
+ // Y[0:(k-1),] -= R[0:(k-1),k] * X[k,]
+ Vector rColumn = r.viewColumn(k).viewPart(0, k);
+ for (int c = 0; c < columns; c++) {
+ y.viewColumn(c).viewPart(0, k).assign(rColumn, Functions.plusMult(-x.get(k, c)));
+ }
+ }
+ return x;
+ }
+
+ /**
+ * Returns a rough string rendition of a QR.
+ */
+ @Override
+ public String toString() {
+ return String.format(Locale.ENGLISH, "QR(%d,%d,fullRank=%s)", originalColumns, originalRows, hasFullRank());
+ }
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/OrderedIntDoubleMapping.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/mahout/math/OrderedIntDoubleMapping.java b/core/src/main/java/org/apache/mahout/math/OrderedIntDoubleMapping.java
new file mode 100644
index 0000000..7c6ad11
--- /dev/null
+++ b/core/src/main/java/org/apache/mahout/math/OrderedIntDoubleMapping.java
@@ -0,0 +1,265 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math;
+
+import java.io.Serializable;
+
+public final class OrderedIntDoubleMapping implements Serializable, Cloneable {
+
+ static final double DEFAULT_VALUE = 0.0;
+
+ private int[] indices;
+ private double[] values;
+ private int numMappings;
+
+ // If true, doesn't allow DEFAULT_VALUEs in the mapping (adding a zero discards it). Otherwise, a DEFAULT_VALUE is
+ // treated like any other value.
+ private boolean noDefault = true;
+
+ OrderedIntDoubleMapping(boolean noDefault) {
+ this();
+ this.noDefault = noDefault;
+ }
+
+ OrderedIntDoubleMapping() {
+ // no-arg constructor for deserializer
+ this(11);
+ }
+
+ OrderedIntDoubleMapping(int capacity) {
+ indices = new int[capacity];
+ values = new double[capacity];
+ numMappings = 0;
+ }
+
+ OrderedIntDoubleMapping(int[] indices, double[] values, int numMappings) {
+ this.indices = indices;
+ this.values = values;
+ this.numMappings = numMappings;
+ }
+
+ public int[] getIndices() {
+ return indices;
+ }
+
+ public int indexAt(int offset) {
+ return indices[offset];
+ }
+
+ public void setIndexAt(int offset, int index) {
+ indices[offset] = index;
+ }
+
+ public double[] getValues() {
+ return values;
+ }
+
+ public void setValueAt(int offset, double value) {
+ values[offset] = value;
+ }
+
+
+ public int getNumMappings() {
+ return numMappings;
+ }
+
+ private void growTo(int newCapacity) {
+ if (newCapacity > indices.length) {
+ int[] newIndices = new int[newCapacity];
+ System.arraycopy(indices, 0, newIndices, 0, numMappings);
+ indices = newIndices;
+ double[] newValues = new double[newCapacity];
+ System.arraycopy(values, 0, newValues, 0, numMappings);
+ values = newValues;
+ }
+ }
+
+ private int find(int index) {
+ int low = 0;
+ int high = numMappings - 1;
+ while (low <= high) {
+ int mid = low + (high - low >>> 1);
+ int midVal = indices[mid];
+ if (midVal < index) {
+ low = mid + 1;
+ } else if (midVal > index) {
+ high = mid - 1;
+ } else {
+ return mid;
+ }
+ }
+ return -(low + 1);
+ }
+
+ public double get(int index) {
+ int offset = find(index);
+ return offset >= 0 ? values[offset] : DEFAULT_VALUE;
+ }
+
+ public void set(int index, double value) {
+ if (numMappings == 0 || index > indices[numMappings - 1]) {
+ if (!noDefault || value != DEFAULT_VALUE) {
+ if (numMappings >= indices.length) {
+ growTo(Math.max((int) (1.2 * numMappings), numMappings + 1));
+ }
+ indices[numMappings] = index;
+ values[numMappings] = value;
+ ++numMappings;
+ }
+ } else {
+ int offset = find(index);
+ if (offset >= 0) {
+ insertOrUpdateValueIfPresent(offset, value);
+ } else {
+ insertValueIfNotDefault(index, offset, value);
+ }
+ }
+ }
+
+ /**
+ * Merges the updates in linear time by allocating new arrays and iterating through the existing indices and values
+ * and the updates' indices and values at the same time while selecting the minimum index to set at each step.
+ * @param updates another list of mappings to be merged in.
+ */
+ public void merge(OrderedIntDoubleMapping updates) {
+ int[] updateIndices = updates.getIndices();
+ double[] updateValues = updates.getValues();
+
+ int newNumMappings = numMappings + updates.getNumMappings();
+ int newCapacity = Math.max((int) (1.2 * newNumMappings), newNumMappings + 1);
+ int[] newIndices = new int[newCapacity];
+ double[] newValues = new double[newCapacity];
+
+ int k = 0;
+ int i = 0, j = 0;
+ for (; i < numMappings && j < updates.getNumMappings(); ++k) {
+ if (indices[i] < updateIndices[j]) {
+ newIndices[k] = indices[i];
+ newValues[k] = values[i];
+ ++i;
+ } else if (indices[i] > updateIndices[j]) {
+ newIndices[k] = updateIndices[j];
+ newValues[k] = updateValues[j];
+ ++j;
+ } else {
+ newIndices[k] = updateIndices[j];
+ newValues[k] = updateValues[j];
+ ++i;
+ ++j;
+ }
+ }
+
+ for (; i < numMappings; ++i, ++k) {
+ newIndices[k] = indices[i];
+ newValues[k] = values[i];
+ }
+ for (; j < updates.getNumMappings(); ++j, ++k) {
+ newIndices[k] = updateIndices[j];
+ newValues[k] = updateValues[j];
+ }
+
+ indices = newIndices;
+ values = newValues;
+ numMappings = k;
+ }
+
+ @Override
+ public int hashCode() {
+ int result = 0;
+ for (int i = 0; i < numMappings; i++) {
+ result = 31 * result + indices[i];
+ result = 31 * result + (int) Double.doubleToRawLongBits(values[i]);
+ }
+ return result;
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (o instanceof OrderedIntDoubleMapping) {
+ OrderedIntDoubleMapping other = (OrderedIntDoubleMapping) o;
+ if (numMappings == other.numMappings) {
+ for (int i = 0; i < numMappings; i++) {
+ if (indices[i] != other.indices[i] || values[i] != other.values[i]) {
+ return false;
+ }
+ }
+ return true;
+ }
+ }
+ return false;
+ }
+
+ @Override
+ public String toString() {
+ StringBuilder result = new StringBuilder(10 * numMappings);
+ for (int i = 0; i < numMappings; i++) {
+ result.append('(');
+ result.append(indices[i]);
+ result.append(',');
+ result.append(values[i]);
+ result.append(')');
+ }
+ return result.toString();
+ }
+
+ @SuppressWarnings("CloneDoesntCallSuperClone")
+ @Override
+ public OrderedIntDoubleMapping clone() {
+ return new OrderedIntDoubleMapping(indices.clone(), values.clone(), numMappings);
+ }
+
+ public void increment(int index, double increment) {
+ int offset = find(index);
+ if (offset >= 0) {
+ double newValue = values[offset] + increment;
+ insertOrUpdateValueIfPresent(offset, newValue);
+ } else {
+ insertValueIfNotDefault(index, offset, increment);
+ }
+ }
+
+ private void insertValueIfNotDefault(int index, int offset, double value) {
+ if (!noDefault || value != DEFAULT_VALUE) {
+ if (numMappings >= indices.length) {
+ growTo(Math.max((int) (1.2 * numMappings), numMappings + 1));
+ }
+ int at = -offset - 1;
+ if (numMappings > at) {
+ for (int i = numMappings - 1, j = numMappings; i >= at; i--, j--) {
+ indices[j] = indices[i];
+ values[j] = values[i];
+ }
+ }
+ indices[at] = index;
+ values[at] = value;
+ numMappings++;
+ }
+ }
+
+ private void insertOrUpdateValueIfPresent(int offset, double newValue) {
+ if (noDefault && newValue == DEFAULT_VALUE) {
+ for (int i = offset + 1, j = offset; i < numMappings; i++, j++) {
+ indices[j] = indices[i];
+ values[j] = values[i];
+ }
+ numMappings--;
+ } else {
+ values[offset] = newValue;
+ }
+ }
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/OrthonormalityVerifier.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/mahout/math/OrthonormalityVerifier.java b/core/src/main/java/org/apache/mahout/math/OrthonormalityVerifier.java
new file mode 100644
index 0000000..e8dd2b1
--- /dev/null
+++ b/core/src/main/java/org/apache/mahout/math/OrthonormalityVerifier.java
@@ -0,0 +1,46 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math;
+
+import com.google.common.collect.Lists;
+
+import java.util.List;
+
+public final class OrthonormalityVerifier {
+
+ private OrthonormalityVerifier() {
+ }
+
+ public static VectorIterable pairwiseInnerProducts(Iterable<MatrixSlice> basis) {
+ DenseMatrix out = null;
+ for (MatrixSlice slice1 : basis) {
+ List<Double> dots = Lists.newArrayList();
+ for (MatrixSlice slice2 : basis) {
+ dots.add(slice1.vector().dot(slice2.vector()));
+ }
+ if (out == null) {
+ out = new DenseMatrix(dots.size(), dots.size());
+ }
+ for (int i = 0; i < dots.size(); i++) {
+ out.set(slice1.index(), i, dots.get(i));
+ }
+ }
+ return out;
+ }
+
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/PermutedVectorView.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/mahout/math/PermutedVectorView.java b/core/src/main/java/org/apache/mahout/math/PermutedVectorView.java
new file mode 100644
index 0000000..e46f326
--- /dev/null
+++ b/core/src/main/java/org/apache/mahout/math/PermutedVectorView.java
@@ -0,0 +1,250 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math;
+
+import java.util.Iterator;
+
+import com.google.common.collect.AbstractIterator;
+
+/**
+ * Provides a permuted view of a vector.
+ */
+public class PermutedVectorView extends AbstractVector {
+ private final Vector vector; // the vector containing the data
+ private final int[] pivot; // convert from external index to internal
+ private final int[] unpivot; // convert from internal index to external
+
+ public PermutedVectorView(Vector vector, int[] pivot, int[] unpivot) {
+ super(vector.size());
+ this.vector = vector;
+ this.pivot = pivot;
+ this.unpivot = unpivot;
+ }
+
+ public PermutedVectorView(Vector vector, int[] pivot) {
+ this(vector, pivot, reversePivotPermutation(pivot));
+ }
+
+ private static int[] reversePivotPermutation(int[] pivot) {
+ int[] unpivot1 = new int[pivot.length];
+ for (int i = 0; i < pivot.length; i++) {
+ unpivot1[pivot[i]] = i;
+ }
+ return unpivot1;
+ }
+
+ /**
+ * Subclasses must override to return an appropriately sparse or dense result
+ *
+ * @param rows the row cardinality
+ * @param columns the column cardinality
+ * @return a Matrix
+ */
+ @Override
+ protected Matrix matrixLike(int rows, int columns) {
+ if (vector.isDense()) {
+ return new DenseMatrix(rows, columns);
+ } else {
+ return new SparseRowMatrix(rows, columns);
+ }
+ }
+
+ /**
+ * Used internally by assign() to update multiple indices and values at once.
+ * Only really useful for sparse vectors (especially SequentialAccessSparseVector).
+ * <p>
+ * If someone ever adds a new type of sparse vectors, this method must merge (index, value) pairs into the vector.
+ *
+ * @param updates a mapping of indices to values to merge in the vector.
+ */
+ @Override
+ public void mergeUpdates(OrderedIntDoubleMapping updates) {
+ for (int i = 0; i < updates.getNumMappings(); ++i) {
+ updates.setIndexAt(i, pivot[updates.indexAt(i)]);
+ }
+ vector.mergeUpdates(updates);
+ }
+
+ /**
+ * @return true iff this implementation should be considered dense -- that it explicitly
+ * represents every value
+ */
+ @Override
+ public boolean isDense() {
+ return vector.isDense();
+ }
+
+ /**
+ * If the view is permuted, the elements cannot be accessed in the same order.
+ *
+ * @return true iff this implementation should be considered to be iterable in index order in an
+ * efficient way. In particular this implies that {@link #iterator()} and {@link
+ * #iterateNonZero()} return elements in ascending order by index.
+ */
+ @Override
+ public boolean isSequentialAccess() {
+ return false;
+ }
+
+ /**
+ * Iterates over all elements <p> * NOTE: Implementations may choose to reuse the Element
+ * returned for performance reasons, so if you need a copy of it, you should call {@link
+ * #getElement(int)} for the given index
+ *
+ * @return An {@link java.util.Iterator} over all elements
+ */
+ @Override
+ public Iterator<Element> iterator() {
+ return new AbstractIterator<Element>() {
+ private final Iterator<Element> i = vector.all().iterator();
+
+ @Override
+ protected Vector.Element computeNext() {
+ if (i.hasNext()) {
+ final Element x = i.next();
+ return new Element() {
+ private final int index = unpivot[x.index()];
+
+ @Override
+ public double get() {
+ return x.get();
+ }
+
+ @Override
+ public int index() {
+ return index;
+ }
+
+ @Override
+ public void set(double value) {
+ x.set(value);
+ }
+ };
+ } else {
+ return endOfData();
+ }
+ }
+ };
+ }
+
+ /**
+ * Iterates over all non-zero elements. <p> NOTE: Implementations may choose to reuse the Element
+ * returned for performance reasons, so if you need a copy of it, you should call {@link
+ * #getElement(int)} for the given index
+ *
+ * @return An {@link java.util.Iterator} over all non-zero elements
+ */
+ @Override
+ public Iterator<Element> iterateNonZero() {
+ return new AbstractIterator<Element>() {
+ private final Iterator<Element> i = vector.nonZeroes().iterator();
+
+ @Override
+ protected Vector.Element computeNext() {
+ if (i.hasNext()) {
+ final Element x = i.next();
+ return new Element() {
+ private final int index = unpivot[x.index()];
+
+ @Override
+ public double get() {
+ return x.get();
+ }
+
+ @Override
+ public int index() {
+ return index;
+ }
+
+ @Override
+ public void set(double value) {
+ x.set(value);
+ }
+ };
+ } else {
+ return endOfData();
+ }
+ }
+ };
+ }
+
+ /**
+ * Return the value at the given index, without checking bounds
+ *
+ * @param index an int index
+ * @return the double at the index
+ */
+ @Override
+ public double getQuick(int index) {
+ return vector.getQuick(pivot[index]);
+ }
+
+ /**
+ * Return an empty vector of the same underlying class as the receiver
+ *
+ * @return a Vector
+ */
+ @Override
+ public Vector like() {
+ return vector.like();
+ }
+
+ @Override
+ public Vector like(int cardinality) {
+ return vector.like(cardinality);
+ }
+
+ /**
+ * Set the value at the given index, without checking bounds
+ *
+ * @param index an int index into the receiver
+ * @param value a double value to set
+ */
+ @Override
+ public void setQuick(int index, double value) {
+ vector.setQuick(pivot[index], value);
+ }
+
+ /** Return the number of values in the recipient */
+ @Override
+ public int getNumNondefaultElements() {
+ return vector.getNumNondefaultElements();
+ }
+
+ @Override
+ public int getNumNonZeroElements() {
+ // Return the number of nonzeros in the recipient,
+ // so potentially don't have to go through our iterator
+ return vector.getNumNonZeroElements();
+ }
+
+ @Override
+ public double getLookupCost() {
+ return vector.getLookupCost();
+ }
+
+ @Override
+ public double getIteratorAdvanceCost() {
+ return vector.getIteratorAdvanceCost();
+ }
+
+ @Override
+ public boolean isAddConstantTime() {
+ return vector.isAddConstantTime();
+ }
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/PersistentObject.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/mahout/math/PersistentObject.java b/core/src/main/java/org/apache/mahout/math/PersistentObject.java
new file mode 100644
index 0000000..f1d4293
--- /dev/null
+++ b/core/src/main/java/org/apache/mahout/math/PersistentObject.java
@@ -0,0 +1,58 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*
+Copyright 1999 CERN - European Organization for Nuclear Research.
+Permission to use, copy, modify, distribute and sell this software and its documentation for any purpose
+is hereby granted without fee, provided that the above copyright notice appear in all copies and
+that both that copyright notice and this permission notice appear in supporting documentation.
+CERN makes no representations about the suitability of this software for any purpose.
+It is provided "as is" without expressed or implied warranty.
+*/
+package org.apache.mahout.math;
+
+/**
+ * This empty class is the common root for all persistent capable classes.
+ * If this class inherits from <tt>java.lang.Object</tt> then all subclasses are serializable with
+ * the standard Java serialization mechanism.
+ * If this class inherits from <tt>com.objy.db.app.ooObj</tt> then all subclasses are
+ * <i>additionally</i> serializable with the Objectivity ODBMS persistance mechanism.
+ * Thus, by modifying the inheritance of this class the entire tree of subclasses can
+ * be switched to Objectivity compatibility (and back) with minimum effort.
+ */
+public abstract class PersistentObject implements java.io.Serializable, Cloneable {
+
+ /** Not yet commented. */
+ protected PersistentObject() {
+ }
+
+ /**
+ * Returns a copy of the receiver. This default implementation does not nothing except making the otherwise
+ * <tt>protected</tt> clone method <tt>public</tt>.
+ *
+ * @return a copy of the receiver.
+ */
+ @Override
+ public Object clone() {
+ try {
+ return super.clone();
+ } catch (CloneNotSupportedException exc) {
+ throw new InternalError(); //should never happen since we are cloneable
+ }
+ }
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/PivotedMatrix.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/mahout/math/PivotedMatrix.java b/core/src/main/java/org/apache/mahout/math/PivotedMatrix.java
new file mode 100644
index 0000000..fba1e98
--- /dev/null
+++ b/core/src/main/java/org/apache/mahout/math/PivotedMatrix.java
@@ -0,0 +1,288 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math;
+
+import com.google.common.base.Preconditions;
+
+/**
+ * Matrix that allows transparent row and column permutation.
+ */
+public class PivotedMatrix extends AbstractMatrix {
+
+ private Matrix base;
+ private int[] rowPivot;
+ private int[] rowUnpivot;
+ private int[] columnPivot;
+ private int[] columnUnpivot;
+
+ public PivotedMatrix(Matrix base, int[] pivot) {
+ this(base, pivot, java.util.Arrays.copyOf(pivot, pivot.length));
+ }
+ public PivotedMatrix(Matrix base, int[] rowPivot, int[] columnPivot) {
+ super(base.rowSize(), base.columnSize());
+
+ this.base = base;
+ this.rowPivot = rowPivot;
+ rowUnpivot = invert(rowPivot);
+
+ this.columnPivot = columnPivot;
+ columnUnpivot = invert(columnPivot);
+ }
+
+ public PivotedMatrix(Matrix base) {
+ this(base, identityPivot(base.rowSize()),identityPivot(base.columnSize()));
+ }
+
+ /**
+ * Swaps indexes i and j. This does both row and column permutation.
+ *
+ * @param i First index to swap.
+ * @param j Second index to swap.
+ */
+ public void swap(int i, int j) {
+ swapRows(i, j);
+ swapColumns(i, j);
+ }
+
+ /**
+ * Swaps indexes i and j. This does just row permutation.
+ *
+ * @param i First index to swap.
+ * @param j Second index to swap.
+ */
+ public void swapRows(int i, int j) {
+ swap(rowPivot, rowUnpivot, i, j);
+ }
+
+
+ /**
+ * Swaps indexes i and j. This does just row permutation.
+ *
+ * @param i First index to swap.
+ * @param j Second index to swap.
+ */
+ public void swapColumns(int i, int j) {
+ swap(columnPivot, columnUnpivot, i, j);
+ }
+
+ private static void swap(int[] pivot, int[] unpivot, int i, int j) {
+ Preconditions.checkPositionIndex(i, pivot.length);
+ Preconditions.checkPositionIndex(j, pivot.length);
+ if (i != j) {
+ int tmp = pivot[i];
+ pivot[i] = pivot[j];
+ pivot[j] = tmp;
+
+ unpivot[pivot[i]] = i;
+ unpivot[pivot[j]] = j;
+ }
+ }
+
+ /**
+ * Assign the other vector values to the column of the receiver
+ *
+ * @param column the int row to assign
+ * @param other a Vector
+ * @return the modified receiver
+ * @throws org.apache.mahout.math.CardinalityException
+ * if the cardinalities differ
+ */
+ @Override
+ public Matrix assignColumn(int column, Vector other) {
+ // note the reversed pivoting for other
+ return base.assignColumn(columnPivot[column], new PermutedVectorView(other, rowUnpivot, rowPivot));
+ }
+
+ /**
+ * Assign the other vector values to the row of the receiver
+ *
+ * @param row the int row to assign
+ * @param other a Vector
+ * @return the modified receiver
+ * @throws org.apache.mahout.math.CardinalityException
+ * if the cardinalities differ
+ */
+ @Override
+ public Matrix assignRow(int row, Vector other) {
+ // note the reversed pivoting for other
+ return base.assignRow(rowPivot[row], new PermutedVectorView(other, columnUnpivot, columnPivot));
+ }
+
+ /**
+ * Return the column at the given index
+ *
+ * @param column an int column index
+ * @return a Vector at the index
+ * @throws org.apache.mahout.math.IndexException
+ * if the index is out of bounds
+ */
+ @Override
+ public Vector viewColumn(int column) {
+ if (column < 0 || column >= columnSize()) {
+ throw new IndexException(column, columnSize());
+ }
+ return new PermutedVectorView(base.viewColumn(columnPivot[column]), rowPivot, rowUnpivot);
+ }
+
+ /**
+ * Return the row at the given index
+ *
+ * @param row an int row index
+ * @return a Vector at the index
+ * @throws org.apache.mahout.math.IndexException
+ * if the index is out of bounds
+ */
+ @Override
+ public Vector viewRow(int row) {
+ if (row < 0 || row >= rowSize()) {
+ throw new IndexException(row, rowSize());
+ }
+ return new PermutedVectorView(base.viewRow(rowPivot[row]), columnPivot, columnUnpivot);
+ }
+
+ /**
+ * Return the value at the given indexes, without checking bounds
+ *
+ * @param row an int row index
+ * @param column an int column index
+ * @return the double at the index
+ */
+ @Override
+ public double getQuick(int row, int column) {
+ return base.getQuick(rowPivot[row], columnPivot[column]);
+ }
+
+ /**
+ * Return an empty matrix of the same underlying class as the receiver
+ *
+ * @return a Matrix
+ */
+ @Override
+ public Matrix like() {
+ return new PivotedMatrix(base.like());
+ }
+
+
+ @Override
+ public Matrix clone() {
+ PivotedMatrix clone = (PivotedMatrix) super.clone();
+
+ base = base.clone();
+ rowPivot = rowPivot.clone();
+ rowUnpivot = rowUnpivot.clone();
+ columnPivot = columnPivot.clone();
+ columnUnpivot = columnUnpivot.clone();
+
+ return clone;
+ }
+
+
+ /**
+ * Returns an empty matrix of the same underlying class as the receiver and of the specified
+ * size.
+ *
+ * @param rows the int number of rows
+ * @param columns the int number of columns
+ */
+ @Override
+ public Matrix like(int rows, int columns) {
+ return new PivotedMatrix(base.like(rows, columns));
+ }
+
+ /**
+ * Set the value at the given index, without checking bounds
+ *
+ * @param row an int row index into the receiver
+ * @param column an int column index into the receiver
+ * @param value a double value to set
+ */
+ @Override
+ public void setQuick(int row, int column, double value) {
+ base.setQuick(rowPivot[row], columnPivot[column], value);
+ }
+
+ /**
+ * Return the number of values in the recipient
+ *
+ * @return an int[2] containing [row, column] count
+ */
+ @Override
+ public int[] getNumNondefaultElements() {
+ return base.getNumNondefaultElements();
+ }
+
+ /**
+ * Return a new matrix containing the subset of the recipient
+ *
+ * @param offset an int[2] offset into the receiver
+ * @param size the int[2] size of the desired result
+ * @return a new Matrix that is a view of the original
+ * @throws org.apache.mahout.math.CardinalityException
+ * if the length is greater than the cardinality of the receiver
+ * @throws org.apache.mahout.math.IndexException
+ * if the offset is negative or the offset+length is outside of the receiver
+ */
+ @Override
+ public Matrix viewPart(int[] offset, int[] size) {
+ return new MatrixView(this, offset, size);
+ }
+
+ public int rowUnpivot(int k) {
+ return rowUnpivot[k];
+ }
+
+ public int columnUnpivot(int k) {
+ return columnUnpivot[k];
+ }
+
+ public int[] getRowPivot() {
+ return rowPivot;
+ }
+
+ public int[] getInverseRowPivot() {
+ return rowUnpivot;
+ }
+
+ public int[] getColumnPivot() {
+ return columnPivot;
+ }
+
+ public int[] getInverseColumnPivot() {
+ return columnUnpivot;
+ }
+
+ public Matrix getBase() {
+ return base;
+ }
+
+ private static int[] identityPivot(int n) {
+ int[] pivot = new int[n];
+ for (int i = 0; i < n; i++) {
+ pivot[i] = i;
+ }
+ return pivot;
+ }
+
+ private static int[] invert(int[] pivot) {
+ int[] x = new int[pivot.length];
+ for (int i = 0; i < pivot.length; i++) {
+ x[pivot[i]] = i;
+ }
+ return x;
+ }
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/QR.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/mahout/math/QR.java b/core/src/main/java/org/apache/mahout/math/QR.java
new file mode 100644
index 0000000..5992224
--- /dev/null
+++ b/core/src/main/java/org/apache/mahout/math/QR.java
@@ -0,0 +1,27 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.package org.apache.mahout.math;
+ */
+package org.apache.mahout.math;
+
+public interface QR {
+ Matrix getQ();
+
+ Matrix getR();
+
+ boolean hasFullRank();
+
+ Matrix solve(Matrix B);
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/QRDecomposition.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/mahout/math/QRDecomposition.java b/core/src/main/java/org/apache/mahout/math/QRDecomposition.java
new file mode 100644
index 0000000..ab5b3d2
--- /dev/null
+++ b/core/src/main/java/org/apache/mahout/math/QRDecomposition.java
@@ -0,0 +1,181 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *
+ * Copyright 1999 CERN - European Organization for Nuclear Research.
+ * Permission to use, copy, modify, distribute and sell this software and its documentation for any purpose
+ * is hereby granted without fee, provided that the above copyright notice appear in all copies and
+ * that both that copyright notice and this permission notice appear in supporting documentation.
+ * CERN makes no representations about the suitability of this software for any purpose.
+ * It is provided "as is" without expressed or implied warranty.
+ */
+package org.apache.mahout.math;
+
+import org.apache.mahout.math.function.Functions;
+
+import java.util.Locale;
+
+/**
+ For an <tt>m x n</tt> matrix <tt>A</tt> with {@code m >= n}, the QR decomposition is an <tt>m x n</tt>
+ orthogonal matrix <tt>Q</tt> and an <tt>n x n</tt> upper triangular matrix <tt>R</tt> so that
+ <tt>A = Q*R</tt>.
+ <P>
+ The QR decomposition always exists, even if the matrix does not have
+ full rank, so the constructor will never fail. The primary use of the
+ QR decomposition is in the least squares solution of non-square systems
+ of simultaneous linear equations. This will fail if <tt>isFullRank()</tt>
+ returns <tt>false</tt>.
+ */
+
+public class QRDecomposition implements QR {
+ private final Matrix q;
+ private final Matrix r;
+ private final Matrix mType;
+ private final boolean fullRank;
+ private final int rows;
+ private final int columns;
+
+ /**
+ * Constructs and returns a new QR decomposition object; computed by Householder reflections; The
+ * decomposed matrices can be retrieved via instance methods of the returned decomposition
+ * object.
+ *
+ * @param a A rectangular matrix.
+ * @throws IllegalArgumentException if {@code A.rows() < A.columns()}.
+ */
+ public QRDecomposition(Matrix a) {
+
+ rows = a.rowSize();
+ int min = Math.min(a.rowSize(), a.columnSize());
+ columns = a.columnSize();
+ mType = a.like(1,1);
+
+ Matrix qTmp = a.clone();
+
+ boolean fullRank = true;
+
+ r = new DenseMatrix(min, columns);
+
+ for (int i = 0; i < min; i++) {
+ Vector qi = qTmp.viewColumn(i);
+ double alpha = qi.norm(2);
+ if (Math.abs(alpha) > Double.MIN_VALUE) {
+ qi.assign(Functions.div(alpha));
+ } else {
+ if (Double.isInfinite(alpha) || Double.isNaN(alpha)) {
+ throw new ArithmeticException("Invalid intermediate result");
+ }
+ fullRank = false;
+ }
+ r.set(i, i, alpha);
+
+ for (int j = i + 1; j < columns; j++) {
+ Vector qj = qTmp.viewColumn(j);
+ double norm = qj.norm(2);
+ if (Math.abs(norm) > Double.MIN_VALUE) {
+ double beta = qi.dot(qj);
+ r.set(i, j, beta);
+ if (j < min) {
+ qj.assign(qi, Functions.plusMult(-beta));
+ }
+ } else {
+ if (Double.isInfinite(norm) || Double.isNaN(norm)) {
+ throw new ArithmeticException("Invalid intermediate result");
+ }
+ }
+ }
+ }
+ if (columns > min) {
+ q = qTmp.viewPart(0, rows, 0, min).clone();
+ } else {
+ q = qTmp;
+ }
+ this.fullRank = fullRank;
+ }
+
+ /**
+ * Generates and returns the (economy-sized) orthogonal factor <tt>Q</tt>.
+ *
+ * @return <tt>Q</tt>
+ */
+ @Override
+ public Matrix getQ() {
+ return q;
+ }
+
+ /**
+ * Returns the upper triangular factor, <tt>R</tt>.
+ *
+ * @return <tt>R</tt>
+ */
+ @Override
+ public Matrix getR() {
+ return r;
+ }
+
+ /**
+ * Returns whether the matrix <tt>A</tt> has full rank.
+ *
+ * @return true if <tt>R</tt>, and hence <tt>A</tt>, has full rank.
+ */
+ @Override
+ public boolean hasFullRank() {
+ return fullRank;
+ }
+
+ /**
+ * Least squares solution of <tt>A*X = B</tt>; <tt>returns X</tt>.
+ *
+ * @param B A matrix with as many rows as <tt>A</tt> and any number of columns.
+ * @return <tt>X</tt> that minimizes the two norm of <tt>Q*R*X - B</tt>.
+ * @throws IllegalArgumentException if <tt>B.rows() != A.rows()</tt>.
+ */
+ @Override
+ public Matrix solve(Matrix B) {
+ if (B.numRows() != rows) {
+ throw new IllegalArgumentException("Matrix row dimensions must agree.");
+ }
+
+ int cols = B.numCols();
+ Matrix x = mType.like(columns, cols);
+
+ // this can all be done a bit more efficiently if we don't actually
+ // form explicit versions of Q^T and R but this code isn't so bad
+ // and it is much easier to understand
+ Matrix qt = getQ().transpose();
+ Matrix y = qt.times(B);
+
+ Matrix r = getR();
+ for (int k = Math.min(columns, rows) - 1; k >= 0; k--) {
+ // X[k,] = Y[k,] / R[k,k], note that X[k,] starts with 0 so += is same as =
+ x.viewRow(k).assign(y.viewRow(k), Functions.plusMult(1 / r.get(k, k)));
+
+ // Y[0:(k-1),] -= R[0:(k-1),k] * X[k,]
+ Vector rColumn = r.viewColumn(k).viewPart(0, k);
+ for (int c = 0; c < cols; c++) {
+ y.viewColumn(c).viewPart(0, k).assign(rColumn, Functions.plusMult(-x.get(k, c)));
+ }
+ }
+ return x;
+ }
+
+ /**
+ * Returns a rough string rendition of a QR.
+ */
+ @Override
+ public String toString() {
+ return String.format(Locale.ENGLISH, "QR(%d x %d,fullRank=%s)", rows, columns, hasFullRank());
+ }
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/RandomAccessSparseVector.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/mahout/math/RandomAccessSparseVector.java b/core/src/main/java/org/apache/mahout/math/RandomAccessSparseVector.java
new file mode 100644
index 0000000..c325078
--- /dev/null
+++ b/core/src/main/java/org/apache/mahout/math/RandomAccessSparseVector.java
@@ -0,0 +1,303 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math;
+
+import it.unimi.dsi.fastutil.doubles.DoubleIterator;
+import it.unimi.dsi.fastutil.ints.Int2DoubleMap;
+import it.unimi.dsi.fastutil.ints.Int2DoubleMap.Entry;
+import it.unimi.dsi.fastutil.ints.Int2DoubleOpenHashMap;
+import it.unimi.dsi.fastutil.objects.ObjectIterator;
+
+import java.util.Iterator;
+import java.util.NoSuchElementException;
+
+import org.apache.mahout.math.set.AbstractSet;
+
+/** Implements vector that only stores non-zero doubles */
+public class RandomAccessSparseVector extends AbstractVector {
+
+ private static final int INITIAL_CAPACITY = 11;
+
+ private Int2DoubleOpenHashMap values;
+
+ /** For serialization purposes only. */
+ public RandomAccessSparseVector() {
+ super(0);
+ }
+
+ public RandomAccessSparseVector(int cardinality) {
+ this(cardinality, Math.min(cardinality, INITIAL_CAPACITY)); // arbitrary estimate of 'sparseness'
+ }
+
+ public RandomAccessSparseVector(int cardinality, int initialCapacity) {
+ super(cardinality);
+ values = new Int2DoubleOpenHashMap(initialCapacity, .5f);
+ }
+
+ public RandomAccessSparseVector(Vector other) {
+ this(other.size(), other.getNumNondefaultElements());
+ for (Element e : other.nonZeroes()) {
+ values.put(e.index(), e.get());
+ }
+ }
+
+ private RandomAccessSparseVector(int cardinality, Int2DoubleOpenHashMap values) {
+ super(cardinality);
+ this.values = values;
+ }
+
+ public RandomAccessSparseVector(RandomAccessSparseVector other, boolean shallowCopy) {
+ super(other.size());
+ values = shallowCopy ? other.values : other.values.clone();
+ }
+
+ @Override
+ protected Matrix matrixLike(int rows, int columns) {
+ return new SparseMatrix(rows, columns);
+ }
+
+ @Override
+ public RandomAccessSparseVector clone() {
+ return new RandomAccessSparseVector(size(), values.clone());
+ }
+
+ @Override
+ public String toString() {
+ return sparseVectorToString();
+ }
+
+ @Override
+ public Vector assign(Vector other) {
+ if (size() != other.size()) {
+ throw new CardinalityException(size(), other.size());
+ }
+ values.clear();
+ for (Element e : other.nonZeroes()) {
+ setQuick(e.index(), e.get());
+ }
+ return this;
+ }
+
+ @Override
+ public void mergeUpdates(OrderedIntDoubleMapping updates) {
+ for (int i = 0; i < updates.getNumMappings(); ++i) {
+ values.put(updates.getIndices()[i], updates.getValues()[i]);
+ }
+ }
+
+ /**
+ * @return false
+ */
+ @Override
+ public boolean isDense() {
+ return false;
+ }
+
+ /**
+ * @return false
+ */
+ @Override
+ public boolean isSequentialAccess() {
+ return false;
+ }
+
+ @Override
+ public double getQuick(int index) {
+ return values.get(index);
+ }
+
+ @Override
+ public void setQuick(int index, double value) {
+ invalidateCachedLength();
+ if (value == 0.0) {
+ values.remove(index);
+ } else {
+ values.put(index, value);
+ }
+ }
+
+ @Override
+ public void incrementQuick(int index, double increment) {
+ invalidateCachedLength();
+ values.addTo( index, increment);
+ }
+
+
+ @Override
+ public RandomAccessSparseVector like() {
+ return new RandomAccessSparseVector(size(), values.size());
+ }
+
+ @Override
+ public Vector like(int cardinality) {
+ return new RandomAccessSparseVector(cardinality, values.size());
+ }
+
+ @Override
+ public int getNumNondefaultElements() {
+ return values.size();
+ }
+
+ @Override
+ public int getNumNonZeroElements() {
+ final DoubleIterator iterator = values.values().iterator();
+ int numNonZeros = 0;
+ for( int i = values.size(); i-- != 0; ) if ( iterator.nextDouble() != 0 ) numNonZeros++;
+ return numNonZeros;
+ }
+
+ @Override
+ public double getLookupCost() {
+ return 1;
+ }
+
+ @Override
+ public double getIteratorAdvanceCost() {
+ return 1 + (AbstractSet.DEFAULT_MAX_LOAD_FACTOR + AbstractSet.DEFAULT_MIN_LOAD_FACTOR) / 2;
+ }
+
+ /**
+ * This is "sort of" constant, but really it might resize the array.
+ */
+ @Override
+ public boolean isAddConstantTime() {
+ return true;
+ }
+
+ /*
+ @Override
+ public Element getElement(int index) {
+ // TODO: this should return a MapElement so as to avoid hashing for both getQuick and setQuick.
+ return super.getElement(index);
+ }
+ */
+
+ private final class NonZeroIterator implements Iterator<Element> {
+ final ObjectIterator<Int2DoubleMap.Entry> fastIterator = values.int2DoubleEntrySet().fastIterator();
+ final RandomAccessElement element = new RandomAccessElement( fastIterator );
+
+ @Override
+ public boolean hasNext() {
+ return fastIterator.hasNext();
+ }
+
+ @Override
+ public Element next() {
+ if ( ! hasNext() ) throw new NoSuchElementException();
+ element.entry = fastIterator.next();
+ return element;
+ }
+
+ @Override
+ public void remove() {
+ throw new UnsupportedOperationException();
+ }
+ }
+
+ final class RandomAccessElement implements Element {
+ Int2DoubleMap.Entry entry;
+ final ObjectIterator<Int2DoubleMap.Entry> fastIterator;
+
+ public RandomAccessElement( ObjectIterator<Entry> fastIterator ) {
+ super();
+ this.fastIterator = fastIterator;
+ }
+
+ @Override
+ public double get() {
+ return entry.getDoubleValue();
+ }
+
+ @Override
+ public int index() {
+ return entry.getIntKey();
+ }
+
+ @Override
+ public void set( double value ) {
+ invalidateCachedLength();
+ if (value == 0.0) fastIterator.remove();
+ else entry.setValue( value );
+ }
+ }
+ /**
+ * NOTE: this implementation reuses the Vector.Element instance for each call of next(). If you need to preserve the
+ * instance, you need to make a copy of it
+ *
+ * @return an {@link Iterator} over the Elements.
+ * @see #getElement(int)
+ */
+ @Override
+ public Iterator<Element> iterateNonZero() {
+ return new NonZeroIterator();
+ }
+
+ @Override
+ public Iterator<Element> iterator() {
+ return new AllIterator();
+ }
+
+ final class GeneralElement implements Element {
+ int index;
+ double value;
+
+ @Override
+ public double get() {
+ return value;
+ }
+
+ @Override
+ public int index() {
+ return index;
+ }
+
+ @Override
+ public void set( double value ) {
+ invalidateCachedLength();
+ if (value == 0.0) values.remove( index );
+ else values.put( index, value );
+ }
+}
+
+ private final class AllIterator implements Iterator<Element> {
+ private final GeneralElement element = new GeneralElement();
+
+ private AllIterator() {
+ element.index = -1;
+ }
+
+ @Override
+ public boolean hasNext() {
+ return element.index + 1 < size();
+ }
+
+ @Override
+ public Element next() {
+ if (!hasNext()) {
+ throw new NoSuchElementException();
+ }
+ element.value = values.get( ++element.index );
+ return element;
+ }
+
+ @Override
+ public void remove() {
+ throw new UnsupportedOperationException();
+ }
+ }
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/RandomTrinaryMatrix.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/mahout/math/RandomTrinaryMatrix.java b/core/src/main/java/org/apache/mahout/math/RandomTrinaryMatrix.java
new file mode 100644
index 0000000..85de0cd
--- /dev/null
+++ b/core/src/main/java/org/apache/mahout/math/RandomTrinaryMatrix.java
@@ -0,0 +1,146 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math;
+
+import java.nio.ByteBuffer;
+import java.util.concurrent.atomic.AtomicInteger;
+
+/**
+ * Random matrix. Each value is taken from {-1,0,1} with roughly equal probability. Note
+ * that by default, the value is determined by a relatively simple hash of the coordinates.
+ * Such a hash is not usable where real randomness is required, but suffices nicely for
+ * random projection methods.
+ *
+ * If the simple hash method is not satisfactory, an optional high quality mode is available
+ * which uses a murmur hash of the coordinates.
+ */
+public class RandomTrinaryMatrix extends AbstractMatrix {
+ private static final AtomicInteger ID = new AtomicInteger();
+ private static final int PRIME1 = 104047;
+ private static final int PRIME2 = 101377;
+ private static final int PRIME3 = 64661;
+ private static final long SCALE = 1L << 32;
+
+ private final int seed;
+
+ // set this to true to use a high quality hash
+ private boolean highQuality = false;
+
+ public RandomTrinaryMatrix(int seed, int rows, int columns, boolean highQuality) {
+ super(rows, columns);
+
+ this.highQuality = highQuality;
+ this.seed = seed;
+ }
+
+ public RandomTrinaryMatrix(int rows, int columns) {
+ this(ID.incrementAndGet(), rows, columns, false);
+ }
+
+ @Override
+ public Matrix assignColumn(int column, Vector other) {
+ throw new UnsupportedOperationException("Can't assign to read-only matrix");
+ }
+
+ @Override
+ public Matrix assignRow(int row, Vector other) {
+ throw new UnsupportedOperationException("Can't assign to read-only matrix");
+ }
+
+ /**
+ * Return the value at the given indexes, without checking bounds
+ *
+ * @param row an int row index
+ * @param column an int column index
+ * @return the double at the index
+ */
+ @Override
+ public double getQuick(int row, int column) {
+ if (highQuality) {
+ ByteBuffer buf = ByteBuffer.allocate(8);
+ buf.putInt(row);
+ buf.putInt(column);
+ buf.flip();
+ return (MurmurHash.hash64A(buf, seed) & (SCALE - 1)) / (double) SCALE;
+ } else {
+ // this isn't a fantastic random number generator, but it is just fine for random projections
+ return ((((row * PRIME1) + column * PRIME2 + row * column * PRIME3) & 8) * 0.25) - 1;
+ }
+ }
+
+
+ /**
+ * Return an empty matrix of the same underlying class as the receiver
+ *
+ * @return a Matrix
+ */
+ @Override
+ public Matrix like() {
+ return new DenseMatrix(rowSize(), columnSize());
+ }
+
+ /**
+ * Returns an empty matrix of the same underlying class as the receiver and of the specified
+ * size.
+ *
+ * @param rows the int number of rows
+ * @param columns the int number of columns
+ */
+ @Override
+ public Matrix like(int rows, int columns) {
+ return new DenseMatrix(rows, columns);
+ }
+
+ /**
+ * Set the value at the given index, without checking bounds
+ *
+ * @param row an int row index into the receiver
+ * @param column an int column index into the receiver
+ * @param value a double value to set
+ */
+ @Override
+ public void setQuick(int row, int column, double value) {
+ throw new UnsupportedOperationException("Can't assign to read-only matrix");
+ }
+
+ /**
+ * Return the number of values in the recipient
+ *
+ * @return an int[2] containing [row, column] count
+ */
+ @Override
+ public int[] getNumNondefaultElements() {
+ throw new UnsupportedOperationException("Can't assign to read-only matrix");
+ }
+
+ /**
+ * Return a new matrix containing the subset of the recipient
+ *
+ * @param offset an int[2] offset into the receiver
+ * @param size the int[2] size of the desired result
+ * @return a new Matrix that is a view of the original
+ * @throws org.apache.mahout.math.CardinalityException
+ * if the length is greater than the cardinality of the receiver
+ * @throws org.apache.mahout.math.IndexException
+ * if the offset is negative or the offset+length is outside of the receiver
+ */
+ @Override
+ public Matrix viewPart(int[] offset, int[] size) {
+ return new MatrixView(this, offset, size);
+ }
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/SequentialAccessSparseVector.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/mahout/math/SequentialAccessSparseVector.java b/core/src/main/java/org/apache/mahout/math/SequentialAccessSparseVector.java
new file mode 100644
index 0000000..f7d67a7
--- /dev/null
+++ b/core/src/main/java/org/apache/mahout/math/SequentialAccessSparseVector.java
@@ -0,0 +1,379 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math;
+
+import java.util.Arrays;
+import java.util.Iterator;
+import java.util.NoSuchElementException;
+
+import com.google.common.primitives.Doubles;
+import org.apache.mahout.math.function.Functions;
+
+/**
+ * <p>
+ * Implements vector that only stores non-zero doubles as a pair of parallel arrays (OrderedIntDoubleMapping),
+ * one int[], one double[]. If there are <b>k</b> non-zero elements in the vector, this implementation has
+ * O(log(k)) random-access read performance, and O(k) random-access write performance, which is far below that
+ * of the hashmap based {@link org.apache.mahout.math.RandomAccessSparseVector RandomAccessSparseVector}. This
+ * class is primarily used for operations where the all the elements will be accessed in a read-only fashion
+ * sequentially: methods which operate not via get() or set(), but via iterateNonZero(), such as (but not limited
+ * to) :</p>
+ * <ul>
+ * <li>dot(Vector)</li>
+ * <li>addTo(Vector)</li>
+ * </ul>
+ *
+ * See {@link OrderedIntDoubleMapping}
+ */
+public class SequentialAccessSparseVector extends AbstractVector {
+
+ private OrderedIntDoubleMapping values;
+
+ /** For serialization purposes only. */
+ public SequentialAccessSparseVector() {
+ super(0);
+ }
+
+ public SequentialAccessSparseVector(int cardinality) {
+ this(cardinality, Math.min(100, cardinality / 1000 < 10 ? 10 : cardinality / 1000)); // arbitrary estimate of
+ // 'sparseness'
+ }
+
+ public SequentialAccessSparseVector(int cardinality, int size) {
+ super(cardinality);
+ values = new OrderedIntDoubleMapping(size);
+ }
+
+ public SequentialAccessSparseVector(Vector other) {
+ this(other.size(), other.getNumNondefaultElements());
+
+ if (other.isSequentialAccess()) {
+ for (Element e : other.nonZeroes()) {
+ set(e.index(), e.get());
+ }
+ } else {
+ // If the incoming Vector to copy is random, then adding items
+ // from the Iterator can degrade performance dramatically if
+ // the number of elements is large as this Vector tries to stay
+ // in order as items are added, so it's better to sort the other
+ // Vector's elements by index and then add them to this
+ copySortedRandomAccessSparseVector(other);
+ }
+ }
+
+ // Sorts a RandomAccessSparseVectors Elements before adding them to this
+ private int copySortedRandomAccessSparseVector(Vector other) {
+ int elementCount = other.getNumNondefaultElements();
+ OrderedElement[] sortableElements = new OrderedElement[elementCount];
+ int s = 0;
+ for (Element e : other.nonZeroes()) {
+ sortableElements[s++] = new OrderedElement(e.index(), e.get());
+ }
+ Arrays.sort(sortableElements);
+ for (int i = 0; i < sortableElements.length; i++) {
+ values.setIndexAt(i, sortableElements[i].index);
+ values.setValueAt(i, sortableElements[i].value);
+ }
+ values = new OrderedIntDoubleMapping(values.getIndices(), values.getValues(), elementCount);
+ return elementCount;
+ }
+
+ public SequentialAccessSparseVector(SequentialAccessSparseVector other, boolean shallowCopy) {
+ super(other.size());
+ values = shallowCopy ? other.values : other.values.clone();
+ }
+
+ public SequentialAccessSparseVector(SequentialAccessSparseVector other) {
+ this(other.size(), other.getNumNondefaultElements());
+ values = other.values.clone();
+ }
+
+ private SequentialAccessSparseVector(int cardinality, OrderedIntDoubleMapping values) {
+ super(cardinality);
+ this.values = values;
+ }
+
+ @Override
+ protected Matrix matrixLike(int rows, int columns) {
+ //return new SparseRowMatrix(rows, columns);
+ return new SparseMatrix(rows, columns);
+ }
+
+ @SuppressWarnings("CloneDoesntCallSuperClone")
+ @Override
+ public SequentialAccessSparseVector clone() {
+ return new SequentialAccessSparseVector(size(), values.clone());
+ }
+
+ @Override
+ public void mergeUpdates(OrderedIntDoubleMapping updates) {
+ values.merge(updates);
+ }
+
+ @Override
+ public String toString() {
+ return sparseVectorToString();
+ }
+
+ /**
+ * @return false
+ */
+ @Override
+ public boolean isDense() {
+ return false;
+ }
+
+ /**
+ * @return true
+ */
+ @Override
+ public boolean isSequentialAccess() {
+ return true;
+ }
+
+ /**
+ * Warning! This takes O(log n) time as it does a binary search behind the scenes!
+ * Only use it when STRICTLY necessary.
+ * @param index an int index.
+ * @return the value at that position in the vector.
+ */
+ @Override
+ public double getQuick(int index) {
+ return values.get(index);
+ }
+
+ /**
+ * Warning! This takes O(log n) time as it does a binary search behind the scenes!
+ * Only use it when STRICTLY necessary.
+ * @param index an int index.
+ */
+ @Override
+ public void setQuick(int index, double value) {
+ invalidateCachedLength();
+ values.set(index, value);
+ }
+
+ @Override
+ public void incrementQuick(int index, double increment) {
+ invalidateCachedLength();
+ values.increment(index, increment);
+ }
+
+ @Override
+ public SequentialAccessSparseVector like() {
+ return new SequentialAccessSparseVector(size(), values.getNumMappings());
+ }
+
+ @Override
+ public Vector like(int cardinality) {
+ return new SequentialAccessSparseVector(cardinality);
+ }
+
+ @Override
+ public int getNumNondefaultElements() {
+ return values.getNumMappings();
+ }
+
+ @Override
+ public int getNumNonZeroElements() {
+ double[] elementValues = values.getValues();
+ int numMappedElements = values.getNumMappings();
+ int numNonZeros = 0;
+ for (int index = 0; index < numMappedElements; index++) {
+ if (elementValues[index] != 0) {
+ numNonZeros++;
+ }
+ }
+ return numNonZeros;
+ }
+
+ @Override
+ public double getLookupCost() {
+ return Math.max(1, Math.round(Functions.LOG2.apply(getNumNondefaultElements())));
+ }
+
+ @Override
+ public double getIteratorAdvanceCost() {
+ return 1;
+ }
+
+ @Override
+ public boolean isAddConstantTime() {
+ return false;
+ }
+
+ @Override
+ public Iterator<Element> iterateNonZero() {
+
+ // TODO: this is a bug, since nonDefaultIterator doesn't hold to non-zero contract.
+ return new NonDefaultIterator();
+ }
+
+ @Override
+ public Iterator<Element> iterator() {
+ return new AllIterator();
+ }
+
+ private final class NonDefaultIterator implements Iterator<Element> {
+ private final NonDefaultElement element = new NonDefaultElement();
+
+ @Override
+ public boolean hasNext() {
+ return element.getNextOffset() < values.getNumMappings();
+ }
+
+ @Override
+ public Element next() {
+ if (!hasNext()) {
+ throw new NoSuchElementException();
+ }
+ element.advanceOffset();
+ return element;
+ }
+
+ @Override
+ public void remove() {
+ throw new UnsupportedOperationException();
+ }
+ }
+
+ private final class AllIterator implements Iterator<Element> {
+ private final AllElement element = new AllElement();
+
+ @Override
+ public boolean hasNext() {
+ return element.getNextIndex() < SequentialAccessSparseVector.this.size();
+ }
+
+ @Override
+ public Element next() {
+ if (!hasNext()) {
+ throw new NoSuchElementException();
+ }
+
+ element.advanceIndex();
+ return element;
+ }
+
+ @Override
+ public void remove() {
+ throw new UnsupportedOperationException();
+ }
+ }
+
+ private final class NonDefaultElement implements Element {
+ private int offset = -1;
+
+ void advanceOffset() {
+ offset++;
+ }
+
+ int getNextOffset() {
+ return offset + 1;
+ }
+
+ @Override
+ public double get() {
+ return values.getValues()[offset];
+ }
+
+ @Override
+ public int index() {
+ return values.getIndices()[offset];
+ }
+
+ @Override
+ public void set(double value) {
+ invalidateCachedLength();
+ values.setValueAt(offset, value);
+ }
+ }
+
+ private final class AllElement implements Element {
+ private int index = -1;
+ private int nextOffset;
+
+ void advanceIndex() {
+ index++;
+ if (nextOffset < values.getNumMappings() && index > values.getIndices()[nextOffset]) {
+ nextOffset++;
+ }
+ }
+
+ int getNextIndex() {
+ return index + 1;
+ }
+
+ @Override
+ public double get() {
+ if (nextOffset < values.getNumMappings() && index == values.getIndices()[nextOffset]) {
+ return values.getValues()[nextOffset];
+ } else {
+ return OrderedIntDoubleMapping.DEFAULT_VALUE;
+ }
+ }
+
+ @Override
+ public int index() {
+ return index;
+ }
+
+ @Override
+ public void set(double value) {
+ invalidateCachedLength();
+ if (nextOffset < values.getNumMappings() && index == values.indexAt(nextOffset)) {
+ values.setValueAt(nextOffset, value);
+ } else {
+ // Yes, this works; the offset into indices of the new value's index will still be nextOffset
+ values.set(index, value);
+ }
+ }
+ }
+
+ // Comparable Element for sorting Elements by index
+ private static final class OrderedElement implements Comparable<OrderedElement> {
+ private final int index;
+ private final double value;
+
+ OrderedElement(int index, double value) {
+ this.index = index;
+ this.value = value;
+ }
+
+ @Override
+ public int compareTo(OrderedElement that) {
+ // both indexes are positive, and neither can be Integer.MAX_VALUE (otherwise there would be
+ // an array somewhere with Integer.MAX_VALUE + 1 elements)
+ return this.index - that.index;
+ }
+
+ @Override
+ public int hashCode() {
+ return index ^ Doubles.hashCode(value);
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (!(o instanceof OrderedElement)) {
+ return false;
+ }
+ OrderedElement other = (OrderedElement) o;
+ return index == other.index && value == other.value;
+ }
+ }
+}
r***@apache.org
2018-09-08 23:35:18 UTC
Permalink
http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/Arrays.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/mahout/math/Arrays.java b/core/src/main/java/org/apache/mahout/math/Arrays.java
new file mode 100644
index 0000000..802ffb7
--- /dev/null
+++ b/core/src/main/java/org/apache/mahout/math/Arrays.java
@@ -0,0 +1,662 @@
+/*
+Copyright 1999 CERN - European Organization for Nuclear Research.
+Permission to use, copy, modify, distribute and sell this software and its documentation for any purpose
+is hereby granted without fee, provided that the above copyright notice appear in all copies and
+that both that copyright notice and this permission notice appear in supporting documentation.
+CERN makes no representations about the suitability of this software for any purpose.
+It is provided "as is" without expressed or implied warranty.
+*/
+package org.apache.mahout.math;
+
+/**
+ * Array manipulations; complements <tt>java.util.Arrays</tt>.
+ *
+ * @see java.util.Arrays
+ * @see org.apache.mahout.math.Sorting
+ *
+ */
+public final class Arrays {
+
+ private Arrays() {
+ }
+
+ /**
+ * Ensures that a given array can hold up to <tt>minCapacity</tt> elements.
+ *
+ * Returns the identical array if it can hold at least the number of elements specified. Otherwise, returns a new
+ * array with increased capacity containing the same elements, ensuring that it can hold at least the number of
+ * elements specified by the minimum capacity argument.
+ *
+ * @param minCapacity the desired minimum capacity.
+ */
+ public static byte[] ensureCapacity(byte[] array, int minCapacity) {
+ int oldCapacity = array.length;
+ byte[] newArray;
+ if (minCapacity > oldCapacity) {
+ int newCapacity = (oldCapacity * 3) / 2 + 1;
+ if (newCapacity < minCapacity) {
+ newCapacity = minCapacity;
+ }
+
+ newArray = new byte[newCapacity];
+ System.arraycopy(array, 0, newArray, 0, oldCapacity);
+ } else {
+ newArray = array;
+ }
+ return newArray;
+ }
+
+ /**
+ * Ensures that a given array can hold up to <tt>minCapacity</tt> elements.
+ *
+ * Returns the identical array if it can hold at least the number of elements specified. Otherwise, returns a new
+ * array with increased capacity containing the same elements, ensuring that it can hold at least the number of
+ * elements specified by the minimum capacity argument.
+ *
+ * @param minCapacity the desired minimum capacity.
+ */
+ public static char[] ensureCapacity(char[] array, int minCapacity) {
+ int oldCapacity = array.length;
+ char[] newArray;
+ if (minCapacity > oldCapacity) {
+ int newCapacity = (oldCapacity * 3) / 2 + 1;
+ if (newCapacity < minCapacity) {
+ newCapacity = minCapacity;
+ }
+
+ newArray = new char[newCapacity];
+ System.arraycopy(array, 0, newArray, 0, oldCapacity);
+ } else {
+ newArray = array;
+ }
+ return newArray;
+ }
+
+ /**
+ * Ensures that a given array can hold up to <tt>minCapacity</tt> elements.
+ *
+ * Returns the identical array if it can hold at least the number of elements specified. Otherwise, returns a new
+ * array with increased capacity containing the same elements, ensuring that it can hold at least the number of
+ * elements specified by the minimum capacity argument.
+ *
+ * @param minCapacity the desired minimum capacity.
+ */
+ public static double[] ensureCapacity(double[] array, int minCapacity) {
+ int oldCapacity = array.length;
+ double[] newArray;
+ if (minCapacity > oldCapacity) {
+ int newCapacity = (oldCapacity * 3) / 2 + 1;
+ if (newCapacity < minCapacity) {
+ newCapacity = minCapacity;
+ }
+
+ newArray = new double[newCapacity];
+ //for (int i = oldCapacity; --i >= 0; ) newArray[i] = array[i];
+ System.arraycopy(array, 0, newArray, 0, oldCapacity);
+ } else {
+ newArray = array;
+ }
+ return newArray;
+ }
+
+ /**
+ * Ensures that a given array can hold up to <tt>minCapacity</tt> elements.
+ *
+ * Returns the identical array if it can hold at least the number of elements specified. Otherwise, returns a new
+ * array with increased capacity containing the same elements, ensuring that it can hold at least the number of
+ * elements specified by the minimum capacity argument.
+ *
+ * @param minCapacity the desired minimum capacity.
+ */
+ public static float[] ensureCapacity(float[] array, int minCapacity) {
+ int oldCapacity = array.length;
+ float[] newArray;
+ if (minCapacity > oldCapacity) {
+ int newCapacity = (oldCapacity * 3) / 2 + 1;
+ if (newCapacity < minCapacity) {
+ newCapacity = minCapacity;
+ }
+
+ newArray = new float[newCapacity];
+ System.arraycopy(array, 0, newArray, 0, oldCapacity);
+ } else {
+ newArray = array;
+ }
+ return newArray;
+ }
+
+ /**
+ * Ensures that a given array can hold up to <tt>minCapacity</tt> elements.
+ *
+ * Returns the identical array if it can hold at least the number of elements specified. Otherwise, returns a new
+ * array with increased capacity containing the same elements, ensuring that it can hold at least the number of
+ * elements specified by the minimum capacity argument.
+ *
+ * @param minCapacity the desired minimum capacity.
+ */
+ public static int[] ensureCapacity(int[] array, int minCapacity) {
+ int oldCapacity = array.length;
+ int[] newArray;
+ if (minCapacity > oldCapacity) {
+ int newCapacity = (oldCapacity * 3) / 2 + 1;
+ if (newCapacity < minCapacity) {
+ newCapacity = minCapacity;
+ }
+
+ newArray = new int[newCapacity];
+ System.arraycopy(array, 0, newArray, 0, oldCapacity);
+ } else {
+ newArray = array;
+ }
+ return newArray;
+ }
+
+ /**
+ * Ensures that a given array can hold up to <tt>minCapacity</tt> elements.
+ *
+ * Returns the identical array if it can hold at least the number of elements specified. Otherwise, returns a new
+ * array with increased capacity containing the same elements, ensuring that it can hold at least the number of
+ * elements specified by the minimum capacity argument.
+ *
+ * @param minCapacity the desired minimum capacity.
+ */
+ public static long[] ensureCapacity(long[] array, int minCapacity) {
+ int oldCapacity = array.length;
+ long[] newArray;
+ if (minCapacity > oldCapacity) {
+ int newCapacity = (oldCapacity * 3) / 2 + 1;
+ if (newCapacity < minCapacity) {
+ newCapacity = minCapacity;
+ }
+
+ newArray = new long[newCapacity];
+ System.arraycopy(array, 0, newArray, 0, oldCapacity);
+ } else {
+ newArray = array;
+ }
+ return newArray;
+ }
+
+ /**
+ * Ensures that a given array can hold up to <tt>minCapacity</tt> elements.
+ *
+ * Returns the identical array if it can hold at least the number of elements specified. Otherwise, returns a new
+ * array with increased capacity containing the same elements, ensuring that it can hold at least the number of
+ * elements specified by the minimum capacity argument.
+ *
+ * @param minCapacity the desired minimum capacity.
+ */
+ public static Object[] ensureCapacity(Object[] array, int minCapacity) {
+ int oldCapacity = array.length;
+ Object[] newArray;
+ if (minCapacity > oldCapacity) {
+ int newCapacity = (oldCapacity * 3) / 2 + 1;
+ if (newCapacity < minCapacity) {
+ newCapacity = minCapacity;
+ }
+
+ newArray = new Object[newCapacity];
+ System.arraycopy(array, 0, newArray, 0, oldCapacity);
+ } else {
+ newArray = array;
+ }
+ return newArray;
+ }
+
+ /**
+ * Ensures that a given array can hold up to <tt>minCapacity</tt> elements.
+ *
+ * Returns the identical array if it can hold at least the number of elements specified. Otherwise, returns a new
+ * array with increased capacity containing the same elements, ensuring that it can hold at least the number of
+ * elements specified by the minimum capacity argument.
+ *
+ * @param minCapacity the desired minimum capacity.
+ */
+ public static short[] ensureCapacity(short[] array, int minCapacity) {
+ int oldCapacity = array.length;
+ short[] newArray;
+ if (minCapacity > oldCapacity) {
+ int newCapacity = (oldCapacity * 3) / 2 + 1;
+ if (newCapacity < minCapacity) {
+ newCapacity = minCapacity;
+ }
+
+ newArray = new short[newCapacity];
+ System.arraycopy(array, 0, newArray, 0, oldCapacity);
+ } else {
+ newArray = array;
+ }
+ return newArray;
+ }
+
+ /**
+ * Ensures that a given array can hold up to <tt>minCapacity</tt> elements.
+ *
+ * Returns the identical array if it can hold at least the number of elements specified. Otherwise, returns a new
+ * array with increased capacity containing the same elements, ensuring that it can hold at least the number of
+ * elements specified by the minimum capacity argument.
+ *
+ * @param minCapacity the desired minimum capacity.
+ */
+ public static boolean[] ensureCapacity(boolean[] array, int minCapacity) {
+ int oldCapacity = array.length;
+ boolean[] newArray;
+ if (minCapacity > oldCapacity) {
+ int newCapacity = (oldCapacity * 3) / 2 + 1;
+ if (newCapacity < minCapacity) {
+ newCapacity = minCapacity;
+ }
+
+ newArray = new boolean[newCapacity];
+ System.arraycopy(array, 0, newArray, 0, oldCapacity);
+ } else {
+ newArray = array;
+ }
+ return newArray;
+ }
+
+ /**
+ * Returns a string representation of the specified array. The string representation consists of a list of the
+ * arrays's elements, enclosed in square brackets (<tt>"[]"</tt>). Adjacent elements are separated by the characters
+ * <tt>", "</tt> (comma and space).
+ *
+ * @return a string representation of the specified array.
+ */
+ public static String toString(byte[] array) {
+ StringBuilder buf = new StringBuilder();
+ buf.append('[');
+ int maxIndex = array.length - 1;
+ for (int i = 0; i <= maxIndex; i++) {
+ buf.append(array[i]);
+ if (i < maxIndex) {
+ buf.append(", ");
+ }
+ }
+ buf.append(']');
+ return buf.toString();
+ }
+
+ /**
+ * Returns a string representation of the specified array. The string representation consists of a list of the
+ * arrays's elements, enclosed in square brackets (<tt>"[]"</tt>). Adjacent elements are separated by the characters
+ * <tt>", "</tt> (comma and space).
+ *
+ * @return a string representation of the specified array.
+ */
+ public static String toString(char[] array) {
+ StringBuilder buf = new StringBuilder();
+ buf.append('[');
+ int maxIndex = array.length - 1;
+ for (int i = 0; i <= maxIndex; i++) {
+ buf.append(array[i]);
+ if (i < maxIndex) {
+ buf.append(", ");
+ }
+ }
+ buf.append(']');
+ return buf.toString();
+ }
+
+ /**
+ * Returns a string representation of the specified array. The string representation consists of a list of the
+ * arrays's elements, enclosed in square brackets (<tt>"[]"</tt>). Adjacent elements are separated by the characters
+ * <tt>", "</tt> (comma and space).
+ *
+ * @return a string representation of the specified array.
+ */
+ public static String toString(double[] array) {
+ StringBuilder buf = new StringBuilder();
+ buf.append('[');
+ int maxIndex = array.length - 1;
+ for (int i = 0; i <= maxIndex; i++) {
+ buf.append(array[i]);
+ if (i < maxIndex) {
+ buf.append(", ");
+ }
+ }
+ buf.append(']');
+ return buf.toString();
+ }
+
+ /**
+ * Returns a string representation of the specified array. The string representation consists of a list of the
+ * arrays's elements, enclosed in square brackets (<tt>"[]"</tt>). Adjacent elements are separated by the characters
+ * <tt>", "</tt> (comma and space).
+ *
+ * @return a string representation of the specified array.
+ */
+ public static String toString(float[] array) {
+ StringBuilder buf = new StringBuilder();
+ buf.append('[');
+ int maxIndex = array.length - 1;
+ for (int i = 0; i <= maxIndex; i++) {
+ buf.append(array[i]);
+ if (i < maxIndex) {
+ buf.append(", ");
+ }
+ }
+ buf.append(']');
+ return buf.toString();
+ }
+
+ /**
+ * Returns a string representation of the specified array. The string representation consists of a list of the
+ * arrays's elements, enclosed in square brackets (<tt>"[]"</tt>). Adjacent elements are separated by the characters
+ * <tt>", "</tt> (comma and space).
+ *
+ * @return a string representation of the specified array.
+ */
+ public static String toString(int[] array) {
+ StringBuilder buf = new StringBuilder();
+ buf.append('[');
+ int maxIndex = array.length - 1;
+ for (int i = 0; i <= maxIndex; i++) {
+ buf.append(array[i]);
+ if (i < maxIndex) {
+ buf.append(", ");
+ }
+ }
+ buf.append(']');
+ return buf.toString();
+ }
+
+ /**
+ * Returns a string representation of the specified array. The string representation consists of a list of the
+ * arrays's elements, enclosed in square brackets (<tt>"[]"</tt>). Adjacent elements are separated by the characters
+ * <tt>", "</tt> (comma and space).
+ *
+ * @return a string representation of the specified array.
+ */
+ public static String toString(long[] array) {
+ StringBuilder buf = new StringBuilder();
+ buf.append('[');
+ int maxIndex = array.length - 1;
+ for (int i = 0; i <= maxIndex; i++) {
+ buf.append(array[i]);
+ if (i < maxIndex) {
+ buf.append(", ");
+ }
+ }
+ buf.append(']');
+ return buf.toString();
+ }
+
+ /**
+ * Returns a string representation of the specified array. The string representation consists of a list of the
+ * arrays's elements, enclosed in square brackets (<tt>"[]"</tt>). Adjacent elements are separated by the characters
+ * <tt>", "</tt> (comma and space).
+ *
+ * @return a string representation of the specified array.
+ */
+ public static String toString(Object[] array) {
+ StringBuilder buf = new StringBuilder();
+ buf.append('[');
+ int maxIndex = array.length - 1;
+ for (int i = 0; i <= maxIndex; i++) {
+ buf.append(array[i]);
+ if (i < maxIndex) {
+ buf.append(", ");
+ }
+ }
+ buf.append(']');
+ return buf.toString();
+ }
+
+ /**
+ * Returns a string representation of the specified array. The string representation consists of a list of the
+ * arrays's elements, enclosed in square brackets (<tt>"[]"</tt>). Adjacent elements are separated by the characters
+ * <tt>", "</tt> (comma and space).
+ *
+ * @return a string representation of the specified array.
+ */
+ public static String toString(short[] array) {
+ StringBuilder buf = new StringBuilder();
+ buf.append('[');
+ int maxIndex = array.length - 1;
+ for (int i = 0; i <= maxIndex; i++) {
+ buf.append(array[i]);
+ if (i < maxIndex) {
+ buf.append(", ");
+ }
+ }
+ buf.append(']');
+ return buf.toString();
+ }
+
+ /**
+ * Returns a string representation of the specified array. The string representation consists of a list of the
+ * arrays's elements, enclosed in square brackets (<tt>"[]"</tt>). Adjacent elements are separated by the characters
+ * <tt>", "</tt> (comma and space).
+ *
+ * @return a string representation of the specified array.
+ */
+ public static String toString(boolean[] array) {
+ StringBuilder buf = new StringBuilder();
+ buf.append('[');
+ int maxIndex = array.length - 1;
+ for (int i = 0; i <= maxIndex; i++) {
+ buf.append(array[i]);
+ if (i < maxIndex) {
+ buf.append(", ");
+ }
+ }
+ buf.append(']');
+ return buf.toString();
+ }
+
+ /**
+ * Ensures that the specified array cannot hold more than <tt>maxCapacity</tt> elements. An application can use this
+ * operation to minimize array storage. <p> Returns the identical array if <tt>array.length &lt;= maxCapacity</tt>.
+ * Otherwise, returns a new array with a length of <tt>maxCapacity</tt> containing the first <tt>maxCapacity</tt>
+ * elements of <tt>array</tt>.
+ *
+ * @param maxCapacity the desired maximum capacity.
+ */
+ public static byte[] trimToCapacity(byte[] array, int maxCapacity) {
+ if (array.length > maxCapacity) {
+ byte[] oldArray = array;
+ array = new byte[maxCapacity];
+ System.arraycopy(oldArray, 0, array, 0, maxCapacity);
+ }
+ return array;
+ }
+
+ /**
+ * Ensures that the specified array cannot hold more than <tt>maxCapacity</tt> elements. An application can use this
+ * operation to minimize array storage. <p> Returns the identical array if <tt>array.length &lt;= maxCapacity</tt>.
+ * Otherwise, returns a new array with a length of <tt>maxCapacity</tt> containing the first <tt>maxCapacity</tt>
+ * elements of <tt>array</tt>.
+ *
+ * @param maxCapacity the desired maximum capacity.
+ */
+ public static char[] trimToCapacity(char[] array, int maxCapacity) {
+ if (array.length > maxCapacity) {
+ char[] oldArray = array;
+ array = new char[maxCapacity];
+ System.arraycopy(oldArray, 0, array, 0, maxCapacity);
+ }
+ return array;
+ }
+
+ /**
+ * Ensures that the specified array cannot hold more than <tt>maxCapacity</tt> elements. An application can use this
+ * operation to minimize array storage. <p> Returns the identical array if <tt>array.length &lt;= maxCapacity</tt>.
+ * Otherwise, returns a new array with a length of <tt>maxCapacity</tt> containing the first <tt>maxCapacity</tt>
+ * elements of <tt>array</tt>.
+ *
+ * @param maxCapacity the desired maximum capacity.
+ */
+ public static double[] trimToCapacity(double[] array, int maxCapacity) {
+ if (array.length > maxCapacity) {
+ double[] oldArray = array;
+ array = new double[maxCapacity];
+ System.arraycopy(oldArray, 0, array, 0, maxCapacity);
+ }
+ return array;
+ }
+
+ /**
+ * Ensures that the specified array cannot hold more than <tt>maxCapacity</tt> elements. An application can use this
+ * operation to minimize array storage. <p> Returns the identical array if <tt>array.length &lt;= maxCapacity</tt>.
+ * Otherwise, returns a new array with a length of <tt>maxCapacity</tt> containing the first <tt>maxCapacity</tt>
+ * elements of <tt>array</tt>.
+ *
+ * @param maxCapacity the desired maximum capacity.
+ */
+ public static float[] trimToCapacity(float[] array, int maxCapacity) {
+ if (array.length > maxCapacity) {
+ float[] oldArray = array;
+ array = new float[maxCapacity];
+ System.arraycopy(oldArray, 0, array, 0, maxCapacity);
+ }
+ return array;
+ }
+
+ /**
+ * Ensures that the specified array cannot hold more than <tt>maxCapacity</tt> elements. An application can use this
+ * operation to minimize array storage. <p> Returns the identical array if <tt>array.length &lt;= maxCapacity</tt>.
+ * Otherwise, returns a new array with a length of <tt>maxCapacity</tt> containing the first <tt>maxCapacity</tt>
+ * elements of <tt>array</tt>.
+ *
+ * @param maxCapacity the desired maximum capacity.
+ */
+ public static int[] trimToCapacity(int[] array, int maxCapacity) {
+ if (array.length > maxCapacity) {
+ int[] oldArray = array;
+ array = new int[maxCapacity];
+ System.arraycopy(oldArray, 0, array, 0, maxCapacity);
+ }
+ return array;
+ }
+
+ /**
+ * Ensures that the specified array cannot hold more than <tt>maxCapacity</tt> elements. An application can use this
+ * operation to minimize array storage. <p> Returns the identical array if <tt>array.length &lt;= maxCapacity</tt>.
+ * Otherwise, returns a new array with a length of <tt>maxCapacity</tt> containing the first <tt>maxCapacity</tt>
+ * elements of <tt>array</tt>.
+ *
+ * @param maxCapacity the desired maximum capacity.
+ */
+ public static long[] trimToCapacity(long[] array, int maxCapacity) {
+ if (array.length > maxCapacity) {
+ long[] oldArray = array;
+ array = new long[maxCapacity];
+ System.arraycopy(oldArray, 0, array, 0, maxCapacity);
+ }
+ return array;
+ }
+
+ /**
+ * Ensures that the specified array cannot hold more than <tt>maxCapacity</tt> elements. An application can use this
+ * operation to minimize array storage. <p> Returns the identical array if <tt>array.length &lt;= maxCapacity</tt>.
+ * Otherwise, returns a new array with a length of <tt>maxCapacity</tt> containing the first <tt>maxCapacity</tt>
+ * elements of <tt>array</tt>.
+ *
+ * @param maxCapacity the desired maximum capacity.
+ */
+ public static Object[] trimToCapacity(Object[] array, int maxCapacity) {
+ if (array.length > maxCapacity) {
+ Object[] oldArray = array;
+ array = new Object[maxCapacity];
+ System.arraycopy(oldArray, 0, array, 0, maxCapacity);
+ }
+ return array;
+ }
+
+ /**
+ * Ensures that the specified array cannot hold more than <tt>maxCapacity</tt> elements. An application can use this
+ * operation to minimize array storage. <p> Returns the identical array if <tt>array.length &lt;= maxCapacity</tt>.
+ * Otherwise, returns a new array with a length of <tt>maxCapacity</tt> containing the first <tt>maxCapacity</tt>
+ * elements of <tt>array</tt>.
+ *
+ * @param maxCapacity the desired maximum capacity.
+ */
+ public static short[] trimToCapacity(short[] array, int maxCapacity) {
+ if (array.length > maxCapacity) {
+ short[] oldArray = array;
+ array = new short[maxCapacity];
+ System.arraycopy(oldArray, 0, array, 0, maxCapacity);
+ }
+ return array;
+ }
+
+ /**
+ * Ensures that the specified array cannot hold more than <tt>maxCapacity</tt> elements. An application can use this
+ * operation to minimize array storage. <p> Returns the identical array if <tt>array.length &lt;= maxCapacity</tt>.
+ * Otherwise, returns a new array with a length of <tt>maxCapacity</tt> containing the first <tt>maxCapacity</tt>
+ * elements of <tt>array</tt>.
+ *
+ * @param maxCapacity the desired maximum capacity.
+ */
+ public static boolean[] trimToCapacity(boolean[] array, int maxCapacity) {
+ if (array.length > maxCapacity) {
+ boolean[] oldArray = array;
+ array = new boolean[maxCapacity];
+ System.arraycopy(oldArray, 0, array, 0, maxCapacity);
+ }
+ return array;
+ }
+
+ /**
+ * {@link java.util.Arrays#copyOf} compatibility with Java 1.5.
+ */
+ public static byte[] copyOf(byte[] src, int length) {
+ byte[] result = new byte [length];
+ System.arraycopy(src, 0, result, 0, Math.min(length, src.length));
+ return result;
+ }
+
+ /**
+ * {@link java.util.Arrays#copyOf} compatibility with Java 1.5.
+ */
+ public static char[] copyOf(char[] src, int length) {
+ char[] result = new char [length];
+ System.arraycopy(src, 0, result, 0, Math.min(length, src.length));
+ return result;
+ }
+
+ /**
+ * {@link java.util.Arrays#copyOf} compatibility with Java 1.5.
+ */
+ public static short[] copyOf(short[] src, int length) {
+ short[] result = new short [length];
+ System.arraycopy(src, 0, result, 0, Math.min(length, src.length));
+ return result;
+ }
+
+ /**
+ * {@link java.util.Arrays#copyOf} compatibility with Java 1.5.
+ */
+ public static int[] copyOf(int[] src, int length) {
+ int[] result = new int [length];
+ System.arraycopy(src, 0, result, 0, Math.min(length, src.length));
+ return result;
+ }
+
+ /**
+ * {@link java.util.Arrays#copyOf} compatibility with Java 1.5.
+ */
+ public static float[] copyOf(float[] src, int length) {
+ float[] result = new float [length];
+ System.arraycopy(src, 0, result, 0, Math.min(length, src.length));
+ return result;
+ }
+
+ /**
+ * {@link java.util.Arrays#copyOf} compatibility with Java 1.5.
+ */
+ public static double[] copyOf(double[] src, int length) {
+ double[] result = new double [length];
+ System.arraycopy(src, 0, result, 0, Math.min(length, src.length));
+ return result;
+ }
+
+ /**
+ * {@link java.util.Arrays#copyOf} compatibility with Java 1.5.
+ */
+ public static long[] copyOf(long[] src, int length) {
+ long[] result = new long [length];
+ System.arraycopy(src, 0, result, 0, Math.min(length, src.length));
+ return result;
+ }
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/BinarySearch.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/mahout/math/BinarySearch.java b/core/src/main/java/org/apache/mahout/math/BinarySearch.java
new file mode 100644
index 0000000..ddb04a7
--- /dev/null
+++ b/core/src/main/java/org/apache/mahout/math/BinarySearch.java
@@ -0,0 +1,403 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math;
+
+import java.util.Comparator;
+
+public final class BinarySearch {
+
+ private BinarySearch() {}
+
+ /**
+ * Performs a binary search for the specified element in the specified
+ * ascending sorted array. Searching in an unsorted array has an undefined
+ * result. It's also undefined which element is found if there are multiple
+ * occurrences of the same element.
+ *
+ * @param array
+ * the sorted {@code byte} array to search.
+ * @param value
+ * the {@code byte} element to find.
+ * @param from
+ * the first index to sort, inclusive.
+ * @param to
+ * the last index to sort, inclusive.
+ * @return the non-negative index of the element, or a negative index which is
+ * {@code -index - 1} where the element would be inserted.
+ */
+ public static int binarySearchFromTo(byte[] array, byte value, int from, int to) {
+ int mid = -1;
+ while (from <= to) {
+ mid = (from + to) >>> 1;
+ if (value > array[mid]) {
+ from = mid + 1;
+ } else if (value == array[mid]) {
+ return mid;
+ } else {
+ to = mid - 1;
+ }
+ }
+ if (mid < 0) {
+ return -1;
+ }
+
+ return -mid - (value < array[mid] ? 1 : 2);
+ }
+
+ /**
+ * Performs a binary search for the specified element in the specified
+ * ascending sorted array. Searching in an unsorted array has an undefined
+ * result. It's also undefined which element is found if there are multiple
+ * occurrences of the same element.
+ *
+ * @param array
+ * the sorted {@code char} array to search.
+ * @param value
+ * the {@code char} element to find.
+ * @param from
+ * the first index to sort, inclusive.
+ * @param to
+ * the last index to sort, inclusive.
+ * @return the non-negative index of the element, or a negative index which is
+ * {@code -index - 1} where the element would be inserted.
+ */
+ public static int binarySearchFromTo(char[] array, char value, int from, int to) {
+ int mid = -1;
+ while (from <= to) {
+ mid = (from + to) >>> 1;
+ if (value > array[mid]) {
+ from = mid + 1;
+ } else if (value == array[mid]) {
+ return mid;
+ } else {
+ to = mid - 1;
+ }
+ }
+ if (mid < 0) {
+ return -1;
+ }
+ return -mid - (value < array[mid] ? 1 : 2);
+ }
+
+ /**
+ * Performs a binary search for the specified element in the specified
+ * ascending sorted array. Searching in an unsorted array has an undefined
+ * result. It's also undefined which element is found if there are multiple
+ * occurrences of the same element.
+ *
+ * @param array
+ * the sorted {@code double} array to search.
+ * @param value
+ * the {@code double} element to find.
+ * @param from
+ * the first index to sort, inclusive.
+ * @param to
+ * the last index to sort, inclusive.
+ * @return the non-negative index of the element, or a negative index which is
+ * {@code -index - 1} where the element would be inserted.
+ */
+ public static int binarySearchFromTo(double[] array, double value, int from, int to) {
+ long longBits = Double.doubleToLongBits(value);
+ int mid = -1;
+ while (from <= to) {
+ mid = (from + to) >>> 1;
+ if (lessThan(array[mid], value)) {
+ from = mid + 1;
+ } else if (longBits == Double.doubleToLongBits(array[mid])) {
+ return mid;
+ } else {
+ to = mid - 1;
+ }
+ }
+ if (mid < 0) {
+ return -1;
+ }
+ return -mid - (lessThan(value, array[mid]) ? 1 : 2);
+ }
+
+ /**
+ * Performs a binary search for the specified element in the specified
+ * ascending sorted array. Searching in an unsorted array has an undefined
+ * result. It's also undefined which element is found if there are multiple
+ * occurrences of the same element.
+ *
+ * @param array
+ * the sorted {@code float} array to search.
+ * @param value
+ * the {@code float} element to find.
+ * @param from
+ * the first index to sort, inclusive.
+ * @param to
+ * the last index to sort, inclusive.
+ * @return the non-negative index of the element, or a negative index which is
+ * {@code -index - 1} where the element would be inserted.
+ */
+ public static int binarySearchFromTo(float[] array, float value, int from, int to) {
+ int intBits = Float.floatToIntBits(value);
+ int mid = -1;
+ while (from <= to) {
+ mid = (from + to) >>> 1;
+ if (lessThan(array[mid], value)) {
+ from = mid + 1;
+ } else if (intBits == Float.floatToIntBits(array[mid])) {
+ return mid;
+ } else {
+ to = mid - 1;
+ }
+ }
+ if (mid < 0) {
+ return -1;
+ }
+ return -mid - (lessThan(value, array[mid]) ? 1 : 2);
+ }
+
+ /**
+ * Performs a binary search for the specified element in the specified
+ * ascending sorted array. Searching in an unsorted array has an undefined
+ * result. It's also undefined which element is found if there are multiple
+ * occurrences of the same element.
+ *
+ * @param array
+ * the sorted {@code int} array to search.
+ * @param value
+ * the {@code int} element to find.
+ * @param from
+ * the first index to sort, inclusive.
+ * @param to
+ * the last index to sort, inclusive.
+ * @return the non-negative index of the element, or a negative index which is
+ * {@code -index - 1} where the element would be inserted.
+ */
+ public static int binarySearchFromTo(int[] array, int value, int from, int to) {
+ int mid = -1;
+ while (from <= to) {
+ mid = (from + to) >>> 1;
+ if (value > array[mid]) {
+ from = mid + 1;
+ } else if (value == array[mid]) {
+ return mid;
+ } else {
+ to = mid - 1;
+ }
+ }
+ if (mid < 0) {
+ return -1;
+ }
+ return -mid - (value < array[mid] ? 1 : 2);
+ }
+
+ /**
+ * Performs a binary search for the specified element in the specified
+ * ascending sorted array. Searching in an unsorted array has an undefined
+ * result. It's also undefined which element is found if there are multiple
+ * occurrences of the same element.
+ *
+ * @param array
+ * the sorted {@code long} array to search.
+ * @param value
+ * the {@code long} element to find.
+ * @param from
+ * the first index to sort, inclusive.
+ * @param to
+ * the last index to sort, inclusive.
+ * @return the non-negative index of the element, or a negative index which is
+ * {@code -index - 1} where the element would be inserted.
+ */
+ public static int binarySearchFromTo(long[] array, long value, int from, int to) {
+ int mid = -1;
+ while (from <= to) {
+ mid = (from + to) >>> 1;
+ if (value > array[mid]) {
+ from = mid + 1;
+ } else if (value == array[mid]) {
+ return mid;
+ } else {
+ to = mid - 1;
+ }
+ }
+ if (mid < 0) {
+ return -1;
+ }
+ return -mid - (value < array[mid] ? 1 : 2);
+ }
+
+ /**
+ * Performs a binary search for the specified element in the specified
+ * ascending sorted array. Searching in an unsorted array has an undefined
+ * result. It's also undefined which element is found if there are multiple
+ * occurrences of the same element.
+ *
+ * @param array
+ * the sorted {@code Object} array to search.
+ * @param object
+ * the {@code Object} element to find
+ * @param from
+ * the first index to sort, inclusive.
+ * @param to
+ * the last index to sort, inclusive.
+ * @return the non-negative index of the element, or a negative index which is
+ * {@code -index - 1} where the element would be inserted.
+ *
+ */
+ public static <T extends Comparable<T>> int binarySearchFromTo(T[] array, T object, int from, int to) {
+ if (array.length == 0) {
+ return -1;
+ }
+
+ int mid = 0;
+ int result = 0;
+ while (from <= to) {
+ mid = (from + to) >>> 1;
+ if ((result = array[mid].compareTo(object)) < 0) {
+ from = mid + 1;
+ } else if (result == 0) {
+ return mid;
+ } else {
+ to = mid - 1;
+ }
+ }
+ return -mid - (result >= 0 ? 1 : 2);
+ }
+
+ /**
+ * Performs a binary search for the specified element in the specified
+ * ascending sorted array using the {@code Comparator} to compare elements.
+ * Searching in an unsorted array has an undefined result. It's also undefined
+ * which element is found if there are multiple occurrences of the same
+ * element.
+ *
+ * @param array
+ * the sorted array to search
+ * @param object
+ * the element to find
+ * @param from
+ * the first index to sort, inclusive.
+ * @param to
+ * the last index to sort, inclusive.
+ * @param comparator
+ * the {@code Comparator} used to compare the elements.
+ * @return the non-negative index of the element, or a negative index which
+ */
+ public static <T> int binarySearchFromTo(T[] array, T object, int from, int to, Comparator<? super T> comparator) {
+ int mid = 0;
+ int result = 0;
+ while (from <= to) {
+ mid = (from + to) >>> 1;
+ if ((result = comparator.compare(array[mid], object)) < 0) {
+ from = mid + 1;
+ } else if (result == 0) {
+ return mid;
+ } else {
+ to = mid - 1;
+ }
+ }
+ return -mid - (result >= 0 ? 1 : 2);
+ }
+
+ /**
+ * Performs a binary search for the specified element in the specified
+ * ascending sorted array. Searching in an unsorted array has an undefined
+ * result. It's also undefined which element is found if there are multiple
+ * occurrences of the same element.
+ *
+ * @param array
+ * the sorted {@code short} array to search.
+ * @param value
+ * the {@code short} element to find.
+ * @param from
+ * the first index to sort, inclusive.
+ * @param to
+ * the last index to sort, inclusive.
+ * @return the non-negative index of the element, or a negative index which is
+ * {@code -index - 1} where the element would be inserted.
+ */
+ public static int binarySearchFromTo(short[] array, short value, int from, int to) {
+ int mid = -1;
+ while (from <= to) {
+ mid = (from + to) >>> 1;
+ if (value > array[mid]) {
+ from = mid + 1;
+ } else if (value == array[mid]) {
+ return mid;
+ } else {
+ to = mid - 1;
+ }
+ }
+ if (mid < 0) {
+ return -1;
+ }
+ return -mid - (value < array[mid] ? 1 : 2);
+ }
+
+ private static boolean lessThan(double double1, double double2) {
+ // A slightly specialized version of
+ // Double.compare(double1, double2) < 0.
+
+ // Non-zero and non-NaN checking.
+ if (double1 < double2) {
+ return true;
+ }
+ if (double1 > double2) {
+ return false;
+ }
+ if (double1 == double2 && double1 != 0.0) {
+ return false;
+ }
+
+ // NaNs are equal to other NaNs and larger than any other double.
+ if (Double.isNaN(double1)) {
+ return false;
+ }
+ if (Double.isNaN(double2)) {
+ return true;
+ }
+
+ // Deal with +0.0 and -0.0.
+ long d1 = Double.doubleToRawLongBits(double1);
+ long d2 = Double.doubleToRawLongBits(double2);
+ return d1 < d2;
+ }
+
+ private static boolean lessThan(float float1, float float2) {
+ // A slightly specialized version of Float.compare(float1, float2) < 0.
+
+ // Non-zero and non-NaN checking.
+ if (float1 < float2) {
+ return true;
+ }
+ if (float1 > float2) {
+ return false;
+ }
+ if (float1 == float2 && float1 != 0.0f) {
+ return false;
+ }
+
+ // NaNs are equal to other NaNs and larger than any other float
+ if (Float.isNaN(float1)) {
+ return false;
+ }
+ if (Float.isNaN(float2)) {
+ return true;
+ }
+
+ // Deal with +0.0 and -0.0
+ int f1 = Float.floatToRawIntBits(float1);
+ int f2 = Float.floatToRawIntBits(float2);
+ return f1 < f2;
+ }
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/CardinalityException.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/mahout/math/CardinalityException.java b/core/src/main/java/org/apache/mahout/math/CardinalityException.java
new file mode 100644
index 0000000..04e7602
--- /dev/null
+++ b/core/src/main/java/org/apache/mahout/math/CardinalityException.java
@@ -0,0 +1,30 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math;
+
+/**
+ * Exception thrown when there is a cardinality mismatch in matrix or vector operations.
+ * For example, vectors of differing cardinality cannot be added.
+ */
+public class CardinalityException extends IllegalArgumentException {
+
+ public CardinalityException(int expected, int cardinality) {
+ super("Required cardinality " + expected + " but got " + cardinality);
+ }
+
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/Centroid.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/mahout/math/Centroid.java b/core/src/main/java/org/apache/mahout/math/Centroid.java
new file mode 100644
index 0000000..dceffe1
--- /dev/null
+++ b/core/src/main/java/org/apache/mahout/math/Centroid.java
@@ -0,0 +1,89 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math;
+
+import org.apache.mahout.math.function.Functions;
+
+/**
+ * A centroid is a weighted vector. We have it delegate to the vector itself for lots of operations
+ * to make it easy to use vector search classes and such.
+ */
+public class Centroid extends WeightedVector {
+ public Centroid(WeightedVector original) {
+ super(original.getVector().like().assign(original), original.getWeight(), original.getIndex());
+ }
+
+ public Centroid(int key, Vector initialValue) {
+ super(initialValue, 1, key);
+ }
+
+ public Centroid(int key, Vector initialValue, double weight) {
+ super(initialValue, weight, key);
+ }
+
+ public static Centroid create(int key, Vector initialValue) {
+ if (initialValue instanceof WeightedVector) {
+ return new Centroid(key, new DenseVector(initialValue), ((WeightedVector) initialValue).getWeight());
+ } else {
+ return new Centroid(key, new DenseVector(initialValue), 1);
+ }
+ }
+
+ public void update(Vector v) {
+ if (v instanceof Centroid) {
+ Centroid c = (Centroid) v;
+ update(c.delegate, c.getWeight());
+ } else {
+ update(v, 1);
+ }
+ }
+
+ public void update(Vector other, final double wy) {
+ final double wx = getWeight();
+ delegate.assign(other, Functions.reweigh(wx, wy));
+ setWeight(wx + wy);
+ }
+
+ @Override
+ public Centroid like() {
+ return new Centroid(getIndex(), getVector().like(), getWeight());
+ }
+
+ /**
+ * Gets the index of this centroid. Use getIndex instead to maintain standard names.
+ */
+ @Deprecated
+ public int getKey() {
+ return getIndex();
+ }
+
+ public void addWeight(double newWeight) {
+ setWeight(getWeight() + newWeight);
+ }
+
+ @Override
+ public String toString() {
+ return String.format("key = %d, weight = %.2f, vector = %s", getIndex(), getWeight(), delegate);
+ }
+
+ @SuppressWarnings("CloneDoesntCallSuperClone")
+ @Override
+ public Centroid clone() {
+ return new Centroid(this);
+ }
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/CholeskyDecomposition.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/mahout/math/CholeskyDecomposition.java b/core/src/main/java/org/apache/mahout/math/CholeskyDecomposition.java
new file mode 100644
index 0000000..5cea8e5
--- /dev/null
+++ b/core/src/main/java/org/apache/mahout/math/CholeskyDecomposition.java
@@ -0,0 +1,227 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math;
+
+import com.google.common.base.Preconditions;
+import org.apache.mahout.math.function.Functions;
+
+/**
+ * Cholesky decomposition shamelessly ported from JAMA.
+ * <p>
+ * A Cholesky decomposition of a semi-positive definite matrix A is a lower triangular matrix L such
+ * that L L^* = A. If A is full rank, L is unique. If A is real, then it must be symmetric and R
+ * will also be real.
+ */
+public class CholeskyDecomposition {
+ private final PivotedMatrix L;
+ private boolean isPositiveDefinite = true;
+
+ public CholeskyDecomposition(Matrix a) {
+ this(a, true);
+ }
+
+ public CholeskyDecomposition(Matrix a, boolean pivot) {
+ int rows = a.rowSize();
+ L = new PivotedMatrix(new DenseMatrix(rows, rows));
+
+ // must be square
+ Preconditions.checkArgument(rows == a.columnSize(), "Must be a Square Matrix");
+
+ if (pivot) {
+ decomposeWithPivoting(a);
+ } else {
+ decompose(a);
+ }
+ }
+
+ private void decomposeWithPivoting(Matrix a) {
+ int n = a.rowSize();
+ L.assign(a);
+
+ // pivoted column-wise submatrix cholesky with simple pivoting
+ double uberMax = L.viewDiagonal().aggregate(Functions.MAX, Functions.ABS);
+ for (int k = 0; k < n; k++) {
+ double max = 0;
+ int pivot = k;
+ for (int j = k; j < n; j++) {
+ if (L.get(j, j) > max) {
+ max = L.get(j, j);
+ pivot = j;
+ if (uberMax < Math.abs(max)) {
+ uberMax = Math.abs(max);
+ }
+ }
+ }
+ L.swap(k, pivot);
+
+ double akk = L.get(k, k);
+ double epsilon = 1.0e-10 * Math.max(uberMax, L.viewColumn(k).aggregate(Functions.MAX, Functions.ABS));
+
+ if (akk < -epsilon) {
+ // can't have decidedly negative element on diagonal
+ throw new IllegalArgumentException("Matrix is not positive semi-definite");
+ } else if (akk <= epsilon) {
+ // degenerate column case. Set all to zero
+ L.viewColumn(k).assign(0);
+ isPositiveDefinite = false;
+
+ // no need to subtract from remaining sub-matrix
+ } else {
+ // normalize column by diagonal element
+ akk = Math.sqrt(Math.max(0, akk));
+ L.viewColumn(k).viewPart(k, n - k).assign(Functions.div(akk));
+ L.viewColumn(k).viewPart(0, k).assign(0);
+
+ // subtract off scaled version of this column to the right
+ for (int j = k + 1; j < n; j++) {
+ Vector columnJ = L.viewColumn(j).viewPart(k, n - k);
+ Vector columnK = L.viewColumn(k).viewPart(k, n - k);
+ columnJ.assign(columnK, Functions.minusMult(columnK.get(j - k)));
+ }
+
+ }
+ }
+ }
+
+ private void decompose(Matrix a) {
+ int n = a.rowSize();
+ L.assign(a);
+
+ // column-wise submatrix cholesky with simple pivoting
+ for (int k = 0; k < n; k++) {
+
+ double akk = L.get(k, k);
+
+ // set upper part of column to 0.
+ L.viewColumn(k).viewPart(0, k).assign(0);
+
+ double epsilon = 1.0e-10 * L.viewColumn(k).aggregate(Functions.MAX, Functions.ABS);
+ if (akk <= epsilon) {
+ // degenerate column case. Set diagonal to 1, all others to zero
+ L.viewColumn(k).viewPart(k, n - k).assign(0);
+
+ isPositiveDefinite = false;
+
+ // no need to subtract from remaining sub-matrix
+ } else {
+ // normalize column by diagonal element
+ akk = Math.sqrt(Math.max(0, akk));
+ L.set(k, k, akk);
+ L.viewColumn(k).viewPart(k + 1, n - k - 1).assign(Functions.div(akk));
+
+ // now subtract scaled version of column
+ for (int j = k + 1; j < n; j++) {
+ Vector columnJ = L.viewColumn(j).viewPart(j, n - j);
+ Vector columnK = L.viewColumn(k).viewPart(j, n - j);
+ columnJ.assign(columnK, Functions.minusMult(L.get(j, k)));
+ }
+ }
+ }
+ }
+
+ public boolean isPositiveDefinite() {
+ return isPositiveDefinite;
+ }
+
+ public Matrix getL() {
+ return L.getBase();
+ }
+
+ public PivotedMatrix getPermutedL() {
+ return L;
+ }
+
+ /**
+ * @return Returns the permutation of rows and columns that was applied to L
+ */
+ public int[] getPivot() {
+ return L.getRowPivot();
+ }
+
+ public int[] getInversePivot() {
+ return L.getInverseRowPivot();
+ }
+
+ /**
+ * Compute inv(L) * z efficiently.
+ *
+ * @param z
+ */
+ public Matrix solveLeft(Matrix z) {
+ int n = L.columnSize();
+ int nx = z.columnSize();
+
+ Matrix X = new DenseMatrix(n, z.columnSize());
+ X.assign(z);
+
+ // Solve L*Y = Z using back-substitution
+ // note that k and i have to go in a funny order because L is pivoted
+ for (int internalK = 0; internalK < n; internalK++) {
+ int k = L.rowUnpivot(internalK);
+ for (int j = 0; j < nx; j++) {
+ for (int internalI = 0; internalI < internalK; internalI++) {
+ int i = L.rowUnpivot(internalI);
+ X.set(k, j, X.get(k, j) - X.get(i, j) * L.get(k, i));
+ }
+ if (L.get(k, k) != 0) {
+ X.set(k, j, X.get(k, j) / L.get(k, k));
+ } else {
+ X.set(k, j, 0);
+ }
+ }
+ }
+ return X;
+ }
+
+ /**
+ * Compute z * inv(L') efficiently
+ */
+ public Matrix solveRight(Matrix z) {
+ int n = z.columnSize();
+ int nx = z.rowSize();
+
+ Matrix x = new DenseMatrix(z.rowSize(), z.columnSize());
+ x.assign(z);
+
+ // Solve Y*L' = Z using back-substitution
+ for (int internalK = 0; internalK < n; internalK++) {
+ int k = L.rowUnpivot(internalK);
+ for (int j = 0; j < nx; j++) {
+ for (int internalI = 0; internalI < k; internalI++) {
+ int i = L.rowUnpivot(internalI);
+ x.set(j, k, x.get(j, k) - x.get(j, i) * L.get(k, i));
+ if (Double.isInfinite(x.get(j, k)) || Double.isNaN(x.get(j, k))) {
+ throw new IllegalStateException(
+ String.format("Invalid value found at %d,%d (should not be possible)", j, k));
+ }
+ }
+ if (L.get(k, k) != 0) {
+ x.set(j, k, x.get(j, k) / L.get(k, k));
+ } else {
+ x.set(j, k, 0);
+ }
+ if (Double.isInfinite(x.get(j, k)) || Double.isNaN(x.get(j, k))) {
+ throw new IllegalStateException(String.format("Invalid value found at %d,%d (should not be possible)", j, k));
+ }
+ }
+ }
+ return x;
+ }
+
+}
+

http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/ConstantVector.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/mahout/math/ConstantVector.java b/core/src/main/java/org/apache/mahout/math/ConstantVector.java
new file mode 100644
index 0000000..f10f631
--- /dev/null
+++ b/core/src/main/java/org/apache/mahout/math/ConstantVector.java
@@ -0,0 +1,177 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math;
+
+import java.util.Iterator;
+
+import com.google.common.collect.AbstractIterator;
+
+/**
+ * Implements a vector with all the same values.
+ */
+public class ConstantVector extends AbstractVector {
+ private final double value;
+
+ public ConstantVector(double value, int size) {
+ super(size);
+ this.value = value;
+ }
+
+ /**
+ * Subclasses must override to return an appropriately sparse or dense result
+ *
+ * @param rows the row cardinality
+ * @param columns the column cardinality
+ * @return a Matrix
+ */
+ @Override
+ protected Matrix matrixLike(int rows, int columns) {
+ return new DenseMatrix(rows, columns);
+ }
+
+ /**
+ * Used internally by assign() to update multiple indices and values at once.
+ * Only really useful for sparse vectors (especially SequentialAccessSparseVector).
+ * <p>
+ * If someone ever adds a new type of sparse vectors, this method must merge (index, value) pairs into the vector.
+ *
+ * @param updates a mapping of indices to values to merge in the vector.
+ */
+ @Override
+ public void mergeUpdates(OrderedIntDoubleMapping updates) {
+ throw new UnsupportedOperationException("Cannot mutate a ConstantVector");
+ }
+
+ /**
+ * @return true iff this implementation should be considered dense -- that it explicitly represents
+ * every value
+ */
+ @Override
+ public boolean isDense() {
+ return true;
+ }
+
+ /**
+ * @return true iff this implementation should be considered to be iterable in index order in an
+ * efficient way. In particular this implies that {@link #iterator()} and {@link
+ * #iterateNonZero()} return elements in ascending order by index.
+ */
+ @Override
+ public boolean isSequentialAccess() {
+ return true;
+ }
+
+ /**
+ * Iterates over all elements <p>
+ * NOTE: Implementations may choose to reuse the Element returned
+ * for performance reasons, so if you need a copy of it, you should call {@link #getElement(int)}
+ * for the given index
+ *
+ * @return An {@link java.util.Iterator} over all elements
+ */
+ @Override
+ public Iterator<Element> iterator() {
+ return new AbstractIterator<Element>() {
+ private int i = 0;
+ private final int n = size();
+ @Override
+ protected Element computeNext() {
+ if (i < n) {
+ return new LocalElement(i++);
+ } else {
+ return endOfData();
+ }
+ }
+ };
+ }
+
+ /**
+ * Iterates over all non-zero elements.<p>
+ * NOTE: Implementations may choose to reuse the Element
+ * returned for performance reasons, so if you need a copy of it, you should call {@link
+ * #getElement(int)} for the given index
+ *
+ * @return An {@link java.util.Iterator} over all non-zero elements
+ */
+ @Override
+ public Iterator<Element> iterateNonZero() {
+ return iterator();
+ }
+
+ /**
+ * Return the value at the given index, without checking bounds
+ *
+ * @param index an int index
+ * @return the double at the index
+ */
+ @Override
+ public double getQuick(int index) {
+ return value;
+ }
+
+ /**
+ * Return an empty vector of the same underlying class as the receiver
+ *
+ * @return a Vector
+ */
+ @Override
+ public Vector like() {
+ return new DenseVector(size());
+ }
+
+ @Override
+ public Vector like(int cardinality) {
+ return new DenseVector(cardinality);
+ }
+
+ /**
+ * Set the value at the given index, without checking bounds
+ *
+ * @param index an int index into the receiver
+ * @param value a double value to set
+ */
+ @Override
+ public void setQuick(int index, double value) {
+ throw new UnsupportedOperationException("Can't set a value in a constant matrix");
+ }
+
+ /**
+ * Return the number of values in the recipient
+ *
+ * @return an int
+ */
+ @Override
+ public int getNumNondefaultElements() {
+ return size();
+ }
+
+ @Override
+ public double getLookupCost() {
+ return 1;
+ }
+
+ @Override
+ public double getIteratorAdvanceCost() {
+ return 1;
+ }
+
+ @Override
+ public boolean isAddConstantTime() {
+ throw new UnsupportedOperationException("Cannot mutate a ConstantVector");
+ }
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/DelegatingVector.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/mahout/math/DelegatingVector.java b/core/src/main/java/org/apache/mahout/math/DelegatingVector.java
new file mode 100644
index 0000000..0b2e36b
--- /dev/null
+++ b/core/src/main/java/org/apache/mahout/math/DelegatingVector.java
@@ -0,0 +1,336 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math;
+
+import org.apache.mahout.math.function.DoubleDoubleFunction;
+import org.apache.mahout.math.function.DoubleFunction;
+
+/**
+ * A delegating vector provides an easy way to decorate vectors with weights or id's and such while
+ * keeping all of the Vector functionality.
+ *
+ * This vector implements LengthCachingVector because almost all delegates cache the length and
+ * the cost of false positives is very low.
+ */
+public class DelegatingVector implements Vector, LengthCachingVector {
+ protected Vector delegate;
+
+ public DelegatingVector(Vector v) {
+ delegate = v;
+ }
+
+ protected DelegatingVector() {
+ }
+
+ public Vector getVector() {
+ return delegate;
+ }
+
+ @Override
+ public double aggregate(DoubleDoubleFunction aggregator, DoubleFunction map) {
+ return delegate.aggregate(aggregator, map);
+ }
+
+ @Override
+ public double aggregate(Vector other, DoubleDoubleFunction aggregator, DoubleDoubleFunction combiner) {
+ return delegate.aggregate(other, aggregator, combiner);
+ }
+
+ @Override
+ public Vector viewPart(int offset, int length) {
+ return delegate.viewPart(offset, length);
+ }
+
+ @SuppressWarnings("CloneDoesntDeclareCloneNotSupportedException")
+ @Override
+ public Vector clone() {
+ DelegatingVector r;
+ try {
+ r = (DelegatingVector) super.clone();
+ } catch (CloneNotSupportedException e) {
+ throw new RuntimeException("Clone not supported for DelegatingVector, shouldn't be possible");
+ }
+ // delegate points to original without this
+ r.delegate = delegate.clone();
+ return r;
+ }
+
+ @Override
+ public Iterable<Element> all() {
+ return delegate.all();
+ }
+
+ @Override
+ public Iterable<Element> nonZeroes() {
+ return delegate.nonZeroes();
+ }
+
+ @Override
+ public Vector divide(double x) {
+ return delegate.divide(x);
+ }
+
+ @Override
+ public double dot(Vector x) {
+ return delegate.dot(x);
+ }
+
+ @Override
+ public double get(int index) {
+ return delegate.get(index);
+ }
+
+ @Override
+ public Element getElement(int index) {
+ return delegate.getElement(index);
+ }
+
+ /**
+ * Merge a set of (index, value) pairs into the vector.
+ *
+ * @param updates an ordered mapping of indices to values to be merged in.
+ */
+ @Override
+ public void mergeUpdates(OrderedIntDoubleMapping updates) {
+ delegate.mergeUpdates(updates);
+ }
+
+ @Override
+ public Vector minus(Vector that) {
+ return delegate.minus(that);
+ }
+
+ @Override
+ public Vector normalize() {
+ return delegate.normalize();
+ }
+
+ @Override
+ public Vector normalize(double power) {
+ return delegate.normalize(power);
+ }
+
+ @Override
+ public Vector logNormalize() {
+ return delegate.logNormalize();
+ }
+
+ @Override
+ public Vector logNormalize(double power) {
+ return delegate.logNormalize(power);
+ }
+
+ @Override
+ public double norm(double power) {
+ return delegate.norm(power);
+ }
+
+ @Override
+ public double getLengthSquared() {
+ return delegate.getLengthSquared();
+ }
+
+ @Override
+ public void invalidateCachedLength() {
+ if (delegate instanceof LengthCachingVector) {
+ ((LengthCachingVector) delegate).invalidateCachedLength();
+ }
+ }
+
+ @Override
+ public double getDistanceSquared(Vector v) {
+ return delegate.getDistanceSquared(v);
+ }
+
+ @Override
+ public double getLookupCost() {
+ return delegate.getLookupCost();
+ }
+
+ @Override
+ public double getIteratorAdvanceCost() {
+ return delegate.getIteratorAdvanceCost();
+ }
+
+ @Override
+ public boolean isAddConstantTime() {
+ return delegate.isAddConstantTime();
+ }
+
+ @Override
+ public double maxValue() {
+ return delegate.maxValue();
+ }
+
+ @Override
+ public int maxValueIndex() {
+ return delegate.maxValueIndex();
+ }
+
+ @Override
+ public double minValue() {
+ return delegate.minValue();
+ }
+
+ @Override
+ public int minValueIndex() {
+ return delegate.minValueIndex();
+ }
+
+ @Override
+ public Vector plus(double x) {
+ return delegate.plus(x);
+ }
+
+ @Override
+ public Vector plus(Vector x) {
+ return delegate.plus(x);
+ }
+
+ @Override
+ public void set(int index, double value) {
+ delegate.set(index, value);
+ }
+
+ @Override
+ public Vector times(double x) {
+ return delegate.times(x);
+ }
+
+ @Override
+ public Vector times(Vector x) {
+ return delegate.times(x);
+ }
+
+ @Override
+ public double zSum() {
+ return delegate.zSum();
+ }
+
+ @Override
+ public Vector assign(double value) {
+ delegate.assign(value);
+ return this;
+ }
+
+ @Override
+ public Vector assign(double[] values) {
+ delegate.assign(values);
+ return this;
+ }
+
+ @Override
+ public Vector assign(Vector other) {
+ delegate.assign(other);
+ return this;
+ }
+
+ @Override
+ public Vector assign(DoubleDoubleFunction f, double y) {
+ delegate.assign(f, y);
+ return this;
+ }
+
+ @Override
+ public Vector assign(DoubleFunction function) {
+ delegate.assign(function);
+ return this;
+ }
+
+ @Override
+ public Vector assign(Vector other, DoubleDoubleFunction function) {
+ delegate.assign(other, function);
+ return this;
+ }
+
+ @Override
+ public Matrix cross(Vector other) {
+ return delegate.cross(other);
+ }
+
+ @Override
+ public int size() {
+ return delegate.size();
+ }
+
+ @Override
+ public String asFormatString() {
+ return delegate.asFormatString();
+ }
+
+ @Override
+ public int hashCode() {
+ return delegate.hashCode();
+ }
+
+ @SuppressWarnings("EqualsWhichDoesntCheckParameterClass")
+ @Override
+ public boolean equals(Object o) {
+ return delegate.equals(o);
+ }
+
+ @Override
+ public String toString() {
+ return delegate.toString();
+ }
+
+ @Override
+ public boolean isDense() {
+ return delegate.isDense();
+ }
+
+ @Override
+ public boolean isSequentialAccess() {
+ return delegate.isSequentialAccess();
+ }
+
+ @Override
+ public double getQuick(int index) {
+ return delegate.getQuick(index);
+ }
+
+ @Override
+ public Vector like() {
+ return new DelegatingVector(delegate.like());
+ }
+
+ @Override
+ public Vector like(int cardinality) {
+ return new DelegatingVector(delegate.like(cardinality));
+ }
+
+ @Override
+ public void setQuick(int index, double value) {
+ delegate.setQuick(index, value);
+ }
+
+ @Override
+ public void incrementQuick(int index, double increment) {
+ delegate.incrementQuick(index, increment);
+ }
+
+ @Override
+ public int getNumNondefaultElements() {
+ return delegate.getNumNondefaultElements();
+ }
+
+ @Override
+ public int getNumNonZeroElements() {
+ return delegate.getNumNonZeroElements();
+ }
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/DenseMatrix.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/mahout/math/DenseMatrix.java b/core/src/main/java/org/apache/mahout/math/DenseMatrix.java
new file mode 100644
index 0000000..eac449a
--- /dev/null
+++ b/core/src/main/java/org/apache/mahout/math/DenseMatrix.java
@@ -0,0 +1,193 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math;
+
+import org.apache.mahout.math.flavor.MatrixFlavor;
+
+import java.util.Arrays;
+
+/** Matrix of doubles implemented using a 2-d array */
+public class DenseMatrix extends AbstractMatrix {
+
+ private double[][] values;
+
+ /**
+ * Construct a matrix from the given values
+ *
+ * @param values
+ * a double[][]
+ */
+ public DenseMatrix(double[][] values) {
+ this(values, false);
+ }
+
+ /**
+ * Construct a matrix from the given values
+ *
+ * @param values
+ * a double[][]
+ * @param shallowCopy directly use the supplied array?
+ */
+ public DenseMatrix(double[][] values, boolean shallowCopy) {
+ super(values.length, values[0].length);
+ if (shallowCopy) {
+ this.values = values;
+ } else {
+ this.values = new double[values.length][];
+ for (int i = 0; i < values.length; i++) {
+ this.values[i] = values[i].clone();
+ }
+ }
+ }
+
+ /**
+ * Constructs an empty matrix of the given size.
+ * @param rows The number of rows in the result.
+ * @param columns The number of columns in the result.
+ */
+ public DenseMatrix(int rows, int columns) {
+ super(rows, columns);
+ this.values = new double[rows][columns];
+ }
+
+ /**
+ * Returns the backing array
+ * @return double[][]
+ */
+ public double[][] getBackingStructure() {
+ return this.values;
+ }
+
+ @Override
+ public Matrix clone() {
+ DenseMatrix clone = (DenseMatrix) super.clone();
+ clone.values = new double[values.length][];
+ for (int i = 0; i < values.length; i++) {
+ clone.values[i] = values[i].clone();
+ }
+ return clone;
+ }
+
+ @Override
+ public double getQuick(int row, int column) {
+ return values[row][column];
+ }
+
+ @Override
+ public Matrix like() {
+ return like(rowSize(), columnSize());
+ }
+
+ @Override
+ public Matrix like(int rows, int columns) {
+ return new DenseMatrix(rows, columns);
+ }
+
+ @Override
+ public void setQuick(int row, int column, double value) {
+ values[row][column] = value;
+ }
+
+ @Override
+ public Matrix viewPart(int[] offset, int[] size) {
+ int rowOffset = offset[ROW];
+ int rowsRequested = size[ROW];
+ int columnOffset = offset[COL];
+ int columnsRequested = size[COL];
+
+ return viewPart(rowOffset, rowsRequested, columnOffset, columnsRequested);
+ }
+
+ @Override
+ public Matrix viewPart(int rowOffset, int rowsRequested, int columnOffset, int columnsRequested) {
+ if (rowOffset < 0) {
+ throw new IndexException(rowOffset, rowSize());
+ }
+ if (rowOffset + rowsRequested > rowSize()) {
+ throw new IndexException(rowOffset + rowsRequested, rowSize());
+ }
+ if (columnOffset < 0) {
+ throw new IndexException(columnOffset, columnSize());
+ }
+ if (columnOffset + columnsRequested > columnSize()) {
+ throw new IndexException(columnOffset + columnsRequested, columnSize());
+ }
+ return new MatrixView(this, new int[]{rowOffset, columnOffset}, new int[]{rowsRequested, columnsRequested});
+ }
+
+ @Override
+ public Matrix assign(double value) {
+ for (int row = 0; row < rowSize(); row++) {
+ Arrays.fill(values[row], value);
+ }
+ return this;
+ }
+
+ public Matrix assign(DenseMatrix matrix) {
+ // make sure the data field has the correct length
+ if (matrix.values[0].length != this.values[0].length || matrix.values.length != this.values.length) {
+ this.values = new double[matrix.values.length][matrix.values[0].length];
+ }
+ // now copy the values
+ for (int i = 0; i < this.values.length; i++) {
+ System.arraycopy(matrix.values[i], 0, this.values[i], 0, this.values[0].length);
+ }
+ return this;
+ }
+
+ @Override
+ public Matrix assignColumn(int column, Vector other) {
+ if (rowSize() != other.size()) {
+ throw new CardinalityException(rowSize(), other.size());
+ }
+ if (column < 0 || column >= columnSize()) {
+ throw new IndexException(column, columnSize());
+ }
+ for (int row = 0; row < rowSize(); row++) {
+ values[row][column] = other.getQuick(row);
+ }
+ return this;
+ }
+
+ @Override
+ public Matrix assignRow(int row, Vector other) {
+ if (columnSize() != other.size()) {
+ throw new CardinalityException(columnSize(), other.size());
+ }
+ if (row < 0 || row >= rowSize()) {
+ throw new IndexException(row, rowSize());
+ }
+ for (int col = 0; col < columnSize(); col++) {
+ values[row][col] = other.getQuick(col);
+ }
+ return this;
+ }
+
+ @Override
+ public Vector viewRow(int row) {
+ if (row < 0 || row >= rowSize()) {
+ throw new IndexException(row, rowSize());
+ }
+ return new DenseVector(values[row], true);
+ }
+
+ @Override
+ public MatrixFlavor getFlavor() {
+ return MatrixFlavor.DENSELIKE;
+ }
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/DenseSymmetricMatrix.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/mahout/math/DenseSymmetricMatrix.java b/core/src/main/java/org/apache/mahout/math/DenseSymmetricMatrix.java
new file mode 100644
index 0000000..7252b9b
--- /dev/null
+++ b/core/src/main/java/org/apache/mahout/math/DenseSymmetricMatrix.java
@@ -0,0 +1,62 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math;
+
+import org.apache.mahout.math.flavor.TraversingStructureEnum;
+
+/**
+ * Economy packaging for a dense symmetric in-core matrix.
+ */
+public class DenseSymmetricMatrix extends UpperTriangular {
+ public DenseSymmetricMatrix(int n) {
+ super(n);
+ }
+
+ public DenseSymmetricMatrix(double[] data, boolean shallow) {
+ super(data, shallow);
+ }
+
+ public DenseSymmetricMatrix(Vector data) {
+ super(data);
+ }
+
+ public DenseSymmetricMatrix(UpperTriangular mx) {
+ super(mx);
+ }
+
+ @Override
+ public double getQuick(int row, int column) {
+ if (column < row) {
+ int swap = row;
+ row = column;
+ column = swap;
+ }
+ return super.getQuick(row, column);
+ }
+
+ @Override
+ public void setQuick(int row, int column, double value) {
+ if (column < row) {
+ int swap = row;
+ row = column;
+ column = swap;
+ }
+ super.setQuick(row, column, value);
+ }
+
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/DenseVector.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/mahout/math/DenseVector.java b/core/src/main/java/org/apache/mahout/math/DenseVector.java
new file mode 100644
index 0000000..3961966
--- /dev/null
+++ b/core/src/main/java/org/apache/mahout/math/DenseVector.java
@@ -0,0 +1,442 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math;
+
+import java.util.Arrays;
+import java.util.Iterator;
+import java.util.NoSuchElementException;
+
+import com.google.common.base.Preconditions;
+
+/** Implements vector as an array of doubles */
+public class DenseVector extends AbstractVector {
+
+ private double[] values;
+
+ /** For serialization purposes only */
+ public DenseVector() {
+ super(0);
+ }
+
+ /** Construct a new instance using provided values
+ * @param values - array of values
+ */
+ public DenseVector(double[] values) {
+ this(values, false);
+ }
+
+ public DenseVector(double[] values, boolean shallowCopy) {
+ super(values.length);
+ this.values = shallowCopy ? values : values.clone();
+ }
+
+ public DenseVector(DenseVector values, boolean shallowCopy) {
+ this(values.values, shallowCopy);
+ }
+
+ /** Construct a new instance of the given cardinality
+ * @param cardinality - number of values in the vector
+ */
+ public DenseVector(int cardinality) {
+ super(cardinality);
+ this.values = new double[cardinality];
+ }
+
+ /**
+ * Copy-constructor (for use in turning a sparse vector into a dense one, for example)
+ * @param vector The vector to copy
+ */
+ public DenseVector(Vector vector) {
+ super(vector.size());
+ values = new double[vector.size()];
+ for (Element e : vector.nonZeroes()) {
+ values[e.index()] = e.get();
+ }
+ }
+
+ @Override
+ public double dot(Vector x) {
+ if (!x.isDense()) {
+ return super.dot(x);
+ } else {
+
+ int size = x.size();
+ if (values.length != size) {
+ throw new CardinalityException(values.length, size);
+ }
+
+ double sum = 0;
+ for (int n = 0; n < size; n++) {
+ sum += values[n] * x.getQuick(n);
+ }
+ return sum;
+ }
+ }
+
+ @Override
+ protected Matrix matrixLike(int rows, int columns) {
+ return new DenseMatrix(rows, columns);
+ }
+
+ @SuppressWarnings("CloneDoesntCallSuperClone")
+ @Override
+ public DenseVector clone() {
+ return new DenseVector(values.clone());
+ }
+
+ /**
+ * @return true
+ */
+ @Override
+ public boolean isDense() {
+ return true;
+ }
+
+ /**
+ * @return true
+ */
+ @Override
+ public boolean isSequentialAccess() {
+ return true;
+ }
+
+ @Override
+ protected double dotSelf() {
+ double result = 0.0;
+ int max = size();
+ for (int i = 0; i < max; i++) {
+ result += values[i] * values[i];
+ }
+ return result;
+ }
+
+ @Override
+ public double getQuick(int index) {
+ return values[index];
+ }
+
+ @Override
+ public DenseVector like() {
+ return new DenseVector(size());
+ }
+
+ @Override
+ public Vector like(int cardinality) {
+ return new DenseVector(cardinality);
+ }
+
+ @Override
+ public void setQuick(int index, double value) {
+ invalidateCachedLength();
+ values[index] = value;
+ }
+
+ @Override
+ public void incrementQuick(int index, double increment) {
+ invalidateCachedLength();
+ values[index] += increment;
+ }
+
+ @Override
+ public Vector assign(double value) {
+ invalidateCachedLength();
+ Arrays.fill(values, value);
+ return this;
+ }
+
+ @Override
+ public int getNumNondefaultElements() {
+ return values.length;
+ }
+
+ @Override
+ public int getNumNonZeroElements() {
+ int numNonZeros = 0;
+ for (int index = 0; index < values.length; index++) {
+ if (values[index] != 0) {
+ numNonZeros++;
+ }
+ }
+ return numNonZeros;
+ }
+
+ public Vector assign(DenseVector vector) {
+ // make sure the data field has the correct length
+ if (vector.values.length != this.values.length) {
+ this.values = new double[vector.values.length];
+ }
+ // now copy the values
+ System.arraycopy(vector.values, 0, this.values, 0, this.values.length);
+ return this;
+ }
+
+ @Override
+ public void mergeUpdates(OrderedIntDoubleMapping updates) {
+ int numUpdates = updates.getNumMappings();
+ int[] indices = updates.getIndices();
+ double[] values = updates.getValues();
+ for (int i = 0; i < numUpdates; ++i) {
+ this.values[indices[i]] = values[i];
+ }
+ }
+
+ @Override
+ public Vector viewPart(int offset, int length) {
+ if (offset < 0) {
+ throw new IndexException(offset, size());
+ }
+ if (offset + length > size()) {
+ throw new IndexException(offset + length, size());
+ }
+ return new DenseVectorView(this, offset, length);
+ }
+
+ @Override
+ public double getLookupCost() {
+ return 1;
+ }
+
+ @Override
+ public double getIteratorAdvanceCost() {
+ return 1;
+ }
+
+ @Override
+ public boolean isAddConstantTime() {
+ return true;
+ }
+
+ /**
+ * Returns an iterator that traverses this Vector from 0 to cardinality-1, in that order.
+ */
+ @Override
+ public Iterator<Element> iterateNonZero() {
+ return new NonDefaultIterator();
+ }
+
+ @Override
+ public Iterator<Element> iterator() {
+ return new AllIterator();
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (o instanceof DenseVector) {
+ // Speedup for DenseVectors
+ return Arrays.equals(values, ((DenseVector) o).values);
+ }
+ return super.equals(o);
+ }
+
+ public void addAll(Vector v) {
+ if (size() != v.size()) {
+ throw new CardinalityException(size(), v.size());
+ }
+
+ for (Element element : v.nonZeroes()) {
+ values[element.index()] += element.get();
+ }
+ }
+
+ private final class NonDefaultIterator implements Iterator<Element> {
+ private final DenseElement element = new DenseElement();
+ private int index = -1;
+ private int lookAheadIndex = -1;
+
+ @Override
+ public boolean hasNext() {
+ if (lookAheadIndex == index) { // User calls hasNext() after a next()
+ lookAhead();
+ } // else user called hasNext() repeatedly.
+ return lookAheadIndex < size();
+ }
+
+ private void lookAhead() {
+ lookAheadIndex++;
+ while (lookAheadIndex < size() && values[lookAheadIndex] == 0.0) {
+ lookAheadIndex++;
+ }
+ }
+
+ @Override
+ public Element next() {
+ if (lookAheadIndex == index) { // If user called next() without checking hasNext().
+ lookAhead();
+ }
+
+ Preconditions.checkState(lookAheadIndex > index);
+ index = lookAheadIndex;
+
+ if (index >= size()) { // If the end is reached.
+ throw new NoSuchElementException();
+ }
+
+ element.index = index;
+ return element;
+ }
+
+ @Override
+ public void remove() {
+ throw new UnsupportedOperationException();
+ }
+ }
+
+ private final class AllIterator implements Iterator<Element> {
+ private final DenseElement element = new DenseElement();
+
+ private AllIterator() {
+ element.index = -1;
+ }
+
+ @Override
+ public boolean hasNext() {
+ return element.index + 1 < size();
+ }
+
+ @Override
+ public Element next() {
+ if (element.index + 1 >= size()) { // If the end is reached.
+ throw new NoSuchElementException();
+ }
+ element.index++;
+ return element;
+ }
+
+ @Override
+ public void remove() {
+ throw new UnsupportedOperationException();
+ }
+ }
+
+ private final class DenseElement implements Element {
+ int index;
+
+ @Override
+ public double get() {
+ return values[index];
+ }
+
+ @Override
+ public int index() {
+ return index;
+ }
+
+ @Override
+ public void set(double value) {
+ invalidateCachedLength();
+ values[index] = value;
+ }
+ }
+
+ private final class DenseVectorView extends VectorView {
+
+ public DenseVectorView(Vector vector, int offset, int cardinality) {
+ super(vector, offset, cardinality);
+ }
+
+ @Override
+ public double dot(Vector x) {
+
+ // Apply custom dot kernels for pairs of dense vectors or their views to reduce
+ // view indirection.
+ if (x instanceof DenseVectorView) {
+
+ if (size() != x.size())
+ throw new IllegalArgumentException("Cardinality mismatch during dot(x,y).");
+
+ DenseVectorView xv = (DenseVectorView) x;
+ double[] thisValues = ((DenseVector) vector).values;
+ double[] thatValues = ((DenseVector) xv.vector).values;
+ int untilOffset = offset + size();
+
+ int i, j;
+ double sum = 0.0;
+
+ // Provoking SSE
+ int until4 = offset + (size() & ~3);
+ for (
+ i = offset, j = xv.offset;
+ i < until4;
+ i += 4, j += 4
+ ) {
+ sum += thisValues[i] * thatValues[j] +
+ thisValues[i + 1] * thatValues[j + 1] +
+ thisValues[i + 2] * thatValues[j + 2] +
+ thisValues[i + 3] * thatValues[j + 3];
+ }
+
+ // Picking up the slack
+ for (
+ i = offset, j = xv.offset;
+ i < untilOffset;
+ ) {
+ sum += thisValues[i++] * thatValues[j++];
+ }
+ return sum;
+
+ } else if (x instanceof DenseVector ) {
+
+ if (size() != x.size())
+ throw new IllegalArgumentException("Cardinality mismatch during dot(x,y).");
+
+ DenseVector xv = (DenseVector) x;
+ double[] thisValues = ((DenseVector) vector).values;
+ double[] thatValues = xv.values;
+ int untilOffset = offset + size();
+
+ int i, j;
+ double sum = 0.0;
+
+ // Provoking SSE
+ int until4 = offset + (size() & ~3);
+ for (
+ i = offset, j = 0;
+ i < until4;
+ i += 4, j += 4
+ ) {
+ sum += thisValues[i] * thatValues[j] +
+ thisValues[i + 1] * thatValues[j + 1] +
+ thisValues[i + 2] * thatValues[j + 2] +
+ thisValues[i + 3] * thatValues[j + 3];
+ }
+
+ // Picking up slack
+ for ( ;
+ i < untilOffset;
+ ) {
+ sum += thisValues[i++] * thatValues[j++];
+ }
+ return sum;
+
+ } else {
+ return super.dot(x);
+ }
+ }
+
+ @Override
+ public Vector viewPart(int offset, int length) {
+ if (offset < 0) {
+ throw new IndexException(offset, size());
+ }
+ if (offset + length > size()) {
+ throw new IndexException(offset + length, size());
+ }
+ return new DenseVectorView(vector, offset + this.offset, length);
+ }
+ }
+}
r***@apache.org
2018-09-08 23:35:13 UTC
Permalink
http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/SparseRowMatrix.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/mahout/math/SparseRowMatrix.java b/core/src/main/java/org/apache/mahout/math/SparseRowMatrix.java
new file mode 100644
index 0000000..ee54ad0
--- /dev/null
+++ b/core/src/main/java/org/apache/mahout/math/SparseRowMatrix.java
@@ -0,0 +1,289 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math;
+
+import org.apache.mahout.math.flavor.MatrixFlavor;
+import org.apache.mahout.math.flavor.TraversingStructureEnum;
+import org.apache.mahout.math.function.DoubleDoubleFunction;
+import org.apache.mahout.math.function.Functions;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.util.Iterator;
+
+/**
+ * sparse matrix with general element values whose rows are accessible quickly. Implemented as a row
+ * array of either SequentialAccessSparseVectors or RandomAccessSparseVectors.
+ */
+public class SparseRowMatrix extends AbstractMatrix {
+ private Vector[] rowVectors;
+
+ private final boolean randomAccessRows;
+
+ private static final Logger log = LoggerFactory.getLogger(SparseRowMatrix.class);
+
+ /**
+ * Construct a sparse matrix starting with the provided row vectors.
+ *
+ * @param rows The number of rows in the result
+ * @param columns The number of columns in the result
+ * @param rowVectors a Vector[] array of rows
+ */
+ public SparseRowMatrix(int rows, int columns, Vector[] rowVectors) {
+ this(rows, columns, rowVectors, false, rowVectors instanceof RandomAccessSparseVector[]);
+ }
+
+ public SparseRowMatrix(int rows, int columns, boolean randomAccess) {
+ this(rows, columns, randomAccess
+ ? new RandomAccessSparseVector[rows]
+ : new SequentialAccessSparseVector[rows],
+ true,
+ randomAccess);
+ }
+
+ public SparseRowMatrix(int rows, int columns, Vector[] vectors, boolean shallowCopy, boolean randomAccess) {
+ super(rows, columns);
+ this.randomAccessRows = randomAccess;
+ this.rowVectors = vectors.clone();
+ for (int row = 0; row < rows; row++) {
+ if (vectors[row] == null) {
+ // TODO: this can't be right to change the argument
+ vectors[row] = randomAccess
+ ? new RandomAccessSparseVector(numCols(), 10)
+ : new SequentialAccessSparseVector(numCols(), 10);
+ }
+ this.rowVectors[row] = shallowCopy ? vectors[row] : vectors[row].clone();
+ }
+ }
+
+ /**
+ * Construct a matrix of the given cardinality, with rows defaulting to RandomAccessSparseVector
+ * implementation
+ *
+ * @param rows Number of rows in result
+ * @param columns Number of columns in result
+ */
+ public SparseRowMatrix(int rows, int columns) {
+ this(rows, columns, true);
+ }
+
+ @Override
+ public Matrix clone() {
+ SparseRowMatrix clone = (SparseRowMatrix) super.clone();
+ clone.rowVectors = new Vector[rowVectors.length];
+ for (int i = 0; i < rowVectors.length; i++) {
+ clone.rowVectors[i] = rowVectors[i].clone();
+ }
+ return clone;
+ }
+
+ @Override
+ public double getQuick(int row, int column) {
+ return rowVectors[row] == null ? 0.0 : rowVectors[row].getQuick(column);
+ }
+
+ @Override
+ public Matrix like() {
+ return new SparseRowMatrix(rowSize(), columnSize(), randomAccessRows);
+ }
+
+ @Override
+ public Matrix like(int rows, int columns) {
+ return new SparseRowMatrix(rows, columns, randomAccessRows);
+ }
+
+ @Override
+ public void setQuick(int row, int column, double value) {
+ rowVectors[row].setQuick(column, value);
+ }
+
+ @Override
+ public int[] getNumNondefaultElements() {
+ int[] result = new int[2];
+ result[ROW] = rowVectors.length;
+ for (int row = 0; row < rowSize(); row++) {
+ result[COL] = Math.max(result[COL], rowVectors[row].getNumNondefaultElements());
+ }
+ return result;
+ }
+
+ @Override
+ public Matrix viewPart(int[] offset, int[] size) {
+ if (offset[ROW] < 0) {
+ throw new IndexException(offset[ROW], rowVectors.length);
+ }
+ if (offset[ROW] + size[ROW] > rowVectors.length) {
+ throw new IndexException(offset[ROW] + size[ROW], rowVectors.length);
+ }
+ if (offset[COL] < 0) {
+ throw new IndexException(offset[COL], rowVectors[ROW].size());
+ }
+ if (offset[COL] + size[COL] > rowVectors[ROW].size()) {
+ throw new IndexException(offset[COL] + size[COL], rowVectors[ROW].size());
+ }
+ return new MatrixView(this, offset, size);
+ }
+
+ @Override
+ public Matrix assign(Matrix other, DoubleDoubleFunction function) {
+ int rows = rowSize();
+ if (rows != other.rowSize()) {
+ throw new CardinalityException(rows, other.rowSize());
+ }
+ int columns = columnSize();
+ if (columns != other.columnSize()) {
+ throw new CardinalityException(columns, other.columnSize());
+ }
+ for (int row = 0; row < rows; row++) {
+ try {
+ Iterator<Vector.Element> sparseRowIterator = ((SequentialAccessSparseVector) this.rowVectors[row])
+ .iterateNonZero();
+ if (function.isLikeMult()) { // TODO: is this a sufficient test?
+ // TODO: this may cause an exception if the row type is not compatible but it is currently guaranteed to be
+ // a SequentialAccessSparseVector, should "try" here just in case and Warn
+ // TODO: can we use iterateNonZero on both rows until the index is the same to get better speedup?
+
+ // TODO: SASVs have an iterateNonZero that returns zeros, this should not hurt but is far from optimal
+ // this might perform much better if SparseRowMatrix were backed by RandomAccessSparseVectors, which
+ // are backed by fastutil hashmaps and the iterateNonZero actually does only return nonZeros.
+ while (sparseRowIterator.hasNext()) {
+ Vector.Element element = sparseRowIterator.next();
+ int col = element.index();
+ setQuick(row, col, function.apply(element.get(), other.getQuick(row, col)));
+ }
+ } else {
+ for (int col = 0; col < columns; col++) {
+ setQuick(row, col, function.apply(getQuick(row, col), other.getQuick(row, col)));
+ }
+ }
+
+ } catch (ClassCastException e) {
+ // Warn and use default implementation
+ log.warn("Error casting the row to SequentialAccessSparseVector, this should never happen because" +
+ "SparseRomMatrix is always made of SequentialAccessSparseVectors. Proceeding with non-optimzed" +
+ "implementation.");
+ for (int col = 0; col < columns; col++) {
+ setQuick(row, col, function.apply(getQuick(row, col), other.getQuick(row, col)));
+ }
+ }
+ }
+ return this;
+ }
+
+ @Override
+ public Matrix assignColumn(int column, Vector other) {
+ if (rowSize() != other.size()) {
+ throw new CardinalityException(rowSize(), other.size());
+ }
+ if (column < 0 || column >= columnSize()) {
+ throw new IndexException(column, columnSize());
+ }
+ for (int row = 0; row < rowSize(); row++) {
+ rowVectors[row].setQuick(column, other.getQuick(row));
+ }
+ return this;
+ }
+
+ @Override
+ public Matrix assignRow(int row, Vector other) {
+ if (columnSize() != other.size()) {
+ throw new CardinalityException(columnSize(), other.size());
+ }
+ if (row < 0 || row >= rowSize()) {
+ throw new IndexException(row, rowSize());
+ }
+ rowVectors[row].assign(other);
+ return this;
+ }
+
+ /**
+ * @param row an int row index
+ * @return a shallow view of the Vector at specified row (ie you may mutate the original matrix
+ * using this row)
+ */
+ @Override
+ public Vector viewRow(int row) {
+ if (row < 0 || row >= rowSize()) {
+ throw new IndexException(row, rowSize());
+ }
+ return rowVectors[row];
+ }
+
+ @Override
+ public Matrix transpose() {
+ SparseColumnMatrix scm = new SparseColumnMatrix(columns, rows);
+ for (int i = 0; i < rows; i++) {
+ Vector row = rowVectors[i];
+ if (row.getNumNonZeroElements() > 0) {
+ scm.assignColumn(i, row);
+ }
+ }
+ return scm;
+ }
+
+ @Override
+ public Matrix times(Matrix other) {
+ if (columnSize() != other.rowSize()) {
+ throw new CardinalityException(columnSize(), other.rowSize());
+ }
+
+ if (other instanceof SparseRowMatrix) {
+ SparseRowMatrix y = (SparseRowMatrix) other;
+ SparseRowMatrix result = (SparseRowMatrix) like(rowSize(), other.columnSize());
+
+ for (int i = 0; i < rows; i++) {
+ Vector row = rowVectors[i];
+ for (Vector.Element element : row.nonZeroes()) {
+ result.rowVectors[i].assign(y.rowVectors[element.index()], Functions.plusMult(element.get()));
+ }
+ }
+ return result;
+ } else {
+ if (other.viewRow(0).isDense()) {
+ // result is dense, but can be computed relatively cheaply
+ Matrix result = other.like(rowSize(), other.columnSize());
+
+ for (int i = 0; i < rows; i++) {
+ Vector row = rowVectors[i];
+ Vector r = new DenseVector(other.columnSize());
+ for (Vector.Element element : row.nonZeroes()) {
+ r.assign(other.viewRow(element.index()), Functions.plusMult(element.get()));
+ }
+ result.viewRow(i).assign(r);
+ }
+ return result;
+ } else {
+ // other is sparse, but not something we understand intimately
+ SparseRowMatrix result = (SparseRowMatrix) like(rowSize(), other.columnSize());
+
+ for (int i = 0; i < rows; i++) {
+ Vector row = rowVectors[i];
+ for (Vector.Element element : row.nonZeroes()) {
+ result.rowVectors[i].assign(other.viewRow(element.index()), Functions.plusMult(element.get()));
+ }
+ }
+ return result;
+ }
+ }
+ }
+
+ @Override
+ public MatrixFlavor getFlavor() {
+ return MatrixFlavor.SPARSELIKE;
+ }
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/Swapper.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/mahout/math/Swapper.java b/core/src/main/java/org/apache/mahout/math/Swapper.java
new file mode 100644
index 0000000..1ca3744
--- /dev/null
+++ b/core/src/main/java/org/apache/mahout/math/Swapper.java
@@ -0,0 +1,35 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/*
+Copyright 1999 CERN - European Organization for Nuclear Research.
+Permission to use, copy, modify, distribute and sell this software and its documentation for any purpose
+is hereby granted without fee, provided that the above copyright notice appear in all copies and
+that both that copyright notice and this permission notice appear in supporting documentation.
+CERN makes no representations about the suitability of this software for any purpose.
+It is provided "as is" without expressed or implied warranty.
+*/
+package org.apache.mahout.math;
+
+/**
+ * Interface for an object that knows how to swap elements at two positions (a,b).
+ */
+public interface Swapper {
+
+ /** Swaps the generic data g[a] with g[b]. */
+ void swap(int a, int b);
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/TransposedMatrixView.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/mahout/math/TransposedMatrixView.java b/core/src/main/java/org/apache/mahout/math/TransposedMatrixView.java
new file mode 100644
index 0000000..ede6f35
--- /dev/null
+++ b/core/src/main/java/org/apache/mahout/math/TransposedMatrixView.java
@@ -0,0 +1,147 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math;
+
+import org.apache.mahout.math.flavor.BackEnum;
+import org.apache.mahout.math.flavor.MatrixFlavor;
+import org.apache.mahout.math.flavor.TraversingStructureEnum;
+import org.apache.mahout.math.function.DoubleDoubleFunction;
+import org.apache.mahout.math.function.DoubleFunction;
+
+/**
+ * Matrix View backed by an {@link org.apache.mahout.math.function.IntIntFunction}
+ */
+public class TransposedMatrixView extends AbstractMatrix {
+
+ private Matrix m;
+
+ public TransposedMatrixView(Matrix m) {
+ super(m.numCols(), m.numRows());
+ this.m = m;
+ }
+
+ @Override
+ public Matrix assignColumn(int column, Vector other) {
+ m.assignRow(column,other);
+ return this;
+ }
+
+ @Override
+ public Matrix assignRow(int row, Vector other) {
+ m.assignColumn(row,other);
+ return this;
+ }
+
+ @Override
+ public double getQuick(int row, int column) {
+ return m.getQuick(column,row);
+ }
+
+ @Override
+ public Matrix like() {
+ return m.like(rows, columns);
+ }
+
+ @Override
+ public Matrix like(int rows, int columns) {
+ return m.like(rows,columns);
+ }
+
+ @Override
+ public void setQuick(int row, int column, double value) {
+ m.setQuick(column, row, value);
+ }
+
+ @Override
+ public Vector viewRow(int row) {
+ return m.viewColumn(row);
+ }
+
+ @Override
+ public Vector viewColumn(int column) {
+ return m.viewRow(column);
+ }
+
+ @Override
+ public Matrix assign(double value) {
+ return m.assign(value);
+ }
+
+ @Override
+ public Matrix assign(Matrix other, DoubleDoubleFunction function) {
+ if (other instanceof TransposedMatrixView) {
+ m.assign(((TransposedMatrixView) other).m, function);
+ } else {
+ m.assign(new TransposedMatrixView(other), function);
+ }
+ return this;
+ }
+
+ @Override
+ public Matrix assign(Matrix other) {
+ if (other instanceof TransposedMatrixView) {
+ return m.assign(((TransposedMatrixView) other).m);
+ } else {
+ return m.assign(new TransposedMatrixView(other));
+ }
+ }
+
+ @Override
+ public Matrix assign(DoubleFunction function) {
+ return m.assign(function);
+ }
+
+ @Override
+ public MatrixFlavor getFlavor() {
+ return flavor;
+ }
+
+ private MatrixFlavor flavor = new MatrixFlavor() {
+ @Override
+ public BackEnum getBacking() {
+ return m.getFlavor().getBacking();
+ }
+
+ @Override
+ public TraversingStructureEnum getStructure() {
+ TraversingStructureEnum flavor = m.getFlavor().getStructure();
+ switch (flavor) {
+ case COLWISE:
+ return TraversingStructureEnum.ROWWISE;
+ case SPARSECOLWISE:
+ return TraversingStructureEnum.SPARSEROWWISE;
+ case ROWWISE:
+ return TraversingStructureEnum.COLWISE;
+ case SPARSEROWWISE:
+ return TraversingStructureEnum.SPARSECOLWISE;
+ default:
+ return flavor;
+ }
+ }
+
+ @Override
+ public boolean isDense() {
+ return m.getFlavor().isDense();
+ }
+ };
+
+ Matrix getDelegate() {
+ return m;
+ }
+
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/UpperTriangular.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/mahout/math/UpperTriangular.java b/core/src/main/java/org/apache/mahout/math/UpperTriangular.java
new file mode 100644
index 0000000..29fa6a0
--- /dev/null
+++ b/core/src/main/java/org/apache/mahout/math/UpperTriangular.java
@@ -0,0 +1,160 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math;
+
+import org.apache.mahout.math.flavor.BackEnum;
+import org.apache.mahout.math.flavor.MatrixFlavor;
+import org.apache.mahout.math.flavor.TraversingStructureEnum;
+
+/**
+ *
+ * Quick and dirty implementation of some {@link org.apache.mahout.math.Matrix} methods
+ * over packed upper triangular matrix.
+ *
+ */
+public class UpperTriangular extends AbstractMatrix {
+
+ private static final double EPSILON = 1.0e-12; // assume anything less than
+ // that to be 0 during
+ // non-upper assignments
+
+ private double[] values;
+
+ /**
+ * represents n x n upper triangular matrix
+ *
+ * @param n
+ */
+
+ public UpperTriangular(int n) {
+ super(n, n);
+ values = new double[n * (n + 1) / 2];
+ }
+
+ public UpperTriangular(double[] data, boolean shallow) {
+ this(elementsToMatrixSize(data != null ? data.length : 0));
+ if (data == null) {
+ throw new IllegalArgumentException("data");
+ }
+ values = shallow ? data : data.clone();
+ }
+
+ public UpperTriangular(Vector data) {
+ this(elementsToMatrixSize(data.size()));
+
+ for (Vector.Element el:data.nonZeroes()) {
+ values[el.index()] = el.get();
+ }
+ }
+
+ private static int elementsToMatrixSize(int dataSize) {
+ return (int) Math.round((-1 + Math.sqrt(1 + 8 * dataSize)) / 2);
+ }
+
+ // copy-constructor
+ public UpperTriangular(UpperTriangular mx) {
+ this(mx.values, false);
+ }
+
+ @Override
+ public Matrix assignColumn(int column, Vector other) {
+ if (columnSize() != other.size()) {
+ throw new IndexException(columnSize(), other.size());
+ }
+ if (other.viewPart(column + 1, other.size() - column - 1).norm(1) > 1.0e-14) {
+ throw new IllegalArgumentException("Cannot set lower portion of triangular matrix to non-zero");
+ }
+ for (Vector.Element element : other.viewPart(0, column).all()) {
+ setQuick(element.index(), column, element.get());
+ }
+ return this;
+ }
+
+ @Override
+ public Matrix assignRow(int row, Vector other) {
+ if (columnSize() != other.size()) {
+ throw new IndexException(numCols(), other.size());
+ }
+ for (int i = 0; i < row; i++) {
+ if (Math.abs(other.getQuick(i)) > EPSILON) {
+ throw new IllegalArgumentException("non-triangular source");
+ }
+ }
+ for (int i = row; i < rows; i++) {
+ setQuick(row, i, other.get(i));
+ }
+ return this;
+ }
+
+ public Matrix assignNonZeroElementsInRow(int row, double[] other) {
+ System.arraycopy(other, row, values, getL(row, row), rows - row);
+ return this;
+ }
+
+ @Override
+ public double getQuick(int row, int column) {
+ if (row > column) {
+ return 0;
+ }
+ int i = getL(row, column);
+ return values[i];
+ }
+
+ private int getL(int row, int col) {
+ /*
+ * each row starts with some zero elements that we don't store. this
+ * accumulates an offset of (row+1)*row/2
+ */
+ return col + row * numCols() - (row + 1) * row / 2;
+ }
+
+ @Override
+ public Matrix like() {
+ return like(rowSize(), columnSize());
+ }
+
+ @Override
+ public Matrix like(int rows, int columns) {
+ return new DenseMatrix(rows, columns);
+ }
+
+ @Override
+ public void setQuick(int row, int column, double value) {
+ values[getL(row, column)] = value;
+ }
+
+ @Override
+ public int[] getNumNondefaultElements() {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public Matrix viewPart(int[] offset, int[] size) {
+ return new MatrixView(this, offset, size);
+ }
+
+ public double[] getData() {
+ return values;
+ }
+
+ @Override
+ public MatrixFlavor getFlavor() {
+ // We kind of consider ourselves a vector-backed but dense matrix for mmul, etc. purposes.
+ return new MatrixFlavor.FlavorImpl(BackEnum.JVMMEM, TraversingStructureEnum.VECTORBACKED, true);
+ }
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/Vector.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/mahout/math/Vector.java b/core/src/main/java/org/apache/mahout/math/Vector.java
new file mode 100644
index 0000000..c3b1dc9
--- /dev/null
+++ b/core/src/main/java/org/apache/mahout/math/Vector.java
@@ -0,0 +1,434 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math;
+
+
+import org.apache.mahout.math.function.DoubleDoubleFunction;
+import org.apache.mahout.math.function.DoubleFunction;
+
+/**
+ * The basic interface including numerous convenience functions <p> NOTE: All implementing classes must have a
+ * constructor that takes an int for cardinality and a no-arg constructor that can be used for marshalling the Writable
+ * instance <p> NOTE: Implementations may choose to reuse the Vector.Element in the Iterable methods
+ */
+public interface Vector extends Cloneable {
+
+ /** @return a formatted String suitable for output */
+ String asFormatString();
+
+ /**
+ * Assign the value to all elements of the receiver
+ *
+ * @param value a double value
+ * @return the modified receiver
+ */
+ Vector assign(double value);
+
+ /**
+ * Assign the values to the receiver
+ *
+ * @param values a double[] of values
+ * @return the modified receiver
+ * @throws CardinalityException if the cardinalities differ
+ */
+ Vector assign(double[] values);
+
+ /**
+ * Assign the other vector values to the receiver
+ *
+ * @param other a Vector
+ * @return the modified receiver
+ * @throws CardinalityException if the cardinalities differ
+ */
+ Vector assign(Vector other);
+
+ /**
+ * Apply the function to each element of the receiver
+ *
+ * @param function a DoubleFunction to apply
+ * @return the modified receiver
+ */
+ Vector assign(DoubleFunction function);
+
+ /**
+ * Apply the function to each element of the receiver and the corresponding element of the other argument
+ *
+ * @param other a Vector containing the second arguments to the function
+ * @param function a DoubleDoubleFunction to apply
+ * @return the modified receiver
+ * @throws CardinalityException if the cardinalities differ
+ */
+ Vector assign(Vector other, DoubleDoubleFunction function);
+
+ /**
+ * Apply the function to each element of the receiver, using the y value as the second argument of the
+ * DoubleDoubleFunction
+ *
+ * @param f a DoubleDoubleFunction to be applied
+ * @param y a double value to be argument to the function
+ * @return the modified receiver
+ */
+ Vector assign(DoubleDoubleFunction f, double y);
+
+ /**
+ * Return the cardinality of the recipient (the maximum number of values)
+ *
+ * @return an int
+ */
+ int size();
+
+ /**
+ * true if this implementation should be considered dense -- that it explicitly
+ * represents every value
+ *
+ * @return true or false
+ */
+ boolean isDense();
+
+ /**
+ * true if this implementation should be considered to be iterable in index order in an efficient way.
+ * In particular this implies that {@link #all()} and {@link #nonZeroes()} ()} return elements
+ * in ascending order by index.
+ *
+ * @return true iff this implementation should be considered to be iterable in index order in an efficient way.
+ */
+ boolean isSequentialAccess();
+
+ /**
+ * Return a copy of the recipient
+ *
+ * @return a new Vector
+ */
+ @SuppressWarnings("CloneDoesntDeclareCloneNotSupportedException")
+ Vector clone();
+
+ Iterable<Element> all();
+
+ Iterable<Element> nonZeroes();
+
+ /**
+ * Return an object of Vector.Element representing an element of this Vector. Useful when designing new iterator
+ * types.
+ *
+ * @param index Index of the Vector.Element required
+ * @return The Vector.Element Object
+ */
+ Element getElement(int index);
+
+ /**
+ * Merge a set of (index, value) pairs into the vector.
+ * @param updates an ordered mapping of indices to values to be merged in.
+ */
+ void mergeUpdates(OrderedIntDoubleMapping updates);
+
+ /**
+ * A holder for information about a specific item in the Vector. <p>
+ * When using with an Iterator, the implementation
+ * may choose to reuse this element, so you may need to make a copy if you want to keep it
+ */
+ interface Element {
+
+ /** @return the value of this vector element. */
+ double get();
+
+ /** @return the index of this vector element. */
+ int index();
+
+ /** @param value Set the current element to value. */
+ void set(double value);
+ }
+
+ /**
+ * Return a new vector containing the values of the recipient divided by the argument
+ *
+ * @param x a double value
+ * @return a new Vector
+ */
+ Vector divide(double x);
+
+ /**
+ * Return the dot product of the recipient and the argument
+ *
+ * @param x a Vector
+ * @return a new Vector
+ * @throws CardinalityException if the cardinalities differ
+ */
+ double dot(Vector x);
+
+ /**
+ * Return the value at the given index
+ *
+ * @param index an int index
+ * @return the double at the index
+ * @throws IndexException if the index is out of bounds
+ */
+ double get(int index);
+
+ /**
+ * Return the value at the given index, without checking bounds
+ *
+ * @param index an int index
+ * @return the double at the index
+ */
+ double getQuick(int index);
+
+ /**
+ * Return an empty vector of the same underlying class as the receiver
+ *
+ * @return a Vector
+ */
+ Vector like();
+
+ /**
+ * Return a new empty vector of the same underlying class as the receiver with given cardinality
+ *
+ * @param cardinality - size of vector
+ * @return {@link Vector}
+ */
+ Vector like(int cardinality);
+
+ /**
+ * Return a new vector containing the element by element difference of the recipient and the argument
+ *
+ * @param x a Vector
+ * @return a new Vector
+ * @throws CardinalityException if the cardinalities differ
+ */
+ Vector minus(Vector x);
+
+ /**
+ * Return a new vector containing the normalized (L_2 norm) values of the recipient
+ *
+ * @return a new Vector
+ */
+ Vector normalize();
+
+ /**
+ * Return a new Vector containing the normalized (L_power norm) values of the recipient. <p>
+ * See
+ * http://en.wikipedia.org/wiki/Lp_space <p>
+ * Technically, when {@code 0 < power < 1}, we don't have a norm, just a metric,
+ * but we'll overload this here. <p>
+ * Also supports {@code power == 0} (number of non-zero elements) and power = {@link
+ * Double#POSITIVE_INFINITY} (max element). Again, see the Wikipedia page for more info
+ *
+ * @param power The power to use. Must be >= 0. May also be {@link Double#POSITIVE_INFINITY}. See the Wikipedia link
+ * for more on this.
+ * @return a new Vector x such that norm(x, power) == 1
+ */
+ Vector normalize(double power);
+
+ /**
+ * Return a new vector containing the log(1 + entry)/ L_2 norm values of the recipient
+ *
+ * @return a new Vector
+ */
+ Vector logNormalize();
+
+ /**
+ * Return a new Vector with a normalized value calculated as log_power(1 + entry)/ L_power norm. <p>
+ *
+ * @param power The power to use. Must be > 1. Cannot be {@link Double#POSITIVE_INFINITY}.
+ * @return a new Vector
+ */
+ Vector logNormalize(double power);
+
+ /**
+ * Return the k-norm of the vector. <p/> See http://en.wikipedia.org/wiki/Lp_space <p>
+ * Technically, when {@code 0 > power < 1}, we don't have a norm, just a metric, but we'll overload this here. Also supports power == 0 (number of
+ * non-zero elements) and power = {@link Double#POSITIVE_INFINITY} (max element). Again, see the Wikipedia page for
+ * more info.
+ *
+ * @param power The power to use.
+ * @see #normalize(double)
+ */
+ double norm(double power);
+
+ /** @return The minimum value in the Vector */
+ double minValue();
+
+ /** @return The index of the minimum value */
+ int minValueIndex();
+
+ /** @return The maximum value in the Vector */
+ double maxValue();
+
+ /** @return The index of the maximum value */
+ int maxValueIndex();
+
+ /**
+ * Return a new vector containing the sum of each value of the recipient and the argument
+ *
+ * @param x a double
+ * @return a new Vector
+ */
+ Vector plus(double x);
+
+ /**
+ * Return a new vector containing the element by element sum of the recipient and the argument
+ *
+ * @param x a Vector
+ * @return a new Vector
+ * @throws CardinalityException if the cardinalities differ
+ */
+ Vector plus(Vector x);
+
+ /**
+ * Set the value at the given index
+ *
+ * @param index an int index into the receiver
+ * @param value a double value to set
+ * @throws IndexException if the index is out of bounds
+ */
+ void set(int index, double value);
+
+ /**
+ * Set the value at the given index, without checking bounds
+ *
+ * @param index an int index into the receiver
+ * @param value a double value to set
+ */
+ void setQuick(int index, double value);
+
+ /**
+ * Increment the value at the given index by the given value.
+ *
+ * @param index an int index into the receiver
+ * @param increment sets the value at the given index to value + increment;
+ */
+ void incrementQuick(int index, double increment);
+
+ /**
+ * Return the number of values in the recipient which are not the default value. For instance, for a
+ * sparse vector, this would be the number of non-zero values.
+ *
+ * @return an int
+ */
+ int getNumNondefaultElements();
+
+ /**
+ * Return the number of non zero elements in the vector.
+ *
+ * @return an int
+ */
+ int getNumNonZeroElements();
+
+ /**
+ * Return a new vector containing the product of each value of the recipient and the argument
+ *
+ * @param x a double argument
+ * @return a new Vector
+ */
+ Vector times(double x);
+
+ /**
+ * Return a new vector containing the element-wise product of the recipient and the argument
+ *
+ * @param x a Vector argument
+ * @return a new Vector
+ * @throws CardinalityException if the cardinalities differ
+ */
+ Vector times(Vector x);
+
+ /**
+ * Return a new vector containing the subset of the recipient
+ *
+ * @param offset an int offset into the receiver
+ * @param length the cardinality of the desired result
+ * @return a new Vector
+ * @throws CardinalityException if the length is greater than the cardinality of the receiver
+ * @throws IndexException if the offset is negative or the offset+length is outside of the receiver
+ */
+ Vector viewPart(int offset, int length);
+
+ /**
+ * Return the sum of all the elements of the receiver
+ *
+ * @return a double
+ */
+ double zSum();
+
+ /**
+ * Return the cross product of the receiver and the other vector
+ *
+ * @param other another Vector
+ * @return a Matrix
+ */
+ Matrix cross(Vector other);
+
+ /*
+ * Need stories for these but keeping them here for now.
+ */
+ // void getNonZeros(IntArrayList jx, DoubleArrayList values);
+ // void foreachNonZero(IntDoubleFunction f);
+ // DoubleDoubleFunction map);
+ // NewVector assign(Vector y, DoubleDoubleFunction function, IntArrayList
+ // nonZeroIndexes);
+
+ /**
+ * Examples speak louder than words: aggregate(plus, pow(2)) is another way to say
+ * getLengthSquared(), aggregate(max, abs) is norm(Double.POSITIVE_INFINITY). To sum all of the positive values,
+ * aggregate(plus, max(0)).
+ * @param aggregator used to combine the current value of the aggregation with the result of map.apply(nextValue)
+ * @param map a function to apply to each element of the vector in turn before passing to the aggregator
+ * @return the final aggregation
+ */
+ double aggregate(DoubleDoubleFunction aggregator, DoubleFunction map);
+
+ /**
+ * <p>Generalized inner product - take two vectors, iterate over them both, using the combiner to combine together
+ * (and possibly map in some way) each pair of values, which are then aggregated with the previous accumulated
+ * value in the combiner.</p>
+ * <p>
+ * Example: dot(other) could be expressed as aggregate(other, Plus, Times), and kernelized inner products (which
+ * are symmetric on the indices) work similarly.
+ * @param other a vector to aggregate in combination with
+ * @param aggregator function we're aggregating with; fa
+ * @param combiner function we're combining with; fc
+ * @return the final aggregation; {@code if r0 = fc(this[0], other[0]), ri = fa(r_{i-1}, fc(this[i], other[i]))
+ * for all i > 0}
+ */
+ double aggregate(Vector other, DoubleDoubleFunction aggregator, DoubleDoubleFunction combiner);
+
+ /**
+ * Return the sum of squares of all elements in the vector. Square root of
+ * this value is the length of the vector.
+ */
+ double getLengthSquared();
+
+ /**
+ * Get the square of the distance between this vector and the other vector.
+ */
+ double getDistanceSquared(Vector v);
+
+ /**
+ * Gets an estimate of the cost (in number of operations) it takes to lookup a random element in this vector.
+ */
+ double getLookupCost();
+
+ /**
+ * Gets an estimate of the cost (in number of operations) it takes to advance an iterator through the nonzero
+ * elements of this vector.
+ */
+ double getIteratorAdvanceCost();
+
+ /**
+ * Return true iff adding a new (nonzero) element takes constant time for this vector.
+ */
+ boolean isAddConstantTime();
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/VectorBinaryAggregate.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/mahout/math/VectorBinaryAggregate.java b/core/src/main/java/org/apache/mahout/math/VectorBinaryAggregate.java
new file mode 100644
index 0000000..4d3a80f
--- /dev/null
+++ b/core/src/main/java/org/apache/mahout/math/VectorBinaryAggregate.java
@@ -0,0 +1,481 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math;
+
+import org.apache.mahout.math.function.DoubleDoubleFunction;
+import org.apache.mahout.math.set.OpenIntHashSet;
+
+import java.util.Iterator;
+
+/**
+ * Abstract class encapsulating different algorithms that perform the Vector operations aggregate().
+ * x.aggregte(y, fa, fc), for x and y Vectors and fa, fc DoubleDouble functions:
+ * - applies the function fc to every element in x and y, fc(xi, yi)
+ * - constructs a result iteratively, r0 = fc(x0, y0), ri = fc(r_{i-1}, fc(xi, yi)).
+ * This works essentially like a map/reduce functional combo.
+ *
+ * The names of variables, methods and classes used here follow the following conventions:
+ * The vector being assigned to (the left hand side) is called this or x.
+ * The right hand side is called that or y.
+ * The aggregating (reducing) function to be applied is called fa.
+ * The combining (mapping) function to be applied is called fc.
+ *
+ * The different algorithms take into account the different characteristics of vector classes:
+ * - whether the vectors support sequential iteration (isSequential())
+ * - what the lookup cost is (getLookupCost())
+ * - what the iterator advancement cost is (getIteratorAdvanceCost())
+ *
+ * The names of the actual classes (they're nested in VectorBinaryAssign) describe the used for assignment.
+ * The most important optimization is iterating just through the nonzeros (only possible if f(0, 0) = 0).
+ * There are 4 main possibilities:
+ * - iterating through the nonzeros of just one vector and looking up the corresponding elements in the other
+ * - iterating through the intersection of nonzeros (those indices where both vectors have nonzero values)
+ * - iterating through the union of nonzeros (those indices where at least one of the vectors has a nonzero value)
+ * - iterating through all the elements in some way (either through both at the same time, both one after the other,
+ * looking up both, looking up just one).
+ *
+ * The internal details are not important and a particular algorithm should generally not be called explicitly.
+ * The best one will be selected through assignBest(), which is itself called through Vector.assign().
+ *
+ * See https://docs.google.com/document/d/1g1PjUuvjyh2LBdq2_rKLIcUiDbeOORA1sCJiSsz-JVU/edit# for a more detailed
+ * explanation.
+ */
+public abstract class VectorBinaryAggregate {
+ public static final VectorBinaryAggregate[] OPERATIONS = {
+ new AggregateNonzerosIterateThisLookupThat(),
+ new AggregateNonzerosIterateThatLookupThis(),
+
+ new AggregateIterateIntersection(),
+
+ new AggregateIterateUnionSequential(),
+ new AggregateIterateUnionRandom(),
+
+ new AggregateAllIterateSequential(),
+ new AggregateAllIterateThisLookupThat(),
+ new AggregateAllIterateThatLookupThis(),
+ new AggregateAllLoop(),
+ };
+
+ /**
+ * Returns true iff we can use this algorithm to apply fc to x and y component-wise and aggregate the result using fa.
+ */
+ public abstract boolean isValid(Vector x, Vector y, DoubleDoubleFunction fa, DoubleDoubleFunction fc);
+
+ /**
+ * Estimates the cost of using this algorithm to compute the aggregation. The algorithm is assumed to be valid.
+ */
+ public abstract double estimateCost(Vector x, Vector y, DoubleDoubleFunction fa, DoubleDoubleFunction fc);
+
+ /**
+ * Main method that applies fc to x and y component-wise aggregating the results with fa. It returns the result of
+ * the aggregation.
+ */
+ public abstract double aggregate(Vector x, Vector y, DoubleDoubleFunction fa, DoubleDoubleFunction fc);
+
+ /**
+ * The best operation is the least expensive valid one.
+ */
+ public static VectorBinaryAggregate getBestOperation(Vector x, Vector y, DoubleDoubleFunction fa,
+ DoubleDoubleFunction fc) {
+ int bestOperationIndex = -1;
+ double bestCost = Double.POSITIVE_INFINITY;
+ for (int i = 0; i < OPERATIONS.length; ++i) {
+ if (OPERATIONS[i].isValid(x, y, fa, fc)) {
+ double cost = OPERATIONS[i].estimateCost(x, y, fa, fc);
+ if (cost < bestCost) {
+ bestCost = cost;
+ bestOperationIndex = i;
+ }
+ }
+ }
+ return OPERATIONS[bestOperationIndex];
+ }
+
+ /**
+ * This is the method that should be used when aggregating. It selects the best algorithm and applies it.
+ */
+ public static double aggregateBest(Vector x, Vector y, DoubleDoubleFunction fa, DoubleDoubleFunction fc) {
+ return getBestOperation(x, y, fa, fc).aggregate(x, y, fa, fc);
+ }
+
+ public static class AggregateNonzerosIterateThisLookupThat extends VectorBinaryAggregate {
+
+ @Override
+ public boolean isValid(Vector x, Vector y, DoubleDoubleFunction fa, DoubleDoubleFunction fc) {
+ return fa.isLikeRightPlus() && (fa.isAssociativeAndCommutative() || x.isSequentialAccess())
+ && fc.isLikeLeftMult();
+ }
+
+ @Override
+ public double estimateCost(Vector x, Vector y, DoubleDoubleFunction fa, DoubleDoubleFunction fc) {
+ return x.getNumNondefaultElements() * x.getIteratorAdvanceCost() * y.getLookupCost();
+ }
+
+ @Override
+ public double aggregate(Vector x, Vector y, DoubleDoubleFunction fa, DoubleDoubleFunction fc) {
+ Iterator<Vector.Element> xi = x.nonZeroes().iterator();
+ if (!xi.hasNext()) {
+ return 0;
+ }
+ Vector.Element xe = xi.next();
+ double result = fc.apply(xe.get(), y.getQuick(xe.index()));
+ while (xi.hasNext()) {
+ xe = xi.next();
+ result = fa.apply(result, fc.apply(xe.get(), y.getQuick(xe.index())));
+ }
+ return result;
+ }
+ }
+
+ public static class AggregateNonzerosIterateThatLookupThis extends VectorBinaryAggregate {
+
+ @Override
+ public boolean isValid(Vector x, Vector y, DoubleDoubleFunction fa, DoubleDoubleFunction fc) {
+ return fa.isLikeRightPlus() && (fa.isAssociativeAndCommutative() || y.isSequentialAccess())
+ && fc.isLikeRightMult();
+ }
+
+ @Override
+ public double estimateCost(Vector x, Vector y, DoubleDoubleFunction fa, DoubleDoubleFunction fc) {
+ return y.getNumNondefaultElements() * y.getIteratorAdvanceCost() * x.getLookupCost() * x.getLookupCost();
+ }
+
+ @Override
+ public double aggregate(Vector x, Vector y, DoubleDoubleFunction fa, DoubleDoubleFunction fc) {
+ Iterator<Vector.Element> yi = y.nonZeroes().iterator();
+ if (!yi.hasNext()) {
+ return 0;
+ }
+ Vector.Element ye = yi.next();
+ double result = fc.apply(x.getQuick(ye.index()), ye.get());
+ while (yi.hasNext()) {
+ ye = yi.next();
+ result = fa.apply(result, fc.apply(x.getQuick(ye.index()), ye.get()));
+ }
+ return result;
+ }
+ }
+
+ public static class AggregateIterateIntersection extends VectorBinaryAggregate {
+
+ @Override
+ public boolean isValid(Vector x, Vector y, DoubleDoubleFunction fa, DoubleDoubleFunction fc) {
+ return fa.isLikeRightPlus() && fc.isLikeMult() && x.isSequentialAccess() && y.isSequentialAccess();
+ }
+
+ @Override
+ public double estimateCost(Vector x, Vector y, DoubleDoubleFunction fa, DoubleDoubleFunction fc) {
+ return Math.min(x.getNumNondefaultElements() * x.getIteratorAdvanceCost(),
+ y.getNumNondefaultElements() * y.getIteratorAdvanceCost());
+ }
+
+ @Override
+ public double aggregate(Vector x, Vector y, DoubleDoubleFunction fa, DoubleDoubleFunction fc) {
+ Iterator<Vector.Element> xi = x.nonZeroes().iterator();
+ Iterator<Vector.Element> yi = y.nonZeroes().iterator();
+ Vector.Element xe = null;
+ Vector.Element ye = null;
+ boolean advanceThis = true;
+ boolean advanceThat = true;
+ boolean validResult = false;
+ double result = 0;
+ while (true) {
+ if (advanceThis) {
+ if (xi.hasNext()) {
+ xe = xi.next();
+ } else {
+ break;
+ }
+ }
+ if (advanceThat) {
+ if (yi.hasNext()) {
+ ye = yi.next();
+ } else {
+ break;
+ }
+ }
+ if (xe.index() == ye.index()) {
+ double thisResult = fc.apply(xe.get(), ye.get());
+ if (validResult) {
+ result = fa.apply(result, thisResult);
+ } else {
+ result = thisResult;
+ validResult = true;
+ }
+ advanceThis = true;
+ advanceThat = true;
+ } else {
+ if (xe.index() < ye.index()) { // f(x, 0) = 0
+ advanceThis = true;
+ advanceThat = false;
+ } else { // f(0, y) = 0
+ advanceThis = false;
+ advanceThat = true;
+ }
+ }
+ }
+ return result;
+ }
+ }
+
+ public static class AggregateIterateUnionSequential extends VectorBinaryAggregate {
+
+ @Override
+ public boolean isValid(Vector x, Vector y, DoubleDoubleFunction fa, DoubleDoubleFunction fc) {
+ return fa.isLikeRightPlus() && !fc.isDensifying()
+ && x.isSequentialAccess() && y.isSequentialAccess();
+ }
+
+ @Override
+ public double estimateCost(Vector x, Vector y, DoubleDoubleFunction fa, DoubleDoubleFunction fc) {
+ return Math.max(x.getNumNondefaultElements() * x.getIteratorAdvanceCost(),
+ y.getNumNondefaultElements() * y.getIteratorAdvanceCost());
+ }
+
+ @Override
+ public double aggregate(Vector x, Vector y, DoubleDoubleFunction fa, DoubleDoubleFunction fc) {
+ Iterator<Vector.Element> xi = x.nonZeroes().iterator();
+ Iterator<Vector.Element> yi = y.nonZeroes().iterator();
+ Vector.Element xe = null;
+ Vector.Element ye = null;
+ boolean advanceThis = true;
+ boolean advanceThat = true;
+ boolean validResult = false;
+ double result = 0;
+ while (true) {
+ if (advanceThis) {
+ if (xi.hasNext()) {
+ xe = xi.next();
+ } else {
+ xe = null;
+ }
+ }
+ if (advanceThat) {
+ if (yi.hasNext()) {
+ ye = yi.next();
+ } else {
+ ye = null;
+ }
+ }
+ double thisResult;
+ if (xe != null && ye != null) { // both vectors have nonzero elements
+ if (xe.index() == ye.index()) {
+ thisResult = fc.apply(xe.get(), ye.get());
+ advanceThis = true;
+ advanceThat = true;
+ } else {
+ if (xe.index() < ye.index()) { // f(x, 0)
+ thisResult = fc.apply(xe.get(), 0);
+ advanceThis = true;
+ advanceThat = false;
+ } else {
+ thisResult = fc.apply(0, ye.get());
+ advanceThis = false;
+ advanceThat = true;
+ }
+ }
+ } else if (xe != null) { // just the first one still has nonzeros
+ thisResult = fc.apply(xe.get(), 0);
+ advanceThis = true;
+ advanceThat = false;
+ } else if (ye != null) { // just the second one has nonzeros
+ thisResult = fc.apply(0, ye.get());
+ advanceThis = false;
+ advanceThat = true;
+ } else { // we're done, both are empty
+ break;
+ }
+ if (validResult) {
+ result = fa.apply(result, thisResult);
+ } else {
+ result = thisResult;
+ validResult = true;
+ }
+ }
+ return result;
+ }
+ }
+
+ public static class AggregateIterateUnionRandom extends VectorBinaryAggregate {
+
+ @Override
+ public boolean isValid(Vector x, Vector y, DoubleDoubleFunction fa, DoubleDoubleFunction fc) {
+ return fa.isLikeRightPlus() && !fc.isDensifying()
+ && (fa.isAssociativeAndCommutative() || (x.isSequentialAccess() && y.isSequentialAccess()));
+ }
+
+ @Override
+ public double estimateCost(Vector x, Vector y, DoubleDoubleFunction fa, DoubleDoubleFunction fc) {
+ return Math.max(x.getNumNondefaultElements() * x.getIteratorAdvanceCost() * y.getLookupCost(),
+ y.getNumNondefaultElements() * y.getIteratorAdvanceCost() * x.getLookupCost());
+ }
+
+ @Override
+ public double aggregate(Vector x, Vector y, DoubleDoubleFunction fa, DoubleDoubleFunction fc) {
+ OpenIntHashSet visited = new OpenIntHashSet();
+ Iterator<Vector.Element> xi = x.nonZeroes().iterator();
+ boolean validResult = false;
+ double result = 0;
+ double thisResult;
+ while (xi.hasNext()) {
+ Vector.Element xe = xi.next();
+ thisResult = fc.apply(xe.get(), y.getQuick(xe.index()));
+ if (validResult) {
+ result = fa.apply(result, thisResult);
+ } else {
+ result = thisResult;
+ validResult = true;
+ }
+ visited.add(xe.index());
+ }
+ Iterator<Vector.Element> yi = y.nonZeroes().iterator();
+ while (yi.hasNext()) {
+ Vector.Element ye = yi.next();
+ if (!visited.contains(ye.index())) {
+ thisResult = fc.apply(x.getQuick(ye.index()), ye.get());
+ if (validResult) {
+ result = fa.apply(result, thisResult);
+ } else {
+ result = thisResult;
+ validResult = true;
+ }
+ }
+ }
+ return result;
+ }
+ }
+
+ public static class AggregateAllIterateSequential extends VectorBinaryAggregate {
+
+ @Override
+ public boolean isValid(Vector x, Vector y, DoubleDoubleFunction fa, DoubleDoubleFunction fc) {
+ return x.isSequentialAccess() && y.isSequentialAccess() && !x.isDense() && !y.isDense();
+ }
+
+ @Override
+ public double estimateCost(Vector x, Vector y, DoubleDoubleFunction fa, DoubleDoubleFunction fc) {
+ return Math.max(x.size() * x.getIteratorAdvanceCost(), y.size() * y.getIteratorAdvanceCost());
+ }
+
+ @Override
+ public double aggregate(Vector x, Vector y, DoubleDoubleFunction fa, DoubleDoubleFunction fc) {
+ Iterator<Vector.Element> xi = x.all().iterator();
+ Iterator<Vector.Element> yi = y.all().iterator();
+ boolean validResult = false;
+ double result = 0;
+ while (xi.hasNext() && yi.hasNext()) {
+ Vector.Element xe = xi.next();
+ double thisResult = fc.apply(xe.get(), yi.next().get());
+ if (validResult) {
+ result = fa.apply(result, thisResult);
+ } else {
+ result = thisResult;
+ validResult = true;
+ }
+ }
+ return result;
+ }
+ }
+
+ public static class AggregateAllIterateThisLookupThat extends VectorBinaryAggregate {
+
+ @Override
+ public boolean isValid(Vector x, Vector y, DoubleDoubleFunction fa, DoubleDoubleFunction fc) {
+ return (fa.isAssociativeAndCommutative() || x.isSequentialAccess())
+ && !x.isDense();
+ }
+
+ @Override
+ public double estimateCost(Vector x, Vector y, DoubleDoubleFunction fa, DoubleDoubleFunction fc) {
+ return x.size() * x.getIteratorAdvanceCost() * y.getLookupCost();
+ }
+
+ @Override
+ public double aggregate(Vector x, Vector y, DoubleDoubleFunction fa, DoubleDoubleFunction fc) {
+ Iterator<Vector.Element> xi = x.all().iterator();
+ boolean validResult = false;
+ double result = 0;
+ while (xi.hasNext()) {
+ Vector.Element xe = xi.next();
+ double thisResult = fc.apply(xe.get(), y.getQuick(xe.index()));
+ if (validResult) {
+ result = fa.apply(result, thisResult);
+ } else {
+ result = thisResult;
+ validResult = true;
+ }
+ }
+ return result;
+ }
+ }
+
+ public static class AggregateAllIterateThatLookupThis extends VectorBinaryAggregate {
+
+ @Override
+ public boolean isValid(Vector x, Vector y, DoubleDoubleFunction fa, DoubleDoubleFunction fc) {
+ return (fa.isAssociativeAndCommutative() || y.isSequentialAccess())
+ && !y.isDense();
+ }
+
+ @Override
+ public double estimateCost(Vector x, Vector y, DoubleDoubleFunction fa, DoubleDoubleFunction fc) {
+ return y.size() * y.getIteratorAdvanceCost() * x.getLookupCost();
+ }
+
+ @Override
+ public double aggregate(Vector x, Vector y, DoubleDoubleFunction fa, DoubleDoubleFunction fc) {
+ Iterator<Vector.Element> yi = y.all().iterator();
+ boolean validResult = false;
+ double result = 0;
+ while (yi.hasNext()) {
+ Vector.Element ye = yi.next();
+ double thisResult = fc.apply(x.getQuick(ye.index()), ye.get());
+ if (validResult) {
+ result = fa.apply(result, thisResult);
+ } else {
+ result = thisResult;
+ validResult = true;
+ }
+ }
+ return result;
+ }
+ }
+
+ public static class AggregateAllLoop extends VectorBinaryAggregate {
+
+ @Override
+ public boolean isValid(Vector x, Vector y, DoubleDoubleFunction fa, DoubleDoubleFunction fc) {
+ return true;
+ }
+
+ @Override
+ public double estimateCost(Vector x, Vector y, DoubleDoubleFunction fa, DoubleDoubleFunction fc) {
+ return x.size() * x.getLookupCost() * y.getLookupCost();
+ }
+
+ @Override
+ public double aggregate(Vector x, Vector y, DoubleDoubleFunction fa, DoubleDoubleFunction fc) {
+ double result = fc.apply(x.getQuick(0), y.getQuick(0));
+ int s = x.size();
+ for (int i = 1; i < s; ++i) {
+ result = fa.apply(result, fc.apply(x.getQuick(i), y.getQuick(i)));
+ }
+ return result;
+ }
+ }
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/VectorBinaryAssign.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/mahout/math/VectorBinaryAssign.java b/core/src/main/java/org/apache/mahout/math/VectorBinaryAssign.java
new file mode 100644
index 0000000..f24d552
--- /dev/null
+++ b/core/src/main/java/org/apache/mahout/math/VectorBinaryAssign.java
@@ -0,0 +1,667 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math;
+
+import org.apache.mahout.math.Vector.Element;
+import org.apache.mahout.math.function.DoubleDoubleFunction;
+import org.apache.mahout.math.set.OpenIntHashSet;
+
+import java.util.Iterator;
+
+/**
+ * Abstract class encapsulating different algorithms that perform the Vector operations assign().
+ * x.assign(y, f), for x and y Vectors and f a DoubleDouble function:
+ * - applies the function f to every element in x and y, f(xi, yi)
+ * - assigns xi = f(xi, yi) for all indices i
+ *
+ * The names of variables, methods and classes used here follow the following conventions:
+ * The vector being assigned to (the left hand side) is called this or x.
+ * The right hand side is called that or y.
+ * The function to be applied is called f.
+ *
+ * The different algorithms take into account the different characteristics of vector classes:
+ * - whether the vectors support sequential iteration (isSequential())
+ * - whether the vectors support constant-time additions (isAddConstantTime())
+ * - what the lookup cost is (getLookupCost())
+ * - what the iterator advancement cost is (getIteratorAdvanceCost())
+ *
+ * The names of the actual classes (they're nested in VectorBinaryAssign) describe the used for assignment.
+ * The most important optimization is iterating just through the nonzeros (only possible if f(0, 0) = 0).
+ * There are 4 main possibilities:
+ * - iterating through the nonzeros of just one vector and looking up the corresponding elements in the other
+ * - iterating through the intersection of nonzeros (those indices where both vectors have nonzero values)
+ * - iterating through the union of nonzeros (those indices where at least one of the vectors has a nonzero value)
+ * - iterating through all the elements in some way (either through both at the same time, both one after the other,
+ * looking up both, looking up just one).
+ * Then, there are two additional sub-possibilities:
+ * - if a new value can be added to x in constant time (isAddConstantTime()), the *Inplace updates are used
+ * - otherwise (really just for SequentialAccessSparseVectors right now), the *Merge updates are used, where
+ * a sorted list of (index, value) pairs is merged into the vector at the end.
+ *
+ * The internal details are not important and a particular algorithm should generally not be called explicitly.
+ * The best one will be selected through assignBest(), which is itself called through Vector.assign().
+ *
+ * See https://docs.google.com/document/d/1g1PjUuvjyh2LBdq2_rKLIcUiDbeOORA1sCJiSsz-JVU/edit# for a more detailed
+ * explanation.
+ */
+public abstract class VectorBinaryAssign {
+ public static final VectorBinaryAssign[] OPERATIONS = {
+ new AssignNonzerosIterateThisLookupThat(),
+ new AssignNonzerosIterateThatLookupThisMergeUpdates(),
+ new AssignNonzerosIterateThatLookupThisInplaceUpdates(),
+
+ new AssignIterateIntersection(),
+
+ new AssignIterateUnionSequentialMergeUpdates(),
+ new AssignIterateUnionSequentialInplaceUpdates(),
+ new AssignIterateUnionRandomMergeUpdates(),
+ new AssignIterateUnionRandomInplaceUpdates(),
+
+ new AssignAllIterateSequentialMergeUpdates(),
+ new AssignAllIterateSequentialInplaceUpdates(),
+ new AssignAllIterateThisLookupThatMergeUpdates(),
+ new AssignAllIterateThisLookupThatInplaceUpdates(),
+ new AssignAllIterateThatLookupThisMergeUpdates(),
+ new AssignAllIterateThatLookupThisInplaceUpdates(),
+ new AssignAllLoopMergeUpdates(),
+ new AssignAllLoopInplaceUpdates(),
+ };
+
+ /**
+ * Returns true iff we can use this algorithm to apply f to x and y component-wise and assign the result to x.
+ */
+ public abstract boolean isValid(Vector x, Vector y, DoubleDoubleFunction f);
+
+ /**
+ * Estimates the cost of using this algorithm to compute the assignment. The algorithm is assumed to be valid.
+ */
+ public abstract double estimateCost(Vector x, Vector y, DoubleDoubleFunction f);
+
+ /**
+ * Main method that applies f to x and y component-wise assigning the results to x. It returns the modified vector,
+ * x.
+ */
+ public abstract Vector assign(Vector x, Vector y, DoubleDoubleFunction f);
+
+ /**
+ * The best operation is the least expensive valid one.
+ */
+ public static VectorBinaryAssign getBestOperation(Vector x, Vector y, DoubleDoubleFunction f) {
+ int bestOperationIndex = -1;
+ double bestCost = Double.POSITIVE_INFINITY;
+ for (int i = 0; i < OPERATIONS.length; ++i) {
+ if (OPERATIONS[i].isValid(x, y, f)) {
+ double cost = OPERATIONS[i].estimateCost(x, y, f);
+ if (cost < bestCost) {
+ bestCost = cost;
+ bestOperationIndex = i;
+ }
+ }
+ }
+ return OPERATIONS[bestOperationIndex];
+ }
+
+ /**
+ * This is the method that should be used when assigning. It selects the best algorithm and applies it.
+ * Note that it does NOT invalidate the cached length of the Vector and should only be used through the wrapprs
+ * in AbstractVector.
+ */
+ public static Vector assignBest(Vector x, Vector y, DoubleDoubleFunction f) {
+ return getBestOperation(x, y, f).assign(x, y, f);
+ }
+
+ /**
+ * If f(0, y) = 0, the zeros in x don't matter and we can simply iterate through the nonzeros of x.
+ * To get the corresponding element of y, we perform a lookup.
+ * There are no *Merge or *Inplace versions because in this case x cannot become more dense because of f, meaning
+ * all changes will occur at indices whose values are already nonzero.
+ */
+ public static class AssignNonzerosIterateThisLookupThat extends VectorBinaryAssign {
+
+ @Override
+ public boolean isValid(Vector x, Vector y, DoubleDoubleFunction f) {
+ return f.isLikeLeftMult();
+ }
+
+ @Override
+ public double estimateCost(Vector x, Vector y, DoubleDoubleFunction f) {
+ return x.getNumNondefaultElements() * x.getIteratorAdvanceCost() * y.getLookupCost();
+ }
+
+ @Override
+ public Vector assign(Vector x, Vector y, DoubleDoubleFunction f) {
+ for (Element xe : x.nonZeroes()) {
+ xe.set(f.apply(xe.get(), y.getQuick(xe.index())));
+ }
+ return x;
+ }
+ }
+
+ /**
+ * If f(x, 0) = x, the zeros in y don't matter and we can simply iterate through the nonzeros of y.
+ * We get the corresponding element of x through a lookup and update x inplace.
+ */
+ public static class AssignNonzerosIterateThatLookupThisInplaceUpdates extends VectorBinaryAssign {
+
+ @Override
+ public boolean isValid(Vector x, Vector y, DoubleDoubleFunction f) {
+ return f.isLikeRightPlus();
+ }
+
+ @Override
+ public double estimateCost(Vector x, Vector y, DoubleDoubleFunction f) {
+ return y.getNumNondefaultElements() * y.getIteratorAdvanceCost() * x.getLookupCost() * x.getLookupCost();
+ }
+
+ @Override
+ public Vector assign(Vector x, Vector y, DoubleDoubleFunction f) {
+ for (Element ye : y.nonZeroes()) {
+ x.setQuick(ye.index(), f.apply(x.getQuick(ye.index()), ye.get()));
+ }
+ return x;
+ }
+ }
+
+ /**
+ * If f(x, 0) = x, the zeros in y don't matter and we can simply iterate through the nonzeros of y.
+ * We get the corresponding element of x through a lookup and update x by merging.
+ */
+ public static class AssignNonzerosIterateThatLookupThisMergeUpdates extends VectorBinaryAssign {
+
+ @Override
+ public boolean isValid(Vector x, Vector y, DoubleDoubleFunction f) {
+ return f.isLikeRightPlus() && y.isSequentialAccess() && !x.isAddConstantTime();
+ }
+
+ @Override
+ public double estimateCost(Vector x, Vector y, DoubleDoubleFunction f) {
+ return y.getNumNondefaultElements() * y.getIteratorAdvanceCost() * y.getLookupCost();
+ }
+
+ @Override
+ public Vector assign(Vector x, Vector y, DoubleDoubleFunction f) {
+ OrderedIntDoubleMapping updates = new OrderedIntDoubleMapping(false);
+ for (Element ye : y.nonZeroes()) {
+ updates.set(ye.index(), f.apply(x.getQuick(ye.index()), ye.get()));
+ }
+ x.mergeUpdates(updates);
+ return x;
+ }
+ }
+
+ /**
+ * If f(x, 0) = x and f(0, y) = 0 the zeros in x and y don't matter and we can iterate through the nonzeros
+ * in both x and y.
+ * This is only possible if both x and y support sequential access.
+ */
+ public static class AssignIterateIntersection extends VectorBinaryAssign {
+
+ @Override
+ public boolean isValid(Vector x, Vector y, DoubleDoubleFunction f) {
+ return f.isLikeLeftMult() && f.isLikeRightPlus() && x.isSequentialAccess() && y.isSequentialAccess();
+ }
+
+ @Override
+ public double estimateCost(Vector x, Vector y, DoubleDoubleFunction f) {
+ return Math.min(x.getNumNondefaultElements() * x.getIteratorAdvanceCost(),
+ y.getNumNondefaultElements() * y.getIteratorAdvanceCost());
+ }
+
+ @Override
+ public Vector assign(Vector x, Vector y, DoubleDoubleFunction f) {
+ Iterator<Vector.Element> xi = x.nonZeroes().iterator();
+ Iterator<Vector.Element> yi = y.nonZeroes().iterator();
+ Vector.Element xe = null;
+ Vector.Element ye = null;
+ boolean advanceThis = true;
+ boolean advanceThat = true;
+ while (true) {
+ if (advanceThis) {
+ if (xi.hasNext()) {
+ xe = xi.next();
+ } else {
+ break;
+ }
+ }
+ if (advanceThat) {
+ if (yi.hasNext()) {
+ ye = yi.next();
+ } else {
+ break;
+ }
+ }
+ if (xe.index() == ye.index()) {
+ xe.set(f.apply(xe.get(), ye.get()));
+ advanceThis = true;
+ advanceThat = true;
+ } else {
+ if (xe.index() < ye.index()) { // f(x, 0) = 0
+ advanceThis = true;
+ advanceThat = false;
+ } else { // f(0, y) = 0
+ advanceThis = false;
+ advanceThat = true;
+ }
+ }
+ }
+ return x;
+ }
+ }
+
+ /**
+ * If f(0, 0) = 0 we can iterate through the nonzeros in either x or y.
+ * In this case we iterate through them in parallel and update x by merging. Because we're iterating through
+ * both vectors at the same time, x and y need to support sequential access.
+ */
+ public static class AssignIterateUnionSequentialMergeUpdates extends VectorBinaryAssign {
+
+ @Override
+ public boolean isValid(Vector x, Vector y, DoubleDoubleFunction f) {
+ return !f.isDensifying() && x.isSequentialAccess() && y.isSequentialAccess() && !x.isAddConstantTime();
+ }
+
+ @Override
+ public double estimateCost(Vector x, Vector y, DoubleDoubleFunction f) {
+ return Math.max(x.getNumNondefaultElements() * x.getIteratorAdvanceCost(),
+ y.getNumNondefaultElements() * y.getIteratorAdvanceCost());
+ }
+
+ @Override
+ public Vector assign(Vector x, Vector y, DoubleDoubleFunction f) {
+ Iterator<Vector.Element> xi = x.nonZeroes().iterator();
+ Iterator<Vector.Element> yi = y.nonZeroes().iterator();
+ Vector.Element xe = null;
+ Vector.Element ye = null;
+ boolean advanceThis = true;
+ boolean advanceThat = true;
+ OrderedIntDoubleMapping updates = new OrderedIntDoubleMapping(false);
+ while (true) {
+ if (advanceThis) {
+ if (xi.hasNext()) {
+ xe = xi.next();
+ } else {
+ xe = null;
+ }
+ }
+ if (advanceThat) {
+ if (yi.hasNext()) {
+ ye = yi.next();
+ } else {
+ ye = null;
+ }
+ }
+ if (xe != null && ye != null) { // both vectors have nonzero elements
+ if (xe.index() == ye.index()) {
+ xe.set(f.apply(xe.get(), ye.get()));
+ advanceThis = true;
+ advanceThat = true;
+ } else {
+ if (xe.index() < ye.index()) { // f(x, 0)
+ xe.set(f.apply(xe.get(), 0));
+ advanceThis = true;
+ advanceThat = false;
+ } else {
+ updates.set(ye.index(), f.apply(0, ye.get()));
+ advanceThis = false;
+ advanceThat = true;
+ }
+ }
+ } else if (xe != null) { // just the first one still has nonzeros
+ xe.set(f.apply(xe.get(), 0));
+ advanceThis = true;
+ advanceThat = false;
+ } else if (ye != null) { // just the second one has nonzeros
+ updates.set(ye.index(), f.apply(0, ye.get()));
+ advanceThis = false;
+ advanceThat = true;
+ } else { // we're done, both are empty
+ break;
+ }
+ }
+ x.mergeUpdates(updates);
+ return x;
+ }
+ }
+
+ /**
+ * If f(0, 0) = 0 we can iterate through the nonzeros in either x or y.
+ * In this case we iterate through them in parallel and update x inplace. Because we're iterating through
+ * both vectors at the same time, x and y need to support sequential access.
+ */
+ public static class AssignIterateUnionSequentialInplaceUpdates extends VectorBinaryAssign {
+
+ @Override
+ public boolean isValid(Vector x, Vector y, DoubleDoubleFunction f) {
+ return !f.isDensifying() && x.isSequentialAccess() && y.isSequentialAccess() && x.isAddConstantTime();
+ }
+
+ @Override
+ public double estimateCost(Vector x, Vector y, DoubleDoubleFunction f) {
+ return Math.max(x.getNumNondefaultElements() * x.getIteratorAdvanceCost(),
+ y.getNumNondefaultElements() * y.getIteratorAdvanceCost());
+ }
+
+ @Override
+ public Vector assign(Vector x, Vector y, DoubleDoubleFunction f) {
+ Iterator<Vector.Element> xi = x.nonZeroes().iterator();
+ Iterator<Vector.Element> yi = y.nonZeroes().iterator();
+ Vector.Element xe = null;
+ Vector.Element ye = null;
+ boolean advanceThis = true;
+ boolean advanceThat = true;
+ while (true) {
+ if (advanceThis) {
+ if (xi.hasNext()) {
+ xe = xi.next();
+ } else {
+ xe = null;
+ }
+ }
+ if (advanceThat) {
+ if (yi.hasNext()) {
+ ye = yi.next();
+ } else {
+ ye = null;
+ }
+ }
+ if (xe != null && ye != null) { // both vectors have nonzero elements
+ if (xe.index() == ye.index()) {
+ xe.set(f.apply(xe.get(), ye.get()));
+ advanceThis = true;
+ advanceThat = true;
+ } else {
+ if (xe.index() < ye.index()) { // f(x, 0)
+ xe.set(f.apply(xe.get(), 0));
+ advanceThis = true;
+ advanceThat = false;
+ } else {
+ x.setQuick(ye.index(), f.apply(0, ye.get()));
+ advanceThis = false;
+ advanceThat = true;
+ }
+ }
+ } else if (xe != null) { // just the first one still has nonzeros
+ xe.set(f.apply(xe.get(), 0));
+ advanceThis = true;
+ advanceThat = false;
+ } else if (ye != null) { // just the second one has nonzeros
+ x.setQuick(ye.index(), f.apply(0, ye.get()));
+ advanceThis = false;
+ advanceThat = true;
+ } else { // we're done, both are empty
+ break;
+ }
+ }
+ return x;
+ }
+ }
+
+ /**
+ * If f(0, 0) = 0 we can iterate through the nonzeros in either x or y.
+ * In this case, we iterate through the nozeros of x and y alternatively (this works even when one of them
+ * doesn't support sequential access). Since we're merging the results into x, when iterating through y, the
+ * order of iteration matters and y must support sequential access.
+ */
+ public static class AssignIterateUnionRandomMergeUpdates extends VectorBinaryAssign {
+
+ @Override
+ public boolean isValid(Vector x, Vector y, DoubleDoubleFunction f) {
+ return !f.isDensifying() && !x.isAddConstantTime() && y.isSequentialAccess();
+ }
+
+ @Override
+ public double estimateCost(Vector x, Vector y, DoubleDoubleFunction f) {
+ return Math.max(x.getNumNondefaultElements() * x.getIteratorAdvanceCost() * y.getLookupCost(),
+ y.getNumNondefaultElements() * y.getIteratorAdvanceCost() * x.getLookupCost());
+ }
+
+ @Override
+ public Vector assign(Vector x, Vector y, DoubleDoubleFunction f) {
+ OpenIntHashSet visited = new OpenIntHashSet();
+ for (Element xe : x.nonZeroes()) {
+ xe.set(f.apply(xe.get(), y.getQuick(xe.index())));
+ visited.add(xe.index());
+ }
+ OrderedIntDoubleMapping updates = new OrderedIntDoubleMapping(false);
+ for (Element ye : y.nonZeroes()) {
+ if (!visited.contains(ye.index())) {
+ updates.set(ye.index(), f.apply(x.getQuick(ye.index()), ye.get()));
+ }
+ }
+ x.mergeUpdates(updates);
+ return x;
+ }
+ }
+
+ /**
+ * If f(0, 0) = 0 we can iterate through the nonzeros in either x or y.
+ * In this case, we iterate through the nozeros of x and y alternatively (this works even when one of them
+ * doesn't support sequential access). Because updates to x are inplace, neither x, nor y need to support
+ * sequential access.
+ */
+ public static class AssignIterateUnionRandomInplaceUpdates extends VectorBinaryAssign {
+
+ @Override
+ public boolean isValid(Vector x, Vector y, DoubleDoubleFunction f) {
+ return !f.isDensifying() && x.isAddConstantTime();
+ }
+
+ @Override
+ public double estimateCost(Vector x, Vector y, DoubleDoubleFunction f) {
+ return Math.max(x.getNumNondefaultElements() * x.getIteratorAdvanceCost() * y.getLookupCost(),
+ y.getNumNondefaultElements() * y.getIteratorAdvanceCost() * x.getLookupCost());
+ }
+ @Override
+ public Vector assign(Vector x, Vector y, DoubleDoubleFunction f) {
+ OpenIntHashSet visited = new OpenIntHashSet();
+ for (Element xe : x.nonZeroes()) {
+ xe.set(f.apply(xe.get(), y.getQuick(xe.index())));
+ visited.add(xe.index());
+ }
+ for (Element ye : y.nonZeroes()) {
+ if (!visited.contains(ye.index())) {
+ x.setQuick(ye.index(), f.apply(x.getQuick(ye.index()), ye.get()));
+ }
+ }
+ return x;
+ }
+ }
+
+ public static class AssignAllIterateSequentialMergeUpdates extends VectorBinaryAssign {
+
+ @Override
+ public boolean isValid(Vector x, Vector y, DoubleDoubleFunction f) {
+ return x.isSequentialAccess() && y.isSequentialAccess() && !x.isAddConstantTime() && !x.isDense() && !y.isDense();
+ }
+
+ @Override
+ public double estimateCost(Vector x, Vector y, DoubleDoubleFunction f) {
+ return Math.max(x.size() * x.getIteratorAdvanceCost(), y.size() * y.getIteratorAdvanceCost());
+ }
+
+ @Override
+ public Vector assign(Vector x, Vector y, DoubleDoubleFunction f) {
+ Iterator<Vector.Element> xi = x.all().iterator();
+ Iterator<Vector.Element> yi = y.all().iterator();
+ OrderedIntDoubleMapping updates = new OrderedIntDoubleMapping(false);
+ while (xi.hasNext() && yi.hasNext()) {
+ Element xe = xi.next();
+ updates.set(xe.index(), f.apply(xe.get(), yi.next().get()));
+ }
+ x.mergeUpdates(updates);
+ return x;
+ }
+ }
+
+ public static class AssignAllIterateSequentialInplaceUpdates extends VectorBinaryAssign {
+
+ @Override
+ public boolean isValid(Vector x, Vector y, DoubleDoubleFunction f) {
+ return x.isSequentialAccess() && y.isSequentialAccess() && x.isAddConstantTime()
+ && !x.isDense() && !y.isDense();
+ }
+
+ @Override
+ public double estimateCost(Vector x, Vector y, DoubleDoubleFunction f) {
+ return Math.max(x.size() * x.getIteratorAdvanceCost(), y.size() * y.getIteratorAdvanceCost());
+ }
+
+ @Override
+ public Vector assign(Vector x, Vector y, DoubleDoubleFunction f) {
+ Iterator<Vector.Element> xi = x.all().iterator();
+ Iterator<Vector.Element> yi = y.all().iterator();
+ while (xi.hasNext() && yi.hasNext()) {
+ Element xe = xi.next();
+ x.setQuick(xe.index(), f.apply(xe.get(), yi.next().get()));
+ }
+ return x;
+ }
+ }
+
+ public static class AssignAllIterateThisLookupThatMergeUpdates extends VectorBinaryAssign {
+
+ @Override
+ public boolean isValid(Vector x, Vector y, DoubleDoubleFunction f) {
+ return !x.isAddConstantTime() && !x.isDense();
+ }
+
+ @Override
+ public double estimateCost(Vector x, Vector y, DoubleDoubleFunction f) {
+ return x.size() * x.getIteratorAdvanceCost() * y.getLookupCost();
+ }
+
+ @Override
+ public Vector assign(Vector x, Vector y, DoubleDoubleFunction f) {
+ OrderedIntDoubleMapping updates = new OrderedIntDoubleMapping(false);
+ for (Element xe : x.all()) {
+ updates.set(xe.index(), f.apply(xe.get(), y.getQuick(xe.index())));
+ }
+ x.mergeUpdates(updates);
+ return x;
+ }
+ }
+
+ public static class AssignAllIterateThisLookupThatInplaceUpdates extends VectorBinaryAssign {
+
+ @Override
+ public boolean isValid(Vector x, Vector y, DoubleDoubleFunction f) {
+ return x.isAddConstantTime() && !x.isDense();
+ }
+
+ @Override
+ public double estimateCost(Vector x, Vector y, DoubleDoubleFunction f) {
+ return x.size() * x.getIteratorAdvanceCost() * y.getLookupCost();
+ }
+
+ @Override
+ public Vector assign(Vector x, Vector y, DoubleDoubleFunction f) {
+ for (Element xe : x.all()) {
+ x.setQuick(xe.index(), f.apply(xe.get(), y.getQuick(xe.index())));
+ }
+ return x;
+ }
+ }
+
+ public static class AssignAllIterateThatLookupThisMergeUpdates extends VectorBinaryAssign {
+
+ @Override
+ public boolean isValid(Vector x, Vector y, DoubleDoubleFunction f) {
+ return !x.isAddConstantTime() && !y.isDense();
+ }
+
+ @Override
+ public double estimateCost(Vector x, Vector y, DoubleDoubleFunction f) {
+ return y.size() * y.getIteratorAdvanceCost() * x.getLookupCost();
+ }
+
+ @Override
+ public Vector assign(Vector x, Vector y, DoubleDoubleFunction f) {
+ OrderedIntDoubleMapping updates = new OrderedIntDoubleMapping(false);
+ for (Element ye : y.all()) {
+ updates.set(ye.index(), f.apply(x.getQuick(ye.index()), ye.get()));
+ }
+ x.mergeUpdates(updates);
+ return x;
+ }
+ }
+
+ public static class AssignAllIterateThatLookupThisInplaceUpdates extends VectorBinaryAssign {
+
+ @Override
+ public boolean isValid(Vector x, Vector y, DoubleDoubleFunction f) {
+ return x.isAddConstantTime() && !y.isDense();
+ }
+
+ @Override
+ public double estimateCost(Vector x, Vector y, DoubleDoubleFunction f) {
+ return y.size() * y.getIteratorAdvanceCost() * x.getLookupCost();
+ }
+
+ @Override
+ public Vector assign(Vector x, Vector y, DoubleDoubleFunction f) {
+ for (Element ye : y.all()) {
+ x.setQuick(ye.index(), f.apply(x.getQuick(ye.index()), ye.get()));
+ }
+ return x;
+ }
+ }
+
+ public static class AssignAllLoopMergeUpdates extends VectorBinaryAssign {
+
+ @Override
+ public boolean isValid(Vector x, Vector y, DoubleDoubleFunction f) {
+ return !x.isAddConstantTime();
+ }
+
+ @Override
+ public double estimateCost(Vector x, Vector y, DoubleDoubleFunction f) {
+ return x.size() * x.getLookupCost() * y.getLookupCost();
+ }
+
+ @Override
+ public Vector assign(Vector x, Vector y, DoubleDoubleFunction f) {
+ OrderedIntDoubleMapping updates = new OrderedIntDoubleMapping(false);
+ for (int i = 0; i < x.size(); ++i) {
+ updates.set(i, f.apply(x.getQuick(i), y.getQuick(i)));
+ }
+ x.mergeUpdates(updates);
+ return x;
+ }
+ }
+
+ public static class AssignAllLoopInplaceUpdates extends VectorBinaryAssign {
+
+ @Override
+ public boolean isValid(Vector x, Vector y, DoubleDoubleFunction f) {
+ return x.isAddConstantTime();
+ }
+
+ @Override
+ public double estimateCost(Vector x, Vector y, DoubleDoubleFunction f) {
+ return x.size() * x.getLookupCost() * y.getLookupCost();
+ }
+
+ @Override
+ public Vector assign(Vector x, Vector y, DoubleDoubleFunction f) {
+ for (int i = 0; i < x.size(); ++i) {
+ x.setQuick(i, f.apply(x.getQuick(i), y.getQuick(i)));
+ }
+ return x;
+ }
+ }
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/VectorIterable.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/mahout/math/VectorIterable.java b/core/src/main/java/org/apache/mahout/math/VectorIterable.java
new file mode 100644
index 0000000..8414fdb
--- /dev/null
+++ b/core/src/main/java/org/apache/mahout/math/VectorIterable.java
@@ -0,0 +1,56 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math;
+
+import java.util.Iterator;
+
+public interface VectorIterable extends Iterable<MatrixSlice> {
+
+ /* Iterate all rows in order */
+ Iterator<MatrixSlice> iterateAll();
+
+ /* Iterate all non empty rows in arbitrary order */
+ Iterator<MatrixSlice> iterateNonEmpty();
+
+ int numSlices();
+
+ int numRows();
+
+ int numCols();
+
+ /**
+ * Return a new vector with cardinality equal to getNumRows() of this matrix which is the matrix product of the
+ * recipient and the argument
+ *
+ * @param v a vector with cardinality equal to getNumCols() of the recipient
+ * @return a new vector (typically a DenseVector)
+ * @throws CardinalityException if this.getNumRows() != v.size()
+ */
+ Vector times(Vector v);
+
+ /**
+ * Convenience method for producing this.transpose().times(this.times(v)), which can be implemented with only one pass
+ * over the matrix, without making the transpose() call (which can be expensive if the matrix is sparse)
+ *
+ * @param v a vector with cardinality equal to getNumCols() of the recipient
+ * @return a new vector (typically a DenseVector) with cardinality equal to that of the argument.
+ * @throws CardinalityException if this.getNumCols() != v.size()
+ */
+ Vector timesSquared(Vector v);
+
+}
r***@apache.org
2018-09-08 23:35:19 UTC
Permalink
NO-JIRA Trevors updates


Project: http://git-wip-us.apache.org/repos/asf/mahout/repo
Commit: http://git-wip-us.apache.org/repos/asf/mahout/commit/545648f6
Tree: http://git-wip-us.apache.org/repos/asf/mahout/tree/545648f6
Diff: http://git-wip-us.apache.org/repos/asf/mahout/diff/545648f6

Branch: refs/heads/branch-0.14.0
Commit: 545648f6a8f69139757f0328e924c6b16a839441
Parents: 49ad8cb
Author: Trevor a.k.a @rawkintrevo <***@gmail.com>
Authored: Sat Sep 8 18:34:43 2018 -0500
Committer: Trevor a.k.a @rawkintrevo <***@gmail.com>
Committed: Sat Sep 8 18:34:43 2018 -0500

----------------------------------------------------------------------
.../apache/mahout/collections/Arithmetic.java | 489 ++++
.../apache/mahout/collections/Constants.java | 75 +
.../org/apache/mahout/common/RandomUtils.java | 100 +
.../org/apache/mahout/common/RandomWrapper.java | 105 +
.../org/apache/mahout/math/AbstractMatrix.java | 834 +++++++
.../org/apache/mahout/math/AbstractVector.java | 684 ++++++
.../java/org/apache/mahout/math/Algebra.java | 73 +
.../java/org/apache/mahout/math/Arrays.java | 662 +++++
.../org/apache/mahout/math/BinarySearch.java | 403 +++
.../mahout/math/CardinalityException.java | 30 +
.../java/org/apache/mahout/math/Centroid.java | 89 +
.../mahout/math/CholeskyDecomposition.java | 227 ++
.../org/apache/mahout/math/ConstantVector.java | 177 ++
.../apache/mahout/math/DelegatingVector.java | 336 +++
.../org/apache/mahout/math/DenseMatrix.java | 193 ++
.../mahout/math/DenseSymmetricMatrix.java | 62 +
.../org/apache/mahout/math/DenseVector.java | 442 ++++
.../org/apache/mahout/math/DiagonalMatrix.java | 378 +++
.../org/apache/mahout/math/FileBasedMatrix.java | 185 ++
.../math/FileBasedSparseBinaryMatrix.java | 535 ++++
.../mahout/math/FunctionalMatrixView.java | 99 +
.../org/apache/mahout/math/IndexException.java | 30 +
.../apache/mahout/math/LengthCachingVector.java | 35 +
.../java/org/apache/mahout/math/Matrices.java | 167 ++
.../java/org/apache/mahout/math/Matrix.java | 413 ++++
.../org/apache/mahout/math/MatrixSlice.java | 36 +
.../org/apache/mahout/math/MatrixTimesOps.java | 35 +
.../apache/mahout/math/MatrixVectorView.java | 292 +++
.../java/org/apache/mahout/math/MatrixView.java | 160 ++
.../java/org/apache/mahout/math/MurmurHash.java | 158 ++
.../org/apache/mahout/math/MurmurHash3.java | 84 +
.../org/apache/mahout/math/NamedVector.java | 328 +++
.../apache/mahout/math/OldQRDecomposition.java | 234 ++
.../mahout/math/OrderedIntDoubleMapping.java | 265 ++
.../mahout/math/OrthonormalityVerifier.java | 46 +
.../apache/mahout/math/PermutedVectorView.java | 250 ++
.../apache/mahout/math/PersistentObject.java | 58 +
.../org/apache/mahout/math/PivotedMatrix.java | 288 +++
.../main/java/org/apache/mahout/math/QR.java | 27 +
.../org/apache/mahout/math/QRDecomposition.java | 181 ++
.../mahout/math/RandomAccessSparseVector.java | 303 +++
.../apache/mahout/math/RandomTrinaryMatrix.java | 146 ++
.../math/SequentialAccessSparseVector.java | 379 +++
.../mahout/math/SingularValueDecomposition.java | 669 +++++
.../java/org/apache/mahout/math/Sorting.java | 2297 ++++++++++++++++++
.../apache/mahout/math/SparseColumnMatrix.java | 220 ++
.../org/apache/mahout/math/SparseMatrix.java | 245 ++
.../org/apache/mahout/math/SparseRowMatrix.java | 289 +++
.../java/org/apache/mahout/math/Swapper.java | 35 +
.../mahout/math/TransposedMatrixView.java | 147 ++
.../org/apache/mahout/math/UpperTriangular.java | 160 ++
.../java/org/apache/mahout/math/Vector.java | 434 ++++
.../mahout/math/VectorBinaryAggregate.java | 481 ++++
.../apache/mahout/math/VectorBinaryAssign.java | 667 +++++
.../org/apache/mahout/math/VectorIterable.java | 56 +
.../java/org/apache/mahout/math/VectorView.java | 238 ++
.../org/apache/mahout/math/WeightedVector.java | 87 +
.../mahout/math/WeightedVectorComparator.java | 54 +
.../math/als/AlternatingLeastSquaresSolver.java | 116 +
...itFeedbackAlternatingLeastSquaresSolver.java | 171 ++
.../math/decomposer/AsyncEigenVerifier.java | 80 +
.../mahout/math/decomposer/EigenStatus.java | 50 +
.../math/decomposer/SimpleEigenVerifier.java | 41 +
.../math/decomposer/SingularVectorVerifier.java | 25 +
.../math/decomposer/hebbian/EigenUpdater.java | 25 +
.../math/decomposer/hebbian/HebbianSolver.java | 342 +++
.../math/decomposer/hebbian/HebbianUpdater.java | 71 +
.../math/decomposer/hebbian/TrainingState.java | 143 ++
.../math/decomposer/lanczos/LanczosSolver.java | 213 ++
.../math/decomposer/lanczos/LanczosState.java | 107 +
.../org/apache/mahout/math/flavor/BackEnum.java | 26 +
.../apache/mahout/math/flavor/MatrixFlavor.java | 82 +
.../math/flavor/TraversingStructureEnum.java | 48 +
.../math/function/DoubleDoubleFunction.java | 98 +
.../mahout/math/function/DoubleFunction.java | 48 +
.../mahout/math/function/FloatFunction.java | 36 +
.../apache/mahout/math/function/Functions.java | 1730 +++++++++++++
.../mahout/math/function/IntFunction.java | 41 +
.../math/function/IntIntDoubleFunction.java | 43 +
.../mahout/math/function/IntIntFunction.java | 25 +
.../org/apache/mahout/math/function/Mult.java | 71 +
.../math/function/ObjectObjectProcedure.java | 40 +
.../mahout/math/function/ObjectProcedure.java | 47 +
.../apache/mahout/math/function/PlusMult.java | 123 +
.../math/function/SquareRootFunction.java | 26 +
.../mahout/math/function/TimesFunction.java | 77 +
.../mahout/math/function/VectorFunction.java | 27 +
.../mahout/math/function/package-info.java | 4 +
.../apache/mahout/math/jet/math/Arithmetic.java | 328 +++
.../apache/mahout/math/jet/math/Constants.java | 49 +
.../apache/mahout/math/jet/math/Polynomial.java | 98 +
.../mahout/math/jet/math/package-info.java | 5 +
.../random/AbstractContinousDistribution.java | 51 +
.../random/AbstractDiscreteDistribution.java | 27 +
.../math/jet/random/AbstractDistribution.java | 87 +
.../mahout/math/jet/random/Exponential.java | 81 +
.../apache/mahout/math/jet/random/Gamma.java | 302 +++
.../math/jet/random/NegativeBinomial.java | 106 +
.../apache/mahout/math/jet/random/Normal.java | 110 +
.../apache/mahout/math/jet/random/Poisson.java | 296 +++
.../apache/mahout/math/jet/random/Uniform.java | 164 ++
.../math/jet/random/engine/MersenneTwister.java | 275 +++
.../math/jet/random/engine/RandomEngine.java | 169 ++
.../math/jet/random/engine/package-info.java | 7 +
.../math/jet/random/sampling/RandomSampler.java | 503 ++++
.../org/apache/mahout/math/jet/stat/Gamma.java | 681 ++++++
.../mahout/math/jet/stat/Probability.java | 203 ++
.../mahout/math/jet/stat/package-info.java | 5 +
.../apache/mahout/math/list/AbstractList.java | 247 ++
.../mahout/math/list/AbstractObjectList.java | 80 +
.../mahout/math/list/ObjectArrayList.java | 419 ++++
.../mahout/math/list/SimpleLongArrayList.java | 102 +
.../apache/mahout/math/list/package-info.java | 144 ++
.../apache/mahout/math/map/HashFunctions.java | 115 +
.../org/apache/mahout/math/map/OpenHashMap.java | 652 +++++
.../org/apache/mahout/math/map/PrimeFinder.java | 145 ++
.../mahout/math/map/QuickOpenIntIntHashMap.java | 215 ++
.../apache/mahout/math/map/package-info.java | 250 ++
.../org/apache/mahout/math/package-info.java | 4 +
.../math/random/AbstractSamplerFunction.java | 39 +
.../mahout/math/random/ChineseRestaurant.java | 111 +
.../apache/mahout/math/random/Empirical.java | 124 +
.../apache/mahout/math/random/IndianBuffet.java | 157 ++
.../org/apache/mahout/math/random/Missing.java | 59 +
.../apache/mahout/math/random/MultiNormal.java | 118 +
.../apache/mahout/math/random/Multinomial.java | 202 ++
.../org/apache/mahout/math/random/Normal.java | 40 +
.../mahout/math/random/PoissonSampler.java | 67 +
.../org/apache/mahout/math/random/Sampler.java | 25 +
.../mahout/math/random/WeightedThing.java | 71 +
.../org/apache/mahout/math/set/AbstractSet.java | 188 ++
.../org/apache/mahout/math/set/HashUtils.java | 56 +
.../org/apache/mahout/math/set/OpenHashSet.java | 548 +++++
.../math/solver/ConjugateGradientSolver.java | 213 ++
.../mahout/math/solver/EigenDecomposition.java | 892 +++++++
.../mahout/math/solver/JacobiConditioner.java | 47 +
.../org/apache/mahout/math/solver/LSMR.java | 565 +++++
.../mahout/math/solver/Preconditioner.java | 36 +
.../mahout/math/ssvd/SequentialBigSvd.java | 69 +
.../apache/mahout/math/stats/LogLikelihood.java | 220 ++
.../math/stats/OnlineExponentialAverage.java | 62 +
.../mahout/math/stats/OnlineSummarizer.java | 93 +
.../apache/mahout/math/QRDecompositionTest.java | 280 +++
.../math/TestSingularValueDecomposition.java | 327 +++
.../als/AlternatingLeastSquaresSolverTest.java | 151 ++
.../mahout/math/decomposer/SolverTest.java | 177 ++
.../decomposer/hebbian/TestHebbianSolver.java | 207 ++
.../decomposer/lanczos/TestLanczosSolver.java | 97 +
.../apache/mahout/math/jet/stat/GammaTest.java | 138 ++
.../mahout/math/jet/stat/ProbabilityTest.java | 196 ++
.../math/random/ChineseRestaurantTest.java | 158 ++
.../mahout/math/randomized/RandomBlasting.java | 355 +++
.../mahout/math/ssvd/SequentialBigSvdTest.java | 86 +
.../mahout/math/stats/OnlineSummarizerTest.java | 108 +
154 files changed, 32350 insertions(+)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/collections/Arithmetic.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/mahout/collections/Arithmetic.java b/core/src/main/java/org/apache/mahout/collections/Arithmetic.java
new file mode 100644
index 0000000..18e3200
--- /dev/null
+++ b/core/src/main/java/org/apache/mahout/collections/Arithmetic.java
@@ -0,0 +1,489 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*
+Copyright 1999 CERN - European Organization for Nuclear Research.
+Permission to use, copy, modify, distribute and sell this software and its documentation for any purpose
+is hereby granted without fee, provided that the above copyright notice appear in all copies and
+that both that copyright notice and this permission notice appear in supporting documentation.
+CERN makes no representations about the suitability of this software for any purpose.
+It is provided "as is" without expressed or implied warranty.
+*/
+package org.apache.mahout.collections;
+
+/**
+ * Arithmetic functions.
+ */
+public final class Arithmetic extends Constants {
+ // for method STIRLING_CORRECTION(...)
+ private static final double[] STIRLING_CORRECTION = {
+ 0.0,
+ 8.106146679532726e-02, 4.134069595540929e-02,
+ 2.767792568499834e-02, 2.079067210376509e-02,
+ 1.664469118982119e-02, 1.387612882307075e-02,
+ 1.189670994589177e-02, 1.041126526197209e-02,
+ 9.255462182712733e-03, 8.330563433362871e-03,
+ 7.573675487951841e-03, 6.942840107209530e-03,
+ 6.408994188004207e-03, 5.951370112758848e-03,
+ 5.554733551962801e-03, 5.207655919609640e-03,
+ 4.901395948434738e-03, 4.629153749334029e-03,
+ 4.385560249232324e-03, 4.166319691996922e-03,
+ 3.967954218640860e-03, 3.787618068444430e-03,
+ 3.622960224683090e-03, 3.472021382978770e-03,
+ 3.333155636728090e-03, 3.204970228055040e-03,
+ 3.086278682608780e-03, 2.976063983550410e-03,
+ 2.873449362352470e-03, 2.777674929752690e-03,
+ };
+
+ // for method logFactorial(...)
+ // log(k!) for k = 0, ..., 29
+ private static final double[] LOG_FACTORIALS = {
+ 0.00000000000000000, 0.00000000000000000, 0.69314718055994531,
+ 1.79175946922805500, 3.17805383034794562, 4.78749174278204599,
+ 6.57925121201010100, 8.52516136106541430, 10.60460290274525023,
+ 12.80182748008146961, 15.10441257307551530, 17.50230784587388584,
+ 19.98721449566188615, 22.55216385312342289, 25.19122118273868150,
+ 27.89927138384089157, 30.67186010608067280, 33.50507345013688888,
+ 36.39544520803305358, 39.33988418719949404, 42.33561646075348503,
+ 45.38013889847690803, 48.47118135183522388, 51.60667556776437357,
+ 54.78472939811231919, 58.00360522298051994, 61.26170176100200198,
+ 64.55753862700633106, 67.88974313718153498, 71.25703896716800901
+ };
+
+ // k! for k = 0, ..., 20
+ private static final long[] LONG_FACTORIALS = {
+ 1L,
+ 1L,
+ 2L,
+ 6L,
+ 24L,
+ 120L,
+ 720L,
+ 5040L,
+ 40320L,
+ 362880L,
+ 3628800L,
+ 39916800L,
+ 479001600L,
+ 6227020800L,
+ 87178291200L,
+ 1307674368000L,
+ 20922789888000L,
+ 355687428096000L,
+ 6402373705728000L,
+ 121645100408832000L,
+ 2432902008176640000L
+ };
+
+ // k! for k = 21, ..., 170
+ private static final double[] DOUBLE_FACTORIALS = {
+ 5.109094217170944E19,
+ 1.1240007277776077E21,
+ 2.585201673888498E22,
+ 6.204484017332394E23,
+ 1.5511210043330984E25,
+ 4.032914611266057E26,
+ 1.0888869450418352E28,
+ 3.048883446117138E29,
+ 8.841761993739701E30,
+ 2.652528598121911E32,
+ 8.222838654177924E33,
+ 2.6313083693369355E35,
+ 8.68331761881189E36,
+ 2.952327990396041E38,
+ 1.0333147966386144E40,
+ 3.719933267899013E41,
+ 1.3763753091226346E43,
+ 5.23022617466601E44,
+ 2.0397882081197447E46,
+ 8.15915283247898E47,
+ 3.34525266131638E49,
+ 1.4050061177528801E51,
+ 6.041526306337384E52,
+ 2.6582715747884495E54,
+ 1.196222208654802E56,
+ 5.502622159812089E57,
+ 2.5862324151116827E59,
+ 1.2413915592536068E61,
+ 6.082818640342679E62,
+ 3.0414093201713376E64,
+ 1.5511187532873816E66,
+ 8.06581751709439E67,
+ 4.274883284060024E69,
+ 2.308436973392413E71,
+ 1.2696403353658264E73,
+ 7.109985878048632E74,
+ 4.052691950487723E76,
+ 2.350561331282879E78,
+ 1.386831185456898E80,
+ 8.32098711274139E81,
+ 5.075802138772246E83,
+ 3.146997326038794E85,
+ 1.9826083154044396E87,
+ 1.2688693218588414E89,
+ 8.247650592082472E90,
+ 5.443449390774432E92,
+ 3.6471110918188705E94,
+ 2.48003554243683E96,
+ 1.7112245242814127E98,
+ 1.1978571669969892E100,
+ 8.504785885678624E101,
+ 6.123445837688612E103,
+ 4.470115461512686E105,
+ 3.307885441519387E107,
+ 2.4809140811395404E109,
+ 1.8854947016660506E111,
+ 1.451830920282859E113,
+ 1.1324281178206295E115,
+ 8.94618213078298E116,
+ 7.15694570462638E118,
+ 5.797126020747369E120,
+ 4.7536433370128435E122,
+ 3.94552396972066E124,
+ 3.314240134565354E126,
+ 2.8171041143805494E128,
+ 2.4227095383672744E130,
+ 2.107757298379527E132,
+ 1.854826422573984E134,
+ 1.6507955160908465E136,
+ 1.4857159644817605E138,
+ 1.3520015276784033E140,
+ 1.2438414054641305E142,
+ 1.156772507081641E144,
+ 1.0873661566567426E146,
+ 1.0329978488239061E148,
+ 9.916779348709491E149,
+ 9.619275968248216E151,
+ 9.426890448883248E153,
+ 9.332621544394415E155,
+ 9.332621544394418E157,
+ 9.42594775983836E159,
+ 9.614466715035125E161,
+ 9.902900716486178E163,
+ 1.0299016745145631E166,
+ 1.0813967582402912E168,
+ 1.1462805637347086E170,
+ 1.2265202031961373E172,
+ 1.324641819451829E174,
+ 1.4438595832024942E176,
+ 1.5882455415227423E178,
+ 1.7629525510902457E180,
+ 1.974506857221075E182,
+ 2.2311927486598138E184,
+ 2.543559733472186E186,
+ 2.925093693493014E188,
+ 3.393108684451899E190,
+ 3.96993716080872E192,
+ 4.6845258497542896E194,
+ 5.574585761207606E196,
+ 6.689502913449135E198,
+ 8.094298525273444E200,
+ 9.875044200833601E202,
+ 1.2146304367025332E205,
+ 1.506141741511141E207,
+ 1.882677176888926E209,
+ 2.3721732428800483E211,
+ 3.0126600184576624E213,
+ 3.856204823625808E215,
+ 4.974504222477287E217,
+ 6.466855489220473E219,
+ 8.471580690878813E221,
+ 1.1182486511960037E224,
+ 1.4872707060906847E226,
+ 1.99294274616152E228,
+ 2.690472707318049E230,
+ 3.6590428819525483E232,
+ 5.0128887482749884E234,
+ 6.917786472619482E236,
+ 9.615723196941089E238,
+ 1.3462012475717523E241,
+ 1.8981437590761713E243,
+ 2.6953641378881633E245,
+ 3.8543707171800694E247,
+ 5.550293832739308E249,
+ 8.047926057471989E251,
+ 1.1749972043909107E254,
+ 1.72724589045464E256,
+ 2.5563239178728637E258,
+ 3.8089226376305687E260,
+ 5.7133839564458575E262,
+ 8.627209774233244E264,
+ 1.3113358856834527E267,
+ 2.0063439050956838E269,
+ 3.0897696138473515E271,
+ 4.789142901463393E273,
+ 7.471062926282892E275,
+ 1.1729568794264134E278,
+ 1.8532718694937346E280,
+ 2.946702272495036E282,
+ 4.714723635992061E284,
+ 7.590705053947223E286,
+ 1.2296942187394494E289,
+ 2.0044015765453032E291,
+ 3.287218585534299E293,
+ 5.423910666131583E295,
+ 9.003691705778434E297,
+ 1.5036165148649983E300,
+ 2.5260757449731988E302,
+ 4.2690680090047056E304,
+ 7.257415615308004E306
+ };
+
+ /** Makes this class non instantiable, but still let's others inherit from it. */
+ Arithmetic() {
+ }
+
+ /**
+ * Efficiently returns the binomial coefficient, often also referred to as
+ * "n over k" or "n choose k". The binomial coefficient is defined as
+ * <tt>(n * n-1 * ... * n-k+1 ) / ( 1 * 2 * ... * k )</tt>.
+ * <ul> <li><tt>k&lt;0</tt>: <tt>0</tt>.</li>
+ * <li><tt>k==0</tt>: <tt>1</tt>.</li>
+ * <li><tt>k==1</tt>: <tt>n</tt>.</li>
+ * <li>else: <tt>(n * n-1 * ... * n-k+1 ) / ( 1 * 2 * ... * k)</tt>.</li>
+ * </ul>
+ *
+ * @param n
+ * @param k
+ * @return the binomial coefficient.
+ */
+ public static double binomial(double n, long k) {
+ if (k < 0) {
+ return 0;
+ }
+ if (k == 0) {
+ return 1;
+ }
+ if (k == 1) {
+ return n;
+ }
+
+ // binomial(n,k) = (n * n-1 * ... * n-k+1 ) / ( 1 * 2 * ... * k )
+ double a = n - k + 1;
+ double b = 1;
+ double binomial = 1;
+ for (long i = k; i-- > 0;) {
+ binomial *= (a++) / (b++);
+ }
+ return binomial;
+ }
+
+ /**
+ * Efficiently returns the binomial coefficient, often also referred to as "n over k" or "n choose k". The binomial
+ * coefficient is defined as <ul> <li><tt>k&lt;0</tt>: <tt>0</tt>. <li><tt>k==0 || k==n</tt>: <tt>1</tt>. <li><tt>k==1 || k==n-1</tt>:
+ * <tt>n</tt>. <li>else: <tt>(n * n-1 * ... * n-k+1 ) / ( 1 * 2 * ... * k )</tt>. </ul>
+ *
+ * @return the binomial coefficient.
+ */
+ public static double binomial(long n, long k) {
+ if (k < 0) {
+ return 0;
+ }
+ if (k == 0 || k == n) {
+ return 1;
+ }
+ if (k == 1 || k == n - 1) {
+ return n;
+ }
+
+ // try quick version and see whether we get numeric overflows.
+ // factorial(..) is O(1); requires no loop; only a table lookup.
+ if (n > k) {
+ int max = LONG_FACTORIALS.length + DOUBLE_FACTORIALS.length;
+ if (n < max) { // if (n! < inf && k! < inf)
+ double n_fac = factorial((int) n);
+ double k_fac = factorial((int) k);
+ double n_minus_k_fac = factorial((int) (n - k));
+ double nk = n_minus_k_fac * k_fac;
+ if (nk != Double.POSITIVE_INFINITY) { // no numeric overflow?
+ // now this is completely safe and accurate
+ return n_fac / nk;
+ }
+ }
+ if (k > n / 2) {
+ k = n - k;
+ } // quicker
+ }
+
+ // binomial(n,k) = (n * n-1 * ... * n-k+1 ) / ( 1 * 2 * ... * k )
+ long a = n - k + 1;
+ long b = 1;
+ double binomial = 1;
+ for (long i = k; i-- > 0;) {
+ binomial *= (double) a++ / (b++);
+ }
+ return binomial;
+ }
+
+ /**
+ * Returns the smallest <code>long &gt;= value</code>.
+ * <dl><dt>Examples: {@code 1.0 -> 1, 1.2 -> 2, 1.9 -> 2}. This
+ * method is safer than using (long) Math.ceil(value), because of possible rounding error.</dt></dl>
+ */
+ public static long ceil(double value) {
+ return Math.round(Math.ceil(value));
+ }
+
+ /**
+ * Evaluates the series of Chebyshev polynomials Ti at argument x/2. The series is given by
+ * <pre>
+ * N-1
+ * - '
+ * y = &gt; coef[i] T (x/2)
+ * - i
+ * i=0
+ * </pre>
+ * Coefficients are stored in reverse order, i.e. the zero order term is last in the array. Note N is the number of
+ * coefficients, not the order. <p> If coefficients are for the interval a to b, x must have been transformed to x -&lt;
+ * 2(2x - b - a)/(b-a) before entering the routine. This maps x from (a, b) to (-1, 1), over which the Chebyshev
+ * polynomials are defined. <p> If the coefficients are for the inverted interval, in which (a, b) is mapped to (1/b,
+ * 1/a), the transformation required is {@code x -> 2(2ab/x - b - a)/(b-a)}. If b is infinity, this becomes {@code x -> 4a/x - 1}.
+ * <p> SPEED: <p> Taking advantage of the recurrence properties of the Chebyshev polynomials, the routine requires one
+ * more addition per loop than evaluating a nested polynomial of the same degree.
+ *
+ * @param x argument to the polynomial.
+ * @param coef the coefficients of the polynomial.
+ * @param N the number of coefficients.
+ */
+ public static double chbevl(double x, double[] coef, int N) {
+
+ int p = 0;
+
+ double b0 = coef[p++];
+ double b1 = 0.0;
+ int i = N - 1;
+
+ double b2;
+ do {
+ b2 = b1;
+ b1 = b0;
+ b0 = x * b1 - b2 + coef[p++];
+ } while (--i > 0);
+
+ return 0.5 * (b0 - b2);
+ }
+
+ /**
+ * Instantly returns the factorial <tt>k!</tt>.
+ *
+ * @param k must hold <tt>k &gt;= 0</tt>.
+ */
+ private static double factorial(int k) {
+ if (k < 0) {
+ throw new IllegalArgumentException();
+ }
+
+ int length1 = LONG_FACTORIALS.length;
+ if (k < length1) {
+ return LONG_FACTORIALS[k];
+ }
+
+ int length2 = DOUBLE_FACTORIALS.length;
+ if (k < length1 + length2) {
+ return DOUBLE_FACTORIALS[k - length1];
+ } else {
+ return Double.POSITIVE_INFINITY;
+ }
+ }
+
+ /**
+ * Returns the largest <code>long &lt;= value</code>.
+ * <dl><dt>Examples: {@code 1.0 -> 1, 1.2 -> 1, 1.9 -> 1 <dt> 2.0 -> 2, 2.2 -> 2, 2.9 -> 2}</dt></dl>
+ * This method is safer than using (long) Math.floor(value), because of possible rounding error.
+ */
+ public static long floor(double value) {
+ return Math.round(Math.floor(value));
+ }
+
+ /** Returns <tt>log<sub>base</sub>value</tt>. */
+ public static double log(double base, double value) {
+ return Math.log(value) / Math.log(base);
+ }
+
+ /** Returns <tt>log<sub>10</sub>value</tt>. */
+ public static double log10(double value) {
+ // 1.0 / Math.log(10) == 0.43429448190325176
+ return Math.log(value) * 0.43429448190325176;
+ }
+
+ /** Returns <tt>log<sub>2</sub>value</tt>. */
+ public static double log2(double value) {
+ // 1.0 / Math.log(2) == 1.4426950408889634
+ return Math.log(value) * 1.4426950408889634;
+ }
+
+ /**
+ * Returns <tt>log(k!)</tt>. Tries to avoid overflows. For <tt>k&lt;30</tt> simply looks up a table in O(1). For
+ * <tt>k&gt;=30</tt> uses stirlings approximation.
+ *
+ * @param k must hold <tt>k &gt;= 0</tt>.
+ */
+ public static double logFactorial(int k) {
+ if (k >= 30) {
+
+ double r = 1.0 / k;
+ double rr = r * r;
+ double C7 = -5.95238095238095238e-04;
+ double C5 = 7.93650793650793651e-04;
+ double C3 = -2.77777777777777778e-03;
+ double C1 = 8.33333333333333333e-02;
+ double C0 = 9.18938533204672742e-01;
+ return (k + 0.5) * Math.log(k) - k + C0 + r * (C1 + rr * (C3 + rr * (C5 + rr * C7)));
+ } else {
+ return LOG_FACTORIALS[k];
+ }
+ }
+
+ /**
+ * Instantly returns the factorial <tt>k!</tt>.
+ *
+ * @param k must hold {@code k >= 0 && k < 21}
+ */
+ public static long longFactorial(int k) {
+ if (k < 0) {
+ throw new IllegalArgumentException("Negative k");
+ }
+
+ if (k < LONG_FACTORIALS.length) {
+ return LONG_FACTORIALS[k];
+ }
+ throw new IllegalArgumentException("Overflow");
+ }
+
+ /**
+ * Returns the StirlingCorrection. <p> Correction term of the Stirling approximation for <tt>log(k!)</tt> (series in
+ * 1/k, or table values for small k) with int parameter k. </p> <tt> log k! = (k + 1/2)log(k + 1) - (k + 1) +
+ * (1/2)log(2Pi) + STIRLING_CORRECTION(k + 1) log k! = (k + 1/2)log(k) - k + (1/2)log(2Pi) +
+ * STIRLING_CORRECTION(k) </tt>
+ */
+ public static double stirlingCorrection(int k) {
+
+ if (k > 30) {
+ double r = 1.0 / k;
+ double rr = r * r;
+ double C7 = -5.95238095238095238e-04; // -1/1680
+ double C5 = 7.93650793650793651e-04; // +1/1260
+ double C3 = -2.77777777777777778e-03; // -1/360
+ double C1 = 8.33333333333333333e-02; // +1/12
+ return r * (C1 + rr * (C3 + rr * (C5 + rr * C7)));
+ } else {
+ return STIRLING_CORRECTION[k];
+ }
+ }
+
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/collections/Constants.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/mahout/collections/Constants.java b/core/src/main/java/org/apache/mahout/collections/Constants.java
new file mode 100644
index 0000000..007bd3f
--- /dev/null
+++ b/core/src/main/java/org/apache/mahout/collections/Constants.java
@@ -0,0 +1,75 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*
+Copyright 1999 CERN - European Organization for Nuclear Research.
+Permission to use, copy, modify, distribute and sell this software and its documentation for any purpose
+is hereby granted without fee, provided that the above copyright notice appear in all copies and
+that both that copyright notice and this permission notice appear in supporting documentation.
+CERN makes no representations about the suitability of this software for any purpose.
+It is provided "as is" without expressed or implied warranty.
+*/
+package org.apache.mahout.collections;
+
+/**
+ * Defines some useful constants.
+ */
+public class Constants {
+ /*
+ * machine constants
+ */
+ protected static final double MACHEP = 1.11022302462515654042E-16;
+ protected static final double MAXLOG = 7.09782712893383996732E2;
+ protected static final double MINLOG = -7.451332191019412076235E2;
+ protected static final double MAXGAM = 171.624376956302725;
+ protected static final double SQTPI = 2.50662827463100050242E0;
+ protected static final double SQRTH = 7.07106781186547524401E-1;
+ protected static final double LOGPI = 1.14472988584940017414;
+
+ protected static final double BIG = 4.503599627370496e15;
+ protected static final double BIGINV = 2.22044604925031308085e-16;
+
+
+ /*
+ * MACHEP = 1.38777878078144567553E-17 2**-56
+ * MAXLOG = 8.8029691931113054295988E1 log(2**127)
+ * MINLOG = -8.872283911167299960540E1 log(2**-128)
+ * MAXNUM = 1.701411834604692317316873e38 2**127
+ *
+ * For IEEE arithmetic (IBMPC):
+ * MACHEP = 1.11022302462515654042E-16 2**-53
+ * MAXLOG = 7.09782712893383996843E2 log(2**1024)
+ * MINLOG = -7.08396418532264106224E2 log(2**-1022)
+ * MAXNUM = 1.7976931348623158E308 2**1024
+ *
+ * The global symbols for mathematical constants are
+ * PI = 3.14159265358979323846 pi
+ * PIO2 = 1.57079632679489661923 pi/2
+ * PIO4 = 7.85398163397448309616E-1 pi/4
+ * SQRT2 = 1.41421356237309504880 sqrt(2)
+ * SQRTH = 7.07106781186547524401E-1 sqrt(2)/2
+ * LOG2E = 1.4426950408889634073599 1/log(2)
+ * SQ2OPI = 7.9788456080286535587989E-1 sqrt( 2/pi )
+ * LOGE2 = 6.93147180559945309417E-1 log(2)
+ * LOGSQ2 = 3.46573590279972654709E-1 log(2)/2
+ * THPIO4 = 2.35619449019234492885 3*pi/4
+ * TWOOPI = 6.36619772367581343075535E-1 2/pi
+ */
+ protected Constants() {}
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/common/RandomUtils.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/mahout/common/RandomUtils.java b/core/src/main/java/org/apache/mahout/common/RandomUtils.java
new file mode 100644
index 0000000..ba71292
--- /dev/null
+++ b/core/src/main/java/org/apache/mahout/common/RandomUtils.java
@@ -0,0 +1,100 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.common;
+
+import java.util.Collections;
+import java.util.Map;
+import java.util.Random;
+import java.util.WeakHashMap;
+
+import com.google.common.primitives.Longs;
+import org.apache.commons.math3.primes.Primes;
+
+/**
+ * <p>
+ * The source of random stuff for the whole project. This lets us make all randomness in the project
+ * predictable, if desired, for when we run unit tests, which should be repeatable.
+ * </p>
+ */
+public final class RandomUtils {
+
+ /** The largest prime less than 2<sup>31</sup>-1 that is the smaller of a twin prime pair. */
+ public static final int MAX_INT_SMALLER_TWIN_PRIME = 2147482949;
+
+ private static final Map<RandomWrapper,Boolean> INSTANCES =
+ Collections.synchronizedMap(new WeakHashMap<RandomWrapper,Boolean>());
+
+ private static boolean testSeed = false;
+
+ private RandomUtils() { }
+
+ public static void useTestSeed() {
+ testSeed = true;
+ synchronized (INSTANCES) {
+ for (RandomWrapper rng : INSTANCES.keySet()) {
+ rng.resetToTestSeed();
+ }
+ }
+ }
+
+ public static RandomWrapper getRandom() {
+ RandomWrapper random = new RandomWrapper();
+ if (testSeed) {
+ random.resetToTestSeed();
+ }
+ INSTANCES.put(random, Boolean.TRUE);
+ return random;
+ }
+
+ public static Random getRandom(long seed) {
+ RandomWrapper random = new RandomWrapper(seed);
+ INSTANCES.put(random, Boolean.TRUE);
+ return random;
+ }
+
+ /** @return what {@link Double#hashCode()} would return for the same value */
+ public static int hashDouble(double value) {
+ return Longs.hashCode(Double.doubleToLongBits(value));
+ }
+
+ /** @return what {@link Float#hashCode()} would return for the same value */
+ public static int hashFloat(float value) {
+ return Float.floatToIntBits(value);
+ }
+
+ /**
+ * <p>
+ * Finds next-largest "twin primes": numbers p and p+2 such that both are prime. Finds the smallest such p
+ * such that the smaller twin, p, is greater than or equal to n. Returns p+2, the larger of the two twins.
+ * </p>
+ */
+ public static int nextTwinPrime(int n) {
+ if (n > MAX_INT_SMALLER_TWIN_PRIME) {
+ throw new IllegalArgumentException();
+ }
+ if (n <= 3) {
+ return 5;
+ }
+ int next = Primes.nextPrime(n);
+ while (!Primes.isPrime(next + 2)) {
+ next = Primes.nextPrime(next + 4);
+ }
+ return next + 2;
+ }
+
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/common/RandomWrapper.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/mahout/common/RandomWrapper.java b/core/src/main/java/org/apache/mahout/common/RandomWrapper.java
new file mode 100644
index 0000000..802291b
--- /dev/null
+++ b/core/src/main/java/org/apache/mahout/common/RandomWrapper.java
@@ -0,0 +1,105 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.common;
+
+import org.apache.commons.math3.random.MersenneTwister;
+import org.apache.commons.math3.random.RandomGenerator;
+
+import java.util.Random;
+
+public final class RandomWrapper extends Random {
+
+ private static final long STANDARD_SEED = 0xCAFEDEADBEEFBABEL;
+
+ private final RandomGenerator random;
+
+ RandomWrapper() {
+ random = new MersenneTwister();
+ random.setSeed(System.currentTimeMillis() + System.identityHashCode(random));
+ }
+
+ RandomWrapper(long seed) {
+ random = new MersenneTwister(seed);
+ }
+
+ @Override
+ public void setSeed(long seed) {
+ // Since this will be called by the java.util.Random() constructor before we construct
+ // the delegate... and because we don't actually care about the result of this for our
+ // purpose:
+ if (random != null) {
+ random.setSeed(seed);
+ }
+ }
+
+ void resetToTestSeed() {
+ setSeed(STANDARD_SEED);
+ }
+
+ public RandomGenerator getRandomGenerator() {
+ return random;
+ }
+
+ @Override
+ protected int next(int bits) {
+ // Ugh, can't delegate this method -- it's protected
+ // Callers can't use it and other methods are delegated, so shouldn't matter
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public void nextBytes(byte[] bytes) {
+ random.nextBytes(bytes);
+ }
+
+ @Override
+ public int nextInt() {
+ return random.nextInt();
+ }
+
+ @Override
+ public int nextInt(int n) {
+ return random.nextInt(n);
+ }
+
+ @Override
+ public long nextLong() {
+ return random.nextLong();
+ }
+
+ @Override
+ public boolean nextBoolean() {
+ return random.nextBoolean();
+ }
+
+ @Override
+ public float nextFloat() {
+ return random.nextFloat();
+ }
+
+ @Override
+ public double nextDouble() {
+ return random.nextDouble();
+ }
+
+ @Override
+ public double nextGaussian() {
+ return random.nextGaussian();
+ }
+
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/AbstractMatrix.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/mahout/math/AbstractMatrix.java b/core/src/main/java/org/apache/mahout/math/AbstractMatrix.java
new file mode 100644
index 0000000..eaaa397
--- /dev/null
+++ b/core/src/main/java/org/apache/mahout/math/AbstractMatrix.java
@@ -0,0 +1,834 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math;
+
+import com.google.common.collect.AbstractIterator;
+import com.google.common.collect.Maps;
+import org.apache.mahout.math.flavor.MatrixFlavor;
+import org.apache.mahout.math.function.DoubleDoubleFunction;
+import org.apache.mahout.math.function.DoubleFunction;
+import org.apache.mahout.math.function.Functions;
+import org.apache.mahout.math.function.PlusMult;
+import org.apache.mahout.math.function.VectorFunction;
+
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.Map;
+
+/**
+ * A few universal implementations of convenience functions for a JVM-backed matrix.
+ */
+public abstract class AbstractMatrix implements Matrix {
+
+ protected Map<String, Integer> columnLabelBindings;
+ protected Map<String, Integer> rowLabelBindings;
+ protected int rows;
+ protected int columns;
+
+ protected AbstractMatrix(int rows, int columns) {
+ this.rows = rows;
+ this.columns = columns;
+ }
+
+ @Override
+ public int columnSize() {
+ return columns;
+ }
+
+ @Override
+ public int rowSize() {
+ return rows;
+ }
+
+ @Override
+ public Iterator<MatrixSlice> iterator() {
+ return iterateAll();
+ }
+
+ @Override
+ public Iterator<MatrixSlice> iterateAll() {
+ return new AbstractIterator<MatrixSlice>() {
+ private int row;
+
+ @Override
+ protected MatrixSlice computeNext() {
+ if (row >= numRows()) {
+ return endOfData();
+ }
+ int i = row++;
+ return new MatrixSlice(viewRow(i), i);
+ }
+ };
+ }
+
+ @Override
+ public Iterator<MatrixSlice> iterateNonEmpty() {
+ return iterator();
+ }
+
+ /**
+ * Abstracted out for the iterator
+ *
+ * @return numRows() for row-based iterator, numColumns() for column-based.
+ */
+ @Override
+ public int numSlices() {
+ return numRows();
+ }
+
+ @Override
+ public double get(String rowLabel, String columnLabel) {
+ if (columnLabelBindings == null || rowLabelBindings == null) {
+ throw new IllegalStateException("Unbound label");
+ }
+ Integer row = rowLabelBindings.get(rowLabel);
+ Integer col = columnLabelBindings.get(columnLabel);
+ if (row == null || col == null) {
+ throw new IllegalStateException("Unbound label");
+ }
+
+ return get(row, col);
+ }
+
+ @Override
+ public Map<String, Integer> getColumnLabelBindings() {
+ return columnLabelBindings;
+ }
+
+ @Override
+ public Map<String, Integer> getRowLabelBindings() {
+ return rowLabelBindings;
+ }
+
+ @Override
+ public void set(String rowLabel, double[] rowData) {
+ if (columnLabelBindings == null) {
+ throw new IllegalStateException("Unbound label");
+ }
+ Integer row = rowLabelBindings.get(rowLabel);
+ if (row == null) {
+ throw new IllegalStateException("Unbound label");
+ }
+ set(row, rowData);
+ }
+
+ @Override
+ public void set(String rowLabel, int row, double[] rowData) {
+ if (rowLabelBindings == null) {
+ rowLabelBindings = new HashMap<>();
+ }
+ rowLabelBindings.put(rowLabel, row);
+ set(row, rowData);
+ }
+
+ @Override
+ public void set(String rowLabel, String columnLabel, double value) {
+ if (columnLabelBindings == null || rowLabelBindings == null) {
+ throw new IllegalStateException("Unbound label");
+ }
+ Integer row = rowLabelBindings.get(rowLabel);
+ Integer col = columnLabelBindings.get(columnLabel);
+ if (row == null || col == null) {
+ throw new IllegalStateException("Unbound label");
+ }
+ set(row, col, value);
+ }
+
+ @Override
+ public void set(String rowLabel, String columnLabel, int row, int column, double value) {
+ if (rowLabelBindings == null) {
+ rowLabelBindings = new HashMap<>();
+ }
+ rowLabelBindings.put(rowLabel, row);
+ if (columnLabelBindings == null) {
+ columnLabelBindings = new HashMap<>();
+ }
+ columnLabelBindings.put(columnLabel, column);
+
+ set(row, column, value);
+ }
+
+ @Override
+ public void setColumnLabelBindings(Map<String, Integer> bindings) {
+ columnLabelBindings = bindings;
+ }
+
+ @Override
+ public void setRowLabelBindings(Map<String, Integer> bindings) {
+ rowLabelBindings = bindings;
+ }
+
+ // index into int[2] for column value
+ public static final int COL = 1;
+
+ // index into int[2] for row value
+ public static final int ROW = 0;
+
+ @Override
+ public int numRows() {
+ return rowSize();
+ }
+
+ @Override
+ public int numCols() {
+ return columnSize();
+ }
+
+ @Override
+ public String asFormatString() {
+ return toString();
+ }
+
+ @Override
+ public Matrix assign(double value) {
+ int rows = rowSize();
+ int columns = columnSize();
+ for (int row = 0; row < rows; row++) {
+ for (int col = 0; col < columns; col++) {
+ setQuick(row, col, value);
+ }
+ }
+ return this;
+ }
+
+ @Override
+ public Matrix assign(double[][] values) {
+ int rows = rowSize();
+ if (rows != values.length) {
+ throw new CardinalityException(rows, values.length);
+ }
+ int columns = columnSize();
+ for (int row = 0; row < rows; row++) {
+ if (columns == values[row].length) {
+ for (int col = 0; col < columns; col++) {
+ setQuick(row, col, values[row][col]);
+ }
+ } else {
+ throw new CardinalityException(columns, values[row].length);
+ }
+ }
+ return this;
+ }
+
+ @Override
+ public Matrix assign(Matrix other, DoubleDoubleFunction function) {
+ int rows = rowSize();
+ if (rows != other.rowSize()) {
+ throw new CardinalityException(rows, other.rowSize());
+ }
+ int columns = columnSize();
+ if (columns != other.columnSize()) {
+ throw new CardinalityException(columns, other.columnSize());
+ }
+ for (int row = 0; row < rows; row++) {
+ for (int col = 0; col < columns; col++) {
+ setQuick(row, col, function.apply(getQuick(row, col), other.getQuick(
+ row, col)));
+ }
+ }
+ return this;
+ }
+
+ @Override
+ public Matrix assign(Matrix other) {
+ int rows = rowSize();
+ if (rows != other.rowSize()) {
+ throw new CardinalityException(rows, other.rowSize());
+ }
+ int columns = columnSize();
+ if (columns != other.columnSize()) {
+ throw new CardinalityException(columns, other.columnSize());
+ }
+ for (int row = 0; row < rows; row++) {
+ for (int col = 0; col < columns; col++) {
+ setQuick(row, col, other.getQuick(row, col));
+ }
+ }
+ return this;
+ }
+
+ @Override
+ public Matrix assign(DoubleFunction function) {
+ int rows = rowSize();
+ int columns = columnSize();
+ for (int row = 0; row < rows; row++) {
+ for (int col = 0; col < columns; col++) {
+ setQuick(row, col, function.apply(getQuick(row, col)));
+ }
+ }
+ return this;
+ }
+
+ /**
+ * Collects the results of a function applied to each row of a matrix.
+ *
+ * @param f The function to be applied to each row.
+ * @return The vector of results.
+ */
+ @Override
+ public Vector aggregateRows(VectorFunction f) {
+ Vector r = new DenseVector(numRows());
+ int n = numRows();
+ for (int row = 0; row < n; row++) {
+ r.set(row, f.apply(viewRow(row)));
+ }
+ return r;
+ }
+
+ /**
+ * Returns a view of a row. Changes to the view will affect the original.
+ *
+ * @param row Which row to return.
+ * @return A vector that references the desired row.
+ */
+ @Override
+ public Vector viewRow(int row) {
+ return new MatrixVectorView(this, row, 0, 0, 1);
+ }
+
+
+ /**
+ * Returns a view of a row. Changes to the view will affect the original.
+ *
+ * @param column Which column to return.
+ * @return A vector that references the desired column.
+ */
+ @Override
+ public Vector viewColumn(int column) {
+ return new MatrixVectorView(this, 0, column, 1, 0);
+ }
+
+ /**
+ * Provides a view of the diagonal of a matrix.
+ */
+ @Override
+ public Vector viewDiagonal() {
+ return new MatrixVectorView(this, 0, 0, 1, 1);
+ }
+
+ /**
+ * Collects the results of a function applied to each element of a matrix and then aggregated.
+ *
+ * @param combiner A function that combines the results of the mapper.
+ * @param mapper A function to apply to each element.
+ * @return The result.
+ */
+ @Override
+ public double aggregate(final DoubleDoubleFunction combiner, final DoubleFunction mapper) {
+ return aggregateRows(new VectorFunction() {
+ @Override
+ public double apply(Vector v) {
+ return v.aggregate(combiner, mapper);
+ }
+ }).aggregate(combiner, Functions.IDENTITY);
+ }
+
+ /**
+ * Collects the results of a function applied to each column of a matrix.
+ *
+ * @param f The function to be applied to each column.
+ * @return The vector of results.
+ */
+ @Override
+ public Vector aggregateColumns(VectorFunction f) {
+ Vector r = new DenseVector(numCols());
+ for (int col = 0; col < numCols(); col++) {
+ r.set(col, f.apply(viewColumn(col)));
+ }
+ return r;
+ }
+
+ @Override
+ public double determinant() {
+ int rows = rowSize();
+ int columns = columnSize();
+ if (rows != columns) {
+ throw new CardinalityException(rows, columns);
+ }
+
+ if (rows == 2) {
+ return getQuick(0, 0) * getQuick(1, 1) - getQuick(0, 1) * getQuick(1, 0);
+ } else {
+ // TODO: this really should just be one line:
+ // TODO: new CholeskyDecomposition(this).getL().viewDiagonal().aggregate(Functions.TIMES)
+ int sign = 1;
+ double ret = 0;
+
+ for (int i = 0; i < columns; i++) {
+ Matrix minor = new DenseMatrix(rows - 1, columns - 1);
+ for (int j = 1; j < rows; j++) {
+ boolean flag = false; /* column offset flag */
+ for (int k = 0; k < columns; k++) {
+ if (k == i) {
+ flag = true;
+ continue;
+ }
+ minor.set(j - 1, flag ? k - 1 : k, getQuick(j, k));
+ }
+ }
+ ret += getQuick(0, i) * sign * minor.determinant();
+ sign *= -1;
+
+ }
+
+ return ret;
+ }
+
+ }
+
+ @SuppressWarnings("CloneDoesntDeclareCloneNotSupportedException")
+ @Override
+ public Matrix clone() {
+ AbstractMatrix clone;
+ try {
+ clone = (AbstractMatrix) super.clone();
+ } catch (CloneNotSupportedException cnse) {
+ throw new IllegalStateException(cnse); // can't happen
+ }
+ if (rowLabelBindings != null) {
+ clone.rowLabelBindings = Maps.newHashMap(rowLabelBindings);
+ }
+ if (columnLabelBindings != null) {
+ clone.columnLabelBindings = Maps.newHashMap(columnLabelBindings);
+ }
+ return clone;
+ }
+
+ @Override
+ public Matrix divide(double x) {
+ Matrix result = like();
+ for (int row = 0; row < rowSize(); row++) {
+ for (int col = 0; col < columnSize(); col++) {
+ result.setQuick(row, col, getQuick(row, col) / x);
+ }
+ }
+ return result;
+ }
+
+ @Override
+ public double get(int row, int column) {
+ if (row < 0 || row >= rowSize()) {
+ throw new IndexException(row, rowSize());
+ }
+ if (column < 0 || column >= columnSize()) {
+ throw new IndexException(column, columnSize());
+ }
+ return getQuick(row, column);
+ }
+
+ @Override
+ public Matrix minus(Matrix other) {
+ int rows = rowSize();
+ if (rows != other.rowSize()) {
+ throw new CardinalityException(rows, other.rowSize());
+ }
+ int columns = columnSize();
+ if (columns != other.columnSize()) {
+ throw new CardinalityException(columns, other.columnSize());
+ }
+ Matrix result = like();
+ for (int row = 0; row < rows; row++) {
+ for (int col = 0; col < columns; col++) {
+ result.setQuick(row, col, getQuick(row, col)
+ - other.getQuick(row, col));
+ }
+ }
+ return result;
+ }
+
+ @Override
+ public Matrix plus(double x) {
+ Matrix result = like();
+ int rows = rowSize();
+ int columns = columnSize();
+ for (int row = 0; row < rows; row++) {
+ for (int col = 0; col < columns; col++) {
+ result.setQuick(row, col, getQuick(row, col) + x);
+ }
+ }
+ return result;
+ }
+
+ @Override
+ public Matrix plus(Matrix other) {
+ int rows = rowSize();
+ if (rows != other.rowSize()) {
+ throw new CardinalityException(rows, other.rowSize());
+ }
+ int columns = columnSize();
+ if (columns != other.columnSize()) {
+ throw new CardinalityException(columns, other.columnSize());
+ }
+ Matrix result = like();
+ for (int row = 0; row < rows; row++) {
+ for (int col = 0; col < columns; col++) {
+ result.setQuick(row, col, getQuick(row, col)
+ + other.getQuick(row, col));
+ }
+ }
+ return result;
+ }
+
+ @Override
+ public void set(int row, int column, double value) {
+ if (row < 0 || row >= rowSize()) {
+ throw new IndexException(row, rowSize());
+ }
+ if (column < 0 || column >= columnSize()) {
+ throw new IndexException(column, columnSize());
+ }
+ setQuick(row, column, value);
+ }
+
+ @Override
+ public void set(int row, double[] data) {
+ int columns = columnSize();
+ if (columns < data.length) {
+ throw new CardinalityException(columns, data.length);
+ }
+ int rows = rowSize();
+ if (row < 0 || row >= rows) {
+ throw new IndexException(row, rowSize());
+ }
+ for (int i = 0; i < columns; i++) {
+ setQuick(row, i, data[i]);
+ }
+ }
+
+ @Override
+ public Matrix times(double x) {
+ Matrix result = like();
+ int rows = rowSize();
+ int columns = columnSize();
+ for (int row = 0; row < rows; row++) {
+ for (int col = 0; col < columns; col++) {
+ result.setQuick(row, col, getQuick(row, col) * x);
+ }
+ }
+ return result;
+ }
+
+ @Override
+ public Matrix times(Matrix other) {
+ int columns = columnSize();
+ if (columns != other.rowSize()) {
+ throw new CardinalityException(columns, other.rowSize());
+ }
+ int rows = rowSize();
+ int otherColumns = other.columnSize();
+ Matrix result = like(rows, otherColumns);
+ for (int row = 0; row < rows; row++) {
+ for (int col = 0; col < otherColumns; col++) {
+ double sum = 0.0;
+ for (int k = 0; k < columns; k++) {
+ sum += getQuick(row, k) * other.getQuick(k, col);
+ }
+ result.setQuick(row, col, sum);
+ }
+ }
+ return result;
+ }
+
+ @Override
+ public Vector times(Vector v) {
+ int columns = columnSize();
+ if (columns != v.size()) {
+ throw new CardinalityException(columns, v.size());
+ }
+ int rows = rowSize();
+ Vector w = new DenseVector(rows);
+ for (int row = 0; row < rows; row++) {
+ w.setQuick(row, v.dot(viewRow(row)));
+ }
+ return w;
+ }
+
+ @Override
+ public Vector timesSquared(Vector v) {
+ int columns = columnSize();
+ if (columns != v.size()) {
+ throw new CardinalityException(columns, v.size());
+ }
+ int rows = rowSize();
+ Vector w = new DenseVector(columns);
+ for (int i = 0; i < rows; i++) {
+ Vector xi = viewRow(i);
+ double d = xi.dot(v);
+ if (d != 0.0) {
+ w.assign(xi, new PlusMult(d));
+ }
+
+ }
+ return w;
+ }
+
+ @Override
+ public Matrix transpose() {
+ int rows = rowSize();
+ int columns = columnSize();
+ Matrix result = like(columns, rows);
+ for (int row = 0; row < rows; row++) {
+ for (int col = 0; col < columns; col++) {
+ result.setQuick(col, row, getQuick(row, col));
+ }
+ }
+ return result;
+ }
+
+ @Override
+ public Matrix viewPart(int rowOffset, int rowsRequested, int columnOffset, int columnsRequested) {
+ return viewPart(new int[]{rowOffset, columnOffset}, new int[]{rowsRequested, columnsRequested});
+ }
+
+ @Override
+ public Matrix viewPart(int[] offset, int[] size) {
+
+ if (offset[ROW] < 0) {
+ throw new IndexException(offset[ROW], 0);
+ }
+ if (offset[ROW] + size[ROW] > rowSize()) {
+ throw new IndexException(offset[ROW] + size[ROW], rowSize());
+ }
+ if (offset[COL] < 0) {
+ throw new IndexException(offset[COL], 0);
+ }
+ if (offset[COL] + size[COL] > columnSize()) {
+ throw new IndexException(offset[COL] + size[COL], columnSize());
+ }
+
+ return new MatrixView(this, offset, size);
+ }
+
+
+ @Override
+ public double zSum() {
+ double result = 0;
+ for (int row = 0; row < rowSize(); row++) {
+ for (int col = 0; col < columnSize(); col++) {
+ result += getQuick(row, col);
+ }
+ }
+ return result;
+ }
+
+ @Override
+ public int[] getNumNondefaultElements() {
+ return new int[]{rowSize(), columnSize()};
+ }
+
+ protected static class TransposeViewVector extends AbstractVector {
+
+ private final Matrix matrix;
+ private final int transposeOffset;
+ private final int numCols;
+ private final boolean rowToColumn;
+
+ protected TransposeViewVector(Matrix m, int offset) {
+ this(m, offset, true);
+ }
+
+ protected TransposeViewVector(Matrix m, int offset, boolean rowToColumn) {
+ super(rowToColumn ? m.numRows() : m.numCols());
+ matrix = m;
+ this.transposeOffset = offset;
+ this.rowToColumn = rowToColumn;
+ numCols = rowToColumn ? m.numCols() : m.numRows();
+ }
+
+ @SuppressWarnings("CloneDoesntCallSuperClone")
+ @Override
+ public Vector clone() {
+ Vector v = new DenseVector(size());
+ v.assign(this, Functions.PLUS);
+ return v;
+ }
+
+ @Override
+ public boolean isDense() {
+ return true;
+ }
+
+ @Override
+ public boolean isSequentialAccess() {
+ return true;
+ }
+
+ @Override
+ protected Matrix matrixLike(int rows, int columns) {
+ return matrix.like(rows, columns);
+ }
+
+ @Override
+ public Iterator<Element> iterator() {
+ return new AbstractIterator<Element>() {
+ private int i;
+
+ @Override
+ protected Element computeNext() {
+ if (i >= size()) {
+ return endOfData();
+ }
+ return getElement(i++);
+ }
+ };
+ }
+
+ /**
+ * Currently delegates to {@link #iterator()}.
+ * TODO: This could be optimized to at least skip empty rows if there are many of them.
+ *
+ * @return an iterator (currently dense).
+ */
+ @Override
+ public Iterator<Element> iterateNonZero() {
+ return iterator();
+ }
+
+ @Override
+ public Element getElement(final int i) {
+ return new Element() {
+ @Override
+ public double get() {
+ return getQuick(i);
+ }
+
+ @Override
+ public int index() {
+ return i;
+ }
+
+ @Override
+ public void set(double value) {
+ setQuick(i, value);
+ }
+ };
+ }
+
+ /**
+ * Used internally by assign() to update multiple indices and values at once.
+ * Only really useful for sparse vectors (especially SequentialAccessSparseVector).
+ * <p>
+ * If someone ever adds a new type of sparse vectors, this method must merge (index, value) pairs into the vector.
+ *
+ * @param updates a mapping of indices to values to merge in the vector.
+ */
+ @Override
+ public void mergeUpdates(OrderedIntDoubleMapping updates) {
+ throw new UnsupportedOperationException("Cannot mutate TransposeViewVector");
+ }
+
+ @Override
+ public double getQuick(int index) {
+ Vector v = rowToColumn ? matrix.viewColumn(index) : matrix.viewRow(index);
+ return v == null ? 0.0 : v.getQuick(transposeOffset);
+ }
+
+ @Override
+ public void setQuick(int index, double value) {
+ Vector v = rowToColumn ? matrix.viewColumn(index) : matrix.viewRow(index);
+ if (v == null) {
+ v = newVector(numCols);
+ if (rowToColumn) {
+ matrix.assignColumn(index, v);
+ } else {
+ matrix.assignRow(index, v);
+ }
+ }
+ v.setQuick(transposeOffset, value);
+ }
+
+ protected Vector newVector(int cardinality) {
+ return new DenseVector(cardinality);
+ }
+
+ @Override
+ public Vector like() {
+ return new DenseVector(size());
+ }
+
+ public Vector like(int cardinality) {
+ return new DenseVector(cardinality);
+ }
+
+ /**
+ * TODO: currently I don't know of an efficient way to getVector this value correctly.
+ *
+ * @return the number of nonzero entries
+ */
+ @Override
+ public int getNumNondefaultElements() {
+ return size();
+ }
+
+ @Override
+ public double getLookupCost() {
+ return (rowToColumn ? matrix.viewColumn(0) : matrix.viewRow(0)).getLookupCost();
+ }
+
+ @Override
+ public double getIteratorAdvanceCost() {
+ return (rowToColumn ? matrix.viewColumn(0) : matrix.viewRow(0)).getIteratorAdvanceCost();
+ }
+
+ @Override
+ public boolean isAddConstantTime() {
+ return (rowToColumn ? matrix.viewColumn(0) : matrix.viewRow(0)).isAddConstantTime();
+ }
+ }
+
+ @Override
+ public String toString() {
+ int row = 0;
+ int maxRowsToDisplay = 10;
+ int maxColsToDisplay = 20;
+ int colsToDisplay = maxColsToDisplay;
+
+ if(maxColsToDisplay > columnSize()){
+ colsToDisplay = columnSize();
+ }
+
+
+ StringBuilder s = new StringBuilder("{\n");
+ Iterator<MatrixSlice> it = iterator();
+ while ((it.hasNext()) && (row < maxRowsToDisplay)) {
+ MatrixSlice next = it.next();
+ s.append(" ").append(next.index())
+ .append(" =>\t")
+ .append(new VectorView(next.vector(), 0, colsToDisplay))
+ .append('\n');
+ row ++;
+ }
+ String returnString = s.toString();
+ if (maxColsToDisplay <= columnSize()) {
+ returnString = returnString.replace("}", " ... } ");
+ }
+ if(maxRowsToDisplay <= rowSize())
+ return returnString + ("... }");
+ else{
+ return returnString + ("}");
+ }
+ }
+
+ @Override
+ public MatrixFlavor getFlavor() {
+ throw new UnsupportedOperationException("Flavor support not implemented for this matrix.");
+ }
+
+ ////////////// Matrix flavor trait ///////////////////
+
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/AbstractVector.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/mahout/math/AbstractVector.java b/core/src/main/java/org/apache/mahout/math/AbstractVector.java
new file mode 100644
index 0000000..27eddbc
--- /dev/null
+++ b/core/src/main/java/org/apache/mahout/math/AbstractVector.java
@@ -0,0 +1,684 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math;
+
+import java.util.Iterator;
+
+import com.google.common.base.Preconditions;
+import org.apache.mahout.common.RandomUtils;
+import org.apache.mahout.math.function.DoubleDoubleFunction;
+import org.apache.mahout.math.function.DoubleFunction;
+import org.apache.mahout.math.function.Functions;
+
+/** Implementations of generic capabilities like sum of elements and dot products */
+public abstract class AbstractVector implements Vector, LengthCachingVector {
+
+ private int size;
+ protected double lengthSquared = -1.0;
+
+ protected AbstractVector(int size) {
+ this.size = size;
+ }
+
+ @Override
+ public Iterable<Element> all() {
+ return new Iterable<Element>() {
+ @Override
+ public Iterator<Element> iterator() {
+ return AbstractVector.this.iterator();
+ }
+ };
+ }
+
+ @Override
+ public Iterable<Element> nonZeroes() {
+ return new Iterable<Element>() {
+ @Override
+ public Iterator<Element> iterator() {
+ return iterateNonZero();
+ }
+ };
+ }
+
+ /**
+ * Iterates over all elements <p>
+ * NOTE: Implementations may choose to reuse the Element returned for performance
+ * reasons, so if you need a copy of it, you should call {@link #getElement(int)} for the given index
+ *
+ * @return An {@link Iterator} over all elements
+ */
+ protected abstract Iterator<Element> iterator();
+
+ /**
+ * Iterates over all non-zero elements. <p>
+ * NOTE: Implementations may choose to reuse the Element returned for
+ * performance reasons, so if you need a copy of it, you should call {@link #getElement(int)} for the given index
+ *
+ * @return An {@link Iterator} over all non-zero elements
+ */
+ protected abstract Iterator<Element> iterateNonZero();
+ /**
+ * Aggregates a vector by applying a mapping function fm(x) to every component and aggregating
+ * the results with an aggregating function fa(x, y).
+ *
+ * @param aggregator used to combine the current value of the aggregation with the result of map.apply(nextValue)
+ * @param map a function to apply to each element of the vector in turn before passing to the aggregator
+ * @return the result of the aggregation
+ */
+ @Override
+ public double aggregate(DoubleDoubleFunction aggregator, DoubleFunction map) {
+ if (size == 0) {
+ return 0;
+ }
+
+ // If the aggregator is associative and commutative and it's likeLeftMult (fa(0, y) = 0), and there is
+ // at least one zero in the vector (size > getNumNondefaultElements) and applying fm(0) = 0, the result
+ // gets cascaded through the aggregation and the final result will be 0.
+ if (aggregator.isAssociativeAndCommutative() && aggregator.isLikeLeftMult()
+ && size > getNumNondefaultElements() && !map.isDensifying()) {
+ return 0;
+ }
+
+ double result;
+ if (isSequentialAccess() || aggregator.isAssociativeAndCommutative()) {
+ Iterator<Element> iterator;
+ // If fm(0) = 0 and fa(x, 0) = x, we can skip all zero values.
+ if (!map.isDensifying() && aggregator.isLikeRightPlus()) {
+ iterator = iterateNonZero();
+ if (!iterator.hasNext()) {
+ return 0;
+ }
+ } else {
+ iterator = iterator();
+ }
+ Element element = iterator.next();
+ result = map.apply(element.get());
+ while (iterator.hasNext()) {
+ element = iterator.next();
+ result = aggregator.apply(result, map.apply(element.get()));
+ }
+ } else {
+ result = map.apply(getQuick(0));
+ for (int i = 1; i < size; i++) {
+ result = aggregator.apply(result, map.apply(getQuick(i)));
+ }
+ }
+
+ return result;
+ }
+
+ @Override
+ public double aggregate(Vector other, DoubleDoubleFunction aggregator, DoubleDoubleFunction combiner) {
+ Preconditions.checkArgument(size == other.size(), "Vector sizes differ");
+ if (size == 0) {
+ return 0;
+ }
+ return VectorBinaryAggregate.aggregateBest(this, other, aggregator, combiner);
+ }
+
+ /**
+ * Subclasses must override to return an appropriately sparse or dense result
+ *
+ * @param rows the row cardinality
+ * @param columns the column cardinality
+ * @return a Matrix
+ */
+ protected abstract Matrix matrixLike(int rows, int columns);
+
+ @Override
+ public Vector viewPart(int offset, int length) {
+ if (offset < 0) {
+ throw new IndexException(offset, size);
+ }
+ if (offset + length > size) {
+ throw new IndexException(offset + length, size);
+ }
+ return new VectorView(this, offset, length);
+ }
+
+ @SuppressWarnings("CloneDoesntDeclareCloneNotSupportedException")
+ @Override
+ public Vector clone() {
+ try {
+ AbstractVector r = (AbstractVector) super.clone();
+ r.size = size;
+ r.lengthSquared = lengthSquared;
+ return r;
+ } catch (CloneNotSupportedException e) {
+ throw new IllegalStateException("Can't happen");
+ }
+ }
+
+ @Override
+ public Vector divide(double x) {
+ if (x == 1.0) {
+ return clone();
+ }
+ Vector result = createOptimizedCopy();
+ for (Element element : result.nonZeroes()) {
+ element.set(element.get() / x);
+ }
+ return result;
+ }
+
+ @Override
+ public double dot(Vector x) {
+ if (size != x.size()) {
+ throw new CardinalityException(size, x.size());
+ }
+ if (this == x) {
+ return getLengthSquared();
+ }
+ return aggregate(x, Functions.PLUS, Functions.MULT);
+ }
+
+ protected double dotSelf() {
+ return aggregate(Functions.PLUS, Functions.pow(2));
+ }
+
+ @Override
+ public double get(int index) {
+ if (index < 0 || index >= size) {
+ throw new IndexException(index, size);
+ }
+ return getQuick(index);
+ }
+
+ @Override
+ public Element getElement(int index) {
+ return new LocalElement(index);
+ }
+
+ @Override
+ public Vector normalize() {
+ return divide(Math.sqrt(getLengthSquared()));
+ }
+
+ @Override
+ public Vector normalize(double power) {
+ return divide(norm(power));
+ }
+
+ @Override
+ public Vector logNormalize() {
+ return logNormalize(2.0, Math.sqrt(getLengthSquared()));
+ }
+
+ @Override
+ public Vector logNormalize(double power) {
+ return logNormalize(power, norm(power));
+ }
+
+ public Vector logNormalize(double power, double normLength) {
+ // we can special case certain powers
+ if (Double.isInfinite(power) || power <= 1.0) {
+ throw new IllegalArgumentException("Power must be > 1 and < infinity");
+ } else {
+ double denominator = normLength * Math.log(power);
+ Vector result = createOptimizedCopy();
+ for (Element element : result.nonZeroes()) {
+ element.set(Math.log1p(element.get()) / denominator);
+ }
+ return result;
+ }
+ }
+
+ @Override
+ public double norm(double power) {
+ if (power < 0.0) {
+ throw new IllegalArgumentException("Power must be >= 0");
+ }
+ // We can special case certain powers.
+ if (Double.isInfinite(power)) {
+ return aggregate(Functions.MAX, Functions.ABS);
+ } else if (power == 2.0) {
+ return Math.sqrt(getLengthSquared());
+ } else if (power == 1.0) {
+ double result = 0.0;
+ Iterator<Element> iterator = this.iterateNonZero();
+ while (iterator.hasNext()) {
+ result += Math.abs(iterator.next().get());
+ }
+ return result;
+ // TODO: this should ideally be used, but it's slower.
+ // return aggregate(Functions.PLUS, Functions.ABS);
+ } else if (power == 0.0) {
+ return getNumNonZeroElements();
+ } else {
+ return Math.pow(aggregate(Functions.PLUS, Functions.pow(power)), 1.0 / power);
+ }
+ }
+
+ @Override
+ public double getLengthSquared() {
+ if (lengthSquared >= 0.0) {
+ return lengthSquared;
+ }
+ return lengthSquared = dotSelf();
+ }
+
+ @Override
+ public void invalidateCachedLength() {
+ lengthSquared = -1;
+ }
+
+ @Override
+ public double getDistanceSquared(Vector that) {
+ if (size != that.size()) {
+ throw new CardinalityException(size, that.size());
+ }
+ double thisLength = getLengthSquared();
+ double thatLength = that.getLengthSquared();
+ double dot = dot(that);
+ double distanceEstimate = thisLength + thatLength - 2 * dot;
+ if (distanceEstimate > 1.0e-3 * (thisLength + thatLength)) {
+ // The vectors are far enough from each other that the formula is accurate.
+ return Math.max(distanceEstimate, 0);
+ } else {
+ return aggregate(that, Functions.PLUS, Functions.MINUS_SQUARED);
+ }
+ }
+
+ @Override
+ public double maxValue() {
+ if (size == 0) {
+ return Double.NEGATIVE_INFINITY;
+ }
+ return aggregate(Functions.MAX, Functions.IDENTITY);
+ }
+
+ @Override
+ public int maxValueIndex() {
+ int result = -1;
+ double max = Double.NEGATIVE_INFINITY;
+ int nonZeroElements = 0;
+ Iterator<Element> iter = this.iterateNonZero();
+ while (iter.hasNext()) {
+ nonZeroElements++;
+ Element element = iter.next();
+ double tmp = element.get();
+ if (tmp > max) {
+ max = tmp;
+ result = element.index();
+ }
+ }
+ // if the maxElement is negative and the vector is sparse then any
+ // unfilled element(0.0) could be the maxValue hence we need to
+ // find one of those elements
+ if (nonZeroElements < size && max < 0.0) {
+ for (Element element : all()) {
+ if (element.get() == 0.0) {
+ return element.index();
+ }
+ }
+ }
+ return result;
+ }
+
+ @Override
+ public double minValue() {
+ if (size == 0) {
+ return Double.POSITIVE_INFINITY;
+ }
+ return aggregate(Functions.MIN, Functions.IDENTITY);
+ }
+
+ @Override
+ public int minValueIndex() {
+ int result = -1;
+ double min = Double.POSITIVE_INFINITY;
+ int nonZeroElements = 0;
+ Iterator<Element> iter = this.iterateNonZero();
+ while (iter.hasNext()) {
+ nonZeroElements++;
+ Element element = iter.next();
+ double tmp = element.get();
+ if (tmp < min) {
+ min = tmp;
+ result = element.index();
+ }
+ }
+ // if the maxElement is positive and the vector is sparse then any
+ // unfilled element(0.0) could be the maxValue hence we need to
+ // find one of those elements
+ if (nonZeroElements < size && min > 0.0) {
+ for (Element element : all()) {
+ if (element.get() == 0.0) {
+ return element.index();
+ }
+ }
+ }
+ return result;
+ }
+
+ @Override
+ public Vector plus(double x) {
+ Vector result = createOptimizedCopy();
+ if (x == 0.0) {
+ return result;
+ }
+ return result.assign(Functions.plus(x));
+ }
+
+ @Override
+ public Vector plus(Vector that) {
+ if (size != that.size()) {
+ throw new CardinalityException(size, that.size());
+ }
+ return createOptimizedCopy().assign(that, Functions.PLUS);
+ }
+
+ @Override
+ public Vector minus(Vector that) {
+ if (size != that.size()) {
+ throw new CardinalityException(size, that.size());
+ }
+ return createOptimizedCopy().assign(that, Functions.MINUS);
+ }
+
+ @Override
+ public void set(int index, double value) {
+ if (index < 0 || index >= size) {
+ throw new IndexException(index, size);
+ }
+ setQuick(index, value);
+ }
+
+ @Override
+ public void incrementQuick(int index, double increment) {
+ setQuick(index, getQuick(index) + increment);
+ }
+
+ @Override
+ public Vector times(double x) {
+ if (x == 0.0) {
+ return like();
+ }
+ return createOptimizedCopy().assign(Functions.mult(x));
+ }
+
+ /**
+ * Copy the current vector in the most optimum fashion. Used by immutable methods like plus(), minus().
+ * Use this instead of vector.like().assign(vector). Sub-class can choose to override this method.
+ *
+ * @return a copy of the current vector.
+ */
+ protected Vector createOptimizedCopy() {
+ return createOptimizedCopy(this);
+ }
+
+ private static Vector createOptimizedCopy(Vector vector) {
+ Vector result;
+ if (vector.isDense()) {
+ result = vector.like().assign(vector, Functions.SECOND_LEFT_ZERO);
+ } else {
+ result = vector.clone();
+ }
+ return result;
+ }
+
+ @Override
+ public Vector times(Vector that) {
+ if (size != that.size()) {
+ throw new CardinalityException(size, that.size());
+ }
+
+ if (this.getNumNondefaultElements() <= that.getNumNondefaultElements()) {
+ return createOptimizedCopy(this).assign(that, Functions.MULT);
+ } else {
+ return createOptimizedCopy(that).assign(this, Functions.MULT);
+ }
+ }
+
+ @Override
+ public double zSum() {
+ return aggregate(Functions.PLUS, Functions.IDENTITY);
+ }
+
+ @Override
+ public int getNumNonZeroElements() {
+ int count = 0;
+ Iterator<Element> it = iterateNonZero();
+ while (it.hasNext()) {
+ if (it.next().get() != 0.0) {
+ count++;
+ }
+ }
+ return count;
+ }
+
+ @Override
+ public Vector assign(double value) {
+ Iterator<Element> it;
+ if (value == 0.0) {
+ // Make all the non-zero values 0.
+ it = iterateNonZero();
+ while (it.hasNext()) {
+ it.next().set(value);
+ }
+ } else {
+ if (isSequentialAccess() && !isAddConstantTime()) {
+ // Update all the non-zero values and queue the updates for the zero vaues.
+ // The vector will become dense.
+ it = iterator();
+ OrderedIntDoubleMapping updates = new OrderedIntDoubleMapping();
+ while (it.hasNext()) {
+ Element element = it.next();
+ if (element.get() == 0.0) {
+ updates.set(element.index(), value);
+ } else {
+ element.set(value);
+ }
+ }
+ mergeUpdates(updates);
+ } else {
+ for (int i = 0; i < size; ++i) {
+ setQuick(i, value);
+ }
+ }
+ }
+ invalidateCachedLength();
+ return this;
+ }
+
+ @Override
+ public Vector assign(double[] values) {
+ if (size != values.length) {
+ throw new CardinalityException(size, values.length);
+ }
+ if (isSequentialAccess() && !isAddConstantTime()) {
+ OrderedIntDoubleMapping updates = new OrderedIntDoubleMapping();
+ Iterator<Element> it = iterator();
+ while (it.hasNext()) {
+ Element element = it.next();
+ int index = element.index();
+ if (element.get() == 0.0) {
+ updates.set(index, values[index]);
+ } else {
+ element.set(values[index]);
+ }
+ }
+ mergeUpdates(updates);
+ } else {
+ for (int i = 0; i < size; ++i) {
+ setQuick(i, values[i]);
+ }
+ }
+ invalidateCachedLength();
+ return this;
+ }
+
+ @Override
+ public Vector assign(Vector other) {
+ return assign(other, Functions.SECOND);
+ }
+
+ @Override
+ public Vector assign(DoubleDoubleFunction f, double y) {
+ Iterator<Element> iterator = f.apply(0, y) == 0 ? iterateNonZero() : iterator();
+ while (iterator.hasNext()) {
+ Element element = iterator.next();
+ element.set(f.apply(element.get(), y));
+ }
+ invalidateCachedLength();
+ return this;
+ }
+
+ @Override
+ public Vector assign(DoubleFunction f) {
+ Iterator<Element> iterator = !f.isDensifying() ? iterateNonZero() : iterator();
+ while (iterator.hasNext()) {
+ Element element = iterator.next();
+ element.set(f.apply(element.get()));
+ }
+ invalidateCachedLength();
+ return this;
+ }
+
+ @Override
+ public Vector assign(Vector other, DoubleDoubleFunction function) {
+ if (size != other.size()) {
+ throw new CardinalityException(size, other.size());
+ }
+ VectorBinaryAssign.assignBest(this, other, function);
+ invalidateCachedLength();
+ return this;
+ }
+
+ @Override
+ public Matrix cross(Vector other) {
+ Matrix result = matrixLike(size, other.size());
+ Iterator<Vector.Element> it = iterateNonZero();
+ while (it.hasNext()) {
+ Vector.Element e = it.next();
+ int row = e.index();
+ result.assignRow(row, other.times(getQuick(row)));
+ }
+ return result;
+ }
+
+ @Override
+ public final int size() {
+ return size;
+ }
+
+ @Override
+ public String asFormatString() {
+ return toString();
+ }
+
+ @Override
+ public int hashCode() {
+ int result = size;
+ Iterator<Element> iter = iterateNonZero();
+ while (iter.hasNext()) {
+ Element ele = iter.next();
+ result += ele.index() * RandomUtils.hashDouble(ele.get());
+ }
+ return result;
+ }
+
+ /**
+ * Determines whether this {@link Vector} represents the same logical vector as another
+ * object. Two {@link Vector}s are equal (regardless of implementation) if the value at
+ * each index is the same, and the cardinalities are the same.
+ */
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) {
+ return true;
+ }
+ if (!(o instanceof Vector)) {
+ return false;
+ }
+ Vector that = (Vector) o;
+ return size == that.size() && aggregate(that, Functions.PLUS, Functions.MINUS_ABS) == 0.0;
+ }
+
+ @Override
+ public String toString() {
+ return toString(null);
+ }
+
+ public String toString(String[] dictionary) {
+ StringBuilder result = new StringBuilder();
+ result.append('{');
+ for (int index = 0; index < size; index++) {
+ double value = getQuick(index);
+ if (value != 0.0) {
+ result.append(dictionary != null && dictionary.length > index ? dictionary[index] : index);
+ result.append(':');
+ result.append(value);
+ result.append(',');
+ }
+ }
+ if (result.length() > 1) {
+ result.setCharAt(result.length() - 1, '}');
+ } else {
+ result.append('}');
+ }
+ return result.toString();
+ }
+
+ /**
+ * toString() implementation for sparse vectors via {@link #nonZeroes()} method
+ * @return String representation of the vector
+ */
+ public String sparseVectorToString() {
+ Iterator<Element> it = iterateNonZero();
+ if (!it.hasNext()) {
+ return "{}";
+ }
+ else {
+ StringBuilder result = new StringBuilder();
+ result.append('{');
+ while (it.hasNext()) {
+ Vector.Element e = it.next();
+ result.append(e.index());
+ result.append(':');
+ result.append(e.get());
+ result.append(',');
+ }
+ result.setCharAt(result.length() - 1, '}');
+ return result.toString();
+ }
+ }
+
+ protected final class LocalElement implements Element {
+ int index;
+
+ LocalElement(int index) {
+ this.index = index;
+ }
+
+ @Override
+ public double get() {
+ return getQuick(index);
+ }
+
+ @Override
+ public int index() {
+ return index;
+ }
+
+ @Override
+ public void set(double value) {
+ setQuick(index, value);
+ }
+ }
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/Algebra.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/mahout/math/Algebra.java b/core/src/main/java/org/apache/mahout/math/Algebra.java
new file mode 100644
index 0000000..3049057
--- /dev/null
+++ b/core/src/main/java/org/apache/mahout/math/Algebra.java
@@ -0,0 +1,73 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math;
+
+public final class Algebra {
+
+ private Algebra() {
+ }
+
+ public static Vector mult(Matrix m, Vector v) {
+ if (m.numRows() != v.size()) {
+ throw new CardinalityException(m.numRows(), v.size());
+ }
+ // Use a Dense Vector for the moment,
+ Vector result = new DenseVector(m.numRows());
+
+ for (int i = 0; i < m.numRows(); i++) {
+ result.set(i, m.viewRow(i).dot(v));
+ }
+
+ return result;
+ }
+
+ /** Returns sqrt(a^2 + b^2) without under/overflow. */
+ public static double hypot(double a, double b) {
+ double r;
+ if (Math.abs(a) > Math.abs(b)) {
+ r = b / a;
+ r = Math.abs(a) * Math.sqrt(1 + r * r);
+ } else if (b != 0) {
+ r = a / b;
+ r = Math.abs(b) * Math.sqrt(1 + r * r);
+ } else {
+ r = 0.0;
+ }
+ return r;
+ }
+
+ /**
+ * Compute Maximum Absolute Row Sum Norm of input Matrix m
+ * http://mathworld.wolfram.com/MaximumAbsoluteRowSumNorm.html
+ */
+ public static double getNorm(Matrix m) {
+ double max = 0.0;
+ for (int i = 0; i < m.numRows(); i++) {
+ int sum = 0;
+ Vector cv = m.viewRow(i);
+ for (int j = 0; j < cv.size(); j++) {
+ sum += (int) Math.abs(cv.getQuick(j));
+ }
+ if (sum > max) {
+ max = sum;
+ }
+ }
+ return max;
+ }
+
+}
r***@apache.org
2018-09-08 23:35:06 UTC
Permalink
http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/solver/LSMR.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/mahout/math/solver/LSMR.java b/core/src/main/java/org/apache/mahout/math/solver/LSMR.java
new file mode 100644
index 0000000..1f3e706
--- /dev/null
+++ b/core/src/main/java/org/apache/mahout/math/solver/LSMR.java
@@ -0,0 +1,565 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math.solver;
+
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.function.Functions;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * Solves sparse least-squares using the LSMR algorithm.
+ * <p/>
+ * LSMR solves the system of linear equations A * X = B. If the system is inconsistent, it solves
+ * the least-squares problem min ||b - Ax||_2. A is a rectangular matrix of dimension m-by-n, where
+ * all cases are allowed: m=n, m>n, or m&lt;n. B is a vector of length m. The matrix A may be dense
+ * or sparse (usually sparse).
+ * <p/>
+ * Some additional configurable properties adjust the behavior of the algorithm.
+ * <p/>
+ * If you set lambda to a non-zero value then LSMR solves the regularized least-squares problem min
+ * ||(B) - ( A )X|| ||(0) (lambda*I) ||_2 where LAMBDA is a scalar. If LAMBDA is not set,
+ * the system is solved without regularization.
+ * <p/>
+ * You can also set aTolerance and bTolerance. These cause LSMR to iterate until a certain backward
+ * error estimate is smaller than some quantity depending on ATOL and BTOL. Let RES = B - A*X be
+ * the residual vector for the current approximate solution X. If A*X = B seems to be consistent,
+ * LSMR terminates when NORM(RES) <= ATOL*NORM(A)*NORM(X) + BTOL*NORM(B). Otherwise, LSMR terminates
+ * when NORM(A'*RES) <= ATOL*NORM(A)*NORM(RES). If both tolerances are 1.0e-6 (say), the final
+ * NORM(RES) should be accurate to about 6 digits. (The final X will usually have fewer correct
+ * digits, depending on cond(A) and the size of LAMBDA.)
+ * <p/>
+ * The default value for ATOL and BTOL is 1e-6.
+ * <p/>
+ * Ideally, they should be estimates of the relative error in the entries of A and B respectively.
+ * For example, if the entries of A have 7 correct digits, set ATOL = 1e-7. This prevents the
+ * algorithm from doing unnecessary work beyond the uncertainty of the input data.
+ * <p/>
+ * You can also set conditionLimit. In that case, LSMR terminates if an estimate of cond(A) exceeds
+ * conditionLimit. For compatible systems Ax = b, conditionLimit could be as large as 1.0e+12 (say).
+ * For least-squares problems, conditionLimit should be less than 1.0e+8. If conditionLimit is not
+ * set, the default value is 1e+8. Maximum precision can be obtained by setting aTolerance =
+ * bTolerance = conditionLimit = 0, but the number of iterations may then be excessive.
+ * <p/>
+ * Setting iterationLimit causes LSMR to terminate if the number of iterations reaches
+ * iterationLimit. The default is iterationLimit = min(m,n). For ill-conditioned systems, a
+ * larger value of ITNLIM may be needed.
+ * <p/>
+ * Setting localSize causes LSMR to run with rerorthogonalization on the last localSize v_k's.
+ * (v-vectors generated by Golub-Kahan bidiagonalization) If localSize is not set, LSMR runs without
+ * reorthogonalization. A localSize > max(n,m) performs reorthogonalization on all v_k's.
+ * Reorthgonalizing only u_k or both u_k and v_k are not an option here. Details are discussed in
+ * the SIAM paper.
+ * <p/>
+ * getTerminationReason() gives the reason for termination. ISTOP = 0 means X=0 is a solution. = 1
+ * means X is an approximate solution to A*X = B, according to ATOL and BTOL. = 2 means X
+ * approximately solves the least-squares problem according to ATOL. = 3 means COND(A) seems to be
+ * greater than CONLIM. = 4 is the same as 1 with ATOL = BTOL = EPS. = 5 is the same as 2 with ATOL
+ * = EPS. = 6 is the same as 3 with CONLIM = 1/EPS. = 7 means ITN reached ITNLIM before the other
+ * stopping conditions were satisfied.
+ * <p/>
+ * getIterationCount() gives ITN = the number of LSMR iterations.
+ * <p/>
+ * getResidualNorm() gives an estimate of the residual norm: NORMR = norm(B-A*X).
+ * <p/>
+ * getNormalEquationResidual() gives an estimate of the residual for the normal equation: NORMAR =
+ * NORM(A'*(B-A*X)).
+ * <p/>
+ * getANorm() gives an estimate of the Frobenius norm of A.
+ * <p/>
+ * getCondition() gives an estimate of the condition number of A.
+ * <p/>
+ * getXNorm() gives an estimate of NORM(X).
+ * <p/>
+ * LSMR uses an iterative method. For further information, see D. C.-L. Fong and M. A. Saunders
+ * LSMR: An iterative algorithm for least-square problems Draft of 03 Apr 2010, to be submitted to
+ * SISC.
+ * <p/>
+ * David Chin-lung Fong ***@stanford.edu Institute for Computational and Mathematical
+ * Engineering Stanford University
+ * <p/>
+ * Michael Saunders ***@stanford.edu Systems Optimization Laboratory Dept of
+ * MS&E, Stanford University. -----------------------------------------------------------------------
+ */
+public final class LSMR {
+
+ private static final Logger log = LoggerFactory.getLogger(LSMR.class);
+
+ private final double lambda;
+ private int localSize;
+ private int iterationLimit;
+ private double conditionLimit;
+ private double bTolerance;
+ private double aTolerance;
+ private int localPointer;
+ private Vector[] localV;
+ private double residualNorm;
+ private double normalEquationResidual;
+ private double xNorm;
+ private int iteration;
+ private double normA;
+ private double condA;
+
+ public int getIterationCount() {
+ return iteration;
+ }
+
+ public double getResidualNorm() {
+ return residualNorm;
+ }
+
+ public double getNormalEquationResidual() {
+ return normalEquationResidual;
+ }
+
+ public double getANorm() {
+ return normA;
+ }
+
+ public double getCondition() {
+ return condA;
+ }
+
+ public double getXNorm() {
+ return xNorm;
+ }
+
+ /**
+ * LSMR uses an iterative method to solve a linear system. For further information, see D. C.-L.
+ * Fong and M. A. Saunders LSMR: An iterative algorithm for least-square problems Draft of 03 Apr
+ * 2010, to be submitted to SISC.
+ * <p/>
+ * 08 Dec 2009: First release version of LSMR. 09 Apr 2010: Updated documentation and default
+ * parameters. 14 Apr 2010: Updated documentation. 03 Jun 2010: LSMR with local
+ * reorthogonalization (full reorthogonalization is also implemented)
+ * <p/>
+ * David Chin-lung Fong ***@stanford.edu Institute for Computational and
+ * Mathematical Engineering Stanford University
+ * <p/>
+ * Michael Saunders ***@stanford.edu Systems Optimization Laboratory Dept of
+ * MS&E, Stanford University. -----------------------------------------------------------------------
+ */
+
+ public LSMR() {
+ // Set default parameters.
+ lambda = 0;
+ aTolerance = 1.0e-6;
+ bTolerance = 1.0e-6;
+ conditionLimit = 1.0e8;
+ iterationLimit = -1;
+ localSize = 0;
+ }
+
+ public Vector solve(Matrix A, Vector b) {
+ /*
+ % Initialize.
+
+
+ hdg1 = ' itn x(1) norm r norm A''r';
+ hdg2 = ' compatible LS norm A cond A';
+ pfreq = 20; % print frequency (for repeating the heading)
+ pcount = 0; % print counter
+
+ % Determine dimensions m and n, and
+ % form the first vectors u and v.
+ % These satisfy beta*u = b, alpha*v = A'u.
+ */
+ log.debug(" itn x(1) norm r norm A'r");
+ log.debug(" compatible LS norm A cond A");
+
+ Matrix transposedA = A.transpose();
+ Vector u = b;
+
+ double beta = u.norm(2);
+ if (beta > 0) {
+ u = u.divide(beta);
+ }
+
+ Vector v = transposedA.times(u);
+ int m = A.numRows();
+ int n = A.numCols();
+
+ int minDim = Math.min(m, n);
+ if (iterationLimit == -1) {
+ iterationLimit = minDim;
+ }
+
+ if (log.isDebugEnabled()) {
+ log.debug("LSMR - Least-squares solution of Ax = b, based on Matlab Version 1.02, 14 Apr 2010, "
+ + "Mahout version {}", getClass().getPackage().getImplementationVersion());
+ log.debug(String.format("The matrix A has %d rows and %d cols, lambda = %.4g, atol = %g, btol = %g",
+ m, n, lambda, aTolerance, bTolerance));
+ }
+
+ double alpha = v.norm(2);
+ if (alpha > 0) {
+ v.assign(Functions.div(alpha));
+ }
+
+
+ // Initialization for local reorthogonalization
+ localPointer = 0;
+
+ // Preallocate storage for storing the last few v_k. Since with
+ // orthogonal v_k's, Krylov subspace method would converge in not
+ // more iterations than the number of singular values, more
+ // space is not necessary.
+ localV = new Vector[Math.min(localSize, minDim)];
+ boolean localOrtho = false;
+ if (localSize > 0) {
+ localOrtho = true;
+ localV[0] = v;
+ }
+
+
+ // Initialize variables for 1st iteration.
+
+ iteration = 0;
+ double zetabar = alpha * beta;
+ double alphabar = alpha;
+
+ Vector h = v;
+ Vector hbar = zeros(n);
+ Vector x = zeros(n);
+
+ // Initialize variables for estimation of ||r||.
+
+ double betadd = beta;
+
+ // Initialize variables for estimation of ||A|| and cond(A)
+
+ double aNorm = alpha * alpha;
+
+ // Items for use in stopping rules.
+ double normb = beta;
+
+ double ctol = 0;
+ if (conditionLimit > 0) {
+ ctol = 1 / conditionLimit;
+ }
+ residualNorm = beta;
+
+ // Exit if b=0 or A'b = 0.
+
+ normalEquationResidual = alpha * beta;
+ if (normalEquationResidual == 0) {
+ return x;
+ }
+
+ // Heading for iteration log.
+
+
+ if (log.isDebugEnabled()) {
+ double test2 = alpha / beta;
+// log.debug('{} {}', hdg1, hdg2);
+ log.debug("{} {}", iteration, x.get(0));
+ log.debug("{} {}", residualNorm, normalEquationResidual);
+ double test1 = 1;
+ log.debug("{} {}", test1, test2);
+ }
+
+
+ //------------------------------------------------------------------
+ // Main iteration loop.
+ //------------------------------------------------------------------
+ double rho = 1;
+ double rhobar = 1;
+ double cbar = 1;
+ double sbar = 0;
+ double betad = 0;
+ double rhodold = 1;
+ double tautildeold = 0;
+ double thetatilde = 0;
+ double zeta = 0;
+ double d = 0;
+ double maxrbar = 0;
+ double minrbar = 1.0e+100;
+ StopCode stop = StopCode.CONTINUE;
+ while (iteration <= iterationLimit && stop == StopCode.CONTINUE) {
+
+ iteration++;
+
+ // Perform the next step of the bidiagonalization to obtain the
+ // next beta, u, alpha, v. These satisfy the relations
+ // beta*u = A*v - alpha*u,
+ // alpha*v = A'*u - beta*v.
+
+ u = A.times(v).minus(u.times(alpha));
+ beta = u.norm(2);
+ if (beta > 0) {
+ u.assign(Functions.div(beta));
+
+ // store data for local-reorthogonalization of V
+ if (localOrtho) {
+ localVEnqueue(v);
+ }
+ v = transposedA.times(u).minus(v.times(beta));
+ // local-reorthogonalization of V
+ if (localOrtho) {
+ v = localVOrtho(v);
+ }
+ alpha = v.norm(2);
+ if (alpha > 0) {
+ v.assign(Functions.div(alpha));
+ }
+ }
+
+ // At this point, beta = beta_{k+1}, alpha = alpha_{k+1}.
+
+ // Construct rotation Qhat_{k,2k+1}.
+
+ double alphahat = Math.hypot(alphabar, lambda);
+ double chat = alphabar / alphahat;
+ double shat = lambda / alphahat;
+
+ // Use a plane rotation (Q_i) to turn B_i to R_i
+
+ double rhoold = rho;
+ rho = Math.hypot(alphahat, beta);
+ double c = alphahat / rho;
+ double s = beta / rho;
+ double thetanew = s * alpha;
+ alphabar = c * alpha;
+
+ // Use a plane rotation (Qbar_i) to turn R_i^T to R_i^bar
+
+ double rhobarold = rhobar;
+ double zetaold = zeta;
+ double thetabar = sbar * rho;
+ double rhotemp = cbar * rho;
+ rhobar = Math.hypot(cbar * rho, thetanew);
+ cbar = cbar * rho / rhobar;
+ sbar = thetanew / rhobar;
+ zeta = cbar * zetabar;
+ zetabar = -sbar * zetabar;
+
+
+ // Update h, h_hat, x.
+
+ hbar = h.minus(hbar.times(thetabar * rho / (rhoold * rhobarold)));
+
+ x.assign(hbar.times(zeta / (rho * rhobar)), Functions.PLUS);
+ h = v.minus(h.times(thetanew / rho));
+
+ // Estimate of ||r||.
+
+ // Apply rotation Qhat_{k,2k+1}.
+ double betaacute = chat * betadd;
+ double betacheck = -shat * betadd;
+
+ // Apply rotation Q_{k,k+1}.
+ double betahat = c * betaacute;
+ betadd = -s * betaacute;
+
+ // Apply rotation Qtilde_{k-1}.
+ // betad = betad_{k-1} here.
+
+ double thetatildeold = thetatilde;
+ double rhotildeold = Math.hypot(rhodold, thetabar);
+ double ctildeold = rhodold / rhotildeold;
+ double stildeold = thetabar / rhotildeold;
+ thetatilde = stildeold * rhobar;
+ rhodold = ctildeold * rhobar;
+ betad = -stildeold * betad + ctildeold * betahat;
+
+ // betad = betad_k here.
+ // rhodold = rhod_k here.
+
+ tautildeold = (zetaold - thetatildeold * tautildeold) / rhotildeold;
+ double taud = (zeta - thetatilde * tautildeold) / rhodold;
+ d += betacheck * betacheck;
+ residualNorm = Math.sqrt(d + (betad - taud) * (betad - taud) + betadd * betadd);
+
+ // Estimate ||A||.
+ aNorm += beta * beta;
+ normA = Math.sqrt(aNorm);
+ aNorm += alpha * alpha;
+
+ // Estimate cond(A).
+ maxrbar = Math.max(maxrbar, rhobarold);
+ if (iteration > 1) {
+ minrbar = Math.min(minrbar, rhobarold);
+ }
+ condA = Math.max(maxrbar, rhotemp) / Math.min(minrbar, rhotemp);
+
+ // Test for convergence.
+
+ // Compute norms for convergence testing.
+ normalEquationResidual = Math.abs(zetabar);
+ xNorm = x.norm(2);
+
+ // Now use these norms to estimate certain other quantities,
+ // some of which will be small near a solution.
+
+ double test1 = residualNorm / normb;
+ double test2 = normalEquationResidual / (normA * residualNorm);
+ double test3 = 1 / condA;
+ double t1 = test1 / (1 + normA * xNorm / normb);
+ double rtol = bTolerance + aTolerance * normA * xNorm / normb;
+
+ // The following tests guard against extremely small values of
+ // atol, btol or ctol. (The user may have set any or all of
+ // the parameters atol, btol, conlim to 0.)
+ // The effect is equivalent to the normAl tests using
+ // atol = eps, btol = eps, conlim = 1/eps.
+
+ if (iteration > iterationLimit) {
+ stop = StopCode.ITERATION_LIMIT;
+ }
+ if (1 + test3 <= 1) {
+ stop = StopCode.CONDITION_MACHINE_TOLERANCE;
+ }
+ if (1 + test2 <= 1) {
+ stop = StopCode.LEAST_SQUARE_CONVERGED_MACHINE_TOLERANCE;
+ }
+ if (1 + t1 <= 1) {
+ stop = StopCode.CONVERGED_MACHINE_TOLERANCE;
+ }
+
+ // Allow for tolerances set by the user.
+
+ if (test3 <= ctol) {
+ stop = StopCode.CONDITION;
+ }
+ if (test2 <= aTolerance) {
+ stop = StopCode.CONVERGED;
+ }
+ if (test1 <= rtol) {
+ stop = StopCode.TRIVIAL;
+ }
+
+ // See if it is time to print something.
+ if (log.isDebugEnabled()) {
+ if ((n <= 40) || (iteration <= 10) || (iteration >= iterationLimit - 10) || ((iteration % 10) == 0)
+ || (test3 <= 1.1 * ctol) || (test2 <= 1.1 * aTolerance) || (test1 <= 1.1 * rtol)
+ || (stop != StopCode.CONTINUE)) {
+ statusDump(x, normA, condA, test1, test2);
+ }
+ }
+ } // iteration loop
+
+ // Print the stopping condition.
+ log.debug("Finished: {}", stop.getMessage());
+
+ return x;
+ /*
+
+
+ if show
+ fprintf('\n\nLSMR finished')
+ fprintf('\n%s', msg(istop+1,:))
+ fprintf('\nistop =%8g normr =%8.1e' , istop, normr )
+ fprintf(' normA =%8.1e normAr =%8.1e', normA, normAr)
+ fprintf('\nitn =%8g condA =%8.1e' , itn , condA )
+ fprintf(' normx =%8.1e\n', normx)
+ end
+ */
+ }
+
+ private void statusDump(Vector x, double normA, double condA, double test1, double test2) {
+ log.debug("{} {}", residualNorm, normalEquationResidual);
+ log.debug("{} {}", iteration, x.get(0));
+ log.debug("{} {}", test1, test2);
+ log.debug("{} {}", normA, condA);
+ }
+
+ private static Vector zeros(int n) {
+ return new DenseVector(n);
+ }
+
+ //-----------------------------------------------------------------------
+ // stores v into the circular buffer localV
+ //-----------------------------------------------------------------------
+
+ private void localVEnqueue(Vector v) {
+ if (localV.length > 0) {
+ localV[localPointer] = v;
+ localPointer = (localPointer + 1) % localV.length;
+ }
+ }
+
+ //-----------------------------------------------------------------------
+ // Perform local reorthogonalization of V
+ //-----------------------------------------------------------------------
+
+ private Vector localVOrtho(Vector v) {
+ for (Vector old : localV) {
+ if (old != null) {
+ double x = v.dot(old);
+ v = v.minus(old.times(x));
+ }
+ }
+ return v;
+ }
+
+ private enum StopCode {
+ CONTINUE("Not done"),
+ TRIVIAL("The exact solution is x = 0"),
+ CONVERGED("Ax - b is small enough, given atol, btol"),
+ LEAST_SQUARE_CONVERGED("The least-squares solution is good enough, given atol"),
+ CONDITION("The estimate of cond(Abar) has exceeded condition limit"),
+ CONVERGED_MACHINE_TOLERANCE("Ax - b is small enough for this machine"),
+ LEAST_SQUARE_CONVERGED_MACHINE_TOLERANCE("The least-squares solution is good enough for this machine"),
+ CONDITION_MACHINE_TOLERANCE("Cond(Abar) seems to be too large for this machine"),
+ ITERATION_LIMIT("The iteration limit has been reached");
+
+ private final String message;
+
+ StopCode(String message) {
+ this.message = message;
+ }
+
+ public String getMessage() {
+ return message;
+ }
+ }
+
+ public void setAtolerance(double aTolerance) {
+ this.aTolerance = aTolerance;
+ }
+
+ public void setBtolerance(double bTolerance) {
+ this.bTolerance = bTolerance;
+ }
+
+ public void setConditionLimit(double conditionLimit) {
+ this.conditionLimit = conditionLimit;
+ }
+
+ public void setIterationLimit(int iterationLimit) {
+ this.iterationLimit = iterationLimit;
+ }
+
+ public void setLocalSize(int localSize) {
+ this.localSize = localSize;
+ }
+
+ public double getLambda() {
+ return lambda;
+ }
+
+ public double getAtolerance() {
+ return aTolerance;
+ }
+
+ public double getBtolerance() {
+ return bTolerance;
+ }
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/solver/Preconditioner.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/mahout/math/solver/Preconditioner.java b/core/src/main/java/org/apache/mahout/math/solver/Preconditioner.java
new file mode 100644
index 0000000..91528fc
--- /dev/null
+++ b/core/src/main/java/org/apache/mahout/math/solver/Preconditioner.java
@@ -0,0 +1,36 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math.solver;
+
+import org.apache.mahout.math.Vector;
+
+/**
+ * Interface for defining preconditioners used for improving the performance and/or stability of linear
+ * system solvers.
+ */
+public interface Preconditioner {
+
+ /**
+ * Preconditions the specified vector.
+ *
+ * @param v The vector to precondition.
+ * @return The preconditioned vector.
+ */
+ Vector precondition(Vector v);
+
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/ssvd/SequentialBigSvd.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/mahout/math/ssvd/SequentialBigSvd.java b/core/src/main/java/org/apache/mahout/math/ssvd/SequentialBigSvd.java
new file mode 100644
index 0000000..46354da
--- /dev/null
+++ b/core/src/main/java/org/apache/mahout/math/ssvd/SequentialBigSvd.java
@@ -0,0 +1,69 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math.ssvd;
+
+import org.apache.mahout.math.CholeskyDecomposition;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.RandomTrinaryMatrix;
+import org.apache.mahout.math.SingularValueDecomposition;
+import org.apache.mahout.math.Vector;
+
+/**
+ * Implements an in-memory version of stochastic projection based SVD. See SequentialOutOfCoreSvd
+ * for algorithm notes.
+ */
+public class SequentialBigSvd {
+ private final Matrix y;
+ private final CholeskyDecomposition cd1;
+ private final CholeskyDecomposition cd2;
+ private final SingularValueDecomposition svd;
+ private final Matrix b;
+
+
+ public SequentialBigSvd(Matrix A, int p) {
+ // Y = A * \Omega
+ y = A.times(new RandomTrinaryMatrix(A.columnSize(), p));
+
+ // R'R = Y' Y
+ cd1 = new CholeskyDecomposition(y.transpose().times(y));
+
+ // B = Q" A = (Y R^{-1} )' A
+ b = cd1.solveRight(y).transpose().times(A);
+
+ // L L' = B B'
+ cd2 = new CholeskyDecomposition(b.times(b.transpose()));
+
+ // U_0 D V_0' = L
+ svd = new SingularValueDecomposition(cd2.getL());
+ }
+
+ public Vector getSingularValues() {
+ return new DenseVector(svd.getSingularValues());
+ }
+
+ public Matrix getU() {
+ // U = (Y inv(R)) U_0
+ return cd1.solveRight(y).times(svd.getU());
+ }
+
+ public Matrix getV() {
+ // V = (B' inv(L')) V_0
+ return cd2.solveRight(b.transpose()).times(svd.getV());
+ }
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/stats/LogLikelihood.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/mahout/math/stats/LogLikelihood.java b/core/src/main/java/org/apache/mahout/math/stats/LogLikelihood.java
new file mode 100644
index 0000000..d2c8434
--- /dev/null
+++ b/core/src/main/java/org/apache/mahout/math/stats/LogLikelihood.java
@@ -0,0 +1,220 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math.stats;
+
+import com.google.common.base.Preconditions;
+import com.google.common.collect.Multiset;
+import com.google.common.collect.Ordering;
+
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+import java.util.PriorityQueue;
+import java.util.Queue;
+
+/**
+ * Utility methods for working with log-likelihood
+ */
+public final class LogLikelihood {
+
+ private LogLikelihood() {
+ }
+
+ /**
+ * Calculates the unnormalized Shannon entropy. This is
+ *
+ * -sum x_i log x_i / N = -N sum x_i/N log x_i/N
+ *
+ * where N = sum x_i
+ *
+ * If the x's sum to 1, then this is the same as the normal
+ * expression. Leaving this un-normalized makes working with
+ * counts and computing the LLR easier.
+ *
+ * @return The entropy value for the elements
+ */
+ public static double entropy(long... elements) {
+ long sum = 0;
+ double result = 0.0;
+ for (long element : elements) {
+ Preconditions.checkArgument(element >= 0);
+ result += xLogX(element);
+ sum += element;
+ }
+ return xLogX(sum) - result;
+ }
+
+ private static double xLogX(long x) {
+ return x == 0 ? 0.0 : x * Math.log(x);
+ }
+
+ /**
+ * Merely an optimization for the common two argument case of {@link #entropy(long...)}
+ * @see #logLikelihoodRatio(long, long, long, long)
+ */
+ private static double entropy(long a, long b) {
+ return xLogX(a + b) - xLogX(a) - xLogX(b);
+ }
+
+ /**
+ * Merely an optimization for the common four argument case of {@link #entropy(long...)}
+ * @see #logLikelihoodRatio(long, long, long, long)
+ */
+ private static double entropy(long a, long b, long c, long d) {
+ return xLogX(a + b + c + d) - xLogX(a) - xLogX(b) - xLogX(c) - xLogX(d);
+ }
+
+ /**
+ * Calculates the Raw Log-likelihood ratio for two events, call them A and B. Then we have:
+ * <p/>
+ * <table border="1" cellpadding="5" cellspacing="0">
+ * <tbody><tr><td>&nbsp;</td><td>Event A</td><td>Everything but A</td></tr>
+ * <tr><td>Event B</td><td>A and B together (k_11)</td><td>B, but not A (k_12)</td></tr>
+ * <tr><td>Everything but B</td><td>A without B (k_21)</td><td>Neither A nor B (k_22)</td></tr></tbody>
+ * </table>
+ *
+ * @param k11 The number of times the two events occurred together
+ * @param k12 The number of times the second event occurred WITHOUT the first event
+ * @param k21 The number of times the first event occurred WITHOUT the second event
+ * @param k22 The number of times something else occurred (i.e. was neither of these events
+ * @return The raw log-likelihood ratio
+ *
+ * <p/>
+ * Credit to http://tdunning.blogspot.com/2008/03/surprise-and-coincidence.html for the table and the descriptions.
+ */
+ public static double logLikelihoodRatio(long k11, long k12, long k21, long k22) {
+ Preconditions.checkArgument(k11 >= 0 && k12 >= 0 && k21 >= 0 && k22 >= 0);
+ // note that we have counts here, not probabilities, and that the entropy is not normalized.
+ double rowEntropy = entropy(k11 + k12, k21 + k22);
+ double columnEntropy = entropy(k11 + k21, k12 + k22);
+ double matrixEntropy = entropy(k11, k12, k21, k22);
+ if (rowEntropy + columnEntropy < matrixEntropy) {
+ // round off error
+ return 0.0;
+ }
+ return 2.0 * (rowEntropy + columnEntropy - matrixEntropy);
+ }
+
+ /**
+ * Calculates the root log-likelihood ratio for two events.
+ * See {@link #logLikelihoodRatio(long, long, long, long)}.
+
+ * @param k11 The number of times the two events occurred together
+ * @param k12 The number of times the second event occurred WITHOUT the first event
+ * @param k21 The number of times the first event occurred WITHOUT the second event
+ * @param k22 The number of times something else occurred (i.e. was neither of these events
+ * @return The root log-likelihood ratio
+ *
+ * <p/>
+ * There is some more discussion here: http://s.apache.org/CGL
+ *
+ * And see the response to Wataru's comment here:
+ * http://tdunning.blogspot.com/2008/03/surprise-and-coincidence.html
+ */
+ public static double rootLogLikelihoodRatio(long k11, long k12, long k21, long k22) {
+ double llr = logLikelihoodRatio(k11, k12, k21, k22);
+ double sqrt = Math.sqrt(llr);
+ if ((double) k11 / (k11 + k12) < (double) k21 / (k21 + k22)) {
+ sqrt = -sqrt;
+ }
+ return sqrt;
+ }
+
+ /**
+ * Compares two sets of counts to see which items are interestingly over-represented in the first
+ * set.
+ * @param a The first counts.
+ * @param b The reference counts.
+ * @param maxReturn The maximum number of items to return. Use maxReturn >= a.elementSet.size() to return all
+ * scores above the threshold.
+ * @param threshold The minimum score for items to be returned. Use 0 to return all items more common
+ * in a than b. Use -Double.MAX_VALUE (not Double.MIN_VALUE !) to not use a threshold.
+ * @return A list of scored items with their scores.
+ */
+ public static <T> List<ScoredItem<T>> compareFrequencies(Multiset<T> a,
+ Multiset<T> b,
+ int maxReturn,
+ double threshold) {
+ int totalA = a.size();
+ int totalB = b.size();
+
+ Ordering<ScoredItem<T>> byScoreAscending = new Ordering<ScoredItem<T>>() {
+ @Override
+ public int compare(ScoredItem<T> tScoredItem, ScoredItem<T> tScoredItem1) {
+ return Double.compare(tScoredItem.score, tScoredItem1.score);
+ }
+ };
+ Queue<ScoredItem<T>> best = new PriorityQueue<>(maxReturn + 1, byScoreAscending);
+
+ for (T t : a.elementSet()) {
+ compareAndAdd(a, b, maxReturn, threshold, totalA, totalB, best, t);
+ }
+
+ // if threshold >= 0 we only iterate through a because anything not there can't be as or more common than in b.
+ if (threshold < 0) {
+ for (T t : b.elementSet()) {
+ // only items missing from a need be scored
+ if (a.count(t) == 0) {
+ compareAndAdd(a, b, maxReturn, threshold, totalA, totalB, best, t);
+ }
+ }
+ }
+
+ List<ScoredItem<T>> r = new ArrayList<>(best);
+ Collections.sort(r, byScoreAscending.reverse());
+ return r;
+ }
+
+ private static <T> void compareAndAdd(Multiset<T> a,
+ Multiset<T> b,
+ int maxReturn,
+ double threshold,
+ int totalA,
+ int totalB,
+ Queue<ScoredItem<T>> best,
+ T t) {
+ int kA = a.count(t);
+ int kB = b.count(t);
+ double score = rootLogLikelihoodRatio(kA, totalA - kA, kB, totalB - kB);
+ if (score >= threshold) {
+ ScoredItem<T> x = new ScoredItem<>(t, score);
+ best.add(x);
+ while (best.size() > maxReturn) {
+ best.poll();
+ }
+ }
+ }
+
+ public static final class ScoredItem<T> {
+ private final T item;
+ private final double score;
+
+ public ScoredItem(T item, double score) {
+ this.item = item;
+ this.score = score;
+ }
+
+ public double getScore() {
+ return score;
+ }
+
+ public T getItem() {
+ return item;
+ }
+ }
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/stats/OnlineExponentialAverage.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/mahout/math/stats/OnlineExponentialAverage.java b/core/src/main/java/org/apache/mahout/math/stats/OnlineExponentialAverage.java
new file mode 100644
index 0000000..54a0ec7
--- /dev/null
+++ b/core/src/main/java/org/apache/mahout/math/stats/OnlineExponentialAverage.java
@@ -0,0 +1,62 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math.stats;
+
+/**
+ * Computes an online average that is exponentially weighted toward recent time-embedded samples.
+ */
+public class OnlineExponentialAverage {
+
+ private final double alpha;
+ private double lastT;
+ private double s;
+ private double w;
+ private double t;
+
+ /**
+ * Creates an averager that has a specified time constant for discounting old data. The time
+ * constant, alpha, is the time at which an older sample is discounted to 1/e relative to current
+ * data. Roughly speaking, data that is more than 3*alpha old doesn't matter any more and data
+ * that is more recent than alpha/3 is about as important as current data.
+ *
+ * See http://tdunning.blogspot.com/2011/03/exponential-weighted-averages-with.html for a
+ * derivation. See http://tdunning.blogspot.com/2011/03/exponentially-weighted-averaging-for.html
+ * for the rate method.
+ *
+ * @param alpha The time constant for discounting old data and state.
+ */
+ public OnlineExponentialAverage(double alpha) {
+ this.alpha = alpha;
+ }
+
+ public void add(double t, double x) {
+ double pi = Math.exp(-(t - lastT) / alpha);
+ s = x + pi * s;
+ w = 1.0 + pi * w;
+ this.t = t - lastT + pi * this.t;
+ lastT = t;
+ }
+
+ public double mean() {
+ return s / w;
+ }
+
+ public double meanRate() {
+ return s / t;
+ }
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/stats/OnlineSummarizer.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/mahout/math/stats/OnlineSummarizer.java b/core/src/main/java/org/apache/mahout/math/stats/OnlineSummarizer.java
new file mode 100644
index 0000000..793aa71
--- /dev/null
+++ b/core/src/main/java/org/apache/mahout/math/stats/OnlineSummarizer.java
@@ -0,0 +1,93 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math.stats;
+
+//import com.tdunning.math.stats.TDigest;
+
+/**
+ * Computes on-line estimates of mean, variance and all five quartiles (notably including the
+ * median). Since this is done in a completely incremental fashion (that is what is meant by
+ * on-line) estimates are available at any time and the amount of memory used is constant. Somewhat
+ * surprisingly, the quantile estimates are about as good as you would get if you actually kept all
+ * of the samples.
+ * <p/>
+ * The method used for mean and variance is Welford's method. See
+ * <p/>
+ * http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#On-line_algorithm
+ * <p/>
+ * The method used for computing the quartiles is a simplified form of the stochastic approximation
+ * method described in the article "Incremental Quantile Estimation for Massive Tracking" by Chen,
+ * Lambert and Pinheiro
+ * <p/>
+ * See
+ * <p/>
+ * http://citeseerx.ist.psu.edu/viewdoc/summary?doi=10.1.1.105.1580
+ */
+public class OnlineSummarizer {
+
+// private TDigest quantiles = TDigest.createDigest(100.0);
+
+ // mean and variance estimates
+ private double mean;
+ private double variance;
+
+ // number of samples seen so far
+ private int n;
+
+ public void add(double sample) {
+ n++;
+ double oldMean = mean;
+ mean += (sample - mean) / n;
+ double diff = (sample - mean) * (sample - oldMean);
+ variance += (diff - variance) / n;
+
+// quantiles.add(sample);
+ }
+
+ public int getCount() {
+ return n;
+ }
+
+ public double getMean() {
+ return mean;
+ }
+
+ public double getSD() {
+ return Math.sqrt(variance);
+ }
+
+// public double getMin() {
+// return getQuartile(0);
+// }
+//
+// public double getMax() {
+// return getQuartile(4);
+// }
+
+// public double getQuartile(int i) {
+// return quantiles.quantile(0.25 * i);
+// }
+//
+// public double quantile(double q) {
+// return quantiles.quantile(q);
+// }
+
+// public double getMedian() {
+// return getQuartile(2);
+// }
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/test/java/org/apache/mahout/math/QRDecompositionTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/org/apache/mahout/math/QRDecompositionTest.java b/core/src/test/java/org/apache/mahout/math/QRDecompositionTest.java
new file mode 100644
index 0000000..b85458e
--- /dev/null
+++ b/core/src/test/java/org/apache/mahout/math/QRDecompositionTest.java
@@ -0,0 +1,280 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math;
+
+import java.util.List;
+
+import com.google.common.collect.Lists;
+import org.apache.mahout.math.function.DoubleDoubleFunction;
+import org.apache.mahout.math.function.Functions;
+import org.apache.mahout.math.stats.OnlineSummarizer;
+import org.junit.Ignore;
+import org.junit.Test;
+
+public final class QRDecompositionTest extends MahoutTestCase {
+ @Test
+ public void randomMatrix() {
+ Matrix a = new DenseMatrix(60, 60).assign(Functions.random());
+ QRDecomposition qr = new QRDecomposition(a);
+
+ // how close is Q to actually being orthornormal?
+ double maxIdent = qr.getQ().transpose().times(qr.getQ()).viewDiagonal().assign(Functions.plus(-1)).norm(1);
+ assertEquals(0, maxIdent, 1.0e-13);
+
+ // how close is Q R to the original value of A?
+ Matrix z = qr.getQ().times(qr.getR()).minus(a);
+ double maxError = z.aggregate(Functions.MIN, Functions.ABS);
+ assertEquals(0, maxError, 1.0e-13);
+ }
+
+ @Test
+ public void rank1() {
+ Matrix x = new DenseMatrix(3, 3);
+ x.viewRow(0).assign(new double[]{1, 2, 3});
+ x.viewRow(1).assign(new double[]{2, 4, 6});
+ x.viewRow(2).assign(new double[]{3, 6, 9});
+
+ QRDecomposition qr = new QRDecomposition(x);
+ assertFalse(qr.hasFullRank());
+ assertEquals(0, new DenseVector(new double[]{3.741657, 7.483315, 11.22497}).aggregate(qr.getR().viewRow(0), Functions.PLUS, new DoubleDoubleFunction() {
+ @Override
+ public double apply(double arg1, double arg2) {
+ return Math.abs(arg1) - Math.abs(arg2);
+ }
+ }), 1.0e-5);
+ }
+
+ @Test
+ public void fullRankTall() {
+ Matrix x = matrix();
+ QRDecomposition qr = new QRDecomposition(x);
+ assertTrue(qr.hasFullRank());
+ Matrix rRef = reshape(new double[]{
+ -2.99129686445138, 0, 0, 0, 0,
+ -0.0282260628674372, -2.38850244769059, 0, 0, 0,
+ 0.733739310355871, 1.48042000631646, 2.29051263117895, 0, 0,
+ -0.0394082168269326, 0.282829484207801, -0.00438521041803086, -2.90823198084203, 0,
+ 0.923669647838536, 1.76679276072492, 0.637690104222683, -0.225890909498753, -1.35732293800944},
+ 5, 5);
+ Matrix r = qr.getR();
+
+ // check identity down to sign
+ assertEquals(0, r.clone().assign(Functions.ABS).minus(rRef.clone().assign(Functions.ABS)).aggregate(Functions.PLUS, Functions.IDENTITY), 1.0e-12);
+
+ Matrix qRef = reshape(new double[]{
+ -0.165178287646573, 0.0510035857637869, 0.13985915987379, -0.120173729496501,
+ -0.453198314345324, 0.644400679630493, -0.503117990820608, 0.24968739845381,
+ 0.323968339146224, -0.465266080134262, 0.276508948773268, -0.687909700644343,
+ 0.0544048888907195, -0.0166677718378263, 0.171309755790717, 0.310339001630029,
+ 0.674790532821663, 0.0058166082200493, -0.381707516461884, 0.300504956413142,
+ -0.105751091334003, 0.410450870871096, 0.31113446615821, 0.179338172684956,
+ 0.361951807617901, 0.763921725548796, 0.380327892605634, -0.287274944594054,
+ 0.0311604042556675, 0.0386096858143961, 0.0387156960650472, -0.232975755728917,
+ 0.0358178276684149, 0.173105775703199, 0.327321867815603, 0.328671945345279,
+ -0.36015879836344, -0.444261660176044, 0.09438499563253, 0.646216148583769
+ }, 8, 5);
+
+ printMatrix("qRef", qRef);
+
+ Matrix q = qr.getQ();
+ printMatrix("q", q);
+
+ assertEquals(0, q.clone().assign(Functions.ABS).minus(qRef.clone().assign(Functions.ABS)).aggregate(Functions.PLUS, Functions.IDENTITY), 1.0e-12);
+
+ Matrix x1 = qr.solve(reshape(new double[]{
+ -0.0178247686747641, 0.68631714634098, -0.335464858468858, 1.50249941751569,
+ -0.669901640772149, -0.977025038942455, -1.18857546169856, -1.24792900492054
+ }, 8, 1));
+ Matrix xref = reshape(new double[]{
+ -0.0127440093664874, 0.655825940180799, -0.100755415991702, -0.0349559562697406,
+ -0.190744297762028
+ }, 5, 1);
+
+ printMatrix("x1", x1);
+ printMatrix("xref", xref);
+
+ assertEquals(xref, x1, 1.0e-8);
+ }
+
+ @Test
+ public void fullRankWide() {
+ Matrix x = matrix().transpose();
+ QRDecomposition qr = new QRDecomposition(x);
+ assertTrue(qr.hasFullRank());
+ Matrix rActual = qr.getR();
+
+ Matrix rRef = reshape(new double[]{
+ -2.42812464965842, 0, 0, 0, 0,
+ 0.303587286111356, -2.91663643494775, 0, 0, 0,
+ -0.201812474153156, -0.765485720168378, 1.09989373598954, 0, 0,
+ 1.47980701097885, -0.637545820524326, -1.55519859337935, 0.844655127991726, 0,
+ 0.0248883129453161, 0.00115010570270549, -0.236340588891252, -0.092924118200147, 1.42910099545547,
+ -1.1678472412429, 0.531245845248056, 0.351978196071514, -1.03241474816555, -2.20223861735426,
+ -0.887809959067632, 0.189731251982918, -0.504321849233586, 0.490484123999836, 1.21266692336743,
+ -0.633888169775463, 1.04738559065986, 0.284041239547031, 0.578183510077156, -0.942314870832456
+ }, 5, 8);
+ printMatrix("rRef", rRef);
+ printMatrix("rActual", rActual);
+ assertEquals(0, rActual.clone().assign(Functions.ABS).minus(rRef.clone().assign(Functions.ABS)).aggregate(Functions.PLUS, Functions.IDENTITY), 1.0e-12);
+// assertEquals(rRef, rActual, 1.0e-8);
+
+ Matrix qRef = reshape(new double[]{
+ -0.203489262374627, 0.316761677948356, -0.784155643293468, 0.394321494579, -0.29641971170211,
+ 0.0311283614803723, -0.34755265020736, 0.137138511478328, 0.848579887681972, 0.373287266507375,
+ -0.39603700561249, -0.787812566647329, -0.377864833067864, -0.275080943427399, 0.0636764674878229,
+ 0.0763976893309043, -0.318551137554327, 0.286407036668598, 0.206004127289883, -0.876482672226889,
+ 0.89159476695423, -0.238213616975551, -0.376141107880836, -0.0794701657055114, 0.0227025098210165
+ }, 5, 5);
+
+ Matrix q = qr.getQ();
+
+ printMatrix("qRef", qRef);
+ printMatrix("q", q);
+
+ assertEquals(0, q.clone().assign(Functions.ABS).minus(qRef.clone().assign(Functions.ABS)).aggregate(Functions.PLUS, Functions.IDENTITY), 1.0e-12);
+// assertEquals(qRef, q, 1.0e-8);
+
+ Matrix x1 = qr.solve(b());
+ Matrix xRef = reshape(new double[]{
+ -0.182580239668147, -0.437233627652114, 0.138787653097464, 0.672934739896228, -0.131420217069083, 0, 0, 0
+ }, 8, 1);
+
+ printMatrix("xRef", xRef);
+ printMatrix("x", x1);
+ assertEquals(xRef, x1, 1.0e-8);
+
+ assertEquals(x, qr.getQ().times(qr.getR()), 1.0e-15);
+ }
+
+ // TODO: the speedup constant should be checked and oddly, the times don't increase as the counts increase
+ @Ignore
+ public void fasterThanBefore() {
+
+ OnlineSummarizer s1 = new OnlineSummarizer();
+ OnlineSummarizer s2 = new OnlineSummarizer();
+
+ Matrix a = new DenseMatrix(60, 60).assign(Functions.random());
+
+ decompositionSpeedCheck(new Decomposer() {
+ @Override
+ public QR decompose(Matrix a) {
+ return new QRDecomposition(a);
+ }
+ }, s1, a, "new");
+
+ decompositionSpeedCheck(new Decomposer() {
+ @Override
+ public QR decompose(Matrix a) {
+ return new OldQRDecomposition(a);
+ }
+ }, s2, a, "old");
+
+ // should be much more than twice as fast. (originally was on s2.getMedian, but we factored out com.tdunning )
+ System.out.printf("Speedup is about %.1f times\n", s2.getMean() / s1.getMean());
+ assertTrue(s1.getMean() < 0.5 * s2.getMean());
+ }
+
+ private interface Decomposer {
+ QR decompose(Matrix a);
+ }
+
+ private static void decompositionSpeedCheck(Decomposer qrf, OnlineSummarizer s1, Matrix a, String label) {
+ int n = 0;
+ List<Integer> counts = Lists.newArrayList(10, 20, 50, 100, 200, 500);
+ for (int k : counts) {
+ double warmup = 0;
+ double other = 0;
+
+ n += k;
+ for (int i = 0; i < k; i++) {
+ QR qr = qrf.decompose(a);
+ warmup = Math.max(warmup, qr.getQ().transpose().times(qr.getQ()).viewDiagonal().assign(Functions.plus(-1)).norm(1));
+ Matrix z = qr.getQ().times(qr.getR()).minus(a);
+ other = Math.max(other, z.aggregate(Functions.MIN, Functions.ABS));
+ }
+
+ double maxIdent = 0;
+ double maxError = 0;
+
+ long t0 = System.nanoTime();
+ for (int i = 0; i < n; i++) {
+ QR qr = qrf.decompose(a);
+
+ maxIdent = Math.max(maxIdent, qr.getQ().transpose().times(qr.getQ()).viewDiagonal().assign(Functions.plus(-1)).norm(1));
+ Matrix z = qr.getQ().times(qr.getR()).minus(a);
+ maxError = Math.max(maxError, z.aggregate(Functions.MIN, Functions.ABS));
+ }
+ long t1 = System.nanoTime();
+ if (k > 100) {
+ s1.add(t1 - t0);
+ }
+ System.out.printf("%s %d\t%.1f\t%g\t%g\t%g\n", label, n, (t1 - t0) / 1.0e3 / n, maxIdent, maxError, warmup);
+ }
+ }
+
+ private static void assertEquals(Matrix ref, Matrix actual, double epsilon) {
+ assertEquals(0, ref.minus(actual).aggregate(Functions.MAX, Functions.ABS), epsilon);
+ }
+
+ private static void printMatrix(String name, Matrix m) {
+ int rows = m.numRows();
+ int columns = m.numCols();
+ System.out.printf("%s - %d x %d\n", name, rows, columns);
+ for (int i = 0; i < rows; i++) {
+ for (int j = 0; j < columns; j++) {
+ System.out.printf("%10.5f", m.get(i, j));
+ }
+ System.out.printf("\n");
+ }
+ System.out.printf("\n");
+ System.out.printf("\n");
+ }
+
+ private static Matrix matrix() {
+ double[] values = {
+ 0.494097293912641, -0.152566866170993, -0.418360266395271, 0.359475300232312,
+ 1.35565069667582, -1.92759373242903, 1.50497526839076, -0.746889132087904,
+ -0.769136838293565, 1.10984954080986, -0.664389974392489, 1.6464660350229,
+ -0.11715420616969, 0.0216221197371269, -0.394972730980765, -0.748293157213142,
+ 1.90402764664962, -0.638042862848559, -0.362336344669668, -0.418261074380526,
+ -0.494211543128429, 1.38828971158414, 0.597110366867923, 1.05341387608687,
+ -0.957461740877418, -2.35528802598249, -1.03171458944128, 0.644319090271635,
+ -0.0569108993041965, -0.14419465550881, -0.0456801828174936,
+ 0.754694392571835, 0.719744008628535, -1.17873249802301, -0.155887528905918,
+ -1.5159868405466, 0.0918931582603128, 1.42179027361583, -0.100495054250176,
+ 0.0687986548485584
+ };
+ return reshape(values, 8, 5);
+ }
+
+ private static Matrix reshape(double[] values, int rows, int columns) {
+ Matrix m = new DenseMatrix(rows, columns);
+ int i = 0;
+ for (double v : values) {
+ m.set(i % rows, i / rows, v);
+ i++;
+ }
+ return m;
+ }
+
+ private static Matrix b() {
+ return reshape(new double[]
+ {-0.0178247686747641, 0.68631714634098, -0.335464858468858, 1.50249941751569, -0.669901640772149}, 5, 1);
+ }
+}
+

http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/test/java/org/apache/mahout/math/TestSingularValueDecomposition.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/org/apache/mahout/math/TestSingularValueDecomposition.java b/core/src/test/java/org/apache/mahout/math/TestSingularValueDecomposition.java
new file mode 100644
index 0000000..c9e4026
--- /dev/null
+++ b/core/src/test/java/org/apache/mahout/math/TestSingularValueDecomposition.java
@@ -0,0 +1,327 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math;
+
+import com.google.common.base.Charsets;
+import com.google.common.base.Splitter;
+import com.google.common.collect.Iterables;
+import com.google.common.io.Resources;
+import org.apache.mahout.common.RandomUtils;
+import org.apache.mahout.math.function.Functions;
+import org.junit.Test;
+
+import java.io.IOException;
+import java.util.List;
+import java.util.Random;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.TimeoutException;
+
+//To launch this test only : mvn test -Dtest=org.apache.mahout.math.TestSingularValueDecomposition
+public final class TestSingularValueDecomposition extends MahoutTestCase {
+
+ private final double[][] testSquare = {
+ { 24.0 / 25.0, 43.0 / 25.0 },
+ { 57.0 / 25.0, 24.0 / 25.0 }
+ };
+
+ private final double[][] testNonSquare = {
+ { -540.0 / 625.0, 963.0 / 625.0, -216.0 / 625.0 },
+ { -1730.0 / 625.0, -744.0 / 625.0, 1008.0 / 625.0 },
+ { -720.0 / 625.0, 1284.0 / 625.0, -288.0 / 625.0 },
+ { -360.0 / 625.0, 192.0 / 625.0, 1756.0 / 625.0 },
+ };
+
+ private static final double NORM_TOLERANCE = 10.0e-14;
+
+ @Test
+ public void testMoreRows() {
+ double[] singularValues = { 123.456, 2.3, 1.001, 0.999 };
+ int rows = singularValues.length + 2;
+ int columns = singularValues.length;
+ Random r = RandomUtils.getRandom();
+ SingularValueDecomposition svd =
+ new SingularValueDecomposition(createTestMatrix(r, rows, columns, singularValues));
+ double[] computedSV = svd.getSingularValues();
+ assertEquals(singularValues.length, computedSV.length);
+ for (int i = 0; i < singularValues.length; ++i) {
+ assertEquals(singularValues[i], computedSV[i], 1.0e-10);
+ }
+ }
+
+ @Test
+ public void testMoreColumns() {
+ double[] singularValues = { 123.456, 2.3, 1.001, 0.999 };
+ int rows = singularValues.length;
+ int columns = singularValues.length + 2;
+ Random r = RandomUtils.getRandom();
+ SingularValueDecomposition svd =
+ new SingularValueDecomposition(createTestMatrix(r, rows, columns, singularValues));
+ double[] computedSV = svd.getSingularValues();
+ assertEquals(singularValues.length, computedSV.length);
+ for (int i = 0; i < singularValues.length; ++i) {
+ assertEquals(singularValues[i], computedSV[i], 1.0e-10);
+ }
+ }
+
+ /** test dimensions */
+ @Test
+ public void testDimensions() {
+ Matrix matrix = new DenseMatrix(testSquare);
+ int m = matrix.numRows();
+ int n = matrix.numCols();
+ SingularValueDecomposition svd = new SingularValueDecomposition(matrix);
+ assertEquals(m, svd.getU().numRows());
+ assertEquals(m, svd.getU().numCols());
+ assertEquals(m, svd.getS().numCols());
+ assertEquals(n, svd.getS().numCols());
+ assertEquals(n, svd.getV().numRows());
+ assertEquals(n, svd.getV().numCols());
+
+ }
+
+ /** Test based on a dimension 4 Hadamard matrix. */
+ // getCovariance to be implemented
+ @Test
+ public void testHadamard() {
+ Matrix matrix = new DenseMatrix(new double[][] {
+ {15.0 / 2.0, 5.0 / 2.0, 9.0 / 2.0, 3.0 / 2.0 },
+ { 5.0 / 2.0, 15.0 / 2.0, 3.0 / 2.0, 9.0 / 2.0 },
+ { 9.0 / 2.0, 3.0 / 2.0, 15.0 / 2.0, 5.0 / 2.0 },
+ { 3.0 / 2.0, 9.0 / 2.0, 5.0 / 2.0, 15.0 / 2.0 }
+ });
+ SingularValueDecomposition svd = new SingularValueDecomposition(matrix);
+ assertEquals(16.0, svd.getSingularValues()[0], 1.0e-14);
+ assertEquals( 8.0, svd.getSingularValues()[1], 1.0e-14);
+ assertEquals( 4.0, svd.getSingularValues()[2], 1.0e-14);
+ assertEquals( 2.0, svd.getSingularValues()[3], 1.0e-14);
+
+ Matrix fullCovariance = new DenseMatrix(new double[][] {
+ { 85.0 / 1024, -51.0 / 1024, -75.0 / 1024, 45.0 / 1024 },
+ { -51.0 / 1024, 85.0 / 1024, 45.0 / 1024, -75.0 / 1024 },
+ { -75.0 / 1024, 45.0 / 1024, 85.0 / 1024, -51.0 / 1024 },
+ { 45.0 / 1024, -75.0 / 1024, -51.0 / 1024, 85.0 / 1024 }
+ });
+
+ assertEquals(0.0,Algebra.getNorm(fullCovariance.minus(svd.getCovariance(0.0))),1.0e-14);
+
+
+ Matrix halfCovariance = new DenseMatrix(new double[][] {
+ { 5.0 / 1024, -3.0 / 1024, 5.0 / 1024, -3.0 / 1024 },
+ { -3.0 / 1024, 5.0 / 1024, -3.0 / 1024, 5.0 / 1024 },
+ { 5.0 / 1024, -3.0 / 1024, 5.0 / 1024, -3.0 / 1024 },
+ { -3.0 / 1024, 5.0 / 1024, -3.0 / 1024, 5.0 / 1024 }
+ });
+ assertEquals(0.0,Algebra.getNorm(halfCovariance.minus(svd.getCovariance(6.0))),1.0e-14);
+
+ }
+
+ /** test A = USVt */
+ @Test
+ public void testAEqualUSVt() {
+ checkAEqualUSVt(new DenseMatrix(testSquare));
+ checkAEqualUSVt(new DenseMatrix(testNonSquare));
+ checkAEqualUSVt(new DenseMatrix(testNonSquare).transpose());
+ }
+
+ public static void checkAEqualUSVt(Matrix matrix) {
+ SingularValueDecomposition svd = new SingularValueDecomposition(matrix);
+ Matrix u = svd.getU();
+ Matrix s = svd.getS();
+ Matrix v = svd.getV();
+
+ //pad with 0, to be able to check some properties if some singular values are equal to 0
+ if (s.numRows()<matrix.numRows()) {
+
+ Matrix sp = new DenseMatrix(s.numRows()+1,s.numCols());
+ Matrix up = new DenseMatrix(u.numRows(),u.numCols()+1);
+
+
+ for (int i = 0; i < u.numRows(); i++) {
+ for (int j = 0; j < u.numCols(); j++) {
+ up.set(i, j, u.get(i, j));
+ }
+ }
+
+ for (int i = 0; i < s.numRows(); i++) {
+ for (int j = 0; j < s.numCols(); j++) {
+ sp.set(i, j, s.get(i, j));
+ }
+ }
+
+ u = up;
+ s = sp;
+ }
+
+ double norm = Algebra.getNorm(u.times(s).times(v.transpose()).minus(matrix));
+ assertEquals(0, norm, NORM_TOLERANCE);
+
+ }
+
+ /** test that U is orthogonal */
+ @Test
+ public void testUOrthogonal() {
+ checkOrthogonal(new SingularValueDecomposition(new DenseMatrix(testSquare)).getU());
+ checkOrthogonal(new SingularValueDecomposition(new DenseMatrix(testNonSquare)).getU());
+ checkOrthogonal(new SingularValueDecomposition(new DenseMatrix(testNonSquare).transpose()).getU());
+ }
+
+ /** test that V is orthogonal */
+ @Test
+ public void testVOrthogonal() {
+ checkOrthogonal(new SingularValueDecomposition(new DenseMatrix(testSquare)).getV());
+ checkOrthogonal(new SingularValueDecomposition(new DenseMatrix(testNonSquare)).getV());
+ checkOrthogonal(new SingularValueDecomposition(new DenseMatrix(testNonSquare).transpose()).getV());
+ }
+
+ public static void checkOrthogonal(Matrix m) {
+ Matrix mTm = m.transpose().times(m);
+ Matrix id = new DenseMatrix(mTm.numRows(),mTm.numRows());
+ for (int i = 0; i < mTm.numRows(); i++) {
+ id.set(i, i, 1);
+ }
+ assertEquals(0, Algebra.getNorm(mTm.minus(id)), NORM_TOLERANCE);
+ }
+
+ /** test matrices values */
+ @Test
+ public void testMatricesValues1() {
+ SingularValueDecomposition svd =
+ new SingularValueDecomposition(new DenseMatrix(testSquare));
+ Matrix uRef = new DenseMatrix(new double[][] {
+ { 3.0 / 5.0, 4.0 / 5.0 },
+ { 4.0 / 5.0, -3.0 / 5.0 }
+ });
+ Matrix sRef = new DenseMatrix(new double[][] {
+ { 3.0, 0.0 },
+ { 0.0, 1.0 }
+ });
+ Matrix vRef = new DenseMatrix(new double[][] {
+ { 4.0 / 5.0, -3.0 / 5.0 },
+ { 3.0 / 5.0, 4.0 / 5.0 }
+ });
+
+ // check values against known references
+ Matrix u = svd.getU();
+
+ assertEquals(0, Algebra.getNorm(u.minus(uRef)), NORM_TOLERANCE);
+ Matrix s = svd.getS();
+ assertEquals(0, Algebra.getNorm(s.minus(sRef)), NORM_TOLERANCE);
+ Matrix v = svd.getV();
+ assertEquals(0, Algebra.getNorm(v.minus(vRef)), NORM_TOLERANCE);
+ }
+
+
+ /** test condition number */
+ @Test
+ public void testConditionNumber() {
+ SingularValueDecomposition svd =
+ new SingularValueDecomposition(new DenseMatrix(testSquare));
+ // replace 1.0e-15 with 1.5e-15
+ assertEquals(3.0, svd.cond(), 1.5e-15);
+ }
+
+ @Test
+ public void testSvdHang() throws IOException, InterruptedException, ExecutionException, TimeoutException {
+ System.out.printf("starting hanging-svd\n");
+ final Matrix m = readTsv("hanging-svd.tsv");
+ SingularValueDecomposition svd = new SingularValueDecomposition(m);
+ assertEquals(0, m.minus(svd.getU().times(svd.getS()).times(svd.getV().transpose())).aggregate(Functions.PLUS, Functions.ABS), 1e-10);
+ System.out.printf("No hang\n");
+ }
+
+ Matrix readTsv(String name) throws IOException {
+ Splitter onTab = Splitter.on("\t");
+ List<String> lines = Resources.readLines((Resources.getResource(name)), Charsets.UTF_8);
+ int rows = lines.size();
+ int columns = Iterables.size(onTab.split(lines.get(0)));
+ Matrix r = new DenseMatrix(rows, columns);
+ int row = 0;
+ for (String line : lines) {
+ Iterable<String> values = onTab.split(line);
+ int column = 0;
+ for (String value : values) {
+ r.set(row, column, Double.parseDouble(value));
+ column++;
+ }
+ row++;
+ }
+ return r;
+ }
+
+
+ private static Matrix createTestMatrix(Random r, int rows, int columns, double[] singularValues) {
+ Matrix u = createOrthogonalMatrix(r, rows);
+ Matrix d = createDiagonalMatrix(singularValues, rows, columns);
+ Matrix v = createOrthogonalMatrix(r, columns);
+ return u.times(d).times(v);
+ }
+
+
+ public static Matrix createOrthogonalMatrix(Random r, int size) {
+
+ double[][] data = new double[size][size];
+
+ for (int i = 0; i < size; ++i) {
+ double[] dataI = data[i];
+ double norm2;
+ do {
+
+ // generate randomly row I
+ for (int j = 0; j < size; ++j) {
+ dataI[j] = 2 * r.nextDouble() - 1;
+ }
+
+ // project the row in the subspace orthogonal to previous rows
+ for (int k = 0; k < i; ++k) {
+ double[] dataK = data[k];
+ double dotProduct = 0;
+ for (int j = 0; j < size; ++j) {
+ dotProduct += dataI[j] * dataK[j];
+ }
+ for (int j = 0; j < size; ++j) {
+ dataI[j] -= dotProduct * dataK[j];
+ }
+ }
+
+ // normalize the row
+ norm2 = 0;
+ for (double dataIJ : dataI) {
+ norm2 += dataIJ * dataIJ;
+ }
+ double inv = 1.0 / Math.sqrt(norm2);
+ for (int j = 0; j < size; ++j) {
+ dataI[j] *= inv;
+ }
+
+ } while (norm2 * size < 0.01);
+ }
+
+ return new DenseMatrix(data);
+
+ }
+
+ public static Matrix createDiagonalMatrix(double[] diagonal, int rows, int columns) {
+ double[][] dData = new double[rows][columns];
+ for (int i = 0; i < Math.min(rows, columns); ++i) {
+ dData[i][i] = diagonal[i];
+ }
+ return new DenseMatrix(dData);
+ }
+
+
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/test/java/org/apache/mahout/math/als/AlternatingLeastSquaresSolverTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/org/apache/mahout/math/als/AlternatingLeastSquaresSolverTest.java b/core/src/test/java/org/apache/mahout/math/als/AlternatingLeastSquaresSolverTest.java
new file mode 100644
index 0000000..95b19ad
--- /dev/null
+++ b/core/src/test/java/org/apache/mahout/math/als/AlternatingLeastSquaresSolverTest.java
@@ -0,0 +1,151 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math.als;
+
+import java.util.Arrays;
+
+import org.apache.mahout.math.DenseMatrix;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.MahoutTestCase;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.RandomAccessSparseVector;
+import org.apache.mahout.math.SequentialAccessSparseVector;
+import org.apache.mahout.math.SparseMatrix;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.map.OpenIntObjectHashMap;
+import org.junit.Test;
+
+public class AlternatingLeastSquaresSolverTest extends MahoutTestCase {
+
+ @Test
+ public void testYtY() {
+
+ double[][] testMatrix = new double[][] {
+ new double[] { 1, 2, 3, 4, 5 },
+ new double[] { 1, 2, 3, 4, 5 },
+ new double[] { 1, 2, 3, 4, 5 },
+ new double[] { 1, 2, 3, 4, 5 },
+ new double[] { 1, 2, 3, 4, 5 }};
+
+ double[][] testMatrix2 = new double[][] {
+ new double[] { 1, 2, 3, 4, 5, 6 },
+ new double[] { 5, 4, 3, 2, 1, 7 },
+ new double[] { 1, 2, 3, 4, 5, 8 },
+ new double[] { 1, 2, 3, 4, 5, 8 },
+ new double[] { 11, 12, 13, 20, 27, 8 }};
+
+ double[][][] testData = new double[][][] {
+ testMatrix,
+ testMatrix2 };
+
+ for (int i = 0; i < testData.length; i++) {
+ Matrix matrixToTest = new DenseMatrix(testData[i]);
+
+ //test for race conditions by trying a few times
+ for (int j = 0; j < 100; j++) {
+ validateYtY(matrixToTest, 4);
+ }
+
+ //one thread @ a time test
+ validateYtY(matrixToTest, 1);
+ }
+
+ }
+
+ private void validateYtY(Matrix matrixToTest, int numThreads) {
+
+ OpenIntObjectHashMap<Vector> matrixToTestAsRowVectors = asRowVectors(matrixToTest);
+ ImplicitFeedbackAlternatingLeastSquaresSolver solver = new ImplicitFeedbackAlternatingLeastSquaresSolver(
+ matrixToTest.columnSize(), 1, 1, matrixToTestAsRowVectors, numThreads);
+
+ Matrix yTy = matrixToTest.transpose().times(matrixToTest);
+ Matrix shouldMatchyTy = solver.getYtransposeY(matrixToTestAsRowVectors);
+
+ for (int row = 0; row < yTy.rowSize(); row++) {
+ for (int column = 0; column < yTy.columnSize(); column++) {
+ assertEquals(yTy.getQuick(row, column), shouldMatchyTy.getQuick(row, column), 0);
+ }
+ }
+ }
+
+ private OpenIntObjectHashMap<Vector> asRowVectors(Matrix matrix) {
+ OpenIntObjectHashMap<Vector> rows = new OpenIntObjectHashMap<>();
+ for (int row = 0; row < matrix.numRows(); row++) {
+ rows.put(row, matrix.viewRow(row).clone());
+ }
+ return rows;
+ }
+
+ @Test
+ public void addLambdaTimesNuiTimesE() {
+ int nui = 5;
+ double lambda = 0.2;
+ Matrix matrix = new SparseMatrix(5, 5);
+
+ AlternatingLeastSquaresSolver.addLambdaTimesNuiTimesE(matrix, lambda, nui);
+
+ for (int n = 0; n < 5; n++) {
+ assertEquals(1.0, matrix.getQuick(n, n), EPSILON);
+ }
+ }
+
+ @Test
+ public void createMiIi() {
+ Vector f1 = new DenseVector(new double[] { 1, 2, 3 });
+ Vector f2 = new DenseVector(new double[] { 4, 5, 6 });
+
+ Matrix miIi = AlternatingLeastSquaresSolver.createMiIi(Arrays.asList(f1, f2), 3);
+
+ assertEquals(1.0, miIi.getQuick(0, 0), EPSILON);
+ assertEquals(2.0, miIi.getQuick(1, 0), EPSILON);
+ assertEquals(3.0, miIi.getQuick(2, 0), EPSILON);
+ assertEquals(4.0, miIi.getQuick(0, 1), EPSILON);
+ assertEquals(5.0, miIi.getQuick(1, 1), EPSILON);
+ assertEquals(6.0, miIi.getQuick(2, 1), EPSILON);
+ }
+
+ @Test
+ public void createRiIiMaybeTransposed() {
+ Vector ratings = new SequentialAccessSparseVector(3);
+ ratings.setQuick(1, 1.0);
+ ratings.setQuick(3, 3.0);
+ ratings.setQuick(5, 5.0);
+
+ Matrix riIiMaybeTransposed = AlternatingLeastSquaresSolver.createRiIiMaybeTransposed(ratings);
+ assertEquals(1, riIiMaybeTransposed.numCols(), 1);
+ assertEquals(3, riIiMaybeTransposed.numRows(), 3);
+
+ assertEquals(1.0, riIiMaybeTransposed.getQuick(0, 0), EPSILON);
+ assertEquals(3.0, riIiMaybeTransposed.getQuick(1, 0), EPSILON);
+ assertEquals(5.0, riIiMaybeTransposed.getQuick(2, 0), EPSILON);
+ }
+
+ @Test
+ public void createRiIiMaybeTransposedExceptionOnNonSequentialVector() {
+ Vector ratings = new RandomAccessSparseVector(3);
+ ratings.setQuick(1, 1.0);
+ ratings.setQuick(3, 3.0);
+ ratings.setQuick(5, 5.0);
+
+ try {
+ AlternatingLeastSquaresSolver.createRiIiMaybeTransposed(ratings);
+ fail();
+ } catch (IllegalArgumentException e) {}
+ }
+
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/test/java/org/apache/mahout/math/decomposer/SolverTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/org/apache/mahout/math/decomposer/SolverTest.java b/core/src/test/java/org/apache/mahout/math/decomposer/SolverTest.java
new file mode 100644
index 0000000..13baad8
--- /dev/null
+++ b/core/src/test/java/org/apache/mahout/math/decomposer/SolverTest.java
@@ -0,0 +1,177 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math.decomposer;
+
+import com.google.common.collect.Lists;
+import org.apache.mahout.common.RandomUtils;
+import org.apache.mahout.math.DenseMatrix;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.MahoutTestCase;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.SequentialAccessSparseVector;
+import org.apache.mahout.math.SparseRowMatrix;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorIterable;
+import org.apache.mahout.math.decomposer.lanczos.LanczosState;
+import org.apache.mahout.math.function.Functions;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.util.List;
+import java.util.Random;
+
+public abstract class SolverTest extends MahoutTestCase {
+ private static final Logger log = LoggerFactory.getLogger(SolverTest.class);
+
+ public static void assertOrthonormal(Matrix eigens) {
+ assertOrthonormal(eigens, 1.0e-6);
+ }
+
+ public static void assertOrthonormal(Matrix currentEigens, double errorMargin) {
+ List<String> nonOrthogonals = Lists.newArrayList();
+ for (int i = 0; i < currentEigens.numRows(); i++) {
+ Vector ei = currentEigens.viewRow(i);
+ for (int j = 0; j <= i; j++) {
+ Vector ej = currentEigens.viewRow(j);
+ if (ei.norm(2) == 0 || ej.norm(2) == 0) {
+ continue;
+ }
+ double dot = ei.dot(ej);
+ if (i == j) {
+ assertTrue("not norm 1 : " + dot + " (eigen #" + i + ')', Math.abs(1.0 - dot) < errorMargin);
+ } else {
+ if (Math.abs(dot) > errorMargin) {
+ log.info("not orthogonal : {} (eigens {}, {})", dot, i, j);
+ nonOrthogonals.add("(" + i + ',' + j + ')');
+ }
+ }
+ }
+ log.info("{}:{}", nonOrthogonals.size(), nonOrthogonals);
+ }
+ }
+
+ public static void assertOrthonormal(LanczosState state) {
+ double errorMargin = 1.0e-5;
+ List<String> nonOrthogonals = Lists.newArrayList();
+ for (int i = 0; i < state.getIterationNumber(); i++) {
+ Vector ei = state.getRightSingularVector(i);
+ for (int j = 0; j <= i; j++) {
+ Vector ej = state.getRightSingularVector(j);
+ if (ei.norm(2) == 0 || ej.norm(2) == 0) {
+ continue;
+ }
+ double dot = ei.dot(ej);
+ if (i == j) {
+ assertTrue("not norm 1 : " + dot + " (eigen #" + i + ')', Math.abs(1.0 - dot) < errorMargin);
+ } else {
+ if (Math.abs(dot) > errorMargin) {
+ log.info("not orthogonal : {} (eigens {}, {})", dot, i, j);
+ nonOrthogonals.add("(" + i + ',' + j + ')');
+ }
+ }
+ }
+ if (!nonOrthogonals.isEmpty()) {
+ log.info("{}:{}", nonOrthogonals.size(), nonOrthogonals);
+ }
+ }
+ }
+
+ public static void assertEigen(Matrix eigens, VectorIterable corpus, double errorMargin, boolean isSymmetric) {
+ assertEigen(eigens, corpus, eigens.numRows(), errorMargin, isSymmetric);
+ }
+
+ public static void assertEigen(Matrix eigens,
+ VectorIterable corpus,
+ int numEigensToCheck,
+ double errorMargin,
+ boolean isSymmetric) {
+ for (int i = 0; i < numEigensToCheck; i++) {
+ Vector e = eigens.viewRow(i);
+ assertEigen(i, e, corpus, errorMargin, isSymmetric);
+ }
+ }
+
+ public static void assertEigen(int i, Vector e, VectorIterable corpus, double errorMargin,
+ boolean isSymmetric) {
+ if (e.getLengthSquared() == 0) {
+ return;
+ }
+ Vector afterMultiply = isSymmetric ? corpus.times(e) : corpus.timesSquared(e);
+ double dot = afterMultiply.dot(e);
+ double afterNorm = afterMultiply.getLengthSquared();
+ double error = 1 - Math.abs(dot / Math.sqrt(afterNorm * e.getLengthSquared()));
+ log.info("the eigen-error: {} for eigen {}", error, i);
+ assertTrue("Error: {" + error + " too high! (for eigen " + i + ')', Math.abs(error) < errorMargin);
+ }
+
+ /**
+ * Builds up a consistently random (same seed every time) sparse matrix, with sometimes
+ * repeated rows.
+ */
+ public static Matrix randomSequentialAccessSparseMatrix(int numRows,
+ int nonNullRows,
+ int numCols,
+ int entriesPerRow,
+ double entryMean) {
+ Matrix m = new SparseRowMatrix(numRows, numCols);
+ //double n = 0;
+ Random r = RandomUtils.getRandom();
+ for (int i = 0; i < nonNullRows; i++) {
+ Vector v = new SequentialAccessSparseVector(numCols);
+ for (int j = 0; j < entriesPerRow; j++) {
+ int col = r.nextInt(numCols);
+ double val = r.nextGaussian();
+ v.set(col, val * entryMean);
+ }
+ int c = r.nextInt(numRows);
+ if (r.nextBoolean() || numRows == nonNullRows) {
+ m.assignRow(numRows == nonNullRows ? i : c, v);
+ } else {
+ Vector other = m.viewRow(r.nextInt(numRows));
+ if (other != null && other.getLengthSquared() > 0) {
+ m.assignRow(c, other.clone());
+ }
+ }
+ //n += m.getRow(c).getLengthSquared();
+ }
+ return m;
+ }
+
+ public static Matrix randomHierarchicalMatrix(int numRows, int numCols, boolean symmetric) {
+ Matrix matrix = new DenseMatrix(numRows, numCols);
+ // TODO rejigger tests so that it doesn't expect this particular seed
+ Random r = new Random(1234L);
+ for (int row = 0; row < numRows; row++) {
+ Vector v = new DenseVector(numCols);
+ for (int col = 0; col < numCols; col++) {
+ double val = r.nextGaussian();
+ v.set(col, val);
+ }
+ v.assign(Functions.MULT, 1/((row + 1) * v.norm(2)));
+ matrix.assignRow(row, v);
+ }
+ if (symmetric) {
+ return matrix.times(matrix.transpose());
+ }
+ return matrix;
+ }
+
+ public static Matrix randomHierarchicalSymmetricMatrix(int size) {
+ return randomHierarchicalMatrix(size, size, true);
+ }
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/test/java/org/apache/mahout/math/decomposer/hebbian/TestHebbianSolver.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/org/apache/mahout/math/decomposer/hebbian/TestHebbianSolver.java b/core/src/test/java/org/apache/mahout/math/decomposer/hebbian/TestHebbianSolver.java
new file mode 100644
index 0000000..56ea4f6
--- /dev/null
+++ b/core/src/test/java/org/apache/mahout/math/decomposer/hebbian/TestHebbianSolver.java
@@ -0,0 +1,207 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math.decomposer.hebbian;
+
+import org.apache.mahout.math.DenseMatrix;
+import org.apache.mahout.math.Matrix;
+
+import org.apache.mahout.math.decomposer.AsyncEigenVerifier;
+import org.apache.mahout.math.decomposer.SolverTest;
+import org.junit.Test;
+
+/**
+ * This test is woefully inadequate, and also requires tons of memory, because it's part
+ * unit test, part performance test, and part comparison test (between the Hebbian and Lanczos
+ * approaches).
+ * TODO: make better.
+ */
+public final class TestHebbianSolver extends SolverTest {
+
+ public static long timeSolver(Matrix corpus,
+ double convergence,
+ int maxNumPasses,
+ TrainingState state) {
+ return timeSolver(corpus,
+ convergence,
+ maxNumPasses,
+ 10,
+ state);
+ }
+
+ public static long timeSolver(Matrix corpus,
+ double convergence,
+ int maxNumPasses,
+ int desiredRank,
+ TrainingState state) {
+ HebbianUpdater updater = new HebbianUpdater();
+ AsyncEigenVerifier verifier = new AsyncEigenVerifier();
+ HebbianSolver solver = new HebbianSolver(updater,
+ verifier,
+ convergence,
+ maxNumPasses);
+ long start = System.nanoTime();
+ TrainingState finalState = solver.solve(corpus, desiredRank);
+ assertNotNull(finalState);
+ state.setCurrentEigens(finalState.getCurrentEigens());
+ state.setCurrentEigenValues(finalState.getCurrentEigenValues());
+ long time = 0L;
+ time += System.nanoTime() - start;
+ verifier.close();
+ assertEquals(state.getCurrentEigens().numRows(), desiredRank);
+ return time / 1000000L;
+ }
+
+
+
+ public static long timeSolver(Matrix corpus, TrainingState state) {
+ return timeSolver(corpus, state, 10);
+ }
+
+ public static long timeSolver(Matrix corpus, TrainingState state, int rank) {
+ return timeSolver(corpus, 0.01, 20, rank, state);
+ }
+
+ @Test
+ public void testHebbianSolver() {
+ int numColumns = 800;
+ Matrix corpus = randomSequentialAccessSparseMatrix(1000, 900, numColumns, 30, 1.0);
+ int rank = 50;
+ Matrix eigens = new DenseMatrix(rank, numColumns);
+ TrainingState state = new TrainingState(eigens, null);
+ long optimizedTime = timeSolver(corpus,
+ 0.00001,
+ 5,
+ rank,
+ state);
+ eigens = state.getCurrentEigens();
+ assertEigen(eigens, corpus, 0.05, false);
+ assertOrthonormal(eigens, 1.0e-6);
+ System.out.println("Avg solving (Hebbian) time in ms: " + optimizedTime);
+ }
+
+ /*
+ public void testSolverWithSerialization() throws Exception
+ {
+ _corpusProjectionsVectorFactory = new DenseMapVectorFactory();
+ _eigensVectorFactory = new DenseMapVectorFactory();
+
+ timeSolver(TMP_EIGEN_DIR,
+ 0.001,
+ 5,
+ new TrainingState(null, null));
+
+ File eigenDir = new File(TMP_EIGEN_DIR + File.separator + HebbianSolver.EIGEN_VECT_DIR);
+ DiskBufferedDoubleMatrix eigens = new DiskBufferedDoubleMatrix(eigenDir, 10);
+
+ DoubleMatrix inMemoryMatrix = new HashMapDoubleMatrix(_corpusProjectionsVectorFactory, eigens);
+
+ for (Entry<Integer, MapVector> diskEntry : eigens)
+ {
+ for (Entry<Integer, MapVector> inMemoryEntry : inMemoryMatrix)
+ {
+ if (diskEntry.getKey() - inMemoryEntry.getKey() == 0)
+ {
+ assertTrue("vector with index : " + diskEntry.getKey() + " is not the same on disk as in memory",
+ Math.abs(1 - diskEntry.getValue().dot(inMemoryEntry.getValue())) < 1e-6);
+ }
+ else
+ {
+ assertTrue("vector with index : " + diskEntry.getKey()
+ + " is not orthogonal to memory vect with index : " + inMemoryEntry.getKey(),
+ Math.abs(diskEntry.getValue().dot(inMemoryEntry.getValue())) < 1e-6);
+ }
+ }
+ }
+ File corpusDir = new File(TMP_EIGEN_DIR + File.separator + "corpus");
+ corpusDir.mkdir();
+ // TODO: persist to disk?
+ // DiskBufferedDoubleMatrix.persistChunk(corpusDir, corpus, true);
+ // eigens.delete();
+
+ // DiskBufferedDoubleMatrix.delete(new File(TMP_EIGEN_DIR));
+ }
+ */
+/*
+ public void testHebbianVersusLanczos() throws Exception
+ {
+ _corpusProjectionsVectorFactory = new DenseMapVectorFactory();
+ _eigensVectorFactory = new DenseMapVectorFactory();
+ int desiredRank = 200;
+ long time = timeSolver(TMP_EIGEN_DIR,
+ 0.00001,
+ 5,
+ desiredRank,
+ new TrainingState());
+
+ System.out.println("Hebbian time: " + time + "ms");
+ File eigenDir = new File(TMP_EIGEN_DIR + File.separator + HebbianSolver.EIGEN_VECT_DIR);
+ DiskBufferedDoubleMatrix eigens = new DiskBufferedDoubleMatrix(eigenDir, 10);
+
+ DoubleMatrix2D srm = asSparseDoubleMatrix2D(corpus);
+ long timeA = System.nanoTime();
+ EigenvalueDecomposition asSparseRealDecomp = new EigenvalueDecomposition(srm);
+ for (int i=0; i<desiredRank; i++)
+ asSparseRealDecomp.getEigenvector(i);
+ System.out.println("CommonsMath time: " + (System.nanoTime() - timeA)/TimingConstants.NANOS_IN_MILLI + "ms");
+
+ // System.out.println("Hebbian results:");
+ // printEigenVerify(eigens, corpus);
+
+ DoubleMatrix lanczosEigenVectors = new HashMapDoubleMatrix(new HashMapVectorFactory());
+ List<Double> lanczosEigenValues = new ArrayList<Double>();
+
+ LanczosSolver solver = new LanczosSolver();
+ solver.solve(corpus, desiredRank*5, lanczosEigenVectors, lanczosEigenValues);
+
+ for (TimingSection section : LanczosSolver.TimingSection.values())
+ {
+ System.out.println("Lanczos " + section.toString() + " = " + (int)(solver.getTimeMillis(section)/1000) + " seconds");
+ }
+
+ // System.out.println("\nLanczos results:");
+ // printEigenVerify(lanczosEigenVectors, corpus);
+ }
+
+ private DoubleMatrix2D asSparseDoubleMatrix2D(Matrix corpus)
+ {
+ DoubleMatrix2D result = new DenseDoubleMatrix2D(corpus.numRows(), corpus.numRows());
+ for (int i=0; i<corpus.numRows(); i++) {
+ for (int j=i; j<corpus.numRows(); j++) {
+ double v = corpus.getRow(i).dot(corpus.getRow(j));
+ result.set(i, j, v);
+ result.set(j, i, v);
+ }
+ }
+ return result;
+ }
+
+
+ public static void printEigenVerify(DoubleMatrix eigens, DoubleMatrix corpus)
+ {
+ for (Map.Entry<Integer, MapVector> entry : eigens)
+ {
+ MapVector eigen = entry.getValue();
+ MapVector afterMultiply = corpus.timesSquared(eigen);
+ double norm = afterMultiply.norm();
+ double error = 1 - eigen.dot(afterMultiply) / (eigen.norm() * afterMultiply.norm());
+ System.out.println(entry.getKey() + ": error = " + error + ", eVal = " + (norm / eigen.norm()));
+ }
+ }
+ */
+
+}
r***@apache.org
2018-09-08 23:35:08 UTC
Permalink
http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/list/package-info.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/mahout/math/list/package-info.java b/core/src/main/java/org/apache/mahout/math/list/package-info.java
new file mode 100644
index 0000000..43b5c4b
--- /dev/null
+++ b/core/src/main/java/org/apache/mahout/math/list/package-info.java
@@ -0,0 +1,144 @@
+/**
+ * <HTML>
+ * <BODY>
+ * Resizable lists holding objects or primitive data types such as <tt>int</tt>,
+ * <tt>double</tt>, etc. For non-resizable lists (1-dimensional matrices) see
+ * package <code>org.apache.mahout.math.matrix</code>.<p></p>
+ * <h1><a name="Overview"></a>Getting Started</h1>
+ * <h2>1. Overview</h2>
+ * <p>The list package offers flexible object oriented abstractions modelling dynamically
+ * resizing lists holding objects or primitive data types such as <tt>int</tt>,
+ * <tt>double</tt>, etc. It is designed to be scalable in terms of performance
+ * and memory requirements.</p>
+ * <p>Features include: </p>
+ * <p></p>
+ * <ul>
+ * <li>Lists operating on objects as well as all primitive data types such as <tt>int</tt>,
+ * <tt>double</tt>, etc.
+ * </li>
+ * <li>Compact representations</li>
+ * <li>A number of general purpose list operations including: adding, inserting,
+ * removing, iterating, searching, sorting, extracting ranges and copying. All
+ * operations are designed to perform well on mass data.
+ * </li>
+ * <li>Support for quick access to list elements. This is achieved by bounds-checking
+ * and non-bounds-checking accessor methods as well as zero-copy transformations
+ * to primitive arrays such as <tt>int[]</tt>, <tt>double[]</tt>, etc.
+ * </li>
+ * <li>Allows to use high level algorithms on primitive data types without any
+ * space and time overhead. Operations on primitive arrays, Colt lists and JAL
+ * algorithms can freely be mixed at zero copy overhead.
+ * </li>
+ * </ul>
+ * <p>File-based I/O can be achieved through the standard Java built-in serialization
+ * mechanism. All classes implement the {@link java.io.Serializable} interface.
+ * However, the toolkit is entirely decoupled from advanced I/O. It provides data
+ * structures and algorithms only.
+ * <p> This toolkit borrows concepts and terminology from the Javasoft <a
+ * href="http://www.javasoft.com/products/jdk/1.2/docs/guide/collections/index.html">
+ * Collections framework</a> written by Josh Bloch and introduced in JDK 1.2.
+ * <h2>2. Introduction</h2>
+ * <p>Lists are fundamental to virtually any application. Large scale resizable lists
+ * are, for example, used in scientific computations, simulations database management
+ * systems, to name just a few.</p>
+ * <h2></h2>
+ * <p>A list is a container holding elements that can be accessed via zero-based
+ * indexes. Lists may be implemented in different ways (most commonly with arrays).
+ * A resizable list automatically grows as elements are added. The lists of this
+ * package do not automatically shrink. Shrinking needs to be triggered by explicitly
+ * calling <tt>trimToSize()</tt> methods.</p>
+ * <p><i>Growing policy</i>: A list implemented with arrays initially has a certain
+ * <tt>initialCapacity</tt> - per default 10 elements, but customizable upon instance
+ * construction. As elements are added, this capacity may nomore be sufficient.
+ * When a list is automatically grown, its capacity is expanded to <tt>1.5*currentCapacity</tt>.
+ * Thus, excessive resizing (involving copying) is avoided.</p>
+ * <h4>Copying</h4>
+ * <p>
+ * <p>Any list can be copied. A copy is <i>equal</i> to the original but entirely
+ * independent of the original. So changes in the copy are not reflected in the
+ * original, and vice-versa.
+ * <h2>3. Organization of this package</h2>
+ * <p>Class naming follows the schema <tt>&lt;ElementType&gt;&lt;ImplementationTechnique&gt;List</tt>.
+ * For example, we have a {@link org.apache.mahout.math.list.DoubleArrayList}, which is a list
+ * holding <tt>double</tt> elements implemented with <tt>double</tt>[] arrays.
+ * </p>
+ * <p>The classes for lists of a given value type are derived from a common abstract
+ * base class tagged <tt>Abstract&lt;ElementType&gt;</tt><tt>List</tt>. For example,
+ * all lists operating on <tt>double</tt> elements are derived from
+ * {@link org.apache.mahout.math.list.AbstractDoubleList},
+ * which in turn is derived from an abstract base class tying together all lists
+ * regardless of value type, {@link org.apache.mahout.math.list.AbstractList}. The abstract
+ * base classes provide skeleton implementations for all but few methods. Experimental
+ * data layouts (such as compressed, sparse, linked, etc.) can easily be implemented
+ * and inherit a rich set of functionality. Have a look at the javadoc <a href="package-tree.html">tree
+ * view</a> to get the broad picture.</p>
+ * <h2>4. Example usage</h2>
+ * <p>The following snippet fills a list, randomizes it, extracts the first half
+ * of the elements, sums them up and prints the result. It is implemented entirely
+ * with accessor methods.</p>
+ * <table>
+ * <td class="PRE">
+ * <pre>
+ * int s = 1000000;<br>AbstractDoubleList list = new DoubleArrayList();
+ * for (int i=0; i&lt;s; i++) { list.add((double)i); }
+ * list.shuffle();
+ * AbstractDoubleList part = list.partFromTo(0,list.size()/2 - 1);
+ * double sum = 0.0;
+ * for (int i=0; i&lt;part.size(); i++) { sum += part.get(i); }
+ * log.info(sum);
+ * </pre>
+ * </td>
+ * </table>
+ * <p> For efficiency, all classes provide back doors to enable getting/setting the
+ * backing array directly. In this way, the high level operations of these classes
+ * can be used where appropriate, and one can switch to <tt>[]</tt>-array index
+ * notations where necessary. The key methods for this are <tt>public &lt;ElementType&gt;[]
+ * elements()</tt> and <tt>public void elements(&lt;ElementType&gt;[])</tt>. The
+ * former trustingly returns the array it internally keeps to store the elements.
+ * Holding this array in hand, we can use the <tt>[]</tt>-array operator to
+ * perform iteration over large lists without needing to copy the array or paying
+ * the performance penalty introduced by accessor methods. Alternatively any JAL
+ * algorithm (or other algorithm) can operate on the returned primitive array.
+ * The latter method forces a list to internally hold a user provided array. Using
+ * this approach one can avoid needing to copy the elements into the list.
+ * <p>As a consequence, operations on primitive arrays, Colt lists and JAL algorithms
+ * can freely be mixed at zero-copy overhead.
+ * <p> Note that such special treatment certainly breaks encapsulation. This functionality
+ * is provided for performance reasons only and should only be used when absolutely
+ * necessary. Here is the above example in mixed notation:
+ * <table>
+ * <td class="PRE">
+ * <pre>
+ * int s = 1000000;<br>DoubleArrayList list = new DoubleArrayList(s); // list.size()==0, capacity==s
+ * list.setSize(s); // list.size()==s<br>double[] values = list.elements();
+ * // zero copy, values.length==s<br>for (int i=0; i&lt;s; i++) { values[i]=(double)i; }
+ * list.shuffle();
+ * double sum = 0.0;
+ * int limit = values.length/2;
+ * for (int i=0; i&lt;limit; i++) { sum += values[i]; }
+ * log.info(sum);
+ * </pre>
+ * </td>
+ * </table>
+ * <p> Or even more compact using lists as algorithm objects:
+ * <table>
+ * <td class="PRE">
+ * <pre>
+ * int s = 1000000;<br>double[] values = new double[s];
+ * for (int i=0; i&lt;s; i++) { values[i]=(double)i; }
+ * new DoubleArrayList(values).shuffle(); // zero-copy, shuffle via back door
+ * double sum = 0.0;
+ * int limit = values.length/2;
+ * for (int i=0; i&lt;limit; i++) { sum += values[i]; }
+ * log.info(sum);
+ * </pre>
+ * </td>
+ * </table>
+ * <p>
+ * <h2>5. Notes </h2>
+ * <p>The quicksorts and mergesorts are the JDK 1.2 V1.26 algorithms, modified as
+ * necessary to operate on the given data types.
+ * </BODY>
+ * </HTML>
+ */
+package org.apache.mahout.math.list;

http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/map/HashFunctions.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/mahout/math/map/HashFunctions.java b/core/src/main/java/org/apache/mahout/math/map/HashFunctions.java
new file mode 100644
index 0000000..b749307
--- /dev/null
+++ b/core/src/main/java/org/apache/mahout/math/map/HashFunctions.java
@@ -0,0 +1,115 @@
+/*
+Copyright 1999 CERN - European Organization for Nuclear Research.
+Permission to use, copy, modify, distribute and sell this software and its documentation for any purpose
+is hereby granted without fee, provided that the above copyright notice appear in all copies and
+that both that copyright notice and this permission notice appear in supporting documentation.
+CERN makes no representations about the suitability of this software for any purpose.
+It is provided "as is" without expressed or implied warranty.
+*/
+package org.apache.mahout.math.map;
+
+
+/**
+ * Provides various hash functions.
+ */
+public final class HashFunctions {
+
+ /**
+ * Utility class pattern: all static members, no inheritance.
+ */
+ private HashFunctions() {
+ }
+
+ /**
+ * Returns a hashcode for the specified value.
+ *
+ * @return a hash code value for the specified value.
+ */
+ public static int hash(char value) {
+ return value;
+ }
+
+ /**
+ * Returns a hashcode for the specified value.
+ *
+ * @return a hash code value for the specified value.
+ */
+ public static int hash(double value) {
+ long bits = Double.doubleToLongBits(value);
+ return (int) (bits ^ (bits >>> 32));
+
+ //return (int) Double.doubleToLongBits(value*663608941.737);
+ // this avoids excessive hashCollisions in the case values are of the form (1.0, 2.0, 3.0, ...)
+ }
+
+ /**
+ * Returns a hashcode for the specified value.
+ *
+ * @return a hash code value for the specified value.
+ */
+ public static int hash(float value) {
+ return Float.floatToIntBits(value * 663608941.737f);
+ // this avoids excessive hashCollisions in the case values are of the form (1.0, 2.0, 3.0, ...)
+ }
+
+ /**
+ * Returns a hashcode for the specified value.
+ * The hashcode computation is similar to the last step
+ * of MurMurHash3.
+ *
+ * @return a hash code value for the specified value.
+ */
+ public static int hash(int value) {
+ int h = value;
+ h ^= h >>> 16;
+ h *= 0x85ebca6b;
+ h ^= h >>> 13;
+ h *= 0xc2b2ae35;
+ h ^= h >>> 16;
+ return h;
+ }
+
+ /**
+ * Returns a hashcode for the specified value.
+ *
+ * @return a hash code value for the specified value.
+ */
+ public static int hash(long value) {
+ return (int) (value ^ (value >> 32));
+ /*
+ value &= 0x7FFFFFFFFFFFFFFFL; // make it >=0 (0x7FFFFFFFFFFFFFFFL==Long.MAX_VALUE)
+ int hashCode = 0;
+ do hashCode = 31*hashCode + (int) (value%10);
+ while ((value /= 10) > 0);
+
+ return 28629151*hashCode; // spread even further; h*31^5
+ */
+ }
+
+ /**
+ * Returns a hashcode for the specified object.
+ *
+ * @return a hash code value for the specified object.
+ */
+ public static int hash(Object object) {
+ return object == null ? 0 : object.hashCode();
+ }
+
+ /**
+ * Returns a hashcode for the specified value.
+ *
+ * @return a hash code value for the specified value.
+ */
+ public static int hash(short value) {
+ return value;
+ }
+
+ /**
+ * Returns a hashcode for the specified value.
+ *
+ * @return a hash code value for the specified value.
+ */
+ public static int hash(boolean value) {
+ return value ? 1231 : 1237;
+ }
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/map/OpenHashMap.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/mahout/math/map/OpenHashMap.java b/core/src/main/java/org/apache/mahout/math/map/OpenHashMap.java
new file mode 100644
index 0000000..0efca4b
--- /dev/null
+++ b/core/src/main/java/org/apache/mahout/math/map/OpenHashMap.java
@@ -0,0 +1,652 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*
+Copyright � 1999 CERN - European Organization for Nuclear Research.
+Permission to use, copy, modify, distribute and sell this software and its documentation for any purpose
+is hereby granted without fee, provided that the above copyright notice appear in all copies and
+that both that copyright notice and this permission notice appear in supporting documentation.
+CERN makes no representations about the suitability of this software for any purpose.
+It is provided "as is" without expressed or implied warranty.
+*/
+package org.apache.mahout.math.map;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+
+import org.apache.mahout.math.function.ObjectObjectProcedure;
+import org.apache.mahout.math.function.ObjectProcedure;
+import org.apache.mahout.math.set.AbstractSet;
+import org.apache.mahout.math.set.OpenHashSet;
+
+/**
+ * Open hash map. This implements Map, but it does not respect several aspects of the Map contract
+ * that impose the very sorts of performance penalities that this class exists to avoid.
+ * {@link #entrySet}, {@link #values}, and {@link #keySet()} do <strong>not</strong> return
+ * collections that share storage with the main map, and changes to those returned objects
+ * are <strong>not</strong> reflected in the container.
+ **/
+public class OpenHashMap<K,V> extends AbstractSet implements Map<K,V> {
+ protected static final byte FREE = 0;
+ protected static final byte FULL = 1;
+ protected static final byte REMOVED = 2;
+ protected static final Object NO_KEY_VALUE = null;
+
+ /** The hash table keys. */
+ protected Object[] table;
+
+ /** The hash table values. */
+ protected Object[] values;
+
+ /** The state of each hash table entry (FREE, FULL, REMOVED). */
+ protected byte[] state;
+
+ /** The number of table entries in state==FREE. */
+ protected int freeEntries;
+
+
+ /** Constructs an empty map with default capacity and default load factors. */
+ public OpenHashMap() {
+ this(DEFAULT_CAPACITY);
+ }
+
+ /**
+ * Constructs an empty map with the specified initial capacity and default load factors.
+ *
+ * @param initialCapacity the initial capacity of the map.
+ * @throws IllegalArgumentException if the initial capacity is less than zero.
+ */
+ public OpenHashMap(int initialCapacity) {
+ this(initialCapacity, DEFAULT_MIN_LOAD_FACTOR, DEFAULT_MAX_LOAD_FACTOR);
+ }
+
+ /**
+ * Constructs an empty map with the specified initial capacity and the specified minimum and maximum load factor.
+ *
+ * @param initialCapacity the initial capacity.
+ * @param minLoadFactor the minimum load factor.
+ * @param maxLoadFactor the maximum load factor.
+ * @throws IllegalArgumentException if <tt>initialCapacity < 0 || (minLoadFactor < 0.0 || minLoadFactor >= 1.0) ||
+ * (maxLoadFactor <= 0.0 || maxLoadFactor >= 1.0) || (minLoadFactor >=
+ * maxLoadFactor)</tt>.
+ */
+ public OpenHashMap(int initialCapacity, double minLoadFactor, double maxLoadFactor) {
+ setUp(initialCapacity, minLoadFactor, maxLoadFactor);
+ }
+
+ /** Removes all (key,value) associations from the receiver. Implicitly calls <tt>trimToSize()</tt>. */
+ @Override
+ public void clear() {
+ Arrays.fill(this.state, FREE);
+ distinct = 0;
+ freeEntries = table.length; // delta
+ trimToSize();
+ }
+
+ /**
+ * Returns a deep copy of the receiver.
+ *
+ * @return a deep copy of the receiver.
+ */
+ @Override
+ @SuppressWarnings("unchecked")
+ public Object clone() {
+ OpenHashMap<K,V> copy = (OpenHashMap<K,V>) super.clone();
+ copy.table = copy.table.clone();
+ copy.values = copy.values.clone();
+ copy.state = copy.state.clone();
+ return copy;
+ }
+
+ /**
+ * Returns <tt>true</tt> if the receiver contains the specified key.
+ *
+ * @return <tt>true</tt> if the receiver contains the specified key.
+ */
+ @SuppressWarnings("unchecked")
+ @Override
+ public boolean containsKey(Object key) {
+ return indexOfKey((K)key) >= 0;
+ }
+
+ /**
+ * Returns <tt>true</tt> if the receiver contains the specified value.
+ *
+ * @return <tt>true</tt> if the receiver contains the specified value.
+ */
+ @SuppressWarnings("unchecked")
+ @Override
+ public boolean containsValue(Object value) {
+ return indexOfValue((V)value) >= 0;
+ }
+
+ /**
+ * Ensures that the receiver can hold at least the specified number of associations without needing to allocate new
+ * internal memory. If necessary, allocates new internal memory and increases the capacity of the receiver. <p> This
+ * method never need be called; it is for performance tuning only. Calling this method before <tt>put()</tt>ing a
+ * large number of associations boosts performance, because the receiver will grow only once instead of potentially
+ * many times and hash collisions get less probable.
+ *
+ * @param minCapacity the desired minimum capacity.
+ */
+ @Override
+ public void ensureCapacity(int minCapacity) {
+ if (table.length < minCapacity) {
+ int newCapacity = nextPrime(minCapacity);
+ rehash(newCapacity);
+ }
+ }
+
+ /**
+ * Applies a procedure to each key of the receiver, if any. Note: Iterates over the keys in no particular order.
+ * Subclasses can define a particular order, for example, "sorted by key". All methods which <i>can</i> be expressed
+ * in terms of this method (most methods can) <i>must guarantee</i> to use the <i>same</i> order defined by this
+ * method, even if it is no particular order. This is necessary so that, for example, methods <tt>keys</tt> and
+ * <tt>values</tt> will yield association pairs, not two uncorrelated lists.
+ *
+ * @param procedure the procedure to be applied. Stops iteration if the procedure returns <tt>false</tt>, otherwise
+ * continues.
+ * @return <tt>false</tt> if the procedure stopped before all keys where iterated over, <tt>true</tt> otherwise.
+ */
+ @SuppressWarnings("unchecked")
+ public boolean forEachKey(ObjectProcedure<K> procedure) {
+ for (int i = table.length; i-- > 0;) {
+ if (state[i] == FULL && !procedure.apply((K)table[i])) {
+ return false;
+ }
+ }
+ return true;
+ }
+
+ /**
+ * Applies a procedure to each (key,value) pair of the receiver, if any. Iteration order is guaranteed to be
+ * <i>identical</i> to the order used by method {@link #forEachKey(ObjectProcedure)}.
+ *
+ * @param procedure the procedure to be applied. Stops iteration if the procedure returns <tt>false</tt>, otherwise
+ * continues.
+ * @return <tt>false</tt> if the procedure stopped before all keys where iterated over, <tt>true</tt> otherwise.
+ */
+ @SuppressWarnings("unchecked")
+ public boolean forEachPair(ObjectObjectProcedure<K,V> procedure) {
+ for (int i = table.length; i-- > 0;) {
+ if (state[i] == FULL && !procedure.apply((K)table[i], (V)values[i])) {
+ return false;
+ }
+ }
+ return true;
+ }
+
+ /**
+ * Returns the value associated with the specified key. It is often a good idea to first check with {@link
+ * #containsKey(Object)} whether the given key has a value associated or not, i.e. whether there exists an association
+ * for the given key or not.
+ *
+ * @param key the key to be searched for.
+ * @return the value associated with the specified key; <tt>0</tt> if no such key is present.
+ */
+ @SuppressWarnings("unchecked")
+ @Override
+ public V get(Object key) {
+ int i = indexOfKey((K)key);
+ if (i < 0) {
+ return null;
+ } //not contained
+ return (V)values[i];
+ }
+
+ /**
+ * @param key the key to be added to the receiver.
+ * @return the index where the key would need to be inserted, if it is not already contained. Returns -index-1 if the
+ * key is already contained at slot index. Therefore, if the returned index < 0, then it is already contained
+ * at slot -index-1. If the returned index >= 0, then it is NOT already contained and should be inserted at
+ * slot index.
+ */
+ protected int indexOfInsertion(K key) {
+ Object[] tab = table;
+ byte[] stat = state;
+ int length = tab.length;
+
+ int hash = key.hashCode() & 0x7FFFFFFF;
+ int i = hash % length;
+ int decrement = hash % (length - 2); // double hashing, see http://www.eece.unm.edu/faculty/heileman/hash/node4.html
+ //int decrement = (hash / length) % length;
+ if (decrement == 0) {
+ decrement = 1;
+ }
+
+ // stop if we find a removed or free slot, or if we find the key itself
+ // do NOT skip over removed slots (yes, open addressing is like that...)
+ while (stat[i] == FULL && !equalsMindTheNull(key, tab[i])) {
+ i -= decrement;
+ //hashCollisions++;
+ if (i < 0) {
+ i += length;
+ }
+ }
+
+ if (stat[i] == REMOVED) {
+ // stop if we find a free slot, or if we find the key itself.
+ // do skip over removed slots (yes, open addressing is like that...)
+ // assertion: there is at least one FREE slot.
+ int j = i;
+ while (stat[i] != FREE && (stat[i] == REMOVED || tab[i] != key)) {
+ i -= decrement;
+ //hashCollisions++;
+ if (i < 0) {
+ i += length;
+ }
+ }
+ if (stat[i] == FREE) {
+ i = j;
+ }
+ }
+
+
+ if (stat[i] == FULL) {
+ // key already contained at slot i.
+ // return a negative number identifying the slot.
+ return -i - 1;
+ }
+ // not already contained, should be inserted at slot i.
+ // return a number >= 0 identifying the slot.
+ return i;
+ }
+
+ /**
+ * @param key the key to be searched in the receiver.
+ * @return the index where the key is contained in the receiver, returns -1 if the key was not found.
+ */
+ protected int indexOfKey(K key) {
+ Object[] tab = table;
+ byte[] stat = state;
+ int length = tab.length;
+
+ int hash = key.hashCode() & 0x7FFFFFFF;
+ int i = hash % length;
+ int decrement = hash % (length - 2); // double hashing, see http://www.eece.unm.edu/faculty/heileman/hash/node4.html
+ //int decrement = (hash / length) % length;
+ if (decrement == 0) {
+ decrement = 1;
+ }
+
+ // stop if we find a free slot, or if we find the key itself.
+ // do skip over removed slots (yes, open addressing is like that...)
+ while (stat[i] != FREE && (stat[i] == REMOVED || !equalsMindTheNull(key, tab[i]))) {
+ i -= decrement;
+ //hashCollisions++;
+ if (i < 0) {
+ i += length;
+ }
+ }
+
+ if (stat[i] == FREE) {
+ return -1;
+ } // not found
+ return i; //found, return index where key is contained
+ }
+
+ /**
+ * @param value the value to be searched in the receiver.
+ * @return the index where the value is contained in the receiver, returns -1 if the value was not found.
+ */
+ protected int indexOfValue(V value) {
+ Object[] val = values;
+ byte[] stat = state;
+
+ for (int i = stat.length; --i >= 0;) {
+ if (stat[i] == FULL && equalsMindTheNull(val[i], value)) {
+ return i;
+ }
+ }
+
+ return -1; // not found
+ }
+
+ /**
+ * Fills all keys contained in the receiver into the specified list. Fills the list, starting at index 0. After this
+ * call returns the specified list has a new size that equals <tt>this.size()</tt>.
+ * This method can be used
+ * to iterate over the keys of the receiver.
+ *
+ * @param list the list to be filled, can have any size.
+ */
+ @SuppressWarnings("unchecked")
+ public void keys(List<K> list) {
+ list.clear();
+
+
+ Object [] tab = table;
+ byte[] stat = state;
+
+ for (int i = tab.length; i-- > 0;) {
+ if (stat[i] == FULL) {
+ list.add((K)tab[i]);
+ }
+ }
+ }
+
+ /**
+ * Associates the given key with the given value. Replaces any old <tt>(key,someOtherValue)</tt> association, if
+ * existing.
+ *
+ * @param key the key the value shall be associated with.
+ * @param value the value to be associated.
+ * @return <tt>true</tt> if the receiver did not already contain such a key; <tt>false</tt> if the receiver did
+ * already contain such a key - the new value has now replaced the formerly associated value.
+ */
+ @SuppressWarnings("unchecked")
+ @Override
+ public V put(K key, V value) {
+ int i = indexOfInsertion(key);
+ if (i < 0) { //already contained
+ i = -i - 1;
+ V previous = (V) this.values[i];
+ this.values[i] = value;
+ return previous;
+ }
+
+ if (this.distinct > this.highWaterMark) {
+ int newCapacity = chooseGrowCapacity(this.distinct + 1, this.minLoadFactor, this.maxLoadFactor);
+ rehash(newCapacity);
+ return put(key, value);
+ }
+
+ this.table[i] = key;
+ this.values[i] = value;
+ if (this.state[i] == FREE) {
+ this.freeEntries--;
+ }
+ this.state[i] = FULL;
+ this.distinct++;
+
+ if (this.freeEntries < 1) { //delta
+ int newCapacity = chooseGrowCapacity(this.distinct + 1, this.minLoadFactor, this.maxLoadFactor);
+ rehash(newCapacity);
+ }
+
+ return null;
+ }
+
+ /**
+ * Rehashes the contents of the receiver into a new table with a smaller or larger capacity. This method is called
+ * automatically when the number of keys in the receiver exceeds the high water mark or falls below the low water
+ * mark.
+ */
+ @SuppressWarnings("unchecked")
+ protected void rehash(int newCapacity) {
+ int oldCapacity = table.length;
+ //if (oldCapacity == newCapacity) return;
+
+ Object[] oldTable = table;
+ Object[] oldValues = values;
+ byte[] oldState = state;
+
+ Object[] newTable = new Object[newCapacity];
+ Object[] newValues = new Object[newCapacity];
+ byte[] newState = new byte[newCapacity];
+
+ this.lowWaterMark = chooseLowWaterMark(newCapacity, this.minLoadFactor);
+ this.highWaterMark = chooseHighWaterMark(newCapacity, this.maxLoadFactor);
+
+ this.table = newTable;
+ this.values = newValues;
+ this.state = newState;
+ this.freeEntries = newCapacity - this.distinct; // delta
+
+ for (int i = oldCapacity; i-- > 0;) {
+ if (oldState[i] == FULL) {
+ Object element = oldTable[i];
+ int index = indexOfInsertion((K)element);
+ newTable[index] = element;
+ newValues[index] = oldValues[i];
+ newState[index] = FULL;
+ }
+ }
+ }
+
+ /**
+ * Removes the given key with its associated element from the receiver, if present.
+ *
+ * @param key the key to be removed from the receiver.
+ * @return <tt>true</tt> if the receiver contained the specified key, <tt>false</tt> otherwise.
+ */
+ @SuppressWarnings("unchecked")
+ @Override
+ public V remove(Object key) {
+ int i = indexOfKey((K)key);
+ if (i < 0) {
+ return null;
+ }
+ // key not contained
+ V removed = (V) values[i];
+
+ this.state[i] = REMOVED;
+ //this.values[i]=0; // delta
+ this.distinct--;
+
+ if (this.distinct < this.lowWaterMark) {
+ int newCapacity = chooseShrinkCapacity(this.distinct, this.minLoadFactor, this.maxLoadFactor);
+ rehash(newCapacity);
+ }
+
+ return removed;
+ }
+
+ /**
+ * Initializes the receiver.
+ *
+ * @param initialCapacity the initial capacity of the receiver.
+ * @param minLoadFactor the minLoadFactor of the receiver.
+ * @param maxLoadFactor the maxLoadFactor of the receiver.
+ * @throws IllegalArgumentException if <tt>initialCapacity < 0 || (minLoadFactor < 0.0 || minLoadFactor >= 1.0) ||
+ * (maxLoadFactor <= 0.0 || maxLoadFactor >= 1.0) || (minLoadFactor >=
+ * maxLoadFactor)</tt>.
+ */
+ @Override
+ protected void setUp(int initialCapacity, double minLoadFactor, double maxLoadFactor) {
+ int capacity = initialCapacity;
+ super.setUp(capacity, minLoadFactor, maxLoadFactor);
+ capacity = nextPrime(capacity);
+ if (capacity == 0) {
+ capacity = 1;
+ } // open addressing needs at least one FREE slot at any time.
+
+ this.table = new Object[capacity];
+ this.values = new Object[capacity];
+ this.state = new byte[capacity];
+
+ // memory will be exhausted long before this pathological case happens, anyway.
+ this.minLoadFactor = minLoadFactor;
+ if (capacity == PrimeFinder.LARGEST_PRIME) {
+ this.maxLoadFactor = 1.0;
+ } else {
+ this.maxLoadFactor = maxLoadFactor;
+ }
+
+ this.distinct = 0;
+ this.freeEntries = capacity; // delta
+
+ // lowWaterMark will be established upon first expansion.
+ // establishing it now (upon instance construction) would immediately make the table shrink upon first put(...).
+ // After all the idea of an "initialCapacity" implies violating lowWaterMarks when an object is young.
+ // See ensureCapacity(...)
+ this.lowWaterMark = 0;
+ this.highWaterMark = chooseHighWaterMark(capacity, this.maxLoadFactor);
+ }
+
+ /**
+ * Trims the capacity of the receiver to be the receiver's current size. Releases any superfluous internal memory. An
+ * application can use this operation to minimize the storage of the receiver.
+ */
+ @Override
+ public void trimToSize() {
+ // * 1.2 because open addressing's performance exponentially degrades beyond that point
+ // so that even rehashing the table can take very long
+ int newCapacity = nextPrime((int) (1 + 1.2 * size()));
+ if (table.length > newCapacity) {
+ rehash(newCapacity);
+ }
+ }
+
+ /**
+ * Access for unit tests.
+ * @param capacity
+ * @param minLoadFactor
+ * @param maxLoadFactor
+ */
+ void getInternalFactors(int[] capacity,
+ double[] minLoadFactor,
+ double[] maxLoadFactor) {
+ capacity[0] = table.length;
+ minLoadFactor[0] = this.minLoadFactor;
+ maxLoadFactor[0] = this.maxLoadFactor;
+ }
+
+ private class MapEntry implements Map.Entry<K,V> {
+ private final K key;
+ private final V value;
+
+ MapEntry(K key, V value) {
+ this.key = key;
+ this.value = value;
+ }
+
+ @Override
+ public K getKey() {
+ return key;
+ }
+
+ @Override
+ public V getValue() {
+ return value;
+ }
+
+ @Override
+ public V setValue(V value) {
+ throw new UnsupportedOperationException("Map.Entry.setValue not supported for OpenHashMap");
+ }
+
+ }
+
+ /**
+ * Allocate a set to contain Map.Entry objects for the pairs and return it.
+ */
+ @Override
+ public Set<java.util.Map.Entry<K,V>> entrySet() {
+ final Set<Entry<K, V>> entries = new OpenHashSet<>();
+ forEachPair(new ObjectObjectProcedure<K,V>() {
+ @Override
+ public boolean apply(K key, V value) {
+ entries.add(new MapEntry(key, value));
+ return true;
+ }
+ });
+ return entries;
+ }
+
+ /**
+ * Allocate a set to contain keys and return it.
+ * This violates the 'backing' provisions of the map interface.
+ */
+ @Override
+ public Set<K> keySet() {
+ final Set<K> keys = new OpenHashSet<>();
+ forEachKey(new ObjectProcedure<K>() {
+ @Override
+ public boolean apply(K element) {
+ keys.add(element);
+ return true;
+ }
+ });
+ return keys;
+ }
+
+ @Override
+ public void putAll(Map<? extends K,? extends V> m) {
+ for (Map.Entry<? extends K, ? extends V> e : m.entrySet()) {
+ put(e.getKey(), e.getValue());
+ }
+ }
+
+ /**
+ * Allocate a list to contain the values and return it.
+ * This violates the 'backing' provision of the Map interface.
+ */
+ @Override
+ public Collection<V> values() {
+ final List<V> valueList = new ArrayList<>();
+ forEachPair(new ObjectObjectProcedure<K,V>() {
+ @Override
+ public boolean apply(K key, V value) {
+ valueList.add(value);
+ return true;
+ }
+ });
+ return valueList;
+ }
+
+ @SuppressWarnings("unchecked")
+ @Override
+ public boolean equals(Object obj) {
+ if (!(obj instanceof OpenHashMap)) {
+ return false;
+ }
+ final OpenHashMap<K,V> o = (OpenHashMap<K,V>) obj;
+ if (o.size() != size()) {
+ return false;
+ }
+ final boolean[] equal = new boolean[1];
+ equal[0] = true;
+ forEachPair(new ObjectObjectProcedure<K,V>() {
+ @Override
+ public boolean apply(K key, V value) {
+ Object ov = o.get(key);
+ if (!value.equals(ov)) {
+ equal[0] = false;
+ return false;
+ }
+ return true;
+ }
+ });
+ return equal[0];
+ }
+
+ @Override
+ public String toString() {
+ final StringBuilder sb = new StringBuilder();
+ sb.append('{');
+ forEachPair(new ObjectObjectProcedure<K,V>() {
+ @Override
+ public boolean apply(K key, V value) {
+ sb.append('[');
+ sb.append(key);
+ sb.append(" -> ");
+ sb.append(value);
+ sb.append("] ");
+ return true;
+ }
+ });
+ sb.append('}');
+ return sb.toString();
+ }
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/map/PrimeFinder.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/mahout/math/map/PrimeFinder.java b/core/src/main/java/org/apache/mahout/math/map/PrimeFinder.java
new file mode 100644
index 0000000..b02611e
--- /dev/null
+++ b/core/src/main/java/org/apache/mahout/math/map/PrimeFinder.java
@@ -0,0 +1,145 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.mahout.math.map;
+
+import java.util.Arrays;
+
+/**
+ * Not of interest for users; only for implementors of hashtables.
+ * Used to keep hash table capacities prime numbers.
+ *
+ * <p>Choosing prime numbers as hash table capacities is a good idea to keep them working fast,
+ * particularly under hash table expansions.
+ *
+ * <p>However, JDK 1.2, JGL 3.1 and many other toolkits do nothing to keep capacities prime.
+ * This class provides efficient means to choose prime capacities.
+ *
+ * <p>Choosing a prime is <tt>O(log 300)</tt> (binary search in a list of 300 int's).
+ * Memory requirements: 1 KB static memory.
+ *
+ */
+public final class PrimeFinder {
+
+ /** The largest prime this class can generate; currently equal to <tt>Integer.MAX_VALUE</tt>. */
+ public static final int LARGEST_PRIME = Integer.MAX_VALUE; //yes, it is prime.
+
+ /**
+ * The prime number list consists of 11 chunks. Each chunk contains prime numbers. A chunk starts with a prime P1. The
+ * next element is a prime P2. P2 is the smallest prime for which holds: P2 >= 2*P1. The next element is P3, for which
+ * the same holds with respect to P2, and so on.
+ *
+ * Chunks are chosen such that for any desired capacity >= 1000 the list includes a prime number <= desired capacity *
+ * 1.11 (11%). For any desired capacity >= 200 the list includes a prime number <= desired capacity * 1.16 (16%). For
+ * any desired capacity >= 16 the list includes a prime number <= desired capacity * 1.21 (21%).
+ *
+ * Therefore, primes can be retrieved which are quite close to any desired capacity, which in turn avoids wasting
+ * memory. For example, the list includes 1039,1117,1201,1277,1361,1439,1523,1597,1759,1907,2081. So if you need a
+ * prime >= 1040, you will find a prime <= 1040*1.11=1154.
+ *
+ * Chunks are chosen such that they are optimized for a hashtable growthfactor of 2.0; If your hashtable has such a
+ * growthfactor then, after initially "rounding to a prime" upon hashtable construction, it will later expand to prime
+ * capacities such that there exist no better primes.
+ *
+ * In total these are about 32*10=320 numbers -> 1 KB of static memory needed. If you are stingy, then delete every
+ * second or fourth chunk.
+ */
+
+ private static final int[] PRIME_CAPACITIES = {
+ //chunk #0
+ LARGEST_PRIME,
+
+ //chunk #1
+ 5, 11, 23, 47, 97, 197, 397, 797, 1597, 3203, 6421, 12853, 25717, 51437, 102877, 205759,
+ 411527, 823117, 1646237, 3292489, 6584983, 13169977, 26339969, 52679969, 105359939,
+ 210719881, 421439783, 842879579, 1685759167,
+
+ //chunk #2
+ 433, 877, 1759, 3527, 7057, 14143, 28289, 56591, 113189, 226379, 452759, 905551, 1811107,
+ 3622219, 7244441, 14488931, 28977863, 57955739, 115911563, 231823147, 463646329, 927292699,
+ 1854585413,
+
+ //chunk #3
+ 953, 1907, 3821, 7643, 15287, 30577, 61169, 122347, 244703, 489407, 978821, 1957651, 3915341,
+ 7830701, 15661423, 31322867, 62645741, 125291483, 250582987, 501165979, 1002331963,
+ 2004663929,
+
+ //chunk #4
+ 1039, 2081, 4177, 8363, 16729, 33461, 66923, 133853, 267713, 535481, 1070981, 2141977, 4283963,
+ 8567929, 17135863, 34271747, 68543509, 137087021, 274174111, 548348231, 1096696463,
+
+ //chunk #5
+ 31, 67, 137, 277, 557, 1117, 2237, 4481, 8963, 17929, 35863, 71741, 143483, 286973, 573953,
+ 1147921, 2295859, 4591721, 9183457, 18366923, 36733847, 73467739, 146935499, 293871013,
+ 587742049, 1175484103,
+
+ //chunk #6
+ 599, 1201, 2411, 4831, 9677, 19373, 38747, 77509, 155027, 310081, 620171, 1240361, 2480729,
+ 4961459, 9922933, 19845871, 39691759, 79383533, 158767069, 317534141, 635068283, 1270136683,
+
+ //chunk #7
+ 311, 631, 1277, 2557, 5119, 10243, 20507, 41017, 82037, 164089, 328213, 656429, 1312867,
+ 2625761, 5251529, 10503061, 21006137, 42012281, 84024581, 168049163, 336098327, 672196673,
+ 1344393353,
+
+ //chunk #8
+ 3, 7, 17, 37, 79, 163, 331, 673, 1361, 2729, 5471, 10949, 21911, 43853, 87719, 175447, 350899,
+ 701819, 1403641, 2807303, 5614657, 11229331, 22458671, 44917381, 89834777, 179669557,
+ 359339171, 718678369, 1437356741,
+
+ //chunk #9
+ 43, 89, 179, 359, 719, 1439, 2879, 5779, 11579, 23159, 46327, 92657, 185323, 370661, 741337,
+ 1482707, 2965421, 5930887, 11861791, 23723597, 47447201, 94894427, 189788857, 379577741,
+ 759155483, 1518310967,
+
+ //chunk #10
+ 379, 761, 1523, 3049, 6101, 12203, 24407, 48817, 97649, 195311, 390647, 781301, 1562611,
+ 3125257, 6250537, 12501169, 25002389, 50004791, 100009607, 200019221, 400038451, 800076929,
+ 1600153859
+ };
+
+
+ static { //initializer
+ // The above prime numbers are formatted for human readability.
+ // To find numbers fast, we sort them once and for all.
+
+ Arrays.sort(PRIME_CAPACITIES);
+ }
+
+ /** Makes this class non instantiable, but still let's others inherit from it. */
+ private PrimeFinder() {
+ }
+
+ /**
+ * Returns a prime number which is {@code <= desiredCapacity} and very close to {@code desiredCapacity}
+ * (within 11% if {@code desiredCapacity <= 1000}).
+ *
+ * @param desiredCapacity the capacity desired by the user.
+ * @return the capacity which should be used for a hashtable.
+ */
+ public static int nextPrime(int desiredCapacity) {
+ int i = java.util.Arrays.binarySearch(PRIME_CAPACITIES, desiredCapacity);
+ if (i < 0) {
+ // desired capacity not found, choose next prime greater than desired capacity
+ i = -i - 1; // remember the semantics of binarySearch...
+ }
+ return PRIME_CAPACITIES[i];
+ }
+
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/map/QuickOpenIntIntHashMap.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/mahout/math/map/QuickOpenIntIntHashMap.java b/core/src/main/java/org/apache/mahout/math/map/QuickOpenIntIntHashMap.java
new file mode 100644
index 0000000..6a7cef8
--- /dev/null
+++ b/core/src/main/java/org/apache/mahout/math/map/QuickOpenIntIntHashMap.java
@@ -0,0 +1,215 @@
+/*
+Copyright � 1999 CERN - European Organization for Nuclear Research.
+Permission to use, copy, modify, distribute and sell this software and its documentation for any purpose
+is hereby granted without fee, provided that the above copyright notice appear in all copies and
+that both that copyright notice and this permission notice appear in supporting documentation.
+CERN makes no representations about the suitability of this software for any purpose.
+It is provided "as is" without expressed or implied warranty.
+*/
+package org.apache.mahout.math.map;
+
+/**
+ * Status: Experimental; Do not use for production yet. Hash map holding (key,value) associations of type
+ * <tt>(int-->int)</tt>; Automatically grows and shrinks as needed; Implemented using open addressing with double
+ * hashing. First see the <a href="package-summary.html">package summary</a> and javadoc <a
+ * href="package-tree.html">tree view</a> to get the broad picture.
+ *
+ * Implements open addressing with double hashing, using "Brent's variation". Brent's variation slows insertions a bit
+ * down (not much) but reduces probes (collisions) for successful searches, in particular for large load factors. (It
+ * does not improve unsuccessful searches.) See D. Knuth, Searching and Sorting, 3rd ed., p.533-545
+ *
+ * @author ***@cern.ch
+ * @version 1.0, 09/24/99
+ * @see java.util.HashMap
+ */
+class QuickOpenIntIntHashMap extends OpenIntIntHashMap {
+ //public int totalProbesSaved = 0; // benchmark only
+
+ /** Constructs an empty map with default capacity and default load factors. */
+ QuickOpenIntIntHashMap() {
+ this(DEFAULT_CAPACITY);
+ }
+
+ /**
+ * Constructs an empty map with the specified initial capacity and default load factors.
+ *
+ * @param initialCapacity the initial capacity of the map.
+ * @throws IllegalArgumentException if the initial capacity is less than zero.
+ */
+ QuickOpenIntIntHashMap(int initialCapacity) {
+ this(initialCapacity, DEFAULT_MIN_LOAD_FACTOR, DEFAULT_MAX_LOAD_FACTOR);
+ }
+
+ /**
+ * Constructs an empty map with the specified initial capacity and the specified minimum and maximum load factor.
+ *
+ * @param initialCapacity the initial capacity.
+ * @param minLoadFactor the minimum load factor.
+ * @param maxLoadFactor the maximum load factor.
+ * @throws IllegalArgumentException if <tt>initialCapacity < 0 || (minLoadFactor < 0.0 || minLoadFactor >= 1.0) ||
+ * (maxLoadFactor <= 0.0 || maxLoadFactor >= 1.0) || (minLoadFactor >=
+ * maxLoadFactor)</tt>.
+ */
+ QuickOpenIntIntHashMap(int initialCapacity, double minLoadFactor, double maxLoadFactor) {
+ setUp(initialCapacity, minLoadFactor, maxLoadFactor);
+ }
+
+ /**
+ * Associates the given key with the given value. Replaces any old <tt>(key,someOtherValue)</tt> association, if
+ * existing.
+ *
+ * @param key the key the value shall be associated with.
+ * @param value the value to be associated.
+ * @return <tt>true</tt> if the receiver did not already contain such a key; <tt>false</tt> if the receiver did
+ * already contain such a key - the new value has now replaced the formerly associated value.
+ */
+ @Override
+ public boolean put(int key, int value) {
+ /*
+ This is open addressing with double hashing, using "Brent's variation".
+ Brent's variation slows insertions a bit down (not much) but reduces probes (collisions) for successful searches,
+ in particular for large load factors.
+ (It does not improve unsuccessful searches.)
+ See D. Knuth, Searching and Sorting, 3rd ed., p.533-545
+
+ h1(key) = hash % M
+ h2(key) = decrement = Max(1, hash/M % M)
+ M is prime = capacity = table.length
+ probing positions are table[(h1-j*h2) % M] for j=0,1,...
+ (M and h2 could also be chosen differently, but h2 is required to be relative prime to M.)
+ */
+
+ int[] tab = table;
+ byte[] stat = state;
+ int length = tab.length;
+
+ int hash = HashFunctions.hash(key) & 0x7FFFFFFF;
+ int i = hash % length;
+ int decrement = (hash / length) % length;
+ if (decrement == 0) {
+ decrement = 1;
+ }
+
+ // stop if we find a removed or free slot, or if we find the key itself
+ // do NOT skip over removed slots (yes, open addressing is like that...)
+ //int comp = comparisons;
+ int t = 0; // the number of probes
+ int p0 = i; // the first position to probe
+ while (stat[i] == FULL && tab[i] != key) {
+ t++;
+ i -= decrement;
+ //hashCollisions++;
+ if (i < 0) {
+ i += length;
+ }
+ }
+ if (stat[i] == FULL) {
+ // key already contained at slot i.
+ this.values[i] = value;
+ return false;
+ }
+ // not already contained, should be inserted at slot i.
+
+ if (this.distinct > this.highWaterMark) {
+ int newCapacity = chooseGrowCapacity(this.distinct + 1, this.minLoadFactor, this.maxLoadFactor);
+ rehash(newCapacity);
+ return put(key, value);
+ }
+
+ /*
+ Brent's variation does a local reorganization to reduce probes. It essentially means:
+ We test whether it is possible to move the association we probed first (table[p0]) out of the way.
+ If this is possible, it will reduce probes for the key to be inserted, since it takes its place;
+ it gets hit earlier.
+ However, future probes for the key that we move out of the way will increase.
+ Thus we only move it out of the way, if we have a net gain, that is, if we save more probes than we loose.
+ For the first probe we safe more than we loose if the number of probes we needed was >=2 (t>=2).
+ If the first probe cannot be moved out of the way, we try the next probe (p1).
+ Now we safe more than we loose if t>=3.
+ We repeat this until we find that we cannot gain or that we can indeed move p(x) out of the way.
+
+ Note: Under the great majority of insertions t<=1, so the loop is entered very infrequently.
+ */
+ while (t > 1) {
+ int key0 = tab[p0];
+ hash = HashFunctions.hash(key0) & 0x7FFFFFFF;
+ decrement = (hash / length) % length;
+ if (decrement == 0) {
+ decrement = 1;
+ }
+ int pc = p0 - decrement; // pc = (p0-j*decrement) % M, j=1,2,..
+ if (pc < 0) {
+ pc += length;
+ }
+
+ if (stat[pc] != FREE) { // not a free slot, continue searching for free slot to move to, or break.
+ p0 = pc;
+ t--;
+ } else { // free or removed slot found, now move...
+ tab[pc] = key0;
+ stat[pc] = FULL;
+ values[pc] = values[p0];
+ i = p0; // prepare to insert: table[p0]=key
+ t = 0; // break loop
+ }
+ }
+
+ this.table[i] = key;
+ this.values[i] = value;
+ if (this.state[i] == FREE) {
+ this.freeEntries--;
+ }
+ this.state[i] = FULL;
+ this.distinct++;
+
+ if (this.freeEntries < 1) { //delta
+ int newCapacity = chooseGrowCapacity(this.distinct + 1, this.minLoadFactor, this.maxLoadFactor);
+ rehash(newCapacity);
+ }
+
+ return true;
+ }
+
+ /**
+ * Rehashes the contents of the receiver into a new table with a smaller or larger capacity. This method is called
+ * automatically when the number of keys in the receiver exceeds the high water mark or falls below the low water
+ * mark.
+ */
+ @Override
+ protected void rehash(int newCapacity) {
+ int oldCapacity = table.length;
+ //if (oldCapacity == newCapacity) return;
+
+ int[] oldTable = table;
+ int[] oldValues = values;
+ byte[] oldState = state;
+
+ int[] newTable = new int[newCapacity];
+ int[] newValues = new int[newCapacity];
+ byte[] newState = new byte[newCapacity];
+
+ this.lowWaterMark = chooseLowWaterMark(newCapacity, this.minLoadFactor);
+ this.highWaterMark = chooseHighWaterMark(newCapacity, this.maxLoadFactor);
+
+ this.table = newTable;
+ this.values = newValues;
+ this.state = newState;
+ this.freeEntries = newCapacity - this.distinct; // delta
+
+ int tmp = this.distinct;
+ this.distinct = Integer.MIN_VALUE; // switch of watermarks
+ for (int i = oldCapacity; i-- > 0;) {
+ if (oldState[i] == FULL) {
+ put(oldTable[i], oldValues[i]);
+ /*
+ int element = oldTable[i];
+ int index = indexOfInsertion(element);
+ newTable[index]=element;
+ newValues[index]=oldValues[i];
+ newState[index]=FULL;
+ */
+ }
+ }
+ this.distinct = tmp;
+ }
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/map/package-info.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/mahout/math/map/package-info.java b/core/src/main/java/org/apache/mahout/math/map/package-info.java
new file mode 100644
index 0000000..9356f45
--- /dev/null
+++ b/core/src/main/java/org/apache/mahout/math/map/package-info.java
@@ -0,0 +1,250 @@
+/**
+ * <HTML>
+ * <BODY>
+ * Automatically growing and shrinking maps holding objects or primitive
+ * data types such as <tt>int</tt>, <tt>double</tt>, etc. Currently all maps are
+ * based upon hashing.
+ * <h2><a name="Overview"></a>1. Overview</h2>
+ * <p>The map package offers flexible object oriented abstractions modelling automatically
+ * resizing maps. It is designed to be scalable in terms of performance and memory
+ * requirements.</p>
+ * <p>Features include: </p>
+ * <p></p>
+ * <ul>
+ * <li>Maps operating on objects as well as all primitive data types such as <code>int</code>,
+ * <code>double</code>, etc.
+ * </li>
+ * <li>Compact representations</li>
+ * <li>Support for quick access to associations</li>
+ * <li>A number of general purpose map operations</li>
+ * </ul>
+ * <p>File-based I/O can be achieved through the standard Java built-in serialization
+ * mechanism. All classes implement the {@link java.io.Serializable} interface.
+ * However, the toolkit is entirely decoupled from advanced I/O. It provides data
+ * structures and algorithms only.
+ * <p> This toolkit borrows some terminology from the Javasoft <a
+ * href="http://www.javasoft.com/products/jdk/1.2/docs/guide/collections/index.html">
+ * Collections framework</a> written by Josh Bloch and introduced in JDK 1.2.
+ * <h2>2. Introduction</h2>
+ * <p>A map is an associative container that manages a set of (key,value) pairs.
+ * It is useful for implementing a collection of one-to-one mappings. A (key,value)
+ * pair is called an <i>association</i>. A value can be looked up up via its key.
+ * Associations can quickly be set, removed and retrieved. They are stored in a
+ * hashing structure based on the hash code of their keys, which is obtained by
+ * using a hash function. </p>
+ * <p> A map can, for example, contain <tt>Name-->Location</tt> associations like
+ * <tt>{("Pete", "Geneva"), ("Steve", "Paris"), ("Robert", "New York")}</tt> used
+ * in address books or <tt>Index-->Value</tt> mappings like <tt>{(0, 100), (3,
+ * 1000), (100000, 70)}</tt> representing sparse lists or matrices. For example
+ * this could mean at index 0 we have a value of 100, at index 3 we have a value
+ * of 1000, at index 1000000 we have a value of 70, and at all other indexes we
+ * have a value of, say, zero. Another example is a map of IP addresses to domain
+ * names (DNS). Maps can also be useful to represent<i> multi sets</i>, that is,
+ * sets where elements can occur more than once. For multi sets one would have
+ * <tt>Value-->Frequency</tt> mappings like <tt>{(100, 1), (50, 1000), (101, 3))}</tt>
+ * meaning element 100 occurs 1 time, element 50 occurs 1000 times, element 101
+ * occurs 3 times. Further, maps can also manage <tt>ObjectIdentifier-->Object</tt>
+ * mappings like <tt>{(12, obj1), (7, obj2), (10000, obj3), (9, obj4)}</tt> used
+ * in Object Databases.
+ * <p> A map cannot contain two or more <i>equal</i> keys; a key can map to at most
+ * one value. However, more than one key can map to identical values. For primitive
+ * data types "equality" of keys is defined as identity (operator <tt>==</tt>).
+ * For maps using <tt>Object</tt> keys, the meaning of "equality" can be specified
+ * by the user upon instance construction. It can either be defined to be identity
+ * (operator <tt>==</tt>) or to be given by the method {@link java.lang.Object#equals(Object)}.
+ * Associations of kind <tt>(AnyType,Object)</tt> can be of the form <tt>(AnyKey,null)
+ * </tt>, i.e. values can be <tt>null</tt>.
+ * <p> The classes of this package make no guarantees as to the order of the elements
+ * returned by iterators; in particular, they do not guarantee that the order will
+ * remain constant over time.
+ * <h2></h2>
+ * <h4>Copying</h4>
+ * <p>
+ * <p>Any map can be copied. A copy is <i>equal</i> to the original but entirely
+ * independent of the original. So changes in the copy are not reflected in the
+ * original, and vice-versa.
+ * <h2>3. Package organization</h2>
+ * <p>For most primitive data types and for objects there exists a separate map version.
+ * All versions are just the same, except that they operate on different data types.
+ * Colt includes two kinds of implementations for maps: The two different implementations
+ * are tagged <b>Chained</b> and <b>Open</b>.
+ * Note: Chained is no more included. Wherever it is mentioned it is of historic interest only.</p>
+ * <ul>
+ * <li><b>Chained</b> uses extendible separate chaining with chains holding unsorted
+ * dynamically linked collision lists.
+ * <li><b>Open</b> uses extendible open addressing with double hashing.
+ * </ul>
+ * <p>Class naming follows the schema <tt>&lt;Implementation&gt;&lt;KeyType&gt;&lt;ValueType&gt;HashMap</tt>.
+ * For example, a {@link org.apache.mahout.math.map.OpenIntDoubleHashMap} holds <tt>(int-->double)</tt>
+ * associations and is implemented with open addressing. A {@link org.apache.mahout.math.map.OpenIntObjectHashMap}
+ * holds <tt>(int-->Object)</tt> associations and is implemented with open addressing.
+ * </p>
+ * <p>The classes for maps of a given (key,value) type are derived from a common
+ * abstract base class tagged <tt>Abstract&lt;KeyType&gt;&lt;ValueType&gt;</tt><tt>Map</tt>.
+ * For example, all maps operating on <tt>(int-->double)</tt> associations are
+ * derived from {@link org.apache.mahout.math.map.AbstractIntDoubleMap}, which in turn is derived
+ * from an abstract base class tying together all maps regardless of assocation
+ * type, {@link org.apache.mahout.math.set.AbstractSet}. The abstract base classes provide skeleton
+ * implementations for all but few methods. Experimental layouts (such as chaining,
+ * open addressing, extensible hashing, red-black-trees, etc.) can easily be implemented
+ * and inherit a rich set of functionality. Have a look at the javadoc <a href="package-tree.html">tree
+ * view</a> to get the broad picture.</p>
+ * <h2>4. Example usage</h2>
+ * <TABLE>
+ * <TD CLASS="PRE">
+ * <PRE>
+ * int[] keys = {0 , 3 , 100000, 9 };
+ * double[] values = {100.0, 1000.0, 70.0 , 71.0};
+ * AbstractIntDoubleMap map = new OpenIntDoubleHashMap();
+ * // add several associations
+ * for (int i=0; i &lt; keys.length; i++) map.put(keys[i], values[i]);
+ * log.info("map="+map);
+ * log.info("size="+map.size());
+ * log.info(map.containsKey(3));
+ * log.info("get(3)="+map.get(3));
+ * log.info(map.containsKey(4));
+ * log.info("get(4)="+map.get(4));
+ * log.info(map.containsValue(71.0));
+ * log.info("keyOf(71.0)="+map.keyOf(71.0));
+ * // remove one association
+ * map.removeKey(3);
+ * log.info("\nmap="+map);
+ * log.info(map.containsKey(3));
+ * log.info("get(3)="+map.get(3));
+ * log.info(map.containsValue(1000.0));
+ * log.info("keyOf(1000.0)="+map.keyOf(1000.0));
+ * // clear
+ * map.clear();
+ * log.info("\nmap="+map);
+ * log.info("size="+map.size());
+ * </PRE>
+ * </TD>
+ * </TABLE>
+ * yields the following output
+ * <TABLE>
+ * <TD CLASS="PRE">
+ * <PRE>
+ * map=[0->100.0, 3->1000.0, 9->71.0, 100000->70.0]
+ * size=4
+ * true
+ * get(3)=1000.0
+ * false
+ * get(4)=0.0
+ * true
+ * keyOf(71.0)=9
+ * map=[0->100.0, 9->71.0, 100000->70.0]
+ * false
+ * get(3)=0.0
+ * false
+ * keyOf(1000.0)=-2147483648
+ * map=[]
+ * size=0
+ * </PRE>
+ * </TD>
+ * </TABLE>
+ * <h2> 5. Notes </h2>
+ * <p>
+ * Note that implementations are not synchronized.
+ * <p>
+ * Choosing efficient parameters for hash maps is not always easy.
+ * However, since parameters determine efficiency and memory requirements, here is a quick guide how to choose them.
+ * If your use case does not heavily operate on hash maps but uses them just because they provide
+ * convenient functionality, you can safely skip this section.
+ * For those of you who care, read on.
+ * <p>
+ * There are three parameters that can be customized upon map construction: <tt>initialCapacity</tt>,
+ * <tt>minLoadFactor</tt> and <tt>maxLoadFactor</tt>.
+ * The more memory one can afford, the faster a hash map.
+ * The hash map's capacity is the maximum number of associations that can be added without needing to allocate new
+ * internal memory.
+ * A larger capacity means faster adding, searching and removing.
+ * The <tt>initialCapacity</tt> corresponds to the capacity used upon instance construction.
+ * <p>
+ * The <tt>loadFactor</tt> of a hash map measures the degree of "fullness".
+ * It is given by the number of assocations (<tt>size()</tt>)
+ * divided by the hash map capacity <tt>(0.0 &lt;= loadFactor &lt;= 1.0)</tt>.
+ * The more associations are added, the larger the loadFactor and the more hash map performance degrades.
+ * Therefore, when the loadFactor exceeds a customizable threshold (<tt>maxLoadFactor</tt>), the hash map is
+ * automatically grown.
+ * In such a way performance degradation can be avoided.
+ * Similarly, when the loadFactor falls below a customizable threshold (<tt>minLoadFactor</tt>), the hash map is
+ * automatically shrinked.
+ * In such a way excessive memory consumption can be avoided.
+ * Automatic resizing (both growing and shrinking) obeys the following invariant:
+ * <p>
+ * <tt>capacity * minLoadFactor <= size() <= capacity * maxLoadFactor</tt>
+ * <p> The term <tt>capacity * minLoadFactor</tt> is called the <i>low water mark</i>,
+ * <tt>capacity * maxLoadFactor</tt> is called the <i>high water mark</i>. In other
+ * words, the number of associations may vary within the water mark constraints.
+ * When it goes out of range, the map is automatically resized and memory consumption
+ * changes proportionally.
+ * <ul>
+ * <li>To tune for memory at the expense of performance, both increase <tt>minLoadFactor</tt> and
+ * <tt>maxLoadFactor</tt>.
+ * <li>To tune for performance at the expense of memory, both decrease <tt>minLoadFactor</tt> and
+ * <tt>maxLoadFactor</tt>.
+ * As as special case set <tt>minLoadFactor=0</tt> to avoid any automatic shrinking.
+ * </ul>
+ * Resizing large hash maps can be time consuming, <tt>O(size())</tt>, and should be avoided if possible (maintaining
+ * primes is not the reason).
+ * Unnecessary growing operations can be avoided if the number of associations is known before they are added, or can be
+ * estimated.<p>
+ * In such a case good parameters are as follows:
+ * <p>
+ * <i>For chaining:</i>
+ * <br>Set the <tt>initialCapacity = 1.4*expectedSize</tt> or greater.
+ * <br>Set the <tt>maxLoadFactor = 0.8</tt> or greater.
+ * <p>
+ * <i>For open addressing:</i>
+ * <br>Set the <tt>initialCapacity = 2*expectedSize</tt> or greater. Alternatively call <tt>ensureCapacity(...)</tt>.
+ * <br>Set the <tt>maxLoadFactor = 0.5</tt>.
+ * <br>Never set <tt>maxLoadFactor &gt; 0.55</tt>; open addressing exponentially slows down beyond that point.
+ * <p>
+ * In this way the hash map will never need to grow and still stay fast.
+ * It is never a good idea to set <tt>maxLoadFactor &lt; 0.1</tt>,
+ * because the hash map would grow too often.
+ * If it is entirelly unknown how many associations the application will use,
+ * the default constructor should be used. The map will grow and shrink as needed.
+ * <p>
+ * <b>Comparision of chaining and open addressing</b>
+ * <p> Chaining is faster than open addressing, when assuming unconstrained memory
+ * consumption. Open addressing is more space efficient than chaining, because
+ * it does not create entry objects but uses primitive arrays which are considerably
+ * smaller. Entry objects consume significant amounts of memory compared to the
+ * information they actually hold. Open addressing also poses no problems to the
+ * garbage collector. In contrast, chaining can create millions of entry objects
+ * which are linked; a nightmare for any garbage collector. In addition, entry
+ * object creation is a bit slow. <br>
+ * Therefore, with the same amount of memory, or even less memory, hash maps with
+ * larger capacity can be maintained under open addressing, which yields smaller
+ * loadFactors, which in turn keeps performance competitive with chaining. In our
+ * benchmarks, using significantly less memory, open addressing usually is not
+ * more than 1.2-1.5 times slower than chaining.
+ * <p><b>Further readings</b>:
+ * <br>Knuth D., The Art of Computer Programming: Searching and Sorting, 3rd ed.
+ * <br>Griswold W., Townsend G., The Design and Implementation of Dynamic Hashing for Sets and Tables in Icon,
+ * Software - Practice and Experience, Vol. 23(4), 351-367 (April 1993).
+ * <br>Larson P., Dynamic hash tables, Comm. of the ACM, 31, (4), 1988.
+ * <p>
+ * <b>Performance:</b>
+ * <p>
+ * Time complexity:
+ * <br>The classes offer <i>expected</i> time complexity <tt>O(1)</tt> (i.e. constant time) for the basic operations
+ * <tt>put</tt>, <tt>get</tt>, <tt>removeKey</tt>, <tt>containsKey</tt> and <tt>size</tt>,
+ * assuming the hash function disperses the elements properly among the buckets.
+ * Otherwise, pathological cases, although highly improbable, can occur, degrading performance to <tt>O(N)</tt> in the
+ * worst case.
+ * Operations <tt>containsValue</tt> and <tt>keyOf</tt> are <tt>O(N)</tt>.
+ * <p>
+ * Memory requirements for <i>open addressing</i>:
+ * <br>worst case: <tt>memory [bytes] = (1/minLoadFactor) * size() * (1 + sizeOf(key) + sizeOf(value))</tt>.
+ * <br>best case: <tt>memory [bytes] = (1/maxLoadFactor) * size() * (1 + sizeOf(key) + sizeOf(value))</tt>.
+ * Where <tt>sizeOf(int) = 4</tt>, <tt>sizeOf(double) = 8</tt>, <tt>sizeOf(Object) = 4</tt>, etc.
+ * Thus, an <tt>OpenIntIntHashMap</tt> with minLoadFactor=0.25 and maxLoadFactor=0.5 and 1000000 associations uses
+ * between 17 MB and 34 MB.
+ * The same map with 1000 associations uses between 17 and 34 KB.
+ * <p>
+ * </BODY>
+ * </HTML>
+ */
+package org.apache.mahout.math.map;

http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/package-info.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/mahout/math/package-info.java b/core/src/main/java/org/apache/mahout/math/package-info.java
new file mode 100644
index 0000000..de664f0
--- /dev/null
+++ b/core/src/main/java/org/apache/mahout/math/package-info.java
@@ -0,0 +1,4 @@
+/**
+ * Core base classes; Operations on primitive arrays such as sorting, partitioning and permuting.
+ */
+package org.apache.mahout.math;

http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/random/AbstractSamplerFunction.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/mahout/math/random/AbstractSamplerFunction.java b/core/src/main/java/org/apache/mahout/math/random/AbstractSamplerFunction.java
new file mode 100644
index 0000000..d657fd9
--- /dev/null
+++ b/core/src/main/java/org/apache/mahout/math/random/AbstractSamplerFunction.java
@@ -0,0 +1,39 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math.random;
+
+import org.apache.mahout.math.function.DoubleFunction;
+
+/**
+ * This shim allows samplers to be used to initialize vectors.
+ */
+public abstract class AbstractSamplerFunction extends DoubleFunction implements Sampler<Double> {
+ /**
+ * Apply the function to the argument and return the result
+ *
+ * @param ignored Ignored argument
+ * @return A sample from this distribution.
+ */
+ @Override
+ public double apply(double ignored) {
+ return sample();
+ }
+
+ @Override
+ public abstract Double sample();
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/random/ChineseRestaurant.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/mahout/math/random/ChineseRestaurant.java b/core/src/main/java/org/apache/mahout/math/random/ChineseRestaurant.java
new file mode 100644
index 0000000..8127b92
--- /dev/null
+++ b/core/src/main/java/org/apache/mahout/math/random/ChineseRestaurant.java
@@ -0,0 +1,111 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math.random;
+
+import com.google.common.base.Preconditions;
+import org.apache.mahout.common.RandomUtils;
+import org.apache.mahout.math.list.DoubleArrayList;
+
+import java.util.Random;
+
+/**
+ *
+ * Generates samples from a generalized Chinese restaurant process (or Pittman-Yor process).
+ *
+ * The number of values drawn exactly once will asymptotically be equal to the discount parameter
+ * as the total number of draws T increases without bound. The number of unique values sampled will
+ * increase as O(alpha * log T) if discount = 0 or O(alpha * T^discount) for discount > 0.
+ */
+public final class ChineseRestaurant implements Sampler<Integer> {
+
+ private final double alpha;
+ private double weight = 0;
+ private double discount = 0;
+ private final DoubleArrayList weights = new DoubleArrayList();
+ private final Random rand = RandomUtils.getRandom();
+
+ /**
+ * Constructs a Dirichlet process sampler. This is done by setting discount = 0.
+ * @param alpha The strength parameter for the Dirichlet process.
+ */
+ public ChineseRestaurant(double alpha) {
+ this(alpha, 0);
+ }
+
+ /**
+ * Constructs a Pitman-Yor sampler.
+ *
+ * @param alpha The strength parameter that drives the number of unique values as a function of draws.
+ * @param discount The discount parameter that drives the percentage of values that occur once in a large sample.
+ */
+ public ChineseRestaurant(double alpha, double discount) {
+ Preconditions.checkArgument(alpha > 0, "Strength Parameter, alpha must be greater then 0!");
+ Preconditions.checkArgument(discount >= 0 && discount <= 1, "Must be: 0 <= discount <= 1");
+ this.alpha = alpha;
+ this.discount = discount;
+ }
+
+ @Override
+ public Integer sample() {
+ double u = rand.nextDouble() * (alpha + weight);
+ for (int j = 0; j < weights.size(); j++) {
+ // select existing options with probability (w_j - d) / (alpha + w)
+ if (u < weights.get(j) - discount) {
+ weights.set(j, weights.get(j) + 1);
+ weight++;
+ return j;
+ } else {
+ u -= weights.get(j) - discount;
+ }
+ }
+
+ // if no existing item selected, pick new item with probability (alpha - d*t) / (alpha + w)
+ // where t is number of pre-existing cases
+ weights.add(1);
+ weight++;
+ return weights.size() - 1;
+ }
+
+ /**
+ * @return the number of unique values that have been returned.
+ */
+ public int size() {
+ return weights.size();
+ }
+
+ /**
+ * @return the number draws so far.
+ */
+ public int count() {
+ return (int) weight;
+ }
+
+ /**
+ * @param j Which value to test.
+ * @return The number of times that j has been returned so far.
+ */
+ public int count(int j) {
+ Preconditions.checkArgument(j >= 0);
+
+ if (j < weights.size()) {
+ return (int) weights.get(j);
+ } else {
+ return 0;
+ }
+ }
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/random/Empirical.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/mahout/math/random/Empirical.java b/core/src/main/java/org/apache/mahout/math/random/Empirical.java
new file mode 100644
index 0000000..78bfec5
--- /dev/null
+++ b/core/src/main/java/org/apache/mahout/math/random/Empirical.java
@@ -0,0 +1,124 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math.random;
+
+import com.google.common.base.Preconditions;
+import org.apache.mahout.common.RandomUtils;
+
+import java.util.Random;
+
+/**
+ * Samples from an empirical cumulative distribution.
+ */
+public final class Empirical extends AbstractSamplerFunction {
+ private final Random gen;
+ private final boolean exceedMinimum;
+ private final boolean exceedMaximum;
+
+ private final double[] x;
+ private final double[] y;
+ private final int n;
+
+ /**
+ * Sets up a sampler for a specified empirical cumulative distribution function. The distribution
+ * can have optional exponential tails on either or both ends, but otherwise does a linear
+ * interpolation between known points.
+ *
+ * @param exceedMinimum Should we generate samples less than the smallest quantile (i.e. generate a left tail)?
+ * @param exceedMaximum Should we generate samples greater than the largest observed quantile (i.e. generate a right
+ * tail)?
+ * @param samples The number of samples observed to get the quantiles.
+ * @param ecdf Alternating values that represent which percentile (in the [0..1] range)
+ * and values. For instance, if you have the min, median and max of 1, 3, 10
+ * you should pass 0.0, 1, 0.5, 3, 1.0, 10. Note that the list must include
+ * the 0-th (1.0-th) quantile if the left (right) tail is not allowed.
+ */
+ public Empirical(boolean exceedMinimum, boolean exceedMaximum, int samples, double... ecdf) {
+ Preconditions.checkArgument(ecdf.length % 2 == 0, "ecdf must have an even count of values");
+ Preconditions.checkArgument(samples >= 3, "Sample size must be >= 3");
+
+ // if we can't exceed the observed bounds, then we have to be given the bounds.
+ Preconditions.checkArgument(exceedMinimum || ecdf[0] == 0);
+ Preconditions.checkArgument(exceedMaximum || ecdf[ecdf.length - 2] == 1);
+
+ gen = RandomUtils.getRandom();
+
+ n = ecdf.length / 2;
+ x = new double[n];
+ y = new double[n];
+
+ double lastX = ecdf[1];
+ double lastY = ecdf[0];
+ for (int i = 0; i < ecdf.length; i += 2) {
+ // values have to be monotonic increasing
+ Preconditions.checkArgument(i == 0 || ecdf[i + 1] > lastY);
+ y[i / 2] = ecdf[i + 1];
+ lastY = y[i / 2];
+
+ // quantiles have to be in [0,1] and be monotonic increasing
+ Preconditions.checkArgument(ecdf[i] >= 0 && ecdf[i] <= 1);
+ Preconditions.checkArgument(i == 0 || ecdf[i] > lastX);
+
+ x[i / 2] = ecdf[i];
+ lastX = x[i / 2];
+ }
+
+ // squeeze a bit to allow for unobserved tails
+ double x0 = exceedMinimum ? 0.5 / samples : 0;
+ double x1 = 1 - (exceedMaximum ? 0.5 / samples : 0);
+ for (int i = 0; i < n; i++) {
+ x[i] = x[i] * (x1 - x0) + x0;
+ }
+
+ this.exceedMinimum = exceedMinimum;
+ this.exceedMaximum = exceedMaximum;
+ }
+
+ @Override
+ public Double sample() {
+ return sample(gen.nextDouble());
+ }
+
+ public double sample(double u) {
+ if (exceedMinimum && u < x[0]) {
+ // generate from left tail
+ if (u == 0) {
+ u = 1.0e-16;
+ }
+ return y[0] + Math.log(u / x[0]) * x[0] * (y[1] - y[0]) / (x[1] - x[0]);
+ } else if (exceedMaximum && u > x[n - 1]) {
+ if (u == 1) {
+ u = 1 - 1.0e-16;
+ }
+ // generate from right tail
+ double dy = y[n - 1] - y[n - 2];
+ double dx = x[n - 1] - x[n - 2];
+ return y[n - 1] - Math.log((1 - u) / (1 - x[n - 1])) * (1 - x[n - 1]) * dy / dx;
+ } else {
+ // linear interpolation
+ for (int i = 1; i < n; i++) {
+ if (x[i] > u) {
+ double dy = y[i] - y[i - 1];
+ double dx = x[i] - x[i - 1];
+ return y[i - 1] + (u - x[i - 1]) * dy / dx;
+ }
+ }
+ throw new RuntimeException(String.format("Can't happen (%.3f is not in [%.3f,%.3f]", u, x[0], x[n - 1]));
+ }
+ }
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/random/IndianBuffet.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/mahout/math/random/IndianBuffet.java b/core/src/main/java/org/apache/mahout/math/random/IndianBuffet.java
new file mode 100644
index 0000000..27b5d84
--- /dev/null
+++ b/core/src/main/java/org/apache/mahout/math/random/IndianBuffet.java
@@ -0,0 +1,157 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math.random;
+
+import com.google.common.base.CharMatcher;
+import com.google.common.base.Charsets;
+import com.google.common.base.Splitter;
+import com.google.common.collect.Iterables;
+import com.google.common.collect.Lists;
+import com.google.common.io.LineProcessor;
+import com.google.common.io.Resources;
+import org.apache.mahout.common.RandomUtils;
+
+import java.io.IOException;
+import java.util.List;
+import java.util.Random;
+
+/**
+ * Samples a "document" from an IndianBuffet process.
+ *
+ * See http://mlg.eng.cam.ac.uk/zoubin/talks/turin09.pdf for details
+ */
+public final class IndianBuffet<T> implements Sampler<List<T>> {
+ private final List<Integer> count = Lists.newArrayList();
+ private int documents = 0;
+ private final double alpha;
+ private WordFunction<T> converter = null;
+ private final Random gen;
+
+ public IndianBuffet(double alpha, WordFunction<T> converter) {
+ this.alpha = alpha;
+ this.converter = converter;
+ gen = RandomUtils.getRandom();
+ }
+
+ public static IndianBuffet<Integer> createIntegerDocumentSampler(double alpha) {
+ return new IndianBuffet<>(alpha, new IdentityConverter());
+ }
+
+ public static IndianBuffet<String> createTextDocumentSampler(double alpha) {
+ return new IndianBuffet<>(alpha, new WordConverter());
+ }
+
+ @Override
+ public List<T> sample() {
+ List<T> r = Lists.newArrayList();
+ if (documents == 0) {
+ double n = new PoissonSampler(alpha).sample();
+ for (int i = 0; i < n; i++) {
+ r.add(converter.convert(i));
+ count.add(1);
+ }
+ documents++;
+ } else {
+ documents++;
+ int i = 0;
+ for (double cnt : count) {
+ if (gen.nextDouble() < cnt / documents) {
+ r.add(converter.convert(i));
+ count.set(i, count.get(i) + 1);
+ }
+ i++;
+ }
+ int newItems = new PoissonSampler(alpha / documents).sample().intValue();
+ for (int j = 0; j < newItems; j++) {
+ r.add(converter.convert(i + j));
+ count.add(1);
+ }
+ }
+ return r;
+ }
+
+ private interface WordFunction<T> {
+ T convert(int i);
+ }
+
+ /**
+ * Just converts to an integer.
+ */
+ public static class IdentityConverter implements WordFunction<Integer> {
+ @Override
+ public Integer convert(int i) {
+ return i;
+ }
+ }
+
+ /**
+ * Converts to a string.
+ */
+ public static class StringConverter implements WordFunction<String> {
+ @Override
+ public String convert(int i) {
+ return String.valueOf(i);
+ }
+ }
+
+ /**
+ * Converts to one of a list of common English words for reasonably small integers and converts
+ * to a token like w_92463 for big integers.
+ */
+ public static final class WordConverter implements WordFunction<String> {
+ private final Splitter onSpace = Splitter.on(CharMatcher.WHITESPACE).omitEmptyStrings().trimResults();
+ private final List<String> words;
+
+ public WordConverter() {
+ try {
+ words = Resources.readLines(Resources.getResource("words.txt"), Charsets.UTF_8,
+ new LineProcessor<List<String>>() {
+ private final List<String> theWords = Lists.newArrayList();
+
+ @Override
+ public boolean processLine(String line) {
+ Iterables.addAll(theWords, onSpace.split(line));
+ return true;
+ }
+
+ @Override
+ public List<String> getResult() {
+ return theWords;
+ }
+ });
+ } catch (IOException e) {
+ throw new ImpossibleException(e);
+ }
+ }
+
+ @Override
+ public String convert(int i) {
+ if (i < words.size()) {
+ return words.get(i);
+ } else {
+ return "w_" + i;
+ }
+ }
+ }
+
+ public static class ImpossibleException extends RuntimeException {
+ public ImpossibleException(Throwable e) {
+ super(e);
+ }
+ }
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/random/Missing.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/mahout/math/random/Missing.java b/core/src/main/java/org/apache/mahout/math/random/Missing.java
new file mode 100644
index 0000000..8141a71
--- /dev/null
+++ b/core/src/main/java/org/apache/mahout/math/random/Missing.java
@@ -0,0 +1,59 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math.random;
+
+import java.util.Random;
+
+import org.apache.mahout.common.RandomUtils;
+
+/**
+ * Models data with missing values. Note that all variables with the same fraction of missing
+ * values will have the same sequence of missing values. Similarly, if two variables have
+ * missing probabilities of p1 > p2, then all of the p2 missing values will also be missing for
+ * p1.
+ */
+public final class Missing<T> implements Sampler<T> {
+ private final Random gen;
+ private final double p;
+ private final Sampler<T> delegate;
+ private final T missingMarker;
+
+ public Missing(int seed, double p, Sampler<T> delegate, T missingMarker) {
+ this.p = p;
+ this.delegate = delegate;
+ this.missingMarker = missingMarker;
+ gen = RandomUtils.getRandom(seed);
+ }
+
+ public Missing(double p, Sampler<T> delegate, T missingMarker) {
+ this(1, p, delegate, missingMarker);
+ }
+
+ public Missing(double p, Sampler<T> delegate) {
+ this(1, p, delegate, null);
+ }
+
+ @Override
+ public T sample() {
+ if (gen.nextDouble() >= p) {
+ return delegate.sample();
+ } else {
+ return missingMarker;
+ }
+ }
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/random/MultiNormal.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/mahout/math/random/MultiNormal.java b/core/src/main/java/org/apache/mahout/math/random/MultiNormal.java
new file mode 100644
index 0000000..748d4e8
--- /dev/null
+++ b/core/src/main/java/org/apache/mahout/math/random/MultiNormal.java
@@ -0,0 +1,118 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math.random;
+
+import org.apache.mahout.common.RandomUtils;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.DiagonalMatrix;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.function.DoubleFunction;
+
+import java.util.Random;
+
+/**
+ * Samples from a multi-variate normal distribution.
+ * <p/>
+ * This is done by sampling from several independent unit normal distributions to get a vector u.
+ * The sample value that is returned is then A u + m where A is derived from the covariance matrix
+ * and m is the mean of the result.
+ * <p/>
+ * If \Sigma is the desired covariance matrix, then you can use any value of A such that A' A =
+ * \Sigma. The Cholesky decomposition can be used to compute A if \Sigma is positive definite.
+ * Slightly more expensive is to use the SVD U S V' = \Sigma and then set A = U \sqrt{S}.
+ *
+ * Useful special cases occur when \Sigma is diagonal so that A = \sqrt(\Sigma) or where \Sigma = r I.
+ *
+ * Another special case is where m = 0.
+ */
+public class MultiNormal implements Sampler<Vector> {
+ private final Random gen;
+ private final int dimension;
+ private final Matrix scale;
+ private final Vector mean;
+
+ /**
+ * Constructs a sampler with diagonal scale matrix.
+ * @param diagonal The diagonal elements of the scale matrix.
+ */
+ public MultiNormal(Vector diagonal) {
+ this(new DiagonalMatrix(diagonal), null);
+ }
+
+ /**
+ * Constructs a sampler with diagonal scale matrix and (potentially)
+ * non-zero mean.
+ * @param diagonal The scale matrix's principal diagonal.
+ * @param mean The desired mean. Set to null if zero mean is desired.
+ */
+ public MultiNormal(Vector diagonal, Vector mean) {
+ this(new DiagonalMatrix(diagonal), mean);
+ }
+
+ /**
+ * Constructs a sampler with non-trivial scale matrix and mean.
+ */
+ public MultiNormal(Matrix a, Vector mean) {
+ this(a, mean, a.columnSize());
+ }
+
+ public MultiNormal(int dimension) {
+ this(null, null, dimension);
+ }
+
+ public MultiNormal(double radius, Vector mean) {
+ this(new DiagonalMatrix(radius, mean.size()), mean);
+ }
+
+ private MultiNormal(Matrix scale, Vector mean, int dimension) {
+ gen = RandomUtils.getRandom();
+ this.dimension = dimension;
+ this.scale = scale;
+ this.mean = mean;
+ }
+
+ @Override
+ public Vector sample() {
+ Vector v = new DenseVector(dimension).assign(
+ new DoubleFunction() {
+ @Override
+ public double apply(double ignored) {
+ return gen.nextGaussian();
+ }
+ }
+ );
+ if (mean != null) {
+ if (scale != null) {
+ return scale.times(v).plus(mean);
+ } else {
+ return v.plus(mean);
+ }
+ } else {
+ if (scale != null) {
+ return scale.times(v);
+ } else {
+ return v;
+ }
+ }
+ }
+
+ public Vector getScale() {
+ return mean;
+ }
+}

Loading...