徒然なる日々を送るソフトウェアデベロッパーの記録(2)

技術上思ったことや感じたことを気ままに記録していくブログです。さくらから移設しました。

Naive Bayes を Java で実装してみた

今回も元ネタはこの本を使います。やっぱり Weka で実行すると
するりと実行できてしまって頭に残らないので、Java で実装する
ことを考えます。

フリーソフトではじめる機械学習入門

フリーソフトではじめる機械学習入門

package com.minosys.aplication;

import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;

/**
 * 単純ベイジアン検出器
 *
 * L(D) = arg[i] max P(\Omega_i) * ΠP(x_j|Omega_i)
 * ただし arg[i] max は引数が最大となるクラス値 i を取り出すことを示す。
 * P(A) は集合 A の事前確率、P(x_j|A) はサンプルの入力ベクトルを
 * x = {x_1, x_2, ..., x_n } と表現した場合に j 番目の属性が x_j となる
 * 事後確率を表すものとする。
 *
 * @author minosys
 *
 */
public class NaiveBayes implements Classifier {
	public static class NumValueWithProbability {
		int count;
		double probability;
		public String toString() {
			return count + "(" + probability + ")";
		}
	}
	// 頻度および出現確率は [判定値, 属性名, 属性値] のマップとなる
	Map<String, Map<String, Map<String, NumValueWithProbability>>> concept;
	Map<String, Double> zappMap;

	/**
	 * 判定属性が特定の判定値を持つ状況下で、指定した属性が出現する回数+Nを返す。
	 * ゼロ頻度問題を回避するため、各 count 値は +1 されていることに注意。
	 * Σ_{aname} C_{aname} は判定値の場合の和とはならず、判定値の取り得る
	 * 数だけ大きい。
	 * @param zvalue
	 * @param aname
	 * @return
	 */
	private int calcTotalCount(String zvalue, String aname) {
		int count = 0;
		Map<String, NumValueWithProbability> map = concept.get(zvalue).get(aname);
		for (Iterator<Map.Entry<String, NumValueWithProbability>> i = map.entrySet().iterator(); i.hasNext(); ) {
			count += i.next().getValue().count;
		}
		return count;
	}

	/**
	 * 各属性に対する出現確率を更新する。
	 * 必ず analyze() の最後に呼び出すこと。
	 */
	private void updateProbability() {
		int totalCount = 0;
		zappMap = new HashMap<String, Double>();

		// 各判定クラスに対して
		for (Iterator<Map.Entry<String, Map<String, Map<String, NumValueWithProbability>>>> i = concept.entrySet().iterator();
				i.hasNext(); ) {
			Map.Entry<String, Map<String, Map<String, NumValueWithProbability>>> e1 = i.next();
			Map<String, Map<String, NumValueWithProbability>> map = e1.getValue();
			int count = 0;

			// 各属性に対して
			for (Iterator<Map.Entry<String, Map<String, NumValueWithProbability>>> j = map.entrySet().iterator(); j.hasNext(); ) {
				Map.Entry<String, Map<String, NumValueWithProbability>> e2 = j.next();
				int total = calcTotalCount(e1.getKey(), e2.getKey()); // 属性の取り得る値の集合の要素数 + 属性数
				assert(total > 0);

				// 各属性値に対して
				for (Iterator<Map.Entry<String, NumValueWithProbability>> k = e2.getValue().entrySet().iterator(); k.hasNext(); ) {
					Map.Entry<String, NumValueWithProbability> e3 = k.next();
					count += e3.getValue().count - 1; // 出現数
					e3.getValue().probability = (double)e3.getValue().count / (double) total;
				}
			}
			totalCount += count;
			zappMap.put(e1.getKey(), (double)count);
		}

		if (totalCount > 0) {
			// 判定値の出現割合を計算する
			for (Iterator<String> x = concept.keySet().iterator(); x.hasNext(); ) {
				String s = x.next();
				zappMap.put(s, zappMap.get(s) / (double)totalCount);
			}
		}
	}

	@Override
	public void analyze(GenericData data) {
		// TODO 自動生成されたメソッド・スタブ
		concept = new HashMap<String, Map<String, Map<String, NumValueWithProbability>>>();

		// マトリックスを初期化する
		String zname = data.attributeNameList.get(data.lastIindex());
		for (Iterator<String> zp = data.attributes.get(zname).iterator(); zp.hasNext(); ) {
			String z = zp.next();
			Map<String, Map<String, NumValueWithProbability>> amap = new HashMap<String, Map<String, NumValueWithProbability>>();
			for (Iterator<Map.Entry<String, HashSet<String>>> akp = data.attributes.entrySet().iterator(); akp.hasNext(); ) {
				Map.Entry<String, HashSet<String>> e2 = akp.next();
				if (e2.getKey().equals(zname)) {
					// 判定名は含めない
					continue;
				}
				Map<String, NumValueWithProbability> vmap = new HashMap<String, NumValueWithProbability>();
				for (Iterator<String> avp = e2.getValue().iterator(); avp.hasNext(); ) {
					NumValueWithProbability np = new NumValueWithProbability();
					np.count = 1;
					String av = avp.next();
					vmap.put(av, np);
				}
				amap.put(e2.getKey(), vmap);
			}
			concept.put(z, amap);
		}

		// 各属性名に対して
		for (Iterator<Map.Entry<String, Integer>> i = data.attributeNameIndex.entrySet().iterator(); i.hasNext(); ) {
			Map.Entry<String, Integer> e = i.next();

			// 属性名が判定名である場合は計算対象から外す(二重計算するのを回避するため)
			if (e.getValue() == data.lastIindex()) {
				continue;
			}

			// 各行を分類する
			for (int j = 0; j < data.data.size(); ++j) {
				String aval = data.get(j, e.getValue()); // 属性値
				String zval = data.get(j, data.lastIindex()); // 判定値
				NumValueWithProbability np = concept.get(zval).get(e.getKey()).get(aval);
				++np.count;
			}
		}

		// 出現確率を計算する
		updateProbability();
	}

	@Override
	public ValueWithProbability decide(List<String> slist,
			Map<String, Integer> amap) {
		double maxval = 0.0;
		String maxname = null;
		Map<String, Double> pr2 = new HashMap<String, Double>();
		double px1 = 0.0;

		// P(x1|z_y)*...*P(z_y) を計算する
		for (Iterator<Map.Entry<String, Double>> i = zappMap.entrySet().iterator(); i.hasNext(); ) {
			Map.Entry<String, Double> e = i.next();
			double pr = e.getValue();
			for (Iterator<Map.Entry<String, Integer>> j = amap.entrySet().iterator(); j.hasNext(); ) {
				Map.Entry<String, Integer> e2 = j.next();
				if (concept.get(e.getKey()).containsKey(e2.getKey())) {
					pr = pr * concept.get(e.getKey()).get(e2.getKey()).get(slist.get(e2.getValue())).probability;
				}
			}
			px1 = px1 + pr;
			pr2.put(e.getKey(), pr);
		}

		// P(z_y|x1) を Bayes 定理から計算する
		for (Iterator<Map.Entry<String, Double>> j = pr2.entrySet().iterator(); j.hasNext(); ) {
			Map.Entry<String, Double> e = j.next();
			double pr = e.getValue() / px1;
			if (pr > maxval) {
				maxval = pr;
				maxname = e.getKey();
			}
		}
		if (maxname != null) {
			ValueWithProbability v = new ValueWithProbability(maxname, maxval);
			return v;
		}
		return null;
	}

}

NaiveBayes インスタンスに含まれる concept のキーは <クラス値>,
<特徴量名称>, <特徴値> の組となります。
またクラス値が現れる事前確率を zappMap に記録しています。

ゼロ頻度問題を回避するため、concept の値は 1 からカウントするようにしています。
weather.nominal.arff に対してプログラムを実行した結果は以下の通り。

no(0.35714285714285715) yes(0.6428571428571429) 
--------------------
outlook
rainy | 3 4 
overcast | 1 5 
sunny | 4 3 
temperature
mild | 3 5 
cool | 2 4 
hot | 3 3 
humidity
normal | 2 7 
high | 5 4 
windy
TRUE | 4 4 
FALSE | 3 7 

<< make a decision for the first 10 samples >>
[sunny, hot, high, FALSE, no]:no(0.687969069820332)
[sunny, hot, high, TRUE, no]:no(0.8372543592582342)
[overcast, hot, high, FALSE, yes]:yes(0.7514719978091196)
[rainy, mild, high, FALSE, yes]:yes(0.573353879907018)
[rainy, cool, normal, FALSE, yes]:yes(0.8758578235790337)
[rainy, cool, normal, TRUE, no]:yes(0.7514719978091196)
[overcast, cool, normal, TRUE, yes]:yes(0.9189551239115875)
[sunny, mild, high, FALSE, no]:no(0.5695010982114841)
[sunny, cool, normal, FALSE, yes]:yes(0.7987358616101131)
[rainy, mild, normal, FALSE, yes]:yes(0.854638487208009)

1例のみ正しくない判定がなされるが、これは Weka の結果と同じ。
お手本と算出された確率が微妙に違うんだが、まあいいか。