ブログトップ 記事一覧 ログイン 無料ブログ開設

hamadakoichi blog このページをアンテナに追加 RSSフィード Twitter

2011-05-07

[][] Mahout RandomForest Driver 実装法 -大規模分散 機械学習・判別 -  Mahout RandomForest Driver 実装法 -大規模分散 機械学習・判別 - を含むブックマーク  Mahout RandomForest Driver 実装法 -大規模分散 機械学習・判別 - のブックマークコメント

Apache Mahout は、Hadoop上で動作する大規模分散データマイニング機械学習ライブラリ

Random Forest は大規模データで高精度の分類・判別を実現するアルゴリズム


Random Forestを、"R言語での実行のように容易"に "大規模分散 学習・判別"できるように、

Mahout を用いた各種 Driver を実装しました。

以下に実行方法、実装を紹介します。


  • org.mahoutjp.df.ForestDriver
    • Random Forest の分散学習から、分散判別、判別結果出力、および、精度評価まで行う Driver。
  • org.mahoutjp.df.ForestClassificationDriver
    • 生成された Forest Modelを用いて、分散判別、判別結果出力、および、精度評価まで行う Driver。

両 Driver とも、1コマンドで容易に分散実行できる実装にしています。

コマンドライン実行とともに、public static の ForestDriver.run(...), ForestClassificationDriver.run(...) も用意しており、シンプルに呼び出し実行できます。

Mahoutの Example SoucrCode にある、出力結果が読めない等の各種あった課題も、解決する実装にしています。


Random Forest 入門

Random Forest は、大規模データでも他分類器に比べて高精度・高速に判別できるアルゴリズム

入門資料:


org.mathoujp.df.ForestDriver: 分散学習・判別

実行法

Forest Modelの分散学習から、分散判別、判別結果出力、および、精度評価まで行えます。

$HADOOP_HOME/bin/hadoop jar mahoutjp-0.1-job.jar \
org.mahoutjp.df.ForestDriver \
- ... (Options)
実行オプション

以下の実行オプションを指定できます。

Options
  --data (-d) path                                   Data path 
  --descriptor (-dsc) descriptor [descriptor ...]    data descriptor (N:Numerical, C:Categorical, L:Label)
  --descriptor_out (-ds) file path                   Path to generated descriptor file 
  --selection (-sl) m                                Number of variables to select randomly at each tree-node
  --nbtrees (-t) nbtrees                             Number of trees to grow
  --forest_out (-fo) dir path                        Output path of Decision Forest
  -oob                                               Optional, estimate the out-of-bag error
  --seed (-sd) seed                                  Optional, seed value used to initialise the Random number generator
  --partial (-p)                                     Optional, use the Partial Data implementation
  --testdata (-td) dir path                          Test data path
  --predout (-po) dir path                           Path to generated prediction output file
  --help (-h)                                        Print out help
データ形式

入力形式:

dataid, attributes1, attribute2, ... (comma sep)

出力形式:

dataid, predictionIndex, realLabelIndex (tab sep)
実行例
$HADOOP_HOME/bin/hadoop jar mahoutjp-0.1-job.jar \
org.mahoutjp.df.ForestDriver \
-Dmapred.max.split.size=5074231 \
-d testdata/kdd/KDDTrain.arff \
-ds testdata/kdd/KDD.info \
-fo testdata/kdd/forest \
-dsc N 4 C 2 N C 4 N C 8 N 2 C 19 N L \
-oob \
-sl 7 \
-p \
-t 500 \
-td testdata/kdd/KDDTest \
-po testdata/kdd/predictions

※Data Descriptorの短縮形表記。個数 Type で短縮表記ができる。 "N N N I N N C C L I I I I I"は"3 N I N N 2 C L 5 I"。

データ:

NSL-KDD データ

データ定義行(@で始まる行)を削除。dataIdを1列目に付加。

学習データ:

dataid, attributes1, attribute2, ...

1000001,0,"tcp","ftp_data","SF",491,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,2,2,0,0,0,0,1,0,0,150,25,0.17,0.03,0.17,0,0,0,0.05,0,"normal"
1000002,0,"udp","other","SF",146,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,13,1,0,0,0,0,0.08,0.15,0,255,1,0,0.6,0.88,0,0,0,0,0,"normal"
1000003,0,"tcp","private","S0",0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,123,6,1,1,0,0,0.05,0.07,0,255,26,0.1,0.05,0,0,1,1,0,0,"anomaly"
...

テストデータ:

testdata/kdd/Test/part-00000

2000001,0,"tcp","private","REJ",0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,229,10,0,0,1,1,0.04,0.06,0,255,10,0.04,0.06,0,0,0,0,1,1,"anomaly"
2000002,0,"tcp","private","REJ",0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,136,1,0,0,1,1,0.01,0.06,0,255,1,0,0.06,0,0,0,0,1,1,"anomaly"
2000003,2,"tcp","ftp_data","SF",12983,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,0,0,0,0,1,0,0,134,86,0.61,0.04,0.61,0.02,0,0,0,0,"normal"
...
出力結果:
  • testdata/kdd/predictions/part-00000, part-00001, part-00002

dataid, predictionIndex, realLabelIndex

testdata/kdd/predictions/part-00000

2000001	1	1
2000002	1	1
2000003	0	0
...
実行結果ログ(抜粋)

精度評価まで実行される。

11/05/07 21:32:09 INFO df.ForestDriver: Generating the descriptor...
11/05/07 21:32:09 INFO df.ForestDriver: generating the descriptor dataset...
11/05/07 21:32:12 INFO df.ForestDriver: storing the dataset description
11/05/07 21:32:13 INFO df.ForestDriver: Partial Mapred
11/05/07 21:32:13 INFO df.ForestDriver: Building the forest...
...
11/05/07 21:35:24 INFO df.ForestDriver: Build Time: 0h 3m 11s 360
11/05/07 21:35:26 INFO df.ForestDriver: oob error estimate : 0.0016114564232017972
11/05/07 21:35:26 INFO df.ForestDriver: Storing the forest in: testdata/kdd/forest/forest.seq
...
11/05/07 21:36:00 INFO df.ForestClassifier: # of data which cannot be predicted: 0
11/05/07 21:36:00 INFO df.ForestClassificationDriver: =======================================================
Summary
-------------------------------------------------------
Correctly Classified Instances          :      18013       79.9015%
Incorrectly Classified Instances        :       4531       20.0985%
Total Classified Instances              :      22544

=======================================================
Confusion Matrix
-------------------------------------------------------
a       b       <--Classified as
9453    4273     |  13726       a     = "normal"
258     8560     |  8818        b     = "anomaly"
Default Category: unknown: 2

org.mathoujp.df.ForestClassificationDriver: 分散判別

分散分類のみを実行。生成されたForest Modelを用いて、分散判別、判別結果出力、および、精度評価まで行えます。

実行法
$HADOOP_HOME/bin/hadoop jar mahoutjp-0.1-job.jar \
org.mahoutjp.df.ForestClassifyDriver \
-...(Options)
実行Option
  --testdata (-td) dir path       Test data path
  --dataset (-ds) file path       Dataset path
  --model (-m) dir path           Path to the Decision Forest
  --predout (-po) dir path        Path to generated predictions file
  --analyze (-a)                  Analyze Results
  --help (-h)                     Print out help
実行例
$HADOOP_HOME/bin/hadoop jar mahoutjp-0.1-job.jar \
org.mahoutjp.df.ForestClassificationDriver \
-td testdata/kdd/KDDTest \
-ds testdata/kdd/KDD.info \
-m testdata/kdd/forest \
-po testdata/kdd/predictions \
-a


ソース

概要:

  • Random Forestの分散学習・判別
    • org.mahoutjp.df.ForestDriver
  • Forest Modelによる分散判別
    • org.mahoutjp.df.ForestClassificationDriver
    • org.mahoutjp.df.ForestClassifier

org.mahoutjp.df.ForestDriver

package org.mahoutjp.df;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Random;

import org.apache.commons.cli2.CommandLine;
import org.apache.commons.cli2.Group;
import org.apache.commons.cli2.Option;
import org.apache.commons.cli2.OptionException;
import org.apache.commons.cli2.builder.ArgumentBuilder;
import org.apache.commons.cli2.builder.DefaultOptionBuilder;
import org.apache.commons.cli2.builder.GroupBuilder;
import org.apache.commons.cli2.commandline.Parser;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.util.ToolRunner;
import org.apache.mahout.common.AbstractJob;
import org.apache.mahout.common.CommandLineUtil;
import org.apache.mahout.common.RandomUtils;
import org.apache.mahout.df.DFUtils;
import org.apache.mahout.df.DecisionForest;
import org.apache.mahout.df.ErrorEstimate;
import org.apache.mahout.df.builder.DefaultTreeBuilder;
import org.apache.mahout.df.callback.ForestPredictions;
import org.apache.mahout.df.data.Data;
import org.apache.mahout.df.data.DataLoader;
import org.apache.mahout.df.data.Dataset;
import org.apache.mahout.df.data.DescriptorException;
import org.apache.mahout.df.data.DescriptorUtils;
import org.apache.mahout.df.mapreduce.Builder;
import org.apache.mahout.df.mapreduce.inmem.InMemBuilder;
import org.apache.mahout.df.mapreduce.partial.PartialBuilder;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
 * Random Forest Driver
 * 
 * @author hamadakoichi
 */
public class ForestDriver extends AbstractJob {

  private static final Logger log = LoggerFactory.getLogger(ForestDriver.class);

  @Override
  public int run(String[] args) throws Exception {

		DefaultOptionBuilder obuilder = new DefaultOptionBuilder();
		ArgumentBuilder abuilder = new ArgumentBuilder();
		GroupBuilder gbuilder = new GroupBuilder();

		Option dataOpt = obuilder.withLongName("data").withShortName("d")
				.withRequired(true).withArgument(abuilder.withName("file path").withMinimum(1).withMaximum(1).create())
				.withDescription("Data file path")
				.create();

		Option descriptorOpt = obuilder
				.withLongName("descriptor").withShortName("dsc")
				.withRequired(true)
				.withArgument(abuilder.withName("descriptor").withMinimum(1).create())
				.withDescription("data descriptor file path").create();

		Option descPathOpt = obuilder.withLongName("descriptor_out")
				.withShortName("ds")
				.withRequired(true).withArgument(abuilder.withName("file path").withMinimum(1).withMaximum(1).create())
				.withDescription("Path to generated descriptor file").create();

		Option selectionOpt = obuilder
				.withLongName("selection").withShortName("sl")
				.withRequired(true)
				.withArgument(abuilder.withName("m").withMinimum(1).withMaximum(1).create())
				.withDescription("Number of variables to select randomly at each tree-node").create();

		Option oobOpt = obuilder.withShortName("oob").withRequired(false)
				.withDescription("Optional, estimate the out-of-bag error").create();

		Option seedOpt = obuilder
				.withLongName("seed").withShortName("sd")
				.withRequired(false)
				.withArgument(abuilder.withName("seed").withMinimum(1).withMaximum(1).create())
				.withDescription("Optional, seed value used to initialise the Random number generator").create();

		Option partialOpt = obuilder.withLongName("partial").withShortName("p")
				.withRequired(false).withDescription("Optional, use the Partial Data implementation").create();

		Option nbtreesOpt = obuilder.withLongName("nbtrees").withShortName("t")
				.withRequired(true).withArgument(abuilder.withName("nbtrees").withMinimum(1).withMaximum(1).create())
				.withDescription("Number of trees to grow").create();

		Option forestOutOpt = obuilder.withLongName("forest_out").withShortName("fo")
				.withRequired(true).withArgument(abuilder.withName("dir path").withMinimum(1).withMaximum(1).create())
				.withDescription("Output path of Decision Forest").create();

		
		//For Predictions
		Option testDataOpt = obuilder.withLongName("testdata").withShortName("td")
				.withRequired(true).withArgument(abuilder.withName("dir path").withMinimum(1).withMaximum(1).create())
				.withDescription("Test data path").create();

		Option predictionOutOpt = obuilder.withLongName("predout").withShortName("po")
				.withRequired(true).withArgument(abuilder.withName("dir path").withMinimum(1).withMaximum(1).create()).
				withDescription("Path to generated prediction output file").create();		

		Option helpOpt = obuilder.withLongName("help").withDescription("Print out help").withShortName("h").create();

		Group group = gbuilder.withName("Options").withOption(oobOpt)
				.withOption(dataOpt).withOption(descriptorOpt).withOption(descPathOpt)
				.withOption(selectionOpt).withOption(seedOpt).withOption(partialOpt)
				.withOption(nbtreesOpt).withOption(forestOutOpt)
				.withOption(testDataOpt).withOption(predictionOutOpt)
				.withOption(helpOpt).create();
    
		try {
			Parser parser = new Parser();
			parser.setGroup(group);
			CommandLine cmdLine = parser.parse(args);

			if (cmdLine.hasOption("help")) {
				CommandLineUtil.printHelp(group);
				return -1;
			}

			// Forest Parameter
			boolean isPartial = cmdLine.hasOption(partialOpt);
			boolean isOob = cmdLine.hasOption(oobOpt);
			String dataName = cmdLine.getValue(dataOpt).toString();
			List<String> descriptor = convert(cmdLine.getValues(descriptorOpt));
			String descriptorName = cmdLine.getValue(descPathOpt).toString();
			String forestName = cmdLine.getValue(forestOutOpt).toString();
			int m = Integer.parseInt(cmdLine.getValue(selectionOpt).toString());
			int nbTrees = Integer.parseInt(cmdLine.getValue(nbtreesOpt).toString());
			Long seed = 0L ; 
			if (cmdLine.hasOption(seedOpt)) {
				seed = Long.valueOf(cmdLine.getValue(seedOpt).toString());
			}

			// Classification Parameters
			String testdataName =  cmdLine.getValue(testDataOpt).toString();
			String predictionOutName =  cmdLine.getValue(predictionOutOpt).toString();
			

			log.debug("data : {}", dataName);
			log.debug("descriptor: {}", descriptor);
			log.debug("descriptorOut: {}", descriptorName);
			log.debug("forestOut    : {}", forestName);
			log.debug("m : {}", m);
			log.debug("seed : {}", seed);
			log.debug("nbtrees : {}", nbTrees);
			log.debug("isPartial : {}", isPartial);
			log.debug("isOob : {}", isOob);
			log.debug("testData : {}", testdataName);
			log.debug("predictionOut : {}", predictionOutName);

			// Execute
			run(getConf(), dataName, descriptor, descriptorName, forestName, m, 
					seed, nbTrees, isPartial, isOob, testdataName, predictionOutName);
			
		} catch (OptionException e) {
			log.error("Exception", e);
			CommandLineUtil.printHelp(group);
			return -1;
		}
		return 0;
    }

    public static void run(Configuration conf, String dataPathName, List<String> description,
		String descriptorPathName, String forestPathName, int m, Long seed, int nbTrees,
		boolean isPartial, boolean isOob, 
		String testdataPathName, String predictionOutPathName)
    		throws DescriptorException, IOException, ClassNotFoundException, InterruptedException {
   
	  	// Create Descriptor
		Path dataPath = validateInput(dataPathName);
		Path descriptorPath = validateOutput(descriptorPathName);
	  	createDescriptor(conf, dataPath, description, descriptorPath);
		
		// Build Forest
		Path forestPath = validateOutput(forestPathName);
		buildForest(conf, dataPath, description,
				descriptorPath, forestPath, m, seed, nbTrees,
				isPartial, isOob);
		
		// Predict
		boolean analyze = true;
		ForestClassificationDriver.run(conf, forestPathName, testdataPathName, 
				descriptorPathName, predictionOutPathName, analyze);
    }	  
	/**
	 * Greate Descreptor
	 */	
	private static void createDescriptor(Configuration conf, Path dataPath, List<String> description,
			Path outPath) throws DescriptorException, IOException {
		
		log.info("Generating the descriptor...");
		String descriptor = DescriptorUtils.generateDescriptor(description);
		log.info("generating the descriptor dataset...");
		Dataset dataset = generateDataset(descriptor, dataPath);
		log.info("storing the dataset description");

		DFUtils.storeWritable(conf, outPath, dataset);
	}

	/**
	 * Generate DataSet
	 */	
	private static Dataset generateDataset(String descriptor, Path dataPath)
			throws IOException, DescriptorException {
		
		FileSystem fs = dataPath.getFileSystem(new Configuration());		
		Path[] files = DFUtils.listOutputFiles(fs, dataPath);
		return DataLoader.generateDataset(descriptor, fs, files[0]);
	}

	/**
	 * Build Forest
	 */	
	private static void buildForest(Configuration conf, Path dataPath, List<String> description,
			Path descriptorPath, Path forestPath, int m, Long seed, int nbTrees,
			boolean isPartial, boolean isOob)
			throws IOException, ClassNotFoundException,InterruptedException {

		FileSystem ofs = forestPath.getFileSystem(conf);
		if (ofs.exists(forestPath)) {
			log.error("Forest Output Path already exists");
			return;
		}

		DefaultTreeBuilder treeBuilder = new DefaultTreeBuilder();
		treeBuilder.setM(m);
		Dataset dataset = Dataset.load(conf, descriptorPath);

		ForestPredictions callback = isOob ? new ForestPredictions(dataset
				.nbInstances(), dataset.nblabels()) : null;

		Builder forestBuilder;

		if (isPartial) {
			log.info("Partial Mapred");
			forestBuilder = new PartialBuilder(treeBuilder, dataPath,
					descriptorPath, seed, conf);
		} else {
			log.info("InMemory Mapred");
			forestBuilder = new InMemBuilder(treeBuilder, dataPath,
					descriptorPath, seed, conf);
		}

		forestBuilder.setOutputDirName(forestPath.getName());

		log.info("Building the forest...");
		long time = System.currentTimeMillis();

		DecisionForest forest = forestBuilder.build(nbTrees, callback);

		time = System.currentTimeMillis() - time;
		log.info("Build Time: {}", DFUtils.elapsedTime(time));

		if (isOob) {
			Random rng;
			if (seed != null) {
				rng = RandomUtils.getRandom(seed);
			} else {
				rng = RandomUtils.getRandom();
			}

			FileSystem fs = dataPath.getFileSystem(conf);
			int[] labels = Data.extractLabels(dataset, fs, dataPath);

			log.info("oob error estimate : "
					+ ErrorEstimate.errorRate(labels, callback
							.computePredictions(rng)));
		}
		
		// store the forest
		Path forestoutPath = new Path(forestPath, "forest.seq");
		log.info("Storing the forest in: " + forestoutPath);
		DFUtils.storeWritable(conf, forestPath, forest);
	}

	/**
	 * Load data
	 */	
	protected static Data loadData(Configuration conf, Path dataPath,
			Dataset dataset) throws IOException {
		log.info("Loading the data...");
		FileSystem fs = dataPath.getFileSystem(conf);
		Data data = DataLoader.loadData(dataset, fs, dataPath);
		log.info("Data Loaded");

		return data;
	}

	/**
	 * Convert Collections to a String List
	 */	
	private static List<String> convert(Collection<?> values) {
		List<String> list = new ArrayList<String>(values.size());
		for (Object value : values) {
			list.add(value.toString());
		}
		return list;
	}

	/**
	 * Validation of the Output Path
	 */	
	private static Path validateOutput(String filePath) throws IOException {
		Path path = new Path(filePath);
		FileSystem fs = path.getFileSystem(new Configuration());
		if (fs.exists(path)) {
			throw new IllegalStateException(path.toString() + " already exists");
		}
		return path;
	}  
	

	/**
	 * Validation of the Input Path
	 */
	private static Path validateInput(String filePath) throws IOException {
		Path path = new Path(filePath);
		FileSystem fs = path.getFileSystem(new Configuration());
		if (!fs.exists(path)) {
			throw new IllegalArgumentException(path.toString() + " does not exist");
		}
		return path;		
	}

	public static void main(String[] args) throws Exception {
		    ToolRunner.run(new Configuration(), new ForestDriver(), args);
	}
}

org.mahoutjp.df.ForestClassificationDriver

package org.mahoutjp.df;

import java.io.IOException;

import org.apache.commons.cli2.CommandLine;
import org.apache.commons.cli2.Group;
import org.apache.commons.cli2.Option;
import org.apache.commons.cli2.OptionException;
import org.apache.commons.cli2.builder.ArgumentBuilder;
import org.apache.commons.cli2.builder.DefaultOptionBuilder;
import org.apache.commons.cli2.builder.GroupBuilder;
import org.apache.commons.cli2.commandline.Parser;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.util.ToolRunner;
import org.apache.mahout.common.AbstractJob;
import org.apache.mahout.common.CommandLineUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
 * Classification Mapreduce using a built Decision Forest
 * 
 * @author hamadakoichi
 */
public class ForestClassificationDriver extends AbstractJob {

  private static final Logger log = LoggerFactory.getLogger(ForestClassificationDriver.class);

  @Override
  public int run(String[] args) throws Exception {

		DefaultOptionBuilder obuilder = new DefaultOptionBuilder();
		ArgumentBuilder abuilder = new ArgumentBuilder();
		GroupBuilder gbuilder = new GroupBuilder();
		
		Option testDataOpt = obuilder.withLongName("testdata").withShortName("td")
				.withRequired(true).withArgument(abuilder.withName("dir path").withMinimum(1).withMaximum(1).create())
				.withDescription("Test data path").create();

		Option datasetOpt = obuilder.withLongName("dataset").withShortName("ds")
				.withRequired(true).withArgument(abuilder.withName("file path").withMinimum(1).withMaximum(1).create())
				.withDescription("Dataset path").create();

		Option modelOpt = obuilder.withLongName("model").withShortName("m")
				.withRequired(true).withArgument(abuilder.withName("dir path").withMinimum(1).withMaximum(1).create())
				.withDescription("Path to the Decision Forest").create();

		Option predictionOutOpt = obuilder.withLongName("predout").withShortName("po")
				.withRequired(false).withArgument(abuilder.withName("dir path").withMinimum(1).withMaximum(1).create())
				.withDescription("Path to generated predictions file").create();

		Option analyzeOpt = obuilder.withLongName("analyze").withShortName("a")
				.withRequired(false).create();

		Option helpOpt = obuilder.withLongName("help").withDescription("Print out help").withShortName("h").create();

		Group group = gbuilder.withName("Options").withOption(testDataOpt)
				.withOption(datasetOpt).withOption(modelOpt).withOption(predictionOutOpt)
				.withOption(analyzeOpt).withOption(helpOpt).create();

    
		try {
			Parser parser = new Parser();
			parser.setGroup(group);
			CommandLine cmdLine = parser.parse(args);

			if (cmdLine.hasOption("help")) {
				CommandLineUtil.printHelp(group);
				return -1;
			}
			
			//Parameters
			String testDataName = cmdLine.getValue(testDataOpt).toString();
			String descriptorName = cmdLine.getValue(datasetOpt).toString();
			String forestName = cmdLine.getValue(modelOpt).toString();
			String predictionName = (cmdLine.hasOption(predictionOutOpt)) ? cmdLine
					.getValue(predictionOutOpt).toString() : null;
			boolean analyze = cmdLine.hasOption(analyzeOpt);
			
			log.debug("inout     : {}", testDataName);
			log.debug("descriptor: {}", descriptorName);
			log.debug("forest    : {}", forestName);
			log.debug("prediction: {}", predictionName);
			log.debug("analyze   : {}", analyze);
			
			//Execute Classification
			log.info("Execute the Mapreduce Classification ...");
			run(getConf(), forestName, testDataName, descriptorName, predictionName, analyze);

		} catch (OptionException e) {
			log.warn(e.toString(), e);
			CommandLineUtil.printHelp(group);
			return -1;
		}
		return 0;
    }

  
    public static void run(Configuration conf, String forestPathName, 
    		String testDataPathName, String descriptorPathName, String predictionPathName, boolean analyze)
    	throws IOException, ClassNotFoundException, InterruptedException{

		// Classify data
		Path testDataPath = validateInput(testDataPathName);
		Path descriptorPath = validateInput(descriptorPathName);
		
		Path forestPath = validateInput(forestPathName);
		Path predictionPath = validateOutput(conf, predictionPathName);

		ForestClassifier classifier = new ForestClassifier(conf, forestPath, testDataPath,
				descriptorPath, predictionPath, analyze);
		classifier.run();

		// Analyze Results
		if (analyze) {
			log.info(classifier.getAnalyzer().summarize());
		}
    }	  

	/**
	 * Validation of the Output Path
	 */
	private static Path validateOutput(Configuration conf, String filePath) throws IOException {
		Path path = new Path(filePath);
		FileSystem fs = path.getFileSystem(conf);
		if (fs.exists(path)) {
			throw new IllegalStateException(path.toString() + " already exists");
		}
		return path;
	}

	/**
	 * Validation of the Input Path
	 */
	private static Path validateInput(String filePath) throws IOException {
		Path path = new Path(filePath);
		FileSystem fs = path.getFileSystem(new Configuration());
		if (!fs.exists(path)) {
			throw new IllegalArgumentException(path.toString() + " does not exist");
		}
		return path;		
	}

	public static void main(String[] args) throws Exception {
		ToolRunner.run(new Configuration(), new ForestClassificationDriver(), args);
	}
}

org.mahoutjp.df.ForestClassifier

package org.mahoutjp.df;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.apache.mahout.df.DecisionForest;
import org.apache.mahout.df.DFUtils;
import org.apache.mahout.df.data.DataConverter;
import org.apache.mahout.df.data.Dataset;
import org.apache.mahout.df.data.Instance;
import org.apache.mahout.common.RandomUtils;
import org.apache.hadoop.fs.FSDataInputStream;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.filecache.DistributedCache;
import org.apache.hadoop.mapreduce.Job;
import org.apache.hadoop.mapreduce.Mapper;
import org.apache.hadoop.mapreduce.JobContext;
import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;
import org.apache.hadoop.mapreduce.lib.input.TextInputFormat;
import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;
import org.apache.hadoop.mapreduce.lib.output.TextOutputFormat;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.io.LongWritable;

import java.io.IOException;
import java.util.Arrays;
import java.util.Random;
import java.util.Scanner;
import java.net.URI;
import org.apache.mahout.classifier.ClassifierResult;
import org.apache.mahout.classifier.ResultAnalyzer;

/**
 * Mapreduce implementation of Classifier with Forest
 * 
 * @author hamadakoichi
 */
public class ForestClassifier {

	private static final Logger log = LoggerFactory
			.getLogger(ForestClassifier.class);
	private final Path forestPath;
	private final Path inputPath;
	private final Path datasetPath;
	private final Configuration conf;

	private final ResultAnalyzer analyzer;
	private final Dataset dataset;
	private final Path outputPath;
	
	public ForestClassifier(Configuration conf, Path forestPath, Path inputPath, Path datasetPath,
			Path outputPath, boolean analyze)
			throws IOException {
		this.forestPath = forestPath;
		this.inputPath = inputPath;
		this.datasetPath = datasetPath;
		this.conf = conf;

		if (analyze) {
			dataset = Dataset.load(conf, datasetPath);
			analyzer = new ResultAnalyzer(Arrays.asList(dataset.labels()),
					"unknown");
		} else {
			dataset = null;
			analyzer = null;
		}
		this.outputPath = outputPath;
	}

	/**
	 * Classification Job Configure
	 */
	private void configureClsssifyJob(Job job) throws IOException {

		job.setJarByClass(ForestClassifier.class);

		FileInputFormat.setInputPaths(job, inputPath);
		FileOutputFormat.setOutputPath(job, outputPath);

		job.setOutputKeyClass(Text.class);
		job.setOutputValueClass(Text.class);

		job.setMapperClass(ClassifyMapper.class);
		job.setNumReduceTasks(0); // Classification Mapper Only

		job.setInputFormatClass(ClassifyTextInputFormat.class);
		job.setOutputFormatClass(TextOutputFormat.class);
	}

	public void run() throws IOException, ClassNotFoundException,
			InterruptedException {
		FileSystem fs = FileSystem.get(conf);

		if (fs.exists(outputPath)) {
			throw new IOException(outputPath + " already exists");
		}

		// put the dataset
		log.info("Adding the dataset to the DistributedCache");
		DistributedCache.addCacheFile(datasetPath.toUri(), conf);

		// load the forest
		log.info("Adding the forest to the DistributedCache");
		DistributedCache.addCacheFile(forestPath.toUri(), conf);

		// Classification
		Job cjob = new Job(conf, "Decision Forest classification");
		log.info("Configuring the Classification Job...");
		configureClsssifyJob(cjob);

		log.info("Running the Classification Job...");
		if (!cjob.waitForCompletion(true)) {
			log.error("Classification Job failed!");
			return;
		}
		
		// Analyze Results
		if (analyzer != null) {
			analyzeOutput(cjob);
		}
	}
	
	public ResultAnalyzer getAnalyzer() {
		return analyzer;
	}

	/**
	 * Analyze the Classification Results
	 * @param job
	 */
	private void analyzeOutput(Job job) throws IOException {
		Configuration conf = job.getConfiguration();
		Integer prediction;
		Integer realLabel;
		
		FileSystem fs = outputPath.getFileSystem(conf);
		Path[] outfiles = DFUtils.listOutputFiles(fs, outputPath);

		int cnt_cnp = 0;
		
		for (Path path : outfiles) {
			FSDataInputStream input = fs.open(path);
			Scanner scanner = new Scanner(input);

			while (scanner.hasNext()) {
				String line = scanner.nextLine();
				if (line.isEmpty()) {
					continue;
				}

				// id, predict, realLabel with \t sep
				String[] tmp = line.split("\t", -1);
				prediction = Integer.parseInt(tmp[1]);
				realLabel = Integer.parseInt(tmp[2]);
				
				if(prediction == -1) {
					// label cannot be predicted
					cnt_cnp++;
				}
				else{					
					if (analyzer != null) {
						analyzer.addInstance(dataset.getLabel(prediction), new ClassifierResult(dataset
								.getLabel(realLabel), 1.0));
					}
				}
			}
		}
		log.info("# of data which cannot be predicted: " + cnt_cnp);
	}

	/**
	 * Text Input Format: Each file is processed by single Mapper.
	 */
	public static class ClassifyTextInputFormat extends TextInputFormat {
		@Override
		protected boolean isSplitable(JobContext jobContext, Path path) {
			return false;
		}
	}

	/**
	 * Classification Mapper.
	 */
	public static class ClassifyMapper extends
			Mapper<LongWritable, Text, LongWritable, Text> {

		private DataConverter converter;
		private DecisionForest forest;
		private final Random rng = RandomUtils.getRandom();

		private final Text val = new Text();

		@Override
		protected void setup(Context context) throws IOException,
				InterruptedException {
			super.setup(context); // To change body of overridden methods use

			Configuration conf = context.getConfiguration();
			URI[] files = DistributedCache.getCacheFiles(conf);

			if ((files == null) || (files.length < 2)) {
				throw new IOException(
						"not enough paths in the DistributedCache");
			}

			Dataset dataset = Dataset.load(conf, new Path(files[0].getPath()));
			converter = new DataConverter(dataset);
			forest = DecisionForest.load(conf, new Path(files[1].getPath()));
			if (forest == null) {
				throw new InterruptedException("DecisionForest not found!");
			}
		}

		@Override
		protected void map(LongWritable key, Text value, Context context)
				throws IOException, InterruptedException {

			String line = value.toString();
			if (!line.isEmpty()) {
				String[] idVal = line.split(",", -1);
				Integer id = Integer.parseInt(idVal[0]);

				Instance instance = converter.convert(id, line);
				int prediction = forest.classify(rng, instance);

				// key:id 
				key.set(instance.getId());

				// val: prediction, originalLabel (with tab sep)
				StringBuffer sb = new StringBuffer();
				sb.append(Integer.toString(prediction));
				sb.append("\t");
				sb.append(instance.getLabel());
				val.set(sb.toString());

				context.write(key, val);
			}
		}
	}
}

関連資料