[03/24] mahout git commit: MAHOUT-2034 Split MR and New Examples into seperate modules
2018-06-27 13:14:29 UTC
diff --git a/examples/src/main/resources/cf-data-purchase.txt b/examples/src/main/resources/cf-data-purchase.txt
deleted file mode 100644
index d87c031..0000000
--- a/examples/src/main/resources/cf-data-purchase.txt
+++ /dev/null
@@ -1,7 +0,0 @@

diff --git a/examples/src/main/resources/cf-data-view.txt b/examples/src/main/resources/cf-data-view.txt
deleted file mode 100644
index 09ad9b6..0000000
--- a/examples/src/main/resources/cf-data-view.txt
+++ /dev/null
@@ -1,12 +0,0 @@

diff --git a/examples/src/main/resources/donut-test.csv b/examples/src/main/resources/donut-test.csv
deleted file mode 100644
index 46ea564..0000000
--- a/examples/src/main/resources/donut-test.csv
+++ /dev/null
@@ -1,41 +0,0 @@

diff --git a/examples/src/main/resources/donut.csv b/examples/src/main/resources/donut.csv
deleted file mode 100644
index 33ba3b7..0000000
--- a/examples/src/main/resources/donut.csv
+++ /dev/null
@@ -1,41 +0,0 @@

diff --git a/examples/src/main/resources/test-data.csv b/examples/src/main/resources/test-data.csv
deleted file mode 100644
index ab683cd..0000000
--- a/examples/src/main/resources/test-data.csv
+++ /dev/null
@@ -1,61 +0,0 @@

diff --git a/examples/src/test/java/org/apache/mahout/classifier/sgd/LogisticModelParametersTest.java b/examples/src/test/java/org/apache/mahout/classifier/sgd/LogisticModelParametersTest.java
deleted file mode 100644
index e849011..0000000
--- a/examples/src/test/java/org/apache/mahout/classifier/sgd/LogisticModelParametersTest.java
+++ /dev/null
@@ -1,43 +0,0 @@
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.mahout.classifier.sgd;
-import org.apache.mahout.common.MahoutTestCase;
-import org.junit.Test;
-import java.io.ByteArrayOutputStream;
-import java.io.IOException;
-import java.util.Arrays;
-import java.util.Collections;
-public class LogisticModelParametersTest extends MahoutTestCase {
- @Test
- public void serializationWithoutCsv() throws IOException {
- LogisticModelParameters params = new LogisticModelParameters();
- params.setTargetVariable("foo");
- params.setTypeMap(Collections.<String, String>emptyMap());
- params.setTargetCategories(Arrays.asList("foo", "bar"));
- params.setNumFeatures(1);
- params.createRegression();
- //MAHOUT-1196 should work without "csv" being set
- params.saveTo(new ByteArrayOutputStream());
- }

diff --git a/examples/src/test/java/org/apache/mahout/classifier/sgd/ModelDissectorTest.java b/examples/src/test/java/org/apache/mahout/classifier/sgd/ModelDissectorTest.java
deleted file mode 100644
index c8e4879..0000000
--- a/examples/src/test/java/org/apache/mahout/classifier/sgd/ModelDissectorTest.java
+++ /dev/null
@@ -1,40 +0,0 @@
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.mahout.classifier.sgd;
-import org.apache.mahout.examples.MahoutTestCase;
-import org.apache.mahout.math.DenseVector;
-import org.junit.Test;
-public class ModelDissectorTest extends MahoutTestCase {
- @Test
- public void testCategoryOrdering() {
- ModelDissector.Weight w = new ModelDissector.Weight("a", new DenseVector(new double[]{-2, -5, 5, 2, 4, 1, 0}), 4);
- assertEquals(1, w.getCategory(0), 0);
- assertEquals(-5, w.getWeight(0), 0);
- assertEquals(2, w.getCategory(1), 0);
- assertEquals(5, w.getWeight(1), 0);
- assertEquals(4, w.getCategory(2), 0);
- assertEquals(4, w.getWeight(2), 0);
- assertEquals(0, w.getCategory(3), 0);
- assertEquals(-2, w.getWeight(3), 0);
- }

diff --git a/examples/src/test/java/org/apache/mahout/classifier/sgd/TrainLogisticTest.java b/examples/src/test/java/org/apache/mahout/classifier/sgd/TrainLogisticTest.java
deleted file mode 100644
index 4cde692..0000000
--- a/examples/src/test/java/org/apache/mahout/classifier/sgd/TrainLogisticTest.java
+++ /dev/null
@@ -1,167 +0,0 @@
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.mahout.classifier.sgd;
-import com.google.common.base.Charsets;
-import com.google.common.collect.ImmutableMap;
-import com.google.common.collect.Sets;
-import com.google.common.io.Resources;
-import org.apache.mahout.classifier.AbstractVectorClassifier;
-import org.apache.mahout.examples.MahoutTestCase;
-import org.apache.mahout.math.DenseVector;
-import org.apache.mahout.math.Vector;
-import org.junit.Test;
-import java.io.File;
-import java.io.FileInputStream;
-import java.io.InputStream;
-import java.io.PrintWriter;
-import java.io.StringWriter;
-import java.util.List;
-import java.util.Map;
-import java.util.Set;
-import java.util.TreeSet;
-public class TrainLogisticTest extends MahoutTestCase {
- @Test
- public void example131() throws Exception {
- String outputFile = getTestTempFile("model").getAbsolutePath();
- StringWriter sw = new StringWriter();
- PrintWriter pw = new PrintWriter(sw, true);
- TrainLogistic.mainToOutput(new String[]{
- "--input", "donut.csv",
- "--output", outputFile,
- "--target", "color", "--categories", "2",
- "--predictors", "x", "y",
- "--types", "numeric",
- "--features", "20",
- "--passes", "100",
- "--rate", "50"
- }, pw);
- String trainOut = sw.toString();
- assertTrue(trainOut.contains("x -0.7"));
- assertTrue(trainOut.contains("y -0.4"));
- LogisticModelParameters lmp = TrainLogistic.getParameters();
- assertEquals(1.0e-4, lmp.getLambda(), 1.0e-9);
- assertEquals(20, lmp.getNumFeatures());
- assertTrue(lmp.useBias());
- assertEquals("color", lmp.getTargetVariable());
- CsvRecordFactory csv = lmp.getCsvRecordFactory();
- assertEquals("[1, 2]", new TreeSet<>(csv.getTargetCategories()).toString());
- assertEquals("[Intercept Term, x, y]", Sets.newTreeSet(csv.getPredictors()).toString());
- // verify model by building dissector
- AbstractVectorClassifier model = TrainLogistic.getModel();
- List<String> data = Resources.readLines(Resources.getResource("donut.csv"), Charsets.UTF_8);
- Map<String, Double> expectedValues = ImmutableMap.of("x", -0.7, "y", -0.43, "Intercept Term", -0.15);
- verifyModel(lmp, csv, data, model, expectedValues);
- // test saved model
- try (InputStream in = new FileInputStream(new File(outputFile))){
- LogisticModelParameters lmpOut = LogisticModelParameters.loadFrom(in);
- CsvRecordFactory csvOut = lmpOut.getCsvRecordFactory();
- csvOut.firstLine(data.get(0));
- OnlineLogisticRegression lrOut = lmpOut.createRegression();
- verifyModel(lmpOut, csvOut, data, lrOut, expectedValues);
- }
- sw = new StringWriter();
- pw = new PrintWriter(sw, true);
- RunLogistic.mainToOutput(new String[]{
- "--input", "donut.csv",
- "--model", outputFile,
- "--auc",
- "--confusion"
- }, pw);
- trainOut = sw.toString();
- assertTrue(trainOut.contains("AUC = 0.57"));
- assertTrue(trainOut.contains("confusion: [[27.0, 13.0], [0.0, 0.0]]"));
- }
- @Test
- public void example132() throws Exception {
- String outputFile = getTestTempFile("model").getAbsolutePath();
- StringWriter sw = new StringWriter();
- PrintWriter pw = new PrintWriter(sw, true);
- TrainLogistic.mainToOutput(new String[]{
- "--input", "donut.csv",
- "--output", outputFile,
- "--target", "color",
- "--categories", "2",
- "--predictors", "x", "y", "a", "b", "c",
- "--types", "numeric",
- "--features", "20",
- "--passes", "100",
- "--rate", "50"
- }, pw);
- String trainOut = sw.toString();
- assertTrue(trainOut.contains("a 0."));
- assertTrue(trainOut.contains("b -1."));
- assertTrue(trainOut.contains("c -25."));
- sw = new StringWriter();
- pw = new PrintWriter(sw, true);
- RunLogistic.mainToOutput(new String[]{
- "--input", "donut.csv",
- "--model", outputFile,
- "--auc",
- "--confusion"
- }, pw);
- trainOut = sw.toString();
- assertTrue(trainOut.contains("AUC = 1.00"));
- sw = new StringWriter();
- pw = new PrintWriter(sw, true);
- RunLogistic.mainToOutput(new String[]{
- "--input", "donut-test.csv",
- "--model", outputFile,
- "--auc",
- "--confusion"
- }, pw);
- trainOut = sw.toString();
- assertTrue(trainOut.contains("AUC = 0.9"));
- }
- private static void verifyModel(LogisticModelParameters lmp,
- RecordFactory csv,
- List<String> data,
- AbstractVectorClassifier model,
- Map<String, Double> expectedValues) {
- ModelDissector md = new ModelDissector();
- for (String line : data.subList(1, data.size())) {
- Vector v = new DenseVector(lmp.getNumFeatures());
- csv.getTraceDictionary().clear();
- csv.processLine(line, v);
- md.update(v, csv.getTraceDictionary(), model);
- }
- // check right variables are present
- List<ModelDissector.Weight> weights = md.summary(10);
- Set<String> expected = Sets.newHashSet(expectedValues.keySet());
- for (ModelDissector.Weight weight : weights) {
- assertTrue(expected.remove(weight.getFeature()));
- assertEquals(expectedValues.get(weight.getFeature()), weight.getWeight(), 0.1);
- }
- assertEquals(0, expected.size());
- }

diff --git a/examples/src/test/java/org/apache/mahout/clustering/display/ClustersFilterTest.java b/examples/src/test/java/org/apache/mahout/clustering/display/ClustersFilterTest.java
deleted file mode 100644
index 6e43b97..0000000
--- a/examples/src/test/java/org/apache/mahout/clustering/display/ClustersFilterTest.java
+++ /dev/null
@@ -1,75 +0,0 @@
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.mahout.clustering.display;
-import org.apache.hadoop.conf.Configuration;
-import org.apache.hadoop.fs.Path;
-import org.apache.hadoop.fs.PathFilter;
-import org.apache.mahout.common.MahoutTestCase;
-import org.junit.Before;
-import org.junit.Test;
-import java.io.IOException;
-public class ClustersFilterTest extends MahoutTestCase {
- private Configuration configuration;
- private Path output;
- @Override
- @Before
- public void setUp() throws Exception {
- super.setUp();
- configuration = getConfiguration();
- output = getTestTempDirPath();
- }
- @Test
- public void testAcceptNotFinal() throws Exception {
- Path path0 = new Path(output, "clusters-0");
- Path path1 = new Path(output, "clusters-1");
- path0.getFileSystem(configuration).createNewFile(path0);
- path1.getFileSystem(configuration).createNewFile(path1);
- PathFilter clustersFilter = new ClustersFilter();
- assertTrue(clustersFilter.accept(path0));
- assertTrue(clustersFilter.accept(path1));
- }
- @Test
- public void testAcceptFinalPath() throws IOException {
- Path path0 = new Path(output, "clusters-0");
- Path path1 = new Path(output, "clusters-1");
- Path path2 = new Path(output, "clusters-2");
- Path path3Final = new Path(output, "clusters-3-final");
- path0.getFileSystem(configuration).createNewFile(path0);
- path1.getFileSystem(configuration).createNewFile(path1);
- path2.getFileSystem(configuration).createNewFile(path2);
- path3Final.getFileSystem(configuration).createNewFile(path3Final);
- PathFilter clustersFilter = new ClustersFilter();
- assertTrue(clustersFilter.accept(path0));
- assertTrue(clustersFilter.accept(path1));
- assertTrue(clustersFilter.accept(path2));
- assertTrue(clustersFilter.accept(path3Final));
- }

diff --git a/examples/src/test/java/org/apache/mahout/examples/MahoutTestCase.java b/examples/src/test/java/org/apache/mahout/examples/MahoutTestCase.java
deleted file mode 100644
index 4d81e3f..0000000
--- a/examples/src/test/java/org/apache/mahout/examples/MahoutTestCase.java
+++ /dev/null
@@ -1,30 +0,0 @@
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.mahout.examples;
- * This class should not exist. It's here to work around some bizarre problem in Maven
- * dependency management wherein it can see methods in {@link org.apache.mahout.common.MahoutTestCase}
- * but not constants. Duplicated here to make it jive.
- */
-public abstract class MahoutTestCase extends org.apache.mahout.common.MahoutTestCase {
- /** "Close enough" value for floating-point comparisons. */
- public static final double EPSILON = 0.000001;

diff --git a/examples/src/test/resources/country.txt b/examples/src/test/resources/country.txt
deleted file mode 100644
index 6a22091..0000000
--- a/examples/src/test/resources/country.txt
+++ /dev/null
@@ -1,229 +0,0 @@
-American Samoa
-Antigua and Barbuda
-Bosnia and Herzegovina
-Bouvet Island
-British Indian Ocean Territory
-Brunei Darussalam
-Burkina Faso
-Cape Verde
-Cayman Islands
-Central African Republic
-Christmas Island
-Cocos Islands
-Cook Islands
-Costa Rica
-C�te d'Ivoire
-Czech Republic
-Dominican Republic
-El Salvador
-Equatorial Guinea
-Falkland Islands
-Faroe Islands
-French Guiana
-French Polynesia
-French Southern Territories
-Hong Kong
-Isle of Man
-Marshall Islands
-Netherlands Antilles
-New Caledonia
-New Zealand
-Norfolk Island
-Northern Mariana Islands
-Palestinian Territory
-Papua New Guinea
-Puerto Rico
-Russian Federation
-Saint Barth�lemy
-Saint Helena
-Saint Kitts and Nevis
-Saint Lucia
-Saint Martin
-Saint Pierre and Miquelon
-Saint Vincent and the Grenadines
-San Marino
-Sao Tome and Principe
-Saudi Arabia
-Sierra Leone
-Solomon Islands
-South Africa
-South Georgia and the South Sandwich Islands
-Sri Lanka
-Svalbard and Jan Mayen
-Syrian Arab Republic
-Trinidad and Tobago
-Turks and Caicos Islands
-United Arab Emirates
-United Kingdom
-United States
-United States Minor Outlying Islands
-Virgin Islands
-Wallis and Futuna

diff --git a/examples/src/test/resources/country10.txt b/examples/src/test/resources/country10.txt
deleted file mode 100644
index 97a63e1..0000000
--- a/examples/src/test/resources/country10.txt
+++ /dev/null
@@ -1,10 +0,0 @@
-United Kingdom

diff --git a/examples/src/test/resources/country2.txt b/examples/src/test/resources/country2.txt
deleted file mode 100644
index f4b4f61..0000000
--- a/examples/src/test/resources/country2.txt
+++ /dev/null
@@ -1,2 +0,0 @@
-United States
-United Kingdom

diff --git a/examples/src/test/resources/subjects.txt b/examples/src/test/resources/subjects.txt
deleted file mode 100644
index f52ae33..0000000
--- a/examples/src/test/resources/subjects.txt
+++ /dev/null
@@ -1,2 +0,0 @@

diff --git a/examples/src/test/resources/wdbc.infos b/examples/src/test/resources/wdbc.infos
deleted file mode 100644
index 94a63d6..0000000
--- a/examples/src/test/resources/wdbc.infos
+++ /dev/null
@@ -1,32 +0,0 @@
-NUMERICAL, 6.9, 28.2
-NUMERICAL, 9.7, 39.3
-NUMERICAL, 43.7, 188.5
-NUMERICAL, 143.5, 2501.0
-NUMERICAL, 0.0, 0.2
-NUMERICAL, 0.0, 0.4
-NUMERICAL, 0.0, 0.5
-NUMERICAL, 0.0, 0.3
-NUMERICAL, 0.1, 0.4
-NUMERICAL, 0.0, 0.1
-NUMERICAL, 0.1, 2.9
-NUMERICAL, 0.3, 4.9
-NUMERICAL, 0.7, 22.0
-NUMERICAL, 6.8, 542.3
-NUMERICAL, 0.0, 0.1
-NUMERICAL, 0.0, 0.2
-NUMERICAL, 0.0, 0.4
-NUMERICAL, 0.0, 0.1
-NUMERICAL, 0.0, 0.1
-NUMERICAL, 0.0, 0.1
-NUMERICAL, 7.9, 36.1
-NUMERICAL, 12.0, 49.6
-NUMERICAL, 50.4, 251.2
-NUMERICAL, 185.2, 4254.0
-NUMERICAL, 0.0, 0.3
-NUMERICAL, 0.0, 1.1
-NUMERICAL, 0.0, 1.3
-NUMERICAL, 0.0, 0.3
-NUMERICAL, 0.1, 0.7
-NUMERICAL, 0.0, 0.3
2018-06-27 13:14:30 UTC
diff --git a/examples/src/main/resources/bank-full.csv b/examples/src/main/resources/bank-full.csv
deleted file mode 100644
index d7a2ede..0000000
--- a/examples/src/main/resources/bank-full.csv
+++ /dev/null
@@ -1,45212 +0,0 @@

2018-06-27 13:14:32 UTC
diff --git a/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainNewsGroups.java b/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainNewsGroups.java
deleted file mode 100644
index 632b32c..0000000
--- a/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainNewsGroups.java
+++ /dev/null
@@ -1,154 +0,0 @@
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.mahout.classifier.sgd;
-import com.google.common.collect.HashMultiset;
-import com.google.common.collect.Multiset;
-import com.google.common.collect.Ordering;
-import org.apache.mahout.classifier.NewsgroupHelper;
-import org.apache.mahout.ep.State;
-import org.apache.mahout.math.Vector;
-import org.apache.mahout.vectorizer.encoders.Dictionary;
-import java.io.File;
-import java.io.IOException;
-import java.util.ArrayList;
-import java.util.Arrays;
-import java.util.Collections;
-import java.util.List;
- * Reads and trains an adaptive logistic regression model on the 20 newsgroups data.
- * The first command line argument gives the path of the directory holding the training
- * data. The optional second argument, leakType, defines which classes of features to use.
- * Importantly, leakType controls whether a synthetic date is injected into the data as
- * a target leak and if so, how.
- * <p/>
- * The value of leakType % 3 determines whether the target leak is injected according to
- * the following table:
- * <p/>
- * <table>
- * <tr><td valign='top'>0</td><td>No leak injected</td></tr>
- * <tr><td valign='top'>1</td><td>Synthetic date injected in MMM-yyyy format. This will be a single token and
- * is a perfect target leak since each newsgroup is given a different month</td></tr>
- * <tr><td valign='top'>2</td><td>Synthetic date injected in dd-MMM-yyyy HH:mm:ss format. The day varies
- * and thus there are more leak symbols that need to be learned. Ultimately this is just
- * as big a leak as case 1.</td></tr>
- * </table>
- * <p/>
- * Leaktype also determines what other text will be indexed. If leakType is greater
- * than or equal to 6, then neither headers nor text body will be used for features and the leak is the only
- * source of data. If leakType is greater than or equal to 3, then subject words will be used as features.
- * If leakType is less than 3, then both subject and body text will be used as features.
- * <p/>
- * A leakType of 0 gives no leak and all textual features.
- * <p/>
- * See the following table for a summary of commonly used values for leakType
- * <p/>
- * <table>
- * <tr><td><b>leakType</b></td><td><b>Leak?</b></td><td><b>Subject?</b></td><td><b>Body?</b></td></tr>
- * <tr><td colspan=4><hr></td></tr>
- * <tr><td>0</td><td>no</td><td>yes</td><td>yes</td></tr>
- * <tr><td>1</td><td>mmm-yyyy</td><td>yes</td><td>yes</td></tr>
- * <tr><td>2</td><td>dd-mmm-yyyy</td><td>yes</td><td>yes</td></tr>
- * <tr><td colspan=4><hr></td></tr>
- * <tr><td>3</td><td>no</td><td>yes</td><td>no</td></tr>
- * <tr><td>4</td><td>mmm-yyyy</td><td>yes</td><td>no</td></tr>
- * <tr><td>5</td><td>dd-mmm-yyyy</td><td>yes</td><td>no</td></tr>
- * <tr><td colspan=4><hr></td></tr>
- * <tr><td>6</td><td>no</td><td>no</td><td>no</td></tr>
- * <tr><td>7</td><td>mmm-yyyy</td><td>no</td><td>no</td></tr>
- * <tr><td>8</td><td>dd-mmm-yyyy</td><td>no</td><td>no</td></tr>
- * <tr><td colspan=4><hr></td></tr>
- * </table>
- */
-public final class TrainNewsGroups {
- private TrainNewsGroups() {
- }
- public static void main(String[] args) throws IOException {
- File base = new File(args[0]);
- Multiset<String> overallCounts = HashMultiset.create();
- int leakType = 0;
- if (args.length > 1) {
- leakType = Integer.parseInt(args[1]);
- }
- Dictionary newsGroups = new Dictionary();
- NewsgroupHelper helper = new NewsgroupHelper();
- helper.getEncoder().setProbes(2);
- AdaptiveLogisticRegression learningAlgorithm =
- new AdaptiveLogisticRegression(20, NewsgroupHelper.FEATURES, new L1());
- learningAlgorithm.setInterval(800);
- learningAlgorithm.setAveragingWindow(500);
- List<File> files = new ArrayList<>();
- for (File newsgroup : base.listFiles()) {
- if (newsgroup.isDirectory()) {
- newsGroups.intern(newsgroup.getName());
- files.addAll(Arrays.asList(newsgroup.listFiles()));
- }
- }
- Collections.shuffle(files);
- System.out.println(files.size() + " training files");
- SGDInfo info = new SGDInfo();
- int k = 0;
- for (File file : files) {
- String ng = file.getParentFile().getName();
- int actual = newsGroups.intern(ng);
- Vector v = helper.encodeFeatureVector(file, actual, leakType, overallCounts);
- learningAlgorithm.train(actual, v);
- k++;
- State<AdaptiveLogisticRegression.Wrapper, CrossFoldLearner> best = learningAlgorithm.getBest();
- SGDHelper.analyzeState(info, leakType, k, best);
- }
- learningAlgorithm.close();
- SGDHelper.dissect(leakType, newsGroups, learningAlgorithm, files, overallCounts);
- System.out.println("exiting main");
- File modelFile = new File(System.getProperty("java.io.tmpdir"), "news-group.model");
- ModelSerializer.writeBinary(modelFile.getAbsolutePath(),
- learningAlgorithm.getBest().getPayload().getLearner().getModels().get(0));
- List<Integer> counts = new ArrayList<>();
- System.out.println("Word counts");
- for (String count : overallCounts.elementSet()) {
- counts.add(overallCounts.count(count));
- }
- Collections.sort(counts, Ordering.natural().reverse());
- k = 0;
- for (Integer count : counts) {
- System.out.println(k + "\t" + count);
- k++;
- if (k > 1000) {
- break;
- }
- }
- }

diff --git a/examples/src/main/java/org/apache/mahout/classifier/sgd/ValidateAdaptiveLogistic.java b/examples/src/main/java/org/apache/mahout/classifier/sgd/ValidateAdaptiveLogistic.java
deleted file mode 100644
index 7a74289..0000000
--- a/examples/src/main/java/org/apache/mahout/classifier/sgd/ValidateAdaptiveLogistic.java
+++ /dev/null
@@ -1,218 +0,0 @@
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.mahout.classifier.sgd;
-import java.io.BufferedReader;
-import java.io.File;
-import java.io.IOException;
-import java.io.OutputStreamWriter;
-import java.io.PrintWriter;
-import java.util.Locale;
-import org.apache.commons.cli2.CommandLine;
-import org.apache.commons.cli2.Group;
-import org.apache.commons.cli2.Option;
-import org.apache.commons.cli2.builder.ArgumentBuilder;
-import org.apache.commons.cli2.builder.DefaultOptionBuilder;
-import org.apache.commons.cli2.builder.GroupBuilder;
-import org.apache.commons.cli2.commandline.Parser;
-import org.apache.commons.cli2.util.HelpFormatter;
-import org.apache.commons.io.Charsets;
-import org.apache.mahout.classifier.ConfusionMatrix;
-import org.apache.mahout.classifier.evaluation.Auc;
-import org.apache.mahout.classifier.sgd.AdaptiveLogisticRegression.Wrapper;
-import org.apache.mahout.ep.State;
-import org.apache.mahout.math.Matrix;
-import org.apache.mahout.math.SequentialAccessSparseVector;
-import org.apache.mahout.math.Vector;
-import org.apache.mahout.math.stats.OnlineSummarizer;
- * Auc and averageLikelihood are always shown if possible, if the number of target value is more than 2,
- * then Auc and entropy matirx are not shown regardless the value of showAuc and showEntropy
- * the user passes, because the current implementation does not support them on two value targets.
- * */
-public final class ValidateAdaptiveLogistic {
- private static String inputFile;
- private static String modelFile;
- private static String defaultCategory;
- private static boolean showAuc;
- private static boolean showScores;
- private static boolean showConfusion;
- private ValidateAdaptiveLogistic() {
- }
- public static void main(String[] args) throws IOException {
- mainToOutput(args, new PrintWriter(new OutputStreamWriter(System.out, Charsets.UTF_8), true));
- }
- static void mainToOutput(String[] args, PrintWriter output) throws IOException {
- if (parseArgs(args)) {
- if (!showAuc && !showConfusion && !showScores) {
- showAuc = true;
- showConfusion = true;
- }
- Auc collector = null;
- AdaptiveLogisticModelParameters lmp = AdaptiveLogisticModelParameters
- .loadFromFile(new File(modelFile));
- CsvRecordFactory csv = lmp.getCsvRecordFactory();
- AdaptiveLogisticRegression lr = lmp.createAdaptiveLogisticRegression();
- if (lmp.getTargetCategories().size() <= 2) {
- collector = new Auc();
- }
- OnlineSummarizer slh = new OnlineSummarizer();
- ConfusionMatrix cm = new ConfusionMatrix(lmp.getTargetCategories(), defaultCategory);
- State<Wrapper, CrossFoldLearner> best = lr.getBest();
- if (best == null) {
- output.println("AdaptiveLogisticRegression has not be trained probably.");
- return;
- }
- CrossFoldLearner learner = best.getPayload().getLearner();
- BufferedReader in = TrainLogistic.open(inputFile);
- String line = in.readLine();
- csv.firstLine(line);
- line = in.readLine();
- if (showScores) {
- output.println("\"target\", \"model-output\", \"log-likelihood\", \"average-likelihood\"");
- }
- while (line != null) {
- Vector v = new SequentialAccessSparseVector(lmp.getNumFeatures());
- //TODO: How to avoid extra target values not shown in the training process.
- int target = csv.processLine(line, v);
- double likelihood = learner.logLikelihood(target, v);
- double score = learner.classifyFull(v).maxValue();
- slh.add(likelihood);
- cm.addInstance(csv.getTargetString(line), csv.getTargetLabel(target));
- if (showScores) {
- output.printf(Locale.ENGLISH, "%8d, %.12f, %.13f, %.13f%n", target,
- score, learner.logLikelihood(target, v), slh.getMean());
- }
- if (collector != null) {
- collector.add(target, score);
- }
- line = in.readLine();
- }
- output.printf(Locale.ENGLISH,"\nLog-likelihood:");
- output.printf(Locale.ENGLISH, "Min=%.2f, Max=%.2f, Mean=%.2f, Median=%.2f%n",
- slh.getMin(), slh.getMax(), slh.getMean(), slh.getMedian());
- if (collector != null) {
- output.printf(Locale.ENGLISH, "%nAUC = %.2f%n", collector.auc());
- }
- if (showConfusion) {
- output.printf(Locale.ENGLISH, "%n%s%n%n", cm.toString());
- if (collector != null) {
- Matrix m = collector.entropy();
- output.printf(Locale.ENGLISH,
- "Entropy Matrix: [[%.1f, %.1f], [%.1f, %.1f]]%n", m.get(0, 0),
- m.get(1, 0), m.get(0, 1), m.get(1, 1));
- }
- }
- }
- }
- private static boolean parseArgs(String[] args) {
- DefaultOptionBuilder builder = new DefaultOptionBuilder();
- Option help = builder.withLongName("help")
- .withDescription("print this list").create();
- Option quiet = builder.withLongName("quiet")
- .withDescription("be extra quiet").create();
- Option auc = builder.withLongName("auc").withDescription("print AUC")
- .create();
- Option confusion = builder.withLongName("confusion")
- .withDescription("print confusion matrix").create();
- Option scores = builder.withLongName("scores")
- .withDescription("print scores").create();
- ArgumentBuilder argumentBuilder = new ArgumentBuilder();
- Option inputFileOption = builder
- .withLongName("input")
- .withRequired(true)
- .withArgument(
- argumentBuilder.withName("input").withMaximum(1)
- .create())
- .withDescription("where to get validate data").create();
- Option modelFileOption = builder
- .withLongName("model")
- .withRequired(true)
- .withArgument(
- argumentBuilder.withName("model").withMaximum(1)
- .create())
- .withDescription("where to get the trained model").create();
- Option defaultCagetoryOption = builder
- .withLongName("defaultCategory")
- .withRequired(false)
- .withArgument(
- argumentBuilder.withName("defaultCategory").withMaximum(1).withDefault("unknown")
- .create())
- .withDescription("the default category value to use").create();
- Group normalArgs = new GroupBuilder().withOption(help)
- .withOption(quiet).withOption(auc).withOption(scores)
- .withOption(confusion).withOption(inputFileOption)
- .withOption(modelFileOption).withOption(defaultCagetoryOption).create();
- Parser parser = new Parser();
- parser.setHelpOption(help);
- parser.setHelpTrigger("--help");
- parser.setGroup(normalArgs);
- parser.setHelpFormatter(new HelpFormatter(" ", "", " ", 130));
- CommandLine cmdLine = parser.parseAndHelp(args);
- if (cmdLine == null) {
- return false;
- }
- inputFile = getStringArgument(cmdLine, inputFileOption);
- modelFile = getStringArgument(cmdLine, modelFileOption);
- defaultCategory = getStringArgument(cmdLine, defaultCagetoryOption);
- showAuc = getBooleanArgument(cmdLine, auc);
- showScores = getBooleanArgument(cmdLine, scores);
- showConfusion = getBooleanArgument(cmdLine, confusion);
- return true;
- }
- private static boolean getBooleanArgument(CommandLine cmdLine, Option option) {
- return cmdLine.hasOption(option);
- }
- private static String getStringArgument(CommandLine cmdLine, Option inputFile) {
- return (String) cmdLine.getValue(inputFile);
- }

diff --git a/examples/src/main/java/org/apache/mahout/classifier/sgd/bankmarketing/BankMarketingClassificationMain.java b/examples/src/main/java/org/apache/mahout/classifier/sgd/bankmarketing/BankMarketingClassificationMain.java
deleted file mode 100644
index ab3c861..0000000
--- a/examples/src/main/java/org/apache/mahout/classifier/sgd/bankmarketing/BankMarketingClassificationMain.java
+++ /dev/null
@@ -1,70 +0,0 @@
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.mahout.classifier.sgd.bankmarketing;
-import com.google.common.collect.Lists;
-import org.apache.mahout.classifier.evaluation.Auc;
-import org.apache.mahout.classifier.sgd.L1;
-import org.apache.mahout.classifier.sgd.OnlineLogisticRegression;
-import java.util.Collections;
-import java.util.List;
- * Uses the SGD classifier on the 'Bank marketing' dataset from UCI.
- *
- * See http://archive.ics.uci.edu/ml/datasets/Bank+Marketing
- *
- * Learn when people accept or reject an offer from the bank via telephone based on income, age, education and more.
- */
-public class BankMarketingClassificationMain {
- public static final int NUM_CATEGORIES = 2;
- public static void main(String[] args) throws Exception {
- List<TelephoneCall> calls = Lists.newArrayList(new TelephoneCallParser("bank-full.csv"));
- double heldOutPercentage = 0.10;
- for (int run = 0; run < 20; run++) {
- Collections.shuffle(calls);
- int cutoff = (int) (heldOutPercentage * calls.size());
- List<TelephoneCall> test = calls.subList(0, cutoff);
- List<TelephoneCall> train = calls.subList(cutoff, calls.size());
- OnlineLogisticRegression lr = new OnlineLogisticRegression(NUM_CATEGORIES, TelephoneCall.FEATURES, new L1())
- .learningRate(1)
- .alpha(1)
- .lambda(0.000001)
- .stepOffset(10000)
- .decayExponent(0.2);
- for (int pass = 0; pass < 20; pass++) {
- for (TelephoneCall observation : train) {
- lr.train(observation.getTarget(), observation.asVector());
- }
- if (pass % 5 == 0) {
- Auc eval = new Auc(0.5);
- for (TelephoneCall testCall : test) {
- eval.add(testCall.getTarget(), lr.classifyScalar(testCall.asVector()));
- }
- System.out.printf("%d, %.4f, %.4f\n", pass, lr.currentLearningRate(), eval.auc());
- }
- }
- }
- }

diff --git a/examples/src/main/java/org/apache/mahout/classifier/sgd/bankmarketing/TelephoneCall.java b/examples/src/main/java/org/apache/mahout/classifier/sgd/bankmarketing/TelephoneCall.java
deleted file mode 100644
index 728ec20..0000000
--- a/examples/src/main/java/org/apache/mahout/classifier/sgd/bankmarketing/TelephoneCall.java
+++ /dev/null
@@ -1,104 +0,0 @@
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.mahout.classifier.sgd.bankmarketing;
-import org.apache.mahout.math.RandomAccessSparseVector;
-import org.apache.mahout.math.Vector;
-import org.apache.mahout.vectorizer.encoders.ConstantValueEncoder;
-import org.apache.mahout.vectorizer.encoders.FeatureVectorEncoder;
-import org.apache.mahout.vectorizer.encoders.StaticWordValueEncoder;
-import java.util.Iterator;
-import java.util.LinkedHashMap;
-import java.util.Map;
-public class TelephoneCall {
- public static final int FEATURES = 100;
- private static final ConstantValueEncoder interceptEncoder = new ConstantValueEncoder("intercept");
- private static final FeatureVectorEncoder featureEncoder = new StaticWordValueEncoder("feature");
- private RandomAccessSparseVector vector;
- private Map<String, String> fields = new LinkedHashMap<>();
- public TelephoneCall(Iterable<String> fieldNames, Iterable<String> values) {
- vector = new RandomAccessSparseVector(FEATURES);
- Iterator<String> value = values.iterator();
- interceptEncoder.addToVector("1", vector);
- for (String name : fieldNames) {
- String fieldValue = value.next();
- fields.put(name, fieldValue);
- switch (name) {
- case "age": {
- double v = Double.parseDouble(fieldValue);
- featureEncoder.addToVector(name, Math.log(v), vector);
- break;
- }
- case "balance": {
- double v;
- v = Double.parseDouble(fieldValue);
- if (v < -2000) {
- v = -2000;
- }
- featureEncoder.addToVector(name, Math.log(v + 2001) - 8, vector);
- break;
- }
- case "duration": {
- double v;
- v = Double.parseDouble(fieldValue);
- featureEncoder.addToVector(name, Math.log(v + 1) - 5, vector);
- break;
- }
- case "pdays": {
- double v;
- v = Double.parseDouble(fieldValue);
- featureEncoder.addToVector(name, Math.log(v + 2), vector);
- break;
- }
- case "job":
- case "marital":
- case "education":
- case "default":
- case "housing":
- case "loan":
- case "contact":
- case "campaign":
- case "previous":
- case "poutcome":
- featureEncoder.addToVector(name + ":" + fieldValue, 1, vector);
- break;
- case "day":
- case "month":
- case "y":
- // ignore these for vectorizing
- break;
- default:
- throw new IllegalArgumentException(String.format("Bad field name: %s", name));
- }
- }
- }
- public Vector asVector() {
- return vector;
- }
- public int getTarget() {
- return fields.get("y").equals("no") ? 0 : 1;
- }

diff --git a/examples/src/main/java/org/apache/mahout/classifier/sgd/bankmarketing/TelephoneCallParser.java b/examples/src/main/java/org/apache/mahout/classifier/sgd/bankmarketing/TelephoneCallParser.java
deleted file mode 100644
index 5ef6490..0000000
--- a/examples/src/main/java/org/apache/mahout/classifier/sgd/bankmarketing/TelephoneCallParser.java
+++ /dev/null
@@ -1,66 +0,0 @@
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.mahout.classifier.sgd.bankmarketing;
-import com.google.common.base.CharMatcher;
-import com.google.common.base.Splitter;
-import com.google.common.collect.AbstractIterator;
-import com.google.common.io.Resources;
-import java.io.BufferedReader;
-import java.io.IOException;
-import java.io.InputStreamReader;
-import java.util.Iterator;
-/** Parses semi-colon separated data as TelephoneCalls */
-public class TelephoneCallParser implements Iterable<TelephoneCall> {
- private final Splitter onSemi = Splitter.on(";").trimResults(CharMatcher.anyOf("\" ;"));
- private String resourceName;
- public TelephoneCallParser(String resourceName) throws IOException {
- this.resourceName = resourceName;
- }
- @Override
- public Iterator<TelephoneCall> iterator() {
- try {
- return new AbstractIterator<TelephoneCall>() {
- BufferedReader input =
- new BufferedReader(new InputStreamReader(Resources.getResource(resourceName).openStream()));
- Iterable<String> fieldNames = onSemi.split(input.readLine());
- @Override
- protected TelephoneCall computeNext() {
- try {
- String line = input.readLine();
- if (line == null) {
- return endOfData();
- }
- return new TelephoneCall(fieldNames, onSemi.split(line));
- } catch (IOException e) {
- throw new RuntimeException("Error reading data", e);
- }
- }
- };
- } catch (IOException e) {
- throw new RuntimeException("Error reading data", e);
- }
- }

diff --git a/examples/src/main/java/org/apache/mahout/clustering/display/ClustersFilter.java b/examples/src/main/java/org/apache/mahout/clustering/display/ClustersFilter.java
deleted file mode 100644
index a0b845f..0000000
--- a/examples/src/main/java/org/apache/mahout/clustering/display/ClustersFilter.java
+++ /dev/null
@@ -1,31 +0,0 @@
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.mahout.clustering.display;
-import org.apache.hadoop.fs.Path;
-import org.apache.hadoop.fs.PathFilter;
-final class ClustersFilter implements PathFilter {
- @Override
- public boolean accept(Path path) {
- String pathString = path.toString();
- return pathString.contains("/clusters-");
- }

diff --git a/examples/src/main/java/org/apache/mahout/clustering/display/DisplayCanopy.java b/examples/src/main/java/org/apache/mahout/clustering/display/DisplayCanopy.java
deleted file mode 100644
index 50dba99..0000000
--- a/examples/src/main/java/org/apache/mahout/clustering/display/DisplayCanopy.java
+++ /dev/null
@@ -1,88 +0,0 @@
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.mahout.clustering.display;
-import java.awt.BasicStroke;
-import java.awt.Color;
-import java.awt.Graphics;
-import java.awt.Graphics2D;
-import java.util.List;
-import org.apache.hadoop.conf.Configuration;
-import org.apache.hadoop.fs.Path;
-import org.apache.mahout.clustering.Cluster;
-import org.apache.mahout.clustering.canopy.CanopyDriver;
-import org.apache.mahout.common.HadoopUtil;
-import org.apache.mahout.common.RandomUtils;
-import org.apache.mahout.common.distance.ManhattanDistanceMeasure;
-import org.apache.mahout.math.DenseVector;
- * Java desktop graphics class that runs canopy clustering and displays the results.
- * This class generates random data and clusters it.
- */
-public class DisplayCanopy extends DisplayClustering {
- DisplayCanopy() {
- initialize();
- this.setTitle("Canopy Clusters (>" + (int) (significance * 100) + "% of population)");
- }
- @Override
- public void paint(Graphics g) {
- plotSampleData((Graphics2D) g);
- plotClusters((Graphics2D) g);
- }
- protected static void plotClusters(Graphics2D g2) {
- int cx = CLUSTERS.size() - 1;
- for (List<Cluster> clusters : CLUSTERS) {
- for (Cluster cluster : clusters) {
- if (isSignificant(cluster)) {
- g2.setStroke(new BasicStroke(1));
- g2.setColor(Color.BLUE);
- double[] t1 = {T1, T1};
- plotEllipse(g2, cluster.getCenter(), new DenseVector(t1));
- double[] t2 = {T2, T2};
- plotEllipse(g2, cluster.getCenter(), new DenseVector(t2));
- g2.setColor(COLORS[Math.min(DisplayClustering.COLORS.length - 1, cx)]);
- g2.setStroke(new BasicStroke(cx == 0 ? 3 : 1));
- plotEllipse(g2, cluster.getCenter(), cluster.getRadius().times(3));
- }
- }
- cx--;
- }
- }
- public static void main(String[] args) throws Exception {
- Path samples = new Path("samples");
- Path output = new Path("output");
- Configuration conf = new Configuration();
- HadoopUtil.delete(conf, samples);
- HadoopUtil.delete(conf, output);
- RandomUtils.useTestSeed();
- generateSamples();
- writeSampleData(samples);
- CanopyDriver.buildClusters(conf, samples, output, new ManhattanDistanceMeasure(), T1, T2, 0, true);
- loadClustersWritable(output);
- new DisplayCanopy();
- }

diff --git a/examples/src/main/java/org/apache/mahout/clustering/display/DisplayClustering.java b/examples/src/main/java/org/apache/mahout/clustering/display/DisplayClustering.java
deleted file mode 100644
index ad85c6a..0000000
--- a/examples/src/main/java/org/apache/mahout/clustering/display/DisplayClustering.java
+++ /dev/null
@@ -1,374 +0,0 @@
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.mahout.clustering.display;
-import java.awt.*;
-import java.awt.event.WindowAdapter;
-import java.awt.event.WindowEvent;
-import java.awt.geom.AffineTransform;
-import java.awt.geom.Ellipse2D;
-import java.awt.geom.Rectangle2D;
-import java.io.IOException;
-import java.util.ArrayList;
-import java.util.Collection;
-import java.util.HashMap;
-import java.util.List;
-import java.util.Map;
-import org.apache.hadoop.conf.Configuration;
-import org.apache.hadoop.fs.FileStatus;
-import org.apache.hadoop.fs.FileSystem;
-import org.apache.hadoop.fs.Path;
-import org.apache.hadoop.io.IntWritable;
-import org.apache.hadoop.io.SequenceFile;
-import org.apache.hadoop.io.Text;
-import org.apache.mahout.clustering.AbstractCluster;
-import org.apache.mahout.clustering.Cluster;
-import org.apache.mahout.clustering.UncommonDistributions;
-import org.apache.mahout.clustering.classify.WeightedVectorWritable;
-import org.apache.mahout.clustering.iterator.ClusterWritable;
-import org.apache.mahout.common.Pair;
-import org.apache.mahout.common.RandomUtils;
-import org.apache.mahout.common.iterator.sequencefile.PathFilters;
-import org.apache.mahout.common.iterator.sequencefile.PathType;
-import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirValueIterable;
-import org.apache.mahout.common.iterator.sequencefile.SequenceFileIterable;
-import org.apache.mahout.math.DenseVector;
-import org.apache.mahout.math.Vector;
-import org.apache.mahout.math.VectorWritable;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
-public class DisplayClustering extends Frame {
- private static final Logger log = LoggerFactory.getLogger(DisplayClustering.class);
- protected static final int DS = 72; // default scale = 72 pixels per inch
- protected static final int SIZE = 8; // screen size in inches
- private static final Collection<Vector> SAMPLE_PARAMS = new ArrayList<>();
- protected static final List<VectorWritable> SAMPLE_DATA = new ArrayList<>();
- protected static final List<List<Cluster>> CLUSTERS = new ArrayList<>();
- static final Color[] COLORS = { Color.red, Color.orange, Color.yellow, Color.green, Color.blue, Color.magenta,
- Color.lightGray };
- protected static final double T1 = 3.0;
- protected static final double T2 = 2.8;
- static double significance = 0.05;
- protected static int res; // screen resolution
- public DisplayClustering() {
- initialize();
- this.setTitle("Sample Data");
- }
- public void initialize() {
- // Get screen resolution
- res = Toolkit.getDefaultToolkit().getScreenResolution();
- // Set Frame size in inches
- this.setSize(SIZE * res, SIZE * res);
- this.setVisible(true);
- this.setTitle("Asymmetric Sample Data");
- // Window listener to terminate program.
- this.addWindowListener(new WindowAdapter() {
- @Override
- public void windowClosing(WindowEvent e) {
- System.exit(0);
- }
- });
- }
- public static void main(String[] args) throws Exception {
- RandomUtils.useTestSeed();
- generateSamples();
- new DisplayClustering();
- }
- // Override the paint() method
- @Override
- public void paint(Graphics g) {
- Graphics2D g2 = (Graphics2D) g;
- plotSampleData(g2);
- plotSampleParameters(g2);
- plotClusters(g2);
- }
- protected static void plotClusters(Graphics2D g2) {
- int cx = CLUSTERS.size() - 1;
- for (List<Cluster> clusters : CLUSTERS) {
- g2.setStroke(new BasicStroke(cx == 0 ? 3 : 1));
- g2.setColor(COLORS[Math.min(COLORS.length - 1, cx--)]);
- for (Cluster cluster : clusters) {
- plotEllipse(g2, cluster.getCenter(), cluster.getRadius().times(3));
- }
- }
- }
- protected static void plotSampleParameters(Graphics2D g2) {
- Vector v = new DenseVector(2);
- Vector dv = new DenseVector(2);
- g2.setColor(Color.RED);
- for (Vector param : SAMPLE_PARAMS) {
- v.set(0, param.get(0));
- v.set(1, param.get(1));
- dv.set(0, param.get(2) * 3);
- dv.set(1, param.get(3) * 3);
- plotEllipse(g2, v, dv);
- }
- }
- protected static void plotSampleData(Graphics2D g2) {
- double sx = (double) res / DS;
- g2.setTransform(AffineTransform.getScaleInstance(sx, sx));
- // plot the axes
- g2.setColor(Color.BLACK);
- Vector dv = new DenseVector(2).assign(SIZE / 2.0);
- plotRectangle(g2, new DenseVector(2).assign(2), dv);
- plotRectangle(g2, new DenseVector(2).assign(-2), dv);
- // plot the sample data
- g2.setColor(Color.DARK_GRAY);
- dv.assign(0.03);
- for (VectorWritable v : SAMPLE_DATA) {
- plotRectangle(g2, v.get(), dv);
- }
- }
- /**
- * This method plots points and colors them according to their cluster
- * membership, rather than drawing ellipses.
- *
- * As of commit, this method is used only by K-means spectral clustering.
- * Since the cluster assignments are set within the eigenspace of the data, it
- * is not inherent that the original data cluster as they would in K-means:
- * that is, as symmetric gaussian mixtures.
- *
- * Since Spectral K-Means uses K-Means to cluster the eigenspace data, the raw
- * output is not directly usable. Rather, the cluster assignments from the raw
- * output need to be transferred back to the original data. As such, this
- * method will read the SequenceFile cluster results of K-means and transfer
- * the cluster assignments to the original data, coloring them appropriately.
- *
- * @param g2
- * @param data
- */
- protected static void plotClusteredSampleData(Graphics2D g2, Path data) {
- double sx = (double) res / DS;
- g2.setTransform(AffineTransform.getScaleInstance(sx, sx));
- g2.setColor(Color.BLACK);
- Vector dv = new DenseVector(2).assign(SIZE / 2.0);
- plotRectangle(g2, new DenseVector(2).assign(2), dv);
- plotRectangle(g2, new DenseVector(2).assign(-2), dv);
- // plot the sample data, colored according to the cluster they belong to
- dv.assign(0.03);
- Path clusteredPointsPath = new Path(data, "clusteredPoints");
- Path inputPath = new Path(clusteredPointsPath, "part-m-00000");
- Map<Integer,Color> colors = new HashMap<>();
- int point = 0;
- for (Pair<IntWritable,WeightedVectorWritable> record : new SequenceFileIterable<IntWritable,WeightedVectorWritable>(
- inputPath, new Configuration())) {
- int clusterId = record.getFirst().get();
- VectorWritable v = SAMPLE_DATA.get(point++);
- Integer key = clusterId;
- if (!colors.containsKey(key)) {
- colors.put(key, COLORS[Math.min(COLORS.length - 1, colors.size())]);
- }
- plotClusteredRectangle(g2, v.get(), dv, colors.get(key));
- }
- }
- /**
- * Identical to plotRectangle(), but with the option of setting the color of
- * the rectangle's stroke.
- *
- * NOTE: This should probably be refactored with plotRectangle() since most of
- * the code here is direct copy/paste from that method.
- *
- * @param g2
- * A Graphics2D context.
- * @param v
- * A vector for the rectangle's center.
- * @param dv
- * A vector for the rectangle's dimensions.
- * @param color
- * The color of the rectangle's stroke.
- */
- protected static void plotClusteredRectangle(Graphics2D g2, Vector v, Vector dv, Color color) {
- double[] flip = {1, -1};
- Vector v2 = v.times(new DenseVector(flip));
- v2 = v2.minus(dv.divide(2));
- int h = SIZE / 2;
- double x = v2.get(0) + h;
- double y = v2.get(1) + h;
- g2.setStroke(new BasicStroke(1));
- g2.setColor(color);
- g2.draw(new Rectangle2D.Double(x * DS, y * DS, dv.get(0) * DS, dv.get(1) * DS));
- }
- /**
- * Draw a rectangle on the graphics context
- *
- * @param g2
- * a Graphics2D context
- * @param v
- * a Vector of rectangle center
- * @param dv
- * a Vector of rectangle dimensions
- */
- protected static void plotRectangle(Graphics2D g2, Vector v, Vector dv) {
- double[] flip = {1, -1};
- Vector v2 = v.times(new DenseVector(flip));
- v2 = v2.minus(dv.divide(2));
- int h = SIZE / 2;
- double x = v2.get(0) + h;
- double y = v2.get(1) + h;
- g2.draw(new Rectangle2D.Double(x * DS, y * DS, dv.get(0) * DS, dv.get(1) * DS));
- }
- /**
- * Draw an ellipse on the graphics context
- *
- * @param g2
- * a Graphics2D context
- * @param v
- * a Vector of ellipse center
- * @param dv
- * a Vector of ellipse dimensions
- */
- protected static void plotEllipse(Graphics2D g2, Vector v, Vector dv) {
- double[] flip = {1, -1};
- Vector v2 = v.times(new DenseVector(flip));
- v2 = v2.minus(dv.divide(2));
- int h = SIZE / 2;
- double x = v2.get(0) + h;
- double y = v2.get(1) + h;
- g2.draw(new Ellipse2D.Double(x * DS, y * DS, dv.get(0) * DS, dv.get(1) * DS));
- }
- protected static void generateSamples() {
- generateSamples(500, 1, 1, 3);
- generateSamples(300, 1, 0, 0.5);
- generateSamples(300, 0, 2, 0.1);
- }
- protected static void generate2dSamples() {
- generate2dSamples(500, 1, 1, 3, 1);
- generate2dSamples(300, 1, 0, 0.5, 1);
- generate2dSamples(300, 0, 2, 0.1, 0.5);
- }
- /**
- * Generate random samples and add them to the sampleData
- *
- * @param num
- * int number of samples to generate
- * @param mx
- * double x-value of the sample mean
- * @param my
- * double y-value of the sample mean
- * @param sd
- * double standard deviation of the samples
- */
- protected static void generateSamples(int num, double mx, double my, double sd) {
- double[] params = {mx, my, sd, sd};
- SAMPLE_PARAMS.add(new DenseVector(params));
- log.info("Generating {} samples m=[{}, {}] sd={}", num, mx, my, sd);
- for (int i = 0; i < num; i++) {
- SAMPLE_DATA.add(new VectorWritable(new DenseVector(new double[] {UncommonDistributions.rNorm(mx, sd),
- UncommonDistributions.rNorm(my, sd)})));
- }
- }
- protected static void writeSampleData(Path output) throws IOException {
- Configuration conf = new Configuration();
- FileSystem fs = FileSystem.get(output.toUri(), conf);
- try (SequenceFile.Writer writer = new SequenceFile.Writer(fs, conf, output, Text.class, VectorWritable.class)) {
- int i = 0;
- for (VectorWritable vw : SAMPLE_DATA) {
- writer.append(new Text("sample_" + i++), vw);
- }
- }
- }
- protected static List<Cluster> readClustersWritable(Path clustersIn) {
- List<Cluster> clusters = new ArrayList<>();
- Configuration conf = new Configuration();
- for (ClusterWritable value : new SequenceFileDirValueIterable<ClusterWritable>(clustersIn, PathType.LIST,
- PathFilters.logsCRCFilter(), conf)) {
- Cluster cluster = value.getValue();
- log.info(
- "Reading Cluster:{} center:{} numPoints:{} radius:{}",
- cluster.getId(), AbstractCluster.formatVector(cluster.getCenter(), null),
- cluster.getNumObservations(), AbstractCluster.formatVector(cluster.getRadius(), null));
- clusters.add(cluster);
- }
- return clusters;
- }
- protected static void loadClustersWritable(Path output) throws IOException {
- Configuration conf = new Configuration();
- FileSystem fs = FileSystem.get(output.toUri(), conf);
- for (FileStatus s : fs.listStatus(output, new ClustersFilter())) {
- List<Cluster> clusters = readClustersWritable(s.getPath());
- CLUSTERS.add(clusters);
- }
- }
- /**
- * Generate random samples and add them to the sampleData
- *
- * @param num
- * int number of samples to generate
- * @param mx
- * double x-value of the sample mean
- * @param my
- * double y-value of the sample mean
- * @param sdx
- * double x-value standard deviation of the samples
- * @param sdy
- * double y-value standard deviation of the samples
- */
- protected static void generate2dSamples(int num, double mx, double my, double sdx, double sdy) {
- double[] params = {mx, my, sdx, sdy};
- SAMPLE_PARAMS.add(new DenseVector(params));
- log.info("Generating {} samples m=[{}, {}] sd=[{}, {}]", num, mx, my, sdx, sdy);
- for (int i = 0; i < num; i++) {
- SAMPLE_DATA.add(new VectorWritable(new DenseVector(new double[] {UncommonDistributions.rNorm(mx, sdx),
- UncommonDistributions.rNorm(my, sdy)})));
- }
- }
- protected static boolean isSignificant(Cluster cluster) {
- return (double) cluster.getNumObservations() / SAMPLE_DATA.size() > significance;
- }

diff --git a/examples/src/main/java/org/apache/mahout/clustering/display/DisplayFuzzyKMeans.java b/examples/src/main/java/org/apache/mahout/clustering/display/DisplayFuzzyKMeans.java
deleted file mode 100644
index f8ce7c7..0000000
--- a/examples/src/main/java/org/apache/mahout/clustering/display/DisplayFuzzyKMeans.java
+++ /dev/null
@@ -1,110 +0,0 @@
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.mahout.clustering.display;
-import java.awt.Graphics;
-import java.awt.Graphics2D;
-import java.io.IOException;
-import java.util.Collection;
-import java.util.List;
-import org.apache.hadoop.conf.Configuration;
-import org.apache.hadoop.fs.Path;
-import org.apache.mahout.clustering.Cluster;
-import org.apache.mahout.clustering.classify.ClusterClassifier;
-import org.apache.mahout.clustering.fuzzykmeans.FuzzyKMeansDriver;
-import org.apache.mahout.clustering.fuzzykmeans.SoftCluster;
-import org.apache.mahout.clustering.iterator.ClusterIterator;
-import org.apache.mahout.clustering.iterator.FuzzyKMeansClusteringPolicy;
-import org.apache.mahout.clustering.kmeans.RandomSeedGenerator;
-import org.apache.mahout.common.HadoopUtil;
-import org.apache.mahout.common.RandomUtils;
-import org.apache.mahout.common.distance.DistanceMeasure;
-import org.apache.mahout.common.distance.ManhattanDistanceMeasure;
-import org.apache.mahout.math.Vector;
-import com.google.common.collect.Lists;
-public class DisplayFuzzyKMeans extends DisplayClustering {
- DisplayFuzzyKMeans() {
- initialize();
- this.setTitle("Fuzzy k-Means Clusters (>" + (int) (significance * 100) + "% of population)");
- }
- // Override the paint() method
- @Override
- public void paint(Graphics g) {
- plotSampleData((Graphics2D) g);
- plotClusters((Graphics2D) g);
- }
- public static void main(String[] args) throws Exception {
- DistanceMeasure measure = new ManhattanDistanceMeasure();
- Path samples = new Path("samples");
- Path output = new Path("output");
- Configuration conf = new Configuration();
- HadoopUtil.delete(conf, output);
- HadoopUtil.delete(conf, samples);
- RandomUtils.useTestSeed();
- DisplayClustering.generateSamples();
- writeSampleData(samples);
- boolean runClusterer = true;
- int maxIterations = 10;
- float threshold = 0.001F;
- float m = 1.1F;
- if (runClusterer) {
- runSequentialFuzzyKClusterer(conf, samples, output, measure, maxIterations, m, threshold);
- } else {
- int numClusters = 3;
- runSequentialFuzzyKClassifier(conf, samples, output, measure, numClusters, maxIterations, m, threshold);
- }
- new DisplayFuzzyKMeans();
- }
- private static void runSequentialFuzzyKClassifier(Configuration conf, Path samples, Path output,
- DistanceMeasure measure, int numClusters, int maxIterations, float m, double threshold) throws IOException {
- Collection<Vector> points = Lists.newArrayList();
- for (int i = 0; i < numClusters; i++) {
- points.add(SAMPLE_DATA.get(i).get());
- }
- List<Cluster> initialClusters = Lists.newArrayList();
- int id = 0;
- for (Vector point : points) {
- initialClusters.add(new SoftCluster(point, id++, measure));
- }
- ClusterClassifier prior = new ClusterClassifier(initialClusters, new FuzzyKMeansClusteringPolicy(m, threshold));
- Path priorPath = new Path(output, "classifier-0");
- prior.writeToSeqFiles(priorPath);
- ClusterIterator.iterateSeq(conf, samples, priorPath, output, maxIterations);
- loadClustersWritable(output);
- }
- private static void runSequentialFuzzyKClusterer(Configuration conf, Path samples, Path output,
- DistanceMeasure measure, int maxIterations, float m, double threshold) throws IOException,
- ClassNotFoundException, InterruptedException {
- Path clustersIn = new Path(output, "random-seeds");
- RandomSeedGenerator.buildRandom(conf, samples, clustersIn, 3, measure);
- FuzzyKMeansDriver.run(samples, clustersIn, output, threshold, maxIterations, m, true, true, threshold,
- true);
- loadClustersWritable(output);
- }

diff --git a/examples/src/main/java/org/apache/mahout/clustering/display/DisplayKMeans.java b/examples/src/main/java/org/apache/mahout/clustering/display/DisplayKMeans.java
deleted file mode 100644
index 336d69e..0000000
--- a/examples/src/main/java/org/apache/mahout/clustering/display/DisplayKMeans.java
+++ /dev/null
@@ -1,106 +0,0 @@
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.mahout.clustering.display;
-import java.awt.Graphics;
-import java.awt.Graphics2D;
-import java.io.IOException;
-import java.util.Collection;
-import java.util.List;
-import org.apache.hadoop.conf.Configuration;
-import org.apache.hadoop.fs.Path;
-import org.apache.mahout.clustering.Cluster;
-import org.apache.mahout.clustering.classify.ClusterClassifier;
-import org.apache.mahout.clustering.iterator.ClusterIterator;
-import org.apache.mahout.clustering.iterator.KMeansClusteringPolicy;
-import org.apache.mahout.clustering.kmeans.KMeansDriver;
-import org.apache.mahout.clustering.kmeans.RandomSeedGenerator;
-import org.apache.mahout.common.HadoopUtil;
-import org.apache.mahout.common.RandomUtils;
-import org.apache.mahout.common.distance.DistanceMeasure;
-import org.apache.mahout.common.distance.ManhattanDistanceMeasure;
-import org.apache.mahout.math.Vector;
-import com.google.common.collect.Lists;
-public class DisplayKMeans extends DisplayClustering {
- DisplayKMeans() {
- initialize();
- this.setTitle("k-Means Clusters (>" + (int) (significance * 100) + "% of population)");
- }
- public static void main(String[] args) throws Exception {
- DistanceMeasure measure = new ManhattanDistanceMeasure();
- Path samples = new Path("samples");
- Path output = new Path("output");
- Configuration conf = new Configuration();
- HadoopUtil.delete(conf, samples);
- HadoopUtil.delete(conf, output);
- RandomUtils.useTestSeed();
- generateSamples();
- writeSampleData(samples);
- boolean runClusterer = true;
- double convergenceDelta = 0.001;
- int numClusters = 3;
- int maxIterations = 10;
- if (runClusterer) {
- runSequentialKMeansClusterer(conf, samples, output, measure, numClusters, maxIterations, convergenceDelta);
- } else {
- runSequentialKMeansClassifier(conf, samples, output, measure, numClusters, maxIterations, convergenceDelta);
- }
- new DisplayKMeans();
- }
- private static void runSequentialKMeansClassifier(Configuration conf, Path samples, Path output,
- DistanceMeasure measure, int numClusters, int maxIterations, double convergenceDelta) throws IOException {
- Collection<Vector> points = Lists.newArrayList();
- for (int i = 0; i < numClusters; i++) {
- points.add(SAMPLE_DATA.get(i).get());
- }
- List<Cluster> initialClusters = Lists.newArrayList();
- int id = 0;
- for (Vector point : points) {
- initialClusters.add(new org.apache.mahout.clustering.kmeans.Kluster(point, id++, measure));
- }
- ClusterClassifier prior = new ClusterClassifier(initialClusters, new KMeansClusteringPolicy(convergenceDelta));
- Path priorPath = new Path(output, Cluster.INITIAL_CLUSTERS_DIR);
- prior.writeToSeqFiles(priorPath);
- ClusterIterator.iterateSeq(conf, samples, priorPath, output, maxIterations);
- loadClustersWritable(output);
- }
- private static void runSequentialKMeansClusterer(Configuration conf, Path samples, Path output,
- DistanceMeasure measure, int numClusters, int maxIterations, double convergenceDelta)
- throws IOException, InterruptedException, ClassNotFoundException {
- Path clustersIn = new Path(output, "random-seeds");
- RandomSeedGenerator.buildRandom(conf, samples, clustersIn, numClusters, measure);
- KMeansDriver.run(samples, clustersIn, output, convergenceDelta, maxIterations, true, 0.0, true);
- loadClustersWritable(output);
- }
- // Override the paint() method
- @Override
- public void paint(Graphics g) {
- plotSampleData((Graphics2D) g);
- plotClusters((Graphics2D) g);
- }

diff --git a/examples/src/main/java/org/apache/mahout/clustering/display/DisplaySpectralKMeans.java b/examples/src/main/java/org/apache/mahout/clustering/display/DisplaySpectralKMeans.java
deleted file mode 100644
index 2b70749..0000000
--- a/examples/src/main/java/org/apache/mahout/clustering/display/DisplaySpectralKMeans.java
+++ /dev/null
@@ -1,85 +0,0 @@
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.mahout.clustering.display;
-import java.awt.Graphics;
-import java.awt.Graphics2D;
-import java.io.BufferedWriter;
-import java.io.FileWriter;
-import java.io.Writer;
-import org.apache.hadoop.conf.Configuration;
-import org.apache.hadoop.fs.FileSystem;
-import org.apache.hadoop.fs.Path;
-import org.apache.mahout.clustering.spectral.kmeans.SpectralKMeansDriver;
-import org.apache.mahout.common.HadoopUtil;
-import org.apache.mahout.common.RandomUtils;
-import org.apache.mahout.common.distance.DistanceMeasure;
-import org.apache.mahout.common.distance.ManhattanDistanceMeasure;
-public class DisplaySpectralKMeans extends DisplayClustering {
- protected static final String SAMPLES = "samples";
- protected static final String OUTPUT = "output";
- protected static final String TEMP = "tmp";
- protected static final String AFFINITIES = "affinities";
- DisplaySpectralKMeans() {
- initialize();
- setTitle("Spectral k-Means Clusters (>" + (int) (significance * 100) + "% of population)");
- }
- public static void main(String[] args) throws Exception {
- DistanceMeasure measure = new ManhattanDistanceMeasure();
- Path samples = new Path(SAMPLES);
- Path output = new Path(OUTPUT);
- Path tempDir = new Path(TEMP);
- Configuration conf = new Configuration();
- HadoopUtil.delete(conf, samples);
- HadoopUtil.delete(conf, output);
- RandomUtils.useTestSeed();
- DisplayClustering.generateSamples();
- writeSampleData(samples);
- Path affinities = new Path(output, AFFINITIES);
- FileSystem fs = FileSystem.get(output.toUri(), conf);
- if (!fs.exists(output)) {
- fs.mkdirs(output);
- }
- try (Writer writer = new BufferedWriter(new FileWriter(affinities.toString()))){
- for (int i = 0; i < SAMPLE_DATA.size(); i++) {
- for (int j = 0; j < SAMPLE_DATA.size(); j++) {
- writer.write(i + "," + j + ',' + measure.distance(SAMPLE_DATA.get(i).get(),
- SAMPLE_DATA.get(j).get()) + '\n');
- }
- }
- }
- int maxIter = 10;
- double convergenceDelta = 0.001;
- SpectralKMeansDriver.run(new Configuration(), affinities, output, SAMPLE_DATA.size(), 3, measure,
- convergenceDelta, maxIter, tempDir);
- new DisplaySpectralKMeans();
- }
- @Override
- public void paint(Graphics g) {
- plotClusteredSampleData((Graphics2D) g, new Path(new Path(OUTPUT), "kmeans_out"));
- }

diff --git a/examples/src/main/java/org/apache/mahout/clustering/display/README.txt b/examples/src/main/java/org/apache/mahout/clustering/display/README.txt
deleted file mode 100644
index 470c16c..0000000
--- a/examples/src/main/java/org/apache/mahout/clustering/display/README.txt
+++ /dev/null
@@ -1,22 +0,0 @@
-The following classes can be run without parameters to generate a sample data set and
-run the reference clustering implementations over them:
-DisplayClustering - generates 1000 samples from three, symmetric distributions. This is the same
- data set that is used by the following clustering programs. It displays the points on a screen
- and superimposes the model parameters that were used to generate the points. You can edit the
- generateSamples() method to change the sample points used by these programs.
- * DisplayCanopy - uses Canopy clustering
- * DisplayKMeans - uses k-Means clustering
- * DisplayFuzzyKMeans - uses Fuzzy k-Means clustering
- * NOTE: some of these programs display the sample points and then superimpose all of the clusters
- from each iteration. The last iteration's clusters are in bold red and the previous several are
- colored (orange, yellow, green, blue, violet) in order after which all earlier clusters are in
- light grey. This helps to visualize how the clusters converge upon a solution over multiple
- iterations.
- * NOTE: by changing the parameter values (k, ALPHA_0, numIterations) and the display SIGNIFICANCE
- you can obtain different results.
\ No newline at end of file

diff --git a/examples/src/main/java/org/apache/mahout/clustering/streaming/tools/ClusterQualitySummarizer.java b/examples/src/main/java/org/apache/mahout/clustering/streaming/tools/ClusterQualitySummarizer.java
deleted file mode 100644
index c29cbc4..0000000
--- a/examples/src/main/java/org/apache/mahout/clustering/streaming/tools/ClusterQualitySummarizer.java
+++ /dev/null
@@ -1,279 +0,0 @@
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.mahout.clustering.streaming.tools;
-import java.io.FileOutputStream;
-import java.io.IOException;
-import java.io.PrintWriter;
-import java.util.List;
-import com.google.common.collect.Iterables;
-import com.google.common.collect.Lists;
-import com.google.common.io.Closeables;
-import org.apache.commons.cli2.CommandLine;
-import org.apache.commons.cli2.Group;
-import org.apache.commons.cli2.Option;
-import org.apache.commons.cli2.builder.ArgumentBuilder;
-import org.apache.commons.cli2.builder.DefaultOptionBuilder;
-import org.apache.commons.cli2.builder.GroupBuilder;
-import org.apache.commons.cli2.commandline.Parser;
-import org.apache.commons.cli2.util.HelpFormatter;
-import org.apache.hadoop.conf.Configuration;
-import org.apache.hadoop.fs.Path;
-import org.apache.mahout.clustering.iterator.ClusterWritable;
-import org.apache.mahout.clustering.ClusteringUtils;
-import org.apache.mahout.clustering.streaming.mapreduce.CentroidWritable;
-import org.apache.mahout.common.AbstractJob;
-import org.apache.mahout.common.distance.DistanceMeasure;
-import org.apache.mahout.common.distance.SquaredEuclideanDistanceMeasure;
-import org.apache.mahout.common.iterator.sequencefile.PathType;
-import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirValueIterable;
-import org.apache.mahout.math.Centroid;
-import org.apache.mahout.math.Vector;
-import org.apache.mahout.math.VectorWritable;
-import org.apache.mahout.math.stats.OnlineSummarizer;
-public class ClusterQualitySummarizer extends AbstractJob {
- private String outputFile;
- private PrintWriter fileOut;
- private String trainFile;
- private String testFile;
- private String centroidFile;
- private String centroidCompareFile;
- private boolean mahoutKMeansFormat;
- private boolean mahoutKMeansFormatCompare;
- private DistanceMeasure distanceMeasure = new SquaredEuclideanDistanceMeasure();
- public void printSummaries(List<OnlineSummarizer> summarizers, String type) {
- printSummaries(summarizers, type, fileOut);
- }
- public static void printSummaries(List<OnlineSummarizer> summarizers, String type, PrintWriter fileOut) {
- double maxDistance = 0;
- for (int i = 0; i < summarizers.size(); ++i) {
- OnlineSummarizer summarizer = summarizers.get(i);
- if (summarizer.getCount() > 1) {
- maxDistance = Math.max(maxDistance, summarizer.getMax());
- System.out.printf("Average distance in cluster %d [%d]: %f\n", i, summarizer.getCount(), summarizer.getMean());
- // If there is just one point in the cluster, quartiles cannot be estimated. We'll just assume all the quartiles
- // equal the only value.
- if (fileOut != null) {
- fileOut.printf("%d,%f,%f,%f,%f,%f,%f,%f,%d,%s\n", i, summarizer.getMean(),
- summarizer.getSD(),
- summarizer.getQuartile(0),
- summarizer.getQuartile(1),
- summarizer.getQuartile(2),
- summarizer.getQuartile(3),
- summarizer.getQuartile(4), summarizer.getCount(), type);
- }
- } else {
- System.out.printf("Cluster %d is has %d data point. Need atleast 2 data points in a cluster for" +
- " OnlineSummarizer.\n", i, summarizer.getCount());
- }
- }
- System.out.printf("Num clusters: %d; maxDistance: %f\n", summarizers.size(), maxDistance);
- }
- public int run(String[] args) throws IOException {
- if (!parseArgs(args)) {
- return -1;
- }
- Configuration conf = new Configuration();
- try {
- fileOut = new PrintWriter(new FileOutputStream(outputFile));
- fileOut.printf("cluster,distance.mean,distance.sd,distance.q0,distance.q1,distance.q2,distance.q3,"
- + "distance.q4,count,is.train\n");
- // Reading in the centroids (both pairs, if they exist).
- List<Centroid> centroids;
- List<Centroid> centroidsCompare = null;
- if (mahoutKMeansFormat) {
- SequenceFileDirValueIterable<ClusterWritable> clusterIterable =
- new SequenceFileDirValueIterable<>(new Path(centroidFile), PathType.GLOB, conf);
- centroids = Lists.newArrayList(IOUtils.getCentroidsFromClusterWritableIterable(clusterIterable));
- } else {
- SequenceFileDirValueIterable<CentroidWritable> centroidIterable =
- new SequenceFileDirValueIterable<>(new Path(centroidFile), PathType.GLOB, conf);
- centroids = Lists.newArrayList(IOUtils.getCentroidsFromCentroidWritableIterable(centroidIterable));
- }
- if (centroidCompareFile != null) {
- if (mahoutKMeansFormatCompare) {
- SequenceFileDirValueIterable<ClusterWritable> clusterCompareIterable =
- new SequenceFileDirValueIterable<>(new Path(centroidCompareFile), PathType.GLOB, conf);
- centroidsCompare = Lists.newArrayList(
- IOUtils.getCentroidsFromClusterWritableIterable(clusterCompareIterable));
- } else {
- SequenceFileDirValueIterable<CentroidWritable> centroidCompareIterable =
- new SequenceFileDirValueIterable<>(new Path(centroidCompareFile), PathType.GLOB, conf);
- centroidsCompare = Lists.newArrayList(
- IOUtils.getCentroidsFromCentroidWritableIterable(centroidCompareIterable));
- }
- }
- // Reading in the "training" set.
- SequenceFileDirValueIterable<VectorWritable> trainIterable =
- new SequenceFileDirValueIterable<>(new Path(trainFile), PathType.GLOB, conf);
- Iterable<Vector> trainDatapoints = IOUtils.getVectorsFromVectorWritableIterable(trainIterable);
- Iterable<Vector> datapoints = trainDatapoints;
- printSummaries(ClusteringUtils.summarizeClusterDistances(trainDatapoints, centroids,
- new SquaredEuclideanDistanceMeasure()), "train");
- // Also adding in the "test" set.
- if (testFile != null) {
- SequenceFileDirValueIterable<VectorWritable> testIterable =
- new SequenceFileDirValueIterable<>(new Path(testFile), PathType.GLOB, conf);
- Iterable<Vector> testDatapoints = IOUtils.getVectorsFromVectorWritableIterable(testIterable);
- printSummaries(ClusteringUtils.summarizeClusterDistances(testDatapoints, centroids,
- new SquaredEuclideanDistanceMeasure()), "test");
- datapoints = Iterables.concat(trainDatapoints, testDatapoints);
- }
- // At this point, all train/test CSVs have been written. We now compute quality metrics.
- List<OnlineSummarizer> summaries =
- ClusteringUtils.summarizeClusterDistances(datapoints, centroids, distanceMeasure);
- List<OnlineSummarizer> compareSummaries = null;
- if (centroidsCompare != null) {
- compareSummaries = ClusteringUtils.summarizeClusterDistances(datapoints, centroidsCompare, distanceMeasure);
- }
- System.out.printf("[Dunn Index] First: %f", ClusteringUtils.dunnIndex(centroids, distanceMeasure, summaries));
- if (compareSummaries != null) {
- System.out.printf(" Second: %f\n", ClusteringUtils.dunnIndex(centroidsCompare, distanceMeasure, compareSummaries));
- } else {
- System.out.printf("\n");
- }
- System.out.printf("[Davies-Bouldin Index] First: %f",
- ClusteringUtils.daviesBouldinIndex(centroids, distanceMeasure, summaries));
- if (compareSummaries != null) {
- System.out.printf(" Second: %f\n",
- ClusteringUtils.daviesBouldinIndex(centroidsCompare, distanceMeasure, compareSummaries));
- } else {
- System.out.printf("\n");
- }
- } catch (IOException e) {
- System.out.println(e.getMessage());
- } finally {
- Closeables.close(fileOut, false);
- }
- return 0;
- }
- private boolean parseArgs(String[] args) {
- DefaultOptionBuilder builder = new DefaultOptionBuilder();
- Option help = builder.withLongName("help").withDescription("print this list").create();
- ArgumentBuilder argumentBuilder = new ArgumentBuilder();
- Option inputFileOption = builder.withLongName("input")
- .withShortName("i")
- .withRequired(true)
- .withArgument(argumentBuilder.withName("input").withMaximum(1).create())
- .withDescription("where to get seq files with the vectors (training set)")
- .create();
- Option testInputFileOption = builder.withLongName("testInput")
- .withShortName("itest")
- .withArgument(argumentBuilder.withName("testInput").withMaximum(1).create())
- .withDescription("where to get seq files with the vectors (test set)")
- .create();
- Option centroidsFileOption = builder.withLongName("centroids")
- .withShortName("c")
- .withRequired(true)
- .withArgument(argumentBuilder.withName("centroids").withMaximum(1).create())
- .withDescription("where to get seq files with the centroids (from Mahout KMeans or StreamingKMeansDriver)")
- .create();
- Option centroidsCompareFileOption = builder.withLongName("centroidsCompare")
- .withShortName("cc")
- .withRequired(false)
- .withArgument(argumentBuilder.withName("centroidsCompare").withMaximum(1).create())
- .withDescription("where to get seq files with the second set of centroids (from Mahout KMeans or "
- + "StreamingKMeansDriver)")
- .create();
- Option outputFileOption = builder.withLongName("output")
- .withShortName("o")
- .withRequired(true)
- .withArgument(argumentBuilder.withName("output").withMaximum(1).create())
- .withDescription("where to dump the CSV file with the results")
- .create();
- Option mahoutKMeansFormatOption = builder.withLongName("mahoutkmeansformat")
- .withShortName("mkm")
- .withDescription("if set, read files as (IntWritable, ClusterWritable) pairs")
- .withArgument(argumentBuilder.withName("numpoints").withMaximum(1).create())
- .create();
- Option mahoutKMeansCompareFormatOption = builder.withLongName("mahoutkmeansformatCompare")
- .withShortName("mkmc")
- .withDescription("if set, read files as (IntWritable, ClusterWritable) pairs")
- .withArgument(argumentBuilder.withName("numpoints").withMaximum(1).create())
- .create();
- Group normalArgs = new GroupBuilder()
- .withOption(help)
- .withOption(inputFileOption)
- .withOption(testInputFileOption)
- .withOption(outputFileOption)
- .withOption(centroidsFileOption)
- .withOption(centroidsCompareFileOption)
- .withOption(mahoutKMeansFormatOption)
- .withOption(mahoutKMeansCompareFormatOption)
- .create();
- Parser parser = new Parser();
- parser.setHelpOption(help);
- parser.setHelpTrigger("--help");
- parser.setGroup(normalArgs);
- parser.setHelpFormatter(new HelpFormatter(" ", "", " ", 150));
- CommandLine cmdLine = parser.parseAndHelp(args);
- if (cmdLine == null) {
- return false;
- }
- trainFile = (String) cmdLine.getValue(inputFileOption);
- if (cmdLine.hasOption(testInputFileOption)) {
- testFile = (String) cmdLine.getValue(testInputFileOption);
- }
- centroidFile = (String) cmdLine.getValue(centroidsFileOption);
- if (cmdLine.hasOption(centroidsCompareFileOption)) {
- centroidCompareFile = (String) cmdLine.getValue(centroidsCompareFileOption);
- }
- outputFile = (String) cmdLine.getValue(outputFileOption);
- if (cmdLine.hasOption(mahoutKMeansFormatOption)) {
- mahoutKMeansFormat = true;
- }
- if (cmdLine.hasOption(mahoutKMeansCompareFormatOption)) {
- mahoutKMeansFormatCompare = true;
- }
- return true;
- }
- public static void main(String[] args) throws IOException {
- new ClusterQualitySummarizer().run(args);
- }

diff --git a/examples/src/main/java/org/apache/mahout/clustering/streaming/tools/IOUtils.java b/examples/src/main/java/org/apache/mahout/clustering/streaming/tools/IOUtils.java
deleted file mode 100644
index bd1149b..0000000
--- a/examples/src/main/java/org/apache/mahout/clustering/streaming/tools/IOUtils.java
+++ /dev/null
@@ -1,80 +0,0 @@
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.mahout.clustering.streaming.tools;
-import com.google.common.base.Function;
-import com.google.common.base.Preconditions;
-import com.google.common.collect.Iterables;
-import org.apache.mahout.clustering.iterator.ClusterWritable;
-import org.apache.mahout.clustering.streaming.mapreduce.CentroidWritable;
-import org.apache.mahout.math.Centroid;
-import org.apache.mahout.math.Vector;
-import org.apache.mahout.math.VectorWritable;
-public class IOUtils {
- private IOUtils() {}
- /**
- * Converts CentroidWritable values in a sequence file into Centroids lazily.
- * @param dirIterable the source iterable (comes from a SequenceFileDirIterable).
- * @return an Iterable<Centroid> with the converted vectors.
- */
- public static Iterable<Centroid> getCentroidsFromCentroidWritableIterable(
- Iterable<CentroidWritable> dirIterable) {
- return Iterables.transform(dirIterable, new Function<CentroidWritable, Centroid>() {
- @Override
- public Centroid apply(CentroidWritable input) {
- Preconditions.checkNotNull(input);
- return input.getCentroid().clone();
- }
- });
- }
- /**
- * Converts CentroidWritable values in a sequence file into Centroids lazily.
- * @param dirIterable the source iterable (comes from a SequenceFileDirIterable).
- * @return an Iterable<Centroid> with the converted vectors.
- */
- public static Iterable<Centroid> getCentroidsFromClusterWritableIterable(Iterable<ClusterWritable> dirIterable) {
- return Iterables.transform(dirIterable, new Function<ClusterWritable, Centroid>() {
- int numClusters = 0;
- @Override
- public Centroid apply(ClusterWritable input) {
- Preconditions.checkNotNull(input);
- return new Centroid(numClusters++, input.getValue().getCenter().clone(),
- input.getValue().getTotalObservations());
- }
- });
- }
- /**
- * Converts VectorWritable values in a sequence file into Vectors lazily.
- * @param dirIterable the source iterable (comes from a SequenceFileDirIterable).
- * @return an Iterable<Vector> with the converted vectors.
- */
- public static Iterable<Vector> getVectorsFromVectorWritableIterable(Iterable<VectorWritable> dirIterable) {
- return Iterables.transform(dirIterable, new Function<VectorWritable, Vector>() {
- @Override
- public Vector apply(VectorWritable input) {
- Preconditions.checkNotNull(input);
- return input.get().clone();
- }
- });
- }

diff --git a/examples/src/main/java/org/apache/mahout/clustering/syntheticcontrol/canopy/Job.java b/examples/src/main/java/org/apache/mahout/clustering/syntheticcontrol/canopy/Job.java
deleted file mode 100644
index 083cd8c..0000000
--- a/examples/src/main/java/org/apache/mahout/clustering/syntheticcontrol/canopy/Job.java
+++ /dev/null
@@ -1,125 +0,0 @@
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.mahout.clustering.syntheticcontrol.canopy;
-import java.util.List;
-import java.util.Map;
-import org.apache.hadoop.conf.Configuration;
-import org.apache.hadoop.fs.Path;
-import org.apache.hadoop.util.ToolRunner;
-import org.apache.mahout.clustering.canopy.CanopyDriver;
-import org.apache.mahout.clustering.conversion.InputDriver;
-import org.apache.mahout.common.AbstractJob;
-import org.apache.mahout.common.ClassUtils;
-import org.apache.mahout.common.HadoopUtil;
-import org.apache.mahout.common.commandline.DefaultOptionCreator;
-import org.apache.mahout.common.distance.DistanceMeasure;
-import org.apache.mahout.common.distance.EuclideanDistanceMeasure;
-import org.apache.mahout.utils.clustering.ClusterDumper;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
-public final class Job extends AbstractJob {
- private static final String DIRECTORY_CONTAINING_CONVERTED_INPUT = "data";
- private Job() {
- }
- private static final Logger log = LoggerFactory.getLogger(Job.class);
- public static void main(String[] args) throws Exception {
- if (args.length > 0) {
- log.info("Running with only user-supplied arguments");
- ToolRunner.run(new Configuration(), new Job(), args);
- } else {
- log.info("Running with default arguments");
- Path output = new Path("output");
- HadoopUtil.delete(new Configuration(), output);
- run(new Path("testdata"), output, new EuclideanDistanceMeasure(), 80, 55);
- }
- }
- /**
- * Run the canopy clustering job on an input dataset using the given distance
- * measure, t1 and t2 parameters. All output data will be written to the
- * output directory, which will be initially deleted if it exists. The
- * clustered points will reside in the path <output>/clustered-points. By
- * default, the job expects the a file containing synthetic_control.data as
- * obtained from
- * http://archive.ics.uci.edu/ml/datasets/Synthetic+Control+Chart+Time+Series
- * resides in a directory named "testdata", and writes output to a directory
- * named "output".
- *
- * @param input
- * the String denoting the input directory path
- * @param output
- * the String denoting the output directory path
- * @param measure
- * the DistanceMeasure to use
- * @param t1
- * the canopy T1 threshold
- * @param t2
- * the canopy T2 threshold
- */
- private static void run(Path input, Path output, DistanceMeasure measure,
- double t1, double t2) throws Exception {
- Path directoryContainingConvertedInput = new Path(output,
- InputDriver.runJob(input, directoryContainingConvertedInput,
- "org.apache.mahout.math.RandomAccessSparseVector");
- CanopyDriver.run(new Configuration(), directoryContainingConvertedInput,
- output, measure, t1, t2, true, 0.0, false);
- // run ClusterDumper
- ClusterDumper clusterDumper = new ClusterDumper(new Path(output,
- "clusters-0-final"), new Path(output, "clusteredPoints"));
- clusterDumper.printClusters(null);
- }
- @Override
- public int run(String[] args) throws Exception {
- addInputOption();
- addOutputOption();
- addOption(DefaultOptionCreator.distanceMeasureOption().create());
- addOption(DefaultOptionCreator.t1Option().create());
- addOption(DefaultOptionCreator.t2Option().create());
- addOption(DefaultOptionCreator.overwriteOption().create());
- Map<String, List<String>> argMap = parseArguments(args);
- if (argMap == null) {
- return -1;
- }
- Path input = getInputPath();
- Path output = getOutputPath();
- if (hasOption(DefaultOptionCreator.OVERWRITE_OPTION)) {
- HadoopUtil.delete(new Configuration(), output);
- }
- String measureClass = getOption(DefaultOptionCreator.DISTANCE_MEASURE_OPTION);
- double t1 = Double.parseDouble(getOption(DefaultOptionCreator.T1_OPTION));
- double t2 = Double.parseDouble(getOption(DefaultOptionCreator.T2_OPTION));
- DistanceMeasure measure = ClassUtils.instantiateAs(measureClass, DistanceMeasure.class);
- run(input, output, measure, t1, t2);
- return 0;
- }
2018-06-27 13:14:33 UTC
diff --git a/examples/src/main/java/org/apache/mahout/classifier/sgd/LogisticModelParameters.java b/examples/src/main/java/org/apache/mahout/classifier/sgd/LogisticModelParameters.java
deleted file mode 100644
index e762924..0000000
--- a/examples/src/main/java/org/apache/mahout/classifier/sgd/LogisticModelParameters.java
+++ /dev/null
@@ -1,265 +0,0 @@
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.mahout.classifier.sgd;
-import com.google.common.base.Preconditions;
-import com.google.common.io.Closeables;
-import java.io.DataInput;
-import java.io.DataInputStream;
-import java.io.DataOutput;
-import java.io.DataOutputStream;
-import java.io.File;
-import java.io.FileInputStream;
-import java.io.IOException;
-import java.io.InputStream;
-import java.io.OutputStream;
-import java.util.ArrayList;
-import java.util.HashMap;
-import java.util.Iterator;
-import java.util.List;
-import java.util.Map;
-import org.apache.hadoop.io.Writable;
- * Encapsulates everything we need to know about a model and how it reads and vectorizes its input.
- * This encapsulation allows us to coherently save and restore a model from a file. This also
- * allows us to keep command line arguments that affect learning in a coherent way.
- */
-public class LogisticModelParameters implements Writable {
- private String targetVariable;
- private Map<String, String> typeMap;
- private int numFeatures;
- private boolean useBias;
- private int maxTargetCategories;
- private List<String> targetCategories;
- private double lambda;
- private double learningRate;
- private CsvRecordFactory csv;
- private OnlineLogisticRegression lr;
- /**
- * Returns a CsvRecordFactory compatible with this logistic model. The reason that this is tied
- * in here is so that we have access to the list of target categories when it comes time to save
- * the model. If the input isn't CSV, then calling setTargetCategories before calling saveTo will
- * suffice.
- *
- * @return The CsvRecordFactory.
- */
- public CsvRecordFactory getCsvRecordFactory() {
- if (csv == null) {
- csv = new CsvRecordFactory(getTargetVariable(), getTypeMap())
- .maxTargetValue(getMaxTargetCategories())
- .includeBiasTerm(useBias());
- if (targetCategories != null) {
- csv.defineTargetCategories(targetCategories);
- }
- }
- return csv;
- }
- /**
- * Creates a logistic regression trainer using the parameters collected here.
- *
- * @return The newly allocated OnlineLogisticRegression object
- */
- public OnlineLogisticRegression createRegression() {
- if (lr == null) {
- lr = new OnlineLogisticRegression(getMaxTargetCategories(), getNumFeatures(), new L1())
- .lambda(getLambda())
- .learningRate(getLearningRate())
- .alpha(1 - 1.0e-3);
- }
- return lr;
- }
- /**
- * Saves a model to an output stream.
- */
- public void saveTo(OutputStream out) throws IOException {
- Closeables.close(lr, false);
- targetCategories = getCsvRecordFactory().getTargetCategories();
- write(new DataOutputStream(out));
- }
- /**
- * Reads a model from a stream.
- */
- public static LogisticModelParameters loadFrom(InputStream in) throws IOException {
- LogisticModelParameters result = new LogisticModelParameters();
- result.readFields(new DataInputStream(in));
- return result;
- }
- /**
- * Reads a model from a file.
- * @throws IOException If there is an error opening or closing the file.
- */
- public static LogisticModelParameters loadFrom(File in) throws IOException {
- try (InputStream input = new FileInputStream(in)) {
- return loadFrom(input);
- }
- }
- @Override
- public void write(DataOutput out) throws IOException {
- out.writeUTF(targetVariable);
- out.writeInt(typeMap.size());
- for (Map.Entry<String,String> entry : typeMap.entrySet()) {
- out.writeUTF(entry.getKey());
- out.writeUTF(entry.getValue());
- }
- out.writeInt(numFeatures);
- out.writeBoolean(useBias);
- out.writeInt(maxTargetCategories);
- if (targetCategories == null) {
- out.writeInt(0);
- } else {
- out.writeInt(targetCategories.size());
- for (String category : targetCategories) {
- out.writeUTF(category);
- }
- }
- out.writeDouble(lambda);
- out.writeDouble(learningRate);
- // skip csv
- lr.write(out);
- }
- @Override
- public void readFields(DataInput in) throws IOException {
- targetVariable = in.readUTF();
- int typeMapSize = in.readInt();
- typeMap = new HashMap<>(typeMapSize);
- for (int i = 0; i < typeMapSize; i++) {
- String key = in.readUTF();
- String value = in.readUTF();
- typeMap.put(key, value);
- }
- numFeatures = in.readInt();
- useBias = in.readBoolean();
- maxTargetCategories = in.readInt();
- int targetCategoriesSize = in.readInt();
- targetCategories = new ArrayList<>(targetCategoriesSize);
- for (int i = 0; i < targetCategoriesSize; i++) {
- targetCategories.add(in.readUTF());
- }
- lambda = in.readDouble();
- learningRate = in.readDouble();
- csv = null;
- lr = new OnlineLogisticRegression();
- lr.readFields(in);
- }
- /**
- * Sets the types of the predictors. This will later be used when reading CSV data. If you don't
- * use the CSV data and convert to vectors on your own, you don't need to call this.
- *
- * @param predictorList The list of variable names.
- * @param typeList The list of types in the format preferred by CsvRecordFactory.
- */
- public void setTypeMap(Iterable<String> predictorList, List<String> typeList) {
- Preconditions.checkArgument(!typeList.isEmpty(), "Must have at least one type specifier");
- typeMap = new HashMap<>();
- Iterator<String> iTypes = typeList.iterator();
- String lastType = null;
- for (Object x : predictorList) {
- // type list can be short .. we just repeat last spec
- if (iTypes.hasNext()) {
- lastType = iTypes.next();
- }
- typeMap.put(x.toString(), lastType);
- }
- }
- /**
- * Sets the target variable. If you don't use the CSV record factory, then this is irrelevant.
- *
- * @param targetVariable The name of the target variable.
- */
- public void setTargetVariable(String targetVariable) {
- this.targetVariable = targetVariable;
- }
- /**
- * Sets the number of target categories to be considered.
- *
- * @param maxTargetCategories The number of target categories.
- */
- public void setMaxTargetCategories(int maxTargetCategories) {
- this.maxTargetCategories = maxTargetCategories;
- }
- public void setNumFeatures(int numFeatures) {
- this.numFeatures = numFeatures;
- }
- public void setTargetCategories(List<String> targetCategories) {
- this.targetCategories = targetCategories;
- maxTargetCategories = targetCategories.size();
- }
- public List<String> getTargetCategories() {
- return this.targetCategories;
- }
- public void setUseBias(boolean useBias) {
- this.useBias = useBias;
- }
- public boolean useBias() {
- return useBias;
- }
- public String getTargetVariable() {
- return targetVariable;
- }
- public Map<String, String> getTypeMap() {
- return typeMap;
- }
- public void setTypeMap(Map<String, String> map) {
- this.typeMap = map;
- }
- public int getNumFeatures() {
- return numFeatures;
- }
- public int getMaxTargetCategories() {
- return maxTargetCategories;
- }
- public double getLambda() {
- return lambda;
- }
- public void setLambda(double lambda) {
- this.lambda = lambda;
- }
- public double getLearningRate() {
- return learningRate;
- }
- public void setLearningRate(double learningRate) {
- this.learningRate = learningRate;
- }

diff --git a/examples/src/main/java/org/apache/mahout/classifier/sgd/PrintResourceOrFile.java b/examples/src/main/java/org/apache/mahout/classifier/sgd/PrintResourceOrFile.java
deleted file mode 100644
index 3ec6a06..0000000
--- a/examples/src/main/java/org/apache/mahout/classifier/sgd/PrintResourceOrFile.java
+++ /dev/null
@@ -1,42 +0,0 @@
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.mahout.classifier.sgd;
-import com.google.common.base.Preconditions;
-import java.io.BufferedReader;
- * Uses the same logic as TrainLogistic and RunLogistic for finding an input, but instead
- * of processing the input, this class just prints the input to standard out.
- */
-public final class PrintResourceOrFile {
- private PrintResourceOrFile() {
- }
- public static void main(String[] args) throws Exception {
- Preconditions.checkArgument(args.length == 1, "Must have a single argument that names a file or resource.");
- try (BufferedReader in = TrainLogistic.open(args[0])){
- String line;
- while ((line = in.readLine()) != null) {
- System.out.println(line);
- }
- }
- }

diff --git a/examples/src/main/java/org/apache/mahout/classifier/sgd/RunAdaptiveLogistic.java b/examples/src/main/java/org/apache/mahout/classifier/sgd/RunAdaptiveLogistic.java
deleted file mode 100644
index 678a8f5..0000000
--- a/examples/src/main/java/org/apache/mahout/classifier/sgd/RunAdaptiveLogistic.java
+++ /dev/null
@@ -1,197 +0,0 @@
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.mahout.classifier.sgd;
-import org.apache.commons.cli2.CommandLine;
-import org.apache.commons.cli2.Group;
-import org.apache.commons.cli2.Option;
-import org.apache.commons.cli2.builder.ArgumentBuilder;
-import org.apache.commons.cli2.builder.DefaultOptionBuilder;
-import org.apache.commons.cli2.builder.GroupBuilder;
-import org.apache.commons.cli2.commandline.Parser;
-import org.apache.commons.cli2.util.HelpFormatter;
-import org.apache.commons.io.Charsets;
-import org.apache.mahout.classifier.sgd.AdaptiveLogisticRegression.Wrapper;
-import org.apache.mahout.ep.State;
-import org.apache.mahout.math.SequentialAccessSparseVector;
-import org.apache.mahout.math.Vector;
-import java.io.BufferedReader;
-import java.io.BufferedWriter;
-import java.io.File;
-import java.io.FileOutputStream;
-import java.io.OutputStreamWriter;
-import java.io.PrintWriter;
-import java.util.HashMap;
-import java.util.Map;
-public final class RunAdaptiveLogistic {
- private static String inputFile;
- private static String modelFile;
- private static String outputFile;
- private static String idColumn;
- private static boolean maxScoreOnly;
- private RunAdaptiveLogistic() {
- }
- public static void main(String[] args) throws Exception {
- mainToOutput(args, new PrintWriter(new OutputStreamWriter(System.out, Charsets.UTF_8), true));
- }
- static void mainToOutput(String[] args, PrintWriter output) throws Exception {
- if (!parseArgs(args)) {
- return;
- }
- AdaptiveLogisticModelParameters lmp = AdaptiveLogisticModelParameters
- .loadFromFile(new File(modelFile));
- CsvRecordFactory csv = lmp.getCsvRecordFactory();
- csv.setIdName(idColumn);
- AdaptiveLogisticRegression lr = lmp.createAdaptiveLogisticRegression();
- State<Wrapper, CrossFoldLearner> best = lr.getBest();
- if (best == null) {
- output.println("AdaptiveLogisticRegression has not be trained probably.");
- return;
- }
- CrossFoldLearner learner = best.getPayload().getLearner();
- BufferedReader in = TrainAdaptiveLogistic.open(inputFile);
- int k = 0;
- try (BufferedWriter out = new BufferedWriter(new OutputStreamWriter(new FileOutputStream(outputFile),
- Charsets.UTF_8))) {
- out.write(idColumn + ",target,score");
- out.newLine();
- String line = in.readLine();
- csv.firstLine(line);
- line = in.readLine();
- Map<String, Double> results = new HashMap<>();
- while (line != null) {
- Vector v = new SequentialAccessSparseVector(lmp.getNumFeatures());
- csv.processLine(line, v, false);
- Vector scores = learner.classifyFull(v);
- results.clear();
- if (maxScoreOnly) {
- results.put(csv.getTargetLabel(scores.maxValueIndex()),
- scores.maxValue());
- } else {
- for (int i = 0; i < scores.size(); i++) {
- results.put(csv.getTargetLabel(i), scores.get(i));
- }
- }
- for (Map.Entry<String, Double> entry : results.entrySet()) {
- out.write(csv.getIdString(line) + ',' + entry.getKey() + ',' + entry.getValue());
- out.newLine();
- }
- k++;
- if (k % 100 == 0) {
- output.println(k + " records processed");
- }
- line = in.readLine();
- }
- out.flush();
- }
- output.println(k + " records processed totally.");
- }
- private static boolean parseArgs(String[] args) {
- DefaultOptionBuilder builder = new DefaultOptionBuilder();
- Option help = builder.withLongName("help")
- .withDescription("print this list").create();
- Option quiet = builder.withLongName("quiet")
- .withDescription("be extra quiet").create();
- ArgumentBuilder argumentBuilder = new ArgumentBuilder();
- Option inputFileOption = builder
- .withLongName("input")
- .withRequired(true)
- .withArgument(
- argumentBuilder.withName("input").withMaximum(1)
- .create())
- .withDescription("where to get training data").create();
- Option modelFileOption = builder
- .withLongName("model")
- .withRequired(true)
- .withArgument(
- argumentBuilder.withName("model").withMaximum(1)
- .create())
- .withDescription("where to get the trained model").create();
- Option outputFileOption = builder
- .withLongName("output")
- .withRequired(true)
- .withDescription("the file path to output scores")
- .withArgument(argumentBuilder.withName("output").withMaximum(1).create())
- .create();
- Option idColumnOption = builder
- .withLongName("idcolumn")
- .withRequired(true)
- .withDescription("the name of the id column for each record")
- .withArgument(argumentBuilder.withName("idcolumn").withMaximum(1).create())
- .create();
- Option maxScoreOnlyOption = builder
- .withLongName("maxscoreonly")
- .withDescription("only output the target label with max scores")
- .create();
- Group normalArgs = new GroupBuilder()
- .withOption(help).withOption(quiet)
- .withOption(inputFileOption).withOption(modelFileOption)
- .withOption(outputFileOption).withOption(idColumnOption)
- .withOption(maxScoreOnlyOption)
- .create();
- Parser parser = new Parser();
- parser.setHelpOption(help);
- parser.setHelpTrigger("--help");
- parser.setGroup(normalArgs);
- parser.setHelpFormatter(new HelpFormatter(" ", "", " ", 130));
- CommandLine cmdLine = parser.parseAndHelp(args);
- if (cmdLine == null) {
- return false;
- }
- inputFile = getStringArgument(cmdLine, inputFileOption);
- modelFile = getStringArgument(cmdLine, modelFileOption);
- outputFile = getStringArgument(cmdLine, outputFileOption);
- idColumn = getStringArgument(cmdLine, idColumnOption);
- maxScoreOnly = getBooleanArgument(cmdLine, maxScoreOnlyOption);
- return true;
- }
- private static boolean getBooleanArgument(CommandLine cmdLine, Option option) {
- return cmdLine.hasOption(option);
- }
- private static String getStringArgument(CommandLine cmdLine, Option inputFile) {
- return (String) cmdLine.getValue(inputFile);
- }

diff --git a/examples/src/main/java/org/apache/mahout/classifier/sgd/RunLogistic.java b/examples/src/main/java/org/apache/mahout/classifier/sgd/RunLogistic.java
deleted file mode 100644
index 2d57016..0000000
--- a/examples/src/main/java/org/apache/mahout/classifier/sgd/RunLogistic.java
+++ /dev/null
@@ -1,163 +0,0 @@
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.mahout.classifier.sgd;
-import org.apache.commons.cli2.CommandLine;
-import org.apache.commons.cli2.Group;
-import org.apache.commons.cli2.Option;
-import org.apache.commons.cli2.builder.ArgumentBuilder;
-import org.apache.commons.cli2.builder.DefaultOptionBuilder;
-import org.apache.commons.cli2.builder.GroupBuilder;
-import org.apache.commons.cli2.commandline.Parser;
-import org.apache.commons.cli2.util.HelpFormatter;
-import org.apache.commons.io.Charsets;
-import org.apache.mahout.classifier.evaluation.Auc;
-import org.apache.mahout.math.Matrix;
-import org.apache.mahout.math.SequentialAccessSparseVector;
-import org.apache.mahout.math.Vector;
-import java.io.BufferedReader;
-import java.io.File;
-import java.io.OutputStreamWriter;
-import java.io.PrintWriter;
-import java.util.Locale;
-public final class RunLogistic {
- private static String inputFile;
- private static String modelFile;
- private static boolean showAuc;
- private static boolean showScores;
- private static boolean showConfusion;
- private RunLogistic() {
- }
- public static void main(String[] args) throws Exception {
- mainToOutput(args, new PrintWriter(new OutputStreamWriter(System.out, Charsets.UTF_8), true));
- }
- static void mainToOutput(String[] args, PrintWriter output) throws Exception {
- if (parseArgs(args)) {
- if (!showAuc && !showConfusion && !showScores) {
- showAuc = true;
- showConfusion = true;
- }
- Auc collector = new Auc();
- LogisticModelParameters lmp = LogisticModelParameters.loadFrom(new File(modelFile));
- CsvRecordFactory csv = lmp.getCsvRecordFactory();
- OnlineLogisticRegression lr = lmp.createRegression();
- BufferedReader in = TrainLogistic.open(inputFile);
- String line = in.readLine();
- csv.firstLine(line);
- line = in.readLine();
- if (showScores) {
- output.println("\"target\",\"model-output\",\"log-likelihood\"");
- }
- while (line != null) {
- Vector v = new SequentialAccessSparseVector(lmp.getNumFeatures());
- int target = csv.processLine(line, v);
- double score = lr.classifyScalar(v);
- if (showScores) {
- output.printf(Locale.ENGLISH, "%d,%.3f,%.6f%n", target, score, lr.logLikelihood(target, v));
- }
- collector.add(target, score);
- line = in.readLine();
- }
- if (showAuc) {
- output.printf(Locale.ENGLISH, "AUC = %.2f%n", collector.auc());
- }
- if (showConfusion) {
- Matrix m = collector.confusion();
- output.printf(Locale.ENGLISH, "confusion: [[%.1f, %.1f], [%.1f, %.1f]]%n",
- m.get(0, 0), m.get(1, 0), m.get(0, 1), m.get(1, 1));
- m = collector.entropy();
- output.printf(Locale.ENGLISH, "entropy: [[%.1f, %.1f], [%.1f, %.1f]]%n",
- m.get(0, 0), m.get(1, 0), m.get(0, 1), m.get(1, 1));
- }
- }
- }
- private static boolean parseArgs(String[] args) {
- DefaultOptionBuilder builder = new DefaultOptionBuilder();
- Option help = builder.withLongName("help").withDescription("print this list").create();
- Option quiet = builder.withLongName("quiet").withDescription("be extra quiet").create();
- Option auc = builder.withLongName("auc").withDescription("print AUC").create();
- Option confusion = builder.withLongName("confusion").withDescription("print confusion matrix").create();
- Option scores = builder.withLongName("scores").withDescription("print scores").create();
- ArgumentBuilder argumentBuilder = new ArgumentBuilder();
- Option inputFileOption = builder.withLongName("input")
- .withRequired(true)
- .withArgument(argumentBuilder.withName("input").withMaximum(1).create())
- .withDescription("where to get training data")
- .create();
- Option modelFileOption = builder.withLongName("model")
- .withRequired(true)
- .withArgument(argumentBuilder.withName("model").withMaximum(1).create())
- .withDescription("where to get a model")
- .create();
- Group normalArgs = new GroupBuilder()
- .withOption(help)
- .withOption(quiet)
- .withOption(auc)
- .withOption(scores)
- .withOption(confusion)
- .withOption(inputFileOption)
- .withOption(modelFileOption)
- .create();
- Parser parser = new Parser();
- parser.setHelpOption(help);
- parser.setHelpTrigger("--help");
- parser.setGroup(normalArgs);
- parser.setHelpFormatter(new HelpFormatter(" ", "", " ", 130));
- CommandLine cmdLine = parser.parseAndHelp(args);
- if (cmdLine == null) {
- return false;
- }
- inputFile = getStringArgument(cmdLine, inputFileOption);
- modelFile = getStringArgument(cmdLine, modelFileOption);
- showAuc = getBooleanArgument(cmdLine, auc);
- showScores = getBooleanArgument(cmdLine, scores);
- showConfusion = getBooleanArgument(cmdLine, confusion);
- return true;
- }
- private static boolean getBooleanArgument(CommandLine cmdLine, Option option) {
- return cmdLine.hasOption(option);
- }
- private static String getStringArgument(CommandLine cmdLine, Option inputFile) {
- return (String) cmdLine.getValue(inputFile);
- }

diff --git a/examples/src/main/java/org/apache/mahout/classifier/sgd/SGDHelper.java b/examples/src/main/java/org/apache/mahout/classifier/sgd/SGDHelper.java
deleted file mode 100644
index c657803..0000000
--- a/examples/src/main/java/org/apache/mahout/classifier/sgd/SGDHelper.java
+++ /dev/null
@@ -1,151 +0,0 @@
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.mahout.classifier.sgd;
-import com.google.common.collect.Multiset;
-import org.apache.mahout.classifier.NewsgroupHelper;
-import org.apache.mahout.ep.State;
-import org.apache.mahout.math.Matrix;
-import org.apache.mahout.math.Vector;
-import org.apache.mahout.math.function.DoubleFunction;
-import org.apache.mahout.math.function.Functions;
-import org.apache.mahout.vectorizer.encoders.Dictionary;
-import java.io.File;
-import java.io.IOException;
-import java.util.ArrayList;
-import java.util.List;
-import java.util.Map;
-import java.util.Random;
-import java.util.Set;
-import java.util.TreeMap;
-public final class SGDHelper {
- private static final String[] LEAK_LABELS = {"none", "month-year", "day-month-year"};
- private SGDHelper() {
- }
- public static void dissect(int leakType,
- Dictionary dictionary,
- AdaptiveLogisticRegression learningAlgorithm,
- Iterable<File> files, Multiset<String> overallCounts) throws IOException {
- CrossFoldLearner model = learningAlgorithm.getBest().getPayload().getLearner();
- model.close();
- Map<String, Set<Integer>> traceDictionary = new TreeMap<>();
- ModelDissector md = new ModelDissector();
- NewsgroupHelper helper = new NewsgroupHelper();
- helper.getEncoder().setTraceDictionary(traceDictionary);
- helper.getBias().setTraceDictionary(traceDictionary);
- for (File file : permute(files, helper.getRandom()).subList(0, 500)) {
- String ng = file.getParentFile().getName();
- int actual = dictionary.intern(ng);
- traceDictionary.clear();
- Vector v = helper.encodeFeatureVector(file, actual, leakType, overallCounts);
- md.update(v, traceDictionary, model);
- }
- List<String> ngNames = new ArrayList<>(dictionary.values());
- List<ModelDissector.Weight> weights = md.summary(100);
- System.out.println("============");
- System.out.println("Model Dissection");
- for (ModelDissector.Weight w : weights) {
- System.out.printf("%s\t%.1f\t%s\t%.1f\t%s\t%.1f\t%s%n",
- w.getFeature(), w.getWeight(), ngNames.get(w.getMaxImpact() + 1),
- w.getCategory(1), w.getWeight(1), w.getCategory(2), w.getWeight(2));
- }
- }
- public static List<File> permute(Iterable<File> files, Random rand) {
- List<File> r = new ArrayList<>();
- for (File file : files) {
- int i = rand.nextInt(r.size() + 1);
- if (i == r.size()) {
- r.add(file);
- } else {
- r.add(r.get(i));
- r.set(i, file);
- }
- }
- return r;
- }
- static void analyzeState(SGDInfo info, int leakType, int k, State<AdaptiveLogisticRegression.Wrapper,
- CrossFoldLearner> best) throws IOException {
- int bump = info.getBumps()[(int) Math.floor(info.getStep()) % info.getBumps().length];
- int scale = (int) Math.pow(10, Math.floor(info.getStep() / info.getBumps().length));
- double maxBeta;
- double nonZeros;
- double positive;
- double norm;
- double lambda = 0;
- double mu = 0;
- if (best != null) {
- CrossFoldLearner state = best.getPayload().getLearner();
- info.setAverageCorrect(state.percentCorrect());
- info.setAverageLL(state.logLikelihood());
- OnlineLogisticRegression model = state.getModels().get(0);
- // finish off pending regularization
- model.close();
- Matrix beta = model.getBeta();
- maxBeta = beta.aggregate(Functions.MAX, Functions.ABS);
- nonZeros = beta.aggregate(Functions.PLUS, new DoubleFunction() {
- @Override
- public double apply(double v) {
- return Math.abs(v) > 1.0e-6 ? 1 : 0;
- }
- });
- positive = beta.aggregate(Functions.PLUS, new DoubleFunction() {
- @Override
- public double apply(double v) {
- return v > 0 ? 1 : 0;
- }
- });
- norm = beta.aggregate(Functions.PLUS, Functions.ABS);
- lambda = best.getMappedParams()[0];
- mu = best.getMappedParams()[1];
- } else {
- maxBeta = 0;
- nonZeros = 0;
- positive = 0;
- norm = 0;
- }
- if (k % (bump * scale) == 0) {
- if (best != null) {
- File modelFile = new File(System.getProperty("java.io.tmpdir"), "news-group-" + k + ".model");
- ModelSerializer.writeBinary(modelFile.getAbsolutePath(), best.getPayload().getLearner().getModels().get(0));
- }
- info.setStep(info.getStep() + 0.25);
- System.out.printf("%.2f\t%.2f\t%.2f\t%.2f\t%.8g\t%.8g\t", maxBeta, nonZeros, positive, norm, lambda, mu);
- System.out.printf("%d\t%.3f\t%.2f\t%s%n",
- k, info.getAverageLL(), info.getAverageCorrect() * 100, LEAK_LABELS[leakType % 3]);
- }
- }

diff --git a/examples/src/main/java/org/apache/mahout/classifier/sgd/SGDInfo.java b/examples/src/main/java/org/apache/mahout/classifier/sgd/SGDInfo.java
deleted file mode 100644
index be55d43..0000000
--- a/examples/src/main/java/org/apache/mahout/classifier/sgd/SGDInfo.java
+++ /dev/null
@@ -1,59 +0,0 @@
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.mahout.classifier.sgd;
-final class SGDInfo {
- private double averageLL;
- private double averageCorrect;
- private double step;
- private int[] bumps = {1, 2, 5};
- double getAverageLL() {
- return averageLL;
- }
- void setAverageLL(double averageLL) {
- this.averageLL = averageLL;
- }
- double getAverageCorrect() {
- return averageCorrect;
- }
- void setAverageCorrect(double averageCorrect) {
- this.averageCorrect = averageCorrect;
- }
- double getStep() {
- return step;
- }
- void setStep(double step) {
- this.step = step;
- }
- int[] getBumps() {
- return bumps;
- }
- void setBumps(int[] bumps) {
- this.bumps = bumps;
- }

diff --git a/examples/src/main/java/org/apache/mahout/classifier/sgd/SimpleCsvExamples.java b/examples/src/main/java/org/apache/mahout/classifier/sgd/SimpleCsvExamples.java
deleted file mode 100644
index b3da452..0000000
--- a/examples/src/main/java/org/apache/mahout/classifier/sgd/SimpleCsvExamples.java
+++ /dev/null
@@ -1,283 +0,0 @@
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.mahout.classifier.sgd;
-import com.google.common.base.Joiner;
-import com.google.common.base.Splitter;
-import com.google.common.collect.Lists;
-import com.google.common.io.Closeables;
-import com.google.common.io.Files;
-import org.apache.commons.io.Charsets;
-import org.apache.mahout.common.RandomUtils;
-import org.apache.mahout.math.DenseVector;
-import org.apache.mahout.math.Vector;
-import org.apache.mahout.math.list.IntArrayList;
-import org.apache.mahout.math.stats.OnlineSummarizer;
-import org.apache.mahout.vectorizer.encoders.ConstantValueEncoder;
-import org.apache.mahout.vectorizer.encoders.FeatureVectorEncoder;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
-import java.io.BufferedReader;
-import java.io.Closeable;
-import java.io.File;
-import java.io.FileInputStream;
-import java.io.FileOutputStream;
-import java.io.IOException;
-import java.io.InputStream;
-import java.io.OutputStreamWriter;
-import java.io.PrintWriter;
-import java.nio.ByteBuffer;
-import java.util.ArrayList;
-import java.util.List;
-import java.util.Random;
- * Shows how different encoding choices can make big speed differences.
- * <p/>
- * Run with command line options --generate 1000000 test.csv to generate a million data lines in
- * test.csv.
- * <p/>
- * Run with command line options --parser test.csv to time how long it takes to parse and encode
- * those million data points
- * <p/>
- * Run with command line options --fast test.csv to time how long it takes to parse and encode those
- * million data points using byte-level parsing and direct value encoding.
- * <p/>
- * This doesn't demonstrate text encoding which is subject to somewhat different tricks. The basic
- * idea of caching hash locations and byte level parsing still very much applies to text, however.
- */
-public final class SimpleCsvExamples {
- public static final char SEPARATOR_CHAR = '\t';
- private static final int FIELDS = 100;
- private static final Logger log = LoggerFactory.getLogger(SimpleCsvExamples.class);
- private SimpleCsvExamples() {}
- public static void main(String[] args) throws IOException {
- FeatureVectorEncoder[] encoder = new FeatureVectorEncoder[FIELDS];
- for (int i = 0; i < FIELDS; i++) {
- encoder[i] = new ConstantValueEncoder("v" + 1);
- }
- OnlineSummarizer[] s = new OnlineSummarizer[FIELDS];
- for (int i = 0; i < FIELDS; i++) {
- s[i] = new OnlineSummarizer();
- }
- long t0 = System.currentTimeMillis();
- Vector v = new DenseVector(1000);
- if ("--generate".equals(args[0])) {
- try (PrintWriter out =
- new PrintWriter(new OutputStreamWriter(new FileOutputStream(new File(args[2])), Charsets.UTF_8))) {
- int n = Integer.parseInt(args[1]);
- for (int i = 0; i < n; i++) {
- Line x = Line.generate();
- out.println(x);
- }
- }
- } else if ("--parse".equals(args[0])) {
- try (BufferedReader in = Files.newReader(new File(args[1]), Charsets.UTF_8)){
- String line = in.readLine();
- while (line != null) {
- v.assign(0);
- Line x = new Line(line);
- for (int i = 0; i < FIELDS; i++) {
- s[i].add(x.getDouble(i));
- encoder[i].addToVector(x.get(i), v);
- }
- line = in.readLine();
- }
- }
- String separator = "";
- for (int i = 0; i < FIELDS; i++) {
- System.out.printf("%s%.3f", separator, s[i].getMean());
- separator = ",";
- }
- } else if ("--fast".equals(args[0])) {
- try (FastLineReader in = new FastLineReader(new FileInputStream(args[1]))){
- FastLine line = in.read();
- while (line != null) {
- v.assign(0);
- for (int i = 0; i < FIELDS; i++) {
- double z = line.getDouble(i);
- s[i].add(z);
- encoder[i].addToVector((byte[]) null, z, v);
- }
- line = in.read();
- }
- }
- String separator = "";
- for (int i = 0; i < FIELDS; i++) {
- System.out.printf("%s%.3f", separator, s[i].getMean());
- separator = ",";
- }
- }
- System.out.printf("\nElapsed time = %.3f%n", (System.currentTimeMillis() - t0) / 1000.0);
- }
- private static final class Line {
- private static final Splitter ON_TABS = Splitter.on(SEPARATOR_CHAR).trimResults();
- public static final Joiner WITH_COMMAS = Joiner.on(SEPARATOR_CHAR);
- public static final Random RAND = RandomUtils.getRandom();
- private final List<String> data;
- private Line(CharSequence line) {
- data = Lists.newArrayList(ON_TABS.split(line));
- }
- private Line() {
- data = new ArrayList<>();
- }
- public double getDouble(int field) {
- return Double.parseDouble(data.get(field));
- }
- /**
- * Generate a random line with 20 fields each with integer values.
- *
- * @return A new line with data.
- */
- public static Line generate() {
- Line r = new Line();
- for (int i = 0; i < FIELDS; i++) {
- double mean = ((i + 1) * 257) % 50 + 1;
- r.data.add(Integer.toString(randomValue(mean)));
- }
- return r;
- }
- /**
- * Returns a random exponentially distributed integer with a particular mean value. This is
- * just a way to create more small numbers than big numbers.
- *
- * @param mean mean of the distribution
- * @return random exponentially distributed integer with the specific mean
- */
- private static int randomValue(double mean) {
- return (int) (-mean * Math.log1p(-RAND.nextDouble()));
- }
- @Override
- public String toString() {
- return WITH_COMMAS.join(data);
- }
- public String get(int field) {
- return data.get(field);
- }
- }
- private static final class FastLine {
- private final ByteBuffer base;
- private final IntArrayList start = new IntArrayList();
- private final IntArrayList length = new IntArrayList();
- private FastLine(ByteBuffer base) {
- this.base = base;
- }
- public static FastLine read(ByteBuffer buf) {
- FastLine r = new FastLine(buf);
- r.start.add(buf.position());
- int offset = buf.position();
- while (offset < buf.limit()) {
- int ch = buf.get();
- offset = buf.position();
- switch (ch) {
- case '\n':
- r.length.add(offset - r.start.get(r.length.size()) - 1);
- return r;
- r.length.add(offset - r.start.get(r.length.size()) - 1);
- r.start.add(offset);
- break;
- default:
- // nothing to do for now
- }
- }
- throw new IllegalArgumentException("Not enough bytes in buffer");
- }
- public double getDouble(int field) {
- int offset = start.get(field);
- int size = length.get(field);
- switch (size) {
- case 1:
- return base.get(offset) - '0';
- case 2:
- return (base.get(offset) - '0') * 10 + base.get(offset + 1) - '0';
- default:
- double r = 0;
- for (int i = 0; i < size; i++) {
- r = 10 * r + base.get(offset + i) - '0';
- }
- return r;
- }
- }
- }
- private static final class FastLineReader implements Closeable {
- private final InputStream in;
- private final ByteBuffer buf = ByteBuffer.allocate(100000);
- private FastLineReader(InputStream in) throws IOException {
- this.in = in;
- buf.limit(0);
- fillBuffer();
- }
- public FastLine read() throws IOException {
- fillBuffer();
- if (buf.remaining() > 0) {
- return FastLine.read(buf);
- } else {
- return null;
- }
- }
- private void fillBuffer() throws IOException {
- if (buf.remaining() < 10000) {
- buf.compact();
- int n = in.read(buf.array(), buf.position(), buf.remaining());
- if (n == -1) {
- buf.flip();
- } else {
- buf.limit(buf.position() + n);
- buf.position(0);
- }
- }
- }
- @Override
- public void close() {
- try {
- Closeables.close(in, true);
- } catch (IOException e) {
- log.error(e.getMessage(), e);
- }
- }
- }

diff --git a/examples/src/main/java/org/apache/mahout/classifier/sgd/TestASFEmail.java b/examples/src/main/java/org/apache/mahout/classifier/sgd/TestASFEmail.java
deleted file mode 100644
index 074f774..0000000
--- a/examples/src/main/java/org/apache/mahout/classifier/sgd/TestASFEmail.java
+++ /dev/null
@@ -1,152 +0,0 @@
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.mahout.classifier.sgd;
-import org.apache.commons.cli2.CommandLine;
-import org.apache.commons.cli2.Group;
-import org.apache.commons.cli2.Option;
-import org.apache.commons.cli2.builder.ArgumentBuilder;
-import org.apache.commons.cli2.builder.DefaultOptionBuilder;
-import org.apache.commons.cli2.builder.GroupBuilder;
-import org.apache.commons.cli2.commandline.Parser;
-import org.apache.commons.cli2.util.HelpFormatter;
-import org.apache.commons.io.Charsets;
-import org.apache.hadoop.conf.Configuration;
-import org.apache.hadoop.fs.Path;
-import org.apache.hadoop.fs.PathFilter;
-import org.apache.hadoop.io.Text;
-import org.apache.mahout.classifier.ClassifierResult;
-import org.apache.mahout.classifier.ResultAnalyzer;
-import org.apache.mahout.common.Pair;
-import org.apache.mahout.common.iterator.sequencefile.PathType;
-import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirIterator;
-import org.apache.mahout.math.Vector;
-import org.apache.mahout.math.VectorWritable;
-import org.apache.mahout.vectorizer.encoders.Dictionary;
-import java.io.File;
-import java.io.FileInputStream;
-import java.io.IOException;
-import java.io.OutputStreamWriter;
-import java.io.PrintWriter;
- * Run the ASF email, as trained by TrainASFEmail
- */
-public final class TestASFEmail {
- private String inputFile;
- private String modelFile;
- private TestASFEmail() {}
- public static void main(String[] args) throws IOException {
- TestASFEmail runner = new TestASFEmail();
- if (runner.parseArgs(args)) {
- runner.run(new PrintWriter(new OutputStreamWriter(System.out, Charsets.UTF_8), true));
- }
- }
- public void run(PrintWriter output) throws IOException {
- File base = new File(inputFile);
- //contains the best model
- OnlineLogisticRegression classifier =
- ModelSerializer.readBinary(new FileInputStream(modelFile), OnlineLogisticRegression.class);
- Dictionary asfDictionary = new Dictionary();
- Configuration conf = new Configuration();
- PathFilter testFilter = new PathFilter() {
- @Override
- public boolean accept(Path path) {
- return path.getName().contains("test");
- }
- };
- SequenceFileDirIterator<Text, VectorWritable> iter =
- new SequenceFileDirIterator<>(new Path(base.toString()), PathType.LIST, testFilter,
- null, true, conf);
- long numItems = 0;
- while (iter.hasNext()) {
- Pair<Text, VectorWritable> next = iter.next();
- asfDictionary.intern(next.getFirst().toString());
- numItems++;
- }
- System.out.println(numItems + " test files");
- ResultAnalyzer ra = new ResultAnalyzer(asfDictionary.values(), "DEFAULT");
- iter = new SequenceFileDirIterator<>(new Path(base.toString()), PathType.LIST, testFilter,
- null, true, conf);
- while (iter.hasNext()) {
- Pair<Text, VectorWritable> next = iter.next();
- String ng = next.getFirst().toString();
- int actual = asfDictionary.intern(ng);
- Vector result = classifier.classifyFull(next.getSecond().get());
- int cat = result.maxValueIndex();
- double score = result.maxValue();
- double ll = classifier.logLikelihood(actual, next.getSecond().get());
- ClassifierResult cr = new ClassifierResult(asfDictionary.values().get(cat), score, ll);
- ra.addInstance(asfDictionary.values().get(actual), cr);
- }
- output.println(ra);
- }
- boolean parseArgs(String[] args) {
- DefaultOptionBuilder builder = new DefaultOptionBuilder();
- Option help = builder.withLongName("help").withDescription("print this list").create();
- ArgumentBuilder argumentBuilder = new ArgumentBuilder();
- Option inputFileOption = builder.withLongName("input")
- .withRequired(true)
- .withArgument(argumentBuilder.withName("input").withMaximum(1).create())
- .withDescription("where to get training data")
- .create();
- Option modelFileOption = builder.withLongName("model")
- .withRequired(true)
- .withArgument(argumentBuilder.withName("model").withMaximum(1).create())
- .withDescription("where to get a model")
- .create();
- Group normalArgs = new GroupBuilder()
- .withOption(help)
- .withOption(inputFileOption)
- .withOption(modelFileOption)
- .create();
- Parser parser = new Parser();
- parser.setHelpOption(help);
- parser.setHelpTrigger("--help");
- parser.setGroup(normalArgs);
- parser.setHelpFormatter(new HelpFormatter(" ", "", " ", 130));
- CommandLine cmdLine = parser.parseAndHelp(args);
- if (cmdLine == null) {
- return false;
- }
- inputFile = (String) cmdLine.getValue(inputFileOption);
- modelFile = (String) cmdLine.getValue(modelFileOption);
- return true;
- }

diff --git a/examples/src/main/java/org/apache/mahout/classifier/sgd/TestNewsGroups.java b/examples/src/main/java/org/apache/mahout/classifier/sgd/TestNewsGroups.java
deleted file mode 100644
index f0316e9..0000000
--- a/examples/src/main/java/org/apache/mahout/classifier/sgd/TestNewsGroups.java
+++ /dev/null
@@ -1,141 +0,0 @@
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.mahout.classifier.sgd;
-import com.google.common.collect.HashMultiset;
-import com.google.common.collect.Multiset;
-import org.apache.commons.cli2.CommandLine;
-import org.apache.commons.cli2.Group;
-import org.apache.commons.cli2.Option;
-import org.apache.commons.cli2.builder.ArgumentBuilder;
-import org.apache.commons.cli2.builder.DefaultOptionBuilder;
-import org.apache.commons.cli2.builder.GroupBuilder;
-import org.apache.commons.cli2.commandline.Parser;
-import org.apache.commons.cli2.util.HelpFormatter;
-import org.apache.commons.io.Charsets;
-import org.apache.mahout.classifier.ClassifierResult;
-import org.apache.mahout.classifier.NewsgroupHelper;
-import org.apache.mahout.classifier.ResultAnalyzer;
-import org.apache.mahout.math.Vector;
-import org.apache.mahout.vectorizer.encoders.Dictionary;
-import java.io.File;
-import java.io.FileInputStream;
-import java.io.IOException;
-import java.io.OutputStreamWriter;
-import java.io.PrintWriter;
-import java.util.ArrayList;
-import java.util.Arrays;
-import java.util.List;
- * Run the 20 news groups test data through SGD, as trained by {@link org.apache.mahout.classifier.sgd.TrainNewsGroups}.
- */
-public final class TestNewsGroups {
- private String inputFile;
- private String modelFile;
- private TestNewsGroups() {
- }
- public static void main(String[] args) throws IOException {
- TestNewsGroups runner = new TestNewsGroups();
- if (runner.parseArgs(args)) {
- runner.run(new PrintWriter(new OutputStreamWriter(System.out, Charsets.UTF_8), true));
- }
- }
- public void run(PrintWriter output) throws IOException {
- File base = new File(inputFile);
- //contains the best model
- OnlineLogisticRegression classifier =
- ModelSerializer.readBinary(new FileInputStream(modelFile), OnlineLogisticRegression.class);
- Dictionary newsGroups = new Dictionary();
- Multiset<String> overallCounts = HashMultiset.create();
- List<File> files = new ArrayList<>();
- for (File newsgroup : base.listFiles()) {
- if (newsgroup.isDirectory()) {
- newsGroups.intern(newsgroup.getName());
- files.addAll(Arrays.asList(newsgroup.listFiles()));
- }
- }
- System.out.println(files.size() + " test files");
- ResultAnalyzer ra = new ResultAnalyzer(newsGroups.values(), "DEFAULT");
- for (File file : files) {
- String ng = file.getParentFile().getName();
- int actual = newsGroups.intern(ng);
- NewsgroupHelper helper = new NewsgroupHelper();
- //no leak type ensures this is a normal vector
- Vector input = helper.encodeFeatureVector(file, actual, 0, overallCounts);
- Vector result = classifier.classifyFull(input);
- int cat = result.maxValueIndex();
- double score = result.maxValue();
- double ll = classifier.logLikelihood(actual, input);
- ClassifierResult cr = new ClassifierResult(newsGroups.values().get(cat), score, ll);
- ra.addInstance(newsGroups.values().get(actual), cr);
- }
- output.println(ra);
- }
- boolean parseArgs(String[] args) {
- DefaultOptionBuilder builder = new DefaultOptionBuilder();
- Option help = builder.withLongName("help").withDescription("print this list").create();
- ArgumentBuilder argumentBuilder = new ArgumentBuilder();
- Option inputFileOption = builder.withLongName("input")
- .withRequired(true)
- .withArgument(argumentBuilder.withName("input").withMaximum(1).create())
- .withDescription("where to get training data")
- .create();
- Option modelFileOption = builder.withLongName("model")
- .withRequired(true)
- .withArgument(argumentBuilder.withName("model").withMaximum(1).create())
- .withDescription("where to get a model")
- .create();
- Group normalArgs = new GroupBuilder()
- .withOption(help)
- .withOption(inputFileOption)
- .withOption(modelFileOption)
- .create();
- Parser parser = new Parser();
- parser.setHelpOption(help);
- parser.setHelpTrigger("--help");
- parser.setGroup(normalArgs);
- parser.setHelpFormatter(new HelpFormatter(" ", "", " ", 130));
- CommandLine cmdLine = parser.parseAndHelp(args);
- if (cmdLine == null) {
- return false;
- }
- inputFile = (String) cmdLine.getValue(inputFileOption);
- modelFile = (String) cmdLine.getValue(modelFileOption);
- return true;
- }

diff --git a/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainASFEmail.java b/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainASFEmail.java
deleted file mode 100644
index e681f92..0000000
--- a/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainASFEmail.java
+++ /dev/null
@@ -1,137 +0,0 @@
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.mahout.classifier.sgd;
-import com.google.common.collect.HashMultiset;
-import com.google.common.collect.Multiset;
-import com.google.common.collect.Ordering;
-import org.apache.hadoop.conf.Configuration;
-import org.apache.hadoop.fs.Path;
-import org.apache.hadoop.fs.PathFilter;
-import org.apache.hadoop.io.Text;
-import org.apache.mahout.common.AbstractJob;
-import org.apache.mahout.common.Pair;
-import org.apache.mahout.common.iterator.sequencefile.PathType;
-import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirIterator;
-import org.apache.mahout.ep.State;
-import org.apache.mahout.math.VectorWritable;
-import org.apache.mahout.vectorizer.encoders.Dictionary;
-import java.io.File;
-import java.util.ArrayList;
-import java.util.Collections;
-import java.util.List;
-public final class TrainASFEmail extends AbstractJob {
- private TrainASFEmail() {
- }
- @Override
- public int run(String[] args) throws Exception {
- addInputOption();
- addOutputOption();
- addOption("categories", "nc", "The number of categories to train on", true);
- addOption("cardinality", "c", "The size of the vectors to use", "100000");
- addOption("threads", "t", "The number of threads to use in the learner", "20");
- addOption("poolSize", "p", "The number of CrossFoldLearners to use in the AdaptiveLogisticRegression. "
- + "Higher values require more memory.", "5");
- if (parseArguments(args) == null) {
- return -1;
- }
- File base = new File(getInputPath().toString());
- Multiset<String> overallCounts = HashMultiset.create();
- File output = new File(getOutputPath().toString());
- output.mkdirs();
- int numCats = Integer.parseInt(getOption("categories"));
- int cardinality = Integer.parseInt(getOption("cardinality", "100000"));
- int threadCount = Integer.parseInt(getOption("threads", "20"));
- int poolSize = Integer.parseInt(getOption("poolSize", "5"));
- Dictionary asfDictionary = new Dictionary();
- AdaptiveLogisticRegression learningAlgorithm =
- new AdaptiveLogisticRegression(numCats, cardinality, new L1(), threadCount, poolSize);
- learningAlgorithm.setInterval(800);
- learningAlgorithm.setAveragingWindow(500);
- //We ran seq2encoded and split input already, so let's just build up the dictionary
- Configuration conf = new Configuration();
- PathFilter trainFilter = new PathFilter() {
- @Override
- public boolean accept(Path path) {
- return path.getName().contains("training");
- }
- };
- SequenceFileDirIterator<Text, VectorWritable> iter =
- new SequenceFileDirIterator<>(new Path(base.toString()), PathType.LIST, trainFilter, null, true, conf);
- long numItems = 0;
- while (iter.hasNext()) {
- Pair<Text, VectorWritable> next = iter.next();
- asfDictionary.intern(next.getFirst().toString());
- numItems++;
- }
- System.out.println(numItems + " training files");
- SGDInfo info = new SGDInfo();
- iter = new SequenceFileDirIterator<>(new Path(base.toString()), PathType.LIST, trainFilter,
- null, true, conf);
- int k = 0;
- while (iter.hasNext()) {
- Pair<Text, VectorWritable> next = iter.next();
- String ng = next.getFirst().toString();
- int actual = asfDictionary.intern(ng);
- //we already have encoded
- learningAlgorithm.train(actual, next.getSecond().get());
- k++;
- State<AdaptiveLogisticRegression.Wrapper, CrossFoldLearner> best = learningAlgorithm.getBest();
- SGDHelper.analyzeState(info, 0, k, best);
- }
- learningAlgorithm.close();
- //TODO: how to dissection since we aren't processing the files here
- //SGDHelper.dissect(leakType, asfDictionary, learningAlgorithm, files, overallCounts);
- System.out.println("exiting main, writing model to " + output);
- ModelSerializer.writeBinary(output + "/asf.model",
- learningAlgorithm.getBest().getPayload().getLearner().getModels().get(0));
- List<Integer> counts = new ArrayList<>();
- System.out.println("Word counts");
- for (String count : overallCounts.elementSet()) {
- counts.add(overallCounts.count(count));
- }
- Collections.sort(counts, Ordering.natural().reverse());
- k = 0;
- for (Integer count : counts) {
- System.out.println(k + "\t" + count);
- k++;
- if (k > 1000) {
- break;
- }
- }
- return 0;
- }
- public static void main(String[] args) throws Exception {
- TrainASFEmail trainer = new TrainASFEmail();
- trainer.run(args);
- }

diff --git a/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainAdaptiveLogistic.java b/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainAdaptiveLogistic.java
deleted file mode 100644
index defb5b9..0000000
--- a/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainAdaptiveLogistic.java
+++ /dev/null
@@ -1,377 +0,0 @@
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.mahout.classifier.sgd;
-import com.google.common.io.Resources;
-import org.apache.commons.cli2.CommandLine;
-import org.apache.commons.cli2.Group;
-import org.apache.commons.cli2.Option;
-import org.apache.commons.cli2.builder.ArgumentBuilder;
-import org.apache.commons.cli2.builder.DefaultOptionBuilder;
-import org.apache.commons.cli2.builder.GroupBuilder;
-import org.apache.commons.cli2.commandline.Parser;
-import org.apache.commons.cli2.util.HelpFormatter;
-import org.apache.commons.io.Charsets;
-import org.apache.mahout.classifier.sgd.AdaptiveLogisticRegression.Wrapper;
-import org.apache.mahout.ep.State;
-import org.apache.mahout.math.RandomAccessSparseVector;
-import org.apache.mahout.math.Vector;
-import java.io.BufferedReader;
-import java.io.File;
-import java.io.FileInputStream;
-import java.io.FileOutputStream;
-import java.io.IOException;
-import java.io.InputStream;
-import java.io.InputStreamReader;
-import java.io.OutputStream;
-import java.io.OutputStreamWriter;
-import java.io.PrintWriter;
-import java.util.ArrayList;
-import java.util.List;
-import java.util.Locale;
-public final class TrainAdaptiveLogistic {
- private static String inputFile;
- private static String outputFile;
- private static AdaptiveLogisticModelParameters lmp;
- private static int passes;
- private static boolean showperf;
- private static int skipperfnum = 99;
- private static AdaptiveLogisticRegression model;
- private TrainAdaptiveLogistic() {
- }
- public static void main(String[] args) throws Exception {
- mainToOutput(args, new PrintWriter(new OutputStreamWriter(System.out, Charsets.UTF_8), true));
- }
- static void mainToOutput(String[] args, PrintWriter output) throws Exception {
- if (parseArgs(args)) {
- CsvRecordFactory csv = lmp.getCsvRecordFactory();
- model = lmp.createAdaptiveLogisticRegression();
- State<Wrapper, CrossFoldLearner> best;
- CrossFoldLearner learner = null;
- int k = 0;
- for (int pass = 0; pass < passes; pass++) {
- BufferedReader in = open(inputFile);
- // read variable names
- csv.firstLine(in.readLine());
- String line = in.readLine();
- while (line != null) {
- // for each new line, get target and predictors
- Vector input = new RandomAccessSparseVector(lmp.getNumFeatures());
- int targetValue = csv.processLine(line, input);
- // update model
- model.train(targetValue, input);
- k++;
- if (showperf && (k % (skipperfnum + 1) == 0)) {
- best = model.getBest();
- if (best != null) {
- learner = best.getPayload().getLearner();
- }
- if (learner != null) {
- double averageCorrect = learner.percentCorrect();
- double averageLL = learner.logLikelihood();
- output.printf("%d\t%.3f\t%.2f%n",
- k, averageLL, averageCorrect * 100);
- } else {
- output.printf(Locale.ENGLISH,
- "%10d %2d %s%n", k, targetValue,
- "AdaptiveLogisticRegression has not found a good model ......");
- }
- }
- line = in.readLine();
- }
- in.close();
- }
- best = model.getBest();
- if (best != null) {
- learner = best.getPayload().getLearner();
- }
- if (learner == null) {
- output.println("AdaptiveLogisticRegression has failed to train a model.");
- return;
- }
- try (OutputStream modelOutput = new FileOutputStream(outputFile)) {
- lmp.saveTo(modelOutput);
- }
- OnlineLogisticRegression lr = learner.getModels().get(0);
- output.println(lmp.getNumFeatures());
- output.println(lmp.getTargetVariable() + " ~ ");
- String sep = "";
- for (String v : csv.getTraceDictionary().keySet()) {
- double weight = predictorWeight(lr, 0, csv, v);
- if (weight != 0) {
- output.printf(Locale.ENGLISH, "%s%.3f*%s", sep, weight, v);
- sep = " + ";
- }
- }
- output.printf("%n");
- for (int row = 0; row < lr.getBeta().numRows(); row++) {
- for (String key : csv.getTraceDictionary().keySet()) {
- double weight = predictorWeight(lr, row, csv, key);
- if (weight != 0) {
- output.printf(Locale.ENGLISH, "%20s %.5f%n", key, weight);
- }
- }
- for (int column = 0; column < lr.getBeta().numCols(); column++) {
- output.printf(Locale.ENGLISH, "%15.9f ", lr.getBeta().get(row, column));
- }
- output.println();
- }
- }
- }
- private static double predictorWeight(OnlineLogisticRegression lr, int row, RecordFactory csv, String predictor) {
- double weight = 0;
- for (Integer column : csv.getTraceDictionary().get(predictor)) {
- weight += lr.getBeta().get(row, column);
- }
- return weight;
- }
- private static boolean parseArgs(String[] args) {
- DefaultOptionBuilder builder = new DefaultOptionBuilder();
- Option help = builder.withLongName("help")
- .withDescription("print this list").create();
- Option quiet = builder.withLongName("quiet")
- .withDescription("be extra quiet").create();
- ArgumentBuilder argumentBuilder = new ArgumentBuilder();
- Option showperf = builder
- .withLongName("showperf")
- .withDescription("output performance measures during training")
- .create();
- Option inputFile = builder
- .withLongName("input")
- .withRequired(true)
- .withArgument(
- argumentBuilder.withName("input").withMaximum(1)
- .create())
- .withDescription("where to get training data").create();
- Option outputFile = builder
- .withLongName("output")
- .withRequired(true)
- .withArgument(
- argumentBuilder.withName("output").withMaximum(1)
- .create())
- .withDescription("where to write the model content").create();
- Option threads = builder.withLongName("threads")
- .withArgument(
- argumentBuilder.withName("threads").withDefault("4").create())
- .withDescription("the number of threads AdaptiveLogisticRegression uses")
- .create();
- Option predictors = builder.withLongName("predictors")
- .withRequired(true)
- .withArgument(argumentBuilder.withName("predictors").create())
- .withDescription("a list of predictor variables").create();
- Option types = builder
- .withLongName("types")
- .withRequired(true)
- .withArgument(argumentBuilder.withName("types").create())
- .withDescription(
- "a list of predictor variable types (numeric, word, or text)")
- .create();
- Option target = builder
- .withLongName("target")
- .withDescription("the name of the target variable")
- .withRequired(true)
- .withArgument(
- argumentBuilder.withName("target").withMaximum(1)
- .create())
- .create();
- Option targetCategories = builder
- .withLongName("categories")
- .withDescription("the number of target categories to be considered")
- .withRequired(true)
- .withArgument(argumentBuilder.withName("categories").withMaximum(1).create())
- .create();
- Option features = builder
- .withLongName("features")
- .withDescription("the number of internal hashed features to use")
- .withArgument(
- argumentBuilder.withName("numFeatures")
- .withDefault("1000").withMaximum(1).create())
- .create();
- Option passes = builder
- .withLongName("passes")
- .withDescription("the number of times to pass over the input data")
- .withArgument(
- argumentBuilder.withName("passes").withDefault("2")
- .withMaximum(1).create())
- .create();
- Option interval = builder.withLongName("interval")
- .withArgument(
- argumentBuilder.withName("interval").withDefault("500").create())
- .withDescription("the interval property of AdaptiveLogisticRegression")
- .create();
- Option window = builder.withLongName("window")
- .withArgument(
- argumentBuilder.withName("window").withDefault("800").create())
- .withDescription("the average propery of AdaptiveLogisticRegression")
- .create();
- Option skipperfnum = builder.withLongName("skipperfnum")
- .withArgument(
- argumentBuilder.withName("skipperfnum").withDefault("99").create())
- .withDescription("show performance measures every (skipperfnum + 1) rows")
- .create();
- Option prior = builder.withLongName("prior")
- .withArgument(
- argumentBuilder.withName("prior").withDefault("L1").create())
- .withDescription("the prior algorithm to use: L1, L2, ebp, tp, up")
- .create();
- Option priorOption = builder.withLongName("prioroption")
- .withArgument(
- argumentBuilder.withName("prioroption").create())
- .withDescription("constructor parameter for ElasticBandPrior and TPrior")
- .create();
- Option auc = builder.withLongName("auc")
- .withArgument(
- argumentBuilder.withName("auc").withDefault("global").create())
- .withDescription("the auc to use: global or grouped")
- .create();
- Group normalArgs = new GroupBuilder().withOption(help)
- .withOption(quiet).withOption(inputFile).withOption(outputFile)
- .withOption(target).withOption(targetCategories)
- .withOption(predictors).withOption(types).withOption(passes)
- .withOption(interval).withOption(window).withOption(threads)
- .withOption(prior).withOption(features).withOption(showperf)
- .withOption(skipperfnum).withOption(priorOption).withOption(auc)
- .create();
- Parser parser = new Parser();
- parser.setHelpOption(help);
- parser.setHelpTrigger("--help");
- parser.setGroup(normalArgs);
- parser.setHelpFormatter(new HelpFormatter(" ", "", " ", 130));
- CommandLine cmdLine = parser.parseAndHelp(args);
- if (cmdLine == null) {
- return false;
- }
- TrainAdaptiveLogistic.inputFile = getStringArgument(cmdLine, inputFile);
- TrainAdaptiveLogistic.outputFile = getStringArgument(cmdLine,
- outputFile);
- List<String> typeList = new ArrayList<>();
- for (Object x : cmdLine.getValues(types)) {
- typeList.add(x.toString());
- }
- List<String> predictorList = new ArrayList<>();
- for (Object x : cmdLine.getValues(predictors)) {
- predictorList.add(x.toString());
- }
- lmp = new AdaptiveLogisticModelParameters();
- lmp.setTargetVariable(getStringArgument(cmdLine, target));
- lmp.setMaxTargetCategories(getIntegerArgument(cmdLine, targetCategories));
- lmp.setNumFeatures(getIntegerArgument(cmdLine, features));
- lmp.setInterval(getIntegerArgument(cmdLine, interval));
- lmp.setAverageWindow(getIntegerArgument(cmdLine, window));
- lmp.setThreads(getIntegerArgument(cmdLine, threads));
- lmp.setAuc(getStringArgument(cmdLine, auc));
- lmp.setPrior(getStringArgument(cmdLine, prior));
- if (cmdLine.getValue(priorOption) != null) {
- lmp.setPriorOption(getDoubleArgument(cmdLine, priorOption));
- }
- lmp.setTypeMap(predictorList, typeList);
- TrainAdaptiveLogistic.showperf = getBooleanArgument(cmdLine, showperf);
- TrainAdaptiveLogistic.skipperfnum = getIntegerArgument(cmdLine, skipperfnum);
- TrainAdaptiveLogistic.passes = getIntegerArgument(cmdLine, passes);
- lmp.checkParameters();
- return true;
- }
- private static String getStringArgument(CommandLine cmdLine,
- Option inputFile) {
- return (String) cmdLine.getValue(inputFile);
- }
- private static boolean getBooleanArgument(CommandLine cmdLine, Option option) {
- return cmdLine.hasOption(option);
- }
- private static int getIntegerArgument(CommandLine cmdLine, Option features) {
- return Integer.parseInt((String) cmdLine.getValue(features));
- }
- private static double getDoubleArgument(CommandLine cmdLine, Option op) {
- return Double.parseDouble((String) cmdLine.getValue(op));
- }
- public static AdaptiveLogisticRegression getModel() {
- return model;
- }
- public static LogisticModelParameters getParameters() {
- return lmp;
- }
- static BufferedReader open(String inputFile) throws IOException {
- InputStream in;
- try {
- in = Resources.getResource(inputFile).openStream();
- } catch (IllegalArgumentException e) {
- in = new FileInputStream(new File(inputFile));
- }
- return new BufferedReader(new InputStreamReader(in, Charsets.UTF_8));
- }

diff --git a/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainLogistic.java b/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainLogistic.java
deleted file mode 100644
index f4b8bcb..0000000
--- a/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainLogistic.java
+++ /dev/null
@@ -1,311 +0,0 @@
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.mahout.classifier.sgd;
-import com.google.common.io.Resources;
-import org.apache.commons.cli2.CommandLine;
-import org.apache.commons.cli2.Group;
-import org.apache.commons.cli2.Option;
-import org.apache.commons.cli2.builder.ArgumentBuilder;
-import org.apache.commons.cli2.builder.DefaultOptionBuilder;
-import org.apache.commons.cli2.builder.GroupBuilder;
-import org.apache.commons.cli2.commandline.Parser;
-import org.apache.commons.cli2.util.HelpFormatter;
-import org.apache.commons.io.Charsets;
-import org.apache.mahout.math.RandomAccessSparseVector;
-import org.apache.mahout.math.Vector;
-import java.io.BufferedReader;
-import java.io.File;
-import java.io.FileInputStream;
-import java.io.FileOutputStream;
-import java.io.IOException;
-import java.io.InputStream;
-import java.io.InputStreamReader;
-import java.io.OutputStream;
-import java.io.OutputStreamWriter;
-import java.io.PrintWriter;
-import java.util.ArrayList;
-import java.util.List;
-import java.util.Locale;
- * Train a logistic regression for the examples from Chapter 13 of Mahout in Action
- */
-public final class TrainLogistic {
- private static String inputFile;
- private static String outputFile;
- private static LogisticModelParameters lmp;
- private static int passes;
- private static boolean scores;
- private static OnlineLogisticRegression model;
- private TrainLogistic() {
- }
- public static void main(String[] args) throws Exception {
- mainToOutput(args, new PrintWriter(new OutputStreamWriter(System.out, Charsets.UTF_8), true));
- }
- static void mainToOutput(String[] args, PrintWriter output) throws Exception {
- if (parseArgs(args)) {
- double logPEstimate = 0;
- int samples = 0;
- CsvRecordFactory csv = lmp.getCsvRecordFactory();
- OnlineLogisticRegression lr = lmp.createRegression();
- for (int pass = 0; pass < passes; pass++) {
- try (BufferedReader in = open(inputFile)) {
- // read variable names
- csv.firstLine(in.readLine());
- String line = in.readLine();
- while (line != null) {
- // for each new line, get target and predictors
- Vector input = new RandomAccessSparseVector(lmp.getNumFeatures());
- int targetValue = csv.processLine(line, input);
- // check performance while this is still news
- double logP = lr.logLikelihood(targetValue, input);
- if (!Double.isInfinite(logP)) {
- if (samples < 20) {
- logPEstimate = (samples * logPEstimate + logP) / (samples + 1);
- } else {
- logPEstimate = 0.95 * logPEstimate + 0.05 * logP;
- }
- samples++;
- }
- double p = lr.classifyScalar(input);
- if (scores) {
- output.printf(Locale.ENGLISH, "%10d %2d %10.2f %2.4f %10.4f %10.4f%n",
- samples, targetValue, lr.currentLearningRate(), p, logP, logPEstimate);
- }
- // now update model
- lr.train(targetValue, input);
- line = in.readLine();
- }
- }
- }
- try (OutputStream modelOutput = new FileOutputStream(outputFile)) {
- lmp.saveTo(modelOutput);
- }
- output.println(lmp.getNumFeatures());
- output.println(lmp.getTargetVariable() + " ~ ");
- String sep = "";
- for (String v : csv.getTraceDictionary().keySet()) {
- double weight = predictorWeight(lr, 0, csv, v);
- if (weight != 0) {
- output.printf(Locale.ENGLISH, "%s%.3f*%s", sep, weight, v);
- sep = " + ";
- }
- }
- output.printf("%n");
- model = lr;
- for (int row = 0; row < lr.getBeta().numRows(); row++) {
- for (String key : csv.getTraceDictionary().keySet()) {
- double weight = predictorWeight(lr, row, csv, key);
- if (weight != 0) {
- output.printf(Locale.ENGLISH, "%20s %.5f%n", key, weight);
- }
- }
- for (int column = 0; column < lr.getBeta().numCols(); column++) {
- output.printf(Locale.ENGLISH, "%15.9f ", lr.getBeta().get(row, column));
- }
- output.println();
- }
- }
- }
- private static double predictorWeight(OnlineLogisticRegression lr, int row, RecordFactory csv, String predictor) {
- double weight = 0;
- for (Integer column : csv.getTraceDictionary().get(predictor)) {
- weight += lr.getBeta().get(row, column);
- }
- return weight;
- }
- private static boolean parseArgs(String[] args) {
- DefaultOptionBuilder builder = new DefaultOptionBuilder();
- Option help = builder.withLongName("help").withDescription("print this list").create();
- Option quiet = builder.withLongName("quiet").withDescription("be extra quiet").create();
- Option scores = builder.withLongName("scores").withDescription("output score diagnostics during training").create();
- ArgumentBuilder argumentBuilder = new ArgumentBuilder();
- Option inputFile = builder.withLongName("input")
- .withRequired(true)
- .withArgument(argumentBuilder.withName("input").withMaximum(1).create())
- .withDescription("where to get training data")
- .create();
- Option outputFile = builder.withLongName("output")
- .withRequired(true)
- .withArgument(argumentBuilder.withName("output").withMaximum(1).create())
- .withDescription("where to get training data")
- .create();
- Option predictors = builder.withLongName("predictors")
- .withRequired(true)
- .withArgument(argumentBuilder.withName("p").create())
- .withDescription("a list of predictor variables")
- .create();
- Option types = builder.withLongName("types")
- .withRequired(true)
- .withArgument(argumentBuilder.withName("t").create())
- .withDescription("a list of predictor variable types (numeric, word, or text)")
- .create();
- Option target = builder.withLongName("target")
- .withRequired(true)
- .withArgument(argumentBuilder.withName("target").withMaximum(1).create())
- .withDescription("the name of the target variable")
- .create();
- Option features = builder.withLongName("features")
- .withArgument(
- argumentBuilder.withName("numFeatures")
- .withDefault("1000")
- .withMaximum(1).create())
- .withDescription("the number of internal hashed features to use")
- .create();
- Option passes = builder.withLongName("passes")
- .withArgument(
- argumentBuilder.withName("passes")
- .withDefault("2")
- .withMaximum(1).create())
- .withDescription("the number of times to pass over the input data")
- .create();
- Option lambda = builder.withLongName("lambda")
- .withArgument(argumentBuilder.withName("lambda").withDefault("1e-4").withMaximum(1).create())
- .withDescription("the amount of coefficient decay to use")
- .create();
- Option rate = builder.withLongName("rate")
- .withArgument(argumentBuilder.withName("learningRate").withDefault("1e-3").withMaximum(1).create())
- .withDescription("the learning rate")
- .create();
- Option noBias = builder.withLongName("noBias")
- .withDescription("don't include a bias term")
- .create();
- Option targetCategories = builder.withLongName("categories")
- .withRequired(true)
- .withArgument(argumentBuilder.withName("number").withMaximum(1).create())
- .withDescription("the number of target categories to be considered")
- .create();
- Group normalArgs = new GroupBuilder()
- .withOption(help)
- .withOption(quiet)
- .withOption(inputFile)
- .withOption(outputFile)
- .withOption(target)
- .withOption(targetCategories)
- .withOption(predictors)
- .withOption(types)
- .withOption(passes)
- .withOption(lambda)
- .withOption(rate)
- .withOption(noBias)
- .withOption(features)
- .create();
- Parser parser = new Parser();
- parser.setHelpOption(help);
- parser.setHelpTrigger("--help");
- parser.setGroup(normalArgs);
- parser.setHelpFormatter(new HelpFormatter(" ", "", " ", 130));
- CommandLine cmdLine = parser.parseAndHelp(args);
- if (cmdLine == null) {
- return false;
- }
- TrainLogistic.inputFile = getStringArgument(cmdLine, inputFile);
- TrainLogistic.outputFile = getStringArgument(cmdLine, outputFile);
- List<String> typeList = new ArrayList<>();
- for (Object x : cmdLine.getValues(types)) {
- typeList.add(x.toString());
- }
- List<String> predictorList = new ArrayList<>();
- for (Object x : cmdLine.getValues(predictors)) {
- predictorList.add(x.toString());
- }
- lmp = new LogisticModelParameters();
- lmp.setTargetVariable(getStringArgument(cmdLine, target));
- lmp.setMaxTargetCategories(getIntegerArgument(cmdLine, targetCategories));
- lmp.setNumFeatures(getIntegerArgument(cmdLine, features));
- lmp.setUseBias(!getBooleanArgument(cmdLine, noBias));
- lmp.setTypeMap(predictorList, typeList);
- lmp.setLambda(getDoubleArgument(cmdLine, lambda));
- lmp.setLearningRate(getDoubleArgument(cmdLine, rate));
- TrainLogistic.scores = getBooleanArgument(cmdLine, scores);
- TrainLogistic.passes = getIntegerArgument(cmdLine, passes);
- return true;
- }
- private static String getStringArgument(CommandLine cmdLine, Option inputFile) {
- return (String) cmdLine.getValue(inputFile);
- }
- private static boolean getBooleanArgument(CommandLine cmdLine, Option option) {
- return cmdLine.hasOption(option);
- }
- private static int getIntegerArgument(CommandLine cmdLine, Option features) {
- return Integer.parseInt((String) cmdLine.getValue(features));
- }
- private static double getDoubleArgument(CommandLine cmdLine, Option op) {
- return Double.parseDouble((String) cmdLine.getValue(op));
- }
- public static OnlineLogisticRegression getModel() {
- return model;
- }
- public static LogisticModelParameters getParameters() {
- return lmp;
- }
- static BufferedReader open(String inputFile) throws IOException {
- InputStream in;
- try {
- in = Resources.getResource(inputFile).openStream();
- } catch (IllegalArgumentException e) {
- in = new FileInputStream(new File(inputFile));
- }
- return new BufferedReader(new InputStreamReader(in, Charsets.UTF_8));
- }
2018-06-27 13:14:35 UTC
diff --git a/examples/src/main/java/org/apache/mahout/cf/taste/example/email/MailToPrefsDriver.java b/examples/src/main/java/org/apache/mahout/cf/taste/example/email/MailToPrefsDriver.java
deleted file mode 100644
index 752bb48..0000000
--- a/examples/src/main/java/org/apache/mahout/cf/taste/example/email/MailToPrefsDriver.java
+++ /dev/null
@@ -1,274 +0,0 @@
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.mahout.cf.taste.example.email;
-import com.google.common.io.Closeables;
-import org.apache.hadoop.conf.Configuration;
-import org.apache.hadoop.filecache.DistributedCache;
-import org.apache.hadoop.fs.FileStatus;
-import org.apache.hadoop.fs.FileSystem;
-import org.apache.hadoop.fs.FileUtil;
-import org.apache.hadoop.fs.Path;
-import org.apache.hadoop.io.IntWritable;
-import org.apache.hadoop.io.LongWritable;
-import org.apache.hadoop.io.NullWritable;
-import org.apache.hadoop.io.SequenceFile;
-import org.apache.hadoop.io.Text;
-import org.apache.hadoop.io.Writable;
-import org.apache.hadoop.mapreduce.Job;
-import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat;
-import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat;
-import org.apache.hadoop.mapreduce.lib.output.TextOutputFormat;
-import org.apache.hadoop.util.ToolRunner;
-import org.apache.mahout.common.AbstractJob;
-import org.apache.mahout.common.HadoopUtil;
-import org.apache.mahout.common.Pair;
-import org.apache.mahout.common.commandline.DefaultOptionCreator;
-import org.apache.mahout.common.iterator.sequencefile.PathFilters;
-import org.apache.mahout.common.iterator.sequencefile.PathType;
-import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirIterable;
-import org.apache.mahout.math.VarIntWritable;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
-import java.io.IOException;
-import java.net.URI;
-import java.util.ArrayList;
-import java.util.List;
-import java.util.Map;
-import java.util.concurrent.atomic.AtomicInteger;
- * Convert the Mail archives (see {@link org.apache.mahout.text.SequenceFilesFromMailArchives}) to a preference
- * file that can be consumed by the {@link org.apache.mahout.cf.taste.hadoop.item.RecommenderJob}.
- * <p/>
- * This assumes the input is a Sequence File, that the key is: filename/message id and the value is a list
- * (separated by the user's choosing) containing the from email and any references
- * <p/>
- * The output is a matrix where either the from or to are the rows (represented as longs) and the columns are the
- * message ids that the user has interacted with (as a VectorWritable). This class currently does not account for
- * thread hijacking.
- * <p/>
- * It also outputs a side table mapping the row ids to their original and the message ids to the message thread id
- */
-public final class MailToPrefsDriver extends AbstractJob {
- private static final Logger log = LoggerFactory.getLogger(MailToPrefsDriver.class);
- private static final String OUTPUT_FILES_PATTERN = "part-*";
- private static final int DICTIONARY_BYTE_OVERHEAD = 4;
- public static void main(String[] args) throws Exception {
- ToolRunner.run(new Configuration(), new MailToPrefsDriver(), args);
- }
- @Override
- public int run(String[] args) throws Exception {
- addInputOption();
- addOutputOption();
- addOption(DefaultOptionCreator.overwriteOption().create());
- addOption("chunkSize", "cs", "The size of chunks to write. Default is 100 mb", "100");
- addOption("separator", "sep", "The separator used in the input file to separate to, from, subject. Default is \\n",
- "\n");
- addOption("from", "f", "The position in the input text (value) where the from email is located, starting from "
- + "zero (0).", "0");
- addOption("refs", "r", "The position in the input text (value) where the reference ids are located, "
- + "starting from zero (0).", "1");
- addOption(buildOption("useCounts", "u", "If set, then use the number of times the user has interacted with a "
- + "thread as an indication of their preference. Otherwise, use boolean preferences.", false, false,
- String.valueOf(true)));
- Map<String, List<String>> parsedArgs = parseArguments(args);
- Path input = getInputPath();
- Path output = getOutputPath();
- int chunkSize = Integer.parseInt(getOption("chunkSize"));
- String separator = getOption("separator");
- Configuration conf = getConf();
- boolean useCounts = hasOption("useCounts");
- AtomicInteger currentPhase = new AtomicInteger();
- int[] msgDim = new int[1];
- //TODO: mod this to not do so many passes over the data. Dictionary creation could probably be a chain mapper
- List<Path> msgIdChunks = null;
- boolean overwrite = hasOption(DefaultOptionCreator.OVERWRITE_OPTION);
- // create the dictionary between message ids and longs
- if (shouldRunNextPhase(parsedArgs, currentPhase)) {
- //TODO: there seems to be a pattern emerging for dictionary creation
- // -- sparse vectors from seq files also has this.
- Path msgIdsPath = new Path(output, "msgIds");
- if (overwrite) {
- HadoopUtil.delete(conf, msgIdsPath);
- }
- log.info("Creating Msg Id Dictionary");
- Job createMsgIdDictionary = prepareJob(input,
- msgIdsPath,
- SequenceFileInputFormat.class,
- MsgIdToDictionaryMapper.class,
- Text.class,
- VarIntWritable.class,
- MailToDictionaryReducer.class,
- Text.class,
- VarIntWritable.class,
- SequenceFileOutputFormat.class);
- boolean succeeded = createMsgIdDictionary.waitForCompletion(true);
- if (!succeeded) {
- return -1;
- }
- //write out the dictionary at the top level
- msgIdChunks = createDictionaryChunks(msgIdsPath, output, "msgIds-dictionary-",
- createMsgIdDictionary.getConfiguration(), chunkSize, msgDim);
- }
- //create the dictionary between from email addresses and longs
- List<Path> fromChunks = null;
- if (shouldRunNextPhase(parsedArgs, currentPhase)) {
- Path fromIdsPath = new Path(output, "fromIds");
- if (overwrite) {
- HadoopUtil.delete(conf, fromIdsPath);
- }
- log.info("Creating From Id Dictionary");
- Job createFromIdDictionary = prepareJob(input,
- fromIdsPath,
- SequenceFileInputFormat.class,
- FromEmailToDictionaryMapper.class,
- Text.class,
- VarIntWritable.class,
- MailToDictionaryReducer.class,
- Text.class,
- VarIntWritable.class,
- SequenceFileOutputFormat.class);
- createFromIdDictionary.getConfiguration().set(EmailUtility.SEPARATOR, separator);
- boolean succeeded = createFromIdDictionary.waitForCompletion(true);
- if (!succeeded) {
- return -1;
- }
- //write out the dictionary at the top level
- int[] fromDim = new int[1];
- fromChunks = createDictionaryChunks(fromIdsPath, output, "fromIds-dictionary-",
- createFromIdDictionary.getConfiguration(), chunkSize, fromDim);
- }
- //OK, we have our dictionaries, let's output the real thing we need: <from_id -> <msgId, msgId, msgId, ...>>
- if (shouldRunNextPhase(parsedArgs, currentPhase) && fromChunks != null && msgIdChunks != null) {
- //Job map
- //may be a way to do this so that we can load the from ids in memory, if they are small enough so that
- // we don't need the double loop
- log.info("Creating recommendation matrix");
- Path vecPath = new Path(output, "recInput");
- if (overwrite) {
- HadoopUtil.delete(conf, vecPath);
- }
- //conf.set(EmailUtility.FROM_DIMENSION, String.valueOf(fromDim[0]));
- conf.set(EmailUtility.MSG_ID_DIMENSION, String.valueOf(msgDim[0]));
- conf.set(EmailUtility.FROM_PREFIX, "fromIds-dictionary-");
- conf.set(EmailUtility.MSG_IDS_PREFIX, "msgIds-dictionary-");
- conf.set(EmailUtility.FROM_INDEX, getOption("from"));
- conf.set(EmailUtility.REFS_INDEX, getOption("refs"));
- conf.set(EmailUtility.SEPARATOR, separator);
- conf.set(MailToRecReducer.USE_COUNTS_PREFERENCE, String.valueOf(useCounts));
- int j = 0;
- int i = 0;
- for (Path fromChunk : fromChunks) {
- for (Path idChunk : msgIdChunks) {
- Path out = new Path(vecPath, "tmp-" + i + '-' + j);
- DistributedCache.setCacheFiles(new URI[]{fromChunk.toUri(), idChunk.toUri()}, conf);
- Job createRecMatrix = prepareJob(input, out, SequenceFileInputFormat.class,
- MailToRecMapper.class, Text.class, LongWritable.class, MailToRecReducer.class, Text.class,
- NullWritable.class, TextOutputFormat.class);
- createRecMatrix.getConfiguration().set("mapred.output.compress", "false");
- boolean succeeded = createRecMatrix.waitForCompletion(true);
- if (!succeeded) {
- return -1;
- }
- //copy the results up a level
- //HadoopUtil.copyMergeSeqFiles(out.getFileSystem(conf), out, vecPath.getFileSystem(conf), outPath, true,
- // conf, "");
- FileStatus[] fs = HadoopUtil.getFileStatus(new Path(out, "*"), PathType.GLOB, PathFilters.partFilter(), null,
- conf);
- for (int k = 0; k < fs.length; k++) {
- FileStatus f = fs[k];
- Path outPath = new Path(vecPath, "chunk-" + i + '-' + j + '-' + k);
- FileUtil.copy(f.getPath().getFileSystem(conf), f.getPath(), outPath.getFileSystem(conf), outPath, true,
- overwrite, conf);
- }
- HadoopUtil.delete(conf, out);
- j++;
- }
- i++;
- }
- //concat the files together
- /*Path mergePath = new Path(output, "vectors.dat");
- if (overwrite) {
- HadoopUtil.delete(conf, mergePath);
- }
- log.info("Merging together output vectors to vectors.dat in {}", output);*/
- //HadoopUtil.copyMergeSeqFiles(vecPath.getFileSystem(conf), vecPath, mergePath.getFileSystem(conf), mergePath,
- // false, conf, "\n");
- }
- return 0;
- }
- private static List<Path> createDictionaryChunks(Path inputPath,
- Path dictionaryPathBase,
- String name,
- Configuration baseConf,
- int chunkSizeInMegabytes, int[] maxTermDimension)
- throws IOException {
- List<Path> chunkPaths = new ArrayList<>();
- Configuration conf = new Configuration(baseConf);
- FileSystem fs = FileSystem.get(inputPath.toUri(), conf);
- long chunkSizeLimit = chunkSizeInMegabytes * 1024L * 1024L;
- int chunkIndex = 0;
- Path chunkPath = new Path(dictionaryPathBase, name + chunkIndex);
- chunkPaths.add(chunkPath);
- SequenceFile.Writer dictWriter = new SequenceFile.Writer(fs, conf, chunkPath, Text.class, IntWritable.class);
- try {
- long currentChunkSize = 0;
- Path filesPattern = new Path(inputPath, OUTPUT_FILES_PATTERN);
- int i = 1; //start at 1, since a miss in the OpenObjectIntHashMap returns a 0
- for (Pair<Writable, Writable> record
- : new SequenceFileDirIterable<>(filesPattern, PathType.GLOB, null, null, true, conf)) {
- if (currentChunkSize > chunkSizeLimit) {
- Closeables.close(dictWriter, false);
- chunkIndex++;
- chunkPath = new Path(dictionaryPathBase, name + chunkIndex);
- chunkPaths.add(chunkPath);
- dictWriter = new SequenceFile.Writer(fs, conf, chunkPath, Text.class, IntWritable.class);
- currentChunkSize = 0;
- }
- Writable key = record.getFirst();
- int fieldSize = DICTIONARY_BYTE_OVERHEAD + key.toString().length() * 2 + Integer.SIZE / 8;
- currentChunkSize += fieldSize;
- dictWriter.append(key, new IntWritable(i++));
- }
- maxTermDimension[0] = i;
- } finally {
- Closeables.close(dictWriter, false);
- }
- return chunkPaths;
- }

diff --git a/examples/src/main/java/org/apache/mahout/cf/taste/example/email/MailToRecMapper.java b/examples/src/main/java/org/apache/mahout/cf/taste/example/email/MailToRecMapper.java
deleted file mode 100644
index 91bbd17..0000000
--- a/examples/src/main/java/org/apache/mahout/cf/taste/example/email/MailToRecMapper.java
+++ /dev/null
@@ -1,101 +0,0 @@
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.mahout.cf.taste.example.email;
-import org.apache.commons.lang3.StringUtils;
-import org.apache.hadoop.conf.Configuration;
-import org.apache.hadoop.io.LongWritable;
-import org.apache.hadoop.io.Text;
-import org.apache.hadoop.mapreduce.Mapper;
-import org.apache.mahout.math.map.OpenObjectIntHashMap;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
-import java.io.IOException;
-public final class MailToRecMapper extends Mapper<Text, Text, Text, LongWritable> {
- private static final Logger log = LoggerFactory.getLogger(MailToRecMapper.class);
- private final OpenObjectIntHashMap<String> fromDictionary = new OpenObjectIntHashMap<>();
- private final OpenObjectIntHashMap<String> msgIdDictionary = new OpenObjectIntHashMap<>();
- private String separator = "\n";
- private int fromIdx;
- private int refsIdx;
- public enum Counters {
- }
- @Override
- protected void setup(Context context) throws IOException, InterruptedException {
- super.setup(context);
- Configuration conf = context.getConfiguration();
- String fromPrefix = conf.get(EmailUtility.FROM_PREFIX);
- String msgPrefix = conf.get(EmailUtility.MSG_IDS_PREFIX);
- fromIdx = conf.getInt(EmailUtility.FROM_INDEX, 0);
- refsIdx = conf.getInt(EmailUtility.REFS_INDEX, 1);
- EmailUtility.loadDictionaries(conf, fromPrefix, fromDictionary, msgPrefix, msgIdDictionary);
- log.info("From Dictionary size: {} Msg Id Dictionary size: {}", fromDictionary.size(), msgIdDictionary.size());
- separator = context.getConfiguration().get(EmailUtility.SEPARATOR);
- }
- @Override
- protected void map(Text key, Text value, Context context) throws IOException, InterruptedException {
- int msgIdKey = Integer.MIN_VALUE;
- int fromKey = Integer.MIN_VALUE;
- String valStr = value.toString();
- String[] splits = StringUtils.splitByWholeSeparatorPreserveAllTokens(valStr, separator);
- if (splits != null && splits.length > 0) {
- if (splits.length > refsIdx) {
- String from = EmailUtility.cleanUpEmailAddress(splits[fromIdx]);
- fromKey = fromDictionary.get(from);
- }
- //get the references
- if (splits.length > refsIdx) {
- String[] theRefs = EmailUtility.parseReferences(splits[refsIdx]);
- if (theRefs != null && theRefs.length > 0) {
- //we have a reference, the first one is the original message id, so map to that one if it exists
- msgIdKey = msgIdDictionary.get(theRefs[0]);
- context.getCounter(Counters.REFERENCE).increment(1);
- }
- }
- }
- //we don't have any references, so use the msg id
- if (msgIdKey == Integer.MIN_VALUE) {
- //get the msg id and the from and output the associated ids
- String keyStr = key.toString();
- int idx = keyStr.lastIndexOf('/');
- if (idx != -1) {
- String msgId = keyStr.substring(idx + 1);
- msgIdKey = msgIdDictionary.get(msgId);
- context.getCounter(Counters.ORIGINAL).increment(1);
- }
- }
- if (msgIdKey != Integer.MIN_VALUE && fromKey != Integer.MIN_VALUE) {
- context.write(new Text(fromKey + "," + msgIdKey), new LongWritable(1));
- }
- }

diff --git a/examples/src/main/java/org/apache/mahout/cf/taste/example/email/MailToRecReducer.java b/examples/src/main/java/org/apache/mahout/cf/taste/example/email/MailToRecReducer.java
deleted file mode 100644
index ee36a41..0000000
--- a/examples/src/main/java/org/apache/mahout/cf/taste/example/email/MailToRecReducer.java
+++ /dev/null
@@ -1,53 +0,0 @@
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.mahout.cf.taste.example.email;
-import org.apache.hadoop.io.LongWritable;
-import org.apache.hadoop.io.NullWritable;
-import org.apache.hadoop.io.Text;
-import org.apache.hadoop.mapreduce.Reducer;
-import java.io.IOException;
-public class MailToRecReducer extends Reducer<Text, LongWritable, Text, NullWritable> {
- //if true, then output weight
- private boolean useCounts = true;
- /**
- * We can either ignore how many times the user interacted (boolean) or output the number of times they interacted.
- */
- public static final String USE_COUNTS_PREFERENCE = "useBooleanPreferences";
- @Override
- protected void setup(Context context) throws IOException, InterruptedException {
- useCounts = context.getConfiguration().getBoolean(USE_COUNTS_PREFERENCE, true);
- }
- @Override
- protected void reduce(Text key, Iterable<LongWritable> values, Context context)
- throws IOException, InterruptedException {
- if (useCounts) {
- long sum = 0;
- for (LongWritable value : values) {
- sum++;
- }
- context.write(new Text(key.toString() + ',' + sum), null);
- } else {
- context.write(new Text(key.toString()), null);
- }
- }

diff --git a/examples/src/main/java/org/apache/mahout/cf/taste/example/email/MsgIdToDictionaryMapper.java b/examples/src/main/java/org/apache/mahout/cf/taste/example/email/MsgIdToDictionaryMapper.java
deleted file mode 100644
index f3de847..0000000
--- a/examples/src/main/java/org/apache/mahout/cf/taste/example/email/MsgIdToDictionaryMapper.java
+++ /dev/null
@@ -1,49 +0,0 @@
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.mahout.cf.taste.example.email;
-import org.apache.hadoop.io.Text;
-import org.apache.hadoop.mapreduce.Mapper;
-import org.apache.mahout.math.VarIntWritable;
-import java.io.IOException;
- * Assumes the input is in the format created by {@link org.apache.mahout.text.SequenceFilesFromMailArchives}
- */
-public final class MsgIdToDictionaryMapper extends Mapper<Text, Text, Text, VarIntWritable> {
- @Override
- protected void map(Text key, Text value, Context context) throws IOException, InterruptedException {
- //message id is in the key: /201008/AANLkTikvVnhNH+Y5AGEwqd2=***@mail.gmail.com
- String keyStr = key.toString();
- int idx = keyStr.lastIndexOf('@'); //find the last @
- if (idx == -1) {
- context.getCounter(EmailUtility.Counters.NO_MESSAGE_ID).increment(1);
- } else {
- //found the @, now find the last slash before the @ and grab everything after that
- idx = keyStr.lastIndexOf('/', idx);
- String msgId = keyStr.substring(idx + 1);
- if (EmailUtility.WHITESPACE.matcher(msgId).matches()) {
- context.getCounter(EmailUtility.Counters.NO_MESSAGE_ID).increment(1);
- } else {
- context.write(new Text(msgId), new VarIntWritable(1));
- }
- }
- }

diff --git a/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/DataFileIterable.java b/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/DataFileIterable.java
deleted file mode 100644
index c358021..0000000
--- a/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/DataFileIterable.java
+++ /dev/null
@@ -1,44 +0,0 @@
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.mahout.cf.taste.example.kddcup;
-import java.io.File;
-import java.io.IOException;
-import java.util.Iterator;
-import org.apache.mahout.cf.taste.model.PreferenceArray;
-import org.apache.mahout.common.Pair;
-public final class DataFileIterable implements Iterable<Pair<PreferenceArray,long[]>> {
- private final File dataFile;
- public DataFileIterable(File dataFile) {
- this.dataFile = dataFile;
- }
- @Override
- public Iterator<Pair<PreferenceArray, long[]>> iterator() {
- try {
- return new DataFileIterator(dataFile);
- } catch (IOException ioe) {
- throw new IllegalStateException(ioe);
- }
- }

diff --git a/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/DataFileIterator.java b/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/DataFileIterator.java
deleted file mode 100644
index 786e080..0000000
--- a/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/DataFileIterator.java
+++ /dev/null
@@ -1,158 +0,0 @@
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.mahout.cf.taste.example.kddcup;
-import java.io.Closeable;
-import java.io.File;
-import java.io.IOException;
-import java.util.regex.Pattern;
-import com.google.common.collect.AbstractIterator;
-import com.google.common.io.Closeables;
-import org.apache.mahout.cf.taste.impl.common.SkippingIterator;
-import org.apache.mahout.cf.taste.impl.model.GenericUserPreferenceArray;
-import org.apache.mahout.cf.taste.model.PreferenceArray;
-import org.apache.mahout.common.iterator.FileLineIterator;
-import org.apache.mahout.common.Pair;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
- * <p>An {@link java.util.Iterator} which iterates over any of the KDD Cup's rating files. These include the files
- * {train,test,validation}Idx{1,2}}.txt. See http://kddcup.yahoo.com/. Each element in the iteration corresponds
- * to one user's ratings as a {@link PreferenceArray} and corresponding timestamps as a parallel {@code long}
- * array.</p>
- *
- * <p>Timestamps in the data set are relative to some unknown point in time, for anonymity. They are assumed
- * to be relative to the epoch, time 0, or January 1 1970, for purposes here.</p>
- */
-public final class DataFileIterator
- extends AbstractIterator<Pair<PreferenceArray,long[]>>
- implements SkippingIterator<Pair<PreferenceArray,long[]>>, Closeable {
- private static final Pattern COLON_PATTERN = Pattern.compile(":");
- private static final Pattern PIPE_PATTERN = Pattern.compile("\\|");
- private static final Pattern TAB_PATTERN = Pattern.compile("\t");
- private final FileLineIterator lineIterator;
- private static final Logger log = LoggerFactory.getLogger(DataFileIterator.class);
- public DataFileIterator(File dataFile) throws IOException {
- if (dataFile == null || dataFile.isDirectory() || !dataFile.exists()) {
- throw new IllegalArgumentException("Bad data file: " + dataFile);
- }
- lineIterator = new FileLineIterator(dataFile);
- }
- @Override
- protected Pair<PreferenceArray, long[]> computeNext() {
- if (!lineIterator.hasNext()) {
- return endOfData();
- }
- String line = lineIterator.next();
- // First a userID|ratingsCount line
- String[] tokens = PIPE_PATTERN.split(line);
- long userID = Long.parseLong(tokens[0]);
- int ratingsLeftToRead = Integer.parseInt(tokens[1]);
- int ratingsRead = 0;
- PreferenceArray currentUserPrefs = new GenericUserPreferenceArray(ratingsLeftToRead);
- long[] timestamps = new long[ratingsLeftToRead];
- while (ratingsLeftToRead > 0) {
- line = lineIterator.next();
- // Then a data line. May be 1-4 tokens depending on whether preference info is included (it's not in test data)
- // or whether date info is included (not inluded in track 2). Item ID is always first, and date is the last
- // two fields if it exists.
- tokens = TAB_PATTERN.split(line);
- boolean hasPref = tokens.length == 2 || tokens.length == 4;
- boolean hasDate = tokens.length > 2;
- long itemID = Long.parseLong(tokens[0]);
- currentUserPrefs.setUserID(0, userID);
- currentUserPrefs.setItemID(ratingsRead, itemID);
- if (hasPref) {
- float preference = Float.parseFloat(tokens[1]);
- currentUserPrefs.setValue(ratingsRead, preference);
- }
- if (hasDate) {
- long timestamp;
- if (hasPref) {
- timestamp = parseFakeTimestamp(tokens[2], tokens[3]);
- } else {
- timestamp = parseFakeTimestamp(tokens[1], tokens[2]);
- }
- timestamps[ratingsRead] = timestamp;
- }
- ratingsRead++;
- ratingsLeftToRead--;
- }
- return new Pair<>(currentUserPrefs, timestamps);
- }
- @Override
- public void skip(int n) {
- for (int i = 0; i < n; i++) {
- if (lineIterator.hasNext()) {
- String line = lineIterator.next();
- // First a userID|ratingsCount line
- String[] tokens = PIPE_PATTERN.split(line);
- int linesToSKip = Integer.parseInt(tokens[1]);
- lineIterator.skip(linesToSKip);
- } else {
- break;
- }
- }
- }
- @Override
- public void close() {
- endOfData();
- try {
- Closeables.close(lineIterator, true);
- } catch (IOException e) {
- log.error(e.getMessage(), e);
- }
- }
- /**
- * @param dateString "date" in days since some undisclosed date, which we will arbitrarily assume to be the
- * epoch, January 1 1970.
- * @param timeString time of day in HH:mm:ss format
- * @return the UNIX timestamp for this moment in time
- */
- private static long parseFakeTimestamp(String dateString, CharSequence timeString) {
- int days = Integer.parseInt(dateString);
- String[] timeTokens = COLON_PATTERN.split(timeString);
- int hours = Integer.parseInt(timeTokens[0]);
- int minutes = Integer.parseInt(timeTokens[1]);
- int seconds = Integer.parseInt(timeTokens[2]);
- return 86400L * days + 3600L + hours + 60L * minutes + seconds;
- }

diff --git a/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/KDDCupDataModel.java b/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/KDDCupDataModel.java
deleted file mode 100644
index 4b62050..0000000
--- a/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/KDDCupDataModel.java
+++ /dev/null
@@ -1,231 +0,0 @@
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.mahout.cf.taste.example.kddcup;
-import java.io.File;
-import java.io.IOException;
-import java.util.Collection;
-import java.util.Iterator;
-import com.google.common.base.Preconditions;
-import org.apache.mahout.cf.taste.common.Refreshable;
-import org.apache.mahout.cf.taste.common.TasteException;
-import org.apache.mahout.cf.taste.impl.common.FastByIDMap;
-import org.apache.mahout.cf.taste.impl.common.FastIDSet;
-import org.apache.mahout.cf.taste.impl.common.LongPrimitiveIterator;
-import org.apache.mahout.cf.taste.impl.model.GenericDataModel;
-import org.apache.mahout.cf.taste.model.DataModel;
-import org.apache.mahout.cf.taste.model.PreferenceArray;
-import org.apache.mahout.common.Pair;
-import org.apache.mahout.common.iterator.SamplingIterator;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
- * <p>An {@link DataModel} which reads into memory any of the KDD Cup's rating files; it is really
- * meant for use with training data in the files trainIdx{1,2}}.txt.
- * See http://kddcup.yahoo.com/.</p>
- *
- * <p>Timestamps in the data set are relative to some unknown point in time, for anonymity. They are assumed
- * to be relative to the epoch, time 0, or January 1 1970, for purposes here.</p>
- */
-public final class KDDCupDataModel implements DataModel {
- private static final Logger log = LoggerFactory.getLogger(KDDCupDataModel.class);
- private final File dataFileDirectory;
- private final DataModel delegate;
- /**
- * @param dataFile training rating file
- */
- public KDDCupDataModel(File dataFile) throws IOException {
- this(dataFile, false, 1.0);
- }
- /**
- * @param dataFile training rating file
- * @param storeDates if true, dates are parsed and stored, otherwise not
- * @param samplingRate percentage of users to keep; can be used to reduce memory requirements
- */
- public KDDCupDataModel(File dataFile, boolean storeDates, double samplingRate) throws IOException {
- Preconditions.checkArgument(!Double.isNaN(samplingRate) && samplingRate > 0.0 && samplingRate <= 1.0,
- "Must be: 0.0 < samplingRate <= 1.0");
- dataFileDirectory = dataFile.getParentFile();
- Iterator<Pair<PreferenceArray,long[]>> dataIterator = new DataFileIterator(dataFile);
- if (samplingRate < 1.0) {
- dataIterator = new SamplingIterator<>(dataIterator, samplingRate);
- }
- FastByIDMap<PreferenceArray> userData = new FastByIDMap<>();
- FastByIDMap<FastByIDMap<Long>> timestamps = new FastByIDMap<>();
- while (dataIterator.hasNext()) {
- Pair<PreferenceArray,long[]> pair = dataIterator.next();
- PreferenceArray userPrefs = pair.getFirst();
- long[] timestampsForPrefs = pair.getSecond();
- userData.put(userPrefs.getUserID(0), userPrefs);
- if (storeDates) {
- FastByIDMap<Long> itemTimestamps = new FastByIDMap<>();
- for (int i = 0; i < timestampsForPrefs.length; i++) {
- long timestamp = timestampsForPrefs[i];
- if (timestamp > 0L) {
- itemTimestamps.put(userPrefs.getItemID(i), timestamp);
- }
- }
- }
- }
- if (storeDates) {
- delegate = new GenericDataModel(userData, timestamps);
- } else {
- delegate = new GenericDataModel(userData);
- }
- Runtime runtime = Runtime.getRuntime();
- log.info("Loaded data model in about {}MB heap", (runtime.totalMemory() - runtime.freeMemory()) / 1000000);
- }
- public File getDataFileDirectory() {
- return dataFileDirectory;
- }
- public static File getTrainingFile(File dataFileDirectory) {
- return getFile(dataFileDirectory, "trainIdx");
- }
- public static File getValidationFile(File dataFileDirectory) {
- return getFile(dataFileDirectory, "validationIdx");
- }
- public static File getTestFile(File dataFileDirectory) {
- return getFile(dataFileDirectory, "testIdx");
- }
- public static File getTrackFile(File dataFileDirectory) {
- return getFile(dataFileDirectory, "trackData");
- }
- private static File getFile(File dataFileDirectory, String prefix) {
- // Works on set 1 or 2
- for (int set : new int[] {1,2}) {
- // Works on sample data from before contest or real data
- for (String firstLinesOrNot : new String[] {"", ".firstLines"}) {
- for (String gzippedOrNot : new String[] {".gz", ""}) {
- File dataFile = new File(dataFileDirectory, prefix + set + firstLinesOrNot + ".txt" + gzippedOrNot);
- if (dataFile.exists()) {
- return dataFile;
- }
- }
- }
- }
- throw new IllegalArgumentException("Can't find " + prefix + " file in " + dataFileDirectory);
- }
- @Override
- public LongPrimitiveIterator getUserIDs() throws TasteException {
- return delegate.getUserIDs();
- }
- @Override
- public PreferenceArray getPreferencesFromUser(long userID) throws TasteException {
- return delegate.getPreferencesFromUser(userID);
- }
- @Override
- public FastIDSet getItemIDsFromUser(long userID) throws TasteException {
- return delegate.getItemIDsFromUser(userID);
- }
- @Override
- public LongPrimitiveIterator getItemIDs() throws TasteException {
- return delegate.getItemIDs();
- }
- @Override
- public PreferenceArray getPreferencesForItem(long itemID) throws TasteException {
- return delegate.getPreferencesForItem(itemID);
- }
- @Override
- public Float getPreferenceValue(long userID, long itemID) throws TasteException {
- return delegate.getPreferenceValue(userID, itemID);
- }
- @Override
- public Long getPreferenceTime(long userID, long itemID) throws TasteException {
- return delegate.getPreferenceTime(userID, itemID);
- }
- @Override
- public int getNumItems() throws TasteException {
- return delegate.getNumItems();
- }
- @Override
- public int getNumUsers() throws TasteException {
- return delegate.getNumUsers();
- }
- @Override
- public int getNumUsersWithPreferenceFor(long itemID) throws TasteException {
- return delegate.getNumUsersWithPreferenceFor(itemID);
- }
- @Override
- public int getNumUsersWithPreferenceFor(long itemID1, long itemID2) throws TasteException {
- return delegate.getNumUsersWithPreferenceFor(itemID1, itemID2);
- }
- @Override
- public void setPreference(long userID, long itemID, float value) throws TasteException {
- delegate.setPreference(userID, itemID, value);
- }
- @Override
- public void removePreference(long userID, long itemID) throws TasteException {
- delegate.removePreference(userID, itemID);
- }
- @Override
- public boolean hasPreferenceValues() {
- return delegate.hasPreferenceValues();
- }
- @Override
- public float getMaxPreference() {
- return 100.0f;
- }
- @Override
- public float getMinPreference() {
- return 0.0f;
- }
- @Override
- public void refresh(Collection<Refreshable> alreadyRefreshed) {
- // do nothing
- }

diff --git a/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/ToCSV.java b/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/ToCSV.java
deleted file mode 100644
index 3f4a732..0000000
--- a/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/ToCSV.java
+++ /dev/null
@@ -1,77 +0,0 @@
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.mahout.cf.taste.example.kddcup;
-import org.apache.commons.io.Charsets;
-import org.apache.mahout.cf.taste.model.PreferenceArray;
-import org.apache.mahout.common.Pair;
-import java.io.BufferedWriter;
-import java.io.File;
-import java.io.FileOutputStream;
-import java.io.OutputStream;
-import java.io.OutputStreamWriter;
-import java.io.Writer;
-import java.util.zip.GZIPOutputStream;
- * <p>This class converts a KDD Cup input file into a compressed CSV format. The output format is
- * {@code userID,itemID,score,timestamp}. It can optionally restrict its output to exclude
- * score and/or timestamp.</p>
- *
- * <p>Run as: {@code ToCSV (input file) (output file) [num columns to output]}</p>
- */
-public final class ToCSV {
- private ToCSV() {
- }
- public static void main(String[] args) throws Exception {
- File inputFile = new File(args[0]);
- File outputFile = new File(args[1]);
- int columnsToOutput = 4;
- if (args.length >= 3) {
- columnsToOutput = Integer.parseInt(args[2]);
- }
- OutputStream outStream = new GZIPOutputStream(new FileOutputStream(outputFile));
- try (Writer outWriter = new BufferedWriter(new OutputStreamWriter(outStream, Charsets.UTF_8))){
- for (Pair<PreferenceArray,long[]> user : new DataFileIterable(inputFile)) {
- PreferenceArray prefs = user.getFirst();
- long[] timestamps = user.getSecond();
- for (int i = 0; i < prefs.length(); i++) {
- outWriter.write(String.valueOf(prefs.getUserID(i)));
- outWriter.write(',');
- outWriter.write(String.valueOf(prefs.getItemID(i)));
- if (columnsToOutput > 2) {
- outWriter.write(',');
- outWriter.write(String.valueOf(prefs.getValue(i)));
- }
- if (columnsToOutput > 3) {
- outWriter.write(',');
- outWriter.write(String.valueOf(timestamps[i]));
- }
- outWriter.write('\n');
- }
- }
- }
- }

diff --git a/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/EstimateConverter.java b/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/EstimateConverter.java
deleted file mode 100644
index 0112ab9..0000000
--- a/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/EstimateConverter.java
+++ /dev/null
@@ -1,43 +0,0 @@
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.mahout.cf.taste.example.kddcup.track1;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
-public final class EstimateConverter {
- private static final Logger log = LoggerFactory.getLogger(EstimateConverter.class);
- private EstimateConverter() {}
- public static byte convert(double estimate, long userID, long itemID) {
- if (Double.isNaN(estimate)) {
- log.warn("Unable to compute estimate for user {}, item {}", userID, itemID);
- return 0x7F;
- } else {
- int scaledEstimate = (int) (estimate * 2.55);
- if (scaledEstimate > 255) {
- scaledEstimate = 255;
- } else if (scaledEstimate < 0) {
- scaledEstimate = 0;
- }
- return (byte) scaledEstimate;
- }
- }

diff --git a/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/Track1Callable.java b/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/Track1Callable.java
deleted file mode 100644
index 72056da..0000000
--- a/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/Track1Callable.java
+++ /dev/null
@@ -1,67 +0,0 @@
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.mahout.cf.taste.example.kddcup.track1;
-import java.util.concurrent.Callable;
-import java.util.concurrent.atomic.AtomicInteger;
-import org.apache.mahout.cf.taste.common.NoSuchItemException;
-import org.apache.mahout.cf.taste.common.TasteException;
-import org.apache.mahout.cf.taste.model.PreferenceArray;
-import org.apache.mahout.cf.taste.recommender.Recommender;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
-final class Track1Callable implements Callable<byte[]> {
- private static final Logger log = LoggerFactory.getLogger(Track1Callable.class);
- private static final AtomicInteger COUNT = new AtomicInteger();
- private final Recommender recommender;
- private final PreferenceArray userTest;
- Track1Callable(Recommender recommender, PreferenceArray userTest) {
- this.recommender = recommender;
- this.userTest = userTest;
- }
- @Override
- public byte[] call() throws TasteException {
- long userID = userTest.get(0).getUserID();
- byte[] result = new byte[userTest.length()];
- for (int i = 0; i < userTest.length(); i++) {
- long itemID = userTest.getItemID(i);
- double estimate;
- try {
- estimate = recommender.estimatePreference(userID, itemID);
- } catch (NoSuchItemException nsie) {
- // OK in the sample data provided before the contest, should never happen otherwise
- log.warn("Unknown item {}; OK unless this is the real contest data", itemID);
- continue;
- }
- result[i] = EstimateConverter.convert(estimate, userID, itemID);
- }
- if (COUNT.incrementAndGet() % 10000 == 0) {
- log.info("Completed {} users", COUNT.get());
- }
- return result;
- }

diff --git a/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/Track1Recommender.java b/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/Track1Recommender.java
deleted file mode 100644
index 067daf5..0000000
--- a/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/Track1Recommender.java
+++ /dev/null
@@ -1,94 +0,0 @@
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.mahout.cf.taste.example.kddcup.track1;
-import java.util.Collection;
-import java.util.List;
-import org.apache.mahout.cf.taste.common.Refreshable;
-import org.apache.mahout.cf.taste.common.TasteException;
-import org.apache.mahout.cf.taste.impl.recommender.GenericItemBasedRecommender;
-import org.apache.mahout.cf.taste.impl.similarity.UncenteredCosineSimilarity;
-import org.apache.mahout.cf.taste.model.DataModel;
-import org.apache.mahout.cf.taste.recommender.IDRescorer;
-import org.apache.mahout.cf.taste.recommender.RecommendedItem;
-import org.apache.mahout.cf.taste.recommender.Recommender;
-import org.apache.mahout.cf.taste.similarity.ItemSimilarity;
-public final class Track1Recommender implements Recommender {
- private final Recommender recommender;
- public Track1Recommender(DataModel dataModel) throws TasteException {
- // Change this to whatever you like!
- ItemSimilarity similarity = new UncenteredCosineSimilarity(dataModel);
- recommender = new GenericItemBasedRecommender(dataModel, similarity);
- }
- @Override
- public List<RecommendedItem> recommend(long userID, int howMany) throws TasteException {
- return recommender.recommend(userID, howMany);
- }
- @Override
- public List<RecommendedItem> recommend(long userID, int howMany, boolean includeKnownItems) throws TasteException {
- return recommend(userID, howMany, null, includeKnownItems);
- }
- @Override
- public List<RecommendedItem> recommend(long userID, int howMany, IDRescorer rescorer) throws TasteException {
- return recommender.recommend(userID, howMany, rescorer, false);
- }
- @Override
- public List<RecommendedItem> recommend(long userID, int howMany, IDRescorer rescorer, boolean includeKnownItems)
- throws TasteException {
- return recommender.recommend(userID, howMany, rescorer, includeKnownItems);
- }
- @Override
- public float estimatePreference(long userID, long itemID) throws TasteException {
- return recommender.estimatePreference(userID, itemID);
- }
- @Override
- public void setPreference(long userID, long itemID, float value) throws TasteException {
- recommender.setPreference(userID, itemID, value);
- }
- @Override
- public void removePreference(long userID, long itemID) throws TasteException {
- recommender.removePreference(userID, itemID);
- }
- @Override
- public DataModel getDataModel() {
- return recommender.getDataModel();
- }
- @Override
- public void refresh(Collection<Refreshable> alreadyRefreshed) {
- recommender.refresh(alreadyRefreshed);
- }
- @Override
- public String toString() {
- return "Track1Recommender[recommender:" + recommender + ']';
- }

diff --git a/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/Track1RecommenderBuilder.java b/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/Track1RecommenderBuilder.java
deleted file mode 100644
index 6b9fe1b..0000000
--- a/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/Track1RecommenderBuilder.java
+++ /dev/null
@@ -1,32 +0,0 @@
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.mahout.cf.taste.example.kddcup.track1;
-import org.apache.mahout.cf.taste.common.TasteException;
-import org.apache.mahout.cf.taste.eval.RecommenderBuilder;
-import org.apache.mahout.cf.taste.model.DataModel;
-import org.apache.mahout.cf.taste.recommender.Recommender;
-final class Track1RecommenderBuilder implements RecommenderBuilder {
- @Override
- public Recommender buildRecommender(DataModel dataModel) throws TasteException {
- return new Track1Recommender(dataModel);
- }

diff --git a/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/Track1RecommenderEvaluator.java b/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/Track1RecommenderEvaluator.java
deleted file mode 100644
index bcd0a3d..0000000
--- a/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/Track1RecommenderEvaluator.java
+++ /dev/null
@@ -1,108 +0,0 @@
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.mahout.cf.taste.example.kddcup.track1;
-import java.io.File;
-import java.util.Collection;
-import java.util.concurrent.Callable;
-import java.util.concurrent.atomic.AtomicInteger;
-import com.google.common.collect.Lists;
-import org.apache.mahout.cf.taste.common.TasteException;
-import org.apache.mahout.cf.taste.eval.DataModelBuilder;
-import org.apache.mahout.cf.taste.eval.RecommenderBuilder;
-import org.apache.mahout.cf.taste.example.kddcup.DataFileIterable;
-import org.apache.mahout.cf.taste.example.kddcup.KDDCupDataModel;
-import org.apache.mahout.cf.taste.impl.common.FullRunningAverage;
-import org.apache.mahout.cf.taste.impl.common.FullRunningAverageAndStdDev;
-import org.apache.mahout.cf.taste.impl.common.RunningAverage;
-import org.apache.mahout.cf.taste.impl.common.RunningAverageAndStdDev;
-import org.apache.mahout.cf.taste.impl.eval.AbstractDifferenceRecommenderEvaluator;
-import org.apache.mahout.cf.taste.model.DataModel;
-import org.apache.mahout.cf.taste.model.Preference;
-import org.apache.mahout.cf.taste.model.PreferenceArray;
-import org.apache.mahout.cf.taste.recommender.Recommender;
-import org.apache.mahout.common.Pair;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
- * Attempts to run an evaluation just like that dictated for Yahoo's KDD Cup, Track 1.
- * It will compute the RMSE of a validation data set against the predicted ratings from
- * the training data set.
- */
-public final class Track1RecommenderEvaluator extends AbstractDifferenceRecommenderEvaluator {
- private static final Logger log = LoggerFactory.getLogger(Track1RecommenderEvaluator.class);
- private RunningAverage average;
- private final File dataFileDirectory;
- public Track1RecommenderEvaluator(File dataFileDirectory) {
- setMaxPreference(100.0f);
- setMinPreference(0.0f);
- average = new FullRunningAverage();
- this.dataFileDirectory = dataFileDirectory;
- }
- @Override
- public double evaluate(RecommenderBuilder recommenderBuilder,
- DataModelBuilder dataModelBuilder,
- DataModel dataModel,
- double trainingPercentage,
- double evaluationPercentage) throws TasteException {
- Recommender recommender = recommenderBuilder.buildRecommender(dataModel);
- Collection<Callable<Void>> estimateCallables = Lists.newArrayList();
- AtomicInteger noEstimateCounter = new AtomicInteger();
- for (Pair<PreferenceArray,long[]> userData
- : new DataFileIterable(KDDCupDataModel.getValidationFile(dataFileDirectory))) {
- PreferenceArray validationPrefs = userData.getFirst();
- long userID = validationPrefs.get(0).getUserID();
- estimateCallables.add(
- new PreferenceEstimateCallable(recommender, userID, validationPrefs, noEstimateCounter));
- }
- RunningAverageAndStdDev timing = new FullRunningAverageAndStdDev();
- execute(estimateCallables, noEstimateCounter, timing);
- double result = computeFinalEvaluation();
- log.info("Evaluation result: {}", result);
- return result;
- }
- // Use RMSE scoring:
- @Override
- protected void reset() {
- average = new FullRunningAverage();
- }
- @Override
- protected void processOneEstimate(float estimatedPreference, Preference realPref) {
- double diff = realPref.getValue() - estimatedPreference;
- average.addDatum(diff * diff);
- }
- @Override
- protected double computeFinalEvaluation() {
- return Math.sqrt(average.getAverage());
- }

diff --git a/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/Track1RecommenderEvaluatorRunner.java b/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/Track1RecommenderEvaluatorRunner.java
deleted file mode 100644
index deadc00..0000000
--- a/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/Track1RecommenderEvaluatorRunner.java
+++ /dev/null
@@ -1,56 +0,0 @@
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.mahout.cf.taste.example.kddcup.track1;
-import java.io.File;
-import java.io.IOException;
-import org.apache.commons.cli2.OptionException;
-import org.apache.mahout.cf.taste.common.TasteException;
-import org.apache.mahout.cf.taste.example.TasteOptionParser;
-import org.apache.mahout.cf.taste.example.kddcup.KDDCupDataModel;
-import org.apache.mahout.cf.taste.model.DataModel;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
-public final class Track1RecommenderEvaluatorRunner {
- private static final Logger log = LoggerFactory.getLogger(Track1RecommenderEvaluatorRunner.class);
- private Track1RecommenderEvaluatorRunner() {
- }
- public static void main(String... args) throws IOException, TasteException, OptionException {
- File dataFileDirectory = TasteOptionParser.getRatings(args);
- if (dataFileDirectory == null) {
- throw new IllegalArgumentException("No data directory");
- }
- if (!dataFileDirectory.exists() || !dataFileDirectory.isDirectory()) {
- throw new IllegalArgumentException("Bad data file directory: " + dataFileDirectory);
- }
- Track1RecommenderEvaluator evaluator = new Track1RecommenderEvaluator(dataFileDirectory);
- DataModel model = new KDDCupDataModel(KDDCupDataModel.getTrainingFile(dataFileDirectory));
- double evaluation = evaluator.evaluate(new Track1RecommenderBuilder(),
- null,
- model,
- Float.NaN,
- Float.NaN);
- log.info(String.valueOf(evaluation));
- }

diff --git a/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/Track1Runner.java b/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/Track1Runner.java
deleted file mode 100644
index a0ff126..0000000
--- a/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/Track1Runner.java
+++ /dev/null
@@ -1,95 +0,0 @@
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.mahout.cf.taste.example.kddcup.track1;
-import org.apache.mahout.cf.taste.example.kddcup.DataFileIterable;
-import org.apache.mahout.cf.taste.example.kddcup.KDDCupDataModel;
-import org.apache.mahout.cf.taste.model.PreferenceArray;
-import org.apache.mahout.common.Pair;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
-import java.io.BufferedOutputStream;
-import java.io.File;
-import java.io.FileOutputStream;
-import java.io.OutputStream;
-import java.util.ArrayList;
-import java.util.Collection;
-import java.util.List;
-import java.util.concurrent.ExecutorService;
-import java.util.concurrent.Executors;
-import java.util.concurrent.Future;
- * <p>Runs "track 1" of the KDD Cup competition using whatever recommender is inside {@link Track1Recommender}
- * and attempts to output the result in the correct contest format.</p>
- *
- * <p>Run as: {@code Track1Runner [track 1 data file directory] [output file]}</p>
- */
-public final class Track1Runner {
- private static final Logger log = LoggerFactory.getLogger(Track1Runner.class);
- private Track1Runner() {
- }
- public static void main(String[] args) throws Exception {
- File dataFileDirectory = new File(args[0]);
- if (!dataFileDirectory.exists() || !dataFileDirectory.isDirectory()) {
- throw new IllegalArgumentException("Bad data file directory: " + dataFileDirectory);
- }
- long start = System.currentTimeMillis();
- KDDCupDataModel model = new KDDCupDataModel(KDDCupDataModel.getTrainingFile(dataFileDirectory));
- Track1Recommender recommender = new Track1Recommender(model);
- long end = System.currentTimeMillis();
- log.info("Loaded model in {}s", (end - start) / 1000);
- start = end;
- Collection<Track1Callable> callables = new ArrayList<>();
- for (Pair<PreferenceArray,long[]> tests : new DataFileIterable(KDDCupDataModel.getTestFile(dataFileDirectory))) {
- PreferenceArray userTest = tests.getFirst();
- callables.add(new Track1Callable(recommender, userTest));
- }
- int cores = Runtime.getRuntime().availableProcessors();
- log.info("Running on {} cores", cores);
- ExecutorService executor = Executors.newFixedThreadPool(cores);
- List<Future<byte[]>> results = executor.invokeAll(callables);
- executor.shutdown();
- end = System.currentTimeMillis();
- log.info("Ran recommendations in {}s", (end - start) / 1000);
- start = end;
- try (OutputStream out = new BufferedOutputStream(new FileOutputStream(new File(args[1])))){
- for (Future<byte[]> result : results) {
- for (byte estimate : result.get()) {
- out.write(estimate);
- }
- }
- }
- end = System.currentTimeMillis();
- log.info("Wrote output in {}s", (end - start) / 1000);
- }

diff --git a/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/svd/DataModelFactorizablePreferences.java b/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/svd/DataModelFactorizablePreferences.java
deleted file mode 100644
index 022d78c..0000000
--- a/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/svd/DataModelFactorizablePreferences.java
+++ /dev/null
@@ -1,107 +0,0 @@
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.mahout.cf.taste.example.kddcup.track1.svd;
-import org.apache.mahout.cf.taste.common.TasteException;
-import org.apache.mahout.cf.taste.impl.common.FastIDSet;
-import org.apache.mahout.cf.taste.impl.common.LongPrimitiveIterator;
-import org.apache.mahout.cf.taste.impl.model.GenericPreference;
-import org.apache.mahout.cf.taste.model.DataModel;
-import org.apache.mahout.cf.taste.model.Preference;
-import java.util.ArrayList;
-import java.util.List;
- * can be used to drop {@link DataModel}s into {@link ParallelArraysSGDFactorizer}
- */
-public class DataModelFactorizablePreferences implements FactorizablePreferences {
- private final FastIDSet userIDs;
- private final FastIDSet itemIDs;
- private final List<Preference> preferences;
- private final float minPreference;
- private final float maxPreference;
- public DataModelFactorizablePreferences(DataModel dataModel) {
- minPreference = dataModel.getMinPreference();
- maxPreference = dataModel.getMaxPreference();
- try {
- userIDs = new FastIDSet(dataModel.getNumUsers());
- itemIDs = new FastIDSet(dataModel.getNumItems());
- preferences = new ArrayList<>();
- LongPrimitiveIterator userIDsIterator = dataModel.getUserIDs();
- while (userIDsIterator.hasNext()) {
- long userID = userIDsIterator.nextLong();
- userIDs.add(userID);
- for (Preference preference : dataModel.getPreferencesFromUser(userID)) {
- itemIDs.add(preference.getItemID());
- preferences.add(new GenericPreference(userID, preference.getItemID(), preference.getValue()));
- }
- }
- } catch (TasteException te) {
- throw new IllegalStateException("Unable to create factorizable preferences!", te);
- }
- }
- @Override
- public LongPrimitiveIterator getUserIDs() {
- return userIDs.iterator();
- }
- @Override
- public LongPrimitiveIterator getItemIDs() {
- return itemIDs.iterator();
- }
- @Override
- public Iterable<Preference> getPreferences() {
- return preferences;
- }
- @Override
- public float getMinPreference() {
- return minPreference;
- }
- @Override
- public float getMaxPreference() {
- return maxPreference;
- }
- @Override
- public int numUsers() {
- return userIDs.size();
- }
- @Override
- public int numItems() {
- return itemIDs.size();
- }
- @Override
- public int numPreferences() {
- return preferences.size();
- }

diff --git a/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/svd/FactorizablePreferences.java b/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/svd/FactorizablePreferences.java
deleted file mode 100644
index a126dec..0000000
--- a/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/svd/FactorizablePreferences.java
+++ /dev/null
@@ -1,44 +0,0 @@
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.mahout.cf.taste.example.kddcup.track1.svd;
-import org.apache.mahout.cf.taste.impl.common.LongPrimitiveIterator;
-import org.apache.mahout.cf.taste.model.Preference;
- * models the necessary input for {@link ParallelArraysSGDFactorizer}
- */
-public interface FactorizablePreferences {
- LongPrimitiveIterator getUserIDs();
- LongPrimitiveIterator getItemIDs();
- Iterable<Preference> getPreferences();
- float getMinPreference();
- float getMaxPreference();
- int numUsers();
- int numItems();
- int numPreferences();

diff --git a/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/svd/KDDCupFactorizablePreferences.java b/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/svd/KDDCupFactorizablePreferences.java
deleted file mode 100644
index 6dcef6b..0000000
--- a/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/svd/KDDCupFactorizablePreferences.java
+++ /dev/null
@@ -1,123 +0,0 @@
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.mahout.cf.taste.example.kddcup.track1.svd;
-import com.google.common.base.Function;
-import com.google.common.collect.Iterables;
-import org.apache.mahout.cf.taste.example.kddcup.DataFileIterable;
-import org.apache.mahout.cf.taste.impl.common.AbstractLongPrimitiveIterator;
-import org.apache.mahout.cf.taste.impl.common.LongPrimitiveIterator;
-import org.apache.mahout.cf.taste.model.Preference;
-import org.apache.mahout.cf.taste.model.PreferenceArray;
-import org.apache.mahout.common.Pair;
-import java.io.File;
-public class KDDCupFactorizablePreferences implements FactorizablePreferences {
- private final File dataFile;
- public KDDCupFactorizablePreferences(File dataFile) {
- this.dataFile = dataFile;
- }
- @Override
- public LongPrimitiveIterator getUserIDs() {
- return new FixedSizeLongIterator(numUsers());
- }
- @Override
- public LongPrimitiveIterator getItemIDs() {
- return new FixedSizeLongIterator(numItems());
- }
- @Override
- public Iterable<Preference> getPreferences() {
- Iterable<Iterable<Preference>> prefIterators =
- Iterables.transform(new DataFileIterable(dataFile),
- new Function<Pair<PreferenceArray,long[]>,Iterable<Preference>>() {
- @Override
- public Iterable<Preference> apply(Pair<PreferenceArray,long[]> from) {
- return from.getFirst();
- }
- });
- return Iterables.concat(prefIterators);
- }
- @Override
- public float getMinPreference() {
- return 0;
- }
- @Override
- public float getMaxPreference() {
- return 100;
- }
- @Override
- public int numUsers() {
- return 1000990;
- }
- @Override
- public int numItems() {
- return 624961;
- }
- @Override
- public int numPreferences() {
- return 252800275;
- }
- static class FixedSizeLongIterator extends AbstractLongPrimitiveIterator {
- private long currentValue;
- private final long maximum;
- FixedSizeLongIterator(long maximum) {
- this.maximum = maximum;
- currentValue = 0;
- }
- @Override
- public long nextLong() {
- return currentValue++;
- }
- @Override
- public long peek() {
- return currentValue;
- }
- @Override
- public void skip(int n) {
- currentValue += n;
- }
- @Override
- public boolean hasNext() {
- return currentValue < maximum;
- }
- @Override
- public void remove() {
- throw new UnsupportedOperationException();
- }
- }

diff --git a/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/svd/ParallelArraysSGDFactorizer.java b/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/svd/ParallelArraysSGDFactorizer.java
deleted file mode 100644
index a99d54c..0000000
--- a/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/svd/ParallelArraysSGDFactorizer.java
+++ /dev/null
@@ -1,265 +0,0 @@
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.mahout.cf.taste.example.kddcup.track1.svd;
-import org.apache.mahout.cf.taste.common.Refreshable;
-import org.apache.mahout.cf.taste.common.TasteException;
-import org.apache.mahout.cf.taste.impl.common.FastByIDMap;
-import org.apache.mahout.cf.taste.impl.common.FullRunningAverage;
-import org.apache.mahout.cf.taste.impl.common.LongPrimitiveIterator;
-import org.apache.mahout.cf.taste.impl.common.RunningAverage;
-import org.apache.mahout.cf.taste.impl.recommender.svd.Factorization;
-import org.apache.mahout.cf.taste.impl.recommender.svd.Factorizer;
-import org.apache.mahout.cf.taste.model.DataModel;
-import org.apache.mahout.cf.taste.model.Preference;
-import org.apache.mahout.common.RandomUtils;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
-import java.util.Collection;
-import java.util.Random;
- * {@link Factorizer} based on Simon Funk's famous article <a href="http://sifter.org/~simon/journal/20061211.html">
- * "Netflix Update: Try this at home"</a>.
- *
- * Attempts to be as memory efficient as possible, only iterating once through the
- * {@link FactorizablePreferences} or {@link DataModel} while copying everything to primitive arrays.
- * Learning works in place on these datastructures after that.
- */
-public class ParallelArraysSGDFactorizer implements Factorizer {
- public static final double DEFAULT_LEARNING_RATE = 0.005;
- public static final double DEFAULT_PREVENT_OVERFITTING = 0.02;
- public static final double DEFAULT_RANDOM_NOISE = 0.005;
- private final int numFeatures;
- private final int numIterations;
- private final float minPreference;
- private final float maxPreference;
- private final Random random;
- private final double learningRate;
- private final double preventOverfitting;
- private final FastByIDMap<Integer> userIDMapping;
- private final FastByIDMap<Integer> itemIDMapping;
- private final double[][] userFeatures;
- private final double[][] itemFeatures;
- private final int[] userIndexes;
- private final int[] itemIndexes;
- private final float[] values;
- private final double defaultValue;
- private final double interval;
- private final double[] cachedEstimates;
- private static final Logger log = LoggerFactory.getLogger(ParallelArraysSGDFactorizer.class);
- public ParallelArraysSGDFactorizer(DataModel dataModel, int numFeatures, int numIterations) {
- this(new DataModelFactorizablePreferences(dataModel), numFeatures, numIterations, DEFAULT_LEARNING_RATE,
- }
- public ParallelArraysSGDFactorizer(DataModel dataModel, int numFeatures, int numIterations, double learningRate,
- double preventOverfitting, double randomNoise) {
- this(new DataModelFactorizablePreferences(dataModel), numFeatures, numIterations, learningRate, preventOverfitting,
- randomNoise);
- }
- public ParallelArraysSGDFactorizer(FactorizablePreferences factorizablePrefs, int numFeatures, int numIterations) {
- this(factorizablePrefs, numFeatures, numIterations, DEFAULT_LEARNING_RATE, DEFAULT_PREVENT_OVERFITTING,
- }
- public ParallelArraysSGDFactorizer(FactorizablePreferences factorizablePreferences, int numFeatures,
- int numIterations, double learningRate, double preventOverfitting, double randomNoise) {
- this.numFeatures = numFeatures;
- this.numIterations = numIterations;
- minPreference = factorizablePreferences.getMinPreference();
- maxPreference = factorizablePreferences.getMaxPreference();
- this.random = RandomUtils.getRandom();
- this.learningRate = learningRate;
- this.preventOverfitting = preventOverfitting;
- int numUsers = factorizablePreferences.numUsers();
- int numItems = factorizablePreferences.numItems();
- int numPrefs = factorizablePreferences.numPreferences();
- log.info("Mapping {} users...", numUsers);
- userIDMapping = new FastByIDMap<>(numUsers);
- int index = 0;
- LongPrimitiveIterator userIterator = factorizablePreferences.getUserIDs();
- while (userIterator.hasNext()) {
- userIDMapping.put(userIterator.nextLong(), index++);
- }
- log.info("Mapping {} items", numItems);
- itemIDMapping = new FastByIDMap<>(numItems);
- index = 0;
- LongPrimitiveIterator itemIterator = factorizablePreferences.getItemIDs();
- while (itemIterator.hasNext()) {
- itemIDMapping.put(itemIterator.nextLong(), index++);
- }
- this.userIndexes = new int[numPrefs];
- this.itemIndexes = new int[numPrefs];
- this.values = new float[numPrefs];
- this.cachedEstimates = new double[numPrefs];
- index = 0;
- log.info("Loading {} preferences into memory", numPrefs);
- RunningAverage average = new FullRunningAverage();
- for (Preference preference : factorizablePreferences.getPreferences()) {
- userIndexes[index] = userIDMapping.get(preference.getUserID());
- itemIndexes[index] = itemIDMapping.get(preference.getItemID());
- values[index] = preference.getValue();
- cachedEstimates[index] = 0;
- average.addDatum(preference.getValue());
- index++;
- if (index % 1000000 == 0) {
- log.info("Processed {} preferences", index);
- }
- }
- log.info("Processed {} preferences, done.", index);
- double averagePreference = average.getAverage();
- log.info("Average preference value is {}", averagePreference);
- double prefInterval = factorizablePreferences.getMaxPreference() - factorizablePreferences.getMinPreference();
- defaultValue = Math.sqrt((averagePreference - prefInterval * 0.1) / numFeatures);
- interval = prefInterval * 0.1 / numFeatures;
- userFeatures = new double[numUsers][numFeatures];
- itemFeatures = new double[numItems][numFeatures];
- log.info("Initializing feature vectors...");
- for (int feature = 0; feature < numFeatures; feature++) {
- for (int userIndex = 0; userIndex < numUsers; userIndex++) {
- userFeatures[userIndex][feature] = defaultValue + (random.nextDouble() - 0.5) * interval * randomNoise;
- }
- for (int itemIndex = 0; itemIndex < numItems; itemIndex++) {
- itemFeatures[itemIndex][feature] = defaultValue + (random.nextDouble() - 0.5) * interval * randomNoise;
- }
- }
- }
- @Override
- public Factorization factorize() throws TasteException {
- for (int feature = 0; feature < numFeatures; feature++) {
- log.info("Shuffling preferences...");
- shufflePreferences();
- log.info("Starting training of feature {} ...", feature);
- for (int currentIteration = 0; currentIteration < numIterations; currentIteration++) {
- if (currentIteration == numIterations - 1) {
- double rmse = trainingIterationWithRmse(feature);
- log.info("Finished training feature {} with RMSE {}", feature, rmse);
- } else {
- trainingIteration(feature);
- }
- }
- if (feature < numFeatures - 1) {
- log.info("Updating cache...");
- for (int index = 0; index < userIndexes.length; index++) {
- cachedEstimates[index] = estimate(userIndexes[index], itemIndexes[index], feature, cachedEstimates[index],
- false);
- }
- }
- }
- log.info("Factorization done");
- return new Factorization(userIDMapping, itemIDMapping, userFeatures, itemFeatures);
- }
- private void trainingIteration(int feature) {
- for (int index = 0; index < userIndexes.length; index++) {
- train(userIndexes[index], itemIndexes[index], feature, values[index], cachedEstimates[index]);
- }
- }
- private double trainingIterationWithRmse(int feature) {
- double rmse = 0.0;
- for (int index = 0; index < userIndexes.length; index++) {
- double error = train(userIndexes[index], itemIndexes[index], feature, values[index], cachedEstimates[index]);
- rmse += error * error;
- }
- return Math.sqrt(rmse / userIndexes.length);
- }
- private double estimate(int userIndex, int itemIndex, int feature, double cachedEstimate, boolean trailing) {
- double sum = cachedEstimate;
- sum += userFeatures[userIndex][feature] * itemFeatures[itemIndex][feature];
- if (trailing) {
- sum += (numFeatures - feature - 1) * (defaultValue + interval) * (defaultValue + interval);
- if (sum > maxPreference) {
- sum = maxPreference;
- } else if (sum < minPreference) {
- sum = minPreference;
- }
- }
- return sum;
- }
- public double train(int userIndex, int itemIndex, int feature, double original, double cachedEstimate) {
- double error = original - estimate(userIndex, itemIndex, feature, cachedEstimate, true);
- double[] userVector = userFeatures[userIndex];
- double[] itemVector = itemFeatures[itemIndex];
- userVector[feature] += learningRate * (error * itemVector[feature] - preventOverfitting * userVector[feature]);
- itemVector[feature] += learningRate * (error * userVector[feature] - preventOverfitting * itemVector[feature]);
- return error;
- }
- protected void shufflePreferences() {
- /* Durstenfeld shuffle */
- for (int currentPos = userIndexes.length - 1; currentPos > 0; currentPos--) {
- int swapPos = random.nextInt(currentPos + 1);
- swapPreferences(currentPos, swapPos);
- }
- }
- private void swapPreferences(int posA, int posB) {
- int tmpUserIndex = userIndexes[posA];
- int tmpItemIndex = itemIndexes[posA];
- float tmpValue = values[posA];
- double tmpEstimate = cachedEstimates[posA];
- userIndexes[posA] = userIndexes[posB];
- itemIndexes[posA] = itemIndexes[posB];
- values[posA] = values[posB];
- cachedEstimates[posA] = cachedEstimates[posB];
- userIndexes[posB] = tmpUserIndex;
- itemIndexes[posB] = tmpItemIndex;
- values[posB] = tmpValue;
- cachedEstimates[posB] = tmpEstimate;
- }
- @Override
- public void refresh(Collection<Refreshable> alreadyRefreshed) {
- // do nothing
- }
2018-06-27 13:14:38 UTC
diff --git a/community/mahout-mr/pom.xml b/community/mahout-mr/pom.xml
index 625f6b0..0ea47c8 100644
--- a/community/mahout-mr/pom.xml
+++ b/community/mahout-mr/pom.xml
@@ -34,6 +34,10 @@


+ <modules>
+ <module>mr-examples</module>
+ </modules>

diff --git a/community/spark-cli-drivers/pom.xml b/community/spark-cli-drivers/pom.xml
index a2e6b5f..2e9ca58 100644
--- a/community/spark-cli-drivers/pom.xml
+++ b/community/spark-cli-drivers/pom.xml
@@ -72,6 +72,27 @@

+ <!-- create fat jar -->
+ <plugin>
+ <groupId>org.apache.maven.plugins</groupId>
+ <artifactId>maven-assembly-plugin</artifactId>
+ <executions>
+ <execution>
+ <id>dependency-reduced</id>
+ <phase>package</phase>
+ <goals>
+ <goal>single</goal>
+ </goals>
+ <configuration>
+ <descriptors>
+ <descriptor>src/main/assembly/dependency-reduced.xml</descriptor>
+ </descriptors>
+ </configuration>
+ </execution>
+ </executions>
+ </plugin>
<!-- ensure licenses -->

diff --git a/community/spark-cli-drivers/src/main/assembly/dependency-reduced.xml b/community/spark-cli-drivers/src/main/assembly/dependency-reduced.xml
new file mode 100644
index 0000000..5cf7d7e
--- /dev/null
+++ b/community/spark-cli-drivers/src/main/assembly/dependency-reduced.xml
@@ -0,0 +1,51 @@
+<?xml version="1.0" encoding="UTF-8"?>
+ Licensed to the Apache Software Foundation (ASF) under one or more
+ contributor license agreements. See the NOTICE file distributed with
+ this work for additional information regarding copyright ownership.
+ The ASF licenses this file to You under the Apache License, Version 2.0
+ (the "License"); you may not use this file except in compliance with
+ the License. You may obtain a copy of the License at
+ http://www.apache.org/licenses/LICENSE-2.0
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ See the License for the specific language governing permissions and
+ limitations under the License.
+ xmlns="http://maven.apache.org/plugins/maven-assembly-plugin/assembly/1.1.0"
+ xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
+ xsi:schemaLocation="http://maven.apache.org/plugins/maven-assembly-plugin/assembly/1.1.0
+ http://maven.apache.org/xsd/assembly-1.1.0.xsd">
+ <id>dependency-reduced</id>
+ <formats>
+ <format>jar</format>
+ </formats>
+ <includeBaseDirectory>false</includeBaseDirectory>
+ <dependencySets>
+ <dependencySet>
+ <unpack>true</unpack>
+ <unpackOptions>
+ <!-- MAHOUT-1126 -->
+ <excludes>
+ <exclude>META-INF/LICENSE</exclude>
+ </excludes>
+ </unpackOptions>
+ <scope>runtime</scope>
+ <outputDirectory>/</outputDirectory>
+ <useTransitiveFiltering>true</useTransitiveFiltering>
+ <!--<includes>-->
+ <!--&lt;!&ndash; guava only included to get Preconditions in mahout-math and mahout-hdfs &ndash;&gt;-->
+ <!--<include>com.google.guava:guava</include>-->
+ <!--<include>com.github.scopt_2.11</include>-->
+ <!--&lt;!&ndash;<include>com.tdunning:t-digest</include>&ndash;&gt;-->
+ <!--<include>org.apache.commons:commons-math3</include>-->
+ <!--<include>it.unimi.dsi:fastutil</include>-->
+ <!--<include>org.apache.mahout:mahout-native-viennacl_${scala.compat.version}</include>-->
+ <!--<include>org.apache.mahout:mahout-native-viennacl-omp_${scala.compat.version}</include>-->
+ <!--<include>org.bytedeco:javacpp</include>-->
+ <!--</includes>-->
+ </dependencySet>
+ </dependencySets>

diff --git a/engine/spark/src/main/assembly/dependency-reduced.xml b/engine/spark/src/main/assembly/dependency-reduced.xml
index 2e90e06..25f05fb 100644
--- a/engine/spark/src/main/assembly/dependency-reduced.xml
+++ b/engine/spark/src/main/assembly/dependency-reduced.xml
@@ -39,7 +39,7 @@
<!-- guava only included to get Preconditions in mahout-math and mahout-hdfs -->
- <include>com.tdunning:t-digest</include>
+ <!--<include>com.tdunning:t-digest</include>-->

diff --git a/examples/bin/README.txt b/examples/bin/README.txt
deleted file mode 100644
index 7ad3a38..0000000
--- a/examples/bin/README.txt
+++ /dev/null
@@ -1,13 +0,0 @@
-This directory contains helpful shell scripts for working with some of Mahout's examples.
-To set a non-default temporary work directory: `export MAHOUT_WORK_DIR=/path/in/hdfs/to/temp/dir`
- Note that this requires the same path to be writable both on the local file system as well as on HDFS.
-Here's a description of what each does:
-classify-20newsgroups.sh -- Run SGD and Bayes classifiers over the classic 20 News Groups. Downloads the data set automatically.
-cluster-reuters.sh -- Cluster the Reuters data set using a variety of algorithms. Downloads the data set automatically.
-cluster-syntheticcontrol.sh -- Cluster the Synthetic Control data set. Downloads the data set automatically.
-factorize-movielens-1m.sh -- Run the Alternating Least Squares Recommender on the Grouplens data set (size 1M).
-factorize-netflix.sh -- (Deprecated due to lack of availability of the data set) Run the ALS Recommender on the Netflix data set.
-spark-document-classifier.mscala -- A mahout-shell script which trains and tests a Naive Bayes model on the Wikipedia XML dump and defines simple methods to classify new text.

diff --git a/examples/bin/basicOLS.scala b/examples/bin/basicOLS.scala
new file mode 100644
index 0000000..97e4f83
--- /dev/null
+++ b/examples/bin/basicOLS.scala
@@ -0,0 +1,61 @@
+import org.apache.mahout.math._
+import org.apache.mahout.math.scalabindings._
+import org.apache.mahout.math.drm._
+import org.apache.mahout.math.scalabindings.RLikeOps._
+import org.apache.mahout.math.drm.RLikeDrmOps._
+import org.apache.mahout.sparkbindings._
+implicit val sdc: org.apache.mahout.sparkbindings.SparkDistributedContext = sc2sdc(sc)
+val drmData = drmParallelize(dense(
+ (2, 2, 10.5, 10, 29.509541), // Apple Cinnamon Cheerios
+ (1, 2, 12, 12, 18.042851), // Cap'n'Crunch
+ (1, 1, 12, 13, 22.736446), // Cocoa Puffs
+ (2, 1, 11, 13, 32.207582), // Froot Loops
+ (1, 2, 12, 11, 21.871292), // Honey Graham Ohs
+ (2, 1, 16, 8, 36.187559), // Wheaties Honey Gold
+ (6, 2, 17, 1, 50.764999), // Cheerios
+ (3, 2, 13, 7, 40.400208), // Clusters
+ (3, 3, 13, 4, 45.811716)), // Great Grains Pecan
+ numPartitions = 2);
+val drmX = drmData(::, 0 until 4)
+val y = drmData.collect(::, 4)
+val drmXtX = drmX.t %*% drmX
+val drmXty = drmX.t %*% y
+val XtX = drmXtX.collect
+val Xty = drmXty.collect(::, 0)
+val beta = solve(XtX, Xty)
+val yFitted = (drmX %*% beta).collect(::, 0)
+(y - yFitted).norm(2)
+def ols(drmX: DrmLike[Int], y: Vector) =
+ solve(drmX.t %*% drmX, drmX.t %*% y)(::, 0)
+def goodnessOfFit(drmX: DrmLike[Int], beta: Vector, y: Vector) = {
+ val fittedY = (drmX %*% beta).collect(::, 0)
+ (y - fittedY).norm(2)
+val drmXwithBiasColumn = drmX cbind 1
+val betaWithBiasTerm = ols(drmXwithBiasColumn, y)
+goodnessOfFit(drmXwithBiasColumn, betaWithBiasTerm, y)
+val cachedDrmX = drmXwithBiasColumn.checkpoint()
+val betaWithBiasTerm = ols(cachedDrmX, y)
+val goodness = goodnessOfFit(cachedDrmX, betaWithBiasTerm, y)
\ No newline at end of file

diff --git a/examples/bin/cco-lastfm.scala b/examples/bin/cco-lastfm.scala
new file mode 100644
index 0000000..709ab2a
--- /dev/null
+++ b/examples/bin/cco-lastfm.scala
@@ -0,0 +1,112 @@
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * Download data from: http://files.grouplens.org/datasets/hetrec2011/hetrec2011-lastfm-2k.zip
+ * then run this in the mahout shell.
+ */
+import org.apache.mahout.sparkbindings.indexeddataset.IndexedDatasetSpark
+// We need to turn our raw text files into RDD[(String, String)]
+val userTagsRDD = sc.textFile("/path/to/lastfm/user_taggedartists.dat").map(line => line.split("\t")).map(a => (a(0), a(2))).filter(_._1 != "userID")
+val userTagsIDS = IndexedDatasetSpark.apply(userTagsRDD)(sc)
+val userArtistsRDD = sc.textFile("/path/to/lastfm/user_artists.dat").map(line => line.split("\t")).map(a => (a(0), a(1))).filter(_._1 != "userID")
+val userArtistsIDS = IndexedDatasetSpark.apply(userArtistsRDD)(sc)
+val userFriendsRDD = sc.textFile("/path/to/data/lastfm/user_friends.dat").map(line => line.split("\t")).map(a => (a(0), a(1))).filter(_._1 != "userID")
+val userFriendsIDS = IndexedDatasetSpark.apply(userFriendsRDD)(sc)
+val primaryIDS = userFriendsIDS
+val secondaryActionRDDs = List(userArtistsRDD, userTagsRDD)
+import org.apache.mahout.math.indexeddataset.{IndexedDataset, BiDictionary}
+def adjustRowCardinality(rowCardinality: Integer, datasetA: IndexedDataset): IndexedDataset = {
+ val returnedA = if (rowCardinality != datasetA.matrix.nrow) datasetA.newRowCardinality(rowCardinality)
+ else datasetA // this guarantees matching cardinality
+ returnedA
+var rowCardinality = primaryIDS.rowIDs.size
+val secondaryActionIDS: Array[IndexedDataset] = new Array[IndexedDataset](secondaryActionRDDs.length)
+for (i <- secondaryActionRDDs.indices) {
+ val bcPrimaryRowIDs = sc.broadcast(primaryIDS.rowIDs)
+ bcPrimaryRowIDs.value
+ val tempRDD = secondaryActionRDDs(i).filter(a => bcPrimaryRowIDs.value.contains(a._1))
+ var tempIDS = IndexedDatasetSpark.apply(tempRDD, existingRowIDs = Some(primaryIDS.rowIDs))(sc)
+ secondaryActionIDS(i) = adjustRowCardinality(rowCardinality,tempIDS)
+import org.apache.mahout.math.cf.SimilarityAnalysis
+val artistReccosLlrDrmListByArtist = SimilarityAnalysis.cooccurrencesIDSs(
+ Array(primaryIDS, secondaryActionIDS(0), secondaryActionIDS(1)),
+ maxInterestingItemsPerThing = 20,
+ maxNumInteractions = 500,
+ randomSeed = 1234)
+// Anonymous User
+val artistMap = sc.textFile("/path/to/lastfm/artists.dat").map(line => line.split("\t")).map(a => (a(1), a(0))).filter(_._1 != "name").collect.toMap
+val tagsMap = sc.textFile("/path/to/lastfm/tags.dat").map(line => line.split("\t")).map(a => (a(1), a(0))).filter(_._1 != "tagValue").collect.toMap
+// Watch your skin- you're not wearing armour. (This will fail on misspelled artists
+// This is neccessary because the ids are integer-strings already, and for this demo I didn't want to chance them to Integer types (bc more often you'll have strings).
+val kilroyUserArtists = svec( (userArtistsIDS.columnIDs.get(artistMap("Beck")).get, 1) ::
+ (userArtistsIDS.columnIDs.get(artistMap("David Bowie")).get, 1) ::
+ (userArtistsIDS.columnIDs.get(artistMap("Gary Numan")).get, 1) ::
+ (userArtistsIDS.columnIDs.get(artistMap("Less Than Jake")).get, 1) ::
+ (userArtistsIDS.columnIDs.get(artistMap("Lou Reed")).get, 1) ::
+ (userArtistsIDS.columnIDs.get(artistMap("Parliament")).get, 1) ::
+ (userArtistsIDS.columnIDs.get(artistMap("Radiohead")).get, 1) ::
+ (userArtistsIDS.columnIDs.get(artistMap("Seu Jorge")).get, 1) ::
+ (userArtistsIDS.columnIDs.get(artistMap("The Skatalites")).get, 1) ::
+ (userArtistsIDS.columnIDs.get(artistMap("Reverend Horton Heat")).get, 1) ::
+ (userArtistsIDS.columnIDs.get(artistMap("Talking Heads")).get, 1) ::
+ (userArtistsIDS.columnIDs.get(artistMap("Tom Waits")).get, 1) ::
+ (userArtistsIDS.columnIDs.get(artistMap("Waylon Jennings")).get, 1) ::
+ (userArtistsIDS.columnIDs.get(artistMap("Wu-Tang Clan")).get, 1) :: Nil, cardinality = userArtistsIDS.columnIDs.size
+val kilroyUserTags = svec(
+ (userTagsIDS.columnIDs.get(tagsMap("classical")).get, 1) ::
+ (userTagsIDS.columnIDs.get(tagsMap("skacore")).get, 1) ::
+ (userTagsIDS.columnIDs.get(tagsMap("why on earth is this just a bonus track")).get, 1) ::
+ (userTagsIDS.columnIDs.get(tagsMap("punk rock")).get, 1) :: Nil, cardinality = userTagsIDS.columnIDs.size)
+val kilroysRecs = (artistReccosLlrDrmListByArtist(0).matrix %*% kilroyUserArtists + artistReccosLlrDrmListByArtist(1).matrix %*% kilroyUserTags).collect
+import org.apache.mahout.math.scalabindings.MahoutCollections._
+import collection._
+import JavaConversions._
+// Which Users I should Be Friends with.
+println(kilroysRecs(::, 0).toMap.toList.sortWith(_._2 > _._2).take(5))
+ * So there you have it- the basis for a new dating/friend finding app based on musical preferences which
+ * is actually a pretty dope idea.
+ *
+ * Solving for which bands a user might like is left as an exercise to the reader.
+ */
\ No newline at end of file

diff --git a/examples/bin/classify-20newsgroups.sh b/examples/bin/classify-20newsgroups.sh
deleted file mode 100755
index f47d5c5..0000000
--- a/examples/bin/classify-20newsgroups.sh
+++ /dev/null
@@ -1,197 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one or more
-# contributor license agreements. See the NOTICE file distributed with
-# this work for additional information regarding copyright ownership.
-# The ASF licenses this file to You under the Apache License, Version 2.0
-# (the "License"); you may not use this file except in compliance with
-# the License. You may obtain a copy of the License at
-# http://www.apache.org/licenses/LICENSE-2.0
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# Downloads the 20newsgroups dataset, trains and tests a classifier.
-# To run: change into the mahout directory and type:
-# examples/bin/classify-20newsgroups.sh
-if [ "$1" = "--help" ] || [ "$1" = "--?" ]; then
- echo "This script runs SGD and Bayes classifiers over the classic 20 News Groups."
- exit
-if [ "$0" != "$SCRIPT_PATH" ] && [ "$SCRIPT_PATH" != "" ]; then
-# Set commands for dfs
-source ${START_PATH}/set-dfs-commands.sh
-if [[ -z "$MAHOUT_WORK_DIR" ]]; then
- WORK_DIR=/tmp/mahout-work-${USER}
-algorithm=( cnaivebayes-MapReduce naivebayes-MapReduce cnaivebayes-Spark naivebayes-Spark sgd clean)
-if [ -n "$1" ]; then
- choice=$1
- echo "Please select a number to choose the corresponding task to run"
- echo "1. ${algorithm[0]}"
- echo "2. ${algorithm[1]}"
- echo "3. ${algorithm[2]}"
- echo "4. ${algorithm[3]}"
- echo "5. ${algorithm[4]}"
- echo "6. ${algorithm[5]}-- cleans up the work area in $WORK_DIR"
- read -p "Enter your choice : " choice
-echo "ok. You chose $choice and we'll use ${algorithm[$choice-1]}"
-# Spark specific check and work
-if [ "x$alg" == "xnaivebayes-Spark" -o "x$alg" == "xcnaivebayes-Spark" ]; then
- if [ "$MASTER" == "" ] ; then
- echo "Please set your MASTER env variable to point to your Spark Master URL. exiting..."
- exit 1
- fi
- if [ "$MAHOUT_LOCAL" != "" ] ; then
- echo "Options 3 and 4 can not run in MAHOUT_LOCAL mode. exiting..."
- exit 1
- fi
-if [ "x$alg" != "xclean" ]; then
- echo "creating work directory at ${WORK_DIR}"
- mkdir -p ${WORK_DIR}
- if [ ! -e ${WORK_DIR}/20news-bayesinput ]; then
- if [ ! -e ${WORK_DIR}/20news-bydate ]; then
- if [ ! -f ${WORK_DIR}/20news-bydate.tar.gz ]; then
- echo "Downloading 20news-bydate"
- curl http://people.csail.mit.edu/jrennie/20Newsgroups/20news-bydate.tar.gz -o ${WORK_DIR}/20news-bydate.tar.gz
- fi
- mkdir -p ${WORK_DIR}/20news-bydate
- echo "Extracting..."
- cd ${WORK_DIR}/20news-bydate && tar xzf ../20news-bydate.tar.gz && cd .. && cd ..
- fi
- fi
-#echo $START_PATH
-cd ../..
-set -e
-if ( [ "x$alg" == "xnaivebayes-MapReduce" ] || [ "x$alg" == "xcnaivebayes-MapReduce" ] || [ "x$alg" == "xnaivebayes-Spark" ] || [ "x$alg" == "xcnaivebayes-Spark" ] ); then
- c=""
- if [ "x$alg" == "xcnaivebayes-MapReduce" -o "x$alg" == "xnaivebayes-Spark" ]; then
- c=" -c"
- fi
- set -x
- echo "Preparing 20newsgroups data"
- rm -rf ${WORK_DIR}/20news-all
- mkdir ${WORK_DIR}/20news-all
- cp -R ${WORK_DIR}/20news-bydate/*/* ${WORK_DIR}/20news-all
- if [ "$HADOOP_HOME" != "" ] && [ "$MAHOUT_LOCAL" == "" ] ; then
- echo "Copying 20newsgroups data to HDFS"
- set +e
- $DFSRM ${WORK_DIR}/20news-all
- $DFS -mkdir -p ${WORK_DIR}
- $DFS -mkdir ${WORK_DIR}/20news-all
- set -e
- if [ $HVERSION -eq "1" ] ; then
- echo "Copying 20newsgroups data to Hadoop 1 HDFS"
- $DFS -put ${WORK_DIR}/20news-all ${WORK_DIR}/20news-all
- elif [ $HVERSION -eq "2" ] ; then
- echo "Copying 20newsgroups data to Hadoop 2 HDFS"
- $DFS -put ${WORK_DIR}/20news-all ${WORK_DIR}/
- fi
- fi
- echo "Creating sequence files from 20newsgroups data"
- ./bin/mahout seqdirectory \
- -i ${WORK_DIR}/20news-all \
- -o ${WORK_DIR}/20news-seq -ow
- echo "Converting sequence files to vectors"
- ./bin/mahout seq2sparse \
- -i ${WORK_DIR}/20news-seq \
- -o ${WORK_DIR}/20news-vectors -lnorm -nv -wt tfidf
- echo "Creating training and holdout set with a random 80-20 split of the generated vector dataset"
- ./bin/mahout split \
- -i ${WORK_DIR}/20news-vectors/tfidf-vectors \
- --trainingOutput ${WORK_DIR}/20news-train-vectors \
- --testOutput ${WORK_DIR}/20news-test-vectors \
- --randomSelectionPct 40 --overwrite --sequenceFiles -xm sequential
- if [ "x$alg" == "xnaivebayes-MapReduce" -o "x$alg" == "xcnaivebayes-MapReduce" ]; then
- echo "Training Naive Bayes model"
- ./bin/mahout trainnb \
- -i ${WORK_DIR}/20news-train-vectors \
- -o ${WORK_DIR}/model \
- -li ${WORK_DIR}/labelindex \
- -ow $c
- echo "Self testing on training set"
- ./bin/mahout testnb \
- -i ${WORK_DIR}/20news-train-vectors\
- -m ${WORK_DIR}/model \
- -l ${WORK_DIR}/labelindex \
- -ow -o ${WORK_DIR}/20news-testing $c
- echo "Testing on holdout set"
- ./bin/mahout testnb \
- -i ${WORK_DIR}/20news-test-vectors\
- -m ${WORK_DIR}/model \
- -l ${WORK_DIR}/labelindex \
- -ow -o ${WORK_DIR}/20news-testing $c
- elif [ "x$alg" == "xnaivebayes-Spark" -o "x$alg" == "xcnaivebayes-Spark" ]; then
- echo "Training Naive Bayes model"
- ./bin/mahout spark-trainnb \
- -i ${WORK_DIR}/20news-train-vectors \
- -o ${WORK_DIR}/spark-model $c -ow -ma $MASTER
- echo "Self testing on training set"
- ./bin/mahout spark-testnb \
- -i ${WORK_DIR}/20news-train-vectors\
- -m ${WORK_DIR}/spark-model $c -ma $MASTER
- echo "Testing on holdout set"
- ./bin/mahout spark-testnb \
- -i ${WORK_DIR}/20news-test-vectors\
- -m ${WORK_DIR}/spark-model $c -ma $MASTER
- fi
-elif [ "x$alg" == "xsgd" ]; then
- if [ ! -e "/tmp/news-group.model" ]; then
- echo "Training on ${WORK_DIR}/20news-bydate/20news-bydate-train/"
- ./bin/mahout org.apache.mahout.classifier.sgd.TrainNewsGroups ${WORK_DIR}/20news-bydate/20news-bydate-train/
- fi
- echo "Testing on ${WORK_DIR}/20news-bydate/20news-bydate-test/ with model: /tmp/news-group.model"
- ./bin/mahout org.apache.mahout.classifier.sgd.TestNewsGroups --input ${WORK_DIR}/20news-bydate/20news-bydate-test/ --model /tmp/news-group.model
-elif [ "x$alg" == "xclean" ]; then
- rm -rf $WORK_DIR
- rm -rf /tmp/news-group.model
-# Remove the work directory

diff --git a/examples/bin/classify-wikipedia.sh b/examples/bin/classify-wikipedia.sh
deleted file mode 100755
index 41dc0c9..0000000
--- a/examples/bin/classify-wikipedia.sh
+++ /dev/null
@@ -1,196 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one or more
-# contributor license agreements. See the NOTICE file distributed with
-# this work for additional information regarding copyright ownership.
-# The ASF licenses this file to You under the Apache License, Version 2.0
-# (the "License"); you may not use this file except in compliance with
-# the License. You may obtain a copy of the License at
-# http://www.apache.org/licenses/LICENSE-2.0
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# Downloads a (partial) wikipedia dump, trains and tests a classifier.
-# To run: change into the mahout directory and type:
-# examples/bin/classify-wikipedia.sh
-if [ "$1" = "--help" ] || [ "$1" = "--?" ]; then
- echo "This script Bayes and CBayes classifiers over the last wikipedia dump."
- exit
-# ensure that MAHOUT_HOME is set
-if [[ -z "$MAHOUT_HOME" ]]; then
- echo "Please set MAHOUT_HOME."
- exit
-if [ "$0" != "$SCRIPT_PATH" ] && [ "$SCRIPT_PATH" != "" ]; then
-# Set commands for dfs
-source ${START_PATH}/set-dfs-commands.sh
-if [[ -z "$MAHOUT_WORK_DIR" ]]; then
- WORK_DIR=/tmp/mahout-work-wiki
-algorithm=( CBayes BinaryCBayes clean)
-if [ -n "$1" ]; then
- choice=$1
- echo "Please select a number to choose the corresponding task to run"
- echo "1. ${algorithm[0]} (may require increased heap space on yarn)"
- echo "2. ${algorithm[1]}"
- echo "3. ${algorithm[2]} -- cleans up the work area in $WORK_DIR"
- read -p "Enter your choice : " choice
-echo "ok. You chose $choice and we'll use ${algorithm[$choice-1]}"
-if [ "x$alg" != "xclean" ]; then
- echo "creating work directory at ${WORK_DIR}"
- mkdir -p ${WORK_DIR}
- if [ ! -e ${WORK_DIR}/wikixml ]; then
- mkdir -p ${WORK_DIR}/wikixml
- fi
- if [ ! -e ${WORK_DIR}/wikixml/enwiki-latest-pages-articles.xml.bz2 ]; then
- echo "Downloading wikipedia XML dump"
- ########################################################
- # Datasets: uncomment and run "clean" to change dataset
- ########################################################
- ########## partial small 42.5M zipped
- # curl https://dumps.wikimedia.org/enwiki/latest/enwiki-latest-pages-articles1.xml-p000000010p000030302.bz2 -o ${WORK_DIR}/wikixml/enwiki-latest-pages-articles.xml.bz2
- ########## partial larger 256M zipped
- curl https://dumps.wikimedia.org/enwiki/latest/enwiki-latest-pages-articles10.xml-p2336425p3046511.bz2 -o ${WORK_DIR}/wikixml/enwiki-latest-pages-articles.xml.bz2
- ######### full wikipedia dump: 10G zipped
- # curl https://dumps.wikimedia.org/enwiki/latest/enwiki-latest-pages-articles.xml.bz2 -o ${WORK_DIR}/wikixml/enwiki-latest-pages-articles.xml.bz2
- ########################################################
- fi
- if [ ! -e ${WORK_DIR}/wikixml/enwiki-latest-pages-articles.xml ]; then
- echo "Extracting..."
- cd ${WORK_DIR}/wikixml && bunzip2 enwiki-latest-pages-articles.xml.bz2 && cd .. && cd ..
- fi
-set -e
-if [ "x$alg" == "xCBayes" ] || [ "x$alg" == "xBinaryCBayes" ] ; then
- set -x
- echo "Preparing wikipedia data"
- rm -rf ${WORK_DIR}/wiki
- mkdir ${WORK_DIR}/wiki
- if [ "x$alg" == "xCBayes" ] ; then
- # use a list of 10 countries as categories
- cp $MAHOUT_HOME/examples/bin/resources/country10.txt ${WORK_DIR}/country.txt
- chmod 666 ${WORK_DIR}/country.txt
- fi
- if [ "x$alg" == "xBinaryCBayes" ] ; then
- # use United States and United Kingdom as categories
- cp $MAHOUT_HOME/examples/bin/resources/country2.txt ${WORK_DIR}/country.txt
- chmod 666 ${WORK_DIR}/country.txt
- fi
- if [ "$HADOOP_HOME" != "" ] && [ "$MAHOUT_LOCAL" == "" ] ; then
- echo "Copying wikipedia data to HDFS"
- set +e
- $DFSRM ${WORK_DIR}/wikixml
- $DFS -mkdir -p ${WORK_DIR}
- set -e
- $DFS -put ${WORK_DIR}/wikixml ${WORK_DIR}/wikixml
- fi
- echo "Creating sequence files from wikiXML"
- $MAHOUT_HOME/bin/mahout seqwiki -c ${WORK_DIR}/country.txt \
- -i ${WORK_DIR}/wikixml/enwiki-latest-pages-articles.xml \
- -o ${WORK_DIR}/wikipediainput
- # if using the 10 class problem use bigrams
- if [ "x$alg" == "xCBayes" ] ; then
- echo "Converting sequence files to vectors using bigrams"
- $MAHOUT_HOME/bin/mahout seq2sparse -i ${WORK_DIR}/wikipediainput \
- -o ${WORK_DIR}/wikipediaVecs \
- -wt tfidf \
- -lnorm -nv \
- -ow -ng 2
- fi
- # if using the 2 class problem try different options
- if [ "x$alg" == "xBinaryCBayes" ] ; then
- echo "Converting sequence files to vectors using unigrams and a max document frequency of 30%"
- $MAHOUT_HOME/bin/mahout seq2sparse -i ${WORK_DIR}/wikipediainput \
- -o ${WORK_DIR}/wikipediaVecs \
- -wt tfidf \
- -lnorm \
- -nv \
- -ow \
- -ng 1 \
- -x 30
- fi
- echo "Creating training and holdout set with a random 80-20 split of the generated vector dataset"
- $MAHOUT_HOME/bin/mahout split -i ${WORK_DIR}/wikipediaVecs/tfidf-vectors/ \
- --trainingOutput ${WORK_DIR}/training \
- --testOutput ${WORK_DIR}/testing \
- -rp 20 \
- -ow \
- -seq \
- -xm sequential
- echo "Training Naive Bayes model"
- $MAHOUT_HOME/bin/mahout trainnb -i ${WORK_DIR}/training \
- -o ${WORK_DIR}/model \
- -li ${WORK_DIR}/labelindex \
- -ow \
- -c
- echo "Self testing on training set"
- $MAHOUT_HOME/bin/mahout testnb -i ${WORK_DIR}/training \
- -m ${WORK_DIR}/model \
- -l ${WORK_DIR}/labelindex \
- -ow \
- -o ${WORK_DIR}/output \
- -c
- echo "Testing on holdout set: Bayes"
- $MAHOUT_HOME/bin/mahout testnb -i ${WORK_DIR}/testing \
- -m ${WORK_DIR}/model \
- -l ${WORK_DIR}/labelindex \
- -ow \
- -o ${WORK_DIR}/output \
- -seq
- echo "Testing on holdout set: CBayes"
- $MAHOUT_HOME/bin/mahout testnb -i ${WORK_DIR}/testing \
- -m ${WORK_DIR}/model -l \
- ${WORK_DIR}/labelindex \
- -ow \
- -o ${WORK_DIR}/output \
- -c \
- -seq
-elif [ "x$alg" == "xclean" ]; then
- rm -rf $WORK_DIR
-# Remove the work directory

diff --git a/examples/bin/cluster-reuters.sh b/examples/bin/cluster-reuters.sh
deleted file mode 100755
index 49f6c94..0000000
--- a/examples/bin/cluster-reuters.sh
+++ /dev/null
@@ -1,203 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one or more
-# contributor license agreements. See the NOTICE file distributed with
-# this work for additional information regarding copyright ownership.
-# The ASF licenses this file to You under the Apache License, Version 2.0
-# (the "License"); you may not use this file except in compliance with
-# the License. You may obtain a copy of the License at
-# http://www.apache.org/licenses/LICENSE-2.0
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# Downloads the Reuters dataset and prepares it for clustering
-# To run: change into the mahout directory and type:
-# examples/bin/cluster-reuters.sh
-if [ "$1" = "--help" ] || [ "$1" = "--?" ]; then
- echo "This script clusters the Reuters data set using a variety of algorithms. The data set is downloaded automatically."
- exit
-if [ "$0" != "$SCRIPT_PATH" ] && [ "$SCRIPT_PATH" != "" ]; then
-# Set commands for dfs
-source ${START_PATH}/set-dfs-commands.sh
-if [ ! -e $MAHOUT ]; then
- echo "Can't find mahout driver in $MAHOUT, cwd `pwd`, exiting.."
- exit 1
-if [[ -z "$MAHOUT_WORK_DIR" ]]; then
- WORK_DIR=/tmp/mahout-work-${USER}
-algorithm=( kmeans fuzzykmeans lda streamingkmeans clean)
-if [ -n "$1" ]; then
- choice=$1
- echo "Please select a number to choose the corresponding clustering algorithm"
- echo "1. ${algorithm[0]} clustering (runs from this example script in cluster mode only)"
- echo "2. ${algorithm[1]} clustering (may require increased heap space on yarn)"
- echo "3. ${algorithm[2]} clustering"
- echo "4. ${algorithm[3]} clustering"
- echo "5. ${algorithm[4]} -- cleans up the work area in $WORK_DIR"
- read -p "Enter your choice : " choice
-echo "ok. You chose $choice and we'll use ${algorithm[$choice-1]} Clustering"
-if [ "x$clustertype" == "xclean" ]; then
- rm -rf $WORK_DIR
- exit 1
- $DFS -mkdir -p $WORK_DIR
- mkdir -p $WORK_DIR
- echo "Creating work directory at ${WORK_DIR}"
-if [ ! -e ${WORK_DIR}/reuters-out-seqdir ]; then
- if [ ! -e ${WORK_DIR}/reuters-out ]; then
- if [ ! -e ${WORK_DIR}/reuters-sgm ]; then
- if [ ! -f ${WORK_DIR}/reuters21578.tar.gz ]; then
- if [ -n "$2" ]; then
- echo "Copying Reuters from local download"
- cp $2 ${WORK_DIR}/reuters21578.tar.gz
- else
- echo "Downloading Reuters-21578"
- curl http://kdd.ics.uci.edu/databases/reuters21578/reuters21578.tar.gz -o ${WORK_DIR}/reuters21578.tar.gz
- fi
- fi
- #make sure it was actually downloaded
- if [ ! -f ${WORK_DIR}/reuters21578.tar.gz ]; then
- echo "Failed to download reuters"
- exit 1
- fi
- mkdir -p ${WORK_DIR}/reuters-sgm
- echo "Extracting..."
- tar xzf ${WORK_DIR}/reuters21578.tar.gz -C ${WORK_DIR}/reuters-sgm
- fi
- echo "Extracting Reuters"
- $MAHOUT org.apache.lucene.benchmark.utils.ExtractReuters ${WORK_DIR}/reuters-sgm ${WORK_DIR}/reuters-out
- if [ "$HADOOP_HOME" != "" ] && [ "$MAHOUT_LOCAL" == "" ] ; then
- echo "Copying Reuters data to Hadoop"
- set +e
- $DFSRM ${WORK_DIR}/reuters-sgm
- $DFSRM ${WORK_DIR}/reuters-out
- $DFS -mkdir -p ${WORK_DIR}/
- $DFS -mkdir ${WORK_DIR}/reuters-sgm
- $DFS -mkdir ${WORK_DIR}/reuters-out
- $DFS -put ${WORK_DIR}/reuters-sgm ${WORK_DIR}/reuters-sgm
- $DFS -put ${WORK_DIR}/reuters-out ${WORK_DIR}/reuters-out
- set -e
- fi
- fi
- echo "Converting to Sequence Files from Directory"
- $MAHOUT seqdirectory -i ${WORK_DIR}/reuters-out -o ${WORK_DIR}/reuters-out-seqdir -c UTF-8 -chunk 64 -xm sequential
-if [ "x$clustertype" == "xkmeans" ]; then
- $MAHOUT seq2sparse \
- -i ${WORK_DIR}/reuters-out-seqdir/ \
- -o ${WORK_DIR}/reuters-out-seqdir-sparse-kmeans --maxDFPercent 85 --namedVector \
- && \
- $MAHOUT kmeans \
- -i ${WORK_DIR}/reuters-out-seqdir-sparse-kmeans/tfidf-vectors/ \
- -c ${WORK_DIR}/reuters-kmeans-clusters \
- -o ${WORK_DIR}/reuters-kmeans \
- -dm org.apache.mahout.common.distance.EuclideanDistanceMeasure \
- -x 10 -k 20 -ow --clustering \
- && \
- $MAHOUT clusterdump \
- -i `$DFS -ls -d ${WORK_DIR}/reuters-kmeans/clusters-*-final | awk '{print $8}'` \
- -o ${WORK_DIR}/reuters-kmeans/clusterdump \
- -d ${WORK_DIR}/reuters-out-seqdir-sparse-kmeans/dictionary.file-0 \
- -dt sequencefile -b 100 -n 20 --evaluate -dm org.apache.mahout.common.distance.EuclideanDistanceMeasure -sp 0 \
- --pointsDir ${WORK_DIR}/reuters-kmeans/clusteredPoints \
- && \
- cat ${WORK_DIR}/reuters-kmeans/clusterdump
-elif [ "x$clustertype" == "xfuzzykmeans" ]; then
- $MAHOUT seq2sparse \
- -i ${WORK_DIR}/reuters-out-seqdir/ \
- -o ${WORK_DIR}/reuters-out-seqdir-sparse-fkmeans --maxDFPercent 85 --namedVector \
- && \
- $MAHOUT fkmeans \
- -i ${WORK_DIR}/reuters-out-seqdir-sparse-fkmeans/tfidf-vectors/ \
- -c ${WORK_DIR}/reuters-fkmeans-clusters \
- -o ${WORK_DIR}/reuters-fkmeans \
- -dm org.apache.mahout.common.distance.EuclideanDistanceMeasure \
- -x 10 -k 20 -ow -m 1.1 \
- && \
- $MAHOUT clusterdump \
- -i ${WORK_DIR}/reuters-fkmeans/clusters-*-final \
- -o ${WORK_DIR}/reuters-fkmeans/clusterdump \
- -d ${WORK_DIR}/reuters-out-seqdir-sparse-fkmeans/dictionary.file-0 \
- -dt sequencefile -b 100 -n 20 -sp 0 \
- && \
- cat ${WORK_DIR}/reuters-fkmeans/clusterdump
-elif [ "x$clustertype" == "xlda" ]; then
- $MAHOUT seq2sparse \
- -i ${WORK_DIR}/reuters-out-seqdir/ \
- -o ${WORK_DIR}/reuters-out-seqdir-sparse-lda -ow --maxDFPercent 85 --namedVector \
- && \
- $MAHOUT rowid \
- -i ${WORK_DIR}/reuters-out-seqdir-sparse-lda/tfidf-vectors \
- -o ${WORK_DIR}/reuters-out-matrix \
- && \
- rm -rf ${WORK_DIR}/reuters-lda ${WORK_DIR}/reuters-lda-topics ${WORK_DIR}/reuters-lda-model \
- && \
- $MAHOUT cvb \
- -i ${WORK_DIR}/reuters-out-matrix/matrix \
- -o ${WORK_DIR}/reuters-lda -k 20 -ow -x 20 \
- -dict ${WORK_DIR}/reuters-out-seqdir-sparse-lda/dictionary.file-* \
- -dt ${WORK_DIR}/reuters-lda-topics \
- -mt ${WORK_DIR}/reuters-lda-model \
- && \
- $MAHOUT vectordump \
- -i ${WORK_DIR}/reuters-lda-topics/part-m-00000 \
- -o ${WORK_DIR}/reuters-lda/vectordump \
- -vs 10 -p true \
- -d ${WORK_DIR}/reuters-out-seqdir-sparse-lda/dictionary.file-* \
- -dt sequencefile -sort ${WORK_DIR}/reuters-lda-topics/part-m-00000 \
- && \
- cat ${WORK_DIR}/reuters-lda/vectordump
-elif [ "x$clustertype" == "xstreamingkmeans" ]; then
- $MAHOUT seq2sparse \
- -i ${WORK_DIR}/reuters-out-seqdir/ \
- -o ${WORK_DIR}/reuters-out-seqdir-sparse-streamingkmeans -ow --maxDFPercent 85 --namedVector \
- && \
- rm -rf ${WORK_DIR}/reuters-streamingkmeans \
- && \
- $MAHOUT streamingkmeans \
- -i ${WORK_DIR}/reuters-out-seqdir-sparse-streamingkmeans/tfidf-vectors/ \
- --tempDir ${WORK_DIR}/tmp \
- -o ${WORK_DIR}/reuters-streamingkmeans \
- -sc org.apache.mahout.math.neighborhood.FastProjectionSearch \
- -dm org.apache.mahout.common.distance.SquaredEuclideanDistanceMeasure \
- -k 10 -km 100 -ow \
- && \
- $MAHOUT qualcluster \
- -i ${WORK_DIR}/reuters-out-seqdir-sparse-streamingkmeans/tfidf-vectors/part-r-00000 \
- -c ${WORK_DIR}/reuters-streamingkmeans/part-r-00000 \
- -o ${WORK_DIR}/reuters-cluster-distance.csv \
- && \
- cat ${WORK_DIR}/reuters-cluster-distance.csv

diff --git a/examples/bin/cluster-syntheticcontrol.sh b/examples/bin/cluster-syntheticcontrol.sh
deleted file mode 100755
index 39b2255..0000000
--- a/examples/bin/cluster-syntheticcontrol.sh
+++ /dev/null
@@ -1,105 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one or more
-# contributor license agreements. See the NOTICE file distributed with
-# this work for additional information regarding copyright ownership.
-# The ASF licenses this file to You under the Apache License, Version 2.0
-# (the "License"); you may not use this file except in compliance with
-# the License. You may obtain a copy of the License at
-# http://www.apache.org/licenses/LICENSE-2.0
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# Downloads the Synthetic control dataset and prepares it for clustering
-# To run: change into the mahout directory and type:
-# examples/bin/cluster-syntheticcontrol.sh
-if [ "$1" = "--help" ] || [ "$1" = "--?" ]; then
- echo "This script clusters the Synthetic Control data set. The data set is downloaded automatically."
- exit
-algorithm=( kmeans fuzzykmeans )
-if [ -n "$1" ]; then
- choice=$1
- echo "Please select a number to choose the corresponding clustering algorithm"
- echo "1. ${algorithm[0]} clustering"
- echo "2. ${algorithm[1]} clustering"
- read -p "Enter your choice : " choice
-echo "ok. You chose $choice and we'll use ${algorithm[$choice-1]} Clustering"
-if [ "$0" != "$SCRIPT_PATH" ] && [ "$SCRIPT_PATH" != "" ]; then
-# Set commands for dfs
-source ${START_PATH}/set-dfs-commands.sh
-if [[ -z "$MAHOUT_WORK_DIR" ]]; then
- WORK_DIR=/tmp/mahout-work-${USER}
-echo "creating work directory at ${WORK_DIR}"
-mkdir -p ${WORK_DIR}
-if [ ! -f ${WORK_DIR}/synthetic_control.data ]; then
- if [ -n "$2" ]; then
- cp $2 ${WORK_DIR}/.
- else
- echo "Downloading Synthetic control data"
- curl http://archive.ics.uci.edu/ml/databases/synthetic_control/synthetic_control.data -o ${WORK_DIR}/synthetic_control.data
- fi
-if [ ! -f ${WORK_DIR}/synthetic_control.data ]; then
- echo "Couldn't download synthetic control"
- exit 1
-if [ "$HADOOP_HOME" != "" ] && [ "$MAHOUT_LOCAL" == "" ]; then
- echo "Checking the health of DFS..."
- $DFS -ls /
- if [ $? -eq 0 ];then
- echo "DFS is healthy... "
- echo "Uploading Synthetic control data to HDFS"
- $DFSRM ${WORK_DIR}/testdata
- $DFS -mkdir -p ${WORK_DIR}/testdata
- $DFS -put ${WORK_DIR}/synthetic_control.data ${WORK_DIR}/testdata
- echo "Successfully Uploaded Synthetic control data to HDFS "
- options="--input ${WORK_DIR}/testdata --output ${WORK_DIR}/output --maxIter 10 --convergenceDelta 0.5"
- if [ "${clustertype}" == "kmeans" ]; then
- options="${options} --numClusters 6"
- # t1 & t2 not used if --numClusters specified, but parser requires input
- options="${options} --t1 1 --t2 2"
- ../../bin/mahout org.apache.mahout.clustering.syntheticcontrol."${clustertype}".Job ${options}
- else
- options="${options} --m 2.0f --t1 80 --t2 55"
- ../../bin/mahout org.apache.mahout.clustering.syntheticcontrol."${clustertype}".Job ${options}
- fi
- else
- echo " HADOOP is not running. Please make sure you hadoop is running. "
- fi
-elif [ "$MAHOUT_LOCAL" != "" ]; then
- echo "running MAHOUT_LOCAL"
- cp ${WORK_DIR}/synthetic_control.data testdata
- ../../bin/mahout org.apache.mahout.clustering.syntheticcontrol."${clustertype}".Job
- rm testdata
- echo " HADOOP_HOME variable is not set. Please set this environment variable and rerun the script"
-# Remove the work directory
-rm -rf ${WORK_DIR}

diff --git a/examples/bin/factorize-movielens-1M.sh b/examples/bin/factorize-movielens-1M.sh
deleted file mode 100755
index 29730e1..0000000
--- a/examples/bin/factorize-movielens-1M.sh
+++ /dev/null
@@ -1,85 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one or more
-# contributor license agreements. See the NOTICE file distributed with
-# this work for additional information regarding copyright ownership.
-# The ASF licenses this file to You under the Apache License, Version 2.0
-# (the "License"); you may not use this file except in compliance with
-# the License. You may obtain a copy of the License at
-# http://www.apache.org/licenses/LICENSE-2.0
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# Instructions:
-# Before using this script, you have to download and extract the Movielens 1M dataset
-# from http://www.grouplens.org/node/73
-# To run: change into the mahout directory and type:
-# export MAHOUT_LOCAL=true
-# Then:
-# examples/bin/factorize-movielens-1M.sh /path/to/ratings.dat
-if [ "$1" = "--help" ] || [ "$1" = "--?" ]; then
- echo "This script runs the Alternating Least Squares Recommender on the Grouplens data set (size 1M)."
- echo "Syntax: $0 /path/to/ratings.dat\n"
- exit
-if [ $# -ne 1 ]
- echo -e "\nYou have to download the Movielens 1M dataset from http://www.grouplens.org/node/73 before"
- echo -e "you can run this example. After that extract it and supply the path to the ratings.dat file.\n"
- echo -e "Syntax: $0 /path/to/ratings.dat\n"
- exit -1
-export MAHOUT_LOCAL=true
-if [[ -z "$MAHOUT_WORK_DIR" ]]; then
- WORK_DIR=/tmp/mahout-work-${USER}
-echo "creating work directory at ${WORK_DIR}"
-mkdir -p ${WORK_DIR}/movielens
-echo "Converting ratings..."
-cat $1 |sed -e s/::/,/g| cut -d, -f1,2,3 > ${WORK_DIR}/movielens/ratings.csv
-# create a 90% percent training set and a 10% probe set
-$MAHOUT splitDataset --input ${WORK_DIR}/movielens/ratings.csv --output ${WORK_DIR}/dataset \
- --trainingPercentage 0.9 --probePercentage 0.1 --tempDir ${WORK_DIR}/dataset/tmp
-# run distributed ALS-WR to factorize the rating matrix defined by the training set
-$MAHOUT parallelALS --input ${WORK_DIR}/dataset/trainingSet/ --output ${WORK_DIR}/als/out \
- --tempDir ${WORK_DIR}/als/tmp --numFeatures 20 --numIterations 10 --lambda 0.065 --numThreadsPerSolver 2
-# compute predictions against the probe set, measure the error
-$MAHOUT evaluateFactorization --input ${WORK_DIR}/dataset/probeSet/ --output ${WORK_DIR}/als/rmse/ \
- --userFeatures ${WORK_DIR}/als/out/U/ --itemFeatures ${WORK_DIR}/als/out/M/ --tempDir ${WORK_DIR}/als/tmp
-# compute recommendations
-$MAHOUT recommendfactorized --input ${WORK_DIR}/als/out/userRatings/ --output ${WORK_DIR}/recommendations/ \
- --userFeatures ${WORK_DIR}/als/out/U/ --itemFeatures ${WORK_DIR}/als/out/M/ \
- --numRecommendations 6 --maxRating 5 --numThreads 2
-# print the error
-echo -e "\nRMSE is:\n"
-cat ${WORK_DIR}/als/rmse/rmse.txt
-echo -e "\n"
-echo -e "\nSample recommendations:\n"
-shuf ${WORK_DIR}/recommendations/part-m-00000 |head
-echo -e "\n\n"
-echo "removing work directory"
-rm -rf ${WORK_DIR}

diff --git a/examples/bin/factorize-netflix.sh b/examples/bin/factorize-netflix.sh
deleted file mode 100755
index 26faf66..0000000
--- a/examples/bin/factorize-netflix.sh
+++ /dev/null
@@ -1,90 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one or more
-# contributor license agreements. See the NOTICE file distributed with
-# this work for additional information regarding copyright ownership.
-# The ASF licenses this file to You under the Apache License, Version 2.0
-# (the "License"); you may not use this file except in compliance with
-# the License. You may obtain a copy of the License at
-# http://www.apache.org/licenses/LICENSE-2.0
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# Instructions:
-# You can only use this script in conjunction with the Netflix dataset. Unpack the Netflix dataset and provide the
-# following:
-# 1) the path to the folder 'training_set' that contains all the movie rating files
-# 2) the path to the file 'qualifying.txt' that contains the user,item pairs to predict
-# 3) the path to the file 'judging.txt' that contains the ratings of user,item pairs to predict for
-# To run:
-# ./factorize-netflix.sh /path/to/training_set/ /path/to/qualifying.txt /path/to/judging.txt
-echo "Note this script has been deprecated due to the lack of access to the Netflix data set."
-exit 1
-if [ "$1" = "--help" ] || [ "$1" = "--?" ]; then
- echo "This script runs the ALS Recommender on the Netflix data set."
- echo "Syntax: $0 /path/to/training_set/ /path/to/qualifying.txt /path/to/judging.txt\n"
- exit
-if [ $# -ne 3 ]
- echo -e "Syntax: $0 /path/to/training_set/ /path/to/qualifying.txt /path/to/judging.txt\n"
- exit -1
-if [[ -z "$MAHOUT_WORK_DIR" ]]; then
- WORK_DIR=/tmp/mahout-work-${USER}
-# Set commands for dfs
-source ${START_PATH}/set-dfs-commands.sh
-echo "Preparing data..."
-$MAHOUT org.apache.mahout.cf.taste.hadoop.example.als.netflix.NetflixDatasetConverter $1 $2 $3 ${WORK_DIR}
-# run distributed ALS-WR to factorize the rating matrix defined by the training set
-$MAHOUT parallelALS --input ${WORK_DIR}/trainingSet/ratings.tsv --output ${WORK_DIR}/als/out \
- --tempDir ${WORK_DIR}/als/tmp --numFeatures 25 --numIterations 10 --lambda 0.065 --numThreadsPerSolver 4
-# compute predictions against the probe set, measure the error
-$MAHOUT evaluateFactorization --input ${WORK_DIR}/probeSet/ratings.tsv --output ${WORK_DIR}/als/rmse/ \
- --userFeatures ${WORK_DIR}/als/out/U/ --itemFeatures ${WORK_DIR}/als/out/M/ --tempDir ${WORK_DIR}/als/tmp
-if [ "$HADOOP_HOME" != "" ] && [ "$MAHOUT_LOCAL" == "" ] ; then
- # print the error, should be around 0.923
- echo -e "\nRMSE is:\n"
- $DFS -tail ${WORK_DIR}/als/rmse/rmse.txt
- echo -e "\n"
- echo "removing work directory"
- set +e
- # print the error, should be around 0.923
- echo -e "\nRMSE is:\n"
- cat ${WORK_DIR}/als/rmse/rmse.txt
- echo -e "\n"
- echo "removing work directory"
- rm -rf ${WORK_DIR}

diff --git a/examples/bin/get-all-examples.sh b/examples/bin/get-all-examples.sh
deleted file mode 100755
index 4128e47..0000000
--- a/examples/bin/get-all-examples.sh
+++ /dev/null
@@ -1,36 +0,0 @@
-#!/usr/bin/env bash
-# Licensed to the Apache Software Foundation (ASF) under one or more
-# contributor license agreements. See the NOTICE file distributed with
-# this work for additional information regarding copyright ownership.
-# The ASF licenses this file to You under the Apache License, Version 2.0
-# (the "License"); you may not use this file except in compliance with
-# the License. You may obtain a copy of the License at
-# http://www.apache.org/licenses/LICENSE-2.0
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# Clones Mahout example code from remote repositories with their own
-# build process. Follow the README for each example for instructions.
-# Usage: change into the mahout directory and type:
-# examples/bin/get-all-examples.sh
-# Solr-recommender
-echo " Solr-recommender example: "
-echo " 1) imports text 'log files' of some delimited form for user preferences"
-echo " 2) creates the correct Mahout files and stores distionaries to translate external Id to and from Mahout Ids"
-echo " 3) it implements a prototype two actions 'cross-recommender', which takes two actions made by the same user and creates recommendations"
-echo " 4) it creates output for user->preference history CSV and and item->similar items 'similarity' matrix for use in a Solr-recommender."
-echo " To use Solr you would index the similarity matrix CSV, and use user preference history from the history CSV as a query, the result"
-echo " from Solr will be an ordered list of recommendations returning the same item Ids as were input."
-echo " For further description see the README.md here https://github.com/pferrel/solr-recommender"
-echo " To build run 'cd solr-recommender; mvn install'"
-echo " To process the example after building make sure MAHOUT_LOCAL IS SET and hadoop is in local mode then "
-echo " run 'cd scripts; ./solr-recommender-example'"
-git clone https://github.com/pferrel/solr-recommender

diff --git a/examples/bin/lda.algorithm b/examples/bin/lda.algorithm
deleted file mode 100644
index fb84ea0..0000000
--- a/examples/bin/lda.algorithm
+++ /dev/null
@@ -1,45 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one or more
-# contributor license agreements. See the NOTICE file distributed with
-# this work for additional information regarding copyright ownership.
-# The ASF licenses this file to You under the Apache License, Version 2.0
-# (the "License"); you may not use this file except in compliance with
-# the License. You may obtain a copy of the License at
-# http://www.apache.org/licenses/LICENSE-2.0
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# task at this depth or less would print when they start
-# --------- alg
-{ "BuildReuters"
- CreateIndex
- { "AddDocs" AddDoc > : *
-# Optimize
- CloseIndex
2018-06-27 13:14:34 UTC
diff --git a/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/svd/Track1SVDRunner.java b/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/svd/Track1SVDRunner.java
deleted file mode 100644
index 5cce02d..0000000
--- a/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/svd/Track1SVDRunner.java
+++ /dev/null
@@ -1,141 +0,0 @@
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.mahout.cf.taste.example.kddcup.track1.svd;
-import java.io.BufferedOutputStream;
-import java.io.File;
-import java.io.FileOutputStream;
-import java.io.OutputStream;
-import com.google.common.io.Closeables;
-import org.apache.mahout.cf.taste.common.NoSuchItemException;
-import org.apache.mahout.cf.taste.common.NoSuchUserException;
-import org.apache.mahout.cf.taste.example.kddcup.DataFileIterable;
-import org.apache.mahout.cf.taste.example.kddcup.KDDCupDataModel;
-import org.apache.mahout.cf.taste.example.kddcup.track1.EstimateConverter;
-import org.apache.mahout.cf.taste.impl.common.FullRunningAverage;
-import org.apache.mahout.cf.taste.impl.common.RunningAverage;
-import org.apache.mahout.cf.taste.impl.recommender.svd.Factorization;
-import org.apache.mahout.cf.taste.impl.recommender.svd.Factorizer;
-import org.apache.mahout.cf.taste.model.Preference;
-import org.apache.mahout.cf.taste.model.PreferenceArray;
-import org.apache.mahout.common.Pair;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
- * run an SVD factorization of the KDD track1 data.
- *
- * needs at least 6-7GB of memory, tested with -Xms6700M -Xmx6700M
- *
- */
-public final class Track1SVDRunner {
- private static final Logger log = LoggerFactory.getLogger(Track1SVDRunner.class);
- private Track1SVDRunner() {
- }
- public static void main(String[] args) throws Exception {
- if (args.length != 2) {
- System.err.println("Necessary arguments: <kddDataFileDirectory> <resultFile>");
- return;
- }
- File dataFileDirectory = new File(args[0]);
- if (!dataFileDirectory.exists() || !dataFileDirectory.isDirectory()) {
- throw new IllegalArgumentException("Bad data file directory: " + dataFileDirectory);
- }
- File resultFile = new File(args[1]);
- /* the knobs to turn */
- int numFeatures = 20;
- int numIterations = 5;
- double learningRate = 0.0001;
- double preventOverfitting = 0.002;
- double randomNoise = 0.0001;
- KDDCupFactorizablePreferences factorizablePreferences =
- new KDDCupFactorizablePreferences(KDDCupDataModel.getTrainingFile(dataFileDirectory));
- Factorizer sgdFactorizer = new ParallelArraysSGDFactorizer(factorizablePreferences, numFeatures, numIterations,
- learningRate, preventOverfitting, randomNoise);
- Factorization factorization = sgdFactorizer.factorize();
- log.info("Estimating validation preferences...");
- int prefsProcessed = 0;
- RunningAverage average = new FullRunningAverage();
- for (Pair<PreferenceArray,long[]> validationPair
- : new DataFileIterable(KDDCupDataModel.getValidationFile(dataFileDirectory))) {
- for (Preference validationPref : validationPair.getFirst()) {
- double estimate = estimatePreference(factorization, validationPref.getUserID(), validationPref.getItemID(),
- factorizablePreferences.getMinPreference(), factorizablePreferences.getMaxPreference());
- double error = validationPref.getValue() - estimate;
- average.addDatum(error * error);
- prefsProcessed++;
- if (prefsProcessed % 100000 == 0) {
- log.info("Computed {} estimations", prefsProcessed);
- }
- }
- }
- log.info("Computed {} estimations, done.", prefsProcessed);
- double rmse = Math.sqrt(average.getAverage());
- log.info("RMSE {}", rmse);
- log.info("Estimating test preferences...");
- OutputStream out = null;
- try {
- out = new BufferedOutputStream(new FileOutputStream(resultFile));
- for (Pair<PreferenceArray,long[]> testPair
- : new DataFileIterable(KDDCupDataModel.getTestFile(dataFileDirectory))) {
- for (Preference testPref : testPair.getFirst()) {
- double estimate = estimatePreference(factorization, testPref.getUserID(), testPref.getItemID(),
- factorizablePreferences.getMinPreference(), factorizablePreferences.getMaxPreference());
- byte result = EstimateConverter.convert(estimate, testPref.getUserID(), testPref.getItemID());
- out.write(result);
- }
- }
- } finally {
- Closeables.close(out, false);
- }
- log.info("wrote estimates to {}, done.", resultFile.getAbsolutePath());
- }
- static double estimatePreference(Factorization factorization, long userID, long itemID, float minPreference,
- float maxPreference) throws NoSuchUserException, NoSuchItemException {
- double[] userFeatures = factorization.getUserFeatures(userID);
- double[] itemFeatures = factorization.getItemFeatures(itemID);
- double estimate = 0;
- for (int feature = 0; feature < userFeatures.length; feature++) {
- estimate += userFeatures[feature] * itemFeatures[feature];
- }
- if (estimate < minPreference) {
- estimate = minPreference;
- } else if (estimate > maxPreference) {
- estimate = maxPreference;
- }
- return estimate;
- }

diff --git a/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/HybridSimilarity.java b/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/HybridSimilarity.java
deleted file mode 100644
index ce025a9..0000000
--- a/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/HybridSimilarity.java
+++ /dev/null
@@ -1,62 +0,0 @@
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.mahout.cf.taste.example.kddcup.track2;
-import java.io.File;
-import java.io.IOException;
-import java.util.Collection;
-import org.apache.mahout.cf.taste.common.Refreshable;
-import org.apache.mahout.cf.taste.common.TasteException;
-import org.apache.mahout.cf.taste.impl.similarity.AbstractItemSimilarity;
-import org.apache.mahout.cf.taste.impl.similarity.LogLikelihoodSimilarity;
-import org.apache.mahout.cf.taste.model.DataModel;
-import org.apache.mahout.cf.taste.similarity.ItemSimilarity;
-final class HybridSimilarity extends AbstractItemSimilarity {
- private final ItemSimilarity cfSimilarity;
- private final ItemSimilarity contentSimilarity;
- HybridSimilarity(DataModel dataModel, File dataFileDirectory) throws IOException {
- super(dataModel);
- cfSimilarity = new LogLikelihoodSimilarity(dataModel);
- contentSimilarity = new TrackItemSimilarity(dataFileDirectory);
- }
- @Override
- public double itemSimilarity(long itemID1, long itemID2) throws TasteException {
- return contentSimilarity.itemSimilarity(itemID1, itemID2) * cfSimilarity.itemSimilarity(itemID1, itemID2);
- }
- @Override
- public double[] itemSimilarities(long itemID1, long[] itemID2s) throws TasteException {
- double[] result = contentSimilarity.itemSimilarities(itemID1, itemID2s);
- double[] multipliers = cfSimilarity.itemSimilarities(itemID1, itemID2s);
- for (int i = 0; i < result.length; i++) {
- result[i] *= multipliers[i];
- }
- return result;
- }
- @Override
- public void refresh(Collection<Refreshable> alreadyRefreshed) {
- cfSimilarity.refresh(alreadyRefreshed);
- }

diff --git a/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/Track2Callable.java b/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/Track2Callable.java
deleted file mode 100644
index 50fd35e..0000000
--- a/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/Track2Callable.java
+++ /dev/null
@@ -1,106 +0,0 @@
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.mahout.cf.taste.example.kddcup.track2;
-import org.apache.mahout.cf.taste.common.NoSuchItemException;
-import org.apache.mahout.cf.taste.common.TasteException;
-import org.apache.mahout.cf.taste.model.PreferenceArray;
-import org.apache.mahout.cf.taste.recommender.Recommender;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
-import java.util.ArrayList;
-import java.util.Collection;
-import java.util.Collections;
-import java.util.HashSet;
-import java.util.List;
-import java.util.TreeMap;
-import java.util.concurrent.Callable;
-import java.util.concurrent.atomic.AtomicInteger;
-final class Track2Callable implements Callable<UserResult> {
- private static final Logger log = LoggerFactory.getLogger(Track2Callable.class);
- private static final AtomicInteger COUNT = new AtomicInteger();
- private final Recommender recommender;
- private final PreferenceArray userTest;
- Track2Callable(Recommender recommender, PreferenceArray userTest) {
- this.recommender = recommender;
- this.userTest = userTest;
- }
- @Override
- public UserResult call() throws TasteException {
- int testSize = userTest.length();
- if (testSize != 6) {
- throw new IllegalArgumentException("Expecting 6 items for user but got " + userTest);
- }
- long userID = userTest.get(0).getUserID();
- TreeMap<Double,Long> estimateToItemID = new TreeMap<>(Collections.reverseOrder());
- for (int i = 0; i < testSize; i++) {
- long itemID = userTest.getItemID(i);
- double estimate;
- try {
- estimate = recommender.estimatePreference(userID, itemID);
- } catch (NoSuchItemException nsie) {
- // OK in the sample data provided before the contest, should never happen otherwise
- log.warn("Unknown item {}; OK unless this is the real contest data", itemID);
- continue;
- }
- if (!Double.isNaN(estimate)) {
- estimateToItemID.put(estimate, itemID);
- }
- }
- Collection<Long> itemIDs = estimateToItemID.values();
- List<Long> topThree = new ArrayList<>(itemIDs);
- if (topThree.size() > 3) {
- topThree = topThree.subList(0, 3);
- } else if (topThree.size() < 3) {
- log.warn("Unable to recommend three items for {}", userID);
- // Some NaNs - just guess at the rest then
- Collection<Long> newItemIDs = new HashSet<>(3);
- newItemIDs.addAll(itemIDs);
- int i = 0;
- while (i < testSize && newItemIDs.size() < 3) {
- newItemIDs.add(userTest.getItemID(i));
- i++;
- }
- topThree = new ArrayList<>(newItemIDs);
- }
- if (topThree.size() != 3) {
- throw new IllegalStateException();
- }
- boolean[] result = new boolean[testSize];
- for (int i = 0; i < testSize; i++) {
- result[i] = topThree.contains(userTest.getItemID(i));
- }
- if (COUNT.incrementAndGet() % 1000 == 0) {
- log.info("Completed {} users", COUNT.get());
- }
- return new UserResult(userID, result);
- }

diff --git a/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/Track2Recommender.java b/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/Track2Recommender.java
deleted file mode 100644
index 185a00d..0000000
--- a/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/Track2Recommender.java
+++ /dev/null
@@ -1,100 +0,0 @@
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.mahout.cf.taste.example.kddcup.track2;
-import java.io.File;
-import java.io.IOException;
-import java.util.Collection;
-import java.util.List;
-import org.apache.mahout.cf.taste.common.Refreshable;
-import org.apache.mahout.cf.taste.common.TasteException;
-import org.apache.mahout.cf.taste.impl.recommender.GenericBooleanPrefItemBasedRecommender;
-import org.apache.mahout.cf.taste.model.DataModel;
-import org.apache.mahout.cf.taste.recommender.IDRescorer;
-import org.apache.mahout.cf.taste.recommender.RecommendedItem;
-import org.apache.mahout.cf.taste.recommender.Recommender;
-import org.apache.mahout.cf.taste.similarity.ItemSimilarity;
-public final class Track2Recommender implements Recommender {
- private final Recommender recommender;
- public Track2Recommender(DataModel dataModel, File dataFileDirectory) throws TasteException {
- // Change this to whatever you like!
- ItemSimilarity similarity;
- try {
- similarity = new HybridSimilarity(dataModel, dataFileDirectory);
- } catch (IOException ioe) {
- throw new TasteException(ioe);
- }
- recommender = new GenericBooleanPrefItemBasedRecommender(dataModel, similarity);
- }
- @Override
- public List<RecommendedItem> recommend(long userID, int howMany) throws TasteException {
- return recommender.recommend(userID, howMany);
- }
- @Override
- public List<RecommendedItem> recommend(long userID, int howMany, boolean includeKnownItems) throws TasteException {
- return recommend(userID, howMany, null, includeKnownItems);
- }
- @Override
- public List<RecommendedItem> recommend(long userID, int howMany, IDRescorer rescorer) throws TasteException {
- return recommender.recommend(userID, howMany, rescorer, false);
- }
- @Override
- public List<RecommendedItem> recommend(long userID, int howMany, IDRescorer rescorer, boolean includeKnownItems)
- throws TasteException {
- return recommender.recommend(userID, howMany, rescorer, includeKnownItems);
- }
- @Override
- public float estimatePreference(long userID, long itemID) throws TasteException {
- return recommender.estimatePreference(userID, itemID);
- }
- @Override
- public void setPreference(long userID, long itemID, float value) throws TasteException {
- recommender.setPreference(userID, itemID, value);
- }
- @Override
- public void removePreference(long userID, long itemID) throws TasteException {
- recommender.removePreference(userID, itemID);
- }
- @Override
- public DataModel getDataModel() {
- return recommender.getDataModel();
- }
- @Override
- public void refresh(Collection<Refreshable> alreadyRefreshed) {
- recommender.refresh(alreadyRefreshed);
- }
- @Override
- public String toString() {
- return "Track1Recommender[recommender:" + recommender + ']';
- }

diff --git a/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/Track2RecommenderBuilder.java b/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/Track2RecommenderBuilder.java
deleted file mode 100644
index 09ade5d..0000000
--- a/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/Track2RecommenderBuilder.java
+++ /dev/null
@@ -1,33 +0,0 @@
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.mahout.cf.taste.example.kddcup.track2;
-import org.apache.mahout.cf.taste.common.TasteException;
-import org.apache.mahout.cf.taste.eval.RecommenderBuilder;
-import org.apache.mahout.cf.taste.example.kddcup.KDDCupDataModel;
-import org.apache.mahout.cf.taste.model.DataModel;
-import org.apache.mahout.cf.taste.recommender.Recommender;
-final class Track2RecommenderBuilder implements RecommenderBuilder {
- @Override
- public Recommender buildRecommender(DataModel dataModel) throws TasteException {
- return new Track2Recommender(dataModel, ((KDDCupDataModel) dataModel).getDataFileDirectory());
- }

diff --git a/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/Track2Runner.java b/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/Track2Runner.java
deleted file mode 100644
index 3cbb61c..0000000
--- a/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/Track2Runner.java
+++ /dev/null
@@ -1,100 +0,0 @@
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.mahout.cf.taste.example.kddcup.track2;
-import org.apache.mahout.cf.taste.example.kddcup.DataFileIterable;
-import org.apache.mahout.cf.taste.example.kddcup.KDDCupDataModel;
-import org.apache.mahout.cf.taste.model.PreferenceArray;
-import org.apache.mahout.common.Pair;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
-import java.io.BufferedOutputStream;
-import java.io.File;
-import java.io.FileOutputStream;
-import java.io.OutputStream;
-import java.util.ArrayList;
-import java.util.Collection;
-import java.util.List;
-import java.util.concurrent.ExecutorService;
-import java.util.concurrent.Executors;
-import java.util.concurrent.Future;
- * <p>Runs "track 2" of the KDD Cup competition using whatever recommender is inside {@link Track2Recommender}
- * and attempts to output the result in the correct contest format.</p>
- *
- * <p>Run as: {@code Track2Runner [track 2 data file directory] [output file]}</p>
- */
-public final class Track2Runner {
- private static final Logger log = LoggerFactory.getLogger(Track2Runner.class);
- private Track2Runner() {
- }
- public static void main(String[] args) throws Exception {
- File dataFileDirectory = new File(args[0]);
- if (!dataFileDirectory.exists() || !dataFileDirectory.isDirectory()) {
- throw new IllegalArgumentException("Bad data file directory: " + dataFileDirectory);
- }
- long start = System.currentTimeMillis();
- KDDCupDataModel model = new KDDCupDataModel(KDDCupDataModel.getTrainingFile(dataFileDirectory));
- Track2Recommender recommender = new Track2Recommender(model, dataFileDirectory);
- long end = System.currentTimeMillis();
- log.info("Loaded model in {}s", (end - start) / 1000);
- start = end;
- Collection<Track2Callable> callables = new ArrayList<>();
- for (Pair<PreferenceArray,long[]> tests : new DataFileIterable(KDDCupDataModel.getTestFile(dataFileDirectory))) {
- PreferenceArray userTest = tests.getFirst();
- callables.add(new Track2Callable(recommender, userTest));
- }
- int cores = Runtime.getRuntime().availableProcessors();
- log.info("Running on {} cores", cores);
- ExecutorService executor = Executors.newFixedThreadPool(cores);
- List<Future<UserResult>> futures = executor.invokeAll(callables);
- executor.shutdown();
- end = System.currentTimeMillis();
- log.info("Ran recommendations in {}s", (end - start) / 1000);
- start = end;
- try (OutputStream out = new BufferedOutputStream(new FileOutputStream(new File(args[1])))){
- long lastUserID = Long.MIN_VALUE;
- for (Future<UserResult> future : futures) {
- UserResult result = future.get();
- long userID = result.getUserID();
- if (userID <= lastUserID) {
- throw new IllegalStateException();
- }
- lastUserID = userID;
- out.write(result.getResultBytes());
- }
- }
- end = System.currentTimeMillis();
- log.info("Wrote output in {}s", (end - start) / 1000);
- }

diff --git a/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/TrackData.java b/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/TrackData.java
deleted file mode 100644
index abd15f8..0000000
--- a/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/TrackData.java
+++ /dev/null
@@ -1,71 +0,0 @@
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.mahout.cf.taste.example.kddcup.track2;
-import java.util.regex.Pattern;
-import org.apache.mahout.cf.taste.impl.common.FastIDSet;
-final class TrackData {
- private static final Pattern PIPE = Pattern.compile("\\|");
- private static final String NO_VALUE = "None";
- static final long NO_VALUE_ID = Long.MIN_VALUE;
- private static final FastIDSet NO_GENRES = new FastIDSet();
- private final long trackID;
- private final long albumID;
- private final long artistID;
- private final FastIDSet genreIDs;
- TrackData(CharSequence line) {
- String[] tokens = PIPE.split(line);
- trackID = Long.parseLong(tokens[0]);
- albumID = parse(tokens[1]);
- artistID = parse(tokens[2]);
- if (tokens.length > 3) {
- genreIDs = new FastIDSet(tokens.length - 3);
- for (int i = 3; i < tokens.length; i++) {
- genreIDs.add(Long.parseLong(tokens[i]));
- }
- } else {
- genreIDs = NO_GENRES;
- }
- }
- private static long parse(String value) {
- return NO_VALUE.equals(value) ? NO_VALUE_ID : Long.parseLong(value);
- }
- public long getTrackID() {
- return trackID;
- }
- public long getAlbumID() {
- return albumID;
- }
- public long getArtistID() {
- return artistID;
- }
- public FastIDSet getGenreIDs() {
- return genreIDs;
- }

diff --git a/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/TrackItemSimilarity.java b/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/TrackItemSimilarity.java
deleted file mode 100644
index 3012a84..0000000
--- a/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/TrackItemSimilarity.java
+++ /dev/null
@@ -1,106 +0,0 @@
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.mahout.cf.taste.example.kddcup.track2;
-import java.io.File;
-import java.io.IOException;
-import java.util.Collection;
-import org.apache.mahout.cf.taste.common.Refreshable;
-import org.apache.mahout.cf.taste.example.kddcup.KDDCupDataModel;
-import org.apache.mahout.cf.taste.impl.common.FastByIDMap;
-import org.apache.mahout.cf.taste.impl.common.FastIDSet;
-import org.apache.mahout.cf.taste.impl.common.LongPrimitiveIterator;
-import org.apache.mahout.cf.taste.similarity.ItemSimilarity;
-import org.apache.mahout.common.iterator.FileLineIterable;
-final class TrackItemSimilarity implements ItemSimilarity {
- private final FastByIDMap<TrackData> trackData;
- TrackItemSimilarity(File dataFileDirectory) throws IOException {
- trackData = new FastByIDMap<>();
- for (String line : new FileLineIterable(KDDCupDataModel.getTrackFile(dataFileDirectory))) {
- TrackData trackDatum = new TrackData(line);
- trackData.put(trackDatum.getTrackID(), trackDatum);
- }
- }
- @Override
- public double itemSimilarity(long itemID1, long itemID2) {
- if (itemID1 == itemID2) {
- return 1.0;
- }
- TrackData data1 = trackData.get(itemID1);
- TrackData data2 = trackData.get(itemID2);
- if (data1 == null || data2 == null) {
- return 0.0;
- }
- // Arbitrarily decide that same album means "very similar"
- if (data1.getAlbumID() != TrackData.NO_VALUE_ID && data1.getAlbumID() == data2.getAlbumID()) {
- return 0.9;
- }
- // ... and same artist means "fairly similar"
- if (data1.getArtistID() != TrackData.NO_VALUE_ID && data1.getArtistID() == data2.getArtistID()) {
- return 0.7;
- }
- // Tanimoto coefficient similarity based on genre, but maximum value of 0.25
- FastIDSet genres1 = data1.getGenreIDs();
- FastIDSet genres2 = data2.getGenreIDs();
- if (genres1 == null || genres2 == null) {
- return 0.0;
- }
- int intersectionSize = genres1.intersectionSize(genres2);
- if (intersectionSize == 0) {
- return 0.0;
- }
- int unionSize = genres1.size() + genres2.size() - intersectionSize;
- return intersectionSize / (4.0 * unionSize);
- }
- @Override
- public double[] itemSimilarities(long itemID1, long[] itemID2s) {
- int length = itemID2s.length;
- double[] result = new double[length];
- for (int i = 0; i < length; i++) {
- result[i] = itemSimilarity(itemID1, itemID2s[i]);
- }
- return result;
- }
- @Override
- public long[] allSimilarItemIDs(long itemID) {
- FastIDSet allSimilarItemIDs = new FastIDSet();
- LongPrimitiveIterator allItemIDs = trackData.keySetIterator();
- while (allItemIDs.hasNext()) {
- long possiblySimilarItemID = allItemIDs.nextLong();
- if (!Double.isNaN(itemSimilarity(itemID, possiblySimilarItemID))) {
- allSimilarItemIDs.add(possiblySimilarItemID);
- }
- }
- return allSimilarItemIDs.toArray();
- }
- @Override
- public void refresh(Collection<Refreshable> alreadyRefreshed) {
- // do nothing
- }

diff --git a/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/UserResult.java b/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/UserResult.java
deleted file mode 100644
index e554d10..0000000
--- a/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/UserResult.java
+++ /dev/null
@@ -1,54 +0,0 @@
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.mahout.cf.taste.example.kddcup.track2;
-final class UserResult {
- private final long userID;
- private final byte[] resultBytes;
- UserResult(long userID, boolean[] result) {
- this.userID = userID;
- int trueCount = 0;
- for (boolean b : result) {
- if (b) {
- trueCount++;
- }
- }
- if (trueCount != 3) {
- throw new IllegalStateException();
- }
- resultBytes = new byte[result.length];
- for (int i = 0; i < result.length; i++) {
- resultBytes[i] = (byte) (result[i] ? '1' : '0');
- }
- }
- public long getUserID() {
- return userID;
- }
- public byte[] getResultBytes() {
- return resultBytes;
- }

diff --git a/examples/src/main/java/org/apache/mahout/cf/taste/hadoop/example/als/netflix/NetflixDatasetConverter.java b/examples/src/main/java/org/apache/mahout/cf/taste/hadoop/example/als/netflix/NetflixDatasetConverter.java
deleted file mode 100644
index 22f122e..0000000
--- a/examples/src/main/java/org/apache/mahout/cf/taste/hadoop/example/als/netflix/NetflixDatasetConverter.java
+++ /dev/null
@@ -1,140 +0,0 @@
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.mahout.cf.taste.hadoop.example.als.netflix;
-import com.google.common.base.Preconditions;
-import org.apache.commons.io.Charsets;
-import org.apache.hadoop.conf.Configuration;
-import org.apache.hadoop.fs.FileSystem;
-import org.apache.hadoop.fs.Path;
-import org.apache.mahout.cf.taste.impl.model.GenericPreference;
-import org.apache.mahout.cf.taste.model.Preference;
-import org.apache.mahout.common.iterator.FileLineIterable;
-import org.apache.mahout.common.iterator.FileLineIterator;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
-import java.io.BufferedWriter;
-import java.io.File;
-import java.io.IOException;
-import java.io.OutputStreamWriter;
-import java.util.ArrayList;
-import java.util.List;
-import java.util.regex.Pattern;
-/** converts the raw files provided by netflix to an appropriate input format */
-public final class NetflixDatasetConverter {
- private static final Logger log = LoggerFactory.getLogger(NetflixDatasetConverter.class);
- private static final Pattern SEPARATOR = Pattern.compile(",");
- private static final String MOVIE_DENOTER = ":";
- private static final String TAB = "\t";
- private static final String NEWLINE = "\n";
- private NetflixDatasetConverter() {
- }
- public static void main(String[] args) throws IOException {
- if (args.length != 4) {
- System.err.println("Usage: NetflixDatasetConverter /path/to/training_set/ /path/to/qualifying.txt "
- + "/path/to/judging.txt /path/to/destination");
- return;
- }
- String trainingDataDir = args[0];
- String qualifyingTxt = args[1];
- String judgingTxt = args[2];
- Path outputPath = new Path(args[3]);
- Configuration conf = new Configuration();
- FileSystem fs = FileSystem.get(outputPath.toUri(), conf);
- Preconditions.checkArgument(trainingDataDir != null, "Training Data location needs to be specified");
- log.info("Creating training set at {}/trainingSet/ratings.tsv ...", outputPath);
- try (BufferedWriter writer =
- new BufferedWriter(
- new OutputStreamWriter(
- fs.create(new Path(outputPath, "trainingSet/ratings.tsv")), Charsets.UTF_8))){
- int ratingsProcessed = 0;
- for (File movieRatings : new File(trainingDataDir).listFiles()) {
- try (FileLineIterator lines = new FileLineIterator(movieRatings)) {
- boolean firstLineRead = false;
- String movieID = null;
- while (lines.hasNext()) {
- String line = lines.next();
- if (firstLineRead) {
- String[] tokens = SEPARATOR.split(line);
- String userID = tokens[0];
- String rating = tokens[1];
- writer.write(userID + TAB + movieID + TAB + rating + NEWLINE);
- ratingsProcessed++;
- if (ratingsProcessed % 1000000 == 0) {
- log.info("{} ratings processed...", ratingsProcessed);
- }
- } else {
- movieID = line.replaceAll(MOVIE_DENOTER, "");
- firstLineRead = true;
- }
- }
- }
- }
- log.info("{} ratings processed. done.", ratingsProcessed);
- }
- log.info("Reading probes...");
- List<Preference> probes = new ArrayList<>(2817131);
- long currentMovieID = -1;
- for (String line : new FileLineIterable(new File(qualifyingTxt))) {
- if (line.contains(MOVIE_DENOTER)) {
- currentMovieID = Long.parseLong(line.replaceAll(MOVIE_DENOTER, ""));
- } else {
- long userID = Long.parseLong(SEPARATOR.split(line)[0]);
- probes.add(new GenericPreference(userID, currentMovieID, 0));
- }
- }
- log.info("{} probes read...", probes.size());
- log.info("Reading ratings, creating probe set at {}/probeSet/ratings.tsv ...", outputPath);
- try (BufferedWriter writer =
- new BufferedWriter(new OutputStreamWriter(
- fs.create(new Path(outputPath, "probeSet/ratings.tsv")), Charsets.UTF_8))){
- int ratingsProcessed = 0;
- for (String line : new FileLineIterable(new File(judgingTxt))) {
- if (line.contains(MOVIE_DENOTER)) {
- currentMovieID = Long.parseLong(line.replaceAll(MOVIE_DENOTER, ""));
- } else {
- float rating = Float.parseFloat(SEPARATOR.split(line)[0]);
- Preference pref = probes.get(ratingsProcessed);
- Preconditions.checkState(pref.getItemID() == currentMovieID);
- ratingsProcessed++;
- writer.write(pref.getUserID() + TAB + pref.getItemID() + TAB + rating + NEWLINE);
- if (ratingsProcessed % 1000000 == 0) {
- log.info("{} ratings processed...", ratingsProcessed);
- }
- }
- }
- log.info("{} ratings processed. done.", ratingsProcessed);
- }
- }

diff --git a/examples/src/main/java/org/apache/mahout/cf/taste/similarity/precompute/example/BatchItemSimilaritiesGroupLens.java b/examples/src/main/java/org/apache/mahout/cf/taste/similarity/precompute/example/BatchItemSimilaritiesGroupLens.java
deleted file mode 100644
index 8021d00..0000000
--- a/examples/src/main/java/org/apache/mahout/cf/taste/similarity/precompute/example/BatchItemSimilaritiesGroupLens.java
+++ /dev/null
@@ -1,65 +0,0 @@
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.mahout.cf.taste.similarity.precompute.example;
-import org.apache.mahout.cf.taste.impl.recommender.GenericItemBasedRecommender;
-import org.apache.mahout.cf.taste.impl.similarity.LogLikelihoodSimilarity;
-import org.apache.mahout.cf.taste.impl.similarity.precompute.FileSimilarItemsWriter;
-import org.apache.mahout.cf.taste.impl.similarity.precompute.MultithreadedBatchItemSimilarities;
-import org.apache.mahout.cf.taste.model.DataModel;
-import org.apache.mahout.cf.taste.recommender.ItemBasedRecommender;
-import org.apache.mahout.cf.taste.similarity.precompute.BatchItemSimilarities;
-import java.io.File;
- * Example that precomputes all item similarities of the Movielens1M dataset
- *
- * Usage: download movielens1M from http://www.grouplens.org/node/73 , unzip it and invoke this code with the path
- * to the ratings.dat file as argument
- *
- */
-public final class BatchItemSimilaritiesGroupLens {
- private BatchItemSimilaritiesGroupLens() {}
- public static void main(String[] args) throws Exception {
- if (args.length != 1) {
- System.err.println("Need path to ratings.dat of the movielens1M dataset as argument!");
- System.exit(-1);
- }
- File resultFile = new File(System.getProperty("java.io.tmpdir"), "similarities.csv");
- if (resultFile.exists()) {
- resultFile.delete();
- }
- DataModel dataModel = new GroupLensDataModel(new File(args[0]));
- ItemBasedRecommender recommender = new GenericItemBasedRecommender(dataModel,
- new LogLikelihoodSimilarity(dataModel));
- BatchItemSimilarities batch = new MultithreadedBatchItemSimilarities(recommender, 5);
- int numSimilarities = batch.computeItemSimilarities(Runtime.getRuntime().availableProcessors(), 1,
- new FileSimilarItemsWriter(resultFile));
- System.out.println("Computed " + numSimilarities + " similarities for " + dataModel.getNumItems() + " items "
- + "and saved them to " + resultFile.getAbsolutePath());
- }

diff --git a/examples/src/main/java/org/apache/mahout/cf/taste/similarity/precompute/example/GroupLensDataModel.java b/examples/src/main/java/org/apache/mahout/cf/taste/similarity/precompute/example/GroupLensDataModel.java
deleted file mode 100644
index 7ee9b17..0000000
--- a/examples/src/main/java/org/apache/mahout/cf/taste/similarity/precompute/example/GroupLensDataModel.java
+++ /dev/null
@@ -1,96 +0,0 @@
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.mahout.cf.taste.similarity.precompute.example;
-import com.google.common.io.Files;
-import com.google.common.io.InputSupplier;
-import com.google.common.io.Resources;
-import java.io.File;
-import java.io.FileOutputStream;
-import java.io.IOException;
-import java.io.InputStream;
-import java.io.OutputStreamWriter;
-import java.io.Writer;
-import java.net.URL;
-import java.util.regex.Pattern;
-import org.apache.commons.io.Charsets;
-import org.apache.mahout.cf.taste.impl.model.file.FileDataModel;
-import org.apache.mahout.common.iterator.FileLineIterable;
-public final class GroupLensDataModel extends FileDataModel {
- private static final String COLON_DELIMTER = "::";
- private static final Pattern COLON_DELIMITER_PATTERN = Pattern.compile(COLON_DELIMTER);
- public GroupLensDataModel() throws IOException {
- this(readResourceToTempFile("/org/apache/mahout/cf/taste/example/grouplens/ratings.dat"));
- }
- /**
- * @param ratingsFile GroupLens ratings.dat file in its native format
- * @throws IOException if an error occurs while reading or writing files
- */
- public GroupLensDataModel(File ratingsFile) throws IOException {
- super(convertGLFile(ratingsFile));
- }
- private static File convertGLFile(File originalFile) throws IOException {
- // Now translate the file; remove commas, then convert "::" delimiter to comma
- File resultFile = new File(new File(System.getProperty("java.io.tmpdir")), "ratings.txt");
- if (resultFile.exists()) {
- resultFile.delete();
- }
- try (Writer writer = new OutputStreamWriter(new FileOutputStream(resultFile), Charsets.UTF_8)){
- for (String line : new FileLineIterable(originalFile, false)) {
- int lastDelimiterStart = line.lastIndexOf(COLON_DELIMTER);
- if (lastDelimiterStart < 0) {
- throw new IOException("Unexpected input format on line: " + line);
- }
- String subLine = line.substring(0, lastDelimiterStart);
- String convertedLine = COLON_DELIMITER_PATTERN.matcher(subLine).replaceAll(",");
- writer.write(convertedLine);
- writer.write('\n');
- }
- } catch (IOException ioe) {
- resultFile.delete();
- throw ioe;
- }
- return resultFile;
- }
- public static File readResourceToTempFile(String resourceName) throws IOException {
- InputSupplier<? extends InputStream> inSupplier;
- try {
- URL resourceURL = Resources.getResource(GroupLensDataModel.class, resourceName);
- inSupplier = Resources.newInputStreamSupplier(resourceURL);
- } catch (IllegalArgumentException iae) {
- File resourceFile = new File("src/main/java" + resourceName);
- inSupplier = Files.newInputStreamSupplier(resourceFile);
- }
- File tempFile = File.createTempFile("taste", null);
- tempFile.deleteOnExit();
- Files.copy(inSupplier, tempFile);
- return tempFile;
- }
- @Override
- public String toString() {
- return "GroupLensDataModel";
- }

diff --git a/examples/src/main/java/org/apache/mahout/classifier/NewsgroupHelper.java b/examples/src/main/java/org/apache/mahout/classifier/NewsgroupHelper.java
deleted file mode 100644
index 5cec51c..0000000
--- a/examples/src/main/java/org/apache/mahout/classifier/NewsgroupHelper.java
+++ /dev/null
@@ -1,128 +0,0 @@
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.mahout.classifier;
-import com.google.common.collect.ConcurrentHashMultiset;
-import com.google.common.collect.Multiset;
-import com.google.common.io.Closeables;
-import com.google.common.io.Files;
-import org.apache.commons.io.Charsets;
-import org.apache.lucene.analysis.Analyzer;
-import org.apache.lucene.analysis.TokenStream;
-import org.apache.lucene.analysis.standard.StandardAnalyzer;
-import org.apache.lucene.analysis.tokenattributes.CharTermAttribute;
-import org.apache.mahout.common.RandomUtils;
-import org.apache.mahout.math.RandomAccessSparseVector;
-import org.apache.mahout.math.Vector;
-import org.apache.mahout.vectorizer.encoders.ConstantValueEncoder;
-import org.apache.mahout.vectorizer.encoders.FeatureVectorEncoder;
-import org.apache.mahout.vectorizer.encoders.StaticWordValueEncoder;
-import java.io.BufferedReader;
-import java.io.File;
-import java.io.IOException;
-import java.io.Reader;
-import java.io.StringReader;
-import java.text.SimpleDateFormat;
-import java.util.Collection;
-import java.util.Date;
-import java.util.Locale;
-import java.util.Random;
-public final class NewsgroupHelper {
- private static final SimpleDateFormat[] DATE_FORMATS = {
- new SimpleDateFormat("", Locale.ENGLISH),
- new SimpleDateFormat("MMM-yyyy", Locale.ENGLISH),
- new SimpleDateFormat("dd-MMM-yyyy HH:mm:ss", Locale.ENGLISH)
- };
- public static final int FEATURES = 10000;
- // 1997-01-15 00:01:00 GMT
- private static final long DATE_REFERENCE = 853286460;
- private static final long MONTH = 30 * 24 * 3600;
- private static final long WEEK = 7 * 24 * 3600;
- private final Random rand = RandomUtils.getRandom();
- private final Analyzer analyzer = new StandardAnalyzer();
- private final FeatureVectorEncoder encoder = new StaticWordValueEncoder("body");
- private final FeatureVectorEncoder bias = new ConstantValueEncoder("Intercept");
- public FeatureVectorEncoder getEncoder() {
- return encoder;
- }
- public FeatureVectorEncoder getBias() {
- return bias;
- }
- public Random getRandom() {
- return rand;
- }
- public Vector encodeFeatureVector(File file, int actual, int leakType, Multiset<String> overallCounts)
- throws IOException {
- long date = (long) (1000 * (DATE_REFERENCE + actual * MONTH + 1 * WEEK * rand.nextDouble()));
- Multiset<String> words = ConcurrentHashMultiset.create();
- try (BufferedReader reader = Files.newReader(file, Charsets.UTF_8)) {
- String line = reader.readLine();
- Reader dateString = new StringReader(DATE_FORMATS[leakType % 3].format(new Date(date)));
- countWords(analyzer, words, dateString, overallCounts);
- while (line != null && !line.isEmpty()) {
- boolean countHeader = (
- line.startsWith("From:") || line.startsWith("Subject:")
- || line.startsWith("Keywords:") || line.startsWith("Summary:")) && leakType < 6;
- do {
- Reader in = new StringReader(line);
- if (countHeader) {
- countWords(analyzer, words, in, overallCounts);
- }
- line = reader.readLine();
- } while (line != null && line.startsWith(" "));
- }
- if (leakType < 3) {
- countWords(analyzer, words, reader, overallCounts);
- }
- }
- Vector v = new RandomAccessSparseVector(FEATURES);
- bias.addToVector("", 1, v);
- for (String word : words.elementSet()) {
- encoder.addToVector(word, Math.log1p(words.count(word)), v);
- }
- return v;
- }
- public static void countWords(Analyzer analyzer,
- Collection<String> words,
- Reader in,
- Multiset<String> overallCounts) throws IOException {
- TokenStream ts = analyzer.tokenStream("text", in);
- ts.addAttribute(CharTermAttribute.class);
- ts.reset();
- while (ts.incrementToken()) {
- String s = ts.getAttribute(CharTermAttribute.class).toString();
- words.add(s);
- }
- overallCounts.addAll(words);
- ts.end();
- Closeables.close(ts, true);
- }

diff --git a/examples/src/main/java/org/apache/mahout/classifier/email/PrepEmailMapper.java b/examples/src/main/java/org/apache/mahout/classifier/email/PrepEmailMapper.java
deleted file mode 100644
index 16e9d80..0000000
--- a/examples/src/main/java/org/apache/mahout/classifier/email/PrepEmailMapper.java
+++ /dev/null
@@ -1,65 +0,0 @@
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.mahout.classifier.email;
-import org.apache.hadoop.io.Text;
-import org.apache.hadoop.io.WritableComparable;
-import org.apache.hadoop.mapreduce.Mapper;
-import org.apache.mahout.math.VectorWritable;
-import java.io.IOException;
-import java.util.Locale;
-import java.util.regex.Pattern;
- * Convert the labels created by the {@link org.apache.mahout.utils.email.MailProcessor} to one consumable
- * by the classifiers
- */
-public class PrepEmailMapper extends Mapper<WritableComparable<?>, VectorWritable, Text, VectorWritable> {
- private static final Pattern DASH_DOT = Pattern.compile("-|\\.");
- private static final Pattern SLASH = Pattern.compile("\\/");
- private boolean useListName = false; //if true, use the project name and the list name in label creation
- @Override
- protected void setup(Context context) throws IOException, InterruptedException {
- useListName = Boolean.parseBoolean(context.getConfiguration().get(PrepEmailVectorsDriver.USE_LIST_NAME));
- }
- @Override
- protected void map(WritableComparable<?> key, VectorWritable value, Context context)
- throws IOException, InterruptedException {
- String input = key.toString();
- ///Example: /cocoon.apache.org/dev/200307.gz/001401c3414f$8394e160$***@WRPO
- String[] splits = SLASH.split(input);
- //we need the first two splits;
- if (splits.length >= 3) {
- StringBuilder bldr = new StringBuilder();
- bldr.append(escape(splits[1]));
- if (useListName) {
- bldr.append('_').append(escape(splits[2]));
- }
- context.write(new Text(bldr.toString()), value);
- }
- }
- private static String escape(CharSequence value) {
- return DASH_DOT.matcher(value).replaceAll("_").toLowerCase(Locale.ENGLISH);
- }

diff --git a/examples/src/main/java/org/apache/mahout/classifier/email/PrepEmailReducer.java b/examples/src/main/java/org/apache/mahout/classifier/email/PrepEmailReducer.java
deleted file mode 100644
index da6e613..0000000
--- a/examples/src/main/java/org/apache/mahout/classifier/email/PrepEmailReducer.java
+++ /dev/null
@@ -1,47 +0,0 @@
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.mahout.classifier.email;
-import org.apache.hadoop.io.Text;
-import org.apache.hadoop.mapreduce.Reducer;
-import org.apache.mahout.math.VectorWritable;
-import java.io.IOException;
-import java.util.Iterator;
-public class PrepEmailReducer extends Reducer<Text, VectorWritable, Text, VectorWritable> {
- private long maxItemsPerLabel = 10000;
- @Override
- protected void setup(Context context) throws IOException, InterruptedException {
- maxItemsPerLabel = Long.parseLong(context.getConfiguration().get(PrepEmailVectorsDriver.ITEMS_PER_CLASS));
- }
- @Override
- protected void reduce(Text key, Iterable<VectorWritable> values, Context context)
- throws IOException, InterruptedException {
- //TODO: support randomization? Likely not needed due to the SplitInput utility which does random selection
- long i = 0;
- Iterator<VectorWritable> iterator = values.iterator();
- while (i < maxItemsPerLabel && iterator.hasNext()) {
- context.write(key, iterator.next());
- i++;
- }
- }

diff --git a/examples/src/main/java/org/apache/mahout/classifier/email/PrepEmailVectorsDriver.java b/examples/src/main/java/org/apache/mahout/classifier/email/PrepEmailVectorsDriver.java
deleted file mode 100644
index 8fba739..0000000
--- a/examples/src/main/java/org/apache/mahout/classifier/email/PrepEmailVectorsDriver.java
+++ /dev/null
@@ -1,76 +0,0 @@
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.mahout.classifier.email;
-import org.apache.hadoop.conf.Configuration;
-import org.apache.hadoop.fs.Path;
-import org.apache.hadoop.io.Text;
-import org.apache.hadoop.mapreduce.Job;
-import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat;
-import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat;
-import org.apache.hadoop.util.ToolRunner;
-import org.apache.mahout.common.AbstractJob;
-import org.apache.mahout.common.HadoopUtil;
-import org.apache.mahout.common.commandline.DefaultOptionCreator;
-import org.apache.mahout.math.VectorWritable;
-import java.util.List;
-import java.util.Map;
- * Convert the labels generated by {@link org.apache.mahout.text.SequenceFilesFromMailArchives} and
- * {@link org.apache.mahout.vectorizer.SparseVectorsFromSequenceFiles} to ones consumable by the classifiers. We do this
- * here b/c if it is done in the creation of sparse vectors, the Reducer collapses all the vectors.
- */
-public class PrepEmailVectorsDriver extends AbstractJob {
- public static final String ITEMS_PER_CLASS = "itemsPerClass";
- public static final String USE_LIST_NAME = "USE_LIST_NAME";
- public static void main(String[] args) throws Exception {
- ToolRunner.run(new Configuration(), new PrepEmailVectorsDriver(), args);
- }
- @Override
- public int run(String[] args) throws Exception {
- addInputOption();
- addOutputOption();
- addOption(DefaultOptionCreator.overwriteOption().create());
- addOption("maxItemsPerLabel", "mipl", "The maximum number of items per label. Can be useful for making the "
- + "training sets the same size", String.valueOf(100000));
- addOption(buildOption("useListName", "ul", "Use the name of the list as part of the label. If not set, then "
- + "just use the project name", false, false, "false"));
- Map<String,List<String>> parsedArgs = parseArguments(args);
- if (parsedArgs == null) {
- return -1;
- }
- Path input = getInputPath();
- Path output = getOutputPath();
- if (hasOption(DefaultOptionCreator.OVERWRITE_OPTION)) {
- HadoopUtil.delete(getConf(), output);
- }
- Job convertJob = prepareJob(input, output, SequenceFileInputFormat.class, PrepEmailMapper.class, Text.class,
- VectorWritable.class, PrepEmailReducer.class, Text.class, VectorWritable.class, SequenceFileOutputFormat.class);
- convertJob.getConfiguration().set(ITEMS_PER_CLASS, getOption("maxItemsPerLabel"));
- convertJob.getConfiguration().set(USE_LIST_NAME, String.valueOf(hasOption("useListName")));
- boolean succeeded = convertJob.waitForCompletion(true);
- return succeeded ? 0 : -1;
- }

diff --git a/examples/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/PosTagger.java b/examples/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/PosTagger.java
deleted file mode 100644
index 9c0ef56..0000000
--- a/examples/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/PosTagger.java
+++ /dev/null
@@ -1,277 +0,0 @@
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.mahout.classifier.sequencelearning.hmm;
-import com.google.common.io.Resources;
-import org.apache.commons.io.Charsets;
-import org.apache.mahout.math.Matrix;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
-import java.io.IOException;
-import java.net.URL;
-import java.util.Arrays;
-import java.util.HashMap;
-import java.util.LinkedList;
-import java.util.List;
-import java.util.Map;
-import java.util.regex.Pattern;
- * This class implements a sample program that uses a pre-tagged training data
- * set to train an HMM model as a POS tagger. The training data is automatically
- * downloaded from the following URL:
- * http://www.jaist.ac.jp/~hieuxuan/flexcrfs/CoNLL2000-NP/train.txt It then
- * trains an HMM Model using supervised learning and tests the model on the
- * following test data set:
- * http://www.jaist.ac.jp/~hieuxuan/flexcrfs/CoNLL2000-NP/test.txt Further
- * details regarding the data files can be found at
- * http://flexcrfs.sourceforge.net/#Case_Study
- */
-public final class PosTagger {
- private static final Logger log = LoggerFactory.getLogger(PosTagger.class);
- private static final Pattern SPACE = Pattern.compile(" ");
- private static final Pattern SPACES = Pattern.compile("[ ]+");
- /**
- * No public constructors for utility classes.
- */
- private PosTagger() {
- // nothing to do here really.
- }
- /**
- * Model trained in the example.
- */
- private static HmmModel taggingModel;
- /**
- * Map for storing the IDs for the POS tags (hidden states)
- */
- private static Map<String, Integer> tagIDs;
- /**
- * Counter for the next assigned POS tag ID The value of 0 is reserved for
- * "unknown POS tag"
- */
- private static int nextTagId;
- /**
- * Map for storing the IDs for observed words (observed states)
- */
- private static Map<String, Integer> wordIDs;
- /**
- * Counter for the next assigned word ID The value of 0 is reserved for
- * "unknown word"
- */
- private static int nextWordId = 1; // 0 is reserved for "unknown word"
- /**
- * Used for storing a list of POS tags of read sentences.
- */
- private static List<int[]> hiddenSequences;
- /**
- * Used for storing a list of word tags of read sentences.
- */
- private static List<int[]> observedSequences;
- /**
- * number of read lines
- */
- private static int readLines;
- /**
- * Given an URL, this function fetches the data file, parses it, assigns POS
- * Tag/word IDs and fills the hiddenSequences/observedSequences lists with
- * data from those files. The data is expected to be in the following format
- * (one word per line): word pos-tag np-tag sentences are closed with the .
- * pos tag
- *
- * @param url Where the data file is stored
- * @param assignIDs Should IDs for unknown words/tags be assigned? (Needed for
- * training data, not needed for test data)
- * @throws IOException in case data file cannot be read.
- */
- private static void readFromURL(String url, boolean assignIDs) throws IOException {
- // initialize the data structure
- hiddenSequences = new LinkedList<>();
- observedSequences = new LinkedList<>();
- readLines = 0;
- // now read line by line of the input file
- List<Integer> observedSequence = new LinkedList<>();
- List<Integer> hiddenSequence = new LinkedList<>();
- for (String line :Resources.readLines(new URL(url), Charsets.UTF_8)) {
- if (line.isEmpty()) {
- // new sentence starts
- int[] observedSequenceArray = new int[observedSequence.size()];
- int[] hiddenSequenceArray = new int[hiddenSequence.size()];
- for (int i = 0; i < observedSequence.size(); ++i) {
- observedSequenceArray[i] = observedSequence.get(i);
- hiddenSequenceArray[i] = hiddenSequence.get(i);
- }
- // now register those arrays
- hiddenSequences.add(hiddenSequenceArray);
- observedSequences.add(observedSequenceArray);
- // and reset the linked lists
- observedSequence.clear();
- hiddenSequence.clear();
- continue;
- }
- readLines++;
- // we expect the format [word] [POS tag] [NP tag]
- String[] tags = SPACE.split(line);
- // when analyzing the training set, assign IDs
- if (assignIDs) {
- if (!wordIDs.containsKey(tags[0])) {
- wordIDs.put(tags[0], nextWordId++);
- }
- if (!tagIDs.containsKey(tags[1])) {
- tagIDs.put(tags[1], nextTagId++);
- }
- }
- // determine the IDs
- Integer wordID = wordIDs.get(tags[0]);
- Integer tagID = tagIDs.get(tags[1]);
- // now construct the current sequence
- if (wordID == null) {
- observedSequence.add(0);
- } else {
- observedSequence.add(wordID);
- }
- if (tagID == null) {
- hiddenSequence.add(0);
- } else {
- hiddenSequence.add(tagID);
- }
- }
- // if there is still something in the pipe, register it
- if (!observedSequence.isEmpty()) {
- int[] observedSequenceArray = new int[observedSequence.size()];
- int[] hiddenSequenceArray = new int[hiddenSequence.size()];
- for (int i = 0; i < observedSequence.size(); ++i) {
- observedSequenceArray[i] = observedSequence.get(i);
- hiddenSequenceArray[i] = hiddenSequence.get(i);
- }
- // now register those arrays
- hiddenSequences.add(hiddenSequenceArray);
- observedSequences.add(observedSequenceArray);
- }
- }
- private static void trainModel(String trainingURL) throws IOException {
- tagIDs = new HashMap<>(44); // we expect 44 distinct tags
- wordIDs = new HashMap<>(19122); // we expect 19122
- // distinct words
- log.info("Reading and parsing training data file from URL: {}", trainingURL);
- long start = System.currentTimeMillis();
- readFromURL(trainingURL, true);
- long end = System.currentTimeMillis();
- double duration = (end - start) / 1000.0;
- log.info("Parsing done in {} seconds!", duration);
- log.info("Read {} lines containing {} sentences with a total of {} distinct words and {} distinct POS tags.",
- readLines, hiddenSequences.size(), nextWordId - 1, nextTagId - 1);
- start = System.currentTimeMillis();
- taggingModel = HmmTrainer.trainSupervisedSequence(nextTagId, nextWordId,
- hiddenSequences, observedSequences, 0.05);
- // we have to adjust the model a bit,
- // since we assume a higher probability that a given unknown word is NNP
- // than anything else
- Matrix emissions = taggingModel.getEmissionMatrix();
- for (int i = 0; i < taggingModel.getNrOfHiddenStates(); ++i) {
- emissions.setQuick(i, 0, 0.1 / taggingModel.getNrOfHiddenStates());
- }
- int nnptag = tagIDs.get("NNP");
- emissions.setQuick(nnptag, 0, 1 / (double) taggingModel.getNrOfHiddenStates());
- // re-normalize the emission probabilities
- HmmUtils.normalizeModel(taggingModel);
- // now register the names
- taggingModel.registerHiddenStateNames(tagIDs);
- taggingModel.registerOutputStateNames(wordIDs);
- end = System.currentTimeMillis();
- duration = (end - start) / 1000.0;
- log.info("Trained HMM models in {} seconds!", duration);
- }
- private static void testModel(String testingURL) throws IOException {
- log.info("Reading and parsing test data file from URL: {}", testingURL);
- long start = System.currentTimeMillis();
- readFromURL(testingURL, false);
- long end = System.currentTimeMillis();
- double duration = (end - start) / 1000.0;
- log.info("Parsing done in {} seconds!", duration);
- log.info("Read {} lines containing {} sentences.", readLines, hiddenSequences.size());
- start = System.currentTimeMillis();
- int errorCount = 0;
- int totalCount = 0;
- for (int i = 0; i < observedSequences.size(); ++i) {
- // fetch the viterbi path as the POS tag for this observed sequence
- int[] posEstimate = HmmEvaluator.decode(taggingModel, observedSequences.get(i), false);
- // compare with the expected
- int[] posExpected = hiddenSequences.get(i);
- for (int j = 0; j < posExpected.length; ++j) {
- totalCount++;
- if (posEstimate[j] != posExpected[j]) {
- errorCount++;
- }
- }
- }
- end = System.currentTimeMillis();
- duration = (end - start) / 1000.0;
- log.info("POS tagged test file in {} seconds!", duration);
- double errorRate = (double) errorCount / totalCount;
- log.info("Tagged the test file with an error rate of: {}", errorRate);
- }
- private static List<String> tagSentence(String sentence) {
- // first, we need to isolate all punctuation characters, so that they
- // can be recognized
- sentence = sentence.replaceAll("[,.!?:;\"]", " $0 ");
- sentence = sentence.replaceAll("''", " '' ");
- // now we tokenize the sentence
- String[] tokens = SPACES.split(sentence);
- // now generate the observed sequence
- int[] observedSequence = HmmUtils.encodeStateSequence(taggingModel, Arrays.asList(tokens), true, 0);
- // POS tag this observedSequence
- int[] hiddenSequence = HmmEvaluator.decode(taggingModel, observedSequence, false);
- // and now decode the tag names
- return HmmUtils.decodeStateSequence(taggingModel, hiddenSequence, false, null);
- }
- public static void main(String[] args) throws IOException {
- // generate the model from URL
- trainModel("http://www.jaist.ac.jp/~hieuxuan/flexcrfs/CoNLL2000-NP/train.txt");
- testModel("http://www.jaist.ac.jp/~hieuxuan/flexcrfs/CoNLL2000-NP/test.txt");
- // tag an exemplary sentence
- String test = "McDonalds is a huge company with many employees .";
- String[] testWords = SPACE.split(test);
- List<String> posTags = tagSentence(test);
- for (int i = 0; i < posTags.size(); ++i) {
- log.info("{}[{}]", testWords[i], posTags.get(i));
- }
- }

diff --git a/examples/src/main/java/org/apache/mahout/classifier/sgd/AdaptiveLogisticModelParameters.java b/examples/src/main/java/org/apache/mahout/classifier/sgd/AdaptiveLogisticModelParameters.java
deleted file mode 100644
index b2ce8b1..0000000
--- a/examples/src/main/java/org/apache/mahout/classifier/sgd/AdaptiveLogisticModelParameters.java
+++ /dev/null
@@ -1,236 +0,0 @@
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.mahout.classifier.sgd;
-import org.apache.mahout.math.stats.GlobalOnlineAuc;
-import org.apache.mahout.math.stats.GroupedOnlineAuc;
-import org.apache.mahout.math.stats.OnlineAuc;
-import java.io.DataInput;
-import java.io.DataInputStream;
-import java.io.DataOutput;
-import java.io.DataOutputStream;
-import java.io.File;
-import java.io.FileInputStream;
-import java.io.IOException;
-import java.io.InputStream;
-import java.io.OutputStream;
-import java.util.ArrayList;
-import java.util.HashMap;
-import java.util.List;
-import java.util.Locale;
-import java.util.Map;
-public class AdaptiveLogisticModelParameters extends LogisticModelParameters {
- private AdaptiveLogisticRegression alr;
- private int interval = 800;
- private int averageWindow = 500;
- private int threads = 4;
- private String prior = "L1";
- private double priorOption = Double.NaN;
- private String auc = null;
- public AdaptiveLogisticRegression createAdaptiveLogisticRegression() {
- if (alr == null) {
- alr = new AdaptiveLogisticRegression(getMaxTargetCategories(),
- getNumFeatures(), createPrior(prior, priorOption));
- alr.setInterval(interval);
- alr.setAveragingWindow(averageWindow);
- alr.setThreadCount(threads);
- alr.setAucEvaluator(createAUC(auc));
- }
- return alr;
- }
- public void checkParameters() {
- if (prior != null) {
- String priorUppercase = prior.toUpperCase(Locale.ENGLISH).trim();
- if (("TP".equals(priorUppercase) || "EBP".equals(priorUppercase)) && Double.isNaN(priorOption)) {
- throw new IllegalArgumentException("You must specify a double value for TPrior and ElasticBandPrior.");
- }
- }
- }
- private static PriorFunction createPrior(String cmd, double priorOption) {
- if (cmd == null) {
- return null;
- }
- if ("L1".equals(cmd.toUpperCase(Locale.ENGLISH).trim())) {
- return new L1();
- }
- if ("L2".equals(cmd.toUpperCase(Locale.ENGLISH).trim())) {
- return new L2();
- }
- if ("UP".equals(cmd.toUpperCase(Locale.ENGLISH).trim())) {
- return new UniformPrior();
- }
- if ("TP".equals(cmd.toUpperCase(Locale.ENGLISH).trim())) {
- return new TPrior(priorOption);
- }
- if ("EBP".equals(cmd.toUpperCase(Locale.ENGLISH).trim())) {
- return new ElasticBandPrior(priorOption);
- }
- return null;
- }
- private static OnlineAuc createAUC(String cmd) {
- if (cmd == null) {
- return null;
- }
- if ("GLOBAL".equals(cmd.toUpperCase(Locale.ENGLISH).trim())) {
- return new GlobalOnlineAuc();
- }
- if ("GROUPED".equals(cmd.toUpperCase(Locale.ENGLISH).trim())) {
- return new GroupedOnlineAuc();
- }
- return null;
- }
- @Override
- public void saveTo(OutputStream out) throws IOException {
- if (alr != null) {
- alr.close();
- }
- setTargetCategories(getCsvRecordFactory().getTargetCategories());
- write(new DataOutputStream(out));
- }
- @Override
- public void write(DataOutput out) throws IOException {
- out.writeUTF(getTargetVariable());
- out.writeInt(getTypeMap().size());
- for (Map.Entry<String, String> entry : getTypeMap().entrySet()) {
- out.writeUTF(entry.getKey());
- out.writeUTF(entry.getValue());
- }
- out.writeInt(getNumFeatures());
- out.writeInt(getMaxTargetCategories());
- out.writeInt(getTargetCategories().size());
- for (String category : getTargetCategories()) {
- out.writeUTF(category);
- }
- out.writeInt(interval);
- out.writeInt(averageWindow);
- out.writeInt(threads);
- out.writeUTF(prior);
- out.writeDouble(priorOption);
- out.writeUTF(auc);
- // skip csv
- alr.write(out);
- }
- @Override
- public void readFields(DataInput in) throws IOException {
- setTargetVariable(in.readUTF());
- int typeMapSize = in.readInt();
- Map<String, String> typeMap = new HashMap<>(typeMapSize);
- for (int i = 0; i < typeMapSize; i++) {
- String key = in.readUTF();
- String value = in.readUTF();
- typeMap.put(key, value);
- }
- setTypeMap(typeMap);
- setNumFeatures(in.readInt());
- setMaxTargetCategories(in.readInt());
- int targetCategoriesSize = in.readInt();
- List<String> targetCategories = new ArrayList<>(targetCategoriesSize);
- for (int i = 0; i < targetCategoriesSize; i++) {
- targetCategories.add(in.readUTF());
- }
- setTargetCategories(targetCategories);
- interval = in.readInt();
- averageWindow = in.readInt();
- threads = in.readInt();
- prior = in.readUTF();
- priorOption = in.readDouble();
- auc = in.readUTF();
- alr = new AdaptiveLogisticRegression();
- alr.readFields(in);
- }
- private static AdaptiveLogisticModelParameters loadFromStream(InputStream in) throws IOException {
- AdaptiveLogisticModelParameters result = new AdaptiveLogisticModelParameters();
- result.readFields(new DataInputStream(in));
- return result;
- }
- public static AdaptiveLogisticModelParameters loadFromFile(File in) throws IOException {
- try (InputStream input = new FileInputStream(in)) {
- return loadFromStream(input);
- }
- }
- public int getInterval() {
- return interval;
- }
- public void setInterval(int interval) {
- this.interval = interval;
- }
- public int getAverageWindow() {
- return averageWindow;
- }
- public void setAverageWindow(int averageWindow) {
- this.averageWindow = averageWindow;
- }
- public int getThreads() {
- return threads;
- }
- public void setThreads(int threads) {
- this.threads = threads;
- }
- public String getPrior() {
- return prior;
- }
- public void setPrior(String prior) {
- this.prior = prior;
- }
- public String getAuc() {
- return auc;
- }
- public void setAuc(String auc) {
- this.auc = auc;
- }
- public double getPriorOption() {
- return priorOption;
- }
- public void setPriorOption(double priorOption) {
- this.priorOption = priorOption;
- }
2018-06-27 13:14:39 UTC
diff --git a/community/mahout-mr/examples/src/test/resources/wdbc/wdbc.data b/community/mahout-mr/examples/src/test/resources/wdbc/wdbc.data
new file mode 100644
index 0000000..8885375
--- /dev/null
+++ b/community/mahout-mr/examples/src/test/resources/wdbc/wdbc.data
@@ -0,0 +1,569 @@

2018-06-27 13:14:40 UTC
diff --git a/community/mahout-mr/examples/src/main/resources/cf-data-purchase.txt b/community/mahout-mr/examples/src/main/resources/cf-data-purchase.txt
new file mode 100644
index 0000000..d87c031
--- /dev/null
+++ b/community/mahout-mr/examples/src/main/resources/cf-data-purchase.txt
@@ -0,0 +1,7 @@

diff --git a/community/mahout-mr/examples/src/main/resources/cf-data-view.txt b/community/mahout-mr/examples/src/main/resources/cf-data-view.txt
new file mode 100644
index 0000000..09ad9b6
--- /dev/null
+++ b/community/mahout-mr/examples/src/main/resources/cf-data-view.txt
@@ -0,0 +1,12 @@

diff --git a/community/mahout-mr/examples/src/main/resources/donut-test.csv b/community/mahout-mr/examples/src/main/resources/donut-test.csv
new file mode 100644
index 0000000..46ea564
--- /dev/null
+++ b/community/mahout-mr/examples/src/main/resources/donut-test.csv
@@ -0,0 +1,41 @@

diff --git a/community/mahout-mr/examples/src/main/resources/donut.csv b/community/mahout-mr/examples/src/main/resources/donut.csv
new file mode 100644
index 0000000..33ba3b7
--- /dev/null
+++ b/community/mahout-mr/examples/src/main/resources/donut.csv
@@ -0,0 +1,41 @@

diff --git a/community/mahout-mr/examples/src/main/resources/test-data.csv b/community/mahout-mr/examples/src/main/resources/test-data.csv
new file mode 100644
index 0000000..ab683cd
--- /dev/null
+++ b/community/mahout-mr/examples/src/main/resources/test-data.csv
@@ -0,0 +1,61 @@

diff --git a/community/mahout-mr/examples/src/test/java/org/apache/mahout/classifier/sgd/LogisticModelParametersTest.java b/community/mahout-mr/examples/src/test/java/org/apache/mahout/classifier/sgd/LogisticModelParametersTest.java
new file mode 100644
index 0000000..e849011
--- /dev/null
+++ b/community/mahout-mr/examples/src/test/java/org/apache/mahout/classifier/sgd/LogisticModelParametersTest.java
@@ -0,0 +1,43 @@
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.classifier.sgd;
+import org.apache.mahout.common.MahoutTestCase;
+import org.junit.Test;
+import java.io.ByteArrayOutputStream;
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.Collections;
+public class LogisticModelParametersTest extends MahoutTestCase {
+ @Test
+ public void serializationWithoutCsv() throws IOException {
+ LogisticModelParameters params = new LogisticModelParameters();
+ params.setTargetVariable("foo");
+ params.setTypeMap(Collections.<String, String>emptyMap());
+ params.setTargetCategories(Arrays.asList("foo", "bar"));
+ params.setNumFeatures(1);
+ params.createRegression();
+ //MAHOUT-1196 should work without "csv" being set
+ params.saveTo(new ByteArrayOutputStream());
+ }

diff --git a/community/mahout-mr/examples/src/test/java/org/apache/mahout/classifier/sgd/ModelDissectorTest.java b/community/mahout-mr/examples/src/test/java/org/apache/mahout/classifier/sgd/ModelDissectorTest.java
new file mode 100644
index 0000000..c8e4879
--- /dev/null
+++ b/community/mahout-mr/examples/src/test/java/org/apache/mahout/classifier/sgd/ModelDissectorTest.java
@@ -0,0 +1,40 @@
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.classifier.sgd;
+import org.apache.mahout.examples.MahoutTestCase;
+import org.apache.mahout.math.DenseVector;
+import org.junit.Test;
+public class ModelDissectorTest extends MahoutTestCase {
+ @Test
+ public void testCategoryOrdering() {
+ ModelDissector.Weight w = new ModelDissector.Weight("a", new DenseVector(new double[]{-2, -5, 5, 2, 4, 1, 0}), 4);
+ assertEquals(1, w.getCategory(0), 0);
+ assertEquals(-5, w.getWeight(0), 0);
+ assertEquals(2, w.getCategory(1), 0);
+ assertEquals(5, w.getWeight(1), 0);
+ assertEquals(4, w.getCategory(2), 0);
+ assertEquals(4, w.getWeight(2), 0);
+ assertEquals(0, w.getCategory(3), 0);
+ assertEquals(-2, w.getWeight(3), 0);
+ }

diff --git a/community/mahout-mr/examples/src/test/java/org/apache/mahout/classifier/sgd/TrainLogisticTest.java b/community/mahout-mr/examples/src/test/java/org/apache/mahout/classifier/sgd/TrainLogisticTest.java
new file mode 100644
index 0000000..4cde692
--- /dev/null
+++ b/community/mahout-mr/examples/src/test/java/org/apache/mahout/classifier/sgd/TrainLogisticTest.java
@@ -0,0 +1,167 @@
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.classifier.sgd;
+import com.google.common.base.Charsets;
+import com.google.common.collect.ImmutableMap;
+import com.google.common.collect.Sets;
+import com.google.common.io.Resources;
+import org.apache.mahout.classifier.AbstractVectorClassifier;
+import org.apache.mahout.examples.MahoutTestCase;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Vector;
+import org.junit.Test;
+import java.io.File;
+import java.io.FileInputStream;
+import java.io.InputStream;
+import java.io.PrintWriter;
+import java.io.StringWriter;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.TreeSet;
+public class TrainLogisticTest extends MahoutTestCase {
+ @Test
+ public void example131() throws Exception {
+ String outputFile = getTestTempFile("model").getAbsolutePath();
+ StringWriter sw = new StringWriter();
+ PrintWriter pw = new PrintWriter(sw, true);
+ TrainLogistic.mainToOutput(new String[]{
+ "--input", "donut.csv",
+ "--output", outputFile,
+ "--target", "color", "--categories", "2",
+ "--predictors", "x", "y",
+ "--types", "numeric",
+ "--features", "20",
+ "--passes", "100",
+ "--rate", "50"
+ }, pw);
+ String trainOut = sw.toString();
+ assertTrue(trainOut.contains("x -0.7"));
+ assertTrue(trainOut.contains("y -0.4"));
+ LogisticModelParameters lmp = TrainLogistic.getParameters();
+ assertEquals(1.0e-4, lmp.getLambda(), 1.0e-9);
+ assertEquals(20, lmp.getNumFeatures());
+ assertTrue(lmp.useBias());
+ assertEquals("color", lmp.getTargetVariable());
+ CsvRecordFactory csv = lmp.getCsvRecordFactory();
+ assertEquals("[1, 2]", new TreeSet<>(csv.getTargetCategories()).toString());
+ assertEquals("[Intercept Term, x, y]", Sets.newTreeSet(csv.getPredictors()).toString());
+ // verify model by building dissector
+ AbstractVectorClassifier model = TrainLogistic.getModel();
+ List<String> data = Resources.readLines(Resources.getResource("donut.csv"), Charsets.UTF_8);
+ Map<String, Double> expectedValues = ImmutableMap.of("x", -0.7, "y", -0.43, "Intercept Term", -0.15);
+ verifyModel(lmp, csv, data, model, expectedValues);
+ // test saved model
+ try (InputStream in = new FileInputStream(new File(outputFile))){
+ LogisticModelParameters lmpOut = LogisticModelParameters.loadFrom(in);
+ CsvRecordFactory csvOut = lmpOut.getCsvRecordFactory();
+ csvOut.firstLine(data.get(0));
+ OnlineLogisticRegression lrOut = lmpOut.createRegression();
+ verifyModel(lmpOut, csvOut, data, lrOut, expectedValues);
+ }
+ sw = new StringWriter();
+ pw = new PrintWriter(sw, true);
+ RunLogistic.mainToOutput(new String[]{
+ "--input", "donut.csv",
+ "--model", outputFile,
+ "--auc",
+ "--confusion"
+ }, pw);
+ trainOut = sw.toString();
+ assertTrue(trainOut.contains("AUC = 0.57"));
+ assertTrue(trainOut.contains("confusion: [[27.0, 13.0], [0.0, 0.0]]"));
+ }
+ @Test
+ public void example132() throws Exception {
+ String outputFile = getTestTempFile("model").getAbsolutePath();
+ StringWriter sw = new StringWriter();
+ PrintWriter pw = new PrintWriter(sw, true);
+ TrainLogistic.mainToOutput(new String[]{
+ "--input", "donut.csv",
+ "--output", outputFile,
+ "--target", "color",
+ "--categories", "2",
+ "--predictors", "x", "y", "a", "b", "c",
+ "--types", "numeric",
+ "--features", "20",
+ "--passes", "100",
+ "--rate", "50"
+ }, pw);
+ String trainOut = sw.toString();
+ assertTrue(trainOut.contains("a 0."));
+ assertTrue(trainOut.contains("b -1."));
+ assertTrue(trainOut.contains("c -25."));
+ sw = new StringWriter();
+ pw = new PrintWriter(sw, true);
+ RunLogistic.mainToOutput(new String[]{
+ "--input", "donut.csv",
+ "--model", outputFile,
+ "--auc",
+ "--confusion"
+ }, pw);
+ trainOut = sw.toString();
+ assertTrue(trainOut.contains("AUC = 1.00"));
+ sw = new StringWriter();
+ pw = new PrintWriter(sw, true);
+ RunLogistic.mainToOutput(new String[]{
+ "--input", "donut-test.csv",
+ "--model", outputFile,
+ "--auc",
+ "--confusion"
+ }, pw);
+ trainOut = sw.toString();
+ assertTrue(trainOut.contains("AUC = 0.9"));
+ }
+ private static void verifyModel(LogisticModelParameters lmp,
+ RecordFactory csv,
+ List<String> data,
+ AbstractVectorClassifier model,
+ Map<String, Double> expectedValues) {
+ ModelDissector md = new ModelDissector();
+ for (String line : data.subList(1, data.size())) {
+ Vector v = new DenseVector(lmp.getNumFeatures());
+ csv.getTraceDictionary().clear();
+ csv.processLine(line, v);
+ md.update(v, csv.getTraceDictionary(), model);
+ }
+ // check right variables are present
+ List<ModelDissector.Weight> weights = md.summary(10);
+ Set<String> expected = Sets.newHashSet(expectedValues.keySet());
+ for (ModelDissector.Weight weight : weights) {
+ assertTrue(expected.remove(weight.getFeature()));
+ assertEquals(expectedValues.get(weight.getFeature()), weight.getWeight(), 0.1);
+ }
+ assertEquals(0, expected.size());
+ }

diff --git a/community/mahout-mr/examples/src/test/java/org/apache/mahout/clustering/display/ClustersFilterTest.java b/community/mahout-mr/examples/src/test/java/org/apache/mahout/clustering/display/ClustersFilterTest.java
new file mode 100644
index 0000000..6e43b97
--- /dev/null
+++ b/community/mahout-mr/examples/src/test/java/org/apache/mahout/clustering/display/ClustersFilterTest.java
@@ -0,0 +1,75 @@
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.clustering.display;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.fs.PathFilter;
+import org.apache.mahout.common.MahoutTestCase;
+import org.junit.Before;
+import org.junit.Test;
+import java.io.IOException;
+public class ClustersFilterTest extends MahoutTestCase {
+ private Configuration configuration;
+ private Path output;
+ @Override
+ @Before
+ public void setUp() throws Exception {
+ super.setUp();
+ configuration = getConfiguration();
+ output = getTestTempDirPath();
+ }
+ @Test
+ public void testAcceptNotFinal() throws Exception {
+ Path path0 = new Path(output, "clusters-0");
+ Path path1 = new Path(output, "clusters-1");
+ path0.getFileSystem(configuration).createNewFile(path0);
+ path1.getFileSystem(configuration).createNewFile(path1);
+ PathFilter clustersFilter = new ClustersFilter();
+ assertTrue(clustersFilter.accept(path0));
+ assertTrue(clustersFilter.accept(path1));
+ }
+ @Test
+ public void testAcceptFinalPath() throws IOException {
+ Path path0 = new Path(output, "clusters-0");
+ Path path1 = new Path(output, "clusters-1");
+ Path path2 = new Path(output, "clusters-2");
+ Path path3Final = new Path(output, "clusters-3-final");
+ path0.getFileSystem(configuration).createNewFile(path0);
+ path1.getFileSystem(configuration).createNewFile(path1);
+ path2.getFileSystem(configuration).createNewFile(path2);
+ path3Final.getFileSystem(configuration).createNewFile(path3Final);
+ PathFilter clustersFilter = new ClustersFilter();
+ assertTrue(clustersFilter.accept(path0));
+ assertTrue(clustersFilter.accept(path1));
+ assertTrue(clustersFilter.accept(path2));
+ assertTrue(clustersFilter.accept(path3Final));
+ }

diff --git a/community/mahout-mr/examples/src/test/java/org/apache/mahout/examples/MahoutTestCase.java b/community/mahout-mr/examples/src/test/java/org/apache/mahout/examples/MahoutTestCase.java
new file mode 100644
index 0000000..4d81e3f
--- /dev/null
+++ b/community/mahout-mr/examples/src/test/java/org/apache/mahout/examples/MahoutTestCase.java
@@ -0,0 +1,30 @@
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.examples;
+ * This class should not exist. It's here to work around some bizarre problem in Maven
+ * dependency management wherein it can see methods in {@link org.apache.mahout.common.MahoutTestCase}
+ * but not constants. Duplicated here to make it jive.
+ */
+public abstract class MahoutTestCase extends org.apache.mahout.common.MahoutTestCase {
+ /** "Close enough" value for floating-point comparisons. */
+ public static final double EPSILON = 0.000001;

diff --git a/community/mahout-mr/examples/src/test/resources/country.txt b/community/mahout-mr/examples/src/test/resources/country.txt
new file mode 100644
index 0000000..6a22091
--- /dev/null
+++ b/community/mahout-mr/examples/src/test/resources/country.txt
@@ -0,0 +1,229 @@
+American Samoa
+Antigua and Barbuda
+Bosnia and Herzegovina
+Bouvet Island
+British Indian Ocean Territory
+Brunei Darussalam
+Burkina Faso
+Cape Verde
+Cayman Islands
+Central African Republic
+Christmas Island
+Cocos Islands
+Cook Islands
+Costa Rica
+C�te d'Ivoire
+Czech Republic
+Dominican Republic
+El Salvador
+Equatorial Guinea
+Falkland Islands
+Faroe Islands
+French Guiana
+French Polynesia
+French Southern Territories
+Hong Kong
+Isle of Man
+Marshall Islands
+Netherlands Antilles
+New Caledonia
+New Zealand
+Norfolk Island
+Northern Mariana Islands
+Palestinian Territory
+Papua New Guinea
+Puerto Rico
+Russian Federation
+Saint Barth�lemy
+Saint Helena
+Saint Kitts and Nevis
+Saint Lucia
+Saint Martin
+Saint Pierre and Miquelon
+Saint Vincent and the Grenadines
+San Marino
+Sao Tome and Principe
+Saudi Arabia
+Sierra Leone
+Solomon Islands
+South Africa
+South Georgia and the South Sandwich Islands
+Sri Lanka
+Svalbard and Jan Mayen
+Syrian Arab Republic
+Trinidad and Tobago
+Turks and Caicos Islands
+United Arab Emirates
+United Kingdom
+United States
+United States Minor Outlying Islands
+Virgin Islands
+Wallis and Futuna

diff --git a/community/mahout-mr/examples/src/test/resources/country10.txt b/community/mahout-mr/examples/src/test/resources/country10.txt
new file mode 100644
index 0000000..97a63e1
--- /dev/null
+++ b/community/mahout-mr/examples/src/test/resources/country10.txt
@@ -0,0 +1,10 @@
+United Kingdom

diff --git a/community/mahout-mr/examples/src/test/resources/country2.txt b/community/mahout-mr/examples/src/test/resources/country2.txt
new file mode 100644
index 0000000..f4b4f61
--- /dev/null
+++ b/community/mahout-mr/examples/src/test/resources/country2.txt
@@ -0,0 +1,2 @@
+United States
+United Kingdom

diff --git a/community/mahout-mr/examples/src/test/resources/subjects.txt b/community/mahout-mr/examples/src/test/resources/subjects.txt
new file mode 100644
index 0000000..f52ae33
--- /dev/null
+++ b/community/mahout-mr/examples/src/test/resources/subjects.txt
@@ -0,0 +1,2 @@

diff --git a/community/mahout-mr/examples/src/test/resources/wdbc.infos b/community/mahout-mr/examples/src/test/resources/wdbc.infos
new file mode 100644
index 0000000..94a63d6
--- /dev/null
+++ b/community/mahout-mr/examples/src/test/resources/wdbc.infos
@@ -0,0 +1,32 @@
+NUMERICAL, 6.9, 28.2
+NUMERICAL, 9.7, 39.3
+NUMERICAL, 43.7, 188.5
+NUMERICAL, 143.5, 2501.0
+NUMERICAL, 0.0, 0.2
+NUMERICAL, 0.0, 0.4
+NUMERICAL, 0.0, 0.5
+NUMERICAL, 0.0, 0.3
+NUMERICAL, 0.1, 0.4
+NUMERICAL, 0.0, 0.1
+NUMERICAL, 0.1, 2.9
+NUMERICAL, 0.3, 4.9
+NUMERICAL, 0.7, 22.0
+NUMERICAL, 6.8, 542.3
+NUMERICAL, 0.0, 0.1
+NUMERICAL, 0.0, 0.2
+NUMERICAL, 0.0, 0.4
+NUMERICAL, 0.0, 0.1
+NUMERICAL, 0.0, 0.1
+NUMERICAL, 0.0, 0.1
+NUMERICAL, 7.9, 36.1
+NUMERICAL, 12.0, 49.6
+NUMERICAL, 50.4, 251.2
+NUMERICAL, 185.2, 4254.0
+NUMERICAL, 0.0, 0.3
+NUMERICAL, 0.0, 1.1
+NUMERICAL, 0.0, 1.3
+NUMERICAL, 0.0, 0.3
+NUMERICAL, 0.1, 0.7
+NUMERICAL, 0.0, 0.3
2018-06-27 13:14:36 UTC
diff --git a/examples/bin/resources/country.txt b/examples/bin/resources/country.txt
deleted file mode 100644
index 6a22091..0000000
--- a/examples/bin/resources/country.txt
+++ /dev/null
@@ -1,229 +0,0 @@
-American Samoa
-Antigua and Barbuda
-Bosnia and Herzegovina
-Bouvet Island
-British Indian Ocean Territory
-Brunei Darussalam
-Burkina Faso
-Cape Verde
-Cayman Islands
-Central African Republic
-Christmas Island
-Cocos Islands
-Cook Islands
-Costa Rica
-C�te d'Ivoire
-Czech Republic
-Dominican Republic
-El Salvador
-Equatorial Guinea
-Falkland Islands
-Faroe Islands
-French Guiana
-French Polynesia
-French Southern Territories
-Hong Kong
-Isle of Man
-Marshall Islands
-Netherlands Antilles
-New Caledonia
-New Zealand
-Norfolk Island
-Northern Mariana Islands
-Palestinian Territory
-Papua New Guinea
-Puerto Rico
-Russian Federation
-Saint Barth�lemy
-Saint Helena
-Saint Kitts and Nevis
-Saint Lucia
-Saint Martin
-Saint Pierre and Miquelon
-Saint Vincent and the Grenadines
-San Marino
-Sao Tome and Principe
-Saudi Arabia
-Sierra Leone
-Solomon Islands
-South Africa
-South Georgia and the South Sandwich Islands
-Sri Lanka
-Svalbard and Jan Mayen
-Syrian Arab Republic
-Trinidad and Tobago
-Turks and Caicos Islands
-United Arab Emirates
-United Kingdom
-United States
-United States Minor Outlying Islands
-Virgin Islands
-Wallis and Futuna

diff --git a/examples/bin/resources/country10.txt b/examples/bin/resources/country10.txt
deleted file mode 100644
index 97a63e1..0000000
--- a/examples/bin/resources/country10.txt
+++ /dev/null
@@ -1,10 +0,0 @@
-United Kingdom

diff --git a/examples/bin/resources/country2.txt b/examples/bin/resources/country2.txt
deleted file mode 100644
index f4b4f61..0000000
--- a/examples/bin/resources/country2.txt
+++ /dev/null
@@ -1,2 +0,0 @@
-United States
-United Kingdom

diff --git a/examples/bin/resources/donut-test.csv b/examples/bin/resources/donut-test.csv
deleted file mode 100644
index 46ea564..0000000
--- a/examples/bin/resources/donut-test.csv
+++ /dev/null
@@ -1,41 +0,0 @@

diff --git a/examples/bin/resources/donut.csv b/examples/bin/resources/donut.csv
deleted file mode 100644
index 33ba3b7..0000000
--- a/examples/bin/resources/donut.csv
+++ /dev/null
@@ -1,41 +0,0 @@

diff --git a/examples/bin/resources/test-data.csv b/examples/bin/resources/test-data.csv
deleted file mode 100644
index ab683cd..0000000
--- a/examples/bin/resources/test-data.csv
+++ /dev/null
@@ -1,61 +0,0 @@

diff --git a/examples/bin/run-item-sim.sh b/examples/bin/run-item-sim.sh
index 258cdfc..bfe75e2 100755
--- a/examples/bin/run-item-sim.sh
+++ b/examples/bin/run-item-sim.sh
@@ -68,7 +68,7 @@ echo "Removing old output file if it exists"

-mahout spark-itemsimilarity -i $PURCHASE -i2 $VIEW -o $FS_OUPUT -ma local
+$MAHOUT_HOME/bin/mahout spark-itemsimilarity -i $PURCHASE -i2 $VIEW -o $FS_OUPUT -ma local

export MAHOUT_LOCAL=$LOCAL #restore state

@@ -77,9 +77,9 @@ echo "Look in " $FS_OUPUT " for spark-itemsimilarity indicator data."
echo ""
echo "Purchase cooccurrence indicators (itemid<tab>simliar items by purchase)"
echo ""
-cat .$OUTPUT1
+cat ../..$OUTPUT1
echo ""
echo "View cross-cooccurrence indicators (items<tab>similar items where views led to purchases)"
echo ""
-cat .$OUTPUT2
+cat ../..$OUTPUT2
echo ""

diff --git a/examples/bin/set-dfs-commands.sh b/examples/bin/set-dfs-commands.sh
deleted file mode 100755
index 0ee5fe1..0000000
--- a/examples/bin/set-dfs-commands.sh
+++ /dev/null
@@ -1,54 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one or more
-# contributor license agreements. See the NOTICE file distributed with
-# this work for additional information regarding copyright ownership.
-# The ASF licenses this file to You under the Apache License, Version 2.0
-# (the "License"); you may not use this file except in compliance with
-# the License. You may obtain a copy of the License at
-# http://www.apache.org/licenses/LICENSE-2.0
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# Requires $HADOOP_HOME to be set.
-# Figures out the major version of Hadoop we're using and sets commands
-# for dfs commands
-# Run by each example script.
-# Find a hadoop shell
-if [ "$HADOOP_HOME" != "" ] && [ "$MAHOUT_LOCAL" == "" ] ; then
- HADOOP="${HADOOP_HOME}/bin/hadoop"
- if [ ! -e $HADOOP ]; then
- echo "Can't find hadoop in $HADOOP, exiting"
- exit 1
- fi
-# Check Hadoop version
-v=`${HADOOP_HOME}/bin/hadoop version | egrep "Hadoop [0-9]+.[0-9]+.[0-9]+" | cut -f 2 -d ' ' | cut -f 1 -d '.'`
-if [ $v -eq "1" -o $v -eq "0" ]
- echo "Discovered Hadoop v0 or v1."
- export DFS="${HADOOP_HOME}/bin/hadoop dfs"
- export DFSRM="$DFS -rmr -skipTrash"
-elif [ $v -eq "2" ]
- echo "Discovered Hadoop v2."
- export DFS="${HADOOP_HOME}/bin/hdfs dfs"
- export DFSRM="$DFS -rm -r -skipTrash"
- echo "Can't determine Hadoop version."
- exit 1
-echo "Setting dfs command to $DFS, dfs rm to $DFSRM."
-export HVERSION=$v

diff --git a/examples/pom.xml b/examples/pom.xml
index 3798117..e76ff1a 100644
--- a/examples/pom.xml
+++ b/examples/pom.xml
@@ -23,177 +23,14 @@
- <version>0.13.1-SNAPSHOT</version>
+ <version>0.14.0-SNAPSHOT</version>

- <artifactId>mahout-examples</artifactId>
- <name>Mahout Examples</name>
- <description>Scalable machine learning library examples</description>
+ <artifactId>engine</artifactId>
+ <name>Mahout Engine</name>
+ <description>Apache Mahout Examples.</description>

- <properties>
- <mahout.skip.example>false</mahout.skip.example>
- </properties>
- <build>
- <plugins>
- <plugin>
- <groupId>org.apache.maven.plugins</groupId>
- <artifactId>maven-dependency-plugin</artifactId>
- <executions>
- <execution>
- <id>copy-dependencies</id>
- <phase>package</phase>
- <goals>
- <goal>copy-dependencies</goal>
- </goals>
- <configuration>
- <!-- configure the plugin here -->
- </configuration>
- </execution>
- </executions>
- </plugin>

- <!-- create examples hadoop job jar -->
- <plugin>
- <groupId>org.apache.maven.plugins</groupId>
- <artifactId>maven-assembly-plugin</artifactId>
- <executions>
- <execution>
- <id>job</id>
- <phase>package</phase>
- <goals>
- <goal>single</goal>
- </goals>
- <configuration>
- <skipAssembly>${mahout.skip.example}</skipAssembly>
- <descriptors>
- <descriptor>src/main/assembly/job.xml</descriptor>
- </descriptors>
- </configuration>
- </execution>
- </executions>
- </plugin>
- <plugin>
- <groupId>org.apache.maven.plugins</groupId>
- <artifactId>maven-remote-resources-plugin</artifactId>
- <configuration>
- <appendedResourcesDirectory>../src/main/appended-resources</appendedResourcesDirectory>
- <resourceBundles>
- <resourceBundle>org.apache:apache-jar-resource-bundle:1.4</resourceBundle>
- </resourceBundles>
- <supplementalModels>
- <supplementalModel>supplemental-models.xml</supplementalModel>
- </supplementalModels>
- </configuration>
- </plugin>
- <plugin>
- <artifactId>maven-source-plugin</artifactId>
- </plugin>
- <plugin>
- <groupId>org.mortbay.jetty</groupId>
- <artifactId>maven-jetty-plugin</artifactId>
- <version>6.1.26</version>
- </plugin>
- </plugins>
- </build>
- <dependencies>
- <!-- our modules -->
- <dependency>
- <groupId>${project.groupId}</groupId>
- <artifactId>mahout-hdfs</artifactId>
- </dependency>
- <dependency>
- <groupId>${project.groupId}</groupId>
- <artifactId>mahout-mr</artifactId>
- </dependency>
- <dependency>
- <groupId>${project.groupId}</groupId>
- <artifactId>mahout-hdfs</artifactId>
- <type>test-jar</type>
- <scope>test</scope>
- </dependency>
- <dependency>
- <groupId>${project.groupId}</groupId>
- <artifactId>mahout-mr</artifactId>
- <type>test-jar</type>
- <scope>test</scope>
- </dependency>
- <dependency>
- <groupId>${project.groupId}</groupId>
- <artifactId>mahout-math</artifactId>
- </dependency>
- <dependency>
- <groupId>${project.groupId}</groupId>
- <artifactId>mahout-math</artifactId>
- <type>test-jar</type>
- <scope>test</scope>
- </dependency>
- <dependency>
- <groupId>${project.groupId}</groupId>
- <artifactId>mahout-integration</artifactId>
- </dependency>
- <dependency>
- <groupId>org.apache.lucene</groupId>
- <artifactId>lucene-benchmark</artifactId>
- </dependency>
- <dependency>
- <groupId>org.apache.lucene</groupId>
- <artifactId>lucene-analyzers-common</artifactId>
- </dependency>
- <dependency>
- <groupId>com.carrotsearch.randomizedtesting</groupId>
- <artifactId>randomizedtesting-runner</artifactId>
- </dependency>
- <dependency>
- <groupId>org.easymock</groupId>
- <artifactId>easymock</artifactId>
- </dependency>
- <dependency>
- <groupId>junit</groupId>
- <artifactId>junit</artifactId>
- </dependency>
- <dependency>
- <groupId>org.slf4j</groupId>
- <artifactId>slf4j-api</artifactId>
- </dependency>
- <dependency>
- <groupId>org.slf4j</groupId>
- <artifactId>slf4j-log4j12</artifactId>
- </dependency>
- <dependency>
- <groupId>org.slf4j</groupId>
- <artifactId>jcl-over-slf4j</artifactId>
- </dependency>
- <dependency>
- <groupId>commons-logging</groupId>
- <artifactId>commons-logging</artifactId>
- </dependency>
- <dependency>
- <groupId>log4j</groupId>
- <artifactId>log4j</artifactId>
- </dependency>
- </dependencies>
- <profiles>
- <profile>
- <id>release.prepare</id>
- <properties>
- <mahout.skip.example>true</mahout.skip.example>
- </properties>
- </profile>
- </profiles>
\ No newline at end of file

diff --git a/examples/src/main/assembly/job.xml b/examples/src/main/assembly/job.xml
deleted file mode 100644
index 0c41f3d..0000000
--- a/examples/src/main/assembly/job.xml
+++ /dev/null
@@ -1,46 +0,0 @@
-<?xml version="1.0" encoding="UTF-8"?>
- Licensed to the Apache Software Foundation (ASF) under one or more
- contributor license agreements. See the NOTICE file distributed with
- this work for additional information regarding copyright ownership.
- The ASF licenses this file to You under the Apache License, Version 2.0
- (the "License"); you may not use this file except in compliance with
- the License. You may obtain a copy of the License at
- http://www.apache.org/licenses/LICENSE-2.0
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- See the License for the specific language governing permissions and
- limitations under the License.
- xmlns="http://maven.apache.org/plugins/maven-assembly-plugin/assembly/1.1.0"
- xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
- xsi:schemaLocation="http://maven.apache.org/plugins/maven-assembly-plugin/assembly/1.1.0
- http://maven.apache.org/xsd/assembly-1.1.0.xsd">
- <id>job</id>
- <formats>
- <format>jar</format>
- </formats>
- <includeBaseDirectory>false</includeBaseDirectory>
- <dependencySets>
- <dependencySet>
- <unpack>true</unpack>
- <unpackOptions>
- <!-- MAHOUT-1126 -->
- <excludes>
- <exclude>META-INF/LICENSE</exclude>
- </excludes>
- </unpackOptions>
- <scope>runtime</scope>
- <outputDirectory>/</outputDirectory>
- <useTransitiveFiltering>true</useTransitiveFiltering>
- <excludes>
- <exclude>org.apache.hadoop:hadoop-core</exclude>
- </excludes>
- </dependencySet>
- </dependencySets>
\ No newline at end of file

diff --git a/examples/src/main/java/org/apache/mahout/cf/taste/example/TasteOptionParser.java b/examples/src/main/java/org/apache/mahout/cf/taste/example/TasteOptionParser.java
deleted file mode 100644
index 6392b9f..0000000
--- a/examples/src/main/java/org/apache/mahout/cf/taste/example/TasteOptionParser.java
+++ /dev/null
@@ -1,75 +0,0 @@
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.mahout.cf.taste.example;
-import java.io.File;
-import org.apache.commons.cli2.CommandLine;
-import org.apache.commons.cli2.Group;
-import org.apache.commons.cli2.Option;
-import org.apache.commons.cli2.OptionException;
-import org.apache.commons.cli2.builder.ArgumentBuilder;
-import org.apache.commons.cli2.builder.DefaultOptionBuilder;
-import org.apache.commons.cli2.builder.GroupBuilder;
-import org.apache.commons.cli2.commandline.Parser;
-import org.apache.mahout.common.CommandLineUtil;
-import org.apache.mahout.common.commandline.DefaultOptionCreator;
- * This class provides a common implementation for parsing input parameters for
- * all taste examples. Currently they only need the path to the recommendations
- * file as input.
- *
- * The class is safe to be used in threaded contexts.
- */
-public final class TasteOptionParser {
- private TasteOptionParser() {
- }
- /**
- * Parse the given command line arguments.
- * @param args the arguments as given to the application.
- * @return the input file if a file was given on the command line, null otherwise.
- */
- public static File getRatings(String[] args) throws OptionException {
- DefaultOptionBuilder obuilder = new DefaultOptionBuilder();
- ArgumentBuilder abuilder = new ArgumentBuilder();
- GroupBuilder gbuilder = new GroupBuilder();
- Option inputOpt = obuilder.withLongName("input").withRequired(false).withShortName("i")
- .withArgument(abuilder.withName("input").withMinimum(1).withMaximum(1).create())
- .withDescription("The Path for input data directory.").create();
- Option helpOpt = DefaultOptionCreator.helpOption();
- Group group = gbuilder.withName("Options").withOption(inputOpt).withOption(helpOpt).create();
- Parser parser = new Parser();
- parser.setGroup(group);
- CommandLine cmdLine = parser.parse(args);
- if (cmdLine.hasOption(helpOpt)) {
- CommandLineUtil.printHelp(group);
- return null;
- }
- return cmdLine.hasOption(inputOpt) ? new File(cmdLine.getValue(inputOpt).toString()) : null;
- }

diff --git a/examples/src/main/java/org/apache/mahout/cf/taste/example/bookcrossing/BookCrossingBooleanRecommender.java b/examples/src/main/java/org/apache/mahout/cf/taste/example/bookcrossing/BookCrossingBooleanRecommender.java
deleted file mode 100644
index c908e5b..0000000
--- a/examples/src/main/java/org/apache/mahout/cf/taste/example/bookcrossing/BookCrossingBooleanRecommender.java
+++ /dev/null
@@ -1,102 +0,0 @@
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.mahout.cf.taste.example.bookcrossing;
-import org.apache.mahout.cf.taste.common.Refreshable;
-import org.apache.mahout.cf.taste.common.TasteException;
-import org.apache.mahout.cf.taste.impl.neighborhood.NearestNUserNeighborhood;
-import org.apache.mahout.cf.taste.impl.recommender.GenericBooleanPrefUserBasedRecommender;
-import org.apache.mahout.cf.taste.impl.similarity.CachingUserSimilarity;
-import org.apache.mahout.cf.taste.impl.similarity.LogLikelihoodSimilarity;
-import org.apache.mahout.cf.taste.model.DataModel;
-import org.apache.mahout.cf.taste.neighborhood.UserNeighborhood;
-import org.apache.mahout.cf.taste.recommender.IDRescorer;
-import org.apache.mahout.cf.taste.recommender.RecommendedItem;
-import org.apache.mahout.cf.taste.recommender.Recommender;
-import org.apache.mahout.cf.taste.similarity.UserSimilarity;
-import java.util.Collection;
-import java.util.List;
- * A simple {@link Recommender} implemented for the Book Crossing demo.
- * See the <a href="http://www.informatik.uni-freiburg.de/~cziegler/BX/">Book Crossing site</a>.
- */
-public final class BookCrossingBooleanRecommender implements Recommender {
- private final Recommender recommender;
- public BookCrossingBooleanRecommender(DataModel bcModel) throws TasteException {
- UserSimilarity similarity = new CachingUserSimilarity(new LogLikelihoodSimilarity(bcModel), bcModel);
- UserNeighborhood neighborhood =
- new NearestNUserNeighborhood(10, Double.NEGATIVE_INFINITY, similarity, bcModel, 1.0);
- recommender = new GenericBooleanPrefUserBasedRecommender(bcModel, neighborhood, similarity);
- }
- @Override
- public List<RecommendedItem> recommend(long userID, int howMany) throws TasteException {
- return recommender.recommend(userID, howMany);
- }
- @Override
- public List<RecommendedItem> recommend(long userID, int howMany, boolean includeKnownItems) throws TasteException {
- return recommend(userID, howMany, null, includeKnownItems);
- }
- @Override
- public List<RecommendedItem> recommend(long userID, int howMany, IDRescorer rescorer) throws TasteException {
- return recommender.recommend(userID, howMany, rescorer, false);
- }
- @Override
- public List<RecommendedItem> recommend(long userID, int howMany, IDRescorer rescorer, boolean includeKnownItems)
- throws TasteException {
- return recommender.recommend(userID, howMany, rescorer, includeKnownItems);
- }
- @Override
- public float estimatePreference(long userID, long itemID) throws TasteException {
- return recommender.estimatePreference(userID, itemID);
- }
- @Override
- public void setPreference(long userID, long itemID, float value) throws TasteException {
- recommender.setPreference(userID, itemID, value);
- }
- @Override
- public void removePreference(long userID, long itemID) throws TasteException {
- recommender.removePreference(userID, itemID);
- }
- @Override
- public DataModel getDataModel() {
- return recommender.getDataModel();
- }
- @Override
- public void refresh(Collection<Refreshable> alreadyRefreshed) {
- recommender.refresh(alreadyRefreshed);
- }
- @Override
- public String toString() {
- return "BookCrossingBooleanRecommender[recommender:" + recommender + ']';
- }

diff --git a/examples/src/main/java/org/apache/mahout/cf/taste/example/bookcrossing/BookCrossingBooleanRecommenderBuilder.java b/examples/src/main/java/org/apache/mahout/cf/taste/example/bookcrossing/BookCrossingBooleanRecommenderBuilder.java
deleted file mode 100644
index 2219bce..0000000
--- a/examples/src/main/java/org/apache/mahout/cf/taste/example/bookcrossing/BookCrossingBooleanRecommenderBuilder.java
+++ /dev/null
@@ -1,32 +0,0 @@
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.mahout.cf.taste.example.bookcrossing;
-import org.apache.mahout.cf.taste.common.TasteException;
-import org.apache.mahout.cf.taste.eval.RecommenderBuilder;
-import org.apache.mahout.cf.taste.model.DataModel;
-import org.apache.mahout.cf.taste.recommender.Recommender;
-final class BookCrossingBooleanRecommenderBuilder implements RecommenderBuilder {
- @Override
- public Recommender buildRecommender(DataModel dataModel) throws TasteException {
- return new BookCrossingBooleanRecommender(dataModel);
- }

diff --git a/examples/src/main/java/org/apache/mahout/cf/taste/example/bookcrossing/BookCrossingBooleanRecommenderEvaluatorRunner.java b/examples/src/main/java/org/apache/mahout/cf/taste/example/bookcrossing/BookCrossingBooleanRecommenderEvaluatorRunner.java
deleted file mode 100644
index b9814c7..0000000
--- a/examples/src/main/java/org/apache/mahout/cf/taste/example/bookcrossing/BookCrossingBooleanRecommenderEvaluatorRunner.java
+++ /dev/null
@@ -1,59 +0,0 @@
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.mahout.cf.taste.example.bookcrossing;
-import org.apache.commons.cli2.OptionException;
-import org.apache.mahout.cf.taste.common.TasteException;
-import org.apache.mahout.cf.taste.eval.IRStatistics;
-import org.apache.mahout.cf.taste.eval.RecommenderIRStatsEvaluator;
-import org.apache.mahout.cf.taste.example.TasteOptionParser;
-import org.apache.mahout.cf.taste.impl.eval.GenericRecommenderIRStatsEvaluator;
-import org.apache.mahout.cf.taste.model.DataModel;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
-import java.io.File;
-import java.io.IOException;
-public final class BookCrossingBooleanRecommenderEvaluatorRunner {
- private static final Logger log = LoggerFactory.getLogger(BookCrossingBooleanRecommenderEvaluatorRunner.class);
- private BookCrossingBooleanRecommenderEvaluatorRunner() {
- // do nothing
- }
- public static void main(String... args) throws IOException, TasteException, OptionException {
- RecommenderIRStatsEvaluator evaluator = new GenericRecommenderIRStatsEvaluator();
- File ratingsFile = TasteOptionParser.getRatings(args);
- DataModel model =
- ratingsFile == null ? new BookCrossingDataModel(true) : new BookCrossingDataModel(ratingsFile, true);
- IRStatistics evaluation = evaluator.evaluate(
- new BookCrossingBooleanRecommenderBuilder(),
- new BookCrossingDataModelBuilder(),
- model,
- null,
- 3,
- 1.0);
- log.info(String.valueOf(evaluation));
- }

diff --git a/examples/src/main/java/org/apache/mahout/cf/taste/example/bookcrossing/BookCrossingDataModel.java b/examples/src/main/java/org/apache/mahout/cf/taste/example/bookcrossing/BookCrossingDataModel.java
deleted file mode 100644
index 3e2f8b5..0000000
--- a/examples/src/main/java/org/apache/mahout/cf/taste/example/bookcrossing/BookCrossingDataModel.java
+++ /dev/null
@@ -1,99 +0,0 @@
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.mahout.cf.taste.example.bookcrossing;
-import java.io.File;
-import java.io.FileNotFoundException;
-import java.io.FileOutputStream;
-import java.io.IOException;
-import java.io.OutputStreamWriter;
-import java.io.Writer;
-import java.util.regex.Pattern;
-import com.google.common.base.Charsets;
-import com.google.common.io.Closeables;
-import org.apache.mahout.cf.taste.similarity.precompute.example.GroupLensDataModel;
-import org.apache.mahout.cf.taste.impl.model.file.FileDataModel;
-import org.apache.mahout.common.iterator.FileLineIterable;
- * See <a href="http://www.informatik.uni-freiburg.de/~cziegler/BX/BX-CSV-Dump.zip">download</a> for
- * data needed by this class. The BX-Book-Ratings.csv file is needed.
- */
-public final class BookCrossingDataModel extends FileDataModel {
- private static final Pattern NON_DIGIT_SEMICOLON_PATTERN = Pattern.compile("[^0-9;]");
- public BookCrossingDataModel(boolean ignoreRatings) throws IOException {
- this(GroupLensDataModel.readResourceToTempFile(
- "/org/apache/mahout/cf/taste/example/bookcrossing/BX-Book-Ratings.csv"),
- ignoreRatings);
- }
- /**
- * @param ratingsFile BookCrossing ratings file in its native format
- * @throws IOException if an error occurs while reading or writing files
- */
- public BookCrossingDataModel(File ratingsFile, boolean ignoreRatings) throws IOException {
- super(convertBCFile(ratingsFile, ignoreRatings));
- }
- private static File convertBCFile(File originalFile, boolean ignoreRatings) throws IOException {
- if (!originalFile.exists()) {
- throw new FileNotFoundException(originalFile.toString());
- }
- File resultFile = new File(new File(System.getProperty("java.io.tmpdir")), "taste.bookcrossing.txt");
- resultFile.delete();
- Writer writer = null;
- try {
- writer = new OutputStreamWriter(new FileOutputStream(resultFile), Charsets.UTF_8);
- for (String line : new FileLineIterable(originalFile, true)) {
- // 0 ratings are basically "no rating", ignore them (thanks h.9000)
- if (line.endsWith("\"0\"")) {
- continue;
- }
- // Delete replace anything that isn't numeric, or a semicolon delimiter. Make comma the delimiter.
- String convertedLine = NON_DIGIT_SEMICOLON_PATTERN.matcher(line)
- .replaceAll("").replace(';', ',');
- // If this means we deleted an entire ID -- few cases like that -- skip the line
- if (convertedLine.contains(",,")) {
- continue;
- }
- if (ignoreRatings) {
- // drop rating
- convertedLine = convertedLine.substring(0, convertedLine.lastIndexOf(','));
- }
- writer.write(convertedLine);
- writer.write('\n');
- }
- writer.flush();
- } catch (IOException ioe) {
- resultFile.delete();
- throw ioe;
- } finally {
- Closeables.close(writer, false);
- }
- return resultFile;
- }
- @Override
- public String toString() {
- return "BookCrossingDataModel";
- }

diff --git a/examples/src/main/java/org/apache/mahout/cf/taste/example/bookcrossing/BookCrossingDataModelBuilder.java b/examples/src/main/java/org/apache/mahout/cf/taste/example/bookcrossing/BookCrossingDataModelBuilder.java
deleted file mode 100644
index 9ec2eaf..0000000
--- a/examples/src/main/java/org/apache/mahout/cf/taste/example/bookcrossing/BookCrossingDataModelBuilder.java
+++ /dev/null
@@ -1,33 +0,0 @@
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.mahout.cf.taste.example.bookcrossing;
-import org.apache.mahout.cf.taste.eval.DataModelBuilder;
-import org.apache.mahout.cf.taste.impl.common.FastByIDMap;
-import org.apache.mahout.cf.taste.impl.model.GenericBooleanPrefDataModel;
-import org.apache.mahout.cf.taste.model.DataModel;
-import org.apache.mahout.cf.taste.model.PreferenceArray;
-final class BookCrossingDataModelBuilder implements DataModelBuilder {
- @Override
- public DataModel buildDataModel(FastByIDMap<PreferenceArray> trainingData) {
- return new GenericBooleanPrefDataModel(GenericBooleanPrefDataModel.toDataMap(trainingData));
- }

diff --git a/examples/src/main/java/org/apache/mahout/cf/taste/example/bookcrossing/BookCrossingRecommender.java b/examples/src/main/java/org/apache/mahout/cf/taste/example/bookcrossing/BookCrossingRecommender.java
deleted file mode 100644
index c06ca2f..0000000
--- a/examples/src/main/java/org/apache/mahout/cf/taste/example/bookcrossing/BookCrossingRecommender.java
+++ /dev/null
@@ -1,101 +0,0 @@
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.mahout.cf.taste.example.bookcrossing;
-import java.util.Collection;
-import java.util.List;
-import org.apache.mahout.cf.taste.common.Refreshable;
-import org.apache.mahout.cf.taste.common.TasteException;
-import org.apache.mahout.cf.taste.impl.neighborhood.NearestNUserNeighborhood;
-import org.apache.mahout.cf.taste.impl.recommender.GenericUserBasedRecommender;
-import org.apache.mahout.cf.taste.impl.similarity.CachingUserSimilarity;
-import org.apache.mahout.cf.taste.impl.similarity.EuclideanDistanceSimilarity;
-import org.apache.mahout.cf.taste.model.DataModel;
-import org.apache.mahout.cf.taste.neighborhood.UserNeighborhood;
-import org.apache.mahout.cf.taste.recommender.IDRescorer;
-import org.apache.mahout.cf.taste.recommender.RecommendedItem;
-import org.apache.mahout.cf.taste.recommender.Recommender;
-import org.apache.mahout.cf.taste.similarity.UserSimilarity;
- * A simple {@link Recommender} implemented for the Book Crossing demo.
- * See the <a href="http://www.informatik.uni-freiburg.de/~cziegler/BX/">Book Crossing site</a>.
- */
-public final class BookCrossingRecommender implements Recommender {
- private final Recommender recommender;
- public BookCrossingRecommender(DataModel bcModel) throws TasteException {
- UserSimilarity similarity = new CachingUserSimilarity(new EuclideanDistanceSimilarity(bcModel), bcModel);
- UserNeighborhood neighborhood = new NearestNUserNeighborhood(10, 0.2, similarity, bcModel, 0.2);
- recommender = new GenericUserBasedRecommender(bcModel, neighborhood, similarity);
- }
- @Override
- public List<RecommendedItem> recommend(long userID, int howMany) throws TasteException {
- return recommender.recommend(userID, howMany);
- }
- @Override
- public List<RecommendedItem> recommend(long userID, int howMany, boolean includeKnownItems) throws TasteException {
- return recommend(userID, howMany, null, includeKnownItems);
- }
- @Override
- public List<RecommendedItem> recommend(long userID, int howMany, IDRescorer rescorer) throws TasteException {
- return recommender.recommend(userID, howMany, rescorer, false);
- }
- @Override
- public List<RecommendedItem> recommend(long userID, int howMany, IDRescorer rescorer, boolean includeKnownItems)
- throws TasteException {
- return recommender.recommend(userID, howMany, rescorer, false);
- }
- @Override
- public float estimatePreference(long userID, long itemID) throws TasteException {
- return recommender.estimatePreference(userID, itemID);
- }
- @Override
- public void setPreference(long userID, long itemID, float value) throws TasteException {
- recommender.setPreference(userID, itemID, value);
- }
- @Override
- public void removePreference(long userID, long itemID) throws TasteException {
- recommender.removePreference(userID, itemID);
- }
- @Override
- public DataModel getDataModel() {
- return recommender.getDataModel();
- }
- @Override
- public void refresh(Collection<Refreshable> alreadyRefreshed) {
- recommender.refresh(alreadyRefreshed);
- }
- @Override
- public String toString() {
- return "BookCrossingRecommender[recommender:" + recommender + ']';
- }

diff --git a/examples/src/main/java/org/apache/mahout/cf/taste/example/bookcrossing/BookCrossingRecommenderBuilder.java b/examples/src/main/java/org/apache/mahout/cf/taste/example/bookcrossing/BookCrossingRecommenderBuilder.java
deleted file mode 100644
index bb6d3e1..0000000
--- a/examples/src/main/java/org/apache/mahout/cf/taste/example/bookcrossing/BookCrossingRecommenderBuilder.java
+++ /dev/null
@@ -1,32 +0,0 @@
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.mahout.cf.taste.example.bookcrossing;
-import org.apache.mahout.cf.taste.common.TasteException;
-import org.apache.mahout.cf.taste.eval.RecommenderBuilder;
-import org.apache.mahout.cf.taste.model.DataModel;
-import org.apache.mahout.cf.taste.recommender.Recommender;
-final class BookCrossingRecommenderBuilder implements RecommenderBuilder {
- @Override
- public Recommender buildRecommender(DataModel dataModel) throws TasteException {
- return new BookCrossingRecommender(dataModel);
- }

diff --git a/examples/src/main/java/org/apache/mahout/cf/taste/example/bookcrossing/BookCrossingRecommenderEvaluatorRunner.java b/examples/src/main/java/org/apache/mahout/cf/taste/example/bookcrossing/BookCrossingRecommenderEvaluatorRunner.java
deleted file mode 100644
index 97074d2..0000000
--- a/examples/src/main/java/org/apache/mahout/cf/taste/example/bookcrossing/BookCrossingRecommenderEvaluatorRunner.java
+++ /dev/null
@@ -1,54 +0,0 @@
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.mahout.cf.taste.example.bookcrossing;
-import java.io.File;
-import java.io.IOException;
-import org.apache.commons.cli2.OptionException;
-import org.apache.mahout.cf.taste.common.TasteException;
-import org.apache.mahout.cf.taste.eval.RecommenderEvaluator;
-import org.apache.mahout.cf.taste.example.TasteOptionParser;
-import org.apache.mahout.cf.taste.impl.eval.AverageAbsoluteDifferenceRecommenderEvaluator;
-import org.apache.mahout.cf.taste.model.DataModel;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
-public final class BookCrossingRecommenderEvaluatorRunner {
- private static final Logger log = LoggerFactory.getLogger(BookCrossingRecommenderEvaluatorRunner.class);
- private BookCrossingRecommenderEvaluatorRunner() {
- // do nothing
- }
- public static void main(String... args) throws IOException, TasteException, OptionException {
- RecommenderEvaluator evaluator = new AverageAbsoluteDifferenceRecommenderEvaluator();
- File ratingsFile = TasteOptionParser.getRatings(args);
- DataModel model =
- ratingsFile == null ? new BookCrossingDataModel(false) : new BookCrossingDataModel(ratingsFile, false);
- double evaluation = evaluator.evaluate(new BookCrossingRecommenderBuilder(),
- null,
- model,
- 0.9,
- 0.3);
- log.info(String.valueOf(evaluation));
- }

diff --git a/examples/src/main/java/org/apache/mahout/cf/taste/example/bookcrossing/README b/examples/src/main/java/org/apache/mahout/cf/taste/example/bookcrossing/README
deleted file mode 100644
index 9244fe3..0000000
--- a/examples/src/main/java/org/apache/mahout/cf/taste/example/bookcrossing/README
+++ /dev/null
@@ -1,9 +0,0 @@
-Code works with BookCrossing data set, which is not included in this distribution but is downloadable from
-Data set originated from:
-Improving Recommendation Lists Through Topic Diversification,
- Cai-Nicolas Ziegler, Sean M. McNee, Joseph A. Konstan, Georg Lausen;
- Proceedings of the 14th International World Wide Web Conference (WWW '05), May 10-14, 2005, Chiba, Japan.
- To appear.
\ No newline at end of file

diff --git a/examples/src/main/java/org/apache/mahout/cf/taste/example/email/EmailUtility.java b/examples/src/main/java/org/apache/mahout/cf/taste/example/email/EmailUtility.java
deleted file mode 100644
index 033daa2..0000000
--- a/examples/src/main/java/org/apache/mahout/cf/taste/example/email/EmailUtility.java
+++ /dev/null
@@ -1,104 +0,0 @@
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.mahout.cf.taste.example.email;
-import org.apache.hadoop.conf.Configuration;
-import org.apache.hadoop.fs.FileSystem;
-import org.apache.hadoop.fs.Path;
-import org.apache.hadoop.io.IntWritable;
-import org.apache.hadoop.io.Writable;
-import org.apache.mahout.common.HadoopUtil;
-import org.apache.mahout.common.Pair;
-import org.apache.mahout.common.iterator.sequencefile.SequenceFileIterable;
-import org.apache.mahout.math.map.OpenObjectIntHashMap;
-import java.io.IOException;
-import java.util.regex.Pattern;
-public final class EmailUtility {
- public static final String SEPARATOR = "separator";
- public static final String MSG_IDS_PREFIX = "msgIdsPrefix";
- public static final String FROM_PREFIX = "fromPrefix";
- public static final String MSG_ID_DIMENSION = "msgIdDim";
- public static final String FROM_INDEX = "fromIdx";
- public static final String REFS_INDEX = "refsIdx";
- private static final String[] EMPTY = new String[0];
- private static final Pattern ADDRESS_CLEANUP = Pattern.compile("mailto:|<|>|\\[|\\]|\\=20");
- private static final Pattern ANGLE_BRACES = Pattern.compile("<|>");
- private static final Pattern SPACE_OR_CLOSE_ANGLE = Pattern.compile(">|\\s+");
- public static final Pattern WHITESPACE = Pattern.compile("\\s*");
- private EmailUtility() {
- }
- /**
- * Strip off some spurious characters that make it harder to dedup
- */
- public static String cleanUpEmailAddress(CharSequence address) {
- //do some cleanup to normalize some things, like: Key: karthik ananth <***@gmail.com>: Value: 178
- //Key: karthik ananth [mailto:***@gmail.com]=20: Value: 179
- //TODO: is there more to clean up here?
- return ADDRESS_CLEANUP.matcher(address).replaceAll("");
- }
- public static void loadDictionaries(Configuration conf, String fromPrefix,
- OpenObjectIntHashMap<String> fromDictionary,
- String msgIdPrefix,
- OpenObjectIntHashMap<String> msgIdDictionary) throws IOException {
- Path[] localFiles = HadoopUtil.getCachedFiles(conf);
- FileSystem fs = FileSystem.getLocal(conf);
- for (Path dictionaryFile : localFiles) {
- // key is word value is id
- OpenObjectIntHashMap<String> dictionary = null;
- if (dictionaryFile.getName().startsWith(fromPrefix)) {
- dictionary = fromDictionary;
- } else if (dictionaryFile.getName().startsWith(msgIdPrefix)) {
- dictionary = msgIdDictionary;
- }
- if (dictionary != null) {
- dictionaryFile = fs.makeQualified(dictionaryFile);
- for (Pair<Writable, IntWritable> record
- : new SequenceFileIterable<Writable, IntWritable>(dictionaryFile, true, conf)) {
- dictionary.put(record.getFirst().toString(), record.getSecond().get());
- }
- }
- }
- }
- public static String[] parseReferences(CharSequence rawRefs) {
- String[] splits;
- if (rawRefs != null && rawRefs.length() > 0) {
- splits = SPACE_OR_CLOSE_ANGLE.split(rawRefs);
- for (int i = 0; i < splits.length; i++) {
- splits[i] = ANGLE_BRACES.matcher(splits[i]).replaceAll("");
- }
- } else {
- splits = EMPTY;
- }
- return splits;
- }
- public enum Counters {
- }

diff --git a/examples/src/main/java/org/apache/mahout/cf/taste/example/email/FromEmailToDictionaryMapper.java b/examples/src/main/java/org/apache/mahout/cf/taste/example/email/FromEmailToDictionaryMapper.java
deleted file mode 100644
index 5cd308d..0000000
--- a/examples/src/main/java/org/apache/mahout/cf/taste/example/email/FromEmailToDictionaryMapper.java
+++ /dev/null
@@ -1,61 +0,0 @@
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.mahout.cf.taste.example.email;
-import org.apache.hadoop.io.Text;
-import org.apache.hadoop.mapreduce.Mapper;
-import org.apache.mahout.math.VarIntWritable;
-import java.io.IOException;
- * Assumes the input is in the format created by {@link org.apache.mahout.text.SequenceFilesFromMailArchives}
- */
-public final class FromEmailToDictionaryMapper extends Mapper<Text, Text, Text, VarIntWritable> {
- private String separator;
- @Override
- protected void setup(Context context) throws IOException, InterruptedException {
- super.setup(context);
- separator = context.getConfiguration().get(EmailUtility.SEPARATOR);
- }
- @Override
- protected void map(Text key, Text value, Context context) throws IOException, InterruptedException {
- //From is in the value
- String valStr = value.toString();
- int idx = valStr.indexOf(separator);
- if (idx == -1) {
- context.getCounter(EmailUtility.Counters.NO_FROM_ADDRESS).increment(1);
- } else {
- String full = valStr.substring(0, idx);
- //do some cleanup to normalize some things, like: Key: karthik ananth <***@gmail.com>: Value: 178
- //Key: karthik ananth [mailto:***@gmail.com]=20: Value: 179
- //TODO: is there more to clean up here?
- full = EmailUtility.cleanUpEmailAddress(full);
- if (EmailUtility.WHITESPACE.matcher(full).matches()) {
- context.getCounter(EmailUtility.Counters.NO_FROM_ADDRESS).increment(1);
- } else {
- context.write(new Text(full), new VarIntWritable(1));
- }
- }
- }

diff --git a/examples/src/main/java/org/apache/mahout/cf/taste/example/email/MailToDictionaryReducer.java b/examples/src/main/java/org/apache/mahout/cf/taste/example/email/MailToDictionaryReducer.java
deleted file mode 100644
index 72fcde9..0000000
--- a/examples/src/main/java/org/apache/mahout/cf/taste/example/email/MailToDictionaryReducer.java
+++ /dev/null
@@ -1,43 +0,0 @@
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.mahout.cf.taste.example.email;
-import org.apache.hadoop.io.Text;
-import org.apache.hadoop.mapreduce.Reducer;
-import org.apache.mahout.math.VarIntWritable;
-import java.io.IOException;
- * Key: the string id
- * Value: the count
- * Out Key: the string id
- * Out Value: the sum of the counts
- */
-public final class MailToDictionaryReducer extends Reducer<Text, VarIntWritable, Text, VarIntWritable> {
- @Override
- protected void reduce(Text key, Iterable<VarIntWritable> values, Context context)
- throws IOException, InterruptedException {
- int sum = 0;
- for (VarIntWritable value : values) {
- sum += value.get();
- }
- context.write(new Text(key), new VarIntWritable(sum));
- }
2018-06-27 13:14:37 UTC
diff --git a/examples/bin/resources/bank-full.csv b/examples/bin/resources/bank-full.csv
deleted file mode 100644
index d7a2ede..0000000
--- a/examples/bin/resources/bank-full.csv
+++ /dev/null
@@ -1,45212 +0,0 @@

2018-06-27 13:14:42 UTC
diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/clustering/streaming/tools/IOUtils.java b/community/mahout-mr/examples/src/main/java/org/apache/mahout/clustering/streaming/tools/IOUtils.java
new file mode 100644
index 0000000..bd1149b
--- /dev/null
+++ b/community/mahout-mr/examples/src/main/java/org/apache/mahout/clustering/streaming/tools/IOUtils.java
@@ -0,0 +1,80 @@
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.clustering.streaming.tools;
+import com.google.common.base.Function;
+import com.google.common.base.Preconditions;
+import com.google.common.collect.Iterables;
+import org.apache.mahout.clustering.iterator.ClusterWritable;
+import org.apache.mahout.clustering.streaming.mapreduce.CentroidWritable;
+import org.apache.mahout.math.Centroid;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+public class IOUtils {
+ private IOUtils() {}
+ /**
+ * Converts CentroidWritable values in a sequence file into Centroids lazily.
+ * @param dirIterable the source iterable (comes from a SequenceFileDirIterable).
+ * @return an Iterable<Centroid> with the converted vectors.
+ */
+ public static Iterable<Centroid> getCentroidsFromCentroidWritableIterable(
+ Iterable<CentroidWritable> dirIterable) {
+ return Iterables.transform(dirIterable, new Function<CentroidWritable, Centroid>() {
+ @Override
+ public Centroid apply(CentroidWritable input) {
+ Preconditions.checkNotNull(input);
+ return input.getCentroid().clone();
+ }
+ });
+ }
+ /**
+ * Converts CentroidWritable values in a sequence file into Centroids lazily.
+ * @param dirIterable the source iterable (comes from a SequenceFileDirIterable).
+ * @return an Iterable<Centroid> with the converted vectors.
+ */
+ public static Iterable<Centroid> getCentroidsFromClusterWritableIterable(Iterable<ClusterWritable> dirIterable) {
+ return Iterables.transform(dirIterable, new Function<ClusterWritable, Centroid>() {
+ int numClusters = 0;
+ @Override
+ public Centroid apply(ClusterWritable input) {
+ Preconditions.checkNotNull(input);
+ return new Centroid(numClusters++, input.getValue().getCenter().clone(),
+ input.getValue().getTotalObservations());
+ }
+ });
+ }
+ /**
+ * Converts VectorWritable values in a sequence file into Vectors lazily.
+ * @param dirIterable the source iterable (comes from a SequenceFileDirIterable).
+ * @return an Iterable<Vector> with the converted vectors.
+ */
+ public static Iterable<Vector> getVectorsFromVectorWritableIterable(Iterable<VectorWritable> dirIterable) {
+ return Iterables.transform(dirIterable, new Function<VectorWritable, Vector>() {
+ @Override
+ public Vector apply(VectorWritable input) {
+ Preconditions.checkNotNull(input);
+ return input.get().clone();
+ }
+ });
+ }

diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/clustering/syntheticcontrol/canopy/Job.java b/community/mahout-mr/examples/src/main/java/org/apache/mahout/clustering/syntheticcontrol/canopy/Job.java
new file mode 100644
index 0000000..083cd8c
--- /dev/null
+++ b/community/mahout-mr/examples/src/main/java/org/apache/mahout/clustering/syntheticcontrol/canopy/Job.java
@@ -0,0 +1,125 @@
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.clustering.syntheticcontrol.canopy;
+import java.util.List;
+import java.util.Map;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.util.ToolRunner;
+import org.apache.mahout.clustering.canopy.CanopyDriver;
+import org.apache.mahout.clustering.conversion.InputDriver;
+import org.apache.mahout.common.AbstractJob;
+import org.apache.mahout.common.ClassUtils;
+import org.apache.mahout.common.HadoopUtil;
+import org.apache.mahout.common.commandline.DefaultOptionCreator;
+import org.apache.mahout.common.distance.DistanceMeasure;
+import org.apache.mahout.common.distance.EuclideanDistanceMeasure;
+import org.apache.mahout.utils.clustering.ClusterDumper;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+public final class Job extends AbstractJob {
+ private static final String DIRECTORY_CONTAINING_CONVERTED_INPUT = "data";
+ private Job() {
+ }
+ private static final Logger log = LoggerFactory.getLogger(Job.class);
+ public static void main(String[] args) throws Exception {
+ if (args.length > 0) {
+ log.info("Running with only user-supplied arguments");
+ ToolRunner.run(new Configuration(), new Job(), args);
+ } else {
+ log.info("Running with default arguments");
+ Path output = new Path("output");
+ HadoopUtil.delete(new Configuration(), output);
+ run(new Path("testdata"), output, new EuclideanDistanceMeasure(), 80, 55);
+ }
+ }
+ /**
+ * Run the canopy clustering job on an input dataset using the given distance
+ * measure, t1 and t2 parameters. All output data will be written to the
+ * output directory, which will be initially deleted if it exists. The
+ * clustered points will reside in the path <output>/clustered-points. By
+ * default, the job expects the a file containing synthetic_control.data as
+ * obtained from
+ * http://archive.ics.uci.edu/ml/datasets/Synthetic+Control+Chart+Time+Series
+ * resides in a directory named "testdata", and writes output to a directory
+ * named "output".
+ *
+ * @param input
+ * the String denoting the input directory path
+ * @param output
+ * the String denoting the output directory path
+ * @param measure
+ * the DistanceMeasure to use
+ * @param t1
+ * the canopy T1 threshold
+ * @param t2
+ * the canopy T2 threshold
+ */
+ private static void run(Path input, Path output, DistanceMeasure measure,
+ double t1, double t2) throws Exception {
+ Path directoryContainingConvertedInput = new Path(output,
+ InputDriver.runJob(input, directoryContainingConvertedInput,
+ "org.apache.mahout.math.RandomAccessSparseVector");
+ CanopyDriver.run(new Configuration(), directoryContainingConvertedInput,
+ output, measure, t1, t2, true, 0.0, false);
+ // run ClusterDumper
+ ClusterDumper clusterDumper = new ClusterDumper(new Path(output,
+ "clusters-0-final"), new Path(output, "clusteredPoints"));
+ clusterDumper.printClusters(null);
+ }
+ @Override
+ public int run(String[] args) throws Exception {
+ addInputOption();
+ addOutputOption();
+ addOption(DefaultOptionCreator.distanceMeasureOption().create());
+ addOption(DefaultOptionCreator.t1Option().create());
+ addOption(DefaultOptionCreator.t2Option().create());
+ addOption(DefaultOptionCreator.overwriteOption().create());
+ Map<String, List<String>> argMap = parseArguments(args);
+ if (argMap == null) {
+ return -1;
+ }
+ Path input = getInputPath();
+ Path output = getOutputPath();
+ if (hasOption(DefaultOptionCreator.OVERWRITE_OPTION)) {
+ HadoopUtil.delete(new Configuration(), output);
+ }
+ String measureClass = getOption(DefaultOptionCreator.DISTANCE_MEASURE_OPTION);
+ double t1 = Double.parseDouble(getOption(DefaultOptionCreator.T1_OPTION));
+ double t2 = Double.parseDouble(getOption(DefaultOptionCreator.T2_OPTION));
+ DistanceMeasure measure = ClassUtils.instantiateAs(measureClass, DistanceMeasure.class);
+ run(input, output, measure, t1, t2);
+ return 0;
+ }

diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/clustering/syntheticcontrol/fuzzykmeans/Job.java b/community/mahout-mr/examples/src/main/java/org/apache/mahout/clustering/syntheticcontrol/fuzzykmeans/Job.java
new file mode 100644
index 0000000..43beb78
--- /dev/null
+++ b/community/mahout-mr/examples/src/main/java/org/apache/mahout/clustering/syntheticcontrol/fuzzykmeans/Job.java
@@ -0,0 +1,144 @@
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.clustering.syntheticcontrol.fuzzykmeans;
+import java.util.List;
+import java.util.Map;
+import org.apache.commons.cli2.builder.ArgumentBuilder;
+import org.apache.commons.cli2.builder.DefaultOptionBuilder;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.util.ToolRunner;
+import org.apache.mahout.clustering.canopy.CanopyDriver;
+import org.apache.mahout.clustering.conversion.InputDriver;
+import org.apache.mahout.clustering.fuzzykmeans.FuzzyKMeansDriver;
+import org.apache.mahout.common.AbstractJob;
+import org.apache.mahout.common.ClassUtils;
+import org.apache.mahout.common.HadoopUtil;
+import org.apache.mahout.common.commandline.DefaultOptionCreator;
+import org.apache.mahout.common.distance.DistanceMeasure;
+import org.apache.mahout.common.distance.EuclideanDistanceMeasure;
+import org.apache.mahout.common.distance.SquaredEuclideanDistanceMeasure;
+import org.apache.mahout.utils.clustering.ClusterDumper;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+public final class Job extends AbstractJob {
+ private static final Logger log = LoggerFactory.getLogger(Job.class);
+ private static final String DIRECTORY_CONTAINING_CONVERTED_INPUT = "data";
+ private static final String M_OPTION = FuzzyKMeansDriver.M_OPTION;
+ private Job() {
+ }
+ public static void main(String[] args) throws Exception {
+ if (args.length > 0) {
+ log.info("Running with only user-supplied arguments");
+ ToolRunner.run(new Configuration(), new Job(), args);
+ } else {
+ log.info("Running with default arguments");
+ Path output = new Path("output");
+ Configuration conf = new Configuration();
+ HadoopUtil.delete(conf, output);
+ run(conf, new Path("testdata"), output, new EuclideanDistanceMeasure(), 80, 55, 10, 2.0f, 0.5);
+ }
+ }
+ @Override
+ public int run(String[] args) throws Exception {
+ addInputOption();
+ addOutputOption();
+ addOption(DefaultOptionCreator.distanceMeasureOption().create());
+ addOption(DefaultOptionCreator.convergenceOption().create());
+ addOption(DefaultOptionCreator.maxIterationsOption().create());
+ addOption(DefaultOptionCreator.overwriteOption().create());
+ addOption(DefaultOptionCreator.t1Option().create());
+ addOption(DefaultOptionCreator.t2Option().create());
+ addOption(M_OPTION, M_OPTION, "coefficient normalization factor, must be greater than 1", true);
+ Map<String,List<String>> argMap = parseArguments(args);
+ if (argMap == null) {
+ return -1;
+ }
+ Path input = getInputPath();
+ Path output = getOutputPath();
+ String measureClass = getOption(DefaultOptionCreator.DISTANCE_MEASURE_OPTION);
+ if (measureClass == null) {
+ measureClass = SquaredEuclideanDistanceMeasure.class.getName();
+ }
+ double convergenceDelta = Double.parseDouble(getOption(DefaultOptionCreator.CONVERGENCE_DELTA_OPTION));
+ int maxIterations = Integer.parseInt(getOption(DefaultOptionCreator.MAX_ITERATIONS_OPTION));
+ float fuzziness = Float.parseFloat(getOption(M_OPTION));
+ addOption(new DefaultOptionBuilder().withLongName(M_OPTION).withRequired(true)
+ .withArgument(new ArgumentBuilder().withName(M_OPTION).withMinimum(1).withMaximum(1).create())
+ .withDescription("coefficient normalization factor, must be greater than 1").withShortName(M_OPTION).create());
+ if (hasOption(DefaultOptionCreator.OVERWRITE_OPTION)) {
+ HadoopUtil.delete(getConf(), output);
+ }
+ DistanceMeasure measure = ClassUtils.instantiateAs(measureClass, DistanceMeasure.class);
+ double t1 = Double.parseDouble(getOption(DefaultOptionCreator.T1_OPTION));
+ double t2 = Double.parseDouble(getOption(DefaultOptionCreator.T2_OPTION));
+ run(getConf(), input, output, measure, t1, t2, maxIterations, fuzziness, convergenceDelta);
+ return 0;
+ }
+ /**
+ * Run the kmeans clustering job on an input dataset using the given distance measure, t1, t2 and iteration
+ * parameters. All output data will be written to the output directory, which will be initially deleted if it exists.
+ * The clustered points will reside in the path <output>/clustered-points. By default, the job expects the a file
+ * containing synthetic_control.data as obtained from
+ * http://archive.ics.uci.edu/ml/datasets/Synthetic+Control+Chart+Time+Series resides in a directory named "testdata",
+ * and writes output to a directory named "output".
+ *
+ * @param input
+ * the String denoting the input directory path
+ * @param output
+ * the String denoting the output directory path
+ * @param t1
+ * the canopy T1 threshold
+ * @param t2
+ * the canopy T2 threshold
+ * @param maxIterations
+ * the int maximum number of iterations
+ * @param fuzziness
+ * the float "m" fuzziness coefficient
+ * @param convergenceDelta
+ * the double convergence criteria for iterations
+ */
+ public static void run(Configuration conf, Path input, Path output, DistanceMeasure measure, double t1, double t2,
+ int maxIterations, float fuzziness, double convergenceDelta) throws Exception {
+ Path directoryContainingConvertedInput = new Path(output, DIRECTORY_CONTAINING_CONVERTED_INPUT);
+ log.info("Preparing Input");
+ InputDriver.runJob(input, directoryContainingConvertedInput, "org.apache.mahout.math.RandomAccessSparseVector");
+ log.info("Running Canopy to get initial clusters");
+ Path canopyOutput = new Path(output, "canopies");
+ CanopyDriver.run(new Configuration(), directoryContainingConvertedInput, canopyOutput, measure, t1, t2, false, 0.0, false);
+ log.info("Running FuzzyKMeans");
+ FuzzyKMeansDriver.run(directoryContainingConvertedInput, new Path(canopyOutput, "clusters-0-final"), output,
+ convergenceDelta, maxIterations, fuzziness, true, true, 0.0, false);
+ // run ClusterDumper
+ ClusterDumper clusterDumper = new ClusterDumper(new Path(output, "clusters-*-final"), new Path(output, "clusteredPoints"));
+ clusterDumper.printClusters(null);
+ }

diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/clustering/syntheticcontrol/kmeans/Job.java b/community/mahout-mr/examples/src/main/java/org/apache/mahout/clustering/syntheticcontrol/kmeans/Job.java
new file mode 100644
index 0000000..70c41fe
--- /dev/null
+++ b/community/mahout-mr/examples/src/main/java/org/apache/mahout/clustering/syntheticcontrol/kmeans/Job.java
@@ -0,0 +1,187 @@
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.clustering.syntheticcontrol.kmeans;
+import java.util.List;
+import java.util.Map;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.util.ToolRunner;
+import org.apache.mahout.clustering.Cluster;
+import org.apache.mahout.clustering.canopy.CanopyDriver;
+import org.apache.mahout.clustering.conversion.InputDriver;
+import org.apache.mahout.clustering.kmeans.KMeansDriver;
+import org.apache.mahout.clustering.kmeans.RandomSeedGenerator;
+import org.apache.mahout.common.AbstractJob;
+import org.apache.mahout.common.ClassUtils;
+import org.apache.mahout.common.HadoopUtil;
+import org.apache.mahout.common.commandline.DefaultOptionCreator;
+import org.apache.mahout.common.distance.DistanceMeasure;
+import org.apache.mahout.common.distance.EuclideanDistanceMeasure;
+import org.apache.mahout.common.distance.SquaredEuclideanDistanceMeasure;
+import org.apache.mahout.utils.clustering.ClusterDumper;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+public final class Job extends AbstractJob {
+ private static final Logger log = LoggerFactory.getLogger(Job.class);
+ private static final String DIRECTORY_CONTAINING_CONVERTED_INPUT = "data";
+ private Job() {
+ }
+ public static void main(String[] args) throws Exception {
+ if (args.length > 0) {
+ log.info("Running with only user-supplied arguments");
+ ToolRunner.run(new Configuration(), new Job(), args);
+ } else {
+ log.info("Running with default arguments");
+ Path output = new Path("output");
+ Configuration conf = new Configuration();
+ HadoopUtil.delete(conf, output);
+ run(conf, new Path("testdata"), output, new EuclideanDistanceMeasure(), 6, 0.5, 10);
+ }
+ }
+ @Override
+ public int run(String[] args) throws Exception {
+ addInputOption();
+ addOutputOption();
+ addOption(DefaultOptionCreator.distanceMeasureOption().create());
+ addOption(DefaultOptionCreator.numClustersOption().create());
+ addOption(DefaultOptionCreator.t1Option().create());
+ addOption(DefaultOptionCreator.t2Option().create());
+ addOption(DefaultOptionCreator.convergenceOption().create());
+ addOption(DefaultOptionCreator.maxIterationsOption().create());
+ addOption(DefaultOptionCreator.overwriteOption().create());
+ Map<String,List<String>> argMap = parseArguments(args);
+ if (argMap == null) {
+ return -1;
+ }
+ Path input = getInputPath();
+ Path output = getOutputPath();
+ String measureClass = getOption(DefaultOptionCreator.DISTANCE_MEASURE_OPTION);
+ if (measureClass == null) {
+ measureClass = SquaredEuclideanDistanceMeasure.class.getName();
+ }
+ double convergenceDelta = Double.parseDouble(getOption(DefaultOptionCreator.CONVERGENCE_DELTA_OPTION));
+ int maxIterations = Integer.parseInt(getOption(DefaultOptionCreator.MAX_ITERATIONS_OPTION));
+ if (hasOption(DefaultOptionCreator.OVERWRITE_OPTION)) {
+ HadoopUtil.delete(getConf(), output);
+ }
+ DistanceMeasure measure = ClassUtils.instantiateAs(measureClass, DistanceMeasure.class);
+ if (hasOption(DefaultOptionCreator.NUM_CLUSTERS_OPTION)) {
+ int k = Integer.parseInt(getOption(DefaultOptionCreator.NUM_CLUSTERS_OPTION));
+ run(getConf(), input, output, measure, k, convergenceDelta, maxIterations);
+ } else {
+ double t1 = Double.parseDouble(getOption(DefaultOptionCreator.T1_OPTION));
+ double t2 = Double.parseDouble(getOption(DefaultOptionCreator.T2_OPTION));
+ run(getConf(), input, output, measure, t1, t2, convergenceDelta, maxIterations);
+ }
+ return 0;
+ }
+ /**
+ * Run the kmeans clustering job on an input dataset using the given the number of clusters k and iteration
+ * parameters. All output data will be written to the output directory, which will be initially deleted if it exists.
+ * The clustered points will reside in the path <output>/clustered-points. By default, the job expects a file
+ * containing equal length space delimited data that resides in a directory named "testdata", and writes output to a
+ * directory named "output".
+ *
+ * @param conf
+ * the Configuration to use
+ * @param input
+ * the String denoting the input directory path
+ * @param output
+ * the String denoting the output directory path
+ * @param measure
+ * the DistanceMeasure to use
+ * @param k
+ * the number of clusters in Kmeans
+ * @param convergenceDelta
+ * the double convergence criteria for iterations
+ * @param maxIterations
+ * the int maximum number of iterations
+ */
+ public static void run(Configuration conf, Path input, Path output, DistanceMeasure measure, int k,
+ double convergenceDelta, int maxIterations) throws Exception {
+ Path directoryContainingConvertedInput = new Path(output, DIRECTORY_CONTAINING_CONVERTED_INPUT);
+ log.info("Preparing Input");
+ InputDriver.runJob(input, directoryContainingConvertedInput, "org.apache.mahout.math.RandomAccessSparseVector");
+ log.info("Running random seed to get initial clusters");
+ Path clusters = new Path(output, "random-seeds");
+ clusters = RandomSeedGenerator.buildRandom(conf, directoryContainingConvertedInput, clusters, k, measure);
+ log.info("Running KMeans with k = {}", k);
+ KMeansDriver.run(conf, directoryContainingConvertedInput, clusters, output, convergenceDelta,
+ maxIterations, true, 0.0, false);
+ // run ClusterDumper
+ Path outGlob = new Path(output, "clusters-*-final");
+ Path clusteredPoints = new Path(output,"clusteredPoints");
+ log.info("Dumping out clusters from clusters: {} and clusteredPoints: {}", outGlob, clusteredPoints);
+ ClusterDumper clusterDumper = new ClusterDumper(outGlob, clusteredPoints);
+ clusterDumper.printClusters(null);
+ }
+ /**
+ * Run the kmeans clustering job on an input dataset using the given distance measure, t1, t2 and iteration
+ * parameters. All output data will be written to the output directory, which will be initially deleted if it exists.
+ * The clustered points will reside in the path <output>/clustered-points. By default, the job expects the a file
+ * containing synthetic_control.data as obtained from
+ * http://archive.ics.uci.edu/ml/datasets/Synthetic+Control+Chart+Time+Series resides in a directory named "testdata",
+ * and writes output to a directory named "output".
+ *
+ * @param conf
+ * the Configuration to use
+ * @param input
+ * the String denoting the input directory path
+ * @param output
+ * the String denoting the output directory path
+ * @param measure
+ * the DistanceMeasure to use
+ * @param t1
+ * the canopy T1 threshold
+ * @param t2
+ * the canopy T2 threshold
+ * @param convergenceDelta
+ * the double convergence criteria for iterations
+ * @param maxIterations
+ * the int maximum number of iterations
+ */
+ public static void run(Configuration conf, Path input, Path output, DistanceMeasure measure, double t1, double t2,
+ double convergenceDelta, int maxIterations) throws Exception {
+ Path directoryContainingConvertedInput = new Path(output, DIRECTORY_CONTAINING_CONVERTED_INPUT);
+ log.info("Preparing Input");
+ InputDriver.runJob(input, directoryContainingConvertedInput, "org.apache.mahout.math.RandomAccessSparseVector");
+ log.info("Running Canopy to get initial clusters");
+ Path canopyOutput = new Path(output, "canopies");
+ CanopyDriver.run(new Configuration(), directoryContainingConvertedInput, canopyOutput, measure, t1, t2, false, 0.0,
+ false);
+ log.info("Running KMeans");
+ KMeansDriver.run(conf, directoryContainingConvertedInput, new Path(canopyOutput, Cluster.INITIAL_CLUSTERS_DIR
+ + "-final"), output, convergenceDelta, maxIterations, true, 0.0, false);
+ // run ClusterDumper
+ ClusterDumper clusterDumper = new ClusterDumper(new Path(output, "clusters-*-final"), new Path(output,
+ "clusteredPoints"));
+ clusterDumper.printClusters(null);
+ }

diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/fpm/pfpgrowth/DeliciousTagsExample.java b/community/mahout-mr/examples/src/main/java/org/apache/mahout/fpm/pfpgrowth/DeliciousTagsExample.java
new file mode 100644
index 0000000..92363e5
--- /dev/null
+++ b/community/mahout-mr/examples/src/main/java/org/apache/mahout/fpm/pfpgrowth/DeliciousTagsExample.java
@@ -0,0 +1,94 @@
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.fpm.pfpgrowth;
+import java.io.IOException;
+import org.apache.commons.cli2.CommandLine;
+import org.apache.commons.cli2.Group;
+import org.apache.commons.cli2.Option;
+import org.apache.commons.cli2.OptionException;
+import org.apache.commons.cli2.builder.ArgumentBuilder;
+import org.apache.commons.cli2.builder.DefaultOptionBuilder;
+import org.apache.commons.cli2.builder.GroupBuilder;
+import org.apache.commons.cli2.commandline.Parser;
+import org.apache.mahout.common.CommandLineUtil;
+import org.apache.mahout.common.Parameters;
+import org.apache.mahout.common.commandline.DefaultOptionCreator;
+import org.apache.mahout.fpm.pfpgrowth.dataset.KeyBasedStringTupleGrouper;
+public final class DeliciousTagsExample {
+ private DeliciousTagsExample() { }
+ public static void main(String[] args) throws IOException, InterruptedException, ClassNotFoundException {
+ DefaultOptionBuilder obuilder = new DefaultOptionBuilder();
+ ArgumentBuilder abuilder = new ArgumentBuilder();
+ GroupBuilder gbuilder = new GroupBuilder();
+ Option inputDirOpt = DefaultOptionCreator.inputOption().create();
+ Option outputOpt = DefaultOptionCreator.outputOption().create();
+ Option helpOpt = DefaultOptionCreator.helpOption();
+ Option recordSplitterOpt = obuilder.withLongName("splitterPattern").withArgument(
+ abuilder.withName("splitterPattern").withMinimum(1).withMaximum(1).create()).withDescription(
+ "Regular Expression pattern used to split given line into fields."
+ + " Default value splits comma or tab separated fields."
+ + " Default Value: \"[ ,\\t]*\\t[ ,\\t]*\" ").withShortName("regex").create();
+ Option encodingOpt = obuilder.withLongName("encoding").withArgument(
+ abuilder.withName("encoding").withMinimum(1).withMaximum(1).create()).withDescription(
+ "(Optional) The file encoding. Default value: UTF-8").withShortName("e").create();
+ Group group = gbuilder.withName("Options").withOption(inputDirOpt).withOption(outputOpt).withOption(
+ helpOpt).withOption(recordSplitterOpt).withOption(encodingOpt).create();
+ try {
+ Parser parser = new Parser();
+ parser.setGroup(group);
+ CommandLine cmdLine = parser.parse(args);
+ if (cmdLine.hasOption(helpOpt)) {
+ CommandLineUtil.printHelp(group);
+ return;
+ }
+ Parameters params = new Parameters();
+ if (cmdLine.hasOption(recordSplitterOpt)) {
+ params.set("splitPattern", (String) cmdLine.getValue(recordSplitterOpt));
+ }
+ String encoding = "UTF-8";
+ if (cmdLine.hasOption(encodingOpt)) {
+ encoding = (String) cmdLine.getValue(encodingOpt);
+ }
+ params.set("encoding", encoding);
+ String inputDir = (String) cmdLine.getValue(inputDirOpt);
+ String outputDir = (String) cmdLine.getValue(outputOpt);
+ params.set("input", inputDir);
+ params.set("output", outputDir);
+ params.set("groupingFieldCount", "2");
+ params.set("gfield0", "1");
+ params.set("gfield1", "2");
+ params.set("selectedFieldCount", "1");
+ params.set("field0", "3");
+ params.set("maxTransactionLength", "100");
+ KeyBasedStringTupleGrouper.startJob(params);
+ } catch (OptionException ex) {
+ CommandLineUtil.printHelp(group);
+ }
+ }

diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/fpm/pfpgrowth/dataset/KeyBasedStringTupleCombiner.java b/community/mahout-mr/examples/src/main/java/org/apache/mahout/fpm/pfpgrowth/dataset/KeyBasedStringTupleCombiner.java
new file mode 100644
index 0000000..4c80a31
--- /dev/null
+++ b/community/mahout-mr/examples/src/main/java/org/apache/mahout/fpm/pfpgrowth/dataset/KeyBasedStringTupleCombiner.java
@@ -0,0 +1,40 @@
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.fpm.pfpgrowth.dataset;
+import java.io.IOException;
+import java.util.HashSet;
+import java.util.Set;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.mapreduce.Reducer;
+import org.apache.mahout.common.StringTuple;
+public class KeyBasedStringTupleCombiner extends Reducer<Text,StringTuple,Text,StringTuple> {
+ @Override
+ protected void reduce(Text key,
+ Iterable<StringTuple> values,
+ Context context) throws IOException, InterruptedException {
+ Set<String> outputValues = new HashSet<>();
+ for (StringTuple value : values) {
+ outputValues.addAll(value.getEntries());
+ }
+ context.write(key, new StringTuple(outputValues));
+ }

diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/fpm/pfpgrowth/dataset/KeyBasedStringTupleGrouper.java b/community/mahout-mr/examples/src/main/java/org/apache/mahout/fpm/pfpgrowth/dataset/KeyBasedStringTupleGrouper.java
new file mode 100644
index 0000000..cd17770
--- /dev/null
+++ b/community/mahout-mr/examples/src/main/java/org/apache/mahout/fpm/pfpgrowth/dataset/KeyBasedStringTupleGrouper.java
@@ -0,0 +1,77 @@
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.fpm.pfpgrowth.dataset;
+import java.io.IOException;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.mapreduce.Job;
+import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;
+import org.apache.hadoop.mapreduce.lib.input.TextInputFormat;
+import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;
+import org.apache.hadoop.mapreduce.lib.output.TextOutputFormat;
+import org.apache.mahout.common.HadoopUtil;
+import org.apache.mahout.common.Parameters;
+import org.apache.mahout.common.StringTuple;
+public final class KeyBasedStringTupleGrouper {
+ private KeyBasedStringTupleGrouper() { }
+ public static void startJob(Parameters params) throws IOException,
+ InterruptedException,
+ ClassNotFoundException {
+ Configuration conf = new Configuration();
+ conf.set("job.parameters", params.toString());
+ conf.set("mapred.compress.map.output", "true");
+ conf.set("mapred.output.compression.type", "BLOCK");
+ conf.set("mapred.map.output.compression.codec", "org.apache.hadoop.io.compress.GzipCodec");
+ conf.set("io.serializations", "org.apache.hadoop.io.serializer.JavaSerialization,"
+ + "org.apache.hadoop.io.serializer.WritableSerialization");
+ String input = params.get("input");
+ Job job = new Job(conf, "Generating dataset based from input" + input);
+ job.setJarByClass(KeyBasedStringTupleGrouper.class);
+ job.setMapOutputKeyClass(Text.class);
+ job.setMapOutputValueClass(StringTuple.class);
+ job.setOutputKeyClass(Text.class);
+ job.setOutputValueClass(Text.class);
+ FileInputFormat.addInputPath(job, new Path(input));
+ Path outPath = new Path(params.get("output"));
+ FileOutputFormat.setOutputPath(job, outPath);
+ HadoopUtil.delete(conf, outPath);
+ job.setInputFormatClass(TextInputFormat.class);
+ job.setMapperClass(KeyBasedStringTupleMapper.class);
+ job.setCombinerClass(KeyBasedStringTupleCombiner.class);
+ job.setReducerClass(KeyBasedStringTupleReducer.class);
+ job.setOutputFormatClass(TextOutputFormat.class);
+ boolean succeeded = job.waitForCompletion(true);
+ if (!succeeded) {
+ throw new IllegalStateException("Job failed!");
+ }
+ }

diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/fpm/pfpgrowth/dataset/KeyBasedStringTupleMapper.java b/community/mahout-mr/examples/src/main/java/org/apache/mahout/fpm/pfpgrowth/dataset/KeyBasedStringTupleMapper.java
new file mode 100644
index 0000000..362d1ce
--- /dev/null
+++ b/community/mahout-mr/examples/src/main/java/org/apache/mahout/fpm/pfpgrowth/dataset/KeyBasedStringTupleMapper.java
@@ -0,0 +1,90 @@
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.fpm.pfpgrowth.dataset;
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.List;
+import java.util.regex.Pattern;
+import org.apache.hadoop.io.LongWritable;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.mapreduce.Mapper;
+import org.apache.mahout.common.Parameters;
+import org.apache.mahout.common.StringTuple;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+ * Splits the line using a {@link Pattern} and outputs key as given by the groupingFields
+ *
+ */
+public class KeyBasedStringTupleMapper extends Mapper<LongWritable,Text,Text,StringTuple> {
+ private static final Logger log = LoggerFactory.getLogger(KeyBasedStringTupleMapper.class);
+ private Pattern splitter;
+ private int[] selectedFields;
+ private int[] groupingFields;
+ @Override
+ protected void map(LongWritable key, Text value, Context context) throws IOException, InterruptedException {
+ String[] fields = splitter.split(value.toString());
+ if (fields.length != 4) {
+ log.info("{} {}", fields.length, value.toString());
+ context.getCounter("Map", "ERROR").increment(1);
+ return;
+ }
+ Collection<String> oKey = new ArrayList<>();
+ for (int groupingField : groupingFields) {
+ oKey.add(fields[groupingField]);
+ context.setStatus(fields[groupingField]);
+ }
+ List<String> oValue = new ArrayList<>();
+ for (int selectedField : selectedFields) {
+ oValue.add(fields[selectedField]);
+ }
+ context.write(new Text(oKey.toString()), new StringTuple(oValue));
+ }
+ @Override
+ protected void setup(Context context) throws IOException, InterruptedException {
+ super.setup(context);
+ Parameters params = new Parameters(context.getConfiguration().get("job.parameters", ""));
+ splitter = Pattern.compile(params.get("splitPattern", "[ \t]*\t[ \t]*"));
+ int selectedFieldCount = Integer.valueOf(params.get("selectedFieldCount", "0"));
+ selectedFields = new int[selectedFieldCount];
+ for (int i = 0; i < selectedFieldCount; i++) {
+ selectedFields[i] = Integer.valueOf(params.get("field" + i, "0"));
+ }
+ int groupingFieldCount = Integer.valueOf(params.get("groupingFieldCount", "0"));
+ groupingFields = new int[groupingFieldCount];
+ for (int i = 0; i < groupingFieldCount; i++) {
+ groupingFields[i] = Integer.valueOf(params.get("gfield" + i, "0"));
+ }
+ }

diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/fpm/pfpgrowth/dataset/KeyBasedStringTupleReducer.java b/community/mahout-mr/examples/src/main/java/org/apache/mahout/fpm/pfpgrowth/dataset/KeyBasedStringTupleReducer.java
new file mode 100644
index 0000000..a7ef762
--- /dev/null
+++ b/community/mahout-mr/examples/src/main/java/org/apache/mahout/fpm/pfpgrowth/dataset/KeyBasedStringTupleReducer.java
@@ -0,0 +1,74 @@
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.fpm.pfpgrowth.dataset;
+import java.io.IOException;
+import java.util.Collection;
+import java.util.HashSet;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.mapreduce.Reducer;
+import org.apache.mahout.common.Parameters;
+import org.apache.mahout.common.StringTuple;
+public class KeyBasedStringTupleReducer extends Reducer<Text,StringTuple,Text,Text> {
+ private int maxTransactionLength = 100;
+ @Override
+ protected void reduce(Text key, Iterable<StringTuple> values, Context context)
+ throws IOException, InterruptedException {
+ Collection<String> items = new HashSet<>();
+ for (StringTuple value : values) {
+ for (String field : value.getEntries()) {
+ items.add(field);
+ }
+ }
+ if (items.size() > 1) {
+ int i = 0;
+ StringBuilder sb = new StringBuilder();
+ String sep = "";
+ for (String field : items) {
+ if (i % maxTransactionLength == 0) {
+ if (i != 0) {
+ context.write(null, new Text(sb.toString()));
+ }
+ sb.replace(0, sb.length(), "");
+ sep = "";
+ }
+ sb.append(sep).append(field);
+ sep = "\t";
+ i++;
+ }
+ if (sb.length() > 0) {
+ context.write(null, new Text(sb.toString()));
+ }
+ }
+ }
+ @Override
+ protected void setup(Context context) throws IOException, InterruptedException {
+ super.setup(context);
+ Parameters params = new Parameters(context.getConfiguration().get("job.parameters", ""));
+ maxTransactionLength = Integer.valueOf(params.get("maxTransactionLength", "100"));
+ }
2018-06-27 13:14:44 UTC
diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sgd/AdaptiveLogisticModelParameters.java b/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sgd/AdaptiveLogisticModelParameters.java
new file mode 100644
index 0000000..b2ce8b1
--- /dev/null
+++ b/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sgd/AdaptiveLogisticModelParameters.java
@@ -0,0 +1,236 @@
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.classifier.sgd;
+import org.apache.mahout.math.stats.GlobalOnlineAuc;
+import org.apache.mahout.math.stats.GroupedOnlineAuc;
+import org.apache.mahout.math.stats.OnlineAuc;
+import java.io.DataInput;
+import java.io.DataInputStream;
+import java.io.DataOutput;
+import java.io.DataOutputStream;
+import java.io.File;
+import java.io.FileInputStream;
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.OutputStream;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Locale;
+import java.util.Map;
+public class AdaptiveLogisticModelParameters extends LogisticModelParameters {
+ private AdaptiveLogisticRegression alr;
+ private int interval = 800;
+ private int averageWindow = 500;
+ private int threads = 4;
+ private String prior = "L1";
+ private double priorOption = Double.NaN;
+ private String auc = null;
+ public AdaptiveLogisticRegression createAdaptiveLogisticRegression() {
+ if (alr == null) {
+ alr = new AdaptiveLogisticRegression(getMaxTargetCategories(),
+ getNumFeatures(), createPrior(prior, priorOption));
+ alr.setInterval(interval);
+ alr.setAveragingWindow(averageWindow);
+ alr.setThreadCount(threads);
+ alr.setAucEvaluator(createAUC(auc));
+ }
+ return alr;
+ }
+ public void checkParameters() {
+ if (prior != null) {
+ String priorUppercase = prior.toUpperCase(Locale.ENGLISH).trim();
+ if (("TP".equals(priorUppercase) || "EBP".equals(priorUppercase)) && Double.isNaN(priorOption)) {
+ throw new IllegalArgumentException("You must specify a double value for TPrior and ElasticBandPrior.");
+ }
+ }
+ }
+ private static PriorFunction createPrior(String cmd, double priorOption) {
+ if (cmd == null) {
+ return null;
+ }
+ if ("L1".equals(cmd.toUpperCase(Locale.ENGLISH).trim())) {
+ return new L1();
+ }
+ if ("L2".equals(cmd.toUpperCase(Locale.ENGLISH).trim())) {
+ return new L2();
+ }
+ if ("UP".equals(cmd.toUpperCase(Locale.ENGLISH).trim())) {
+ return new UniformPrior();
+ }
+ if ("TP".equals(cmd.toUpperCase(Locale.ENGLISH).trim())) {
+ return new TPrior(priorOption);
+ }
+ if ("EBP".equals(cmd.toUpperCase(Locale.ENGLISH).trim())) {
+ return new ElasticBandPrior(priorOption);
+ }
+ return null;
+ }
+ private static OnlineAuc createAUC(String cmd) {
+ if (cmd == null) {
+ return null;
+ }
+ if ("GLOBAL".equals(cmd.toUpperCase(Locale.ENGLISH).trim())) {
+ return new GlobalOnlineAuc();
+ }
+ if ("GROUPED".equals(cmd.toUpperCase(Locale.ENGLISH).trim())) {
+ return new GroupedOnlineAuc();
+ }
+ return null;
+ }
+ @Override
+ public void saveTo(OutputStream out) throws IOException {
+ if (alr != null) {
+ alr.close();
+ }
+ setTargetCategories(getCsvRecordFactory().getTargetCategories());
+ write(new DataOutputStream(out));
+ }
+ @Override
+ public void write(DataOutput out) throws IOException {
+ out.writeUTF(getTargetVariable());
+ out.writeInt(getTypeMap().size());
+ for (Map.Entry<String, String> entry : getTypeMap().entrySet()) {
+ out.writeUTF(entry.getKey());
+ out.writeUTF(entry.getValue());
+ }
+ out.writeInt(getNumFeatures());
+ out.writeInt(getMaxTargetCategories());
+ out.writeInt(getTargetCategories().size());
+ for (String category : getTargetCategories()) {
+ out.writeUTF(category);
+ }
+ out.writeInt(interval);
+ out.writeInt(averageWindow);
+ out.writeInt(threads);
+ out.writeUTF(prior);
+ out.writeDouble(priorOption);
+ out.writeUTF(auc);
+ // skip csv
+ alr.write(out);
+ }
+ @Override
+ public void readFields(DataInput in) throws IOException {
+ setTargetVariable(in.readUTF());
+ int typeMapSize = in.readInt();
+ Map<String, String> typeMap = new HashMap<>(typeMapSize);
+ for (int i = 0; i < typeMapSize; i++) {
+ String key = in.readUTF();
+ String value = in.readUTF();
+ typeMap.put(key, value);
+ }
+ setTypeMap(typeMap);
+ setNumFeatures(in.readInt());
+ setMaxTargetCategories(in.readInt());
+ int targetCategoriesSize = in.readInt();
+ List<String> targetCategories = new ArrayList<>(targetCategoriesSize);
+ for (int i = 0; i < targetCategoriesSize; i++) {
+ targetCategories.add(in.readUTF());
+ }
+ setTargetCategories(targetCategories);
+ interval = in.readInt();
+ averageWindow = in.readInt();
+ threads = in.readInt();
+ prior = in.readUTF();
+ priorOption = in.readDouble();
+ auc = in.readUTF();
+ alr = new AdaptiveLogisticRegression();
+ alr.readFields(in);
+ }
+ private static AdaptiveLogisticModelParameters loadFromStream(InputStream in) throws IOException {
+ AdaptiveLogisticModelParameters result = new AdaptiveLogisticModelParameters();
+ result.readFields(new DataInputStream(in));
+ return result;
+ }
+ public static AdaptiveLogisticModelParameters loadFromFile(File in) throws IOException {
+ try (InputStream input = new FileInputStream(in)) {
+ return loadFromStream(input);
+ }
+ }
+ public int getInterval() {
+ return interval;
+ }
+ public void setInterval(int interval) {
+ this.interval = interval;
+ }
+ public int getAverageWindow() {
+ return averageWindow;
+ }
+ public void setAverageWindow(int averageWindow) {
+ this.averageWindow = averageWindow;
+ }
+ public int getThreads() {
+ return threads;
+ }
+ public void setThreads(int threads) {
+ this.threads = threads;
+ }
+ public String getPrior() {
+ return prior;
+ }
+ public void setPrior(String prior) {
+ this.prior = prior;
+ }
+ public String getAuc() {
+ return auc;
+ }
+ public void setAuc(String auc) {
+ this.auc = auc;
+ }
+ public double getPriorOption() {
+ return priorOption;
+ }
+ public void setPriorOption(double priorOption) {
+ this.priorOption = priorOption;
+ }

diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sgd/LogisticModelParameters.java b/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sgd/LogisticModelParameters.java
new file mode 100644
index 0000000..e762924
--- /dev/null
+++ b/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sgd/LogisticModelParameters.java
@@ -0,0 +1,265 @@
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.classifier.sgd;
+import com.google.common.base.Preconditions;
+import com.google.common.io.Closeables;
+import java.io.DataInput;
+import java.io.DataInputStream;
+import java.io.DataOutput;
+import java.io.DataOutputStream;
+import java.io.File;
+import java.io.FileInputStream;
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.OutputStream;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Map;
+import org.apache.hadoop.io.Writable;
+ * Encapsulates everything we need to know about a model and how it reads and vectorizes its input.
+ * This encapsulation allows us to coherently save and restore a model from a file. This also
+ * allows us to keep command line arguments that affect learning in a coherent way.
+ */
+public class LogisticModelParameters implements Writable {
+ private String targetVariable;
+ private Map<String, String> typeMap;
+ private int numFeatures;
+ private boolean useBias;
+ private int maxTargetCategories;
+ private List<String> targetCategories;
+ private double lambda;
+ private double learningRate;
+ private CsvRecordFactory csv;
+ private OnlineLogisticRegression lr;
+ /**
+ * Returns a CsvRecordFactory compatible with this logistic model. The reason that this is tied
+ * in here is so that we have access to the list of target categories when it comes time to save
+ * the model. If the input isn't CSV, then calling setTargetCategories before calling saveTo will
+ * suffice.
+ *
+ * @return The CsvRecordFactory.
+ */
+ public CsvRecordFactory getCsvRecordFactory() {
+ if (csv == null) {
+ csv = new CsvRecordFactory(getTargetVariable(), getTypeMap())
+ .maxTargetValue(getMaxTargetCategories())
+ .includeBiasTerm(useBias());
+ if (targetCategories != null) {
+ csv.defineTargetCategories(targetCategories);
+ }
+ }
+ return csv;
+ }
+ /**
+ * Creates a logistic regression trainer using the parameters collected here.
+ *
+ * @return The newly allocated OnlineLogisticRegression object
+ */
+ public OnlineLogisticRegression createRegression() {
+ if (lr == null) {
+ lr = new OnlineLogisticRegression(getMaxTargetCategories(), getNumFeatures(), new L1())
+ .lambda(getLambda())
+ .learningRate(getLearningRate())
+ .alpha(1 - 1.0e-3);
+ }
+ return lr;
+ }
+ /**
+ * Saves a model to an output stream.
+ */
+ public void saveTo(OutputStream out) throws IOException {
+ Closeables.close(lr, false);
+ targetCategories = getCsvRecordFactory().getTargetCategories();
+ write(new DataOutputStream(out));
+ }
+ /**
+ * Reads a model from a stream.
+ */
+ public static LogisticModelParameters loadFrom(InputStream in) throws IOException {
+ LogisticModelParameters result = new LogisticModelParameters();
+ result.readFields(new DataInputStream(in));
+ return result;
+ }
+ /**
+ * Reads a model from a file.
+ * @throws IOException If there is an error opening or closing the file.
+ */
+ public static LogisticModelParameters loadFrom(File in) throws IOException {
+ try (InputStream input = new FileInputStream(in)) {
+ return loadFrom(input);
+ }
+ }
+ @Override
+ public void write(DataOutput out) throws IOException {
+ out.writeUTF(targetVariable);
+ out.writeInt(typeMap.size());
+ for (Map.Entry<String,String> entry : typeMap.entrySet()) {
+ out.writeUTF(entry.getKey());
+ out.writeUTF(entry.getValue());
+ }
+ out.writeInt(numFeatures);
+ out.writeBoolean(useBias);
+ out.writeInt(maxTargetCategories);
+ if (targetCategories == null) {
+ out.writeInt(0);
+ } else {
+ out.writeInt(targetCategories.size());
+ for (String category : targetCategories) {
+ out.writeUTF(category);
+ }
+ }
+ out.writeDouble(lambda);
+ out.writeDouble(learningRate);
+ // skip csv
+ lr.write(out);
+ }
+ @Override
+ public void readFields(DataInput in) throws IOException {
+ targetVariable = in.readUTF();
+ int typeMapSize = in.readInt();
+ typeMap = new HashMap<>(typeMapSize);
+ for (int i = 0; i < typeMapSize; i++) {
+ String key = in.readUTF();
+ String value = in.readUTF();
+ typeMap.put(key, value);
+ }
+ numFeatures = in.readInt();
+ useBias = in.readBoolean();
+ maxTargetCategories = in.readInt();
+ int targetCategoriesSize = in.readInt();
+ targetCategories = new ArrayList<>(targetCategoriesSize);
+ for (int i = 0; i < targetCategoriesSize; i++) {
+ targetCategories.add(in.readUTF());
+ }
+ lambda = in.readDouble();
+ learningRate = in.readDouble();
+ csv = null;
+ lr = new OnlineLogisticRegression();
+ lr.readFields(in);
+ }
+ /**
+ * Sets the types of the predictors. This will later be used when reading CSV data. If you don't
+ * use the CSV data and convert to vectors on your own, you don't need to call this.
+ *
+ * @param predictorList The list of variable names.
+ * @param typeList The list of types in the format preferred by CsvRecordFactory.
+ */
+ public void setTypeMap(Iterable<String> predictorList, List<String> typeList) {
+ Preconditions.checkArgument(!typeList.isEmpty(), "Must have at least one type specifier");
+ typeMap = new HashMap<>();
+ Iterator<String> iTypes = typeList.iterator();
+ String lastType = null;
+ for (Object x : predictorList) {
+ // type list can be short .. we just repeat last spec
+ if (iTypes.hasNext()) {
+ lastType = iTypes.next();
+ }
+ typeMap.put(x.toString(), lastType);
+ }
+ }
+ /**
+ * Sets the target variable. If you don't use the CSV record factory, then this is irrelevant.
+ *
+ * @param targetVariable The name of the target variable.
+ */
+ public void setTargetVariable(String targetVariable) {
+ this.targetVariable = targetVariable;
+ }
+ /**
+ * Sets the number of target categories to be considered.
+ *
+ * @param maxTargetCategories The number of target categories.
+ */
+ public void setMaxTargetCategories(int maxTargetCategories) {
+ this.maxTargetCategories = maxTargetCategories;
+ }
+ public void setNumFeatures(int numFeatures) {
+ this.numFeatures = numFeatures;
+ }
+ public void setTargetCategories(List<String> targetCategories) {
+ this.targetCategories = targetCategories;
+ maxTargetCategories = targetCategories.size();
+ }
+ public List<String> getTargetCategories() {
+ return this.targetCategories;
+ }
+ public void setUseBias(boolean useBias) {
+ this.useBias = useBias;
+ }
+ public boolean useBias() {
+ return useBias;
+ }
+ public String getTargetVariable() {
+ return targetVariable;
+ }
+ public Map<String, String> getTypeMap() {
+ return typeMap;
+ }
+ public void setTypeMap(Map<String, String> map) {
+ this.typeMap = map;
+ }
+ public int getNumFeatures() {
+ return numFeatures;
+ }
+ public int getMaxTargetCategories() {
+ return maxTargetCategories;
+ }
+ public double getLambda() {
+ return lambda;
+ }
+ public void setLambda(double lambda) {
+ this.lambda = lambda;
+ }
+ public double getLearningRate() {
+ return learningRate;
+ }
+ public void setLearningRate(double learningRate) {
+ this.learningRate = learningRate;
+ }

diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sgd/PrintResourceOrFile.java b/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sgd/PrintResourceOrFile.java
new file mode 100644
index 0000000..3ec6a06
--- /dev/null
+++ b/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sgd/PrintResourceOrFile.java
@@ -0,0 +1,42 @@
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.classifier.sgd;
+import com.google.common.base.Preconditions;
+import java.io.BufferedReader;
+ * Uses the same logic as TrainLogistic and RunLogistic for finding an input, but instead
+ * of processing the input, this class just prints the input to standard out.
+ */
+public final class PrintResourceOrFile {
+ private PrintResourceOrFile() {
+ }
+ public static void main(String[] args) throws Exception {
+ Preconditions.checkArgument(args.length == 1, "Must have a single argument that names a file or resource.");
+ try (BufferedReader in = TrainLogistic.open(args[0])){
+ String line;
+ while ((line = in.readLine()) != null) {
+ System.out.println(line);
+ }
+ }
+ }

diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sgd/RunAdaptiveLogistic.java b/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sgd/RunAdaptiveLogistic.java
new file mode 100644
index 0000000..678a8f5
--- /dev/null
+++ b/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sgd/RunAdaptiveLogistic.java
@@ -0,0 +1,197 @@
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.classifier.sgd;
+import org.apache.commons.cli2.CommandLine;
+import org.apache.commons.cli2.Group;
+import org.apache.commons.cli2.Option;
+import org.apache.commons.cli2.builder.ArgumentBuilder;
+import org.apache.commons.cli2.builder.DefaultOptionBuilder;
+import org.apache.commons.cli2.builder.GroupBuilder;
+import org.apache.commons.cli2.commandline.Parser;
+import org.apache.commons.cli2.util.HelpFormatter;
+import org.apache.commons.io.Charsets;
+import org.apache.mahout.classifier.sgd.AdaptiveLogisticRegression.Wrapper;
+import org.apache.mahout.ep.State;
+import org.apache.mahout.math.SequentialAccessSparseVector;
+import org.apache.mahout.math.Vector;
+import java.io.BufferedReader;
+import java.io.BufferedWriter;
+import java.io.File;
+import java.io.FileOutputStream;
+import java.io.OutputStreamWriter;
+import java.io.PrintWriter;
+import java.util.HashMap;
+import java.util.Map;
+public final class RunAdaptiveLogistic {
+ private static String inputFile;
+ private static String modelFile;
+ private static String outputFile;
+ private static String idColumn;
+ private static boolean maxScoreOnly;
+ private RunAdaptiveLogistic() {
+ }
+ public static void main(String[] args) throws Exception {
+ mainToOutput(args, new PrintWriter(new OutputStreamWriter(System.out, Charsets.UTF_8), true));
+ }
+ static void mainToOutput(String[] args, PrintWriter output) throws Exception {
+ if (!parseArgs(args)) {
+ return;
+ }
+ AdaptiveLogisticModelParameters lmp = AdaptiveLogisticModelParameters
+ .loadFromFile(new File(modelFile));
+ CsvRecordFactory csv = lmp.getCsvRecordFactory();
+ csv.setIdName(idColumn);
+ AdaptiveLogisticRegression lr = lmp.createAdaptiveLogisticRegression();
+ State<Wrapper, CrossFoldLearner> best = lr.getBest();
+ if (best == null) {
+ output.println("AdaptiveLogisticRegression has not be trained probably.");
+ return;
+ }
+ CrossFoldLearner learner = best.getPayload().getLearner();
+ BufferedReader in = TrainAdaptiveLogistic.open(inputFile);
+ int k = 0;
+ try (BufferedWriter out = new BufferedWriter(new OutputStreamWriter(new FileOutputStream(outputFile),
+ Charsets.UTF_8))) {
+ out.write(idColumn + ",target,score");
+ out.newLine();
+ String line = in.readLine();
+ csv.firstLine(line);
+ line = in.readLine();
+ Map<String, Double> results = new HashMap<>();
+ while (line != null) {
+ Vector v = new SequentialAccessSparseVector(lmp.getNumFeatures());
+ csv.processLine(line, v, false);
+ Vector scores = learner.classifyFull(v);
+ results.clear();
+ if (maxScoreOnly) {
+ results.put(csv.getTargetLabel(scores.maxValueIndex()),
+ scores.maxValue());
+ } else {
+ for (int i = 0; i < scores.size(); i++) {
+ results.put(csv.getTargetLabel(i), scores.get(i));
+ }
+ }
+ for (Map.Entry<String, Double> entry : results.entrySet()) {
+ out.write(csv.getIdString(line) + ',' + entry.getKey() + ',' + entry.getValue());
+ out.newLine();
+ }
+ k++;
+ if (k % 100 == 0) {
+ output.println(k + " records processed");
+ }
+ line = in.readLine();
+ }
+ out.flush();
+ }
+ output.println(k + " records processed totally.");
+ }
+ private static boolean parseArgs(String[] args) {
+ DefaultOptionBuilder builder = new DefaultOptionBuilder();
+ Option help = builder.withLongName("help")
+ .withDescription("print this list").create();
+ Option quiet = builder.withLongName("quiet")
+ .withDescription("be extra quiet").create();
+ ArgumentBuilder argumentBuilder = new ArgumentBuilder();
+ Option inputFileOption = builder
+ .withLongName("input")
+ .withRequired(true)
+ .withArgument(
+ argumentBuilder.withName("input").withMaximum(1)
+ .create())
+ .withDescription("where to get training data").create();
+ Option modelFileOption = builder
+ .withLongName("model")
+ .withRequired(true)
+ .withArgument(
+ argumentBuilder.withName("model").withMaximum(1)
+ .create())
+ .withDescription("where to get the trained model").create();
+ Option outputFileOption = builder
+ .withLongName("output")
+ .withRequired(true)
+ .withDescription("the file path to output scores")
+ .withArgument(argumentBuilder.withName("output").withMaximum(1).create())
+ .create();
+ Option idColumnOption = builder
+ .withLongName("idcolumn")
+ .withRequired(true)
+ .withDescription("the name of the id column for each record")
+ .withArgument(argumentBuilder.withName("idcolumn").withMaximum(1).create())
+ .create();
+ Option maxScoreOnlyOption = builder
+ .withLongName("maxscoreonly")
+ .withDescription("only output the target label with max scores")
+ .create();
+ Group normalArgs = new GroupBuilder()
+ .withOption(help).withOption(quiet)
+ .withOption(inputFileOption).withOption(modelFileOption)
+ .withOption(outputFileOption).withOption(idColumnOption)
+ .withOption(maxScoreOnlyOption)
+ .create();
+ Parser parser = new Parser();
+ parser.setHelpOption(help);
+ parser.setHelpTrigger("--help");
+ parser.setGroup(normalArgs);
+ parser.setHelpFormatter(new HelpFormatter(" ", "", " ", 130));
+ CommandLine cmdLine = parser.parseAndHelp(args);
+ if (cmdLine == null) {
+ return false;
+ }
+ inputFile = getStringArgument(cmdLine, inputFileOption);
+ modelFile = getStringArgument(cmdLine, modelFileOption);
+ outputFile = getStringArgument(cmdLine, outputFileOption);
+ idColumn = getStringArgument(cmdLine, idColumnOption);
+ maxScoreOnly = getBooleanArgument(cmdLine, maxScoreOnlyOption);
+ return true;
+ }
+ private static boolean getBooleanArgument(CommandLine cmdLine, Option option) {
+ return cmdLine.hasOption(option);
+ }
+ private static String getStringArgument(CommandLine cmdLine, Option inputFile) {
+ return (String) cmdLine.getValue(inputFile);
+ }

diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sgd/RunLogistic.java b/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sgd/RunLogistic.java
new file mode 100644
index 0000000..2d57016
--- /dev/null
+++ b/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sgd/RunLogistic.java
@@ -0,0 +1,163 @@
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.classifier.sgd;
+import org.apache.commons.cli2.CommandLine;
+import org.apache.commons.cli2.Group;
+import org.apache.commons.cli2.Option;
+import org.apache.commons.cli2.builder.ArgumentBuilder;
+import org.apache.commons.cli2.builder.DefaultOptionBuilder;
+import org.apache.commons.cli2.builder.GroupBuilder;
+import org.apache.commons.cli2.commandline.Parser;
+import org.apache.commons.cli2.util.HelpFormatter;
+import org.apache.commons.io.Charsets;
+import org.apache.mahout.classifier.evaluation.Auc;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.SequentialAccessSparseVector;
+import org.apache.mahout.math.Vector;
+import java.io.BufferedReader;
+import java.io.File;
+import java.io.OutputStreamWriter;
+import java.io.PrintWriter;
+import java.util.Locale;
+public final class RunLogistic {
+ private static String inputFile;
+ private static String modelFile;
+ private static boolean showAuc;
+ private static boolean showScores;
+ private static boolean showConfusion;
+ private RunLogistic() {
+ }
+ public static void main(String[] args) throws Exception {
+ mainToOutput(args, new PrintWriter(new OutputStreamWriter(System.out, Charsets.UTF_8), true));
+ }
+ static void mainToOutput(String[] args, PrintWriter output) throws Exception {
+ if (parseArgs(args)) {
+ if (!showAuc && !showConfusion && !showScores) {
+ showAuc = true;
+ showConfusion = true;
+ }
+ Auc collector = new Auc();
+ LogisticModelParameters lmp = LogisticModelParameters.loadFrom(new File(modelFile));
+ CsvRecordFactory csv = lmp.getCsvRecordFactory();
+ OnlineLogisticRegression lr = lmp.createRegression();
+ BufferedReader in = TrainLogistic.open(inputFile);
+ String line = in.readLine();
+ csv.firstLine(line);
+ line = in.readLine();
+ if (showScores) {
+ output.println("\"target\",\"model-output\",\"log-likelihood\"");
+ }
+ while (line != null) {
+ Vector v = new SequentialAccessSparseVector(lmp.getNumFeatures());
+ int target = csv.processLine(line, v);
+ double score = lr.classifyScalar(v);
+ if (showScores) {
+ output.printf(Locale.ENGLISH, "%d,%.3f,%.6f%n", target, score, lr.logLikelihood(target, v));
+ }
+ collector.add(target, score);
+ line = in.readLine();
+ }
+ if (showAuc) {
+ output.printf(Locale.ENGLISH, "AUC = %.2f%n", collector.auc());
+ }
+ if (showConfusion) {
+ Matrix m = collector.confusion();
+ output.printf(Locale.ENGLISH, "confusion: [[%.1f, %.1f], [%.1f, %.1f]]%n",
+ m.get(0, 0), m.get(1, 0), m.get(0, 1), m.get(1, 1));
+ m = collector.entropy();
+ output.printf(Locale.ENGLISH, "entropy: [[%.1f, %.1f], [%.1f, %.1f]]%n",
+ m.get(0, 0), m.get(1, 0), m.get(0, 1), m.get(1, 1));
+ }
+ }
+ }
+ private static boolean parseArgs(String[] args) {
+ DefaultOptionBuilder builder = new DefaultOptionBuilder();
+ Option help = builder.withLongName("help").withDescription("print this list").create();
+ Option quiet = builder.withLongName("quiet").withDescription("be extra quiet").create();
+ Option auc = builder.withLongName("auc").withDescription("print AUC").create();
+ Option confusion = builder.withLongName("confusion").withDescription("print confusion matrix").create();
+ Option scores = builder.withLongName("scores").withDescription("print scores").create();
+ ArgumentBuilder argumentBuilder = new ArgumentBuilder();
+ Option inputFileOption = builder.withLongName("input")
+ .withRequired(true)
+ .withArgument(argumentBuilder.withName("input").withMaximum(1).create())
+ .withDescription("where to get training data")
+ .create();
+ Option modelFileOption = builder.withLongName("model")
+ .withRequired(true)
+ .withArgument(argumentBuilder.withName("model").withMaximum(1).create())
+ .withDescription("where to get a model")
+ .create();
+ Group normalArgs = new GroupBuilder()
+ .withOption(help)
+ .withOption(quiet)
+ .withOption(auc)
+ .withOption(scores)
+ .withOption(confusion)
+ .withOption(inputFileOption)
+ .withOption(modelFileOption)
+ .create();
+ Parser parser = new Parser();
+ parser.setHelpOption(help);
+ parser.setHelpTrigger("--help");
+ parser.setGroup(normalArgs);
+ parser.setHelpFormatter(new HelpFormatter(" ", "", " ", 130));
+ CommandLine cmdLine = parser.parseAndHelp(args);
+ if (cmdLine == null) {
+ return false;
+ }
+ inputFile = getStringArgument(cmdLine, inputFileOption);
+ modelFile = getStringArgument(cmdLine, modelFileOption);
+ showAuc = getBooleanArgument(cmdLine, auc);
+ showScores = getBooleanArgument(cmdLine, scores);
+ showConfusion = getBooleanArgument(cmdLine, confusion);
+ return true;
+ }
+ private static boolean getBooleanArgument(CommandLine cmdLine, Option option) {
+ return cmdLine.hasOption(option);
+ }
+ private static String getStringArgument(CommandLine cmdLine, Option inputFile) {
+ return (String) cmdLine.getValue(inputFile);
+ }

diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sgd/SGDHelper.java b/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sgd/SGDHelper.java
new file mode 100644
index 0000000..c657803
--- /dev/null
+++ b/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sgd/SGDHelper.java
@@ -0,0 +1,151 @@
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.classifier.sgd;
+import com.google.common.collect.Multiset;
+import org.apache.mahout.classifier.NewsgroupHelper;
+import org.apache.mahout.ep.State;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.function.DoubleFunction;
+import org.apache.mahout.math.function.Functions;
+import org.apache.mahout.vectorizer.encoders.Dictionary;
+import java.io.File;
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+import java.util.Set;
+import java.util.TreeMap;
+public final class SGDHelper {
+ private static final String[] LEAK_LABELS = {"none", "month-year", "day-month-year"};
+ private SGDHelper() {
+ }
+ public static void dissect(int leakType,
+ Dictionary dictionary,
+ AdaptiveLogisticRegression learningAlgorithm,
+ Iterable<File> files, Multiset<String> overallCounts) throws IOException {
+ CrossFoldLearner model = learningAlgorithm.getBest().getPayload().getLearner();
+ model.close();
+ Map<String, Set<Integer>> traceDictionary = new TreeMap<>();
+ ModelDissector md = new ModelDissector();
+ NewsgroupHelper helper = new NewsgroupHelper();
+ helper.getEncoder().setTraceDictionary(traceDictionary);
+ helper.getBias().setTraceDictionary(traceDictionary);
+ for (File file : permute(files, helper.getRandom()).subList(0, 500)) {
+ String ng = file.getParentFile().getName();
+ int actual = dictionary.intern(ng);
+ traceDictionary.clear();
+ Vector v = helper.encodeFeatureVector(file, actual, leakType, overallCounts);
+ md.update(v, traceDictionary, model);
+ }
+ List<String> ngNames = new ArrayList<>(dictionary.values());
+ List<ModelDissector.Weight> weights = md.summary(100);
+ System.out.println("============");
+ System.out.println("Model Dissection");
+ for (ModelDissector.Weight w : weights) {
+ System.out.printf("%s\t%.1f\t%s\t%.1f\t%s\t%.1f\t%s%n",
+ w.getFeature(), w.getWeight(), ngNames.get(w.getMaxImpact() + 1),
+ w.getCategory(1), w.getWeight(1), w.getCategory(2), w.getWeight(2));
+ }
+ }
+ public static List<File> permute(Iterable<File> files, Random rand) {
+ List<File> r = new ArrayList<>();
+ for (File file : files) {
+ int i = rand.nextInt(r.size() + 1);
+ if (i == r.size()) {
+ r.add(file);
+ } else {
+ r.add(r.get(i));
+ r.set(i, file);
+ }
+ }
+ return r;
+ }
+ static void analyzeState(SGDInfo info, int leakType, int k, State<AdaptiveLogisticRegression.Wrapper,
+ CrossFoldLearner> best) throws IOException {
+ int bump = info.getBumps()[(int) Math.floor(info.getStep()) % info.getBumps().length];
+ int scale = (int) Math.pow(10, Math.floor(info.getStep() / info.getBumps().length));
+ double maxBeta;
+ double nonZeros;
+ double positive;
+ double norm;
+ double lambda = 0;
+ double mu = 0;
+ if (best != null) {
+ CrossFoldLearner state = best.getPayload().getLearner();
+ info.setAverageCorrect(state.percentCorrect());
+ info.setAverageLL(state.logLikelihood());
+ OnlineLogisticRegression model = state.getModels().get(0);
+ // finish off pending regularization
+ model.close();
+ Matrix beta = model.getBeta();
+ maxBeta = beta.aggregate(Functions.MAX, Functions.ABS);
+ nonZeros = beta.aggregate(Functions.PLUS, new DoubleFunction() {
+ @Override
+ public double apply(double v) {
+ return Math.abs(v) > 1.0e-6 ? 1 : 0;
+ }
+ });
+ positive = beta.aggregate(Functions.PLUS, new DoubleFunction() {
+ @Override
+ public double apply(double v) {
+ return v > 0 ? 1 : 0;
+ }
+ });
+ norm = beta.aggregate(Functions.PLUS, Functions.ABS);
+ lambda = best.getMappedParams()[0];
+ mu = best.getMappedParams()[1];
+ } else {
+ maxBeta = 0;
+ nonZeros = 0;
+ positive = 0;
+ norm = 0;
+ }
+ if (k % (bump * scale) == 0) {
+ if (best != null) {
+ File modelFile = new File(System.getProperty("java.io.tmpdir"), "news-group-" + k + ".model");
+ ModelSerializer.writeBinary(modelFile.getAbsolutePath(), best.getPayload().getLearner().getModels().get(0));
+ }
+ info.setStep(info.getStep() + 0.25);
+ System.out.printf("%.2f\t%.2f\t%.2f\t%.2f\t%.8g\t%.8g\t", maxBeta, nonZeros, positive, norm, lambda, mu);
+ System.out.printf("%d\t%.3f\t%.2f\t%s%n",
+ k, info.getAverageLL(), info.getAverageCorrect() * 100, LEAK_LABELS[leakType % 3]);
+ }
+ }

diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sgd/SGDInfo.java b/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sgd/SGDInfo.java
new file mode 100644
index 0000000..be55d43
--- /dev/null
+++ b/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sgd/SGDInfo.java
@@ -0,0 +1,59 @@
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.classifier.sgd;
+final class SGDInfo {
+ private double averageLL;
+ private double averageCorrect;
+ private double step;
+ private int[] bumps = {1, 2, 5};
+ double getAverageLL() {
+ return averageLL;
+ }
+ void setAverageLL(double averageLL) {
+ this.averageLL = averageLL;
+ }
+ double getAverageCorrect() {
+ return averageCorrect;
+ }
+ void setAverageCorrect(double averageCorrect) {
+ this.averageCorrect = averageCorrect;
+ }
+ double getStep() {
+ return step;
+ }
+ void setStep(double step) {
+ this.step = step;
+ }
+ int[] getBumps() {
+ return bumps;
+ }
+ void setBumps(int[] bumps) {
+ this.bumps = bumps;
+ }

diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sgd/SimpleCsvExamples.java b/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sgd/SimpleCsvExamples.java
new file mode 100644
index 0000000..b3da452
--- /dev/null
+++ b/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sgd/SimpleCsvExamples.java
@@ -0,0 +1,283 @@
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.classifier.sgd;
+import com.google.common.base.Joiner;
+import com.google.common.base.Splitter;
+import com.google.common.collect.Lists;
+import com.google.common.io.Closeables;
+import com.google.common.io.Files;
+import org.apache.commons.io.Charsets;
+import org.apache.mahout.common.RandomUtils;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.list.IntArrayList;
+import org.apache.mahout.math.stats.OnlineSummarizer;
+import org.apache.mahout.vectorizer.encoders.ConstantValueEncoder;
+import org.apache.mahout.vectorizer.encoders.FeatureVectorEncoder;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+import java.io.BufferedReader;
+import java.io.Closeable;
+import java.io.File;
+import java.io.FileInputStream;
+import java.io.FileOutputStream;
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.OutputStreamWriter;
+import java.io.PrintWriter;
+import java.nio.ByteBuffer;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Random;
+ * Shows how different encoding choices can make big speed differences.
+ * <p/>
+ * Run with command line options --generate 1000000 test.csv to generate a million data lines in
+ * test.csv.
+ * <p/>
+ * Run with command line options --parser test.csv to time how long it takes to parse and encode
+ * those million data points
+ * <p/>
+ * Run with command line options --fast test.csv to time how long it takes to parse and encode those
+ * million data points using byte-level parsing and direct value encoding.
+ * <p/>
+ * This doesn't demonstrate text encoding which is subject to somewhat different tricks. The basic
+ * idea of caching hash locations and byte level parsing still very much applies to text, however.
+ */
+public final class SimpleCsvExamples {
+ public static final char SEPARATOR_CHAR = '\t';
+ private static final int FIELDS = 100;
+ private static final Logger log = LoggerFactory.getLogger(SimpleCsvExamples.class);
+ private SimpleCsvExamples() {}
+ public static void main(String[] args) throws IOException {
+ FeatureVectorEncoder[] encoder = new FeatureVectorEncoder[FIELDS];
+ for (int i = 0; i < FIELDS; i++) {
+ encoder[i] = new ConstantValueEncoder("v" + 1);
+ }
+ OnlineSummarizer[] s = new OnlineSummarizer[FIELDS];
+ for (int i = 0; i < FIELDS; i++) {
+ s[i] = new OnlineSummarizer();
+ }
+ long t0 = System.currentTimeMillis();
+ Vector v = new DenseVector(1000);
+ if ("--generate".equals(args[0])) {
+ try (PrintWriter out =
+ new PrintWriter(new OutputStreamWriter(new FileOutputStream(new File(args[2])), Charsets.UTF_8))) {
+ int n = Integer.parseInt(args[1]);
+ for (int i = 0; i < n; i++) {
+ Line x = Line.generate();
+ out.println(x);
+ }
+ }
+ } else if ("--parse".equals(args[0])) {
+ try (BufferedReader in = Files.newReader(new File(args[1]), Charsets.UTF_8)){
+ String line = in.readLine();
+ while (line != null) {
+ v.assign(0);
+ Line x = new Line(line);
+ for (int i = 0; i < FIELDS; i++) {
+ s[i].add(x.getDouble(i));
+ encoder[i].addToVector(x.get(i), v);
+ }
+ line = in.readLine();
+ }
+ }
+ String separator = "";
+ for (int i = 0; i < FIELDS; i++) {
+ System.out.printf("%s%.3f", separator, s[i].getMean());
+ separator = ",";
+ }
+ } else if ("--fast".equals(args[0])) {
+ try (FastLineReader in = new FastLineReader(new FileInputStream(args[1]))){
+ FastLine line = in.read();
+ while (line != null) {
+ v.assign(0);
+ for (int i = 0; i < FIELDS; i++) {
+ double z = line.getDouble(i);
+ s[i].add(z);
+ encoder[i].addToVector((byte[]) null, z, v);
+ }
+ line = in.read();
+ }
+ }
+ String separator = "";
+ for (int i = 0; i < FIELDS; i++) {
+ System.out.printf("%s%.3f", separator, s[i].getMean());
+ separator = ",";
+ }
+ }
+ System.out.printf("\nElapsed time = %.3f%n", (System.currentTimeMillis() - t0) / 1000.0);
+ }
+ private static final class Line {
+ private static final Splitter ON_TABS = Splitter.on(SEPARATOR_CHAR).trimResults();
+ public static final Joiner WITH_COMMAS = Joiner.on(SEPARATOR_CHAR);
+ public static final Random RAND = RandomUtils.getRandom();
+ private final List<String> data;
+ private Line(CharSequence line) {
+ data = Lists.newArrayList(ON_TABS.split(line));
+ }
+ private Line() {
+ data = new ArrayList<>();
+ }
+ public double getDouble(int field) {
+ return Double.parseDouble(data.get(field));
+ }
+ /**
+ * Generate a random line with 20 fields each with integer values.
+ *
+ * @return A new line with data.
+ */
+ public static Line generate() {
+ Line r = new Line();
+ for (int i = 0; i < FIELDS; i++) {
+ double mean = ((i + 1) * 257) % 50 + 1;
+ r.data.add(Integer.toString(randomValue(mean)));
+ }
+ return r;
+ }
+ /**
+ * Returns a random exponentially distributed integer with a particular mean value. This is
+ * just a way to create more small numbers than big numbers.
+ *
+ * @param mean mean of the distribution
+ * @return random exponentially distributed integer with the specific mean
+ */
+ private static int randomValue(double mean) {
+ return (int) (-mean * Math.log1p(-RAND.nextDouble()));
+ }
+ @Override
+ public String toString() {
+ return WITH_COMMAS.join(data);
+ }
+ public String get(int field) {
+ return data.get(field);
+ }
+ }
+ private static final class FastLine {
+ private final ByteBuffer base;
+ private final IntArrayList start = new IntArrayList();
+ private final IntArrayList length = new IntArrayList();
+ private FastLine(ByteBuffer base) {
+ this.base = base;
+ }
+ public static FastLine read(ByteBuffer buf) {
+ FastLine r = new FastLine(buf);
+ r.start.add(buf.position());
+ int offset = buf.position();
+ while (offset < buf.limit()) {
+ int ch = buf.get();
+ offset = buf.position();
+ switch (ch) {
+ case '\n':
+ r.length.add(offset - r.start.get(r.length.size()) - 1);
+ return r;
+ r.length.add(offset - r.start.get(r.length.size()) - 1);
+ r.start.add(offset);
+ break;
+ default:
+ // nothing to do for now
+ }
+ }
+ throw new IllegalArgumentException("Not enough bytes in buffer");
+ }
+ public double getDouble(int field) {
+ int offset = start.get(field);
+ int size = length.get(field);
+ switch (size) {
+ case 1:
+ return base.get(offset) - '0';
+ case 2:
+ return (base.get(offset) - '0') * 10 + base.get(offset + 1) - '0';
+ default:
+ double r = 0;
+ for (int i = 0; i < size; i++) {
+ r = 10 * r + base.get(offset + i) - '0';
+ }
+ return r;
+ }
+ }
+ }
+ private static final class FastLineReader implements Closeable {
+ private final InputStream in;
+ private final ByteBuffer buf = ByteBuffer.allocate(100000);
+ private FastLineReader(InputStream in) throws IOException {
+ this.in = in;
+ buf.limit(0);
+ fillBuffer();
+ }
+ public FastLine read() throws IOException {
+ fillBuffer();
+ if (buf.remaining() > 0) {
+ return FastLine.read(buf);
+ } else {
+ return null;
+ }
+ }
+ private void fillBuffer() throws IOException {
+ if (buf.remaining() < 10000) {
+ buf.compact();
+ int n = in.read(buf.array(), buf.position(), buf.remaining());
+ if (n == -1) {
+ buf.flip();
+ } else {
+ buf.limit(buf.position() + n);
+ buf.position(0);
+ }
+ }
+ }
+ @Override
+ public void close() {
+ try {
+ Closeables.close(in, true);
+ } catch (IOException e) {
+ log.error(e.getMessage(), e);
+ }
+ }
+ }

diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sgd/TestASFEmail.java b/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sgd/TestASFEmail.java
new file mode 100644
index 0000000..074f774
--- /dev/null
+++ b/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sgd/TestASFEmail.java
@@ -0,0 +1,152 @@
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.classifier.sgd;
+import org.apache.commons.cli2.CommandLine;
+import org.apache.commons.cli2.Group;
+import org.apache.commons.cli2.Option;
+import org.apache.commons.cli2.builder.ArgumentBuilder;
+import org.apache.commons.cli2.builder.DefaultOptionBuilder;
+import org.apache.commons.cli2.builder.GroupBuilder;
+import org.apache.commons.cli2.commandline.Parser;
+import org.apache.commons.cli2.util.HelpFormatter;
+import org.apache.commons.io.Charsets;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.fs.PathFilter;
+import org.apache.hadoop.io.Text;
+import org.apache.mahout.classifier.ClassifierResult;
+import org.apache.mahout.classifier.ResultAnalyzer;
+import org.apache.mahout.common.Pair;
+import org.apache.mahout.common.iterator.sequencefile.PathType;
+import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirIterator;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+import org.apache.mahout.vectorizer.encoders.Dictionary;
+import java.io.File;
+import java.io.FileInputStream;
+import java.io.IOException;
+import java.io.OutputStreamWriter;
+import java.io.PrintWriter;
+ * Run the ASF email, as trained by TrainASFEmail
+ */
+public final class TestASFEmail {
+ private String inputFile;
+ private String modelFile;
+ private TestASFEmail() {}
+ public static void main(String[] args) throws IOException {
+ TestASFEmail runner = new TestASFEmail();
+ if (runner.parseArgs(args)) {
+ runner.run(new PrintWriter(new OutputStreamWriter(System.out, Charsets.UTF_8), true));
+ }
+ }
+ public void run(PrintWriter output) throws IOException {
+ File base = new File(inputFile);
+ //contains the best model
+ OnlineLogisticRegression classifier =
+ ModelSerializer.readBinary(new FileInputStream(modelFile), OnlineLogisticRegression.class);
+ Dictionary asfDictionary = new Dictionary();
+ Configuration conf = new Configuration();
+ PathFilter testFilter = new PathFilter() {
+ @Override
+ public boolean accept(Path path) {
+ return path.getName().contains("test");
+ }
+ };
+ SequenceFileDirIterator<Text, VectorWritable> iter =
+ new SequenceFileDirIterator<>(new Path(base.toString()), PathType.LIST, testFilter,
+ null, true, conf);
+ long numItems = 0;
+ while (iter.hasNext()) {
+ Pair<Text, VectorWritable> next = iter.next();
+ asfDictionary.intern(next.getFirst().toString());
+ numItems++;
+ }
+ System.out.println(numItems + " test files");
+ ResultAnalyzer ra = new ResultAnalyzer(asfDictionary.values(), "DEFAULT");
+ iter = new SequenceFileDirIterator<>(new Path(base.toString()), PathType.LIST, testFilter,
+ null, true, conf);
+ while (iter.hasNext()) {
+ Pair<Text, VectorWritable> next = iter.next();
+ String ng = next.getFirst().toString();
+ int actual = asfDictionary.intern(ng);
+ Vector result = classifier.classifyFull(next.getSecond().get());
+ int cat = result.maxValueIndex();
+ double score = result.maxValue();
+ double ll = classifier.logLikelihood(actual, next.getSecond().get());
+ ClassifierResult cr = new ClassifierResult(asfDictionary.values().get(cat), score, ll);
+ ra.addInstance(asfDictionary.values().get(actual), cr);
+ }
+ output.println(ra);
+ }
+ boolean parseArgs(String[] args) {
+ DefaultOptionBuilder builder = new DefaultOptionBuilder();
+ Option help = builder.withLongName("help").withDescription("print this list").create();
+ ArgumentBuilder argumentBuilder = new ArgumentBuilder();
+ Option inputFileOption = builder.withLongName("input")
+ .withRequired(true)
+ .withArgument(argumentBuilder.withName("input").withMaximum(1).create())
+ .withDescription("where to get training data")
+ .create();
+ Option modelFileOption = builder.withLongName("model")
+ .withRequired(true)
+ .withArgument(argumentBuilder.withName("model").withMaximum(1).create())
+ .withDescription("where to get a model")
+ .create();
+ Group normalArgs = new GroupBuilder()
+ .withOption(help)
+ .withOption(inputFileOption)
+ .withOption(modelFileOption)
+ .create();
+ Parser parser = new Parser();
+ parser.setHelpOption(help);
+ parser.setHelpTrigger("--help");
+ parser.setGroup(normalArgs);
+ parser.setHelpFormatter(new HelpFormatter(" ", "", " ", 130));
+ CommandLine cmdLine = parser.parseAndHelp(args);
+ if (cmdLine == null) {
+ return false;
+ }
+ inputFile = (String) cmdLine.getValue(inputFileOption);
+ modelFile = (String) cmdLine.getValue(modelFileOption);
+ return true;
+ }

diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sgd/TestNewsGroups.java b/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sgd/TestNewsGroups.java
new file mode 100644
index 0000000..f0316e9
--- /dev/null
+++ b/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sgd/TestNewsGroups.java
@@ -0,0 +1,141 @@
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.classifier.sgd;
+import com.google.common.collect.HashMultiset;
+import com.google.common.collect.Multiset;
+import org.apache.commons.cli2.CommandLine;
+import org.apache.commons.cli2.Group;
+import org.apache.commons.cli2.Option;
+import org.apache.commons.cli2.builder.ArgumentBuilder;
+import org.apache.commons.cli2.builder.DefaultOptionBuilder;
+import org.apache.commons.cli2.builder.GroupBuilder;
+import org.apache.commons.cli2.commandline.Parser;
+import org.apache.commons.cli2.util.HelpFormatter;
+import org.apache.commons.io.Charsets;
+import org.apache.mahout.classifier.ClassifierResult;
+import org.apache.mahout.classifier.NewsgroupHelper;
+import org.apache.mahout.classifier.ResultAnalyzer;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.vectorizer.encoders.Dictionary;
+import java.io.File;
+import java.io.FileInputStream;
+import java.io.IOException;
+import java.io.OutputStreamWriter;
+import java.io.PrintWriter;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+ * Run the 20 news groups test data through SGD, as trained by {@link org.apache.mahout.classifier.sgd.TrainNewsGroups}.
+ */
+public final class TestNewsGroups {
+ private String inputFile;
+ private String modelFile;
+ private TestNewsGroups() {
+ }
+ public static void main(String[] args) throws IOException {
+ TestNewsGroups runner = new TestNewsGroups();
+ if (runner.parseArgs(args)) {
+ runner.run(new PrintWriter(new OutputStreamWriter(System.out, Charsets.UTF_8), true));
+ }
+ }
+ public void run(PrintWriter output) throws IOException {
+ File base = new File(inputFile);
+ //contains the best model
+ OnlineLogisticRegression classifier =
+ ModelSerializer.readBinary(new FileInputStream(modelFile), OnlineLogisticRegression.class);
+ Dictionary newsGroups = new Dictionary();
+ Multiset<String> overallCounts = HashMultiset.create();
+ List<File> files = new ArrayList<>();
+ for (File newsgroup : base.listFiles()) {
+ if (newsgroup.isDirectory()) {
+ newsGroups.intern(newsgroup.getName());
+ files.addAll(Arrays.asList(newsgroup.listFiles()));
+ }
+ }
+ System.out.println(files.size() + " test files");
+ ResultAnalyzer ra = new ResultAnalyzer(newsGroups.values(), "DEFAULT");
+ for (File file : files) {
+ String ng = file.getParentFile().getName();
+ int actual = newsGroups.intern(ng);
+ NewsgroupHelper helper = new NewsgroupHelper();
+ //no leak type ensures this is a normal vector
+ Vector input = helper.encodeFeatureVector(file, actual, 0, overallCounts);
+ Vector result = classifier.classifyFull(input);
+ int cat = result.maxValueIndex();
+ double score = result.maxValue();
+ double ll = classifier.logLikelihood(actual, input);
+ ClassifierResult cr = new ClassifierResult(newsGroups.values().get(cat), score, ll);
+ ra.addInstance(newsGroups.values().get(actual), cr);
+ }
+ output.println(ra);
+ }
+ boolean parseArgs(String[] args) {
+ DefaultOptionBuilder builder = new DefaultOptionBuilder();
+ Option help = builder.withLongName("help").withDescription("print this list").create();
+ ArgumentBuilder argumentBuilder = new ArgumentBuilder();
+ Option inputFileOption = builder.withLongName("input")
+ .withRequired(true)
+ .withArgument(argumentBuilder.withName("input").withMaximum(1).create())
+ .withDescription("where to get training data")
+ .create();
+ Option modelFileOption = builder.withLongName("model")
+ .withRequired(true)
+ .withArgument(argumentBuilder.withName("model").withMaximum(1).create())
+ .withDescription("where to get a model")
+ .create();
+ Group normalArgs = new GroupBuilder()
+ .withOption(help)
+ .withOption(inputFileOption)
+ .withOption(modelFileOption)
+ .create();
+ Parser parser = new Parser();
+ parser.setHelpOption(help);
+ parser.setHelpTrigger("--help");
+ parser.setGroup(normalArgs);
+ parser.setHelpFormatter(new HelpFormatter(" ", "", " ", 130));
+ CommandLine cmdLine = parser.parseAndHelp(args);
+ if (cmdLine == null) {
+ return false;
+ }
+ inputFile = (String) cmdLine.getValue(inputFileOption);
+ modelFile = (String) cmdLine.getValue(modelFileOption);
+ return true;
+ }

diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainASFEmail.java b/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainASFEmail.java
new file mode 100644
index 0000000..e681f92
--- /dev/null
+++ b/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainASFEmail.java
@@ -0,0 +1,137 @@
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.classifier.sgd;
+import com.google.common.collect.HashMultiset;
+import com.google.common.collect.Multiset;
+import com.google.common.collect.Ordering;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.fs.PathFilter;
+import org.apache.hadoop.io.Text;
+import org.apache.mahout.common.AbstractJob;
+import org.apache.mahout.common.Pair;
+import org.apache.mahout.common.iterator.sequencefile.PathType;
+import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirIterator;
+import org.apache.mahout.ep.State;
+import org.apache.mahout.math.VectorWritable;
+import org.apache.mahout.vectorizer.encoders.Dictionary;
+import java.io.File;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+public final class TrainASFEmail extends AbstractJob {
+ private TrainASFEmail() {
+ }
+ @Override
+ public int run(String[] args) throws Exception {
+ addInputOption();
+ addOutputOption();
+ addOption("categories", "nc", "The number of categories to train on", true);
+ addOption("cardinality", "c", "The size of the vectors to use", "100000");
+ addOption("threads", "t", "The number of threads to use in the learner", "20");
+ addOption("poolSize", "p", "The number of CrossFoldLearners to use in the AdaptiveLogisticRegression. "
+ + "Higher values require more memory.", "5");
+ if (parseArguments(args) == null) {
+ return -1;
+ }
+ File base = new File(getInputPath().toString());
+ Multiset<String> overallCounts = HashMultiset.create();
+ File output = new File(getOutputPath().toString());
+ output.mkdirs();
+ int numCats = Integer.parseInt(getOption("categories"));
+ int cardinality = Integer.parseInt(getOption("cardinality", "100000"));
+ int threadCount = Integer.parseInt(getOption("threads", "20"));
+ int poolSize = Integer.parseInt(getOption("poolSize", "5"));
+ Dictionary asfDictionary = new Dictionary();
+ AdaptiveLogisticRegression learningAlgorithm =
+ new AdaptiveLogisticRegression(numCats, cardinality, new L1(), threadCount, poolSize);
+ learningAlgorithm.setInterval(800);
+ learningAlgorithm.setAveragingWindow(500);
+ //We ran seq2encoded and split input already, so let's just build up the dictionary
+ Configuration conf = new Configuration();
+ PathFilter trainFilter = new PathFilter() {
+ @Override
+ public boolean accept(Path path) {
+ return path.getName().contains("training");
+ }
+ };
+ SequenceFileDirIterator<Text, VectorWritable> iter =
+ new SequenceFileDirIterator<>(new Path(base.toString()), PathType.LIST, trainFilter, null, true, conf);
+ long numItems = 0;
+ while (iter.hasNext()) {
+ Pair<Text, VectorWritable> next = iter.next();
+ asfDictionary.intern(next.getFirst().toString());
+ numItems++;
+ }
+ System.out.println(numItems + " training files");
+ SGDInfo info = new SGDInfo();
+ iter = new SequenceFileDirIterator<>(new Path(base.toString()), PathType.LIST, trainFilter,
+ null, true, conf);
+ int k = 0;
+ while (iter.hasNext()) {
+ Pair<Text, VectorWritable> next = iter.next();
+ String ng = next.getFirst().toString();
+ int actual = asfDictionary.intern(ng);
+ //we already have encoded
+ learningAlgorithm.train(actual, next.getSecond().get());
+ k++;
+ State<AdaptiveLogisticRegression.Wrapper, CrossFoldLearner> best = learningAlgorithm.getBest();
+ SGDHelper.analyzeState(info, 0, k, best);
+ }
+ learningAlgorithm.close();
+ //TODO: how to dissection since we aren't processing the files here
+ //SGDHelper.dissect(leakType, asfDictionary, learningAlgorithm, files, overallCounts);
+ System.out.println("exiting main, writing model to " + output);
+ ModelSerializer.writeBinary(output + "/asf.model",
+ learningAlgorithm.getBest().getPayload().getLearner().getModels().get(0));
+ List<Integer> counts = new ArrayList<>();
+ System.out.println("Word counts");
+ for (String count : overallCounts.elementSet()) {
+ counts.add(overallCounts.count(count));
+ }
+ Collections.sort(counts, Ordering.natural().reverse());
+ k = 0;
+ for (Integer count : counts) {
+ System.out.println(k + "\t" + count);
+ k++;
+ if (k > 1000) {
+ break;
+ }
+ }
+ return 0;
+ }
+ public static void main(String[] args) throws Exception {
+ TrainASFEmail trainer = new TrainASFEmail();
+ trainer.run(args);
+ }

diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainAdaptiveLogistic.java b/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainAdaptiveLogistic.java
new file mode 100644
index 0000000..defb5b9
--- /dev/null
+++ b/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainAdaptiveLogistic.java
@@ -0,0 +1,377 @@
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.classifier.sgd;
+import com.google.common.io.Resources;
+import org.apache.commons.cli2.CommandLine;
+import org.apache.commons.cli2.Group;
+import org.apache.commons.cli2.Option;
+import org.apache.commons.cli2.builder.ArgumentBuilder;
+import org.apache.commons.cli2.builder.DefaultOptionBuilder;
+import org.apache.commons.cli2.builder.GroupBuilder;
+import org.apache.commons.cli2.commandline.Parser;
+import org.apache.commons.cli2.util.HelpFormatter;
+import org.apache.commons.io.Charsets;
+import org.apache.mahout.classifier.sgd.AdaptiveLogisticRegression.Wrapper;
+import org.apache.mahout.ep.State;
+import org.apache.mahout.math.RandomAccessSparseVector;
+import org.apache.mahout.math.Vector;
+import java.io.BufferedReader;
+import java.io.File;
+import java.io.FileInputStream;
+import java.io.FileOutputStream;
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.InputStreamReader;
+import java.io.OutputStream;
+import java.io.OutputStreamWriter;
+import java.io.PrintWriter;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Locale;
+public final class TrainAdaptiveLogistic {
+ private static String inputFile;
+ private static String outputFile;
+ private static AdaptiveLogisticModelParameters lmp;
+ private static int passes;
+ private static boolean showperf;
+ private static int skipperfnum = 99;
+ private static AdaptiveLogisticRegression model;
+ private TrainAdaptiveLogistic() {
+ }
+ public static void main(String[] args) throws Exception {
+ mainToOutput(args, new PrintWriter(new OutputStreamWriter(System.out, Charsets.UTF_8), true));
+ }
+ static void mainToOutput(String[] args, PrintWriter output) throws Exception {
+ if (parseArgs(args)) {
+ CsvRecordFactory csv = lmp.getCsvRecordFactory();
+ model = lmp.createAdaptiveLogisticRegression();
+ State<Wrapper, CrossFoldLearner> best;
+ CrossFoldLearner learner = null;
+ int k = 0;
+ for (int pass = 0; pass < passes; pass++) {
+ BufferedReader in = open(inputFile);
+ // read variable names
+ csv.firstLine(in.readLine());
+ String line = in.readLine();
+ while (line != null) {
+ // for each new line, get target and predictors
+ Vector input = new RandomAccessSparseVector(lmp.getNumFeatures());
+ int targetValue = csv.processLine(line, input);
+ // update model
+ model.train(targetValue, input);
+ k++;
+ if (showperf && (k % (skipperfnum + 1) == 0)) {
+ best = model.getBest();
+ if (best != null) {
+ learner = best.getPayload().getLearner();
+ }
+ if (learner != null) {
+ double averageCorrect = learner.percentCorrect();
+ double averageLL = learner.logLikelihood();
+ output.printf("%d\t%.3f\t%.2f%n",
+ k, averageLL, averageCorrect * 100);
+ } else {
+ output.printf(Locale.ENGLISH,
+ "%10d %2d %s%n", k, targetValue,
+ "AdaptiveLogisticRegression has not found a good model ......");
+ }
+ }
+ line = in.readLine();
+ }
+ in.close();
+ }
+ best = model.getBest();
+ if (best != null) {
+ learner = best.getPayload().getLearner();
+ }
+ if (learner == null) {
+ output.println("AdaptiveLogisticRegression has failed to train a model.");
+ return;
+ }
+ try (OutputStream modelOutput = new FileOutputStream(outputFile)) {
+ lmp.saveTo(modelOutput);
+ }
+ OnlineLogisticRegression lr = learner.getModels().get(0);
+ output.println(lmp.getNumFeatures());
+ output.println(lmp.getTargetVariable() + " ~ ");
+ String sep = "";
+ for (String v : csv.getTraceDictionary().keySet()) {
+ double weight = predictorWeight(lr, 0, csv, v);
+ if (weight != 0) {
+ output.printf(Locale.ENGLISH, "%s%.3f*%s", sep, weight, v);
+ sep = " + ";
+ }
+ }
+ output.printf("%n");
+ for (int row = 0; row < lr.getBeta().numRows(); row++) {
+ for (String key : csv.getTraceDictionary().keySet()) {
+ double weight = predictorWeight(lr, row, csv, key);
+ if (weight != 0) {
+ output.printf(Locale.ENGLISH, "%20s %.5f%n", key, weight);
+ }
+ }
+ for (int column = 0; column < lr.getBeta().numCols(); column++) {
+ output.printf(Locale.ENGLISH, "%15.9f ", lr.getBeta().get(row, column));
+ }
+ output.println();
+ }
+ }
+ }
+ private static double predictorWeight(OnlineLogisticRegression lr, int row, RecordFactory csv, String predictor) {
+ double weight = 0;
+ for (Integer column : csv.getTraceDictionary().get(predictor)) {
+ weight += lr.getBeta().get(row, column);
+ }
+ return weight;
+ }
+ private static boolean parseArgs(String[] args) {
+ DefaultOptionBuilder builder = new DefaultOptionBuilder();
+ Option help = builder.withLongName("help")
+ .withDescription("print this list").create();
+ Option quiet = builder.withLongName("quiet")
+ .withDescription("be extra quiet").create();
+ ArgumentBuilder argumentBuilder = new ArgumentBuilder();
+ Option showperf = builder
+ .withLongName("showperf")
+ .withDescription("output performance measures during training")
+ .create();
+ Option inputFile = builder
+ .withLongName("input")
+ .withRequired(true)
+ .withArgument(
+ argumentBuilder.withName("input").withMaximum(1)
+ .create())
+ .withDescription("where to get training data").create();
+ Option outputFile = builder
+ .withLongName("output")
+ .withRequired(true)
+ .withArgument(
+ argumentBuilder.withName("output").withMaximum(1)
+ .create())
+ .withDescription("where to write the model content").create();
+ Option threads = builder.withLongName("threads")
+ .withArgument(
+ argumentBuilder.withName("threads").withDefault("4").create())
+ .withDescription("the number of threads AdaptiveLogisticRegression uses")
+ .create();
+ Option predictors = builder.withLongName("predictors")
+ .withRequired(true)
+ .withArgument(argumentBuilder.withName("predictors").create())
+ .withDescription("a list of predictor variables").create();
+ Option types = builder
+ .withLongName("types")
+ .withRequired(true)
+ .withArgument(argumentBuilder.withName("types").create())
+ .withDescription(
+ "a list of predictor variable types (numeric, word, or text)")
+ .create();
+ Option target = builder
+ .withLongName("target")
+ .withDescription("the name of the target variable")
+ .withRequired(true)
+ .withArgument(
+ argumentBuilder.withName("target").withMaximum(1)
+ .create())
+ .create();
+ Option targetCategories = builder
+ .withLongName("categories")
+ .withDescription("the number of target categories to be considered")
+ .withRequired(true)
+ .withArgument(argumentBuilder.withName("categories").withMaximum(1).create())
+ .create();
+ Option features = builder
+ .withLongName("features")
+ .withDescription("the number of internal hashed features to use")
+ .withArgument(
+ argumentBuilder.withName("numFeatures")
+ .withDefault("1000").withMaximum(1).create())
+ .create();
+ Option passes = builder
+ .withLongName("passes")
+ .withDescription("the number of times to pass over the input data")
+ .withArgument(
+ argumentBuilder.withName("passes").withDefault("2")
+ .withMaximum(1).create())
+ .create();
+ Option interval = builder.withLongName("interval")
+ .withArgument(
+ argumentBuilder.withName("interval").withDefault("500").create())
+ .withDescription("the interval property of AdaptiveLogisticRegression")
+ .create();
+ Option window = builder.withLongName("window")
+ .withArgument(
+ argumentBuilder.withName("window").withDefault("800").create())
+ .withDescription("the average propery of AdaptiveLogisticRegression")
+ .create();
+ Option skipperfnum = builder.withLongName("skipperfnum")
+ .withArgument(
+ argumentBuilder.withName("skipperfnum").withDefault("99").create())
+ .withDescription("show performance measures every (skipperfnum + 1) rows")
+ .create();
+ Option prior = builder.withLongName("prior")
+ .withArgument(
+ argumentBuilder.withName("prior").withDefault("L1").create())
+ .withDescription("the prior algorithm to use: L1, L2, ebp, tp, up")
+ .create();
+ Option priorOption = builder.withLongName("prioroption")
+ .withArgument(
+ argumentBuilder.withName("prioroption").create())
+ .withDescription("constructor parameter for ElasticBandPrior and TPrior")
+ .create();
+ Option auc = builder.withLongName("auc")
+ .withArgument(
+ argumentBuilder.withName("auc").withDefault("global").create())
+ .withDescription("the auc to use: global or grouped")
+ .create();
+ Group normalArgs = new GroupBuilder().withOption(help)
+ .withOption(quiet).withOption(inputFile).withOption(outputFile)
+ .withOption(target).withOption(targetCategories)
+ .withOption(predictors).withOption(types).withOption(passes)
+ .withOption(interval).withOption(window).withOption(threads)
+ .withOption(prior).withOption(features).withOption(showperf)
+ .withOption(skipperfnum).withOption(priorOption).withOption(auc)
+ .create();
+ Parser parser = new Parser();
+ parser.setHelpOption(help);
+ parser.setHelpTrigger("--help");
+ parser.setGroup(normalArgs);
+ parser.setHelpFormatter(new HelpFormatter(" ", "", " ", 130));
+ CommandLine cmdLine = parser.parseAndHelp(args);
+ if (cmdLine == null) {
+ return false;
+ }
+ TrainAdaptiveLogistic.inputFile = getStringArgument(cmdLine, inputFile);
+ TrainAdaptiveLogistic.outputFile = getStringArgument(cmdLine,
+ outputFile);
+ List<String> typeList = new ArrayList<>();
+ for (Object x : cmdLine.getValues(types)) {
+ typeList.add(x.toString());
+ }
+ List<String> predictorList = new ArrayList<>();
+ for (Object x : cmdLine.getValues(predictors)) {
+ predictorList.add(x.toString());
+ }
+ lmp = new AdaptiveLogisticModelParameters();
+ lmp.setTargetVariable(getStringArgument(cmdLine, target));
+ lmp.setMaxTargetCategories(getIntegerArgument(cmdLine, targetCategories));
+ lmp.setNumFeatures(getIntegerArgument(cmdLine, features));
+ lmp.setInterval(getIntegerArgument(cmdLine, interval));
+ lmp.setAverageWindow(getIntegerArgument(cmdLine, window));
+ lmp.setThreads(getIntegerArgument(cmdLine, threads));
+ lmp.setAuc(getStringArgument(cmdLine, auc));
+ lmp.setPrior(getStringArgument(cmdLine, prior));
+ if (cmdLine.getValue(priorOption) != null) {
+ lmp.setPriorOption(getDoubleArgument(cmdLine, priorOption));
+ }
+ lmp.setTypeMap(predictorList, typeList);
+ TrainAdaptiveLogistic.showperf = getBooleanArgument(cmdLine, showperf);
+ TrainAdaptiveLogistic.skipperfnum = getIntegerArgument(cmdLine, skipperfnum);
+ TrainAdaptiveLogistic.passes = getIntegerArgument(cmdLine, passes);
+ lmp.checkParameters();
+ return true;
+ }
+ private static String getStringArgument(CommandLine cmdLine,
+ Option inputFile) {
+ return (String) cmdLine.getValue(inputFile);
+ }
+ private static boolean getBooleanArgument(CommandLine cmdLine, Option option) {
+ return cmdLine.hasOption(option);
+ }
+ private static int getIntegerArgument(CommandLine cmdLine, Option features) {
+ return Integer.parseInt((String) cmdLine.getValue(features));
+ }
+ private static double getDoubleArgument(CommandLine cmdLine, Option op) {
+ return Double.parseDouble((String) cmdLine.getValue(op));
+ }
+ public static AdaptiveLogisticRegression getModel() {
+ return model;
+ }
+ public static LogisticModelParameters getParameters() {
+ return lmp;
+ }
+ static BufferedReader open(String inputFile) throws IOException {
+ InputStream in;
+ try {
+ in = Resources.getResource(inputFile).openStream();
+ } catch (IllegalArgumentException e) {
+ in = new FileInputStream(new File(inputFile));
+ }
+ return new BufferedReader(new InputStreamReader(in, Charsets.UTF_8));
+ }
2018-06-27 13:14:41 UTC
diff --git a/community/mahout-mr/examples/src/main/resources/bank-full.csv b/community/mahout-mr/examples/src/main/resources/bank-full.csv
new file mode 100644
index 0000000..d7a2ede
--- /dev/null
+++ b/community/mahout-mr/examples/src/main/resources/bank-full.csv
@@ -0,0 +1,45212 @@

2018-06-27 13:14:43 UTC
diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainLogistic.java b/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainLogistic.java
new file mode 100644
index 0000000..f4b8bcb
--- /dev/null
+++ b/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainLogistic.java
@@ -0,0 +1,311 @@
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.classifier.sgd;
+import com.google.common.io.Resources;
+import org.apache.commons.cli2.CommandLine;
+import org.apache.commons.cli2.Group;
+import org.apache.commons.cli2.Option;
+import org.apache.commons.cli2.builder.ArgumentBuilder;
+import org.apache.commons.cli2.builder.DefaultOptionBuilder;
+import org.apache.commons.cli2.builder.GroupBuilder;
+import org.apache.commons.cli2.commandline.Parser;
+import org.apache.commons.cli2.util.HelpFormatter;
+import org.apache.commons.io.Charsets;
+import org.apache.mahout.math.RandomAccessSparseVector;
+import org.apache.mahout.math.Vector;
+import java.io.BufferedReader;
+import java.io.File;
+import java.io.FileInputStream;
+import java.io.FileOutputStream;
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.InputStreamReader;
+import java.io.OutputStream;
+import java.io.OutputStreamWriter;
+import java.io.PrintWriter;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Locale;
+ * Train a logistic regression for the examples from Chapter 13 of Mahout in Action
+ */
+public final class TrainLogistic {
+ private static String inputFile;
+ private static String outputFile;
+ private static LogisticModelParameters lmp;
+ private static int passes;
+ private static boolean scores;
+ private static OnlineLogisticRegression model;
+ private TrainLogistic() {
+ }
+ public static void main(String[] args) throws Exception {
+ mainToOutput(args, new PrintWriter(new OutputStreamWriter(System.out, Charsets.UTF_8), true));
+ }
+ static void mainToOutput(String[] args, PrintWriter output) throws Exception {
+ if (parseArgs(args)) {
+ double logPEstimate = 0;
+ int samples = 0;
+ CsvRecordFactory csv = lmp.getCsvRecordFactory();
+ OnlineLogisticRegression lr = lmp.createRegression();
+ for (int pass = 0; pass < passes; pass++) {
+ try (BufferedReader in = open(inputFile)) {
+ // read variable names
+ csv.firstLine(in.readLine());
+ String line = in.readLine();
+ while (line != null) {
+ // for each new line, get target and predictors
+ Vector input = new RandomAccessSparseVector(lmp.getNumFeatures());
+ int targetValue = csv.processLine(line, input);
+ // check performance while this is still news
+ double logP = lr.logLikelihood(targetValue, input);
+ if (!Double.isInfinite(logP)) {
+ if (samples < 20) {
+ logPEstimate = (samples * logPEstimate + logP) / (samples + 1);
+ } else {
+ logPEstimate = 0.95 * logPEstimate + 0.05 * logP;
+ }
+ samples++;
+ }
+ double p = lr.classifyScalar(input);
+ if (scores) {
+ output.printf(Locale.ENGLISH, "%10d %2d %10.2f %2.4f %10.4f %10.4f%n",
+ samples, targetValue, lr.currentLearningRate(), p, logP, logPEstimate);
+ }
+ // now update model
+ lr.train(targetValue, input);
+ line = in.readLine();
+ }
+ }
+ }
+ try (OutputStream modelOutput = new FileOutputStream(outputFile)) {
+ lmp.saveTo(modelOutput);
+ }
+ output.println(lmp.getNumFeatures());
+ output.println(lmp.getTargetVariable() + " ~ ");
+ String sep = "";
+ for (String v : csv.getTraceDictionary().keySet()) {
+ double weight = predictorWeight(lr, 0, csv, v);
+ if (weight != 0) {
+ output.printf(Locale.ENGLISH, "%s%.3f*%s", sep, weight, v);
+ sep = " + ";
+ }
+ }
+ output.printf("%n");
+ model = lr;
+ for (int row = 0; row < lr.getBeta().numRows(); row++) {
+ for (String key : csv.getTraceDictionary().keySet()) {
+ double weight = predictorWeight(lr, row, csv, key);
+ if (weight != 0) {
+ output.printf(Locale.ENGLISH, "%20s %.5f%n", key, weight);
+ }
+ }
+ for (int column = 0; column < lr.getBeta().numCols(); column++) {
+ output.printf(Locale.ENGLISH, "%15.9f ", lr.getBeta().get(row, column));
+ }
+ output.println();
+ }
+ }
+ }
+ private static double predictorWeight(OnlineLogisticRegression lr, int row, RecordFactory csv, String predictor) {
+ double weight = 0;
+ for (Integer column : csv.getTraceDictionary().get(predictor)) {
+ weight += lr.getBeta().get(row, column);
+ }
+ return weight;
+ }
+ private static boolean parseArgs(String[] args) {
+ DefaultOptionBuilder builder = new DefaultOptionBuilder();
+ Option help = builder.withLongName("help").withDescription("print this list").create();
+ Option quiet = builder.withLongName("quiet").withDescription("be extra quiet").create();
+ Option scores = builder.withLongName("scores").withDescription("output score diagnostics during training").create();
+ ArgumentBuilder argumentBuilder = new ArgumentBuilder();
+ Option inputFile = builder.withLongName("input")
+ .withRequired(true)
+ .withArgument(argumentBuilder.withName("input").withMaximum(1).create())
+ .withDescription("where to get training data")
+ .create();
+ Option outputFile = builder.withLongName("output")
+ .withRequired(true)
+ .withArgument(argumentBuilder.withName("output").withMaximum(1).create())
+ .withDescription("where to get training data")
+ .create();
+ Option predictors = builder.withLongName("predictors")
+ .withRequired(true)
+ .withArgument(argumentBuilder.withName("p").create())
+ .withDescription("a list of predictor variables")
+ .create();
+ Option types = builder.withLongName("types")
+ .withRequired(true)
+ .withArgument(argumentBuilder.withName("t").create())
+ .withDescription("a list of predictor variable types (numeric, word, or text)")
+ .create();
+ Option target = builder.withLongName("target")
+ .withRequired(true)
+ .withArgument(argumentBuilder.withName("target").withMaximum(1).create())
+ .withDescription("the name of the target variable")
+ .create();
+ Option features = builder.withLongName("features")
+ .withArgument(
+ argumentBuilder.withName("numFeatures")
+ .withDefault("1000")
+ .withMaximum(1).create())
+ .withDescription("the number of internal hashed features to use")
+ .create();
+ Option passes = builder.withLongName("passes")
+ .withArgument(
+ argumentBuilder.withName("passes")
+ .withDefault("2")
+ .withMaximum(1).create())
+ .withDescription("the number of times to pass over the input data")
+ .create();
+ Option lambda = builder.withLongName("lambda")
+ .withArgument(argumentBuilder.withName("lambda").withDefault("1e-4").withMaximum(1).create())
+ .withDescription("the amount of coefficient decay to use")
+ .create();
+ Option rate = builder.withLongName("rate")
+ .withArgument(argumentBuilder.withName("learningRate").withDefault("1e-3").withMaximum(1).create())
+ .withDescription("the learning rate")
+ .create();
+ Option noBias = builder.withLongName("noBias")
+ .withDescription("don't include a bias term")
+ .create();
+ Option targetCategories = builder.withLongName("categories")
+ .withRequired(true)
+ .withArgument(argumentBuilder.withName("number").withMaximum(1).create())
+ .withDescription("the number of target categories to be considered")
+ .create();
+ Group normalArgs = new GroupBuilder()
+ .withOption(help)
+ .withOption(quiet)
+ .withOption(inputFile)
+ .withOption(outputFile)
+ .withOption(target)
+ .withOption(targetCategories)
+ .withOption(predictors)
+ .withOption(types)
+ .withOption(passes)
+ .withOption(lambda)
+ .withOption(rate)
+ .withOption(noBias)
+ .withOption(features)
+ .create();
+ Parser parser = new Parser();
+ parser.setHelpOption(help);
+ parser.setHelpTrigger("--help");
+ parser.setGroup(normalArgs);
+ parser.setHelpFormatter(new HelpFormatter(" ", "", " ", 130));
+ CommandLine cmdLine = parser.parseAndHelp(args);
+ if (cmdLine == null) {
+ return false;
+ }
+ TrainLogistic.inputFile = getStringArgument(cmdLine, inputFile);
+ TrainLogistic.outputFile = getStringArgument(cmdLine, outputFile);
+ List<String> typeList = new ArrayList<>();
+ for (Object x : cmdLine.getValues(types)) {
+ typeList.add(x.toString());
+ }
+ List<String> predictorList = new ArrayList<>();
+ for (Object x : cmdLine.getValues(predictors)) {
+ predictorList.add(x.toString());
+ }
+ lmp = new LogisticModelParameters();
+ lmp.setTargetVariable(getStringArgument(cmdLine, target));
+ lmp.setMaxTargetCategories(getIntegerArgument(cmdLine, targetCategories));
+ lmp.setNumFeatures(getIntegerArgument(cmdLine, features));
+ lmp.setUseBias(!getBooleanArgument(cmdLine, noBias));
+ lmp.setTypeMap(predictorList, typeList);
+ lmp.setLambda(getDoubleArgument(cmdLine, lambda));
+ lmp.setLearningRate(getDoubleArgument(cmdLine, rate));
+ TrainLogistic.scores = getBooleanArgument(cmdLine, scores);
+ TrainLogistic.passes = getIntegerArgument(cmdLine, passes);
+ return true;
+ }
+ private static String getStringArgument(CommandLine cmdLine, Option inputFile) {
+ return (String) cmdLine.getValue(inputFile);
+ }
+ private static boolean getBooleanArgument(CommandLine cmdLine, Option option) {
+ return cmdLine.hasOption(option);
+ }
+ private static int getIntegerArgument(CommandLine cmdLine, Option features) {
+ return Integer.parseInt((String) cmdLine.getValue(features));
+ }
+ private static double getDoubleArgument(CommandLine cmdLine, Option op) {
+ return Double.parseDouble((String) cmdLine.getValue(op));
+ }
+ public static OnlineLogisticRegression getModel() {
+ return model;
+ }
+ public static LogisticModelParameters getParameters() {
+ return lmp;
+ }
+ static BufferedReader open(String inputFile) throws IOException {
+ InputStream in;
+ try {
+ in = Resources.getResource(inputFile).openStream();
+ } catch (IllegalArgumentException e) {
+ in = new FileInputStream(new File(inputFile));
+ }
+ return new BufferedReader(new InputStreamReader(in, Charsets.UTF_8));
+ }

diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainNewsGroups.java b/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainNewsGroups.java
new file mode 100644
index 0000000..632b32c
--- /dev/null
+++ b/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainNewsGroups.java
@@ -0,0 +1,154 @@
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.classifier.sgd;
+import com.google.common.collect.HashMultiset;
+import com.google.common.collect.Multiset;
+import com.google.common.collect.Ordering;
+import org.apache.mahout.classifier.NewsgroupHelper;
+import org.apache.mahout.ep.State;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.vectorizer.encoders.Dictionary;
+import java.io.File;
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.List;
+ * Reads and trains an adaptive logistic regression model on the 20 newsgroups data.
+ * The first command line argument gives the path of the directory holding the training
+ * data. The optional second argument, leakType, defines which classes of features to use.
+ * Importantly, leakType controls whether a synthetic date is injected into the data as
+ * a target leak and if so, how.
+ * <p/>
+ * The value of leakType % 3 determines whether the target leak is injected according to
+ * the following table:
+ * <p/>
+ * <table>
+ * <tr><td valign='top'>0</td><td>No leak injected</td></tr>
+ * <tr><td valign='top'>1</td><td>Synthetic date injected in MMM-yyyy format. This will be a single token and
+ * is a perfect target leak since each newsgroup is given a different month</td></tr>
+ * <tr><td valign='top'>2</td><td>Synthetic date injected in dd-MMM-yyyy HH:mm:ss format. The day varies
+ * and thus there are more leak symbols that need to be learned. Ultimately this is just
+ * as big a leak as case 1.</td></tr>
+ * </table>
+ * <p/>
+ * Leaktype also determines what other text will be indexed. If leakType is greater
+ * than or equal to 6, then neither headers nor text body will be used for features and the leak is the only
+ * source of data. If leakType is greater than or equal to 3, then subject words will be used as features.
+ * If leakType is less than 3, then both subject and body text will be used as features.
+ * <p/>
+ * A leakType of 0 gives no leak and all textual features.
+ * <p/>
+ * See the following table for a summary of commonly used values for leakType
+ * <p/>
+ * <table>
+ * <tr><td><b>leakType</b></td><td><b>Leak?</b></td><td><b>Subject?</b></td><td><b>Body?</b></td></tr>
+ * <tr><td colspan=4><hr></td></tr>
+ * <tr><td>0</td><td>no</td><td>yes</td><td>yes</td></tr>
+ * <tr><td>1</td><td>mmm-yyyy</td><td>yes</td><td>yes</td></tr>
+ * <tr><td>2</td><td>dd-mmm-yyyy</td><td>yes</td><td>yes</td></tr>
+ * <tr><td colspan=4><hr></td></tr>
+ * <tr><td>3</td><td>no</td><td>yes</td><td>no</td></tr>
+ * <tr><td>4</td><td>mmm-yyyy</td><td>yes</td><td>no</td></tr>
+ * <tr><td>5</td><td>dd-mmm-yyyy</td><td>yes</td><td>no</td></tr>
+ * <tr><td colspan=4><hr></td></tr>
+ * <tr><td>6</td><td>no</td><td>no</td><td>no</td></tr>
+ * <tr><td>7</td><td>mmm-yyyy</td><td>no</td><td>no</td></tr>
+ * <tr><td>8</td><td>dd-mmm-yyyy</td><td>no</td><td>no</td></tr>
+ * <tr><td colspan=4><hr></td></tr>
+ * </table>
+ */
+public final class TrainNewsGroups {
+ private TrainNewsGroups() {
+ }
+ public static void main(String[] args) throws IOException {
+ File base = new File(args[0]);
+ Multiset<String> overallCounts = HashMultiset.create();
+ int leakType = 0;
+ if (args.length > 1) {
+ leakType = Integer.parseInt(args[1]);
+ }
+ Dictionary newsGroups = new Dictionary();
+ NewsgroupHelper helper = new NewsgroupHelper();
+ helper.getEncoder().setProbes(2);
+ AdaptiveLogisticRegression learningAlgorithm =
+ new AdaptiveLogisticRegression(20, NewsgroupHelper.FEATURES, new L1());
+ learningAlgorithm.setInterval(800);
+ learningAlgorithm.setAveragingWindow(500);
+ List<File> files = new ArrayList<>();
+ for (File newsgroup : base.listFiles()) {
+ if (newsgroup.isDirectory()) {
+ newsGroups.intern(newsgroup.getName());
+ files.addAll(Arrays.asList(newsgroup.listFiles()));
+ }
+ }
+ Collections.shuffle(files);
+ System.out.println(files.size() + " training files");
+ SGDInfo info = new SGDInfo();
+ int k = 0;
+ for (File file : files) {
+ String ng = file.getParentFile().getName();
+ int actual = newsGroups.intern(ng);
+ Vector v = helper.encodeFeatureVector(file, actual, leakType, overallCounts);
+ learningAlgorithm.train(actual, v);
+ k++;
+ State<AdaptiveLogisticRegression.Wrapper, CrossFoldLearner> best = learningAlgorithm.getBest();
+ SGDHelper.analyzeState(info, leakType, k, best);
+ }
+ learningAlgorithm.close();
+ SGDHelper.dissect(leakType, newsGroups, learningAlgorithm, files, overallCounts);
+ System.out.println("exiting main");
+ File modelFile = new File(System.getProperty("java.io.tmpdir"), "news-group.model");
+ ModelSerializer.writeBinary(modelFile.getAbsolutePath(),
+ learningAlgorithm.getBest().getPayload().getLearner().getModels().get(0));
+ List<Integer> counts = new ArrayList<>();
+ System.out.println("Word counts");
+ for (String count : overallCounts.elementSet()) {
+ counts.add(overallCounts.count(count));
+ }
+ Collections.sort(counts, Ordering.natural().reverse());
+ k = 0;
+ for (Integer count : counts) {
+ System.out.println(k + "\t" + count);
+ k++;
+ if (k > 1000) {
+ break;
+ }
+ }
+ }

diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sgd/ValidateAdaptiveLogistic.java b/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sgd/ValidateAdaptiveLogistic.java
new file mode 100644
index 0000000..7a74289
--- /dev/null
+++ b/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sgd/ValidateAdaptiveLogistic.java
@@ -0,0 +1,218 @@
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.classifier.sgd;
+import java.io.BufferedReader;
+import java.io.File;
+import java.io.IOException;
+import java.io.OutputStreamWriter;
+import java.io.PrintWriter;
+import java.util.Locale;
+import org.apache.commons.cli2.CommandLine;
+import org.apache.commons.cli2.Group;
+import org.apache.commons.cli2.Option;
+import org.apache.commons.cli2.builder.ArgumentBuilder;
+import org.apache.commons.cli2.builder.DefaultOptionBuilder;
+import org.apache.commons.cli2.builder.GroupBuilder;
+import org.apache.commons.cli2.commandline.Parser;
+import org.apache.commons.cli2.util.HelpFormatter;
+import org.apache.commons.io.Charsets;
+import org.apache.mahout.classifier.ConfusionMatrix;
+import org.apache.mahout.classifier.evaluation.Auc;
+import org.apache.mahout.classifier.sgd.AdaptiveLogisticRegression.Wrapper;
+import org.apache.mahout.ep.State;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.SequentialAccessSparseVector;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.stats.OnlineSummarizer;
+ * Auc and averageLikelihood are always shown if possible, if the number of target value is more than 2,
+ * then Auc and entropy matirx are not shown regardless the value of showAuc and showEntropy
+ * the user passes, because the current implementation does not support them on two value targets.
+ * */
+public final class ValidateAdaptiveLogistic {
+ private static String inputFile;
+ private static String modelFile;
+ private static String defaultCategory;
+ private static boolean showAuc;
+ private static boolean showScores;
+ private static boolean showConfusion;
+ private ValidateAdaptiveLogistic() {
+ }
+ public static void main(String[] args) throws IOException {
+ mainToOutput(args, new PrintWriter(new OutputStreamWriter(System.out, Charsets.UTF_8), true));
+ }
+ static void mainToOutput(String[] args, PrintWriter output) throws IOException {
+ if (parseArgs(args)) {
+ if (!showAuc && !showConfusion && !showScores) {
+ showAuc = true;
+ showConfusion = true;
+ }
+ Auc collector = null;
+ AdaptiveLogisticModelParameters lmp = AdaptiveLogisticModelParameters
+ .loadFromFile(new File(modelFile));
+ CsvRecordFactory csv = lmp.getCsvRecordFactory();
+ AdaptiveLogisticRegression lr = lmp.createAdaptiveLogisticRegression();
+ if (lmp.getTargetCategories().size() <= 2) {
+ collector = new Auc();
+ }
+ OnlineSummarizer slh = new OnlineSummarizer();
+ ConfusionMatrix cm = new ConfusionMatrix(lmp.getTargetCategories(), defaultCategory);
+ State<Wrapper, CrossFoldLearner> best = lr.getBest();
+ if (best == null) {
+ output.println("AdaptiveLogisticRegression has not be trained probably.");
+ return;
+ }
+ CrossFoldLearner learner = best.getPayload().getLearner();
+ BufferedReader in = TrainLogistic.open(inputFile);
+ String line = in.readLine();
+ csv.firstLine(line);
+ line = in.readLine();
+ if (showScores) {
+ output.println("\"target\", \"model-output\", \"log-likelihood\", \"average-likelihood\"");
+ }
+ while (line != null) {
+ Vector v = new SequentialAccessSparseVector(lmp.getNumFeatures());
+ //TODO: How to avoid extra target values not shown in the training process.
+ int target = csv.processLine(line, v);
+ double likelihood = learner.logLikelihood(target, v);
+ double score = learner.classifyFull(v).maxValue();
+ slh.add(likelihood);
+ cm.addInstance(csv.getTargetString(line), csv.getTargetLabel(target));
+ if (showScores) {
+ output.printf(Locale.ENGLISH, "%8d, %.12f, %.13f, %.13f%n", target,
+ score, learner.logLikelihood(target, v), slh.getMean());
+ }
+ if (collector != null) {
+ collector.add(target, score);
+ }
+ line = in.readLine();
+ }
+ output.printf(Locale.ENGLISH,"\nLog-likelihood:");
+ output.printf(Locale.ENGLISH, "Min=%.2f, Max=%.2f, Mean=%.2f, Median=%.2f%n",
+ slh.getMin(), slh.getMax(), slh.getMean(), slh.getMedian());
+ if (collector != null) {
+ output.printf(Locale.ENGLISH, "%nAUC = %.2f%n", collector.auc());
+ }
+ if (showConfusion) {
+ output.printf(Locale.ENGLISH, "%n%s%n%n", cm.toString());
+ if (collector != null) {
+ Matrix m = collector.entropy();
+ output.printf(Locale.ENGLISH,
+ "Entropy Matrix: [[%.1f, %.1f], [%.1f, %.1f]]%n", m.get(0, 0),
+ m.get(1, 0), m.get(0, 1), m.get(1, 1));
+ }
+ }
+ }
+ }
+ private static boolean parseArgs(String[] args) {
+ DefaultOptionBuilder builder = new DefaultOptionBuilder();
+ Option help = builder.withLongName("help")
+ .withDescription("print this list").create();
+ Option quiet = builder.withLongName("quiet")
+ .withDescription("be extra quiet").create();
+ Option auc = builder.withLongName("auc").withDescription("print AUC")
+ .create();
+ Option confusion = builder.withLongName("confusion")
+ .withDescription("print confusion matrix").create();
+ Option scores = builder.withLongName("scores")
+ .withDescription("print scores").create();
+ ArgumentBuilder argumentBuilder = new ArgumentBuilder();
+ Option inputFileOption = builder
+ .withLongName("input")
+ .withRequired(true)
+ .withArgument(
+ argumentBuilder.withName("input").withMaximum(1)
+ .create())
+ .withDescription("where to get validate data").create();
+ Option modelFileOption = builder
+ .withLongName("model")
+ .withRequired(true)
+ .withArgument(
+ argumentBuilder.withName("model").withMaximum(1)
+ .create())
+ .withDescription("where to get the trained model").create();
+ Option defaultCagetoryOption = builder
+ .withLongName("defaultCategory")
+ .withRequired(false)
+ .withArgument(
+ argumentBuilder.withName("defaultCategory").withMaximum(1).withDefault("unknown")
+ .create())
+ .withDescription("the default category value to use").create();
+ Group normalArgs = new GroupBuilder().withOption(help)
+ .withOption(quiet).withOption(auc).withOption(scores)
+ .withOption(confusion).withOption(inputFileOption)
+ .withOption(modelFileOption).withOption(defaultCagetoryOption).create();
+ Parser parser = new Parser();
+ parser.setHelpOption(help);
+ parser.setHelpTrigger("--help");
+ parser.setGroup(normalArgs);
+ parser.setHelpFormatter(new HelpFormatter(" ", "", " ", 130));
+ CommandLine cmdLine = parser.parseAndHelp(args);
+ if (cmdLine == null) {
+ return false;
+ }
+ inputFile = getStringArgument(cmdLine, inputFileOption);
+ modelFile = getStringArgument(cmdLine, modelFileOption);
+ defaultCategory = getStringArgument(cmdLine, defaultCagetoryOption);
+ showAuc = getBooleanArgument(cmdLine, auc);
+ showScores = getBooleanArgument(cmdLine, scores);
+ showConfusion = getBooleanArgument(cmdLine, confusion);
+ return true;
+ }
+ private static boolean getBooleanArgument(CommandLine cmdLine, Option option) {
+ return cmdLine.hasOption(option);
+ }
+ private static String getStringArgument(CommandLine cmdLine, Option inputFile) {
+ return (String) cmdLine.getValue(inputFile);
+ }

diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sgd/bankmarketing/BankMarketingClassificationMain.java b/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sgd/bankmarketing/BankMarketingClassificationMain.java
new file mode 100644
index 0000000..ab3c861
--- /dev/null
+++ b/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sgd/bankmarketing/BankMarketingClassificationMain.java
@@ -0,0 +1,70 @@
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.classifier.sgd.bankmarketing;
+import com.google.common.collect.Lists;
+import org.apache.mahout.classifier.evaluation.Auc;
+import org.apache.mahout.classifier.sgd.L1;
+import org.apache.mahout.classifier.sgd.OnlineLogisticRegression;
+import java.util.Collections;
+import java.util.List;
+ * Uses the SGD classifier on the 'Bank marketing' dataset from UCI.
+ *
+ * See http://archive.ics.uci.edu/ml/datasets/Bank+Marketing
+ *
+ * Learn when people accept or reject an offer from the bank via telephone based on income, age, education and more.
+ */
+public class BankMarketingClassificationMain {
+ public static final int NUM_CATEGORIES = 2;
+ public static void main(String[] args) throws Exception {
+ List<TelephoneCall> calls = Lists.newArrayList(new TelephoneCallParser("bank-full.csv"));
+ double heldOutPercentage = 0.10;
+ for (int run = 0; run < 20; run++) {
+ Collections.shuffle(calls);
+ int cutoff = (int) (heldOutPercentage * calls.size());
+ List<TelephoneCall> test = calls.subList(0, cutoff);
+ List<TelephoneCall> train = calls.subList(cutoff, calls.size());
+ OnlineLogisticRegression lr = new OnlineLogisticRegression(NUM_CATEGORIES, TelephoneCall.FEATURES, new L1())
+ .learningRate(1)
+ .alpha(1)
+ .lambda(0.000001)
+ .stepOffset(10000)
+ .decayExponent(0.2);
+ for (int pass = 0; pass < 20; pass++) {
+ for (TelephoneCall observation : train) {
+ lr.train(observation.getTarget(), observation.asVector());
+ }
+ if (pass % 5 == 0) {
+ Auc eval = new Auc(0.5);
+ for (TelephoneCall testCall : test) {
+ eval.add(testCall.getTarget(), lr.classifyScalar(testCall.asVector()));
+ }
+ System.out.printf("%d, %.4f, %.4f\n", pass, lr.currentLearningRate(), eval.auc());
+ }
+ }
+ }
+ }

diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sgd/bankmarketing/TelephoneCall.java b/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sgd/bankmarketing/TelephoneCall.java
new file mode 100644
index 0000000..728ec20
--- /dev/null
+++ b/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sgd/bankmarketing/TelephoneCall.java
@@ -0,0 +1,104 @@
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.classifier.sgd.bankmarketing;
+import org.apache.mahout.math.RandomAccessSparseVector;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.vectorizer.encoders.ConstantValueEncoder;
+import org.apache.mahout.vectorizer.encoders.FeatureVectorEncoder;
+import org.apache.mahout.vectorizer.encoders.StaticWordValueEncoder;
+import java.util.Iterator;
+import java.util.LinkedHashMap;
+import java.util.Map;
+public class TelephoneCall {
+ public static final int FEATURES = 100;
+ private static final ConstantValueEncoder interceptEncoder = new ConstantValueEncoder("intercept");
+ private static final FeatureVectorEncoder featureEncoder = new StaticWordValueEncoder("feature");
+ private RandomAccessSparseVector vector;
+ private Map<String, String> fields = new LinkedHashMap<>();
+ public TelephoneCall(Iterable<String> fieldNames, Iterable<String> values) {
+ vector = new RandomAccessSparseVector(FEATURES);
+ Iterator<String> value = values.iterator();
+ interceptEncoder.addToVector("1", vector);
+ for (String name : fieldNames) {
+ String fieldValue = value.next();
+ fields.put(name, fieldValue);
+ switch (name) {
+ case "age": {
+ double v = Double.parseDouble(fieldValue);
+ featureEncoder.addToVector(name, Math.log(v), vector);
+ break;
+ }
+ case "balance": {
+ double v;
+ v = Double.parseDouble(fieldValue);
+ if (v < -2000) {
+ v = -2000;
+ }
+ featureEncoder.addToVector(name, Math.log(v + 2001) - 8, vector);
+ break;
+ }
+ case "duration": {
+ double v;
+ v = Double.parseDouble(fieldValue);
+ featureEncoder.addToVector(name, Math.log(v + 1) - 5, vector);
+ break;
+ }
+ case "pdays": {
+ double v;
+ v = Double.parseDouble(fieldValue);
+ featureEncoder.addToVector(name, Math.log(v + 2), vector);
+ break;
+ }
+ case "job":
+ case "marital":
+ case "education":
+ case "default":
+ case "housing":
+ case "loan":
+ case "contact":
+ case "campaign":
+ case "previous":
+ case "poutcome":
+ featureEncoder.addToVector(name + ":" + fieldValue, 1, vector);
+ break;
+ case "day":
+ case "month":
+ case "y":
+ // ignore these for vectorizing
+ break;
+ default:
+ throw new IllegalArgumentException(String.format("Bad field name: %s", name));
+ }
+ }
+ }
+ public Vector asVector() {
+ return vector;
+ }
+ public int getTarget() {
+ return fields.get("y").equals("no") ? 0 : 1;
+ }

diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sgd/bankmarketing/TelephoneCallParser.java b/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sgd/bankmarketing/TelephoneCallParser.java
new file mode 100644
index 0000000..5ef6490
--- /dev/null
+++ b/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sgd/bankmarketing/TelephoneCallParser.java
@@ -0,0 +1,66 @@
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.classifier.sgd.bankmarketing;
+import com.google.common.base.CharMatcher;
+import com.google.common.base.Splitter;
+import com.google.common.collect.AbstractIterator;
+import com.google.common.io.Resources;
+import java.io.BufferedReader;
+import java.io.IOException;
+import java.io.InputStreamReader;
+import java.util.Iterator;
+/** Parses semi-colon separated data as TelephoneCalls */
+public class TelephoneCallParser implements Iterable<TelephoneCall> {
+ private final Splitter onSemi = Splitter.on(";").trimResults(CharMatcher.anyOf("\" ;"));
+ private String resourceName;
+ public TelephoneCallParser(String resourceName) throws IOException {
+ this.resourceName = resourceName;
+ }
+ @Override
+ public Iterator<TelephoneCall> iterator() {
+ try {
+ return new AbstractIterator<TelephoneCall>() {
+ BufferedReader input =
+ new BufferedReader(new InputStreamReader(Resources.getResource(resourceName).openStream()));
+ Iterable<String> fieldNames = onSemi.split(input.readLine());
+ @Override
+ protected TelephoneCall computeNext() {
+ try {
+ String line = input.readLine();
+ if (line == null) {
+ return endOfData();
+ }
+ return new TelephoneCall(fieldNames, onSemi.split(line));
+ } catch (IOException e) {
+ throw new RuntimeException("Error reading data", e);
+ }
+ }
+ };
+ } catch (IOException e) {
+ throw new RuntimeException("Error reading data", e);
+ }
+ }

diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/clustering/display/ClustersFilter.java b/community/mahout-mr/examples/src/main/java/org/apache/mahout/clustering/display/ClustersFilter.java
new file mode 100644
index 0000000..a0b845f
--- /dev/null
+++ b/community/mahout-mr/examples/src/main/java/org/apache/mahout/clustering/display/ClustersFilter.java
@@ -0,0 +1,31 @@
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.clustering.display;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.fs.PathFilter;
+final class ClustersFilter implements PathFilter {
+ @Override
+ public boolean accept(Path path) {
+ String pathString = path.toString();
+ return pathString.contains("/clusters-");
+ }

diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/clustering/display/DisplayCanopy.java b/community/mahout-mr/examples/src/main/java/org/apache/mahout/clustering/display/DisplayCanopy.java
new file mode 100644
index 0000000..50dba99
--- /dev/null
+++ b/community/mahout-mr/examples/src/main/java/org/apache/mahout/clustering/display/DisplayCanopy.java
@@ -0,0 +1,88 @@
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.clustering.display;
+import java.awt.BasicStroke;
+import java.awt.Color;
+import java.awt.Graphics;
+import java.awt.Graphics2D;
+import java.util.List;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.Path;
+import org.apache.mahout.clustering.Cluster;
+import org.apache.mahout.clustering.canopy.CanopyDriver;
+import org.apache.mahout.common.HadoopUtil;
+import org.apache.mahout.common.RandomUtils;
+import org.apache.mahout.common.distance.ManhattanDistanceMeasure;
+import org.apache.mahout.math.DenseVector;
+ * Java desktop graphics class that runs canopy clustering and displays the results.
+ * This class generates random data and clusters it.
+ */
+public class DisplayCanopy extends DisplayClustering {
+ DisplayCanopy() {
+ initialize();
+ this.setTitle("Canopy Clusters (>" + (int) (significance * 100) + "% of population)");
+ }
+ @Override
+ public void paint(Graphics g) {
+ plotSampleData((Graphics2D) g);
+ plotClusters((Graphics2D) g);
+ }
+ protected static void plotClusters(Graphics2D g2) {
+ int cx = CLUSTERS.size() - 1;
+ for (List<Cluster> clusters : CLUSTERS) {
+ for (Cluster cluster : clusters) {
+ if (isSignificant(cluster)) {
+ g2.setStroke(new BasicStroke(1));
+ g2.setColor(Color.BLUE);
+ double[] t1 = {T1, T1};
+ plotEllipse(g2, cluster.getCenter(), new DenseVector(t1));
+ double[] t2 = {T2, T2};
+ plotEllipse(g2, cluster.getCenter(), new DenseVector(t2));
+ g2.setColor(COLORS[Math.min(DisplayClustering.COLORS.length - 1, cx)]);
+ g2.setStroke(new BasicStroke(cx == 0 ? 3 : 1));
+ plotEllipse(g2, cluster.getCenter(), cluster.getRadius().times(3));
+ }
+ }
+ cx--;
+ }
+ }
+ public static void main(String[] args) throws Exception {
+ Path samples = new Path("samples");
+ Path output = new Path("output");
+ Configuration conf = new Configuration();
+ HadoopUtil.delete(conf, samples);
+ HadoopUtil.delete(conf, output);
+ RandomUtils.useTestSeed();
+ generateSamples();
+ writeSampleData(samples);
+ CanopyDriver.buildClusters(conf, samples, output, new ManhattanDistanceMeasure(), T1, T2, 0, true);
+ loadClustersWritable(output);
+ new DisplayCanopy();
+ }

diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/clustering/display/DisplayClustering.java b/community/mahout-mr/examples/src/main/java/org/apache/mahout/clustering/display/DisplayClustering.java
new file mode 100644
index 0000000..ad85c6a
--- /dev/null
+++ b/community/mahout-mr/examples/src/main/java/org/apache/mahout/clustering/display/DisplayClustering.java
@@ -0,0 +1,374 @@
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.clustering.display;
+import java.awt.*;
+import java.awt.event.WindowAdapter;
+import java.awt.event.WindowEvent;
+import java.awt.geom.AffineTransform;
+import java.awt.geom.Ellipse2D;
+import java.awt.geom.Rectangle2D;
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileStatus;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.SequenceFile;
+import org.apache.hadoop.io.Text;
+import org.apache.mahout.clustering.AbstractCluster;
+import org.apache.mahout.clustering.Cluster;
+import org.apache.mahout.clustering.UncommonDistributions;
+import org.apache.mahout.clustering.classify.WeightedVectorWritable;
+import org.apache.mahout.clustering.iterator.ClusterWritable;
+import org.apache.mahout.common.Pair;
+import org.apache.mahout.common.RandomUtils;
+import org.apache.mahout.common.iterator.sequencefile.PathFilters;
+import org.apache.mahout.common.iterator.sequencefile.PathType;
+import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirValueIterable;
+import org.apache.mahout.common.iterator.sequencefile.SequenceFileIterable;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+public class DisplayClustering extends Frame {
+ private static final Logger log = LoggerFactory.getLogger(DisplayClustering.class);
+ protected static final int DS = 72; // default scale = 72 pixels per inch
+ protected static final int SIZE = 8; // screen size in inches
+ private static final Collection<Vector> SAMPLE_PARAMS = new ArrayList<>();
+ protected static final List<VectorWritable> SAMPLE_DATA = new ArrayList<>();
+ protected static final List<List<Cluster>> CLUSTERS = new ArrayList<>();
+ static final Color[] COLORS = { Color.red, Color.orange, Color.yellow, Color.green, Color.blue, Color.magenta,
+ Color.lightGray };
+ protected static final double T1 = 3.0;
+ protected static final double T2 = 2.8;
+ static double significance = 0.05;
+ protected static int res; // screen resolution
+ public DisplayClustering() {
+ initialize();
+ this.setTitle("Sample Data");
+ }
+ public void initialize() {
+ // Get screen resolution
+ res = Toolkit.getDefaultToolkit().getScreenResolution();
+ // Set Frame size in inches
+ this.setSize(SIZE * res, SIZE * res);
+ this.setVisible(true);
+ this.setTitle("Asymmetric Sample Data");
+ // Window listener to terminate program.
+ this.addWindowListener(new WindowAdapter() {
+ @Override
+ public void windowClosing(WindowEvent e) {
+ System.exit(0);
+ }
+ });
+ }
+ public static void main(String[] args) throws Exception {
+ RandomUtils.useTestSeed();
+ generateSamples();
+ new DisplayClustering();
+ }
+ // Override the paint() method
+ @Override
+ public void paint(Graphics g) {
+ Graphics2D g2 = (Graphics2D) g;
+ plotSampleData(g2);
+ plotSampleParameters(g2);
+ plotClusters(g2);
+ }
+ protected static void plotClusters(Graphics2D g2) {
+ int cx = CLUSTERS.size() - 1;
+ for (List<Cluster> clusters : CLUSTERS) {
+ g2.setStroke(new BasicStroke(cx == 0 ? 3 : 1));
+ g2.setColor(COLORS[Math.min(COLORS.length - 1, cx--)]);
+ for (Cluster cluster : clusters) {
+ plotEllipse(g2, cluster.getCenter(), cluster.getRadius().times(3));
+ }
+ }
+ }
+ protected static void plotSampleParameters(Graphics2D g2) {
+ Vector v = new DenseVector(2);
+ Vector dv = new DenseVector(2);
+ g2.setColor(Color.RED);
+ for (Vector param : SAMPLE_PARAMS) {
+ v.set(0, param.get(0));
+ v.set(1, param.get(1));
+ dv.set(0, param.get(2) * 3);
+ dv.set(1, param.get(3) * 3);
+ plotEllipse(g2, v, dv);
+ }
+ }
+ protected static void plotSampleData(Graphics2D g2) {
+ double sx = (double) res / DS;
+ g2.setTransform(AffineTransform.getScaleInstance(sx, sx));
+ // plot the axes
+ g2.setColor(Color.BLACK);
+ Vector dv = new DenseVector(2).assign(SIZE / 2.0);
+ plotRectangle(g2, new DenseVector(2).assign(2), dv);
+ plotRectangle(g2, new DenseVector(2).assign(-2), dv);
+ // plot the sample data
+ g2.setColor(Color.DARK_GRAY);
+ dv.assign(0.03);
+ for (VectorWritable v : SAMPLE_DATA) {
+ plotRectangle(g2, v.get(), dv);
+ }
+ }
+ /**
+ * This method plots points and colors them according to their cluster
+ * membership, rather than drawing ellipses.
+ *
+ * As of commit, this method is used only by K-means spectral clustering.
+ * Since the cluster assignments are set within the eigenspace of the data, it
+ * is not inherent that the original data cluster as they would in K-means:
+ * that is, as symmetric gaussian mixtures.
+ *
+ * Since Spectral K-Means uses K-Means to cluster the eigenspace data, the raw
+ * output is not directly usable. Rather, the cluster assignments from the raw
+ * output need to be transferred back to the original data. As such, this
+ * method will read the SequenceFile cluster results of K-means and transfer
+ * the cluster assignments to the original data, coloring them appropriately.
+ *
+ * @param g2
+ * @param data
+ */
+ protected static void plotClusteredSampleData(Graphics2D g2, Path data) {
+ double sx = (double) res / DS;
+ g2.setTransform(AffineTransform.getScaleInstance(sx, sx));
+ g2.setColor(Color.BLACK);
+ Vector dv = new DenseVector(2).assign(SIZE / 2.0);
+ plotRectangle(g2, new DenseVector(2).assign(2), dv);
+ plotRectangle(g2, new DenseVector(2).assign(-2), dv);
+ // plot the sample data, colored according to the cluster they belong to
+ dv.assign(0.03);
+ Path clusteredPointsPath = new Path(data, "clusteredPoints");
+ Path inputPath = new Path(clusteredPointsPath, "part-m-00000");
+ Map<Integer,Color> colors = new HashMap<>();
+ int point = 0;
+ for (Pair<IntWritable,WeightedVectorWritable> record : new SequenceFileIterable<IntWritable,WeightedVectorWritable>(
+ inputPath, new Configuration())) {
+ int clusterId = record.getFirst().get();
+ VectorWritable v = SAMPLE_DATA.get(point++);
+ Integer key = clusterId;
+ if (!colors.containsKey(key)) {
+ colors.put(key, COLORS[Math.min(COLORS.length - 1, colors.size())]);
+ }
+ plotClusteredRectangle(g2, v.get(), dv, colors.get(key));
+ }
+ }
+ /**
+ * Identical to plotRectangle(), but with the option of setting the color of
+ * the rectangle's stroke.
+ *
+ * NOTE: This should probably be refactored with plotRectangle() since most of
+ * the code here is direct copy/paste from that method.
+ *
+ * @param g2
+ * A Graphics2D context.
+ * @param v
+ * A vector for the rectangle's center.
+ * @param dv
+ * A vector for the rectangle's dimensions.
+ * @param color
+ * The color of the rectangle's stroke.
+ */
+ protected static void plotClusteredRectangle(Graphics2D g2, Vector v, Vector dv, Color color) {
+ double[] flip = {1, -1};
+ Vector v2 = v.times(new DenseVector(flip));
+ v2 = v2.minus(dv.divide(2));
+ int h = SIZE / 2;
+ double x = v2.get(0) + h;
+ double y = v2.get(1) + h;
+ g2.setStroke(new BasicStroke(1));
+ g2.setColor(color);
+ g2.draw(new Rectangle2D.Double(x * DS, y * DS, dv.get(0) * DS, dv.get(1) * DS));
+ }
+ /**
+ * Draw a rectangle on the graphics context
+ *
+ * @param g2
+ * a Graphics2D context
+ * @param v
+ * a Vector of rectangle center
+ * @param dv
+ * a Vector of rectangle dimensions
+ */
+ protected static void plotRectangle(Graphics2D g2, Vector v, Vector dv) {
+ double[] flip = {1, -1};
+ Vector v2 = v.times(new DenseVector(flip));
+ v2 = v2.minus(dv.divide(2));
+ int h = SIZE / 2;
+ double x = v2.get(0) + h;
+ double y = v2.get(1) + h;
+ g2.draw(new Rectangle2D.Double(x * DS, y * DS, dv.get(0) * DS, dv.get(1) * DS));
+ }
+ /**
+ * Draw an ellipse on the graphics context
+ *
+ * @param g2
+ * a Graphics2D context
+ * @param v
+ * a Vector of ellipse center
+ * @param dv
+ * a Vector of ellipse dimensions
+ */
+ protected static void plotEllipse(Graphics2D g2, Vector v, Vector dv) {
+ double[] flip = {1, -1};
+ Vector v2 = v.times(new DenseVector(flip));
+ v2 = v2.minus(dv.divide(2));
+ int h = SIZE / 2;
+ double x = v2.get(0) + h;
+ double y = v2.get(1) + h;
+ g2.draw(new Ellipse2D.Double(x * DS, y * DS, dv.get(0) * DS, dv.get(1) * DS));
+ }
+ protected static void generateSamples() {
+ generateSamples(500, 1, 1, 3);
+ generateSamples(300, 1, 0, 0.5);
+ generateSamples(300, 0, 2, 0.1);
+ }
+ protected static void generate2dSamples() {
+ generate2dSamples(500, 1, 1, 3, 1);
+ generate2dSamples(300, 1, 0, 0.5, 1);
+ generate2dSamples(300, 0, 2, 0.1, 0.5);
+ }
+ /**
+ * Generate random samples and add them to the sampleData
+ *
+ * @param num
+ * int number of samples to generate
+ * @param mx
+ * double x-value of the sample mean
+ * @param my
+ * double y-value of the sample mean
+ * @param sd
+ * double standard deviation of the samples
+ */
+ protected static void generateSamples(int num, double mx, double my, double sd) {
+ double[] params = {mx, my, sd, sd};
+ SAMPLE_PARAMS.add(new DenseVector(params));
+ log.info("Generating {} samples m=[{}, {}] sd={}", num, mx, my, sd);
+ for (int i = 0; i < num; i++) {
+ SAMPLE_DATA.add(new VectorWritable(new DenseVector(new double[] {UncommonDistributions.rNorm(mx, sd),
+ UncommonDistributions.rNorm(my, sd)})));
+ }
+ }
+ protected static void writeSampleData(Path output) throws IOException {
+ Configuration conf = new Configuration();
+ FileSystem fs = FileSystem.get(output.toUri(), conf);
+ try (SequenceFile.Writer writer = new SequenceFile.Writer(fs, conf, output, Text.class, VectorWritable.class)) {
+ int i = 0;
+ for (VectorWritable vw : SAMPLE_DATA) {
+ writer.append(new Text("sample_" + i++), vw);
+ }
+ }
+ }
+ protected static List<Cluster> readClustersWritable(Path clustersIn) {
+ List<Cluster> clusters = new ArrayList<>();
+ Configuration conf = new Configuration();
+ for (ClusterWritable value : new SequenceFileDirValueIterable<ClusterWritable>(clustersIn, PathType.LIST,
+ PathFilters.logsCRCFilter(), conf)) {
+ Cluster cluster = value.getValue();
+ log.info(
+ "Reading Cluster:{} center:{} numPoints:{} radius:{}",
+ cluster.getId(), AbstractCluster.formatVector(cluster.getCenter(), null),
+ cluster.getNumObservations(), AbstractCluster.formatVector(cluster.getRadius(), null));
+ clusters.add(cluster);
+ }
+ return clusters;
+ }
+ protected static void loadClustersWritable(Path output) throws IOException {
+ Configuration conf = new Configuration();
+ FileSystem fs = FileSystem.get(output.toUri(), conf);
+ for (FileStatus s : fs.listStatus(output, new ClustersFilter())) {
+ List<Cluster> clusters = readClustersWritable(s.getPath());
+ CLUSTERS.add(clusters);
+ }
+ }
+ /**
+ * Generate random samples and add them to the sampleData
+ *
+ * @param num
+ * int number of samples to generate
+ * @param mx
+ * double x-value of the sample mean
+ * @param my
+ * double y-value of the sample mean
+ * @param sdx
+ * double x-value standard deviation of the samples
+ * @param sdy
+ * double y-value standard deviation of the samples
+ */
+ protected static void generate2dSamples(int num, double mx, double my, double sdx, double sdy) {
+ double[] params = {mx, my, sdx, sdy};
+ SAMPLE_PARAMS.add(new DenseVector(params));
+ log.info("Generating {} samples m=[{}, {}] sd=[{}, {}]", num, mx, my, sdx, sdy);
+ for (int i = 0; i < num; i++) {
+ SAMPLE_DATA.add(new VectorWritable(new DenseVector(new double[] {UncommonDistributions.rNorm(mx, sdx),
+ UncommonDistributions.rNorm(my, sdy)})));
+ }
+ }
+ protected static boolean isSignificant(Cluster cluster) {
+ return (double) cluster.getNumObservations() / SAMPLE_DATA.size() > significance;
+ }

diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/clustering/display/DisplayFuzzyKMeans.java b/community/mahout-mr/examples/src/main/java/org/apache/mahout/clustering/display/DisplayFuzzyKMeans.java
new file mode 100644
index 0000000..f8ce7c7
--- /dev/null
+++ b/community/mahout-mr/examples/src/main/java/org/apache/mahout/clustering/display/DisplayFuzzyKMeans.java
@@ -0,0 +1,110 @@
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.clustering.display;
+import java.awt.Graphics;
+import java.awt.Graphics2D;
+import java.io.IOException;
+import java.util.Collection;
+import java.util.List;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.Path;
+import org.apache.mahout.clustering.Cluster;
+import org.apache.mahout.clustering.classify.ClusterClassifier;
+import org.apache.mahout.clustering.fuzzykmeans.FuzzyKMeansDriver;
+import org.apache.mahout.clustering.fuzzykmeans.SoftCluster;
+import org.apache.mahout.clustering.iterator.ClusterIterator;
+import org.apache.mahout.clustering.iterator.FuzzyKMeansClusteringPolicy;
+import org.apache.mahout.clustering.kmeans.RandomSeedGenerator;
+import org.apache.mahout.common.HadoopUtil;
+import org.apache.mahout.common.RandomUtils;
+import org.apache.mahout.common.distance.DistanceMeasure;
+import org.apache.mahout.common.distance.ManhattanDistanceMeasure;
+import org.apache.mahout.math.Vector;
+import com.google.common.collect.Lists;
+public class DisplayFuzzyKMeans extends DisplayClustering {
+ DisplayFuzzyKMeans() {
+ initialize();
+ this.setTitle("Fuzzy k-Means Clusters (>" + (int) (significance * 100) + "% of population)");
+ }
+ // Override the paint() method
+ @Override
+ public void paint(Graphics g) {
+ plotSampleData((Graphics2D) g);
+ plotClusters((Graphics2D) g);
+ }
+ public static void main(String[] args) throws Exception {
+ DistanceMeasure measure = new ManhattanDistanceMeasure();
+ Path samples = new Path("samples");
+ Path output = new Path("output");
+ Configuration conf = new Configuration();
+ HadoopUtil.delete(conf, output);
+ HadoopUtil.delete(conf, samples);
+ RandomUtils.useTestSeed();
+ DisplayClustering.generateSamples();
+ writeSampleData(samples);
+ boolean runClusterer = true;
+ int maxIterations = 10;
+ float threshold = 0.001F;
+ float m = 1.1F;
+ if (runClusterer) {
+ runSequentialFuzzyKClusterer(conf, samples, output, measure, maxIterations, m, threshold);
+ } else {
+ int numClusters = 3;
+ runSequentialFuzzyKClassifier(conf, samples, output, measure, numClusters, maxIterations, m, threshold);
+ }
+ new DisplayFuzzyKMeans();
+ }
+ private static void runSequentialFuzzyKClassifier(Configuration conf, Path samples, Path output,
+ DistanceMeasure measure, int numClusters, int maxIterations, float m, double threshold) throws IOException {
+ Collection<Vector> points = Lists.newArrayList();
+ for (int i = 0; i < numClusters; i++) {
+ points.add(SAMPLE_DATA.get(i).get());
+ }
+ List<Cluster> initialClusters = Lists.newArrayList();
+ int id = 0;
+ for (Vector point : points) {
+ initialClusters.add(new SoftCluster(point, id++, measure));
+ }
+ ClusterClassifier prior = new ClusterClassifier(initialClusters, new FuzzyKMeansClusteringPolicy(m, threshold));
+ Path priorPath = new Path(output, "classifier-0");
+ prior.writeToSeqFiles(priorPath);
+ ClusterIterator.iterateSeq(conf, samples, priorPath, output, maxIterations);
+ loadClustersWritable(output);
+ }
+ private static void runSequentialFuzzyKClusterer(Configuration conf, Path samples, Path output,
+ DistanceMeasure measure, int maxIterations, float m, double threshold) throws IOException,
+ ClassNotFoundException, InterruptedException {
+ Path clustersIn = new Path(output, "random-seeds");
+ RandomSeedGenerator.buildRandom(conf, samples, clustersIn, 3, measure);
+ FuzzyKMeansDriver.run(samples, clustersIn, output, threshold, maxIterations, m, true, true, threshold,
+ true);
+ loadClustersWritable(output);
+ }

diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/clustering/display/DisplayKMeans.java b/community/mahout-mr/examples/src/main/java/org/apache/mahout/clustering/display/DisplayKMeans.java
new file mode 100644
index 0000000..336d69e
--- /dev/null
+++ b/community/mahout-mr/examples/src/main/java/org/apache/mahout/clustering/display/DisplayKMeans.java
@@ -0,0 +1,106 @@
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.clustering.display;
+import java.awt.Graphics;
+import java.awt.Graphics2D;
+import java.io.IOException;
+import java.util.Collection;
+import java.util.List;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.Path;
+import org.apache.mahout.clustering.Cluster;
+import org.apache.mahout.clustering.classify.ClusterClassifier;
+import org.apache.mahout.clustering.iterator.ClusterIterator;
+import org.apache.mahout.clustering.iterator.KMeansClusteringPolicy;
+import org.apache.mahout.clustering.kmeans.KMeansDriver;
+import org.apache.mahout.clustering.kmeans.RandomSeedGenerator;
+import org.apache.mahout.common.HadoopUtil;
+import org.apache.mahout.common.RandomUtils;
+import org.apache.mahout.common.distance.DistanceMeasure;
+import org.apache.mahout.common.distance.ManhattanDistanceMeasure;
+import org.apache.mahout.math.Vector;
+import com.google.common.collect.Lists;
+public class DisplayKMeans extends DisplayClustering {
+ DisplayKMeans() {
+ initialize();
+ this.setTitle("k-Means Clusters (>" + (int) (significance * 100) + "% of population)");
+ }
+ public static void main(String[] args) throws Exception {
+ DistanceMeasure measure = new ManhattanDistanceMeasure();
+ Path samples = new Path("samples");
+ Path output = new Path("output");
+ Configuration conf = new Configuration();
+ HadoopUtil.delete(conf, samples);
+ HadoopUtil.delete(conf, output);
+ RandomUtils.useTestSeed();
+ generateSamples();
+ writeSampleData(samples);
+ boolean runClusterer = true;
+ double convergenceDelta = 0.001;
+ int numClusters = 3;
+ int maxIterations = 10;
+ if (runClusterer) {
+ runSequentialKMeansClusterer(conf, samples, output, measure, numClusters, maxIterations, convergenceDelta);
+ } else {
+ runSequentialKMeansClassifier(conf, samples, output, measure, numClusters, maxIterations, convergenceDelta);
+ }
+ new DisplayKMeans();
+ }
+ private static void runSequentialKMeansClassifier(Configuration conf, Path samples, Path output,
+ DistanceMeasure measure, int numClusters, int maxIterations, double convergenceDelta) throws IOException {
+ Collection<Vector> points = Lists.newArrayList();
+ for (int i = 0; i < numClusters; i++) {
+ points.add(SAMPLE_DATA.get(i).get());
+ }
+ List<Cluster> initialClusters = Lists.newArrayList();
+ int id = 0;
+ for (Vector point : points) {
+ initialClusters.add(new org.apache.mahout.clustering.kmeans.Kluster(point, id++, measure));
+ }
+ ClusterClassifier prior = new ClusterClassifier(initialClusters, new KMeansClusteringPolicy(convergenceDelta));
+ Path priorPath = new Path(output, Cluster.INITIAL_CLUSTERS_DIR);
+ prior.writeToSeqFiles(priorPath);
+ ClusterIterator.iterateSeq(conf, samples, priorPath, output, maxIterations);
+ loadClustersWritable(output);
+ }
+ private static void runSequentialKMeansClusterer(Configuration conf, Path samples, Path output,
+ DistanceMeasure measure, int numClusters, int maxIterations, double convergenceDelta)
+ throws IOException, InterruptedException, ClassNotFoundException {
+ Path clustersIn = new Path(output, "random-seeds");
+ RandomSeedGenerator.buildRandom(conf, samples, clustersIn, numClusters, measure);
+ KMeansDriver.run(samples, clustersIn, output, convergenceDelta, maxIterations, true, 0.0, true);
+ loadClustersWritable(output);
+ }
+ // Override the paint() method
+ @Override
+ public void paint(Graphics g) {
+ plotSampleData((Graphics2D) g);
+ plotClusters((Graphics2D) g);
+ }

diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/clustering/display/DisplaySpectralKMeans.java b/community/mahout-mr/examples/src/main/java/org/apache/mahout/clustering/display/DisplaySpectralKMeans.java
new file mode 100644
index 0000000..2b70749
--- /dev/null
+++ b/community/mahout-mr/examples/src/main/java/org/apache/mahout/clustering/display/DisplaySpectralKMeans.java
@@ -0,0 +1,85 @@
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.clustering.display;
+import java.awt.Graphics;
+import java.awt.Graphics2D;
+import java.io.BufferedWriter;
+import java.io.FileWriter;
+import java.io.Writer;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.mahout.clustering.spectral.kmeans.SpectralKMeansDriver;
+import org.apache.mahout.common.HadoopUtil;
+import org.apache.mahout.common.RandomUtils;
+import org.apache.mahout.common.distance.DistanceMeasure;
+import org.apache.mahout.common.distance.ManhattanDistanceMeasure;
+public class DisplaySpectralKMeans extends DisplayClustering {
+ protected static final String SAMPLES = "samples";
+ protected static final String OUTPUT = "output";
+ protected static final String TEMP = "tmp";
+ protected static final String AFFINITIES = "affinities";
+ DisplaySpectralKMeans() {
+ initialize();
+ setTitle("Spectral k-Means Clusters (>" + (int) (significance * 100) + "% of population)");
+ }
+ public static void main(String[] args) throws Exception {
+ DistanceMeasure measure = new ManhattanDistanceMeasure();
+ Path samples = new Path(SAMPLES);
+ Path output = new Path(OUTPUT);
+ Path tempDir = new Path(TEMP);
+ Configuration conf = new Configuration();
+ HadoopUtil.delete(conf, samples);
+ HadoopUtil.delete(conf, output);
+ RandomUtils.useTestSeed();
+ DisplayClustering.generateSamples();
+ writeSampleData(samples);
+ Path affinities = new Path(output, AFFINITIES);
+ FileSystem fs = FileSystem.get(output.toUri(), conf);
+ if (!fs.exists(output)) {
+ fs.mkdirs(output);
+ }
+ try (Writer writer = new BufferedWriter(new FileWriter(affinities.toString()))){
+ for (int i = 0; i < SAMPLE_DATA.size(); i++) {
+ for (int j = 0; j < SAMPLE_DATA.size(); j++) {
+ writer.write(i + "," + j + ',' + measure.distance(SAMPLE_DATA.get(i).get(),
+ SAMPLE_DATA.get(j).get()) + '\n');
+ }
+ }
+ }
+ int maxIter = 10;
+ double convergenceDelta = 0.001;
+ SpectralKMeansDriver.run(new Configuration(), affinities, output, SAMPLE_DATA.size(), 3, measure,
+ convergenceDelta, maxIter, tempDir);
+ new DisplaySpectralKMeans();
+ }
+ @Override
+ public void paint(Graphics g) {
+ plotClusteredSampleData((Graphics2D) g, new Path(new Path(OUTPUT), "kmeans_out"));
+ }

diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/clustering/display/README.txt b/community/mahout-mr/examples/src/main/java/org/apache/mahout/clustering/display/README.txt
new file mode 100644
index 0000000..470c16c
--- /dev/null
+++ b/community/mahout-mr/examples/src/main/java/org/apache/mahout/clustering/display/README.txt
@@ -0,0 +1,22 @@
+The following classes can be run without parameters to generate a sample data set and
+run the reference clustering implementations over them:
+DisplayClustering - generates 1000 samples from three, symmetric distributions. This is the same
+ data set that is used by the following clustering programs. It displays the points on a screen
+ and superimposes the model parameters that were used to generate the points. You can edit the
+ generateSamples() method to change the sample points used by these programs.
+ * DisplayCanopy - uses Canopy clustering
+ * DisplayKMeans - uses k-Means clustering
+ * DisplayFuzzyKMeans - uses Fuzzy k-Means clustering
+ * NOTE: some of these programs display the sample points and then superimpose all of the clusters
+ from each iteration. The last iteration's clusters are in bold red and the previous several are
+ colored (orange, yellow, green, blue, violet) in order after which all earlier clusters are in
+ light grey. This helps to visualize how the clusters converge upon a solution over multiple
+ iterations.
+ * NOTE: by changing the parameter values (k, ALPHA_0, numIterations) and the display SIGNIFICANCE
+ you can obtain different results.
\ No newline at end of file

diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/clustering/streaming/tools/ClusterQualitySummarizer.java b/community/mahout-mr/examples/src/main/java/org/apache/mahout/clustering/streaming/tools/ClusterQualitySummarizer.java
new file mode 100644
index 0000000..c29cbc4
--- /dev/null
+++ b/community/mahout-mr/examples/src/main/java/org/apache/mahout/clustering/streaming/tools/ClusterQualitySummarizer.java
@@ -0,0 +1,279 @@
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.clustering.streaming.tools;
+import java.io.FileOutputStream;
+import java.io.IOException;
+import java.io.PrintWriter;
+import java.util.List;
+import com.google.common.collect.Iterables;
+import com.google.common.collect.Lists;
+import com.google.common.io.Closeables;
+import org.apache.commons.cli2.CommandLine;
+import org.apache.commons.cli2.Group;
+import org.apache.commons.cli2.Option;
+import org.apache.commons.cli2.builder.ArgumentBuilder;
+import org.apache.commons.cli2.builder.DefaultOptionBuilder;
+import org.apache.commons.cli2.builder.GroupBuilder;
+import org.apache.commons.cli2.commandline.Parser;
+import org.apache.commons.cli2.util.HelpFormatter;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.Path;
+import org.apache.mahout.clustering.iterator.ClusterWritable;
+import org.apache.mahout.clustering.ClusteringUtils;
+import org.apache.mahout.clustering.streaming.mapreduce.CentroidWritable;
+import org.apache.mahout.common.AbstractJob;
+import org.apache.mahout.common.distance.DistanceMeasure;
+import org.apache.mahout.common.distance.SquaredEuclideanDistanceMeasure;
+import org.apache.mahout.common.iterator.sequencefile.PathType;
+import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirValueIterable;
+import org.apache.mahout.math.Centroid;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+import org.apache.mahout.math.stats.OnlineSummarizer;
+public class ClusterQualitySummarizer extends AbstractJob {
+ private String outputFile;
+ private PrintWriter fileOut;
+ private String trainFile;
+ private String testFile;
+ private String centroidFile;
+ private String centroidCompareFile;
+ private boolean mahoutKMeansFormat;
+ private boolean mahoutKMeansFormatCompare;
+ private DistanceMeasure distanceMeasure = new SquaredEuclideanDistanceMeasure();
+ public void printSummaries(List<OnlineSummarizer> summarizers, String type) {
+ printSummaries(summarizers, type, fileOut);
+ }
+ public static void printSummaries(List<OnlineSummarizer> summarizers, String type, PrintWriter fileOut) {
+ double maxDistance = 0;
+ for (int i = 0; i < summarizers.size(); ++i) {
+ OnlineSummarizer summarizer = summarizers.get(i);
+ if (summarizer.getCount() > 1) {
+ maxDistance = Math.max(maxDistance, summarizer.getMax());
+ System.out.printf("Average distance in cluster %d [%d]: %f\n", i, summarizer.getCount(), summarizer.getMean());
+ // If there is just one point in the cluster, quartiles cannot be estimated. We'll just assume all the quartiles
+ // equal the only value.
+ if (fileOut != null) {
+ fileOut.printf("%d,%f,%f,%f,%f,%f,%f,%f,%d,%s\n", i, summarizer.getMean(),
+ summarizer.getSD(),
+ summarizer.getQuartile(0),
+ summarizer.getQuartile(1),
+ summarizer.getQuartile(2),
+ summarizer.getQuartile(3),
+ summarizer.getQuartile(4), summarizer.getCount(), type);
+ }
+ } else {
+ System.out.printf("Cluster %d is has %d data point. Need atleast 2 data points in a cluster for" +
+ " OnlineSummarizer.\n", i, summarizer.getCount());
+ }
+ }
+ System.out.printf("Num clusters: %d; maxDistance: %f\n", summarizers.size(), maxDistance);
+ }
+ public int run(String[] args) throws IOException {
+ if (!parseArgs(args)) {
+ return -1;
+ }
+ Configuration conf = new Configuration();
+ try {
+ fileOut = new PrintWriter(new FileOutputStream(outputFile));
+ fileOut.printf("cluster,distance.mean,distance.sd,distance.q0,distance.q1,distance.q2,distance.q3,"
+ + "distance.q4,count,is.train\n");
+ // Reading in the centroids (both pairs, if they exist).
+ List<Centroid> centroids;
+ List<Centroid> centroidsCompare = null;
+ if (mahoutKMeansFormat) {
+ SequenceFileDirValueIterable<ClusterWritable> clusterIterable =
+ new SequenceFileDirValueIterable<>(new Path(centroidFile), PathType.GLOB, conf);
+ centroids = Lists.newArrayList(IOUtils.getCentroidsFromClusterWritableIterable(clusterIterable));
+ } else {
+ SequenceFileDirValueIterable<CentroidWritable> centroidIterable =
+ new SequenceFileDirValueIterable<>(new Path(centroidFile), PathType.GLOB, conf);
+ centroids = Lists.newArrayList(IOUtils.getCentroidsFromCentroidWritableIterable(centroidIterable));
+ }
+ if (centroidCompareFile != null) {
+ if (mahoutKMeansFormatCompare) {
+ SequenceFileDirValueIterable<ClusterWritable> clusterCompareIterable =
+ new SequenceFileDirValueIterable<>(new Path(centroidCompareFile), PathType.GLOB, conf);
+ centroidsCompare = Lists.newArrayList(
+ IOUtils.getCentroidsFromClusterWritableIterable(clusterCompareIterable));
+ } else {
+ SequenceFileDirValueIterable<CentroidWritable> centroidCompareIterable =
+ new SequenceFileDirValueIterable<>(new Path(centroidCompareFile), PathType.GLOB, conf);
+ centroidsCompare = Lists.newArrayList(
+ IOUtils.getCentroidsFromCentroidWritableIterable(centroidCompareIterable));
+ }
+ }
+ // Reading in the "training" set.
+ SequenceFileDirValueIterable<VectorWritable> trainIterable =
+ new SequenceFileDirValueIterable<>(new Path(trainFile), PathType.GLOB, conf);
+ Iterable<Vector> trainDatapoints = IOUtils.getVectorsFromVectorWritableIterable(trainIterable);
+ Iterable<Vector> datapoints = trainDatapoints;
+ printSummaries(ClusteringUtils.summarizeClusterDistances(trainDatapoints, centroids,
+ new SquaredEuclideanDistanceMeasure()), "train");
+ // Also adding in the "test" set.
+ if (testFile != null) {
+ SequenceFileDirValueIterable<VectorWritable> testIterable =
+ new SequenceFileDirValueIterable<>(new Path(testFile), PathType.GLOB, conf);
+ Iterable<Vector> testDatapoints = IOUtils.getVectorsFromVectorWritableIterable(testIterable);
+ printSummaries(ClusteringUtils.summarizeClusterDistances(testDatapoints, centroids,
+ new SquaredEuclideanDistanceMeasure()), "test");
+ datapoints = Iterables.concat(trainDatapoints, testDatapoints);
+ }
+ // At this point, all train/test CSVs have been written. We now compute quality metrics.
+ List<OnlineSummarizer> summaries =
+ ClusteringUtils.summarizeClusterDistances(datapoints, centroids, distanceMeasure);
+ List<OnlineSummarizer> compareSummaries = null;
+ if (centroidsCompare != null) {
+ compareSummaries = ClusteringUtils.summarizeClusterDistances(datapoints, centroidsCompare, distanceMeasure);
+ }
+ System.out.printf("[Dunn Index] First: %f", ClusteringUtils.dunnIndex(centroids, distanceMeasure, summaries));
+ if (compareSummaries != null) {
+ System.out.printf(" Second: %f\n", ClusteringUtils.dunnIndex(centroidsCompare, distanceMeasure, compareSummaries));
+ } else {
+ System.out.printf("\n");
+ }
+ System.out.printf("[Davies-Bouldin Index] First: %f",
+ ClusteringUtils.daviesBouldinIndex(centroids, distanceMeasure, summaries));
+ if (compareSummaries != null) {
+ System.out.printf(" Second: %f\n",
+ ClusteringUtils.daviesBouldinIndex(centroidsCompare, distanceMeasure, compareSummaries));
+ } else {
+ System.out.printf("\n");
+ }
+ } catch (IOException e) {
+ System.out.println(e.getMessage());
+ } finally {
+ Closeables.close(fileOut, false);
+ }
+ return 0;
+ }
+ private boolean parseArgs(String[] args) {
+ DefaultOptionBuilder builder = new DefaultOptionBuilder();
+ Option help = builder.withLongName("help").withDescription("print this list").create();
+ ArgumentBuilder argumentBuilder = new ArgumentBuilder();
+ Option inputFileOption = builder.withLongName("input")
+ .withShortName("i")
+ .withRequired(true)
+ .withArgument(argumentBuilder.withName("input").withMaximum(1).create())
+ .withDescription("where to get seq files with the vectors (training set)")
+ .create();
+ Option testInputFileOption = builder.withLongName("testInput")
+ .withShortName("itest")
+ .withArgument(argumentBuilder.withName("testInput").withMaximum(1).create())
+ .withDescription("where to get seq files with the vectors (test set)")
+ .create();
+ Option centroidsFileOption = builder.withLongName("centroids")
+ .withShortName("c")
+ .withRequired(true)
+ .withArgument(argumentBuilder.withName("centroids").withMaximum(1).create())
+ .withDescription("where to get seq files with the centroids (from Mahout KMeans or StreamingKMeansDriver)")
+ .create();
+ Option centroidsCompareFileOption = builder.withLongName("centroidsCompare")
+ .withShortName("cc")
+ .withRequired(false)
+ .withArgument(argumentBuilder.withName("centroidsCompare").withMaximum(1).create())
+ .withDescription("where to get seq files with the second set of centroids (from Mahout KMeans or "
+ + "StreamingKMeansDriver)")
+ .create();
+ Option outputFileOption = builder.withLongName("output")
+ .withShortName("o")
+ .withRequired(true)
+ .withArgument(argumentBuilder.withName("output").withMaximum(1).create())
+ .withDescription("where to dump the CSV file with the results")
+ .create();
+ Option mahoutKMeansFormatOption = builder.withLongName("mahoutkmeansformat")
+ .withShortName("mkm")
+ .withDescription("if set, read files as (IntWritable, ClusterWritable) pairs")
+ .withArgument(argumentBuilder.withName("numpoints").withMaximum(1).create())
+ .create();
+ Option mahoutKMeansCompareFormatOption = builder.withLongName("mahoutkmeansformatCompare")
+ .withShortName("mkmc")
+ .withDescription("if set, read files as (IntWritable, ClusterWritable) pairs")
+ .withArgument(argumentBuilder.withName("numpoints").withMaximum(1).create())
+ .create();
+ Group normalArgs = new GroupBuilder()
+ .withOption(help)
+ .withOption(inputFileOption)
+ .withOption(testInputFileOption)
+ .withOption(outputFileOption)
+ .withOption(centroidsFileOption)
+ .withOption(centroidsCompareFileOption)
+ .withOption(mahoutKMeansFormatOption)
+ .withOption(mahoutKMeansCompareFormatOption)
+ .create();
+ Parser parser = new Parser();
+ parser.setHelpOption(help);
+ parser.setHelpTrigger("--help");
+ parser.setGroup(normalArgs);
+ parser.setHelpFormatter(new HelpFormatter(" ", "", " ", 150));
+ CommandLine cmdLine = parser.parseAndHelp(args);
+ if (cmdLine == null) {
+ return false;
+ }
+ trainFile = (String) cmdLine.getValue(inputFileOption);
+ if (cmdLine.hasOption(testInputFileOption)) {
+ testFile = (String) cmdLine.getValue(testInputFileOption);
+ }
+ centroidFile = (String) cmdLine.getValue(centroidsFileOption);
+ if (cmdLine.hasOption(centroidsCompareFileOption)) {
+ centroidCompareFile = (String) cmdLine.getValue(centroidsCompareFileOption);
+ }
+ outputFile = (String) cmdLine.getValue(outputFileOption);
+ if (cmdLine.hasOption(mahoutKMeansFormatOption)) {
+ mahoutKMeansFormat = true;
+ }
+ if (cmdLine.hasOption(mahoutKMeansCompareFormatOption)) {
+ mahoutKMeansFormatCompare = true;
+ }
+ return true;
+ }
+ public static void main(String[] args) throws IOException {
+ new ClusterQualitySummarizer().run(args);
+ }
2018-06-27 13:14:48 UTC
diff --git a/community/mahout-mr/examples/bin/resources/bank-full.csv b/community/mahout-mr/examples/bin/resources/bank-full.csv
new file mode 100644
index 0000000..d7a2ede
--- /dev/null
+++ b/community/mahout-mr/examples/bin/resources/bank-full.csv
@@ -0,0 +1,45212 @@

2018-06-27 13:14:49 UTC
diff --git a/community/mahout-mr/examples/bin/cluster-syntheticcontrol.sh b/community/mahout-mr/examples/bin/cluster-syntheticcontrol.sh
new file mode 100755
index 0000000..796da33
--- /dev/null
+++ b/community/mahout-mr/examples/bin/cluster-syntheticcontrol.sh
@@ -0,0 +1,105 @@
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# Downloads the Synthetic control dataset and prepares it for clustering
+# To run: change into the mahout directory and type:
+# examples/bin/cluster-syntheticcontrol.sh
+if [ "$1" = "--help" ] || [ "$1" = "--?" ]; then
+ echo "This script clusters the Synthetic Control data set. The data set is downloaded automatically."
+ exit
+algorithm=( kmeans fuzzykmeans )
+if [ -n "$1" ]; then
+ choice=$1
+ echo "Please select a number to choose the corresponding clustering algorithm"
+ echo "1. ${algorithm[0]} clustering"
+ echo "2. ${algorithm[1]} clustering"
+ read -p "Enter your choice : " choice
+echo "ok. You chose $choice and we'll use ${algorithm[$choice-1]} Clustering"
+if [ "$0" != "$SCRIPT_PATH" ] && [ "$SCRIPT_PATH" != "" ]; then
+# Set commands for dfs
+source ${START_PATH}/set-dfs-commands.sh
+if [[ -z "$MAHOUT_WORK_DIR" ]]; then
+ WORK_DIR=/tmp/mahout-work-${USER}
+echo "creating work directory at ${WORK_DIR}"
+mkdir -p ${WORK_DIR}
+if [ ! -f ${WORK_DIR}/synthetic_control.data ]; then
+ if [ -n "$2" ]; then
+ cp $2 ${WORK_DIR}/.
+ else
+ echo "Downloading Synthetic control data"
+ curl http://archive.ics.uci.edu/ml/databases/synthetic_control/synthetic_control.data -o ${WORK_DIR}/synthetic_control.data
+ fi
+if [ ! -f ${WORK_DIR}/synthetic_control.data ]; then
+ echo "Couldn't download synthetic control"
+ exit 1
+if [ "$HADOOP_HOME" != "" ] && [ "$MAHOUT_LOCAL" == "" ]; then
+ echo "Checking the health of DFS..."
+ $DFS -ls /
+ if [ $? -eq 0 ];then
+ echo "DFS is healthy... "
+ echo "Uploading Synthetic control data to HDFS"
+ $DFSRM ${WORK_DIR}/testdata
+ $DFS -mkdir -p ${WORK_DIR}/testdata
+ $DFS -put ${WORK_DIR}/synthetic_control.data ${WORK_DIR}/testdata
+ echo "Successfully Uploaded Synthetic control data to HDFS "
+ options="--input ${WORK_DIR}/testdata --output ${WORK_DIR}/output --maxIter 10 --convergenceDelta 0.5"
+ if [ "${clustertype}" == "kmeans" ]; then
+ options="${options} --numClusters 6"
+ # t1 & t2 not used if --numClusters specified, but parser requires input
+ options="${options} --t1 1 --t2 2"
+ ../../bin/mahout.bu org.apache.mahout.clustering.syntheticcontrol."${clustertype}".Job ${options}
+ else
+ options="${options} --m 2.0f --t1 80 --t2 55"
+ ../../bin/mahout.bu org.apache.mahout.clustering.syntheticcontrol."${clustertype}".Job ${options}
+ fi
+ else
+ echo " HADOOP is not running. Please make sure you hadoop is running. "
+ fi
+elif [ "$MAHOUT_LOCAL" != "" ]; then
+ echo "running MAHOUT_LOCAL"
+ cp ${WORK_DIR}/synthetic_control.data testdata
+ ../../bin/mahout.bu org.apache.mahout.clustering.syntheticcontrol."${clustertype}".Job
+ rm testdata
+ echo " HADOOP_HOME variable is not set. Please set this environment variable and rerun the script"
+# Remove the work directory
+rm -rf ${WORK_DIR}

diff --git a/community/mahout-mr/examples/bin/factorize-movielens-1M.sh b/community/mahout-mr/examples/bin/factorize-movielens-1M.sh
new file mode 100755
index 0000000..29730e1
--- /dev/null
+++ b/community/mahout-mr/examples/bin/factorize-movielens-1M.sh
@@ -0,0 +1,85 @@
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# Instructions:
+# Before using this script, you have to download and extract the Movielens 1M dataset
+# from http://www.grouplens.org/node/73
+# To run: change into the mahout directory and type:
+# export MAHOUT_LOCAL=true
+# Then:
+# examples/bin/factorize-movielens-1M.sh /path/to/ratings.dat
+if [ "$1" = "--help" ] || [ "$1" = "--?" ]; then
+ echo "This script runs the Alternating Least Squares Recommender on the Grouplens data set (size 1M)."
+ echo "Syntax: $0 /path/to/ratings.dat\n"
+ exit
+if [ $# -ne 1 ]
+ echo -e "\nYou have to download the Movielens 1M dataset from http://www.grouplens.org/node/73 before"
+ echo -e "you can run this example. After that extract it and supply the path to the ratings.dat file.\n"
+ echo -e "Syntax: $0 /path/to/ratings.dat\n"
+ exit -1
+export MAHOUT_LOCAL=true
+if [[ -z "$MAHOUT_WORK_DIR" ]]; then
+ WORK_DIR=/tmp/mahout-work-${USER}
+echo "creating work directory at ${WORK_DIR}"
+mkdir -p ${WORK_DIR}/movielens
+echo "Converting ratings..."
+cat $1 |sed -e s/::/,/g| cut -d, -f1,2,3 > ${WORK_DIR}/movielens/ratings.csv
+# create a 90% percent training set and a 10% probe set
+$MAHOUT splitDataset --input ${WORK_DIR}/movielens/ratings.csv --output ${WORK_DIR}/dataset \
+ --trainingPercentage 0.9 --probePercentage 0.1 --tempDir ${WORK_DIR}/dataset/tmp
+# run distributed ALS-WR to factorize the rating matrix defined by the training set
+$MAHOUT parallelALS --input ${WORK_DIR}/dataset/trainingSet/ --output ${WORK_DIR}/als/out \
+ --tempDir ${WORK_DIR}/als/tmp --numFeatures 20 --numIterations 10 --lambda 0.065 --numThreadsPerSolver 2
+# compute predictions against the probe set, measure the error
+$MAHOUT evaluateFactorization --input ${WORK_DIR}/dataset/probeSet/ --output ${WORK_DIR}/als/rmse/ \
+ --userFeatures ${WORK_DIR}/als/out/U/ --itemFeatures ${WORK_DIR}/als/out/M/ --tempDir ${WORK_DIR}/als/tmp
+# compute recommendations
+$MAHOUT recommendfactorized --input ${WORK_DIR}/als/out/userRatings/ --output ${WORK_DIR}/recommendations/ \
+ --userFeatures ${WORK_DIR}/als/out/U/ --itemFeatures ${WORK_DIR}/als/out/M/ \
+ --numRecommendations 6 --maxRating 5 --numThreads 2
+# print the error
+echo -e "\nRMSE is:\n"
+cat ${WORK_DIR}/als/rmse/rmse.txt
+echo -e "\n"
+echo -e "\nSample recommendations:\n"
+shuf ${WORK_DIR}/recommendations/part-m-00000 |head
+echo -e "\n\n"
+echo "removing work directory"
+rm -rf ${WORK_DIR}

diff --git a/community/mahout-mr/examples/bin/factorize-netflix.sh b/community/mahout-mr/examples/bin/factorize-netflix.sh
new file mode 100755
index 0000000..26faf66
--- /dev/null
+++ b/community/mahout-mr/examples/bin/factorize-netflix.sh
@@ -0,0 +1,90 @@
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# Instructions:
+# You can only use this script in conjunction with the Netflix dataset. Unpack the Netflix dataset and provide the
+# following:
+# 1) the path to the folder 'training_set' that contains all the movie rating files
+# 2) the path to the file 'qualifying.txt' that contains the user,item pairs to predict
+# 3) the path to the file 'judging.txt' that contains the ratings of user,item pairs to predict for
+# To run:
+# ./factorize-netflix.sh /path/to/training_set/ /path/to/qualifying.txt /path/to/judging.txt
+echo "Note this script has been deprecated due to the lack of access to the Netflix data set."
+exit 1
+if [ "$1" = "--help" ] || [ "$1" = "--?" ]; then
+ echo "This script runs the ALS Recommender on the Netflix data set."
+ echo "Syntax: $0 /path/to/training_set/ /path/to/qualifying.txt /path/to/judging.txt\n"
+ exit
+if [ $# -ne 3 ]
+ echo -e "Syntax: $0 /path/to/training_set/ /path/to/qualifying.txt /path/to/judging.txt\n"
+ exit -1
+if [[ -z "$MAHOUT_WORK_DIR" ]]; then
+ WORK_DIR=/tmp/mahout-work-${USER}
+# Set commands for dfs
+source ${START_PATH}/set-dfs-commands.sh
+echo "Preparing data..."
+$MAHOUT org.apache.mahout.cf.taste.hadoop.example.als.netflix.NetflixDatasetConverter $1 $2 $3 ${WORK_DIR}
+# run distributed ALS-WR to factorize the rating matrix defined by the training set
+$MAHOUT parallelALS --input ${WORK_DIR}/trainingSet/ratings.tsv --output ${WORK_DIR}/als/out \
+ --tempDir ${WORK_DIR}/als/tmp --numFeatures 25 --numIterations 10 --lambda 0.065 --numThreadsPerSolver 4
+# compute predictions against the probe set, measure the error
+$MAHOUT evaluateFactorization --input ${WORK_DIR}/probeSet/ratings.tsv --output ${WORK_DIR}/als/rmse/ \
+ --userFeatures ${WORK_DIR}/als/out/U/ --itemFeatures ${WORK_DIR}/als/out/M/ --tempDir ${WORK_DIR}/als/tmp
+if [ "$HADOOP_HOME" != "" ] && [ "$MAHOUT_LOCAL" == "" ] ; then
+ # print the error, should be around 0.923
+ echo -e "\nRMSE is:\n"
+ $DFS -tail ${WORK_DIR}/als/rmse/rmse.txt
+ echo -e "\n"
+ echo "removing work directory"
+ set +e
+ # print the error, should be around 0.923
+ echo -e "\nRMSE is:\n"
+ cat ${WORK_DIR}/als/rmse/rmse.txt
+ echo -e "\n"
+ echo "removing work directory"
+ rm -rf ${WORK_DIR}

diff --git a/community/mahout-mr/examples/bin/get-all-examples.sh b/community/mahout-mr/examples/bin/get-all-examples.sh
new file mode 100755
index 0000000..4128e47
--- /dev/null
+++ b/community/mahout-mr/examples/bin/get-all-examples.sh
@@ -0,0 +1,36 @@
+#!/usr/bin/env bash
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# Clones Mahout example code from remote repositories with their own
+# build process. Follow the README for each example for instructions.
+# Usage: change into the mahout directory and type:
+# examples/bin/get-all-examples.sh
+# Solr-recommender
+echo " Solr-recommender example: "
+echo " 1) imports text 'log files' of some delimited form for user preferences"
+echo " 2) creates the correct Mahout files and stores distionaries to translate external Id to and from Mahout Ids"
+echo " 3) it implements a prototype two actions 'cross-recommender', which takes two actions made by the same user and creates recommendations"
+echo " 4) it creates output for user->preference history CSV and and item->similar items 'similarity' matrix for use in a Solr-recommender."
+echo " To use Solr you would index the similarity matrix CSV, and use user preference history from the history CSV as a query, the result"
+echo " from Solr will be an ordered list of recommendations returning the same item Ids as were input."
+echo " For further description see the README.md here https://github.com/pferrel/solr-recommender"
+echo " To build run 'cd solr-recommender; mvn install'"
+echo " To process the example after building make sure MAHOUT_LOCAL IS SET and hadoop is in local mode then "
+echo " run 'cd scripts; ./solr-recommender-example'"
+git clone https://github.com/pferrel/solr-recommender

diff --git a/community/mahout-mr/examples/bin/lda.algorithm b/community/mahout-mr/examples/bin/lda.algorithm
new file mode 100644
index 0000000..fb84ea0
--- /dev/null
+++ b/community/mahout-mr/examples/bin/lda.algorithm
@@ -0,0 +1,45 @@
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# task at this depth or less would print when they start
+# --------- alg
+{ "BuildReuters"
+ CreateIndex
+ { "AddDocs" AddDoc > : *
+# Optimize
+ CloseIndex
2018-06-27 13:14:50 UTC
MAHOUT-2034 Split MR and New Examples into seperate modules

Project: http://git-wip-us.apache.org/repos/asf/mahout/repo
Commit: http://git-wip-us.apache.org/repos/asf/mahout/commit/02f75f99
Tree: http://git-wip-us.apache.org/repos/asf/mahout/tree/02f75f99
Diff: http://git-wip-us.apache.org/repos/asf/mahout/diff/02f75f99

Branch: refs/heads/branch-0.14.0
Commit: 02f75f997bbc01083a345287072e821bfe4f1558
Parents: aa57e2f
Author: Trevor a.k.a @rawkintrevo <***@gmail.com>
Authored: Wed Jun 27 08:13:16 2018 -0500
Committer: Trevor a.k.a @rawkintrevo <***@gmail.com>
Committed: Wed Jun 27 08:13:16 2018 -0500

bin/load-shell.scala | 2 +-
bin/mahout | 196 +-
bin/mahout.bu | 395 +
community/mahout-mr/bin/mahout | 395 +
community/mahout-mr/bin/mahout.cmd | 397 +
community/mahout-mr/examples/bin/README.txt | 13 +
.../examples/bin/classify-20newsgroups.sh | 197 +
.../examples/bin/classify-wikipedia.sh | 196 +
.../mahout-mr/examples/bin/cluster-reuters.sh | 203 +
.../examples/bin/cluster-syntheticcontrol.sh | 105 +
.../examples/bin/factorize-movielens-1M.sh | 85 +
.../mahout-mr/examples/bin/factorize-netflix.sh | 90 +
.../mahout-mr/examples/bin/get-all-examples.sh | 36 +
community/mahout-mr/examples/bin/lda.algorithm | 45 +
.../examples/bin/resources/bank-full.csv | 45212 +++++++++++++++++
.../examples/bin/resources/country.txt | 229 +
.../examples/bin/resources/country10.txt | 10 +
.../examples/bin/resources/country2.txt | 2 +
.../examples/bin/resources/donut-test.csv | 41 +
.../mahout-mr/examples/bin/resources/donut.csv | 41 +
.../examples/bin/resources/test-data.csv | 61 +
.../mahout-mr/examples/bin/set-dfs-commands.sh | 54 +
community/mahout-mr/examples/pom.xml | 199 +
.../examples/src/main/assembly/job.xml | 46 +
.../cf/taste/example/TasteOptionParser.java | 75 +
.../BookCrossingBooleanRecommender.java | 102 +
.../BookCrossingBooleanRecommenderBuilder.java | 32 +
...ossingBooleanRecommenderEvaluatorRunner.java | 59 +
.../bookcrossing/BookCrossingDataModel.java | 99 +
.../BookCrossingDataModelBuilder.java | 33 +
.../bookcrossing/BookCrossingRecommender.java | 101 +
.../BookCrossingRecommenderBuilder.java | 32 +
.../BookCrossingRecommenderEvaluatorRunner.java | 54 +
.../mahout/cf/taste/example/bookcrossing/README | 9 +
.../cf/taste/example/email/EmailUtility.java | 104 +
.../email/FromEmailToDictionaryMapper.java | 61 +
.../example/email/MailToDictionaryReducer.java | 43 +
.../taste/example/email/MailToPrefsDriver.java | 274 +
.../cf/taste/example/email/MailToRecMapper.java | 101 +
.../taste/example/email/MailToRecReducer.java | 53 +
.../example/email/MsgIdToDictionaryMapper.java | 49 +
.../taste/example/kddcup/DataFileIterable.java | 44 +
.../taste/example/kddcup/DataFileIterator.java | 158 +
.../taste/example/kddcup/KDDCupDataModel.java | 231 +
.../mahout/cf/taste/example/kddcup/ToCSV.java | 77 +
.../kddcup/track1/EstimateConverter.java | 43 +
.../example/kddcup/track1/Track1Callable.java | 67 +
.../kddcup/track1/Track1Recommender.java | 94 +
.../kddcup/track1/Track1RecommenderBuilder.java | 32 +
.../track1/Track1RecommenderEvaluator.java | 108 +
.../Track1RecommenderEvaluatorRunner.java | 56 +
.../example/kddcup/track1/Track1Runner.java | 95 +
.../svd/DataModelFactorizablePreferences.java | 107 +
.../track1/svd/FactorizablePreferences.java | 44 +
.../svd/KDDCupFactorizablePreferences.java | 123 +
.../track1/svd/ParallelArraysSGDFactorizer.java | 265 +
.../kddcup/track1/svd/Track1SVDRunner.java | 141 +
.../example/kddcup/track2/HybridSimilarity.java | 62 +
.../example/kddcup/track2/Track2Callable.java | 106 +
.../kddcup/track2/Track2Recommender.java | 100 +
.../kddcup/track2/Track2RecommenderBuilder.java | 33 +
.../example/kddcup/track2/Track2Runner.java | 100 +
.../taste/example/kddcup/track2/TrackData.java | 71 +
.../kddcup/track2/TrackItemSimilarity.java | 106 +
.../taste/example/kddcup/track2/UserResult.java | 54 +
.../als/netflix/NetflixDatasetConverter.java | 140 +
.../example/BatchItemSimilaritiesGroupLens.java | 65 +
.../precompute/example/GroupLensDataModel.java | 96 +
.../mahout/classifier/NewsgroupHelper.java | 128 +
.../classifier/email/PrepEmailMapper.java | 65 +
.../classifier/email/PrepEmailReducer.java | 47 +
.../email/PrepEmailVectorsDriver.java | 76 +
.../sequencelearning/hmm/PosTagger.java | 277 +
.../sgd/AdaptiveLogisticModelParameters.java | 236 +
.../classifier/sgd/LogisticModelParameters.java | 265 +
.../classifier/sgd/PrintResourceOrFile.java | 42 +
.../classifier/sgd/RunAdaptiveLogistic.java | 197 +
.../mahout/classifier/sgd/RunLogistic.java | 163 +
.../apache/mahout/classifier/sgd/SGDHelper.java | 151 +
.../apache/mahout/classifier/sgd/SGDInfo.java | 59 +
.../classifier/sgd/SimpleCsvExamples.java | 283 +
.../mahout/classifier/sgd/TestASFEmail.java | 152 +
.../mahout/classifier/sgd/TestNewsGroups.java | 141 +
.../mahout/classifier/sgd/TrainASFEmail.java | 137 +
.../classifier/sgd/TrainAdaptiveLogistic.java | 377 +
.../mahout/classifier/sgd/TrainLogistic.java | 311 +
.../mahout/classifier/sgd/TrainNewsGroups.java | 154 +
.../sgd/ValidateAdaptiveLogistic.java | 218 +
.../BankMarketingClassificationMain.java | 70 +
.../sgd/bankmarketing/TelephoneCall.java | 104 +
.../sgd/bankmarketing/TelephoneCallParser.java | 66 +
.../clustering/display/ClustersFilter.java | 31 +
.../clustering/display/DisplayCanopy.java | 88 +
.../clustering/display/DisplayClustering.java | 374 +
.../clustering/display/DisplayFuzzyKMeans.java | 110 +
.../clustering/display/DisplayKMeans.java | 106 +
.../display/DisplaySpectralKMeans.java | 85 +
.../apache/mahout/clustering/display/README.txt | 22 +
.../tools/ClusterQualitySummarizer.java | 279 +
.../clustering/streaming/tools/IOUtils.java | 80 +
.../clustering/syntheticcontrol/canopy/Job.java | 125 +
.../syntheticcontrol/fuzzykmeans/Job.java | 144 +
.../clustering/syntheticcontrol/kmeans/Job.java | 187 +
.../fpm/pfpgrowth/DeliciousTagsExample.java | 94 +
.../dataset/KeyBasedStringTupleCombiner.java | 40 +
.../dataset/KeyBasedStringTupleGrouper.java | 77 +
.../dataset/KeyBasedStringTupleMapper.java | 90 +
.../dataset/KeyBasedStringTupleReducer.java | 74 +
.../examples/src/main/resources/bank-full.csv | 45212 +++++++++++++++++
.../src/main/resources/cf-data-purchase.txt | 7 +
.../src/main/resources/cf-data-view.txt | 12 +
.../examples/src/main/resources/donut-test.csv | 41 +
.../examples/src/main/resources/donut.csv | 41 +
.../examples/src/main/resources/test-data.csv | 61 +
.../sgd/LogisticModelParametersTest.java | 43 +
.../classifier/sgd/ModelDissectorTest.java | 40 +
.../classifier/sgd/TrainLogisticTest.java | 167 +
.../clustering/display/ClustersFilterTest.java | 75 +
.../apache/mahout/examples/MahoutTestCase.java | 30 +
.../examples/src/test/resources/country.txt | 229 +
.../examples/src/test/resources/country10.txt | 10 +
.../examples/src/test/resources/country2.txt | 2 +
.../examples/src/test/resources/subjects.txt | 2 +
.../examples/src/test/resources/wdbc.infos | 32 +
.../examples/src/test/resources/wdbc/wdbc.data | 569 +
community/mahout-mr/pom.xml | 4 +
community/spark-cli-drivers/pom.xml | 21 +
.../src/main/assembly/dependency-reduced.xml | 51 +
.../src/main/assembly/dependency-reduced.xml | 2 +-
examples/bin/README.txt | 13 -
examples/bin/basicOLS.scala | 61 +
examples/bin/cco-lastfm.scala | 112 +
examples/bin/classify-20newsgroups.sh | 197 -
examples/bin/classify-wikipedia.sh | 196 -
examples/bin/cluster-reuters.sh | 203 -
examples/bin/cluster-syntheticcontrol.sh | 105 -
examples/bin/factorize-movielens-1M.sh | 85 -
examples/bin/factorize-netflix.sh | 90 -
examples/bin/get-all-examples.sh | 36 -
examples/bin/lda.algorithm | 45 -
examples/bin/resources/bank-full.csv | 45212 -----------------
examples/bin/resources/country.txt | 229 -
examples/bin/resources/country10.txt | 10 -
examples/bin/resources/country2.txt | 2 -
examples/bin/resources/donut-test.csv | 41 -
examples/bin/resources/donut.csv | 41 -
examples/bin/resources/test-data.csv | 61 -
examples/bin/run-item-sim.sh | 6 +-
examples/bin/set-dfs-commands.sh | 54 -
examples/pom.xml | 173 +-
examples/src/main/assembly/job.xml | 46 -
.../cf/taste/example/TasteOptionParser.java | 75 -
.../BookCrossingBooleanRecommender.java | 102 -
.../BookCrossingBooleanRecommenderBuilder.java | 32 -
...ossingBooleanRecommenderEvaluatorRunner.java | 59 -
.../bookcrossing/BookCrossingDataModel.java | 99 -
.../BookCrossingDataModelBuilder.java | 33 -
.../bookcrossing/BookCrossingRecommender.java | 101 -
.../BookCrossingRecommenderBuilder.java | 32 -
.../BookCrossingRecommenderEvaluatorRunner.java | 54 -
.../mahout/cf/taste/example/bookcrossing/README | 9 -
.../cf/taste/example/email/EmailUtility.java | 104 -
.../email/FromEmailToDictionaryMapper.java | 61 -
.../example/email/MailToDictionaryReducer.java | 43 -
.../taste/example/email/MailToPrefsDriver.java | 274 -
.../cf/taste/example/email/MailToRecMapper.java | 101 -
.../taste/example/email/MailToRecReducer.java | 53 -
.../example/email/MsgIdToDictionaryMapper.java | 49 -
.../taste/example/kddcup/DataFileIterable.java | 44 -
.../taste/example/kddcup/DataFileIterator.java | 158 -
.../taste/example/kddcup/KDDCupDataModel.java | 231 -
.../mahout/cf/taste/example/kddcup/ToCSV.java | 77 -
.../kddcup/track1/EstimateConverter.java | 43 -
.../example/kddcup/track1/Track1Callable.java | 67 -
.../kddcup/track1/Track1Recommender.java | 94 -
.../kddcup/track1/Track1RecommenderBuilder.java | 32 -
.../track1/Track1RecommenderEvaluator.java | 108 -
.../Track1RecommenderEvaluatorRunner.java | 56 -
.../example/kddcup/track1/Track1Runner.java | 95 -
.../svd/DataModelFactorizablePreferences.java | 107 -
.../track1/svd/FactorizablePreferences.java | 44 -
.../svd/KDDCupFactorizablePreferences.java | 123 -
.../track1/svd/ParallelArraysSGDFactorizer.java | 265 -
.../kddcup/track1/svd/Track1SVDRunner.java | 141 -
.../example/kddcup/track2/HybridSimilarity.java | 62 -
.../example/kddcup/track2/Track2Callable.java | 106 -
.../kddcup/track2/Track2Recommender.java | 100 -
.../kddcup/track2/Track2RecommenderBuilder.java | 33 -
.../example/kddcup/track2/Track2Runner.java | 100 -
.../taste/example/kddcup/track2/TrackData.java | 71 -
.../kddcup/track2/TrackItemSimilarity.java | 106 -
.../taste/example/kddcup/track2/UserResult.java | 54 -
.../als/netflix/NetflixDatasetConverter.java | 140 -
.../example/BatchItemSimilaritiesGroupLens.java | 65 -
.../precompute/example/GroupLensDataModel.java | 96 -
.../mahout/classifier/NewsgroupHelper.java | 128 -
.../classifier/email/PrepEmailMapper.java | 65 -
.../classifier/email/PrepEmailReducer.java | 47 -
.../email/PrepEmailVectorsDriver.java | 76 -
.../sequencelearning/hmm/PosTagger.java | 277 -
.../sgd/AdaptiveLogisticModelParameters.java | 236 -
.../classifier/sgd/LogisticModelParameters.java | 265 -
.../classifier/sgd/PrintResourceOrFile.java | 42 -
.../classifier/sgd/RunAdaptiveLogistic.java | 197 -
.../mahout/classifier/sgd/RunLogistic.java | 163 -
.../apache/mahout/classifier/sgd/SGDHelper.java | 151 -
.../apache/mahout/classifier/sgd/SGDInfo.java | 59 -
.../classifier/sgd/SimpleCsvExamples.java | 283 -
.../mahout/classifier/sgd/TestASFEmail.java | 152 -
.../mahout/classifier/sgd/TestNewsGroups.java | 141 -
.../mahout/classifier/sgd/TrainASFEmail.java | 137 -
.../classifier/sgd/TrainAdaptiveLogistic.java | 377 -
.../mahout/classifier/sgd/TrainLogistic.java | 311 -
.../mahout/classifier/sgd/TrainNewsGroups.java | 154 -
.../sgd/ValidateAdaptiveLogistic.java | 218 -
.../BankMarketingClassificationMain.java | 70 -
.../sgd/bankmarketing/TelephoneCall.java | 104 -
.../sgd/bankmarketing/TelephoneCallParser.java | 66 -
.../clustering/display/ClustersFilter.java | 31 -
.../clustering/display/DisplayCanopy.java | 88 -
.../clustering/display/DisplayClustering.java | 374 -
.../clustering/display/DisplayFuzzyKMeans.java | 110 -
.../clustering/display/DisplayKMeans.java | 106 -
.../display/DisplaySpectralKMeans.java | 85 -
.../apache/mahout/clustering/display/README.txt | 22 -
.../tools/ClusterQualitySummarizer.java | 279 -
.../clustering/streaming/tools/IOUtils.java | 80 -
.../clustering/syntheticcontrol/canopy/Job.java | 125 -
.../syntheticcontrol/fuzzykmeans/Job.java | 144 -
.../clustering/syntheticcontrol/kmeans/Job.java | 187 -
.../fpm/pfpgrowth/DeliciousTagsExample.java | 94 -
.../dataset/KeyBasedStringTupleCombiner.java | 40 -
.../dataset/KeyBasedStringTupleGrouper.java | 77 -
.../dataset/KeyBasedStringTupleMapper.java | 90 -
.../dataset/KeyBasedStringTupleReducer.java | 74 -
examples/src/main/resources/bank-full.csv | 45212 -----------------
.../src/main/resources/cf-data-purchase.txt | 7 -
examples/src/main/resources/cf-data-view.txt | 12 -
examples/src/main/resources/donut-test.csv | 41 -
examples/src/main/resources/donut.csv | 41 -
examples/src/main/resources/test-data.csv | 61 -
.../sgd/LogisticModelParametersTest.java | 43 -
.../classifier/sgd/ModelDissectorTest.java | 40 -
.../classifier/sgd/TrainLogisticTest.java | 167 -
.../clustering/display/ClustersFilterTest.java | 75 -
.../apache/mahout/examples/MahoutTestCase.java | 30 -
examples/src/test/resources/country.txt | 229 -
examples/src/test/resources/country10.txt | 10 -
examples/src/test/resources/country2.txt | 2 -
examples/src/test/resources/subjects.txt | 2 -
examples/src/test/resources/wdbc.infos | 32 -
examples/src/test/resources/wdbc/wdbc.data | 569 -
pom.xml | 4 +-
253 files changed, 104613 insertions(+), 103131 deletions(-)

diff --git a/bin/load-shell.scala b/bin/load-shell.scala
index 7468b76..f60705c 100644
--- a/bin/load-shell.scala
+++ b/bin/load-shell.scala
@@ -29,6 +29,6 @@ println("""
_ __ ___ __ _| |__ ___ _ _| |_
'_ ` _ \ / _` | '_ \ / _ \| | | | __|
| | | | | (_| | | | | (_) | |_| | |_
-_| |_| |_|\__,_|_| |_|\___/ \__,_|\__| version 0.13.0
+_| |_| |_|\__,_|_| |_|\___/ \__,_|\__| version 0.14.0

\ No newline at end of file

diff --git a/bin/mahout b/bin/mahout
index 3017c9e..fd40fe0 100755
--- a/bin/mahout
+++ b/bin/mahout
@@ -57,6 +57,8 @@ case "`uname`" in
CYGWIN*) cygwin=true;;

+# Check that mahout home is set, if not set it to one dir up.
# resolve links - $0 may be a softlink
while [ -h "$THIS" ]; do
@@ -123,6 +125,13 @@ if [ "$JAVA_HOME" = "" ]; then
exit 1

+if [ "$SPARK" = "1" ]; then
+ if [ "$SPARK_HOME" = "" ]; then
+ echo "Error: SPARK_HOME is not set."
+ exit 1
+ fi

@@ -133,53 +142,57 @@ if [ "$MAHOUT_HEAPSIZE" != "" ]; then

-if [ "x$MAHOUT_CONF_DIR" = "x" ]; then
- if [ -d $MAHOUT_HOME/src/conf ]; then
- else
- if [ -d $MAHOUT_HOME/conf ]; then
- else
- echo No MAHOUT_CONF_DIR found
- fi
- fi
+#if [ "x$MAHOUT_CONF_DIR" = "x" ]; then
+# if [ -d $MAHOUT_HOME/src/conf ]; then
+# else
+# if [ -d $MAHOUT_HOME/conf ]; then
+# else
+# echo No MAHOUT_CONF_DIR found
+# fi
+# fi

# CLASSPATH initially contains $MAHOUT_CONF_DIR, or defaults to $MAHOUT_HOME/src/conf

-if [ "$MAHOUT_LOCAL" != "" ]; then
- echo "MAHOUT_LOCAL is set, so we don't add HADOOP_CONF_DIR to classpath."
-elif [ -n "$HADOOP_CONF_DIR" ] ; then
- echo "MAHOUT_LOCAL is not set; adding HADOOP_CONF_DIR to classpath."
+#if [ "$MAHOUT_LOCAL" != "" ]; then
+# echo "MAHOUT_LOCAL is set, so we don't add HADOOP_CONF_DIR to classpath."
+#elif [ -n "$HADOOP_CONF_DIR" ] ; then
+# echo "MAHOUT_LOCAL is not set; adding HADOOP_CONF_DIR to classpath."


# so that filenames w/ spaces are handled correctly in loops below

if [ $IS_CORE == 0 ]
# add release dependencies to CLASSPATH
- for f in $MAHOUT_HOME/mahout-*.jar; do
+ echo "Adding lib/ to CLASSPATH"
+ for f in $MAHOUT_HOME/lib/*.jar; do

- if [ "$SPARK" != "1" ]; then

- # add dev targets if they exist
- for f in $MAHOUT_HOME/examples/target/mahout-examples-*-job.jar $MAHOUT_HOME/mahout-examples-*-job.jar ; do
- done
- fi
+# if [ "$SPARK" != "1" ]; then
+# # add dev targets if they exist
+# for f in $MAHOUT_HOME/examples/target/mahout-examples-*-job.jar $MAHOUT_HOME/mahout-examples-*-job.jar ; do
+# done
+# fi

# add scala dev target
- for f in $MAHOUT_HOME/math-scala/target/mahout-math-scala_*.jar ; do
- done
+# for f in $MAHOUT_HOME/math-scala/target/mahout-math-scala_*.jar ; do
+# done

if [ "$H2O" == "1" ]; then
for f in $MAHOUT_HOME/hdfs/target/mahout-hdfs-*.jar; do
@@ -193,38 +206,34 @@ then

# add jars for running from the command line if we requested shell or spark CLI driver
- if [ "$SPARK" == "1" ]; then
- for f in $MAHOUT_HOME/hdfs/target/mahout-hdfs-*.jar ; do
- done
- for f in $MAHOUT_HOME/math/target/mahout-math-*.jar ; do
- done
- for f in $MAHOUT_HOME/spark/target/mahout-spark_*.jar ; do
- done
- for f in $MAHOUT_HOME/spark-shell/target/mahout-spark-shell_*.jar ; do
- done
- # viennacl jars- may or may not be available depending on build profile
- for f in $MAHOUT_HOME/viennacl/target/mahout-native-viennacl_*.jar ; do
- done
- # viennacl jars- may or may not be available depending on build profile
- for f in $MAHOUT_HOME/viennacl-omp/target/mahout-native-viennacl-omp_*.jar ; do
- done
+# if [ "$SPARK" == "1" ]; then
+# for f in $MAHOUT_HOME/lib/mahout-hdfs-*.jar ; do
+# done
+# for f in $MAHOUT_HOME/lib/mahout-core-*.jar ; do
+# done
+# for f in $MAHOUT_HOME/lib/spark_*.jar ; do
+# done
+# for f in $MAHOUT_HOME/lib/spark-cli_*.jar ; do
+# done
+# # viennacl jars- may or may not be available depending on build profile
+# for f in $MAHOUT_HOME/viennacl/target/mahout-native-viennacl_*.jar ; do
+# done
+# # viennacl jars- may or may not be available depending on build profile
+# for f in $MAHOUT_HOME/viennacl-omp/target/mahout-native-viennacl-omp_*.jar ; do
+# done

- # viennacl jars- may or may not be available depending on build profile
- for f in $MAHOUT_HOME/viennacl-omp/target/mahout-native-viennacl-omp_*.jar ; do
- done

if [ -x "${SPARK_CP_BIN}" ]; then
@@ -245,39 +254,39 @@ then

- # add vcl jars at any point.
- # viennacl jars- may or may not be available depending on build profile
- for f in $MAHOUT_HOME/viennacl/target/mahout-native-viennacl_*.jar ; do
- done
- # viennacl jars- may or may not be available depending on build profile
- for f in $MAHOUT_HOME/viennacl-omp/target/mahout-native-viennacl-omp_*.jar ; do
- done
- # add release dependencies to CLASSPATH
- for f in $MAHOUT_HOME/lib/*.jar; do
- done
- CLASSPATH=${CLASSPATH}:$MAHOUT_HOME/math/target/classes
- CLASSPATH=${CLASSPATH}:$MAHOUT_HOME/hdfs/target/classes
- CLASSPATH=${CLASSPATH}:$MAHOUT_HOME/integration/target/classes
- CLASSPATH=${CLASSPATH}:$MAHOUT_HOME/examples/target/classes
- CLASSPATH=${CLASSPATH}:$MAHOUT_HOME/math-scala/target/classes
- CLASSPATH=${CLASSPATH}:$MAHOUT_HOME/spark/target/classes
- CLASSPATH=${CLASSPATH}:$MAHOUT_HOME/spark-shell/target/classes
- CLASSPATH=${CLASSPATH}:$MAHOUT_HOME/h2o/target/classes
+ # add vcl jars at any point.
+ # viennacl jars- may or may not be available depending on build profile
+# for f in $MAHOUT_HOME/viennacl/target/mahout-native-viennacl_*.jar ; do
+# done
+# # viennacl jars- may or may not be available depending on build profile
+# for f in $MAHOUT_HOME/viennacl-omp/target/mahout-native-viennacl-omp_*.jar ; do
+# done
+# # add release dependencies to CLASSPATH
+# for f in $MAHOUT_HOME/lib/*.jar; do
+# done
+# CLASSPATH=${CLASSPATH}:$MAHOUT_HOME/math/target/classes
+# CLASSPATH=${CLASSPATH}:$MAHOUT_HOME/hdfs/target/classes
+# CLASSPATH=${CLASSPATH}:$MAHOUT_HOME/mr/target/classes
+# CLASSPATH=${CLASSPATH}:$MAHOUT_HOME/integration/target/classes
+# CLASSPATH=${CLASSPATH}:$MAHOUT_HOME/examples/target/classes
+# CLASSPATH=${CLASSPATH}:$MAHOUT_HOME/math-scala/target/classes
+# CLASSPATH=${CLASSPATH}:$MAHOUT_HOME/spark/target/classes
+# CLASSPATH=${CLASSPATH}:$MAHOUT_HOME/spark-shell/target/classes
+# CLASSPATH=${CLASSPATH}:$MAHOUT_HOME/h2o/target/classes

# add development dependencies to CLASSPATH
-if [ "$SPARK" != "1" ]; then
- for f in $MAHOUT_HOME/examples/target/dependency/*.jar; do
- done
+#if [ "$SPARK" != "1" ]; then
+# for f in $MAHOUT_HOME/examples/target/dependency/*.jar; do
+# done

# cygwin path translation
@@ -287,7 +296,7 @@ fi

# restore ordinary behaviour
unset IFS
-JARS=$(echo "$MAHOUT_HOME"/*.jar | tr ' ' ',')
+JARS=$(echo "$MAHOUT_HOME"/lib/*.jar | tr ' ' ',')
case "$1" in
save_stty=$(stty -g 2>/dev/null);
@@ -297,6 +306,7 @@ case "$1" in
# Spark CLI drivers go here
"$JAVA" $JAVA_HEAP_MAX -classpath "$CLASSPATH" "org.apache.mahout.drivers.ItemSimilarityDriver" "$@"
@@ -333,7 +343,7 @@ case "$1" in


if [ "x$JAVA_LIBRARY_PATH" != "x" ]; then

diff --git a/bin/mahout.bu b/bin/mahout.bu
new file mode 100755
index 0000000..20f9c3d
--- /dev/null
+++ b/bin/mahout.bu
@@ -0,0 +1,395 @@
+# The Mahout command script
+# Environment Variables
+# MAHOUT_JAVA_HOME The java implementation to use. Overrides JAVA_HOME.
+# MAHOUT_HEAPSIZE The maximum amount of heap to use, in MB.
+# Default is 4000.
+# HADOOP_CONF_DIR The location of a hadoop config directory
+# MAHOUT_OPTS Extra Java runtime options.
+# MAHOUT_CONF_DIR The location of the program short-name to class name
+# mappings and the default properties files
+# defaults to "$MAHOUT_HOME/src/conf"
+# MAHOUT_LOCAL set to anything other than an empty string to force
+# mahout to run locally even if
+# MAHOUT_CORE set to anything other than an empty string to force
+# mahout to run in developer 'core' mode, just as if the
+# -core option was presented on the command-line
+# Command-line Options
+# -core -core is used to switch into 'developer mode' when
+# running mahout locally. If specified, the classes
+# from the 'target/classes' directories in each project
+# are used. Otherwise classes will be retrieved from
+# jars in the binary release collection or *-job.jar files
+# found in build directories. When running on hadoop
+# the job files will always be used.
+# * Licensed to the Apache Software Foundation (ASF) under one or more
+# * contributor license agreements. See the NOTICE file distributed with
+# * this work for additional information regarding copyright ownership.
+# * The ASF licenses this file to You under the Apache License, Version 2.0
+# * (the "License"); you may not use this file except in compliance with
+# * the License. You may obtain a copy of the License at
+# *
+# * http://www.apache.org/licenses/LICENSE-2.0
+# *
+# * Unless required by applicable law or agreed to in writing, software
+# * distributed under the License is distributed on an "AS IS" BASIS,
+# * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# * See the License for the specific language governing permissions and
+# * limitations under the License.
+# */
+case "`uname`" in
+CYGWIN*) cygwin=true;;
+# Check that mahout home is set, if not set it to one dir up.
+# resolve links - $0 may be a softlink
+while [ -h "$THIS" ]; do
+ ls=`ls -ld "$THIS"`
+ link=`expr "$ls" : '.*-> \(.*\)$'`
+ if expr "$link" : '.*/.*' > /dev/null; then
+ THIS="$link"
+ else
+ THIS=`dirname "$THIS"`/"$link"
+ fi
+if [ "$1" == "-core" ] ; then
+ shift
+if [ "$1" == "-spark" ]; then
+ shift
+if [ "$1" == "spark-shell" ]; then
+if [ "$1" == "spark-itemsimilarity" ]; then
+if [ "$1" == "spark-rowsimilarity" ]; then
+if [ "$1" == "spark-trainnb" ]; then
+if [ "$1" == "spark-testnb" ]; then
+if [ "$MAHOUT_CORE" != "" ]; then
+if [ "$1" == "h2o-node" ]; then
+ H2O=1
+# some directories
+THIS_DIR=`dirname "$THIS"`
+MAHOUT_HOME=`cd "$THIS_DIR/.." ; pwd`
+# some Java parameters
+if [ "$MAHOUT_JAVA_HOME" != "" ]; then
+ #echo "run java in $MAHOUT_JAVA_HOME"
+if [ "$JAVA_HOME" = "" ]; then
+ echo "Error: JAVA_HOME is not set."
+ exit 1
+# check envvars which might override default args
+if [ "$MAHOUT_HEAPSIZE" != "" ]; then
+ #echo "run with heapsize $MAHOUT_HEAPSIZE"
+ #echo $JAVA_HEAP_MAX
+if [ "x$MAHOUT_CONF_DIR" = "x" ]; then
+ if [ -d $MAHOUT_HOME/src/conf ]; then
+ else
+ if [ -d $MAHOUT_HOME/conf ]; then
+ else
+ echo No MAHOUT_CONF_DIR found
+ fi
+ fi
+# CLASSPATH initially contains $MAHOUT_CONF_DIR, or defaults to $MAHOUT_HOME/src/conf
+if [ "$MAHOUT_LOCAL" != "" ]; then
+ echo "MAHOUT_LOCAL is set, so we don't add HADOOP_CONF_DIR to classpath."
+elif [ -n "$HADOOP_CONF_DIR" ] ; then
+ echo "MAHOUT_LOCAL is not set; adding HADOOP_CONF_DIR to classpath."
+# so that filenames w/ spaces are handled correctly in loops below
+if [ $IS_CORE == 0 ]
+ # add release dependencies to CLASSPATH
+ for f in $MAHOUT_HOME/lib/*.jar; do
+ done
+ if [ "$SPARK" != "1" ]; then
+ if [$SPARK_HOME == ""]; then
+ echo "Have you set SPARK_HOME ?"
+ fi
+ # add dev targets if they exist
+ for f in $MAHOUT_HOME/examples/target/mahout-examples-*-job.jar $MAHOUT_HOME/mahout-examples-*-job.jar ; do
+ done
+ fi
+ # add scala dev target
+ for f in $MAHOUT_HOME/math-scala/target/mahout-math-scala_*.jar ; do
+ done
+ if [ "$H2O" == "1" ]; then
+ for f in $MAHOUT_HOME/hdfs/target/mahout-hdfs-*.jar; do
+ done
+ for f in $MAHOUT_HOME/h2o/target/mahout-h2o*.jar; do
+ done
+ fi
+ # add jars for running from the command line if we requested shell or spark CLI driver
+ if [ "$SPARK" == "1" ]; then
+ for f in $MAHOUT_HOME/lib/mahout-hdfs-*.jar ; do
+ done
+ for f in $MAHOUT_HOME/lib/mahout-core-*.jar ; do
+ done
+ for f in $MAHOUT_HOME/lib/spark_*.jar ; do
+ done
+ for f in $MAHOUT_HOME/lib/spark-cli_*.jar ; do
+ done
+ # viennacl jars- may or may not be available depending on build profile
+ for f in $MAHOUT_HOME/viennacl/target/mahout-native-viennacl_*.jar ; do
+ done
+ # viennacl jars- may or may not be available depending on build profile
+ for f in $MAHOUT_HOME/viennacl-omp/target/mahout-native-viennacl-omp_*.jar ; do
+ done
+ SPARK_CP_BIN="${MAHOUT_HOME}/bin/compute-classpath.sh"
+ if [ -x "${SPARK_CP_BIN}" ]; then
+ SPARK_CLASSPATH=$("${SPARK_CP_BIN}" 2>/dev/null)
+ else
+ echo "Cannot find Spark classpath. Is 'SPARK_HOME' set?"
+ exit -1
+ fi
+ SPARK_ASSEMBLY_BIN="${MAHOUT_HOME}/bin/mahout-spark-class.sh"
+ if [ -x "${SPARK_ASSEMBLY_BIN}" ]; then
+ else
+ echo "Cannot find Spark assembly classpath. Is 'SPARK_HOME' set?"
+ exit -1
+ fi
+ fi
+ # add vcl jars at any point.
+ # viennacl jars- may or may not be available depending on build profile
+ for f in $MAHOUT_HOME/viennacl/target/mahout-native-viennacl_*.jar ; do
+ done
+ # viennacl jars- may or may not be available depending on build profile
+ for f in $MAHOUT_HOME/viennacl-omp/target/mahout-native-viennacl-omp_*.jar ; do
+ done
+ # add release dependencies to CLASSPATH
+ for f in $MAHOUT_HOME/lib/*.jar; do
+ done
+ CLASSPATH=${CLASSPATH}:$MAHOUT_HOME/math/target/classes
+ CLASSPATH=${CLASSPATH}:$MAHOUT_HOME/hdfs/target/classes
+ CLASSPATH=${CLASSPATH}:$MAHOUT_HOME/integration/target/classes
+ CLASSPATH=${CLASSPATH}:$MAHOUT_HOME/examples/target/classes
+ CLASSPATH=${CLASSPATH}:$MAHOUT_HOME/math-scala/target/classes
+ CLASSPATH=${CLASSPATH}:$MAHOUT_HOME/spark/target/classes
+ CLASSPATH=${CLASSPATH}:$MAHOUT_HOME/spark-shell/target/classes
+ CLASSPATH=${CLASSPATH}:$MAHOUT_HOME/h2o/target/classes
+# add development dependencies to CLASSPATH
+if [ "$SPARK" != "1" ]; then
+ for f in $MAHOUT_HOME/examples/target/dependency/*.jar; do
+ done
+# cygwin path translation
+if $cygwin; then
+ CLASSPATH=`cygpath -p -w "$CLASSPATH"`
+# restore ordinary behaviour
+unset IFS
+JARS=$(echo "$MAHOUT_HOME"/*.jar | tr ' ' ',')
+case "$1" in
+ (spark-shell)
+ save_stty=$(stty -g 2>/dev/null);
+ $SPARK_HOME/bin/spark-shell --jars "$JARS" -i $MAHOUT_HOME/bin/load-shell.scala --conf spark.kryo.referenceTracking=false --conf spark.kryo.registrator=org.apache.mahout.sparkbindings.io.MahoutKryoRegistrator --conf spark.kryoserializer.buffer=32k --conf spark.kryoserializer.buffer.max=600m --conf spark.serializer=org.apache.spark.serializer.KryoSerializer $@
+ stty sane; stty $save_stty
+ ;;
+ # Spark CLI drivers go here
+ (spark-itemsimilarity)
+ shift
+ "$JAVA" $JAVA_HEAP_MAX -classpath "$CLASSPATH" "org.apache.mahout.drivers.ItemSimilarityDriver" "$@"
+ ;;
+ (spark-rowsimilarity)
+ shift
+ "$JAVA" $JAVA_HEAP_MAX -classpath "$CLASSPATH" "org.apache.mahout.drivers.RowSimilarityDriver" "$@"
+ ;;
+ (spark-trainnb)
+ shift
+ "$JAVA" $JAVA_HEAP_MAX -classpath "$CLASSPATH" "org.apache.mahout.drivers.TrainNBDriver" "$@"
+ ;;
+ (spark-testnb)
+ shift
+ "$JAVA" $JAVA_HEAP_MAX -classpath "$CLASSPATH" "org.apache.mahout.drivers.TestNBDriver" "$@"
+ ;;
+ (h2o-node)
+ shift
+ "$JAVA" $JAVA_HEAP_MAX -classpath "$CLASSPATH" "water.H2O" -md5skip "$@" -name mah2out
+ ;;
+ (*)
+ # default log directory & file
+ if [ "$MAHOUT_LOG_DIR" = "" ]; then
+ fi
+ if [ "$MAHOUT_LOGFILE" = "" ]; then
+ MAHOUT_LOGFILE='mahout.log'
+ fi
+ #Fix log path under cygwin
+ if $cygwin; then
+ MAHOUT_LOG_DIR=`cygpath -p -w "$MAHOUT_LOG_DIR"`
+ fi
+ if [ "x$JAVA_LIBRARY_PATH" != "x" ]; then
+ fi
+ CLASS=org.apache.mahout.driver.MahoutDriver
+ for f in $MAHOUT_HOME/examples/target/mahout-examples-*-job.jar $MAHOUT_HOME/mahout-examples-*-job.jar ; do
+ if [ -e "$f" ]; then
+ fi
+ done
+ # run it
+ HADOOP_BINARY=$(PATH="${HADOOP_HOME:-${HADOOP_PREFIX}}/bin:$PATH" which hadoop 2>/dev/null)
+ if [ -x "$HADOOP_BINARY" ] ; then
+ fi
+ if [ ! -x "$HADOOP_BINARY" ] || [ "$MAHOUT_LOCAL" != "" ] ; then
+ if [ ! -x "$HADOOP_BINARY" ] ; then
+ echo "hadoop binary is not in PATH,HADOOP_HOME/bin,HADOOP_PREFIX/bin, running locally"
+ elif [ "$MAHOUT_LOCAL" != "" ] ; then
+ echo "MAHOUT_LOCAL is set, running locally"
+ fi
+ case $1 in
+ (classpath)
+ ;;
+ (*)
+ exec "$JAVA" $JAVA_HEAP_MAX $MAHOUT_OPTS -classpath "$CLASSPATH" $CLASS "$@"
+ esac
+ else
+ echo "Running on hadoop, using $HADOOP_BINARY and HADOOP_CONF_DIR=$HADOOP_CONF_DIR"
+ if [ "$MAHOUT_JOB" = "" ] ; then
+ echo "ERROR: Could not find mahout-examples-*.job in $MAHOUT_HOME or $MAHOUT_HOME/examples/target, please run 'mvn install' to create the .job file"
+ exit 1
+ else
+ case "$1" in
+ (hadoop)
+ shift
+ exec "$HADOOP_BINARY" "$@"
+ ;;
+ (classpath)
+ ;;
+ (*)
+ esac
+ fi
+ fi
+ ;;

diff --git a/community/mahout-mr/bin/mahout b/community/mahout-mr/bin/mahout
new file mode 100755
index 0000000..3017c9e
--- /dev/null
+++ b/community/mahout-mr/bin/mahout
@@ -0,0 +1,395 @@
+# The Mahout command script
+# Environment Variables
+# MAHOUT_JAVA_HOME The java implementation to use. Overrides JAVA_HOME.
+# MAHOUT_HEAPSIZE The maximum amount of heap to use, in MB.
+# Default is 4000.
+# HADOOP_CONF_DIR The location of a hadoop config directory
+# MAHOUT_OPTS Extra Java runtime options.
+# MAHOUT_CONF_DIR The location of the program short-name to class name
+# mappings and the default properties files
+# defaults to "$MAHOUT_HOME/src/conf"
+# MAHOUT_LOCAL set to anything other than an empty string to force
+# mahout to run locally even if
+# MAHOUT_CORE set to anything other than an empty string to force
+# mahout to run in developer 'core' mode, just as if the
+# -core option was presented on the command-line
+# Command-line Options
+# -core -core is used to switch into 'developer mode' when
+# running mahout locally. If specified, the classes
+# from the 'target/classes' directories in each project
+# are used. Otherwise classes will be retrieved from
+# jars in the binary release collection or *-job.jar files
+# found in build directories. When running on hadoop
+# the job files will always be used.
+# * Licensed to the Apache Software Foundation (ASF) under one or more
+# * contributor license agreements. See the NOTICE file distributed with
+# * this work for additional information regarding copyright ownership.
+# * The ASF licenses this file to You under the Apache License, Version 2.0
+# * (the "License"); you may not use this file except in compliance with
+# * the License. You may obtain a copy of the License at
+# *
+# * http://www.apache.org/licenses/LICENSE-2.0
+# *
+# * Unless required by applicable law or agreed to in writing, software
+# * distributed under the License is distributed on an "AS IS" BASIS,
+# * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# * See the License for the specific language governing permissions and
+# * limitations under the License.
+# */
+case "`uname`" in
+CYGWIN*) cygwin=true;;
+# resolve links - $0 may be a softlink
+while [ -h "$THIS" ]; do
+ ls=`ls -ld "$THIS"`
+ link=`expr "$ls" : '.*-> \(.*\)$'`
+ if expr "$link" : '.*/.*' > /dev/null; then
+ THIS="$link"
+ else
+ THIS=`dirname "$THIS"`/"$link"
+ fi
+if [ "$1" == "-core" ] ; then
+ shift
+if [ "$1" == "-spark" ]; then
+ shift
+if [ "$1" == "spark-shell" ]; then
+if [ "$1" == "spark-itemsimilarity" ]; then
+if [ "$1" == "spark-rowsimilarity" ]; then
+if [ "$1" == "spark-trainnb" ]; then
+if [ "$1" == "spark-testnb" ]; then
+if [ "$MAHOUT_CORE" != "" ]; then
+if [ "$1" == "h2o-node" ]; then
+ H2O=1
+# some directories
+THIS_DIR=`dirname "$THIS"`
+MAHOUT_HOME=`cd "$THIS_DIR/.." ; pwd`
+# some Java parameters
+if [ "$MAHOUT_JAVA_HOME" != "" ]; then
+ #echo "run java in $MAHOUT_JAVA_HOME"
+if [ "$JAVA_HOME" = "" ]; then
+ echo "Error: JAVA_HOME is not set."
+ exit 1
+# check envvars which might override default args
+if [ "$MAHOUT_HEAPSIZE" != "" ]; then
+ #echo "run with heapsize $MAHOUT_HEAPSIZE"
+ #echo $JAVA_HEAP_MAX
+if [ "x$MAHOUT_CONF_DIR" = "x" ]; then
+ if [ -d $MAHOUT_HOME/src/conf ]; then
+ else
+ if [ -d $MAHOUT_HOME/conf ]; then
+ else
+ echo No MAHOUT_CONF_DIR found
+ fi
+ fi
+# CLASSPATH initially contains $MAHOUT_CONF_DIR, or defaults to $MAHOUT_HOME/src/conf
+if [ "$MAHOUT_LOCAL" != "" ]; then
+ echo "MAHOUT_LOCAL is set, so we don't add HADOOP_CONF_DIR to classpath."
+elif [ -n "$HADOOP_CONF_DIR" ] ; then
+ echo "MAHOUT_LOCAL is not set; adding HADOOP_CONF_DIR to classpath."
+# so that filenames w/ spaces are handled correctly in loops below
+if [ $IS_CORE == 0 ]
+ # add release dependencies to CLASSPATH
+ for f in $MAHOUT_HOME/mahout-*.jar; do
+ done
+ if [ "$SPARK" != "1" ]; then
+ # add dev targets if they exist
+ for f in $MAHOUT_HOME/examples/target/mahout-examples-*-job.jar $MAHOUT_HOME/mahout-examples-*-job.jar ; do
+ done
+ fi
+ # add scala dev target
+ for f in $MAHOUT_HOME/math-scala/target/mahout-math-scala_*.jar ; do
+ done
+ if [ "$H2O" == "1" ]; then
+ for f in $MAHOUT_HOME/hdfs/target/mahout-hdfs-*.jar; do
+ done
+ for f in $MAHOUT_HOME/h2o/target/mahout-h2o*.jar; do
+ done
+ fi
+ # add jars for running from the command line if we requested shell or spark CLI driver
+ if [ "$SPARK" == "1" ]; then
+ for f in $MAHOUT_HOME/hdfs/target/mahout-hdfs-*.jar ; do
+ done
+ for f in $MAHOUT_HOME/math/target/mahout-math-*.jar ; do
+ done
+ for f in $MAHOUT_HOME/spark/target/mahout-spark_*.jar ; do
+ done
+ for f in $MAHOUT_HOME/spark-shell/target/mahout-spark-shell_*.jar ; do
+ done
+ # viennacl jars- may or may not be available depending on build profile
+ for f in $MAHOUT_HOME/viennacl/target/mahout-native-viennacl_*.jar ; do
+ done
+ # viennacl jars- may or may not be available depending on build profile
+ for f in $MAHOUT_HOME/viennacl-omp/target/mahout-native-viennacl-omp_*.jar ; do
+ done
+ # viennacl jars- may or may not be available depending on build profile
+ for f in $MAHOUT_HOME/viennacl-omp/target/mahout-native-viennacl-omp_*.jar ; do
+ done
+ SPARK_CP_BIN="${MAHOUT_HOME}/bin/compute-classpath.sh"
+ if [ -x "${SPARK_CP_BIN}" ]; then
+ SPARK_CLASSPATH=$("${SPARK_CP_BIN}" 2>/dev/null)
+ else
+ echo "Cannot find Spark classpath. Is 'SPARK_HOME' set?"
+ exit -1
+ fi
+ SPARK_ASSEMBLY_BIN="${MAHOUT_HOME}/bin/mahout-spark-class.sh"
+ if [ -x "${SPARK_ASSEMBLY_BIN}" ]; then
+ else
+ echo "Cannot find Spark assembly classpath. Is 'SPARK_HOME' set?"
+ exit -1
+ fi
+ fi
+ # add vcl jars at any point.
+ # viennacl jars- may or may not be available depending on build profile
+ for f in $MAHOUT_HOME/viennacl/target/mahout-native-viennacl_*.jar ; do
+ done
+ # viennacl jars- may or may not be available depending on build profile
+ for f in $MAHOUT_HOME/viennacl-omp/target/mahout-native-viennacl-omp_*.jar ; do
+ done
+ # add release dependencies to CLASSPATH
+ for f in $MAHOUT_HOME/lib/*.jar; do
+ done
+ CLASSPATH=${CLASSPATH}:$MAHOUT_HOME/math/target/classes
+ CLASSPATH=${CLASSPATH}:$MAHOUT_HOME/hdfs/target/classes
+ CLASSPATH=${CLASSPATH}:$MAHOUT_HOME/integration/target/classes
+ CLASSPATH=${CLASSPATH}:$MAHOUT_HOME/examples/target/classes
+ CLASSPATH=${CLASSPATH}:$MAHOUT_HOME/math-scala/target/classes
+ CLASSPATH=${CLASSPATH}:$MAHOUT_HOME/spark/target/classes
+ CLASSPATH=${CLASSPATH}:$MAHOUT_HOME/spark-shell/target/classes
+ CLASSPATH=${CLASSPATH}:$MAHOUT_HOME/h2o/target/classes
+# add development dependencies to CLASSPATH
+if [ "$SPARK" != "1" ]; then
+ for f in $MAHOUT_HOME/examples/target/dependency/*.jar; do
+ done
+# cygwin path translation
+if $cygwin; then
+ CLASSPATH=`cygpath -p -w "$CLASSPATH"`
+# restore ordinary behaviour
+unset IFS
+JARS=$(echo "$MAHOUT_HOME"/*.jar | tr ' ' ',')
+case "$1" in
+ (spark-shell)
+ save_stty=$(stty -g 2>/dev/null);
+ $SPARK_HOME/bin/spark-shell --jars "$JARS" -i $MAHOUT_HOME/bin/load-shell.scala --conf spark.kryo.referenceTracking=false --conf spark.kryo.registrator=org.apache.mahout.sparkbindings.io.MahoutKryoRegistrator --conf spark.kryoserializer.buffer=32k --conf spark.kryoserializer.buffer.max=600m --conf spark.serializer=org.apache.spark.serializer.KryoSerializer $@
+ stty sane; stty $save_stty
+ ;;
+ # Spark CLI drivers go here
+ (spark-itemsimilarity)
+ shift
+ "$JAVA" $JAVA_HEAP_MAX -classpath "$CLASSPATH" "org.apache.mahout.drivers.ItemSimilarityDriver" "$@"
+ ;;
+ (spark-rowsimilarity)
+ shift
+ "$JAVA" $JAVA_HEAP_MAX -classpath "$CLASSPATH" "org.apache.mahout.drivers.RowSimilarityDriver" "$@"
+ ;;
+ (spark-trainnb)
+ shift
+ "$JAVA" $JAVA_HEAP_MAX -classpath "$CLASSPATH" "org.apache.mahout.drivers.TrainNBDriver" "$@"
+ ;;
+ (spark-testnb)
+ shift
+ "$JAVA" $JAVA_HEAP_MAX -classpath "$CLASSPATH" "org.apache.mahout.drivers.TestNBDriver" "$@"
+ ;;
+ (h2o-node)
+ shift
+ "$JAVA" $JAVA_HEAP_MAX -classpath "$CLASSPATH" "water.H2O" -md5skip "$@" -name mah2out
+ ;;
+ (*)
+ # default log directory & file
+ if [ "$MAHOUT_LOG_DIR" = "" ]; then
+ fi
+ if [ "$MAHOUT_LOGFILE" = "" ]; then
+ MAHOUT_LOGFILE='mahout.log'
+ fi
+ #Fix log path under cygwin
+ if $cygwin; then
+ MAHOUT_LOG_DIR=`cygpath -p -w "$MAHOUT_LOG_DIR"`
+ fi
+ if [ "x$JAVA_LIBRARY_PATH" != "x" ]; then
+ fi
+ CLASS=org.apache.mahout.driver.MahoutDriver
+ for f in $MAHOUT_HOME/examples/target/mahout-examples-*-job.jar $MAHOUT_HOME/mahout-examples-*-job.jar ; do
+ if [ -e "$f" ]; then
+ fi
+ done
+ # run it
+ HADOOP_BINARY=$(PATH="${HADOOP_HOME:-${HADOOP_PREFIX}}/bin:$PATH" which hadoop 2>/dev/null)
+ if [ -x "$HADOOP_BINARY" ] ; then
+ fi
+ if [ ! -x "$HADOOP_BINARY" ] || [ "$MAHOUT_LOCAL" != "" ] ; then
+ if [ ! -x "$HADOOP_BINARY" ] ; then
+ echo "hadoop binary is not in PATH,HADOOP_HOME/bin,HADOOP_PREFIX/bin, running locally"
+ elif [ "$MAHOUT_LOCAL" != "" ] ; then
+ echo "MAHOUT_LOCAL is set, running locally"
+ fi
+ case $1 in
+ (classpath)
+ ;;
+ (*)
+ exec "$JAVA" $JAVA_HEAP_MAX $MAHOUT_OPTS -classpath "$CLASSPATH" $CLASS "$@"
+ esac
+ else
+ echo "Running on hadoop, using $HADOOP_BINARY and HADOOP_CONF_DIR=$HADOOP_CONF_DIR"
+ if [ "$MAHOUT_JOB" = "" ] ; then
+ echo "ERROR: Could not find mahout-examples-*.job in $MAHOUT_HOME or $MAHOUT_HOME/examples/target, please run 'mvn install' to create the .job file"
+ exit 1
+ else
+ case "$1" in
+ (hadoop)
+ shift
+ exec "$HADOOP_BINARY" "$@"
+ ;;
+ (classpath)
+ ;;
+ (*)
+ esac
+ fi
+ fi
+ ;;

diff --git a/community/mahout-mr/bin/mahout.cmd b/community/mahout-mr/bin/mahout.cmd
new file mode 100644
index 0000000..86bae79
--- /dev/null
+++ b/community/mahout-mr/bin/mahout.cmd
@@ -0,0 +1,397 @@
+@echo off
+echo "===============DEPRECATION WARNING==============="
+echo "This script is no longer supported for new drivers as of Mahout 0.10.0"
+echo "Mahout's bash script is supported and if someone wants to contribute a fix for this"
+echo "it would be appreciated."
+@rem The Mahout command script
+@rem Environment Variables
+@rem MAHOUT_JAVA_HOME The java implementation to use. Overrides JAVA_HOME.
+@rem MAHOUT_HEAPSIZE The maximum amount of heap to use, in MB.
+@rem Default is 1000.
+@rem HADOOP_CONF_DIR The location of a hadoop config directory
+@rem MAHOUT_OPTS Extra Java runtime options.
+@rem MAHOUT_CONF_DIR The location of the program short-name to class name
+@rem mappings and the default properties files
+@rem defaults to "$MAHOUT_HOME/src/conf"
+@rem MAHOUT_LOCAL set to anything other than an empty string to force
+@rem mahout to run locally even if
+@rem MAHOUT_CORE set to anything other than an empty string to force
+@rem mahout to run in developer 'core' mode, just as if the
+@rem -core option was presented on the command-line
+@rem Commane-line Options
+@rem -core -core is used to switch into 'developer mode' when
+@rem running mahout locally. If specified, the classes
+@rem from the 'target/classes' directories in each project
+@rem are used. Otherwise classes will be retrived from
+@rem jars in the binary releas collection or *-job.jar files
+@rem found in build directories. When running on hadoop
+@rem the job files will always be used.
+@rem /*
+@rem * Licensed to the Apache Software Foundation (ASF) under one or more
+@rem * contributor license agreements. See the NOTICE file distributed with
+@rem * this work for additional information regarding copyright ownership.
+@rem * The ASF licenses this file to You under the Apache License, Version 2.0
+@rem * (the "License"); you may not use this file except in compliance with
+@rem * the License. You may obtain a copy of the License at
+@rem *
+@rem * http://www.apache.org/licenses/LICENSE-2.0
+@rem *
+@rem * Unless required by applicable law or agreed to in writing, software
+@rem * distributed under the License is distributed on an "AS IS" BASIS,
+@rem * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+@rem * See the License for the specific language governing permissions and
+@rem * limitations under the License.
+@rem */
+setlocal enabledelayedexpansion
+@rem disable "developer mode"
+set IS_CORE=0
+if [%1] == [-core] (
+ set IS_CORE=1
+ shift
+if not [%MAHOUT_CORE%] == [] (
+set IS_CORE=1
+if [%MAHOUT_HOME%] == [] set MAHOUT_HOME=%~dp0..
+echo "Mahout home set %MAHOUT_HOME%"
+@rem some Java parameters
+if not [%MAHOUT_JAVA_HOME%] == [] (
+@rem echo run java in %MAHOUT_JAVA_HOME%
+if [%JAVA_HOME%] == [] (
+ echo Error: JAVA_HOME is not set.
+ exit /B 1
+set JAVA=%JAVA_HOME%\bin\java
+set JAVA_HEAP_MAX=-Xmx3g
+@rem check envvars which might override default args
+if not [%MAHOUT_HEAPSIZE%] == [] (
+@rem echo run with heapsize %MAHOUT_HEAPSIZE%
+@rem echo %JAVA_HEAP_MAX%
+if [%MAHOUT_CONF_DIR%] == [] (
+@rem MAHOUT_CLASSPATH initially contains $MAHOUT_CONF_DIR, or defaults to $MAHOUT_HOME/src/conf
+if not [%MAHOUT_LOCAL%] == [] (
+echo "MAHOUT_LOCAL is set, so we do not add HADOOP_CONF_DIR to classpath."
+) else (
+if not [%HADOOP_CONF_DIR%] == [] (
+echo "MAHOUT_LOCAL is not set; adding HADOOP_CONF_DIR to classpath."
+set CLASSPATH=%CLASSPATH%;%JAVA_HOME%\lib\tools.jar
+if %IS_CORE% == 0 (
+@rem add release dependencies to CLASSPATH
+for %%f in (%MAHOUT_HOME%\mahout-*.jar) do (
+@rem add dev targets if they exist
+for %%f in (%MAHOUT_HOME%\examples\target\mahout-examples-*-job.jar) do (
+for %%f in (%MAHOUT_HOME%\mahout-examples-*-job.jar) do (
+@rem add release dependencies to CLASSPATH
+for %%f in (%MAHOUT_HOME%\lib\*.jar) do (
+) else (
+set CLASSPATH=!CLASSPATH!;%MAHOUT_HOME%\math\target\classes
+set CLASSPATH=!CLASSPATH!;%MAHOUT_HOME%\core\target\classes
+set CLASSPATH=!CLASSPATH!;%MAHOUT_HOME%\integration\target\classes
+set CLASSPATH=!CLASSPATH!;%MAHOUT_HOME%\examples\target\classes
+@rem set CLASSPATH=%CLASSPATH%;%MAHOUT_HOME%\core\src\main\resources
+@rem add development dependencies to CLASSPATH
+for %%f in (%MAHOUT_HOME%\examples\target\dependency\*.jar) do (
+@rem default log directory & file
+if [%MAHOUT_LOG_DIR%] == [] (
+if [%MAHOUT_LOGFILE%] == [] (
+set MAHOUT_LOGFILE=mahout.log
+if not [%JAVA_LIBRARY_PATH%] == [] (
+set CLASS=org.apache.mahout.driver.MahoutDriver
+for %%f in (%MAHOUT_HOME%\examples\target\mahout-examples-*-job.jar) do (
+set MAHOUT_JOB=%%f
+@rem run it
+if not [%MAHOUT_LOCAL%] == [] (
+ echo "MAHOUT_LOCAL is set, running locally"
+) else (
+ if [%MAHOUT_JOB%] == [] (
+ echo "ERROR: Could not find mahout-examples-*.job in %MAHOUT_HOME% or %MAHOUT_HOME%/examples/target, please run 'mvn install' to create the .job file"
+ exit /B 1
+ ) else (
+ if /i [%1] == [hadoop] (
+ call %HADOOP_HOME%\bin\%*
+ ) else (
+if /i [%1] == [classpath] (
+) else (
+call %HADOOP_HOME%\bin\hadoop jar %MAHOUT_JOB% %CLASS% %*
+ )
+ )
+@echo off
+@rem The Mahout command script
+@rem Environment Variables
+@rem MAHOUT_JAVA_HOME The java implementation to use. Overrides JAVA_HOME.
+@rem MAHOUT_HEAPSIZE The maximum amount of heap to use, in MB.
+@rem Default is 1000.
+@rem HADOOP_CONF_DIR The location of a hadoop config directory
+@rem MAHOUT_OPTS Extra Java runtime options.
+@rem MAHOUT_CONF_DIR The location of the program short-name to class name
+@rem mappings and the default properties files
+@rem defaults to "$MAHOUT_HOME/src/conf"
+@rem MAHOUT_LOCAL set to anything other than an empty string to force
+@rem mahout to run locally even if
+@rem MAHOUT_CORE set to anything other than an empty string to force
+@rem mahout to run in developer 'core' mode, just as if the
+@rem -core option was presented on the command-line
+@rem Commane-line Options
+@rem -core -core is used to switch into 'developer mode' when
+@rem running mahout locally. If specified, the classes
+@rem from the 'target/classes' directories in each project
+@rem are used. Otherwise classes will be retrived from
+@rem jars in the binary releas collection or *-job.jar files
+@rem found in build directories. When running on hadoop
+@rem the job files will always be used.
+@rem /*
+@rem * Licensed to the Apache Software Foundation (ASF) under one or more
+@rem * contributor license agreements. See the NOTICE file distributed with
+@rem * this work for additional information regarding copyright ownership.
+@rem * The ASF licenses this file to You under the Apache License, Version 2.0
+@rem * (the "License"); you may not use this file except in compliance with
+@rem * the License. You may obtain a copy of the License at
+@rem *
+@rem * http://www.apache.org/licenses/LICENSE-2.0
+@rem *
+@rem * Unless required by applicable law or agreed to in writing, software
+@rem * distributed under the License is distributed on an "AS IS" BASIS,
+@rem * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+@rem * See the License for the specific language governing permissions and
+@rem * limitations under the License.
+@rem */
+setlocal enabledelayedexpansion
+@rem disable "developer mode"
+set IS_CORE=0
+if [%1] == [-core] (
+ set IS_CORE=1
+ shift
+if not [%MAHOUT_CORE%] == [] (
+set IS_CORE=1
+if [%MAHOUT_HOME%] == [] set MAHOUT_HOME=%~dp0..
+echo "Mahout home set %MAHOUT_HOME%"
+@rem some Java parameters
+if not [%MAHOUT_JAVA_HOME%] == [] (
+@rem echo run java in %MAHOUT_JAVA_HOME%
+if [%JAVA_HOME%] == [] (
+ echo Error: JAVA_HOME is not set.
+ exit /B 1
+set JAVA=%JAVA_HOME%\bin\java
+set JAVA_HEAP_MAX=-Xmx3g
+@rem check envvars which might override default args
+if not [%MAHOUT_HEAPSIZE%] == [] (
+@rem echo run with heapsize %MAHOUT_HEAPSIZE%
+@rem echo %JAVA_HEAP_MAX%
+if [%MAHOUT_CONF_DIR%] == [] (
+@rem MAHOUT_CLASSPATH initially contains $MAHOUT_CONF_DIR, or defaults to $MAHOUT_HOME/src/conf
+if not [%MAHOUT_LOCAL%] == [] (
+echo "MAHOUT_LOCAL is set, so we do not add HADOOP_CONF_DIR to classpath."
+) else (
+if not [%HADOOP_CONF_DIR%] == [] (
+echo "MAHOUT_LOCAL is not set; adding HADOOP_CONF_DIR to classpath."
+set CLASSPATH=%CLASSPATH%;%JAVA_HOME%\lib\tools.jar
+if %IS_CORE% == 0 (
+@rem add release dependencies to CLASSPATH
+for %%f in (%MAHOUT_HOME%\mahout-*.jar) do (
+@rem add dev targets if they exist
+for %%f in (%MAHOUT_HOME%\examples\target\mahout-examples-*-job.jar) do (
+for %%f in (%MAHOUT_HOME%\mahout-examples-*-job.jar) do (
+@rem add release dependencies to CLASSPATH
+for %%f in (%MAHOUT_HOME%\lib\*.jar) do (
+) else (
+set CLASSPATH=!CLASSPATH!;%MAHOUT_HOME%\math\target\classes
+set CLASSPATH=!CLASSPATH!;%MAHOUT_HOME%\core\target\classes
+set CLASSPATH=!CLASSPATH!;%MAHOUT_HOME%\integration\target\classes
+set CLASSPATH=!CLASSPATH!;%MAHOUT_HOME%\examples\target\classes
+@rem set CLASSPATH=%CLASSPATH%;%MAHOUT_HOME%\core\src\main\resources
+@rem add development dependencies to CLASSPATH
+for %%f in (%MAHOUT_HOME%\examples\target\dependency\*.jar) do (
+@rem default log directory & file
+if [%MAHOUT_LOG_DIR%] == [] (
+if [%MAHOUT_LOGFILE%] == [] (
+set MAHOUT_LOGFILE=mahout.log
+set MAHOUT_OPTS=%MAHOUT_OPTS% -Dmapred.min.split.size=512MB
+set MAHOUT_OPTS=%MAHOUT_OPTS% -Dmapred.map.child.java.opts=-Xmx4096m
+set MAHOUT_OPTS=%MAHOUT_OPTS% -Dmapred.reduce.child.java.opts=-Xmx4096m
+set MAHOUT_OPTS=%MAHOUT_OPTS% -Dmapred.output.compress=true
+set MAHOUT_OPTS=%MAHOUT_OPTS% -Dmapred.compress.map.output=true
+set MAHOUT_OPTS=%MAHOUT_OPTS% -Dmapred.map.tasks=1
+set MAHOUT_OPTS=%MAHOUT_OPTS% -Dmapred.reduce.tasks=1
+set MAHOUT_OPTS=%MAHOUT_OPTS% -Dio.sort.factor=30
+set MAHOUT_OPTS=%MAHOUT_OPTS% -Dio.sort.mb=1024
+set MAHOUT_OPTS=%MAHOUT_OPTS% -Dio.file.buffer.size=32786
+set HADOOP_OPTS=%HADOOP_OPTS% -Djava.library.path=%HADOOP_HOME%\bin
+if not [%JAVA_LIBRARY_PATH%] == [] (
+set CLASS=org.apache.mahout.driver.MahoutDriver
+for %%f in (%MAHOUT_HOME%\examples\target\mahout-examples-*-job.jar) do (
+set MAHOUT_JOB=%%f
+@rem run it
+if not [%MAHOUT_LOCAL%] == [] (
+ echo "MAHOUT_LOCAL is set, running locally"
+) else (
+ if [%MAHOUT_JOB%] == [] (
+ echo "ERROR: Could not find mahout-examples-*.job in %MAHOUT_HOME% or %MAHOUT_HOME%/examples/target, please run 'mvn install' to create the .job file"
+ exit /B 1
+ ) else (
+ if /i [%1] == [hadoop] (
+ call %HADOOP_HOME%\bin\%*
+ ) else (
+if /i [%1] == [classpath] (
+) else (
+call %HADOOP_HOME%\bin\hadoop jar %MAHOUT_JOB% %CLASS% %*
+ )
+ )

diff --git a/community/mahout-mr/examples/bin/README.txt b/community/mahout-mr/examples/bin/README.txt
new file mode 100644
index 0000000..7ad3a38
--- /dev/null
+++ b/community/mahout-mr/examples/bin/README.txt
@@ -0,0 +1,13 @@
+This directory contains helpful shell scripts for working with some of Mahout's examples.
+To set a non-default temporary work directory: `export MAHOUT_WORK_DIR=/path/in/hdfs/to/temp/dir`
+ Note that this requires the same path to be writable both on the local file system as well as on HDFS.
+Here's a description of what each does:
+classify-20newsgroups.sh -- Run SGD and Bayes classifiers over the classic 20 News Groups. Downloads the data set automatically.
+cluster-reuters.sh -- Cluster the Reuters data set using a variety of algorithms. Downloads the data set automatically.
+cluster-syntheticcontrol.sh -- Cluster the Synthetic Control data set. Downloads the data set automatically.
+factorize-movielens-1m.sh -- Run the Alternating Least Squares Recommender on the Grouplens data set (size 1M).
+factorize-netflix.sh -- (Deprecated due to lack of availability of the data set) Run the ALS Recommender on the Netflix data set.
+spark-document-classifier.mscala -- A mahout-shell script which trains and tests a Naive Bayes model on the Wikipedia XML dump and defines simple methods to classify new text.

diff --git a/community/mahout-mr/examples/bin/classify-20newsgroups.sh b/community/mahout-mr/examples/bin/classify-20newsgroups.sh
new file mode 100755
index 0000000..f47d5c5
--- /dev/null
+++ b/community/mahout-mr/examples/bin/classify-20newsgroups.sh
@@ -0,0 +1,197 @@
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# Downloads the 20newsgroups dataset, trains and tests a classifier.
+# To run: change into the mahout directory and type:
+# examples/bin/classify-20newsgroups.sh
+if [ "$1" = "--help" ] || [ "$1" = "--?" ]; then
+ echo "This script runs SGD and Bayes classifiers over the classic 20 News Groups."
+ exit
+if [ "$0" != "$SCRIPT_PATH" ] && [ "$SCRIPT_PATH" != "" ]; then
+# Set commands for dfs
+source ${START_PATH}/set-dfs-commands.sh
+if [[ -z "$MAHOUT_WORK_DIR" ]]; then
+ WORK_DIR=/tmp/mahout-work-${USER}
+algorithm=( cnaivebayes-MapReduce naivebayes-MapReduce cnaivebayes-Spark naivebayes-Spark sgd clean)
+if [ -n "$1" ]; then
+ choice=$1
+ echo "Please select a number to choose the corresponding task to run"
+ echo "1. ${algorithm[0]}"
+ echo "2. ${algorithm[1]}"
+ echo "3. ${algorithm[2]}"
+ echo "4. ${algorithm[3]}"
+ echo "5. ${algorithm[4]}"
+ echo "6. ${algorithm[5]}-- cleans up the work area in $WORK_DIR"
+ read -p "Enter your choice : " choice
+echo "ok. You chose $choice and we'll use ${algorithm[$choice-1]}"
+# Spark specific check and work
+if [ "x$alg" == "xnaivebayes-Spark" -o "x$alg" == "xcnaivebayes-Spark" ]; then
+ if [ "$MASTER" == "" ] ; then
+ echo "Please set your MASTER env variable to point to your Spark Master URL. exiting..."
+ exit 1
+ fi
+ if [ "$MAHOUT_LOCAL" != "" ] ; then
+ echo "Options 3 and 4 can not run in MAHOUT_LOCAL mode. exiting..."
+ exit 1
+ fi
+if [ "x$alg" != "xclean" ]; then
+ echo "creating work directory at ${WORK_DIR}"
+ mkdir -p ${WORK_DIR}
+ if [ ! -e ${WORK_DIR}/20news-bayesinput ]; then
+ if [ ! -e ${WORK_DIR}/20news-bydate ]; then
+ if [ ! -f ${WORK_DIR}/20news-bydate.tar.gz ]; then
+ echo "Downloading 20news-bydate"
+ curl http://people.csail.mit.edu/jrennie/20Newsgroups/20news-bydate.tar.gz -o ${WORK_DIR}/20news-bydate.tar.gz
+ fi
+ mkdir -p ${WORK_DIR}/20news-bydate
+ echo "Extracting..."
+ cd ${WORK_DIR}/20news-bydate && tar xzf ../20news-bydate.tar.gz && cd .. && cd ..
+ fi
+ fi
+#echo $START_PATH
+cd ../..
+set -e
+if ( [ "x$alg" == "xnaivebayes-MapReduce" ] || [ "x$alg" == "xcnaivebayes-MapReduce" ] || [ "x$alg" == "xnaivebayes-Spark" ] || [ "x$alg" == "xcnaivebayes-Spark" ] ); then
+ c=""
+ if [ "x$alg" == "xcnaivebayes-MapReduce" -o "x$alg" == "xnaivebayes-Spark" ]; then
+ c=" -c"
+ fi
+ set -x
+ echo "Preparing 20newsgroups data"
+ rm -rf ${WORK_DIR}/20news-all
+ mkdir ${WORK_DIR}/20news-all
+ cp -R ${WORK_DIR}/20news-bydate/*/* ${WORK_DIR}/20news-all
+ if [ "$HADOOP_HOME" != "" ] && [ "$MAHOUT_LOCAL" == "" ] ; then
+ echo "Copying 20newsgroups data to HDFS"
+ set +e
+ $DFSRM ${WORK_DIR}/20news-all
+ $DFS -mkdir -p ${WORK_DIR}
+ $DFS -mkdir ${WORK_DIR}/20news-all
+ set -e
+ if [ $HVERSION -eq "1" ] ; then
+ echo "Copying 20newsgroups data to Hadoop 1 HDFS"
+ $DFS -put ${WORK_DIR}/20news-all ${WORK_DIR}/20news-all
+ elif [ $HVERSION -eq "2" ] ; then
+ echo "Copying 20newsgroups data to Hadoop 2 HDFS"
+ $DFS -put ${WORK_DIR}/20news-all ${WORK_DIR}/
+ fi
+ fi
+ echo "Creating sequence files from 20newsgroups data"
+ ./bin/mahout seqdirectory \
+ -i ${WORK_DIR}/20news-all \
+ -o ${WORK_DIR}/20news-seq -ow
+ echo "Converting sequence files to vectors"
+ ./bin/mahout seq2sparse \
+ -i ${WORK_DIR}/20news-seq \
+ -o ${WORK_DIR}/20news-vectors -lnorm -nv -wt tfidf
+ echo "Creating training and holdout set with a random 80-20 split of the generated vector dataset"
+ ./bin/mahout split \
+ -i ${WORK_DIR}/20news-vectors/tfidf-vectors \
+ --trainingOutput ${WORK_DIR}/20news-train-vectors \
+ --testOutput ${WORK_DIR}/20news-test-vectors \
+ --randomSelectionPct 40 --overwrite --sequenceFiles -xm sequential
+ if [ "x$alg" == "xnaivebayes-MapReduce" -o "x$alg" == "xcnaivebayes-MapReduce" ]; then
+ echo "Training Naive Bayes model"
+ ./bin/mahout trainnb \
+ -i ${WORK_DIR}/20news-train-vectors \
+ -o ${WORK_DIR}/model \
+ -li ${WORK_DIR}/labelindex \
+ -ow $c
+ echo "Self testing on training set"
+ ./bin/mahout testnb \
+ -i ${WORK_DIR}/20news-train-vectors\
+ -m ${WORK_DIR}/model \
+ -l ${WORK_DIR}/labelindex \
+ -ow -o ${WORK_DIR}/20news-testing $c
+ echo "Testing on holdout set"
+ ./bin/mahout testnb \
+ -i ${WORK_DIR}/20news-test-vectors\
+ -m ${WORK_DIR}/model \
+ -l ${WORK_DIR}/labelindex \
+ -ow -o ${WORK_DIR}/20news-testing $c
+ elif [ "x$alg" == "xnaivebayes-Spark" -o "x$alg" == "xcnaivebayes-Spark" ]; then
+ echo "Training Naive Bayes model"
+ ./bin/mahout spark-trainnb \
+ -i ${WORK_DIR}/20news-train-vectors \
+ -o ${WORK_DIR}/spark-model $c -ow -ma $MASTER
+ echo "Self testing on training set"
+ ./bin/mahout spark-testnb \
+ -i ${WORK_DIR}/20news-train-vectors\
+ -m ${WORK_DIR}/spark-model $c -ma $MASTER
+ echo "Testing on holdout set"
+ ./bin/mahout spark-testnb \
+ -i ${WORK_DIR}/20news-test-vectors\
+ -m ${WORK_DIR}/spark-model $c -ma $MASTER
+ fi
+elif [ "x$alg" == "xsgd" ]; then
+ if [ ! -e "/tmp/news-group.model" ]; then
+ echo "Training on ${WORK_DIR}/20news-bydate/20news-bydate-train/"
+ ./bin/mahout org.apache.mahout.classifier.sgd.TrainNewsGroups ${WORK_DIR}/20news-bydate/20news-bydate-train/
+ fi
+ echo "Testing on ${WORK_DIR}/20news-bydate/20news-bydate-test/ with model: /tmp/news-group.model"
+ ./bin/mahout org.apache.mahout.classifier.sgd.TestNewsGroups --input ${WORK_DIR}/20news-bydate/20news-bydate-test/ --model /tmp/news-group.model
+elif [ "x$alg" == "xclean" ]; then
+ rm -rf $WORK_DIR
+ rm -rf /tmp/news-group.model
+# Remove the work directory

diff --git a/community/mahout-mr/examples/bin/classify-wikipedia.sh b/community/mahout-mr/examples/bin/classify-wikipedia.sh
new file mode 100755
index 0000000..41dc0c9
--- /dev/null
+++ b/community/mahout-mr/examples/bin/classify-wikipedia.sh
@@ -0,0 +1,196 @@
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# Downloads a (partial) wikipedia dump, trains and tests a classifier.
+# To run: change into the mahout directory and type:
+# examples/bin/classify-wikipedia.sh
+if [ "$1" = "--help" ] || [ "$1" = "--?" ]; then
+ echo "This script Bayes and CBayes classifiers over the last wikipedia dump."
+ exit
+# ensure that MAHOUT_HOME is set
+if [[ -z "$MAHOUT_HOME" ]]; then
+ echo "Please set MAHOUT_HOME."
+ exit
+if [ "$0" != "$SCRIPT_PATH" ] && [ "$SCRIPT_PATH" != "" ]; then
+# Set commands for dfs
+source ${START_PATH}/set-dfs-commands.sh
+if [[ -z "$MAHOUT_WORK_DIR" ]]; then
+ WORK_DIR=/tmp/mahout-work-wiki
+algorithm=( CBayes BinaryCBayes clean)
+if [ -n "$1" ]; then
+ choice=$1
+ echo "Please select a number to choose the corresponding task to run"
+ echo "1. ${algorithm[0]} (may require increased heap space on yarn)"
+ echo "2. ${algorithm[1]}"
+ echo "3. ${algorithm[2]} -- cleans up the work area in $WORK_DIR"
+ read -p "Enter your choice : " choice
+echo "ok. You chose $choice and we'll use ${algorithm[$choice-1]}"
+if [ "x$alg" != "xclean" ]; then
+ echo "creating work directory at ${WORK_DIR}"
+ mkdir -p ${WORK_DIR}
+ if [ ! -e ${WORK_DIR}/wikixml ]; then
+ mkdir -p ${WORK_DIR}/wikixml
+ fi
+ if [ ! -e ${WORK_DIR}/wikixml/enwiki-latest-pages-articles.xml.bz2 ]; then
+ echo "Downloading wikipedia XML dump"
+ ########################################################
+ # Datasets: uncomment and run "clean" to change dataset
+ ########################################################
+ ########## partial small 42.5M zipped
+ # curl https://dumps.wikimedia.org/enwiki/latest/enwiki-latest-pages-articles1.xml-p000000010p000030302.bz2 -o ${WORK_DIR}/wikixml/enwiki-latest-pages-articles.xml.bz2
+ ########## partial larger 256M zipped
+ curl https://dumps.wikimedia.org/enwiki/latest/enwiki-latest-pages-articles10.xml-p2336425p3046511.bz2 -o ${WORK_DIR}/wikixml/enwiki-latest-pages-articles.xml.bz2
+ ######### full wikipedia dump: 10G zipped
+ # curl https://dumps.wikimedia.org/enwiki/latest/enwiki-latest-pages-articles.xml.bz2 -o ${WORK_DIR}/wikixml/enwiki-latest-pages-articles.xml.bz2
+ ########################################################
+ fi
+ if [ ! -e ${WORK_DIR}/wikixml/enwiki-latest-pages-articles.xml ]; then
+ echo "Extracting..."
+ cd ${WORK_DIR}/wikixml && bunzip2 enwiki-latest-pages-articles.xml.bz2 && cd .. && cd ..
+ fi
+set -e
+if [ "x$alg" == "xCBayes" ] || [ "x$alg" == "xBinaryCBayes" ] ; then
+ set -x
+ echo "Preparing wikipedia data"
+ rm -rf ${WORK_DIR}/wiki
+ mkdir ${WORK_DIR}/wiki
+ if [ "x$alg" == "xCBayes" ] ; then
+ # use a list of 10 countries as categories
+ cp $MAHOUT_HOME/examples/bin/resources/country10.txt ${WORK_DIR}/country.txt
+ chmod 666 ${WORK_DIR}/country.txt
+ fi
+ if [ "x$alg" == "xBinaryCBayes" ] ; then
+ # use United States and United Kingdom as categories
+ cp $MAHOUT_HOME/examples/bin/resources/country2.txt ${WORK_DIR}/country.txt
+ chmod 666 ${WORK_DIR}/country.txt
+ fi
+ if [ "$HADOOP_HOME" != "" ] && [ "$MAHOUT_LOCAL" == "" ] ; then
+ echo "Copying wikipedia data to HDFS"
+ set +e
+ $DFSRM ${WORK_DIR}/wikixml
+ $DFS -mkdir -p ${WORK_DIR}
+ set -e
+ $DFS -put ${WORK_DIR}/wikixml ${WORK_DIR}/wikixml
+ fi
+ echo "Creating sequence files from wikiXML"
+ $MAHOUT_HOME/bin/mahout seqwiki -c ${WORK_DIR}/country.txt \
+ -i ${WORK_DIR}/wikixml/enwiki-latest-pages-articles.xml \
+ -o ${WORK_DIR}/wikipediainput
+ # if using the 10 class problem use bigrams
+ if [ "x$alg" == "xCBayes" ] ; then
+ echo "Converting sequence files to vectors using bigrams"
+ $MAHOUT_HOME/bin/mahout seq2sparse -i ${WORK_DIR}/wikipediainput \
+ -o ${WORK_DIR}/wikipediaVecs \
+ -wt tfidf \
+ -lnorm -nv \
+ -ow -ng 2
+ fi
+ # if using the 2 class problem try different options
+ if [ "x$alg" == "xBinaryCBayes" ] ; then
+ echo "Converting sequence files to vectors using unigrams and a max document frequency of 30%"
+ $MAHOUT_HOME/bin/mahout seq2sparse -i ${WORK_DIR}/wikipediainput \
+ -o ${WORK_DIR}/wikipediaVecs \
+ -wt tfidf \
+ -lnorm \
+ -nv \
+ -ow \
+ -ng 1 \
+ -x 30
+ fi
+ echo "Creating training and holdout set with a random 80-20 split of the generated vector dataset"
+ $MAHOUT_HOME/bin/mahout split -i ${WORK_DIR}/wikipediaVecs/tfidf-vectors/ \
+ --trainingOutput ${WORK_DIR}/training \
+ --testOutput ${WORK_DIR}/testing \
+ -rp 20 \
+ -ow \
+ -seq \
+ -xm sequential
+ echo "Training Naive Bayes model"
+ $MAHOUT_HOME/bin/mahout trainnb -i ${WORK_DIR}/training \
+ -o ${WORK_DIR}/model \
+ -li ${WORK_DIR}/labelindex \
+ -ow \
+ -c
+ echo "Self testing on training set"
+ $MAHOUT_HOME/bin/mahout testnb -i ${WORK_DIR}/training \
+ -m ${WORK_DIR}/model \
+ -l ${WORK_DIR}/labelindex \
+ -ow \
+ -o ${WORK_DIR}/output \
+ -c
+ echo "Testing on holdout set: Bayes"
+ $MAHOUT_HOME/bin/mahout testnb -i ${WORK_DIR}/testing \
+ -m ${WORK_DIR}/model \
+ -l ${WORK_DIR}/labelindex \
+ -ow \
+ -o ${WORK_DIR}/output \
+ -seq
+ echo "Testing on holdout set: CBayes"
+ $MAHOUT_HOME/bin/mahout testnb -i ${WORK_DIR}/testing \
+ -m ${WORK_DIR}/model -l \
+ ${WORK_DIR}/labelindex \
+ -ow \
+ -o ${WORK_DIR}/output \
+ -c \
+ -seq
+elif [ "x$alg" == "xclean" ]; then
+ rm -rf $WORK_DIR
+# Remove the work directory

diff --git a/community/mahout-mr/examples/bin/cluster-reuters.sh b/community/mahout-mr/examples/bin/cluster-reuters.sh
new file mode 100755
index 0000000..49f6c94
--- /dev/null
+++ b/community/mahout-mr/examples/bin/cluster-reuters.sh
@@ -0,0 +1,203 @@
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# Downloads the Reuters dataset and prepares it for clustering
+# To run: change into the mahout directory and type:
+# examples/bin/cluster-reuters.sh
+if [ "$1" = "--help" ] || [ "$1" = "--?" ]; then
+ echo "This script clusters the Reuters data set using a variety of algorithms. The data set is downloaded automatically."
+ exit
+if [ "$0" != "$SCRIPT_PATH" ] && [ "$SCRIPT_PATH" != "" ]; then
+# Set commands for dfs
+source ${START_PATH}/set-dfs-commands.sh
+if [ ! -e $MAHOUT ]; then
+ echo "Can't find mahout driver in $MAHOUT, cwd `pwd`, exiting.."
+ exit 1
+if [[ -z "$MAHOUT_WORK_DIR" ]]; then
+ WORK_DIR=/tmp/mahout-work-${USER}
+algorithm=( kmeans fuzzykmeans lda streamingkmeans clean)
+if [ -n "$1" ]; then
+ choice=$1
+ echo "Please select a number to choose the corresponding clustering algorithm"
+ echo "1. ${algorithm[0]} clustering (runs from this example script in cluster mode only)"
+ echo "2. ${algorithm[1]} clustering (may require increased heap space on yarn)"
+ echo "3. ${algorithm[2]} clustering"
+ echo "4. ${algorithm[3]} clustering"
+ echo "5. ${algorithm[4]} -- cleans up the work area in $WORK_DIR"
+ read -p "Enter your choice : " choice
+echo "ok. You chose $choice and we'll use ${algorithm[$choice-1]} Clustering"
+if [ "x$clustertype" == "xclean" ]; then
+ rm -rf $WORK_DIR
+ exit 1
+ $DFS -mkdir -p $WORK_DIR
+ mkdir -p $WORK_DIR
+ echo "Creating work directory at ${WORK_DIR}"
+if [ ! -e ${WORK_DIR}/reuters-out-seqdir ]; then
+ if [ ! -e ${WORK_DIR}/reuters-out ]; then
+ if [ ! -e ${WORK_DIR}/reuters-sgm ]; then
+ if [ ! -f ${WORK_DIR}/reuters21578.tar.gz ]; then
+ if [ -n "$2" ]; then
+ echo "Copying Reuters from local download"
+ cp $2 ${WORK_DIR}/reuters21578.tar.gz
+ else
+ echo "Downloading Reuters-21578"
+ curl http://kdd.ics.uci.edu/databases/reuters21578/reuters21578.tar.gz -o ${WORK_DIR}/reuters21578.tar.gz
+ fi
+ fi
+ #make sure it was actually downloaded
+ if [ ! -f ${WORK_DIR}/reuters21578.tar.gz ]; then
+ echo "Failed to download reuters"
+ exit 1
+ fi
+ mkdir -p ${WORK_DIR}/reuters-sgm
+ echo "Extracting..."
+ tar xzf ${WORK_DIR}/reuters21578.tar.gz -C ${WORK_DIR}/reuters-sgm
+ fi
+ echo "Extracting Reuters"
+ $MAHOUT org.apache.lucene.benchmark.utils.ExtractReuters ${WORK_DIR}/reuters-sgm ${WORK_DIR}/reuters-out
+ if [ "$HADOOP_HOME" != "" ] && [ "$MAHOUT_LOCAL" == "" ] ; then
+ echo "Copying Reuters data to Hadoop"
+ set +e
+ $DFSRM ${WORK_DIR}/reuters-sgm
+ $DFSRM ${WORK_DIR}/reuters-out
+ $DFS -mkdir -p ${WORK_DIR}/
+ $DFS -mkdir ${WORK_DIR}/reuters-sgm
+ $DFS -mkdir ${WORK_DIR}/reuters-out
+ $DFS -put ${WORK_DIR}/reuters-sgm ${WORK_DIR}/reuters-sgm
+ $DFS -put ${WORK_DIR}/reuters-out ${WORK_DIR}/reuters-out
+ set -e
+ fi
+ fi
+ echo "Converting to Sequence Files from Directory"
+ $MAHOUT seqdirectory -i ${WORK_DIR}/reuters-out -o ${WORK_DIR}/reuters-out-seqdir -c UTF-8 -chunk 64 -xm sequential
+if [ "x$clustertype" == "xkmeans" ]; then
+ $MAHOUT seq2sparse \
+ -i ${WORK_DIR}/reuters-out-seqdir/ \
+ -o ${WORK_DIR}/reuters-out-seqdir-sparse-kmeans --maxDFPercent 85 --namedVector \
+ && \
+ $MAHOUT kmeans \
+ -i ${WORK_DIR}/reuters-out-seqdir-sparse-kmeans/tfidf-vectors/ \
+ -c ${WORK_DIR}/reuters-kmeans-clusters \
+ -o ${WORK_DIR}/reuters-kmeans \
+ -dm org.apache.mahout.common.distance.EuclideanDistanceMeasure \
+ -x 10 -k 20 -ow --clustering \
+ && \
+ $MAHOUT clusterdump \
+ -i `$DFS -ls -d ${WORK_DIR}/reuters-kmeans/clusters-*-final | awk '{print $8}'` \
+ -o ${WORK_DIR}/reuters-kmeans/clusterdump \
+ -d ${WORK_DIR}/reuters-out-seqdir-sparse-kmeans/dictionary.file-0 \
+ -dt sequencefile -b 100 -n 20 --evaluate -dm org.apache.mahout.common.distance.EuclideanDistanceMeasure -sp 0 \
+ --pointsDir ${WORK_DIR}/reuters-kmeans/clusteredPoints \
+ && \
+ cat ${WORK_DIR}/reuters-kmeans/clusterdump
+elif [ "x$clustertype" == "xfuzzykmeans" ]; then
+ $MAHOUT seq2sparse \
+ -i ${WORK_DIR}/reuters-out-seqdir/ \
+ -o ${WORK_DIR}/reuters-out-seqdir-sparse-fkmeans --maxDFPercent 85 --namedVector \
+ && \
+ $MAHOUT fkmeans \
+ -i ${WORK_DIR}/reuters-out-seqdir-sparse-fkmeans/tfidf-vectors/ \
+ -c ${WORK_DIR}/reuters-fkmeans-clusters \
+ -o ${WORK_DIR}/reuters-fkmeans \
+ -dm org.apache.mahout.common.distance.EuclideanDistanceMeasure \
+ -x 10 -k 20 -ow -m 1.1 \
+ && \
+ $MAHOUT clusterdump \
+ -i ${WORK_DIR}/reuters-fkmeans/clusters-*-final \
+ -o ${WORK_DIR}/reuters-fkmeans/clusterdump \
+ -d ${WORK_DIR}/reuters-out-seqdir-sparse-fkmeans/dictionary.file-0 \
+ -dt sequencefile -b 100 -n 20 -sp 0 \
+ && \
+ cat ${WORK_DIR}/reuters-fkmeans/clusterdump
+elif [ "x$clustertype" == "xlda" ]; then
+ $MAHOUT seq2sparse \
+ -i ${WORK_DIR}/reuters-out-seqdir/ \
+ -o ${WORK_DIR}/reuters-out-seqdir-sparse-lda -ow --maxDFPercent 85 --namedVector \
+ && \
+ $MAHOUT rowid \
+ -i ${WORK_DIR}/reuters-out-seqdir-sparse-lda/tfidf-vectors \
+ -o ${WORK_DIR}/reuters-out-matrix \
+ && \
+ rm -rf ${WORK_DIR}/reuters-lda ${WORK_DIR}/reuters-lda-topics ${WORK_DIR}/reuters-lda-model \
+ && \
+ $MAHOUT cvb \
+ -i ${WORK_DIR}/reuters-out-matrix/matrix \
+ -o ${WORK_DIR}/reuters-lda -k 20 -ow -x 20 \
+ -dict ${WORK_DIR}/reuters-out-seqdir-sparse-lda/dictionary.file-* \
+ -dt ${WORK_DIR}/reuters-lda-topics \
+ -mt ${WORK_DIR}/reuters-lda-model \
+ && \
+ $MAHOUT vectordump \
+ -i ${WORK_DIR}/reuters-lda-topics/part-m-00000 \
+ -o ${WORK_DIR}/reuters-lda/vectordump \
+ -vs 10 -p true \
+ -d ${WORK_DIR}/reuters-out-seqdir-sparse-lda/dictionary.file-* \
+ -dt sequencefile -sort ${WORK_DIR}/reuters-lda-topics/part-m-00000 \
+ && \
+ cat ${WORK_DIR}/reuters-lda/vectordump
+elif [ "x$clustertype" == "xstreamingkmeans" ]; then
+ $MAHOUT seq2sparse \
+ -i ${WORK_DIR}/reuters-out-seqdir/ \
+ -o ${WORK_DIR}/reuters-out-seqdir-sparse-streamingkmeans -ow --maxDFPercent 85 --namedVector \
+ && \
+ rm -rf ${WORK_DIR}/reuters-streamingkmeans \
+ && \
+ $MAHOUT streamingkmeans \
+ -i ${WORK_DIR}/reuters-out-seqdir-sparse-streamingkmeans/tfidf-vectors/ \
+ --tempDir ${WORK_DIR}/tmp \
+ -o ${WORK_DIR}/reuters-streamingkmeans \
+ -sc org.apache.mahout.math.neighborhood.FastProjectionSearch \
+ -dm org.apache.mahout.common.distance.SquaredEuclideanDistanceMeasure \
+ -k 10 -km 100 -ow \
+ && \
+ $MAHOUT qualcluster \
+ -i ${WORK_DIR}/reuters-out-seqdir-sparse-streamingkmeans/tfidf-vectors/part-r-00000 \
+ -c ${WORK_DIR}/reuters-streamingkmeans/part-r-00000 \
+ -o ${WORK_DIR}/reuters-cluster-distance.csv \
+ && \
+ cat ${WORK_DIR}/reuters-cluster-distance.csv
2018-06-27 13:14:45 UTC
diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/svd/ParallelArraysSGDFactorizer.java b/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/svd/ParallelArraysSGDFactorizer.java
new file mode 100644
index 0000000..a99d54c
--- /dev/null
+++ b/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/svd/ParallelArraysSGDFactorizer.java
@@ -0,0 +1,265 @@
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.cf.taste.example.kddcup.track1.svd;
+import org.apache.mahout.cf.taste.common.Refreshable;
+import org.apache.mahout.cf.taste.common.TasteException;
+import org.apache.mahout.cf.taste.impl.common.FastByIDMap;
+import org.apache.mahout.cf.taste.impl.common.FullRunningAverage;
+import org.apache.mahout.cf.taste.impl.common.LongPrimitiveIterator;
+import org.apache.mahout.cf.taste.impl.common.RunningAverage;
+import org.apache.mahout.cf.taste.impl.recommender.svd.Factorization;
+import org.apache.mahout.cf.taste.impl.recommender.svd.Factorizer;
+import org.apache.mahout.cf.taste.model.DataModel;
+import org.apache.mahout.cf.taste.model.Preference;
+import org.apache.mahout.common.RandomUtils;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+import java.util.Collection;
+import java.util.Random;
+ * {@link Factorizer} based on Simon Funk's famous article <a href="http://sifter.org/~simon/journal/20061211.html">
+ * "Netflix Update: Try this at home"</a>.
+ *
+ * Attempts to be as memory efficient as possible, only iterating once through the
+ * {@link FactorizablePreferences} or {@link DataModel} while copying everything to primitive arrays.
+ * Learning works in place on these datastructures after that.
+ */
+public class ParallelArraysSGDFactorizer implements Factorizer {
+ public static final double DEFAULT_LEARNING_RATE = 0.005;
+ public static final double DEFAULT_PREVENT_OVERFITTING = 0.02;
+ public static final double DEFAULT_RANDOM_NOISE = 0.005;
+ private final int numFeatures;
+ private final int numIterations;
+ private final float minPreference;
+ private final float maxPreference;
+ private final Random random;
+ private final double learningRate;
+ private final double preventOverfitting;
+ private final FastByIDMap<Integer> userIDMapping;
+ private final FastByIDMap<Integer> itemIDMapping;
+ private final double[][] userFeatures;
+ private final double[][] itemFeatures;
+ private final int[] userIndexes;
+ private final int[] itemIndexes;
+ private final float[] values;
+ private final double defaultValue;
+ private final double interval;
+ private final double[] cachedEstimates;
+ private static final Logger log = LoggerFactory.getLogger(ParallelArraysSGDFactorizer.class);
+ public ParallelArraysSGDFactorizer(DataModel dataModel, int numFeatures, int numIterations) {
+ this(new DataModelFactorizablePreferences(dataModel), numFeatures, numIterations, DEFAULT_LEARNING_RATE,
+ }
+ public ParallelArraysSGDFactorizer(DataModel dataModel, int numFeatures, int numIterations, double learningRate,
+ double preventOverfitting, double randomNoise) {
+ this(new DataModelFactorizablePreferences(dataModel), numFeatures, numIterations, learningRate, preventOverfitting,
+ randomNoise);
+ }
+ public ParallelArraysSGDFactorizer(FactorizablePreferences factorizablePrefs, int numFeatures, int numIterations) {
+ this(factorizablePrefs, numFeatures, numIterations, DEFAULT_LEARNING_RATE, DEFAULT_PREVENT_OVERFITTING,
+ }
+ public ParallelArraysSGDFactorizer(FactorizablePreferences factorizablePreferences, int numFeatures,
+ int numIterations, double learningRate, double preventOverfitting, double randomNoise) {
+ this.numFeatures = numFeatures;
+ this.numIterations = numIterations;
+ minPreference = factorizablePreferences.getMinPreference();
+ maxPreference = factorizablePreferences.getMaxPreference();
+ this.random = RandomUtils.getRandom();
+ this.learningRate = learningRate;
+ this.preventOverfitting = preventOverfitting;
+ int numUsers = factorizablePreferences.numUsers();
+ int numItems = factorizablePreferences.numItems();
+ int numPrefs = factorizablePreferences.numPreferences();
+ log.info("Mapping {} users...", numUsers);
+ userIDMapping = new FastByIDMap<>(numUsers);
+ int index = 0;
+ LongPrimitiveIterator userIterator = factorizablePreferences.getUserIDs();
+ while (userIterator.hasNext()) {
+ userIDMapping.put(userIterator.nextLong(), index++);
+ }
+ log.info("Mapping {} items", numItems);
+ itemIDMapping = new FastByIDMap<>(numItems);
+ index = 0;
+ LongPrimitiveIterator itemIterator = factorizablePreferences.getItemIDs();
+ while (itemIterator.hasNext()) {
+ itemIDMapping.put(itemIterator.nextLong(), index++);
+ }
+ this.userIndexes = new int[numPrefs];
+ this.itemIndexes = new int[numPrefs];
+ this.values = new float[numPrefs];
+ this.cachedEstimates = new double[numPrefs];
+ index = 0;
+ log.info("Loading {} preferences into memory", numPrefs);
+ RunningAverage average = new FullRunningAverage();
+ for (Preference preference : factorizablePreferences.getPreferences()) {
+ userIndexes[index] = userIDMapping.get(preference.getUserID());
+ itemIndexes[index] = itemIDMapping.get(preference.getItemID());
+ values[index] = preference.getValue();
+ cachedEstimates[index] = 0;
+ average.addDatum(preference.getValue());
+ index++;
+ if (index % 1000000 == 0) {
+ log.info("Processed {} preferences", index);
+ }
+ }
+ log.info("Processed {} preferences, done.", index);
+ double averagePreference = average.getAverage();
+ log.info("Average preference value is {}", averagePreference);
+ double prefInterval = factorizablePreferences.getMaxPreference() - factorizablePreferences.getMinPreference();
+ defaultValue = Math.sqrt((averagePreference - prefInterval * 0.1) / numFeatures);
+ interval = prefInterval * 0.1 / numFeatures;
+ userFeatures = new double[numUsers][numFeatures];
+ itemFeatures = new double[numItems][numFeatures];
+ log.info("Initializing feature vectors...");
+ for (int feature = 0; feature < numFeatures; feature++) {
+ for (int userIndex = 0; userIndex < numUsers; userIndex++) {
+ userFeatures[userIndex][feature] = defaultValue + (random.nextDouble() - 0.5) * interval * randomNoise;
+ }
+ for (int itemIndex = 0; itemIndex < numItems; itemIndex++) {
+ itemFeatures[itemIndex][feature] = defaultValue + (random.nextDouble() - 0.5) * interval * randomNoise;
+ }
+ }
+ }
+ @Override
+ public Factorization factorize() throws TasteException {
+ for (int feature = 0; feature < numFeatures; feature++) {
+ log.info("Shuffling preferences...");
+ shufflePreferences();
+ log.info("Starting training of feature {} ...", feature);
+ for (int currentIteration = 0; currentIteration < numIterations; currentIteration++) {
+ if (currentIteration == numIterations - 1) {
+ double rmse = trainingIterationWithRmse(feature);
+ log.info("Finished training feature {} with RMSE {}", feature, rmse);
+ } else {
+ trainingIteration(feature);
+ }
+ }
+ if (feature < numFeatures - 1) {
+ log.info("Updating cache...");
+ for (int index = 0; index < userIndexes.length; index++) {
+ cachedEstimates[index] = estimate(userIndexes[index], itemIndexes[index], feature, cachedEstimates[index],
+ false);
+ }
+ }
+ }
+ log.info("Factorization done");
+ return new Factorization(userIDMapping, itemIDMapping, userFeatures, itemFeatures);
+ }
+ private void trainingIteration(int feature) {
+ for (int index = 0; index < userIndexes.length; index++) {
+ train(userIndexes[index], itemIndexes[index], feature, values[index], cachedEstimates[index]);
+ }
+ }
+ private double trainingIterationWithRmse(int feature) {
+ double rmse = 0.0;
+ for (int index = 0; index < userIndexes.length; index++) {
+ double error = train(userIndexes[index], itemIndexes[index], feature, values[index], cachedEstimates[index]);
+ rmse += error * error;
+ }
+ return Math.sqrt(rmse / userIndexes.length);
+ }
+ private double estimate(int userIndex, int itemIndex, int feature, double cachedEstimate, boolean trailing) {
+ double sum = cachedEstimate;
+ sum += userFeatures[userIndex][feature] * itemFeatures[itemIndex][feature];
+ if (trailing) {
+ sum += (numFeatures - feature - 1) * (defaultValue + interval) * (defaultValue + interval);
+ if (sum > maxPreference) {
+ sum = maxPreference;
+ } else if (sum < minPreference) {
+ sum = minPreference;
+ }
+ }
+ return sum;
+ }
+ public double train(int userIndex, int itemIndex, int feature, double original, double cachedEstimate) {
+ double error = original - estimate(userIndex, itemIndex, feature, cachedEstimate, true);
+ double[] userVector = userFeatures[userIndex];
+ double[] itemVector = itemFeatures[itemIndex];
+ userVector[feature] += learningRate * (error * itemVector[feature] - preventOverfitting * userVector[feature]);
+ itemVector[feature] += learningRate * (error * userVector[feature] - preventOverfitting * itemVector[feature]);
+ return error;
+ }
+ protected void shufflePreferences() {
+ /* Durstenfeld shuffle */
+ for (int currentPos = userIndexes.length - 1; currentPos > 0; currentPos--) {
+ int swapPos = random.nextInt(currentPos + 1);
+ swapPreferences(currentPos, swapPos);
+ }
+ }
+ private void swapPreferences(int posA, int posB) {
+ int tmpUserIndex = userIndexes[posA];
+ int tmpItemIndex = itemIndexes[posA];
+ float tmpValue = values[posA];
+ double tmpEstimate = cachedEstimates[posA];
+ userIndexes[posA] = userIndexes[posB];
+ itemIndexes[posA] = itemIndexes[posB];
+ values[posA] = values[posB];
+ cachedEstimates[posA] = cachedEstimates[posB];
+ userIndexes[posB] = tmpUserIndex;
+ itemIndexes[posB] = tmpItemIndex;
+ values[posB] = tmpValue;
+ cachedEstimates[posB] = tmpEstimate;
+ }
+ @Override
+ public void refresh(Collection<Refreshable> alreadyRefreshed) {
+ // do nothing
+ }

diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/svd/Track1SVDRunner.java b/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/svd/Track1SVDRunner.java
new file mode 100644
index 0000000..5cce02d
--- /dev/null
+++ b/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/svd/Track1SVDRunner.java
@@ -0,0 +1,141 @@
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.cf.taste.example.kddcup.track1.svd;
+import java.io.BufferedOutputStream;
+import java.io.File;
+import java.io.FileOutputStream;
+import java.io.OutputStream;
+import com.google.common.io.Closeables;
+import org.apache.mahout.cf.taste.common.NoSuchItemException;
+import org.apache.mahout.cf.taste.common.NoSuchUserException;
+import org.apache.mahout.cf.taste.example.kddcup.DataFileIterable;
+import org.apache.mahout.cf.taste.example.kddcup.KDDCupDataModel;
+import org.apache.mahout.cf.taste.example.kddcup.track1.EstimateConverter;
+import org.apache.mahout.cf.taste.impl.common.FullRunningAverage;
+import org.apache.mahout.cf.taste.impl.common.RunningAverage;
+import org.apache.mahout.cf.taste.impl.recommender.svd.Factorization;
+import org.apache.mahout.cf.taste.impl.recommender.svd.Factorizer;
+import org.apache.mahout.cf.taste.model.Preference;
+import org.apache.mahout.cf.taste.model.PreferenceArray;
+import org.apache.mahout.common.Pair;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+ * run an SVD factorization of the KDD track1 data.
+ *
+ * needs at least 6-7GB of memory, tested with -Xms6700M -Xmx6700M
+ *
+ */
+public final class Track1SVDRunner {
+ private static final Logger log = LoggerFactory.getLogger(Track1SVDRunner.class);
+ private Track1SVDRunner() {
+ }
+ public static void main(String[] args) throws Exception {
+ if (args.length != 2) {
+ System.err.println("Necessary arguments: <kddDataFileDirectory> <resultFile>");
+ return;
+ }
+ File dataFileDirectory = new File(args[0]);
+ if (!dataFileDirectory.exists() || !dataFileDirectory.isDirectory()) {
+ throw new IllegalArgumentException("Bad data file directory: " + dataFileDirectory);
+ }
+ File resultFile = new File(args[1]);
+ /* the knobs to turn */
+ int numFeatures = 20;
+ int numIterations = 5;
+ double learningRate = 0.0001;
+ double preventOverfitting = 0.002;
+ double randomNoise = 0.0001;
+ KDDCupFactorizablePreferences factorizablePreferences =
+ new KDDCupFactorizablePreferences(KDDCupDataModel.getTrainingFile(dataFileDirectory));
+ Factorizer sgdFactorizer = new ParallelArraysSGDFactorizer(factorizablePreferences, numFeatures, numIterations,
+ learningRate, preventOverfitting, randomNoise);
+ Factorization factorization = sgdFactorizer.factorize();
+ log.info("Estimating validation preferences...");
+ int prefsProcessed = 0;
+ RunningAverage average = new FullRunningAverage();
+ for (Pair<PreferenceArray,long[]> validationPair
+ : new DataFileIterable(KDDCupDataModel.getValidationFile(dataFileDirectory))) {
+ for (Preference validationPref : validationPair.getFirst()) {
+ double estimate = estimatePreference(factorization, validationPref.getUserID(), validationPref.getItemID(),
+ factorizablePreferences.getMinPreference(), factorizablePreferences.getMaxPreference());
+ double error = validationPref.getValue() - estimate;
+ average.addDatum(error * error);
+ prefsProcessed++;
+ if (prefsProcessed % 100000 == 0) {
+ log.info("Computed {} estimations", prefsProcessed);
+ }
+ }
+ }
+ log.info("Computed {} estimations, done.", prefsProcessed);
+ double rmse = Math.sqrt(average.getAverage());
+ log.info("RMSE {}", rmse);
+ log.info("Estimating test preferences...");
+ OutputStream out = null;
+ try {
+ out = new BufferedOutputStream(new FileOutputStream(resultFile));
+ for (Pair<PreferenceArray,long[]> testPair
+ : new DataFileIterable(KDDCupDataModel.getTestFile(dataFileDirectory))) {
+ for (Preference testPref : testPair.getFirst()) {
+ double estimate = estimatePreference(factorization, testPref.getUserID(), testPref.getItemID(),
+ factorizablePreferences.getMinPreference(), factorizablePreferences.getMaxPreference());
+ byte result = EstimateConverter.convert(estimate, testPref.getUserID(), testPref.getItemID());
+ out.write(result);
+ }
+ }
+ } finally {
+ Closeables.close(out, false);
+ }
+ log.info("wrote estimates to {}, done.", resultFile.getAbsolutePath());
+ }
+ static double estimatePreference(Factorization factorization, long userID, long itemID, float minPreference,
+ float maxPreference) throws NoSuchUserException, NoSuchItemException {
+ double[] userFeatures = factorization.getUserFeatures(userID);
+ double[] itemFeatures = factorization.getItemFeatures(itemID);
+ double estimate = 0;
+ for (int feature = 0; feature < userFeatures.length; feature++) {
+ estimate += userFeatures[feature] * itemFeatures[feature];
+ }
+ if (estimate < minPreference) {
+ estimate = minPreference;
+ } else if (estimate > maxPreference) {
+ estimate = maxPreference;
+ }
+ return estimate;
+ }

diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/HybridSimilarity.java b/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/HybridSimilarity.java
new file mode 100644
index 0000000..ce025a9
--- /dev/null
+++ b/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/HybridSimilarity.java
@@ -0,0 +1,62 @@
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.cf.taste.example.kddcup.track2;
+import java.io.File;
+import java.io.IOException;
+import java.util.Collection;
+import org.apache.mahout.cf.taste.common.Refreshable;
+import org.apache.mahout.cf.taste.common.TasteException;
+import org.apache.mahout.cf.taste.impl.similarity.AbstractItemSimilarity;
+import org.apache.mahout.cf.taste.impl.similarity.LogLikelihoodSimilarity;
+import org.apache.mahout.cf.taste.model.DataModel;
+import org.apache.mahout.cf.taste.similarity.ItemSimilarity;
+final class HybridSimilarity extends AbstractItemSimilarity {
+ private final ItemSimilarity cfSimilarity;
+ private final ItemSimilarity contentSimilarity;
+ HybridSimilarity(DataModel dataModel, File dataFileDirectory) throws IOException {
+ super(dataModel);
+ cfSimilarity = new LogLikelihoodSimilarity(dataModel);
+ contentSimilarity = new TrackItemSimilarity(dataFileDirectory);
+ }
+ @Override
+ public double itemSimilarity(long itemID1, long itemID2) throws TasteException {
+ return contentSimilarity.itemSimilarity(itemID1, itemID2) * cfSimilarity.itemSimilarity(itemID1, itemID2);
+ }
+ @Override
+ public double[] itemSimilarities(long itemID1, long[] itemID2s) throws TasteException {
+ double[] result = contentSimilarity.itemSimilarities(itemID1, itemID2s);
+ double[] multipliers = cfSimilarity.itemSimilarities(itemID1, itemID2s);
+ for (int i = 0; i < result.length; i++) {
+ result[i] *= multipliers[i];
+ }
+ return result;
+ }
+ @Override
+ public void refresh(Collection<Refreshable> alreadyRefreshed) {
+ cfSimilarity.refresh(alreadyRefreshed);
+ }

diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/Track2Callable.java b/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/Track2Callable.java
new file mode 100644
index 0000000..50fd35e
--- /dev/null
+++ b/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/Track2Callable.java
@@ -0,0 +1,106 @@
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.cf.taste.example.kddcup.track2;
+import org.apache.mahout.cf.taste.common.NoSuchItemException;
+import org.apache.mahout.cf.taste.common.TasteException;
+import org.apache.mahout.cf.taste.model.PreferenceArray;
+import org.apache.mahout.cf.taste.recommender.Recommender;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.HashSet;
+import java.util.List;
+import java.util.TreeMap;
+import java.util.concurrent.Callable;
+import java.util.concurrent.atomic.AtomicInteger;
+final class Track2Callable implements Callable<UserResult> {
+ private static final Logger log = LoggerFactory.getLogger(Track2Callable.class);
+ private static final AtomicInteger COUNT = new AtomicInteger();
+ private final Recommender recommender;
+ private final PreferenceArray userTest;
+ Track2Callable(Recommender recommender, PreferenceArray userTest) {
+ this.recommender = recommender;
+ this.userTest = userTest;
+ }
+ @Override
+ public UserResult call() throws TasteException {
+ int testSize = userTest.length();
+ if (testSize != 6) {
+ throw new IllegalArgumentException("Expecting 6 items for user but got " + userTest);
+ }
+ long userID = userTest.get(0).getUserID();
+ TreeMap<Double,Long> estimateToItemID = new TreeMap<>(Collections.reverseOrder());
+ for (int i = 0; i < testSize; i++) {
+ long itemID = userTest.getItemID(i);
+ double estimate;
+ try {
+ estimate = recommender.estimatePreference(userID, itemID);
+ } catch (NoSuchItemException nsie) {
+ // OK in the sample data provided before the contest, should never happen otherwise
+ log.warn("Unknown item {}; OK unless this is the real contest data", itemID);
+ continue;
+ }
+ if (!Double.isNaN(estimate)) {
+ estimateToItemID.put(estimate, itemID);
+ }
+ }
+ Collection<Long> itemIDs = estimateToItemID.values();
+ List<Long> topThree = new ArrayList<>(itemIDs);
+ if (topThree.size() > 3) {
+ topThree = topThree.subList(0, 3);
+ } else if (topThree.size() < 3) {
+ log.warn("Unable to recommend three items for {}", userID);
+ // Some NaNs - just guess at the rest then
+ Collection<Long> newItemIDs = new HashSet<>(3);
+ newItemIDs.addAll(itemIDs);
+ int i = 0;
+ while (i < testSize && newItemIDs.size() < 3) {
+ newItemIDs.add(userTest.getItemID(i));
+ i++;
+ }
+ topThree = new ArrayList<>(newItemIDs);
+ }
+ if (topThree.size() != 3) {
+ throw new IllegalStateException();
+ }
+ boolean[] result = new boolean[testSize];
+ for (int i = 0; i < testSize; i++) {
+ result[i] = topThree.contains(userTest.getItemID(i));
+ }
+ if (COUNT.incrementAndGet() % 1000 == 0) {
+ log.info("Completed {} users", COUNT.get());
+ }
+ return new UserResult(userID, result);
+ }

diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/Track2Recommender.java b/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/Track2Recommender.java
new file mode 100644
index 0000000..185a00d
--- /dev/null
+++ b/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/Track2Recommender.java
@@ -0,0 +1,100 @@
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.cf.taste.example.kddcup.track2;
+import java.io.File;
+import java.io.IOException;
+import java.util.Collection;
+import java.util.List;
+import org.apache.mahout.cf.taste.common.Refreshable;
+import org.apache.mahout.cf.taste.common.TasteException;
+import org.apache.mahout.cf.taste.impl.recommender.GenericBooleanPrefItemBasedRecommender;
+import org.apache.mahout.cf.taste.model.DataModel;
+import org.apache.mahout.cf.taste.recommender.IDRescorer;
+import org.apache.mahout.cf.taste.recommender.RecommendedItem;
+import org.apache.mahout.cf.taste.recommender.Recommender;
+import org.apache.mahout.cf.taste.similarity.ItemSimilarity;
+public final class Track2Recommender implements Recommender {
+ private final Recommender recommender;
+ public Track2Recommender(DataModel dataModel, File dataFileDirectory) throws TasteException {
+ // Change this to whatever you like!
+ ItemSimilarity similarity;
+ try {
+ similarity = new HybridSimilarity(dataModel, dataFileDirectory);
+ } catch (IOException ioe) {
+ throw new TasteException(ioe);
+ }
+ recommender = new GenericBooleanPrefItemBasedRecommender(dataModel, similarity);
+ }
+ @Override
+ public List<RecommendedItem> recommend(long userID, int howMany) throws TasteException {
+ return recommender.recommend(userID, howMany);
+ }
+ @Override
+ public List<RecommendedItem> recommend(long userID, int howMany, boolean includeKnownItems) throws TasteException {
+ return recommend(userID, howMany, null, includeKnownItems);
+ }
+ @Override
+ public List<RecommendedItem> recommend(long userID, int howMany, IDRescorer rescorer) throws TasteException {
+ return recommender.recommend(userID, howMany, rescorer, false);
+ }
+ @Override
+ public List<RecommendedItem> recommend(long userID, int howMany, IDRescorer rescorer, boolean includeKnownItems)
+ throws TasteException {
+ return recommender.recommend(userID, howMany, rescorer, includeKnownItems);
+ }
+ @Override
+ public float estimatePreference(long userID, long itemID) throws TasteException {
+ return recommender.estimatePreference(userID, itemID);
+ }
+ @Override
+ public void setPreference(long userID, long itemID, float value) throws TasteException {
+ recommender.setPreference(userID, itemID, value);
+ }
+ @Override
+ public void removePreference(long userID, long itemID) throws TasteException {
+ recommender.removePreference(userID, itemID);
+ }
+ @Override
+ public DataModel getDataModel() {
+ return recommender.getDataModel();
+ }
+ @Override
+ public void refresh(Collection<Refreshable> alreadyRefreshed) {
+ recommender.refresh(alreadyRefreshed);
+ }
+ @Override
+ public String toString() {
+ return "Track1Recommender[recommender:" + recommender + ']';
+ }

diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/Track2RecommenderBuilder.java b/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/Track2RecommenderBuilder.java
new file mode 100644
index 0000000..09ade5d
--- /dev/null
+++ b/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/Track2RecommenderBuilder.java
@@ -0,0 +1,33 @@
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.cf.taste.example.kddcup.track2;
+import org.apache.mahout.cf.taste.common.TasteException;
+import org.apache.mahout.cf.taste.eval.RecommenderBuilder;
+import org.apache.mahout.cf.taste.example.kddcup.KDDCupDataModel;
+import org.apache.mahout.cf.taste.model.DataModel;
+import org.apache.mahout.cf.taste.recommender.Recommender;
+final class Track2RecommenderBuilder implements RecommenderBuilder {
+ @Override
+ public Recommender buildRecommender(DataModel dataModel) throws TasteException {
+ return new Track2Recommender(dataModel, ((KDDCupDataModel) dataModel).getDataFileDirectory());
+ }

diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/Track2Runner.java b/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/Track2Runner.java
new file mode 100644
index 0000000..3cbb61c
--- /dev/null
+++ b/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/Track2Runner.java
@@ -0,0 +1,100 @@
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.cf.taste.example.kddcup.track2;
+import org.apache.mahout.cf.taste.example.kddcup.DataFileIterable;
+import org.apache.mahout.cf.taste.example.kddcup.KDDCupDataModel;
+import org.apache.mahout.cf.taste.model.PreferenceArray;
+import org.apache.mahout.common.Pair;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+import java.io.BufferedOutputStream;
+import java.io.File;
+import java.io.FileOutputStream;
+import java.io.OutputStream;
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.List;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.Future;
+ * <p>Runs "track 2" of the KDD Cup competition using whatever recommender is inside {@link Track2Recommender}
+ * and attempts to output the result in the correct contest format.</p>
+ *
+ * <p>Run as: {@code Track2Runner [track 2 data file directory] [output file]}</p>
+ */
+public final class Track2Runner {
+ private static final Logger log = LoggerFactory.getLogger(Track2Runner.class);
+ private Track2Runner() {
+ }
+ public static void main(String[] args) throws Exception {
+ File dataFileDirectory = new File(args[0]);
+ if (!dataFileDirectory.exists() || !dataFileDirectory.isDirectory()) {
+ throw new IllegalArgumentException("Bad data file directory: " + dataFileDirectory);
+ }
+ long start = System.currentTimeMillis();
+ KDDCupDataModel model = new KDDCupDataModel(KDDCupDataModel.getTrainingFile(dataFileDirectory));
+ Track2Recommender recommender = new Track2Recommender(model, dataFileDirectory);
+ long end = System.currentTimeMillis();
+ log.info("Loaded model in {}s", (end - start) / 1000);
+ start = end;
+ Collection<Track2Callable> callables = new ArrayList<>();
+ for (Pair<PreferenceArray,long[]> tests : new DataFileIterable(KDDCupDataModel.getTestFile(dataFileDirectory))) {
+ PreferenceArray userTest = tests.getFirst();
+ callables.add(new Track2Callable(recommender, userTest));
+ }
+ int cores = Runtime.getRuntime().availableProcessors();
+ log.info("Running on {} cores", cores);
+ ExecutorService executor = Executors.newFixedThreadPool(cores);
+ List<Future<UserResult>> futures = executor.invokeAll(callables);
+ executor.shutdown();
+ end = System.currentTimeMillis();
+ log.info("Ran recommendations in {}s", (end - start) / 1000);
+ start = end;
+ try (OutputStream out = new BufferedOutputStream(new FileOutputStream(new File(args[1])))){
+ long lastUserID = Long.MIN_VALUE;
+ for (Future<UserResult> future : futures) {
+ UserResult result = future.get();
+ long userID = result.getUserID();
+ if (userID <= lastUserID) {
+ throw new IllegalStateException();
+ }
+ lastUserID = userID;
+ out.write(result.getResultBytes());
+ }
+ }
+ end = System.currentTimeMillis();
+ log.info("Wrote output in {}s", (end - start) / 1000);
+ }

diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/TrackData.java b/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/TrackData.java
new file mode 100644
index 0000000..abd15f8
--- /dev/null
+++ b/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/TrackData.java
@@ -0,0 +1,71 @@
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.cf.taste.example.kddcup.track2;
+import java.util.regex.Pattern;
+import org.apache.mahout.cf.taste.impl.common.FastIDSet;
+final class TrackData {
+ private static final Pattern PIPE = Pattern.compile("\\|");
+ private static final String NO_VALUE = "None";
+ static final long NO_VALUE_ID = Long.MIN_VALUE;
+ private static final FastIDSet NO_GENRES = new FastIDSet();
+ private final long trackID;
+ private final long albumID;
+ private final long artistID;
+ private final FastIDSet genreIDs;
+ TrackData(CharSequence line) {
+ String[] tokens = PIPE.split(line);
+ trackID = Long.parseLong(tokens[0]);
+ albumID = parse(tokens[1]);
+ artistID = parse(tokens[2]);
+ if (tokens.length > 3) {
+ genreIDs = new FastIDSet(tokens.length - 3);
+ for (int i = 3; i < tokens.length; i++) {
+ genreIDs.add(Long.parseLong(tokens[i]));
+ }
+ } else {
+ genreIDs = NO_GENRES;
+ }
+ }
+ private static long parse(String value) {
+ return NO_VALUE.equals(value) ? NO_VALUE_ID : Long.parseLong(value);
+ }
+ public long getTrackID() {
+ return trackID;
+ }
+ public long getAlbumID() {
+ return albumID;
+ }
+ public long getArtistID() {
+ return artistID;
+ }
+ public FastIDSet getGenreIDs() {
+ return genreIDs;
+ }

diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/TrackItemSimilarity.java b/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/TrackItemSimilarity.java
new file mode 100644
index 0000000..3012a84
--- /dev/null
+++ b/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/TrackItemSimilarity.java
@@ -0,0 +1,106 @@
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.cf.taste.example.kddcup.track2;
+import java.io.File;
+import java.io.IOException;
+import java.util.Collection;
+import org.apache.mahout.cf.taste.common.Refreshable;
+import org.apache.mahout.cf.taste.example.kddcup.KDDCupDataModel;
+import org.apache.mahout.cf.taste.impl.common.FastByIDMap;
+import org.apache.mahout.cf.taste.impl.common.FastIDSet;
+import org.apache.mahout.cf.taste.impl.common.LongPrimitiveIterator;
+import org.apache.mahout.cf.taste.similarity.ItemSimilarity;
+import org.apache.mahout.common.iterator.FileLineIterable;
+final class TrackItemSimilarity implements ItemSimilarity {
+ private final FastByIDMap<TrackData> trackData;
+ TrackItemSimilarity(File dataFileDirectory) throws IOException {
+ trackData = new FastByIDMap<>();
+ for (String line : new FileLineIterable(KDDCupDataModel.getTrackFile(dataFileDirectory))) {
+ TrackData trackDatum = new TrackData(line);
+ trackData.put(trackDatum.getTrackID(), trackDatum);
+ }
+ }
+ @Override
+ public double itemSimilarity(long itemID1, long itemID2) {
+ if (itemID1 == itemID2) {
+ return 1.0;
+ }
+ TrackData data1 = trackData.get(itemID1);
+ TrackData data2 = trackData.get(itemID2);
+ if (data1 == null || data2 == null) {
+ return 0.0;
+ }
+ // Arbitrarily decide that same album means "very similar"
+ if (data1.getAlbumID() != TrackData.NO_VALUE_ID && data1.getAlbumID() == data2.getAlbumID()) {
+ return 0.9;
+ }
+ // ... and same artist means "fairly similar"
+ if (data1.getArtistID() != TrackData.NO_VALUE_ID && data1.getArtistID() == data2.getArtistID()) {
+ return 0.7;
+ }
+ // Tanimoto coefficient similarity based on genre, but maximum value of 0.25
+ FastIDSet genres1 = data1.getGenreIDs();
+ FastIDSet genres2 = data2.getGenreIDs();
+ if (genres1 == null || genres2 == null) {
+ return 0.0;
+ }
+ int intersectionSize = genres1.intersectionSize(genres2);
+ if (intersectionSize == 0) {
+ return 0.0;
+ }
+ int unionSize = genres1.size() + genres2.size() - intersectionSize;
+ return intersectionSize / (4.0 * unionSize);
+ }
+ @Override
+ public double[] itemSimilarities(long itemID1, long[] itemID2s) {
+ int length = itemID2s.length;
+ double[] result = new double[length];
+ for (int i = 0; i < length; i++) {
+ result[i] = itemSimilarity(itemID1, itemID2s[i]);
+ }
+ return result;
+ }
+ @Override
+ public long[] allSimilarItemIDs(long itemID) {
+ FastIDSet allSimilarItemIDs = new FastIDSet();
+ LongPrimitiveIterator allItemIDs = trackData.keySetIterator();
+ while (allItemIDs.hasNext()) {
+ long possiblySimilarItemID = allItemIDs.nextLong();
+ if (!Double.isNaN(itemSimilarity(itemID, possiblySimilarItemID))) {
+ allSimilarItemIDs.add(possiblySimilarItemID);
+ }
+ }
+ return allSimilarItemIDs.toArray();
+ }
+ @Override
+ public void refresh(Collection<Refreshable> alreadyRefreshed) {
+ // do nothing
+ }

diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/UserResult.java b/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/UserResult.java
new file mode 100644
index 0000000..e554d10
--- /dev/null
+++ b/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/UserResult.java
@@ -0,0 +1,54 @@
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.cf.taste.example.kddcup.track2;
+final class UserResult {
+ private final long userID;
+ private final byte[] resultBytes;
+ UserResult(long userID, boolean[] result) {
+ this.userID = userID;
+ int trueCount = 0;
+ for (boolean b : result) {
+ if (b) {
+ trueCount++;
+ }
+ }
+ if (trueCount != 3) {
+ throw new IllegalStateException();
+ }
+ resultBytes = new byte[result.length];
+ for (int i = 0; i < result.length; i++) {
+ resultBytes[i] = (byte) (result[i] ? '1' : '0');
+ }
+ }
+ public long getUserID() {
+ return userID;
+ }
+ public byte[] getResultBytes() {
+ return resultBytes;
+ }

diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/hadoop/example/als/netflix/NetflixDatasetConverter.java b/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/hadoop/example/als/netflix/NetflixDatasetConverter.java
new file mode 100644
index 0000000..22f122e
--- /dev/null
+++ b/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/hadoop/example/als/netflix/NetflixDatasetConverter.java
@@ -0,0 +1,140 @@
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.cf.taste.hadoop.example.als.netflix;
+import com.google.common.base.Preconditions;
+import org.apache.commons.io.Charsets;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.mahout.cf.taste.impl.model.GenericPreference;
+import org.apache.mahout.cf.taste.model.Preference;
+import org.apache.mahout.common.iterator.FileLineIterable;
+import org.apache.mahout.common.iterator.FileLineIterator;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+import java.io.BufferedWriter;
+import java.io.File;
+import java.io.IOException;
+import java.io.OutputStreamWriter;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.regex.Pattern;
+/** converts the raw files provided by netflix to an appropriate input format */
+public final class NetflixDatasetConverter {
+ private static final Logger log = LoggerFactory.getLogger(NetflixDatasetConverter.class);
+ private static final Pattern SEPARATOR = Pattern.compile(",");
+ private static final String MOVIE_DENOTER = ":";
+ private static final String TAB = "\t";
+ private static final String NEWLINE = "\n";
+ private NetflixDatasetConverter() {
+ }
+ public static void main(String[] args) throws IOException {
+ if (args.length != 4) {
+ System.err.println("Usage: NetflixDatasetConverter /path/to/training_set/ /path/to/qualifying.txt "
+ + "/path/to/judging.txt /path/to/destination");
+ return;
+ }
+ String trainingDataDir = args[0];
+ String qualifyingTxt = args[1];
+ String judgingTxt = args[2];
+ Path outputPath = new Path(args[3]);
+ Configuration conf = new Configuration();
+ FileSystem fs = FileSystem.get(outputPath.toUri(), conf);
+ Preconditions.checkArgument(trainingDataDir != null, "Training Data location needs to be specified");
+ log.info("Creating training set at {}/trainingSet/ratings.tsv ...", outputPath);
+ try (BufferedWriter writer =
+ new BufferedWriter(
+ new OutputStreamWriter(
+ fs.create(new Path(outputPath, "trainingSet/ratings.tsv")), Charsets.UTF_8))){
+ int ratingsProcessed = 0;
+ for (File movieRatings : new File(trainingDataDir).listFiles()) {
+ try (FileLineIterator lines = new FileLineIterator(movieRatings)) {
+ boolean firstLineRead = false;
+ String movieID = null;
+ while (lines.hasNext()) {
+ String line = lines.next();
+ if (firstLineRead) {
+ String[] tokens = SEPARATOR.split(line);
+ String userID = tokens[0];
+ String rating = tokens[1];
+ writer.write(userID + TAB + movieID + TAB + rating + NEWLINE);
+ ratingsProcessed++;
+ if (ratingsProcessed % 1000000 == 0) {
+ log.info("{} ratings processed...", ratingsProcessed);
+ }
+ } else {
+ movieID = line.replaceAll(MOVIE_DENOTER, "");
+ firstLineRead = true;
+ }
+ }
+ }
+ }
+ log.info("{} ratings processed. done.", ratingsProcessed);
+ }
+ log.info("Reading probes...");
+ List<Preference> probes = new ArrayList<>(2817131);
+ long currentMovieID = -1;
+ for (String line : new FileLineIterable(new File(qualifyingTxt))) {
+ if (line.contains(MOVIE_DENOTER)) {
+ currentMovieID = Long.parseLong(line.replaceAll(MOVIE_DENOTER, ""));
+ } else {
+ long userID = Long.parseLong(SEPARATOR.split(line)[0]);
+ probes.add(new GenericPreference(userID, currentMovieID, 0));
+ }
+ }
+ log.info("{} probes read...", probes.size());
+ log.info("Reading ratings, creating probe set at {}/probeSet/ratings.tsv ...", outputPath);
+ try (BufferedWriter writer =
+ new BufferedWriter(new OutputStreamWriter(
+ fs.create(new Path(outputPath, "probeSet/ratings.tsv")), Charsets.UTF_8))){
+ int ratingsProcessed = 0;
+ for (String line : new FileLineIterable(new File(judgingTxt))) {
+ if (line.contains(MOVIE_DENOTER)) {
+ currentMovieID = Long.parseLong(line.replaceAll(MOVIE_DENOTER, ""));
+ } else {
+ float rating = Float.parseFloat(SEPARATOR.split(line)[0]);
+ Preference pref = probes.get(ratingsProcessed);
+ Preconditions.checkState(pref.getItemID() == currentMovieID);
+ ratingsProcessed++;
+ writer.write(pref.getUserID() + TAB + pref.getItemID() + TAB + rating + NEWLINE);
+ if (ratingsProcessed % 1000000 == 0) {
+ log.info("{} ratings processed...", ratingsProcessed);
+ }
+ }
+ }
+ log.info("{} ratings processed. done.", ratingsProcessed);
+ }
+ }

diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/similarity/precompute/example/BatchItemSimilaritiesGroupLens.java b/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/similarity/precompute/example/BatchItemSimilaritiesGroupLens.java
new file mode 100644
index 0000000..8021d00
--- /dev/null
+++ b/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/similarity/precompute/example/BatchItemSimilaritiesGroupLens.java
@@ -0,0 +1,65 @@
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.cf.taste.similarity.precompute.example;
+import org.apache.mahout.cf.taste.impl.recommender.GenericItemBasedRecommender;
+import org.apache.mahout.cf.taste.impl.similarity.LogLikelihoodSimilarity;
+import org.apache.mahout.cf.taste.impl.similarity.precompute.FileSimilarItemsWriter;
+import org.apache.mahout.cf.taste.impl.similarity.precompute.MultithreadedBatchItemSimilarities;
+import org.apache.mahout.cf.taste.model.DataModel;
+import org.apache.mahout.cf.taste.recommender.ItemBasedRecommender;
+import org.apache.mahout.cf.taste.similarity.precompute.BatchItemSimilarities;
+import java.io.File;
+ * Example that precomputes all item similarities of the Movielens1M dataset
+ *
+ * Usage: download movielens1M from http://www.grouplens.org/node/73 , unzip it and invoke this code with the path
+ * to the ratings.dat file as argument
+ *
+ */
+public final class BatchItemSimilaritiesGroupLens {
+ private BatchItemSimilaritiesGroupLens() {}
+ public static void main(String[] args) throws Exception {
+ if (args.length != 1) {
+ System.err.println("Need path to ratings.dat of the movielens1M dataset as argument!");
+ System.exit(-1);
+ }
+ File resultFile = new File(System.getProperty("java.io.tmpdir"), "similarities.csv");
+ if (resultFile.exists()) {
+ resultFile.delete();
+ }
+ DataModel dataModel = new GroupLensDataModel(new File(args[0]));
+ ItemBasedRecommender recommender = new GenericItemBasedRecommender(dataModel,
+ new LogLikelihoodSimilarity(dataModel));
+ BatchItemSimilarities batch = new MultithreadedBatchItemSimilarities(recommender, 5);
+ int numSimilarities = batch.computeItemSimilarities(Runtime.getRuntime().availableProcessors(), 1,
+ new FileSimilarItemsWriter(resultFile));
+ System.out.println("Computed " + numSimilarities + " similarities for " + dataModel.getNumItems() + " items "
+ + "and saved them to " + resultFile.getAbsolutePath());
+ }

diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/similarity/precompute/example/GroupLensDataModel.java b/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/similarity/precompute/example/GroupLensDataModel.java
new file mode 100644
index 0000000..7ee9b17
--- /dev/null
+++ b/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/similarity/precompute/example/GroupLensDataModel.java
@@ -0,0 +1,96 @@
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.cf.taste.similarity.precompute.example;
+import com.google.common.io.Files;
+import com.google.common.io.InputSupplier;
+import com.google.common.io.Resources;
+import java.io.File;
+import java.io.FileOutputStream;
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.OutputStreamWriter;
+import java.io.Writer;
+import java.net.URL;
+import java.util.regex.Pattern;
+import org.apache.commons.io.Charsets;
+import org.apache.mahout.cf.taste.impl.model.file.FileDataModel;
+import org.apache.mahout.common.iterator.FileLineIterable;
+public final class GroupLensDataModel extends FileDataModel {
+ private static final String COLON_DELIMTER = "::";
+ private static final Pattern COLON_DELIMITER_PATTERN = Pattern.compile(COLON_DELIMTER);
+ public GroupLensDataModel() throws IOException {
+ this(readResourceToTempFile("/org/apache/mahout/cf/taste/example/grouplens/ratings.dat"));
+ }
+ /**
+ * @param ratingsFile GroupLens ratings.dat file in its native format
+ * @throws IOException if an error occurs while reading or writing files
+ */
+ public GroupLensDataModel(File ratingsFile) throws IOException {
+ super(convertGLFile(ratingsFile));
+ }
+ private static File convertGLFile(File originalFile) throws IOException {
+ // Now translate the file; remove commas, then convert "::" delimiter to comma
+ File resultFile = new File(new File(System.getProperty("java.io.tmpdir")), "ratings.txt");
+ if (resultFile.exists()) {
+ resultFile.delete();
+ }
+ try (Writer writer = new OutputStreamWriter(new FileOutputStream(resultFile), Charsets.UTF_8)){
+ for (String line : new FileLineIterable(originalFile, false)) {
+ int lastDelimiterStart = line.lastIndexOf(COLON_DELIMTER);
+ if (lastDelimiterStart < 0) {
+ throw new IOException("Unexpected input format on line: " + line);
+ }
+ String subLine = line.substring(0, lastDelimiterStart);
+ String convertedLine = COLON_DELIMITER_PATTERN.matcher(subLine).replaceAll(",");
+ writer.write(convertedLine);
+ writer.write('\n');
+ }
+ } catch (IOException ioe) {
+ resultFile.delete();
+ throw ioe;
+ }
+ return resultFile;
+ }
+ public static File readResourceToTempFile(String resourceName) throws IOException {
+ InputSupplier<? extends InputStream> inSupplier;
+ try {
+ URL resourceURL = Resources.getResource(GroupLensDataModel.class, resourceName);
+ inSupplier = Resources.newInputStreamSupplier(resourceURL);
+ } catch (IllegalArgumentException iae) {
+ File resourceFile = new File("src/main/java" + resourceName);
+ inSupplier = Files.newInputStreamSupplier(resourceFile);
+ }
+ File tempFile = File.createTempFile("taste", null);
+ tempFile.deleteOnExit();
+ Files.copy(inSupplier, tempFile);
+ return tempFile;
+ }
+ @Override
+ public String toString() {
+ return "GroupLensDataModel";
+ }

diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/NewsgroupHelper.java b/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/NewsgroupHelper.java
new file mode 100644
index 0000000..5cec51c
--- /dev/null
+++ b/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/NewsgroupHelper.java
@@ -0,0 +1,128 @@
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.classifier;
+import com.google.common.collect.ConcurrentHashMultiset;
+import com.google.common.collect.Multiset;
+import com.google.common.io.Closeables;
+import com.google.common.io.Files;
+import org.apache.commons.io.Charsets;
+import org.apache.lucene.analysis.Analyzer;
+import org.apache.lucene.analysis.TokenStream;
+import org.apache.lucene.analysis.standard.StandardAnalyzer;
+import org.apache.lucene.analysis.tokenattributes.CharTermAttribute;
+import org.apache.mahout.common.RandomUtils;
+import org.apache.mahout.math.RandomAccessSparseVector;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.vectorizer.encoders.ConstantValueEncoder;
+import org.apache.mahout.vectorizer.encoders.FeatureVectorEncoder;
+import org.apache.mahout.vectorizer.encoders.StaticWordValueEncoder;
+import java.io.BufferedReader;
+import java.io.File;
+import java.io.IOException;
+import java.io.Reader;
+import java.io.StringReader;
+import java.text.SimpleDateFormat;
+import java.util.Collection;
+import java.util.Date;
+import java.util.Locale;
+import java.util.Random;
+public final class NewsgroupHelper {
+ private static final SimpleDateFormat[] DATE_FORMATS = {
+ new SimpleDateFormat("", Locale.ENGLISH),
+ new SimpleDateFormat("MMM-yyyy", Locale.ENGLISH),
+ new SimpleDateFormat("dd-MMM-yyyy HH:mm:ss", Locale.ENGLISH)
+ };
+ public static final int FEATURES = 10000;
+ // 1997-01-15 00:01:00 GMT
+ private static final long DATE_REFERENCE = 853286460;
+ private static final long MONTH = 30 * 24 * 3600;
+ private static final long WEEK = 7 * 24 * 3600;
+ private final Random rand = RandomUtils.getRandom();
+ private final Analyzer analyzer = new StandardAnalyzer();
+ private final FeatureVectorEncoder encoder = new StaticWordValueEncoder("body");
+ private final FeatureVectorEncoder bias = new ConstantValueEncoder("Intercept");
+ public FeatureVectorEncoder getEncoder() {
+ return encoder;
+ }
+ public FeatureVectorEncoder getBias() {
+ return bias;
+ }
+ public Random getRandom() {
+ return rand;
+ }
+ public Vector encodeFeatureVector(File file, int actual, int leakType, Multiset<String> overallCounts)
+ throws IOException {
+ long date = (long) (1000 * (DATE_REFERENCE + actual * MONTH + 1 * WEEK * rand.nextDouble()));
+ Multiset<String> words = ConcurrentHashMultiset.create();
+ try (BufferedReader reader = Files.newReader(file, Charsets.UTF_8)) {
+ String line = reader.readLine();
+ Reader dateString = new StringReader(DATE_FORMATS[leakType % 3].format(new Date(date)));
+ countWords(analyzer, words, dateString, overallCounts);
+ while (line != null && !line.isEmpty()) {
+ boolean countHeader = (
+ line.startsWith("From:") || line.startsWith("Subject:")
+ || line.startsWith("Keywords:") || line.startsWith("Summary:")) && leakType < 6;
+ do {
+ Reader in = new StringReader(line);
+ if (countHeader) {
+ countWords(analyzer, words, in, overallCounts);
+ }
+ line = reader.readLine();
+ } while (line != null && line.startsWith(" "));
+ }
+ if (leakType < 3) {
+ countWords(analyzer, words, reader, overallCounts);
+ }
+ }
+ Vector v = new RandomAccessSparseVector(FEATURES);
+ bias.addToVector("", 1, v);
+ for (String word : words.elementSet()) {
+ encoder.addToVector(word, Math.log1p(words.count(word)), v);
+ }
+ return v;
+ }
+ public static void countWords(Analyzer analyzer,
+ Collection<String> words,
+ Reader in,
+ Multiset<String> overallCounts) throws IOException {
+ TokenStream ts = analyzer.tokenStream("text", in);
+ ts.addAttribute(CharTermAttribute.class);
+ ts.reset();
+ while (ts.incrementToken()) {
+ String s = ts.getAttribute(CharTermAttribute.class).toString();
+ words.add(s);
+ }
+ overallCounts.addAll(words);
+ ts.end();
+ Closeables.close(ts, true);
+ }

diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/email/PrepEmailMapper.java b/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/email/PrepEmailMapper.java
new file mode 100644
index 0000000..16e9d80
--- /dev/null
+++ b/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/email/PrepEmailMapper.java
@@ -0,0 +1,65 @@
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.classifier.email;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.io.WritableComparable;
+import org.apache.hadoop.mapreduce.Mapper;
+import org.apache.mahout.math.VectorWritable;
+import java.io.IOException;
+import java.util.Locale;
+import java.util.regex.Pattern;
+ * Convert the labels created by the {@link org.apache.mahout.utils.email.MailProcessor} to one consumable
+ * by the classifiers
+ */
+public class PrepEmailMapper extends Mapper<WritableComparable<?>, VectorWritable, Text, VectorWritable> {
+ private static final Pattern DASH_DOT = Pattern.compile("-|\\.");
+ private static final Pattern SLASH = Pattern.compile("\\/");
+ private boolean useListName = false; //if true, use the project name and the list name in label creation
+ @Override
+ protected void setup(Context context) throws IOException, InterruptedException {
+ useListName = Boolean.parseBoolean(context.getConfiguration().get(PrepEmailVectorsDriver.USE_LIST_NAME));
+ }
+ @Override
+ protected void map(WritableComparable<?> key, VectorWritable value, Context context)
+ throws IOException, InterruptedException {
+ String input = key.toString();
+ ///Example: /cocoon.apache.org/dev/200307.gz/001401c3414f$8394e160$***@WRPO
+ String[] splits = SLASH.split(input);
+ //we need the first two splits;
+ if (splits.length >= 3) {
+ StringBuilder bldr = new StringBuilder();
+ bldr.append(escape(splits[1]));
+ if (useListName) {
+ bldr.append('_').append(escape(splits[2]));
+ }
+ context.write(new Text(bldr.toString()), value);
+ }
+ }
+ private static String escape(CharSequence value) {
+ return DASH_DOT.matcher(value).replaceAll("_").toLowerCase(Locale.ENGLISH);
+ }

diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/email/PrepEmailReducer.java b/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/email/PrepEmailReducer.java
new file mode 100644
index 0000000..da6e613
--- /dev/null
+++ b/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/email/PrepEmailReducer.java
@@ -0,0 +1,47 @@
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.classifier.email;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.mapreduce.Reducer;
+import org.apache.mahout.math.VectorWritable;
+import java.io.IOException;
+import java.util.Iterator;
+public class PrepEmailReducer extends Reducer<Text, VectorWritable, Text, VectorWritable> {
+ private long maxItemsPerLabel = 10000;
+ @Override
+ protected void setup(Context context) throws IOException, InterruptedException {
+ maxItemsPerLabel = Long.parseLong(context.getConfiguration().get(PrepEmailVectorsDriver.ITEMS_PER_CLASS));
+ }
+ @Override
+ protected void reduce(Text key, Iterable<VectorWritable> values, Context context)
+ throws IOException, InterruptedException {
+ //TODO: support randomization? Likely not needed due to the SplitInput utility which does random selection
+ long i = 0;
+ Iterator<VectorWritable> iterator = values.iterator();
+ while (i < maxItemsPerLabel && iterator.hasNext()) {
+ context.write(key, iterator.next());
+ i++;
+ }
+ }

diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/email/PrepEmailVectorsDriver.java b/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/email/PrepEmailVectorsDriver.java
new file mode 100644
index 0000000..8fba739
--- /dev/null
+++ b/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/email/PrepEmailVectorsDriver.java
@@ -0,0 +1,76 @@
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.classifier.email;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.mapreduce.Job;
+import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat;
+import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat;
+import org.apache.hadoop.util.ToolRunner;
+import org.apache.mahout.common.AbstractJob;
+import org.apache.mahout.common.HadoopUtil;
+import org.apache.mahout.common.commandline.DefaultOptionCreator;
+import org.apache.mahout.math.VectorWritable;
+import java.util.List;
+import java.util.Map;
+ * Convert the labels generated by {@link org.apache.mahout.text.SequenceFilesFromMailArchives} and
+ * {@link org.apache.mahout.vectorizer.SparseVectorsFromSequenceFiles} to ones consumable by the classifiers. We do this
+ * here b/c if it is done in the creation of sparse vectors, the Reducer collapses all the vectors.
+ */
+public class PrepEmailVectorsDriver extends AbstractJob {
+ public static final String ITEMS_PER_CLASS = "itemsPerClass";
+ public static final String USE_LIST_NAME = "USE_LIST_NAME";
+ public static void main(String[] args) throws Exception {
+ ToolRunner.run(new Configuration(), new PrepEmailVectorsDriver(), args);
+ }
+ @Override
+ public int run(String[] args) throws Exception {
+ addInputOption();
+ addOutputOption();
+ addOption(DefaultOptionCreator.overwriteOption().create());
+ addOption("maxItemsPerLabel", "mipl", "The maximum number of items per label. Can be useful for making the "
+ + "training sets the same size", String.valueOf(100000));
+ addOption(buildOption("useListName", "ul", "Use the name of the list as part of the label. If not set, then "
+ + "just use the project name", false, false, "false"));
+ Map<String,List<String>> parsedArgs = parseArguments(args);
+ if (parsedArgs == null) {
+ return -1;
+ }
+ Path input = getInputPath();
+ Path output = getOutputPath();
+ if (hasOption(DefaultOptionCreator.OVERWRITE_OPTION)) {
+ HadoopUtil.delete(getConf(), output);
+ }
+ Job convertJob = prepareJob(input, output, SequenceFileInputFormat.class, PrepEmailMapper.class, Text.class,
+ VectorWritable.class, PrepEmailReducer.class, Text.class, VectorWritable.class, SequenceFileOutputFormat.class);
+ convertJob.getConfiguration().set(ITEMS_PER_CLASS, getOption("maxItemsPerLabel"));
+ convertJob.getConfiguration().set(USE_LIST_NAME, String.valueOf(hasOption("useListName")));
+ boolean succeeded = convertJob.waitForCompletion(true);
+ return succeeded ? 0 : -1;
+ }

diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/PosTagger.java b/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/PosTagger.java
new file mode 100644
index 0000000..9c0ef56
--- /dev/null
+++ b/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/PosTagger.java
@@ -0,0 +1,277 @@
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.classifier.sequencelearning.hmm;
+import com.google.common.io.Resources;
+import org.apache.commons.io.Charsets;
+import org.apache.mahout.math.Matrix;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+import java.io.IOException;
+import java.net.URL;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Map;
+import java.util.regex.Pattern;
+ * This class implements a sample program that uses a pre-tagged training data
+ * set to train an HMM model as a POS tagger. The training data is automatically
+ * downloaded from the following URL:
+ * http://www.jaist.ac.jp/~hieuxuan/flexcrfs/CoNLL2000-NP/train.txt It then
+ * trains an HMM Model using supervised learning and tests the model on the
+ * following test data set:
+ * http://www.jaist.ac.jp/~hieuxuan/flexcrfs/CoNLL2000-NP/test.txt Further
+ * details regarding the data files can be found at
+ * http://flexcrfs.sourceforge.net/#Case_Study
+ */
+public final class PosTagger {
+ private static final Logger log = LoggerFactory.getLogger(PosTagger.class);
+ private static final Pattern SPACE = Pattern.compile(" ");
+ private static final Pattern SPACES = Pattern.compile("[ ]+");
+ /**
+ * No public constructors for utility classes.
+ */
+ private PosTagger() {
+ // nothing to do here really.
+ }
+ /**
+ * Model trained in the example.
+ */
+ private static HmmModel taggingModel;
+ /**
+ * Map for storing the IDs for the POS tags (hidden states)
+ */
+ private static Map<String, Integer> tagIDs;
+ /**
+ * Counter for the next assigned POS tag ID The value of 0 is reserved for
+ * "unknown POS tag"
+ */
+ private static int nextTagId;
+ /**
+ * Map for storing the IDs for observed words (observed states)
+ */
+ private static Map<String, Integer> wordIDs;
+ /**
+ * Counter for the next assigned word ID The value of 0 is reserved for
+ * "unknown word"
+ */
+ private static int nextWordId = 1; // 0 is reserved for "unknown word"
+ /**
+ * Used for storing a list of POS tags of read sentences.
+ */
+ private static List<int[]> hiddenSequences;
+ /**
+ * Used for storing a list of word tags of read sentences.
+ */
+ private static List<int[]> observedSequences;
+ /**
+ * number of read lines
+ */
+ private static int readLines;
+ /**
+ * Given an URL, this function fetches the data file, parses it, assigns POS
+ * Tag/word IDs and fills the hiddenSequences/observedSequences lists with
+ * data from those files. The data is expected to be in the following format
+ * (one word per line): word pos-tag np-tag sentences are closed with the .
+ * pos tag
+ *
+ * @param url Where the data file is stored
+ * @param assignIDs Should IDs for unknown words/tags be assigned? (Needed for
+ * training data, not needed for test data)
+ * @throws IOException in case data file cannot be read.
+ */
+ private static void readFromURL(String url, boolean assignIDs) throws IOException {
+ // initialize the data structure
+ hiddenSequences = new LinkedList<>();
+ observedSequences = new LinkedList<>();
+ readLines = 0;
+ // now read line by line of the input file
+ List<Integer> observedSequence = new LinkedList<>();
+ List<Integer> hiddenSequence = new LinkedList<>();
+ for (String line :Resources.readLines(new URL(url), Charsets.UTF_8)) {
+ if (line.isEmpty()) {
+ // new sentence starts
+ int[] observedSequenceArray = new int[observedSequence.size()];
+ int[] hiddenSequenceArray = new int[hiddenSequence.size()];
+ for (int i = 0; i < observedSequence.size(); ++i) {
+ observedSequenceArray[i] = observedSequence.get(i);
+ hiddenSequenceArray[i] = hiddenSequence.get(i);
+ }
+ // now register those arrays
+ hiddenSequences.add(hiddenSequenceArray);
+ observedSequences.add(observedSequenceArray);
+ // and reset the linked lists
+ observedSequence.clear();
+ hiddenSequence.clear();
+ continue;
+ }
+ readLines++;
+ // we expect the format [word] [POS tag] [NP tag]
+ String[] tags = SPACE.split(line);
+ // when analyzing the training set, assign IDs
+ if (assignIDs) {
+ if (!wordIDs.containsKey(tags[0])) {
+ wordIDs.put(tags[0], nextWordId++);
+ }
+ if (!tagIDs.containsKey(tags[1])) {
+ tagIDs.put(tags[1], nextTagId++);
+ }
+ }
+ // determine the IDs
+ Integer wordID = wordIDs.get(tags[0]);
+ Integer tagID = tagIDs.get(tags[1]);
+ // now construct the current sequence
+ if (wordID == null) {
+ observedSequence.add(0);
+ } else {
+ observedSequence.add(wordID);
+ }
+ if (tagID == null) {
+ hiddenSequence.add(0);
+ } else {
+ hiddenSequence.add(tagID);
+ }
+ }
+ // if there is still something in the pipe, register it
+ if (!observedSequence.isEmpty()) {
+ int[] observedSequenceArray = new int[observedSequence.size()];
+ int[] hiddenSequenceArray = new int[hiddenSequence.size()];
+ for (int i = 0; i < observedSequence.size(); ++i) {
+ observedSequenceArray[i] = observedSequence.get(i);
+ hiddenSequenceArray[i] = hiddenSequence.get(i);
+ }
+ // now register those arrays
+ hiddenSequences.add(hiddenSequenceArray);
+ observedSequences.add(observedSequenceArray);
+ }
+ }
+ private static void trainModel(String trainingURL) throws IOException {
+ tagIDs = new HashMap<>(44); // we expect 44 distinct tags
+ wordIDs = new HashMap<>(19122); // we expect 19122
+ // distinct words
+ log.info("Reading and parsing training data file from URL: {}", trainingURL);
+ long start = System.currentTimeMillis();
+ readFromURL(trainingURL, true);
+ long end = System.currentTimeMillis();
+ double duration = (end - start) / 1000.0;
+ log.info("Parsing done in {} seconds!", duration);
+ log.info("Read {} lines containing {} sentences with a total of {} distinct words and {} distinct POS tags.",
+ readLines, hiddenSequences.size(), nextWordId - 1, nextTagId - 1);
+ start = System.currentTimeMillis();
+ taggingModel = HmmTrainer.trainSupervisedSequence(nextTagId, nextWordId,
+ hiddenSequences, observedSequences, 0.05);
+ // we have to adjust the model a bit,
+ // since we assume a higher probability that a given unknown word is NNP
+ // than anything else
+ Matrix emissions = taggingModel.getEmissionMatrix();
+ for (int i = 0; i < taggingModel.getNrOfHiddenStates(); ++i) {
+ emissions.setQuick(i, 0, 0.1 / taggingModel.getNrOfHiddenStates());
+ }
+ int nnptag = tagIDs.get("NNP");
+ emissions.setQuick(nnptag, 0, 1 / (double) taggingModel.getNrOfHiddenStates());
+ // re-normalize the emission probabilities
+ HmmUtils.normalizeModel(taggingModel);
+ // now register the names
+ taggingModel.registerHiddenStateNames(tagIDs);
+ taggingModel.registerOutputStateNames(wordIDs);
+ end = System.currentTimeMillis();
+ duration = (end - start) / 1000.0;
+ log.info("Trained HMM models in {} seconds!", duration);
+ }
+ private static void testModel(String testingURL) throws IOException {
+ log.info("Reading and parsing test data file from URL: {}", testingURL);
+ long start = System.currentTimeMillis();
+ readFromURL(testingURL, false);
+ long end = System.currentTimeMillis();
+ double duration = (end - start) / 1000.0;
+ log.info("Parsing done in {} seconds!", duration);
+ log.info("Read {} lines containing {} sentences.", readLines, hiddenSequences.size());
+ start = System.currentTimeMillis();
+ int errorCount = 0;
+ int totalCount = 0;
+ for (int i = 0; i < observedSequences.size(); ++i) {
+ // fetch the viterbi path as the POS tag for this observed sequence
+ int[] posEstimate = HmmEvaluator.decode(taggingModel, observedSequences.get(i), false);
+ // compare with the expected
+ int[] posExpected = hiddenSequences.get(i);
+ for (int j = 0; j < posExpected.length; ++j) {
+ totalCount++;
+ if (posEstimate[j] != posExpected[j]) {
+ errorCount++;
+ }
+ }
+ }
+ end = System.currentTimeMillis();
+ duration = (end - start) / 1000.0;
+ log.info("POS tagged test file in {} seconds!", duration);
+ double errorRate = (double) errorCount / totalCount;
+ log.info("Tagged the test file with an error rate of: {}", errorRate);
+ }
+ private static List<String> tagSentence(String sentence) {
+ // first, we need to isolate all punctuation characters, so that they
+ // can be recognized
+ sentence = sentence.replaceAll("[,.!?:;\"]", " $0 ");
+ sentence = sentence.replaceAll("''", " '' ");
+ // now we tokenize the sentence
+ String[] tokens = SPACES.split(sentence);
+ // now generate the observed sequence
+ int[] observedSequence = HmmUtils.encodeStateSequence(taggingModel, Arrays.asList(tokens), true, 0);
+ // POS tag this observedSequence
+ int[] hiddenSequence = HmmEvaluator.decode(taggingModel, observedSequence, false);
+ // and now decode the tag names
+ return HmmUtils.decodeStateSequence(taggingModel, hiddenSequence, false, null);
+ }
+ public static void main(String[] args) throws IOException {
+ // generate the model from URL
+ trainModel("http://www.jaist.ac.jp/~hieuxuan/flexcrfs/CoNLL2000-NP/train.txt");
+ testModel("http://www.jaist.ac.jp/~hieuxuan/flexcrfs/CoNLL2000-NP/test.txt");
+ // tag an exemplary sentence
+ String test = "McDonalds is a huge company with many employees .";
+ String[] testWords = SPACE.split(test);
+ List<String> posTags = tagSentence(test);
+ for (int i = 0; i < posTags.size(); ++i) {
+ log.info("{}[{}]", testWords[i], posTags.get(i));
+ }
+ }
2018-06-27 13:14:46 UTC
diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/email/MailToPrefsDriver.java b/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/email/MailToPrefsDriver.java
new file mode 100644
index 0000000..752bb48
--- /dev/null
+++ b/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/email/MailToPrefsDriver.java
@@ -0,0 +1,274 @@
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.cf.taste.example.email;
+import com.google.common.io.Closeables;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.filecache.DistributedCache;
+import org.apache.hadoop.fs.FileStatus;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.FileUtil;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.LongWritable;
+import org.apache.hadoop.io.NullWritable;
+import org.apache.hadoop.io.SequenceFile;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.io.Writable;
+import org.apache.hadoop.mapreduce.Job;
+import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat;
+import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat;
+import org.apache.hadoop.mapreduce.lib.output.TextOutputFormat;
+import org.apache.hadoop.util.ToolRunner;
+import org.apache.mahout.common.AbstractJob;
+import org.apache.mahout.common.HadoopUtil;
+import org.apache.mahout.common.Pair;
+import org.apache.mahout.common.commandline.DefaultOptionCreator;
+import org.apache.mahout.common.iterator.sequencefile.PathFilters;
+import org.apache.mahout.common.iterator.sequencefile.PathType;
+import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirIterable;
+import org.apache.mahout.math.VarIntWritable;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+import java.io.IOException;
+import java.net.URI;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.atomic.AtomicInteger;
+ * Convert the Mail archives (see {@link org.apache.mahout.text.SequenceFilesFromMailArchives}) to a preference
+ * file that can be consumed by the {@link org.apache.mahout.cf.taste.hadoop.item.RecommenderJob}.
+ * <p/>
+ * This assumes the input is a Sequence File, that the key is: filename/message id and the value is a list
+ * (separated by the user's choosing) containing the from email and any references
+ * <p/>
+ * The output is a matrix where either the from or to are the rows (represented as longs) and the columns are the
+ * message ids that the user has interacted with (as a VectorWritable). This class currently does not account for
+ * thread hijacking.
+ * <p/>
+ * It also outputs a side table mapping the row ids to their original and the message ids to the message thread id
+ */
+public final class MailToPrefsDriver extends AbstractJob {
+ private static final Logger log = LoggerFactory.getLogger(MailToPrefsDriver.class);
+ private static final String OUTPUT_FILES_PATTERN = "part-*";
+ private static final int DICTIONARY_BYTE_OVERHEAD = 4;
+ public static void main(String[] args) throws Exception {
+ ToolRunner.run(new Configuration(), new MailToPrefsDriver(), args);
+ }
+ @Override
+ public int run(String[] args) throws Exception {
+ addInputOption();
+ addOutputOption();
+ addOption(DefaultOptionCreator.overwriteOption().create());
+ addOption("chunkSize", "cs", "The size of chunks to write. Default is 100 mb", "100");
+ addOption("separator", "sep", "The separator used in the input file to separate to, from, subject. Default is \\n",
+ "\n");
+ addOption("from", "f", "The position in the input text (value) where the from email is located, starting from "
+ + "zero (0).", "0");
+ addOption("refs", "r", "The position in the input text (value) where the reference ids are located, "
+ + "starting from zero (0).", "1");
+ addOption(buildOption("useCounts", "u", "If set, then use the number of times the user has interacted with a "
+ + "thread as an indication of their preference. Otherwise, use boolean preferences.", false, false,
+ String.valueOf(true)));
+ Map<String, List<String>> parsedArgs = parseArguments(args);
+ Path input = getInputPath();
+ Path output = getOutputPath();
+ int chunkSize = Integer.parseInt(getOption("chunkSize"));
+ String separator = getOption("separator");
+ Configuration conf = getConf();
+ boolean useCounts = hasOption("useCounts");
+ AtomicInteger currentPhase = new AtomicInteger();
+ int[] msgDim = new int[1];
+ //TODO: mod this to not do so many passes over the data. Dictionary creation could probably be a chain mapper
+ List<Path> msgIdChunks = null;
+ boolean overwrite = hasOption(DefaultOptionCreator.OVERWRITE_OPTION);
+ // create the dictionary between message ids and longs
+ if (shouldRunNextPhase(parsedArgs, currentPhase)) {
+ //TODO: there seems to be a pattern emerging for dictionary creation
+ // -- sparse vectors from seq files also has this.
+ Path msgIdsPath = new Path(output, "msgIds");
+ if (overwrite) {
+ HadoopUtil.delete(conf, msgIdsPath);
+ }
+ log.info("Creating Msg Id Dictionary");
+ Job createMsgIdDictionary = prepareJob(input,
+ msgIdsPath,
+ SequenceFileInputFormat.class,
+ MsgIdToDictionaryMapper.class,
+ Text.class,
+ VarIntWritable.class,
+ MailToDictionaryReducer.class,
+ Text.class,
+ VarIntWritable.class,
+ SequenceFileOutputFormat.class);
+ boolean succeeded = createMsgIdDictionary.waitForCompletion(true);
+ if (!succeeded) {
+ return -1;
+ }
+ //write out the dictionary at the top level
+ msgIdChunks = createDictionaryChunks(msgIdsPath, output, "msgIds-dictionary-",
+ createMsgIdDictionary.getConfiguration(), chunkSize, msgDim);
+ }
+ //create the dictionary between from email addresses and longs
+ List<Path> fromChunks = null;
+ if (shouldRunNextPhase(parsedArgs, currentPhase)) {
+ Path fromIdsPath = new Path(output, "fromIds");
+ if (overwrite) {
+ HadoopUtil.delete(conf, fromIdsPath);
+ }
+ log.info("Creating From Id Dictionary");
+ Job createFromIdDictionary = prepareJob(input,
+ fromIdsPath,
+ SequenceFileInputFormat.class,
+ FromEmailToDictionaryMapper.class,
+ Text.class,
+ VarIntWritable.class,
+ MailToDictionaryReducer.class,
+ Text.class,
+ VarIntWritable.class,
+ SequenceFileOutputFormat.class);
+ createFromIdDictionary.getConfiguration().set(EmailUtility.SEPARATOR, separator);
+ boolean succeeded = createFromIdDictionary.waitForCompletion(true);
+ if (!succeeded) {
+ return -1;
+ }
+ //write out the dictionary at the top level
+ int[] fromDim = new int[1];
+ fromChunks = createDictionaryChunks(fromIdsPath, output, "fromIds-dictionary-",
+ createFromIdDictionary.getConfiguration(), chunkSize, fromDim);
+ }
+ //OK, we have our dictionaries, let's output the real thing we need: <from_id -> <msgId, msgId, msgId, ...>>
+ if (shouldRunNextPhase(parsedArgs, currentPhase) && fromChunks != null && msgIdChunks != null) {
+ //Job map
+ //may be a way to do this so that we can load the from ids in memory, if they are small enough so that
+ // we don't need the double loop
+ log.info("Creating recommendation matrix");
+ Path vecPath = new Path(output, "recInput");
+ if (overwrite) {
+ HadoopUtil.delete(conf, vecPath);
+ }
+ //conf.set(EmailUtility.FROM_DIMENSION, String.valueOf(fromDim[0]));
+ conf.set(EmailUtility.MSG_ID_DIMENSION, String.valueOf(msgDim[0]));
+ conf.set(EmailUtility.FROM_PREFIX, "fromIds-dictionary-");
+ conf.set(EmailUtility.MSG_IDS_PREFIX, "msgIds-dictionary-");
+ conf.set(EmailUtility.FROM_INDEX, getOption("from"));
+ conf.set(EmailUtility.REFS_INDEX, getOption("refs"));
+ conf.set(EmailUtility.SEPARATOR, separator);
+ conf.set(MailToRecReducer.USE_COUNTS_PREFERENCE, String.valueOf(useCounts));
+ int j = 0;
+ int i = 0;
+ for (Path fromChunk : fromChunks) {
+ for (Path idChunk : msgIdChunks) {
+ Path out = new Path(vecPath, "tmp-" + i + '-' + j);
+ DistributedCache.setCacheFiles(new URI[]{fromChunk.toUri(), idChunk.toUri()}, conf);
+ Job createRecMatrix = prepareJob(input, out, SequenceFileInputFormat.class,
+ MailToRecMapper.class, Text.class, LongWritable.class, MailToRecReducer.class, Text.class,
+ NullWritable.class, TextOutputFormat.class);
+ createRecMatrix.getConfiguration().set("mapred.output.compress", "false");
+ boolean succeeded = createRecMatrix.waitForCompletion(true);
+ if (!succeeded) {
+ return -1;
+ }
+ //copy the results up a level
+ //HadoopUtil.copyMergeSeqFiles(out.getFileSystem(conf), out, vecPath.getFileSystem(conf), outPath, true,
+ // conf, "");
+ FileStatus[] fs = HadoopUtil.getFileStatus(new Path(out, "*"), PathType.GLOB, PathFilters.partFilter(), null,
+ conf);
+ for (int k = 0; k < fs.length; k++) {
+ FileStatus f = fs[k];
+ Path outPath = new Path(vecPath, "chunk-" + i + '-' + j + '-' + k);
+ FileUtil.copy(f.getPath().getFileSystem(conf), f.getPath(), outPath.getFileSystem(conf), outPath, true,
+ overwrite, conf);
+ }
+ HadoopUtil.delete(conf, out);
+ j++;
+ }
+ i++;
+ }
+ //concat the files together
+ /*Path mergePath = new Path(output, "vectors.dat");
+ if (overwrite) {
+ HadoopUtil.delete(conf, mergePath);
+ }
+ log.info("Merging together output vectors to vectors.dat in {}", output);*/
+ //HadoopUtil.copyMergeSeqFiles(vecPath.getFileSystem(conf), vecPath, mergePath.getFileSystem(conf), mergePath,
+ // false, conf, "\n");
+ }
+ return 0;
+ }
+ private static List<Path> createDictionaryChunks(Path inputPath,
+ Path dictionaryPathBase,
+ String name,
+ Configuration baseConf,
+ int chunkSizeInMegabytes, int[] maxTermDimension)
+ throws IOException {
+ List<Path> chunkPaths = new ArrayList<>();
+ Configuration conf = new Configuration(baseConf);
+ FileSystem fs = FileSystem.get(inputPath.toUri(), conf);
+ long chunkSizeLimit = chunkSizeInMegabytes * 1024L * 1024L;
+ int chunkIndex = 0;
+ Path chunkPath = new Path(dictionaryPathBase, name + chunkIndex);
+ chunkPaths.add(chunkPath);
+ SequenceFile.Writer dictWriter = new SequenceFile.Writer(fs, conf, chunkPath, Text.class, IntWritable.class);
+ try {
+ long currentChunkSize = 0;
+ Path filesPattern = new Path(inputPath, OUTPUT_FILES_PATTERN);
+ int i = 1; //start at 1, since a miss in the OpenObjectIntHashMap returns a 0
+ for (Pair<Writable, Writable> record
+ : new SequenceFileDirIterable<>(filesPattern, PathType.GLOB, null, null, true, conf)) {
+ if (currentChunkSize > chunkSizeLimit) {
+ Closeables.close(dictWriter, false);
+ chunkIndex++;
+ chunkPath = new Path(dictionaryPathBase, name + chunkIndex);
+ chunkPaths.add(chunkPath);
+ dictWriter = new SequenceFile.Writer(fs, conf, chunkPath, Text.class, IntWritable.class);
+ currentChunkSize = 0;
+ }
+ Writable key = record.getFirst();
+ int fieldSize = DICTIONARY_BYTE_OVERHEAD + key.toString().length() * 2 + Integer.SIZE / 8;
+ currentChunkSize += fieldSize;
+ dictWriter.append(key, new IntWritable(i++));
+ }
+ maxTermDimension[0] = i;
+ } finally {
+ Closeables.close(dictWriter, false);
+ }
+ return chunkPaths;
+ }

diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/email/MailToRecMapper.java b/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/email/MailToRecMapper.java
new file mode 100644
index 0000000..91bbd17
--- /dev/null
+++ b/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/email/MailToRecMapper.java
@@ -0,0 +1,101 @@
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.cf.taste.example.email;
+import org.apache.commons.lang3.StringUtils;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.io.LongWritable;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.mapreduce.Mapper;
+import org.apache.mahout.math.map.OpenObjectIntHashMap;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+import java.io.IOException;
+public final class MailToRecMapper extends Mapper<Text, Text, Text, LongWritable> {
+ private static final Logger log = LoggerFactory.getLogger(MailToRecMapper.class);
+ private final OpenObjectIntHashMap<String> fromDictionary = new OpenObjectIntHashMap<>();
+ private final OpenObjectIntHashMap<String> msgIdDictionary = new OpenObjectIntHashMap<>();
+ private String separator = "\n";
+ private int fromIdx;
+ private int refsIdx;
+ public enum Counters {
+ }
+ @Override
+ protected void setup(Context context) throws IOException, InterruptedException {
+ super.setup(context);
+ Configuration conf = context.getConfiguration();
+ String fromPrefix = conf.get(EmailUtility.FROM_PREFIX);
+ String msgPrefix = conf.get(EmailUtility.MSG_IDS_PREFIX);
+ fromIdx = conf.getInt(EmailUtility.FROM_INDEX, 0);
+ refsIdx = conf.getInt(EmailUtility.REFS_INDEX, 1);
+ EmailUtility.loadDictionaries(conf, fromPrefix, fromDictionary, msgPrefix, msgIdDictionary);
+ log.info("From Dictionary size: {} Msg Id Dictionary size: {}", fromDictionary.size(), msgIdDictionary.size());
+ separator = context.getConfiguration().get(EmailUtility.SEPARATOR);
+ }
+ @Override
+ protected void map(Text key, Text value, Context context) throws IOException, InterruptedException {
+ int msgIdKey = Integer.MIN_VALUE;
+ int fromKey = Integer.MIN_VALUE;
+ String valStr = value.toString();
+ String[] splits = StringUtils.splitByWholeSeparatorPreserveAllTokens(valStr, separator);
+ if (splits != null && splits.length > 0) {
+ if (splits.length > refsIdx) {
+ String from = EmailUtility.cleanUpEmailAddress(splits[fromIdx]);
+ fromKey = fromDictionary.get(from);
+ }
+ //get the references
+ if (splits.length > refsIdx) {
+ String[] theRefs = EmailUtility.parseReferences(splits[refsIdx]);
+ if (theRefs != null && theRefs.length > 0) {
+ //we have a reference, the first one is the original message id, so map to that one if it exists
+ msgIdKey = msgIdDictionary.get(theRefs[0]);
+ context.getCounter(Counters.REFERENCE).increment(1);
+ }
+ }
+ }
+ //we don't have any references, so use the msg id
+ if (msgIdKey == Integer.MIN_VALUE) {
+ //get the msg id and the from and output the associated ids
+ String keyStr = key.toString();
+ int idx = keyStr.lastIndexOf('/');
+ if (idx != -1) {
+ String msgId = keyStr.substring(idx + 1);
+ msgIdKey = msgIdDictionary.get(msgId);
+ context.getCounter(Counters.ORIGINAL).increment(1);
+ }
+ }
+ if (msgIdKey != Integer.MIN_VALUE && fromKey != Integer.MIN_VALUE) {
+ context.write(new Text(fromKey + "," + msgIdKey), new LongWritable(1));
+ }
+ }

diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/email/MailToRecReducer.java b/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/email/MailToRecReducer.java
new file mode 100644
index 0000000..ee36a41
--- /dev/null
+++ b/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/email/MailToRecReducer.java
@@ -0,0 +1,53 @@
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.cf.taste.example.email;
+import org.apache.hadoop.io.LongWritable;
+import org.apache.hadoop.io.NullWritable;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.mapreduce.Reducer;
+import java.io.IOException;
+public class MailToRecReducer extends Reducer<Text, LongWritable, Text, NullWritable> {
+ //if true, then output weight
+ private boolean useCounts = true;
+ /**
+ * We can either ignore how many times the user interacted (boolean) or output the number of times they interacted.
+ */
+ public static final String USE_COUNTS_PREFERENCE = "useBooleanPreferences";
+ @Override
+ protected void setup(Context context) throws IOException, InterruptedException {
+ useCounts = context.getConfiguration().getBoolean(USE_COUNTS_PREFERENCE, true);
+ }
+ @Override
+ protected void reduce(Text key, Iterable<LongWritable> values, Context context)
+ throws IOException, InterruptedException {
+ if (useCounts) {
+ long sum = 0;
+ for (LongWritable value : values) {
+ sum++;
+ }
+ context.write(new Text(key.toString() + ',' + sum), null);
+ } else {
+ context.write(new Text(key.toString()), null);
+ }
+ }

diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/email/MsgIdToDictionaryMapper.java b/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/email/MsgIdToDictionaryMapper.java
new file mode 100644
index 0000000..f3de847
--- /dev/null
+++ b/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/email/MsgIdToDictionaryMapper.java
@@ -0,0 +1,49 @@
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.cf.taste.example.email;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.mapreduce.Mapper;
+import org.apache.mahout.math.VarIntWritable;
+import java.io.IOException;
+ * Assumes the input is in the format created by {@link org.apache.mahout.text.SequenceFilesFromMailArchives}
+ */
+public final class MsgIdToDictionaryMapper extends Mapper<Text, Text, Text, VarIntWritable> {
+ @Override
+ protected void map(Text key, Text value, Context context) throws IOException, InterruptedException {
+ //message id is in the key: /201008/AANLkTikvVnhNH+Y5AGEwqd2=***@mail.gmail.com
+ String keyStr = key.toString();
+ int idx = keyStr.lastIndexOf('@'); //find the last @
+ if (idx == -1) {
+ context.getCounter(EmailUtility.Counters.NO_MESSAGE_ID).increment(1);
+ } else {
+ //found the @, now find the last slash before the @ and grab everything after that
+ idx = keyStr.lastIndexOf('/', idx);
+ String msgId = keyStr.substring(idx + 1);
+ if (EmailUtility.WHITESPACE.matcher(msgId).matches()) {
+ context.getCounter(EmailUtility.Counters.NO_MESSAGE_ID).increment(1);
+ } else {
+ context.write(new Text(msgId), new VarIntWritable(1));
+ }
+ }
+ }

diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/DataFileIterable.java b/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/DataFileIterable.java
new file mode 100644
index 0000000..c358021
--- /dev/null
+++ b/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/DataFileIterable.java
@@ -0,0 +1,44 @@
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.cf.taste.example.kddcup;
+import java.io.File;
+import java.io.IOException;
+import java.util.Iterator;
+import org.apache.mahout.cf.taste.model.PreferenceArray;
+import org.apache.mahout.common.Pair;
+public final class DataFileIterable implements Iterable<Pair<PreferenceArray,long[]>> {
+ private final File dataFile;
+ public DataFileIterable(File dataFile) {
+ this.dataFile = dataFile;
+ }
+ @Override
+ public Iterator<Pair<PreferenceArray, long[]>> iterator() {
+ try {
+ return new DataFileIterator(dataFile);
+ } catch (IOException ioe) {
+ throw new IllegalStateException(ioe);
+ }
+ }

diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/DataFileIterator.java b/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/DataFileIterator.java
new file mode 100644
index 0000000..786e080
--- /dev/null
+++ b/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/DataFileIterator.java
@@ -0,0 +1,158 @@
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.cf.taste.example.kddcup;
+import java.io.Closeable;
+import java.io.File;
+import java.io.IOException;
+import java.util.regex.Pattern;
+import com.google.common.collect.AbstractIterator;
+import com.google.common.io.Closeables;
+import org.apache.mahout.cf.taste.impl.common.SkippingIterator;
+import org.apache.mahout.cf.taste.impl.model.GenericUserPreferenceArray;
+import org.apache.mahout.cf.taste.model.PreferenceArray;
+import org.apache.mahout.common.iterator.FileLineIterator;
+import org.apache.mahout.common.Pair;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+ * <p>An {@link java.util.Iterator} which iterates over any of the KDD Cup's rating files. These include the files
+ * {train,test,validation}Idx{1,2}}.txt. See http://kddcup.yahoo.com/. Each element in the iteration corresponds
+ * to one user's ratings as a {@link PreferenceArray} and corresponding timestamps as a parallel {@code long}
+ * array.</p>
+ *
+ * <p>Timestamps in the data set are relative to some unknown point in time, for anonymity. They are assumed
+ * to be relative to the epoch, time 0, or January 1 1970, for purposes here.</p>
+ */
+public final class DataFileIterator
+ extends AbstractIterator<Pair<PreferenceArray,long[]>>
+ implements SkippingIterator<Pair<PreferenceArray,long[]>>, Closeable {
+ private static final Pattern COLON_PATTERN = Pattern.compile(":");
+ private static final Pattern PIPE_PATTERN = Pattern.compile("\\|");
+ private static final Pattern TAB_PATTERN = Pattern.compile("\t");
+ private final FileLineIterator lineIterator;
+ private static final Logger log = LoggerFactory.getLogger(DataFileIterator.class);
+ public DataFileIterator(File dataFile) throws IOException {
+ if (dataFile == null || dataFile.isDirectory() || !dataFile.exists()) {
+ throw new IllegalArgumentException("Bad data file: " + dataFile);
+ }
+ lineIterator = new FileLineIterator(dataFile);
+ }
+ @Override
+ protected Pair<PreferenceArray, long[]> computeNext() {
+ if (!lineIterator.hasNext()) {
+ return endOfData();
+ }
+ String line = lineIterator.next();
+ // First a userID|ratingsCount line
+ String[] tokens = PIPE_PATTERN.split(line);
+ long userID = Long.parseLong(tokens[0]);
+ int ratingsLeftToRead = Integer.parseInt(tokens[1]);
+ int ratingsRead = 0;
+ PreferenceArray currentUserPrefs = new GenericUserPreferenceArray(ratingsLeftToRead);
+ long[] timestamps = new long[ratingsLeftToRead];
+ while (ratingsLeftToRead > 0) {
+ line = lineIterator.next();
+ // Then a data line. May be 1-4 tokens depending on whether preference info is included (it's not in test data)
+ // or whether date info is included (not inluded in track 2). Item ID is always first, and date is the last
+ // two fields if it exists.
+ tokens = TAB_PATTERN.split(line);
+ boolean hasPref = tokens.length == 2 || tokens.length == 4;
+ boolean hasDate = tokens.length > 2;
+ long itemID = Long.parseLong(tokens[0]);
+ currentUserPrefs.setUserID(0, userID);
+ currentUserPrefs.setItemID(ratingsRead, itemID);
+ if (hasPref) {
+ float preference = Float.parseFloat(tokens[1]);
+ currentUserPrefs.setValue(ratingsRead, preference);
+ }
+ if (hasDate) {
+ long timestamp;
+ if (hasPref) {
+ timestamp = parseFakeTimestamp(tokens[2], tokens[3]);
+ } else {
+ timestamp = parseFakeTimestamp(tokens[1], tokens[2]);
+ }
+ timestamps[ratingsRead] = timestamp;
+ }
+ ratingsRead++;
+ ratingsLeftToRead--;
+ }
+ return new Pair<>(currentUserPrefs, timestamps);
+ }
+ @Override
+ public void skip(int n) {
+ for (int i = 0; i < n; i++) {
+ if (lineIterator.hasNext()) {
+ String line = lineIterator.next();
+ // First a userID|ratingsCount line
+ String[] tokens = PIPE_PATTERN.split(line);
+ int linesToSKip = Integer.parseInt(tokens[1]);
+ lineIterator.skip(linesToSKip);
+ } else {
+ break;
+ }
+ }
+ }
+ @Override
+ public void close() {
+ endOfData();
+ try {
+ Closeables.close(lineIterator, true);
+ } catch (IOException e) {
+ log.error(e.getMessage(), e);
+ }
+ }
+ /**
+ * @param dateString "date" in days since some undisclosed date, which we will arbitrarily assume to be the
+ * epoch, January 1 1970.
+ * @param timeString time of day in HH:mm:ss format
+ * @return the UNIX timestamp for this moment in time
+ */
+ private static long parseFakeTimestamp(String dateString, CharSequence timeString) {
+ int days = Integer.parseInt(dateString);
+ String[] timeTokens = COLON_PATTERN.split(timeString);
+ int hours = Integer.parseInt(timeTokens[0]);
+ int minutes = Integer.parseInt(timeTokens[1]);
+ int seconds = Integer.parseInt(timeTokens[2]);
+ return 86400L * days + 3600L + hours + 60L * minutes + seconds;
+ }

diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/KDDCupDataModel.java b/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/KDDCupDataModel.java
new file mode 100644
index 0000000..4b62050
--- /dev/null
+++ b/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/KDDCupDataModel.java
@@ -0,0 +1,231 @@
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.cf.taste.example.kddcup;
+import java.io.File;
+import java.io.IOException;
+import java.util.Collection;
+import java.util.Iterator;
+import com.google.common.base.Preconditions;
+import org.apache.mahout.cf.taste.common.Refreshable;
+import org.apache.mahout.cf.taste.common.TasteException;
+import org.apache.mahout.cf.taste.impl.common.FastByIDMap;
+import org.apache.mahout.cf.taste.impl.common.FastIDSet;
+import org.apache.mahout.cf.taste.impl.common.LongPrimitiveIterator;
+import org.apache.mahout.cf.taste.impl.model.GenericDataModel;
+import org.apache.mahout.cf.taste.model.DataModel;
+import org.apache.mahout.cf.taste.model.PreferenceArray;
+import org.apache.mahout.common.Pair;
+import org.apache.mahout.common.iterator.SamplingIterator;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+ * <p>An {@link DataModel} which reads into memory any of the KDD Cup's rating files; it is really
+ * meant for use with training data in the files trainIdx{1,2}}.txt.
+ * See http://kddcup.yahoo.com/.</p>
+ *
+ * <p>Timestamps in the data set are relative to some unknown point in time, for anonymity. They are assumed
+ * to be relative to the epoch, time 0, or January 1 1970, for purposes here.</p>
+ */
+public final class KDDCupDataModel implements DataModel {
+ private static final Logger log = LoggerFactory.getLogger(KDDCupDataModel.class);
+ private final File dataFileDirectory;
+ private final DataModel delegate;
+ /**
+ * @param dataFile training rating file
+ */
+ public KDDCupDataModel(File dataFile) throws IOException {
+ this(dataFile, false, 1.0);
+ }
+ /**
+ * @param dataFile training rating file
+ * @param storeDates if true, dates are parsed and stored, otherwise not
+ * @param samplingRate percentage of users to keep; can be used to reduce memory requirements
+ */
+ public KDDCupDataModel(File dataFile, boolean storeDates, double samplingRate) throws IOException {
+ Preconditions.checkArgument(!Double.isNaN(samplingRate) && samplingRate > 0.0 && samplingRate <= 1.0,
+ "Must be: 0.0 < samplingRate <= 1.0");
+ dataFileDirectory = dataFile.getParentFile();
+ Iterator<Pair<PreferenceArray,long[]>> dataIterator = new DataFileIterator(dataFile);
+ if (samplingRate < 1.0) {
+ dataIterator = new SamplingIterator<>(dataIterator, samplingRate);
+ }
+ FastByIDMap<PreferenceArray> userData = new FastByIDMap<>();
+ FastByIDMap<FastByIDMap<Long>> timestamps = new FastByIDMap<>();
+ while (dataIterator.hasNext()) {
+ Pair<PreferenceArray,long[]> pair = dataIterator.next();
+ PreferenceArray userPrefs = pair.getFirst();
+ long[] timestampsForPrefs = pair.getSecond();
+ userData.put(userPrefs.getUserID(0), userPrefs);
+ if (storeDates) {
+ FastByIDMap<Long> itemTimestamps = new FastByIDMap<>();
+ for (int i = 0; i < timestampsForPrefs.length; i++) {
+ long timestamp = timestampsForPrefs[i];
+ if (timestamp > 0L) {
+ itemTimestamps.put(userPrefs.getItemID(i), timestamp);
+ }
+ }
+ }
+ }
+ if (storeDates) {
+ delegate = new GenericDataModel(userData, timestamps);
+ } else {
+ delegate = new GenericDataModel(userData);
+ }
+ Runtime runtime = Runtime.getRuntime();
+ log.info("Loaded data model in about {}MB heap", (runtime.totalMemory() - runtime.freeMemory()) / 1000000);
+ }
+ public File getDataFileDirectory() {
+ return dataFileDirectory;
+ }
+ public static File getTrainingFile(File dataFileDirectory) {
+ return getFile(dataFileDirectory, "trainIdx");
+ }
+ public static File getValidationFile(File dataFileDirectory) {
+ return getFile(dataFileDirectory, "validationIdx");
+ }
+ public static File getTestFile(File dataFileDirectory) {
+ return getFile(dataFileDirectory, "testIdx");
+ }
+ public static File getTrackFile(File dataFileDirectory) {
+ return getFile(dataFileDirectory, "trackData");
+ }
+ private static File getFile(File dataFileDirectory, String prefix) {
+ // Works on set 1 or 2
+ for (int set : new int[] {1,2}) {
+ // Works on sample data from before contest or real data
+ for (String firstLinesOrNot : new String[] {"", ".firstLines"}) {
+ for (String gzippedOrNot : new String[] {".gz", ""}) {
+ File dataFile = new File(dataFileDirectory, prefix + set + firstLinesOrNot + ".txt" + gzippedOrNot);
+ if (dataFile.exists()) {
+ return dataFile;
+ }
+ }
+ }
+ }
+ throw new IllegalArgumentException("Can't find " + prefix + " file in " + dataFileDirectory);
+ }
+ @Override
+ public LongPrimitiveIterator getUserIDs() throws TasteException {
+ return delegate.getUserIDs();
+ }
+ @Override
+ public PreferenceArray getPreferencesFromUser(long userID) throws TasteException {
+ return delegate.getPreferencesFromUser(userID);
+ }
+ @Override
+ public FastIDSet getItemIDsFromUser(long userID) throws TasteException {
+ return delegate.getItemIDsFromUser(userID);
+ }
+ @Override
+ public LongPrimitiveIterator getItemIDs() throws TasteException {
+ return delegate.getItemIDs();
+ }
+ @Override
+ public PreferenceArray getPreferencesForItem(long itemID) throws TasteException {
+ return delegate.getPreferencesForItem(itemID);
+ }
+ @Override
+ public Float getPreferenceValue(long userID, long itemID) throws TasteException {
+ return delegate.getPreferenceValue(userID, itemID);
+ }
+ @Override
+ public Long getPreferenceTime(long userID, long itemID) throws TasteException {
+ return delegate.getPreferenceTime(userID, itemID);
+ }
+ @Override
+ public int getNumItems() throws TasteException {
+ return delegate.getNumItems();
+ }
+ @Override
+ public int getNumUsers() throws TasteException {
+ return delegate.getNumUsers();
+ }
+ @Override
+ public int getNumUsersWithPreferenceFor(long itemID) throws TasteException {
+ return delegate.getNumUsersWithPreferenceFor(itemID);
+ }
+ @Override
+ public int getNumUsersWithPreferenceFor(long itemID1, long itemID2) throws TasteException {
+ return delegate.getNumUsersWithPreferenceFor(itemID1, itemID2);
+ }
+ @Override
+ public void setPreference(long userID, long itemID, float value) throws TasteException {
+ delegate.setPreference(userID, itemID, value);
+ }
+ @Override
+ public void removePreference(long userID, long itemID) throws TasteException {
+ delegate.removePreference(userID, itemID);
+ }
+ @Override
+ public boolean hasPreferenceValues() {
+ return delegate.hasPreferenceValues();
+ }
+ @Override
+ public float getMaxPreference() {
+ return 100.0f;
+ }
+ @Override
+ public float getMinPreference() {
+ return 0.0f;
+ }
+ @Override
+ public void refresh(Collection<Refreshable> alreadyRefreshed) {
+ // do nothing
+ }

diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/ToCSV.java b/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/ToCSV.java
new file mode 100644
index 0000000..3f4a732
--- /dev/null
+++ b/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/ToCSV.java
@@ -0,0 +1,77 @@
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.cf.taste.example.kddcup;
+import org.apache.commons.io.Charsets;
+import org.apache.mahout.cf.taste.model.PreferenceArray;
+import org.apache.mahout.common.Pair;
+import java.io.BufferedWriter;
+import java.io.File;
+import java.io.FileOutputStream;
+import java.io.OutputStream;
+import java.io.OutputStreamWriter;
+import java.io.Writer;
+import java.util.zip.GZIPOutputStream;
+ * <p>This class converts a KDD Cup input file into a compressed CSV format. The output format is
+ * {@code userID,itemID,score,timestamp}. It can optionally restrict its output to exclude
+ * score and/or timestamp.</p>
+ *
+ * <p>Run as: {@code ToCSV (input file) (output file) [num columns to output]}</p>
+ */
+public final class ToCSV {
+ private ToCSV() {
+ }
+ public static void main(String[] args) throws Exception {
+ File inputFile = new File(args[0]);
+ File outputFile = new File(args[1]);
+ int columnsToOutput = 4;
+ if (args.length >= 3) {
+ columnsToOutput = Integer.parseInt(args[2]);
+ }
+ OutputStream outStream = new GZIPOutputStream(new FileOutputStream(outputFile));
+ try (Writer outWriter = new BufferedWriter(new OutputStreamWriter(outStream, Charsets.UTF_8))){
+ for (Pair<PreferenceArray,long[]> user : new DataFileIterable(inputFile)) {
+ PreferenceArray prefs = user.getFirst();
+ long[] timestamps = user.getSecond();
+ for (int i = 0; i < prefs.length(); i++) {
+ outWriter.write(String.valueOf(prefs.getUserID(i)));
+ outWriter.write(',');
+ outWriter.write(String.valueOf(prefs.getItemID(i)));
+ if (columnsToOutput > 2) {
+ outWriter.write(',');
+ outWriter.write(String.valueOf(prefs.getValue(i)));
+ }
+ if (columnsToOutput > 3) {
+ outWriter.write(',');
+ outWriter.write(String.valueOf(timestamps[i]));
+ }
+ outWriter.write('\n');
+ }
+ }
+ }
+ }

diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/EstimateConverter.java b/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/EstimateConverter.java
new file mode 100644
index 0000000..0112ab9
--- /dev/null
+++ b/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/EstimateConverter.java
@@ -0,0 +1,43 @@
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.cf.taste.example.kddcup.track1;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+public final class EstimateConverter {
+ private static final Logger log = LoggerFactory.getLogger(EstimateConverter.class);
+ private EstimateConverter() {}
+ public static byte convert(double estimate, long userID, long itemID) {
+ if (Double.isNaN(estimate)) {
+ log.warn("Unable to compute estimate for user {}, item {}", userID, itemID);
+ return 0x7F;
+ } else {
+ int scaledEstimate = (int) (estimate * 2.55);
+ if (scaledEstimate > 255) {
+ scaledEstimate = 255;
+ } else if (scaledEstimate < 0) {
+ scaledEstimate = 0;
+ }
+ return (byte) scaledEstimate;
+ }
+ }

diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/Track1Callable.java b/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/Track1Callable.java
new file mode 100644
index 0000000..72056da
--- /dev/null
+++ b/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/Track1Callable.java
@@ -0,0 +1,67 @@
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.cf.taste.example.kddcup.track1;
+import java.util.concurrent.Callable;
+import java.util.concurrent.atomic.AtomicInteger;
+import org.apache.mahout.cf.taste.common.NoSuchItemException;
+import org.apache.mahout.cf.taste.common.TasteException;
+import org.apache.mahout.cf.taste.model.PreferenceArray;
+import org.apache.mahout.cf.taste.recommender.Recommender;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+final class Track1Callable implements Callable<byte[]> {
+ private static final Logger log = LoggerFactory.getLogger(Track1Callable.class);
+ private static final AtomicInteger COUNT = new AtomicInteger();
+ private final Recommender recommender;
+ private final PreferenceArray userTest;
+ Track1Callable(Recommender recommender, PreferenceArray userTest) {
+ this.recommender = recommender;
+ this.userTest = userTest;
+ }
+ @Override
+ public byte[] call() throws TasteException {
+ long userID = userTest.get(0).getUserID();
+ byte[] result = new byte[userTest.length()];
+ for (int i = 0; i < userTest.length(); i++) {
+ long itemID = userTest.getItemID(i);
+ double estimate;
+ try {
+ estimate = recommender.estimatePreference(userID, itemID);
+ } catch (NoSuchItemException nsie) {
+ // OK in the sample data provided before the contest, should never happen otherwise
+ log.warn("Unknown item {}; OK unless this is the real contest data", itemID);
+ continue;
+ }
+ result[i] = EstimateConverter.convert(estimate, userID, itemID);
+ }
+ if (COUNT.incrementAndGet() % 10000 == 0) {
+ log.info("Completed {} users", COUNT.get());
+ }
+ return result;
+ }

diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/Track1Recommender.java b/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/Track1Recommender.java
new file mode 100644
index 0000000..067daf5
--- /dev/null
+++ b/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/Track1Recommender.java
@@ -0,0 +1,94 @@
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.cf.taste.example.kddcup.track1;
+import java.util.Collection;
+import java.util.List;
+import org.apache.mahout.cf.taste.common.Refreshable;
+import org.apache.mahout.cf.taste.common.TasteException;
+import org.apache.mahout.cf.taste.impl.recommender.GenericItemBasedRecommender;
+import org.apache.mahout.cf.taste.impl.similarity.UncenteredCosineSimilarity;
+import org.apache.mahout.cf.taste.model.DataModel;
+import org.apache.mahout.cf.taste.recommender.IDRescorer;
+import org.apache.mahout.cf.taste.recommender.RecommendedItem;
+import org.apache.mahout.cf.taste.recommender.Recommender;
+import org.apache.mahout.cf.taste.similarity.ItemSimilarity;
+public final class Track1Recommender implements Recommender {
+ private final Recommender recommender;
+ public Track1Recommender(DataModel dataModel) throws TasteException {
+ // Change this to whatever you like!
+ ItemSimilarity similarity = new UncenteredCosineSimilarity(dataModel);
+ recommender = new GenericItemBasedRecommender(dataModel, similarity);
+ }
+ @Override
+ public List<RecommendedItem> recommend(long userID, int howMany) throws TasteException {
+ return recommender.recommend(userID, howMany);
+ }
+ @Override
+ public List<RecommendedItem> recommend(long userID, int howMany, boolean includeKnownItems) throws TasteException {
+ return recommend(userID, howMany, null, includeKnownItems);
+ }
+ @Override
+ public List<RecommendedItem> recommend(long userID, int howMany, IDRescorer rescorer) throws TasteException {
+ return recommender.recommend(userID, howMany, rescorer, false);
+ }
+ @Override
+ public List<RecommendedItem> recommend(long userID, int howMany, IDRescorer rescorer, boolean includeKnownItems)
+ throws TasteException {
+ return recommender.recommend(userID, howMany, rescorer, includeKnownItems);
+ }
+ @Override
+ public float estimatePreference(long userID, long itemID) throws TasteException {
+ return recommender.estimatePreference(userID, itemID);
+ }
+ @Override
+ public void setPreference(long userID, long itemID, float value) throws TasteException {
+ recommender.setPreference(userID, itemID, value);
+ }
+ @Override
+ public void removePreference(long userID, long itemID) throws TasteException {
+ recommender.removePreference(userID, itemID);
+ }
+ @Override
+ public DataModel getDataModel() {
+ return recommender.getDataModel();
+ }
+ @Override
+ public void refresh(Collection<Refreshable> alreadyRefreshed) {
+ recommender.refresh(alreadyRefreshed);
+ }
+ @Override
+ public String toString() {
+ return "Track1Recommender[recommender:" + recommender + ']';
+ }

diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/Track1RecommenderBuilder.java b/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/Track1RecommenderBuilder.java
new file mode 100644
index 0000000..6b9fe1b
--- /dev/null
+++ b/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/Track1RecommenderBuilder.java
@@ -0,0 +1,32 @@
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.cf.taste.example.kddcup.track1;
+import org.apache.mahout.cf.taste.common.TasteException;
+import org.apache.mahout.cf.taste.eval.RecommenderBuilder;
+import org.apache.mahout.cf.taste.model.DataModel;
+import org.apache.mahout.cf.taste.recommender.Recommender;
+final class Track1RecommenderBuilder implements RecommenderBuilder {
+ @Override
+ public Recommender buildRecommender(DataModel dataModel) throws TasteException {
+ return new Track1Recommender(dataModel);
+ }

diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/Track1RecommenderEvaluator.java b/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/Track1RecommenderEvaluator.java
new file mode 100644
index 0000000..bcd0a3d
--- /dev/null
+++ b/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/Track1RecommenderEvaluator.java
@@ -0,0 +1,108 @@
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.cf.taste.example.kddcup.track1;
+import java.io.File;
+import java.util.Collection;
+import java.util.concurrent.Callable;
+import java.util.concurrent.atomic.AtomicInteger;
+import com.google.common.collect.Lists;
+import org.apache.mahout.cf.taste.common.TasteException;
+import org.apache.mahout.cf.taste.eval.DataModelBuilder;
+import org.apache.mahout.cf.taste.eval.RecommenderBuilder;
+import org.apache.mahout.cf.taste.example.kddcup.DataFileIterable;
+import org.apache.mahout.cf.taste.example.kddcup.KDDCupDataModel;
+import org.apache.mahout.cf.taste.impl.common.FullRunningAverage;
+import org.apache.mahout.cf.taste.impl.common.FullRunningAverageAndStdDev;
+import org.apache.mahout.cf.taste.impl.common.RunningAverage;
+import org.apache.mahout.cf.taste.impl.common.RunningAverageAndStdDev;
+import org.apache.mahout.cf.taste.impl.eval.AbstractDifferenceRecommenderEvaluator;
+import org.apache.mahout.cf.taste.model.DataModel;
+import org.apache.mahout.cf.taste.model.Preference;
+import org.apache.mahout.cf.taste.model.PreferenceArray;
+import org.apache.mahout.cf.taste.recommender.Recommender;
+import org.apache.mahout.common.Pair;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+ * Attempts to run an evaluation just like that dictated for Yahoo's KDD Cup, Track 1.
+ * It will compute the RMSE of a validation data set against the predicted ratings from
+ * the training data set.
+ */
+public final class Track1RecommenderEvaluator extends AbstractDifferenceRecommenderEvaluator {
+ private static final Logger log = LoggerFactory.getLogger(Track1RecommenderEvaluator.class);
+ private RunningAverage average;
+ private final File dataFileDirectory;
+ public Track1RecommenderEvaluator(File dataFileDirectory) {
+ setMaxPreference(100.0f);
+ setMinPreference(0.0f);
+ average = new FullRunningAverage();
+ this.dataFileDirectory = dataFileDirectory;
+ }
+ @Override
+ public double evaluate(RecommenderBuilder recommenderBuilder,
+ DataModelBuilder dataModelBuilder,
+ DataModel dataModel,
+ double trainingPercentage,
+ double evaluationPercentage) throws TasteException {
+ Recommender recommender = recommenderBuilder.buildRecommender(dataModel);
+ Collection<Callable<Void>> estimateCallables = Lists.newArrayList();
+ AtomicInteger noEstimateCounter = new AtomicInteger();
+ for (Pair<PreferenceArray,long[]> userData
+ : new DataFileIterable(KDDCupDataModel.getValidationFile(dataFileDirectory))) {
+ PreferenceArray validationPrefs = userData.getFirst();
+ long userID = validationPrefs.get(0).getUserID();
+ estimateCallables.add(
+ new PreferenceEstimateCallable(recommender, userID, validationPrefs, noEstimateCounter));
+ }
+ RunningAverageAndStdDev timing = new FullRunningAverageAndStdDev();
+ execute(estimateCallables, noEstimateCounter, timing);
+ double result = computeFinalEvaluation();
+ log.info("Evaluation result: {}", result);
+ return result;
+ }
+ // Use RMSE scoring:
+ @Override
+ protected void reset() {
+ average = new FullRunningAverage();
+ }
+ @Override
+ protected void processOneEstimate(float estimatedPreference, Preference realPref) {
+ double diff = realPref.getValue() - estimatedPreference;
+ average.addDatum(diff * diff);
+ }
+ @Override
+ protected double computeFinalEvaluation() {
+ return Math.sqrt(average.getAverage());
+ }

diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/Track1RecommenderEvaluatorRunner.java b/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/Track1RecommenderEvaluatorRunner.java
new file mode 100644
index 0000000..deadc00
--- /dev/null
+++ b/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/Track1RecommenderEvaluatorRunner.java
@@ -0,0 +1,56 @@
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.cf.taste.example.kddcup.track1;
+import java.io.File;
+import java.io.IOException;
+import org.apache.commons.cli2.OptionException;
+import org.apache.mahout.cf.taste.common.TasteException;
+import org.apache.mahout.cf.taste.example.TasteOptionParser;
+import org.apache.mahout.cf.taste.example.kddcup.KDDCupDataModel;
+import org.apache.mahout.cf.taste.model.DataModel;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+public final class Track1RecommenderEvaluatorRunner {
+ private static final Logger log = LoggerFactory.getLogger(Track1RecommenderEvaluatorRunner.class);
+ private Track1RecommenderEvaluatorRunner() {
+ }
+ public static void main(String... args) throws IOException, TasteException, OptionException {
+ File dataFileDirectory = TasteOptionParser.getRatings(args);
+ if (dataFileDirectory == null) {
+ throw new IllegalArgumentException("No data directory");
+ }
+ if (!dataFileDirectory.exists() || !dataFileDirectory.isDirectory()) {
+ throw new IllegalArgumentException("Bad data file directory: " + dataFileDirectory);
+ }
+ Track1RecommenderEvaluator evaluator = new Track1RecommenderEvaluator(dataFileDirectory);
+ DataModel model = new KDDCupDataModel(KDDCupDataModel.getTrainingFile(dataFileDirectory));
+ double evaluation = evaluator.evaluate(new Track1RecommenderBuilder(),
+ null,
+ model,
+ Float.NaN,
+ Float.NaN);
+ log.info(String.valueOf(evaluation));
+ }

diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/Track1Runner.java b/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/Track1Runner.java
new file mode 100644
index 0000000..a0ff126
--- /dev/null
+++ b/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/Track1Runner.java
@@ -0,0 +1,95 @@
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.cf.taste.example.kddcup.track1;
+import org.apache.mahout.cf.taste.example.kddcup.DataFileIterable;
+import org.apache.mahout.cf.taste.example.kddcup.KDDCupDataModel;
+import org.apache.mahout.cf.taste.model.PreferenceArray;
+import org.apache.mahout.common.Pair;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+import java.io.BufferedOutputStream;
+import java.io.File;
+import java.io.FileOutputStream;
+import java.io.OutputStream;
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.List;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.Future;
+ * <p>Runs "track 1" of the KDD Cup competition using whatever recommender is inside {@link Track1Recommender}
+ * and attempts to output the result in the correct contest format.</p>
+ *
+ * <p>Run as: {@code Track1Runner [track 1 data file directory] [output file]}</p>
+ */
+public final class Track1Runner {
+ private static final Logger log = LoggerFactory.getLogger(Track1Runner.class);
+ private Track1Runner() {
+ }
+ public static void main(String[] args) throws Exception {
+ File dataFileDirectory = new File(args[0]);
+ if (!dataFileDirectory.exists() || !dataFileDirectory.isDirectory()) {
+ throw new IllegalArgumentException("Bad data file directory: " + dataFileDirectory);
+ }
+ long start = System.currentTimeMillis();
+ KDDCupDataModel model = new KDDCupDataModel(KDDCupDataModel.getTrainingFile(dataFileDirectory));
+ Track1Recommender recommender = new Track1Recommender(model);
+ long end = System.currentTimeMillis();
+ log.info("Loaded model in {}s", (end - start) / 1000);
+ start = end;
+ Collection<Track1Callable> callables = new ArrayList<>();
+ for (Pair<PreferenceArray,long[]> tests : new DataFileIterable(KDDCupDataModel.getTestFile(dataFileDirectory))) {
+ PreferenceArray userTest = tests.getFirst();
+ callables.add(new Track1Callable(recommender, userTest));
+ }
+ int cores = Runtime.getRuntime().availableProcessors();
+ log.info("Running on {} cores", cores);
+ ExecutorService executor = Executors.newFixedThreadPool(cores);
+ List<Future<byte[]>> results = executor.invokeAll(callables);
+ executor.shutdown();
+ end = System.currentTimeMillis();
+ log.info("Ran recommendations in {}s", (end - start) / 1000);
+ start = end;
+ try (OutputStream out = new BufferedOutputStream(new FileOutputStream(new File(args[1])))){
+ for (Future<byte[]> result : results) {
+ for (byte estimate : result.get()) {
+ out.write(estimate);
+ }
+ }
+ }
+ end = System.currentTimeMillis();
+ log.info("Wrote output in {}s", (end - start) / 1000);
+ }

diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/svd/DataModelFactorizablePreferences.java b/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/svd/DataModelFactorizablePreferences.java
new file mode 100644
index 0000000..022d78c
--- /dev/null
+++ b/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/svd/DataModelFactorizablePreferences.java
@@ -0,0 +1,107 @@
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.cf.taste.example.kddcup.track1.svd;
+import org.apache.mahout.cf.taste.common.TasteException;
+import org.apache.mahout.cf.taste.impl.common.FastIDSet;
+import org.apache.mahout.cf.taste.impl.common.LongPrimitiveIterator;
+import org.apache.mahout.cf.taste.impl.model.GenericPreference;
+import org.apache.mahout.cf.taste.model.DataModel;
+import org.apache.mahout.cf.taste.model.Preference;
+import java.util.ArrayList;
+import java.util.List;
+ * can be used to drop {@link DataModel}s into {@link ParallelArraysSGDFactorizer}
+ */
+public class DataModelFactorizablePreferences implements FactorizablePreferences {
+ private final FastIDSet userIDs;
+ private final FastIDSet itemIDs;
+ private final List<Preference> preferences;
+ private final float minPreference;
+ private final float maxPreference;
+ public DataModelFactorizablePreferences(DataModel dataModel) {
+ minPreference = dataModel.getMinPreference();
+ maxPreference = dataModel.getMaxPreference();
+ try {
+ userIDs = new FastIDSet(dataModel.getNumUsers());
+ itemIDs = new FastIDSet(dataModel.getNumItems());
+ preferences = new ArrayList<>();
+ LongPrimitiveIterator userIDsIterator = dataModel.getUserIDs();
+ while (userIDsIterator.hasNext()) {
+ long userID = userIDsIterator.nextLong();
+ userIDs.add(userID);
+ for (Preference preference : dataModel.getPreferencesFromUser(userID)) {
+ itemIDs.add(preference.getItemID());
+ preferences.add(new GenericPreference(userID, preference.getItemID(), preference.getValue()));
+ }
+ }
+ } catch (TasteException te) {
+ throw new IllegalStateException("Unable to create factorizable preferences!", te);
+ }
+ }
+ @Override
+ public LongPrimitiveIterator getUserIDs() {
+ return userIDs.iterator();
+ }
+ @Override
+ public LongPrimitiveIterator getItemIDs() {
+ return itemIDs.iterator();
+ }
+ @Override
+ public Iterable<Preference> getPreferences() {
+ return preferences;
+ }
+ @Override
+ public float getMinPreference() {
+ return minPreference;
+ }
+ @Override
+ public float getMaxPreference() {
+ return maxPreference;
+ }
+ @Override
+ public int numUsers() {
+ return userIDs.size();
+ }
+ @Override
+ public int numItems() {
+ return itemIDs.size();
+ }
+ @Override
+ public int numPreferences() {
+ return preferences.size();
+ }

diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/svd/FactorizablePreferences.java b/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/svd/FactorizablePreferences.java
new file mode 100644
index 0000000..a126dec
--- /dev/null
+++ b/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/svd/FactorizablePreferences.java
@@ -0,0 +1,44 @@
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.cf.taste.example.kddcup.track1.svd;
+import org.apache.mahout.cf.taste.impl.common.LongPrimitiveIterator;
+import org.apache.mahout.cf.taste.model.Preference;
+ * models the necessary input for {@link ParallelArraysSGDFactorizer}
+ */
+public interface FactorizablePreferences {
+ LongPrimitiveIterator getUserIDs();
+ LongPrimitiveIterator getItemIDs();
+ Iterable<Preference> getPreferences();
+ float getMinPreference();
+ float getMaxPreference();
+ int numUsers();
+ int numItems();
+ int numPreferences();

diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/svd/KDDCupFactorizablePreferences.java b/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/svd/KDDCupFactorizablePreferences.java
new file mode 100644
index 0000000..6dcef6b
--- /dev/null
+++ b/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/svd/KDDCupFactorizablePreferences.java
@@ -0,0 +1,123 @@
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.cf.taste.example.kddcup.track1.svd;
+import com.google.common.base.Function;
+import com.google.common.collect.Iterables;
+import org.apache.mahout.cf.taste.example.kddcup.DataFileIterable;
+import org.apache.mahout.cf.taste.impl.common.AbstractLongPrimitiveIterator;
+import org.apache.mahout.cf.taste.impl.common.LongPrimitiveIterator;
+import org.apache.mahout.cf.taste.model.Preference;
+import org.apache.mahout.cf.taste.model.PreferenceArray;
+import org.apache.mahout.common.Pair;
+import java.io.File;
+public class KDDCupFactorizablePreferences implements FactorizablePreferences {
+ private final File dataFile;
+ public KDDCupFactorizablePreferences(File dataFile) {
+ this.dataFile = dataFile;
+ }
+ @Override
+ public LongPrimitiveIterator getUserIDs() {
+ return new FixedSizeLongIterator(numUsers());
+ }
+ @Override
+ public LongPrimitiveIterator getItemIDs() {
+ return new FixedSizeLongIterator(numItems());
+ }
+ @Override
+ public Iterable<Preference> getPreferences() {
+ Iterable<Iterable<Preference>> prefIterators =
+ Iterables.transform(new DataFileIterable(dataFile),
+ new Function<Pair<PreferenceArray,long[]>,Iterable<Preference>>() {
+ @Override
+ public Iterable<Preference> apply(Pair<PreferenceArray,long[]> from) {
+ return from.getFirst();
+ }
+ });
+ return Iterables.concat(prefIterators);
+ }
+ @Override
+ public float getMinPreference() {
+ return 0;
+ }
+ @Override
+ public float getMaxPreference() {
+ return 100;
+ }
+ @Override
+ public int numUsers() {
+ return 1000990;
+ }
+ @Override
+ public int numItems() {
+ return 624961;
+ }
+ @Override
+ public int numPreferences() {
+ return 252800275;
+ }
+ static class FixedSizeLongIterator extends AbstractLongPrimitiveIterator {
+ private long currentValue;
+ private final long maximum;
+ FixedSizeLongIterator(long maximum) {
+ this.maximum = maximum;
+ currentValue = 0;
+ }
+ @Override
+ public long nextLong() {
+ return currentValue++;
+ }
+ @Override
+ public long peek() {
+ return currentValue;
+ }
+ @Override
+ public void skip(int n) {
+ currentValue += n;
+ }
+ @Override
+ public boolean hasNext() {
+ return currentValue < maximum;
+ }
+ @Override
+ public void remove() {
+ throw new UnsupportedOperationException();
+ }
+ }
2018-06-27 13:14:31 UTC
diff --git a/examples/src/main/java/org/apache/mahout/clustering/syntheticcontrol/fuzzykmeans/Job.java b/examples/src/main/java/org/apache/mahout/clustering/syntheticcontrol/fuzzykmeans/Job.java
deleted file mode 100644
index 43beb78..0000000
--- a/examples/src/main/java/org/apache/mahout/clustering/syntheticcontrol/fuzzykmeans/Job.java
+++ /dev/null
@@ -1,144 +0,0 @@
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.mahout.clustering.syntheticcontrol.fuzzykmeans;
-import java.util.List;
-import java.util.Map;
-import org.apache.commons.cli2.builder.ArgumentBuilder;
-import org.apache.commons.cli2.builder.DefaultOptionBuilder;
-import org.apache.hadoop.conf.Configuration;
-import org.apache.hadoop.fs.Path;
-import org.apache.hadoop.util.ToolRunner;
-import org.apache.mahout.clustering.canopy.CanopyDriver;
-import org.apache.mahout.clustering.conversion.InputDriver;
-import org.apache.mahout.clustering.fuzzykmeans.FuzzyKMeansDriver;
-import org.apache.mahout.common.AbstractJob;
-import org.apache.mahout.common.ClassUtils;
-import org.apache.mahout.common.HadoopUtil;
-import org.apache.mahout.common.commandline.DefaultOptionCreator;
-import org.apache.mahout.common.distance.DistanceMeasure;
-import org.apache.mahout.common.distance.EuclideanDistanceMeasure;
-import org.apache.mahout.common.distance.SquaredEuclideanDistanceMeasure;
-import org.apache.mahout.utils.clustering.ClusterDumper;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
-public final class Job extends AbstractJob {
- private static final Logger log = LoggerFactory.getLogger(Job.class);
- private static final String DIRECTORY_CONTAINING_CONVERTED_INPUT = "data";
- private static final String M_OPTION = FuzzyKMeansDriver.M_OPTION;
- private Job() {
- }
- public static void main(String[] args) throws Exception {
- if (args.length > 0) {
- log.info("Running with only user-supplied arguments");
- ToolRunner.run(new Configuration(), new Job(), args);
- } else {
- log.info("Running with default arguments");
- Path output = new Path("output");
- Configuration conf = new Configuration();
- HadoopUtil.delete(conf, output);
- run(conf, new Path("testdata"), output, new EuclideanDistanceMeasure(), 80, 55, 10, 2.0f, 0.5);
- }
- }
- @Override
- public int run(String[] args) throws Exception {
- addInputOption();
- addOutputOption();
- addOption(DefaultOptionCreator.distanceMeasureOption().create());
- addOption(DefaultOptionCreator.convergenceOption().create());
- addOption(DefaultOptionCreator.maxIterationsOption().create());
- addOption(DefaultOptionCreator.overwriteOption().create());
- addOption(DefaultOptionCreator.t1Option().create());
- addOption(DefaultOptionCreator.t2Option().create());
- addOption(M_OPTION, M_OPTION, "coefficient normalization factor, must be greater than 1", true);
- Map<String,List<String>> argMap = parseArguments(args);
- if (argMap == null) {
- return -1;
- }
- Path input = getInputPath();
- Path output = getOutputPath();
- String measureClass = getOption(DefaultOptionCreator.DISTANCE_MEASURE_OPTION);
- if (measureClass == null) {
- measureClass = SquaredEuclideanDistanceMeasure.class.getName();
- }
- double convergenceDelta = Double.parseDouble(getOption(DefaultOptionCreator.CONVERGENCE_DELTA_OPTION));
- int maxIterations = Integer.parseInt(getOption(DefaultOptionCreator.MAX_ITERATIONS_OPTION));
- float fuzziness = Float.parseFloat(getOption(M_OPTION));
- addOption(new DefaultOptionBuilder().withLongName(M_OPTION).withRequired(true)
- .withArgument(new ArgumentBuilder().withName(M_OPTION).withMinimum(1).withMaximum(1).create())
- .withDescription("coefficient normalization factor, must be greater than 1").withShortName(M_OPTION).create());
- if (hasOption(DefaultOptionCreator.OVERWRITE_OPTION)) {
- HadoopUtil.delete(getConf(), output);
- }
- DistanceMeasure measure = ClassUtils.instantiateAs(measureClass, DistanceMeasure.class);
- double t1 = Double.parseDouble(getOption(DefaultOptionCreator.T1_OPTION));
- double t2 = Double.parseDouble(getOption(DefaultOptionCreator.T2_OPTION));
- run(getConf(), input, output, measure, t1, t2, maxIterations, fuzziness, convergenceDelta);
- return 0;
- }
- /**
- * Run the kmeans clustering job on an input dataset using the given distance measure, t1, t2 and iteration
- * parameters. All output data will be written to the output directory, which will be initially deleted if it exists.
- * The clustered points will reside in the path <output>/clustered-points. By default, the job expects the a file
- * containing synthetic_control.data as obtained from
- * http://archive.ics.uci.edu/ml/datasets/Synthetic+Control+Chart+Time+Series resides in a directory named "testdata",
- * and writes output to a directory named "output".
- *
- * @param input
- * the String denoting the input directory path
- * @param output
- * the String denoting the output directory path
- * @param t1
- * the canopy T1 threshold
- * @param t2
- * the canopy T2 threshold
- * @param maxIterations
- * the int maximum number of iterations
- * @param fuzziness
- * the float "m" fuzziness coefficient
- * @param convergenceDelta
- * the double convergence criteria for iterations
- */
- public static void run(Configuration conf, Path input, Path output, DistanceMeasure measure, double t1, double t2,
- int maxIterations, float fuzziness, double convergenceDelta) throws Exception {
- Path directoryContainingConvertedInput = new Path(output, DIRECTORY_CONTAINING_CONVERTED_INPUT);
- log.info("Preparing Input");
- InputDriver.runJob(input, directoryContainingConvertedInput, "org.apache.mahout.math.RandomAccessSparseVector");
- log.info("Running Canopy to get initial clusters");
- Path canopyOutput = new Path(output, "canopies");
- CanopyDriver.run(new Configuration(), directoryContainingConvertedInput, canopyOutput, measure, t1, t2, false, 0.0, false);
- log.info("Running FuzzyKMeans");
- FuzzyKMeansDriver.run(directoryContainingConvertedInput, new Path(canopyOutput, "clusters-0-final"), output,
- convergenceDelta, maxIterations, fuzziness, true, true, 0.0, false);
- // run ClusterDumper
- ClusterDumper clusterDumper = new ClusterDumper(new Path(output, "clusters-*-final"), new Path(output, "clusteredPoints"));
- clusterDumper.printClusters(null);
- }

diff --git a/examples/src/main/java/org/apache/mahout/clustering/syntheticcontrol/kmeans/Job.java b/examples/src/main/java/org/apache/mahout/clustering/syntheticcontrol/kmeans/Job.java
deleted file mode 100644
index 70c41fe..0000000
--- a/examples/src/main/java/org/apache/mahout/clustering/syntheticcontrol/kmeans/Job.java
+++ /dev/null
@@ -1,187 +0,0 @@
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.mahout.clustering.syntheticcontrol.kmeans;
-import java.util.List;
-import java.util.Map;
-import org.apache.hadoop.conf.Configuration;
-import org.apache.hadoop.fs.Path;
-import org.apache.hadoop.util.ToolRunner;
-import org.apache.mahout.clustering.Cluster;
-import org.apache.mahout.clustering.canopy.CanopyDriver;
-import org.apache.mahout.clustering.conversion.InputDriver;
-import org.apache.mahout.clustering.kmeans.KMeansDriver;
-import org.apache.mahout.clustering.kmeans.RandomSeedGenerator;
-import org.apache.mahout.common.AbstractJob;
-import org.apache.mahout.common.ClassUtils;
-import org.apache.mahout.common.HadoopUtil;
-import org.apache.mahout.common.commandline.DefaultOptionCreator;
-import org.apache.mahout.common.distance.DistanceMeasure;
-import org.apache.mahout.common.distance.EuclideanDistanceMeasure;
-import org.apache.mahout.common.distance.SquaredEuclideanDistanceMeasure;
-import org.apache.mahout.utils.clustering.ClusterDumper;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
-public final class Job extends AbstractJob {
- private static final Logger log = LoggerFactory.getLogger(Job.class);
- private static final String DIRECTORY_CONTAINING_CONVERTED_INPUT = "data";
- private Job() {
- }
- public static void main(String[] args) throws Exception {
- if (args.length > 0) {
- log.info("Running with only user-supplied arguments");
- ToolRunner.run(new Configuration(), new Job(), args);
- } else {
- log.info("Running with default arguments");
- Path output = new Path("output");
- Configuration conf = new Configuration();
- HadoopUtil.delete(conf, output);
- run(conf, new Path("testdata"), output, new EuclideanDistanceMeasure(), 6, 0.5, 10);
- }
- }
- @Override
- public int run(String[] args) throws Exception {
- addInputOption();
- addOutputOption();
- addOption(DefaultOptionCreator.distanceMeasureOption().create());
- addOption(DefaultOptionCreator.numClustersOption().create());
- addOption(DefaultOptionCreator.t1Option().create());
- addOption(DefaultOptionCreator.t2Option().create());
- addOption(DefaultOptionCreator.convergenceOption().create());
- addOption(DefaultOptionCreator.maxIterationsOption().create());
- addOption(DefaultOptionCreator.overwriteOption().create());
- Map<String,List<String>> argMap = parseArguments(args);
- if (argMap == null) {
- return -1;
- }
- Path input = getInputPath();
- Path output = getOutputPath();
- String measureClass = getOption(DefaultOptionCreator.DISTANCE_MEASURE_OPTION);
- if (measureClass == null) {
- measureClass = SquaredEuclideanDistanceMeasure.class.getName();
- }
- double convergenceDelta = Double.parseDouble(getOption(DefaultOptionCreator.CONVERGENCE_DELTA_OPTION));
- int maxIterations = Integer.parseInt(getOption(DefaultOptionCreator.MAX_ITERATIONS_OPTION));
- if (hasOption(DefaultOptionCreator.OVERWRITE_OPTION)) {
- HadoopUtil.delete(getConf(), output);
- }
- DistanceMeasure measure = ClassUtils.instantiateAs(measureClass, DistanceMeasure.class);
- if (hasOption(DefaultOptionCreator.NUM_CLUSTERS_OPTION)) {
- int k = Integer.parseInt(getOption(DefaultOptionCreator.NUM_CLUSTERS_OPTION));
- run(getConf(), input, output, measure, k, convergenceDelta, maxIterations);
- } else {
- double t1 = Double.parseDouble(getOption(DefaultOptionCreator.T1_OPTION));
- double t2 = Double.parseDouble(getOption(DefaultOptionCreator.T2_OPTION));
- run(getConf(), input, output, measure, t1, t2, convergenceDelta, maxIterations);
- }
- return 0;
- }
- /**
- * Run the kmeans clustering job on an input dataset using the given the number of clusters k and iteration
- * parameters. All output data will be written to the output directory, which will be initially deleted if it exists.
- * The clustered points will reside in the path <output>/clustered-points. By default, the job expects a file
- * containing equal length space delimited data that resides in a directory named "testdata", and writes output to a
- * directory named "output".
- *
- * @param conf
- * the Configuration to use
- * @param input
- * the String denoting the input directory path
- * @param output
- * the String denoting the output directory path
- * @param measure
- * the DistanceMeasure to use
- * @param k
- * the number of clusters in Kmeans
- * @param convergenceDelta
- * the double convergence criteria for iterations
- * @param maxIterations
- * the int maximum number of iterations
- */
- public static void run(Configuration conf, Path input, Path output, DistanceMeasure measure, int k,
- double convergenceDelta, int maxIterations) throws Exception {
- Path directoryContainingConvertedInput = new Path(output, DIRECTORY_CONTAINING_CONVERTED_INPUT);
- log.info("Preparing Input");
- InputDriver.runJob(input, directoryContainingConvertedInput, "org.apache.mahout.math.RandomAccessSparseVector");
- log.info("Running random seed to get initial clusters");
- Path clusters = new Path(output, "random-seeds");
- clusters = RandomSeedGenerator.buildRandom(conf, directoryContainingConvertedInput, clusters, k, measure);
- log.info("Running KMeans with k = {}", k);
- KMeansDriver.run(conf, directoryContainingConvertedInput, clusters, output, convergenceDelta,
- maxIterations, true, 0.0, false);
- // run ClusterDumper
- Path outGlob = new Path(output, "clusters-*-final");
- Path clusteredPoints = new Path(output,"clusteredPoints");
- log.info("Dumping out clusters from clusters: {} and clusteredPoints: {}", outGlob, clusteredPoints);
- ClusterDumper clusterDumper = new ClusterDumper(outGlob, clusteredPoints);
- clusterDumper.printClusters(null);
- }
- /**
- * Run the kmeans clustering job on an input dataset using the given distance measure, t1, t2 and iteration
- * parameters. All output data will be written to the output directory, which will be initially deleted if it exists.
- * The clustered points will reside in the path <output>/clustered-points. By default, the job expects the a file
- * containing synthetic_control.data as obtained from
- * http://archive.ics.uci.edu/ml/datasets/Synthetic+Control+Chart+Time+Series resides in a directory named "testdata",
- * and writes output to a directory named "output".
- *
- * @param conf
- * the Configuration to use
- * @param input
- * the String denoting the input directory path
- * @param output
- * the String denoting the output directory path
- * @param measure
- * the DistanceMeasure to use
- * @param t1
- * the canopy T1 threshold
- * @param t2
- * the canopy T2 threshold
- * @param convergenceDelta
- * the double convergence criteria for iterations
- * @param maxIterations
- * the int maximum number of iterations
- */
- public static void run(Configuration conf, Path input, Path output, DistanceMeasure measure, double t1, double t2,
- double convergenceDelta, int maxIterations) throws Exception {
- Path directoryContainingConvertedInput = new Path(output, DIRECTORY_CONTAINING_CONVERTED_INPUT);
- log.info("Preparing Input");
- InputDriver.runJob(input, directoryContainingConvertedInput, "org.apache.mahout.math.RandomAccessSparseVector");
- log.info("Running Canopy to get initial clusters");
- Path canopyOutput = new Path(output, "canopies");
- CanopyDriver.run(new Configuration(), directoryContainingConvertedInput, canopyOutput, measure, t1, t2, false, 0.0,
- false);
- log.info("Running KMeans");
- KMeansDriver.run(conf, directoryContainingConvertedInput, new Path(canopyOutput, Cluster.INITIAL_CLUSTERS_DIR
- + "-final"), output, convergenceDelta, maxIterations, true, 0.0, false);
- // run ClusterDumper
- ClusterDumper clusterDumper = new ClusterDumper(new Path(output, "clusters-*-final"), new Path(output,
- "clusteredPoints"));
- clusterDumper.printClusters(null);
- }

diff --git a/examples/src/main/java/org/apache/mahout/fpm/pfpgrowth/DeliciousTagsExample.java b/examples/src/main/java/org/apache/mahout/fpm/pfpgrowth/DeliciousTagsExample.java
deleted file mode 100644
index 92363e5..0000000
--- a/examples/src/main/java/org/apache/mahout/fpm/pfpgrowth/DeliciousTagsExample.java
+++ /dev/null
@@ -1,94 +0,0 @@
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.mahout.fpm.pfpgrowth;
-import java.io.IOException;
-import org.apache.commons.cli2.CommandLine;
-import org.apache.commons.cli2.Group;
-import org.apache.commons.cli2.Option;
-import org.apache.commons.cli2.OptionException;
-import org.apache.commons.cli2.builder.ArgumentBuilder;
-import org.apache.commons.cli2.builder.DefaultOptionBuilder;
-import org.apache.commons.cli2.builder.GroupBuilder;
-import org.apache.commons.cli2.commandline.Parser;
-import org.apache.mahout.common.CommandLineUtil;
-import org.apache.mahout.common.Parameters;
-import org.apache.mahout.common.commandline.DefaultOptionCreator;
-import org.apache.mahout.fpm.pfpgrowth.dataset.KeyBasedStringTupleGrouper;
-public final class DeliciousTagsExample {
- private DeliciousTagsExample() { }
- public static void main(String[] args) throws IOException, InterruptedException, ClassNotFoundException {
- DefaultOptionBuilder obuilder = new DefaultOptionBuilder();
- ArgumentBuilder abuilder = new ArgumentBuilder();
- GroupBuilder gbuilder = new GroupBuilder();
- Option inputDirOpt = DefaultOptionCreator.inputOption().create();
- Option outputOpt = DefaultOptionCreator.outputOption().create();
- Option helpOpt = DefaultOptionCreator.helpOption();
- Option recordSplitterOpt = obuilder.withLongName("splitterPattern").withArgument(
- abuilder.withName("splitterPattern").withMinimum(1).withMaximum(1).create()).withDescription(
- "Regular Expression pattern used to split given line into fields."
- + " Default value splits comma or tab separated fields."
- + " Default Value: \"[ ,\\t]*\\t[ ,\\t]*\" ").withShortName("regex").create();
- Option encodingOpt = obuilder.withLongName("encoding").withArgument(
- abuilder.withName("encoding").withMinimum(1).withMaximum(1).create()).withDescription(
- "(Optional) The file encoding. Default value: UTF-8").withShortName("e").create();
- Group group = gbuilder.withName("Options").withOption(inputDirOpt).withOption(outputOpt).withOption(
- helpOpt).withOption(recordSplitterOpt).withOption(encodingOpt).create();
- try {
- Parser parser = new Parser();
- parser.setGroup(group);
- CommandLine cmdLine = parser.parse(args);
- if (cmdLine.hasOption(helpOpt)) {
- CommandLineUtil.printHelp(group);
- return;
- }
- Parameters params = new Parameters();
- if (cmdLine.hasOption(recordSplitterOpt)) {
- params.set("splitPattern", (String) cmdLine.getValue(recordSplitterOpt));
- }
- String encoding = "UTF-8";
- if (cmdLine.hasOption(encodingOpt)) {
- encoding = (String) cmdLine.getValue(encodingOpt);
- }
- params.set("encoding", encoding);
- String inputDir = (String) cmdLine.getValue(inputDirOpt);
- String outputDir = (String) cmdLine.getValue(outputOpt);
- params.set("input", inputDir);
- params.set("output", outputDir);
- params.set("groupingFieldCount", "2");
- params.set("gfield0", "1");
- params.set("gfield1", "2");
- params.set("selectedFieldCount", "1");
- params.set("field0", "3");
- params.set("maxTransactionLength", "100");
- KeyBasedStringTupleGrouper.startJob(params);
- } catch (OptionException ex) {
- CommandLineUtil.printHelp(group);
- }
- }

diff --git a/examples/src/main/java/org/apache/mahout/fpm/pfpgrowth/dataset/KeyBasedStringTupleCombiner.java b/examples/src/main/java/org/apache/mahout/fpm/pfpgrowth/dataset/KeyBasedStringTupleCombiner.java
deleted file mode 100644
index 4c80a31..0000000
--- a/examples/src/main/java/org/apache/mahout/fpm/pfpgrowth/dataset/KeyBasedStringTupleCombiner.java
+++ /dev/null
@@ -1,40 +0,0 @@
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.mahout.fpm.pfpgrowth.dataset;
-import java.io.IOException;
-import java.util.HashSet;
-import java.util.Set;
-import org.apache.hadoop.io.Text;
-import org.apache.hadoop.mapreduce.Reducer;
-import org.apache.mahout.common.StringTuple;
-public class KeyBasedStringTupleCombiner extends Reducer<Text,StringTuple,Text,StringTuple> {
- @Override
- protected void reduce(Text key,
- Iterable<StringTuple> values,
- Context context) throws IOException, InterruptedException {
- Set<String> outputValues = new HashSet<>();
- for (StringTuple value : values) {
- outputValues.addAll(value.getEntries());
- }
- context.write(key, new StringTuple(outputValues));
- }

diff --git a/examples/src/main/java/org/apache/mahout/fpm/pfpgrowth/dataset/KeyBasedStringTupleGrouper.java b/examples/src/main/java/org/apache/mahout/fpm/pfpgrowth/dataset/KeyBasedStringTupleGrouper.java
deleted file mode 100644
index cd17770..0000000
--- a/examples/src/main/java/org/apache/mahout/fpm/pfpgrowth/dataset/KeyBasedStringTupleGrouper.java
+++ /dev/null
@@ -1,77 +0,0 @@
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.mahout.fpm.pfpgrowth.dataset;
-import java.io.IOException;
-import org.apache.hadoop.conf.Configuration;
-import org.apache.hadoop.fs.Path;
-import org.apache.hadoop.io.Text;
-import org.apache.hadoop.mapreduce.Job;
-import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;
-import org.apache.hadoop.mapreduce.lib.input.TextInputFormat;
-import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;
-import org.apache.hadoop.mapreduce.lib.output.TextOutputFormat;
-import org.apache.mahout.common.HadoopUtil;
-import org.apache.mahout.common.Parameters;
-import org.apache.mahout.common.StringTuple;
-public final class KeyBasedStringTupleGrouper {
- private KeyBasedStringTupleGrouper() { }
- public static void startJob(Parameters params) throws IOException,
- InterruptedException,
- ClassNotFoundException {
- Configuration conf = new Configuration();
- conf.set("job.parameters", params.toString());
- conf.set("mapred.compress.map.output", "true");
- conf.set("mapred.output.compression.type", "BLOCK");
- conf.set("mapred.map.output.compression.codec", "org.apache.hadoop.io.compress.GzipCodec");
- conf.set("io.serializations", "org.apache.hadoop.io.serializer.JavaSerialization,"
- + "org.apache.hadoop.io.serializer.WritableSerialization");
- String input = params.get("input");
- Job job = new Job(conf, "Generating dataset based from input" + input);
- job.setJarByClass(KeyBasedStringTupleGrouper.class);
- job.setMapOutputKeyClass(Text.class);
- job.setMapOutputValueClass(StringTuple.class);
- job.setOutputKeyClass(Text.class);
- job.setOutputValueClass(Text.class);
- FileInputFormat.addInputPath(job, new Path(input));
- Path outPath = new Path(params.get("output"));
- FileOutputFormat.setOutputPath(job, outPath);
- HadoopUtil.delete(conf, outPath);
- job.setInputFormatClass(TextInputFormat.class);
- job.setMapperClass(KeyBasedStringTupleMapper.class);
- job.setCombinerClass(KeyBasedStringTupleCombiner.class);
- job.setReducerClass(KeyBasedStringTupleReducer.class);
- job.setOutputFormatClass(TextOutputFormat.class);
- boolean succeeded = job.waitForCompletion(true);
- if (!succeeded) {
- throw new IllegalStateException("Job failed!");
- }
- }

diff --git a/examples/src/main/java/org/apache/mahout/fpm/pfpgrowth/dataset/KeyBasedStringTupleMapper.java b/examples/src/main/java/org/apache/mahout/fpm/pfpgrowth/dataset/KeyBasedStringTupleMapper.java
deleted file mode 100644
index 362d1ce..0000000
--- a/examples/src/main/java/org/apache/mahout/fpm/pfpgrowth/dataset/KeyBasedStringTupleMapper.java
+++ /dev/null
@@ -1,90 +0,0 @@
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.mahout.fpm.pfpgrowth.dataset;
-import java.io.IOException;
-import java.util.ArrayList;
-import java.util.Collection;
-import java.util.List;
-import java.util.regex.Pattern;
-import org.apache.hadoop.io.LongWritable;
-import org.apache.hadoop.io.Text;
-import org.apache.hadoop.mapreduce.Mapper;
-import org.apache.mahout.common.Parameters;
-import org.apache.mahout.common.StringTuple;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
- * Splits the line using a {@link Pattern} and outputs key as given by the groupingFields
- *
- */
-public class KeyBasedStringTupleMapper extends Mapper<LongWritable,Text,Text,StringTuple> {
- private static final Logger log = LoggerFactory.getLogger(KeyBasedStringTupleMapper.class);
- private Pattern splitter;
- private int[] selectedFields;
- private int[] groupingFields;
- @Override
- protected void map(LongWritable key, Text value, Context context) throws IOException, InterruptedException {
- String[] fields = splitter.split(value.toString());
- if (fields.length != 4) {
- log.info("{} {}", fields.length, value.toString());
- context.getCounter("Map", "ERROR").increment(1);
- return;
- }
- Collection<String> oKey = new ArrayList<>();
- for (int groupingField : groupingFields) {
- oKey.add(fields[groupingField]);
- context.setStatus(fields[groupingField]);
- }
- List<String> oValue = new ArrayList<>();
- for (int selectedField : selectedFields) {
- oValue.add(fields[selectedField]);
- }
- context.write(new Text(oKey.toString()), new StringTuple(oValue));
- }
- @Override
- protected void setup(Context context) throws IOException, InterruptedException {
- super.setup(context);
- Parameters params = new Parameters(context.getConfiguration().get("job.parameters", ""));
- splitter = Pattern.compile(params.get("splitPattern", "[ \t]*\t[ \t]*"));
- int selectedFieldCount = Integer.valueOf(params.get("selectedFieldCount", "0"));
- selectedFields = new int[selectedFieldCount];
- for (int i = 0; i < selectedFieldCount; i++) {
- selectedFields[i] = Integer.valueOf(params.get("field" + i, "0"));
- }
- int groupingFieldCount = Integer.valueOf(params.get("groupingFieldCount", "0"));
- groupingFields = new int[groupingFieldCount];
- for (int i = 0; i < groupingFieldCount; i++) {
- groupingFields[i] = Integer.valueOf(params.get("gfield" + i, "0"));
- }
- }

diff --git a/examples/src/main/java/org/apache/mahout/fpm/pfpgrowth/dataset/KeyBasedStringTupleReducer.java b/examples/src/main/java/org/apache/mahout/fpm/pfpgrowth/dataset/KeyBasedStringTupleReducer.java
deleted file mode 100644
index a7ef762..0000000
--- a/examples/src/main/java/org/apache/mahout/fpm/pfpgrowth/dataset/KeyBasedStringTupleReducer.java
+++ /dev/null
@@ -1,74 +0,0 @@
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.mahout.fpm.pfpgrowth.dataset;
-import java.io.IOException;
-import java.util.Collection;
-import java.util.HashSet;
-import org.apache.hadoop.io.Text;
-import org.apache.hadoop.mapreduce.Reducer;
-import org.apache.mahout.common.Parameters;
-import org.apache.mahout.common.StringTuple;
-public class KeyBasedStringTupleReducer extends Reducer<Text,StringTuple,Text,Text> {
- private int maxTransactionLength = 100;
- @Override
- protected void reduce(Text key, Iterable<StringTuple> values, Context context)
- throws IOException, InterruptedException {
- Collection<String> items = new HashSet<>();
- for (StringTuple value : values) {
- for (String field : value.getEntries()) {
- items.add(field);
- }
- }
- if (items.size() > 1) {
- int i = 0;
- StringBuilder sb = new StringBuilder();
- String sep = "";
- for (String field : items) {
- if (i % maxTransactionLength == 0) {
- if (i != 0) {
- context.write(null, new Text(sb.toString()));
- }
- sb.replace(0, sb.length(), "");
- sep = "";
- }
- sb.append(sep).append(field);
- sep = "\t";
- i++;
- }
- if (sb.length() > 0) {
- context.write(null, new Text(sb.toString()));
- }
- }
- }
- @Override
- protected void setup(Context context) throws IOException, InterruptedException {
- super.setup(context);
- Parameters params = new Parameters(context.getConfiguration().get("job.parameters", ""));
- maxTransactionLength = Integer.valueOf(params.get("maxTransactionLength", "100"));
- }
2018-06-27 13:14:47 UTC
diff --git a/community/mahout-mr/examples/bin/resources/country.txt b/community/mahout-mr/examples/bin/resources/country.txt
new file mode 100644
index 0000000..6a22091
--- /dev/null
+++ b/community/mahout-mr/examples/bin/resources/country.txt
@@ -0,0 +1,229 @@
+American Samoa
+Antigua and Barbuda
+Bosnia and Herzegovina
+Bouvet Island
+British Indian Ocean Territory
+Brunei Darussalam
+Burkina Faso
+Cape Verde
+Cayman Islands
+Central African Republic
+Christmas Island
+Cocos Islands
+Cook Islands
+Costa Rica
+C�te d'Ivoire
+Czech Republic
+Dominican Republic
+El Salvador
+Equatorial Guinea
+Falkland Islands
+Faroe Islands
+French Guiana
+French Polynesia
+French Southern Territories
+Hong Kong
+Isle of Man
+Marshall Islands
+Netherlands Antilles
+New Caledonia
+New Zealand
+Norfolk Island
+Northern Mariana Islands
+Palestinian Territory
+Papua New Guinea
+Puerto Rico
+Russian Federation
+Saint Barth�lemy
+Saint Helena
+Saint Kitts and Nevis
+Saint Lucia
+Saint Martin
+Saint Pierre and Miquelon
+Saint Vincent and the Grenadines
+San Marino
+Sao Tome and Principe
+Saudi Arabia
+Sierra Leone
+Solomon Islands
+South Africa
+South Georgia and the South Sandwich Islands
+Sri Lanka
+Svalbard and Jan Mayen
+Syrian Arab Republic
+Trinidad and Tobago
+Turks and Caicos Islands
+United Arab Emirates
+United Kingdom
+United States
+United States Minor Outlying Islands
+Virgin Islands
+Wallis and Futuna

diff --git a/community/mahout-mr/examples/bin/resources/country10.txt b/community/mahout-mr/examples/bin/resources/country10.txt
new file mode 100644
index 0000000..97a63e1
--- /dev/null
+++ b/community/mahout-mr/examples/bin/resources/country10.txt
@@ -0,0 +1,10 @@
+United Kingdom

diff --git a/community/mahout-mr/examples/bin/resources/country2.txt b/community/mahout-mr/examples/bin/resources/country2.txt
new file mode 100644
index 0000000..f4b4f61
--- /dev/null
+++ b/community/mahout-mr/examples/bin/resources/country2.txt
@@ -0,0 +1,2 @@
+United States
+United Kingdom

diff --git a/community/mahout-mr/examples/bin/resources/donut-test.csv b/community/mahout-mr/examples/bin/resources/donut-test.csv
new file mode 100644
index 0000000..46ea564
--- /dev/null
+++ b/community/mahout-mr/examples/bin/resources/donut-test.csv
@@ -0,0 +1,41 @@

diff --git a/community/mahout-mr/examples/bin/resources/donut.csv b/community/mahout-mr/examples/bin/resources/donut.csv
new file mode 100644
index 0000000..33ba3b7
--- /dev/null
+++ b/community/mahout-mr/examples/bin/resources/donut.csv
@@ -0,0 +1,41 @@

diff --git a/community/mahout-mr/examples/bin/resources/test-data.csv b/community/mahout-mr/examples/bin/resources/test-data.csv
new file mode 100644
index 0000000..ab683cd
--- /dev/null
+++ b/community/mahout-mr/examples/bin/resources/test-data.csv
@@ -0,0 +1,61 @@

diff --git a/community/mahout-mr/examples/bin/set-dfs-commands.sh b/community/mahout-mr/examples/bin/set-dfs-commands.sh
new file mode 100755
index 0000000..0ee5fe1
--- /dev/null
+++ b/community/mahout-mr/examples/bin/set-dfs-commands.sh
@@ -0,0 +1,54 @@
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# Requires $HADOOP_HOME to be set.
+# Figures out the major version of Hadoop we're using and sets commands
+# for dfs commands
+# Run by each example script.
+# Find a hadoop shell
+if [ "$HADOOP_HOME" != "" ] && [ "$MAHOUT_LOCAL" == "" ] ; then
+ HADOOP="${HADOOP_HOME}/bin/hadoop"
+ if [ ! -e $HADOOP ]; then
+ echo "Can't find hadoop in $HADOOP, exiting"
+ exit 1
+ fi
+# Check Hadoop version
+v=`${HADOOP_HOME}/bin/hadoop version | egrep "Hadoop [0-9]+.[0-9]+.[0-9]+" | cut -f 2 -d ' ' | cut -f 1 -d '.'`
+if [ $v -eq "1" -o $v -eq "0" ]
+ echo "Discovered Hadoop v0 or v1."
+ export DFS="${HADOOP_HOME}/bin/hadoop dfs"
+ export DFSRM="$DFS -rmr -skipTrash"
+elif [ $v -eq "2" ]
+ echo "Discovered Hadoop v2."
+ export DFS="${HADOOP_HOME}/bin/hdfs dfs"
+ export DFSRM="$DFS -rm -r -skipTrash"
+ echo "Can't determine Hadoop version."
+ exit 1
+echo "Setting dfs command to $DFS, dfs rm to $DFSRM."
+export HVERSION=$v

diff --git a/community/mahout-mr/examples/pom.xml b/community/mahout-mr/examples/pom.xml
new file mode 100644
index 0000000..28a5795
--- /dev/null
+++ b/community/mahout-mr/examples/pom.xml
@@ -0,0 +1,199 @@
+<?xml version="1.0" encoding="UTF-8"?>
+ Licensed to the Apache Software Foundation (ASF) under one or more
+ contributor license agreements. See the NOTICE file distributed with
+ this work for additional information regarding copyright ownership.
+ The ASF licenses this file to You under the Apache License, Version 2.0
+ (the "License"); you may not use this file except in compliance with
+ the License. You may obtain a copy of the License at
+ http://www.apache.org/licenses/LICENSE-2.0
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ See the License for the specific language governing permissions and
+ limitations under the License.
+<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/maven-v4_0_0.xsd">
+ <modelVersion>4.0.0</modelVersion>
+ <parent>
+ <groupId>org.apache.mahout</groupId>
+ <artifactId>mahout-mr</artifactId>
+ <version>0.14.0-SNAPSHOT</version>
+ <relativePath>../pom.xml</relativePath>
+ </parent>
+ <artifactId>mr-examples</artifactId>
+ <name>Mahout Examples</name>
+ <description>Scalable machine learning library examples</description>
+ <packaging>jar</packaging>
+ <properties>
+ <mahout.skip.example>false</mahout.skip.example>
+ </properties>
+ <build>
+ <plugins>
+ <plugin>
+ <groupId>org.apache.maven.plugins</groupId>
+ <artifactId>maven-dependency-plugin</artifactId>
+ <executions>
+ <execution>
+ <id>copy-dependencies</id>
+ <phase>package</phase>
+ <goals>
+ <goal>copy-dependencies</goal>
+ </goals>
+ <configuration>
+ <!-- configure the plugin here -->
+ </configuration>
+ </execution>
+ </executions>
+ </plugin>
+ <!-- create examples hadoop job jar -->
+ <plugin>
+ <groupId>org.apache.maven.plugins</groupId>
+ <artifactId>maven-assembly-plugin</artifactId>
+ <executions>
+ <execution>
+ <id>job</id>
+ <phase>package</phase>
+ <goals>
+ <goal>single</goal>
+ </goals>
+ <configuration>
+ <skipAssembly>${mahout.skip.example}</skipAssembly>
+ <descriptors>
+ <descriptor>src/main/assembly/job.xml</descriptor>
+ </descriptors>
+ </configuration>
+ </execution>
+ </executions>
+ </plugin>
+ <plugin>
+ <groupId>org.apache.maven.plugins</groupId>
+ <artifactId>maven-remote-resources-plugin</artifactId>
+ <configuration>
+ <appendedResourcesDirectory>../src/main/appended-resources</appendedResourcesDirectory>
+ <resourceBundles>
+ <resourceBundle>org.apache:apache-jar-resource-bundle:1.4</resourceBundle>
+ </resourceBundles>
+ <supplementalModels>
+ <supplementalModel>supplemental-models.xml</supplementalModel>
+ </supplementalModels>
+ </configuration>
+ </plugin>
+ <plugin>
+ <artifactId>maven-source-plugin</artifactId>
+ </plugin>
+ <plugin>
+ <groupId>org.mortbay.jetty</groupId>
+ <artifactId>maven-jetty-plugin</artifactId>
+ <version>6.1.26</version>
+ </plugin>
+ </plugins>
+ </build>
+ <dependencies>
+ <!-- our modules -->
+ <dependency>
+ <groupId>${project.groupId}</groupId>
+ <artifactId>mahout-hdfs</artifactId>
+ </dependency>
+ <dependency>
+ <groupId>${project.groupId}</groupId>
+ <artifactId>mahout-mr</artifactId>
+ </dependency>
+ <dependency>
+ <groupId>${project.groupId}</groupId>
+ <artifactId>mahout-hdfs</artifactId>
+ <type>test-jar</type>
+ <scope>test</scope>
+ </dependency>
+ <dependency>
+ <groupId>${project.groupId}</groupId>
+ <artifactId>mahout-mr</artifactId>
+ <type>test-jar</type>
+ <scope>test</scope>
+ </dependency>
+ <dependency>
+ <groupId>${project.groupId}</groupId>
+ <artifactId>mahout-math</artifactId>
+ </dependency>
+ <dependency>
+ <groupId>${project.groupId}</groupId>
+ <artifactId>mahout-math</artifactId>
+ <type>test-jar</type>
+ <scope>test</scope>
+ </dependency>
+ <dependency>
+ <groupId>${project.groupId}</groupId>
+ <artifactId>mahout-integration</artifactId>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.lucene</groupId>
+ <artifactId>lucene-benchmark</artifactId>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.lucene</groupId>
+ <artifactId>lucene-analyzers-common</artifactId>
+ </dependency>
+ <dependency>
+ <groupId>com.carrotsearch.randomizedtesting</groupId>
+ <artifactId>randomizedtesting-runner</artifactId>
+ </dependency>
+ <dependency>
+ <groupId>org.easymock</groupId>
+ <artifactId>easymock</artifactId>
+ </dependency>
+ <dependency>
+ <groupId>junit</groupId>
+ <artifactId>junit</artifactId>
+ </dependency>
+ <dependency>
+ <groupId>org.slf4j</groupId>
+ <artifactId>slf4j-api</artifactId>
+ </dependency>
+ <dependency>
+ <groupId>org.slf4j</groupId>
+ <artifactId>slf4j-log4j12</artifactId>
+ </dependency>
+ <dependency>
+ <groupId>org.slf4j</groupId>
+ <artifactId>jcl-over-slf4j</artifactId>
+ </dependency>
+ <dependency>
+ <groupId>commons-logging</groupId>
+ <artifactId>commons-logging</artifactId>
+ </dependency>
+ <dependency>
+ <groupId>log4j</groupId>
+ <artifactId>log4j</artifactId>
+ </dependency>
+ </dependencies>
+ <profiles>
+ <profile>
+ <id>release.prepare</id>
+ <properties>
+ <mahout.skip.example>true</mahout.skip.example>
+ </properties>
+ </profile>
+ </profiles>

diff --git a/community/mahout-mr/examples/src/main/assembly/job.xml b/community/mahout-mr/examples/src/main/assembly/job.xml
new file mode 100644
index 0000000..0c41f3d
--- /dev/null
+++ b/community/mahout-mr/examples/src/main/assembly/job.xml
@@ -0,0 +1,46 @@
+<?xml version="1.0" encoding="UTF-8"?>
+ Licensed to the Apache Software Foundation (ASF) under one or more
+ contributor license agreements. See the NOTICE file distributed with
+ this work for additional information regarding copyright ownership.
+ The ASF licenses this file to You under the Apache License, Version 2.0
+ (the "License"); you may not use this file except in compliance with
+ the License. You may obtain a copy of the License at
+ http://www.apache.org/licenses/LICENSE-2.0
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ See the License for the specific language governing permissions and
+ limitations under the License.
+ xmlns="http://maven.apache.org/plugins/maven-assembly-plugin/assembly/1.1.0"
+ xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
+ xsi:schemaLocation="http://maven.apache.org/plugins/maven-assembly-plugin/assembly/1.1.0
+ http://maven.apache.org/xsd/assembly-1.1.0.xsd">
+ <id>job</id>
+ <formats>
+ <format>jar</format>
+ </formats>
+ <includeBaseDirectory>false</includeBaseDirectory>
+ <dependencySets>
+ <dependencySet>
+ <unpack>true</unpack>
+ <unpackOptions>
+ <!-- MAHOUT-1126 -->
+ <excludes>
+ <exclude>META-INF/LICENSE</exclude>
+ </excludes>
+ </unpackOptions>
+ <scope>runtime</scope>
+ <outputDirectory>/</outputDirectory>
+ <useTransitiveFiltering>true</useTransitiveFiltering>
+ <excludes>
+ <exclude>org.apache.hadoop:hadoop-core</exclude>
+ </excludes>
+ </dependencySet>
+ </dependencySets>
\ No newline at end of file

diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/TasteOptionParser.java b/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/TasteOptionParser.java
new file mode 100644
index 0000000..6392b9f
--- /dev/null
+++ b/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/TasteOptionParser.java
@@ -0,0 +1,75 @@
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.cf.taste.example;
+import java.io.File;
+import org.apache.commons.cli2.CommandLine;
+import org.apache.commons.cli2.Group;
+import org.apache.commons.cli2.Option;
+import org.apache.commons.cli2.OptionException;
+import org.apache.commons.cli2.builder.ArgumentBuilder;
+import org.apache.commons.cli2.builder.DefaultOptionBuilder;
+import org.apache.commons.cli2.builder.GroupBuilder;
+import org.apache.commons.cli2.commandline.Parser;
+import org.apache.mahout.common.CommandLineUtil;
+import org.apache.mahout.common.commandline.DefaultOptionCreator;
+ * This class provides a common implementation for parsing input parameters for
+ * all taste examples. Currently they only need the path to the recommendations
+ * file as input.
+ *
+ * The class is safe to be used in threaded contexts.
+ */
+public final class TasteOptionParser {
+ private TasteOptionParser() {
+ }
+ /**
+ * Parse the given command line arguments.
+ * @param args the arguments as given to the application.
+ * @return the input file if a file was given on the command line, null otherwise.
+ */
+ public static File getRatings(String[] args) throws OptionException {
+ DefaultOptionBuilder obuilder = new DefaultOptionBuilder();
+ ArgumentBuilder abuilder = new ArgumentBuilder();
+ GroupBuilder gbuilder = new GroupBuilder();
+ Option inputOpt = obuilder.withLongName("input").withRequired(false).withShortName("i")
+ .withArgument(abuilder.withName("input").withMinimum(1).withMaximum(1).create())
+ .withDescription("The Path for input data directory.").create();
+ Option helpOpt = DefaultOptionCreator.helpOption();
+ Group group = gbuilder.withName("Options").withOption(inputOpt).withOption(helpOpt).create();
+ Parser parser = new Parser();
+ parser.setGroup(group);
+ CommandLine cmdLine = parser.parse(args);
+ if (cmdLine.hasOption(helpOpt)) {
+ CommandLineUtil.printHelp(group);
+ return null;
+ }
+ return cmdLine.hasOption(inputOpt) ? new File(cmdLine.getValue(inputOpt).toString()) : null;
+ }

diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/bookcrossing/BookCrossingBooleanRecommender.java b/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/bookcrossing/BookCrossingBooleanRecommender.java
new file mode 100644
index 0000000..c908e5b
--- /dev/null
+++ b/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/bookcrossing/BookCrossingBooleanRecommender.java
@@ -0,0 +1,102 @@
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.cf.taste.example.bookcrossing;
+import org.apache.mahout.cf.taste.common.Refreshable;
+import org.apache.mahout.cf.taste.common.TasteException;
+import org.apache.mahout.cf.taste.impl.neighborhood.NearestNUserNeighborhood;
+import org.apache.mahout.cf.taste.impl.recommender.GenericBooleanPrefUserBasedRecommender;
+import org.apache.mahout.cf.taste.impl.similarity.CachingUserSimilarity;
+import org.apache.mahout.cf.taste.impl.similarity.LogLikelihoodSimilarity;
+import org.apache.mahout.cf.taste.model.DataModel;
+import org.apache.mahout.cf.taste.neighborhood.UserNeighborhood;
+import org.apache.mahout.cf.taste.recommender.IDRescorer;
+import org.apache.mahout.cf.taste.recommender.RecommendedItem;
+import org.apache.mahout.cf.taste.recommender.Recommender;
+import org.apache.mahout.cf.taste.similarity.UserSimilarity;
+import java.util.Collection;
+import java.util.List;
+ * A simple {@link Recommender} implemented for the Book Crossing demo.
+ * See the <a href="http://www.informatik.uni-freiburg.de/~cziegler/BX/">Book Crossing site</a>.
+ */
+public final class BookCrossingBooleanRecommender implements Recommender {
+ private final Recommender recommender;
+ public BookCrossingBooleanRecommender(DataModel bcModel) throws TasteException {
+ UserSimilarity similarity = new CachingUserSimilarity(new LogLikelihoodSimilarity(bcModel), bcModel);
+ UserNeighborhood neighborhood =
+ new NearestNUserNeighborhood(10, Double.NEGATIVE_INFINITY, similarity, bcModel, 1.0);
+ recommender = new GenericBooleanPrefUserBasedRecommender(bcModel, neighborhood, similarity);
+ }
+ @Override
+ public List<RecommendedItem> recommend(long userID, int howMany) throws TasteException {
+ return recommender.recommend(userID, howMany);
+ }
+ @Override
+ public List<RecommendedItem> recommend(long userID, int howMany, boolean includeKnownItems) throws TasteException {
+ return recommend(userID, howMany, null, includeKnownItems);
+ }
+ @Override
+ public List<RecommendedItem> recommend(long userID, int howMany, IDRescorer rescorer) throws TasteException {
+ return recommender.recommend(userID, howMany, rescorer, false);
+ }
+ @Override
+ public List<RecommendedItem> recommend(long userID, int howMany, IDRescorer rescorer, boolean includeKnownItems)
+ throws TasteException {
+ return recommender.recommend(userID, howMany, rescorer, includeKnownItems);
+ }
+ @Override
+ public float estimatePreference(long userID, long itemID) throws TasteException {
+ return recommender.estimatePreference(userID, itemID);
+ }
+ @Override
+ public void setPreference(long userID, long itemID, float value) throws TasteException {
+ recommender.setPreference(userID, itemID, value);
+ }
+ @Override
+ public void removePreference(long userID, long itemID) throws TasteException {
+ recommender.removePreference(userID, itemID);
+ }
+ @Override
+ public DataModel getDataModel() {
+ return recommender.getDataModel();
+ }
+ @Override
+ public void refresh(Collection<Refreshable> alreadyRefreshed) {
+ recommender.refresh(alreadyRefreshed);
+ }
+ @Override
+ public String toString() {
+ return "BookCrossingBooleanRecommender[recommender:" + recommender + ']';
+ }

diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/bookcrossing/BookCrossingBooleanRecommenderBuilder.java b/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/bookcrossing/BookCrossingBooleanRecommenderBuilder.java
new file mode 100644
index 0000000..2219bce
--- /dev/null
+++ b/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/bookcrossing/BookCrossingBooleanRecommenderBuilder.java
@@ -0,0 +1,32 @@
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.cf.taste.example.bookcrossing;
+import org.apache.mahout.cf.taste.common.TasteException;
+import org.apache.mahout.cf.taste.eval.RecommenderBuilder;
+import org.apache.mahout.cf.taste.model.DataModel;
+import org.apache.mahout.cf.taste.recommender.Recommender;
+final class BookCrossingBooleanRecommenderBuilder implements RecommenderBuilder {
+ @Override
+ public Recommender buildRecommender(DataModel dataModel) throws TasteException {
+ return new BookCrossingBooleanRecommender(dataModel);
+ }

diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/bookcrossing/BookCrossingBooleanRecommenderEvaluatorRunner.java b/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/bookcrossing/BookCrossingBooleanRecommenderEvaluatorRunner.java
new file mode 100644
index 0000000..b9814c7
--- /dev/null
+++ b/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/bookcrossing/BookCrossingBooleanRecommenderEvaluatorRunner.java
@@ -0,0 +1,59 @@
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.cf.taste.example.bookcrossing;
+import org.apache.commons.cli2.OptionException;
+import org.apache.mahout.cf.taste.common.TasteException;
+import org.apache.mahout.cf.taste.eval.IRStatistics;
+import org.apache.mahout.cf.taste.eval.RecommenderIRStatsEvaluator;
+import org.apache.mahout.cf.taste.example.TasteOptionParser;
+import org.apache.mahout.cf.taste.impl.eval.GenericRecommenderIRStatsEvaluator;
+import org.apache.mahout.cf.taste.model.DataModel;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+import java.io.File;
+import java.io.IOException;
+public final class BookCrossingBooleanRecommenderEvaluatorRunner {
+ private static final Logger log = LoggerFactory.getLogger(BookCrossingBooleanRecommenderEvaluatorRunner.class);
+ private BookCrossingBooleanRecommenderEvaluatorRunner() {
+ // do nothing
+ }
+ public static void main(String... args) throws IOException, TasteException, OptionException {
+ RecommenderIRStatsEvaluator evaluator = new GenericRecommenderIRStatsEvaluator();
+ File ratingsFile = TasteOptionParser.getRatings(args);
+ DataModel model =
+ ratingsFile == null ? new BookCrossingDataModel(true) : new BookCrossingDataModel(ratingsFile, true);
+ IRStatistics evaluation = evaluator.evaluate(
+ new BookCrossingBooleanRecommenderBuilder(),
+ new BookCrossingDataModelBuilder(),
+ model,
+ null,
+ 3,
+ 1.0);
+ log.info(String.valueOf(evaluation));
+ }

diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/bookcrossing/BookCrossingDataModel.java b/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/bookcrossing/BookCrossingDataModel.java
new file mode 100644
index 0000000..3e2f8b5
--- /dev/null
+++ b/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/bookcrossing/BookCrossingDataModel.java
@@ -0,0 +1,99 @@
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.cf.taste.example.bookcrossing;
+import java.io.File;
+import java.io.FileNotFoundException;
+import java.io.FileOutputStream;
+import java.io.IOException;
+import java.io.OutputStreamWriter;
+import java.io.Writer;
+import java.util.regex.Pattern;
+import com.google.common.base.Charsets;
+import com.google.common.io.Closeables;
+import org.apache.mahout.cf.taste.similarity.precompute.example.GroupLensDataModel;
+import org.apache.mahout.cf.taste.impl.model.file.FileDataModel;
+import org.apache.mahout.common.iterator.FileLineIterable;
+ * See <a href="http://www.informatik.uni-freiburg.de/~cziegler/BX/BX-CSV-Dump.zip">download</a> for
+ * data needed by this class. The BX-Book-Ratings.csv file is needed.
+ */
+public final class BookCrossingDataModel extends FileDataModel {
+ private static final Pattern NON_DIGIT_SEMICOLON_PATTERN = Pattern.compile("[^0-9;]");
+ public BookCrossingDataModel(boolean ignoreRatings) throws IOException {
+ this(GroupLensDataModel.readResourceToTempFile(
+ "/org/apache/mahout/cf/taste/example/bookcrossing/BX-Book-Ratings.csv"),
+ ignoreRatings);
+ }
+ /**
+ * @param ratingsFile BookCrossing ratings file in its native format
+ * @throws IOException if an error occurs while reading or writing files
+ */
+ public BookCrossingDataModel(File ratingsFile, boolean ignoreRatings) throws IOException {
+ super(convertBCFile(ratingsFile, ignoreRatings));
+ }
+ private static File convertBCFile(File originalFile, boolean ignoreRatings) throws IOException {
+ if (!originalFile.exists()) {
+ throw new FileNotFoundException(originalFile.toString());
+ }
+ File resultFile = new File(new File(System.getProperty("java.io.tmpdir")), "taste.bookcrossing.txt");
+ resultFile.delete();
+ Writer writer = null;
+ try {
+ writer = new OutputStreamWriter(new FileOutputStream(resultFile), Charsets.UTF_8);
+ for (String line : new FileLineIterable(originalFile, true)) {
+ // 0 ratings are basically "no rating", ignore them (thanks h.9000)
+ if (line.endsWith("\"0\"")) {
+ continue;
+ }
+ // Delete replace anything that isn't numeric, or a semicolon delimiter. Make comma the delimiter.
+ String convertedLine = NON_DIGIT_SEMICOLON_PATTERN.matcher(line)
+ .replaceAll("").replace(';', ',');
+ // If this means we deleted an entire ID -- few cases like that -- skip the line
+ if (convertedLine.contains(",,")) {
+ continue;
+ }
+ if (ignoreRatings) {
+ // drop rating
+ convertedLine = convertedLine.substring(0, convertedLine.lastIndexOf(','));
+ }
+ writer.write(convertedLine);
+ writer.write('\n');
+ }
+ writer.flush();
+ } catch (IOException ioe) {
+ resultFile.delete();
+ throw ioe;
+ } finally {
+ Closeables.close(writer, false);
+ }
+ return resultFile;
+ }
+ @Override
+ public String toString() {
+ return "BookCrossingDataModel";
+ }

diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/bookcrossing/BookCrossingDataModelBuilder.java b/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/bookcrossing/BookCrossingDataModelBuilder.java
new file mode 100644
index 0000000..9ec2eaf
--- /dev/null
+++ b/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/bookcrossing/BookCrossingDataModelBuilder.java
@@ -0,0 +1,33 @@
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.cf.taste.example.bookcrossing;
+import org.apache.mahout.cf.taste.eval.DataModelBuilder;
+import org.apache.mahout.cf.taste.impl.common.FastByIDMap;
+import org.apache.mahout.cf.taste.impl.model.GenericBooleanPrefDataModel;
+import org.apache.mahout.cf.taste.model.DataModel;
+import org.apache.mahout.cf.taste.model.PreferenceArray;
+final class BookCrossingDataModelBuilder implements DataModelBuilder {
+ @Override
+ public DataModel buildDataModel(FastByIDMap<PreferenceArray> trainingData) {
+ return new GenericBooleanPrefDataModel(GenericBooleanPrefDataModel.toDataMap(trainingData));
+ }

diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/bookcrossing/BookCrossingRecommender.java b/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/bookcrossing/BookCrossingRecommender.java
new file mode 100644
index 0000000..c06ca2f
--- /dev/null
+++ b/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/bookcrossing/BookCrossingRecommender.java
@@ -0,0 +1,101 @@
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.cf.taste.example.bookcrossing;
+import java.util.Collection;
+import java.util.List;
+import org.apache.mahout.cf.taste.common.Refreshable;
+import org.apache.mahout.cf.taste.common.TasteException;
+import org.apache.mahout.cf.taste.impl.neighborhood.NearestNUserNeighborhood;
+import org.apache.mahout.cf.taste.impl.recommender.GenericUserBasedRecommender;
+import org.apache.mahout.cf.taste.impl.similarity.CachingUserSimilarity;
+import org.apache.mahout.cf.taste.impl.similarity.EuclideanDistanceSimilarity;
+import org.apache.mahout.cf.taste.model.DataModel;
+import org.apache.mahout.cf.taste.neighborhood.UserNeighborhood;
+import org.apache.mahout.cf.taste.recommender.IDRescorer;
+import org.apache.mahout.cf.taste.recommender.RecommendedItem;
+import org.apache.mahout.cf.taste.recommender.Recommender;
+import org.apache.mahout.cf.taste.similarity.UserSimilarity;
+ * A simple {@link Recommender} implemented for the Book Crossing demo.
+ * See the <a href="http://www.informatik.uni-freiburg.de/~cziegler/BX/">Book Crossing site</a>.
+ */
+public final class BookCrossingRecommender implements Recommender {
+ private final Recommender recommender;
+ public BookCrossingRecommender(DataModel bcModel) throws TasteException {
+ UserSimilarity similarity = new CachingUserSimilarity(new EuclideanDistanceSimilarity(bcModel), bcModel);
+ UserNeighborhood neighborhood = new NearestNUserNeighborhood(10, 0.2, similarity, bcModel, 0.2);
+ recommender = new GenericUserBasedRecommender(bcModel, neighborhood, similarity);
+ }
+ @Override
+ public List<RecommendedItem> recommend(long userID, int howMany) throws TasteException {
+ return recommender.recommend(userID, howMany);
+ }
+ @Override
+ public List<RecommendedItem> recommend(long userID, int howMany, boolean includeKnownItems) throws TasteException {
+ return recommend(userID, howMany, null, includeKnownItems);
+ }
+ @Override
+ public List<RecommendedItem> recommend(long userID, int howMany, IDRescorer rescorer) throws TasteException {
+ return recommender.recommend(userID, howMany, rescorer, false);
+ }
+ @Override
+ public List<RecommendedItem> recommend(long userID, int howMany, IDRescorer rescorer, boolean includeKnownItems)
+ throws TasteException {
+ return recommender.recommend(userID, howMany, rescorer, false);
+ }
+ @Override
+ public float estimatePreference(long userID, long itemID) throws TasteException {
+ return recommender.estimatePreference(userID, itemID);
+ }
+ @Override
+ public void setPreference(long userID, long itemID, float value) throws TasteException {
+ recommender.setPreference(userID, itemID, value);
+ }
+ @Override
+ public void removePreference(long userID, long itemID) throws TasteException {
+ recommender.removePreference(userID, itemID);
+ }
+ @Override
+ public DataModel getDataModel() {
+ return recommender.getDataModel();
+ }
+ @Override
+ public void refresh(Collection<Refreshable> alreadyRefreshed) {
+ recommender.refresh(alreadyRefreshed);
+ }
+ @Override
+ public String toString() {
+ return "BookCrossingRecommender[recommender:" + recommender + ']';
+ }

diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/bookcrossing/BookCrossingRecommenderBuilder.java b/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/bookcrossing/BookCrossingRecommenderBuilder.java
new file mode 100644
index 0000000..bb6d3e1
--- /dev/null
+++ b/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/bookcrossing/BookCrossingRecommenderBuilder.java
@@ -0,0 +1,32 @@
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.cf.taste.example.bookcrossing;
+import org.apache.mahout.cf.taste.common.TasteException;
+import org.apache.mahout.cf.taste.eval.RecommenderBuilder;
+import org.apache.mahout.cf.taste.model.DataModel;
+import org.apache.mahout.cf.taste.recommender.Recommender;
+final class BookCrossingRecommenderBuilder implements RecommenderBuilder {
+ @Override
+ public Recommender buildRecommender(DataModel dataModel) throws TasteException {
+ return new BookCrossingRecommender(dataModel);
+ }

diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/bookcrossing/BookCrossingRecommenderEvaluatorRunner.java b/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/bookcrossing/BookCrossingRecommenderEvaluatorRunner.java
new file mode 100644
index 0000000..97074d2
--- /dev/null
+++ b/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/bookcrossing/BookCrossingRecommenderEvaluatorRunner.java
@@ -0,0 +1,54 @@
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.cf.taste.example.bookcrossing;
+import java.io.File;
+import java.io.IOException;
+import org.apache.commons.cli2.OptionException;
+import org.apache.mahout.cf.taste.common.TasteException;
+import org.apache.mahout.cf.taste.eval.RecommenderEvaluator;
+import org.apache.mahout.cf.taste.example.TasteOptionParser;
+import org.apache.mahout.cf.taste.impl.eval.AverageAbsoluteDifferenceRecommenderEvaluator;
+import org.apache.mahout.cf.taste.model.DataModel;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+public final class BookCrossingRecommenderEvaluatorRunner {
+ private static final Logger log = LoggerFactory.getLogger(BookCrossingRecommenderEvaluatorRunner.class);
+ private BookCrossingRecommenderEvaluatorRunner() {
+ // do nothing
+ }
+ public static void main(String... args) throws IOException, TasteException, OptionException {
+ RecommenderEvaluator evaluator = new AverageAbsoluteDifferenceRecommenderEvaluator();
+ File ratingsFile = TasteOptionParser.getRatings(args);
+ DataModel model =
+ ratingsFile == null ? new BookCrossingDataModel(false) : new BookCrossingDataModel(ratingsFile, false);
+ double evaluation = evaluator.evaluate(new BookCrossingRecommenderBuilder(),
+ null,
+ model,
+ 0.9,
+ 0.3);
+ log.info(String.valueOf(evaluation));
+ }

diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/bookcrossing/README b/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/bookcrossing/README
new file mode 100644
index 0000000..9244fe3
--- /dev/null
+++ b/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/bookcrossing/README
@@ -0,0 +1,9 @@
+Code works with BookCrossing data set, which is not included in this distribution but is downloadable from
+Data set originated from:
+Improving Recommendation Lists Through Topic Diversification,
+ Cai-Nicolas Ziegler, Sean M. McNee, Joseph A. Konstan, Georg Lausen;
+ Proceedings of the 14th International World Wide Web Conference (WWW '05), May 10-14, 2005, Chiba, Japan.
+ To appear.
\ No newline at end of file

diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/email/EmailUtility.java b/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/email/EmailUtility.java
new file mode 100644
index 0000000..033daa2
--- /dev/null
+++ b/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/email/EmailUtility.java
@@ -0,0 +1,104 @@
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.cf.taste.example.email;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.Writable;
+import org.apache.mahout.common.HadoopUtil;
+import org.apache.mahout.common.Pair;
+import org.apache.mahout.common.iterator.sequencefile.SequenceFileIterable;
+import org.apache.mahout.math.map.OpenObjectIntHashMap;
+import java.io.IOException;
+import java.util.regex.Pattern;
+public final class EmailUtility {
+ public static final String SEPARATOR = "separator";
+ public static final String MSG_IDS_PREFIX = "msgIdsPrefix";
+ public static final String FROM_PREFIX = "fromPrefix";
+ public static final String MSG_ID_DIMENSION = "msgIdDim";
+ public static final String FROM_INDEX = "fromIdx";
+ public static final String REFS_INDEX = "refsIdx";
+ private static final String[] EMPTY = new String[0];
+ private static final Pattern ADDRESS_CLEANUP = Pattern.compile("mailto:|<|>|\\[|\\]|\\=20");
+ private static final Pattern ANGLE_BRACES = Pattern.compile("<|>");
+ private static final Pattern SPACE_OR_CLOSE_ANGLE = Pattern.compile(">|\\s+");
+ public static final Pattern WHITESPACE = Pattern.compile("\\s*");
+ private EmailUtility() {
+ }
+ /**
+ * Strip off some spurious characters that make it harder to dedup
+ */
+ public static String cleanUpEmailAddress(CharSequence address) {
+ //do some cleanup to normalize some things, like: Key: karthik ananth <***@gmail.com>: Value: 178
+ //Key: karthik ananth [mailto:***@gmail.com]=20: Value: 179
+ //TODO: is there more to clean up here?
+ return ADDRESS_CLEANUP.matcher(address).replaceAll("");
+ }
+ public static void loadDictionaries(Configuration conf, String fromPrefix,
+ OpenObjectIntHashMap<String> fromDictionary,
+ String msgIdPrefix,
+ OpenObjectIntHashMap<String> msgIdDictionary) throws IOException {
+ Path[] localFiles = HadoopUtil.getCachedFiles(conf);
+ FileSystem fs = FileSystem.getLocal(conf);
+ for (Path dictionaryFile : localFiles) {
+ // key is word value is id
+ OpenObjectIntHashMap<String> dictionary = null;
+ if (dictionaryFile.getName().startsWith(fromPrefix)) {
+ dictionary = fromDictionary;
+ } else if (dictionaryFile.getName().startsWith(msgIdPrefix)) {
+ dictionary = msgIdDictionary;
+ }
+ if (dictionary != null) {
+ dictionaryFile = fs.makeQualified(dictionaryFile);
+ for (Pair<Writable, IntWritable> record
+ : new SequenceFileIterable<Writable, IntWritable>(dictionaryFile, true, conf)) {
+ dictionary.put(record.getFirst().toString(), record.getSecond().get());
+ }
+ }
+ }
+ }
+ public static String[] parseReferences(CharSequence rawRefs) {
+ String[] splits;
+ if (rawRefs != null && rawRefs.length() > 0) {
+ splits = SPACE_OR_CLOSE_ANGLE.split(rawRefs);
+ for (int i = 0; i < splits.length; i++) {
+ splits[i] = ANGLE_BRACES.matcher(splits[i]).replaceAll("");
+ }
+ } else {
+ splits = EMPTY;
+ }
+ return splits;
+ }
+ public enum Counters {
+ }

diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/email/FromEmailToDictionaryMapper.java b/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/email/FromEmailToDictionaryMapper.java
new file mode 100644
index 0000000..5cd308d
--- /dev/null
+++ b/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/email/FromEmailToDictionaryMapper.java
@@ -0,0 +1,61 @@
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.cf.taste.example.email;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.mapreduce.Mapper;
+import org.apache.mahout.math.VarIntWritable;
+import java.io.IOException;
+ * Assumes the input is in the format created by {@link org.apache.mahout.text.SequenceFilesFromMailArchives}
+ */
+public final class FromEmailToDictionaryMapper extends Mapper<Text, Text, Text, VarIntWritable> {
+ private String separator;
+ @Override
+ protected void setup(Context context) throws IOException, InterruptedException {
+ super.setup(context);
+ separator = context.getConfiguration().get(EmailUtility.SEPARATOR);
+ }
+ @Override
+ protected void map(Text key, Text value, Context context) throws IOException, InterruptedException {
+ //From is in the value
+ String valStr = value.toString();
+ int idx = valStr.indexOf(separator);
+ if (idx == -1) {
+ context.getCounter(EmailUtility.Counters.NO_FROM_ADDRESS).increment(1);
+ } else {
+ String full = valStr.substring(0, idx);
+ //do some cleanup to normalize some things, like: Key: karthik ananth <***@gmail.com>: Value: 178
+ //Key: karthik ananth [mailto:***@gmail.com]=20: Value: 179
+ //TODO: is there more to clean up here?
+ full = EmailUtility.cleanUpEmailAddress(full);
+ if (EmailUtility.WHITESPACE.matcher(full).matches()) {
+ context.getCounter(EmailUtility.Counters.NO_FROM_ADDRESS).increment(1);
+ } else {
+ context.write(new Text(full), new VarIntWritable(1));
+ }
+ }
+ }

diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/email/MailToDictionaryReducer.java b/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/email/MailToDictionaryReducer.java
new file mode 100644
index 0000000..72fcde9
--- /dev/null
+++ b/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/email/MailToDictionaryReducer.java
@@ -0,0 +1,43 @@
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.cf.taste.example.email;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.mapreduce.Reducer;
+import org.apache.mahout.math.VarIntWritable;
+import java.io.IOException;
+ * Key: the string id
+ * Value: the count
+ * Out Key: the string id
+ * Out Value: the sum of the counts
+ */
+public final class MailToDictionaryReducer extends Reducer<Text, VarIntWritable, Text, VarIntWritable> {
+ @Override
+ protected void reduce(Text key, Iterable<VarIntWritable> values, Context context)
+ throws IOException, InterruptedException {
+ int sum = 0;
+ for (VarIntWritable value : values) {
+ sum += value.get();
+ }
+ context.write(new Text(key), new VarIntWritable(sum));
+ }
2018-06-27 13:14:28 UTC
diff --git a/examples/src/test/resources/wdbc/wdbc.data b/examples/src/test/resources/wdbc/wdbc.data
deleted file mode 100644
index 8885375..0000000
--- a/examples/src/test/resources/wdbc/wdbc.data
+++ /dev/null
@@ -1,569 +0,0 @@

